DiffusersSpace / app.py
luckyo87's picture
Add application file
f224c77
import gradio as gr
from transformers import ViTFeatureExtractor, ViTForImageClassification
from PIL import Image
import torch
# Cargar modelo y extractor
model = ViTForImageClassification.from_pretrained("akahana/vit-base-cats-vs-dogs")
feature_extractor = ViTFeatureExtractor.from_pretrained("google/vit-base-patch16-224-in21k")
# Función de predicción
def classify_image(image):
inputs = feature_extractor(images=image, return_tensors="pt")
outputs = model(**inputs)
logits = outputs.logits
predicted_class_idx = logits.argmax(-1).item()
predicted_class = model.config.id2label[predicted_class_idx]
return predicted_class
# Interfaz mejorada
with gr.Blocks() as demo:
gr.Markdown("# 🐱🐶 Clasificador de Gatos vs Perros")
gr.Markdown("Sube una imagen de un gato o un perro. Este modelo basado en Vision Transformer (ViT) te dirá cuál es.")
with gr.Row():
with gr.Column():
image_input = gr.Image(label="📷 Sube tu imagen", type="pil")
submit_btn = gr.Button("🔍 Clasificar")
with gr.Column():
output_label = gr.Textbox(label="🔎 Resultado", interactive=False)
submit_btn.click(fn=classify_image, inputs=image_input, outputs=output_label)
demo.launch()