|
import torch |
|
|
|
|
|
def bilinear_interpolate(im, pos, batch=None): |
|
|
|
x = pos[:, 1] |
|
y = pos[:, 0] |
|
x0 = torch.floor(x).long() |
|
x1 = x0 + 1 |
|
|
|
y0 = torch.floor(y).long() |
|
y1 = y0 + 1 |
|
|
|
x0_int = torch.clamp(x0, 0, im.shape[-1] - 1) |
|
x1_int = torch.clamp(x1, 0, im.shape[-1] - 1) |
|
y0_int = torch.clamp(y0, 0, im.shape[-2] - 1) |
|
y1_int = torch.clamp(y1, 0, im.shape[-2] - 1) |
|
|
|
if batch is not None: |
|
Ia = im[batch, :, y0_int, x0_int] |
|
Ib = im[batch, :, y1_int, x0_int] |
|
Ic = im[batch, :, y0_int, x1_int] |
|
Id = im[batch, :, y1_int, x1_int] |
|
else: |
|
Ia = im[..., y0_int, x0_int].t() |
|
Ib = im[..., y1_int, x0_int].t() |
|
Ic = im[..., y0_int, x1_int].t() |
|
Id = im[..., y1_int, x1_int].t() |
|
|
|
wa = (x1.float() - x) * (y1.float() - y) |
|
wb = (x1.float() - x) * (y - y0.float()) |
|
wc = (x - x0.float()) * (y1.float() - y) |
|
wd = (x - x0.float()) * (y - y0.float()) |
|
|
|
wa = wa.unsqueeze(1) |
|
wb = wb.unsqueeze(1) |
|
wc = wc.unsqueeze(1) |
|
wd = wd.unsqueeze(1) |
|
|
|
out = wa * Ia + wb * Ib + wc * Ic + wd * Id |
|
|
|
return out |
|
|
|
|
|
def main(): |
|
im = torch.tensor([ |
|
[ |
|
[0, 0.5, 0, 0], |
|
[0.25, 1, 0, 0], |
|
], |
|
[ |
|
[1, 1, 1, 1], |
|
[1, 1, 1, 1], |
|
], |
|
[ |
|
[2, 2, 2, 2], |
|
[2, 2, 2, 2], |
|
] |
|
], dtype=torch.float) |
|
im = im[:, None, ...] |
|
print(im.shape) |
|
print(im) |
|
pos = torch.tensor([ |
|
[1.0, 0], |
|
[0.5, 0.5], |
|
[0.5, 0.5], |
|
]) |
|
batch = torch.tensor([0, 1, 2], dtype=torch.long) |
|
|
|
val = bilinear_interpolate(im, pos, batch=batch) |
|
print(val) |
|
|
|
|
|
if __name__ == '__main__': |
|
main() |
|
|