from typing import Iterable, Dict, Any, List
from ..metrics import Metric
from torch import Tensor
[docs]class Callback:
order: int=0
[docs] def on_train_begin(self): pass
[docs] def on_epoch_begin(self): pass
[docs] def on_batch_begin(self, data: Dict[str, Tensor], train) -> Dict[str, Tensor]: return data
[docs] def after_outputs(self, outputs: Dict[str, Tensor], train: bool) -> Dict[str, Tensor]: return outputs
[docs] def after_losses(self, losses: Dict[str, Tensor], train: bool) -> Dict[str, Tensor]: return losses
[docs] def on_backward_begin(self) -> bool: return True # if false, skip backward
[docs] def after_backward(self) -> bool: return True # whether to continue with iteration
[docs] def after_step(self) -> bool: return True
# def on_phase_begin(self): pass
[docs] def on_epoch_end(self, logs: Dict[str, Any]) -> bool: return False # whether to stop training
# def on_phase_end(self): pass
[docs] def on_batch_end(self, logs: Dict[str, Any]): pass
[docs] def on_train_end(self): pass
[docs]class GroupCallback(Callback):
"""
Group several callbacks together (UNTESTED)
"""
def __init__(self, callbacks: List[Callback]):
self._callbacks = callbacks
self.order = callbacks[0].order
[docs] def on_train_begin(self):
for cb in self._callbacks: cb.on_train_begin()
[docs] def on_epoch_begin(self):
for cb in self._callbacks: cb.on_epoch_begin()
[docs] def on_batch_begin(self, data: Dict[str, Tensor], train) -> Dict[str, Tensor]:
for cb in self._callbacks:
data = cb.on_batch_begin(data, train)
return data
[docs] def after_outputs(self, outputs: Dict[str, Tensor], train: bool) -> Dict[str, Tensor]:
for cb in self._callbacks:
outputs = cb.after_outputs(outputs, train)
return outputs
[docs] def after_losses(self, losses: Dict[str, Tensor], train: bool) -> Dict[str, Tensor]:
for cb in self._callbacks:
losses = cb.after_losses(losses, train)
return losses
[docs] def on_backward_begin(self) -> bool:
ret = True
for cb in self._callbacks: ret = ret and cb.on_backward_begin()
return ret # if false, skip backward
[docs] def after_backward(self) -> bool:
ret = True
for cb in self._callbacks: ret = ret and cb.after_backward()
return ret # whether to continue with iteration
[docs] def after_step(self) -> bool:
ret = True
for cb in self._callbacks: ret = ret and cb.after_step()
return ret # whether to stop training
# def on_phase_begin(self): pass
[docs] def on_epoch_end(self, logs: Dict[str, Any]) -> bool:
ret = False
for cb in self._callbacks: ret = ret or cb.on_epoch_end(logs)
return ret # whether to stop training
# def on_phase_end(self): pass
[docs] def on_batch_end(self, logs: Dict[str, Any]):
for cb in self._callbacks: cb.on_batch_end(logs)
[docs] def on_train_end(self):
for cb in self._callbacks: cb.on_train_end()
[docs]class CallbackHandler:
def __init__(
self, learner, n_epoch: int, callbacks: List[Callback]=None,
metrics: Dict[str, Metric]=None, final_metric: str='accuracy'
):
if metrics is not None:
assert final_metric in metrics
if callbacks is not None:
for callback in callbacks:
if isinstance(callback, GroupCallback):
for subcb in callback._callbacks:
subcb.learner = learner
subcb.n_epoch = n_epoch
else:
callback.learner = learner
callback.n_epoch = n_epoch
callbacks.sort(key=lambda cb: cb.order)
self._callbacks = callbacks
self._metrics = metrics
self._final_metric = final_metric
self._iter_cnt = 0
self._epoch = 0
self.learner = learner
[docs] def on_train_begin(self):
if self._callbacks is not None:
for callback in self._callbacks:
callback.on_train_begin()
[docs] def on_epoch_begin(self):
if self._callbacks is not None:
for callback in self._callbacks:
callback.on_epoch_begin()
[docs] def on_batch_begin(self, data: Dict[str, Tensor], train: bool) -> Dict[str, Tensor]:
if self._callbacks is not None:
for callback in self._callbacks:
data = callback.on_batch_begin(data, train)
return data
[docs] def after_outputs(self, outputs: Dict[str, Tensor], train) -> Dict[str, Tensor]:
if self._callbacks is not None:
for callback in self._callbacks:
outputs = callback.after_outputs(outputs, train)
return outputs
[docs] def after_losses(self, losses: Dict[str, Tensor], train) -> Dict[str, Tensor]:
if self._callbacks is not None:
for callback in self._callbacks:
losses = callback.after_losses(losses, train)
return losses
[docs] def on_backward_begin(self):
ret = True
if self._callbacks is not None:
for callback in self._callbacks:
ret = ret and callback.on_backward_begin()
return ret
[docs] def after_backward(self) -> bool:
ret = True
if self._callbacks is not None:
for callback in self._callbacks:
ret = ret and callback.after_backward()
return ret
[docs] def after_step(self) -> bool:
ret = True
if self._callbacks is not None:
for callback in self._callbacks:
ret = ret and callback.after_step()
return ret
[docs] def on_batch_end(self, logs: Dict[str, Any]):
logs["iter_cnt"] = self._iter_cnt
if self._callbacks is not None:
for callback in self._callbacks:
callback.on_batch_end(logs)
self._iter_cnt += 1
[docs] def on_epoch_end(self, logs: Dict[str, Any]) -> bool:
print("Evaluate for epoch " + str(self._epoch) + ": ")
logs["epoch"] = self._epoch
stop_training = False
if self._metrics is not None:
epoch_metrics = dict()
for metric in self._metrics:
epoch_metrics[metric] = self._metrics[metric](logs)
print(metric + ": " + str(epoch_metrics[metric]))
logs["epoch_metrics"] = epoch_metrics
if self._callbacks is not None:
for callback in self._callbacks:
stop_training = stop_training or callback.on_epoch_end(logs)
self._epoch += 1
return stop_training
[docs] def on_train_end(self) -> float:
if self._callbacks is not None:
for callback in self._callbacks:
callback.on_train_end()
if self._metrics is None:
return 0.0
else:
return self._metrics[self._final_metric].get_best()
[docs]class EarlyStoppingCB(Callback):
def __init__(self, monitor='loss', min_delta: int=0, patience: int=0, mode: str='min', baseline: float=None):
self._monitor = monitor
self._min_delta = min_delta
self._patience = patience
self._cur_p = 0
self._mode = mode
self._baseline = baseline
self._metrics = []
[docs] def on_epoch_end(self, logs: Dict[str, Any]) -> bool:
epoch_metrics = logs['epoch_metrics']
assert self._monitor in epoch_metrics
self._metrics.append(epoch_metrics[self._monitor])
if self._mode == "min":
if epoch_metrics[self._monitor] == min(self._metrics) and \
(self._baseline is None or epoch_metrics[self._monitor] <= self._baseline):
self._cur_p = 0
else:
self._cur_p += 1
else:
if epoch_metrics[self._monitor] == max(self._metrics) and \
(self._baseline is None or epoch_metrics[self._monitor] >= self._baseline):
self._cur_p = 0
else:
self._cur_p += 1
return self._cur_p > self._patience