|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import argparse |
|
import os |
|
from typing import OrderedDict |
|
|
|
from tqdm import tqdm |
|
|
|
from pyserini.search import FaissSearcher, BinaryDenseSearcher, TctColBertQueryEncoder, QueryEncoder, \ |
|
DprQueryEncoder, BprQueryEncoder, DkrrDprQueryEncoder, AnceQueryEncoder, AggretrieverQueryEncoder, AutoQueryEncoder, DenseVectorAveragePrf, \ |
|
DenseVectorRocchioPrf, DenseVectorAncePrf |
|
|
|
from pyserini.encode import PcaEncoder |
|
from pyserini.query_iterator import get_query_iterator, TopicsFormat |
|
from pyserini.output_writer import get_output_writer, OutputFormat |
|
from pyserini.search.lucene import LuceneSearcher |
|
|
|
|
|
|
|
|
|
|
|
os.environ['KMP_DUPLICATE_LIB_OK'] = 'True' |
|
|
|
|
|
def define_dsearch_args(parser): |
|
parser.add_argument('--index', type=str, metavar='path to index or index name', required=True, |
|
help="Path to Faiss index or name of prebuilt index.") |
|
parser.add_argument('--encoder-class', type=str, metavar='which query encoder class to use. `default` would infer from the args.encoder', |
|
required=False, |
|
choices=["dkrr", "dpr", "bpr", "tct_colbert", "ance", "sentence", "contriever", "auto", "aggretriever"], |
|
default=None, |
|
help='which query encoder class to use. `default` would infer from the args.encoder') |
|
parser.add_argument('--encoder', type=str, metavar='path to query encoder checkpoint or encoder name', |
|
required=False, |
|
help="Path to query encoder pytorch checkpoint or hgf encoder model name") |
|
parser.add_argument('--tokenizer', type=str, metavar='name or path', |
|
required=False, |
|
help="Path to a hgf tokenizer name or path") |
|
parser.add_argument('--encoded-queries', type=str, metavar='path to query encoded queries dir or queries name', |
|
required=False, |
|
help="Path to query encoder pytorch checkpoint or hgf encoder model name") |
|
parser.add_argument('--pca-model', type=str, metavar='path', required=False, |
|
default=None, help="Path to a faiss pca model") |
|
parser.add_argument('--device', type=str, metavar='device to run query encoder', required=False, default='cpu', |
|
help="Device to run query encoder, cpu or [cuda:0, cuda:1, ...]") |
|
parser.add_argument('--query-prefix', type=str, metavar='str', required=False, default=None, |
|
help="Query prefix if exists.") |
|
parser.add_argument('--searcher', type=str, metavar='str', required=False, default='simple', |
|
help="dense searcher type") |
|
parser.add_argument('--prf-depth', type=int, metavar='num of passages used for PRF', required=False, default=0, |
|
help="Specify how many passages are used for PRF, 0: Simple retrieval with no PRF, > 0: perform PRF") |
|
parser.add_argument('--prf-method', type=str, metavar='avg or rocchio', required=False, default='avg', |
|
help="Choose PRF methods, avg or rocchio") |
|
parser.add_argument('--rocchio-alpha', type=float, metavar='alpha parameter for rocchio', required=False, |
|
default=0.9, |
|
help="The alpha parameter to control the contribution from the query vector") |
|
parser.add_argument('--rocchio-beta', type=float, metavar='beta parameter for rocchio', required=False, default=0.1, |
|
help="The beta parameter to control the contribution from the average vector of the positive PRF passages") |
|
parser.add_argument('--rocchio-gamma', type=float, metavar='gamma parameter for rocchio', required=False, default=0.1, |
|
help="The gamma parameter to control the contribution from the average vector of the negative PRF passages") |
|
parser.add_argument('--rocchio-topk', type=int, metavar='topk passages as positive for rocchio', required=False, default=3, |
|
help="Set topk passages as positive PRF passages for rocchio") |
|
parser.add_argument('--rocchio-bottomk', type=int, metavar='bottomk passages as negative for rocchio', required=False, default=0, |
|
help="Set bottomk passages as negative PRF passages for rocchio, 0: do not use negatives prf passages.") |
|
parser.add_argument('--sparse-index', type=str, metavar='sparse lucene index containing contents', required=False, |
|
help='The path to sparse index containing the passage contents') |
|
parser.add_argument('--ance-prf-encoder', type=str, metavar='query encoder path for ANCE-PRF', required=False, |
|
help='The path or name to ANCE-PRF model checkpoint') |
|
parser.add_argument('--ef-search', type=int, metavar='efSearch for HNSW index', required=False, default=None, |
|
help="Set efSearch for HNSW index") |
|
|
|
|
|
def init_query_encoder(encoder, encoder_class, tokenizer_name, topics_name, encoded_queries, device, prefix): |
|
encoded_queries_map = { |
|
'msmarco-passage-dev-subset': 'tct_colbert-msmarco-passage-dev-subset', |
|
'dpr-nq-dev': 'dpr_multi-nq-dev', |
|
'dpr-nq-test': 'dpr_multi-nq-test', |
|
'dpr-trivia-dev': 'dpr_multi-trivia-dev', |
|
'dpr-trivia-test': 'dpr_multi-trivia-test', |
|
'dpr-wq-test': 'dpr_multi-wq-test', |
|
'dpr-squad-test': 'dpr_multi-squad-test', |
|
'dpr-curated-test': 'dpr_multi-curated-test' |
|
} |
|
encoder_class_map = { |
|
"dkrr": DkrrDprQueryEncoder, |
|
"dpr": DprQueryEncoder, |
|
"bpr": BprQueryEncoder, |
|
"tct_colbert": TctColBertQueryEncoder, |
|
"ance": AnceQueryEncoder, |
|
"sentence": AutoQueryEncoder, |
|
"contriever": AutoQueryEncoder, |
|
"aggretriever": AggretrieverQueryEncoder, |
|
"auto": AutoQueryEncoder, |
|
} |
|
|
|
if encoder: |
|
_encoder_class = encoder_class |
|
|
|
|
|
if encoder_class is not None: |
|
encoder_class = encoder_class_map[encoder_class] |
|
else: |
|
|
|
|
|
for class_keyword in encoder_class_map: |
|
if class_keyword in encoder.lower(): |
|
encoder_class = encoder_class_map[class_keyword] |
|
break |
|
|
|
|
|
|
|
if encoder_class is None: |
|
encoder_class = AutoQueryEncoder |
|
|
|
|
|
kwargs = dict(encoder_dir=encoder, tokenizer_name=tokenizer_name, device=device, prefix=prefix) |
|
if (_encoder_class == "sentence") or ("sentence" in encoder): |
|
kwargs.update(dict(pooling='mean', l2_norm=True)) |
|
if (_encoder_class == "contriever") or ("contriever" in encoder): |
|
kwargs.update(dict(pooling='mean', l2_norm=False)) |
|
|
|
return encoder_class(**kwargs) |
|
|
|
if encoded_queries: |
|
if os.path.exists(encoded_queries): |
|
if 'bpr' in encoded_queries: |
|
return BprQueryEncoder(encoded_query_dir=encoded_queries) |
|
else: |
|
return QueryEncoder(encoded_queries) |
|
else: |
|
if 'bpr' in encoded_queries: |
|
return BprQueryEncoder.load_encoded_queries(encoded_queries) |
|
else: |
|
return QueryEncoder.load_encoded_queries(encoded_queries) |
|
|
|
if topics_name in encoded_queries_map: |
|
return QueryEncoder.load_encoded_queries(encoded_queries_map[topics_name]) |
|
raise ValueError(f'No encoded queries for topic {topics_name}') |
|
|
|
|
|
if __name__ == '__main__': |
|
parser = argparse.ArgumentParser(description='Search a Faiss index.') |
|
parser.add_argument('--topics', type=str, metavar='topic_name', required=True, |
|
help="Name of topics. Available: msmarco-passage-dev-subset.") |
|
parser.add_argument('--hits', type=int, metavar='num', required=False, default=1000, help="Number of hits.") |
|
parser.add_argument('--binary-hits', type=int, metavar='num', required=False, default=1000, |
|
help="Number of binary hits.") |
|
parser.add_argument("--rerank", action="store_true", help='whethere rerank bpr sparse results.') |
|
parser.add_argument('--topics-format', type=str, metavar='format', default=TopicsFormat.DEFAULT.value, |
|
help=f"Format of topics. Available: {[x.value for x in list(TopicsFormat)]}") |
|
parser.add_argument('--output-format', type=str, metavar='format', default=OutputFormat.TREC.value, |
|
help=f"Format of output. Available: {[x.value for x in list(OutputFormat)]}") |
|
parser.add_argument('--output', type=str, metavar='path', required=True, help="Path to output file.") |
|
parser.add_argument('--max-passage', action='store_true', |
|
default=False, help="Select only max passage from document.") |
|
parser.add_argument('--max-passage-hits', type=int, metavar='num', required=False, default=100, |
|
help="Final number of hits when selecting only max passage.") |
|
parser.add_argument('--max-passage-delimiter', type=str, metavar='str', required=False, default='#', |
|
help="Delimiter between docid and passage id.") |
|
parser.add_argument('--batch-size', type=int, metavar='num', required=False, default=1, |
|
help="search batch of queries in parallel") |
|
parser.add_argument('--threads', type=int, metavar='num', required=False, default=1, |
|
help="maximum threads to use during search") |
|
|
|
|
|
parser.add_argument('--remove-query', action='store_true', default=False, help="Remove query from results list.") |
|
define_dsearch_args(parser) |
|
args = parser.parse_args() |
|
|
|
query_iterator = get_query_iterator(args.topics, TopicsFormat(args.topics_format)) |
|
topics = query_iterator.topics |
|
|
|
query_encoder = init_query_encoder( |
|
args.encoder, args.encoder_class, args.tokenizer, args.topics, args.encoded_queries, args.device, args.query_prefix) |
|
if args.pca_model: |
|
query_encoder = PcaEncoder(query_encoder, args.pca_model) |
|
kwargs = {} |
|
if os.path.exists(args.index): |
|
|
|
if args.searcher.lower() == 'bpr': |
|
kwargs = dict(binary_k=args.binary_hits, rerank=args.rerank) |
|
searcher = BinaryDenseSearcher(args.index, query_encoder) |
|
else: |
|
searcher = FaissSearcher(args.index, query_encoder) |
|
else: |
|
|
|
if args.searcher.lower() == 'bpr': |
|
kwargs = dict(binary_k=args.binary_hits, rerank=args.rerank) |
|
searcher = BinaryDenseSearcher.from_prebuilt_index(args.index, query_encoder) |
|
else: |
|
searcher = FaissSearcher.from_prebuilt_index(args.index, query_encoder) |
|
|
|
if args.ef_search: |
|
searcher.set_hnsw_ef_search(args.ef_search) |
|
|
|
if not searcher: |
|
exit() |
|
|
|
|
|
if args.prf_depth > 0 and type(searcher) == FaissSearcher: |
|
PRF_FLAG = True |
|
if args.prf_method.lower() == 'avg': |
|
prfRule = DenseVectorAveragePrf() |
|
elif args.prf_method.lower() == 'rocchio': |
|
prfRule = DenseVectorRocchioPrf(args.rocchio_alpha, args.rocchio_beta, args.rocchio_gamma, |
|
args.rocchio_topk, args.rocchio_bottomk) |
|
|
|
elif args.prf_method.lower() == 'ance-prf' and type(query_encoder) == AnceQueryEncoder: |
|
if os.path.exists(args.sparse_index): |
|
sparse_searcher = LuceneSearcher(args.sparse_index) |
|
else: |
|
sparse_searcher = LuceneSearcher.from_prebuilt_index(args.sparse_index) |
|
prf_query_encoder = AnceQueryEncoder(encoder_dir=args.ance_prf_encoder, tokenizer_name=args.tokenizer, |
|
device=args.device) |
|
prfRule = DenseVectorAncePrf(prf_query_encoder, sparse_searcher) |
|
print(f'Running FaissSearcher with {args.prf_method.upper()} PRF...') |
|
else: |
|
PRF_FLAG = False |
|
|
|
|
|
output_path = args.output |
|
|
|
print(f'Running {args.topics} topics, saving to {output_path}...') |
|
tag = 'Faiss' |
|
|
|
output_writer = get_output_writer(output_path, OutputFormat(args.output_format), 'w', |
|
max_hits=args.hits, tag=tag, topics=topics, |
|
use_max_passage=args.max_passage, |
|
max_passage_delimiter=args.max_passage_delimiter, |
|
max_passage_hits=args.max_passage_hits) |
|
|
|
with output_writer: |
|
batch_topics = list() |
|
batch_topic_ids = list() |
|
for index, (topic_id, text) in enumerate(tqdm(query_iterator, total=len(topics.keys()))): |
|
if args.batch_size <= 1 and args.threads <= 1: |
|
if PRF_FLAG: |
|
emb_q, prf_candidates = searcher.search(text, k=args.prf_depth, return_vector=True, **kwargs) |
|
|
|
if args.prf_method.lower() == 'ance-prf': |
|
prf_emb_q = prfRule.get_prf_q_emb(text, prf_candidates) |
|
else: |
|
prf_emb_q = prfRule.get_prf_q_emb(emb_q[0], prf_candidates) |
|
prf_emb_q = np.expand_dims(prf_emb_q, axis=0).astype('float32') |
|
hits = searcher.search(prf_emb_q, k=args.hits, **kwargs) |
|
else: |
|
hits = searcher.search(text, args.hits, **kwargs) |
|
results = [(topic_id, hits)] |
|
else: |
|
batch_topic_ids.append(str(topic_id)) |
|
batch_topics.append(text) |
|
if (index + 1) % args.batch_size == 0 or \ |
|
index == len(topics.keys()) - 1: |
|
if PRF_FLAG: |
|
q_embs, prf_candidates = searcher.batch_search(batch_topics, batch_topic_ids, |
|
k=args.prf_depth, return_vector=True, **kwargs) |
|
|
|
if args.prf_method.lower() == 'ance-prf': |
|
prf_embs_q = prfRule.get_batch_prf_q_emb(batch_topics, batch_topic_ids, prf_candidates) |
|
else: |
|
prf_embs_q = prfRule.get_batch_prf_q_emb(batch_topic_ids, q_embs, prf_candidates) |
|
results = searcher.batch_search(prf_embs_q, batch_topic_ids, k=args.hits, threads=args.threads, |
|
**kwargs) |
|
results = [(id_, results[id_]) for id_ in batch_topic_ids] |
|
else: |
|
results = searcher.batch_search(batch_topics, batch_topic_ids, args.hits, threads=args.threads, |
|
**kwargs) |
|
results = [(id_, results[id_]) for id_ in batch_topic_ids] |
|
batch_topic_ids.clear() |
|
batch_topics.clear() |
|
else: |
|
continue |
|
|
|
for topic, hits in results: |
|
|
|
|
|
if args.remove_query: |
|
hits = [hit for hit in hits if hit.docid != topic] |
|
|
|
output_writer.write(topic, hits) |
|
|
|
results.clear() |
|
|