Source code for nntoolbox.components.kernel
from torch import nn, Tensor
import torch
__all__ = ['DistKernel', 'GaussianDistKernel']
[docs]class DistKernel(nn.Module):
def __call__(self, dists: Tensor) -> Tensor: pass
[docs]class GaussianDistKernel(DistKernel):
def __init__(self, log_beta: float=0.0, trainable_beta: bool=False):
"""
:param log_beta: log of beta (which is inverse of bandwidth)
:param trainable_beta: whether beta should be trainable
"""
super(GaussianDistKernel, self).__init__()
self.log_beta = nn.Parameter(torch.tensor(log_beta), requires_grad=trainable_beta)
def __call__(self, dists: Tensor) -> Tensor:
return torch.exp(-torch.exp(self.log_beta) * dists.pow(2))