Source code for diffusion_models.gaussian_diffusion.base_diffuser

 1import abc
 2from typing import List, TYPE_CHECKING
 3from typing import Tuple
 4
 5import torch
 6
 7from diffusion_models.gaussian_diffusion.beta_schedulers import (
 8  BaseBetaScheduler,
 9)
10from diffusion_models.utils.schemas import Timestep
11
12if TYPE_CHECKING:
13  from diffusion_models.models.base_diffusion_model import BaseDiffusionModel
14
15
[docs] 16class BaseDiffuser(abc.ABC): 17 def __init__(self, beta_scheduler: BaseBetaScheduler): 18 """Initializes the object with the specified beta scheduler. 19 20 BaseDiffuser is an abstract base class for different diffuser 21 implementations. It defines the interface that all diffusers should adhere 22 to. 23 24 Args: 25 beta_scheduler: The beta scheduler used by the diffuser. 26 27 Warnings: 28 Do not instantiate this class directly. Instead, build your own Diffuser 29 by inheriting from BaseDiffuser. 30 (see :class:`~.gaussian_diffuser.GaussianDiffuser`) 31 32 """
[docs] 33 self.beta_scheduler: BaseBetaScheduler = beta_scheduler
34 """The beta scheduler used by the diffuser.""" 35 36 @property 37 @abc.abstractmethod
[docs] 38 def steps(self) -> List[int]: 39 """Returns the list of steps used in the denoising process.""" 40 raise NotImplementedError
41 42 @abc.abstractmethod
[docs] 43 def get_timestep(self, number_of_images: int, idx: int) -> Timestep: 44 """Get timestep information used for denoising.""" 45 raise NotImplementedError
46 47 @abc.abstractmethod
[docs] 48 def diffuse_batch( 49 self, images: torch.Tensor 50 ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: 51 """Diffuse a batch of images. 52 53 Args: 54 images: A tensor containing a batch of images. 55 56 Returns: 57 A tuple containing three tensors 58 59 - images: Diffused batch of images. 60 - noise: Noise added to the images. 61 - timesteps: Timesteps used for diffusion. 62 """ 63 raise NotImplementedError()
64 65 @abc.abstractmethod
[docs] 66 def denoise_batch( 67 self, images: torch.Tensor, model: "BaseDiffusionModel" 68 ) -> List[torch.Tensor]: 69 """Denoise a batch of images. 70 71 Args: 72 images: A tensor containing a batch of images to denoise. 73 model: The model to be used for denoising. 74 75 Returns: 76 A list of tensors containing a batch of denoised images. 77 """ 78 raise NotImplementedError()
79 80 @abc.abstractmethod
[docs] 81 def to(self, device: str = "cpu"): 82 """Moves the data to the specified device. 83 84 This performs a similar behaviour to the `to` method of PyTorch. 85 86 Args: 87 device: The device to which the method should move the data. 88 It should be a string indicating the desired device. 89 Default is "cpu". 90 """ 91 raise NotImplementedError()