Source code for diffusion_models.utils.schemas

  1import dataclasses
  2import pathlib
  3from dataclasses import dataclass
  4from typing import Any
  5from typing import Dict
  6from typing import Optional
  7from typing import Union
  8
  9import torch
 10from torch.cuda.amp import GradScaler
 11
 12from diffusion_models.gaussian_diffusion.beta_schedulers import (
 13  BaseBetaScheduler,
 14)
 15
 16
 17@dataclass
[docs] 18class TrainingConfiguration:
[docs] 19 training_name: str
[docs] 20 batch_size: int
[docs] 21 learning_rate: float
[docs] 22 number_of_epochs: int
23
[docs] 24 log_rate: int = 10
[docs] 25 image_rate: int = 100
[docs] 26 checkpoint_rate: int = 100
27
[docs] 28 mixed_precision_training: bool = False # TODO: This is not complete yet
[docs] 29 gradient_clip: Optional[float] = None # TODO: This is not complete yet
30 31 32@dataclass
[docs] 33class LogConfiguration:
[docs] 34 log_rate: int = 10
[docs] 35 image_rate: int = 50
[docs] 36 number_of_images: int = 5
37 # metrics: Dict[str, float] # TODO: consider Dict[str, Callable] 38 39 40@dataclass
[docs] 41class BetaSchedulerConfiguration:
[docs] 42 steps: int
[docs] 43 betas: torch.Tensor
[docs] 44 alpha_bars: torch.Tensor
45 46 47@dataclass
[docs] 48class Checkpoint:
[docs] 49 epoch: int
[docs] 50 model_state_dict: Dict[str, Any]
[docs] 51 optimizer_state_dict: Dict[str, Any]
[docs] 52 scaler: Optional[GradScaler]
[docs] 53 beta_scheduler_config: BetaSchedulerConfiguration
[docs] 54 tensorboard_run_name: Optional[str] = None
[docs] 55 image_channels: int = 3
[docs] 56 loss: Optional[float] = ( 57 None # TODO: remove legacy parameter and resave models 58 )
59 60 @classmethod
[docs] 61 def from_file(cls, file_path: str) -> "Checkpoint": 62 checkpoint = torch.load(f=file_path) 63 checkpoint = cls(**checkpoint) 64 beta_scheduler_config = BetaSchedulerConfiguration( 65 **checkpoint.beta_scheduler_config 66 ) 67 checkpoint.beta_scheduler_config = beta_scheduler_config 68 return checkpoint
69
[docs] 70 def to_file(self, file_path: Union[str, pathlib.Path]) -> None: 71 torch.save(dataclasses.asdict(self), file_path)
72 73 74@dataclass
[docs] 75class OldCheckpoint:
[docs] 76 epoch: int
[docs] 77 model_state_dict: Dict[str, Any]
[docs] 78 optimizer_state_dict: Dict[str, Any]
[docs] 79 scaler: Optional[GradScaler]
80 # beta_scheduler_config: BetaSchedulerConfiguration
[docs] 81 tensorboard_run_name: Optional[str] = None
[docs] 82 loss: Optional[float] = ( 83 None # TODO: remove legacy parameter and resave models 84 )
85 86 @classmethod
[docs] 87 def from_file(cls, file_path: str) -> "OldCheckpoint": 88 checkpoint = torch.load(f=file_path) 89 return cls(**checkpoint)
90
[docs] 91 def to_file(self, file_path: Union[str, pathlib.Path]) -> None: 92 torch.save(dataclasses.asdict(self), file_path)
93
[docs] 94 def to_new_checkpoint(self, beta_scheduler: BaseBetaScheduler) -> Checkpoint: 95 beta_scheduler_config = BetaSchedulerConfiguration( 96 steps=beta_scheduler.steps, 97 betas=beta_scheduler.betas, 98 alpha_bars=beta_scheduler.alpha_bars, 99 ) 100 return Checkpoint( 101 **dataclasses.asdict(self), beta_scheduler_config=beta_scheduler_config 102 )