|
import gradio as gr |
|
import os |
|
import sys |
|
import json |
|
import gc |
|
import numpy as np |
|
from vllm import LLM, SamplingParams |
|
from jinja2 import Template |
|
from typing import List |
|
import types |
|
from tooluniverse import ToolUniverse |
|
from gradio import ChatMessage |
|
from .toolrag import ToolRAGModel |
|
import torch |
|
import logging |
|
|
|
|
|
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') |
|
logger = logging.getLogger("TxAgent") |
|
|
|
from .utils import NoRepeatSentenceProcessor, ReasoningTraceChecker, tool_result_format |
|
|
|
class TxAgent: |
|
def __init__(self, model_name, |
|
rag_model_name, |
|
tool_files_dict=None, |
|
enable_finish=True, |
|
enable_rag=False, |
|
enable_summary=False, |
|
init_rag_num=0, |
|
step_rag_num=0, |
|
summary_mode='step', |
|
summary_skip_last_k=0, |
|
summary_context_length=None, |
|
force_finish=True, |
|
avoid_repeat=True, |
|
seed=None, |
|
enable_checker=False, |
|
enable_chat=False, |
|
additional_default_tools=None): |
|
self.model_name = model_name |
|
self.tokenizer = None |
|
self.terminators = None |
|
self.rag_model_name = rag_model_name |
|
self.tool_files_dict = tool_files_dict |
|
self.model = None |
|
self.rag_model = ToolRAGModel(rag_model_name) |
|
self.tooluniverse = None |
|
self.prompt_multi_step = "You are a helpful assistant that solves problems through step-by-step reasoning." |
|
self.self_prompt = "Strictly follow the instruction." |
|
self.chat_prompt = "You are a helpful assistant for user chat." |
|
self.enable_finish = enable_finish |
|
self.enable_rag = enable_rag |
|
self.enable_summary = enable_summary |
|
self.summary_mode = summary_mode |
|
self.summary_skip_last_k = summary_skip_last_k |
|
self.summary_context_length = summary_context_length |
|
self.init_rag_num = init_rag_num |
|
self.step_rag_num = step_rag_num |
|
self.force_finish = force_finish |
|
self.avoid_repeat = avoid_repeat |
|
self.seed = seed |
|
self.enable_checker = enable_checker |
|
self.additional_default_tools = additional_default_tools |
|
logger.info("TxAgent initialized with model: %s, RAG: %s", model_name, rag_model_name) |
|
|
|
def init_model(self): |
|
self.load_models() |
|
self.load_tooluniverse() |
|
|
|
def load_models(self, model_name=None): |
|
if model_name is not None: |
|
if model_name == self.model_name: |
|
return f"The model {model_name} is already loaded." |
|
self.model_name = model_name |
|
|
|
self.model = LLM( |
|
model=self.model_name, |
|
dtype="float16", |
|
max_model_len=131072, |
|
max_num_batched_tokens=65536, |
|
max_num_seqs=512, |
|
gpu_memory_utilization=0.95, |
|
trust_remote_code=True, |
|
) |
|
self.chat_template = Template(self.model.get_tokenizer().chat_template) |
|
self.tokenizer = self.model.get_tokenizer() |
|
logger.info( |
|
"Model %s loaded with max_model_len=%d, max_num_batched_tokens=%d, gpu_memory_utilization=%.2f", |
|
self.model_name, 131072, 32768, 0.9 |
|
) |
|
return f"Model {model_name} loaded successfully." |
|
|
|
def load_tooluniverse(self): |
|
self.tooluniverse = ToolUniverse(tool_files=self.tool_files_dict) |
|
self.tooluniverse.load_tools() |
|
special_tools = self.tooluniverse.prepare_tool_prompts( |
|
self.tooluniverse.tool_category_dicts["special_tools"]) |
|
self.special_tools_name = [tool['name'] for tool in special_tools] |
|
logger.debug("ToolUniverse loaded with %d special tools", len(self.special_tools_name)) |
|
|
|
def load_tool_desc_embedding(self): |
|
cache_path = os.path.join(os.path.dirname(self.tool_files_dict["new_tool"]), "tool_embeddings.pkl") |
|
if os.path.exists(cache_path): |
|
self.rag_model.load_cached_embeddings(cache_path) |
|
else: |
|
self.rag_model.load_tool_desc_embedding(self.tooluniverse) |
|
self.rag_model.save_embeddings(cache_path) |
|
logger.debug("Tool description embeddings loaded") |
|
|
|
def rag_infer(self, query, top_k=5): |
|
return self.rag_model.rag_infer(query, top_k) |
|
|
|
def initialize_tools_prompt(self, call_agent, call_agent_level, message): |
|
picked_tools_prompt = [] |
|
picked_tools_prompt = self.add_special_tools( |
|
picked_tools_prompt, call_agent=call_agent) |
|
if call_agent: |
|
call_agent_level += 1 |
|
if call_agent_level >= 2: |
|
call_agent = False |
|
return picked_tools_prompt, call_agent_level |
|
|
|
def initialize_conversation(self, message, conversation=None, history=None): |
|
if conversation is None: |
|
conversation = [] |
|
|
|
conversation = self.set_system_prompt( |
|
conversation, self.prompt_multi_step) |
|
if history: |
|
for i in range(len(history)): |
|
if history[i]['role'] == 'user': |
|
conversation.append({"role": "user", "content": history[i]['content']}) |
|
elif history[i]['role'] == 'assistant': |
|
conversation.append({"role": "assistant", "content": history[i]['content']}) |
|
conversation.append({"role": "user", "content": message}) |
|
logger.debug("Conversation initialized with %d messages", len(conversation)) |
|
return conversation |
|
|
|
def tool_RAG(self, message=None, |
|
picked_tool_names=None, |
|
existing_tools_prompt=[], |
|
rag_num=0, |
|
return_call_result=False): |
|
if not self.enable_rag: |
|
return [] |
|
extra_factor = 10 |
|
if picked_tool_names is None: |
|
assert picked_tool_names is not None or message is not None |
|
picked_tool_names = self.rag_infer( |
|
message, top_k=rag_num * extra_factor) |
|
|
|
picked_tool_names_no_special = [tool for tool in picked_tool_names if tool not in self.special_tools_name] |
|
picked_tool_names = picked_tool_names_no_special[:rag_num] |
|
|
|
picked_tools = self.tooluniverse.get_tool_by_name(picked_tool_names) |
|
picked_tools_prompt = self.tooluniverse.prepare_tool_prompts(picked_tools) |
|
logger.debug("Retrieved %d tools via RAG", len(picked_tools_prompt)) |
|
if return_call_result: |
|
return picked_tools_prompt, picked_tool_names |
|
return picked_tools_prompt |
|
|
|
def add_special_tools(self, tools, call_agent=False): |
|
if self.enable_finish: |
|
tools.append(self.tooluniverse.get_one_tool_by_one_name('Finish', return_prompt=True)) |
|
logger.debug("Finish tool added") |
|
if call_agent: |
|
tools.append(self.tooluniverse.get_one_tool_by_one_name('CallAgent', return_prompt=True)) |
|
logger.debug("CallAgent tool added") |
|
return tools |
|
|
|
def add_finish_tools(self, tools): |
|
tools.append(self.tooluniverse.get_one_tool_by_one_name('Finish', return_prompt=True)) |
|
logger.debug("Finish tool added") |
|
return tools |
|
|
|
def set_system_prompt(self, conversation, sys_prompt): |
|
if not conversation: |
|
conversation.append({"role": "system", "content": sys_prompt}) |
|
else: |
|
conversation[0] = {"role": "system", "content": sys_prompt} |
|
return conversation |
|
|
|
def run_function_call(self, fcall_str, |
|
return_message=False, |
|
existing_tools_prompt=None, |
|
message_for_call_agent=None, |
|
call_agent=False, |
|
call_agent_level=None, |
|
temperature=None): |
|
try: |
|
function_call_json, message = self.tooluniverse.extract_function_call_json( |
|
fcall_str, return_message=return_message, verbose=False) |
|
except Exception as e: |
|
logger.error("Tool call parsing failed: %s", e) |
|
function_call_json = [] |
|
message = fcall_str |
|
|
|
call_results = [] |
|
special_tool_call = '' |
|
if function_call_json: |
|
if isinstance(function_call_json, list): |
|
for i in range(len(function_call_json)): |
|
logger.info("Tool Call: %s", function_call_json[i]) |
|
if function_call_json[i]["name"] == 'Finish': |
|
special_tool_call = 'Finish' |
|
break |
|
elif function_call_json[i]["name"] == 'CallAgent': |
|
if call_agent_level < 2 and call_agent: |
|
solution_plan = function_call_json[i]['arguments']['solution'] |
|
full_message = ( |
|
message_for_call_agent + |
|
"\nYou must follow the following plan to answer the question: " + |
|
str(solution_plan) |
|
) |
|
call_result = self.run_multistep_agent( |
|
full_message, temperature=temperature, |
|
max_new_tokens=512, max_token=131072, |
|
call_agent=False, call_agent_level=call_agent_level) |
|
if call_result is None: |
|
call_result = "⚠️ No content returned from sub-agent." |
|
else: |
|
call_result = call_result.split('[FinalAnswer]')[-1].strip() |
|
else: |
|
call_result = "Error: CallAgent disabled." |
|
else: |
|
call_result = self.tooluniverse.run_one_function(function_call_json[i]) |
|
call_id = self.tooluniverse.call_id_gen() |
|
function_call_json[i]["call_id"] = call_id |
|
logger.info("Tool Call Result: %s", call_result) |
|
call_results.append({ |
|
"role": "tool", |
|
"content": json.dumps({"tool_name": function_call_json[i]["name"], "content": call_result, "call_id": call_id}) |
|
}) |
|
else: |
|
call_results.append({ |
|
"role": "tool", |
|
"content": json.dumps({"content": "Invalid or no function call detected."}) |
|
}) |
|
|
|
revised_messages = [{ |
|
"role": "assistant", |
|
"content": message.strip(), |
|
"tool_calls": json.dumps(function_call_json) |
|
}] + call_results |
|
return revised_messages, existing_tools_prompt, special_tool_call |
|
|
|
def run_function_call_stream(self, fcall_str, |
|
return_message=False, |
|
existing_tools_prompt=None, |
|
message_for_call_agent=None, |
|
call_agent=False, |
|
call_agent_level=None, |
|
temperature=None, |
|
return_gradio_history=True): |
|
try: |
|
function_call_json, message = self.tooluniverse.extract_function_call_json( |
|
fcall_str, return_message=return_message, verbose=False) |
|
except Exception as e: |
|
logger.error("Tool call parsing failed: %s", e) |
|
function_call_json = [] |
|
message = fcall_str |
|
|
|
call_results = [] |
|
special_tool_call = '' |
|
if return_gradio_history: |
|
gradio_history = [] |
|
if function_call_json: |
|
if isinstance(function_call_json, list): |
|
for i in range(len(function_call_json)): |
|
if function_call_json[i]["name"] == 'Finish': |
|
special_tool_call = 'Finish' |
|
break |
|
elif function_call_json[i]["name"] == 'DirectResponse': |
|
call_result = function_call_json[i]['arguments']['respose'] |
|
special_tool_call = 'DirectResponse' |
|
elif function_call_json[i]["name"] == 'RequireClarification': |
|
call_result = function_call_json[i]['arguments']['unclear_question'] |
|
special_tool_call = 'RequireClarification' |
|
elif function_call_json[i]["name"] == 'CallAgent': |
|
if call_agent_level < 2 and call_agent: |
|
solution_plan = function_call_json[i]['arguments']['solution'] |
|
full_message = ( |
|
message_for_call_agent + |
|
"\nYou must follow the following plan to answer the question: " + |
|
str(solution_plan) |
|
) |
|
sub_agent_task = "Sub TxAgent plan: " + str(solution_plan) |
|
call_result = yield from self.run_gradio_chat( |
|
full_message, history=[], temperature=temperature, |
|
max_new_tokens=512, max_token=131072, |
|
call_agent=False, call_agent_level=call_agent_level, |
|
conversation=None, sub_agent_task=sub_agent_task) |
|
if call_result is not None and isinstance(call_result, str): |
|
call_result = call_result.split('[FinalAnswer]')[-1] |
|
else: |
|
call_result = "⚠️ No content returned from sub-agent." |
|
else: |
|
call_result = "Error: CallAgent disabled." |
|
else: |
|
call_result = self.tooluniverse.run_one_function(function_call_json[i]) |
|
call_id = self.tooluniverse.call_id_gen() |
|
function_call_json[i]["call_id"] = call_id |
|
call_results.append({ |
|
"role": "tool", |
|
"content": json.dumps({"tool_name": function_call_json[i]["name"], "content": call_result, "call_id": call_id}) |
|
}) |
|
if return_gradio_history and function_call_json[i]["name"] != 'Finish': |
|
metadata = {"title": f"🧰 {function_call_json[i]['name']}", "log": str(function_call_json[i]['arguments'])} |
|
gradio_history.append(ChatMessage(role="assistant", content=str(call_result), metadata=metadata)) |
|
else: |
|
call_results.append({ |
|
"role": "tool", |
|
"content": json.dumps({"content": "Invalid or no function call detected."}) |
|
}) |
|
|
|
revised_messages = [{ |
|
"role": "assistant", |
|
"content": message.strip(), |
|
"tool_calls": json.dumps(function_call_json) |
|
}] + call_results |
|
if return_gradio_history: |
|
return revised_messages, existing_tools_prompt, special_tool_call, gradio_history |
|
return revised_messages, existing_tools_prompt, special_tool_call |
|
|
|
def get_answer_based_on_unfinished_reasoning(self, conversation, temperature, max_new_tokens, max_token, outputs=None): |
|
if conversation[-1]['role'] == 'assistant': |
|
conversation.append( |
|
{'role': 'tool', 'content': 'Errors occurred during function call; provide final answer with current information.'}) |
|
finish_tools_prompt = self.add_finish_tools([]) |
|
last_outputs_str = self.llm_infer( |
|
messages=conversation, |
|
temperature=temperature, |
|
tools=finish_tools_prompt, |
|
output_begin_string='[FinalAnswer]', |
|
skip_special_tokens=True, |
|
max_new_tokens=max_new_tokens, |
|
max_token=max_token) |
|
logger.info("Unfinished reasoning answer: %s", last_outputs_str[:100]) |
|
return last_outputs_str |
|
|
|
def run_multistep_agent(self, message: str, |
|
temperature: float, |
|
max_new_tokens: int, |
|
max_token: int, |
|
max_round: int = 5, |
|
call_agent=False, |
|
call_agent_level=0): |
|
logger.info("Starting multistep agent for message: %s", message[:100]) |
|
picked_tools_prompt, call_agent_level = self.initialize_tools_prompt( |
|
call_agent, call_agent_level, message) |
|
conversation = self.initialize_conversation(message) |
|
outputs = [] |
|
last_outputs = [] |
|
next_round = True |
|
current_round = 0 |
|
token_overflow = False |
|
enable_summary = False |
|
last_status = {} |
|
|
|
while next_round and current_round < max_round: |
|
current_round += 1 |
|
if len(outputs) > 0: |
|
function_call_messages, picked_tools_prompt, special_tool_call = self.run_function_call( |
|
last_outputs, return_message=True, |
|
existing_tools_prompt=picked_tools_prompt, |
|
message_for_call_agent=message, |
|
call_agent=call_agent, |
|
call_agent_level=call_agent_level, |
|
temperature=temperature) |
|
|
|
if special_tool_call == 'Finish': |
|
next_round = False |
|
conversation.extend(function_call_messages) |
|
content = function_call_messages[0]['content'] |
|
if content is None: |
|
return "❌ No content returned after Finish tool call." |
|
return content.split('[FinalAnswer]')[-1] |
|
|
|
if (self.enable_summary or token_overflow) and not call_agent: |
|
enable_summary = True |
|
last_status = self.function_result_summary( |
|
conversation, status=last_status, enable_summary=enable_summary) |
|
|
|
if function_call_messages: |
|
conversation.extend(function_call_messages) |
|
outputs.append(tool_result_format(function_call_messages)) |
|
else: |
|
next_round = False |
|
conversation.extend([{"role": "assistant", "content": ''.join(last_outputs)}]) |
|
return ''.join(last_outputs).replace("</s>", "") |
|
|
|
last_outputs = [] |
|
outputs.append("### TxAgent:\n") |
|
last_outputs_str, token_overflow = self.llm_infer( |
|
messages=conversation, |
|
temperature=temperature, |
|
tools=picked_tools_prompt, |
|
skip_special_tokens=False, |
|
max_new_tokens=2048, |
|
max_token=131072, |
|
check_token_status=True) |
|
if last_outputs_str is None: |
|
logger.warning("Token limit exceeded") |
|
if self.force_finish: |
|
return self.get_answer_based_on_unfinished_reasoning( |
|
conversation, temperature, max_new_tokens, max_token) |
|
return "❌ Token limit exceeded." |
|
last_outputs.append(last_outputs_str) |
|
|
|
if max_round == current_round: |
|
logger.warning("Max rounds exceeded") |
|
if self.force_finish: |
|
return self.get_answer_based_on_unfinished_reasoning( |
|
conversation, temperature, max_new_tokens, max_token) |
|
return None |
|
|
|
def build_logits_processor(self, messages, llm): |
|
logger.warning("Logits processor disabled due to vLLM V1 limitation") |
|
return None |
|
|
|
def llm_infer(self, messages, temperature=0.1, tools=None, |
|
output_begin_string=None, max_new_tokens=512, |
|
max_token=131072, skip_special_tokens=True, |
|
model=None, tokenizer=None, terminators=None, |
|
seed=None, check_token_status=False): |
|
if model is None: |
|
model = self.model |
|
|
|
logits_processor = self.build_logits_processor(messages, model) |
|
sampling_params = SamplingParams( |
|
temperature=temperature, |
|
max_tokens=max_new_tokens, |
|
seed=seed if seed is not None else self.seed, |
|
) |
|
|
|
prompt = self.chat_template.render( |
|
messages=messages, tools=tools, add_generation_prompt=True) |
|
if output_begin_string is not None: |
|
prompt += output_begin_string |
|
|
|
if check_token_status and max_token is not None: |
|
token_overflow = False |
|
num_input_tokens = len(self.tokenizer.encode(prompt, add_special_tokens=False)) |
|
logger.info("Input prompt tokens: %d, max_token: %d", num_input_tokens, max_token) |
|
if num_input_tokens > max_token: |
|
torch.cuda.empty_cache() |
|
gc.collect() |
|
logger.warning("Token overflow: %d > %d", num_input_tokens, max_token) |
|
return None, True |
|
|
|
output = model.generate(prompt, sampling_params=sampling_params) |
|
output_text = output[0].outputs[0].text |
|
output_tokens = len(self.tokenizer.encode(output_text, add_special_tokens=False)) |
|
logger.debug("Inference output: %s (output tokens: %d)", output_text[:100], output_tokens) |
|
torch.cuda.empty_cache() |
|
gc.collect() |
|
if check_token_status and max_token is not None: |
|
return output_text, token_overflow |
|
return output_text |
|
|
|
def run_self_agent(self, message: str, |
|
temperature: float, |
|
max_new_tokens: int, |
|
max_token: int): |
|
logger.info("Starting self agent") |
|
conversation = self.set_system_prompt([], self.self_prompt) |
|
conversation.append({"role": "user", "content": message}) |
|
return self.llm_infer( |
|
messages=conversation, |
|
temperature=temperature, |
|
tools=None, |
|
max_new_tokens=max_new_tokens, |
|
max_token=max_token) |
|
|
|
def run_chat_agent(self, message: str, |
|
temperature: float, |
|
max_new_tokens: int, |
|
max_token: int): |
|
logger.info("Starting chat agent") |
|
conversation = self.set_system_prompt([], self.chat_prompt) |
|
conversation.append({"role": "user", "content": message}) |
|
return self.llm_infer( |
|
messages=conversation, |
|
temperature=temperature, |
|
tools=None, |
|
max_new_tokens=max_new_tokens, |
|
max_token=max_token) |
|
|
|
def run_format_agent(self, message: str, |
|
answer: str, |
|
temperature: float, |
|
max_new_tokens: int, |
|
max_token: int): |
|
logger.info("Starting format agent") |
|
if '[FinalAnswer]' in answer: |
|
possible_final_answer = answer.split("[FinalAnswer]")[-1] |
|
elif "\n\n" in answer: |
|
possible_final_answer = answer.split("\n\n")[-1] |
|
else: |
|
possible_final_answer = answer.strip() |
|
if len(possible_final_answer) == 1 and possible_final_answer in ['A', 'B', 'C', 'D', 'E']: |
|
return possible_final_answer |
|
elif len(possible_final_answer) > 1 and possible_final_answer[1] == ':' and possible_final_answer[0] in ['A', 'B', 'C', 'D', 'E']: |
|
return possible_final_answer[0] |
|
|
|
conversation = self.set_system_prompt( |
|
[], "Transform the agent's answer to a single letter: 'A', 'B', 'C', 'D'.") |
|
conversation.append({"role": "user", "content": message + |
|
"\nAgent's answer: " + answer + "\nAnswer (must be a letter):"}) |
|
return self.llm_infer( |
|
messages=conversation, |
|
temperature=temperature, |
|
tools=None, |
|
max_new_tokens=max_new_tokens, |
|
max_token=max_token) |
|
|
|
def run_summary_agent(self, thought_calls: str, |
|
function_response: str, |
|
temperature: float, |
|
max_new_tokens: int, |
|
max_token: int): |
|
logger.info("Summarizing tool result") |
|
prompt = f"""Thought and function calls: |
|
{thought_calls} |
|
Function calls' responses: |
|
\"\"\" |
|
{function_response} |
|
\"\"\" |
|
Summarize the function calls' l responses in one sentence with all necessary information. |
|
""" |
|
conversation = [{"role": "user", "content": prompt}] |
|
output = self.llm_infer( |
|
messages=conversation, |
|
temperature=temperature, |
|
tools=None, |
|
max_new_tokens=max_new_tokens, |
|
max_token=max_token) |
|
if '[' in output: |
|
output = output.split('[')[0] |
|
return output |
|
|
|
def function_result_summary(self, input_list, status, enable_summary): |
|
if 'tool_call_step' not in status: |
|
status['tool_call_step'] = 0 |
|
for idx in range(len(input_list)): |
|
pos_id = len(input_list) - idx - 1 |
|
if input_list[pos_id]['role'] == 'assistant' and 'tool_calls' in input_list[pos_id]: |
|
break |
|
|
|
status['step'] = status.get('step', 0) + 1 |
|
if not enable_summary: |
|
return status |
|
|
|
status['summarized_index'] = status.get('summarized_index', 0) |
|
status['summarized_step'] = status.get('summarized_step', 0) |
|
status['previous_length'] = status.get('previous_length', 0) |
|
status['history'] = status.get('history', []) |
|
|
|
function_response = '' |
|
idx = status['summarized_index'] |
|
this_thought_calls = None |
|
|
|
while idx < len(input_list): |
|
if (self.summary_mode == 'step' and status['summarized_step'] < status['step'] - status['tool_call_step'] - self.summary_skip_last_k) or \ |
|
(self.summary_mode == 'length' and status['previous_length'] > self.summary_context_length): |
|
if input_list[idx]['role'] == 'assistant': |
|
if function_response: |
|
status['summarized_step'] += 1 |
|
result_summary = self.run_summary_agent( |
|
thought_calls=this_thought_calls, |
|
function_response=function_response, |
|
temperature=0.1, |
|
max_new_tokens=512, |
|
max_token=131072) |
|
input_list.insert(last_call_idx + 1, {'role': 'tool', 'content': result_summary}) |
|
status['summarized_index'] = last_call_idx + 2 |
|
idx += 1 |
|
last_call_idx = idx |
|
this_thought_calls = input_list[idx]['content'] + input_list[idx]['tool_calls'] |
|
function_response = '' |
|
elif input_list[idx]['role'] == 'tool' and this_thought_calls is not None: |
|
function_response += input_list[idx]['content'] |
|
del input_list[idx] |
|
idx -= 1 |
|
else: |
|
break |
|
idx += 1 |
|
|
|
if function_response: |
|
status['summarized_step'] += 1 |
|
result_summary = self.run_summary_agent( |
|
thought_calls=this_thought_calls, |
|
function_response=function_response, |
|
temperature=0.1, |
|
max_new_tokens=512, |
|
max_token=131072) |
|
tool_calls = json.loads(input_list[last_call_idx]['tool_calls']) |
|
for tool_call in tool_calls: |
|
del tool_call['call_id'] |
|
input_list[last_call_idx]['tool_calls'] = json.dumps(tool_calls) |
|
input_list.insert(last_call_idx + 1, {'role': 'tool', 'content': result_summary}) |
|
status['summarized_index'] = last_call_idx + 2 |
|
|
|
return status |
|
|
|
def update_parameters(self, **kwargs): |
|
updated_attributes = {} |
|
for key, value in kwargs.items(): |
|
if hasattr(self, key): |
|
setattr(self, key, value) |
|
updated_attributes[key] = value |
|
logger.info("Updated parameters: %s", updated_attributes) |
|
return updated_attributes |
|
|
|
def run_gradio_chat(self, message: str, |
|
history: list, |
|
temperature: float, |
|
max_new_tokens: int = 2048, |
|
max_token: int = 131072, |
|
call_agent: bool = False, |
|
conversation: gr.State = None, |
|
max_round: int = 5, |
|
seed: int = None, |
|
call_agent_level: int = 0, |
|
sub_agent_task: str = None, |
|
uploaded_files: list = None): |
|
logger.info("Chat started, message: %s", message[:100]) |
|
if not message or len(message.strip()) < 5: |
|
yield "Please provide a valid message or upload files to analyze." |
|
return |
|
|
|
picked_tools_prompt, call_agent_level = self.initialize_tools_prompt( |
|
call_agent, call_agent_level, message) |
|
conversation = self.initialize_conversation( |
|
message, conversation, history) |
|
history = [] |
|
last_outputs = [] |
|
|
|
next_round = True |
|
current_round = 0 |
|
enable_summary = False |
|
last_status = {} |
|
token_overflow = False |
|
|
|
try: |
|
while next_round and current_round < max_round: |
|
current_round += 1 |
|
logger.debug("Starting round %d/%d", current_round, max_round) |
|
if last_outputs: |
|
function_call_messages, picked_tools_prompt, special_tool_call, current_gradio_history = yield from self.run_function_call_stream( |
|
last_outputs, return_message=True, |
|
existing_tools_prompt=picked_tools_prompt, |
|
message_for_call_agent=message, |
|
call_agent=call_agent, |
|
call_agent_level=call_agent_level, |
|
temperature=temperature) |
|
history.extend(current_gradio_history) |
|
|
|
if special_tool_call == 'Finish': |
|
logger.info("Finish tool called, ending chat") |
|
yield history |
|
next_round = False |
|
conversation.extend(function_call_messages) |
|
content = function_call_messages[0]['content'] |
|
if content: |
|
return content |
|
return "No content returned after Finish tool call." |
|
|
|
elif special_tool_call in ['RequireClarification', 'DirectResponse']: |
|
last_msg = history[-1] if history else ChatMessage(role="assistant", content="Response needed.") |
|
history.append(ChatMessage(role="assistant", content=last_msg.content)) |
|
logger.info("Special tool %s called, ending chat", special_tool_call) |
|
yield history |
|
next_round = False |
|
return last_msg.content |
|
|
|
if (self.enable_summary or token_overflow) and not call_agent: |
|
enable_summary = True |
|
last_status = self.function_result_summary( |
|
conversation, status=last_status, enable_summary=enable_summary) |
|
|
|
if function_call_messages: |
|
conversation.extend(function_call_messages) |
|
yield history |
|
else: |
|
next_round = False |
|
conversation.append({"role": "assistant", "content": ''.join(last_outputs)}) |
|
logger.info("No function call messages, ending chat") |
|
return ''.join(last_outputs).replace("</s>", "") |
|
|
|
last_outputs = [] |
|
last_outputs_str, token_overflow = self.llm_infer( |
|
messages=conversation, |
|
temperature=temperature, |
|
tools=picked_tools_prompt, |
|
skip_special_tokens=False, |
|
max_new_tokens=max_new_tokens, |
|
max_token=max_token, |
|
seed=seed, |
|
check_token_status=True) |
|
|
|
if last_outputs_str is None: |
|
logger.warning("Token limit exceeded") |
|
if self.force_finish: |
|
last_outputs_str = self.get_answer_based_on_unfinished_reasoning( |
|
conversation, temperature, max_new_tokens, max_token) |
|
history.append(ChatMessage(role="assistant", content=last_outputs_str.strip())) |
|
yield history |
|
return last_outputs_str |
|
error_msg = "Token limit exceeded." |
|
history.append(ChatMessage(role="assistant", content=error_msg)) |
|
yield history |
|
return error_msg |
|
|
|
last_thought = last_outputs_str.split("[TOOL_CALLS]")[0] |
|
for msg in history: |
|
if msg.metadata is not None: |
|
msg.metadata['status'] = 'done' |
|
|
|
if '[FinalAnswer]' in last_thought: |
|
parts = last_thought.split('[FinalAnswer]', 1) |
|
final_thought, final_answer = parts if len(parts) == 2 else (last_thought, "") |
|
history.append(ChatMessage(role="assistant", content=final_thought.strip())) |
|
yield history |
|
history.append(ChatMessage(role="assistant", content="**🧠 Final Analysis:**\n" + final_answer.strip())) |
|
logger.info("Final answer provided: %s", final_answer[:100]) |
|
yield history |
|
next_round = False |
|
return final_answer |
|
else: |
|
history.append(ChatMessage(role="assistant", content=last_thought)) |
|
yield history |
|
|
|
last_outputs.append(last_outputs_str) |
|
|
|
if next_round: |
|
if self.force_finish: |
|
last_outputs_str = self.get_answer_based_on_unfinished_reasoning( |
|
conversation, temperature, max_new_tokens, max_token) |
|
parts = last_outputs_str.split('[FinalAnswer]', 1) |
|
final_thought, final_answer = parts if len(parts) == 2 else (last_outputs_str, "") |
|
history.append(ChatMessage(role="assistant", content=final_thought.strip())) |
|
yield history |
|
history.append(ChatMessage(role="assistant", content="**🧠 Final Analysis:**\n" + final_answer.strip())) |
|
logger.info("Forced final answer: %s", final_answer[:100]) |
|
yield history |
|
return final_answer |
|
else: |
|
error_msg = "Reasoning rounds exceeded limit." |
|
history.append(ChatMessage(role="assistant", content=error_msg)) |
|
yield history |
|
return error_msg |
|
|
|
except Exception as e: |
|
logger.error("Exception in run_gradio_chat: %s", e, exc_info=True) |
|
error_msg = f"Error: {e}" |
|
history.append(ChatMessage(role="assistant", content=error_msg)) |
|
yield history |
|
if self.force_finish: |
|
last_outputs_str = self.get_answer_based_on_unfinished_reasoning( |
|
conversation, temperature, max_new_tokens, max_token) |
|
parts = last_outputs_str.split('[FinalAnswer]', 1) |
|
final_thought, final_answer = parts if len(parts) == 2 else (last_outputs_str, "") |
|
history.append(ChatMessage(role="assistant", content=final_thought.strip())) |
|
yield history |
|
history.append(ChatMessage(role="assistant", content="**🧠 Final Analysis:**\n" + final_answer.strip())) |
|
logger.info("Forced final answer after error: %s", final_answer[:100]) |
|
yield history |
|
return final_answer |
|
return error_msg |
|
|
|
def run_gradio_chat_batch(self, messages: List[str], |
|
temperature: float, |
|
max_new_tokens: int = 2048, |
|
max_token: int = 131072, |
|
call_agent: bool = False, |
|
conversation: List = None, |
|
max_round: int = 5, |
|
seed: int = None, |
|
call_agent_level: int = 0): |
|
"""Run batch inference for multiple messages.""" |
|
logger.info("Starting batch chat for %d messages", len(messages)) |
|
batch_results = [] |
|
|
|
for message in messages: |
|
|
|
conv = self.initialize_conversation(message, conversation, history=None) |
|
picked_tools_prompt, call_agent_level = self.initialize_tools_prompt( |
|
call_agent, call_agent_level, message) |
|
|
|
|
|
output, token_overflow = self.llm_infer( |
|
messages=conv, |
|
temperature=temperature, |
|
tools=picked_tools_prompt, |
|
max_new_tokens=max_new_tokens, |
|
max_token=max_token, |
|
skip_special_tokens=False, |
|
seed=seed, |
|
check_token_status=True |
|
) |
|
|
|
if output is None: |
|
logger.warning("Token limit exceeded for message: %s", message[:100]) |
|
batch_results.append("Token limit exceeded.") |
|
else: |
|
batch_results.append(output) |
|
|
|
logger.info("Batch chat completed for %d messages", len(messages)) |
|
return batch_results |