# Code modified from https://github.com/thunlp/LLMxMapReduce under Apache 2.0 import re from typing import List from .utils import logger def get_prompt_length(prompt: str, tokenizer, no_special_tokens=False, **kwargs) -> int: """ Returns the token length of a prompt using the given tokenizer. """ if isinstance(prompt, list): prompt = "\n\n".join(prompt) if no_special_tokens: kwargs["add_special_tokens"] = False return len(tokenizer.encode(prompt, **kwargs)) def chunk_context(doc: str, chunk_size: int, tokenizer, separator="\n",) -> List[str]: """ Splits a long document into token-limited chunks based on a separator, ensuring each chunk fits within `chunk_size`. Uses a greedy approach to accumulate text segments (split by `separator`) into chunks that fit within the token limit. If a segment alone exceeds the limit, it is recursively broken down using sentence-level splitting. Attempts to preserve natural boundaries while minimizing excessive chunking. Args: doc (str): Input document to split. chunk_size (int): Maximum number of tokens allowed per chunk. tokenizer: Tokenizer instance with `.encode()` method to compute token length. separator (str): Delimiter to split initial segments (default: newline). Returns: List[str]: List of non-empty, token-constrained document chunks. """ paragraphs = doc.split(separator) paragraphs = [paragraph for paragraph in paragraphs if paragraph] separator_len = get_prompt_length(separator, tokenizer, no_special_tokens=True) docs = [] current_doc = [] total = 0 for paragraph in paragraphs: plen = get_prompt_length(paragraph, tokenizer, no_special_tokens=True) if total + plen + (separator_len if len(current_doc) > 0 else 0) > chunk_size: if total > chunk_size: logger.info( f"Created a chunk of size {total}, " f"which is longer than the specified {chunk_size}" ) # If single chunk is too long, split into more granular if len(current_doc) == 1: split_again = split_into_granular_chunks( current_doc[0], chunk_size, tokenizer ) docs.extend(split_again) current_doc = [] total = 0 if len(current_doc) > 0: doc = separator.join(current_doc) if doc is not None: docs.append(doc) while total > 0 or ( total + plen + (separator_len if len(current_doc) > 0 else 0) > chunk_size and total > 0 ): total -= get_prompt_length( current_doc[0], tokenizer, no_special_tokens=True ) + (separator_len if len(current_doc) > 1 else 0) current_doc = current_doc[1:] current_doc.append(paragraph) total += plen + (separator_len if len(current_doc) > 1 else 0) # Check if the last one exceeds if ( get_prompt_length(current_doc[-1], tokenizer, no_special_tokens=True) > chunk_size and len(current_doc) == 1 ): split_again = split_into_granular_chunks(current_doc[0], chunk_size, tokenizer) docs.extend(split_again) current_doc = [] else: doc = separator.join(current_doc) if doc is not None: docs.append(doc) return [doc for doc in docs if doc.strip()] def split_sentences(text: str, spliter: str): """ Splits text into sentences or segments based on a given delimiter while preserving punctuation. For punctuation-based splitters (e.g., ".", "!", "。"), it interleaves text and punctuation. For space-based splitting, it preserves trailing spaces. Args: text (str): The input text to split. spliter (str): Delimiter regex pattern (e.g., r"([.!?])", r"(。)", or " "). Returns: List[str]: List of split sentence-like segments with punctuation retained. """ # Split by punctuation and keep punctuation text = text.strip() sentence_list = re.split(spliter, text) # Rearrange sentences and punctuation if spliter != " ": sentences = ["".join(i) for i in zip(sentence_list[0::2], sentence_list[1::2])] if len(sentence_list) % 2 != 0 and sentence_list[-1] != "": sentences.append(sentence_list[-1]) else: sentences = [i + " " for i in sentence_list if i != ""] sentences[-1] = sentences[-1].strip() return sentences def split_into_granular_chunks( text: str, chunk_size: int, tokenizer, spliter=r"([。!?;.?!;])", ) -> List[str]: """ Splits long text into granular, token-length-constrained chunks using sentence boundaries. Sentences are first extracted using a delimiter pattern (`spliter`), then grouped into chunks such that each chunk does not exceed the specified `chunk_size` (in tokens). If a chunk still exceeds the limit, it is recursively broken down further using whitespace as a fallback. Ensures that the final chunks are balanced: if the last chunk is too small, it redistributes the last two chunks more evenly by re-splitting and re-allocating their sentences. Args: text (str): Input text to be chunked. chunk_size (int): Maximum number of tokens per chunk. tokenizer: Tokenizer instance with `.encode()` method to compute token length. spliter (str): Regex pattern to split sentences. Returns: List[str]: List of token-limited chunks, each composed of one or more sentences. """ sentences = split_sentences(text, spliter) chunks = [] current_chunk = "" for sentence in sentences: sentence_length = get_prompt_length(sentence, tokenizer) if get_prompt_length(current_chunk, tokenizer) + sentence_length <= chunk_size: current_chunk += sentence else: if current_chunk: if get_prompt_length(current_chunk, tokenizer) <= chunk_size: chunks.append(current_chunk) else: if spliter != " ": # Avoid infinite loops chunks.extend( split_into_granular_chunks( current_chunk, chunk_size=chunk_size, tokenizer=tokenizer, spliter=" ", ) ) current_chunk = sentence if current_chunk != "": if get_prompt_length(current_chunk, tokenizer) <= chunk_size: chunks.append(current_chunk) else: if spliter != " ": # Avoid infinite loops chunks.extend( split_into_granular_chunks( current_chunk, chunk_size=chunk_size, tokenizer=tokenizer, spliter=" ", ) ) # If last chunk too short, re-balance the last two chunks if len(chunks) > 1 and get_prompt_length(chunks[-1], tokenizer) < chunk_size // 2: last_chunk = chunks.pop() penultimate_chunk = chunks.pop() combined_text = penultimate_chunk + last_chunk new_sentences = split_sentences(combined_text, spliter) # Reallocate sentence using double pointer new_penultimate_chunk = "" new_last_chunk = "" start, end = 0, len(new_sentences) - 1 while start <= end and len(new_sentences) != 1: flag = False if ( get_prompt_length( new_penultimate_chunk + new_sentences[start], tokenizer ) <= chunk_size ): flag = True new_penultimate_chunk += new_sentences[start] if start == end: break start += 1 if ( get_prompt_length(new_last_chunk + new_sentences[end], tokenizer) <= chunk_size ): new_last_chunk = new_sentences[end] + new_last_chunk end -= 1 flag = True if flag == False: break if start < end: # If there is any unallocated part, split it by punctuation or space and then allocate it remaining_sentences = new_sentences[start : end + 1] if remaining_sentences: remaining_text = "".join(remaining_sentences) words = remaining_text.split(" ") end_index = len(words) - 1 for index, w in enumerate(words): if ( get_prompt_length( " ".join([new_penultimate_chunk, w]), tokenizer ) <= chunk_size ): new_penultimate_chunk = " ".join([new_penultimate_chunk, w]) else: end_index = index break if end_index != len(words) - 1: new_last_chunk = " ".join(words[end_index:]) + " " + new_last_chunk if len(new_sentences) == 1: chunks.append(penultimate_chunk) chunks.append(last_chunk) else: chunks.append(new_penultimate_chunk) chunks.append(new_last_chunk) return chunks