Source code for nntoolbox.losses.smooth

from torch.nn import Module
from torch import Tensor
from nntoolbox.utils import to_onehot
import torch.nn.functional as F
from typing import Optional


__all__ = ['SmoothedCrossEntropy']


[docs]class SmoothedCrossEntropy(Module): """ Drop-in replacement for cross entropy loss with label smoothing: loss(y_hat, y) = -sum_c p_c * log y_hat_c where p_c = 1 - epsilon if c = y and epsilon / (C - 1) otherwise Based on: http://openaccess.thecvf.com/content_CVPR_2019/papers/He_Bag_of_Tricks_for_Image_Classification_with_Convolutional_Neural_Networks_CVPR_2019_paper.pdf Note that deprecated arguments of CrossEntropyLoss are not included """ def __init__(self, weight: Optional[Tensor]=None, reduction: str='mean', eps: float=0.1): assert reduction == 'mean' or reduction =='sum' or reduction == 'none' if weight is not None: assert len(weight.shape) == 1 super(SmoothedCrossEntropy, self).__init__() self.eps = eps self.weight = weight self.reduction = reduction
[docs] def forward(self, output: Tensor, label: Tensor) -> Tensor: """ :param output: Predicted class scores. (batch_size, C, *) :param label: The true label. (batch_size, *) :return: """ if self.weight is not None: assert len(self.weight) == output.shape[1] smoothed_label = self.smooth_label(label, output.shape[1]).to(output.dtype) output = F.log_softmax(output, 1) loss = -output * smoothed_label if self.weight is not None: weight_shape = [1, self.weight.shape[0]] + [1 for _ in range(len(output.shape) - 2)] weight = self.weight.view(weight_shape) loss = loss * weight loss = loss.sum(1) if self.reduction == 'none': return loss elif self.reduction == 'mean': return loss.mean() else: return loss.sum()
[docs] def smooth_label(self, label: Tensor, n_class: int) -> Tensor: """ Smooth the label :param label: (batch_size, *) :param n_class: number of class of the output :return: (batch_size, C, *) """ label_oh = to_onehot(label, n_class).float() return (1 - self.eps) * label_oh + self.eps / (n_class - 1) * (1 - label_oh)