IncreasingLoss commited on
Commit
eca7b02
·
verified ·
1 Parent(s): 9125d4a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +33 -64
app.py CHANGED
@@ -26,92 +26,61 @@ css = """
26
  height: 100px !important;
27
  width: 100px !important;
28
  object-fit: cover !important;
29
- }"""
 
30
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
 
32
  with gr.Blocks(title="Animal Classifier", css=css) as demo:
33
  gr.Markdown("## 🐾 Animal Classifier")
34
- gr.Markdown("Upload multiple animal images to get predictions!")
35
-
36
- # Store uploaded and example file paths
37
- all_files_state = gr.State([])
38
-
39
- with gr.Row():
40
- inputs = gr.File(file_count="multiple", file_types=["image"], label="Upload Animal Images")
41
- submit = gr.Button("Classify 🚀", variant="primary")
42
 
43
  with gr.Row():
44
- gallery = gr.Gallery(label="Upload Preview", columns=4)
45
- outputs = gr.Textbox(label="Predictions", lines=5)
46
 
47
  # Example gallery with click handling
48
  with gr.Row(variant="panel"):
49
  examples_gallery = gr.Gallery(
50
  value=example_images,
51
- label="Example Images (Click to Add)",
52
  columns=7,
53
- height=244,
54
  allow_preview=False,
55
  elem_classes=["centered-examples"]
56
  )
57
 
58
- # Update state when files are uploaded
59
- def update_state(new_files):
60
- return [f.name for f in new_files] if new_files else []
61
 
62
- inputs.change(update_state, inputs, all_files_state)
63
-
64
- # Handle example selection - FIXED WITH VALIDATION
65
- def add_example(example_index, current_files):
66
- try:
67
- # Handle Gradio's selection format
68
- if isinstance(example_index, list):
69
- if not example_index: # Empty selection
70
- return current_files
71
- selected_idx = example_index[0]
72
- else:
73
- selected_idx = example_index
74
-
75
- # Validate index range
76
- if 0 <= selected_idx < len(example_images):
77
- selected_path = example_images[selected_idx]
78
- return current_files + [selected_path]
79
- return current_files
80
- except Exception as e:
81
- print(f"Error handling example selection: {str(e)}")
82
- return current_files
83
-
84
- # Corrected gallery select handler
85
  examples_gallery.select(
86
- add_example,
87
- [all_files_state],
88
- all_files_state,
89
  show_progress=False
90
  )
91
 
92
- # Update gallery preview
93
- def update_gallery(files):
94
- return files if files else []
95
-
96
- all_files_state.change(update_gallery, all_files_state, gallery)
97
-
98
- # Modified prediction function
99
- def predict(files):
100
- if not files:
101
- return ""
102
- try:
103
- batch = torch.stack([preprocess(Image.open(img).convert('RGB')) for img in files])
104
- with torch.inference_mode():
105
- outputs = model(batch)
106
- _, preds = torch.max(outputs, 1)
107
- return ", ".join([classes[p] for p in preds.cpu().numpy()])
108
- except Exception as e:
109
- return f"Error: {str(e)}"
110
-
111
- submit.click(
112
  fn=predict,
113
- inputs=all_files_state,
114
- outputs=outputs
115
  )
116
 
117
  if __name__ == "__main__":
 
26
  height: 100px !important;
27
  width: 100px !important;
28
  object-fit: cover !important;
29
+ }
30
+ """
31
 
32
+ def predict(img):
33
+ """Process single image and return prediction"""
34
+ if img is None:
35
+ return "Please select or upload an image"
36
+
37
+ try:
38
+ image = Image.open(img).convert('RGB')
39
+ tensor = preprocess(image).unsqueeze(0) # Add batch dimension
40
+
41
+ with torch.inference_mode():
42
+ outputs = model(tensor)
43
+ _, pred = torch.max(outputs, 1)
44
+
45
+ return classes[pred.item()]
46
+
47
+ except Exception as e:
48
+ return f"Error: {str(e)}"
49
 
50
  with gr.Blocks(title="Animal Classifier", css=css) as demo:
51
  gr.Markdown("## 🐾 Animal Classifier")
52
+ gr.Markdown("Click example images below or upload your own")
 
 
 
 
 
 
 
53
 
54
  with gr.Row():
55
+ input_image = gr.Image(type="filepath", label="Selected Image")
56
+ output_label = gr.Label(label="Prediction")
57
 
58
  # Example gallery with click handling
59
  with gr.Row(variant="panel"):
60
  examples_gallery = gr.Gallery(
61
  value=example_images,
62
+ label="Example Images (Click to Predict)",
63
  columns=7,
64
+ height=120,
65
  allow_preview=False,
66
  elem_classes=["centered-examples"]
67
  )
68
 
69
+ # Handle example image clicks
70
+ def select_example(evt: gr.SelectData):
71
+ return example_images[evt.index]
72
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
73
  examples_gallery.select(
74
+ fn=select_example,
75
+ outputs=input_image,
 
76
  show_progress=False
77
  )
78
 
79
+ # Handle predictions for both upload and example clicks
80
+ input_image.change(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
81
  fn=predict,
82
+ inputs=input_image,
83
+ outputs=output_label
84
  )
85
 
86
  if __name__ == "__main__":