File size: 1,180 Bytes
f670afc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES.  All rights reserved.
#
# This work is made available under the Nvidia Source Code License-NC.
# To view a copy of this license, check out LICENSE.md
# flake8: noqa

import torch
import torch.nn.functional as F


def get_r_precision(image_text_code, eps=1e-5):
    all_image_code, all_text_code = torch.chunk(image_text_code, 2, dim=1)
    P_rates = []
    num_samples = len(all_image_code)
    assert num_samples >= 100
    for i in range(0, num_samples, 100):
        if i + 100 <= num_samples:
            cur_image_code = all_image_code[i:i + 100]
            cur_text_code = all_text_code[i:i + 100]
            cur_image_code = F.normalize(cur_image_code, dim=1, eps=eps)
            cur_text_code = F.normalize(cur_text_code, dim=1, eps=eps)
            cosine_similarities = cur_image_code @ cur_text_code.T
            top1_indices = torch.topk(cosine_similarities, dim=1, k=1)[1][:, 0]
            P_rate = torch.sum(top1_indices == torch.arange(100, device=top1_indices.device)).item()
            P_rates.append(P_rate)
    A_precision = sum(P_rates) * 1.0 / len(P_rates)
    return {"caption_rprec": A_precision}