File size: 4,947 Bytes
1155704 8c814cd 9aeb1dd 3f69fbe 9c9d2f8 1d2016b 16f16a5 1d2016b 537f975 1d2016b 6604d0d 1729ddc 4b98818 3f76413 1d2016b 537f975 0cec600 1155704 0cec600 1d2016b 47f0902 0cec600 47f0902 0cec600 1d2016b 47f0902 1d2016b 91d1d93 47f0902 0cec600 1d2016b 3f69fbe 91d1d93 537f975 1d2016b 6309d92 537f975 6309d92 1d2016b 537f975 84b4115 1d2016b 84b4115 ecaf6bd 1d2016b ecaf6bd 1d2016b ecaf6bd 1d2016b 6309d92 1d2016b 84b4115 537f975 1d2016b 84b4115 6309d92 1d2016b 6309d92 3f69fbe 1d2016b 3f69fbe 9c9d2f8 3f69fbe 1d2016b 3f69fbe 9c9d2f8 91d1d93 1d2016b 9c9d2f8 3f69fbe d756c15 9c9d2f8 3f69fbe ecaf6bd 9c9d2f8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 |
import os
import sys
import gradio as gr
from multiprocessing import freeze_support
import importlib
import inspect
import json
# Fix path to include src
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "src"))
# Reload TxAgent from txagent.py
import txagent.txagent
importlib.reload(txagent.txagent)
from txagent.txagent import TxAgent
# Debug info
print(">>> TxAgent loaded from:", inspect.getfile(TxAgent))
print(">>> TxAgent has run_gradio_chat:", hasattr(TxAgent, "run_gradio_chat"))
# Env vars
current_dir = os.path.abspath(os.path.dirname(__file__))
os.environ["MKL_THREADING_LAYER"] = "GNU"
os.environ["TOKENIZERS_PARALLELISM"] = "false"
# Model config
model_name = "mims-harvard/TxAgent-T1-Llama-3.1-8B"
rag_model_name = "mims-harvard/ToolRAG-T1-GTE-Qwen2-1.5B"
new_tool_files = {
"new_tool": os.path.join(current_dir, "data", "new_tool.json")
}
# Sample questions
question_examples = [
["Given a patient with WHIM syndrome on prophylactic antibiotics, is it advisable to co-administer Xolremdi with fluconazole?"],
["What treatment options exist for HER2+ breast cancer resistant to trastuzumab?"]
]
# Helper: format assistant responses in collapsible panels
def format_collapsible(content):
if isinstance(content, (dict, list)):
try:
formatted = json.dumps(content, indent=2)
except Exception:
formatted = str(content)
else:
formatted = str(content)
return (
"<details style='border: 1px solid #ccc; padding: 8px; margin-top: 8px;'>"
"<summary style='font-weight: bold;'>Answer</summary>"
f"<pre style='white-space: pre-wrap;'>{formatted}</pre>"
"</details>"
)
# === UI setup
def create_ui(agent):
with gr.Blocks() as demo:
gr.Markdown("<h1 style='text-align: center;'>TxAgent: Therapeutic Reasoning</h1>")
gr.Markdown("Ask biomedical or therapeutic questions. Powered by step-by-step reasoning and tools.")
temperature = gr.Slider(0, 1, value=0.3, label="Temperature")
max_new_tokens = gr.Slider(128, 4096, value=1024, label="Max New Tokens")
max_tokens = gr.Slider(128, 32000, value=8192, label="Max Total Tokens")
max_round = gr.Slider(1, 50, value=30, label="Max Rounds")
multi_agent = gr.Checkbox(label="Enable Multi-agent Reasoning", value=False)
conversation_state = gr.State([])
chatbot = gr.Chatbot(label="TxAgent", height=600, type="messages")
message_input = gr.Textbox(placeholder="Ask your biomedical question...", show_label=False)
send_button = gr.Button("Send", variant="primary")
# Main handler
def handle_chat(message, history, temperature, max_new_tokens, max_tokens, multi_agent, conversation, max_round):
generator = agent.run_gradio_chat(
message=message,
history=history,
temperature=temperature,
max_new_tokens=max_new_tokens,
max_token=max_tokens,
call_agent=multi_agent,
conversation=conversation,
max_round=max_round
)
for update in generator:
formatted = []
for m in update:
role = m["role"] if isinstance(m, dict) else getattr(m, "role", "assistant")
content = m["content"] if isinstance(m, dict) else getattr(m, "content", "")
if role == "assistant":
content = format_collapsible(content)
formatted.append({"role": role, "content": content})
yield formatted
# Button and Enter triggers
inputs = [message_input, chatbot, temperature, max_new_tokens, max_tokens, multi_agent, conversation_state, max_round]
send_button.click(fn=handle_chat, inputs=inputs, outputs=chatbot)
message_input.submit(fn=handle_chat, inputs=inputs, outputs=chatbot)
gr.Examples(examples=question_examples, inputs=message_input)
gr.Markdown("**DISCLAIMER**: This demo is for research purposes only and does not provide medical advice.")
return demo
# === Entry point
if __name__ == "__main__":
freeze_support()
try:
agent = TxAgent(
model_name=model_name,
rag_model_name=rag_model_name,
tool_files_dict=new_tool_files,
force_finish=True,
enable_checker=True,
step_rag_num=10,
seed=100,
additional_default_tools=[] # Avoid loading unimplemented tools
)
agent.init_model()
if not hasattr(agent, "run_gradio_chat"):
raise AttributeError("TxAgent missing run_gradio_chat")
demo = create_ui(agent)
demo.launch(server_name="0.0.0.0", server_port=7860, show_error=True)
except Exception as e:
print(f"❌ App failed to start: {e}")
raise
|