Andykool69 commited on
Commit
d8b538d
·
verified ·
1 Parent(s): a998fbc

Added files via upload

Browse files
Files changed (3) hide show
  1. model.py +279 -0
  2. requirements.txt +7 -0
  3. resnet.py +96 -0
model.py ADDED
@@ -0,0 +1,279 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import torchvision
5
+
6
+ from resnet import Resnet18
7
+ # from modules.bn import InPlaceABNSync as BatchNorm2d
8
+
9
+
10
+ class ConvBNReLU(nn.Module):
11
+ def __init__(self, in_chan, out_chan, ks=3, stride=1, padding=1, *args, **kwargs):
12
+ super(ConvBNReLU, self).__init__()
13
+ self.conv = nn.Conv2d(in_chan,
14
+ out_chan,
15
+ kernel_size = ks,
16
+ stride = stride,
17
+ padding = padding,
18
+ bias = False)
19
+ self.bn = nn.BatchNorm2d(out_chan)
20
+ self.init_weight()
21
+
22
+ def forward(self, x):
23
+ x = self.conv(x)
24
+ x = F.relu(self.bn(x))
25
+ return x
26
+
27
+ def init_weight(self):
28
+ for ly in self.children():
29
+ if isinstance(ly, nn.Conv2d):
30
+ nn.init.kaiming_normal_(ly.weight, a=1)
31
+ if not ly.bias is None: nn.init.constant_(ly.bias, 0)
32
+
33
+ class BiSeNetOutput(nn.Module):
34
+ def __init__(self, in_chan, mid_chan, n_classes, *args, **kwargs):
35
+ super(BiSeNetOutput, self).__init__()
36
+ self.conv = ConvBNReLU(in_chan, mid_chan, ks=3, stride=1, padding=1)
37
+ self.conv_out = nn.Conv2d(mid_chan, n_classes, kernel_size=1, bias=False)
38
+ self.init_weight()
39
+
40
+ def forward(self, x):
41
+ x = self.conv(x)
42
+ x = self.conv_out(x)
43
+ return x
44
+
45
+ def init_weight(self):
46
+ for ly in self.children():
47
+ if isinstance(ly, nn.Conv2d):
48
+ nn.init.kaiming_normal_(ly.weight, a=1)
49
+ if not ly.bias is None: nn.init.constant_(ly.bias, 0)
50
+
51
+ def get_params(self):
52
+ wd_params, nowd_params = [], []
53
+ for name, module in self.named_modules():
54
+ if isinstance(module, nn.Linear) or isinstance(module, nn.Conv2d):
55
+ wd_params.append(module.weight)
56
+ if not module.bias is None:
57
+ nowd_params.append(module.bias)
58
+ elif isinstance(module, nn.BatchNorm2d):
59
+ nowd_params += list(module.parameters())
60
+ return wd_params, nowd_params
61
+
62
+
63
+ class AttentionRefinementModule(nn.Module):
64
+ def __init__(self, in_chan, out_chan, *args, **kwargs):
65
+ super(AttentionRefinementModule, self).__init__()
66
+ self.conv = ConvBNReLU(in_chan, out_chan, ks=3, stride=1, padding=1)
67
+ self.conv_atten = nn.Conv2d(out_chan, out_chan, kernel_size= 1, bias=False)
68
+ self.bn_atten = nn.BatchNorm2d(out_chan)
69
+ self.sigmoid_atten = nn.Sigmoid()
70
+ self.init_weight()
71
+
72
+ def forward(self, x):
73
+ feat = self.conv(x)
74
+ atten = F.avg_pool2d(feat, feat.size()[2:])
75
+ atten = self.conv_atten(atten)
76
+ atten = self.bn_atten(atten)
77
+ atten = self.sigmoid_atten(atten)
78
+ out = torch.mul(feat, atten)
79
+ return out
80
+
81
+ def init_weight(self):
82
+ for ly in self.children():
83
+ if isinstance(ly, nn.Conv2d):
84
+ nn.init.kaiming_normal_(ly.weight, a=1)
85
+ if not ly.bias is None: nn.init.constant_(ly.bias, 0)
86
+
87
+
88
+ class ContextPath(nn.Module):
89
+ def __init__(self, *args, **kwargs):
90
+ super(ContextPath, self).__init__()
91
+ self.resnet = Resnet18()
92
+ self.arm16 = AttentionRefinementModule(256, 128)
93
+ self.arm32 = AttentionRefinementModule(512, 128)
94
+ self.conv_head32 = ConvBNReLU(128, 128, ks=3, stride=1, padding=1)
95
+ self.conv_head16 = ConvBNReLU(128, 128, ks=3, stride=1, padding=1)
96
+ self.conv_avg = ConvBNReLU(512, 128, ks=1, stride=1, padding=0)
97
+
98
+ self.init_weight()
99
+
100
+ def forward(self, x):
101
+ H0, W0 = x.size()[2:]
102
+ feat8, feat16, feat32 = self.resnet(x)
103
+ H8, W8 = feat8.size()[2:]
104
+ H16, W16 = feat16.size()[2:]
105
+ H32, W32 = feat32.size()[2:]
106
+
107
+ avg = F.avg_pool2d(feat32, feat32.size()[2:])
108
+ avg = self.conv_avg(avg)
109
+ avg_up = F.interpolate(avg, (H32, W32), mode='nearest')
110
+
111
+ feat32_arm = self.arm32(feat32)
112
+ feat32_sum = feat32_arm + avg_up
113
+ feat32_up = F.interpolate(feat32_sum, (H16, W16), mode='nearest')
114
+ feat32_up = self.conv_head32(feat32_up)
115
+
116
+ feat16_arm = self.arm16(feat16)
117
+ feat16_sum = feat16_arm + feat32_up
118
+ feat16_up = F.interpolate(feat16_sum, (H8, W8), mode='nearest')
119
+ feat16_up = self.conv_head16(feat16_up)
120
+
121
+ return feat8, feat16_up, feat32_up # x8, x8, x16
122
+
123
+ def init_weight(self):
124
+ for ly in self.children():
125
+ if isinstance(ly, nn.Conv2d):
126
+ nn.init.kaiming_normal_(ly.weight, a=1)
127
+ if not ly.bias is None: nn.init.constant_(ly.bias, 0)
128
+
129
+ def get_params(self):
130
+ wd_params, nowd_params = [], []
131
+ for name, module in self.named_modules():
132
+ if isinstance(module, (nn.Linear, nn.Conv2d)):
133
+ wd_params.append(module.weight)
134
+ if not module.bias is None:
135
+ nowd_params.append(module.bias)
136
+ elif isinstance(module, nn.BatchNorm2d):
137
+ nowd_params += list(module.parameters())
138
+ return wd_params, nowd_params
139
+
140
+
141
+ ### This is not used, since I replace this with the resnet feature with the same size
142
+ class SpatialPath(nn.Module):
143
+ def __init__(self, *args, **kwargs):
144
+ super(SpatialPath, self).__init__()
145
+ self.conv1 = ConvBNReLU(3, 64, ks=7, stride=2, padding=3)
146
+ self.conv2 = ConvBNReLU(64, 64, ks=3, stride=2, padding=1)
147
+ self.conv3 = ConvBNReLU(64, 64, ks=3, stride=2, padding=1)
148
+ self.conv_out = ConvBNReLU(64, 128, ks=1, stride=1, padding=0)
149
+ self.init_weight()
150
+
151
+ def forward(self, x):
152
+ feat = self.conv1(x)
153
+ feat = self.conv2(feat)
154
+ feat = self.conv3(feat)
155
+ feat = self.conv_out(feat)
156
+ return feat
157
+
158
+ def init_weight(self):
159
+ for ly in self.children():
160
+ if isinstance(ly, nn.Conv2d):
161
+ nn.init.kaiming_normal_(ly.weight, a=1)
162
+ if not ly.bias is None: nn.init.constant_(ly.bias, 0)
163
+
164
+ def get_params(self):
165
+ wd_params, nowd_params = [], []
166
+ for name, module in self.named_modules():
167
+ if isinstance(module, nn.Linear) or isinstance(module, nn.Conv2d):
168
+ wd_params.append(module.weight)
169
+ if not module.bias is None:
170
+ nowd_params.append(module.bias)
171
+ elif isinstance(module, nn.BatchNorm2d):
172
+ nowd_params += list(module.parameters())
173
+ return wd_params, nowd_params
174
+
175
+
176
+ class FeatureFusionModule(nn.Module):
177
+ def __init__(self, in_chan, out_chan, *args, **kwargs):
178
+ super(FeatureFusionModule, self).__init__()
179
+ self.convblk = ConvBNReLU(in_chan, out_chan, ks=1, stride=1, padding=0)
180
+ self.conv1 = nn.Conv2d(out_chan,
181
+ out_chan//4,
182
+ kernel_size = 1,
183
+ stride = 1,
184
+ padding = 0,
185
+ bias = False)
186
+ self.conv2 = nn.Conv2d(out_chan//4,
187
+ out_chan,
188
+ kernel_size = 1,
189
+ stride = 1,
190
+ padding = 0,
191
+ bias = False)
192
+ self.relu = nn.ReLU(inplace=True)
193
+ self.sigmoid = nn.Sigmoid()
194
+ self.init_weight()
195
+
196
+ def forward(self, fsp, fcp):
197
+ fcat = torch.cat([fsp, fcp], dim=1)
198
+ feat = self.convblk(fcat)
199
+ atten = F.avg_pool2d(feat, feat.size()[2:])
200
+ atten = self.conv1(atten)
201
+ atten = self.relu(atten)
202
+ atten = self.conv2(atten)
203
+ atten = self.sigmoid(atten)
204
+ feat_atten = torch.mul(feat, atten)
205
+ feat_out = feat_atten + feat
206
+ return feat_out
207
+
208
+ def init_weight(self):
209
+ for ly in self.children():
210
+ if isinstance(ly, nn.Conv2d):
211
+ nn.init.kaiming_normal_(ly.weight, a=1)
212
+ if not ly.bias is None: nn.init.constant_(ly.bias, 0)
213
+
214
+ def get_params(self):
215
+ wd_params, nowd_params = [], []
216
+ for name, module in self.named_modules():
217
+ if isinstance(module, nn.Linear) or isinstance(module, nn.Conv2d):
218
+ wd_params.append(module.weight)
219
+ if not module.bias is None:
220
+ nowd_params.append(module.bias)
221
+ elif isinstance(module, nn.BatchNorm2d):
222
+ nowd_params += list(module.parameters())
223
+ return wd_params, nowd_params
224
+
225
+
226
+ class BiSeNet(nn.Module):
227
+ def __init__(self, n_classes, *args, **kwargs):
228
+ super(BiSeNet, self).__init__()
229
+ self.cp = ContextPath()
230
+ ## here self.sp is deleted
231
+ self.ffm = FeatureFusionModule(256, 256)
232
+ self.conv_out = BiSeNetOutput(256, 256, n_classes)
233
+ self.conv_out16 = BiSeNetOutput(128, 64, n_classes)
234
+ self.conv_out32 = BiSeNetOutput(128, 64, n_classes)
235
+ self.init_weight()
236
+
237
+ def forward(self, x):
238
+ H, W = x.size()[2:]
239
+ feat_res8, feat_cp8, feat_cp16 = self.cp(x) # here return res3b1 feature
240
+ feat_sp = feat_res8 # use res3b1 feature to replace spatial path feature
241
+ feat_fuse = self.ffm(feat_sp, feat_cp8)
242
+
243
+ feat_out = self.conv_out(feat_fuse)
244
+ feat_out16 = self.conv_out16(feat_cp8)
245
+ feat_out32 = self.conv_out32(feat_cp16)
246
+
247
+ feat_out = F.interpolate(feat_out, (H, W), mode='bilinear', align_corners=True)
248
+ feat_out16 = F.interpolate(feat_out16, (H, W), mode='bilinear', align_corners=True)
249
+ feat_out32 = F.interpolate(feat_out32, (H, W), mode='bilinear', align_corners=True)
250
+ return feat_out, feat_out16, feat_out32
251
+
252
+ def init_weight(self):
253
+ for ly in self.children():
254
+ if isinstance(ly, nn.Conv2d):
255
+ nn.init.kaiming_normal_(ly.weight, a=1)
256
+ if not ly.bias is None: nn.init.constant_(ly.bias, 0)
257
+
258
+ def get_params(self):
259
+ wd_params, nowd_params, lr_mul_wd_params, lr_mul_nowd_params = [], [], [], []
260
+ for name, child in self.named_children():
261
+ child_wd_params, child_nowd_params = child.get_params()
262
+ if isinstance(child, FeatureFusionModule) or isinstance(child, BiSeNetOutput):
263
+ lr_mul_wd_params += child_wd_params
264
+ lr_mul_nowd_params += child_nowd_params
265
+ else:
266
+ wd_params += child_wd_params
267
+ nowd_params += child_nowd_params
268
+ return wd_params, nowd_params, lr_mul_wd_params, lr_mul_nowd_params
269
+
270
+
271
+ if __name__ == "__main__":
272
+ net = BiSeNet(19)
273
+ net.cuda()
274
+ net.eval()
275
+ in_ten = torch.randn(16, 3, 640, 480).cuda()
276
+ out, out16, out32 = net(in_ten)
277
+ print(out.shape)
278
+
279
+ net.get_params()
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ torch
2
+ torchvision
3
+ numpy
4
+ opencv-python
5
+ pillow
6
+ streamlit
7
+ mediapipe
resnet.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import torch.utils.model_zoo as modelzoo
5
+
6
+ # from modules.bn import InPlaceABNSync as BatchNorm2d
7
+
8
+ resnet18_url = 'https://download.pytorch.org/models/resnet18-5c106cde.pth'
9
+
10
+
11
+ def conv3x3(in_planes, out_planes, stride=1):
12
+ """3x3 convolution with padding"""
13
+ return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
14
+ padding=1, bias=False)
15
+
16
+
17
+ class BasicBlock(nn.Module):
18
+ def __init__(self, in_chan, out_chan, stride=1):
19
+ super(BasicBlock, self).__init__()
20
+ self.conv1 = conv3x3(in_chan, out_chan, stride)
21
+ self.bn1 = nn.BatchNorm2d(out_chan)
22
+ self.conv2 = conv3x3(out_chan, out_chan)
23
+ self.bn2 = nn.BatchNorm2d(out_chan)
24
+ self.relu = nn.ReLU(inplace=True)
25
+ self.downsample = None
26
+ if in_chan != out_chan or stride != 1:
27
+ self.downsample = nn.Sequential(
28
+ nn.Conv2d(in_chan, out_chan,
29
+ kernel_size=1, stride=stride, bias=False),
30
+ nn.BatchNorm2d(out_chan),
31
+ )
32
+
33
+ def forward(self, x):
34
+ residual = self.conv1(x)
35
+ residual = F.relu(self.bn1(residual))
36
+ residual = self.conv2(residual)
37
+ residual = self.bn2(residual)
38
+
39
+ shortcut = x
40
+ if self.downsample is not None:
41
+ shortcut = self.downsample(x)
42
+
43
+ out = shortcut + residual
44
+ out = self.relu(out)
45
+ return out
46
+
47
+
48
+ def create_layer_basic(in_chan, out_chan, bnum, stride=1):
49
+ layers = [BasicBlock(in_chan, out_chan, stride=stride)]
50
+ for i in range(bnum-1):
51
+ layers.append(BasicBlock(out_chan, out_chan, stride=1))
52
+ return nn.Sequential(*layers)
53
+
54
+
55
+ class Resnet18(nn.Module):
56
+ def __init__(self):
57
+ super(Resnet18, self).__init__()
58
+ self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,
59
+ bias=False)
60
+ self.bn1 = nn.BatchNorm2d(64)
61
+ self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
62
+ self.layer1 = create_layer_basic(64, 64, bnum=2, stride=1)
63
+ self.layer2 = create_layer_basic(64, 128, bnum=2, stride=2)
64
+ self.layer3 = create_layer_basic(128, 256, bnum=2, stride=2)
65
+ self.layer4 = create_layer_basic(256, 512, bnum=2, stride=2)
66
+ self.init_weight()
67
+
68
+ def forward(self, x):
69
+ x = self.conv1(x)
70
+ x = F.relu(self.bn1(x))
71
+ x = self.maxpool(x)
72
+
73
+ x = self.layer1(x)
74
+ feat8 = self.layer2(x) # 1/8
75
+ feat16 = self.layer3(feat8) # 1/16
76
+ feat32 = self.layer4(feat16) # 1/32
77
+ return feat8, feat16, feat32
78
+
79
+ def init_weight(self):
80
+ state_dict = modelzoo.load_url(resnet18_url)
81
+ self_state_dict = self.state_dict()
82
+ for k, v in state_dict.items():
83
+ if 'fc' in k: continue
84
+ self_state_dict.update({k: v})
85
+ self.load_state_dict(self_state_dict)
86
+
87
+ def get_params(self):
88
+ wd_params, nowd_params = [], []
89
+ for name, module in self.named_modules():
90
+ if isinstance(module, (nn.Linear, nn.Conv2d)):
91
+ wd_params.append(module.weight)
92
+ if not module.bias is None:
93
+ nowd_params.append(module.bias)
94
+ elif isinstance(module, nn.BatchNorm2d):
95
+ nowd_params += list(module.parameters())
96
+ return wd_params, nowd_params