sengerchen's picture
Upload folder using huggingface_hub
1bb1365 verified
raw
history blame contribute delete
10.5 kB
# Copyright (C) 2022-present Naver Corporation. All rights reserved.
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
# --------------------------------------------------------
# Main test function
# --------------------------------------------------------
import argparse
import os
import pickle
import numpy as np
import torch
import utils.misc as misc
from models.croco_downstream import CroCoDownstreamBinocular
from models.head_downstream import PixelwiseTaskWithDPT
from PIL import Image
from stereoflow.criterion import *
from stereoflow.datasets_flow import flowToColor, get_test_datasets_flow
from stereoflow.datasets_stereo import get_test_datasets_stereo, vis_disparity
from stereoflow.engine import tiled_pred
from torch.utils.data import DataLoader
from tqdm import tqdm
def get_args_parser():
parser = argparse.ArgumentParser("Test CroCo models on stereo/flow", add_help=False)
# important argument
parser.add_argument(
"--model", required=True, type=str, help="Path to the model to evaluate"
)
parser.add_argument(
"--dataset",
required=True,
type=str,
help="test dataset (there can be multiple dataset separated by a +)",
)
# tiling
parser.add_argument(
"--tile_conf_mode",
type=str,
default="",
help="Weights for the tiling aggregation based on confidence (empty means use the formula from the loaded checkpoint",
)
parser.add_argument(
"--tile_overlap", type=float, default=0.7, help="overlap between tiles"
)
# save (it will automatically go to <model_path>_<dataset_str>/<tile_str>_<save>)
parser.add_argument(
"--save",
type=str,
nargs="+",
default=[],
help="what to save: \
metrics (pickle file), \
pred (raw prediction save as torch tensor), \
visu (visualization in png of each prediction), \
err10 (visualization in png of the error clamp at 10 for each prediction), \
submission (submission file)",
)
# other (no impact)
parser.add_argument("--num_workers", default=4, type=int)
return parser
def _load_model_and_criterion(model_path, do_load_metrics, device):
print("loading model from", model_path)
assert os.path.isfile(model_path)
ckpt = torch.load(model_path, "cpu")
ckpt_args = ckpt["args"]
task = ckpt_args.task
tile_conf_mode = ckpt_args.tile_conf_mode
num_channels = {"stereo": 1, "flow": 2}[task]
with_conf = eval(ckpt_args.criterion).with_conf
if with_conf:
num_channels += 1
print("head: PixelwiseTaskWithDPT()")
head = PixelwiseTaskWithDPT()
head.num_channels = num_channels
print("croco_args:", ckpt_args.croco_args)
model = CroCoDownstreamBinocular(head, **ckpt_args.croco_args)
msg = model.load_state_dict(ckpt["model"], strict=True)
model.eval()
model = model.to(device)
if do_load_metrics:
if task == "stereo":
metrics = StereoDatasetMetrics().to(device)
else:
metrics = FlowDatasetMetrics().to(device)
else:
metrics = None
return model, metrics, ckpt_args.crop, with_conf, task, tile_conf_mode
def _save_batch(
pred, gt, pairnames, dataset, task, save, outdir, time, submission_dir=None
):
for i in range(len(pairnames)):
pairname = (
eval(pairnames[i]) if pairnames[i].startswith("(") else pairnames[i]
) # unbatch pairname
fname = os.path.join(outdir, dataset.pairname_to_str(pairname))
os.makedirs(os.path.dirname(fname), exist_ok=True)
predi = pred[i, ...]
if gt is not None:
gti = gt[i, ...]
if "pred" in save:
torch.save(predi.squeeze(0).cpu(), fname + "_pred.pth")
if "visu" in save:
if task == "stereo":
disparity = predi.permute((1, 2, 0)).squeeze(2).cpu().numpy()
m, M = None
if gt is not None:
mask = torch.isfinite(gti)
m = gt[mask].min()
M = gt[mask].max()
img_disparity = vis_disparity(disparity, m=m, M=M)
Image.fromarray(img_disparity).save(fname + "_pred.png")
else:
# normalize flowToColor according to the maxnorm of gt (or prediction if not available)
flowNorm = (
torch.sqrt(
torch.sum((gti if gt is not None else predi) ** 2, dim=0)
)
.max()
.item()
)
imgflow = flowToColor(
predi.permute((1, 2, 0)).cpu().numpy(), maxflow=flowNorm
)
Image.fromarray(imgflow).save(fname + "_pred.png")
if "err10" in save:
assert gt is not None
L2err = torch.sqrt(torch.sum((gti - predi) ** 2, dim=0))
valid = torch.isfinite(gti[0, :, :])
L2err[~valid] = 0.0
L2err = torch.clamp(L2err, max=10.0)
red = (L2err * 255.0 / 10.0).to(dtype=torch.uint8)[:, :, None]
zer = torch.zeros_like(red)
imgerr = torch.cat((red, zer, zer), dim=2).cpu().numpy()
Image.fromarray(imgerr).save(fname + "_err10.png")
if "submission" in save:
assert submission_dir is not None
predi_np = (
predi.permute(1, 2, 0).squeeze(2).cpu().numpy()
) # transform into HxWx2 for flow or HxW for stereo
dataset.submission_save_pairname(pairname, predi_np, submission_dir, time)
def main(args):
# load the pretrained model and metrics
device = (
torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
)
(
model,
metrics,
cropsize,
with_conf,
task,
tile_conf_mode,
) = _load_model_and_criterion(args.model, "metrics" in args.save, device)
if args.tile_conf_mode == "":
args.tile_conf_mode = tile_conf_mode
# load the datasets
datasets = (
get_test_datasets_stereo if task == "stereo" else get_test_datasets_flow
)(args.dataset)
dataloaders = [
DataLoader(
dataset,
batch_size=1,
shuffle=False,
num_workers=args.num_workers,
pin_memory=True,
drop_last=False,
)
for dataset in datasets
]
# run
for i, dataloader in enumerate(dataloaders):
dataset = datasets[i]
dstr = args.dataset.split("+")[i]
outdir = args.model + "_" + misc.filename(dstr)
if "metrics" in args.save and len(args.save) == 1:
fname = os.path.join(
outdir, f"conf_{args.tile_conf_mode}_overlap_{args.tile_overlap}.pkl"
)
if os.path.isfile(fname) and len(args.save) == 1:
print(" metrics already compute in " + fname)
with open(fname, "rb") as fid:
results = pickle.load(fid)
for k, v in results.items():
print("{:s}: {:.3f}".format(k, v))
continue
if "submission" in args.save:
dirname = (
f"submission_conf_{args.tile_conf_mode}_overlap_{args.tile_overlap}"
)
submission_dir = os.path.join(outdir, dirname)
else:
submission_dir = None
print("")
print("saving {:s} in {:s}".format("+".join(args.save), outdir))
print(repr(dataset))
if metrics is not None:
metrics.reset()
for data_iter_step, (image1, image2, gt, pairnames) in enumerate(
tqdm(dataloader)
):
do_flip = (
task == "stereo"
and dstr.startswith("Spring")
and any("right" in p for p in pairnames)
) # we flip the images and will flip the prediction after as we assume img1 is on the left
image1 = image1.to(device, non_blocking=True)
image2 = image2.to(device, non_blocking=True)
gt = (
gt.to(device, non_blocking=True) if gt.numel() > 0 else None
) # special case for test time
if do_flip:
assert all("right" in p for p in pairnames)
image1 = image1.flip(
dims=[3]
) # this is already the right frame, let's flip it
image2 = image2.flip(dims=[3])
gt = gt # that is ok
with torch.inference_mode():
pred, _, _, time = tiled_pred(
model,
None,
image1,
image2,
None if dataset.name == "Spring" else gt,
conf_mode=args.tile_conf_mode,
overlap=args.tile_overlap,
crop=cropsize,
with_conf=with_conf,
return_time=True,
)
if do_flip:
pred = pred.flip(dims=[3])
if metrics is not None:
metrics.add_batch(pred, gt)
if any(k in args.save for k in ["pred", "visu", "err10", "submission"]):
_save_batch(
pred,
gt,
pairnames,
dataset,
task,
args.save,
outdir,
time,
submission_dir=submission_dir,
)
# print
if metrics is not None:
results = metrics.get_results()
for k, v in results.items():
print("{:s}: {:.3f}".format(k, v))
# save if needed
if "metrics" in args.save:
os.makedirs(os.path.dirname(fname), exist_ok=True)
with open(fname, "wb") as fid:
pickle.dump(results, fid)
print("metrics saved in", fname)
# finalize submission if needed
if "submission" in args.save:
dataset.finalize_submission(submission_dir)
if __name__ == "__main__":
args = get_args_parser()
args = args.parse_args()
main(args)