Source code for nntoolbox.ensembler.cv
from typing import Callable
from sklearn.model_selection import KFold
from torch.utils.data import Subset, Dataset
from torch import nn
from ..utils import load_model
from typing import List
__all__ = ['CVEnsembler']
[docs]class CVEnsembler:
"""
Create an ensemble of identical models, each trained on a separate (k - 1) folds of the data
and validated on the remaining fold.
References:
Anders Krogh and Jesper Vedelsby. "Neural Network Ensembles, Cross Validation, and Active Learning."
https://papers.nips.cc/paper/1001-neural-network-ensembles-cross-validation-and-active-learning.pdf
"""
def __init__(
self, data: Dataset, path: str, n_model: int, model_fn: Callable[..., nn.Module],
learn_fn: Callable[[Dataset, Dataset, nn.Module, str], None]
):
"""
:param data: The full dataset
:param n_model: number of models to generated for the ensemble
:param model_fn: a function that returns a model
:param learn_fn: a function that takes in a train dataset, a val dataset, a model and a save path
and save the learned model at save path
"""
self.model_fn = model_fn
self.n_model = n_model
self.kf = KFold(n_splits=n_model)
self.data = data
self.learn_fn = learn_fn
self.path = path
[docs] def learn(self):
model_ind = 0
for train_idx, val_idx in self.kf.split(list(range(len(self.data)))):
save_path = self.path + "model_" + str(model_ind) + ".pt"
train_data = Subset(self.data, train_idx)
val_data = Subset(self.data, val_idx)
model = self.model_fn()
self.learn_fn(train_data, val_data, model, save_path)
model_ind += 1
[docs] def get_models(self) -> List[nn.Module]:
models = []
for i in range(self.n_model):
model = self.model_fn()
load_path = self.path + "model_" + str(i) + ".pt"
load_model(model, load_path)
models.append(model)
return models