Source code for nntoolbox.components.mixture

"""Implement mixture of probability distribution layers"""
import torch
from torch import Tensor, nn
from torch.nn import Module
import torch.nn.functional as F
from typing import List, Union, Tuple


__all__ = ['MixtureOfGaussian', 'MixtureOfExpert']


[docs]class MixtureOfGaussian(nn.Linear): """ A layer that generates means, stds and mixing coefficients of a mixture of gaussian distributions. Used as the final layer of a mixture of (Gaussian) density network. Only support isotropic covariances for the components. References: Christopher Bishop. "Pattern Recognition and Machine Learning" """ def __init__(self, in_features: int, out_features: int, n_dist: int, bias: bool=True): assert n_dist > 0 and in_features > 0 and out_features > 0 self.n_dist = n_dist super(MixtureOfGaussian, self).__init__(in_features, n_dist * (2 + out_features), bias)
[docs] def forward(self, input: Tensor) -> Tuple[Tensor, Tensor, Tensor]: """ :param input: :return: means, stds and mixing coefficients """ features = super().forward(input) mixing_coeffs = F.softmax(features[:, :self.n_dist], dim=-1) stds = torch.exp(features[:, self.n_dist:self.n_dist * 2]) means = features[:, self.n_dist * 2:] return means, stds, mixing_coeffs
[docs]class MixtureOfExpert(Module): def __init__(self, experts: List[Module], gate: Module, return_mixture: bool=True): """ :param experts: list of separate expert networks. Each must take the same input and return output of same dimensionality :param gate: take the input and output (un-normalized) score for each expert """ super(MixtureOfExpert, self).__init__() self.experts = nn.ModuleList(experts) self.gate = gate self.softmax = nn.Softmax(dim=-1) self.return_mixture = return_mixture
[docs] def forward(self, input: Tensor) -> Union[Tuple[Tensor, Tensor], Tensor]: """ :param input: :return: if return_mixture, return the mixture of expert output; else return both expert score and expert output (with the n_expert channel coming last) """ expert_scores = self.softmax(self.gate(input)) expert_outputs = torch.stack([expert(input) for expert in self.experts], dim=-1) expert_scores = expert_scores.view( list(expert_scores.shape)[:-1] + [1 for _ in range(len(expert_outputs.shape) - len(expert_scores.shape))] + list(expert_scores.shape)[-1:] ) if self.return_mixture: return torch.sum(expert_outputs * expert_scores, dim=-1) else: return expert_outputs, expert_scores