IncreasingLoss's picture
Update app.py
eca7b02 verified
raw
history blame
2.45 kB
import torch
import gradio as gr
import torch.nn as nn
import torch.nn.functional as F
import os
from pathlib import Path
from torch.nn import init
import torchvision.transforms as transforms
from PIL import Image
# ... [Keep all your existing model definitions and initialization code] ...
# Precompute example image paths
example_dir = "examples"
example_images = [os.path.join(example_dir, f) for f in os.listdir(example_dir)
if f.lower().endswith(('.png', '.jpg', '.jpeg', '.webp'))]
# Custom CSS for styling
css = """
.centered-examples {
margin: 0 auto !important;
justify-content: center !important;
gap: 8px !important;
}
.centered-examples .thumb {
height: 100px !important;
width: 100px !important;
object-fit: cover !important;
}
"""
def predict(img):
"""Process single image and return prediction"""
if img is None:
return "Please select or upload an image"
try:
image = Image.open(img).convert('RGB')
tensor = preprocess(image).unsqueeze(0) # Add batch dimension
with torch.inference_mode():
outputs = model(tensor)
_, pred = torch.max(outputs, 1)
return classes[pred.item()]
except Exception as e:
return f"Error: {str(e)}"
with gr.Blocks(title="Animal Classifier", css=css) as demo:
gr.Markdown("## 🐾 Animal Classifier")
gr.Markdown("Click example images below or upload your own")
with gr.Row():
input_image = gr.Image(type="filepath", label="Selected Image")
output_label = gr.Label(label="Prediction")
# Example gallery with click handling
with gr.Row(variant="panel"):
examples_gallery = gr.Gallery(
value=example_images,
label="Example Images (Click to Predict)",
columns=7,
height=120,
allow_preview=False,
elem_classes=["centered-examples"]
)
# Handle example image clicks
def select_example(evt: gr.SelectData):
return example_images[evt.index]
examples_gallery.select(
fn=select_example,
outputs=input_image,
show_progress=False
)
# Handle predictions for both upload and example clicks
input_image.change(
fn=predict,
inputs=input_image,
outputs=output_label
)
if __name__ == "__main__":
demo.launch(show_error=True)