KyanChen's picture
add
02c5426
import torch
import torch.nn as nn
from matplotlib import pyplot as plt
from torch.nn.utils import weight_norm
import time
from os.path import exists
import os
from . import upsampler
from .dynamic_layers import ScaleAwareDynamicConv2d
from easydict import EasyDict
import models
from models import register
def spatial_fold(input, fold):
if fold == 1:
return input
batch, channel, height, width = input.shape
h_fold = height // fold
w_fold = width // fold
return (
input.view(batch, channel, h_fold, fold, w_fold, fold)
.permute(0, 1, 3, 5, 2, 4)
.reshape(batch, -1, h_fold, w_fold)
)
def spatial_unfold(input, unfold):
if unfold == 1:
return input
batch, channel, height, width = input.shape
h_unfold = height * unfold
w_unfold = width * unfold
return (
input.view(batch, -1, unfold, unfold, height, width)
.permute(0, 1, 4, 2, 5, 3)
.reshape(batch, -1, h_unfold, w_unfold)
)
def default_conv(in_channels, out_channels, kernel_size, bias=True):
# logger.warning("The module is deprecated, and will be removed in the future! ")
return nn.Conv2d(
in_channels, out_channels, kernel_size, padding=(kernel_size // 2), bias=bias
)
class WeightNormedConv(nn.Sequential):
def __init__(
self,
in_channels,
out_channels,
kernel_size,
stride=1,
bias=True,
act=nn.ReLU(True),
):
conv = weight_norm(
nn.Conv2d(
in_channels,
out_channels,
kernel_size,
padding=kernel_size // 2,
stride=stride,
bias=bias,
)
)
m = [conv]
if act:
m.append(act)
super().__init__(*m)
class MeanShift(nn.Conv2d):
def __init__(self, rgb_range, rgb_mean, rgb_std, sign=-1):
if len(rgb_std) != len(rgb_mean):
assert len(rgb_std) == 1
rgb_std = rgb_std * len(rgb_mean)
channel = len(rgb_mean)
super(MeanShift, self).__init__(channel, channel, kernel_size=1)
std = torch.Tensor(rgb_std)
self.weight.data = torch.eye(channel).view(channel, channel, 1, 1)
self.weight.data.div_(std.view(channel, 1, 1, 1))
self.bias.data = sign * rgb_range * torch.Tensor(rgb_mean)
self.bias.data.div_(std)
self.requires_grad = False
class BasicBlock(nn.Sequential):
def __init__(
self,
in_channels,
out_channels,
kernel_size,
stride=1,
bias=False,
bn=True,
act=nn.ReLU(True),
):
m = [
nn.Conv2d(
in_channels,
out_channels,
kernel_size,
padding=(kernel_size // 2),
stride=stride,
bias=bias,
)
]
if bn:
m.append(nn.BatchNorm2d(out_channels))
if act is not None:
m.append(act)
super(BasicBlock, self).__init__(*m)
class ResBlock(nn.Module):
def __init__(
self,
conv,
n_feats,
kernel_size,
bias=True,
bn=False,
act=nn.ReLU(True),
res_scale=1,
):
super(ResBlock, self).__init__()
m = []
for i in range(2):
m.append(conv(n_feats, n_feats, kernel_size, bias=bias))
if bn:
m.append(nn.BatchNorm2d(n_feats))
if i == 0:
m.append(act)
self.body = nn.Sequential(*m)
self.res_scale = res_scale
def forward(self, x):
res = self.body(x).mul(self.res_scale)
res += x
return res
def channel_shuffle(x, groups):
batchsize, num_channels, height, width = x.data.size()
channels_per_group = num_channels // groups
# reshape
x = x.view(batchsize, groups, channels_per_group, height, width)
x = torch.transpose(x, 1, 2).contiguous()
# flatten
x = x.view(batchsize, -1, height, width)
return x
def make_coord(shape, ranges=None, flatten=True):
"""Make coordinates at grid centers."""
coord_seqs = []
for i, n in enumerate(shape):
if ranges is None:
v0, v1 = -1, 1
else:
v0, v1 = ranges[i]
r = (v1 - v0) / (2 * n)
seq = v0 + r + (2 * r) * torch.arange(n).float()
coord_seqs.append(seq)
ret = torch.stack(torch.meshgrid(*coord_seqs), dim=-1)
if flatten:
ret = ret.view(-1, ret.shape[-1])
return ret
class SEBlock(nn.Module):
def __init__(self, channels, reduction=16):
super().__init__()
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.fc = nn.Sequential(
nn.Linear(channels, channels // reduction, bias=False),
nn.ReLU(True),
nn.Linear(channels // reduction, channels, bias=False),
nn.Sigmoid(),
)
def forward(self, x):
b, c, _, _ = x.size()
y = self.avg_pool(x).view(b, c)
y = self.fc(y).view(b, c, 1, 1)
return x * y.expand_as(x)
class WideConvBlock(nn.Module):
def __init__(self, num_features, kernel_size, width_multiplier=4, reduction=4):
super().__init__()
self.body = nn.Sequential(
*[
WeightNormedConv(
num_features, int(num_features * width_multiplier), 3
),
WeightNormedConv(
int(num_features * width_multiplier), num_features, 3, act=None
),
WeightNormedConv(
num_features,
num_features,
kernel_size,
act=None,
# res_scale=res_scale,
),
SEBlock(num_features, reduction),
]
)
def forward(self, x, scale):
return x + self.body(x)
class DynamicWideConvBlock(nn.Module):
def __init__(
self,
num_features,
kernel_size,
width_multiplier=4,
dynamic_K=4,
reduction=4,
):
super().__init__()
self.body = nn.Sequential(
*[
WeightNormedConv(
num_features,
int(num_features * width_multiplier),
kernel_size,
# res_scale=2.0,
),
WeightNormedConv(
int(num_features * width_multiplier),
num_features,
kernel_size,
act=None,
),
]
)
self.d_conv = weight_norm(
ScaleAwareDynamicConv2d(
num_features,
num_features,
kernel_size,
padding=kernel_size // 2,
K=dynamic_K,
)
)
self.se_block = SEBlock(num_features, reduction)
def forward(self, x, scale):
r = self.body(x)
r = self.d_conv(r, scale)
r = self.se_block(r)
return x + r
class LocalDenseGroup(nn.Module):
def __init__(
self,
num_features,
width_multiplier,
num_layers,
reduction,
use_dynamic_conv,
dynamic_K,
):
super().__init__()
kSize = 3
self.num_layers = num_layers
self.ConvBlockList = nn.ModuleList()
self.compressList = nn.ModuleList()
self.use_dynamic_conv = use_dynamic_conv
for idx in range(num_layers):
if use_dynamic_conv:
self.ConvBlockList.append(
DynamicWideConvBlock(
num_features,
kSize,
width_multiplier=width_multiplier,
# res_scale=1 / math.sqrt(num_layers),
dynamic_K=dynamic_K,
reduction=reduction,
)
)
else:
self.ConvBlockList.append(
WideConvBlock(
num_features,
kSize,
width_multiplier=width_multiplier,
# res_scale=1 / math.sqrt(num_layers),
reduction=reduction,
)
)
for idx in range(1, num_layers):
self.compressList.append(
WeightNormedConv(
(idx + 1) * num_features, num_features, 1, act=None
)
)
def forward(self, x, scale):
concat = x
for l in range(self.num_layers):
if l == 0:
out = self.ConvBlockList[l](concat, scale)
else:
concat = torch.cat([concat, out], dim=1)
out = self.compressList[l - 1](concat)
out = self.ConvBlockList[l](out, scale)
return out
class FeedbackBlock(nn.Module):
def __init__(
self,
num_features,
width_multiplier,
num_layers,
num_groups,
reduction,
use_dynamic_conv,
dynamic_K,
):
super().__init__()
kSize = 3
self.num_groups = num_groups
self.LDGList = nn.ModuleList()
for _ in range(num_groups):
self.LDGList.append(
LocalDenseGroup(
num_features,
width_multiplier,
num_layers,
reduction,
use_dynamic_conv,
dynamic_K,
)
)
self.compressList = nn.ModuleList()
for idx in range(1, num_groups):
self.compressList.append(
WeightNormedConv(
(idx + 1) * num_features, num_features, 1, act=None
)
)
self.compress_in = WeightNormedConv(
2 * num_features, num_features, kSize
)
self.should_reset = True
self.last_hidden = None
def forward(self, x, scale):
if self.should_reset:
self.last_hidden = torch.zeros(x.size(), device=x.device)
self.last_hidden.copy_(x)
self.should_reset = False
x = torch.cat((x, self.last_hidden), 1)
concat = self.compress_in(x)
for l in range(self.num_groups):
if l == 0:
out = self.LDGList[l](concat, scale)
else:
concat = torch.cat([concat, out], dim=1)
out = self.compressList[l - 1](concat)
out = self.LDGList[l](out, scale)
self.last_hidden = out
return out
def reset_state(self):
self.should_reset = True
@register('sadnarc')
class SADN(nn.Module):
def __init__(
self,
in_channels=3,
out_channels=3,
num_features=64,
num_layers=4,
num_groups=4,
reduction=4,
width_multiplier=4,
interpolate_mode='bilinear',
levels=4,
use_dynamic_conv=True,
dynamic_K=3,
which_uplayer="UPLayer_MS_WN",
uplayer_ksize=3,
rgb_range=1,
# rgb_mean=[0.5, 0.5, 0.5],
# rgb_std=[0.5, 0.5, 0.5],
*args,
**kwargs
):
super().__init__()
kernel_size = 3
skip_kernel_size = 5
num_inputs = in_channels
n_feats = num_features
self.interpolate_mode = interpolate_mode
self.levels = levels
# self.sub_mean = MeanShift(rgb_range, rgb_mean, rgb_std)
# self.add_mean = MeanShift(rgb_range, rgb_mean, rgb_std, 1)
self.head = nn.Sequential(
*[WeightNormedConv(num_inputs, num_features, kernel_size)]
)
self.body = FeedbackBlock(
num_features,
width_multiplier,
num_layers,
num_groups,
reduction,
use_dynamic_conv,
dynamic_K,
)
self.tail = nn.Sequential(
*[
WeightNormedConv(
num_features, num_features, kernel_size, act=None
)
]
)
self.skip = WeightNormedConv(
num_inputs, num_features, skip_kernel_size, act=None
)
UpLayer = getattr(upsampler, which_uplayer)
self.uplayer = UpLayer(
n_feats,
uplayer_ksize,
out_channels,
interpolate_mode,
levels,
)
def update_temperature(self):
for m in self.modules():
if isinstance(m, ScaleAwareDynamicConv2d):
m.update_temperature()
def forward(self, x, out_size):
self.body.reset_state()
if isinstance(out_size, int):
out_size = [out_size, out_size]
scale = torch.tensor([x.shape[2] / out_size[0]], device=x.device)
# x = self.sub_mean(x)
skip = self.skip(x)
x = self.head(x)
h_list = []
for _ in range(self.levels):
h = self.body(x, scale)
h = self.tail(h)
h = h + skip
h_list.append(h)
x = self.uplayer(h_list, out_size)
# x = self.add_mean(x)
return x
class SADN_vis(nn.Module):
def __init__(
self,
in_channels,
out_channels,
num_features,
num_layers,
num_groups,
reduction,
width_multiplier,
interpolate_mode,
levels,
use_dynamic_conv,
dynamic_K,
which_uplayer,
uplayer_ksize,
rgb_range,
rgb_mean,
rgb_std,
):
super().__init__()
kernel_size = 3
skip_kernel_size = 5
num_inputs = in_channels
n_feats = num_features
self.interpolate_mode = interpolate_mode
self.levels = levels
self.sub_mean = MeanShift(rgb_range, rgb_mean, rgb_std)
self.add_mean = MeanShift(rgb_range, rgb_mean, rgb_std, 1)
self.head = nn.Sequential(
*[WeightNormedConv(num_inputs, num_features, kernel_size)]
)
self.use_dynamic_conv = use_dynamic_conv
self.body = FeedbackBlock(
num_features,
width_multiplier,
num_layers,
num_groups,
reduction,
use_dynamic_conv,
dynamic_K,
)
self.tail = nn.Sequential(
*[
WeightNormedConv(
num_features, num_features, kernel_size, act=None
)
]
)
self.skip = WeightNormedConv(
num_inputs, num_features, skip_kernel_size, act=None
)
UpLayer = getattr(upsampler, which_uplayer)
self.uplayer = UpLayer(
n_feats,
uplayer_ksize,
out_channels,
interpolate_mode,
levels,
)
def update_temperature(self):
for m in self.modules():
if isinstance(m, ScaleAwareDynamicConv2d):
m.update_temperature()
def forward(self, x, out_size):
self.body.reset_state()
if isinstance(out_size, int):
out_size = [out_size, out_size]
scale = torch.tensor([x.shape[2] / out_size[0]], device=x.device)
x = self.sub_mean(x)
skip = self.skip(x)
x = self.head(x)
h_list = []
for _ in range(self.levels):
h = self.body(x, scale)
h = self.tail(h)
h = h + skip
h_list.append(h)
vis = torch.mean(h_list[-1], dim=1)
vis = (vis - vis.min()) / (vis.max() - vis.min())
vis = vis[..., 88:217, 32:161]
# vis = vis + 0.2
# vis.clamp_max_(1)
print(torch.min(vis), torch.max(vis))
# print(vis.shape)
savepath = "logs/vis"
filename = "geo_residential_t7.png"
if self.use_dynamic_conv:
savepath = os.path.join(savepath, "dy" + filename.replace(".png", ""))
else:
savepath = os.path.join(savepath, "wo_dy" + filename.replace(".png", ""))
if not exists(savepath):
os.mkdir(savepath)
savepath = os.path.join(savepath, "x{0}.png".format(int((1 / scale).item())))
plt.imsave(savepath, vis.cpu().numpy()[0], cmap="hsv")
x = self.uplayer(h_list, out_size)
x = self.add_mean(x)
return x
@register('edsr-sadn')
class EDSR_MS(nn.Module):
def __init__(
self,
n_resblocks=16,
n_feats=64,
in_channels=3,
out_channels=3,
res_scale=1,
which_uplayer="UPLayer_MS_WN",
uplayer_ksize=3,
interpolate_mode='bilinear',
levels=4,
*args,
**kwargs
):
super().__init__()
conv = default_conv
kernel_size = 3
act = nn.ReLU(True)
# define head module
m_head = [conv(in_channels, n_feats, kernel_size)]
# define body module
m_body = [
ResBlock(conv, n_feats, kernel_size, act=act, res_scale=res_scale)
for _ in range(n_resblocks)
]
m_body.append(conv(n_feats, n_feats, kernel_size))
self.head = nn.Sequential(*m_head)
self.body = nn.Sequential(*m_body)
UpLayer = getattr(upsampler, which_uplayer)
self.tail = UpLayer(
n_feats,
uplayer_ksize,
out_channels,
interpolate_mode,
levels,
)
def forward(self, x, out_size):
x = self.head(x)
res = self.body(x)
res += x
x = self.tail(res, out_size)
return x
class RDB_Conv(nn.Module):
def __init__(self, inChannels, growRate, kSize=3):
super(RDB_Conv, self).__init__()
Cin = inChannels
G = growRate
self.conv = nn.Sequential(
*[nn.Conv2d(Cin, G, kSize, padding=(kSize - 1) // 2, stride=1), nn.ReLU()]
)
def forward(self, x):
out = self.conv(x)
return torch.cat((x, out), 1)
class RDB(nn.Module):
def __init__(self, growRate0, growRate, nConvLayers, kSize=3):
super(RDB, self).__init__()
G0 = growRate0
G = growRate
C = nConvLayers
convs = []
for c in range(C):
convs.append(RDB_Conv(G0 + c * G, G))
self.convs = nn.Sequential(*convs)
# Local Feature Fusion
self.LFF = nn.Conv2d(G0 + C * G, G0, 1, padding=0, stride=1)
def forward(self, x):
return self.LFF(self.convs(x)) + x
class RDN(nn.Module):
def __init__(
self,
scale,
num_features,
num_blocks,
num_layers,
rgb_range,
in_channels,
out_channels,
rgb_mean=(0.4488, 0.4371, 0.4040),
rgb_std=(1.0, 1.0, 1.0),
):
super().__init__()
r = scale
G0 = num_features
kSize = 3
# number of RDB blocks, conv layers, out channels
self.D, C, G = [num_blocks, num_layers, num_features]
# self.sub_mean = common.MeanShift(rgb_range, rgb_mean, rgb_std)
# self.add_mean = common.MeanShift(rgb_range, rgb_mean, rgb_std, 1)
# Shallow feature extraction net
self.SFENet1 = nn.Conv2d(
in_channels, G0, kSize, padding=(kSize - 1) // 2, stride=1
)
self.SFENet2 = nn.Conv2d(G0, G0, kSize, padding=(kSize - 1) // 2, stride=1)
# Redidual dense blocks and dense feature fusion
self.RDBs = nn.ModuleList()
for i in range(self.D):
self.RDBs.append(RDB(growRate0=G0, growRate=G, nConvLayers=C))
# Global Feature Fusion
self.GFF = nn.Sequential(
*[
nn.Conv2d(self.D * G0, G0, 1, padding=0, stride=1),
nn.Conv2d(G0, G0, kSize, padding=(kSize - 1) // 2, stride=1),
]
)
# Up-sampling net
if r == 2 or r == 3:
self.UPNet = nn.Sequential(
*[
nn.Conv2d(G0, G * r * r, kSize, padding=(kSize - 1) // 2, stride=1),
nn.PixelShuffle(r),
nn.Conv2d(
G, out_channels, kSize, padding=(kSize - 1) // 2, stride=1
),
]
)
elif r == 4:
self.UPNet = nn.Sequential(
*[
nn.Conv2d(G0, G * 4, kSize, padding=(kSize - 1) // 2, stride=1),
nn.PixelShuffle(2),
nn.Conv2d(G, G * 4, kSize, padding=(kSize - 1) // 2, stride=1),
nn.PixelShuffle(2),
nn.Conv2d(
G, out_channels, kSize, padding=(kSize - 1) // 2, stride=1
),
]
)
def forward(self, x, return_features=False):
# x = self.sub_mean(x)
f__1 = self.SFENet1(x)
x = self.SFENet2(f__1)
RDBs_out = []
for i in range(self.D):
x = self.RDBs[i](x)
RDBs_out.append(x)
x = self.GFF(torch.cat(RDBs_out, 1))
feat = x + f__1
out = self.UPNet(feat)
# out = self.add_mean(out)
if return_features:
return out, feat
return out
@register('rdn-sadn')
class RDN_MS(RDN):
"""
The multi scale version of RDN, and you can specify rgb_mean/rgb_std/rgb_range!
"""
def __init__(self, **args):
args = EasyDict(args)
args.num_features = 64
args.num_blocks = 16
args.num_layers = 8
args.rgb_range = 1
args.in_channels = 3
args.out_channels = 3
args.which_uplayer = "UPLayer_MS_V9"
args.uplayer_ksize = 3
args.width_multiplier = 4
args.interpolate_mode = 'bilinear'
args.levels = 4
super().__init__(
scale=0,
num_features=args.num_features,
num_blocks=args.num_blocks,
num_layers=args.num_layers,
rgb_range=args.rgb_range,
in_channels=args.in_channels,
out_channels=args.out_channels,
)
# Redefine up-sampling net
UpLayer = getattr(upsampler, args.which_uplayer)
self.UPNet = UpLayer(
args.num_features, 3, args.out_channels, args.interpolate_mode, args.levels
)
rgb_mean = args.get("rgb_mean", (0.4488, 0.4371, 0.4040))
rgb_std = args.get("rgb_std", (1.0, 1.0, 1.0))
rgb_range = args.get("rgb_range")
# self.sub_mean = common.MeanShift(rgb_range, rgb_mean, rgb_std)
# self.add_mean = common.MeanShift(rgb_range, rgb_mean, rgb_std, 1)
def forward(self, x, out_size):
# x = self.sub_mean(x)
f__1 = self.SFENet1(x)
x = self.SFENet2(f__1)
RDBs_out = []
for i in range(self.D):
x = self.RDBs[i](x)
RDBs_out.append(x)
x = self.GFF(torch.cat(RDBs_out, 1))
x += f__1
x = self.UPNet(x, out_size)
# x = self.add_mean(x)
return x