1import dataclasses
2import pathlib
3from dataclasses import dataclass
4from typing import Any
5from typing import Dict
6from typing import Optional
7from typing import Union
8
9import torch
10from torch.cuda.amp import GradScaler
11
12from diffusion_models.gaussian_diffusion.beta_schedulers import (
13 BaseBetaScheduler,
14)
15
16
17@dataclass
[docs]
18class TrainingConfiguration:
19 """A training configuration for simple experiment management."""
21 """The name of the training."""
23 """The batch size used for training."""
[docs]
24 learning_rate: float
25 """The learning rate of the training."""
[docs]
26 number_of_epochs: int
27 """The number of epoch used for training."""
[docs]
28 checkpoint_rate: int = 100
29 """The rate at which checkpoints are saved.."""
[docs]
30 mixed_precision_training: bool = False # TODO: This is not complete yet
31 """Whether or not to use automatic mixed precision training."""
[docs]
32 gradient_clip: Optional[float] = None # TODO: This is not complete yet
33 """Whether or not to clip gradients."""
34
35
36@dataclass
[docs]
37class LogConfiguration:
38 """An object to manage logging configuration."""
40 """The rate at which training metrics are logged."""
[docs]
41 image_rate: int = 50
42 """The rate at which images are generated for visualization. This can be used to validate model performance."""
[docs]
43 number_of_images: int = 5
44 """The number of images to generate."""
45 # metrics: Dict[str, float] # TODO: consider Dict[str, Callable]
46
47
48@dataclass
[docs]
49class BetaSchedulerConfiguration:
50 """A simplified beta scheduler configuration."""
52 """The number of steps in the beta scheduler."""
54 """The beta values."""
[docs]
55 alpha_bars: torch.Tensor
56 """The alpha bar values."""
57
58@dataclass
[docs]
59class Checkpoint:
60 """A simplified checkpoint framework for easy saving and loading."""
62 """The current epoch."""
[docs]
63 model_state_dict: Dict[str, Any]
64 """The model state dict."""
[docs]
65 optimizer_state_dict: Dict[str, Any]
66 """The optimizer state dict."""
[docs]
67 scaler: Optional[GradScaler]
68 """The GradScaler instance."""
[docs]
69 beta_scheduler_config: BetaSchedulerConfiguration
70 """The beta scheduler configuration."""
[docs]
71 tensorboard_run_name: Optional[str] = None
72 """The name of the tensorboard run."""
[docs]
73 image_channels: int = 3
74 """The number of image channels used in the training."""
[docs]
75 loss: Optional[float] = (
76 None # TODO: remove legacy parameter and resave models
77 )
78 """The final loss value recorded.
79
80 Note:
81 This is a legacy parameter and will be removed in a future release.
82
83 """
84
85 @classmethod
[docs]
86 def from_file(cls, file_path: str) -> "Checkpoint":
87 """Load and instantiate a checkpoint from a file.
88
89 Args:
90 file_path: The path to the checkpoint file.
91
92 Returns:
93 A checkpoint instance.
94 """
95 checkpoint = torch.load(f=file_path)
96 checkpoint = cls(**checkpoint)
97 beta_scheduler_config = BetaSchedulerConfiguration(
98 **checkpoint.beta_scheduler_config
99 )
100 checkpoint.beta_scheduler_config = beta_scheduler_config
101 return checkpoint
102
[docs]
103 def to_file(self, file_path: Union[str, pathlib.Path]) -> None:
104 """Saves a checkpoint to a file."""
105 torch.save(dataclasses.asdict(self), file_path)
106
107
108@dataclass
[docs]
109class OldCheckpoint:
[docs]
111 model_state_dict: Dict[str, Any]
[docs]
112 optimizer_state_dict: Dict[str, Any]
[docs]
113 scaler: Optional[GradScaler]
114 # beta_scheduler_config: BetaSchedulerConfiguration
[docs]
115 tensorboard_run_name: Optional[str] = None
[docs]
116 loss: Optional[float] = (
117 None # TODO: remove legacy parameter and resave models
118 )
119
120 @classmethod
[docs]
121 def from_file(cls, file_path: str) -> "OldCheckpoint":
122 checkpoint = torch.load(f=file_path)
123 return cls(**checkpoint)
124
[docs]
125 def to_file(self, file_path: Union[str, pathlib.Path]) -> None:
126 torch.save(dataclasses.asdict(self), file_path)
127
[docs]
128 def to_new_checkpoint(self, beta_scheduler: BaseBetaScheduler) -> Checkpoint:
129 beta_scheduler_config = BetaSchedulerConfiguration(
130 steps=beta_scheduler.steps,
131 betas=beta_scheduler.betas,
132 alpha_bars=beta_scheduler.alpha_bars,
133 )
134 return Checkpoint(
135 **dataclasses.asdict(self), beta_scheduler_config=beta_scheduler_config
136 )