api-demo
/
opencompass-my-api
/build
/lib
/opencompass
/openicl
/icl_retriever
/icl_bm25_retriever.py
"""BM25 Retriever.""" | |
from typing import List, Optional | |
import numpy as np | |
from nltk.tokenize import word_tokenize | |
from rank_bm25 import BM25Okapi | |
from tqdm import trange | |
from opencompass.openicl.icl_retriever import BaseRetriever | |
from opencompass.openicl.utils.logging import get_logger | |
from opencompass.registry import ICL_RETRIEVERS | |
logger = get_logger(__name__) | |
class BM25Retriever(BaseRetriever): | |
"""BM25 Retriever. In information retrieval, Okapi BM25 (BM is an | |
abbreviation of best matching) is a ranking function used by search engines | |
to estimate the relevance of documents to a given search query. You can | |
find more details in https://en.wikipedia.org/wiki/Okapi_BM25. Each in- | |
context example of the test prompts is retrieved by the BM25 Algorithm. | |
Args: | |
dataset (`BaseDataset`): Any BaseDataset instances. | |
Attributes of ``reader``, ``train`` and ``test`` will be used. | |
ice_separator (`Optional[str]`): The separator between each in-context | |
example template when origin `PromptTemplate` is provided. Defaults | |
to '\n'. | |
ice_eos_token (`Optional[str]`): The end of sentence token for | |
in-context example template when origin `PromptTemplate` is | |
provided. Defaults to '\n'. | |
ice_num (`Optional[int]`): The number of in-context example template | |
when origin `PromptTemplate` is provided. Defaults to 1. | |
index_split (`Optional[str]`): The split of the dataset to retrieve the | |
in-context example index, used when `dataset_reader.dataset` is an | |
instance of `datasets.Dataset`. Defaults to 'train'. | |
test_split (`Optional[str]`): The split of the dataset to retrieve the | |
in-context example, used when `dataset_reader.dataset` is an | |
instance of `datasets.Dataset`. Defaults to 'test'. | |
""" | |
bm25 = None | |
index_corpus = None | |
test_corpus = None | |
def __init__(self, | |
dataset, | |
ice_separator: Optional[str] = '\n', | |
ice_eos_token: Optional[str] = '\n', | |
ice_num: Optional[int] = 1) -> None: | |
super().__init__(dataset, ice_separator, ice_eos_token, ice_num) | |
self.index_corpus = [ | |
word_tokenize(data) for data in | |
self.dataset_reader.generate_input_field_corpus(self.index_ds) | |
] | |
self.bm25 = BM25Okapi(self.index_corpus) | |
self.test_corpus = [ | |
word_tokenize(data) for data in | |
self.dataset_reader.generate_input_field_corpus(self.test_ds) | |
] | |
def retrieve(self) -> List[List]: | |
"""Retrieve the in-context example index for each test example.""" | |
rtr_idx_list = [] | |
logger.info('Retrieving data for test set...') | |
for idx in trange(len(self.test_corpus), | |
disable=not self.is_main_process): | |
query = self.test_corpus[idx] | |
scores = self.bm25.get_scores(query) | |
near_ids = list(np.argsort(scores)[::-1][:self.ice_num]) | |
near_ids = [int(a) for a in near_ids] | |
rtr_idx_list.append(near_ids) | |
return rtr_idx_list | |