Source code for nntoolbox.components.dndf

"""Deep Neural Decision Forest"""
from functools import partial
import torch
from torch import nn, Tensor
from .merge import Mean


__all__ = ['DNDFTree', 'DNDF']


[docs]class DNDFTree(nn.Module): """ Based on Deep Neural Decision Forest, but with the leaf node parameterized for end-to-end training, and the decision trees balanced. Use BFS + DP for fast path computations References: Peter Kontschieder, Madalina Fiterau, Antonio Criminisi, Samuel Rota Bulo. "Deep Neural Decision Forests." https://www.cv-foundation.org/openaccess/content_iccv_2015/papers/Kontschieder_Deep_Neural_Decision_ICCV_2015_paper.pdf """ def __init__( self, in_features: int, out_features: int, tree_depth: int, output_activation=partial(nn.Softmax, dim=1) ): super().__init__() self.n_leaves = 2 ** tree_depth self.tree_depth = tree_depth self.out_features = out_features self.output_activation = output_activation() self.transform = nn.Linear(in_features, out_features * self.n_leaves + self.n_leaves - 1)
[docs] def forward(self, input: Tensor) -> Tensor: features = self.transform(input) decision_nodes, leaves = torch.sigmoid(features[:, :self.n_leaves - 1]), \ features[:, :self.out_features * self.n_leaves] neg_decision_nodes = 1.0 - decision_nodes leaves = self.output_activation(leaves.view(-1, self.out_features, self.n_leaves)).permute(2, 0, 1) routings = [1.0] for d in range(self.tree_depth): new_level = [] for f in range(2 ** d - 1, 2 ** (d + 1) - 1): new_level.append(routings[f - 2 ** d + 1] * decision_nodes[:, f:f + 1]) new_level.append(routings[f - 2 ** d + 1] * neg_decision_nodes[:, f:f + 1]) routings = new_level return torch.stack([routings[i] * leaves[i] for i in range(self.n_leaves)], dim=-1).sum(-1)
[docs]class DNDF(Mean): """ References: Peter Kontschieder, Madalina Fiterau, Antonio Criminisi, Samuel Rota Bulo. "Deep Neural Decision Forests." https://www.cv-foundation.org/openaccess/content_iccv_2015/papers/Kontschieder_Deep_Neural_Decision_ICCV_2015_paper.pdf """ def __init__( self, in_features: int, out_features: int, n_trees: int, tree_depth: int, output_activation=partial(nn.Softmax, dim=1) ): super().__init__([DNDFTree(in_features, out_features, tree_depth, output_activation) for _ in range(n_trees)])