danifei commited on
Commit
992c084
·
verified ·
1 Parent(s): 3de3832

Update archs/arch_util.py

Browse files
Files changed (1) hide show
  1. archs/arch_util.py +12 -123
archs/arch_util.py CHANGED
@@ -10,128 +10,6 @@ except:
10
  from nafnet_utils.arch_util import LayerNorm2d
11
  from nafnet_utils.arch_model import SimpleGate
12
 
13
- '''
14
- https://github.com/wangchx67/FourLLIE.git
15
- '''
16
-
17
- def initialize_weights(net_l, scale=1):
18
- if not isinstance(net_l, list):
19
- net_l = [net_l]
20
- for net in net_l:
21
- for m in net.modules():
22
- if isinstance(m, nn.Conv2d):
23
- init.kaiming_normal_(m.weight, a=0, mode='fan_in')
24
- m.weight.data *= scale # for residual block
25
- if m.bias is not None:
26
- m.bias.data.zero_()
27
- elif isinstance(m, nn.Linear):
28
- init.kaiming_normal_(m.weight, a=0, mode='fan_in')
29
- m.weight.data *= scale
30
- if m.bias is not None:
31
- m.bias.data.zero_()
32
- elif isinstance(m, nn.BatchNorm2d):
33
- init.constant_(m.weight, 1)
34
- init.constant_(m.bias.data, 0.0)
35
-
36
-
37
- def make_layer(block, n_layers):
38
- layers = []
39
- for _ in range(n_layers):
40
- layers.append(block())
41
- return nn.Sequential(*layers)
42
-
43
-
44
- class ResidualBlock_noBN(nn.Module):
45
- '''Residual block w/o BN
46
- ---Conv-ReLU-Conv-+-
47
- |________________|
48
- '''
49
-
50
- def __init__(self, nf=64):
51
- super(ResidualBlock_noBN, self).__init__()
52
- self.conv1 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
53
- self.conv2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
54
-
55
- # initialization
56
- initialize_weights([self.conv1, self.conv2], 0.1)
57
-
58
- def forward(self, x):
59
- identity = x
60
- out = F.relu(self.conv1(x), inplace=True)
61
- out = self.conv2(out)
62
- return identity + out
63
-
64
- class SpaBlock(nn.Module):
65
- def __init__(self, nc):
66
- super(SpaBlock, self).__init__()
67
- self.block = nn.Sequential(
68
- nn.Conv2d(nc,nc,3,1,1),
69
- nn.LeakyReLU(0.1,inplace=True),
70
- nn.Conv2d(nc, nc, 3, 1, 1),
71
- nn.LeakyReLU(0.1, inplace=True))
72
-
73
- def forward(self, x):
74
- return x+self.block(x)
75
-
76
- class FreBlock(nn.Module):
77
- def __init__(self, nc):
78
- super(FreBlock, self).__init__()
79
- self.fpre = nn.Conv2d(nc, nc, 1, 1, 0)
80
- self.process1 = nn.Sequential(
81
- nn.Conv2d(nc, nc, 1, 1, 0),
82
- nn.LeakyReLU(0.1, inplace=True),
83
- nn.Conv2d(nc, nc, 1, 1, 0))
84
- self.process2 = nn.Sequential(
85
- nn.Conv2d(nc, nc, 1, 1, 0),
86
- nn.LeakyReLU(0.1, inplace=True),
87
- nn.Conv2d(nc, nc, 1, 1, 0))
88
-
89
- def forward(self, x):
90
- _, _, H, W = x.shape
91
- x_freq = torch.fft.rfft2(self.fpre(x), norm='backward')
92
- mag = torch.abs(x_freq)
93
- pha = torch.angle(x_freq)
94
- mag = self.process1(mag)
95
- pha = self.process2(pha)
96
- real = mag * torch.cos(pha)
97
- imag = mag * torch.sin(pha)
98
- x_out = torch.complex(real, imag)
99
- x_out = torch.fft.irfft2(x_out, s=(H, W), norm='backward')
100
-
101
- return x_out+x
102
-
103
- class ProcessBlock(nn.Module):
104
- def __init__(self, in_nc, spatial = True):
105
- super(ProcessBlock,self).__init__()
106
- self.spatial = spatial
107
- self.spatial_process = SpaBlock(in_nc) if spatial else nn.Identity()
108
- self.frequency_process = FreBlock(in_nc)
109
- self.cat = nn.Conv2d(2*in_nc,in_nc,1,1,0) if spatial else nn.Conv2d(in_nc,in_nc,1,1,0)
110
-
111
- def forward(self, x):
112
- xori = x
113
- x_freq = self.frequency_process(x)
114
- x_spatial = self.spatial_process(x)
115
- xcat = torch.cat([x_spatial,x_freq],1)
116
- x_out = self.cat(xcat) if self.spatial else self.cat(x_freq)
117
-
118
- return x_out+xori
119
-
120
- class Attention_Light(nn.Module):
121
-
122
- def __init__(self, img_channels = 3, width = 16, spatial = False):
123
- super(Attention_Light, self).__init__()
124
- self.block = nn.Sequential(
125
- nn.Conv2d(in_channels = img_channels, out_channels = width//2, kernel_size = 1, padding = 0, stride = 1, groups = 1, bias = True),
126
- ProcessBlock(in_nc = width //2, spatial = spatial),
127
- nn.Conv2d(in_channels = width//2, out_channels = width, kernel_size = 1, padding = 0, stride = 1, groups = 1, bias = True),
128
- ProcessBlock(in_nc = width, spatial = spatial),
129
- nn.Conv2d(in_channels = width, out_channels = width, kernel_size = 1, padding = 0, stride = 1, groups = 1, bias = True),
130
- ProcessBlock(in_nc=width, spatial = spatial),
131
- nn.Sigmoid()
132
- )
133
- def forward(self, input):
134
- return self.block(input)
135
 
136
  class Branch(nn.Module):
137
  '''
@@ -223,7 +101,18 @@ if __name__ == '__main__':
223
 
224
  from ptflops import get_model_complexity_info
225
 
226
- macs, params = get_model_complexity_info(net, inp_shape, verbose=False, print_per_layer_stat=True)
 
 
 
227
 
 
 
 
228
 
 
 
 
 
229
  print(macs, params)
 
 
10
  from nafnet_utils.arch_util import LayerNorm2d
11
  from nafnet_utils.arch_model import SimpleGate
12
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
 
14
  class Branch(nn.Module):
15
  '''
 
101
 
102
  from ptflops import get_model_complexity_info
103
 
104
+ # macs, params = get_model_complexity_info(net, inp_shape, verbose=False, print_per_layer_stat=True)
105
+
106
+ # print('Values of EBlock:')
107
+ # print(macs, params)
108
 
109
+ channels = 128
110
+ resol = 32
111
+ ksize = 5
112
 
113
+ net = FAC(channels=channels, ksize=ksize)
114
+ inp_shape = (channels, resol, resol)
115
+ macs, params = get_model_complexity_info(net, inp_shape, verbose=False, print_per_layer_stat=True)
116
+ print('Values of FAC:')
117
  print(macs, params)
118
+