nntoolbox.losses.moe module

class nntoolbox.losses.moe.CompetitiveMOELoss(base_loss: torch.nn.modules.module.Module = MSELoss())[source]

Bases: torch.nn.modules.module.Module

Encourage expert specialization:

loss(expert_op, expert_weight, target) = sum_e weight_e * base_loss(op_e, target)

Reference:

https://www.cs.toronto.edu/~hinton/absps/jjnh91.pdf

forward(experts: Tuple[torch.Tensor, torch.Tensor], targets: torch.Tensor) → torch.Tensor[source]
Parameters
  • experts – expert_output: (batch_size, *, n_expert), expert_score: (batch_size, *, n_expert)

  • targets – (batch_size, *)

Returns

training: bool