Source code for nntoolbox.components.merge

import torch
from torch import nn, Tensor
from typing import List


__all__ = ['Multiply', 'Mean', 'Sum']


[docs]class Multiply(nn.Module): def __init__(self, modules: List[nn.Module]): super().__init__() self.module_list = nn.ModuleList(modules)
[docs] def forward(self, input: Tensor) -> Tensor: return torch.stack([module(input) for module in self.module_list], dim=-1).prod(dim=-1)
[docs]class Sum(nn.Module): def __init__(self, modules: List[nn.Module]): super().__init__() self.module_list = nn.ModuleList(modules)
[docs] def forward(self, input: Tensor) -> Tensor: return torch.stack([module(input) for module in self.module_list], dim=-1).sum(dim=-1)
[docs]class Mean(nn.Module): def __init__(self, modules: List[nn.Module]): super().__init__() self.module_list = nn.ModuleList(modules)
[docs] def forward(self, input: Tensor) -> Tensor: return torch.stack([module(input) for module in self.module_list], dim=-1).mean(dim=-1)