import os
import torch
import time
import torch
import time
import gradio as gr
import spaces
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig, LlamaTokenizer, TextIteratorStreamer
import threading
import queue
# Globals
current_model = None
current_tokenizer = None
# Curated models
model_choices = [
"meta-llama/Llama-3.2-3B-Instruct",
"deepseek-ai/DeepSeek-R1-Distill-Llama-8B",
"google/gemma-7b-it",
"mistralai/Mistral-Nemo-Instruct-FP8-2407"
]
# Example patient database
patient_db = {
"001 - John Doe": {
"name": "John Doe",
"age": "45",
"id": "001",
"notes": "History of chest pain and hypertension. No prior surgeries."
},
"002 - Maria Sanchez": {
"name": "Maria Sanchez",
"age": "62",
"id": "002",
"notes": "Suspected pulmonary embolism. Shortness of breath, tachycardia."
},
"003 - Ahmed Al-Farsi": {
"name": "Ahmed Al-Farsi",
"age": "29",
"id": "003",
"notes": "Persistent migraines. MRI scheduled for brain imaging."
},
"004 - Lin Wei": {
"name": "Lin Wei",
"age": "51",
"id": "004",
"notes": "Annual screening. Family history of breast cancer."
}
}
# Store conversations per patient
patient_conversations = {}
class RichTextStreamer(TextIteratorStreamer):
def __init__(self, tokenizer, prompt_len=0, **kwargs):
super().__init__(tokenizer, **kwargs)
self.token_queue = queue.Queue()
self.prompt_len = prompt_len
self.count = 0
def put(self, value):
if isinstance(value, torch.Tensor):
token_ids = value.view(-1).tolist()
elif isinstance(value, list):
token_ids = value
else:
token_ids = [value]
for token_id in token_ids:
self.count += 1
if self.count <= self.prompt_len:
continue # skip prompt tokens
token_str = self.tokenizer.decode([token_id], **self.decode_kwargs)
is_special = token_id in self.tokenizer.all_special_ids
self.token_queue.put({
"token_id": token_id,
"token": token_str,
"is_special": is_special
})
def __iter__(self):
while True:
try:
token_info = self.token_queue.get(timeout=self.timeout)
yield token_info
except queue.Empty:
if self.end_of_generation.is_set():
break
@spaces.GPU
def chat_with_model(messages, pid):
global current_model, current_tokenizer
if current_model is None or current_tokenizer is None:
yield messages + [{"role": "assistant", "content": "⚠️ No model loaded."}]
return
current_id = pid
if not current_id:
yield messages
return
max_new_tokens = 1024
output_text = ""
in_think = False
generated_tokens = 0
pad_id = current_tokenizer.pad_token_id or current_tokenizer.unk_token_id or 0
eos_id = current_tokenizer.eos_token_id
# --- Generate from full context
prompt = format_prompt(messages)
device = torch.device("cuda")
current_model.to(device).half()
inputs = current_tokenizer(prompt, return_tensors="pt").to(device)
prompt_len = inputs["input_ids"].shape[-1]
print(prompt)
streamer = RichTextStreamer(
tokenizer=current_tokenizer,
prompt_len=prompt_len,
skip_special_tokens=False
)
generation_kwargs = dict(
**inputs,
max_new_tokens=max_new_tokens,
do_sample=True,
streamer=streamer,
eos_token_id=eos_id,
pad_token_id=pad_id
)
thread = threading.Thread(target=current_model.generate, kwargs=generation_kwargs)
thread.start()
# Now extend previous messages
updated_messages = messages.copy()
updated_messages.append({"role": "assistant", "content": ""})
print(updated_messages)
for token_info in streamer:
token_str = token_info["token"]
token_id = token_info["token_id"]
if token_id == eos_id:
break
if "" in token_str:
in_think = True
token_str = token_str.replace("", "")
output_text += "*"
if "" in token_str:
in_think = False
token_str = token_str.replace("", "")
output_text += token_str + "*"
else:
output_text += token_str
if "\nUser" in output_text:
output_text = output_text.split("\nUser")[0].rstrip()
updated_messages[-1]["content"] = output_text
break
if "\nSystem" in output_text:
output_text = output_text.split("\nSystem")[0].rstrip()
updated_messages[-1]["content"] = output_text
break
if "\nAssistant" in output_text:
output_text = output_text.split("\nAssistant")[0].rstrip()
updated_messages[-1]["content"] = output_text
break
generated_tokens += 1
if generated_tokens >= max_new_tokens:
break
updated_messages[-1]["content"] = output_text
patient_conversations[current_id] = updated_messages
yield updated_messages
if in_think:
output_text += "*"
updated_messages[-1]["content"] = output_text
patient_conversations[current_id] = updated_messages # <- SAVE the full conversation including model outputs
torch.cuda.empty_cache()
return updated_messages
def load_model_on_selection(model_name, progress=gr.Progress(track_tqdm=False)):
global current_model, current_tokenizer
token = os.getenv("HF_TOKEN")
progress(0, desc="Loading config...")
config = AutoConfig.from_pretrained(model_name, use_auth_token=token)
progress(0.2, desc="Loading tokenizer...")
# Default
current_tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code= True, use_auth_token=token)
progress(0.5, desc="Loading model...")
current_model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.float16,
device_map="cpu", # loaded to CPU initially
use_auth_token=token
)
progress(1, desc="Model ready.")
return f"{model_name} loaded and ready!"
# Format conversation as plain text
def format_prompt(messages):
prompt = ""
for msg in messages:
role = msg["role"]
if role == "user":
prompt += f"User: {msg['content'].strip()}\n"
elif role == "assistant":
prompt += f"Assistant: {msg['content'].strip()}\n"
elif role == "system":
prompt += f"System: {msg['content'].strip()}\n"
prompt += "Assistant:"
return prompt
def add_user_message(user_input, history, pid):
if not pid:
return "", []
history.append({"role": "user", "content": user_input})
patient_conversations[pid] = history # single source of truth
return "", history # no extra welcome
def autofill_patient(patient_key):
if patient_key in patient_db:
info = patient_db[patient_key]
# Init empty conversation if not existing
if info["id"] not in patient_conversations:
patient_conversations[info["id"]] = []
return info["name"], info["age"], info["id"], info["notes"]
return "", "", "", ""
# --- Functions (OUTSIDE) ---
def resolve_model_choice(mode, dropdown_value, textbox_value):
return textbox_value.strip() if mode == "Enter custom model" else dropdown_value
def load_patient_conversation(patient_key):
if patient_key in patient_db:
patient_id_val = patient_db[patient_key]["id"]
history = patient_conversations.get(patient_id_val, [])
if not history:
system_message = [
{
"role": "system",
"content": (
"You are a radiologist's companion, here to answer questions about the patient and assist in the diagnosis if asked to do so. "
"You are able to call specialized tools. "
"At the moment, you have one tool available: an organ segmentation algorithm for abdominal CTs.\n\n"
"If the user requests an organ segmentation, output a JSON object in this structure:\n"
"{\n"
" \"function\": \"segment_organ\",\n"
" \"arguments\": {\n"
" \"scan_path\": \"\",\n"
" \"organ\": \"\"\n"
" }\n"
"}\n\n"
"Once you call the function, the app will execute it and return the result."
)
},
{
"role": "system",
"content": f"Patient Information:\nName: {patient_name.value}\nAge: {patient_age.value}\nID: {patient_id.value}\nNotes: {patient_notes.value}"
}
]
welcome_message = [
{
"role": "assistant",
"content": (
"Welcome to the Radiologist's Companion!\n\n"
"You can ask me about the patient's medical history or available imaging data.\n"
"- I can summarize key details from the EHR.\n"
"- I can tell you which medical images are available.\n"
"- If you'd like an organ segmentation (e.g. spleen, liver, kidney_left, colon, femur_right) on an abdominal CT scan, just ask!\n\n"
"Example Requests:\n"
"- \"What do we know about this patient?\"\n"
"- \"Which images are available for this patient?\"\n"
"- \"Can you segment the spleen from the CT scan?\"\n"
)
}
]
history = system_message + welcome_message
return history
return []
def get_patient_conversation():
current_id = patient_id.value
if not current_id:
return []
return patient_conversations.get(current_id, [])
# --- Gradio App ---
css = """
.equal-height > .gr-column {
height: 100% !important;
display: flex;
flex-direction: column;
}
"""
with gr.Blocks(css=css) as demo:
gr.Markdown("Radiologist's Companion
")
default_model = gr.State(model_choices[0])
with gr.Row(elem_classes="equal-height"):
# Patient Information
with gr.Column(scale=1):
gr.Markdown("### Patient Information")
patient_selector = gr.Dropdown(
choices=list(patient_db.keys()),
value=list(patient_db.keys())[0],
label="Select Patient",
allow_custom_value=False
)
patient_name = gr.Textbox(label="Name", placeholder="e.g., John Doe", interactive=False)
patient_age = gr.Textbox(label="Age", placeholder="e.g., 45", interactive=False)
patient_id = gr.Textbox(label="Patient ID", placeholder="e.g., 123456", interactive=False)
patient_notes = gr.Textbox(label="Clinical Notes", lines=10, interactive=False)
# Chat
with gr.Column(scale=3):
gr.Markdown("### Chat")
chatbot = gr.Chatbot(label="Chat", type="messages", height=450)
msg = gr.Textbox(label="Your message", placeholder="Enter your chat message...", show_label=False)
with gr.Row():
submit_btn = gr.Button("Submit", variant="primary")
clear_btn = gr.Button("Clear", variant="secondary")
# Model Settings
with gr.Column(scale=1):
gr.Markdown("### Model Settings")
mode = gr.Radio(["Choose from list", "Enter custom model"], value="Choose from list", label="Model Input Mode")
model_selector = gr.Dropdown(choices=model_choices, label="Select Predefined Model")
model_textbox = gr.Textbox(label="Or Enter HF Model Name")
model_status = gr.Textbox(label="Model Status", interactive=False)
# --- Event Bindings ---
# Load patient info + conversation + model on startup
demo.load(
lambda: autofill_patient(list(patient_db.keys())[0]),
inputs=None,
outputs=[patient_name, patient_age, patient_id, patient_notes]
).then(
lambda: load_patient_conversation(list(patient_db.keys())[0]),
inputs=None,
outputs=chatbot
).then(
load_model_on_selection,
inputs=default_model,
outputs=model_status
)
# Patient selection changes
patient_selector.change(
autofill_patient,
inputs=[patient_selector],
outputs=[patient_name, patient_age, patient_id, patient_notes]
).then(
load_patient_conversation,
inputs=[patient_selector],
outputs=[chatbot]
)
# Model selection logic
mode.select(fn=resolve_model_choice, inputs=[mode, model_selector, model_textbox], outputs=default_model).then(
load_model_on_selection, inputs=default_model, outputs=model_status
)
model_selector.change(fn=resolve_model_choice, inputs=[mode, model_selector, model_textbox], outputs=default_model).then(
load_model_on_selection, inputs=default_model, outputs=model_status
)
model_textbox.submit(fn=resolve_model_choice, inputs=[mode, model_selector, model_textbox], outputs=default_model).then(
load_model_on_selection, inputs=default_model, outputs=model_status
)
msg.submit(
add_user_message,
[msg, chatbot, patient_id],
[msg, chatbot],
queue=False,
).then(
chat_with_model,
[chatbot, patient_id],
chatbot,
)
submit_btn.click(
add_user_message,
[msg, chatbot, patient_id],
[msg, chatbot],
queue=False,
).then(
chat_with_model,
[chatbot, patient_id],
chatbot,
)
# Clear chat
clear_btn.click(lambda: [], None, chatbot, queue=False)
demo.launch()