|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import argparse |
|
import os |
|
|
|
from tqdm import tqdm |
|
from transformers import AutoTokenizer |
|
|
|
from pyserini.analysis import JDefaultEnglishAnalyzer, JWhiteSpaceAnalyzer |
|
from pyserini.output_writer import OutputFormat, get_output_writer |
|
from pyserini.pyclass import autoclass |
|
from pyserini.query_iterator import get_query_iterator, TopicsFormat |
|
from pyserini.search import JDisjunctionMaxQueryGenerator |
|
from . import LuceneImpactSearcher, LuceneSearcher, SlimSearcher |
|
from .reranker import ClassifierType, PseudoRelevanceClassifierReranker |
|
|
|
|
|
def set_bm25_parameters(searcher, index, k1=None, b=None): |
|
if k1 is not None or b is not None: |
|
if k1 is None or b is None: |
|
print('Must set *both* k1 and b for BM25!') |
|
exit() |
|
print(f'Setting BM25 parameters: k1={k1}, b={b}') |
|
searcher.set_bm25(k1, b) |
|
else: |
|
|
|
if index == 'msmarco-passage' or index == 'msmarco-passage-slim' or index == 'msmarco-v1-passage' or \ |
|
index == 'msmarco-v1-passage-slim' or index == 'msmarco-v1-passage-full': |
|
|
|
print('MS MARCO passage: setting k1=0.82, b=0.68') |
|
searcher.set_bm25(0.82, 0.68) |
|
elif index == 'msmarco-passage-expanded' or \ |
|
index == 'msmarco-v1-passage-d2q-t5' or \ |
|
index == 'msmarco-v1-passage-d2q-t5-docvectors': |
|
|
|
print('MS MARCO passage w/ doc2query-T5 expansion: setting k1=2.18, b=0.86') |
|
searcher.set_bm25(2.18, 0.86) |
|
elif index == 'msmarco-doc' or index == 'msmarco-doc-slim' or index == 'msmarco-v1-doc' or \ |
|
index == 'msmarco-v1-doc-slim' or index == 'msmarco-v1-doc-full': |
|
|
|
print('MS MARCO doc: setting k1=4.46, b=0.82') |
|
searcher.set_bm25(4.46, 0.82) |
|
elif index == 'msmarco-doc-per-passage' or index == 'msmarco-doc-per-passage-slim' or \ |
|
index == 'msmarco-v1-doc-segmented' or index == 'msmarco-v1-doc-segmented-slim' or \ |
|
index == 'msmarco-v1-doc-segmented-full': |
|
|
|
print('MS MARCO doc, per passage: setting k1=2.16, b=0.61') |
|
searcher.set_bm25(2.16, 0.61) |
|
elif index == 'msmarco-doc-expanded-per-doc' or \ |
|
index == 'msmarco-v1-doc-d2q-t5' or \ |
|
index == 'msmarco-v1-doc-d2q-t5-docvectors': |
|
|
|
print('MS MARCO doc w/ doc2query-T5 (per doc) expansion: setting k1=4.68, b=0.87') |
|
searcher.set_bm25(4.68, 0.87) |
|
elif index == 'msmarco-doc-expanded-per-passage' or \ |
|
index == 'msmarco-v1-doc-segmented-d2q-t5' or \ |
|
index == 'msmarco-v1-doc-segmented-d2q-t5-docvectors': |
|
|
|
print('MS MARCO doc w/ doc2query-T5 (per passage) expansion: setting k1=2.56, b=0.59') |
|
searcher.set_bm25(2.56, 0.59) |
|
|
|
|
|
def define_search_args(parser): |
|
parser.add_argument('--index', type=str, metavar='path to index or index name', required=True, |
|
help="Path to Lucene index or name of prebuilt index.") |
|
parser.add_argument('--encoded-corpus', type=str, default=None, help="path to stored sparse vectors") |
|
|
|
parser.add_argument('--impact', action='store_true', help="Use Impact.") |
|
parser.add_argument('--encoder', type=str, default=None, help="encoder name") |
|
parser.add_argument('--onnx-encoder', type=str, default=None, help="onnx encoder name") |
|
parser.add_argument('--min-idf', type=int, default=0, help="minimum idf") |
|
|
|
parser.add_argument('--bm25', action='store_true', default=True, help="Use BM25 (default).") |
|
parser.add_argument('--k1', type=float, help='BM25 k1 parameter.') |
|
parser.add_argument('--b', type=float, help='BM25 b parameter.') |
|
|
|
parser.add_argument('--rm3', action='store_true', help="Use RM3") |
|
parser.add_argument('--rocchio', action='store_true', help="Use Rocchio") |
|
parser.add_argument('--rocchio-use-negative', action='store_true', help="Use nonrelevant labels in Rocchio") |
|
parser.add_argument('--qld', action='store_true', help="Use QLD") |
|
|
|
parser.add_argument('--language', type=str, help='language code for BM25, e.g. zh for Chinese', default='en') |
|
parser.add_argument('--pretokenized', action='store_true', help="Boolean switch to accept pre-tokenized topics") |
|
|
|
parser.add_argument('--prcl', type=ClassifierType, nargs='+', default=[], |
|
help='Specify the classifier PseudoRelevanceClassifierReranker uses.') |
|
parser.add_argument('--prcl.vectorizer', dest='vectorizer', type=str, |
|
help='Type of vectorizer. Available: TfidfVectorizer, BM25Vectorizer.') |
|
parser.add_argument('--prcl.r', dest='r', type=int, default=10, |
|
help='Number of positive labels in pseudo relevance feedback.') |
|
parser.add_argument('--prcl.n', dest='n', type=int, default=100, |
|
help='Number of negative labels in pseudo relevance feedback.') |
|
parser.add_argument('--prcl.alpha', dest='alpha', type=float, default=0.5, |
|
help='Alpha value for interpolation in pseudo relevance feedback.') |
|
|
|
parser.add_argument('--fields', metavar="key=value", nargs='+', |
|
help='Fields to search with assigned float weights.') |
|
parser.add_argument('--dismax', action='store_true', default=False, |
|
help='Use disjunction max queries when searching multiple fields.') |
|
parser.add_argument('--dismax.tiebreaker', dest='tiebreaker', type=float, default=0.0, |
|
help='The tiebreaker weight to use in disjunction max queries.') |
|
|
|
parser.add_argument('--stopwords', type=str, help='Path to file with customstopwords.') |
|
|
|
|
|
if __name__ == "__main__": |
|
JLuceneSearcher = autoclass('io.anserini.search.SimpleSearcher') |
|
parser = argparse.ArgumentParser(description='Search a Lucene index.') |
|
define_search_args(parser) |
|
parser.add_argument('--topics', type=str, metavar='topic_name', required=True, |
|
help="Name of topics. Available: robust04, robust05, core17, core18.") |
|
parser.add_argument('--hits', type=int, metavar='num', |
|
required=False, default=1000, help="Number of hits.") |
|
parser.add_argument('--topics-format', type=str, metavar='format', default=TopicsFormat.DEFAULT.value, |
|
help=f"Format of topics. Available: {[x.value for x in list(TopicsFormat)]}") |
|
parser.add_argument('--output-format', type=str, metavar='format', default=OutputFormat.TREC.value, |
|
help=f"Format of output. Available: {[x.value for x in list(OutputFormat)]}") |
|
parser.add_argument('--output', type=str, metavar='path', |
|
help="Path to output file.") |
|
parser.add_argument('--max-passage', action='store_true', |
|
default=False, help="Select only max passage from document.") |
|
parser.add_argument('--max-passage-hits', type=int, metavar='num', required=False, default=100, |
|
help="Final number of hits when selecting only max passage.") |
|
parser.add_argument('--max-passage-delimiter', type=str, metavar='str', required=False, default='#', |
|
help="Delimiter between docid and passage id.") |
|
parser.add_argument('--batch-size', type=int, metavar='num', required=False, |
|
default=1, help="Specify batch size to search the collection concurrently.") |
|
parser.add_argument('--threads', type=int, metavar='num', required=False, |
|
default=1, help="Maximum number of threads to use.") |
|
parser.add_argument('--tokenizer', type=str, help='tokenizer used to preprocess topics') |
|
parser.add_argument('--remove-duplicates', action='store_true', default=False, help="Remove duplicate docs.") |
|
|
|
|
|
parser.add_argument('--remove-query', action='store_true', default=False, help="Remove query from results list.") |
|
|
|
args = parser.parse_args() |
|
|
|
query_iterator = get_query_iterator(args.topics, TopicsFormat(args.topics_format)) |
|
topics = query_iterator.topics |
|
|
|
if not args.impact: |
|
if os.path.exists(args.index): |
|
|
|
searcher = LuceneSearcher(args.index) |
|
else: |
|
|
|
searcher = LuceneSearcher.from_prebuilt_index(args.index) |
|
elif args.impact: |
|
if args.encoder and args.onnx_encoder: |
|
raise ValueError("Cannot specify both --encoder and --onnx-encoder") |
|
if args.encoder: |
|
if os.path.exists(args.index): |
|
if args.encoded_corpus is not None: |
|
searcher = SlimSearcher(args.encoded_corpus, args.index, args.encoder, args.min_idf) |
|
else: |
|
searcher = LuceneImpactSearcher(args.index, args.encoder, args.min_idf) |
|
else: |
|
if args.encoded_corpus is not None: |
|
searcher = SlimSearcher.from_prebuilt_index(args.encoded_corpus, args.index, args.encoder, args.min_idf) |
|
else: |
|
searcher = LuceneImpactSearcher.from_prebuilt_index(args.index, args.encoder, args.min_idf) |
|
elif args.onnx_encoder: |
|
if os.path.exists(args.index): |
|
if args.encoded_corpus is not None: |
|
searcher = SlimSearcher(args.encoded_corpus, args.index, args.onnx_encoder, args.min_idf) |
|
else: |
|
searcher = LuceneImpactSearcher(args.index, args.onnx_encoder, args.min_idf, 'onnx') |
|
else: |
|
if args.encoded_corpus is not None: |
|
searcher = SlimSearcher.from_prebuilt_index(args.encoded_corpus, args.index, args.onnx_encoder, args.min_idf) |
|
else: |
|
searcher = LuceneImpactSearcher.from_prebuilt_index(args.index, args.onnx_encoder, args.min_idf, 'onnx') |
|
|
|
elif os.path.exists(args.index): |
|
searcher = LuceneImpactSearcher(args.index, args.encoder, args.min_idf) |
|
else: |
|
searcher = LuceneImpactSearcher.from_prebuilt_index(args.index, args.encoder, args.min_idf) |
|
|
|
if args.language != 'en': |
|
searcher.set_language(args.language) |
|
|
|
if not searcher: |
|
exit() |
|
|
|
search_rankers = [] |
|
|
|
if args.qld: |
|
search_rankers.append('qld') |
|
searcher.set_qld() |
|
elif args.bm25: |
|
search_rankers.append('bm25') |
|
set_bm25_parameters(searcher, args.index, args.k1, args.b) |
|
|
|
if args.rm3: |
|
search_rankers.append('rm3') |
|
searcher.set_rm3() |
|
|
|
if args.rocchio: |
|
search_rankers.append('rocchio') |
|
if args.rocchio_use_negative: |
|
searcher.set_rocchio(gamma=0.15, use_negative=True) |
|
else: |
|
searcher.set_rocchio() |
|
|
|
fields = dict() |
|
if args.fields: |
|
fields = dict([pair.split('=') for pair in args.fields]) |
|
print(f'Searching over fields: {fields}') |
|
|
|
query_generator = None |
|
if args.dismax: |
|
query_generator = JDisjunctionMaxQueryGenerator(args.tiebreaker) |
|
print(f'Using dismax query generator with tiebreaker={args.tiebreaker}') |
|
|
|
if args.pretokenized: |
|
analyzer = JWhiteSpaceAnalyzer() |
|
searcher.set_analyzer(analyzer) |
|
if args.tokenizer is not None: |
|
raise ValueError(f"--tokenizer is not supported with when setting --pretokenized.") |
|
|
|
if args.tokenizer != None: |
|
analyzer = JWhiteSpaceAnalyzer() |
|
searcher.set_analyzer(analyzer) |
|
print(f'Using whitespace analyzer because of pretokenized topics') |
|
tokenizer = AutoTokenizer.from_pretrained(args.tokenizer) |
|
print(f'Using {args.tokenizer} to preprocess topics') |
|
|
|
if args.stopwords: |
|
analyzer = JDefaultEnglishAnalyzer.fromArguments('porter', False, args.stopwords) |
|
searcher.set_analyzer(analyzer) |
|
print(f'Using custom stopwords={args.stopwords}') |
|
|
|
|
|
use_prcl = args.prcl and len(args.prcl) > 0 and args.alpha > 0 |
|
if use_prcl is True: |
|
ranker = PseudoRelevanceClassifierReranker( |
|
searcher.index_dir, args.vectorizer, args.prcl, r=args.r, n=args.n, alpha=args.alpha) |
|
|
|
|
|
output_path = args.output |
|
if output_path is None: |
|
if use_prcl is True: |
|
clf_rankers = [] |
|
for t in args.prcl: |
|
if t == ClassifierType.LR: |
|
clf_rankers.append('lr') |
|
elif t == ClassifierType.SVM: |
|
clf_rankers.append('svm') |
|
|
|
r_str = f'prcl.r_{args.r}' |
|
n_str = f'prcl.n_{args.n}' |
|
a_str = f'prcl.alpha_{args.alpha}' |
|
clf_str = 'prcl_' + '+'.join(clf_rankers) |
|
tokens1 = ['run', args.topics, '+'.join(search_rankers)] |
|
tokens2 = [args.vectorizer, clf_str, r_str, n_str, a_str] |
|
output_path = '.'.join(tokens1) + '-' + '-'.join(tokens2) + ".txt" |
|
else: |
|
tokens = ['run', args.topics, '+'.join(search_rankers), 'txt'] |
|
output_path = '.'.join(tokens) |
|
|
|
print(f'Running {args.topics} topics, saving to {output_path}...') |
|
tag = output_path[:-4] if args.output is None else 'Anserini' |
|
|
|
output_writer = get_output_writer(output_path, OutputFormat(args.output_format), 'w', |
|
max_hits=args.hits, tag=tag, topics=topics, |
|
use_max_passage=args.max_passage, |
|
max_passage_delimiter=args.max_passage_delimiter, |
|
max_passage_hits=args.max_passage_hits) |
|
|
|
with output_writer: |
|
batch_topics = list() |
|
batch_topic_ids = list() |
|
for index, (topic_id, text) in enumerate(tqdm(query_iterator, total=len(topics.keys()))): |
|
if (args.tokenizer != None): |
|
toks = tokenizer.tokenize(text) |
|
text = ' ' |
|
text = text.join(toks) |
|
if args.batch_size <= 1 and args.threads <= 1: |
|
if args.impact: |
|
hits = searcher.search(text, args.hits, fields=fields) |
|
else: |
|
hits = searcher.search(text, args.hits, query_generator=query_generator, fields=fields) |
|
results = [(topic_id, hits)] |
|
else: |
|
batch_topic_ids.append(str(topic_id)) |
|
batch_topics.append(text) |
|
if (index + 1) % args.batch_size == 0 or \ |
|
index == len(topics.keys()) - 1: |
|
if args.impact: |
|
results = searcher.batch_search( |
|
batch_topics, batch_topic_ids, args.hits, args.threads, fields=fields |
|
) |
|
else: |
|
results = searcher.batch_search( |
|
batch_topics, batch_topic_ids, args.hits, args.threads, |
|
query_generator=query_generator, fields=fields |
|
) |
|
results = [(id_, results[id_]) for id_ in batch_topic_ids] |
|
batch_topic_ids.clear() |
|
batch_topics.clear() |
|
else: |
|
continue |
|
|
|
for topic, hits in results: |
|
|
|
if use_prcl and len(hits) > (args.r + args.n): |
|
docids = [hit.docid.strip() for hit in hits] |
|
scores = [hit.score for hit in hits] |
|
scores, docids = ranker.rerank(docids, scores) |
|
docid_score_map = dict(zip(docids, scores)) |
|
for hit in hits: |
|
hit.score = docid_score_map[hit.docid.strip()] |
|
|
|
if args.remove_duplicates: |
|
seen_docids = set() |
|
dedup_hits = [] |
|
for hit in hits: |
|
if hit.docid.strip() in seen_docids: |
|
continue |
|
seen_docids.add(hit.docid.strip()) |
|
dedup_hits.append(hit) |
|
hits = dedup_hits |
|
|
|
|
|
|
|
if args.remove_query: |
|
hits = [hit for hit in hits if hit.docid != topic] |
|
|
|
|
|
output_writer.write(topic, hits) |
|
|
|
results.clear() |
|
|