Spaces:
Running
Running
File size: 1,267 Bytes
75b6726 389dfcd 75b6726 f224c77 389dfcd 75b6726 389dfcd f224c77 d5a5ca6 f224c77 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 |
import gradio as gr
from transformers import ViTFeatureExtractor, ViTForImageClassification
from PIL import Image
import torch
# Cargar modelo y extractor
model = ViTForImageClassification.from_pretrained("akahana/vit-base-cats-vs-dogs")
feature_extractor = ViTFeatureExtractor.from_pretrained("google/vit-base-patch16-224-in21k")
# Función de predicción
def classify_image(image):
inputs = feature_extractor(images=image, return_tensors="pt")
outputs = model(**inputs)
logits = outputs.logits
predicted_class_idx = logits.argmax(-1).item()
predicted_class = model.config.id2label[predicted_class_idx]
return predicted_class
# Interfaz mejorada
with gr.Blocks() as demo:
gr.Markdown("# 🐱🐶 Clasificador de Gatos vs Perros")
gr.Markdown("Sube una imagen de un gato o un perro. Este modelo basado en Vision Transformer (ViT) te dirá cuál es.")
with gr.Row():
with gr.Column():
image_input = gr.Image(label="📷 Sube tu imagen", type="pil")
submit_btn = gr.Button("🔍 Clasificar")
with gr.Column():
output_label = gr.Textbox(label="🔎 Resultado", interactive=False)
submit_btn.click(fn=classify_image, inputs=image_input, outputs=output_label)
demo.launch()
|