1import abc
2from typing import List
3from typing import Tuple
4
5import torch
6
7from diffusion_models.gaussian_diffusion.beta_schedulers import (
8 BaseBetaScheduler,
9)
10
11
[docs]
12class BaseDiffuser(abc.ABC):
13 def __init__(self, beta_scheduler: BaseBetaScheduler):
[docs]
14 self.beta_scheduler = beta_scheduler
15
16 @abc.abstractmethod
[docs]
17 def diffuse_batch(
18 self, images: torch.Tensor
19 ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
20 pass
21
22 @abc.abstractmethod
[docs]
23 def denoise_batch(self, images: torch.Tensor, model) -> List[torch.Tensor]:
24 pass
25
26 @abc.abstractmethod
[docs]
27 def to(self, device: str = "cpu"):
28 pass