Source code for nntoolbox.tabular.models.collab
import torch
from torch import nn, Tensor
from typing import Optional
from ...components import MLP
import numpy as np
from ...init import sqrt_uniform_init
__all__ = ['CollabFiltering', 'NonLinearCF']
[docs]class CollabFiltering(nn.Module):
"""A simple collaborative model"""
def __init__(self, n_users: int, n_products: int, embedding_dim: int):
super().__init__()
self.users = nn.Embedding(num_embeddings=n_users, embedding_dim=embedding_dim)
self.products = nn.Embedding(num_embeddings=n_products, embedding_dim=embedding_dim)
self.embedding_dim = embedding_dim
sqrt_uniform_init(self)
[docs] def forward(self, inputs: Tensor) -> Tensor:
"""
:param inputs: the pair of user-product of shape (batch_size, 2)
:return: (batch_size, 1)
"""
return (self.users(inputs[:, 0]) * self.products(inputs[:, 1])).sum(-1, keepdim=True)
[docs] def get_score(self, users: Tensor, products: Tensor) -> Tensor:
"""
Return the score of corresponding pairs of users-products
:param users: (batch_size, )
:param products: (batch_size, )
"""
return self.forward(torch.stack((users, products), dim=-1))
[docs]class NonLinearCF(nn.Module):
"""
A non-linear collaborative model. If no body model is provided, default to a one-hidden layer net
"""
def __init__(
self, n_users: int, n_products: int, user_dim: int, product_dim: int, body: Optional[nn.Module]=None
):
super().__init__()
self.users = nn.Embedding(num_embeddings=n_users, embedding_dim=user_dim)
self.products = nn.Embedding(num_embeddings=n_products, embedding_dim=product_dim)
sqrt_uniform_init(self)
if body is None:
self.body = MLP(
in_features=product_dim + user_dim,
hidden_layer_sizes=(2 ** int(np.log2(np.sqrt(product_dim + user_dim))), ),
out_features=1
)
else:
self.body = body
[docs] def forward(self, inputs: Tensor) -> Tensor:
"""
Return the score of corresponding pairs of products and users
:param inputs: the pair of user-product of shape (batch_size, 2)
:return: (batch_size, 1)
"""
features = torch.cat((self.users(inputs[:, 0]), self.products(inputs[:, 1])), dim=-1)
return self.body(features)
[docs] def get_score(self, users: Tensor, products: Tensor) -> Tensor:
"""
Return the score of corresponding pairs of users-products
:param users: (batch_size, )
:param products: (batch_size, )
"""
return self.forward(torch.stack((users, products), dim=-1))