File size: 20,470 Bytes
4790918 54b9d20 6ae44c8 6045f26 4790918 6ae44c8 4790918 6ae44c8 4790918 6ae44c8 7711f86 6ae44c8 4790918 6be1f83 4790918 6ae44c8 6be1f83 4790918 6be1f83 6ae44c8 4790918 6be1f83 6ae44c8 6be1f83 4790918 6ae44c8 6be1f83 4790918 6ae44c8 4790918 6be1f83 6ae44c8 6be1f83 6ae44c8 4790918 6ae44c8 4790918 6ae44c8 6be1f83 6ae44c8 6be1f83 6ae44c8 4790918 6ae44c8 6be1f83 6ae44c8 4790918 6ae44c8 54b9d20 6ae44c8 54b9d20 6ae44c8 54b9d20 6ae44c8 54b9d20 6ae44c8 6be1f83 54b9d20 6be1f83 6ae44c8 6be1f83 6ae44c8 6be1f83 6ae44c8 6be1f83 54b9d20 6ae44c8 54b9d20 6be1f83 6ae44c8 6be1f83 6ae44c8 54b9d20 6ae44c8 54b9d20 6ae44c8 54b9d20 6be1f83 6ae44c8 6be1f83 6ae44c8 54b9d20 6ae44c8 6be1f83 6ae44c8 54b9d20 6be1f83 6ae44c8 6be1f83 6ae44c8 6be1f83 6ae44c8 6be1f83 4790918 6ae44c8 4790918 6ae44c8 6be1f83 54b9d20 6ae44c8 54b9d20 6ae44c8 6be1f83 6ae44c8 4790918 6ae44c8 6be1f83 6ae44c8 54b9d20 6ae44c8 4790918 6ae44c8 6be1f83 6ae44c8 4790918 6be1f83 4790918 6ae44c8 6be1f83 4790918 54b9d20 6be1f83 6ae44c8 54b9d20 6ae44c8 6be1f83 6ae44c8 6be1f83 4790918 6be1f83 6ae44c8 4790918 6ae44c8 6be1f83 6ae44c8 6be1f83 54b9d20 6ae44c8 6be1f83 4790918 6be1f83 54b9d20 6ae44c8 6be1f83 4790918 6be1f83 6ae44c8 4790918 6ae44c8 54b9d20 6be1f83 6ae44c8 4790918 6be1f83 6ae44c8 4790918 6be1f83 6ae44c8 4790918 6ae44c8 6be1f83 6ae44c8 6be1f83 6ae44c8 6be1f83 6ae44c8 6be1f83 4790918 6ae44c8 4790918 6be1f83 4790918 6ae44c8 4790918 6ae44c8 54b9d20 6ae44c8 54b9d20 6ae44c8 4790918 6ae44c8 6be1f83 6ae44c8 a03be3c 6ae44c8 a03be3c 6be1f83 6ae44c8 4790918 6ae44c8 6be1f83 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 |
import gradio as gr
import os
import io
import png
import tensorflow as tf
import tensorflow_text as tf_text
import tensorflow_hub as tf_hub
import numpy as np
from PIL import Image
from huggingface_hub import snapshot_download, HfFolder
from sklearn.metrics.pairwise import cosine_similarity
import traceback
import time
import pandas as pd # Para formatear la salida en tabla
# --- Configuración ---
MODEL_REPO_ID = "google/cxr-foundation"
MODEL_DOWNLOAD_DIR = './hf_cxr_foundation_space' # Directorio dentro del contenedor del Space
# Umbrales
SIMILARITY_DIFFERENCE_THRESHOLD = 0.1
POSITIVE_SIMILARITY_THRESHOLD = 0.1
print(f"Usando umbrales: Comp Δ={SIMILARITY_DIFFERENCE_THRESHOLD}, Simp τ={POSITIVE_SIMILARITY_THRESHOLD}")
# --- Prompts ---
criteria_list_positive = [
"optimal centering", "optimal inspiration", "optimal penetration",
"complete field of view", "scapulae retracted", "sharp image", "artifact free"
]
criteria_list_negative = [
"poorly centered", "poor inspiration", "non-diagnostic exposure",
"cropped image", "scapulae overlying lungs", "blurred image", "obscuring artifact"
]
# --- Funciones Auxiliares (Integradas o adaptadas) ---
# @tf.function(input_signature=[tf.TensorSpec(shape=[None], dtype=tf.string)]) # Puede ayudar rendimiento
def preprocess_text(text):
"""Función interna del preprocesador BERT."""
return bert_preprocessor_global(text)
def bert_tokenize(text, preprocessor):
"""Tokeniza texto usando el preprocesador BERT cargado globalmente."""
if preprocessor is None:
raise ValueError("BERT preprocessor no está cargado.")
if not isinstance(text, str): text = str(text)
# Ejecutar el preprocesador
out = preprocessor(tf.constant([text.lower()]))
# Extraer y procesar IDs y máscaras
ids = out['input_word_ids'].numpy().astype(np.int32)
masks = out['input_mask'].numpy().astype(np.float32)
paddings = 1.0 - masks
# Reemplazar token [SEP] (102) por 0 y marcar como padding
end_token_idx = (ids == 102)
ids[end_token_idx] = 0
paddings[end_token_idx] = 1.0
# Asegurar las dimensiones (B, T, S) -> (1, 1, 128)
# El preprocesador puede devolver (1, 128), necesitamos (1, 1, 128)
if ids.ndim == 2: ids = np.expand_dims(ids, axis=1)
if paddings.ndim == 2: paddings = np.expand_dims(paddings, axis=1)
# Verificar formas finales
expected_shape = (1, 1, 128)
if ids.shape != expected_shape:
# Intentar reajustar si es necesario (puede pasar con algunas versiones)
if ids.shape == (1,128): ids = np.expand_dims(ids, axis=1)
else: raise ValueError(f"Shape incorrecta para ids: {ids.shape}, esperado {expected_shape}")
if paddings.shape != expected_shape:
if paddings.shape == (1,128): paddings = np.expand_dims(paddings, axis=1)
else: raise ValueError(f"Shape incorrecta para paddings: {paddings.shape}, esperado {expected_shape}")
return ids, paddings
def png_to_tfexample(image_array: np.ndarray) -> tf.train.Example:
"""Crea tf.train.Example desde NumPy array (escala de grises)."""
if image_array.ndim == 3 and image_array.shape[2] == 1:
image_array = np.squeeze(image_array, axis=2) # Asegurar 2D
elif image_array.ndim != 2:
raise ValueError(f'Array debe ser 2-D (escala de grises). Dimensiones actuales: {image_array.ndim}')
image = image_array.astype(np.float32)
min_val = image.min()
max_val = image.max()
# Evitar división por cero si la imagen es constante
if max_val <= min_val:
# Si es constante, tratar como uint8 si el rango original lo permitía,
# o simplemente ponerla a 0 si es float.
if image_array.dtype == np.uint8 or (min_val >= 0 and max_val <= 255):
pixel_array = image.astype(np.uint8)
bitdepth = 8
else: # Caso flotante constante o fuera de rango uint8
pixel_array = np.zeros_like(image, dtype=np.uint16)
bitdepth = 16
else:
image -= min_val # Mover mínimo a cero
current_max = max_val - min_val
# Escalar a 16-bit para mayor precisión si no era uint8 originalmente
if image_array.dtype != np.uint8:
image *= 65535 / current_max
pixel_array = image.astype(np.uint16)
bitdepth = 16
else:
# Si era uint8, mantener el rango y tipo
# La resta del min ya la dejó en [0, current_max]
# Escalar a 255 si es necesario
image *= 255 / current_max
pixel_array = image.astype(np.uint8)
bitdepth = 8
# Codificar como PNG
output = io.BytesIO()
png.Writer(
width=pixel_array.shape[1],
height=pixel_array.shape[0],
greyscale=True,
bitdepth=bitdepth
).write(output, pixel_array.tolist())
png_bytes = output.getvalue()
# Crear tf.train.Example
example = tf.train.Example()
features = example.features.feature
features['image/encoded'].bytes_list.value.append(png_bytes)
features['image/format'].bytes_list.value.append(b'png')
return example
def generate_image_embedding(img_np, elixrc_infer, qformer_infer):
"""Genera embedding final de imagen."""
if elixrc_infer is None or qformer_infer is None:
raise ValueError("Modelos ELIXR-C o QFormer no cargados.")
try:
# 1. ELIXR-C
serialized_img_tf_example = png_to_tfexample(img_np).SerializeToString()
elixrc_output = elixrc_infer(input_example=tf.constant([serialized_img_tf_example]))
elixrc_embedding = elixrc_output['feature_maps_0'].numpy()
print(f" Embedding ELIXR-C shape: {elixrc_embedding.shape}")
# 2. QFormer (Imagen)
qformer_input_img = {
'image_feature': elixrc_embedding.tolist(),
'ids': np.zeros((1, 1, 128), dtype=np.int32).tolist(), # Texto vacío
'paddings': np.ones((1, 1, 128), dtype=np.float32).tolist(), # Todo padding
}
qformer_output_img = qformer_infer(**qformer_input_img)
image_embedding = qformer_output_img['all_contrastive_img_emb'].numpy()
# Ajustar dimensiones si es necesario
if image_embedding.ndim > 2:
print(f" Ajustando dimensiones embedding imagen (original: {image_embedding.shape})")
image_embedding = np.mean(
image_embedding,
axis=tuple(range(1, image_embedding.ndim - 1))
)
if image_embedding.ndim == 1:
image_embedding = np.expand_dims(image_embedding, axis=0)
elif image_embedding.ndim == 1:
image_embedding = np.expand_dims(image_embedding, axis=0) # Asegurar 2D
print(f" Embedding final imagen shape: {image_embedding.shape}")
if image_embedding.ndim != 2:
raise ValueError(f"Embedding final de imagen no tiene 2 dimensiones: {image_embedding.shape}")
return image_embedding
except Exception as e:
print(f"Error generando embedding de imagen: {e}")
traceback.print_exc()
raise # Re-lanzar la excepción para que Gradio la maneje
def calculate_similarities_and_classify(image_embedding, bert_preprocessor, qformer_infer):
"""Calcula similitudes y clasifica."""
if image_embedding is None: raise ValueError("Embedding de imagen es None.")
if bert_preprocessor is None: raise ValueError("Preprocesador BERT es None.")
if qformer_infer is None: raise ValueError("QFormer es None.")
detailed_results = {}
print("\n--- Calculando similitudes y clasificando ---")
for i in range(len(criteria_list_positive)):
positive_text = criteria_list_positive[i]
negative_text = criteria_list_negative[i]
criterion_name = positive_text # Usar prompt positivo como clave
print(f"Procesando criterio: \"{criterion_name}\"")
similarity_positive, similarity_negative, difference = None, None, None
classification_comp, classification_simp = "ERROR", "ERROR"
try:
# 1. Embedding Texto Positivo
tokens_pos, paddings_pos = bert_tokenize(positive_text, bert_preprocessor)
qformer_input_text_pos = {
'image_feature': np.zeros([1, 8, 8, 1376], dtype=np.float32).tolist(), # Dummy
'ids': tokens_pos.tolist(), 'paddings': paddings_pos.tolist(),
}
text_embedding_pos = qformer_infer(**qformer_input_text_pos)['contrastive_txt_emb'].numpy()
if text_embedding_pos.ndim == 1: text_embedding_pos = np.expand_dims(text_embedding_pos, axis=0)
# 2. Embedding Texto Negativo
tokens_neg, paddings_neg = bert_tokenize(negative_text, bert_preprocessor)
qformer_input_text_neg = {
'image_feature': np.zeros([1, 8, 8, 1376], dtype=np.float32).tolist(), # Dummy
'ids': tokens_neg.tolist(), 'paddings': paddings_neg.tolist(),
}
text_embedding_neg = qformer_infer(**qformer_input_text_neg)['contrastive_txt_emb'].numpy()
if text_embedding_neg.ndim == 1: text_embedding_neg = np.expand_dims(text_embedding_neg, axis=0)
# Verificar compatibilidad de dimensiones para similitud
if image_embedding.shape[1] != text_embedding_pos.shape[1]:
raise ValueError(f"Dimensión incompatible: Imagen ({image_embedding.shape[1]}) vs Texto Pos ({text_embedding_pos.shape[1]})")
if image_embedding.shape[1] != text_embedding_neg.shape[1]:
raise ValueError(f"Dimensión incompatible: Imagen ({image_embedding.shape[1]}) vs Texto Neg ({text_embedding_neg.shape[1]})")
# 3. Calcular Similitudes
similarity_positive = cosine_similarity(image_embedding, text_embedding_pos)[0][0]
similarity_negative = cosine_similarity(image_embedding, text_embedding_neg)[0][0]
print(f" Sim (+)={similarity_positive:.4f}, Sim (-)={similarity_negative:.4f}")
# 4. Clasificar
difference = similarity_positive - similarity_negative
classification_comp = "PASS" if difference > SIMILARITY_DIFFERENCE_THRESHOLD else "FAIL"
classification_simp = "PASS" if similarity_positive > POSITIVE_SIMILARITY_THRESHOLD else "FAIL"
print(f" Diff={difference:.4f} -> Comp: {classification_comp}, Simp: {classification_simp}")
except Exception as e:
print(f" ERROR procesando criterio '{criterion_name}': {e}")
traceback.print_exc()
# Mantener clasificaciones como "ERROR"
# Guardar resultados
detailed_results[criterion_name] = {
'positive_prompt': positive_text,
'negative_prompt': negative_text,
'similarity_positive': float(similarity_positive) if similarity_positive is not None else None,
'similarity_negative': float(similarity_negative) if similarity_negative is not None else None,
'difference': float(difference) if difference is not None else None,
'classification_comparative': classification_comp,
'classification_simplified': classification_simp
}
return detailed_results
# --- Carga Global de Modelos ---
# Se ejecuta UNA VEZ al iniciar la aplicación Gradio/Space
print("--- Iniciando carga global de modelos ---")
start_time = time.time()
models_loaded = False
bert_preprocessor_global = None
elixrc_infer_global = None
qformer_infer_global = None
try:
# Verificar autenticación HF (útil si se usan modelos privados, aunque no es el caso aquí)
# if HfFolder.get_token() is None:
# print("Advertencia: No se encontró token de Hugging Face.")
# else:
# print("Token de Hugging Face encontrado.")
# Crear directorio si no existe
os.makedirs(MODEL_DOWNLOAD_DIR, exist_ok=True)
print(f"Descargando/verificando modelos en: {MODEL_DOWNLOAD_DIR}")
snapshot_download(repo_id=MODEL_REPO_ID, local_dir=MODEL_DOWNLOAD_DIR,
allow_patterns=['elixr-c-v2-pooled/*', 'pax-elixr-b-text/*'],
local_dir_use_symlinks=False) # Evitar symlinks
print("Modelos descargados/verificados.")
# Cargar Preprocesador BERT desde TF Hub
print("Cargando Preprocesador BERT...")
# Usar handle explícito puede ser más robusto en algunos entornos
bert_preprocess_handle = "https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/3"
bert_preprocessor_global = tf_hub.KerasLayer(bert_preprocess_handle)
print("Preprocesador BERT cargado.")
# Cargar ELIXR-C
print("Cargando ELIXR-C...")
elixrc_model_path = os.path.join(MODEL_DOWNLOAD_DIR, 'elixr-c-v2-pooled')
elixrc_model = tf.saved_model.load(elixrc_model_path)
elixrc_infer_global = elixrc_model.signatures['serving_default']
print("Modelo ELIXR-C cargado.")
# Cargar QFormer (ELIXR-B Text)
print("Cargando QFormer (ELIXR-B Text)...")
qformer_model_path = os.path.join(MODEL_DOWNLOAD_DIR, 'pax-elixr-b-text')
qformer_model = tf.saved_model.load(qformer_model_path)
qformer_infer_global = qformer_model.signatures['serving_default']
print("Modelo QFormer cargado.")
models_loaded = True
end_time = time.time()
print(f"--- Modelos cargados globalmente con éxito en {end_time - start_time:.2f} segundos ---")
except Exception as e:
models_loaded = False
print(f"--- ERROR CRÍTICO DURANTE LA CARGA GLOBAL DE MODELOS ---")
print(e)
traceback.print_exc()
# Gradio se iniciará, pero la función de análisis fallará.
# --- Función Principal de Procesamiento para Gradio ---
def assess_quality(image_pil):
"""Función que Gradio llamará con la imagen de entrada."""
if not models_loaded:
raise gr.Error("Error: Los modelos no se pudieron cargar. La aplicación no puede procesar imágenes.")
if image_pil is None:
# Devolver resultados vacíos o un mensaje de error si no hay imagen
return pd.DataFrame(), "N/A", None # Dataframe vacío, Label vacío, JSON vacío
print("\n--- Iniciando evaluación para nueva imagen ---")
start_process_time = time.time()
try:
# 1. Convertir PIL Image a NumPy array (escala de grises)
# Gradio con type="pil" ya la entrega como objeto PIL
img_np = np.array(image_pil.convert('L'))
print(f"Imagen convertida a NumPy. Shape: {img_np.shape}, Tipo: {img_np.dtype}")
# 2. Generar Embedding de Imagen
print("Generando embedding de imagen...")
image_embedding = generate_image_embedding(img_np, elixrc_infer_global, qformer_infer_global)
print("Embedding de imagen generado.")
# 3. Calcular Similitudes y Clasificar
print("Calculando similitudes y clasificando criterios...")
detailed_results = calculate_similarities_and_classify(image_embedding, bert_preprocessor_global, qformer_infer_global)
print("Clasificación completada.")
# 4. Formatear Resultados para Gradio
output_data = []
passed_count = 0
total_count = 0
for criterion, details in detailed_results.items():
total_count += 1
sim_pos_str = f"{details['similarity_positive']:.4f}" if details['similarity_positive'] is not None else "N/A"
sim_neg_str = f"{details['similarity_negative']:.4f}" if details['similarity_negative'] is not None else "N/A"
diff_str = f"{details['difference']:.4f}" if details['difference'] is not None else "N/A"
assessment_comp = details['classification_comparative']
assessment_simp = details['classification_simplified']
output_data.append([
criterion,
sim_pos_str,
sim_neg_str,
diff_str,
assessment_comp,
assessment_simp
])
if assessment_comp == "PASS":
passed_count += 1
# Crear DataFrame
df_results = pd.DataFrame(output_data, columns=[
"Criterion", "Sim (+)", "Sim (-)", "Difference", "Assessment (Comp)", "Assessment (Simp)"
])
# Calcular etiqueta de calidad general
overall_quality = "Error"
if total_count > 0:
pass_rate = passed_count / total_count
if pass_rate >= 0.85: overall_quality = "Excellent"
elif pass_rate >= 0.70: overall_quality = "Good"
elif pass_rate >= 0.50: overall_quality = "Fair"
else: overall_quality = "Poor"
quality_label = f"{overall_quality} ({passed_count}/{total_count} criteria passed)"
end_process_time = time.time()
print(f"--- Evaluación completada en {end_process_time - start_process_time:.2f} segundos ---")
# Devolver DataFrame, Etiqueta y JSON
return df_results, quality_label, detailed_results
except Exception as e:
print(f"Error durante el procesamiento de la imagen en Gradio: {e}")
traceback.print_exc()
# Lanzar un gr.Error para mostrarlo en la UI de Gradio
raise gr.Error(f"Error procesando la imagen: {str(e)}")
# --- Definir la Interfaz Gradio ---
css = """
#quality-label label {
font-size: 1.1em;
font-weight: bold;
}
"""
with gr.Blocks(css=css) as demo:
gr.Markdown(
"""
# Chest X-ray Technical Quality Assessment
Upload a chest X-ray image (PNG, JPG, etc.) to evaluate its technical quality based on 7 standard criteria
using the ELIXR model family (comparative strategy: Positive vs Negative prompts).
**Note:** Model loading on startup might take a minute. Processing an image can take 10-30 seconds depending on server load.
"""
)
with gr.Row():
with gr.Column(scale=1):
input_image = gr.Image(type="pil", label="Upload Chest X-ray")
submit_button = gr.Button("Assess Quality", variant="primary")
# Añadir ejemplos si tienes imágenes de ejemplo
# Asegúrate de que la carpeta 'examples' exista y contenga las imágenes
# gr.Examples(
# examples=[os.path.join("examples", "sample_cxr.png")], # Lista de rutas a ejemplos
# inputs=input_image
# )
with gr.Column(scale=2):
output_label = gr.Label(label="Overall Quality Estimate", elem_id="quality-label")
output_dataframe = gr.DataFrame(
headers=["Criterion", "Sim (+)", "Sim (-)", "Difference", "Assessment (Comp)", "Assessment (Simp)"],
label="Detailed Quality Assessment",
wrap=True,
height=350
)
output_json = gr.JSON(label="Raw Results (for debugging)")
gr.Markdown(
f"""
**Explanation:**
* **Criterion:** The quality aspect being evaluated (using the positive prompt text).
* **Sim (+):** Cosine similarity between the image and the *positive* text prompt (e.g., "optimal centering"). Higher is better.
* **Sim (-):** Cosine similarity between the image and the *negative* text prompt (e.g., "poorly centered"). Lower is better.
* **Difference:** Sim (+) - Sim (-). A large positive difference indicates the image is much closer to the positive description.
* **Assessment (Comp):** PASS if Difference > {SIMILARITY_DIFFERENCE_THRESHOLD}, otherwise FAIL. This is the main comparative assessment.
* **Assessment (Simp):** PASS if Sim (+) > {POSITIVE_SIMILARITY_THRESHOLD}, otherwise FAIL. A simpler check based only on positive similarity.
"""
)
# Conectar el botón a la función de procesamiento
submit_button.click(
fn=assess_quality,
inputs=input_image,
outputs=[output_dataframe, output_label, output_json]
)
# --- Iniciar la Aplicación Gradio ---
# Al desplegar en Spaces, Gradio se encarga de esto automáticamente.
# Para ejecutar localmente: demo.launch()
# Para Spaces, es mejor dejar que HF maneje el launch.
# demo.launch(share=True) # Para obtener un link público temporal si corres localmente
if __name__ == "__main__":
# share=True solo si quieres un enlace público temporal desde local
# server_name="0.0.0.0" para permitir conexiones de red local
# server_port=7860 es el puerto estándar de HF Spaces
demo.launch(server_name="0.0.0.0", server_port=7860) |