File size: 4,575 Bytes
02c5426 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 |
import torch
import torch.nn as nn
import torch.nn.functional as F
import models
from models import register
from utils import make_coord
@register('aliif')
class ALIIF(nn.Module):
def __init__(self, encoder_spec, pdn_spec=None, basis_spec=None, imnet_spec=None,
local_ensemble=True, feat_unfold=True, cell_decode=True):
super().__init__()
self.local_ensemble = local_ensemble
self.feat_unfold = feat_unfold
self.cell_decode = cell_decode
self.encoder = models.make(encoder_spec)
if pdn_spec is not None:
self.pdn=models.make(pdn_spec)
self.use_pdn=True
else:
self.use_pdn = False
if basis_spec is not None:
self.basis=models.make(basis_spec)
self.use_basis=True
self.B,self.b=self.basis()
else:
self.use_basis = False
if imnet_spec is not None:
imnet_in_dim = self.encoder.out_dim
if self.feat_unfold:
imnet_in_dim *= 9
imnet_in_dim += 2 # attach coord
if self.cell_decode:
imnet_in_dim += 2
self.imnet = models.make(imnet_spec, args={'in_dim': imnet_in_dim})
else:
self.imnet = None
def gen_feat(self, inp):
self.feat = self.encoder(inp)
return self.feat
def query_rgb(self, coord, cell=None):
feat = self.feat
if self.imnet is None:
ret = F.grid_sample(feat, coord.flip(-1).unsqueeze(1),
mode='nearest', align_corners=False)[:, :, 0, :] \
.permute(0, 2, 1)
return ret
if self.feat_unfold:
feat = F.unfold(feat, 3, padding=1).view(
feat.shape[0], feat.shape[1] * 9, feat.shape[2], feat.shape[3])
if self.local_ensemble:
vx_lst = [-1, 1]
vy_lst = [-1, 1]
eps_shift = 1e-6
else:
vx_lst, vy_lst, eps_shift = [0], [0], 0
# field radius (global: [-1, 1])
rx = 2 / feat.shape[-2] / 2
ry = 2 / feat.shape[-1] / 2
feat_coord = make_coord(feat.shape[-2:], flatten=False).cuda() \
.permute(2, 0, 1) \
.unsqueeze(0).expand(feat.shape[0], 2, *feat.shape[-2:])
preds = []
areas = []
for vx in vx_lst:
for vy in vy_lst:
coord_ = coord.clone()
coord_[:, :, 0] += vx * rx + eps_shift
coord_[:, :, 1] += vy * ry + eps_shift
coord_.clamp_(-1 + 1e-6, 1 - 1e-6)
q_feat = F.grid_sample(
feat, coord_.flip(-1).unsqueeze(1),
mode='nearest', align_corners=False)[:, :, 0, :] \
.permute(0, 2, 1)
q_coord = F.grid_sample(
feat_coord, coord_.flip(-1).unsqueeze(1),
mode='nearest', align_corners=False)[:, :, 0, :] \
.permute(0, 2, 1)
rel_coord = coord - q_coord
rel_coord[:, :, 0] *= feat.shape[-2]
rel_coord[:, :, 1] *= feat.shape[-1]
inp = torch.cat([q_feat, rel_coord], dim=-1)
if self.cell_decode:
rel_cell = cell.clone()
rel_cell[:, :, 0] *= feat.shape[-2]
rel_cell[:, :, 1] *= feat.shape[-1]
inp = torch.cat([inp, rel_cell], dim=-1)
bs, q = coord.shape[:2]
if self.use_pdn:
Coeff=self.pdn(inp) # out:(b,h*w,K)
else:
Coeff=torch.ones([inp.shape[0],inp.shape[1],1])
if self.use_basis:
pred = self.imnet(inp.view(bs * q, -1),Coeff.view(-1,Coeff.shape[2]),self.B,self.b).view(bs, q, -1)
else:
pred = self.imnet(inp.view(bs * q, -1)).view(bs, q, -1)
preds.append(pred)
area = torch.abs(rel_coord[:, :, 0] * rel_coord[:, :, 1])
areas.append(area + 1e-9)
tot_area = torch.stack(areas).sum(dim=0)
if self.local_ensemble:
t = areas[0]; areas[0] = areas[3]; areas[3] = t
t = areas[1]; areas[1] = areas[2]; areas[2] = t
ret = 0
for pred, area in zip(preds, areas):
ret = ret + pred * (area / tot_area).unsqueeze(-1)
return ret
def forward(self, inp, coord, cell):
self.gen_feat(inp)
return self.query_rgb(coord, cell)
|