Source code for nntoolbox.losses.losses
from torch.nn import MSELoss
import torch.nn.functional as F
from torch import Tensor, nn
import torch
from typing import List, Optional
__all__ = ['RMSELoss', 'LogSigmoidLoss', 'CombinedLoss']
[docs]class RMSELoss(MSELoss):
def __init__(self, reduction='mean', eps: float=1e-8):
super(RMSELoss, self).__init__(reduction=reduction)
self._eps = eps
[docs] def forward(self, input: Tensor, target: Tensor) -> Tensor:
return torch.sqrt(super().forward(input, target) + self._eps)
[docs]class LogSigmoidLoss(nn.Module):
[docs] def forward(self, input: Tensor) -> Tensor: return -F.logsigmoid(input).mean(0)
[docs]class CombinedLoss(nn.Module):
def __init__(self, losses: List[nn.Module], weights: Optional[List[float]]=None):
super(CombinedLoss, self).__init__()
if weights is not None: assert len(weights) == len(losses)
else: weights = [1.0 / len(losses) for _ in range(len(losses))]
self.losses = nn.ModuleList(losses)
self.weights = weights
[docs] def forward(self, input: Tensor, target: Tensor) -> Tensor:
losses = torch.stack([self.losses[i](input, target) * self.weights[i] for i in range(len(self.losses))], -1)
return torch.sum(losses)