Source code for diffusion_models.gaussian_diffusion.gaussian_diffuser

 1from typing import List
 2
 3import torch
 4from tqdm import tqdm
 5
 6from diffusion_models.gaussian_diffusion.base_diffuser import BaseDiffuser
 7from diffusion_models.gaussian_diffusion.beta_schedulers import (
 8  BaseBetaScheduler,
 9)
10from diffusion_models.utils.schemas import Checkpoint
11
12
[docs] 13class GaussianDiffuser(BaseDiffuser): 14 def __init__(self, beta_scheduler: BaseBetaScheduler): 15 super().__init__(beta_scheduler)
[docs] 16 self.device = "cuda"
17 18 @classmethod
[docs] 19 def from_checkpoint(cls, checkpoint: Checkpoint) -> "GaussianDiffuser": 20 return cls( 21 beta_scheduler=BaseBetaScheduler.from_tensors( 22 steps=checkpoint.beta_scheduler_config.steps, 23 betas=checkpoint.beta_scheduler_config.betas, 24 alpha_bars=checkpoint.beta_scheduler_config.alpha_bars, 25 ) 26 )
27
[docs] 28 def to(self, device: str = "cpu"): 29 self.device = device 30 self.beta_scheduler = self.beta_scheduler.to(self.device) 31 return self
32
[docs] 33 def diffuse_batch(self, images): 34 timesteps = torch.randint( 35 0, self.beta_scheduler.steps, (images.shape[0],), device=self.device 36 ) 37 noise = torch.randn_like(images, device=self.device) 38 39 alpha_bar_t = self.beta_scheduler.alpha_bars.gather(dim=0, index=timesteps) 40 41 alpha_bar_t = alpha_bar_t.reshape((-1, *((1,) * (len(images.shape) - 1)))) 42 43 mu = torch.sqrt(alpha_bar_t) 44 sigma = torch.sqrt(1 - alpha_bar_t) 45 images = mu * images + sigma * noise 46 return images, noise, timesteps
47 48 @torch.no_grad() 49 def _denoise_step( 50 self, images: torch.Tensor, model: torch.nn.Module, timestep: torch.Tensor 51 ) -> torch.Tensor: 52 beta_t = self.beta_scheduler.betas[timestep].reshape(-1, 1, 1, 1) 53 alpha_t = 1 - beta_t 54 alpha_bar_t = self.beta_scheduler.alpha_bars.gather( 55 dim=0, index=timestep 56 ).reshape(-1, 1, 1, 1) 57 mu = (1 / torch.sqrt(alpha_t)) * ( 58 images - model(images, timestep) * (beta_t / torch.sqrt(1 - alpha_bar_t)) 59 ) 60 61 if timestep[0] == 0: 62 return mu 63 else: 64 sigma = torch.sqrt(beta_t) * torch.randn_like(images) 65 return mu + sigma 66
[docs] 67 def denoise_batch( 68 self, 69 images: torch.Tensor, 70 model: torch.nn.Module, 71 ) -> List[torch.Tensor]: 72 denoised_images = [] 73 for i in tqdm(range(self.beta_scheduler.steps)[::-1], desc="Denoising"): 74 timestep = torch.full((images.shape[0],), i, device=self.device) 75 images = self._denoise_step(images, model=model, timestep=timestep) 76 images = torch.clamp(images, -1.0, 1.0) 77 denoised_images.append(images) 78 return denoised_images