Tzktz's picture
Upload 7664 files
6fc683c verified
import json
import torch
import torch.distributed as dist
from typing import List, Union, Optional, Tuple, Mapping, Dict
def save_json_to_file(objects: Union[List, dict], path: str, line_by_line: bool = False):
if line_by_line:
assert isinstance(objects, list), 'Only list can be saved in line by line format'
with open(path, 'w', encoding='utf-8') as writer:
if not line_by_line:
json.dump(objects, writer, ensure_ascii=False, indent=4, separators=(',', ':'))
else:
for obj in objects:
writer.write(json.dumps(obj, ensure_ascii=False, separators=(',', ':')))
writer.write('\n')
def move_to_cuda(sample):
if len(sample) == 0:
return {}
def _move_to_cuda(maybe_tensor):
if torch.is_tensor(maybe_tensor):
return maybe_tensor.cuda(non_blocking=True)
elif isinstance(maybe_tensor, dict):
return {key: _move_to_cuda(value) for key, value in maybe_tensor.items()}
elif isinstance(maybe_tensor, list):
return [_move_to_cuda(x) for x in maybe_tensor]
elif isinstance(maybe_tensor, tuple):
return tuple([_move_to_cuda(x) for x in maybe_tensor])
elif isinstance(maybe_tensor, Mapping):
return type(maybe_tensor)({k: _move_to_cuda(v) for k, v in maybe_tensor.items()})
else:
return maybe_tensor
return _move_to_cuda(sample)
def dist_gather_tensor(t: Optional[torch.Tensor]) -> Optional[torch.Tensor]:
if t is None:
return None
t = t.contiguous()
all_tensors = [torch.empty_like(t) for _ in range(dist.get_world_size())]
dist.all_gather(all_tensors, t)
all_tensors[dist.get_rank()] = t
all_tensors = torch.cat(all_tensors, dim=0)
return all_tensors
@torch.no_grad()
def select_grouped_indices(scores: torch.Tensor,
group_size: int,
start: int = 0) -> torch.Tensor:
assert len(scores.shape) == 2
batch_size = scores.shape[0]
assert batch_size * group_size <= scores.shape[1]
indices = torch.arange(0, group_size, dtype=torch.long)
indices = indices.repeat(batch_size, 1)
indices += torch.arange(0, batch_size, dtype=torch.long).unsqueeze(-1) * group_size
indices += start
return indices.to(scores.device)
def full_contrastive_scores_and_labels(
query: torch.Tensor,
key: torch.Tensor,
use_all_pairs: bool = True) -> Tuple[torch.Tensor, torch.Tensor]:
assert key.shape[0] % query.shape[0] == 0, '{} % {} > 0'.format(key.shape[0], query.shape[0])
train_n_passages = key.shape[0] // query.shape[0]
labels = torch.arange(0, query.shape[0], dtype=torch.long, device=query.device)
labels = labels * train_n_passages
# batch_size x (batch_size x n_psg)
qk = torch.mm(query, key.t())
if not use_all_pairs:
return qk, labels
# batch_size x dim
sliced_key = key.index_select(dim=0, index=labels)
assert query.shape[0] == sliced_key.shape[0]
# batch_size x batch_size
kq = torch.mm(sliced_key, query.t())
kq.fill_diagonal_(float('-inf'))
qq = torch.mm(query, query.t())
qq.fill_diagonal_(float('-inf'))
kk = torch.mm(sliced_key, sliced_key.t())
kk.fill_diagonal_(float('-inf'))
scores = torch.cat([qk, kq, qq, kk], dim=-1)
return scores, labels
def slice_batch_dict(batch_dict: Dict[str, torch.Tensor], prefix: str) -> dict:
return {k[len(prefix):]: v for k, v in batch_dict.items() if k.startswith(prefix)}
class AverageMeter(object):
"""Computes and stores the average and current value"""
def __init__(self, name: str, round_digits: int = 3):
self.name = name
self.round_digits = round_digits
self.reset()
def reset(self):
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
self.sum += val * n
self.count += n
self.avg = self.sum / self.count
def __str__(self):
return '{}: {}'.format(self.name, round(self.avg, self.round_digits))
if __name__ == '__main__':
query = torch.randn(4, 16)
key = torch.randn(4 * 3, 16)
scores, labels = full_contrastive_scores_and_labels(query, key)
print(scores.shape)
print(labels)