|
import torch |
|
import torch.nn as nn |
|
from torch.nn.functional import interpolate |
|
|
|
import math |
|
from tqdm import tqdm |
|
|
|
from modules.feature_extactor import Extractor |
|
from modules.half_warper import HalfWarper |
|
from modules.cupy_module.nedt import NEDT |
|
from modules.flow_models.flow_models import ( |
|
RAFTFineFlow, |
|
PWCFineFlow |
|
) |
|
from modules.synthesizer import Synthesis |
|
|
|
class FeatureWarper(nn.Module): |
|
def __init__( |
|
self, |
|
in_channels: int = 3, |
|
channels: list[int] = [32, 64, 128, 256], |
|
): |
|
super().__init__() |
|
channels = [in_channels + 1] + channels |
|
|
|
self.half_warper = HalfWarper() |
|
self.feature_extractor = Extractor(channels) |
|
self.nedt = NEDT() |
|
|
|
def forward( |
|
self, |
|
I0: torch.Tensor, |
|
I1: torch.Tensor, |
|
flow0to1: torch.Tensor, |
|
flow1to0: torch.Tensor, |
|
tau: torch.Tensor = None |
|
) -> tuple[list[torch.Tensor], list[torch.Tensor]]: |
|
assert tau.shape == (I0.shape[0], 2), "tau shape must be (batch, 2)" |
|
|
|
flow0tot = tau[:, 0][:, None, None, None] * flow0to1 |
|
flow1tot = tau[:, 1][:, None, None, None] * flow1to0 |
|
|
|
I0 = torch.cat([I0, self.nedt(I0)], dim=1) |
|
I1 = torch.cat([I1, self.nedt(I1)], dim=1) |
|
|
|
z0to1, z1to0 = HalfWarper.z_metric(I0, I1, flow0to1, flow1to0) |
|
base0, base1 = self.half_warper(I0, I1, flow0tot, flow1tot, z0to1, z1to0) |
|
warped0, warped1 = [base0], [base1] |
|
|
|
features0 = self.feature_extractor(I0) |
|
features1 = self.feature_extractor(I1) |
|
|
|
for feat0, feat1 in zip(features0, features1): |
|
f0 = interpolate(flow0tot, size=feat0.shape[2:], mode='bilinear', align_corners=False) |
|
f1 = interpolate(flow1tot, size=feat0.shape[2:], mode='bilinear', align_corners=False) |
|
z0 = interpolate(z0to1, size=feat0.shape[2:], mode='bilinear', align_corners=False) |
|
z1 = interpolate(z1to0, size=feat0.shape[2:], mode='bilinear', align_corners=False) |
|
w0, w1 = self.half_warper(feat0, feat1, f0, f1, z0, z1) |
|
warped0.append(w0) |
|
warped1.append(w1) |
|
return warped0, warped1 |
|
|
|
class MultiInputResShift(nn.Module): |
|
def __init__( |
|
self, |
|
kappa: float=2.0, |
|
p: float =0.3, |
|
min_noise_level: float=0.04, |
|
etas_end: float=0.99, |
|
timesteps: int=15, |
|
flow_model: str = 'raft', |
|
flow_kwargs: dict = {}, |
|
warping_kwargs: dict = {}, |
|
synthesis_kwargs: dict = {} |
|
): |
|
super().__init__() |
|
|
|
self.timesteps = timesteps |
|
self.kappa = kappa |
|
self.eta_partition = None |
|
|
|
sqrt_eta_1 = min(min_noise_level / kappa, min_noise_level, math.sqrt(0.001)) |
|
b0 = math.exp(1/float(timesteps - 1) * math.log(etas_end/sqrt_eta_1)) |
|
base = torch.ones(timesteps)*b0 |
|
beta = ((torch.linspace(0,1,timesteps))**p)*(timesteps-1) |
|
sqrt_eta = torch.pow(base, beta) * sqrt_eta_1 |
|
|
|
self.register_buffer("sqrt_sum_eta", sqrt_eta) |
|
self.register_buffer("sum_eta", sqrt_eta**2) |
|
|
|
sum_prev_eta = torch.roll(self.sum_eta, 1) |
|
sum_prev_eta[0] = 0 |
|
self.register_buffer("sum_prev_eta", sum_prev_eta) |
|
|
|
self.register_buffer("sum_alpha", self.sum_eta - self.sum_prev_eta) |
|
|
|
self.register_buffer("backward_mean_c1", self.sum_prev_eta / self.sum_eta) |
|
self.register_buffer("backward_mean_c2", self.sum_alpha / self.sum_eta) |
|
self.register_buffer("backward_std", self.kappa*torch.sqrt(self.sum_prev_eta*self.sum_alpha/self.sum_eta)) |
|
|
|
if flow_model == 'raft': |
|
self.flow_model = RAFTFineFlow(**flow_kwargs) |
|
elif flow_model == 'pwc': |
|
self.flow_model = PWCFineFlow(**flow_kwargs) |
|
else: |
|
raise ValueError(f"Flow model {flow_model} not supported") |
|
|
|
self.feature_warper = FeatureWarper(**warping_kwargs) |
|
self.synthesis = Synthesis(**synthesis_kwargs) |
|
|
|
def forward_process( |
|
self, |
|
x: torch.Tensor | None, |
|
Y: list[torch.Tensor], |
|
tau: torch.Tensor | float | None, |
|
t: torch.Tensor | int |
|
) -> torch.Tensor: |
|
if tau is None: |
|
tau: torch.Tensor = torch.full((x.shape[0], len(Y)), 0.5, device=x.device, dtype=x.dtype) |
|
elif isinstance(tau, float): |
|
assert tau >= 0 and tau <= 1, "tau must be between 0 and 1" |
|
tau: torch.Tensor = torch.cat([ |
|
torch.full((x.shape[0], 1), tau, device=x.device, dtype=x.dtype), |
|
torch.full((x.shape[0], 1), 1 - tau, device=x.device, dtype=x.dtype) |
|
], dim=1) |
|
if not torch.is_tensor(t): |
|
t: torch.Tensor = torch.tensor([t], device=x.device, dtype=torch.long) |
|
if x is None: |
|
x: torch.Tensor = torch.zeros_like(Y[0]) |
|
|
|
eta = self.sum_eta[t][:, None] * tau |
|
eta = eta[:, :, None, None, None].transpose(0, 1) |
|
|
|
e_i = torch.stack([y - x for y in Y]) |
|
mean = x + (eta*e_i).sum(dim=0) |
|
|
|
sqrt_sum_eta = self.sqrt_sum_eta[t][:, None, None, None] |
|
std = self.kappa*sqrt_sum_eta |
|
epsilon = torch.randn_like(x) |
|
|
|
return mean + std*epsilon |
|
|
|
@torch.inference_mode() |
|
def reverse_process( |
|
self, |
|
Y: list[torch.Tensor], |
|
tau: torch.Tensor | float, |
|
flows: list[torch.Tensor] | None = None, |
|
) -> torch.Tensor: |
|
y = Y[0] |
|
batch, device, dtype = y.shape[0], y.device, y.dtype |
|
|
|
if isinstance(tau, float): |
|
assert tau >= 0 and tau <= 1, "tau must be between 0 and 1" |
|
tau: torch.Tensor = torch.cat([ |
|
torch.full((batch, 1), tau, device=device, dtype=dtype), |
|
torch.full((batch, 1), 1 - tau, device=device, dtype=dtype) |
|
], dim=1) |
|
if flows is None: |
|
flow0to1, flow1to0 = self.flow_model(Y[0], Y[1]) |
|
else: |
|
flow0to1, flow1to0 = flows |
|
warp0to1, warp1to0 = self.feature_warper(Y[0], Y[1], flow0to1, flow1to0, tau) |
|
|
|
T = torch.tensor([self.timesteps-1,] * batch, device=device, dtype=torch.long) |
|
x = self.forward_process(torch.zeros_like(Y[0]), [warp0to1[0][:, :3], warp1to0[0][:, :3]], tau, T) |
|
|
|
pbar = tqdm(total=self.timesteps, desc="Reversing Process") |
|
for i in reversed(range(self.timesteps)): |
|
t = torch.ones(batch, device = device, dtype=torch.long) * i |
|
|
|
predicted_x0 = self.synthesis(x, warp0to1, warp1to0, t) |
|
|
|
mean_c1 = self.backward_mean_c1[t][:, None, None, None] |
|
mean_c2 = self.backward_mean_c2[t][:, None, None, None] |
|
std = self.backward_std[t][:, None, None, None] |
|
|
|
eta = self.sum_eta[t][:, None] * tau |
|
prev_eta = self.sum_prev_eta[t][:, None] * tau |
|
eta = eta[:, :, None, None, None].transpose(0, 1) |
|
prev_eta = prev_eta[:, :, None, None, None].transpose(0, 1) |
|
e_i = torch.stack([y - predicted_x0 for y in Y]) |
|
|
|
mean = ( |
|
mean_c1*(x + (eta*e_i).sum(dim=0)) |
|
+ mean_c2*predicted_x0 |
|
- (prev_eta*e_i).sum(dim=0) |
|
) |
|
|
|
x = mean + std*torch.randn_like(x) |
|
pbar.update(1) |
|
pbar.close() |
|
return x |
|
|
|
|
|
def forward( |
|
self, |
|
I0: torch.Tensor, |
|
It: torch.Tensor, |
|
I1: torch.Tensor, |
|
flow1to0: torch.Tensor | None = None, |
|
flow0to1: torch.Tensor | None = None, |
|
tau: torch.Tensor | None = None, |
|
t: torch.Tensor | None = None |
|
) -> torch.Tensor: |
|
|
|
if tau is None: |
|
tau = torch.full((It.shape[0], 2), 0.5, device=It.device, dtype=It.dtype) |
|
|
|
if flow0to1 is None or flow1to0 is None: |
|
flow0to1, flow1to0 = self.flow_model(I0, I1) |
|
|
|
if t is None: |
|
t = torch.randint(low=1, high=self.timesteps, size=(It.shape[0],), device=It.device, dtype=torch.long) |
|
|
|
warp0to1, warp1to0 = self.feature_warper(I0, I1, flow0to1, flow1to0, tau) |
|
x_t = self.forward_process(It, [warp0to1[0][:, :3], warp1to0[0][:, :3]], tau, t) |
|
|
|
predicted_It = self.synthesis(x_t, warp0to1, warp1to0, t) |
|
return predicted_It |
|
|