Source code for nntoolbox.utils.transfer
"""Some utility functions for transfer learning"""
from torch.nn import Sequential, Module, AdaptiveAvgPool2d
from typing import Callable, Tuple
from torch import nn
__all__ = ['cut_sequential_model', 'cut_model']
[docs]def cut_sequential_model(
model: Sequential, sep: Callable[..., Module]=AdaptiveAvgPool2d
) -> Tuple[Sequential, Sequential]:
"""
Cut a sequential model at the first instance of layer type
:param model:
:param sep:
:return:
"""
cut_ind = next(i for i, o in enumerate(model.children()) if isinstance(o, sep))
return model[:cut_ind], model[cut_ind:]
[docs]def cut_model(
model: Sequential, sep: Callable[..., Module]=AdaptiveAvgPool2d
) -> Tuple[Sequential, Sequential]:
"""
Cut a non-sequential model at the first instance of layer type
:param model:
:param sep:
:return:
"""
modules = [model._modules[key] for key in model._modules]
cut_ind = [i for i in range(len(modules)) if isinstance(modules[i], sep)][0]
return nn.Sequential(*modules[:cut_ind]), nn.Sequential(*modules[cut_ind:])