from typing import Sequence, Tuple, Dict, Any
from .callbacks import Callback
# from ..vision.learner import SupervisedImageLearner
import numpy as np
from torch import Tensor, nn
import torch
__all__ = ['ManifoldMixupCallback', 'ManifoldMixupModule']
[docs]class ManifoldMixupModule(nn.Module):
"""
Wrapper module to apply manifold mixup
"""
def __init__(self, base_module: nn.Module):
super(ManifoldMixupModule, self).__init__()
self._base_module = base_module
self._mixup = None
self.is_mixing = False
[docs] def forward(self, input):
if self.training and self._mixup is not None and self.is_mixing:
input = self._mixup.transform_input(input)
return self._base_module(input)
[docs]class ManifoldMixupCallback(Callback):
"""
Implement manifold mixup regularization as a callback. Each iteration, pick a random layer and transform its output
and label:
x = tau x_1 + (1 - tau) x_2
y = tau y_1 + (1 - tau) y_2
Reference: https://arxiv.org/pdf/1806.05236.pdf
Based on fastai implementation: https://github.com/fastai/fastai/blob/master/fastai/callbacks/mixup.py
"""
def __init__(self, learner, modules: Sequence[ManifoldMixupModule], alpha: float=2.0):
for module in modules:
module._mixup = self
self._modules = modules
self._alpha = alpha
self._learner = learner
self._old_loss = learner._criterion
self._old_reduction = getattr(self._old_loss, 'reduction')
self._shuffle = None
self._lambd = None
[docs] def on_batch_begin(self, data: Dict[str, Any], train) -> Dict[str, Any]:
if train:
self._learner._criterion = self.transform_loss(self._learner._criterion, True)
labels = data["labels"]
self._shuffle = torch.randperm(labels.size(0)).to(labels.device)
self._lambd = self.get_lambd(labels.size(0), labels.device)
data['labels'] = self.transform_labels(data['labels'])
mix_ind = np.random.choice(len(self._modules))
for ind in range(len(self._modules)):
if mix_ind == ind:
self._modules[ind].is_mixing = True
else:
self._modules[ind].is_mixing = False
return data
[docs] def on_batch_end(self, logs: Dict[str, Any]):
self._learner._criterion = self.transform_loss(self._learner._criterion, False)
[docs] def on_train_end(self):
self.deregister()
[docs] def deregister(self):
for module in self._modules:
module._mixup = None
[docs] def get_lambd(self, batch_size, device):
lambd = np.random.beta(self._alpha, self._alpha, batch_size)
lambd = np.concatenate([lambd[:, None], 1 - lambd[:, None]], 1).max(1)
return torch.from_numpy(lambd).float().to(device)