File size: 4,569 Bytes
3e3b258
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
import sys
import json
import hashlib
import torch
from typing import List


def get_md5(input_str):
    # Create an MD5 hash object
    md5_hash = hashlib.md5()

    # Encode the string and update the hash object
    md5_hash.update(input_str.encode('utf-8'))

    # Return the hexadecimal MD5 digest
    return md5_hash.hexdigest()


def tool_result_format(function_call_messages):
    current_output = "\n\n<details>\n<summary> <strong>Verfied Feedback from Tools</strong>, click to see details:</summary>\n\n"
    for each_message in function_call_messages:
        if each_message['role'] == 'tool':
            current_output += f"{each_message['content']}\n\n"
    current_output += "</details>\n\n\n"
    return current_output


class NoRepeatSentenceProcessor:
    def __init__(self, forbidden_sequences: List[List[int]], allowed_prefix_length: int):
        """

        Args:

            forbidden_sequences (List[List[int]]): A list of token ID sequences corresponding to forbidden sentences.

            allowed_prefix_length (int): The number k such that if the generated tokens match the first k tokens 

                                         of a forbidden sequence, then the candidate token that would extend the match is blocked.

        """
        self.allowed_prefix_length = allowed_prefix_length
        # Build a lookup dictionary: key is a tuple of the first k tokens, value is a set of tokens to block.
        self.forbidden_prefix_dict = {}
        for seq in forbidden_sequences:
            if len(seq) > allowed_prefix_length:
                prefix = tuple(seq[:allowed_prefix_length])
                next_token = seq[allowed_prefix_length]
                self.forbidden_prefix_dict.setdefault(
                    prefix, set()).add(next_token)

    def __call__(self, token_ids: List[int], logits: torch.Tensor) -> torch.Tensor:
        """

        Modifies the logits to block tokens that would extend a forbidden sentence.



        Args:

            token_ids (List[int]): List of token IDs generated so far.

            logits (torch.Tensor): Logits tensor for the next token (shape: [vocab_size]).



        Returns:

            torch.Tensor: Modified logits.

        """
        if len(token_ids) >= self.allowed_prefix_length:
            prefix = tuple(token_ids[:self.allowed_prefix_length])
            if prefix in self.forbidden_prefix_dict:
                for token_id in self.forbidden_prefix_dict[prefix]:
                    logits[token_id] = -float("inf")
        return logits


class ReasoningTraceChecker:
    def __init__(self, question, conversation, init_index=None):
        self.question = question
        self.conversation = conversation
        self.existing_thoughts = []
        self.existing_actions = []
        if init_index is not None:
            self.index = init_index
        else:
            self.index = 1
        self.question = self.question.lower()
        self.new_thoughts = []
        self.new_actions = []

    def check_conversation(self):
        info = ''
        current_index = self.index
        for i in range(current_index, len(self.conversation)):
            each = self.conversation[i]
            self.index = i
            if each['role'] == 'assistant':
                print(each)
                thought = each['content']
                actions = each['tool_calls']

                good_status, current_info = self.check_repeat_thought(thought)
                info += current_info
                if not good_status:
                    return False, info

                good_status, current_info = self.check_repeat_action(actions)
                info += current_info
                if not good_status:
                    return False, info
        return True, info

    def check_repeat_thought(self, thought):
        if thought in self.existing_thoughts:
            return False, "repeat_thought"
        self.existing_thoughts.append(thought)
        return True, ''

    def check_repeat_action(self, actions):
        if type(actions) != list:
            actions = json.loads(actions)
        for each_action in actions:
            if 'call_id' in each_action:
                del each_action['call_id']
            each_action = json.dumps(each_action)
            if each_action in self.existing_actions:
                return False, "repeat_action"
            self.existing_actions.append(each_action)
        return True, ''