Ali2206 commited on
Commit
3e3b258
·
verified ·
1 Parent(s): 05b577c

Upload 13 files

Browse files
.gitattributes CHANGED
@@ -33,3 +33,6 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ img/q1.gif filter=lfs diff=lfs merge=lfs -text
37
+ img/q2.gif filter=lfs diff=lfs merge=lfs -text
38
+ img/q3.gif filter=lfs diff=lfs merge=lfs -text
ToolRAG-T1-GTE-Qwen2-1.5Btool_embedding_47dc56b3e3ddeb31af4f19defdd538d984de1500368852a0fab80bc2e826c944.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d6f7e35367db5296b03cf366f3d276ecd3e08867b70aec24d91258af94a648df
3
+ size 20132
app.py ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import gradio as gr
4
+ from multiprocessing import freeze_support
5
+ import importlib
6
+ import inspect
7
+ import json
8
+
9
+ # Fix path to include src
10
+ sys.path.insert(0, os.path.join(os.path.dirname(__file__), "src"))
11
+
12
+ # Reload TxAgent from txagent.py
13
+ import txagent.txagent
14
+ importlib.reload(txagent.txagent)
15
+ from txagent.txagent import TxAgent
16
+
17
+ # Debug info
18
+ print(">>> TxAgent loaded from:", inspect.getfile(TxAgent))
19
+ print(">>> TxAgent has run_gradio_chat:", hasattr(TxAgent, "run_gradio_chat"))
20
+
21
+ # Env vars
22
+ current_dir = os.path.abspath(os.path.dirname(__file__))
23
+ os.environ["MKL_THREADING_LAYER"] = "GNU"
24
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
25
+
26
+ # Model config
27
+ model_name = "mims-harvard/TxAgent-T1-Llama-3.1-8B"
28
+ rag_model_name = "mims-harvard/ToolRAG-T1-GTE-Qwen2-1.5B"
29
+ new_tool_files = {
30
+ "new_tool": os.path.join(current_dir, "data", "new_tool.json")
31
+ }
32
+
33
+ # Sample questions
34
+ question_examples = [
35
+ ["Given a patient with WHIM syndrome on prophylactic antibiotics, is it advisable to co-administer Xolremdi with fluconazole?"],
36
+ ["What treatment options exist for HER2+ breast cancer resistant to trastuzumab?"]
37
+ ]
38
+
39
+ # Helper: format assistant responses in collapsible panels
40
+ def format_collapsible(content):
41
+ if isinstance(content, (dict, list)):
42
+ try:
43
+ formatted = json.dumps(content, indent=2)
44
+ except Exception:
45
+ formatted = str(content)
46
+ else:
47
+ formatted = str(content)
48
+
49
+ return (
50
+ "<details style='border: 1px solid #ccc; padding: 8px; margin-top: 8px;'>"
51
+ "<summary style='font-weight: bold;'>Answer</summary>"
52
+ f"<pre style='white-space: pre-wrap;'>{formatted}</pre>"
53
+ "</details>"
54
+ )
55
+
56
+ # === UI setup
57
+ def create_ui(agent):
58
+ with gr.Blocks() as demo:
59
+ gr.Markdown("<h1 style='text-align: center;'>TxAgent: Therapeutic Reasoning</h1>")
60
+ gr.Markdown("Ask biomedical or therapeutic questions. Powered by step-by-step reasoning and tools.")
61
+
62
+ temperature = gr.Slider(0, 1, value=0.3, label="Temperature")
63
+ max_new_tokens = gr.Slider(128, 4096, value=1024, label="Max New Tokens")
64
+ max_tokens = gr.Slider(128, 32000, value=8192, label="Max Total Tokens")
65
+ max_round = gr.Slider(1, 50, value=30, label="Max Rounds")
66
+ multi_agent = gr.Checkbox(label="Enable Multi-agent Reasoning", value=False)
67
+ conversation_state = gr.State([])
68
+
69
+ chatbot = gr.Chatbot(label="TxAgent", height=600, type="messages")
70
+ message_input = gr.Textbox(placeholder="Ask your biomedical question...", show_label=False)
71
+ send_button = gr.Button("Send", variant="primary")
72
+
73
+ # Main handler
74
+ def handle_chat(message, history, temperature, max_new_tokens, max_tokens, multi_agent, conversation, max_round):
75
+ generator = agent.run_gradio_chat(
76
+ message=message,
77
+ history=history,
78
+ temperature=temperature,
79
+ max_new_tokens=max_new_tokens,
80
+ max_token=max_tokens,
81
+ call_agent=multi_agent,
82
+ conversation=conversation,
83
+ max_round=max_round
84
+ )
85
+
86
+ for update in generator:
87
+ formatted = []
88
+ for m in update:
89
+ role = m["role"] if isinstance(m, dict) else getattr(m, "role", "assistant")
90
+ content = m["content"] if isinstance(m, dict) else getattr(m, "content", "")
91
+
92
+ if role == "assistant":
93
+ content = format_collapsible(content)
94
+
95
+ formatted.append({"role": role, "content": content})
96
+ yield formatted
97
+
98
+ # Button and Enter triggers
99
+ inputs = [message_input, chatbot, temperature, max_new_tokens, max_tokens, multi_agent, conversation_state, max_round]
100
+ send_button.click(fn=handle_chat, inputs=inputs, outputs=chatbot)
101
+ message_input.submit(fn=handle_chat, inputs=inputs, outputs=chatbot)
102
+
103
+ gr.Examples(examples=question_examples, inputs=message_input)
104
+ gr.Markdown("**DISCLAIMER**: This demo is for research purposes only and does not provide medical advice.")
105
+
106
+ return demo
107
+
108
+ # === Entry point
109
+ if __name__ == "__main__":
110
+ freeze_support()
111
+
112
+ try:
113
+ agent = TxAgent(
114
+ model_name=model_name,
115
+ rag_model_name=rag_model_name,
116
+ tool_files_dict=new_tool_files,
117
+ force_finish=True,
118
+ enable_checker=True,
119
+ step_rag_num=10,
120
+ seed=100,
121
+ additional_default_tools=[] # Avoid loading unimplemented tools
122
+ )
123
+ agent.init_model()
124
+
125
+ if not hasattr(agent, "run_gradio_chat"):
126
+ raise AttributeError("TxAgent missing run_gradio_chat")
127
+
128
+ demo = create_ui(agent)
129
+ demo.launch(server_name="0.0.0.0", server_port=7860, show_error=True)
130
+
131
+ except Exception as e:
132
+ print(f"❌ App failed to start: {e}")
133
+ raise
data/new_tool.json ADDED
@@ -0,0 +1 @@
 
 
1
+ []
img/q1.gif ADDED

Git LFS Details

  • SHA256: f0cbda2e1ec46defdae51233c03aee0ddea1ad1f28ad9ed79e4ea72a8f13edf9
  • Pointer size: 132 Bytes
  • Size of remote file: 7.65 MB
img/q2.gif ADDED

Git LFS Details

  • SHA256: a453c339ddcc333e28bc9626b287d9d6fa1554edec7b127611617bcb27b90591
  • Pointer size: 132 Bytes
  • Size of remote file: 6.31 MB
img/q3.gif ADDED

Git LFS Details

  • SHA256: ded44920ea272367247ac4f1a222c0a55932ad6b1173c2b14009a3ec4a79f524
  • Pointer size: 132 Bytes
  • Size of remote file: 8.83 MB
pyproject.toml ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ [build-system]
2
+ requires = ["setuptools", "wheel"]
3
+ build-backend = "setuptools.build_meta"
requirements.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ gradio
2
+ tooluniverse
3
+ transformers
4
+ sentence-transformers
5
+ torch
6
+ vllm
7
+ accelerate
8
+ scipy
9
+ huggingface_hub
setup.cfg ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [metadata]
2
+ name = txagent
3
+ version = 0.1.2
4
+ description = TxAgent: An AI Agent for Therapeutic Reasoning Across a Universe of Tools
5
+ author = Shanghua Gao
6
+ author_email = [email protected]
7
+ url = https://github.com/mims-harvard/TxAgent
8
+ long_description = file: README.md
9
+ long_description_content_type = text/markdown
10
+
11
+ [options]
12
+ packages = find:
13
+ package_dir =
14
+ = src
15
+ python_requires = >=3.6
16
+ install_requires =
17
+ gradio
18
+ vllm
19
+ sentence_transformers
20
+
21
+ [options.packages.find]
22
+ where = src
src/txagent/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ from .txagent import TxAgent
2
+ from .toolrag import ToolRAGModel
3
+ __all__ = [
4
+ "TxAgent",
5
+ "ToolRAGModel",
6
+ ]
src/txagent/toolrag.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import torch
4
+ from sentence_transformers import SentenceTransformer
5
+ from .utils import get_md5
6
+
7
+
8
+ class ToolRAGModel:
9
+ def __init__(self, rag_model_name):
10
+ self.rag_model_name = rag_model_name
11
+ self.rag_model = None
12
+ self.tool_desc_embedding = None
13
+ self.tool_name = None
14
+ self.tool_embedding_path = None
15
+ self.load_rag_model()
16
+
17
+ def load_rag_model(self):
18
+ self.rag_model = SentenceTransformer(self.rag_model_name)
19
+ self.rag_model.max_seq_length = 4096
20
+ self.rag_model.tokenizer.padding_side = "right"
21
+
22
+ def load_tool_desc_embedding(self, toolbox):
23
+ self.tool_name, _ = toolbox.refresh_tool_name_desc(enable_full_desc=True)
24
+ all_tools_str = [json.dumps(each) for each in toolbox.prepare_tool_prompts(toolbox.all_tools)]
25
+ md5_value = get_md5(str(all_tools_str))
26
+ print("Computed MD5 for tool embedding:", md5_value)
27
+
28
+ self.tool_embedding_path = os.path.join(
29
+ os.path.dirname(__file__),
30
+ self.rag_model_name.split("/")[-1] + f"_tool_embedding_{md5_value}.pt"
31
+ )
32
+
33
+ if os.path.exists(self.tool_embedding_path):
34
+ try:
35
+ self.tool_desc_embedding = torch.load(self.tool_embedding_path, map_location="cpu")
36
+ assert len(self.tool_desc_embedding) == len(toolbox.all_tools), \
37
+ "Tool count mismatch with loaded embeddings."
38
+ print("\033[92mLoaded cached tool_desc_embedding.\033[0m")
39
+ return
40
+ except Exception as e:
41
+ print(f"⚠️ Failed loading cached embeddings: {e}")
42
+ self.tool_desc_embedding = None
43
+
44
+ print("\033[93mGenerating new tool_desc_embedding...\033[0m")
45
+ self.tool_desc_embedding = self.rag_model.encode(
46
+ all_tools_str, prompt="", normalize_embeddings=True
47
+ )
48
+
49
+ torch.save(self.tool_desc_embedding, self.tool_embedding_path)
50
+ print(f"\033[92mSaved new tool_desc_embedding to {self.tool_embedding_path}\033[0m")
51
+
52
+ def rag_infer(self, query, top_k=5):
53
+ torch.cuda.empty_cache()
54
+ queries = [query]
55
+ query_embeddings = self.rag_model.encode(
56
+ queries, prompt="", normalize_embeddings=True
57
+ )
58
+ if self.tool_desc_embedding is None:
59
+ raise RuntimeError("❌ tool_desc_embedding is not initialized. Did you forget to call load_tool_desc_embedding()?")
60
+
61
+ scores = self.rag_model.similarity(
62
+ query_embeddings, self.tool_desc_embedding
63
+ )
64
+ top_k = min(top_k, len(self.tool_name))
65
+ top_k_indices = torch.topk(scores, top_k).indices.tolist()[0]
66
+ top_k_tool_names = [self.tool_name[i] for i in top_k_indices]
67
+ return top_k_tool_names
src/txagent/txagent.py ADDED
@@ -0,0 +1,945 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import os
3
+ import sys
4
+ import json
5
+ import gc
6
+ import numpy as np
7
+ from vllm import LLM, SamplingParams
8
+ from jinja2 import Template
9
+ from typing import List
10
+ import types
11
+ from tooluniverse import ToolUniverse
12
+ from gradio import ChatMessage
13
+ from .toolrag import ToolRAGModel
14
+ import torch
15
+ # near the top of txagent.py
16
+ import logging
17
+ logger = logging.getLogger(__name__)
18
+ logging.basicConfig(level=logging.INFO)
19
+
20
+ from .utils import NoRepeatSentenceProcessor, ReasoningTraceChecker, tool_result_format
21
+
22
+
23
+ class TxAgent:
24
+ def __init__(self, model_name,
25
+ rag_model_name,
26
+ tool_files_dict=None, # None leads to the default tool files in ToolUniverse
27
+ enable_finish=True,
28
+ enable_rag=True,
29
+ enable_summary=False,
30
+ init_rag_num=0,
31
+ step_rag_num=10,
32
+ summary_mode='step',
33
+ summary_skip_last_k=0,
34
+ summary_context_length=None,
35
+ force_finish=True,
36
+ avoid_repeat=True,
37
+ seed=None,
38
+ enable_checker=False,
39
+ enable_chat=False,
40
+ additional_default_tools=None,
41
+ ):
42
+ self.model_name = model_name
43
+ self.tokenizer = None
44
+ self.terminators = None
45
+ self.rag_model_name = rag_model_name
46
+ self.tool_files_dict = tool_files_dict
47
+ self.model = None
48
+ self.rag_model = ToolRAGModel(rag_model_name)
49
+ self.tooluniverse = None
50
+ # self.tool_desc = None
51
+ self.prompt_multi_step = "You are a helpful assistant that will solve problems through detailed, step-by-step reasoning and actions based on your reasoning. Typically, your actions will use the provided functions. You have access to the following functions."
52
+ self.self_prompt = "Strictly follow the instruction."
53
+ self.chat_prompt = "You are helpful assistant to chat with the user."
54
+ self.enable_finish = enable_finish
55
+ self.enable_rag = enable_rag
56
+ self.enable_summary = enable_summary
57
+ self.summary_mode = summary_mode
58
+ self.summary_skip_last_k = summary_skip_last_k
59
+ self.summary_context_length = summary_context_length
60
+ self.init_rag_num = init_rag_num
61
+ self.step_rag_num = step_rag_num
62
+ self.force_finish = force_finish
63
+ self.avoid_repeat = avoid_repeat
64
+ self.seed = seed
65
+ self.enable_checker = enable_checker
66
+ self.additional_default_tools = additional_default_tools
67
+ self.print_self_values()
68
+
69
+ def init_model(self):
70
+ self.load_models()
71
+ self.load_tooluniverse()
72
+ self.load_tool_desc_embedding()
73
+
74
+ def print_self_values(self):
75
+ for attr, value in self.__dict__.items():
76
+ print(f"{attr}: {value}")
77
+
78
+ def load_models(self, model_name=None):
79
+ if model_name is not None:
80
+ if model_name == self.model_name:
81
+ return f"The model {model_name} is already loaded."
82
+ self.model_name = model_name
83
+
84
+ self.model = LLM(model=self.model_name)
85
+ self.chat_template = Template(self.model.get_tokenizer().chat_template)
86
+ self.tokenizer = self.model.get_tokenizer()
87
+
88
+ return f"Model {model_name} loaded successfully."
89
+
90
+ def load_tooluniverse(self):
91
+ self.tooluniverse = ToolUniverse(tool_files=self.tool_files_dict)
92
+ self.tooluniverse.load_tools()
93
+ special_tools = self.tooluniverse.prepare_tool_prompts(
94
+ self.tooluniverse.tool_category_dicts["special_tools"])
95
+ self.special_tools_name = [tool['name'] for tool in special_tools]
96
+
97
+ def load_tool_desc_embedding(self):
98
+ self.rag_model.load_tool_desc_embedding(self.tooluniverse)
99
+
100
+ def rag_infer(self, query, top_k=5):
101
+ return self.rag_model.rag_infer(query, top_k)
102
+
103
+ def initialize_tools_prompt(self, call_agent, call_agent_level, message):
104
+ picked_tools_prompt = []
105
+ picked_tools_prompt = self.add_special_tools(
106
+ picked_tools_prompt, call_agent=call_agent)
107
+ if call_agent:
108
+ call_agent_level += 1
109
+ if call_agent_level >= 2:
110
+ call_agent = False
111
+
112
+ if not call_agent:
113
+ picked_tools_prompt += self.tool_RAG(
114
+ message=message, rag_num=self.init_rag_num)
115
+ return picked_tools_prompt, call_agent_level
116
+
117
+ def initialize_conversation(self, message, conversation=None, history=None):
118
+ if conversation is None:
119
+ conversation = []
120
+
121
+ conversation = self.set_system_prompt(
122
+ conversation, self.prompt_multi_step)
123
+ if history is not None:
124
+ if len(history) == 0:
125
+ conversation = []
126
+ print("clear conversation successfully")
127
+ else:
128
+ for i in range(len(history)):
129
+ if history[i]['role'] == 'user':
130
+ if i-1 >= 0 and history[i-1]['role'] == 'assistant':
131
+ conversation.append(
132
+ {"role": "assistant", "content": history[i-1]['content']})
133
+ conversation.append(
134
+ {"role": "user", "content": history[i]['content']})
135
+ if i == len(history)-1 and history[i]['role'] == 'assistant':
136
+ conversation.append(
137
+ {"role": "assistant", "content": history[i]['content']})
138
+
139
+ conversation.append({"role": "user", "content": message})
140
+
141
+ return conversation
142
+
143
+ def tool_RAG(self, message=None,
144
+ picked_tool_names=None,
145
+ existing_tools_prompt=[],
146
+ rag_num=5,
147
+ return_call_result=False):
148
+ extra_factor = 30 # Factor to retrieve more than rag_num
149
+ if picked_tool_names is None:
150
+ assert picked_tool_names is not None or message is not None
151
+ picked_tool_names = self.rag_infer(
152
+ message, top_k=rag_num*extra_factor)
153
+
154
+ picked_tool_names_no_special = []
155
+ for tool in picked_tool_names:
156
+ if tool not in self.special_tools_name:
157
+ picked_tool_names_no_special.append(tool)
158
+ picked_tool_names_no_special = picked_tool_names_no_special[:rag_num]
159
+ picked_tool_names = picked_tool_names_no_special[:rag_num]
160
+
161
+ picked_tools = self.tooluniverse.get_tool_by_name(picked_tool_names)
162
+ picked_tools_prompt = self.tooluniverse.prepare_tool_prompts(
163
+ picked_tools)
164
+ if return_call_result:
165
+ return picked_tools_prompt, picked_tool_names
166
+ return picked_tools_prompt
167
+
168
+ def add_special_tools(self, tools, call_agent=False):
169
+ if self.enable_finish:
170
+ tools.append(self.tooluniverse.get_one_tool_by_one_name(
171
+ 'Finish', return_prompt=True))
172
+ print("Finish tool is added")
173
+ if call_agent:
174
+ tools.append(self.tooluniverse.get_one_tool_by_one_name(
175
+ 'CallAgent', return_prompt=True))
176
+ print("CallAgent tool is added")
177
+ else:
178
+ if self.enable_rag:
179
+ tools.append(self.tooluniverse.get_one_tool_by_one_name(
180
+ 'Tool_RAG', return_prompt=True))
181
+ print("Tool_RAG tool is added")
182
+
183
+ if self.additional_default_tools is not None:
184
+ for each_tool_name in self.additional_default_tools:
185
+ tool_prompt = self.tooluniverse.get_one_tool_by_one_name(
186
+ each_tool_name, return_prompt=True)
187
+ if tool_prompt is not None:
188
+ print(f"{each_tool_name} tool is added")
189
+ tools.append(tool_prompt)
190
+ return tools
191
+
192
+ def add_finish_tools(self, tools):
193
+ tools.append(self.tooluniverse.get_one_tool_by_one_name(
194
+ 'Finish', return_prompt=True))
195
+ print("Finish tool is added")
196
+ return tools
197
+
198
+ def set_system_prompt(self, conversation, sys_prompt):
199
+ if len(conversation) == 0:
200
+ conversation.append(
201
+ {"role": "system", "content": sys_prompt})
202
+ else:
203
+ conversation[0] = {"role": "system", "content": sys_prompt}
204
+ return conversation
205
+
206
+ def run_function_call(self, fcall_str,
207
+ return_message=False,
208
+ existing_tools_prompt=None,
209
+ message_for_call_agent=None,
210
+ call_agent=False,
211
+ call_agent_level=None,
212
+ temperature=None):
213
+
214
+ function_call_json, message = self.tooluniverse.extract_function_call_json(
215
+ fcall_str, return_message=return_message, verbose=False)
216
+ call_results = []
217
+ special_tool_call = ''
218
+ if function_call_json is not None:
219
+ if isinstance(function_call_json, list):
220
+ for i in range(len(function_call_json)):
221
+ print("\033[94mTool Call:\033[0m", function_call_json[i])
222
+ if function_call_json[i]["name"] == 'Finish':
223
+ special_tool_call = 'Finish'
224
+ break
225
+ elif function_call_json[i]["name"] == 'Tool_RAG':
226
+ new_tools_prompt, call_result = self.tool_RAG(
227
+ message=message,
228
+ existing_tools_prompt=existing_tools_prompt,
229
+ rag_num=self.step_rag_num,
230
+ return_call_result=True)
231
+ existing_tools_prompt += new_tools_prompt
232
+ elif function_call_json[i]["name"] == 'CallAgent':
233
+ if call_agent_level < 2 and call_agent:
234
+ solution_plan = function_call_json[i]['arguments']['solution']
235
+ full_message = (
236
+ message_for_call_agent +
237
+ "\nYou must follow the following plan to answer the question: " +
238
+ str(solution_plan)
239
+ )
240
+ call_result = self.run_multistep_agent(
241
+ full_message, temperature=temperature,
242
+ max_new_tokens=1024, max_token=99999,
243
+ call_agent=False, call_agent_level=call_agent_level)
244
+ call_result = call_result.split(
245
+ '[FinalAnswer]')[-1].strip()
246
+ else:
247
+ call_result = "Error: The CallAgent has been disabled. Please proceed with your reasoning process to solve this question."
248
+ else:
249
+ call_result = self.tooluniverse.run_one_function(
250
+ function_call_json[i])
251
+
252
+ call_id = self.tooluniverse.call_id_gen()
253
+ function_call_json[i]["call_id"] = call_id
254
+ print("\033[94mTool Call Result:\033[0m", call_result)
255
+ call_results.append({
256
+ "role": "tool",
257
+ "content": json.dumps({"content": call_result, "call_id": call_id})
258
+ })
259
+ else:
260
+ call_results.append({
261
+ "role": "tool",
262
+ "content": json.dumps({"content": "Not a valid function call, please check the function call format."})
263
+ })
264
+
265
+ revised_messages = [{
266
+ "role": "assistant",
267
+ "content": message.strip(),
268
+ "tool_calls": json.dumps(function_call_json)
269
+ }] + call_results
270
+
271
+ # Yield the final result.
272
+ return revised_messages, existing_tools_prompt, special_tool_call
273
+
274
+ def run_function_call_stream(self, fcall_str,
275
+ return_message=False,
276
+ existing_tools_prompt=None,
277
+ message_for_call_agent=None,
278
+ call_agent=False,
279
+ call_agent_level=None,
280
+ temperature=None,
281
+ return_gradio_history=True):
282
+
283
+ function_call_json, message = self.tooluniverse.extract_function_call_json(
284
+ fcall_str, return_message=return_message, verbose=False)
285
+ call_results = []
286
+ special_tool_call = ''
287
+ if return_gradio_history:
288
+ gradio_history = []
289
+ if function_call_json is not None:
290
+ if isinstance(function_call_json, list):
291
+ for i in range(len(function_call_json)):
292
+ if function_call_json[i]["name"] == 'Finish':
293
+ special_tool_call = 'Finish'
294
+ break
295
+ elif function_call_json[i]["name"] == 'Tool_RAG':
296
+ new_tools_prompt, call_result = self.tool_RAG(
297
+ message=message,
298
+ existing_tools_prompt=existing_tools_prompt,
299
+ rag_num=self.step_rag_num,
300
+ return_call_result=True)
301
+ existing_tools_prompt += new_tools_prompt
302
+ elif function_call_json[i]["name"] == 'DirectResponse':
303
+ call_result = function_call_json[i]['arguments']['respose']
304
+ special_tool_call = 'DirectResponse'
305
+ elif function_call_json[i]["name"] == 'RequireClarification':
306
+ call_result = function_call_json[i]['arguments']['unclear_question']
307
+ special_tool_call = 'RequireClarification'
308
+ elif function_call_json[i]["name"] == 'CallAgent':
309
+ if call_agent_level < 2 and call_agent:
310
+ solution_plan = function_call_json[i]['arguments']['solution']
311
+ full_message = (
312
+ message_for_call_agent +
313
+ "\nYou must follow the following plan to answer the question: " +
314
+ str(solution_plan)
315
+ )
316
+ sub_agent_task = "Sub TxAgent plan: " + \
317
+ str(solution_plan)
318
+ # When streaming, yield responses as they arrive.
319
+ call_result = yield from self.run_gradio_chat(
320
+ full_message, history=[], temperature=temperature,
321
+ max_new_tokens=1024, max_token=99999,
322
+ call_agent=False, call_agent_level=call_agent_level,
323
+ conversation=None,
324
+ sub_agent_task=sub_agent_task)
325
+
326
+ call_result = call_result.split(
327
+ '[FinalAnswer]')[-1]
328
+ else:
329
+ call_result = "Error: The CallAgent has been disabled. Please proceed with your reasoning process to solve this question."
330
+ else:
331
+ call_result = self.tooluniverse.run_one_function(
332
+ function_call_json[i])
333
+
334
+ call_id = self.tooluniverse.call_id_gen()
335
+ function_call_json[i]["call_id"] = call_id
336
+ call_results.append({
337
+ "role": "tool",
338
+ "content": json.dumps({"content": call_result, "call_id": call_id})
339
+ })
340
+ if return_gradio_history and function_call_json[i]["name"] != 'Finish':
341
+ if function_call_json[i]["name"] == 'Tool_RAG':
342
+ gradio_history.append(ChatMessage(role="assistant", content=str(call_result), metadata={
343
+ "title": "🧰 "+function_call_json[i]['name'], "log": str(function_call_json[i]['arguments'])}))
344
+
345
+ else:
346
+ gradio_history.append(ChatMessage(role="assistant", content=str(call_result), metadata={
347
+ "title": "⚒️ "+function_call_json[i]['name'], "log": str(function_call_json[i]['arguments'])}))
348
+ else:
349
+ call_results.append({
350
+ "role": "tool",
351
+ "content": json.dumps({"content": "Not a valid function call, please check the function call format."})
352
+ })
353
+
354
+ revised_messages = [{
355
+ "role": "assistant",
356
+ "content": message.strip(),
357
+ "tool_calls": json.dumps(function_call_json)
358
+ }] + call_results
359
+
360
+ # Yield the final result.
361
+ if return_gradio_history:
362
+ return revised_messages, existing_tools_prompt, special_tool_call, gradio_history
363
+ else:
364
+ return revised_messages, existing_tools_prompt, special_tool_call
365
+
366
+ def get_answer_based_on_unfinished_reasoning(self, conversation, temperature, max_new_tokens, max_token, outputs=None):
367
+ if conversation[-1]['role'] == 'assisant':
368
+ conversation.append(
369
+ {'role': 'tool', 'content': 'Errors happen during the function call, please come up with the final answer with the current information.'})
370
+ finish_tools_prompt = self.add_finish_tools([])
371
+
372
+ last_outputs_str = self.llm_infer(messages=conversation,
373
+ temperature=temperature,
374
+ tools=finish_tools_prompt,
375
+ output_begin_string='Since I cannot continue reasoning, I will provide the final answer based on the current information and general knowledge.\n\n[FinalAnswer]',
376
+ skip_special_tokens=True,
377
+ max_new_tokens=max_new_tokens, max_token=max_token)
378
+ print(last_outputs_str)
379
+ return last_outputs_str
380
+
381
+ def run_multistep_agent(self, message: str,
382
+ temperature: float,
383
+ max_new_tokens: int,
384
+ max_token: int,
385
+ max_round: int = 20,
386
+ call_agent=False,
387
+ call_agent_level=0) -> str:
388
+ """
389
+ Generate a streaming response using the llama3-8b model.
390
+ Args:
391
+ message (str): The input message.
392
+ temperature (float): The temperature for generating the response.
393
+ max_new_tokens (int): The maximum number of new tokens to generate.
394
+ Returns:
395
+ str: The generated response.
396
+ """
397
+ print("\033[1;32;40mstart\033[0m")
398
+ picked_tools_prompt, call_agent_level = self.initialize_tools_prompt(
399
+ call_agent, call_agent_level, message)
400
+ conversation = self.initialize_conversation(message)
401
+
402
+ outputs = []
403
+ last_outputs = []
404
+ next_round = True
405
+ function_call_messages = []
406
+ current_round = 0
407
+ token_overflow = False
408
+ enable_summary = False
409
+ last_status = {}
410
+
411
+ if self.enable_checker:
412
+ checker = ReasoningTraceChecker(message, conversation)
413
+ try:
414
+ while next_round and current_round < max_round:
415
+ current_round += 1
416
+ if len(outputs) > 0:
417
+ function_call_messages, picked_tools_prompt, special_tool_call = self.run_function_call(
418
+ last_outputs, return_message=True,
419
+ existing_tools_prompt=picked_tools_prompt,
420
+ message_for_call_agent=message,
421
+ call_agent=call_agent,
422
+ call_agent_level=call_agent_level,
423
+ temperature=temperature)
424
+
425
+ if special_tool_call == 'Finish':
426
+ next_round = False
427
+ conversation.extend(function_call_messages)
428
+ if isinstance(function_call_messages[0]['content'], types.GeneratorType):
429
+ function_call_messages[0]['content'] = next(
430
+ function_call_messages[0]['content'])
431
+ return function_call_messages[0]['content'].split('[FinalAnswer]')[-1]
432
+
433
+ if (self.enable_summary or token_overflow) and not call_agent:
434
+ if token_overflow:
435
+ print("token_overflow, using summary")
436
+ enable_summary = True
437
+ last_status = self.function_result_summary(
438
+ conversation, status=last_status, enable_summary=enable_summary)
439
+
440
+ if function_call_messages is not None:
441
+ conversation.extend(function_call_messages)
442
+ outputs.append(tool_result_format(
443
+ function_call_messages))
444
+ else:
445
+ next_round = False
446
+ conversation.extend(
447
+ [{"role": "assistant", "content": ''.join(last_outputs)}])
448
+ return ''.join(last_outputs).replace("</s>", "")
449
+ if self.enable_checker:
450
+ good_status, wrong_info = checker.check_conversation()
451
+ if not good_status:
452
+ next_round = False
453
+ print(
454
+ "Internal error in reasoning: " + wrong_info)
455
+ break
456
+ last_outputs = []
457
+ outputs.append("### TxAgent:\n")
458
+ last_outputs_str, token_overflow = self.llm_infer(messages=conversation,
459
+ temperature=temperature,
460
+ tools=picked_tools_prompt,
461
+ skip_special_tokens=False,
462
+ max_new_tokens=max_new_tokens, max_token=max_token,
463
+ check_token_status=True)
464
+ if last_outputs_str is None:
465
+ next_round = False
466
+ print(
467
+ "The number of tokens exceeds the maximum limit.")
468
+ else:
469
+ last_outputs.append(last_outputs_str)
470
+ if max_round == current_round:
471
+ print("The number of rounds exceeds the maximum limit!")
472
+ if self.force_finish:
473
+ return self.get_answer_based_on_unfinished_reasoning(conversation, temperature, max_new_tokens, max_token)
474
+ else:
475
+ return None
476
+
477
+ except Exception as e:
478
+ print(f"Error: {e}")
479
+ if self.force_finish:
480
+ return self.get_answer_based_on_unfinished_reasoning(conversation, temperature, max_new_tokens, max_token)
481
+ else:
482
+ return None
483
+
484
+ def build_logits_processor(self, messages, llm):
485
+ # Use the tokenizer from the LLM instance.
486
+ tokenizer = llm.get_tokenizer()
487
+ if self.avoid_repeat and len(messages) > 2:
488
+ assistant_messages = []
489
+ for i in range(1, len(messages) + 1):
490
+ if messages[-i]['role'] == 'assistant':
491
+ assistant_messages.append(messages[-i]['content'])
492
+ if len(assistant_messages) == 2:
493
+ break
494
+ forbidden_ids = [tokenizer.encode(
495
+ msg, add_special_tokens=False) for msg in assistant_messages]
496
+ return [NoRepeatSentenceProcessor(forbidden_ids, 5)]
497
+ else:
498
+ return None
499
+
500
+ def llm_infer(self, messages, temperature=0.1, tools=None,
501
+ output_begin_string=None, max_new_tokens=2048,
502
+ max_token=None, skip_special_tokens=True,
503
+ model=None, tokenizer=None, terminators=None, seed=None, check_token_status=False):
504
+
505
+ if model is None:
506
+ model = self.model
507
+
508
+ logits_processor = self.build_logits_processor(messages, model)
509
+ sampling_params = SamplingParams(
510
+ temperature=temperature,
511
+ max_tokens=max_new_tokens,
512
+
513
+ seed=seed if seed is not None else self.seed,
514
+ )
515
+
516
+ prompt = self.chat_template.render(
517
+ messages=messages, tools=tools, add_generation_prompt=True)
518
+ if output_begin_string is not None:
519
+ prompt += output_begin_string
520
+
521
+ if check_token_status and max_token is not None:
522
+ token_overflow = False
523
+ num_input_tokens = len(self.tokenizer.encode(
524
+ prompt, return_tensors="pt")[0])
525
+ if max_token is not None:
526
+ if num_input_tokens > max_token:
527
+ torch.cuda.empty_cache()
528
+ gc.collect()
529
+ print("Number of input tokens before inference:",
530
+ num_input_tokens)
531
+ logger.info(
532
+ "The number of tokens exceeds the maximum limit!!!!")
533
+ token_overflow = True
534
+ return None, token_overflow
535
+ output = model.generate(
536
+ prompt,
537
+ sampling_params=sampling_params,
538
+ )
539
+ output = output[0].outputs[0].text
540
+ print("\033[92m" + output + "\033[0m")
541
+ if check_token_status and max_token is not None:
542
+ return output, token_overflow
543
+
544
+ return output
545
+
546
+ def run_self_agent(self, message: str,
547
+ temperature: float,
548
+ max_new_tokens: int,
549
+ max_token: int) -> str:
550
+
551
+ print("\033[1;32;40mstart self agent\033[0m")
552
+ conversation = []
553
+ conversation = self.set_system_prompt(conversation, self.self_prompt)
554
+ conversation.append({"role": "user", "content": message})
555
+ return self.llm_infer(messages=conversation,
556
+ temperature=temperature,
557
+ tools=None,
558
+ max_new_tokens=max_new_tokens, max_token=max_token)
559
+
560
+ def run_chat_agent(self, message: str,
561
+ temperature: float,
562
+ max_new_tokens: int,
563
+ max_token: int) -> str:
564
+
565
+ print("\033[1;32;40mstart chat agent\033[0m")
566
+ conversation = []
567
+ conversation = self.set_system_prompt(conversation, self.chat_prompt)
568
+ conversation.append({"role": "user", "content": message})
569
+ return self.llm_infer(messages=conversation,
570
+ temperature=temperature,
571
+ tools=None,
572
+ max_new_tokens=max_new_tokens, max_token=max_token)
573
+
574
+ def run_format_agent(self, message: str,
575
+ answer: str,
576
+ temperature: float,
577
+ max_new_tokens: int,
578
+ max_token: int) -> str:
579
+
580
+ print("\033[1;32;40mstart format agent\033[0m")
581
+ if '[FinalAnswer]' in answer:
582
+ possible_final_answer = answer.split("[FinalAnswer]")[-1]
583
+ elif "\n\n" in answer:
584
+ possible_final_answer = answer.split("\n\n")[-1]
585
+ else:
586
+ possible_final_answer = answer.strip()
587
+ if len(possible_final_answer) == 1:
588
+ choice = possible_final_answer[0]
589
+ if choice in ['A', 'B', 'C', 'D', 'E']:
590
+ return choice
591
+ elif len(possible_final_answer) > 1:
592
+ if possible_final_answer[1] == ':':
593
+ choice = possible_final_answer[0]
594
+ if choice in ['A', 'B', 'C', 'D', 'E']:
595
+ print("choice", choice)
596
+ return choice
597
+
598
+ conversation = []
599
+ format_prompt = f"You are helpful assistant to transform the answer of agent to the final answer of 'A', 'B', 'C', 'D'."
600
+ conversation = self.set_system_prompt(conversation, format_prompt)
601
+ conversation.append({"role": "user", "content": message +
602
+ "\nThe final answer of agent:" + answer + "\n The answer is (must be a letter):"})
603
+ return self.llm_infer(messages=conversation,
604
+ temperature=temperature,
605
+ tools=None,
606
+ max_new_tokens=max_new_tokens, max_token=max_token)
607
+
608
+ def run_summary_agent(self, thought_calls: str,
609
+ function_response: str,
610
+ temperature: float,
611
+ max_new_tokens: int,
612
+ max_token: int) -> str:
613
+ print("\033[1;32;40mSummarized Tool Result:\033[0m")
614
+ generate_tool_result_summary_training_prompt = """Thought and function calls:
615
+ {thought_calls}
616
+
617
+ Function calls' responses:
618
+ \"\"\"
619
+ {function_response}
620
+ \"\"\"
621
+
622
+ Based on the Thought and function calls, and the function calls' responses, you need to generate a summary of the function calls' responses that fulfills the requirements of the thought. The summary MUST BE ONE sentence and include all necessary information.
623
+
624
+ Directly respond with the summarized sentence of the function calls' responses only.
625
+
626
+ Generate **one summarized sentence** about "function calls' responses" with necessary information, and respond with a string:
627
+ """.format(thought_calls=thought_calls, function_response=function_response)
628
+ conversation = []
629
+ conversation.append(
630
+ {"role": "user", "content": generate_tool_result_summary_training_prompt})
631
+ output = self.llm_infer(messages=conversation,
632
+ temperature=temperature,
633
+ tools=None,
634
+ max_new_tokens=max_new_tokens, max_token=max_token)
635
+
636
+ if '[' in output:
637
+ output = output.split('[')[0]
638
+ return output
639
+
640
+ def function_result_summary(self, input_list, status, enable_summary):
641
+ """
642
+ Processes the input list, extracting information from sequences of 'user', 'tool', 'assistant' roles.
643
+ Supports 'length' and 'step' modes, and skips the last 'k' groups.
644
+
645
+ Parameters:
646
+ input_list (list): A list of dictionaries containing role and other information.
647
+ summary_skip_last_k (int): Number of groups to skip from the end. Defaults to 0.
648
+ summary_context_length (int): The context length threshold for the 'length' mode.
649
+ last_processed_index (tuple or int): The last processed index.
650
+
651
+ Returns:
652
+ list: A list of extracted information from valid sequences.
653
+ """
654
+ if 'tool_call_step' not in status:
655
+ status['tool_call_step'] = 0
656
+
657
+ for idx in range(len(input_list)):
658
+ pos_id = len(input_list)-idx-1
659
+ if input_list[pos_id]['role'] == 'assistant':
660
+ if 'tool_calls' in input_list[pos_id]:
661
+ if 'Tool_RAG' in str(input_list[pos_id]['tool_calls']):
662
+ status['tool_call_step'] += 1
663
+ break
664
+
665
+ if 'step' in status:
666
+ status['step'] += 1
667
+ else:
668
+ status['step'] = 0
669
+
670
+ if not enable_summary:
671
+ return status
672
+
673
+ if 'summarized_index' not in status:
674
+ status['summarized_index'] = 0
675
+
676
+ if 'summarized_step' not in status:
677
+ status['summarized_step'] = 0
678
+
679
+ if 'previous_length' not in status:
680
+ status['previous_length'] = 0
681
+
682
+ if 'history' not in status:
683
+ status['history'] = []
684
+
685
+ function_response = ''
686
+ idx = 0
687
+ current_summarized_index = status['summarized_index']
688
+
689
+ status['history'].append(self.summary_mode == 'step' and status['summarized_step']
690
+ < status['step']-status['tool_call_step']-self.summary_skip_last_k)
691
+
692
+ idx = current_summarized_index
693
+ while idx < len(input_list):
694
+ if (self.summary_mode == 'step' and status['summarized_step'] < status['step']-status['tool_call_step']-self.summary_skip_last_k) or (self.summary_mode == 'length' and status['previous_length'] > self.summary_context_length):
695
+
696
+ if input_list[idx]['role'] == 'assistant':
697
+ if 'Tool_RAG' in str(input_list[idx]['tool_calls']):
698
+ this_thought_calls = None
699
+ else:
700
+ if len(function_response) != 0:
701
+ print("internal summary")
702
+ status['summarized_step'] += 1
703
+ result_summary = self.run_summary_agent(
704
+ thought_calls=this_thought_calls,
705
+ function_response=function_response,
706
+ temperature=0.1,
707
+ max_new_tokens=1024,
708
+ max_token=99999
709
+ )
710
+
711
+ input_list.insert(
712
+ last_call_idx+1, {'role': 'tool', 'content': result_summary})
713
+ status['summarized_index'] = last_call_idx + 2
714
+ idx += 1
715
+
716
+ last_call_idx = idx
717
+ this_thought_calls = input_list[idx]['content'] + \
718
+ input_list[idx]['tool_calls']
719
+ function_response = ''
720
+
721
+ elif input_list[idx]['role'] == 'tool' and this_thought_calls is not None:
722
+ function_response += input_list[idx]['content']
723
+ del input_list[idx]
724
+ idx -= 1
725
+
726
+ else:
727
+ break
728
+ idx += 1
729
+
730
+ if len(function_response) != 0:
731
+ status['summarized_step'] += 1
732
+ result_summary = self.run_summary_agent(
733
+ thought_calls=this_thought_calls,
734
+ function_response=function_response,
735
+ temperature=0.1,
736
+ max_new_tokens=1024,
737
+ max_token=99999
738
+ )
739
+
740
+ tool_calls = json.loads(input_list[last_call_idx]['tool_calls'])
741
+ for tool_call in tool_calls:
742
+ del tool_call['call_id']
743
+ input_list[last_call_idx]['tool_calls'] = json.dumps(tool_calls)
744
+ input_list.insert(
745
+ last_call_idx+1, {'role': 'tool', 'content': result_summary})
746
+ status['summarized_index'] = last_call_idx + 2
747
+
748
+ return status
749
+
750
+ # Following are Gradio related functions
751
+
752
+ # General update method that accepts any new arguments through kwargs
753
+ def update_parameters(self, **kwargs):
754
+ for key, value in kwargs.items():
755
+ if hasattr(self, key):
756
+ setattr(self, key, value)
757
+
758
+ # Return the updated attributes
759
+ updated_attributes = {key: value for key,
760
+ value in kwargs.items() if hasattr(self, key)}
761
+ return updated_attributes
762
+
763
+ def run_gradio_chat(self, message: str,
764
+ history: list,
765
+ temperature: float,
766
+ max_new_tokens: int,
767
+ max_token: int,
768
+ call_agent: bool,
769
+ conversation: gr.State,
770
+ max_round: int = 20,
771
+ seed: int = None,
772
+ call_agent_level: int = 0,
773
+ sub_agent_task: str = None) -> str:
774
+ """
775
+ Generate a streaming response using the llama3-8b model.
776
+ Args:
777
+ message (str): The input message.
778
+ history (list): The conversation history used by ChatInterface.
779
+ temperature (float): The temperature for generating the response.
780
+ max_new_tokens (int): The maximum number of new tokens to generate.
781
+ Returns:
782
+ str: The generated response.
783
+ """
784
+ print("\033[1;32;40mstart\033[0m")
785
+ print("len(message)", len(message))
786
+ if len(message) <= 10:
787
+ yield "Hi, I am TxAgent, an assistant for answering biomedical questions. Please provide a valid message with a string longer than 10 characters."
788
+ return "Please provide a valid message."
789
+ outputs = []
790
+ outputs_str = ''
791
+ last_outputs = []
792
+
793
+ picked_tools_prompt, call_agent_level = self.initialize_tools_prompt(
794
+ call_agent,
795
+ call_agent_level,
796
+ message)
797
+
798
+ conversation = self.initialize_conversation(
799
+ message,
800
+ conversation=conversation,
801
+ history=history)
802
+ history = []
803
+
804
+ next_round = True
805
+ function_call_messages = []
806
+ current_round = 0
807
+ enable_summary = False
808
+ last_status = {} # for summary
809
+ token_overflow = False
810
+ if self.enable_checker:
811
+ checker = ReasoningTraceChecker(
812
+ message, conversation, init_index=len(conversation))
813
+
814
+ try:
815
+ while next_round and current_round < max_round:
816
+ current_round += 1
817
+ if len(last_outputs) > 0:
818
+ function_call_messages, picked_tools_prompt, special_tool_call, current_gradio_history = yield from self.run_function_call_stream(
819
+ last_outputs, return_message=True,
820
+ existing_tools_prompt=picked_tools_prompt,
821
+ message_for_call_agent=message,
822
+ call_agent=call_agent,
823
+ call_agent_level=call_agent_level,
824
+ temperature=temperature)
825
+ history.extend(current_gradio_history)
826
+ if special_tool_call == 'Finish':
827
+ yield history
828
+ next_round = False
829
+ conversation.extend(function_call_messages)
830
+ return function_call_messages[0]['content']
831
+ elif special_tool_call == 'RequireClarification' or special_tool_call == 'DirectResponse':
832
+ history.append(
833
+ ChatMessage(role="assistant", content=history[-1].content))
834
+ yield history
835
+ next_round = False
836
+ return history[-1].content
837
+ if (self.enable_summary or token_overflow) and not call_agent:
838
+ if token_overflow:
839
+ print("token_overflow, using summary")
840
+ enable_summary = True
841
+ last_status = self.function_result_summary(
842
+ conversation, status=last_status,
843
+ enable_summary=enable_summary)
844
+ if function_call_messages is not None:
845
+ conversation.extend(function_call_messages)
846
+ formated_md_function_call_messages = tool_result_format(
847
+ function_call_messages)
848
+ yield history
849
+ else:
850
+ next_round = False
851
+ conversation.extend(
852
+ [{"role": "assistant", "content": ''.join(last_outputs)}])
853
+ return ''.join(last_outputs).replace("</s>", "")
854
+ if self.enable_checker:
855
+ good_status, wrong_info = checker.check_conversation()
856
+ if not good_status:
857
+ next_round = False
858
+ print("Internal error in reasoning: " + wrong_info)
859
+ break
860
+ last_outputs = []
861
+ last_outputs_str, token_overflow = self.llm_infer(
862
+ messages=conversation,
863
+ temperature=temperature,
864
+ tools=picked_tools_prompt,
865
+ skip_special_tokens=False,
866
+ max_new_tokens=max_new_tokens,
867
+ max_token=max_token,
868
+ seed=seed,
869
+ check_token_status=True)
870
+ last_thought = last_outputs_str.split("[TOOL_CALLS]")[0]
871
+ for each in history:
872
+ if each.metadata is not None:
873
+ each.metadata['status'] = 'done'
874
+ if '[FinalAnswer]' in last_thought:
875
+ final_thought, final_answer = last_thought.split(
876
+ '[FinalAnswer]')
877
+ history.append(
878
+ ChatMessage(role="assistant",
879
+ content=final_thought.strip())
880
+ )
881
+ yield history
882
+ history.append(
883
+ ChatMessage(
884
+ role="assistant", content="**Answer**:\n"+final_answer.strip())
885
+ )
886
+ yield history
887
+ else:
888
+ history.append(ChatMessage(
889
+ role="assistant", content=last_thought))
890
+ yield history
891
+
892
+ last_outputs.append(last_outputs_str)
893
+
894
+ if next_round:
895
+ if self.force_finish:
896
+ last_outputs_str = self.get_answer_based_on_unfinished_reasoning(
897
+ conversation, temperature, max_new_tokens, max_token)
898
+ for each in history:
899
+ if each.metadata is not None:
900
+ each.metadata['status'] = 'done'
901
+ if '[FinalAnswer]' in last_thought:
902
+ final_thought, final_answer = last_thought.split(
903
+ '[FinalAnswer]')
904
+ history.append(
905
+ ChatMessage(role="assistant",
906
+ content=final_thought.strip())
907
+ )
908
+ yield history
909
+ history.append(
910
+ ChatMessage(
911
+ role="assistant", content="**Answer**:\n"+final_answer.strip())
912
+ )
913
+ yield history
914
+ else:
915
+ yield "The number of rounds exceeds the maximum limit!"
916
+
917
+ except Exception as e:
918
+ print(f"Error: {e}")
919
+ if self.force_finish:
920
+ last_outputs_str = self.get_answer_based_on_unfinished_reasoning(
921
+ conversation,
922
+ temperature,
923
+ max_new_tokens,
924
+ max_token)
925
+ for each in history:
926
+ if each.metadata is not None:
927
+ each.metadata['status'] = 'done'
928
+ if '[FinalAnswer]' in last_thought or '"name": "Finish",' in last_outputs_str:
929
+ if '[FinalAnswer]' in last_thought:
930
+ final_thought, final_answer = last_thought.split('[FinalAnswer]', 1)
931
+ else:
932
+ final_thought = ""
933
+ final_answer = last_thought
934
+ history.append(
935
+ ChatMessage(role="assistant",
936
+ content=final_thought.strip())
937
+ )
938
+ yield history
939
+ history.append(
940
+ ChatMessage(
941
+ role="assistant", content="**Answer**:\n" + final_answer.strip())
942
+ )
943
+ yield history
944
+ else:
945
+ return None
src/txagent/utils.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import json
3
+ import hashlib
4
+ import torch
5
+ from typing import List
6
+
7
+
8
+ def get_md5(input_str):
9
+ # Create an MD5 hash object
10
+ md5_hash = hashlib.md5()
11
+
12
+ # Encode the string and update the hash object
13
+ md5_hash.update(input_str.encode('utf-8'))
14
+
15
+ # Return the hexadecimal MD5 digest
16
+ return md5_hash.hexdigest()
17
+
18
+
19
+ def tool_result_format(function_call_messages):
20
+ current_output = "\n\n<details>\n<summary> <strong>Verfied Feedback from Tools</strong>, click to see details:</summary>\n\n"
21
+ for each_message in function_call_messages:
22
+ if each_message['role'] == 'tool':
23
+ current_output += f"{each_message['content']}\n\n"
24
+ current_output += "</details>\n\n\n"
25
+ return current_output
26
+
27
+
28
+ class NoRepeatSentenceProcessor:
29
+ def __init__(self, forbidden_sequences: List[List[int]], allowed_prefix_length: int):
30
+ """
31
+ Args:
32
+ forbidden_sequences (List[List[int]]): A list of token ID sequences corresponding to forbidden sentences.
33
+ allowed_prefix_length (int): The number k such that if the generated tokens match the first k tokens
34
+ of a forbidden sequence, then the candidate token that would extend the match is blocked.
35
+ """
36
+ self.allowed_prefix_length = allowed_prefix_length
37
+ # Build a lookup dictionary: key is a tuple of the first k tokens, value is a set of tokens to block.
38
+ self.forbidden_prefix_dict = {}
39
+ for seq in forbidden_sequences:
40
+ if len(seq) > allowed_prefix_length:
41
+ prefix = tuple(seq[:allowed_prefix_length])
42
+ next_token = seq[allowed_prefix_length]
43
+ self.forbidden_prefix_dict.setdefault(
44
+ prefix, set()).add(next_token)
45
+
46
+ def __call__(self, token_ids: List[int], logits: torch.Tensor) -> torch.Tensor:
47
+ """
48
+ Modifies the logits to block tokens that would extend a forbidden sentence.
49
+
50
+ Args:
51
+ token_ids (List[int]): List of token IDs generated so far.
52
+ logits (torch.Tensor): Logits tensor for the next token (shape: [vocab_size]).
53
+
54
+ Returns:
55
+ torch.Tensor: Modified logits.
56
+ """
57
+ if len(token_ids) >= self.allowed_prefix_length:
58
+ prefix = tuple(token_ids[:self.allowed_prefix_length])
59
+ if prefix in self.forbidden_prefix_dict:
60
+ for token_id in self.forbidden_prefix_dict[prefix]:
61
+ logits[token_id] = -float("inf")
62
+ return logits
63
+
64
+
65
+ class ReasoningTraceChecker:
66
+ def __init__(self, question, conversation, init_index=None):
67
+ self.question = question
68
+ self.conversation = conversation
69
+ self.existing_thoughts = []
70
+ self.existing_actions = []
71
+ if init_index is not None:
72
+ self.index = init_index
73
+ else:
74
+ self.index = 1
75
+ self.question = self.question.lower()
76
+ self.new_thoughts = []
77
+ self.new_actions = []
78
+
79
+ def check_conversation(self):
80
+ info = ''
81
+ current_index = self.index
82
+ for i in range(current_index, len(self.conversation)):
83
+ each = self.conversation[i]
84
+ self.index = i
85
+ if each['role'] == 'assistant':
86
+ print(each)
87
+ thought = each['content']
88
+ actions = each['tool_calls']
89
+
90
+ good_status, current_info = self.check_repeat_thought(thought)
91
+ info += current_info
92
+ if not good_status:
93
+ return False, info
94
+
95
+ good_status, current_info = self.check_repeat_action(actions)
96
+ info += current_info
97
+ if not good_status:
98
+ return False, info
99
+ return True, info
100
+
101
+ def check_repeat_thought(self, thought):
102
+ if thought in self.existing_thoughts:
103
+ return False, "repeat_thought"
104
+ self.existing_thoughts.append(thought)
105
+ return True, ''
106
+
107
+ def check_repeat_action(self, actions):
108
+ if type(actions) != list:
109
+ actions = json.loads(actions)
110
+ for each_action in actions:
111
+ if 'call_id' in each_action:
112
+ del each_action['call_id']
113
+ each_action = json.dumps(each_action)
114
+ if each_action in self.existing_actions:
115
+ return False, "repeat_action"
116
+ self.existing_actions.append(each_action)
117
+ return True, ''