1import pathlib
2from typing import Callable, Dict
3
4import torch
5from torch.nn import functional as F
6from torch.utils.data import DataLoader, Dataset
7from tqdm import tqdm
8
9from diffusion_models.models.base_diffusion_model import BaseDiffusionModel
10from diffusion_models.utils.schemas import (
11 BetaSchedulerConfiguration,
12 Checkpoint,
13 LogConfiguration,
14 TrainingConfiguration,
15)
16from diffusion_models.utils.tensorboard import TensorboardManager
17
18
[docs]
19class DiffusionTrainer:
20 def __init__(
21 self,
22 model: BaseDiffusionModel,
23 dataset: Dataset,
24 optimizer: torch.optim.Optimizer,
25 training_configuration: TrainingConfiguration,
26 loss_function: Callable = F.l1_loss,
27 # scheduler: Optional[torch.optim.lr_scheduler.StepLR] = None,
28 log_configuration: LogConfiguration = LogConfiguration(),
29 reverse_transforms: Callable = lambda x: x,
30 device: str = "cuda",
31 ):
32 """A diffusion trainer framework.
33
34 This is a simplified framework for training a diffusion models.
35
36 Args:
37 model: A diffusion model.
38 dataset: A dataset to train on.
39 optimizer: The optimizer to use.
40 training_configuration: The training configuration to use.
41 loss_function: The loss function to use.
42 log_configuration: The logging configuration to use.
43 reverse_transforms: The reverse transforms to use.
44 device: The device to use.
45 """
[docs]
46 self.model = model.to(device)
47 """The diffusion model to use."""
[docs]
48 self.optimizer = optimizer
49 """The optimizer to use."""
[docs]
50 self.loss_function = loss_function
51 """The loss function to use."""
[docs]
52 self.training_configuration = training_configuration
53 """The training configuration to use."""
54 # self.scheduler = scheduler
[docs]
55 self.device = device
56 """The device to use."""
57
[docs]
58 self.dataloader = DataLoader(
59 dataset=dataset,
60 batch_size=training_configuration.batch_size,
61 shuffle=True,
62 drop_last=True,
63 num_workers=16,
64 pin_memory=True,
65 persistent_workers=True,
66 )
67 """A torch dataloader."""
68
69 self._image_shape = dataset[0][0].shape
70
[docs]
71 self.scaler = torch.amp.GradScaler(
72 device=device
73 # init_scale=8192,
74 )
75 """A torch GradScaler object."""
76
[docs]
77 self.log_configuration = log_configuration
78 """A LogConfiguration object."""
79
[docs]
80 self.checkpoint_path = (
81 pathlib.Path("../checkpoints")
82 / self.training_configuration.training_name
83 )
84 """The path to save checkpoints."""
85
86 self.checkpoint_path.mkdir(exist_ok=True)
[docs]
87 self.tensorboard_manager = TensorboardManager(
88 log_name=self.training_configuration.training_name,
89 )
90 """A tensorboard manager instance."""
91
93 """A set of reverse transforms."""
94
95 torch.backends.cudnn.benchmark = True
96
[docs]
97 def save_checkpoint(self, epoch: int, checkpoint_name: str):
98 """Save a checkpoint.
99
100 Args:
101 epoch: The current epoch.
102 checkpoint_name: The name of the checkpoint.
103 """
104 checkpoint = Checkpoint(
105 epoch=epoch,
106 model_state_dict=self.model.state_dict(),
107 optimizer_state_dict=self.optimizer.state_dict(),
108 scaler=self.scaler.state_dict()
109 if self.training_configuration.mixed_precision_training
110 else None,
111 image_channels=self._image_shape[0],
112 beta_scheduler_config=BetaSchedulerConfiguration(
113 steps=self.model.diffuser.beta_scheduler.steps,
114 betas=self.model.diffuser.beta_scheduler.betas,
115 alpha_bars=self.model.diffuser.beta_scheduler.alpha_bars,
116 ),
117 tensorboard_run_name=self.tensorboard_manager.summary_writer.log_dir,
118 )
119 checkpoint.to_file(self.checkpoint_path / checkpoint_name)
120
[docs]
121 def train(self):
122 """Start the diffusion training."""
123 self.model.train()
124 for epoch in range(self.training_configuration.number_of_epochs):
125 for step, batch in enumerate(
126 tqdm(self.dataloader, desc=f"Epoch={epoch}")
127 ):
128 global_step = epoch * len(self.dataloader) + step
129
130 images, _ = batch
131 images = images.to(self.device)
132
133 noisy_images, noise, timesteps = self.model.diffuse(images=images)
134
135 self.optimizer.zero_grad(set_to_none=True)
136
137 with torch.autocast(
138 device_type=self.device,
139 enabled=self.training_configuration.mixed_precision_training,
140 ):
141 prediction = self.model(noisy_images, timesteps)
142 loss = self.loss_function(noise, prediction)
143
144 self.scaler.scale(loss).backward()
145
146 if self.training_configuration.gradient_clip is not None:
147 # Unscales the gradients of optimizer's assigned params in-place
148 self.scaler.unscale_(self.optimizer)
149
150 # Since the gradients of optimizer's assigned params are unscaled, clips as usual:
151 torch.nn.utils.clip_grad_norm_(
152 self.model.parameters(),
153 max_norm=self.training_configuration.gradient_clip,
154 )
155
156 self.scaler.step(self.optimizer)
157 self.scaler.update()
158
159 self.log_to_tensorboard(
160 metrics={
161 "Loss": loss,
162 },
163 global_step=global_step,
164 )
165 if epoch % self.training_configuration.checkpoint_rate == 0:
166 self.save_checkpoint(epoch=epoch, checkpoint_name=f"epoch_{epoch}.pt")
167 self.save_checkpoint(
168 epoch=self.training_configuration.number_of_epochs,
169 checkpoint_name="final.pt",
170 )
171
172 @torch.no_grad()
[docs]
173 def log_to_tensorboard(self, metrics: Dict[str, float], global_step: int):
174 """Log to tensorboard.
175
176 This method logs some useful metrics and visualizations to tensorboard.
177
178 Args:
179 metrics: A dictionary mapping metric names to values.
180 global_step: The current global step.
181 """
182 self.model.eval()
183 if global_step % self.log_configuration.log_rate == 0:
184 self.tensorboard_manager.log_metrics(
185 metrics=metrics, global_step=global_step
186 )
187
188 if (global_step % self.log_configuration.image_rate == 0) and (
189 self.log_configuration.number_of_images > 0
190 ):
191 image_channels, image_height, image_width = self._image_shape
192 images = torch.randn(
193 (
194 self.log_configuration.number_of_images,
195 image_channels,
196 image_height,
197 image_width,
198 ),
199 device=self.device,
200 )
201 images = self.model.denoise(images)
202 for step, images in enumerate(images[::-1]):
203 self.tensorboard_manager.log_images(
204 tag=f"Images at timestep {global_step}",
205 images=self.reverse_transforms(images),
206 timestep=step,
207 )