File size: 805 Bytes
587665f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 |
import torch
from torchvision.models.optical_flow import raft_large
from modules.flow_models.raft.rfr_new import RAFT
def raft_flow(
I0: torch.Tensor,
I1: torch.Tensor,
data_domain: str = "animation",
device: str = 'cuda'
) -> tuple[torch.Tensor, torch.Tensor]:
if I0.dtype != torch.float32 or I1.dtype != torch.float32:
I0 = I0.to(torch.float32)
I1 = I1.to(torch.float32)
if data_domain == "animation":
raft = RAFT().requires_grad_(False).eval().to(device)
elif data_domain == "photorealism":
raft = raft_large().requires_grad_(False).eval().to(device)
else:
raise ValueError("data_domain must be either 'animation' or 'photorealism'")
return raft(I0, I1) if data_domain == "animation" else raft(I0, I1)[-1] |