import torch import torch.nn as nn from torch.nn.functional import interpolate from modules.cupy_module import correlation from modules.half_warper import HalfWarper from modules.feature_extactor import Extractor from modules.flow_models.raft.rfr_new import RAFT class Decoder(nn.Module): def __init__(self, in_channels: int): super().__init__() self.syntesis = nn.Sequential( nn.Conv2d(in_channels=in_channels, out_channels=128, kernel_size=3, stride=1, padding=1), nn.SiLU(), nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=1), nn.SiLU(), nn.Conv2d(in_channels=128, out_channels=96, kernel_size=3, stride=1, padding=1), nn.SiLU(), nn.Conv2d(in_channels=96, out_channels=64, kernel_size=3, stride=1, padding=1), nn.SiLU(), nn.Conv2d(in_channels=64, out_channels=32, kernel_size=3, stride=1, padding=1), nn.SiLU(), nn.Conv2d(in_channels=32, out_channels=2, kernel_size=3, stride=1, padding=1) ) def forward(self, img1: torch.Tensor, img2: torch.Tensor, residual: torch.Tensor | None) -> torch.Tensor: width = img1.shape[3] and img2.shape[3] height = img1.shape[2] and img2.shape[2] if residual is None: corr = correlation.FunctionCorrelation(tenOne=img1, tenTwo=img2) main = torch.cat([img1, corr], dim=1) else: flow = interpolate(input=residual, size=(height, width), mode='bilinear', align_corners=False) / \ float(residual.shape[3]) * float(width) backwarp_img = HalfWarper.backward_wrapping(img=img2, flow=flow) corr = correlation.FunctionCorrelation(tenOne=img1, tenTwo=backwarp_img) main = torch.cat([img1, corr, flow], dim=1) return self.syntesis(main) class PWCFineFlow(nn.Module): def __init__(self, pretrained_path: str | None = None): super().__init__() self.feature_extractor = Extractor([3, 16, 32, 64, 96, 128, 192], num_groups=16) self.decoders = nn.ModuleList([ Decoder(16 + 81 + 2), Decoder(32 + 81 + 2), Decoder(64 + 81 + 2), Decoder(96 + 81 + 2), Decoder(128 + 81 + 2), Decoder(192 + 81) ]) if pretrained_path is not None: self.load_state_dict(torch.load(pretrained_path)) def forward(self, img1: torch.Tensor, img2: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: width = img1.shape[3] and img2.shape[3] height = img1.shape[2] and img2.shape[2] feats1 = self.feature_extractor(img1) feats2 = self.feature_extractor(img2) forward = None backward = None for i in reversed(range(len(feats1))): forward = self.decoders[i](feats1[i], feats2[i], forward) backward = self.decoders[i](feats2[i], feats1[i], backward) forward = interpolate(input=forward, size=(height, width), mode='bilinear', align_corners=False) * \ (float(width) / float(forward.shape[3])) backward = interpolate(input=backward, size=(height, width), mode='bilinear', align_corners=False) * \ (float(width) / float(backward.shape[3])) return forward, backward class RAFTFineFlow(nn.Module): def __init__(self, pretrained_path: str | None = None): super().__init__() self.raft = RAFT(pretrained_path) def forward(self, img1: torch.Tensor, img2: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: forward = self.raft(img1, img2) backward = self.raft(img2, img1) return forward, backward