Source code for nntoolbox.models.ensemble
from typing import List, Optional
import numpy as np
__all__ = ['Ensemble']
[docs]class Ensemble:
def __init__(self, models, model_weights: Optional[List[float]]=None):
assert len(models) > 0
if model_weights is not None:
assert len(models) == len(model_weights)
assert len(model_weights.shape) == 1
self.models = models
if model_weights is not None:
self.model_weights = np.array(model_weights) / sum(model_weights)
else:
self.model_weights = np.array([1 / len(self.models) for _ in range(len(self.models))])