1import abc
2from typing import List, TYPE_CHECKING
3from typing import Tuple
4
5import torch
6
7from diffusion_models.gaussian_diffusion.beta_schedulers import (
8 BaseBetaScheduler,
9)
10from diffusion_models.utils.schemas import Timestep
11
12if TYPE_CHECKING:
13 from diffusion_models.models.base_diffusion_model import BaseDiffusionModel
14
15
[docs]
16class BaseDiffuser(abc.ABC):
17 def __init__(self, beta_scheduler: BaseBetaScheduler):
18 """Initializes the object with the specified beta scheduler.
19
20 BaseDiffuser is an abstract base class for different diffuser
21 implementations. It defines the interface that all diffusers should adhere
22 to.
23
24 Args:
25 beta_scheduler: The beta scheduler used by the diffuser.
26
27 Warnings:
28 Do not instantiate this class directly. Instead, build your own Diffuser
29 by inheriting from BaseDiffuser.
30 (see :class:`~.gaussian_diffuser.GaussianDiffuser`)
31
32 """
[docs]
33 self.beta_scheduler: BaseBetaScheduler = beta_scheduler
34 """The beta scheduler used by the diffuser."""
35
36 @property
37 @abc.abstractmethod
[docs]
38 def steps(self) -> List[int]:
39 """Returns the list of steps used in the denoising process."""
40 raise NotImplementedError
41
42 @abc.abstractmethod
[docs]
43 def get_timestep(self, number_of_images: int, idx: int) -> Timestep:
44 """Get timestep information used for denoising."""
45 raise NotImplementedError
46
47 @abc.abstractmethod
[docs]
48 def diffuse_batch(
49 self, images: torch.Tensor
50 ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
51 """Diffuse a batch of images.
52
53 Args:
54 images: A tensor containing a batch of images.
55
56 Returns:
57 A tuple containing three tensors
58
59 - images: Diffused batch of images.
60 - noise: Noise added to the images.
61 - timesteps: Timesteps used for diffusion.
62 """
63 raise NotImplementedError()
64
65 @abc.abstractmethod
[docs]
66 def denoise_batch(
67 self, images: torch.Tensor, model: "BaseDiffusionModel"
68 ) -> List[torch.Tensor]:
69 """Denoise a batch of images.
70
71 Args:
72 images: A tensor containing a batch of images to denoise.
73 model: The model to be used for denoising.
74
75 Returns:
76 A list of tensors containing a batch of denoised images.
77 """
78 raise NotImplementedError()
79
80 @abc.abstractmethod
[docs]
81 def to(self, device: str = "cpu"):
82 """Moves the data to the specified device.
83
84 This performs a similar behaviour to the `to` method of PyTorch.
85
86 Args:
87 device: The device to which the method should move the data.
88 It should be a string indicating the desired device.
89 Default is "cpu".
90 """
91 raise NotImplementedError()