Source code for diffusion_models.gaussian_diffusion.base_diffuser

 1import abc
 2from typing import List, TYPE_CHECKING
 3from typing import Tuple
 4
 5import torch
 6
 7from diffusion_models.gaussian_diffusion.beta_schedulers import (
 8  BaseBetaScheduler,
 9)
10if TYPE_CHECKING:
11  from diffusion_models.models.base_diffusion_model import BaseDiffusionModel
12
13
[docs] 14class BaseDiffuser(abc.ABC): 15 16 def __init__(self, beta_scheduler: BaseBetaScheduler): 17 """Initializes the object with the specified beta scheduler. 18 19 BaseDiffuser is an abstract base class for different diffuser 20 implementations. It defines the interface that all diffusers should adhere 21 to. 22 23 Args: 24 beta_scheduler: The beta scheduler used by the diffuser. 25 26 Warnings: 27 Do not instantiate this class directly. Instead, build your own Diffuser 28 by inheriting from BaseDiffuser. 29 (see :class:`~.gaussian_diffuser.GaussianDiffuser`) 30 31 """
[docs] 32 self.beta_scheduler: BaseBetaScheduler = beta_scheduler
33 """The beta scheduler used by the diffuser.""" 34 35 @abc.abstractmethod
[docs] 36 def diffuse_batch( 37 self, images: torch.Tensor 38 ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: 39 """Diffuse a batch of images. 40 41 Args: 42 images: A tensor containing a batch of images. 43 44 Returns: 45 A tuple containing three tensors 46 47 - images: Diffused batch of images. 48 - noise: Noise added to the images. 49 - timesteps: Timesteps used for diffusion. 50 """ 51 raise NotImplementedError()
52 53 @abc.abstractmethod
[docs] 54 def denoise_batch(self, images: torch.Tensor, model: "BaseDiffusionModel") -> List[torch.Tensor]: 55 """Denoise a batch of images. 56 57 Args: 58 images: A tensor containing a batch of images to denoise. 59 model: The model to be used for denoising. 60 61 Returns: 62 A list of tensors containing a batch of denoised images. 63 """ 64 raise NotImplementedError()
65 66 @abc.abstractmethod
[docs] 67 def to(self, device: str = "cpu"): 68 """Moves the data to the specified device. 69 70 This performs a similar behaviour to the `to` method of PyTorch. 71 72 Args: 73 device: The device to which the method should move the data. 74 It should be a string indicating the desired device. 75 Default is "cpu". 76 """ 77 raise NotImplementedError()
78