1import abc
2import logging
3from typing import Optional
4
5import torch
6
7
[docs]
8class BaseBetaScheduler:
9 def __init__(self, steps: int, enforce_zero_terminal_snr: bool = False):
10 """Initializes a beta scheduler.
11
12 BaseBetaScheduler is an abstract base class for different beta scheduler
13 implementations. It defines the interface that all beta schedulers should
14 adhere to.
15
16 Args:
17 steps: The number of steps for the beta.
18 enforce_zero_terminal_snr: Whether to enforce zero terminal SNR inline
19 with `"Common Diffusion Noise Schedules and Sample Steps are Flawed"
20 <https://arxiv.org/abs/2305.08891>`_.\n
21 Defaults to ``False``.
22
23 Warnings:
24 Do not instantiate this class directly. Instead, build your own Beta
25 scheduler by inheriting from BaseBetaScheduler.
26 (see :class:`~.LinearBetaScheduler`)
27 """
28 super().__init__()
[docs]
29 self.steps: int = steps
30 """The number of steps for the beta scheduler."""
[docs]
31 self.betas = self.sample_betas()
32 """The :math:`\\beta` computed according to :meth:`~.BaseBetaScheduler.sample_betas`."""
[docs]
33 self.alpha_bars = self.compute_alpha_bar()
34 """The :math:`\\bar{\\alpha}` computed according to :meth:`~.BaseBetaScheduler.compute_alpha_bar`."""
35
36 if enforce_zero_terminal_snr:
37 self.enforce_zero_terminal_snr()
38
[docs]
39 def enforce_zero_terminal_snr(self):
40 """Enforce terminal SNR by adjusting :math:`\\beta` and :math:`\\bar{\\alpha}`.
41
42 This method enforces zero terminal SNR according to
43 `"Common Diffusion Noise Schedules and Sample Steps are Flawed"
44 <https://arxiv.org/abs/2305.08891>`_.
45 """
46 alpha_bar_length = len(self.alpha_bars)
47
48 # Convert betas to alphas_bar_sqrt
49 alphas = 1 - self.betas
50 alphas_bar = alphas.cumprod(0)
51 alphas_bar_sqrt = alphas_bar.sqrt()
52
53 # Store old values.
54 alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone()
55 alphas_bar_sqrt_t = alphas_bar_sqrt[-1].clone()
56 # Shift so last timestep is zero.
57 alphas_bar_sqrt -= alphas_bar_sqrt_t
58 # Scale so first timestep is back to old value.
59 alphas_bar_sqrt *= alphas_bar_sqrt_0 / (
60 alphas_bar_sqrt_0 - alphas_bar_sqrt_t
61 )
62
63 # Convert alphas_bar_sqrt to betas
64 alphas_bar = alphas_bar_sqrt**2
65 alphas = alphas_bar[1:] / alphas_bar[:-1]
66 alphas = torch.cat([alphas_bar[0:1], alphas])
67 betas = 1 - alphas
68 if len(alphas) == alpha_bar_length:
69 self.betas = betas
70 self.alpha_bars = alphas_bar
71 else:
72 logging.warning(
73 "Got different alpha_bar length after enforcing zero SNR. Please check your beta scheduler"
74 )
75
76 @abc.abstractmethod
[docs]
77 def sample_betas(self) -> torch.Tensor:
78 """Compute :math:`\\beta` for noise scheduling.
79
80 Returns:
81 A torch tensor of the :math:`\\beta` values.
82 """
83 raise NotImplementedError()
84
85 @abc.abstractmethod
[docs]
86 def compute_alpha_bar(self) -> torch.Tensor:
87 """Compute :math:`\\bar{\\alpha}` for noise scheduling.
88
89 Returns:
90 A torch tensor of the :math:`\\bar{\\alpha}` values.
91 """
92 raise NotImplementedError()
93
[docs]
94 def to(self, device: str):
95 """Moves the beta scheduler to the given device.
96
97 Args:
98 device: The device to which the method should move the object.
99 Default is "cpu".
100
101 """
102 self.betas = self.betas.to(device)
103 self.alpha_bars = self.alpha_bars.to(device)
104 return self
105
106 @classmethod
[docs]
107 def from_tensors(
108 cls, steps: int, betas: torch.Tensor, alpha_bars: torch.Tensor
109 ):
110 """Instantiate a beta scheduler from tensors.
111
112 Instantiate a beta scheduler from tensors. This is particularly useful for
113 loading checkpoints.
114
115 Args:
116 steps: The number of steps for the beta scheduler.
117 betas: The pre-computed beta values for the noise scheduler.
118 alpha_bars: The pre-computed alpha bar values for the noise scheduler.
119
120 Returns:
121
122 """
123 generic_beta_scheduler = cls(0)
124 generic_beta_scheduler.steps = steps
125 generic_beta_scheduler.betas = betas
126 generic_beta_scheduler.alpha_bars = alpha_bars
127 return generic_beta_scheduler
128
129
[docs]
130class LinearBetaScheduler(BaseBetaScheduler):
131 def __init__(
132 self,
133 beta_start: float = 0.0001,
134 beta_end: float = 0.02,
135 steps: int = 1000,
136 enforce_zero_terminal_snr: bool = True,
137 ):
138 """A Linear Beta scheduler.
139
140 A simple linear beta scheduler with betas linearly spaced between
141 ``beta_start`` and ``beta_end``.
142
143 Args:
144 beta_start: The starting value of the betas.
145 beta_end: The end value of the betas.
146 steps: The number of steps for the beta scheduler. This is also the number
147 of betas.
148 enforce_zero_terminal_snr: Whether to enforce zero terminal SNR.
149 """
[docs]
150 self.beta_start: int = beta_start
151 """The starting value of the betas."""
[docs]
152 self.beta_end: int = beta_end
153 """The end value of the betas."""
154 super().__init__(
155 steps=steps,
156 enforce_zero_terminal_snr=enforce_zero_terminal_snr,
157 )
158
[docs]
159 def sample_betas(self) -> torch.Tensor:
160 """Return linearly spaced betas between ``self.beta_start`` and ``self.beta_end``."""
161 return torch.linspace(self.beta_start, self.beta_end, self.steps)
162
[docs]
163 def compute_alpha_bar(self):
164 """Return :math:`\\bar{\\alpha}` computed from the beta values."""
165 alphas = 1 - self.betas
166 alpha_bar = torch.cumprod(alphas, dim=0)
167 return alpha_bar
168
169
[docs]
170class CosineBetaScheduler(BaseBetaScheduler):
171 def __init__(
172 self,
173 offset: float = 0.008,
174 steps: int = 1000,
175 max_beta: Optional[float] = 0.999,
176 ):
177 """A Cosine Beta scheduler.
178
179 A Cosine Beta Scheduler based on the following formulas:
180
181 .. math::
182 :nowrap:
183
184 \\begin{equation}
185 \\left\\{ \\begin{aligned}
186 \\bar{\\alpha}_t &= \\frac{f(t)}{f(0)} \\\\
187 \\beta_t &= 1 - \\frac{\\bar{\\alpha}_t}{\\bar{\\alpha}_t -1}
188 \\end{aligned} \\right.
189 \\end{equation}
190
191 where
192
193 .. math::
194
195 f(t) = \\cos(\\frac{t/T + s}{1 + s} * \\frac{\\pi}{2})^2
196
197 where
198
199 .. math::
200 :nowrap:
201
202 \\begin{equation}
203 \\left\\{ \\begin{aligned}
204
205 s & \\text{ is the offset} \\\\
206 T & \\text{ is the number of steps}
207
208 \\end{aligned} \\right.
209 \\end{equation}
210
211 Args:
212 offset: The offset :math:`s` defined above.
213 steps: The number of steps for the beta scheduler.
214 max_beta: The maximum beta values. Higher values will be clipped.
215 """
[docs]
216 self.offset: float = offset
217 """The offset :math:`s` defined above."""
[docs]
218 self.max_beta: Optional[float] = max_beta
219 """The maximum beta values. Higher values will be clipped."""
[docs]
220 self.steps: int = steps
221 """The number of steps for the beta scheduler."""
222 self._alpha_bars = self._compute_alpha_bar()
223 self._betas = self._compute_betas()
224
225 super().__init__(
226 steps=steps,
227 )
228
[docs]
229 def f(self, t: torch.Tensor) -> torch.Tensor:
230 """A helper function to compute the :math:`\\bar{\\alpha}_t`.
231
232 Args:
233 t: The timestep to compute.
234
235 Returns:
236
237 .. math::
238
239 f(t) = \\cos(\\frac{t/T + s}{1 + s} * \\frac{\\pi}{2})^2
240
241 """
242 return (
243 torch.cos(
244 (((t / self.steps) + self.offset) / (1 + self.offset)) * (torch.pi / 2)
245 )
246 ** 2
247 )
248
249 def _compute_betas(self):
250 betas = 1 - self._alpha_bars[1:] / self._alpha_bars[:-1]
251 if self.max_beta:
252 betas = torch.clip(betas, max=self.max_beta)
253 return betas
254
255 def _compute_alpha_bar(self):
256 t = torch.linspace(0, self.steps, self.steps, dtype=torch.float32)
257 return self.f(t) / self.f(torch.tensor([0], dtype=torch.float32))
258
[docs]
259 def sample_betas(self):
260 return self._betas
261
[docs]
262 def compute_alpha_bar(self):
263 return self._alpha_bars