1from enum import Enum
2from typing import List
3from typing import Tuple
4
5import numpy as np
6import torch
7from tqdm import tqdm
8
9from diffusion_models.gaussian_diffusion.base_diffuser import BaseDiffuser
10from diffusion_models.gaussian_diffusion.beta_schedulers import (
11 BaseBetaScheduler,
12)
13from diffusion_models.models.base_diffusion_model import BaseDiffusionModel
14from diffusion_models.utils.schemas import Checkpoint, Timestep
15
16
[docs]
17class DenoisingMode(str, Enum):
[docs]
19 Quadratic = "quadratic"
20
21
[docs]
22class DdimDiffuser(BaseDiffuser):
23 def __init__(
24 self,
25 beta_scheduler: BaseBetaScheduler,
26 mode: DenoisingMode = DenoisingMode.Quadratic,
27 number_of_steps: int = 20,
28 ):
29 """Initializes the class instance.
30
31 Args:
32 beta_scheduler (BaseBetaScheduler): The beta scheduler instance to be used.
33
34 """
35 super().__init__(beta_scheduler)
36
[docs]
37 self.number_of_steps = number_of_steps
38 """Number of steps to use in the denoising process."""
39
41 """Linear or Quadratic sampling."""
42
[docs]
43 self.device: str = "cpu"
44 """The device to use. Defaults to cpu."""
45
46 @property
[docs]
47 def steps(self) -> List[int]:
48 if self.mode == DenoisingMode.Linear:
49 a = self.beta_scheduler.steps // self.number_of_steps
50 time_steps = np.asarray(list(range(0, self.beta_scheduler.steps, a)))
51 else:
52 time_steps = (
53 np.linspace(
54 0, np.sqrt(self.beta_scheduler.steps * 0.8), self.number_of_steps
55 )
56 ** 2
57 ).astype(int)
58 self._time_steps = time_steps + 1
59 self._time_steps_prev = np.concatenate([[0], time_steps[:-1]])
60 return list(range(self.number_of_steps))[::-1]
61
[docs]
62 def get_timestep(self, number_of_images: int, idx: int) -> Timestep:
63 timestep = torch.full(
64 (number_of_images,), self._time_steps[idx], device=self.device
65 )
66 timestep_prev = torch.full(
67 (number_of_images,), self._time_steps_prev[idx], device=self.device
68 )
69 return Timestep(
70 current=timestep,
71 previous=timestep_prev,
72 )
73
74 @classmethod
[docs]
75 def from_checkpoint(cls, checkpoint: Checkpoint) -> "DdimDiffuser":
76 """Instantiate a DDIM Diffuser from a training checkpoint.
77
78 Args:
79 checkpoint: The training checkpoint object containing
80 the trained model's parameters and configuration.
81
82 Returns:
83 An instance of the DdimDiffuser class initialized with the parameters
84 loaded from the given checkpoint.
85 """
86 return cls(
87 beta_scheduler=BaseBetaScheduler.from_tensors(
88 steps=checkpoint.beta_scheduler_config.steps,
89 betas=checkpoint.beta_scheduler_config.betas,
90 alpha_bars=checkpoint.beta_scheduler_config.alpha_bars,
91 )
92 )
93
[docs]
94 def to(self, device: str = "cpu"):
95 """Moves the data to the specified device.
96
97 This performs a similar behaviour to the `to` method of PyTorch. moving the
98 GaussianDiffuser and the BetaScheduler to the specified device.
99
100 Args:
101 device: The device to which the method should move the object.
102 Default is "cpu".
103
104 Example:
105 >>> ddim_diffuser = DdimDiffuser()
106 >>> ddim_diffuser = ddim_diffuser.to(device="cuda")
107 """
108 self.device = device
109 self.beta_scheduler = self.beta_scheduler.to(self.device)
110 return self
111
112 def _diffuse_batch(
113 self, images: torch.Tensor, timesteps: torch.Tensor
114 ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
115 noise = torch.randn_like(images, device=self.device)
116
117 alpha_bar_t = self.beta_scheduler.alpha_bars.gather(dim=0, index=timesteps)
118
119 alpha_bar_t = alpha_bar_t.reshape((-1, *((1,) * (len(images.shape) - 1))))
120
121 mu = torch.sqrt(alpha_bar_t)
122 sigma = torch.sqrt(1 - alpha_bar_t)
123 images = mu * images + sigma * noise
124 return images, noise, timesteps
125
[docs]
126 def diffuse_batch(
127 self, images: torch.Tensor
128 ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
129 """Diffuse a batch of images.
130
131 Diffuse the given batch of images by adding noise based on the beta scheduler.
132
133 Args:
134 images: Batch of images to diffuse.\n
135 Shape should be ``(B, C, H, W)``.
136
137 Returns:
138 A tuple containing three tensors
139
140 - images: Diffused batch of images.
141 - noise: Noise added to the images.
142 - timesteps: Timesteps used for diffusion.
143 """
144 timesteps = torch.randint(
145 0, self.beta_scheduler.steps, (images.shape[0],), device=self.device
146 )
147 return self._diffuse_batch(images, timesteps)
148
149 @torch.no_grad()
150 def _denoise_step(
151 self,
152 images: torch.Tensor,
153 model: torch.nn.Module,
154 timestep: Timestep,
155 eta: float = 0.0,
156 ) -> torch.Tensor:
157 current_timestep = timestep.current
158 previous_timestep = timestep.previous
159
160 epsilon_theta = model(images, current_timestep)
161
162 alpha_bar_t = self.beta_scheduler.alpha_bars.gather(
163 dim=0, index=current_timestep
164 ).reshape(-1, 1, 1, 1)
165
166 alpha_bar_t_prev = self.beta_scheduler.alpha_bars.gather(
167 dim=0, index=previous_timestep
168 ).reshape(-1, 1, 1, 1)
169
170 sigma = eta * torch.sqrt(
171 (1 - alpha_bar_t_prev)
172 / (1 - alpha_bar_t)
173 * (1 - alpha_bar_t / alpha_bar_t_prev)
174 )
175
176 epsilon_t = torch.randn_like(images)
177
178 mu = (
179 torch.sqrt(alpha_bar_t_prev / alpha_bar_t) * images
180 + (
181 torch.sqrt(1 - alpha_bar_t_prev - sigma**2)
182 - torch.sqrt((alpha_bar_t_prev * (1 - alpha_bar_t)) / alpha_bar_t)
183 )
184 * epsilon_theta
185 )
186 return mu + sigma * epsilon_t
187
[docs]
188 def denoise_batch(
189 self,
190 images: torch.Tensor,
191 model: "BaseDiffusionModel",
192 ) -> List[torch.Tensor]:
193 """Denoise a batch of images.
194
195 This denoises a batch images. This is the image generation process.
196
197 Args:
198 images: A batch of noisy images.
199 model: The model to be used for denoising.
200
201 Returns:
202 A list of tensors containing a batch of denoised images.
203 """
204 denoised_images = []
205 for i in tqdm(self.steps, desc="Denoising"):
206 timestep = self.get_timestep(images.shape[0], idx=i)
207
208 images = self._denoise_step(
209 images,
210 model=model,
211 timestep=timestep,
212 )
213 images = torch.clamp(images, -1.0, 1.0)
214 denoised_images.append(images)
215 return denoised_images