JaviSwift's picture
Create app.py
b9e8f84 verified
raw
history blame
578 Bytes
import gradio as gr
import tensorflow as tf
model = tf.keras.models.load_model("hf://JaviSwift/cifar10_simple") # Cambia esto por tu usuario y nombre del modelo
def predict(image):
image = tf.image.resize(image, (32, 32)) # Ajusta según tu modelo
image = image / 255.0 # Normaliza si es necesario
image = tf.expand_dims(image, axis=0) # Añade dimensión para batch
predictions = model.predict(image)
return predictions.argmax(axis=1)[0] # Devuelve la clase predicha
iface = gr.Interface(fn=predict, inputs="image", outputs="label")
iface.launch()