nntoolbox.callbacks.fge module

class nntoolbox.callbacks.fge.FastGeometricEnsembling(model: torch.nn.modules.module.Module, max_n_model: int, save_after: int, save_every: int = 1, timescale: str = 'iter')[source]

Bases: nntoolbox.callbacks.callbacks.Callback

get_models() → List[torch.nn.modules.module.Module][source]

Return the post-training average model :return: the averaged model

on_batch_end(logs: Dict[str, Any])[source]
on_epoch_end(logs: Dict[str, Any]) → bool[source]
on_train_end()[source]