Source code for diffusion_models.models.SimpleUnet

  1# Diffusion Model taken from DeepFindr (https://www.youtube.com/@DeepFindr)
  2import math
  3
  4import torch
  5from torch import nn
  6
  7from diffusion_models.gaussian_diffusion.gaussian_diffuser import (
  8  GaussianDiffuser,
  9)
 10from diffusion_models.models.base_diffusion_model import BaseDiffusionModel
 11
 12
[docs] 13class Block(nn.Module): 14 def __init__(self, in_ch, out_ch, time_emb_dim, up=False): 15 super().__init__()
[docs] 16 self.time_mlp = nn.Linear(time_emb_dim, out_ch)
17 if up: 18 self.conv1 = nn.Conv2d(2 * in_ch, out_ch, 3, padding=1) 19 self.transform = nn.ConvTranspose2d(out_ch, out_ch, 4, 2, 1) 20 else: 21 self.conv1 = nn.Conv2d(in_ch, out_ch, 3, padding=1) 22 self.transform = nn.Conv2d(out_ch, out_ch, 4, 2, 1)
[docs] 23 self.conv2 = nn.Conv2d(out_ch, out_ch, 3, padding=1)
[docs] 24 self.bnorm1 = nn.BatchNorm2d(out_ch)
[docs] 25 self.bnorm2 = nn.BatchNorm2d(out_ch)
[docs] 26 self.relu = nn.ReLU()
27
[docs] 28 def forward( 29 self, 30 x, 31 t, 32 ): 33 # First Conv 34 h = self.bnorm1(self.relu(self.conv1(x))) 35 # Time embedding 36 time_emb = self.relu(self.time_mlp(t)) 37 # Extend last 2 dimensions 38 time_emb = time_emb[(...,) + (None,) * 2] 39 # Add time channel 40 h = h + time_emb 41 # Second Conv 42 h = self.bnorm2(self.relu(self.conv2(h))) 43 # Down or Upsample 44 return self.transform(h)
45 46
[docs] 47class SinusoidalPositionEmbeddings(nn.Module): 48 def __init__(self, dim): 49 super().__init__()
[docs] 50 self.dim = dim
51
[docs] 52 def forward(self, time): 53 device = time.device 54 half_dim = self.dim // 2 55 embeddings = math.log(10000) / (half_dim - 1) 56 embeddings = torch.exp(torch.arange(half_dim, device=device) * -embeddings) 57 embeddings = time[:, None] * embeddings[None, :] 58 embeddings = torch.cat((embeddings.sin(), embeddings.cos()), dim=-1) 59 # TODO: Double check the ordering here 60 return embeddings
61 62
[docs] 63class SimpleUnet(BaseDiffusionModel): 64 def __init__( 65 self, 66 diffuser: GaussianDiffuser, 67 image_channels: int, 68 ): 69 """A Simplified variant of the Unet architecture used in DDPM. 70 71 Args: 72 diffuser: A gaussian diffuser. 73 image_channels: The number of image channels. 74 """ 75 super().__init__(diffuser=diffuser)
[docs] 76 image_channels = image_channels
[docs] 77 down_channels = (64, 128, 256, 512, 1024)
[docs] 78 up_channels = (1024, 512, 256, 128, 64)
[docs] 79 out_dim = image_channels
[docs] 80 time_emb_dim = 32
81 82 # Time embedding
[docs] 83 self.time_mlp = nn.Sequential( 84 SinusoidalPositionEmbeddings(time_emb_dim), 85 nn.Linear(time_emb_dim, time_emb_dim), 86 nn.ReLU(), 87 )
88 89 # Initial projection
[docs] 90 self.conv0 = nn.Conv2d(image_channels, down_channels[0], 3, padding=1)
91 92 # Downsample
[docs] 93 self.downs = nn.ModuleList( 94 [ 95 Block(down_channels[i], down_channels[i + 1], time_emb_dim) 96 for i in range(len(down_channels) - 1) 97 ] 98 )
99 # Upsample
[docs] 100 self.ups = nn.ModuleList( 101 [ 102 Block(up_channels[i], up_channels[i + 1], time_emb_dim, up=True) 103 for i in range(len(up_channels) - 1) 104 ] 105 )
106 107 # Edit: Corrected a bug found by Jakub C (see YouTube comment)
[docs] 108 self.output = nn.Conv2d(up_channels[-1], out_dim, 1)
109
[docs] 110 def forward(self, x, timestep): 111 # if self.training: 112 # x = self.augmentations(x) 113 114 # Embedd time 115 t = self.time_mlp(timestep) 116 # Initial conv 117 x = self.conv0(x) 118 # Unet 119 residual_inputs = [] 120 for down in self.downs: 121 x = down(x, t) 122 residual_inputs.append(x) 123 for up in self.ups: 124 residual_x = residual_inputs.pop() 125 # Add residual x as additional channels 126 x = torch.cat((x, residual_x), dim=1) 127 x = up(x, t) 128 return self.output(x)