Source code for diffusion_models.gaussian_diffusion.gaussian_diffuser

  1from typing import TYPE_CHECKING
  2from typing import List
  3from typing import Tuple
  4
  5import numpy as np
  6import torch
  7from tqdm import tqdm
  8
  9from diffusion_models.gaussian_diffusion.base_diffuser import BaseDiffuser
 10from diffusion_models.gaussian_diffusion.beta_schedulers import (
 11  BaseBetaScheduler,
 12)
 13from diffusion_models.utils.schemas import Checkpoint
 14
 15
 16if TYPE_CHECKING:
 17  from diffusion_models.models.base_diffusion_model import BaseDiffusionModel
 18
 19
[docs] 20class GaussianDiffuser(BaseDiffuser): 21 def __init__(self, beta_scheduler: BaseBetaScheduler): 22 """Initializes the class instance. 23 24 Args: 25 beta_scheduler (BaseBetaScheduler): The beta scheduler instance to be used. 26 27 """ 28 super().__init__(beta_scheduler)
[docs] 29 self.device: str = "cpu"
30 """The device to use. Defaults to cpu.""" 31 32 @classmethod
[docs] 33 def from_checkpoint(cls, checkpoint: Checkpoint) -> "GaussianDiffuser": 34 """Instantiate a Gaussian Diffuser from a training checkpoint. 35 36 Args: 37 checkpoint: The training checkpoint object containing 38 the trained model's parameters and configuration. 39 40 Returns: 41 An instance of the GaussianDiffuser class initialized with the parameters 42 loaded from the given checkpoint. 43 """ 44 return cls( 45 beta_scheduler=BaseBetaScheduler.from_tensors( 46 steps=checkpoint.beta_scheduler_config.steps, 47 betas=checkpoint.beta_scheduler_config.betas, 48 alpha_bars=checkpoint.beta_scheduler_config.alpha_bars, 49 ) 50 )
51
[docs] 52 def to(self, device: str = "cpu"): 53 """Moves the data to the specified device. 54 55 This performs a similar behaviour to the `to` method of PyTorch. moving the 56 GaussianDiffuser and the BetaScheduler to the specified device. 57 58 Args: 59 device: The device to which the method should move the object. 60 Default is "cpu". 61 62 Example: 63 >>> gaussian_diffuser = GaussianDiffuser() 64 >>> gaussian_diffuser = gaussian_diffuser.to(device="cuda") 65 """ 66 self.device = device 67 self.beta_scheduler = self.beta_scheduler.to(self.device) 68 return self
69 70 def _diffuse_batch( 71 self, images: torch.Tensor, timesteps: torch.Tensor 72 ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: 73 noise = torch.randn_like(images, device=self.device) 74 75 alpha_bar_t = self.beta_scheduler.alpha_bars.gather(dim=0, index=timesteps) 76 77 alpha_bar_t = alpha_bar_t.reshape((-1, *((1,) * (len(images.shape) - 1)))) 78 79 mu = torch.sqrt(alpha_bar_t) 80 sigma = torch.sqrt(1 - alpha_bar_t) 81 images = mu * images + sigma * noise 82 return images, noise, timesteps 83
[docs] 84 def diffuse_batch( 85 self, images: torch.Tensor 86 ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: 87 """Diffuse a batch of images. 88 89 Diffuse the given batch of images by adding noise based on the beta scheduler. 90 91 Args: 92 images: Batch of images to diffuse.\n 93 Shape should be ``(B, C, H, W)``. 94 95 Returns: 96 A tuple containing three tensors 97 98 - images: Diffused batch of images. 99 - noise: Noise added to the images. 100 - timesteps: Timesteps used for diffusion. 101 """ 102 timesteps = torch.randint( 103 0, self.beta_scheduler.steps, (images.shape[0],), device=self.device 104 ) 105 return self._diffuse_batch(images, timesteps)
106 107 @torch.no_grad() 108 def _denoise_step( 109 self, images: torch.Tensor, model: torch.nn.Module, timestep: torch.Tensor 110 ) -> torch.Tensor: 111 beta_t = self.beta_scheduler.betas[timestep].reshape(-1, 1, 1, 1) 112 alpha_t = 1 - beta_t 113 alpha_bar_t = self.beta_scheduler.alpha_bars.gather( 114 dim=0, index=timestep 115 ).reshape(-1, 1, 1, 1) 116 mu = (1 / torch.sqrt(alpha_t)) * ( 117 images - model(images, timestep) * (beta_t / torch.sqrt(1 - alpha_bar_t)) 118 ) 119 120 if timestep[0] == 0: 121 return mu 122 else: 123 sigma = torch.sqrt(beta_t) * torch.randn_like(images) 124 return mu + sigma 125
[docs] 126 def denoise_batch( 127 self, 128 images: torch.Tensor, 129 model: "BaseDiffusionModel", 130 ) -> List[torch.Tensor]: 131 """Denoise a batch of images. 132 133 This denoises a batch images. This is the image generation process. 134 135 Args: 136 images: A batch of noisy images. 137 model: The model to be used for denoising. 138 139 Returns: 140 A list of tensors containing a batch of denoised images. 141 """ 142 denoised_images = [] 143 for i in tqdm(range(self.beta_scheduler.steps)[::-1], desc="Denoising"): 144 timestep = torch.full((images.shape[0],), i, device=self.device) 145 images = self._denoise_step(images, model=model, timestep=timestep) 146 images = torch.clamp(images, -1.0, 1.0) 147 denoised_images.append(images) 148 return denoised_images