File size: 1,791 Bytes
abd2a81 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 |
import torch
def bilinear_interpolate(im, pos, batch=None):
# From https://gist.github.com/peteflorence/a1da2c759ca1ac2b74af9a83f69ce20e
x = pos[:, 1]
y = pos[:, 0]
x0 = torch.floor(x).long()
x1 = x0 + 1
y0 = torch.floor(y).long()
y1 = y0 + 1
x0_int = torch.clamp(x0, 0, im.shape[-1] - 1)
x1_int = torch.clamp(x1, 0, im.shape[-1] - 1)
y0_int = torch.clamp(y0, 0, im.shape[-2] - 1)
y1_int = torch.clamp(y1, 0, im.shape[-2] - 1)
if batch is not None:
Ia = im[batch, :, y0_int, x0_int]
Ib = im[batch, :, y1_int, x0_int]
Ic = im[batch, :, y0_int, x1_int]
Id = im[batch, :, y1_int, x1_int]
else:
Ia = im[..., y0_int, x0_int].t()
Ib = im[..., y1_int, x0_int].t()
Ic = im[..., y0_int, x1_int].t()
Id = im[..., y1_int, x1_int].t()
wa = (x1.float() - x) * (y1.float() - y)
wb = (x1.float() - x) * (y - y0.float())
wc = (x - x0.float()) * (y1.float() - y)
wd = (x - x0.float()) * (y - y0.float())
wa = wa.unsqueeze(1)
wb = wb.unsqueeze(1)
wc = wc.unsqueeze(1)
wd = wd.unsqueeze(1)
out = wa * Ia + wb * Ib + wc * Ic + wd * Id
return out
def main():
im = torch.tensor([
[
[0, 0.5, 0, 0],
[0.25, 1, 0, 0],
],
[
[1, 1, 1, 1],
[1, 1, 1, 1],
],
[
[2, 2, 2, 2],
[2, 2, 2, 2],
]
], dtype=torch.float)
im = im[:, None, ...]
print(im.shape)
print(im)
pos = torch.tensor([
[1.0, 0],
[0.5, 0.5],
[0.5, 0.5],
])
batch = torch.tensor([0, 1, 2], dtype=torch.long)
val = bilinear_interpolate(im, pos, batch=batch)
print(val)
if __name__ == '__main__':
main()
|