Spaces:
Sleeping
Sleeping
File size: 4,329 Bytes
6fc683c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 |
import json
import torch
import torch.distributed as dist
from typing import List, Union, Optional, Tuple, Mapping, Dict
def save_json_to_file(objects: Union[List, dict], path: str, line_by_line: bool = False):
if line_by_line:
assert isinstance(objects, list), 'Only list can be saved in line by line format'
with open(path, 'w', encoding='utf-8') as writer:
if not line_by_line:
json.dump(objects, writer, ensure_ascii=False, indent=4, separators=(',', ':'))
else:
for obj in objects:
writer.write(json.dumps(obj, ensure_ascii=False, separators=(',', ':')))
writer.write('\n')
def move_to_cuda(sample):
if len(sample) == 0:
return {}
def _move_to_cuda(maybe_tensor):
if torch.is_tensor(maybe_tensor):
return maybe_tensor.cuda(non_blocking=True)
elif isinstance(maybe_tensor, dict):
return {key: _move_to_cuda(value) for key, value in maybe_tensor.items()}
elif isinstance(maybe_tensor, list):
return [_move_to_cuda(x) for x in maybe_tensor]
elif isinstance(maybe_tensor, tuple):
return tuple([_move_to_cuda(x) for x in maybe_tensor])
elif isinstance(maybe_tensor, Mapping):
return type(maybe_tensor)({k: _move_to_cuda(v) for k, v in maybe_tensor.items()})
else:
return maybe_tensor
return _move_to_cuda(sample)
def dist_gather_tensor(t: Optional[torch.Tensor]) -> Optional[torch.Tensor]:
if t is None:
return None
t = t.contiguous()
all_tensors = [torch.empty_like(t) for _ in range(dist.get_world_size())]
dist.all_gather(all_tensors, t)
all_tensors[dist.get_rank()] = t
all_tensors = torch.cat(all_tensors, dim=0)
return all_tensors
@torch.no_grad()
def select_grouped_indices(scores: torch.Tensor,
group_size: int,
start: int = 0) -> torch.Tensor:
assert len(scores.shape) == 2
batch_size = scores.shape[0]
assert batch_size * group_size <= scores.shape[1]
indices = torch.arange(0, group_size, dtype=torch.long)
indices = indices.repeat(batch_size, 1)
indices += torch.arange(0, batch_size, dtype=torch.long).unsqueeze(-1) * group_size
indices += start
return indices.to(scores.device)
def full_contrastive_scores_and_labels(
query: torch.Tensor,
key: torch.Tensor,
use_all_pairs: bool = True) -> Tuple[torch.Tensor, torch.Tensor]:
assert key.shape[0] % query.shape[0] == 0, '{} % {} > 0'.format(key.shape[0], query.shape[0])
train_n_passages = key.shape[0] // query.shape[0]
labels = torch.arange(0, query.shape[0], dtype=torch.long, device=query.device)
labels = labels * train_n_passages
# batch_size x (batch_size x n_psg)
qk = torch.mm(query, key.t())
if not use_all_pairs:
return qk, labels
# batch_size x dim
sliced_key = key.index_select(dim=0, index=labels)
assert query.shape[0] == sliced_key.shape[0]
# batch_size x batch_size
kq = torch.mm(sliced_key, query.t())
kq.fill_diagonal_(float('-inf'))
qq = torch.mm(query, query.t())
qq.fill_diagonal_(float('-inf'))
kk = torch.mm(sliced_key, sliced_key.t())
kk.fill_diagonal_(float('-inf'))
scores = torch.cat([qk, kq, qq, kk], dim=-1)
return scores, labels
def slice_batch_dict(batch_dict: Dict[str, torch.Tensor], prefix: str) -> dict:
return {k[len(prefix):]: v for k, v in batch_dict.items() if k.startswith(prefix)}
class AverageMeter(object):
"""Computes and stores the average and current value"""
def __init__(self, name: str, round_digits: int = 3):
self.name = name
self.round_digits = round_digits
self.reset()
def reset(self):
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
self.sum += val * n
self.count += n
self.avg = self.sum / self.count
def __str__(self):
return '{}: {}'.format(self.name, round(self.avg, self.round_digits))
if __name__ == '__main__':
query = torch.randn(4, 16)
key = torch.randn(4 * 3, 16)
scores, labels = full_contrastive_scores_and_labels(query, key)
print(scores.shape)
print(labels)
|