examples.train_model
Running a simple training script.
Below is a simple example of a training script. It contains basic configurations for a simple diffusion model and constitutes a good starting point.
1from torch.nn import functional as F
2from torch.optim import AdamW
3from torchvision import datasets
4from torchvision.transforms import v2
5
6from diffusion_models.diffusion_trainer import DiffusionTrainer
7from diffusion_models.gaussian_diffusion.beta_schedulers import (
8 LinearBetaScheduler,
9)
10from diffusion_models.gaussian_diffusion.gaussian_diffuser import (
11 GaussianDiffuser,
12)
13from diffusion_models.models.SimpleUnet import SimpleUnet
14from diffusion_models.utils.schemas import LogConfiguration
15from diffusion_models.utils.schemas import TrainingConfiguration
16
17
18if __name__ == "__main__":
19 image_size = 64
20 image_channels = 3
21
22 training_configuration = TrainingConfiguration(
23 batch_size=128,
24 learning_rate=2 * 10e-4,
25 number_of_epochs=500,
26 training_name="SimpleTraining",
27 checkpoint_rate=100,
28 mixed_precision_training=True,
29 gradient_clip=1.0,
30 )
31 log_configuration = LogConfiguration(
32 log_rate=10,
33 image_rate=5000,
34 number_of_images=5,
35 )
36 model = SimpleUnet(
37 image_channels=image_channels,
38 diffuser=GaussianDiffuser(
39 beta_scheduler=LinearBetaScheduler(
40 beta_start=0.0001,
41 beta_end=0.02,
42 steps=1000,
43 ),
44 ),
45 )
46
47 print("Num params: ", sum(p.numel() for p in model.parameters()))
48 # model = model.compile(fullgraph=True, mode="reduce-overhead")
49
50 # Define Image Transforms and Reverse Transforms
51 image_transforms = v2.Compose(
52 [
53 v2.ToImage(),
54 v2.Resize((image_size, image_size)),
55 v2.Lambda(lambda x: (x + 1) / 2),
56 ]
57 )
58
59 reverse_transforms = v2.Compose(
60 [v2.Lambda(lambda x: (x + 1) / 2), v2.Resize((128, 128))]
61 )
62
63 # Define Dataset
64 dataset = datasets.CelebA(
65 root="../data", download=False, transform=image_transforms, split="train"
66 )
67
68 # Instantiate DiffusionTrainer
69 trainer = DiffusionTrainer(
70 model=model,
71 dataset=dataset,
72 optimizer=AdamW(
73 model.parameters(), lr=training_configuration.learning_rate
74 ),
75 reverse_transforms=reverse_transforms,
76 training_configuration=training_configuration,
77 loss_function=F.l1_loss,
78 scheduler=None,
79 log_configuration=log_configuration,
80 device="cuda",
81 )
82
83 # Launch training
84 trainer.train()