File size: 6,391 Bytes
88282d2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
import gradio as gr
import torch
from transformers import CvtForImageClassification, AutoFeatureExtractor
from PIL import Image
import os

# Configuración del dispositivo
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Cargar el extractor de características de Hugging Face
extractor = AutoFeatureExtractor.from_pretrained("microsoft/cvt-13")

# Definir las clases en el mismo orden que el modelo las predice
class_names = [
    "glioma_tumor",
    "meningioma_tumor",
    "no_tumor",
    "pituitary_tumor"
]

# Función para cargar el modelo (solo una vez)
def load_model():
    model_dir = "models"  # Ruta a los pesos
    model_file_pytorch = "cvt_model.pth"
    
    # Cargar los pesos del modelo desde el archivo .pth
    checkpoint = torch.load(os.path.join(model_dir, model_file_pytorch), map_location=device)
    
    # Cargar el modelo dependiendo de si tenemos el modelo completo o solo los pesos
    if isinstance(checkpoint, CvtForImageClassification):
        model_pytorch = checkpoint  # El checkpoint ya es un modelo completo
    else:
        model_pytorch = CvtForImageClassification.from_pretrained("microsoft/cvt-13")
        model_pytorch.load_state_dict(checkpoint)  # Cargar los pesos en el modelo
    
    model_pytorch.to(device)
    model_pytorch.eval()
    return model_pytorch

# Cargar el modelo una vez cuando la app se inicie
model_pytorch = load_model()

# Función para hacer predicción con la imagen cargada
def predict_image(image):
    # Preprocesar la imagen usando el extractor de características
    inputs = extractor(images=image, return_tensors="pt").to(device)
    
    # Hacer la predicción con el modelo
    with torch.no_grad():
        outputs = model_pytorch(**inputs)
    
    # Obtener los logits de la salida
    logits = outputs.logits
    
    # Convertir los logits en probabilidades
    probabilities = torch.nn.functional.softmax(logits, dim=-1)
    
    # Obtener la clase predicha (índice con mayor probabilidad)
    predicted_index = probabilities.argmax(dim=-1).item()

    # Mapear el índice de la clase predicha al nombre de la clase
    predicted_class = class_names[predicted_index]

    # Retornar el nombre de la clase predicha
    return predicted_class
# Función para limpiar los inputs
def clear_inputs():
    return None, None, None

# Definir el tema y la interfaz de Gradio
theme = gr.themes.Soft(
    primary_hue="indigo",
    secondary_hue="indigo",
).set(
    background_fill_primary='#121212',  # Dark background
    background_fill_secondary='#1e1e1e',
    block_background_fill='#1e1e1e',  # Almost black
    block_border_color='#333',
    block_label_text_color='#fffff',
    block_label_text_color_dark = '#fffff',
    block_title_text_color_dark = '#fffff',
    button_primary_background_fill='#4f46e5',  # Violet
    button_primary_background_fill_hover='#2563eb',  # Light blue
    button_secondary_background_fill='#4f46e5',
    button_secondary_background_fill_hover='#2563eb',
    input_background_fill='#333',  # Dark grey
    input_border_color='#444',  # Intermediate grey
    block_label_background_fill='#4f46e5',
    block_label_background_fill_dark='#4f46e5',
    slider_color='#2563eb',
    slider_color_dark='#2563eb',
    button_primary_text_color='#fffff',
    button_secondary_text_color='#fffff',
    button_secondary_background_fill_hover_dark='#4f46e5',
    button_cancel_background_fill_hover='#444',
    button_cancel_background_fill_hover_dark='#444'
)

with gr.Blocks(theme=theme, css="""
    body, gradio-app {
      background-image: url('https://b2928487.smushcdn.com/2928487/wp-content/uploads/2022/04/Brain-inspiredAI-2048x1365.jpeg?lossy=1&strip=1&webp=1'); 
      background-size: cover;
      color: white;
    }
    .gradio-container {
      background-color: transparent;
      background-image: url('https://b2928487.smushcdn.com/2928487/wp-content/uploads/2022/04/Brain-inspiredAI-2048x1365.jpeg?lossy=1&strip=1&webp=1') !important; 
      background-size: cover !important;
      color: white;
    }
    .gradio-container .gr-dropdown-container select::after {
        content: '▼'; 
        color: white; 
        padding-left: 5px; 
    }
    .gradio-container .gr-dropdown-container select:focus {
        outline: none; 
        border-color: #4f46e5;
    }
    .gradio-container select {
      color: white;
    }  
    input, select, span, button, svg, .secondary-wrap {
      color: white; 
    }
               
    h1 {
        color: white;
        font-size: 4em;  
        margin: 20px auto;
    }
    .gradio-container h1 { 
        font-size: 5em;  
        color: white; 
        text-align: center; 
        text-shadow: 2px 2px 0px #8A2BE2,  
                     4px 4px 0px #00000033; 
        text-transform: uppercase;
        margin: 18px auto;
    }
    .gradio-container input { 
        color: white; 
    }
    .gradio-container .output { 
        color: white; 
    }
    .required-dropdown li {
      color: white;
    }
    .button-style {
      background-color: #4f46e5;
      color: white;
    }
    .button-style:hover {
      background-color: #2563eb;
      color: white;
    }
               
    .gradio-container .contain textarea {
      color: white;
      font-weight: 600;
      font-size: 1.5rem;
    }
    .contain textarea {
      color: white;
      font-weight: 600;
      font-size: 1.5rem;
    }     
    textarea {
      color: white;
      font-weight: 600;
      font-size: 1.5rem;
      background-color: black;
    }     
    textarea .scroll-hide {
      color: white;
    }
    .scroll-hide svelte-1f354aw {
      color: white;
    }
    """) as demo:

    gr.Markdown("# Brain Tumor Classification 🧠")

    with gr.Row():
        with gr.Column():
            image_input = gr.Image(type="pil", label="Sube la imagen")
            model_input = gr.Dropdown(choices=["model_1", "model_2"], label="Selecciona un modelo", elem_classes=['required-dropdown'])
            classify_btn = gr.Button("Clasificar", elem_classes=['button-style'])
            clear_btn = gr.Button("Limpiar")
        with gr.Column():
            prediction_output = gr.Textbox(label="Predicción")
    
    classify_btn.click(predict_image, inputs=[image_input], outputs=prediction_output)
    clear_btn.click(clear_inputs, inputs=[], outputs=[image_input, model_input, prediction_output])

demo.launch()