examples.train_model

Running a simple training script.

Below is a simple example of a training script. It contains basic configurations for a simple diffusion model and constitutes a good starting point.

 1from torch.nn import functional as F
 2from torch.optim import AdamW
 3from torchvision import datasets
 4from torchvision.transforms import v2
 5
 6from diffusion_models.diffusion_trainer import DiffusionTrainer
 7from diffusion_models.gaussian_diffusion.beta_schedulers import (
 8  LinearBetaScheduler,
 9)
10from diffusion_models.gaussian_diffusion.gaussian_diffuser import (
11  GaussianDiffuser,
12)
13from diffusion_models.models.SimpleUnet import SimpleUnet
14from diffusion_models.utils.schemas import LogConfiguration
15from diffusion_models.utils.schemas import TrainingConfiguration
16
17
18if __name__ == "__main__":
19  image_size = 64
20  image_channels = 3
21
22  training_configuration = TrainingConfiguration(
23    batch_size=256,
24    learning_rate=2 * 10e-4,
25    number_of_epochs=500,
26    training_name="ReworkedFrameworkBase",
27    checkpoint_rate=100,
28    mixed_precision_training=False,
29    # gradient_clip=0.1,
30  )
31  log_configuration = LogConfiguration(
32    log_rate=10,
33    image_rate=635,
34    number_of_images=5,
35  )
36  model = SimpleUnet(
37    image_channels=image_channels,
38    diffuser=GaussianDiffuser(
39      beta_scheduler=LinearBetaScheduler(
40        beta_start=0.0001,
41        beta_end=0.02,
42        steps=1000,
43      ),
44    ),
45  )
46
47  print("Num params: ", sum(p.numel() for p in model.parameters()))
48  # model = model.compile(fullgraph=True, mode="reduce-overhead")
49
50  # Define Image Transforms and Reverse Transforms
51  image_transforms = v2.Compose(
52    [
53      v2.ToImage(),
54      v2.Resize((image_size, image_size)),
55      v2.Lambda(lambda x: (x + 1) / 2),
56    ]
57  )
58
59  reverse_transforms = v2.Compose(
60    [v2.Lambda(lambda x: (x + 1) / 2), v2.Resize((128, 128))]
61  )
62
63  # Define Dataset
64  dataset = datasets.CelebA(
65    root="../data", download=False, transform=image_transforms, split="train"
66  )
67  # dataset = datasets.MNIST(
68  #   transform=image_transforms,
69  #   root="../data",
70  #   download=True
71  # )
72
73  # Instantiate DiffusionTrainer
74  trainer = DiffusionTrainer(
75    model=model,
76    dataset=dataset,
77    optimizer=AdamW(
78      model.parameters(), lr=training_configuration.learning_rate
79    ),
80    reverse_transforms=reverse_transforms,
81    training_configuration=training_configuration,
82    loss_function=F.l1_loss,
83    scheduler=None,
84    log_configuration=log_configuration,