Spaces:
Runtime error
Runtime error
File size: 1,180 Bytes
f670afc |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 |
# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# This work is made available under the Nvidia Source Code License-NC.
# To view a copy of this license, check out LICENSE.md
# flake8: noqa
import torch
import torch.nn.functional as F
def get_r_precision(image_text_code, eps=1e-5):
all_image_code, all_text_code = torch.chunk(image_text_code, 2, dim=1)
P_rates = []
num_samples = len(all_image_code)
assert num_samples >= 100
for i in range(0, num_samples, 100):
if i + 100 <= num_samples:
cur_image_code = all_image_code[i:i + 100]
cur_text_code = all_text_code[i:i + 100]
cur_image_code = F.normalize(cur_image_code, dim=1, eps=eps)
cur_text_code = F.normalize(cur_text_code, dim=1, eps=eps)
cosine_similarities = cur_image_code @ cur_text_code.T
top1_indices = torch.topk(cosine_similarities, dim=1, k=1)[1][:, 0]
P_rate = torch.sum(top1_indices == torch.arange(100, device=top1_indices.device)).item()
P_rates.append(P_rate)
A_precision = sum(P_rates) * 1.0 / len(P_rates)
return {"caption_rprec": A_precision}
|