File size: 6,391 Bytes
88282d2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 |
import gradio as gr
import torch
from transformers import CvtForImageClassification, AutoFeatureExtractor
from PIL import Image
import os
# Configuración del dispositivo
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Cargar el extractor de características de Hugging Face
extractor = AutoFeatureExtractor.from_pretrained("microsoft/cvt-13")
# Definir las clases en el mismo orden que el modelo las predice
class_names = [
"glioma_tumor",
"meningioma_tumor",
"no_tumor",
"pituitary_tumor"
]
# Función para cargar el modelo (solo una vez)
def load_model():
model_dir = "models" # Ruta a los pesos
model_file_pytorch = "cvt_model.pth"
# Cargar los pesos del modelo desde el archivo .pth
checkpoint = torch.load(os.path.join(model_dir, model_file_pytorch), map_location=device)
# Cargar el modelo dependiendo de si tenemos el modelo completo o solo los pesos
if isinstance(checkpoint, CvtForImageClassification):
model_pytorch = checkpoint # El checkpoint ya es un modelo completo
else:
model_pytorch = CvtForImageClassification.from_pretrained("microsoft/cvt-13")
model_pytorch.load_state_dict(checkpoint) # Cargar los pesos en el modelo
model_pytorch.to(device)
model_pytorch.eval()
return model_pytorch
# Cargar el modelo una vez cuando la app se inicie
model_pytorch = load_model()
# Función para hacer predicción con la imagen cargada
def predict_image(image):
# Preprocesar la imagen usando el extractor de características
inputs = extractor(images=image, return_tensors="pt").to(device)
# Hacer la predicción con el modelo
with torch.no_grad():
outputs = model_pytorch(**inputs)
# Obtener los logits de la salida
logits = outputs.logits
# Convertir los logits en probabilidades
probabilities = torch.nn.functional.softmax(logits, dim=-1)
# Obtener la clase predicha (índice con mayor probabilidad)
predicted_index = probabilities.argmax(dim=-1).item()
# Mapear el índice de la clase predicha al nombre de la clase
predicted_class = class_names[predicted_index]
# Retornar el nombre de la clase predicha
return predicted_class
# Función para limpiar los inputs
def clear_inputs():
return None, None, None
# Definir el tema y la interfaz de Gradio
theme = gr.themes.Soft(
primary_hue="indigo",
secondary_hue="indigo",
).set(
background_fill_primary='#121212', # Dark background
background_fill_secondary='#1e1e1e',
block_background_fill='#1e1e1e', # Almost black
block_border_color='#333',
block_label_text_color='#fffff',
block_label_text_color_dark = '#fffff',
block_title_text_color_dark = '#fffff',
button_primary_background_fill='#4f46e5', # Violet
button_primary_background_fill_hover='#2563eb', # Light blue
button_secondary_background_fill='#4f46e5',
button_secondary_background_fill_hover='#2563eb',
input_background_fill='#333', # Dark grey
input_border_color='#444', # Intermediate grey
block_label_background_fill='#4f46e5',
block_label_background_fill_dark='#4f46e5',
slider_color='#2563eb',
slider_color_dark='#2563eb',
button_primary_text_color='#fffff',
button_secondary_text_color='#fffff',
button_secondary_background_fill_hover_dark='#4f46e5',
button_cancel_background_fill_hover='#444',
button_cancel_background_fill_hover_dark='#444'
)
with gr.Blocks(theme=theme, css="""
body, gradio-app {
background-image: url('https://b2928487.smushcdn.com/2928487/wp-content/uploads/2022/04/Brain-inspiredAI-2048x1365.jpeg?lossy=1&strip=1&webp=1');
background-size: cover;
color: white;
}
.gradio-container {
background-color: transparent;
background-image: url('https://b2928487.smushcdn.com/2928487/wp-content/uploads/2022/04/Brain-inspiredAI-2048x1365.jpeg?lossy=1&strip=1&webp=1') !important;
background-size: cover !important;
color: white;
}
.gradio-container .gr-dropdown-container select::after {
content: '▼';
color: white;
padding-left: 5px;
}
.gradio-container .gr-dropdown-container select:focus {
outline: none;
border-color: #4f46e5;
}
.gradio-container select {
color: white;
}
input, select, span, button, svg, .secondary-wrap {
color: white;
}
h1 {
color: white;
font-size: 4em;
margin: 20px auto;
}
.gradio-container h1 {
font-size: 5em;
color: white;
text-align: center;
text-shadow: 2px 2px 0px #8A2BE2,
4px 4px 0px #00000033;
text-transform: uppercase;
margin: 18px auto;
}
.gradio-container input {
color: white;
}
.gradio-container .output {
color: white;
}
.required-dropdown li {
color: white;
}
.button-style {
background-color: #4f46e5;
color: white;
}
.button-style:hover {
background-color: #2563eb;
color: white;
}
.gradio-container .contain textarea {
color: white;
font-weight: 600;
font-size: 1.5rem;
}
.contain textarea {
color: white;
font-weight: 600;
font-size: 1.5rem;
}
textarea {
color: white;
font-weight: 600;
font-size: 1.5rem;
background-color: black;
}
textarea .scroll-hide {
color: white;
}
.scroll-hide svelte-1f354aw {
color: white;
}
""") as demo:
gr.Markdown("# Brain Tumor Classification 🧠")
with gr.Row():
with gr.Column():
image_input = gr.Image(type="pil", label="Sube la imagen")
model_input = gr.Dropdown(choices=["model_1", "model_2"], label="Selecciona un modelo", elem_classes=['required-dropdown'])
classify_btn = gr.Button("Clasificar", elem_classes=['button-style'])
clear_btn = gr.Button("Limpiar")
with gr.Column():
prediction_output = gr.Textbox(label="Predicción")
classify_btn.click(predict_image, inputs=[image_input], outputs=prediction_output)
clear_btn.click(clear_inputs, inputs=[], outputs=[image_input, model_input, prediction_output])
demo.launch() |