Source code for nntoolbox.components.maxout

import torch
from torch import nn, Tensor


[docs]class MaxoutLinear(nn.Module): """ A linear maxout layer: output_i = max_{j = 1,...,k} (w_1 input + b_1, w_2 input + b_2,..., w_k input + b_k) References: Ian J. Goodfellow et al. "Maxout Networks." https://arxiv.org/pdf/1302.4389.pdf """ def __init__(self, in_features: int, out_features: int, nb_features: int, bias: bool=True): super(MaxoutLinear, self).__init__() self._features = nn.ModuleList( [nn.Linear(in_features=in_features, out_features=out_features, bias=bias) for _ in range(nb_features)] )
[docs] def forward(self, input: Tensor) -> Tensor: """ :param input: (batch_size, in_features) :return: (batch_size, out_features) """ features = [self._features[i](input) for i in range(len(self._features))] return torch.max(torch.stack(features, dim=-1), dim=-1)[0]