Ali2206 commited on
Commit
458d2c3
·
verified ·
1 Parent(s): 1c98688

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +19 -7
app.py CHANGED
@@ -1,10 +1,12 @@
1
  import gradio as gr
2
  import logging
 
3
 
4
  logging.basicConfig(level=logging.INFO)
5
  logger = logging.getLogger(__name__)
6
 
7
- tx_app = None # global agent
 
8
 
9
  def respond(message, chat_history, temperature, max_new_tokens, max_tokens, multi_agent, conversation_state, max_round):
10
  global tx_app
@@ -42,7 +44,7 @@ def respond(message, chat_history, temperature, max_new_tokens, max_tokens, mult
42
  logger.error(f"Respond error: {e}")
43
  yield chat_history + [("", f"⚠️ Error: {e}")]
44
 
45
- # Define Gradio app at module level so Hugging Face Spaces can find it
46
  with gr.Blocks(title="TxAgent Biomedical Assistant") as app:
47
  gr.Markdown("# 🧠 TxAgent Biomedical Assistant")
48
 
@@ -72,16 +74,15 @@ with gr.Blocks(title="TxAgent Biomedical Assistant") as app:
72
  chatbot
73
  )
74
 
75
- # 🔥 Safely initialize vLLM inside __main__
76
  if __name__ == "__main__":
77
  import multiprocessing
78
  multiprocessing.set_start_method("spawn", force=True)
79
 
80
- import torch
81
  from txagent import TxAgent
82
  from importlib.resources import files
83
 
84
- logger.info("🔥 Initializing TxAgent safely in __main__")
85
 
86
  tool_files = {
87
  "opentarget": str(files('tooluniverse.data').joinpath('opentarget_tools.json')),
@@ -90,6 +91,7 @@ if __name__ == "__main__":
90
  "monarch": str(files('tooluniverse.data').joinpath('monarch_tools.json'))
91
  }
92
 
 
93
  tx_app = TxAgent(
94
  model_name="mims-harvard/TxAgent-T1-Llama-3.1-8B",
95
  rag_model_name="mims-harvard/ToolRAG-T1-GTE-Qwen2-1.5B",
@@ -110,5 +112,15 @@ if __name__ == "__main__":
110
  additional_default_tools=["DirectResponse", "RequireClarification"]
111
  )
112
 
113
- tx_app.init_model()
114
- logger.info("✅ TxAgent initialized.")
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
  import logging
3
+ import os
4
 
5
  logging.basicConfig(level=logging.INFO)
6
  logger = logging.getLogger(__name__)
7
 
8
+ tx_app = None
9
+ TOOL_CACHE_PATH = "/home/user/.cache/tool_embeddings_done" # flag file for skip
10
 
11
  def respond(message, chat_history, temperature, max_new_tokens, max_tokens, multi_agent, conversation_state, max_round):
12
  global tx_app
 
44
  logger.error(f"Respond error: {e}")
45
  yield chat_history + [("", f"⚠️ Error: {e}")]
46
 
47
+ # === Define Gradio interface ===
48
  with gr.Blocks(title="TxAgent Biomedical Assistant") as app:
49
  gr.Markdown("# 🧠 TxAgent Biomedical Assistant")
50
 
 
74
  chatbot
75
  )
76
 
77
+ # === Safe model init block for vLLM + Hugging Face ===
78
  if __name__ == "__main__":
79
  import multiprocessing
80
  multiprocessing.set_start_method("spawn", force=True)
81
 
 
82
  from txagent import TxAgent
83
  from importlib.resources import files
84
 
85
+ logger.info("🔥 Initializing TxAgent inside __main__...")
86
 
87
  tool_files = {
88
  "opentarget": str(files('tooluniverse.data').joinpath('opentarget_tools.json')),
 
91
  "monarch": str(files('tooluniverse.data').joinpath('monarch_tools.json'))
92
  }
93
 
94
+ # Initialize agent
95
  tx_app = TxAgent(
96
  model_name="mims-harvard/TxAgent-T1-Llama-3.1-8B",
97
  rag_model_name="mims-harvard/ToolRAG-T1-GTE-Qwen2-1.5B",
 
112
  additional_default_tools=["DirectResponse", "RequireClarification"]
113
  )
114
 
115
+ # ✅ Only do tool embedding the first time
116
+ if not os.path.exists(TOOL_CACHE_PATH):
117
+ logger.info("🔧 First run: running full model + embedding")
118
+ tx_app.init_model() # runs full setup
119
+ os.makedirs(os.path.dirname(TOOL_CACHE_PATH), exist_ok=True)
120
+ with open(TOOL_CACHE_PATH, "w") as f:
121
+ f.write("done")
122
+ else:
123
+ logger.info("⚡️ Skipping tool embedding (cached)...")
124
+ tx_app.init_model(skip_tool_embedding=True) # assumes this param is supported
125
+
126
+ logger.info("✅ TxAgent is ready!")