1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
| # encoding: utf-8
|
|
| import torch
| from torch import nn
| from .batch_norm import get_norm
|
|
| class Non_local(nn.Module):
| def __init__(self, in_channels, bn_norm, num_splits, reduc_ratio=2):
| super(Non_local, self).__init__()
|
| self.in_channels = in_channels
| self.inter_channels = reduc_ratio // reduc_ratio
|
| self.g = nn.Conv2d(in_channels=self.in_channels, out_channels=self.inter_channels,
| kernel_size=1, stride=1, padding=0)
|
| self.W = nn.Sequential(
| nn.Conv2d(in_channels=self.inter_channels, out_channels=self.in_channels,
| kernel_size=1, stride=1, padding=0),
| get_norm(bn_norm, self.in_channels, num_splits),
| )
| nn.init.constant_(self.W[1].weight, 0.0)
| nn.init.constant_(self.W[1].bias, 0.0)
|
| self.theta = nn.Conv2d(in_channels=self.in_channels, out_channels=self.inter_channels,
| kernel_size=1, stride=1, padding=0)
|
| self.phi = nn.Conv2d(in_channels=self.in_channels, out_channels=self.inter_channels,
| kernel_size=1, stride=1, padding=0)
|
| def forward(self, x):
| '''
| :param x: (b, t, h, w)
| :return x: (b, t, h, w)
| '''
| batch_size = x.size(0)
| g_x = self.g(x).view(batch_size, self.inter_channels, -1)
| g_x = g_x.permute(0, 2, 1)
|
| theta_x = self.theta(x).view(batch_size, self.inter_channels, -1)
| theta_x = theta_x.permute(0, 2, 1)
| phi_x = self.phi(x).view(batch_size, self.inter_channels, -1)
| f = torch.matmul(theta_x, phi_x)
| N = f.size(-1)
| f_div_C = f / N
|
| y = torch.matmul(f_div_C, g_x)
| y = y.permute(0, 2, 1).contiguous()
| y = y.view(batch_size, self.inter_channels, *x.size()[2:])
| W_y = self.W(y)
| z = W_y + x
| return z
|
|