IncreasingLoss commited on
Commit
5f757e6
·
verified ·
1 Parent(s): 4e539c0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +111 -1
app.py CHANGED
@@ -7,8 +7,118 @@ from pathlib import Path
7
  from torch.nn import init
8
  import torchvision.transforms as transforms
9
  from PIL import Image
 
 
 
 
10
 
11
- # ... [Keep all your existing model definitions and initialization code] ...
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
 
13
  # Precompute example image paths
14
  example_dir = "examples"
 
7
  from torch.nn import init
8
  import torchvision.transforms as transforms
9
  from PIL import Image
10
+ # MobileNetV3 Model Definition (keep this exactly as in your original code)
11
+ class hswish(nn.Module):
12
+ def forward(self, x):
13
+ return x * F.relu6(x + 3) / 6
14
 
15
+ class hsigmoid(nn.Module):
16
+ def forward(self, x):
17
+ return F.relu6(x + 3) / 6
18
+
19
+ class SeModule(nn.Module):
20
+ def __init__(self, in_size, reduction=4):
21
+ super().__init__()
22
+ self.se = nn.Sequential(
23
+ nn.AdaptiveAvgPool2d(1),
24
+ nn.Conv2d(in_size, in_size//reduction, 1, bias=False),
25
+ nn.BatchNorm2d(in_size//reduction),
26
+ nn.ReLU(inplace=True),
27
+ nn.Conv2d(in_size//reduction, in_size, 1, bias=False),
28
+ nn.BatchNorm2d(in_size),
29
+ hsigmoid()
30
+ )
31
+
32
+ def forward(self, x):
33
+ return x * self.se(x)
34
+
35
+ class Block(nn.Module):
36
+ def __init__(self, kernel_size, in_size, expand_size, out_size, nolinear, semodule, stride):
37
+ super().__init__()
38
+ self.stride = stride
39
+ self.se = semodule
40
+ self.conv1 = nn.Conv2d(in_size, expand_size, 1, 1, 0, bias=False)
41
+ self.bn1 = nn.BatchNorm2d(expand_size)
42
+ self.nolinear1 = nolinear
43
+ self.conv2 = nn.Conv2d(expand_size, expand_size, kernel_size, stride, kernel_size//2, groups=expand_size, bias=False)
44
+ self.bn2 = nn.BatchNorm2d(expand_size)
45
+ self.nolinear2 = nolinear
46
+ self.conv3 = nn.Conv2d(expand_size, out_size, 1, 1, 0, bias=False)
47
+ self.bn3 = nn.BatchNorm2d(out_size)
48
+ self.shortcut = nn.Sequential()
49
+ if stride == 1 and in_size != out_size:
50
+ self.shortcut = nn.Sequential(
51
+ nn.Conv2d(in_size, out_size, 1, 1, 0, bias=False),
52
+ nn.BatchNorm2d(out_size),
53
+ )
54
+
55
+ def forward(self, x):
56
+ out = self.nolinear1(self.bn1(self.conv1(x)))
57
+ out = self.nolinear2(self.bn2(self.conv2(out)))
58
+ out = self.bn3(self.conv3(out))
59
+ if self.se: out = self.se(out)
60
+ return out + self.shortcut(x) if self.stride==1 else out
61
+
62
+ class MobileNetV3_Small(nn.Module):
63
+ def __init__(self, num_classes=30):
64
+ super().__init__()
65
+ self.conv1 = nn.Conv2d(3, 16, 3, 2, 1, bias=False)
66
+ self.bn1 = nn.BatchNorm2d(16)
67
+ self.hs1 = hswish()
68
+ self.bneck = nn.Sequential(
69
+ Block(3, 16, 16, 16, nn.ReLU(), SeModule(16), 2),
70
+ Block(3, 16, 72, 24, nn.ReLU(), None, 2),
71
+ Block(3, 24, 88, 24, nn.ReLU(), None, 1),
72
+ Block(5, 24, 96, 40, hswish(), SeModule(40), 2),
73
+ Block(5, 40, 240, 40, hswish(), SeModule(40), 1),
74
+ Block(5, 40, 240, 40, hswish(), SeModule(40), 1),
75
+ Block(5, 40, 120, 48, hswish(), SeModule(48), 1),
76
+ Block(5, 48, 144, 48, hswish(), SeModule(48), 1),
77
+ Block(5, 48, 288, 96, hswish(), SeModule(96), 2),
78
+ Block(5, 96, 576, 96, hswish(), SeModule(96), 1),
79
+ Block(5, 96, 576, 96, hswish(), SeModule(96), 1),
80
+ )
81
+ self.conv2 = nn.Conv2d(96, 576, 1, 1, 0, bias=False)
82
+ self.bn2 = nn.BatchNorm2d(576)
83
+ self.hs2 = hswish()
84
+ self.linear3 = nn.Linear(576, 1280)
85
+ self.bn3 = nn.BatchNorm1d(1280)
86
+ self.hs3 = hswish()
87
+ self.linear4 = nn.Linear(1280, num_classes)
88
+
89
+ for m in self.modules():
90
+ if isinstance(m, nn.Conv2d):
91
+ init.kaiming_normal_(m.weight, mode='fan_out')
92
+ if m.bias is not None: init.constant_(m.bias, 0)
93
+ elif isinstance(m, nn.BatchNorm2d):
94
+ init.constant_(m.weight, 1)
95
+ init.constant_(m.bias, 0)
96
+ elif isinstance(m, nn.Linear):
97
+ init.normal_(m.weight, std=0.001)
98
+ if m.bias is not None: init.constant_(m.bias, 0)
99
+
100
+ def forward(self, x):
101
+ x = self.hs1(self.bn1(self.conv1(x)))
102
+ x = self.bneck(x)
103
+ x = self.hs2(self.bn2(self.conv2(x)))
104
+ x = F.avg_pool2d(x, x.size()[2:])
105
+ x = x.view(x.size(0), -1)
106
+ x = self.hs3(self.bn3(self.linear3(x)))
107
+ return self.linear4(x)
108
+
109
+ # Initialize Model
110
+ model = MobileNetV3_Small().cpu()
111
+ model.load_state_dict(torch.load("MobileNet3_small_StateDictionary.pth", map_location='cpu'))
112
+ model.eval()
113
+
114
+ # Class Labels
115
+ classes = [
116
+ 'antelope', 'buffalo', 'chimpanzee', 'cow', 'deer', 'dolphin',
117
+ 'elephant', 'fox', 'giant+panda', 'giraffe', 'gorilla', 'grizzlybear',
118
+ 'hamster', 'hippopotamus', 'horse', 'humpbackwhale', 'leopard', 'lion',
119
+ 'moose', 'otter', 'ox', 'pig', 'polarbear', 'rabbit', 'rhinoceros',
120
+ 'seal', 'sheep', 'squirrel', 'tiger', 'zebra'
121
+ ]
122
 
123
  # Precompute example image paths
124
  example_dir = "examples"