IncreasingLoss commited on
Commit
f322570
·
verified ·
1 Parent(s): be24743

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +177 -169
app.py CHANGED
@@ -1,170 +1,178 @@
1
- import torch
2
- import gradio as gr
3
- import torch.nn as nn
4
- import torch.nn.functional as F
5
- from torch.nn import init
6
- import torchvision.transforms as transforms
7
- from PIL import Image
8
-
9
- # MobileNetV3 Model Definition (keep this exactly as in your original code)
10
- class hswish(nn.Module):
11
- def forward(self, x):
12
- return x * F.relu6(x + 3) / 6
13
-
14
- class hsigmoid(nn.Module):
15
- def forward(self, x):
16
- return F.relu6(x + 3) / 6
17
-
18
- class SeModule(nn.Module):
19
- def __init__(self, in_size, reduction=4):
20
- super().__init__()
21
- self.se = nn.Sequential(
22
- nn.AdaptiveAvgPool2d(1),
23
- nn.Conv2d(in_size, in_size//reduction, 1, bias=False),
24
- nn.BatchNorm2d(in_size//reduction),
25
- nn.ReLU(inplace=True),
26
- nn.Conv2d(in_size//reduction, in_size, 1, bias=False),
27
- nn.BatchNorm2d(in_size),
28
- hsigmoid()
29
- )
30
-
31
- def forward(self, x):
32
- return x * self.se(x)
33
-
34
- class Block(nn.Module):
35
- def __init__(self, kernel_size, in_size, expand_size, out_size, nolinear, semodule, stride):
36
- super().__init__()
37
- self.stride = stride
38
- self.se = semodule
39
- self.conv1 = nn.Conv2d(in_size, expand_size, 1, 1, 0, bias=False)
40
- self.bn1 = nn.BatchNorm2d(expand_size)
41
- self.nolinear1 = nolinear
42
- self.conv2 = nn.Conv2d(expand_size, expand_size, kernel_size, stride, kernel_size//2, groups=expand_size, bias=False)
43
- self.bn2 = nn.BatchNorm2d(expand_size)
44
- self.nolinear2 = nolinear
45
- self.conv3 = nn.Conv2d(expand_size, out_size, 1, 1, 0, bias=False)
46
- self.bn3 = nn.BatchNorm2d(out_size)
47
- self.shortcut = nn.Sequential()
48
- if stride == 1 and in_size != out_size:
49
- self.shortcut = nn.Sequential(
50
- nn.Conv2d(in_size, out_size, 1, 1, 0, bias=False),
51
- nn.BatchNorm2d(out_size),
52
- )
53
-
54
- def forward(self, x):
55
- out = self.nolinear1(self.bn1(self.conv1(x)))
56
- out = self.nolinear2(self.bn2(self.conv2(out)))
57
- out = self.bn3(self.conv3(out))
58
- if self.se: out = self.se(out)
59
- return out + self.shortcut(x) if self.stride==1 else out
60
-
61
- class MobileNetV3_Small(nn.Module):
62
- def __init__(self, num_classes=30):
63
- super().__init__()
64
- self.conv1 = nn.Conv2d(3, 16, 3, 2, 1, bias=False)
65
- self.bn1 = nn.BatchNorm2d(16)
66
- self.hs1 = hswish()
67
- self.bneck = nn.Sequential(
68
- Block(3, 16, 16, 16, nn.ReLU(), SeModule(16), 2),
69
- Block(3, 16, 72, 24, nn.ReLU(), None, 2),
70
- Block(3, 24, 88, 24, nn.ReLU(), None, 1),
71
- Block(5, 24, 96, 40, hswish(), SeModule(40), 2),
72
- Block(5, 40, 240, 40, hswish(), SeModule(40), 1),
73
- Block(5, 40, 240, 40, hswish(), SeModule(40), 1),
74
- Block(5, 40, 120, 48, hswish(), SeModule(48), 1),
75
- Block(5, 48, 144, 48, hswish(), SeModule(48), 1),
76
- Block(5, 48, 288, 96, hswish(), SeModule(96), 2),
77
- Block(5, 96, 576, 96, hswish(), SeModule(96), 1),
78
- Block(5, 96, 576, 96, hswish(), SeModule(96), 1),
79
- )
80
- self.conv2 = nn.Conv2d(96, 576, 1, 1, 0, bias=False)
81
- self.bn2 = nn.BatchNorm2d(576)
82
- self.hs2 = hswish()
83
- self.linear3 = nn.Linear(576, 1280)
84
- self.bn3 = nn.BatchNorm1d(1280)
85
- self.hs3 = hswish()
86
- self.linear4 = nn.Linear(1280, num_classes)
87
-
88
- for m in self.modules():
89
- if isinstance(m, nn.Conv2d):
90
- init.kaiming_normal_(m.weight, mode='fan_out')
91
- if m.bias is not None: init.constant_(m.bias, 0)
92
- elif isinstance(m, nn.BatchNorm2d):
93
- init.constant_(m.weight, 1)
94
- init.constant_(m.bias, 0)
95
- elif isinstance(m, nn.Linear):
96
- init.normal_(m.weight, std=0.001)
97
- if m.bias is not None: init.constant_(m.bias, 0)
98
-
99
- def forward(self, x):
100
- x = self.hs1(self.bn1(self.conv1(x)))
101
- x = self.bneck(x)
102
- x = self.hs2(self.bn2(self.conv2(x)))
103
- x = F.avg_pool2d(x, x.size()[2:])
104
- x = x.view(x.size(0), -1)
105
- x = self.hs3(self.bn3(self.linear3(x)))
106
- return self.linear4(x)
107
-
108
- # Initialize Model
109
- model = MobileNetV3_Small().cpu()
110
- model.load_state_dict(torch.load("MobileNet3_small_StateDictionary.pth", map_location='cpu'))
111
- model.eval()
112
-
113
- # Class Labels
114
- classes = [
115
- 'antelope', 'buffalo', 'chimpanzee', 'cow', 'deer', 'dolphin',
116
- 'elephant', 'fox', 'giant+panda', 'giraffe', 'gorilla', 'grizzlybear',
117
- 'hamster', 'hippopotamus', 'horse', 'humpbackwhale', 'leopard', 'lion',
118
- 'moose', 'otter', 'ox', 'pig', 'polarbear', 'rabbit', 'rhinoceros',
119
- 'seal', 'sheep', 'squirrel', 'tiger', 'zebra'
120
- ]
121
-
122
- # Preprocessing
123
- preprocess = transforms.Compose([
124
- transforms.Resize(256),
125
- transforms.CenterCrop(224),
126
- transforms.ToTensor(),
127
- transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
128
- ])
129
-
130
- def predict(images):
131
- """Process multiple images and return predictions"""
132
- predictions = []
133
-
134
- # Batch processing
135
- batch = torch.stack([preprocess(Image.open(img).convert('RGB')) for img in images])
136
-
137
- with torch.inference_mode():
138
- outputs = model(batch)
139
- _, preds = torch.max(outputs, 1)
140
-
141
- return ", ".join([classes[p] for p in preds.cpu().numpy()])
142
-
143
- # Gradio Interface
144
- with gr.Blocks(title="Animal Classifier") as demo:
145
- gr.Markdown("## 🐾 Animal Classifier")
146
- gr.Markdown("Upload multiple animal images to get predictions!")
147
- gr.Markdown("Detectable Classes: antelope, buffalo, chimpanzee, cow, deer, dolphin, elephant, fox, giantpanda, giraffe, gorilla, grizzlybear, hamster, hippopotamus, horse, humpbackwhale, leopard, lion, moose, otter, ox, pig, polarbear, rabbit, rhinoceros, seal, sheep, squirrel, tiger, zebra")
148
- with gr.Row():
149
- inputs = gr.File(
150
- file_count="multiple",
151
- file_types=["image"],
152
- label="Upload Animal Images"
153
- )
154
- submit = gr.Button("Classify 🚀", variant="primary")
155
-
156
- with gr.Row():
157
- gallery = gr.Gallery(label="Upload Preview", columns=4)
158
- outputs = gr.Textbox(label="Predictions", lines=5)
159
-
160
- submit.click(
161
- fn=lambda files: (
162
- [f.name for f in files], # Update gallery
163
- predict([f.name for f in files]) # Get predictions
164
- ),
165
- inputs=inputs,
166
- outputs=[gallery, outputs]
167
- )
168
-
169
- if __name__ == "__main__":
 
 
 
 
 
 
 
 
170
  demo.launch(show_error=True)
 
1
+ import torch
2
+ import gradio as gr
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ from torch.nn import init
6
+ import torchvision.transforms as transforms
7
+ from PIL import Image
8
+
9
+ # MobileNetV3 Model Definition (keep this exactly as in your original code)
10
+ class hswish(nn.Module):
11
+ def forward(self, x):
12
+ return x * F.relu6(x + 3) / 6
13
+
14
+ class hsigmoid(nn.Module):
15
+ def forward(self, x):
16
+ return F.relu6(x + 3) / 6
17
+
18
+ class SeModule(nn.Module):
19
+ def __init__(self, in_size, reduction=4):
20
+ super().__init__()
21
+ self.se = nn.Sequential(
22
+ nn.AdaptiveAvgPool2d(1),
23
+ nn.Conv2d(in_size, in_size//reduction, 1, bias=False),
24
+ nn.BatchNorm2d(in_size//reduction),
25
+ nn.ReLU(inplace=True),
26
+ nn.Conv2d(in_size//reduction, in_size, 1, bias=False),
27
+ nn.BatchNorm2d(in_size),
28
+ hsigmoid()
29
+ )
30
+
31
+ def forward(self, x):
32
+ return x * self.se(x)
33
+
34
+ class Block(nn.Module):
35
+ def __init__(self, kernel_size, in_size, expand_size, out_size, nolinear, semodule, stride):
36
+ super().__init__()
37
+ self.stride = stride
38
+ self.se = semodule
39
+ self.conv1 = nn.Conv2d(in_size, expand_size, 1, 1, 0, bias=False)
40
+ self.bn1 = nn.BatchNorm2d(expand_size)
41
+ self.nolinear1 = nolinear
42
+ self.conv2 = nn.Conv2d(expand_size, expand_size, kernel_size, stride, kernel_size//2, groups=expand_size, bias=False)
43
+ self.bn2 = nn.BatchNorm2d(expand_size)
44
+ self.nolinear2 = nolinear
45
+ self.conv3 = nn.Conv2d(expand_size, out_size, 1, 1, 0, bias=False)
46
+ self.bn3 = nn.BatchNorm2d(out_size)
47
+ self.shortcut = nn.Sequential()
48
+ if stride == 1 and in_size != out_size:
49
+ self.shortcut = nn.Sequential(
50
+ nn.Conv2d(in_size, out_size, 1, 1, 0, bias=False),
51
+ nn.BatchNorm2d(out_size),
52
+ )
53
+
54
+ def forward(self, x):
55
+ out = self.nolinear1(self.bn1(self.conv1(x)))
56
+ out = self.nolinear2(self.bn2(self.conv2(out)))
57
+ out = self.bn3(self.conv3(out))
58
+ if self.se: out = self.se(out)
59
+ return out + self.shortcut(x) if self.stride==1 else out
60
+
61
+ class MobileNetV3_Small(nn.Module):
62
+ def __init__(self, num_classes=30):
63
+ super().__init__()
64
+ self.conv1 = nn.Conv2d(3, 16, 3, 2, 1, bias=False)
65
+ self.bn1 = nn.BatchNorm2d(16)
66
+ self.hs1 = hswish()
67
+ self.bneck = nn.Sequential(
68
+ Block(3, 16, 16, 16, nn.ReLU(), SeModule(16), 2),
69
+ Block(3, 16, 72, 24, nn.ReLU(), None, 2),
70
+ Block(3, 24, 88, 24, nn.ReLU(), None, 1),
71
+ Block(5, 24, 96, 40, hswish(), SeModule(40), 2),
72
+ Block(5, 40, 240, 40, hswish(), SeModule(40), 1),
73
+ Block(5, 40, 240, 40, hswish(), SeModule(40), 1),
74
+ Block(5, 40, 120, 48, hswish(), SeModule(48), 1),
75
+ Block(5, 48, 144, 48, hswish(), SeModule(48), 1),
76
+ Block(5, 48, 288, 96, hswish(), SeModule(96), 2),
77
+ Block(5, 96, 576, 96, hswish(), SeModule(96), 1),
78
+ Block(5, 96, 576, 96, hswish(), SeModule(96), 1),
79
+ )
80
+ self.conv2 = nn.Conv2d(96, 576, 1, 1, 0, bias=False)
81
+ self.bn2 = nn.BatchNorm2d(576)
82
+ self.hs2 = hswish()
83
+ self.linear3 = nn.Linear(576, 1280)
84
+ self.bn3 = nn.BatchNorm1d(1280)
85
+ self.hs3 = hswish()
86
+ self.linear4 = nn.Linear(1280, num_classes)
87
+
88
+ for m in self.modules():
89
+ if isinstance(m, nn.Conv2d):
90
+ init.kaiming_normal_(m.weight, mode='fan_out')
91
+ if m.bias is not None: init.constant_(m.bias, 0)
92
+ elif isinstance(m, nn.BatchNorm2d):
93
+ init.constant_(m.weight, 1)
94
+ init.constant_(m.bias, 0)
95
+ elif isinstance(m, nn.Linear):
96
+ init.normal_(m.weight, std=0.001)
97
+ if m.bias is not None: init.constant_(m.bias, 0)
98
+
99
+ def forward(self, x):
100
+ x = self.hs1(self.bn1(self.conv1(x)))
101
+ x = self.bneck(x)
102
+ x = self.hs2(self.bn2(self.conv2(x)))
103
+ x = F.avg_pool2d(x, x.size()[2:])
104
+ x = x.view(x.size(0), -1)
105
+ x = self.hs3(self.bn3(self.linear3(x)))
106
+ return self.linear4(x)
107
+
108
+ # Initialize Model
109
+ model = MobileNetV3_Small().cpu()
110
+ model.load_state_dict(torch.load("MobileNet3_small_StateDictionary.pth", map_location='cpu'))
111
+ model.eval()
112
+
113
+ # Class Labels
114
+ classes = [
115
+ 'antelope', 'buffalo', 'chimpanzee', 'cow', 'deer', 'dolphin',
116
+ 'elephant', 'fox', 'giant+panda', 'giraffe', 'gorilla', 'grizzlybear',
117
+ 'hamster', 'hippopotamus', 'horse', 'humpbackwhale', 'leopard', 'lion',
118
+ 'moose', 'otter', 'ox', 'pig', 'polarbear', 'rabbit', 'rhinoceros',
119
+ 'seal', 'sheep', 'squirrel', 'tiger', 'zebra'
120
+ ]
121
+
122
+ # Preprocessing
123
+ preprocess = transforms.Compose([
124
+ transforms.Resize(256),
125
+ transforms.CenterCrop(224),
126
+ transforms.ToTensor(),
127
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
128
+ ])
129
+
130
+ def predict(images):
131
+ """Process multiple images and return predictions"""
132
+ predictions = []
133
+
134
+ # Batch processing
135
+ batch = torch.stack([preprocess(Image.open(img).convert('RGB')) for img in images])
136
+
137
+ with torch.inference_mode():
138
+ outputs = model(batch)
139
+ _, preds = torch.max(outputs, 1)
140
+
141
+ return ", ".join([classes[p] for p in preds.cpu().numpy()])
142
+
143
+ # Gradio Interface
144
+ with gr.Blocks(title="Animal Classifier") as demo:
145
+ gr.Markdown("## 🐾 Animal Classifier")
146
+ gr.Markdown("Upload multiple animal images to get predictions!")
147
+ gr.Markdown("Detectable Classes: antelope, buffalo, chimpanzee, cow, deer, dolphin, elephant, fox, giantpanda, giraffe, gorilla, grizzlybear, hamster, hippopotamus, horse, humpbackwhale, leopard, lion, moose, otter, ox, pig, polarbear, rabbit, rhinoceros, seal, sheep, squirrel, tiger, zebra")
148
+ with gr.Row():
149
+ inputs = gr.File(
150
+ file_count="multiple",
151
+ file_types=["image"],
152
+ label="Upload Animal Images"
153
+ )
154
+ submit = gr.Button("Classify 🚀", variant="primary")
155
+
156
+ with gr.Row():
157
+ gallery = gr.Gallery(label="Upload Preview", columns=4)
158
+ outputs = gr.Textbox(label="Predictions", lines=5)
159
+ with gr.Row():
160
+ gr.Markdown("Test Images")
161
+ gr.Gallery(
162
+ columns = 7,
163
+ preview = True,
164
+ height = 200
165
+ )
166
+
167
+
168
+ submit.click(
169
+ fn=lambda files: (
170
+ [f.name for f in files], # Update gallery
171
+ predict([f.name for f in files]) # Get predictions
172
+ ),
173
+ inputs=inputs,
174
+ outputs=[gallery, outputs]
175
+ )
176
+
177
+ if __name__ == "__main__":
178
  demo.launch(show_error=True)