Source code for nntoolbox.utils.norm_dist

"""Utility functions involving computing norms and distances"""
import torch
from torch import Tensor
import torch.nn.functional as F


__all__ = ['emb_pairwise_dist', 'compute_squared_norm', 'pairwise_dist']


# Follows https://github.com/omoindrot/tensorflow-triplet-loss/blob/master/model/triplet_loss.py
[docs]@torch.no_grad() def emb_pairwise_dist(embeddings: Tensor, squared: bool=True, eps: float=1e-16) -> Tensor: interaction = embeddings.mm(torch.t(embeddings)) # EE^T, (M, M) # norm = torch.norm(embeddings, dim = -1).view(embeddings.shape[0], 1) # sqr_norm_i = \sum_j E_{i, j}^2 = E_i E^T_i square_norm = torch.diag(interaction).view(embeddings.shape[0], 1) # (M, 1) squared_dist = square_norm - 2 * interaction + torch.t(square_norm) squared_dist = F.relu(squared_dist) if squared: return squared_dist else: mask = torch.eq(squared_dist, 0).float() squared_dist = squared_dist + mask * eps dist = torch.sqrt(squared_dist) dist = dist * (1.0 - mask) return dist
[docs]@torch.no_grad() def compute_squared_norm(A: Tensor) -> Tensor: """ Compute the squared norm of each row of A :param A: (M, D) :return: squared norm (M, 1) """ return torch.diag(A.mm(torch.t(A)))
[docs]@torch.no_grad() def pairwise_dist(A: Tensor, B: Tensor) -> Tensor: """ Compute pairwise distance from each row vector of A to row vector of B :param A: (N, D) :param B: (M, D) :return: (M, N) """ sq_norm_A = compute_squared_norm(A).view(1, A.shape[0]) # (1, N) sq_norm_B = compute_squared_norm(B).view(B.shape[0], 1) # (M, 1) interaction = B.mm(torch.t(A)) # (M, N) return F.relu(sq_norm_A - 2 * interaction + sq_norm_B)