File size: 9,771 Bytes
493728d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 |
# Code modified from https://github.com/thunlp/LLMxMapReduce under Apache 2.0
import re
from typing import List
from .utils import logger
def get_prompt_length(prompt: str, tokenizer, no_special_tokens=False, **kwargs) -> int:
"""
Returns the token length of a prompt using the given tokenizer.
"""
if isinstance(prompt, list):
prompt = "\n\n".join(prompt)
if no_special_tokens:
kwargs["add_special_tokens"] = False
return len(tokenizer.encode(prompt, **kwargs))
def chunk_context(doc: str, chunk_size: int, tokenizer, separator="\n",) -> List[str]:
"""
Splits a long document into token-limited chunks based on a separator, ensuring each chunk fits within `chunk_size`.
Uses a greedy approach to accumulate text segments (split by `separator`) into chunks that fit within the
token limit. If a segment alone exceeds the limit, it is recursively broken down using sentence-level
splitting. Attempts to preserve natural boundaries while minimizing excessive chunking.
Args:
doc (str): Input document to split.
chunk_size (int): Maximum number of tokens allowed per chunk.
tokenizer: Tokenizer instance with `.encode()` method to compute token length.
separator (str): Delimiter to split initial segments (default: newline).
Returns:
List[str]: List of non-empty, token-constrained document chunks.
"""
paragraphs = doc.split(separator)
paragraphs = [paragraph for paragraph in paragraphs if paragraph]
separator_len = get_prompt_length(separator, tokenizer, no_special_tokens=True)
docs = []
current_doc = []
total = 0
for paragraph in paragraphs:
plen = get_prompt_length(paragraph, tokenizer, no_special_tokens=True)
if total + plen + (separator_len if len(current_doc) > 0 else 0) > chunk_size:
if total > chunk_size:
logger.info(
f"Created a chunk of size {total}, "
f"which is longer than the specified {chunk_size}"
)
# If single chunk is too long, split into more granular
if len(current_doc) == 1:
split_again = split_into_granular_chunks(
current_doc[0], chunk_size, tokenizer
)
docs.extend(split_again)
current_doc = []
total = 0
if len(current_doc) > 0:
doc = separator.join(current_doc)
if doc is not None:
docs.append(doc)
while total > 0 or (
total + plen + (separator_len if len(current_doc) > 0 else 0)
> chunk_size
and total > 0
):
total -= get_prompt_length(
current_doc[0], tokenizer, no_special_tokens=True
) + (separator_len if len(current_doc) > 1 else 0)
current_doc = current_doc[1:]
current_doc.append(paragraph)
total += plen + (separator_len if len(current_doc) > 1 else 0)
# Check if the last one exceeds
if (
get_prompt_length(current_doc[-1], tokenizer, no_special_tokens=True)
> chunk_size
and len(current_doc) == 1
):
split_again = split_into_granular_chunks(current_doc[0], chunk_size, tokenizer)
docs.extend(split_again)
current_doc = []
else:
doc = separator.join(current_doc)
if doc is not None:
docs.append(doc)
return [doc for doc in docs if doc.strip()]
def split_sentences(text: str, spliter: str):
"""
Splits text into sentences or segments based on a given delimiter while preserving punctuation.
For punctuation-based splitters (e.g., ".", "!", "。"), it interleaves text and punctuation.
For space-based splitting, it preserves trailing spaces.
Args:
text (str): The input text to split.
spliter (str): Delimiter regex pattern (e.g., r"([.!?])", r"(。)", or " ").
Returns:
List[str]: List of split sentence-like segments with punctuation retained.
"""
# Split by punctuation and keep punctuation
text = text.strip()
sentence_list = re.split(spliter, text)
# Rearrange sentences and punctuation
if spliter != " ":
sentences = ["".join(i) for i in zip(sentence_list[0::2], sentence_list[1::2])]
if len(sentence_list) % 2 != 0 and sentence_list[-1] != "":
sentences.append(sentence_list[-1])
else:
sentences = [i + " " for i in sentence_list if i != ""]
sentences[-1] = sentences[-1].strip()
return sentences
def split_into_granular_chunks(
text: str, chunk_size: int, tokenizer, spliter=r"([。!?;.?!;])",
) -> List[str]:
"""
Splits long text into granular, token-length-constrained chunks using sentence boundaries.
Sentences are first extracted using a delimiter pattern (`spliter`), then grouped into chunks such that
each chunk does not exceed the specified `chunk_size` (in tokens). If a chunk still exceeds the limit,
it is recursively broken down further using whitespace as a fallback.
Ensures that the final chunks are balanced: if the last chunk is too small, it redistributes the last two
chunks more evenly by re-splitting and re-allocating their sentences.
Args:
text (str): Input text to be chunked.
chunk_size (int): Maximum number of tokens per chunk.
tokenizer: Tokenizer instance with `.encode()` method to compute token length.
spliter (str): Regex pattern to split sentences.
Returns:
List[str]: List of token-limited chunks, each composed of one or more sentences.
"""
sentences = split_sentences(text, spliter)
chunks = []
current_chunk = ""
for sentence in sentences:
sentence_length = get_prompt_length(sentence, tokenizer)
if get_prompt_length(current_chunk, tokenizer) + sentence_length <= chunk_size:
current_chunk += sentence
else:
if current_chunk:
if get_prompt_length(current_chunk, tokenizer) <= chunk_size:
chunks.append(current_chunk)
else:
if spliter != " ": # Avoid infinite loops
chunks.extend(
split_into_granular_chunks(
current_chunk,
chunk_size=chunk_size,
tokenizer=tokenizer,
spliter=" ",
)
)
current_chunk = sentence
if current_chunk != "":
if get_prompt_length(current_chunk, tokenizer) <= chunk_size:
chunks.append(current_chunk)
else:
if spliter != " ": # Avoid infinite loops
chunks.extend(
split_into_granular_chunks(
current_chunk,
chunk_size=chunk_size,
tokenizer=tokenizer,
spliter=" ",
)
)
# If last chunk too short, re-balance the last two chunks
if len(chunks) > 1 and get_prompt_length(chunks[-1], tokenizer) < chunk_size // 2:
last_chunk = chunks.pop()
penultimate_chunk = chunks.pop()
combined_text = penultimate_chunk + last_chunk
new_sentences = split_sentences(combined_text, spliter)
# Reallocate sentence using double pointer
new_penultimate_chunk = ""
new_last_chunk = ""
start, end = 0, len(new_sentences) - 1
while start <= end and len(new_sentences) != 1:
flag = False
if (
get_prompt_length(
new_penultimate_chunk + new_sentences[start], tokenizer
)
<= chunk_size
):
flag = True
new_penultimate_chunk += new_sentences[start]
if start == end:
break
start += 1
if (
get_prompt_length(new_last_chunk + new_sentences[end], tokenizer)
<= chunk_size
):
new_last_chunk = new_sentences[end] + new_last_chunk
end -= 1
flag = True
if flag == False:
break
if start < end:
# If there is any unallocated part, split it by punctuation or space and then allocate it
remaining_sentences = new_sentences[start : end + 1]
if remaining_sentences:
remaining_text = "".join(remaining_sentences)
words = remaining_text.split(" ")
end_index = len(words) - 1
for index, w in enumerate(words):
if (
get_prompt_length(
" ".join([new_penultimate_chunk, w]), tokenizer
)
<= chunk_size
):
new_penultimate_chunk = " ".join([new_penultimate_chunk, w])
else:
end_index = index
break
if end_index != len(words) - 1:
new_last_chunk = " ".join(words[end_index:]) + " " + new_last_chunk
if len(new_sentences) == 1:
chunks.append(penultimate_chunk)
chunks.append(last_chunk)
else:
chunks.append(new_penultimate_chunk)
chunks.append(new_last_chunk)
return chunks
|