import torch import torch.nn.functional as F import torch.nn as nn # from functools import partial # # Helper modules # class LocalAffinity(nn.Module): def __init__(self, dilations=[1]): super(LocalAffinity, self).__init__() self.dilations = dilations weight = self._init_aff() self.register_buffer('kernel', weight) def _init_aff(self): # initialising the shift kernel weight = torch.zeros(8, 1, 3, 3) for i in range(weight.size(0)): weight[i, 0, 1, 1] = 1 weight[0, 0, 0, 0] = -1 weight[1, 0, 0, 1] = -1 weight[2, 0, 0, 2] = -1 weight[3, 0, 1, 0] = -1 weight[4, 0, 1, 2] = -1 weight[5, 0, 2, 0] = -1 weight[6, 0, 2, 1] = -1 weight[7, 0, 2, 2] = -1 self.weight_check = weight.clone() return weight def forward(self, x): self.weight_check = self.weight_check.type_as(x) assert torch.all(self.weight_check.eq(self.kernel)) B, K, H, W = x.size() # print(f"x: {x.shape}") x = x.view(B * K, 1, H, W) # import time # start = time.time() x_affs = [] for d in self.dilations: x_pad = F.pad(x, [d] * 4, mode='replicate') x_aff = F.conv2d(x_pad, self.kernel, dilation=d) x_affs.append(x_aff) # single_mask_dilation_loop_time = time.time() - start # print(f"single_mask_dilation_loop_time run: {single_mask_dilation_loop_time}, dilations: {self.dilations}") x_aff = torch.cat(x_affs, 1) return x_aff.view(B, K, -1, H, W) class LocalAffinityCopy(LocalAffinity): def _init_aff(self): # initialising the shift kernel weight = torch.zeros(8, 1, 3, 3) weight[0, 0, 0, 0] = 1 weight[1, 0, 0, 1] = 1 weight[2, 0, 0, 2] = 1 weight[3, 0, 1, 0] = 1 weight[4, 0, 1, 2] = 1 weight[5, 0, 2, 0] = 1 weight[6, 0, 2, 1] = 1 weight[7, 0, 2, 2] = 1 self.weight_check = weight.clone() return weight class LocalStDev(LocalAffinity): def _init_aff(self): weight = torch.zeros(9, 1, 3, 3) weight.zero_() weight[0, 0, 0, 0] = 1 weight[1, 0, 0, 1] = 1 weight[2, 0, 0, 2] = 1 weight[3, 0, 1, 0] = 1 weight[4, 0, 1, 1] = 1 weight[5, 0, 1, 2] = 1 weight[6, 0, 2, 0] = 1 weight[7, 0, 2, 1] = 1 weight[8, 0, 2, 2] = 1 self.weight_check = weight.clone() return weight def forward(self, x): # returns (B,K,P,H,W), where P is the number # of locations x = super(LocalStDev, self).forward(x) return x.std(2, keepdim=True) class LocalAffinityAbs(LocalAffinity): def forward(self, x): x = super(LocalAffinityAbs, self).forward(x) return torch.abs(x) # # PACRN module # class TeRN(nn.Module): def __init__(self, num_iter=10, dilations=[1]): super(TeRN, self).__init__() self.num_iter = num_iter self.aff_x = LocalAffinityAbs(dilations) self.aff_m = LocalAffinityCopy(dilations) self.aff_std = LocalStDev(dilations) def forward(self, x, mask): mask = F.interpolate(mask, size=x.size()[-2:], mode="bilinear", align_corners=True) # import time # start = time.time() # x: [BxKxHxW] # mask: [BxCxHxW] B, K, H, W = x.size() _, C, _, _ = mask.size() x_std = self.aff_std(x) # std_run_time = time.time() - start # start = time.time() x = -self.aff_x(x) / (1e-8 + 0.1 * x_std) x = x.mean(1, keepdim=True) x = F.softmax(x, 2) # mean_sm_run_time = time.time() - start for i in range(self.num_iter): # start = time.time() m = self.aff_m(mask) # [BxCxPxHxW] # single_mask_run_time = time.time() - start # print(f"single mask run: {single_mask_run_time}") mask = (m * x).sum(2) # print(f"std_run_time: {std_run_time}, mean_sm_run_time: {mean_sm_run_time}") # xvals: [BxCxHxW] return mask