Source code for nntoolbox.vision.losses.metrics

import torch
from torch import nn, Tensor
from torch.nn import functional as F
import numpy as np
from typing import Tuple
from ...components import MLP


__all__ = [
    'VerificationLoss', 'ContrastiveLoss', 'TripletSoftMarginLoss',
    'TripletMarginLossV2', 'AngularLoss', 'NPairLoss', 'NPairAngular'
]


[docs]class VerificationLoss(nn.Module): """ Verify if two embeddings belong to the same class """ def __init__(self, embedding_dim: int): super(VerificationLoss, self).__init__() self.embedding_dim = embedding_dim self.verifier = MLP(in_features=embedding_dim, out_features=1, hidden_layer_sizes=[embedding_dim // 2]) self.loss = nn.BCEWithLogitsLoss()
[docs] def forward(self, data: Tuple[Tensor, ...]) -> Tensor: (x0, x1), y = data assert x0.shape[-1] == x1.shape[-1] == self.embedding_dim assert len(x0) == len(x1) == len(y) y = y.float() if len(x0.shape) == len(y.shape) + 1: y = y.unsqueeze(-1) score = self.verifier(torch.abs(x0 - x1)) return self.loss(score, y)
[docs] def get_verifier(self): return self.verifier
[docs]class ContrastiveLoss(nn.Module): """ Contrastive loss function. Based on: https://github.com/delijati/pytorch-siamese/blob/master/contrastive.py#L20 """ def __init__(self, margin=1.0): super(ContrastiveLoss, self).__init__() self.margin = margin
[docs] def check_type_forward(self, in_types): assert len(in_types) == 3 x0_type, x1_type, y_type = in_types assert x0_type.size() == x1_type.shape assert x1_type.size()[0] == y_type.shape[0] assert x1_type.size()[0] > 0 assert x0_type.dim() == 2 assert x1_type.dim() == 2 assert y_type.dim() == 1
[docs] def forward(self, data: Tuple[Tensor, ...]) -> Tensor: (x0, x1), y = data self.check_type_forward((x0, x1, y)) y = y.to(x0.dtype) # euclidian distance dist = self.dist(x0, x1, squared=False) dist_sq = dist.pow(2) mdist = self.margin - dist cl_dist = torch.clamp(mdist, min=0.0) # print(dist) loss = y * dist_sq + (1 - y) * cl_dist.pow(2) loss = torch.sum(loss) / 2.0 / x0.shape[0] return loss
[docs] def dist(self, x_0, x_1, eps = 1e-8, squared = False): interaction = x_0.mm(torch.t(x_1)) norm_square_0 = torch.diag(x_0.mm(torch.t(x_0))).view(x_0.shape[0], 1) norm_square_1 = torch.diag(x_1.mm(torch.t(x_1))).view(1, x_1.shape[0]) dist_squared = norm_square_0 - 2 * interaction + norm_square_1 if squared: return dist_squared else: return torch.sqrt(torch.clamp(dist_squared, 0.0) + eps)
[docs]class TripletSoftMarginLoss(nn.Module): def __init__(self, p: float=2.0): super(TripletSoftMarginLoss, self).__init__() self._p = p
[docs] def forward(self, data: Tuple[Tensor, Tensor, Tensor]) -> Tensor: anchor, positive, negative = data ap = torch.norm(anchor - positive, dim = -1, p = self._p) an = torch.norm(anchor - negative, dim = -1, p = self._p) return torch.mean(torch.log1p(torch.exp(ap - an)))
[docs]class TripletMarginLossV2(nn.TripletMarginLoss): """A quick wrapper for margin loss""" def __init__(self, margin=1.0, p=2.0, eps=1e-06, swap=False, size_average=None, reduce=None, reduction='mean'): super(TripletMarginLossV2, self).__init__(margin, p, eps, swap, size_average, reduce, reduction)
[docs] def forward(self, data: Tuple[Tensor, Tensor, Tensor]) -> Tensor: anchor, positive, negative = data return super().forward(anchor, positive, negative)
[docs]class NPairLoss(nn.Module): def __init__(self, reg_lambda: float=0.002): super(NPairLoss, self).__init__() self._reg_lambda = reg_lambda self.ce_loss = nn.CrossEntropyLoss() # anchors, positives: (N, D)
[docs] def forward(self, data: Tuple[Tensor, Tensor]) -> Tensor: anchors, positives = data interaction = anchors.mm(torch.t(positives)) #(N, N) (i, j) = anchor_i positive j labels = torch.from_numpy(np.arange(len(anchors))) if anchors.is_cuda: labels = labels.cuda() reg_an = torch.mean(torch.sum(anchors * anchors, dim = -1)) reg_pos = torch.mean(torch.sum(positives * positives, dim = -1)) l2_reg = self._reg_lambda * 0.25 * (reg_an + reg_pos) return self.ce_loss(interaction, labels) + l2_reg
[docs]class AngularLoss(nn.Module): """ Based on https://github.com/leeesangwon/PyTorch-Image-Retrieval/blob/public/losses.py """ def __init__(self, alpha = 45): super(AngularLoss, self).__init__() self._alpha = torch.from_numpy(np.deg2rad([alpha])).float()
[docs] def forward(self, data: Tuple[Tensor, Tensor]) -> Tensor: anchors, positives = data if anchors.is_cuda: self._alpha = self._alpha.cuda() # Normalize anchors and positives: anchors = F.normalize(anchors, dim = -1, p = 2) positives = F.normalize(positives, dim = -1, p = 2) n_pair = len(anchors) # get negative indices: (N, N - 1) all_pairs = np.array([[j for j in range(n_pair) if j != i] for i in range(n_pair)]).astype(np.uint8) stack_an = torch.stack([anchors[all_pairs[i]] for i in range(n_pair)]) # (N, N - 1, D) stack_pos = torch.stack([positives[all_pairs[i]] for i in range(n_pair)]) # (N, N - 1, D) negatives = torch.cat((stack_an, stack_pos), dim=1) # (N, 2 * (N - 1), D) anchors = torch.unsqueeze(anchors, dim=1) # (N, 1, D) positives = torch.unsqueeze(positives, dim=1) # (N, 1, D) angle_bound = torch.tan(self._alpha).pow(2) interaction = 4. * angle_bound * torch.matmul((anchors + positives), negatives.transpose(1, 2)) \ - 2. * (1. + angle_bound) * torch.matmul(anchors, positives.transpose(1, 2)) # (N, 1, 2 * (N - 1)) with torch.no_grad(): t = torch.max(interaction, dim=2)[0] interaction = torch.exp(interaction - t.unsqueeze(dim=1)) interaction = torch.log(torch.exp(-t) + torch.sum(interaction, 2)) loss = torch.mean(t + interaction) return loss
[docs]class NPairAngular(nn.Module): """ Combining N-Pair loss and Angular loss """ def __init__(self, alpha = 45, reg_lambda = 0.002, angular_lambda = 2): super(NPairAngular, self).__init__() self._angular_loss = AngularLoss(alpha) self._npair_loss = NPairLoss(reg_lambda) self._angular_lambda = angular_lambda
[docs] def forward(self, data: Tuple[Tensor, Tensor]) -> Tensor: anchors, positives = data return ( self._npair_loss(anchors, positives) + self._angular_lambda * self._angular_loss(anchors, positives) ) / (1 + self._angular_lambda)