Ali2206 commited on
Commit
0814161
·
verified ·
1 Parent(s): 50ffc74

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +309 -133
app.py CHANGED
@@ -1,133 +1,309 @@
1
- import os
2
- import sys
3
- import gradio as gr
4
- from multiprocessing import freeze_support
5
- import importlib
6
- import inspect
7
- import json
8
-
9
- # Fix path to include src
10
- sys.path.insert(0, os.path.join(os.path.dirname(__file__), "src"))
11
-
12
- # Reload TxAgent from txagent.py
13
- import txagent.txagent
14
- importlib.reload(txagent.txagent)
15
- from txagent.txagent import TxAgent
16
-
17
- # Debug info
18
- print(">>> TxAgent loaded from:", inspect.getfile(TxAgent))
19
- print(">>> TxAgent has run_gradio_chat:", hasattr(TxAgent, "run_gradio_chat"))
20
-
21
- # Env vars
22
- current_dir = os.path.abspath(os.path.dirname(__file__))
23
- os.environ["MKL_THREADING_LAYER"] = "GNU"
24
- os.environ["TOKENIZERS_PARALLELISM"] = "false"
25
-
26
- # Model config
27
- model_name = "mims-harvard/TxAgent-T1-Llama-3.1-8B"
28
- rag_model_name = "mims-harvard/ToolRAG-T1-GTE-Qwen2-1.5B"
29
- new_tool_files = {
30
- "new_tool": os.path.join(current_dir, "data", "new_tool.json")
31
- }
32
-
33
- # Sample questions
34
- question_examples = [
35
- ["Given a patient with WHIM syndrome on prophylactic antibiotics, is it advisable to co-administer Xolremdi with fluconazole?"],
36
- ["What treatment options exist for HER2+ breast cancer resistant to trastuzumab?"]
37
- ]
38
-
39
- # Helper: format assistant responses in collapsible panels
40
- def format_collapsible(content):
41
- if isinstance(content, (dict, list)):
42
- try:
43
- formatted = json.dumps(content, indent=2)
44
- except Exception:
45
- formatted = str(content)
46
- else:
47
- formatted = str(content)
48
-
49
- return (
50
- "<details style='border: 1px solid #ccc; padding: 8px; margin-top: 8px;'>"
51
- "<summary style='font-weight: bold;'>Answer</summary>"
52
- f"<pre style='white-space: pre-wrap;'>{formatted}</pre>"
53
- "</details>"
54
- )
55
-
56
- # === UI setup
57
- def create_ui(agent):
58
- with gr.Blocks() as demo:
59
- gr.Markdown("<h1 style='text-align: center;'>TxAgent: Therapeutic Reasoning</h1>")
60
- gr.Markdown("Ask biomedical or therapeutic questions. Powered by step-by-step reasoning and tools.")
61
-
62
- temperature = gr.Slider(0, 1, value=0.3, label="Temperature")
63
- max_new_tokens = gr.Slider(128, 4096, value=1024, label="Max New Tokens")
64
- max_tokens = gr.Slider(128, 32000, value=8192, label="Max Total Tokens")
65
- max_round = gr.Slider(1, 50, value=30, label="Max Rounds")
66
- multi_agent = gr.Checkbox(label="Enable Multi-agent Reasoning", value=False)
67
- conversation_state = gr.State([])
68
-
69
- chatbot = gr.Chatbot(label="TxAgent", height=600, type="messages")
70
- message_input = gr.Textbox(placeholder="Ask your biomedical question...", show_label=False)
71
- send_button = gr.Button("Send", variant="primary")
72
-
73
- # Main handler
74
- def handle_chat(message, history, temperature, max_new_tokens, max_tokens, multi_agent, conversation, max_round):
75
- generator = agent.run_gradio_chat(
76
- message=message,
77
- history=history,
78
- temperature=temperature,
79
- max_new_tokens=max_new_tokens,
80
- max_token=max_tokens,
81
- call_agent=multi_agent,
82
- conversation=conversation,
83
- max_round=max_round
84
- )
85
-
86
- for update in generator:
87
- formatted = []
88
- for m in update:
89
- role = m["role"] if isinstance(m, dict) else getattr(m, "role", "assistant")
90
- content = m["content"] if isinstance(m, dict) else getattr(m, "content", "")
91
-
92
- if role == "assistant":
93
- content = format_collapsible(content)
94
-
95
- formatted.append({"role": role, "content": content})
96
- yield formatted
97
-
98
- # Button and Enter triggers
99
- inputs = [message_input, chatbot, temperature, max_new_tokens, max_tokens, multi_agent, conversation_state, max_round]
100
- send_button.click(fn=handle_chat, inputs=inputs, outputs=chatbot)
101
- message_input.submit(fn=handle_chat, inputs=inputs, outputs=chatbot)
102
-
103
- gr.Examples(examples=question_examples, inputs=message_input)
104
- gr.Markdown("**DISCLAIMER**: This demo is for research purposes only and does not provide medical advice.")
105
-
106
- return demo
107
-
108
- # === Entry point
109
- if __name__ == "__main__":
110
- freeze_support()
111
-
112
- try:
113
- agent = TxAgent(
114
- model_name=model_name,
115
- rag_model_name=rag_model_name,
116
- tool_files_dict=new_tool_files,
117
- force_finish=True,
118
- enable_checker=True,
119
- step_rag_num=10,
120
- seed=100,
121
- additional_default_tools=[] # Avoid loading unimplemented tools
122
- )
123
- agent.init_model()
124
-
125
- if not hasattr(agent, "run_gradio_chat"):
126
- raise AttributeError("TxAgent missing run_gradio_chat")
127
-
128
- demo = create_ui(agent)
129
- demo.launch(server_name="0.0.0.0", server_port=7860, show_error=True)
130
-
131
- except Exception as e:
132
- print(f"❌ App failed to start: {e}")
133
- raise
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import gradio as gr
4
+ from multiprocessing import freeze_support
5
+ import importlib
6
+ import inspect
7
+ import json
8
+ from typing import Dict, List, Union
9
+
10
+ # Fix path to include src
11
+ sys.path.insert(0, os.path.join(os.path.dirname(__file__), "src"))
12
+
13
+ # Reload TxAgent from txagent.py
14
+ import txagent.txagent
15
+ importlib.reload(txagent.txagent)
16
+ from txagent.txagent import TxAgent
17
+
18
+ # Debug info
19
+ print(">>> TxAgent loaded from:", inspect.getfile(TxAgent))
20
+ print(">>> TxAgent has run_gradio_chat:", hasattr(TxAgent, "run_gradio_chat"))
21
+
22
+ # Env vars
23
+ current_dir = os.path.abspath(os.path.dirname(__file__))
24
+ os.environ["MKL_THREADING_LAYER"] = "GNU"
25
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
26
+
27
+ # Model config
28
+ model_name = "mims-harvard/TxAgent-T1-Llama-3.1-8B"
29
+ rag_model_name = "mims-harvard/ToolRAG-T1-GTE-Qwen2-1.5B"
30
+ new_tool_files = {
31
+ "new_tool": os.path.join(current_dir, "data", "new_tool.json")
32
+ }
33
+
34
+ # Sample questions
35
+ question_examples = [
36
+ ["Given a patient with WHIM syndrome on prophylactic antibiotics, is it advisable to co-administer Xolremdi with fluconazole?"],
37
+ ["What treatment options exist for HER2+ breast cancer resistant to trastuzumab?"],
38
+ ["What are the drug interactions between warfarin and ciprofloxacin?"]
39
+ ]
40
+
41
+ # Custom CSS for elegant design
42
+ custom_css = """
43
+ :root {
44
+ --primary-color: #4f46e5;
45
+ --secondary-color: #f9fafb;
46
+ --accent-color: #e5e7eb;
47
+ --text-color: #111827;
48
+ --border-radius: 8px;
49
+ }
50
+
51
+ body {
52
+ font-family: 'Inter', system-ui, -apple-system, sans-serif;
53
+ }
54
+
55
+ .dark body {
56
+ --secondary-color: #1f2937;
57
+ --text-color: #f9fafb;
58
+ }
59
+
60
+ .gradio-container {
61
+ max-width: 900px !important;
62
+ margin: 0 auto !important;
63
+ }
64
+
65
+ h1 {
66
+ color: var(--primary-color) !important;
67
+ font-weight: 600 !important;
68
+ margin-bottom: 1rem !important;
69
+ }
70
+
71
+ .chatbot {
72
+ min-height: 600px;
73
+ border-radius: var(--border-radius) !important;
74
+ box-shadow: 0 1px 3px rgba(0, 0, 0, 0.1) !important;
75
+ }
76
+
77
+ .textbox {
78
+ border-radius: var(--border-radius) !important;
79
+ }
80
+
81
+ .button-primary {
82
+ background: var(--primary-color) !important;
83
+ border-radius: var(--border-radius) !important;
84
+ }
85
+
86
+ .answer-panel {
87
+ background: var(--secondary-color) !important;
88
+ border-radius: var(--border-radius) !important;
89
+ padding: 16px !important;
90
+ margin-top: 8px !important;
91
+ border: 1px solid var(--accent-color) !important;
92
+ }
93
+
94
+ .answer-title {
95
+ font-weight: 600 !important;
96
+ color: var(--primary-color) !important;
97
+ margin-bottom: 8px !important;
98
+ }
99
+
100
+ .answer-content {
101
+ white-space: pre-wrap;
102
+ font-family: 'Roboto Mono', monospace;
103
+ font-size: 0.9em;
104
+ line-height: 1.5;
105
+ }
106
+
107
+ .settings-panel {
108
+ background: var(--secondary-color) !important;
109
+ border-radius: var(--border-radius) !important;
110
+ padding: 16px !important;
111
+ margin-bottom: 16px !important;
112
+ border: 1px solid var(--accent-color) !important;
113
+ }
114
+
115
+ .settings-title {
116
+ font-weight: 600 !important;
117
+ margin-bottom: 12px !important;
118
+ color: var(--text-color) !important;
119
+ }
120
+
121
+ .examples-panel {
122
+ margin-top: 16px !important;
123
+ }
124
+ """
125
+
126
+ # Helper: format assistant responses in elegant panels
127
+ def format_response(content: Union[str, Dict, List]) -> str:
128
+ """Format the assistant's response in a structured, user-friendly way."""
129
+ if isinstance(content, (dict, list)):
130
+ try:
131
+ formatted = json.dumps(content, indent=2)
132
+ except Exception:
133
+ formatted = str(content)
134
+ else:
135
+ formatted = str(content)
136
+
137
+ # Clean up common formatting issues
138
+ formatted = formatted.replace("\\n", "\n").replace("\\t", "\t")
139
+
140
+ return (
141
+ f"<div class='answer-panel'>"
142
+ f"<div class='answer-title'>Detailed Response</div>"
143
+ f"<div class='answer-content'>{formatted}</div>"
144
+ f"</div>"
145
+ )
146
+
147
+ # Helper: format tool calls in a structured way
148
+ def format_tool_call(tool_name: str, parameters: Dict) -> str:
149
+ """Format tool calls for display in the chat."""
150
+ return (
151
+ f"<div class='answer-panel' style='background: #f0f9ff;'>"
152
+ f"<div class='answer-title'>Tool Used: {tool_name}</div>"
153
+ f"<div class='answer-content'>Parameters: {json.dumps(parameters, indent=2)}</div>"
154
+ f"</div>"
155
+ )
156
+
157
+ # === UI setup
158
+ def create_ui(agent: TxAgent) -> gr.Blocks:
159
+ """Create the Gradio UI with elegant design and organized responses."""
160
+ with gr.Blocks(css=custom_css, theme=gr.themes.Default()) as demo:
161
+ # Header section
162
+ gr.Markdown(
163
+ """
164
+ <div style='text-align: center; margin-bottom: 24px;'>
165
+ <h1 style='margin-bottom: 8px;'>Therapeutic Decision Support</h1>
166
+ <p style='color: #6b7280;'>Get evidence-based answers to your biomedical questions with step-by-step reasoning</p>
167
+ </div>
168
+ """
169
+ )
170
+
171
+ # Settings panel
172
+ with gr.Accordion("⚙️ Advanced Settings", open=False):
173
+ with gr.Row():
174
+ temperature = gr.Slider(0, 1, value=0.3, label="Creativity", info="Higher values produce more creative outputs")
175
+ max_new_tokens = gr.Slider(128, 4096, value=1024, label="Max Response Length", step=128)
176
+
177
+ with gr.Row():
178
+ max_tokens = gr.Slider(128, 32000, value=8192, label="Context Window", step=1024)
179
+ max_round = gr.Slider(1, 50, value=30, label="Max Reasoning Steps")
180
+
181
+ multi_agent = gr.Checkbox(label="Enable Multi-agent Reasoning", value=False, info="Uses multiple specialized agents for complex questions")
182
+
183
+ conversation_state = gr.State([])
184
+
185
+ # Chat interface
186
+ chatbot = gr.Chatbot(
187
+ label="Therapeutic Reasoning Chat",
188
+ height=600,
189
+ bubble_full_width=False,
190
+ avatar_images=(
191
+ "assets/user_avatar.png", # User avatar
192
+ "assets/bot_avatar.png" # Bot avatar
193
+ )
194
+ )
195
+
196
+ with gr.Row():
197
+ message_input = gr.Textbox(
198
+ placeholder="Ask your biomedical question...",
199
+ show_label=False,
200
+ container=False,
201
+ autofocus=True,
202
+ lines=3,
203
+ max_lines=6
204
+ )
205
+ send_button = gr.Button("Send", variant="primary", size="lg")
206
+
207
+ # Examples section
208
+ gr.Examples(
209
+ examples=question_examples,
210
+ inputs=message_input,
211
+ label="💡 Example Questions",
212
+ examples_per_page=3
213
+ )
214
+
215
+ # Disclaimer
216
+ gr.Markdown(
217
+ """
218
+ <div style='text-align: center; margin-top: 24px; color: #6b7280; font-size: 0.9em;'>
219
+ <strong>Disclaimer</strong>: This tool is for research purposes only and does not constitute medical advice.
220
+ Always consult a healthcare professional for medical decisions.
221
+ </div>
222
+ """
223
+ )
224
+
225
+ # Main handler
226
+ def handle_chat(
227
+ message: str,
228
+ history: List,
229
+ temperature: float,
230
+ max_new_tokens: int,
231
+ max_tokens: int,
232
+ multi_agent: bool,
233
+ conversation: List,
234
+ max_round: int
235
+ ):
236
+ generator = agent.run_gradio_chat(
237
+ message=message,
238
+ history=history,
239
+ temperature=temperature,
240
+ max_new_tokens=max_new_tokens,
241
+ max_token=max_tokens,
242
+ call_agent=multi_agent,
243
+ conversation=conversation,
244
+ max_round=max_round
245
+ )
246
+
247
+ for update in generator:
248
+ formatted = []
249
+ for m in update:
250
+ role = m["role"] if isinstance(m, dict) else getattr(m, "role", "assistant")
251
+ content = m["content"] if isinstance(m, dict) else getattr(m, "content", "")
252
+
253
+ # Format different types of messages appropriately
254
+ if role == "assistant":
255
+ if "tool_name" in m:
256
+ formatted.append({
257
+ "role": role,
258
+ "content": format_tool_call(m["tool_name"], m.get("parameters", {}))
259
+ })
260
+ else:
261
+ formatted.append({
262
+ "role": role,
263
+ "content": format_response(content)
264
+ })
265
+ else:
266
+ formatted.append({"role": role, "content": content})
267
+
268
+ yield formatted
269
+
270
+ # Event handlers
271
+ inputs = [message_input, chatbot, temperature, max_new_tokens, max_tokens, multi_agent, conversation_state, max_round]
272
+ send_button.click(fn=handle_chat, inputs=inputs, outputs=chatbot)
273
+ message_input.submit(fn=handle_chat, inputs=inputs, outputs=chatbot)
274
+
275
+ return demo
276
+
277
+ # === Entry point
278
+ if __name__ == "__main__":
279
+ freeze_support()
280
+
281
+ try:
282
+ # Initialize the agent
283
+ agent = TxAgent(
284
+ model_name=model_name,
285
+ rag_model_name=rag_model_name,
286
+ tool_files_dict=new_tool_files,
287
+ force_finish=True,
288
+ enable_checker=True,
289
+ step_rag_num=10,
290
+ seed=100,
291
+ additional_default_tools=[]
292
+ )
293
+ agent.init_model()
294
+
295
+ if not hasattr(agent, "run_gradio_chat"):
296
+ raise AttributeError("TxAgent missing run_gradio_chat")
297
+
298
+ # Create and launch the UI
299
+ demo = create_ui(agent)
300
+ demo.launch(
301
+ server_name="0.0.0.0",
302
+ server_port=7860,
303
+ show_error=True,
304
+ favicon_path="assets/favicon.ico"
305
+ )
306
+
307
+ except Exception as e:
308
+ print(f"❌ Application failed to start: {e}")
309
+ raise