import time from collections import OrderedDict import torch import torch.nn as nn import math import torchvision.utils as SI def make_model(args, parent=False): return metafpn(args) class Pos2Weight(nn.Module): def __init__(self, inC, kernel_size=3, outC=3): super(Pos2Weight, self).__init__() self.inC = inC self.kernel_size = kernel_size self.outC = outC self.meta_block = nn.Sequential( nn.Linear(3, 256), nn.ReLU(inplace=True), nn.Linear(256, 512), nn.ReLU(inplace=True), nn.Linear(512, self.kernel_size * self.kernel_size * self.inC * self.outC) ) def forward(self, x): output = self.meta_block(x) return output class RDB_Conv(nn.Module): def __init__(self, inChannels, growRate, kSize=3): super(RDB_Conv, self).__init__() Cin = inChannels G = growRate self.conv = nn.Sequential(*[ nn.Conv2d(Cin, G, kSize, padding=(kSize - 1) // 2, stride=1), nn.ReLU() ]) def forward(self, x): out = self.conv(x) return out class FPN(nn.Module): def __init__(self, G0, kSize=3): super(FPN, self).__init__() kSize1 = 1 self.conv1 = RDB_Conv(G0, G0, kSize) self.conv2 = RDB_Conv(G0, G0, kSize) self.conv3 = RDB_Conv(G0, G0, kSize) self.conv4 = RDB_Conv(G0, G0, kSize) self.conv5 = RDB_Conv(G0, G0, kSize) self.conv6 = RDB_Conv(G0, G0, kSize) self.conv7 = RDB_Conv(G0, G0, kSize) self.conv8 = RDB_Conv(G0, G0, kSize) self.conv9 = RDB_Conv(G0, G0, kSize) self.conv10 = RDB_Conv(G0, G0, kSize) self.compress_in1 = nn.Conv2d(4 * G0, G0, kSize1, padding=(kSize1 - 1) // 2, stride=1) self.compress_in2 = nn.Conv2d(3 * G0, G0, kSize1, padding=(kSize1 - 1) // 2, stride=1) self.compress_in3 = nn.Conv2d(2 * G0, G0, kSize1, padding=(kSize1 - 1) // 2, stride=1) self.compress_in4 = nn.Conv2d(2 * G0, G0, kSize1, padding=(kSize1 - 1) // 2, stride=1) self.compress_out = nn.Conv2d(4 * G0, G0, kSize1, padding=(kSize1 - 1) // 2, stride=1) def forward(self, x): x1 = self.conv1(x) x2 = self.conv2(x1) x3 = self.conv3(x2) x4 = self.conv4(x3) x11 = x + x4 x5 = torch.cat((x1, x2, x3, x4), dim=1) x5_res = self.compress_in1(x5) x5 = self.conv5(x5_res) x6 = self.conv6(x5) x7 = self.conv7(x6) x12 = x5_res + x7 x8 = torch.cat((x5, x6, x7), dim=1) x8_res = self.compress_in2(x8) x8 = self.conv8(x8_res) x9 = self.conv9(x8) x13 = x8_res + x9 x10 = torch.cat((x8, x9), dim=1) x10_res = self.compress_in3(x10) x10 = self.conv10(x10_res) x14 = x10_res + x10 output = torch.cat((x11, x12, x13, x14), dim=1) output = self.compress_out(output) output = output + x return output def default_conv(in_channels, out_channels, kernel_size, bias=True): return nn.Conv2d( in_channels, out_channels, kernel_size, padding=(kernel_size // 2), bias=bias) class MeanShift(nn.Conv2d): def __init__(self, rgb_range, rgb_mean, rgb_std, sign=-1): super(MeanShift, self).__init__(3, 3, kernel_size=1) std = torch.Tensor(rgb_std) self.weight.data = torch.eye(3).view(3, 3, 1, 1) self.weight.data.div_(std.view(3, 1, 1, 1)) self.bias.data = sign * rgb_range * torch.Tensor(rgb_mean) self.bias.data.div_(std) self.requires_grad = False class BasicBlock(nn.Sequential): def __init__( self, in_channels, out_channels, kernel_size, stride=1, bias=False, bn=True, act=nn.ReLU(True)): m = [nn.Conv2d( in_channels, out_channels, kernel_size, padding=(kernel_size // 2), stride=stride, bias=bias) ] if bn: m.append(nn.BatchNorm2d(out_channels)) if act is not None: m.append(act) super(BasicBlock, self).__init__(*m) class ResBlock(nn.Module): def __init__( self, conv, n_feats, kernel_size, bias=True, bn=False, act=nn.ReLU(True), res_scale=1): super(ResBlock, self).__init__() m = [] for i in range(2): m.append(conv(n_feats, n_feats, kernel_size, bias=bias)) if bn: m.append(nn.BatchNorm2d(n_feats)) if i == 0: m.append(act) self.body = nn.Sequential(*m) self.res_scale = res_scale def forward(self, x): res = self.body(x).mul(self.res_scale) res += x return res class Upsampler(nn.Sequential): def __init__(self, conv, scale, n_feats, bn=False, act=False, bias=True): m = [] if (scale & (scale - 1)) == 0: # Is scale = 2^n? for _ in range(int(math.log(scale, 2))): m.append(conv(n_feats, 4 * n_feats, 3, bias)) m.append(nn.PixelShuffle(2)) if bn: m.append(nn.BatchNorm2d(n_feats)) if act == 'relu': m.append(nn.ReLU(True)) elif act == 'prelu': m.append(nn.PReLU(n_feats)) elif scale == 3: m.append(conv(n_feats, 9 * n_feats, 3, bias)) m.append(nn.PixelShuffle(3)) if bn: m.append(nn.BatchNorm2d(n_feats)) if act == 'relu': m.append(nn.ReLU(True)) elif act == 'prelu': m.append(nn.PReLU(n_feats)) else: raise NotImplementedError super(Upsampler, self).__init__(*m) class ResidualDenseBlock_8C(nn.Module): ''' Residual Dense Block style: 8 convs The core module of paper: (Residual Dense Network for Image Super-Resolution, CVPR 18) ''' def __init__(self, nc, kernel_size=3, gc=32, stride=1, bias=True, pad_type='zero', norm_type=None, act_type='relu', mode='CNA'): super(ResidualDenseBlock_8C, self).__init__() # gc: growth channel, i.e. intermediate channels self.conv1 = ConvBlock(nc, gc, kernel_size, stride, bias=bias, pad_type=pad_type, norm_type=norm_type, act_type=act_type, mode=mode) self.conv2 = ConvBlock(nc + gc, gc, kernel_size, stride, bias=bias, pad_type=pad_type, norm_type=norm_type, act_type=act_type, mode=mode) self.conv3 = ConvBlock(nc + 2 * gc, gc, kernel_size, stride, bias=bias, pad_type=pad_type, norm_type=norm_type, act_type=act_type, mode=mode) self.conv4 = ConvBlock(nc + 3 * gc, gc, kernel_size, stride, bias=bias, pad_type=pad_type, norm_type=norm_type, act_type=act_type, mode=mode) self.conv5 = ConvBlock(nc + 4 * gc, gc, kernel_size, stride, bias=bias, pad_type=pad_type, norm_type=norm_type, act_type=act_type, mode=mode) self.conv6 = ConvBlock(nc + 5 * gc, gc, kernel_size, stride, bias=bias, pad_type=pad_type, norm_type=norm_type, act_type=act_type, mode=mode) self.conv7 = ConvBlock(nc + 6 * gc, gc, kernel_size, stride, bias=bias, pad_type=pad_type, norm_type=norm_type, act_type=act_type, mode=mode) self.conv8 = ConvBlock(nc + 7 * gc, gc, kernel_size, stride, bias=bias, pad_type=pad_type, norm_type=norm_type, act_type=act_type, mode=mode) if mode == 'CNA': last_act = None else: last_act = act_type self.conv9 = ConvBlock(nc + 8 * gc, nc, 1, stride, bias=bias, pad_type=pad_type, norm_type=norm_type, act_type=last_act, mode=mode) def forward(self, x): x1 = self.conv1(x) x2 = self.conv2(torch.cat((x, x1), 1)) x3 = self.conv3(torch.cat((x, x1, x2), 1)) x4 = self.conv4(torch.cat((x, x1, x2, x3), 1)) x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1)) x6 = self.conv6(torch.cat((x, x1, x2, x3, x4, x5), 1)) x7 = self.conv7(torch.cat((x, x1, x2, x3, x4, x5, x6), 1)) x8 = self.conv8(torch.cat((x, x1, x2, x3, x4, x5, x6, x7), 1)) x9 = self.conv9(torch.cat((x, x1, x2, x3, x4, x5, x6, x7, x8), 1)) return x9.mul(0.2) + x def ConvBlock(in_channels, out_channels, kernel_size, stride=1, dilation=1, bias=True, valid_padding=True, padding=0, \ act_type='relu', norm_type='bn', pad_type='zero', mode='CNA'): assert (mode in ['CNA', 'NAC']), '[ERROR] Wrong mode in [%s]!' % sys.modules[__name__] if valid_padding: padding = get_valid_padding(kernel_size, dilation) else: pass p = pad(pad_type, padding) if pad_type and pad_type != 'zero' else None conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride=stride, padding=padding, dilation=dilation, bias=bias) if mode == 'CNA': act = activation(act_type) if act_type else None n = norm(out_channels, norm_type) if norm_type else None return sequential(p, conv, n, act) elif mode == 'NAC': act = activation(act_type, inplace=False) if act_type else None n = norm(in_channels, norm_type) if norm_type else None return sequential(n, act, p, conv) def DeconvBlock(in_channels, out_channels, kernel_size, stride=1, dilation=1, bias=True, padding=0, \ act_type='relu', norm_type='bn', pad_type='zero', mode='CNA'): assert (mode in ['CNA', 'NAC']), '[ERROR] Wrong mode in [%s]!' % sys.modules[__name__] p = pad(pad_type, padding) if pad_type and pad_type != 'zero' else None deconv = nn.ConvTranspose2d(in_channels, out_channels, kernel_size, stride, padding, dilation=dilation, bias=bias) if mode == 'CNA': act = activation(act_type) if act_type else None n = norm(out_channels, norm_type) if norm_type else None return sequential(p, deconv, n, act) elif mode == 'NAC': act = activation(act_type, inplace=False) if act_type else None n = norm(in_channels, norm_type) if norm_type else None return sequential(n, act, p, deconv) def get_valid_padding(kernel_size, dilation): """ Padding value to remain feature size. """ kernel_size = kernel_size + (kernel_size - 1) * (dilation - 1) padding = (kernel_size - 1) // 2 return padding def pad(pad_type, padding): pad_type = pad_type.lower() if padding == 0: return None layer = None if pad_type == 'reflect': layer = nn.ReflectionPad2d(padding) elif pad_type == 'replicate': layer = nn.ReplicationPad2d(padding) else: raise NotImplementedError('[ERROR] Padding layer [%s] is not implemented!' % pad_type) return layer def activation(act_type='relu', inplace=True, slope=0.2, n_prelu=1): act_type = act_type.lower() layer = None if act_type == 'relu': layer = nn.ReLU(inplace) elif act_type == 'lrelu': layer = nn.LeakyReLU(slope, inplace) elif act_type == 'prelu': layer = nn.PReLU(num_parameters=n_prelu, init=slope) else: raise NotImplementedError('[ERROR] Activation layer [%s] is not implemented!' % act_type) return layer def norm(n_feature, norm_type='bn'): norm_type = norm_type.lower() layer = None if norm_type == 'bn': layer = nn.BatchNorm2d(n_feature) else: raise NotImplementedError('[ERROR] Normalization layer [%s] is not implemented!' % norm_type) return layer def sequential(*args): if len(args) == 1: if isinstance(args[0], OrderedDict): raise NotImplementedError('[ERROR] %s.sequential() does not support OrderedDict' % sys.modules[__name__]) else: return args[0] modules = [] for module in args: if isinstance(module, nn.Sequential): for submodule in module: modules.append(submodule) elif isinstance(module, nn.Module): modules.append(module) return nn.Sequential(*modules) class FeedbackBlock(nn.Module): def __init__(self, num_features, num_groups, upscale_factor, act_type, norm_type): super(FeedbackBlock, self).__init__() if upscale_factor == 2: stride = 2 padding = 2 kernel_size = 6 elif upscale_factor == 3: stride = 3 padding = 2 kernel_size = 7 elif upscale_factor == 4: stride = 4 padding = 2 kernel_size = 8 elif upscale_factor == 8: stride = 8 padding = 2 kernel_size = 12 kSize = 3 kSize1 = 1 self.fpn1 = FPN(num_features) self.fpn2 = FPN(num_features) self.fpn3 = FPN(num_features) self.fpn4 = FPN(num_features) self.compress_in = nn.Conv2d(2 * num_features, num_features, kSize1, padding=(kSize1 - 1) // 2, stride=1) self.compress_out = nn.Conv2d(4 * num_features, num_features, kSize1, padding=(kSize1 - 1) // 2, stride=1) def forward(self, x): if self.should_reset: self.last_hidden = torch.zeros(x.size()).cuda() self.last_hidden.copy_(x) self.should_reset = False x = torch.cat((x, self.last_hidden), dim=1) # tense拼接 x = self.compress_in(x) fpn1 = self.fpn1(x) fpn2 = self.fpn2(fpn1) fpn3 = self.fpn3(fpn2) fpn4 = self.fpn4(fpn3) output = torch.cat((fpn1, fpn2, fpn3, fpn4), dim=1) output = self.compress_out(output) self.last_hidden = output return output def reset_state(self): self.should_reset = True class metafpn(nn.Module): def __init__(self, RDNkSize=3, G0=64, n_colors=3, act_type='prelu', norm_type=None ): super(metafpn, self).__init__() # 第一句话,调用父类的构造函数,这是对继承自父类的属性进行初始化。而且是用父类的初始化方法来初始化继承的属性。也就是说,子类继承了父类的所有属性和方法,父类属性自然会用父类方法来进行初始化。当然,如果初始化的逻辑与父类的不同,不使用父类的方法,自己重新初始化也是可以的。 kernel_size = RDNkSize self.num_steps = 4 self.num_features = G0 self.scale_idx = 0 self.scale = 1 in_channels = n_colors num_groups = 6 # RGB mean for DIV2K # rgb_mean = (0.4488, 0.4371, 0.4040) # rgb_std = (1.0, 1.0, 1.0) # self.sub_mean = common.MeanShift(args.rgb_range, rgb_mean, rgb_std) # self.add_mean = common.MeanShift(args.rgb_range, rgb_mean, rgb_std, 1) # LR feature extraction block self.conv_in = ConvBlock(in_channels, 4 * self.num_features, # 3×3Conv 一个卷积核产生一个feature map就是num_features kernel_size=3, act_type=act_type, norm_type=norm_type) self.feat_in = ConvBlock(4 * self.num_features, self.num_features, kernel_size=1, act_type=act_type, norm_type=norm_type) # basic block self.block = FeedbackBlock(self.num_features, num_groups, self.scale, act_type, norm_type) # reconstruction block # uncomment for pytorch 0.4.0 # self.upsample = nn.Upsample(scale_factor=upscale_factor, mode='bilinear') # self.out = DeconvBlock(num_features, num_features, # kernel_size=kernel_size, stride=stride, padding=padding, # act_type='prelu', norm_type=norm_type) self.P2W = Pos2Weight(inC=self.num_features) def repeat_x(self, x): scale_int = math.ceil(self.scale) N, C, H, W = x.size() x = x.view(N, C, H, 1, W, 1) x = torch.cat([x] * scale_int, 3) x = torch.cat([x] * scale_int, 5).permute(0, 3, 5, 1, 2, 4) return x.contiguous().view(-1, C, H, W) def forward(self, x, pos_mat): self._reset_state() # x = self.sub_mean(x) scale_int = math.ceil(self.scale) # uncomment for pytorch 0.4.0 # inter_res = self.upsample(x) # comment for pytorch 0.4.0 inter_res = nn.functional.interpolate(x, scale_factor=scale_int, mode='bilinear', align_corners=False) x = self.conv_in(x) x = self.feat_in(x) outs = [] for _ in range(self.num_steps): h = self.block(x) #output1 = h.clone() # for i in range(60): # output2 = output1[:,i:i+3,:,:] # SI.save_image(output2,"results/result"+str(i)+".png") # meta########################################### local_weight = self.P2W( pos_mat.view(pos_mat.size(1), -1)) ### (outH*outW, outC*inC*kernel_size*kernel_size) up_x = self.repeat_x(h) ### the output is (N*r*r,inC,inH,inW) # N*r^2 x [inC * kH * kW] x [inH * inW] cols = nn.functional.unfold(up_x, 3, padding=1) scale_int = math.ceil(self.scale) cols = cols.contiguous().view(cols.size(0) // (scale_int ** 2), scale_int ** 2, cols.size(1), cols.size(2), 1).permute(0, 1, 3, 4, 2).contiguous() local_weight = local_weight.contiguous().view(x.size(2), scale_int, x.size(3), scale_int, -1, 3).permute(1, 3, 0, 2, 4, 5).contiguous() local_weight = local_weight.contiguous().view(scale_int ** 2, x.size(2) * x.size(3), -1, 3) out = torch.matmul(cols, local_weight).permute(0, 1, 4, 2, 3) out = out.contiguous().view(x.size(0), scale_int, scale_int, 3, x.size(2), x.size(3)).permute(0, 3, 4, 1, 5, 2) out = out.contiguous().view(x.size(0), 3, scale_int * x.size(2), scale_int * x.size(3)) h = torch.add(inter_res, out) # h = self.add_mean(h) outs.append(h) return outs # return output of every timesteps def _reset_state(self): self.block.reset_state() def set_scale(self, scale_idx): self.scale_idx = scale_idx self.scale = self.args.scale[scale_idx]