diffusion_models.diffusion_trainer ================================== .. py:module:: diffusion_models.diffusion_trainer Module Contents --------------- .. py:class:: DiffusionTrainer(model, dataset, optimizer, training_configuration, loss_function = F.l1_loss, log_configuration = LogConfiguration(), reverse_transforms = lambda x: x, device = 'cuda') A diffusion trainer framework. This is a simplified framework for training a diffusion models. :param model: A diffusion model. :param dataset: A dataset to train on. :param optimizer: The optimizer to use. :param training_configuration: The training configuration to use. :param loss_function: The loss function to use. :param log_configuration: The logging configuration to use. :param reverse_transforms: The reverse transforms to use. :param device: The device to use. .. py:attribute:: model The diffusion model to use. .. py:attribute:: optimizer The optimizer to use. .. py:attribute:: loss_function The loss function to use. .. py:attribute:: training_configuration The training configuration to use. .. py:attribute:: device The device to use. .. py:attribute:: dataloader A torch dataloader. .. py:attribute:: scaler A torch GradScaler object. .. py:attribute:: log_configuration A LogConfiguration object. .. py:attribute:: checkpoint_path The path to save checkpoints. .. py:attribute:: tensorboard_manager A tensorboard manager instance. .. py:attribute:: reverse_transforms A set of reverse transforms. .. py:method:: save_checkpoint(epoch, checkpoint_name) Save a checkpoint. :param epoch: The current epoch. :param checkpoint_name: The name of the checkpoint. .. py:method:: train() Start the diffusion training. .. py:method:: log_to_tensorboard(metrics, global_step) Log to tensorboard. This method logs some useful metrics and visualizations to tensorboard. :param metrics: A dictionary mapping metric names to values. :param global_step: The current global step.