File size: 805 Bytes
587665f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
import torch
from torchvision.models.optical_flow import raft_large 
from modules.flow_models.raft.rfr_new import RAFT

def raft_flow(
        I0: torch.Tensor, 
        I1: torch.Tensor, 
        data_domain: str = "animation", 
        device: str = 'cuda'
) -> tuple[torch.Tensor, torch.Tensor]:
    if I0.dtype != torch.float32 or I1.dtype != torch.float32:
        I0 = I0.to(torch.float32)
        I1 = I1.to(torch.float32)
    if data_domain == "animation":
        raft = RAFT().requires_grad_(False).eval().to(device)
    elif data_domain == "photorealism":
        raft = raft_large().requires_grad_(False).eval().to(device)
    else:
        raise ValueError("data_domain must be either 'animation' or 'photorealism'")
    return raft(I0, I1) if data_domain == "animation" else raft(I0, I1)[-1]