vfontech's picture
Uploading the app
587665f verified
raw
history blame contribute delete
4.17 kB
import torch
import torch.nn as nn
from torch.nn.functional import interpolate
from modules.cupy_module import correlation
from modules.half_warper import HalfWarper
from modules.feature_extactor import Extractor
from modules.flow_models.raft.rfr_new import RAFT
class Decoder(nn.Module):
def __init__(self, in_channels: int):
super().__init__()
self.syntesis = nn.Sequential(
nn.Conv2d(in_channels=in_channels, out_channels=128, kernel_size=3, stride=1, padding=1),
nn.SiLU(),
nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=1),
nn.SiLU(),
nn.Conv2d(in_channels=128, out_channels=96, kernel_size=3, stride=1, padding=1),
nn.SiLU(),
nn.Conv2d(in_channels=96, out_channels=64, kernel_size=3, stride=1, padding=1),
nn.SiLU(),
nn.Conv2d(in_channels=64, out_channels=32, kernel_size=3, stride=1, padding=1),
nn.SiLU(),
nn.Conv2d(in_channels=32, out_channels=2, kernel_size=3, stride=1, padding=1)
)
def forward(self, img1: torch.Tensor, img2: torch.Tensor, residual: torch.Tensor | None) -> torch.Tensor:
width = img1.shape[3] and img2.shape[3]
height = img1.shape[2] and img2.shape[2]
if residual is None:
corr = correlation.FunctionCorrelation(tenOne=img1, tenTwo=img2)
main = torch.cat([img1, corr], dim=1)
else:
flow = interpolate(input=residual,
size=(height, width),
mode='bilinear',
align_corners=False) / \
float(residual.shape[3]) * float(width)
backwarp_img = HalfWarper.backward_wrapping(img=img2, flow=flow)
corr = correlation.FunctionCorrelation(tenOne=img1, tenTwo=backwarp_img)
main = torch.cat([img1, corr, flow], dim=1)
return self.syntesis(main)
class PWCFineFlow(nn.Module):
def __init__(self, pretrained_path: str | None = None):
super().__init__()
self.feature_extractor = Extractor([3, 16, 32, 64, 96, 128, 192], num_groups=16)
self.decoders = nn.ModuleList([
Decoder(16 + 81 + 2),
Decoder(32 + 81 + 2),
Decoder(64 + 81 + 2),
Decoder(96 + 81 + 2),
Decoder(128 + 81 + 2),
Decoder(192 + 81)
])
if pretrained_path is not None:
self.load_state_dict(torch.load(pretrained_path))
def forward(self, img1: torch.Tensor, img2: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
width = img1.shape[3] and img2.shape[3]
height = img1.shape[2] and img2.shape[2]
feats1 = self.feature_extractor(img1)
feats2 = self.feature_extractor(img2)
forward = None
backward = None
for i in reversed(range(len(feats1))):
forward = self.decoders[i](feats1[i], feats2[i], forward)
backward = self.decoders[i](feats2[i], feats1[i], backward)
forward = interpolate(input=forward,
size=(height, width),
mode='bilinear',
align_corners=False) * \
(float(width) / float(forward.shape[3]))
backward = interpolate(input=backward,
size=(height, width),
mode='bilinear',
align_corners=False) * \
(float(width) / float(backward.shape[3]))
return forward, backward
class RAFTFineFlow(nn.Module):
def __init__(self, pretrained_path: str | None = None):
super().__init__()
self.raft = RAFT(pretrained_path)
def forward(self, img1: torch.Tensor, img2: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
forward = self.raft(img1, img2)
backward = self.raft(img2, img1)
return forward, backward