Source code for nntoolbox.tabular.components.components
import torch
from torch import nn
import math
[docs]class CrossLayer(nn.Module):
"""
Implement a (residual) crossing layer for Deep and Cross Net (DCN):
x_{l+1} = x_0 x^T_l w + b + x_l
Based on: https://arxiv.org/pdf/1708.05123.pdf
"""
def __init__(self, n_hidden, bias=True, return_first=False):
super(CrossLayer, self).__init__()
self.weight = nn.Parameter(torch.Tensor(1, n_hidden))
if bias:
self.bias = nn.Parameter(torch.Tensor(n_hidden))
else:
self.register_parameter('bias', None)
self.reset_parameters()
self._return_first = return_first
[docs] def reset_parameters(self):
"""
Reset the parameters of the model
"""
nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
if self.bias is not None:
nn.init.uniform_(self.bias, 0, 0)
[docs] def forward(self, inputs):
"""
:param inputs: a tuple: first element is the orinal features, second element is the output of last layer
:return:
"""
input, first = inputs
interaction = torch.bmm(
first.view(first.shape[0], first.shape[1], 1),
input.view(input.shape[0], 1, input.shape[1])
)
interaction = torch.bmm(
interaction,
self.weight.t().view(-1, self.weight.shape[1], self.weight.shape[0]).repeat(interaction.shape[0], 1, 1)
).view(interaction.shape[0], interaction.shape[1])
if self.bias is not None:
interaction += self.bias
if self._return_first:
return (interaction + input, first)
else:
return interaction + input