Spaces:
Runtime error
Runtime error
# Copyright (C) 2022-present Naver Corporation. All rights reserved. | |
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). | |
# -------------------------------------------------------- | |
# Main function for training one epoch or testing | |
# -------------------------------------------------------- | |
import math | |
import sys | |
from typing import Iterable | |
import numpy as np | |
import torch | |
import torchvision | |
from utils import misc as misc | |
def split_prediction_conf(predictions, with_conf=False): | |
if not with_conf: | |
return predictions, None | |
conf = predictions[:, -1:, :, :] | |
predictions = predictions[:, :-1, :, :] | |
return predictions, conf | |
def train_one_epoch( | |
model: torch.nn.Module, | |
criterion: torch.nn.Module, | |
metrics: torch.nn.Module, | |
data_loader: Iterable, | |
optimizer: torch.optim.Optimizer, | |
device: torch.device, | |
epoch: int, | |
loss_scaler, | |
log_writer=None, | |
print_freq=20, | |
args=None, | |
): | |
model.train(True) | |
metric_logger = misc.MetricLogger(delimiter=" ") | |
metric_logger.add_meter("lr", misc.SmoothedValue(window_size=1, fmt="{value:.6f}")) | |
header = "Epoch: [{}]".format(epoch) | |
accum_iter = args.accum_iter | |
optimizer.zero_grad() | |
details = {} | |
if log_writer is not None: | |
print("log_dir: {}".format(log_writer.log_dir)) | |
if args.img_per_epoch: | |
iter_per_epoch = args.img_per_epoch // args.batch_size + int( | |
args.img_per_epoch % args.batch_size > 0 | |
) | |
assert ( | |
len(data_loader) >= iter_per_epoch | |
), "Dataset is too small for so many iterations" | |
len_data_loader = iter_per_epoch | |
else: | |
len_data_loader, iter_per_epoch = len(data_loader), None | |
for data_iter_step, (image1, image2, gt, pairname) in enumerate( | |
metric_logger.log_every( | |
data_loader, print_freq, header, max_iter=iter_per_epoch | |
) | |
): | |
image1 = image1.to(device, non_blocking=True) | |
image2 = image2.to(device, non_blocking=True) | |
gt = gt.to(device, non_blocking=True) | |
# we use a per iteration (instead of per epoch) lr scheduler | |
if data_iter_step % accum_iter == 0: | |
misc.adjust_learning_rate( | |
optimizer, data_iter_step / len_data_loader + epoch, args | |
) | |
with torch.cuda.amp.autocast(enabled=bool(args.amp)): | |
prediction = model(image1, image2) | |
prediction, conf = split_prediction_conf(prediction, criterion.with_conf) | |
batch_metrics = metrics(prediction.detach(), gt) | |
loss = ( | |
criterion(prediction, gt) | |
if conf is None | |
else criterion(prediction, gt, conf) | |
) | |
loss_value = loss.item() | |
if not math.isfinite(loss_value): | |
print("Loss is {}, stopping training".format(loss_value)) | |
sys.exit(1) | |
loss /= accum_iter | |
loss_scaler( | |
loss, | |
optimizer, | |
parameters=model.parameters(), | |
update_grad=(data_iter_step + 1) % accum_iter == 0, | |
) | |
if (data_iter_step + 1) % accum_iter == 0: | |
optimizer.zero_grad() | |
torch.cuda.synchronize() | |
metric_logger.update(loss=loss_value) | |
for k, v in batch_metrics.items(): | |
metric_logger.update(**{k: v.item()}) | |
lr = optimizer.param_groups[0]["lr"] | |
metric_logger.update(lr=lr) | |
# if args.dsitributed: loss_value_reduce = misc.all_reduce_mean(loss_value) | |
time_to_log = (data_iter_step + 1) % ( | |
args.tboard_log_step * accum_iter | |
) == 0 or data_iter_step == len_data_loader - 1 | |
loss_value_reduce = misc.all_reduce_mean(loss_value) | |
if log_writer is not None and time_to_log: | |
epoch_1000x = int((data_iter_step / len_data_loader + epoch) * 1000) | |
# We use epoch_1000x as the x-axis in tensorboard. This calibrates different curves when batch size changes. | |
log_writer.add_scalar("train/loss", loss_value_reduce, epoch_1000x) | |
log_writer.add_scalar("lr", lr, epoch_1000x) | |
for k, v in batch_metrics.items(): | |
log_writer.add_scalar("train/" + k, v.item(), epoch_1000x) | |
# gather the stats from all processes | |
# if args.distributed: metric_logger.synchronize_between_processes() | |
print("Averaged stats:", metric_logger) | |
return {k: meter.global_avg for k, meter in metric_logger.meters.items()} | |
def validate_one_epoch( | |
model: torch.nn.Module, | |
criterion: torch.nn.Module, | |
metrics: torch.nn.Module, | |
data_loaders: list[Iterable], | |
device: torch.device, | |
epoch: int, | |
log_writer=None, | |
args=None, | |
): | |
model.eval() | |
metric_loggers = [] | |
header = "Epoch: [{}]".format(epoch) | |
print_freq = 20 | |
conf_mode = args.tile_conf_mode | |
crop = args.crop | |
if log_writer is not None: | |
print("log_dir: {}".format(log_writer.log_dir)) | |
results = {} | |
dnames = [] | |
image1, image2, gt, prediction = None, None, None, None | |
for didx, data_loader in enumerate(data_loaders): | |
dname = str(data_loader.dataset) | |
dnames.append(dname) | |
metric_loggers.append(misc.MetricLogger(delimiter=" ")) | |
for data_iter_step, (image1, image2, gt, pairname) in enumerate( | |
metric_loggers[didx].log_every(data_loader, print_freq, header) | |
): | |
image1 = image1.to(device, non_blocking=True) | |
image2 = image2.to(device, non_blocking=True) | |
gt = gt.to(device, non_blocking=True) | |
if dname.startswith("Spring"): | |
assert ( | |
gt.size(2) == image1.size(2) * 2 | |
and gt.size(3) == image1.size(3) * 2 | |
) | |
gt = ( | |
gt[:, :, 0::2, 0::2] | |
+ gt[:, :, 0::2, 1::2] | |
+ gt[:, :, 1::2, 0::2] | |
+ gt[:, :, 1::2, 1::2] | |
) / 4.0 # we approximate the gt based on the 2x upsampled ones | |
with torch.inference_mode(): | |
prediction, tiled_loss, c = tiled_pred( | |
model, | |
criterion, | |
image1, | |
image2, | |
gt, | |
conf_mode=conf_mode, | |
overlap=args.val_overlap, | |
crop=crop, | |
with_conf=criterion.with_conf, | |
) | |
batch_metrics = metrics(prediction.detach(), gt) | |
loss = ( | |
criterion(prediction.detach(), gt) | |
if not criterion.with_conf | |
else criterion(prediction.detach(), gt, c) | |
) | |
loss_value = loss.item() | |
metric_loggers[didx].update(loss_tiled=tiled_loss.item()) | |
metric_loggers[didx].update(**{f"loss": loss_value}) | |
for k, v in batch_metrics.items(): | |
metric_loggers[didx].update(**{dname + "_" + k: v.item()}) | |
results = { | |
k: meter.global_avg for ml in metric_loggers for k, meter in ml.meters.items() | |
} | |
if len(dnames) > 1: | |
for k in batch_metrics.keys(): | |
results["AVG_" + k] = sum( | |
results[dname + "_" + k] for dname in dnames | |
) / len(dnames) | |
if log_writer is not None: | |
epoch_1000x = int((1 + epoch) * 1000) | |
for k, v in results.items(): | |
log_writer.add_scalar("val/" + k, v, epoch_1000x) | |
print("Averaged stats:", results) | |
return results | |
import torch.nn.functional as F | |
def _resize_img(img, new_size): | |
return F.interpolate(img, size=new_size, mode="bicubic", align_corners=False) | |
def _resize_stereo_or_flow(data, new_size): | |
assert data.ndim == 4 | |
assert data.size(1) in [1, 2] | |
scale_x = new_size[1] / float(data.size(3)) | |
out = F.interpolate(data, size=new_size, mode="bicubic", align_corners=False) | |
out[:, 0, :, :] *= scale_x | |
if out.size(1) == 2: | |
scale_y = new_size[0] / float(data.size(2)) | |
out[:, 1, :, :] *= scale_y | |
print(scale_x, new_size, data.shape) | |
return out | |
def tiled_pred( | |
model, | |
criterion, | |
img1, | |
img2, | |
gt, | |
overlap=0.5, | |
bad_crop_thr=0.05, | |
downscale=False, | |
crop=512, | |
ret="loss", | |
conf_mode="conf_expsigmoid_10_5", | |
with_conf=False, | |
return_time=False, | |
): | |
# for each image, we are going to run inference on many overlapping patches | |
# then, all predictions will be weighted-averaged | |
if gt is not None: | |
B, C, H, W = gt.shape | |
else: | |
B, _, H, W = img1.shape | |
C = model.head.num_channels - int(with_conf) | |
win_height, win_width = crop[0], crop[1] | |
# upscale to be larger than the crop | |
do_change_scale = H < win_height or W < win_width | |
if do_change_scale: | |
upscale_factor = max(win_width / W, win_height / W) | |
original_size = (H, W) | |
new_size = (round(H * upscale_factor), round(W * upscale_factor)) | |
img1 = _resize_img(img1, new_size) | |
img2 = _resize_img(img2, new_size) | |
# resize gt just for the computation of tiled losses | |
if gt is not None: | |
gt = _resize_stereo_or_flow(gt, new_size) | |
H, W = img1.shape[2:4] | |
if conf_mode.startswith("conf_expsigmoid_"): # conf_expsigmoid_30_10 | |
beta, betasigmoid = map(float, conf_mode[len("conf_expsigmoid_") :].split("_")) | |
elif conf_mode.startswith("conf_expbeta"): # conf_expbeta3 | |
beta = float(conf_mode[len("conf_expbeta") :]) | |
else: | |
raise NotImplementedError(f"conf_mode {conf_mode} is not implemented") | |
def crop_generator(): | |
for sy in _overlapping(H, win_height, overlap): | |
for sx in _overlapping(W, win_width, overlap): | |
yield sy, sx, sy, sx, True | |
# keep track of weighted sum of prediction*weights and weights | |
accu_pred = img1.new_zeros( | |
(B, C, H, W) | |
) # accumulate the weighted sum of predictions | |
accu_conf = img1.new_zeros((B, H, W)) + 1e-16 # accumulate the weights | |
accu_c = img1.new_zeros( | |
(B, H, W) | |
) # accumulate the weighted sum of confidences ; not so useful except for computing some losses | |
tiled_losses = [] | |
if return_time: | |
start = torch.cuda.Event(enable_timing=True) | |
end = torch.cuda.Event(enable_timing=True) | |
start.record() | |
for sy1, sx1, sy2, sx2, aligned in crop_generator(): | |
# compute optical flow there | |
pred = model(_crop(img1, sy1, sx1), _crop(img2, sy2, sx2)) | |
pred, predconf = split_prediction_conf(pred, with_conf=with_conf) | |
if gt is not None: | |
gtcrop = _crop(gt, sy1, sx1) | |
if criterion is not None and gt is not None: | |
tiled_losses.append( | |
criterion(pred, gtcrop).item() | |
if predconf is None | |
else criterion(pred, gtcrop, predconf).item() | |
) | |
if conf_mode.startswith("conf_expsigmoid_"): | |
conf = torch.exp( | |
-beta * 2 * (torch.sigmoid(predconf / betasigmoid) - 0.5) | |
).view(B, win_height, win_width) | |
elif conf_mode.startswith("conf_expbeta"): | |
conf = torch.exp(-beta * predconf).view(B, win_height, win_width) | |
else: | |
raise NotImplementedError | |
accu_pred[..., sy1, sx1] += pred * conf[:, None, :, :] | |
accu_conf[..., sy1, sx1] += conf | |
accu_c[..., sy1, sx1] += predconf.view(B, win_height, win_width) * conf | |
pred = accu_pred / accu_conf[:, None, :, :] | |
c = accu_c / accu_conf | |
assert not torch.any(torch.isnan(pred)) | |
if return_time: | |
end.record() | |
torch.cuda.synchronize() | |
time = start.elapsed_time(end) / 1000.0 # this was in milliseconds | |
if do_change_scale: | |
pred = _resize_stereo_or_flow(pred, original_size) | |
if return_time: | |
return pred, torch.mean(torch.tensor(tiled_losses)), c, time | |
return pred, torch.mean(torch.tensor(tiled_losses)), c | |
def _overlapping(total, window, overlap=0.5): | |
assert total >= window and 0 <= overlap < 1, (total, window, overlap) | |
num_windows = 1 + int(np.ceil((total - window) / ((1 - overlap) * window))) | |
offsets = np.linspace(0, total - window, num_windows).round().astype(int) | |
yield from (slice(x, x + window) for x in offsets) | |
def _crop(img, sy, sx): | |
B, THREE, H, W = img.shape | |
if 0 <= sy.start and sy.stop <= H and 0 <= sx.start and sx.stop <= W: | |
return img[:, :, sy, sx] | |
l, r = max(0, -sx.start), max(0, sx.stop - W) | |
t, b = max(0, -sy.start), max(0, sy.stop - H) | |
img = torch.nn.functional.pad(img, (l, r, t, b), mode="constant") | |
return img[:, :, slice(sy.start + t, sy.stop + t), slice(sx.start + l, sx.stop + l)] | |