CXR / app.py
fecia's picture
Update app.py
08879e2 verified
import gradio as gr
import os
import io
import png
import tensorflow as tf
import tensorflow_text as tf_text
import tensorflow_hub as tf_hub
import numpy as np
from PIL import Image
from huggingface_hub import snapshot_download
from sklearn.metrics.pairwise import cosine_similarity
import traceback
import time
# --- Configuración ---
MODEL_REPO_ID = "google/cxr-foundation"
MODEL_DOWNLOAD_DIR = './hf_cxr_foundation_space'
SIMILARITY_DIFFERENCE_THRESHOLD = 0.1
POSITIVE_SIMILARITY_THRESHOLD = 0.1
print(f"Usando umbrales: Comp Δ={SIMILARITY_DIFFERENCE_THRESHOLD}, Simp τ={POSITIVE_SIMILARITY_THRESHOLD}")
# Prompts por defecto mejorados
criteria_list_positive = [
"optimal centering mediastinum",
"deep inspiration",
"adequate penetration",
"complete lung fields",
"scapulae retracted outside lungs",
"sharp contrast",
"artifact-free image"
]
criteria_list_negative = [
"poor centering",
"shallow inspiration",
"overexposed image",
"underexposed image",
"cropped lung fields",
"scapular overlay on lungs",
"blurred image with artifacts"
]
# --- Funciones Auxiliares ---
def bert_tokenize(text, preprocessor):
if preprocessor is None:
raise ValueError("BERT preprocessor no está cargado.")
text = str(text).lower()
out = preprocessor(tf.constant([text]))
ids = out['input_word_ids'].numpy().astype(np.int32)
masks = out['input_mask'].numpy().astype(np.float32)
paddings = 1.0 - masks
# Ajustes para el token de fin
end_token_idx = (ids == 102)
ids[end_token_idx] = 0
paddings[end_token_idx] = 1.0
# Asegurar forma (1,1,128)
if ids.ndim == 2: ids = np.expand_dims(ids, 1)
if paddings.ndim == 2: paddings = np.expand_dims(paddings, 1)
return ids, paddings
def png_to_tfexample(image_array: np.ndarray) -> tf.train.Example:
# (sin cambios, convierte array NumPy a tf.Example PNG)
if image_array.ndim == 3 and image_array.shape[2] == 1:
image_array = np.squeeze(image_array, axis=2)
elif image_array.ndim != 2:
raise ValueError(f'Array debe ser 2-D. Dimensiones: {image_array.ndim}')
image = image_array.astype(np.float32)
min_val, max_val = image.min(), image.max()
if max_val <= min_val:
if image_array.dtype == np.uint8 or (min_val >= 0 and max_val <= 255):
pixel_array = image.astype(np.uint8); bitdepth = 8
else:
pixel_array = np.zeros_like(image, dtype=np.uint16); bitdepth = 16
else:
image -= min_val
current_max = max_val - min_val
if image_array.dtype != np.uint8:
image *= 65535 / current_max
pixel_array = image.astype(np.uint16); bitdepth = 16
else:
image *= 255 / current_max
pixel_array = image.astype(np.uint8); bitdepth = 8
output = io.BytesIO()
png.Writer(width=pixel_array.shape[1], height=pixel_array.shape[0],
greyscale=True, bitdepth=bitdepth).write(output, pixel_array.tolist())
example = tf.train.Example()
features = example.features.feature
features['image/encoded'].bytes_list.value.append(output.getvalue())
features['image/format'].bytes_list.value.append(b'png')
return example
def generate_image_embedding(img_np, elixrc_infer, qformer_infer):
if elixrc_infer is None or qformer_infer is None:
raise ValueError("Modelos ELIXR-C o QFormer no cargados.")
try:
serialized = png_to_tfexample(img_np).SerializeToString()
elixrc_out = elixrc_infer(input_example=tf.constant([serialized]))
elixr_emb = elixrc_out['feature_maps_0'].numpy()
q_in = {
'image_feature': elixr_emb.tolist(),
'ids': np.zeros((1,1,128),dtype=np.int32).tolist(),
'paddings': np.ones((1,1,128),dtype=np.float32).tolist(),
}
q_out = qformer_infer(**q_in)
img_emb = q_out['all_contrastive_img_emb'].numpy()
if img_emb.ndim > 2:
img_emb = img_emb.mean(axis=tuple(range(1, img_emb.ndim-1)))
if img_emb.ndim == 1:
img_emb = img_emb[np.newaxis, :]
return img_emb
except Exception as e:
print(f"Error embedding imagen: {e}")
traceback.print_exc()
raise
def calculate_similarities_and_classify(image_embedding, bert_preprocessor, qformer_infer,
criteria_positive, criteria_negative):
results = {}
for pos, neg in zip(criteria_positive, criteria_negative):
sim_pos = sim_neg = diff = None
comp = simp = "ERROR"
try:
# Embedding texto positivo
ids_p, pad_p = bert_tokenize(pos, bert_preprocessor)
inp_p = {'image_feature': np.zeros([1,8,8,1376],dtype=np.float32).tolist(),
'ids': ids_p.tolist(), 'paddings': pad_p.tolist()}
txt_p = qformer_infer(**inp_p)['contrastive_txt_emb'].numpy()
# Embedding texto negativo
ids_n, pad_n = bert_tokenize(neg, bert_preprocessor)
inp_n = {'image_feature': np.zeros([1,8,8,1376],dtype=np.float32).tolist(),
'ids': ids_n.tolist(), 'paddings': pad_n.tolist()}
txt_n = qformer_infer(**inp_n)['contrastive_txt_emb'].numpy()
sim_pos = float(cosine_similarity(image_embedding, txt_p.reshape(1,-1))[0][0])
sim_neg = float(cosine_similarity(image_embedding, txt_n.reshape(1,-1))[0][0])
diff = sim_pos - sim_neg
comp = "PASS" if diff > SIMILARITY_DIFFERENCE_THRESHOLD else "FAIL"
simp = "PASS" if sim_pos > POSITIVE_SIMILARITY_THRESHOLD else "FAIL"
except Exception as e:
print(f"Error en criterio '{pos}': {e}")
results[pos] = {
'positive_prompt': pos,
'negative_prompt': neg,
'sim_pos': sim_pos,
'sim_neg': sim_neg,
'difference': diff,
'comp': comp,
'simp': simp
}
return results
# --- Carga Global de Modelos ---
print("--- Iniciando carga de modelos ---")
start_time = time.time()
models_loaded = False
bert_preproc = elixrc = qformer = None
try:
hf_token = os.environ.get("HF_TOKEN")
os.makedirs(MODEL_DOWNLOAD_DIR, exist_ok=True)
snapshot_download(repo_id=MODEL_REPO_ID, local_dir=MODEL_DOWNLOAD_DIR,
allow_patterns=['elixr-c-v2-pooled/*','pax-elixr-b-text/*'],
local_dir_use_symlinks=False, token=hf_token)
bert_preproc = tf_hub.KerasLayer("https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/3")
elixr = tf.saved_model.load(os.path.join(MODEL_DOWNLOAD_DIR,'elixr-c-v2-pooled')).signatures['serving_default']
qformer = tf.saved_model.load(os.path.join(MODEL_DOWNLOAD_DIR,'pax-elixr-b-text')).signatures['serving_default']
models_loaded = True
print(f"Modelos cargados en {time.time()-start_time:.2f}s")
except Exception as e:
print("ERROR cargando modelos:", e)
traceback.print_exc()
# --- Función Principal para Gradio ---
def assess_quality_and_update_ui(image_pil, pos_input, neg_input):
if not models_loaded:
raise gr.Error("No se pudieron cargar los modelos.")
if image_pil is None:
# devuelve: welcome visible, results oculto, imagen None, etiqueta N/A, html vacío, json vacío
return (
gr.update(visible=True),
gr.update(visible=False),
None,
"N/A",
"",
{}
)
# Parsear listas de prompts
pos_list = [l.strip() for l in pos_input.splitlines() if l.strip()]
neg_list = [l.strip() for l in neg_input.splitlines() if l.strip()]
if len(pos_list) != len(neg_list):
raise gr.Error("El número de prompts positivos y negativos debe coincidir.")
# Embedding imagen
img_np = np.array(image_pil.convert('L'))
emb = generate_image_embedding(img_np, elixr, qformer)
# Calcular similitudes
details = calculate_similarities_and_classify(emb, bert_preproc, qformer, pos_list, neg_list)
# Generar HTML
passed = total = 0
rows = ""
for crit, d in details.items():
total += 1
if d['comp']=="PASS": passed+=1
c_style = "color:#22c55e;font-weight:bold;" if d['comp']=="PASS" else "color:#ef4444;font-weight:bold;"
s_style = "color:#22c55e;font-weight:bold;" if d['simp']=="PASS" else "color:#ef4444;font-weight:bold;"
rows += (
f"<tr>"
f"<td>{crit}</td>"
f"<td>{d['sim_pos']:.4f}</td>"
f"<td>{d['sim_neg']:.4f}</td>"
f"<td>{d['difference']:.4f}</td>"
f"<td style='{c_style}'>{d['comp']}</td>"
f"<td style='{s_style}'>{d['simp']}</td>"
f"</tr>"
)
html = f"""
<table style="width:100%;border-collapse:collapse;">
<thead style="background:#f2f2f2;">
<tr>
<th>Criterion</th><th>Sim (+)</th><th>Sim (-)</th><th>Diff</th>
<th>Assessment (Comp)</th><th>Assessment (Simp)</th>
</tr>
</thead>
<tbody>{rows}</tbody>
</table>
"""
# Etiqueta general
pass_rate = passed/total if total>0 else 0
if pass_rate>=0.85: overall="Excellent"
elif pass_rate>=0.70: overall="Good"
elif pass_rate>=0.50: overall="Fair"
else: overall="Poor"
quality_label = f"{overall} ({passed}/{total} passed)"
# Devolver actualizaciones UI
return (
gr.update(visible=False),
gr.update(visible=True),
image_pil,
quality_label,
html,
details
)
def reset_ui():
return (
gr.update(visible=True),
gr.update(visible=False),
None, # limpia input_image
None, # limpia output_image
"N/A", # etiqueta calidad
"", # HTML
{} # JSON
)
# --- Definir Tema ---
dark_theme = gr.themes.Default(
primary_hue=gr.themes.colors.blue,
secondary_hue=gr.themes.colors.blue,
neutral_hue=gr.themes.colors.gray,
font=[gr.themes.GoogleFont("Inter"), "ui-sans-serif", "system-ui", "sans-serif"],
font_mono=[gr.themes.GoogleFont("JetBrains Mono"), "ui-monospace", "Consolas", "monospace"],
).set(
body_background_fill="#111827",
background_fill_primary="#1f2937",
background_fill_secondary="#374151",
block_background_fill="#1f2937",
body_text_color="#d1d5db",
block_label_text_color="#d1d5db",
block_title_text_color="#ffffff",
border_color_accent="#374151",
border_color_primary="#4b5563",
button_primary_background_fill="*primary_600",
button_primary_text_color="#ffffff",
button_secondary_background_fill="*neutral_700",
button_secondary_text_color="#ffffff",
input_background_fill="#374151",
input_border_color="#4b5563",
shadow_drop="rgba(0,0,0,0.2) 0px 2px 4px",
block_shadow="rgba(0,0,0,0.2) 0px 2px 5px",
)
# --- Interfaz Gradio ---
with gr.Blocks(theme=dark_theme, title="CXR Quality Assessment") as demo:
# Cabecera
gr.Markdown("""
# <span style="color: #e5e7eb;">CXR Quality Assessment</span>
<p style="color: #9ca3af;">Evalúa la calidad técnica de radiografías de tórax con AI</p>
""")
# Prompts editables
with gr.Row():
positive_prompts_input = gr.Textarea(
label="Prompts Positivos (uno por línea)",
value="\n".join(criteria_list_positive),
lines=7
)
negative_prompts_input = gr.Textarea(
label="Prompts Negativos (uno por línea)",
value="\n".join(criteria_list_negative),
lines=7
)
# Contenido principal
with gr.Row(equal_height=False):
with gr.Column(scale=1, min_width=300):
gr.Markdown("### 1. Carga de Imagen")
input_image = gr.Image(type="pil", label="Sube tu CXR", height=300)
with gr.Row():
analyze_btn = gr.Button("Analizar", variant="primary")
reset_btn = gr.Button("Reset", variant="secondary")
gr.Markdown("<p style='color:#9ca3af; font-size:0.9em;'>La carga de modelos tarda ~1 min; el análisis ~15–40 s.</p>")
with gr.Column(scale=2):
with gr.Column(visible=True) as welcome_block:
gr.Markdown("### ¡Bienvenido! Sube una radiografía y haz clic en «Analizar».")
with gr.Column(visible=False) as results_block:
gr.Markdown("### 2. Resultados")
with gr.Row():
output_image = gr.Image(type="pil", label="Imagen Analizada", interactive=False)
with gr.Column():
gr.Markdown("#### Calidad Global")
output_label = gr.Label(value="N/A")
gr.Markdown("#### Evaluación Detallada")
output_html = gr.HTML()
with gr.Accordion("Ver JSON (debug)", open=False):
output_json = gr.JSON()
# Conexiones
analyze_btn.click(
fn=assess_quality_and_update_ui,
inputs=[input_image, positive_prompts_input, negative_prompts_input],
outputs=[welcome_block, results_block, output_image, output_label, output_html, output_json]
)
reset_btn.click(
fn=reset_ui,
inputs=None,
outputs=[welcome_block, results_block, input_image, output_image, output_label, output_html, output_json]
)
if __name__ == "__main__":
demo.launch(server_name="0.0.0.0", server_port=7860)