1from enum import Enum
2from typing import List, Tuple
3
4import numpy as np
5import torch
6from tqdm import tqdm
7
8from diffusion_models.gaussian_diffusion.base_diffuser import BaseDiffuser
9from diffusion_models.gaussian_diffusion.beta_schedulers import \
10 BaseBetaScheduler
11from diffusion_models.models.base_diffusion_model import BaseDiffusionModel
12from diffusion_models.utils.schemas import Checkpoint
13
[docs]
14class DenoisingMode(str, Enum):
[docs]
16 Quadratic = "quadratic"
17
[docs]
18class DdimDiffuser(BaseDiffuser):
19 def __init__(self, beta_scheduler: BaseBetaScheduler):
20 """Initializes the class instance.
21
22 Args:
23 beta_scheduler (BaseBetaScheduler): The beta scheduler instance to be used.
24
25 """
26 super().__init__(beta_scheduler)
[docs]
27 self.device: str = "cpu"
28 """The device to use. Defaults to cpu."""
29
30 @classmethod
[docs]
31 def from_checkpoint(cls, checkpoint: Checkpoint) -> "DdimDiffuser":
32 """Instantiate a DDIM Diffuser from a training checkpoint.
33
34 Args:
35 checkpoint: The training checkpoint object containing
36 the trained model's parameters and configuration.
37
38 Returns:
39 An instance of the DdimDiffuser class initialized with the parameters
40 loaded from the given checkpoint.
41 """
42 return cls(
43 beta_scheduler=BaseBetaScheduler.from_tensors(
44 steps=checkpoint.beta_scheduler_config.steps,
45 betas=checkpoint.beta_scheduler_config.betas,
46 alpha_bars=checkpoint.beta_scheduler_config.alpha_bars,
47 )
48 )
49
[docs]
50 def to(self, device: str = "cpu"):
51 """Moves the data to the specified device.
52
53 This performs a similar behaviour to the `to` method of PyTorch. moving the
54 GaussianDiffuser and the BetaScheduler to the specified device.
55
56 Args:
57 device: The device to which the method should move the object.
58 Default is "cpu".
59
60 Example:
61 >>> ddim_diffuser = DdimDiffuser()
62 >>> ddim_diffuser = ddim_diffuser.to(device="cuda")
63 """
64 self.device = device
65 self.beta_scheduler = self.beta_scheduler.to(self.device)
66 return self
67
68 def _diffuse_batch(
69 self, images: torch.Tensor, timesteps: torch.Tensor
70 ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
71 noise = torch.randn_like(images, device=self.device)
72
73 alpha_bar_t = self.beta_scheduler.alpha_bars.gather(dim=0, index=timesteps)
74
75 alpha_bar_t = alpha_bar_t.reshape((-1, *((1,) * (len(images.shape) - 1))))
76
77 mu = torch.sqrt(alpha_bar_t)
78 sigma = torch.sqrt(1 - alpha_bar_t)
79 images = mu * images + sigma * noise
80 return images, noise, timesteps
81
[docs]
82 def diffuse_batch(
83 self, images: torch.Tensor
84 ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
85 """Diffuse a batch of images.
86
87 Diffuse the given batch of images by adding noise based on the beta scheduler.
88
89 Args:
90 images: Batch of images to diffuse.\n
91 Shape should be ``(B, C, H, W)``.
92
93 Returns:
94 A tuple containing three tensors
95
96 - images: Diffused batch of images.
97 - noise: Noise added to the images.
98 - timesteps: Timesteps used for diffusion.
99 """
100 timesteps = torch.randint(
101 0, self.beta_scheduler.steps, (images.shape[0],), device=self.device
102 )
103 return self._diffuse_batch(images, timesteps)
104
105 @torch.no_grad()
106 def _denoise_step(
107 self,
108 images: torch.Tensor,
109 model: torch.nn.Module,
110 timestep: torch.Tensor,
111 timestep_prev: torch.Tensor,
112 eta: float = 0.0,
113 ) -> torch.Tensor:
114 epsilon_theta = model(images, timestep)
115
116 alpha_bar_t = self.beta_scheduler.alpha_bars.gather(
117 dim=0, index=timestep
118 ).reshape(-1, 1, 1, 1)
119
120 alpha_bar_t_prev = self.beta_scheduler.alpha_bars.gather(
121 dim=0, index=timestep_prev
122 ).reshape(-1, 1, 1, 1)
123
124 sigma = eta * torch.sqrt(
125 (1 - alpha_bar_t_prev)
126 / (1 - alpha_bar_t)
127 * (1 - alpha_bar_t / alpha_bar_t_prev)
128 )
129
130 epsilon_t = torch.randn_like(images)
131
132 mu = (
133 torch.sqrt(alpha_bar_t_prev / alpha_bar_t) * images
134 + (
135 torch.sqrt(1 - alpha_bar_t_prev - sigma**2)
136 - torch.sqrt((alpha_bar_t_prev * (1 - alpha_bar_t)) / alpha_bar_t)
137 ) * epsilon_theta
138 )
139 return mu + sigma * epsilon_t
140
[docs]
141 def denoise_batch(
142 self,
143 images: torch.Tensor,
144 model: "BaseDiffusionModel",
145 number_of_steps: int = 50,
146 mode: DenoisingMode = DenoisingMode.Linear
147 ) -> List[torch.Tensor]:
148 """Denoise a batch of images.
149
150 This denoises a batch images. This is the image generation process.
151
152 Args:
153 images: A batch of noisy images.
154 model: The model to be used for denoising.
155 number_of_steps: Number of steps used in the denoising process.
156 mode: Linear or Quadratic sampling.
157
158 Returns:
159 A list of tensors containing a batch of denoised images.
160 """
161 if mode == DenoisingMode.Linear:
162 a = self.beta_scheduler.steps // number_of_steps
163 time_steps = np.asarray(list(range(0, self.beta_scheduler.steps, a)))
164 else:
165 time_steps = (
166 np.linspace(0, np.sqrt(self.beta_scheduler.steps * 0.8), number_of_steps)
167 ** 2
168 ).astype(int)
169
170 time_steps = time_steps + 1
171 time_steps_prev = np.concatenate([[0], time_steps[:-1]])
172
173 denoised_images = []
174 for i in tqdm(range(number_of_steps)[::-1], desc="Denoising"):
175 timestep = torch.full(
176 (images.shape[0],), time_steps[i], device=self.device
177 )
178 timestep_prev = torch.full(
179 (images.shape[0],), time_steps_prev[i], device=self.device
180 )
181
182 images = self._denoise_step(
183 images, model=model, timestep=timestep, timestep_prev=timestep_prev,
184 )
185 images = torch.clamp(images, -1.0, 1.0)
186 denoised_images.append(images)
187 return denoised_images