Source code for nntoolbox.hooks.hooks

"""
Implement abstraction for hooks
Adopt from FastAI:
"""
import torch
from torch.nn import Module
from typing import Callable, Any, List
from torch import Tensor
from functools import partial


__all__ = ['Hook', 'Hooks']


[docs]class Hook: def __init__( self, module: Module, hook_func: Callable[['Hook', Module, Tensor, Tensor], Any], forward: bool=True ): if forward: self.hook = module.register_forward_hook(partial(hook_func, self)) else: self.hook = module.register_backward_hook(partial(hook_func, self)) def __del__(self): self.remove()
[docs] def remove(self): self.hook.remove()
[docs]class Hooks: def __init__(self, ms: List[Module], hook_fn: Callable[['Hook', Module, Tensor, Tensor], Any], forward): if not isinstance(forward, list): forward = [forward for _ in range(len(ms))] self.hooks = [Hook(m, hook_fn, f) for m, f in zip(ms, forward)] def __iter__(self): return iter(self.hooks)
[docs] def remove(self): for hook in self.hooks: hook.remove()
def __enter__(self, *args): return self def __exit__ (self, *args): self.remove() def __del__(self): self.remove() def __delitem__(self, i: int): self.hooks[i].remove() del(self.hooks[i]) def __len__(self) -> int: return len(self.hooks) def __setitem__(self, i: int, hook: Hook): self.hooks[i] = hook