Uploading the app
Browse files- .gitattributes +37 -35
- .gitignore +26 -0
- README.md +14 -14
- _data/.gitkeep +0 -0
- _data/example_images/frame1.png +3 -0
- _data/example_images/frame3.png +3 -0
- app.py +76 -0
- config/confg.yaml +64 -0
- model/hub.py +12 -0
- model/model.py +220 -0
- model/train_pipline.py +177 -0
- modules/basic_layers.py +313 -0
- modules/cupy_module/correlation.py +402 -0
- modules/cupy_module/cupy_utils.py +7 -0
- modules/cupy_module/nedt.py +129 -0
- modules/cupy_module/softsplat.py +368 -0
- modules/feature_extactor.py +87 -0
- modules/flow_models/flow_models.py +102 -0
- modules/flow_models/raft/LICENSE +29 -0
- modules/flow_models/raft/corr.py +56 -0
- modules/flow_models/raft/extractor.py +342 -0
- modules/flow_models/raft/rfr_new.py +235 -0
- modules/flow_models/raft/update.py +139 -0
- modules/flow_models/raft/utils.py +81 -0
- modules/half_warper.py +129 -0
- modules/synthesizer.py +277 -0
- requirements.txt +42 -0
- utils/ema.py +32 -0
- utils/inter_frame_idx.py +123 -0
- utils/raft.py +20 -0
- utils/uncertainty.py +49 -0
- utils/utils.py +83 -0
.gitattributes
CHANGED
@@ -1,35 +1,37 @@
|
|
1 |
-
*.7z filter=lfs diff=lfs merge=lfs -text
|
2 |
-
*.arrow filter=lfs diff=lfs merge=lfs -text
|
3 |
-
*.bin filter=lfs diff=lfs merge=lfs -text
|
4 |
-
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
5 |
-
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
6 |
-
*.ftz filter=lfs diff=lfs merge=lfs -text
|
7 |
-
*.gz filter=lfs diff=lfs merge=lfs -text
|
8 |
-
*.h5 filter=lfs diff=lfs merge=lfs -text
|
9 |
-
*.joblib filter=lfs diff=lfs merge=lfs -text
|
10 |
-
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
11 |
-
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
12 |
-
*.model filter=lfs diff=lfs merge=lfs -text
|
13 |
-
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
14 |
-
*.npy filter=lfs diff=lfs merge=lfs -text
|
15 |
-
*.npz filter=lfs diff=lfs merge=lfs -text
|
16 |
-
*.onnx filter=lfs diff=lfs merge=lfs -text
|
17 |
-
*.ot filter=lfs diff=lfs merge=lfs -text
|
18 |
-
*.parquet filter=lfs diff=lfs merge=lfs -text
|
19 |
-
*.pb filter=lfs diff=lfs merge=lfs -text
|
20 |
-
*.pickle filter=lfs diff=lfs merge=lfs -text
|
21 |
-
*.pkl filter=lfs diff=lfs merge=lfs -text
|
22 |
-
*.pt filter=lfs diff=lfs merge=lfs -text
|
23 |
-
*.pth filter=lfs diff=lfs merge=lfs -text
|
24 |
-
*.rar filter=lfs diff=lfs merge=lfs -text
|
25 |
-
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
26 |
-
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
27 |
-
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
28 |
-
*.tar filter=lfs diff=lfs merge=lfs -text
|
29 |
-
*.tflite filter=lfs diff=lfs merge=lfs -text
|
30 |
-
*.tgz filter=lfs diff=lfs merge=lfs -text
|
31 |
-
*.wasm filter=lfs diff=lfs merge=lfs -text
|
32 |
-
*.xz filter=lfs diff=lfs merge=lfs -text
|
33 |
-
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
-
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
-
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
5 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
6 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
7 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
8 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
9 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
10 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
11 |
+
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
12 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
13 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
14 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
15 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
16 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
17 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
18 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
19 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
20 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
21 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
22 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
23 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
24 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
25 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
26 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
27 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
28 |
+
*.tar filter=lfs diff=lfs merge=lfs -text
|
29 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
30 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
31 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
32 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
33 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
_data/example_images/frame1.png filter=lfs diff=lfs merge=lfs -text
|
37 |
+
_data/example_images/frame3.png filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
*.png
|
2 |
+
*.jpg
|
3 |
+
*.jpeg
|
4 |
+
*.gif
|
5 |
+
*.bmp
|
6 |
+
*.tiff
|
7 |
+
*.ico
|
8 |
+
!_data/example_images/frame1.png
|
9 |
+
!_data/example_images/frame3.png
|
10 |
+
|
11 |
+
__pycache__
|
12 |
+
|
13 |
+
*.pyc
|
14 |
+
*.pyo
|
15 |
+
*.pyd
|
16 |
+
*.pyw
|
17 |
+
*.pyz
|
18 |
+
|
19 |
+
*.ckpt
|
20 |
+
*.pt
|
21 |
+
*.pth
|
22 |
+
!metrics/flolpips/weights/v0.1/alex.pth
|
23 |
+
|
24 |
+
*.ipynb
|
25 |
+
|
26 |
+
|
README.md
CHANGED
@@ -1,14 +1,14 @@
|
|
1 |
-
---
|
2 |
-
title: Multi Input Res Diffusion VFI
|
3 |
-
emoji: 🚀
|
4 |
-
colorFrom: blue
|
5 |
-
colorTo: green
|
6 |
-
sdk: gradio
|
7 |
-
sdk_version: 5.25.2
|
8 |
-
app_file: app.py
|
9 |
-
pinned: false
|
10 |
-
license: mit
|
11 |
-
short_description: Gradio demo for Multi-Input ResShift Diffusion VFI
|
12 |
-
---
|
13 |
-
|
14 |
-
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
|
|
1 |
+
---
|
2 |
+
title: Multi Input Res Diffusion VFI
|
3 |
+
emoji: 🚀
|
4 |
+
colorFrom: blue
|
5 |
+
colorTo: green
|
6 |
+
sdk: gradio
|
7 |
+
sdk_version: 5.25.2
|
8 |
+
app_file: app.py
|
9 |
+
pinned: false
|
10 |
+
license: mit
|
11 |
+
short_description: Gradio demo for Multi-Input ResShift Diffusion VFI
|
12 |
+
---
|
13 |
+
|
14 |
+
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
_data/.gitkeep
ADDED
File without changes
|
_data/example_images/frame1.png
ADDED
![]() |
Git LFS Details
|
_data/example_images/frame3.png
ADDED
![]() |
Git LFS Details
|
app.py
ADDED
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
|
3 |
+
from PIL import Image
|
4 |
+
from torchvision.transforms import Compose, ToTensor, Resize, Normalize
|
5 |
+
import numpy as np
|
6 |
+
import imageio
|
7 |
+
import tempfile
|
8 |
+
|
9 |
+
from utils.utils import denorm
|
10 |
+
from model.hub import MultiInputResShiftHub
|
11 |
+
|
12 |
+
model = MultiInputResShiftHub.from_pretrained("vfontech/Multiple-Input-Resshift-VFI")
|
13 |
+
model.requires_grad_(False).cuda().eval()
|
14 |
+
|
15 |
+
transform = Compose([
|
16 |
+
Resize((256, 448)),
|
17 |
+
ToTensor(),
|
18 |
+
Normalize(mean=[0.5]*3, std=[0.5]*3),
|
19 |
+
])
|
20 |
+
|
21 |
+
def to_numpy(img_tensor):
|
22 |
+
img_np = denorm(img_tensor, mean=[0.5]*3, std=[0.5]*3).squeeze().permute(1, 2, 0).cpu().numpy()
|
23 |
+
img_np = np.clip(img_np, 0, 1)
|
24 |
+
return (img_np * 255).astype(np.uint8)
|
25 |
+
|
26 |
+
def interpolate(img0_pil, img2_pil, tau, num_samples):
|
27 |
+
img0 = transform(img0_pil.convert("RGB")).unsqueeze(0).cuda()
|
28 |
+
img2 = transform(img2_pil.convert("RGB")).unsqueeze(0).cuda()
|
29 |
+
|
30 |
+
if num_samples == 1:
|
31 |
+
# Unique image
|
32 |
+
img1 = model.reverse_process([img0, img2], tau)
|
33 |
+
return Image.fromarray(to_numpy(img1)), None
|
34 |
+
else:
|
35 |
+
# Múltiples imágenes → video
|
36 |
+
frames = [to_numpy(img0)]
|
37 |
+
for t in np.linspace(0, 1, num_samples):
|
38 |
+
img = model.reverse_process([img0, img2], float(t))
|
39 |
+
frames.append(to_numpy(img))
|
40 |
+
frames.append(to_numpy(img2))
|
41 |
+
|
42 |
+
temp_path = tempfile.NamedTemporaryFile(suffix=".mp4", delete=False).name
|
43 |
+
imageio.mimsave(temp_path, frames, fps=8)
|
44 |
+
return None, temp_path
|
45 |
+
|
46 |
+
demo = gr.Interface(
|
47 |
+
fn=interpolate,
|
48 |
+
inputs=[
|
49 |
+
gr.Image(type="pil", label="Initial Image (frame1)"),
|
50 |
+
gr.Image(type="pil", label="Final Image (frame3)"),
|
51 |
+
gr.Slider(0.0, 1.0, step=0.05, value=0.5, label="Tau Value (only if Num Samples = 1)"),
|
52 |
+
gr.Slider(1, 15, step=1, value=1, label="Number of Samples"),
|
53 |
+
],
|
54 |
+
outputs=[
|
55 |
+
gr.Image(label="Interpolated Image (if num_samples = 1)"),
|
56 |
+
gr.Video(label="Interpolation in video (if num_samples > 1)"),
|
57 |
+
],
|
58 |
+
title="Multi-Input ResShift Diffusion VFI",
|
59 |
+
description=(
|
60 |
+
"📄 [arXiv Paper](https://arxiv.org/pdf/2504.05402) • "
|
61 |
+
"🤗 [Model](https://huggingface.co/vfontech/Multiple-Input-Resshift-VFI) • "
|
62 |
+
"🧪 [Colab](https://colab.research.google.com/drive/1MGYycbNMW6Mxu5MUqw_RW_xxiVeHK5Aa#scrollTo=EKaYCioiP3tQ) • "
|
63 |
+
"🌐 [GitHub](https://github.com/VicFonch/Multi-Input-Resshift-Diffusion-VFI)\n\n"
|
64 |
+
"Video interpolation using Conditional Residual Diffusion.\n"
|
65 |
+
"- All images are resized to 256x448.\n"
|
66 |
+
"- If `Number of Samples` = 1, generates only one intermediate image with the given Tau value.\n"
|
67 |
+
"- If `Number of Samples` > 1, ignores Tau and generates a sequence of interpolated images."
|
68 |
+
),
|
69 |
+
examples=[
|
70 |
+
["_data/example_images/frame1.png", "_data/example_images/frame3.png", 0.5],
|
71 |
+
],
|
72 |
+
)
|
73 |
+
|
74 |
+
if __name__ == "__main__":
|
75 |
+
demo.queue(max_size=12)
|
76 |
+
demo.launch(max_threads=1)
|
config/confg.yaml
ADDED
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
data_confg:
|
2 |
+
train_batch_size: 6
|
3 |
+
val_batch_size: 6
|
4 |
+
test_batch_size: 6
|
5 |
+
flow_method: raft
|
6 |
+
data_domain: animation
|
7 |
+
datamodule_confg:
|
8 |
+
mean: [0.5, 0.5, 0.5]
|
9 |
+
sd: [0.5, 0.5, 0.5]
|
10 |
+
size: [256, 448]
|
11 |
+
amount_augmentations: 1
|
12 |
+
horizontal_flip: 0.5
|
13 |
+
time_flip: True
|
14 |
+
rotation: 0
|
15 |
+
brightness: 0.2
|
16 |
+
contrast: 0.2
|
17 |
+
saturation: 0.2
|
18 |
+
hue: 0.1
|
19 |
+
|
20 |
+
trainer_confg:
|
21 |
+
accumulate_grad_batches: 5
|
22 |
+
gradient_clip_val: 1.0
|
23 |
+
max_epochs: 500
|
24 |
+
num_nodes: 1
|
25 |
+
devices: 2
|
26 |
+
accelerator: gpu
|
27 |
+
strategy: ddp_find_unused_parameters_true
|
28 |
+
|
29 |
+
optim_confg:
|
30 |
+
optimizer_confg: # AdamW
|
31 |
+
lr: 1.0e-4
|
32 |
+
betas: [0.9, 0.999]
|
33 |
+
eps: 1.0e-8
|
34 |
+
scheduler_confg: # ReduceLROnPlateau
|
35 |
+
mode: min
|
36 |
+
factor: 0.5
|
37 |
+
patience: 3
|
38 |
+
verbose: True
|
39 |
+
|
40 |
+
pretrained_model_path: null # Fine-tune model path
|
41 |
+
|
42 |
+
model_confg:
|
43 |
+
kappa: 2.0
|
44 |
+
timesteps: 20
|
45 |
+
p: 0.3
|
46 |
+
etas_end: 0.99
|
47 |
+
min_noise_level: 0.04
|
48 |
+
flow_model: raft
|
49 |
+
flow_kwargs:
|
50 |
+
pretrained_path: null #_pretrain_models/anime_interp_full.ckpt
|
51 |
+
warping_kwargs:
|
52 |
+
in_channels: 3
|
53 |
+
channels: [128, 256, 384, 512]
|
54 |
+
synthesis_kwargs:
|
55 |
+
in_channels: 3
|
56 |
+
channels: [128, 256, 384, 512]
|
57 |
+
temb_channels: 512
|
58 |
+
heads: 1
|
59 |
+
window_size: 8
|
60 |
+
window_attn: True
|
61 |
+
grid_attn: True
|
62 |
+
expansion_rate: 1.5
|
63 |
+
num_conv_blocks: 1
|
64 |
+
dropout: 0.0
|
model/hub.py
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from model.model import MultiInputResShift
|
2 |
+
from huggingface_hub import PyTorchModelHubMixin
|
3 |
+
|
4 |
+
class MultiInputResShiftHub(
|
5 |
+
MultiInputResShift,
|
6 |
+
PyTorchModelHubMixin,
|
7 |
+
repo_url="https://github.com/VicFonch/Multi-Input-Resshift-Diffusion-VFI",
|
8 |
+
paper_url="https://arxiv.org/pdf/2504.05402",
|
9 |
+
language="en",
|
10 |
+
):
|
11 |
+
def __init__(self, *args, **kwargs):
|
12 |
+
super().__init__(*args, **kwargs)
|
model/model.py
ADDED
@@ -0,0 +1,220 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
from torch.nn.functional import interpolate
|
4 |
+
|
5 |
+
import math
|
6 |
+
from tqdm import tqdm
|
7 |
+
|
8 |
+
from modules.feature_extactor import Extractor
|
9 |
+
from modules.half_warper import HalfWarper
|
10 |
+
from modules.cupy_module.nedt import NEDT
|
11 |
+
from modules.flow_models.flow_models import (
|
12 |
+
RAFTFineFlow,
|
13 |
+
PWCFineFlow
|
14 |
+
)
|
15 |
+
from modules.synthesizer import Synthesis
|
16 |
+
|
17 |
+
class FeatureWarper(nn.Module):
|
18 |
+
def __init__(
|
19 |
+
self,
|
20 |
+
in_channels: int = 3,
|
21 |
+
channels: list[int] = [32, 64, 128, 256],
|
22 |
+
):
|
23 |
+
super().__init__()
|
24 |
+
channels = [in_channels + 1] + channels
|
25 |
+
|
26 |
+
self.half_warper = HalfWarper()
|
27 |
+
self.feature_extractor = Extractor(channels)
|
28 |
+
self.nedt = NEDT()
|
29 |
+
|
30 |
+
def forward(
|
31 |
+
self,
|
32 |
+
I0: torch.Tensor,
|
33 |
+
I1: torch.Tensor,
|
34 |
+
flow0to1: torch.Tensor,
|
35 |
+
flow1to0: torch.Tensor,
|
36 |
+
tau: torch.Tensor = None
|
37 |
+
) -> tuple[list[torch.Tensor], list[torch.Tensor]]:
|
38 |
+
assert tau.shape == (I0.shape[0], 2), "tau shape must be (batch, 2)"
|
39 |
+
|
40 |
+
flow0tot = tau[:, 0][:, None, None, None] * flow0to1
|
41 |
+
flow1tot = tau[:, 1][:, None, None, None] * flow1to0
|
42 |
+
|
43 |
+
I0 = torch.cat([I0, self.nedt(I0)], dim=1)
|
44 |
+
I1 = torch.cat([I1, self.nedt(I1)], dim=1)
|
45 |
+
|
46 |
+
z0to1, z1to0 = HalfWarper.z_metric(I0, I1, flow0to1, flow1to0)
|
47 |
+
base0, base1 = self.half_warper(I0, I1, flow0tot, flow1tot, z0to1, z1to0)
|
48 |
+
warped0, warped1 = [base0], [base1]
|
49 |
+
|
50 |
+
features0 = self.feature_extractor(I0)
|
51 |
+
features1 = self.feature_extractor(I1)
|
52 |
+
|
53 |
+
for feat0, feat1 in zip(features0, features1):
|
54 |
+
f0 = interpolate(flow0tot, size=feat0.shape[2:], mode='bilinear', align_corners=False)
|
55 |
+
f1 = interpolate(flow1tot, size=feat0.shape[2:], mode='bilinear', align_corners=False)
|
56 |
+
z0 = interpolate(z0to1, size=feat0.shape[2:], mode='bilinear', align_corners=False)
|
57 |
+
z1 = interpolate(z1to0, size=feat0.shape[2:], mode='bilinear', align_corners=False)
|
58 |
+
w0, w1 = self.half_warper(feat0, feat1, f0, f1, z0, z1)
|
59 |
+
warped0.append(w0)
|
60 |
+
warped1.append(w1)
|
61 |
+
return warped0, warped1
|
62 |
+
|
63 |
+
class MultiInputResShift(nn.Module):
|
64 |
+
def __init__(
|
65 |
+
self,
|
66 |
+
kappa: float=2.0,
|
67 |
+
p: float =0.3,
|
68 |
+
min_noise_level: float=0.04,
|
69 |
+
etas_end: float=0.99,
|
70 |
+
timesteps: int=15,
|
71 |
+
flow_model: str = 'raft',
|
72 |
+
flow_kwargs: dict = {},
|
73 |
+
warping_kwargs: dict = {},
|
74 |
+
synthesis_kwargs: dict = {}
|
75 |
+
):
|
76 |
+
super().__init__()
|
77 |
+
|
78 |
+
self.timesteps = timesteps
|
79 |
+
self.kappa = kappa
|
80 |
+
self.eta_partition = None
|
81 |
+
|
82 |
+
sqrt_eta_1 = min(min_noise_level / kappa, min_noise_level, math.sqrt(0.001))
|
83 |
+
b0 = math.exp(1/float(timesteps - 1) * math.log(etas_end/sqrt_eta_1))
|
84 |
+
base = torch.ones(timesteps)*b0
|
85 |
+
beta = ((torch.linspace(0,1,timesteps))**p)*(timesteps-1)
|
86 |
+
sqrt_eta = torch.pow(base, beta) * sqrt_eta_1
|
87 |
+
|
88 |
+
self.register_buffer("sqrt_sum_eta", sqrt_eta)
|
89 |
+
self.register_buffer("sum_eta", sqrt_eta**2)
|
90 |
+
|
91 |
+
sum_prev_eta = torch.roll(self.sum_eta, 1)
|
92 |
+
sum_prev_eta[0] = 0
|
93 |
+
self.register_buffer("sum_prev_eta", sum_prev_eta)
|
94 |
+
|
95 |
+
self.register_buffer("sum_alpha", self.sum_eta - self.sum_prev_eta)
|
96 |
+
|
97 |
+
self.register_buffer("backward_mean_c1", self.sum_prev_eta / self.sum_eta)
|
98 |
+
self.register_buffer("backward_mean_c2", self.sum_alpha / self.sum_eta)
|
99 |
+
self.register_buffer("backward_std", self.kappa*torch.sqrt(self.sum_prev_eta*self.sum_alpha/self.sum_eta))
|
100 |
+
|
101 |
+
if flow_model == 'raft':
|
102 |
+
self.flow_model = RAFTFineFlow(**flow_kwargs)
|
103 |
+
elif flow_model == 'pwc':
|
104 |
+
self.flow_model = PWCFineFlow(**flow_kwargs)
|
105 |
+
else:
|
106 |
+
raise ValueError(f"Flow model {flow_model} not supported")
|
107 |
+
|
108 |
+
self.feature_warper = FeatureWarper(**warping_kwargs)
|
109 |
+
self.synthesis = Synthesis(**synthesis_kwargs)
|
110 |
+
|
111 |
+
def forward_process(
|
112 |
+
self,
|
113 |
+
x: torch.Tensor | None,
|
114 |
+
Y: list[torch.Tensor],
|
115 |
+
tau: torch.Tensor | float | None,
|
116 |
+
t: torch.Tensor | int
|
117 |
+
) -> torch.Tensor:
|
118 |
+
if tau is None:
|
119 |
+
tau: torch.Tensor = torch.full((x.shape[0], len(Y)), 0.5, device=x.device, dtype=x.dtype)
|
120 |
+
elif isinstance(tau, float):
|
121 |
+
assert tau >= 0 and tau <= 1, "tau must be between 0 and 1"
|
122 |
+
tau: torch.Tensor = torch.cat([
|
123 |
+
torch.full((x.shape[0], 1), tau, device=x.device, dtype=x.dtype),
|
124 |
+
torch.full((x.shape[0], 1), 1 - tau, device=x.device, dtype=x.dtype)
|
125 |
+
], dim=1)
|
126 |
+
if not torch.is_tensor(t):
|
127 |
+
t: torch.Tensor = torch.tensor([t], device=x.device, dtype=torch.long)
|
128 |
+
if x is None:
|
129 |
+
x: torch.Tensor = torch.zeros_like(Y[0])
|
130 |
+
|
131 |
+
eta = self.sum_eta[t][:, None] * tau
|
132 |
+
eta = eta[:, :, None, None, None].transpose(0, 1)
|
133 |
+
|
134 |
+
e_i = torch.stack([y - x for y in Y])
|
135 |
+
mean = x + (eta*e_i).sum(dim=0)
|
136 |
+
|
137 |
+
sqrt_sum_eta = self.sqrt_sum_eta[t][:, None, None, None]
|
138 |
+
std = self.kappa*sqrt_sum_eta
|
139 |
+
epsilon = torch.randn_like(x)
|
140 |
+
|
141 |
+
return mean + std*epsilon
|
142 |
+
|
143 |
+
@torch.inference_mode()
|
144 |
+
def reverse_process(
|
145 |
+
self,
|
146 |
+
Y: list[torch.Tensor],
|
147 |
+
tau: torch.Tensor | float,
|
148 |
+
flows: list[torch.Tensor] | None = None,
|
149 |
+
) -> torch.Tensor:
|
150 |
+
y = Y[0]
|
151 |
+
batch, device, dtype = y.shape[0], y.device, y.dtype
|
152 |
+
|
153 |
+
if isinstance(tau, float):
|
154 |
+
assert tau >= 0 and tau <= 1, "tau must be between 0 and 1"
|
155 |
+
tau: torch.Tensor = torch.cat([
|
156 |
+
torch.full((batch, 1), tau, device=device, dtype=dtype),
|
157 |
+
torch.full((batch, 1), 1 - tau, device=device, dtype=dtype)
|
158 |
+
], dim=1)
|
159 |
+
if flows is None:
|
160 |
+
flow0to1, flow1to0 = self.flow_model(Y[0], Y[1])
|
161 |
+
else:
|
162 |
+
flow0to1, flow1to0 = flows
|
163 |
+
warp0to1, warp1to0 = self.feature_warper(Y[0], Y[1], flow0to1, flow1to0, tau)
|
164 |
+
|
165 |
+
T = torch.tensor([self.timesteps-1,] * batch, device=device, dtype=torch.long)
|
166 |
+
x = self.forward_process(torch.zeros_like(Y[0]), [warp0to1[0][:, :3], warp1to0[0][:, :3]], tau, T)
|
167 |
+
|
168 |
+
pbar = tqdm(total=self.timesteps, desc="Reversing Process")
|
169 |
+
for i in reversed(range(self.timesteps)):
|
170 |
+
t = torch.ones(batch, device = device, dtype=torch.long) * i
|
171 |
+
|
172 |
+
predicted_x0 = self.synthesis(x, warp0to1, warp1to0, t)
|
173 |
+
|
174 |
+
mean_c1 = self.backward_mean_c1[t][:, None, None, None]
|
175 |
+
mean_c2 = self.backward_mean_c2[t][:, None, None, None]
|
176 |
+
std = self.backward_std[t][:, None, None, None]
|
177 |
+
|
178 |
+
eta = self.sum_eta[t][:, None] * tau
|
179 |
+
prev_eta = self.sum_prev_eta[t][:, None] * tau
|
180 |
+
eta = eta[:, :, None, None, None].transpose(0, 1)
|
181 |
+
prev_eta = prev_eta[:, :, None, None, None].transpose(0, 1)
|
182 |
+
e_i = torch.stack([y - predicted_x0 for y in Y])
|
183 |
+
|
184 |
+
mean = (
|
185 |
+
mean_c1*(x + (eta*e_i).sum(dim=0))
|
186 |
+
+ mean_c2*predicted_x0
|
187 |
+
- (prev_eta*e_i).sum(dim=0)
|
188 |
+
)
|
189 |
+
|
190 |
+
x = mean + std*torch.randn_like(x)
|
191 |
+
pbar.update(1)
|
192 |
+
pbar.close()
|
193 |
+
return x
|
194 |
+
|
195 |
+
# Training Step Only
|
196 |
+
def forward(
|
197 |
+
self,
|
198 |
+
I0: torch.Tensor,
|
199 |
+
It: torch.Tensor,
|
200 |
+
I1: torch.Tensor,
|
201 |
+
flow1to0: torch.Tensor | None = None,
|
202 |
+
flow0to1: torch.Tensor | None = None,
|
203 |
+
tau: torch.Tensor | None = None,
|
204 |
+
t: torch.Tensor | None = None
|
205 |
+
) -> torch.Tensor:
|
206 |
+
|
207 |
+
if tau is None:
|
208 |
+
tau = torch.full((It.shape[0], 2), 0.5, device=It.device, dtype=It.dtype)
|
209 |
+
|
210 |
+
if flow0to1 is None or flow1to0 is None:
|
211 |
+
flow0to1, flow1to0 = self.flow_model(I0, I1)
|
212 |
+
|
213 |
+
if t is None:
|
214 |
+
t = torch.randint(low=1, high=self.timesteps, size=(It.shape[0],), device=It.device, dtype=torch.long)
|
215 |
+
|
216 |
+
warp0to1, warp1to0 = self.feature_warper(I0, I1, flow0to1, flow1to0, tau)
|
217 |
+
x_t = self.forward_process(It, [warp0to1[0][:, :3], warp1to0[0][:, :3]], tau, t)
|
218 |
+
|
219 |
+
predicted_It = self.synthesis(x_t, warp0to1, warp1to0, t)
|
220 |
+
return predicted_It
|
model/train_pipline.py
ADDED
@@ -0,0 +1,177 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import copy
|
3 |
+
import matplotlib.pyplot as plt
|
4 |
+
from typing import Any
|
5 |
+
|
6 |
+
import torch
|
7 |
+
from torch.optim.lr_scheduler import ReduceLROnPlateau
|
8 |
+
from torch.optim import AdamW, Optimizer
|
9 |
+
from torch.utils.data import DataLoader
|
10 |
+
from lightning import LightningModule
|
11 |
+
|
12 |
+
from torchmetrics import MetricCollection
|
13 |
+
from torchmetrics.image import PeakSignalNoiseRatio as PSNR
|
14 |
+
from torchmetrics.image import StructuralSimilarityIndexMeasure as SSIM
|
15 |
+
from torchmetrics.image import LearnedPerceptualImagePatchSimilarity as LPIPS
|
16 |
+
|
17 |
+
from model.model import MultiInputResShift
|
18 |
+
|
19 |
+
from utils.utils import denorm, make_grid_images#, save_triplet
|
20 |
+
from utils.ema import EMA
|
21 |
+
from utils.inter_frame_idx import get_inter_frame_temp_index
|
22 |
+
from utils.raft import raft_flow
|
23 |
+
|
24 |
+
|
25 |
+
class TrainPipline(LightningModule):
|
26 |
+
def __init__(self,
|
27 |
+
confg: dict,
|
28 |
+
test_dataloader: DataLoader):
|
29 |
+
super(TrainPipline, self).__init__()
|
30 |
+
|
31 |
+
self.test_dataloader = test_dataloader
|
32 |
+
|
33 |
+
self.confg = confg
|
34 |
+
|
35 |
+
self.mean, self.sd = confg["data_confg"]["mean"], confg["data_confg"]["sd"]
|
36 |
+
|
37 |
+
self.model = MultiInputResShift(**confg["model_confg"])
|
38 |
+
self.model.flow_model.requires_grad_(False).eval()
|
39 |
+
|
40 |
+
self.ema = EMA(beta=0.995)
|
41 |
+
self.ema_model = copy.deepcopy(self.model).eval().requires_grad_(False)
|
42 |
+
|
43 |
+
self.charbonnier_loss = lambda x, y: torch.mean(torch.sqrt((x - y)**2 + 1e-6))
|
44 |
+
self.lpips_loss = LPIPS(net_type='vgg')
|
45 |
+
|
46 |
+
self.train_metrics = MetricCollection({
|
47 |
+
"train_lpips": LPIPS(net_type='alex'),
|
48 |
+
"train_psnr": PSNR(),
|
49 |
+
"train_ssim": SSIM()
|
50 |
+
})
|
51 |
+
self.val_metrics = MetricCollection({
|
52 |
+
"val_lpips": LPIPS(net_type='alex'),
|
53 |
+
"val_psnr": PSNR(),
|
54 |
+
"val_ssim": SSIM()
|
55 |
+
})
|
56 |
+
|
57 |
+
def loss_fn(self,
|
58 |
+
x: torch.Tensor,
|
59 |
+
predicted_x: torch.Tensor) -> torch.Tensor:
|
60 |
+
percep_loss = 0.2 * self.lpips_loss(x, predicted_x.clamp(-1, 1))
|
61 |
+
pix2pix_loss = self.charbonnier_loss(x, predicted_x)
|
62 |
+
return percep_loss + pix2pix_loss
|
63 |
+
|
64 |
+
def sample_t(self,
|
65 |
+
shape: tuple[int, ...],
|
66 |
+
max_t: int,
|
67 |
+
device: torch.device) -> torch.Tensor:
|
68 |
+
p = torch.linspace(1, max_t, steps=max_t, device=device) ** 2
|
69 |
+
p = p / p.sum()
|
70 |
+
t = torch.multinomial(p, num_samples=shape[0], replacement=True)
|
71 |
+
return t
|
72 |
+
|
73 |
+
def forward(self,
|
74 |
+
I0: torch.Tensor,
|
75 |
+
It: torch.Tensor,
|
76 |
+
I1: torch.Tensor) -> torch.Tensor:
|
77 |
+
flow0tot = raft_flow(I0, It, 'animation')
|
78 |
+
flow1tot = raft_flow(I1, It, 'animation')
|
79 |
+
mid_idx = get_inter_frame_temp_index(I0, It, I1, flow0tot, flow1tot).to(It.dtype)
|
80 |
+
|
81 |
+
tau = torch.stack([mid_idx, 1 - mid_idx], dim=1)
|
82 |
+
|
83 |
+
if self.current_epoch > 5:
|
84 |
+
t = torch.randint(low=1, high=self.model.timesteps, size=(It.shape[0],), device=It.device, dtype=torch.long)
|
85 |
+
else:
|
86 |
+
t = self.sample_t(shape=(It.shape[0],), max_t=self.model.timesteps, device=It.device)
|
87 |
+
|
88 |
+
predicted_It = self.model(I0, It, I1, tau=tau, t=t)
|
89 |
+
return predicted_It
|
90 |
+
|
91 |
+
def get_step_plt_images(self,
|
92 |
+
It: torch.Tensor,
|
93 |
+
predicted_It: torch.Tensor) -> plt.Figure:
|
94 |
+
fig, ax = plt.subplots(1, 2, figsize=(20, 10))
|
95 |
+
ax[0].imshow(denorm(predicted_It.clamp(-1, 1), self.mean, self.sd)[0].permute(1, 2, 0).cpu().numpy())
|
96 |
+
ax[0].axis("off")
|
97 |
+
ax[0].set_title("Predicted")
|
98 |
+
ax[1].imshow(denorm(It, self.mean, self.sd)[0].permute(1, 2, 0).cpu().numpy())
|
99 |
+
ax[1].axis("off")
|
100 |
+
ax[1].set_title("Ground Truth")
|
101 |
+
plt.tight_layout()
|
102 |
+
#img_path = "step_image.png"
|
103 |
+
#fig.savefig(img_path, dpi=300, bbox_inches='tight')
|
104 |
+
plt.close(fig)
|
105 |
+
return fig
|
106 |
+
|
107 |
+
def training_step(self, batch: tuple[torch.Tensor, ...], _) -> torch.Tensor:
|
108 |
+
I0, It, I1 = batch
|
109 |
+
predicted_It = self(I0, It, I1)
|
110 |
+
loss = self.loss_fn(It, predicted_It)
|
111 |
+
|
112 |
+
self.log("lr", self.trainer.optimizers[0].param_groups[0]["lr"], prog_bar=True, on_step=True, on_epoch=False, sync_dist=True)
|
113 |
+
self.log("train_loss", loss, prog_bar=True, on_step=True, on_epoch=False, sync_dist=True)
|
114 |
+
|
115 |
+
self.ema.step_ema(self.ema_model, self.model)
|
116 |
+
with torch.inference_mode():
|
117 |
+
fig = self.get_step_plt_images(It, predicted_It)
|
118 |
+
self.logger.experiment.add_figure("Train Predictions", fig, self.global_step)
|
119 |
+
mets = self.train_metrics(It, predicted_It.clamp(-1, 1))
|
120 |
+
self.log_dict(mets, prog_bar=True, on_step=True,on_epoch=False)
|
121 |
+
return loss
|
122 |
+
|
123 |
+
@torch.no_grad()
|
124 |
+
def validation_step(self, batch: tuple[torch.Tensor, ...], _) -> None:
|
125 |
+
I0, It, I1 = batch
|
126 |
+
predicted_It = self(I0, It, I1)
|
127 |
+
loss = self.loss_fn(It, predicted_It)
|
128 |
+
|
129 |
+
self.log("val_loss", loss, prog_bar=True, on_step=False, on_epoch=True, sync_dist=True)
|
130 |
+
|
131 |
+
mets = self.val_metrics(It, predicted_It.clamp(-1, 1))
|
132 |
+
self.log_dict(mets, prog_bar=True, on_step=False, on_epoch=True)
|
133 |
+
|
134 |
+
@torch.inference_mode()
|
135 |
+
def on_train_epoch_end(self) -> None:
|
136 |
+
torch.save(self.ema_model.state_dict(),
|
137 |
+
os.path.join("_checkpoint", f"resshift_diff_{self.current_epoch}.pth"))
|
138 |
+
|
139 |
+
batch = next(iter(self.test_dataloader))
|
140 |
+
I0, It, I1 = batch
|
141 |
+
I0, It, I1 = I0.to(self.device), It.to(self.device), I1.to(self.device)
|
142 |
+
|
143 |
+
flow0tot = raft_flow(I0, It, 'animation')
|
144 |
+
flow1tot = raft_flow(I1, It, 'animation')
|
145 |
+
mid_idx = get_inter_frame_temp_index(I0, It, I1, flow0tot, flow1tot).to(It.dtype)
|
146 |
+
tau = torch.stack([mid_idx, 1 - mid_idx], dim=1)
|
147 |
+
|
148 |
+
predicted_It = self.ema_model.reverse_process([I0, I1], tau)
|
149 |
+
|
150 |
+
I0 = denorm(I0, self.mean, self.sd)
|
151 |
+
I1 = denorm(I1, self.mean, self.sd)
|
152 |
+
It = denorm(It, self.mean, self.sd)
|
153 |
+
predicted_It = denorm(predicted_It.clamp(-1, 1), self.mean, self.sd)
|
154 |
+
|
155 |
+
#save_triplet([I0, It, predicted_It, I1], f"./_output/target_{self.current_epoch}.png", nrow=1)
|
156 |
+
grid = make_grid_images([I0, It, predicted_It, I1], nrow=1)
|
157 |
+
self.logger.experiment.add_image("Predicted Images", grid, self.global_step)
|
158 |
+
|
159 |
+
def configure_optimizers(self) -> tuple[list[Optimizer], list[dict[str, Any]]]:
|
160 |
+
optimizer = [AdamW(
|
161 |
+
self.model.parameters(),
|
162 |
+
**self.confg["optim_confg"]['optimizer_confg']
|
163 |
+
)]
|
164 |
+
|
165 |
+
scheduler = [{
|
166 |
+
'scheduler': ReduceLROnPlateau(
|
167 |
+
optimizer[0],
|
168 |
+
**self.confg["optim_confg"]['scheduler_confg']
|
169 |
+
),
|
170 |
+
'monitor': 'val_loss',
|
171 |
+
'interval': 'epoch',
|
172 |
+
'frequency': 1,
|
173 |
+
'strict': True,
|
174 |
+
}]
|
175 |
+
|
176 |
+
return optimizer, scheduler
|
177 |
+
|
modules/basic_layers.py
ADDED
@@ -0,0 +1,313 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
|
5 |
+
from einops import rearrange
|
6 |
+
from einops.layers.torch import Rearrange
|
7 |
+
|
8 |
+
class GroupNorm(nn.Module):
|
9 |
+
def __init__(self, in_channels: int, num_groups: int = 32):
|
10 |
+
super(GroupNorm, self).__init__()
|
11 |
+
self.gn = nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True)
|
12 |
+
|
13 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
14 |
+
return self.gn(x)
|
15 |
+
|
16 |
+
class AdaLayerNorm(nn.Module):
|
17 |
+
def __init__(self, channels: int, cond_channels: int = 0, return_scale_shift: bool = True):
|
18 |
+
super(AdaLayerNorm, self).__init__()
|
19 |
+
self.norm = nn.LayerNorm(channels)
|
20 |
+
self.return_scale_shift = return_scale_shift
|
21 |
+
if cond_channels != 0:
|
22 |
+
if return_scale_shift:
|
23 |
+
self.proj = nn.Linear(cond_channels, channels * 3, bias=False)
|
24 |
+
else:
|
25 |
+
self.proj = nn.Linear(cond_channels, channels * 2, bias=False)
|
26 |
+
nn.init.xavier_uniform_(self.proj.weight)
|
27 |
+
|
28 |
+
def expand_dims(self, tensor: torch.Tensor, dims: list[int]) -> torch.Tensor:
|
29 |
+
for dim in dims:
|
30 |
+
tensor = tensor.unsqueeze(dim)
|
31 |
+
return tensor
|
32 |
+
|
33 |
+
def forward(self, x: torch.Tensor, cond: torch.Tensor | None = None) -> torch.Tensor:
|
34 |
+
x = self.norm(x)
|
35 |
+
if cond is None:
|
36 |
+
return x
|
37 |
+
dims = list(range(1, len(x.shape) - 1))
|
38 |
+
if self.return_scale_shift:
|
39 |
+
gamma, beta, sigma = self.proj(cond).chunk(3, dim=-1)
|
40 |
+
gamma, beta, sigma = [self.expand_dims(t, dims) for t in (gamma, beta, sigma)]
|
41 |
+
return x * (1 + gamma) + beta, sigma
|
42 |
+
else:
|
43 |
+
gamma, beta = self.proj(cond).chunk(2, dim=-1)
|
44 |
+
gamma, beta = [self.expand_dims(t, dims) for t in (gamma, beta)]
|
45 |
+
return x * (1 + gamma) + beta
|
46 |
+
|
47 |
+
class SinusoidalPositionalEmbedding(nn.Module):
|
48 |
+
def __init__(self, emb_dim: int = 256):
|
49 |
+
super(SinusoidalPositionalEmbedding, self).__init__()
|
50 |
+
self.channels = emb_dim
|
51 |
+
|
52 |
+
def forward(self, t: torch.Tensor) -> torch.Tensor:
|
53 |
+
inv_freq = 1.0 / (
|
54 |
+
10000
|
55 |
+
** (torch.arange(0, self.channels, 2, device=t.device).float() / self.channels)
|
56 |
+
)
|
57 |
+
pos_enc_a = torch.sin(t.repeat(1, self.channels // 2) * inv_freq)
|
58 |
+
pos_enc_b = torch.cos(t.repeat(1, self.channels // 2) * inv_freq)
|
59 |
+
pos_enc = torch.cat([pos_enc_a, pos_enc_b], dim=-1)
|
60 |
+
return pos_enc
|
61 |
+
|
62 |
+
class GatedConv2d(nn.Module):
|
63 |
+
def __init__(self,
|
64 |
+
in_channels: int,
|
65 |
+
out_channels: int,
|
66 |
+
kernel_size: int = 3,
|
67 |
+
padding: int = 1,
|
68 |
+
bias: bool = False):
|
69 |
+
super(GatedConv2d, self).__init__()
|
70 |
+
self.gate_conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)
|
71 |
+
self.feature_conv = nn.Conv2d(in_channels,
|
72 |
+
out_channels,
|
73 |
+
kernel_size=kernel_size,
|
74 |
+
padding=padding,
|
75 |
+
bias=bias)
|
76 |
+
|
77 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
78 |
+
gate = torch.sigmoid(self.gate_conv(x))
|
79 |
+
feature = F.silu(self.feature_conv(x))
|
80 |
+
return gate * feature
|
81 |
+
|
82 |
+
class ResGatedBlock(nn.Module):
|
83 |
+
def __init__(self,
|
84 |
+
in_channels: int,
|
85 |
+
out_channels: int,
|
86 |
+
mid_channels: int | None = None,
|
87 |
+
num_groups: int = 32,
|
88 |
+
residual: bool = True,
|
89 |
+
emb_channels: int | None = None,
|
90 |
+
gated_conv: bool = False):
|
91 |
+
super().__init__()
|
92 |
+
self.residual = residual
|
93 |
+
self.emb_channels = emb_channels
|
94 |
+
if not mid_channels:
|
95 |
+
mid_channels = out_channels
|
96 |
+
|
97 |
+
if gated_conv: conv2d = GatedConv2d
|
98 |
+
else: conv2d = nn.Conv2d
|
99 |
+
|
100 |
+
self.conv1 = conv2d(in_channels, mid_channels, kernel_size=3, padding=1, bias=False)
|
101 |
+
self.norm1 = GroupNorm(mid_channels, num_groups=num_groups)
|
102 |
+
self.nonlienrity = nn.SiLU()
|
103 |
+
if emb_channels:
|
104 |
+
self.emb_proj = nn.Linear(emb_channels, mid_channels)
|
105 |
+
self.conv2 = conv2d(mid_channels, out_channels, kernel_size=3, padding=1, bias=False)
|
106 |
+
self.norm2 = GroupNorm(out_channels, num_groups=num_groups)
|
107 |
+
|
108 |
+
if in_channels != out_channels:
|
109 |
+
self.skip = conv2d(in_channels, out_channels, kernel_size=1, padding=0)
|
110 |
+
|
111 |
+
def double_conv(self, x: torch.Tensor, emb: torch.Tensor | None = None) -> torch.Tensor:
|
112 |
+
x = self.conv1(x)
|
113 |
+
x = self.norm1(x)
|
114 |
+
x = self.nonlienrity(x)
|
115 |
+
if emb is not None and self.emb_channels is not None:
|
116 |
+
x = x + self.emb_proj(emb)[:,:,None,None]
|
117 |
+
x = self.conv2(x)
|
118 |
+
return self.norm2(x)
|
119 |
+
|
120 |
+
def forward(self, x: torch.Tensor, emb: torch.Tensor | None = None) -> torch.Tensor:
|
121 |
+
if self.residual:
|
122 |
+
if hasattr(self, 'skip'):
|
123 |
+
return F.silu(self.skip(x) + self.double_conv(x, emb))
|
124 |
+
return F.silu(x + self.double_conv(x, emb))
|
125 |
+
else:
|
126 |
+
return self.double_conv(x, emb)
|
127 |
+
|
128 |
+
class Downsample(nn.Module):
|
129 |
+
def __init__(self,
|
130 |
+
in_channels: int,
|
131 |
+
out_channels: int,
|
132 |
+
use_conv: bool=True):
|
133 |
+
super().__init__()
|
134 |
+
|
135 |
+
self.use_conv = use_conv
|
136 |
+
if use_conv:
|
137 |
+
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=2, padding=0)
|
138 |
+
else:
|
139 |
+
assert in_channels == out_channels
|
140 |
+
self.conv = nn.AvgPool2d(kernel_size=2, stride=2)
|
141 |
+
|
142 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
143 |
+
pad = (0, 1, 0, 1)
|
144 |
+
hidden_states = F.pad(x, pad, mode="constant", value=0)
|
145 |
+
return self.conv(hidden_states) if self.use_conv else self.conv(x)
|
146 |
+
|
147 |
+
class Upsample(nn.Module):
|
148 |
+
def __init__(self,
|
149 |
+
in_channels: int,
|
150 |
+
out_channels: int,
|
151 |
+
use_conv: bool=True):
|
152 |
+
super().__init__()
|
153 |
+
|
154 |
+
self.use_conv = use_conv
|
155 |
+
if use_conv:
|
156 |
+
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
|
157 |
+
|
158 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
159 |
+
x = F.interpolate(x,
|
160 |
+
scale_factor = (2, 2) if x.dim() == 4 else (1, 2, 2),
|
161 |
+
mode='nearest')
|
162 |
+
return self.conv(x) if self.use_conv else x
|
163 |
+
|
164 |
+
class FeedForward(nn.Module):
|
165 |
+
def __init__(self,
|
166 |
+
dim: int,
|
167 |
+
emb_channels: int,
|
168 |
+
expansion_rate: int = 4,
|
169 |
+
dropout: float = 0.0):
|
170 |
+
super().__init__()
|
171 |
+
inner_dim = int(dim * expansion_rate)
|
172 |
+
self.norm = AdaLayerNorm(dim, emb_channels)
|
173 |
+
self.net = nn.Sequential(
|
174 |
+
nn.Linear(dim, inner_dim),
|
175 |
+
nn.SiLU(),
|
176 |
+
nn.Dropout(dropout),
|
177 |
+
nn.Linear(inner_dim, dim),
|
178 |
+
nn.Dropout(dropout)
|
179 |
+
)
|
180 |
+
self.__init_weights()
|
181 |
+
|
182 |
+
def __init_weights(self):
|
183 |
+
nn.init.xavier_uniform_(self.net[0].weight)
|
184 |
+
nn.init.xavier_uniform_(self.net[3].weight)
|
185 |
+
|
186 |
+
def forward(self, x: torch.Tensor, emb: torch.Tensor | None = None) -> torch.Tensor:
|
187 |
+
x, sigma = self.norm(x, emb)
|
188 |
+
return self.net(x) * sigma
|
189 |
+
|
190 |
+
class Attention(nn.Module):
|
191 |
+
def __init__(
|
192 |
+
self,
|
193 |
+
dim: int,
|
194 |
+
emb_channels: int = 512,
|
195 |
+
dim_head: int = 32,
|
196 |
+
dropout: float = 0.,
|
197 |
+
window_size: int = 7
|
198 |
+
):
|
199 |
+
super().__init__()
|
200 |
+
assert (dim % dim_head) == 0, 'dimension should be divisible by dimension per head'
|
201 |
+
self.heads = dim // dim_head
|
202 |
+
self.scale = dim_head ** -0.5
|
203 |
+
self.norm = AdaLayerNorm(dim, emb_channels)
|
204 |
+
|
205 |
+
self.to_q = nn.Linear(dim, dim, bias = False)
|
206 |
+
self.to_k = nn.Linear(dim, dim, bias = False)
|
207 |
+
self.to_v = nn.Linear(dim, dim, bias = False)
|
208 |
+
|
209 |
+
self.attend = nn.Sequential(
|
210 |
+
nn.Softmax(dim = -1),
|
211 |
+
nn.Dropout(dropout)
|
212 |
+
)
|
213 |
+
self.to_out = nn.Sequential(
|
214 |
+
nn.Linear(dim, dim, bias = False),
|
215 |
+
nn.Dropout(dropout)
|
216 |
+
)
|
217 |
+
|
218 |
+
self.rel_pos_bias = nn.Embedding((2 * window_size - 1) ** 2, self.heads)
|
219 |
+
pos = torch.arange(window_size)
|
220 |
+
grid = torch.stack(torch.meshgrid(pos, pos, indexing = 'ij'))
|
221 |
+
grid = rearrange(grid, 'c i j -> (i j) c')
|
222 |
+
rel_pos = rearrange(grid, 'i ... -> i 1 ...') - rearrange(grid, 'j ... -> 1 j ...')
|
223 |
+
rel_pos += window_size - 1
|
224 |
+
rel_pos_indices = (rel_pos * torch.tensor([2 * window_size - 1, 1])).sum(dim = -1)
|
225 |
+
|
226 |
+
self.register_buffer('rel_pos_indices', rel_pos_indices, persistent = False)
|
227 |
+
|
228 |
+
def forward(self, x: torch.Tensor, emb: torch.Tensor | None = None) -> torch.Tensor:
|
229 |
+
batch, height, width, window_height, window_width, _, device, h = *x.shape, x.device, self.heads
|
230 |
+
|
231 |
+
x, sigma = self.norm(x, emb)
|
232 |
+
x = rearrange(x, 'b x y w1 w2 d -> (b x y) (w1 w2) d')
|
233 |
+
|
234 |
+
q = self.to_q(x)
|
235 |
+
k = self.to_k(x)
|
236 |
+
v = self.to_v(x)
|
237 |
+
|
238 |
+
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, k, v)) # split heads
|
239 |
+
|
240 |
+
q = q * self.scale
|
241 |
+
sim = torch.einsum('b h i d, b h j d -> b h i j', q, k) # sim
|
242 |
+
bias = self.rel_pos_bias(self.rel_pos_indices)
|
243 |
+
sim = sim + rearrange(bias, 'i j h -> h i j')# add positional bias
|
244 |
+
attn = self.attend(sim) # attention
|
245 |
+
out = torch.einsum('b h i j, b h j d -> b h i d', attn, v) # aggregate
|
246 |
+
|
247 |
+
out = rearrange(out, 'b h (w1 w2) d -> b w1 w2 (h d)', w1 = window_height, w2 = window_width) # merge heads
|
248 |
+
out = self.to_out(out) # combine heads out
|
249 |
+
return rearrange(out, '(b x y) ... -> b x y ...', x = height, y = width) * sigma
|
250 |
+
|
251 |
+
class MaxViTBlock(nn.Module):
|
252 |
+
def __init__(
|
253 |
+
self,
|
254 |
+
channels: int,
|
255 |
+
emb_channels: int = 512,
|
256 |
+
heads: int = 1,
|
257 |
+
window_size: int = 8,
|
258 |
+
window_attn: bool = True,
|
259 |
+
grid_attn: bool = True,
|
260 |
+
expansion_rate: int = 4,
|
261 |
+
dropout: float = 0.0,
|
262 |
+
):
|
263 |
+
super(MaxViTBlock, self).__init__()
|
264 |
+
dim_head = channels // heads
|
265 |
+
layer_dim = dim_head * heads
|
266 |
+
w = window_size
|
267 |
+
|
268 |
+
self.window_attn = window_attn
|
269 |
+
self.grid_attn = grid_attn
|
270 |
+
|
271 |
+
if window_attn:
|
272 |
+
self.wind_rearrange_forward = Rearrange('b d (x w1) (y w2) -> b x y w1 w2 d', w1 = w, w2 = w) # block-like attention
|
273 |
+
self.wind_attn = Attention(
|
274 |
+
dim = layer_dim,
|
275 |
+
emb_channels = emb_channels,
|
276 |
+
dim_head = dim_head,
|
277 |
+
dropout = dropout,
|
278 |
+
window_size = w
|
279 |
+
)
|
280 |
+
|
281 |
+
self.wind_ff = FeedForward(dim = layer_dim,
|
282 |
+
emb_channels = emb_channels,
|
283 |
+
expansion_rate = expansion_rate,
|
284 |
+
dropout = dropout)
|
285 |
+
self.wind_rearrange_backward = Rearrange('b x y w1 w2 d -> b d (x w1) (y w2)')
|
286 |
+
|
287 |
+
if grid_attn:
|
288 |
+
self.grid_rearrange_forward = Rearrange('b d (w1 x) (w2 y) -> b x y w1 w2 d', w1 = w, w2 = w) # grid-like attention
|
289 |
+
self.grid_attn = Attention(
|
290 |
+
dim = layer_dim,
|
291 |
+
emb_channels = emb_channels,
|
292 |
+
dim_head = dim_head,
|
293 |
+
dropout = dropout,
|
294 |
+
window_size = w
|
295 |
+
)
|
296 |
+
self.grid_ff = FeedForward(dim = layer_dim,
|
297 |
+
emb_channels = emb_channels,
|
298 |
+
expansion_rate = expansion_rate,
|
299 |
+
dropout = dropout)
|
300 |
+
self.grid_rearrange_backward = Rearrange('b x y w1 w2 d -> b d (w1 x) (w2 y)')
|
301 |
+
|
302 |
+
def forward(self, x: torch.Tensor, emb: torch.Tensor | None = None) -> torch.Tensor:
|
303 |
+
if self.window_attn:
|
304 |
+
x = self.wind_rearrange_forward(x)
|
305 |
+
x = x + self.wind_attn(x, emb = emb)
|
306 |
+
x = x + self.wind_ff(x, emb = emb)
|
307 |
+
x = self.wind_rearrange_backward(x)
|
308 |
+
if self.grid_attn:
|
309 |
+
x = self.grid_rearrange_forward(x)
|
310 |
+
x = x + self.grid_attn(x, emb = emb)
|
311 |
+
x = x + self.grid_ff(x, emb = emb)
|
312 |
+
x = self.grid_rearrange_backward(x)
|
313 |
+
return x
|
modules/cupy_module/correlation.py
ADDED
@@ -0,0 +1,402 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
|
3 |
+
import cupy
|
4 |
+
import os
|
5 |
+
import re
|
6 |
+
import torch
|
7 |
+
|
8 |
+
# Code taken from https://github.com/sniklaus/softmax-splatting/blob/master/correlation/correlation.py
|
9 |
+
|
10 |
+
kernel_Correlation_rearrange = '''
|
11 |
+
extern "C" __global__ void kernel_Correlation_rearrange(
|
12 |
+
const int n,
|
13 |
+
const float* input,
|
14 |
+
float* output
|
15 |
+
) {
|
16 |
+
int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x;
|
17 |
+
|
18 |
+
if (intIndex >= n) {
|
19 |
+
return;
|
20 |
+
}
|
21 |
+
|
22 |
+
int intSample = blockIdx.z;
|
23 |
+
int intChannel = blockIdx.y;
|
24 |
+
|
25 |
+
float fltValue = input[(((intSample * SIZE_1(input)) + intChannel) * SIZE_2(input) * SIZE_3(input)) + intIndex];
|
26 |
+
|
27 |
+
__syncthreads();
|
28 |
+
|
29 |
+
int intPaddedY = (intIndex / SIZE_3(input)) + 4;
|
30 |
+
int intPaddedX = (intIndex % SIZE_3(input)) + 4;
|
31 |
+
int intRearrange = ((SIZE_3(input) + 8) * intPaddedY) + intPaddedX;
|
32 |
+
|
33 |
+
output[(((intSample * SIZE_1(output) * SIZE_2(output)) + intRearrange) * SIZE_1(input)) + intChannel] = fltValue;
|
34 |
+
}
|
35 |
+
'''
|
36 |
+
|
37 |
+
kernel_Correlation_updateOutput = '''
|
38 |
+
extern "C" __global__ void kernel_Correlation_updateOutput(
|
39 |
+
const int n,
|
40 |
+
const float* rbot0,
|
41 |
+
const float* rbot1,
|
42 |
+
float* top
|
43 |
+
) {
|
44 |
+
extern __shared__ char patch_data_char[];
|
45 |
+
|
46 |
+
float *patch_data = (float *)patch_data_char;
|
47 |
+
|
48 |
+
// First (upper left) position of kernel upper-left corner in current center position of neighborhood in image 1
|
49 |
+
int x1 = blockIdx.x + 4;
|
50 |
+
int y1 = blockIdx.y + 4;
|
51 |
+
int item = blockIdx.z;
|
52 |
+
int ch_off = threadIdx.x;
|
53 |
+
|
54 |
+
// Load 3D patch into shared shared memory
|
55 |
+
for (int j = 0; j < 1; j++) { // HEIGHT
|
56 |
+
for (int i = 0; i < 1; i++) { // WIDTH
|
57 |
+
int ji_off = (j + i) * SIZE_3(rbot0);
|
58 |
+
for (int ch = ch_off; ch < SIZE_3(rbot0); ch += 32) { // CHANNELS
|
59 |
+
int idx1 = ((item * SIZE_1(rbot0) + y1+j) * SIZE_2(rbot0) + x1+i) * SIZE_3(rbot0) + ch;
|
60 |
+
int idxPatchData = ji_off + ch;
|
61 |
+
patch_data[idxPatchData] = rbot0[idx1];
|
62 |
+
}
|
63 |
+
}
|
64 |
+
}
|
65 |
+
|
66 |
+
__syncthreads();
|
67 |
+
|
68 |
+
__shared__ float sum[32];
|
69 |
+
|
70 |
+
// Compute correlation
|
71 |
+
for (int top_channel = 0; top_channel < SIZE_1(top); top_channel++) {
|
72 |
+
sum[ch_off] = 0;
|
73 |
+
|
74 |
+
int s2o = top_channel % 9 - 4;
|
75 |
+
int s2p = top_channel / 9 - 4;
|
76 |
+
|
77 |
+
for (int j = 0; j < 1; j++) { // HEIGHT
|
78 |
+
for (int i = 0; i < 1; i++) { // WIDTH
|
79 |
+
int ji_off = (j + i) * SIZE_3(rbot0);
|
80 |
+
for (int ch = ch_off; ch < SIZE_3(rbot0); ch += 32) { // CHANNELS
|
81 |
+
int x2 = x1 + s2o;
|
82 |
+
int y2 = y1 + s2p;
|
83 |
+
|
84 |
+
int idxPatchData = ji_off + ch;
|
85 |
+
int idx2 = ((item * SIZE_1(rbot0) + y2+j) * SIZE_2(rbot0) + x2+i) * SIZE_3(rbot0) + ch;
|
86 |
+
|
87 |
+
sum[ch_off] += patch_data[idxPatchData] * rbot1[idx2];
|
88 |
+
}
|
89 |
+
}
|
90 |
+
}
|
91 |
+
|
92 |
+
__syncthreads();
|
93 |
+
|
94 |
+
if (ch_off == 0) {
|
95 |
+
float total_sum = 0;
|
96 |
+
for (int idx = 0; idx < 32; idx++) {
|
97 |
+
total_sum += sum[idx];
|
98 |
+
}
|
99 |
+
const int sumelems = SIZE_3(rbot0);
|
100 |
+
const int index = ((top_channel*SIZE_2(top) + blockIdx.y)*SIZE_3(top))+blockIdx.x;
|
101 |
+
top[index + item*SIZE_1(top)*SIZE_2(top)*SIZE_3(top)] = total_sum / (float)sumelems;
|
102 |
+
}
|
103 |
+
}
|
104 |
+
}
|
105 |
+
'''
|
106 |
+
|
107 |
+
kernel_Correlation_updateGradOne = '''
|
108 |
+
#define ROUND_OFF 50000
|
109 |
+
|
110 |
+
extern "C" __global__ void kernel_Correlation_updateGradOne(
|
111 |
+
const int n,
|
112 |
+
const int intSample,
|
113 |
+
const float* rbot0,
|
114 |
+
const float* rbot1,
|
115 |
+
const float* gradOutput,
|
116 |
+
float* gradOne,
|
117 |
+
float* gradTwo
|
118 |
+
) { for (int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x; intIndex < n; intIndex += blockDim.x * gridDim.x) {
|
119 |
+
int n = intIndex % SIZE_1(gradOne); // channels
|
120 |
+
int l = (intIndex / SIZE_1(gradOne)) % SIZE_3(gradOne) + 4; // w-pos
|
121 |
+
int m = (intIndex / SIZE_1(gradOne) / SIZE_3(gradOne)) % SIZE_2(gradOne) + 4; // h-pos
|
122 |
+
|
123 |
+
// round_off is a trick to enable integer division with ceil, even for negative numbers
|
124 |
+
// We use a large offset, for the inner part not to become negative.
|
125 |
+
const int round_off = ROUND_OFF;
|
126 |
+
const int round_off_s1 = round_off;
|
127 |
+
|
128 |
+
// We add round_off before_s1 the int division and subtract round_off after it, to ensure the formula matches ceil behavior:
|
129 |
+
int xmin = (l - 4 + round_off_s1 - 1) + 1 - round_off; // ceil (l - 4)
|
130 |
+
int ymin = (m - 4 + round_off_s1 - 1) + 1 - round_off; // ceil (l - 4)
|
131 |
+
|
132 |
+
// Same here:
|
133 |
+
int xmax = (l - 4 + round_off_s1) - round_off; // floor (l - 4)
|
134 |
+
int ymax = (m - 4 + round_off_s1) - round_off; // floor (m - 4)
|
135 |
+
|
136 |
+
float sum = 0;
|
137 |
+
if (xmax>=0 && ymax>=0 && (xmin<=SIZE_3(gradOutput)-1) && (ymin<=SIZE_2(gradOutput)-1)) {
|
138 |
+
xmin = max(0,xmin);
|
139 |
+
xmax = min(SIZE_3(gradOutput)-1,xmax);
|
140 |
+
|
141 |
+
ymin = max(0,ymin);
|
142 |
+
ymax = min(SIZE_2(gradOutput)-1,ymax);
|
143 |
+
|
144 |
+
for (int p = -4; p <= 4; p++) {
|
145 |
+
for (int o = -4; o <= 4; o++) {
|
146 |
+
// Get rbot1 data:
|
147 |
+
int s2o = o;
|
148 |
+
int s2p = p;
|
149 |
+
int idxbot1 = ((intSample * SIZE_1(rbot0) + (m+s2p)) * SIZE_2(rbot0) + (l+s2o)) * SIZE_3(rbot0) + n;
|
150 |
+
float bot1tmp = rbot1[idxbot1]; // rbot1[l+s2o,m+s2p,n]
|
151 |
+
|
152 |
+
// Index offset for gradOutput in following loops:
|
153 |
+
int op = (p+4) * 9 + (o+4); // index[o,p]
|
154 |
+
int idxopoffset = (intSample * SIZE_1(gradOutput) + op);
|
155 |
+
|
156 |
+
for (int y = ymin; y <= ymax; y++) {
|
157 |
+
for (int x = xmin; x <= xmax; x++) {
|
158 |
+
int idxgradOutput = (idxopoffset * SIZE_2(gradOutput) + y) * SIZE_3(gradOutput) + x; // gradOutput[x,y,o,p]
|
159 |
+
sum += gradOutput[idxgradOutput] * bot1tmp;
|
160 |
+
}
|
161 |
+
}
|
162 |
+
}
|
163 |
+
}
|
164 |
+
}
|
165 |
+
const int sumelems = SIZE_1(gradOne);
|
166 |
+
const int bot0index = ((n * SIZE_2(gradOne)) + (m-4)) * SIZE_3(gradOne) + (l-4);
|
167 |
+
gradOne[bot0index + intSample*SIZE_1(gradOne)*SIZE_2(gradOne)*SIZE_3(gradOne)] = sum / (float)sumelems;
|
168 |
+
} }
|
169 |
+
'''
|
170 |
+
|
171 |
+
kernel_Correlation_updateGradTwo = '''
|
172 |
+
#define ROUND_OFF 50000
|
173 |
+
|
174 |
+
extern "C" __global__ void kernel_Correlation_updateGradTwo(
|
175 |
+
const int n,
|
176 |
+
const int intSample,
|
177 |
+
const float* rbot0,
|
178 |
+
const float* rbot1,
|
179 |
+
const float* gradOutput,
|
180 |
+
float* gradOne,
|
181 |
+
float* gradTwo
|
182 |
+
) { for (int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x; intIndex < n; intIndex += blockDim.x * gridDim.x) {
|
183 |
+
int n = intIndex % SIZE_1(gradTwo); // channels
|
184 |
+
int l = (intIndex / SIZE_1(gradTwo)) % SIZE_3(gradTwo) + 4; // w-pos
|
185 |
+
int m = (intIndex / SIZE_1(gradTwo) / SIZE_3(gradTwo)) % SIZE_2(gradTwo) + 4; // h-pos
|
186 |
+
|
187 |
+
// round_off is a trick to enable integer division with ceil, even for negative numbers
|
188 |
+
// We use a large offset, for the inner part not to become negative.
|
189 |
+
const int round_off = ROUND_OFF;
|
190 |
+
const int round_off_s1 = round_off;
|
191 |
+
|
192 |
+
float sum = 0;
|
193 |
+
for (int p = -4; p <= 4; p++) {
|
194 |
+
for (int o = -4; o <= 4; o++) {
|
195 |
+
int s2o = o;
|
196 |
+
int s2p = p;
|
197 |
+
|
198 |
+
//Get X,Y ranges and clamp
|
199 |
+
// We add round_off before_s1 the int division and subtract round_off after it, to ensure the formula matches ceil behavior:
|
200 |
+
int xmin = (l - 4 - s2o + round_off_s1 - 1) + 1 - round_off; // ceil (l - 4 - s2o)
|
201 |
+
int ymin = (m - 4 - s2p + round_off_s1 - 1) + 1 - round_off; // ceil (l - 4 - s2o)
|
202 |
+
|
203 |
+
// Same here:
|
204 |
+
int xmax = (l - 4 - s2o + round_off_s1) - round_off; // floor (l - 4 - s2o)
|
205 |
+
int ymax = (m - 4 - s2p + round_off_s1) - round_off; // floor (m - 4 - s2p)
|
206 |
+
|
207 |
+
if (xmax>=0 && ymax>=0 && (xmin<=SIZE_3(gradOutput)-1) && (ymin<=SIZE_2(gradOutput)-1)) {
|
208 |
+
xmin = max(0,xmin);
|
209 |
+
xmax = min(SIZE_3(gradOutput)-1,xmax);
|
210 |
+
|
211 |
+
ymin = max(0,ymin);
|
212 |
+
ymax = min(SIZE_2(gradOutput)-1,ymax);
|
213 |
+
|
214 |
+
// Get rbot0 data:
|
215 |
+
int idxbot0 = ((intSample * SIZE_1(rbot0) + (m-s2p)) * SIZE_2(rbot0) + (l-s2o)) * SIZE_3(rbot0) + n;
|
216 |
+
float bot0tmp = rbot0[idxbot0]; // rbot1[l+s2o,m+s2p,n]
|
217 |
+
|
218 |
+
// Index offset for gradOutput in following loops:
|
219 |
+
int op = (p+4) * 9 + (o+4); // index[o,p]
|
220 |
+
int idxopoffset = (intSample * SIZE_1(gradOutput) + op);
|
221 |
+
|
222 |
+
for (int y = ymin; y <= ymax; y++) {
|
223 |
+
for (int x = xmin; x <= xmax; x++) {
|
224 |
+
int idxgradOutput = (idxopoffset * SIZE_2(gradOutput) + y) * SIZE_3(gradOutput) + x; // gradOutput[x,y,o,p]
|
225 |
+
sum += gradOutput[idxgradOutput] * bot0tmp;
|
226 |
+
}
|
227 |
+
}
|
228 |
+
}
|
229 |
+
}
|
230 |
+
}
|
231 |
+
const int sumelems = SIZE_1(gradTwo);
|
232 |
+
const int bot1index = ((n * SIZE_2(gradTwo)) + (m-4)) * SIZE_3(gradTwo) + (l-4);
|
233 |
+
gradTwo[bot1index + intSample*SIZE_1(gradTwo)*SIZE_2(gradTwo)*SIZE_3(gradTwo)] = sum / (float)sumelems;
|
234 |
+
} }
|
235 |
+
'''
|
236 |
+
|
237 |
+
def cupy_kernel(strFunction, objVariables):
|
238 |
+
strKernel = globals()[strFunction]
|
239 |
+
|
240 |
+
while True:
|
241 |
+
objMatch = re.search('(SIZE_)([0-4])(\()([^\)]*)(\))', strKernel)
|
242 |
+
|
243 |
+
if objMatch is None:
|
244 |
+
break
|
245 |
+
# end
|
246 |
+
|
247 |
+
intArg = int(objMatch.group(2))
|
248 |
+
|
249 |
+
strTensor = objMatch.group(4)
|
250 |
+
intSizes = objVariables[strTensor].size()
|
251 |
+
|
252 |
+
strKernel = strKernel.replace(objMatch.group(), str(intSizes[intArg] if torch.is_tensor(intSizes[intArg]) == False else intSizes[intArg].item()))
|
253 |
+
|
254 |
+
while True:
|
255 |
+
objMatch = re.search('(VALUE_)([0-4])(\()([^\)]+)(\))', strKernel)
|
256 |
+
|
257 |
+
if objMatch is None:
|
258 |
+
break
|
259 |
+
# end
|
260 |
+
|
261 |
+
intArgs = int(objMatch.group(2))
|
262 |
+
strArgs = objMatch.group(4).split(',')
|
263 |
+
|
264 |
+
strTensor = strArgs[0]
|
265 |
+
intStrides = objVariables[strTensor].stride()
|
266 |
+
strIndex = [ '((' + strArgs[intArg + 1].replace('{', '(').replace('}', ')').strip() + ')*' + str(intStrides[intArg] if torch.is_tensor(intStrides[intArg]) == False else intStrides[intArg].item()) + ')' for intArg in range(intArgs) ]
|
267 |
+
|
268 |
+
strKernel = strKernel.replace(objMatch.group(0), strTensor + '[' + str('+').join(strIndex) + ']')
|
269 |
+
# end
|
270 |
+
|
271 |
+
return strKernel
|
272 |
+
# end
|
273 |
+
|
274 |
+
@cupy.memoize(for_each_device=True)
|
275 |
+
def cupy_launch(strFunction, strKernel):
|
276 |
+
if 'CUDA_HOME' not in os.environ:
|
277 |
+
os.environ['CUDA_HOME'] = cupy.cuda.get_cuda_path()
|
278 |
+
# end
|
279 |
+
|
280 |
+
return cupy.RawKernel(strKernel, strFunction, tuple(['-I ' + os.environ['CUDA_HOME'], '-I ' + os.environ['CUDA_HOME'] + '/include']))
|
281 |
+
# end
|
282 |
+
|
283 |
+
class _FunctionCorrelation(torch.autograd.Function):
|
284 |
+
@staticmethod
|
285 |
+
def forward(self, one, two):
|
286 |
+
rbot0 = one.new_zeros([ one.shape[0], one.shape[2] + 8, one.shape[3] + 8, one.shape[1] ])
|
287 |
+
rbot1 = one.new_zeros([ one.shape[0], one.shape[2] + 8, one.shape[3] + 8, one.shape[1] ])
|
288 |
+
|
289 |
+
one = one.contiguous(); assert(one.is_cuda == True)
|
290 |
+
two = two.contiguous(); assert(two.is_cuda == True)
|
291 |
+
|
292 |
+
output = one.new_zeros([ one.shape[0], 81, one.shape[2], one.shape[3] ])
|
293 |
+
|
294 |
+
if one.is_cuda == True:
|
295 |
+
n = one.shape[2] * one.shape[3]
|
296 |
+
cupy_launch('kernel_Correlation_rearrange', cupy_kernel('kernel_Correlation_rearrange', {
|
297 |
+
'input': one,
|
298 |
+
'output': rbot0
|
299 |
+
}))(
|
300 |
+
grid=tuple([ int((n + 16 - 1) / 16), one.shape[1], one.shape[0] ]),
|
301 |
+
block=tuple([ 16, 1, 1 ]),
|
302 |
+
args=[ cupy.int32(n), one.data_ptr(), rbot0.data_ptr() ]
|
303 |
+
)
|
304 |
+
|
305 |
+
n = two.shape[2] * two.shape[3]
|
306 |
+
cupy_launch('kernel_Correlation_rearrange', cupy_kernel('kernel_Correlation_rearrange', {
|
307 |
+
'input': two,
|
308 |
+
'output': rbot1
|
309 |
+
}))(
|
310 |
+
grid=tuple([ int((n + 16 - 1) / 16), two.shape[1], two.shape[0] ]),
|
311 |
+
block=tuple([ 16, 1, 1 ]),
|
312 |
+
args=[ cupy.int32(n), two.data_ptr(), rbot1.data_ptr() ]
|
313 |
+
)
|
314 |
+
|
315 |
+
n = output.shape[1] * output.shape[2] * output.shape[3]
|
316 |
+
cupy_launch('kernel_Correlation_updateOutput', cupy_kernel('kernel_Correlation_updateOutput', {
|
317 |
+
'rbot0': rbot0,
|
318 |
+
'rbot1': rbot1,
|
319 |
+
'top': output
|
320 |
+
}))(
|
321 |
+
grid=tuple([ output.shape[3], output.shape[2], output.shape[0] ]),
|
322 |
+
block=tuple([ 32, 1, 1 ]),
|
323 |
+
shared_mem=one.shape[1] * 4,
|
324 |
+
args=[ cupy.int32(n), rbot0.data_ptr(), rbot1.data_ptr(), output.data_ptr() ]
|
325 |
+
)
|
326 |
+
|
327 |
+
elif one.is_cuda == False:
|
328 |
+
raise NotImplementedError()
|
329 |
+
|
330 |
+
# end
|
331 |
+
|
332 |
+
self.save_for_backward(one, two, rbot0, rbot1)
|
333 |
+
|
334 |
+
return output
|
335 |
+
# end
|
336 |
+
|
337 |
+
@staticmethod
|
338 |
+
def backward(self, gradOutput):
|
339 |
+
one, two, rbot0, rbot1 = self.saved_tensors
|
340 |
+
|
341 |
+
gradOutput = gradOutput.contiguous(); assert(gradOutput.is_cuda == True)
|
342 |
+
|
343 |
+
gradOne = one.new_zeros([ one.shape[0], one.shape[1], one.shape[2], one.shape[3] ]) if self.needs_input_grad[0] == True else None
|
344 |
+
gradTwo = one.new_zeros([ one.shape[0], one.shape[1], one.shape[2], one.shape[3] ]) if self.needs_input_grad[1] == True else None
|
345 |
+
|
346 |
+
if one.is_cuda == True:
|
347 |
+
if gradOne is not None:
|
348 |
+
for intSample in range(one.shape[0]):
|
349 |
+
n = one.shape[1] * one.shape[2] * one.shape[3]
|
350 |
+
cupy_launch('kernel_Correlation_updateGradOne', cupy_kernel('kernel_Correlation_updateGradOne', {
|
351 |
+
'rbot0': rbot0,
|
352 |
+
'rbot1': rbot1,
|
353 |
+
'gradOutput': gradOutput,
|
354 |
+
'gradOne': gradOne,
|
355 |
+
'gradTwo': None
|
356 |
+
}))(
|
357 |
+
grid=tuple([ int((n + 512 - 1) / 512), 1, 1 ]),
|
358 |
+
block=tuple([ 512, 1, 1 ]),
|
359 |
+
args=[ cupy.int32(n), intSample, rbot0.data_ptr(), rbot1.data_ptr(), gradOutput.data_ptr(), gradOne.data_ptr(), None ]
|
360 |
+
)
|
361 |
+
# end
|
362 |
+
# end
|
363 |
+
|
364 |
+
if gradTwo is not None:
|
365 |
+
for intSample in range(one.shape[0]):
|
366 |
+
n = one.shape[1] * one.shape[2] * one.shape[3]
|
367 |
+
cupy_launch('kernel_Correlation_updateGradTwo', cupy_kernel('kernel_Correlation_updateGradTwo', {
|
368 |
+
'rbot0': rbot0,
|
369 |
+
'rbot1': rbot1,
|
370 |
+
'gradOutput': gradOutput,
|
371 |
+
'gradOne': None,
|
372 |
+
'gradTwo': gradTwo
|
373 |
+
}))(
|
374 |
+
grid=tuple([ int((n + 512 - 1) / 512), 1, 1 ]),
|
375 |
+
block=tuple([ 512, 1, 1 ]),
|
376 |
+
args=[ cupy.int32(n), intSample, rbot0.data_ptr(), rbot1.data_ptr(), gradOutput.data_ptr(), None, gradTwo.data_ptr() ]
|
377 |
+
)
|
378 |
+
# end
|
379 |
+
# end
|
380 |
+
|
381 |
+
elif one.is_cuda == False:
|
382 |
+
raise NotImplementedError()
|
383 |
+
|
384 |
+
# end
|
385 |
+
|
386 |
+
return gradOne, gradTwo
|
387 |
+
# end
|
388 |
+
# end
|
389 |
+
|
390 |
+
def FunctionCorrelation(tenOne, tenTwo):
|
391 |
+
return _FunctionCorrelation.apply(tenOne, tenTwo)
|
392 |
+
# end
|
393 |
+
|
394 |
+
class ModuleCorrelation(torch.nn.Module):
|
395 |
+
def __init__(self):
|
396 |
+
super().__init__()
|
397 |
+
# end
|
398 |
+
|
399 |
+
def forward(self, tenOne, tenTwo):
|
400 |
+
return _FunctionCorrelation.apply(tenOne, tenTwo)
|
401 |
+
# end
|
402 |
+
# end
|
modules/cupy_module/cupy_utils.py
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import cupy
|
2 |
+
|
3 |
+
#@cupy.memoize(for_each_device=True)
|
4 |
+
def cupy_launch(strFunction, strKernel):
|
5 |
+
# return cupy.cuda.compile_with_cache(strKernel).get_function(strFunction)
|
6 |
+
return cupy.RawKernel(strKernel, strFunction)
|
7 |
+
# end
|
modules/cupy_module/nedt.py
ADDED
@@ -0,0 +1,129 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import cupy
|
3 |
+
import kornia
|
4 |
+
import torch.nn as nn
|
5 |
+
|
6 |
+
from modules.cupy_module.cupy_utils import cupy_launch
|
7 |
+
# Code taken from https://github.com/ShuhongChen/eisai-anime-interpolator
|
8 |
+
|
9 |
+
_batch_edt_kernel = ('kernel_dt', '''
|
10 |
+
extern "C" __global__ void kernel_dt(
|
11 |
+
const int bs,
|
12 |
+
const int h,
|
13 |
+
const int w,
|
14 |
+
const float diam2,
|
15 |
+
float* data,
|
16 |
+
float* output
|
17 |
+
) {
|
18 |
+
int idx = blockIdx.x * blockDim.x + threadIdx.x;
|
19 |
+
if (idx >= bs*h*w) {
|
20 |
+
return;
|
21 |
+
}
|
22 |
+
int pb = idx / (h*w);
|
23 |
+
int pi = (idx - h*w*pb) / w;
|
24 |
+
int pj = (idx - h*w*pb - w*pi);
|
25 |
+
|
26 |
+
float cost;
|
27 |
+
float mincost = diam2;
|
28 |
+
for (int j = 0; j < w; j++) {
|
29 |
+
cost = data[h*w*pb + w*pi + j] + (pj-j)*(pj-j);
|
30 |
+
if (cost < mincost) {
|
31 |
+
mincost = cost;
|
32 |
+
}
|
33 |
+
}
|
34 |
+
output[idx] = mincost;
|
35 |
+
return;
|
36 |
+
}
|
37 |
+
''')
|
38 |
+
|
39 |
+
class NEDT(nn.Module):
|
40 |
+
def __init__(self):
|
41 |
+
super().__init__()
|
42 |
+
|
43 |
+
def batch_edt(self, img, block=1024):
|
44 |
+
# must initialize cuda/cupy after forking
|
45 |
+
_batch_edt = cupy_launch(*_batch_edt_kernel)
|
46 |
+
|
47 |
+
# bookkeeppingg
|
48 |
+
if len(img.shape)==4:
|
49 |
+
assert img.shape[1]==1
|
50 |
+
img = img.squeeze(1)
|
51 |
+
expand = True
|
52 |
+
else:
|
53 |
+
expand = False
|
54 |
+
bs,h,w = img.shape
|
55 |
+
diam2 = h**2 + w**2
|
56 |
+
odtype = img.dtype
|
57 |
+
grid = (img.nelement()+block-1) // block
|
58 |
+
|
59 |
+
# first pass, y-axis
|
60 |
+
data = ((1-img.type(torch.float32)) * diam2).contiguous()
|
61 |
+
intermed = torch.zeros_like(data)
|
62 |
+
_batch_edt(
|
63 |
+
grid=(grid, 1, 1),
|
64 |
+
block=(block, 1, 1), # < 1024
|
65 |
+
args=[
|
66 |
+
cupy.int32(bs),
|
67 |
+
cupy.int32(h),
|
68 |
+
cupy.int32(w),
|
69 |
+
cupy.float32(diam2),
|
70 |
+
data.data_ptr(),
|
71 |
+
intermed.data_ptr(),
|
72 |
+
],
|
73 |
+
)
|
74 |
+
|
75 |
+
# second pass, x-axis
|
76 |
+
intermed = intermed.permute(0,2,1).contiguous()
|
77 |
+
out = torch.zeros_like(intermed)
|
78 |
+
_batch_edt(
|
79 |
+
grid=(grid, 1, 1),
|
80 |
+
block=(block, 1, 1),
|
81 |
+
args=[
|
82 |
+
cupy.int32(bs),
|
83 |
+
cupy.int32(w),
|
84 |
+
cupy.int32(h),
|
85 |
+
cupy.float32(diam2),
|
86 |
+
intermed.data_ptr(),
|
87 |
+
out.data_ptr(),
|
88 |
+
],
|
89 |
+
)
|
90 |
+
ans = out.permute(0,2,1).sqrt()
|
91 |
+
ans = ans.type(odtype) if odtype!=ans.dtype else ans
|
92 |
+
|
93 |
+
if expand:
|
94 |
+
ans = ans.unsqueeze(1)
|
95 |
+
return ans
|
96 |
+
|
97 |
+
def batch_dog(self, img, t=1.0, sigma=1.0, k=1.6, epsilon=0.01, kernel_factor=4, clip=True):
|
98 |
+
# to grayscale if needed
|
99 |
+
bs,ch,h,w = img.shape
|
100 |
+
if ch in [3,4]:
|
101 |
+
img = kornia.color.rgb_to_grayscale(img[:,:3])
|
102 |
+
else:
|
103 |
+
assert ch==1
|
104 |
+
|
105 |
+
# calculate dog
|
106 |
+
kern0 = max(2*int(sigma*kernel_factor)+1, 3)
|
107 |
+
kern1 = max(2*int(sigma*k*kernel_factor)+1, 3)
|
108 |
+
g0 = kornia.filters.gaussian_blur2d(
|
109 |
+
img, (kern0,kern0), (sigma,sigma), border_type='replicate',
|
110 |
+
)
|
111 |
+
g1 = kornia.filters.gaussian_blur2d(
|
112 |
+
img, (kern1,kern1), (sigma*k,sigma*k), border_type='replicate',
|
113 |
+
)
|
114 |
+
out = 0.5 + t*(g1 - g0) - epsilon
|
115 |
+
out = out.clip(0,1) if clip else out
|
116 |
+
return out
|
117 |
+
|
118 |
+
def forward(
|
119 |
+
self, img, t=2.0, sigma_factor=1/540,
|
120 |
+
k=1.6, epsilon=0.01,
|
121 |
+
kernel_factor=4, exp_factor=540/15
|
122 |
+
):
|
123 |
+
dog = self.batch_dog(
|
124 |
+
img, t=t, sigma=img.shape[-2]*sigma_factor, k=k,
|
125 |
+
epsilon=epsilon, kernel_factor=kernel_factor, clip=False,
|
126 |
+
)
|
127 |
+
edt = self.batch_edt((dog > 0.5).float())
|
128 |
+
out = 1 - (-edt*exp_factor / max(edt.shape[-2:])).exp()
|
129 |
+
return out
|
modules/cupy_module/softsplat.py
ADDED
@@ -0,0 +1,368 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import re
|
3 |
+
import cupy
|
4 |
+
|
5 |
+
from modules.cupy_module.cupy_utils import cupy_launch
|
6 |
+
|
7 |
+
# Code from https://github.com/sniklaus/softmax-splatting/blob/master/softsplat.py
|
8 |
+
|
9 |
+
kernel_Softsplat_updateOutput = '''
|
10 |
+
extern "C" __global__ void kernel_Softsplat_updateOutput(
|
11 |
+
const int n,
|
12 |
+
const float* input,
|
13 |
+
const float* flow,
|
14 |
+
float* output
|
15 |
+
) { for (int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x; intIndex < n; intIndex += blockDim.x * gridDim.x) {
|
16 |
+
const int intN = ( intIndex / SIZE_3(output) / SIZE_2(output) / SIZE_1(output) ) % SIZE_0(output);
|
17 |
+
const int intC = ( intIndex / SIZE_3(output) / SIZE_2(output) ) % SIZE_1(output);
|
18 |
+
const int intY = ( intIndex / SIZE_3(output) ) % SIZE_2(output);
|
19 |
+
const int intX = ( intIndex ) % SIZE_3(output);
|
20 |
+
|
21 |
+
float fltOutputX = (float) (intX) + VALUE_4(flow, intN, 0, intY, intX);
|
22 |
+
float fltOutputY = (float) (intY) + VALUE_4(flow, intN, 1, intY, intX);
|
23 |
+
|
24 |
+
int intNorthwestX = (int) (floor(fltOutputX));
|
25 |
+
int intNorthwestY = (int) (floor(fltOutputY));
|
26 |
+
int intNortheastX = intNorthwestX + 1;
|
27 |
+
int intNortheastY = intNorthwestY;
|
28 |
+
int intSouthwestX = intNorthwestX;
|
29 |
+
int intSouthwestY = intNorthwestY + 1;
|
30 |
+
int intSoutheastX = intNorthwestX + 1;
|
31 |
+
int intSoutheastY = intNorthwestY + 1;
|
32 |
+
|
33 |
+
float fltNorthwest = ((float) (intSoutheastX) - fltOutputX) * ((float) (intSoutheastY) - fltOutputY);
|
34 |
+
float fltNortheast = (fltOutputX - (float) (intSouthwestX)) * ((float) (intSouthwestY) - fltOutputY);
|
35 |
+
float fltSouthwest = ((float) (intNortheastX) - fltOutputX) * (fltOutputY - (float) (intNortheastY));
|
36 |
+
float fltSoutheast = (fltOutputX - (float) (intNorthwestX)) * (fltOutputY - (float) (intNorthwestY));
|
37 |
+
|
38 |
+
if ((intNorthwestX >= 0) & (intNorthwestX < SIZE_3(output)) & (intNorthwestY >= 0) & (intNorthwestY < SIZE_2(output))) {
|
39 |
+
atomicAdd(&output[OFFSET_4(output, intN, intC, intNorthwestY, intNorthwestX)], VALUE_4(input, intN, intC, intY, intX) * fltNorthwest);
|
40 |
+
}
|
41 |
+
|
42 |
+
if ((intNortheastX >= 0) & (intNortheastX < SIZE_3(output)) & (intNortheastY >= 0) & (intNortheastY < SIZE_2(output))) {
|
43 |
+
atomicAdd(&output[OFFSET_4(output, intN, intC, intNortheastY, intNortheastX)], VALUE_4(input, intN, intC, intY, intX) * fltNortheast);
|
44 |
+
}
|
45 |
+
|
46 |
+
if ((intSouthwestX >= 0) & (intSouthwestX < SIZE_3(output)) & (intSouthwestY >= 0) & (intSouthwestY < SIZE_2(output))) {
|
47 |
+
atomicAdd(&output[OFFSET_4(output, intN, intC, intSouthwestY, intSouthwestX)], VALUE_4(input, intN, intC, intY, intX) * fltSouthwest);
|
48 |
+
}
|
49 |
+
|
50 |
+
if ((intSoutheastX >= 0) & (intSoutheastX < SIZE_3(output)) & (intSoutheastY >= 0) & (intSoutheastY < SIZE_2(output))) {
|
51 |
+
atomicAdd(&output[OFFSET_4(output, intN, intC, intSoutheastY, intSoutheastX)], VALUE_4(input, intN, intC, intY, intX) * fltSoutheast);
|
52 |
+
}
|
53 |
+
} }
|
54 |
+
'''
|
55 |
+
|
56 |
+
kernel_Softsplat_updateGradInput = '''
|
57 |
+
extern "C" __global__ void kernel_Softsplat_updateGradInput(
|
58 |
+
const int n,
|
59 |
+
const float* input,
|
60 |
+
const float* flow,
|
61 |
+
const float* gradOutput,
|
62 |
+
float* gradInput,
|
63 |
+
float* gradFlow
|
64 |
+
) { for (int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x; intIndex < n; intIndex += blockDim.x * gridDim.x) {
|
65 |
+
const int intN = ( intIndex / SIZE_3(gradInput) / SIZE_2(gradInput) / SIZE_1(gradInput) ) % SIZE_0(gradInput);
|
66 |
+
const int intC = ( intIndex / SIZE_3(gradInput) / SIZE_2(gradInput) ) % SIZE_1(gradInput);
|
67 |
+
const int intY = ( intIndex / SIZE_3(gradInput) ) % SIZE_2(gradInput);
|
68 |
+
const int intX = ( intIndex ) % SIZE_3(gradInput);
|
69 |
+
|
70 |
+
float fltGradInput = 0.0;
|
71 |
+
|
72 |
+
float fltOutputX = (float) (intX) + VALUE_4(flow, intN, 0, intY, intX);
|
73 |
+
float fltOutputY = (float) (intY) + VALUE_4(flow, intN, 1, intY, intX);
|
74 |
+
|
75 |
+
int intNorthwestX = (int) (floor(fltOutputX));
|
76 |
+
int intNorthwestY = (int) (floor(fltOutputY));
|
77 |
+
int intNortheastX = intNorthwestX + 1;
|
78 |
+
int intNortheastY = intNorthwestY;
|
79 |
+
int intSouthwestX = intNorthwestX;
|
80 |
+
int intSouthwestY = intNorthwestY + 1;
|
81 |
+
int intSoutheastX = intNorthwestX + 1;
|
82 |
+
int intSoutheastY = intNorthwestY + 1;
|
83 |
+
|
84 |
+
float fltNorthwest = ((float) (intSoutheastX) - fltOutputX) * ((float) (intSoutheastY) - fltOutputY);
|
85 |
+
float fltNortheast = (fltOutputX - (float) (intSouthwestX)) * ((float) (intSouthwestY) - fltOutputY);
|
86 |
+
float fltSouthwest = ((float) (intNortheastX) - fltOutputX) * (fltOutputY - (float) (intNortheastY));
|
87 |
+
float fltSoutheast = (fltOutputX - (float) (intNorthwestX)) * (fltOutputY - (float) (intNorthwestY));
|
88 |
+
|
89 |
+
if ((intNorthwestX >= 0) & (intNorthwestX < SIZE_3(gradOutput)) & (intNorthwestY >= 0) & (intNorthwestY < SIZE_2(gradOutput))) {
|
90 |
+
fltGradInput += VALUE_4(gradOutput, intN, intC, intNorthwestY, intNorthwestX) * fltNorthwest;
|
91 |
+
}
|
92 |
+
|
93 |
+
if ((intNortheastX >= 0) & (intNortheastX < SIZE_3(gradOutput)) & (intNortheastY >= 0) & (intNortheastY < SIZE_2(gradOutput))) {
|
94 |
+
fltGradInput += VALUE_4(gradOutput, intN, intC, intNortheastY, intNortheastX) * fltNortheast;
|
95 |
+
}
|
96 |
+
|
97 |
+
if ((intSouthwestX >= 0) & (intSouthwestX < SIZE_3(gradOutput)) & (intSouthwestY >= 0) & (intSouthwestY < SIZE_2(gradOutput))) {
|
98 |
+
fltGradInput += VALUE_4(gradOutput, intN, intC, intSouthwestY, intSouthwestX) * fltSouthwest;
|
99 |
+
}
|
100 |
+
|
101 |
+
if ((intSoutheastX >= 0) & (intSoutheastX < SIZE_3(gradOutput)) & (intSoutheastY >= 0) & (intSoutheastY < SIZE_2(gradOutput))) {
|
102 |
+
fltGradInput += VALUE_4(gradOutput, intN, intC, intSoutheastY, intSoutheastX) * fltSoutheast;
|
103 |
+
}
|
104 |
+
|
105 |
+
gradInput[intIndex] = fltGradInput;
|
106 |
+
} }
|
107 |
+
'''
|
108 |
+
|
109 |
+
kernel_Softsplat_updateGradFlow = '''
|
110 |
+
extern "C" __global__ void kernel_Softsplat_updateGradFlow(
|
111 |
+
const int n,
|
112 |
+
const float* input,
|
113 |
+
const float* flow,
|
114 |
+
const float* gradOutput,
|
115 |
+
float* gradInput,
|
116 |
+
float* gradFlow
|
117 |
+
) { for (int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x; intIndex < n; intIndex += blockDim.x * gridDim.x) {
|
118 |
+
float fltGradFlow = 0.0;
|
119 |
+
|
120 |
+
const int intN = ( intIndex / SIZE_3(gradFlow) / SIZE_2(gradFlow) / SIZE_1(gradFlow) ) % SIZE_0(gradFlow);
|
121 |
+
const int intC = ( intIndex / SIZE_3(gradFlow) / SIZE_2(gradFlow) ) % SIZE_1(gradFlow);
|
122 |
+
const int intY = ( intIndex / SIZE_3(gradFlow) ) % SIZE_2(gradFlow);
|
123 |
+
const int intX = ( intIndex ) % SIZE_3(gradFlow);
|
124 |
+
|
125 |
+
float fltOutputX = (float) (intX) + VALUE_4(flow, intN, 0, intY, intX);
|
126 |
+
float fltOutputY = (float) (intY) + VALUE_4(flow, intN, 1, intY, intX);
|
127 |
+
|
128 |
+
int intNorthwestX = (int) (floor(fltOutputX));
|
129 |
+
int intNorthwestY = (int) (floor(fltOutputY));
|
130 |
+
int intNortheastX = intNorthwestX + 1;
|
131 |
+
int intNortheastY = intNorthwestY;
|
132 |
+
int intSouthwestX = intNorthwestX;
|
133 |
+
int intSouthwestY = intNorthwestY + 1;
|
134 |
+
int intSoutheastX = intNorthwestX + 1;
|
135 |
+
int intSoutheastY = intNorthwestY + 1;
|
136 |
+
|
137 |
+
float fltNorthwest = 0.0;
|
138 |
+
float fltNortheast = 0.0;
|
139 |
+
float fltSouthwest = 0.0;
|
140 |
+
float fltSoutheast = 0.0;
|
141 |
+
|
142 |
+
if (intC == 0) {
|
143 |
+
fltNorthwest = ((float) (-1.0)) * ((float) (intSoutheastY) - fltOutputY);
|
144 |
+
fltNortheast = ((float) (+1.0)) * ((float) (intSouthwestY) - fltOutputY);
|
145 |
+
fltSouthwest = ((float) (-1.0)) * (fltOutputY - (float) (intNortheastY));
|
146 |
+
fltSoutheast = ((float) (+1.0)) * (fltOutputY - (float) (intNorthwestY));
|
147 |
+
|
148 |
+
} else if (intC == 1) {
|
149 |
+
fltNorthwest = ((float) (intSoutheastX) - fltOutputX) * ((float) (-1.0));
|
150 |
+
fltNortheast = (fltOutputX - (float) (intSouthwestX)) * ((float) (-1.0));
|
151 |
+
fltSouthwest = ((float) (intNortheastX) - fltOutputX) * ((float) (+1.0));
|
152 |
+
fltSoutheast = (fltOutputX - (float) (intNorthwestX)) * ((float) (+1.0));
|
153 |
+
|
154 |
+
}
|
155 |
+
|
156 |
+
for (int intChannel = 0; intChannel < SIZE_1(gradOutput); intChannel += 1) {
|
157 |
+
float fltInput = VALUE_4(input, intN, intChannel, intY, intX);
|
158 |
+
|
159 |
+
if ((intNorthwestX >= 0) & (intNorthwestX < SIZE_3(gradOutput)) & (intNorthwestY >= 0) & (intNorthwestY < SIZE_2(gradOutput))) {
|
160 |
+
fltGradFlow += fltInput * VALUE_4(gradOutput, intN, intChannel, intNorthwestY, intNorthwestX) * fltNorthwest;
|
161 |
+
}
|
162 |
+
|
163 |
+
if ((intNortheastX >= 0) & (intNortheastX < SIZE_3(gradOutput)) & (intNortheastY >= 0) & (intNortheastY < SIZE_2(gradOutput))) {
|
164 |
+
fltGradFlow += fltInput * VALUE_4(gradOutput, intN, intChannel, intNortheastY, intNortheastX) * fltNortheast;
|
165 |
+
}
|
166 |
+
|
167 |
+
if ((intSouthwestX >= 0) & (intSouthwestX < SIZE_3(gradOutput)) & (intSouthwestY >= 0) & (intSouthwestY < SIZE_2(gradOutput))) {
|
168 |
+
fltGradFlow += fltInput * VALUE_4(gradOutput, intN, intChannel, intSouthwestY, intSouthwestX) * fltSouthwest;
|
169 |
+
}
|
170 |
+
|
171 |
+
if ((intSoutheastX >= 0) & (intSoutheastX < SIZE_3(gradOutput)) & (intSoutheastY >= 0) & (intSoutheastY < SIZE_2(gradOutput))) {
|
172 |
+
fltGradFlow += fltInput * VALUE_4(gradOutput, intN, intChannel, intSoutheastY, intSoutheastX) * fltSoutheast;
|
173 |
+
}
|
174 |
+
}
|
175 |
+
|
176 |
+
gradFlow[intIndex] = fltGradFlow;
|
177 |
+
} }
|
178 |
+
'''
|
179 |
+
|
180 |
+
def cupy_kernel(strFunction, objVariables):
|
181 |
+
strKernel = globals()[strFunction]
|
182 |
+
|
183 |
+
while True:
|
184 |
+
objMatch = re.search('(SIZE_)([0-4])(\()([^\)]*)(\))', strKernel)
|
185 |
+
|
186 |
+
if objMatch is None:
|
187 |
+
break
|
188 |
+
# end
|
189 |
+
|
190 |
+
intArg = int(objMatch.group(2))
|
191 |
+
|
192 |
+
strTensor = objMatch.group(4)
|
193 |
+
intSizes = objVariables[strTensor].size()
|
194 |
+
|
195 |
+
strKernel = strKernel.replace(objMatch.group(), str(intSizes[intArg]))
|
196 |
+
# end
|
197 |
+
|
198 |
+
while True:
|
199 |
+
objMatch = re.search('(OFFSET_)([0-4])(\()([^\)]+)(\))', strKernel)
|
200 |
+
|
201 |
+
if objMatch is None:
|
202 |
+
break
|
203 |
+
# end
|
204 |
+
|
205 |
+
intArgs = int(objMatch.group(2))
|
206 |
+
strArgs = objMatch.group(4).split(',')
|
207 |
+
|
208 |
+
strTensor = strArgs[0]
|
209 |
+
intStrides = objVariables[strTensor].stride()
|
210 |
+
strIndex = [ '((' + strArgs[intArg + 1].replace('{', '(').replace('}', ')').strip() + ')*' + str(intStrides[intArg]) + ')' for intArg in range(intArgs) ]
|
211 |
+
|
212 |
+
strKernel = strKernel.replace(objMatch.group(0), '(' + str.join('+', strIndex) + ')')
|
213 |
+
# end
|
214 |
+
|
215 |
+
while True:
|
216 |
+
objMatch = re.search('(VALUE_)([0-4])(\()([^\)]+)(\))', strKernel)
|
217 |
+
|
218 |
+
if objMatch is None:
|
219 |
+
break
|
220 |
+
# end
|
221 |
+
|
222 |
+
intArgs = int(objMatch.group(2))
|
223 |
+
strArgs = objMatch.group(4).split(',')
|
224 |
+
|
225 |
+
strTensor = strArgs[0]
|
226 |
+
intStrides = objVariables[strTensor].stride()
|
227 |
+
strIndex = [ '((' + strArgs[intArg + 1].replace('{', '(').replace('}', ')').strip() + ')*' + str(intStrides[intArg]) + ')' for intArg in range(intArgs) ]
|
228 |
+
|
229 |
+
strKernel = strKernel.replace(objMatch.group(0), strTensor + '[' + str.join('+', strIndex) + ']')
|
230 |
+
# end
|
231 |
+
|
232 |
+
return strKernel
|
233 |
+
# end
|
234 |
+
|
235 |
+
class _FunctionSoftsplat(torch.autograd.Function):
|
236 |
+
@staticmethod
|
237 |
+
def forward(self, input, flow):
|
238 |
+
intSamples = input.shape[0]
|
239 |
+
intInputDepth, intInputHeight, intInputWidth = input.shape[1], input.shape[2], input.shape[3]
|
240 |
+
intFlowDepth, intFlowHeight, intFlowWidth = flow.shape[1], flow.shape[2], flow.shape[3]
|
241 |
+
|
242 |
+
assert(intFlowDepth == 2)
|
243 |
+
assert(intInputHeight == intFlowHeight)
|
244 |
+
assert(intInputWidth == intFlowWidth)
|
245 |
+
|
246 |
+
input = input.contiguous(); assert(input.is_cuda == True)
|
247 |
+
flow = flow.contiguous(); assert(flow.is_cuda == True)
|
248 |
+
|
249 |
+
output = input.new_zeros([ intSamples, intInputDepth, intInputHeight, intInputWidth ])
|
250 |
+
|
251 |
+
if input.is_cuda == True:
|
252 |
+
n = output.nelement()
|
253 |
+
cupy_launch('kernel_Softsplat_updateOutput', cupy_kernel('kernel_Softsplat_updateOutput', {
|
254 |
+
'input': input,
|
255 |
+
'flow': flow,
|
256 |
+
'output': output
|
257 |
+
}))(
|
258 |
+
grid=tuple([ int((n + 512 - 1) / 512), 1, 1 ]),
|
259 |
+
block=tuple([ 512, 1, 1 ]),
|
260 |
+
args=[ cupy.int32(n), input.data_ptr(), flow.data_ptr(), output.data_ptr() ]
|
261 |
+
)
|
262 |
+
|
263 |
+
elif input.is_cuda == False:
|
264 |
+
raise NotImplementedError()
|
265 |
+
|
266 |
+
# end
|
267 |
+
|
268 |
+
self.save_for_backward(input, flow)
|
269 |
+
|
270 |
+
return output
|
271 |
+
# end
|
272 |
+
|
273 |
+
@staticmethod
|
274 |
+
def backward(self, gradOutput):
|
275 |
+
input, flow = self.saved_tensors
|
276 |
+
|
277 |
+
intSamples = input.shape[0]
|
278 |
+
intInputDepth, intInputHeight, intInputWidth = input.shape[1], input.shape[2], input.shape[3]
|
279 |
+
intFlowDepth, intFlowHeight, intFlowWidth = flow.shape[1], flow.shape[2], flow.shape[3]
|
280 |
+
|
281 |
+
assert(intFlowDepth == 2)
|
282 |
+
assert(intInputHeight == intFlowHeight)
|
283 |
+
assert(intInputWidth == intFlowWidth)
|
284 |
+
|
285 |
+
gradOutput = gradOutput.contiguous(); assert(gradOutput.is_cuda == True)
|
286 |
+
|
287 |
+
gradInput = input.new_zeros([ intSamples, intInputDepth, intInputHeight, intInputWidth ]) if self.needs_input_grad[0] == True else None
|
288 |
+
gradFlow = input.new_zeros([ intSamples, intFlowDepth, intFlowHeight, intFlowWidth ]) if self.needs_input_grad[1] == True else None
|
289 |
+
|
290 |
+
if input.is_cuda == True:
|
291 |
+
if gradInput is not None:
|
292 |
+
n = gradInput.nelement()
|
293 |
+
cupy_launch('kernel_Softsplat_updateGradInput', cupy_kernel('kernel_Softsplat_updateGradInput', {
|
294 |
+
'input': input,
|
295 |
+
'flow': flow,
|
296 |
+
'gradOutput': gradOutput,
|
297 |
+
'gradInput': gradInput,
|
298 |
+
'gradFlow': gradFlow
|
299 |
+
}))(
|
300 |
+
grid=tuple([ int((n + 512 - 1) / 512), 1, 1 ]),
|
301 |
+
block=tuple([ 512, 1, 1 ]),
|
302 |
+
args=[ cupy.int32(n), input.data_ptr(), flow.data_ptr(), gradOutput.data_ptr(), gradInput.data_ptr(), None ]
|
303 |
+
)
|
304 |
+
# end
|
305 |
+
|
306 |
+
if gradFlow is not None:
|
307 |
+
n = gradFlow.nelement()
|
308 |
+
cupy_launch('kernel_Softsplat_updateGradFlow', cupy_kernel('kernel_Softsplat_updateGradFlow', {
|
309 |
+
'input': input,
|
310 |
+
'flow': flow,
|
311 |
+
'gradOutput': gradOutput,
|
312 |
+
'gradInput': gradInput,
|
313 |
+
'gradFlow': gradFlow
|
314 |
+
}))(
|
315 |
+
grid=tuple([ int((n + 512 - 1) / 512), 1, 1 ]),
|
316 |
+
block=tuple([ 512, 1, 1 ]),
|
317 |
+
args=[ cupy.int32(n), input.data_ptr(), flow.data_ptr(), gradOutput.data_ptr(), None, gradFlow.data_ptr() ]
|
318 |
+
)
|
319 |
+
# end
|
320 |
+
|
321 |
+
elif input.is_cuda == False:
|
322 |
+
raise NotImplementedError()
|
323 |
+
|
324 |
+
# end
|
325 |
+
|
326 |
+
return gradInput, gradFlow
|
327 |
+
# end
|
328 |
+
# end
|
329 |
+
|
330 |
+
def FunctionSoftsplat(tenInput, tenFlow, tenMetric, strType):
|
331 |
+
assert(tenMetric is None or tenMetric.shape[1] == 1)
|
332 |
+
assert(strType in ['summation', 'average', 'linear', 'softmax'])
|
333 |
+
|
334 |
+
if strType == 'average':
|
335 |
+
tenInput = torch.cat([ tenInput, tenInput.new_ones(tenInput.shape[0], 1, tenInput.shape[2], tenInput.shape[3]) ], 1)
|
336 |
+
|
337 |
+
elif strType == 'linear':
|
338 |
+
tenInput = torch.cat([ tenInput * tenMetric, tenMetric ], 1)
|
339 |
+
|
340 |
+
elif strType == 'softmax':
|
341 |
+
tenInput = torch.cat([ tenInput * tenMetric.exp(), tenMetric.exp() ], 1)
|
342 |
+
|
343 |
+
# end
|
344 |
+
|
345 |
+
tenOutput = _FunctionSoftsplat.apply(tenInput, tenFlow)
|
346 |
+
|
347 |
+
if strType != 'summation':
|
348 |
+
tenNormalize = tenOutput[:, -1:, :, :]
|
349 |
+
|
350 |
+
tenNormalize[tenNormalize == 0.0] = 1.0
|
351 |
+
|
352 |
+
tenOutput = tenOutput[:, :-1, :, :] / tenNormalize
|
353 |
+
# end
|
354 |
+
|
355 |
+
return tenOutput
|
356 |
+
# end
|
357 |
+
|
358 |
+
class ModuleSoftsplat(torch.nn.Module):
|
359 |
+
def __init__(self, strType):
|
360 |
+
super().__init__()
|
361 |
+
|
362 |
+
self.strType = strType
|
363 |
+
# end
|
364 |
+
|
365 |
+
def forward(self, tenInput, tenFlow, tenMetric):
|
366 |
+
return FunctionSoftsplat(tenInput, tenFlow, tenMetric, self.strType)
|
367 |
+
# end
|
368 |
+
# end
|
modules/feature_extactor.py
ADDED
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
|
4 |
+
import torchvision.models as models
|
5 |
+
|
6 |
+
from modules.basic_layers import GroupNorm
|
7 |
+
|
8 |
+
class Extractor(nn.Module):
|
9 |
+
def __init__(self, channels: list[int], num_groups: int = 32, use_residual: bool = True):
|
10 |
+
super().__init__()
|
11 |
+
|
12 |
+
self.use_residual = use_residual
|
13 |
+
|
14 |
+
self.layers = nn.ModuleList([
|
15 |
+
nn.Sequential(
|
16 |
+
nn.Conv2d(in_channels=channels[i], out_channels=channels[i + 1], kernel_size=3, stride=2, padding=1),
|
17 |
+
GroupNorm(channels[i + 1], num_groups = num_groups),
|
18 |
+
nn.SiLU(),
|
19 |
+
nn.Conv2d(in_channels=channels[i + 1], out_channels=channels[i + 1], kernel_size=3, stride=1, padding=1),
|
20 |
+
GroupNorm(channels[i + 1], num_groups = num_groups),
|
21 |
+
nn.SiLU()
|
22 |
+
) for i in range(len(channels) - 1)
|
23 |
+
])
|
24 |
+
if self.use_residual:
|
25 |
+
self.residual = nn.ModuleList([
|
26 |
+
nn.Sequential(
|
27 |
+
nn.Conv2d(in_channels=channels[i], out_channels=channels[i + 1], kernel_size=3, stride=2, padding=1),
|
28 |
+
) for i in range(len(channels) - 1)
|
29 |
+
])
|
30 |
+
|
31 |
+
def forward(self, x: torch.Tensor) -> list[torch.Tensor]:
|
32 |
+
features = []
|
33 |
+
for residual, layer in zip(self.residual, self.layers):
|
34 |
+
if self.use_residual:
|
35 |
+
x = layer(x) + residual(x)
|
36 |
+
else:
|
37 |
+
x = layer(x)
|
38 |
+
features.append(x)
|
39 |
+
return features
|
40 |
+
|
41 |
+
|
42 |
+
class ResNetExtractor(nn.Module):
|
43 |
+
def __init__(self, pretrained: bool = True, layers_to_extract: list[str] = ["layer1", "layer2", "layer3"]):
|
44 |
+
super(ResNetExtractor, self).__init__()
|
45 |
+
|
46 |
+
resnet = models.resnet18(pretrained=pretrained)
|
47 |
+
|
48 |
+
self.initial_layers = nn.Sequential(
|
49 |
+
resnet.conv1,
|
50 |
+
resnet.bn1,
|
51 |
+
resnet.relu
|
52 |
+
)
|
53 |
+
|
54 |
+
self.layers = nn.ModuleDict({
|
55 |
+
"layer1": resnet.layer1,
|
56 |
+
"layer2": resnet.layer2,
|
57 |
+
"layer3": resnet.layer3,
|
58 |
+
})
|
59 |
+
|
60 |
+
self.layers_to_extract = layers_to_extract
|
61 |
+
|
62 |
+
def forward(self, x: torch.Tensor) -> list[torch.Tensor]:
|
63 |
+
features = []
|
64 |
+
x = self.initial_layers(x)
|
65 |
+
|
66 |
+
for name, layer in self.layers.items():
|
67 |
+
x = layer(x)
|
68 |
+
if name in self.layers_to_extract:
|
69 |
+
features.append(x)
|
70 |
+
|
71 |
+
return features
|
72 |
+
|
73 |
+
class VGGExtractor(nn.Module):
|
74 |
+
def __init__(self, layers_to_extract: list[int] = [8, 15, 22, 29]):
|
75 |
+
super(VGGExtractor, self).__init__()
|
76 |
+
|
77 |
+
self.vgg = models.vgg16(pretrained=True).features
|
78 |
+
self.layers_to_extract = layers_to_extract
|
79 |
+
self.selected_layers = [self.vgg[i] for i in layers_to_extract]
|
80 |
+
|
81 |
+
def forward(self, x: torch.Tensor) -> list[torch.Tensor]:
|
82 |
+
features = []
|
83 |
+
for i, layer in enumerate(self.vgg):
|
84 |
+
x = layer(x)
|
85 |
+
if i in self.layers_to_extract:
|
86 |
+
features.append(x)
|
87 |
+
return features
|
modules/flow_models/flow_models.py
ADDED
@@ -0,0 +1,102 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
from torch.nn.functional import interpolate
|
4 |
+
|
5 |
+
from modules.cupy_module import correlation
|
6 |
+
from modules.half_warper import HalfWarper
|
7 |
+
from modules.feature_extactor import Extractor
|
8 |
+
|
9 |
+
from modules.flow_models.raft.rfr_new import RAFT
|
10 |
+
|
11 |
+
class Decoder(nn.Module):
|
12 |
+
def __init__(self, in_channels: int):
|
13 |
+
super().__init__()
|
14 |
+
|
15 |
+
self.syntesis = nn.Sequential(
|
16 |
+
nn.Conv2d(in_channels=in_channels, out_channels=128, kernel_size=3, stride=1, padding=1),
|
17 |
+
nn.SiLU(),
|
18 |
+
nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=1),
|
19 |
+
nn.SiLU(),
|
20 |
+
nn.Conv2d(in_channels=128, out_channels=96, kernel_size=3, stride=1, padding=1),
|
21 |
+
nn.SiLU(),
|
22 |
+
nn.Conv2d(in_channels=96, out_channels=64, kernel_size=3, stride=1, padding=1),
|
23 |
+
nn.SiLU(),
|
24 |
+
nn.Conv2d(in_channels=64, out_channels=32, kernel_size=3, stride=1, padding=1),
|
25 |
+
nn.SiLU(),
|
26 |
+
nn.Conv2d(in_channels=32, out_channels=2, kernel_size=3, stride=1, padding=1)
|
27 |
+
)
|
28 |
+
|
29 |
+
def forward(self, img1: torch.Tensor, img2: torch.Tensor, residual: torch.Tensor | None) -> torch.Tensor:
|
30 |
+
width = img1.shape[3] and img2.shape[3]
|
31 |
+
height = img1.shape[2] and img2.shape[2]
|
32 |
+
|
33 |
+
if residual is None:
|
34 |
+
corr = correlation.FunctionCorrelation(tenOne=img1, tenTwo=img2)
|
35 |
+
main = torch.cat([img1, corr], dim=1)
|
36 |
+
else:
|
37 |
+
flow = interpolate(input=residual,
|
38 |
+
size=(height, width),
|
39 |
+
mode='bilinear',
|
40 |
+
align_corners=False) / \
|
41 |
+
float(residual.shape[3]) * float(width)
|
42 |
+
backwarp_img = HalfWarper.backward_wrapping(img=img2, flow=flow)
|
43 |
+
corr = correlation.FunctionCorrelation(tenOne=img1, tenTwo=backwarp_img)
|
44 |
+
main = torch.cat([img1, corr, flow], dim=1)
|
45 |
+
|
46 |
+
return self.syntesis(main)
|
47 |
+
|
48 |
+
class PWCFineFlow(nn.Module):
|
49 |
+
def __init__(self, pretrained_path: str | None = None):
|
50 |
+
super().__init__()
|
51 |
+
|
52 |
+
self.feature_extractor = Extractor([3, 16, 32, 64, 96, 128, 192], num_groups=16)
|
53 |
+
|
54 |
+
self.decoders = nn.ModuleList([
|
55 |
+
Decoder(16 + 81 + 2),
|
56 |
+
Decoder(32 + 81 + 2),
|
57 |
+
Decoder(64 + 81 + 2),
|
58 |
+
Decoder(96 + 81 + 2),
|
59 |
+
Decoder(128 + 81 + 2),
|
60 |
+
Decoder(192 + 81)
|
61 |
+
])
|
62 |
+
|
63 |
+
if pretrained_path is not None:
|
64 |
+
self.load_state_dict(torch.load(pretrained_path))
|
65 |
+
|
66 |
+
def forward(self, img1: torch.Tensor, img2: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
67 |
+
width = img1.shape[3] and img2.shape[3]
|
68 |
+
height = img1.shape[2] and img2.shape[2]
|
69 |
+
|
70 |
+
feats1 = self.feature_extractor(img1)
|
71 |
+
feats2 = self.feature_extractor(img2)
|
72 |
+
|
73 |
+
forward = None
|
74 |
+
backward = None
|
75 |
+
|
76 |
+
for i in reversed(range(len(feats1))):
|
77 |
+
forward = self.decoders[i](feats1[i], feats2[i], forward)
|
78 |
+
backward = self.decoders[i](feats2[i], feats1[i], backward)
|
79 |
+
|
80 |
+
forward = interpolate(input=forward,
|
81 |
+
size=(height, width),
|
82 |
+
mode='bilinear',
|
83 |
+
align_corners=False) * \
|
84 |
+
(float(width) / float(forward.shape[3]))
|
85 |
+
backward = interpolate(input=backward,
|
86 |
+
size=(height, width),
|
87 |
+
mode='bilinear',
|
88 |
+
align_corners=False) * \
|
89 |
+
(float(width) / float(backward.shape[3]))
|
90 |
+
|
91 |
+
return forward, backward
|
92 |
+
|
93 |
+
|
94 |
+
class RAFTFineFlow(nn.Module):
|
95 |
+
def __init__(self, pretrained_path: str | None = None):
|
96 |
+
super().__init__()
|
97 |
+
self.raft = RAFT(pretrained_path)
|
98 |
+
|
99 |
+
def forward(self, img1: torch.Tensor, img2: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
100 |
+
forward = self.raft(img1, img2)
|
101 |
+
backward = self.raft(img2, img1)
|
102 |
+
return forward, backward
|
modules/flow_models/raft/LICENSE
ADDED
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
BSD 3-Clause License
|
2 |
+
|
3 |
+
Copyright (c) 2020, princeton-vl
|
4 |
+
All rights reserved.
|
5 |
+
|
6 |
+
Redistribution and use in source and binary forms, with or without
|
7 |
+
modification, are permitted provided that the following conditions are met:
|
8 |
+
|
9 |
+
* Redistributions of source code must retain the above copyright notice, this
|
10 |
+
list of conditions and the following disclaimer.
|
11 |
+
|
12 |
+
* Redistributions in binary form must reproduce the above copyright notice,
|
13 |
+
this list of conditions and the following disclaimer in the documentation
|
14 |
+
and/or other materials provided with the distribution.
|
15 |
+
|
16 |
+
* Neither the name of the copyright holder nor the names of its
|
17 |
+
contributors may be used to endorse or promote products derived from
|
18 |
+
this software without specific prior written permission.
|
19 |
+
|
20 |
+
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
21 |
+
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
22 |
+
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
23 |
+
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
24 |
+
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
25 |
+
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
26 |
+
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
27 |
+
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
28 |
+
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
29 |
+
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
modules/flow_models/raft/corr.py
ADDED
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn.functional as F
|
3 |
+
from .utils import bilinear_sampler, coords_grid
|
4 |
+
|
5 |
+
|
6 |
+
class CorrBlock:
|
7 |
+
def __init__(self, fmap1, fmap2, num_levels=4, radius=4):
|
8 |
+
self.num_levels = num_levels
|
9 |
+
self.radius = radius
|
10 |
+
self.corr_pyramid = []
|
11 |
+
|
12 |
+
# all pairs correlation
|
13 |
+
corr = CorrBlock.corr(fmap1, fmap2)
|
14 |
+
|
15 |
+
batch, h1, w1, dim, h2, w2 = corr.shape
|
16 |
+
corr = corr.reshape(batch*h1*w1, dim, h2, w2)
|
17 |
+
|
18 |
+
self.corr_pyramid.append(corr)
|
19 |
+
for i in range(self.num_levels-1):
|
20 |
+
corr = F.avg_pool2d(corr, 2, stride=2)
|
21 |
+
self.corr_pyramid.append(corr)
|
22 |
+
|
23 |
+
def __call__(self, coords):
|
24 |
+
r = self.radius
|
25 |
+
coords = coords.permute(0, 2, 3, 1)
|
26 |
+
batch, h1, w1, _ = coords.shape
|
27 |
+
|
28 |
+
out_pyramid = []
|
29 |
+
for i in range(self.num_levels):
|
30 |
+
corr = self.corr_pyramid[i]
|
31 |
+
dx = torch.linspace(-r, r, 2*r+1)
|
32 |
+
dy = torch.linspace(-r, r, 2*r+1)
|
33 |
+
delta = torch.stack(torch.meshgrid(dy, dx), dim=-1).to(coords.device)
|
34 |
+
|
35 |
+
centroid_lvl = coords.reshape(batch*h1*w1, 1, 1, 2) / 2**i
|
36 |
+
delta_lvl = delta.view(1, 2*r+1, 2*r+1, 2)
|
37 |
+
coords_lvl = centroid_lvl + delta_lvl
|
38 |
+
|
39 |
+
corr = bilinear_sampler(corr, coords_lvl)
|
40 |
+
corr = corr.view(batch, h1, w1, -1)
|
41 |
+
out_pyramid.append(corr)
|
42 |
+
|
43 |
+
out = torch.cat(out_pyramid, dim=-1)
|
44 |
+
return out.permute(0, 3, 1, 2).contiguous().float()
|
45 |
+
|
46 |
+
|
47 |
+
@staticmethod
|
48 |
+
def corr(fmap1, fmap2):
|
49 |
+
batch, dim, ht, wd = fmap1.shape
|
50 |
+
fmap1 = fmap1.view(batch, dim, ht*wd)
|
51 |
+
fmap2 = fmap2.view(batch, dim, ht*wd)
|
52 |
+
|
53 |
+
corr = torch.matmul(fmap1.transpose(1,2), fmap2)
|
54 |
+
corr = corr.view(batch, ht, wd, 1, ht, wd)
|
55 |
+
return corr / torch.sqrt(torch.tensor(dim).float())
|
56 |
+
|
modules/flow_models/raft/extractor.py
ADDED
@@ -0,0 +1,342 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
|
5 |
+
|
6 |
+
class ResidualBlock(nn.Module):
|
7 |
+
def __init__(self, in_planes, planes, norm_fn='group', stride=1):
|
8 |
+
super(ResidualBlock, self).__init__()
|
9 |
+
|
10 |
+
self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, padding=1, stride=stride)
|
11 |
+
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1)
|
12 |
+
self.relu = nn.ReLU(inplace=True)
|
13 |
+
|
14 |
+
num_groups = planes // 8
|
15 |
+
|
16 |
+
if norm_fn == 'group':
|
17 |
+
self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
|
18 |
+
self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
|
19 |
+
if not stride == 1:
|
20 |
+
self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
|
21 |
+
|
22 |
+
elif norm_fn == 'batch':
|
23 |
+
self.norm1 = nn.BatchNorm2d(planes)
|
24 |
+
self.norm2 = nn.BatchNorm2d(planes)
|
25 |
+
if not stride == 1:
|
26 |
+
self.norm3 = nn.BatchNorm2d(planes)
|
27 |
+
|
28 |
+
elif norm_fn == 'instance':
|
29 |
+
self.norm1 = nn.InstanceNorm2d(planes)
|
30 |
+
self.norm2 = nn.InstanceNorm2d(planes)
|
31 |
+
if not stride == 1:
|
32 |
+
self.norm3 = nn.InstanceNorm2d(planes)
|
33 |
+
|
34 |
+
elif norm_fn == 'none':
|
35 |
+
self.norm1 = nn.Sequential()
|
36 |
+
self.norm2 = nn.Sequential()
|
37 |
+
if not stride == 1:
|
38 |
+
self.norm3 = nn.Sequential()
|
39 |
+
|
40 |
+
if stride == 1:
|
41 |
+
self.downsample = None
|
42 |
+
|
43 |
+
else:
|
44 |
+
self.downsample = nn.Sequential(
|
45 |
+
nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm3)
|
46 |
+
|
47 |
+
|
48 |
+
def forward(self, x):
|
49 |
+
y = x
|
50 |
+
y = self.relu(self.norm1(self.conv1(y)))
|
51 |
+
y = self.relu(self.norm2(self.conv2(y)))
|
52 |
+
|
53 |
+
if self.downsample is not None:
|
54 |
+
x = self.downsample(x)
|
55 |
+
|
56 |
+
return self.relu(x+y)
|
57 |
+
|
58 |
+
|
59 |
+
|
60 |
+
class BottleneckBlock(nn.Module):
|
61 |
+
def __init__(self, in_planes, planes, norm_fn='group', stride=1):
|
62 |
+
super(BottleneckBlock, self).__init__()
|
63 |
+
|
64 |
+
self.conv1 = nn.Conv2d(in_planes, planes//4, kernel_size=1, padding=0)
|
65 |
+
self.conv2 = nn.Conv2d(planes//4, planes//4, kernel_size=3, padding=1, stride=stride)
|
66 |
+
self.conv3 = nn.Conv2d(planes//4, planes, kernel_size=1, padding=0)
|
67 |
+
self.relu = nn.ReLU(inplace=True)
|
68 |
+
|
69 |
+
num_groups = planes // 8
|
70 |
+
|
71 |
+
if norm_fn == 'group':
|
72 |
+
self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes//4)
|
73 |
+
self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes//4)
|
74 |
+
self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
|
75 |
+
if not stride == 1:
|
76 |
+
self.norm4 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
|
77 |
+
|
78 |
+
elif norm_fn == 'batch':
|
79 |
+
self.norm1 = nn.BatchNorm2d(planes//4)
|
80 |
+
self.norm2 = nn.BatchNorm2d(planes//4)
|
81 |
+
self.norm3 = nn.BatchNorm2d(planes)
|
82 |
+
if not stride == 1:
|
83 |
+
self.norm4 = nn.BatchNorm2d(planes)
|
84 |
+
|
85 |
+
elif norm_fn == 'instance':
|
86 |
+
self.norm1 = nn.InstanceNorm2d(planes//4)
|
87 |
+
self.norm2 = nn.InstanceNorm2d(planes//4)
|
88 |
+
self.norm3 = nn.InstanceNorm2d(planes)
|
89 |
+
if not stride == 1:
|
90 |
+
self.norm4 = nn.InstanceNorm2d(planes)
|
91 |
+
|
92 |
+
elif norm_fn == 'none':
|
93 |
+
self.norm1 = nn.Sequential()
|
94 |
+
self.norm2 = nn.Sequential()
|
95 |
+
self.norm3 = nn.Sequential()
|
96 |
+
if not stride == 1:
|
97 |
+
self.norm4 = nn.Sequential()
|
98 |
+
|
99 |
+
if stride == 1:
|
100 |
+
self.downsample = None
|
101 |
+
|
102 |
+
else:
|
103 |
+
self.downsample = nn.Sequential(
|
104 |
+
nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm4)
|
105 |
+
|
106 |
+
|
107 |
+
def forward(self, x):
|
108 |
+
y = x
|
109 |
+
y = self.relu(self.norm1(self.conv1(y)))
|
110 |
+
y = self.relu(self.norm2(self.conv2(y)))
|
111 |
+
y = self.relu(self.norm3(self.conv3(y)))
|
112 |
+
|
113 |
+
if self.downsample is not None:
|
114 |
+
x = self.downsample(x)
|
115 |
+
|
116 |
+
return self.relu(x+y)
|
117 |
+
|
118 |
+
class BasicEncoder(nn.Module):
|
119 |
+
def __init__(self, output_dim=128, norm_fn='batch', dropout=0.0):
|
120 |
+
super(BasicEncoder, self).__init__()
|
121 |
+
self.norm_fn = norm_fn
|
122 |
+
|
123 |
+
if self.norm_fn == 'group':
|
124 |
+
self.norm1 = nn.GroupNorm(num_groups=8, num_channels=64)
|
125 |
+
|
126 |
+
elif self.norm_fn == 'batch':
|
127 |
+
self.norm1 = nn.BatchNorm2d(64)
|
128 |
+
|
129 |
+
elif self.norm_fn == 'instance':
|
130 |
+
self.norm1 = nn.InstanceNorm2d(64)
|
131 |
+
|
132 |
+
elif self.norm_fn == 'none':
|
133 |
+
self.norm1 = nn.Sequential()
|
134 |
+
|
135 |
+
self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3)
|
136 |
+
self.relu1 = nn.ReLU(inplace=True)
|
137 |
+
|
138 |
+
self.in_planes = 64
|
139 |
+
self.layer1 = self._make_layer(64, stride=1)
|
140 |
+
self.layer2 = self._make_layer(96, stride=2)
|
141 |
+
self.layer3 = self._make_layer(128, stride=2)
|
142 |
+
|
143 |
+
# output convolution
|
144 |
+
self.conv2 = nn.Conv2d(128, output_dim, kernel_size=1)
|
145 |
+
|
146 |
+
self.dropout = None
|
147 |
+
if dropout > 0:
|
148 |
+
self.dropout = nn.Dropout2d(p=dropout)
|
149 |
+
|
150 |
+
for m in self.modules():
|
151 |
+
if isinstance(m, nn.Conv2d):
|
152 |
+
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
|
153 |
+
elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)):
|
154 |
+
if m.weight is not None:
|
155 |
+
nn.init.constant_(m.weight, 1)
|
156 |
+
if m.bias is not None:
|
157 |
+
nn.init.constant_(m.bias, 0)
|
158 |
+
|
159 |
+
def _make_layer(self, dim, stride=1):
|
160 |
+
layer1 = ResidualBlock(self.in_planes, dim, self.norm_fn, stride=stride)
|
161 |
+
layer2 = ResidualBlock(dim, dim, self.norm_fn, stride=1)
|
162 |
+
layers = (layer1, layer2)
|
163 |
+
|
164 |
+
self.in_planes = dim
|
165 |
+
return nn.Sequential(*layers)
|
166 |
+
|
167 |
+
|
168 |
+
def forward(self, x):
|
169 |
+
|
170 |
+
# if input is list, combine batch dimension
|
171 |
+
is_list = isinstance(x, tuple) or isinstance(x, list)
|
172 |
+
if is_list:
|
173 |
+
batch_dim = x[0].shape[0]
|
174 |
+
x = torch.cat(x, dim=0)
|
175 |
+
|
176 |
+
x = self.conv1(x)
|
177 |
+
x = self.norm1(x)
|
178 |
+
x = self.relu1(x)
|
179 |
+
|
180 |
+
x = self.layer1(x)
|
181 |
+
x = self.layer2(x)
|
182 |
+
x = self.layer3(x)
|
183 |
+
|
184 |
+
x = self.conv2(x)
|
185 |
+
|
186 |
+
if self.training and self.dropout is not None:
|
187 |
+
x = self.dropout(x)
|
188 |
+
|
189 |
+
if is_list:
|
190 |
+
x = torch.split(x, [batch_dim, batch_dim], dim=0)
|
191 |
+
|
192 |
+
return x
|
193 |
+
|
194 |
+
class BasicEncoder1(nn.Module):
|
195 |
+
def __init__(self, output_dim=128, norm_fn='batch', dropout=0.0):
|
196 |
+
super(BasicEncoder1, self).__init__()
|
197 |
+
self.norm_fn = norm_fn
|
198 |
+
|
199 |
+
if self.norm_fn == 'group':
|
200 |
+
self.norm1 = nn.GroupNorm(num_groups=8, num_channels=64)
|
201 |
+
|
202 |
+
elif self.norm_fn == 'batch':
|
203 |
+
self.norm1 = nn.BatchNorm2d(64)
|
204 |
+
|
205 |
+
elif self.norm_fn == 'instance':
|
206 |
+
self.norm1 = nn.InstanceNorm2d(64)
|
207 |
+
|
208 |
+
elif self.norm_fn == 'none':
|
209 |
+
self.norm1 = nn.Sequential()
|
210 |
+
|
211 |
+
self.conv1 = nn.Conv2d(2, 64, kernel_size=7, stride=2, padding=3)
|
212 |
+
self.relu1 = nn.ReLU(inplace=True)
|
213 |
+
|
214 |
+
self.in_planes = 64
|
215 |
+
self.layer1 = self._make_layer(64, stride=1)
|
216 |
+
self.layer2 = self._make_layer(96, stride=2)
|
217 |
+
self.layer3 = self._make_layer(128, stride=2)
|
218 |
+
|
219 |
+
# output convolution
|
220 |
+
self.conv2 = nn.Conv2d(128, output_dim, kernel_size=1)
|
221 |
+
|
222 |
+
self.dropout = None
|
223 |
+
if dropout > 0:
|
224 |
+
self.dropout = nn.Dropout2d(p=dropout)
|
225 |
+
|
226 |
+
for m in self.modules():
|
227 |
+
if isinstance(m, nn.Conv2d):
|
228 |
+
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
|
229 |
+
elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)):
|
230 |
+
if m.weight is not None:
|
231 |
+
nn.init.constant_(m.weight, 1)
|
232 |
+
if m.bias is not None:
|
233 |
+
nn.init.constant_(m.bias, 0)
|
234 |
+
|
235 |
+
def _make_layer(self, dim, stride=1):
|
236 |
+
layer1 = ResidualBlock(self.in_planes, dim, self.norm_fn, stride=stride)
|
237 |
+
layer2 = ResidualBlock(dim, dim, self.norm_fn, stride=1)
|
238 |
+
layers = (layer1, layer2)
|
239 |
+
|
240 |
+
self.in_planes = dim
|
241 |
+
return nn.Sequential(*layers)
|
242 |
+
|
243 |
+
|
244 |
+
def forward(self, x):
|
245 |
+
|
246 |
+
# if input is list, combine batch dimension
|
247 |
+
is_list = isinstance(x, tuple) or isinstance(x, list)
|
248 |
+
if is_list:
|
249 |
+
batch_dim = x[0].shape[0]
|
250 |
+
x = torch.cat(x, dim=0)
|
251 |
+
|
252 |
+
x = self.conv1(x)
|
253 |
+
x = self.norm1(x)
|
254 |
+
x = self.relu1(x)
|
255 |
+
|
256 |
+
x = self.layer1(x)
|
257 |
+
x = self.layer2(x)
|
258 |
+
x = self.layer3(x)
|
259 |
+
|
260 |
+
x = self.conv2(x)
|
261 |
+
|
262 |
+
if self.training and self.dropout is not None:
|
263 |
+
x = self.dropout(x)
|
264 |
+
|
265 |
+
if is_list:
|
266 |
+
x = torch.split(x, [batch_dim, batch_dim], dim=0)
|
267 |
+
|
268 |
+
return x
|
269 |
+
|
270 |
+
class SmallEncoder(nn.Module):
|
271 |
+
def __init__(self, output_dim=128, norm_fn='batch', dropout=0.0):
|
272 |
+
super(SmallEncoder, self).__init__()
|
273 |
+
self.norm_fn = norm_fn
|
274 |
+
|
275 |
+
if self.norm_fn == 'group':
|
276 |
+
self.norm1 = nn.GroupNorm(num_groups=8, num_channels=32)
|
277 |
+
|
278 |
+
elif self.norm_fn == 'batch':
|
279 |
+
self.norm1 = nn.BatchNorm2d(32)
|
280 |
+
|
281 |
+
elif self.norm_fn == 'instance':
|
282 |
+
self.norm1 = nn.InstanceNorm2d(32)
|
283 |
+
|
284 |
+
elif self.norm_fn == 'none':
|
285 |
+
self.norm1 = nn.Sequential()
|
286 |
+
|
287 |
+
self.conv1 = nn.Conv2d(3, 32, kernel_size=7, stride=2, padding=3)
|
288 |
+
self.relu1 = nn.ReLU(inplace=True)
|
289 |
+
|
290 |
+
self.in_planes = 32
|
291 |
+
self.layer1 = self._make_layer(32, stride=1)
|
292 |
+
self.layer2 = self._make_layer(64, stride=2)
|
293 |
+
self.layer3 = self._make_layer(96, stride=2)
|
294 |
+
|
295 |
+
self.dropout = None
|
296 |
+
if dropout > 0:
|
297 |
+
self.dropout = nn.Dropout2d(p=dropout)
|
298 |
+
|
299 |
+
self.conv2 = nn.Conv2d(96, output_dim, kernel_size=1)
|
300 |
+
|
301 |
+
for m in self.modules():
|
302 |
+
if isinstance(m, nn.Conv2d):
|
303 |
+
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
|
304 |
+
elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)):
|
305 |
+
if m.weight is not None:
|
306 |
+
nn.init.constant_(m.weight, 1)
|
307 |
+
if m.bias is not None:
|
308 |
+
nn.init.constant_(m.bias, 0)
|
309 |
+
|
310 |
+
def _make_layer(self, dim, stride=1):
|
311 |
+
layer1 = BottleneckBlock(self.in_planes, dim, self.norm_fn, stride=stride)
|
312 |
+
layer2 = BottleneckBlock(dim, dim, self.norm_fn, stride=1)
|
313 |
+
layers = (layer1, layer2)
|
314 |
+
|
315 |
+
self.in_planes = dim
|
316 |
+
return nn.Sequential(*layers)
|
317 |
+
|
318 |
+
|
319 |
+
def forward(self, x):
|
320 |
+
|
321 |
+
# if input is list, combine batch dimension
|
322 |
+
is_list = isinstance(x, tuple) or isinstance(x, list)
|
323 |
+
if is_list:
|
324 |
+
batch_dim = x[0].shape[0]
|
325 |
+
x = torch.cat(x, dim=0)
|
326 |
+
|
327 |
+
x = self.conv1(x)
|
328 |
+
x = self.norm1(x)
|
329 |
+
x = self.relu1(x)
|
330 |
+
|
331 |
+
x = self.layer1(x)
|
332 |
+
x = self.layer2(x)
|
333 |
+
x = self.layer3(x)
|
334 |
+
x = self.conv2(x)
|
335 |
+
|
336 |
+
if self.training and self.dropout is not None:
|
337 |
+
x = self.dropout(x)
|
338 |
+
|
339 |
+
if is_list:
|
340 |
+
x = torch.split(x, [batch_dim, batch_dim], dim=0)
|
341 |
+
|
342 |
+
return x
|
modules/flow_models/raft/rfr_new.py
ADDED
@@ -0,0 +1,235 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
##################################################
|
2 |
+
# RFR is implemented based on RAFT optical flow #
|
3 |
+
##################################################
|
4 |
+
|
5 |
+
|
6 |
+
import torch
|
7 |
+
import torch.nn as nn
|
8 |
+
import torch.nn.functional as F
|
9 |
+
from argparse import Namespace
|
10 |
+
import numpy as np
|
11 |
+
|
12 |
+
from .update import BasicUpdateBlock, SmallUpdateBlock
|
13 |
+
from .extractor import BasicEncoder, SmallEncoder
|
14 |
+
from .corr import CorrBlock
|
15 |
+
from .utils import bilinear_sampler, coords_grid, upflow8
|
16 |
+
|
17 |
+
try:
|
18 |
+
autocast = torch.amp.autocast
|
19 |
+
except:
|
20 |
+
# dummy autocast for PyTorch < 1.6
|
21 |
+
class autocast:
|
22 |
+
def __init__(self, enabled):
|
23 |
+
pass
|
24 |
+
def __enter__(self):
|
25 |
+
pass
|
26 |
+
def __exit__(self, *args):
|
27 |
+
pass
|
28 |
+
|
29 |
+
def backwarp(img, flow):
|
30 |
+
_, _, H, W = img.size()
|
31 |
+
|
32 |
+
u = flow[:, 0, :, :]
|
33 |
+
v = flow[:, 1, :, :]
|
34 |
+
|
35 |
+
gridX, gridY = np.meshgrid(np.arange(W), np.arange(H))
|
36 |
+
|
37 |
+
gridX = torch.tensor(gridX, requires_grad=False,).cuda()
|
38 |
+
gridY = torch.tensor(gridY, requires_grad=False,).cuda()
|
39 |
+
x = gridX.unsqueeze(0).expand_as(u).float() + u
|
40 |
+
y = gridY.unsqueeze(0).expand_as(v).float() + v
|
41 |
+
# range -1 to 1
|
42 |
+
x = 2*(x/(W-1) - 0.5)
|
43 |
+
y = 2*(y/(H-1) - 0.5)
|
44 |
+
# stacking X and Y
|
45 |
+
grid = torch.stack((x,y), dim=3)
|
46 |
+
# Sample pixels using bilinear interpolation.
|
47 |
+
imgOut = torch.nn.functional.grid_sample(img, grid, align_corners=True)
|
48 |
+
|
49 |
+
return imgOut
|
50 |
+
class ErrorAttention(nn.Module):
|
51 |
+
"""A three-layer network for predicting mask"""
|
52 |
+
def __init__(self, input, output):
|
53 |
+
super(ErrorAttention, self).__init__()
|
54 |
+
self.conv1 = nn.Conv2d(input, 32, 5, padding=2)
|
55 |
+
self.conv2 = nn.Conv2d(32, 32, 3, padding=1)
|
56 |
+
self.conv3 = nn.Conv2d(38, output, 3, padding=1)
|
57 |
+
self.prelu1 = nn.PReLU()
|
58 |
+
self.prelu2 = nn.PReLU()
|
59 |
+
|
60 |
+
def forward(self, x1):
|
61 |
+
x = self.prelu1(self.conv1(x1))
|
62 |
+
x = self.prelu2(torch.cat([self.conv2(x), x1], dim=1))
|
63 |
+
x = self.conv3(x)
|
64 |
+
return x
|
65 |
+
|
66 |
+
class RFR(nn.Module):
|
67 |
+
def __init__(self, args):
|
68 |
+
super(RFR, self).__init__()
|
69 |
+
self.attention2 = ErrorAttention(6, 1)
|
70 |
+
self.hidden_dim = hdim = 128
|
71 |
+
self.context_dim = cdim = 128
|
72 |
+
args.corr_levels = 4
|
73 |
+
args.corr_radius = 4
|
74 |
+
args.dropout = 0
|
75 |
+
self.args = args
|
76 |
+
|
77 |
+
# feature network, context network, and update block
|
78 |
+
self.fnet = BasicEncoder(output_dim=256, norm_fn='none', dropout=args.dropout)
|
79 |
+
# self.cnet = BasicEncoder(output_dim=hdim+cdim, norm_fn='none', dropout=args.dropout)
|
80 |
+
self.update_block = BasicUpdateBlock(self.args, hidden_dim=hdim)
|
81 |
+
|
82 |
+
|
83 |
+
|
84 |
+
def freeze_bn(self):
|
85 |
+
for m in self.modules():
|
86 |
+
if isinstance(m, nn.BatchNorm2d):
|
87 |
+
m.eval()
|
88 |
+
|
89 |
+
def initialize_flow(self, img):
|
90 |
+
""" Flow is represented as difference between two coordinate grids flow = coords1 - coords0"""
|
91 |
+
N, C, H, W = img.shape
|
92 |
+
coords0 = coords_grid(N, H//8, W//8).to(img.device)
|
93 |
+
coords1 = coords_grid(N, H//8, W//8).to(img.device)
|
94 |
+
|
95 |
+
# optical flow computed as difference: flow = coords1 - coords0
|
96 |
+
return coords0, coords1
|
97 |
+
|
98 |
+
def upsample_flow(self, flow, mask):
|
99 |
+
""" Upsample flow field [H/8, W/8, 2] -> [H, W, 2] using convex combination """
|
100 |
+
N, _, H, W = flow.shape
|
101 |
+
mask = mask.view(N, 1, 9, 8, 8, H, W)
|
102 |
+
mask = torch.softmax(mask, dim=2)
|
103 |
+
|
104 |
+
up_flow = F.unfold(8 * flow, [3,3], padding=1)
|
105 |
+
up_flow = up_flow.view(N, 2, 9, 1, 1, H, W)
|
106 |
+
|
107 |
+
up_flow = torch.sum(mask * up_flow, dim=2)
|
108 |
+
up_flow = up_flow.permute(0, 1, 4, 2, 5, 3)
|
109 |
+
return up_flow.reshape(N, 2, 8*H, 8*W)
|
110 |
+
|
111 |
+
def forward(self, image1, image2, iters=12, flow_init=None, upsample=True, test_mode=False):
|
112 |
+
H, W = image1.size()[2:4]
|
113 |
+
H8 = H // 8 * 8
|
114 |
+
W8 = W // 8 * 8
|
115 |
+
|
116 |
+
if flow_init is not None:
|
117 |
+
flow_init_resize = F.interpolate(flow_init, size=(H8//8, W8//8), mode='nearest')
|
118 |
+
|
119 |
+
flow_init_resize[:, :1] = flow_init_resize[:, :1].clone() * (W8 // 8 *1.0) / flow_init.size()[3]
|
120 |
+
flow_init_resize[:, 1:] = flow_init_resize[:, 1:].clone() * (H8 // 8*1.0) / flow_init.size()[2]
|
121 |
+
|
122 |
+
if not hasattr(self.args, 'not_use_rfr_mask') or ( hasattr(self.args, 'not_use_rfr_mask') and (not self.args.not_use_rfr_mask)):
|
123 |
+
im18 = F.interpolate(image1, size=(H8//8, W8//8), mode='bilinear')
|
124 |
+
im28 = F.interpolate(image2, size=(H8//8, W8//8), mode='bilinear')
|
125 |
+
|
126 |
+
warp21 = backwarp(im28, flow_init_resize)
|
127 |
+
error21 = torch.sum(torch.abs(warp21 - im18), dim=1, keepdim=True)
|
128 |
+
# print('errormin', error21.min(), error21.max())
|
129 |
+
f12init = torch.exp(- self.attention2(torch.cat([im18, error21, flow_init_resize], dim=1)) ** 2) * flow_init_resize
|
130 |
+
else:
|
131 |
+
flow_init_resize = None
|
132 |
+
flow_init = torch.zeros(image1.size()[0], 2, image1.size()[2]//8, image1.size()[3]//8).cuda()
|
133 |
+
error21 = torch.zeros(image1.size()[0], 1, image1.size()[2]//8, image1.size()[3]//8).cuda()
|
134 |
+
|
135 |
+
f12_init = flow_init
|
136 |
+
# print('None inital flow!')
|
137 |
+
|
138 |
+
image1 = F.interpolate(image1, size=(H8, W8), mode='bilinear')
|
139 |
+
image2 = F.interpolate(image2, size=(H8, W8), mode='bilinear')
|
140 |
+
|
141 |
+
f12s, f12, f12_init = self.forward_pred(image1, image2, iters, flow_init_resize, upsample, test_mode)
|
142 |
+
|
143 |
+
|
144 |
+
if (hasattr(self.args, 'requires_sq_flow') and self.args.requires_sq_flow):
|
145 |
+
for ii in range(len(f12s)):
|
146 |
+
f12s[ii] = F.interpolate(f12s[ii], size=(H, W), mode='bilinear')
|
147 |
+
f12s[ii][:, :1] = f12s[ii][:, :1].clone() / (1.0*W8) * W
|
148 |
+
f12s[ii][:, 1:] = f12s[ii][:, 1:].clone() / (1.0*H8) * H
|
149 |
+
if self.training:
|
150 |
+
return f12s
|
151 |
+
else:
|
152 |
+
return [f12s[-1]], f12_init
|
153 |
+
else:
|
154 |
+
f12[:, :1] = f12[:, :1].clone() / (1.0*W8) * W
|
155 |
+
f12[:, 1:] = f12[:, 1:].clone() / (1.0*H8) * H
|
156 |
+
|
157 |
+
f12 = F.interpolate(f12, size=(H, W), mode='bilinear')
|
158 |
+
# print('wo!!')
|
159 |
+
return f12, f12_init, error21,
|
160 |
+
|
161 |
+
def forward_pred(self, image1, image2, iters=12, flow_init=None, upsample=True, test_mode=False):
|
162 |
+
""" Estimate optical flow between pair of frames """
|
163 |
+
|
164 |
+
|
165 |
+
image1 = image1.contiguous()
|
166 |
+
image2 = image2.contiguous()
|
167 |
+
|
168 |
+
hdim = self.hidden_dim
|
169 |
+
cdim = self.context_dim
|
170 |
+
|
171 |
+
# run the feature network
|
172 |
+
with autocast("cuda", enabled=self.args.mixed_precision):
|
173 |
+
fmap1, fmap2 = self.fnet([image1, image2])
|
174 |
+
fmap1 = fmap1.float()
|
175 |
+
fmap2 = fmap2.float()
|
176 |
+
corr_fn = CorrBlock(fmap1, fmap2, radius=self.args.corr_radius)
|
177 |
+
|
178 |
+
# run the context network
|
179 |
+
with autocast("cuda", enabled=self.args.mixed_precision):
|
180 |
+
cnet = self.fnet(image1)
|
181 |
+
net, inp = torch.split(cnet, [hdim, cdim], dim=1)
|
182 |
+
net = torch.tanh(net)
|
183 |
+
inp = torch.relu(inp)
|
184 |
+
|
185 |
+
coords0, coords1 = self.initialize_flow(image1)
|
186 |
+
|
187 |
+
if flow_init is not None:
|
188 |
+
coords1 = coords1 + flow_init
|
189 |
+
|
190 |
+
flow_predictions = []
|
191 |
+
for itr in range(iters):
|
192 |
+
coords1 = coords1.detach()
|
193 |
+
if itr == 0:
|
194 |
+
if flow_init is not None:
|
195 |
+
coords1 = coords1 + flow_init
|
196 |
+
corr = corr_fn(coords1) # index correlation volume
|
197 |
+
|
198 |
+
flow = coords1 - coords0
|
199 |
+
with autocast("cuda", enabled=self.args.mixed_precision):
|
200 |
+
net, up_mask, delta_flow = self.update_block(net, inp, corr, flow)
|
201 |
+
|
202 |
+
# F(t+1) = F(t) + \Delta(t)
|
203 |
+
coords1 = coords1 + delta_flow
|
204 |
+
|
205 |
+
# upsample predictions
|
206 |
+
if up_mask is None:
|
207 |
+
flow_up = upflow8(coords1 - coords0)
|
208 |
+
else:
|
209 |
+
flow_up = self.upsample_flow(coords1 - coords0, up_mask)
|
210 |
+
|
211 |
+
flow_predictions.append(flow_up)
|
212 |
+
|
213 |
+
return flow_predictions, flow_up, flow_init
|
214 |
+
|
215 |
+
class RAFT(nn.Module):
|
216 |
+
def __init__(self, path='./_pretrain_models/anime_interp_full.ckpt'):
|
217 |
+
super().__init__()
|
218 |
+
self.raft = RFR(Namespace(
|
219 |
+
small=False,
|
220 |
+
mixed_precision=False,
|
221 |
+
))
|
222 |
+
if path is not None:
|
223 |
+
sd = torch.load(path)['model_state_dict']
|
224 |
+
self.raft.load_state_dict({
|
225 |
+
k[len('module.flownet.'):]: v
|
226 |
+
for k,v in sd.items()
|
227 |
+
if k.startswith('module.flownet.')
|
228 |
+
}, strict=False)
|
229 |
+
return
|
230 |
+
def forward(self, img0, img1, flow0=None, iters=12, return_more=False):
|
231 |
+
if flow0 is not None:
|
232 |
+
flow0 = flow0.flip(dims=(1,))
|
233 |
+
out = self.raft(img0, img1, iters=iters, flow_init=flow0)
|
234 |
+
return out[0].flip(dims=(1,))
|
235 |
+
|
modules/flow_models/raft/update.py
ADDED
@@ -0,0 +1,139 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
|
5 |
+
|
6 |
+
class FlowHead(nn.Module):
|
7 |
+
def __init__(self, input_dim=128, hidden_dim=256):
|
8 |
+
super(FlowHead, self).__init__()
|
9 |
+
self.conv1 = nn.Conv2d(input_dim, hidden_dim, 3, padding=1)
|
10 |
+
self.conv2 = nn.Conv2d(hidden_dim, 2, 3, padding=1)
|
11 |
+
self.relu = nn.ReLU(inplace=True)
|
12 |
+
|
13 |
+
def forward(self, x):
|
14 |
+
return self.conv2(self.relu(self.conv1(x)))
|
15 |
+
|
16 |
+
class ConvGRU(nn.Module):
|
17 |
+
def __init__(self, hidden_dim=128, input_dim=192+128):
|
18 |
+
super(ConvGRU, self).__init__()
|
19 |
+
self.convz = nn.Conv2d(hidden_dim+input_dim, hidden_dim, 3, padding=1)
|
20 |
+
self.convr = nn.Conv2d(hidden_dim+input_dim, hidden_dim, 3, padding=1)
|
21 |
+
self.convq = nn.Conv2d(hidden_dim+input_dim, hidden_dim, 3, padding=1)
|
22 |
+
|
23 |
+
def forward(self, h, x):
|
24 |
+
hx = torch.cat([h, x], dim=1)
|
25 |
+
|
26 |
+
z = torch.sigmoid(self.convz(hx))
|
27 |
+
r = torch.sigmoid(self.convr(hx))
|
28 |
+
q = torch.tanh(self.convq(torch.cat([r*h, x], dim=1)))
|
29 |
+
|
30 |
+
h = (1-z) * h + z * q
|
31 |
+
return h
|
32 |
+
|
33 |
+
class SepConvGRU(nn.Module):
|
34 |
+
def __init__(self, hidden_dim=128, input_dim=192+128):
|
35 |
+
super(SepConvGRU, self).__init__()
|
36 |
+
self.convz1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2))
|
37 |
+
self.convr1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2))
|
38 |
+
self.convq1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2))
|
39 |
+
|
40 |
+
self.convz2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0))
|
41 |
+
self.convr2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0))
|
42 |
+
self.convq2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0))
|
43 |
+
|
44 |
+
|
45 |
+
def forward(self, h, x):
|
46 |
+
# horizontal
|
47 |
+
hx = torch.cat([h, x], dim=1)
|
48 |
+
z = torch.sigmoid(self.convz1(hx))
|
49 |
+
r = torch.sigmoid(self.convr1(hx))
|
50 |
+
q = torch.tanh(self.convq1(torch.cat([r*h, x], dim=1)))
|
51 |
+
h = (1-z) * h + z * q
|
52 |
+
|
53 |
+
# vertical
|
54 |
+
hx = torch.cat([h, x], dim=1)
|
55 |
+
z = torch.sigmoid(self.convz2(hx))
|
56 |
+
r = torch.sigmoid(self.convr2(hx))
|
57 |
+
q = torch.tanh(self.convq2(torch.cat([r*h, x], dim=1)))
|
58 |
+
h = (1-z) * h + z * q
|
59 |
+
|
60 |
+
return h
|
61 |
+
|
62 |
+
class SmallMotionEncoder(nn.Module):
|
63 |
+
def __init__(self, args):
|
64 |
+
super(SmallMotionEncoder, self).__init__()
|
65 |
+
cor_planes = args.corr_levels * (2*args.corr_radius + 1)**2
|
66 |
+
self.convc1 = nn.Conv2d(cor_planes, 96, 1, padding=0)
|
67 |
+
self.convf1 = nn.Conv2d(2, 64, 7, padding=3)
|
68 |
+
self.convf2 = nn.Conv2d(64, 32, 3, padding=1)
|
69 |
+
self.conv = nn.Conv2d(128, 80, 3, padding=1)
|
70 |
+
|
71 |
+
def forward(self, flow, corr):
|
72 |
+
cor = F.relu(self.convc1(corr))
|
73 |
+
flo = F.relu(self.convf1(flow))
|
74 |
+
flo = F.relu(self.convf2(flo))
|
75 |
+
cor_flo = torch.cat([cor, flo], dim=1)
|
76 |
+
out = F.relu(self.conv(cor_flo))
|
77 |
+
return torch.cat([out, flow], dim=1)
|
78 |
+
|
79 |
+
class BasicMotionEncoder(nn.Module):
|
80 |
+
def __init__(self, args):
|
81 |
+
super(BasicMotionEncoder, self).__init__()
|
82 |
+
cor_planes = args.corr_levels * (2*args.corr_radius + 1)**2
|
83 |
+
self.convc1 = nn.Conv2d(cor_planes, 256, 1, padding=0)
|
84 |
+
self.convc2 = nn.Conv2d(256, 192, 3, padding=1)
|
85 |
+
self.convf1 = nn.Conv2d(2, 128, 7, padding=3)
|
86 |
+
self.convf2 = nn.Conv2d(128, 64, 3, padding=1)
|
87 |
+
self.conv = nn.Conv2d(64+192, 128-2, 3, padding=1)
|
88 |
+
|
89 |
+
def forward(self, flow, corr):
|
90 |
+
cor = F.relu(self.convc1(corr))
|
91 |
+
cor = F.relu(self.convc2(cor))
|
92 |
+
flo = F.relu(self.convf1(flow))
|
93 |
+
flo = F.relu(self.convf2(flo))
|
94 |
+
|
95 |
+
cor_flo = torch.cat([cor, flo], dim=1)
|
96 |
+
out = F.relu(self.conv(cor_flo))
|
97 |
+
return torch.cat([out, flow], dim=1)
|
98 |
+
|
99 |
+
class SmallUpdateBlock(nn.Module):
|
100 |
+
def __init__(self, args, hidden_dim=96):
|
101 |
+
super(SmallUpdateBlock, self).__init__()
|
102 |
+
self.encoder = SmallMotionEncoder(args)
|
103 |
+
self.gru = ConvGRU(hidden_dim=hidden_dim, input_dim=82+64)
|
104 |
+
self.flow_head = FlowHead(hidden_dim, hidden_dim=128)
|
105 |
+
|
106 |
+
def forward(self, net, inp, corr, flow):
|
107 |
+
motion_features = self.encoder(flow, corr)
|
108 |
+
inp = torch.cat([inp, motion_features], dim=1)
|
109 |
+
net = self.gru(net, inp)
|
110 |
+
delta_flow = self.flow_head(net)
|
111 |
+
|
112 |
+
return net, None, delta_flow
|
113 |
+
|
114 |
+
class BasicUpdateBlock(nn.Module):
|
115 |
+
def __init__(self, args, hidden_dim=128, input_dim=128):
|
116 |
+
super(BasicUpdateBlock, self).__init__()
|
117 |
+
self.args = args
|
118 |
+
self.encoder = BasicMotionEncoder(args)
|
119 |
+
self.gru = SepConvGRU(hidden_dim=hidden_dim, input_dim=128+hidden_dim)
|
120 |
+
self.flow_head = FlowHead(hidden_dim, hidden_dim=256)
|
121 |
+
|
122 |
+
self.mask = nn.Sequential(
|
123 |
+
nn.Conv2d(128, 256, 3, padding=1),
|
124 |
+
nn.ReLU(inplace=True),
|
125 |
+
nn.Conv2d(256, 64*9, 1, padding=0))
|
126 |
+
|
127 |
+
def forward(self, net, inp, corr, flow, upsample=True):
|
128 |
+
motion_features = self.encoder(flow, corr)
|
129 |
+
inp = torch.cat([inp, motion_features], dim=1)
|
130 |
+
|
131 |
+
net = self.gru(net, inp)
|
132 |
+
delta_flow = self.flow_head(net)
|
133 |
+
|
134 |
+
# scale mask to balence gradients
|
135 |
+
mask = .25 * self.mask(net)
|
136 |
+
return net, mask, delta_flow
|
137 |
+
|
138 |
+
|
139 |
+
|
modules/flow_models/raft/utils.py
ADDED
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn.functional as F
|
3 |
+
import numpy as np
|
4 |
+
from scipy import interpolate
|
5 |
+
|
6 |
+
|
7 |
+
class InputPadder:
|
8 |
+
""" Pads images such that dimensions are divisible by 8 """
|
9 |
+
def __init__(self, dims):
|
10 |
+
self.ht, self.wd = dims[-2:]
|
11 |
+
pad_ht = (((self.ht // 8) + 1) * 8 - self.ht) % 8
|
12 |
+
pad_wd = (((self.wd // 8) + 1) * 8 - self.wd) % 8
|
13 |
+
self._pad = [pad_wd//2, pad_wd - pad_wd//2, 0, pad_ht]
|
14 |
+
|
15 |
+
def pad(self, *inputs):
|
16 |
+
return [F.pad(x, self._pad, mode='replicate') for x in inputs]
|
17 |
+
|
18 |
+
def unpad(self,x):
|
19 |
+
ht, wd = x.shape[-2:]
|
20 |
+
c = [self._pad[2], ht-self._pad[3], self._pad[0], wd-self._pad[1]]
|
21 |
+
return x[..., c[0]:c[1], c[2]:c[3]]
|
22 |
+
|
23 |
+
def forward_interpolate(flow):
|
24 |
+
flow = flow.detach().cpu().numpy()
|
25 |
+
dx, dy = flow[0], flow[1]
|
26 |
+
|
27 |
+
ht, wd = dx.shape
|
28 |
+
x0, y0 = np.meshgrid(np.arange(wd), np.arange(ht))
|
29 |
+
|
30 |
+
x1 = x0 + dx
|
31 |
+
y1 = y0 + dy
|
32 |
+
|
33 |
+
x1 = x1.reshape(-1)
|
34 |
+
y1 = y1.reshape(-1)
|
35 |
+
dx = dx.reshape(-1)
|
36 |
+
dy = dy.reshape(-1)
|
37 |
+
|
38 |
+
valid = (x1 > 0) & (x1 < wd) & (y1 > 0) & (y1 < ht)
|
39 |
+
x1 = x1[valid]
|
40 |
+
y1 = y1[valid]
|
41 |
+
dx = dx[valid]
|
42 |
+
dy = dy[valid]
|
43 |
+
|
44 |
+
flow_x = interpolate.griddata(
|
45 |
+
(x1, y1), dx, (x0, y0), method='cubic', fill_value=0)
|
46 |
+
|
47 |
+
flow_y = interpolate.griddata(
|
48 |
+
(x1, y1), dy, (x0, y0), method='cubic', fill_value=0)
|
49 |
+
|
50 |
+
flow = np.stack([flow_x, flow_y], axis=0)
|
51 |
+
return torch.from_numpy(flow).float()
|
52 |
+
|
53 |
+
|
54 |
+
def bilinear_sampler(img, coords, mode='bilinear', mask=False):
|
55 |
+
""" Wrapper for grid_sample, uses pixel coordinates """
|
56 |
+
H, W = img.shape[-2:]
|
57 |
+
xgrid, ygrid = coords.split([1,1], dim=-1)
|
58 |
+
xgrid = 2*xgrid/(W-1) - 1
|
59 |
+
ygrid = 2*ygrid/(H-1) - 1
|
60 |
+
|
61 |
+
grid = torch.cat([xgrid, ygrid], dim=-1)
|
62 |
+
# print(img.size())
|
63 |
+
img = F.grid_sample(img, grid, align_corners=True)
|
64 |
+
|
65 |
+
if mask:
|
66 |
+
mask = (xgrid > -1) & (ygrid > -1) & (xgrid < 1) & (ygrid < 1)
|
67 |
+
return img, mask.float()
|
68 |
+
|
69 |
+
return img
|
70 |
+
|
71 |
+
|
72 |
+
|
73 |
+
def coords_grid(batch, ht, wd):
|
74 |
+
coords = torch.meshgrid(torch.arange(ht), torch.arange(wd))
|
75 |
+
coords = torch.stack(coords[::-1], dim=0).float()
|
76 |
+
return coords[None].repeat(batch, 1, 1, 1)
|
77 |
+
|
78 |
+
|
79 |
+
def upflow8(flow, mode='bilinear'):
|
80 |
+
new_size = (8 * flow.shape[2], 8 * flow.shape[3])
|
81 |
+
return 8 * F.interpolate(flow, size=new_size, mode=mode, align_corners=True)
|
modules/half_warper.py
ADDED
@@ -0,0 +1,129 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
from kornia.color import rgb_to_lab
|
5 |
+
|
6 |
+
from utils.utils import morph_open
|
7 |
+
|
8 |
+
from modules.cupy_module.softsplat import FunctionSoftsplat
|
9 |
+
|
10 |
+
class HalfWarper(nn.Module):
|
11 |
+
def __init__(self):
|
12 |
+
super().__init__()
|
13 |
+
|
14 |
+
@staticmethod
|
15 |
+
def backward_wrapping(
|
16 |
+
img: torch.Tensor,
|
17 |
+
flow: torch.Tensor,
|
18 |
+
resample: str = 'bilinear',
|
19 |
+
padding_mode: str = 'border',
|
20 |
+
align_corners: bool = False
|
21 |
+
) -> torch.Tensor:
|
22 |
+
if len(img.shape) != 4: img = img[None,]
|
23 |
+
if len(flow.shape) != 4: flow = flow[None,]
|
24 |
+
|
25 |
+
q = 2 * flow / torch.tensor([
|
26 |
+
flow.shape[-2], flow.shape[-1],
|
27 |
+
], device=flow.device, dtype=torch.float)[None,:,None,None]
|
28 |
+
|
29 |
+
q = q + torch.stack(torch.meshgrid(
|
30 |
+
torch.linspace(-1, 1, flow.shape[-2]),
|
31 |
+
torch.linspace(-1, 1, flow.shape[-1]),
|
32 |
+
))[None,].to(flow.device)
|
33 |
+
|
34 |
+
if img.dtype != q.dtype:
|
35 |
+
img = img.type(q.dtype)
|
36 |
+
|
37 |
+
return F.grid_sample(
|
38 |
+
img,
|
39 |
+
q.flip(dims=(1,)).permute(0, 2, 3, 1).contiguous(),
|
40 |
+
mode = resample, # nearest, bicubic, bilinear
|
41 |
+
padding_mode = padding_mode, # border, zeros, reflection
|
42 |
+
align_corners = align_corners,
|
43 |
+
)
|
44 |
+
|
45 |
+
@staticmethod
|
46 |
+
def forward_warpping(
|
47 |
+
img: torch.Tensor,
|
48 |
+
flow: torch.Tensor,
|
49 |
+
mode: str = 'softmax',
|
50 |
+
metric: torch.Tensor | None = None,
|
51 |
+
mask: bool = True
|
52 |
+
) -> torch.Tensor:
|
53 |
+
if len(img.shape) != 4: img = img[None,]
|
54 |
+
if len(flow.shape) != 4: flow = flow[None,]
|
55 |
+
if metric is not None and len(metric.shape)!=4: metric = metric[None,]
|
56 |
+
|
57 |
+
flow = flow.flip(dims=(1,))
|
58 |
+
if img.dtype != torch.float32:
|
59 |
+
img = img.type(torch.float32)
|
60 |
+
if flow.dtype != torch.float32:
|
61 |
+
flow = flow.type(torch.float32)
|
62 |
+
if metric is not None and metric.dtype != torch.float32:
|
63 |
+
metric = metric.type(torch.float32)
|
64 |
+
|
65 |
+
assert img.device == flow.device
|
66 |
+
if metric is not None: assert img.device == metric.device
|
67 |
+
if img.device.type=='cpu':
|
68 |
+
img = img.to('cuda')
|
69 |
+
flow = flow.to('cuda')
|
70 |
+
if metric is not None: metric = metric.to('cuda')
|
71 |
+
|
72 |
+
if mask:
|
73 |
+
batch, _, h, w = img.shape
|
74 |
+
img = torch.cat([img, torch.ones(batch, 1, h, w, dtype=img.dtype, device=img.device)], dim=1)
|
75 |
+
|
76 |
+
return FunctionSoftsplat(img, flow, metric, mode)
|
77 |
+
|
78 |
+
@staticmethod
|
79 |
+
def z_metric(
|
80 |
+
img0: torch.Tensor,
|
81 |
+
img1: torch.Tensor,
|
82 |
+
flow0to1: torch.Tensor,
|
83 |
+
flow1to0: torch.Tensor
|
84 |
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
85 |
+
img0 = rgb_to_lab(img0[:,:3])
|
86 |
+
img1 = rgb_to_lab(img1[:,:3])
|
87 |
+
z1to0 = -0.1*(img1 - HalfWarper.backward_wrapping(img0, flow1to0)).norm(dim=1, keepdim=True)
|
88 |
+
z0to1 = -0.1*(img0 - HalfWarper.backward_wrapping(img1, flow0to1)).norm(dim=1, keepdim=True)
|
89 |
+
return z0to1, z1to0
|
90 |
+
|
91 |
+
def forward(
|
92 |
+
self,
|
93 |
+
I0: torch.Tensor,
|
94 |
+
I1: torch.Tensor,
|
95 |
+
flow0to1: torch.Tensor,
|
96 |
+
flow1to0: torch.Tensor,
|
97 |
+
z0to1: torch.Tensor | None = None,
|
98 |
+
z1to0: torch.Tensor | None = None,
|
99 |
+
tau: float | None = None,
|
100 |
+
morph_kernel_size: int = 5,
|
101 |
+
mask: bool = True
|
102 |
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
103 |
+
|
104 |
+
if z1to0 is None or z0to1 is None:
|
105 |
+
z0to1, z1to0 = self.z_metric(I0, I1, flow0to1, flow1to0)
|
106 |
+
|
107 |
+
if tau is not None:
|
108 |
+
flow0tot = tau*flow0to1
|
109 |
+
flow1tot = (1 - tau)*flow1to0
|
110 |
+
else:
|
111 |
+
flow0tot = flow0to1
|
112 |
+
flow1tot = flow1to0
|
113 |
+
|
114 |
+
# image warping
|
115 |
+
fw0to1 = HalfWarper.forward_warpping(I0, flow0tot, mode='softmax', metric=z0to1, mask=True)
|
116 |
+
fw1to0 = HalfWarper.forward_warpping(I1, flow1tot, mode='softmax', metric=z1to0, mask=True)
|
117 |
+
|
118 |
+
wrapped_image0tot = fw0to1[:,:-1]
|
119 |
+
wrapped_image1tot = fw1to0[:,:-1]
|
120 |
+
mask0tot = morph_open(fw0to1[:,-1:], k=morph_kernel_size)
|
121 |
+
mask1tot = morph_open(fw1to0[:,-1:], k=morph_kernel_size)
|
122 |
+
|
123 |
+
base0 = mask0tot*wrapped_image0tot + (1 - mask0tot)*wrapped_image1tot
|
124 |
+
base1 = mask1tot*wrapped_image1tot + (1 - mask1tot)*wrapped_image0tot
|
125 |
+
|
126 |
+
if mask:
|
127 |
+
base0 = torch.cat([base0, mask0tot], dim=1)
|
128 |
+
base1 = torch.cat([base1, mask1tot], dim=1)
|
129 |
+
return base0, base1
|
modules/synthesizer.py
ADDED
@@ -0,0 +1,277 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
|
4 |
+
from modules.basic_layers import (
|
5 |
+
SinusoidalPositionalEmbedding,
|
6 |
+
ResGatedBlock,
|
7 |
+
MaxViTBlock,
|
8 |
+
Downsample,
|
9 |
+
Upsample
|
10 |
+
)
|
11 |
+
|
12 |
+
class UnetDownBlock(nn.Module):
|
13 |
+
def __init__(
|
14 |
+
self,
|
15 |
+
in_channels: int,
|
16 |
+
out_channels: int,
|
17 |
+
temb_channels: int = 128,
|
18 |
+
heads: int = 1,
|
19 |
+
window_size: int = 7,
|
20 |
+
window_attn: bool = True,
|
21 |
+
grid_attn: bool = True,
|
22 |
+
expansion_rate: int = 4,
|
23 |
+
num_conv_blocks: int = 2,
|
24 |
+
dropout: float = 0.0
|
25 |
+
):
|
26 |
+
super(UnetDownBlock, self).__init__()
|
27 |
+
self.pool = Downsample(
|
28 |
+
in_channels = in_channels,
|
29 |
+
out_channels = in_channels,
|
30 |
+
use_conv = True
|
31 |
+
)
|
32 |
+
in_channels = 3 * in_channels + 2
|
33 |
+
self.conv = nn.ModuleList([
|
34 |
+
ResGatedBlock(
|
35 |
+
in_channels = in_channels if i == 0 else out_channels,
|
36 |
+
out_channels = out_channels,
|
37 |
+
emb_channels = temb_channels,
|
38 |
+
gated_conv = True
|
39 |
+
) for i in range(num_conv_blocks)
|
40 |
+
])
|
41 |
+
self.maxvit = MaxViTBlock(
|
42 |
+
channels = out_channels,
|
43 |
+
#latent_dim = out_channels // 6,
|
44 |
+
heads = heads,
|
45 |
+
window_size = window_size,
|
46 |
+
window_attn = window_attn,
|
47 |
+
grid_attn = grid_attn,
|
48 |
+
expansion_rate = expansion_rate,
|
49 |
+
dropout = dropout,
|
50 |
+
emb_channels = temb_channels
|
51 |
+
)
|
52 |
+
|
53 |
+
def forward(
|
54 |
+
self,
|
55 |
+
x: torch.Tensor,
|
56 |
+
warp0: torch.Tensor,
|
57 |
+
warp1: torch.Tensor,
|
58 |
+
temb: torch.Tensor
|
59 |
+
):
|
60 |
+
x = self.pool(x)
|
61 |
+
x = torch.cat([x, warp0, warp1], dim=1)
|
62 |
+
for conv in self.conv:
|
63 |
+
x = conv(x, temb)
|
64 |
+
x = self.maxvit(x, temb)
|
65 |
+
return x
|
66 |
+
|
67 |
+
class UnetMiddleBlock(nn.Module):
|
68 |
+
def __init__(
|
69 |
+
self,
|
70 |
+
in_channels: int,
|
71 |
+
mid_channels: int,
|
72 |
+
out_channels: int,
|
73 |
+
temb_channels: int = 128,
|
74 |
+
heads: int = 1,
|
75 |
+
window_size: int = 7,
|
76 |
+
window_attn: bool = True,
|
77 |
+
grid_attn: bool = True,
|
78 |
+
expansion_rate: int = 4,
|
79 |
+
dropout: float = 0.0
|
80 |
+
):
|
81 |
+
super(UnetMiddleBlock, self).__init__()
|
82 |
+
|
83 |
+
self.middle_blocks = nn.ModuleList([
|
84 |
+
ResGatedBlock(
|
85 |
+
in_channels = in_channels,
|
86 |
+
out_channels = mid_channels,
|
87 |
+
emb_channels = temb_channels,
|
88 |
+
gated_conv = True
|
89 |
+
),
|
90 |
+
MaxViTBlock(
|
91 |
+
channels = mid_channels,
|
92 |
+
#latent_dim = mid_channels // 6,
|
93 |
+
heads = heads,
|
94 |
+
window_size = window_size,
|
95 |
+
window_attn = window_attn,
|
96 |
+
grid_attn = grid_attn,
|
97 |
+
expansion_rate = expansion_rate,
|
98 |
+
dropout = dropout,
|
99 |
+
emb_channels = temb_channels
|
100 |
+
),
|
101 |
+
ResGatedBlock(
|
102 |
+
in_channels = mid_channels,
|
103 |
+
out_channels = out_channels,
|
104 |
+
emb_channels = temb_channels,
|
105 |
+
gated_conv = True
|
106 |
+
)
|
107 |
+
])
|
108 |
+
|
109 |
+
def forward(self, x, temb):
|
110 |
+
for block in self.middle_blocks:
|
111 |
+
x = block(x, temb)
|
112 |
+
return x
|
113 |
+
|
114 |
+
class UnetUpBlock(nn.Module):
|
115 |
+
def __init__(
|
116 |
+
self,
|
117 |
+
in_channels: int,
|
118 |
+
out_channels: int,
|
119 |
+
temb_channels: int = 128,
|
120 |
+
heads: int = 1,
|
121 |
+
window_size: int = 7,
|
122 |
+
window_attn: bool = True,
|
123 |
+
grid_attn: bool = True,
|
124 |
+
expansion_rate: int = 4,
|
125 |
+
num_conv_blocks: int = 2,
|
126 |
+
dropout: float = 0.0
|
127 |
+
):
|
128 |
+
super(UnetUpBlock, self).__init__()
|
129 |
+
in_channels = 2 * in_channels
|
130 |
+
self.maxvit = MaxViTBlock(
|
131 |
+
channels = in_channels,
|
132 |
+
#latent_dim = in_channels // 6,
|
133 |
+
heads = heads,
|
134 |
+
window_size = window_size,
|
135 |
+
window_attn = window_attn,
|
136 |
+
grid_attn = grid_attn,
|
137 |
+
expansion_rate = expansion_rate,
|
138 |
+
dropout = dropout,
|
139 |
+
emb_channels = temb_channels
|
140 |
+
)
|
141 |
+
self.upsample = Upsample(
|
142 |
+
in_channels = in_channels,
|
143 |
+
out_channels = in_channels,
|
144 |
+
use_conv = True
|
145 |
+
)
|
146 |
+
self.conv = nn.ModuleList([
|
147 |
+
ResGatedBlock(
|
148 |
+
in_channels if i == 0 else out_channels,
|
149 |
+
out_channels,
|
150 |
+
emb_channels = temb_channels,
|
151 |
+
gated_conv = True
|
152 |
+
) for i in range(num_conv_blocks)
|
153 |
+
])
|
154 |
+
|
155 |
+
def forward(
|
156 |
+
self,
|
157 |
+
x: torch.Tensor,
|
158 |
+
skip_connection: torch.Tensor,
|
159 |
+
temb: torch.Tensor
|
160 |
+
):
|
161 |
+
x = torch.cat([x, skip_connection], dim=1)
|
162 |
+
x = self.maxvit(x, temb)
|
163 |
+
x = self.upsample(x)
|
164 |
+
for conv in self.conv:
|
165 |
+
x = conv(x, temb)
|
166 |
+
return x
|
167 |
+
|
168 |
+
class Synthesis(nn.Module):
|
169 |
+
def __init__(
|
170 |
+
self,
|
171 |
+
in_channels: int,
|
172 |
+
channels: list[int],
|
173 |
+
temb_channels: int,
|
174 |
+
heads: int = 1,
|
175 |
+
window_size: int = 7,
|
176 |
+
window_attn: bool = True,
|
177 |
+
grid_attn: bool = True,
|
178 |
+
expansion_rate: int = 4,
|
179 |
+
num_conv_blocks: int = 2,
|
180 |
+
dropout: float = 0.0
|
181 |
+
):
|
182 |
+
super(Synthesis, self).__init__()
|
183 |
+
|
184 |
+
|
185 |
+
self.t_pos_encoding = SinusoidalPositionalEmbedding(temb_channels)
|
186 |
+
|
187 |
+
self.input_blocks = nn.ModuleList([
|
188 |
+
nn.Conv2d(3*in_channels + 4, channels[0], kernel_size=3, padding=1),
|
189 |
+
ResGatedBlock(
|
190 |
+
in_channels = channels[0],
|
191 |
+
out_channels = channels[0],
|
192 |
+
emb_channels = temb_channels,
|
193 |
+
gated_conv = True
|
194 |
+
)
|
195 |
+
])
|
196 |
+
|
197 |
+
self.down_blocks = nn.ModuleList([
|
198 |
+
UnetDownBlock(
|
199 |
+
#3 * channels[i] + 2,
|
200 |
+
channels[i],
|
201 |
+
channels[i + 1],
|
202 |
+
temb_channels,
|
203 |
+
heads = heads,
|
204 |
+
window_size = window_size,
|
205 |
+
window_attn = window_attn,
|
206 |
+
grid_attn = grid_attn,
|
207 |
+
expansion_rate = expansion_rate,
|
208 |
+
num_conv_blocks = num_conv_blocks,
|
209 |
+
dropout = dropout,
|
210 |
+
) for i in range(len(channels) - 1)
|
211 |
+
])
|
212 |
+
|
213 |
+
self.middle_block = UnetMiddleBlock(
|
214 |
+
in_channels = channels[-1],
|
215 |
+
mid_channels = channels[-1],
|
216 |
+
out_channels = channels[-1],
|
217 |
+
temb_channels = temb_channels,
|
218 |
+
heads = heads,
|
219 |
+
window_size = window_size,
|
220 |
+
window_attn = window_attn,
|
221 |
+
grid_attn = grid_attn,
|
222 |
+
expansion_rate = expansion_rate,
|
223 |
+
dropout = dropout,
|
224 |
+
)
|
225 |
+
|
226 |
+
self.up_blocks = nn.ModuleList([
|
227 |
+
UnetUpBlock(
|
228 |
+
channels[i + 1],
|
229 |
+
channels[i],
|
230 |
+
temb_channels,
|
231 |
+
heads = heads,
|
232 |
+
window_size = window_size,
|
233 |
+
window_attn = window_attn,
|
234 |
+
grid_attn = grid_attn,
|
235 |
+
expansion_rate = expansion_rate,
|
236 |
+
num_conv_blocks = num_conv_blocks,
|
237 |
+
dropout = dropout,
|
238 |
+
) for i in reversed(range(len(channels) - 1))
|
239 |
+
])
|
240 |
+
|
241 |
+
self.output_blocks = nn.ModuleList([
|
242 |
+
ResGatedBlock(
|
243 |
+
in_channels = channels[0],
|
244 |
+
out_channels = channels[0],
|
245 |
+
emb_channels = temb_channels,
|
246 |
+
gated_conv = True
|
247 |
+
),
|
248 |
+
nn.Conv2d(channels[0], in_channels, kernel_size=3, padding=1)
|
249 |
+
])
|
250 |
+
|
251 |
+
def forward(
|
252 |
+
self,
|
253 |
+
x: torch.Tensor,
|
254 |
+
warp0: list[torch.Tensor],
|
255 |
+
warp1: list[torch.Tensor],
|
256 |
+
temb: torch.Tensor
|
257 |
+
):
|
258 |
+
temb = temb.unsqueeze(-1).type(torch.float)
|
259 |
+
temb = self.t_pos_encoding(temb)
|
260 |
+
|
261 |
+
x = self.input_blocks[0](torch.cat([x, warp0[0], warp1[0]], dim=1))
|
262 |
+
x = self.input_blocks[1](x, temb)
|
263 |
+
|
264 |
+
features = []
|
265 |
+
for i, down_block in enumerate(self.down_blocks):
|
266 |
+
x = down_block(x, warp0[i + 1], warp1[i + 1], temb)
|
267 |
+
features.append(x)
|
268 |
+
|
269 |
+
x = self.middle_block(x, temb)
|
270 |
+
|
271 |
+
for i, up_block in enumerate(self.up_blocks):
|
272 |
+
x = up_block(x, features[-(i + 1)], temb)
|
273 |
+
|
274 |
+
x = self.output_blocks[0](x, temb)
|
275 |
+
x = self.output_blocks[1](x)
|
276 |
+
|
277 |
+
return x
|
requirements.txt
ADDED
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Main dependencies
|
2 |
+
torch>=2.6.0
|
3 |
+
torchvision>=0.21.0
|
4 |
+
lightning>=2.2.4
|
5 |
+
numpy>=1.26.4
|
6 |
+
matplotlib>=3.8.0
|
7 |
+
pyyaml>=6.0.0
|
8 |
+
|
9 |
+
# Huggingface
|
10 |
+
huggingface-hub>=0.30.2
|
11 |
+
|
12 |
+
# Image processing and computer vision
|
13 |
+
kornia>=0.7.2
|
14 |
+
opencv-python>=4.10.0.84
|
15 |
+
opencv-contrib-python>=4.10.0.84
|
16 |
+
einops>=0.8.0
|
17 |
+
|
18 |
+
# Custom cuda implementation /modules/cupy_module/
|
19 |
+
cupy-cuda12x>=12.0.0 # For CUDA 12.4
|
20 |
+
# Note: For cupy, you need to install the specific version for your CUDA version
|
21 |
+
# Examples:
|
22 |
+
# cupy-cuda11x for CUDA 11.x
|
23 |
+
# cupy-cuda12x for CUDA 12.x
|
24 |
+
# cupy-cuda10x for CUDA 10.x
|
25 |
+
|
26 |
+
# Utilities and tools
|
27 |
+
scipy>=1.7.0
|
28 |
+
tensorboard>=2.8.0
|
29 |
+
|
30 |
+
# Project-Specific Dependencies
|
31 |
+
# RAFT (Flow Estimation)
|
32 |
+
# Note: RAFT is included in the project code; no external installation is required.
|
33 |
+
|
34 |
+
# FLOLPIPS (Quality Metrics)
|
35 |
+
# Note: FLOLPIPS is included in the project code; no external installation is required.
|
36 |
+
|
37 |
+
# Gradio
|
38 |
+
gradio>=4.34.0
|
39 |
+
imageio>=2.34.1
|
40 |
+
imageio-ffmpeg>=0.6.0
|
41 |
+
|
42 |
+
|
utils/ema.py
ADDED
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
|
4 |
+
class EMA:
|
5 |
+
def __init__(self, beta: float):
|
6 |
+
super().__init__()
|
7 |
+
self.beta = beta
|
8 |
+
self.step = 0
|
9 |
+
|
10 |
+
def update_model_average(self, ema_model: nn.Module, current_model: nn.Module) -> None:
|
11 |
+
for current_params, ema_model in zip(current_model.parameters(), ema_model.parameters()):
|
12 |
+
old_weight, up_weight = ema_model.data, current_params.data
|
13 |
+
ema_model.data = self.update_average(old_weight, up_weight)
|
14 |
+
|
15 |
+
def update_average(self, old: torch.Tensor | None, new: torch.Tensor) -> torch.Tensor:
|
16 |
+
if old is None:
|
17 |
+
return new
|
18 |
+
return old * self.beta + (1 - self.beta) * new
|
19 |
+
|
20 |
+
def step_ema(self, ema_model: nn.Module, model: nn.Module, step_start_ema: int = 2000) -> None:
|
21 |
+
if self.step < step_start_ema:
|
22 |
+
self.reset_parameters(ema_model, model)
|
23 |
+
self.step += 1
|
24 |
+
return
|
25 |
+
self.update_model_average(ema_model, model)
|
26 |
+
self.step += 1
|
27 |
+
|
28 |
+
def copy_to(self, ema_model: nn.Module, model: nn.Module) -> None:
|
29 |
+
model.load_state_dict(ema_model.state_dict())
|
30 |
+
|
31 |
+
def reset_parameters(self, ema_model: nn.Module, model: nn.Module) -> None:
|
32 |
+
ema_model.load_state_dict(model.state_dict())
|
utils/inter_frame_idx.py
ADDED
@@ -0,0 +1,123 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from utils.utils import morph_open
|
2 |
+
|
3 |
+
import torch
|
4 |
+
from kornia.color import rgb_to_grayscale
|
5 |
+
|
6 |
+
import cv2
|
7 |
+
import numpy as np
|
8 |
+
|
9 |
+
class FlowEstimation:
|
10 |
+
def __init__(self, flow_estimator: str = "farneback"):
|
11 |
+
assert flow_estimator in ["farneback", "dualtvl1"], "Flow estimator must be one of [farneback, dualtvl1]"
|
12 |
+
|
13 |
+
if flow_estimator == "farneback":
|
14 |
+
self.flow_estimator = self.OptFlow_Farneback
|
15 |
+
elif flow_estimator == "dualtvl1":
|
16 |
+
self.flow_estimator = self.OptFlow_DualTVL1
|
17 |
+
else:
|
18 |
+
raise NotImplementedError
|
19 |
+
|
20 |
+
def OptFlow_Farneback(self, I0: torch.Tensor, I1: torch.Tensor) -> torch.Tensor:
|
21 |
+
device = I0.device
|
22 |
+
|
23 |
+
I0 = I0.cpu().clamp(0, 1) * 255
|
24 |
+
I1 = I1.cpu().clamp(0, 1) * 255
|
25 |
+
|
26 |
+
batch_size = I0.shape[0]
|
27 |
+
for i in range(batch_size):
|
28 |
+
I0_np = I0[i].permute(1, 2, 0).numpy().astype(np.uint8)
|
29 |
+
I1_np = I1[i].permute(1, 2, 0).numpy().astype(np.uint8)
|
30 |
+
|
31 |
+
I0_gray = cv2.cvtColor(I0_np, cv2.COLOR_BGR2GRAY)
|
32 |
+
I1_gray = cv2.cvtColor(I1_np, cv2.COLOR_BGR2GRAY)
|
33 |
+
|
34 |
+
flow = cv2.calcOpticalFlowFarneback(I0_gray, I1_gray, None, 0.5, 3, 15, 3, 5, 1.2, 0)
|
35 |
+
flow = torch.from_numpy(flow).permute(2, 0, 1).unsqueeze(0).float()
|
36 |
+
if i == 0:
|
37 |
+
flows = flow
|
38 |
+
else:
|
39 |
+
flows = torch.cat((flows, flow), dim = 0)
|
40 |
+
|
41 |
+
return flows.to(device)
|
42 |
+
|
43 |
+
def OptFlow_DualTVL1(
|
44 |
+
self,
|
45 |
+
I0: torch.Tensor,
|
46 |
+
I1: torch.Tensor,
|
47 |
+
tau: float = 0.25,
|
48 |
+
lambda_: float = 0.15,
|
49 |
+
theta: float = 0.3,
|
50 |
+
scales_number: int = 5,
|
51 |
+
warps: int = 5,
|
52 |
+
epsilon: float = 0.01,
|
53 |
+
inner_iterations: int = 30,
|
54 |
+
outer_iterations: int = 10,
|
55 |
+
scale_step: float = 0.8,
|
56 |
+
gamma: float = 0.0
|
57 |
+
) -> torch.Tensor:
|
58 |
+
optical_flow = cv2.optflow.createOptFlow_DualTVL1()
|
59 |
+
optical_flow.setTau(tau)
|
60 |
+
optical_flow.setLambda(lambda_)
|
61 |
+
optical_flow.setTheta(theta)
|
62 |
+
optical_flow.setScalesNumber(scales_number)
|
63 |
+
optical_flow.setWarpingsNumber(warps)
|
64 |
+
optical_flow.setEpsilon(epsilon)
|
65 |
+
optical_flow.setInnerIterations(inner_iterations)
|
66 |
+
optical_flow.setOuterIterations(outer_iterations)
|
67 |
+
optical_flow.setScaleStep(scale_step)
|
68 |
+
optical_flow.setGamma(gamma)
|
69 |
+
|
70 |
+
device = I0.device
|
71 |
+
|
72 |
+
I0 = I0.cpu().clamp(0, 1) * 255
|
73 |
+
I1 = I1.cpu().clamp(0, 1) * 255
|
74 |
+
|
75 |
+
batch_size = I0.shape[0]
|
76 |
+
for i in range(batch_size):
|
77 |
+
I0_np = I0[i].permute(1, 2, 0).numpy().astype(np.uint8)
|
78 |
+
I1_np = I1[i].permute(1, 2, 0).numpy().astype(np.uint8)
|
79 |
+
|
80 |
+
I0_gray = cv2.cvtColor(I0_np, cv2.COLOR_BGR2GRAY)
|
81 |
+
I1_gray = cv2.cvtColor(I1_np, cv2.COLOR_BGR2GRAY)
|
82 |
+
|
83 |
+
flow = optical_flow.calc(I0_gray, I1_gray, None)
|
84 |
+
flow = torch.from_numpy(flow).permute(2, 0, 1).unsqueeze(0).float()
|
85 |
+
if i == 0:
|
86 |
+
flows = flow
|
87 |
+
else:
|
88 |
+
flows = torch.cat((flows, flow), dim = 0)
|
89 |
+
|
90 |
+
return flows.to(device)
|
91 |
+
|
92 |
+
def __call__(self, I1: torch.Tensor, I0: torch.Tensor) -> torch.Tensor:
|
93 |
+
return self.flow_estimator(I1, I0)
|
94 |
+
|
95 |
+
def get_inter_frame_temp_index(
|
96 |
+
I0: torch.Tensor,
|
97 |
+
It: torch.Tensor,
|
98 |
+
I1: torch.Tensor,
|
99 |
+
flow0tot: torch.Tensor,
|
100 |
+
flow1tot: torch.Tensor,
|
101 |
+
k: int = 5,
|
102 |
+
threshold: float = 2e-2
|
103 |
+
) -> torch.Tensor:
|
104 |
+
|
105 |
+
I0_gray = rgb_to_grayscale(I0)
|
106 |
+
It_gray = rgb_to_grayscale(It)
|
107 |
+
I1_gray = rgb_to_grayscale(I1)
|
108 |
+
|
109 |
+
mask0tot = morph_open(It_gray - I0_gray, k=k)
|
110 |
+
mask1tot = morph_open(I1_gray - It_gray, k=k)
|
111 |
+
|
112 |
+
mask0tot = (abs(mask0tot) > threshold).to(torch.uint8)
|
113 |
+
mask1tot = (abs(mask1tot) > threshold).to(torch.uint8)
|
114 |
+
|
115 |
+
flow_mag0tot = torch.sqrt(flow0tot[:, 0, :, :]**2 + flow0tot[:, 1, :, :]**2).unsqueeze(1)
|
116 |
+
flow_mag1tot = torch.sqrt(flow1tot[:, 0, :, :]**2 + flow1tot[:, 1, :, :]**2).unsqueeze(1)
|
117 |
+
|
118 |
+
norm0tot = (flow_mag0tot*mask0tot).squeeze(1)
|
119 |
+
norm1tot = (flow_mag1tot*mask1tot).squeeze(1)
|
120 |
+
d0tot = torch.sum(norm0tot, dim = (1, 2))
|
121 |
+
d1tot = torch.sum(norm1tot, dim = (1, 2))
|
122 |
+
|
123 |
+
return d0tot / (d0tot + d1tot + 1e-12)
|
utils/raft.py
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torchvision.models.optical_flow import raft_large
|
3 |
+
from modules.flow_models.raft.rfr_new import RAFT
|
4 |
+
|
5 |
+
def raft_flow(
|
6 |
+
I0: torch.Tensor,
|
7 |
+
I1: torch.Tensor,
|
8 |
+
data_domain: str = "animation",
|
9 |
+
device: str = 'cuda'
|
10 |
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
11 |
+
if I0.dtype != torch.float32 or I1.dtype != torch.float32:
|
12 |
+
I0 = I0.to(torch.float32)
|
13 |
+
I1 = I1.to(torch.float32)
|
14 |
+
if data_domain == "animation":
|
15 |
+
raft = RAFT().requires_grad_(False).eval().to(device)
|
16 |
+
elif data_domain == "photorealism":
|
17 |
+
raft = raft_large().requires_grad_(False).eval().to(device)
|
18 |
+
else:
|
19 |
+
raise ValueError("data_domain must be either 'animation' or 'photorealism'")
|
20 |
+
return raft(I0, I1) if data_domain == "animation" else raft(I0, I1)[-1]
|
utils/uncertainty.py
ADDED
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import itertools
|
3 |
+
from torchmetrics.image import LearnedPerceptualImagePatchSimilarity as LPIPS
|
4 |
+
from utils.utils import denorm
|
5 |
+
|
6 |
+
def compute_lpips_variability(samples: torch.Tensor,
|
7 |
+
net: str = 'alex',
|
8 |
+
device: str = 'cuda'
|
9 |
+
) -> float:
|
10 |
+
loss_fn = LPIPS(net_type=net).to(device)
|
11 |
+
loss_fn.eval()
|
12 |
+
|
13 |
+
if samples.min() >= 0.0:
|
14 |
+
samples = samples * 2 - 1 # Convertir [0, 1] → [-1, 1]
|
15 |
+
|
16 |
+
N = samples.size(0)
|
17 |
+
scores = []
|
18 |
+
for i, j in itertools.combinations(range(N), 2):
|
19 |
+
x = samples[i:i+1].to(device)
|
20 |
+
y = samples[j:j+1].to(device)
|
21 |
+
dist = loss_fn(denorm(x.clamp(-1, 1)), denorm(y.clamp(-1, 1)))
|
22 |
+
scores.append(dist.item())
|
23 |
+
|
24 |
+
return sum(scores) / len(scores)
|
25 |
+
|
26 |
+
def compute_pixelwise_correlation(samples: torch.Tensor) -> float:
|
27 |
+
N, C, H, W = samples.shape
|
28 |
+
samples_flat = samples.view(N, C, -1) # (N, C, H*W)
|
29 |
+
|
30 |
+
corrs = []
|
31 |
+
for i, j in itertools.combinations(range(N), 2):
|
32 |
+
x = samples_flat[i] # (C, HW)
|
33 |
+
y = samples_flat[j] # (C, HW)
|
34 |
+
mean_x = x.mean(dim=1, keepdim=True)
|
35 |
+
mean_y = y.mean(dim=1, keepdim=True)
|
36 |
+
x_centered = x - mean_x
|
37 |
+
y_centered = y - mean_y
|
38 |
+
numerator = (x_centered * y_centered).sum(dim=1)
|
39 |
+
denominator = (x_centered.norm(dim=1) * y_centered.norm(dim=1)) + 1e-8
|
40 |
+
corr = numerator / denominator # (C,)
|
41 |
+
corrs.append(corr.mean().item())
|
42 |
+
return sum(corrs) / len(corrs)
|
43 |
+
|
44 |
+
def compute_dynamic_range(samples: torch.Tensor) -> float:
|
45 |
+
max_vals, _ = samples.max(dim=0) # (C, H, W)
|
46 |
+
min_vals, _ = samples.min(dim=0) # (C, H, W)
|
47 |
+
|
48 |
+
dynamic_range = max_vals - min_vals # (C, H, W)
|
49 |
+
return dynamic_range.mean().item()
|
utils/utils.py
ADDED
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import torch
|
3 |
+
import numpy as np
|
4 |
+
import matplotlib.pyplot as plt
|
5 |
+
|
6 |
+
try:
|
7 |
+
from kornia.morphology import opening
|
8 |
+
except ImportError:
|
9 |
+
from kornia.morphology import open as opening
|
10 |
+
|
11 |
+
from torchvision import transforms
|
12 |
+
from torchvision.utils import make_grid, save_image
|
13 |
+
|
14 |
+
from typing import Any
|
15 |
+
|
16 |
+
def exist(val: Any) -> bool:
|
17 |
+
return val is not None
|
18 |
+
|
19 |
+
def morph_open(x: torch.Tensor, k: int) -> torch.Tensor:
|
20 |
+
if k==0:
|
21 |
+
return x
|
22 |
+
else:
|
23 |
+
with torch.no_grad():
|
24 |
+
return opening(x, torch.ones(k,k,device=x.device))
|
25 |
+
|
26 |
+
def make_grid_images(images: list[torch.Tensor], **kwargs) -> torch.Tensor:
|
27 |
+
concatenated_images = torch.cat(images, dim=3)
|
28 |
+
grid_concatenated = make_grid(concatenated_images, **kwargs)
|
29 |
+
return grid_concatenated
|
30 |
+
|
31 |
+
def save_images(images: tuple[torch.Tensor, torch.Tensor], path: str, **kwargs) -> None:
|
32 |
+
gen, real = images
|
33 |
+
concatenated_images = torch.cat((gen, real), dim=3)
|
34 |
+
grid_concatenated = make_grid(concatenated_images, **kwargs)
|
35 |
+
|
36 |
+
ndarr_concatenated = grid_concatenated.permute(1, 2, 0).to("cpu").numpy()
|
37 |
+
ndarr_concatenated = (ndarr_concatenated * 255).astype(np.uint8)
|
38 |
+
|
39 |
+
save_image(torch.from_numpy(ndarr_concatenated).permute(2, 0, 1) / 255, path)
|
40 |
+
|
41 |
+
def save_triplet(images: tuple[torch.Tensor, ...], path: str, **kwargs) -> None:
|
42 |
+
concatenated_images = torch.cat(images, dim=3)
|
43 |
+
grid_concatenated = make_grid(concatenated_images, **kwargs)
|
44 |
+
|
45 |
+
ndarr_concatenated = grid_concatenated.permute(1, 2, 0).to("cpu").numpy()
|
46 |
+
ndarr_concatenated = (ndarr_concatenated * 255).astype(np.uint8)
|
47 |
+
|
48 |
+
save_image(torch.from_numpy(ndarr_concatenated).permute(2, 0, 1) / 255, path)
|
49 |
+
|
50 |
+
def plot_images(images: torch.Tensor) -> None:
|
51 |
+
plt.figure(figsize=(32, 32))
|
52 |
+
plt.imshow(torch.cat([
|
53 |
+
torch.cat([i for i in images.cpu()], dim=-1),
|
54 |
+
], dim=-2).permute(1, 2, 0).cpu())
|
55 |
+
plt.show()
|
56 |
+
|
57 |
+
def make_graphic(metric_name: str, metrics: list[torch.Tensor], path: str) -> None:
|
58 |
+
plt.figure(figsize=(32, 32))
|
59 |
+
metrics = [m.cpu().numpy() for m in metrics]
|
60 |
+
plt.plot(metrics)
|
61 |
+
plt.title(metric_name)
|
62 |
+
plt.xlabel("Epoch")
|
63 |
+
plt.ylabel(metric_name)
|
64 |
+
path = os.path.join(path, f"{metric_name}.png")
|
65 |
+
plt.savefig(path)
|
66 |
+
plt.close()
|
67 |
+
|
68 |
+
def norm(
|
69 |
+
img: torch.Tensor,
|
70 |
+
mean: list[float] = [0.5, 0.5, 0.5],
|
71 |
+
std: list[float] = [0.5, 0.5, 0.5]
|
72 |
+
) -> torch.Tensor:
|
73 |
+
normalize = transforms.Normalize(mean, std)
|
74 |
+
return normalize(img)
|
75 |
+
|
76 |
+
def denorm(
|
77 |
+
img: torch.Tensor,
|
78 |
+
mean: list[float] = [0.5, 0.5, 0.5],
|
79 |
+
std: list[float] = [0.5, 0.5, 0.5]
|
80 |
+
) -> torch.Tensor:
|
81 |
+
mean = torch.tensor(mean, device=img.device)
|
82 |
+
std = torch.tensor(std, device=img.device)
|
83 |
+
return img*std[None][...,None,None] + mean[None][...,None,None]
|