Source code for diffusion_models.utils.tensorboard

 1from typing import Dict
 2from typing import Optional
 3
 4import torch
 5from torch.utils.tensorboard import SummaryWriter
 6
 7
[docs] 8class TensorboardManager: 9 def __init__(self, log_name: Optional[str] = None): 10 """A tensorboard manager for simplified tensorboard logging. 11 12 Args: 13 log_name: The name of the tensorboard log run. 14 """ 15 if log_name is None: 16 log_name = "fill this" ##########
[docs] 17 self.log_directory = f"../runs/{log_name}"
18 """The directory where tensorboard logs are saved."""
[docs] 19 self.summary_writer = SummaryWriter(log_dir=self.log_directory)
20 """The tensorboard summary writer.""" 21
[docs] 22 def log_metrics(self, metrics: Dict[str, float], global_step: int): 23 """Log metrics to tensorboard. 24 25 Args: 26 metrics: A dictionary mapping metric names to values. 27 global_step: The step at which the metrics are recorded. 28 """ 29 for metric_name, value in metrics.items(): 30 self.summary_writer.add_scalar( 31 metric_name, value, global_step=global_step 32 )
33
[docs] 34 def log_images(self, tag: str, images: torch.Tensor, timestep: int): 35 """Log images to tensorboard. 36 37 Args: 38 tag: The name to give the images in tensorboard. 39 images: A tensor representing the images to log. 40 timestep: The timestep at which the images are produced. 41 42 """ 43 self.summary_writer.add_images( 44 tag=tag, img_tensor=images, global_step=timestep 45 )