Ali2206 commited on
Commit
410d25f
·
verified ·
1 Parent(s): 9053cc4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +47 -33
app.py CHANGED
@@ -5,6 +5,10 @@ import torch
5
  from txagent import TxAgent
6
  import gradio as gr
7
  from tooluniverse import ToolUniverse
 
 
 
 
8
 
9
  # Configuration with hardcoded embedding file
10
  CONFIG = {
@@ -33,49 +37,59 @@ def prepare_tool_files():
33
  json.dump(tools, f, indent=2)
34
  logger.info(f"Saved {len(tools)} tools to {CONFIG['tool_files']['new_tool']}")
35
 
 
 
 
 
 
 
 
 
 
 
36
  def patch_embedding_loading():
37
  """Monkey-patch the embedding loading functionality"""
38
  try:
39
- # Try to get the RAG model class dynamically
40
- from txagent.txagent import TxAgent as TxAgentClass
41
- original_init = TxAgentClass.__init__
42
 
43
- def patched_init(self, *args, **kwargs):
44
- # First let the original initialization happen
45
- original_init(self, *args, **kwargs)
46
-
47
- # Then handle the embeddings our way
48
  try:
49
- if os.path.exists(CONFIG["embedding_filename"]):
50
- logger.info(f"Loading embeddings from {CONFIG['embedding_filename']}")
51
- self.rag_model.tool_desc_embedding = torch.load(CONFIG["embedding_filename"])
52
-
53
- # Handle tool count mismatch
54
- tools = self.tooluniverse.get_all_tools()
55
- current_count = len(tools)
56
- embedding_count = len(self.rag_model.tool_desc_embedding)
57
-
58
- if current_count != embedding_count:
59
- logger.warning(f"Tool count mismatch (tools: {current_count}, embeddings: {embedding_count})")
60
-
61
- if current_count < embedding_count:
62
- self.rag_model.tool_desc_embedding = self.rag_model.tool_desc_embedding[:current_count]
63
- logger.info(f"Truncated embeddings to match {current_count} tools")
64
- else:
65
- last_embedding = self.rag_model.tool_desc_embedding[-1]
66
- padding = [last_embedding] * (current_count - embedding_count)
67
- self.rag_model.tool_desc_embedding = torch.cat(
68
- [self.rag_model.tool_desc_embedding] + padding
69
- )
70
- logger.info(f"Padded embeddings to match {current_count} tools")
71
- else:
72
  logger.error(f"Embedding file not found: {CONFIG['embedding_filename']}")
73
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
74
  except Exception as e:
75
  logger.error(f"Failed to load embeddings: {str(e)}")
 
76
 
77
  # Apply the patch
78
- TxAgentClass.__init__ = patched_init
79
  logger.info("Successfully patched embedding loading")
80
 
81
  except Exception as e:
 
5
  from txagent import TxAgent
6
  import gradio as gr
7
  from tooluniverse import ToolUniverse
8
+ import warnings
9
+
10
+ # Suppress specific warnings
11
+ warnings.filterwarnings("ignore", category=UserWarning)
12
 
13
  # Configuration with hardcoded embedding file
14
  CONFIG = {
 
37
  json.dump(tools, f, indent=2)
38
  logger.info(f"Saved {len(tools)} tools to {CONFIG['tool_files']['new_tool']}")
39
 
40
+ def safe_load_embeddings(filepath):
41
+ """Safely load embeddings with proper weights_only handling"""
42
+ try:
43
+ # First try with weights_only=True (secure mode)
44
+ return torch.load(filepath, weights_only=True)
45
+ except Exception as e:
46
+ logger.warning(f"Secure load failed, trying with weights_only=False: {str(e)}")
47
+ # If that fails, try with weights_only=False (less secure)
48
+ return torch.load(filepath, weights_only=False)
49
+
50
  def patch_embedding_loading():
51
  """Monkey-patch the embedding loading functionality"""
52
  try:
53
+ from txagent.toolrag import ToolRAGModel
 
 
54
 
55
+ original_load = ToolRAGModel.load_tool_desc_embedding
56
+
57
+ def patched_load(self, tooluniverse):
 
 
58
  try:
59
+ if not os.path.exists(CONFIG["embedding_filename"]):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60
  logger.error(f"Embedding file not found: {CONFIG['embedding_filename']}")
61
+ return False
62
+
63
+ # Load embeddings safely
64
+ self.tool_desc_embedding = safe_load_embeddings(CONFIG["embedding_filename"])
65
+
66
+ # Handle tool count mismatch
67
+ tools = tooluniverse.get_all_tools()
68
+ current_count = len(tools)
69
+ embedding_count = len(self.tool_desc_embedding)
70
+
71
+ if current_count != embedding_count:
72
+ logger.warning(f"Tool count mismatch (tools: {current_count}, embeddings: {embedding_count})")
73
+
74
+ if current_count < embedding_count:
75
+ self.tool_desc_embedding = self.tool_desc_embedding[:current_count]
76
+ logger.info(f"Truncated embeddings to match {current_count} tools")
77
+ else:
78
+ last_embedding = self.tool_desc_embedding[-1]
79
+ padding = [last_embedding] * (current_count - embedding_count)
80
+ self.tool_desc_embedding = torch.cat(
81
+ [self.tool_desc_embedding] + padding
82
+ )
83
+ logger.info(f"Padded embeddings to match {current_count} tools")
84
+
85
+ return True
86
+
87
  except Exception as e:
88
  logger.error(f"Failed to load embeddings: {str(e)}")
89
+ return False
90
 
91
  # Apply the patch
92
+ ToolRAGModel.load_tool_desc_embedding = patched_load
93
  logger.info("Successfully patched embedding loading")
94
 
95
  except Exception as e: