Source code for diffusion_models.gaussian_diffusion.beta_schedulers

  1import abc
  2import logging
  3from typing import Optional
  4
  5import torch
  6
  7
[docs] 8class BaseBetaScheduler: 9 def __init__(self, 10 steps: int, 11 enforce_zero_terminal_snr: bool = False, 12 initialize: bool = True, 13 ): 14 """Initializes a beta scheduler. 15 16 BaseBetaScheduler is an abstract base class for different beta scheduler 17 implementations. It defines the interface that all beta schedulers should 18 adhere to. 19 20 Args: 21 steps: The number of steps for the beta. 22 enforce_zero_terminal_snr: Whether to enforce zero terminal SNR inline 23 with `"Common Diffusion Noise Schedules and Sample Steps are Flawed" 24 <https://arxiv.org/abs/2305.08891>`_.\n 25 Defaults to ``False``. 26 initialize: Whether to initialize the beta scheduler. If this is 27 set to ``False``, you will need to manually set ``self.betas`` and 28 ``self.alpha_bars``. Otherwise, they are initialized using your 29 ``sample_betas`` and ``compute_alpha_bar`` methods. 30 31 Warnings: 32 Do not instantiate this class directly. Instead, build your own Beta 33 scheduler by inheriting from BaseBetaScheduler. 34 (see :class:`~.LinearBetaScheduler`) 35 """
[docs] 36 self.steps: int = steps
37 """The number of steps for the beta scheduler."""
[docs] 38 self.betas = None
39 """The :math:`\\beta` computed according to :meth:`~.BaseBetaScheduler.sample_betas`."""
[docs] 40 self.alpha_bars = None
41 """The :math:`\\bar{\\alpha}` computed according to :meth:`~.BaseBetaScheduler.compute_alpha_bar`.""" 42 43 if initialize: 44 self._initialize() 45 if enforce_zero_terminal_snr: 46 self.enforce_zero_terminal_snr() 47 48 def _initialize(self): 49 self.betas = self.sample_betas() 50 self.alpha_bars = self.compute_alpha_bar() 51
[docs] 52 def enforce_zero_terminal_snr(self): 53 """Enforce terminal SNR by adjusting :math:`\\beta` and :math:`\\bar{\\alpha}`. 54 55 This method enforces zero terminal SNR according to 56 `"Common Diffusion Noise Schedules and Sample Steps are Flawed" 57 <https://arxiv.org/abs/2305.08891>`_. 58 """ 59 alpha_bar_length = len(self.alpha_bars) 60 61 # Convert betas to alphas_bar_sqrt 62 alphas = 1 - self.betas 63 alphas_bar = alphas.cumprod(0) 64 alphas_bar_sqrt = alphas_bar.sqrt() 65 66 # Store old values. 67 alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone() 68 alphas_bar_sqrt_t = alphas_bar_sqrt[-1].clone() 69 # Shift so last timestep is zero. 70 alphas_bar_sqrt -= alphas_bar_sqrt_t 71 # Scale so first timestep is back to old value. 72 alphas_bar_sqrt *= alphas_bar_sqrt_0 / ( 73 alphas_bar_sqrt_0 - alphas_bar_sqrt_t 74 ) 75 76 # Convert alphas_bar_sqrt to betas 77 alphas_bar = alphas_bar_sqrt**2 78 alphas = alphas_bar[1:] / alphas_bar[:-1] 79 alphas = torch.cat([alphas_bar[0:1], alphas]) 80 betas = 1 - alphas 81 if len(alphas) == alpha_bar_length: 82 self.betas = betas 83 self.alpha_bars = alphas_bar 84 else: 85 logging.warning( 86 "Got different alpha_bar length after enforcing zero SNR. Please check your beta scheduler" 87 )
88 89 @abc.abstractmethod
[docs] 90 def sample_betas(self) -> torch.Tensor: 91 """Compute :math:`\\beta` for noise scheduling. 92 93 Returns: 94 A torch tensor of the :math:`\\beta` values. 95 """ 96 raise NotImplementedError()
97 98 @abc.abstractmethod
[docs] 99 def compute_alpha_bar(self) -> torch.Tensor: 100 """Compute :math:`\\bar{\\alpha}` for noise scheduling. 101 102 Returns: 103 A torch tensor of the :math:`\\bar{\\alpha}` values. 104 """ 105 raise NotImplementedError()
106
[docs] 107 def to(self, device: str): 108 """Moves the beta scheduler to the given device. 109 110 Args: 111 device: The device to which the method should move the object. 112 Default is "cpu". 113 114 """ 115 self.betas = self.betas.to(device) 116 self.alpha_bars = self.alpha_bars.to(device) 117 return self
118 119 @classmethod
[docs] 120 def from_tensors( 121 cls, steps: int, betas: torch.Tensor, alpha_bars: torch.Tensor 122 ): 123 """Instantiate a beta scheduler from tensors. 124 125 Instantiate a beta scheduler from tensors. This is particularly useful for 126 loading checkpoints. 127 128 Args: 129 steps: The number of steps for the beta scheduler. 130 betas: The pre-computed beta values for the noise scheduler. 131 alpha_bars: The pre-computed alpha bar values for the noise scheduler. 132 133 Returns: 134 135 """ 136 generic_beta_scheduler = cls(steps, initialize=False) 137 generic_beta_scheduler.steps = steps 138 generic_beta_scheduler.betas = betas 139 generic_beta_scheduler.alpha_bars = alpha_bars 140 return generic_beta_scheduler
141 142
[docs] 143class LinearBetaScheduler(BaseBetaScheduler): 144 def __init__( 145 self, 146 beta_start: float = 0.0001, 147 beta_end: float = 0.02, 148 steps: int = 1000, 149 enforce_zero_terminal_snr: bool = True, 150 ): 151 """A Linear Beta scheduler. 152 153 A simple linear beta scheduler with betas linearly spaced between 154 ``beta_start`` and ``beta_end``. 155 156 Args: 157 beta_start: The starting value of the betas. 158 beta_end: The end value of the betas. 159 steps: The number of steps for the beta scheduler. This is also the number 160 of betas. 161 enforce_zero_terminal_snr: Whether to enforce zero terminal SNR. 162 """
[docs] 163 self.beta_start: int = beta_start
164 """The starting value of the betas."""
[docs] 165 self.beta_end: int = beta_end
166 """The end value of the betas.""" 167 super().__init__( 168 steps=steps, 169 enforce_zero_terminal_snr=enforce_zero_terminal_snr, 170 ) 171
[docs] 172 def sample_betas(self) -> torch.Tensor: 173 """Return linearly spaced betas between ``self.beta_start`` and ``self.beta_end``.""" 174 return torch.linspace(self.beta_start, self.beta_end, self.steps)
175
[docs] 176 def compute_alpha_bar(self): 177 """Return :math:`\\bar{\\alpha}` computed from the beta values.""" 178 alphas = 1 - self.betas 179 alpha_bar = torch.cumprod(alphas, dim=0) 180 return alpha_bar
181 182
[docs] 183class CosineBetaScheduler(BaseBetaScheduler): 184 def __init__( 185 self, 186 offset: float = 0.008, 187 steps: int = 1000, 188 max_beta: Optional[float] = 0.999, 189 ): 190 """A Cosine Beta scheduler. 191 192 A Cosine Beta Scheduler based on the following formulas: 193 194 .. math:: 195 :nowrap: 196 197 \\begin{equation} 198 \\left\\{ \\begin{aligned} 199 \\bar{\\alpha}_t &= \\frac{f(t)}{f(0)} \\\\ 200 \\beta_t &= 1 - \\frac{\\bar{\\alpha}_t}{\\bar{\\alpha}_t -1} 201 \\end{aligned} \\right. 202 \\end{equation} 203 204 where 205 206 .. math:: 207 208 f(t) = \\cos(\\frac{t/T + s}{1 + s} * \\frac{\\pi}{2})^2 209 210 where 211 212 .. math:: 213 :nowrap: 214 215 \\begin{equation} 216 \\left\\{ \\begin{aligned} 217 218 s & \\text{ is the offset} \\\\ 219 T & \\text{ is the number of steps} 220 221 \\end{aligned} \\right. 222 \\end{equation} 223 224 Args: 225 offset: The offset :math:`s` defined above. 226 steps: The number of steps for the beta scheduler. 227 max_beta: The maximum beta values. Higher values will be clipped. 228 """
[docs] 229 self.offset: float = offset
230 """The offset :math:`s` defined above."""
[docs] 231 self.max_beta: Optional[float] = max_beta
232 """The maximum beta values. Higher values will be clipped."""
[docs] 233 self.steps: int = steps
234 """The number of steps for the beta scheduler.""" 235 self._alpha_bars = self._compute_alpha_bar() 236 self._betas = self._compute_betas() 237 238 super().__init__( 239 steps=steps, 240 ) 241
[docs] 242 def f(self, t: torch.Tensor) -> torch.Tensor: 243 """A helper function to compute the :math:`\\bar{\\alpha}_t`. 244 245 Args: 246 t: The timestep to compute. 247 248 Returns: 249 250 .. math:: 251 252 f(t) = \\cos(\\frac{t/T + s}{1 + s} * \\frac{\\pi}{2})^2 253 254 """ 255 return ( 256 torch.cos( 257 (((t / self.steps) + self.offset) / (1 + self.offset)) * (torch.pi / 2) 258 ) 259 ** 2 260 )
261 262 def _compute_betas(self): 263 betas = 1 - self._alpha_bars[1:] / self._alpha_bars[:-1] 264 if self.max_beta: 265 betas = torch.clip(betas, max=self.max_beta) 266 return betas 267 268 def _compute_alpha_bar(self): 269 t = torch.linspace(0, self.steps, self.steps, dtype=torch.float32) 270 return self.f(t) / self.f(torch.tensor([0], dtype=torch.float32)) 271
[docs] 272 def sample_betas(self): 273 return self._betas
274
[docs] 275 def compute_alpha_bar(self): 276 return self._alpha_bars