import torch import torch.nn as nn from torch.nn.functional import interpolate import math from tqdm import tqdm from modules.feature_extactor import Extractor from modules.half_warper import HalfWarper from modules.cupy_module.nedt import NEDT from modules.flow_models.flow_models import ( RAFTFineFlow, PWCFineFlow ) from modules.synthesizer import Synthesis class FeatureWarper(nn.Module): def __init__( self, in_channels: int = 3, channels: list[int] = [32, 64, 128, 256], ): super().__init__() channels = [in_channels + 1] + channels self.half_warper = HalfWarper() self.feature_extractor = Extractor(channels) self.nedt = NEDT() def forward( self, I0: torch.Tensor, I1: torch.Tensor, flow0to1: torch.Tensor, flow1to0: torch.Tensor, tau: torch.Tensor = None ) -> tuple[list[torch.Tensor], list[torch.Tensor]]: assert tau.shape == (I0.shape[0], 2), "tau shape must be (batch, 2)" flow0tot = tau[:, 0][:, None, None, None] * flow0to1 flow1tot = tau[:, 1][:, None, None, None] * flow1to0 I0 = torch.cat([I0, self.nedt(I0)], dim=1) I1 = torch.cat([I1, self.nedt(I1)], dim=1) z0to1, z1to0 = HalfWarper.z_metric(I0, I1, flow0to1, flow1to0) base0, base1 = self.half_warper(I0, I1, flow0tot, flow1tot, z0to1, z1to0) warped0, warped1 = [base0], [base1] features0 = self.feature_extractor(I0) features1 = self.feature_extractor(I1) for feat0, feat1 in zip(features0, features1): f0 = interpolate(flow0tot, size=feat0.shape[2:], mode='bilinear', align_corners=False) f1 = interpolate(flow1tot, size=feat0.shape[2:], mode='bilinear', align_corners=False) z0 = interpolate(z0to1, size=feat0.shape[2:], mode='bilinear', align_corners=False) z1 = interpolate(z1to0, size=feat0.shape[2:], mode='bilinear', align_corners=False) w0, w1 = self.half_warper(feat0, feat1, f0, f1, z0, z1) warped0.append(w0) warped1.append(w1) return warped0, warped1 class MultiInputResShift(nn.Module): def __init__( self, kappa: float=2.0, p: float =0.3, min_noise_level: float=0.04, etas_end: float=0.99, timesteps: int=15, flow_model: str = 'raft', flow_kwargs: dict = {}, warping_kwargs: dict = {}, synthesis_kwargs: dict = {} ): super().__init__() self.timesteps = timesteps self.kappa = kappa self.eta_partition = None sqrt_eta_1 = min(min_noise_level / kappa, min_noise_level, math.sqrt(0.001)) b0 = math.exp(1/float(timesteps - 1) * math.log(etas_end/sqrt_eta_1)) base = torch.ones(timesteps)*b0 beta = ((torch.linspace(0,1,timesteps))**p)*(timesteps-1) sqrt_eta = torch.pow(base, beta) * sqrt_eta_1 self.register_buffer("sqrt_sum_eta", sqrt_eta) self.register_buffer("sum_eta", sqrt_eta**2) sum_prev_eta = torch.roll(self.sum_eta, 1) sum_prev_eta[0] = 0 self.register_buffer("sum_prev_eta", sum_prev_eta) self.register_buffer("sum_alpha", self.sum_eta - self.sum_prev_eta) self.register_buffer("backward_mean_c1", self.sum_prev_eta / self.sum_eta) self.register_buffer("backward_mean_c2", self.sum_alpha / self.sum_eta) self.register_buffer("backward_std", self.kappa*torch.sqrt(self.sum_prev_eta*self.sum_alpha/self.sum_eta)) if flow_model == 'raft': self.flow_model = RAFTFineFlow(**flow_kwargs) elif flow_model == 'pwc': self.flow_model = PWCFineFlow(**flow_kwargs) else: raise ValueError(f"Flow model {flow_model} not supported") self.feature_warper = FeatureWarper(**warping_kwargs) self.synthesis = Synthesis(**synthesis_kwargs) def forward_process( self, x: torch.Tensor | None, Y: list[torch.Tensor], tau: torch.Tensor | float | None, t: torch.Tensor | int ) -> torch.Tensor: if tau is None: tau: torch.Tensor = torch.full((x.shape[0], len(Y)), 0.5, device=x.device, dtype=x.dtype) elif isinstance(tau, float): assert tau >= 0 and tau <= 1, "tau must be between 0 and 1" tau: torch.Tensor = torch.cat([ torch.full((x.shape[0], 1), tau, device=x.device, dtype=x.dtype), torch.full((x.shape[0], 1), 1 - tau, device=x.device, dtype=x.dtype) ], dim=1) if not torch.is_tensor(t): t: torch.Tensor = torch.tensor([t], device=x.device, dtype=torch.long) if x is None: x: torch.Tensor = torch.zeros_like(Y[0]) eta = self.sum_eta[t][:, None] * tau eta = eta[:, :, None, None, None].transpose(0, 1) e_i = torch.stack([y - x for y in Y]) mean = x + (eta*e_i).sum(dim=0) sqrt_sum_eta = self.sqrt_sum_eta[t][:, None, None, None] std = self.kappa*sqrt_sum_eta epsilon = torch.randn_like(x) return mean + std*epsilon @torch.inference_mode() def reverse_process( self, Y: list[torch.Tensor], tau: torch.Tensor | float, flows: list[torch.Tensor] | None = None, ) -> torch.Tensor: y = Y[0] batch, device, dtype = y.shape[0], y.device, y.dtype if isinstance(tau, float): assert tau >= 0 and tau <= 1, "tau must be between 0 and 1" tau: torch.Tensor = torch.cat([ torch.full((batch, 1), tau, device=device, dtype=dtype), torch.full((batch, 1), 1 - tau, device=device, dtype=dtype) ], dim=1) if flows is None: flow0to1, flow1to0 = self.flow_model(Y[0], Y[1]) else: flow0to1, flow1to0 = flows warp0to1, warp1to0 = self.feature_warper(Y[0], Y[1], flow0to1, flow1to0, tau) T = torch.tensor([self.timesteps-1,] * batch, device=device, dtype=torch.long) x = self.forward_process(torch.zeros_like(Y[0]), [warp0to1[0][:, :3], warp1to0[0][:, :3]], tau, T) pbar = tqdm(total=self.timesteps, desc="Reversing Process") for i in reversed(range(self.timesteps)): t = torch.ones(batch, device = device, dtype=torch.long) * i predicted_x0 = self.synthesis(x, warp0to1, warp1to0, t) mean_c1 = self.backward_mean_c1[t][:, None, None, None] mean_c2 = self.backward_mean_c2[t][:, None, None, None] std = self.backward_std[t][:, None, None, None] eta = self.sum_eta[t][:, None] * tau prev_eta = self.sum_prev_eta[t][:, None] * tau eta = eta[:, :, None, None, None].transpose(0, 1) prev_eta = prev_eta[:, :, None, None, None].transpose(0, 1) e_i = torch.stack([y - predicted_x0 for y in Y]) mean = ( mean_c1*(x + (eta*e_i).sum(dim=0)) + mean_c2*predicted_x0 - (prev_eta*e_i).sum(dim=0) ) x = mean + std*torch.randn_like(x) pbar.update(1) pbar.close() return x # Training Step Only def forward( self, I0: torch.Tensor, It: torch.Tensor, I1: torch.Tensor, flow1to0: torch.Tensor | None = None, flow0to1: torch.Tensor | None = None, tau: torch.Tensor | None = None, t: torch.Tensor | None = None ) -> torch.Tensor: if tau is None: tau = torch.full((It.shape[0], 2), 0.5, device=It.device, dtype=It.dtype) if flow0to1 is None or flow1to0 is None: flow0to1, flow1to0 = self.flow_model(I0, I1) if t is None: t = torch.randint(low=1, high=self.timesteps, size=(It.shape[0],), device=It.device, dtype=torch.long) warp0to1, warp1to0 = self.feature_warper(I0, I1, flow0to1, flow1to0, tau) x_t = self.forward_process(It, [warp0to1[0][:, :3], warp1to0[0][:, :3]], tau, t) predicted_It = self.synthesis(x_t, warp0to1, warp1to0, t) return predicted_It