# Copyright (C) 2022-present Naver Corporation. All rights reserved. # Licensed under CC BY-NC-SA 4.0 (non-commercial use only). # -------------------------------------------------------- # Main test function # -------------------------------------------------------- import argparse import os import pickle import numpy as np import torch import utils.misc as misc from models.croco_downstream import CroCoDownstreamBinocular from models.head_downstream import PixelwiseTaskWithDPT from PIL import Image from stereoflow.criterion import * from stereoflow.datasets_flow import flowToColor, get_test_datasets_flow from stereoflow.datasets_stereo import get_test_datasets_stereo, vis_disparity from stereoflow.engine import tiled_pred from torch.utils.data import DataLoader from tqdm import tqdm def get_args_parser(): parser = argparse.ArgumentParser("Test CroCo models on stereo/flow", add_help=False) # important argument parser.add_argument( "--model", required=True, type=str, help="Path to the model to evaluate" ) parser.add_argument( "--dataset", required=True, type=str, help="test dataset (there can be multiple dataset separated by a +)", ) # tiling parser.add_argument( "--tile_conf_mode", type=str, default="", help="Weights for the tiling aggregation based on confidence (empty means use the formula from the loaded checkpoint", ) parser.add_argument( "--tile_overlap", type=float, default=0.7, help="overlap between tiles" ) # save (it will automatically go to _/_) parser.add_argument( "--save", type=str, nargs="+", default=[], help="what to save: \ metrics (pickle file), \ pred (raw prediction save as torch tensor), \ visu (visualization in png of each prediction), \ err10 (visualization in png of the error clamp at 10 for each prediction), \ submission (submission file)", ) # other (no impact) parser.add_argument("--num_workers", default=4, type=int) return parser def _load_model_and_criterion(model_path, do_load_metrics, device): print("loading model from", model_path) assert os.path.isfile(model_path) ckpt = torch.load(model_path, "cpu") ckpt_args = ckpt["args"] task = ckpt_args.task tile_conf_mode = ckpt_args.tile_conf_mode num_channels = {"stereo": 1, "flow": 2}[task] with_conf = eval(ckpt_args.criterion).with_conf if with_conf: num_channels += 1 print("head: PixelwiseTaskWithDPT()") head = PixelwiseTaskWithDPT() head.num_channels = num_channels print("croco_args:", ckpt_args.croco_args) model = CroCoDownstreamBinocular(head, **ckpt_args.croco_args) msg = model.load_state_dict(ckpt["model"], strict=True) model.eval() model = model.to(device) if do_load_metrics: if task == "stereo": metrics = StereoDatasetMetrics().to(device) else: metrics = FlowDatasetMetrics().to(device) else: metrics = None return model, metrics, ckpt_args.crop, with_conf, task, tile_conf_mode def _save_batch( pred, gt, pairnames, dataset, task, save, outdir, time, submission_dir=None ): for i in range(len(pairnames)): pairname = ( eval(pairnames[i]) if pairnames[i].startswith("(") else pairnames[i] ) # unbatch pairname fname = os.path.join(outdir, dataset.pairname_to_str(pairname)) os.makedirs(os.path.dirname(fname), exist_ok=True) predi = pred[i, ...] if gt is not None: gti = gt[i, ...] if "pred" in save: torch.save(predi.squeeze(0).cpu(), fname + "_pred.pth") if "visu" in save: if task == "stereo": disparity = predi.permute((1, 2, 0)).squeeze(2).cpu().numpy() m, M = None if gt is not None: mask = torch.isfinite(gti) m = gt[mask].min() M = gt[mask].max() img_disparity = vis_disparity(disparity, m=m, M=M) Image.fromarray(img_disparity).save(fname + "_pred.png") else: # normalize flowToColor according to the maxnorm of gt (or prediction if not available) flowNorm = ( torch.sqrt( torch.sum((gti if gt is not None else predi) ** 2, dim=0) ) .max() .item() ) imgflow = flowToColor( predi.permute((1, 2, 0)).cpu().numpy(), maxflow=flowNorm ) Image.fromarray(imgflow).save(fname + "_pred.png") if "err10" in save: assert gt is not None L2err = torch.sqrt(torch.sum((gti - predi) ** 2, dim=0)) valid = torch.isfinite(gti[0, :, :]) L2err[~valid] = 0.0 L2err = torch.clamp(L2err, max=10.0) red = (L2err * 255.0 / 10.0).to(dtype=torch.uint8)[:, :, None] zer = torch.zeros_like(red) imgerr = torch.cat((red, zer, zer), dim=2).cpu().numpy() Image.fromarray(imgerr).save(fname + "_err10.png") if "submission" in save: assert submission_dir is not None predi_np = ( predi.permute(1, 2, 0).squeeze(2).cpu().numpy() ) # transform into HxWx2 for flow or HxW for stereo dataset.submission_save_pairname(pairname, predi_np, submission_dir, time) def main(args): # load the pretrained model and metrics device = ( torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu") ) ( model, metrics, cropsize, with_conf, task, tile_conf_mode, ) = _load_model_and_criterion(args.model, "metrics" in args.save, device) if args.tile_conf_mode == "": args.tile_conf_mode = tile_conf_mode # load the datasets datasets = ( get_test_datasets_stereo if task == "stereo" else get_test_datasets_flow )(args.dataset) dataloaders = [ DataLoader( dataset, batch_size=1, shuffle=False, num_workers=args.num_workers, pin_memory=True, drop_last=False, ) for dataset in datasets ] # run for i, dataloader in enumerate(dataloaders): dataset = datasets[i] dstr = args.dataset.split("+")[i] outdir = args.model + "_" + misc.filename(dstr) if "metrics" in args.save and len(args.save) == 1: fname = os.path.join( outdir, f"conf_{args.tile_conf_mode}_overlap_{args.tile_overlap}.pkl" ) if os.path.isfile(fname) and len(args.save) == 1: print(" metrics already compute in " + fname) with open(fname, "rb") as fid: results = pickle.load(fid) for k, v in results.items(): print("{:s}: {:.3f}".format(k, v)) continue if "submission" in args.save: dirname = ( f"submission_conf_{args.tile_conf_mode}_overlap_{args.tile_overlap}" ) submission_dir = os.path.join(outdir, dirname) else: submission_dir = None print("") print("saving {:s} in {:s}".format("+".join(args.save), outdir)) print(repr(dataset)) if metrics is not None: metrics.reset() for data_iter_step, (image1, image2, gt, pairnames) in enumerate( tqdm(dataloader) ): do_flip = ( task == "stereo" and dstr.startswith("Spring") and any("right" in p for p in pairnames) ) # we flip the images and will flip the prediction after as we assume img1 is on the left image1 = image1.to(device, non_blocking=True) image2 = image2.to(device, non_blocking=True) gt = ( gt.to(device, non_blocking=True) if gt.numel() > 0 else None ) # special case for test time if do_flip: assert all("right" in p for p in pairnames) image1 = image1.flip( dims=[3] ) # this is already the right frame, let's flip it image2 = image2.flip(dims=[3]) gt = gt # that is ok with torch.inference_mode(): pred, _, _, time = tiled_pred( model, None, image1, image2, None if dataset.name == "Spring" else gt, conf_mode=args.tile_conf_mode, overlap=args.tile_overlap, crop=cropsize, with_conf=with_conf, return_time=True, ) if do_flip: pred = pred.flip(dims=[3]) if metrics is not None: metrics.add_batch(pred, gt) if any(k in args.save for k in ["pred", "visu", "err10", "submission"]): _save_batch( pred, gt, pairnames, dataset, task, args.save, outdir, time, submission_dir=submission_dir, ) # print if metrics is not None: results = metrics.get_results() for k, v in results.items(): print("{:s}: {:.3f}".format(k, v)) # save if needed if "metrics" in args.save: os.makedirs(os.path.dirname(fname), exist_ok=True) with open(fname, "wb") as fid: pickle.dump(results, fid) print("metrics saved in", fname) # finalize submission if needed if "submission" in args.save: dataset.finalize_submission(submission_dir) if __name__ == "__main__": args = get_args_parser() args = args.parse_args() main(args)