Spaces:
Sleeping
Sleeping
import os | |
import random | |
import tqdm | |
import json | |
from typing import Dict, List, Any | |
from datasets import load_dataset, Dataset | |
from dataclasses import dataclass, field | |
from logger_config import logger | |
from config import Arguments | |
from utils import save_json_to_file | |
class ScoredDoc: | |
qid: str | |
pid: str | |
rank: int | |
score: float = field(default=-1) | |
def load_qrels(path: str) -> Dict[str, Dict[str, int]]: | |
assert path.endswith('.txt') | |
# qid -> pid -> score | |
qrels = {} | |
for line in open(path, 'r', encoding='utf-8'): | |
qid, _, pid, score = line.strip().split('\t') | |
if qid not in qrels: | |
qrels[qid] = {} | |
qrels[qid][pid] = int(score) | |
logger.info('Load {} queries {} qrels from {}'.format(len(qrels), sum(len(v) for v in qrels.values()), path)) | |
return qrels | |
def load_queries(path: str, task_type: str = 'ir') -> Dict[str, str]: | |
assert path.endswith('.tsv') | |
if task_type == 'qa': | |
qid_to_query = load_query_answers(path) | |
qid_to_query = {k: v['query'] for k, v in qid_to_query.items()} | |
elif task_type == 'ir': | |
qid_to_query = {} | |
for line in open(path, 'r', encoding='utf-8'): | |
qid, query = line.strip().split('\t') | |
qid_to_query[qid] = query | |
else: | |
raise ValueError('Unknown task type: {}'.format(task_type)) | |
logger.info('Load {} queries from {}'.format(len(qid_to_query), path)) | |
return qid_to_query | |
def normalize_qa_text(text: str) -> str: | |
# TriviaQA has some weird formats | |
# For example: """What breakfast food gets its name from the German word for """"stirrup""""?""" | |
while text.startswith('"') and text.endswith('"'): | |
text = text[1:-1].replace('""', '"') | |
return text | |
def get_question_key(question: str) -> str: | |
# For QA dataset, we'll use normalized question strings as dict key | |
return question | |
def load_query_answers(path: str) -> Dict[str, Dict[str, Any]]: | |
assert path.endswith('.tsv') | |
qid_to_query = {} | |
for line in open(path, 'r', encoding='utf-8'): | |
query, answers = line.strip().split('\t') | |
query = normalize_qa_text(query) | |
answers = normalize_qa_text(answers) | |
qid = get_question_key(query) | |
if qid in qid_to_query: | |
logger.warning('Duplicate question: {} vs {}'.format(query, qid_to_query[qid]['query'])) | |
continue | |
qid_to_query[qid] = {} | |
qid_to_query[qid]['query'] = query | |
qid_to_query[qid]['answers'] = list(eval(answers)) | |
logger.info('Load {} queries from {}'.format(len(qid_to_query), path)) | |
return qid_to_query | |
def load_corpus(path: str) -> Dataset: | |
assert path.endswith('.jsonl') or path.endswith('.jsonl.gz') | |
# two fields: id, contents | |
corpus = load_dataset('json', data_files=path)['train'] | |
logger.info('Load {} documents from {} with columns {}'.format(len(corpus), path, corpus.column_names)) | |
logger.info('A random document: {}'.format(random.choice(corpus))) | |
return corpus | |
def load_msmarco_predictions(path: str) -> Dict[str, List[ScoredDoc]]: | |
assert path.endswith('.txt') | |
qid_to_scored_doc = {} | |
for line in tqdm.tqdm(open(path, 'r', encoding='utf-8'), desc='load prediction', mininterval=3): | |
fs = line.strip().split('\t') | |
qid, pid, rank = fs[:3] | |
rank = int(rank) | |
score = round(1 / rank, 4) if len(fs) == 3 else float(fs[3]) | |
if qid not in qid_to_scored_doc: | |
qid_to_scored_doc[qid] = [] | |
scored_doc = ScoredDoc(qid=qid, pid=pid, rank=rank, score=score) | |
qid_to_scored_doc[qid].append(scored_doc) | |
qid_to_scored_doc = {qid: sorted(scored_docs, key=lambda sd: sd.rank) | |
for qid, scored_docs in qid_to_scored_doc.items()} | |
logger.info('Load {} query predictions from {}'.format(len(qid_to_scored_doc), path)) | |
return qid_to_scored_doc | |
def save_preds_to_msmarco_format(preds: Dict[str, List[ScoredDoc]], out_path: str): | |
with open(out_path, 'w', encoding='utf-8') as writer: | |
for qid in preds: | |
for idx, scored_doc in enumerate(preds[qid]): | |
writer.write('{}\t{}\t{}\t{}\n'.format(qid, scored_doc.pid, idx + 1, round(scored_doc.score, 3))) | |
logger.info('Successfully saved to {}'.format(out_path)) | |
def save_to_readable_format(in_path: str, corpus: Dataset): | |
out_path = '{}/readable_{}'.format(os.path.dirname(in_path), os.path.basename(in_path)) | |
dataset: Dataset = load_dataset('json', data_files=in_path)['train'] | |
max_to_keep = 5 | |
def _create_readable_field(samples: Dict[str, List]) -> List: | |
readable_ex = [] | |
for idx in range(min(len(samples['doc_id']), max_to_keep)): | |
doc_id = samples['doc_id'][idx] | |
readable_ex.append({'doc_id': doc_id, | |
'title': corpus[int(doc_id)].get('title', ''), | |
'contents': corpus[int(doc_id)]['contents'], | |
'score': samples['score'][idx]}) | |
return readable_ex | |
def _mp_func(ex: Dict) -> Dict: | |
ex['positives'] = _create_readable_field(ex['positives']) | |
ex['negatives'] = _create_readable_field(ex['negatives']) | |
return ex | |
dataset = dataset.map(_mp_func, num_proc=8) | |
dataset.to_json(out_path, force_ascii=False, lines=False, indent=4) | |
logger.info('Done convert {} to readable format in {}'.format(in_path, out_path)) | |
def get_rerank_shard_path(args: Arguments, worker_idx: int) -> str: | |
return '{}_shard_{}'.format(args.rerank_out_path, worker_idx) | |
def merge_rerank_predictions(args: Arguments, gpu_count: int): | |
from metrics import trec_eval, compute_mrr | |
qid_to_scored_doc: Dict[str, List[ScoredDoc]] = {} | |
for worker_idx in range(gpu_count): | |
path = get_rerank_shard_path(args, worker_idx) | |
for line in tqdm.tqdm(open(path, 'r', encoding='utf-8'), 'merge results', mininterval=3): | |
fs = line.strip().split('\t') | |
qid, pid, _, score = fs | |
score = float(score) | |
if qid not in qid_to_scored_doc: | |
qid_to_scored_doc[qid] = [] | |
scored_doc = ScoredDoc(qid=qid, pid=pid, rank=-1, score=score) | |
qid_to_scored_doc[qid].append(scored_doc) | |
qid_to_scored_doc = {k: sorted(v, key=lambda sd: sd.score, reverse=True) for k, v in qid_to_scored_doc.items()} | |
ori_preds = load_msmarco_predictions(path=args.rerank_in_path) | |
for query_id in list(qid_to_scored_doc.keys()): | |
remain_scored_docs = ori_preds[query_id][args.rerank_depth:] | |
for idx, sd in enumerate(remain_scored_docs): | |
# make sure the order is not broken | |
sd.score = qid_to_scored_doc[query_id][-1].score - idx - 1 | |
qid_to_scored_doc[query_id] += remain_scored_docs | |
assert len(set([sd.pid for sd in qid_to_scored_doc[query_id]])) == len(qid_to_scored_doc[query_id]) | |
save_preds_to_msmarco_format(qid_to_scored_doc, out_path=args.rerank_out_path) | |
path_qrels = '{}/{}_qrels.txt'.format(args.data_dir, args.rerank_split) | |
if os.path.exists(path_qrels): | |
qrels = load_qrels(path=path_qrels) | |
all_metrics = trec_eval(qrels=qrels, predictions=qid_to_scored_doc) | |
all_metrics['mrr'] = compute_mrr(qrels=qrels, predictions=qid_to_scored_doc) | |
logger.info('{} trec metrics = {}'.format(args.rerank_split, json.dumps(all_metrics, ensure_ascii=False, indent=4))) | |
metrics_out_path = '{}/metrics_rerank_{}.json'.format(os.path.dirname(args.rerank_out_path), args.rerank_split) | |
save_json_to_file(all_metrics, metrics_out_path) | |
else: | |
logger.warning('No qrels found for {}'.format(args.rerank_split)) | |
# cleanup some intermediate results | |
for worker_idx in range(gpu_count): | |
path = get_rerank_shard_path(args, worker_idx) | |
os.remove(path) | |
if __name__ == '__main__': | |
load_qrels('./data/msmarco/dev_qrels.txt') | |
load_queries('./data/msmarco/dev_queries.tsv') | |
corpus = load_corpus('./data/msmarco/passages.jsonl.gz') | |
preds = load_msmarco_predictions('./data/bm25.msmarco.txt') | |