1from typing import List, Tuple, TYPE_CHECKING
2
3import torch
4from tqdm import tqdm
5
6from diffusion_models.gaussian_diffusion.base_diffuser import BaseDiffuser
7from diffusion_models.gaussian_diffusion.beta_schedulers import (
8 BaseBetaScheduler,
9)
10from diffusion_models.utils.schemas import Checkpoint
11
12if TYPE_CHECKING:
13 from diffusion_models.models.base_diffusion_model import BaseDiffusionModel
14
[docs]
15class GaussianDiffuser(BaseDiffuser):
16 def __init__(self, beta_scheduler: BaseBetaScheduler):
17 """Initializes the class instance.
18
19 Args:
20 beta_scheduler (BaseBetaScheduler): The beta scheduler instance to be used.
21
22 """
23 super().__init__(beta_scheduler)
[docs]
24 self.device: str = "cpu"
25 """The device to use. Defaults to cpu."""
26
27 @classmethod
[docs]
28 def from_checkpoint(cls, checkpoint: Checkpoint) -> "GaussianDiffuser":
29 """Instantiate a Gaussian Diffuser from a training checkpoint.
30
31 Args:
32 checkpoint: The training checkpoint object containing
33 the trained model's parameters and configuration.
34
35 Returns:
36 An instance of the GaussianDiffuser class initialized with the parameters
37 loaded from the given checkpoint.
38 """
39 return cls(
40 beta_scheduler=BaseBetaScheduler.from_tensors(
41 steps=checkpoint.beta_scheduler_config.steps,
42 betas=checkpoint.beta_scheduler_config.betas,
43 alpha_bars=checkpoint.beta_scheduler_config.alpha_bars,
44 )
45 )
46
[docs]
47 def to(self, device: str = "cpu"):
48 """Moves the data to the specified device.
49
50 This performs a similar behaviour to the `to` method of PyTorch. moving the
51 GaussianDiffuser and the BetaScheduler to the specified device.
52
53 Args:
54 device: The device to which the method should move the object.
55 Default is "cpu".
56
57 Example:
58 >>> gaussian_diffuser = GaussianDiffuser()
59 >>> gaussian_diffuser = gaussian_diffuser.to(device="cuda")
60 """
61 self.device = device
62 self.beta_scheduler = self.beta_scheduler.to(self.device)
63 return self
64
[docs]
65 def diffuse_batch(
66 self, images: torch.Tensor
67 ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
68 """Diffuse a batch of images.
69
70 Diffuse the given batch of images by adding noise based on the beta scheduler.
71
72 Args:
73 images: Batch of images to diffuse.\n
74 Shape should be ``(B, C, H, W)``.
75
76 Returns:
77 A tuple containing three tensors
78
79 - images: Diffused batch of images.
80 - noise: Noise added to the images.
81 - timesteps: Timesteps used for diffusion.
82 """
83 timesteps = torch.randint(
84 0, self.beta_scheduler.steps, (images.shape[0],), device=self.device
85 )
86 noise = torch.randn_like(images, device=self.device)
87
88 alpha_bar_t = self.beta_scheduler.alpha_bars.gather(dim=0, index=timesteps)
89
90 alpha_bar_t = alpha_bar_t.reshape((-1, *((1,) * (len(images.shape) - 1))))
91
92 mu = torch.sqrt(alpha_bar_t)
93 sigma = torch.sqrt(1 - alpha_bar_t)
94 images = mu * images + sigma * noise
95 return images, noise, timesteps
96
97 @torch.no_grad()
98 def _denoise_step(
99 self, images: torch.Tensor, model: torch.nn.Module, timestep: torch.Tensor
100 ) -> torch.Tensor:
101 beta_t = self.beta_scheduler.betas[timestep].reshape(-1, 1, 1, 1)
102 alpha_t = 1 - beta_t
103 alpha_bar_t = self.beta_scheduler.alpha_bars.gather(
104 dim=0, index=timestep
105 ).reshape(-1, 1, 1, 1)
106 mu = (1 / torch.sqrt(alpha_t)) * (
107 images - model(images, timestep) * (beta_t / torch.sqrt(1 - alpha_bar_t))
108 )
109
110 if timestep[0] == 0:
111 return mu
112 else:
113 sigma = torch.sqrt(beta_t) * torch.randn_like(images)
114 return mu + sigma
115
[docs]
116 def denoise_batch(
117 self,
118 images: torch.Tensor,
119 model: "BaseDiffusionModel",
120 ) -> List[torch.Tensor]:
121 """Denoise a batch of images.
122
123 This denoises a batch images. This is the image generation process.
124
125 Args:
126 images: A batch of noisy images.
127 model: The model to be used for denoising.
128
129 Returns:
130 A list of tensors containing a batch of denoised images.
131 """
132 denoised_images = []
133 for i in tqdm(range(self.beta_scheduler.steps)[::-1], desc="Denoising"):
134 timestep = torch.full((images.shape[0],), i, device=self.device)
135 images = self._denoise_step(images, model=model, timestep=timestep)
136 images = torch.clamp(images, -1.0, 1.0)
137 denoised_images.append(images)
138 return denoised_images