Source code for nntoolbox.callbacks.logger
from torch.utils.tensorboard import SummaryWriter
from .callbacks import Callback
from typing import Sequence, Dict, Any
__all__ = ['Tensorboard', 'LossLogger', 'MultipleMetricLogger']
[docs]class Tensorboard(Callback):
def __init__(self, every_iter: int=1, every_epoch: int=1):
self._writer = SummaryWriter()
self._every_iter = every_iter
self._every_epoch = every_epoch
[docs] def on_batch_end(self, logs):
if logs["iter_cnt"] % self._every_iter == 0:
if "loss" in logs:
self._writer.add_scalar(
tag="Training loss",
scalar_value=logs["loss"].item(),
global_step=logs["iter_cnt"]
)
if "allocated_memory" in logs:
self._writer.add_scalar(
tag="Allocated memory",
scalar_value=logs["allocated_memory"],
global_step=logs["iter_cnt"]
)
[docs] def on_epoch_end(self, logs: Dict[str, Any]):
if logs["epoch"] % self._every_epoch == 0:
if "epoch_metrics" in logs:
for metric in logs["epoch_metrics"]:
self._writer.add_scalar(
tag= "Validation " + metric,
scalar_value=logs["epoch_metrics"][metric],
global_step=logs["epoch"]
)
if "draw" in logs and "tag" in logs:
for i in range(len(logs["tag"])):
self._writer.add_image(
tag=logs["tag"][i],
img_tensor=logs["draw"][i],
global_step=logs["epoch"]
)
return False
[docs]class LossLogger(Callback):
def __init__(self, print_every=1000):
self._print_every = print_every
[docs] def on_batch_end(self, logs):
if logs["iter_cnt"] % self._print_every == 0:
print("Iteration " + str(logs["iter_cnt"]) + ": " + str(logs["loss"]))
[docs]class MultipleMetricLogger(Callback):
def __init__(self, iter_metrics: Sequence[str]=[], epoch_metrics: Sequence[str]=[], print_every=1000):
self._print_every = print_every
self._iter_metrics = iter_metrics
self._epoch_metrics = epoch_metrics
[docs] def on_batch_end(self, logs):
if logs["iter_cnt"] % self._print_every == 0:
print("Iteration " + str(logs["iter_cnt"]) + " with:" )
for metric in self._iter_metrics:
assert metric in logs
print(metric + ": " + str(logs[metric]))
[docs] def on_epoch_end(self, logs) -> bool:
print("Epoch " + str(logs["epoch"]) + " with:")
for metric in self._epoch_metrics:
assert metric in logs
print(metric + ": " + str(logs[metric]))
return False