Source code for nntoolbox.losses.pinball
from torch import nn, Tensor
import torch
__all__ = ['PinballLoss']
[docs]class PinballLoss(nn.Module):
"""
Pinball loss for quantile regression:
L_tau(y_true, y_pred) = max(tau * (y_true - y_pred), (tau - 1) * (y_true - y_pred))
References:
https://www.tensorflow.org/addons/api_docs/python/tfa/losses/PinballLoss
Ingo Steinwart and Andreas Christmann, "Estimating conditional quantiles with the help of the pinball loss."
https://projecteuclid.org/download/pdfview_1/euclid.bj/1297173840
"""
def __init__(self, tau: float=0.5, reduction: str='mean'):
super().__init__()
assert 0.0 < tau < 1.0
assert reduction in ['mean', 'sum', 'none']
self.tau, self.reduction = tau, reduction
[docs] def forward(self, input: Tensor, target: Tensor) -> Tensor:
error = target - input
losses = torch.stack([self.tau * error, (self.tau - 1.0) * error], dim=0).max(dim=0)[0]
if self.reduction == 'mean':
return losses.mean()
elif self.reduction == 'sum':
return losses.sum()
else:
return losses