File size: 4,947 Bytes
1155704
8c814cd
9aeb1dd
3f69fbe
9c9d2f8
 
1d2016b
16f16a5
1d2016b
537f975
1d2016b
 
6604d0d
1729ddc
4b98818
3f76413
1d2016b
 
 
 
 
537f975
0cec600
1155704
0cec600
1d2016b
47f0902
 
0cec600
47f0902
0cec600
 
1d2016b
47f0902
1d2016b
91d1d93
47f0902
0cec600
1d2016b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3f69fbe
91d1d93
537f975
1d2016b
6309d92
537f975
 
 
 
6309d92
 
 
1d2016b
537f975
 
84b4115
1d2016b
84b4115
ecaf6bd
 
 
 
 
 
 
 
 
 
 
 
1d2016b
ecaf6bd
 
 
 
1d2016b
 
ecaf6bd
1d2016b
 
6309d92
1d2016b
 
 
 
84b4115
537f975
1d2016b
84b4115
6309d92
 
1d2016b
6309d92
3f69fbe
1d2016b
3f69fbe
 
9c9d2f8
 
3f69fbe
 
 
 
 
1d2016b
3f69fbe
 
9c9d2f8
91d1d93
1d2016b
9c9d2f8
3f69fbe
d756c15
9c9d2f8
3f69fbe
ecaf6bd
9c9d2f8
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
import os
import sys
import gradio as gr
from multiprocessing import freeze_support
import importlib
import inspect
import json

# Fix path to include src
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "src"))

# Reload TxAgent from txagent.py
import txagent.txagent
importlib.reload(txagent.txagent)
from txagent.txagent import TxAgent

# Debug info
print(">>> TxAgent loaded from:", inspect.getfile(TxAgent))
print(">>> TxAgent has run_gradio_chat:", hasattr(TxAgent, "run_gradio_chat"))

# Env vars
current_dir = os.path.abspath(os.path.dirname(__file__))
os.environ["MKL_THREADING_LAYER"] = "GNU"
os.environ["TOKENIZERS_PARALLELISM"] = "false"

# Model config
model_name = "mims-harvard/TxAgent-T1-Llama-3.1-8B"
rag_model_name = "mims-harvard/ToolRAG-T1-GTE-Qwen2-1.5B"
new_tool_files = {
    "new_tool": os.path.join(current_dir, "data", "new_tool.json")
}

# Sample questions
question_examples = [
    ["Given a patient with WHIM syndrome on prophylactic antibiotics, is it advisable to co-administer Xolremdi with fluconazole?"],
    ["What treatment options exist for HER2+ breast cancer resistant to trastuzumab?"]
]

# Helper: format assistant responses in collapsible panels
def format_collapsible(content):
    if isinstance(content, (dict, list)):
        try:
            formatted = json.dumps(content, indent=2)
        except Exception:
            formatted = str(content)
    else:
        formatted = str(content)

    return (
        "<details style='border: 1px solid #ccc; padding: 8px; margin-top: 8px;'>"
        "<summary style='font-weight: bold;'>Answer</summary>"
        f"<pre style='white-space: pre-wrap;'>{formatted}</pre>"
        "</details>"
    )

# === UI setup
def create_ui(agent):
    with gr.Blocks() as demo:
        gr.Markdown("<h1 style='text-align: center;'>TxAgent: Therapeutic Reasoning</h1>")
        gr.Markdown("Ask biomedical or therapeutic questions. Powered by step-by-step reasoning and tools.")

        temperature = gr.Slider(0, 1, value=0.3, label="Temperature")
        max_new_tokens = gr.Slider(128, 4096, value=1024, label="Max New Tokens")
        max_tokens = gr.Slider(128, 32000, value=8192, label="Max Total Tokens")
        max_round = gr.Slider(1, 50, value=30, label="Max Rounds")
        multi_agent = gr.Checkbox(label="Enable Multi-agent Reasoning", value=False)
        conversation_state = gr.State([])

        chatbot = gr.Chatbot(label="TxAgent", height=600, type="messages")
        message_input = gr.Textbox(placeholder="Ask your biomedical question...", show_label=False)
        send_button = gr.Button("Send", variant="primary")

        # Main handler
        def handle_chat(message, history, temperature, max_new_tokens, max_tokens, multi_agent, conversation, max_round):
            generator = agent.run_gradio_chat(
                message=message,
                history=history,
                temperature=temperature,
                max_new_tokens=max_new_tokens,
                max_token=max_tokens,
                call_agent=multi_agent,
                conversation=conversation,
                max_round=max_round
            )

            for update in generator:
                formatted = []
                for m in update:
                    role = m["role"] if isinstance(m, dict) else getattr(m, "role", "assistant")
                    content = m["content"] if isinstance(m, dict) else getattr(m, "content", "")

                    if role == "assistant":
                        content = format_collapsible(content)

                    formatted.append({"role": role, "content": content})
                yield formatted

        # Button and Enter triggers
        inputs = [message_input, chatbot, temperature, max_new_tokens, max_tokens, multi_agent, conversation_state, max_round]
        send_button.click(fn=handle_chat, inputs=inputs, outputs=chatbot)
        message_input.submit(fn=handle_chat, inputs=inputs, outputs=chatbot)

        gr.Examples(examples=question_examples, inputs=message_input)
        gr.Markdown("**DISCLAIMER**: This demo is for research purposes only and does not provide medical advice.")

    return demo

# === Entry point
if __name__ == "__main__":
    freeze_support()

    try:
        agent = TxAgent(
            model_name=model_name,
            rag_model_name=rag_model_name,
            tool_files_dict=new_tool_files,
            force_finish=True,
            enable_checker=True,
            step_rag_num=10,
            seed=100,
            additional_default_tools=[]  # Avoid loading unimplemented tools
        )
        agent.init_model()

        if not hasattr(agent, "run_gradio_chat"):
            raise AttributeError("TxAgent missing run_gradio_chat")

        demo = create_ui(agent)
        demo.launch(server_name="0.0.0.0", server_port=7860, show_error=True)

    except Exception as e:
        print(f"❌ App failed to start: {e}")
        raise