Source code for betty.logging.logger_tensorboard
import atexit
import os
import socket
from datetime import datetime
import torch
try:
from torch.utils.tensorboard import SummaryWriter
HAS_TENSORBOARD = True
except ImportError:
HAS_TENSORBOARD = False
from betty.logging.logger_base import LoggerBase
[docs]class TensorBoardLogger(LoggerBase):
def __init__(self, comment=""):
atexit.register(self.close)
current_time = datetime.now().strftime("%b%d_%H-%M-%S")
log_dir = os.path.join(
"betty_tensorboard", current_time + "_" + socket.gethostname() + comment
)
self.writer = SummaryWriter(log_dir=log_dir)
[docs] def close(self):
"""
Close PyTorch's tensorboard ``SummaryWriter``.
"""
self.writer.close()
[docs] def log(self, stats, tag=None, step=None):
"""
Log metrics/stats to PyTorch tensorboard.
:param stats: Dictoinary of values and their names to be recorded
:type stats: dict
:param tag: Data identifier
:type tag: str, optional
:param step: step value associated with ``stats`` to record
:type step: int, optional
"""
if not HAS_TENSORBOARD:
return
if stats is None:
return
for key, values in stats.items():
prefix = "" if tag is None else tag + "/"
key_extended = prefix + key
if isinstance(values, tuple) or isinstance(values, list):
for value_idx, value in enumerate(values):
full_key = key_extended + "_" + str(value_idx)
if torch.is_tensor(value):
value = value.item()
self.writer.add_scalar(full_key, value, step)
else:
value = values
full_key = key_extended
if torch.is_tensor(value):
value = value.item()
self.writer.add_scalar(full_key, value, step)