1import abc
2import logging
3from typing import Optional
4
5import torch
6
7
[docs]
8class BaseBetaScheduler:
9 def __init__(self, steps: int, enforce_zero_terminal_snr: bool = False):
10 super().__init__()
[docs]
12 self.betas = self.sample_betas()
[docs]
13 self.alpha_bars = self.compute_alpha_bar()
14
15 if enforce_zero_terminal_snr:
16 self.enforce_zero_terminal_snr()
17
[docs]
18 def enforce_zero_terminal_snr(self):
19 alpha_bar_length = len(self.alpha_bars)
20
21 # Convert betas to alphas_bar_sqrt
22 alphas = 1 - self.betas
23 alphas_bar = alphas.cumprod(0)
24 alphas_bar_sqrt = alphas_bar.sqrt()
25
26 # Store old values.
27 alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone()
28 alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone()
29 # Shift so last timestep is zero.
30 alphas_bar_sqrt -= alphas_bar_sqrt_T
31 # Scale so first timestep is back to old value.
32 alphas_bar_sqrt *= alphas_bar_sqrt_0 / (
33 alphas_bar_sqrt_0 - alphas_bar_sqrt_T
34 )
35
36 # Convert alphas_bar_sqrt to betas
37 alphas_bar = alphas_bar_sqrt**2
38 alphas = alphas_bar[1:] / alphas_bar[:-1]
39 alphas = torch.cat([alphas_bar[0:1], alphas])
40 betas = 1 - alphas
41 if len(alphas) == alpha_bar_length:
42 self.betas = betas
43 self.alpha_bars = alphas_bar
44 else:
45 logging.warning(
46 "Got different alpha_bar length after enforcing zero SNR. Please check your beta scheduler"
47 )
48
49 @abc.abstractmethod
[docs]
50 def sample_betas(self):
51 pass
52
53 @abc.abstractmethod
[docs]
54 def compute_alpha_bar(self):
55 pass
56
[docs]
57 def to(self, device: str):
58 self.betas = self.betas.to(device)
59 self.alpha_bars = self.alpha_bars.to(device)
60 return self
61
62 @classmethod
[docs]
63 def from_tensors(
64 cls, steps: int, betas: torch.Tensor, alpha_bars: torch.Tensor
65 ):
66 generic_beta_scheduler = cls(0)
67 generic_beta_scheduler.steps = steps
68 generic_beta_scheduler.betas = betas
69 generic_beta_scheduler.alpha_bars = alpha_bars
70 return generic_beta_scheduler
71
72
[docs]
73class LinearBetaScheduler(BaseBetaScheduler):
74 def __init__(
75 self,
76 beta_start: float = 0.0001,
77 beta_end: float = 0.02,
78 steps: int = 1000,
79 enforce_zero_terminal_snr: bool = True,
80 ):
[docs]
81 self.beta_start = beta_start
[docs]
82 self.beta_end = beta_end
83 super().__init__(
84 steps=steps,
85 enforce_zero_terminal_snr=enforce_zero_terminal_snr,
86 )
87
[docs]
88 def sample_betas(self):
89 return torch.linspace(self.beta_start, self.beta_end, self.steps)
90
[docs]
91 def compute_alpha_bar(self):
92 alphas = 1 - self.betas
93 alpha_bar = torch.cumprod(alphas, dim=0)
94 return alpha_bar
95
96
[docs]
97class CosineBetaScheduler(BaseBetaScheduler):
98 def __init__(
99 self,
100 offset: float = 0.008,
101 steps: int = 1000,
102 max_beta: Optional[float] = 0.999,
103 ):
[docs]
104 self.offset = offset
[docs]
105 self.max_beta = max_beta
107 self._alpha_bars = self._compute_alpha_bar()
108 self._betas = self._compute_betas()
109
110 super().__init__(
111 steps=steps,
112 )
113
[docs]
114 def f(self, t: torch.Tensor):
115 return (
116 torch.cos(
117 (((t / self.steps) + self.offset) / (1 + self.offset)) * (torch.pi / 2)
118 )
119 ** 2
120 )
121
122 def _compute_betas(self):
123 betas = 1 - self._alpha_bars[1:] / self._alpha_bars[:-1]
124 if self.max_beta:
125 betas = torch.clip(betas, max=self.max_beta)
126 return betas
127
128 def _compute_alpha_bar(self):
129 t = torch.linspace(0, self.steps, self.steps, dtype=torch.float32)
130 return self.f(t) / self.f(torch.tensor([0], dtype=torch.float32))
131
[docs]
132 def sample_betas(self):
133 return self._betas
134
[docs]
135 def compute_alpha_bar(self):
136 return self._alpha_bars