Source code for diffusion_models.utils.schemas

  1import dataclasses
  2import pathlib
  3from dataclasses import dataclass
  4from typing import Any
  5from typing import Dict
  6from typing import Optional
  7from typing import Union
  8
  9import torch
 10from torch.cuda.amp import GradScaler
 11
 12from diffusion_models.gaussian_diffusion.beta_schedulers import (
 13  BaseBetaScheduler,
 14)
 15
 16
 17@dataclass
[docs] 18class TrainingConfiguration: 19 """A training configuration for simple experiment management.""" 20
[docs] 21 training_name: str
22 """The name of the training."""
[docs] 23 batch_size: int
24 """The batch size used for training."""
[docs] 25 learning_rate: float
26 """The learning rate of the training."""
[docs] 27 number_of_epochs: int
28 """The number of epoch used for training."""
[docs] 29 checkpoint_rate: int = 100
30 """The rate at which checkpoints are saved.."""
[docs] 31 mixed_precision_training: bool = False # TODO: This is not complete yet
32 """Whether or not to use automatic mixed precision training."""
[docs] 33 gradient_clip: Optional[float] = None # TODO: This is not complete yet
34 """Whether or not to clip gradients."""
35 36 37@dataclass
[docs] 38class LogConfiguration: 39 """An object to manage logging configuration.""" 40
[docs] 41 log_rate: int = 10
42 """The rate at which training metrics are logged."""
[docs] 43 image_rate: int = 50
44 """The rate at which images are generated for visualization. This can be used to validate model performance."""
[docs] 45 number_of_images: int = 5
46 """The number of images to generate."""
47 # metrics: Dict[str, float] # TODO: consider Dict[str, Callable] 48 49 50@dataclass
[docs] 51class BetaSchedulerConfiguration: 52 """A simplified beta scheduler configuration.""" 53
[docs] 54 steps: int
55 """The number of steps in the beta scheduler."""
[docs] 56 betas: torch.Tensor
57 """The beta values."""
[docs] 58 alpha_bars: torch.Tensor
59 """The alpha bar values."""
60 61 62@dataclass
[docs] 63class Checkpoint: 64 """A simplified checkpoint framework for easy saving and loading.""" 65
[docs] 66 epoch: int
67 """The current epoch."""
[docs] 68 model_state_dict: Dict[str, Any]
69 """The model state dict."""
[docs] 70 optimizer_state_dict: Dict[str, Any]
71 """The optimizer state dict."""
[docs] 72 scaler: Optional[GradScaler]
73 """The GradScaler instance."""
[docs] 74 beta_scheduler_config: BetaSchedulerConfiguration
75 """The beta scheduler configuration."""
[docs] 76 tensorboard_run_name: Optional[str] = None
77 """The name of the tensorboard run."""
[docs] 78 image_channels: int = 3
79 """The number of image channels used in the training."""
[docs] 80 loss: Optional[float] = ( 81 None # TODO: remove legacy parameter and resave models 82 )
83 """The final loss value recorded. 84 85 Note: 86 This is a legacy parameter and will be removed in a future release. 87 88 """ 89 90 @classmethod
[docs] 91 def from_file( 92 cls, file_path: str, map_location: Optional[str] = None 93 ) -> "Checkpoint": 94 """Load and instantiate a checkpoint from a file. 95 96 Args: 97 file_path: The path to the checkpoint file. 98 map_location: A function, torch. device, string or a dict specifying how to remap storage location. 99 100 Returns: 101 A checkpoint instance. 102 """ 103 checkpoint = torch.load( 104 f=file_path, weights_only=True, map_location=map_location 105 ) 106 checkpoint = cls(**checkpoint) 107 beta_scheduler_config = BetaSchedulerConfiguration( 108 **checkpoint.beta_scheduler_config 109 ) 110 checkpoint.beta_scheduler_config = beta_scheduler_config 111 return checkpoint
112
[docs] 113 def to_file(self, file_path: Union[str, pathlib.Path]) -> None: 114 """Saves a checkpoint to a file.""" 115 torch.save(dataclasses.asdict(self), file_path)
116 117 118@dataclass
[docs] 119class OldCheckpoint:
[docs] 120 epoch: int
[docs] 121 model_state_dict: Dict[str, Any]
[docs] 122 optimizer_state_dict: Dict[str, Any]
[docs] 123 scaler: Optional[GradScaler]
124 # beta_scheduler_config: BetaSchedulerConfiguration
[docs] 125 tensorboard_run_name: Optional[str] = None
[docs] 126 loss: Optional[float] = ( 127 None # TODO: remove legacy parameter and resave models 128 )
129 130 @classmethod
[docs] 131 def from_file(cls, file_path: str) -> "OldCheckpoint": 132 checkpoint = torch.load(f=file_path) 133 return cls(**checkpoint)
134
[docs] 135 def to_file(self, file_path: Union[str, pathlib.Path]) -> None: 136 torch.save(dataclasses.asdict(self), file_path)
137
[docs] 138 def to_new_checkpoint(self, beta_scheduler: BaseBetaScheduler) -> Checkpoint: 139 beta_scheduler_config = BetaSchedulerConfiguration( 140 steps=beta_scheduler.steps, 141 betas=beta_scheduler.betas, 142 alpha_bars=beta_scheduler.alpha_bars, 143 ) 144 return Checkpoint( 145 **dataclasses.asdict(self), beta_scheduler_config=beta_scheduler_config 146 )