diffusion_models.utils.schemas

Module Contents

class TrainingConfiguration[source]

A training configuration for simple experiment management.

training_name: str[source]

The name of the training.

batch_size: int[source]

The batch size used for training.

learning_rate: float[source]

The learning rate of the training.

number_of_epochs: int[source]

The number of epoch used for training.

checkpoint_rate: int = 100[source]

The rate at which checkpoints are saved..

mixed_precision_training: bool = False[source]

Whether or not to use automatic mixed precision training.

gradient_clip: float | None = None[source]

Whether or not to clip gradients.

class LogConfiguration[source]

An object to manage logging configuration.

log_rate: int = 10[source]

The rate at which training metrics are logged.

image_rate: int = 50[source]

The rate at which images are generated for visualization. This can be used to validate model performance.

number_of_images: int = 5[source]

The number of images to generate.

class BetaSchedulerConfiguration[source]

A simplified beta scheduler configuration.

steps: int[source]

The number of steps in the beta scheduler.

betas: Tensor[source]

The beta values.

alpha_bars: Tensor[source]

The alpha bar values.

class Checkpoint[source]

A simplified checkpoint framework for easy saving and loading.

epoch: int[source]

The current epoch.

model_state_dict: Dict[str, Any][source]

The model state dict.

optimizer_state_dict: Dict[str, Any][source]

The optimizer state dict.

beta_scheduler_config: BetaSchedulerConfiguration[source]

The beta scheduler configuration.

scaler: GradScaler | None = None[source]

The GradScaler instance.

tensorboard_run_name: str | None = None[source]

The name of the tensorboard run.

image_channels: int = 3[source]

The number of image channels used in the training.

loss: float | None = None[source]

The final loss value recorded.

Note

This is a legacy parameter and will be removed in a future release.

classmethod from_file(file_path, map_location=None)[source]

Load and instantiate a checkpoint from a file.

Parameters:
  • file_path (str) – The path to the checkpoint file.

  • map_location (Optional[str]) – A function, torch. device, string or a dict specifying how to remap storage location.

Returns:

A checkpoint instance.

Return type:

Checkpoint

to_file(file_path)[source]

Saves a checkpoint to a file.

class OldCheckpoint[source]
epoch: int[source]
model_state_dict: Dict[str, Any][source]
optimizer_state_dict: Dict[str, Any][source]
scaler: GradScaler | None[source]
tensorboard_run_name: str | None = None[source]
loss: float | None = None[source]
classmethod from_file(file_path)[source]
to_file(file_path)[source]
to_new_checkpoint(beta_scheduler)[source]
class Timestep[source]
current: Tensor[source]
previous: Tensor | None = None[source]