File size: 4,386 Bytes
d6585f5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
#
# Pyserini: Python interface to the Anserini IR toolkit built on Lucene
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

# IMPORTANT NOTE: As of January 2021, this script is *DEFUNCT*.
# That is, you *SHOULDN'T* be using it anymore.
# For "standard" batch retrieval runs on MS MARCO, use 'python -m pyserini.search ...'.

import sys
import argparse
import time

# We're going to explicitly use a local installation of Pyserini (as opposed to a pip-installed one).
# Comment these lines out to use a pip-installed one instead.
sys.path.insert(0, './')
sys.path.insert(0, '../pyserini/')

from pyserini.search.lucene import LuceneSearcher


if __name__ == '__main__':
    print(f'WARNING: this script is defunct. Use python -m pyserini.search.lucene instead.')

    parser = argparse.ArgumentParser(description='Retrieve MS MARCO Passages.')
    parser.add_argument('--queries', required=True, default='', help='Queries file.')
    parser.add_argument('--output', required=True, default='', help='Output run file.')
    parser.add_argument('--index', required=True, default='', help='Index path.')
    parser.add_argument('--hits', default=10, type=int, help='Number of hits to retrieve.')
    parser.add_argument('--k1', default=0.82, type=float, help='BM25 k1 parameter.')
    parser.add_argument('--b', default=0.68, type=float, help='BM25 b parameter.')
    # See our MS MARCO documentation to understand how these parameter values were tuned.
    parser.add_argument('--rm3', action='store_true', default=False, help='Use RM3.')
    parser.add_argument('--fbTerms', default=10, type=int, help='RM3: number of expansion terms.')
    parser.add_argument('--fbDocs', default=10, type=int, help='RM3: number of documents.')
    parser.add_argument('--originalQueryWeight', default=0.5, type=float, help='RM3: weight of original query.')
    parser.add_argument('--threads', default=1, type=int, help='Maximum number of threads.')

    args = parser.parse_args()

    total_start_time = time.time()

    searcher = LuceneSearcher(args.index)
    searcher.set_bm25(args.k1, args.b)
    print(f'Initializing BM25, setting k1={args.k1} and b={args.b}', flush=True)
    if args.rm3:
        searcher.set_rm3(args.fbTerms, args.fbDocs, args.originalQueryWeight)
        print(f'Initializing RM3, setting fbTerms={args.fbTerms}, fbDocs={args.fbDocs}, ' +
              f' and originalQueryWeight={args.originalQueryWeight}', flush=True)

    if args.threads == 1:

        with open(args.output, 'w') as fout:
            start_time = time.time()
            for line_number, line in enumerate(open(args.queries, 'r', encoding='utf8')):
                qid, query = line.strip().split('\t')
                hits = searcher.search(query, args.hits)
                if line_number % 100 == 0:
                    time_per_query = (time.time() - start_time) / (line_number + 1)
                    print(f'Retrieving query {line_number} ({time_per_query:0.3f} s/query)', flush=True)
                for rank in range(len(hits)):
                    docno = hits[rank].docid
                    fout.write('{}\t{}\t{}\n'.format(qid, docno, rank + 1))
    else:
        qids = []
        queries = []
        result_dict = {}

        for line_number, line in enumerate(open(args.queries, 'r', encoding='utf8')):
            qid, query = line.strip().split('\t')
            qids.append(qid)
            queries.append(query)

        results = searcher.batch_search(queries, qids, args.hits, args.threads)

        with open(args.output, 'w') as fout:
            for qid in qids:
                hits = results.get(qid)
                for rank in range(len(hits)):
                    docno = hits[rank].docid
                    fout.write(f'{qid}\t{docno}\t{rank+1}\n')

    total_time = (time.time() - total_start_time)
    print(f'Total retrieval time: {total_time:0.3f} s')
    print('Done!')