FunSR / models /baselines /RSI_HFAS.py
KyanChen's picture
add
02c5426
import time
from collections import OrderedDict
import torch
import torch.nn as nn
import math
import torchvision.utils as SI
def make_model(args, parent=False):
return metafpn(args)
class Pos2Weight(nn.Module):
def __init__(self, inC, kernel_size=3, outC=3):
super(Pos2Weight, self).__init__()
self.inC = inC
self.kernel_size = kernel_size
self.outC = outC
self.meta_block = nn.Sequential(
nn.Linear(3, 256),
nn.ReLU(inplace=True),
nn.Linear(256, 512),
nn.ReLU(inplace=True),
nn.Linear(512, self.kernel_size * self.kernel_size * self.inC * self.outC)
)
def forward(self, x):
output = self.meta_block(x)
return output
class RDB_Conv(nn.Module):
def __init__(self, inChannels, growRate, kSize=3):
super(RDB_Conv, self).__init__()
Cin = inChannels
G = growRate
self.conv = nn.Sequential(*[
nn.Conv2d(Cin, G, kSize, padding=(kSize - 1) // 2, stride=1),
nn.ReLU()
])
def forward(self, x):
out = self.conv(x)
return out
class FPN(nn.Module):
def __init__(self, G0, kSize=3):
super(FPN, self).__init__()
kSize1 = 1
self.conv1 = RDB_Conv(G0, G0, kSize)
self.conv2 = RDB_Conv(G0, G0, kSize)
self.conv3 = RDB_Conv(G0, G0, kSize)
self.conv4 = RDB_Conv(G0, G0, kSize)
self.conv5 = RDB_Conv(G0, G0, kSize)
self.conv6 = RDB_Conv(G0, G0, kSize)
self.conv7 = RDB_Conv(G0, G0, kSize)
self.conv8 = RDB_Conv(G0, G0, kSize)
self.conv9 = RDB_Conv(G0, G0, kSize)
self.conv10 = RDB_Conv(G0, G0, kSize)
self.compress_in1 = nn.Conv2d(4 * G0, G0, kSize1, padding=(kSize1 - 1) // 2, stride=1)
self.compress_in2 = nn.Conv2d(3 * G0, G0, kSize1, padding=(kSize1 - 1) // 2, stride=1)
self.compress_in3 = nn.Conv2d(2 * G0, G0, kSize1, padding=(kSize1 - 1) // 2, stride=1)
self.compress_in4 = nn.Conv2d(2 * G0, G0, kSize1, padding=(kSize1 - 1) // 2, stride=1)
self.compress_out = nn.Conv2d(4 * G0, G0, kSize1, padding=(kSize1 - 1) // 2, stride=1)
def forward(self, x):
x1 = self.conv1(x)
x2 = self.conv2(x1)
x3 = self.conv3(x2)
x4 = self.conv4(x3)
x11 = x + x4
x5 = torch.cat((x1, x2, x3, x4), dim=1)
x5_res = self.compress_in1(x5)
x5 = self.conv5(x5_res)
x6 = self.conv6(x5)
x7 = self.conv7(x6)
x12 = x5_res + x7
x8 = torch.cat((x5, x6, x7), dim=1)
x8_res = self.compress_in2(x8)
x8 = self.conv8(x8_res)
x9 = self.conv9(x8)
x13 = x8_res + x9
x10 = torch.cat((x8, x9), dim=1)
x10_res = self.compress_in3(x10)
x10 = self.conv10(x10_res)
x14 = x10_res + x10
output = torch.cat((x11, x12, x13, x14), dim=1)
output = self.compress_out(output)
output = output + x
return output
def default_conv(in_channels, out_channels, kernel_size, bias=True):
return nn.Conv2d(
in_channels, out_channels, kernel_size,
padding=(kernel_size // 2), bias=bias)
class MeanShift(nn.Conv2d):
def __init__(self, rgb_range, rgb_mean, rgb_std, sign=-1):
super(MeanShift, self).__init__(3, 3, kernel_size=1)
std = torch.Tensor(rgb_std)
self.weight.data = torch.eye(3).view(3, 3, 1, 1)
self.weight.data.div_(std.view(3, 1, 1, 1))
self.bias.data = sign * rgb_range * torch.Tensor(rgb_mean)
self.bias.data.div_(std)
self.requires_grad = False
class BasicBlock(nn.Sequential):
def __init__(
self, in_channels, out_channels, kernel_size, stride=1, bias=False,
bn=True, act=nn.ReLU(True)):
m = [nn.Conv2d(
in_channels, out_channels, kernel_size,
padding=(kernel_size // 2), stride=stride, bias=bias)
]
if bn: m.append(nn.BatchNorm2d(out_channels))
if act is not None: m.append(act)
super(BasicBlock, self).__init__(*m)
class ResBlock(nn.Module):
def __init__(
self, conv, n_feats, kernel_size,
bias=True, bn=False, act=nn.ReLU(True), res_scale=1):
super(ResBlock, self).__init__()
m = []
for i in range(2):
m.append(conv(n_feats, n_feats, kernel_size, bias=bias))
if bn: m.append(nn.BatchNorm2d(n_feats))
if i == 0: m.append(act)
self.body = nn.Sequential(*m)
self.res_scale = res_scale
def forward(self, x):
res = self.body(x).mul(self.res_scale)
res += x
return res
class Upsampler(nn.Sequential):
def __init__(self, conv, scale, n_feats, bn=False, act=False, bias=True):
m = []
if (scale & (scale - 1)) == 0: # Is scale = 2^n?
for _ in range(int(math.log(scale, 2))):
m.append(conv(n_feats, 4 * n_feats, 3, bias))
m.append(nn.PixelShuffle(2))
if bn: m.append(nn.BatchNorm2d(n_feats))
if act == 'relu':
m.append(nn.ReLU(True))
elif act == 'prelu':
m.append(nn.PReLU(n_feats))
elif scale == 3:
m.append(conv(n_feats, 9 * n_feats, 3, bias))
m.append(nn.PixelShuffle(3))
if bn: m.append(nn.BatchNorm2d(n_feats))
if act == 'relu':
m.append(nn.ReLU(True))
elif act == 'prelu':
m.append(nn.PReLU(n_feats))
else:
raise NotImplementedError
super(Upsampler, self).__init__(*m)
class ResidualDenseBlock_8C(nn.Module):
'''
Residual Dense Block
style: 8 convs
The core module of paper: (Residual Dense Network for Image Super-Resolution, CVPR 18)
'''
def __init__(self, nc, kernel_size=3, gc=32, stride=1, bias=True, pad_type='zero', norm_type=None, act_type='relu',
mode='CNA'):
super(ResidualDenseBlock_8C, self).__init__()
# gc: growth channel, i.e. intermediate channels
self.conv1 = ConvBlock(nc, gc, kernel_size, stride, bias=bias, pad_type=pad_type, norm_type=norm_type,
act_type=act_type, mode=mode)
self.conv2 = ConvBlock(nc + gc, gc, kernel_size, stride, bias=bias, pad_type=pad_type, norm_type=norm_type,
act_type=act_type, mode=mode)
self.conv3 = ConvBlock(nc + 2 * gc, gc, kernel_size, stride, bias=bias, pad_type=pad_type, norm_type=norm_type,
act_type=act_type, mode=mode)
self.conv4 = ConvBlock(nc + 3 * gc, gc, kernel_size, stride, bias=bias, pad_type=pad_type, norm_type=norm_type,
act_type=act_type, mode=mode)
self.conv5 = ConvBlock(nc + 4 * gc, gc, kernel_size, stride, bias=bias, pad_type=pad_type, norm_type=norm_type,
act_type=act_type, mode=mode)
self.conv6 = ConvBlock(nc + 5 * gc, gc, kernel_size, stride, bias=bias, pad_type=pad_type, norm_type=norm_type,
act_type=act_type, mode=mode)
self.conv7 = ConvBlock(nc + 6 * gc, gc, kernel_size, stride, bias=bias, pad_type=pad_type, norm_type=norm_type,
act_type=act_type, mode=mode)
self.conv8 = ConvBlock(nc + 7 * gc, gc, kernel_size, stride, bias=bias, pad_type=pad_type, norm_type=norm_type,
act_type=act_type, mode=mode)
if mode == 'CNA':
last_act = None
else:
last_act = act_type
self.conv9 = ConvBlock(nc + 8 * gc, nc, 1, stride, bias=bias, pad_type=pad_type, norm_type=norm_type,
act_type=last_act, mode=mode)
def forward(self, x):
x1 = self.conv1(x)
x2 = self.conv2(torch.cat((x, x1), 1))
x3 = self.conv3(torch.cat((x, x1, x2), 1))
x4 = self.conv4(torch.cat((x, x1, x2, x3), 1))
x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1))
x6 = self.conv6(torch.cat((x, x1, x2, x3, x4, x5), 1))
x7 = self.conv7(torch.cat((x, x1, x2, x3, x4, x5, x6), 1))
x8 = self.conv8(torch.cat((x, x1, x2, x3, x4, x5, x6, x7), 1))
x9 = self.conv9(torch.cat((x, x1, x2, x3, x4, x5, x6, x7, x8), 1))
return x9.mul(0.2) + x
def ConvBlock(in_channels, out_channels, kernel_size, stride=1, dilation=1, bias=True, valid_padding=True, padding=0, \
act_type='relu', norm_type='bn', pad_type='zero', mode='CNA'):
assert (mode in ['CNA', 'NAC']), '[ERROR] Wrong mode in [%s]!' % sys.modules[__name__]
if valid_padding:
padding = get_valid_padding(kernel_size, dilation)
else:
pass
p = pad(pad_type, padding) if pad_type and pad_type != 'zero' else None
conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride=stride, padding=padding, dilation=dilation,
bias=bias)
if mode == 'CNA':
act = activation(act_type) if act_type else None
n = norm(out_channels, norm_type) if norm_type else None
return sequential(p, conv, n, act)
elif mode == 'NAC':
act = activation(act_type, inplace=False) if act_type else None
n = norm(in_channels, norm_type) if norm_type else None
return sequential(n, act, p, conv)
def DeconvBlock(in_channels, out_channels, kernel_size, stride=1, dilation=1, bias=True, padding=0, \
act_type='relu', norm_type='bn', pad_type='zero', mode='CNA'):
assert (mode in ['CNA', 'NAC']), '[ERROR] Wrong mode in [%s]!' % sys.modules[__name__]
p = pad(pad_type, padding) if pad_type and pad_type != 'zero' else None
deconv = nn.ConvTranspose2d(in_channels, out_channels, kernel_size, stride, padding, dilation=dilation, bias=bias)
if mode == 'CNA':
act = activation(act_type) if act_type else None
n = norm(out_channels, norm_type) if norm_type else None
return sequential(p, deconv, n, act)
elif mode == 'NAC':
act = activation(act_type, inplace=False) if act_type else None
n = norm(in_channels, norm_type) if norm_type else None
return sequential(n, act, p, deconv)
def get_valid_padding(kernel_size, dilation):
"""
Padding value to remain feature size.
"""
kernel_size = kernel_size + (kernel_size - 1) * (dilation - 1)
padding = (kernel_size - 1) // 2
return padding
def pad(pad_type, padding):
pad_type = pad_type.lower()
if padding == 0:
return None
layer = None
if pad_type == 'reflect':
layer = nn.ReflectionPad2d(padding)
elif pad_type == 'replicate':
layer = nn.ReplicationPad2d(padding)
else:
raise NotImplementedError('[ERROR] Padding layer [%s] is not implemented!' % pad_type)
return layer
def activation(act_type='relu', inplace=True, slope=0.2, n_prelu=1):
act_type = act_type.lower()
layer = None
if act_type == 'relu':
layer = nn.ReLU(inplace)
elif act_type == 'lrelu':
layer = nn.LeakyReLU(slope, inplace)
elif act_type == 'prelu':
layer = nn.PReLU(num_parameters=n_prelu, init=slope)
else:
raise NotImplementedError('[ERROR] Activation layer [%s] is not implemented!' % act_type)
return layer
def norm(n_feature, norm_type='bn'):
norm_type = norm_type.lower()
layer = None
if norm_type == 'bn':
layer = nn.BatchNorm2d(n_feature)
else:
raise NotImplementedError('[ERROR] Normalization layer [%s] is not implemented!' % norm_type)
return layer
def sequential(*args):
if len(args) == 1:
if isinstance(args[0], OrderedDict):
raise NotImplementedError('[ERROR] %s.sequential() does not support OrderedDict' % sys.modules[__name__])
else:
return args[0]
modules = []
for module in args:
if isinstance(module, nn.Sequential):
for submodule in module:
modules.append(submodule)
elif isinstance(module, nn.Module):
modules.append(module)
return nn.Sequential(*modules)
class FeedbackBlock(nn.Module):
def __init__(self, num_features, num_groups, upscale_factor, act_type, norm_type):
super(FeedbackBlock, self).__init__()
if upscale_factor == 2:
stride = 2
padding = 2
kernel_size = 6
elif upscale_factor == 3:
stride = 3
padding = 2
kernel_size = 7
elif upscale_factor == 4:
stride = 4
padding = 2
kernel_size = 8
elif upscale_factor == 8:
stride = 8
padding = 2
kernel_size = 12
kSize = 3
kSize1 = 1
self.fpn1 = FPN(num_features)
self.fpn2 = FPN(num_features)
self.fpn3 = FPN(num_features)
self.fpn4 = FPN(num_features)
self.compress_in = nn.Conv2d(2 * num_features, num_features, kSize1, padding=(kSize1 - 1) // 2, stride=1)
self.compress_out = nn.Conv2d(4 * num_features, num_features, kSize1, padding=(kSize1 - 1) // 2, stride=1)
def forward(self, x):
if self.should_reset:
self.last_hidden = torch.zeros(x.size()).cuda()
self.last_hidden.copy_(x)
self.should_reset = False
x = torch.cat((x, self.last_hidden), dim=1) # tense拼接
x = self.compress_in(x)
fpn1 = self.fpn1(x)
fpn2 = self.fpn2(fpn1)
fpn3 = self.fpn3(fpn2)
fpn4 = self.fpn4(fpn3)
output = torch.cat((fpn1, fpn2, fpn3, fpn4), dim=1)
output = self.compress_out(output)
self.last_hidden = output
return output
def reset_state(self):
self.should_reset = True
class metafpn(nn.Module):
def __init__(self,
RDNkSize=3,
G0=64,
n_colors=3,
act_type='prelu',
norm_type=None
):
super(metafpn, self).__init__() # 第一句话,调用父类的构造函数,这是对继承自父类的属性进行初始化。而且是用父类的初始化方法来初始化继承的属性。也就是说,子类继承了父类的所有属性和方法,父类属性自然会用父类方法来进行初始化。当然,如果初始化的逻辑与父类的不同,不使用父类的方法,自己重新初始化也是可以的。
kernel_size = RDNkSize
self.num_steps = 4
self.num_features = G0
self.scale_idx = 0
self.scale = 1
in_channels = n_colors
num_groups = 6
# RGB mean for DIV2K
# rgb_mean = (0.4488, 0.4371, 0.4040)
# rgb_std = (1.0, 1.0, 1.0)
# self.sub_mean = common.MeanShift(args.rgb_range, rgb_mean, rgb_std)
# self.add_mean = common.MeanShift(args.rgb_range, rgb_mean, rgb_std, 1)
# LR feature extraction block
self.conv_in = ConvBlock(in_channels, 4 * self.num_features,
# 3×3Conv 一个卷积核产生一个feature map就是num_features
kernel_size=3,
act_type=act_type, norm_type=norm_type)
self.feat_in = ConvBlock(4 * self.num_features, self.num_features,
kernel_size=1,
act_type=act_type, norm_type=norm_type)
# basic block
self.block = FeedbackBlock(self.num_features, num_groups, self.scale, act_type, norm_type)
# reconstruction block
# uncomment for pytorch 0.4.0
# self.upsample = nn.Upsample(scale_factor=upscale_factor, mode='bilinear')
# self.out = DeconvBlock(num_features, num_features,
# kernel_size=kernel_size, stride=stride, padding=padding,
# act_type='prelu', norm_type=norm_type)
self.P2W = Pos2Weight(inC=self.num_features)
def repeat_x(self, x):
scale_int = math.ceil(self.scale)
N, C, H, W = x.size()
x = x.view(N, C, H, 1, W, 1)
x = torch.cat([x] * scale_int, 3)
x = torch.cat([x] * scale_int, 5).permute(0, 3, 5, 1, 2, 4)
return x.contiguous().view(-1, C, H, W)
def forward(self, x, pos_mat):
self._reset_state()
# x = self.sub_mean(x)
scale_int = math.ceil(self.scale)
# uncomment for pytorch 0.4.0
# inter_res = self.upsample(x)
# comment for pytorch 0.4.0
inter_res = nn.functional.interpolate(x, scale_factor=scale_int, mode='bilinear', align_corners=False)
x = self.conv_in(x)
x = self.feat_in(x)
outs = []
for _ in range(self.num_steps):
h = self.block(x)
#output1 = h.clone()
# for i in range(60):
# output2 = output1[:,i:i+3,:,:]
# SI.save_image(output2,"results/result"+str(i)+".png")
# meta###########################################
local_weight = self.P2W(
pos_mat.view(pos_mat.size(1), -1)) ### (outH*outW, outC*inC*kernel_size*kernel_size)
up_x = self.repeat_x(h) ### the output is (N*r*r,inC,inH,inW)
# N*r^2 x [inC * kH * kW] x [inH * inW]
cols = nn.functional.unfold(up_x, 3, padding=1)
scale_int = math.ceil(self.scale)
cols = cols.contiguous().view(cols.size(0) // (scale_int ** 2), scale_int ** 2, cols.size(1), cols.size(2),
1).permute(0, 1, 3, 4, 2).contiguous()
local_weight = local_weight.contiguous().view(x.size(2), scale_int, x.size(3), scale_int, -1, 3).permute(1,
3,
0,
2,
4,
5).contiguous()
local_weight = local_weight.contiguous().view(scale_int ** 2, x.size(2) * x.size(3), -1, 3)
out = torch.matmul(cols, local_weight).permute(0, 1, 4, 2, 3)
out = out.contiguous().view(x.size(0), scale_int, scale_int, 3, x.size(2), x.size(3)).permute(0, 3, 4, 1, 5,
2)
out = out.contiguous().view(x.size(0), 3, scale_int * x.size(2), scale_int * x.size(3))
h = torch.add(inter_res, out)
# h = self.add_mean(h)
outs.append(h)
return outs # return output of every timesteps
def _reset_state(self):
self.block.reset_state()
def set_scale(self, scale_idx):
self.scale_idx = scale_idx
self.scale = self.args.scale[scale_idx]