examples.train_model
Running a simple training script.
Below is a simple example of a training script. It contains basic configurations for a simple diffusion model and constitutes a good starting point.
1from torch.nn import functional as F
2from torch.optim import AdamW
3from torchvision import datasets
4from torchvision.transforms import v2
5
6from diffusion_models.diffusion_trainer import DiffusionTrainer
7from diffusion_models.gaussian_diffusion.beta_schedulers import (
8 LinearBetaScheduler,
9)
10from diffusion_models.gaussian_diffusion.gaussian_diffuser import (
11 GaussianDiffuser,
12)
13from diffusion_models.models.SimpleUnet import SimpleUnet
14from diffusion_models.utils.schemas import LogConfiguration
15from diffusion_models.utils.schemas import TrainingConfiguration
16
17
18if __name__ == "__main__":
19 image_size = 64
20 image_channels = 3
21
22 training_configuration = TrainingConfiguration(
23 batch_size=256,
24 learning_rate=2 * 10e-4,
25 number_of_epochs=500,
26 training_name="ReworkedFrameworkBase",
27 checkpoint_rate=100,
28 mixed_precision_training=False,
29 # gradient_clip=0.1,
30 )
31 log_configuration = LogConfiguration(
32 log_rate=10,
33 image_rate=635,
34 number_of_images=5,
35 )
36 model = SimpleUnet(
37 image_channels=image_channels,
38 diffuser=GaussianDiffuser(
39 beta_scheduler=LinearBetaScheduler(
40 beta_start=0.0001,
41 beta_end=0.02,
42 steps=1000,
43 ),
44 ),
45 )
46
47 print("Num params: ", sum(p.numel() for p in model.parameters()))
48 # model = model.compile(fullgraph=True, mode="reduce-overhead")
49
50 # Define Image Transforms and Reverse Transforms
51 image_transforms = v2.Compose(
52 [
53 v2.ToImage(),
54 v2.Resize((image_size, image_size)),
55 v2.Lambda(lambda x: (x + 1) / 2),
56 ]
57 )
58
59 reverse_transforms = v2.Compose(
60 [v2.Lambda(lambda x: (x + 1) / 2), v2.Resize((128, 128))]
61 )
62
63 # Define Dataset
64 dataset = datasets.CelebA(
65 root="../data", download=False, transform=image_transforms, split="train"
66 )
67 # dataset = datasets.MNIST(
68 # transform=image_transforms,
69 # root="../data",
70 # download=True
71 # )
72
73 # Instantiate DiffusionTrainer
74 trainer = DiffusionTrainer(
75 model=model,
76 dataset=dataset,
77 optimizer=AdamW(
78 model.parameters(), lr=training_configuration.learning_rate
79 ),
80 reverse_transforms=reverse_transforms,
81 training_configuration=training_configuration,
82 loss_function=F.l1_loss,
83 scheduler=None,
84 log_configuration=log_configuration,