diffusion_models.diffusion_trainer

Module Contents

class DiffusionTrainer(model, dataset, optimizer, training_configuration, loss_function=F.l1_loss, log_configuration=LogConfiguration(), reverse_transforms=lambda x: ..., device='cuda')[source]

A diffusion trainer framework.

This is a simplified framework for training a diffusion models.

Parameters:
  • model (BaseDiffusionModel) – A diffusion model.

  • dataset (Dataset) – A dataset to train on.

  • optimizer (Optimizer) – The optimizer to use.

  • training_configuration (TrainingConfiguration) – The training configuration to use.

  • loss_function (Callable) – The loss function to use.

  • log_configuration (LogConfiguration) – The logging configuration to use.

  • reverse_transforms (Callable) – The reverse transforms to use.

  • device (str) – The device to use.

model[source]

The diffusion model to use.

optimizer[source]

The optimizer to use.

loss_function[source]

The loss function to use.

training_configuration[source]

The training configuration to use.

device[source]

The device to use.

dataloader[source]

A torch dataloader.

scaler[source]

A torch GradScaler object.

log_configuration[source]

A LogConfiguration object.

checkpoint_path[source]

The path to save checkpoints.

tensorboard_manager[source]

A tensorboard manager instance.

reverse_transforms[source]

A set of reverse transforms.

benchmark = True[source]
save_checkpoint(epoch, checkpoint_name)[source]

Save a checkpoint.

Parameters:
  • epoch (int) – The current epoch.

  • checkpoint_name (str) – The name of the checkpoint.

train()[source]

Start the diffusion training.

log_to_tensorboard(metrics, global_step)[source]

Log to tensorboard.

This method logs some useful metrics and visualizations to tensorboard.

Parameters:
  • metrics (Dict[str, float]) – A dictionary mapping metric names to values.

  • global_step (int) – The current global step.