File size: 3,709 Bytes
6ecf798
 
 
 
 
83eff00
6ecf798
 
 
 
 
 
 
 
 
83eff00
6ecf798
 
 
83eff00
 
 
 
 
 
 
 
6ecf798
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
83eff00
 
6ecf798
83eff00
6ecf798
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
import sys
import json
import hashlib
import torch
from typing import List


def get_md5(input_str):
    # Create an MD5 hash object
    md5_hash = hashlib.md5()
    md5_hash.update(input_str.encode('utf-8'))
    return md5_hash.hexdigest()


def tool_result_format(function_call_messages):
    current_output = "\n\n<details>\n<summary> <strong>Verified Feedback from Tools</strong>, click to see details:</summary>\n\n"
    for each_message in function_call_messages:
        if each_message['role'] == 'tool':
            try:
                parsed = json.loads(each_message['content'])
                tool_name = parsed.get("tool_name", "Unknown Tool")
                tool_output = parsed.get("content", each_message['content'])
                current_output += f"**🔧 Tool: {tool_name}**\n\n{tool_output}\n\n"
            except Exception:
                current_output += f"{each_message['content']}\n\n"
    current_output += "</details>\n\n\n"
    return current_output


class NoRepeatSentenceProcessor:
    def __init__(self, forbidden_sequences: List[List[int]], allowed_prefix_length: int):
        self.allowed_prefix_length = allowed_prefix_length
        self.forbidden_prefix_dict = {}
        for seq in forbidden_sequences:
            if len(seq) > allowed_prefix_length:
                prefix = tuple(seq[:allowed_prefix_length])
                next_token = seq[allowed_prefix_length]
                self.forbidden_prefix_dict.setdefault(prefix, set()).add(next_token)

    def __call__(self, token_ids: List[int], logits: torch.Tensor) -> torch.Tensor:
        if len(token_ids) >= self.allowed_prefix_length:
            prefix = tuple(token_ids[:self.allowed_prefix_length])
            if prefix in self.forbidden_prefix_dict:
                for token_id in self.forbidden_prefix_dict[prefix]:
                    logits[token_id] = -float("inf")
        return logits


class ReasoningTraceChecker:
    def __init__(self, question, conversation, init_index=None):
        self.question = question.lower()
        self.conversation = conversation
        self.existing_thoughts = []
        self.existing_actions = []
        self.new_thoughts = []
        self.new_actions = []
        self.index = init_index if init_index is not None else 1

    def check_conversation(self):
        info = ''
        current_index = self.index
        for i in range(current_index, len(self.conversation)):
            each = self.conversation[i]
            self.index = i
            if each['role'] == 'assistant':
                thought = each['content']
                actions = each['tool_calls']
                good_status, current_info = self.check_repeat_thought(thought)
                info += current_info
                if not good_status:
                    return False, info
                good_status, current_info = self.check_repeat_action(actions)
                info += current_info
                if not good_status:
                    return False, info
        return True, info

    def check_repeat_thought(self, thought):
        if thought in self.existing_thoughts:
            return False, "repeat_thought"
        self.existing_thoughts.append(thought)
        return True, ''

    def check_repeat_action(self, actions):
        if type(actions) != list:
            actions = json.loads(actions)
        for each_action in actions:
            if 'call_id' in each_action:
                del each_action['call_id']
            each_action = json.dumps(each_action)
            if each_action in self.existing_actions:
                return False, "repeat_action"
            self.existing_actions.append(each_action)
        return True, ''