Source code for diffusion_models.utils.schemas

  1import dataclasses
  2import pathlib
  3from dataclasses import dataclass
  4from typing import Any
  5from typing import Dict
  6from typing import Optional
  7from typing import Union
  8
  9import torch
 10from torch.cuda.amp import GradScaler
 11
 12from diffusion_models.gaussian_diffusion.beta_schedulers import (
 13  BaseBetaScheduler,
 14)
 15
 16
 17@dataclass
[docs] 18class TrainingConfiguration: 19 """A training configuration for simple experiment management."""
[docs] 20 training_name: str
21 """The name of the training."""
[docs] 22 batch_size: int
23 """The batch size used for training."""
[docs] 24 learning_rate: float
25 """The learning rate of the training."""
[docs] 26 number_of_epochs: int
27 """The number of epoch used for training."""
[docs] 28 checkpoint_rate: int = 100
29 """The rate at which checkpoints are saved.."""
[docs] 30 mixed_precision_training: bool = False # TODO: This is not complete yet
31 """Whether or not to use automatic mixed precision training."""
[docs] 32 gradient_clip: Optional[float] = None # TODO: This is not complete yet
33 """Whether or not to clip gradients."""
34 35 36@dataclass
[docs] 37class LogConfiguration: 38 """An object to manage logging configuration."""
[docs] 39 log_rate: int = 10
40 """The rate at which training metrics are logged."""
[docs] 41 image_rate: int = 50
42 """The rate at which images are generated for visualization. This can be used to validate model performance."""
[docs] 43 number_of_images: int = 5
44 """The number of images to generate."""
45 # metrics: Dict[str, float] # TODO: consider Dict[str, Callable] 46 47 48@dataclass
[docs] 49class BetaSchedulerConfiguration: 50 """A simplified beta scheduler configuration."""
[docs] 51 steps: int
52 """The number of steps in the beta scheduler."""
[docs] 53 betas: torch.Tensor
54 """The beta values."""
[docs] 55 alpha_bars: torch.Tensor
56 """The alpha bar values."""
57 58@dataclass
[docs] 59class Checkpoint: 60 """A simplified checkpoint framework for easy saving and loading."""
[docs] 61 epoch: int
62 """The current epoch."""
[docs] 63 model_state_dict: Dict[str, Any]
64 """The model state dict."""
[docs] 65 optimizer_state_dict: Dict[str, Any]
66 """The optimizer state dict."""
[docs] 67 scaler: Optional[GradScaler]
68 """The GradScaler instance."""
[docs] 69 beta_scheduler_config: BetaSchedulerConfiguration
70 """The beta scheduler configuration."""
[docs] 71 tensorboard_run_name: Optional[str] = None
72 """The name of the tensorboard run."""
[docs] 73 image_channels: int = 3
74 """The number of image channels used in the training."""
[docs] 75 loss: Optional[float] = ( 76 None # TODO: remove legacy parameter and resave models 77 )
78 """The final loss value recorded. 79 80 Note: 81 This is a legacy parameter and will be removed in a future release. 82 83 """ 84 85 @classmethod
[docs] 86 def from_file(cls, file_path: str) -> "Checkpoint": 87 """Load and instantiate a checkpoint from a file. 88 89 Args: 90 file_path: The path to the checkpoint file. 91 92 Returns: 93 A checkpoint instance. 94 """ 95 checkpoint = torch.load(f=file_path, weights_only=True) 96 checkpoint = cls(**checkpoint) 97 beta_scheduler_config = BetaSchedulerConfiguration( 98 **checkpoint.beta_scheduler_config 99 ) 100 checkpoint.beta_scheduler_config = beta_scheduler_config 101 return checkpoint
102
[docs] 103 def to_file(self, file_path: Union[str, pathlib.Path]) -> None: 104 """Saves a checkpoint to a file.""" 105 torch.save(dataclasses.asdict(self), file_path)
106 107 108@dataclass
[docs] 109class OldCheckpoint:
[docs] 110 epoch: int
[docs] 111 model_state_dict: Dict[str, Any]
[docs] 112 optimizer_state_dict: Dict[str, Any]
[docs] 113 scaler: Optional[GradScaler]
114 # beta_scheduler_config: BetaSchedulerConfiguration
[docs] 115 tensorboard_run_name: Optional[str] = None
[docs] 116 loss: Optional[float] = ( 117 None # TODO: remove legacy parameter and resave models 118 )
119 120 @classmethod
[docs] 121 def from_file(cls, file_path: str) -> "OldCheckpoint": 122 checkpoint = torch.load(f=file_path) 123 return cls(**checkpoint)
124
[docs] 125 def to_file(self, file_path: Union[str, pathlib.Path]) -> None: 126 torch.save(dataclasses.asdict(self), file_path)
127
[docs] 128 def to_new_checkpoint(self, beta_scheduler: BaseBetaScheduler) -> Checkpoint: 129 beta_scheduler_config = BetaSchedulerConfiguration( 130 steps=beta_scheduler.steps, 131 betas=beta_scheduler.betas, 132 alpha_bars=beta_scheduler.alpha_bars, 133 ) 134 return Checkpoint( 135 **dataclasses.asdict(self), beta_scheduler_config=beta_scheduler_config 136 )