Update app.py
Browse files
app.py
CHANGED
@@ -1,228 +1,170 @@
|
|
1 |
-
import torch
|
2 |
-
import gradio as gr
|
3 |
-
import torch.nn as nn
|
4 |
-
import torch.nn.functional as F
|
5 |
-
from torch.nn import init
|
6 |
-
import torchvision.transforms as transforms
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
)
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
self.
|
59 |
-
self.
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
self.
|
81 |
-
self.
|
82 |
-
self.
|
83 |
-
|
84 |
-
self.
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
self.
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
|
118 |
-
|
119 |
-
|
120 |
-
|
121 |
-
|
122 |
-
|
123 |
-
|
124 |
-
|
125 |
-
|
126 |
-
|
127 |
-
|
128 |
-
|
129 |
-
|
130 |
-
|
131 |
-
|
132 |
-
|
133 |
-
|
134 |
-
|
135 |
-
|
136 |
-
|
137 |
-
|
138 |
-
|
139 |
-
|
140 |
-
|
141 |
-
|
142 |
-
|
143 |
-
|
144 |
-
|
145 |
-
|
146 |
-
|
147 |
-
|
148 |
-
|
149 |
-
|
150 |
-
|
151 |
-
|
152 |
-
|
153 |
-
|
154 |
-
|
155 |
-
|
156 |
-
|
157 |
-
|
158 |
-
|
159 |
-
|
160 |
-
|
161 |
-
|
162 |
-
|
163 |
-
|
164 |
-
|
165 |
-
|
166 |
-
|
167 |
-
|
168 |
-
|
169 |
-
|
170 |
-
|
171 |
-
def predict(img):
|
172 |
-
# Add preprocessing and fix inference mode
|
173 |
-
model.eval()
|
174 |
-
|
175 |
-
# Convert Gradio input to tensor
|
176 |
-
preprocess = transforms.Compose([
|
177 |
-
transforms.Resize(256),
|
178 |
-
transforms.CenterCrop(224),
|
179 |
-
transforms.ToTensor(),
|
180 |
-
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
181 |
-
])
|
182 |
-
|
183 |
-
img_tensor = preprocess(img).unsqueeze(0)
|
184 |
-
|
185 |
-
with torch.inference_mode():
|
186 |
-
logits = model(img_tensor)
|
187 |
-
preds = logits.argmax(dim=1)
|
188 |
-
return classes[preds.item()]
|
189 |
-
|
190 |
-
|
191 |
-
|
192 |
-
"""gradio interface"""
|
193 |
-
"""demo = gr.Interface(
|
194 |
-
fn=predict,
|
195 |
-
inputs=gr.Image(type="pil", width=244, height=244),
|
196 |
-
outputs="label",
|
197 |
-
title="Animal Classifier",
|
198 |
-
description="Classify 30 animal categories: antelope, buffalo, chimpanzee, cow, deer, dolphin, elephant, fox, giant+panda, giraffe, gorilla, grizzly+bear, hamster, hippopotamus, horse, humpback+whale, leopard, lion, moose, otter, ox, pig, polar+bear, rabbit, rhinoceros, seal, sheep, squirrel, tiger, zebra"
|
199 |
-
)
|
200 |
-
"""
|
201 |
-
|
202 |
-
with gr.Blocks() as demo:
|
203 |
-
gr.Markdown("## Animal Classifier")
|
204 |
-
gr.Markdown("Classify 30 animal categories: antelope, buffalo, chimpanzee, cow, deer, dolphin, elephant, fox, giant+panda, giraffe, gorilla, grizzly+bear, hamster, hippopotamus, horse, humpback+whale, leopard, lion, moose, otter, ox, pig, polar+bear, rabbit, rhinoceros, seal, sheep, squirrel, tiger, zebra")
|
205 |
-
with gr.Row():
|
206 |
-
upload = gr.File(
|
207 |
-
file_count="multiple",
|
208 |
-
file_types=["image"],
|
209 |
-
label="Upload Images"
|
210 |
-
)
|
211 |
-
submit = gr.Button("Classify")
|
212 |
-
|
213 |
-
with gr.Row():
|
214 |
-
gallery = gr.Gallery(label="Uploaded Images")
|
215 |
-
predictions = gr.Textbox(label="Predictions", interactive=False)
|
216 |
-
|
217 |
-
submit.click(
|
218 |
-
fn=lambda files: (
|
219 |
-
[f.name for f in files],
|
220 |
-
", ".join([predict([file]) for file in files])
|
221 |
-
),
|
222 |
-
inputs=upload,
|
223 |
-
outputs=[gallery, predictions]
|
224 |
-
)
|
225 |
-
|
226 |
-
"""launch interface"""
|
227 |
-
if __name__ == "__main__":
|
228 |
-
demo.launch()
|
|
|
1 |
+
import torch
|
2 |
+
import gradio as gr
|
3 |
+
import torch.nn as nn
|
4 |
+
import torch.nn.functional as F
|
5 |
+
from torch.nn import init
|
6 |
+
import torchvision.transforms as transforms
|
7 |
+
from PIL import Image
|
8 |
+
|
9 |
+
# MobileNetV3 Model Definition (keep this exactly as in your original code)
|
10 |
+
class hswish(nn.Module):
|
11 |
+
def forward(self, x):
|
12 |
+
return x * F.relu6(x + 3) / 6
|
13 |
+
|
14 |
+
class hsigmoid(nn.Module):
|
15 |
+
def forward(self, x):
|
16 |
+
return F.relu6(x + 3) / 6
|
17 |
+
|
18 |
+
class SeModule(nn.Module):
|
19 |
+
def __init__(self, in_size, reduction=4):
|
20 |
+
super().__init__()
|
21 |
+
self.se = nn.Sequential(
|
22 |
+
nn.AdaptiveAvgPool2d(1),
|
23 |
+
nn.Conv2d(in_size, in_size//reduction, 1, bias=False),
|
24 |
+
nn.BatchNorm2d(in_size//reduction),
|
25 |
+
nn.ReLU(inplace=True),
|
26 |
+
nn.Conv2d(in_size//reduction, in_size, 1, bias=False),
|
27 |
+
nn.BatchNorm2d(in_size),
|
28 |
+
hsigmoid()
|
29 |
+
)
|
30 |
+
|
31 |
+
def forward(self, x):
|
32 |
+
return x * self.se(x)
|
33 |
+
|
34 |
+
class Block(nn.Module):
|
35 |
+
def __init__(self, kernel_size, in_size, expand_size, out_size, nolinear, semodule, stride):
|
36 |
+
super().__init__()
|
37 |
+
self.stride = stride
|
38 |
+
self.se = semodule
|
39 |
+
self.conv1 = nn.Conv2d(in_size, expand_size, 1, 1, 0, bias=False)
|
40 |
+
self.bn1 = nn.BatchNorm2d(expand_size)
|
41 |
+
self.nolinear1 = nolinear
|
42 |
+
self.conv2 = nn.Conv2d(expand_size, expand_size, kernel_size, stride, kernel_size//2, groups=expand_size, bias=False)
|
43 |
+
self.bn2 = nn.BatchNorm2d(expand_size)
|
44 |
+
self.nolinear2 = nolinear
|
45 |
+
self.conv3 = nn.Conv2d(expand_size, out_size, 1, 1, 0, bias=False)
|
46 |
+
self.bn3 = nn.BatchNorm2d(out_size)
|
47 |
+
self.shortcut = nn.Sequential()
|
48 |
+
if stride == 1 and in_size != out_size:
|
49 |
+
self.shortcut = nn.Sequential(
|
50 |
+
nn.Conv2d(in_size, out_size, 1, 1, 0, bias=False),
|
51 |
+
nn.BatchNorm2d(out_size),
|
52 |
+
)
|
53 |
+
|
54 |
+
def forward(self, x):
|
55 |
+
out = self.nolinear1(self.bn1(self.conv1(x)))
|
56 |
+
out = self.nolinear2(self.bn2(self.conv2(out)))
|
57 |
+
out = self.bn3(self.conv3(out))
|
58 |
+
if self.se: out = self.se(out)
|
59 |
+
return out + self.shortcut(x) if self.stride==1 else out
|
60 |
+
|
61 |
+
class MobileNetV3_Small(nn.Module):
|
62 |
+
def __init__(self, num_classes=30):
|
63 |
+
super().__init__()
|
64 |
+
self.conv1 = nn.Conv2d(3, 16, 3, 2, 1, bias=False)
|
65 |
+
self.bn1 = nn.BatchNorm2d(16)
|
66 |
+
self.hs1 = hswish()
|
67 |
+
self.bneck = nn.Sequential(
|
68 |
+
Block(3, 16, 16, 16, nn.ReLU(), SeModule(16), 2),
|
69 |
+
Block(3, 16, 72, 24, nn.ReLU(), None, 2),
|
70 |
+
Block(3, 24, 88, 24, nn.ReLU(), None, 1),
|
71 |
+
Block(5, 24, 96, 40, hswish(), SeModule(40), 2),
|
72 |
+
Block(5, 40, 240, 40, hswish(), SeModule(40), 1),
|
73 |
+
Block(5, 40, 240, 40, hswish(), SeModule(40), 1),
|
74 |
+
Block(5, 40, 120, 48, hswish(), SeModule(48), 1),
|
75 |
+
Block(5, 48, 144, 48, hswish(), SeModule(48), 1),
|
76 |
+
Block(5, 48, 288, 96, hswish(), SeModule(96), 2),
|
77 |
+
Block(5, 96, 576, 96, hswish(), SeModule(96), 1),
|
78 |
+
Block(5, 96, 576, 96, hswish(), SeModule(96), 1),
|
79 |
+
)
|
80 |
+
self.conv2 = nn.Conv2d(96, 576, 1, 1, 0, bias=False)
|
81 |
+
self.bn2 = nn.BatchNorm2d(576)
|
82 |
+
self.hs2 = hswish()
|
83 |
+
self.linear3 = nn.Linear(576, 1280)
|
84 |
+
self.bn3 = nn.BatchNorm1d(1280)
|
85 |
+
self.hs3 = hswish()
|
86 |
+
self.linear4 = nn.Linear(1280, num_classes)
|
87 |
+
|
88 |
+
for m in self.modules():
|
89 |
+
if isinstance(m, nn.Conv2d):
|
90 |
+
init.kaiming_normal_(m.weight, mode='fan_out')
|
91 |
+
if m.bias is not None: init.constant_(m.bias, 0)
|
92 |
+
elif isinstance(m, nn.BatchNorm2d):
|
93 |
+
init.constant_(m.weight, 1)
|
94 |
+
init.constant_(m.bias, 0)
|
95 |
+
elif isinstance(m, nn.Linear):
|
96 |
+
init.normal_(m.weight, std=0.001)
|
97 |
+
if m.bias is not None: init.constant_(m.bias, 0)
|
98 |
+
|
99 |
+
def forward(self, x):
|
100 |
+
x = self.hs1(self.bn1(self.conv1(x)))
|
101 |
+
x = self.bneck(x)
|
102 |
+
x = self.hs2(self.bn2(self.conv2(x)))
|
103 |
+
x = F.avg_pool2d(x, x.size()[2:])
|
104 |
+
x = x.view(x.size(0), -1)
|
105 |
+
x = self.hs3(self.bn3(self.linear3(x)))
|
106 |
+
return self.linear4(x)
|
107 |
+
|
108 |
+
# Initialize Model
|
109 |
+
model = MobileNetV3_Small().cpu()
|
110 |
+
model.load_state_dict(torch.load("MobileNet3_small_StateDictionary.pth", map_location='cpu'))
|
111 |
+
model.eval()
|
112 |
+
|
113 |
+
# Class Labels
|
114 |
+
classes = [
|
115 |
+
'antelope', 'buffalo', 'chimpanzee', 'cow', 'deer', 'dolphin',
|
116 |
+
'elephant', 'fox', 'giant+panda', 'giraffe', 'gorilla', 'grizzly+bear',
|
117 |
+
'hamster', 'hippopotamus', 'horse', 'humpback+whale', 'leopard', 'lion',
|
118 |
+
'moose', 'otter', 'ox', 'pig', 'polar+bear', 'rabbit', 'rhinoceros',
|
119 |
+
'seal', 'sheep', 'squirrel', 'tiger', 'zebra'
|
120 |
+
]
|
121 |
+
|
122 |
+
# Preprocessing
|
123 |
+
preprocess = transforms.Compose([
|
124 |
+
transforms.Resize(256),
|
125 |
+
transforms.CenterCrop(224),
|
126 |
+
transforms.ToTensor(),
|
127 |
+
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
128 |
+
])
|
129 |
+
|
130 |
+
def predict(images):
|
131 |
+
"""Process multiple images and return predictions"""
|
132 |
+
predictions = []
|
133 |
+
|
134 |
+
# Batch processing
|
135 |
+
batch = torch.stack([preprocess(Image.open(img).convert('RGB')) for img in images])
|
136 |
+
|
137 |
+
with torch.inference_mode():
|
138 |
+
outputs = model(batch)
|
139 |
+
_, preds = torch.max(outputs, 1)
|
140 |
+
|
141 |
+
return ", ".join([classes[p] for p in preds.cpu().numpy()])
|
142 |
+
|
143 |
+
# Gradio Interface
|
144 |
+
with gr.Blocks(title="Animal Classifier") as demo:
|
145 |
+
gr.Markdown("# 🐾 Animal Classifier")
|
146 |
+
gr.Markdown("Upload multiple animal images to get predictions!")
|
147 |
+
|
148 |
+
with gr.Row():
|
149 |
+
inputs = gr.File(
|
150 |
+
file_count="multiple",
|
151 |
+
file_types=["image"],
|
152 |
+
label="Upload Animal Images"
|
153 |
+
)
|
154 |
+
submit = gr.Button("Classify 🚀", variant="primary")
|
155 |
+
|
156 |
+
with gr.Row():
|
157 |
+
gallery = gr.Gallery(label="Upload Preview", columns=4)
|
158 |
+
outputs = gr.Textbox(label="Predictions", lines=5)
|
159 |
+
|
160 |
+
submit.click(
|
161 |
+
fn=lambda files: (
|
162 |
+
[f.name for f in files], # Update gallery
|
163 |
+
predict([f.name for f in files]) # Get predictions
|
164 |
+
),
|
165 |
+
inputs=inputs,
|
166 |
+
outputs=[gallery, outputs]
|
167 |
+
)
|
168 |
+
|
169 |
+
if __name__ == "__main__":
|
170 |
+
demo.launch(show_error=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|