import gradio as gr import os import io import png import tensorflow as tf import tensorflow_text as tf_text import tensorflow_hub as tf_hub import numpy as np from PIL import Image from huggingface_hub import snapshot_download from sklearn.metrics.pairwise import cosine_similarity import traceback import time # --- Configuración --- MODEL_REPO_ID = "google/cxr-foundation" MODEL_DOWNLOAD_DIR = './hf_cxr_foundation_space' SIMILARITY_DIFFERENCE_THRESHOLD = 0.1 POSITIVE_SIMILARITY_THRESHOLD = 0.1 print(f"Usando umbrales: Comp Δ={SIMILARITY_DIFFERENCE_THRESHOLD}, Simp τ={POSITIVE_SIMILARITY_THRESHOLD}") # Prompts por defecto mejorados criteria_list_positive = [ "optimal centering mediastinum", "deep inspiration", "adequate penetration", "complete lung fields", "scapulae retracted outside lungs", "sharp contrast", "artifact-free image" ] criteria_list_negative = [ "poor centering", "shallow inspiration", "overexposed image", "underexposed image", "cropped lung fields", "scapular overlay on lungs", "blurred image with artifacts" ] # --- Funciones Auxiliares --- def bert_tokenize(text, preprocessor): if preprocessor is None: raise ValueError("BERT preprocessor no está cargado.") text = str(text).lower() out = preprocessor(tf.constant([text])) ids = out['input_word_ids'].numpy().astype(np.int32) masks = out['input_mask'].numpy().astype(np.float32) paddings = 1.0 - masks # Ajustes para el token de fin end_token_idx = (ids == 102) ids[end_token_idx] = 0 paddings[end_token_idx] = 1.0 # Asegurar forma (1,1,128) if ids.ndim == 2: ids = np.expand_dims(ids, 1) if paddings.ndim == 2: paddings = np.expand_dims(paddings, 1) return ids, paddings def png_to_tfexample(image_array: np.ndarray) -> tf.train.Example: # (sin cambios, convierte array NumPy a tf.Example PNG) if image_array.ndim == 3 and image_array.shape[2] == 1: image_array = np.squeeze(image_array, axis=2) elif image_array.ndim != 2: raise ValueError(f'Array debe ser 2-D. Dimensiones: {image_array.ndim}') image = image_array.astype(np.float32) min_val, max_val = image.min(), image.max() if max_val <= min_val: if image_array.dtype == np.uint8 or (min_val >= 0 and max_val <= 255): pixel_array = image.astype(np.uint8); bitdepth = 8 else: pixel_array = np.zeros_like(image, dtype=np.uint16); bitdepth = 16 else: image -= min_val current_max = max_val - min_val if image_array.dtype != np.uint8: image *= 65535 / current_max pixel_array = image.astype(np.uint16); bitdepth = 16 else: image *= 255 / current_max pixel_array = image.astype(np.uint8); bitdepth = 8 output = io.BytesIO() png.Writer(width=pixel_array.shape[1], height=pixel_array.shape[0], greyscale=True, bitdepth=bitdepth).write(output, pixel_array.tolist()) example = tf.train.Example() features = example.features.feature features['image/encoded'].bytes_list.value.append(output.getvalue()) features['image/format'].bytes_list.value.append(b'png') return example def generate_image_embedding(img_np, elixrc_infer, qformer_infer): if elixrc_infer is None or qformer_infer is None: raise ValueError("Modelos ELIXR-C o QFormer no cargados.") try: serialized = png_to_tfexample(img_np).SerializeToString() elixrc_out = elixrc_infer(input_example=tf.constant([serialized])) elixr_emb = elixrc_out['feature_maps_0'].numpy() q_in = { 'image_feature': elixr_emb.tolist(), 'ids': np.zeros((1,1,128),dtype=np.int32).tolist(), 'paddings': np.ones((1,1,128),dtype=np.float32).tolist(), } q_out = qformer_infer(**q_in) img_emb = q_out['all_contrastive_img_emb'].numpy() if img_emb.ndim > 2: img_emb = img_emb.mean(axis=tuple(range(1, img_emb.ndim-1))) if img_emb.ndim == 1: img_emb = img_emb[np.newaxis, :] return img_emb except Exception as e: print(f"Error embedding imagen: {e}") traceback.print_exc() raise def calculate_similarities_and_classify(image_embedding, bert_preprocessor, qformer_infer, criteria_positive, criteria_negative): results = {} for pos, neg in zip(criteria_positive, criteria_negative): sim_pos = sim_neg = diff = None comp = simp = "ERROR" try: # Embedding texto positivo ids_p, pad_p = bert_tokenize(pos, bert_preprocessor) inp_p = {'image_feature': np.zeros([1,8,8,1376],dtype=np.float32).tolist(), 'ids': ids_p.tolist(), 'paddings': pad_p.tolist()} txt_p = qformer_infer(**inp_p)['contrastive_txt_emb'].numpy() # Embedding texto negativo ids_n, pad_n = bert_tokenize(neg, bert_preprocessor) inp_n = {'image_feature': np.zeros([1,8,8,1376],dtype=np.float32).tolist(), 'ids': ids_n.tolist(), 'paddings': pad_n.tolist()} txt_n = qformer_infer(**inp_n)['contrastive_txt_emb'].numpy() sim_pos = float(cosine_similarity(image_embedding, txt_p.reshape(1,-1))[0][0]) sim_neg = float(cosine_similarity(image_embedding, txt_n.reshape(1,-1))[0][0]) diff = sim_pos - sim_neg comp = "PASS" if diff > SIMILARITY_DIFFERENCE_THRESHOLD else "FAIL" simp = "PASS" if sim_pos > POSITIVE_SIMILARITY_THRESHOLD else "FAIL" except Exception as e: print(f"Error en criterio '{pos}': {e}") results[pos] = { 'positive_prompt': pos, 'negative_prompt': neg, 'sim_pos': sim_pos, 'sim_neg': sim_neg, 'difference': diff, 'comp': comp, 'simp': simp } return results # --- Carga Global de Modelos --- print("--- Iniciando carga de modelos ---") start_time = time.time() models_loaded = False bert_preproc = elixrc = qformer = None try: hf_token = os.environ.get("HF_TOKEN") os.makedirs(MODEL_DOWNLOAD_DIR, exist_ok=True) snapshot_download(repo_id=MODEL_REPO_ID, local_dir=MODEL_DOWNLOAD_DIR, allow_patterns=['elixr-c-v2-pooled/*','pax-elixr-b-text/*'], local_dir_use_symlinks=False, token=hf_token) bert_preproc = tf_hub.KerasLayer("https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/3") elixr = tf.saved_model.load(os.path.join(MODEL_DOWNLOAD_DIR,'elixr-c-v2-pooled')).signatures['serving_default'] qformer = tf.saved_model.load(os.path.join(MODEL_DOWNLOAD_DIR,'pax-elixr-b-text')).signatures['serving_default'] models_loaded = True print(f"Modelos cargados en {time.time()-start_time:.2f}s") except Exception as e: print("ERROR cargando modelos:", e) traceback.print_exc() # --- Función Principal para Gradio --- def assess_quality_and_update_ui(image_pil, pos_input, neg_input): if not models_loaded: raise gr.Error("No se pudieron cargar los modelos.") if image_pil is None: # devuelve: welcome visible, results oculto, imagen None, etiqueta N/A, html vacío, json vacío return ( gr.update(visible=True), gr.update(visible=False), None, "N/A", "", {} ) # Parsear listas de prompts pos_list = [l.strip() for l in pos_input.splitlines() if l.strip()] neg_list = [l.strip() for l in neg_input.splitlines() if l.strip()] if len(pos_list) != len(neg_list): raise gr.Error("El número de prompts positivos y negativos debe coincidir.") # Embedding imagen img_np = np.array(image_pil.convert('L')) emb = generate_image_embedding(img_np, elixr, qformer) # Calcular similitudes details = calculate_similarities_and_classify(emb, bert_preproc, qformer, pos_list, neg_list) # Generar HTML passed = total = 0 rows = "" for crit, d in details.items(): total += 1 if d['comp']=="PASS": passed+=1 c_style = "color:#22c55e;font-weight:bold;" if d['comp']=="PASS" else "color:#ef4444;font-weight:bold;" s_style = "color:#22c55e;font-weight:bold;" if d['simp']=="PASS" else "color:#ef4444;font-weight:bold;" rows += ( f"" f"{crit}" f"{d['sim_pos']:.4f}" f"{d['sim_neg']:.4f}" f"{d['difference']:.4f}" f"{d['comp']}" f"{d['simp']}" f"" ) html = f""" {rows}
CriterionSim (+)Sim (-)Diff Assessment (Comp)Assessment (Simp)
""" # Etiqueta general pass_rate = passed/total if total>0 else 0 if pass_rate>=0.85: overall="Excellent" elif pass_rate>=0.70: overall="Good" elif pass_rate>=0.50: overall="Fair" else: overall="Poor" quality_label = f"{overall} ({passed}/{total} passed)" # Devolver actualizaciones UI return ( gr.update(visible=False), gr.update(visible=True), image_pil, quality_label, html, details ) def reset_ui(): return ( gr.update(visible=True), gr.update(visible=False), None, # limpia input_image None, # limpia output_image "N/A", # etiqueta calidad "", # HTML {} # JSON ) # --- Definir Tema --- dark_theme = gr.themes.Default( primary_hue=gr.themes.colors.blue, secondary_hue=gr.themes.colors.blue, neutral_hue=gr.themes.colors.gray, font=[gr.themes.GoogleFont("Inter"), "ui-sans-serif", "system-ui", "sans-serif"], font_mono=[gr.themes.GoogleFont("JetBrains Mono"), "ui-monospace", "Consolas", "monospace"], ).set( body_background_fill="#111827", background_fill_primary="#1f2937", background_fill_secondary="#374151", block_background_fill="#1f2937", body_text_color="#d1d5db", block_label_text_color="#d1d5db", block_title_text_color="#ffffff", border_color_accent="#374151", border_color_primary="#4b5563", button_primary_background_fill="*primary_600", button_primary_text_color="#ffffff", button_secondary_background_fill="*neutral_700", button_secondary_text_color="#ffffff", input_background_fill="#374151", input_border_color="#4b5563", shadow_drop="rgba(0,0,0,0.2) 0px 2px 4px", block_shadow="rgba(0,0,0,0.2) 0px 2px 5px", ) # --- Interfaz Gradio --- with gr.Blocks(theme=dark_theme, title="CXR Quality Assessment") as demo: # Cabecera gr.Markdown(""" # CXR Quality Assessment

Evalúa la calidad técnica de radiografías de tórax con AI

""") # Prompts editables with gr.Row(): positive_prompts_input = gr.Textarea( label="Prompts Positivos (uno por línea)", value="\n".join(criteria_list_positive), lines=7 ) negative_prompts_input = gr.Textarea( label="Prompts Negativos (uno por línea)", value="\n".join(criteria_list_negative), lines=7 ) # Contenido principal with gr.Row(equal_height=False): with gr.Column(scale=1, min_width=300): gr.Markdown("### 1. Carga de Imagen") input_image = gr.Image(type="pil", label="Sube tu CXR", height=300) with gr.Row(): analyze_btn = gr.Button("Analizar", variant="primary") reset_btn = gr.Button("Reset", variant="secondary") gr.Markdown("

La carga de modelos tarda ~1 min; el análisis ~15–40 s.

") with gr.Column(scale=2): with gr.Column(visible=True) as welcome_block: gr.Markdown("### ¡Bienvenido! Sube una radiografía y haz clic en «Analizar».") with gr.Column(visible=False) as results_block: gr.Markdown("### 2. Resultados") with gr.Row(): output_image = gr.Image(type="pil", label="Imagen Analizada", interactive=False) with gr.Column(): gr.Markdown("#### Calidad Global") output_label = gr.Label(value="N/A") gr.Markdown("#### Evaluación Detallada") output_html = gr.HTML() with gr.Accordion("Ver JSON (debug)", open=False): output_json = gr.JSON() # Conexiones analyze_btn.click( fn=assess_quality_and_update_ui, inputs=[input_image, positive_prompts_input, negative_prompts_input], outputs=[welcome_block, results_block, output_image, output_label, output_html, output_json] ) reset_btn.click( fn=reset_ui, inputs=None, outputs=[welcome_block, results_block, input_image, output_image, output_label, output_html, output_json] ) if __name__ == "__main__": demo.launch(server_name="0.0.0.0", server_port=7860)