|
import time |
|
from collections import OrderedDict |
|
|
|
import torch |
|
import torch.nn as nn |
|
import math |
|
import torchvision.utils as SI |
|
|
|
def make_model(args, parent=False): |
|
return metafpn(args) |
|
|
|
|
|
class Pos2Weight(nn.Module): |
|
def __init__(self, inC, kernel_size=3, outC=3): |
|
super(Pos2Weight, self).__init__() |
|
self.inC = inC |
|
self.kernel_size = kernel_size |
|
self.outC = outC |
|
self.meta_block = nn.Sequential( |
|
nn.Linear(3, 256), |
|
nn.ReLU(inplace=True), |
|
nn.Linear(256, 512), |
|
nn.ReLU(inplace=True), |
|
nn.Linear(512, self.kernel_size * self.kernel_size * self.inC * self.outC) |
|
) |
|
|
|
def forward(self, x): |
|
output = self.meta_block(x) |
|
return output |
|
|
|
|
|
class RDB_Conv(nn.Module): |
|
def __init__(self, inChannels, growRate, kSize=3): |
|
super(RDB_Conv, self).__init__() |
|
Cin = inChannels |
|
G = growRate |
|
self.conv = nn.Sequential(*[ |
|
nn.Conv2d(Cin, G, kSize, padding=(kSize - 1) // 2, stride=1), |
|
nn.ReLU() |
|
]) |
|
|
|
def forward(self, x): |
|
out = self.conv(x) |
|
return out |
|
|
|
|
|
class FPN(nn.Module): |
|
def __init__(self, G0, kSize=3): |
|
super(FPN, self).__init__() |
|
|
|
kSize1 = 1 |
|
self.conv1 = RDB_Conv(G0, G0, kSize) |
|
self.conv2 = RDB_Conv(G0, G0, kSize) |
|
self.conv3 = RDB_Conv(G0, G0, kSize) |
|
self.conv4 = RDB_Conv(G0, G0, kSize) |
|
self.conv5 = RDB_Conv(G0, G0, kSize) |
|
self.conv6 = RDB_Conv(G0, G0, kSize) |
|
self.conv7 = RDB_Conv(G0, G0, kSize) |
|
self.conv8 = RDB_Conv(G0, G0, kSize) |
|
self.conv9 = RDB_Conv(G0, G0, kSize) |
|
self.conv10 = RDB_Conv(G0, G0, kSize) |
|
self.compress_in1 = nn.Conv2d(4 * G0, G0, kSize1, padding=(kSize1 - 1) // 2, stride=1) |
|
self.compress_in2 = nn.Conv2d(3 * G0, G0, kSize1, padding=(kSize1 - 1) // 2, stride=1) |
|
self.compress_in3 = nn.Conv2d(2 * G0, G0, kSize1, padding=(kSize1 - 1) // 2, stride=1) |
|
self.compress_in4 = nn.Conv2d(2 * G0, G0, kSize1, padding=(kSize1 - 1) // 2, stride=1) |
|
self.compress_out = nn.Conv2d(4 * G0, G0, kSize1, padding=(kSize1 - 1) // 2, stride=1) |
|
|
|
def forward(self, x): |
|
x1 = self.conv1(x) |
|
x2 = self.conv2(x1) |
|
x3 = self.conv3(x2) |
|
x4 = self.conv4(x3) |
|
x11 = x + x4 |
|
x5 = torch.cat((x1, x2, x3, x4), dim=1) |
|
x5_res = self.compress_in1(x5) |
|
x5 = self.conv5(x5_res) |
|
x6 = self.conv6(x5) |
|
x7 = self.conv7(x6) |
|
x12 = x5_res + x7 |
|
x8 = torch.cat((x5, x6, x7), dim=1) |
|
x8_res = self.compress_in2(x8) |
|
x8 = self.conv8(x8_res) |
|
x9 = self.conv9(x8) |
|
x13 = x8_res + x9 |
|
x10 = torch.cat((x8, x9), dim=1) |
|
x10_res = self.compress_in3(x10) |
|
x10 = self.conv10(x10_res) |
|
x14 = x10_res + x10 |
|
output = torch.cat((x11, x12, x13, x14), dim=1) |
|
output = self.compress_out(output) |
|
output = output + x |
|
return output |
|
|
|
|
|
def default_conv(in_channels, out_channels, kernel_size, bias=True): |
|
return nn.Conv2d( |
|
in_channels, out_channels, kernel_size, |
|
padding=(kernel_size // 2), bias=bias) |
|
|
|
|
|
class MeanShift(nn.Conv2d): |
|
def __init__(self, rgb_range, rgb_mean, rgb_std, sign=-1): |
|
super(MeanShift, self).__init__(3, 3, kernel_size=1) |
|
std = torch.Tensor(rgb_std) |
|
self.weight.data = torch.eye(3).view(3, 3, 1, 1) |
|
self.weight.data.div_(std.view(3, 1, 1, 1)) |
|
self.bias.data = sign * rgb_range * torch.Tensor(rgb_mean) |
|
self.bias.data.div_(std) |
|
self.requires_grad = False |
|
|
|
|
|
class BasicBlock(nn.Sequential): |
|
def __init__( |
|
self, in_channels, out_channels, kernel_size, stride=1, bias=False, |
|
bn=True, act=nn.ReLU(True)): |
|
|
|
m = [nn.Conv2d( |
|
in_channels, out_channels, kernel_size, |
|
padding=(kernel_size // 2), stride=stride, bias=bias) |
|
] |
|
if bn: m.append(nn.BatchNorm2d(out_channels)) |
|
if act is not None: m.append(act) |
|
super(BasicBlock, self).__init__(*m) |
|
|
|
|
|
class ResBlock(nn.Module): |
|
def __init__( |
|
self, conv, n_feats, kernel_size, |
|
bias=True, bn=False, act=nn.ReLU(True), res_scale=1): |
|
|
|
super(ResBlock, self).__init__() |
|
m = [] |
|
for i in range(2): |
|
m.append(conv(n_feats, n_feats, kernel_size, bias=bias)) |
|
if bn: m.append(nn.BatchNorm2d(n_feats)) |
|
if i == 0: m.append(act) |
|
|
|
self.body = nn.Sequential(*m) |
|
self.res_scale = res_scale |
|
|
|
def forward(self, x): |
|
res = self.body(x).mul(self.res_scale) |
|
res += x |
|
|
|
return res |
|
|
|
|
|
class Upsampler(nn.Sequential): |
|
def __init__(self, conv, scale, n_feats, bn=False, act=False, bias=True): |
|
|
|
m = [] |
|
if (scale & (scale - 1)) == 0: |
|
for _ in range(int(math.log(scale, 2))): |
|
m.append(conv(n_feats, 4 * n_feats, 3, bias)) |
|
m.append(nn.PixelShuffle(2)) |
|
if bn: m.append(nn.BatchNorm2d(n_feats)) |
|
|
|
if act == 'relu': |
|
m.append(nn.ReLU(True)) |
|
elif act == 'prelu': |
|
m.append(nn.PReLU(n_feats)) |
|
|
|
elif scale == 3: |
|
m.append(conv(n_feats, 9 * n_feats, 3, bias)) |
|
m.append(nn.PixelShuffle(3)) |
|
if bn: m.append(nn.BatchNorm2d(n_feats)) |
|
|
|
if act == 'relu': |
|
m.append(nn.ReLU(True)) |
|
elif act == 'prelu': |
|
m.append(nn.PReLU(n_feats)) |
|
else: |
|
raise NotImplementedError |
|
|
|
super(Upsampler, self).__init__(*m) |
|
|
|
|
|
class ResidualDenseBlock_8C(nn.Module): |
|
''' |
|
Residual Dense Block |
|
style: 8 convs |
|
The core module of paper: (Residual Dense Network for Image Super-Resolution, CVPR 18) |
|
''' |
|
|
|
def __init__(self, nc, kernel_size=3, gc=32, stride=1, bias=True, pad_type='zero', norm_type=None, act_type='relu', |
|
mode='CNA'): |
|
super(ResidualDenseBlock_8C, self).__init__() |
|
|
|
self.conv1 = ConvBlock(nc, gc, kernel_size, stride, bias=bias, pad_type=pad_type, norm_type=norm_type, |
|
act_type=act_type, mode=mode) |
|
self.conv2 = ConvBlock(nc + gc, gc, kernel_size, stride, bias=bias, pad_type=pad_type, norm_type=norm_type, |
|
act_type=act_type, mode=mode) |
|
self.conv3 = ConvBlock(nc + 2 * gc, gc, kernel_size, stride, bias=bias, pad_type=pad_type, norm_type=norm_type, |
|
act_type=act_type, mode=mode) |
|
self.conv4 = ConvBlock(nc + 3 * gc, gc, kernel_size, stride, bias=bias, pad_type=pad_type, norm_type=norm_type, |
|
act_type=act_type, mode=mode) |
|
self.conv5 = ConvBlock(nc + 4 * gc, gc, kernel_size, stride, bias=bias, pad_type=pad_type, norm_type=norm_type, |
|
act_type=act_type, mode=mode) |
|
self.conv6 = ConvBlock(nc + 5 * gc, gc, kernel_size, stride, bias=bias, pad_type=pad_type, norm_type=norm_type, |
|
act_type=act_type, mode=mode) |
|
self.conv7 = ConvBlock(nc + 6 * gc, gc, kernel_size, stride, bias=bias, pad_type=pad_type, norm_type=norm_type, |
|
act_type=act_type, mode=mode) |
|
self.conv8 = ConvBlock(nc + 7 * gc, gc, kernel_size, stride, bias=bias, pad_type=pad_type, norm_type=norm_type, |
|
act_type=act_type, mode=mode) |
|
if mode == 'CNA': |
|
last_act = None |
|
else: |
|
last_act = act_type |
|
self.conv9 = ConvBlock(nc + 8 * gc, nc, 1, stride, bias=bias, pad_type=pad_type, norm_type=norm_type, |
|
act_type=last_act, mode=mode) |
|
|
|
def forward(self, x): |
|
x1 = self.conv1(x) |
|
x2 = self.conv2(torch.cat((x, x1), 1)) |
|
x3 = self.conv3(torch.cat((x, x1, x2), 1)) |
|
x4 = self.conv4(torch.cat((x, x1, x2, x3), 1)) |
|
x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1)) |
|
x6 = self.conv6(torch.cat((x, x1, x2, x3, x4, x5), 1)) |
|
x7 = self.conv7(torch.cat((x, x1, x2, x3, x4, x5, x6), 1)) |
|
x8 = self.conv8(torch.cat((x, x1, x2, x3, x4, x5, x6, x7), 1)) |
|
x9 = self.conv9(torch.cat((x, x1, x2, x3, x4, x5, x6, x7, x8), 1)) |
|
return x9.mul(0.2) + x |
|
|
|
|
|
def ConvBlock(in_channels, out_channels, kernel_size, stride=1, dilation=1, bias=True, valid_padding=True, padding=0, \ |
|
act_type='relu', norm_type='bn', pad_type='zero', mode='CNA'): |
|
assert (mode in ['CNA', 'NAC']), '[ERROR] Wrong mode in [%s]!' % sys.modules[__name__] |
|
|
|
if valid_padding: |
|
padding = get_valid_padding(kernel_size, dilation) |
|
else: |
|
pass |
|
p = pad(pad_type, padding) if pad_type and pad_type != 'zero' else None |
|
conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride=stride, padding=padding, dilation=dilation, |
|
bias=bias) |
|
|
|
if mode == 'CNA': |
|
act = activation(act_type) if act_type else None |
|
n = norm(out_channels, norm_type) if norm_type else None |
|
return sequential(p, conv, n, act) |
|
elif mode == 'NAC': |
|
act = activation(act_type, inplace=False) if act_type else None |
|
n = norm(in_channels, norm_type) if norm_type else None |
|
return sequential(n, act, p, conv) |
|
|
|
|
|
def DeconvBlock(in_channels, out_channels, kernel_size, stride=1, dilation=1, bias=True, padding=0, \ |
|
act_type='relu', norm_type='bn', pad_type='zero', mode='CNA'): |
|
assert (mode in ['CNA', 'NAC']), '[ERROR] Wrong mode in [%s]!' % sys.modules[__name__] |
|
|
|
p = pad(pad_type, padding) if pad_type and pad_type != 'zero' else None |
|
deconv = nn.ConvTranspose2d(in_channels, out_channels, kernel_size, stride, padding, dilation=dilation, bias=bias) |
|
|
|
if mode == 'CNA': |
|
act = activation(act_type) if act_type else None |
|
n = norm(out_channels, norm_type) if norm_type else None |
|
return sequential(p, deconv, n, act) |
|
elif mode == 'NAC': |
|
act = activation(act_type, inplace=False) if act_type else None |
|
n = norm(in_channels, norm_type) if norm_type else None |
|
return sequential(n, act, p, deconv) |
|
|
|
|
|
def get_valid_padding(kernel_size, dilation): |
|
""" |
|
Padding value to remain feature size. |
|
""" |
|
kernel_size = kernel_size + (kernel_size - 1) * (dilation - 1) |
|
padding = (kernel_size - 1) // 2 |
|
return padding |
|
|
|
|
|
def pad(pad_type, padding): |
|
pad_type = pad_type.lower() |
|
if padding == 0: |
|
return None |
|
|
|
layer = None |
|
if pad_type == 'reflect': |
|
layer = nn.ReflectionPad2d(padding) |
|
elif pad_type == 'replicate': |
|
layer = nn.ReplicationPad2d(padding) |
|
else: |
|
raise NotImplementedError('[ERROR] Padding layer [%s] is not implemented!' % pad_type) |
|
return layer |
|
|
|
|
|
def activation(act_type='relu', inplace=True, slope=0.2, n_prelu=1): |
|
act_type = act_type.lower() |
|
layer = None |
|
if act_type == 'relu': |
|
layer = nn.ReLU(inplace) |
|
elif act_type == 'lrelu': |
|
layer = nn.LeakyReLU(slope, inplace) |
|
elif act_type == 'prelu': |
|
layer = nn.PReLU(num_parameters=n_prelu, init=slope) |
|
else: |
|
raise NotImplementedError('[ERROR] Activation layer [%s] is not implemented!' % act_type) |
|
return layer |
|
|
|
|
|
def norm(n_feature, norm_type='bn'): |
|
norm_type = norm_type.lower() |
|
layer = None |
|
if norm_type == 'bn': |
|
layer = nn.BatchNorm2d(n_feature) |
|
else: |
|
raise NotImplementedError('[ERROR] Normalization layer [%s] is not implemented!' % norm_type) |
|
return layer |
|
|
|
|
|
def sequential(*args): |
|
if len(args) == 1: |
|
if isinstance(args[0], OrderedDict): |
|
raise NotImplementedError('[ERROR] %s.sequential() does not support OrderedDict' % sys.modules[__name__]) |
|
else: |
|
return args[0] |
|
modules = [] |
|
for module in args: |
|
if isinstance(module, nn.Sequential): |
|
for submodule in module: |
|
modules.append(submodule) |
|
elif isinstance(module, nn.Module): |
|
modules.append(module) |
|
return nn.Sequential(*modules) |
|
|
|
class FeedbackBlock(nn.Module): |
|
def __init__(self, num_features, num_groups, upscale_factor, act_type, norm_type): |
|
super(FeedbackBlock, self).__init__() |
|
if upscale_factor == 2: |
|
stride = 2 |
|
padding = 2 |
|
kernel_size = 6 |
|
elif upscale_factor == 3: |
|
stride = 3 |
|
padding = 2 |
|
kernel_size = 7 |
|
elif upscale_factor == 4: |
|
stride = 4 |
|
padding = 2 |
|
kernel_size = 8 |
|
elif upscale_factor == 8: |
|
stride = 8 |
|
padding = 2 |
|
kernel_size = 12 |
|
|
|
kSize = 3 |
|
kSize1 = 1 |
|
|
|
self.fpn1 = FPN(num_features) |
|
self.fpn2 = FPN(num_features) |
|
self.fpn3 = FPN(num_features) |
|
self.fpn4 = FPN(num_features) |
|
self.compress_in = nn.Conv2d(2 * num_features, num_features, kSize1, padding=(kSize1 - 1) // 2, stride=1) |
|
self.compress_out = nn.Conv2d(4 * num_features, num_features, kSize1, padding=(kSize1 - 1) // 2, stride=1) |
|
|
|
def forward(self, x): |
|
if self.should_reset: |
|
self.last_hidden = torch.zeros(x.size()).cuda() |
|
self.last_hidden.copy_(x) |
|
self.should_reset = False |
|
|
|
x = torch.cat((x, self.last_hidden), dim=1) |
|
x = self.compress_in(x) |
|
|
|
fpn1 = self.fpn1(x) |
|
fpn2 = self.fpn2(fpn1) |
|
fpn3 = self.fpn3(fpn2) |
|
fpn4 = self.fpn4(fpn3) |
|
output = torch.cat((fpn1, fpn2, fpn3, fpn4), dim=1) |
|
output = self.compress_out(output) |
|
|
|
self.last_hidden = output |
|
|
|
return output |
|
|
|
def reset_state(self): |
|
self.should_reset = True |
|
|
|
|
|
class metafpn(nn.Module): |
|
def __init__(self, |
|
RDNkSize=3, |
|
G0=64, |
|
n_colors=3, |
|
act_type='prelu', |
|
norm_type=None |
|
): |
|
super(metafpn, self).__init__() |
|
|
|
kernel_size = RDNkSize |
|
self.num_steps = 4 |
|
self.num_features = G0 |
|
self.scale_idx = 0 |
|
self.scale = 1 |
|
in_channels = n_colors |
|
num_groups = 6 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.conv_in = ConvBlock(in_channels, 4 * self.num_features, |
|
|
|
kernel_size=3, |
|
act_type=act_type, norm_type=norm_type) |
|
self.feat_in = ConvBlock(4 * self.num_features, self.num_features, |
|
kernel_size=1, |
|
act_type=act_type, norm_type=norm_type) |
|
|
|
|
|
self.block = FeedbackBlock(self.num_features, num_groups, self.scale, act_type, norm_type) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.P2W = Pos2Weight(inC=self.num_features) |
|
|
|
def repeat_x(self, x): |
|
scale_int = math.ceil(self.scale) |
|
N, C, H, W = x.size() |
|
x = x.view(N, C, H, 1, W, 1) |
|
|
|
x = torch.cat([x] * scale_int, 3) |
|
x = torch.cat([x] * scale_int, 5).permute(0, 3, 5, 1, 2, 4) |
|
|
|
return x.contiguous().view(-1, C, H, W) |
|
|
|
def forward(self, x, pos_mat): |
|
self._reset_state() |
|
|
|
|
|
scale_int = math.ceil(self.scale) |
|
|
|
|
|
|
|
|
|
inter_res = nn.functional.interpolate(x, scale_factor=scale_int, mode='bilinear', align_corners=False) |
|
|
|
x = self.conv_in(x) |
|
x = self.feat_in(x) |
|
|
|
outs = [] |
|
for _ in range(self.num_steps): |
|
h = self.block(x) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
local_weight = self.P2W( |
|
pos_mat.view(pos_mat.size(1), -1)) |
|
up_x = self.repeat_x(h) |
|
|
|
|
|
cols = nn.functional.unfold(up_x, 3, padding=1) |
|
scale_int = math.ceil(self.scale) |
|
|
|
cols = cols.contiguous().view(cols.size(0) // (scale_int ** 2), scale_int ** 2, cols.size(1), cols.size(2), |
|
1).permute(0, 1, 3, 4, 2).contiguous() |
|
|
|
local_weight = local_weight.contiguous().view(x.size(2), scale_int, x.size(3), scale_int, -1, 3).permute(1, |
|
3, |
|
0, |
|
2, |
|
4, |
|
5).contiguous() |
|
local_weight = local_weight.contiguous().view(scale_int ** 2, x.size(2) * x.size(3), -1, 3) |
|
|
|
out = torch.matmul(cols, local_weight).permute(0, 1, 4, 2, 3) |
|
out = out.contiguous().view(x.size(0), scale_int, scale_int, 3, x.size(2), x.size(3)).permute(0, 3, 4, 1, 5, |
|
2) |
|
out = out.contiguous().view(x.size(0), 3, scale_int * x.size(2), scale_int * x.size(3)) |
|
|
|
h = torch.add(inter_res, out) |
|
|
|
|
|
outs.append(h) |
|
|
|
return outs |
|
|
|
def _reset_state(self): |
|
self.block.reset_state() |
|
|
|
def set_scale(self, scale_idx): |
|
self.scale_idx = scale_idx |
|
self.scale = self.args.scale[scale_idx] |