Source code for diffusion_models.gaussian_diffusion.base_diffuser

 1import abc
 2from typing import List
 3from typing import Tuple
 4
 5import torch
 6
 7from diffusion_models.gaussian_diffusion.beta_schedulers import (
 8  BaseBetaScheduler,
 9)
10
11
[docs] 12class BaseDiffuser(abc.ABC): 13 def __init__(self, beta_scheduler: BaseBetaScheduler):
[docs] 14 self.beta_scheduler = beta_scheduler
15 16 @abc.abstractmethod
[docs] 17 def diffuse_batch( 18 self, images: torch.Tensor 19 ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: 20 pass
21 22 @abc.abstractmethod
[docs] 23 def denoise_batch(self, images: torch.Tensor, model) -> List[torch.Tensor]: 24 pass
25 26 @abc.abstractmethod
[docs] 27 def to(self, device: str = "cpu"): 28 pass