Source code for diffusion_models.diffusion_trainer

  1import pathlib
  2from typing import Callable
  3from typing import Dict
  4from typing import Optional
  5
  6import torch
  7from torch.nn import functional as F
  8from torch.utils.data import DataLoader
  9from torch.utils.data import Dataset
 10from tqdm import tqdm
 11
 12from diffusion_models.models.base_diffusion_model import BaseDiffusionModel
 13from diffusion_models.utils.schemas import BetaSchedulerConfiguration
 14from diffusion_models.utils.schemas import Checkpoint
 15from diffusion_models.utils.schemas import LogConfiguration
 16from diffusion_models.utils.schemas import TrainingConfiguration
 17from diffusion_models.utils.tensorboard import TensorboardManager
 18
 19
[docs] 20class DiffusionTrainer: 21 def __init__( 22 self, 23 model: BaseDiffusionModel, 24 dataset: Dataset, 25 optimizer: torch.optim.Optimizer, 26 training_configuration: TrainingConfiguration, 27 loss_function: Callable = F.l1_loss, 28 scheduler: Optional[torch.optim.lr_scheduler.LRScheduler] = None, 29 log_configuration: LogConfiguration = LogConfiguration(), 30 reverse_transforms: Callable = lambda x: x, 31 device: str = "cuda", 32 ):
[docs] 33 self.model = model.to(device)
[docs] 34 self.optimizer = optimizer
[docs] 35 self.loss_function = loss_function
[docs] 36 self.training_configuration = training_configuration
[docs] 37 self.scheduler = scheduler
[docs] 38 self.device = device
39
[docs] 40 self.dataloader = DataLoader( 41 dataset=dataset, 42 batch_size=training_configuration.batch_size, 43 shuffle=True, 44 drop_last=True, 45 num_workers=16, 46 pin_memory=True, 47 persistent_workers=True, 48 )
49 50 self._image_shape = dataset[0][0].shape 51
[docs] 52 self.scaler = torch.amp.GradScaler( 53 device=device 54 # init_scale=8192, 55 )
56
[docs] 57 self.log_configuration = log_configuration
58
[docs] 59 self.checkpoint_path = ( 60 pathlib.Path("../checkpoints") 61 / self.training_configuration.training_name 62 )
63 64 self.checkpoint_path.mkdir(exist_ok=True)
[docs] 65 self.tensorboard_manager = TensorboardManager( 66 log_name=self.training_configuration.training_name, 67 )
68
[docs] 69 self.reverse_transforms = reverse_transforms
70
[docs] 71 torch.backends.cudnn.benchmark = True
72
[docs] 73 def save_checkpoint(self, epoch: int, checkpoint_name: str): 74 checkpoint = Checkpoint( 75 epoch=epoch, 76 model_state_dict=self.model.state_dict(), 77 optimizer_state_dict=self.optimizer.state_dict(), 78 scaler=self.scaler.state_dict() 79 if self.training_configuration.mixed_precision_training 80 else None, 81 image_channels=self._image_shape[0], 82 beta_scheduler_config=BetaSchedulerConfiguration( 83 steps=self.model.diffuser.beta_scheduler.steps, 84 betas=self.model.diffuser.beta_scheduler.betas, 85 alpha_bars=self.model.diffuser.beta_scheduler.alpha_bars, 86 ), 87 tensorboard_run_name=self.tensorboard_manager.summary_writer.log_dir, 88 ) 89 checkpoint.to_file(self.checkpoint_path / checkpoint_name)
90
[docs] 91 def train(self): 92 self.model.train() 93 for epoch in range(self.training_configuration.number_of_epochs): 94 for step, batch in enumerate( 95 tqdm(self.dataloader, desc=f"Epoch={epoch}") 96 ): 97 global_step = epoch * len(self.dataloader) + step 98 99 images, _ = batch 100 images = images.to(self.device) 101 102 noisy_images, noise, timesteps = self.model.diffuse(images=images) 103 104 self.optimizer.zero_grad(set_to_none=True) 105 106 with torch.autocast( 107 device_type=self.device, 108 enabled=self.training_configuration.mixed_precision_training, 109 ): 110 prediction = self.model(noisy_images, timesteps) 111 loss = self.loss_function(noise, prediction) 112 113 self.scaler.scale(loss).backward() 114 115 if self.training_configuration.gradient_clip is not None: 116 # Unscales the gradients of optimizer's assigned params in-place 117 self.scaler.unscale_(self.optimizer) 118 119 # Since the gradients of optimizer's assigned params are unscaled, clips as usual: 120 torch.nn.utils.clip_grad_norm_( 121 self.model.parameters(), 122 max_norm=self.training_configuration.gradient_clip, 123 ) 124 125 self.scaler.step(self.optimizer) 126 self.scaler.update() 127 128 self.log_to_tensorboard( 129 metrics={ 130 "Loss": loss, 131 }, 132 global_step=global_step, 133 ) 134 if epoch % self.training_configuration.checkpoint_rate == 0: 135 self.save_checkpoint(epoch=epoch, checkpoint_name=f"epoch_{epoch}.pt") 136 self.save_checkpoint( 137 epoch=self.training_configuration.number_of_epochs, 138 checkpoint_name="final.pt", 139 )
140 141 @torch.no_grad()
[docs] 142 def log_to_tensorboard(self, metrics: Dict[str, float], global_step: int): 143 self.model.eval() 144 if global_step % self.log_configuration.log_rate == 0: 145 self.tensorboard_manager.log_metrics( 146 metrics=metrics, global_step=global_step 147 ) 148 149 if (global_step % self.log_configuration.image_rate == 0) and ( 150 self.log_configuration.number_of_images > 0 151 ): 152 image_channels, image_height, image_width = self._image_shape 153 images = torch.randn( 154 ( 155 self.log_configuration.number_of_images, 156 image_channels, 157 image_height, 158 image_width, 159 ), 160 device=self.device, 161 ) 162 images = self.model.denoise(images) 163 for step, images in enumerate(images[::-1]): 164 self.tensorboard_manager.log_images( 165 tag=f"Images at timestep {global_step}", 166 images=self.reverse_transforms(images), 167 timestep=step, 168 )