KyanChen's picture
add
02c5426
import torch
import torch.nn as nn
import torch.nn.functional as F
import models
from models import register
@register('diinn')
class DIINN(nn.Module):
def __init__(self,
encoder_spec,
mode=3, init_q=False):
super().__init__()
self.encoder = models.make(encoder_spec)
self.decoder = ImplicitDecoder(mode=mode, init_q=init_q)
def forward(self, x, size, bsize=None):
x = self.encoder(x)
x = self.decoder(x, size, bsize)
return x
class SineAct(nn.Module):
def __init__(self):
super().__init__()
def forward(self, x):
return torch.sin(x)
def patch_norm_2d(x, kernel_size=3):
# B, C, H, W = x.shape
# var, mean = torch.var_mean(F.unfold(x, kernel_size=kernel_size, padding=padding).view(B, C,kernel_size**2, H, W), dim=2, keepdim=False)
# return (x - mean) / torch.sqrt(var + 1e-6)
mean = F.avg_pool2d(x, kernel_size=kernel_size, padding=kernel_size // 2)
mean_sq = F.avg_pool2d(x ** 2, kernel_size=kernel_size, padding=kernel_size // 2)
var = mean_sq - mean ** 2
return (x - mean) / (var + 1e-6)
class ImplicitDecoder(nn.Module):
def __init__(self, in_channels=64, hidden_dims=[256, 256, 256, 256], mode=3, init_q=False):
super().__init__()
self.mode = mode
self.init_q = init_q
last_dim_K = in_channels * 9
if self.init_q:
self.first_layer = nn.Sequential(nn.Conv2d(3, in_channels * 9, 1),
SineAct())
last_dim_Q = in_channels * 9
else:
last_dim_Q = 3
self.K = nn.ModuleList()
self.Q = nn.ModuleList()
if self.mode == 1:
for hidden_dim in hidden_dims:
self.K.append(nn.Sequential(nn.Conv2d(last_dim_K, hidden_dim, 1),
nn.ReLU()))
self.Q.append(nn.Sequential(nn.Conv2d(last_dim_Q, hidden_dim, 1),
SineAct()))
last_dim_K = hidden_dim
last_dim_Q = hidden_dim
elif self.mode == 2:
for hidden_dim in hidden_dims:
self.K.append(nn.Sequential(nn.Conv2d(last_dim_K, hidden_dim, 1),
nn.ReLU()))
self.Q.append(nn.Sequential(nn.Conv2d(last_dim_Q, hidden_dim, 1),
SineAct()))
last_dim_K = hidden_dim + in_channels * 9
last_dim_Q = hidden_dim
elif self.mode == 3:
for hidden_dim in hidden_dims:
self.K.append(nn.Sequential(nn.Conv2d(last_dim_K, hidden_dim, 1),
nn.ReLU()))
self.Q.append(nn.Sequential(nn.Conv2d(last_dim_Q, hidden_dim, 1),
SineAct()))
last_dim_K = hidden_dim + in_channels * 9
last_dim_Q = hidden_dim
elif self.mode == 4:
for hidden_dim in hidden_dims:
self.K.append(nn.Sequential(nn.Conv2d(last_dim_K, hidden_dim, 1),
nn.ReLU()))
self.Q.append(nn.Sequential(nn.Conv2d(last_dim_Q, hidden_dim, 1),
SineAct()))
last_dim_K = hidden_dim + in_channels * 9
last_dim_Q = hidden_dim
if self.mode == 4:
self.last_layer = nn.Conv2d(hidden_dims[-1], 3, 3, padding=1, padding_mode='reflect')
else:
self.last_layer = nn.Conv2d(hidden_dims[-1], 3, 1)
def _make_pos_encoding(self, x, size):
B, C, H, W = x.shape
H_up, W_up = size
h_idx = -1 + 1 / H + 2 / H * torch.arange(H, device=x.device).float()
w_idx = -1 + 1 / W + 2 / W * torch.arange(W, device=x.device).float()
in_grid = torch.stack(torch.meshgrid(h_idx, w_idx, indexing='ij'), dim=0)
h_idx_up = -1 + 1 / H_up + 2 / H_up * torch.arange(H_up, device=x.device).float()
w_idx_up = -1 + 1 / W_up + 2 / W_up * torch.arange(W_up, device=x.device).float()
up_grid = torch.stack(torch.meshgrid(h_idx_up, w_idx_up, indexing='ij'), dim=0)
rel_grid = (up_grid - F.interpolate(in_grid.unsqueeze(0), size=(H_up, W_up),
mode='nearest-exact')) # important! mode='nearest' gives inconsistent results
rel_grid[:, 0, :, :] *= H
rel_grid[:, 1, :, :] *= W
return rel_grid.contiguous().detach()
def step(self, x, syn_inp):
if self.init_q:
syn_inp = self.first_layer(syn_inp)
x = syn_inp * x
if self.mode == 1:
k = self.K[0](x)
q = k * self.Q[0](syn_inp)
for i in range(1, len(self.K)):
k = self.K[i](k)
q = k * self.Q[i](q)
q = self.last_layer(q)
return q
elif self.mode == 2:
k = self.K[0](x)
q = k * self.Q[0](syn_inp)
for i in range(1, len(self.K)):
k = self.K[i](torch.cat([k, x], dim=1))
q = k * self.Q[i](q)
q = self.last_layer(q)
return q
elif self.mode == 3:
k = self.K[0](x)
q = k * self.Q[0](syn_inp)
# q = k + self.Q[0](syn_inp)
for i in range(1, len(self.K)):
k = self.K[i](torch.cat([q, x], dim=1))
q = k * self.Q[i](q)
# q = k + self.Q[i](q)
q = self.last_layer(q)
return q
elif self.mode == 4:
k = self.K[0](x)
q = k * self.Q[0](syn_inp)
for i in range(1, len(self.K)):
k = self.K[i](torch.cat([q, x], dim=1))
q = k * self.Q[i](q)
q = self.last_layer(q)
return q
def batched_step(self, x, syn_inp, bsize):
with torch.no_grad():
h, w = syn_inp.shape[-2:]
ql = 0
preds = []
while ql < w:
qr = min(ql + bsize // h, w)
pred = self.step(x[:, :, :, ql: qr], syn_inp[:, :, :, ql: qr])
preds.append(pred)
ql = qr
pred = torch.cat(preds, dim=-1)
return pred
def forward(self, x, size, bsize=None):
B, C, H_in, W_in = x.shape
rel_coord = self._make_pos_encoding(x, size).expand(B, -1, *size) # 2
ratio = x.new_tensor([(H_in * W_in) / (size[0] * size[1])]).view(1, -1, 1, 1).expand(B, -1, *size) # 2
syn_inp = torch.cat([rel_coord, ratio], dim=1)
x = F.interpolate(F.unfold(x, 3, padding=1).view(B, C * 9, H_in, W_in), size=syn_inp.shape[-2:],
mode='nearest-exact')
if bsize is None:
pred = self.step(x, syn_inp)
else:
pred = self.batched_step(x, syn_inp, bsize)
return pred