1import abc
2from typing import List, Tuple
3
4import torch
5from torch import nn
6
7from diffusion_models.gaussian_diffusion.base_diffuser import BaseDiffuser
8
9
[docs]
10class BaseDiffusionModel(nn.Module, abc.ABC):
11 def __init__(self, diffuser: BaseDiffuser):
12 """Initializes the object with the specified diffuser.
13
14 BaseDiffusionModel is an abstract base class for different diffusion models
15 implementations. It defines the interface that all diffusion models should
16 adhere to.
17
18 Args:
19 diffuser: The diffuser to use for the diffusion model.
20
21 Warnings:
22 Do not instantiate this class directly. Instead, build your own diffusion
23 model by inheriting from BaseDiffusionModel.
24 (see :class:`~.SimpleUnet.SimpleUnet`)
25
26 """
27 super().__init__()
[docs]
28 self.diffuser: BaseDiffuser = diffuser
29 """A diffuser to be used by the diffusion model."""
30
[docs]
31 def diffuse(
32 self, images: torch.Tensor
33 ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
34 """Diffuse a batch of images.
35
36 Args:
37 images: A tensor containing a batch of images.
38
39 Returns:
40 A tuple containing three tensors
41
42 - images: Diffused batch of images.
43 - noise: Noise added to the images.
44 - timesteps: Timesteps used for diffusion.
45 """
46 return self.diffuser.diffuse_batch(images=images)
47
[docs]
48 def denoise(self, images: torch.Tensor) -> List[torch.Tensor]:
49 """Denoise a batch of images.
50
51 Args:
52 images: A tensor containing a batch of images to denoise.
53
54 Returns:
55 A list of tensors containing a batch of denoised images.
56 """
57 return self.diffuser.denoise_batch(images=images, model=self)
58
59 @abc.abstractmethod
[docs]
60 def forward(self, x: torch.Tensor, timestep: torch.Tensor) -> torch.Tensor:
61 """Forward pass of the diffusion model.
62
63 The forward pass of the diffusion model, predicting the noise at a single
64 step.
65
66 Args:
67 x: A batch of noisy images.
68 timestep: The timesteps of each image in the batch.
69
70 Returns:
71 A tensor representing the noise predicted for each image.
72 """
73 raise NotImplementedError
74
[docs]
75 def to(self, device: str = "cpu") -> "BaseDiffusionModel":
76 """Moves the model to the specified device.
77
78 This performs a similar behaviour to the `to` method of PyTorch. moving the
79 DiffusionModel and all related artifacts to the specified device.
80
81 Args:
82 device: The device to which the method should move the object.
83 Default is "cpu".
84 """
85 new_self = super(BaseDiffusionModel, self).to(device)
86 new_self.diffuser = new_self.diffuser.to(device)
87 return new_self
88
[docs]
89 def compile(self, *args, **kwargs):
90 """Compiles the diffusion model.
91
92 This performs a similar behaviour to the `compile` method of PyTorch.
93
94 Returns:
95 A compiled diffusion model.
96 """
97 model = torch.compile(self, *args, **kwargs)
98 model.to = self.to
99 return model