1from typing import TYPE_CHECKING, List, Tuple
2
3import torch
4from tqdm import tqdm
5
6from diffusion_models.gaussian_diffusion.base_diffuser import BaseDiffuser
7from diffusion_models.gaussian_diffusion.beta_schedulers import (
8 BaseBetaScheduler,
9)
10from diffusion_models.utils.schemas import Checkpoint, Timestep
11
12if TYPE_CHECKING:
13 from diffusion_models.models.base_diffusion_model import BaseDiffusionModel
14
15
[docs]
16class GaussianDiffuser(BaseDiffuser):
17 def __init__(self, beta_scheduler: BaseBetaScheduler):
18 """Initializes the class instance.
19
20 Args:
21 beta_scheduler (BaseBetaScheduler): The beta scheduler instance to be used.
22
23 """
24 super().__init__(beta_scheduler)
[docs]
25 self.device: str = "cpu"
26 """The device to use. Defaults to cpu."""
27
28 @property
[docs]
29 def steps(self) -> List[int]:
30 return list(range(self.beta_scheduler.steps))[::-1]
31
[docs]
32 def get_timestep(self, number_of_images: int, idx: int):
33 timestep = torch.full((number_of_images,), idx, device=self.device)
34 return Timestep(
35 current=timestep,
36 )
37
38 @classmethod
[docs]
39 def from_checkpoint(cls, checkpoint: Checkpoint) -> "GaussianDiffuser":
40 """Instantiate a Gaussian Diffuser from a training checkpoint.
41
42 Args:
43 checkpoint: The training checkpoint object containing
44 the trained model's parameters and configuration.
45
46 Returns:
47 An instance of the GaussianDiffuser class initialized with the parameters
48 loaded from the given checkpoint.
49 """
50 return cls(
51 beta_scheduler=BaseBetaScheduler.from_tensors(
52 steps=checkpoint.beta_scheduler_config.steps,
53 betas=checkpoint.beta_scheduler_config.betas,
54 alpha_bars=checkpoint.beta_scheduler_config.alpha_bars,
55 )
56 )
57
[docs]
58 def to(self, device: str = "cpu"):
59 """Moves the data to the specified device.
60
61 This performs a similar behaviour to the `to` method of PyTorch. moving the
62 GaussianDiffuser and the BetaScheduler to the specified device.
63
64 Args:
65 device: The device to which the method should move the object.
66 Default is "cpu".
67
68 Example:
69 >>> gaussian_diffuser = GaussianDiffuser()
70 >>> gaussian_diffuser = gaussian_diffuser.to(device="cuda")
71 """
72 self.device = device
73 self.beta_scheduler = self.beta_scheduler.to(self.device)
74 return self
75
76 def _diffuse_batch(
77 self, images: torch.Tensor, timesteps: torch.Tensor
78 ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
79 noise = torch.randn_like(images, device=self.device)
80
81 alpha_bar_t = self.beta_scheduler.alpha_bars.gather(dim=0, index=timesteps)
82
83 alpha_bar_t = alpha_bar_t.reshape((-1, *((1,) * (len(images.shape) - 1))))
84
85 mu = torch.sqrt(alpha_bar_t)
86 sigma = torch.sqrt(1 - alpha_bar_t)
87 images = mu * images + sigma * noise
88 return images, noise, timesteps
89
[docs]
90 def diffuse_batch(
91 self, images: torch.Tensor
92 ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
93 """Diffuse a batch of images.
94
95 Diffuse the given batch of images by adding noise based on the beta scheduler.
96
97 Args:
98 images: Batch of images to diffuse.\n
99 Shape should be ``(B, C, H, W)``.
100
101 Returns:
102 A tuple containing three tensors
103
104 - images: Diffused batch of images.
105 - noise: Noise added to the images.
106 - timesteps: Timesteps used for diffusion.
107 """
108 timesteps = torch.randint(
109 0, self.beta_scheduler.steps, (images.shape[0],), device=self.device
110 )
111 return self._diffuse_batch(images, timesteps)
112
113 @torch.no_grad()
114 def _denoise_step(
115 self, images: torch.Tensor, model: torch.nn.Module, timestep: Timestep
116 ) -> torch.Tensor:
117 current_timestep = timestep.current
118 beta_t = self.beta_scheduler.betas[current_timestep].reshape(-1, 1, 1, 1)
119 alpha_t = 1 - beta_t
120 alpha_bar_t = self.beta_scheduler.alpha_bars.gather(
121 dim=0, index=current_timestep
122 ).reshape(-1, 1, 1, 1)
123 mu = (1 / torch.sqrt(alpha_t)) * (
124 images
125 - model(images, current_timestep)
126 * (beta_t / torch.sqrt(1 - alpha_bar_t))
127 )
128
129 if current_timestep[0] == 0:
130 return mu
131 else:
132 sigma = torch.sqrt(beta_t) * torch.randn_like(images)
133 return mu + sigma
134
[docs]
135 def denoise_batch(
136 self,
137 images: torch.Tensor,
138 model: "BaseDiffusionModel",
139 ) -> List[torch.Tensor]:
140 """Denoise a batch of images.
141
142 This denoises a batch images. This is the image generation process.
143
144 Args:
145 images: A batch of noisy images.
146 model: The model to be used for denoising.
147
148 Returns:
149 A list of tensors containing a batch of denoised images.
150 """
151 denoised_images = []
152 for i in tqdm(self.steps, desc="Denoising"):
153 timestep = self.get_timestep(images.shape[0], idx=i)
154 images = self._denoise_step(images, model=model, timestep=timestep)
155 images = torch.clamp(images, -1.0, 1.0)
156 denoised_images.append(images)
157 return denoised_images