Source code for nntoolbox.vision.components.pretrain

from torch import nn
from .layers import InputNormalization
from torchvision.models import resnet18, vgg16_bn
from typing import Optional


[docs]class PretrainedModel(nn.Sequential): """ based on https://github.com/chenyuntc/pytorch-book/blob/master/chapter8-%E9%A3%8E%E6%A0%BC%E8%BF%81%E7%A7%BB(Neural%20Style)/PackedVGG.py """ def __init__(self, model=resnet18, embedding_size=128, fine_tune=False): super(PretrainedModel, self).__init__() model = model(pretrained=True) if not fine_tune: for param in model.parameters(): param.requires_grad = False # model.fc = nn.Linear(model.fc.in_features, embedding_size) features = list(model.features) for ind in range(len(features)): self.add_module( "layer_" + str(ind), features[ind] )
[docs]class FeatureExtractor(nn.Module): """ based on https://github.com/chenyuntc/pytorch-book/blob/master/chapter8-%E9%A3%8E%E6%A0%BC%E8%BF%81%E7%A7%BB(Neural%20Style)/PackedVGG.py """ def __init__( self, model, mean=None, std=None, last_layer=None, default_extracted_feature: Optional[int]=None,fine_tune=True, device=None ): super(FeatureExtractor, self).__init__() if mean is not None and std is not None: self._normalization = InputNormalization(mean=mean, std=std) else: self._normalization = nn.Identity() if not isinstance(model, nn.Module): model = model(pretrained=True) if device is not None: model.to(device) if self._normalization is not None: self._normalization.to(device) if not fine_tune: for param in model.parameters(): param.requires_grad = False self.default_extracted_feature = default_extracted_feature self._features = list(model.features) if last_layer is not None: self._features = self._features[:last_layer + 1] self._features = nn.ModuleList(self._features)
[docs] def forward(self, input, layers=None): input = self._normalization(input) op = [] for ind in range(len(self._features)): input = self._features[ind](input) if layers is not None: if ind in layers: op.append(input) if ind >= max(layers): break else: if self.default_extracted_feature is not None: if ind == self.default_extracted_feature: return input else: if ind == len(self._features) - 1: return input if len(op) == 1: return op[0] return op
[docs]class FeatureExtractorSequential(nn.Sequential): """ based on https://github.com/chenyuntc/pytorch-book/blob/master/chapter8-%E9%A3%8E%E6%A0%BC%E8%BF%81%E7%A7%BB(Neural%20Style)/PackedVGG.py """ def __init__( self, model, mean=None, std=None, last_layer=None, default_extracted_feature: Optional[int]=None,fine_tune=True ): if mean is not None and std is not None: normalization = InputNormalization(mean=mean, std=std) else: normalization = nn.Identity() if not isinstance(model, nn.Module): model = model(pretrained=True) if not fine_tune: for param in model.parameters(): param.requires_grad = False self.default_extracted_feature = default_extracted_feature self._features = list(model.features) if last_layer is not None: self._features = self._features[:last_layer + 1] super(FeatureExtractorSequential, self).__init__(*([normalization] + self._features))
[docs] def forward(self, input, layers=None): input = self._modules['0'](input) op = [] for ind in range(len(self._features)): input = self._features[ind](input) if layers is not None: if ind in layers: op.append(input) if ind >= max(layers): break else: if self.default_extracted_feature is not None: if ind == self.default_extracted_feature: return input else: if ind == len(self._features) - 1: return input if len(op) == 1: return op[0] return op