Source code for diffusion_models.diffusion_trainer

  1import pathlib
  2from typing import Callable, Dict
  3
  4import torch
  5from torch.nn import functional as F
  6from torch.utils.data import DataLoader, Dataset
  7from tqdm import tqdm
  8
  9from diffusion_models.models.base_diffusion_model import BaseDiffusionModel
 10from diffusion_models.utils.schemas import (
 11  BetaSchedulerConfiguration,
 12  Checkpoint,
 13  LogConfiguration,
 14  TrainingConfiguration,
 15)
 16from diffusion_models.utils.tensorboard import TensorboardManager
 17
 18
[docs] 19class DiffusionTrainer: 20 def __init__( 21 self, 22 model: BaseDiffusionModel, 23 dataset: Dataset, 24 optimizer: torch.optim.Optimizer, 25 training_configuration: TrainingConfiguration, 26 loss_function: Callable = F.l1_loss, 27 # scheduler: Optional[torch.optim.lr_scheduler.StepLR] = None, 28 log_configuration: LogConfiguration = LogConfiguration(), 29 reverse_transforms: Callable = lambda x: x, 30 device: str = "cuda", 31 ): 32 """A diffusion trainer framework. 33 34 This is a simplified framework for training a diffusion models. 35 36 Args: 37 model: A diffusion model. 38 dataset: A dataset to train on. 39 optimizer: The optimizer to use. 40 training_configuration: The training configuration to use. 41 loss_function: The loss function to use. 42 log_configuration: The logging configuration to use. 43 reverse_transforms: The reverse transforms to use. 44 device: The device to use. 45 """
[docs] 46 self.model = model.to(device)
47 """The diffusion model to use."""
[docs] 48 self.optimizer = optimizer
49 """The optimizer to use."""
[docs] 50 self.loss_function = loss_function
51 """The loss function to use."""
[docs] 52 self.training_configuration = training_configuration
53 """The training configuration to use.""" 54 # self.scheduler = scheduler
[docs] 55 self.device = device
56 """The device to use.""" 57
[docs] 58 self.dataloader = DataLoader( 59 dataset=dataset, 60 batch_size=training_configuration.batch_size, 61 shuffle=True, 62 drop_last=True, 63 num_workers=16, 64 pin_memory=True, 65 persistent_workers=True, 66 )
67 """A torch dataloader.""" 68 69 self._image_shape = dataset[0][0].shape 70
[docs] 71 self.scaler = torch.amp.GradScaler( 72 device=device 73 # init_scale=8192, 74 )
75 """A torch GradScaler object.""" 76
[docs] 77 self.log_configuration = log_configuration
78 """A LogConfiguration object.""" 79
[docs] 80 self.checkpoint_path = ( 81 pathlib.Path("../checkpoints") 82 / self.training_configuration.training_name 83 )
84 """The path to save checkpoints.""" 85 86 self.checkpoint_path.mkdir(exist_ok=True)
[docs] 87 self.tensorboard_manager = TensorboardManager( 88 log_name=self.training_configuration.training_name, 89 )
90 """A tensorboard manager instance.""" 91
[docs] 92 self.reverse_transforms = reverse_transforms
93 """A set of reverse transforms.""" 94 95 torch.backends.cudnn.benchmark = True 96
[docs] 97 def save_checkpoint(self, epoch: int, checkpoint_name: str): 98 """Save a checkpoint. 99 100 Args: 101 epoch: The current epoch. 102 checkpoint_name: The name of the checkpoint. 103 """ 104 checkpoint = Checkpoint( 105 epoch=epoch, 106 model_state_dict=self.model.state_dict(), 107 optimizer_state_dict=self.optimizer.state_dict(), 108 scaler=self.scaler.state_dict() 109 if self.training_configuration.mixed_precision_training 110 else None, 111 image_channels=self._image_shape[0], 112 beta_scheduler_config=BetaSchedulerConfiguration( 113 steps=self.model.diffuser.beta_scheduler.steps, 114 betas=self.model.diffuser.beta_scheduler.betas, 115 alpha_bars=self.model.diffuser.beta_scheduler.alpha_bars, 116 ), 117 tensorboard_run_name=self.tensorboard_manager.summary_writer.log_dir, 118 ) 119 checkpoint.to_file(self.checkpoint_path / checkpoint_name)
120
[docs] 121 def train(self): 122 """Start the diffusion training.""" 123 self.model.train() 124 for epoch in range(self.training_configuration.number_of_epochs): 125 for step, batch in enumerate( 126 tqdm(self.dataloader, desc=f"Epoch={epoch}") 127 ): 128 global_step = epoch * len(self.dataloader) + step 129 130 images, _ = batch 131 images = images.to(self.device) 132 133 noisy_images, noise, timesteps = self.model.diffuse(images=images) 134 135 self.optimizer.zero_grad(set_to_none=True) 136 137 with torch.autocast( 138 device_type=self.device, 139 enabled=self.training_configuration.mixed_precision_training, 140 ): 141 prediction = self.model(noisy_images, timesteps) 142 loss = self.loss_function(noise, prediction) 143 144 self.scaler.scale(loss).backward() 145 146 if self.training_configuration.gradient_clip is not None: 147 # Unscales the gradients of optimizer's assigned params in-place 148 self.scaler.unscale_(self.optimizer) 149 150 # Since the gradients of optimizer's assigned params are unscaled, clips as usual: 151 torch.nn.utils.clip_grad_norm_( 152 self.model.parameters(), 153 max_norm=self.training_configuration.gradient_clip, 154 ) 155 156 self.scaler.step(self.optimizer) 157 self.scaler.update() 158 159 self.log_to_tensorboard( 160 metrics={ 161 "Loss": loss, 162 }, 163 global_step=global_step, 164 ) 165 if epoch % self.training_configuration.checkpoint_rate == 0: 166 self.save_checkpoint(epoch=epoch, checkpoint_name=f"epoch_{epoch}.pt") 167 self.save_checkpoint( 168 epoch=self.training_configuration.number_of_epochs, 169 checkpoint_name="final.pt", 170 )
171 172 @torch.no_grad()
[docs] 173 def log_to_tensorboard(self, metrics: Dict[str, float], global_step: int): 174 """Log to tensorboard. 175 176 This method logs some useful metrics and visualizations to tensorboard. 177 178 Args: 179 metrics: A dictionary mapping metric names to values. 180 global_step: The current global step. 181 """ 182 self.model.eval() 183 if global_step % self.log_configuration.log_rate == 0: 184 self.tensorboard_manager.log_metrics( 185 metrics=metrics, global_step=global_step 186 ) 187 188 if (global_step % self.log_configuration.image_rate == 0) and ( 189 self.log_configuration.number_of_images > 0 190 ): 191 image_channels, image_height, image_width = self._image_shape 192 images = torch.randn( 193 ( 194 self.log_configuration.number_of_images, 195 image_channels, 196 image_height, 197 image_width, 198 ), 199 device=self.device, 200 ) 201 images = self.model.denoise(images) 202 for step, images in enumerate(images[::-1]): 203 self.tensorboard_manager.log_images( 204 tag=f"Images at timestep {global_step}", 205 images=self.reverse_transforms(images), 206 timestep=step, 207 )