diffusion_models.diffusion_trainer

Module Contents

class DiffusionTrainer(model, dataset, optimizer, training_configuration, loss_function=F.l1_loss, scheduler=None, log_configuration=LogConfiguration(), reverse_transforms=lambda x: ..., device='cuda')[source]
model[source]
optimizer[source]
loss_function[source]
training_configuration[source]
scheduler[source]
device[source]
dataloader[source]
scaler[source]
log_configuration[source]
checkpoint_path[source]
tensorboard_manager[source]
reverse_transforms[source]
benchmark = True[source]
save_checkpoint(epoch, checkpoint_name)[source]
train()[source]
log_to_tensorboard(metrics, global_step)[source]