|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from typing import Any, Dict, List |
|
|
|
import numpy as np |
|
|
|
from camel.loaders import UnstructuredIO |
|
from camel.retrievers import BaseRetriever |
|
from camel.utils import dependencies_required |
|
|
|
DEFAULT_TOP_K_RESULTS = 1 |
|
|
|
|
|
class BM25Retriever(BaseRetriever): |
|
r"""An implementation of the `BaseRetriever` using the `BM25` model. |
|
|
|
This class facilitates the retriever of relevant information using a |
|
query-based approach, it ranks documents based on the occurrence and |
|
frequency of the query terms. |
|
|
|
Attributes: |
|
bm25 (BM25Okapi): An instance of the BM25Okapi class used for |
|
calculating document scores. |
|
content_input_path (str): The path to the content that has been |
|
processed and stored. |
|
unstructured_modules (UnstructuredIO): A module for parsing files and |
|
URLs and chunking content based on specified parameters. |
|
|
|
References: |
|
https://github.com/dorianbrown/rank_bm25 |
|
""" |
|
|
|
@dependencies_required('rank_bm25') |
|
def __init__(self) -> None: |
|
r"""Initializes the BM25Retriever.""" |
|
from rank_bm25 import BM25Okapi |
|
|
|
self.bm25: BM25Okapi = None |
|
self.content_input_path: str = "" |
|
self.unstructured_modules: UnstructuredIO = UnstructuredIO() |
|
|
|
def process( |
|
self, |
|
content_input_path: str, |
|
chunk_type: str = "chunk_by_title", |
|
**kwargs: Any, |
|
) -> None: |
|
r"""Processes content from a file or URL, divides it into chunks by |
|
using `Unstructured IO`,then stored internally. This method must be |
|
called before executing queries with the retriever. |
|
|
|
Args: |
|
content_input_path (str): File path or URL of the content to be |
|
processed. |
|
chunk_type (str): Type of chunking going to apply. Defaults to |
|
"chunk_by_title". |
|
**kwargs (Any): Additional keyword arguments for content parsing. |
|
""" |
|
from rank_bm25 import BM25Okapi |
|
|
|
|
|
self.content_input_path = content_input_path |
|
elements = self.unstructured_modules.parse_file_or_url( |
|
content_input_path, **kwargs |
|
) |
|
if elements: |
|
self.chunks = self.unstructured_modules.chunk_elements( |
|
chunk_type=chunk_type, elements=elements |
|
) |
|
|
|
|
|
tokenized_corpus = [str(chunk).split(" ") for chunk in self.chunks] |
|
self.bm25 = BM25Okapi(tokenized_corpus) |
|
else: |
|
self.bm25 = None |
|
|
|
def query( |
|
self, |
|
query: str, |
|
top_k: int = DEFAULT_TOP_K_RESULTS, |
|
) -> List[Dict[str, Any]]: |
|
r"""Executes a query and compiles the results. |
|
|
|
Args: |
|
query (str): Query string for information retriever. |
|
top_k (int, optional): The number of top results to return during |
|
retriever. Must be a positive integer. Defaults to |
|
`DEFAULT_TOP_K_RESULTS`. |
|
|
|
Returns: |
|
List[Dict[str]]: Concatenated list of the query results. |
|
|
|
Raises: |
|
ValueError: If `top_k` is less than or equal to 0, if the BM25 |
|
model has not been initialized by calling `process` |
|
first. |
|
""" |
|
|
|
if top_k <= 0: |
|
raise ValueError("top_k must be a positive integer.") |
|
if self.bm25 is None or not self.chunks: |
|
raise ValueError( |
|
"BM25 model is not initialized. Call `process` first." |
|
) |
|
|
|
|
|
processed_query = query.split(" ") |
|
|
|
scores = self.bm25.get_scores(processed_query) |
|
|
|
top_k_indices = np.argpartition(scores, -top_k)[-top_k:] |
|
|
|
formatted_results = [] |
|
for i in top_k_indices: |
|
result_dict = { |
|
'similarity score': scores[i], |
|
'content path': self.content_input_path, |
|
'metadata': self.chunks[i].metadata.to_dict(), |
|
'text': str(self.chunks[i]), |
|
} |
|
formatted_results.append(result_dict) |
|
|
|
|
|
formatted_results.sort( |
|
key=lambda x: x['similarity score'], reverse=True |
|
) |
|
|
|
return formatted_results |
|
|