1import abc
  2import logging
  3from typing import Optional
  4
  5import torch
  6
  7
[docs]
  8class BaseBetaScheduler:
  9  def __init__(
 10    self,
 11    steps: int,
 12    enforce_zero_terminal_snr: bool = False,
 13    initialize: bool = True,
 14  ):
 15    """Initializes a beta scheduler.
 16
 17    BaseBetaScheduler is an abstract base class for different beta scheduler
 18    implementations. It defines the interface that all beta schedulers should
 19    adhere to.
 20
 21    Args:
 22      steps: The number of steps for the beta.
 23      enforce_zero_terminal_snr: Whether to enforce zero terminal SNR inline
 24        with `"Common Diffusion Noise Schedules and Sample Steps are Flawed"
 25        <https://arxiv.org/abs/2305.08891>`_.\n
 26        Defaults to ``False``.
 27      initialize: Whether to initialize the beta scheduler. If this is
 28        set to ``False``, you will need to manually set ``self.betas`` and
 29        ``self.alpha_bars``. Otherwise, they are initialized using your
 30        ``sample_betas`` and ``compute_alpha_bar`` methods.
 31
 32    Warnings:
 33      Do not instantiate this class directly. Instead, build your own Beta
 34      scheduler by inheriting from BaseBetaScheduler.
 35      (see :class:`~.LinearBetaScheduler`)
 36    """
[docs]
 37    self.steps: int = steps 
 38    """The number of steps for the beta scheduler."""
 40    """The :math:`\\beta` computed according to :meth:`~.BaseBetaScheduler.sample_betas`."""
[docs]
 41    self.alpha_bars = None 
 42    """The :math:`\\bar{\\alpha}` computed according to :meth:`~.BaseBetaScheduler.compute_alpha_bar`."""
 43
 44    if initialize:
 45      self._initialize()
 46      if enforce_zero_terminal_snr:
 47        self.enforce_zero_terminal_snr()
 48
 49  def _initialize(self):
 50    self.betas = self.sample_betas()
 51    self.alpha_bars = self.compute_alpha_bar()
 52
[docs]
 53  def enforce_zero_terminal_snr(self):
 54    """Enforce terminal SNR by adjusting :math:`\\beta` and :math:`\\bar{\\alpha}`.
 55
 56    This method enforces zero terminal SNR according to
 57    `"Common Diffusion Noise Schedules and Sample Steps are Flawed"
 58    <https://arxiv.org/abs/2305.08891>`_.
 59    """
 60    alpha_bar_length = len(self.alpha_bars)
 61
 62    # Convert betas to alphas_bar_sqrt
 63    alphas = 1 - self.betas
 64    alphas_bar = alphas.cumprod(0)
 65    alphas_bar_sqrt = alphas_bar.sqrt()
 66
 67    # Store old values.
 68    alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone()
 69    alphas_bar_sqrt_t = alphas_bar_sqrt[-1].clone()
 70    # Shift so last timestep is zero.
 71    alphas_bar_sqrt -= alphas_bar_sqrt_t
 72    # Scale so first timestep is back to old value.
 73    alphas_bar_sqrt *= alphas_bar_sqrt_0 / (
 74      alphas_bar_sqrt_0 - alphas_bar_sqrt_t
 75    )
 76
 77    # Convert alphas_bar_sqrt to betas
 78    alphas_bar = alphas_bar_sqrt**2
 79    alphas = alphas_bar[1:] / alphas_bar[:-1]
 80    alphas = torch.cat([alphas_bar[0:1], alphas])
 81    betas = 1 - alphas
 82    if len(alphas) == alpha_bar_length:
 83      self.betas = betas
 84      self.alpha_bars = alphas_bar
 85    else:
 86      logging.warning(
 87        "Got different alpha_bar length after enforcing zero SNR. Please check your beta scheduler"
 88      ) 
 89
 90  @abc.abstractmethod
[docs]
 91  def sample_betas(self) -> torch.Tensor:
 92    """Compute :math:`\\beta` for noise scheduling.
 93
 94    Returns:
 95      A torch tensor of the :math:`\\beta` values.
 96    """
 97    raise NotImplementedError() 
 98
 99  @abc.abstractmethod
[docs]
100  def compute_alpha_bar(self) -> torch.Tensor:
101    """Compute :math:`\\bar{\\alpha}` for noise scheduling.
102
103    Returns:
104      A torch tensor of the :math:`\\bar{\\alpha}` values.
105    """
106    raise NotImplementedError() 
107
[docs]
108  def to(self, device: str):
109    """Moves the beta scheduler to the given device.
110
111    Args:
112      device: The device to which the method should move the object.
113        Default is "cpu".
114
115    """
116    self.betas = self.betas.to(device)
117    self.alpha_bars = self.alpha_bars.to(device)
118    return self 
119
120  @classmethod
[docs]
121  def from_tensors(
122    cls, steps: int, betas: torch.Tensor, alpha_bars: torch.Tensor
123  ):
124    """Instantiate a beta scheduler from tensors.
125
126    Instantiate a beta scheduler from tensors. This is particularly useful for
127    loading checkpoints.
128
129    Args:
130      steps: The number of steps for the beta scheduler.
131      betas: The pre-computed beta values for the noise scheduler.
132      alpha_bars: The pre-computed alpha bar values for the noise scheduler.
133
134    Returns:
135
136    """
137    generic_beta_scheduler = cls(steps, initialize=False)
138    generic_beta_scheduler.steps = steps
139    generic_beta_scheduler.betas = betas
140    generic_beta_scheduler.alpha_bars = alpha_bars
141    return generic_beta_scheduler 
 
142
143
[docs]
144class LinearBetaScheduler(BaseBetaScheduler):
145  def __init__(
146    self,
147    beta_start: float = 0.0001,
148    beta_end: float = 0.02,
149    steps: int = 1000,
150    enforce_zero_terminal_snr: bool = True,
151  ):
152    """A Linear Beta scheduler.
153
154    A simple linear beta scheduler with betas linearly spaced between
155    ``beta_start`` and ``beta_end``.
156
157    Args:
158      beta_start: The starting value of the betas.
159      beta_end: The end value of the betas.
160      steps: The number of steps for the beta scheduler. This is also the number
161        of betas.
162      enforce_zero_terminal_snr: Whether to enforce zero terminal SNR.
163    """
[docs]
164    self.beta_start: int = beta_start 
165    """The starting value of the betas."""
[docs]
166    self.beta_end: int = beta_end 
167    """The end value of the betas."""
168    super().__init__(
169      steps=steps,
170      enforce_zero_terminal_snr=enforce_zero_terminal_snr,
171    )
172
[docs]
173  def sample_betas(self) -> torch.Tensor:
174    """Return linearly spaced betas between ``self.beta_start`` and ``self.beta_end``."""
175    return torch.linspace(self.beta_start, self.beta_end, self.steps) 
176
[docs]
177  def compute_alpha_bar(self):
178    """Return :math:`\\bar{\\alpha}` computed from the beta values."""
179    alphas = 1 - self.betas
180    alpha_bar = torch.cumprod(alphas, dim=0)
181    return alpha_bar 
 
182
183
[docs]
184class CosineBetaScheduler(BaseBetaScheduler):
185  def __init__(
186    self,
187    offset: float = 0.008,
188    steps: int = 1000,
189    max_beta: Optional[float] = 0.999,
190  ):
191    """A Cosine Beta scheduler.
192
193    A Cosine Beta Scheduler based on the following formulas:
194
195    .. math::
196      :nowrap:
197
198      \\begin{equation}
199      \\left\\{ \\begin{aligned}
200        \\bar{\\alpha}_t &= \\frac{f(t)}{f(0)} \\\\
201        \\beta_t &= 1 - \\frac{\\bar{\\alpha}_t}{\\bar{\\alpha}_t -1}
202      \\end{aligned} \\right.
203      \\end{equation}
204
205    where
206
207      .. math::
208
209        f(t) = \\cos(\\frac{t/T + s}{1 + s} * \\frac{\\pi}{2})^2
210
211    where
212
213      .. math::
214        :nowrap:
215
216        \\begin{equation}
217        \\left\\{ \\begin{aligned}
218
219          s & \\text{ is the offset} \\\\
220          T & \\text{ is the number of steps}
221
222        \\end{aligned} \\right.
223        \\end{equation}
224
225    Args:
226      offset: The offset :math:`s` defined above.
227      steps: The number of steps for the beta scheduler.
228      max_beta: The maximum beta values. Higher values will be clipped.
229    """
[docs]
230    self.offset: float = offset 
231    """The offset :math:`s` defined above."""
[docs]
232    self.max_beta: Optional[float] = max_beta 
233    """The maximum beta values. Higher values will be clipped."""
[docs]
234    self.steps: int = steps 
235    """The number of steps for the beta scheduler."""
236    self._alpha_bars = self._compute_alpha_bar()
237    self._betas = self._compute_betas()
238
239    super().__init__(
240      steps=steps,
241    )
242
[docs]
243  def f(self, t: torch.Tensor) -> torch.Tensor:
244    """A helper function to compute the :math:`\\bar{\\alpha}_t`.
245
246    Args:
247      t: The timestep to compute.
248
249    Returns:
250
251      .. math::
252
253        f(t) = \\cos(\\frac{t/T + s}{1 + s} * \\frac{\\pi}{2})^2
254
255    """
256    return (
257      torch.cos(
258        (((t / self.steps) + self.offset) / (1 + self.offset)) * (torch.pi / 2)
259      )
260      ** 2
261    ) 
262
263  def _compute_betas(self):
264    betas = 1 - self._alpha_bars[1:] / self._alpha_bars[:-1]
265    if self.max_beta:
266      betas = torch.clip(betas, max=self.max_beta)
267    return betas
268
269  def _compute_alpha_bar(self):
270    t = torch.linspace(0, self.steps, self.steps, dtype=torch.float32)
271    return self.f(t) / self.f(torch.tensor([0], dtype=torch.float32))
272
[docs]
273  def sample_betas(self):
274    return self._betas 
275
[docs]
276  def compute_alpha_bar(self):
277    return self._alpha_bars