Ali2206 commited on
Commit
f15352f
·
verified ·
1 Parent(s): 37d892a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +76 -89
app.py CHANGED
@@ -1,6 +1,5 @@
1
  import gradio as gr
2
  import logging
3
- import multiprocessing
4
  from txagent import TxAgent
5
  from tooluniverse import ToolUniverse
6
  from importlib.resources import files
@@ -9,100 +8,84 @@ from importlib.resources import files
9
  logging.basicConfig(level=logging.INFO)
10
  logger = logging.getLogger(__name__)
11
 
12
- tx_app = None # Global holder for app instance (for Gradio to use)
13
 
14
  def init_txagent():
15
  """Initialize the TxAgent with proper tool file paths"""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
  try:
17
- multiprocessing.set_start_method("spawn", force=True)
18
- logger.info("Initializing TxAgent...")
19
-
20
- tool_files = {
21
- "opentarget": str(files('tooluniverse.data').joinpath('opentarget_tools.json')),
22
- "fda_drug_label": str(files('tooluniverse.data').joinpath('fda_drug_labeling_tools.json')),
23
- "special_tools": str(files('tooluniverse.data').joinpath('special_tools.json')),
24
- "monarch": str(files('tooluniverse.data').joinpath('monarch_tools.json'))
25
- }
26
-
27
- logger.info(f"Using tool files at: {tool_files}")
28
-
29
- agent = TxAgent(
30
- model_name="mims-harvard/TxAgent-T1-Llama-3.1-8B",
31
- rag_model_name="mims-harvard/ToolRAG-T1-GTE-Qwen2-1.5B",
32
- tool_files_dict=tool_files,
33
- enable_finish=True,
34
- enable_rag=True,
35
- enable_summary=False,
36
- init_rag_num=0,
37
- step_rag_num=10,
38
- summary_mode='step',
39
- summary_skip_last_k=0,
40
- summary_context_length=None,
41
- force_finish=True,
42
- avoid_repeat=True,
43
- seed=42,
44
- enable_checker=True,
45
- enable_chat=False,
46
- additional_default_tools=["DirectResponse", "RequireClarification"]
47
- )
48
-
49
- agent.init_model()
50
- logger.info("Model loading complete")
51
- return agent
52
 
53
  except Exception as e:
54
- logger.error(f"Initialization failed: {str(e)}")
55
- raise
56
-
57
- class TxAgentApp:
58
- def __init__(self):
59
- self.agent = init_txagent()
60
-
61
- def respond(self, message, chat_history, temperature, max_new_tokens, max_tokens, multi_agent, conversation_state, max_round):
62
- """Handle streaming responses with Gradio"""
63
- try:
64
- if not isinstance(message, str) or len(message.strip()) <= 10:
65
- return chat_history + [("", "Please provide a valid message longer than 10 characters.")]
66
-
67
- if chat_history and isinstance(chat_history[0], dict):
68
- chat_history = [(h["role"], h["content"]) for h in chat_history if "role" in h and "content" in h]
69
-
70
- response = ""
71
- for chunk in self.agent.run_gradio_chat(
72
- message=message.strip(),
73
- history=chat_history,
74
- temperature=temperature,
75
- max_new_tokens=max_new_tokens,
76
- max_token=max_tokens,
77
- call_agent=multi_agent,
78
- conversation=conversation_state,
79
- max_round=max_round,
80
- seed=42
81
- ):
82
- if isinstance(chunk, dict):
83
- response += chunk.get("content", "")
84
- elif isinstance(chunk, str):
85
- response += chunk
86
- else:
87
- response += str(chunk)
88
-
89
- yield chat_history + [("user", message), ("assistant", response)]
90
-
91
- except Exception as e:
92
- logger.error(f"Error in respond function: {str(e)}")
93
- yield chat_history + [("", f"⚠️ Error: {str(e)}")]
94
-
95
- # Initialize the agent safely
96
- tx_app = TxAgentApp()
97
-
98
- # Define Gradio UI interface
99
  with gr.Blocks(title="TxAgent Biomedical Assistant") as app:
100
  gr.Markdown("# 🧠 TxAgent Biomedical Assistant")
101
 
102
- chatbot = gr.Chatbot(
103
- label="Conversation",
104
- height=600
105
- )
106
 
107
  msg = gr.Textbox(
108
  label="Your medical query",
@@ -123,7 +106,7 @@ with gr.Blocks(title="TxAgent Biomedical Assistant") as app:
123
  conversation_state = gr.State([])
124
 
125
  submit.click(
126
- tx_app.respond,
127
  [msg, chatbot, temp, max_new_tokens, max_tokens, multi_agent, conversation_state, max_rounds],
128
  chatbot
129
  )
@@ -131,9 +114,13 @@ with gr.Blocks(title="TxAgent Biomedical Assistant") as app:
131
  clear.click(lambda: [], None, chatbot)
132
 
133
  msg.submit(
134
- tx_app.respond,
135
  [msg, chatbot, temp, max_new_tokens, max_tokens, multi_agent, conversation_state, max_rounds],
136
  chatbot
137
  )
138
 
139
- # This `app` will be served by Hugging Face automatically
 
 
 
 
 
1
  import gradio as gr
2
  import logging
 
3
  from txagent import TxAgent
4
  from tooluniverse import ToolUniverse
5
  from importlib.resources import files
 
8
  logging.basicConfig(level=logging.INFO)
9
  logger = logging.getLogger(__name__)
10
 
11
+ tx_app = None # Will be initialized later in on_start
12
 
13
  def init_txagent():
14
  """Initialize the TxAgent with proper tool file paths"""
15
+ logger.info("Initializing TxAgent...")
16
+
17
+ tool_files = {
18
+ "opentarget": str(files('tooluniverse.data').joinpath('opentarget_tools.json')),
19
+ "fda_drug_label": str(files('tooluniverse.data').joinpath('fda_drug_labeling_tools.json')),
20
+ "special_tools": str(files('tooluniverse.data').joinpath('special_tools.json')),
21
+ "monarch": str(files('tooluniverse.data').joinpath('monarch_tools.json'))
22
+ }
23
+
24
+ logger.info(f"Using tool files at: {tool_files}")
25
+
26
+ agent = TxAgent(
27
+ model_name="mims-harvard/TxAgent-T1-Llama-3.1-8B",
28
+ rag_model_name="mims-harvard/ToolRAG-T1-GTE-Qwen2-1.5B",
29
+ tool_files_dict=tool_files,
30
+ enable_finish=True,
31
+ enable_rag=True,
32
+ enable_summary=False,
33
+ init_rag_num=0,
34
+ step_rag_num=10,
35
+ summary_mode='step',
36
+ summary_skip_last_k=0,
37
+ summary_context_length=None,
38
+ force_finish=True,
39
+ avoid_repeat=True,
40
+ seed=42,
41
+ enable_checker=True,
42
+ enable_chat=False,
43
+ additional_default_tools=["DirectResponse", "RequireClarification"]
44
+ )
45
+
46
+ agent.init_model()
47
+ logger.info("Model loading complete")
48
+ return agent
49
+
50
+ def respond(message, chat_history, temperature, max_new_tokens, max_tokens, multi_agent, conversation_state, max_round):
51
+ global tx_app
52
  try:
53
+ if not isinstance(message, str) or len(message.strip()) <= 10:
54
+ return chat_history + [("", "Please provide a valid message longer than 10 characters.")]
55
+
56
+ if chat_history and isinstance(chat_history[0], dict):
57
+ chat_history = [(h["role"], h["content"]) for h in chat_history if "role" in h and "content" in h]
58
+
59
+ response = ""
60
+ for chunk in tx_app.run_gradio_chat(
61
+ message=message.strip(),
62
+ history=chat_history,
63
+ temperature=temperature,
64
+ max_new_tokens=max_new_tokens,
65
+ max_token=max_tokens,
66
+ call_agent=multi_agent,
67
+ conversation=conversation_state,
68
+ max_round=max_round,
69
+ seed=42
70
+ ):
71
+ if isinstance(chunk, dict):
72
+ response += chunk.get("content", "")
73
+ elif isinstance(chunk, str):
74
+ response += chunk
75
+ else:
76
+ response += str(chunk)
77
+
78
+ yield chat_history + [("user", message), ("assistant", response)]
 
 
 
 
 
 
 
 
 
79
 
80
  except Exception as e:
81
+ logger.error(f"Error in respond function: {str(e)}")
82
+ yield chat_history + [("", f"⚠️ Error: {str(e)}")]
83
+
84
+ # Define Gradio UI
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
85
  with gr.Blocks(title="TxAgent Biomedical Assistant") as app:
86
  gr.Markdown("# 🧠 TxAgent Biomedical Assistant")
87
 
88
+ chatbot = gr.Chatbot(label="Conversation", height=600)
 
 
 
89
 
90
  msg = gr.Textbox(
91
  label="Your medical query",
 
106
  conversation_state = gr.State([])
107
 
108
  submit.click(
109
+ respond,
110
  [msg, chatbot, temp, max_new_tokens, max_tokens, multi_agent, conversation_state, max_rounds],
111
  chatbot
112
  )
 
114
  clear.click(lambda: [], None, chatbot)
115
 
116
  msg.submit(
117
+ respond,
118
  [msg, chatbot, temp, max_new_tokens, max_tokens, multi_agent, conversation_state, max_rounds],
119
  chatbot
120
  )
121
 
122
+ @app.on_start
123
+ def load_model():
124
+ global tx_app
125
+ logger.info("🔥 Loading TxAgent model in Gradio @on_start")
126
+ tx_app = init_txagent()