1from typing import Dict
2from typing import Optional
3
4import torch
5from torch.utils.tensorboard import SummaryWriter
6
7
[docs]
8class TensorboardManager:
9 def __init__(self, log_name: Optional[str] = None):
10 if log_name is None:
11 log_name = "fill this" ##########
[docs]
12 self.log_directory = f"../runs/{log_name}"
[docs]
13 self.summary_writer = SummaryWriter(log_dir=self.log_directory)
14
[docs]
15 def log_metrics(self, metrics: Dict[str, float], global_step: int):
16 for metric_name, value in metrics.items():
17 self.summary_writer.add_scalar(
18 metric_name, value, global_step=global_step
19 )
20
[docs]
21 def log_images(self, tag: str, images: torch.Tensor, timestep: int):
22 self.summary_writer.add_images(
23 tag=tag, img_tensor=images, global_step=timestep
24 )