IncreasingLoss commited on
Commit
dfef520
·
verified ·
1 Parent(s): 82fd976

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +180 -181
app.py CHANGED
@@ -1,182 +1,181 @@
1
- import torch
2
- import gradio as gd
3
- import torch.nn as nn
4
- import torch.nn.functional as F
5
- from torch.nn import init
6
- from PIL import Image
7
- import torchvision.transforms as transforms
8
-
9
- '''MobileNetV3 in PyTorch.
10
-
11
- See the paper "Inverted Residuals and Linear Bottlenecks:
12
- Mobile Networks for Classification, Detection and Segmentation" for more details.
13
- '''
14
-
15
-
16
-
17
- class hswish(nn.Module):
18
- def forward(self, x):
19
- out = x * F.relu6(x + 3, inplace=True) / 6
20
- return out
21
-
22
-
23
- class hsigmoid(nn.Module):
24
- def forward(self, x):
25
- out = F.relu6(x + 3, inplace=True) / 6
26
- return out
27
-
28
-
29
- class SeModule(nn.Module):
30
- def __init__(self, in_size, reduction=4):
31
- super(SeModule, self).__init__()
32
- self.se = nn.Sequential(
33
- nn.AdaptiveAvgPool2d(1),
34
- nn.Conv2d(in_size, in_size // reduction, kernel_size=1, stride=1, padding=0, bias=False),
35
- nn.BatchNorm2d(in_size // reduction),
36
- nn.ReLU(inplace=True),
37
- nn.Conv2d(in_size // reduction, in_size, kernel_size=1, stride=1, padding=0, bias=False),
38
- nn.BatchNorm2d(in_size),
39
- hsigmoid()
40
- )
41
-
42
- def forward(self, x):
43
- return x * self.se(x)
44
-
45
-
46
- class Block(nn.Module):
47
- '''expand + depthwise + pointwise'''
48
- def __init__(self, kernel_size, in_size, expand_size, out_size, nolinear, semodule, stride):
49
- super(Block, self).__init__()
50
- self.stride = stride
51
- self.se = semodule
52
-
53
- self.conv1 = nn.Conv2d(in_size, expand_size, kernel_size=1, stride=1, padding=0, bias=False)
54
- self.bn1 = nn.BatchNorm2d(expand_size)
55
- self.nolinear1 = nolinear
56
- self.conv2 = nn.Conv2d(expand_size, expand_size, kernel_size=kernel_size, stride=stride, padding=kernel_size//2, groups=expand_size, bias=False)
57
- self.bn2 = nn.BatchNorm2d(expand_size)
58
- self.nolinear2 = nolinear
59
- self.conv3 = nn.Conv2d(expand_size, out_size, kernel_size=1, stride=1, padding=0, bias=False)
60
- self.bn3 = nn.BatchNorm2d(out_size)
61
-
62
- self.shortcut = nn.Sequential()
63
- if stride == 1 and in_size != out_size:
64
- self.shortcut = nn.Sequential(
65
- nn.Conv2d(in_size, out_size, kernel_size=1, stride=1, padding=0, bias=False),
66
- nn.BatchNorm2d(out_size),
67
- )
68
-
69
- def forward(self, x):
70
- out = self.nolinear1(self.bn1(self.conv1(x)))
71
- out = self.nolinear2(self.bn2(self.conv2(out)))
72
- out = self.bn3(self.conv3(out))
73
- if self.se != None:
74
- out = self.se(out)
75
- out = out + self.shortcut(x) if self.stride==1 else out
76
- return out
77
-
78
- class MobileNetV3_Small(nn.Module):
79
- def __init__(self, num_classes= 30):
80
- super(MobileNetV3_Small, self).__init__()
81
- self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=2, padding=1, bias=False)
82
- self.bn1 = nn.BatchNorm2d(16)
83
- self.hs1 = hswish()
84
-
85
- self.bneck = nn.Sequential(
86
- Block(3, 16, 16, 16, nn.ReLU(inplace=True), SeModule(16), 2),
87
- Block(3, 16, 72, 24, nn.ReLU(inplace=True), None, 2),
88
- Block(3, 24, 88, 24, nn.ReLU(inplace=True), None, 1),
89
- Block(5, 24, 96, 40, hswish(), SeModule(40), 2),
90
- Block(5, 40, 240, 40, hswish(), SeModule(40), 1),
91
- Block(5, 40, 240, 40, hswish(), SeModule(40), 1),
92
- Block(5, 40, 120, 48, hswish(), SeModule(48), 1),
93
- Block(5, 48, 144, 48, hswish(), SeModule(48), 1),
94
- Block(5, 48, 288, 96, hswish(), SeModule(96), 2),
95
- Block(5, 96, 576, 96, hswish(), SeModule(96), 1),
96
- Block(5, 96, 576, 96, hswish(), SeModule(96), 1),
97
- )
98
-
99
-
100
- self.conv2 = nn.Conv2d(96, 576, kernel_size=1, stride=1, padding=0, bias=False)
101
- self.bn2 = nn.BatchNorm2d(576)
102
- self.hs2 = hswish()
103
- self.linear3 = nn.Linear(576, 1280)
104
- self.bn3 = nn.BatchNorm1d(1280)
105
- self.hs3 = hswish()
106
- self.linear4 = nn.Linear(1280, num_classes)
107
- self.init_params()
108
-
109
- def init_params(self):
110
- for m in self.modules():
111
- if isinstance(m, nn.Conv2d):
112
- init.kaiming_normal_(m.weight, mode='fan_out')
113
- if m.bias is not None:
114
- init.constant_(m.bias, 0)
115
- elif isinstance(m, nn.BatchNorm2d):
116
- init.constant_(m.weight, 1)
117
- init.constant_(m.bias, 0)
118
- elif isinstance(m, nn.Linear):
119
- init.normal_(m.weight, std=0.001)
120
- if m.bias is not None:
121
- init.constant_(m.bias, 0)
122
-
123
- def forward(self, x):
124
- out = self.hs1(self.bn1(self.conv1(x)))
125
- out = self.bneck(out)
126
- out = self.hs2(self.bn2(self.conv2(out)))
127
- out = F.avg_pool2d(out, 7)
128
- out = out.view(out.size(0), -1)
129
- out = self.hs3(self.bn3(self.linear3(out)))
130
- out = self.linear4(out)
131
- return out
132
-
133
-
134
-
135
- """creating modelinstance"""
136
- model = MobileNetV3_Small().to("cpu")
137
- model.load_state_dict( torch.load("MobileNet3_small_full.pth"))
138
- classes = ['antelope',
139
- 'buffalo',
140
- 'chimpanzee',
141
- 'cow',
142
- 'deer',
143
- 'dolphin',
144
- 'elephant',
145
- 'fox',
146
- 'giant+panda',
147
- 'giraffe',
148
- 'gorilla',
149
- 'grizzly+bear',
150
- 'hamster',
151
- 'hippopotamus',
152
- 'horse',
153
- 'humpback+whale',
154
- 'leopard',
155
- 'lion',
156
- 'moose',
157
- 'otter',
158
- 'ox',
159
- 'pig',
160
- 'polar+bear',
161
- 'rabbit',
162
- 'rhinoceros',
163
- 'seal',
164
- 'sheep',
165
- 'squirrel',
166
- 'tiger',
167
- 'zebra']
168
-
169
- def predicts(img):
170
- model.eval()
171
- with torch.inference_mode:
172
- logits = model(img.unsqueez(dim=0))
173
- preds = logits.argmax(dim=1)
174
- return classes[preds]
175
-
176
- """gradio inteface"""
177
- demo = gd.Interface(predicts ,gd.Image("Image", width=244, height=244, image_mode="RGB", ))
178
-
179
-
180
- """launch interface"""
181
- if __name__ == "__main__":
182
  demo.launch()
 
1
+ import torch
2
+ import gradio as gd
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ from torch.nn import init
6
+ import torchvision.transforms as transforms
7
+
8
+ '''MobileNetV3 in PyTorch.
9
+
10
+ See the paper "Inverted Residuals and Linear Bottlenecks:
11
+ Mobile Networks for Classification, Detection and Segmentation" for more details.
12
+ '''
13
+
14
+
15
+
16
+ class hswish(nn.Module):
17
+ def forward(self, x):
18
+ out = x * F.relu6(x + 3, inplace=True) / 6
19
+ return out
20
+
21
+
22
+ class hsigmoid(nn.Module):
23
+ def forward(self, x):
24
+ out = F.relu6(x + 3, inplace=True) / 6
25
+ return out
26
+
27
+
28
+ class SeModule(nn.Module):
29
+ def __init__(self, in_size, reduction=4):
30
+ super(SeModule, self).__init__()
31
+ self.se = nn.Sequential(
32
+ nn.AdaptiveAvgPool2d(1),
33
+ nn.Conv2d(in_size, in_size // reduction, kernel_size=1, stride=1, padding=0, bias=False),
34
+ nn.BatchNorm2d(in_size // reduction),
35
+ nn.ReLU(inplace=True),
36
+ nn.Conv2d(in_size // reduction, in_size, kernel_size=1, stride=1, padding=0, bias=False),
37
+ nn.BatchNorm2d(in_size),
38
+ hsigmoid()
39
+ )
40
+
41
+ def forward(self, x):
42
+ return x * self.se(x)
43
+
44
+
45
+ class Block(nn.Module):
46
+ '''expand + depthwise + pointwise'''
47
+ def __init__(self, kernel_size, in_size, expand_size, out_size, nolinear, semodule, stride):
48
+ super(Block, self).__init__()
49
+ self.stride = stride
50
+ self.se = semodule
51
+
52
+ self.conv1 = nn.Conv2d(in_size, expand_size, kernel_size=1, stride=1, padding=0, bias=False)
53
+ self.bn1 = nn.BatchNorm2d(expand_size)
54
+ self.nolinear1 = nolinear
55
+ self.conv2 = nn.Conv2d(expand_size, expand_size, kernel_size=kernel_size, stride=stride, padding=kernel_size//2, groups=expand_size, bias=False)
56
+ self.bn2 = nn.BatchNorm2d(expand_size)
57
+ self.nolinear2 = nolinear
58
+ self.conv3 = nn.Conv2d(expand_size, out_size, kernel_size=1, stride=1, padding=0, bias=False)
59
+ self.bn3 = nn.BatchNorm2d(out_size)
60
+
61
+ self.shortcut = nn.Sequential()
62
+ if stride == 1 and in_size != out_size:
63
+ self.shortcut = nn.Sequential(
64
+ nn.Conv2d(in_size, out_size, kernel_size=1, stride=1, padding=0, bias=False),
65
+ nn.BatchNorm2d(out_size),
66
+ )
67
+
68
+ def forward(self, x):
69
+ out = self.nolinear1(self.bn1(self.conv1(x)))
70
+ out = self.nolinear2(self.bn2(self.conv2(out)))
71
+ out = self.bn3(self.conv3(out))
72
+ if self.se != None:
73
+ out = self.se(out)
74
+ out = out + self.shortcut(x) if self.stride==1 else out
75
+ return out
76
+
77
+ class MobileNetV3_Small(nn.Module):
78
+ def __init__(self, num_classes= 30):
79
+ super(MobileNetV3_Small, self).__init__()
80
+ self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=2, padding=1, bias=False)
81
+ self.bn1 = nn.BatchNorm2d(16)
82
+ self.hs1 = hswish()
83
+
84
+ self.bneck = nn.Sequential(
85
+ Block(3, 16, 16, 16, nn.ReLU(inplace=True), SeModule(16), 2),
86
+ Block(3, 16, 72, 24, nn.ReLU(inplace=True), None, 2),
87
+ Block(3, 24, 88, 24, nn.ReLU(inplace=True), None, 1),
88
+ Block(5, 24, 96, 40, hswish(), SeModule(40), 2),
89
+ Block(5, 40, 240, 40, hswish(), SeModule(40), 1),
90
+ Block(5, 40, 240, 40, hswish(), SeModule(40), 1),
91
+ Block(5, 40, 120, 48, hswish(), SeModule(48), 1),
92
+ Block(5, 48, 144, 48, hswish(), SeModule(48), 1),
93
+ Block(5, 48, 288, 96, hswish(), SeModule(96), 2),
94
+ Block(5, 96, 576, 96, hswish(), SeModule(96), 1),
95
+ Block(5, 96, 576, 96, hswish(), SeModule(96), 1),
96
+ )
97
+
98
+
99
+ self.conv2 = nn.Conv2d(96, 576, kernel_size=1, stride=1, padding=0, bias=False)
100
+ self.bn2 = nn.BatchNorm2d(576)
101
+ self.hs2 = hswish()
102
+ self.linear3 = nn.Linear(576, 1280)
103
+ self.bn3 = nn.BatchNorm1d(1280)
104
+ self.hs3 = hswish()
105
+ self.linear4 = nn.Linear(1280, num_classes)
106
+ self.init_params()
107
+
108
+ def init_params(self):
109
+ for m in self.modules():
110
+ if isinstance(m, nn.Conv2d):
111
+ init.kaiming_normal_(m.weight, mode='fan_out')
112
+ if m.bias is not None:
113
+ init.constant_(m.bias, 0)
114
+ elif isinstance(m, nn.BatchNorm2d):
115
+ init.constant_(m.weight, 1)
116
+ init.constant_(m.bias, 0)
117
+ elif isinstance(m, nn.Linear):
118
+ init.normal_(m.weight, std=0.001)
119
+ if m.bias is not None:
120
+ init.constant_(m.bias, 0)
121
+
122
+ def forward(self, x):
123
+ out = self.hs1(self.bn1(self.conv1(x)))
124
+ out = self.bneck(out)
125
+ out = self.hs2(self.bn2(self.conv2(out)))
126
+ out = F.avg_pool2d(out, 7)
127
+ out = out.view(out.size(0), -1)
128
+ out = self.hs3(self.bn3(self.linear3(out)))
129
+ out = self.linear4(out)
130
+ return out
131
+
132
+
133
+
134
+ """creating modelinstance"""
135
+ model = MobileNetV3_Small().to("cpu")
136
+ model.load_state_dict( torch.load("MobileNet3_small_full.pth"))
137
+ classes = ['antelope',
138
+ 'buffalo',
139
+ 'chimpanzee',
140
+ 'cow',
141
+ 'deer',
142
+ 'dolphin',
143
+ 'elephant',
144
+ 'fox',
145
+ 'giant+panda',
146
+ 'giraffe',
147
+ 'gorilla',
148
+ 'grizzly+bear',
149
+ 'hamster',
150
+ 'hippopotamus',
151
+ 'horse',
152
+ 'humpback+whale',
153
+ 'leopard',
154
+ 'lion',
155
+ 'moose',
156
+ 'otter',
157
+ 'ox',
158
+ 'pig',
159
+ 'polar+bear',
160
+ 'rabbit',
161
+ 'rhinoceros',
162
+ 'seal',
163
+ 'sheep',
164
+ 'squirrel',
165
+ 'tiger',
166
+ 'zebra']
167
+
168
+ def predicts(img):
169
+ model.eval()
170
+ with torch.inference_mode:
171
+ logits = model(img.unsqueez(dim=0))
172
+ preds = logits.argmax(dim=1)
173
+ return classes[preds]
174
+
175
+ """gradio inteface"""
176
+ demo = gd.Interface(predicts ,gd.Image("Image", width=244, height=244, image_mode="RGB", ))
177
+
178
+
179
+ """launch interface"""
180
+ if __name__ == "__main__":
 
181
  demo.launch()