Ali2206 commited on
Commit
167b103
·
verified ·
1 Parent(s): 8662842

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +50 -85
app.py CHANGED
@@ -1,47 +1,54 @@
1
  import os
 
 
2
  import torch
3
- import requests
4
- from huggingface_hub import hf_hub_download, snapshot_download
5
  from txagent import TxAgent
6
  import gradio as gr
 
 
7
 
8
  # Configuration
9
  CONFIG = {
10
  "model_name": "mims-harvard/TxAgent-T1-Llama-3.1-8B",
11
  "rag_model_name": "mims-harvard/ToolRAG-T1-GTE-Qwen2-1.5B",
12
- "embedding_filename": "ToolRAG-T1-GTE-Qwen2-1.5Btool_embedding_e27fb393f3144ec28f620f33d4d79911.pt",
13
  "local_dir": "./models",
14
  "tool_files": {
15
- 'new_tool': './data/new_tool.json',
16
- 'opentarget': './data/opentarget_tools.json',
17
- 'fda_drug_label': './data/fda_drug_labeling_tools.json',
18
- 'special_tools': './data/special_tools.json',
19
- 'monarch': './data/monarch_tools.json'
20
  }
21
  }
22
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
  def download_model_files():
24
- """Download all required model files from Hugging Face Hub"""
25
  os.makedirs(CONFIG["local_dir"], exist_ok=True)
26
- os.makedirs("./data", exist_ok=True)
27
-
28
  print("Downloading model files...")
29
-
30
- # Download main model
31
  snapshot_download(
32
  repo_id=CONFIG["model_name"],
33
  local_dir=os.path.join(CONFIG["local_dir"], CONFIG["model_name"]),
34
  resume_download=True
35
  )
36
-
37
- # Download RAG model
38
  snapshot_download(
39
  repo_id=CONFIG["rag_model_name"],
40
  local_dir=os.path.join(CONFIG["local_dir"], CONFIG["rag_model_name"]),
41
  resume_download=True
42
  )
43
-
44
- # Try to download the embeddings file
45
  try:
46
  hf_hub_download(
47
  repo_id=CONFIG["rag_model_name"],
@@ -55,30 +62,20 @@ def download_model_files():
55
  print("Will attempt to generate it instead")
56
 
57
  def generate_embeddings(agent):
58
- """Generate and save tool embeddings if missing"""
59
  embedding_path = os.path.join(CONFIG["local_dir"], CONFIG["embedding_filename"])
60
-
61
  if os.path.exists(embedding_path):
62
  print("Embeddings file already exists")
63
  return
64
-
65
  print("Generating missing tool embeddings...")
66
-
67
  try:
68
- # Get all tools from the tool universe
69
  tools = agent.tooluniverse.get_all_tools()
70
- tool_descriptions = [tool['description'] for tool in tools]
71
-
72
- # Generate embeddings using the RAG model
73
- embeddings = agent.rag_model.generate_embeddings(tool_descriptions)
74
-
75
- # Save the embeddings
76
  torch.save(embeddings, embedding_path)
77
- print(f"Embeddings saved to {embedding_path}")
78
-
79
- # Update the RAG model to use the new embeddings
80
  agent.rag_model.tool_desc_embedding = embeddings
81
-
82
  except Exception as e:
83
  print(f"Failed to generate embeddings: {e}")
84
  raise
@@ -91,9 +88,8 @@ class TxAgentApp:
91
  def initialize(self):
92
  if self.is_initialized:
93
  return "Already initialized"
94
-
95
  try:
96
- # Initialize the agent
97
  self.agent = TxAgent(
98
  CONFIG["model_name"],
99
  CONFIG["rag_model_name"],
@@ -102,37 +98,24 @@ class TxAgentApp:
102
  enable_checker=True,
103
  step_rag_num=10,
104
  seed=100,
105
- additional_default_tools=['DirectResponse', 'RequireClarification']
106
  )
107
-
108
- # Initialize model
109
  self.agent.init_model()
110
-
111
- # Handle embeddings
112
  generate_embeddings(self.agent)
113
-
114
  self.is_initialized = True
115
- return "TxAgent initialized successfully"
116
-
117
  except Exception as e:
118
- return f"Initialization failed: {str(e)}"
119
 
120
  def chat(self, message, history):
121
  if not self.is_initialized:
122
- return history + [(message, "Error: Please initialize the model first")]
123
-
124
  try:
125
- # Convert history to messages format
126
- messages = []
127
- for user_msg, bot_msg in history:
128
- messages.append({"role": "user", "content": user_msg})
129
- messages.append({"role": "assistant", "content": bot_msg})
130
- messages.append({"role": "user", "content": message})
131
-
132
- # Get response
133
  response = ""
134
  for chunk in self.agent.run_gradio_chat(
135
- messages,
 
136
  temperature=0.3,
137
  max_new_tokens=1024,
138
  max_tokens=8192,
@@ -141,28 +124,24 @@ class TxAgentApp:
141
  max_round=30
142
  ):
143
  response += chunk
144
-
145
  return history + [(message, response)]
146
  except Exception as e:
147
  return history + [(message, f"Error: {str(e)}")]
148
 
149
  def create_interface():
150
  app = TxAgentApp()
151
-
152
  with gr.Blocks(title="TxAgent") as demo:
153
- gr.Markdown("# TxAgent: Therapeutic Reasoning AI")
154
-
155
- # Initialization
156
  with gr.Row():
157
  init_btn = gr.Button("Initialize Model", variant="primary")
158
  init_status = gr.Textbox(label="Initialization Status")
159
-
160
- # Chat interface
161
- chatbot = gr.Chatbot(height=600)
162
  msg = gr.Textbox(label="Your Question")
163
  submit_btn = gr.Button("Submit")
164
-
165
- # Examples
166
  gr.Examples(
167
  examples=[
168
  "How to adjust Journavx dosage for hepatic impairment?",
@@ -171,29 +150,15 @@ def create_interface():
171
  ],
172
  inputs=msg
173
  )
174
-
175
- # Event handlers
176
- init_btn.click(
177
- app.initialize,
178
- outputs=init_status
179
- )
180
-
181
- def respond(message, chat_history):
182
- return app.chat(message, chat_history)
183
-
184
- msg.submit(respond, [msg, chatbot], chatbot)
185
- submit_btn.click(respond, [msg, chatbot], chatbot)
186
-
187
  return demo
188
 
189
  if __name__ == "__main__":
190
- # First download all required files
191
  download_model_files()
192
-
193
- # Then create and launch the interface
194
  interface = create_interface()
195
- interface.launch(
196
- server_name="0.0.0.0",
197
- server_port=7860,
198
- share=True
199
- )
 
1
  import os
2
+ import json
3
+ import logging
4
  import torch
 
 
5
  from txagent import TxAgent
6
  import gradio as gr
7
+ from huggingface_hub import hf_hub_download, snapshot_download
8
+ from tooluniverse import ToolUniverse
9
 
10
  # Configuration
11
  CONFIG = {
12
  "model_name": "mims-harvard/TxAgent-T1-Llama-3.1-8B",
13
  "rag_model_name": "mims-harvard/ToolRAG-T1-GTE-Qwen2-1.5B",
14
+ "embedding_filename": "ToolRAG-T1-GTE-Qwen2-1.5Btool_embedding.pt",
15
  "local_dir": "./models",
16
  "tool_files": {
17
+ "new_tool": "./data/new_tool.json"
 
 
 
 
18
  }
19
  }
20
 
21
+ # Logging setup
22
+ logging.basicConfig(level=logging.INFO)
23
+ logger = logging.getLogger(__name__)
24
+
25
+ def prepare_tool_files():
26
+ os.makedirs("./data", exist_ok=True)
27
+ if not os.path.exists(CONFIG["tool_files"]["new_tool"]):
28
+ logger.info("Generating tool list using ToolUniverse...")
29
+ tu = ToolUniverse()
30
+ tools = tu.get_all_tools()
31
+ with open(CONFIG["tool_files"]["new_tool"], "w") as f:
32
+ json.dump(tools, f, indent=2)
33
+ logger.info(f"Saved {len(tools)} tools to {CONFIG['tool_files']['new_tool']}")
34
+
35
+
36
  def download_model_files():
 
37
  os.makedirs(CONFIG["local_dir"], exist_ok=True)
 
 
38
  print("Downloading model files...")
39
+
 
40
  snapshot_download(
41
  repo_id=CONFIG["model_name"],
42
  local_dir=os.path.join(CONFIG["local_dir"], CONFIG["model_name"]),
43
  resume_download=True
44
  )
45
+
 
46
  snapshot_download(
47
  repo_id=CONFIG["rag_model_name"],
48
  local_dir=os.path.join(CONFIG["local_dir"], CONFIG["rag_model_name"]),
49
  resume_download=True
50
  )
51
+
 
52
  try:
53
  hf_hub_download(
54
  repo_id=CONFIG["rag_model_name"],
 
62
  print("Will attempt to generate it instead")
63
 
64
  def generate_embeddings(agent):
 
65
  embedding_path = os.path.join(CONFIG["local_dir"], CONFIG["embedding_filename"])
66
+
67
  if os.path.exists(embedding_path):
68
  print("Embeddings file already exists")
69
  return
70
+
71
  print("Generating missing tool embeddings...")
 
72
  try:
 
73
  tools = agent.tooluniverse.get_all_tools()
74
+ descriptions = [tool["description"] for tool in tools]
75
+ embeddings = agent.rag_model.generate_embeddings(descriptions)
 
 
 
 
76
  torch.save(embeddings, embedding_path)
 
 
 
77
  agent.rag_model.tool_desc_embedding = embeddings
78
+ print(f"Embeddings saved to {embedding_path}")
79
  except Exception as e:
80
  print(f"Failed to generate embeddings: {e}")
81
  raise
 
88
  def initialize(self):
89
  if self.is_initialized:
90
  return "Already initialized"
91
+
92
  try:
 
93
  self.agent = TxAgent(
94
  CONFIG["model_name"],
95
  CONFIG["rag_model_name"],
 
98
  enable_checker=True,
99
  step_rag_num=10,
100
  seed=100,
101
+ additional_default_tools=["DirectResponse", "RequireClarification"]
102
  )
 
 
103
  self.agent.init_model()
 
 
104
  generate_embeddings(self.agent)
 
105
  self.is_initialized = True
106
+ return "TxAgent initialized successfully"
 
107
  except Exception as e:
108
+ return f"Initialization failed: {str(e)}"
109
 
110
  def chat(self, message, history):
111
  if not self.is_initialized:
112
+ return history + [(message, "⚠️ Error: Model not initialized")]
113
+
114
  try:
 
 
 
 
 
 
 
 
115
  response = ""
116
  for chunk in self.agent.run_gradio_chat(
117
+ message=message,
118
+ history=history,
119
  temperature=0.3,
120
  max_new_tokens=1024,
121
  max_tokens=8192,
 
124
  max_round=30
125
  ):
126
  response += chunk
127
+
128
  return history + [(message, response)]
129
  except Exception as e:
130
  return history + [(message, f"Error: {str(e)}")]
131
 
132
  def create_interface():
133
  app = TxAgentApp()
 
134
  with gr.Blocks(title="TxAgent") as demo:
135
+ gr.Markdown("# 🧠 TxAgent: Therapeutic Reasoning AI")
136
+
 
137
  with gr.Row():
138
  init_btn = gr.Button("Initialize Model", variant="primary")
139
  init_status = gr.Textbox(label="Initialization Status")
140
+
141
+ chatbot = gr.Chatbot(height=600, label="Conversation")
 
142
  msg = gr.Textbox(label="Your Question")
143
  submit_btn = gr.Button("Submit")
144
+
 
145
  gr.Examples(
146
  examples=[
147
  "How to adjust Journavx dosage for hepatic impairment?",
 
150
  ],
151
  inputs=msg
152
  )
153
+
154
+ init_btn.click(fn=app.initialize, outputs=init_status)
155
+ msg.submit(fn=app.chat, inputs=[msg, chatbot], outputs=chatbot)
156
+ submit_btn.click(fn=app.chat, inputs=[msg, chatbot], outputs=chatbot)
157
+
 
 
 
 
 
 
 
 
158
  return demo
159
 
160
  if __name__ == "__main__":
161
+ prepare_tool_files()
162
  download_model_files()
 
 
163
  interface = create_interface()
164
+ interface.launch(server_name="0.0.0.0", server_port=7860, share=False)