1# Diffusion Model taken from DeepFindr (https://www.youtube.com/@DeepFindr)
2import math
3
4import torch
5from torch import nn
6
7from diffusion_models.gaussian_diffusion.gaussian_diffuser import (
8 GaussianDiffuser,
9)
10from diffusion_models.models.base_diffusion_model import BaseDiffusionModel
11
12
[docs]
13class Block(nn.Module):
14 def __init__(self, in_ch, out_ch, time_emb_dim, up=False):
15 super().__init__()
[docs]
16 self.time_mlp = nn.Linear(time_emb_dim, out_ch)
17 if up:
18 self.conv1 = nn.Conv2d(2 * in_ch, out_ch, 3, padding=1)
19 self.transform = nn.ConvTranspose2d(out_ch, out_ch, 4, 2, 1)
20 else:
21 self.conv1 = nn.Conv2d(in_ch, out_ch, 3, padding=1)
22 self.transform = nn.Conv2d(out_ch, out_ch, 4, 2, 1)
[docs]
23 self.conv2 = nn.Conv2d(out_ch, out_ch, 3, padding=1)
[docs]
24 self.bnorm1 = nn.BatchNorm2d(out_ch)
[docs]
25 self.bnorm2 = nn.BatchNorm2d(out_ch)
[docs]
26 self.relu = nn.ReLU()
27
[docs]
28 def forward(
29 self,
30 x,
31 t,
32 ):
33 # First Conv
34 h = self.bnorm1(self.relu(self.conv1(x)))
35 # Time embedding
36 time_emb = self.relu(self.time_mlp(t))
37 # Extend last 2 dimensions
38 time_emb = time_emb[(...,) + (None,) * 2]
39 # Add time channel
40 h = h + time_emb
41 # Second Conv
42 h = self.bnorm2(self.relu(self.conv2(h)))
43 # Down or Upsample
44 return self.transform(h)
45
46
[docs]
47class SinusoidalPositionEmbeddings(nn.Module):
48 def __init__(self, dim):
49 super().__init__()
51
[docs]
52 def forward(self, time):
53 device = time.device
54 half_dim = self.dim // 2
55 embeddings = math.log(10000) / (half_dim - 1)
56 embeddings = torch.exp(torch.arange(half_dim, device=device) * -embeddings)
57 embeddings = time[:, None] * embeddings[None, :]
58 embeddings = torch.cat((embeddings.sin(), embeddings.cos()), dim=-1)
59 # TODO: Double check the ordering here
60 return embeddings
61
62
[docs]
63class SimpleUnet(BaseDiffusionModel):
64 def __init__(
65 self,
66 diffuser: GaussianDiffuser,
67 image_channels: int,
68 ):
69 """A Simplified variant of the Unet architecture used in DDPM.
70
71 Args:
72 diffuser: A gaussian diffuser.
73 image_channels: The number of image channels.
74 """
75 super().__init__(diffuser=diffuser)
[docs]
76 image_channels = image_channels
[docs]
77 down_channels = (64, 128, 256, 512, 1024)
[docs]
78 up_channels = (1024, 512, 256, 128, 64)
[docs]
79 out_dim = image_channels
81
82 # Time embedding
[docs]
83 self.time_mlp = nn.Sequential(
84 SinusoidalPositionEmbeddings(time_emb_dim),
85 nn.Linear(time_emb_dim, time_emb_dim),
86 nn.ReLU(),
87 )
88
89 # Initial projection
[docs]
90 self.conv0 = nn.Conv2d(image_channels, down_channels[0], 3, padding=1)
91
92 # Downsample
[docs]
93 self.downs = nn.ModuleList(
94 [
95 Block(down_channels[i], down_channels[i + 1], time_emb_dim)
96 for i in range(len(down_channels) - 1)
97 ]
98 )
99 # Upsample
[docs]
100 self.ups = nn.ModuleList(
101 [
102 Block(up_channels[i], up_channels[i + 1], time_emb_dim, up=True)
103 for i in range(len(up_channels) - 1)
104 ]
105 )
106
107 # Edit: Corrected a bug found by Jakub C (see YouTube comment)
[docs]
108 self.output = nn.Conv2d(up_channels[-1], out_dim, 1)
109
[docs]
110 def forward(self, x, timestep):
111 # if self.training:
112 # x = self.augmentations(x)
113
114 # Embedd time
115 t = self.time_mlp(timestep)
116 # Initial conv
117 x = self.conv0(x)
118 # Unet
119 residual_inputs = []
120 for down in self.downs:
121 x = down(x, t)
122 residual_inputs.append(x)
123 for up in self.ups:
124 residual_x = residual_inputs.pop()
125 # Add residual x as additional channels
126 x = torch.cat((x, residual_x), dim=1)
127 x = up(x, t)
128 return self.output(x)