Spaces:
Runtime error
Runtime error
# Copyright (C) 2022-present Naver Corporation. All rights reserved. | |
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). | |
# -------------------------------------------------------- | |
# Main test function | |
# -------------------------------------------------------- | |
import argparse | |
import os | |
import pickle | |
import numpy as np | |
import torch | |
import utils.misc as misc | |
from models.croco_downstream import CroCoDownstreamBinocular | |
from models.head_downstream import PixelwiseTaskWithDPT | |
from PIL import Image | |
from stereoflow.criterion import * | |
from stereoflow.datasets_flow import flowToColor, get_test_datasets_flow | |
from stereoflow.datasets_stereo import get_test_datasets_stereo, vis_disparity | |
from stereoflow.engine import tiled_pred | |
from torch.utils.data import DataLoader | |
from tqdm import tqdm | |
def get_args_parser(): | |
parser = argparse.ArgumentParser("Test CroCo models on stereo/flow", add_help=False) | |
# important argument | |
parser.add_argument( | |
"--model", required=True, type=str, help="Path to the model to evaluate" | |
) | |
parser.add_argument( | |
"--dataset", | |
required=True, | |
type=str, | |
help="test dataset (there can be multiple dataset separated by a +)", | |
) | |
# tiling | |
parser.add_argument( | |
"--tile_conf_mode", | |
type=str, | |
default="", | |
help="Weights for the tiling aggregation based on confidence (empty means use the formula from the loaded checkpoint", | |
) | |
parser.add_argument( | |
"--tile_overlap", type=float, default=0.7, help="overlap between tiles" | |
) | |
# save (it will automatically go to <model_path>_<dataset_str>/<tile_str>_<save>) | |
parser.add_argument( | |
"--save", | |
type=str, | |
nargs="+", | |
default=[], | |
help="what to save: \ | |
metrics (pickle file), \ | |
pred (raw prediction save as torch tensor), \ | |
visu (visualization in png of each prediction), \ | |
err10 (visualization in png of the error clamp at 10 for each prediction), \ | |
submission (submission file)", | |
) | |
# other (no impact) | |
parser.add_argument("--num_workers", default=4, type=int) | |
return parser | |
def _load_model_and_criterion(model_path, do_load_metrics, device): | |
print("loading model from", model_path) | |
assert os.path.isfile(model_path) | |
ckpt = torch.load(model_path, "cpu") | |
ckpt_args = ckpt["args"] | |
task = ckpt_args.task | |
tile_conf_mode = ckpt_args.tile_conf_mode | |
num_channels = {"stereo": 1, "flow": 2}[task] | |
with_conf = eval(ckpt_args.criterion).with_conf | |
if with_conf: | |
num_channels += 1 | |
print("head: PixelwiseTaskWithDPT()") | |
head = PixelwiseTaskWithDPT() | |
head.num_channels = num_channels | |
print("croco_args:", ckpt_args.croco_args) | |
model = CroCoDownstreamBinocular(head, **ckpt_args.croco_args) | |
msg = model.load_state_dict(ckpt["model"], strict=True) | |
model.eval() | |
model = model.to(device) | |
if do_load_metrics: | |
if task == "stereo": | |
metrics = StereoDatasetMetrics().to(device) | |
else: | |
metrics = FlowDatasetMetrics().to(device) | |
else: | |
metrics = None | |
return model, metrics, ckpt_args.crop, with_conf, task, tile_conf_mode | |
def _save_batch( | |
pred, gt, pairnames, dataset, task, save, outdir, time, submission_dir=None | |
): | |
for i in range(len(pairnames)): | |
pairname = ( | |
eval(pairnames[i]) if pairnames[i].startswith("(") else pairnames[i] | |
) # unbatch pairname | |
fname = os.path.join(outdir, dataset.pairname_to_str(pairname)) | |
os.makedirs(os.path.dirname(fname), exist_ok=True) | |
predi = pred[i, ...] | |
if gt is not None: | |
gti = gt[i, ...] | |
if "pred" in save: | |
torch.save(predi.squeeze(0).cpu(), fname + "_pred.pth") | |
if "visu" in save: | |
if task == "stereo": | |
disparity = predi.permute((1, 2, 0)).squeeze(2).cpu().numpy() | |
m, M = None | |
if gt is not None: | |
mask = torch.isfinite(gti) | |
m = gt[mask].min() | |
M = gt[mask].max() | |
img_disparity = vis_disparity(disparity, m=m, M=M) | |
Image.fromarray(img_disparity).save(fname + "_pred.png") | |
else: | |
# normalize flowToColor according to the maxnorm of gt (or prediction if not available) | |
flowNorm = ( | |
torch.sqrt( | |
torch.sum((gti if gt is not None else predi) ** 2, dim=0) | |
) | |
.max() | |
.item() | |
) | |
imgflow = flowToColor( | |
predi.permute((1, 2, 0)).cpu().numpy(), maxflow=flowNorm | |
) | |
Image.fromarray(imgflow).save(fname + "_pred.png") | |
if "err10" in save: | |
assert gt is not None | |
L2err = torch.sqrt(torch.sum((gti - predi) ** 2, dim=0)) | |
valid = torch.isfinite(gti[0, :, :]) | |
L2err[~valid] = 0.0 | |
L2err = torch.clamp(L2err, max=10.0) | |
red = (L2err * 255.0 / 10.0).to(dtype=torch.uint8)[:, :, None] | |
zer = torch.zeros_like(red) | |
imgerr = torch.cat((red, zer, zer), dim=2).cpu().numpy() | |
Image.fromarray(imgerr).save(fname + "_err10.png") | |
if "submission" in save: | |
assert submission_dir is not None | |
predi_np = ( | |
predi.permute(1, 2, 0).squeeze(2).cpu().numpy() | |
) # transform into HxWx2 for flow or HxW for stereo | |
dataset.submission_save_pairname(pairname, predi_np, submission_dir, time) | |
def main(args): | |
# load the pretrained model and metrics | |
device = ( | |
torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu") | |
) | |
( | |
model, | |
metrics, | |
cropsize, | |
with_conf, | |
task, | |
tile_conf_mode, | |
) = _load_model_and_criterion(args.model, "metrics" in args.save, device) | |
if args.tile_conf_mode == "": | |
args.tile_conf_mode = tile_conf_mode | |
# load the datasets | |
datasets = ( | |
get_test_datasets_stereo if task == "stereo" else get_test_datasets_flow | |
)(args.dataset) | |
dataloaders = [ | |
DataLoader( | |
dataset, | |
batch_size=1, | |
shuffle=False, | |
num_workers=args.num_workers, | |
pin_memory=True, | |
drop_last=False, | |
) | |
for dataset in datasets | |
] | |
# run | |
for i, dataloader in enumerate(dataloaders): | |
dataset = datasets[i] | |
dstr = args.dataset.split("+")[i] | |
outdir = args.model + "_" + misc.filename(dstr) | |
if "metrics" in args.save and len(args.save) == 1: | |
fname = os.path.join( | |
outdir, f"conf_{args.tile_conf_mode}_overlap_{args.tile_overlap}.pkl" | |
) | |
if os.path.isfile(fname) and len(args.save) == 1: | |
print(" metrics already compute in " + fname) | |
with open(fname, "rb") as fid: | |
results = pickle.load(fid) | |
for k, v in results.items(): | |
print("{:s}: {:.3f}".format(k, v)) | |
continue | |
if "submission" in args.save: | |
dirname = ( | |
f"submission_conf_{args.tile_conf_mode}_overlap_{args.tile_overlap}" | |
) | |
submission_dir = os.path.join(outdir, dirname) | |
else: | |
submission_dir = None | |
print("") | |
print("saving {:s} in {:s}".format("+".join(args.save), outdir)) | |
print(repr(dataset)) | |
if metrics is not None: | |
metrics.reset() | |
for data_iter_step, (image1, image2, gt, pairnames) in enumerate( | |
tqdm(dataloader) | |
): | |
do_flip = ( | |
task == "stereo" | |
and dstr.startswith("Spring") | |
and any("right" in p for p in pairnames) | |
) # we flip the images and will flip the prediction after as we assume img1 is on the left | |
image1 = image1.to(device, non_blocking=True) | |
image2 = image2.to(device, non_blocking=True) | |
gt = ( | |
gt.to(device, non_blocking=True) if gt.numel() > 0 else None | |
) # special case for test time | |
if do_flip: | |
assert all("right" in p for p in pairnames) | |
image1 = image1.flip( | |
dims=[3] | |
) # this is already the right frame, let's flip it | |
image2 = image2.flip(dims=[3]) | |
gt = gt # that is ok | |
with torch.inference_mode(): | |
pred, _, _, time = tiled_pred( | |
model, | |
None, | |
image1, | |
image2, | |
None if dataset.name == "Spring" else gt, | |
conf_mode=args.tile_conf_mode, | |
overlap=args.tile_overlap, | |
crop=cropsize, | |
with_conf=with_conf, | |
return_time=True, | |
) | |
if do_flip: | |
pred = pred.flip(dims=[3]) | |
if metrics is not None: | |
metrics.add_batch(pred, gt) | |
if any(k in args.save for k in ["pred", "visu", "err10", "submission"]): | |
_save_batch( | |
pred, | |
gt, | |
pairnames, | |
dataset, | |
task, | |
args.save, | |
outdir, | |
time, | |
submission_dir=submission_dir, | |
) | |
if metrics is not None: | |
results = metrics.get_results() | |
for k, v in results.items(): | |
print("{:s}: {:.3f}".format(k, v)) | |
# save if needed | |
if "metrics" in args.save: | |
os.makedirs(os.path.dirname(fname), exist_ok=True) | |
with open(fname, "wb") as fid: | |
pickle.dump(results, fid) | |
print("metrics saved in", fname) | |
# finalize submission if needed | |
if "submission" in args.save: | |
dataset.finalize_submission(submission_dir) | |
if __name__ == "__main__": | |
args = get_args_parser() | |
args = args.parse_args() | |
main(args) | |