| | |
| | |
| | |
| | import torch |
| | import torch.nn.functional as F |
| | import torch.nn as nn |
| |
|
| | from functools import partial |
| |
|
| | |
| | |
| | |
| | class LocalAffinity(nn.Module): |
| |
|
| | def __init__(self, dilations=[1]): |
| | super(LocalAffinity, self).__init__() |
| | self.dilations = dilations |
| | weight = self._init_aff() |
| | self.register_buffer('kernel', weight) |
| |
|
| | def _init_aff(self): |
| | |
| | weight = torch.zeros(8, 1, 3, 3) |
| |
|
| | for i in range(weight.size(0)): |
| | weight[i, 0, 1, 1] = 1 |
| |
|
| | weight[0, 0, 0, 0] = -1 |
| | weight[1, 0, 0, 1] = -1 |
| | weight[2, 0, 0, 2] = -1 |
| |
|
| | weight[3, 0, 1, 0] = -1 |
| | weight[4, 0, 1, 2] = -1 |
| |
|
| | weight[5, 0, 2, 0] = -1 |
| | weight[6, 0, 2, 1] = -1 |
| | weight[7, 0, 2, 2] = -1 |
| |
|
| | self.weight_check = weight.clone() |
| |
|
| | return weight |
| |
|
| | def forward(self, x): |
| |
|
| | self.weight_check = self.weight_check.type_as(x) |
| | assert torch.all(self.weight_check.eq(self.kernel)) |
| |
|
| | B,K,H,W = x.size() |
| | x = x.view(B*K,1,H,W) |
| |
|
| | x_affs = [] |
| | for d in self.dilations: |
| | x_pad = F.pad(x, [d]*4, mode='replicate') |
| | x_aff = F.conv2d(x_pad, self.kernel, dilation=d) |
| | x_affs.append(x_aff) |
| |
|
| | x_aff = torch.cat(x_affs, 1) |
| | return x_aff.view(B,K,-1,H,W) |
| |
|
| | class LocalAffinityCopy(LocalAffinity): |
| |
|
| | def _init_aff(self): |
| | |
| | weight = torch.zeros(8, 1, 3, 3) |
| |
|
| | weight[0, 0, 0, 0] = 1 |
| | weight[1, 0, 0, 1] = 1 |
| | weight[2, 0, 0, 2] = 1 |
| |
|
| | weight[3, 0, 1, 0] = 1 |
| | weight[4, 0, 1, 2] = 1 |
| |
|
| | weight[5, 0, 2, 0] = 1 |
| | weight[6, 0, 2, 1] = 1 |
| | weight[7, 0, 2, 2] = 1 |
| |
|
| | self.weight_check = weight.clone() |
| | return weight |
| |
|
| | class LocalStDev(LocalAffinity): |
| |
|
| | def _init_aff(self): |
| | weight = torch.zeros(9, 1, 3, 3) |
| | weight.zero_() |
| |
|
| | weight[0, 0, 0, 0] = 1 |
| | weight[1, 0, 0, 1] = 1 |
| | weight[2, 0, 0, 2] = 1 |
| |
|
| | weight[3, 0, 1, 0] = 1 |
| | weight[4, 0, 1, 1] = 1 |
| | weight[5, 0, 1, 2] = 1 |
| |
|
| | weight[6, 0, 2, 0] = 1 |
| | weight[7, 0, 2, 1] = 1 |
| | weight[8, 0, 2, 2] = 1 |
| |
|
| | self.weight_check = weight.clone() |
| | return weight |
| |
|
| | def forward(self, x): |
| | |
| | |
| | x = super(LocalStDev, self).forward(x) |
| |
|
| | return x.std(2, keepdim=True) |
| |
|
| | class LocalAffinityAbs(LocalAffinity): |
| |
|
| | def forward(self, x): |
| | x = super(LocalAffinityAbs, self).forward(x) |
| | return torch.abs(x) |
| |
|
| | |
| | |
| | |
| | class PAMR(nn.Module): |
| |
|
| | def __init__(self, num_iter=1, dilations=[1]): |
| | super(PAMR, self).__init__() |
| |
|
| | self.num_iter = num_iter |
| | self.aff_x = LocalAffinityAbs(dilations) |
| | self.aff_m = LocalAffinityCopy(dilations) |
| | self.aff_std = LocalStDev(dilations) |
| |
|
| | def forward(self, x, mask): |
| | mask = F.interpolate(mask, size=x.size()[-2:], mode="bilinear", align_corners=True) |
| |
|
| | |
| | |
| | B,K,H,W = x.size() |
| | _,C,_,_ = mask.size() |
| |
|
| | x_std = self.aff_std(x) |
| |
|
| | x = -self.aff_x(x) / (1e-8 + 0.1 * x_std) |
| | x = x.mean(1, keepdim=True) |
| | x = F.softmax(x, 2) |
| |
|
| | for _ in range(self.num_iter): |
| | m = self.aff_m(mask) |
| | mask = (m * x).sum(2) |
| |
|
| | |
| | return mask |
| |
|