Source code for diffusion_models.gaussian_diffusion.gaussian_diffuser

  1from typing import List, Tuple, TYPE_CHECKING
  2
  3import torch
  4from tqdm import tqdm
  5
  6from diffusion_models.gaussian_diffusion.base_diffuser import BaseDiffuser
  7from diffusion_models.gaussian_diffusion.beta_schedulers import (
  8  BaseBetaScheduler,
  9)
 10from diffusion_models.utils.schemas import Checkpoint
 11
 12if TYPE_CHECKING:
 13  from diffusion_models.models.base_diffusion_model import BaseDiffusionModel
 14
[docs] 15class GaussianDiffuser(BaseDiffuser): 16 def __init__(self, beta_scheduler: BaseBetaScheduler): 17 """Initializes the class instance. 18 19 Args: 20 beta_scheduler (BaseBetaScheduler): The beta scheduler instance to be used. 21 22 """ 23 super().__init__(beta_scheduler)
[docs] 24 self.device: str = "cpu"
25 """The device to use. Defaults to cpu.""" 26 27 @classmethod
[docs] 28 def from_checkpoint(cls, checkpoint: Checkpoint) -> "GaussianDiffuser": 29 """Instantiate a Gaussian Diffuser from a training checkpoint. 30 31 Args: 32 checkpoint: The training checkpoint object containing 33 the trained model's parameters and configuration. 34 35 Returns: 36 An instance of the GaussianDiffuser class initialized with the parameters 37 loaded from the given checkpoint. 38 """ 39 return cls( 40 beta_scheduler=BaseBetaScheduler.from_tensors( 41 steps=checkpoint.beta_scheduler_config.steps, 42 betas=checkpoint.beta_scheduler_config.betas, 43 alpha_bars=checkpoint.beta_scheduler_config.alpha_bars, 44 ) 45 )
46
[docs] 47 def to(self, device: str = "cpu"): 48 """Moves the data to the specified device. 49 50 This performs a similar behaviour to the `to` method of PyTorch. moving the 51 GaussianDiffuser and the BetaScheduler to the specified device. 52 53 Args: 54 device: The device to which the method should move the object. 55 Default is "cpu". 56 57 Example: 58 >>> gaussian_diffuser = GaussianDiffuser() 59 >>> gaussian_diffuser = gaussian_diffuser.to(device="cuda") 60 """ 61 self.device = device 62 self.beta_scheduler = self.beta_scheduler.to(self.device) 63 return self
64
[docs] 65 def diffuse_batch( 66 self, images: torch.Tensor 67 ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: 68 """Diffuse a batch of images. 69 70 Diffuse the given batch of images by adding noise based on the beta scheduler. 71 72 Args: 73 images: Batch of images to diffuse.\n 74 Shape should be ``(B, C, H, W)``. 75 76 Returns: 77 A tuple containing three tensors 78 79 - images: Diffused batch of images. 80 - noise: Noise added to the images. 81 - timesteps: Timesteps used for diffusion. 82 """ 83 timesteps = torch.randint( 84 0, self.beta_scheduler.steps, (images.shape[0],), device=self.device 85 ) 86 noise = torch.randn_like(images, device=self.device) 87 88 alpha_bar_t = self.beta_scheduler.alpha_bars.gather(dim=0, index=timesteps) 89 90 alpha_bar_t = alpha_bar_t.reshape((-1, *((1,) * (len(images.shape) - 1)))) 91 92 mu = torch.sqrt(alpha_bar_t) 93 sigma = torch.sqrt(1 - alpha_bar_t) 94 images = mu * images + sigma * noise 95 return images, noise, timesteps
96 97 @torch.no_grad() 98 def _denoise_step( 99 self, images: torch.Tensor, model: torch.nn.Module, timestep: torch.Tensor 100 ) -> torch.Tensor: 101 beta_t = self.beta_scheduler.betas[timestep].reshape(-1, 1, 1, 1) 102 alpha_t = 1 - beta_t 103 alpha_bar_t = self.beta_scheduler.alpha_bars.gather( 104 dim=0, index=timestep 105 ).reshape(-1, 1, 1, 1) 106 mu = (1 / torch.sqrt(alpha_t)) * ( 107 images - model(images, timestep) * (beta_t / torch.sqrt(1 - alpha_bar_t)) 108 ) 109 110 if timestep[0] == 0: 111 return mu 112 else: 113 sigma = torch.sqrt(beta_t) * torch.randn_like(images) 114 return mu + sigma 115
[docs] 116 def denoise_batch( 117 self, 118 images: torch.Tensor, 119 model: "BaseDiffusionModel", 120 ) -> List[torch.Tensor]: 121 """Denoise a batch of images. 122 123 This denoises a batch images. This is the image generation process. 124 125 Args: 126 images: A batch of noisy images. 127 model: The model to be used for denoising. 128 129 Returns: 130 A list of tensors containing a batch of denoised images. 131 """ 132 denoised_images = [] 133 for i in tqdm(range(self.beta_scheduler.steps)[::-1], desc="Denoising"): 134 timestep = torch.full((images.shape[0],), i, device=self.device) 135 images = self._denoise_step(images, model=model, timestep=timestep) 136 images = torch.clamp(images, -1.0, 1.0) 137 denoised_images.append(images) 138 return denoised_images