File size: 8,125 Bytes
6fc683c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
import os
import random
import tqdm
import json

from typing import Dict, List, Any
from datasets import load_dataset, Dataset
from dataclasses import dataclass, field

from logger_config import logger
from config import Arguments
from utils import save_json_to_file


@dataclass
class ScoredDoc:
    qid: str
    pid: str
    rank: int
    score: float = field(default=-1)


def load_qrels(path: str) -> Dict[str, Dict[str, int]]:
    assert path.endswith('.txt')

    # qid -> pid -> score
    qrels = {}
    for line in open(path, 'r', encoding='utf-8'):
        qid, _, pid, score = line.strip().split('\t')
        if qid not in qrels:
            qrels[qid] = {}
        qrels[qid][pid] = int(score)

    logger.info('Load {} queries {} qrels from {}'.format(len(qrels), sum(len(v) for v in qrels.values()), path))
    return qrels


def load_queries(path: str, task_type: str = 'ir') -> Dict[str, str]:
    assert path.endswith('.tsv')

    if task_type == 'qa':
        qid_to_query = load_query_answers(path)
        qid_to_query = {k: v['query'] for k, v in qid_to_query.items()}
    elif task_type == 'ir':
        qid_to_query = {}
        for line in open(path, 'r', encoding='utf-8'):
            qid, query = line.strip().split('\t')
            qid_to_query[qid] = query
    else:
        raise ValueError('Unknown task type: {}'.format(task_type))

    logger.info('Load {} queries from {}'.format(len(qid_to_query), path))
    return qid_to_query


def normalize_qa_text(text: str) -> str:
    # TriviaQA has some weird formats
    # For example: """What breakfast food gets its name from the German word for """"stirrup""""?"""
    while text.startswith('"') and text.endswith('"'):
        text = text[1:-1].replace('""', '"')
    return text


def get_question_key(question: str) -> str:
    # For QA dataset, we'll use normalized question strings as dict key
    return question


def load_query_answers(path: str) -> Dict[str, Dict[str, Any]]:
    assert path.endswith('.tsv')

    qid_to_query = {}
    for line in open(path, 'r', encoding='utf-8'):
        query, answers = line.strip().split('\t')
        query = normalize_qa_text(query)
        answers = normalize_qa_text(answers)
        qid = get_question_key(query)
        if qid in qid_to_query:
            logger.warning('Duplicate question: {} vs {}'.format(query, qid_to_query[qid]['query']))
            continue

        qid_to_query[qid] = {}
        qid_to_query[qid]['query'] = query
        qid_to_query[qid]['answers'] = list(eval(answers))

    logger.info('Load {} queries from {}'.format(len(qid_to_query), path))
    return qid_to_query


def load_corpus(path: str) -> Dataset:
    assert path.endswith('.jsonl') or path.endswith('.jsonl.gz')

    # two fields: id, contents
    corpus = load_dataset('json', data_files=path)['train']
    logger.info('Load {} documents from {} with columns {}'.format(len(corpus), path, corpus.column_names))
    logger.info('A random document: {}'.format(random.choice(corpus)))
    return corpus


def load_msmarco_predictions(path: str) -> Dict[str, List[ScoredDoc]]:
    assert path.endswith('.txt')

    qid_to_scored_doc = {}
    for line in tqdm.tqdm(open(path, 'r', encoding='utf-8'), desc='load prediction', mininterval=3):
        fs = line.strip().split('\t')
        qid, pid, rank = fs[:3]
        rank = int(rank)
        score = round(1 / rank, 4) if len(fs) == 3 else float(fs[3])

        if qid not in qid_to_scored_doc:
            qid_to_scored_doc[qid] = []
        scored_doc = ScoredDoc(qid=qid, pid=pid, rank=rank, score=score)
        qid_to_scored_doc[qid].append(scored_doc)

    qid_to_scored_doc = {qid: sorted(scored_docs, key=lambda sd: sd.rank)
                         for qid, scored_docs in qid_to_scored_doc.items()}

    logger.info('Load {} query predictions from {}'.format(len(qid_to_scored_doc), path))
    return qid_to_scored_doc


def save_preds_to_msmarco_format(preds: Dict[str, List[ScoredDoc]], out_path: str):
    with open(out_path, 'w', encoding='utf-8') as writer:
        for qid in preds:
            for idx, scored_doc in enumerate(preds[qid]):
                writer.write('{}\t{}\t{}\t{}\n'.format(qid, scored_doc.pid, idx + 1, round(scored_doc.score, 3)))
    logger.info('Successfully saved to {}'.format(out_path))


def save_to_readable_format(in_path: str, corpus: Dataset):
    out_path = '{}/readable_{}'.format(os.path.dirname(in_path), os.path.basename(in_path))
    dataset: Dataset = load_dataset('json', data_files=in_path)['train']

    max_to_keep = 5

    def _create_readable_field(samples: Dict[str, List]) -> List:
        readable_ex = []
        for idx in range(min(len(samples['doc_id']), max_to_keep)):
            doc_id = samples['doc_id'][idx]
            readable_ex.append({'doc_id': doc_id,
                                'title': corpus[int(doc_id)].get('title', ''),
                                'contents': corpus[int(doc_id)]['contents'],
                                'score': samples['score'][idx]})
        return readable_ex

    def _mp_func(ex: Dict) -> Dict:
        ex['positives'] = _create_readable_field(ex['positives'])
        ex['negatives'] = _create_readable_field(ex['negatives'])
        return ex
    dataset = dataset.map(_mp_func, num_proc=8)

    dataset.to_json(out_path, force_ascii=False, lines=False, indent=4)
    logger.info('Done convert {} to readable format in {}'.format(in_path, out_path))


def get_rerank_shard_path(args: Arguments, worker_idx: int) -> str:
    return '{}_shard_{}'.format(args.rerank_out_path, worker_idx)


def merge_rerank_predictions(args: Arguments, gpu_count: int):
    from metrics import trec_eval, compute_mrr

    qid_to_scored_doc: Dict[str, List[ScoredDoc]] = {}
    for worker_idx in range(gpu_count):
        path = get_rerank_shard_path(args, worker_idx)
        for line in tqdm.tqdm(open(path, 'r', encoding='utf-8'), 'merge results', mininterval=3):
            fs = line.strip().split('\t')
            qid, pid, _, score = fs
            score = float(score)

            if qid not in qid_to_scored_doc:
                qid_to_scored_doc[qid] = []
            scored_doc = ScoredDoc(qid=qid, pid=pid, rank=-1, score=score)
            qid_to_scored_doc[qid].append(scored_doc)

    qid_to_scored_doc = {k: sorted(v, key=lambda sd: sd.score, reverse=True) for k, v in qid_to_scored_doc.items()}

    ori_preds = load_msmarco_predictions(path=args.rerank_in_path)
    for query_id in list(qid_to_scored_doc.keys()):
        remain_scored_docs = ori_preds[query_id][args.rerank_depth:]
        for idx, sd in enumerate(remain_scored_docs):
            # make sure the order is not broken
            sd.score = qid_to_scored_doc[query_id][-1].score - idx - 1
        qid_to_scored_doc[query_id] += remain_scored_docs
        assert len(set([sd.pid for sd in qid_to_scored_doc[query_id]])) == len(qid_to_scored_doc[query_id])

    save_preds_to_msmarco_format(qid_to_scored_doc, out_path=args.rerank_out_path)

    path_qrels = '{}/{}_qrels.txt'.format(args.data_dir, args.rerank_split)
    if os.path.exists(path_qrels):
        qrels = load_qrels(path=path_qrels)
        all_metrics = trec_eval(qrels=qrels, predictions=qid_to_scored_doc)
        all_metrics['mrr'] = compute_mrr(qrels=qrels, predictions=qid_to_scored_doc)

        logger.info('{} trec metrics = {}'.format(args.rerank_split, json.dumps(all_metrics, ensure_ascii=False, indent=4)))
        metrics_out_path = '{}/metrics_rerank_{}.json'.format(os.path.dirname(args.rerank_out_path), args.rerank_split)
        save_json_to_file(all_metrics, metrics_out_path)
    else:
        logger.warning('No qrels found for {}'.format(args.rerank_split))

    # cleanup some intermediate results
    for worker_idx in range(gpu_count):
        path = get_rerank_shard_path(args, worker_idx)
        os.remove(path)


if __name__ == '__main__':
    load_qrels('./data/msmarco/dev_qrels.txt')
    load_queries('./data/msmarco/dev_queries.tsv')
    corpus = load_corpus('./data/msmarco/passages.jsonl.gz')
    preds = load_msmarco_predictions('./data/bm25.msmarco.txt')