File size: 9,771 Bytes
493728d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
# Code modified from https://github.com/thunlp/LLMxMapReduce under Apache 2.0

import re
from typing import List

from .utils import logger


def get_prompt_length(prompt: str, tokenizer, no_special_tokens=False, **kwargs) -> int:
    """
    Returns the token length of a prompt using the given tokenizer.
    """
    if isinstance(prompt, list):
        prompt = "\n\n".join(prompt)
    if no_special_tokens:
        kwargs["add_special_tokens"] = False
    return len(tokenizer.encode(prompt, **kwargs))


def chunk_context(doc: str, chunk_size: int, tokenizer, separator="\n",) -> List[str]:
    """
    Splits a long document into token-limited chunks based on a separator, ensuring each chunk fits within `chunk_size`.

    Uses a greedy approach to accumulate text segments (split by `separator`) into chunks that fit within the
    token limit. If a segment alone exceeds the limit, it is recursively broken down using sentence-level
    splitting. Attempts to preserve natural boundaries while minimizing excessive chunking.

    Args:
        doc (str): Input document to split.
        chunk_size (int): Maximum number of tokens allowed per chunk.
        tokenizer: Tokenizer instance with `.encode()` method to compute token length.
        separator (str): Delimiter to split initial segments (default: newline).

    Returns:
        List[str]: List of non-empty, token-constrained document chunks.
    """
    paragraphs = doc.split(separator)
    paragraphs = [paragraph for paragraph in paragraphs if paragraph]
    separator_len = get_prompt_length(separator, tokenizer, no_special_tokens=True)

    docs = []
    current_doc = []
    total = 0
    for paragraph in paragraphs:
        plen = get_prompt_length(paragraph, tokenizer, no_special_tokens=True)
        if total + plen + (separator_len if len(current_doc) > 0 else 0) > chunk_size:
            if total > chunk_size:
                logger.info(
                    f"Created a chunk of size {total}, "
                    f"which is longer than the specified {chunk_size}"
                )
                # If single chunk is too long, split into more granular
                if len(current_doc) == 1:
                    split_again = split_into_granular_chunks(
                        current_doc[0], chunk_size, tokenizer
                    )
                    docs.extend(split_again)
                    current_doc = []
                    total = 0

            if len(current_doc) > 0:
                doc = separator.join(current_doc)
                if doc is not None:
                    docs.append(doc)
                while total > 0 or (
                    total + plen + (separator_len if len(current_doc) > 0 else 0)
                    > chunk_size
                    and total > 0
                ):
                    total -= get_prompt_length(
                        current_doc[0], tokenizer, no_special_tokens=True
                    ) + (separator_len if len(current_doc) > 1 else 0)
                    current_doc = current_doc[1:]

        current_doc.append(paragraph)
        total += plen + (separator_len if len(current_doc) > 1 else 0)
    # Check if the last one exceeds
    if (
        get_prompt_length(current_doc[-1], tokenizer, no_special_tokens=True)
        > chunk_size
        and len(current_doc) == 1
    ):
        split_again = split_into_granular_chunks(current_doc[0], chunk_size, tokenizer)
        docs.extend(split_again)
        current_doc = []
    else:
        doc = separator.join(current_doc)
        if doc is not None:
            docs.append(doc)

    return [doc for doc in docs if doc.strip()]


def split_sentences(text: str, spliter: str):
    """
    Splits text into sentences or segments based on a given delimiter while preserving punctuation.

    For punctuation-based splitters (e.g., ".", "!", "。"), it interleaves text and punctuation.
    For space-based splitting, it preserves trailing spaces.

    Args:
        text (str): The input text to split.
        spliter (str): Delimiter regex pattern (e.g., r"([.!?])", r"(。)", or " ").

    Returns:
        List[str]: List of split sentence-like segments with punctuation retained.
    """

    # Split by punctuation and keep punctuation
    text = text.strip()
    sentence_list = re.split(spliter, text)

    # Rearrange sentences and punctuation
    if spliter != " ":
        sentences = ["".join(i) for i in zip(sentence_list[0::2], sentence_list[1::2])]
        if len(sentence_list) % 2 != 0 and sentence_list[-1] != "":
            sentences.append(sentence_list[-1])
    else:
        sentences = [i + " " for i in sentence_list if i != ""]
        sentences[-1] = sentences[-1].strip()
    return sentences


def split_into_granular_chunks(
    text: str, chunk_size: int, tokenizer, spliter=r"([。!?;.?!;])",
) -> List[str]:
    """
    Splits long text into granular, token-length-constrained chunks using sentence boundaries.

    Sentences are first extracted using a delimiter pattern (`spliter`), then grouped into chunks such that
    each chunk does not exceed the specified `chunk_size` (in tokens). If a chunk still exceeds the limit,
    it is recursively broken down further using whitespace as a fallback.

    Ensures that the final chunks are balanced: if the last chunk is too small, it redistributes the last two
    chunks more evenly by re-splitting and re-allocating their sentences.

    Args:
        text (str): Input text to be chunked.
        chunk_size (int): Maximum number of tokens per chunk.
        tokenizer: Tokenizer instance with `.encode()` method to compute token length.
        spliter (str): Regex pattern to split sentences.

    Returns:
        List[str]: List of token-limited chunks, each composed of one or more sentences.
    """
    sentences = split_sentences(text, spliter)

    chunks = []
    current_chunk = ""

    for sentence in sentences:
        sentence_length = get_prompt_length(sentence, tokenizer)

        if get_prompt_length(current_chunk, tokenizer) + sentence_length <= chunk_size:
            current_chunk += sentence
        else:
            if current_chunk:
                if get_prompt_length(current_chunk, tokenizer) <= chunk_size:
                    chunks.append(current_chunk)
                else:
                    if spliter != " ":  # Avoid infinite loops
                        chunks.extend(
                            split_into_granular_chunks(
                                current_chunk,
                                chunk_size=chunk_size,
                                tokenizer=tokenizer,
                                spliter=" ",
                            )
                        )
            current_chunk = sentence

    if current_chunk != "":
        if get_prompt_length(current_chunk, tokenizer) <= chunk_size:
            chunks.append(current_chunk)
        else:
            if spliter != " ":  # Avoid infinite loops
                chunks.extend(
                    split_into_granular_chunks(
                        current_chunk,
                        chunk_size=chunk_size,
                        tokenizer=tokenizer,
                        spliter=" ",
                    )
                )

    # If last chunk too short, re-balance the last two chunks
    if len(chunks) > 1 and get_prompt_length(chunks[-1], tokenizer) < chunk_size // 2:
        last_chunk = chunks.pop()
        penultimate_chunk = chunks.pop()
        combined_text = penultimate_chunk + last_chunk

        new_sentences = split_sentences(combined_text, spliter)

        # Reallocate sentence using double pointer
        new_penultimate_chunk = ""
        new_last_chunk = ""
        start, end = 0, len(new_sentences) - 1

        while start <= end and len(new_sentences) != 1:
            flag = False
            if (
                get_prompt_length(
                    new_penultimate_chunk + new_sentences[start], tokenizer
                )
                <= chunk_size
            ):
                flag = True
                new_penultimate_chunk += new_sentences[start]
                if start == end:
                    break
                start += 1
            if (
                get_prompt_length(new_last_chunk + new_sentences[end], tokenizer)
                <= chunk_size
            ):
                new_last_chunk = new_sentences[end] + new_last_chunk
                end -= 1
                flag = True
            if flag == False:
                break
        if start < end:
            # If there is any unallocated part, split it by punctuation or space and then allocate it
            remaining_sentences = new_sentences[start : end + 1]
            if remaining_sentences:
                remaining_text = "".join(remaining_sentences)
                words = remaining_text.split(" ")
                end_index = len(words) - 1
                for index, w in enumerate(words):
                    if (
                        get_prompt_length(
                            " ".join([new_penultimate_chunk, w]), tokenizer
                        )
                        <= chunk_size
                    ):
                        new_penultimate_chunk = " ".join([new_penultimate_chunk, w])
                    else:
                        end_index = index
                        break
                if end_index != len(words) - 1:
                    new_last_chunk = " ".join(words[end_index:]) + " " + new_last_chunk
        if len(new_sentences) == 1:
            chunks.append(penultimate_chunk)
            chunks.append(last_chunk)
        else:
            chunks.append(new_penultimate_chunk)
            chunks.append(new_last_chunk)

    return chunks