Source code for nntoolbox.vision.losses.robust

"""More robust loss functions (UNTESTED)"""
from torch import nn, Tensor


__all__ = ['GeneralizedCharbonnierLoss', 'CharbonnierLoss', 'CharbonnierLossV2']


[docs]class GeneralizedCharbonnierLoss(nn.Module): """ Generalized Charbonnier Loss Function: l(input, target) = (input - target)^2 + eps^2) ^ (alpha / 2) References: Deqing Sun et al. "Secrets of Optical Flow Estimation and Their Principles." http://cs.brown.edu/~dqsun/pubs/cvpr_2010_flow.pdf """ def __init__(self, alpha: float=1.0, eps: float=1e-6): super(GeneralizedCharbonnierLoss, self).__init__() self.alpha = alpha self.eps = eps
[docs] def forward(self, input: Tensor, target: Tensor) -> Tensor: return ((input - target).pow(2) + self.eps ** 2).pow(self.alpha / 2).mean()
[docs]class CharbonnierLoss(GeneralizedCharbonnierLoss): """ Charbonnier Loss Function: l(input, target) = sqrt((input - target)^2 + eps^2) References: Wei-Sheng Lai et al. "Fast and Accurate Image Super-Resolution with Deep Laplacian Pyramid Networks." https://arxiv.org/pdf/1710.01992.pdf """ def __init__(self, eps: float=1e-3): super(CharbonnierLoss, self).__init__(1.0, eps)
[docs]class CharbonnierLossV2(nn.Module): """ Charbonnier Loss Function: l(input, target) = sqrt((input - target)^2 + eps^2) References: Wei-Sheng Lai et al. "Fast and Accurate Image Super-Resolution with Deep Laplacian Pyramid Networks." https://arxiv.org/pdf/1710.01992.pdf """ def __init__(self, eps: float=1e-3): super(CharbonnierLossV2, self).__init__() self.eps = eps
[docs] def forward(self, input: Tensor, target: Tensor) -> Tensor: return ((input - target).pow(2) + self.eps ** 2).sqrt().mean()