1from typing import List
2
3import torch
4from tqdm import tqdm
5
6from diffusion_models.gaussian_diffusion.base_diffuser import BaseDiffuser
7from diffusion_models.gaussian_diffusion.beta_schedulers import (
8 BaseBetaScheduler,
9)
10from diffusion_models.utils.schemas import Checkpoint
11
12
[docs]
13class GaussianDiffuser(BaseDiffuser):
14 def __init__(self, beta_scheduler: BaseBetaScheduler):
15 super().__init__(beta_scheduler)
[docs]
16 self.device = "cuda"
17
18 @classmethod
[docs]
19 def from_checkpoint(cls, checkpoint: Checkpoint) -> "GaussianDiffuser":
20 return cls(
21 beta_scheduler=BaseBetaScheduler.from_tensors(
22 steps=checkpoint.beta_scheduler_config.steps,
23 betas=checkpoint.beta_scheduler_config.betas,
24 alpha_bars=checkpoint.beta_scheduler_config.alpha_bars,
25 )
26 )
27
[docs]
28 def to(self, device: str = "cpu"):
29 self.device = device
30 self.beta_scheduler = self.beta_scheduler.to(self.device)
31 return self
32
[docs]
33 def diffuse_batch(self, images):
34 timesteps = torch.randint(
35 0, self.beta_scheduler.steps, (images.shape[0],), device=self.device
36 )
37 noise = torch.randn_like(images, device=self.device)
38
39 alpha_bar_t = self.beta_scheduler.alpha_bars.gather(dim=0, index=timesteps)
40
41 alpha_bar_t = alpha_bar_t.reshape((-1, *((1,) * (len(images.shape) - 1))))
42
43 mu = torch.sqrt(alpha_bar_t)
44 sigma = torch.sqrt(1 - alpha_bar_t)
45 images = mu * images + sigma * noise
46 return images, noise, timesteps
47
48 @torch.no_grad()
49 def _denoise_step(
50 self, images: torch.Tensor, model: torch.nn.Module, timestep: torch.Tensor
51 ) -> torch.Tensor:
52 beta_t = self.beta_scheduler.betas[timestep].reshape(-1, 1, 1, 1)
53 alpha_t = 1 - beta_t
54 alpha_bar_t = self.beta_scheduler.alpha_bars.gather(
55 dim=0, index=timestep
56 ).reshape(-1, 1, 1, 1)
57 mu = (1 / torch.sqrt(alpha_t)) * (
58 images - model(images, timestep) * (beta_t / torch.sqrt(1 - alpha_bar_t))
59 )
60
61 if timestep[0] == 0:
62 return mu
63 else:
64 sigma = torch.sqrt(beta_t) * torch.randn_like(images)
65 return mu + sigma
66
[docs]
67 def denoise_batch(
68 self,
69 images: torch.Tensor,
70 model: torch.nn.Module,
71 ) -> List[torch.Tensor]:
72 denoised_images = []
73 for i in tqdm(range(self.beta_scheduler.steps)[::-1], desc="Denoising"):
74 timestep = torch.full((images.shape[0],), i, device=self.device)
75 images = self._denoise_step(images, model=model, timestep=timestep)
76 images = torch.clamp(images, -1.0, 1.0)
77 denoised_images.append(images)
78 return denoised_images