nntoolbox.callbacks.swa module

class nntoolbox.callbacks.swa.StochasticWeightAveraging(learner, average_after: int, update_every: int = 1, timescale: str = 'iter', device=device(type='cpu'))[source]

Bases: nntoolbox.callbacks.callbacks.Callback

get_averaged_model() → torch.nn.modules.module.Module[source]

Return the post-training average model :return: the averaged model

on_batch_end(logs: Dict[str, Any])[source]
on_epoch_end(logs: Dict[str, Any]) → bool[source]
on_train_end()[source]