noequalindi's picture
add gradio app and model
88282d2
raw
history blame contribute delete
6.39 kB
import gradio as gr
import torch
from transformers import CvtForImageClassification, AutoFeatureExtractor
from PIL import Image
import os
# Configuración del dispositivo
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Cargar el extractor de características de Hugging Face
extractor = AutoFeatureExtractor.from_pretrained("microsoft/cvt-13")
# Definir las clases en el mismo orden que el modelo las predice
class_names = [
"glioma_tumor",
"meningioma_tumor",
"no_tumor",
"pituitary_tumor"
]
# Función para cargar el modelo (solo una vez)
def load_model():
model_dir = "models" # Ruta a los pesos
model_file_pytorch = "cvt_model.pth"
# Cargar los pesos del modelo desde el archivo .pth
checkpoint = torch.load(os.path.join(model_dir, model_file_pytorch), map_location=device)
# Cargar el modelo dependiendo de si tenemos el modelo completo o solo los pesos
if isinstance(checkpoint, CvtForImageClassification):
model_pytorch = checkpoint # El checkpoint ya es un modelo completo
else:
model_pytorch = CvtForImageClassification.from_pretrained("microsoft/cvt-13")
model_pytorch.load_state_dict(checkpoint) # Cargar los pesos en el modelo
model_pytorch.to(device)
model_pytorch.eval()
return model_pytorch
# Cargar el modelo una vez cuando la app se inicie
model_pytorch = load_model()
# Función para hacer predicción con la imagen cargada
def predict_image(image):
# Preprocesar la imagen usando el extractor de características
inputs = extractor(images=image, return_tensors="pt").to(device)
# Hacer la predicción con el modelo
with torch.no_grad():
outputs = model_pytorch(**inputs)
# Obtener los logits de la salida
logits = outputs.logits
# Convertir los logits en probabilidades
probabilities = torch.nn.functional.softmax(logits, dim=-1)
# Obtener la clase predicha (índice con mayor probabilidad)
predicted_index = probabilities.argmax(dim=-1).item()
# Mapear el índice de la clase predicha al nombre de la clase
predicted_class = class_names[predicted_index]
# Retornar el nombre de la clase predicha
return predicted_class
# Función para limpiar los inputs
def clear_inputs():
return None, None, None
# Definir el tema y la interfaz de Gradio
theme = gr.themes.Soft(
primary_hue="indigo",
secondary_hue="indigo",
).set(
background_fill_primary='#121212', # Dark background
background_fill_secondary='#1e1e1e',
block_background_fill='#1e1e1e', # Almost black
block_border_color='#333',
block_label_text_color='#fffff',
block_label_text_color_dark = '#fffff',
block_title_text_color_dark = '#fffff',
button_primary_background_fill='#4f46e5', # Violet
button_primary_background_fill_hover='#2563eb', # Light blue
button_secondary_background_fill='#4f46e5',
button_secondary_background_fill_hover='#2563eb',
input_background_fill='#333', # Dark grey
input_border_color='#444', # Intermediate grey
block_label_background_fill='#4f46e5',
block_label_background_fill_dark='#4f46e5',
slider_color='#2563eb',
slider_color_dark='#2563eb',
button_primary_text_color='#fffff',
button_secondary_text_color='#fffff',
button_secondary_background_fill_hover_dark='#4f46e5',
button_cancel_background_fill_hover='#444',
button_cancel_background_fill_hover_dark='#444'
)
with gr.Blocks(theme=theme, css="""
body, gradio-app {
background-image: url('https://b2928487.smushcdn.com/2928487/wp-content/uploads/2022/04/Brain-inspiredAI-2048x1365.jpeg?lossy=1&strip=1&webp=1');
background-size: cover;
color: white;
}
.gradio-container {
background-color: transparent;
background-image: url('https://b2928487.smushcdn.com/2928487/wp-content/uploads/2022/04/Brain-inspiredAI-2048x1365.jpeg?lossy=1&strip=1&webp=1') !important;
background-size: cover !important;
color: white;
}
.gradio-container .gr-dropdown-container select::after {
content: '▼';
color: white;
padding-left: 5px;
}
.gradio-container .gr-dropdown-container select:focus {
outline: none;
border-color: #4f46e5;
}
.gradio-container select {
color: white;
}
input, select, span, button, svg, .secondary-wrap {
color: white;
}
h1 {
color: white;
font-size: 4em;
margin: 20px auto;
}
.gradio-container h1 {
font-size: 5em;
color: white;
text-align: center;
text-shadow: 2px 2px 0px #8A2BE2,
4px 4px 0px #00000033;
text-transform: uppercase;
margin: 18px auto;
}
.gradio-container input {
color: white;
}
.gradio-container .output {
color: white;
}
.required-dropdown li {
color: white;
}
.button-style {
background-color: #4f46e5;
color: white;
}
.button-style:hover {
background-color: #2563eb;
color: white;
}
.gradio-container .contain textarea {
color: white;
font-weight: 600;
font-size: 1.5rem;
}
.contain textarea {
color: white;
font-weight: 600;
font-size: 1.5rem;
}
textarea {
color: white;
font-weight: 600;
font-size: 1.5rem;
background-color: black;
}
textarea .scroll-hide {
color: white;
}
.scroll-hide svelte-1f354aw {
color: white;
}
""") as demo:
gr.Markdown("# Brain Tumor Classification 🧠")
with gr.Row():
with gr.Column():
image_input = gr.Image(type="pil", label="Sube la imagen")
model_input = gr.Dropdown(choices=["model_1", "model_2"], label="Selecciona un modelo", elem_classes=['required-dropdown'])
classify_btn = gr.Button("Clasificar", elem_classes=['button-style'])
clear_btn = gr.Button("Limpiar")
with gr.Column():
prediction_output = gr.Textbox(label="Predicción")
classify_btn.click(predict_image, inputs=[image_input], outputs=prediction_output)
clear_btn.click(clear_inputs, inputs=[], outputs=[image_input, model_input, prediction_output])
demo.launch()