1import dataclasses
2import pathlib
3from dataclasses import dataclass
4from typing import Any
5from typing import Dict
6from typing import Optional
7from typing import Union
8
9import torch
10from torch.cuda.amp import GradScaler
11
12from diffusion_models.gaussian_diffusion.beta_schedulers import (
13 BaseBetaScheduler,
14)
15
16
17@dataclass
[docs]
18class TrainingConfiguration:
19 """A training configuration for simple experiment management."""
20
22 """The name of the training."""
24 """The batch size used for training."""
[docs]
25 learning_rate: float
26 """The learning rate of the training."""
[docs]
27 number_of_epochs: int
28 """The number of epoch used for training."""
[docs]
29 checkpoint_rate: int = 100
30 """The rate at which checkpoints are saved.."""
[docs]
31 mixed_precision_training: bool = False # TODO: This is not complete yet
32 """Whether or not to use automatic mixed precision training."""
[docs]
33 gradient_clip: Optional[float] = None # TODO: This is not complete yet
34 """Whether or not to clip gradients."""
35
36
37@dataclass
[docs]
38class LogConfiguration:
39 """An object to manage logging configuration."""
40
42 """The rate at which training metrics are logged."""
[docs]
43 image_rate: int = 50
44 """The rate at which images are generated for visualization. This can be used to validate model performance."""
[docs]
45 number_of_images: int = 5
46 """The number of images to generate."""
47 # metrics: Dict[str, float] # TODO: consider Dict[str, Callable]
48
49
50@dataclass
[docs]
51class BetaSchedulerConfiguration:
52 """A simplified beta scheduler configuration."""
53
55 """The number of steps in the beta scheduler."""
57 """The beta values."""
[docs]
58 alpha_bars: torch.Tensor
59 """The alpha bar values."""
60
61
62@dataclass
[docs]
63class Checkpoint:
64 """A simplified checkpoint framework for easy saving and loading."""
65
67 """The current epoch."""
[docs]
68 model_state_dict: Dict[str, Any]
69 """The model state dict."""
[docs]
70 optimizer_state_dict: Dict[str, Any]
71 """The optimizer state dict."""
[docs]
72 beta_scheduler_config: BetaSchedulerConfiguration
73 """The beta scheduler configuration."""
[docs]
74 scaler: Optional[GradScaler] = None
75 """The GradScaler instance."""
[docs]
76 tensorboard_run_name: Optional[str] = None
77 """The name of the tensorboard run."""
[docs]
78 image_channels: int = 3
79 """The number of image channels used in the training."""
[docs]
80 loss: Optional[float] = (
81 None # TODO: remove legacy parameter and resave models
82 )
83 """The final loss value recorded.
84
85 Note:
86 This is a legacy parameter and will be removed in a future release.
87
88 """
89
90 @classmethod
[docs]
91 def from_file(
92 cls, file_path: str, map_location: Optional[str] = None
93 ) -> "Checkpoint":
94 """Load and instantiate a checkpoint from a file.
95
96 Args:
97 file_path: The path to the checkpoint file.
98 map_location: A function, torch. device, string or a dict specifying how to remap storage location.
99
100 Returns:
101 A checkpoint instance.
102 """
103 checkpoint = torch.load(
104 f=file_path, weights_only=True, map_location=map_location
105 )
106 checkpoint = cls(**checkpoint)
107 beta_scheduler_config = BetaSchedulerConfiguration(
108 **checkpoint.beta_scheduler_config
109 )
110 checkpoint.beta_scheduler_config = beta_scheduler_config
111 return checkpoint
112
[docs]
113 def to_file(self, file_path: Union[str, pathlib.Path]) -> None:
114 """Saves a checkpoint to a file."""
115 torch.save(dataclasses.asdict(self), file_path)
116
117
118@dataclass
[docs]
119class OldCheckpoint:
[docs]
121 model_state_dict: Dict[str, Any]
[docs]
122 optimizer_state_dict: Dict[str, Any]
[docs]
123 scaler: Optional[GradScaler]
124 # beta_scheduler_config: BetaSchedulerConfiguration
[docs]
125 tensorboard_run_name: Optional[str] = None
[docs]
126 loss: Optional[float] = (
127 None # TODO: remove legacy parameter and resave models
128 )
129
130 @classmethod
[docs]
131 def from_file(cls, file_path: str) -> "OldCheckpoint":
132 checkpoint = torch.load(f=file_path)
133 return cls(**checkpoint)
134
[docs]
135 def to_file(self, file_path: Union[str, pathlib.Path]) -> None:
136 torch.save(dataclasses.asdict(self), file_path)
137
[docs]
138 def to_new_checkpoint(self, beta_scheduler: BaseBetaScheduler) -> Checkpoint:
139 beta_scheduler_config = BetaSchedulerConfiguration(
140 steps=beta_scheduler.steps,
141 betas=beta_scheduler.betas,
142 alpha_bars=beta_scheduler.alpha_bars,
143 )
144 return Checkpoint(
145 **dataclasses.asdict(self), beta_scheduler_config=beta_scheduler_config
146 )
147
148
149@dataclass
[docs]
150class Timestep:
[docs]
151 current: torch.Tensor
[docs]
152 previous: Optional[torch.Tensor] = None