Spaces:
Runtime error
Runtime error
from __future__ import annotations | |
from typing import Any, Callable, Dict, Iterable, List, Optional | |
from langchain_core.callbacks import CallbackManagerForRetrieverRun | |
from langchain_core.documents import Document | |
from langchain_core.pydantic_v1 import Field | |
from langchain_core.retrievers import BaseRetriever | |
def default_preprocessing_func(text: str) -> List[str]: | |
return text.split() | |
class BM25Retriever(BaseRetriever): | |
"""`BM25` retriever without Elasticsearch.""" | |
vectorizer: Any | |
""" BM25 vectorizer.""" | |
docs: List[Document] = Field(repr=False) | |
""" List of documents.""" | |
k: int = 4 | |
""" Number of documents to return.""" | |
preprocess_func: Callable[[str], List[str]] = default_preprocessing_func | |
""" Preprocessing function to use on the text before BM25 vectorization.""" | |
class Config: | |
"""Configuration for this pydantic object.""" | |
arbitrary_types_allowed = True | |
def from_texts( | |
cls, | |
texts: Iterable[str], | |
metadatas: Optional[Iterable[dict]] = None, | |
bm25_params: Optional[Dict[str, Any]] = None, | |
preprocess_func: Callable[[str], List[str]] = default_preprocessing_func, | |
**kwargs: Any, | |
) -> BM25Retriever: | |
""" | |
Create a BM25Retriever from a list of texts. | |
Args: | |
texts: A list of texts to vectorize. | |
metadatas: A list of metadata dicts to associate with each text. | |
bm25_params: Parameters to pass to the BM25 vectorizer. | |
preprocess_func: A function to preprocess each text before vectorization. | |
**kwargs: Any other arguments to pass to the retriever. | |
Returns: | |
A BM25Retriever instance. | |
""" | |
try: | |
from rank_bm25 import BM25Okapi | |
except ImportError: | |
raise ImportError( | |
"Could not import rank_bm25, please install with `pip install " | |
"rank_bm25`." | |
) | |
texts_processed = [preprocess_func(t) for t in texts] | |
bm25_params = bm25_params or {} | |
vectorizer = BM25Okapi(texts_processed, **bm25_params) | |
metadatas = metadatas or ({} for _ in texts) | |
docs = [Document(page_content=t, metadata=m) for t, m in zip(texts, metadatas)] | |
return cls( | |
vectorizer=vectorizer, docs=docs, preprocess_func=preprocess_func, **kwargs | |
) | |
def from_documents( | |
cls, | |
documents: Iterable[Document], | |
*, | |
bm25_params: Optional[Dict[str, Any]] = None, | |
preprocess_func: Callable[[str], List[str]] = default_preprocessing_func, | |
**kwargs: Any, | |
) -> BM25Retriever: | |
""" | |
Create a BM25Retriever from a list of Documents. | |
Args: | |
documents: A list of Documents to vectorize. | |
bm25_params: Parameters to pass to the BM25 vectorizer. | |
preprocess_func: A function to preprocess each text before vectorization. | |
**kwargs: Any other arguments to pass to the retriever. | |
Returns: | |
A BM25Retriever instance. | |
""" | |
texts, metadatas = zip(*((d.page_content, d.metadata) for d in documents)) | |
return cls.from_texts( | |
texts=texts, | |
bm25_params=bm25_params, | |
metadatas=metadatas, | |
preprocess_func=preprocess_func, | |
**kwargs, | |
) | |
def _get_relevant_documents( | |
self, query: str, *, run_manager: CallbackManagerForRetrieverRun | |
) -> List[Document]: | |
processed_query = self.preprocess_func(query) | |
return_docs = self.vectorizer.get_top_n(processed_query, self.docs, n=self.k) | |
return return_docs | |