IncreasingLoss commited on
Commit
f3d9287
·
verified ·
1 Parent(s): aad9cf7

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +182 -0
app.py ADDED
@@ -0,0 +1,182 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import gradio as gd
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ from torch.nn import init
6
+ from PIL import Image
7
+ import torchvision.transforms as transforms
8
+
9
+ '''MobileNetV3 in PyTorch.
10
+
11
+ See the paper "Inverted Residuals and Linear Bottlenecks:
12
+ Mobile Networks for Classification, Detection and Segmentation" for more details.
13
+ '''
14
+
15
+
16
+
17
+ class hswish(nn.Module):
18
+ def forward(self, x):
19
+ out = x * F.relu6(x + 3, inplace=True) / 6
20
+ return out
21
+
22
+
23
+ class hsigmoid(nn.Module):
24
+ def forward(self, x):
25
+ out = F.relu6(x + 3, inplace=True) / 6
26
+ return out
27
+
28
+
29
+ class SeModule(nn.Module):
30
+ def __init__(self, in_size, reduction=4):
31
+ super(SeModule, self).__init__()
32
+ self.se = nn.Sequential(
33
+ nn.AdaptiveAvgPool2d(1),
34
+ nn.Conv2d(in_size, in_size // reduction, kernel_size=1, stride=1, padding=0, bias=False),
35
+ nn.BatchNorm2d(in_size // reduction),
36
+ nn.ReLU(inplace=True),
37
+ nn.Conv2d(in_size // reduction, in_size, kernel_size=1, stride=1, padding=0, bias=False),
38
+ nn.BatchNorm2d(in_size),
39
+ hsigmoid()
40
+ )
41
+
42
+ def forward(self, x):
43
+ return x * self.se(x)
44
+
45
+
46
+ class Block(nn.Module):
47
+ '''expand + depthwise + pointwise'''
48
+ def __init__(self, kernel_size, in_size, expand_size, out_size, nolinear, semodule, stride):
49
+ super(Block, self).__init__()
50
+ self.stride = stride
51
+ self.se = semodule
52
+
53
+ self.conv1 = nn.Conv2d(in_size, expand_size, kernel_size=1, stride=1, padding=0, bias=False)
54
+ self.bn1 = nn.BatchNorm2d(expand_size)
55
+ self.nolinear1 = nolinear
56
+ self.conv2 = nn.Conv2d(expand_size, expand_size, kernel_size=kernel_size, stride=stride, padding=kernel_size//2, groups=expand_size, bias=False)
57
+ self.bn2 = nn.BatchNorm2d(expand_size)
58
+ self.nolinear2 = nolinear
59
+ self.conv3 = nn.Conv2d(expand_size, out_size, kernel_size=1, stride=1, padding=0, bias=False)
60
+ self.bn3 = nn.BatchNorm2d(out_size)
61
+
62
+ self.shortcut = nn.Sequential()
63
+ if stride == 1 and in_size != out_size:
64
+ self.shortcut = nn.Sequential(
65
+ nn.Conv2d(in_size, out_size, kernel_size=1, stride=1, padding=0, bias=False),
66
+ nn.BatchNorm2d(out_size),
67
+ )
68
+
69
+ def forward(self, x):
70
+ out = self.nolinear1(self.bn1(self.conv1(x)))
71
+ out = self.nolinear2(self.bn2(self.conv2(out)))
72
+ out = self.bn3(self.conv3(out))
73
+ if self.se != None:
74
+ out = self.se(out)
75
+ out = out + self.shortcut(x) if self.stride==1 else out
76
+ return out
77
+
78
+ class MobileNetV3_Small(nn.Module):
79
+ def __init__(self, num_classes= 30):
80
+ super(MobileNetV3_Small, self).__init__()
81
+ self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=2, padding=1, bias=False)
82
+ self.bn1 = nn.BatchNorm2d(16)
83
+ self.hs1 = hswish()
84
+
85
+ self.bneck = nn.Sequential(
86
+ Block(3, 16, 16, 16, nn.ReLU(inplace=True), SeModule(16), 2),
87
+ Block(3, 16, 72, 24, nn.ReLU(inplace=True), None, 2),
88
+ Block(3, 24, 88, 24, nn.ReLU(inplace=True), None, 1),
89
+ Block(5, 24, 96, 40, hswish(), SeModule(40), 2),
90
+ Block(5, 40, 240, 40, hswish(), SeModule(40), 1),
91
+ Block(5, 40, 240, 40, hswish(), SeModule(40), 1),
92
+ Block(5, 40, 120, 48, hswish(), SeModule(48), 1),
93
+ Block(5, 48, 144, 48, hswish(), SeModule(48), 1),
94
+ Block(5, 48, 288, 96, hswish(), SeModule(96), 2),
95
+ Block(5, 96, 576, 96, hswish(), SeModule(96), 1),
96
+ Block(5, 96, 576, 96, hswish(), SeModule(96), 1),
97
+ )
98
+
99
+
100
+ self.conv2 = nn.Conv2d(96, 576, kernel_size=1, stride=1, padding=0, bias=False)
101
+ self.bn2 = nn.BatchNorm2d(576)
102
+ self.hs2 = hswish()
103
+ self.linear3 = nn.Linear(576, 1280)
104
+ self.bn3 = nn.BatchNorm1d(1280)
105
+ self.hs3 = hswish()
106
+ self.linear4 = nn.Linear(1280, num_classes)
107
+ self.init_params()
108
+
109
+ def init_params(self):
110
+ for m in self.modules():
111
+ if isinstance(m, nn.Conv2d):
112
+ init.kaiming_normal_(m.weight, mode='fan_out')
113
+ if m.bias is not None:
114
+ init.constant_(m.bias, 0)
115
+ elif isinstance(m, nn.BatchNorm2d):
116
+ init.constant_(m.weight, 1)
117
+ init.constant_(m.bias, 0)
118
+ elif isinstance(m, nn.Linear):
119
+ init.normal_(m.weight, std=0.001)
120
+ if m.bias is not None:
121
+ init.constant_(m.bias, 0)
122
+
123
+ def forward(self, x):
124
+ out = self.hs1(self.bn1(self.conv1(x)))
125
+ out = self.bneck(out)
126
+ out = self.hs2(self.bn2(self.conv2(out)))
127
+ out = F.avg_pool2d(out, 7)
128
+ out = out.view(out.size(0), -1)
129
+ out = self.hs3(self.bn3(self.linear3(out)))
130
+ out = self.linear4(out)
131
+ return out
132
+
133
+
134
+
135
+ """creating modelinstance"""
136
+ model = MobileNetV3_Small().to("cpu")
137
+ model.load_state_dict( torch.load("MobileNet3_small_full.pth"))
138
+ classes = ['antelope',
139
+ 'buffalo',
140
+ 'chimpanzee',
141
+ 'cow',
142
+ 'deer',
143
+ 'dolphin',
144
+ 'elephant',
145
+ 'fox',
146
+ 'giant+panda',
147
+ 'giraffe',
148
+ 'gorilla',
149
+ 'grizzly+bear',
150
+ 'hamster',
151
+ 'hippopotamus',
152
+ 'horse',
153
+ 'humpback+whale',
154
+ 'leopard',
155
+ 'lion',
156
+ 'moose',
157
+ 'otter',
158
+ 'ox',
159
+ 'pig',
160
+ 'polar+bear',
161
+ 'rabbit',
162
+ 'rhinoceros',
163
+ 'seal',
164
+ 'sheep',
165
+ 'squirrel',
166
+ 'tiger',
167
+ 'zebra']
168
+
169
+ def predicts(img):
170
+ model.eval()
171
+ with torch.inference_mode:
172
+ logits = model(img.unsqueez(dim=0))
173
+ preds = logits.argmax(dim=1)
174
+ return classes[preds]
175
+
176
+ """gradio inteface"""
177
+ demo = gd.Interface(predicts ,gd.Image("Image", width=244, height=244, image_mode="RGB", ))
178
+
179
+
180
+ """launch interface"""
181
+ if __name__ == "__main__":
182
+ demo.launch()