IncreasingLoss commited on
Commit
13685f2
·
verified ·
1 Parent(s): 7130471

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +170 -228
app.py CHANGED
@@ -1,228 +1,170 @@
1
- import torch
2
- import gradio as gr
3
- import torch.nn as nn
4
- import torch.nn.functional as F
5
- from torch.nn import init
6
- import torchvision.transforms as transforms
7
-
8
- '''MobileNetV3 in PyTorch.
9
-
10
- See the paper "Inverted Residuals and Linear Bottlenecks:
11
- Mobile Networks for Classification, Detection and Segmentation" for more details.
12
- '''
13
-
14
-
15
-
16
- class hswish(nn.Module):
17
- def forward(self, x):
18
- out = x * F.relu6(x + 3, inplace=True) / 6
19
- return out
20
-
21
-
22
- class hsigmoid(nn.Module):
23
- def forward(self, x):
24
- out = F.relu6(x + 3, inplace=True) / 6
25
- return out
26
-
27
-
28
- class SeModule(nn.Module):
29
- def __init__(self, in_size, reduction=4):
30
- super(SeModule, self).__init__()
31
- self.se = nn.Sequential(
32
- nn.AdaptiveAvgPool2d(1),
33
- nn.Conv2d(in_size, in_size // reduction, kernel_size=1, stride=1, padding=0, bias=False),
34
- nn.BatchNorm2d(in_size // reduction),
35
- nn.ReLU(inplace=True),
36
- nn.Conv2d(in_size // reduction, in_size, kernel_size=1, stride=1, padding=0, bias=False),
37
- nn.BatchNorm2d(in_size),
38
- hsigmoid()
39
- )
40
-
41
- def forward(self, x):
42
- return x * self.se(x)
43
-
44
-
45
- class Block(nn.Module):
46
- '''expand + depthwise + pointwise'''
47
- def __init__(self, kernel_size, in_size, expand_size, out_size, nolinear, semodule, stride):
48
- super(Block, self).__init__()
49
- self.stride = stride
50
- self.se = semodule
51
-
52
- self.conv1 = nn.Conv2d(in_size, expand_size, kernel_size=1, stride=1, padding=0, bias=False)
53
- self.bn1 = nn.BatchNorm2d(expand_size)
54
- self.nolinear1 = nolinear
55
- self.conv2 = nn.Conv2d(expand_size, expand_size, kernel_size=kernel_size, stride=stride, padding=kernel_size//2, groups=expand_size, bias=False)
56
- self.bn2 = nn.BatchNorm2d(expand_size)
57
- self.nolinear2 = nolinear
58
- self.conv3 = nn.Conv2d(expand_size, out_size, kernel_size=1, stride=1, padding=0, bias=False)
59
- self.bn3 = nn.BatchNorm2d(out_size)
60
-
61
- self.shortcut = nn.Sequential()
62
- if stride == 1 and in_size != out_size:
63
- self.shortcut = nn.Sequential(
64
- nn.Conv2d(in_size, out_size, kernel_size=1, stride=1, padding=0, bias=False),
65
- nn.BatchNorm2d(out_size),
66
- )
67
-
68
- def forward(self, x):
69
- out = self.nolinear1(self.bn1(self.conv1(x)))
70
- out = self.nolinear2(self.bn2(self.conv2(out)))
71
- out = self.bn3(self.conv3(out))
72
- if self.se != None:
73
- out = self.se(out)
74
- out = out + self.shortcut(x) if self.stride==1 else out
75
- return out
76
-
77
- class MobileNetV3_Small(nn.Module):
78
- def __init__(self, num_classes= 30):
79
- super(MobileNetV3_Small, self).__init__()
80
- self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=2, padding=1, bias=False)
81
- self.bn1 = nn.BatchNorm2d(16)
82
- self.hs1 = hswish()
83
-
84
- self.bneck = nn.Sequential(
85
- Block(3, 16, 16, 16, nn.ReLU(inplace=True), SeModule(16), 2),
86
- Block(3, 16, 72, 24, nn.ReLU(inplace=True), None, 2),
87
- Block(3, 24, 88, 24, nn.ReLU(inplace=True), None, 1),
88
- Block(5, 24, 96, 40, hswish(), SeModule(40), 2),
89
- Block(5, 40, 240, 40, hswish(), SeModule(40), 1),
90
- Block(5, 40, 240, 40, hswish(), SeModule(40), 1),
91
- Block(5, 40, 120, 48, hswish(), SeModule(48), 1),
92
- Block(5, 48, 144, 48, hswish(), SeModule(48), 1),
93
- Block(5, 48, 288, 96, hswish(), SeModule(96), 2),
94
- Block(5, 96, 576, 96, hswish(), SeModule(96), 1),
95
- Block(5, 96, 576, 96, hswish(), SeModule(96), 1),
96
- )
97
-
98
-
99
- self.conv2 = nn.Conv2d(96, 576, kernel_size=1, stride=1, padding=0, bias=False)
100
- self.bn2 = nn.BatchNorm2d(576)
101
- self.hs2 = hswish()
102
- self.linear3 = nn.Linear(576, 1280)
103
- self.bn3 = nn.BatchNorm1d(1280)
104
- self.hs3 = hswish()
105
- self.linear4 = nn.Linear(1280, num_classes)
106
- self.init_params()
107
-
108
- def init_params(self):
109
- for m in self.modules():
110
- if isinstance(m, nn.Conv2d):
111
- init.kaiming_normal_(m.weight, mode='fan_out')
112
- if m.bias is not None:
113
- init.constant_(m.bias, 0)
114
- elif isinstance(m, nn.BatchNorm2d):
115
- init.constant_(m.weight, 1)
116
- init.constant_(m.bias, 0)
117
- elif isinstance(m, nn.Linear):
118
- init.normal_(m.weight, std=0.001)
119
- if m.bias is not None:
120
- init.constant_(m.bias, 0)
121
-
122
- def forward(self, x):
123
- out = self.hs1(self.bn1(self.conv1(x)))
124
- out = self.bneck(out)
125
- out = self.hs2(self.bn2(self.conv2(out)))
126
- out = F.avg_pool2d(out, 7)
127
- out = out.view(out.size(0), -1)
128
- out = self.hs3(self.bn3(self.linear3(out)))
129
- out = self.linear4(out)
130
- return out
131
-
132
-
133
-
134
- """creating modelinstance"""
135
- model = MobileNetV3_Small(num_classes=30).to("cpu")
136
- model.load_state_dict( torch.load("MobileNet3_small_StateDictionary.pth"))
137
-
138
- classes = ['antelope',
139
- 'buffalo',
140
- 'chimpanzee',
141
- 'cow',
142
- 'deer',
143
- 'dolphin',
144
- 'elephant',
145
- 'fox',
146
- 'giant+panda',
147
- 'giraffe',
148
- 'gorilla',
149
- 'grizzly+bear',
150
- 'hamster',
151
- 'hippopotamus',
152
- 'horse',
153
- 'humpback+whale',
154
- 'leopard',
155
- 'lion',
156
- 'moose',
157
- 'otter',
158
- 'ox',
159
- 'pig',
160
- 'polar+bear',
161
- 'rabbit',
162
- 'rhinoceros',
163
- 'seal',
164
- 'sheep',
165
- 'squirrel',
166
- 'tiger',
167
- 'zebra']
168
-
169
-
170
-
171
- def predict(img):
172
- # Add preprocessing and fix inference mode
173
- model.eval()
174
-
175
- # Convert Gradio input to tensor
176
- preprocess = transforms.Compose([
177
- transforms.Resize(256),
178
- transforms.CenterCrop(224),
179
- transforms.ToTensor(),
180
- transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
181
- ])
182
-
183
- img_tensor = preprocess(img).unsqueeze(0)
184
-
185
- with torch.inference_mode():
186
- logits = model(img_tensor)
187
- preds = logits.argmax(dim=1)
188
- return classes[preds.item()]
189
-
190
-
191
-
192
- """gradio interface"""
193
- """demo = gr.Interface(
194
- fn=predict,
195
- inputs=gr.Image(type="pil", width=244, height=244),
196
- outputs="label",
197
- title="Animal Classifier",
198
- description="Classify 30 animal categories: antelope, buffalo, chimpanzee, cow, deer, dolphin, elephant, fox, giant+panda, giraffe, gorilla, grizzly+bear, hamster, hippopotamus, horse, humpback+whale, leopard, lion, moose, otter, ox, pig, polar+bear, rabbit, rhinoceros, seal, sheep, squirrel, tiger, zebra"
199
- )
200
- """
201
-
202
- with gr.Blocks() as demo:
203
- gr.Markdown("## Animal Classifier")
204
- gr.Markdown("Classify 30 animal categories: antelope, buffalo, chimpanzee, cow, deer, dolphin, elephant, fox, giant+panda, giraffe, gorilla, grizzly+bear, hamster, hippopotamus, horse, humpback+whale, leopard, lion, moose, otter, ox, pig, polar+bear, rabbit, rhinoceros, seal, sheep, squirrel, tiger, zebra")
205
- with gr.Row():
206
- upload = gr.File(
207
- file_count="multiple",
208
- file_types=["image"],
209
- label="Upload Images"
210
- )
211
- submit = gr.Button("Classify")
212
-
213
- with gr.Row():
214
- gallery = gr.Gallery(label="Uploaded Images")
215
- predictions = gr.Textbox(label="Predictions", interactive=False)
216
-
217
- submit.click(
218
- fn=lambda files: (
219
- [f.name for f in files],
220
- ", ".join([predict([file]) for file in files])
221
- ),
222
- inputs=upload,
223
- outputs=[gallery, predictions]
224
- )
225
-
226
- """launch interface"""
227
- if __name__ == "__main__":
228
- demo.launch()
 
1
+ import torch
2
+ import gradio as gr
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ from torch.nn import init
6
+ import torchvision.transforms as transforms
7
+ from PIL import Image
8
+
9
+ # MobileNetV3 Model Definition (keep this exactly as in your original code)
10
+ class hswish(nn.Module):
11
+ def forward(self, x):
12
+ return x * F.relu6(x + 3) / 6
13
+
14
+ class hsigmoid(nn.Module):
15
+ def forward(self, x):
16
+ return F.relu6(x + 3) / 6
17
+
18
+ class SeModule(nn.Module):
19
+ def __init__(self, in_size, reduction=4):
20
+ super().__init__()
21
+ self.se = nn.Sequential(
22
+ nn.AdaptiveAvgPool2d(1),
23
+ nn.Conv2d(in_size, in_size//reduction, 1, bias=False),
24
+ nn.BatchNorm2d(in_size//reduction),
25
+ nn.ReLU(inplace=True),
26
+ nn.Conv2d(in_size//reduction, in_size, 1, bias=False),
27
+ nn.BatchNorm2d(in_size),
28
+ hsigmoid()
29
+ )
30
+
31
+ def forward(self, x):
32
+ return x * self.se(x)
33
+
34
+ class Block(nn.Module):
35
+ def __init__(self, kernel_size, in_size, expand_size, out_size, nolinear, semodule, stride):
36
+ super().__init__()
37
+ self.stride = stride
38
+ self.se = semodule
39
+ self.conv1 = nn.Conv2d(in_size, expand_size, 1, 1, 0, bias=False)
40
+ self.bn1 = nn.BatchNorm2d(expand_size)
41
+ self.nolinear1 = nolinear
42
+ self.conv2 = nn.Conv2d(expand_size, expand_size, kernel_size, stride, kernel_size//2, groups=expand_size, bias=False)
43
+ self.bn2 = nn.BatchNorm2d(expand_size)
44
+ self.nolinear2 = nolinear
45
+ self.conv3 = nn.Conv2d(expand_size, out_size, 1, 1, 0, bias=False)
46
+ self.bn3 = nn.BatchNorm2d(out_size)
47
+ self.shortcut = nn.Sequential()
48
+ if stride == 1 and in_size != out_size:
49
+ self.shortcut = nn.Sequential(
50
+ nn.Conv2d(in_size, out_size, 1, 1, 0, bias=False),
51
+ nn.BatchNorm2d(out_size),
52
+ )
53
+
54
+ def forward(self, x):
55
+ out = self.nolinear1(self.bn1(self.conv1(x)))
56
+ out = self.nolinear2(self.bn2(self.conv2(out)))
57
+ out = self.bn3(self.conv3(out))
58
+ if self.se: out = self.se(out)
59
+ return out + self.shortcut(x) if self.stride==1 else out
60
+
61
+ class MobileNetV3_Small(nn.Module):
62
+ def __init__(self, num_classes=30):
63
+ super().__init__()
64
+ self.conv1 = nn.Conv2d(3, 16, 3, 2, 1, bias=False)
65
+ self.bn1 = nn.BatchNorm2d(16)
66
+ self.hs1 = hswish()
67
+ self.bneck = nn.Sequential(
68
+ Block(3, 16, 16, 16, nn.ReLU(), SeModule(16), 2),
69
+ Block(3, 16, 72, 24, nn.ReLU(), None, 2),
70
+ Block(3, 24, 88, 24, nn.ReLU(), None, 1),
71
+ Block(5, 24, 96, 40, hswish(), SeModule(40), 2),
72
+ Block(5, 40, 240, 40, hswish(), SeModule(40), 1),
73
+ Block(5, 40, 240, 40, hswish(), SeModule(40), 1),
74
+ Block(5, 40, 120, 48, hswish(), SeModule(48), 1),
75
+ Block(5, 48, 144, 48, hswish(), SeModule(48), 1),
76
+ Block(5, 48, 288, 96, hswish(), SeModule(96), 2),
77
+ Block(5, 96, 576, 96, hswish(), SeModule(96), 1),
78
+ Block(5, 96, 576, 96, hswish(), SeModule(96), 1),
79
+ )
80
+ self.conv2 = nn.Conv2d(96, 576, 1, 1, 0, bias=False)
81
+ self.bn2 = nn.BatchNorm2d(576)
82
+ self.hs2 = hswish()
83
+ self.linear3 = nn.Linear(576, 1280)
84
+ self.bn3 = nn.BatchNorm1d(1280)
85
+ self.hs3 = hswish()
86
+ self.linear4 = nn.Linear(1280, num_classes)
87
+
88
+ for m in self.modules():
89
+ if isinstance(m, nn.Conv2d):
90
+ init.kaiming_normal_(m.weight, mode='fan_out')
91
+ if m.bias is not None: init.constant_(m.bias, 0)
92
+ elif isinstance(m, nn.BatchNorm2d):
93
+ init.constant_(m.weight, 1)
94
+ init.constant_(m.bias, 0)
95
+ elif isinstance(m, nn.Linear):
96
+ init.normal_(m.weight, std=0.001)
97
+ if m.bias is not None: init.constant_(m.bias, 0)
98
+
99
+ def forward(self, x):
100
+ x = self.hs1(self.bn1(self.conv1(x)))
101
+ x = self.bneck(x)
102
+ x = self.hs2(self.bn2(self.conv2(x)))
103
+ x = F.avg_pool2d(x, x.size()[2:])
104
+ x = x.view(x.size(0), -1)
105
+ x = self.hs3(self.bn3(self.linear3(x)))
106
+ return self.linear4(x)
107
+
108
+ # Initialize Model
109
+ model = MobileNetV3_Small().cpu()
110
+ model.load_state_dict(torch.load("MobileNet3_small_StateDictionary.pth", map_location='cpu'))
111
+ model.eval()
112
+
113
+ # Class Labels
114
+ classes = [
115
+ 'antelope', 'buffalo', 'chimpanzee', 'cow', 'deer', 'dolphin',
116
+ 'elephant', 'fox', 'giant+panda', 'giraffe', 'gorilla', 'grizzly+bear',
117
+ 'hamster', 'hippopotamus', 'horse', 'humpback+whale', 'leopard', 'lion',
118
+ 'moose', 'otter', 'ox', 'pig', 'polar+bear', 'rabbit', 'rhinoceros',
119
+ 'seal', 'sheep', 'squirrel', 'tiger', 'zebra'
120
+ ]
121
+
122
+ # Preprocessing
123
+ preprocess = transforms.Compose([
124
+ transforms.Resize(256),
125
+ transforms.CenterCrop(224),
126
+ transforms.ToTensor(),
127
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
128
+ ])
129
+
130
+ def predict(images):
131
+ """Process multiple images and return predictions"""
132
+ predictions = []
133
+
134
+ # Batch processing
135
+ batch = torch.stack([preprocess(Image.open(img).convert('RGB')) for img in images])
136
+
137
+ with torch.inference_mode():
138
+ outputs = model(batch)
139
+ _, preds = torch.max(outputs, 1)
140
+
141
+ return ", ".join([classes[p] for p in preds.cpu().numpy()])
142
+
143
+ # Gradio Interface
144
+ with gr.Blocks(title="Animal Classifier") as demo:
145
+ gr.Markdown("# 🐾 Animal Classifier")
146
+ gr.Markdown("Upload multiple animal images to get predictions!")
147
+
148
+ with gr.Row():
149
+ inputs = gr.File(
150
+ file_count="multiple",
151
+ file_types=["image"],
152
+ label="Upload Animal Images"
153
+ )
154
+ submit = gr.Button("Classify 🚀", variant="primary")
155
+
156
+ with gr.Row():
157
+ gallery = gr.Gallery(label="Upload Preview", columns=4)
158
+ outputs = gr.Textbox(label="Predictions", lines=5)
159
+
160
+ submit.click(
161
+ fn=lambda files: (
162
+ [f.name for f in files], # Update gallery
163
+ predict([f.name for f in files]) # Get predictions
164
+ ),
165
+ inputs=inputs,
166
+ outputs=[gallery, outputs]
167
+ )
168
+
169
+ if __name__ == "__main__":
170
+ demo.launch(show_error=True)