import json import torch import torch.distributed as dist from typing import List, Union, Optional, Tuple, Mapping, Dict def save_json_to_file(objects: Union[List, dict], path: str, line_by_line: bool = False): if line_by_line: assert isinstance(objects, list), 'Only list can be saved in line by line format' with open(path, 'w', encoding='utf-8') as writer: if not line_by_line: json.dump(objects, writer, ensure_ascii=False, indent=4, separators=(',', ':')) else: for obj in objects: writer.write(json.dumps(obj, ensure_ascii=False, separators=(',', ':'))) writer.write('\n') def move_to_cuda(sample): if len(sample) == 0: return {} def _move_to_cuda(maybe_tensor): if torch.is_tensor(maybe_tensor): return maybe_tensor.cuda(non_blocking=True) elif isinstance(maybe_tensor, dict): return {key: _move_to_cuda(value) for key, value in maybe_tensor.items()} elif isinstance(maybe_tensor, list): return [_move_to_cuda(x) for x in maybe_tensor] elif isinstance(maybe_tensor, tuple): return tuple([_move_to_cuda(x) for x in maybe_tensor]) elif isinstance(maybe_tensor, Mapping): return type(maybe_tensor)({k: _move_to_cuda(v) for k, v in maybe_tensor.items()}) else: return maybe_tensor return _move_to_cuda(sample) def dist_gather_tensor(t: Optional[torch.Tensor]) -> Optional[torch.Tensor]: if t is None: return None t = t.contiguous() all_tensors = [torch.empty_like(t) for _ in range(dist.get_world_size())] dist.all_gather(all_tensors, t) all_tensors[dist.get_rank()] = t all_tensors = torch.cat(all_tensors, dim=0) return all_tensors @torch.no_grad() def select_grouped_indices(scores: torch.Tensor, group_size: int, start: int = 0) -> torch.Tensor: assert len(scores.shape) == 2 batch_size = scores.shape[0] assert batch_size * group_size <= scores.shape[1] indices = torch.arange(0, group_size, dtype=torch.long) indices = indices.repeat(batch_size, 1) indices += torch.arange(0, batch_size, dtype=torch.long).unsqueeze(-1) * group_size indices += start return indices.to(scores.device) def full_contrastive_scores_and_labels( query: torch.Tensor, key: torch.Tensor, use_all_pairs: bool = True) -> Tuple[torch.Tensor, torch.Tensor]: assert key.shape[0] % query.shape[0] == 0, '{} % {} > 0'.format(key.shape[0], query.shape[0]) train_n_passages = key.shape[0] // query.shape[0] labels = torch.arange(0, query.shape[0], dtype=torch.long, device=query.device) labels = labels * train_n_passages # batch_size x (batch_size x n_psg) qk = torch.mm(query, key.t()) if not use_all_pairs: return qk, labels # batch_size x dim sliced_key = key.index_select(dim=0, index=labels) assert query.shape[0] == sliced_key.shape[0] # batch_size x batch_size kq = torch.mm(sliced_key, query.t()) kq.fill_diagonal_(float('-inf')) qq = torch.mm(query, query.t()) qq.fill_diagonal_(float('-inf')) kk = torch.mm(sliced_key, sliced_key.t()) kk.fill_diagonal_(float('-inf')) scores = torch.cat([qk, kq, qq, kk], dim=-1) return scores, labels def slice_batch_dict(batch_dict: Dict[str, torch.Tensor], prefix: str) -> dict: return {k[len(prefix):]: v for k, v in batch_dict.items() if k.startswith(prefix)} class AverageMeter(object): """Computes and stores the average and current value""" def __init__(self, name: str, round_digits: int = 3): self.name = name self.round_digits = round_digits self.reset() def reset(self): self.avg = 0 self.sum = 0 self.count = 0 def update(self, val, n=1): self.sum += val * n self.count += n self.avg = self.sum / self.count def __str__(self): return '{}: {}'.format(self.name, round(self.avg, self.round_digits)) if __name__ == '__main__': query = torch.randn(4, 16) key = torch.randn(4 * 3, 16) scores, labels = full_contrastive_scores_and_labels(query, key) print(scores.shape) print(labels)