Spaces:
Runtime error
Runtime error
from __future__ import annotations | |
import copy | |
import logging | |
from abc import ABC, abstractmethod | |
from dataclasses import dataclass | |
from enum import Enum | |
from typing import ( | |
AbstractSet, | |
Any, | |
Callable, | |
Collection, | |
Iterable, | |
List, | |
Literal, | |
Optional, | |
Sequence, | |
Type, | |
TypeVar, | |
Union, | |
) | |
from langchain_core.documents import BaseDocumentTransformer, Document | |
logger = logging.getLogger(__name__) | |
TS = TypeVar("TS", bound="TextSplitter") | |
class TextSplitter(BaseDocumentTransformer, ABC): | |
"""Interface for splitting text into chunks.""" | |
def __init__( | |
self, | |
chunk_size: int = 4000, | |
chunk_overlap: int = 200, | |
length_function: Callable[[str], int] = len, | |
keep_separator: Union[bool, Literal["start", "end"]] = False, | |
add_start_index: bool = False, | |
strip_whitespace: bool = True, | |
) -> None: | |
"""Create a new TextSplitter. | |
Args: | |
chunk_size: Maximum size of chunks to return | |
chunk_overlap: Overlap in characters between chunks | |
length_function: Function that measures the length of given chunks | |
keep_separator: Whether to keep the separator and where to place it | |
in each corresponding chunk (True='start') | |
add_start_index: If `True`, includes chunk's start index in metadata | |
strip_whitespace: If `True`, strips whitespace from the start and end of | |
every document | |
""" | |
if chunk_overlap > chunk_size: | |
raise ValueError( | |
f"Got a larger chunk overlap ({chunk_overlap}) than chunk size " | |
f"({chunk_size}), should be smaller." | |
) | |
self._chunk_size = chunk_size | |
self._chunk_overlap = chunk_overlap | |
self._length_function = length_function | |
self._keep_separator = keep_separator | |
self._add_start_index = add_start_index | |
self._strip_whitespace = strip_whitespace | |
def split_text(self, text: str) -> List[str]: | |
"""Split text into multiple components.""" | |
def create_documents( | |
self, texts: List[str], metadatas: Optional[List[dict]] = None | |
) -> List[Document]: | |
"""Create documents from a list of texts.""" | |
_metadatas = metadatas or [{}] * len(texts) | |
documents = [] | |
for i, text in enumerate(texts): | |
index = 0 | |
previous_chunk_len = 0 | |
for chunk in self.split_text(text): | |
metadata = copy.deepcopy(_metadatas[i]) | |
if self._add_start_index: | |
offset = index + previous_chunk_len - self._chunk_overlap | |
index = text.find(chunk, max(0, offset)) | |
metadata["start_index"] = index | |
previous_chunk_len = len(chunk) | |
new_doc = Document(page_content=chunk, metadata=metadata) | |
documents.append(new_doc) | |
return documents | |
def split_documents(self, documents: Iterable[Document]) -> List[Document]: | |
"""Split documents.""" | |
texts, metadatas = [], [] | |
for doc in documents: | |
texts.append(doc.page_content) | |
metadatas.append(doc.metadata) | |
return self.create_documents(texts, metadatas=metadatas) | |
def _join_docs(self, docs: List[str], separator: str) -> Optional[str]: | |
text = separator.join(docs) | |
if self._strip_whitespace: | |
text = text.strip() | |
if text == "": | |
return None | |
else: | |
return text | |
def _merge_splits(self, splits: Iterable[str], separator: str) -> List[str]: | |
# We now want to combine these smaller pieces into medium size | |
# chunks to send to the LLM. | |
separator_len = self._length_function(separator) | |
docs = [] | |
current_doc: List[str] = [] | |
total = 0 | |
for d in splits: | |
_len = self._length_function(d) | |
if ( | |
total + _len + (separator_len if len(current_doc) > 0 else 0) | |
> self._chunk_size | |
): | |
if total > self._chunk_size: | |
logger.warning( | |
f"Created a chunk of size {total}, " | |
f"which is longer than the specified {self._chunk_size}" | |
) | |
if len(current_doc) > 0: | |
doc = self._join_docs(current_doc, separator) | |
if doc is not None: | |
docs.append(doc) | |
# Keep on popping if: | |
# - we have a larger chunk than in the chunk overlap | |
# - or if we still have any chunks and the length is long | |
while total > self._chunk_overlap or ( | |
total + _len + (separator_len if len(current_doc) > 0 else 0) | |
> self._chunk_size | |
and total > 0 | |
): | |
total -= self._length_function(current_doc[0]) + ( | |
separator_len if len(current_doc) > 1 else 0 | |
) | |
current_doc = current_doc[1:] | |
current_doc.append(d) | |
total += _len + (separator_len if len(current_doc) > 1 else 0) | |
doc = self._join_docs(current_doc, separator) | |
if doc is not None: | |
docs.append(doc) | |
return docs | |
def from_huggingface_tokenizer(cls, tokenizer: Any, **kwargs: Any) -> TextSplitter: | |
"""Text splitter that uses HuggingFace tokenizer to count length.""" | |
try: | |
from transformers import PreTrainedTokenizerBase | |
if not isinstance(tokenizer, PreTrainedTokenizerBase): | |
raise ValueError( | |
"Tokenizer received was not an instance of PreTrainedTokenizerBase" | |
) | |
def _huggingface_tokenizer_length(text: str) -> int: | |
return len(tokenizer.encode(text)) | |
except ImportError: | |
raise ValueError( | |
"Could not import transformers python package. " | |
"Please install it with `pip install transformers`." | |
) | |
return cls(length_function=_huggingface_tokenizer_length, **kwargs) | |
def from_tiktoken_encoder( | |
cls: Type[TS], | |
encoding_name: str = "gpt2", | |
model_name: Optional[str] = None, | |
allowed_special: Union[Literal["all"], AbstractSet[str]] = set(), | |
disallowed_special: Union[Literal["all"], Collection[str]] = "all", | |
**kwargs: Any, | |
) -> TS: | |
"""Text splitter that uses tiktoken encoder to count length.""" | |
try: | |
import tiktoken | |
except ImportError: | |
raise ImportError( | |
"Could not import tiktoken python package. " | |
"This is needed in order to calculate max_tokens_for_prompt. " | |
"Please install it with `pip install tiktoken`." | |
) | |
if model_name is not None: | |
enc = tiktoken.encoding_for_model(model_name) | |
else: | |
enc = tiktoken.get_encoding(encoding_name) | |
def _tiktoken_encoder(text: str) -> int: | |
return len( | |
enc.encode( | |
text, | |
allowed_special=allowed_special, | |
disallowed_special=disallowed_special, | |
) | |
) | |
if issubclass(cls, TokenTextSplitter): | |
extra_kwargs = { | |
"encoding_name": encoding_name, | |
"model_name": model_name, | |
"allowed_special": allowed_special, | |
"disallowed_special": disallowed_special, | |
} | |
kwargs = {**kwargs, **extra_kwargs} | |
return cls(length_function=_tiktoken_encoder, **kwargs) | |
def transform_documents( | |
self, documents: Sequence[Document], **kwargs: Any | |
) -> Sequence[Document]: | |
"""Transform sequence of documents by splitting them.""" | |
return self.split_documents(list(documents)) | |
class TokenTextSplitter(TextSplitter): | |
"""Splitting text to tokens using model tokenizer.""" | |
def __init__( | |
self, | |
encoding_name: str = "gpt2", | |
model_name: Optional[str] = None, | |
allowed_special: Union[Literal["all"], AbstractSet[str]] = set(), | |
disallowed_special: Union[Literal["all"], Collection[str]] = "all", | |
**kwargs: Any, | |
) -> None: | |
"""Create a new TextSplitter.""" | |
super().__init__(**kwargs) | |
try: | |
import tiktoken | |
except ImportError: | |
raise ImportError( | |
"Could not import tiktoken python package. " | |
"This is needed in order to for TokenTextSplitter. " | |
"Please install it with `pip install tiktoken`." | |
) | |
if model_name is not None: | |
enc = tiktoken.encoding_for_model(model_name) | |
else: | |
enc = tiktoken.get_encoding(encoding_name) | |
self._tokenizer = enc | |
self._allowed_special = allowed_special | |
self._disallowed_special = disallowed_special | |
def split_text(self, text: str) -> List[str]: | |
def _encode(_text: str) -> List[int]: | |
return self._tokenizer.encode( | |
_text, | |
allowed_special=self._allowed_special, | |
disallowed_special=self._disallowed_special, | |
) | |
tokenizer = Tokenizer( | |
chunk_overlap=self._chunk_overlap, | |
tokens_per_chunk=self._chunk_size, | |
decode=self._tokenizer.decode, | |
encode=_encode, | |
) | |
return split_text_on_tokens(text=text, tokenizer=tokenizer) | |
class Language(str, Enum): | |
"""Enum of the programming languages.""" | |
CPP = "cpp" | |
GO = "go" | |
JAVA = "java" | |
KOTLIN = "kotlin" | |
JS = "js" | |
TS = "ts" | |
PHP = "php" | |
PROTO = "proto" | |
PYTHON = "python" | |
RST = "rst" | |
RUBY = "ruby" | |
RUST = "rust" | |
SCALA = "scala" | |
SWIFT = "swift" | |
MARKDOWN = "markdown" | |
LATEX = "latex" | |
HTML = "html" | |
SOL = "sol" | |
CSHARP = "csharp" | |
COBOL = "cobol" | |
C = "c" | |
LUA = "lua" | |
PERL = "perl" | |
HASKELL = "haskell" | |
class Tokenizer: | |
"""Tokenizer data class.""" | |
chunk_overlap: int | |
"""Overlap in tokens between chunks""" | |
tokens_per_chunk: int | |
"""Maximum number of tokens per chunk""" | |
decode: Callable[[List[int]], str] | |
""" Function to decode a list of token ids to a string""" | |
encode: Callable[[str], List[int]] | |
""" Function to encode a string to a list of token ids""" | |
def split_text_on_tokens(*, text: str, tokenizer: Tokenizer) -> List[str]: | |
"""Split incoming text and return chunks using tokenizer.""" | |
splits: List[str] = [] | |
input_ids = tokenizer.encode(text) | |
start_idx = 0 | |
cur_idx = min(start_idx + tokenizer.tokens_per_chunk, len(input_ids)) | |
chunk_ids = input_ids[start_idx:cur_idx] | |
while start_idx < len(input_ids): | |
splits.append(tokenizer.decode(chunk_ids)) | |
if cur_idx == len(input_ids): | |
break | |
start_idx += tokenizer.tokens_per_chunk - tokenizer.chunk_overlap | |
cur_idx = min(start_idx + tokenizer.tokens_per_chunk, len(input_ids)) | |
chunk_ids = input_ids[start_idx:cur_idx] | |
return splits | |