Source code for diffusion_models.gaussian_diffusion.gaussian_diffuser

  1from typing import TYPE_CHECKING, List, Tuple
  2
  3import torch
  4from tqdm import tqdm
  5
  6from diffusion_models.gaussian_diffusion.base_diffuser import BaseDiffuser
  7from diffusion_models.gaussian_diffusion.beta_schedulers import (
  8  BaseBetaScheduler,
  9)
 10from diffusion_models.utils.schemas import Checkpoint, Timestep
 11
 12if TYPE_CHECKING:
 13  from diffusion_models.models.base_diffusion_model import BaseDiffusionModel
 14
 15
[docs] 16class GaussianDiffuser(BaseDiffuser): 17 def __init__(self, beta_scheduler: BaseBetaScheduler): 18 """Initializes the class instance. 19 20 Args: 21 beta_scheduler (BaseBetaScheduler): The beta scheduler instance to be used. 22 23 """ 24 super().__init__(beta_scheduler)
[docs] 25 self.device: str = "cpu"
26 """The device to use. Defaults to cpu.""" 27 28 @property
[docs] 29 def steps(self) -> List[int]: 30 return list(range(self.beta_scheduler.steps))[::-1]
31
[docs] 32 def get_timestep(self, number_of_images: int, idx: int): 33 timestep = torch.full((number_of_images,), idx, device=self.device) 34 return Timestep( 35 current=timestep, 36 )
37 38 @classmethod
[docs] 39 def from_checkpoint(cls, checkpoint: Checkpoint) -> "GaussianDiffuser": 40 """Instantiate a Gaussian Diffuser from a training checkpoint. 41 42 Args: 43 checkpoint: The training checkpoint object containing 44 the trained model's parameters and configuration. 45 46 Returns: 47 An instance of the GaussianDiffuser class initialized with the parameters 48 loaded from the given checkpoint. 49 """ 50 return cls( 51 beta_scheduler=BaseBetaScheduler.from_tensors( 52 steps=checkpoint.beta_scheduler_config.steps, 53 betas=checkpoint.beta_scheduler_config.betas, 54 alpha_bars=checkpoint.beta_scheduler_config.alpha_bars, 55 ) 56 )
57
[docs] 58 def to(self, device: str = "cpu"): 59 """Moves the data to the specified device. 60 61 This performs a similar behaviour to the `to` method of PyTorch. moving the 62 GaussianDiffuser and the BetaScheduler to the specified device. 63 64 Args: 65 device: The device to which the method should move the object. 66 Default is "cpu". 67 68 Example: 69 >>> gaussian_diffuser = GaussianDiffuser() 70 >>> gaussian_diffuser = gaussian_diffuser.to(device="cuda") 71 """ 72 self.device = device 73 self.beta_scheduler = self.beta_scheduler.to(self.device) 74 return self
75 76 def _diffuse_batch( 77 self, images: torch.Tensor, timesteps: torch.Tensor 78 ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: 79 noise = torch.randn_like(images, device=self.device) 80 81 alpha_bar_t = self.beta_scheduler.alpha_bars.gather(dim=0, index=timesteps) 82 83 alpha_bar_t = alpha_bar_t.reshape((-1, *((1,) * (len(images.shape) - 1)))) 84 85 mu = torch.sqrt(alpha_bar_t) 86 sigma = torch.sqrt(1 - alpha_bar_t) 87 images = mu * images + sigma * noise 88 return images, noise, timesteps 89
[docs] 90 def diffuse_batch( 91 self, images: torch.Tensor 92 ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: 93 """Diffuse a batch of images. 94 95 Diffuse the given batch of images by adding noise based on the beta scheduler. 96 97 Args: 98 images: Batch of images to diffuse.\n 99 Shape should be ``(B, C, H, W)``. 100 101 Returns: 102 A tuple containing three tensors 103 104 - images: Diffused batch of images. 105 - noise: Noise added to the images. 106 - timesteps: Timesteps used for diffusion. 107 """ 108 timesteps = torch.randint( 109 0, self.beta_scheduler.steps, (images.shape[0],), device=self.device 110 ) 111 return self._diffuse_batch(images, timesteps)
112 113 @torch.no_grad() 114 def _denoise_step( 115 self, images: torch.Tensor, model: torch.nn.Module, timestep: Timestep 116 ) -> torch.Tensor: 117 current_timestep = timestep.current 118 beta_t = self.beta_scheduler.betas[current_timestep].reshape(-1, 1, 1, 1) 119 alpha_t = 1 - beta_t 120 alpha_bar_t = self.beta_scheduler.alpha_bars.gather( 121 dim=0, index=current_timestep 122 ).reshape(-1, 1, 1, 1) 123 mu = (1 / torch.sqrt(alpha_t)) * ( 124 images 125 - model(images, current_timestep) 126 * (beta_t / torch.sqrt(1 - alpha_bar_t)) 127 ) 128 129 if current_timestep[0] == 0: 130 return mu 131 else: 132 sigma = torch.sqrt(beta_t) * torch.randn_like(images) 133 return mu + sigma 134
[docs] 135 def denoise_batch( 136 self, 137 images: torch.Tensor, 138 model: "BaseDiffusionModel", 139 ) -> List[torch.Tensor]: 140 """Denoise a batch of images. 141 142 This denoises a batch images. This is the image generation process. 143 144 Args: 145 images: A batch of noisy images. 146 model: The model to be used for denoising. 147 148 Returns: 149 A list of tensors containing a batch of denoised images. 150 """ 151 denoised_images = [] 152 for i in tqdm(self.steps, desc="Denoising"): 153 timestep = self.get_timestep(images.shape[0], idx=i) 154 images = self._denoise_step(images, model=model, timestep=timestep) 155 images = torch.clamp(images, -1.0, 1.0) 156 denoised_images.append(images) 157 return denoised_images