1from typing import Dict
2from typing import Optional
3
4import torch
5from torch.utils.tensorboard import SummaryWriter
6
7
[docs]
8class TensorboardManager:
9 def __init__(self, log_name: Optional[str] = None):
10 """A tensorboard manager for simplified tensorboard logging.
11
12 Args:
13 log_name: The name of the tensorboard log run.
14 """
15 if log_name is None:
16 log_name = "fill this" ##########
[docs]
17 self.log_directory = f"../runs/{log_name}"
18 """The directory where tensorboard logs are saved."""
[docs]
19 self.summary_writer = SummaryWriter(log_dir=self.log_directory)
20 """The tensorboard summary writer."""
21
[docs]
22 def log_metrics(self, metrics: Dict[str, float], global_step: int):
23 """Log metrics to tensorboard.
24
25 Args:
26 metrics: A dictionary mapping metric names to values.
27 global_step: The step at which the metrics are recorded.
28 """
29 for metric_name, value in metrics.items():
30 self.summary_writer.add_scalar(
31 metric_name, value, global_step=global_step
32 )
33
[docs]
34 def log_images(self, tag: str, images: torch.Tensor, timestep: int):
35 """Log images to tensorboard.
36
37 Args:
38 tag: The name to give the images in tensorboard.
39 images: A tensor representing the images to log.
40 timestep: The timestep at which the images are produced.
41
42 """
43 self.summary_writer.add_images(
44 tag=tag, img_tensor=images, global_step=timestep
45 )