Source code for nntoolbox.components.rbf

from typing import Optional
from torch import nn, Tensor
from sklearn.cluster import KMeans
from sklearn.neighbors import NearestCentroid
from .kernel import DistKernel
import torch


__all__ = ['RBFLayer']


[docs]class RBFLayer(nn.Linear): """ RBF Layer (used for output or as the hidden layer for RBF network) References: Lecun et al. "Gradient-Based Learning Applied to Document Recognition." http://yann.lecun.com/exdb/publis/pdf/lecun-98.pdf """ def __init__( self, in_features: int, out_features: int, trainable_centers: bool=True, normalized: bool= False, kernel: Optional[DistKernel]=None, initial_centers: Optional[Tensor]=None ): """ :param in_features: dimension of input :param out_features: number of centers :param trainable_centers: whether the center can be moved :param normalized: whether the output should be normalized (i.e sum to 1) :param kernel: (optional) a distance-based kernel function :param initial_centers: (optional) initial centers placement """ super(RBFLayer, self).__init__(in_features, out_features, False) self.normalized = normalized self.kernel = kernel if initial_centers is not None: assert initial_centers.shape[0] == out_features and initial_centers.shape[1] == in_features self.weight = self.centers = nn.Parameter(initial_centers, requires_grad=trainable_centers) else: self.centers = self.weight self.centers.requires_grad = trainable_centers
[docs] def cluster_initialize(self, input: Tensor): """ (Re-)initialize the centers based on k-mean clustering on the input :param input: """ model = KMeans(self.out_features) model.fit(input.cpu().detach().numpy()) self.weight.data.copy_(torch.Tensor(model.cluster_centers_).to(self.weight.data.device))
[docs] def centroids_initialize(self, input: Tensor, labels: Tensor): """ (Re-)initialize the centers based on nearest centroids algorithm :param input: :param labels: """ model = NearestCentroid() model.fit(input.cpu().detach().numpy(), labels.cpu().detach().numpy().ravel()) self.weight.data.copy_(torch.Tensor(model.centroids_).to(self.weight.data.device))
[docs] def forward(self, input: Tensor) -> Tensor: dists = pairwise_dist(input, self.centers, squared=True) if self.kernel is not None: dists = self.kernel(torch.sqrt(dists)) return dists / dists.sum(dim=-1, keepdim=True) if self.normalized else dists
def pairwise_dist(A: Tensor, B: Tensor, squared: bool=True) -> Tensor: """ :param A: (M, D) :param B: (N, D) :param squared: whether to return squared distance or just distance :return: (M, N) """ interaction = A.matmul(B.t()) A_sq = A.pow(2).sum(-1).unsqueeze(-1) B_sq = B.pow(2).sum(-1).unsqueeze(0) dist_sq = torch.clamp(A_sq + B_sq - 2 * interaction, min=0) return dist_sq if squared else dist_sq.sqrt()