Ali2206 commited on
Commit
dc06321
·
verified ·
1 Parent(s): 66e2fa0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +127 -131
app.py CHANGED
@@ -4,7 +4,6 @@ import torch
4
  import logging
5
  import numpy
6
  import gradio as gr
7
- import torch.serialization
8
  from importlib.resources import files
9
  from txagent import TxAgent
10
  from tooluniverse import ToolUniverse
@@ -21,14 +20,7 @@ os.environ["MKL_THREADING_LAYER"] = "GNU"
21
  os.environ["TOKENIZERS_PARALLELISM"] = "false"
22
  current_dir = os.path.dirname(os.path.abspath(__file__))
23
 
24
- # Allow loading old numpy types with torch.load
25
- torch.serialization.add_safe_globals([
26
- numpy.core.multiarray._reconstruct,
27
- numpy.ndarray,
28
- numpy.dtype,
29
- numpy.dtypes.Float32DType
30
- ])
31
-
32
  CONFIG = {
33
  "model_name": "mims-harvard/TxAgent-T1-Llama-3.1-8B",
34
  "rag_model_name": "mims-harvard/ToolRAG-T1-GTE-Qwen2-1.5B",
@@ -41,131 +33,135 @@ CONFIG = {
41
  }
42
  }
43
 
44
- def prepare_tool_files():
45
- """Prepare the tool files directory and create new_tool.json if it doesn't exist."""
46
- try:
47
- os.makedirs(os.path.join(current_dir, 'data'), exist_ok=True)
48
- if not os.path.exists(CONFIG["tool_files"]["new_tool"]):
49
- logger.info("Creating new_tool.json...")
50
- tu = ToolUniverse()
51
- tools = tu.get_all_tools() if hasattr(tu, "get_all_tools") else getattr(tu, "tools", [])
52
- with open(CONFIG["tool_files"]["new_tool"], "w") as f:
53
- json.dump(tools, f, indent=2)
54
- except Exception as e:
55
- logger.error(f"Failed to prepare tool files: {e}")
56
- raise
57
-
58
- def create_agent():
59
- """Initialize and return the TxAgent instance."""
60
- try:
61
- prepare_tool_files()
62
- logger.info("Initializing TxAgent...")
63
-
64
- agent = TxAgent(
65
- model_name=CONFIG["model_name"],
66
- rag_model_name=CONFIG["rag_model_name"],
67
- tool_files_dict=CONFIG["tool_files"],
68
- force_finish=True,
69
- enable_checker=True,
70
- step_rag_num=10,
71
- seed=42,
72
- additional_default_tools=["DirectResponse", "RequireClarification"]
73
- )
74
-
75
- logger.info("Initializing model...")
76
- agent.init_model()
77
- logger.info("Agent initialization complete.")
78
- return agent
79
- except Exception as e:
80
- logger.error(f"Failed to create agent: {e}")
81
- raise
82
-
83
- def respond(msg, chat_history, temperature, max_new_tokens, max_tokens, multi_agent, conversation, max_round):
84
- """Handle user message and generate response."""
85
- try:
86
- if not isinstance(msg, str) or len(msg.strip()) <= 10:
87
- return chat_history + [{"role": "assistant", "content": "Please provide a valid message longer than 10 characters."}]
88
-
89
- message = msg.strip()
90
- chat_history.append({"role": "user", "content": message})
91
- formatted_history = [(m["role"], m["content"]) for m in chat_history if "role" in m and "content" in m]
92
-
93
- logger.info(f"Processing message: {message[:100]}...")
94
-
95
- response_generator = agent.run_gradio_chat(
96
- message=message,
97
- history=formatted_history,
98
- temperature=temperature,
99
- max_new_tokens=max_new_tokens,
100
- max_token=max_tokens,
101
- call_agent=multi_agent,
102
- conversation=conversation,
103
- max_round=max_round,
104
- seed=42
105
- )
106
-
107
- collected = ""
108
- for chunk in response_generator:
109
- if isinstance(chunk, dict) and "content" in chunk:
110
- collected += chunk["content"]
111
- elif isinstance(chunk, str):
112
- collected += chunk
113
- elif chunk is not None:
114
- collected += str(chunk)
115
-
116
- chat_history.append({"role": "assistant", "content": collected or "No response generated."})
117
- return chat_history
118
-
119
- except Exception as e:
120
- logger.error(f"Error in respond function: {e}")
121
- chat_history.append({"role": "assistant", "content": f"Error: {str(e)}"})
122
- return chat_history
123
-
124
- def create_demo(agent):
125
- """Create and return the Gradio interface."""
126
- with gr.Blocks(title="TxAgent", css=".gr-button { font-size: 18px !important; }") as demo:
127
- gr.Markdown("# TxAgent - Biomedical AI Assistant")
128
-
129
- with gr.Row():
130
- with gr.Column(scale=3):
131
- chatbot = gr.Chatbot(
132
- label="Conversation",
133
- avatar_images=(
134
- "https://example.com/user.png", # Replace with actual paths
135
- "https://example.com/bot.png"
136
- ),
137
- height=600
138
- )
139
- msg = gr.Textbox(
140
- label="Your question",
141
- placeholder="Ask a biomedical question...",
142
- lines=3
143
- )
144
- submit = gr.Button("Ask", variant="primary")
145
-
146
- with gr.Column(scale=1):
147
- temp = gr.Slider(0, 1, value=0.3, label="Temperature")
148
- max_new_tokens = gr.Slider(128, 4096, value=1024, step=128, label="Max New Tokens")
149
- max_tokens = gr.Slider(128, 81920, value=81920, step=1024, label="Max Total Tokens")
150
- max_rounds = gr.Slider(1, 30, value=10, step=1, label="Max Rounds")
151
- multi_agent = gr.Checkbox(label="Multi-Agent Mode", value=False)
152
- clear_btn = gr.Button("Clear Chat")
153
-
154
- submit.click(
155
- respond,
156
- inputs=[msg, chatbot, temp, max_new_tokens, max_tokens, multi_agent, gr.State([]), max_rounds],
157
- outputs=[chatbot]
158
- )
159
- clear_btn.click(lambda: [], None, chatbot, queue=False)
160
-
161
- return demo
 
 
 
 
162
 
163
  def main():
164
- """Main entry point for the application."""
165
  try:
166
- logger.info("Starting application initialization...")
167
- agent = create_agent()
168
- demo = create_demo(agent)
169
 
170
  logger.info("Launching Gradio interface...")
171
  demo.launch(
 
4
  import logging
5
  import numpy
6
  import gradio as gr
 
7
  from importlib.resources import files
8
  from txagent import TxAgent
9
  from tooluniverse import ToolUniverse
 
20
  os.environ["TOKENIZERS_PARALLELISM"] = "false"
21
  current_dir = os.path.dirname(os.path.abspath(__file__))
22
 
23
+ # Configuration
 
 
 
 
 
 
 
24
  CONFIG = {
25
  "model_name": "mims-harvard/TxAgent-T1-Llama-3.1-8B",
26
  "rag_model_name": "mims-harvard/ToolRAG-T1-GTE-Qwen2-1.5B",
 
33
  }
34
  }
35
 
36
+ class TxAgentApp:
37
+ def __init__(self):
38
+ self.agent = None
39
+ self.initialize_agent()
40
+
41
+ def initialize_agent(self):
42
+ """Initialize the TxAgent with proper error handling"""
43
+ try:
44
+ self.prepare_tool_files()
45
+ logger.info("Initializing TxAgent...")
46
+
47
+ self.agent = TxAgent(
48
+ model_name=CONFIG["model_name"],
49
+ rag_model_name=CONFIG["rag_model_name"],
50
+ tool_files_dict=CONFIG["tool_files"],
51
+ force_finish=True,
52
+ enable_checker=True,
53
+ step_rag_num=10,
54
+ seed=42,
55
+ additional_default_tools=["DirectResponse", "RequireClarification"]
56
+ )
57
+
58
+ logger.info("Initializing model...")
59
+ self.agent.init_model()
60
+ logger.info("Agent initialization complete")
61
+
62
+ except Exception as e:
63
+ logger.error(f"Failed to initialize agent: {e}")
64
+ raise
65
+
66
+ def prepare_tool_files(self):
67
+ """Prepare the tool files directory"""
68
+ try:
69
+ os.makedirs(os.path.join(current_dir, 'data'), exist_ok=True)
70
+ if not os.path.exists(CONFIG["tool_files"]["new_tool"]):
71
+ logger.info("Creating new_tool.json...")
72
+ tu = ToolUniverse()
73
+ tools = tu.get_all_tools() if hasattr(tu, "get_all_tools") else getattr(tu, "tools", [])
74
+ with open(CONFIG["tool_files"]["new_tool"], "w") as f:
75
+ json.dump(tools, f, indent=2)
76
+ except Exception as e:
77
+ logger.error(f"Failed to prepare tool files: {e}")
78
+ raise
79
+
80
+ def respond(self, msg, chat_history, temperature, max_new_tokens, max_tokens, multi_agent, conversation, max_round):
81
+ """Handle user message and generate response"""
82
+ try:
83
+ if not isinstance(msg, str) or len(msg.strip()) <= 10:
84
+ return chat_history + [{"role": "assistant", "content": "Please provide a valid message longer than 10 characters."}]
85
+
86
+ message = msg.strip()
87
+ chat_history.append({"role": "user", "content": message})
88
+ formatted_history = [(m["role"], m["content"]) for m in chat_history if "role" in m and "content" in m]
89
+
90
+ logger.info(f"Processing message: {message[:100]}...")
91
+
92
+ response_generator = self.agent.run_gradio_chat(
93
+ message=message,
94
+ history=formatted_history,
95
+ temperature=temperature,
96
+ max_new_tokens=max_new_tokens,
97
+ max_token=max_tokens,
98
+ call_agent=multi_agent,
99
+ conversation=conversation,
100
+ max_round=max_round,
101
+ seed=42
102
+ )
103
+
104
+ collected = ""
105
+ for chunk in response_generator:
106
+ if isinstance(chunk, dict) and "content" in chunk:
107
+ collected += chunk["content"]
108
+ elif isinstance(chunk, str):
109
+ collected += chunk
110
+ elif chunk is not None:
111
+ collected += str(chunk)
112
+
113
+ chat_history.append({"role": "assistant", "content": collected or "No response generated."})
114
+ return chat_history
115
+
116
+ except Exception as e:
117
+ logger.error(f"Error in respond function: {e}")
118
+ chat_history.append({"role": "assistant", "content": f"Error: {str(e)}"})
119
+ return chat_history
120
+
121
+ def create_demo(self):
122
+ """Create and return the Gradio interface"""
123
+ with gr.Blocks(title="TxAgent", css=".gr-button { font-size: 18px !important; }") as demo:
124
+ gr.Markdown("# TxAgent - Biomedical AI Assistant")
125
+
126
+ with gr.Row():
127
+ with gr.Column(scale=3):
128
+ chatbot = gr.Chatbot(
129
+ label="Conversation",
130
+ height=600
131
+ )
132
+ msg = gr.Textbox(
133
+ label="Your question",
134
+ placeholder="Ask a biomedical question...",
135
+ lines=3
136
+ )
137
+ submit = gr.Button("Ask", variant="primary")
138
+
139
+ with gr.Column(scale=1):
140
+ temp = gr.Slider(0, 1, value=0.3, label="Temperature")
141
+ max_new_tokens = gr.Slider(128, 4096, value=1024, step=128, label="Max New Tokens")
142
+ max_tokens = gr.Slider(128, 81920, value=81920, step=1024, label="Max Total Tokens")
143
+ max_rounds = gr.Slider(1, 30, value=10, step=1, label="Max Rounds")
144
+ multi_agent = gr.Checkbox(label="Multi-Agent Mode", value=False)
145
+ clear_btn = gr.Button("Clear Chat")
146
+
147
+ submit.click(
148
+ self.respond,
149
+ inputs=[msg, chatbot, temp, max_new_tokens, max_tokens, multi_agent, gr.State([]), max_rounds],
150
+ outputs=[chatbot]
151
+ )
152
+ clear_btn.click(lambda: [], None, chatbot, queue=False)
153
+
154
+ # Add a dummy event to ensure the app stays alive
155
+ demo.load(lambda: None, None, None)
156
+
157
+ return demo
158
 
159
  def main():
160
+ """Main entry point for the application"""
161
  try:
162
+ logger.info("Starting TxAgent application...")
163
+ app = TxAgentApp()
164
+ demo = app.create_demo()
165
 
166
  logger.info("Launching Gradio interface...")
167
  demo.launch(