Source code for diffusion_models.gaussian_diffusion.beta_schedulers

  1import abc
  2import logging
  3from typing import Optional
  4
  5import torch
  6
  7
[docs] 8class BaseBetaScheduler: 9 def __init__(self, steps: int, enforce_zero_terminal_snr: bool = False): 10 """Initializes a beta scheduler. 11 12 BaseBetaScheduler is an abstract base class for different beta scheduler 13 implementations. It defines the interface that all beta schedulers should 14 adhere to. 15 16 Args: 17 steps: The number of steps for the beta. 18 enforce_zero_terminal_snr: Whether to enforce zero terminal SNR inline 19 with `"Common Diffusion Noise Schedules and Sample Steps are Flawed" 20 <https://arxiv.org/abs/2305.08891>`_.\n 21 Defaults to ``False``. 22 23 Warnings: 24 Do not instantiate this class directly. Instead, build your own Beta 25 scheduler by inheriting from BaseBetaScheduler. 26 (see :class:`~.LinearBetaScheduler`) 27 """ 28 super().__init__()
[docs] 29 self.steps: int = steps
30 """The number of steps for the beta scheduler."""
[docs] 31 self.betas = self.sample_betas()
32 """The :math:`\\beta` computed according to :meth:`~.BaseBetaScheduler.sample_betas`."""
[docs] 33 self.alpha_bars = self.compute_alpha_bar()
34 """The :math:`\\bar{\\alpha}` computed according to :meth:`~.BaseBetaScheduler.compute_alpha_bar`.""" 35 36 if enforce_zero_terminal_snr: 37 self.enforce_zero_terminal_snr() 38
[docs] 39 def enforce_zero_terminal_snr(self): 40 """Enforce terminal SNR by adjusting :math:`\\beta` and :math:`\\bar{\\alpha}`. 41 42 This method enforces zero terminal SNR according to 43 `"Common Diffusion Noise Schedules and Sample Steps are Flawed" 44 <https://arxiv.org/abs/2305.08891>`_. 45 """ 46 alpha_bar_length = len(self.alpha_bars) 47 48 # Convert betas to alphas_bar_sqrt 49 alphas = 1 - self.betas 50 alphas_bar = alphas.cumprod(0) 51 alphas_bar_sqrt = alphas_bar.sqrt() 52 53 # Store old values. 54 alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone() 55 alphas_bar_sqrt_t = alphas_bar_sqrt[-1].clone() 56 # Shift so last timestep is zero. 57 alphas_bar_sqrt -= alphas_bar_sqrt_t 58 # Scale so first timestep is back to old value. 59 alphas_bar_sqrt *= alphas_bar_sqrt_0 / ( 60 alphas_bar_sqrt_0 - alphas_bar_sqrt_t 61 ) 62 63 # Convert alphas_bar_sqrt to betas 64 alphas_bar = alphas_bar_sqrt**2 65 alphas = alphas_bar[1:] / alphas_bar[:-1] 66 alphas = torch.cat([alphas_bar[0:1], alphas]) 67 betas = 1 - alphas 68 if len(alphas) == alpha_bar_length: 69 self.betas = betas 70 self.alpha_bars = alphas_bar 71 else: 72 logging.warning( 73 "Got different alpha_bar length after enforcing zero SNR. Please check your beta scheduler" 74 )
75 76 @abc.abstractmethod
[docs] 77 def sample_betas(self) -> torch.Tensor: 78 """Compute :math:`\\beta` for noise scheduling. 79 80 Returns: 81 A torch tensor of the :math:`\\beta` values. 82 """ 83 raise NotImplementedError()
84 85 @abc.abstractmethod
[docs] 86 def compute_alpha_bar(self) -> torch.Tensor: 87 """Compute :math:`\\bar{\\alpha}` for noise scheduling. 88 89 Returns: 90 A torch tensor of the :math:`\\bar{\\alpha}` values. 91 """ 92 raise NotImplementedError()
93
[docs] 94 def to(self, device: str): 95 """Moves the beta scheduler to the given device. 96 97 Args: 98 device: The device to which the method should move the object. 99 Default is "cpu". 100 101 """ 102 self.betas = self.betas.to(device) 103 self.alpha_bars = self.alpha_bars.to(device) 104 return self
105 106 @classmethod
[docs] 107 def from_tensors( 108 cls, steps: int, betas: torch.Tensor, alpha_bars: torch.Tensor 109 ): 110 """Instantiate a beta scheduler from tensors. 111 112 Instantiate a beta scheduler from tensors. This is particularly useful for 113 loading checkpoints. 114 115 Args: 116 steps: The number of steps for the beta scheduler. 117 betas: The pre-computed beta values for the noise scheduler. 118 alpha_bars: The pre-computed alpha bar values for the noise scheduler. 119 120 Returns: 121 122 """ 123 generic_beta_scheduler = cls(0) 124 generic_beta_scheduler.steps = steps 125 generic_beta_scheduler.betas = betas 126 generic_beta_scheduler.alpha_bars = alpha_bars 127 return generic_beta_scheduler
128 129
[docs] 130class LinearBetaScheduler(BaseBetaScheduler): 131 def __init__( 132 self, 133 beta_start: float = 0.0001, 134 beta_end: float = 0.02, 135 steps: int = 1000, 136 enforce_zero_terminal_snr: bool = True, 137 ): 138 """A Linear Beta scheduler. 139 140 A simple linear beta scheduler with betas linearly spaced between 141 ``beta_start`` and ``beta_end``. 142 143 Args: 144 beta_start: The starting value of the betas. 145 beta_end: The end value of the betas. 146 steps: The number of steps for the beta scheduler. This is also the number 147 of betas. 148 enforce_zero_terminal_snr: Whether to enforce zero terminal SNR. 149 """
[docs] 150 self.beta_start: int = beta_start
151 """The starting value of the betas."""
[docs] 152 self.beta_end: int = beta_end
153 """The end value of the betas.""" 154 super().__init__( 155 steps=steps, 156 enforce_zero_terminal_snr=enforce_zero_terminal_snr, 157 ) 158
[docs] 159 def sample_betas(self) -> torch.Tensor: 160 """Return linearly spaced betas between ``self.beta_start`` and ``self.beta_end``.""" 161 return torch.linspace(self.beta_start, self.beta_end, self.steps)
162
[docs] 163 def compute_alpha_bar(self): 164 """Return :math:`\\bar{\\alpha}` computed from the beta values.""" 165 alphas = 1 - self.betas 166 alpha_bar = torch.cumprod(alphas, dim=0) 167 return alpha_bar
168 169
[docs] 170class CosineBetaScheduler(BaseBetaScheduler): 171 def __init__( 172 self, 173 offset: float = 0.008, 174 steps: int = 1000, 175 max_beta: Optional[float] = 0.999, 176 ): 177 """A Cosine Beta scheduler. 178 179 A Cosine Beta Scheduler based on the following formulas: 180 181 .. math:: 182 :nowrap: 183 184 \\begin{equation} 185 \\left\\{ \\begin{aligned} 186 \\bar{\\alpha}_t &= \\frac{f(t)}{f(0)} \\\\ 187 \\beta_t &= 1 - \\frac{\\bar{\\alpha}_t}{\\bar{\\alpha}_t -1} 188 \\end{aligned} \\right. 189 \\end{equation} 190 191 where 192 193 .. math:: 194 195 f(t) = \\cos(\\frac{t/T + s}{1 + s} * \\frac{\\pi}{2})^2 196 197 where 198 199 .. math:: 200 :nowrap: 201 202 \\begin{equation} 203 \\left\\{ \\begin{aligned} 204 205 s & \\text{ is the offset} \\\\ 206 T & \\text{ is the number of steps} 207 208 \\end{aligned} \\right. 209 \\end{equation} 210 211 Args: 212 offset: The offset :math:`s` defined above. 213 steps: The number of steps for the beta scheduler. 214 max_beta: The maximum beta values. Higher values will be clipped. 215 """
[docs] 216 self.offset: float = offset
217 """The offset :math:`s` defined above."""
[docs] 218 self.max_beta: Optional[float] = max_beta
219 """The maximum beta values. Higher values will be clipped."""
[docs] 220 self.steps: int = steps
221 """The number of steps for the beta scheduler.""" 222 self._alpha_bars = self._compute_alpha_bar() 223 self._betas = self._compute_betas() 224 225 super().__init__( 226 steps=steps, 227 ) 228
[docs] 229 def f(self, t: torch.Tensor) -> torch.Tensor: 230 """A helper function to compute the :math:`\\bar{\\alpha}_t`. 231 232 Args: 233 t: The timestep to compute. 234 235 Returns: 236 237 .. math:: 238 239 f(t) = \\cos(\\frac{t/T + s}{1 + s} * \\frac{\\pi}{2})^2 240 241 """ 242 return ( 243 torch.cos( 244 (((t / self.steps) + self.offset) / (1 + self.offset)) * (torch.pi / 2) 245 ) 246 ** 2 247 )
248 249 def _compute_betas(self): 250 betas = 1 - self._alpha_bars[1:] / self._alpha_bars[:-1] 251 if self.max_beta: 252 betas = torch.clip(betas, max=self.max_beta) 253 return betas 254 255 def _compute_alpha_bar(self): 256 t = torch.linspace(0, self.steps, self.steps, dtype=torch.float32) 257 return self.f(t) / self.f(torch.tensor([0], dtype=torch.float32)) 258
[docs] 259 def sample_betas(self): 260 return self._betas
261
[docs] 262 def compute_alpha_bar(self): 263 return self._alpha_bars