vfontech commited on
Commit
587665f
·
verified ·
1 Parent(s): 6a0795d

Uploading the app

Browse files
.gitattributes CHANGED
@@ -1,35 +1,37 @@
1
- *.7z filter=lfs diff=lfs merge=lfs -text
2
- *.arrow filter=lfs diff=lfs merge=lfs -text
3
- *.bin filter=lfs diff=lfs merge=lfs -text
4
- *.bz2 filter=lfs diff=lfs merge=lfs -text
5
- *.ckpt filter=lfs diff=lfs merge=lfs -text
6
- *.ftz filter=lfs diff=lfs merge=lfs -text
7
- *.gz filter=lfs diff=lfs merge=lfs -text
8
- *.h5 filter=lfs diff=lfs merge=lfs -text
9
- *.joblib filter=lfs diff=lfs merge=lfs -text
10
- *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
- *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
- *.model filter=lfs diff=lfs merge=lfs -text
13
- *.msgpack filter=lfs diff=lfs merge=lfs -text
14
- *.npy filter=lfs diff=lfs merge=lfs -text
15
- *.npz filter=lfs diff=lfs merge=lfs -text
16
- *.onnx filter=lfs diff=lfs merge=lfs -text
17
- *.ot filter=lfs diff=lfs merge=lfs -text
18
- *.parquet filter=lfs diff=lfs merge=lfs -text
19
- *.pb filter=lfs diff=lfs merge=lfs -text
20
- *.pickle filter=lfs diff=lfs merge=lfs -text
21
- *.pkl filter=lfs diff=lfs merge=lfs -text
22
- *.pt filter=lfs diff=lfs merge=lfs -text
23
- *.pth filter=lfs diff=lfs merge=lfs -text
24
- *.rar filter=lfs diff=lfs merge=lfs -text
25
- *.safetensors filter=lfs diff=lfs merge=lfs -text
26
- saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
- *.tar.* filter=lfs diff=lfs merge=lfs -text
28
- *.tar filter=lfs diff=lfs merge=lfs -text
29
- *.tflite filter=lfs diff=lfs merge=lfs -text
30
- *.tgz filter=lfs diff=lfs merge=lfs -text
31
- *.wasm filter=lfs diff=lfs merge=lfs -text
32
- *.xz filter=lfs diff=lfs merge=lfs -text
33
- *.zip filter=lfs diff=lfs merge=lfs -text
34
- *.zst filter=lfs diff=lfs merge=lfs -text
35
- *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ _data/example_images/frame1.png filter=lfs diff=lfs merge=lfs -text
37
+ _data/example_images/frame3.png filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.png
2
+ *.jpg
3
+ *.jpeg
4
+ *.gif
5
+ *.bmp
6
+ *.tiff
7
+ *.ico
8
+ !_data/example_images/frame1.png
9
+ !_data/example_images/frame3.png
10
+
11
+ __pycache__
12
+
13
+ *.pyc
14
+ *.pyo
15
+ *.pyd
16
+ *.pyw
17
+ *.pyz
18
+
19
+ *.ckpt
20
+ *.pt
21
+ *.pth
22
+ !metrics/flolpips/weights/v0.1/alex.pth
23
+
24
+ *.ipynb
25
+
26
+
README.md CHANGED
@@ -1,14 +1,14 @@
1
- ---
2
- title: Multi Input Res Diffusion VFI
3
- emoji: 🚀
4
- colorFrom: blue
5
- colorTo: green
6
- sdk: gradio
7
- sdk_version: 5.25.2
8
- app_file: app.py
9
- pinned: false
10
- license: mit
11
- short_description: Gradio demo for Multi-Input ResShift Diffusion VFI
12
- ---
13
-
14
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
+ ---
2
+ title: Multi Input Res Diffusion VFI
3
+ emoji: 🚀
4
+ colorFrom: blue
5
+ colorTo: green
6
+ sdk: gradio
7
+ sdk_version: 5.25.2
8
+ app_file: app.py
9
+ pinned: false
10
+ license: mit
11
+ short_description: Gradio demo for Multi-Input ResShift Diffusion VFI
12
+ ---
13
+
14
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
_data/.gitkeep ADDED
File without changes
_data/example_images/frame1.png ADDED

Git LFS Details

  • SHA256: 6109ca9d74f3bf034fa74fcd744750c704310738e50c345802d539028c31738a
  • Pointer size: 132 Bytes
  • Size of remote file: 2.23 MB
_data/example_images/frame3.png ADDED

Git LFS Details

  • SHA256: e2d7226ea4642e45a00ea1af1cb8b1e6bd1209deef0a956a360a7eaf848b25dc
  • Pointer size: 132 Bytes
  • Size of remote file: 2.28 MB
app.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+
3
+ from PIL import Image
4
+ from torchvision.transforms import Compose, ToTensor, Resize, Normalize
5
+ import numpy as np
6
+ import imageio
7
+ import tempfile
8
+
9
+ from utils.utils import denorm
10
+ from model.hub import MultiInputResShiftHub
11
+
12
+ model = MultiInputResShiftHub.from_pretrained("vfontech/Multiple-Input-Resshift-VFI")
13
+ model.requires_grad_(False).cuda().eval()
14
+
15
+ transform = Compose([
16
+ Resize((256, 448)),
17
+ ToTensor(),
18
+ Normalize(mean=[0.5]*3, std=[0.5]*3),
19
+ ])
20
+
21
+ def to_numpy(img_tensor):
22
+ img_np = denorm(img_tensor, mean=[0.5]*3, std=[0.5]*3).squeeze().permute(1, 2, 0).cpu().numpy()
23
+ img_np = np.clip(img_np, 0, 1)
24
+ return (img_np * 255).astype(np.uint8)
25
+
26
+ def interpolate(img0_pil, img2_pil, tau, num_samples):
27
+ img0 = transform(img0_pil.convert("RGB")).unsqueeze(0).cuda()
28
+ img2 = transform(img2_pil.convert("RGB")).unsqueeze(0).cuda()
29
+
30
+ if num_samples == 1:
31
+ # Unique image
32
+ img1 = model.reverse_process([img0, img2], tau)
33
+ return Image.fromarray(to_numpy(img1)), None
34
+ else:
35
+ # Múltiples imágenes → video
36
+ frames = [to_numpy(img0)]
37
+ for t in np.linspace(0, 1, num_samples):
38
+ img = model.reverse_process([img0, img2], float(t))
39
+ frames.append(to_numpy(img))
40
+ frames.append(to_numpy(img2))
41
+
42
+ temp_path = tempfile.NamedTemporaryFile(suffix=".mp4", delete=False).name
43
+ imageio.mimsave(temp_path, frames, fps=8)
44
+ return None, temp_path
45
+
46
+ demo = gr.Interface(
47
+ fn=interpolate,
48
+ inputs=[
49
+ gr.Image(type="pil", label="Initial Image (frame1)"),
50
+ gr.Image(type="pil", label="Final Image (frame3)"),
51
+ gr.Slider(0.0, 1.0, step=0.05, value=0.5, label="Tau Value (only if Num Samples = 1)"),
52
+ gr.Slider(1, 15, step=1, value=1, label="Number of Samples"),
53
+ ],
54
+ outputs=[
55
+ gr.Image(label="Interpolated Image (if num_samples = 1)"),
56
+ gr.Video(label="Interpolation in video (if num_samples > 1)"),
57
+ ],
58
+ title="Multi-Input ResShift Diffusion VFI",
59
+ description=(
60
+ "📄 [arXiv Paper](https://arxiv.org/pdf/2504.05402) • "
61
+ "🤗 [Model](https://huggingface.co/vfontech/Multiple-Input-Resshift-VFI) • "
62
+ "🧪 [Colab](https://colab.research.google.com/drive/1MGYycbNMW6Mxu5MUqw_RW_xxiVeHK5Aa#scrollTo=EKaYCioiP3tQ) • "
63
+ "🌐 [GitHub](https://github.com/VicFonch/Multi-Input-Resshift-Diffusion-VFI)\n\n"
64
+ "Video interpolation using Conditional Residual Diffusion.\n"
65
+ "- All images are resized to 256x448.\n"
66
+ "- If `Number of Samples` = 1, generates only one intermediate image with the given Tau value.\n"
67
+ "- If `Number of Samples` > 1, ignores Tau and generates a sequence of interpolated images."
68
+ ),
69
+ examples=[
70
+ ["_data/example_images/frame1.png", "_data/example_images/frame3.png", 0.5],
71
+ ],
72
+ )
73
+
74
+ if __name__ == "__main__":
75
+ demo.queue(max_size=12)
76
+ demo.launch(max_threads=1)
config/confg.yaml ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ data_confg:
2
+ train_batch_size: 6
3
+ val_batch_size: 6
4
+ test_batch_size: 6
5
+ flow_method: raft
6
+ data_domain: animation
7
+ datamodule_confg:
8
+ mean: [0.5, 0.5, 0.5]
9
+ sd: [0.5, 0.5, 0.5]
10
+ size: [256, 448]
11
+ amount_augmentations: 1
12
+ horizontal_flip: 0.5
13
+ time_flip: True
14
+ rotation: 0
15
+ brightness: 0.2
16
+ contrast: 0.2
17
+ saturation: 0.2
18
+ hue: 0.1
19
+
20
+ trainer_confg:
21
+ accumulate_grad_batches: 5
22
+ gradient_clip_val: 1.0
23
+ max_epochs: 500
24
+ num_nodes: 1
25
+ devices: 2
26
+ accelerator: gpu
27
+ strategy: ddp_find_unused_parameters_true
28
+
29
+ optim_confg:
30
+ optimizer_confg: # AdamW
31
+ lr: 1.0e-4
32
+ betas: [0.9, 0.999]
33
+ eps: 1.0e-8
34
+ scheduler_confg: # ReduceLROnPlateau
35
+ mode: min
36
+ factor: 0.5
37
+ patience: 3
38
+ verbose: True
39
+
40
+ pretrained_model_path: null # Fine-tune model path
41
+
42
+ model_confg:
43
+ kappa: 2.0
44
+ timesteps: 20
45
+ p: 0.3
46
+ etas_end: 0.99
47
+ min_noise_level: 0.04
48
+ flow_model: raft
49
+ flow_kwargs:
50
+ pretrained_path: null #_pretrain_models/anime_interp_full.ckpt
51
+ warping_kwargs:
52
+ in_channels: 3
53
+ channels: [128, 256, 384, 512]
54
+ synthesis_kwargs:
55
+ in_channels: 3
56
+ channels: [128, 256, 384, 512]
57
+ temb_channels: 512
58
+ heads: 1
59
+ window_size: 8
60
+ window_attn: True
61
+ grid_attn: True
62
+ expansion_rate: 1.5
63
+ num_conv_blocks: 1
64
+ dropout: 0.0
model/hub.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from model.model import MultiInputResShift
2
+ from huggingface_hub import PyTorchModelHubMixin
3
+
4
+ class MultiInputResShiftHub(
5
+ MultiInputResShift,
6
+ PyTorchModelHubMixin,
7
+ repo_url="https://github.com/VicFonch/Multi-Input-Resshift-Diffusion-VFI",
8
+ paper_url="https://arxiv.org/pdf/2504.05402",
9
+ language="en",
10
+ ):
11
+ def __init__(self, *args, **kwargs):
12
+ super().__init__(*args, **kwargs)
model/model.py ADDED
@@ -0,0 +1,220 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from torch.nn.functional import interpolate
4
+
5
+ import math
6
+ from tqdm import tqdm
7
+
8
+ from modules.feature_extactor import Extractor
9
+ from modules.half_warper import HalfWarper
10
+ from modules.cupy_module.nedt import NEDT
11
+ from modules.flow_models.flow_models import (
12
+ RAFTFineFlow,
13
+ PWCFineFlow
14
+ )
15
+ from modules.synthesizer import Synthesis
16
+
17
+ class FeatureWarper(nn.Module):
18
+ def __init__(
19
+ self,
20
+ in_channels: int = 3,
21
+ channels: list[int] = [32, 64, 128, 256],
22
+ ):
23
+ super().__init__()
24
+ channels = [in_channels + 1] + channels
25
+
26
+ self.half_warper = HalfWarper()
27
+ self.feature_extractor = Extractor(channels)
28
+ self.nedt = NEDT()
29
+
30
+ def forward(
31
+ self,
32
+ I0: torch.Tensor,
33
+ I1: torch.Tensor,
34
+ flow0to1: torch.Tensor,
35
+ flow1to0: torch.Tensor,
36
+ tau: torch.Tensor = None
37
+ ) -> tuple[list[torch.Tensor], list[torch.Tensor]]:
38
+ assert tau.shape == (I0.shape[0], 2), "tau shape must be (batch, 2)"
39
+
40
+ flow0tot = tau[:, 0][:, None, None, None] * flow0to1
41
+ flow1tot = tau[:, 1][:, None, None, None] * flow1to0
42
+
43
+ I0 = torch.cat([I0, self.nedt(I0)], dim=1)
44
+ I1 = torch.cat([I1, self.nedt(I1)], dim=1)
45
+
46
+ z0to1, z1to0 = HalfWarper.z_metric(I0, I1, flow0to1, flow1to0)
47
+ base0, base1 = self.half_warper(I0, I1, flow0tot, flow1tot, z0to1, z1to0)
48
+ warped0, warped1 = [base0], [base1]
49
+
50
+ features0 = self.feature_extractor(I0)
51
+ features1 = self.feature_extractor(I1)
52
+
53
+ for feat0, feat1 in zip(features0, features1):
54
+ f0 = interpolate(flow0tot, size=feat0.shape[2:], mode='bilinear', align_corners=False)
55
+ f1 = interpolate(flow1tot, size=feat0.shape[2:], mode='bilinear', align_corners=False)
56
+ z0 = interpolate(z0to1, size=feat0.shape[2:], mode='bilinear', align_corners=False)
57
+ z1 = interpolate(z1to0, size=feat0.shape[2:], mode='bilinear', align_corners=False)
58
+ w0, w1 = self.half_warper(feat0, feat1, f0, f1, z0, z1)
59
+ warped0.append(w0)
60
+ warped1.append(w1)
61
+ return warped0, warped1
62
+
63
+ class MultiInputResShift(nn.Module):
64
+ def __init__(
65
+ self,
66
+ kappa: float=2.0,
67
+ p: float =0.3,
68
+ min_noise_level: float=0.04,
69
+ etas_end: float=0.99,
70
+ timesteps: int=15,
71
+ flow_model: str = 'raft',
72
+ flow_kwargs: dict = {},
73
+ warping_kwargs: dict = {},
74
+ synthesis_kwargs: dict = {}
75
+ ):
76
+ super().__init__()
77
+
78
+ self.timesteps = timesteps
79
+ self.kappa = kappa
80
+ self.eta_partition = None
81
+
82
+ sqrt_eta_1 = min(min_noise_level / kappa, min_noise_level, math.sqrt(0.001))
83
+ b0 = math.exp(1/float(timesteps - 1) * math.log(etas_end/sqrt_eta_1))
84
+ base = torch.ones(timesteps)*b0
85
+ beta = ((torch.linspace(0,1,timesteps))**p)*(timesteps-1)
86
+ sqrt_eta = torch.pow(base, beta) * sqrt_eta_1
87
+
88
+ self.register_buffer("sqrt_sum_eta", sqrt_eta)
89
+ self.register_buffer("sum_eta", sqrt_eta**2)
90
+
91
+ sum_prev_eta = torch.roll(self.sum_eta, 1)
92
+ sum_prev_eta[0] = 0
93
+ self.register_buffer("sum_prev_eta", sum_prev_eta)
94
+
95
+ self.register_buffer("sum_alpha", self.sum_eta - self.sum_prev_eta)
96
+
97
+ self.register_buffer("backward_mean_c1", self.sum_prev_eta / self.sum_eta)
98
+ self.register_buffer("backward_mean_c2", self.sum_alpha / self.sum_eta)
99
+ self.register_buffer("backward_std", self.kappa*torch.sqrt(self.sum_prev_eta*self.sum_alpha/self.sum_eta))
100
+
101
+ if flow_model == 'raft':
102
+ self.flow_model = RAFTFineFlow(**flow_kwargs)
103
+ elif flow_model == 'pwc':
104
+ self.flow_model = PWCFineFlow(**flow_kwargs)
105
+ else:
106
+ raise ValueError(f"Flow model {flow_model} not supported")
107
+
108
+ self.feature_warper = FeatureWarper(**warping_kwargs)
109
+ self.synthesis = Synthesis(**synthesis_kwargs)
110
+
111
+ def forward_process(
112
+ self,
113
+ x: torch.Tensor | None,
114
+ Y: list[torch.Tensor],
115
+ tau: torch.Tensor | float | None,
116
+ t: torch.Tensor | int
117
+ ) -> torch.Tensor:
118
+ if tau is None:
119
+ tau: torch.Tensor = torch.full((x.shape[0], len(Y)), 0.5, device=x.device, dtype=x.dtype)
120
+ elif isinstance(tau, float):
121
+ assert tau >= 0 and tau <= 1, "tau must be between 0 and 1"
122
+ tau: torch.Tensor = torch.cat([
123
+ torch.full((x.shape[0], 1), tau, device=x.device, dtype=x.dtype),
124
+ torch.full((x.shape[0], 1), 1 - tau, device=x.device, dtype=x.dtype)
125
+ ], dim=1)
126
+ if not torch.is_tensor(t):
127
+ t: torch.Tensor = torch.tensor([t], device=x.device, dtype=torch.long)
128
+ if x is None:
129
+ x: torch.Tensor = torch.zeros_like(Y[0])
130
+
131
+ eta = self.sum_eta[t][:, None] * tau
132
+ eta = eta[:, :, None, None, None].transpose(0, 1)
133
+
134
+ e_i = torch.stack([y - x for y in Y])
135
+ mean = x + (eta*e_i).sum(dim=0)
136
+
137
+ sqrt_sum_eta = self.sqrt_sum_eta[t][:, None, None, None]
138
+ std = self.kappa*sqrt_sum_eta
139
+ epsilon = torch.randn_like(x)
140
+
141
+ return mean + std*epsilon
142
+
143
+ @torch.inference_mode()
144
+ def reverse_process(
145
+ self,
146
+ Y: list[torch.Tensor],
147
+ tau: torch.Tensor | float,
148
+ flows: list[torch.Tensor] | None = None,
149
+ ) -> torch.Tensor:
150
+ y = Y[0]
151
+ batch, device, dtype = y.shape[0], y.device, y.dtype
152
+
153
+ if isinstance(tau, float):
154
+ assert tau >= 0 and tau <= 1, "tau must be between 0 and 1"
155
+ tau: torch.Tensor = torch.cat([
156
+ torch.full((batch, 1), tau, device=device, dtype=dtype),
157
+ torch.full((batch, 1), 1 - tau, device=device, dtype=dtype)
158
+ ], dim=1)
159
+ if flows is None:
160
+ flow0to1, flow1to0 = self.flow_model(Y[0], Y[1])
161
+ else:
162
+ flow0to1, flow1to0 = flows
163
+ warp0to1, warp1to0 = self.feature_warper(Y[0], Y[1], flow0to1, flow1to0, tau)
164
+
165
+ T = torch.tensor([self.timesteps-1,] * batch, device=device, dtype=torch.long)
166
+ x = self.forward_process(torch.zeros_like(Y[0]), [warp0to1[0][:, :3], warp1to0[0][:, :3]], tau, T)
167
+
168
+ pbar = tqdm(total=self.timesteps, desc="Reversing Process")
169
+ for i in reversed(range(self.timesteps)):
170
+ t = torch.ones(batch, device = device, dtype=torch.long) * i
171
+
172
+ predicted_x0 = self.synthesis(x, warp0to1, warp1to0, t)
173
+
174
+ mean_c1 = self.backward_mean_c1[t][:, None, None, None]
175
+ mean_c2 = self.backward_mean_c2[t][:, None, None, None]
176
+ std = self.backward_std[t][:, None, None, None]
177
+
178
+ eta = self.sum_eta[t][:, None] * tau
179
+ prev_eta = self.sum_prev_eta[t][:, None] * tau
180
+ eta = eta[:, :, None, None, None].transpose(0, 1)
181
+ prev_eta = prev_eta[:, :, None, None, None].transpose(0, 1)
182
+ e_i = torch.stack([y - predicted_x0 for y in Y])
183
+
184
+ mean = (
185
+ mean_c1*(x + (eta*e_i).sum(dim=0))
186
+ + mean_c2*predicted_x0
187
+ - (prev_eta*e_i).sum(dim=0)
188
+ )
189
+
190
+ x = mean + std*torch.randn_like(x)
191
+ pbar.update(1)
192
+ pbar.close()
193
+ return x
194
+
195
+ # Training Step Only
196
+ def forward(
197
+ self,
198
+ I0: torch.Tensor,
199
+ It: torch.Tensor,
200
+ I1: torch.Tensor,
201
+ flow1to0: torch.Tensor | None = None,
202
+ flow0to1: torch.Tensor | None = None,
203
+ tau: torch.Tensor | None = None,
204
+ t: torch.Tensor | None = None
205
+ ) -> torch.Tensor:
206
+
207
+ if tau is None:
208
+ tau = torch.full((It.shape[0], 2), 0.5, device=It.device, dtype=It.dtype)
209
+
210
+ if flow0to1 is None or flow1to0 is None:
211
+ flow0to1, flow1to0 = self.flow_model(I0, I1)
212
+
213
+ if t is None:
214
+ t = torch.randint(low=1, high=self.timesteps, size=(It.shape[0],), device=It.device, dtype=torch.long)
215
+
216
+ warp0to1, warp1to0 = self.feature_warper(I0, I1, flow0to1, flow1to0, tau)
217
+ x_t = self.forward_process(It, [warp0to1[0][:, :3], warp1to0[0][:, :3]], tau, t)
218
+
219
+ predicted_It = self.synthesis(x_t, warp0to1, warp1to0, t)
220
+ return predicted_It
model/train_pipline.py ADDED
@@ -0,0 +1,177 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import copy
3
+ import matplotlib.pyplot as plt
4
+ from typing import Any
5
+
6
+ import torch
7
+ from torch.optim.lr_scheduler import ReduceLROnPlateau
8
+ from torch.optim import AdamW, Optimizer
9
+ from torch.utils.data import DataLoader
10
+ from lightning import LightningModule
11
+
12
+ from torchmetrics import MetricCollection
13
+ from torchmetrics.image import PeakSignalNoiseRatio as PSNR
14
+ from torchmetrics.image import StructuralSimilarityIndexMeasure as SSIM
15
+ from torchmetrics.image import LearnedPerceptualImagePatchSimilarity as LPIPS
16
+
17
+ from model.model import MultiInputResShift
18
+
19
+ from utils.utils import denorm, make_grid_images#, save_triplet
20
+ from utils.ema import EMA
21
+ from utils.inter_frame_idx import get_inter_frame_temp_index
22
+ from utils.raft import raft_flow
23
+
24
+
25
+ class TrainPipline(LightningModule):
26
+ def __init__(self,
27
+ confg: dict,
28
+ test_dataloader: DataLoader):
29
+ super(TrainPipline, self).__init__()
30
+
31
+ self.test_dataloader = test_dataloader
32
+
33
+ self.confg = confg
34
+
35
+ self.mean, self.sd = confg["data_confg"]["mean"], confg["data_confg"]["sd"]
36
+
37
+ self.model = MultiInputResShift(**confg["model_confg"])
38
+ self.model.flow_model.requires_grad_(False).eval()
39
+
40
+ self.ema = EMA(beta=0.995)
41
+ self.ema_model = copy.deepcopy(self.model).eval().requires_grad_(False)
42
+
43
+ self.charbonnier_loss = lambda x, y: torch.mean(torch.sqrt((x - y)**2 + 1e-6))
44
+ self.lpips_loss = LPIPS(net_type='vgg')
45
+
46
+ self.train_metrics = MetricCollection({
47
+ "train_lpips": LPIPS(net_type='alex'),
48
+ "train_psnr": PSNR(),
49
+ "train_ssim": SSIM()
50
+ })
51
+ self.val_metrics = MetricCollection({
52
+ "val_lpips": LPIPS(net_type='alex'),
53
+ "val_psnr": PSNR(),
54
+ "val_ssim": SSIM()
55
+ })
56
+
57
+ def loss_fn(self,
58
+ x: torch.Tensor,
59
+ predicted_x: torch.Tensor) -> torch.Tensor:
60
+ percep_loss = 0.2 * self.lpips_loss(x, predicted_x.clamp(-1, 1))
61
+ pix2pix_loss = self.charbonnier_loss(x, predicted_x)
62
+ return percep_loss + pix2pix_loss
63
+
64
+ def sample_t(self,
65
+ shape: tuple[int, ...],
66
+ max_t: int,
67
+ device: torch.device) -> torch.Tensor:
68
+ p = torch.linspace(1, max_t, steps=max_t, device=device) ** 2
69
+ p = p / p.sum()
70
+ t = torch.multinomial(p, num_samples=shape[0], replacement=True)
71
+ return t
72
+
73
+ def forward(self,
74
+ I0: torch.Tensor,
75
+ It: torch.Tensor,
76
+ I1: torch.Tensor) -> torch.Tensor:
77
+ flow0tot = raft_flow(I0, It, 'animation')
78
+ flow1tot = raft_flow(I1, It, 'animation')
79
+ mid_idx = get_inter_frame_temp_index(I0, It, I1, flow0tot, flow1tot).to(It.dtype)
80
+
81
+ tau = torch.stack([mid_idx, 1 - mid_idx], dim=1)
82
+
83
+ if self.current_epoch > 5:
84
+ t = torch.randint(low=1, high=self.model.timesteps, size=(It.shape[0],), device=It.device, dtype=torch.long)
85
+ else:
86
+ t = self.sample_t(shape=(It.shape[0],), max_t=self.model.timesteps, device=It.device)
87
+
88
+ predicted_It = self.model(I0, It, I1, tau=tau, t=t)
89
+ return predicted_It
90
+
91
+ def get_step_plt_images(self,
92
+ It: torch.Tensor,
93
+ predicted_It: torch.Tensor) -> plt.Figure:
94
+ fig, ax = plt.subplots(1, 2, figsize=(20, 10))
95
+ ax[0].imshow(denorm(predicted_It.clamp(-1, 1), self.mean, self.sd)[0].permute(1, 2, 0).cpu().numpy())
96
+ ax[0].axis("off")
97
+ ax[0].set_title("Predicted")
98
+ ax[1].imshow(denorm(It, self.mean, self.sd)[0].permute(1, 2, 0).cpu().numpy())
99
+ ax[1].axis("off")
100
+ ax[1].set_title("Ground Truth")
101
+ plt.tight_layout()
102
+ #img_path = "step_image.png"
103
+ #fig.savefig(img_path, dpi=300, bbox_inches='tight')
104
+ plt.close(fig)
105
+ return fig
106
+
107
+ def training_step(self, batch: tuple[torch.Tensor, ...], _) -> torch.Tensor:
108
+ I0, It, I1 = batch
109
+ predicted_It = self(I0, It, I1)
110
+ loss = self.loss_fn(It, predicted_It)
111
+
112
+ self.log("lr", self.trainer.optimizers[0].param_groups[0]["lr"], prog_bar=True, on_step=True, on_epoch=False, sync_dist=True)
113
+ self.log("train_loss", loss, prog_bar=True, on_step=True, on_epoch=False, sync_dist=True)
114
+
115
+ self.ema.step_ema(self.ema_model, self.model)
116
+ with torch.inference_mode():
117
+ fig = self.get_step_plt_images(It, predicted_It)
118
+ self.logger.experiment.add_figure("Train Predictions", fig, self.global_step)
119
+ mets = self.train_metrics(It, predicted_It.clamp(-1, 1))
120
+ self.log_dict(mets, prog_bar=True, on_step=True,on_epoch=False)
121
+ return loss
122
+
123
+ @torch.no_grad()
124
+ def validation_step(self, batch: tuple[torch.Tensor, ...], _) -> None:
125
+ I0, It, I1 = batch
126
+ predicted_It = self(I0, It, I1)
127
+ loss = self.loss_fn(It, predicted_It)
128
+
129
+ self.log("val_loss", loss, prog_bar=True, on_step=False, on_epoch=True, sync_dist=True)
130
+
131
+ mets = self.val_metrics(It, predicted_It.clamp(-1, 1))
132
+ self.log_dict(mets, prog_bar=True, on_step=False, on_epoch=True)
133
+
134
+ @torch.inference_mode()
135
+ def on_train_epoch_end(self) -> None:
136
+ torch.save(self.ema_model.state_dict(),
137
+ os.path.join("_checkpoint", f"resshift_diff_{self.current_epoch}.pth"))
138
+
139
+ batch = next(iter(self.test_dataloader))
140
+ I0, It, I1 = batch
141
+ I0, It, I1 = I0.to(self.device), It.to(self.device), I1.to(self.device)
142
+
143
+ flow0tot = raft_flow(I0, It, 'animation')
144
+ flow1tot = raft_flow(I1, It, 'animation')
145
+ mid_idx = get_inter_frame_temp_index(I0, It, I1, flow0tot, flow1tot).to(It.dtype)
146
+ tau = torch.stack([mid_idx, 1 - mid_idx], dim=1)
147
+
148
+ predicted_It = self.ema_model.reverse_process([I0, I1], tau)
149
+
150
+ I0 = denorm(I0, self.mean, self.sd)
151
+ I1 = denorm(I1, self.mean, self.sd)
152
+ It = denorm(It, self.mean, self.sd)
153
+ predicted_It = denorm(predicted_It.clamp(-1, 1), self.mean, self.sd)
154
+
155
+ #save_triplet([I0, It, predicted_It, I1], f"./_output/target_{self.current_epoch}.png", nrow=1)
156
+ grid = make_grid_images([I0, It, predicted_It, I1], nrow=1)
157
+ self.logger.experiment.add_image("Predicted Images", grid, self.global_step)
158
+
159
+ def configure_optimizers(self) -> tuple[list[Optimizer], list[dict[str, Any]]]:
160
+ optimizer = [AdamW(
161
+ self.model.parameters(),
162
+ **self.confg["optim_confg"]['optimizer_confg']
163
+ )]
164
+
165
+ scheduler = [{
166
+ 'scheduler': ReduceLROnPlateau(
167
+ optimizer[0],
168
+ **self.confg["optim_confg"]['scheduler_confg']
169
+ ),
170
+ 'monitor': 'val_loss',
171
+ 'interval': 'epoch',
172
+ 'frequency': 1,
173
+ 'strict': True,
174
+ }]
175
+
176
+ return optimizer, scheduler
177
+
modules/basic_layers.py ADDED
@@ -0,0 +1,313 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ from einops import rearrange
6
+ from einops.layers.torch import Rearrange
7
+
8
+ class GroupNorm(nn.Module):
9
+ def __init__(self, in_channels: int, num_groups: int = 32):
10
+ super(GroupNorm, self).__init__()
11
+ self.gn = nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True)
12
+
13
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
14
+ return self.gn(x)
15
+
16
+ class AdaLayerNorm(nn.Module):
17
+ def __init__(self, channels: int, cond_channels: int = 0, return_scale_shift: bool = True):
18
+ super(AdaLayerNorm, self).__init__()
19
+ self.norm = nn.LayerNorm(channels)
20
+ self.return_scale_shift = return_scale_shift
21
+ if cond_channels != 0:
22
+ if return_scale_shift:
23
+ self.proj = nn.Linear(cond_channels, channels * 3, bias=False)
24
+ else:
25
+ self.proj = nn.Linear(cond_channels, channels * 2, bias=False)
26
+ nn.init.xavier_uniform_(self.proj.weight)
27
+
28
+ def expand_dims(self, tensor: torch.Tensor, dims: list[int]) -> torch.Tensor:
29
+ for dim in dims:
30
+ tensor = tensor.unsqueeze(dim)
31
+ return tensor
32
+
33
+ def forward(self, x: torch.Tensor, cond: torch.Tensor | None = None) -> torch.Tensor:
34
+ x = self.norm(x)
35
+ if cond is None:
36
+ return x
37
+ dims = list(range(1, len(x.shape) - 1))
38
+ if self.return_scale_shift:
39
+ gamma, beta, sigma = self.proj(cond).chunk(3, dim=-1)
40
+ gamma, beta, sigma = [self.expand_dims(t, dims) for t in (gamma, beta, sigma)]
41
+ return x * (1 + gamma) + beta, sigma
42
+ else:
43
+ gamma, beta = self.proj(cond).chunk(2, dim=-1)
44
+ gamma, beta = [self.expand_dims(t, dims) for t in (gamma, beta)]
45
+ return x * (1 + gamma) + beta
46
+
47
+ class SinusoidalPositionalEmbedding(nn.Module):
48
+ def __init__(self, emb_dim: int = 256):
49
+ super(SinusoidalPositionalEmbedding, self).__init__()
50
+ self.channels = emb_dim
51
+
52
+ def forward(self, t: torch.Tensor) -> torch.Tensor:
53
+ inv_freq = 1.0 / (
54
+ 10000
55
+ ** (torch.arange(0, self.channels, 2, device=t.device).float() / self.channels)
56
+ )
57
+ pos_enc_a = torch.sin(t.repeat(1, self.channels // 2) * inv_freq)
58
+ pos_enc_b = torch.cos(t.repeat(1, self.channels // 2) * inv_freq)
59
+ pos_enc = torch.cat([pos_enc_a, pos_enc_b], dim=-1)
60
+ return pos_enc
61
+
62
+ class GatedConv2d(nn.Module):
63
+ def __init__(self,
64
+ in_channels: int,
65
+ out_channels: int,
66
+ kernel_size: int = 3,
67
+ padding: int = 1,
68
+ bias: bool = False):
69
+ super(GatedConv2d, self).__init__()
70
+ self.gate_conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)
71
+ self.feature_conv = nn.Conv2d(in_channels,
72
+ out_channels,
73
+ kernel_size=kernel_size,
74
+ padding=padding,
75
+ bias=bias)
76
+
77
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
78
+ gate = torch.sigmoid(self.gate_conv(x))
79
+ feature = F.silu(self.feature_conv(x))
80
+ return gate * feature
81
+
82
+ class ResGatedBlock(nn.Module):
83
+ def __init__(self,
84
+ in_channels: int,
85
+ out_channels: int,
86
+ mid_channels: int | None = None,
87
+ num_groups: int = 32,
88
+ residual: bool = True,
89
+ emb_channels: int | None = None,
90
+ gated_conv: bool = False):
91
+ super().__init__()
92
+ self.residual = residual
93
+ self.emb_channels = emb_channels
94
+ if not mid_channels:
95
+ mid_channels = out_channels
96
+
97
+ if gated_conv: conv2d = GatedConv2d
98
+ else: conv2d = nn.Conv2d
99
+
100
+ self.conv1 = conv2d(in_channels, mid_channels, kernel_size=3, padding=1, bias=False)
101
+ self.norm1 = GroupNorm(mid_channels, num_groups=num_groups)
102
+ self.nonlienrity = nn.SiLU()
103
+ if emb_channels:
104
+ self.emb_proj = nn.Linear(emb_channels, mid_channels)
105
+ self.conv2 = conv2d(mid_channels, out_channels, kernel_size=3, padding=1, bias=False)
106
+ self.norm2 = GroupNorm(out_channels, num_groups=num_groups)
107
+
108
+ if in_channels != out_channels:
109
+ self.skip = conv2d(in_channels, out_channels, kernel_size=1, padding=0)
110
+
111
+ def double_conv(self, x: torch.Tensor, emb: torch.Tensor | None = None) -> torch.Tensor:
112
+ x = self.conv1(x)
113
+ x = self.norm1(x)
114
+ x = self.nonlienrity(x)
115
+ if emb is not None and self.emb_channels is not None:
116
+ x = x + self.emb_proj(emb)[:,:,None,None]
117
+ x = self.conv2(x)
118
+ return self.norm2(x)
119
+
120
+ def forward(self, x: torch.Tensor, emb: torch.Tensor | None = None) -> torch.Tensor:
121
+ if self.residual:
122
+ if hasattr(self, 'skip'):
123
+ return F.silu(self.skip(x) + self.double_conv(x, emb))
124
+ return F.silu(x + self.double_conv(x, emb))
125
+ else:
126
+ return self.double_conv(x, emb)
127
+
128
+ class Downsample(nn.Module):
129
+ def __init__(self,
130
+ in_channels: int,
131
+ out_channels: int,
132
+ use_conv: bool=True):
133
+ super().__init__()
134
+
135
+ self.use_conv = use_conv
136
+ if use_conv:
137
+ self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=2, padding=0)
138
+ else:
139
+ assert in_channels == out_channels
140
+ self.conv = nn.AvgPool2d(kernel_size=2, stride=2)
141
+
142
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
143
+ pad = (0, 1, 0, 1)
144
+ hidden_states = F.pad(x, pad, mode="constant", value=0)
145
+ return self.conv(hidden_states) if self.use_conv else self.conv(x)
146
+
147
+ class Upsample(nn.Module):
148
+ def __init__(self,
149
+ in_channels: int,
150
+ out_channels: int,
151
+ use_conv: bool=True):
152
+ super().__init__()
153
+
154
+ self.use_conv = use_conv
155
+ if use_conv:
156
+ self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
157
+
158
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
159
+ x = F.interpolate(x,
160
+ scale_factor = (2, 2) if x.dim() == 4 else (1, 2, 2),
161
+ mode='nearest')
162
+ return self.conv(x) if self.use_conv else x
163
+
164
+ class FeedForward(nn.Module):
165
+ def __init__(self,
166
+ dim: int,
167
+ emb_channels: int,
168
+ expansion_rate: int = 4,
169
+ dropout: float = 0.0):
170
+ super().__init__()
171
+ inner_dim = int(dim * expansion_rate)
172
+ self.norm = AdaLayerNorm(dim, emb_channels)
173
+ self.net = nn.Sequential(
174
+ nn.Linear(dim, inner_dim),
175
+ nn.SiLU(),
176
+ nn.Dropout(dropout),
177
+ nn.Linear(inner_dim, dim),
178
+ nn.Dropout(dropout)
179
+ )
180
+ self.__init_weights()
181
+
182
+ def __init_weights(self):
183
+ nn.init.xavier_uniform_(self.net[0].weight)
184
+ nn.init.xavier_uniform_(self.net[3].weight)
185
+
186
+ def forward(self, x: torch.Tensor, emb: torch.Tensor | None = None) -> torch.Tensor:
187
+ x, sigma = self.norm(x, emb)
188
+ return self.net(x) * sigma
189
+
190
+ class Attention(nn.Module):
191
+ def __init__(
192
+ self,
193
+ dim: int,
194
+ emb_channels: int = 512,
195
+ dim_head: int = 32,
196
+ dropout: float = 0.,
197
+ window_size: int = 7
198
+ ):
199
+ super().__init__()
200
+ assert (dim % dim_head) == 0, 'dimension should be divisible by dimension per head'
201
+ self.heads = dim // dim_head
202
+ self.scale = dim_head ** -0.5
203
+ self.norm = AdaLayerNorm(dim, emb_channels)
204
+
205
+ self.to_q = nn.Linear(dim, dim, bias = False)
206
+ self.to_k = nn.Linear(dim, dim, bias = False)
207
+ self.to_v = nn.Linear(dim, dim, bias = False)
208
+
209
+ self.attend = nn.Sequential(
210
+ nn.Softmax(dim = -1),
211
+ nn.Dropout(dropout)
212
+ )
213
+ self.to_out = nn.Sequential(
214
+ nn.Linear(dim, dim, bias = False),
215
+ nn.Dropout(dropout)
216
+ )
217
+
218
+ self.rel_pos_bias = nn.Embedding((2 * window_size - 1) ** 2, self.heads)
219
+ pos = torch.arange(window_size)
220
+ grid = torch.stack(torch.meshgrid(pos, pos, indexing = 'ij'))
221
+ grid = rearrange(grid, 'c i j -> (i j) c')
222
+ rel_pos = rearrange(grid, 'i ... -> i 1 ...') - rearrange(grid, 'j ... -> 1 j ...')
223
+ rel_pos += window_size - 1
224
+ rel_pos_indices = (rel_pos * torch.tensor([2 * window_size - 1, 1])).sum(dim = -1)
225
+
226
+ self.register_buffer('rel_pos_indices', rel_pos_indices, persistent = False)
227
+
228
+ def forward(self, x: torch.Tensor, emb: torch.Tensor | None = None) -> torch.Tensor:
229
+ batch, height, width, window_height, window_width, _, device, h = *x.shape, x.device, self.heads
230
+
231
+ x, sigma = self.norm(x, emb)
232
+ x = rearrange(x, 'b x y w1 w2 d -> (b x y) (w1 w2) d')
233
+
234
+ q = self.to_q(x)
235
+ k = self.to_k(x)
236
+ v = self.to_v(x)
237
+
238
+ q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, k, v)) # split heads
239
+
240
+ q = q * self.scale
241
+ sim = torch.einsum('b h i d, b h j d -> b h i j', q, k) # sim
242
+ bias = self.rel_pos_bias(self.rel_pos_indices)
243
+ sim = sim + rearrange(bias, 'i j h -> h i j')# add positional bias
244
+ attn = self.attend(sim) # attention
245
+ out = torch.einsum('b h i j, b h j d -> b h i d', attn, v) # aggregate
246
+
247
+ out = rearrange(out, 'b h (w1 w2) d -> b w1 w2 (h d)', w1 = window_height, w2 = window_width) # merge heads
248
+ out = self.to_out(out) # combine heads out
249
+ return rearrange(out, '(b x y) ... -> b x y ...', x = height, y = width) * sigma
250
+
251
+ class MaxViTBlock(nn.Module):
252
+ def __init__(
253
+ self,
254
+ channels: int,
255
+ emb_channels: int = 512,
256
+ heads: int = 1,
257
+ window_size: int = 8,
258
+ window_attn: bool = True,
259
+ grid_attn: bool = True,
260
+ expansion_rate: int = 4,
261
+ dropout: float = 0.0,
262
+ ):
263
+ super(MaxViTBlock, self).__init__()
264
+ dim_head = channels // heads
265
+ layer_dim = dim_head * heads
266
+ w = window_size
267
+
268
+ self.window_attn = window_attn
269
+ self.grid_attn = grid_attn
270
+
271
+ if window_attn:
272
+ self.wind_rearrange_forward = Rearrange('b d (x w1) (y w2) -> b x y w1 w2 d', w1 = w, w2 = w) # block-like attention
273
+ self.wind_attn = Attention(
274
+ dim = layer_dim,
275
+ emb_channels = emb_channels,
276
+ dim_head = dim_head,
277
+ dropout = dropout,
278
+ window_size = w
279
+ )
280
+
281
+ self.wind_ff = FeedForward(dim = layer_dim,
282
+ emb_channels = emb_channels,
283
+ expansion_rate = expansion_rate,
284
+ dropout = dropout)
285
+ self.wind_rearrange_backward = Rearrange('b x y w1 w2 d -> b d (x w1) (y w2)')
286
+
287
+ if grid_attn:
288
+ self.grid_rearrange_forward = Rearrange('b d (w1 x) (w2 y) -> b x y w1 w2 d', w1 = w, w2 = w) # grid-like attention
289
+ self.grid_attn = Attention(
290
+ dim = layer_dim,
291
+ emb_channels = emb_channels,
292
+ dim_head = dim_head,
293
+ dropout = dropout,
294
+ window_size = w
295
+ )
296
+ self.grid_ff = FeedForward(dim = layer_dim,
297
+ emb_channels = emb_channels,
298
+ expansion_rate = expansion_rate,
299
+ dropout = dropout)
300
+ self.grid_rearrange_backward = Rearrange('b x y w1 w2 d -> b d (w1 x) (w2 y)')
301
+
302
+ def forward(self, x: torch.Tensor, emb: torch.Tensor | None = None) -> torch.Tensor:
303
+ if self.window_attn:
304
+ x = self.wind_rearrange_forward(x)
305
+ x = x + self.wind_attn(x, emb = emb)
306
+ x = x + self.wind_ff(x, emb = emb)
307
+ x = self.wind_rearrange_backward(x)
308
+ if self.grid_attn:
309
+ x = self.grid_rearrange_forward(x)
310
+ x = x + self.grid_attn(x, emb = emb)
311
+ x = x + self.grid_ff(x, emb = emb)
312
+ x = self.grid_rearrange_backward(x)
313
+ return x
modules/cupy_module/correlation.py ADDED
@@ -0,0 +1,402 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ import cupy
4
+ import os
5
+ import re
6
+ import torch
7
+
8
+ # Code taken from https://github.com/sniklaus/softmax-splatting/blob/master/correlation/correlation.py
9
+
10
+ kernel_Correlation_rearrange = '''
11
+ extern "C" __global__ void kernel_Correlation_rearrange(
12
+ const int n,
13
+ const float* input,
14
+ float* output
15
+ ) {
16
+ int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x;
17
+
18
+ if (intIndex >= n) {
19
+ return;
20
+ }
21
+
22
+ int intSample = blockIdx.z;
23
+ int intChannel = blockIdx.y;
24
+
25
+ float fltValue = input[(((intSample * SIZE_1(input)) + intChannel) * SIZE_2(input) * SIZE_3(input)) + intIndex];
26
+
27
+ __syncthreads();
28
+
29
+ int intPaddedY = (intIndex / SIZE_3(input)) + 4;
30
+ int intPaddedX = (intIndex % SIZE_3(input)) + 4;
31
+ int intRearrange = ((SIZE_3(input) + 8) * intPaddedY) + intPaddedX;
32
+
33
+ output[(((intSample * SIZE_1(output) * SIZE_2(output)) + intRearrange) * SIZE_1(input)) + intChannel] = fltValue;
34
+ }
35
+ '''
36
+
37
+ kernel_Correlation_updateOutput = '''
38
+ extern "C" __global__ void kernel_Correlation_updateOutput(
39
+ const int n,
40
+ const float* rbot0,
41
+ const float* rbot1,
42
+ float* top
43
+ ) {
44
+ extern __shared__ char patch_data_char[];
45
+
46
+ float *patch_data = (float *)patch_data_char;
47
+
48
+ // First (upper left) position of kernel upper-left corner in current center position of neighborhood in image 1
49
+ int x1 = blockIdx.x + 4;
50
+ int y1 = blockIdx.y + 4;
51
+ int item = blockIdx.z;
52
+ int ch_off = threadIdx.x;
53
+
54
+ // Load 3D patch into shared shared memory
55
+ for (int j = 0; j < 1; j++) { // HEIGHT
56
+ for (int i = 0; i < 1; i++) { // WIDTH
57
+ int ji_off = (j + i) * SIZE_3(rbot0);
58
+ for (int ch = ch_off; ch < SIZE_3(rbot0); ch += 32) { // CHANNELS
59
+ int idx1 = ((item * SIZE_1(rbot0) + y1+j) * SIZE_2(rbot0) + x1+i) * SIZE_3(rbot0) + ch;
60
+ int idxPatchData = ji_off + ch;
61
+ patch_data[idxPatchData] = rbot0[idx1];
62
+ }
63
+ }
64
+ }
65
+
66
+ __syncthreads();
67
+
68
+ __shared__ float sum[32];
69
+
70
+ // Compute correlation
71
+ for (int top_channel = 0; top_channel < SIZE_1(top); top_channel++) {
72
+ sum[ch_off] = 0;
73
+
74
+ int s2o = top_channel % 9 - 4;
75
+ int s2p = top_channel / 9 - 4;
76
+
77
+ for (int j = 0; j < 1; j++) { // HEIGHT
78
+ for (int i = 0; i < 1; i++) { // WIDTH
79
+ int ji_off = (j + i) * SIZE_3(rbot0);
80
+ for (int ch = ch_off; ch < SIZE_3(rbot0); ch += 32) { // CHANNELS
81
+ int x2 = x1 + s2o;
82
+ int y2 = y1 + s2p;
83
+
84
+ int idxPatchData = ji_off + ch;
85
+ int idx2 = ((item * SIZE_1(rbot0) + y2+j) * SIZE_2(rbot0) + x2+i) * SIZE_3(rbot0) + ch;
86
+
87
+ sum[ch_off] += patch_data[idxPatchData] * rbot1[idx2];
88
+ }
89
+ }
90
+ }
91
+
92
+ __syncthreads();
93
+
94
+ if (ch_off == 0) {
95
+ float total_sum = 0;
96
+ for (int idx = 0; idx < 32; idx++) {
97
+ total_sum += sum[idx];
98
+ }
99
+ const int sumelems = SIZE_3(rbot0);
100
+ const int index = ((top_channel*SIZE_2(top) + blockIdx.y)*SIZE_3(top))+blockIdx.x;
101
+ top[index + item*SIZE_1(top)*SIZE_2(top)*SIZE_3(top)] = total_sum / (float)sumelems;
102
+ }
103
+ }
104
+ }
105
+ '''
106
+
107
+ kernel_Correlation_updateGradOne = '''
108
+ #define ROUND_OFF 50000
109
+
110
+ extern "C" __global__ void kernel_Correlation_updateGradOne(
111
+ const int n,
112
+ const int intSample,
113
+ const float* rbot0,
114
+ const float* rbot1,
115
+ const float* gradOutput,
116
+ float* gradOne,
117
+ float* gradTwo
118
+ ) { for (int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x; intIndex < n; intIndex += blockDim.x * gridDim.x) {
119
+ int n = intIndex % SIZE_1(gradOne); // channels
120
+ int l = (intIndex / SIZE_1(gradOne)) % SIZE_3(gradOne) + 4; // w-pos
121
+ int m = (intIndex / SIZE_1(gradOne) / SIZE_3(gradOne)) % SIZE_2(gradOne) + 4; // h-pos
122
+
123
+ // round_off is a trick to enable integer division with ceil, even for negative numbers
124
+ // We use a large offset, for the inner part not to become negative.
125
+ const int round_off = ROUND_OFF;
126
+ const int round_off_s1 = round_off;
127
+
128
+ // We add round_off before_s1 the int division and subtract round_off after it, to ensure the formula matches ceil behavior:
129
+ int xmin = (l - 4 + round_off_s1 - 1) + 1 - round_off; // ceil (l - 4)
130
+ int ymin = (m - 4 + round_off_s1 - 1) + 1 - round_off; // ceil (l - 4)
131
+
132
+ // Same here:
133
+ int xmax = (l - 4 + round_off_s1) - round_off; // floor (l - 4)
134
+ int ymax = (m - 4 + round_off_s1) - round_off; // floor (m - 4)
135
+
136
+ float sum = 0;
137
+ if (xmax>=0 && ymax>=0 && (xmin<=SIZE_3(gradOutput)-1) && (ymin<=SIZE_2(gradOutput)-1)) {
138
+ xmin = max(0,xmin);
139
+ xmax = min(SIZE_3(gradOutput)-1,xmax);
140
+
141
+ ymin = max(0,ymin);
142
+ ymax = min(SIZE_2(gradOutput)-1,ymax);
143
+
144
+ for (int p = -4; p <= 4; p++) {
145
+ for (int o = -4; o <= 4; o++) {
146
+ // Get rbot1 data:
147
+ int s2o = o;
148
+ int s2p = p;
149
+ int idxbot1 = ((intSample * SIZE_1(rbot0) + (m+s2p)) * SIZE_2(rbot0) + (l+s2o)) * SIZE_3(rbot0) + n;
150
+ float bot1tmp = rbot1[idxbot1]; // rbot1[l+s2o,m+s2p,n]
151
+
152
+ // Index offset for gradOutput in following loops:
153
+ int op = (p+4) * 9 + (o+4); // index[o,p]
154
+ int idxopoffset = (intSample * SIZE_1(gradOutput) + op);
155
+
156
+ for (int y = ymin; y <= ymax; y++) {
157
+ for (int x = xmin; x <= xmax; x++) {
158
+ int idxgradOutput = (idxopoffset * SIZE_2(gradOutput) + y) * SIZE_3(gradOutput) + x; // gradOutput[x,y,o,p]
159
+ sum += gradOutput[idxgradOutput] * bot1tmp;
160
+ }
161
+ }
162
+ }
163
+ }
164
+ }
165
+ const int sumelems = SIZE_1(gradOne);
166
+ const int bot0index = ((n * SIZE_2(gradOne)) + (m-4)) * SIZE_3(gradOne) + (l-4);
167
+ gradOne[bot0index + intSample*SIZE_1(gradOne)*SIZE_2(gradOne)*SIZE_3(gradOne)] = sum / (float)sumelems;
168
+ } }
169
+ '''
170
+
171
+ kernel_Correlation_updateGradTwo = '''
172
+ #define ROUND_OFF 50000
173
+
174
+ extern "C" __global__ void kernel_Correlation_updateGradTwo(
175
+ const int n,
176
+ const int intSample,
177
+ const float* rbot0,
178
+ const float* rbot1,
179
+ const float* gradOutput,
180
+ float* gradOne,
181
+ float* gradTwo
182
+ ) { for (int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x; intIndex < n; intIndex += blockDim.x * gridDim.x) {
183
+ int n = intIndex % SIZE_1(gradTwo); // channels
184
+ int l = (intIndex / SIZE_1(gradTwo)) % SIZE_3(gradTwo) + 4; // w-pos
185
+ int m = (intIndex / SIZE_1(gradTwo) / SIZE_3(gradTwo)) % SIZE_2(gradTwo) + 4; // h-pos
186
+
187
+ // round_off is a trick to enable integer division with ceil, even for negative numbers
188
+ // We use a large offset, for the inner part not to become negative.
189
+ const int round_off = ROUND_OFF;
190
+ const int round_off_s1 = round_off;
191
+
192
+ float sum = 0;
193
+ for (int p = -4; p <= 4; p++) {
194
+ for (int o = -4; o <= 4; o++) {
195
+ int s2o = o;
196
+ int s2p = p;
197
+
198
+ //Get X,Y ranges and clamp
199
+ // We add round_off before_s1 the int division and subtract round_off after it, to ensure the formula matches ceil behavior:
200
+ int xmin = (l - 4 - s2o + round_off_s1 - 1) + 1 - round_off; // ceil (l - 4 - s2o)
201
+ int ymin = (m - 4 - s2p + round_off_s1 - 1) + 1 - round_off; // ceil (l - 4 - s2o)
202
+
203
+ // Same here:
204
+ int xmax = (l - 4 - s2o + round_off_s1) - round_off; // floor (l - 4 - s2o)
205
+ int ymax = (m - 4 - s2p + round_off_s1) - round_off; // floor (m - 4 - s2p)
206
+
207
+ if (xmax>=0 && ymax>=0 && (xmin<=SIZE_3(gradOutput)-1) && (ymin<=SIZE_2(gradOutput)-1)) {
208
+ xmin = max(0,xmin);
209
+ xmax = min(SIZE_3(gradOutput)-1,xmax);
210
+
211
+ ymin = max(0,ymin);
212
+ ymax = min(SIZE_2(gradOutput)-1,ymax);
213
+
214
+ // Get rbot0 data:
215
+ int idxbot0 = ((intSample * SIZE_1(rbot0) + (m-s2p)) * SIZE_2(rbot0) + (l-s2o)) * SIZE_3(rbot0) + n;
216
+ float bot0tmp = rbot0[idxbot0]; // rbot1[l+s2o,m+s2p,n]
217
+
218
+ // Index offset for gradOutput in following loops:
219
+ int op = (p+4) * 9 + (o+4); // index[o,p]
220
+ int idxopoffset = (intSample * SIZE_1(gradOutput) + op);
221
+
222
+ for (int y = ymin; y <= ymax; y++) {
223
+ for (int x = xmin; x <= xmax; x++) {
224
+ int idxgradOutput = (idxopoffset * SIZE_2(gradOutput) + y) * SIZE_3(gradOutput) + x; // gradOutput[x,y,o,p]
225
+ sum += gradOutput[idxgradOutput] * bot0tmp;
226
+ }
227
+ }
228
+ }
229
+ }
230
+ }
231
+ const int sumelems = SIZE_1(gradTwo);
232
+ const int bot1index = ((n * SIZE_2(gradTwo)) + (m-4)) * SIZE_3(gradTwo) + (l-4);
233
+ gradTwo[bot1index + intSample*SIZE_1(gradTwo)*SIZE_2(gradTwo)*SIZE_3(gradTwo)] = sum / (float)sumelems;
234
+ } }
235
+ '''
236
+
237
+ def cupy_kernel(strFunction, objVariables):
238
+ strKernel = globals()[strFunction]
239
+
240
+ while True:
241
+ objMatch = re.search('(SIZE_)([0-4])(\()([^\)]*)(\))', strKernel)
242
+
243
+ if objMatch is None:
244
+ break
245
+ # end
246
+
247
+ intArg = int(objMatch.group(2))
248
+
249
+ strTensor = objMatch.group(4)
250
+ intSizes = objVariables[strTensor].size()
251
+
252
+ strKernel = strKernel.replace(objMatch.group(), str(intSizes[intArg] if torch.is_tensor(intSizes[intArg]) == False else intSizes[intArg].item()))
253
+
254
+ while True:
255
+ objMatch = re.search('(VALUE_)([0-4])(\()([^\)]+)(\))', strKernel)
256
+
257
+ if objMatch is None:
258
+ break
259
+ # end
260
+
261
+ intArgs = int(objMatch.group(2))
262
+ strArgs = objMatch.group(4).split(',')
263
+
264
+ strTensor = strArgs[0]
265
+ intStrides = objVariables[strTensor].stride()
266
+ strIndex = [ '((' + strArgs[intArg + 1].replace('{', '(').replace('}', ')').strip() + ')*' + str(intStrides[intArg] if torch.is_tensor(intStrides[intArg]) == False else intStrides[intArg].item()) + ')' for intArg in range(intArgs) ]
267
+
268
+ strKernel = strKernel.replace(objMatch.group(0), strTensor + '[' + str('+').join(strIndex) + ']')
269
+ # end
270
+
271
+ return strKernel
272
+ # end
273
+
274
+ @cupy.memoize(for_each_device=True)
275
+ def cupy_launch(strFunction, strKernel):
276
+ if 'CUDA_HOME' not in os.environ:
277
+ os.environ['CUDA_HOME'] = cupy.cuda.get_cuda_path()
278
+ # end
279
+
280
+ return cupy.RawKernel(strKernel, strFunction, tuple(['-I ' + os.environ['CUDA_HOME'], '-I ' + os.environ['CUDA_HOME'] + '/include']))
281
+ # end
282
+
283
+ class _FunctionCorrelation(torch.autograd.Function):
284
+ @staticmethod
285
+ def forward(self, one, two):
286
+ rbot0 = one.new_zeros([ one.shape[0], one.shape[2] + 8, one.shape[3] + 8, one.shape[1] ])
287
+ rbot1 = one.new_zeros([ one.shape[0], one.shape[2] + 8, one.shape[3] + 8, one.shape[1] ])
288
+
289
+ one = one.contiguous(); assert(one.is_cuda == True)
290
+ two = two.contiguous(); assert(two.is_cuda == True)
291
+
292
+ output = one.new_zeros([ one.shape[0], 81, one.shape[2], one.shape[3] ])
293
+
294
+ if one.is_cuda == True:
295
+ n = one.shape[2] * one.shape[3]
296
+ cupy_launch('kernel_Correlation_rearrange', cupy_kernel('kernel_Correlation_rearrange', {
297
+ 'input': one,
298
+ 'output': rbot0
299
+ }))(
300
+ grid=tuple([ int((n + 16 - 1) / 16), one.shape[1], one.shape[0] ]),
301
+ block=tuple([ 16, 1, 1 ]),
302
+ args=[ cupy.int32(n), one.data_ptr(), rbot0.data_ptr() ]
303
+ )
304
+
305
+ n = two.shape[2] * two.shape[3]
306
+ cupy_launch('kernel_Correlation_rearrange', cupy_kernel('kernel_Correlation_rearrange', {
307
+ 'input': two,
308
+ 'output': rbot1
309
+ }))(
310
+ grid=tuple([ int((n + 16 - 1) / 16), two.shape[1], two.shape[0] ]),
311
+ block=tuple([ 16, 1, 1 ]),
312
+ args=[ cupy.int32(n), two.data_ptr(), rbot1.data_ptr() ]
313
+ )
314
+
315
+ n = output.shape[1] * output.shape[2] * output.shape[3]
316
+ cupy_launch('kernel_Correlation_updateOutput', cupy_kernel('kernel_Correlation_updateOutput', {
317
+ 'rbot0': rbot0,
318
+ 'rbot1': rbot1,
319
+ 'top': output
320
+ }))(
321
+ grid=tuple([ output.shape[3], output.shape[2], output.shape[0] ]),
322
+ block=tuple([ 32, 1, 1 ]),
323
+ shared_mem=one.shape[1] * 4,
324
+ args=[ cupy.int32(n), rbot0.data_ptr(), rbot1.data_ptr(), output.data_ptr() ]
325
+ )
326
+
327
+ elif one.is_cuda == False:
328
+ raise NotImplementedError()
329
+
330
+ # end
331
+
332
+ self.save_for_backward(one, two, rbot0, rbot1)
333
+
334
+ return output
335
+ # end
336
+
337
+ @staticmethod
338
+ def backward(self, gradOutput):
339
+ one, two, rbot0, rbot1 = self.saved_tensors
340
+
341
+ gradOutput = gradOutput.contiguous(); assert(gradOutput.is_cuda == True)
342
+
343
+ gradOne = one.new_zeros([ one.shape[0], one.shape[1], one.shape[2], one.shape[3] ]) if self.needs_input_grad[0] == True else None
344
+ gradTwo = one.new_zeros([ one.shape[0], one.shape[1], one.shape[2], one.shape[3] ]) if self.needs_input_grad[1] == True else None
345
+
346
+ if one.is_cuda == True:
347
+ if gradOne is not None:
348
+ for intSample in range(one.shape[0]):
349
+ n = one.shape[1] * one.shape[2] * one.shape[3]
350
+ cupy_launch('kernel_Correlation_updateGradOne', cupy_kernel('kernel_Correlation_updateGradOne', {
351
+ 'rbot0': rbot0,
352
+ 'rbot1': rbot1,
353
+ 'gradOutput': gradOutput,
354
+ 'gradOne': gradOne,
355
+ 'gradTwo': None
356
+ }))(
357
+ grid=tuple([ int((n + 512 - 1) / 512), 1, 1 ]),
358
+ block=tuple([ 512, 1, 1 ]),
359
+ args=[ cupy.int32(n), intSample, rbot0.data_ptr(), rbot1.data_ptr(), gradOutput.data_ptr(), gradOne.data_ptr(), None ]
360
+ )
361
+ # end
362
+ # end
363
+
364
+ if gradTwo is not None:
365
+ for intSample in range(one.shape[0]):
366
+ n = one.shape[1] * one.shape[2] * one.shape[3]
367
+ cupy_launch('kernel_Correlation_updateGradTwo', cupy_kernel('kernel_Correlation_updateGradTwo', {
368
+ 'rbot0': rbot0,
369
+ 'rbot1': rbot1,
370
+ 'gradOutput': gradOutput,
371
+ 'gradOne': None,
372
+ 'gradTwo': gradTwo
373
+ }))(
374
+ grid=tuple([ int((n + 512 - 1) / 512), 1, 1 ]),
375
+ block=tuple([ 512, 1, 1 ]),
376
+ args=[ cupy.int32(n), intSample, rbot0.data_ptr(), rbot1.data_ptr(), gradOutput.data_ptr(), None, gradTwo.data_ptr() ]
377
+ )
378
+ # end
379
+ # end
380
+
381
+ elif one.is_cuda == False:
382
+ raise NotImplementedError()
383
+
384
+ # end
385
+
386
+ return gradOne, gradTwo
387
+ # end
388
+ # end
389
+
390
+ def FunctionCorrelation(tenOne, tenTwo):
391
+ return _FunctionCorrelation.apply(tenOne, tenTwo)
392
+ # end
393
+
394
+ class ModuleCorrelation(torch.nn.Module):
395
+ def __init__(self):
396
+ super().__init__()
397
+ # end
398
+
399
+ def forward(self, tenOne, tenTwo):
400
+ return _FunctionCorrelation.apply(tenOne, tenTwo)
401
+ # end
402
+ # end
modules/cupy_module/cupy_utils.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ import cupy
2
+
3
+ #@cupy.memoize(for_each_device=True)
4
+ def cupy_launch(strFunction, strKernel):
5
+ # return cupy.cuda.compile_with_cache(strKernel).get_function(strFunction)
6
+ return cupy.RawKernel(strKernel, strFunction)
7
+ # end
modules/cupy_module/nedt.py ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import cupy
3
+ import kornia
4
+ import torch.nn as nn
5
+
6
+ from modules.cupy_module.cupy_utils import cupy_launch
7
+ # Code taken from https://github.com/ShuhongChen/eisai-anime-interpolator
8
+
9
+ _batch_edt_kernel = ('kernel_dt', '''
10
+ extern "C" __global__ void kernel_dt(
11
+ const int bs,
12
+ const int h,
13
+ const int w,
14
+ const float diam2,
15
+ float* data,
16
+ float* output
17
+ ) {
18
+ int idx = blockIdx.x * blockDim.x + threadIdx.x;
19
+ if (idx >= bs*h*w) {
20
+ return;
21
+ }
22
+ int pb = idx / (h*w);
23
+ int pi = (idx - h*w*pb) / w;
24
+ int pj = (idx - h*w*pb - w*pi);
25
+
26
+ float cost;
27
+ float mincost = diam2;
28
+ for (int j = 0; j < w; j++) {
29
+ cost = data[h*w*pb + w*pi + j] + (pj-j)*(pj-j);
30
+ if (cost < mincost) {
31
+ mincost = cost;
32
+ }
33
+ }
34
+ output[idx] = mincost;
35
+ return;
36
+ }
37
+ ''')
38
+
39
+ class NEDT(nn.Module):
40
+ def __init__(self):
41
+ super().__init__()
42
+
43
+ def batch_edt(self, img, block=1024):
44
+ # must initialize cuda/cupy after forking
45
+ _batch_edt = cupy_launch(*_batch_edt_kernel)
46
+
47
+ # bookkeeppingg
48
+ if len(img.shape)==4:
49
+ assert img.shape[1]==1
50
+ img = img.squeeze(1)
51
+ expand = True
52
+ else:
53
+ expand = False
54
+ bs,h,w = img.shape
55
+ diam2 = h**2 + w**2
56
+ odtype = img.dtype
57
+ grid = (img.nelement()+block-1) // block
58
+
59
+ # first pass, y-axis
60
+ data = ((1-img.type(torch.float32)) * diam2).contiguous()
61
+ intermed = torch.zeros_like(data)
62
+ _batch_edt(
63
+ grid=(grid, 1, 1),
64
+ block=(block, 1, 1), # < 1024
65
+ args=[
66
+ cupy.int32(bs),
67
+ cupy.int32(h),
68
+ cupy.int32(w),
69
+ cupy.float32(diam2),
70
+ data.data_ptr(),
71
+ intermed.data_ptr(),
72
+ ],
73
+ )
74
+
75
+ # second pass, x-axis
76
+ intermed = intermed.permute(0,2,1).contiguous()
77
+ out = torch.zeros_like(intermed)
78
+ _batch_edt(
79
+ grid=(grid, 1, 1),
80
+ block=(block, 1, 1),
81
+ args=[
82
+ cupy.int32(bs),
83
+ cupy.int32(w),
84
+ cupy.int32(h),
85
+ cupy.float32(diam2),
86
+ intermed.data_ptr(),
87
+ out.data_ptr(),
88
+ ],
89
+ )
90
+ ans = out.permute(0,2,1).sqrt()
91
+ ans = ans.type(odtype) if odtype!=ans.dtype else ans
92
+
93
+ if expand:
94
+ ans = ans.unsqueeze(1)
95
+ return ans
96
+
97
+ def batch_dog(self, img, t=1.0, sigma=1.0, k=1.6, epsilon=0.01, kernel_factor=4, clip=True):
98
+ # to grayscale if needed
99
+ bs,ch,h,w = img.shape
100
+ if ch in [3,4]:
101
+ img = kornia.color.rgb_to_grayscale(img[:,:3])
102
+ else:
103
+ assert ch==1
104
+
105
+ # calculate dog
106
+ kern0 = max(2*int(sigma*kernel_factor)+1, 3)
107
+ kern1 = max(2*int(sigma*k*kernel_factor)+1, 3)
108
+ g0 = kornia.filters.gaussian_blur2d(
109
+ img, (kern0,kern0), (sigma,sigma), border_type='replicate',
110
+ )
111
+ g1 = kornia.filters.gaussian_blur2d(
112
+ img, (kern1,kern1), (sigma*k,sigma*k), border_type='replicate',
113
+ )
114
+ out = 0.5 + t*(g1 - g0) - epsilon
115
+ out = out.clip(0,1) if clip else out
116
+ return out
117
+
118
+ def forward(
119
+ self, img, t=2.0, sigma_factor=1/540,
120
+ k=1.6, epsilon=0.01,
121
+ kernel_factor=4, exp_factor=540/15
122
+ ):
123
+ dog = self.batch_dog(
124
+ img, t=t, sigma=img.shape[-2]*sigma_factor, k=k,
125
+ epsilon=epsilon, kernel_factor=kernel_factor, clip=False,
126
+ )
127
+ edt = self.batch_edt((dog > 0.5).float())
128
+ out = 1 - (-edt*exp_factor / max(edt.shape[-2:])).exp()
129
+ return out
modules/cupy_module/softsplat.py ADDED
@@ -0,0 +1,368 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import re
3
+ import cupy
4
+
5
+ from modules.cupy_module.cupy_utils import cupy_launch
6
+
7
+ # Code from https://github.com/sniklaus/softmax-splatting/blob/master/softsplat.py
8
+
9
+ kernel_Softsplat_updateOutput = '''
10
+ extern "C" __global__ void kernel_Softsplat_updateOutput(
11
+ const int n,
12
+ const float* input,
13
+ const float* flow,
14
+ float* output
15
+ ) { for (int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x; intIndex < n; intIndex += blockDim.x * gridDim.x) {
16
+ const int intN = ( intIndex / SIZE_3(output) / SIZE_2(output) / SIZE_1(output) ) % SIZE_0(output);
17
+ const int intC = ( intIndex / SIZE_3(output) / SIZE_2(output) ) % SIZE_1(output);
18
+ const int intY = ( intIndex / SIZE_3(output) ) % SIZE_2(output);
19
+ const int intX = ( intIndex ) % SIZE_3(output);
20
+
21
+ float fltOutputX = (float) (intX) + VALUE_4(flow, intN, 0, intY, intX);
22
+ float fltOutputY = (float) (intY) + VALUE_4(flow, intN, 1, intY, intX);
23
+
24
+ int intNorthwestX = (int) (floor(fltOutputX));
25
+ int intNorthwestY = (int) (floor(fltOutputY));
26
+ int intNortheastX = intNorthwestX + 1;
27
+ int intNortheastY = intNorthwestY;
28
+ int intSouthwestX = intNorthwestX;
29
+ int intSouthwestY = intNorthwestY + 1;
30
+ int intSoutheastX = intNorthwestX + 1;
31
+ int intSoutheastY = intNorthwestY + 1;
32
+
33
+ float fltNorthwest = ((float) (intSoutheastX) - fltOutputX) * ((float) (intSoutheastY) - fltOutputY);
34
+ float fltNortheast = (fltOutputX - (float) (intSouthwestX)) * ((float) (intSouthwestY) - fltOutputY);
35
+ float fltSouthwest = ((float) (intNortheastX) - fltOutputX) * (fltOutputY - (float) (intNortheastY));
36
+ float fltSoutheast = (fltOutputX - (float) (intNorthwestX)) * (fltOutputY - (float) (intNorthwestY));
37
+
38
+ if ((intNorthwestX >= 0) & (intNorthwestX < SIZE_3(output)) & (intNorthwestY >= 0) & (intNorthwestY < SIZE_2(output))) {
39
+ atomicAdd(&output[OFFSET_4(output, intN, intC, intNorthwestY, intNorthwestX)], VALUE_4(input, intN, intC, intY, intX) * fltNorthwest);
40
+ }
41
+
42
+ if ((intNortheastX >= 0) & (intNortheastX < SIZE_3(output)) & (intNortheastY >= 0) & (intNortheastY < SIZE_2(output))) {
43
+ atomicAdd(&output[OFFSET_4(output, intN, intC, intNortheastY, intNortheastX)], VALUE_4(input, intN, intC, intY, intX) * fltNortheast);
44
+ }
45
+
46
+ if ((intSouthwestX >= 0) & (intSouthwestX < SIZE_3(output)) & (intSouthwestY >= 0) & (intSouthwestY < SIZE_2(output))) {
47
+ atomicAdd(&output[OFFSET_4(output, intN, intC, intSouthwestY, intSouthwestX)], VALUE_4(input, intN, intC, intY, intX) * fltSouthwest);
48
+ }
49
+
50
+ if ((intSoutheastX >= 0) & (intSoutheastX < SIZE_3(output)) & (intSoutheastY >= 0) & (intSoutheastY < SIZE_2(output))) {
51
+ atomicAdd(&output[OFFSET_4(output, intN, intC, intSoutheastY, intSoutheastX)], VALUE_4(input, intN, intC, intY, intX) * fltSoutheast);
52
+ }
53
+ } }
54
+ '''
55
+
56
+ kernel_Softsplat_updateGradInput = '''
57
+ extern "C" __global__ void kernel_Softsplat_updateGradInput(
58
+ const int n,
59
+ const float* input,
60
+ const float* flow,
61
+ const float* gradOutput,
62
+ float* gradInput,
63
+ float* gradFlow
64
+ ) { for (int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x; intIndex < n; intIndex += blockDim.x * gridDim.x) {
65
+ const int intN = ( intIndex / SIZE_3(gradInput) / SIZE_2(gradInput) / SIZE_1(gradInput) ) % SIZE_0(gradInput);
66
+ const int intC = ( intIndex / SIZE_3(gradInput) / SIZE_2(gradInput) ) % SIZE_1(gradInput);
67
+ const int intY = ( intIndex / SIZE_3(gradInput) ) % SIZE_2(gradInput);
68
+ const int intX = ( intIndex ) % SIZE_3(gradInput);
69
+
70
+ float fltGradInput = 0.0;
71
+
72
+ float fltOutputX = (float) (intX) + VALUE_4(flow, intN, 0, intY, intX);
73
+ float fltOutputY = (float) (intY) + VALUE_4(flow, intN, 1, intY, intX);
74
+
75
+ int intNorthwestX = (int) (floor(fltOutputX));
76
+ int intNorthwestY = (int) (floor(fltOutputY));
77
+ int intNortheastX = intNorthwestX + 1;
78
+ int intNortheastY = intNorthwestY;
79
+ int intSouthwestX = intNorthwestX;
80
+ int intSouthwestY = intNorthwestY + 1;
81
+ int intSoutheastX = intNorthwestX + 1;
82
+ int intSoutheastY = intNorthwestY + 1;
83
+
84
+ float fltNorthwest = ((float) (intSoutheastX) - fltOutputX) * ((float) (intSoutheastY) - fltOutputY);
85
+ float fltNortheast = (fltOutputX - (float) (intSouthwestX)) * ((float) (intSouthwestY) - fltOutputY);
86
+ float fltSouthwest = ((float) (intNortheastX) - fltOutputX) * (fltOutputY - (float) (intNortheastY));
87
+ float fltSoutheast = (fltOutputX - (float) (intNorthwestX)) * (fltOutputY - (float) (intNorthwestY));
88
+
89
+ if ((intNorthwestX >= 0) & (intNorthwestX < SIZE_3(gradOutput)) & (intNorthwestY >= 0) & (intNorthwestY < SIZE_2(gradOutput))) {
90
+ fltGradInput += VALUE_4(gradOutput, intN, intC, intNorthwestY, intNorthwestX) * fltNorthwest;
91
+ }
92
+
93
+ if ((intNortheastX >= 0) & (intNortheastX < SIZE_3(gradOutput)) & (intNortheastY >= 0) & (intNortheastY < SIZE_2(gradOutput))) {
94
+ fltGradInput += VALUE_4(gradOutput, intN, intC, intNortheastY, intNortheastX) * fltNortheast;
95
+ }
96
+
97
+ if ((intSouthwestX >= 0) & (intSouthwestX < SIZE_3(gradOutput)) & (intSouthwestY >= 0) & (intSouthwestY < SIZE_2(gradOutput))) {
98
+ fltGradInput += VALUE_4(gradOutput, intN, intC, intSouthwestY, intSouthwestX) * fltSouthwest;
99
+ }
100
+
101
+ if ((intSoutheastX >= 0) & (intSoutheastX < SIZE_3(gradOutput)) & (intSoutheastY >= 0) & (intSoutheastY < SIZE_2(gradOutput))) {
102
+ fltGradInput += VALUE_4(gradOutput, intN, intC, intSoutheastY, intSoutheastX) * fltSoutheast;
103
+ }
104
+
105
+ gradInput[intIndex] = fltGradInput;
106
+ } }
107
+ '''
108
+
109
+ kernel_Softsplat_updateGradFlow = '''
110
+ extern "C" __global__ void kernel_Softsplat_updateGradFlow(
111
+ const int n,
112
+ const float* input,
113
+ const float* flow,
114
+ const float* gradOutput,
115
+ float* gradInput,
116
+ float* gradFlow
117
+ ) { for (int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x; intIndex < n; intIndex += blockDim.x * gridDim.x) {
118
+ float fltGradFlow = 0.0;
119
+
120
+ const int intN = ( intIndex / SIZE_3(gradFlow) / SIZE_2(gradFlow) / SIZE_1(gradFlow) ) % SIZE_0(gradFlow);
121
+ const int intC = ( intIndex / SIZE_3(gradFlow) / SIZE_2(gradFlow) ) % SIZE_1(gradFlow);
122
+ const int intY = ( intIndex / SIZE_3(gradFlow) ) % SIZE_2(gradFlow);
123
+ const int intX = ( intIndex ) % SIZE_3(gradFlow);
124
+
125
+ float fltOutputX = (float) (intX) + VALUE_4(flow, intN, 0, intY, intX);
126
+ float fltOutputY = (float) (intY) + VALUE_4(flow, intN, 1, intY, intX);
127
+
128
+ int intNorthwestX = (int) (floor(fltOutputX));
129
+ int intNorthwestY = (int) (floor(fltOutputY));
130
+ int intNortheastX = intNorthwestX + 1;
131
+ int intNortheastY = intNorthwestY;
132
+ int intSouthwestX = intNorthwestX;
133
+ int intSouthwestY = intNorthwestY + 1;
134
+ int intSoutheastX = intNorthwestX + 1;
135
+ int intSoutheastY = intNorthwestY + 1;
136
+
137
+ float fltNorthwest = 0.0;
138
+ float fltNortheast = 0.0;
139
+ float fltSouthwest = 0.0;
140
+ float fltSoutheast = 0.0;
141
+
142
+ if (intC == 0) {
143
+ fltNorthwest = ((float) (-1.0)) * ((float) (intSoutheastY) - fltOutputY);
144
+ fltNortheast = ((float) (+1.0)) * ((float) (intSouthwestY) - fltOutputY);
145
+ fltSouthwest = ((float) (-1.0)) * (fltOutputY - (float) (intNortheastY));
146
+ fltSoutheast = ((float) (+1.0)) * (fltOutputY - (float) (intNorthwestY));
147
+
148
+ } else if (intC == 1) {
149
+ fltNorthwest = ((float) (intSoutheastX) - fltOutputX) * ((float) (-1.0));
150
+ fltNortheast = (fltOutputX - (float) (intSouthwestX)) * ((float) (-1.0));
151
+ fltSouthwest = ((float) (intNortheastX) - fltOutputX) * ((float) (+1.0));
152
+ fltSoutheast = (fltOutputX - (float) (intNorthwestX)) * ((float) (+1.0));
153
+
154
+ }
155
+
156
+ for (int intChannel = 0; intChannel < SIZE_1(gradOutput); intChannel += 1) {
157
+ float fltInput = VALUE_4(input, intN, intChannel, intY, intX);
158
+
159
+ if ((intNorthwestX >= 0) & (intNorthwestX < SIZE_3(gradOutput)) & (intNorthwestY >= 0) & (intNorthwestY < SIZE_2(gradOutput))) {
160
+ fltGradFlow += fltInput * VALUE_4(gradOutput, intN, intChannel, intNorthwestY, intNorthwestX) * fltNorthwest;
161
+ }
162
+
163
+ if ((intNortheastX >= 0) & (intNortheastX < SIZE_3(gradOutput)) & (intNortheastY >= 0) & (intNortheastY < SIZE_2(gradOutput))) {
164
+ fltGradFlow += fltInput * VALUE_4(gradOutput, intN, intChannel, intNortheastY, intNortheastX) * fltNortheast;
165
+ }
166
+
167
+ if ((intSouthwestX >= 0) & (intSouthwestX < SIZE_3(gradOutput)) & (intSouthwestY >= 0) & (intSouthwestY < SIZE_2(gradOutput))) {
168
+ fltGradFlow += fltInput * VALUE_4(gradOutput, intN, intChannel, intSouthwestY, intSouthwestX) * fltSouthwest;
169
+ }
170
+
171
+ if ((intSoutheastX >= 0) & (intSoutheastX < SIZE_3(gradOutput)) & (intSoutheastY >= 0) & (intSoutheastY < SIZE_2(gradOutput))) {
172
+ fltGradFlow += fltInput * VALUE_4(gradOutput, intN, intChannel, intSoutheastY, intSoutheastX) * fltSoutheast;
173
+ }
174
+ }
175
+
176
+ gradFlow[intIndex] = fltGradFlow;
177
+ } }
178
+ '''
179
+
180
+ def cupy_kernel(strFunction, objVariables):
181
+ strKernel = globals()[strFunction]
182
+
183
+ while True:
184
+ objMatch = re.search('(SIZE_)([0-4])(\()([^\)]*)(\))', strKernel)
185
+
186
+ if objMatch is None:
187
+ break
188
+ # end
189
+
190
+ intArg = int(objMatch.group(2))
191
+
192
+ strTensor = objMatch.group(4)
193
+ intSizes = objVariables[strTensor].size()
194
+
195
+ strKernel = strKernel.replace(objMatch.group(), str(intSizes[intArg]))
196
+ # end
197
+
198
+ while True:
199
+ objMatch = re.search('(OFFSET_)([0-4])(\()([^\)]+)(\))', strKernel)
200
+
201
+ if objMatch is None:
202
+ break
203
+ # end
204
+
205
+ intArgs = int(objMatch.group(2))
206
+ strArgs = objMatch.group(4).split(',')
207
+
208
+ strTensor = strArgs[0]
209
+ intStrides = objVariables[strTensor].stride()
210
+ strIndex = [ '((' + strArgs[intArg + 1].replace('{', '(').replace('}', ')').strip() + ')*' + str(intStrides[intArg]) + ')' for intArg in range(intArgs) ]
211
+
212
+ strKernel = strKernel.replace(objMatch.group(0), '(' + str.join('+', strIndex) + ')')
213
+ # end
214
+
215
+ while True:
216
+ objMatch = re.search('(VALUE_)([0-4])(\()([^\)]+)(\))', strKernel)
217
+
218
+ if objMatch is None:
219
+ break
220
+ # end
221
+
222
+ intArgs = int(objMatch.group(2))
223
+ strArgs = objMatch.group(4).split(',')
224
+
225
+ strTensor = strArgs[0]
226
+ intStrides = objVariables[strTensor].stride()
227
+ strIndex = [ '((' + strArgs[intArg + 1].replace('{', '(').replace('}', ')').strip() + ')*' + str(intStrides[intArg]) + ')' for intArg in range(intArgs) ]
228
+
229
+ strKernel = strKernel.replace(objMatch.group(0), strTensor + '[' + str.join('+', strIndex) + ']')
230
+ # end
231
+
232
+ return strKernel
233
+ # end
234
+
235
+ class _FunctionSoftsplat(torch.autograd.Function):
236
+ @staticmethod
237
+ def forward(self, input, flow):
238
+ intSamples = input.shape[0]
239
+ intInputDepth, intInputHeight, intInputWidth = input.shape[1], input.shape[2], input.shape[3]
240
+ intFlowDepth, intFlowHeight, intFlowWidth = flow.shape[1], flow.shape[2], flow.shape[3]
241
+
242
+ assert(intFlowDepth == 2)
243
+ assert(intInputHeight == intFlowHeight)
244
+ assert(intInputWidth == intFlowWidth)
245
+
246
+ input = input.contiguous(); assert(input.is_cuda == True)
247
+ flow = flow.contiguous(); assert(flow.is_cuda == True)
248
+
249
+ output = input.new_zeros([ intSamples, intInputDepth, intInputHeight, intInputWidth ])
250
+
251
+ if input.is_cuda == True:
252
+ n = output.nelement()
253
+ cupy_launch('kernel_Softsplat_updateOutput', cupy_kernel('kernel_Softsplat_updateOutput', {
254
+ 'input': input,
255
+ 'flow': flow,
256
+ 'output': output
257
+ }))(
258
+ grid=tuple([ int((n + 512 - 1) / 512), 1, 1 ]),
259
+ block=tuple([ 512, 1, 1 ]),
260
+ args=[ cupy.int32(n), input.data_ptr(), flow.data_ptr(), output.data_ptr() ]
261
+ )
262
+
263
+ elif input.is_cuda == False:
264
+ raise NotImplementedError()
265
+
266
+ # end
267
+
268
+ self.save_for_backward(input, flow)
269
+
270
+ return output
271
+ # end
272
+
273
+ @staticmethod
274
+ def backward(self, gradOutput):
275
+ input, flow = self.saved_tensors
276
+
277
+ intSamples = input.shape[0]
278
+ intInputDepth, intInputHeight, intInputWidth = input.shape[1], input.shape[2], input.shape[3]
279
+ intFlowDepth, intFlowHeight, intFlowWidth = flow.shape[1], flow.shape[2], flow.shape[3]
280
+
281
+ assert(intFlowDepth == 2)
282
+ assert(intInputHeight == intFlowHeight)
283
+ assert(intInputWidth == intFlowWidth)
284
+
285
+ gradOutput = gradOutput.contiguous(); assert(gradOutput.is_cuda == True)
286
+
287
+ gradInput = input.new_zeros([ intSamples, intInputDepth, intInputHeight, intInputWidth ]) if self.needs_input_grad[0] == True else None
288
+ gradFlow = input.new_zeros([ intSamples, intFlowDepth, intFlowHeight, intFlowWidth ]) if self.needs_input_grad[1] == True else None
289
+
290
+ if input.is_cuda == True:
291
+ if gradInput is not None:
292
+ n = gradInput.nelement()
293
+ cupy_launch('kernel_Softsplat_updateGradInput', cupy_kernel('kernel_Softsplat_updateGradInput', {
294
+ 'input': input,
295
+ 'flow': flow,
296
+ 'gradOutput': gradOutput,
297
+ 'gradInput': gradInput,
298
+ 'gradFlow': gradFlow
299
+ }))(
300
+ grid=tuple([ int((n + 512 - 1) / 512), 1, 1 ]),
301
+ block=tuple([ 512, 1, 1 ]),
302
+ args=[ cupy.int32(n), input.data_ptr(), flow.data_ptr(), gradOutput.data_ptr(), gradInput.data_ptr(), None ]
303
+ )
304
+ # end
305
+
306
+ if gradFlow is not None:
307
+ n = gradFlow.nelement()
308
+ cupy_launch('kernel_Softsplat_updateGradFlow', cupy_kernel('kernel_Softsplat_updateGradFlow', {
309
+ 'input': input,
310
+ 'flow': flow,
311
+ 'gradOutput': gradOutput,
312
+ 'gradInput': gradInput,
313
+ 'gradFlow': gradFlow
314
+ }))(
315
+ grid=tuple([ int((n + 512 - 1) / 512), 1, 1 ]),
316
+ block=tuple([ 512, 1, 1 ]),
317
+ args=[ cupy.int32(n), input.data_ptr(), flow.data_ptr(), gradOutput.data_ptr(), None, gradFlow.data_ptr() ]
318
+ )
319
+ # end
320
+
321
+ elif input.is_cuda == False:
322
+ raise NotImplementedError()
323
+
324
+ # end
325
+
326
+ return gradInput, gradFlow
327
+ # end
328
+ # end
329
+
330
+ def FunctionSoftsplat(tenInput, tenFlow, tenMetric, strType):
331
+ assert(tenMetric is None or tenMetric.shape[1] == 1)
332
+ assert(strType in ['summation', 'average', 'linear', 'softmax'])
333
+
334
+ if strType == 'average':
335
+ tenInput = torch.cat([ tenInput, tenInput.new_ones(tenInput.shape[0], 1, tenInput.shape[2], tenInput.shape[3]) ], 1)
336
+
337
+ elif strType == 'linear':
338
+ tenInput = torch.cat([ tenInput * tenMetric, tenMetric ], 1)
339
+
340
+ elif strType == 'softmax':
341
+ tenInput = torch.cat([ tenInput * tenMetric.exp(), tenMetric.exp() ], 1)
342
+
343
+ # end
344
+
345
+ tenOutput = _FunctionSoftsplat.apply(tenInput, tenFlow)
346
+
347
+ if strType != 'summation':
348
+ tenNormalize = tenOutput[:, -1:, :, :]
349
+
350
+ tenNormalize[tenNormalize == 0.0] = 1.0
351
+
352
+ tenOutput = tenOutput[:, :-1, :, :] / tenNormalize
353
+ # end
354
+
355
+ return tenOutput
356
+ # end
357
+
358
+ class ModuleSoftsplat(torch.nn.Module):
359
+ def __init__(self, strType):
360
+ super().__init__()
361
+
362
+ self.strType = strType
363
+ # end
364
+
365
+ def forward(self, tenInput, tenFlow, tenMetric):
366
+ return FunctionSoftsplat(tenInput, tenFlow, tenMetric, self.strType)
367
+ # end
368
+ # end
modules/feature_extactor.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ import torchvision.models as models
5
+
6
+ from modules.basic_layers import GroupNorm
7
+
8
+ class Extractor(nn.Module):
9
+ def __init__(self, channels: list[int], num_groups: int = 32, use_residual: bool = True):
10
+ super().__init__()
11
+
12
+ self.use_residual = use_residual
13
+
14
+ self.layers = nn.ModuleList([
15
+ nn.Sequential(
16
+ nn.Conv2d(in_channels=channels[i], out_channels=channels[i + 1], kernel_size=3, stride=2, padding=1),
17
+ GroupNorm(channels[i + 1], num_groups = num_groups),
18
+ nn.SiLU(),
19
+ nn.Conv2d(in_channels=channels[i + 1], out_channels=channels[i + 1], kernel_size=3, stride=1, padding=1),
20
+ GroupNorm(channels[i + 1], num_groups = num_groups),
21
+ nn.SiLU()
22
+ ) for i in range(len(channels) - 1)
23
+ ])
24
+ if self.use_residual:
25
+ self.residual = nn.ModuleList([
26
+ nn.Sequential(
27
+ nn.Conv2d(in_channels=channels[i], out_channels=channels[i + 1], kernel_size=3, stride=2, padding=1),
28
+ ) for i in range(len(channels) - 1)
29
+ ])
30
+
31
+ def forward(self, x: torch.Tensor) -> list[torch.Tensor]:
32
+ features = []
33
+ for residual, layer in zip(self.residual, self.layers):
34
+ if self.use_residual:
35
+ x = layer(x) + residual(x)
36
+ else:
37
+ x = layer(x)
38
+ features.append(x)
39
+ return features
40
+
41
+
42
+ class ResNetExtractor(nn.Module):
43
+ def __init__(self, pretrained: bool = True, layers_to_extract: list[str] = ["layer1", "layer2", "layer3"]):
44
+ super(ResNetExtractor, self).__init__()
45
+
46
+ resnet = models.resnet18(pretrained=pretrained)
47
+
48
+ self.initial_layers = nn.Sequential(
49
+ resnet.conv1,
50
+ resnet.bn1,
51
+ resnet.relu
52
+ )
53
+
54
+ self.layers = nn.ModuleDict({
55
+ "layer1": resnet.layer1,
56
+ "layer2": resnet.layer2,
57
+ "layer3": resnet.layer3,
58
+ })
59
+
60
+ self.layers_to_extract = layers_to_extract
61
+
62
+ def forward(self, x: torch.Tensor) -> list[torch.Tensor]:
63
+ features = []
64
+ x = self.initial_layers(x)
65
+
66
+ for name, layer in self.layers.items():
67
+ x = layer(x)
68
+ if name in self.layers_to_extract:
69
+ features.append(x)
70
+
71
+ return features
72
+
73
+ class VGGExtractor(nn.Module):
74
+ def __init__(self, layers_to_extract: list[int] = [8, 15, 22, 29]):
75
+ super(VGGExtractor, self).__init__()
76
+
77
+ self.vgg = models.vgg16(pretrained=True).features
78
+ self.layers_to_extract = layers_to_extract
79
+ self.selected_layers = [self.vgg[i] for i in layers_to_extract]
80
+
81
+ def forward(self, x: torch.Tensor) -> list[torch.Tensor]:
82
+ features = []
83
+ for i, layer in enumerate(self.vgg):
84
+ x = layer(x)
85
+ if i in self.layers_to_extract:
86
+ features.append(x)
87
+ return features
modules/flow_models/flow_models.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from torch.nn.functional import interpolate
4
+
5
+ from modules.cupy_module import correlation
6
+ from modules.half_warper import HalfWarper
7
+ from modules.feature_extactor import Extractor
8
+
9
+ from modules.flow_models.raft.rfr_new import RAFT
10
+
11
+ class Decoder(nn.Module):
12
+ def __init__(self, in_channels: int):
13
+ super().__init__()
14
+
15
+ self.syntesis = nn.Sequential(
16
+ nn.Conv2d(in_channels=in_channels, out_channels=128, kernel_size=3, stride=1, padding=1),
17
+ nn.SiLU(),
18
+ nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=1),
19
+ nn.SiLU(),
20
+ nn.Conv2d(in_channels=128, out_channels=96, kernel_size=3, stride=1, padding=1),
21
+ nn.SiLU(),
22
+ nn.Conv2d(in_channels=96, out_channels=64, kernel_size=3, stride=1, padding=1),
23
+ nn.SiLU(),
24
+ nn.Conv2d(in_channels=64, out_channels=32, kernel_size=3, stride=1, padding=1),
25
+ nn.SiLU(),
26
+ nn.Conv2d(in_channels=32, out_channels=2, kernel_size=3, stride=1, padding=1)
27
+ )
28
+
29
+ def forward(self, img1: torch.Tensor, img2: torch.Tensor, residual: torch.Tensor | None) -> torch.Tensor:
30
+ width = img1.shape[3] and img2.shape[3]
31
+ height = img1.shape[2] and img2.shape[2]
32
+
33
+ if residual is None:
34
+ corr = correlation.FunctionCorrelation(tenOne=img1, tenTwo=img2)
35
+ main = torch.cat([img1, corr], dim=1)
36
+ else:
37
+ flow = interpolate(input=residual,
38
+ size=(height, width),
39
+ mode='bilinear',
40
+ align_corners=False) / \
41
+ float(residual.shape[3]) * float(width)
42
+ backwarp_img = HalfWarper.backward_wrapping(img=img2, flow=flow)
43
+ corr = correlation.FunctionCorrelation(tenOne=img1, tenTwo=backwarp_img)
44
+ main = torch.cat([img1, corr, flow], dim=1)
45
+
46
+ return self.syntesis(main)
47
+
48
+ class PWCFineFlow(nn.Module):
49
+ def __init__(self, pretrained_path: str | None = None):
50
+ super().__init__()
51
+
52
+ self.feature_extractor = Extractor([3, 16, 32, 64, 96, 128, 192], num_groups=16)
53
+
54
+ self.decoders = nn.ModuleList([
55
+ Decoder(16 + 81 + 2),
56
+ Decoder(32 + 81 + 2),
57
+ Decoder(64 + 81 + 2),
58
+ Decoder(96 + 81 + 2),
59
+ Decoder(128 + 81 + 2),
60
+ Decoder(192 + 81)
61
+ ])
62
+
63
+ if pretrained_path is not None:
64
+ self.load_state_dict(torch.load(pretrained_path))
65
+
66
+ def forward(self, img1: torch.Tensor, img2: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
67
+ width = img1.shape[3] and img2.shape[3]
68
+ height = img1.shape[2] and img2.shape[2]
69
+
70
+ feats1 = self.feature_extractor(img1)
71
+ feats2 = self.feature_extractor(img2)
72
+
73
+ forward = None
74
+ backward = None
75
+
76
+ for i in reversed(range(len(feats1))):
77
+ forward = self.decoders[i](feats1[i], feats2[i], forward)
78
+ backward = self.decoders[i](feats2[i], feats1[i], backward)
79
+
80
+ forward = interpolate(input=forward,
81
+ size=(height, width),
82
+ mode='bilinear',
83
+ align_corners=False) * \
84
+ (float(width) / float(forward.shape[3]))
85
+ backward = interpolate(input=backward,
86
+ size=(height, width),
87
+ mode='bilinear',
88
+ align_corners=False) * \
89
+ (float(width) / float(backward.shape[3]))
90
+
91
+ return forward, backward
92
+
93
+
94
+ class RAFTFineFlow(nn.Module):
95
+ def __init__(self, pretrained_path: str | None = None):
96
+ super().__init__()
97
+ self.raft = RAFT(pretrained_path)
98
+
99
+ def forward(self, img1: torch.Tensor, img2: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
100
+ forward = self.raft(img1, img2)
101
+ backward = self.raft(img2, img1)
102
+ return forward, backward
modules/flow_models/raft/LICENSE ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ BSD 3-Clause License
2
+
3
+ Copyright (c) 2020, princeton-vl
4
+ All rights reserved.
5
+
6
+ Redistribution and use in source and binary forms, with or without
7
+ modification, are permitted provided that the following conditions are met:
8
+
9
+ * Redistributions of source code must retain the above copyright notice, this
10
+ list of conditions and the following disclaimer.
11
+
12
+ * Redistributions in binary form must reproduce the above copyright notice,
13
+ this list of conditions and the following disclaimer in the documentation
14
+ and/or other materials provided with the distribution.
15
+
16
+ * Neither the name of the copyright holder nor the names of its
17
+ contributors may be used to endorse or promote products derived from
18
+ this software without specific prior written permission.
19
+
20
+ THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
21
+ AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22
+ IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
23
+ DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
24
+ FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25
+ DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
26
+ SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
27
+ CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
28
+ OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29
+ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
modules/flow_models/raft/corr.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ from .utils import bilinear_sampler, coords_grid
4
+
5
+
6
+ class CorrBlock:
7
+ def __init__(self, fmap1, fmap2, num_levels=4, radius=4):
8
+ self.num_levels = num_levels
9
+ self.radius = radius
10
+ self.corr_pyramid = []
11
+
12
+ # all pairs correlation
13
+ corr = CorrBlock.corr(fmap1, fmap2)
14
+
15
+ batch, h1, w1, dim, h2, w2 = corr.shape
16
+ corr = corr.reshape(batch*h1*w1, dim, h2, w2)
17
+
18
+ self.corr_pyramid.append(corr)
19
+ for i in range(self.num_levels-1):
20
+ corr = F.avg_pool2d(corr, 2, stride=2)
21
+ self.corr_pyramid.append(corr)
22
+
23
+ def __call__(self, coords):
24
+ r = self.radius
25
+ coords = coords.permute(0, 2, 3, 1)
26
+ batch, h1, w1, _ = coords.shape
27
+
28
+ out_pyramid = []
29
+ for i in range(self.num_levels):
30
+ corr = self.corr_pyramid[i]
31
+ dx = torch.linspace(-r, r, 2*r+1)
32
+ dy = torch.linspace(-r, r, 2*r+1)
33
+ delta = torch.stack(torch.meshgrid(dy, dx), dim=-1).to(coords.device)
34
+
35
+ centroid_lvl = coords.reshape(batch*h1*w1, 1, 1, 2) / 2**i
36
+ delta_lvl = delta.view(1, 2*r+1, 2*r+1, 2)
37
+ coords_lvl = centroid_lvl + delta_lvl
38
+
39
+ corr = bilinear_sampler(corr, coords_lvl)
40
+ corr = corr.view(batch, h1, w1, -1)
41
+ out_pyramid.append(corr)
42
+
43
+ out = torch.cat(out_pyramid, dim=-1)
44
+ return out.permute(0, 3, 1, 2).contiguous().float()
45
+
46
+
47
+ @staticmethod
48
+ def corr(fmap1, fmap2):
49
+ batch, dim, ht, wd = fmap1.shape
50
+ fmap1 = fmap1.view(batch, dim, ht*wd)
51
+ fmap2 = fmap2.view(batch, dim, ht*wd)
52
+
53
+ corr = torch.matmul(fmap1.transpose(1,2), fmap2)
54
+ corr = corr.view(batch, ht, wd, 1, ht, wd)
55
+ return corr / torch.sqrt(torch.tensor(dim).float())
56
+
modules/flow_models/raft/extractor.py ADDED
@@ -0,0 +1,342 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+
6
+ class ResidualBlock(nn.Module):
7
+ def __init__(self, in_planes, planes, norm_fn='group', stride=1):
8
+ super(ResidualBlock, self).__init__()
9
+
10
+ self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, padding=1, stride=stride)
11
+ self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1)
12
+ self.relu = nn.ReLU(inplace=True)
13
+
14
+ num_groups = planes // 8
15
+
16
+ if norm_fn == 'group':
17
+ self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
18
+ self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
19
+ if not stride == 1:
20
+ self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
21
+
22
+ elif norm_fn == 'batch':
23
+ self.norm1 = nn.BatchNorm2d(planes)
24
+ self.norm2 = nn.BatchNorm2d(planes)
25
+ if not stride == 1:
26
+ self.norm3 = nn.BatchNorm2d(planes)
27
+
28
+ elif norm_fn == 'instance':
29
+ self.norm1 = nn.InstanceNorm2d(planes)
30
+ self.norm2 = nn.InstanceNorm2d(planes)
31
+ if not stride == 1:
32
+ self.norm3 = nn.InstanceNorm2d(planes)
33
+
34
+ elif norm_fn == 'none':
35
+ self.norm1 = nn.Sequential()
36
+ self.norm2 = nn.Sequential()
37
+ if not stride == 1:
38
+ self.norm3 = nn.Sequential()
39
+
40
+ if stride == 1:
41
+ self.downsample = None
42
+
43
+ else:
44
+ self.downsample = nn.Sequential(
45
+ nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm3)
46
+
47
+
48
+ def forward(self, x):
49
+ y = x
50
+ y = self.relu(self.norm1(self.conv1(y)))
51
+ y = self.relu(self.norm2(self.conv2(y)))
52
+
53
+ if self.downsample is not None:
54
+ x = self.downsample(x)
55
+
56
+ return self.relu(x+y)
57
+
58
+
59
+
60
+ class BottleneckBlock(nn.Module):
61
+ def __init__(self, in_planes, planes, norm_fn='group', stride=1):
62
+ super(BottleneckBlock, self).__init__()
63
+
64
+ self.conv1 = nn.Conv2d(in_planes, planes//4, kernel_size=1, padding=0)
65
+ self.conv2 = nn.Conv2d(planes//4, planes//4, kernel_size=3, padding=1, stride=stride)
66
+ self.conv3 = nn.Conv2d(planes//4, planes, kernel_size=1, padding=0)
67
+ self.relu = nn.ReLU(inplace=True)
68
+
69
+ num_groups = planes // 8
70
+
71
+ if norm_fn == 'group':
72
+ self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes//4)
73
+ self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes//4)
74
+ self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
75
+ if not stride == 1:
76
+ self.norm4 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
77
+
78
+ elif norm_fn == 'batch':
79
+ self.norm1 = nn.BatchNorm2d(planes//4)
80
+ self.norm2 = nn.BatchNorm2d(planes//4)
81
+ self.norm3 = nn.BatchNorm2d(planes)
82
+ if not stride == 1:
83
+ self.norm4 = nn.BatchNorm2d(planes)
84
+
85
+ elif norm_fn == 'instance':
86
+ self.norm1 = nn.InstanceNorm2d(planes//4)
87
+ self.norm2 = nn.InstanceNorm2d(planes//4)
88
+ self.norm3 = nn.InstanceNorm2d(planes)
89
+ if not stride == 1:
90
+ self.norm4 = nn.InstanceNorm2d(planes)
91
+
92
+ elif norm_fn == 'none':
93
+ self.norm1 = nn.Sequential()
94
+ self.norm2 = nn.Sequential()
95
+ self.norm3 = nn.Sequential()
96
+ if not stride == 1:
97
+ self.norm4 = nn.Sequential()
98
+
99
+ if stride == 1:
100
+ self.downsample = None
101
+
102
+ else:
103
+ self.downsample = nn.Sequential(
104
+ nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm4)
105
+
106
+
107
+ def forward(self, x):
108
+ y = x
109
+ y = self.relu(self.norm1(self.conv1(y)))
110
+ y = self.relu(self.norm2(self.conv2(y)))
111
+ y = self.relu(self.norm3(self.conv3(y)))
112
+
113
+ if self.downsample is not None:
114
+ x = self.downsample(x)
115
+
116
+ return self.relu(x+y)
117
+
118
+ class BasicEncoder(nn.Module):
119
+ def __init__(self, output_dim=128, norm_fn='batch', dropout=0.0):
120
+ super(BasicEncoder, self).__init__()
121
+ self.norm_fn = norm_fn
122
+
123
+ if self.norm_fn == 'group':
124
+ self.norm1 = nn.GroupNorm(num_groups=8, num_channels=64)
125
+
126
+ elif self.norm_fn == 'batch':
127
+ self.norm1 = nn.BatchNorm2d(64)
128
+
129
+ elif self.norm_fn == 'instance':
130
+ self.norm1 = nn.InstanceNorm2d(64)
131
+
132
+ elif self.norm_fn == 'none':
133
+ self.norm1 = nn.Sequential()
134
+
135
+ self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3)
136
+ self.relu1 = nn.ReLU(inplace=True)
137
+
138
+ self.in_planes = 64
139
+ self.layer1 = self._make_layer(64, stride=1)
140
+ self.layer2 = self._make_layer(96, stride=2)
141
+ self.layer3 = self._make_layer(128, stride=2)
142
+
143
+ # output convolution
144
+ self.conv2 = nn.Conv2d(128, output_dim, kernel_size=1)
145
+
146
+ self.dropout = None
147
+ if dropout > 0:
148
+ self.dropout = nn.Dropout2d(p=dropout)
149
+
150
+ for m in self.modules():
151
+ if isinstance(m, nn.Conv2d):
152
+ nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
153
+ elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)):
154
+ if m.weight is not None:
155
+ nn.init.constant_(m.weight, 1)
156
+ if m.bias is not None:
157
+ nn.init.constant_(m.bias, 0)
158
+
159
+ def _make_layer(self, dim, stride=1):
160
+ layer1 = ResidualBlock(self.in_planes, dim, self.norm_fn, stride=stride)
161
+ layer2 = ResidualBlock(dim, dim, self.norm_fn, stride=1)
162
+ layers = (layer1, layer2)
163
+
164
+ self.in_planes = dim
165
+ return nn.Sequential(*layers)
166
+
167
+
168
+ def forward(self, x):
169
+
170
+ # if input is list, combine batch dimension
171
+ is_list = isinstance(x, tuple) or isinstance(x, list)
172
+ if is_list:
173
+ batch_dim = x[0].shape[0]
174
+ x = torch.cat(x, dim=0)
175
+
176
+ x = self.conv1(x)
177
+ x = self.norm1(x)
178
+ x = self.relu1(x)
179
+
180
+ x = self.layer1(x)
181
+ x = self.layer2(x)
182
+ x = self.layer3(x)
183
+
184
+ x = self.conv2(x)
185
+
186
+ if self.training and self.dropout is not None:
187
+ x = self.dropout(x)
188
+
189
+ if is_list:
190
+ x = torch.split(x, [batch_dim, batch_dim], dim=0)
191
+
192
+ return x
193
+
194
+ class BasicEncoder1(nn.Module):
195
+ def __init__(self, output_dim=128, norm_fn='batch', dropout=0.0):
196
+ super(BasicEncoder1, self).__init__()
197
+ self.norm_fn = norm_fn
198
+
199
+ if self.norm_fn == 'group':
200
+ self.norm1 = nn.GroupNorm(num_groups=8, num_channels=64)
201
+
202
+ elif self.norm_fn == 'batch':
203
+ self.norm1 = nn.BatchNorm2d(64)
204
+
205
+ elif self.norm_fn == 'instance':
206
+ self.norm1 = nn.InstanceNorm2d(64)
207
+
208
+ elif self.norm_fn == 'none':
209
+ self.norm1 = nn.Sequential()
210
+
211
+ self.conv1 = nn.Conv2d(2, 64, kernel_size=7, stride=2, padding=3)
212
+ self.relu1 = nn.ReLU(inplace=True)
213
+
214
+ self.in_planes = 64
215
+ self.layer1 = self._make_layer(64, stride=1)
216
+ self.layer2 = self._make_layer(96, stride=2)
217
+ self.layer3 = self._make_layer(128, stride=2)
218
+
219
+ # output convolution
220
+ self.conv2 = nn.Conv2d(128, output_dim, kernel_size=1)
221
+
222
+ self.dropout = None
223
+ if dropout > 0:
224
+ self.dropout = nn.Dropout2d(p=dropout)
225
+
226
+ for m in self.modules():
227
+ if isinstance(m, nn.Conv2d):
228
+ nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
229
+ elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)):
230
+ if m.weight is not None:
231
+ nn.init.constant_(m.weight, 1)
232
+ if m.bias is not None:
233
+ nn.init.constant_(m.bias, 0)
234
+
235
+ def _make_layer(self, dim, stride=1):
236
+ layer1 = ResidualBlock(self.in_planes, dim, self.norm_fn, stride=stride)
237
+ layer2 = ResidualBlock(dim, dim, self.norm_fn, stride=1)
238
+ layers = (layer1, layer2)
239
+
240
+ self.in_planes = dim
241
+ return nn.Sequential(*layers)
242
+
243
+
244
+ def forward(self, x):
245
+
246
+ # if input is list, combine batch dimension
247
+ is_list = isinstance(x, tuple) or isinstance(x, list)
248
+ if is_list:
249
+ batch_dim = x[0].shape[0]
250
+ x = torch.cat(x, dim=0)
251
+
252
+ x = self.conv1(x)
253
+ x = self.norm1(x)
254
+ x = self.relu1(x)
255
+
256
+ x = self.layer1(x)
257
+ x = self.layer2(x)
258
+ x = self.layer3(x)
259
+
260
+ x = self.conv2(x)
261
+
262
+ if self.training and self.dropout is not None:
263
+ x = self.dropout(x)
264
+
265
+ if is_list:
266
+ x = torch.split(x, [batch_dim, batch_dim], dim=0)
267
+
268
+ return x
269
+
270
+ class SmallEncoder(nn.Module):
271
+ def __init__(self, output_dim=128, norm_fn='batch', dropout=0.0):
272
+ super(SmallEncoder, self).__init__()
273
+ self.norm_fn = norm_fn
274
+
275
+ if self.norm_fn == 'group':
276
+ self.norm1 = nn.GroupNorm(num_groups=8, num_channels=32)
277
+
278
+ elif self.norm_fn == 'batch':
279
+ self.norm1 = nn.BatchNorm2d(32)
280
+
281
+ elif self.norm_fn == 'instance':
282
+ self.norm1 = nn.InstanceNorm2d(32)
283
+
284
+ elif self.norm_fn == 'none':
285
+ self.norm1 = nn.Sequential()
286
+
287
+ self.conv1 = nn.Conv2d(3, 32, kernel_size=7, stride=2, padding=3)
288
+ self.relu1 = nn.ReLU(inplace=True)
289
+
290
+ self.in_planes = 32
291
+ self.layer1 = self._make_layer(32, stride=1)
292
+ self.layer2 = self._make_layer(64, stride=2)
293
+ self.layer3 = self._make_layer(96, stride=2)
294
+
295
+ self.dropout = None
296
+ if dropout > 0:
297
+ self.dropout = nn.Dropout2d(p=dropout)
298
+
299
+ self.conv2 = nn.Conv2d(96, output_dim, kernel_size=1)
300
+
301
+ for m in self.modules():
302
+ if isinstance(m, nn.Conv2d):
303
+ nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
304
+ elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)):
305
+ if m.weight is not None:
306
+ nn.init.constant_(m.weight, 1)
307
+ if m.bias is not None:
308
+ nn.init.constant_(m.bias, 0)
309
+
310
+ def _make_layer(self, dim, stride=1):
311
+ layer1 = BottleneckBlock(self.in_planes, dim, self.norm_fn, stride=stride)
312
+ layer2 = BottleneckBlock(dim, dim, self.norm_fn, stride=1)
313
+ layers = (layer1, layer2)
314
+
315
+ self.in_planes = dim
316
+ return nn.Sequential(*layers)
317
+
318
+
319
+ def forward(self, x):
320
+
321
+ # if input is list, combine batch dimension
322
+ is_list = isinstance(x, tuple) or isinstance(x, list)
323
+ if is_list:
324
+ batch_dim = x[0].shape[0]
325
+ x = torch.cat(x, dim=0)
326
+
327
+ x = self.conv1(x)
328
+ x = self.norm1(x)
329
+ x = self.relu1(x)
330
+
331
+ x = self.layer1(x)
332
+ x = self.layer2(x)
333
+ x = self.layer3(x)
334
+ x = self.conv2(x)
335
+
336
+ if self.training and self.dropout is not None:
337
+ x = self.dropout(x)
338
+
339
+ if is_list:
340
+ x = torch.split(x, [batch_dim, batch_dim], dim=0)
341
+
342
+ return x
modules/flow_models/raft/rfr_new.py ADDED
@@ -0,0 +1,235 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ##################################################
2
+ # RFR is implemented based on RAFT optical flow #
3
+ ##################################################
4
+
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+ from argparse import Namespace
10
+ import numpy as np
11
+
12
+ from .update import BasicUpdateBlock, SmallUpdateBlock
13
+ from .extractor import BasicEncoder, SmallEncoder
14
+ from .corr import CorrBlock
15
+ from .utils import bilinear_sampler, coords_grid, upflow8
16
+
17
+ try:
18
+ autocast = torch.amp.autocast
19
+ except:
20
+ # dummy autocast for PyTorch < 1.6
21
+ class autocast:
22
+ def __init__(self, enabled):
23
+ pass
24
+ def __enter__(self):
25
+ pass
26
+ def __exit__(self, *args):
27
+ pass
28
+
29
+ def backwarp(img, flow):
30
+ _, _, H, W = img.size()
31
+
32
+ u = flow[:, 0, :, :]
33
+ v = flow[:, 1, :, :]
34
+
35
+ gridX, gridY = np.meshgrid(np.arange(W), np.arange(H))
36
+
37
+ gridX = torch.tensor(gridX, requires_grad=False,).cuda()
38
+ gridY = torch.tensor(gridY, requires_grad=False,).cuda()
39
+ x = gridX.unsqueeze(0).expand_as(u).float() + u
40
+ y = gridY.unsqueeze(0).expand_as(v).float() + v
41
+ # range -1 to 1
42
+ x = 2*(x/(W-1) - 0.5)
43
+ y = 2*(y/(H-1) - 0.5)
44
+ # stacking X and Y
45
+ grid = torch.stack((x,y), dim=3)
46
+ # Sample pixels using bilinear interpolation.
47
+ imgOut = torch.nn.functional.grid_sample(img, grid, align_corners=True)
48
+
49
+ return imgOut
50
+ class ErrorAttention(nn.Module):
51
+ """A three-layer network for predicting mask"""
52
+ def __init__(self, input, output):
53
+ super(ErrorAttention, self).__init__()
54
+ self.conv1 = nn.Conv2d(input, 32, 5, padding=2)
55
+ self.conv2 = nn.Conv2d(32, 32, 3, padding=1)
56
+ self.conv3 = nn.Conv2d(38, output, 3, padding=1)
57
+ self.prelu1 = nn.PReLU()
58
+ self.prelu2 = nn.PReLU()
59
+
60
+ def forward(self, x1):
61
+ x = self.prelu1(self.conv1(x1))
62
+ x = self.prelu2(torch.cat([self.conv2(x), x1], dim=1))
63
+ x = self.conv3(x)
64
+ return x
65
+
66
+ class RFR(nn.Module):
67
+ def __init__(self, args):
68
+ super(RFR, self).__init__()
69
+ self.attention2 = ErrorAttention(6, 1)
70
+ self.hidden_dim = hdim = 128
71
+ self.context_dim = cdim = 128
72
+ args.corr_levels = 4
73
+ args.corr_radius = 4
74
+ args.dropout = 0
75
+ self.args = args
76
+
77
+ # feature network, context network, and update block
78
+ self.fnet = BasicEncoder(output_dim=256, norm_fn='none', dropout=args.dropout)
79
+ # self.cnet = BasicEncoder(output_dim=hdim+cdim, norm_fn='none', dropout=args.dropout)
80
+ self.update_block = BasicUpdateBlock(self.args, hidden_dim=hdim)
81
+
82
+
83
+
84
+ def freeze_bn(self):
85
+ for m in self.modules():
86
+ if isinstance(m, nn.BatchNorm2d):
87
+ m.eval()
88
+
89
+ def initialize_flow(self, img):
90
+ """ Flow is represented as difference between two coordinate grids flow = coords1 - coords0"""
91
+ N, C, H, W = img.shape
92
+ coords0 = coords_grid(N, H//8, W//8).to(img.device)
93
+ coords1 = coords_grid(N, H//8, W//8).to(img.device)
94
+
95
+ # optical flow computed as difference: flow = coords1 - coords0
96
+ return coords0, coords1
97
+
98
+ def upsample_flow(self, flow, mask):
99
+ """ Upsample flow field [H/8, W/8, 2] -> [H, W, 2] using convex combination """
100
+ N, _, H, W = flow.shape
101
+ mask = mask.view(N, 1, 9, 8, 8, H, W)
102
+ mask = torch.softmax(mask, dim=2)
103
+
104
+ up_flow = F.unfold(8 * flow, [3,3], padding=1)
105
+ up_flow = up_flow.view(N, 2, 9, 1, 1, H, W)
106
+
107
+ up_flow = torch.sum(mask * up_flow, dim=2)
108
+ up_flow = up_flow.permute(0, 1, 4, 2, 5, 3)
109
+ return up_flow.reshape(N, 2, 8*H, 8*W)
110
+
111
+ def forward(self, image1, image2, iters=12, flow_init=None, upsample=True, test_mode=False):
112
+ H, W = image1.size()[2:4]
113
+ H8 = H // 8 * 8
114
+ W8 = W // 8 * 8
115
+
116
+ if flow_init is not None:
117
+ flow_init_resize = F.interpolate(flow_init, size=(H8//8, W8//8), mode='nearest')
118
+
119
+ flow_init_resize[:, :1] = flow_init_resize[:, :1].clone() * (W8 // 8 *1.0) / flow_init.size()[3]
120
+ flow_init_resize[:, 1:] = flow_init_resize[:, 1:].clone() * (H8 // 8*1.0) / flow_init.size()[2]
121
+
122
+ if not hasattr(self.args, 'not_use_rfr_mask') or ( hasattr(self.args, 'not_use_rfr_mask') and (not self.args.not_use_rfr_mask)):
123
+ im18 = F.interpolate(image1, size=(H8//8, W8//8), mode='bilinear')
124
+ im28 = F.interpolate(image2, size=(H8//8, W8//8), mode='bilinear')
125
+
126
+ warp21 = backwarp(im28, flow_init_resize)
127
+ error21 = torch.sum(torch.abs(warp21 - im18), dim=1, keepdim=True)
128
+ # print('errormin', error21.min(), error21.max())
129
+ f12init = torch.exp(- self.attention2(torch.cat([im18, error21, flow_init_resize], dim=1)) ** 2) * flow_init_resize
130
+ else:
131
+ flow_init_resize = None
132
+ flow_init = torch.zeros(image1.size()[0], 2, image1.size()[2]//8, image1.size()[3]//8).cuda()
133
+ error21 = torch.zeros(image1.size()[0], 1, image1.size()[2]//8, image1.size()[3]//8).cuda()
134
+
135
+ f12_init = flow_init
136
+ # print('None inital flow!')
137
+
138
+ image1 = F.interpolate(image1, size=(H8, W8), mode='bilinear')
139
+ image2 = F.interpolate(image2, size=(H8, W8), mode='bilinear')
140
+
141
+ f12s, f12, f12_init = self.forward_pred(image1, image2, iters, flow_init_resize, upsample, test_mode)
142
+
143
+
144
+ if (hasattr(self.args, 'requires_sq_flow') and self.args.requires_sq_flow):
145
+ for ii in range(len(f12s)):
146
+ f12s[ii] = F.interpolate(f12s[ii], size=(H, W), mode='bilinear')
147
+ f12s[ii][:, :1] = f12s[ii][:, :1].clone() / (1.0*W8) * W
148
+ f12s[ii][:, 1:] = f12s[ii][:, 1:].clone() / (1.0*H8) * H
149
+ if self.training:
150
+ return f12s
151
+ else:
152
+ return [f12s[-1]], f12_init
153
+ else:
154
+ f12[:, :1] = f12[:, :1].clone() / (1.0*W8) * W
155
+ f12[:, 1:] = f12[:, 1:].clone() / (1.0*H8) * H
156
+
157
+ f12 = F.interpolate(f12, size=(H, W), mode='bilinear')
158
+ # print('wo!!')
159
+ return f12, f12_init, error21,
160
+
161
+ def forward_pred(self, image1, image2, iters=12, flow_init=None, upsample=True, test_mode=False):
162
+ """ Estimate optical flow between pair of frames """
163
+
164
+
165
+ image1 = image1.contiguous()
166
+ image2 = image2.contiguous()
167
+
168
+ hdim = self.hidden_dim
169
+ cdim = self.context_dim
170
+
171
+ # run the feature network
172
+ with autocast("cuda", enabled=self.args.mixed_precision):
173
+ fmap1, fmap2 = self.fnet([image1, image2])
174
+ fmap1 = fmap1.float()
175
+ fmap2 = fmap2.float()
176
+ corr_fn = CorrBlock(fmap1, fmap2, radius=self.args.corr_radius)
177
+
178
+ # run the context network
179
+ with autocast("cuda", enabled=self.args.mixed_precision):
180
+ cnet = self.fnet(image1)
181
+ net, inp = torch.split(cnet, [hdim, cdim], dim=1)
182
+ net = torch.tanh(net)
183
+ inp = torch.relu(inp)
184
+
185
+ coords0, coords1 = self.initialize_flow(image1)
186
+
187
+ if flow_init is not None:
188
+ coords1 = coords1 + flow_init
189
+
190
+ flow_predictions = []
191
+ for itr in range(iters):
192
+ coords1 = coords1.detach()
193
+ if itr == 0:
194
+ if flow_init is not None:
195
+ coords1 = coords1 + flow_init
196
+ corr = corr_fn(coords1) # index correlation volume
197
+
198
+ flow = coords1 - coords0
199
+ with autocast("cuda", enabled=self.args.mixed_precision):
200
+ net, up_mask, delta_flow = self.update_block(net, inp, corr, flow)
201
+
202
+ # F(t+1) = F(t) + \Delta(t)
203
+ coords1 = coords1 + delta_flow
204
+
205
+ # upsample predictions
206
+ if up_mask is None:
207
+ flow_up = upflow8(coords1 - coords0)
208
+ else:
209
+ flow_up = self.upsample_flow(coords1 - coords0, up_mask)
210
+
211
+ flow_predictions.append(flow_up)
212
+
213
+ return flow_predictions, flow_up, flow_init
214
+
215
+ class RAFT(nn.Module):
216
+ def __init__(self, path='./_pretrain_models/anime_interp_full.ckpt'):
217
+ super().__init__()
218
+ self.raft = RFR(Namespace(
219
+ small=False,
220
+ mixed_precision=False,
221
+ ))
222
+ if path is not None:
223
+ sd = torch.load(path)['model_state_dict']
224
+ self.raft.load_state_dict({
225
+ k[len('module.flownet.'):]: v
226
+ for k,v in sd.items()
227
+ if k.startswith('module.flownet.')
228
+ }, strict=False)
229
+ return
230
+ def forward(self, img0, img1, flow0=None, iters=12, return_more=False):
231
+ if flow0 is not None:
232
+ flow0 = flow0.flip(dims=(1,))
233
+ out = self.raft(img0, img1, iters=iters, flow_init=flow0)
234
+ return out[0].flip(dims=(1,))
235
+
modules/flow_models/raft/update.py ADDED
@@ -0,0 +1,139 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+
6
+ class FlowHead(nn.Module):
7
+ def __init__(self, input_dim=128, hidden_dim=256):
8
+ super(FlowHead, self).__init__()
9
+ self.conv1 = nn.Conv2d(input_dim, hidden_dim, 3, padding=1)
10
+ self.conv2 = nn.Conv2d(hidden_dim, 2, 3, padding=1)
11
+ self.relu = nn.ReLU(inplace=True)
12
+
13
+ def forward(self, x):
14
+ return self.conv2(self.relu(self.conv1(x)))
15
+
16
+ class ConvGRU(nn.Module):
17
+ def __init__(self, hidden_dim=128, input_dim=192+128):
18
+ super(ConvGRU, self).__init__()
19
+ self.convz = nn.Conv2d(hidden_dim+input_dim, hidden_dim, 3, padding=1)
20
+ self.convr = nn.Conv2d(hidden_dim+input_dim, hidden_dim, 3, padding=1)
21
+ self.convq = nn.Conv2d(hidden_dim+input_dim, hidden_dim, 3, padding=1)
22
+
23
+ def forward(self, h, x):
24
+ hx = torch.cat([h, x], dim=1)
25
+
26
+ z = torch.sigmoid(self.convz(hx))
27
+ r = torch.sigmoid(self.convr(hx))
28
+ q = torch.tanh(self.convq(torch.cat([r*h, x], dim=1)))
29
+
30
+ h = (1-z) * h + z * q
31
+ return h
32
+
33
+ class SepConvGRU(nn.Module):
34
+ def __init__(self, hidden_dim=128, input_dim=192+128):
35
+ super(SepConvGRU, self).__init__()
36
+ self.convz1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2))
37
+ self.convr1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2))
38
+ self.convq1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2))
39
+
40
+ self.convz2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0))
41
+ self.convr2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0))
42
+ self.convq2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0))
43
+
44
+
45
+ def forward(self, h, x):
46
+ # horizontal
47
+ hx = torch.cat([h, x], dim=1)
48
+ z = torch.sigmoid(self.convz1(hx))
49
+ r = torch.sigmoid(self.convr1(hx))
50
+ q = torch.tanh(self.convq1(torch.cat([r*h, x], dim=1)))
51
+ h = (1-z) * h + z * q
52
+
53
+ # vertical
54
+ hx = torch.cat([h, x], dim=1)
55
+ z = torch.sigmoid(self.convz2(hx))
56
+ r = torch.sigmoid(self.convr2(hx))
57
+ q = torch.tanh(self.convq2(torch.cat([r*h, x], dim=1)))
58
+ h = (1-z) * h + z * q
59
+
60
+ return h
61
+
62
+ class SmallMotionEncoder(nn.Module):
63
+ def __init__(self, args):
64
+ super(SmallMotionEncoder, self).__init__()
65
+ cor_planes = args.corr_levels * (2*args.corr_radius + 1)**2
66
+ self.convc1 = nn.Conv2d(cor_planes, 96, 1, padding=0)
67
+ self.convf1 = nn.Conv2d(2, 64, 7, padding=3)
68
+ self.convf2 = nn.Conv2d(64, 32, 3, padding=1)
69
+ self.conv = nn.Conv2d(128, 80, 3, padding=1)
70
+
71
+ def forward(self, flow, corr):
72
+ cor = F.relu(self.convc1(corr))
73
+ flo = F.relu(self.convf1(flow))
74
+ flo = F.relu(self.convf2(flo))
75
+ cor_flo = torch.cat([cor, flo], dim=1)
76
+ out = F.relu(self.conv(cor_flo))
77
+ return torch.cat([out, flow], dim=1)
78
+
79
+ class BasicMotionEncoder(nn.Module):
80
+ def __init__(self, args):
81
+ super(BasicMotionEncoder, self).__init__()
82
+ cor_planes = args.corr_levels * (2*args.corr_radius + 1)**2
83
+ self.convc1 = nn.Conv2d(cor_planes, 256, 1, padding=0)
84
+ self.convc2 = nn.Conv2d(256, 192, 3, padding=1)
85
+ self.convf1 = nn.Conv2d(2, 128, 7, padding=3)
86
+ self.convf2 = nn.Conv2d(128, 64, 3, padding=1)
87
+ self.conv = nn.Conv2d(64+192, 128-2, 3, padding=1)
88
+
89
+ def forward(self, flow, corr):
90
+ cor = F.relu(self.convc1(corr))
91
+ cor = F.relu(self.convc2(cor))
92
+ flo = F.relu(self.convf1(flow))
93
+ flo = F.relu(self.convf2(flo))
94
+
95
+ cor_flo = torch.cat([cor, flo], dim=1)
96
+ out = F.relu(self.conv(cor_flo))
97
+ return torch.cat([out, flow], dim=1)
98
+
99
+ class SmallUpdateBlock(nn.Module):
100
+ def __init__(self, args, hidden_dim=96):
101
+ super(SmallUpdateBlock, self).__init__()
102
+ self.encoder = SmallMotionEncoder(args)
103
+ self.gru = ConvGRU(hidden_dim=hidden_dim, input_dim=82+64)
104
+ self.flow_head = FlowHead(hidden_dim, hidden_dim=128)
105
+
106
+ def forward(self, net, inp, corr, flow):
107
+ motion_features = self.encoder(flow, corr)
108
+ inp = torch.cat([inp, motion_features], dim=1)
109
+ net = self.gru(net, inp)
110
+ delta_flow = self.flow_head(net)
111
+
112
+ return net, None, delta_flow
113
+
114
+ class BasicUpdateBlock(nn.Module):
115
+ def __init__(self, args, hidden_dim=128, input_dim=128):
116
+ super(BasicUpdateBlock, self).__init__()
117
+ self.args = args
118
+ self.encoder = BasicMotionEncoder(args)
119
+ self.gru = SepConvGRU(hidden_dim=hidden_dim, input_dim=128+hidden_dim)
120
+ self.flow_head = FlowHead(hidden_dim, hidden_dim=256)
121
+
122
+ self.mask = nn.Sequential(
123
+ nn.Conv2d(128, 256, 3, padding=1),
124
+ nn.ReLU(inplace=True),
125
+ nn.Conv2d(256, 64*9, 1, padding=0))
126
+
127
+ def forward(self, net, inp, corr, flow, upsample=True):
128
+ motion_features = self.encoder(flow, corr)
129
+ inp = torch.cat([inp, motion_features], dim=1)
130
+
131
+ net = self.gru(net, inp)
132
+ delta_flow = self.flow_head(net)
133
+
134
+ # scale mask to balence gradients
135
+ mask = .25 * self.mask(net)
136
+ return net, mask, delta_flow
137
+
138
+
139
+
modules/flow_models/raft/utils.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ import numpy as np
4
+ from scipy import interpolate
5
+
6
+
7
+ class InputPadder:
8
+ """ Pads images such that dimensions are divisible by 8 """
9
+ def __init__(self, dims):
10
+ self.ht, self.wd = dims[-2:]
11
+ pad_ht = (((self.ht // 8) + 1) * 8 - self.ht) % 8
12
+ pad_wd = (((self.wd // 8) + 1) * 8 - self.wd) % 8
13
+ self._pad = [pad_wd//2, pad_wd - pad_wd//2, 0, pad_ht]
14
+
15
+ def pad(self, *inputs):
16
+ return [F.pad(x, self._pad, mode='replicate') for x in inputs]
17
+
18
+ def unpad(self,x):
19
+ ht, wd = x.shape[-2:]
20
+ c = [self._pad[2], ht-self._pad[3], self._pad[0], wd-self._pad[1]]
21
+ return x[..., c[0]:c[1], c[2]:c[3]]
22
+
23
+ def forward_interpolate(flow):
24
+ flow = flow.detach().cpu().numpy()
25
+ dx, dy = flow[0], flow[1]
26
+
27
+ ht, wd = dx.shape
28
+ x0, y0 = np.meshgrid(np.arange(wd), np.arange(ht))
29
+
30
+ x1 = x0 + dx
31
+ y1 = y0 + dy
32
+
33
+ x1 = x1.reshape(-1)
34
+ y1 = y1.reshape(-1)
35
+ dx = dx.reshape(-1)
36
+ dy = dy.reshape(-1)
37
+
38
+ valid = (x1 > 0) & (x1 < wd) & (y1 > 0) & (y1 < ht)
39
+ x1 = x1[valid]
40
+ y1 = y1[valid]
41
+ dx = dx[valid]
42
+ dy = dy[valid]
43
+
44
+ flow_x = interpolate.griddata(
45
+ (x1, y1), dx, (x0, y0), method='cubic', fill_value=0)
46
+
47
+ flow_y = interpolate.griddata(
48
+ (x1, y1), dy, (x0, y0), method='cubic', fill_value=0)
49
+
50
+ flow = np.stack([flow_x, flow_y], axis=0)
51
+ return torch.from_numpy(flow).float()
52
+
53
+
54
+ def bilinear_sampler(img, coords, mode='bilinear', mask=False):
55
+ """ Wrapper for grid_sample, uses pixel coordinates """
56
+ H, W = img.shape[-2:]
57
+ xgrid, ygrid = coords.split([1,1], dim=-1)
58
+ xgrid = 2*xgrid/(W-1) - 1
59
+ ygrid = 2*ygrid/(H-1) - 1
60
+
61
+ grid = torch.cat([xgrid, ygrid], dim=-1)
62
+ # print(img.size())
63
+ img = F.grid_sample(img, grid, align_corners=True)
64
+
65
+ if mask:
66
+ mask = (xgrid > -1) & (ygrid > -1) & (xgrid < 1) & (ygrid < 1)
67
+ return img, mask.float()
68
+
69
+ return img
70
+
71
+
72
+
73
+ def coords_grid(batch, ht, wd):
74
+ coords = torch.meshgrid(torch.arange(ht), torch.arange(wd))
75
+ coords = torch.stack(coords[::-1], dim=0).float()
76
+ return coords[None].repeat(batch, 1, 1, 1)
77
+
78
+
79
+ def upflow8(flow, mode='bilinear'):
80
+ new_size = (8 * flow.shape[2], 8 * flow.shape[3])
81
+ return 8 * F.interpolate(flow, size=new_size, mode=mode, align_corners=True)
modules/half_warper.py ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from kornia.color import rgb_to_lab
5
+
6
+ from utils.utils import morph_open
7
+
8
+ from modules.cupy_module.softsplat import FunctionSoftsplat
9
+
10
+ class HalfWarper(nn.Module):
11
+ def __init__(self):
12
+ super().__init__()
13
+
14
+ @staticmethod
15
+ def backward_wrapping(
16
+ img: torch.Tensor,
17
+ flow: torch.Tensor,
18
+ resample: str = 'bilinear',
19
+ padding_mode: str = 'border',
20
+ align_corners: bool = False
21
+ ) -> torch.Tensor:
22
+ if len(img.shape) != 4: img = img[None,]
23
+ if len(flow.shape) != 4: flow = flow[None,]
24
+
25
+ q = 2 * flow / torch.tensor([
26
+ flow.shape[-2], flow.shape[-1],
27
+ ], device=flow.device, dtype=torch.float)[None,:,None,None]
28
+
29
+ q = q + torch.stack(torch.meshgrid(
30
+ torch.linspace(-1, 1, flow.shape[-2]),
31
+ torch.linspace(-1, 1, flow.shape[-1]),
32
+ ))[None,].to(flow.device)
33
+
34
+ if img.dtype != q.dtype:
35
+ img = img.type(q.dtype)
36
+
37
+ return F.grid_sample(
38
+ img,
39
+ q.flip(dims=(1,)).permute(0, 2, 3, 1).contiguous(),
40
+ mode = resample, # nearest, bicubic, bilinear
41
+ padding_mode = padding_mode, # border, zeros, reflection
42
+ align_corners = align_corners,
43
+ )
44
+
45
+ @staticmethod
46
+ def forward_warpping(
47
+ img: torch.Tensor,
48
+ flow: torch.Tensor,
49
+ mode: str = 'softmax',
50
+ metric: torch.Tensor | None = None,
51
+ mask: bool = True
52
+ ) -> torch.Tensor:
53
+ if len(img.shape) != 4: img = img[None,]
54
+ if len(flow.shape) != 4: flow = flow[None,]
55
+ if metric is not None and len(metric.shape)!=4: metric = metric[None,]
56
+
57
+ flow = flow.flip(dims=(1,))
58
+ if img.dtype != torch.float32:
59
+ img = img.type(torch.float32)
60
+ if flow.dtype != torch.float32:
61
+ flow = flow.type(torch.float32)
62
+ if metric is not None and metric.dtype != torch.float32:
63
+ metric = metric.type(torch.float32)
64
+
65
+ assert img.device == flow.device
66
+ if metric is not None: assert img.device == metric.device
67
+ if img.device.type=='cpu':
68
+ img = img.to('cuda')
69
+ flow = flow.to('cuda')
70
+ if metric is not None: metric = metric.to('cuda')
71
+
72
+ if mask:
73
+ batch, _, h, w = img.shape
74
+ img = torch.cat([img, torch.ones(batch, 1, h, w, dtype=img.dtype, device=img.device)], dim=1)
75
+
76
+ return FunctionSoftsplat(img, flow, metric, mode)
77
+
78
+ @staticmethod
79
+ def z_metric(
80
+ img0: torch.Tensor,
81
+ img1: torch.Tensor,
82
+ flow0to1: torch.Tensor,
83
+ flow1to0: torch.Tensor
84
+ ) -> tuple[torch.Tensor, torch.Tensor]:
85
+ img0 = rgb_to_lab(img0[:,:3])
86
+ img1 = rgb_to_lab(img1[:,:3])
87
+ z1to0 = -0.1*(img1 - HalfWarper.backward_wrapping(img0, flow1to0)).norm(dim=1, keepdim=True)
88
+ z0to1 = -0.1*(img0 - HalfWarper.backward_wrapping(img1, flow0to1)).norm(dim=1, keepdim=True)
89
+ return z0to1, z1to0
90
+
91
+ def forward(
92
+ self,
93
+ I0: torch.Tensor,
94
+ I1: torch.Tensor,
95
+ flow0to1: torch.Tensor,
96
+ flow1to0: torch.Tensor,
97
+ z0to1: torch.Tensor | None = None,
98
+ z1to0: torch.Tensor | None = None,
99
+ tau: float | None = None,
100
+ morph_kernel_size: int = 5,
101
+ mask: bool = True
102
+ ) -> tuple[torch.Tensor, torch.Tensor]:
103
+
104
+ if z1to0 is None or z0to1 is None:
105
+ z0to1, z1to0 = self.z_metric(I0, I1, flow0to1, flow1to0)
106
+
107
+ if tau is not None:
108
+ flow0tot = tau*flow0to1
109
+ flow1tot = (1 - tau)*flow1to0
110
+ else:
111
+ flow0tot = flow0to1
112
+ flow1tot = flow1to0
113
+
114
+ # image warping
115
+ fw0to1 = HalfWarper.forward_warpping(I0, flow0tot, mode='softmax', metric=z0to1, mask=True)
116
+ fw1to0 = HalfWarper.forward_warpping(I1, flow1tot, mode='softmax', metric=z1to0, mask=True)
117
+
118
+ wrapped_image0tot = fw0to1[:,:-1]
119
+ wrapped_image1tot = fw1to0[:,:-1]
120
+ mask0tot = morph_open(fw0to1[:,-1:], k=morph_kernel_size)
121
+ mask1tot = morph_open(fw1to0[:,-1:], k=morph_kernel_size)
122
+
123
+ base0 = mask0tot*wrapped_image0tot + (1 - mask0tot)*wrapped_image1tot
124
+ base1 = mask1tot*wrapped_image1tot + (1 - mask1tot)*wrapped_image0tot
125
+
126
+ if mask:
127
+ base0 = torch.cat([base0, mask0tot], dim=1)
128
+ base1 = torch.cat([base1, mask1tot], dim=1)
129
+ return base0, base1
modules/synthesizer.py ADDED
@@ -0,0 +1,277 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ from modules.basic_layers import (
5
+ SinusoidalPositionalEmbedding,
6
+ ResGatedBlock,
7
+ MaxViTBlock,
8
+ Downsample,
9
+ Upsample
10
+ )
11
+
12
+ class UnetDownBlock(nn.Module):
13
+ def __init__(
14
+ self,
15
+ in_channels: int,
16
+ out_channels: int,
17
+ temb_channels: int = 128,
18
+ heads: int = 1,
19
+ window_size: int = 7,
20
+ window_attn: bool = True,
21
+ grid_attn: bool = True,
22
+ expansion_rate: int = 4,
23
+ num_conv_blocks: int = 2,
24
+ dropout: float = 0.0
25
+ ):
26
+ super(UnetDownBlock, self).__init__()
27
+ self.pool = Downsample(
28
+ in_channels = in_channels,
29
+ out_channels = in_channels,
30
+ use_conv = True
31
+ )
32
+ in_channels = 3 * in_channels + 2
33
+ self.conv = nn.ModuleList([
34
+ ResGatedBlock(
35
+ in_channels = in_channels if i == 0 else out_channels,
36
+ out_channels = out_channels,
37
+ emb_channels = temb_channels,
38
+ gated_conv = True
39
+ ) for i in range(num_conv_blocks)
40
+ ])
41
+ self.maxvit = MaxViTBlock(
42
+ channels = out_channels,
43
+ #latent_dim = out_channels // 6,
44
+ heads = heads,
45
+ window_size = window_size,
46
+ window_attn = window_attn,
47
+ grid_attn = grid_attn,
48
+ expansion_rate = expansion_rate,
49
+ dropout = dropout,
50
+ emb_channels = temb_channels
51
+ )
52
+
53
+ def forward(
54
+ self,
55
+ x: torch.Tensor,
56
+ warp0: torch.Tensor,
57
+ warp1: torch.Tensor,
58
+ temb: torch.Tensor
59
+ ):
60
+ x = self.pool(x)
61
+ x = torch.cat([x, warp0, warp1], dim=1)
62
+ for conv in self.conv:
63
+ x = conv(x, temb)
64
+ x = self.maxvit(x, temb)
65
+ return x
66
+
67
+ class UnetMiddleBlock(nn.Module):
68
+ def __init__(
69
+ self,
70
+ in_channels: int,
71
+ mid_channels: int,
72
+ out_channels: int,
73
+ temb_channels: int = 128,
74
+ heads: int = 1,
75
+ window_size: int = 7,
76
+ window_attn: bool = True,
77
+ grid_attn: bool = True,
78
+ expansion_rate: int = 4,
79
+ dropout: float = 0.0
80
+ ):
81
+ super(UnetMiddleBlock, self).__init__()
82
+
83
+ self.middle_blocks = nn.ModuleList([
84
+ ResGatedBlock(
85
+ in_channels = in_channels,
86
+ out_channels = mid_channels,
87
+ emb_channels = temb_channels,
88
+ gated_conv = True
89
+ ),
90
+ MaxViTBlock(
91
+ channels = mid_channels,
92
+ #latent_dim = mid_channels // 6,
93
+ heads = heads,
94
+ window_size = window_size,
95
+ window_attn = window_attn,
96
+ grid_attn = grid_attn,
97
+ expansion_rate = expansion_rate,
98
+ dropout = dropout,
99
+ emb_channels = temb_channels
100
+ ),
101
+ ResGatedBlock(
102
+ in_channels = mid_channels,
103
+ out_channels = out_channels,
104
+ emb_channels = temb_channels,
105
+ gated_conv = True
106
+ )
107
+ ])
108
+
109
+ def forward(self, x, temb):
110
+ for block in self.middle_blocks:
111
+ x = block(x, temb)
112
+ return x
113
+
114
+ class UnetUpBlock(nn.Module):
115
+ def __init__(
116
+ self,
117
+ in_channels: int,
118
+ out_channels: int,
119
+ temb_channels: int = 128,
120
+ heads: int = 1,
121
+ window_size: int = 7,
122
+ window_attn: bool = True,
123
+ grid_attn: bool = True,
124
+ expansion_rate: int = 4,
125
+ num_conv_blocks: int = 2,
126
+ dropout: float = 0.0
127
+ ):
128
+ super(UnetUpBlock, self).__init__()
129
+ in_channels = 2 * in_channels
130
+ self.maxvit = MaxViTBlock(
131
+ channels = in_channels,
132
+ #latent_dim = in_channels // 6,
133
+ heads = heads,
134
+ window_size = window_size,
135
+ window_attn = window_attn,
136
+ grid_attn = grid_attn,
137
+ expansion_rate = expansion_rate,
138
+ dropout = dropout,
139
+ emb_channels = temb_channels
140
+ )
141
+ self.upsample = Upsample(
142
+ in_channels = in_channels,
143
+ out_channels = in_channels,
144
+ use_conv = True
145
+ )
146
+ self.conv = nn.ModuleList([
147
+ ResGatedBlock(
148
+ in_channels if i == 0 else out_channels,
149
+ out_channels,
150
+ emb_channels = temb_channels,
151
+ gated_conv = True
152
+ ) for i in range(num_conv_blocks)
153
+ ])
154
+
155
+ def forward(
156
+ self,
157
+ x: torch.Tensor,
158
+ skip_connection: torch.Tensor,
159
+ temb: torch.Tensor
160
+ ):
161
+ x = torch.cat([x, skip_connection], dim=1)
162
+ x = self.maxvit(x, temb)
163
+ x = self.upsample(x)
164
+ for conv in self.conv:
165
+ x = conv(x, temb)
166
+ return x
167
+
168
+ class Synthesis(nn.Module):
169
+ def __init__(
170
+ self,
171
+ in_channels: int,
172
+ channels: list[int],
173
+ temb_channels: int,
174
+ heads: int = 1,
175
+ window_size: int = 7,
176
+ window_attn: bool = True,
177
+ grid_attn: bool = True,
178
+ expansion_rate: int = 4,
179
+ num_conv_blocks: int = 2,
180
+ dropout: float = 0.0
181
+ ):
182
+ super(Synthesis, self).__init__()
183
+
184
+
185
+ self.t_pos_encoding = SinusoidalPositionalEmbedding(temb_channels)
186
+
187
+ self.input_blocks = nn.ModuleList([
188
+ nn.Conv2d(3*in_channels + 4, channels[0], kernel_size=3, padding=1),
189
+ ResGatedBlock(
190
+ in_channels = channels[0],
191
+ out_channels = channels[0],
192
+ emb_channels = temb_channels,
193
+ gated_conv = True
194
+ )
195
+ ])
196
+
197
+ self.down_blocks = nn.ModuleList([
198
+ UnetDownBlock(
199
+ #3 * channels[i] + 2,
200
+ channels[i],
201
+ channels[i + 1],
202
+ temb_channels,
203
+ heads = heads,
204
+ window_size = window_size,
205
+ window_attn = window_attn,
206
+ grid_attn = grid_attn,
207
+ expansion_rate = expansion_rate,
208
+ num_conv_blocks = num_conv_blocks,
209
+ dropout = dropout,
210
+ ) for i in range(len(channels) - 1)
211
+ ])
212
+
213
+ self.middle_block = UnetMiddleBlock(
214
+ in_channels = channels[-1],
215
+ mid_channels = channels[-1],
216
+ out_channels = channels[-1],
217
+ temb_channels = temb_channels,
218
+ heads = heads,
219
+ window_size = window_size,
220
+ window_attn = window_attn,
221
+ grid_attn = grid_attn,
222
+ expansion_rate = expansion_rate,
223
+ dropout = dropout,
224
+ )
225
+
226
+ self.up_blocks = nn.ModuleList([
227
+ UnetUpBlock(
228
+ channels[i + 1],
229
+ channels[i],
230
+ temb_channels,
231
+ heads = heads,
232
+ window_size = window_size,
233
+ window_attn = window_attn,
234
+ grid_attn = grid_attn,
235
+ expansion_rate = expansion_rate,
236
+ num_conv_blocks = num_conv_blocks,
237
+ dropout = dropout,
238
+ ) for i in reversed(range(len(channels) - 1))
239
+ ])
240
+
241
+ self.output_blocks = nn.ModuleList([
242
+ ResGatedBlock(
243
+ in_channels = channels[0],
244
+ out_channels = channels[0],
245
+ emb_channels = temb_channels,
246
+ gated_conv = True
247
+ ),
248
+ nn.Conv2d(channels[0], in_channels, kernel_size=3, padding=1)
249
+ ])
250
+
251
+ def forward(
252
+ self,
253
+ x: torch.Tensor,
254
+ warp0: list[torch.Tensor],
255
+ warp1: list[torch.Tensor],
256
+ temb: torch.Tensor
257
+ ):
258
+ temb = temb.unsqueeze(-1).type(torch.float)
259
+ temb = self.t_pos_encoding(temb)
260
+
261
+ x = self.input_blocks[0](torch.cat([x, warp0[0], warp1[0]], dim=1))
262
+ x = self.input_blocks[1](x, temb)
263
+
264
+ features = []
265
+ for i, down_block in enumerate(self.down_blocks):
266
+ x = down_block(x, warp0[i + 1], warp1[i + 1], temb)
267
+ features.append(x)
268
+
269
+ x = self.middle_block(x, temb)
270
+
271
+ for i, up_block in enumerate(self.up_blocks):
272
+ x = up_block(x, features[-(i + 1)], temb)
273
+
274
+ x = self.output_blocks[0](x, temb)
275
+ x = self.output_blocks[1](x)
276
+
277
+ return x
requirements.txt ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Main dependencies
2
+ torch>=2.6.0
3
+ torchvision>=0.21.0
4
+ lightning>=2.2.4
5
+ numpy>=1.26.4
6
+ matplotlib>=3.8.0
7
+ pyyaml>=6.0.0
8
+
9
+ # Huggingface
10
+ huggingface-hub>=0.30.2
11
+
12
+ # Image processing and computer vision
13
+ kornia>=0.7.2
14
+ opencv-python>=4.10.0.84
15
+ opencv-contrib-python>=4.10.0.84
16
+ einops>=0.8.0
17
+
18
+ # Custom cuda implementation /modules/cupy_module/
19
+ cupy-cuda12x>=12.0.0 # For CUDA 12.4
20
+ # Note: For cupy, you need to install the specific version for your CUDA version
21
+ # Examples:
22
+ # cupy-cuda11x for CUDA 11.x
23
+ # cupy-cuda12x for CUDA 12.x
24
+ # cupy-cuda10x for CUDA 10.x
25
+
26
+ # Utilities and tools
27
+ scipy>=1.7.0
28
+ tensorboard>=2.8.0
29
+
30
+ # Project-Specific Dependencies
31
+ # RAFT (Flow Estimation)
32
+ # Note: RAFT is included in the project code; no external installation is required.
33
+
34
+ # FLOLPIPS (Quality Metrics)
35
+ # Note: FLOLPIPS is included in the project code; no external installation is required.
36
+
37
+ # Gradio
38
+ gradio>=4.34.0
39
+ imageio>=2.34.1
40
+ imageio-ffmpeg>=0.6.0
41
+
42
+
utils/ema.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ class EMA:
5
+ def __init__(self, beta: float):
6
+ super().__init__()
7
+ self.beta = beta
8
+ self.step = 0
9
+
10
+ def update_model_average(self, ema_model: nn.Module, current_model: nn.Module) -> None:
11
+ for current_params, ema_model in zip(current_model.parameters(), ema_model.parameters()):
12
+ old_weight, up_weight = ema_model.data, current_params.data
13
+ ema_model.data = self.update_average(old_weight, up_weight)
14
+
15
+ def update_average(self, old: torch.Tensor | None, new: torch.Tensor) -> torch.Tensor:
16
+ if old is None:
17
+ return new
18
+ return old * self.beta + (1 - self.beta) * new
19
+
20
+ def step_ema(self, ema_model: nn.Module, model: nn.Module, step_start_ema: int = 2000) -> None:
21
+ if self.step < step_start_ema:
22
+ self.reset_parameters(ema_model, model)
23
+ self.step += 1
24
+ return
25
+ self.update_model_average(ema_model, model)
26
+ self.step += 1
27
+
28
+ def copy_to(self, ema_model: nn.Module, model: nn.Module) -> None:
29
+ model.load_state_dict(ema_model.state_dict())
30
+
31
+ def reset_parameters(self, ema_model: nn.Module, model: nn.Module) -> None:
32
+ ema_model.load_state_dict(model.state_dict())
utils/inter_frame_idx.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from utils.utils import morph_open
2
+
3
+ import torch
4
+ from kornia.color import rgb_to_grayscale
5
+
6
+ import cv2
7
+ import numpy as np
8
+
9
+ class FlowEstimation:
10
+ def __init__(self, flow_estimator: str = "farneback"):
11
+ assert flow_estimator in ["farneback", "dualtvl1"], "Flow estimator must be one of [farneback, dualtvl1]"
12
+
13
+ if flow_estimator == "farneback":
14
+ self.flow_estimator = self.OptFlow_Farneback
15
+ elif flow_estimator == "dualtvl1":
16
+ self.flow_estimator = self.OptFlow_DualTVL1
17
+ else:
18
+ raise NotImplementedError
19
+
20
+ def OptFlow_Farneback(self, I0: torch.Tensor, I1: torch.Tensor) -> torch.Tensor:
21
+ device = I0.device
22
+
23
+ I0 = I0.cpu().clamp(0, 1) * 255
24
+ I1 = I1.cpu().clamp(0, 1) * 255
25
+
26
+ batch_size = I0.shape[0]
27
+ for i in range(batch_size):
28
+ I0_np = I0[i].permute(1, 2, 0).numpy().astype(np.uint8)
29
+ I1_np = I1[i].permute(1, 2, 0).numpy().astype(np.uint8)
30
+
31
+ I0_gray = cv2.cvtColor(I0_np, cv2.COLOR_BGR2GRAY)
32
+ I1_gray = cv2.cvtColor(I1_np, cv2.COLOR_BGR2GRAY)
33
+
34
+ flow = cv2.calcOpticalFlowFarneback(I0_gray, I1_gray, None, 0.5, 3, 15, 3, 5, 1.2, 0)
35
+ flow = torch.from_numpy(flow).permute(2, 0, 1).unsqueeze(0).float()
36
+ if i == 0:
37
+ flows = flow
38
+ else:
39
+ flows = torch.cat((flows, flow), dim = 0)
40
+
41
+ return flows.to(device)
42
+
43
+ def OptFlow_DualTVL1(
44
+ self,
45
+ I0: torch.Tensor,
46
+ I1: torch.Tensor,
47
+ tau: float = 0.25,
48
+ lambda_: float = 0.15,
49
+ theta: float = 0.3,
50
+ scales_number: int = 5,
51
+ warps: int = 5,
52
+ epsilon: float = 0.01,
53
+ inner_iterations: int = 30,
54
+ outer_iterations: int = 10,
55
+ scale_step: float = 0.8,
56
+ gamma: float = 0.0
57
+ ) -> torch.Tensor:
58
+ optical_flow = cv2.optflow.createOptFlow_DualTVL1()
59
+ optical_flow.setTau(tau)
60
+ optical_flow.setLambda(lambda_)
61
+ optical_flow.setTheta(theta)
62
+ optical_flow.setScalesNumber(scales_number)
63
+ optical_flow.setWarpingsNumber(warps)
64
+ optical_flow.setEpsilon(epsilon)
65
+ optical_flow.setInnerIterations(inner_iterations)
66
+ optical_flow.setOuterIterations(outer_iterations)
67
+ optical_flow.setScaleStep(scale_step)
68
+ optical_flow.setGamma(gamma)
69
+
70
+ device = I0.device
71
+
72
+ I0 = I0.cpu().clamp(0, 1) * 255
73
+ I1 = I1.cpu().clamp(0, 1) * 255
74
+
75
+ batch_size = I0.shape[0]
76
+ for i in range(batch_size):
77
+ I0_np = I0[i].permute(1, 2, 0).numpy().astype(np.uint8)
78
+ I1_np = I1[i].permute(1, 2, 0).numpy().astype(np.uint8)
79
+
80
+ I0_gray = cv2.cvtColor(I0_np, cv2.COLOR_BGR2GRAY)
81
+ I1_gray = cv2.cvtColor(I1_np, cv2.COLOR_BGR2GRAY)
82
+
83
+ flow = optical_flow.calc(I0_gray, I1_gray, None)
84
+ flow = torch.from_numpy(flow).permute(2, 0, 1).unsqueeze(0).float()
85
+ if i == 0:
86
+ flows = flow
87
+ else:
88
+ flows = torch.cat((flows, flow), dim = 0)
89
+
90
+ return flows.to(device)
91
+
92
+ def __call__(self, I1: torch.Tensor, I0: torch.Tensor) -> torch.Tensor:
93
+ return self.flow_estimator(I1, I0)
94
+
95
+ def get_inter_frame_temp_index(
96
+ I0: torch.Tensor,
97
+ It: torch.Tensor,
98
+ I1: torch.Tensor,
99
+ flow0tot: torch.Tensor,
100
+ flow1tot: torch.Tensor,
101
+ k: int = 5,
102
+ threshold: float = 2e-2
103
+ ) -> torch.Tensor:
104
+
105
+ I0_gray = rgb_to_grayscale(I0)
106
+ It_gray = rgb_to_grayscale(It)
107
+ I1_gray = rgb_to_grayscale(I1)
108
+
109
+ mask0tot = morph_open(It_gray - I0_gray, k=k)
110
+ mask1tot = morph_open(I1_gray - It_gray, k=k)
111
+
112
+ mask0tot = (abs(mask0tot) > threshold).to(torch.uint8)
113
+ mask1tot = (abs(mask1tot) > threshold).to(torch.uint8)
114
+
115
+ flow_mag0tot = torch.sqrt(flow0tot[:, 0, :, :]**2 + flow0tot[:, 1, :, :]**2).unsqueeze(1)
116
+ flow_mag1tot = torch.sqrt(flow1tot[:, 0, :, :]**2 + flow1tot[:, 1, :, :]**2).unsqueeze(1)
117
+
118
+ norm0tot = (flow_mag0tot*mask0tot).squeeze(1)
119
+ norm1tot = (flow_mag1tot*mask1tot).squeeze(1)
120
+ d0tot = torch.sum(norm0tot, dim = (1, 2))
121
+ d1tot = torch.sum(norm1tot, dim = (1, 2))
122
+
123
+ return d0tot / (d0tot + d1tot + 1e-12)
utils/raft.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torchvision.models.optical_flow import raft_large
3
+ from modules.flow_models.raft.rfr_new import RAFT
4
+
5
+ def raft_flow(
6
+ I0: torch.Tensor,
7
+ I1: torch.Tensor,
8
+ data_domain: str = "animation",
9
+ device: str = 'cuda'
10
+ ) -> tuple[torch.Tensor, torch.Tensor]:
11
+ if I0.dtype != torch.float32 or I1.dtype != torch.float32:
12
+ I0 = I0.to(torch.float32)
13
+ I1 = I1.to(torch.float32)
14
+ if data_domain == "animation":
15
+ raft = RAFT().requires_grad_(False).eval().to(device)
16
+ elif data_domain == "photorealism":
17
+ raft = raft_large().requires_grad_(False).eval().to(device)
18
+ else:
19
+ raise ValueError("data_domain must be either 'animation' or 'photorealism'")
20
+ return raft(I0, I1) if data_domain == "animation" else raft(I0, I1)[-1]
utils/uncertainty.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import itertools
3
+ from torchmetrics.image import LearnedPerceptualImagePatchSimilarity as LPIPS
4
+ from utils.utils import denorm
5
+
6
+ def compute_lpips_variability(samples: torch.Tensor,
7
+ net: str = 'alex',
8
+ device: str = 'cuda'
9
+ ) -> float:
10
+ loss_fn = LPIPS(net_type=net).to(device)
11
+ loss_fn.eval()
12
+
13
+ if samples.min() >= 0.0:
14
+ samples = samples * 2 - 1 # Convertir [0, 1] → [-1, 1]
15
+
16
+ N = samples.size(0)
17
+ scores = []
18
+ for i, j in itertools.combinations(range(N), 2):
19
+ x = samples[i:i+1].to(device)
20
+ y = samples[j:j+1].to(device)
21
+ dist = loss_fn(denorm(x.clamp(-1, 1)), denorm(y.clamp(-1, 1)))
22
+ scores.append(dist.item())
23
+
24
+ return sum(scores) / len(scores)
25
+
26
+ def compute_pixelwise_correlation(samples: torch.Tensor) -> float:
27
+ N, C, H, W = samples.shape
28
+ samples_flat = samples.view(N, C, -1) # (N, C, H*W)
29
+
30
+ corrs = []
31
+ for i, j in itertools.combinations(range(N), 2):
32
+ x = samples_flat[i] # (C, HW)
33
+ y = samples_flat[j] # (C, HW)
34
+ mean_x = x.mean(dim=1, keepdim=True)
35
+ mean_y = y.mean(dim=1, keepdim=True)
36
+ x_centered = x - mean_x
37
+ y_centered = y - mean_y
38
+ numerator = (x_centered * y_centered).sum(dim=1)
39
+ denominator = (x_centered.norm(dim=1) * y_centered.norm(dim=1)) + 1e-8
40
+ corr = numerator / denominator # (C,)
41
+ corrs.append(corr.mean().item())
42
+ return sum(corrs) / len(corrs)
43
+
44
+ def compute_dynamic_range(samples: torch.Tensor) -> float:
45
+ max_vals, _ = samples.max(dim=0) # (C, H, W)
46
+ min_vals, _ = samples.min(dim=0) # (C, H, W)
47
+
48
+ dynamic_range = max_vals - min_vals # (C, H, W)
49
+ return dynamic_range.mean().item()
utils/utils.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import numpy as np
4
+ import matplotlib.pyplot as plt
5
+
6
+ try:
7
+ from kornia.morphology import opening
8
+ except ImportError:
9
+ from kornia.morphology import open as opening
10
+
11
+ from torchvision import transforms
12
+ from torchvision.utils import make_grid, save_image
13
+
14
+ from typing import Any
15
+
16
+ def exist(val: Any) -> bool:
17
+ return val is not None
18
+
19
+ def morph_open(x: torch.Tensor, k: int) -> torch.Tensor:
20
+ if k==0:
21
+ return x
22
+ else:
23
+ with torch.no_grad():
24
+ return opening(x, torch.ones(k,k,device=x.device))
25
+
26
+ def make_grid_images(images: list[torch.Tensor], **kwargs) -> torch.Tensor:
27
+ concatenated_images = torch.cat(images, dim=3)
28
+ grid_concatenated = make_grid(concatenated_images, **kwargs)
29
+ return grid_concatenated
30
+
31
+ def save_images(images: tuple[torch.Tensor, torch.Tensor], path: str, **kwargs) -> None:
32
+ gen, real = images
33
+ concatenated_images = torch.cat((gen, real), dim=3)
34
+ grid_concatenated = make_grid(concatenated_images, **kwargs)
35
+
36
+ ndarr_concatenated = grid_concatenated.permute(1, 2, 0).to("cpu").numpy()
37
+ ndarr_concatenated = (ndarr_concatenated * 255).astype(np.uint8)
38
+
39
+ save_image(torch.from_numpy(ndarr_concatenated).permute(2, 0, 1) / 255, path)
40
+
41
+ def save_triplet(images: tuple[torch.Tensor, ...], path: str, **kwargs) -> None:
42
+ concatenated_images = torch.cat(images, dim=3)
43
+ grid_concatenated = make_grid(concatenated_images, **kwargs)
44
+
45
+ ndarr_concatenated = grid_concatenated.permute(1, 2, 0).to("cpu").numpy()
46
+ ndarr_concatenated = (ndarr_concatenated * 255).astype(np.uint8)
47
+
48
+ save_image(torch.from_numpy(ndarr_concatenated).permute(2, 0, 1) / 255, path)
49
+
50
+ def plot_images(images: torch.Tensor) -> None:
51
+ plt.figure(figsize=(32, 32))
52
+ plt.imshow(torch.cat([
53
+ torch.cat([i for i in images.cpu()], dim=-1),
54
+ ], dim=-2).permute(1, 2, 0).cpu())
55
+ plt.show()
56
+
57
+ def make_graphic(metric_name: str, metrics: list[torch.Tensor], path: str) -> None:
58
+ plt.figure(figsize=(32, 32))
59
+ metrics = [m.cpu().numpy() for m in metrics]
60
+ plt.plot(metrics)
61
+ plt.title(metric_name)
62
+ plt.xlabel("Epoch")
63
+ plt.ylabel(metric_name)
64
+ path = os.path.join(path, f"{metric_name}.png")
65
+ plt.savefig(path)
66
+ plt.close()
67
+
68
+ def norm(
69
+ img: torch.Tensor,
70
+ mean: list[float] = [0.5, 0.5, 0.5],
71
+ std: list[float] = [0.5, 0.5, 0.5]
72
+ ) -> torch.Tensor:
73
+ normalize = transforms.Normalize(mean, std)
74
+ return normalize(img)
75
+
76
+ def denorm(
77
+ img: torch.Tensor,
78
+ mean: list[float] = [0.5, 0.5, 0.5],
79
+ std: list[float] = [0.5, 0.5, 0.5]
80
+ ) -> torch.Tensor:
81
+ mean = torch.tensor(mean, device=img.device)
82
+ std = torch.tensor(std, device=img.device)
83
+ return img*std[None][...,None,None] + mean[None][...,None,None]