|
|
|
|
|
import re |
|
from typing import List |
|
|
|
from .utils import logger |
|
|
|
|
|
def get_prompt_length(prompt: str, tokenizer, no_special_tokens=False, **kwargs) -> int: |
|
""" |
|
Returns the token length of a prompt using the given tokenizer. |
|
""" |
|
if isinstance(prompt, list): |
|
prompt = "\n\n".join(prompt) |
|
if no_special_tokens: |
|
kwargs["add_special_tokens"] = False |
|
return len(tokenizer.encode(prompt, **kwargs)) |
|
|
|
|
|
def chunk_context(doc: str, chunk_size: int, tokenizer, separator="\n",) -> List[str]: |
|
""" |
|
Splits a long document into token-limited chunks based on a separator, ensuring each chunk fits within `chunk_size`. |
|
|
|
Uses a greedy approach to accumulate text segments (split by `separator`) into chunks that fit within the |
|
token limit. If a segment alone exceeds the limit, it is recursively broken down using sentence-level |
|
splitting. Attempts to preserve natural boundaries while minimizing excessive chunking. |
|
|
|
Args: |
|
doc (str): Input document to split. |
|
chunk_size (int): Maximum number of tokens allowed per chunk. |
|
tokenizer: Tokenizer instance with `.encode()` method to compute token length. |
|
separator (str): Delimiter to split initial segments (default: newline). |
|
|
|
Returns: |
|
List[str]: List of non-empty, token-constrained document chunks. |
|
""" |
|
paragraphs = doc.split(separator) |
|
paragraphs = [paragraph for paragraph in paragraphs if paragraph] |
|
separator_len = get_prompt_length(separator, tokenizer, no_special_tokens=True) |
|
|
|
docs = [] |
|
current_doc = [] |
|
total = 0 |
|
for paragraph in paragraphs: |
|
plen = get_prompt_length(paragraph, tokenizer, no_special_tokens=True) |
|
if total + plen + (separator_len if len(current_doc) > 0 else 0) > chunk_size: |
|
if total > chunk_size: |
|
logger.info( |
|
f"Created a chunk of size {total}, " |
|
f"which is longer than the specified {chunk_size}" |
|
) |
|
|
|
if len(current_doc) == 1: |
|
split_again = split_into_granular_chunks( |
|
current_doc[0], chunk_size, tokenizer |
|
) |
|
docs.extend(split_again) |
|
current_doc = [] |
|
total = 0 |
|
|
|
if len(current_doc) > 0: |
|
doc = separator.join(current_doc) |
|
if doc is not None: |
|
docs.append(doc) |
|
while total > 0 or ( |
|
total + plen + (separator_len if len(current_doc) > 0 else 0) |
|
> chunk_size |
|
and total > 0 |
|
): |
|
total -= get_prompt_length( |
|
current_doc[0], tokenizer, no_special_tokens=True |
|
) + (separator_len if len(current_doc) > 1 else 0) |
|
current_doc = current_doc[1:] |
|
|
|
current_doc.append(paragraph) |
|
total += plen + (separator_len if len(current_doc) > 1 else 0) |
|
|
|
if ( |
|
get_prompt_length(current_doc[-1], tokenizer, no_special_tokens=True) |
|
> chunk_size |
|
and len(current_doc) == 1 |
|
): |
|
split_again = split_into_granular_chunks(current_doc[0], chunk_size, tokenizer) |
|
docs.extend(split_again) |
|
current_doc = [] |
|
else: |
|
doc = separator.join(current_doc) |
|
if doc is not None: |
|
docs.append(doc) |
|
|
|
return [doc for doc in docs if doc.strip()] |
|
|
|
|
|
def split_sentences(text: str, spliter: str): |
|
""" |
|
Splits text into sentences or segments based on a given delimiter while preserving punctuation. |
|
|
|
For punctuation-based splitters (e.g., ".", "!", "。"), it interleaves text and punctuation. |
|
For space-based splitting, it preserves trailing spaces. |
|
|
|
Args: |
|
text (str): The input text to split. |
|
spliter (str): Delimiter regex pattern (e.g., r"([.!?])", r"(。)", or " "). |
|
|
|
Returns: |
|
List[str]: List of split sentence-like segments with punctuation retained. |
|
""" |
|
|
|
|
|
text = text.strip() |
|
sentence_list = re.split(spliter, text) |
|
|
|
|
|
if spliter != " ": |
|
sentences = ["".join(i) for i in zip(sentence_list[0::2], sentence_list[1::2])] |
|
if len(sentence_list) % 2 != 0 and sentence_list[-1] != "": |
|
sentences.append(sentence_list[-1]) |
|
else: |
|
sentences = [i + " " for i in sentence_list if i != ""] |
|
sentences[-1] = sentences[-1].strip() |
|
return sentences |
|
|
|
|
|
def split_into_granular_chunks( |
|
text: str, chunk_size: int, tokenizer, spliter=r"([。!?;.?!;])", |
|
) -> List[str]: |
|
""" |
|
Splits long text into granular, token-length-constrained chunks using sentence boundaries. |
|
|
|
Sentences are first extracted using a delimiter pattern (`spliter`), then grouped into chunks such that |
|
each chunk does not exceed the specified `chunk_size` (in tokens). If a chunk still exceeds the limit, |
|
it is recursively broken down further using whitespace as a fallback. |
|
|
|
Ensures that the final chunks are balanced: if the last chunk is too small, it redistributes the last two |
|
chunks more evenly by re-splitting and re-allocating their sentences. |
|
|
|
Args: |
|
text (str): Input text to be chunked. |
|
chunk_size (int): Maximum number of tokens per chunk. |
|
tokenizer: Tokenizer instance with `.encode()` method to compute token length. |
|
spliter (str): Regex pattern to split sentences. |
|
|
|
Returns: |
|
List[str]: List of token-limited chunks, each composed of one or more sentences. |
|
""" |
|
sentences = split_sentences(text, spliter) |
|
|
|
chunks = [] |
|
current_chunk = "" |
|
|
|
for sentence in sentences: |
|
sentence_length = get_prompt_length(sentence, tokenizer) |
|
|
|
if get_prompt_length(current_chunk, tokenizer) + sentence_length <= chunk_size: |
|
current_chunk += sentence |
|
else: |
|
if current_chunk: |
|
if get_prompt_length(current_chunk, tokenizer) <= chunk_size: |
|
chunks.append(current_chunk) |
|
else: |
|
if spliter != " ": |
|
chunks.extend( |
|
split_into_granular_chunks( |
|
current_chunk, |
|
chunk_size=chunk_size, |
|
tokenizer=tokenizer, |
|
spliter=" ", |
|
) |
|
) |
|
current_chunk = sentence |
|
|
|
if current_chunk != "": |
|
if get_prompt_length(current_chunk, tokenizer) <= chunk_size: |
|
chunks.append(current_chunk) |
|
else: |
|
if spliter != " ": |
|
chunks.extend( |
|
split_into_granular_chunks( |
|
current_chunk, |
|
chunk_size=chunk_size, |
|
tokenizer=tokenizer, |
|
spliter=" ", |
|
) |
|
) |
|
|
|
|
|
if len(chunks) > 1 and get_prompt_length(chunks[-1], tokenizer) < chunk_size // 2: |
|
last_chunk = chunks.pop() |
|
penultimate_chunk = chunks.pop() |
|
combined_text = penultimate_chunk + last_chunk |
|
|
|
new_sentences = split_sentences(combined_text, spliter) |
|
|
|
|
|
new_penultimate_chunk = "" |
|
new_last_chunk = "" |
|
start, end = 0, len(new_sentences) - 1 |
|
|
|
while start <= end and len(new_sentences) != 1: |
|
flag = False |
|
if ( |
|
get_prompt_length( |
|
new_penultimate_chunk + new_sentences[start], tokenizer |
|
) |
|
<= chunk_size |
|
): |
|
flag = True |
|
new_penultimate_chunk += new_sentences[start] |
|
if start == end: |
|
break |
|
start += 1 |
|
if ( |
|
get_prompt_length(new_last_chunk + new_sentences[end], tokenizer) |
|
<= chunk_size |
|
): |
|
new_last_chunk = new_sentences[end] + new_last_chunk |
|
end -= 1 |
|
flag = True |
|
if flag == False: |
|
break |
|
if start < end: |
|
|
|
remaining_sentences = new_sentences[start : end + 1] |
|
if remaining_sentences: |
|
remaining_text = "".join(remaining_sentences) |
|
words = remaining_text.split(" ") |
|
end_index = len(words) - 1 |
|
for index, w in enumerate(words): |
|
if ( |
|
get_prompt_length( |
|
" ".join([new_penultimate_chunk, w]), tokenizer |
|
) |
|
<= chunk_size |
|
): |
|
new_penultimate_chunk = " ".join([new_penultimate_chunk, w]) |
|
else: |
|
end_index = index |
|
break |
|
if end_index != len(words) - 1: |
|
new_last_chunk = " ".join(words[end_index:]) + " " + new_last_chunk |
|
if len(new_sentences) == 1: |
|
chunks.append(penultimate_chunk) |
|
chunks.append(last_chunk) |
|
else: |
|
chunks.append(new_penultimate_chunk) |
|
chunks.append(new_last_chunk) |
|
|
|
return chunks |
|
|