|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import cmd |
|
import json |
|
import random |
|
|
|
from pyserini.search.lucene import LuceneSearcher |
|
from pyserini.search.faiss import FaissSearcher, DprQueryEncoder |
|
from pyserini.search.hybrid import HybridSearcher |
|
from pyserini import search |
|
|
|
|
|
class DPRDemo(cmd.Cmd): |
|
nq_dev_topics = list(search.get_topics('dpr-nq-dev').values()) |
|
trivia_dev_topics = list(search.get_topics('dpr-trivia-dev').values()) |
|
|
|
ssearcher = LuceneSearcher.from_prebuilt_index('wikipedia-dpr') |
|
searcher = ssearcher |
|
|
|
encoder = DprQueryEncoder("facebook/dpr-question_encoder-multiset-base") |
|
index = 'wikipedia-dpr-multi-bf' |
|
dsearcher = FaissSearcher.from_prebuilt_index( |
|
index, |
|
encoder |
|
) |
|
hsearcher = HybridSearcher(dsearcher, ssearcher) |
|
|
|
k = 10 |
|
prompt = '>>> ' |
|
|
|
def precmd(self, line): |
|
if line[0] == '/': |
|
line = line[1:] |
|
return line |
|
|
|
def do_help(self, arg): |
|
print(f'/help : returns this message') |
|
print(f'/k [NUM] : sets k (number of hits to return) to [NUM]') |
|
print(f'/mode [MODE] : sets retriever type to [MODE] (one of sparse, dense, hybrid)') |
|
print(f'/random [COLLECTION]: returns results for a random question from the dev subset [COLLECTION] (one of nq, trivia).') |
|
|
|
def do_k(self, arg): |
|
print(f'setting k = {int(arg)}') |
|
self.k = int(arg) |
|
|
|
def do_mode(self, arg): |
|
if arg == "sparse": |
|
self.searcher = self.ssearcher |
|
elif arg == "dense": |
|
self.searcher = self.dsearcher |
|
elif arg == "hybrid": |
|
self.searcher = self.hsearcher |
|
else: |
|
print( |
|
f'Mode "{arg}" is invalid. Mode should be one of [sparse, dense, hybrid].') |
|
return |
|
print(f'setting retriver = {arg}') |
|
|
|
def do_random(self, arg): |
|
if arg == "nq": |
|
topics = self.nq_dev_topics |
|
elif arg == "trivia": |
|
topics = self.trivia_dev_topics |
|
else: |
|
print( |
|
f'Collection "{arg}" is invalid. Collection should be one of [nq, trivia].') |
|
return |
|
q = random.choice(topics)['title'] |
|
print(f'question: {q}') |
|
self.default(q) |
|
|
|
def do_EOF(self, line): |
|
return True |
|
|
|
def default(self, q): |
|
hits = self.searcher.search(q, self.k) |
|
|
|
for i in range(0, len(hits)): |
|
raw_doc = None |
|
if isinstance(self.searcher, LuceneSearcher): |
|
raw_doc = hits[i].raw |
|
else: |
|
doc = self.searcher.doc(hits[i].docid) |
|
if doc: |
|
raw_doc = doc.raw() |
|
jsondoc = json.loads(raw_doc) |
|
print(f'{i + 1:2} {hits[i].score:.5f} {jsondoc["contents"]}') |
|
|
|
|
|
if __name__ == '__main__': |
|
DPRDemo().cmdloop() |
|
|