import torch import pytrec_eval from typing import List, Dict, Tuple from data_utils import ScoredDoc from logger_config import logger def trec_eval(qrels: Dict[str, Dict[str, int]], predictions: Dict[str, List[ScoredDoc]], k_values: Tuple[int] = (10, 50, 100, 200, 1000)) -> Dict[str, float]: ndcg, _map, recall = {}, {}, {} for k in k_values: ndcg[f"NDCG@{k}"] = 0.0 _map[f"MAP@{k}"] = 0.0 recall[f"Recall@{k}"] = 0.0 map_string = "map_cut." + ",".join([str(k) for k in k_values]) ndcg_string = "ndcg_cut." + ",".join([str(k) for k in k_values]) recall_string = "recall." + ",".join([str(k) for k in k_values]) results: Dict[str, Dict[str, float]] = {} for query_id, scored_docs in predictions.items(): results.update({query_id: {sd.pid: sd.score for sd in scored_docs}}) evaluator = pytrec_eval.RelevanceEvaluator(qrels, {map_string, ndcg_string, recall_string}) scores = evaluator.evaluate(results) for query_id in scores: for k in k_values: ndcg[f"NDCG@{k}"] += scores[query_id]["ndcg_cut_" + str(k)] _map[f"MAP@{k}"] += scores[query_id]["map_cut_" + str(k)] recall[f"Recall@{k}"] += scores[query_id]["recall_" + str(k)] def _normalize(m: dict) -> dict: return {k: round(v / len(scores), 5) for k, v in m.items()} ndcg = _normalize(ndcg) _map = _normalize(_map) recall = _normalize(recall) all_metrics = {} for mt in [ndcg, _map, recall]: all_metrics.update(mt) return all_metrics @torch.no_grad() def accuracy(output: torch.tensor, target: torch.tensor, topk=(1,)) -> List[float]: """Computes the accuracy over the k top predictions for the specified values of k""" maxk = max(topk) batch_size = target.size(0) _, pred = output.topk(maxk, 1, True, True) pred = pred.t() correct = pred.eq(target.view(1, -1).expand_as(pred)) res = [] for k in topk: correct_k = correct[:k].contiguous().view(-1).float().sum(0, keepdim=True) res.append(correct_k.mul_(100.0 / batch_size).item()) return res @torch.no_grad() def batch_mrr(output: torch.tensor, target: torch.tensor) -> float: assert len(output.shape) == 2 assert len(target.shape) == 1 sorted_score, sorted_indices = torch.sort(output, dim=-1, descending=True) _, rank = torch.nonzero(sorted_indices.eq(target.unsqueeze(-1)).long(), as_tuple=True) assert rank.shape[0] == output.shape[0] rank = rank + 1 mrr = torch.sum(100 / rank.float()) / rank.shape[0] return mrr.item() def get_rel_threshold(qrels: Dict[str, Dict[str, int]]) -> int: # For ms-marco passage ranking, score >= 1 is relevant # for trec dl 2019 & 2020, score >= 2 is relevant rel_labels = set() for q_id in qrels: for doc_id, label in qrels[q_id].items(): rel_labels.add(label) logger.info('relevance labels: {}'.format(rel_labels)) return 2 if max(rel_labels) >= 3 else 1 def compute_mrr(qrels: Dict[str, Dict[str, int]], predictions: Dict[str, List[ScoredDoc]], k: int = 10) -> float: threshold = get_rel_threshold(qrels) mrr = 0 for qid in qrels: scored_docs = predictions.get(qid, []) for idx, scored_doc in enumerate(scored_docs[:k]): if scored_doc.pid in qrels[qid] and qrels[qid][scored_doc.pid] >= threshold: mrr += 1 / (idx + 1) break return round(mrr / len(qrels) * 100, 4)