File size: 3,330 Bytes
d4b77ac
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
import torch


def flow_prop(feat, flow, mode='forward'):
    """

    Args:
        feat: features to be aligned
        flow: the filled current flow
        mode: `forward` or `backward`, indicates the propagation direction

    Returns: feature after warping

    """
    assert mode in ['forward', 'backward'], 'Invalid mode: {}'.format(mode)
    feat = warp(feat, flow, mode)

    return feat


def warp(feat, flow, mode):
    device = feat.device
    c = feat.shape[1]
    y = flow[:, 0:1, :, :]
    x = flow[:, 1:2, :, :]

    x = x.repeat(1, c, 1, 1)  # [b, c, h, w]
    y = y.repeat(1, c, 1, 1)

    x1 = torch.floor(x)
    x2 = x1 + 1
    y1 = torch.floor(y)
    y2 = y1 + 1

    w11, w12, w21, w22 = get_gaussian_weights(x, y, x1, y1, x2, y2)

    feat11, o11 = sample_one(feat, x1, y1, w11, mode)
    feat12, o12 = sample_one(feat, x1, y2, w12, mode)
    feat21, o21 = sample_one(feat, x2, y1, w21, mode)
    feat22, o22 = sample_one(feat, x2, y2, w22, mode)

    feat_o = feat11 + feat12 + feat21 + feat22
    o = o11 + o12 + o21 + o22
    feat_o[o > 0] = feat_o[o > 0] / o[o > 0]
    return feat_o


def sample_one(feat, shiftx, shifty, weight, mode):
    device = feat.device
    b, c, h, w = feat.shape
    flat_shiftx = shiftx.view(-1)  # [b * c * h * w]
    flat_shifty = shifty.view(-1)
    flat_basex = torch.arange(0, h, requires_grad=False).view(-1, 1).long().repeat(b, c, 1, w).view(-1)
    flat_basey = torch.arange(0, w, requires_grad=False).view(-1, 1).long().repeat(b, c, h, 1).view(-1)
    flat_basex = flat_basex.to(device)
    flat_basey = flat_basey.to(device)
    flat_weight = weight.reshape(-1)
    flat_feat = feat.reshape(-1)

    idxn = torch.arange(0, b, requires_grad=False).view(b, 1, 1, 1).long().repeat(1, c, h, w).view(-1)
    idxc = torch.arange(0, c, requires_grad=False).view(1, c, 1, 1).long().repeat(b, 1, h, w).view(-1)
    idxn = idxn.to(device)
    idxc = idxc.to(device)
    if mode == 'forward':
        idxx = flat_shiftx.long() + flat_basex  # size [-1]
        idxy = flat_shifty.long() + flat_basey  # size [-1]
    else:  # backward propagation
        idxx = -flat_shiftx.long() + flat_basex  # size [-1]
        idxy = -flat_shifty.long() + flat_basey  # size [-1]

    # record the shifted pixels inside the image boundaries
    mask = idxx.ge(0) & idxx.lt(h) & idxy.ge(0) & idxy.lt(w)

    # mask off points out of boundaries
    ids = idxn * c * h * w + idxc * h * w + idxx * w + idxy
    ids_mask = torch.masked_select(ids, mask).clone()

    # put the value into corresponding regions
    feat_warp = torch.zeros([b * c * h * w])
    feat_warp = feat_warp.to(device)
    feat_warp.put_(ids_mask, torch.masked_select(flat_feat * flat_weight, mask), accumulate=True)
    one_warp = torch.zeros([b * c * h * w])
    one_warp = one_warp.to(device)
    one_warp.put_(ids_mask, torch.masked_select(flat_weight, mask), accumulate=True)
    return feat_warp.view(b, c, h, w), one_warp.view(b, c, h, w)


def get_gaussian_weights(x, y, x1, y1, x2, y2):
    sigma = 1
    w11 = torch.exp(-((x - x1) ** 2 + (y - y1) ** 2) / (sigma ** 2))
    w12 = torch.exp(-((x - x1) ** 2 + (y - y2) ** 2) / (sigma ** 2))
    w21 = torch.exp(-((x - x2) ** 2 + (y - y1) ** 2) / (sigma ** 2))
    w22 = torch.exp(-((x - x2) ** 2 + (y - y2) ** 2) / (sigma ** 2))
    return w11, w12, w21, w22