import torch import torch.nn as nn import torch.nn.functional as F import torch.nn.utils.weight_norm as wn ################ # Upsampler ################ def make_coord(shape, ranges=None, flatten=True): """Make coordinates at grid centers.""" coord_seqs = [] for i, n in enumerate(shape): if ranges is None: v0, v1 = -1, 1 else: v0, v1 = ranges[i] r = (v1 - v0) / (2 * n) seq = v0 + r + (2 * r) * torch.arange(n).float() coord_seqs.append(seq) ret = torch.stack(torch.meshgrid(*coord_seqs), dim=-1) if flatten: ret = ret.view(-1, ret.shape[-1]) return ret class UPLayer_MS_V9(nn.Module): # Up-sampling net def __init__(self, n_feats, kSize, out_channels, interpolate_mode, levels=4): super().__init__() self.interpolate_mode = interpolate_mode self.levels = levels self.UPNet_x2_list = [] for _ in range(levels - 1): self.UPNet_x2_list.append( nn.Sequential( *[ nn.Conv2d( n_feats, n_feats * 4, kSize, padding=(kSize - 1) // 2, stride=1, ), nn.PixelShuffle(2), ] ) ) self.scale_aware_layer = nn.Sequential( *[nn.Linear(1, 64), nn.ReLU(), nn.Linear(64, levels), nn.Sigmoid()] ) self.UPNet_x2_list = nn.Sequential(*self.UPNet_x2_list) self.fuse = nn.Sequential( *[ nn.Conv2d(n_feats * levels, 256, kernel_size=1, padding=0, stride=1), nn.ReLU(), nn.Conv2d(256, 256, kernel_size=1, padding=0, stride=1), nn.ReLU(), nn.Conv2d(256, 256, kernel_size=1, padding=0, stride=1), nn.ReLU(), nn.Conv2d(256, 256, kernel_size=1, padding=0, stride=1), nn.ReLU(), nn.Conv2d(256, out_channels, kernel_size=1, padding=0, stride=1), ] ) def forward(self, x, out_size): if type(out_size) == int: out_size = [out_size, out_size] if type(x) == list: return self.forward_list(x, out_size) r = torch.tensor([x.shape[2] / out_size[0]], device="cuda") scale_w = self.scale_aware_layer(r.unsqueeze(0))[0] # scale_in = x.new_tensor(np.ones([x.shape[0], 1, out_size[0], out_size[1]])*r) x_list = [x] for l in range(1, self.levels): x_list.append(self.UPNet_x2_list[l - 1](x_list[l - 1])) x_resize_list = [] for l in range(self.levels): x_resize = F.interpolate( x_list[l], out_size, mode=self.interpolate_mode, align_corners=False ) x_resize *= scale_w[l] x_resize_list.append(x_resize) # x_resize_list.append(scale_in) out = self.fuse(torch.cat(tuple(x_resize_list), 1)) return out def forward_list(self, h_list, out_size): assert ( len(h_list) == self.levels ), "The Length of input list must equal to the number of levels" device = h_list[0].device r = torch.tensor([h_list[0].shape[2] / out_size[0]], device=device) scale_w = self.scale_aware_layer(r.unsqueeze(0))[0] x_resize_list = [] for l in range(self.levels): h = h_list[l] for i in range(l): h = self.UPNet_x2_list[i](h) x_resize = F.interpolate( h, out_size, mode=self.interpolate_mode, align_corners=False ) x_resize *= scale_w[l] x_resize_list.append(x_resize) out = self.fuse(torch.cat(tuple(x_resize_list), 1)) return out class UPLayer_MS_WN(nn.Module): # Up-sampling net def __init__(self, n_feats, kSize, out_channels, interpolate_mode, levels=4): super().__init__() self.interpolate_mode = interpolate_mode self.levels = levels self.UPNet_x2_list = [] for _ in range(levels - 1): self.UPNet_x2_list.append( nn.Sequential( *[ wn( nn.Conv2d( n_feats, n_feats * 4, kSize, padding=(kSize - 1) // 2, stride=1, ) ), nn.PixelShuffle(2), ] ) ) self.scale_aware_layer = nn.Sequential( *[wn(nn.Linear(1, 64)), nn.ReLU(), wn(nn.Linear(64, levels)), nn.Sigmoid()] ) self.UPNet_x2_list = nn.Sequential(*self.UPNet_x2_list) self.fuse = nn.Sequential( *[ wn( nn.Conv2d(n_feats * levels, 256, kernel_size=1, padding=0, stride=1) ), nn.ReLU(), wn(nn.Conv2d(256, 256, kernel_size=1, padding=0, stride=1)), nn.ReLU(), wn(nn.Conv2d(256, 256, kernel_size=1, padding=0, stride=1)), nn.ReLU(), wn(nn.Conv2d(256, 256, kernel_size=1, padding=0, stride=1)), nn.ReLU(), wn(nn.Conv2d(256, out_channels, kernel_size=1, padding=0, stride=1)), ] ) assert self.interpolate_mode in ( "bilinear", "bicubic", "nearest", "MLP", ), "Interpolate mode must be bilinear/bicubic/nearest/MLP" if self.interpolate_mode == "MLP": self.feature_interpolater = MLP_Interpolate(n_feats, radius=3) elif self.interpolate_mode == "nearest": self.feature_interpolater = lambda x, out_size: F.interpolate( x, out_size, mode=self.interpolate_mode ) else: self.feature_interpolater = lambda x, out_size: F.interpolate( x, out_size, mode=self.interpolate_mode, align_corners=False ) def forward(self, x, out_size): if type(out_size) == int: out_size = [out_size, out_size] if type(x) == list: return self.forward_list(x, out_size) r = torch.tensor([x.shape[2] / out_size[0]], device="cuda") scale_w = self.scale_aware_layer(r.unsqueeze(0))[0] x_list = [x] for l in range(1, self.levels): x_list.append(self.UPNet_x2_list[l - 1](x_list[l - 1])) x_resize_list = [] for l in range(self.levels): x_resize = self.feature_interpolater(x_list[l], out_size) x_resize *= scale_w[l] x_resize_list.append(x_resize) out = self.fuse(torch.cat(tuple(x_resize_list), 1)) return out def forward_list(self, h_list, out_size): assert ( len(h_list) == self.levels ), "The Length of input list must equal to the number of levels" device = h_list[0].device r = torch.tensor([h_list[0].shape[2] / out_size[0]], device=device) scale_w = self.scale_aware_layer(r.unsqueeze(0))[0] x_resize_list = [] for l in range(self.levels): h = h_list[l] for i in range(l): h = self.UPNet_x2_list[i](h) x_resize = self.feature_interpolater(h, out_size) x_resize *= scale_w[l] x_resize_list.append(x_resize) out = self.fuse(torch.cat(tuple(x_resize_list), 1)) return out class UPLayer_MS_WN_woSA(UPLayer_MS_WN): def __init__(self, n_feats, kSize, out_channels, interpolate_mode, levels=4): super().__init__(n_feats, kSize, out_channels, interpolate_mode, levels) def forward(self, x, out_size): if type(out_size) == int: out_size = [out_size, out_size] if type(x) == list: return self.forward_list(x, out_size) x_list = [x] for l in range(1, self.levels): x_list.append(self.UPNet_x2_list[l - 1](x_list[l - 1])) x_resize_list = [] for l in range(self.levels): x_resize = self.feature_interpolater(x_list[l], out_size) x_resize_list.append(x_resize) out = self.fuse(torch.cat(tuple(x_resize_list), 1)) return out def forward_list(self, h_list, out_size): assert ( len(h_list) == self.levels ), "The Length of input list must equal to the number of levels" x_resize_list = [] for l in range(self.levels): h = h_list[l] for i in range(l): h = self.UPNet_x2_list[i](h) x_resize = self.feature_interpolater(h, out_size) x_resize_list.append(x_resize) out = self.fuse(torch.cat(tuple(x_resize_list), 1)) return out class OSM(nn.Module): def __init__(self, n_feats, overscale): super().__init__() self.body = nn.Sequential( wn(nn.Conv2d(n_feats, 1600, 3, padding=1)), nn.PixelShuffle(overscale), wn(nn.Conv2d(64, 3, 3, padding=1)), ) def forward(self, x, out_size): h = self.body(x) return F.interpolate(h, out_size, mode="bicubic", align_corners=False) class MLP_Interpolate(nn.Module): def __init__(self, n_feat, radius=2): super().__init__() self.radius = radius self.f_transfer = nn.Sequential( *[ nn.Linear(n_feat * self.radius * self.radius + 2, n_feat), nn.ReLU(True), nn.Linear(n_feat, n_feat), ] ) def forward(self, x, out_size): x_unfold = F.unfold(x, self.radius, padding=self.radius // 2) x_unfold = x_unfold.view( x.shape[0], x.shape[1] * (self.radius ** 2), x.shape[2], x.shape[3] ) in_shape = x.shape[-2:] in_coord = ( make_coord(in_shape, flatten=False) .cuda() .permute(2, 0, 1) .unsqueeze(0) .expand(x.shape[0], 2, *in_shape) ) if type(out_size) == int: out_size = [out_size, out_size] out_coord = make_coord(out_size, flatten=True).cuda() out_coord = out_coord.expand(x.shape[0], *out_coord.shape) q_feat = F.grid_sample( x_unfold, out_coord.flip(-1).unsqueeze(1), mode="nearest", align_corners=False, )[:, :, 0, :].permute(0, 2, 1) q_coord = F.grid_sample( in_coord, out_coord.flip(-1).unsqueeze(1), mode="nearest", align_corners=False, )[:, :, 0, :].permute(0, 2, 1) rel_coord = out_coord - q_coord rel_coord[:, :, 0] *= x.shape[-2] rel_coord[:, :, 1] *= x.shape[-1] inp = torch.cat([q_feat, rel_coord], dim=-1) bs, q = out_coord.shape[:2] pred = self.f_transfer(inp.view(bs * q, -1)).view(bs, q, -1) pred = ( pred.view(x.shape[0], *out_size, x.shape[1]) .permute(0, 3, 1, 2) .contiguous() ) return pred class LIIF_Upsampler(nn.Module): def __init__(self): super().__init__() raise NotImplementedError def forward(self): pass