Source code for nntoolbox.callbacks.debug
"""
Implement a debug callback. Adapt from fastai course2 v3 notebook 11 a
"""
from .callbacks import Callback
from typing import Callable, Dict, Any
from torch import Tensor
CALLBACK_STEPS = [
'on_train_begin', 'on_epoch_begin', 'on_batch_begin',
'after_outputs', 'after_losses', 'on_backward_begin',
'after_backward', 'after_step', 'on_batch_end',
'on_epoch_end', 'on_train_end'
]
# UNTESTED
[docs]class DebugCallback(Callback):
def __init__(self, step_to_debug: str, func):
assert step_to_debug in CALLBACK_STEPS
self.step_to_debug = step_to_debug
self.func = func
[docs] def on_train_begin(self):
if self.step_to_debug == 'on_train_begin':
self.func(self.learner)
[docs] def on_epoch_begin(self):
if self.step_to_debug == 'on_epoch_begin':
self.func(self.learner)
[docs] def on_batch_begin(self, data: Dict[str, Tensor], train) -> Dict[str, Tensor]:
if self.step_to_debug == 'on_batch_begin':
self.func(self.learner)
return data
[docs] def after_outputs(self, outputs: Dict[str, Tensor], train: bool) -> Dict[str, Tensor]:
if self.step_to_debug == 'after_outputs':
self.func(self.learner)
return outputs
[docs] def after_losses(self, losses: Dict[str, Tensor], train: bool) -> Dict[str, Tensor]:
if self.step_to_debug == 'after_losses':
self.func(self.learner)
return losses
[docs] def on_backward_begin(self) -> bool:
if self.step_to_debug == 'on_backward_begin':
self.func(self.learner)
return True # if false, skip backward
[docs] def after_backward(self) -> bool:
if self.step_to_debug == 'after_backward':
self.func(self.learner)
return True # whether to continue with iteration
[docs] def after_step(self) -> bool:
if self.step_to_debug == 'after_step':
self.func(self.learner)
return True
[docs] def on_batch_end(self, logs: Dict[str, Any]):
if self.step_to_debug == 'on_batch_end':
self.func(self.learner)
[docs] def on_epoch_end(self, logs: Dict[str, Any]) -> bool:
if self.step_to_debug == 'on_epoch_end':
self.func(self.learner)
return super().on_epoch_end(logs) # whether to stop training
[docs] def on_train_end(self):
if self.step_to_debug == 'on_train_end':
self.func(self.learner)