diffusion_models.diffusion_trainer
Module Contents
- class DiffusionTrainer(model, dataset, optimizer, training_configuration, loss_function=F.l1_loss, log_configuration=LogConfiguration(), reverse_transforms=lambda x: ..., device='cuda')[source]
A diffusion trainer framework.
This is a simplified framework for training a diffusion models.
- Parameters:
model (diffusion_models.models.base_diffusion_model.BaseDiffusionModel) – A diffusion model.
dataset (Dataset) – A dataset to train on.
optimizer (Optimizer) – The optimizer to use.
training_configuration (TrainingConfiguration) – The training configuration to use.
loss_function (Callable) – The loss function to use.
log_configuration (LogConfiguration) – The logging configuration to use.
reverse_transforms (Callable) – The reverse transforms to use.
device (str) – The device to use.