Source code for nntoolbox.components.pool
import torch
from torch import nn
[docs]class AveragePool(nn.Module):
def __init__(self, dim):
super(AveragePool, self).__init__()
self._dim = dim
[docs] def forward(self, input):
return torch.mean(input, dim=self._dim)
[docs]class MaxPool(nn.Module):
def __init__(self, dim):
super(MaxPool, self).__init__()
self._dim = dim
[docs] def forward(self, input):
return torch.max(input, dim=self._dim).values()
[docs]class ConcatPool(nn.Module):
def __init__(self, pool_dim, concat_dim):
super(ConcatPool, self).__init__()
self._pool_dim = pool_dim
self._concat_dim = concat_dim - 1 if pool_dim < concat_dim and concat_dim > 0 else concat_dim
[docs] def forward(self, input):
max = torch.max(input, dim=self._pool_dim).values
avg = torch.mean(input, dim=self._pool_dim)
return torch.cat([max, avg, input[-1]], dim=self._concat_dim)