Source code for nntoolbox.vision.components.nap

from torch import nn, Tensor
from typing import List, Tuple, Union, Optional


__all__ = ['NeuralAbstractionPyramid']


[docs]class NeuralAbstractionPyramid(nn.Module): """ Neural Abstraction Pyramid Module. Sharing weights both spatially and temporally: a^t_l = norm(activation(f_l(a^{t - 1}_l) + g_l(a^{t - 1}_{l - 1}) + h_l(a^{t - l}_{l + 1}))) If f_l, g_l and h_l are repeated for all layers, then we can also share weights across depth dimension. (UNTESTED) References: Sven Behnke and Ralil Rojas. "Neural Abstraction Pyramid: A hierarchical image understanding architecture." http://page.mi.fu-berlin.de/rojas/1998/pyramid.pdf Sven Behnke. "Hierarchical Neural Networks for Image Interpretation." https://www.ais.uni-bonn.de/books/LNCS2766.pdf Sven Behnke. "Face Localization and Tracking in the Neural Abstraction Pyramid." https://www.ais.uni-bonn.de/behnke/papers/nca04.pdf """ def __init__( self, lateral_connections: List[nn.Module], forward_connections: List[nn.Module], backward_connections: List[nn.Module], activation_function: nn.Module, normalization: nn.Module, duration: int ): """ Note that here we assume the forward direction increase the resolution and the backward direction reverse the resolution. This can always be reversed. :param lateral_connections: consist of depth + 1 conv layers, each with output of same dimension as input. Aggregate information from a local neighborhood of the same resolution from previous timestep. :param forward_connections: consist of depth downsampling conv layers. Transform information from a region of larger resolution (i.e previous layer) from the previous timestep. :param backward_connections: consist of depth upsampling layers Retrieve feedback from a region of smaller resolution (i.e next layer) from the previous timestep. :param duration: default number of timesteps to process data """ assert len(lateral_connections) - 1 == len(forward_connections) == len(backward_connections) super().__init__() self.depth = len(forward_connections) self.duration = duration self.lateral_connections = nn.ModuleList(lateral_connections) self.forward_connections = nn.ModuleList(forward_connections) self.backward_connections = nn.ModuleList(backward_connections) self.activ_norm = nn.Sequential(activation_function, normalization)
[docs] def forward( self, input: Tensor, return_all_states: bool=False, duration: Optional[int]=None ) -> Union[List[Tensor], Tuple[List[Tensor], List[List[Tensor]]]]: """ :param input: :param return_all_states: whether to return output of all timesteps :param duration: number of timesteps to process data :return: the output of last time steps and outputs of all time steps """ if duration is None: duration=self.duration assert duration > 0 states = self.get_initial_states(input) all_states = [states] for t in range(duration): new_states = [] for l in range(self.depth + 1): new_state = self.lateral_connections[l](states[l]) if l > 0: new_state = new_state + self.forward_connections[l - 1](states[l - 1]) if l < self.depth: new_state = new_state + self.backward_connections[l](states[l + 1]) new_state = self.activ_norm(new_state) new_states.append(new_state) states = new_states all_states.append(states) if return_all_states: return states, all_states else: return states
# return states, all_states if return_all_states else states
[docs] def get_initial_states(self, input: Tensor) -> List[Tensor]: ret = [input] for layer in self.forward_connections: input = layer(input) ret.append(input) return ret