Spaces:
Running
Running
Update archs/arch_util.py
Browse files- archs/arch_util.py +12 -123
archs/arch_util.py
CHANGED
@@ -10,128 +10,6 @@ except:
|
|
10 |
from nafnet_utils.arch_util import LayerNorm2d
|
11 |
from nafnet_utils.arch_model import SimpleGate
|
12 |
|
13 |
-
'''
|
14 |
-
https://github.com/wangchx67/FourLLIE.git
|
15 |
-
'''
|
16 |
-
|
17 |
-
def initialize_weights(net_l, scale=1):
|
18 |
-
if not isinstance(net_l, list):
|
19 |
-
net_l = [net_l]
|
20 |
-
for net in net_l:
|
21 |
-
for m in net.modules():
|
22 |
-
if isinstance(m, nn.Conv2d):
|
23 |
-
init.kaiming_normal_(m.weight, a=0, mode='fan_in')
|
24 |
-
m.weight.data *= scale # for residual block
|
25 |
-
if m.bias is not None:
|
26 |
-
m.bias.data.zero_()
|
27 |
-
elif isinstance(m, nn.Linear):
|
28 |
-
init.kaiming_normal_(m.weight, a=0, mode='fan_in')
|
29 |
-
m.weight.data *= scale
|
30 |
-
if m.bias is not None:
|
31 |
-
m.bias.data.zero_()
|
32 |
-
elif isinstance(m, nn.BatchNorm2d):
|
33 |
-
init.constant_(m.weight, 1)
|
34 |
-
init.constant_(m.bias.data, 0.0)
|
35 |
-
|
36 |
-
|
37 |
-
def make_layer(block, n_layers):
|
38 |
-
layers = []
|
39 |
-
for _ in range(n_layers):
|
40 |
-
layers.append(block())
|
41 |
-
return nn.Sequential(*layers)
|
42 |
-
|
43 |
-
|
44 |
-
class ResidualBlock_noBN(nn.Module):
|
45 |
-
'''Residual block w/o BN
|
46 |
-
---Conv-ReLU-Conv-+-
|
47 |
-
|________________|
|
48 |
-
'''
|
49 |
-
|
50 |
-
def __init__(self, nf=64):
|
51 |
-
super(ResidualBlock_noBN, self).__init__()
|
52 |
-
self.conv1 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
|
53 |
-
self.conv2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
|
54 |
-
|
55 |
-
# initialization
|
56 |
-
initialize_weights([self.conv1, self.conv2], 0.1)
|
57 |
-
|
58 |
-
def forward(self, x):
|
59 |
-
identity = x
|
60 |
-
out = F.relu(self.conv1(x), inplace=True)
|
61 |
-
out = self.conv2(out)
|
62 |
-
return identity + out
|
63 |
-
|
64 |
-
class SpaBlock(nn.Module):
|
65 |
-
def __init__(self, nc):
|
66 |
-
super(SpaBlock, self).__init__()
|
67 |
-
self.block = nn.Sequential(
|
68 |
-
nn.Conv2d(nc,nc,3,1,1),
|
69 |
-
nn.LeakyReLU(0.1,inplace=True),
|
70 |
-
nn.Conv2d(nc, nc, 3, 1, 1),
|
71 |
-
nn.LeakyReLU(0.1, inplace=True))
|
72 |
-
|
73 |
-
def forward(self, x):
|
74 |
-
return x+self.block(x)
|
75 |
-
|
76 |
-
class FreBlock(nn.Module):
|
77 |
-
def __init__(self, nc):
|
78 |
-
super(FreBlock, self).__init__()
|
79 |
-
self.fpre = nn.Conv2d(nc, nc, 1, 1, 0)
|
80 |
-
self.process1 = nn.Sequential(
|
81 |
-
nn.Conv2d(nc, nc, 1, 1, 0),
|
82 |
-
nn.LeakyReLU(0.1, inplace=True),
|
83 |
-
nn.Conv2d(nc, nc, 1, 1, 0))
|
84 |
-
self.process2 = nn.Sequential(
|
85 |
-
nn.Conv2d(nc, nc, 1, 1, 0),
|
86 |
-
nn.LeakyReLU(0.1, inplace=True),
|
87 |
-
nn.Conv2d(nc, nc, 1, 1, 0))
|
88 |
-
|
89 |
-
def forward(self, x):
|
90 |
-
_, _, H, W = x.shape
|
91 |
-
x_freq = torch.fft.rfft2(self.fpre(x), norm='backward')
|
92 |
-
mag = torch.abs(x_freq)
|
93 |
-
pha = torch.angle(x_freq)
|
94 |
-
mag = self.process1(mag)
|
95 |
-
pha = self.process2(pha)
|
96 |
-
real = mag * torch.cos(pha)
|
97 |
-
imag = mag * torch.sin(pha)
|
98 |
-
x_out = torch.complex(real, imag)
|
99 |
-
x_out = torch.fft.irfft2(x_out, s=(H, W), norm='backward')
|
100 |
-
|
101 |
-
return x_out+x
|
102 |
-
|
103 |
-
class ProcessBlock(nn.Module):
|
104 |
-
def __init__(self, in_nc, spatial = True):
|
105 |
-
super(ProcessBlock,self).__init__()
|
106 |
-
self.spatial = spatial
|
107 |
-
self.spatial_process = SpaBlock(in_nc) if spatial else nn.Identity()
|
108 |
-
self.frequency_process = FreBlock(in_nc)
|
109 |
-
self.cat = nn.Conv2d(2*in_nc,in_nc,1,1,0) if spatial else nn.Conv2d(in_nc,in_nc,1,1,0)
|
110 |
-
|
111 |
-
def forward(self, x):
|
112 |
-
xori = x
|
113 |
-
x_freq = self.frequency_process(x)
|
114 |
-
x_spatial = self.spatial_process(x)
|
115 |
-
xcat = torch.cat([x_spatial,x_freq],1)
|
116 |
-
x_out = self.cat(xcat) if self.spatial else self.cat(x_freq)
|
117 |
-
|
118 |
-
return x_out+xori
|
119 |
-
|
120 |
-
class Attention_Light(nn.Module):
|
121 |
-
|
122 |
-
def __init__(self, img_channels = 3, width = 16, spatial = False):
|
123 |
-
super(Attention_Light, self).__init__()
|
124 |
-
self.block = nn.Sequential(
|
125 |
-
nn.Conv2d(in_channels = img_channels, out_channels = width//2, kernel_size = 1, padding = 0, stride = 1, groups = 1, bias = True),
|
126 |
-
ProcessBlock(in_nc = width //2, spatial = spatial),
|
127 |
-
nn.Conv2d(in_channels = width//2, out_channels = width, kernel_size = 1, padding = 0, stride = 1, groups = 1, bias = True),
|
128 |
-
ProcessBlock(in_nc = width, spatial = spatial),
|
129 |
-
nn.Conv2d(in_channels = width, out_channels = width, kernel_size = 1, padding = 0, stride = 1, groups = 1, bias = True),
|
130 |
-
ProcessBlock(in_nc=width, spatial = spatial),
|
131 |
-
nn.Sigmoid()
|
132 |
-
)
|
133 |
-
def forward(self, input):
|
134 |
-
return self.block(input)
|
135 |
|
136 |
class Branch(nn.Module):
|
137 |
'''
|
@@ -223,7 +101,18 @@ if __name__ == '__main__':
|
|
223 |
|
224 |
from ptflops import get_model_complexity_info
|
225 |
|
226 |
-
macs, params = get_model_complexity_info(net, inp_shape, verbose=False, print_per_layer_stat=True)
|
|
|
|
|
|
|
227 |
|
|
|
|
|
|
|
228 |
|
|
|
|
|
|
|
|
|
229 |
print(macs, params)
|
|
|
|
10 |
from nafnet_utils.arch_util import LayerNorm2d
|
11 |
from nafnet_utils.arch_model import SimpleGate
|
12 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
13 |
|
14 |
class Branch(nn.Module):
|
15 |
'''
|
|
|
101 |
|
102 |
from ptflops import get_model_complexity_info
|
103 |
|
104 |
+
# macs, params = get_model_complexity_info(net, inp_shape, verbose=False, print_per_layer_stat=True)
|
105 |
+
|
106 |
+
# print('Values of EBlock:')
|
107 |
+
# print(macs, params)
|
108 |
|
109 |
+
channels = 128
|
110 |
+
resol = 32
|
111 |
+
ksize = 5
|
112 |
|
113 |
+
net = FAC(channels=channels, ksize=ksize)
|
114 |
+
inp_shape = (channels, resol, resol)
|
115 |
+
macs, params = get_model_complexity_info(net, inp_shape, verbose=False, print_per_layer_stat=True)
|
116 |
+
print('Values of FAC:')
|
117 |
print(macs, params)
|
118 |
+
|