Spaces:
Running
on
Zero
Running
on
Zero
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import os | |
import torch | |
from argparse import ArgumentParser | |
from pathlib import Path | |
import math | |
import numpy as np | |
from torch import nn | |
from torch.utils.data import ConcatDataset | |
import torch.distributed as dist | |
from torch.nn.parallel import DistributedDataParallel as DDP | |
import json | |
import wandb | |
from PIL import Image | |
from torchvision.transforms import ToTensor | |
from romatch.benchmarks import MegadepthDenseBenchmark, ScanNetBenchmark | |
from romatch.benchmarks import Mega1500PoseLibBenchmark, ScanNetPoselibBenchmark | |
from romatch.datasets.megadepth import MegadepthBuilder | |
from romatch.losses.robust_loss_tiny_roma import RobustLosses | |
from romatch.benchmarks import MegaDepthPoseEstimationBenchmark, MegadepthDenseBenchmark, HpatchesHomogBenchmark | |
from romatch.train.train import train_k_steps | |
from romatch.checkpointing import CheckPoint | |
resolutions = {"low":(448, 448), "medium":(14*8*5, 14*8*5), "high":(14*8*6, 14*8*6), "xfeat": (600,800), "big": (768, 1024)} | |
def kde(x, std = 0.1): | |
# use a gaussian kernel to estimate density | |
x = x.half() # Do it in half precision TODO: remove hardcoding | |
scores = (-torch.cdist(x,x)**2/(2*std**2)).exp() | |
density = scores.sum(dim=-1) | |
return density | |
class BasicLayer(nn.Module): | |
""" | |
Basic Convolutional Layer: Conv2d -> BatchNorm -> ReLU | |
""" | |
def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1, dilation=1, bias=False, relu = True): | |
super().__init__() | |
self.layer = nn.Sequential( | |
nn.Conv2d( in_channels, out_channels, kernel_size, padding = padding, stride=stride, dilation=dilation, bias = bias), | |
nn.BatchNorm2d(out_channels, affine=False), | |
nn.ReLU(inplace = True) if relu else nn.Identity() | |
) | |
def forward(self, x): | |
return self.layer(x) | |
class XFeatModel(nn.Module): | |
""" | |
Implementation of architecture described in | |
"XFeat: Accelerated Features for Lightweight Image Matching, CVPR 2024." | |
""" | |
def __init__(self, xfeat = None, | |
freeze_xfeat = True, | |
sample_mode = "threshold_balanced", | |
symmetric = False, | |
exact_softmax = False): | |
super().__init__() | |
if xfeat is None: | |
xfeat = torch.hub.load('verlab/accelerated_features', 'XFeat', pretrained = True, top_k = 4096).net | |
del xfeat.heatmap_head, xfeat.keypoint_head, xfeat.fine_matcher | |
if freeze_xfeat: | |
xfeat.train(False) | |
self.xfeat = [xfeat]# hide params from ddp | |
else: | |
self.xfeat = nn.ModuleList([xfeat]) | |
self.freeze_xfeat = freeze_xfeat | |
match_dim = 256 | |
self.coarse_matcher = nn.Sequential( | |
BasicLayer(64+64+2, match_dim,), | |
BasicLayer(match_dim, match_dim,), | |
BasicLayer(match_dim, match_dim,), | |
BasicLayer(match_dim, match_dim,), | |
nn.Conv2d(match_dim, 3, kernel_size=1, bias=True, padding=0)) | |
fine_match_dim = 64 | |
self.fine_matcher = nn.Sequential( | |
BasicLayer(24+24+2, fine_match_dim,), | |
BasicLayer(fine_match_dim, fine_match_dim,), | |
BasicLayer(fine_match_dim, fine_match_dim,), | |
BasicLayer(fine_match_dim, fine_match_dim,), | |
nn.Conv2d(fine_match_dim, 3, kernel_size=1, bias=True, padding=0),) | |
self.sample_mode = sample_mode | |
self.sample_thresh = 0.2 | |
self.symmetric = symmetric | |
self.exact_softmax = exact_softmax | |
def device(self): | |
return self.fine_matcher[-1].weight.device | |
def preprocess_tensor(self, x): | |
""" Guarantee that image is divisible by 32 to avoid aliasing artifacts. """ | |
H, W = x.shape[-2:] | |
_H, _W = (H//32) * 32, (W//32) * 32 | |
rh, rw = H/_H, W/_W | |
x = F.interpolate(x, (_H, _W), mode='bilinear', align_corners=False) | |
return x, rh, rw | |
def forward_single(self, x): | |
with torch.inference_mode(self.freeze_xfeat or not self.training): | |
xfeat = self.xfeat[0] | |
with torch.no_grad(): | |
x = x.mean(dim=1, keepdim = True) | |
x = xfeat.norm(x) | |
#main backbone | |
x1 = xfeat.block1(x) | |
x2 = xfeat.block2(x1 + xfeat.skip1(x)) | |
x3 = xfeat.block3(x2) | |
x4 = xfeat.block4(x3) | |
x5 = xfeat.block5(x4) | |
x4 = F.interpolate(x4, (x3.shape[-2], x3.shape[-1]), mode='bilinear') | |
x5 = F.interpolate(x5, (x3.shape[-2], x3.shape[-1]), mode='bilinear') | |
feats = xfeat.block_fusion( x3 + x4 + x5 ) | |
if self.freeze_xfeat: | |
return x2.clone(), feats.clone() | |
return x2, feats | |
def to_pixel_coordinates(self, coords, H_A, W_A, H_B = None, W_B = None): | |
if coords.shape[-1] == 2: | |
return self._to_pixel_coordinates(coords, H_A, W_A) | |
if isinstance(coords, (list, tuple)): | |
kpts_A, kpts_B = coords[0], coords[1] | |
else: | |
kpts_A, kpts_B = coords[...,:2], coords[...,2:] | |
return self._to_pixel_coordinates(kpts_A, H_A, W_A), self._to_pixel_coordinates(kpts_B, H_B, W_B) | |
def _to_pixel_coordinates(self, coords, H, W): | |
kpts = torch.stack((W/2 * (coords[...,0]+1), H/2 * (coords[...,1]+1)),axis=-1) | |
return kpts | |
def pos_embed(self, corr_volume: torch.Tensor): | |
B, H1, W1, H0, W0 = corr_volume.shape | |
grid = torch.stack( | |
torch.meshgrid( | |
torch.linspace(-1+1/W1,1-1/W1, W1), | |
torch.linspace(-1+1/H1,1-1/H1, H1), | |
indexing = "xy"), | |
dim = -1).float().to(corr_volume).reshape(H1*W1, 2) | |
down = 4 | |
if not self.training and not self.exact_softmax: | |
grid_lr = torch.stack( | |
torch.meshgrid( | |
torch.linspace(-1+down/W1,1-down/W1, W1//down), | |
torch.linspace(-1+down/H1,1-down/H1, H1//down), | |
indexing = "xy"), | |
dim = -1).float().to(corr_volume).reshape(H1*W1 //down**2, 2) | |
cv = corr_volume | |
best_match = cv.reshape(B,H1*W1,H0,W0).amax(dim=1) # B, HW, H, W | |
P_lowres = torch.cat((cv[:,::down,::down].reshape(B,H1*W1 // down**2,H0,W0), best_match[:,None]),dim=1).softmax(dim=1) | |
pos_embeddings = torch.einsum('bchw,cd->bdhw', P_lowres[:,:-1], grid_lr) | |
pos_embeddings += P_lowres[:,-1] * grid[best_match].permute(0,3,1,2) | |
else: | |
P = corr_volume.reshape(B,H1*W1,H0,W0).softmax(dim=1) # B, HW, H, W | |
pos_embeddings = torch.einsum('bchw,cd->bdhw', P, grid) | |
return pos_embeddings | |
def visualize_warp(self, warp, certainty, im_A = None, im_B = None, | |
im_A_path = None, im_B_path = None, symmetric = True, save_path = None, unnormalize = False): | |
device = warp.device | |
H,W2,_ = warp.shape | |
W = W2//2 if symmetric else W2 | |
if im_A is None: | |
from PIL import Image | |
im_A, im_B = Image.open(im_A_path).convert("RGB"), Image.open(im_B_path).convert("RGB") | |
if not isinstance(im_A, torch.Tensor): | |
im_A = im_A.resize((W,H)) | |
im_B = im_B.resize((W,H)) | |
x_B = (torch.tensor(np.array(im_B)) / 255).to(device).permute(2, 0, 1) | |
if symmetric: | |
x_A = (torch.tensor(np.array(im_A)) / 255).to(device).permute(2, 0, 1) | |
else: | |
if symmetric: | |
x_A = im_A | |
x_B = im_B | |
im_A_transfer_rgb = F.grid_sample( | |
x_B[None], warp[:,:W, 2:][None], mode="bilinear", align_corners=False | |
)[0] | |
if symmetric: | |
im_B_transfer_rgb = F.grid_sample( | |
x_A[None], warp[:, W:, :2][None], mode="bilinear", align_corners=False | |
)[0] | |
warp_im = torch.cat((im_A_transfer_rgb,im_B_transfer_rgb),dim=2) | |
white_im = torch.ones((H,2*W),device=device) | |
else: | |
warp_im = im_A_transfer_rgb | |
white_im = torch.ones((H, W), device = device) | |
vis_im = certainty * warp_im + (1 - certainty) * white_im | |
if save_path is not None: | |
from romatch.utils import tensor_to_pil | |
tensor_to_pil(vis_im, unnormalize=unnormalize).save(save_path) | |
return vis_im | |
def corr_volume(self, feat0, feat1): | |
""" | |
input: | |
feat0 -> torch.Tensor(B, C, H, W) | |
feat1 -> torch.Tensor(B, C, H, W) | |
return: | |
corr_volume -> torch.Tensor(B, H, W, H, W) | |
""" | |
B, C, H0, W0 = feat0.shape | |
B, C, H1, W1 = feat1.shape | |
feat0 = feat0.view(B, C, H0*W0) | |
feat1 = feat1.view(B, C, H1*W1) | |
corr_volume = torch.einsum('bci,bcj->bji', feat0, feat1).reshape(B, H1, W1, H0 , W0)/math.sqrt(C) #16*16*16 | |
return corr_volume | |
def match_from_path(self, im0_path, im1_path): | |
device = self.device | |
im0 = ToTensor()(Image.open(im0_path))[None].to(device) | |
im1 = ToTensor()(Image.open(im1_path))[None].to(device) | |
return self.match(im0, im1, batched = False) | |
def match(self, im0, im1, *args, batched = True): | |
# stupid | |
if isinstance(im0, (str, Path)): | |
return self.match_from_path(im0, im1) | |
elif isinstance(im0, Image.Image): | |
batched = False | |
device = self.device | |
im0 = ToTensor()(im0)[None].to(device) | |
im1 = ToTensor()(im1)[None].to(device) | |
B,C,H0,W0 = im0.shape | |
B,C,H1,W1 = im1.shape | |
self.train(False) | |
corresps = self.forward({"im_A":im0, "im_B":im1}) | |
#return 1,1 | |
flow = F.interpolate( | |
corresps[4]["flow"], | |
size = (H0, W0), | |
mode = "bilinear", align_corners = False).permute(0,2,3,1).reshape(B,H0,W0,2) | |
grid = torch.stack( | |
torch.meshgrid( | |
torch.linspace(-1+1/W0,1-1/W0, W0), | |
torch.linspace(-1+1/H0,1-1/H0, H0), | |
indexing = "xy"), | |
dim = -1).float().to(flow.device).expand(B, H0, W0, 2) | |
certainty = F.interpolate(corresps[4]["certainty"], size = (H0,W0), mode = "bilinear", align_corners = False) | |
warp, cert = torch.cat((grid, flow), dim = -1), certainty[:,0].sigmoid() | |
if batched: | |
return warp, cert | |
else: | |
return warp[0], cert[0] | |
def sample( | |
self, | |
matches, | |
certainty, | |
num=10000, | |
): | |
if "threshold" in self.sample_mode: | |
upper_thresh = self.sample_thresh | |
certainty = certainty.clone() | |
certainty[certainty > upper_thresh] = 1 | |
matches, certainty = ( | |
matches.reshape(-1, 4), | |
certainty.reshape(-1), | |
) | |
expansion_factor = 4 if "balanced" in self.sample_mode else 1 | |
good_samples = torch.multinomial(certainty, | |
num_samples = min(expansion_factor*num, len(certainty)), | |
replacement=False) | |
good_matches, good_certainty = matches[good_samples], certainty[good_samples] | |
if "balanced" not in self.sample_mode: | |
return good_matches, good_certainty | |
density = kde(good_matches, std=0.1) | |
p = 1 / (density+1) | |
p[density < 10] = 1e-7 # Basically should have at least 10 perfect neighbours, or around 100 ok ones | |
balanced_samples = torch.multinomial(p, | |
num_samples = min(num,len(good_certainty)), | |
replacement=False) | |
return good_matches[balanced_samples], good_certainty[balanced_samples] | |
def forward(self, batch): | |
""" | |
input: | |
x -> torch.Tensor(B, C, H, W) grayscale or rgb images | |
return: | |
""" | |
im0 = batch["im_A"] | |
im1 = batch["im_B"] | |
corresps = {} | |
im0, rh0, rw0 = self.preprocess_tensor(im0) | |
im1, rh1, rw1 = self.preprocess_tensor(im1) | |
B, C, H0, W0 = im0.shape | |
B, C, H1, W1 = im1.shape | |
to_normalized = torch.tensor((2/W1, 2/H1, 1)).to(im0.device)[None,:,None,None] | |
if im0.shape[-2:] == im1.shape[-2:]: | |
x = torch.cat([im0, im1], dim=0) | |
x = self.forward_single(x) | |
feats_x0_c, feats_x1_c = x[1].chunk(2) | |
feats_x0_f, feats_x1_f = x[0].chunk(2) | |
else: | |
feats_x0_f, feats_x0_c = self.forward_single(im0) | |
feats_x1_f, feats_x1_c = self.forward_single(im1) | |
corr_volume = self.corr_volume(feats_x0_c, feats_x1_c) | |
coarse_warp = self.pos_embed(corr_volume) | |
coarse_matches = torch.cat((coarse_warp, torch.zeros_like(coarse_warp[:,-1:])), dim=1) | |
feats_x1_c_warped = F.grid_sample(feats_x1_c, coarse_matches.permute(0, 2, 3, 1)[...,:2], mode = 'bilinear', align_corners = False) | |
coarse_matches_delta = self.coarse_matcher(torch.cat((feats_x0_c, feats_x1_c_warped, coarse_warp), dim=1)) | |
coarse_matches = coarse_matches + coarse_matches_delta * to_normalized | |
corresps[8] = {"flow": coarse_matches[:,:2], "certainty": coarse_matches[:,2:]} | |
coarse_matches_up = F.interpolate(coarse_matches, size = feats_x0_f.shape[-2:], mode = "bilinear", align_corners = False) | |
coarse_matches_up_detach = coarse_matches_up.detach()#note the detach | |
feats_x1_f_warped = F.grid_sample(feats_x1_f, coarse_matches_up_detach.permute(0, 2, 3, 1)[...,:2], mode = 'bilinear', align_corners = False) | |
fine_matches_delta = self.fine_matcher(torch.cat((feats_x0_f, feats_x1_f_warped, coarse_matches_up_detach[:,:2]), dim=1)) | |
fine_matches = coarse_matches_up_detach+fine_matches_delta * to_normalized | |
corresps[4] = {"flow": fine_matches[:,:2], "certainty": fine_matches[:,2:]} | |
return corresps | |
def train(args): | |
rank = 0 | |
gpus = 1 | |
device_id = rank % torch.cuda.device_count() | |
romatch.LOCAL_RANK = 0 | |
torch.cuda.set_device(device_id) | |
resolution = "big" | |
wandb_log = not args.dont_log_wandb | |
experiment_name = Path(__file__).stem | |
wandb_mode = "online" if wandb_log and rank == 0 else "disabled" | |
wandb.init(project="romatch", entity=args.wandb_entity, name=experiment_name, reinit=False, mode = wandb_mode) | |
checkpoint_dir = "workspace/checkpoints/" | |
h,w = resolutions[resolution] | |
model = XFeatModel(freeze_xfeat = False).to(device_id) | |
# Num steps | |
global_step = 0 | |
batch_size = args.gpu_batch_size | |
step_size = gpus*batch_size | |
romatch.STEP_SIZE = step_size | |
N = 2_000_000 # 2M pairs | |
# checkpoint every | |
k = 25000 // romatch.STEP_SIZE | |
# Data | |
mega = MegadepthBuilder(data_root="data/megadepth", loftr_ignore=True, imc21_ignore = True) | |
use_horizontal_flip_aug = True | |
normalize = False # don't imgnet normalize | |
rot_prob = 0 | |
depth_interpolation_mode = "bilinear" | |
megadepth_train1 = mega.build_scenes( | |
split="train_loftr", min_overlap=0.01, shake_t=32, use_horizontal_flip_aug = use_horizontal_flip_aug, rot_prob = rot_prob, | |
ht=h,wt=w, normalize = normalize | |
) | |
megadepth_train2 = mega.build_scenes( | |
split="train_loftr", min_overlap=0.35, shake_t=32, use_horizontal_flip_aug = use_horizontal_flip_aug, rot_prob = rot_prob, | |
ht=h,wt=w, normalize = normalize | |
) | |
megadepth_train = ConcatDataset(megadepth_train1 + megadepth_train2) | |
mega_ws = mega.weight_scenes(megadepth_train, alpha=0.75) | |
# Loss and optimizer | |
depth_loss = RobustLosses( | |
ce_weight=0.01, | |
local_dist={4:4}, | |
depth_interpolation_mode=depth_interpolation_mode, | |
alpha = {4:0.15, 8:0.15}, | |
c = 1e-4, | |
epe_mask_prob_th = 0.001, | |
) | |
parameters = [ | |
{"params": model.parameters(), "lr": romatch.STEP_SIZE * 1e-4 / 8}, | |
] | |
optimizer = torch.optim.AdamW(parameters, weight_decay=0.01) | |
lr_scheduler = torch.optim.lr_scheduler.MultiStepLR( | |
optimizer, milestones=[(9*N/romatch.STEP_SIZE)//10]) | |
#megadense_benchmark = MegadepthDenseBenchmark("data/megadepth", num_samples = 1000, h=h,w=w) | |
mega1500_benchmark = Mega1500PoseLibBenchmark("data/megadepth", num_ransac_iter = 1, test_every = 30) | |
checkpointer = CheckPoint(checkpoint_dir, experiment_name) | |
model, optimizer, lr_scheduler, global_step = checkpointer.load(model, optimizer, lr_scheduler, global_step) | |
romatch.GLOBAL_STEP = global_step | |
grad_scaler = torch.cuda.amp.GradScaler(growth_interval=1_000_000) | |
grad_clip_norm = 0.01 | |
#megadense_benchmark.benchmark(model) | |
for n in range(romatch.GLOBAL_STEP, N, k * romatch.STEP_SIZE): | |
mega_sampler = torch.utils.data.WeightedRandomSampler( | |
mega_ws, num_samples = batch_size * k, replacement=False | |
) | |
mega_dataloader = iter( | |
torch.utils.data.DataLoader( | |
megadepth_train, | |
batch_size = batch_size, | |
sampler = mega_sampler, | |
num_workers = 8, | |
) | |
) | |
train_k_steps( | |
n, k, mega_dataloader, model, depth_loss, optimizer, lr_scheduler, grad_scaler, grad_clip_norm = grad_clip_norm, | |
) | |
checkpointer.save(model, optimizer, lr_scheduler, romatch.GLOBAL_STEP) | |
wandb.log(mega1500_benchmark.benchmark(model, model_name=experiment_name), step = romatch.GLOBAL_STEP) | |
def test_mega_8_scenes(model, name): | |
mega_8_scenes_benchmark = MegaDepthPoseEstimationBenchmark("data/megadepth", | |
scene_names=['mega_8_scenes_0019_0.1_0.3.npz', | |
'mega_8_scenes_0025_0.1_0.3.npz', | |
'mega_8_scenes_0021_0.1_0.3.npz', | |
'mega_8_scenes_0008_0.1_0.3.npz', | |
'mega_8_scenes_0032_0.1_0.3.npz', | |
'mega_8_scenes_1589_0.1_0.3.npz', | |
'mega_8_scenes_0063_0.1_0.3.npz', | |
'mega_8_scenes_0024_0.1_0.3.npz', | |
'mega_8_scenes_0019_0.3_0.5.npz', | |
'mega_8_scenes_0025_0.3_0.5.npz', | |
'mega_8_scenes_0021_0.3_0.5.npz', | |
'mega_8_scenes_0008_0.3_0.5.npz', | |
'mega_8_scenes_0032_0.3_0.5.npz', | |
'mega_8_scenes_1589_0.3_0.5.npz', | |
'mega_8_scenes_0063_0.3_0.5.npz', | |
'mega_8_scenes_0024_0.3_0.5.npz']) | |
mega_8_scenes_results = mega_8_scenes_benchmark.benchmark(model, model_name=name) | |
print(mega_8_scenes_results) | |
json.dump(mega_8_scenes_results, open(f"results/mega_8_scenes_{name}.json", "w")) | |
def test_mega1500(model, name): | |
mega1500_benchmark = MegaDepthPoseEstimationBenchmark("data/megadepth") | |
mega1500_results = mega1500_benchmark.benchmark(model, model_name=name) | |
json.dump(mega1500_results, open(f"results/mega1500_{name}.json", "w")) | |
def test_mega1500_poselib(model, name): | |
mega1500_benchmark = Mega1500PoseLibBenchmark("data/megadepth", num_ransac_iter = 1, test_every = 1) | |
mega1500_results = mega1500_benchmark.benchmark(model, model_name=name) | |
json.dump(mega1500_results, open(f"results/mega1500_poselib_{name}.json", "w")) | |
def test_mega_8_scenes_poselib(model, name): | |
mega1500_benchmark = Mega1500PoseLibBenchmark("data/megadepth", num_ransac_iter = 1, test_every = 1, | |
scene_names=['mega_8_scenes_0019_0.1_0.3.npz', | |
'mega_8_scenes_0025_0.1_0.3.npz', | |
'mega_8_scenes_0021_0.1_0.3.npz', | |
'mega_8_scenes_0008_0.1_0.3.npz', | |
'mega_8_scenes_0032_0.1_0.3.npz', | |
'mega_8_scenes_1589_0.1_0.3.npz', | |
'mega_8_scenes_0063_0.1_0.3.npz', | |
'mega_8_scenes_0024_0.1_0.3.npz', | |
'mega_8_scenes_0019_0.3_0.5.npz', | |
'mega_8_scenes_0025_0.3_0.5.npz', | |
'mega_8_scenes_0021_0.3_0.5.npz', | |
'mega_8_scenes_0008_0.3_0.5.npz', | |
'mega_8_scenes_0032_0.3_0.5.npz', | |
'mega_8_scenes_1589_0.3_0.5.npz', | |
'mega_8_scenes_0063_0.3_0.5.npz', | |
'mega_8_scenes_0024_0.3_0.5.npz']) | |
mega1500_results = mega1500_benchmark.benchmark(model, model_name=name) | |
json.dump(mega1500_results, open(f"results/mega_8_scenes_poselib_{name}.json", "w")) | |
def test_scannet_poselib(model, name): | |
scannet_benchmark = ScanNetPoselibBenchmark("data/scannet") | |
scannet_results = scannet_benchmark.benchmark(model) | |
json.dump(scannet_results, open(f"results/scannet_{name}.json", "w")) | |
def test_scannet(model, name): | |
scannet_benchmark = ScanNetBenchmark("data/scannet") | |
scannet_results = scannet_benchmark.benchmark(model) | |
json.dump(scannet_results, open(f"results/scannet_{name}.json", "w")) | |
if __name__ == "__main__": | |
os.environ["TORCH_CUDNN_V8_API_ENABLED"] = "1" # For BF16 computations | |
os.environ["OMP_NUM_THREADS"] = "16" | |
torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn | |
import romatch | |
parser = ArgumentParser() | |
parser.add_argument("--only_test", action='store_true') | |
parser.add_argument("--debug_mode", action='store_true') | |
parser.add_argument("--dont_log_wandb", action='store_true') | |
parser.add_argument("--train_resolution", default='medium') | |
parser.add_argument("--gpu_batch_size", default=8, type=int) | |
parser.add_argument("--wandb_entity", required = False) | |
args, _ = parser.parse_known_args() | |
romatch.DEBUG_MODE = args.debug_mode | |
if not args.only_test: | |
train(args) | |
experiment_name = "tiny_roma_v1_outdoor"#Path(__file__).stem | |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
model = XFeatModel(freeze_xfeat=False, exact_softmax=False).to(device) | |
model.load_state_dict(torch.load(f"{experiment_name}.pth")) | |
test_mega1500_poselib(model, experiment_name) | |