Source code for diffusion_models.gaussian_diffusion.beta_schedulers

  1import abc
  2import logging
  3from typing import Optional
  4
  5import torch
  6
  7
[docs] 8class BaseBetaScheduler: 9 def __init__( 10 self, 11 steps: int, 12 enforce_zero_terminal_snr: bool = False, 13 initialize: bool = True, 14 ): 15 """Initializes a beta scheduler. 16 17 BaseBetaScheduler is an abstract base class for different beta scheduler 18 implementations. It defines the interface that all beta schedulers should 19 adhere to. 20 21 Args: 22 steps: The number of steps for the beta. 23 enforce_zero_terminal_snr: Whether to enforce zero terminal SNR inline 24 with `"Common Diffusion Noise Schedules and Sample Steps are Flawed" 25 <https://arxiv.org/abs/2305.08891>`_.\n 26 Defaults to ``False``. 27 initialize: Whether to initialize the beta scheduler. If this is 28 set to ``False``, you will need to manually set ``self.betas`` and 29 ``self.alpha_bars``. Otherwise, they are initialized using your 30 ``sample_betas`` and ``compute_alpha_bar`` methods. 31 32 Warnings: 33 Do not instantiate this class directly. Instead, build your own Beta 34 scheduler by inheriting from BaseBetaScheduler. 35 (see :class:`~.LinearBetaScheduler`) 36 """
[docs] 37 self.steps: int = steps
38 """The number of steps for the beta scheduler."""
[docs] 39 self.betas = None
40 """The :math:`\\beta` computed according to :meth:`~.BaseBetaScheduler.sample_betas`."""
[docs] 41 self.alpha_bars = None
42 """The :math:`\\bar{\\alpha}` computed according to :meth:`~.BaseBetaScheduler.compute_alpha_bar`.""" 43 44 if initialize: 45 self._initialize() 46 if enforce_zero_terminal_snr: 47 self.enforce_zero_terminal_snr() 48 49 def _initialize(self): 50 self.betas = self.sample_betas() 51 self.alpha_bars = self.compute_alpha_bar() 52
[docs] 53 def enforce_zero_terminal_snr(self): 54 """Enforce terminal SNR by adjusting :math:`\\beta` and :math:`\\bar{\\alpha}`. 55 56 This method enforces zero terminal SNR according to 57 `"Common Diffusion Noise Schedules and Sample Steps are Flawed" 58 <https://arxiv.org/abs/2305.08891>`_. 59 """ 60 alpha_bar_length = len(self.alpha_bars) 61 62 # Convert betas to alphas_bar_sqrt 63 alphas = 1 - self.betas 64 alphas_bar = alphas.cumprod(0) 65 alphas_bar_sqrt = alphas_bar.sqrt() 66 67 # Store old values. 68 alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone() 69 alphas_bar_sqrt_t = alphas_bar_sqrt[-1].clone() 70 # Shift so last timestep is zero. 71 alphas_bar_sqrt -= alphas_bar_sqrt_t 72 # Scale so first timestep is back to old value. 73 alphas_bar_sqrt *= alphas_bar_sqrt_0 / ( 74 alphas_bar_sqrt_0 - alphas_bar_sqrt_t 75 ) 76 77 # Convert alphas_bar_sqrt to betas 78 alphas_bar = alphas_bar_sqrt**2 79 alphas = alphas_bar[1:] / alphas_bar[:-1] 80 alphas = torch.cat([alphas_bar[0:1], alphas]) 81 betas = 1 - alphas 82 if len(alphas) == alpha_bar_length: 83 self.betas = betas 84 self.alpha_bars = alphas_bar 85 else: 86 logging.warning( 87 "Got different alpha_bar length after enforcing zero SNR. Please check your beta scheduler" 88 )
89 90 @abc.abstractmethod
[docs] 91 def sample_betas(self) -> torch.Tensor: 92 """Compute :math:`\\beta` for noise scheduling. 93 94 Returns: 95 A torch tensor of the :math:`\\beta` values. 96 """ 97 raise NotImplementedError()
98 99 @abc.abstractmethod
[docs] 100 def compute_alpha_bar(self) -> torch.Tensor: 101 """Compute :math:`\\bar{\\alpha}` for noise scheduling. 102 103 Returns: 104 A torch tensor of the :math:`\\bar{\\alpha}` values. 105 """ 106 raise NotImplementedError()
107
[docs] 108 def to(self, device: str): 109 """Moves the beta scheduler to the given device. 110 111 Args: 112 device: The device to which the method should move the object. 113 Default is "cpu". 114 115 """ 116 self.betas = self.betas.to(device) 117 self.alpha_bars = self.alpha_bars.to(device) 118 return self
119 120 @classmethod
[docs] 121 def from_tensors( 122 cls, steps: int, betas: torch.Tensor, alpha_bars: torch.Tensor 123 ): 124 """Instantiate a beta scheduler from tensors. 125 126 Instantiate a beta scheduler from tensors. This is particularly useful for 127 loading checkpoints. 128 129 Args: 130 steps: The number of steps for the beta scheduler. 131 betas: The pre-computed beta values for the noise scheduler. 132 alpha_bars: The pre-computed alpha bar values for the noise scheduler. 133 134 Returns: 135 136 """ 137 generic_beta_scheduler = cls(steps, initialize=False) 138 generic_beta_scheduler.steps = steps 139 generic_beta_scheduler.betas = betas 140 generic_beta_scheduler.alpha_bars = alpha_bars 141 return generic_beta_scheduler
142 143
[docs] 144class LinearBetaScheduler(BaseBetaScheduler): 145 def __init__( 146 self, 147 beta_start: float = 0.0001, 148 beta_end: float = 0.02, 149 steps: int = 1000, 150 enforce_zero_terminal_snr: bool = True, 151 ): 152 """A Linear Beta scheduler. 153 154 A simple linear beta scheduler with betas linearly spaced between 155 ``beta_start`` and ``beta_end``. 156 157 Args: 158 beta_start: The starting value of the betas. 159 beta_end: The end value of the betas. 160 steps: The number of steps for the beta scheduler. This is also the number 161 of betas. 162 enforce_zero_terminal_snr: Whether to enforce zero terminal SNR. 163 """
[docs] 164 self.beta_start: int = beta_start
165 """The starting value of the betas."""
[docs] 166 self.beta_end: int = beta_end
167 """The end value of the betas.""" 168 super().__init__( 169 steps=steps, 170 enforce_zero_terminal_snr=enforce_zero_terminal_snr, 171 ) 172
[docs] 173 def sample_betas(self) -> torch.Tensor: 174 """Return linearly spaced betas between ``self.beta_start`` and ``self.beta_end``.""" 175 return torch.linspace(self.beta_start, self.beta_end, self.steps)
176
[docs] 177 def compute_alpha_bar(self): 178 """Return :math:`\\bar{\\alpha}` computed from the beta values.""" 179 alphas = 1 - self.betas 180 alpha_bar = torch.cumprod(alphas, dim=0) 181 return alpha_bar
182 183
[docs] 184class CosineBetaScheduler(BaseBetaScheduler): 185 def __init__( 186 self, 187 offset: float = 0.008, 188 steps: int = 1000, 189 max_beta: Optional[float] = 0.999, 190 ): 191 """A Cosine Beta scheduler. 192 193 A Cosine Beta Scheduler based on the following formulas: 194 195 .. math:: 196 :nowrap: 197 198 \\begin{equation} 199 \\left\\{ \\begin{aligned} 200 \\bar{\\alpha}_t &= \\frac{f(t)}{f(0)} \\\\ 201 \\beta_t &= 1 - \\frac{\\bar{\\alpha}_t}{\\bar{\\alpha}_t -1} 202 \\end{aligned} \\right. 203 \\end{equation} 204 205 where 206 207 .. math:: 208 209 f(t) = \\cos(\\frac{t/T + s}{1 + s} * \\frac{\\pi}{2})^2 210 211 where 212 213 .. math:: 214 :nowrap: 215 216 \\begin{equation} 217 \\left\\{ \\begin{aligned} 218 219 s & \\text{ is the offset} \\\\ 220 T & \\text{ is the number of steps} 221 222 \\end{aligned} \\right. 223 \\end{equation} 224 225 Args: 226 offset: The offset :math:`s` defined above. 227 steps: The number of steps for the beta scheduler. 228 max_beta: The maximum beta values. Higher values will be clipped. 229 """
[docs] 230 self.offset: float = offset
231 """The offset :math:`s` defined above."""
[docs] 232 self.max_beta: Optional[float] = max_beta
233 """The maximum beta values. Higher values will be clipped."""
[docs] 234 self.steps: int = steps
235 """The number of steps for the beta scheduler.""" 236 self._alpha_bars = self._compute_alpha_bar() 237 self._betas = self._compute_betas() 238 239 super().__init__( 240 steps=steps, 241 ) 242
[docs] 243 def f(self, t: torch.Tensor) -> torch.Tensor: 244 """A helper function to compute the :math:`\\bar{\\alpha}_t`. 245 246 Args: 247 t: The timestep to compute. 248 249 Returns: 250 251 .. math:: 252 253 f(t) = \\cos(\\frac{t/T + s}{1 + s} * \\frac{\\pi}{2})^2 254 255 """ 256 return ( 257 torch.cos( 258 (((t / self.steps) + self.offset) / (1 + self.offset)) * (torch.pi / 2) 259 ) 260 ** 2 261 )
262 263 def _compute_betas(self): 264 betas = 1 - self._alpha_bars[1:] / self._alpha_bars[:-1] 265 if self.max_beta: 266 betas = torch.clip(betas, max=self.max_beta) 267 return betas 268 269 def _compute_alpha_bar(self): 270 t = torch.linspace(0, self.steps, self.steps, dtype=torch.float32) 271 return self.f(t) / self.f(torch.tensor([0], dtype=torch.float32)) 272
[docs] 273 def sample_betas(self): 274 return self._betas
275
[docs] 276 def compute_alpha_bar(self): 277 return self._alpha_bars