IncreasingLoss's picture
Update app.py
4e539c0 verified
raw
history blame
3.33 kB
import torch
import gradio as gr
import torch.nn as nn
import torch.nn.functional as F
import os
from pathlib import Path
from torch.nn import init
import torchvision.transforms as transforms
from PIL import Image
# ... [Keep all your existing model definitions and initialization code] ...
# Precompute example image paths
example_dir = "examples"
example_images = [os.path.join(example_dir, f) for f in os.listdir(example_dir)
if f.lower().endswith(('.png', '.jpg', '.jpeg', '.webp'))]
# Custom CSS for styling
css = """
.centered-examples {
margin: 0 auto !important;
justify-content: center !important;
gap: 8px !important;
}
.centered-examples .thumb {
height: 100px !important;
width: 100px !important;
object-fit: cover !important;
}
"""
# PREPROCESSING PIPELINE (ADD THIS BACK)
preprocess = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
# Precompute example image paths
example_dir = "examples"
example_images = [os.path.join(example_dir, f) for f in os.listdir(example_dir)
if f.lower().endswith(('.png', '.jpg', '.jpeg', '.webp'))]
def predict(img_path):
"""Process single image and return prediction"""
if not img_path:
return "Please select or upload an image first"
try:
image = Image.open(img_path).convert('RGB')
tensor = preprocess(image).unsqueeze(0)
with torch.inference_mode():
outputs = model(tensor)
_, pred = torch.max(outputs, 1)
return classes[pred.item()]
except Exception as e:
return f"Error: {str(e)}"
with gr.Blocks(title="Animal Classifier", css=css) as demo:
gr.Markdown("## 🐾 Animal Classifier")
gr.Markdown("Select an image below or upload your own, then click Classify")
# Store current image path
current_image = gr.State()
with gr.Row():
with gr.Column():
image_preview = gr.Image(label="Selected Image", type="filepath")
upload_btn = gr.UploadButton("Upload Custom Image", file_types=["image"])
classify_btn = gr.Button("Classify πŸš€", variant="primary")
result = gr.Textbox(label="Prediction", lines=3)
# Example gallery at bottom
with gr.Row(variant="panel"):
examples_gallery = gr.Gallery(
value=example_images,
label="Example Images (Click to Select)",
columns=7,
height=120,
allow_preview=False,
elem_classes=["centered-examples"]
)
# Handle image selection from examples
def select_example(evt: gr.SelectData):
return example_images[evt.index]
examples_gallery.select(
fn=select_example,
outputs=[image_preview, current_image],
show_progress=False
)
# Handle custom uploads
upload_btn.upload(
fn=lambda file: (file.name, file.name),
inputs=upload_btn,
outputs=[image_preview, current_image]
)
# Handle classification
classify_btn.click(
fn=predict,
inputs=current_image,
outputs=result
)
if __name__ == "__main__":
demo.launch(show_error=True)