1# Diffusion Model taken from DeepFindr (https://www.youtube.com/@DeepFindr)
2import math
3
4import torch
5from torch import nn
6
7from diffusion_models.gaussian_diffusion.base_diffuser import BaseDiffuser
8from diffusion_models.gaussian_diffusion.gaussian_diffuser import (
9 GaussianDiffuser,
10)
11from diffusion_models.models.base_diffusion_model import BaseDiffusionModel
12
13
[docs]
14class Block(nn.Module):
15 def __init__(self, in_ch, out_ch, time_emb_dim, up=False):
16 super().__init__()
[docs]
17 self.time_mlp = nn.Linear(time_emb_dim, out_ch)
18 if up:
19 self.conv1 = nn.Conv2d(2 * in_ch, out_ch, 3, padding=1)
20 self.transform = nn.ConvTranspose2d(out_ch, out_ch, 4, 2, 1)
21 else:
22 self.conv1 = nn.Conv2d(in_ch, out_ch, 3, padding=1)
23 self.transform = nn.Conv2d(out_ch, out_ch, 4, 2, 1)
[docs]
24 self.conv2 = nn.Conv2d(out_ch, out_ch, 3, padding=1)
[docs]
25 self.bnorm1 = nn.BatchNorm2d(out_ch)
[docs]
26 self.bnorm2 = nn.BatchNorm2d(out_ch)
[docs]
27 self.relu = nn.ReLU()
28
[docs]
29 def forward(
30 self,
31 x,
32 t,
33 ):
34 # First Conv
35 h = self.bnorm1(self.relu(self.conv1(x)))
36 # Time embedding
37 time_emb = self.relu(self.time_mlp(t))
38 # Extend last 2 dimensions
39 time_emb = time_emb[(...,) + (None,) * 2]
40 # Add time channel
41 h = h + time_emb
42 # Second Conv
43 h = self.bnorm2(self.relu(self.conv2(h)))
44 # Down or Upsample
45 return self.transform(h)
46
47
[docs]
48class SinusoidalPositionEmbeddings(nn.Module):
49 def __init__(self, dim):
50 super().__init__()
52
[docs]
53 def forward(self, time):
54 device = time.device
55 half_dim = self.dim // 2
56 embeddings = math.log(10000) / (half_dim - 1)
57 embeddings = torch.exp(torch.arange(half_dim, device=device) * -embeddings)
58 embeddings = time[:, None] * embeddings[None, :]
59 embeddings = torch.cat((embeddings.sin(), embeddings.cos()), dim=-1)
60 # TODO: Double check the ordering here
61 return embeddings
62
63
[docs]
64class SimpleUnet(BaseDiffusionModel):
65 def __init__(
66 self,
67 diffuser: BaseDiffuser,
68 image_channels: int,
69 ):
70 """A Simplified variant of the Unet architecture used in DDPM.
71
72 Args:
73 diffuser: A gaussian diffuser.
74 image_channels: The number of image channels.
75 """
76 super().__init__(diffuser=diffuser)
77 image_channels = image_channels
78 down_channels = (64, 128, 256, 512, 1024)
79 up_channels = (1024, 512, 256, 128, 64)
80 out_dim = image_channels
81 time_emb_dim = 32
82
83 # Time embedding
[docs]
84 self.time_mlp = nn.Sequential(
85 SinusoidalPositionEmbeddings(time_emb_dim),
86 nn.Linear(time_emb_dim, time_emb_dim),
87 nn.ReLU(),
88 )
89
90 # Initial projection
[docs]
91 self.conv0 = nn.Conv2d(image_channels, down_channels[0], 3, padding=1)
92
93 # Downsample
[docs]
94 self.downs = nn.ModuleList(
95 [
96 Block(down_channels[i], down_channels[i + 1], time_emb_dim)
97 for i in range(len(down_channels) - 1)
98 ]
99 )
100 # Upsample
[docs]
101 self.ups = nn.ModuleList(
102 [
103 Block(up_channels[i], up_channels[i + 1], time_emb_dim, up=True)
104 for i in range(len(up_channels) - 1)
105 ]
106 )
107
108 # Edit: Corrected a bug found by Jakub C (see YouTube comment)
[docs]
109 self.output = nn.Conv2d(up_channels[-1], out_dim, 1)
110
[docs]
111 def forward(self, x, timestep):
112 # if self.training:
113 # x = self.augmentations(x)
114
115 # Embedd time
116 t = self.time_mlp(timestep)
117 # Initial conv
118 x = self.conv0(x)
119 # Unet
120 residual_inputs = []
121 for down in self.downs:
122 x = down(x, t)
123 residual_inputs.append(x)
124 for up in self.ups:
125 residual_x = residual_inputs.pop()
126 # Add residual x as additional channels
127 x = torch.cat((x, residual_x), dim=1)
128 x = up(x, t)
129 return self.output(x)