IncreasingLoss commited on
Commit
ca77fd5
·
verified ·
1 Parent(s): aaa3c34

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +75 -151
app.py CHANGED
@@ -2,176 +2,100 @@ import torch
2
  import gradio as gr
3
  import torch.nn as nn
4
  import torch.nn.functional as F
 
 
5
  from torch.nn import init
6
  import torchvision.transforms as transforms
7
  from PIL import Image
8
 
9
- # MobileNetV3 Model Definition (keep this exactly as in your original code)
10
- class hswish(nn.Module):
11
- def forward(self, x):
12
- return x * F.relu6(x + 3) / 6
13
 
14
- class hsigmoid(nn.Module):
15
- def forward(self, x):
16
- return F.relu6(x + 3) / 6
 
17
 
18
- class SeModule(nn.Module):
19
- def __init__(self, in_size, reduction=4):
20
- super().__init__()
21
- self.se = nn.Sequential(
22
- nn.AdaptiveAvgPool2d(1),
23
- nn.Conv2d(in_size, in_size//reduction, 1, bias=False),
24
- nn.BatchNorm2d(in_size//reduction),
25
- nn.ReLU(inplace=True),
26
- nn.Conv2d(in_size//reduction, in_size, 1, bias=False),
27
- nn.BatchNorm2d(in_size),
28
- hsigmoid()
29
- )
30
-
31
- def forward(self, x):
32
- return x * self.se(x)
33
-
34
- class Block(nn.Module):
35
- def __init__(self, kernel_size, in_size, expand_size, out_size, nolinear, semodule, stride):
36
- super().__init__()
37
- self.stride = stride
38
- self.se = semodule
39
- self.conv1 = nn.Conv2d(in_size, expand_size, 1, 1, 0, bias=False)
40
- self.bn1 = nn.BatchNorm2d(expand_size)
41
- self.nolinear1 = nolinear
42
- self.conv2 = nn.Conv2d(expand_size, expand_size, kernel_size, stride, kernel_size//2, groups=expand_size, bias=False)
43
- self.bn2 = nn.BatchNorm2d(expand_size)
44
- self.nolinear2 = nolinear
45
- self.conv3 = nn.Conv2d(expand_size, out_size, 1, 1, 0, bias=False)
46
- self.bn3 = nn.BatchNorm2d(out_size)
47
- self.shortcut = nn.Sequential()
48
- if stride == 1 and in_size != out_size:
49
- self.shortcut = nn.Sequential(
50
- nn.Conv2d(in_size, out_size, 1, 1, 0, bias=False),
51
- nn.BatchNorm2d(out_size),
52
- )
53
-
54
- def forward(self, x):
55
- out = self.nolinear1(self.bn1(self.conv1(x)))
56
- out = self.nolinear2(self.bn2(self.conv2(out)))
57
- out = self.bn3(self.conv3(out))
58
- if self.se: out = self.se(out)
59
- return out + self.shortcut(x) if self.stride==1 else out
60
-
61
- class MobileNetV3_Small(nn.Module):
62
- def __init__(self, num_classes=30):
63
- super().__init__()
64
- self.conv1 = nn.Conv2d(3, 16, 3, 2, 1, bias=False)
65
- self.bn1 = nn.BatchNorm2d(16)
66
- self.hs1 = hswish()
67
- self.bneck = nn.Sequential(
68
- Block(3, 16, 16, 16, nn.ReLU(), SeModule(16), 2),
69
- Block(3, 16, 72, 24, nn.ReLU(), None, 2),
70
- Block(3, 24, 88, 24, nn.ReLU(), None, 1),
71
- Block(5, 24, 96, 40, hswish(), SeModule(40), 2),
72
- Block(5, 40, 240, 40, hswish(), SeModule(40), 1),
73
- Block(5, 40, 240, 40, hswish(), SeModule(40), 1),
74
- Block(5, 40, 120, 48, hswish(), SeModule(48), 1),
75
- Block(5, 48, 144, 48, hswish(), SeModule(48), 1),
76
- Block(5, 48, 288, 96, hswish(), SeModule(96), 2),
77
- Block(5, 96, 576, 96, hswish(), SeModule(96), 1),
78
- Block(5, 96, 576, 96, hswish(), SeModule(96), 1),
79
- )
80
- self.conv2 = nn.Conv2d(96, 576, 1, 1, 0, bias=False)
81
- self.bn2 = nn.BatchNorm2d(576)
82
- self.hs2 = hswish()
83
- self.linear3 = nn.Linear(576, 1280)
84
- self.bn3 = nn.BatchNorm1d(1280)
85
- self.hs3 = hswish()
86
- self.linear4 = nn.Linear(1280, num_classes)
87
-
88
- for m in self.modules():
89
- if isinstance(m, nn.Conv2d):
90
- init.kaiming_normal_(m.weight, mode='fan_out')
91
- if m.bias is not None: init.constant_(m.bias, 0)
92
- elif isinstance(m, nn.BatchNorm2d):
93
- init.constant_(m.weight, 1)
94
- init.constant_(m.bias, 0)
95
- elif isinstance(m, nn.Linear):
96
- init.normal_(m.weight, std=0.001)
97
- if m.bias is not None: init.constant_(m.bias, 0)
98
-
99
- def forward(self, x):
100
- x = self.hs1(self.bn1(self.conv1(x)))
101
- x = self.bneck(x)
102
- x = self.hs2(self.bn2(self.conv2(x)))
103
- x = F.avg_pool2d(x, x.size()[2:])
104
- x = x.view(x.size(0), -1)
105
- x = self.hs3(self.bn3(self.linear3(x)))
106
- return self.linear4(x)
107
-
108
- # Initialize Model
109
- model = MobileNetV3_Small().cpu()
110
- model.load_state_dict(torch.load("MobileNet3_small_StateDictionary.pth", map_location='cpu'))
111
- model.eval()
112
-
113
- # Class Labels
114
- classes = [
115
- 'antelope', 'buffalo', 'chimpanzee', 'cow', 'deer', 'dolphin',
116
- 'elephant', 'fox', 'giant+panda', 'giraffe', 'gorilla', 'grizzlybear',
117
- 'hamster', 'hippopotamus', 'horse', 'humpbackwhale', 'leopard', 'lion',
118
- 'moose', 'otter', 'ox', 'pig', 'polarbear', 'rabbit', 'rhinoceros',
119
- 'seal', 'sheep', 'squirrel', 'tiger', 'zebra'
120
- ]
121
-
122
- # Preprocessing
123
- preprocess = transforms.Compose([
124
- transforms.Resize(256),
125
- transforms.CenterCrop(224),
126
- transforms.ToTensor(),
127
- transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
128
- ])
129
-
130
- def predict(images):
131
- """Process multiple images and return predictions"""
132
- predictions = []
133
-
134
- # Batch processing
135
- batch = torch.stack([preprocess(Image.open(img).convert('RGB')) for img in images])
136
-
137
- with torch.inference_mode():
138
- outputs = model(batch)
139
- _, preds = torch.max(outputs, 1)
140
-
141
- return ", ".join([classes[p] for p in preds.cpu().numpy()])
142
 
143
- # Gradio Interface
144
- with gr.Blocks(title="Animal Classifier") as demo:
145
  gr.Markdown("## 🐾 Animal Classifier")
146
  gr.Markdown("Upload multiple animal images to get predictions!")
147
- gr.Markdown("Detectable Classes: antelope, buffalo, chimpanzee, cow, deer, dolphin, elephant, fox, giantpanda, giraffe, gorilla, grizzlybear, hamster, hippopotamus, horse, humpbackwhale, leopard, lion, moose, otter, ox, pig, polarbear, rabbit, rhinoceros, seal, sheep, squirrel, tiger, zebra")
 
 
 
148
  with gr.Row():
149
- inputs = gr.File(
150
- file_count="multiple",
151
- file_types=["image"],
152
- label="Upload Animal Images"
153
- )
154
  submit = gr.Button("Classify 🚀", variant="primary")
155
 
156
  with gr.Row():
157
  gallery = gr.Gallery(label="Upload Preview", columns=4)
158
  outputs = gr.Textbox(label="Predictions", lines=5)
159
- with gr.Row():
160
- gr.Markdown("Test Images")
161
- gr.Gallery(
162
- columns = 7,
163
- preview = True,
164
- height = 200
 
 
 
 
165
  )
166
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
167
 
168
  submit.click(
169
- fn=lambda files: (
170
- [f.name for f in files], # Update gallery
171
- predict([f.name for f in files]) # Get predictions
172
- ),
173
- inputs=inputs,
174
- outputs=[gallery, outputs]
175
  )
176
 
177
  if __name__ == "__main__":
 
2
  import gradio as gr
3
  import torch.nn as nn
4
  import torch.nn.functional as F
5
+ import os
6
+ from pathlib import Path
7
  from torch.nn import init
8
  import torchvision.transforms as transforms
9
  from PIL import Image
10
 
11
+ # ... [Keep all your existing model definitions and initialization code] ...
 
 
 
12
 
13
+ # Precompute example image paths
14
+ example_dir = "examples"
15
+ example_images = [os.path.join(example_dir, f) for f in os.listdir(example_dir)
16
+ if f.lower().endswith(('.png', '.jpg', '.jpeg', '.webp'))]
17
 
18
+ # Custom CSS for styling
19
+ css = """
20
+ .centered-examples {
21
+ margin: 0 auto !important;
22
+ justify-content: center !important;
23
+ gap: 8px !important;
24
+ }
25
+ .centered-examples .thumb {
26
+ height: 100px !important;
27
+ width: 100px !important;
28
+ object-fit: cover !important;
29
+ }
30
+ """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
 
32
+ with gr.Blocks(title="Animal Classifier", css=css) as demo:
 
33
  gr.Markdown("## 🐾 Animal Classifier")
34
  gr.Markdown("Upload multiple animal images to get predictions!")
35
+
36
+ # Store uploaded and example file paths
37
+ all_files_state = gr.State([])
38
+
39
  with gr.Row():
40
+ inputs = gr.File(file_count="multiple", file_types=["image"], label="Upload Animal Images")
 
 
 
 
41
  submit = gr.Button("Classify 🚀", variant="primary")
42
 
43
  with gr.Row():
44
  gallery = gr.Gallery(label="Upload Preview", columns=4)
45
  outputs = gr.Textbox(label="Predictions", lines=5)
46
+
47
+ # Example gallery with click handling
48
+ with gr.Row(variant="panel"):
49
+ examples_gallery = gr.Gallery(
50
+ value=example_images,
51
+ label="Example Images (Click to Add)",
52
+ columns=7,
53
+ height=120,
54
+ allow_preview=False,
55
+ elem_classes=["centered-examples"]
56
  )
57
 
58
+ # Update state when files are uploaded
59
+ def update_state(new_files):
60
+ return [f.name for f in new_files] if new_files else []
61
+
62
+ inputs.change(update_state, inputs, all_files_state)
63
+
64
+ # Handle example selection
65
+ def add_example(example_index, current_files):
66
+ selected_path = example_images[example_index]
67
+ return current_files + [selected_path]
68
+
69
+ examples_gallery.select(
70
+ add_example,
71
+ [all_files_state],
72
+ all_files_state,
73
+ show_progress=False
74
+ )
75
+
76
+ # Update gallery preview
77
+ def update_gallery(files):
78
+ return files if files else []
79
+
80
+ all_files_state.change(update_gallery, all_files_state, gallery)
81
+
82
+ # Modified prediction function
83
+ def predict(files):
84
+ if not files:
85
+ return ""
86
+ try:
87
+ batch = torch.stack([preprocess(Image.open(img).convert('RGB')) for img in files])
88
+ with torch.inference_mode():
89
+ outputs = model(batch)
90
+ _, preds = torch.max(outputs, 1)
91
+ return ", ".join([classes[p] for p in preds.cpu().numpy()])
92
+ except Exception as e:
93
+ return f"Error: {str(e)}"
94
 
95
  submit.click(
96
+ fn=predict,
97
+ inputs=all_files_state,
98
+ outputs=outputs
 
 
 
99
  )
100
 
101
  if __name__ == "__main__":