import streamlit as st
from PIL import Image
import torch
from transformers import AutoModelForCausalLM, AutoProcessor
import numpy as np
import supervision as sv
import albumentations as A
import cv2
from transformers import AutoConfig
import yaml
# Set Streamlit page configuration for a wide layout
st.set_page_config(layout="wide")
# Custom CSS for better layout and mobile responsiveness
st.markdown("""
""", unsafe_allow_html=True)
# Load Model and Processor
@st.cache_resource
def load_model():
REVISION = 'refs/pr/6'
# MODEL_NAME = "RioJune/AD-KD-MICCAI25"
MODEL_NAME = 'Anonymous-AC/AD-KD'
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
config_model = AutoConfig.from_pretrained("microsoft/Florence-2-base-ft", trust_remote_code=True)
config_model.vision_config.model_type = "davit"
model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, trust_remote_code=True, config=config_model).to(DEVICE)
BASE_PROCESSOR = "microsoft/Florence-2-base-ft"
processor = AutoProcessor.from_pretrained(BASE_PROCESSOR, trust_remote_code=True)
processor.image_processor.size = 512
processor.image_processor.crop_size = 512
return model, processor, DEVICE
model, processor, DEVICE = load_model()
# Load Definitions
@st.cache_resource
def load_definitions():
vindr_path = 'configs/vindr_definition.yaml'
padchest_path = 'configs/padchest_definition.yaml'
prompt_path = 'examples/prompt.yaml'
with open(vindr_path, 'r') as file:
vindr_definitions = yaml.safe_load(file)
with open(padchest_path, 'r') as file:
padchest_definitions = yaml.safe_load(file)
with open(prompt_path, 'r') as file:
prompt_definitions = yaml.safe_load(file)
return vindr_definitions, padchest_definitions, prompt_definitions
vindr_definitions, padchest_definitions, prompt_definitions = load_definitions()
dataset_options = {"Vindr": vindr_definitions, "PadChest": padchest_definitions}
def load_example_images():
return list(prompt_definitions.keys())
example_images = load_example_images()
def apply_transform(image, size_mode=512):
pad_resize_transform = A.Compose([
A.LongestMaxSize(max_size=size_mode, interpolation=cv2.INTER_AREA),
A.PadIfNeeded(min_height=size_mode, min_width=size_mode, border_mode=cv2.BORDER_CONSTANT, value=(0, 0, 0)),
A.Resize(height=512, width=512, interpolation=cv2.INTER_AREA),
])
image_np = np.array(image)
transformed = pad_resize_transform(image=image_np)
return transformed["image"]
# Streamlit UI with Colorful Title and Emojis
st.markdown("
π©Ί Enhancing Abnormality Grounding for Vision Language Models with Knowledge Descriptions π
", unsafe_allow_html=True)
st.markdown(
"Welcome to a simple demo of our work! π Choose an example or upload your own image to get started! π
",
unsafe_allow_html=True
)
# Display Example Images First
st.subheader("π Example Images")
selected_example = st.selectbox("Choose an example", example_images)
image = Image.open(selected_example).convert("RGB")
example_diseases = prompt_definitions.get(selected_example, [])
st.write("**Associated Diseases:**", ", ".join(example_diseases))
# Layout for Original Image and Instructions
col1, col2 = st.columns([1, 2])
# Left column for original image
with col1:
st.image(image, caption=f"Original Example Image: {selected_example}", width=400)
# Right column for Instructions and Run Inference Button
with col2:
st.subheader("βοΈ Instructions to Get Started:")
st.write("""
- **Run Inference**: Click the "Run Inference on Example" button to process the image and display the results.
- **Choose an Example**: π Select an example image from the dataset to view its associated diseases.
- **Upload Your Own Image**: π€ Upload an image of your choice to analyze it for diseases.
- **Select Dataset**: π Choose between available datasets (Vindr or PadChest) for disease information.
- **Select Disease**: π¦ Pick the disease to be analyzed from the list of diseases in the selected dataset.
""")
st.subheader("β οΈ Warning:")
st.write("""
- **π« Please avoid uploading non-frontal chest X-ray images.** Our model has been specifically trained on **frontal chest X-ray images** only.
- This demo is intended for **π¬ research purposes only** and should **β not be used for medical diagnoses**.
- The modelβs responses may contain **π€ hallucinations or incorrect information**.
- Always consult a **π¨ββοΈ medical professional** for accurate diagnosis and advice.
""", unsafe_allow_html=True)
st.markdown("", unsafe_allow_html=True)
# Run Inference Button
if st.button("Run Inference on Example", key="example"):
if image is None:
st.error("β Please select an example image first.")
else:
# Use the selected example's disease and definition for inference
disease_choice = example_diseases[0] if example_diseases else ""
definition = vindr_definitions.get(disease_choice, padchest_definitions.get(disease_choice, ""))
# Generate the prompt for the model
det_obj = f"{disease_choice} means {definition}."
st.write(f"**Definition:** {definition}")
prompt = f"Locate the phrases in the caption: {det_obj}."
prompt = f"{prompt}"
# Prepare the image and input
np_image = np.array(image)
inputs = processor(text=[prompt], images=[np_image], return_tensors="pt", padding=True).to(DEVICE)
with st.spinner("Processing... β³"):
outputs = model.generate(
input_ids=inputs["input_ids"],
pixel_values=inputs["pixel_values"],
max_new_tokens=1024,
num_beams=3,
output_scores=True, # Make sure we get the scores/logits
return_dict_in_generate=True # Ensures you get both sequences and scores in the output
)
# Ensure transition_scores is properly extracted
transition_scores = model.compute_transition_scores(
outputs.sequences, outputs.scores, outputs.beam_indices, normalize_logits=False
)
# Get the generated token IDs (ignoring the input tokens part)
generated_ids = outputs.sequences
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
# Get input length
input_length = inputs.input_ids.shape[1]
generated_tokens = outputs.sequences
# Calculate output length (number of generated tokens)
output_length = np.sum(transition_scores.cpu().numpy() < 0, axis=1)
# Get length penalty
length_penalty = model.generation_config.length_penalty
# Calculate total score for the generated sentence
reconstructed_scores = transition_scores.cpu().sum(axis=1) / (output_length**length_penalty)
# Convert log-probability to probability (0-1 range)
probabilities = np.exp(reconstructed_scores.cpu().numpy())
# Streamlit UI to display the result
st.markdown(f"**π― Probability of the Results:** {probabilities[0] * 100:.2f}%", unsafe_allow_html=True)
predictions = processor.post_process_generation(generated_text, task="", image_size=np_image.shape[:2])
detection = sv.Detections.from_lmm(sv.LMM.FLORENCE_2, predictions, resolution_wh=np_image.shape[:2])
# Annotate the image with bounding boxes and labels
bounding_box_annotator = sv.BoundingBoxAnnotator(color_lookup=sv.ColorLookup.INDEX)
label_annotator = sv.LabelAnnotator(color_lookup=sv.ColorLookup.INDEX)
image_with_predictions = bounding_box_annotator.annotate(np_image.copy(), detection)
image_with_predictions = label_annotator.annotate(image_with_predictions, detection)
annotated_image = Image.fromarray(image_with_predictions.astype(np.uint8))
# Display the original and result images side by side
col1, col2 = st.columns([1, 1])
with col1:
st.image(image, caption=f"Original Image: {selected_example}", width=400)
with col2:
st.image(annotated_image, caption="Inference Results πΌοΈ", width=400)
# Display the generated text
st.write("**Generated Text:**", generated_text)
# Upload Image section
st.subheader("π€ Upload Your Own Image")
col1, col2 = st.columns([1, 1])
with col1:
dataset_choice = st.selectbox("Select Dataset π", options=list(dataset_options.keys()))
disease_options = list(dataset_options[dataset_choice].keys())
with col2:
disease_choice = st.selectbox("Select Disease π¦ ", options=disease_options)
uploaded_file = st.file_uploader("Upload an Image", type=["png", "jpg", "jpeg"])
col1, col2 = st.columns([1, 2])
with col1:
# Handle file upload
if uploaded_file:
image = Image.open(uploaded_file).convert("RGB")
image = apply_transform(image) # Ensure the uploaded image is transformed correctly
st.image(image, caption="Uploaded Image", width=400)
# Let user select dataset and disease dynamically
disease_choice = disease_choice if disease_choice else example_diseases[0]
# Get Definition Priority: Dataset -> User Input
definition = vindr_definitions.get(disease_choice, padchest_definitions.get(disease_choice, ""))
if not definition:
definition = st.text_input("Enter Definition Manually π", value="")
with col2:
# Instructions and warnings
st.subheader("βοΈ Instructions to Get Started:")
st.write("""
- **Run Inference**: Click the "Run Inference on Example" button to process the image and display the results.
- **Choose an Example**: π Select an example image from the dataset to view its associated diseases.
- **Upload Your Own Image**: π€ Upload an image of your choice to analyze it for diseases.
- **Select Dataset**: π Choose between available datasets (Vindr or PadChest) for disease information.
- **Select Disease**: π¦ Pick the disease to be analyzed from the list of diseases in the selected dataset.
""")
st.subheader("β οΈ Warning:")
st.write("""
- **π« Please avoid uploading non-frontal chest X-ray images.** Our model has been specifically trained on **frontal chest X-ray images** only.
- This demo is intended for **π¬ research purposes only** and should **β not be used for medical diagnoses**.
- The modelβs responses may contain **π€ hallucinations or incorrect information**.
- Always consult a **π¨ββοΈ medical professional** for accurate diagnosis and advice.
""", unsafe_allow_html=True)
# Run inference after upload
if st.button("Run Inference πββοΈ"):
if image is None:
st.error("β Please upload an image or select an example.")
else:
det_obj = f"{disease_choice} means {definition}."
st.write(f"**Definition:** {definition}")
# Construct Prompt with Disease Definition
prompt = f"Locate the phrases in the caption: {det_obj}."
prompt = f"{prompt}"
np_image = np.array(image)
inputs = processor(text=[prompt], images=[np_image], return_tensors="pt", padding=True).to(DEVICE)
with st.spinner("Processing... β³"):
# generated_ids = model.generate(input_ids=inputs["input_ids"], pixel_values=inputs["pixel_values"], max_new_tokens=1024, num_beams=3)
# generated_text = processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
outputs = model.generate(
input_ids=inputs["input_ids"],
pixel_values=inputs["pixel_values"],
max_new_tokens=1024,
num_beams=3,
output_scores=True, # Make sure we get the scores/logits
return_dict_in_generate=True # Ensures you get both sequences and scores in the output
)
transition_scores = model.compute_transition_scores(
outputs.sequences, outputs.scores, outputs.beam_indices, normalize_logits=False
)
# Get the generated token IDs (ignoring the input tokens part)
generated_ids = outputs.sequences
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
# Get input length
input_length = inputs.input_ids.shape[1]
# Extract generated tokens (ignoring the input tokens)
# generated_tokens = outputs.sequences[:, input_length:]
generated_tokens = outputs.sequences
# Calculate output length (number of generated tokens)
output_length = np.sum(transition_scores.cpu().numpy() < 0, axis=1)
# Get length penalty
length_penalty = model.generation_config.length_penalty
# Calculate total score for the generated sentence
reconstructed_scores = transition_scores.cpu().sum(axis=1) / (output_length**length_penalty)
# Convert log-probability to probability (0-1 range)
probabilities = np.exp(reconstructed_scores.cpu().numpy())
# Streamlit UI to display the result
# st.write(f"**Probability of the Results (0-1):** {probabilities[0]:.4f}")
st.markdown(f"**π― Probability of the Results:** {probabilities[0] * 100:.2f}%", unsafe_allow_html=True)
predictions = processor.post_process_generation(generated_text, task="", image_size=np_image.shape[:2])
detection = sv.Detections.from_lmm(sv.LMM.FLORENCE_2, predictions, resolution_wh=np_image.shape[:2])
bounding_box_annotator = sv.BoundingBoxAnnotator(color_lookup=sv.ColorLookup.INDEX)
label_annotator = sv.LabelAnnotator(color_lookup=sv.ColorLookup.INDEX)
image_with_predictions = bounding_box_annotator.annotate(np_image.copy(), detection)
image_with_predictions = label_annotator.annotate(image_with_predictions, detection)
annotated_image = Image.fromarray(image_with_predictions.astype(np.uint8))
# Create two columns to display the original and the results side by side
col1, col2 = st.columns([1, 1])
# Left column for original image
with col1:
st.image(image, caption="Uploaded Image", width=400)
# Right column for result image
with col2:
st.image(annotated_image, caption="Inference Results πΌοΈ", width=400)
# Display the generated text
st.write("**Generated Text:**", generated_text)