Source code for diffusion_models.gaussian_diffusion.ddimm_diffuser

  1from enum import Enum
  2from typing import List
  3from typing import Tuple
  4
  5import numpy as np
  6import torch
  7from tqdm import tqdm
  8
  9from diffusion_models.gaussian_diffusion.base_diffuser import BaseDiffuser
 10from diffusion_models.gaussian_diffusion.beta_schedulers import (
 11  BaseBetaScheduler,
 12)
 13from diffusion_models.models.base_diffusion_model import BaseDiffusionModel
 14from diffusion_models.utils.schemas import Checkpoint, Timestep
 15
 16
[docs] 17class DenoisingMode(str, Enum):
[docs] 18 Linear = "linear"
[docs] 19 Quadratic = "quadratic"
20 21
[docs] 22class DdimDiffuser(BaseDiffuser): 23 def __init__( 24 self, 25 beta_scheduler: BaseBetaScheduler, 26 mode: DenoisingMode = DenoisingMode.Quadratic, 27 number_of_steps: int = 20, 28 ): 29 """Initializes the class instance. 30 31 Args: 32 beta_scheduler (BaseBetaScheduler): The beta scheduler instance to be used. 33 34 """ 35 super().__init__(beta_scheduler) 36
[docs] 37 self.number_of_steps = number_of_steps
38 """Number of steps to use in the denoising process.""" 39
[docs] 40 self.mode = mode
41 """Linear or Quadratic sampling.""" 42
[docs] 43 self.device: str = "cpu"
44 """The device to use. Defaults to cpu.""" 45 46 @property
[docs] 47 def steps(self) -> List[int]: 48 if self.mode == DenoisingMode.Linear: 49 a = self.beta_scheduler.steps // self.number_of_steps 50 time_steps = np.asarray(list(range(0, self.beta_scheduler.steps, a))) 51 else: 52 time_steps = ( 53 np.linspace( 54 0, np.sqrt(self.beta_scheduler.steps * 0.8), self.number_of_steps 55 ) 56 ** 2 57 ).astype(int) 58 self._time_steps = time_steps + 1 59 self._time_steps_prev = np.concatenate([[0], time_steps[:-1]]) 60 return list(range(self.number_of_steps))[::-1]
61
[docs] 62 def get_timestep(self, number_of_images: int, idx: int) -> Timestep: 63 timestep = torch.full( 64 (number_of_images,), self._time_steps[idx], device=self.device 65 ) 66 timestep_prev = torch.full( 67 (number_of_images,), self._time_steps_prev[idx], device=self.device 68 ) 69 return Timestep( 70 current=timestep, 71 previous=timestep_prev, 72 )
73 74 @classmethod
[docs] 75 def from_checkpoint(cls, checkpoint: Checkpoint) -> "DdimDiffuser": 76 """Instantiate a DDIM Diffuser from a training checkpoint. 77 78 Args: 79 checkpoint: The training checkpoint object containing 80 the trained model's parameters and configuration. 81 82 Returns: 83 An instance of the DdimDiffuser class initialized with the parameters 84 loaded from the given checkpoint. 85 """ 86 return cls( 87 beta_scheduler=BaseBetaScheduler.from_tensors( 88 steps=checkpoint.beta_scheduler_config.steps, 89 betas=checkpoint.beta_scheduler_config.betas, 90 alpha_bars=checkpoint.beta_scheduler_config.alpha_bars, 91 ) 92 )
93
[docs] 94 def to(self, device: str = "cpu"): 95 """Moves the data to the specified device. 96 97 This performs a similar behaviour to the `to` method of PyTorch. moving the 98 GaussianDiffuser and the BetaScheduler to the specified device. 99 100 Args: 101 device: The device to which the method should move the object. 102 Default is "cpu". 103 104 Example: 105 >>> ddim_diffuser = DdimDiffuser() 106 >>> ddim_diffuser = ddim_diffuser.to(device="cuda") 107 """ 108 self.device = device 109 self.beta_scheduler = self.beta_scheduler.to(self.device) 110 return self
111 112 def _diffuse_batch( 113 self, images: torch.Tensor, timesteps: torch.Tensor 114 ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: 115 noise = torch.randn_like(images, device=self.device) 116 117 alpha_bar_t = self.beta_scheduler.alpha_bars.gather(dim=0, index=timesteps) 118 119 alpha_bar_t = alpha_bar_t.reshape((-1, *((1,) * (len(images.shape) - 1)))) 120 121 mu = torch.sqrt(alpha_bar_t) 122 sigma = torch.sqrt(1 - alpha_bar_t) 123 images = mu * images + sigma * noise 124 return images, noise, timesteps 125
[docs] 126 def diffuse_batch( 127 self, images: torch.Tensor 128 ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: 129 """Diffuse a batch of images. 130 131 Diffuse the given batch of images by adding noise based on the beta scheduler. 132 133 Args: 134 images: Batch of images to diffuse.\n 135 Shape should be ``(B, C, H, W)``. 136 137 Returns: 138 A tuple containing three tensors 139 140 - images: Diffused batch of images. 141 - noise: Noise added to the images. 142 - timesteps: Timesteps used for diffusion. 143 """ 144 timesteps = torch.randint( 145 0, self.beta_scheduler.steps, (images.shape[0],), device=self.device 146 ) 147 return self._diffuse_batch(images, timesteps)
148 149 @torch.no_grad() 150 def _denoise_step( 151 self, 152 images: torch.Tensor, 153 model: torch.nn.Module, 154 timestep: Timestep, 155 eta: float = 0.0, 156 ) -> torch.Tensor: 157 current_timestep = timestep.current 158 previous_timestep = timestep.previous 159 160 epsilon_theta = model(images, current_timestep) 161 162 alpha_bar_t = self.beta_scheduler.alpha_bars.gather( 163 dim=0, index=current_timestep 164 ).reshape(-1, 1, 1, 1) 165 166 alpha_bar_t_prev = self.beta_scheduler.alpha_bars.gather( 167 dim=0, index=previous_timestep 168 ).reshape(-1, 1, 1, 1) 169 170 sigma = eta * torch.sqrt( 171 (1 - alpha_bar_t_prev) 172 / (1 - alpha_bar_t) 173 * (1 - alpha_bar_t / alpha_bar_t_prev) 174 ) 175 176 epsilon_t = torch.randn_like(images) 177 178 mu = ( 179 torch.sqrt(alpha_bar_t_prev / alpha_bar_t) * images 180 + ( 181 torch.sqrt(1 - alpha_bar_t_prev - sigma**2) 182 - torch.sqrt((alpha_bar_t_prev * (1 - alpha_bar_t)) / alpha_bar_t) 183 ) 184 * epsilon_theta 185 ) 186 return mu + sigma * epsilon_t 187
[docs] 188 def denoise_batch( 189 self, 190 images: torch.Tensor, 191 model: "BaseDiffusionModel", 192 ) -> List[torch.Tensor]: 193 """Denoise a batch of images. 194 195 This denoises a batch images. This is the image generation process. 196 197 Args: 198 images: A batch of noisy images. 199 model: The model to be used for denoising. 200 201 Returns: 202 A list of tensors containing a batch of denoised images. 203 """ 204 denoised_images = [] 205 for i in tqdm(self.steps, desc="Denoising"): 206 timestep = self.get_timestep(images.shape[0], idx=i) 207 208 images = self._denoise_step( 209 images, 210 model=model, 211 timestep=timestep, 212 ) 213 images = torch.clamp(images, -1.0, 1.0) 214 denoised_images.append(images) 215 return denoised_images