examples.train_model

Running a simple training script.

Below is a simple example of a training script. It contains basic configurations for a simple diffusion model and constitutes a good starting point.

 1from torch.nn import functional as F
 2from torch.optim import AdamW
 3from torchvision import datasets
 4from torchvision.transforms import v2
 5
 6from diffusion_models.diffusion_trainer import DiffusionTrainer
 7from diffusion_models.gaussian_diffusion.beta_schedulers import (
 8  LinearBetaScheduler,
 9)
10from diffusion_models.gaussian_diffusion.gaussian_diffuser import (
11  GaussianDiffuser,
12)
13from diffusion_models.models.SimpleUnet import SimpleUnet
14from diffusion_models.utils.schemas import LogConfiguration
15from diffusion_models.utils.schemas import TrainingConfiguration
16
17
18if __name__ == "__main__":
19  image_size = 64
20  image_channels = 3
21
22  training_configuration = TrainingConfiguration(
23    batch_size=128,
24    learning_rate=2 * 10e-4,
25    number_of_epochs=500,
26    training_name="SimpleTraining",
27    checkpoint_rate=100,
28    mixed_precision_training=True,
29    gradient_clip=1.0,
30  )
31  log_configuration = LogConfiguration(
32    log_rate=10,
33    image_rate=5000,
34    number_of_images=5,
35  )
36  model = SimpleUnet(
37    image_channels=image_channels,
38    diffuser=GaussianDiffuser(
39      beta_scheduler=LinearBetaScheduler(
40        beta_start=0.0001,
41        beta_end=0.02,
42        steps=1000,
43      ),
44    ),
45  )
46
47  print("Num params: ", sum(p.numel() for p in model.parameters()))
48  # model = model.compile(fullgraph=True, mode="reduce-overhead")
49
50  # Define Image Transforms and Reverse Transforms
51  image_transforms = v2.Compose(
52    [
53      v2.ToImage(),
54      v2.Resize((image_size, image_size)),
55      v2.Lambda(lambda x: (x + 1) / 2),
56    ]
57  )
58
59  reverse_transforms = v2.Compose(
60    [v2.Lambda(lambda x: (x + 1) / 2), v2.Resize((128, 128))]
61  )
62
63  # Define Dataset
64  dataset = datasets.CelebA(
65    root="../data", download=False, transform=image_transforms, split="train"
66  )
67
68  # Instantiate DiffusionTrainer
69  trainer = DiffusionTrainer(
70    model=model,
71    dataset=dataset,
72    optimizer=AdamW(
73      model.parameters(), lr=training_configuration.learning_rate
74    ),
75    reverse_transforms=reverse_transforms,
76    training_configuration=training_configuration,
77    loss_function=F.l1_loss,
78    scheduler=None,
79    log_configuration=log_configuration,
80    device="cuda",
81  )
82
83  # Launch training
84  trainer.train()