""" General Utils for Models Author: Xiaoyang Wu (xiaoyang.wu.cs@gmail.com) Please cite our work if the code is helpful to you. """ import torch @torch.inference_mode() def offset2bincount(offset): return torch.diff( offset, prepend=torch.tensor([0], device=offset.device, dtype=torch.long) ) @torch.inference_mode() def offset2batch(offset): bincount = offset2bincount(offset) return torch.arange( len(bincount), device=offset.device, dtype=torch.long ).repeat_interleave(bincount) @torch.inference_mode() def batch2offset(batch): return torch.cumsum(batch.bincount(), dim=0).long() def off_diagonal(x): # return a flattened view of the off-diagonal elements of a square matrix n, m = x.shape assert n == m return x.flatten()[:-1].view(n - 1, n + 1)[:, 1:].flatten()