Source code for nntoolbox.callbacks.nan
from .callbacks import Callback
from typing import Dict, Any
from torch import Tensor
from ..utils import is_nan
from warnings import warn
__all__ = ['NaNWarner', 'SkipNaN']
[docs]class NaNWarner(Callback):
[docs] def on_batch_end(self, logs: Dict[str, Any]):
for key in logs:
if isinstance(logs[key], Tensor) and is_nan(logs[key]):
warn(key + " becomes NaN at iteration " + str(logs["iter_cnt"]))
[docs]class SkipNaN(Callback):
"""
Skip when loss or output is nan (UNTESTED)
"""
[docs] def after_outputs(self, outputs: Dict[str, Tensor], train: bool) -> bool:
for key in outputs:
if is_nan(outputs[key]):
print("One of the loss is nan. Skip")
return False
[docs] def after_losses(self, losses: Dict[str, Tensor], train: bool) -> bool:
for key in losses:
if is_nan(losses[key]):
print("One of the losses is nan. Skip")
self.learner._optimizer.zero_grad()
return False
class TerminateOnNaN(Callback):
"""
Terminate training when encounter NaN (INCOMPLETE)
"""
def on_batch_end(self, logs: Dict[str, Any]):
for key in logs:
if isinstance(logs[key], Tensor) and is_nan(logs[key]):
raise ValueError(key + " becomes NaN at iteration " + str(logs["iter_cnt"]))