Spaces:
Runtime error
Runtime error
# Copyright (C) 2022-present Naver Corporation. All rights reserved. | |
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). | |
# -------------------------------------------------------- | |
# Losses, metrics per batch, metrics per dataset | |
# -------------------------------------------------------- | |
import torch | |
import torch.nn.functional as F | |
from torch import nn | |
def _get_gtnorm(gt): | |
if gt.size(1) == 1: # stereo | |
return gt | |
# flow | |
return torch.sqrt(torch.sum(gt**2, dim=1, keepdims=True)) # Bx1xHxW | |
############ losses without confidence | |
class L1Loss(nn.Module): | |
def __init__(self, max_gtnorm=None): | |
super().__init__() | |
self.max_gtnorm = max_gtnorm | |
self.with_conf = False | |
def _error(self, gt, predictions): | |
return torch.abs(gt - predictions) | |
def forward(self, predictions, gt, inspect=False): | |
mask = torch.isfinite(gt) | |
if self.max_gtnorm is not None: | |
mask *= _get_gtnorm(gt).expand(-1, gt.size(1), -1, -1) < self.max_gtnorm | |
if inspect: | |
return self._error(gt, predictions) | |
return self._error(gt[mask], predictions[mask]).mean() | |
############## losses with confience | |
## there are several parametrizations | |
class LaplacianLoss(nn.Module): # used for CroCo-Stereo on ETH3D, d'=exp(d) | |
def __init__(self, max_gtnorm=None): | |
super().__init__() | |
self.max_gtnorm = max_gtnorm | |
self.with_conf = True | |
def forward(self, predictions, gt, conf): | |
mask = torch.isfinite(gt) | |
mask = mask[:, 0, :, :] | |
if self.max_gtnorm is not None: | |
mask *= _get_gtnorm(gt)[:, 0, :, :] < self.max_gtnorm | |
conf = conf.squeeze(1) | |
return ( | |
torch.abs(gt - predictions).sum(dim=1)[mask] / torch.exp(conf[mask]) | |
+ conf[mask] | |
).mean() # + torch.log(2) => which is a constant | |
class LaplacianLossBounded( | |
nn.Module | |
): # used for CroCo-Flow ; in the equation of the paper, we have a=1/b | |
def __init__(self, max_gtnorm=10000.0, a=0.25, b=4.0): | |
super().__init__() | |
self.max_gtnorm = max_gtnorm | |
self.with_conf = True | |
self.a, self.b = a, b | |
def forward(self, predictions, gt, conf): | |
mask = torch.isfinite(gt) | |
mask = mask[:, 0, :, :] | |
if self.max_gtnorm is not None: | |
mask *= _get_gtnorm(gt)[:, 0, :, :] < self.max_gtnorm | |
conf = conf.squeeze(1) | |
conf = (self.b - self.a) * torch.sigmoid(conf) + self.a | |
return ( | |
torch.abs(gt - predictions).sum(dim=1)[mask] / conf[mask] | |
+ torch.log(conf)[mask] | |
).mean() # + torch.log(2) => which is a constant | |
class LaplacianLossBounded2( | |
nn.Module | |
): # used for CroCo-Stereo (except for ETH3D) ; in the equation of the paper, we have a=b | |
def __init__(self, max_gtnorm=None, a=3.0, b=3.0): | |
super().__init__() | |
self.max_gtnorm = max_gtnorm | |
self.with_conf = True | |
self.a, self.b = a, b | |
def forward(self, predictions, gt, conf): | |
mask = torch.isfinite(gt) | |
mask = mask[:, 0, :, :] | |
if self.max_gtnorm is not None: | |
mask *= _get_gtnorm(gt)[:, 0, :, :] < self.max_gtnorm | |
conf = conf.squeeze(1) | |
conf = 2 * self.a * (torch.sigmoid(conf / self.b) - 0.5) | |
return ( | |
torch.abs(gt - predictions).sum(dim=1)[mask] / torch.exp(conf[mask]) | |
+ conf[mask] | |
).mean() # + torch.log(2) => which is a constant | |
############## metrics per batch | |
class StereoMetrics(nn.Module): | |
def __init__(self, do_quantile=False): | |
super().__init__() | |
self.bad_ths = [0.5, 1, 2, 3] | |
self.do_quantile = do_quantile | |
def forward(self, predictions, gt): | |
B = predictions.size(0) | |
metrics = {} | |
gtcopy = gt.clone() | |
mask = torch.isfinite(gtcopy) | |
gtcopy[ | |
~mask | |
] = 999999.0 # we make a copy and put a non-infinite value, such that it does not become nan once multiplied by the mask value 0 | |
Npx = mask.view(B, -1).sum(dim=1) | |
L1error = (torch.abs(gtcopy - predictions) * mask).view(B, -1) | |
L2error = (torch.square(gtcopy - predictions) * mask).view(B, -1) | |
# avgerr | |
metrics["avgerr"] = torch.mean(L1error.sum(dim=1) / Npx) | |
# rmse | |
metrics["rmse"] = torch.sqrt(L2error.sum(dim=1) / Npx).mean(dim=0) | |
# err > t for t in [0.5,1,2,3] | |
for ths in self.bad_ths: | |
metrics["bad@{:.1f}".format(ths)] = ( | |
((L1error > ths) * mask.view(B, -1)).sum(dim=1) / Npx | |
).mean(dim=0) * 100 | |
return metrics | |
class FlowMetrics(nn.Module): | |
def __init__(self): | |
super().__init__() | |
self.bad_ths = [1, 3, 5] | |
def forward(self, predictions, gt): | |
B = predictions.size(0) | |
metrics = {} | |
mask = torch.isfinite(gt[:, 0, :, :]) # both x and y would be infinite | |
Npx = mask.view(B, -1).sum(dim=1) | |
gtcopy = ( | |
gt.clone() | |
) # to compute L1/L2 error, we need to have non-infinite value, the error computed at this locations will be ignored | |
gtcopy[:, 0, :, :][~mask] = 999999.0 | |
gtcopy[:, 1, :, :][~mask] = 999999.0 | |
L1error = (torch.abs(gtcopy - predictions).sum(dim=1) * mask).view(B, -1) | |
L2error = ( | |
torch.sqrt(torch.sum(torch.square(gtcopy - predictions), dim=1)) * mask | |
).view(B, -1) | |
metrics["L1err"] = torch.mean(L1error.sum(dim=1) / Npx) | |
metrics["EPE"] = torch.mean(L2error.sum(dim=1) / Npx) | |
for ths in self.bad_ths: | |
metrics["bad@{:.1f}".format(ths)] = ( | |
((L2error > ths) * mask.view(B, -1)).sum(dim=1) / Npx | |
).mean(dim=0) * 100 | |
return metrics | |
############## metrics per dataset | |
## we update the average and maintain the number of pixels while adding data batch per batch | |
## at the beggining, call reset() | |
## after each batch, call add_batch(...) | |
## at the end: call get_results() | |
class StereoDatasetMetrics(nn.Module): | |
def __init__(self): | |
super().__init__() | |
self.bad_ths = [0.5, 1, 2, 3] | |
def reset(self): | |
self.agg_N = 0 # number of pixels so far | |
self.agg_L1err = torch.tensor(0.0) # L1 error so far | |
self.agg_Nbad = [0 for _ in self.bad_ths] # counter of bad pixels | |
self._metrics = None | |
def add_batch(self, predictions, gt): | |
assert predictions.size(1) == 1, predictions.size() | |
assert gt.size(1) == 1, gt.size() | |
if ( | |
gt.size(2) == predictions.size(2) * 2 | |
and gt.size(3) == predictions.size(3) * 2 | |
): # special case for Spring ... | |
L1err = torch.minimum( | |
torch.minimum( | |
torch.minimum( | |
torch.sum(torch.abs(gt[:, :, 0::2, 0::2] - predictions), dim=1), | |
torch.sum(torch.abs(gt[:, :, 1::2, 0::2] - predictions), dim=1), | |
), | |
torch.sum(torch.abs(gt[:, :, 0::2, 1::2] - predictions), dim=1), | |
), | |
torch.sum(torch.abs(gt[:, :, 1::2, 1::2] - predictions), dim=1), | |
) | |
valid = torch.isfinite(L1err) | |
else: | |
valid = torch.isfinite(gt[:, 0, :, :]) # both x and y would be infinite | |
L1err = torch.sum(torch.abs(gt - predictions), dim=1) | |
N = valid.sum() | |
Nnew = self.agg_N + N | |
self.agg_L1err = ( | |
float(self.agg_N) / Nnew * self.agg_L1err | |
+ L1err[valid].mean().cpu() * float(N) / Nnew | |
) | |
self.agg_N = Nnew | |
for i, th in enumerate(self.bad_ths): | |
self.agg_Nbad[i] += (L1err[valid] > th).sum().cpu() | |
def _compute_metrics(self): | |
if self._metrics is not None: | |
return | |
out = {} | |
out["L1err"] = self.agg_L1err.item() | |
for i, th in enumerate(self.bad_ths): | |
out["bad@{:.1f}".format(th)] = ( | |
float(self.agg_Nbad[i]) / self.agg_N | |
).item() * 100.0 | |
self._metrics = out | |
def get_results(self): | |
self._compute_metrics() # to avoid recompute them multiple times | |
return self._metrics | |
class FlowDatasetMetrics(nn.Module): | |
def __init__(self): | |
super().__init__() | |
self.bad_ths = [0.5, 1, 3, 5] | |
self.speed_ths = [(0, 10), (10, 40), (40, torch.inf)] | |
def reset(self): | |
self.agg_N = 0 # number of pixels so far | |
self.agg_L1err = torch.tensor(0.0) # L1 error so far | |
self.agg_L2err = torch.tensor(0.0) # L2 (=EPE) error so far | |
self.agg_Nbad = [0 for _ in self.bad_ths] # counter of bad pixels | |
self.agg_EPEspeed = [ | |
torch.tensor(0.0) for _ in self.speed_ths | |
] # EPE per speed bin so far | |
self.agg_Nspeed = [0 for _ in self.speed_ths] # N pixels per speed bin so far | |
self._metrics = None | |
self.pairname_results = {} | |
def add_batch(self, predictions, gt): | |
assert predictions.size(1) == 2, predictions.size() | |
assert gt.size(1) == 2, gt.size() | |
if ( | |
gt.size(2) == predictions.size(2) * 2 | |
and gt.size(3) == predictions.size(3) * 2 | |
): # special case for Spring ... | |
L1err = torch.minimum( | |
torch.minimum( | |
torch.minimum( | |
torch.sum(torch.abs(gt[:, :, 0::2, 0::2] - predictions), dim=1), | |
torch.sum(torch.abs(gt[:, :, 1::2, 0::2] - predictions), dim=1), | |
), | |
torch.sum(torch.abs(gt[:, :, 0::2, 1::2] - predictions), dim=1), | |
), | |
torch.sum(torch.abs(gt[:, :, 1::2, 1::2] - predictions), dim=1), | |
) | |
L2err = torch.minimum( | |
torch.minimum( | |
torch.minimum( | |
torch.sqrt( | |
torch.sum( | |
torch.square(gt[:, :, 0::2, 0::2] - predictions), dim=1 | |
) | |
), | |
torch.sqrt( | |
torch.sum( | |
torch.square(gt[:, :, 1::2, 0::2] - predictions), dim=1 | |
) | |
), | |
), | |
torch.sqrt( | |
torch.sum( | |
torch.square(gt[:, :, 0::2, 1::2] - predictions), dim=1 | |
) | |
), | |
), | |
torch.sqrt( | |
torch.sum(torch.square(gt[:, :, 1::2, 1::2] - predictions), dim=1) | |
), | |
) | |
valid = torch.isfinite(L1err) | |
gtspeed = ( | |
torch.sqrt(torch.sum(torch.square(gt[:, :, 0::2, 0::2]), dim=1)) | |
+ torch.sqrt(torch.sum(torch.square(gt[:, :, 0::2, 1::2]), dim=1)) | |
+ torch.sqrt(torch.sum(torch.square(gt[:, :, 1::2, 0::2]), dim=1)) | |
+ torch.sqrt(torch.sum(torch.square(gt[:, :, 1::2, 1::2]), dim=1)) | |
) / 4.0 # let's just average them | |
else: | |
valid = torch.isfinite(gt[:, 0, :, :]) # both x and y would be infinite | |
L1err = torch.sum(torch.abs(gt - predictions), dim=1) | |
L2err = torch.sqrt(torch.sum(torch.square(gt - predictions), dim=1)) | |
gtspeed = torch.sqrt(torch.sum(torch.square(gt), dim=1)) | |
N = valid.sum() | |
Nnew = self.agg_N + N | |
self.agg_L1err = ( | |
float(self.agg_N) / Nnew * self.agg_L1err | |
+ L1err[valid].mean().cpu() * float(N) / Nnew | |
) | |
self.agg_L2err = ( | |
float(self.agg_N) / Nnew * self.agg_L2err | |
+ L2err[valid].mean().cpu() * float(N) / Nnew | |
) | |
self.agg_N = Nnew | |
for i, th in enumerate(self.bad_ths): | |
self.agg_Nbad[i] += (L2err[valid] > th).sum().cpu() | |
for i, (th1, th2) in enumerate(self.speed_ths): | |
vv = (gtspeed[valid] >= th1) * (gtspeed[valid] < th2) | |
iNspeed = vv.sum() | |
if iNspeed == 0: | |
continue | |
iNnew = self.agg_Nspeed[i] + iNspeed | |
self.agg_EPEspeed[i] = ( | |
float(self.agg_Nspeed[i]) / iNnew * self.agg_EPEspeed[i] | |
+ float(iNspeed) / iNnew * L2err[valid][vv].mean().cpu() | |
) | |
self.agg_Nspeed[i] = iNnew | |
def _compute_metrics(self): | |
if self._metrics is not None: | |
return | |
out = {} | |
out["L1err"] = self.agg_L1err.item() | |
out["EPE"] = self.agg_L2err.item() | |
for i, th in enumerate(self.bad_ths): | |
out["bad@{:.1f}".format(th)] = ( | |
float(self.agg_Nbad[i]) / self.agg_N | |
).item() * 100.0 | |
for i, (th1, th2) in enumerate(self.speed_ths): | |
out[ | |
"s{:d}{:s}".format(th1, "-" + str(th2) if th2 < torch.inf else "+") | |
] = self.agg_EPEspeed[i].item() | |
self._metrics = out | |
def get_results(self): | |
self._compute_metrics() # to avoid recompute them multiple times | |
return self._metrics | |