1import abc
2import logging
3from typing import Optional
4
5import torch
6
7
[docs]
8class BaseBetaScheduler:
9 def __init__(
10 self,
11 steps: int,
12 enforce_zero_terminal_snr: bool = False,
13 initialize: bool = True,
14 ):
15 """Initializes a beta scheduler.
16
17 BaseBetaScheduler is an abstract base class for different beta scheduler
18 implementations. It defines the interface that all beta schedulers should
19 adhere to.
20
21 Args:
22 steps: The number of steps for the beta.
23 enforce_zero_terminal_snr: Whether to enforce zero terminal SNR inline
24 with `"Common Diffusion Noise Schedules and Sample Steps are Flawed"
25 <https://arxiv.org/abs/2305.08891>`_.\n
26 Defaults to ``False``.
27 initialize: Whether to initialize the beta scheduler. If this is
28 set to ``False``, you will need to manually set ``self.betas`` and
29 ``self.alpha_bars``. Otherwise, they are initialized using your
30 ``sample_betas`` and ``compute_alpha_bar`` methods.
31
32 Warnings:
33 Do not instantiate this class directly. Instead, build your own Beta
34 scheduler by inheriting from BaseBetaScheduler.
35 (see :class:`~.LinearBetaScheduler`)
36 """
[docs]
37 self.steps: int = steps
38 """The number of steps for the beta scheduler."""
40 """The :math:`\\beta` computed according to :meth:`~.BaseBetaScheduler.sample_betas`."""
[docs]
41 self.alpha_bars = None
42 """The :math:`\\bar{\\alpha}` computed according to :meth:`~.BaseBetaScheduler.compute_alpha_bar`."""
43
44 if initialize:
45 self._initialize()
46 if enforce_zero_terminal_snr:
47 self.enforce_zero_terminal_snr()
48
49 def _initialize(self):
50 self.betas = self.sample_betas()
51 self.alpha_bars = self.compute_alpha_bar()
52
[docs]
53 def enforce_zero_terminal_snr(self):
54 """Enforce terminal SNR by adjusting :math:`\\beta` and :math:`\\bar{\\alpha}`.
55
56 This method enforces zero terminal SNR according to
57 `"Common Diffusion Noise Schedules and Sample Steps are Flawed"
58 <https://arxiv.org/abs/2305.08891>`_.
59 """
60 alpha_bar_length = len(self.alpha_bars)
61
62 # Convert betas to alphas_bar_sqrt
63 alphas = 1 - self.betas
64 alphas_bar = alphas.cumprod(0)
65 alphas_bar_sqrt = alphas_bar.sqrt()
66
67 # Store old values.
68 alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone()
69 alphas_bar_sqrt_t = alphas_bar_sqrt[-1].clone()
70 # Shift so last timestep is zero.
71 alphas_bar_sqrt -= alphas_bar_sqrt_t
72 # Scale so first timestep is back to old value.
73 alphas_bar_sqrt *= alphas_bar_sqrt_0 / (
74 alphas_bar_sqrt_0 - alphas_bar_sqrt_t
75 )
76
77 # Convert alphas_bar_sqrt to betas
78 alphas_bar = alphas_bar_sqrt**2
79 alphas = alphas_bar[1:] / alphas_bar[:-1]
80 alphas = torch.cat([alphas_bar[0:1], alphas])
81 betas = 1 - alphas
82 if len(alphas) == alpha_bar_length:
83 self.betas = betas
84 self.alpha_bars = alphas_bar
85 else:
86 logging.warning(
87 "Got different alpha_bar length after enforcing zero SNR. Please check your beta scheduler"
88 )
89
90 @abc.abstractmethod
[docs]
91 def sample_betas(self) -> torch.Tensor:
92 """Compute :math:`\\beta` for noise scheduling.
93
94 Returns:
95 A torch tensor of the :math:`\\beta` values.
96 """
97 raise NotImplementedError()
98
99 @abc.abstractmethod
[docs]
100 def compute_alpha_bar(self) -> torch.Tensor:
101 """Compute :math:`\\bar{\\alpha}` for noise scheduling.
102
103 Returns:
104 A torch tensor of the :math:`\\bar{\\alpha}` values.
105 """
106 raise NotImplementedError()
107
[docs]
108 def to(self, device: str):
109 """Moves the beta scheduler to the given device.
110
111 Args:
112 device: The device to which the method should move the object.
113 Default is "cpu".
114
115 """
116 self.betas = self.betas.to(device)
117 self.alpha_bars = self.alpha_bars.to(device)
118 return self
119
120 @classmethod
[docs]
121 def from_tensors(
122 cls, steps: int, betas: torch.Tensor, alpha_bars: torch.Tensor
123 ):
124 """Instantiate a beta scheduler from tensors.
125
126 Instantiate a beta scheduler from tensors. This is particularly useful for
127 loading checkpoints.
128
129 Args:
130 steps: The number of steps for the beta scheduler.
131 betas: The pre-computed beta values for the noise scheduler.
132 alpha_bars: The pre-computed alpha bar values for the noise scheduler.
133
134 Returns:
135
136 """
137 generic_beta_scheduler = cls(steps, initialize=False)
138 generic_beta_scheduler.steps = steps
139 generic_beta_scheduler.betas = betas
140 generic_beta_scheduler.alpha_bars = alpha_bars
141 return generic_beta_scheduler
142
143
[docs]
144class LinearBetaScheduler(BaseBetaScheduler):
145 def __init__(
146 self,
147 beta_start: float = 0.0001,
148 beta_end: float = 0.02,
149 steps: int = 1000,
150 enforce_zero_terminal_snr: bool = True,
151 ):
152 """A Linear Beta scheduler.
153
154 A simple linear beta scheduler with betas linearly spaced between
155 ``beta_start`` and ``beta_end``.
156
157 Args:
158 beta_start: The starting value of the betas.
159 beta_end: The end value of the betas.
160 steps: The number of steps for the beta scheduler. This is also the number
161 of betas.
162 enforce_zero_terminal_snr: Whether to enforce zero terminal SNR.
163 """
[docs]
164 self.beta_start: int = beta_start
165 """The starting value of the betas."""
[docs]
166 self.beta_end: int = beta_end
167 """The end value of the betas."""
168 super().__init__(
169 steps=steps,
170 enforce_zero_terminal_snr=enforce_zero_terminal_snr,
171 )
172
[docs]
173 def sample_betas(self) -> torch.Tensor:
174 """Return linearly spaced betas between ``self.beta_start`` and ``self.beta_end``."""
175 return torch.linspace(self.beta_start, self.beta_end, self.steps)
176
[docs]
177 def compute_alpha_bar(self):
178 """Return :math:`\\bar{\\alpha}` computed from the beta values."""
179 alphas = 1 - self.betas
180 alpha_bar = torch.cumprod(alphas, dim=0)
181 return alpha_bar
182
183
[docs]
184class CosineBetaScheduler(BaseBetaScheduler):
185 def __init__(
186 self,
187 offset: float = 0.008,
188 steps: int = 1000,
189 max_beta: Optional[float] = 0.999,
190 ):
191 """A Cosine Beta scheduler.
192
193 A Cosine Beta Scheduler based on the following formulas:
194
195 .. math::
196 :nowrap:
197
198 \\begin{equation}
199 \\left\\{ \\begin{aligned}
200 \\bar{\\alpha}_t &= \\frac{f(t)}{f(0)} \\\\
201 \\beta_t &= 1 - \\frac{\\bar{\\alpha}_t}{\\bar{\\alpha}_t -1}
202 \\end{aligned} \\right.
203 \\end{equation}
204
205 where
206
207 .. math::
208
209 f(t) = \\cos(\\frac{t/T + s}{1 + s} * \\frac{\\pi}{2})^2
210
211 where
212
213 .. math::
214 :nowrap:
215
216 \\begin{equation}
217 \\left\\{ \\begin{aligned}
218
219 s & \\text{ is the offset} \\\\
220 T & \\text{ is the number of steps}
221
222 \\end{aligned} \\right.
223 \\end{equation}
224
225 Args:
226 offset: The offset :math:`s` defined above.
227 steps: The number of steps for the beta scheduler.
228 max_beta: The maximum beta values. Higher values will be clipped.
229 """
[docs]
230 self.offset: float = offset
231 """The offset :math:`s` defined above."""
[docs]
232 self.max_beta: Optional[float] = max_beta
233 """The maximum beta values. Higher values will be clipped."""
[docs]
234 self.steps: int = steps
235 """The number of steps for the beta scheduler."""
236 self._alpha_bars = self._compute_alpha_bar()
237 self._betas = self._compute_betas()
238
239 super().__init__(
240 steps=steps,
241 )
242
[docs]
243 def f(self, t: torch.Tensor) -> torch.Tensor:
244 """A helper function to compute the :math:`\\bar{\\alpha}_t`.
245
246 Args:
247 t: The timestep to compute.
248
249 Returns:
250
251 .. math::
252
253 f(t) = \\cos(\\frac{t/T + s}{1 + s} * \\frac{\\pi}{2})^2
254
255 """
256 return (
257 torch.cos(
258 (((t / self.steps) + self.offset) / (1 + self.offset)) * (torch.pi / 2)
259 )
260 ** 2
261 )
262
263 def _compute_betas(self):
264 betas = 1 - self._alpha_bars[1:] / self._alpha_bars[:-1]
265 if self.max_beta:
266 betas = torch.clip(betas, max=self.max_beta)
267 return betas
268
269 def _compute_alpha_bar(self):
270 t = torch.linspace(0, self.steps, self.steps, dtype=torch.float32)
271 return self.f(t) / self.f(torch.tensor([0], dtype=torch.float32))
272
[docs]
273 def sample_betas(self):
274 return self._betas
275
[docs]
276 def compute_alpha_bar(self):
277 return self._alpha_bars