File size: 8,868 Bytes
256a159
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
import json
import os
import random
from pathlib import Path

import tiktoken
from datasets import Dataset

from opencompass.datasets.base import BaseDataset
from opencompass.openicl import BaseEvaluator
from opencompass.registry import LOAD_DATASET


def get_random_needles(file_path, needle_count):
    with open(file_path, 'r', encoding='utf-8') as file:
        data = json.load(file)

    matching_records = [
        record for record in data
        if record.get('derivation_count') == needle_count
    ]

    if matching_records:
        random_record = random.choice(matching_records)
        return {
            'needles': random_record['derivations'],
            'answer': random_record['answer'],
            'retrieval_question': random_record['question']
        }
    else:
        return None


@LOAD_DATASET.register_module()
class NeedleBenchMultiDataset(BaseDataset):

    @staticmethod
    def load(
        path: str,
        length: int,
        depth: int,
        tokenizer_model: str,
        file_list: 'list[str]',
        num_repeats_per_file: int,
        length_buffer: int,
        guide: bool,
        language: str,
        needle_file_name: str,
        num_needles: int,
        diff: int,
    ):
        data = {'prompt': [], 'answer': []}
        tokenizer = tiktoken.encoding_for_model(tokenizer_model)

        def _generate_context(tokens_context, depth_percent, needles):
            tokens_needle = [
                _get_tokens_from_context(needle) for needle in needles
            ]
            insertion_points = []
            total_length = len(tokens_context)

            for i, needle_tokens in enumerate(tokens_needle):
                if i == 0:
                    insertion_point = int(total_length * (depth_percent / 100))
                else:
                    insertion_point = int(insertion_points[i - 1] +
                                          len(tokens_needle[i - 1]) +
                                          total_length * (diff / 100))
                insertion_point = min(
                    insertion_point,
                    total_length + sum(len(tn) for tn in tokens_needle[:i]))
                insertion_points.append(insertion_point)

            for i, needle_tokens in enumerate(tokens_needle):
                tokens_context = tokens_context[:insertion_points[i]] \
                    + needle_tokens + tokens_context[insertion_points[i]:]
                for j in range(i + 1, len(insertion_points)):
                    insertion_points[j] += len(needle_tokens)

            new_context = _decode_tokens(tokens_context)
            return new_context

        def _get_tokens_from_context(context):
            if isinstance(context, list):
                return [tokenizer.encode(item) for item in context]
            else:
                return tokenizer.encode(context)

        def _decode_tokens(tokens):
            return tokenizer.decode(tokens)

        def _modify_retrieval_question(retrieval_question):
            if language == 'Chinese':
                guide_retrieval_question = (retrieval_question +
                                            '在回答之前,请思考文档中与此问题'
                                            '最相关的内容是什么。')
                return guide_retrieval_question
            elif language == 'English':
                guide_retrieval_question = (
                    retrieval_question + 'Before answering, please consider'
                    ' what in the document is most relevant to this question.')
                return guide_retrieval_question
            else:
                raise ValueError(f"Language '{language}' is not supported.")

        def _generate_prompt(context, retrieval_question):
            if guide:
                retrieval_question = _modify_retrieval_question(
                    retrieval_question)

            if language == 'Chinese':
                prompt = ('你是一个善于回答用户问题的智能AI助手\n'
                          '请保持你的回答简洁清楚。不要说和下面文档中的无关的话'
                          ',或重复你的回答\n'
                          f'用户现在给你的文档是{context}\n\n'
                          f'现在请问:{retrieval_question}')
            elif language == 'English':
                prompt = ('You are an intelligent AI assistant skilled in '
                          'answering user questions.\n'
                          'Please keep your answers concise and clear. Do not'
                          ' talk about irrelevant topics or repeat your '
                          'answers.\n'
                          f'The document given to you by the user is {context}'
                          f'\n\nNow, the question is: {retrieval_question}')
            else:
                raise ValueError(f"Language '{language}' is not supported.")

            return prompt

        files = Path(path).glob('*.jsonl')
        needle_file_path = os.path.join(path, needle_file_name)
        for file in files:
            if file.name not in file_list:
                continue

            with open(file, 'r', encoding='utf-8') as f:
                lines_bak = [json.loads(line.strip()) for line in f]
            lines = lines_bak.copy()
            for counter in range(num_repeats_per_file):
                random.seed(counter)
                random.shuffle(lines)
                random_needle_data = get_random_needles(
                    needle_file_path, num_needles)
                needles = [
                    '\n' + needle + '\n'
                    for needle in random_needle_data['needles']
                ]
                answer = random_needle_data['answer']
                keyword = answer
                retrieval_question = random_needle_data['retrieval_question']
                context_length = length - length_buffer
                target_length_per_record = context_length - \
                    sum(len(tokens) for tokens
                        in _get_tokens_from_context(needles))
                target_length_per_record = max(target_length_per_record, 0)
                accumulated_tokens = []
                for line in lines:
                    tokens_current_line = _get_tokens_from_context(
                        line['text'])
                    accumulated_tokens.extend(tokens_current_line)

                    if len(accumulated_tokens) >= target_length_per_record:
                        break

                processed_text = _generate_context(
                    accumulated_tokens[:target_length_per_record], depth,
                    needles)

                processed_prompt = _generate_prompt(processed_text,
                                                    retrieval_question)

                data['prompt'].append(processed_prompt)
                data['answer'].append(answer + '*' + keyword)

        dataset = Dataset.from_dict({
            'prompt': data['prompt'],
            'answer': data['answer'],
        })
        return dataset


class NeedleBenchMultiEvaluator(BaseEvaluator):

    def levenshtein_distance(self, s1, s2):
        if len(s1) < len(s2):
            return self.levenshtein_distance(s2, s1)

        if len(s2) == 0:
            return len(s1)

        previous_row = range(len(s2) + 1)
        for i, c1 in enumerate(s1):
            current_row = [i + 1]
            for j, c2 in enumerate(s2):
                insertions = previous_row[j + 1] + 1
                deletions = current_row[j] + 1
                substitutions = previous_row[j] + (c1 != c2)
                current_row.append(min(insertions, deletions, substitutions))
            previous_row = current_row

        return previous_row[-1]

    def score(self, predictions, gold):
        if len(predictions) != len(gold):
            return {'error': 'predictions and gold have different lengths'}

        total_score = 0
        details = []

        for prediction, reference in zip(predictions, gold):
            answer, keyword = reference.split('*')
            keywords = keyword.lower().split()
            prediction = prediction.lower()

            keyword_score = 100 / len(keywords) if keywords else 0

            matched_keywords = sum(1 for kword in keywords
                                   if kword in prediction)
            score = matched_keywords * keyword_score

            detail = {
                'pred': prediction,
                'answer': reference,
                'matched_keywords': matched_keywords,
                'score': score
            }

            total_score += score
            details.append(detail)

        average_score = total_score / len(predictions) if predictions else 0
        return {'score': average_score, 'details': details}