|
""" |
|
General Utils for Models |
|
|
|
Author: Xiaoyang Wu ([email protected]) |
|
Please cite our work if the code is helpful to you. |
|
""" |
|
|
|
import torch |
|
|
|
|
|
@torch.inference_mode() |
|
def offset2bincount(offset): |
|
return torch.diff( |
|
offset, prepend=torch.tensor([0], device=offset.device, dtype=torch.long) |
|
) |
|
|
|
|
|
@torch.inference_mode() |
|
def offset2batch(offset): |
|
bincount = offset2bincount(offset) |
|
return torch.arange( |
|
len(bincount), device=offset.device, dtype=torch.long |
|
).repeat_interleave(bincount) |
|
|
|
|
|
@torch.inference_mode() |
|
def batch2offset(batch): |
|
return torch.cumsum(batch.bincount(), dim=0).long() |
|
|
|
|
|
def off_diagonal(x): |
|
|
|
n, m = x.shape |
|
assert n == m |
|
return x.flatten()[:-1].view(n - 1, n + 1)[:, 1:].flatten() |
|
|