|
import torch |
|
import gradio as gr |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
import os |
|
from pathlib import Path |
|
from torch.nn import init |
|
import torchvision.transforms as transforms |
|
from PIL import Image |
|
|
|
|
|
|
|
|
|
example_dir = "examples" |
|
example_images = [os.path.join(example_dir, f) for f in os.listdir(example_dir) |
|
if f.lower().endswith(('.png', '.jpg', '.jpeg', '.webp'))] |
|
|
|
|
|
css = """ |
|
.centered-examples { |
|
margin: 0 auto !important; |
|
justify-content: center !important; |
|
gap: 8px !important; |
|
} |
|
.centered-examples .thumb { |
|
height: 100px !important; |
|
width: 100px !important; |
|
object-fit: cover !important; |
|
} |
|
""" |
|
|
|
def predict(img): |
|
"""Process single image and return prediction""" |
|
if img is None: |
|
return "Please select or upload an image" |
|
|
|
try: |
|
image = Image.open(img).convert('RGB') |
|
tensor = preprocess(image).unsqueeze(0) |
|
|
|
with torch.inference_mode(): |
|
outputs = model(tensor) |
|
_, pred = torch.max(outputs, 1) |
|
|
|
return classes[pred.item()] |
|
|
|
except Exception as e: |
|
return f"Error: {str(e)}" |
|
|
|
with gr.Blocks(title="Animal Classifier", css=css) as demo: |
|
gr.Markdown("## 🐾 Animal Classifier") |
|
gr.Markdown("Click example images below or upload your own") |
|
|
|
with gr.Row(): |
|
input_image = gr.Image(type="filepath", label="Selected Image") |
|
output_label = gr.Label(label="Prediction") |
|
|
|
|
|
with gr.Row(variant="panel"): |
|
examples_gallery = gr.Gallery( |
|
value=example_images, |
|
label="Example Images (Click to Predict)", |
|
columns=7, |
|
height=120, |
|
allow_preview=False, |
|
elem_classes=["centered-examples"] |
|
) |
|
|
|
|
|
def select_example(evt: gr.SelectData): |
|
return example_images[evt.index] |
|
|
|
examples_gallery.select( |
|
fn=select_example, |
|
outputs=input_image, |
|
show_progress=False |
|
) |
|
|
|
|
|
input_image.change( |
|
fn=predict, |
|
inputs=input_image, |
|
outputs=output_label |
|
) |
|
|
|
if __name__ == "__main__": |
|
demo.launch(show_error=True) |