Ali2206 commited on
Commit
1155704
·
verified ·
1 Parent(s): 0cec600

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +108 -173
app.py CHANGED
@@ -1,35 +1,27 @@
1
  import random
 
2
  import datetime
3
  import sys
4
  from txagent import TxAgent
5
  import spaces
6
  import gradio as gr
7
- import os
8
- import os
9
 
10
- # Determine the directory where the current file is located
11
  current_dir = os.path.dirname(os.path.abspath(__file__))
12
  os.environ["MKL_THREADING_LAYER"] = "GNU"
13
-
14
- # Set an environment variable
15
  HF_TOKEN = os.environ.get("HF_TOKEN", None)
16
 
17
-
18
  DESCRIPTION = '''
19
  <div>
20
  <h1 style="text-align: center;">TxAgent: An AI Agent for Therapeutic Reasoning Across a Universe of Tools </h1>
21
  </div>
22
  '''
23
  INTRO = """
24
- Precision therapeutics require multimodal adaptive models that provide personalized treatment recommendations. We introduce TxAgent, an AI agent that leverages multi-step reasoning and real-time biomedical knowledge retrieval across a toolbox of 211 expert-curated tools to navigate complex drug interactions, contraindications, and patient-specific treatment strategies, delivering evidence-grounded therapeutic decisions. TxAgent executes goal-oriented tool selection and iterative function calls to solve therapeutic tasks that require deep clinical understanding and cross-source validation. The ToolUniverse consolidates 211 tools linked to trusted sources, including all US FDA-approved drugs since 1939 and validated clinical insights from Open Targets.
25
  """
26
-
27
  LICENSE = """
28
- We welcome your feedback and suggestions to enhance your experience with TxAgent, and if you're interested in collaboration, please email Marinka Zitnik and Shanghua Gao.
29
-
30
- ### Medical Advice Disclaimer
31
- DISCLAIMER: THIS WEBSITE DOES NOT PROVIDE MEDICAL ADVICE
32
- The information, including but not limited to, text, graphics, images and other material contained on this website are for informational purposes only. No material on this site is intended to be a substitute for professional medical advice, diagnosis or treatment. Always seek the advice of your physician or other qualified health care provider with any questions you may have regarding a medical condition or treatment and before undertaking a new health care regimen, and never disregard professional medical advice or delay in seeking it because of something you have read on this website.
33
  """
34
 
35
  PLACEHOLDER = """
@@ -37,8 +29,8 @@ PLACEHOLDER = """
37
  <h1 style="font-size: 28px; margin-bottom: 2px; opacity: 0.55;">TxAgent</h1>
38
  <p style="font-size: 18px; margin-bottom: 2px; opacity: 0.65;">Tips before using TxAgent:</p>
39
  <p style="font-size: 18px; margin-bottom: 2px; opacity: 0.55;">Please click clear🗑️
40
- (top-right) to remove previous context before sumbmitting a new question.</p>
41
- <p style="font-size: 18px; margin-bottom: 2px; opacity: 0.55;">Click retry🔄 (below message) to get multiple versions of the answer.</p>
42
  </div>
43
  """
44
 
@@ -47,7 +39,6 @@ h1 {
47
  text-align: center;
48
  display: block;
49
  }
50
-
51
  #duplicate-button {
52
  margin: auto;
53
  color: white;
@@ -67,17 +58,13 @@ h1 {
67
  """
68
 
69
  chat_css = """
70
- .gr-button { font-size: 20px !important; } /* Enlarges button icons */
71
- .gr-button svg { width: 32px !important; height: 32px !important; } /* Enlarges SVG icons */
72
  """
73
 
74
- # model_name = '/n/holylfs06/LABS/mzitnik_lab/Lab/shgao/bioagent/bio/alignment-handbook/data_new/L8-qlora-biov49v9v7v16_32k_chat01_merged'
75
  model_name = 'mims-harvard/TxAgent-T1-Llama-3.1-8B'
76
  rag_model_name = 'mims-harvard/ToolRAG-T1-GTE-Qwen2-1.5B'
77
 
78
- os.environ["TOKENIZERS_PARALLELISM"] = "false"
79
-
80
-
81
  question_examples = [
82
  ['Given a 50-year-old patient experiencing severe acute pain and considering the use of the newly approved medication, Journavx, how should the dosage be adjusted considering the presence of moderate hepatic impairment?'],
83
  ['Given a 50-year-old patient experiencing severe acute pain and considering the use of the newly approved medication, Journavx, how should the dosage be adjusted considering the presence of severe hepatic impairment?'],
@@ -98,12 +85,10 @@ agent = TxAgent(model_name,
98
  additional_default_tools=['DirectResponse', 'RequireClarification'])
99
  agent.init_model()
100
 
101
-
102
  def update_model_parameters(enable_finish, enable_rag, enable_summary,
103
  init_rag_num, step_rag_num, skip_last_k,
104
  summary_mode, summary_skip_last_k, summary_context_length, force_finish, seed):
105
- # Update model instance parameters dynamically
106
- updated_params = agent.update_parameters(
107
  enable_finish=enable_finish,
108
  enable_rag=enable_rag,
109
  enable_summary=enable_summary,
@@ -117,33 +102,18 @@ def update_model_parameters(enable_finish, enable_rag, enable_summary,
117
  seed=seed,
118
  )
119
 
120
- return updated_params
121
-
122
-
123
  def update_seed():
124
- # Update model instance parameters dynamically
125
  seed = random.randint(0, 10000)
126
- updated_params = agent.update_parameters(
127
- seed=seed,
128
- )
129
- return updated_params
130
-
131
 
132
  def handle_retry(history, retry_data: gr.RetryData, temperature, max_new_tokens, max_tokens, multi_agent, conversation, max_round):
133
- print("Updated seed:", update_seed())
134
  new_history = history[:retry_data.index]
135
  previous_prompt = history[retry_data.index]['content']
136
-
137
- print("previous_prompt", previous_prompt)
138
-
139
  yield from agent.run_gradio_chat(new_history + [{"role": "user", "content": previous_prompt}], temperature, max_new_tokens, max_tokens, multi_agent, conversation, max_round)
140
 
141
-
142
  PASSWORD = "mypassword"
143
 
144
- # Function to check if the password is correct
145
-
146
-
147
  def check_password(input_password):
148
  if input_password == PASSWORD:
149
  return gr.update(visible=True), ""
@@ -151,143 +121,108 @@ def check_password(input_password):
151
  return gr.update(visible=False), "Incorrect password, try again!"
152
 
153
 
154
- conversation_state = gr.State([])
155
-
156
- # Gradio block
157
- chatbot = gr.Chatbot(height=800, placeholder=PLACEHOLDER,
158
- label='TxAgent', type="messages", show_copy_button=True)
159
-
160
- with gr.Blocks(css=css) as demo:
161
- gr.Markdown(DESCRIPTION)
162
- gr.Markdown(INTRO)
163
- default_temperature = 0.3
164
- default_max_new_tokens = 1024
165
- default_max_tokens = 81920
166
- default_max_round = 30
167
- temperature_state = gr.State(value=default_temperature)
168
- max_new_tokens_state = gr.State(value=default_max_new_tokens)
169
- max_tokens_state = gr.State(value=default_max_tokens)
170
- max_round_state = gr.State(value=default_max_round)
171
- chatbot.retry(handle_retry, chatbot, chatbot, temperature_state, max_new_tokens_state,
172
- max_tokens_state, gr.Checkbox(value=False, render=False), conversation_state, max_round_state)
173
 
174
- gr.ChatInterface(
175
- fn=agent.run_gradio_chat,
176
- chatbot=chatbot,
177
- fill_height=True, fill_width=True, stop_btn=True,
178
- additional_inputs_accordion=gr.Accordion(
179
- label="⚙️ Inference Parameters", open=False, render=False),
180
- additional_inputs=[
181
- temperature_state, max_new_tokens_state, max_tokens_state,
182
- gr.Checkbox(
183
- label="Activate multi-agent reasoning mode (it requires additional time but offers a more comprehensive analysis).", value=False, render=False),
184
- conversation_state,
185
- max_round_state,
186
- gr.Number(label="Seed", value=100, render=False)
187
- ],
188
- examples=question_examples,
189
- cache_examples=False,
190
- css=chat_css,
191
  )
192
 
193
- with gr.Accordion("Settings", open=False):
194
-
195
- # Define the sliders
196
- temperature_slider = gr.Slider(
197
- minimum=0,
198
- maximum=1,
199
- step=0.1,
200
- value=default_temperature,
201
- label="Temperature"
202
- )
203
- max_new_tokens_slider = gr.Slider(
204
- minimum=128,
205
- maximum=4096,
206
- step=1,
207
- value=default_max_new_tokens,
208
- label="Max new tokens"
209
- )
210
- max_tokens_slider = gr.Slider(
211
- minimum=128,
212
- maximum=32000,
213
- step=1,
214
- value=default_max_tokens,
215
- label="Max tokens"
216
  )
217
- max_round_slider = gr.Slider(
218
- minimum=0,
219
- maximum=50,
220
- step=1,
221
- value=default_max_round,
222
- label="Max round")
223
-
224
- # Automatically update states when slider values change
225
- temperature_slider.change(
226
- lambda x: x, inputs=temperature_slider, outputs=temperature_state)
227
- max_new_tokens_slider.change(
228
- lambda x: x, inputs=max_new_tokens_slider, outputs=max_new_tokens_state)
229
- max_tokens_slider.change(
230
- lambda x: x, inputs=max_tokens_slider, outputs=max_tokens_state)
231
- max_round_slider.change(
232
- lambda x: x, inputs=max_round_slider, outputs=max_round_state)
233
 
234
- password_input = gr.Textbox(
235
- label="Enter Password for More Settings", type="password")
236
- incorrect_message = gr.Textbox(visible=False, interactive=False)
237
- with gr.Accordion("⚙️ Settings", open=False, visible=False) as protected_accordion:
238
- with gr.Row():
239
- with gr.Column(scale=1):
240
- with gr.Accordion("⚙️ Model Loading", open=False):
241
- model_name_input = gr.Textbox(
242
- label="Enter model path", value=model_name)
243
- load_model_btn = gr.Button(value="Load Model")
244
- load_model_btn.click(
245
- agent.load_models, inputs=model_name_input, outputs=gr.Textbox(label="Status"))
246
- with gr.Column(scale=1):
247
- with gr.Accordion("⚙️ Functional Parameters", open=False):
248
- # Create Gradio components for parameter inputs
249
- enable_finish = gr.Checkbox(
250
- label="Enable Finish", value=True)
251
- enable_rag = gr.Checkbox(
252
- label="Enable RAG", value=True)
253
- enable_summary = gr.Checkbox(
254
- label="Enable Summary", value=False)
255
- init_rag_num = gr.Number(
256
- label="Initial RAG Num", value=0)
257
- step_rag_num = gr.Number(
258
- label="Step RAG Num", value=10)
259
- skip_last_k = gr.Number(label="Skip Last K", value=0)
260
- summary_mode = gr.Textbox(
261
- label="Summary Mode", value='step')
262
- summary_skip_last_k = gr.Number(
263
- label="Summary Skip Last K", value=0)
264
- summary_context_length = gr.Number(
265
- label="Summary Context Length", value=None)
266
- force_finish = gr.Checkbox(
267
- label="Force FinalAnswer", value=True)
268
- seed = gr.Number(label="Seed", value=100)
269
- # Button to submit and update parameters
270
- submit_btn = gr.Button("Update Parameters")
271
-
272
- # Display the updated parameters
273
- updated_parameters_output = gr.JSON()
274
-
275
- # When button is clicked, update parameters
276
- submit_btn.click(fn=update_model_parameters,
277
- inputs=[enable_finish, enable_rag, enable_summary, init_rag_num, step_rag_num, skip_last_k,
278
- summary_mode, summary_skip_last_k, summary_context_length, force_finish, seed],
279
- outputs=updated_parameters_output)
280
- # Button to submit the password
281
- submit_button = gr.Button("Submit")
282
-
283
- # When the button is clicked, check if the password is correct
284
- submit_button.click(
285
- check_password,
286
- inputs=password_input,
287
- outputs=[protected_accordion, incorrect_message]
288
  )
289
- gr.Markdown(LICENSE)
290
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
291
 
292
- if __name__ == "__main__":
293
  demo.launch(share=True)
 
1
  import random
2
+ import os
3
  import datetime
4
  import sys
5
  from txagent import TxAgent
6
  import spaces
7
  import gradio as gr
 
 
8
 
9
+ # Set environment variables
10
  current_dir = os.path.dirname(os.path.abspath(__file__))
11
  os.environ["MKL_THREADING_LAYER"] = "GNU"
12
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
 
13
  HF_TOKEN = os.environ.get("HF_TOKEN", None)
14
 
 
15
  DESCRIPTION = '''
16
  <div>
17
  <h1 style="text-align: center;">TxAgent: An AI Agent for Therapeutic Reasoning Across a Universe of Tools </h1>
18
  </div>
19
  '''
20
  INTRO = """
21
+ Precision therapeutics require multimodal adaptive models that provide personalized treatment recommendations...
22
  """
 
23
  LICENSE = """
24
+ DISCLAIMER: THIS WEBSITE DOES NOT PROVIDE MEDICAL ADVICE...
 
 
 
 
25
  """
26
 
27
  PLACEHOLDER = """
 
29
  <h1 style="font-size: 28px; margin-bottom: 2px; opacity: 0.55;">TxAgent</h1>
30
  <p style="font-size: 18px; margin-bottom: 2px; opacity: 0.65;">Tips before using TxAgent:</p>
31
  <p style="font-size: 18px; margin-bottom: 2px; opacity: 0.55;">Please click clear🗑️
32
+ (top-right) to remove previous context before submitting a new question.</p>
33
+ <p style="font-size: 18px; margin-bottom: 2px; opacity: 0.55;">Click retry🔄 (below message) to get multiple versions of the answer.</p>
34
  </div>
35
  """
36
 
 
39
  text-align: center;
40
  display: block;
41
  }
 
42
  #duplicate-button {
43
  margin: auto;
44
  color: white;
 
58
  """
59
 
60
  chat_css = """
61
+ .gr-button { font-size: 20px !important; }
62
+ .gr-button svg { width: 32px !important; height: 32px !important; }
63
  """
64
 
 
65
  model_name = 'mims-harvard/TxAgent-T1-Llama-3.1-8B'
66
  rag_model_name = 'mims-harvard/ToolRAG-T1-GTE-Qwen2-1.5B'
67
 
 
 
 
68
  question_examples = [
69
  ['Given a 50-year-old patient experiencing severe acute pain and considering the use of the newly approved medication, Journavx, how should the dosage be adjusted considering the presence of moderate hepatic impairment?'],
70
  ['Given a 50-year-old patient experiencing severe acute pain and considering the use of the newly approved medication, Journavx, how should the dosage be adjusted considering the presence of severe hepatic impairment?'],
 
85
  additional_default_tools=['DirectResponse', 'RequireClarification'])
86
  agent.init_model()
87
 
 
88
  def update_model_parameters(enable_finish, enable_rag, enable_summary,
89
  init_rag_num, step_rag_num, skip_last_k,
90
  summary_mode, summary_skip_last_k, summary_context_length, force_finish, seed):
91
+ return agent.update_parameters(
 
92
  enable_finish=enable_finish,
93
  enable_rag=enable_rag,
94
  enable_summary=enable_summary,
 
102
  seed=seed,
103
  )
104
 
 
 
 
105
  def update_seed():
 
106
  seed = random.randint(0, 10000)
107
+ return agent.update_parameters(seed=seed)
 
 
 
 
108
 
109
  def handle_retry(history, retry_data: gr.RetryData, temperature, max_new_tokens, max_tokens, multi_agent, conversation, max_round):
110
+ update_seed()
111
  new_history = history[:retry_data.index]
112
  previous_prompt = history[retry_data.index]['content']
 
 
 
113
  yield from agent.run_gradio_chat(new_history + [{"role": "user", "content": previous_prompt}], temperature, max_new_tokens, max_tokens, multi_agent, conversation, max_round)
114
 
 
115
  PASSWORD = "mypassword"
116
 
 
 
 
117
  def check_password(input_password):
118
  if input_password == PASSWORD:
119
  return gr.update(visible=True), ""
 
121
  return gr.update(visible=False), "Incorrect password, try again!"
122
 
123
 
124
+ if __name__ == "__main__":
125
+ conversation_state = gr.State([])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
126
 
127
+ chatbot = gr.Chatbot(
128
+ height=800, placeholder=PLACEHOLDER, label='TxAgent',
129
+ type="messages", show_copy_button=True
 
 
 
 
 
 
 
 
 
 
 
 
 
 
130
  )
131
 
132
+ with gr.Blocks(css=css) as demo:
133
+ gr.Markdown(DESCRIPTION)
134
+ gr.Markdown(INTRO)
135
+
136
+ default_temperature = 0.3
137
+ default_max_new_tokens = 1024
138
+ default_max_tokens = 81920
139
+ default_max_round = 30
140
+
141
+ temperature_state = gr.State(value=default_temperature)
142
+ max_new_tokens_state = gr.State(value=default_max_new_tokens)
143
+ max_tokens_state = gr.State(value=default_max_tokens)
144
+ max_round_state = gr.State(value=default_max_round)
145
+
146
+ chatbot.retry(
147
+ handle_retry,
148
+ chatbot, chatbot,
149
+ temperature_state, max_new_tokens_state,
150
+ max_tokens_state,
151
+ gr.Checkbox(value=False, render=False),
152
+ conversation_state,
153
+ max_round_state
 
154
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
155
 
156
+ gr.ChatInterface(
157
+ fn=agent.run_gradio_chat,
158
+ chatbot=chatbot,
159
+ fill_height=True, fill_width=True, stop_btn=True,
160
+ additional_inputs_accordion=gr.Accordion(label="⚙️ Inference Parameters", open=False, render=False),
161
+ additional_inputs=[
162
+ temperature_state, max_new_tokens_state, max_tokens_state,
163
+ gr.Checkbox(label="Activate multi-agent reasoning mode", value=False, render=False),
164
+ conversation_state,
165
+ max_round_state,
166
+ gr.Number(label="Seed", value=100, render=False)
167
+ ],
168
+ examples=question_examples,
169
+ cache_examples=False,
170
+ css=chat_css,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
171
  )
 
172
 
173
+ with gr.Accordion("Settings", open=False):
174
+ temperature_slider = gr.Slider(0, 1, step=0.1, value=default_temperature, label="Temperature")
175
+ max_new_tokens_slider = gr.Slider(128, 4096, step=1, value=default_max_new_tokens, label="Max new tokens")
176
+ max_tokens_slider = gr.Slider(128, 32000, step=1, value=default_max_tokens, label="Max tokens")
177
+ max_round_slider = gr.Slider(0, 50, step=1, value=default_max_round, label="Max round")
178
+
179
+ temperature_slider.change(lambda x: x, inputs=temperature_slider, outputs=temperature_state)
180
+ max_new_tokens_slider.change(lambda x: x, inputs=max_new_tokens_slider, outputs=max_new_tokens_state)
181
+ max_tokens_slider.change(lambda x: x, inputs=max_tokens_slider, outputs=max_tokens_state)
182
+ max_round_slider.change(lambda x: x, inputs=max_round_slider, outputs=max_round_state)
183
+
184
+ password_input = gr.Textbox(label="Enter Password for More Settings", type="password")
185
+ incorrect_message = gr.Textbox(visible=False, interactive=False)
186
+
187
+ with gr.Accordion("⚙️ Settings", open=False, visible=False) as protected_accordion:
188
+ with gr.Row():
189
+ with gr.Column(scale=1):
190
+ with gr.Accordion("⚙️ Model Loading", open=False):
191
+ model_name_input = gr.Textbox(label="Enter model path", value=model_name)
192
+ load_model_btn = gr.Button(value="Load Model")
193
+ load_model_btn.click(agent.load_models, inputs=model_name_input, outputs=gr.Textbox(label="Status"))
194
+ with gr.Column(scale=1):
195
+ with gr.Accordion("⚙️ Functional Parameters", open=False):
196
+ enable_finish = gr.Checkbox(label="Enable Finish", value=True)
197
+ enable_rag = gr.Checkbox(label="Enable RAG", value=True)
198
+ enable_summary = gr.Checkbox(label="Enable Summary", value=False)
199
+ init_rag_num = gr.Number(label="Initial RAG Num", value=0)
200
+ step_rag_num = gr.Number(label="Step RAG Num", value=10)
201
+ skip_last_k = gr.Number(label="Skip Last K", value=0)
202
+ summary_mode = gr.Textbox(label="Summary Mode", value='step')
203
+ summary_skip_last_k = gr.Number(label="Summary Skip Last K", value=0)
204
+ summary_context_length = gr.Number(label="Summary Context Length", value=None)
205
+ force_finish = gr.Checkbox(label="Force FinalAnswer", value=True)
206
+ seed = gr.Number(label="Seed", value=100)
207
+
208
+ submit_btn = gr.Button("Update Parameters")
209
+ updated_parameters_output = gr.JSON()
210
+ submit_btn.click(
211
+ fn=update_model_parameters,
212
+ inputs=[
213
+ enable_finish, enable_rag, enable_summary, init_rag_num, step_rag_num, skip_last_k,
214
+ summary_mode, summary_skip_last_k, summary_context_length, force_finish, seed
215
+ ],
216
+ outputs=updated_parameters_output
217
+ )
218
+
219
+ submit_button = gr.Button("Submit")
220
+ submit_button.click(
221
+ check_password,
222
+ inputs=password_input,
223
+ outputs=[protected_accordion, incorrect_message]
224
+ )
225
+
226
+ gr.Markdown(LICENSE)
227
 
 
228
  demo.launch(share=True)