1from typing import TYPE_CHECKING
2from typing import List
3from typing import Tuple
4
5import numpy as np
6import torch
7from tqdm import tqdm
8
9from diffusion_models.gaussian_diffusion.base_diffuser import BaseDiffuser
10from diffusion_models.gaussian_diffusion.beta_schedulers import (
11 BaseBetaScheduler,
12)
13from diffusion_models.utils.schemas import Checkpoint
14
15
16if TYPE_CHECKING:
17 from diffusion_models.models.base_diffusion_model import BaseDiffusionModel
18
19
[docs]
20class GaussianDiffuser(BaseDiffuser):
21 def __init__(self, beta_scheduler: BaseBetaScheduler):
22 """Initializes the class instance.
23
24 Args:
25 beta_scheduler (BaseBetaScheduler): The beta scheduler instance to be used.
26
27 """
28 super().__init__(beta_scheduler)
[docs]
29 self.device: str = "cpu"
30 """The device to use. Defaults to cpu."""
31
32 @classmethod
[docs]
33 def from_checkpoint(cls, checkpoint: Checkpoint) -> "GaussianDiffuser":
34 """Instantiate a Gaussian Diffuser from a training checkpoint.
35
36 Args:
37 checkpoint: The training checkpoint object containing
38 the trained model's parameters and configuration.
39
40 Returns:
41 An instance of the GaussianDiffuser class initialized with the parameters
42 loaded from the given checkpoint.
43 """
44 return cls(
45 beta_scheduler=BaseBetaScheduler.from_tensors(
46 steps=checkpoint.beta_scheduler_config.steps,
47 betas=checkpoint.beta_scheduler_config.betas,
48 alpha_bars=checkpoint.beta_scheduler_config.alpha_bars,
49 )
50 )
51
[docs]
52 def to(self, device: str = "cpu"):
53 """Moves the data to the specified device.
54
55 This performs a similar behaviour to the `to` method of PyTorch. moving the
56 GaussianDiffuser and the BetaScheduler to the specified device.
57
58 Args:
59 device: The device to which the method should move the object.
60 Default is "cpu".
61
62 Example:
63 >>> gaussian_diffuser = GaussianDiffuser()
64 >>> gaussian_diffuser = gaussian_diffuser.to(device="cuda")
65 """
66 self.device = device
67 self.beta_scheduler = self.beta_scheduler.to(self.device)
68 return self
69
70 def _diffuse_batch(
71 self, images: torch.Tensor, timesteps: torch.Tensor
72 ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
73 noise = torch.randn_like(images, device=self.device)
74
75 alpha_bar_t = self.beta_scheduler.alpha_bars.gather(dim=0, index=timesteps)
76
77 alpha_bar_t = alpha_bar_t.reshape((-1, *((1,) * (len(images.shape) - 1))))
78
79 mu = torch.sqrt(alpha_bar_t)
80 sigma = torch.sqrt(1 - alpha_bar_t)
81 images = mu * images + sigma * noise
82 return images, noise, timesteps
83
[docs]
84 def diffuse_batch(
85 self, images: torch.Tensor
86 ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
87 """Diffuse a batch of images.
88
89 Diffuse the given batch of images by adding noise based on the beta scheduler.
90
91 Args:
92 images: Batch of images to diffuse.\n
93 Shape should be ``(B, C, H, W)``.
94
95 Returns:
96 A tuple containing three tensors
97
98 - images: Diffused batch of images.
99 - noise: Noise added to the images.
100 - timesteps: Timesteps used for diffusion.
101 """
102 timesteps = torch.randint(
103 0, self.beta_scheduler.steps, (images.shape[0],), device=self.device
104 )
105 return self._diffuse_batch(images, timesteps)
106
107 @torch.no_grad()
108 def _denoise_step(
109 self, images: torch.Tensor, model: torch.nn.Module, timestep: torch.Tensor
110 ) -> torch.Tensor:
111 beta_t = self.beta_scheduler.betas[timestep].reshape(-1, 1, 1, 1)
112 alpha_t = 1 - beta_t
113 alpha_bar_t = self.beta_scheduler.alpha_bars.gather(
114 dim=0, index=timestep
115 ).reshape(-1, 1, 1, 1)
116 mu = (1 / torch.sqrt(alpha_t)) * (
117 images - model(images, timestep) * (beta_t / torch.sqrt(1 - alpha_bar_t))
118 )
119
120 if timestep[0] == 0:
121 return mu
122 else:
123 sigma = torch.sqrt(beta_t) * torch.randn_like(images)
124 return mu + sigma
125
[docs]
126 def denoise_batch(
127 self,
128 images: torch.Tensor,
129 model: "BaseDiffusionModel",
130 ) -> List[torch.Tensor]:
131 """Denoise a batch of images.
132
133 This denoises a batch images. This is the image generation process.
134
135 Args:
136 images: A batch of noisy images.
137 model: The model to be used for denoising.
138
139 Returns:
140 A list of tensors containing a batch of denoised images.
141 """
142 denoised_images = []
143 for i in tqdm(range(self.beta_scheduler.steps)[::-1], desc="Denoising"):
144 timestep = torch.full((images.shape[0],), i, device=self.device)
145 images = self._denoise_step(images, model=model, timestep=timestep)
146 images = torch.clamp(images, -1.0, 1.0)
147 denoised_images.append(images)
148 return denoised_images