1import dataclasses
2import pathlib
3from dataclasses import dataclass
4from typing import Any
5from typing import Dict
6from typing import Optional
7from typing import Union
8
9import torch
10from torch.cuda.amp import GradScaler
11
12from diffusion_models.gaussian_diffusion.beta_schedulers import (
13 BaseBetaScheduler,
14)
15
16
17@dataclass
[docs]
18class TrainingConfiguration:
[docs]
21 learning_rate: float
[docs]
22 number_of_epochs: int
23
[docs]
25 image_rate: int = 100
[docs]
26 checkpoint_rate: int = 100
27
[docs]
28 mixed_precision_training: bool = False # TODO: This is not complete yet
[docs]
29 gradient_clip: Optional[float] = None # TODO: This is not complete yet
30
31
32@dataclass
[docs]
33class LogConfiguration:
[docs]
35 image_rate: int = 50
[docs]
36 number_of_images: int = 5
37 # metrics: Dict[str, float] # TODO: consider Dict[str, Callable]
38
39
40@dataclass
[docs]
41class BetaSchedulerConfiguration:
[docs]
44 alpha_bars: torch.Tensor
45
46
47@dataclass
[docs]
48class Checkpoint:
[docs]
50 model_state_dict: Dict[str, Any]
[docs]
51 optimizer_state_dict: Dict[str, Any]
[docs]
52 scaler: Optional[GradScaler]
[docs]
53 beta_scheduler_config: BetaSchedulerConfiguration
[docs]
54 tensorboard_run_name: Optional[str] = None
[docs]
55 image_channels: int = 3
[docs]
56 loss: Optional[float] = (
57 None # TODO: remove legacy parameter and resave models
58 )
59
60 @classmethod
[docs]
61 def from_file(cls, file_path: str) -> "Checkpoint":
62 checkpoint = torch.load(f=file_path)
63 checkpoint = cls(**checkpoint)
64 beta_scheduler_config = BetaSchedulerConfiguration(
65 **checkpoint.beta_scheduler_config
66 )
67 checkpoint.beta_scheduler_config = beta_scheduler_config
68 return checkpoint
69
[docs]
70 def to_file(self, file_path: Union[str, pathlib.Path]) -> None:
71 torch.save(dataclasses.asdict(self), file_path)
72
73
74@dataclass
[docs]
75class OldCheckpoint:
[docs]
77 model_state_dict: Dict[str, Any]
[docs]
78 optimizer_state_dict: Dict[str, Any]
[docs]
79 scaler: Optional[GradScaler]
80 # beta_scheduler_config: BetaSchedulerConfiguration
[docs]
81 tensorboard_run_name: Optional[str] = None
[docs]
82 loss: Optional[float] = (
83 None # TODO: remove legacy parameter and resave models
84 )
85
86 @classmethod
[docs]
87 def from_file(cls, file_path: str) -> "OldCheckpoint":
88 checkpoint = torch.load(f=file_path)
89 return cls(**checkpoint)
90
[docs]
91 def to_file(self, file_path: Union[str, pathlib.Path]) -> None:
92 torch.save(dataclasses.asdict(self), file_path)
93
[docs]
94 def to_new_checkpoint(self, beta_scheduler: BaseBetaScheduler) -> Checkpoint:
95 beta_scheduler_config = BetaSchedulerConfiguration(
96 steps=beta_scheduler.steps,
97 betas=beta_scheduler.betas,
98 alpha_bars=beta_scheduler.alpha_bars,
99 )
100 return Checkpoint(
101 **dataclasses.asdict(self), beta_scheduler_config=beta_scheduler_config
102 )