File size: 4,867 Bytes
9aeb1dd 1c98688 458d2c3 9fcd791 bb17715 41dec39 1c98688 410d25f 458d2c3 9fcd791 79fb3cd 41dec39 1c98688 9fcd791 1c98688 41dec39 1c98688 41dec39 1c98688 41dec39 1c98688 41dec39 c117260 41dec39 c117260 41dec39 1c98688 41dec39 1c98688 c117260 1c98688 41dec39 458d2c3 9fcd791 458d2c3 c117260 41dec39 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 |
import gradio as gr
import logging
import os
import multiprocessing
# Setup logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
tx_app = None
TOOL_CACHE_PATH = "/home/user/.cache/tool_embeddings_done"
# Chatbot response function
def respond(message, chat_history, temperature, max_new_tokens, max_tokens, multi_agent, conversation_state, max_round):
global tx_app
if tx_app is None:
return chat_history + [("", "β οΈ Model is still loading. Please wait a few seconds and try again.")]
try:
if not isinstance(message, str) or len(message.strip()) < 10:
return chat_history + [("", "Please enter a longer message.")]
if chat_history and isinstance(chat_history[0], dict):
chat_history = [(h["role"], h["content"]) for h in chat_history if "role" in h and "content" in h]
response = ""
for chunk in tx_app.run_gradio_chat(
message=message.strip(),
history=chat_history,
temperature=temperature,
max_new_tokens=max_new_tokens,
max_token=max_tokens,
call_agent=multi_agent,
conversation=conversation_state,
max_round=max_round,
seed=42,
):
if isinstance(chunk, dict):
response += chunk.get("content", "")
elif isinstance(chunk, str):
response += chunk
else:
response += str(chunk)
yield chat_history + [("user", message), ("assistant", response)]
except Exception as e:
logger.error(f"Respond error: {e}")
yield chat_history + [("", f"β οΈ Error: {e}")]
# === Gradio UI ===
with gr.Blocks(title="TxAgent Biomedical Assistant") as app:
gr.Markdown("# π§ TxAgent Biomedical Assistant")
chatbot = gr.Chatbot(label="Conversation", height=600, type="messages")
msg = gr.Textbox(label="Your medical query", placeholder="Type your biomedical question...", lines=3)
with gr.Row():
temp = gr.Slider(0, 1, value=0.3, label="Temperature")
max_new_tokens = gr.Slider(128, 4096, value=1024, label="Max New Tokens")
max_tokens = gr.Slider(128, 81920, value=81920, label="Max Total Tokens")
max_rounds = gr.Slider(1, 30, value=10, label="Max Rounds")
multi_agent = gr.Checkbox(label="Multi-Agent Mode")
conversation_state = gr.State([])
submit = gr.Button("Submit")
clear = gr.Button("Clear")
submit.click(
respond,
[msg, chatbot, temp, max_new_tokens, max_tokens, multi_agent, conversation_state, max_rounds],
chatbot
)
clear.click(lambda: [], None, chatbot)
msg.submit(
respond,
[msg, chatbot, temp, max_new_tokens, max_tokens, multi_agent, conversation_state, max_rounds],
chatbot
)
# === Safe model initialization ===
if __name__ == "__main__":
multiprocessing.set_start_method("spawn", force=True)
import tooluniverse
from txagent import TxAgent
from importlib.resources import files
# β
Patch ToolUniverse to prevent exit() after embedding
original_infer = tooluniverse.ToolUniverse.infer_tool_embeddings
def patched_infer(self, *args, **kwargs):
original_infer(self, *args, **kwargs)
print("β
Patched: Skipping forced exit() after embedding.")
tooluniverse.ToolUniverse.infer_tool_embeddings = patched_infer
logger.info("π₯ Initializing TxAgent...")
tool_files = {
"opentarget": str(files('tooluniverse.data').joinpath('opentarget_tools.json')),
"fda_drug_label": str(files('tooluniverse.data').joinpath('fda_drug_labeling_tools.json')),
"special_tools": str(files('tooluniverse.data').joinpath('special_tools.json')),
"monarch": str(files('tooluniverse.data').joinpath('monarch_tools.json'))
}
tx_app = TxAgent(
model_name="mims-harvard/TxAgent-T1-Llama-3.1-8B",
rag_model_name="mims-harvard/ToolRAG-T1-GTE-Qwen2-1.5B",
tool_files_dict=tool_files,
enable_finish=True,
enable_rag=True,
enable_summary=False,
init_rag_num=0,
step_rag_num=10,
summary_mode='step',
summary_skip_last_k=0,
summary_context_length=None,
force_finish=True,
avoid_repeat=True,
seed=42,
enable_checker=True,
enable_chat=False,
additional_default_tools=["DirectResponse", "RequireClarification"]
)
# π Run full embedding once, then cache
if not os.path.exists(TOOL_CACHE_PATH):
tx_app.init_model()
os.makedirs(os.path.dirname(TOOL_CACHE_PATH), exist_ok=True)
with open(TOOL_CACHE_PATH, "w") as f:
f.write("done")
else:
tx_app.init_model(skip_tool_embedding=True)
logger.info("β
TxAgent ready.")
|