|
import torch
|
|
import torch.nn as nn
|
|
from torch.nn.functional import interpolate
|
|
|
|
from modules.cupy_module import correlation
|
|
from modules.half_warper import HalfWarper
|
|
from modules.feature_extactor import Extractor
|
|
|
|
from modules.flow_models.raft.rfr_new import RAFT
|
|
|
|
class Decoder(nn.Module):
|
|
def __init__(self, in_channels: int):
|
|
super().__init__()
|
|
|
|
self.syntesis = nn.Sequential(
|
|
nn.Conv2d(in_channels=in_channels, out_channels=128, kernel_size=3, stride=1, padding=1),
|
|
nn.SiLU(),
|
|
nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=1),
|
|
nn.SiLU(),
|
|
nn.Conv2d(in_channels=128, out_channels=96, kernel_size=3, stride=1, padding=1),
|
|
nn.SiLU(),
|
|
nn.Conv2d(in_channels=96, out_channels=64, kernel_size=3, stride=1, padding=1),
|
|
nn.SiLU(),
|
|
nn.Conv2d(in_channels=64, out_channels=32, kernel_size=3, stride=1, padding=1),
|
|
nn.SiLU(),
|
|
nn.Conv2d(in_channels=32, out_channels=2, kernel_size=3, stride=1, padding=1)
|
|
)
|
|
|
|
def forward(self, img1: torch.Tensor, img2: torch.Tensor, residual: torch.Tensor | None) -> torch.Tensor:
|
|
width = img1.shape[3] and img2.shape[3]
|
|
height = img1.shape[2] and img2.shape[2]
|
|
|
|
if residual is None:
|
|
corr = correlation.FunctionCorrelation(tenOne=img1, tenTwo=img2)
|
|
main = torch.cat([img1, corr], dim=1)
|
|
else:
|
|
flow = interpolate(input=residual,
|
|
size=(height, width),
|
|
mode='bilinear',
|
|
align_corners=False) / \
|
|
float(residual.shape[3]) * float(width)
|
|
backwarp_img = HalfWarper.backward_wrapping(img=img2, flow=flow)
|
|
corr = correlation.FunctionCorrelation(tenOne=img1, tenTwo=backwarp_img)
|
|
main = torch.cat([img1, corr, flow], dim=1)
|
|
|
|
return self.syntesis(main)
|
|
|
|
class PWCFineFlow(nn.Module):
|
|
def __init__(self, pretrained_path: str | None = None):
|
|
super().__init__()
|
|
|
|
self.feature_extractor = Extractor([3, 16, 32, 64, 96, 128, 192], num_groups=16)
|
|
|
|
self.decoders = nn.ModuleList([
|
|
Decoder(16 + 81 + 2),
|
|
Decoder(32 + 81 + 2),
|
|
Decoder(64 + 81 + 2),
|
|
Decoder(96 + 81 + 2),
|
|
Decoder(128 + 81 + 2),
|
|
Decoder(192 + 81)
|
|
])
|
|
|
|
if pretrained_path is not None:
|
|
self.load_state_dict(torch.load(pretrained_path))
|
|
|
|
def forward(self, img1: torch.Tensor, img2: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
|
width = img1.shape[3] and img2.shape[3]
|
|
height = img1.shape[2] and img2.shape[2]
|
|
|
|
feats1 = self.feature_extractor(img1)
|
|
feats2 = self.feature_extractor(img2)
|
|
|
|
forward = None
|
|
backward = None
|
|
|
|
for i in reversed(range(len(feats1))):
|
|
forward = self.decoders[i](feats1[i], feats2[i], forward)
|
|
backward = self.decoders[i](feats2[i], feats1[i], backward)
|
|
|
|
forward = interpolate(input=forward,
|
|
size=(height, width),
|
|
mode='bilinear',
|
|
align_corners=False) * \
|
|
(float(width) / float(forward.shape[3]))
|
|
backward = interpolate(input=backward,
|
|
size=(height, width),
|
|
mode='bilinear',
|
|
align_corners=False) * \
|
|
(float(width) / float(backward.shape[3]))
|
|
|
|
return forward, backward
|
|
|
|
|
|
class RAFTFineFlow(nn.Module):
|
|
def __init__(self, pretrained_path: str | None = None):
|
|
super().__init__()
|
|
self.raft = RAFT(pretrained_path)
|
|
|
|
def forward(self, img1: torch.Tensor, img2: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
|
forward = self.raft(img1, img2)
|
|
backward = self.raft(img2, img1)
|
|
return forward, backward |