Source code for nntoolbox.losses.moe
from torch import nn, Tensor
from typing import Tuple
__all__ = ['CompetitiveMOELoss']
[docs]class CompetitiveMOELoss(nn.Module):
"""
Encourage expert specialization:
loss(expert_op, expert_weight, target) = sum_e weight_e * base_loss(op_e, target)
Reference:
https://www.cs.toronto.edu/~hinton/absps/jjnh91.pdf
"""
def __init__(self, base_loss: nn.Module=nn.MSELoss(reduction='none')):
super(CompetitiveMOELoss, self).__init__()
self.base_loss = base_loss
setattr(self.base_loss, 'reduction', 'none')
[docs] def forward(self, experts: Tuple[Tensor, Tensor], targets: Tensor) -> Tensor:
"""
:param experts: expert_output: (batch_size, *, n_expert), expert_score: (batch_size, *, n_expert)
:param targets: (batch_size, *)
:return:
"""
expert_outputs, expert_scores = experts
targets = targets.unsqueeze(-1)
loss = self.base_loss(expert_outputs, targets) # (batch_size, *, n_expert)
return (loss * expert_scores).sum(-1)