Source code for diffusion_models.gaussian_diffusion.beta_schedulers

  1import abc
  2import logging
  3from typing import Optional
  4
  5import torch
  6
  7
[docs] 8class BaseBetaScheduler: 9 def __init__(self, steps: int, enforce_zero_terminal_snr: bool = False): 10 super().__init__()
[docs] 11 self.steps = steps
[docs] 12 self.betas = self.sample_betas()
[docs] 13 self.alpha_bars = self.compute_alpha_bar()
14 15 if enforce_zero_terminal_snr: 16 self.enforce_zero_terminal_snr() 17
[docs] 18 def enforce_zero_terminal_snr(self): 19 alpha_bar_length = len(self.alpha_bars) 20 21 # Convert betas to alphas_bar_sqrt 22 alphas = 1 - self.betas 23 alphas_bar = alphas.cumprod(0) 24 alphas_bar_sqrt = alphas_bar.sqrt() 25 26 # Store old values. 27 alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone() 28 alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone() 29 # Shift so last timestep is zero. 30 alphas_bar_sqrt -= alphas_bar_sqrt_T 31 # Scale so first timestep is back to old value. 32 alphas_bar_sqrt *= alphas_bar_sqrt_0 / ( 33 alphas_bar_sqrt_0 - alphas_bar_sqrt_T 34 ) 35 36 # Convert alphas_bar_sqrt to betas 37 alphas_bar = alphas_bar_sqrt**2 38 alphas = alphas_bar[1:] / alphas_bar[:-1] 39 alphas = torch.cat([alphas_bar[0:1], alphas]) 40 betas = 1 - alphas 41 if len(alphas) == alpha_bar_length: 42 self.betas = betas 43 self.alpha_bars = alphas_bar 44 else: 45 logging.warning( 46 "Got different alpha_bar length after enforcing zero SNR. Please check your beta scheduler" 47 )
48 49 @abc.abstractmethod
[docs] 50 def sample_betas(self): 51 pass
52 53 @abc.abstractmethod
[docs] 54 def compute_alpha_bar(self): 55 pass
56
[docs] 57 def to(self, device: str): 58 self.betas = self.betas.to(device) 59 self.alpha_bars = self.alpha_bars.to(device) 60 return self
61 62 @classmethod
[docs] 63 def from_tensors( 64 cls, steps: int, betas: torch.Tensor, alpha_bars: torch.Tensor 65 ): 66 generic_beta_scheduler = cls(0) 67 generic_beta_scheduler.steps = steps 68 generic_beta_scheduler.betas = betas 69 generic_beta_scheduler.alpha_bars = alpha_bars 70 return generic_beta_scheduler
71 72
[docs] 73class LinearBetaScheduler(BaseBetaScheduler): 74 def __init__( 75 self, 76 beta_start: float = 0.0001, 77 beta_end: float = 0.02, 78 steps: int = 1000, 79 enforce_zero_terminal_snr: bool = True, 80 ):
[docs] 81 self.beta_start = beta_start
[docs] 82 self.beta_end = beta_end
83 super().__init__( 84 steps=steps, 85 enforce_zero_terminal_snr=enforce_zero_terminal_snr, 86 ) 87
[docs] 88 def sample_betas(self): 89 return torch.linspace(self.beta_start, self.beta_end, self.steps)
90
[docs] 91 def compute_alpha_bar(self): 92 alphas = 1 - self.betas 93 alpha_bar = torch.cumprod(alphas, dim=0) 94 return alpha_bar
95 96
[docs] 97class CosineBetaScheduler(BaseBetaScheduler): 98 def __init__( 99 self, 100 offset: float = 0.008, 101 steps: int = 1000, 102 max_beta: Optional[float] = 0.999, 103 ):
[docs] 104 self.offset = offset
[docs] 105 self.max_beta = max_beta
[docs] 106 self.steps = steps
107 self._alpha_bars = self._compute_alpha_bar() 108 self._betas = self._compute_betas() 109 110 super().__init__( 111 steps=steps, 112 ) 113
[docs] 114 def f(self, t: torch.Tensor): 115 return ( 116 torch.cos( 117 (((t / self.steps) + self.offset) / (1 + self.offset)) * (torch.pi / 2) 118 ) 119 ** 2 120 )
121 122 def _compute_betas(self): 123 betas = 1 - self._alpha_bars[1:] / self._alpha_bars[:-1] 124 if self.max_beta: 125 betas = torch.clip(betas, max=self.max_beta) 126 return betas 127 128 def _compute_alpha_bar(self): 129 t = torch.linspace(0, self.steps, self.steps, dtype=torch.float32) 130 return self.f(t) / self.f(torch.tensor([0], dtype=torch.float32)) 131
[docs] 132 def sample_betas(self): 133 return self._betas
134
[docs] 135 def compute_alpha_bar(self): 136 return self._alpha_bars