|
import gradio as gr |
|
import os |
|
import io |
|
import png |
|
import tensorflow as tf |
|
import tensorflow_text as tf_text |
|
import tensorflow_hub as tf_hub |
|
import numpy as np |
|
from PIL import Image |
|
from huggingface_hub import snapshot_download |
|
from sklearn.metrics.pairwise import cosine_similarity |
|
import traceback |
|
import time |
|
|
|
|
|
MODEL_REPO_ID = "google/cxr-foundation" |
|
MODEL_DOWNLOAD_DIR = './hf_cxr_foundation_space' |
|
SIMILARITY_DIFFERENCE_THRESHOLD = 0.1 |
|
POSITIVE_SIMILARITY_THRESHOLD = 0.1 |
|
|
|
print(f"Usando umbrales: Comp Δ={SIMILARITY_DIFFERENCE_THRESHOLD}, Simp τ={POSITIVE_SIMILARITY_THRESHOLD}") |
|
|
|
|
|
criteria_list_positive = [ |
|
"optimal centering mediastinum", |
|
"deep inspiration", |
|
"adequate penetration", |
|
"complete lung fields", |
|
"scapulae retracted outside lungs", |
|
"sharp contrast", |
|
"artifact-free image" |
|
] |
|
criteria_list_negative = [ |
|
"poor centering", |
|
"shallow inspiration", |
|
"overexposed image", |
|
"underexposed image", |
|
"cropped lung fields", |
|
"scapular overlay on lungs", |
|
"blurred image with artifacts" |
|
] |
|
|
|
|
|
def bert_tokenize(text, preprocessor): |
|
if preprocessor is None: |
|
raise ValueError("BERT preprocessor no está cargado.") |
|
text = str(text).lower() |
|
out = preprocessor(tf.constant([text])) |
|
ids = out['input_word_ids'].numpy().astype(np.int32) |
|
masks = out['input_mask'].numpy().astype(np.float32) |
|
paddings = 1.0 - masks |
|
|
|
end_token_idx = (ids == 102) |
|
ids[end_token_idx] = 0 |
|
paddings[end_token_idx] = 1.0 |
|
|
|
if ids.ndim == 2: ids = np.expand_dims(ids, 1) |
|
if paddings.ndim == 2: paddings = np.expand_dims(paddings, 1) |
|
return ids, paddings |
|
|
|
def png_to_tfexample(image_array: np.ndarray) -> tf.train.Example: |
|
|
|
if image_array.ndim == 3 and image_array.shape[2] == 1: |
|
image_array = np.squeeze(image_array, axis=2) |
|
elif image_array.ndim != 2: |
|
raise ValueError(f'Array debe ser 2-D. Dimensiones: {image_array.ndim}') |
|
image = image_array.astype(np.float32) |
|
min_val, max_val = image.min(), image.max() |
|
if max_val <= min_val: |
|
if image_array.dtype == np.uint8 or (min_val >= 0 and max_val <= 255): |
|
pixel_array = image.astype(np.uint8); bitdepth = 8 |
|
else: |
|
pixel_array = np.zeros_like(image, dtype=np.uint16); bitdepth = 16 |
|
else: |
|
image -= min_val |
|
current_max = max_val - min_val |
|
if image_array.dtype != np.uint8: |
|
image *= 65535 / current_max |
|
pixel_array = image.astype(np.uint16); bitdepth = 16 |
|
else: |
|
image *= 255 / current_max |
|
pixel_array = image.astype(np.uint8); bitdepth = 8 |
|
output = io.BytesIO() |
|
png.Writer(width=pixel_array.shape[1], height=pixel_array.shape[0], |
|
greyscale=True, bitdepth=bitdepth).write(output, pixel_array.tolist()) |
|
example = tf.train.Example() |
|
features = example.features.feature |
|
features['image/encoded'].bytes_list.value.append(output.getvalue()) |
|
features['image/format'].bytes_list.value.append(b'png') |
|
return example |
|
|
|
def generate_image_embedding(img_np, elixrc_infer, qformer_infer): |
|
if elixrc_infer is None or qformer_infer is None: |
|
raise ValueError("Modelos ELIXR-C o QFormer no cargados.") |
|
try: |
|
serialized = png_to_tfexample(img_np).SerializeToString() |
|
elixrc_out = elixrc_infer(input_example=tf.constant([serialized])) |
|
elixr_emb = elixrc_out['feature_maps_0'].numpy() |
|
q_in = { |
|
'image_feature': elixr_emb.tolist(), |
|
'ids': np.zeros((1,1,128),dtype=np.int32).tolist(), |
|
'paddings': np.ones((1,1,128),dtype=np.float32).tolist(), |
|
} |
|
q_out = qformer_infer(**q_in) |
|
img_emb = q_out['all_contrastive_img_emb'].numpy() |
|
if img_emb.ndim > 2: |
|
img_emb = img_emb.mean(axis=tuple(range(1, img_emb.ndim-1))) |
|
if img_emb.ndim == 1: |
|
img_emb = img_emb[np.newaxis, :] |
|
return img_emb |
|
except Exception as e: |
|
print(f"Error embedding imagen: {e}") |
|
traceback.print_exc() |
|
raise |
|
|
|
def calculate_similarities_and_classify(image_embedding, bert_preprocessor, qformer_infer, |
|
criteria_positive, criteria_negative): |
|
results = {} |
|
for pos, neg in zip(criteria_positive, criteria_negative): |
|
sim_pos = sim_neg = diff = None |
|
comp = simp = "ERROR" |
|
try: |
|
|
|
ids_p, pad_p = bert_tokenize(pos, bert_preprocessor) |
|
inp_p = {'image_feature': np.zeros([1,8,8,1376],dtype=np.float32).tolist(), |
|
'ids': ids_p.tolist(), 'paddings': pad_p.tolist()} |
|
txt_p = qformer_infer(**inp_p)['contrastive_txt_emb'].numpy() |
|
|
|
ids_n, pad_n = bert_tokenize(neg, bert_preprocessor) |
|
inp_n = {'image_feature': np.zeros([1,8,8,1376],dtype=np.float32).tolist(), |
|
'ids': ids_n.tolist(), 'paddings': pad_n.tolist()} |
|
txt_n = qformer_infer(**inp_n)['contrastive_txt_emb'].numpy() |
|
|
|
sim_pos = float(cosine_similarity(image_embedding, txt_p.reshape(1,-1))[0][0]) |
|
sim_neg = float(cosine_similarity(image_embedding, txt_n.reshape(1,-1))[0][0]) |
|
diff = sim_pos - sim_neg |
|
comp = "PASS" if diff > SIMILARITY_DIFFERENCE_THRESHOLD else "FAIL" |
|
simp = "PASS" if sim_pos > POSITIVE_SIMILARITY_THRESHOLD else "FAIL" |
|
except Exception as e: |
|
print(f"Error en criterio '{pos}': {e}") |
|
results[pos] = { |
|
'positive_prompt': pos, |
|
'negative_prompt': neg, |
|
'sim_pos': sim_pos, |
|
'sim_neg': sim_neg, |
|
'difference': diff, |
|
'comp': comp, |
|
'simp': simp |
|
} |
|
return results |
|
|
|
|
|
print("--- Iniciando carga de modelos ---") |
|
start_time = time.time() |
|
models_loaded = False |
|
bert_preproc = elixrc = qformer = None |
|
try: |
|
hf_token = os.environ.get("HF_TOKEN") |
|
os.makedirs(MODEL_DOWNLOAD_DIR, exist_ok=True) |
|
snapshot_download(repo_id=MODEL_REPO_ID, local_dir=MODEL_DOWNLOAD_DIR, |
|
allow_patterns=['elixr-c-v2-pooled/*','pax-elixr-b-text/*'], |
|
local_dir_use_symlinks=False, token=hf_token) |
|
bert_preproc = tf_hub.KerasLayer("https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/3") |
|
elixr = tf.saved_model.load(os.path.join(MODEL_DOWNLOAD_DIR,'elixr-c-v2-pooled')).signatures['serving_default'] |
|
qformer = tf.saved_model.load(os.path.join(MODEL_DOWNLOAD_DIR,'pax-elixr-b-text')).signatures['serving_default'] |
|
models_loaded = True |
|
print(f"Modelos cargados en {time.time()-start_time:.2f}s") |
|
except Exception as e: |
|
print("ERROR cargando modelos:", e) |
|
traceback.print_exc() |
|
|
|
|
|
def assess_quality_and_update_ui(image_pil, pos_input, neg_input): |
|
if not models_loaded: |
|
raise gr.Error("No se pudieron cargar los modelos.") |
|
if image_pil is None: |
|
|
|
return ( |
|
gr.update(visible=True), |
|
gr.update(visible=False), |
|
None, |
|
"N/A", |
|
"", |
|
{} |
|
) |
|
|
|
pos_list = [l.strip() for l in pos_input.splitlines() if l.strip()] |
|
neg_list = [l.strip() for l in neg_input.splitlines() if l.strip()] |
|
if len(pos_list) != len(neg_list): |
|
raise gr.Error("El número de prompts positivos y negativos debe coincidir.") |
|
|
|
img_np = np.array(image_pil.convert('L')) |
|
emb = generate_image_embedding(img_np, elixr, qformer) |
|
|
|
details = calculate_similarities_and_classify(emb, bert_preproc, qformer, pos_list, neg_list) |
|
|
|
passed = total = 0 |
|
rows = "" |
|
for crit, d in details.items(): |
|
total += 1 |
|
if d['comp']=="PASS": passed+=1 |
|
c_style = "color:#22c55e;font-weight:bold;" if d['comp']=="PASS" else "color:#ef4444;font-weight:bold;" |
|
s_style = "color:#22c55e;font-weight:bold;" if d['simp']=="PASS" else "color:#ef4444;font-weight:bold;" |
|
rows += ( |
|
f"<tr>" |
|
f"<td>{crit}</td>" |
|
f"<td>{d['sim_pos']:.4f}</td>" |
|
f"<td>{d['sim_neg']:.4f}</td>" |
|
f"<td>{d['difference']:.4f}</td>" |
|
f"<td style='{c_style}'>{d['comp']}</td>" |
|
f"<td style='{s_style}'>{d['simp']}</td>" |
|
f"</tr>" |
|
) |
|
html = f""" |
|
<table style="width:100%;border-collapse:collapse;"> |
|
<thead style="background:#f2f2f2;"> |
|
<tr> |
|
<th>Criterion</th><th>Sim (+)</th><th>Sim (-)</th><th>Diff</th> |
|
<th>Assessment (Comp)</th><th>Assessment (Simp)</th> |
|
</tr> |
|
</thead> |
|
<tbody>{rows}</tbody> |
|
</table> |
|
""" |
|
|
|
pass_rate = passed/total if total>0 else 0 |
|
if pass_rate>=0.85: overall="Excellent" |
|
elif pass_rate>=0.70: overall="Good" |
|
elif pass_rate>=0.50: overall="Fair" |
|
else: overall="Poor" |
|
quality_label = f"{overall} ({passed}/{total} passed)" |
|
|
|
return ( |
|
gr.update(visible=False), |
|
gr.update(visible=True), |
|
image_pil, |
|
quality_label, |
|
html, |
|
details |
|
) |
|
|
|
def reset_ui(): |
|
return ( |
|
gr.update(visible=True), |
|
gr.update(visible=False), |
|
None, |
|
None, |
|
"N/A", |
|
"", |
|
{} |
|
) |
|
|
|
|
|
dark_theme = gr.themes.Default( |
|
primary_hue=gr.themes.colors.blue, |
|
secondary_hue=gr.themes.colors.blue, |
|
neutral_hue=gr.themes.colors.gray, |
|
font=[gr.themes.GoogleFont("Inter"), "ui-sans-serif", "system-ui", "sans-serif"], |
|
font_mono=[gr.themes.GoogleFont("JetBrains Mono"), "ui-monospace", "Consolas", "monospace"], |
|
).set( |
|
body_background_fill="#111827", |
|
background_fill_primary="#1f2937", |
|
background_fill_secondary="#374151", |
|
block_background_fill="#1f2937", |
|
body_text_color="#d1d5db", |
|
block_label_text_color="#d1d5db", |
|
block_title_text_color="#ffffff", |
|
border_color_accent="#374151", |
|
border_color_primary="#4b5563", |
|
button_primary_background_fill="*primary_600", |
|
button_primary_text_color="#ffffff", |
|
button_secondary_background_fill="*neutral_700", |
|
button_secondary_text_color="#ffffff", |
|
input_background_fill="#374151", |
|
input_border_color="#4b5563", |
|
shadow_drop="rgba(0,0,0,0.2) 0px 2px 4px", |
|
block_shadow="rgba(0,0,0,0.2) 0px 2px 5px", |
|
) |
|
|
|
|
|
with gr.Blocks(theme=dark_theme, title="CXR Quality Assessment") as demo: |
|
|
|
gr.Markdown(""" |
|
# <span style="color: #e5e7eb;">CXR Quality Assessment</span> |
|
<p style="color: #9ca3af;">Evalúa la calidad técnica de radiografías de tórax con AI</p> |
|
""") |
|
|
|
with gr.Row(): |
|
positive_prompts_input = gr.Textarea( |
|
label="Prompts Positivos (uno por línea)", |
|
value="\n".join(criteria_list_positive), |
|
lines=7 |
|
) |
|
negative_prompts_input = gr.Textarea( |
|
label="Prompts Negativos (uno por línea)", |
|
value="\n".join(criteria_list_negative), |
|
lines=7 |
|
) |
|
|
|
with gr.Row(equal_height=False): |
|
with gr.Column(scale=1, min_width=300): |
|
gr.Markdown("### 1. Carga de Imagen") |
|
input_image = gr.Image(type="pil", label="Sube tu CXR", height=300) |
|
with gr.Row(): |
|
analyze_btn = gr.Button("Analizar", variant="primary") |
|
reset_btn = gr.Button("Reset", variant="secondary") |
|
gr.Markdown("<p style='color:#9ca3af; font-size:0.9em;'>La carga de modelos tarda ~1 min; el análisis ~15–40 s.</p>") |
|
with gr.Column(scale=2): |
|
with gr.Column(visible=True) as welcome_block: |
|
gr.Markdown("### ¡Bienvenido! Sube una radiografía y haz clic en «Analizar».") |
|
with gr.Column(visible=False) as results_block: |
|
gr.Markdown("### 2. Resultados") |
|
with gr.Row(): |
|
output_image = gr.Image(type="pil", label="Imagen Analizada", interactive=False) |
|
with gr.Column(): |
|
gr.Markdown("#### Calidad Global") |
|
output_label = gr.Label(value="N/A") |
|
gr.Markdown("#### Evaluación Detallada") |
|
output_html = gr.HTML() |
|
with gr.Accordion("Ver JSON (debug)", open=False): |
|
output_json = gr.JSON() |
|
|
|
analyze_btn.click( |
|
fn=assess_quality_and_update_ui, |
|
inputs=[input_image, positive_prompts_input, negative_prompts_input], |
|
outputs=[welcome_block, results_block, output_image, output_label, output_html, output_json] |
|
) |
|
reset_btn.click( |
|
fn=reset_ui, |
|
inputs=None, |
|
outputs=[welcome_block, results_block, input_image, output_image, output_label, output_html, output_json] |
|
) |
|
|
|
if __name__ == "__main__": |
|
demo.launch(server_name="0.0.0.0", server_port=7860) |
|
|