Spaces:
Runtime error
Runtime error
# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | |
# | |
# This work is made available under the Nvidia Source Code License-NC. | |
# To view a copy of this license, check out LICENSE.md | |
# flake8: noqa | |
import torch | |
import torch.nn.functional as F | |
def get_r_precision(image_text_code, eps=1e-5): | |
all_image_code, all_text_code = torch.chunk(image_text_code, 2, dim=1) | |
P_rates = [] | |
num_samples = len(all_image_code) | |
assert num_samples >= 100 | |
for i in range(0, num_samples, 100): | |
if i + 100 <= num_samples: | |
cur_image_code = all_image_code[i:i + 100] | |
cur_text_code = all_text_code[i:i + 100] | |
cur_image_code = F.normalize(cur_image_code, dim=1, eps=eps) | |
cur_text_code = F.normalize(cur_text_code, dim=1, eps=eps) | |
cosine_similarities = cur_image_code @ cur_text_code.T | |
top1_indices = torch.topk(cosine_similarities, dim=1, k=1)[1][:, 0] | |
P_rate = torch.sum(top1_indices == torch.arange(100, device=top1_indices.device)).item() | |
P_rates.append(P_rate) | |
A_precision = sum(P_rates) * 1.0 / len(P_rates) | |
return {"caption_rprec": A_precision} | |