Source code for diffusion_models.utils.tensorboard

 1from typing import Dict
 2from typing import Optional
 3
 4import torch
 5from torch.utils.tensorboard import SummaryWriter
 6
 7
[docs] 8class TensorboardManager: 9 def __init__(self, log_name: Optional[str] = None): 10 if log_name is None: 11 log_name = "fill this" ##########
[docs] 12 self.log_directory = f"../runs/{log_name}"
[docs] 13 self.summary_writer = SummaryWriter(log_dir=self.log_directory)
14
[docs] 15 def log_metrics(self, metrics: Dict[str, float], global_step: int): 16 for metric_name, value in metrics.items(): 17 self.summary_writer.add_scalar( 18 metric_name, value, global_step=global_step 19 )
20
[docs] 21 def log_images(self, tag: str, images: torch.Tensor, timestep: int): 22 self.summary_writer.add_images( 23 tag=tag, img_tensor=images, global_step=timestep 24 )