Ali2206 commited on
Commit
bb17715
·
verified ·
1 Parent(s): 696fd36

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +58 -56
app.py CHANGED
@@ -1,57 +1,23 @@
1
  import gradio as gr
2
  import logging
3
- from txagent import TxAgent
4
- from tooluniverse import ToolUniverse
5
- from importlib.resources import files
6
 
7
  logging.basicConfig(level=logging.INFO)
8
  logger = logging.getLogger(__name__)
9
 
10
- tx_app = None # Global TxAgent instance
11
-
12
- def init_txagent():
13
- logger.info("🔥 Initializing TxAgent...")
14
-
15
- tool_files = {
16
- "opentarget": str(files('tooluniverse.data').joinpath('opentarget_tools.json')),
17
- "fda_drug_label": str(files('tooluniverse.data').joinpath('fda_drug_labeling_tools.json')),
18
- "special_tools": str(files('tooluniverse.data').joinpath('special_tools.json')),
19
- "monarch": str(files('tooluniverse.data').joinpath('monarch_tools.json'))
20
- }
21
-
22
- agent = TxAgent(
23
- model_name="mims-harvard/TxAgent-T1-Llama-3.1-8B",
24
- rag_model_name="mims-harvard/ToolRAG-T1-GTE-Qwen2-1.5B",
25
- tool_files_dict=tool_files,
26
- enable_finish=True,
27
- enable_rag=True,
28
- enable_summary=False,
29
- init_rag_num=0,
30
- step_rag_num=10,
31
- summary_mode='step',
32
- summary_skip_last_k=0,
33
- summary_context_length=None,
34
- force_finish=True,
35
- avoid_repeat=True,
36
- seed=42,
37
- enable_checker=True,
38
- enable_chat=False,
39
- additional_default_tools=["DirectResponse", "RequireClarification"]
40
- )
41
-
42
- agent.init_model()
43
- logger.info("✅ TxAgent fully initialized")
44
- return agent
45
-
46
  def respond(message, chat_history, temperature, max_new_tokens, max_tokens, multi_agent, conversation_state, max_round):
47
  global tx_app
48
  if tx_app is None:
49
- return chat_history + [("", "⚠️ Model not ready yet. Please wait a few seconds and try again.")]
50
 
51
  try:
52
- if not isinstance(message, str) or len(message.strip()) <= 10:
53
- return chat_history + [("", "Please provide a valid message longer than 10 characters.")]
54
-
 
55
  if chat_history and isinstance(chat_history[0], dict):
56
  chat_history = [(h["role"], h["content"]) for h in chat_history if "role" in h and "content" in h]
57
 
@@ -65,7 +31,7 @@ def respond(message, chat_history, temperature, max_new_tokens, max_tokens, mult
65
  call_agent=multi_agent,
66
  conversation=conversation_state,
67
  max_round=max_round,
68
- seed=42
69
  ):
70
  if isinstance(chunk, dict):
71
  response += chunk.get("content", "")
@@ -77,15 +43,20 @@ def respond(message, chat_history, temperature, max_new_tokens, max_tokens, mult
77
  yield chat_history + [("user", message), ("assistant", response)]
78
 
79
  except Exception as e:
80
- logger.error(f"Error in respond function: {str(e)}")
81
- yield chat_history + [("", f"⚠️ Error: {str(e)}")]
82
 
83
- # Top-level app object that HF Spaces can detect
84
  with gr.Blocks(title="TxAgent Biomedical Assistant") as app:
85
  gr.Markdown("# 🧠 TxAgent Biomedical Assistant")
86
 
87
  chatbot = gr.Chatbot(label="Conversation", height=600, type="messages")
88
- msg = gr.Textbox(label="Your medical query", placeholder="Enter your biomedical question...", lines=3)
 
 
 
 
 
89
 
90
  with gr.Row():
91
  temp = gr.Slider(0, 1, value=0.3, label="Temperature")
@@ -103,22 +74,53 @@ with gr.Blocks(title="TxAgent Biomedical Assistant") as app:
103
  [msg, chatbot, temp, max_new_tokens, max_tokens, multi_agent, conversation_state, max_rounds],
104
  chatbot
105
  )
106
-
107
  clear.click(lambda: [], None, chatbot)
108
-
109
  msg.submit(
110
  respond,
111
  [msg, chatbot, temp, max_new_tokens, max_tokens, multi_agent, conversation_state, max_rounds],
112
  chatbot
113
  )
114
 
115
- # hidden init trigger on page load
116
- hidden_button = gr.Button(visible=False)
117
 
118
- def initialize_agent():
119
  global tx_app
120
- tx_app = init_txagent()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
121
  return gr.update(visible=False)
122
 
123
- app.load(hidden_button.click(fn=initialize_agent))
124
-
 
1
  import gradio as gr
2
  import logging
3
+
4
+ # Delay heavy imports until later to avoid multiprocessing conflicts
5
+ tx_app = None # Global agent instance
6
 
7
  logging.basicConfig(level=logging.INFO)
8
  logger = logging.getLogger(__name__)
9
 
10
+ # ========== Dummy Response (will be replaced by real agent later) ==========
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
  def respond(message, chat_history, temperature, max_new_tokens, max_tokens, multi_agent, conversation_state, max_round):
12
  global tx_app
13
  if tx_app is None:
14
+ return chat_history + [("", "⚠️ Model is still loading. Please wait a few seconds and try again.")]
15
 
16
  try:
17
+ if not isinstance(message, str) or len(message.strip()) < 10:
18
+ return chat_history + [("", "Please enter a longer message.")]
19
+
20
+ # Convert chat format if needed
21
  if chat_history and isinstance(chat_history[0], dict):
22
  chat_history = [(h["role"], h["content"]) for h in chat_history if "role" in h and "content" in h]
23
 
 
31
  call_agent=multi_agent,
32
  conversation=conversation_state,
33
  max_round=max_round,
34
+ seed=42,
35
  ):
36
  if isinstance(chunk, dict):
37
  response += chunk.get("content", "")
 
43
  yield chat_history + [("user", message), ("assistant", response)]
44
 
45
  except Exception as e:
46
+ logger.error(f"Respond error: {e}")
47
+ yield chat_history + [("", f"⚠️ Error: {e}")]
48
 
49
+ # ========== Gradio UI ==========
50
  with gr.Blocks(title="TxAgent Biomedical Assistant") as app:
51
  gr.Markdown("# 🧠 TxAgent Biomedical Assistant")
52
 
53
  chatbot = gr.Chatbot(label="Conversation", height=600, type="messages")
54
+
55
+ msg = gr.Textbox(
56
+ label="Your medical query",
57
+ placeholder="Enter your biomedical question...",
58
+ lines=3
59
+ )
60
 
61
  with gr.Row():
62
  temp = gr.Slider(0, 1, value=0.3, label="Temperature")
 
74
  [msg, chatbot, temp, max_new_tokens, max_tokens, multi_agent, conversation_state, max_rounds],
75
  chatbot
76
  )
 
77
  clear.click(lambda: [], None, chatbot)
 
78
  msg.submit(
79
  respond,
80
  [msg, chatbot, temp, max_new_tokens, max_tokens, multi_agent, conversation_state, max_rounds],
81
  chatbot
82
  )
83
 
84
+ # === Hidden trigger to load model safely on app start ===
85
+ init_button = gr.Button(visible=False)
86
 
87
+ def load_model():
88
  global tx_app
89
+ import torch
90
+ from txagent import TxAgent
91
+ from importlib.resources import files
92
+
93
+ logger.info("🔧 Loading full TxAgent model...")
94
+
95
+ tool_files = {
96
+ "opentarget": str(files('tooluniverse.data').joinpath('opentarget_tools.json')),
97
+ "fda_drug_label": str(files('tooluniverse.data').joinpath('fda_drug_labeling_tools.json')),
98
+ "special_tools": str(files('tooluniverse.data').joinpath('special_tools.json')),
99
+ "monarch": str(files('tooluniverse.data').joinpath('monarch_tools.json'))
100
+ }
101
+
102
+ tx_app = TxAgent(
103
+ model_name="mims-harvard/TxAgent-T1-Llama-3.1-8B",
104
+ rag_model_name="mims-harvard/ToolRAG-T1-GTE-Qwen2-1.5B",
105
+ tool_files_dict=tool_files,
106
+ enable_finish=True,
107
+ enable_rag=True,
108
+ enable_summary=False,
109
+ init_rag_num=0,
110
+ step_rag_num=10,
111
+ summary_mode='step',
112
+ summary_skip_last_k=0,
113
+ summary_context_length=None,
114
+ force_finish=True,
115
+ avoid_repeat=True,
116
+ seed=42,
117
+ enable_checker=True,
118
+ enable_chat=False,
119
+ additional_default_tools=["DirectResponse", "RequireClarification"]
120
+ )
121
+
122
+ tx_app.init_model()
123
+ logger.info("✅ Model initialized successfully")
124
  return gr.update(visible=False)
125
 
126
+ app.load(init_button.click(fn=load_model))