Update app.py
Browse files
app.py
CHANGED
@@ -26,92 +26,61 @@ css = """
|
|
26 |
height: 100px !important;
|
27 |
width: 100px !important;
|
28 |
object-fit: cover !important;
|
29 |
-
}
|
|
|
30 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
31 |
|
32 |
with gr.Blocks(title="Animal Classifier", css=css) as demo:
|
33 |
gr.Markdown("## 🐾 Animal Classifier")
|
34 |
-
gr.Markdown("
|
35 |
-
|
36 |
-
# Store uploaded and example file paths
|
37 |
-
all_files_state = gr.State([])
|
38 |
-
|
39 |
-
with gr.Row():
|
40 |
-
inputs = gr.File(file_count="multiple", file_types=["image"], label="Upload Animal Images")
|
41 |
-
submit = gr.Button("Classify 🚀", variant="primary")
|
42 |
|
43 |
with gr.Row():
|
44 |
-
|
45 |
-
|
46 |
|
47 |
# Example gallery with click handling
|
48 |
with gr.Row(variant="panel"):
|
49 |
examples_gallery = gr.Gallery(
|
50 |
value=example_images,
|
51 |
-
label="Example Images (Click to
|
52 |
columns=7,
|
53 |
-
height=
|
54 |
allow_preview=False,
|
55 |
elem_classes=["centered-examples"]
|
56 |
)
|
57 |
|
58 |
-
#
|
59 |
-
def
|
60 |
-
return [
|
61 |
|
62 |
-
inputs.change(update_state, inputs, all_files_state)
|
63 |
-
|
64 |
-
# Handle example selection - FIXED WITH VALIDATION
|
65 |
-
def add_example(example_index, current_files):
|
66 |
-
try:
|
67 |
-
# Handle Gradio's selection format
|
68 |
-
if isinstance(example_index, list):
|
69 |
-
if not example_index: # Empty selection
|
70 |
-
return current_files
|
71 |
-
selected_idx = example_index[0]
|
72 |
-
else:
|
73 |
-
selected_idx = example_index
|
74 |
-
|
75 |
-
# Validate index range
|
76 |
-
if 0 <= selected_idx < len(example_images):
|
77 |
-
selected_path = example_images[selected_idx]
|
78 |
-
return current_files + [selected_path]
|
79 |
-
return current_files
|
80 |
-
except Exception as e:
|
81 |
-
print(f"Error handling example selection: {str(e)}")
|
82 |
-
return current_files
|
83 |
-
|
84 |
-
# Corrected gallery select handler
|
85 |
examples_gallery.select(
|
86 |
-
|
87 |
-
|
88 |
-
all_files_state,
|
89 |
show_progress=False
|
90 |
)
|
91 |
|
92 |
-
#
|
93 |
-
|
94 |
-
return files if files else []
|
95 |
-
|
96 |
-
all_files_state.change(update_gallery, all_files_state, gallery)
|
97 |
-
|
98 |
-
# Modified prediction function
|
99 |
-
def predict(files):
|
100 |
-
if not files:
|
101 |
-
return ""
|
102 |
-
try:
|
103 |
-
batch = torch.stack([preprocess(Image.open(img).convert('RGB')) for img in files])
|
104 |
-
with torch.inference_mode():
|
105 |
-
outputs = model(batch)
|
106 |
-
_, preds = torch.max(outputs, 1)
|
107 |
-
return ", ".join([classes[p] for p in preds.cpu().numpy()])
|
108 |
-
except Exception as e:
|
109 |
-
return f"Error: {str(e)}"
|
110 |
-
|
111 |
-
submit.click(
|
112 |
fn=predict,
|
113 |
-
inputs=
|
114 |
-
outputs=
|
115 |
)
|
116 |
|
117 |
if __name__ == "__main__":
|
|
|
26 |
height: 100px !important;
|
27 |
width: 100px !important;
|
28 |
object-fit: cover !important;
|
29 |
+
}
|
30 |
+
"""
|
31 |
|
32 |
+
def predict(img):
|
33 |
+
"""Process single image and return prediction"""
|
34 |
+
if img is None:
|
35 |
+
return "Please select or upload an image"
|
36 |
+
|
37 |
+
try:
|
38 |
+
image = Image.open(img).convert('RGB')
|
39 |
+
tensor = preprocess(image).unsqueeze(0) # Add batch dimension
|
40 |
+
|
41 |
+
with torch.inference_mode():
|
42 |
+
outputs = model(tensor)
|
43 |
+
_, pred = torch.max(outputs, 1)
|
44 |
+
|
45 |
+
return classes[pred.item()]
|
46 |
+
|
47 |
+
except Exception as e:
|
48 |
+
return f"Error: {str(e)}"
|
49 |
|
50 |
with gr.Blocks(title="Animal Classifier", css=css) as demo:
|
51 |
gr.Markdown("## 🐾 Animal Classifier")
|
52 |
+
gr.Markdown("Click example images below or upload your own")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
53 |
|
54 |
with gr.Row():
|
55 |
+
input_image = gr.Image(type="filepath", label="Selected Image")
|
56 |
+
output_label = gr.Label(label="Prediction")
|
57 |
|
58 |
# Example gallery with click handling
|
59 |
with gr.Row(variant="panel"):
|
60 |
examples_gallery = gr.Gallery(
|
61 |
value=example_images,
|
62 |
+
label="Example Images (Click to Predict)",
|
63 |
columns=7,
|
64 |
+
height=120,
|
65 |
allow_preview=False,
|
66 |
elem_classes=["centered-examples"]
|
67 |
)
|
68 |
|
69 |
+
# Handle example image clicks
|
70 |
+
def select_example(evt: gr.SelectData):
|
71 |
+
return example_images[evt.index]
|
72 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
73 |
examples_gallery.select(
|
74 |
+
fn=select_example,
|
75 |
+
outputs=input_image,
|
|
|
76 |
show_progress=False
|
77 |
)
|
78 |
|
79 |
+
# Handle predictions for both upload and example clicks
|
80 |
+
input_image.change(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
81 |
fn=predict,
|
82 |
+
inputs=input_image,
|
83 |
+
outputs=output_label
|
84 |
)
|
85 |
|
86 |
if __name__ == "__main__":
|