1import pathlib
2from typing import Callable, Dict
3
4import torch
5from torch.nn import functional as F
6from torch.utils.data import DataLoader, Dataset
7from tqdm import tqdm
8
9from diffusion_models.models.base_diffusion_model import BaseDiffusionModel
10from diffusion_models.utils.schemas import BetaSchedulerConfiguration, \
11 Checkpoint, LogConfiguration, TrainingConfiguration
12from diffusion_models.utils.tensorboard import TensorboardManager
13
14
[docs]
15class DiffusionTrainer:
16 def __init__(
17 self,
18 model: BaseDiffusionModel,
19 dataset: Dataset,
20 optimizer: torch.optim.Optimizer,
21 training_configuration: TrainingConfiguration,
22 loss_function: Callable = F.l1_loss,
23 # scheduler: Optional[torch.optim.lr_scheduler.StepLR] = None,
24 log_configuration: LogConfiguration = LogConfiguration(),
25 reverse_transforms: Callable = lambda x: x,
26 device: str = "cuda",
27 ):
28 """A diffusion trainer framework.
29
30 This is a simplified framework for training a diffusion models.
31
32 Args:
33 model: A diffusion model.
34 dataset: A dataset to train on.
35 optimizer: The optimizer to use.
36 training_configuration: The training configuration to use.
37 loss_function: The loss function to use.
38 log_configuration: The logging configuration to use.
39 reverse_transforms: The reverse transforms to use.
40 device: The device to use.
41 """
[docs]
42 self.model = model.to(device)
43 """The diffusion model to use."""
[docs]
44 self.optimizer = optimizer
45 """The optimizer to use."""
[docs]
46 self.loss_function = loss_function
47 """The loss function to use."""
[docs]
48 self.training_configuration = training_configuration
49 """The training configuration to use."""
50 # self.scheduler = scheduler
[docs]
51 self.device = device
52 """The device to use."""
53
[docs]
54 self.dataloader = DataLoader(
55 dataset=dataset,
56 batch_size=training_configuration.batch_size,
57 shuffle=True,
58 drop_last=True,
59 num_workers=16,
60 pin_memory=True,
61 persistent_workers=True,
62 )
63 """A torch dataloader."""
64
65 self._image_shape = dataset[0][0].shape
66
[docs]
67 self.scaler = torch.amp.GradScaler(
68 device=device
69 # init_scale=8192,
70 )
71 """A torch GradScaler object."""
72
[docs]
73 self.log_configuration = log_configuration
74 """A LogConfiguration object."""
75
[docs]
76 self.checkpoint_path = (
77 pathlib.Path("../checkpoints")
78 / self.training_configuration.training_name
79 )
80 """The path to save checkpoints."""
81
82 self.checkpoint_path.mkdir(exist_ok=True)
[docs]
83 self.tensorboard_manager = TensorboardManager(
84 log_name=self.training_configuration.training_name,
85 )
86 """A tensorboard manager instance."""
87
89 """A set of reverse transforms."""
90
[docs]
91 torch.backends.cudnn.benchmark = True
92
[docs]
93 def save_checkpoint(self, epoch: int, checkpoint_name: str):
94 """Save a checkpoint.
95
96 Args:
97 epoch: The current epoch.
98 checkpoint_name: The name of the checkpoint.
99 """
100 checkpoint = Checkpoint(
101 epoch=epoch,
102 model_state_dict=self.model.state_dict(),
103 optimizer_state_dict=self.optimizer.state_dict(),
104 scaler=self.scaler.state_dict()
105 if self.training_configuration.mixed_precision_training
106 else None,
107 image_channels=self._image_shape[0],
108 beta_scheduler_config=BetaSchedulerConfiguration(
109 steps=self.model.diffuser.beta_scheduler.steps,
110 betas=self.model.diffuser.beta_scheduler.betas,
111 alpha_bars=self.model.diffuser.beta_scheduler.alpha_bars,
112 ),
113 tensorboard_run_name=self.tensorboard_manager.summary_writer.log_dir,
114 )
115 checkpoint.to_file(self.checkpoint_path / checkpoint_name)
116
[docs]
117 def train(self):
118 """Start the diffusion training."""
119 self.model.train()
120 for epoch in range(self.training_configuration.number_of_epochs):
121 for step, batch in enumerate(
122 tqdm(self.dataloader, desc=f"Epoch={epoch}")
123 ):
124 global_step = epoch * len(self.dataloader) + step
125
126 images, _ = batch
127 images = images.to(self.device)
128
129 noisy_images, noise, timesteps = self.model.diffuse(images=images)
130
131 self.optimizer.zero_grad(set_to_none=True)
132
133 with torch.autocast(
134 device_type=self.device,
135 enabled=self.training_configuration.mixed_precision_training,
136 ):
137 prediction = self.model(noisy_images, timesteps)
138 loss = self.loss_function(noise, prediction)
139
140 self.scaler.scale(loss).backward()
141
142 if self.training_configuration.gradient_clip is not None:
143 # Unscales the gradients of optimizer's assigned params in-place
144 self.scaler.unscale_(self.optimizer)
145
146 # Since the gradients of optimizer's assigned params are unscaled, clips as usual:
147 torch.nn.utils.clip_grad_norm_(
148 self.model.parameters(),
149 max_norm=self.training_configuration.gradient_clip,
150 )
151
152 self.scaler.step(self.optimizer)
153 self.scaler.update()
154
155 self.log_to_tensorboard(
156 metrics={
157 "Loss": loss,
158 },
159 global_step=global_step,
160 )
161 if epoch % self.training_configuration.checkpoint_rate == 0:
162 self.save_checkpoint(epoch=epoch, checkpoint_name=f"epoch_{epoch}.pt")
163 self.save_checkpoint(
164 epoch=self.training_configuration.number_of_epochs,
165 checkpoint_name="final.pt",
166 )
167
168 @torch.no_grad()
[docs]
169 def log_to_tensorboard(self, metrics: Dict[str, float], global_step: int):
170 """Log to tensorboard.
171
172 This method logs some useful metrics and visualizations to tensorboard.
173
174 Args:
175 metrics: A dictionary mapping metric names to values.
176 global_step: The current global step.
177 """
178 self.model.eval()
179 if global_step % self.log_configuration.log_rate == 0:
180 self.tensorboard_manager.log_metrics(
181 metrics=metrics, global_step=global_step
182 )
183
184 if (global_step % self.log_configuration.image_rate == 0) and (
185 self.log_configuration.number_of_images > 0
186 ):
187 image_channels, image_height, image_width = self._image_shape
188 images = torch.randn(
189 (
190 self.log_configuration.number_of_images,
191 image_channels,
192 image_height,
193 image_width,
194 ),
195 device=self.device,
196 )
197 images = self.model.denoise(images)
198 for step, images in enumerate(images[::-1]):
199 self.tensorboard_manager.log_images(
200 tag=f"Images at timestep {global_step}",
201 images=self.reverse_transforms(images),
202 timestep=step,
203 )