|
import sys |
|
import json |
|
import hashlib |
|
import torch |
|
from typing import List |
|
|
|
|
|
def get_md5(input_str): |
|
|
|
md5_hash = hashlib.md5() |
|
md5_hash.update(input_str.encode('utf-8')) |
|
return md5_hash.hexdigest() |
|
|
|
|
|
def tool_result_format(function_call_messages): |
|
current_output = "\n\n<details>\n<summary> <strong>Verified Feedback from Tools</strong>, click to see details:</summary>\n\n" |
|
for each_message in function_call_messages: |
|
if each_message['role'] == 'tool': |
|
try: |
|
parsed = json.loads(each_message['content']) |
|
tool_name = parsed.get("tool_name", "Unknown Tool") |
|
tool_output = parsed.get("content", each_message['content']) |
|
current_output += f"**🔧 Tool: {tool_name}**\n\n{tool_output}\n\n" |
|
except Exception: |
|
current_output += f"{each_message['content']}\n\n" |
|
current_output += "</details>\n\n\n" |
|
return current_output |
|
|
|
|
|
class NoRepeatSentenceProcessor: |
|
def __init__(self, forbidden_sequences: List[List[int]], allowed_prefix_length: int): |
|
self.allowed_prefix_length = allowed_prefix_length |
|
self.forbidden_prefix_dict = {} |
|
for seq in forbidden_sequences: |
|
if len(seq) > allowed_prefix_length: |
|
prefix = tuple(seq[:allowed_prefix_length]) |
|
next_token = seq[allowed_prefix_length] |
|
self.forbidden_prefix_dict.setdefault(prefix, set()).add(next_token) |
|
|
|
def __call__(self, token_ids: List[int], logits: torch.Tensor) -> torch.Tensor: |
|
if len(token_ids) >= self.allowed_prefix_length: |
|
prefix = tuple(token_ids[:self.allowed_prefix_length]) |
|
if prefix in self.forbidden_prefix_dict: |
|
for token_id in self.forbidden_prefix_dict[prefix]: |
|
logits[token_id] = -float("inf") |
|
return logits |
|
|
|
|
|
class ReasoningTraceChecker: |
|
def __init__(self, question, conversation, init_index=None): |
|
self.question = question.lower() |
|
self.conversation = conversation |
|
self.existing_thoughts = [] |
|
self.existing_actions = [] |
|
self.new_thoughts = [] |
|
self.new_actions = [] |
|
self.index = init_index if init_index is not None else 1 |
|
|
|
def check_conversation(self): |
|
info = '' |
|
current_index = self.index |
|
for i in range(current_index, len(self.conversation)): |
|
each = self.conversation[i] |
|
self.index = i |
|
if each['role'] == 'assistant': |
|
thought = each['content'] |
|
actions = each['tool_calls'] |
|
good_status, current_info = self.check_repeat_thought(thought) |
|
info += current_info |
|
if not good_status: |
|
return False, info |
|
good_status, current_info = self.check_repeat_action(actions) |
|
info += current_info |
|
if not good_status: |
|
return False, info |
|
return True, info |
|
|
|
def check_repeat_thought(self, thought): |
|
if thought in self.existing_thoughts: |
|
return False, "repeat_thought" |
|
self.existing_thoughts.append(thought) |
|
return True, '' |
|
|
|
def check_repeat_action(self, actions): |
|
if type(actions) != list: |
|
actions = json.loads(actions) |
|
for each_action in actions: |
|
if 'call_id' in each_action: |
|
del each_action['call_id'] |
|
each_action = json.dumps(each_action) |
|
if each_action in self.existing_actions: |
|
return False, "repeat_action" |
|
self.existing_actions.append(each_action) |
|
return True, '' |
|
|