Source code for nntoolbox.vision.components.stn

import torch
from torch import nn
import torch.nn.functional as F


[docs]class SpatialTransformerModule(nn.Module): ''' Implement a spatial transformer module Adapt from https://pytorch.org/tutorials/intermediate/spatial_transformer_tutorial.html https://papers.nips.cc/paper/5854-spatial-transformer-networks.pdf ''' def __init__(self): super(SpatialTransformerModule, self).__init__() self.localization = nn.Sequential( nn.Conv2d(1, 8, kernel_size=7), nn.MaxPool2d(2, stride=2), nn.ReLU(True), nn.Conv2d(8, 10, kernel_size=5), nn.MaxPool2d(2, stride=2), nn.ReLU(True) ) # regress transformation parameters self.fc_loc = nn.Sequential( nn.Linear(10 * 3 * 3, 32), nn.ReLU(True), nn.Linear(32, 3 * 2) ) # regressor for 3 x 2 affine matrix # Initialize the weights/bias with identity transformation self.fc_loc[2].weight.data.zero_() self.fc_loc[2].bias.data.copy_(torch.tensor([1, 0, 0, 0, 1, 0], dtype=torch.float))
[docs] def forward(self, input): inputs = self.localization(input) inputs = inputs.view(-1, 10 * 3 * 3) theta = self.fc_loc(inputs) theta = theta.view(-1, 2, 3) grid = F.affine_grid(theta, input.size()) return F.grid_sample(input, grid)