Source code for diffusion_models.gaussian_diffusion.ddimm_diffuser

  1from enum import Enum
  2from typing import List, Tuple
  3
  4import numpy as np
  5import torch
  6from tqdm import tqdm
  7
  8from diffusion_models.gaussian_diffusion.base_diffuser import BaseDiffuser
  9from diffusion_models.gaussian_diffusion.beta_schedulers import \
 10  BaseBetaScheduler
 11from diffusion_models.models.base_diffusion_model import BaseDiffusionModel
 12from diffusion_models.utils.schemas import Checkpoint
 13
[docs] 14class DenoisingMode(str, Enum):
[docs] 15 Linear = "linear"
[docs] 16 Quadratic = "quadratic"
17
[docs] 18class DdimDiffuser(BaseDiffuser): 19 def __init__(self, beta_scheduler: BaseBetaScheduler): 20 """Initializes the class instance. 21 22 Args: 23 beta_scheduler (BaseBetaScheduler): The beta scheduler instance to be used. 24 25 """ 26 super().__init__(beta_scheduler)
[docs] 27 self.device: str = "cpu"
28 """The device to use. Defaults to cpu.""" 29 30 @classmethod
[docs] 31 def from_checkpoint(cls, checkpoint: Checkpoint) -> "DdimDiffuser": 32 """Instantiate a DDIM Diffuser from a training checkpoint. 33 34 Args: 35 checkpoint: The training checkpoint object containing 36 the trained model's parameters and configuration. 37 38 Returns: 39 An instance of the DdimDiffuser class initialized with the parameters 40 loaded from the given checkpoint. 41 """ 42 return cls( 43 beta_scheduler=BaseBetaScheduler.from_tensors( 44 steps=checkpoint.beta_scheduler_config.steps, 45 betas=checkpoint.beta_scheduler_config.betas, 46 alpha_bars=checkpoint.beta_scheduler_config.alpha_bars, 47 ) 48 )
49
[docs] 50 def to(self, device: str = "cpu"): 51 """Moves the data to the specified device. 52 53 This performs a similar behaviour to the `to` method of PyTorch. moving the 54 GaussianDiffuser and the BetaScheduler to the specified device. 55 56 Args: 57 device: The device to which the method should move the object. 58 Default is "cpu". 59 60 Example: 61 >>> ddim_diffuser = DdimDiffuser() 62 >>> ddim_diffuser = ddim_diffuser.to(device="cuda") 63 """ 64 self.device = device 65 self.beta_scheduler = self.beta_scheduler.to(self.device) 66 return self
67 68 def _diffuse_batch( 69 self, images: torch.Tensor, timesteps: torch.Tensor 70 ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: 71 noise = torch.randn_like(images, device=self.device) 72 73 alpha_bar_t = self.beta_scheduler.alpha_bars.gather(dim=0, index=timesteps) 74 75 alpha_bar_t = alpha_bar_t.reshape((-1, *((1,) * (len(images.shape) - 1)))) 76 77 mu = torch.sqrt(alpha_bar_t) 78 sigma = torch.sqrt(1 - alpha_bar_t) 79 images = mu * images + sigma * noise 80 return images, noise, timesteps 81
[docs] 82 def diffuse_batch( 83 self, images: torch.Tensor 84 ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: 85 """Diffuse a batch of images. 86 87 Diffuse the given batch of images by adding noise based on the beta scheduler. 88 89 Args: 90 images: Batch of images to diffuse.\n 91 Shape should be ``(B, C, H, W)``. 92 93 Returns: 94 A tuple containing three tensors 95 96 - images: Diffused batch of images. 97 - noise: Noise added to the images. 98 - timesteps: Timesteps used for diffusion. 99 """ 100 timesteps = torch.randint( 101 0, self.beta_scheduler.steps, (images.shape[0],), device=self.device 102 ) 103 return self._diffuse_batch(images, timesteps)
104 105 @torch.no_grad() 106 def _denoise_step( 107 self, 108 images: torch.Tensor, 109 model: torch.nn.Module, 110 timestep: torch.Tensor, 111 timestep_prev: torch.Tensor, 112 eta: float = 0.0, 113 ) -> torch.Tensor: 114 epsilon_theta = model(images, timestep) 115 116 alpha_bar_t = self.beta_scheduler.alpha_bars.gather( 117 dim=0, index=timestep 118 ).reshape(-1, 1, 1, 1) 119 120 alpha_bar_t_prev = self.beta_scheduler.alpha_bars.gather( 121 dim=0, index=timestep_prev 122 ).reshape(-1, 1, 1, 1) 123 124 sigma = eta * torch.sqrt( 125 (1 - alpha_bar_t_prev) 126 / (1 - alpha_bar_t) 127 * (1 - alpha_bar_t / alpha_bar_t_prev) 128 ) 129 130 epsilon_t = torch.randn_like(images) 131 132 mu = ( 133 torch.sqrt(alpha_bar_t_prev / alpha_bar_t) * images 134 + ( 135 torch.sqrt(1 - alpha_bar_t_prev - sigma**2) 136 - torch.sqrt((alpha_bar_t_prev * (1 - alpha_bar_t)) / alpha_bar_t) 137 ) * epsilon_theta 138 ) 139 return mu + sigma * epsilon_t 140
[docs] 141 def denoise_batch( 142 self, 143 images: torch.Tensor, 144 model: "BaseDiffusionModel", 145 number_of_steps: int = 50, 146 mode: DenoisingMode = DenoisingMode.Linear 147 ) -> List[torch.Tensor]: 148 """Denoise a batch of images. 149 150 This denoises a batch images. This is the image generation process. 151 152 Args: 153 images: A batch of noisy images. 154 model: The model to be used for denoising. 155 number_of_steps: Number of steps used in the denoising process. 156 mode: Linear or Quadratic sampling. 157 158 Returns: 159 A list of tensors containing a batch of denoised images. 160 """ 161 if mode == DenoisingMode.Linear: 162 a = self.beta_scheduler.steps // number_of_steps 163 time_steps = np.asarray(list(range(0, self.beta_scheduler.steps, a))) 164 else: 165 time_steps = ( 166 np.linspace(0, np.sqrt(self.beta_scheduler.steps * 0.8), number_of_steps) 167 ** 2 168 ).astype(int) 169 170 time_steps = time_steps + 1 171 time_steps_prev = np.concatenate([[0], time_steps[:-1]]) 172 173 denoised_images = [] 174 for i in tqdm(range(number_of_steps)[::-1], desc="Denoising"): 175 timestep = torch.full( 176 (images.shape[0],), time_steps[i], device=self.device 177 ) 178 timestep_prev = torch.full( 179 (images.shape[0],), time_steps_prev[i], device=self.device 180 ) 181 182 images = self._denoise_step( 183 images, model=model, timestep=timestep, timestep_prev=timestep_prev, 184 ) 185 images = torch.clamp(images, -1.0, 1.0) 186 denoised_images.append(images) 187 return denoised_images