|
import torch |
|
import gradio as gr |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
import os |
|
from pathlib import Path |
|
from torch.nn import init |
|
import torchvision.transforms as transforms |
|
from PIL import Image |
|
|
|
|
|
|
|
|
|
example_dir = "examples" |
|
example_images = [os.path.join(example_dir, f) for f in os.listdir(example_dir) |
|
if f.lower().endswith(('.png', '.jpg', '.jpeg', '.webp'))] |
|
|
|
|
|
css = """ |
|
.centered-examples { |
|
margin: 0 auto !important; |
|
justify-content: center !important; |
|
gap: 8px !important; |
|
} |
|
.centered-examples .thumb { |
|
height: 100px !important; |
|
width: 100px !important; |
|
object-fit: cover !important; |
|
} |
|
""" |
|
|
|
|
|
preprocess = transforms.Compose([ |
|
transforms.Resize(256), |
|
transforms.CenterCrop(224), |
|
transforms.ToTensor(), |
|
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) |
|
]) |
|
|
|
|
|
example_dir = "examples" |
|
example_images = [os.path.join(example_dir, f) for f in os.listdir(example_dir) |
|
if f.lower().endswith(('.png', '.jpg', '.jpeg', '.webp'))] |
|
|
|
def predict(img_path): |
|
"""Process single image and return prediction""" |
|
if not img_path: |
|
return "Please select or upload an image first" |
|
|
|
try: |
|
image = Image.open(img_path).convert('RGB') |
|
tensor = preprocess(image).unsqueeze(0) |
|
|
|
with torch.inference_mode(): |
|
outputs = model(tensor) |
|
_, pred = torch.max(outputs, 1) |
|
|
|
return classes[pred.item()] |
|
|
|
except Exception as e: |
|
return f"Error: {str(e)}" |
|
|
|
with gr.Blocks(title="Animal Classifier", css=css) as demo: |
|
gr.Markdown("## πΎ Animal Classifier") |
|
gr.Markdown("Select an image below or upload your own, then click Classify") |
|
|
|
|
|
current_image = gr.State() |
|
|
|
with gr.Row(): |
|
with gr.Column(): |
|
image_preview = gr.Image(label="Selected Image", type="filepath") |
|
upload_btn = gr.UploadButton("Upload Custom Image", file_types=["image"]) |
|
classify_btn = gr.Button("Classify π", variant="primary") |
|
result = gr.Textbox(label="Prediction", lines=3) |
|
|
|
|
|
with gr.Row(variant="panel"): |
|
examples_gallery = gr.Gallery( |
|
value=example_images, |
|
label="Example Images (Click to Select)", |
|
columns=7, |
|
height=120, |
|
allow_preview=False, |
|
elem_classes=["centered-examples"] |
|
) |
|
|
|
|
|
def select_example(evt: gr.SelectData): |
|
return example_images[evt.index] |
|
|
|
examples_gallery.select( |
|
fn=select_example, |
|
outputs=[image_preview, current_image], |
|
show_progress=False |
|
) |
|
|
|
|
|
upload_btn.upload( |
|
fn=lambda file: (file.name, file.name), |
|
inputs=upload_btn, |
|
outputs=[image_preview, current_image] |
|
) |
|
|
|
|
|
classify_btn.click( |
|
fn=predict, |
|
inputs=current_image, |
|
outputs=result |
|
) |
|
|
|
if __name__ == "__main__": |
|
demo.launch(show_error=True) |