Ali2206 commited on
Commit
79fb3cd
·
verified ·
1 Parent(s): f2ac533

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +334 -222
app.py CHANGED
@@ -1,16 +1,25 @@
 
 
 
 
 
 
1
  import os
2
- import json
3
- import logging
4
  import torch
5
- import gradio as gr
6
- from tooluniverse import ToolUniverse
7
- from transformers import AutoModelForCausalLM, AutoTokenizer
8
- import warnings
9
- from typing import List, Dict, Any
10
  from importlib.resources import files
 
11
 
12
- # Suppress specific warnings
13
- warnings.filterwarnings("ignore", category=UserWarning)
 
 
 
 
 
 
 
 
14
 
15
  # Configuration
16
  CONFIG = {
@@ -22,27 +31,76 @@ CONFIG = {
22
  "fda_drug_label": str(files('tooluniverse.data').joinpath('fda_drug_labeling_tools.json')),
23
  "special_tools": str(files('tooluniverse.data').joinpath('special_tools.json')),
24
  "monarch": str(files('tooluniverse.data').joinpath('monarch_tools.json')),
25
- "new_tool": "./data/new_tool.json"
26
  }
27
  }
28
 
29
- # Logging setup
30
- logging.basicConfig(
31
- level=logging.INFO,
32
- format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
33
- )
34
- logger = logging.getLogger(__name__)
35
 
36
- def prepare_tool_files():
37
- """Ensure tool files exist and are populated"""
38
- os.makedirs("./data", exist_ok=True)
39
- if not os.path.exists(CONFIG["tool_files"]["new_tool"]):
40
- logger.info("Generating tool list using ToolUniverse...")
41
- tu = ToolUniverse()
42
- tools = tu.get_all_tools() if hasattr(tu, 'get_all_tools') else []
43
- with open(CONFIG["tool_files"]["new_tool"], "w") as f:
44
- json.dump(tools, f, indent=2)
45
- logger.info(f"Saved {len(tools)} tools to {CONFIG['tool_files']['new_tool']}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
 
47
  def safe_load_embeddings(filepath: str) -> Any:
48
  """Safely load embeddings with proper weights_only handling"""
@@ -59,209 +117,263 @@ def safe_load_embeddings(filepath: str) -> Any:
59
  logger.error(f"Failed to load embeddings even with safe_globals: {str(e)}")
60
  return None
61
 
62
- class TxAgentWrapper:
63
- def __init__(self):
64
- self.model = None
65
- self.tokenizer = None
66
- self.rag_model = None
67
- self.tooluniverse = None
68
- self.is_initialized = False
69
- self.special_tools = ['Finish', 'Tool_RAG', 'DirectResponse', 'RequireClarification']
70
-
71
- def initialize(self) -> str:
72
- """Initialize the model from Hugging Face"""
73
- if self.is_initialized:
74
- return "✅ Already initialized"
75
-
76
- try:
77
- logger.info("Loading models from Hugging Face Hub...")
78
-
79
- # Verify tool files exist
80
- for tool_name, tool_path in CONFIG["tool_files"].items():
81
- if tool_name != "new_tool" and not os.path.exists(tool_path):
82
- raise FileNotFoundError(f"Tool file not found: {tool_path}")
83
-
84
- # Initialize ToolUniverse with verified paths
85
- self.tooluniverse = ToolUniverse(tool_files=CONFIG["tool_files"])
86
- if hasattr(self.tooluniverse, 'load_tools'):
87
- self.tooluniverse.load_tools()
88
- logger.info(f"Loaded {len(self.tooluniverse.tools)} tools")
89
- else:
90
- logger.error("ToolUniverse doesn't have load_tools method")
91
- return "❌ Failed to load tools"
92
-
93
- # Load main model
94
- self.tokenizer = AutoTokenizer.from_pretrained(CONFIG["model_name"])
95
- self.model = AutoModelForCausalLM.from_pretrained(
96
- CONFIG["model_name"],
97
- device_map="auto",
98
- torch_dtype=torch.float16
99
- )
100
-
101
- # Load embeddings if file exists
102
- if os.path.exists(CONFIG["embedding_filename"]):
103
- self.rag_model = safe_load_embeddings(CONFIG["embedding_filename"])
104
- if self.rag_model is None:
105
- return "❌ Failed to load embeddings"
106
-
107
- self.is_initialized = True
108
- return "✅ Model initialized successfully"
109
-
110
- except Exception as e:
111
- logger.error(f"Initialization failed: {str(e)}")
112
- return f"❌ Initialization failed: {str(e)}"
113
-
114
- def chat(self, message: str, history: List[List[str]]) -> List[List[str]]:
115
- """Handle chat interactions with the model"""
116
- if not self.is_initialized:
117
- return history + [["", "⚠️ Please initialize the model first"]]
118
-
119
- try:
120
- if len(message) <= 10:
121
- return history + [["", "Please provide a more detailed question (at least 10 characters)"]]
122
-
123
- # Prepare tools prompt
124
- tools_prompt = self._prepare_tools_prompt(message)
125
-
126
- # Format conversation
127
- conversation = [
128
- {"role": "system", "content": "You are a helpful assistant that will solve problems through detailed, step-by-step reasoning." + tools_prompt},
129
- *self._format_history(history),
130
- {"role": "user", "content": message}
131
- ]
132
-
133
- # Generate response
134
- inputs = self.tokenizer.apply_chat_template(
135
- conversation,
136
- add_generation_prompt=True,
137
- return_tensors="pt"
138
- ).to(self.model.device)
139
-
140
- outputs = self.model.generate(
141
- inputs,
142
- max_new_tokens=1024,
143
- temperature=0.7,
144
- do_sample=True,
145
- pad_token_id=self.tokenizer.eos_token_id
146
- )
147
-
148
- # Decode and clean response
149
- response = self.tokenizer.decode(outputs[0][len(inputs[0]):], skip_special_tokens=True)
150
- response = response.split("[TOOL_CALLS]")[0].strip()
151
-
152
- return history + [[message, response]]
153
-
154
- except Exception as e:
155
- logger.error(f"Chat error: {str(e)}")
156
- return history + [["", f"Error: {str(e)}"]]
157
-
158
- def _prepare_tools_prompt(self, message: str) -> str:
159
- """Prepare the tools prompt section"""
160
- if not hasattr(self.tooluniverse, 'tools'):
161
- return ""
162
-
163
- tools_prompt = "\n\nYou have access to the following tools:\n"
164
- for tool in self.tooluniverse.tools:
165
- if tool['name'] not in self.special_tools:
166
- tools_prompt += f"- {tool['name']}: {tool['description']}\n"
167
-
168
- # Add special tools
169
- tools_prompt += "\nSpecial tools:\n"
170
- tools_prompt += "- Finish: Use when you have the final answer\n"
171
- tools_prompt += "- Tool_RAG: Search for additional tools when needed\n"
172
 
173
- return tools_prompt
174
-
175
- def _format_history(self, history: List[List[str]]) -> List[Dict[str, str]]:
176
- """Format chat history for the model"""
177
- formatted = []
178
- for user_msg, bot_msg in history:
179
- formatted.append({"role": "user", "content": user_msg})
180
- if bot_msg:
181
- formatted.append({"role": "assistant", "content": bot_msg})
182
- return formatted
183
-
184
- def create_interface() -> gr.Blocks:
185
- """Create the Gradio interface"""
186
- agent = TxAgentWrapper()
187
-
188
- with gr.Blocks(
189
- title="TxAgent",
190
- css="""
191
- .gradio-container {max-width: 900px !important}
192
- """
193
- ) as demo:
194
- gr.Markdown("""
195
- # 🧠 TxAgent: Therapeutic Reasoning AI
196
- ### (Loading from Hugging Face Hub)
197
- """)
198
 
199
- with gr.Row():
200
- init_btn = gr.Button("Initialize Model", variant="primary")
201
- init_status = gr.Textbox(label="Status", interactive=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
202
 
203
- chatbot = gr.Chatbot(
204
- height=500,
205
- label="Conversation"
206
- )
207
- msg = gr.Textbox(label="Your clinical question")
208
- clear_btn = gr.Button("Clear Chat")
209
 
210
- gr.Examples(
211
- examples=[
212
- "How to adjust Journavx for renal impairment?",
213
- "Xolremdi and Prozac interaction in WHIM syndrome?",
214
- "Alternative to Warfarin for patient with amiodarone?"
215
- ],
216
- inputs=msg
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
217
  )
218
-
219
- def wrapper_initialize():
220
- status = agent.initialize()
221
- return status, gr.update(interactive=False)
222
-
223
- init_btn.click(
224
- fn=wrapper_initialize,
225
- outputs=[init_status, init_btn]
226
  )
227
-
228
- msg.submit(
229
- fn=agent.chat,
230
- inputs=[msg, chatbot],
231
- outputs=chatbot
232
- ).then(
233
- lambda: "", # Clear message box
234
- outputs=msg
235
  )
236
-
237
- clear_btn.click(
238
- fn=lambda: ([], ""),
239
- outputs=[chatbot, msg]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
240
  )
241
-
242
- return demo
243
 
244
  if __name__ == "__main__":
245
- try:
246
- logger.info("Starting application...")
247
-
248
- # Verify embedding file exists
249
- if not os.path.exists(CONFIG["embedding_filename"]):
250
- logger.error(f"Embedding file not found: {CONFIG['embedding_filename']}")
251
- logger.info("Please ensure the file is in the root directory")
252
- else:
253
- logger.info(f"Found embedding file: {CONFIG['embedding_filename']}")
254
-
255
- # Prepare tool files
256
- prepare_tool_files()
257
-
258
- # Launch interface
259
- interface = create_interface()
260
- interface.launch(
261
- server_name="0.0.0.0",
262
- server_port=7860,
263
- share=False
264
- )
265
- except Exception as e:
266
- logger.error(f"Application failed to start: {str(e)}")
267
- raise
 
1
+ import random
2
+ import datetime
3
+ import sys
4
+ from txagent import TxAgent
5
+ import spaces
6
+ import gradio as gr
7
  import os
 
 
8
  import torch
9
+ import logging
 
 
 
 
10
  from importlib.resources import files
11
+ import traceback
12
 
13
+ # Set up logging
14
+ logging.basicConfig(
15
+ level=logging.INFO,
16
+ format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
17
+ )
18
+ logger = logging.getLogger(__name__)
19
+
20
+ # Determine the directory where the current file is located
21
+ current_dir = os.path.dirname(os.path.abspath(__file__))
22
+ os.environ["MKL_THREADING_LAYER"] = "GNU"
23
 
24
  # Configuration
25
  CONFIG = {
 
31
  "fda_drug_label": str(files('tooluniverse.data').joinpath('fda_drug_labeling_tools.json')),
32
  "special_tools": str(files('tooluniverse.data').joinpath('special_tools.json')),
33
  "monarch": str(files('tooluniverse.data').joinpath('monarch_tools.json')),
34
+ "new_tool": os.path.join(current_dir, 'data', 'new_tool.json')
35
  }
36
  }
37
 
38
+ # Set an environment variable
39
+ HF_TOKEN = os.environ.get("HF_TOKEN", None)
 
 
 
 
40
 
41
+ DESCRIPTION = '''
42
+ <div>
43
+ <h1 style="text-align: center;">TxAgent: An AI Agent for Therapeutic Reasoning Across a Universe of Tools </h1>
44
+ </div>
45
+ '''
46
+ INTRO = """
47
+ Precision therapeutics require multimodal adaptive models that provide personalized treatment recommendations. We introduce TxAgent, an AI agent that leverages multi-step reasoning and real-time biomedical knowledge retrieval across a toolbox of 211 expert-curated tools to navigate complex drug interactions, contraindications, and patient-specific treatment strategies, delivering evidence-grounded therapeutic decisions. TxAgent executes goal-oriented tool selection and iterative function calls to solve therapeutic tasks that require deep clinical understanding and cross-source validation. The ToolUniverse consolidates 211 tools linked to trusted sources, including all US FDA-approved drugs since 1939 and validated clinical insights from Open Targets.
48
+ """
49
+
50
+ LICENSE = """
51
+ We welcome your feedback and suggestions to enhance your experience with TxAgent, and if you're interested in collaboration, please email Marinka Zitnik and Shanghua Gao.
52
+
53
+ ### Medical Advice Disclaimer
54
+ DISCLAIMER: THIS WEBSITE DOES NOT PROVIDE MEDICAL ADVICE
55
+ The information, including but not limited to, text, graphics, images and other material contained on this website are for informational purposes only. No material on this site is intended to be a substitute for professional medical advice, diagnosis or treatment. Always seek the advice of your physician or other qualified health care provider with any questions you may have regarding a medical condition or treatment and before undertaking a new health care regimen, and never disregard professional medical advice or delay in seeking it because of something you have read on this website.
56
+ """
57
+
58
+ PLACEHOLDER = """
59
+ <div style="padding: 30px; text-align: center; display: flex; flex-direction: column; align-items: center;">
60
+ <h1 style="font-size: 28px; margin-bottom: 2px; opacity: 0.55;">TxAgent</h1>
61
+ <p style="font-size: 18px; margin-bottom: 2px; opacity: 0.65;">Tips before using TxAgent:</p>
62
+ <p style="font-size: 18px; margin-bottom: 2px; opacity: 0.55;">Please click clear🗑️
63
+ (top-right) to remove previous context before sumbmitting a new question.</p>
64
+ <p style="font-size: 18px; margin-bottom: 2px; opacity: 0.55;">Click retry🔄 (below message) to get multiple versions of the answer.</p>
65
+ </div>
66
+ """
67
+
68
+ css = """
69
+ h1 {
70
+ text-align: center;
71
+ display: block;
72
+ }
73
+
74
+ #duplicate-button {
75
+ margin: auto;
76
+ color: white;
77
+ background: #1565c0;
78
+ border-radius: 100vh;
79
+ }
80
+ .small-button button {
81
+ font-size: 12px !important;
82
+ padding: 4px 8px !important;
83
+ height: 6px !important;
84
+ width: 4px !important;
85
+ }
86
+ .gradio-accordion {
87
+ margin-top: 0px !important;
88
+ margin-bottom: 0px !important;
89
+ }
90
+ """
91
+
92
+ chat_css = """
93
+ .gr-button { font-size: 20px !important; } /* Enlarges button icons */
94
+ .gr-button svg { width: 32px !important; height: 32px !important; } /* Enlarges SVG icons */
95
+ """
96
+
97
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
98
+
99
+ question_examples = [
100
+ ['Given a 50-year-old patient experiencing severe acute pain and considering the use of the newly approved medication, Journavx, how should the dosage be adjusted considering the presence of moderate hepatic impairment?'],
101
+ ['Given a 50-year-old patient experiencing severe acute pain and considering the use of the newly approved medication, Journavx, how should the dosage be adjusted considering the presence of severe hepatic impairment?'],
102
+ ['A 30-year-old patient is taking Prozac to treat their depression. They were recently diagnosed with WHIM syndrome and require a treatment for that condition as well. Is Xolremdi suitable for this patient, considering contraindications?'],
103
+ ]
104
 
105
  def safe_load_embeddings(filepath: str) -> Any:
106
  """Safely load embeddings with proper weights_only handling"""
 
117
  logger.error(f"Failed to load embeddings even with safe_globals: {str(e)}")
118
  return None
119
 
120
+ def patch_embedding_loading():
121
+ """Monkey-patch the embedding loading functionality"""
122
+ try:
123
+ from txagent.toolrag import ToolRAGModel
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
124
 
125
+ original_load = ToolRAGModel.load_tool_desc_embedding
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
126
 
127
+ def patched_load(self, tooluniverse):
128
+ try:
129
+ if not os.path.exists(CONFIG["embedding_filename"]):
130
+ logger.error(f"Embedding file not found: {CONFIG['embedding_filename']}")
131
+ return False
132
+
133
+ # Load embeddings safely
134
+ self.tool_desc_embedding = safe_load_embeddings(CONFIG["embedding_filename"])
135
+
136
+ # Handle tool count mismatch
137
+ tools = tooluniverse.get_all_tools()
138
+ current_count = len(tools)
139
+ embedding_count = len(self.tool_desc_embedding)
140
+
141
+ if current_count != embedding_count:
142
+ logger.warning(f"Tool count mismatch (tools: {current_count}, embeddings: {embedding_count})")
143
+
144
+ if current_count < embedding_count:
145
+ self.tool_desc_embedding = self.tool_desc_embedding[:current_count]
146
+ logger.info(f"Truncated embeddings to match {current_count} tools")
147
+ else:
148
+ last_embedding = self.tool_desc_embedding[-1]
149
+ padding = [last_embedding] * (current_count - embedding_count)
150
+ self.tool_desc_embedding = torch.cat(
151
+ [self.tool_desc_embedding] + padding
152
+ )
153
+ logger.info(f"Padded embeddings to match {current_count} tools")
154
+
155
+ return True
156
+
157
+ except Exception as e:
158
+ logger.error(f"Failed to load embeddings: {str(e)}")
159
+ return False
160
 
161
+ # Apply the patch
162
+ ToolRAGModel.load_tool_desc_embedding = patched_load
163
+ logger.info("Successfully patched embedding loading")
 
 
 
164
 
165
+ except Exception as e:
166
+ logger.error(f"Failed to patch embedding loading: {str(e)}")
167
+ raise
168
+
169
+ def prepare_tool_files():
170
+ """Ensure tool files exist and are populated"""
171
+ os.makedirs(os.path.join(current_dir, 'data'), exist_ok=True)
172
+ if not os.path.exists(CONFIG["tool_files"]["new_tool"]):
173
+ logger.info("Generating tool list using ToolUniverse...")
174
+ tu = ToolUniverse()
175
+ tools = tu.get_all_tools() if hasattr(tu, 'get_all_tools') else []
176
+ with open(CONFIG["tool_files"]["new_tool"], "w") as f:
177
+ json.dump(tools, f, indent=2)
178
+ logger.info(f"Saved {len(tools)} tools to {CONFIG['tool_files']['new_tool']}")
179
+
180
+ # Apply the embedding patch before creating the agent
181
+ patch_embedding_loading()
182
+ prepare_tool_files()
183
+
184
+ # Initialize the agent
185
+ agent = TxAgent(
186
+ CONFIG["model_name"],
187
+ CONFIG["rag_model_name"],
188
+ tool_files_dict=CONFIG["tool_files"],
189
+ force_finish=True,
190
+ enable_checker=True,
191
+ step_rag_num=10,
192
+ seed=100,
193
+ additional_default_tools=['DirectResponse', 'RequireClarification']
194
+ )
195
+ agent.init_model()
196
+
197
+ def update_model_parameters(enable_finish, enable_rag, enable_summary,
198
+ init_rag_num, step_rag_num, skip_last_k,
199
+ summary_mode, summary_skip_last_k, summary_context_length, force_finish, seed):
200
+ # Update model instance parameters dynamically
201
+ updated_params = agent.update_parameters(
202
+ enable_finish=enable_finish,
203
+ enable_rag=enable_rag,
204
+ enable_summary=enable_summary,
205
+ init_rag_num=init_rag_num,
206
+ step_rag_num=step_rag_num,
207
+ skip_last_k=skip_last_k,
208
+ summary_mode=summary_mode,
209
+ summary_skip_last_k=summary_skip_last_k,
210
+ summary_context_length=summary_context_length,
211
+ force_finish=force_finish,
212
+ seed=seed,
213
+ )
214
+
215
+ return updated_params
216
+
217
+ def update_seed():
218
+ # Update model instance parameters dynamically
219
+ seed = random.randint(0, 10000)
220
+ updated_params = agent.update_parameters(
221
+ seed=seed,
222
+ )
223
+ return updated_params
224
+
225
+ def handle_retry(history, retry_data: gr.RetryData, temperature, max_new_tokens, max_tokens, multi_agent, conversation, max_round):
226
+ print("Updated seed:", update_seed())
227
+ new_history = history[:retry_data.index]
228
+ previous_prompt = history[retry_data.index]['content']
229
+
230
+ print("previous_prompt", previous_prompt)
231
+
232
+ yield from agent.run_gradio_chat(new_history + [{"role": "user", "content": previous_prompt}], temperature, max_new_tokens, max_tokens, multi_agent, conversation, max_round)
233
+
234
+ PASSWORD = "mypassword"
235
+
236
+ def check_password(input_password):
237
+ if input_password == PASSWORD:
238
+ return gr.update(visible=True), ""
239
+ else:
240
+ return gr.update(visible=False), "Incorrect password, try again!"
241
+
242
+ conversation_state = gr.State([])
243
+
244
+ # Gradio block
245
+ chatbot = gr.Chatbot(height=800, placeholder=PLACEHOLDER,
246
+ label='TxAgent', type="messages", show_copy_button=True)
247
+
248
+ with gr.Blocks(css=css) as demo:
249
+ gr.Markdown(DESCRIPTION)
250
+ gr.Markdown(INTRO)
251
+ default_temperature = 0.3
252
+ default_max_new_tokens = 1024
253
+ default_max_tokens = 81920
254
+ default_max_round = 30
255
+ temperature_state = gr.State(value=default_temperature)
256
+ max_new_tokens_state = gr.State(value=default_max_new_tokens)
257
+ max_tokens_state = gr.State(value=default_max_tokens)
258
+ max_round_state = gr.State(value=default_max_round)
259
+ chatbot.retry(handle_retry, chatbot, chatbot, temperature_state, max_new_tokens_state,
260
+ max_tokens_state, gr.Checkbox(value=False, render=False), conversation_state, max_round_state)
261
+
262
+ gr.ChatInterface(
263
+ fn=agent.run_gradio_chat,
264
+ chatbot=chatbot,
265
+ fill_height=True, fill_width=True, stop_btn=True,
266
+ additional_inputs_accordion=gr.Accordion(
267
+ label="⚙️ Inference Parameters", open=False, render=False),
268
+ additional_inputs=[
269
+ temperature_state, max_new_tokens_state, max_tokens_state,
270
+ gr.Checkbox(
271
+ label="Activate multi-agent reasoning mode (it requires additional time but offers a more comprehensive analysis).", value=False, render=False),
272
+ conversation_state,
273
+ max_round_state,
274
+ gr.Number(label="Seed", value=100, render=False)
275
+ ],
276
+ examples=question_examples,
277
+ cache_examples=False,
278
+ css=chat_css,
279
+ )
280
+
281
+ with gr.Accordion("Settings", open=False):
282
+ # Define the sliders
283
+ temperature_slider = gr.Slider(
284
+ minimum=0,
285
+ maximum=1,
286
+ step=0.1,
287
+ value=default_temperature,
288
+ label="Temperature"
289
  )
290
+ max_new_tokens_slider = gr.Slider(
291
+ minimum=128,
292
+ maximum=4096,
293
+ step=1,
294
+ value=default_max_new_tokens,
295
+ label="Max new tokens"
 
 
296
  )
297
+ max_tokens_slider = gr.Slider(
298
+ minimum=128,
299
+ maximum=32000,
300
+ step=1,
301
+ value=default_max_tokens,
302
+ label="Max tokens"
 
 
303
  )
304
+ max_round_slider = gr.Slider(
305
+ minimum=0,
306
+ maximum=50,
307
+ step=1,
308
+ value=default_max_round,
309
+ label="Max round")
310
+
311
+ # Automatically update states when slider values change
312
+ temperature_slider.change(
313
+ lambda x: x, inputs=temperature_slider, outputs=temperature_state)
314
+ max_new_tokens_slider.change(
315
+ lambda x: x, inputs=max_new_tokens_slider, outputs=max_new_tokens_state)
316
+ max_tokens_slider.change(
317
+ lambda x: x, inputs=max_tokens_slider, outputs=max_tokens_state)
318
+ max_round_slider.change(
319
+ lambda x: x, inputs=max_round_slider, outputs=max_round_state)
320
+
321
+ password_input = gr.Textbox(
322
+ label="Enter Password for More Settings", type="password")
323
+ incorrect_message = gr.Textbox(visible=False, interactive=False)
324
+ with gr.Accordion("⚙️ Settings", open=False, visible=False) as protected_accordion:
325
+ with gr.Row():
326
+ with gr.Column(scale=1):
327
+ with gr.Accordion("⚙️ Model Loading", open=False):
328
+ model_name_input = gr.Textbox(
329
+ label="Enter model path", value=CONFIG["model_name"])
330
+ load_model_btn = gr.Button(value="Load Model")
331
+ load_model_btn.click(
332
+ agent.load_models, inputs=model_name_input, outputs=gr.Textbox(label="Status"))
333
+ with gr.Column(scale=1):
334
+ with gr.Accordion("⚙️ Functional Parameters", open=False):
335
+ # Create Gradio components for parameter inputs
336
+ enable_finish = gr.Checkbox(
337
+ label="Enable Finish", value=True)
338
+ enable_rag = gr.Checkbox(
339
+ label="Enable RAG", value=True)
340
+ enable_summary = gr.Checkbox(
341
+ label="Enable Summary", value=False)
342
+ init_rag_num = gr.Number(
343
+ label="Initial RAG Num", value=0)
344
+ step_rag_num = gr.Number(
345
+ label="Step RAG Num", value=10)
346
+ skip_last_k = gr.Number(label="Skip Last K", value=0)
347
+ summary_mode = gr.Textbox(
348
+ label="Summary Mode", value='step')
349
+ summary_skip_last_k = gr.Number(
350
+ label="Summary Skip Last K", value=0)
351
+ summary_context_length = gr.Number(
352
+ label="Summary Context Length", value=None)
353
+ force_finish = gr.Checkbox(
354
+ label="Force FinalAnswer", value=True)
355
+ seed = gr.Number(label="Seed", value=100)
356
+ # Button to submit and update parameters
357
+ submit_btn = gr.Button("Update Parameters")
358
+
359
+ # Display the updated parameters
360
+ updated_parameters_output = gr.JSON()
361
+
362
+ # When button is clicked, update parameters
363
+ submit_btn.click(fn=update_model_parameters,
364
+ inputs=[enable_finish, enable_rag, enable_summary, init_rag_num, step_rag_num, skip_last_k,
365
+ summary_mode, summary_skip_last_k, summary_context_length, force_finish, seed],
366
+ outputs=updated_parameters_output)
367
+ # Button to submit the password
368
+ submit_button = gr.Button("Submit")
369
+
370
+ # When the button is clicked, check if the password is correct
371
+ submit_button.click(
372
+ check_password,
373
+ inputs=password_input,
374
+ outputs=[protected_accordion, incorrect_message]
375
  )
376
+ gr.Markdown(LICENSE)
 
377
 
378
  if __name__ == "__main__":
379
+ demo.launch(share=True)