Update app.py
Browse files
app.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1 |
import gradio as gr
|
2 |
-
import os
|
3 |
import logging
|
|
|
4 |
from txagent import TxAgent
|
5 |
from tooluniverse import ToolUniverse
|
6 |
from importlib.resources import files
|
@@ -9,52 +9,54 @@ from importlib.resources import files
|
|
9 |
logging.basicConfig(level=logging.INFO)
|
10 |
logger = logging.getLogger(__name__)
|
11 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
12 |
class TxAgentApp:
|
13 |
def __init__(self):
|
14 |
-
self.agent =
|
15 |
-
|
16 |
-
def _initialize_agent(self):
|
17 |
-
"""Initialize the TxAgent with proper tool file paths"""
|
18 |
-
try:
|
19 |
-
logger.info("Initializing TxAgent...")
|
20 |
-
|
21 |
-
# Get absolute paths to tool files from package installation
|
22 |
-
tool_files = {
|
23 |
-
"opentarget": str(files('tooluniverse.data').joinpath('opentarget_tools.json')),
|
24 |
-
"fda_drug_label": str(files('tooluniverse.data').joinpath('fda_drug_labeling_tools.json')),
|
25 |
-
"special_tools": str(files('tooluniverse.data').joinpath('special_tools.json')),
|
26 |
-
"monarch": str(files('tooluniverse.data').joinpath('monarch_tools.json'))
|
27 |
-
}
|
28 |
-
|
29 |
-
logger.info(f"Using tool files at: {tool_files}")
|
30 |
-
|
31 |
-
agent = TxAgent(
|
32 |
-
model_name="mims-harvard/TxAgent-T1-Llama-3.1-8B",
|
33 |
-
rag_model_name="mims-harvard/ToolRAG-T1-GTE-Qwen2-1.5B",
|
34 |
-
tool_files_dict=tool_files,
|
35 |
-
enable_finish=True,
|
36 |
-
enable_rag=True,
|
37 |
-
enable_summary=False,
|
38 |
-
init_rag_num=0,
|
39 |
-
step_rag_num=10,
|
40 |
-
summary_mode='step',
|
41 |
-
summary_skip_last_k=0,
|
42 |
-
summary_context_length=None,
|
43 |
-
force_finish=True,
|
44 |
-
avoid_repeat=True,
|
45 |
-
seed=42,
|
46 |
-
enable_checker=True,
|
47 |
-
enable_chat=False,
|
48 |
-
additional_default_tools=["DirectResponse", "RequireClarification"]
|
49 |
-
)
|
50 |
-
|
51 |
-
agent.init_model()
|
52 |
-
logger.info("Model loading complete")
|
53 |
-
return agent
|
54 |
-
|
55 |
-
except Exception as e:
|
56 |
-
logger.error(f"Initialization failed: {str(e)}")
|
57 |
-
raise
|
58 |
|
59 |
def respond(self, message, chat_history, temperature, max_new_tokens, max_tokens, multi_agent, conversation_state, max_round):
|
60 |
"""Handle streaming responses with Gradio"""
|
@@ -90,10 +92,10 @@ class TxAgentApp:
|
|
90 |
logger.error(f"Error in respond function: {str(e)}")
|
91 |
yield chat_history + [("", f"⚠️ Error: {str(e)}")]
|
92 |
|
93 |
-
# Initialize
|
94 |
tx_app = TxAgentApp()
|
95 |
|
96 |
-
# Define Gradio UI
|
97 |
with gr.Blocks(title="TxAgent Biomedical Assistant") as app:
|
98 |
gr.Markdown("# 🧠 TxAgent Biomedical Assistant")
|
99 |
|
@@ -134,4 +136,4 @@ with gr.Blocks(title="TxAgent Biomedical Assistant") as app:
|
|
134 |
chatbot
|
135 |
)
|
136 |
|
137 |
-
# `app` will be
|
|
|
1 |
import gradio as gr
|
|
|
2 |
import logging
|
3 |
+
import multiprocessing
|
4 |
from txagent import TxAgent
|
5 |
from tooluniverse import ToolUniverse
|
6 |
from importlib.resources import files
|
|
|
9 |
logging.basicConfig(level=logging.INFO)
|
10 |
logger = logging.getLogger(__name__)
|
11 |
|
12 |
+
tx_app = None # Global holder for app instance (for Gradio to use)
|
13 |
+
|
14 |
+
def init_txagent():
|
15 |
+
"""Initialize the TxAgent with proper tool file paths"""
|
16 |
+
try:
|
17 |
+
multiprocessing.set_start_method("spawn", force=True)
|
18 |
+
logger.info("Initializing TxAgent...")
|
19 |
+
|
20 |
+
tool_files = {
|
21 |
+
"opentarget": str(files('tooluniverse.data').joinpath('opentarget_tools.json')),
|
22 |
+
"fda_drug_label": str(files('tooluniverse.data').joinpath('fda_drug_labeling_tools.json')),
|
23 |
+
"special_tools": str(files('tooluniverse.data').joinpath('special_tools.json')),
|
24 |
+
"monarch": str(files('tooluniverse.data').joinpath('monarch_tools.json'))
|
25 |
+
}
|
26 |
+
|
27 |
+
logger.info(f"Using tool files at: {tool_files}")
|
28 |
+
|
29 |
+
agent = TxAgent(
|
30 |
+
model_name="mims-harvard/TxAgent-T1-Llama-3.1-8B",
|
31 |
+
rag_model_name="mims-harvard/ToolRAG-T1-GTE-Qwen2-1.5B",
|
32 |
+
tool_files_dict=tool_files,
|
33 |
+
enable_finish=True,
|
34 |
+
enable_rag=True,
|
35 |
+
enable_summary=False,
|
36 |
+
init_rag_num=0,
|
37 |
+
step_rag_num=10,
|
38 |
+
summary_mode='step',
|
39 |
+
summary_skip_last_k=0,
|
40 |
+
summary_context_length=None,
|
41 |
+
force_finish=True,
|
42 |
+
avoid_repeat=True,
|
43 |
+
seed=42,
|
44 |
+
enable_checker=True,
|
45 |
+
enable_chat=False,
|
46 |
+
additional_default_tools=["DirectResponse", "RequireClarification"]
|
47 |
+
)
|
48 |
+
|
49 |
+
agent.init_model()
|
50 |
+
logger.info("Model loading complete")
|
51 |
+
return agent
|
52 |
+
|
53 |
+
except Exception as e:
|
54 |
+
logger.error(f"Initialization failed: {str(e)}")
|
55 |
+
raise
|
56 |
+
|
57 |
class TxAgentApp:
|
58 |
def __init__(self):
|
59 |
+
self.agent = init_txagent()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
60 |
|
61 |
def respond(self, message, chat_history, temperature, max_new_tokens, max_tokens, multi_agent, conversation_state, max_round):
|
62 |
"""Handle streaming responses with Gradio"""
|
|
|
92 |
logger.error(f"Error in respond function: {str(e)}")
|
93 |
yield chat_history + [("", f"⚠️ Error: {str(e)}")]
|
94 |
|
95 |
+
# Initialize the agent safely
|
96 |
tx_app = TxAgentApp()
|
97 |
|
98 |
+
# Define Gradio UI interface
|
99 |
with gr.Blocks(title="TxAgent Biomedical Assistant") as app:
|
100 |
gr.Markdown("# 🧠 TxAgent Biomedical Assistant")
|
101 |
|
|
|
136 |
chatbot
|
137 |
)
|
138 |
|
139 |
+
# This `app` will be served by Hugging Face automatically
|