File size: 5,323 Bytes
02c5426
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
from argparse import Namespace

import torch
import torch.nn as nn
from models import register
import torch.nn.functional as F

def make_model(args, parent=False):
    return CNN7(args)


@register('LGCNET')
def LGCNET(scale_ratio, rgb_range=1):
    args = Namespace()
    args.scale = [scale_ratio]
    args.n_colors = 3
    args.rgb_range = rgb_range
    return LGCNET(args)


class LGCNET(nn.Module):
    def __init__(self, args, nfeats = 32):
        super(LGCNET, self).__init__()
        self.conv1 = nn.Conv2d(args.n_colors, nfeats, kernel_size=3, stride=1, padding=1, bias=True)
        self.conv2 = nn.Conv2d(nfeats, nfeats, kernel_size=3, stride=1, padding=1, bias=True)
        self.conv3 = nn.Conv2d(nfeats, nfeats, kernel_size=3, stride=1, padding=1, bias=True)
        self.conv4 = nn.Conv2d(nfeats, nfeats, kernel_size=3, stride=1, padding=1, bias=True)
        self.conv5 = nn.Conv2d(nfeats, nfeats, kernel_size=3, stride=1, padding=1, bias=True)
        self.conv6 = nn.Conv2d(nfeats*3, nfeats*2, kernel_size=5, stride=1, padding=2, bias=True)
        self.conv7 = nn.Conv2d(nfeats*2, 3, kernel_size=3, stride=1, padding=1, bias=True)
        self.relu =  nn.ReLU()

    def forward(self, x, out_size):
        x = F.interpolate(x, out_size, mode='bicubic')
        residual = x
        im1 = self.relu(self.conv1(x))
        im2 = self.relu(self.conv2(im1))
        im3 = self.relu(self.conv3(im2))
        im4 = self.relu(self.conv4(im3))
        im5 = self.relu(self.conv5(im4))
        out = self.relu(self.conv6(torch.cat((im3, im4, im5), dim = 1)))
        out = self.conv7(out) + residual
        return out

    def load_state_dict(self, state_dict, strict=False):
        own_state = self.state_dict()
        for name, param in state_dict.items():
            if name in own_state:
                if isinstance(param, nn.Parameter):
                    param = param.data
                try:
                    own_state[name].copy_(param)
                except Exception:
                    if name.find('tail') >= 0:
                        print('Replace pre-trained upsampler to new one...')
                    else:
                        raise RuntimeError('While copying the parameter named {}, '
                                           'whose dimensions in the model are {} and '
                                           'whose dimensions in the checkpoint are {}.'
                                           .format(name, own_state[name].size(), param.size()))
            elif strict:
                if name.find('tail') == -1:
                    raise KeyError('unexpected key "{}" in state_dict'
                                   .format(name))

        if strict:
            missing = set(own_state.keys()) - set(state_dict.keys())
            if len(missing) > 0:
                raise KeyError('missing keys in state_dict: "{}"'.format(missing))


class CNN7(nn.Module):
    def __init__(self, args, nfeats = 32):
        super(CNN7, self).__init__()
        self.conv1 = nn.Conv2d(args.n_colors, nfeats, kernel_size=3, stride=1, padding=1, bias=True)
        self.conv2 = nn.Conv2d(nfeats, nfeats, kernel_size=3, stride=1, padding=1, bias=True)
        self.conv3 = nn.Conv2d(nfeats, nfeats, kernel_size=3, stride=1, padding=1, bias=True)
        self.conv4 = nn.Conv2d(nfeats, nfeats, kernel_size=3, stride=1, padding=1, bias=True)
        self.conv5 = nn.Conv2d(nfeats, nfeats, kernel_size=3, stride=1, padding=1, bias=True)
        self.conv6 = nn.Conv2d(nfeats, nfeats, kernel_size=3, stride=1, padding=1, bias=True)
        self.conv7 = nn.Conv2d(nfeats, 3, kernel_size=3, stride=1, padding=1, bias=True)
        self.relu =  nn.ReLU()

    def forward(self, x):
        residual = x
        im1 = self.relu(self.conv1(x))
        im2 = self.relu(self.conv2(im1))
        im3 = self.relu(self.conv3(im2))
        im4 = self.relu(self.conv4(im3))
        im5 = self.relu(self.conv5(im4))
        im6 = self.relu(self.conv6(im5))
        out = self.conv7(im6) + residual
        return out

    def load_state_dict(self, state_dict, strict=False):
        own_state = self.state_dict()
        for name, param in state_dict.items():
            if name in own_state:
                if isinstance(param, nn.Parameter):
                    param = param.data
                try:
                    own_state[name].copy_(param)
                except Exception:
                    if name.find('tail') >= 0:
                        print('Replace pre-trained upsampler to new one...')
                    else:
                        raise RuntimeError('While copying the parameter named {}, '
                                           'whose dimensions in the model are {} and '
                                           'whose dimensions in the checkpoint are {}.'
                                           .format(name, own_state[name].size(), param.size()))
            elif strict:
                if name.find('tail') == -1:
                    raise KeyError('unexpected key "{}" in state_dict'
                                   .format(name))

        if strict:
            missing = set(own_state.keys()) - set(state_dict.keys())
            if len(missing) > 0:
                raise KeyError('missing keys in state_dict: "{}"'.format(missing))