import torch
from torch import nn, Tensor
import torch.nn.functional as F
import numpy as np
__all__ = ['SAGANAttention', 'StandAloneSelfAttention', 'StandAloneMultiheadAttention']
[docs]class SAGANAttention(nn.Module):
"""
Implement SAGAN attention module.
References:
Han Zhang, Ian Goodfellow, Dimitris Metaxas, Augustus Odena. "Self-Attention Generative Adversarial Networks."
https://arxiv.org/pdf/1805.08318.pdf
"""
def __init__(self, in_channels: int, reduction_ratio: int=8):
assert in_channels % reduction_ratio == 0
super().__init__()
self.transform = nn.Conv2d(
in_channels=in_channels,
out_channels=(in_channels // reduction_ratio) * 3,
kernel_size=1, bias=False
)
self.softmax = nn.Softmax(dim=1)
self.op_transform = nn.Conv2d(
in_channels=in_channels // reduction_ratio,
out_channels=in_channels,
kernel_size=1, bias=False
)
self.scale = nn.Parameter(torch.zeros(1), requires_grad=True)
[docs] def forward(self, input: Tensor) -> Tensor:
batch_size, _, h, w = input.shape
transformed = self.transform(input)
key, query, value = transformed.chunk(3, 1)
attention_scores = key.view((batch_size, -1, h * w)).permute(0, 2, 1).bmm(
query.view((batch_size, -1, h * w))
)
attention_weights = self.softmax(attention_scores)
output = value.view(batch_size, value.shape[1], -1).bmm(attention_weights)
output = output.view(batch_size, output.shape[1], h, w)
output = self.op_transform(output)
return self.scale * output + input
[docs]class StandAloneSelfAttention(nn.Conv2d):
"""
A single head of Stand-Alone Self-Attention for Vision Model
References:
Prajit Ramachandran, Niki Parmar, Ashish Vaswani, Irwan Bello, Anselm Levskaya, Jonathon Shlens.
"Stand-Alone Self-Attention in Vision Models." https://arxiv.org/pdf/1906.05909.pdf.
"""
def __init__(
self, in_channels: int, out_channels: int, kernel_size, stride=1,
padding: int=0, dilation: int=1,
bias: bool=True, padding_mode: str='zeros'
):
assert out_channels % 2 == 0
super(StandAloneSelfAttention, self).__init__(
in_channels, out_channels, kernel_size, stride,
padding, dilation, 1,
bias, padding_mode
)
self.weight = None
self.bias = None
self.transform = nn.Conv2d(in_channels, out_channels * 3, 1, bias=False)
self.softmax = nn.Softmax(dim=2)
self.rel_h = nn.Embedding(num_embeddings=self.kernel_size[0], embedding_dim=out_channels // 2)
self.rel_w = nn.Embedding(num_embeddings=self.kernel_size[1], embedding_dim=out_channels // 2)
self.h_range = nn.Parameter(torch.arange(0, self.kernel_size[0])[:, None], requires_grad=False)
self.w_range = nn.Parameter(torch.arange(0, self.kernel_size[1])[None, :], requires_grad=False)
[docs] def forward(self, input: Tensor) -> Tensor:
batch_size, _, inp_h, inp_w = input.shape
output_h, output_w = self.compute_output_shape(inp_h, inp_w)
if self.padding_mode == 'circular':
expanded_padding = [(self.padding[1] + 1) // 2, self.padding[1] // 2,
(self.padding[0] + 1) // 2, self.padding[0] // 2]
input = F.pad(input, expanded_padding, mode='circular')
padding = 0
else:
padding = self.padding
transformed = self.transform(input)
key, query, value = transformed.chunk(3, 1)
key_uf = F.unfold(
key, kernel_size=self.kernel_size, dilation=self.dilation,
padding=padding, stride=self.stride
).view(
batch_size, self.out_channels, self.kernel_size[0], self.kernel_size[1], -1
)[:, :, self.kernel_size[0] // 2, self.kernel_size[1] // 2, :]
query_uf = F.unfold(
query, kernel_size=self.kernel_size, dilation=self.dilation,
padding=padding, stride=self.stride
).view(batch_size, self.out_channels, self.kernel_size[0] * self.kernel_size[1], -1)
value_uf = F.unfold(
value, kernel_size=self.kernel_size, dilation=self.dilation,
padding=padding, stride=self.stride
).view(batch_size, self.out_channels, self.kernel_size[0] * self.kernel_size[1], -1)
rel_embedding = self.get_rel_embedding()[None, :, :, None]
logits = (key_uf[:, :, None, :] * (query_uf + rel_embedding)).sum(1, keepdim=True)
attention_weights = self.softmax(logits)
output = (attention_weights * value_uf).sum(2).view(batch_size, -1, output_h, output_w)
return output
[docs] def compute_output_shape(self, height, width):
def compute_shape_helper(inp_dim, padding, kernel_size, dilation, stride):
return np.floor(
(inp_dim + 2 * padding - dilation * (kernel_size - 1) - 1) / stride + 1
).astype(np.uint32)
return (
compute_shape_helper(height, self.padding[0], self.kernel_size[0], self.dilation[0], self.stride[0]),
compute_shape_helper(width, self.padding[1], self.kernel_size[1], self.dilation[1], self.stride[1]),
)
[docs] def get_rel_embedding(self) -> Tensor:
h_embedding = self.rel_h(self.h_range).repeat(1, self.kernel_size[1], 1)
w_embedding = self.rel_w(self.w_range).repeat(self.kernel_size[0], 1, 1)
return torch.cat((h_embedding, w_embedding), dim=-1).view(-1, self.out_channels).transpose(0, 1)
[docs] def to(self, *args, **kwargs):
self.h_range.to(*args, **kwargs)
self.w_range.to(*args, **kwargs)
super().to(*args, **kwargs)
[docs]class StandAloneMultiheadAttention(nn.Module):
"""
Stand-Alone Multihead Self-Attention for Vision Model
References:
Prajit Ramachandran, Niki Parmar, Ashish Vaswani, Irwan Bello, Anselm Levskaya, Jonathon Shlens.
"Stand-Alone Self-Attention in Vision Models." https://arxiv.org/pdf/1906.05909.pdf.
"""
def __init__(
self, num_heads: int, in_channels: int, out_channels: int, kernel_size, stride=1,
padding: int=0, dilation: int=1,
bias: bool=True, padding_mode: str='zeros'
):
assert out_channels % num_heads == 0
super(StandAloneMultiheadAttention, self).__init__()
self.heads = nn.ModuleList(
[
StandAloneSelfAttention(
in_channels, out_channels // num_heads, kernel_size, stride,
padding, dilation, bias, padding_mode
)
for _ in range(num_heads)
]
)
[docs] def forward(self, input: Tensor) -> Tensor:
heads = [head(input) for head in self.heads]
return torch.cat(heads, dim=1)