Source code for diffusion_models.diffusion_trainer

  1import pathlib
  2from typing import Callable, Dict
  3
  4import torch
  5from torch.nn import functional as F
  6from torch.utils.data import DataLoader, Dataset
  7from tqdm import tqdm
  8
  9from diffusion_models.models.base_diffusion_model import BaseDiffusionModel
 10from diffusion_models.utils.schemas import BetaSchedulerConfiguration, \
 11  Checkpoint, LogConfiguration, TrainingConfiguration
 12from diffusion_models.utils.tensorboard import TensorboardManager
 13
 14
[docs] 15class DiffusionTrainer: 16 def __init__( 17 self, 18 model: BaseDiffusionModel, 19 dataset: Dataset, 20 optimizer: torch.optim.Optimizer, 21 training_configuration: TrainingConfiguration, 22 loss_function: Callable = F.l1_loss, 23 # scheduler: Optional[torch.optim.lr_scheduler.StepLR] = None, 24 log_configuration: LogConfiguration = LogConfiguration(), 25 reverse_transforms: Callable = lambda x: x, 26 device: str = "cuda", 27 ): 28 """A diffusion trainer framework. 29 30 This is a simplified framework for training a diffusion models. 31 32 Args: 33 model: A diffusion model. 34 dataset: A dataset to train on. 35 optimizer: The optimizer to use. 36 training_configuration: The training configuration to use. 37 loss_function: The loss function to use. 38 log_configuration: The logging configuration to use. 39 reverse_transforms: The reverse transforms to use. 40 device: The device to use. 41 """
[docs] 42 self.model = model.to(device)
43 """The diffusion model to use."""
[docs] 44 self.optimizer = optimizer
45 """The optimizer to use."""
[docs] 46 self.loss_function = loss_function
47 """The loss function to use."""
[docs] 48 self.training_configuration = training_configuration
49 """The training configuration to use.""" 50 # self.scheduler = scheduler
[docs] 51 self.device = device
52 """The device to use.""" 53
[docs] 54 self.dataloader = DataLoader( 55 dataset=dataset, 56 batch_size=training_configuration.batch_size, 57 shuffle=True, 58 drop_last=True, 59 num_workers=16, 60 pin_memory=True, 61 persistent_workers=True, 62 )
63 """A torch dataloader.""" 64 65 self._image_shape = dataset[0][0].shape 66
[docs] 67 self.scaler = torch.amp.GradScaler( 68 device=device 69 # init_scale=8192, 70 )
71 """A torch GradScaler object.""" 72
[docs] 73 self.log_configuration = log_configuration
74 """A LogConfiguration object.""" 75
[docs] 76 self.checkpoint_path = ( 77 pathlib.Path("../checkpoints") 78 / self.training_configuration.training_name 79 )
80 """The path to save checkpoints.""" 81 82 self.checkpoint_path.mkdir(exist_ok=True)
[docs] 83 self.tensorboard_manager = TensorboardManager( 84 log_name=self.training_configuration.training_name, 85 )
86 """A tensorboard manager instance.""" 87
[docs] 88 self.reverse_transforms = reverse_transforms
89 """A set of reverse transforms.""" 90 91 torch.backends.cudnn.benchmark = True 92
[docs] 93 def save_checkpoint(self, epoch: int, checkpoint_name: str): 94 """Save a checkpoint. 95 96 Args: 97 epoch: The current epoch. 98 checkpoint_name: The name of the checkpoint. 99 """ 100 checkpoint = Checkpoint( 101 epoch=epoch, 102 model_state_dict=self.model.state_dict(), 103 optimizer_state_dict=self.optimizer.state_dict(), 104 scaler=self.scaler.state_dict() 105 if self.training_configuration.mixed_precision_training 106 else None, 107 image_channels=self._image_shape[0], 108 beta_scheduler_config=BetaSchedulerConfiguration( 109 steps=self.model.diffuser.beta_scheduler.steps, 110 betas=self.model.diffuser.beta_scheduler.betas, 111 alpha_bars=self.model.diffuser.beta_scheduler.alpha_bars, 112 ), 113 tensorboard_run_name=self.tensorboard_manager.summary_writer.log_dir, 114 ) 115 checkpoint.to_file(self.checkpoint_path / checkpoint_name)
116
[docs] 117 def train(self): 118 """Start the diffusion training.""" 119 self.model.train() 120 for epoch in range(self.training_configuration.number_of_epochs): 121 for step, batch in enumerate( 122 tqdm(self.dataloader, desc=f"Epoch={epoch}") 123 ): 124 global_step = epoch * len(self.dataloader) + step 125 126 images, _ = batch 127 images = images.to(self.device) 128 129 noisy_images, noise, timesteps = self.model.diffuse(images=images) 130 131 self.optimizer.zero_grad(set_to_none=True) 132 133 with torch.autocast( 134 device_type=self.device, 135 enabled=self.training_configuration.mixed_precision_training, 136 ): 137 prediction = self.model(noisy_images, timesteps) 138 loss = self.loss_function(noise, prediction) 139 140 self.scaler.scale(loss).backward() 141 142 if self.training_configuration.gradient_clip is not None: 143 # Unscales the gradients of optimizer's assigned params in-place 144 self.scaler.unscale_(self.optimizer) 145 146 # Since the gradients of optimizer's assigned params are unscaled, clips as usual: 147 torch.nn.utils.clip_grad_norm_( 148 self.model.parameters(), 149 max_norm=self.training_configuration.gradient_clip, 150 ) 151 152 self.scaler.step(self.optimizer) 153 self.scaler.update() 154 155 self.log_to_tensorboard( 156 metrics={ 157 "Loss": loss, 158 }, 159 global_step=global_step, 160 ) 161 if epoch % self.training_configuration.checkpoint_rate == 0: 162 self.save_checkpoint(epoch=epoch, checkpoint_name=f"epoch_{epoch}.pt") 163 self.save_checkpoint( 164 epoch=self.training_configuration.number_of_epochs, 165 checkpoint_name="final.pt", 166 )
167 168 @torch.no_grad()
[docs] 169 def log_to_tensorboard(self, metrics: Dict[str, float], global_step: int): 170 """Log to tensorboard. 171 172 This method logs some useful metrics and visualizations to tensorboard. 173 174 Args: 175 metrics: A dictionary mapping metric names to values. 176 global_step: The current global step. 177 """ 178 self.model.eval() 179 if global_step % self.log_configuration.log_rate == 0: 180 self.tensorboard_manager.log_metrics( 181 metrics=metrics, global_step=global_step 182 ) 183 184 if (global_step % self.log_configuration.image_rate == 0) and ( 185 self.log_configuration.number_of_images > 0 186 ): 187 image_channels, image_height, image_width = self._image_shape 188 images = torch.randn( 189 ( 190 self.log_configuration.number_of_images, 191 image_channels, 192 image_height, 193 image_width, 194 ), 195 device=self.device, 196 ) 197 images = self.model.denoise(images) 198 for step, images in enumerate(images[::-1]): 199 self.tensorboard_manager.log_images( 200 tag=f"Images at timestep {global_step}", 201 images=self.reverse_transforms(images), 202 timestep=step, 203 )