1# Diffusion Model taken from DeepFindr (https://www.youtube.com/@DeepFindr)
2import math
3
4import torch
5from torch import nn
6
7from diffusion_models.gaussian_diffusion.gaussian_diffuser import (
8 GaussianDiffuser,
9)
10from diffusion_models.models.base_diffusion_model import BaseDiffusionModel
11
12
[docs]
13class Block(nn.Module):
14 def __init__(self, in_ch, out_ch, time_emb_dim, up=False):
15 super().__init__()
[docs]
16 self.time_mlp = nn.Linear(time_emb_dim, out_ch)
17 if up:
18 self.conv1 = nn.Conv2d(2 * in_ch, out_ch, 3, padding=1)
19 self.transform = nn.ConvTranspose2d(out_ch, out_ch, 4, 2, 1)
20 else:
21 self.conv1 = nn.Conv2d(in_ch, out_ch, 3, padding=1)
22 self.transform = nn.Conv2d(out_ch, out_ch, 4, 2, 1)
[docs]
23 self.conv2 = nn.Conv2d(out_ch, out_ch, 3, padding=1)
[docs]
24 self.bnorm1 = nn.BatchNorm2d(out_ch)
[docs]
25 self.bnorm2 = nn.BatchNorm2d(out_ch)
[docs]
26 self.relu = nn.ReLU()
27
[docs]
28 def forward(
29 self,
30 x,
31 t,
32 ):
33 # First Conv
34 h = self.bnorm1(self.relu(self.conv1(x)))
35 # Time embedding
36 time_emb = self.relu(self.time_mlp(t))
37 # Extend last 2 dimensions
38 time_emb = time_emb[(...,) + (None,) * 2]
39 # Add time channel
40 h = h + time_emb
41 # Second Conv
42 h = self.bnorm2(self.relu(self.conv2(h)))
43 # Down or Upsample
44 return self.transform(h)
45
46
[docs]
47class SinusoidalPositionEmbeddings(nn.Module):
48 def __init__(self, dim):
49 super().__init__()
51
[docs]
52 def forward(self, time):
53 device = time.device
54 half_dim = self.dim // 2
55 embeddings = math.log(10000) / (half_dim - 1)
56 embeddings = torch.exp(torch.arange(half_dim, device=device) * -embeddings)
57 embeddings = time[:, None] * embeddings[None, :]
58 embeddings = torch.cat((embeddings.sin(), embeddings.cos()), dim=-1)
59 # TODO: Double check the ordering here
60 return embeddings
61
62
[docs]
63class SimpleUnet(BaseDiffusionModel):
64 """
65 A simplified variant of the Unet architecture.
66 """
67
68 def __init__(
69 self,
70 diffuser: GaussianDiffuser,
71 image_channels: int,
72 ):
73 super().__init__(diffuser=diffuser)
[docs]
74 image_channels = image_channels
[docs]
75 down_channels = (64, 128, 256, 512, 1024)
[docs]
76 up_channels = (1024, 512, 256, 128, 64)
[docs]
77 out_dim = image_channels
79
80 # Time embedding
[docs]
81 self.time_mlp = nn.Sequential(
82 SinusoidalPositionEmbeddings(time_emb_dim),
83 nn.Linear(time_emb_dim, time_emb_dim),
84 nn.ReLU(),
85 )
86
87 # Initial projection
[docs]
88 self.conv0 = nn.Conv2d(image_channels, down_channels[0], 3, padding=1)
89
90 # Downsample
[docs]
91 self.downs = nn.ModuleList(
92 [
93 Block(down_channels[i], down_channels[i + 1], time_emb_dim)
94 for i in range(len(down_channels) - 1)
95 ]
96 )
97 # Upsample
[docs]
98 self.ups = nn.ModuleList(
99 [
100 Block(up_channels[i], up_channels[i + 1], time_emb_dim, up=True)
101 for i in range(len(up_channels) - 1)
102 ]
103 )
104
105 # Edit: Corrected a bug found by Jakub C (see YouTube comment)
[docs]
106 self.output = nn.Conv2d(up_channels[-1], out_dim, 1)
107
[docs]
108 def forward(self, x, timestep):
109 # if self.training:
110 # x = self.augmentations(x)
111
112 # Embedd time
113 t = self.time_mlp(timestep)
114 # Initial conv
115 x = self.conv0(x)
116 # Unet
117 residual_inputs = []
118 for down in self.downs:
119 x = down(x, t)
120 residual_inputs.append(x)
121 for up in self.ups:
122 residual_x = residual_inputs.pop()
123 # Add residual x as additional channels
124 x = torch.cat((x, residual_x), dim=1)
125 x = up(x, t)
126 return self.output(x)