Spaces:
Running
on
L4
Running
on
L4
import torch | |
from torchvision.models.optical_flow import raft_large | |
from modules.flow_models.raft.rfr_new import RAFT | |
def raft_flow( | |
I0: torch.Tensor, | |
I1: torch.Tensor, | |
data_domain: str = "animation", | |
device: str = 'cuda' | |
) -> tuple[torch.Tensor, torch.Tensor]: | |
if I0.dtype != torch.float32 or I1.dtype != torch.float32: | |
I0 = I0.to(torch.float32) | |
I1 = I1.to(torch.float32) | |
if data_domain == "animation": | |
raft = RAFT().requires_grad_(False).eval().to(device) | |
elif data_domain == "photorealism": | |
raft = raft_large().requires_grad_(False).eval().to(device) | |
else: | |
raise ValueError("data_domain must be either 'animation' or 'photorealism'") | |
return raft(I0, I1) if data_domain == "animation" else raft(I0, I1)[-1] |