Source code for diffusion_models.models.SimpleUnet

  1# Diffusion Model taken from DeepFindr (https://www.youtube.com/@DeepFindr)
  2import math
  3
  4import torch
  5from torch import nn
  6
  7from diffusion_models.gaussian_diffusion.base_diffuser import BaseDiffuser
  8from diffusion_models.gaussian_diffusion.gaussian_diffuser import (
  9  GaussianDiffuser,
 10)
 11from diffusion_models.models.base_diffusion_model import BaseDiffusionModel
 12
 13
[docs] 14class Block(nn.Module): 15 def __init__(self, in_ch, out_ch, time_emb_dim, up=False): 16 super().__init__()
[docs] 17 self.time_mlp = nn.Linear(time_emb_dim, out_ch)
18 if up: 19 self.conv1 = nn.Conv2d(2 * in_ch, out_ch, 3, padding=1) 20 self.transform = nn.ConvTranspose2d(out_ch, out_ch, 4, 2, 1) 21 else: 22 self.conv1 = nn.Conv2d(in_ch, out_ch, 3, padding=1) 23 self.transform = nn.Conv2d(out_ch, out_ch, 4, 2, 1)
[docs] 24 self.conv2 = nn.Conv2d(out_ch, out_ch, 3, padding=1)
[docs] 25 self.bnorm1 = nn.BatchNorm2d(out_ch)
[docs] 26 self.bnorm2 = nn.BatchNorm2d(out_ch)
[docs] 27 self.relu = nn.ReLU()
28
[docs] 29 def forward( 30 self, 31 x, 32 t, 33 ): 34 # First Conv 35 h = self.bnorm1(self.relu(self.conv1(x))) 36 # Time embedding 37 time_emb = self.relu(self.time_mlp(t)) 38 # Extend last 2 dimensions 39 time_emb = time_emb[(...,) + (None,) * 2] 40 # Add time channel 41 h = h + time_emb 42 # Second Conv 43 h = self.bnorm2(self.relu(self.conv2(h))) 44 # Down or Upsample 45 return self.transform(h)
46 47
[docs] 48class SinusoidalPositionEmbeddings(nn.Module): 49 def __init__(self, dim): 50 super().__init__()
[docs] 51 self.dim = dim
52
[docs] 53 def forward(self, time): 54 device = time.device 55 half_dim = self.dim // 2 56 embeddings = math.log(10000) / (half_dim - 1) 57 embeddings = torch.exp(torch.arange(half_dim, device=device) * -embeddings) 58 embeddings = time[:, None] * embeddings[None, :] 59 embeddings = torch.cat((embeddings.sin(), embeddings.cos()), dim=-1) 60 # TODO: Double check the ordering here 61 return embeddings
62 63
[docs] 64class SimpleUnet(BaseDiffusionModel): 65 def __init__( 66 self, 67 diffuser: BaseDiffuser, 68 image_channels: int, 69 ): 70 """A Simplified variant of the Unet architecture used in DDPM. 71 72 Args: 73 diffuser: A gaussian diffuser. 74 image_channels: The number of image channels. 75 """ 76 super().__init__(diffuser=diffuser) 77 image_channels = image_channels 78 down_channels = (64, 128, 256, 512, 1024) 79 up_channels = (1024, 512, 256, 128, 64) 80 out_dim = image_channels 81 time_emb_dim = 32 82 83 # Time embedding
[docs] 84 self.time_mlp = nn.Sequential( 85 SinusoidalPositionEmbeddings(time_emb_dim), 86 nn.Linear(time_emb_dim, time_emb_dim), 87 nn.ReLU(), 88 )
89 90 # Initial projection
[docs] 91 self.conv0 = nn.Conv2d(image_channels, down_channels[0], 3, padding=1)
92 93 # Downsample
[docs] 94 self.downs = nn.ModuleList( 95 [ 96 Block(down_channels[i], down_channels[i + 1], time_emb_dim) 97 for i in range(len(down_channels) - 1) 98 ] 99 )
100 # Upsample
[docs] 101 self.ups = nn.ModuleList( 102 [ 103 Block(up_channels[i], up_channels[i + 1], time_emb_dim, up=True) 104 for i in range(len(up_channels) - 1) 105 ] 106 )
107 108 # Edit: Corrected a bug found by Jakub C (see YouTube comment)
[docs] 109 self.output = nn.Conv2d(up_channels[-1], out_dim, 1)
110
[docs] 111 def forward(self, x, timestep): 112 # if self.training: 113 # x = self.augmentations(x) 114 115 # Embedd time 116 t = self.time_mlp(timestep) 117 # Initial conv 118 x = self.conv0(x) 119 # Unet 120 residual_inputs = [] 121 for down in self.downs: 122 x = down(x, t) 123 residual_inputs.append(x) 124 for up in self.ups: 125 residual_x = residual_inputs.pop() 126 # Add residual x as additional channels 127 x = torch.cat((x, residual_x), dim=1) 128 x = up(x, t) 129 return self.output(x)