1import abc
2import logging
3from typing import Optional
4
5import torch
6
7
[docs]
8class BaseBetaScheduler:
9 def __init__(self,
10 steps: int,
11 enforce_zero_terminal_snr: bool = False,
12 initialize: bool = True,
13 ):
14 """Initializes a beta scheduler.
15
16 BaseBetaScheduler is an abstract base class for different beta scheduler
17 implementations. It defines the interface that all beta schedulers should
18 adhere to.
19
20 Args:
21 steps: The number of steps for the beta.
22 enforce_zero_terminal_snr: Whether to enforce zero terminal SNR inline
23 with `"Common Diffusion Noise Schedules and Sample Steps are Flawed"
24 <https://arxiv.org/abs/2305.08891>`_.\n
25 Defaults to ``False``.
26 initialize: Whether to initialize the beta scheduler. If this is
27 set to ``False``, you will need to manually set ``self.betas`` and
28 ``self.alpha_bars``. Otherwise, they are initialized using your
29 ``sample_betas`` and ``compute_alpha_bar`` methods.
30
31 Warnings:
32 Do not instantiate this class directly. Instead, build your own Beta
33 scheduler by inheriting from BaseBetaScheduler.
34 (see :class:`~.LinearBetaScheduler`)
35 """
[docs]
36 self.steps: int = steps
37 """The number of steps for the beta scheduler."""
39 """The :math:`\\beta` computed according to :meth:`~.BaseBetaScheduler.sample_betas`."""
[docs]
40 self.alpha_bars = None
41 """The :math:`\\bar{\\alpha}` computed according to :meth:`~.BaseBetaScheduler.compute_alpha_bar`."""
42
43 if initialize:
44 self._initialize()
45 if enforce_zero_terminal_snr:
46 self.enforce_zero_terminal_snr()
47
48 def _initialize(self):
49 self.betas = self.sample_betas()
50 self.alpha_bars = self.compute_alpha_bar()
51
[docs]
52 def enforce_zero_terminal_snr(self):
53 """Enforce terminal SNR by adjusting :math:`\\beta` and :math:`\\bar{\\alpha}`.
54
55 This method enforces zero terminal SNR according to
56 `"Common Diffusion Noise Schedules and Sample Steps are Flawed"
57 <https://arxiv.org/abs/2305.08891>`_.
58 """
59 alpha_bar_length = len(self.alpha_bars)
60
61 # Convert betas to alphas_bar_sqrt
62 alphas = 1 - self.betas
63 alphas_bar = alphas.cumprod(0)
64 alphas_bar_sqrt = alphas_bar.sqrt()
65
66 # Store old values.
67 alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone()
68 alphas_bar_sqrt_t = alphas_bar_sqrt[-1].clone()
69 # Shift so last timestep is zero.
70 alphas_bar_sqrt -= alphas_bar_sqrt_t
71 # Scale so first timestep is back to old value.
72 alphas_bar_sqrt *= alphas_bar_sqrt_0 / (
73 alphas_bar_sqrt_0 - alphas_bar_sqrt_t
74 )
75
76 # Convert alphas_bar_sqrt to betas
77 alphas_bar = alphas_bar_sqrt**2
78 alphas = alphas_bar[1:] / alphas_bar[:-1]
79 alphas = torch.cat([alphas_bar[0:1], alphas])
80 betas = 1 - alphas
81 if len(alphas) == alpha_bar_length:
82 self.betas = betas
83 self.alpha_bars = alphas_bar
84 else:
85 logging.warning(
86 "Got different alpha_bar length after enforcing zero SNR. Please check your beta scheduler"
87 )
88
89 @abc.abstractmethod
[docs]
90 def sample_betas(self) -> torch.Tensor:
91 """Compute :math:`\\beta` for noise scheduling.
92
93 Returns:
94 A torch tensor of the :math:`\\beta` values.
95 """
96 raise NotImplementedError()
97
98 @abc.abstractmethod
[docs]
99 def compute_alpha_bar(self) -> torch.Tensor:
100 """Compute :math:`\\bar{\\alpha}` for noise scheduling.
101
102 Returns:
103 A torch tensor of the :math:`\\bar{\\alpha}` values.
104 """
105 raise NotImplementedError()
106
[docs]
107 def to(self, device: str):
108 """Moves the beta scheduler to the given device.
109
110 Args:
111 device: The device to which the method should move the object.
112 Default is "cpu".
113
114 """
115 self.betas = self.betas.to(device)
116 self.alpha_bars = self.alpha_bars.to(device)
117 return self
118
119 @classmethod
[docs]
120 def from_tensors(
121 cls, steps: int, betas: torch.Tensor, alpha_bars: torch.Tensor
122 ):
123 """Instantiate a beta scheduler from tensors.
124
125 Instantiate a beta scheduler from tensors. This is particularly useful for
126 loading checkpoints.
127
128 Args:
129 steps: The number of steps for the beta scheduler.
130 betas: The pre-computed beta values for the noise scheduler.
131 alpha_bars: The pre-computed alpha bar values for the noise scheduler.
132
133 Returns:
134
135 """
136 generic_beta_scheduler = cls(steps, initialize=False)
137 generic_beta_scheduler.steps = steps
138 generic_beta_scheduler.betas = betas
139 generic_beta_scheduler.alpha_bars = alpha_bars
140 return generic_beta_scheduler
141
142
[docs]
143class LinearBetaScheduler(BaseBetaScheduler):
144 def __init__(
145 self,
146 beta_start: float = 0.0001,
147 beta_end: float = 0.02,
148 steps: int = 1000,
149 enforce_zero_terminal_snr: bool = True,
150 ):
151 """A Linear Beta scheduler.
152
153 A simple linear beta scheduler with betas linearly spaced between
154 ``beta_start`` and ``beta_end``.
155
156 Args:
157 beta_start: The starting value of the betas.
158 beta_end: The end value of the betas.
159 steps: The number of steps for the beta scheduler. This is also the number
160 of betas.
161 enforce_zero_terminal_snr: Whether to enforce zero terminal SNR.
162 """
[docs]
163 self.beta_start: int = beta_start
164 """The starting value of the betas."""
[docs]
165 self.beta_end: int = beta_end
166 """The end value of the betas."""
167 super().__init__(
168 steps=steps,
169 enforce_zero_terminal_snr=enforce_zero_terminal_snr,
170 )
171
[docs]
172 def sample_betas(self) -> torch.Tensor:
173 """Return linearly spaced betas between ``self.beta_start`` and ``self.beta_end``."""
174 return torch.linspace(self.beta_start, self.beta_end, self.steps)
175
[docs]
176 def compute_alpha_bar(self):
177 """Return :math:`\\bar{\\alpha}` computed from the beta values."""
178 alphas = 1 - self.betas
179 alpha_bar = torch.cumprod(alphas, dim=0)
180 return alpha_bar
181
182
[docs]
183class CosineBetaScheduler(BaseBetaScheduler):
184 def __init__(
185 self,
186 offset: float = 0.008,
187 steps: int = 1000,
188 max_beta: Optional[float] = 0.999,
189 ):
190 """A Cosine Beta scheduler.
191
192 A Cosine Beta Scheduler based on the following formulas:
193
194 .. math::
195 :nowrap:
196
197 \\begin{equation}
198 \\left\\{ \\begin{aligned}
199 \\bar{\\alpha}_t &= \\frac{f(t)}{f(0)} \\\\
200 \\beta_t &= 1 - \\frac{\\bar{\\alpha}_t}{\\bar{\\alpha}_t -1}
201 \\end{aligned} \\right.
202 \\end{equation}
203
204 where
205
206 .. math::
207
208 f(t) = \\cos(\\frac{t/T + s}{1 + s} * \\frac{\\pi}{2})^2
209
210 where
211
212 .. math::
213 :nowrap:
214
215 \\begin{equation}
216 \\left\\{ \\begin{aligned}
217
218 s & \\text{ is the offset} \\\\
219 T & \\text{ is the number of steps}
220
221 \\end{aligned} \\right.
222 \\end{equation}
223
224 Args:
225 offset: The offset :math:`s` defined above.
226 steps: The number of steps for the beta scheduler.
227 max_beta: The maximum beta values. Higher values will be clipped.
228 """
[docs]
229 self.offset: float = offset
230 """The offset :math:`s` defined above."""
[docs]
231 self.max_beta: Optional[float] = max_beta
232 """The maximum beta values. Higher values will be clipped."""
[docs]
233 self.steps: int = steps
234 """The number of steps for the beta scheduler."""
235 self._alpha_bars = self._compute_alpha_bar()
236 self._betas = self._compute_betas()
237
238 super().__init__(
239 steps=steps,
240 )
241
[docs]
242 def f(self, t: torch.Tensor) -> torch.Tensor:
243 """A helper function to compute the :math:`\\bar{\\alpha}_t`.
244
245 Args:
246 t: The timestep to compute.
247
248 Returns:
249
250 .. math::
251
252 f(t) = \\cos(\\frac{t/T + s}{1 + s} * \\frac{\\pi}{2})^2
253
254 """
255 return (
256 torch.cos(
257 (((t / self.steps) + self.offset) / (1 + self.offset)) * (torch.pi / 2)
258 )
259 ** 2
260 )
261
262 def _compute_betas(self):
263 betas = 1 - self._alpha_bars[1:] / self._alpha_bars[:-1]
264 if self.max_beta:
265 betas = torch.clip(betas, max=self.max_beta)
266 return betas
267
268 def _compute_alpha_bar(self):
269 t = torch.linspace(0, self.steps, self.steps, dtype=torch.float32)
270 return self.f(t) / self.f(torch.tensor([0], dtype=torch.float32))
271
[docs]
272 def sample_betas(self):
273 return self._betas
274
[docs]
275 def compute_alpha_bar(self):
276 return self._alpha_bars