IncreasingLoss commited on
Commit
d844a2a
·
verified ·
1 Parent(s): eca7b02

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +31 -18
app.py CHANGED
@@ -29,14 +29,14 @@ css = """
29
  }
30
  """
31
 
32
- def predict(img):
33
  """Process single image and return prediction"""
34
- if img is None:
35
- return "Please select or upload an image"
36
 
37
  try:
38
- image = Image.open(img).convert('RGB')
39
- tensor = preprocess(image).unsqueeze(0) # Add batch dimension
40
 
41
  with torch.inference_mode():
42
  outputs = model(tensor)
@@ -49,38 +49,51 @@ def predict(img):
49
 
50
  with gr.Blocks(title="Animal Classifier", css=css) as demo:
51
  gr.Markdown("## 🐾 Animal Classifier")
52
- gr.Markdown("Click example images below or upload your own")
 
 
 
53
 
54
  with gr.Row():
55
- input_image = gr.Image(type="filepath", label="Selected Image")
56
- output_label = gr.Label(label="Prediction")
 
 
 
57
 
58
- # Example gallery with click handling
59
  with gr.Row(variant="panel"):
60
  examples_gallery = gr.Gallery(
61
  value=example_images,
62
- label="Example Images (Click to Predict)",
63
  columns=7,
64
  height=120,
65
  allow_preview=False,
66
  elem_classes=["centered-examples"]
67
  )
68
-
69
- # Handle example image clicks
70
  def select_example(evt: gr.SelectData):
71
  return example_images[evt.index]
72
 
73
  examples_gallery.select(
74
  fn=select_example,
75
- outputs=input_image,
76
  show_progress=False
77
  )
78
-
79
- # Handle predictions for both upload and example clicks
80
- input_image.change(
 
 
 
 
 
 
 
81
  fn=predict,
82
- inputs=input_image,
83
- outputs=output_label
84
  )
85
 
86
  if __name__ == "__main__":
 
29
  }
30
  """
31
 
32
+ def predict(img_path):
33
  """Process single image and return prediction"""
34
+ if not img_path:
35
+ return "Please select or upload an image first"
36
 
37
  try:
38
+ image = Image.open(img_path).convert('RGB')
39
+ tensor = preprocess(image).unsqueeze(0)
40
 
41
  with torch.inference_mode():
42
  outputs = model(tensor)
 
49
 
50
  with gr.Blocks(title="Animal Classifier", css=css) as demo:
51
  gr.Markdown("## 🐾 Animal Classifier")
52
+ gr.Markdown("Select an image below or upload your own, then click Classify")
53
+
54
+ # Store current image path
55
+ current_image = gr.State()
56
 
57
  with gr.Row():
58
+ with gr.Column():
59
+ image_preview = gr.Image(label="Selected Image", type="filepath")
60
+ upload_btn = gr.UploadButton("Upload Custom Image", file_types=["image"])
61
+ classify_btn = gr.Button("Classify 🚀", variant="primary")
62
+ result = gr.Textbox(label="Prediction", lines=3)
63
 
64
+ # Example gallery at bottom
65
  with gr.Row(variant="panel"):
66
  examples_gallery = gr.Gallery(
67
  value=example_images,
68
+ label="Example Images (Click to Select)",
69
  columns=7,
70
  height=120,
71
  allow_preview=False,
72
  elem_classes=["centered-examples"]
73
  )
74
+
75
+ # Handle image selection from examples
76
  def select_example(evt: gr.SelectData):
77
  return example_images[evt.index]
78
 
79
  examples_gallery.select(
80
  fn=select_example,
81
+ outputs=[image_preview, current_image],
82
  show_progress=False
83
  )
84
+
85
+ # Handle custom uploads
86
+ upload_btn.upload(
87
+ fn=lambda file: (file.name, file.name),
88
+ inputs=upload_btn,
89
+ outputs=[image_preview, current_image]
90
+ )
91
+
92
+ # Handle classification
93
+ classify_btn.click(
94
  fn=predict,
95
+ inputs=current_image,
96
+ outputs=result
97
  )
98
 
99
  if __name__ == "__main__":