1import pathlib
2from typing import Callable
3from typing import Dict
4from typing import Optional
5
6import torch
7from torch.nn import functional as F
8from torch.utils.data import DataLoader
9from torch.utils.data import Dataset
10from tqdm import tqdm
11
12from diffusion_models.models.base_diffusion_model import BaseDiffusionModel
13from diffusion_models.utils.schemas import BetaSchedulerConfiguration
14from diffusion_models.utils.schemas import Checkpoint
15from diffusion_models.utils.schemas import LogConfiguration
16from diffusion_models.utils.schemas import TrainingConfiguration
17from diffusion_models.utils.tensorboard import TensorboardManager
18
19
[docs]
20class DiffusionTrainer:
21 def __init__(
22 self,
23 model: BaseDiffusionModel,
24 dataset: Dataset,
25 optimizer: torch.optim.Optimizer,
26 training_configuration: TrainingConfiguration,
27 loss_function: Callable = F.l1_loss,
28 scheduler: Optional[torch.optim.lr_scheduler.LRScheduler] = None,
29 log_configuration: LogConfiguration = LogConfiguration(),
30 reverse_transforms: Callable = lambda x: x,
31 device: str = "cuda",
32 ):
[docs]
33 self.model = model.to(device)
[docs]
34 self.optimizer = optimizer
[docs]
35 self.loss_function = loss_function
[docs]
36 self.training_configuration = training_configuration
[docs]
37 self.scheduler = scheduler
[docs]
38 self.device = device
39
[docs]
40 self.dataloader = DataLoader(
41 dataset=dataset,
42 batch_size=training_configuration.batch_size,
43 shuffle=True,
44 drop_last=True,
45 num_workers=16,
46 pin_memory=True,
47 persistent_workers=True,
48 )
49
50 self._image_shape = dataset[0][0].shape
51
[docs]
52 self.scaler = torch.amp.GradScaler(
53 device=device
54 # init_scale=8192,
55 )
56
[docs]
57 self.log_configuration = log_configuration
58
[docs]
59 self.checkpoint_path = (
60 pathlib.Path("../checkpoints")
61 / self.training_configuration.training_name
62 )
63
64 self.checkpoint_path.mkdir(exist_ok=True)
[docs]
65 self.tensorboard_manager = TensorboardManager(
66 log_name=self.training_configuration.training_name,
67 )
68
70
[docs]
71 torch.backends.cudnn.benchmark = True
72
[docs]
73 def save_checkpoint(self, epoch: int, checkpoint_name: str):
74 checkpoint = Checkpoint(
75 epoch=epoch,
76 model_state_dict=self.model.state_dict(),
77 optimizer_state_dict=self.optimizer.state_dict(),
78 scaler=self.scaler.state_dict()
79 if self.training_configuration.mixed_precision_training
80 else None,
81 image_channels=self._image_shape[0],
82 beta_scheduler_config=BetaSchedulerConfiguration(
83 steps=self.model.diffuser.beta_scheduler.steps,
84 betas=self.model.diffuser.beta_scheduler.betas,
85 alpha_bars=self.model.diffuser.beta_scheduler.alpha_bars,
86 ),
87 tensorboard_run_name=self.tensorboard_manager.summary_writer.log_dir,
88 )
89 checkpoint.to_file(self.checkpoint_path / checkpoint_name)
90
[docs]
91 def train(self):
92 self.model.train()
93 for epoch in range(self.training_configuration.number_of_epochs):
94 for step, batch in enumerate(
95 tqdm(self.dataloader, desc=f"Epoch={epoch}")
96 ):
97 global_step = epoch * len(self.dataloader) + step
98
99 images, _ = batch
100 images = images.to(self.device)
101
102 noisy_images, noise, timesteps = self.model.diffuse(images=images)
103
104 self.optimizer.zero_grad(set_to_none=True)
105
106 with torch.autocast(
107 device_type=self.device,
108 enabled=self.training_configuration.mixed_precision_training,
109 ):
110 prediction = self.model(noisy_images, timesteps)
111 loss = self.loss_function(noise, prediction)
112
113 self.scaler.scale(loss).backward()
114
115 if self.training_configuration.gradient_clip is not None:
116 # Unscales the gradients of optimizer's assigned params in-place
117 self.scaler.unscale_(self.optimizer)
118
119 # Since the gradients of optimizer's assigned params are unscaled, clips as usual:
120 torch.nn.utils.clip_grad_norm_(
121 self.model.parameters(),
122 max_norm=self.training_configuration.gradient_clip,
123 )
124
125 self.scaler.step(self.optimizer)
126 self.scaler.update()
127
128 self.log_to_tensorboard(
129 metrics={
130 "Loss": loss,
131 },
132 global_step=global_step,
133 )
134 if epoch % self.training_configuration.checkpoint_rate == 0:
135 self.save_checkpoint(epoch=epoch, checkpoint_name=f"epoch_{epoch}.pt")
136 self.save_checkpoint(
137 epoch=self.training_configuration.number_of_epochs,
138 checkpoint_name="final.pt",
139 )
140
141 @torch.no_grad()
[docs]
142 def log_to_tensorboard(self, metrics: Dict[str, float], global_step: int):
143 self.model.eval()
144 if global_step % self.log_configuration.log_rate == 0:
145 self.tensorboard_manager.log_metrics(
146 metrics=metrics, global_step=global_step
147 )
148
149 if (global_step % self.log_configuration.image_rate == 0) and (
150 self.log_configuration.number_of_images > 0
151 ):
152 image_channels, image_height, image_width = self._image_shape
153 images = torch.randn(
154 (
155 self.log_configuration.number_of_images,
156 image_channels,
157 image_height,
158 image_width,
159 ),
160 device=self.device,
161 )
162 images = self.model.denoise(images)
163 for step, images in enumerate(images[::-1]):
164 self.tensorboard_manager.log_images(
165 tag=f"Images at timestep {global_step}",
166 images=self.reverse_transforms(images),
167 timestep=step,
168 )