Spaces:
Sleeping
Sleeping
import json | |
import torch | |
import torch.distributed as dist | |
from typing import List, Union, Optional, Tuple, Mapping, Dict | |
def save_json_to_file(objects: Union[List, dict], path: str, line_by_line: bool = False): | |
if line_by_line: | |
assert isinstance(objects, list), 'Only list can be saved in line by line format' | |
with open(path, 'w', encoding='utf-8') as writer: | |
if not line_by_line: | |
json.dump(objects, writer, ensure_ascii=False, indent=4, separators=(',', ':')) | |
else: | |
for obj in objects: | |
writer.write(json.dumps(obj, ensure_ascii=False, separators=(',', ':'))) | |
writer.write('\n') | |
def move_to_cuda(sample): | |
if len(sample) == 0: | |
return {} | |
def _move_to_cuda(maybe_tensor): | |
if torch.is_tensor(maybe_tensor): | |
return maybe_tensor.cuda(non_blocking=True) | |
elif isinstance(maybe_tensor, dict): | |
return {key: _move_to_cuda(value) for key, value in maybe_tensor.items()} | |
elif isinstance(maybe_tensor, list): | |
return [_move_to_cuda(x) for x in maybe_tensor] | |
elif isinstance(maybe_tensor, tuple): | |
return tuple([_move_to_cuda(x) for x in maybe_tensor]) | |
elif isinstance(maybe_tensor, Mapping): | |
return type(maybe_tensor)({k: _move_to_cuda(v) for k, v in maybe_tensor.items()}) | |
else: | |
return maybe_tensor | |
return _move_to_cuda(sample) | |
def dist_gather_tensor(t: Optional[torch.Tensor]) -> Optional[torch.Tensor]: | |
if t is None: | |
return None | |
t = t.contiguous() | |
all_tensors = [torch.empty_like(t) for _ in range(dist.get_world_size())] | |
dist.all_gather(all_tensors, t) | |
all_tensors[dist.get_rank()] = t | |
all_tensors = torch.cat(all_tensors, dim=0) | |
return all_tensors | |
def select_grouped_indices(scores: torch.Tensor, | |
group_size: int, | |
start: int = 0) -> torch.Tensor: | |
assert len(scores.shape) == 2 | |
batch_size = scores.shape[0] | |
assert batch_size * group_size <= scores.shape[1] | |
indices = torch.arange(0, group_size, dtype=torch.long) | |
indices = indices.repeat(batch_size, 1) | |
indices += torch.arange(0, batch_size, dtype=torch.long).unsqueeze(-1) * group_size | |
indices += start | |
return indices.to(scores.device) | |
def full_contrastive_scores_and_labels( | |
query: torch.Tensor, | |
key: torch.Tensor, | |
use_all_pairs: bool = True) -> Tuple[torch.Tensor, torch.Tensor]: | |
assert key.shape[0] % query.shape[0] == 0, '{} % {} > 0'.format(key.shape[0], query.shape[0]) | |
train_n_passages = key.shape[0] // query.shape[0] | |
labels = torch.arange(0, query.shape[0], dtype=torch.long, device=query.device) | |
labels = labels * train_n_passages | |
# batch_size x (batch_size x n_psg) | |
qk = torch.mm(query, key.t()) | |
if not use_all_pairs: | |
return qk, labels | |
# batch_size x dim | |
sliced_key = key.index_select(dim=0, index=labels) | |
assert query.shape[0] == sliced_key.shape[0] | |
# batch_size x batch_size | |
kq = torch.mm(sliced_key, query.t()) | |
kq.fill_diagonal_(float('-inf')) | |
qq = torch.mm(query, query.t()) | |
qq.fill_diagonal_(float('-inf')) | |
kk = torch.mm(sliced_key, sliced_key.t()) | |
kk.fill_diagonal_(float('-inf')) | |
scores = torch.cat([qk, kq, qq, kk], dim=-1) | |
return scores, labels | |
def slice_batch_dict(batch_dict: Dict[str, torch.Tensor], prefix: str) -> dict: | |
return {k[len(prefix):]: v for k, v in batch_dict.items() if k.startswith(prefix)} | |
class AverageMeter(object): | |
"""Computes and stores the average and current value""" | |
def __init__(self, name: str, round_digits: int = 3): | |
self.name = name | |
self.round_digits = round_digits | |
self.reset() | |
def reset(self): | |
self.avg = 0 | |
self.sum = 0 | |
self.count = 0 | |
def update(self, val, n=1): | |
self.sum += val * n | |
self.count += n | |
self.avg = self.sum / self.count | |
def __str__(self): | |
return '{}: {}'.format(self.name, round(self.avg, self.round_digits)) | |
if __name__ == '__main__': | |
query = torch.randn(4, 16) | |
key = torch.randn(4 * 3, 16) | |
scores, labels = full_contrastive_scores_and_labels(query, key) | |
print(scores.shape) | |
print(labels) | |