1from enum import Enum
  2from typing import List
  3from typing import Tuple
  4
  5import numpy as np
  6import torch
  7from tqdm import tqdm
  8
  9from diffusion_models.gaussian_diffusion.base_diffuser import BaseDiffuser
 10from diffusion_models.gaussian_diffusion.beta_schedulers import (
 11  BaseBetaScheduler,
 12)
 13from diffusion_models.models.base_diffusion_model import BaseDiffusionModel
 14from diffusion_models.utils.schemas import Checkpoint, Timestep
 15
 16
[docs]
 17class DenoisingMode(str, Enum):
[docs]
 19  Quadratic = "quadratic" 
 
 20
 21
[docs]
 22class DdimDiffuser(BaseDiffuser):
 23  def __init__(
 24    self,
 25    beta_scheduler: BaseBetaScheduler,
 26    mode: DenoisingMode = DenoisingMode.Quadratic,
 27    number_of_steps: int = 20,
 28  ):
 29    """Initializes the class instance.
 30
 31    Args:
 32      beta_scheduler (BaseBetaScheduler): The beta scheduler instance to be used.
 33
 34    """
 35    super().__init__(beta_scheduler)
 36
[docs]
 37    self.number_of_steps = number_of_steps 
 38    """Number of steps to use in the denoising process."""
 39
 41    """Linear or Quadratic sampling."""
 42
[docs]
 43    self.device: str = "cpu" 
 44    """The device to use. Defaults to cpu."""
 45
 46  @property
[docs]
 47  def steps(self) -> List[int]:
 48    if self.mode == DenoisingMode.Linear:
 49      a = self.beta_scheduler.steps // self.number_of_steps
 50      time_steps = np.asarray(list(range(0, self.beta_scheduler.steps, a)))
 51    else:
 52      time_steps = (
 53        np.linspace(
 54          0, np.sqrt(self.beta_scheduler.steps * 0.8), self.number_of_steps
 55        )
 56        ** 2
 57      ).astype(int)
 58    self._time_steps = time_steps + 1
 59    self._time_steps_prev = np.concatenate([[0], time_steps[:-1]])
 60    return list(range(self.number_of_steps))[::-1] 
 61
[docs]
 62  def get_timestep(self, number_of_images: int, idx: int) -> Timestep:
 63    timestep = torch.full(
 64      (number_of_images,), self._time_steps[idx], device=self.device
 65    )
 66    timestep_prev = torch.full(
 67      (number_of_images,), self._time_steps_prev[idx], device=self.device
 68    )
 69    return Timestep(
 70      current=timestep,
 71      previous=timestep_prev,
 72    ) 
 73
 74  @classmethod
[docs]
 75  def from_checkpoint(cls, checkpoint: Checkpoint) -> "DdimDiffuser":
 76    """Instantiate a DDIM Diffuser from a training checkpoint.
 77
 78    Args:
 79      checkpoint: The training checkpoint object containing
 80        the trained model's parameters and configuration.
 81
 82    Returns:
 83      An instance of the DdimDiffuser class initialized with the parameters
 84      loaded from the given checkpoint.
 85    """
 86    return cls(
 87      beta_scheduler=BaseBetaScheduler.from_tensors(
 88        steps=checkpoint.beta_scheduler_config.steps,
 89        betas=checkpoint.beta_scheduler_config.betas,
 90        alpha_bars=checkpoint.beta_scheduler_config.alpha_bars,
 91      )
 92    ) 
 93
[docs]
 94  def to(self, device: str = "cpu"):
 95    """Moves the data to the specified device.
 96
 97    This performs a similar behaviour to the `to` method of PyTorch. moving the
 98    GaussianDiffuser and the BetaScheduler to the specified device.
 99
100    Args:
101      device: The device to which the method should move the object.
102        Default is "cpu".
103
104    Example:
105      >>> ddim_diffuser = DdimDiffuser()
106      >>> ddim_diffuser = ddim_diffuser.to(device="cuda")
107    """
108    self.device = device
109    self.beta_scheduler = self.beta_scheduler.to(self.device)
110    return self 
111
112  def _diffuse_batch(
113    self, images: torch.Tensor, timesteps: torch.Tensor
114  ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
115    noise = torch.randn_like(images, device=self.device)
116
117    alpha_bar_t = self.beta_scheduler.alpha_bars.gather(dim=0, index=timesteps)
118
119    alpha_bar_t = alpha_bar_t.reshape((-1, *((1,) * (len(images.shape) - 1))))
120
121    mu = torch.sqrt(alpha_bar_t)
122    sigma = torch.sqrt(1 - alpha_bar_t)
123    images = mu * images + sigma * noise
124    return images, noise, timesteps
125
[docs]
126  def diffuse_batch(
127    self, images: torch.Tensor
128  ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
129    """Diffuse a batch of images.
130
131    Diffuse the given batch of images by adding noise based on the beta scheduler.
132
133    Args:
134      images: Batch of images to diffuse.\n
135        Shape should be ``(B, C, H, W)``.
136
137    Returns:
138      A tuple containing three tensors
139
140        - images: Diffused batch of images.
141        - noise: Noise added to the images.
142        - timesteps: Timesteps used for diffusion.
143    """
144    timesteps = torch.randint(
145      0, self.beta_scheduler.steps, (images.shape[0],), device=self.device
146    )
147    return self._diffuse_batch(images, timesteps) 
148
149  @torch.no_grad()
150  def _denoise_step(
151    self,
152    images: torch.Tensor,
153    model: torch.nn.Module,
154    timestep: Timestep,
155    eta: float = 0.0,
156  ) -> torch.Tensor:
157    current_timestep = timestep.current
158    previous_timestep = timestep.previous
159
160    epsilon_theta = model(images, current_timestep)
161
162    alpha_bar_t = self.beta_scheduler.alpha_bars.gather(
163      dim=0, index=current_timestep
164    ).reshape(-1, 1, 1, 1)
165
166    alpha_bar_t_prev = self.beta_scheduler.alpha_bars.gather(
167      dim=0, index=previous_timestep
168    ).reshape(-1, 1, 1, 1)
169
170    sigma = eta * torch.sqrt(
171      (1 - alpha_bar_t_prev)
172      / (1 - alpha_bar_t)
173      * (1 - alpha_bar_t / alpha_bar_t_prev)
174    )
175
176    epsilon_t = torch.randn_like(images)
177
178    mu = (
179      torch.sqrt(alpha_bar_t_prev / alpha_bar_t) * images
180      + (
181        torch.sqrt(1 - alpha_bar_t_prev - sigma**2)
182        - torch.sqrt((alpha_bar_t_prev * (1 - alpha_bar_t)) / alpha_bar_t)
183      )
184      * epsilon_theta
185    )
186    return mu + sigma * epsilon_t
187
[docs]
188  def denoise_batch(
189    self,
190    images: torch.Tensor,
191    model: "BaseDiffusionModel",
192  ) -> List[torch.Tensor]:
193    """Denoise a batch of images.
194
195    This denoises a batch images. This is the image generation process.
196
197    Args:
198      images: A batch of noisy images.
199      model: The model to be used for denoising.
200
201    Returns:
202      A list of tensors containing a batch of denoised images.
203    """
204    denoised_images = []
205    for i in tqdm(self.steps, desc="Denoising"):
206      timestep = self.get_timestep(images.shape[0], idx=i)
207
208      images = self._denoise_step(
209        images,
210        model=model,
211        timestep=timestep,
212      )
213      images = torch.clamp(images, -1.0, 1.0)
214      denoised_images.append(images)
215    return denoised_images