nntoolbox.utils.norm_dist module

Utility functions involving computing norms and distances

nntoolbox.utils.norm_dist.compute_squared_norm(A: torch.Tensor) → torch.Tensor[source]

Compute the squared norm of each row of A

Parameters

A – (M, D)

Returns

squared norm (M, 1)

nntoolbox.utils.norm_dist.emb_pairwise_dist(embeddings: torch.Tensor, squared: bool = True, eps: float = 1e-16) → torch.Tensor[source]
nntoolbox.utils.norm_dist.pairwise_dist(A: torch.Tensor, B: torch.Tensor) → torch.Tensor[source]

Compute pairwise distance from each row vector of A to row vector of B

Parameters
  • A – (N, D)

  • B – (M, D)

Returns

(M, N)