Source code for nntoolbox.components.regularization
from torch import nn
from typing import List, Union
import torch.nn.functional as F
import warnings
__all__ = ['DropConnect']
[docs]class DropConnect(nn.Module):
"""
Implementation based on fastai's WeightDropout (from course 2 v3 notebook)
Reference:
Li Wan, Matthew Zeiler, Sixin Zhang, Yann Le Cun, Rob Fergus. "Regularization of Neural Networks using DropConnect."
http://yann.lecun.com/exdb/publis/pdf/wan-icml-13.pdf
"""
def __init__(self, module: nn.Module, ps: Union[List[float], float]=0.0, weight_names: List[str]=['weight']):
if isinstance(ps, list):
assert len(ps) == len(weight_names)
else:
ps = [ps for _ in range(len(weight_names))]
super(DropConnect, self).__init__()
self.module, self.ps, self.weight_names = module, ps, weight_names
for ind in range(len(self.weight_names)):
weight = self.weight_names[ind]
p = self.ps[ind]
w = getattr(self.module, weight)
self.register_parameter(weight + "_raw", nn.Parameter(w.data))
self.module._parameters[weight] = F.dropout(w, p=p, training=False)
def _setweights(self):
for ind in range(len(self.weight_names)):
weight = self.weight_names[ind]
p = self.ps[ind]
raw_w = getattr(self, weight + "_raw")
self.module._parameters[weight] = F.dropout(raw_w, p=p, training=self.training)
[docs] def forward(self, *inputs):
self._setweights()
with warnings.catch_warnings():
# To avoid the warning that comes because the weights aren't flattened.
warnings.simplefilter("ignore")
return self.module.forward(*inputs)