Beijia11 commited on
Commit
3aba902
·
1 Parent(s): 686bb9b
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitignore +200 -0
  2. .gitmodules +3 -0
  3. app.py +577 -0
  4. config/__init__.py +0 -0
  5. config/base_cfg.py +410 -0
  6. config/ssm_cfg.py +347 -0
  7. config/yacs.py +506 -0
  8. demo.py +206 -0
  9. models/cogvideox_tracking.py +1020 -0
  10. models/pipelines.py +1040 -0
  11. models/spatracker/__init__.py +5 -0
  12. models/spatracker/models/__init__.py +5 -0
  13. models/spatracker/models/build_spatracker.py +51 -0
  14. models/spatracker/models/core/__init__.py +5 -0
  15. models/spatracker/models/core/embeddings.py +250 -0
  16. models/spatracker/models/core/model_utils.py +477 -0
  17. models/spatracker/models/core/spatracker/__init__.py +5 -0
  18. models/spatracker/models/core/spatracker/blocks.py +999 -0
  19. models/spatracker/models/core/spatracker/dpt/__init__.py +0 -0
  20. models/spatracker/models/core/spatracker/dpt/base_model.py +16 -0
  21. models/spatracker/models/core/spatracker/dpt/blocks.py +394 -0
  22. models/spatracker/models/core/spatracker/dpt/midas_net.py +77 -0
  23. models/spatracker/models/core/spatracker/dpt/models.py +231 -0
  24. models/spatracker/models/core/spatracker/dpt/transforms.py +231 -0
  25. models/spatracker/models/core/spatracker/dpt/vit.py +596 -0
  26. models/spatracker/models/core/spatracker/feature_net.py +915 -0
  27. models/spatracker/models/core/spatracker/loftr/__init__.py +1 -0
  28. models/spatracker/models/core/spatracker/loftr/linear_attention.py +81 -0
  29. models/spatracker/models/core/spatracker/loftr/transformer.py +142 -0
  30. models/spatracker/models/core/spatracker/losses.py +90 -0
  31. models/spatracker/models/core/spatracker/softsplat.py +539 -0
  32. models/spatracker/models/core/spatracker/spatracker.py +732 -0
  33. models/spatracker/models/core/spatracker/unet.py +258 -0
  34. models/spatracker/models/core/spatracker/vit/__init__.py +0 -0
  35. models/spatracker/models/core/spatracker/vit/common.py +43 -0
  36. models/spatracker/models/core/spatracker/vit/encoder.py +397 -0
  37. models/spatracker/predictor.py +284 -0
  38. models/spatracker/utils/__init__.py +5 -0
  39. models/spatracker/utils/basic.py +397 -0
  40. models/spatracker/utils/geom.py +547 -0
  41. models/spatracker/utils/improc.py +1447 -0
  42. models/spatracker/utils/misc.py +166 -0
  43. models/spatracker/utils/samp.py +152 -0
  44. models/spatracker/utils/visualizer.py +409 -0
  45. models/spatracker/utils/vox.py +500 -0
  46. requirements.txt +32 -0
  47. submodules/MoGe/.gitignore +425 -0
  48. submodules/MoGe/CHANGELOG.md +15 -0
  49. submodules/MoGe/CODE_OF_CONDUCT.md +9 -0
  50. submodules/MoGe/LICENSE +224 -0
.gitignore ADDED
@@ -0,0 +1,200 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # JetBrains
7
+ .idea
8
+
9
+ # C extensions
10
+ *.so
11
+
12
+ # Distribution / packaging
13
+ .Python
14
+ build/
15
+ develop-eggs/
16
+ dist/
17
+ downloads/
18
+ eggs/
19
+ .eggs/
20
+ lib/
21
+ lib64/
22
+ parts/
23
+ sdist/
24
+ var/
25
+ wheels/
26
+ share/python-wheels/
27
+ *.egg-info/
28
+ .installed.cfg
29
+ *.egg
30
+ MANIFEST
31
+
32
+ # PyInstaller
33
+ # Usually these files are written by a python script from a template
34
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
35
+ *.manifest
36
+ *.spec
37
+
38
+ # Installer logs
39
+ pip-log.txt
40
+ pip-delete-this-directory.txt
41
+
42
+ # Unit test / coverage reports
43
+ htmlcov/
44
+ .tox/
45
+ .nox/
46
+ .coverage
47
+ .coverage.*
48
+ .cache
49
+ nosetests.xml
50
+ coverage.xml
51
+ *.cover
52
+ *.py,cover
53
+ .hypothesis/
54
+ .pytest_cache/
55
+ cover/
56
+
57
+ # Translations
58
+ *.mo
59
+ *.pot
60
+
61
+ # Django stuff:
62
+ *.log
63
+ local_settings.py
64
+ db.sqlite3
65
+ db.sqlite3-journal
66
+
67
+ # Flask stuff:
68
+ instance/
69
+ .webassets-cache
70
+
71
+ # Scrapy stuff:
72
+ .scrapy
73
+
74
+ # Sphinx documentation
75
+ docs/_build/
76
+
77
+ # PyBuilder
78
+ .pybuilder/
79
+ target/
80
+
81
+ # Jupyter Notebook
82
+ .ipynb_checkpoints
83
+
84
+ # IPython
85
+ profile_default/
86
+ ipython_config.py
87
+
88
+ # pyenv
89
+ # For a library or package, you might want to ignore these files since the code is
90
+ # intended to run in multiple environments; otherwise, check them in:
91
+ # .python-version
92
+
93
+ # pipenv
94
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
95
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
96
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
97
+ # install all needed dependencies.
98
+ #Pipfile.lock
99
+
100
+ # poetry
101
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
102
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
103
+ # commonly ignored for libraries.
104
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
105
+ #poetry.lock
106
+
107
+ # pdm
108
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
109
+ #pdm.lock
110
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
111
+ # in version control.
112
+ # https://pdm.fming.dev/latest/usage/project/#working-with-version-control
113
+ .pdm.toml
114
+ .pdm-python
115
+ .pdm-build/
116
+
117
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
118
+ __pypackages__/
119
+
120
+ # Celery stuff
121
+ celerybeat-schedule
122
+ celerybeat.pid
123
+
124
+ # SageMath parsed files
125
+ *.sage.py
126
+
127
+ # Environments
128
+ .env
129
+ .venv
130
+ env/
131
+ venv/
132
+ ENV/
133
+ env.bak/
134
+ venv.bak/
135
+
136
+ # Spyder project settings
137
+ .spyderproject
138
+ .spyproject
139
+
140
+ # Rope project settings
141
+ .ropeproject
142
+
143
+ # mkdocs documentation
144
+ /site
145
+
146
+ # mypy
147
+ .mypy_cache/
148
+ .dmypy.json
149
+ dmypy.json
150
+
151
+ # Pyre type checker
152
+ .pyre/
153
+
154
+ # pytype static type analyzer
155
+ .pytype/
156
+
157
+ # Cython debug symbols
158
+ cython_debug/
159
+
160
+ # PyCharm
161
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
162
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
163
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
164
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
165
+ #.idea/
166
+
167
+ # manually added
168
+ wandb/
169
+ dump*
170
+
171
+ !requirements.txt
172
+ env/
173
+ datasets/
174
+ validation/
175
+ ckpts/
176
+ .vscode/
177
+ output.mp4
178
+ outputs/
179
+ camctrl_output
180
+ *.code-workspace
181
+
182
+ **/*/.DS_Store
183
+ **/*/__pycache__/*
184
+ .DS_Store
185
+ __pycache__
186
+ vis_results
187
+ checkpoints
188
+ **/*/.pth
189
+ **/*/.pt
190
+ **/*/.mp4
191
+ **/*/.npy
192
+
193
+ /assets/**
194
+ ./vis_results/** */
195
+ models/monoD/zoeDepth/ckpts/*
196
+ slurm-*.out
197
+ .vscode
198
+
199
+ data/
200
+ tmp/
.gitmodules ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ [submodule "submodules/MoGe"]
2
+ path = submodules/MoGe
3
+ url = https://github.com/microsoft/MoGe.git
app.py ADDED
@@ -0,0 +1,577 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import gradio as gr
4
+ import torch
5
+ import subprocess
6
+ import argparse
7
+ import glob
8
+
9
+ project_root = os.path.dirname(os.path.abspath(__file__))
10
+ os.environ["GRADIO_TEMP_DIR"] = os.path.join(project_root, "tmp", "gradio")
11
+ sys.path.append(project_root)
12
+
13
+ HERE_PATH = os.path.normpath(os.path.dirname(__file__))
14
+ sys.path.insert(0, HERE_PATH)
15
+ from huggingface_hub import hf_hub_download
16
+ hf_hub_download(repo_id="EXCAI/Diffusion-As-Shader", filename='spatracker/spaT_final.pth', local_dir=f'{HERE_PATH}/checkpoints/')
17
+
18
+
19
+ # Parse command line arguments
20
+ parser = argparse.ArgumentParser(description="Diffusion as Shader Web UI")
21
+ parser.add_argument("--port", type=int, default=7860, help="Port to run the web UI on")
22
+ parser.add_argument("--share", action="store_true", help="Share the web UI")
23
+ parser.add_argument("--gpu", type=int, default=0, help="GPU device ID")
24
+ parser.add_argument("--model_path", type=str, default="EXCAI/Diffusion-As-Shader", help="Path to model checkpoint")
25
+ parser.add_argument("--output_dir", type=str, default="tmp", help="Output directory")
26
+ args = parser.parse_args()
27
+
28
+ # Use the original GPU ID throughout the entire code for consistency
29
+ GPU_ID = args.gpu
30
+
31
+ # Set environment variables - this used to remap the GPU, but we're removing this for consistency
32
+ # Instead, we'll pass the original GPU ID to all commands
33
+ # os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu) # Commented out to ensure consistent GPU ID usage
34
+
35
+ # Check if CUDA is available
36
+ CUDA_AVAILABLE = torch.cuda.is_available()
37
+ if CUDA_AVAILABLE:
38
+ GPU_COUNT = torch.cuda.device_count()
39
+ GPU_NAMES = [f"{i}: {torch.cuda.get_device_name(i)}" for i in range(GPU_COUNT)]
40
+ else:
41
+ GPU_COUNT = 0
42
+ GPU_NAMES = ["CPU (CUDA not available)"]
43
+ GPU_ID = "CPU"
44
+
45
+ DEFAULT_MODEL_PATH = args.model_path
46
+ OUTPUT_DIR = args.output_dir
47
+
48
+ # Create necessary directories
49
+ os.makedirs("outputs", exist_ok=True)
50
+ # Create project tmp directory instead of using system temp
51
+ os.makedirs(os.path.join(project_root, "tmp"), exist_ok=True)
52
+ os.makedirs(os.path.join(project_root, "tmp", "gradio"), exist_ok=True)
53
+
54
+ def save_uploaded_file(file):
55
+ if file is None:
56
+ return None
57
+
58
+ # Use project tmp directory instead of system temp
59
+ temp_dir = os.path.join(project_root, "tmp")
60
+
61
+ if hasattr(file, 'name'):
62
+ filename = file.name
63
+ else:
64
+ # Generate a unique filename if name attribute is missing
65
+ import uuid
66
+ ext = ".tmp"
67
+ if hasattr(file, 'content_type'):
68
+ if "image" in file.content_type:
69
+ ext = ".png"
70
+ elif "video" in file.content_type:
71
+ ext = ".mp4"
72
+ filename = f"{uuid.uuid4()}{ext}"
73
+
74
+ temp_path = os.path.join(temp_dir, filename)
75
+
76
+ try:
77
+ # Check if file is a FileStorage object or already a path
78
+ if hasattr(file, 'save'):
79
+ file.save(temp_path)
80
+ elif isinstance(file, str):
81
+ # It's already a path
82
+ return file
83
+ else:
84
+ # Try to read and save the file
85
+ with open(temp_path, 'wb') as f:
86
+ f.write(file.read() if hasattr(file, 'read') else file)
87
+ except Exception as e:
88
+ print(f"Error saving file: {e}")
89
+ return None
90
+
91
+ return temp_path
92
+
93
+ def create_run_command(args):
94
+ """Create command based on input parameters"""
95
+ cmd = ["python", "demo.py"]
96
+
97
+ if "prompt" not in args or args["prompt"] is None or args["prompt"] == "":
98
+ args["prompt"] = ""
99
+ if "checkpoint_path" not in args or args["checkpoint_path"] is None or args["checkpoint_path"] == "":
100
+ args["checkpoint_path"] = DEFAULT_MODEL_PATH
101
+
102
+ # 添加调试输出
103
+ print(f"DEBUG: Command args: {args}")
104
+
105
+ for key, value in args.items():
106
+ if value is not None:
107
+ # Handle boolean values correctly - for repaint, we need to pass true/false
108
+ if isinstance(value, bool):
109
+ cmd.append(f"--{key}")
110
+ cmd.append(str(value).lower()) # Convert True/False to true/false
111
+ else:
112
+ cmd.append(f"--{key}")
113
+ cmd.append(str(value))
114
+
115
+ return cmd
116
+
117
+ def run_process(cmd):
118
+ """Run command and return output"""
119
+ print(f"Running command: {' '.join(cmd)}")
120
+ process = subprocess.Popen(
121
+ cmd,
122
+ stdout=subprocess.PIPE,
123
+ stderr=subprocess.PIPE,
124
+ universal_newlines=True
125
+ )
126
+
127
+ output = []
128
+ for line in iter(process.stdout.readline, ""):
129
+ print(line, end="")
130
+ output.append(line)
131
+ if not line:
132
+ break
133
+
134
+ process.stdout.close()
135
+ return_code = process.wait()
136
+
137
+ if return_code:
138
+ stderr = process.stderr.read()
139
+ print(f"Error: {stderr}")
140
+ raise subprocess.CalledProcessError(return_code, cmd, output="\n".join(output), stderr=stderr)
141
+
142
+ return "\n".join(output)
143
+
144
+ # Process functions for each tab
145
+ def process_motion_transfer(source, prompt, mt_repaint_option, mt_repaint_image):
146
+ """Process video motion transfer task"""
147
+ try:
148
+ # Save uploaded files
149
+ input_video_path = save_uploaded_file(source)
150
+ if input_video_path is None:
151
+ return None
152
+
153
+ print(f"DEBUG: Repaint option: {mt_repaint_option}")
154
+ print(f"DEBUG: Repaint image: {mt_repaint_image}")
155
+
156
+ args = {
157
+ "input_path": input_video_path,
158
+ "prompt": f"\"{prompt}\"",
159
+ "checkpoint_path": DEFAULT_MODEL_PATH,
160
+ "output_dir": OUTPUT_DIR,
161
+ "gpu": GPU_ID
162
+ }
163
+
164
+ # Priority: Custom Image > Yes > No
165
+ if mt_repaint_image is not None:
166
+ # Custom image takes precedence if provided
167
+ repaint_path = save_uploaded_file(mt_repaint_image)
168
+ print(f"DEBUG: Repaint path: {repaint_path}")
169
+ args["repaint"] = repaint_path
170
+ elif mt_repaint_option == "Yes":
171
+ # Otherwise use Yes/No selection
172
+ args["repaint"] = "true"
173
+
174
+ # Create and run command
175
+ cmd = create_run_command(args)
176
+ output = run_process(cmd)
177
+
178
+ # Find generated video files
179
+ output_files = glob.glob(os.path.join(OUTPUT_DIR, "*.mp4"))
180
+ if output_files:
181
+ # Sort by modification time, return the latest file
182
+ latest_file = max(output_files, key=os.path.getmtime)
183
+ return latest_file
184
+ else:
185
+ return None
186
+ except Exception as e:
187
+ import traceback
188
+ print(f"Processing failed: {str(e)}\n{traceback.format_exc()}")
189
+ return None
190
+
191
+ def process_camera_control(source, prompt, camera_motion, tracking_method):
192
+ """Process camera control task"""
193
+ try:
194
+ # Save uploaded files
195
+ input_media_path = save_uploaded_file(source)
196
+ if input_media_path is None:
197
+ return None
198
+
199
+ print(f"DEBUG: Camera motion: '{camera_motion}'")
200
+ print(f"DEBUG: Tracking method: '{tracking_method}'")
201
+
202
+ args = {
203
+ "input_path": input_media_path,
204
+ "prompt": prompt,
205
+ "checkpoint_path": DEFAULT_MODEL_PATH,
206
+ "output_dir": OUTPUT_DIR,
207
+ "gpu": GPU_ID,
208
+ "tracking_method": tracking_method
209
+ }
210
+
211
+ if camera_motion and camera_motion.strip():
212
+ args["camera_motion"] = camera_motion
213
+
214
+ # Create and run command
215
+ cmd = create_run_command(args)
216
+ output = run_process(cmd)
217
+
218
+ # Find generated video files
219
+ output_files = glob.glob(os.path.join(OUTPUT_DIR, "*.mp4"))
220
+ if output_files:
221
+ # Sort by modification time, return the latest file
222
+ latest_file = max(output_files, key=os.path.getmtime)
223
+ return latest_file
224
+ else:
225
+ return None
226
+ except Exception as e:
227
+ import traceback
228
+ print(f"Processing failed: {str(e)}\n{traceback.format_exc()}")
229
+ return None
230
+
231
+ def process_object_manipulation(source, prompt, object_motion, object_mask, tracking_method):
232
+ """Process object manipulation task"""
233
+ try:
234
+ # Save uploaded files
235
+ input_image_path = save_uploaded_file(source)
236
+ if input_image_path is None:
237
+ return None
238
+
239
+ object_mask_path = save_uploaded_file(object_mask)
240
+
241
+ args = {
242
+ "input_path": input_image_path,
243
+ "prompt": prompt,
244
+ "checkpoint_path": DEFAULT_MODEL_PATH,
245
+ "output_dir": OUTPUT_DIR,
246
+ "gpu": GPU_ID,
247
+ "object_motion": object_motion,
248
+ "object_mask": object_mask_path,
249
+ "tracking_method": tracking_method
250
+ }
251
+
252
+ # Create and run command
253
+ cmd = create_run_command(args)
254
+ output = run_process(cmd)
255
+
256
+ # Find generated video files
257
+ output_files = glob.glob(os.path.join(OUTPUT_DIR, "*.mp4"))
258
+ if output_files:
259
+ # Sort by modification time, return the latest file
260
+ latest_file = max(output_files, key=os.path.getmtime)
261
+ return latest_file
262
+ else:
263
+ return None
264
+ except Exception as e:
265
+ import traceback
266
+ print(f"Processing failed: {str(e)}\n{traceback.format_exc()}")
267
+ return None
268
+
269
+ def process_mesh_animation(source, prompt, tracking_video, ma_repaint_option, ma_repaint_image):
270
+ """Process mesh animation task"""
271
+ try:
272
+ # Save uploaded files
273
+ input_video_path = save_uploaded_file(source)
274
+ if input_video_path is None:
275
+ return None
276
+
277
+ tracking_video_path = save_uploaded_file(tracking_video)
278
+ if tracking_video_path is None:
279
+ return None
280
+
281
+ args = {
282
+ "input_path": input_video_path,
283
+ "prompt": prompt,
284
+ "checkpoint_path": DEFAULT_MODEL_PATH,
285
+ "output_dir": OUTPUT_DIR,
286
+ "gpu": GPU_ID,
287
+ "tracking_path": tracking_video_path
288
+ }
289
+
290
+ # Priority: Custom Image > Yes > No
291
+ if ma_repaint_image is not None:
292
+ # Custom image takes precedence if provided
293
+ repaint_path = save_uploaded_file(ma_repaint_image)
294
+ args["repaint"] = repaint_path
295
+ elif ma_repaint_option == "Yes":
296
+ # Otherwise use Yes/No selection
297
+ args["repaint"] = "true"
298
+
299
+ # Create and run command
300
+ cmd = create_run_command(args)
301
+ output = run_process(cmd)
302
+
303
+ # Find generated video files
304
+ output_files = glob.glob(os.path.join(OUTPUT_DIR, "*.mp4"))
305
+ if output_files:
306
+ # Sort by modification time, return the latest file
307
+ latest_file = max(output_files, key=os.path.getmtime)
308
+ return latest_file
309
+ else:
310
+ return None
311
+ except Exception as e:
312
+ import traceback
313
+ print(f"Processing failed: {str(e)}\n{traceback.format_exc()}")
314
+ return None
315
+
316
+ # Create Gradio interface with updated layout
317
+ with gr.Blocks(title="Diffusion as Shader") as demo:
318
+ gr.Markdown("# Diffusion as Shader Web UI")
319
+ gr.Markdown("### [Project Page](https://igl-hkust.github.io/das/) | [GitHub](https://github.com/IGL-HKUST/DiffusionAsShader)")
320
+
321
+ with gr.Row():
322
+ left_column = gr.Column(scale=1)
323
+ right_column = gr.Column(scale=1)
324
+
325
+ with right_column:
326
+ output_video = gr.Video(label="Generated Video")
327
+
328
+ with left_column:
329
+ source = gr.File(label="Source", file_types=["image", "video"])
330
+ common_prompt = gr.Textbox(label="Prompt", lines=2)
331
+ gr.Markdown(f"**Using GPU: {GPU_ID}**")
332
+
333
+ with gr.Tabs() as task_tabs:
334
+ # Motion Transfer tab
335
+ with gr.TabItem("Motion Transfer"):
336
+ gr.Markdown("## Motion Transfer")
337
+
338
+ # Simplified controls - Radio buttons for Yes/No and separate file upload
339
+ with gr.Row():
340
+ mt_repaint_option = gr.Radio(
341
+ label="Repaint First Frame",
342
+ choices=["No", "Yes"],
343
+ value="No"
344
+ )
345
+ gr.Markdown("### Note: If you want to use your own image as repainted first frame, please upload the image in below.")
346
+ # Custom image uploader (always visible)
347
+ mt_repaint_image = gr.File(
348
+ label="Custom Repaint Image",
349
+ file_types=["image"]
350
+ )
351
+
352
+ # Add run button for Motion Transfer tab
353
+ mt_run_btn = gr.Button("Run Motion Transfer", variant="primary", size="lg")
354
+
355
+ # Connect to process function
356
+ mt_run_btn.click(
357
+ fn=process_motion_transfer,
358
+ inputs=[
359
+ source, common_prompt,
360
+ mt_repaint_option, mt_repaint_image
361
+ ],
362
+ outputs=[output_video]
363
+ )
364
+
365
+ # Camera Control tab
366
+ with gr.TabItem("Camera Control"):
367
+ gr.Markdown("## Camera Control")
368
+
369
+ cc_camera_motion = gr.Textbox(
370
+ label="Current Camera Motion Sequence",
371
+ placeholder="Your camera motion sequence will appear here...",
372
+ interactive=False
373
+ )
374
+
375
+ # Use tabs for different motion types
376
+ with gr.Tabs() as cc_motion_tabs:
377
+ # Translation tab
378
+ with gr.TabItem("Translation (trans)"):
379
+ with gr.Row():
380
+ cc_trans_x = gr.Slider(minimum=-1.0, maximum=1.0, value=0.0, step=0.05, label="X-axis Movement")
381
+ cc_trans_y = gr.Slider(minimum=-1.0, maximum=1.0, value=0.0, step=0.05, label="Y-axis Movement")
382
+ cc_trans_z = gr.Slider(minimum=-1.0, maximum=1.0, value=0.0, step=0.05, label="Z-axis Movement (depth)")
383
+
384
+ with gr.Row():
385
+ cc_trans_start = gr.Number(minimum=0, maximum=48, value=0, step=1, label="Start Frame", precision=0)
386
+ cc_trans_end = gr.Number(minimum=0, maximum=48, value=48, step=1, label="End Frame", precision=0)
387
+
388
+ cc_trans_note = gr.Markdown("""
389
+ **Translation Notes:**
390
+ - Positive X: Move right, Negative X: Move left
391
+ - Positive Y: Move down, Negative Y: Move up
392
+ - Positive Z: Zoom in, Negative Z: Zoom out
393
+ """)
394
+
395
+ # Add translation button in the Translation tab
396
+ cc_add_trans = gr.Button("Add Camera Translation", variant="secondary")
397
+
398
+ # Function to add translation motion
399
+ def add_translation_motion(current_motion, trans_x, trans_y, trans_z, trans_start, trans_end):
400
+ # Format: trans dx dy dz [start_frame end_frame]
401
+ frame_range = f" {int(trans_start)} {int(trans_end)}" if trans_start != 0 or trans_end != 48 else ""
402
+ new_motion = f"trans {trans_x:.2f} {trans_y:.2f} {trans_z:.2f}{frame_range}"
403
+
404
+ # Append to existing motion string with semicolon separator if needed
405
+ if current_motion and current_motion.strip():
406
+ updated_motion = f"{current_motion}; {new_motion}"
407
+ else:
408
+ updated_motion = new_motion
409
+
410
+ return updated_motion
411
+
412
+ # Connect translation button
413
+ cc_add_trans.click(
414
+ fn=add_translation_motion,
415
+ inputs=[
416
+ cc_camera_motion,
417
+ cc_trans_x, cc_trans_y, cc_trans_z, cc_trans_start, cc_trans_end
418
+ ],
419
+ outputs=[cc_camera_motion]
420
+ )
421
+
422
+ # Rotation tab
423
+ with gr.TabItem("Rotation (rot)"):
424
+ with gr.Row():
425
+ cc_rot_axis = gr.Dropdown(choices=["x", "y", "z"], value="y", label="Rotation Axis")
426
+ cc_rot_angle = gr.Slider(minimum=-30, maximum=30, value=5, step=1, label="Rotation Angle (degrees)")
427
+
428
+ with gr.Row():
429
+ cc_rot_start = gr.Number(minimum=0, maximum=48, value=0, step=1, label="Start Frame", precision=0)
430
+ cc_rot_end = gr.Number(minimum=0, maximum=48, value=48, step=1, label="End Frame", precision=0)
431
+
432
+ cc_rot_note = gr.Markdown("""
433
+ **Rotation Notes:**
434
+ - X-axis rotation: Tilt camera up/down
435
+ - Y-axis rotation: Pan camera left/right
436
+ - Z-axis rotation: Roll camera
437
+ """)
438
+
439
+ # Add rotation button in the Rotation tab
440
+ cc_add_rot = gr.Button("Add Camera Rotation", variant="secondary")
441
+
442
+ # Function to add rotation motion
443
+ def add_rotation_motion(current_motion, rot_axis, rot_angle, rot_start, rot_end):
444
+ # Format: rot axis angle [start_frame end_frame]
445
+ frame_range = f" {int(rot_start)} {int(rot_end)}" if rot_start != 0 or rot_end != 48 else ""
446
+ new_motion = f"rot {rot_axis} {rot_angle}{frame_range}"
447
+
448
+ # Append to existing motion string with semicolon separator if needed
449
+ if current_motion and current_motion.strip():
450
+ updated_motion = f"{current_motion}; {new_motion}"
451
+ else:
452
+ updated_motion = new_motion
453
+
454
+ return updated_motion
455
+
456
+ # Connect rotation button
457
+ cc_add_rot.click(
458
+ fn=add_rotation_motion,
459
+ inputs=[
460
+ cc_camera_motion,
461
+ cc_rot_axis, cc_rot_angle, cc_rot_start, cc_rot_end
462
+ ],
463
+ outputs=[cc_camera_motion]
464
+ )
465
+
466
+ # Add a clear button to reset the motion sequence
467
+ cc_clear_motion = gr.Button("Clear All Motions", variant="stop")
468
+
469
+ def clear_camera_motion():
470
+ return ""
471
+
472
+ cc_clear_motion.click(
473
+ fn=clear_camera_motion,
474
+ inputs=[],
475
+ outputs=[cc_camera_motion]
476
+ )
477
+
478
+ cc_tracking_method = gr.Radio(
479
+ label="Tracking Method",
480
+ choices=["spatracker", "moge"],
481
+ value="moge"
482
+ )
483
+
484
+ # Add run button for Camera Control tab
485
+ cc_run_btn = gr.Button("Run Camera Control", variant="primary", size="lg")
486
+
487
+ # Connect to process function
488
+ cc_run_btn.click(
489
+ fn=process_camera_control,
490
+ inputs=[
491
+ source, common_prompt,
492
+ cc_camera_motion, cc_tracking_method
493
+ ],
494
+ outputs=[output_video]
495
+ )
496
+
497
+ # Object Manipulation tab
498
+ with gr.TabItem("Object Manipulation"):
499
+ gr.Markdown("## Object Manipulation")
500
+ om_object_mask = gr.File(
501
+ label="Object Mask Image",
502
+ file_types=["image"]
503
+ )
504
+ gr.Markdown("Upload a binary mask image, white areas indicate the object to manipulate")
505
+ om_object_motion = gr.Dropdown(
506
+ label="Object Motion Type",
507
+ choices=["up", "down", "left", "right", "front", "back", "rot"],
508
+ value="up"
509
+ )
510
+ om_tracking_method = gr.Radio(
511
+ label="Tracking Method",
512
+ choices=["spatracker", "moge"],
513
+ value="moge"
514
+ )
515
+
516
+ # Add run button for Object Manipulation tab
517
+ om_run_btn = gr.Button("Run Object Manipulation", variant="primary", size="lg")
518
+
519
+ # Connect to process function
520
+ om_run_btn.click(
521
+ fn=process_object_manipulation,
522
+ inputs=[
523
+ source, common_prompt,
524
+ om_object_motion, om_object_mask, om_tracking_method
525
+ ],
526
+ outputs=[output_video]
527
+ )
528
+
529
+ # Animating meshes to video tab
530
+ with gr.TabItem("Animating meshes to video"):
531
+ gr.Markdown("## Mesh Animation to Video")
532
+ gr.Markdown("""
533
+ Note: Currently only supports tracking videos generated with Blender (version > 4.0).
534
+ Please run the script `scripts/blender.py` in your Blender project to generate tracking videos.
535
+ """)
536
+ ma_tracking_video = gr.File(
537
+ label="Tracking Video",
538
+ file_types=["video"]
539
+ )
540
+ gr.Markdown("Tracking video needs to be generated from Blender")
541
+
542
+ # Simplified controls - Radio buttons for Yes/No and separate file upload
543
+ with gr.Row():
544
+ ma_repaint_option = gr.Radio(
545
+ label="Repaint First Frame",
546
+ choices=["No", "Yes"],
547
+ value="No"
548
+ )
549
+ gr.Markdown("### Note: If you want to use your own image as repainted first frame, please upload the image in below.")
550
+ # Custom image uploader (always visible)
551
+ ma_repaint_image = gr.File(
552
+ label="Custom Repaint Image",
553
+ file_types=["image"]
554
+ )
555
+
556
+ # Add run button for Mesh Animation tab
557
+ ma_run_btn = gr.Button("Run Mesh Animation", variant="primary", size="lg")
558
+
559
+ # Connect to process function
560
+ ma_run_btn.click(
561
+ fn=process_mesh_animation,
562
+ inputs=[
563
+ source, common_prompt,
564
+ ma_tracking_video, ma_repaint_option, ma_repaint_image
565
+ ],
566
+ outputs=[output_video]
567
+ )
568
+
569
+ # Launch interface
570
+ if __name__ == "__main__":
571
+ print(f"Using GPU: {GPU_ID}")
572
+ print(f"Web UI will start on port {args.port}")
573
+ if args.share:
574
+ print("Creating public link for remote access")
575
+
576
+ # Launch interface
577
+ demo.launch(share=args.share, server_port=args.port)
config/__init__.py ADDED
File without changes
config/base_cfg.py ADDED
@@ -0,0 +1,410 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #python3.10
2
+ """Hierachical configuration for different pipelines, using `yacs`
3
+ (refered to https://github.com/rbgirshick/yacs)
4
+
5
+ This projects contain the configuration for three aspects:
6
+ the regular config for experiment setting
7
+
8
+ NOTE: Each experiment will be assigned a seperate working space, and the
9
+ intermediate results will be saved in the working space. The experimentes
10
+ folder structure is as follows:
11
+ {
12
+ /${ROOT_WORK_DIR}/
13
+ └── ${PIPELINES_NAME}/
14
+ ├── ${EXP_NAME}/
15
+ ├── ${CHECKPOINT_DIR}/
16
+ ├── ${RESULT_DIR}/
17
+ ├── meta.json/
18
+ └── ${LOG_DIR}
19
+ }
20
+
21
+ """
22
+
23
+ import os, sys
24
+ from .yacs import CfgNode as CN
25
+ import argparse
26
+ import numpy as np
27
+
28
+ # the parser for boolean
29
+ def bool_parser(arg):
30
+ """Parses an argument to boolean."""
31
+ if isinstance(arg, bool):
32
+ return arg
33
+ if arg is None:
34
+ return False
35
+ if arg.lower() in ['1', 'true', 't', 'yes', 'y']:
36
+ return True
37
+ if arg.lower() in ['0', 'false', 'f', 'no', 'n']:
38
+ return False
39
+ raise ValueError(f'`{arg}` cannot be converted to boolean!')
40
+
41
+ # -----------------------------------------------------------------------------
42
+ # base cfg
43
+ # -----------------------------------------------------------------------------
44
+ cfg = CN()
45
+
46
+ # configuration for basic experiments
47
+ cfg.save_dir = "./checkpoints"
48
+ cfg.restore_ckpt = ""
49
+ cfg.model_name = "cotracker"
50
+ cfg.exp_name = ""
51
+
52
+ # NOTE: configuration for datasets and augmentation
53
+ cfg.dataset_root = ""
54
+ cfg.eval_datasets = [""]
55
+ cfg.dont_use_augs = False
56
+ cfg.crop_size = [384, 512]
57
+ cfg.traj_per_sample = 384
58
+ cfg.sample_vis_1st_frame = False
59
+ cfg.depth_near = 0.01 # meter
60
+ cfg.depth_far = 65.0 # meter
61
+ cfg.sequence_len = 24
62
+
63
+ # NOTE: configuration for network arch
64
+ cfg.sliding_window_len = 8
65
+ cfg.remove_space_attn = False
66
+ cfg.updateformer_hidden_size = 384
67
+ cfg.updateformer_num_heads = 8
68
+ cfg.updateformer_space_depth = 6
69
+ cfg.updateformer_time_depth = 6
70
+ cfg.model_stride = 4
71
+ cfg.train_iters = 4
72
+ cfg.if_ARAP = False
73
+ cfg.Embed3D = False
74
+ cfg.Loss_W_feat = 5e-1
75
+ cfg.Loss_W_cls = 1e-4
76
+ cfg.depth_color = False
77
+ cfg.flash_attn = False
78
+ cfg.corr_dp = True
79
+ cfg.support_grid = 0
80
+ cfg.backbone = "CNN"
81
+ cfg.enc_only = False
82
+ cfg.init_match = False
83
+ cfg.Nblock = 4
84
+
85
+ # NOTE: configuration for training and saving
86
+ cfg.nodes_num = 1
87
+ cfg.batch_size = 1
88
+ cfg.num_workers = 6
89
+ cfg.mixed_precision = False
90
+ cfg.lr = 0.0005
91
+ cfg.wdecay = 0.00001
92
+ cfg.num_steps = 200000
93
+ cfg.evaluate_every_n_epoch = 1
94
+ cfg.save_every_n_epoch = 1
95
+ cfg.validate_at_start = False
96
+ cfg.save_freq = 100
97
+ cfg.eval_max_seq_len = 1000
98
+ cfg.debug = False
99
+ cfg.fine_tune = False
100
+ cfg.aug_wind_sample = False
101
+ cfg.use_video_flip = False
102
+ cfg.fix_backbone = False
103
+ cfg.tune_backbone = False
104
+ cfg.tune_arap = False
105
+ cfg.tune_per_scene = False
106
+ cfg.use_hier_encoder = False
107
+ cfg.scales = [4, 2]
108
+
109
+
110
+ # NOTE: configuration for monocular depth estimator
111
+ cfg.mde_name = "zoedepth_nk"
112
+
113
+ # -----------------------------------------------------------------------------
114
+
115
+ # configurations for the command line
116
+ parser = argparse.ArgumentParser()
117
+
118
+ # config for the basic experiment
119
+ parser.add_argument("--save_dir", default="./checkpoints", type=str ,help="path to save checkpoints")
120
+ parser.add_argument("--restore_ckpt", default="", help="path to restore a checkpoint")
121
+ parser.add_argument("--model_name", default="cotracker", help="model name")
122
+ parser.add_argument("--exp_name", type=str, default="base",
123
+ help="the name for experiment",
124
+ )
125
+ # config for dataset and augmentation
126
+ parser.add_argument(
127
+ "--dataset_root", type=str, help="path lo all the datasets (train and eval)"
128
+ )
129
+ parser.add_argument(
130
+ "--eval_datasets", nargs="+", default=["things", "badja"],
131
+ help="what datasets to use for evaluation",
132
+ )
133
+ parser.add_argument(
134
+ "--dont_use_augs", action="store_true", default=False,
135
+ help="don't apply augmentations during training",
136
+ )
137
+ parser.add_argument(
138
+ "--crop_size", type=int, nargs="+", default=[384, 512],
139
+ help="crop videos to this resolution during training",
140
+ )
141
+ parser.add_argument(
142
+ "--traj_per_sample", type=int, default=768,
143
+ help="the number of trajectories to sample for training",
144
+ )
145
+ parser.add_argument(
146
+ "--depth_near", type=float, default=0.01, help="near plane depth"
147
+ )
148
+ parser.add_argument(
149
+ "--depth_far", type=float, default=65.0, help="far plane depth"
150
+ )
151
+ parser.add_argument(
152
+ "--sample_vis_1st_frame",
153
+ action="store_true",
154
+ default=False,
155
+ help="only sample trajectories with points visible on the first frame",
156
+ )
157
+ parser.add_argument(
158
+ "--sequence_len", type=int, default=24, help="train sequence length"
159
+ )
160
+ # configuration for network arch
161
+ parser.add_argument(
162
+ "--sliding_window_len",
163
+ type=int,
164
+ default=8,
165
+ help="length of the CoTracker sliding window",
166
+ )
167
+ parser.add_argument(
168
+ "--remove_space_attn",
169
+ action="store_true",
170
+ default=False,
171
+ help="remove space attention from CoTracker",
172
+ )
173
+ parser.add_argument(
174
+ "--updateformer_hidden_size",
175
+ type=int,
176
+ default=384,
177
+ help="hidden dimension of the CoTracker transformer model",
178
+ )
179
+ parser.add_argument(
180
+ "--updateformer_num_heads",
181
+ type=int,
182
+ default=8,
183
+ help="number of heads of the CoTracker transformer model",
184
+ )
185
+ parser.add_argument(
186
+ "--updateformer_space_depth",
187
+ type=int,
188
+ default=6,
189
+ help="number of group attention layers in the CoTracker transformer model",
190
+ )
191
+ parser.add_argument(
192
+ "--updateformer_time_depth",
193
+ type=int,
194
+ default=6,
195
+ help="number of time attention layers in the CoTracker transformer model",
196
+ )
197
+ parser.add_argument(
198
+ "--model_stride",
199
+ type=int,
200
+ default=4,
201
+ help="stride of the CoTracker feature network",
202
+ )
203
+ parser.add_argument(
204
+ "--train_iters",
205
+ type=int,
206
+ default=4,
207
+ help="number of updates to the disparity field in each forward pass.",
208
+ )
209
+ parser.add_argument(
210
+ "--if_ARAP",
211
+ action="store_true",
212
+ default=False,
213
+ help="if using ARAP loss in the optimization",
214
+ )
215
+ parser.add_argument(
216
+ "--Embed3D",
217
+ action="store_true",
218
+ default=False,
219
+ help="if using the 3D embedding for image",
220
+ )
221
+ parser.add_argument(
222
+ "--Loss_W_feat",
223
+ type=float,
224
+ default=5e-1,
225
+ help="weight for the feature loss",
226
+ )
227
+ parser.add_argument(
228
+ "--Loss_W_cls",
229
+ type=float,
230
+ default=1e-4,
231
+ help="weight for the classification loss",
232
+ )
233
+ parser.add_argument(
234
+ "--depth_color",
235
+ action="store_true",
236
+ default=False,
237
+ help="if using the color for depth",
238
+ )
239
+ parser.add_argument(
240
+ "--flash_attn",
241
+ action="store_true",
242
+ default=False,
243
+ help="if using the flash attention",
244
+ )
245
+ parser.add_argument(
246
+ "--corr_dp",
247
+ action="store_true",
248
+ default=False,
249
+ help="if using the correlation of depth",
250
+ )
251
+ parser.add_argument(
252
+ "--support_grid",
253
+ type=int,
254
+ default=0,
255
+ help="if using the support grid",
256
+ )
257
+ parser.add_argument(
258
+ "--backbone",
259
+ type=str,
260
+ default="CNN",
261
+ help="backbone for the CoTracker feature network",
262
+ )
263
+ parser.add_argument(
264
+ "--enc_only",
265
+ action="store_true",
266
+ default=False,
267
+ help="if using the encoder only",
268
+ )
269
+ parser.add_argument(
270
+ "--init_match",
271
+ action="store_true",
272
+ default=False,
273
+ help="if using the initial matching",
274
+ )
275
+ parser.add_argument(
276
+ "--Nblock",
277
+ type=int,
278
+ default=4,
279
+ help="number of blocks in the CoTracker feature network",
280
+ )
281
+
282
+ # configuration for training and saving
283
+ parser.add_argument(
284
+ "--nodes_num", type=int, default=1, help="number of nodes used for training."
285
+ )
286
+ parser.add_argument(
287
+ "--batch_size", type=int, default=1, help="batch size used during training."
288
+ )
289
+ parser.add_argument(
290
+ "--num_workers", type=int, default=6, help="number of dataloader workers"
291
+ )
292
+
293
+ parser.add_argument(
294
+ "--mixed_precision",
295
+ action="store_true", default=False,
296
+ help="use mixed precision"
297
+ )
298
+ parser.add_argument("--lr", type=float, default=0.0005, help="max learning rate.")
299
+ parser.add_argument(
300
+ "--wdecay", type=float, default=0.00001, help="Weight decay in optimizer."
301
+ )
302
+ parser.add_argument(
303
+ "--num_steps", type=int, default=200000, help="length of training schedule."
304
+ )
305
+ parser.add_argument(
306
+ "--evaluate_every_n_epoch",
307
+ type=int,
308
+ default=1,
309
+ help="evaluate during training after every n epochs, after every epoch by default",
310
+ )
311
+ parser.add_argument(
312
+ "--save_every_n_epoch",
313
+ type=int,
314
+ default=1,
315
+ help="save checkpoints during training after every n epochs, after every epoch by default",
316
+ )
317
+ parser.add_argument(
318
+ "--validate_at_start",
319
+ action="store_true",
320
+ default=False,
321
+ help="whether to run evaluation before training starts",
322
+ )
323
+ parser.add_argument(
324
+ "--save_freq",
325
+ type=int,
326
+ default=100,
327
+ help="frequency of trajectory visualization during training",
328
+ )
329
+ parser.add_argument(
330
+ "--eval_max_seq_len",
331
+ type=int,
332
+ default=1000,
333
+ help="maximum length of evaluation videos",
334
+ )
335
+ parser.add_argument(
336
+ "--debug",
337
+ action="store_true",
338
+ default=False,
339
+ help="if using the visibility mask",
340
+ )
341
+ parser.add_argument(
342
+ "--fine_tune",
343
+ action="store_true",
344
+ default=False,
345
+ help="if fine tune the model",
346
+ )
347
+ parser.add_argument(
348
+ "--aug_wind_sample",
349
+ action="store_true",
350
+ default=False,
351
+ help="if using the window sampling",
352
+ )
353
+ parser.add_argument(
354
+ "--use_video_flip",
355
+ action="store_true",
356
+ default=False,
357
+ help="if using the video flip",
358
+ )
359
+ parser.add_argument(
360
+ "--fix_backbone",
361
+ action="store_true",
362
+ default=False,
363
+ help="if fix the backbone",
364
+ )
365
+ parser.add_argument(
366
+ "--tune_backbone",
367
+ action="store_true",
368
+ default=False,
369
+ help="if tune the backbone",
370
+ )
371
+ parser.add_argument(
372
+ "--tune_arap",
373
+ action="store_true",
374
+ default=False,
375
+ help="if fix the backbone",
376
+ )
377
+ parser.add_argument(
378
+ "--tune_per_scene",
379
+ action="store_true",
380
+ default=False,
381
+ help="if tune one scene",
382
+ )
383
+ parser.add_argument(
384
+ "--use_hier_encoder",
385
+ action="store_true",
386
+ default=False,
387
+ help="if using the hierarchical encoder",
388
+ )
389
+ parser.add_argument(
390
+ "--scales",
391
+ type=int,
392
+ nargs="+",
393
+ default=[4, 2],
394
+ help="scales for the CoTracker feature network",
395
+ )
396
+
397
+ # config for monocular depth estimator
398
+ parser.add_argument(
399
+ "--mde_name", type=str, default="zoedepth_nk", help="name of the MDE model"
400
+ )
401
+ args = parser.parse_args()
402
+ args_dict = vars(args)
403
+
404
+ # -----------------------------------------------------------------------------
405
+
406
+ # merge the `args` to the `cfg`
407
+ cfg.merge_from_dict(args_dict)
408
+
409
+ cfg.ckpt_path=os.path.join(args.save_dir, args.model_name ,args.exp_name)
410
+
config/ssm_cfg.py ADDED
@@ -0,0 +1,347 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #python3.10
2
+ """Hierachical configuration for different pipelines, using `yacs`
3
+ (refered to https://github.com/rbgirshick/yacs)
4
+
5
+ This projects contain the configuration for three aspects:
6
+ the regular config for experiment setting
7
+
8
+ NOTE: Each experiment will be assigned a seperate working space, and the
9
+ intermediate results will be saved in the working space. The experimentes
10
+ folder structure is as follows:
11
+ {
12
+ /${ROOT_WORK_DIR}/
13
+ └── ${PIPELINES_NAME}/
14
+ ├── ${EXP_NAME}/
15
+ ├── ${CHECKPOINT_DIR}/
16
+ ├── ${RESULT_DIR}/
17
+ ├── meta.json/
18
+ └── ${LOG_DIR}
19
+ }
20
+
21
+ """
22
+
23
+ import os, sys
24
+ from .yacs import CfgNode as CN
25
+ import argparse
26
+ import numpy as np
27
+
28
+ # the parser for boolean
29
+ def bool_parser(arg):
30
+ """Parses an argument to boolean."""
31
+ if isinstance(arg, bool):
32
+ return arg
33
+ if arg is None:
34
+ return False
35
+ if arg.lower() in ['1', 'true', 't', 'yes', 'y']:
36
+ return True
37
+ if arg.lower() in ['0', 'false', 'f', 'no', 'n']:
38
+ return False
39
+ raise ValueError(f'`{arg}` cannot be converted to boolean!')
40
+
41
+ # -----------------------------------------------------------------------------
42
+ # base cfg
43
+ # -----------------------------------------------------------------------------
44
+ cfg = CN()
45
+
46
+ # configuration for basic experiments
47
+ cfg.save_dir = "./checkpoints"
48
+ cfg.restore_ckpt = ""
49
+ cfg.model_name = "cotracker"
50
+ cfg.exp_name = ""
51
+
52
+ # NOTE: configuration for datasets and augmentation
53
+ cfg.dataset_root = ""
54
+ cfg.eval_datasets = [""]
55
+ cfg.dont_use_augs = False
56
+ cfg.crop_size = [384, 512]
57
+ cfg.traj_per_sample = 384
58
+ cfg.sample_vis_1st_frame = False
59
+ cfg.depth_near = 0.01 # meter
60
+ cfg.depth_far = 65.0 # meter
61
+ cfg.sequence_len = 24
62
+
63
+ # NOTE: configuration for network arch
64
+ cfg.hidden_size = 384
65
+ cfg.mamba_depth = 8
66
+ cfg.model_stride = 4
67
+ cfg.train_iters = 4
68
+ cfg.updateformer_num_heads = 8
69
+ cfg.updateformer_hidden_size = 384
70
+ cfg.if_ARAP = False
71
+ cfg.Embed3D = False
72
+ cfg.Loss_W_feat = 5e-1
73
+ cfg.Loss_W_cls = 1e-4
74
+ cfg.depth_color = False
75
+ cfg.flash_attn = False
76
+ cfg.corr_dp = True
77
+ cfg.support_grid = 0
78
+ cfg.backbone = "CNN"
79
+ cfg.enc_only = False
80
+
81
+ # NOTE: configuration for training and saving
82
+ cfg.nodes_num = 1
83
+ cfg.batch_size = 1
84
+ cfg.num_workers = 6
85
+ cfg.mixed_precision = False
86
+ cfg.lr = 0.0005
87
+ cfg.wdecay = 0.00001
88
+ cfg.num_steps = 200000
89
+ cfg.evaluate_every_n_epoch = 1
90
+ cfg.save_every_n_epoch = 1
91
+ cfg.validate_at_start = False
92
+ cfg.save_freq = 100
93
+ cfg.eval_max_seq_len = 1000
94
+ cfg.debug = False
95
+ cfg.fine_tune = False
96
+ cfg.aug_wind_sample = False
97
+ cfg.use_video_flip = False
98
+ cfg.fix_backbone = False
99
+ cfg.tune_backbone = False
100
+
101
+
102
+ # NOTE: configuration for monocular depth estimator
103
+ cfg.mde_name = "zoedepth_nk"
104
+
105
+ # -----------------------------------------------------------------------------
106
+
107
+ # configurations for the command line
108
+ parser = argparse.ArgumentParser()
109
+
110
+ # config for the basic experiment
111
+ parser.add_argument("--save_dir", default="./checkpoints", type=str ,help="path to save checkpoints")
112
+ parser.add_argument("--restore_ckpt", default="", help="path to restore a checkpoint")
113
+ parser.add_argument("--model_name", default="cotracker", help="model name")
114
+ parser.add_argument("--exp_name", type=str, default="base",
115
+ help="the name for experiment",
116
+ )
117
+ # config for dataset and augmentation
118
+ parser.add_argument(
119
+ "--dataset_root", type=str, help="path lo all the datasets (train and eval)"
120
+ )
121
+ parser.add_argument(
122
+ "--eval_datasets", nargs="+", default=["things", "badja"],
123
+ help="what datasets to use for evaluation",
124
+ )
125
+ parser.add_argument(
126
+ "--dont_use_augs", action="store_true", default=False,
127
+ help="don't apply augmentations during training",
128
+ )
129
+ parser.add_argument(
130
+ "--crop_size", type=int, nargs="+", default=[384, 512],
131
+ help="crop videos to this resolution during training",
132
+ )
133
+ parser.add_argument(
134
+ "--traj_per_sample", type=int, default=768,
135
+ help="the number of trajectories to sample for training",
136
+ )
137
+ parser.add_argument(
138
+ "--depth_near", type=float, default=0.01, help="near plane depth"
139
+ )
140
+ parser.add_argument(
141
+ "--depth_far", type=float, default=65.0, help="far plane depth"
142
+ )
143
+ parser.add_argument(
144
+ "--sample_vis_1st_frame",
145
+ action="store_true",
146
+ default=False,
147
+ help="only sample trajectories with points visible on the first frame",
148
+ )
149
+ parser.add_argument(
150
+ "--sequence_len", type=int, default=24, help="train sequence length"
151
+ )
152
+ # configuration for network arch
153
+ parser.add_argument(
154
+ "--hidden_size",
155
+ type=int,
156
+ default=384,
157
+ help="hidden dimension of the CoTracker transformer model",
158
+ )
159
+ parser.add_argument(
160
+ "--mamba_depth",
161
+ type=int,
162
+ default=6,
163
+ help="number of group attention layers in the CoTracker transformer model",
164
+ )
165
+ parser.add_argument(
166
+ "--updateformer_num_heads",
167
+ type=int,
168
+ default=8,
169
+ help="number of heads of the CoTracker transformer model",
170
+ )
171
+ parser.add_argument(
172
+ "--updateformer_hidden_size",
173
+ type=int,
174
+ default=384,
175
+ help="hidden dimension of the CoTracker transformer model",
176
+ )
177
+ parser.add_argument(
178
+ "--model_stride",
179
+ type=int,
180
+ default=4,
181
+ help="stride of the CoTracker feature network",
182
+ )
183
+ parser.add_argument(
184
+ "--train_iters",
185
+ type=int,
186
+ default=4,
187
+ help="number of updates to the disparity field in each forward pass.",
188
+ )
189
+ parser.add_argument(
190
+ "--if_ARAP",
191
+ action="store_true",
192
+ default=False,
193
+ help="if using ARAP loss in the optimization",
194
+ )
195
+ parser.add_argument(
196
+ "--Embed3D",
197
+ action="store_true",
198
+ default=False,
199
+ help="if using the 3D embedding for image",
200
+ )
201
+ parser.add_argument(
202
+ "--Loss_W_feat",
203
+ type=float,
204
+ default=5e-1,
205
+ help="weight for the feature loss",
206
+ )
207
+ parser.add_argument(
208
+ "--Loss_W_cls",
209
+ type=float,
210
+ default=1e-4,
211
+ help="weight for the classification loss",
212
+ )
213
+ parser.add_argument(
214
+ "--depth_color",
215
+ action="store_true",
216
+ default=False,
217
+ help="if using the color for depth",
218
+ )
219
+ parser.add_argument(
220
+ "--flash_attn",
221
+ action="store_true",
222
+ default=False,
223
+ help="if using the flash attention",
224
+ )
225
+ parser.add_argument(
226
+ "--corr_dp",
227
+ action="store_true",
228
+ default=False,
229
+ help="if using the correlation of depth",
230
+ )
231
+ parser.add_argument(
232
+ "--support_grid",
233
+ type=int,
234
+ default=0,
235
+ help="if using the support grid",
236
+ )
237
+ parser.add_argument(
238
+ "--backbone",
239
+ type=str,
240
+ default="CNN",
241
+ help="backbone for the CoTracker feature network",
242
+ )
243
+ parser.add_argument(
244
+ "--enc_only",
245
+ action="store_true",
246
+ default=False,
247
+ help="if using the encoder only",
248
+ )
249
+
250
+ # configuration for training and saving
251
+ parser.add_argument(
252
+ "--nodes_num", type=int, default=1, help="number of nodes used for training."
253
+ )
254
+ parser.add_argument(
255
+ "--batch_size", type=int, default=1, help="batch size used during training."
256
+ )
257
+ parser.add_argument(
258
+ "--num_workers", type=int, default=6, help="number of dataloader workers"
259
+ )
260
+
261
+ parser.add_argument(
262
+ "--mixed_precision",
263
+ action="store_true", default=False,
264
+ help="use mixed precision"
265
+ )
266
+ parser.add_argument("--lr", type=float, default=0.0005, help="max learning rate.")
267
+ parser.add_argument(
268
+ "--wdecay", type=float, default=0.00001, help="Weight decay in optimizer."
269
+ )
270
+ parser.add_argument(
271
+ "--num_steps", type=int, default=200000, help="length of training schedule."
272
+ )
273
+ parser.add_argument(
274
+ "--evaluate_every_n_epoch",
275
+ type=int,
276
+ default=1,
277
+ help="evaluate during training after every n epochs, after every epoch by default",
278
+ )
279
+ parser.add_argument(
280
+ "--save_every_n_epoch",
281
+ type=int,
282
+ default=1,
283
+ help="save checkpoints during training after every n epochs, after every epoch by default",
284
+ )
285
+ parser.add_argument(
286
+ "--validate_at_start",
287
+ action="store_true",
288
+ default=False,
289
+ help="whether to run evaluation before training starts",
290
+ )
291
+ parser.add_argument(
292
+ "--save_freq",
293
+ type=int,
294
+ default=100,
295
+ help="frequency of trajectory visualization during training",
296
+ )
297
+ parser.add_argument(
298
+ "--eval_max_seq_len",
299
+ type=int,
300
+ default=1000,
301
+ help="maximum length of evaluation videos",
302
+ )
303
+ parser.add_argument(
304
+ "--debug",
305
+ action="store_true",
306
+ default=False,
307
+ help="if using the visibility mask",
308
+ )
309
+ parser.add_argument(
310
+ "--fine_tune",
311
+ action="store_true",
312
+ default=False,
313
+ help="if fine tune the model",
314
+ )
315
+ parser.add_argument(
316
+ "--aug_wind_sample",
317
+ action="store_true",
318
+ default=False,
319
+ help="if using the window sampling",
320
+ )
321
+ parser.add_argument(
322
+ "--use_video_flip",
323
+ action="store_true",
324
+ default=False,
325
+ help="if using the video flip",
326
+ )
327
+ parser.add_argument(
328
+ "--fix_backbone",
329
+ action="store_true",
330
+ default=False,
331
+ help="if fix the backbone",
332
+ )
333
+
334
+ # config for monocular depth estimator
335
+ parser.add_argument(
336
+ "--mde_name", type=str, default="zoedepth_nk", help="name of the MDE model"
337
+ )
338
+ args = parser.parse_args()
339
+ args_dict = vars(args)
340
+
341
+ # -----------------------------------------------------------------------------
342
+
343
+ # merge the `args` to the `cfg`
344
+ cfg.merge_from_dict(args_dict)
345
+
346
+ cfg.ckpt_path=os.path.join(args.save_dir, args.model_name ,args.exp_name)
347
+
config/yacs.py ADDED
@@ -0,0 +1,506 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2018-present, Facebook, Inc.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ ##############################################################################
15
+
16
+ """YACS -- Yet Another Configuration System is designed to be a simple
17
+ configuration management system for academic and industrial research
18
+ projects.
19
+
20
+ See README.md for usage and examples.
21
+ """
22
+
23
+ import copy
24
+ import io
25
+ import logging
26
+ import os
27
+ from ast import literal_eval
28
+
29
+ import yaml
30
+
31
+
32
+ # Flag for py2 and py3 compatibility to use when separate code paths are necessary
33
+ # When _PY2 is False, we assume Python 3 is in use
34
+ _PY2 = False
35
+
36
+ # Filename extensions for loading configs from files
37
+ _YAML_EXTS = {"", ".yaml", ".yml"}
38
+ _PY_EXTS = {".py"}
39
+
40
+ # py2 and py3 compatibility for checking file object type
41
+ # We simply use this to infer py2 vs py3
42
+ try:
43
+ _FILE_TYPES = (file, io.IOBase)
44
+ _PY2 = True
45
+ except NameError:
46
+ _FILE_TYPES = (io.IOBase,)
47
+
48
+ # CfgNodes can only contain a limited set of valid types
49
+ _VALID_TYPES = {tuple, list, str, int, float, bool}
50
+ # py2 allow for str and unicode
51
+ if _PY2:
52
+ _VALID_TYPES = _VALID_TYPES.union({unicode}) # noqa: F821
53
+
54
+ # Utilities for importing modules from file paths
55
+ if _PY2:
56
+ # imp is available in both py2 and py3 for now, but is deprecated in py3
57
+ import imp
58
+ else:
59
+ import importlib.util
60
+
61
+ logger = logging.getLogger(__name__)
62
+
63
+
64
+ class CfgNode(dict):
65
+ """
66
+ CfgNode represents an internal node in the configuration tree. It's a simple
67
+ dict-like container that allows for attribute-based access to keys.
68
+ """
69
+
70
+ IMMUTABLE = "__immutable__"
71
+ DEPRECATED_KEYS = "__deprecated_keys__"
72
+ RENAMED_KEYS = "__renamed_keys__"
73
+
74
+ def __init__(self, init_dict=None, key_list=None):
75
+ # Recursively convert nested dictionaries in init_dict into CfgNodes
76
+ init_dict = {} if init_dict is None else init_dict
77
+ key_list = [] if key_list is None else key_list
78
+ for k, v in init_dict.items():
79
+ if type(v) is dict:
80
+ # Convert dict to CfgNode
81
+ init_dict[k] = CfgNode(v, key_list=key_list + [k])
82
+ else:
83
+ # Check for valid leaf type or nested CfgNode
84
+ _assert_with_logging(
85
+ _valid_type(v, allow_cfg_node=True),
86
+ "Key {} with value {} is not a valid type; valid types: {}".format(
87
+ ".".join(key_list + [k]), type(v), _VALID_TYPES
88
+ ),
89
+ )
90
+ super(CfgNode, self).__init__(init_dict)
91
+ # Manage if the CfgNode is frozen or not
92
+ self.__dict__[CfgNode.IMMUTABLE] = False
93
+ # Deprecated options
94
+ # If an option is removed from the code and you don't want to break existing
95
+ # yaml configs, you can add the full config key as a string to the set below.
96
+ self.__dict__[CfgNode.DEPRECATED_KEYS] = set()
97
+ # Renamed options
98
+ # If you rename a config option, record the mapping from the old name to the new
99
+ # name in the dictionary below. Optionally, if the type also changed, you can
100
+ # make the value a tuple that specifies first the renamed key and then
101
+ # instructions for how to edit the config file.
102
+ self.__dict__[CfgNode.RENAMED_KEYS] = {
103
+ # 'EXAMPLE.OLD.KEY': 'EXAMPLE.NEW.KEY', # Dummy example to follow
104
+ # 'EXAMPLE.OLD.KEY': ( # A more complex example to follow
105
+ # 'EXAMPLE.NEW.KEY',
106
+ # "Also convert to a tuple, e.g., 'foo' -> ('foo',) or "
107
+ # + "'foo:bar' -> ('foo', 'bar')"
108
+ # ),
109
+ }
110
+
111
+ def __getattr__(self, name):
112
+ if name in self:
113
+ return self[name]
114
+ else:
115
+ raise AttributeError(name)
116
+
117
+ def __setattr__(self, name, value):
118
+ if self.is_frozen():
119
+ raise AttributeError(
120
+ "Attempted to set {} to {}, but CfgNode is immutable".format(
121
+ name, value
122
+ )
123
+ )
124
+
125
+ _assert_with_logging(
126
+ name not in self.__dict__,
127
+ "Invalid attempt to modify internal CfgNode state: {}".format(name),
128
+ )
129
+ _assert_with_logging(
130
+ _valid_type(value, allow_cfg_node=True),
131
+ "Invalid type {} for key {}; valid types = {}".format(
132
+ type(value), name, _VALID_TYPES
133
+ ),
134
+ )
135
+
136
+ self[name] = value
137
+
138
+ def __str__(self):
139
+ def _indent(s_, num_spaces):
140
+ s = s_.split("\n")
141
+ if len(s) == 1:
142
+ return s_
143
+ first = s.pop(0)
144
+ s = [(num_spaces * " ") + line for line in s]
145
+ s = "\n".join(s)
146
+ s = first + "\n" + s
147
+ return s
148
+
149
+ r = ""
150
+ s = []
151
+ for k, v in sorted(self.items()):
152
+ seperator = "\n" if isinstance(v, CfgNode) else " "
153
+ attr_str = "{}:{}{}".format(str(k), seperator, str(v))
154
+ attr_str = _indent(attr_str, 2)
155
+ s.append(attr_str)
156
+ r += "\n".join(s)
157
+ return r
158
+
159
+ def __repr__(self):
160
+ return "{}({})".format(self.__class__.__name__, super(CfgNode, self).__repr__())
161
+
162
+ def dump(self):
163
+ """Dump to a string."""
164
+ self_as_dict = _to_dict(self)
165
+ return yaml.safe_dump(self_as_dict)
166
+
167
+ def merge_from_file(self, cfg_filename):
168
+ """Load a yaml config file and merge it this CfgNode."""
169
+ with open(cfg_filename, "r") as f:
170
+ cfg = load_cfg(f)
171
+ self.merge_from_other_cfg(cfg)
172
+
173
+ def merge_from_other_cfg(self, cfg_other):
174
+ """Merge `cfg_other` into this CfgNode."""
175
+ _merge_a_into_b(cfg_other, self, self, [])
176
+
177
+ def merge_from_list(self, cfg_list):
178
+ """Merge config (keys, values) in a list (e.g., from command line) into
179
+ this CfgNode. For example, `cfg_list = ['FOO.BAR', 0.5]`.
180
+ """
181
+ _assert_with_logging(
182
+ len(cfg_list) % 2 == 0,
183
+ "Override list has odd length: {}; it must be a list of pairs".format(
184
+ cfg_list
185
+ ),
186
+ )
187
+ root = self
188
+ for full_key, v in zip(cfg_list[0::2], cfg_list[1::2]):
189
+ if root.key_is_deprecated(full_key):
190
+ continue
191
+ if root.key_is_renamed(full_key):
192
+ root.raise_key_rename_error(full_key)
193
+ key_list = full_key.split(".")
194
+ d = self
195
+ for subkey in key_list[:-1]:
196
+ _assert_with_logging(
197
+ subkey in d, "Non-existent key: {}".format(full_key)
198
+ )
199
+ d = d[subkey]
200
+ subkey = key_list[-1]
201
+ _assert_with_logging(subkey in d, "Non-existent key: {}".format(full_key))
202
+ value = _decode_cfg_value(v)
203
+ value = _check_and_coerce_cfg_value_type(value, d[subkey], subkey, full_key)
204
+ d[subkey] = value
205
+ def merge_from_dict(self, cfg_dict):
206
+ """Merge config (keys, values) in a dict into this CfgNode."""
207
+ cfg_dict = cfg_dict.items()
208
+ cfg_list = []
209
+ for pair in cfg_dict:
210
+ cfg_list.append(pair[0])
211
+ cfg_list.append(pair[1])
212
+ self.merge_from_list(cfg_list)
213
+
214
+ def freeze(self):
215
+ """Make this CfgNode and all of its children immutable."""
216
+ self._immutable(True)
217
+
218
+ def defrost(self):
219
+ """Make this CfgNode and all of its children mutable."""
220
+ self._immutable(False)
221
+
222
+ def is_frozen(self):
223
+ """Return mutability."""
224
+ return self.__dict__[CfgNode.IMMUTABLE]
225
+
226
+ def _immutable(self, is_immutable):
227
+ """Set immutability to is_immutable and recursively apply the setting
228
+ to all nested CfgNodes.
229
+ """
230
+ self.__dict__[CfgNode.IMMUTABLE] = is_immutable
231
+ # Recursively set immutable state
232
+ for v in self.__dict__.values():
233
+ if isinstance(v, CfgNode):
234
+ v._immutable(is_immutable)
235
+ for v in self.values():
236
+ if isinstance(v, CfgNode):
237
+ v._immutable(is_immutable)
238
+
239
+ def clone(self):
240
+ """Recursively copy this CfgNode."""
241
+ return copy.deepcopy(self)
242
+
243
+ def register_deprecated_key(self, key):
244
+ """Register key (e.g. `FOO.BAR`) a deprecated option. When merging deprecated
245
+ keys a warning is generated and the key is ignored.
246
+ """
247
+ _assert_with_logging(
248
+ key not in self.__dict__[CfgNode.DEPRECATED_KEYS],
249
+ "key {} is already registered as a deprecated key".format(key),
250
+ )
251
+ self.__dict__[CfgNode.DEPRECATED_KEYS].add(key)
252
+
253
+ def register_renamed_key(self, old_name, new_name, message=None):
254
+ """Register a key as having been renamed from `old_name` to `new_name`.
255
+ When merging a renamed key, an exception is thrown alerting to user to
256
+ the fact that the key has been renamed.
257
+ """
258
+ _assert_with_logging(
259
+ old_name not in self.__dict__[CfgNode.RENAMED_KEYS],
260
+ "key {} is already registered as a renamed cfg key".format(old_name),
261
+ )
262
+ value = new_name
263
+ if message:
264
+ value = (new_name, message)
265
+ self.__dict__[CfgNode.RENAMED_KEYS][old_name] = value
266
+
267
+ def key_is_deprecated(self, full_key):
268
+ """Test if a key is deprecated."""
269
+ if full_key in self.__dict__[CfgNode.DEPRECATED_KEYS]:
270
+ logger.warning("Deprecated config key (ignoring): {}".format(full_key))
271
+ return True
272
+ return False
273
+
274
+ def key_is_renamed(self, full_key):
275
+ """Test if a key is renamed."""
276
+ return full_key in self.__dict__[CfgNode.RENAMED_KEYS]
277
+
278
+ def raise_key_rename_error(self, full_key):
279
+ new_key = self.__dict__[CfgNode.RENAMED_KEYS][full_key]
280
+ if isinstance(new_key, tuple):
281
+ msg = " Note: " + new_key[1]
282
+ new_key = new_key[0]
283
+ else:
284
+ msg = ""
285
+ raise KeyError(
286
+ "Key {} was renamed to {}; please update your config.{}".format(
287
+ full_key, new_key, msg
288
+ )
289
+ )
290
+
291
+
292
+ def load_cfg(cfg_file_obj_or_str):
293
+ """Load a cfg. Supports loading from:
294
+ - A file object backed by a YAML file
295
+ - A file object backed by a Python source file that exports an attribute
296
+ "cfg" that is either a dict or a CfgNode
297
+ - A string that can be parsed as valid YAML
298
+ """
299
+ _assert_with_logging(
300
+ isinstance(cfg_file_obj_or_str, _FILE_TYPES + (str,)),
301
+ "Expected first argument to be of type {} or {}, but it was {}".format(
302
+ _FILE_TYPES, str, type(cfg_file_obj_or_str)
303
+ ),
304
+ )
305
+ if isinstance(cfg_file_obj_or_str, str):
306
+ return _load_cfg_from_yaml_str(cfg_file_obj_or_str)
307
+ elif isinstance(cfg_file_obj_or_str, _FILE_TYPES):
308
+ return _load_cfg_from_file(cfg_file_obj_or_str)
309
+ else:
310
+ raise NotImplementedError("Impossible to reach here (unless there's a bug)")
311
+
312
+
313
+ def _load_cfg_from_file(file_obj):
314
+ """Load a config from a YAML file or a Python source file."""
315
+ _, file_extension = os.path.splitext(file_obj.name)
316
+ if file_extension in _YAML_EXTS:
317
+ return _load_cfg_from_yaml_str(file_obj.read())
318
+ elif file_extension in _PY_EXTS:
319
+ return _load_cfg_py_source(file_obj.name)
320
+ else:
321
+ raise Exception(
322
+ "Attempt to load from an unsupported file type {}; "
323
+ "only {} are supported".format(file_obj, _YAML_EXTS.union(_PY_EXTS))
324
+ )
325
+
326
+
327
+ def _load_cfg_from_yaml_str(str_obj):
328
+ """Load a config from a YAML string encoding."""
329
+ cfg_as_dict = yaml.safe_load(str_obj)
330
+ return CfgNode(cfg_as_dict)
331
+
332
+
333
+ def _load_cfg_py_source(filename):
334
+ """Load a config from a Python source file."""
335
+ module = _load_module_from_file("yacs.config.override", filename)
336
+ _assert_with_logging(
337
+ hasattr(module, "cfg"),
338
+ "Python module from file {} must have 'cfg' attr".format(filename),
339
+ )
340
+ VALID_ATTR_TYPES = {dict, CfgNode}
341
+ _assert_with_logging(
342
+ type(module.cfg) in VALID_ATTR_TYPES,
343
+ "Imported module 'cfg' attr must be in {} but is {} instead".format(
344
+ VALID_ATTR_TYPES, type(module.cfg)
345
+ ),
346
+ )
347
+ if type(module.cfg) is dict:
348
+ return CfgNode(module.cfg)
349
+ else:
350
+ return module.cfg
351
+
352
+
353
+ def _to_dict(cfg_node):
354
+ """Recursively convert all CfgNode objects to dict objects."""
355
+
356
+ def convert_to_dict(cfg_node, key_list):
357
+ if not isinstance(cfg_node, CfgNode):
358
+ _assert_with_logging(
359
+ _valid_type(cfg_node),
360
+ "Key {} with value {} is not a valid type; valid types: {}".format(
361
+ ".".join(key_list), type(cfg_node), _VALID_TYPES
362
+ ),
363
+ )
364
+ return cfg_node
365
+ else:
366
+ cfg_dict = dict(cfg_node)
367
+ for k, v in cfg_dict.items():
368
+ cfg_dict[k] = convert_to_dict(v, key_list + [k])
369
+ return cfg_dict
370
+
371
+ return convert_to_dict(cfg_node, [])
372
+
373
+
374
+ def _valid_type(value, allow_cfg_node=False):
375
+ return (type(value) in _VALID_TYPES) or (allow_cfg_node and type(value) == CfgNode)
376
+
377
+
378
+ def _merge_a_into_b(a, b, root, key_list):
379
+ """Merge config dictionary a into config dictionary b, clobbering the
380
+ options in b whenever they are also specified in a.
381
+ """
382
+ _assert_with_logging(
383
+ isinstance(a, CfgNode),
384
+ "`a` (cur type {}) must be an instance of {}".format(type(a), CfgNode),
385
+ )
386
+ _assert_with_logging(
387
+ isinstance(b, CfgNode),
388
+ "`b` (cur type {}) must be an instance of {}".format(type(b), CfgNode),
389
+ )
390
+
391
+ for k, v_ in a.items():
392
+ full_key = ".".join(key_list + [k])
393
+ # a must specify keys that are in b
394
+ if k not in b:
395
+ if root.key_is_deprecated(full_key):
396
+ continue
397
+ elif root.key_is_renamed(full_key):
398
+ root.raise_key_rename_error(full_key)
399
+ else:
400
+ v = copy.deepcopy(v_)
401
+ v = _decode_cfg_value(v)
402
+ b.update({k: v})
403
+ else:
404
+ v = copy.deepcopy(v_)
405
+ v = _decode_cfg_value(v)
406
+ v = _check_and_coerce_cfg_value_type(v, b[k], k, full_key)
407
+
408
+ # Recursively merge dicts
409
+ if isinstance(v, CfgNode):
410
+ try:
411
+ _merge_a_into_b(v, b[k], root, key_list + [k])
412
+ except BaseException:
413
+ raise
414
+ else:
415
+ b[k] = v
416
+
417
+
418
+ def _decode_cfg_value(v):
419
+ """Decodes a raw config value (e.g., from a yaml config files or command
420
+ line argument) into a Python object.
421
+ """
422
+ # Configs parsed from raw yaml will contain dictionary keys that need to be
423
+ # converted to CfgNode objects
424
+ if isinstance(v, dict):
425
+ return CfgNode(v)
426
+ # All remaining processing is only applied to strings
427
+ if not isinstance(v, str):
428
+ return v
429
+ # Try to interpret `v` as a:
430
+ # string, number, tuple, list, dict, boolean, or None
431
+ try:
432
+ v = literal_eval(v)
433
+ # The following two excepts allow v to pass through when it represents a
434
+ # string.
435
+ #
436
+ # Longer explanation:
437
+ # The type of v is always a string (before calling literal_eval), but
438
+ # sometimes it *represents* a string and other times a data structure, like
439
+ # a list. In the case that v represents a string, what we got back from the
440
+ # yaml parser is 'foo' *without quotes* (so, not '"foo"'). literal_eval is
441
+ # ok with '"foo"', but will raise a ValueError if given 'foo'. In other
442
+ # cases, like paths (v = 'foo/bar' and not v = '"foo/bar"'), literal_eval
443
+ # will raise a SyntaxError.
444
+ except ValueError:
445
+ pass
446
+ except SyntaxError:
447
+ pass
448
+ return v
449
+
450
+
451
+ def _check_and_coerce_cfg_value_type(replacement, original, key, full_key):
452
+ """Checks that `replacement`, which is intended to replace `original` is of
453
+ the right type. The type is correct if it matches exactly or is one of a few
454
+ cases in which the type can be easily coerced.
455
+ """
456
+ original_type = type(original)
457
+ replacement_type = type(replacement)
458
+
459
+ # The types must match (with some exceptions)
460
+ if replacement_type == original_type:
461
+ return replacement
462
+
463
+ # Cast replacement from from_type to to_type if the replacement and original
464
+ # types match from_type and to_type
465
+ def conditional_cast(from_type, to_type):
466
+ if replacement_type == from_type and original_type == to_type:
467
+ return True, to_type(replacement)
468
+ else:
469
+ return False, None
470
+
471
+ # Conditionally casts
472
+ # list <-> tuple
473
+ casts = [(tuple, list), (list, tuple)]
474
+ # For py2: allow converting from str (bytes) to a unicode string
475
+ try:
476
+ casts.append((str, unicode)) # noqa: F821
477
+ except Exception:
478
+ pass
479
+
480
+ for (from_type, to_type) in casts:
481
+ converted, converted_value = conditional_cast(from_type, to_type)
482
+ if converted:
483
+ return converted_value
484
+
485
+ raise ValueError(
486
+ "Type mismatch ({} vs. {}) with values ({} vs. {}) for config "
487
+ "key: {}".format(
488
+ original_type, replacement_type, original, replacement, full_key
489
+ )
490
+ )
491
+
492
+
493
+ def _assert_with_logging(cond, msg):
494
+ if not cond:
495
+ logger.debug(msg)
496
+ assert cond, msg
497
+
498
+
499
+ def _load_module_from_file(name, filename):
500
+ if _PY2:
501
+ module = imp.load_source(name, filename)
502
+ else:
503
+ spec = importlib.util.spec_from_file_location(name, filename)
504
+ module = importlib.util.module_from_spec(spec)
505
+ spec.loader.exec_module(module)
506
+ return module
demo.py ADDED
@@ -0,0 +1,206 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import argparse
4
+ from PIL import Image
5
+ project_root = os.path.dirname(os.path.abspath(__file__))
6
+ try:
7
+ sys.path.append(os.path.join(project_root, "submodules/MoGe"))
8
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
9
+ except:
10
+ print("Warning: MoGe not found, motion transfer will not be applied")
11
+
12
+ import torch
13
+ import numpy as np
14
+ from PIL import Image
15
+ import torchvision.transforms as transforms
16
+ from moviepy.editor import VideoFileClip
17
+ from diffusers.utils import load_image, load_video
18
+
19
+ from models.pipelines import DiffusionAsShaderPipeline, FirstFrameRepainter, CameraMotionGenerator, ObjectMotionGenerator
20
+ from submodules.MoGe.moge.model import MoGeModel
21
+
22
+ def load_media(media_path, max_frames=49, transform=None):
23
+ """Load video or image frames and convert to tensor
24
+
25
+ Args:
26
+ media_path (str): Path to video or image file
27
+ max_frames (int): Maximum number of frames to load
28
+ transform (callable): Transform to apply to frames
29
+
30
+ Returns:
31
+ Tuple[torch.Tensor, float]: Video tensor [T,C,H,W] and FPS
32
+ """
33
+ if transform is None:
34
+ transform = transforms.Compose([
35
+ transforms.Resize((480, 720)),
36
+ transforms.ToTensor()
37
+ ])
38
+
39
+ # Determine if input is video or image based on extension
40
+ ext = os.path.splitext(media_path)[1].lower()
41
+ is_video = ext in ['.mp4', '.avi', '.mov']
42
+
43
+ if is_video:
44
+ frames = load_video(media_path)
45
+ fps = len(frames) / VideoFileClip(media_path).duration
46
+ else:
47
+ # Handle image as single frame
48
+ image = load_image(media_path)
49
+ frames = [image]
50
+ fps = 8 # Default fps for images
51
+
52
+ # Ensure we have exactly max_frames
53
+ if len(frames) > max_frames:
54
+ frames = frames[:max_frames]
55
+ elif len(frames) < max_frames:
56
+ last_frame = frames[-1]
57
+ while len(frames) < max_frames:
58
+ frames.append(last_frame.copy())
59
+
60
+ # Convert frames to tensor
61
+ video_tensor = torch.stack([transform(frame) for frame in frames])
62
+
63
+ return video_tensor, fps, is_video
64
+
65
+ if __name__ == "__main__":
66
+ parser = argparse.ArgumentParser()
67
+ parser.add_argument('--input_path', type=str, default=None, help='Path to input video/image')
68
+ parser.add_argument('--prompt', type=str, required=True, help='Repaint prompt')
69
+ parser.add_argument('--output_dir', type=str, default='outputs', help='Output directory')
70
+ parser.add_argument('--gpu', type=int, default=0, help='GPU device ID')
71
+ parser.add_argument('--checkpoint_path', type=str, default="EXCAI/Diffusion-As-Shader", help='Path to model checkpoint')
72
+ parser.add_argument('--depth_path', type=str, default=None, help='Path to depth image')
73
+ parser.add_argument('--tracking_path', type=str, default=None, help='Path to tracking video, if provided, camera motion and object manipulation will not be applied')
74
+ parser.add_argument('--repaint', type=str, default=None,
75
+ help='Path to repainted image, or "true" to perform repainting, if not provided use original frame')
76
+ parser.add_argument('--camera_motion', type=str, default=None,
77
+ help='Camera motion mode: "trans <dx> <dy> <dz>" or "rot <axis> <angle>" or "spiral <radius>"')
78
+ parser.add_argument('--object_motion', type=str, default=None, help='Object motion mode: up/down/left/right')
79
+ parser.add_argument('--object_mask', type=str, default=None, help='Path to object mask image (binary image)')
80
+ parser.add_argument('--tracking_method', type=str, default='spatracker', choices=['spatracker', 'moge'],
81
+ help='Tracking method to use (spatracker or moge)')
82
+ args = parser.parse_args()
83
+
84
+ # Load input video/image
85
+ video_tensor, fps, is_video = load_media(args.input_path)
86
+ if not is_video:
87
+ args.tracking_method = "moge"
88
+ print("Image input detected, using MoGe for tracking video generation.")
89
+
90
+ # Initialize pipeline
91
+ das = DiffusionAsShaderPipeline(gpu_id=args.gpu, output_dir=args.output_dir)
92
+ if args.tracking_method == "moge" and args.tracking_path is None:
93
+ moge = MoGeModel.from_pretrained("Ruicheng/moge-vitl").to(das.device)
94
+
95
+ # Repaint first frame if requested
96
+ repaint_img_tensor = None
97
+ if args.repaint:
98
+ if args.repaint.lower() == "true":
99
+ repainter = FirstFrameRepainter(gpu_id=args.gpu, output_dir=args.output_dir)
100
+ repaint_img_tensor = repainter.repaint(
101
+ video_tensor[0],
102
+ prompt=args.prompt,
103
+ depth_path=args.depth_path
104
+ )
105
+ else:
106
+ repaint_img_tensor, _, _ = load_media(args.repaint)
107
+ repaint_img_tensor = repaint_img_tensor[0] # Take first frame
108
+
109
+ # Generate tracking if not provided
110
+ tracking_tensor = None
111
+ pred_tracks = None
112
+ cam_motion = CameraMotionGenerator(args.camera_motion)
113
+
114
+ if args.tracking_path:
115
+ tracking_tensor, _, _ = load_media(args.tracking_path)
116
+
117
+ elif args.tracking_method == "moge":
118
+ # Use the first frame from previously loaded video_tensor
119
+ infer_result = moge.infer(video_tensor[0].to(das.device)) # [C, H, W] in range [0,1]
120
+ H, W = infer_result["points"].shape[0:2]
121
+ pred_tracks = infer_result["points"].unsqueeze(0).repeat(49, 1, 1, 1) #[T, H, W, 3]
122
+ cam_motion.set_intr(infer_result["intrinsics"])
123
+
124
+ # Apply object motion if specified
125
+ if args.object_motion:
126
+ if args.object_mask is None:
127
+ raise ValueError("Object motion specified but no mask provided. Please provide a mask image with --object_mask")
128
+
129
+ # Load mask image
130
+ mask_image = Image.open(args.object_mask).convert('L') # Convert to grayscale
131
+ mask_image = transforms.Resize((480, 720))(mask_image) # Resize to match video size
132
+ # Convert to binary mask
133
+ mask = torch.from_numpy(np.array(mask_image) > 127) # Threshold at 127
134
+
135
+ motion_generator = ObjectMotionGenerator(device=das.device)
136
+
137
+ pred_tracks = motion_generator.apply_motion(
138
+ pred_tracks=pred_tracks,
139
+ mask=mask,
140
+ motion_type=args.object_motion,
141
+ distance=50,
142
+ num_frames=49,
143
+ tracking_method="moge"
144
+ )
145
+ print("Object motion applied")
146
+
147
+ # Apply camera motion if specified
148
+ if args.camera_motion:
149
+ poses = cam_motion.get_default_motion() # shape: [49, 4, 4]
150
+ print("Camera motion applied")
151
+ else:
152
+ # no poses
153
+ poses = torch.eye(4).unsqueeze(0).repeat(49, 1, 1)
154
+ # change pred_tracks into screen coordinate
155
+ pred_tracks_flatten = pred_tracks.reshape(video_tensor.shape[0], H*W, 3)
156
+ pred_tracks = cam_motion.w2s(pred_tracks_flatten, poses).reshape([video_tensor.shape[0], H, W, 3]) # [T, H, W, 3]
157
+ _, tracking_tensor = das.visualize_tracking_moge(
158
+ pred_tracks.cpu().numpy(),
159
+ infer_result["mask"].cpu().numpy()
160
+ )
161
+ print('export tracking video via MoGe.')
162
+
163
+ else:
164
+ # Generate tracking points
165
+ pred_tracks, pred_visibility, T_Firsts = das.generate_tracking_spatracker(video_tensor)
166
+
167
+ # Apply camera motion if specified
168
+ if args.camera_motion:
169
+ poses = cam_motion.get_default_motion() # shape: [49, 4, 4]
170
+ pred_tracks = cam_motion.apply_motion_on_pts(pred_tracks, poses)
171
+ print("Camera motion applied")
172
+
173
+ # Apply object motion if specified
174
+ if args.object_motion:
175
+ if args.object_mask is None:
176
+ raise ValueError("Object motion specified but no mask provided. Please provide a mask image with --object_mask")
177
+
178
+ # Load mask image
179
+ mask_image = Image.open(args.object_mask).convert('L') # Convert to grayscale
180
+ mask_image = transforms.Resize((480, 720))(mask_image) # Resize to match video size
181
+ # Convert to binary mask
182
+ mask = torch.from_numpy(np.array(mask_image) > 127) # Threshold at 127
183
+
184
+ motion_generator = ObjectMotionGenerator(device=das.device)
185
+
186
+ pred_tracks = motion_generator.apply_motion(
187
+ pred_tracks=pred_tracks.squeeze(),
188
+ mask=mask,
189
+ motion_type=args.object_motion,
190
+ distance=50,
191
+ num_frames=49,
192
+ tracking_method="spatracker"
193
+ ).unsqueeze(0)
194
+ print(f"Object motion '{args.object_motion}' applied using mask from {args.object_mask}")
195
+
196
+ # Generate tracking tensor from modified tracks
197
+ _, tracking_tensor = das.visualize_tracking_spatracker(video_tensor, pred_tracks, pred_visibility, T_Firsts)
198
+
199
+ das.apply_tracking(
200
+ video_tensor=video_tensor,
201
+ fps=8,
202
+ tracking_tensor=tracking_tensor,
203
+ img_cond_tensor=repaint_img_tensor,
204
+ prompt=args.prompt,
205
+ checkpoint_path=args.checkpoint_path
206
+ )
models/cogvideox_tracking.py ADDED
@@ -0,0 +1,1020 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Dict, Optional, Tuple, Union, List, Callable
2
+
3
+ import torch, os, math
4
+ from torch import nn
5
+ from PIL import Image
6
+ from tqdm import tqdm
7
+
8
+ from diffusers.utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers
9
+ from diffusers.models.modeling_outputs import Transformer2DModelOutput
10
+ from diffusers.models.transformers.cogvideox_transformer_3d import CogVideoXBlock, CogVideoXTransformer3DModel
11
+
12
+ from diffusers.pipelines.cogvideo.pipeline_cogvideox import CogVideoXPipeline, CogVideoXPipelineOutput
13
+ from diffusers.pipelines.cogvideo.pipeline_cogvideox_image2video import CogVideoXImageToVideoPipeline
14
+ from diffusers.pipelines.cogvideo.pipeline_cogvideox_video2video import CogVideoXVideoToVideoPipeline
15
+ from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback
16
+ from diffusers.pipelines.cogvideo.pipeline_cogvideox import retrieve_timesteps
17
+ from transformers import T5EncoderModel, T5Tokenizer
18
+ from diffusers.models import AutoencoderKLCogVideoX, CogVideoXTransformer3DModel
19
+ from diffusers.schedulers import CogVideoXDDIMScheduler, CogVideoXDPMScheduler
20
+ from diffusers.pipelines import DiffusionPipeline
21
+ from diffusers.models.modeling_utils import ModelMixin
22
+
23
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
24
+
25
+ class CogVideoXTransformer3DModelTracking(CogVideoXTransformer3DModel, ModelMixin):
26
+ """
27
+ Add tracking maps to the CogVideoX transformer model.
28
+
29
+ Parameters:
30
+ num_tracking_blocks (`int`, defaults to `18`):
31
+ The number of tracking blocks to use. Must be less than or equal to num_layers.
32
+ """
33
+
34
+ def __init__(
35
+ self,
36
+ num_tracking_blocks: Optional[int] = 18,
37
+ num_attention_heads: int = 30,
38
+ attention_head_dim: int = 64,
39
+ in_channels: int = 16,
40
+ out_channels: Optional[int] = 16,
41
+ flip_sin_to_cos: bool = True,
42
+ freq_shift: int = 0,
43
+ time_embed_dim: int = 512,
44
+ text_embed_dim: int = 4096,
45
+ num_layers: int = 30,
46
+ dropout: float = 0.0,
47
+ attention_bias: bool = True,
48
+ sample_width: int = 90,
49
+ sample_height: int = 60,
50
+ sample_frames: int = 49,
51
+ patch_size: int = 2,
52
+ temporal_compression_ratio: int = 4,
53
+ max_text_seq_length: int = 226,
54
+ activation_fn: str = "gelu-approximate",
55
+ timestep_activation_fn: str = "silu",
56
+ norm_elementwise_affine: bool = True,
57
+ norm_eps: float = 1e-5,
58
+ spatial_interpolation_scale: float = 1.875,
59
+ temporal_interpolation_scale: float = 1.0,
60
+ use_rotary_positional_embeddings: bool = False,
61
+ use_learned_positional_embeddings: bool = False,
62
+ **kwargs
63
+ ):
64
+ super().__init__(
65
+ num_attention_heads=num_attention_heads,
66
+ attention_head_dim=attention_head_dim,
67
+ in_channels=in_channels,
68
+ out_channels=out_channels,
69
+ flip_sin_to_cos=flip_sin_to_cos,
70
+ freq_shift=freq_shift,
71
+ time_embed_dim=time_embed_dim,
72
+ text_embed_dim=text_embed_dim,
73
+ num_layers=num_layers,
74
+ dropout=dropout,
75
+ attention_bias=attention_bias,
76
+ sample_width=sample_width,
77
+ sample_height=sample_height,
78
+ sample_frames=sample_frames,
79
+ patch_size=patch_size,
80
+ temporal_compression_ratio=temporal_compression_ratio,
81
+ max_text_seq_length=max_text_seq_length,
82
+ activation_fn=activation_fn,
83
+ timestep_activation_fn=timestep_activation_fn,
84
+ norm_elementwise_affine=norm_elementwise_affine,
85
+ norm_eps=norm_eps,
86
+ spatial_interpolation_scale=spatial_interpolation_scale,
87
+ temporal_interpolation_scale=temporal_interpolation_scale,
88
+ use_rotary_positional_embeddings=use_rotary_positional_embeddings,
89
+ use_learned_positional_embeddings=use_learned_positional_embeddings,
90
+ **kwargs
91
+ )
92
+
93
+ inner_dim = num_attention_heads * attention_head_dim
94
+ self.num_tracking_blocks = num_tracking_blocks
95
+
96
+ # Ensure num_tracking_blocks is not greater than num_layers
97
+ if num_tracking_blocks > num_layers:
98
+ raise ValueError("num_tracking_blocks must be less than or equal to num_layers")
99
+
100
+ # Create linear layers for combining hidden states and tracking maps
101
+ self.combine_linears = nn.ModuleList(
102
+ [nn.Linear(inner_dim, inner_dim) for _ in range(num_tracking_blocks)]
103
+ )
104
+
105
+ # Initialize weights of combine_linears to zero
106
+ for linear in self.combine_linears:
107
+ linear.weight.data.zero_()
108
+ linear.bias.data.zero_()
109
+
110
+ # Create transformer blocks for processing tracking maps
111
+ self.transformer_blocks_copy = nn.ModuleList(
112
+ [
113
+ CogVideoXBlock(
114
+ dim=inner_dim,
115
+ num_attention_heads=self.config.num_attention_heads,
116
+ attention_head_dim=self.config.attention_head_dim,
117
+ time_embed_dim=self.config.time_embed_dim,
118
+ dropout=self.config.dropout,
119
+ activation_fn=self.config.activation_fn,
120
+ attention_bias=self.config.attention_bias,
121
+ norm_elementwise_affine=self.config.norm_elementwise_affine,
122
+ norm_eps=self.config.norm_eps,
123
+ )
124
+ for _ in range(num_tracking_blocks)
125
+ ]
126
+ )
127
+
128
+ # For initial combination of hidden states and tracking maps
129
+ self.initial_combine_linear = nn.Linear(inner_dim, inner_dim)
130
+ self.initial_combine_linear.weight.data.zero_()
131
+ self.initial_combine_linear.bias.data.zero_()
132
+
133
+ # Freeze all parameters
134
+ for param in self.parameters():
135
+ param.requires_grad = False
136
+
137
+ # Unfreeze parameters that need to be trained
138
+ for linear in self.combine_linears:
139
+ for param in linear.parameters():
140
+ param.requires_grad = True
141
+
142
+ for block in self.transformer_blocks_copy:
143
+ for param in block.parameters():
144
+ param.requires_grad = True
145
+
146
+ for param in self.initial_combine_linear.parameters():
147
+ param.requires_grad = True
148
+
149
+ def forward(
150
+ self,
151
+ hidden_states: torch.Tensor,
152
+ encoder_hidden_states: torch.Tensor,
153
+ tracking_maps: torch.Tensor,
154
+ timestep: Union[int, float, torch.LongTensor],
155
+ timestep_cond: Optional[torch.Tensor] = None,
156
+ image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
157
+ attention_kwargs: Optional[Dict[str, Any]] = None,
158
+ return_dict: bool = True,
159
+ ):
160
+ if attention_kwargs is not None:
161
+ attention_kwargs = attention_kwargs.copy()
162
+ lora_scale = attention_kwargs.pop("scale", 1.0)
163
+ else:
164
+ lora_scale = 1.0
165
+
166
+ if USE_PEFT_BACKEND:
167
+ # weight the lora layers by setting `lora_scale` for each PEFT layer
168
+ scale_lora_layers(self, lora_scale)
169
+ else:
170
+ if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
171
+ logger.warning(
172
+ "Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
173
+ )
174
+
175
+ batch_size, num_frames, channels, height, width = hidden_states.shape
176
+
177
+ # 1. Time embedding
178
+ timesteps = timestep
179
+ t_emb = self.time_proj(timesteps)
180
+
181
+ # timesteps does not contain any weights and will always return f32 tensors
182
+ # but time_embedding might actually be running in fp16. so we need to cast here.
183
+ # there might be better ways to encapsulate this.
184
+ t_emb = t_emb.to(dtype=hidden_states.dtype)
185
+ emb = self.time_embedding(t_emb, timestep_cond)
186
+
187
+ # 2. Patch embedding
188
+ hidden_states = self.patch_embed(encoder_hidden_states, hidden_states)
189
+ hidden_states = self.embedding_dropout(hidden_states)
190
+
191
+ # Process tracking maps
192
+ prompt_embed = encoder_hidden_states.clone()
193
+ tracking_maps_hidden_states = self.patch_embed(prompt_embed, tracking_maps)
194
+ tracking_maps_hidden_states = self.embedding_dropout(tracking_maps_hidden_states)
195
+ del prompt_embed
196
+
197
+ text_seq_length = encoder_hidden_states.shape[1]
198
+ encoder_hidden_states = hidden_states[:, :text_seq_length]
199
+ hidden_states = hidden_states[:, text_seq_length:]
200
+ tracking_maps = tracking_maps_hidden_states[:, text_seq_length:]
201
+
202
+ # Combine hidden states and tracking maps initially
203
+ combined = hidden_states + tracking_maps
204
+ tracking_maps = self.initial_combine_linear(combined)
205
+
206
+ # Process transformer blocks
207
+ for i in range(len(self.transformer_blocks)):
208
+ if self.training and self.gradient_checkpointing:
209
+ # Gradient checkpointing logic for hidden states
210
+ def create_custom_forward(module):
211
+ def custom_forward(*inputs):
212
+ return module(*inputs)
213
+ return custom_forward
214
+
215
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
216
+ hidden_states, encoder_hidden_states = torch.utils.checkpoint.checkpoint(
217
+ create_custom_forward(self.transformer_blocks[i]),
218
+ hidden_states,
219
+ encoder_hidden_states,
220
+ emb,
221
+ image_rotary_emb,
222
+ **ckpt_kwargs,
223
+ )
224
+ else:
225
+ hidden_states, encoder_hidden_states = self.transformer_blocks[i](
226
+ hidden_states=hidden_states,
227
+ encoder_hidden_states=encoder_hidden_states,
228
+ temb=emb,
229
+ image_rotary_emb=image_rotary_emb,
230
+ )
231
+
232
+ if i < len(self.transformer_blocks_copy):
233
+ if self.training and self.gradient_checkpointing:
234
+ # Gradient checkpointing logic for tracking maps
235
+ tracking_maps, _ = torch.utils.checkpoint.checkpoint(
236
+ create_custom_forward(self.transformer_blocks_copy[i]),
237
+ tracking_maps,
238
+ encoder_hidden_states,
239
+ emb,
240
+ image_rotary_emb,
241
+ **ckpt_kwargs,
242
+ )
243
+ else:
244
+ tracking_maps, _ = self.transformer_blocks_copy[i](
245
+ hidden_states=tracking_maps,
246
+ encoder_hidden_states=encoder_hidden_states,
247
+ temb=emb,
248
+ image_rotary_emb=image_rotary_emb,
249
+ )
250
+
251
+ # Combine hidden states and tracking maps
252
+ tracking_maps = self.combine_linears[i](tracking_maps)
253
+ hidden_states = hidden_states + tracking_maps
254
+
255
+
256
+ if not self.config.use_rotary_positional_embeddings:
257
+ # CogVideoX-2B
258
+ hidden_states = self.norm_final(hidden_states)
259
+ else:
260
+ # CogVideoX-5B
261
+ hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
262
+ hidden_states = self.norm_final(hidden_states)
263
+ hidden_states = hidden_states[:, text_seq_length:]
264
+
265
+ # 4. Final block
266
+ hidden_states = self.norm_out(hidden_states, temb=emb)
267
+ hidden_states = self.proj_out(hidden_states)
268
+
269
+ # 5. Unpatchify
270
+ # Note: we use `-1` instead of `channels`:
271
+ # - It is okay to `channels` use for CogVideoX-2b and CogVideoX-5b (number of input channels is equal to output channels)
272
+ # - However, for CogVideoX-5b-I2V also takes concatenated input image latents (number of input channels is twice the output channels)
273
+ p = self.config.patch_size
274
+ output = hidden_states.reshape(batch_size, num_frames, height // p, width // p, -1, p, p)
275
+ output = output.permute(0, 1, 4, 2, 5, 3, 6).flatten(5, 6).flatten(3, 4)
276
+
277
+ if USE_PEFT_BACKEND:
278
+ # remove `lora_scale` from each PEFT layer
279
+ unscale_lora_layers(self, lora_scale)
280
+
281
+ if not return_dict:
282
+ return (output,)
283
+ return Transformer2DModelOutput(sample=output)
284
+
285
+ @classmethod
286
+ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs):
287
+ try:
288
+ model = super().from_pretrained(pretrained_model_name_or_path, **kwargs)
289
+ print("Loaded DiffusionAsShader checkpoint directly.")
290
+
291
+ for param in model.parameters():
292
+ param.requires_grad = False
293
+
294
+ for linear in model.combine_linears:
295
+ for param in linear.parameters():
296
+ param.requires_grad = True
297
+
298
+ for block in model.transformer_blocks_copy:
299
+ for param in block.parameters():
300
+ param.requires_grad = True
301
+
302
+ for param in model.initial_combine_linear.parameters():
303
+ param.requires_grad = True
304
+
305
+ return model
306
+
307
+ except Exception as e:
308
+ print(f"Failed to load as DiffusionAsShader: {e}")
309
+ print("Attempting to load as CogVideoXTransformer3DModel and convert...")
310
+
311
+ base_model = CogVideoXTransformer3DModel.from_pretrained(pretrained_model_name_or_path, **kwargs)
312
+
313
+ config = dict(base_model.config)
314
+ config["num_tracking_blocks"] = kwargs.pop("num_tracking_blocks", 18)
315
+
316
+ model = cls(**config)
317
+ model.load_state_dict(base_model.state_dict(), strict=False)
318
+
319
+ model.initial_combine_linear.weight.data.zero_()
320
+ model.initial_combine_linear.bias.data.zero_()
321
+
322
+ for linear in model.combine_linears:
323
+ linear.weight.data.zero_()
324
+ linear.bias.data.zero_()
325
+
326
+ for i in range(model.num_tracking_blocks):
327
+ model.transformer_blocks_copy[i].load_state_dict(model.transformer_blocks[i].state_dict())
328
+
329
+
330
+ for param in model.parameters():
331
+ param.requires_grad = False
332
+
333
+ for linear in model.combine_linears:
334
+ for param in linear.parameters():
335
+ param.requires_grad = True
336
+
337
+ for block in model.transformer_blocks_copy:
338
+ for param in block.parameters():
339
+ param.requires_grad = True
340
+
341
+ for param in model.initial_combine_linear.parameters():
342
+ param.requires_grad = True
343
+
344
+ return model
345
+
346
+ def save_pretrained(
347
+ self,
348
+ save_directory: Union[str, os.PathLike],
349
+ is_main_process: bool = True,
350
+ save_function: Optional[Callable] = None,
351
+ safe_serialization: bool = True,
352
+ variant: Optional[str] = None,
353
+ max_shard_size: Union[int, str] = "5GB",
354
+ push_to_hub: bool = False,
355
+ **kwargs,
356
+ ):
357
+ super().save_pretrained(
358
+ save_directory,
359
+ is_main_process=is_main_process,
360
+ save_function=save_function,
361
+ safe_serialization=safe_serialization,
362
+ variant=variant,
363
+ max_shard_size=max_shard_size,
364
+ push_to_hub=push_to_hub,
365
+ **kwargs,
366
+ )
367
+
368
+ if is_main_process:
369
+ config_dict = dict(self.config)
370
+ config_dict.pop("_name_or_path", None)
371
+ config_dict.pop("_use_default_values", None)
372
+ config_dict["_class_name"] = "CogVideoXTransformer3DModelTracking"
373
+ config_dict["num_tracking_blocks"] = self.num_tracking_blocks
374
+
375
+ os.makedirs(save_directory, exist_ok=True)
376
+ with open(os.path.join(save_directory, "config.json"), "w", encoding="utf-8") as f:
377
+ import json
378
+ json.dump(config_dict, f, indent=2)
379
+
380
+ class CogVideoXPipelineTracking(CogVideoXPipeline, DiffusionPipeline):
381
+
382
+ def __init__(
383
+ self,
384
+ tokenizer: T5Tokenizer,
385
+ text_encoder: T5EncoderModel,
386
+ vae: AutoencoderKLCogVideoX,
387
+ transformer: CogVideoXTransformer3DModelTracking,
388
+ scheduler: Union[CogVideoXDDIMScheduler, CogVideoXDPMScheduler],
389
+ ):
390
+ super().__init__(tokenizer, text_encoder, vae, transformer, scheduler)
391
+
392
+ if not isinstance(self.transformer, CogVideoXTransformer3DModelTracking):
393
+ raise ValueError("The transformer in this pipeline must be of type CogVideoXTransformer3DModelTracking")
394
+
395
+ @torch.no_grad()
396
+ def __call__(
397
+ self,
398
+ prompt: Optional[Union[str, List[str]]] = None,
399
+ negative_prompt: Optional[Union[str, List[str]]] = None,
400
+ height: int = 480,
401
+ width: int = 720,
402
+ num_frames: int = 49,
403
+ num_inference_steps: int = 50,
404
+ timesteps: Optional[List[int]] = None,
405
+ guidance_scale: float = 6,
406
+ use_dynamic_cfg: bool = False,
407
+ num_videos_per_prompt: int = 1,
408
+ eta: float = 0.0,
409
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
410
+ latents: Optional[torch.FloatTensor] = None,
411
+ prompt_embeds: Optional[torch.FloatTensor] = None,
412
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
413
+ output_type: str = "pil",
414
+ return_dict: bool = True,
415
+ attention_kwargs: Optional[Dict[str, Any]] = None,
416
+ callback_on_step_end: Optional[
417
+ Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
418
+ ] = None,
419
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
420
+ max_sequence_length: int = 226,
421
+ tracking_maps: Optional[torch.Tensor] = None,
422
+ ) -> Union[CogVideoXPipelineOutput, Tuple]:
423
+ if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
424
+ callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
425
+
426
+ num_videos_per_prompt = 1
427
+
428
+ self.check_inputs(
429
+ prompt,
430
+ height,
431
+ width,
432
+ negative_prompt,
433
+ callback_on_step_end_tensor_inputs,
434
+ prompt_embeds,
435
+ negative_prompt_embeds,
436
+ )
437
+ self._guidance_scale = guidance_scale
438
+ self._attention_kwargs = attention_kwargs
439
+ self._interrupt = False
440
+
441
+ if prompt is not None and isinstance(prompt, str):
442
+ batch_size = 1
443
+ elif prompt is not None and isinstance(prompt, list):
444
+ batch_size = len(prompt)
445
+ else:
446
+ batch_size = prompt_embeds.shape[0]
447
+
448
+ device = self._execution_device
449
+
450
+ do_classifier_free_guidance = guidance_scale > 1.0
451
+
452
+ prompt_embeds, negative_prompt_embeds = self.encode_prompt(
453
+ prompt,
454
+ negative_prompt,
455
+ do_classifier_free_guidance,
456
+ num_videos_per_prompt=num_videos_per_prompt,
457
+ prompt_embeds=prompt_embeds,
458
+ negative_prompt_embeds=negative_prompt_embeds,
459
+ max_sequence_length=max_sequence_length,
460
+ device=device,
461
+ )
462
+ if do_classifier_free_guidance:
463
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
464
+
465
+ timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
466
+ self._num_timesteps = len(timesteps)
467
+
468
+ latent_channels = self.transformer.config.in_channels
469
+ latents = self.prepare_latents(
470
+ batch_size * num_videos_per_prompt,
471
+ latent_channels,
472
+ num_frames,
473
+ height,
474
+ width,
475
+ prompt_embeds.dtype,
476
+ device,
477
+ generator,
478
+ latents,
479
+ )
480
+
481
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
482
+
483
+ image_rotary_emb = (
484
+ self._prepare_rotary_positional_embeddings(height, width, latents.size(1), device)
485
+ if self.transformer.config.use_rotary_positional_embeddings
486
+ else None
487
+ )
488
+
489
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
490
+
491
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
492
+ old_pred_original_sample = None
493
+ for i, t in enumerate(timesteps):
494
+ if self.interrupt:
495
+ continue
496
+
497
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
498
+ tracking_maps_latent = torch.cat([tracking_maps] * 2) if do_classifier_free_guidance else tracking_maps
499
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
500
+
501
+ timestep = t.expand(latent_model_input.shape[0])
502
+
503
+ noise_pred = self.transformer(
504
+ hidden_states=latent_model_input,
505
+ encoder_hidden_states=prompt_embeds,
506
+ timestep=timestep,
507
+ image_rotary_emb=image_rotary_emb,
508
+ attention_kwargs=attention_kwargs,
509
+ tracking_maps=tracking_maps_latent,
510
+ return_dict=False,
511
+ )[0]
512
+ noise_pred = noise_pred.float()
513
+
514
+ if use_dynamic_cfg:
515
+ self._guidance_scale = 1 + guidance_scale * (
516
+ (1 - math.cos(math.pi * ((num_inference_steps - t.item()) / num_inference_steps) ** 5.0)) / 2
517
+ )
518
+ if do_classifier_free_guidance:
519
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
520
+ noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
521
+
522
+ if not isinstance(self.scheduler, CogVideoXDPMScheduler):
523
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
524
+ else:
525
+ latents, old_pred_original_sample = self.scheduler.step(
526
+ noise_pred,
527
+ old_pred_original_sample,
528
+ t,
529
+ timesteps[i - 1] if i > 0 else None,
530
+ latents,
531
+ **extra_step_kwargs,
532
+ return_dict=False,
533
+ )
534
+ latents = latents.to(prompt_embeds.dtype)
535
+
536
+ if callback_on_step_end is not None:
537
+ callback_kwargs = {}
538
+ for k in callback_on_step_end_tensor_inputs:
539
+ callback_kwargs[k] = locals()[k]
540
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
541
+
542
+ latents = callback_outputs.pop("latents", latents)
543
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
544
+ negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
545
+
546
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
547
+ progress_bar.update()
548
+
549
+ if not output_type == "latent":
550
+ video = self.decode_latents(latents)
551
+ video = self.video_processor.postprocess_video(video=video, output_type=output_type)
552
+ else:
553
+ video = latents
554
+
555
+ self.maybe_free_model_hooks()
556
+
557
+ if not return_dict:
558
+ return (video,)
559
+ return CogVideoXPipelineOutput(frames=video)
560
+
561
+ class CogVideoXImageToVideoPipelineTracking(CogVideoXImageToVideoPipeline, DiffusionPipeline):
562
+
563
+ def __init__(
564
+ self,
565
+ tokenizer: T5Tokenizer,
566
+ text_encoder: T5EncoderModel,
567
+ vae: AutoencoderKLCogVideoX,
568
+ transformer: CogVideoXTransformer3DModelTracking,
569
+ scheduler: Union[CogVideoXDDIMScheduler, CogVideoXDPMScheduler],
570
+ ):
571
+ super().__init__(tokenizer, text_encoder, vae, transformer, scheduler)
572
+
573
+ if not isinstance(self.transformer, CogVideoXTransformer3DModelTracking):
574
+ raise ValueError("The transformer in this pipeline must be of type CogVideoXTransformer3DModelTracking")
575
+
576
+ # 打印transformer blocks的数量
577
+ print(f"Number of transformer blocks: {len(self.transformer.transformer_blocks)}")
578
+ print(f"Number of tracking transformer blocks: {len(self.transformer.transformer_blocks_copy)}")
579
+ self.transformer = torch.compile(self.transformer)
580
+
581
+ @torch.no_grad()
582
+ def __call__(
583
+ self,
584
+ image: Union[torch.Tensor, Image.Image],
585
+ prompt: Optional[Union[str, List[str]]] = None,
586
+ negative_prompt: Optional[Union[str, List[str]]] = None,
587
+ height: Optional[int] = None,
588
+ width: Optional[int] = None,
589
+ num_frames: int = 49,
590
+ num_inference_steps: int = 50,
591
+ timesteps: Optional[List[int]] = None,
592
+ guidance_scale: float = 6,
593
+ use_dynamic_cfg: bool = False,
594
+ num_videos_per_prompt: int = 1,
595
+ eta: float = 0.0,
596
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
597
+ latents: Optional[torch.FloatTensor] = None,
598
+ prompt_embeds: Optional[torch.FloatTensor] = None,
599
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
600
+ output_type: str = "pil",
601
+ return_dict: bool = True,
602
+ attention_kwargs: Optional[Dict[str, Any]] = None,
603
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
604
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
605
+ max_sequence_length: int = 226,
606
+ tracking_maps: Optional[torch.Tensor] = None,
607
+ tracking_image: Optional[torch.Tensor] = None,
608
+ ) -> Union[CogVideoXPipelineOutput, Tuple]:
609
+ # Most of the implementation remains the same as the parent class
610
+ # We will modify the parts that need to handle tracking_maps
611
+
612
+ # 1. Check inputs and set default values
613
+ self.check_inputs(
614
+ image,
615
+ prompt,
616
+ height,
617
+ width,
618
+ negative_prompt,
619
+ callback_on_step_end_tensor_inputs,
620
+ prompt_embeds,
621
+ negative_prompt_embeds,
622
+ )
623
+ self._guidance_scale = guidance_scale
624
+ self._attention_kwargs = attention_kwargs
625
+ self._interrupt = False
626
+
627
+ if prompt is not None and isinstance(prompt, str):
628
+ batch_size = 1
629
+ elif prompt is not None and isinstance(prompt, list):
630
+ batch_size = len(prompt)
631
+ else:
632
+ batch_size = prompt_embeds.shape[0]
633
+
634
+ device = self._execution_device
635
+
636
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
637
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
638
+ # corresponds to doing no classifier free guidance.
639
+ do_classifier_free_guidance = guidance_scale > 1.0
640
+
641
+ # 3. Encode input prompt
642
+ prompt_embeds, negative_prompt_embeds = self.encode_prompt(
643
+ prompt=prompt,
644
+ negative_prompt=negative_prompt,
645
+ do_classifier_free_guidance=do_classifier_free_guidance,
646
+ num_videos_per_prompt=num_videos_per_prompt,
647
+ prompt_embeds=prompt_embeds,
648
+ negative_prompt_embeds=negative_prompt_embeds,
649
+ max_sequence_length=max_sequence_length,
650
+ device=device,
651
+ )
652
+ if do_classifier_free_guidance:
653
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
654
+ del negative_prompt_embeds
655
+
656
+ # 4. Prepare timesteps
657
+ timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
658
+ self._num_timesteps = len(timesteps)
659
+
660
+ # 5. Prepare latents
661
+ image = self.video_processor.preprocess(image, height=height, width=width).to(
662
+ device, dtype=prompt_embeds.dtype
663
+ )
664
+
665
+ tracking_image = self.video_processor.preprocess(tracking_image, height=height, width=width).to(
666
+ device, dtype=prompt_embeds.dtype
667
+ )
668
+ if self.transformer.config.in_channels != 16:
669
+ latent_channels = self.transformer.config.in_channels // 2
670
+ else:
671
+ latent_channels = self.transformer.config.in_channels
672
+ latents, image_latents = self.prepare_latents(
673
+ image,
674
+ batch_size * num_videos_per_prompt,
675
+ latent_channels,
676
+ num_frames,
677
+ height,
678
+ width,
679
+ prompt_embeds.dtype,
680
+ device,
681
+ generator,
682
+ latents,
683
+ )
684
+ del image
685
+
686
+ _, tracking_image_latents = self.prepare_latents(
687
+ tracking_image,
688
+ batch_size * num_videos_per_prompt,
689
+ latent_channels,
690
+ num_frames,
691
+ height,
692
+ width,
693
+ prompt_embeds.dtype,
694
+ device,
695
+ generator,
696
+ latents=None,
697
+ )
698
+ del tracking_image
699
+
700
+ # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
701
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
702
+
703
+ # 7. Create rotary embeds if required
704
+ image_rotary_emb = (
705
+ self._prepare_rotary_positional_embeddings(height, width, latents.size(1), device)
706
+ if self.transformer.config.use_rotary_positional_embeddings
707
+ else None
708
+ )
709
+
710
+ # 8. Denoising loop
711
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
712
+
713
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
714
+ old_pred_original_sample = None
715
+ for i, t in enumerate(timesteps):
716
+ if self.interrupt:
717
+ continue
718
+
719
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
720
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
721
+
722
+ latent_image_input = torch.cat([image_latents] * 2) if do_classifier_free_guidance else image_latents
723
+ latent_model_input = torch.cat([latent_model_input, latent_image_input], dim=2)
724
+ del latent_image_input
725
+
726
+ # Handle tracking maps
727
+ if tracking_maps is not None:
728
+ latents_tracking_image = torch.cat([tracking_image_latents] * 2) if do_classifier_free_guidance else tracking_image_latents
729
+ tracking_maps_input = torch.cat([tracking_maps] * 2) if do_classifier_free_guidance else tracking_maps
730
+ tracking_maps_input = torch.cat([tracking_maps_input, latents_tracking_image], dim=2)
731
+ del latents_tracking_image
732
+ else:
733
+ tracking_maps_input = None
734
+
735
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
736
+ timestep = t.expand(latent_model_input.shape[0])
737
+
738
+ # Predict noise
739
+ self.transformer.to(dtype=latent_model_input.dtype)
740
+ noise_pred = self.transformer(
741
+ hidden_states=latent_model_input,
742
+ encoder_hidden_states=prompt_embeds,
743
+ timestep=timestep,
744
+ image_rotary_emb=image_rotary_emb,
745
+ attention_kwargs=attention_kwargs,
746
+ tracking_maps=tracking_maps_input,
747
+ return_dict=False,
748
+ )[0]
749
+ del latent_model_input
750
+ if tracking_maps_input is not None:
751
+ del tracking_maps_input
752
+ noise_pred = noise_pred.float()
753
+
754
+ # perform guidance
755
+ if use_dynamic_cfg:
756
+ self._guidance_scale = 1 + guidance_scale * (
757
+ (1 - math.cos(math.pi * ((num_inference_steps - t.item()) / num_inference_steps) ** 5.0)) / 2
758
+ )
759
+ if do_classifier_free_guidance:
760
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
761
+ noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
762
+ del noise_pred_uncond, noise_pred_text
763
+
764
+ # compute the previous noisy sample x_t -> x_t-1
765
+ if not isinstance(self.scheduler, CogVideoXDPMScheduler):
766
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
767
+ else:
768
+ latents, old_pred_original_sample = self.scheduler.step(
769
+ noise_pred,
770
+ old_pred_original_sample,
771
+ t,
772
+ timesteps[i - 1] if i > 0 else None,
773
+ latents,
774
+ **extra_step_kwargs,
775
+ return_dict=False,
776
+ )
777
+ del noise_pred
778
+ latents = latents.to(prompt_embeds.dtype)
779
+
780
+ # call the callback, if provided
781
+ if callback_on_step_end is not None:
782
+ callback_kwargs = {}
783
+ for k in callback_on_step_end_tensor_inputs:
784
+ callback_kwargs[k] = locals()[k]
785
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
786
+
787
+ latents = callback_outputs.pop("latents", latents)
788
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
789
+ negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
790
+
791
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
792
+ progress_bar.update()
793
+
794
+ # 9. Post-processing
795
+ if not output_type == "latent":
796
+ video = self.decode_latents(latents)
797
+ video = self.video_processor.postprocess_video(video=video, output_type=output_type)
798
+ else:
799
+ video = latents
800
+
801
+ # Offload all models
802
+ self.maybe_free_model_hooks()
803
+
804
+ if not return_dict:
805
+ return (video,)
806
+
807
+ return CogVideoXPipelineOutput(frames=video)
808
+
809
+ class CogVideoXVideoToVideoPipelineTracking(CogVideoXVideoToVideoPipeline, DiffusionPipeline):
810
+
811
+ def __init__(
812
+ self,
813
+ tokenizer: T5Tokenizer,
814
+ text_encoder: T5EncoderModel,
815
+ vae: AutoencoderKLCogVideoX,
816
+ transformer: CogVideoXTransformer3DModelTracking,
817
+ scheduler: Union[CogVideoXDDIMScheduler, CogVideoXDPMScheduler],
818
+ ):
819
+ super().__init__(tokenizer, text_encoder, vae, transformer, scheduler)
820
+
821
+ if not isinstance(self.transformer, CogVideoXTransformer3DModelTracking):
822
+ raise ValueError("The transformer in this pipeline must be of type CogVideoXTransformer3DModelTracking")
823
+
824
+ @torch.no_grad()
825
+ def __call__(
826
+ self,
827
+ video: List[Image.Image] = None,
828
+ prompt: Optional[Union[str, List[str]]] = None,
829
+ negative_prompt: Optional[Union[str, List[str]]] = None,
830
+ height: int = 480,
831
+ width: int = 720,
832
+ num_inference_steps: int = 50,
833
+ timesteps: Optional[List[int]] = None,
834
+ strength: float = 0.8,
835
+ guidance_scale: float = 6,
836
+ use_dynamic_cfg: bool = False,
837
+ num_videos_per_prompt: int = 1,
838
+ eta: float = 0.0,
839
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
840
+ latents: Optional[torch.FloatTensor] = None,
841
+ prompt_embeds: Optional[torch.FloatTensor] = None,
842
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
843
+ output_type: str = "pil",
844
+ return_dict: bool = True,
845
+ attention_kwargs: Optional[Dict[str, Any]] = None,
846
+ callback_on_step_end: Optional[
847
+ Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
848
+ ] = None,
849
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
850
+ max_sequence_length: int = 226,
851
+ tracking_maps: Optional[torch.Tensor] = None,
852
+ ) -> Union[CogVideoXPipelineOutput, Tuple]:
853
+
854
+ if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
855
+ callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
856
+
857
+ num_videos_per_prompt = 1
858
+
859
+ # 1. Check inputs. Raise error if not correct
860
+ self.check_inputs(
861
+ prompt=prompt,
862
+ height=height,
863
+ width=width,
864
+ strength=strength,
865
+ negative_prompt=negative_prompt,
866
+ callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
867
+ video=video,
868
+ latents=latents,
869
+ prompt_embeds=prompt_embeds,
870
+ negative_prompt_embeds=negative_prompt_embeds,
871
+ )
872
+ self._guidance_scale = guidance_scale
873
+ self._attention_kwargs = attention_kwargs
874
+ self._interrupt = False
875
+
876
+ # 2. Default call parameters
877
+ if prompt is not None and isinstance(prompt, str):
878
+ batch_size = 1
879
+ elif prompt is not None and isinstance(prompt, list):
880
+ batch_size = len(prompt)
881
+ else:
882
+ batch_size = prompt_embeds.shape[0]
883
+
884
+ device = self._execution_device
885
+
886
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
887
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
888
+ # corresponds to doing no classifier free guidance.
889
+ do_classifier_free_guidance = guidance_scale > 1.0
890
+
891
+ # 3. Encode input prompt
892
+ prompt_embeds, negative_prompt_embeds = self.encode_prompt(
893
+ prompt,
894
+ negative_prompt,
895
+ do_classifier_free_guidance,
896
+ num_videos_per_prompt=num_videos_per_prompt,
897
+ prompt_embeds=prompt_embeds,
898
+ negative_prompt_embeds=negative_prompt_embeds,
899
+ max_sequence_length=max_sequence_length,
900
+ device=device,
901
+ )
902
+ if do_classifier_free_guidance:
903
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
904
+
905
+ # 4. Prepare timesteps
906
+ timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
907
+ timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, timesteps, strength, device)
908
+ latent_timestep = timesteps[:1].repeat(batch_size * num_videos_per_prompt)
909
+ self._num_timesteps = len(timesteps)
910
+
911
+ # 5. Prepare latents
912
+ if latents is None:
913
+ video = self.video_processor.preprocess_video(video, height=height, width=width)
914
+ video = video.to(device=device, dtype=prompt_embeds.dtype)
915
+
916
+ latent_channels = self.transformer.config.in_channels
917
+ latents = self.prepare_latents(
918
+ video,
919
+ batch_size * num_videos_per_prompt,
920
+ latent_channels,
921
+ height,
922
+ width,
923
+ prompt_embeds.dtype,
924
+ device,
925
+ generator,
926
+ latents,
927
+ latent_timestep,
928
+ )
929
+
930
+ # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
931
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
932
+
933
+ # 7. Create rotary embeds if required
934
+ image_rotary_emb = (
935
+ self._prepare_rotary_positional_embeddings(height, width, latents.size(1), device)
936
+ if self.transformer.config.use_rotary_positional_embeddings
937
+ else None
938
+ )
939
+
940
+ # 8. Denoising loop
941
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
942
+
943
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
944
+ # for DPM-solver++
945
+ old_pred_original_sample = None
946
+ for i, t in enumerate(timesteps):
947
+ if self.interrupt:
948
+ continue
949
+
950
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
951
+ tracking_maps_input = torch.cat([tracking_maps] * 2) if do_classifier_free_guidance else tracking_maps
952
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
953
+
954
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
955
+ timestep = t.expand(latent_model_input.shape[0])
956
+
957
+ # predict noise model_output
958
+ noise_pred = self.transformer(
959
+ hidden_states=latent_model_input,
960
+ encoder_hidden_states=prompt_embeds,
961
+ timestep=timestep,
962
+ image_rotary_emb=image_rotary_emb,
963
+ attention_kwargs=attention_kwargs,
964
+ tracking_maps=tracking_maps_input,
965
+ return_dict=False,
966
+ )[0]
967
+ noise_pred = noise_pred.float()
968
+
969
+ # perform guidance
970
+ if use_dynamic_cfg:
971
+ self._guidance_scale = 1 + guidance_scale * (
972
+ (1 - math.cos(math.pi * ((num_inference_steps - t.item()) / num_inference_steps) ** 5.0)) / 2
973
+ )
974
+ if do_classifier_free_guidance:
975
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
976
+ noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
977
+
978
+ # compute the previous noisy sample x_t -> x_t-1
979
+ if not isinstance(self.scheduler, CogVideoXDPMScheduler):
980
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
981
+ else:
982
+ latents, old_pred_original_sample = self.scheduler.step(
983
+ noise_pred,
984
+ old_pred_original_sample,
985
+ t,
986
+ timesteps[i - 1] if i > 0 else None,
987
+ latents,
988
+ **extra_step_kwargs,
989
+ return_dict=False,
990
+ )
991
+ latents = latents.to(prompt_embeds.dtype)
992
+
993
+ # call the callback, if provided
994
+ if callback_on_step_end is not None:
995
+ callback_kwargs = {}
996
+ for k in callback_on_step_end_tensor_inputs:
997
+ callback_kwargs[k] = locals()[k]
998
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
999
+
1000
+ latents = callback_outputs.pop("latents", latents)
1001
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
1002
+ negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
1003
+
1004
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
1005
+ progress_bar.update()
1006
+
1007
+ if not output_type == "latent":
1008
+ video = self.decode_latents(latents)
1009
+ video = self.video_processor.postprocess_video(video=video, output_type=output_type)
1010
+ else:
1011
+ video = latents
1012
+
1013
+ # Offload all models
1014
+ self.maybe_free_model_hooks()
1015
+
1016
+ if not return_dict:
1017
+ return (video,)
1018
+
1019
+ return CogVideoXPipelineOutput(frames=video)
1020
+
models/pipelines.py ADDED
@@ -0,0 +1,1040 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import math
4
+ from tqdm import tqdm
5
+ from PIL import Image, ImageDraw
6
+ project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
7
+ try:
8
+ sys.path.append(os.path.join(project_root, "submodules/MoGe"))
9
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
10
+ except:
11
+ print("Warning: MoGe not found, motion transfer will not be applied")
12
+
13
+ import torch
14
+ import numpy as np
15
+ from PIL import Image
16
+ import torchvision.transforms as transforms
17
+ from diffusers import FluxControlPipeline, CogVideoXDPMScheduler
18
+ from diffusers.utils import export_to_video, load_image, load_video
19
+
20
+ from models.spatracker.predictor import SpaTrackerPredictor
21
+ from models.spatracker.utils.visualizer import Visualizer
22
+ from models.cogvideox_tracking import CogVideoXImageToVideoPipelineTracking
23
+
24
+ from submodules.MoGe.moge.model import MoGeModel
25
+ from image_gen_aux import DepthPreprocessor
26
+ from moviepy.editor import ImageSequenceClip
27
+
28
+ class DiffusionAsShaderPipeline:
29
+ def __init__(self, gpu_id=0, output_dir='outputs'):
30
+ """Initialize MotionTransfer class
31
+
32
+ Args:
33
+ gpu_id (int): GPU device ID
34
+ output_dir (str): Output directory path
35
+ """
36
+ # video parameters
37
+ self.max_depth = 65.0
38
+ self.fps = 8
39
+
40
+ # camera parameters
41
+ self.camera_motion=None
42
+ self.fov=55
43
+
44
+ # device
45
+ self.device = f"cuda:{gpu_id}"
46
+ torch.cuda.set_device(gpu_id)
47
+
48
+ # files
49
+ self.output_dir = output_dir
50
+ os.makedirs(output_dir, exist_ok=True)
51
+
52
+ # Initialize transform
53
+ self.transform = transforms.Compose([
54
+ transforms.Resize((480, 720)),
55
+ transforms.ToTensor()
56
+ ])
57
+
58
+ @torch.no_grad()
59
+ def _infer(
60
+ self,
61
+ prompt: str,
62
+ model_path: str,
63
+ tracking_tensor: torch.Tensor = None,
64
+ image_tensor: torch.Tensor = None, # [C,H,W] in range [0,1]
65
+ output_path: str = "./output.mp4",
66
+ num_inference_steps: int = 50,
67
+ guidance_scale: float = 6.0,
68
+ num_videos_per_prompt: int = 1,
69
+ dtype: torch.dtype = torch.bfloat16,
70
+ fps: int = 24,
71
+ seed: int = 42,
72
+ ):
73
+ """
74
+ Generates a video based on the given prompt and saves it to the specified path.
75
+
76
+ Parameters:
77
+ - prompt (str): The description of the video to be generated.
78
+ - model_path (str): The path of the pre-trained model to be used.
79
+ - tracking_tensor (torch.Tensor): Tracking video tensor [T, C, H, W] in range [0,1]
80
+ - image_tensor (torch.Tensor): Input image tensor [C, H, W] in range [0,1]
81
+ - output_path (str): The path where the generated video will be saved.
82
+ - num_inference_steps (int): Number of steps for the inference process.
83
+ - guidance_scale (float): The scale for classifier-free guidance.
84
+ - num_videos_per_prompt (int): Number of videos to generate per prompt.
85
+ - dtype (torch.dtype): The data type for computation.
86
+ - seed (int): The seed for reproducibility.
87
+ """
88
+ from transformers import T5EncoderModel, T5Tokenizer
89
+ from diffusers import AutoencoderKLCogVideoX, CogVideoXDDIMScheduler
90
+ from models.cogvideox_tracking import CogVideoXTransformer3DModelTracking
91
+
92
+ vae = AutoencoderKLCogVideoX.from_pretrained(model_path, subfolder="vae")
93
+ text_encoder = T5EncoderModel.from_pretrained(model_path, subfolder="text_encoder")
94
+ tokenizer = T5Tokenizer.from_pretrained(model_path, subfolder="tokenizer")
95
+ transformer = CogVideoXTransformer3DModelTracking.from_pretrained(model_path, subfolder="transformer")
96
+ scheduler = CogVideoXDDIMScheduler.from_pretrained(model_path, subfolder="scheduler")
97
+
98
+ pipe = CogVideoXImageToVideoPipelineTracking(
99
+ vae=vae,
100
+ text_encoder=text_encoder,
101
+ tokenizer=tokenizer,
102
+ transformer=transformer,
103
+ scheduler=scheduler
104
+ )
105
+
106
+ # Convert tensor to PIL Image
107
+ image_np = (image_tensor.permute(1, 2, 0).numpy() * 255).astype(np.uint8)
108
+ image = Image.fromarray(image_np)
109
+ height, width = image.height, image.width
110
+
111
+ pipe.transformer.eval()
112
+ pipe.text_encoder.eval()
113
+ pipe.vae.eval()
114
+
115
+ # Process tracking tensor
116
+ tracking_maps = tracking_tensor.float() # [T, C, H, W]
117
+ tracking_maps = tracking_maps.to(device=self.device, dtype=dtype)
118
+ tracking_first_frame = tracking_maps[0:1] # Get first frame as [1, C, H, W]
119
+ height, width = tracking_first_frame.shape[2], tracking_first_frame.shape[3]
120
+
121
+ # 2. Set Scheduler.
122
+ pipe.scheduler = CogVideoXDPMScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing")
123
+
124
+ pipe.to(self.device, dtype=dtype)
125
+ # pipe.enable_sequential_cpu_offload()
126
+
127
+ pipe.vae.enable_slicing()
128
+ pipe.vae.enable_tiling()
129
+ pipe.transformer.eval()
130
+ pipe.text_encoder.eval()
131
+ pipe.vae.eval()
132
+
133
+ pipe.transformer.gradient_checkpointing = False
134
+
135
+ print("Encoding tracking maps")
136
+ tracking_maps = tracking_maps.unsqueeze(0) # [B, T, C, H, W]
137
+ tracking_maps = tracking_maps.permute(0, 2, 1, 3, 4) # [B, C, T, H, W]
138
+ tracking_latent_dist = pipe.vae.encode(tracking_maps).latent_dist
139
+ tracking_maps = tracking_latent_dist.sample() * pipe.vae.config.scaling_factor
140
+ tracking_maps = tracking_maps.permute(0, 2, 1, 3, 4) # [B, F, C, H, W]
141
+
142
+ # 4. Generate the video frames based on the prompt.
143
+ video_generate = pipe(
144
+ prompt=prompt,
145
+ negative_prompt="The video is not of a high quality, it has a low resolution. Watermark present in each frame. The background is solid. Strange body and strange trajectory. Distortion.",
146
+ image=image,
147
+ num_videos_per_prompt=num_videos_per_prompt,
148
+ num_inference_steps=num_inference_steps,
149
+ num_frames=49,
150
+ use_dynamic_cfg=True,
151
+ guidance_scale=guidance_scale,
152
+ generator=torch.Generator().manual_seed(seed),
153
+ tracking_maps=tracking_maps,
154
+ tracking_image=tracking_first_frame,
155
+ height=height,
156
+ width=width,
157
+ ).frames[0]
158
+
159
+ # 5. Export the generated frames to a video file. fps must be 8 for original video.
160
+ output_path = output_path if output_path else f"result.mp4"
161
+ os.makedirs(os.path.dirname(output_path), exist_ok=True)
162
+ export_to_video(video_generate, output_path, fps=fps)
163
+
164
+ #========== camera parameters ==========#
165
+
166
+ def _set_camera_motion(self, camera_motion):
167
+ self.camera_motion = camera_motion
168
+
169
+ def _get_intr(self, fov, H=480, W=720):
170
+ fov_rad = math.radians(fov)
171
+ focal_length = (W / 2) / math.tan(fov_rad / 2)
172
+
173
+ cx = W / 2
174
+ cy = H / 2
175
+
176
+ intr = torch.tensor([
177
+ [focal_length, 0, cx],
178
+ [0, focal_length, cy],
179
+ [0, 0, 1]
180
+ ], dtype=torch.float32)
181
+
182
+ return intr
183
+
184
+ def _apply_poses(self, pts, intr, poses):
185
+ """
186
+ Args:
187
+ pts (torch.Tensor): pointclouds coordinates [T, N, 3]
188
+ intr (torch.Tensor): camera intrinsics [T, 3, 3]
189
+ poses (numpy.ndarray): camera poses [T, 4, 4]
190
+ """
191
+ poses = torch.from_numpy(poses).float().to(self.device)
192
+
193
+ T, N, _ = pts.shape
194
+ ones = torch.ones(T, N, 1, device=self.device, dtype=torch.float)
195
+ pts_hom = torch.cat([pts[:, :, :2], ones], dim=-1) # (T, N, 3)
196
+ pts_cam = torch.bmm(pts_hom, torch.linalg.inv(intr).transpose(1, 2)) # (T, N, 3)
197
+ pts_cam[:,:, :3] /= pts[:, :, 2:3]
198
+
199
+ # to homogeneous
200
+ pts_cam = torch.cat([pts_cam, ones], dim=-1) # (T, N, 4)
201
+
202
+ if poses.shape[0] == 1:
203
+ poses = poses.repeat(T, 1, 1)
204
+ elif poses.shape[0] != T:
205
+ raise ValueError(f"Poses length ({poses.shape[0]}) must match sequence length ({T})")
206
+
207
+ pts_world = torch.bmm(pts_cam, poses.transpose(1, 2))[:, :, :3] # (T, N, 3)
208
+
209
+ pts_proj = torch.bmm(pts_world, intr.transpose(1, 2)) # (T, N, 3)
210
+ pts_proj[:, :, :2] /= pts_proj[:, :, 2:3]
211
+
212
+ return pts_proj
213
+
214
+ def apply_traj_on_tracking(self, pred_tracks, camera_motion=None, fov=55, frame_num=49):
215
+ intr = self._get_intr(fov).unsqueeze(0).repeat(frame_num, 1, 1).to(self.device)
216
+ tracking_pts = self._apply_poses(pred_tracks.squeeze(), intr, camera_motion).unsqueeze(0)
217
+ return tracking_pts
218
+
219
+ ##============= SpatialTracker =============##
220
+
221
+ def generate_tracking_spatracker(self, video_tensor, density=70):
222
+ """Generate tracking video
223
+
224
+ Args:
225
+ video_tensor (torch.Tensor): Input video tensor
226
+
227
+ Returns:
228
+ str: Path to tracking video
229
+ """
230
+ print("Loading tracking models...")
231
+ # Load tracking model
232
+ tracker = SpaTrackerPredictor(
233
+ checkpoint=os.path.join(project_root, 'checkpoints/spatracker/spaT_final.pth'),
234
+ interp_shape=(384, 576),
235
+ seq_length=12
236
+ ).to(self.device)
237
+
238
+ # Load depth model
239
+ self.depth_preprocessor = DepthPreprocessor.from_pretrained("Intel/zoedepth-nyu-kitti")
240
+ self.depth_preprocessor.to(self.device)
241
+
242
+ try:
243
+ video = video_tensor.unsqueeze(0).to(self.device)
244
+
245
+ video_depths = []
246
+ for i in range(video_tensor.shape[0]):
247
+ frame = (video_tensor[i].permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8)
248
+ depth = self.depth_preprocessor(Image.fromarray(frame))[0]
249
+ depth_tensor = transforms.ToTensor()(depth) # [1, H, W]
250
+ video_depths.append(depth_tensor)
251
+ video_depth = torch.stack(video_depths, dim=0).to(self.device)
252
+ # print("Video depth shape:", video_depth.shape)
253
+
254
+ segm_mask = np.ones((480, 720), dtype=np.uint8)
255
+
256
+ pred_tracks, pred_visibility, T_Firsts = tracker(
257
+ video * 255,
258
+ video_depth=video_depth,
259
+ grid_size=density,
260
+ backward_tracking=False,
261
+ depth_predictor=None,
262
+ grid_query_frame=0,
263
+ segm_mask=torch.from_numpy(segm_mask)[None, None].to(self.device),
264
+ wind_length=12,
265
+ progressive_tracking=False
266
+ )
267
+
268
+ return pred_tracks, pred_visibility, T_Firsts
269
+
270
+ finally:
271
+ # Clean up GPU memory
272
+ del tracker, self.depth_preprocessor
273
+ torch.cuda.empty_cache()
274
+
275
+ def visualize_tracking_spatracker(self, video, pred_tracks, pred_visibility, T_Firsts, save_tracking=True):
276
+ video = video.unsqueeze(0).to(self.device)
277
+ vis = Visualizer(save_dir=self.output_dir, grayscale=False, fps=24, pad_value=0)
278
+ msk_query = (T_Firsts == 0)
279
+ pred_tracks = pred_tracks[:,:,msk_query.squeeze()]
280
+ pred_visibility = pred_visibility[:,:,msk_query.squeeze()]
281
+
282
+ tracking_video = vis.visualize(video=video, tracks=pred_tracks,
283
+ visibility=pred_visibility, save_video=False,
284
+ filename="temp")
285
+
286
+ tracking_video = tracking_video.squeeze(0) # [T, C, H, W]
287
+ wide_list = list(tracking_video.unbind(0))
288
+ wide_list = [wide.permute(1, 2, 0).cpu().numpy() for wide in wide_list]
289
+ clip = ImageSequenceClip(wide_list, fps=self.fps)
290
+
291
+ tracking_path = None
292
+ if save_tracking:
293
+ try:
294
+ tracking_path = os.path.join(self.output_dir, "tracking_video.mp4")
295
+ clip.write_videofile(tracking_path, codec="libx264", fps=self.fps, logger=None)
296
+ print(f"Video saved to {tracking_path}")
297
+ except Exception as e:
298
+ print(f"Warning: Failed to save tracking video: {e}")
299
+ tracking_path = None
300
+
301
+ # Convert tracking_video back to tensor in range [0,1]
302
+ tracking_frames = np.array(list(clip.iter_frames())) / 255.0
303
+ tracking_video = torch.from_numpy(tracking_frames).permute(0, 3, 1, 2).float()
304
+
305
+ return tracking_path, tracking_video
306
+
307
+ ##============= MoGe =============##
308
+
309
+ def valid_mask(self, pixels, W, H):
310
+ """Check if pixels are within valid image bounds
311
+
312
+ Args:
313
+ pixels (numpy.ndarray): Pixel coordinates of shape [N, 2]
314
+ W (int): Image width
315
+ H (int): Image height
316
+
317
+ Returns:
318
+ numpy.ndarray: Boolean mask of valid pixels
319
+ """
320
+ return ((pixels[:, 0] >= 0) & (pixels[:, 0] < W) & (pixels[:, 1] > 0) & \
321
+ (pixels[:, 1] < H))
322
+
323
+ def sort_points_by_depth(self, points, depths):
324
+ """Sort points by depth values
325
+
326
+ Args:
327
+ points (numpy.ndarray): Points array of shape [N, 2]
328
+ depths (numpy.ndarray): Depth values of shape [N]
329
+
330
+ Returns:
331
+ tuple: (sorted_points, sorted_depths, sort_index)
332
+ """
333
+ # Combine points and depths into a single array for sorting
334
+ combined = np.hstack((points, depths[:, None])) # Nx3 (points + depth)
335
+ # Sort by depth (last column) in descending order
336
+ sort_index = combined[:, -1].argsort()[::-1]
337
+ sorted_combined = combined[sort_index]
338
+ # Split back into points and depths
339
+ sorted_points = sorted_combined[:, :-1]
340
+ sorted_depths = sorted_combined[:, -1]
341
+ return sorted_points, sorted_depths, sort_index
342
+
343
+ def draw_rectangle(self, rgb, coord, side_length, color=(255, 0, 0)):
344
+ """Draw a rectangle on the image
345
+
346
+ Args:
347
+ rgb (PIL.Image): Image to draw on
348
+ coord (tuple): Center coordinates (x, y)
349
+ side_length (int): Length of rectangle sides
350
+ color (tuple): RGB color tuple
351
+ """
352
+ draw = ImageDraw.Draw(rgb)
353
+ # Calculate the bounding box of the rectangle
354
+ left_up_point = (coord[0] - side_length//2, coord[1] - side_length//2)
355
+ right_down_point = (coord[0] + side_length//2, coord[1] + side_length//2)
356
+ color = tuple(list(color))
357
+
358
+ draw.rectangle(
359
+ [left_up_point, right_down_point],
360
+ fill=tuple(color),
361
+ outline=tuple(color),
362
+ )
363
+
364
+ def visualize_tracking_moge(self, points, mask, save_tracking=True):
365
+ """Visualize tracking results from MoGe model
366
+
367
+ Args:
368
+ points (numpy.ndarray): Points array of shape [T, H, W, 3]
369
+ mask (numpy.ndarray): Binary mask of shape [H, W]
370
+ save_tracking (bool): Whether to save tracking video
371
+
372
+ Returns:
373
+ tuple: (tracking_path, tracking_video)
374
+ - tracking_path (str): Path to saved tracking video, None if save_tracking is False
375
+ - tracking_video (torch.Tensor): Tracking visualization tensor of shape [T, C, H, W] in range [0,1]
376
+ """
377
+ # Create color array
378
+ T, H, W, _ = points.shape
379
+ colors = np.zeros((H, W, 3), dtype=np.uint8)
380
+
381
+ # Set R channel - based on x coordinates (smaller on the left)
382
+ colors[:, :, 0] = np.tile(np.linspace(0, 255, W), (H, 1))
383
+
384
+ # Set G channel - based on y coordinates (smaller on the top)
385
+ colors[:, :, 1] = np.tile(np.linspace(0, 255, H), (W, 1)).T
386
+
387
+ # Set B channel - based on depth
388
+ z_values = points[0, :, :, 2] # get z values
389
+ inv_z = 1 / z_values # calculate 1/z
390
+ # Calculate 2% and 98% percentiles
391
+ p2 = np.percentile(inv_z, 2)
392
+ p98 = np.percentile(inv_z, 98)
393
+ # Normalize to [0,1] range
394
+ normalized_z = np.clip((inv_z - p2) / (p98 - p2), 0, 1)
395
+ colors[:, :, 2] = (normalized_z * 255).astype(np.uint8)
396
+ colors = colors.astype(np.uint8)
397
+ # colors = colors * mask[..., None]
398
+ # points = points * mask[None, :, :, None]
399
+
400
+ points = points.reshape(T, -1, 3)
401
+ colors = colors.reshape(-1, 3)
402
+
403
+ # Initialize list to store frames
404
+ frames = []
405
+
406
+ for i, pts_i in enumerate(tqdm(points)):
407
+ pixels, depths = pts_i[..., :2], pts_i[..., 2]
408
+ pixels[..., 0] = pixels[..., 0] * W
409
+ pixels[..., 1] = pixels[..., 1] * H
410
+ pixels = pixels.astype(int)
411
+
412
+ valid = self.valid_mask(pixels, W, H)
413
+ frame_rgb = colors[valid]
414
+ pixels = pixels[valid]
415
+ depths = depths[valid]
416
+
417
+ img = Image.fromarray(np.uint8(np.zeros([H, W, 3])), mode="RGB")
418
+ sorted_pixels, _, sort_index = self.sort_points_by_depth(pixels, depths)
419
+ step = 1
420
+ sorted_pixels = sorted_pixels[::step]
421
+ sorted_rgb = frame_rgb[sort_index][::step]
422
+
423
+ for j in range(sorted_pixels.shape[0]):
424
+ self.draw_rectangle(
425
+ img,
426
+ coord=(sorted_pixels[j, 0], sorted_pixels[j, 1]),
427
+ side_length=2,
428
+ color=sorted_rgb[j],
429
+ )
430
+ frames.append(np.array(img))
431
+
432
+ # Convert frames to video tensor in range [0,1]
433
+ tracking_video = torch.from_numpy(np.stack(frames)).permute(0, 3, 1, 2).float() / 255.0
434
+
435
+ tracking_path = None
436
+ if save_tracking:
437
+ try:
438
+ tracking_path = os.path.join(self.output_dir, "tracking_video_moge.mp4")
439
+ # Convert back to uint8 for saving
440
+ uint8_frames = [frame.astype(np.uint8) for frame in frames]
441
+ clip = ImageSequenceClip(uint8_frames, fps=self.fps)
442
+ clip.write_videofile(tracking_path, codec="libx264", fps=self.fps, logger=None)
443
+ print(f"Video saved to {tracking_path}")
444
+ except Exception as e:
445
+ print(f"Warning: Failed to save tracking video: {e}")
446
+ tracking_path = None
447
+
448
+ return tracking_path, tracking_video
449
+
450
+ def apply_tracking(self, video_tensor, fps=8, tracking_tensor=None, img_cond_tensor=None, prompt=None, checkpoint_path=None):
451
+ """Generate final video with motion transfer
452
+
453
+ Args:
454
+ video_tensor (torch.Tensor): Input video tensor [T,C,H,W]
455
+ fps (float): Input video FPS
456
+ tracking_tensor (torch.Tensor): Tracking video tensor [T,C,H,W]
457
+ image_tensor (torch.Tensor): First frame tensor [C,H,W] to use for generation
458
+ prompt (str): Generation prompt
459
+ checkpoint_path (str): Path to model checkpoint
460
+ """
461
+ self.fps = fps
462
+
463
+ # Use first frame if no image provided
464
+ if img_cond_tensor is None:
465
+ img_cond_tensor = video_tensor[0]
466
+
467
+ # Generate final video
468
+ final_output = os.path.join(os.path.abspath(self.output_dir), "result.mp4")
469
+ self._infer(
470
+ prompt=prompt,
471
+ model_path=checkpoint_path,
472
+ tracking_tensor=tracking_tensor,
473
+ image_tensor=img_cond_tensor,
474
+ output_path=final_output,
475
+ num_inference_steps=50,
476
+ guidance_scale=6.0,
477
+ dtype=torch.bfloat16,
478
+ fps=self.fps
479
+ )
480
+ print(f"Final video generated successfully at: {final_output}")
481
+
482
+ def _set_object_motion(self, motion_type):
483
+ """Set object motion type
484
+
485
+ Args:
486
+ motion_type (str): Motion direction ('up', 'down', 'left', 'right')
487
+ """
488
+ self.object_motion = motion_type
489
+
490
+ class FirstFrameRepainter:
491
+ def __init__(self, gpu_id=0, output_dir='outputs'):
492
+ """Initialize FirstFrameRepainter
493
+
494
+ Args:
495
+ gpu_id (int): GPU device ID
496
+ output_dir (str): Output directory path
497
+ """
498
+ self.device = f"cuda:{gpu_id}"
499
+ self.output_dir = output_dir
500
+ self.max_depth = 65.0
501
+ os.makedirs(output_dir, exist_ok=True)
502
+
503
+ def repaint(self, image_tensor, prompt, depth_path=None, method="dav"):
504
+ """Repaint first frame using Flux
505
+
506
+ Args:
507
+ image_tensor (torch.Tensor): Input image tensor [C,H,W]
508
+ prompt (str): Repaint prompt
509
+ depth_path (str): Path to depth image
510
+ method (str): depth estimator, "moge" or "dav" or "zoedepth"
511
+
512
+ Returns:
513
+ torch.Tensor: Repainted image tensor [C,H,W]
514
+ """
515
+ print("Loading Flux model...")
516
+ # Load Flux model
517
+ flux_pipe = FluxControlPipeline.from_pretrained(
518
+ "black-forest-labs/FLUX.1-Depth-dev",
519
+ torch_dtype=torch.bfloat16
520
+ ).to(self.device)
521
+
522
+ # Get depth map
523
+ if depth_path is None:
524
+ if method == "moge":
525
+ self.moge_model = MoGeModel.from_pretrained("Ruicheng/moge-vitl").to(self.device)
526
+ depth_map = self.moge_model.infer(image_tensor.to(self.device))["depth"]
527
+ depth_map = torch.clamp(depth_map, max=self.max_depth)
528
+ depth_normalized = 1.0 - (depth_map / self.max_depth)
529
+ depth_rgb = (depth_normalized * 255).cpu().numpy().astype(np.uint8)
530
+ control_image = Image.fromarray(depth_rgb).convert("RGB")
531
+ elif method == "zoedepth":
532
+ self.depth_preprocessor = DepthPreprocessor.from_pretrained("Intel/zoedepth-nyu-kitti")
533
+ self.depth_preprocessor.to(self.device)
534
+ image_np = (image_tensor.permute(1, 2, 0).numpy() * 255).astype(np.uint8)
535
+ control_image = self.depth_preprocessor(Image.fromarray(image_np))[0].convert("RGB")
536
+ control_image = control_image.point(lambda x: 255 - x) # the zoedepth depth is inverted
537
+ else:
538
+ self.depth_preprocessor = DepthPreprocessor.from_pretrained("depth-anything/Depth-Anything-V2-Large-hf")
539
+ self.depth_preprocessor.to(self.device)
540
+ image_np = (image_tensor.permute(1, 2, 0).numpy() * 255).astype(np.uint8)
541
+ control_image = self.depth_preprocessor(Image.fromarray(image_np))[0].convert("RGB")
542
+ else:
543
+ control_image = Image.open(depth_path).convert("RGB")
544
+
545
+ try:
546
+ repainted_image = flux_pipe(
547
+ prompt=prompt,
548
+ control_image=control_image,
549
+ height=480,
550
+ width=720,
551
+ num_inference_steps=30,
552
+ guidance_scale=7.5,
553
+ ).images[0]
554
+
555
+ # Save repainted image
556
+ repainted_image.save(os.path.join(self.output_dir, "temp_repainted.png"))
557
+
558
+ # Convert PIL Image to tensor
559
+ transform = transforms.Compose([
560
+ transforms.ToTensor()
561
+ ])
562
+ repainted_tensor = transform(repainted_image)
563
+
564
+ return repainted_tensor
565
+
566
+ finally:
567
+ # Clean up GPU memory
568
+ del flux_pipe
569
+ if method == "moge":
570
+ del self.moge_model
571
+ else:
572
+ del self.depth_preprocessor
573
+ torch.cuda.empty_cache()
574
+
575
+ class CameraMotionGenerator:
576
+ def __init__(self, motion_type, frame_num=49, H=480, W=720, fx=None, fy=None, fov=55, device='cuda'):
577
+ self.motion_type = motion_type
578
+ self.frame_num = frame_num
579
+ self.fov = fov
580
+ self.device = device
581
+ self.W = W
582
+ self.H = H
583
+ self.intr = torch.tensor([
584
+ [0, 0, W / 2],
585
+ [0, 0, H / 2],
586
+ [0, 0, 1]
587
+ ], dtype=torch.float32, device=device)
588
+ # if fx, fy not provided
589
+ if not fx or not fy:
590
+ fov_rad = math.radians(fov)
591
+ fx = fy = (W / 2) / math.tan(fov_rad / 2)
592
+
593
+ self.intr[0, 0] = fx
594
+ self.intr[1, 1] = fy
595
+
596
+ def _apply_poses(self, pts, poses):
597
+ """
598
+ Args:
599
+ pts (torch.Tensor): pointclouds coordinates [T, N, 3]
600
+ intr (torch.Tensor): camera intrinsics [T, 3, 3]
601
+ poses (numpy.ndarray): camera poses [T, 4, 4]
602
+ """
603
+ if isinstance(poses, np.ndarray):
604
+ poses = torch.from_numpy(poses)
605
+
606
+ intr = self.intr.unsqueeze(0).repeat(self.frame_num, 1, 1).to(torch.float)
607
+ T, N, _ = pts.shape
608
+ ones = torch.ones(T, N, 1, device=self.device, dtype=torch.float)
609
+ pts_hom = torch.cat([pts[:, :, :2], ones], dim=-1) # (T, N, 3)
610
+ pts_cam = torch.bmm(pts_hom, torch.linalg.inv(intr).transpose(1, 2)) # (T, N, 3)
611
+ pts_cam[:,:, :3] *= pts[:, :, 2:3]
612
+
613
+ # to homogeneous
614
+ pts_cam = torch.cat([pts_cam, ones], dim=-1) # (T, N, 4)
615
+
616
+ if poses.shape[0] == 1:
617
+ poses = poses.repeat(T, 1, 1)
618
+ elif poses.shape[0] != T:
619
+ raise ValueError(f"Poses length ({poses.shape[0]}) must match sequence length ({T})")
620
+
621
+ poses = poses.to(torch.float).to(self.device)
622
+ pts_world = torch.bmm(pts_cam, poses.transpose(1, 2))[:, :, :3] # (T, N, 3)
623
+ pts_proj = torch.bmm(pts_world, intr.transpose(1, 2)) # (T, N, 3)
624
+ pts_proj[:, :, :2] /= pts_proj[:, :, 2:3]
625
+
626
+ return pts_proj
627
+
628
+ def w2s(self, pts, poses):
629
+ if isinstance(poses, np.ndarray):
630
+ poses = torch.from_numpy(poses)
631
+ assert poses.shape[0] == self.frame_num
632
+ poses = poses.to(torch.float32).to(self.device)
633
+ T, N, _ = pts.shape # (T, N, 3)
634
+ intr = self.intr.unsqueeze(0).repeat(self.frame_num, 1, 1)
635
+ # Step 1: 扩展点的维度,使其变成 (T, N, 4),最后一维填充1 (齐次坐标)
636
+ ones = torch.ones((T, N, 1), device=self.device, dtype=pts.dtype)
637
+ points_world_h = torch.cat([pts, ones], dim=-1)
638
+ points_camera_h = torch.bmm(poses, points_world_h.permute(0, 2, 1))
639
+ points_camera = points_camera_h[:, :3, :].permute(0, 2, 1)
640
+
641
+ points_image_h = torch.bmm(points_camera, intr.permute(0, 2, 1))
642
+
643
+ uv = points_image_h[:, :, :2] / points_image_h[:, :, 2:3]
644
+
645
+ # Step 5: 提取深度 (Z) 并拼接
646
+ depth = points_camera[:, :, 2:3] # (T, N, 1)
647
+ uvd = torch.cat([uv, depth], dim=-1) # (T, N, 3)
648
+
649
+ return uvd # 屏幕坐标 + 深度 (T, N, 3)
650
+
651
+ def apply_motion_on_pts(self, pts, camera_motion):
652
+ tracking_pts = self._apply_poses(pts.squeeze(), camera_motion).unsqueeze(0)
653
+ return tracking_pts
654
+
655
+ def set_intr(self, K):
656
+ if isinstance(K, np.ndarray):
657
+ K = torch.from_numpy(K)
658
+ self.intr = K.to(self.device)
659
+
660
+ def rot_poses(self, angle, axis='y'):
661
+ """Generate a single rotation matrix
662
+
663
+ Args:
664
+ angle (float): Rotation angle in degrees
665
+ axis (str): Rotation axis ('x', 'y', or 'z')
666
+
667
+ Returns:
668
+ torch.Tensor: Single rotation matrix [4, 4]
669
+ """
670
+ angle_rad = math.radians(angle)
671
+ cos_theta = torch.cos(torch.tensor(angle_rad))
672
+ sin_theta = torch.sin(torch.tensor(angle_rad))
673
+
674
+ if axis == 'x':
675
+ rot_mat = torch.tensor([
676
+ [1, 0, 0, 0],
677
+ [0, cos_theta, -sin_theta, 0],
678
+ [0, sin_theta, cos_theta, 0],
679
+ [0, 0, 0, 1]
680
+ ], dtype=torch.float32)
681
+ elif axis == 'y':
682
+ rot_mat = torch.tensor([
683
+ [cos_theta, 0, sin_theta, 0],
684
+ [0, 1, 0, 0],
685
+ [-sin_theta, 0, cos_theta, 0],
686
+ [0, 0, 0, 1]
687
+ ], dtype=torch.float32)
688
+ elif axis == 'z':
689
+ rot_mat = torch.tensor([
690
+ [cos_theta, -sin_theta, 0, 0],
691
+ [sin_theta, cos_theta, 0, 0],
692
+ [0, 0, 1, 0],
693
+ [0, 0, 0, 1]
694
+ ], dtype=torch.float32)
695
+ else:
696
+ raise ValueError("Invalid axis value. Choose 'x', 'y', or 'z'.")
697
+
698
+ return rot_mat.to(self.device)
699
+
700
+ def trans_poses(self, dx, dy, dz):
701
+ """
702
+ params:
703
+ - dx: float, displacement along x axis。
704
+ - dy: float, displacement along y axis。
705
+ - dz: float, displacement along z axis。
706
+
707
+ ret:
708
+ - matrices: torch.Tensor
709
+ """
710
+ trans_mats = torch.eye(4).unsqueeze(0).repeat(self.frame_num, 1, 1) # (n, 4, 4)
711
+
712
+ delta_x = dx / (self.frame_num - 1)
713
+ delta_y = dy / (self.frame_num - 1)
714
+ delta_z = dz / (self.frame_num - 1)
715
+
716
+ for i in range(self.frame_num):
717
+ trans_mats[i, 0, 3] = i * delta_x
718
+ trans_mats[i, 1, 3] = i * delta_y
719
+ trans_mats[i, 2, 3] = i * delta_z
720
+
721
+ return trans_mats.to(self.device)
722
+
723
+
724
+ def _look_at(self, camera_position, target_position):
725
+ # look at direction
726
+ direction = target_position - camera_position
727
+ direction /= np.linalg.norm(direction)
728
+ # calculate rotation matrix
729
+ up = np.array([0, 1, 0])
730
+ right = np.cross(up, direction)
731
+ right /= np.linalg.norm(right)
732
+ up = np.cross(direction, right)
733
+ rotation_matrix = np.vstack([right, up, direction])
734
+ rotation_matrix = np.linalg.inv(rotation_matrix)
735
+ return rotation_matrix
736
+
737
+ def spiral_poses(self, radius, forward_ratio = 0.5, backward_ratio = 0.5, rotation_times = 0.1, look_at_times = 0.5):
738
+ """Generate spiral camera poses
739
+
740
+ Args:
741
+ radius (float): Base radius of the spiral
742
+ forward_ratio (float): Scale factor for forward motion
743
+ backward_ratio (float): Scale factor for backward motion
744
+ rotation_times (float): Number of rotations to complete
745
+ look_at_times (float): Scale factor for look-at point distance
746
+
747
+ Returns:
748
+ torch.Tensor: Camera poses of shape [num_frames, 4, 4]
749
+ """
750
+ # Generate spiral trajectory
751
+ t = np.linspace(0, 1, self.frame_num)
752
+ r = np.sin(np.pi * t) * radius * rotation_times
753
+ theta = 2 * np.pi * t
754
+
755
+ # Calculate camera positions
756
+ # Limit y motion for better floor/sky view
757
+ y = r * np.cos(theta) * 0.3
758
+ x = r * np.sin(theta)
759
+ z = -r
760
+ z[z < 0] *= forward_ratio
761
+ z[z > 0] *= backward_ratio
762
+
763
+ # Set look-at target
764
+ target_pos = np.array([0, 0, radius * look_at_times])
765
+ cam_pos = np.vstack([x, y, z]).T
766
+ cam_poses = []
767
+
768
+ for pos in cam_pos:
769
+ rot_mat = self._look_at(pos, target_pos)
770
+ trans_mat = np.eye(4)
771
+ trans_mat[:3, :3] = rot_mat
772
+ trans_mat[:3, 3] = pos
773
+ cam_poses.append(trans_mat[None])
774
+
775
+ camera_poses = np.concatenate(cam_poses, axis=0)
776
+ return torch.from_numpy(camera_poses).to(self.device)
777
+
778
+ def rot(self, pts, angle, axis):
779
+ """
780
+ pts: torch.Tensor, (T, N, 2)
781
+ """
782
+ rot_mats = self.rot_poses(angle, axis)
783
+ pts = self.apply_motion_on_pts(pts, rot_mats)
784
+ return pts
785
+
786
+ def trans(self, pts, dx, dy, dz):
787
+ if pts.shape[-1] != 3:
788
+ raise ValueError("points should be in the 3d coordinate.")
789
+ trans_mats = self.trans_poses(dx, dy, dz)
790
+ pts = self.apply_motion_on_pts(pts, trans_mats)
791
+ return pts
792
+
793
+ def spiral(self, pts, radius):
794
+ spiral_poses = self.spiral_poses(radius)
795
+ pts = self.apply_motion_on_pts(pts, spiral_poses)
796
+ return pts
797
+
798
+ def get_default_motion(self):
799
+ """Parse motion parameters and generate corresponding motion matrices
800
+
801
+ Supported formats:
802
+ - trans <dx> <dy> <dz> [start_frame] [end_frame]: Translation motion
803
+ - rot <axis> <angle> [start_frame] [end_frame]: Rotation motion
804
+ - spiral <radius> [start_frame] [end_frame]: Spiral motion
805
+
806
+ Multiple transformations can be combined using semicolon (;) as separator:
807
+ e.g., "trans 0 0 0.5 0 30; rot x 25 0 30; trans 0.1 0 0 30 48"
808
+
809
+ Note:
810
+ - start_frame and end_frame are optional
811
+ - frame range: 0-49 (will be clamped to this range)
812
+ - if not specified, defaults to 0-49
813
+ - frames after end_frame will maintain the final transformation
814
+ - for combined transformations, they are applied in sequence
815
+
816
+ Returns:
817
+ torch.Tensor: Motion matrices [num_frames, 4, 4]
818
+ """
819
+ if not isinstance(self.motion_type, str):
820
+ raise ValueError(f'camera_motion must be a string, but got {type(self.motion_type)}')
821
+
822
+ # Split combined transformations
823
+ transform_sequences = [s.strip() for s in self.motion_type.split(';')]
824
+
825
+ # Initialize the final motion matrices
826
+ final_motion = torch.eye(4, device=self.device).unsqueeze(0).repeat(49, 1, 1)
827
+
828
+ # Process each transformation in sequence
829
+ for transform in transform_sequences:
830
+ params = transform.lower().split()
831
+ if not params:
832
+ continue
833
+
834
+ motion_type = params[0]
835
+
836
+ # Default frame range
837
+ start_frame = 0
838
+ end_frame = 48 # 49 frames in total (0-48)
839
+
840
+ if motion_type == 'trans':
841
+ # Parse translation parameters
842
+ if len(params) not in [4, 6]:
843
+ raise ValueError(f"trans motion requires 3 or 5 parameters: 'trans <dx> <dy> <dz>' or 'trans <dx> <dy> <dz> <start_frame> <end_frame>', got: {transform}")
844
+
845
+ dx, dy, dz = map(float, params[1:4])
846
+
847
+ if len(params) == 6:
848
+ start_frame = max(0, min(48, int(params[4])))
849
+ end_frame = max(0, min(48, int(params[5])))
850
+ if start_frame > end_frame:
851
+ start_frame, end_frame = end_frame, start_frame
852
+
853
+ # Generate current transformation
854
+ current_motion = torch.eye(4, device=self.device).unsqueeze(0).repeat(49, 1, 1)
855
+ for frame_idx in range(49):
856
+ if frame_idx < start_frame:
857
+ continue
858
+ elif frame_idx <= end_frame:
859
+ t = (frame_idx - start_frame) / (end_frame - start_frame)
860
+ current_motion[frame_idx, :3, 3] = torch.tensor([dx, dy, dz], device=self.device) * t
861
+ else:
862
+ current_motion[frame_idx] = current_motion[end_frame]
863
+
864
+ # Combine with previous transformations
865
+ final_motion = torch.matmul(final_motion, current_motion)
866
+
867
+ elif motion_type == 'rot':
868
+ # Parse rotation parameters
869
+ if len(params) not in [3, 5]:
870
+ raise ValueError(f"rot motion requires 2 or 4 parameters: 'rot <axis> <angle>' or 'rot <axis> <angle> <start_frame> <end_frame>', got: {transform}")
871
+
872
+ axis = params[1]
873
+ if axis not in ['x', 'y', 'z']:
874
+ raise ValueError(f"Invalid rotation axis '{axis}', must be 'x', 'y' or 'z'")
875
+ angle = float(params[2])
876
+
877
+ if len(params) == 5:
878
+ start_frame = max(0, min(48, int(params[3])))
879
+ end_frame = max(0, min(48, int(params[4])))
880
+ if start_frame > end_frame:
881
+ start_frame, end_frame = end_frame, start_frame
882
+
883
+ current_motion = torch.eye(4, device=self.device).unsqueeze(0).repeat(49, 1, 1)
884
+ for frame_idx in range(49):
885
+ if frame_idx < start_frame:
886
+ continue
887
+ elif frame_idx <= end_frame:
888
+ t = (frame_idx - start_frame) / (end_frame - start_frame)
889
+ current_angle = angle * t
890
+ current_motion[frame_idx] = self.rot_poses(current_angle, axis)
891
+ else:
892
+ current_motion[frame_idx] = current_motion[end_frame]
893
+
894
+ # Combine with previous transformations
895
+ final_motion = torch.matmul(final_motion, current_motion)
896
+
897
+ elif motion_type == 'spiral':
898
+ # Parse spiral motion parameters
899
+ if len(params) not in [2, 4]:
900
+ raise ValueError(f"spiral motion requires 1 or 3 parameters: 'spiral <radius>' or 'spiral <radius> <start_frame> <end_frame>', got: {transform}")
901
+
902
+ radius = float(params[1])
903
+
904
+ if len(params) == 4:
905
+ start_frame = max(0, min(48, int(params[2])))
906
+ end_frame = max(0, min(48, int(params[3])))
907
+ if start_frame > end_frame:
908
+ start_frame, end_frame = end_frame, start_frame
909
+
910
+ current_motion = torch.eye(4, device=self.device).unsqueeze(0).repeat(49, 1, 1)
911
+ spiral_motion = self.spiral_poses(radius)
912
+ for frame_idx in range(49):
913
+ if frame_idx < start_frame:
914
+ continue
915
+ elif frame_idx <= end_frame:
916
+ t = (frame_idx - start_frame) / (end_frame - start_frame)
917
+ idx = int(t * (len(spiral_motion) - 1))
918
+ current_motion[frame_idx] = spiral_motion[idx]
919
+ else:
920
+ current_motion[frame_idx] = current_motion[end_frame]
921
+
922
+ # Combine with previous transformations
923
+ final_motion = torch.matmul(final_motion, current_motion)
924
+
925
+ else:
926
+ raise ValueError(f'camera_motion type must be in [trans, spiral, rot], but got {motion_type}')
927
+
928
+ return final_motion
929
+
930
+ class ObjectMotionGenerator:
931
+ def __init__(self, device="cuda:0"):
932
+ self.device = device
933
+ self.num_frames = 49
934
+
935
+ def _get_points_in_mask(self, pred_tracks, mask):
936
+ """Get points that lie within the mask
937
+
938
+ Args:
939
+ pred_tracks (torch.Tensor): Point trajectories [num_frames, num_points, 3]
940
+ mask (torch.Tensor): Binary mask [H, W]
941
+
942
+ Returns:
943
+ torch.Tensor: Boolean mask for selected points [num_points]
944
+ """
945
+ first_frame_points = pred_tracks[0] # [num_points, 3]
946
+ xy_points = first_frame_points[:, :2] # [num_points, 2]
947
+
948
+ xy_pixels = xy_points.round().long()
949
+ xy_pixels[:, 0].clamp_(0, mask.shape[1] - 1)
950
+ xy_pixels[:, 1].clamp_(0, mask.shape[0] - 1)
951
+
952
+ points_in_mask = mask[xy_pixels[:, 1], xy_pixels[:, 0]]
953
+
954
+ return points_in_mask
955
+
956
+ def apply_motion(self, pred_tracks, mask, motion_type, distance, num_frames=49, tracking_method="spatracker"):
957
+
958
+ self.num_frames = num_frames
959
+ pred_tracks = pred_tracks.to(self.device).float()
960
+ mask = mask.to(self.device)
961
+
962
+ template = {
963
+ 'up': ('trans', torch.tensor([0, -1, 0])),
964
+ 'down': ('trans', torch.tensor([0, 1, 0])),
965
+ 'left': ('trans', torch.tensor([-1, 0, 0])),
966
+ 'right': ('trans', torch.tensor([1, 0, 0])),
967
+ 'front': ('trans', torch.tensor([0, 0, 1])),
968
+ 'back': ('trans', torch.tensor([0, 0, -1])),
969
+ 'rot': ('rot', None) # rotate around y axis
970
+ }
971
+
972
+ if motion_type not in template:
973
+ raise ValueError(f"unknown motion type: {motion_type}")
974
+
975
+ motion_type, base_vec = template[motion_type]
976
+ if base_vec is not None:
977
+ base_vec = base_vec.to(self.device) * distance
978
+
979
+ if tracking_method == "moge":
980
+ T, H, W, _ = pred_tracks.shape
981
+ valid_selected = ~torch.any(torch.isnan(pred_tracks[0]), dim=2) & mask
982
+ points = pred_tracks[0][valid_selected].reshape(-1, 3)
983
+ else:
984
+ points_in_mask = self._get_points_in_mask(pred_tracks, mask)
985
+ points = pred_tracks[0, points_in_mask]
986
+
987
+ center = points.mean(dim=0)
988
+
989
+ motions = []
990
+ for frame_idx in range(num_frames):
991
+ t = frame_idx / (num_frames - 1)
992
+ current_motion = torch.eye(4, device=self.device)
993
+ current_motion[:3, 3] = -center
994
+ motion_mat = torch.eye(4, device=self.device)
995
+ if motion_type == 'trans':
996
+ motion_mat[:3, 3] = base_vec * t
997
+ else: # 'rot'
998
+ angle_rad = torch.deg2rad(torch.tensor(distance * t, device=self.device))
999
+ cos_t = torch.cos(angle_rad)
1000
+ sin_t = torch.sin(angle_rad)
1001
+ motion_mat[0, 0] = cos_t
1002
+ motion_mat[0, 2] = sin_t
1003
+ motion_mat[2, 0] = -sin_t
1004
+ motion_mat[2, 2] = cos_t
1005
+
1006
+ current_motion = motion_mat @ current_motion
1007
+ current_motion[:3, 3] += center
1008
+ motions.append(current_motion)
1009
+
1010
+ motions = torch.stack(motions) # [num_frames, 4, 4]
1011
+
1012
+ if tracking_method == "moge":
1013
+ modified_tracks = pred_tracks.clone().reshape(T, -1, 3)
1014
+ valid_selected = valid_selected.reshape([-1])
1015
+
1016
+ for frame_idx in range(self.num_frames):
1017
+ motion_mat = motions[frame_idx]
1018
+ if W > 1:
1019
+ motion_mat = motion_mat.clone()
1020
+ motion_mat[0, 3] /= W
1021
+ motion_mat[1, 3] /= H
1022
+ points = modified_tracks[frame_idx, valid_selected]
1023
+ points_homo = torch.cat([points, torch.ones_like(points[:, :1])], dim=1)
1024
+ transformed_points = torch.matmul(points_homo, motion_mat.T)
1025
+ modified_tracks[frame_idx, valid_selected] = transformed_points[:, :3]
1026
+
1027
+ return modified_tracks.reshape(T, H, W, 3)
1028
+
1029
+ else:
1030
+ points_in_mask = self._get_points_in_mask(pred_tracks, mask)
1031
+ modified_tracks = pred_tracks.clone()
1032
+
1033
+ for frame_idx in range(pred_tracks.shape[0]):
1034
+ motion_mat = motions[frame_idx]
1035
+ points = modified_tracks[frame_idx, points_in_mask]
1036
+ points_homo = torch.cat([points, torch.ones_like(points[:, :1])], dim=1)
1037
+ transformed_points = torch.matmul(points_homo, motion_mat.T)
1038
+ modified_tracks[frame_idx, points_in_mask] = transformed_points[:, :3]
1039
+
1040
+ return modified_tracks
models/spatracker/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
models/spatracker/models/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
models/spatracker/models/build_spatracker.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import torch
8
+
9
+ from models.spatracker.models.core.spatracker.spatracker import SpaTracker
10
+
11
+
12
+ def build_spatracker(
13
+ checkpoint: str,
14
+ seq_length: int = 8,
15
+ ):
16
+ model_name = checkpoint.split("/")[-1].split(".")[0]
17
+ return build_spatracker_from_cfg(checkpoint=checkpoint, seq_length=seq_length)
18
+
19
+
20
+
21
+ # model used to produce the results in the paper
22
+ def build_spatracker_from_cfg(checkpoint=None, seq_length=8):
23
+ return _build_spatracker(
24
+ stride=4,
25
+ sequence_len=seq_length,
26
+ checkpoint=checkpoint,
27
+ )
28
+
29
+
30
+ def _build_spatracker(
31
+ stride,
32
+ sequence_len,
33
+ checkpoint=None,
34
+ ):
35
+ spatracker = SpaTracker(
36
+ stride=stride,
37
+ S=sequence_len,
38
+ add_space_attn=True,
39
+ space_depth=6,
40
+ time_depth=6,
41
+ )
42
+ if checkpoint is not None:
43
+ with open(checkpoint, "rb") as f:
44
+ state_dict = torch.load(f, map_location="cpu", weights_only=True)
45
+ if "model" in state_dict:
46
+ model_paras = spatracker.state_dict()
47
+ paras_dict = {k: v for k,v in state_dict["model"].items() if k in spatracker.state_dict()}
48
+ model_paras.update(paras_dict)
49
+ state_dict = model_paras
50
+ spatracker.load_state_dict(state_dict)
51
+ return spatracker
models/spatracker/models/core/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
models/spatracker/models/core/embeddings.py ADDED
@@ -0,0 +1,250 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import torch
8
+ import numpy as np
9
+
10
+ def get_3d_sincos_pos_embed(embed_dim, grid_size, cls_token=False, extra_tokens=0):
11
+ """
12
+ grid_size: int of the grid height and width
13
+ return:
14
+ pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
15
+ """
16
+ if isinstance(grid_size, tuple):
17
+ grid_size_h, grid_size_w = grid_size
18
+ else:
19
+ grid_size_h = grid_size_w = grid_size
20
+ grid_h = np.arange(grid_size_h, dtype=np.float32)
21
+ grid_w = np.arange(grid_size_w, dtype=np.float32)
22
+ grid = np.meshgrid(grid_w, grid_h) # here w goes first
23
+ grid = np.stack(grid, axis=0)
24
+
25
+ grid = grid.reshape([2, 1, grid_size_h, grid_size_w])
26
+ pos_embed = get_3d_sincos_pos_embed_from_grid(embed_dim, grid)
27
+ if cls_token and extra_tokens > 0:
28
+ pos_embed = np.concatenate(
29
+ [np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0
30
+ )
31
+ return pos_embed
32
+
33
+
34
+ def get_3d_sincos_pos_embed_from_grid(embed_dim, grid):
35
+ assert embed_dim % 3 == 0
36
+
37
+ # use half of dimensions to encode grid_h
38
+ B, S, N, _ = grid.shape
39
+ gridx = grid[..., 0].view(B*S*N).detach().cpu().numpy()
40
+ gridy = grid[..., 1].view(B*S*N).detach().cpu().numpy()
41
+ gridz = grid[..., 2].view(B*S*N).detach().cpu().numpy()
42
+
43
+ emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 3, gridx) # (N, D/3)
44
+ emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 3, gridy) # (N, D/3)
45
+ emb_z = get_1d_sincos_pos_embed_from_grid(embed_dim // 3, gridz) # (N, D/3)
46
+
47
+
48
+ emb = np.concatenate([emb_h, emb_w, emb_z], axis=1) # (N, D)
49
+ emb = torch.from_numpy(emb).to(grid.device)
50
+ return emb.view(B, S, N, embed_dim)
51
+
52
+
53
+ def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False, extra_tokens=0):
54
+ """
55
+ grid_size: int of the grid height and width
56
+ return:
57
+ pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
58
+ """
59
+ if isinstance(grid_size, tuple):
60
+ grid_size_h, grid_size_w = grid_size
61
+ else:
62
+ grid_size_h = grid_size_w = grid_size
63
+ grid_h = np.arange(grid_size_h, dtype=np.float32)
64
+ grid_w = np.arange(grid_size_w, dtype=np.float32)
65
+ grid = np.meshgrid(grid_w, grid_h) # here w goes first
66
+ grid = np.stack(grid, axis=0)
67
+
68
+ grid = grid.reshape([2, 1, grid_size_h, grid_size_w])
69
+ pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
70
+ if cls_token and extra_tokens > 0:
71
+ pos_embed = np.concatenate(
72
+ [np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0
73
+ )
74
+ return pos_embed
75
+
76
+
77
+ def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
78
+ assert embed_dim % 2 == 0
79
+
80
+ # use half of dimensions to encode grid_h
81
+ emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
82
+ emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
83
+
84
+ emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
85
+ return emb
86
+
87
+
88
+ def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
89
+ """
90
+ embed_dim: output dimension for each position
91
+ pos: a list of positions to be encoded: size (M,)
92
+ out: (M, D)
93
+ """
94
+ assert embed_dim % 2 == 0
95
+ omega = np.arange(embed_dim // 2, dtype=np.float64)
96
+ omega /= embed_dim / 2.0
97
+ omega = 1.0 / 10000 ** omega # (D/2,)
98
+
99
+ pos = pos.reshape(-1) # (M,)
100
+ out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product
101
+
102
+ emb_sin = np.sin(out) # (M, D/2)
103
+ emb_cos = np.cos(out) # (M, D/2)
104
+
105
+ emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
106
+ return emb
107
+
108
+
109
+ def get_2d_embedding(xy, C, cat_coords=True):
110
+ B, N, D = xy.shape
111
+ assert D == 2
112
+
113
+ x = xy[:, :, 0:1]
114
+ y = xy[:, :, 1:2]
115
+ div_term = (
116
+ torch.arange(0, C, 2, device=xy.device, dtype=torch.float32) * (1000.0 / C)
117
+ ).reshape(1, 1, int(C / 2))
118
+
119
+ pe_x = torch.zeros(B, N, C, device=xy.device, dtype=torch.float32)
120
+ pe_y = torch.zeros(B, N, C, device=xy.device, dtype=torch.float32)
121
+
122
+ pe_x[:, :, 0::2] = torch.sin(x * div_term)
123
+ pe_x[:, :, 1::2] = torch.cos(x * div_term)
124
+
125
+ pe_y[:, :, 0::2] = torch.sin(y * div_term)
126
+ pe_y[:, :, 1::2] = torch.cos(y * div_term)
127
+
128
+ pe = torch.cat([pe_x, pe_y], dim=2) # B, N, C*3
129
+ if cat_coords:
130
+ pe = torch.cat([xy, pe], dim=2) # B, N, C*3+3
131
+ return pe
132
+
133
+
134
+ def get_3d_embedding(xyz, C, cat_coords=True):
135
+ B, N, D = xyz.shape
136
+ assert D == 3
137
+
138
+ x = xyz[:, :, 0:1]
139
+ y = xyz[:, :, 1:2]
140
+ z = xyz[:, :, 2:3]
141
+ div_term = (
142
+ torch.arange(0, C, 2, device=xyz.device, dtype=torch.float32) * (1000.0 / C)
143
+ ).reshape(1, 1, int(C / 2))
144
+
145
+ pe_x = torch.zeros(B, N, C, device=xyz.device, dtype=torch.float32)
146
+ pe_y = torch.zeros(B, N, C, device=xyz.device, dtype=torch.float32)
147
+ pe_z = torch.zeros(B, N, C, device=xyz.device, dtype=torch.float32)
148
+
149
+ pe_x[:, :, 0::2] = torch.sin(x * div_term)
150
+ pe_x[:, :, 1::2] = torch.cos(x * div_term)
151
+
152
+ pe_y[:, :, 0::2] = torch.sin(y * div_term)
153
+ pe_y[:, :, 1::2] = torch.cos(y * div_term)
154
+
155
+ pe_z[:, :, 0::2] = torch.sin(z * div_term)
156
+ pe_z[:, :, 1::2] = torch.cos(z * div_term)
157
+
158
+ pe = torch.cat([pe_x, pe_y, pe_z], dim=2) # B, N, C*3
159
+ if cat_coords:
160
+ pe = torch.cat([pe, xyz], dim=2) # B, N, C*3+3
161
+ return pe
162
+
163
+
164
+ def get_4d_embedding(xyzw, C, cat_coords=True):
165
+ B, N, D = xyzw.shape
166
+ assert D == 4
167
+
168
+ x = xyzw[:, :, 0:1]
169
+ y = xyzw[:, :, 1:2]
170
+ z = xyzw[:, :, 2:3]
171
+ w = xyzw[:, :, 3:4]
172
+ div_term = (
173
+ torch.arange(0, C, 2, device=xyzw.device, dtype=torch.float32) * (1000.0 / C)
174
+ ).reshape(1, 1, int(C / 2))
175
+
176
+ pe_x = torch.zeros(B, N, C, device=xyzw.device, dtype=torch.float32)
177
+ pe_y = torch.zeros(B, N, C, device=xyzw.device, dtype=torch.float32)
178
+ pe_z = torch.zeros(B, N, C, device=xyzw.device, dtype=torch.float32)
179
+ pe_w = torch.zeros(B, N, C, device=xyzw.device, dtype=torch.float32)
180
+
181
+ pe_x[:, :, 0::2] = torch.sin(x * div_term)
182
+ pe_x[:, :, 1::2] = torch.cos(x * div_term)
183
+
184
+ pe_y[:, :, 0::2] = torch.sin(y * div_term)
185
+ pe_y[:, :, 1::2] = torch.cos(y * div_term)
186
+
187
+ pe_z[:, :, 0::2] = torch.sin(z * div_term)
188
+ pe_z[:, :, 1::2] = torch.cos(z * div_term)
189
+
190
+ pe_w[:, :, 0::2] = torch.sin(w * div_term)
191
+ pe_w[:, :, 1::2] = torch.cos(w * div_term)
192
+
193
+ pe = torch.cat([pe_x, pe_y, pe_z, pe_w], dim=2) # B, N, C*3
194
+ if cat_coords:
195
+ pe = torch.cat([pe, xyzw], dim=2) # B, N, C*3+3
196
+ return pe
197
+
198
+ import torch.nn as nn
199
+ class Embedder_Fourier(nn.Module):
200
+ def __init__(self, input_dim, max_freq_log2, N_freqs,
201
+ log_sampling=True, include_input=True,
202
+ periodic_fns=(torch.sin, torch.cos)):
203
+ '''
204
+ :param input_dim: dimension of input to be embedded
205
+ :param max_freq_log2: log2 of max freq; min freq is 1 by default
206
+ :param N_freqs: number of frequency bands
207
+ :param log_sampling: if True, frequency bands are linerly sampled in log-space
208
+ :param include_input: if True, raw input is included in the embedding
209
+ :param periodic_fns: periodic functions used to embed input
210
+ '''
211
+ super(Embedder_Fourier, self).__init__()
212
+
213
+ self.input_dim = input_dim
214
+ self.include_input = include_input
215
+ self.periodic_fns = periodic_fns
216
+
217
+ self.out_dim = 0
218
+ if self.include_input:
219
+ self.out_dim += self.input_dim
220
+
221
+ self.out_dim += self.input_dim * N_freqs * len(self.periodic_fns)
222
+
223
+ if log_sampling:
224
+ self.freq_bands = 2. ** torch.linspace(0., max_freq_log2, N_freqs)
225
+ else:
226
+ self.freq_bands = torch.linspace(
227
+ 2. ** 0., 2. ** max_freq_log2, N_freqs)
228
+
229
+ self.freq_bands = self.freq_bands.numpy().tolist()
230
+
231
+ def forward(self,
232
+ input: torch.Tensor,
233
+ rescale: float = 1.0):
234
+ '''
235
+ :param input: tensor of shape [..., self.input_dim]
236
+ :return: tensor of shape [..., self.out_dim]
237
+ '''
238
+ assert (input.shape[-1] == self.input_dim)
239
+ out = []
240
+ if self.include_input:
241
+ out.append(input/rescale)
242
+
243
+ for i in range(len(self.freq_bands)):
244
+ freq = self.freq_bands[i]
245
+ for p_fn in self.periodic_fns:
246
+ out.append(p_fn(input * freq))
247
+ out = torch.cat(out, dim=-1)
248
+
249
+ assert (out.shape[-1] == self.out_dim)
250
+ return out
models/spatracker/models/core/model_utils.py ADDED
@@ -0,0 +1,477 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import torch
8
+ import torch.nn.functional as F
9
+ from easydict import EasyDict as edict
10
+ from sklearn.decomposition import PCA
11
+ import matplotlib.pyplot as plt
12
+
13
+ EPS = 1e-6
14
+
15
+ def nearest_sample2d(im, x, y, return_inbounds=False):
16
+ # x and y are each B, N
17
+ # output is B, C, N
18
+ if len(im.shape) == 5:
19
+ B, N, C, H, W = list(im.shape)
20
+ else:
21
+ B, C, H, W = list(im.shape)
22
+ N = list(x.shape)[1]
23
+
24
+ x = x.float()
25
+ y = y.float()
26
+ H_f = torch.tensor(H, dtype=torch.float32)
27
+ W_f = torch.tensor(W, dtype=torch.float32)
28
+
29
+ # inbound_mask = (x>-0.5).float()*(y>-0.5).float()*(x<W_f+0.5).float()*(y<H_f+0.5).float()
30
+
31
+ max_y = (H_f - 1).int()
32
+ max_x = (W_f - 1).int()
33
+
34
+ x0 = torch.floor(x).int()
35
+ x1 = x0 + 1
36
+ y0 = torch.floor(y).int()
37
+ y1 = y0 + 1
38
+
39
+ x0_clip = torch.clamp(x0, 0, max_x)
40
+ x1_clip = torch.clamp(x1, 0, max_x)
41
+ y0_clip = torch.clamp(y0, 0, max_y)
42
+ y1_clip = torch.clamp(y1, 0, max_y)
43
+ dim2 = W
44
+ dim1 = W * H
45
+
46
+ base = torch.arange(0, B, dtype=torch.int64, device=x.device) * dim1
47
+ base = torch.reshape(base, [B, 1]).repeat([1, N])
48
+
49
+ base_y0 = base + y0_clip * dim2
50
+ base_y1 = base + y1_clip * dim2
51
+
52
+ idx_y0_x0 = base_y0 + x0_clip
53
+ idx_y0_x1 = base_y0 + x1_clip
54
+ idx_y1_x0 = base_y1 + x0_clip
55
+ idx_y1_x1 = base_y1 + x1_clip
56
+
57
+ # use the indices to lookup pixels in the flat image
58
+ # im is B x C x H x W
59
+ # move C out to last dim
60
+ if len(im.shape) == 5:
61
+ im_flat = (im.permute(0, 3, 4, 1, 2)).reshape(B * H * W, N, C)
62
+ i_y0_x0 = torch.diagonal(im_flat[idx_y0_x0.long()], dim1=1, dim2=2).permute(
63
+ 0, 2, 1
64
+ )
65
+ i_y0_x1 = torch.diagonal(im_flat[idx_y0_x1.long()], dim1=1, dim2=2).permute(
66
+ 0, 2, 1
67
+ )
68
+ i_y1_x0 = torch.diagonal(im_flat[idx_y1_x0.long()], dim1=1, dim2=2).permute(
69
+ 0, 2, 1
70
+ )
71
+ i_y1_x1 = torch.diagonal(im_flat[idx_y1_x1.long()], dim1=1, dim2=2).permute(
72
+ 0, 2, 1
73
+ )
74
+ else:
75
+ im_flat = (im.permute(0, 2, 3, 1)).reshape(B * H * W, C)
76
+ i_y0_x0 = im_flat[idx_y0_x0.long()]
77
+ i_y0_x1 = im_flat[idx_y0_x1.long()]
78
+ i_y1_x0 = im_flat[idx_y1_x0.long()]
79
+ i_y1_x1 = im_flat[idx_y1_x1.long()]
80
+
81
+ # Finally calculate interpolated values.
82
+ x0_f = x0.float()
83
+ x1_f = x1.float()
84
+ y0_f = y0.float()
85
+ y1_f = y1.float()
86
+
87
+ w_y0_x0 = ((x1_f - x) * (y1_f - y)).unsqueeze(2)
88
+ w_y0_x1 = ((x - x0_f) * (y1_f - y)).unsqueeze(2)
89
+ w_y1_x0 = ((x1_f - x) * (y - y0_f)).unsqueeze(2)
90
+ w_y1_x1 = ((x - x0_f) * (y - y0_f)).unsqueeze(2)
91
+
92
+ # w_yi_xo is B * N * 1
93
+ max_idx = torch.cat([w_y0_x0, w_y0_x1, w_y1_x0, w_y1_x1], dim=-1).max(dim=-1)[1]
94
+ output = torch.stack([i_y0_x0, i_y0_x1, i_y1_x0, i_y1_x1], dim=-1).gather(-1, max_idx[...,None,None].repeat(1,1,C,1)).squeeze(-1)
95
+
96
+ # output is B*N x C
97
+ output = output.view(B, -1, C)
98
+ output = output.permute(0, 2, 1)
99
+ # output is B x C x N
100
+
101
+ if return_inbounds:
102
+ x_valid = (x > -0.5).byte() & (x < float(W_f - 0.5)).byte()
103
+ y_valid = (y > -0.5).byte() & (y < float(H_f - 0.5)).byte()
104
+ inbounds = (x_valid & y_valid).float()
105
+ inbounds = inbounds.reshape(
106
+ B, N
107
+ ) # something seems wrong here for B>1; i'm getting an error here (or downstream if i put -1)
108
+ return output, inbounds
109
+
110
+ return output # B, C, N
111
+
112
+ def smart_cat(tensor1, tensor2, dim):
113
+ if tensor1 is None:
114
+ return tensor2
115
+ return torch.cat([tensor1, tensor2], dim=dim)
116
+
117
+
118
+ def normalize_single(d):
119
+ # d is a whatever shape torch tensor
120
+ dmin = torch.min(d)
121
+ dmax = torch.max(d)
122
+ d = (d - dmin) / (EPS + (dmax - dmin))
123
+ return d
124
+
125
+
126
+ def normalize(d):
127
+ # d is B x whatever. normalize within each element of the batch
128
+ out = torch.zeros(d.size())
129
+ if d.is_cuda:
130
+ out = out.cuda()
131
+ B = list(d.size())[0]
132
+ for b in list(range(B)):
133
+ out[b] = normalize_single(d[b])
134
+ return out
135
+
136
+
137
+ def meshgrid2d(B, Y, X, stack=False, norm=False, device="cuda"):
138
+ # returns a meshgrid sized B x Y x X
139
+
140
+ grid_y = torch.linspace(0.0, Y - 1, Y, device=torch.device(device))
141
+ grid_y = torch.reshape(grid_y, [1, Y, 1])
142
+ grid_y = grid_y.repeat(B, 1, X)
143
+
144
+ grid_x = torch.linspace(0.0, X - 1, X, device=torch.device(device))
145
+ grid_x = torch.reshape(grid_x, [1, 1, X])
146
+ grid_x = grid_x.repeat(B, Y, 1)
147
+
148
+ if stack:
149
+ # note we stack in xy order
150
+ # (see https://pytorch.org/docs/stable/nn.functional.html#torch.nn.functional.grid_sample)
151
+ grid = torch.stack([grid_x, grid_y], dim=-1)
152
+ return grid
153
+ else:
154
+ return grid_y, grid_x
155
+
156
+
157
+ def reduce_masked_mean(x, mask, dim=None, keepdim=False):
158
+ # x and mask are the same shape, or at least broadcastably so < actually it's safer if you disallow broadcasting
159
+ # returns shape-1
160
+ # axis can be a list of axes
161
+ for (a, b) in zip(x.size(), mask.size()):
162
+ assert a == b # some shape mismatch!
163
+ prod = x * mask
164
+ if dim is None:
165
+ numer = torch.sum(prod)
166
+ denom = EPS + torch.sum(mask)
167
+ else:
168
+ numer = torch.sum(prod, dim=dim, keepdim=keepdim)
169
+ denom = EPS + torch.sum(mask, dim=dim, keepdim=keepdim)
170
+
171
+ mean = numer / denom
172
+ return mean
173
+
174
+
175
+ def bilinear_sample2d(im, x, y, return_inbounds=False):
176
+ # x and y are each B, N
177
+ # output is B, C, N
178
+ if len(im.shape) == 5:
179
+ B, N, C, H, W = list(im.shape)
180
+ else:
181
+ B, C, H, W = list(im.shape)
182
+ N = list(x.shape)[1]
183
+
184
+ x = x.float()
185
+ y = y.float()
186
+ H_f = torch.tensor(H, dtype=torch.float32)
187
+ W_f = torch.tensor(W, dtype=torch.float32)
188
+
189
+ # inbound_mask = (x>-0.5).float()*(y>-0.5).float()*(x<W_f+0.5).float()*(y<H_f+0.5).float()
190
+
191
+ max_y = (H_f - 1).int()
192
+ max_x = (W_f - 1).int()
193
+
194
+ x0 = torch.floor(x).int()
195
+ x1 = x0 + 1
196
+ y0 = torch.floor(y).int()
197
+ y1 = y0 + 1
198
+
199
+ x0_clip = torch.clamp(x0, 0, max_x)
200
+ x1_clip = torch.clamp(x1, 0, max_x)
201
+ y0_clip = torch.clamp(y0, 0, max_y)
202
+ y1_clip = torch.clamp(y1, 0, max_y)
203
+ dim2 = W
204
+ dim1 = W * H
205
+
206
+ base = torch.arange(0, B, dtype=torch.int64, device=x.device) * dim1
207
+ base = torch.reshape(base, [B, 1]).repeat([1, N])
208
+
209
+ base_y0 = base + y0_clip * dim2
210
+ base_y1 = base + y1_clip * dim2
211
+
212
+ idx_y0_x0 = base_y0 + x0_clip
213
+ idx_y0_x1 = base_y0 + x1_clip
214
+ idx_y1_x0 = base_y1 + x0_clip
215
+ idx_y1_x1 = base_y1 + x1_clip
216
+
217
+ # use the indices to lookup pixels in the flat image
218
+ # im is B x C x H x W
219
+ # move C out to last dim
220
+ if len(im.shape) == 5:
221
+ im_flat = (im.permute(0, 3, 4, 1, 2)).reshape(B * H * W, N, C)
222
+ i_y0_x0 = torch.diagonal(im_flat[idx_y0_x0.long()], dim1=1, dim2=2).permute(
223
+ 0, 2, 1
224
+ )
225
+ i_y0_x1 = torch.diagonal(im_flat[idx_y0_x1.long()], dim1=1, dim2=2).permute(
226
+ 0, 2, 1
227
+ )
228
+ i_y1_x0 = torch.diagonal(im_flat[idx_y1_x0.long()], dim1=1, dim2=2).permute(
229
+ 0, 2, 1
230
+ )
231
+ i_y1_x1 = torch.diagonal(im_flat[idx_y1_x1.long()], dim1=1, dim2=2).permute(
232
+ 0, 2, 1
233
+ )
234
+ else:
235
+ im_flat = (im.permute(0, 2, 3, 1)).reshape(B * H * W, C)
236
+ i_y0_x0 = im_flat[idx_y0_x0.long()]
237
+ i_y0_x1 = im_flat[idx_y0_x1.long()]
238
+ i_y1_x0 = im_flat[idx_y1_x0.long()]
239
+ i_y1_x1 = im_flat[idx_y1_x1.long()]
240
+
241
+ # Finally calculate interpolated values.
242
+ x0_f = x0.float()
243
+ x1_f = x1.float()
244
+ y0_f = y0.float()
245
+ y1_f = y1.float()
246
+
247
+ w_y0_x0 = ((x1_f - x) * (y1_f - y)).unsqueeze(2)
248
+ w_y0_x1 = ((x - x0_f) * (y1_f - y)).unsqueeze(2)
249
+ w_y1_x0 = ((x1_f - x) * (y - y0_f)).unsqueeze(2)
250
+ w_y1_x1 = ((x - x0_f) * (y - y0_f)).unsqueeze(2)
251
+
252
+ output = (
253
+ w_y0_x0 * i_y0_x0 + w_y0_x1 * i_y0_x1 + w_y1_x0 * i_y1_x0 + w_y1_x1 * i_y1_x1
254
+ )
255
+ # output is B*N x C
256
+ output = output.view(B, -1, C)
257
+ output = output.permute(0, 2, 1)
258
+ # output is B x C x N
259
+
260
+ if return_inbounds:
261
+ x_valid = (x > -0.5).byte() & (x < float(W_f - 0.5)).byte()
262
+ y_valid = (y > -0.5).byte() & (y < float(H_f - 0.5)).byte()
263
+ inbounds = (x_valid & y_valid).float()
264
+ inbounds = inbounds.reshape(
265
+ B, N
266
+ ) # something seems wrong here for B>1; i'm getting an error here (or downstream if i put -1)
267
+ return output, inbounds
268
+
269
+ return output # B, C, N
270
+
271
+
272
+ def procrustes_analysis(X0,X1,Weight): # [B,N,3]
273
+ # translation
274
+ t0 = X0.mean(dim=1,keepdim=True)
275
+ t1 = X1.mean(dim=1,keepdim=True)
276
+ X0c = X0-t0
277
+ X1c = X1-t1
278
+ # scale
279
+ # s0 = (X0c**2).sum(dim=-1).mean().sqrt()
280
+ # s1 = (X1c**2).sum(dim=-1).mean().sqrt()
281
+ # X0cs = X0c/s0
282
+ # X1cs = X1c/s1
283
+ # rotation (use double for SVD, float loses precision)
284
+ U,_,V = (X0c.t()@X1c).double().svd(some=True)
285
+ R = ([email protected]()).float()
286
+ if R.det()<0: R[2] *= -1
287
+ # align X1 to X0: X1to0 = (X1-t1)/@R.t()+t0
288
+ se3 = edict(t0=t0[0],t1=t1[0],R=R)
289
+
290
+ return se3
291
+
292
+ def bilinear_sampler(input, coords, align_corners=True, padding_mode="border"):
293
+ r"""Sample a tensor using bilinear interpolation
294
+
295
+ `bilinear_sampler(input, coords)` samples a tensor :attr:`input` at
296
+ coordinates :attr:`coords` using bilinear interpolation. It is the same
297
+ as `torch.nn.functional.grid_sample()` but with a different coordinate
298
+ convention.
299
+
300
+ The input tensor is assumed to be of shape :math:`(B, C, H, W)`, where
301
+ :math:`B` is the batch size, :math:`C` is the number of channels,
302
+ :math:`H` is the height of the image, and :math:`W` is the width of the
303
+ image. The tensor :attr:`coords` of shape :math:`(B, H_o, W_o, 2)` is
304
+ interpreted as an array of 2D point coordinates :math:`(x_i,y_i)`.
305
+
306
+ Alternatively, the input tensor can be of size :math:`(B, C, T, H, W)`,
307
+ in which case sample points are triplets :math:`(t_i,x_i,y_i)`. Note
308
+ that in this case the order of the components is slightly different
309
+ from `grid_sample()`, which would expect :math:`(x_i,y_i,t_i)`.
310
+
311
+ If `align_corners` is `True`, the coordinate :math:`x` is assumed to be
312
+ in the range :math:`[0,W-1]`, with 0 corresponding to the center of the
313
+ left-most image pixel :math:`W-1` to the center of the right-most
314
+ pixel.
315
+
316
+ If `align_corners` is `False`, the coordinate :math:`x` is assumed to
317
+ be in the range :math:`[0,W]`, with 0 corresponding to the left edge of
318
+ the left-most pixel :math:`W` to the right edge of the right-most
319
+ pixel.
320
+
321
+ Similar conventions apply to the :math:`y` for the range
322
+ :math:`[0,H-1]` and :math:`[0,H]` and to :math:`t` for the range
323
+ :math:`[0,T-1]` and :math:`[0,T]`.
324
+
325
+ Args:
326
+ input (Tensor): batch of input images.
327
+ coords (Tensor): batch of coordinates.
328
+ align_corners (bool, optional): Coordinate convention. Defaults to `True`.
329
+ padding_mode (str, optional): Padding mode. Defaults to `"border"`.
330
+
331
+ Returns:
332
+ Tensor: sampled points.
333
+ """
334
+
335
+ sizes = input.shape[2:]
336
+
337
+ assert len(sizes) in [2, 3]
338
+
339
+ if len(sizes) == 3:
340
+ # t x y -> x y t to match dimensions T H W in grid_sample
341
+ coords = coords[..., [1, 2, 0]]
342
+
343
+ if align_corners:
344
+ coords = coords * torch.tensor(
345
+ [2 / max(size - 1, 1) for size in reversed(sizes)], device=coords.device
346
+ )
347
+ else:
348
+ coords = coords * torch.tensor([2 / size for size in reversed(sizes)], device=coords.device)
349
+
350
+ coords -= 1
351
+
352
+ return F.grid_sample(input, coords, align_corners=align_corners, padding_mode=padding_mode)
353
+
354
+
355
+ def sample_features4d(input, coords):
356
+ r"""Sample spatial features
357
+
358
+ `sample_features4d(input, coords)` samples the spatial features
359
+ :attr:`input` represented by a 4D tensor :math:`(B, C, H, W)`.
360
+
361
+ The field is sampled at coordinates :attr:`coords` using bilinear
362
+ interpolation. :attr:`coords` is assumed to be of shape :math:`(B, R,
363
+ 3)`, where each sample has the format :math:`(x_i, y_i)`. This uses the
364
+ same convention as :func:`bilinear_sampler` with `align_corners=True`.
365
+
366
+ The output tensor has one feature per point, and has shape :math:`(B,
367
+ R, C)`.
368
+
369
+ Args:
370
+ input (Tensor): spatial features.
371
+ coords (Tensor): points.
372
+
373
+ Returns:
374
+ Tensor: sampled features.
375
+ """
376
+
377
+ B, _, _, _ = input.shape
378
+
379
+ # B R 2 -> B R 1 2
380
+ coords = coords.unsqueeze(2)
381
+
382
+ # B C R 1
383
+ feats = bilinear_sampler(input, coords)
384
+
385
+ return feats.permute(0, 2, 1, 3).view(
386
+ B, -1, feats.shape[1] * feats.shape[3]
387
+ ) # B C R 1 -> B R C
388
+
389
+
390
+ def sample_features5d(input, coords):
391
+ r"""Sample spatio-temporal features
392
+
393
+ `sample_features5d(input, coords)` works in the same way as
394
+ :func:`sample_features4d` but for spatio-temporal features and points:
395
+ :attr:`input` is a 5D tensor :math:`(B, T, C, H, W)`, :attr:`coords` is
396
+ a :math:`(B, R1, R2, 3)` tensor of spatio-temporal point :math:`(t_i,
397
+ x_i, y_i)`. The output tensor has shape :math:`(B, R1, R2, C)`.
398
+
399
+ Args:
400
+ input (Tensor): spatio-temporal features.
401
+ coords (Tensor): spatio-temporal points.
402
+
403
+ Returns:
404
+ Tensor: sampled features.
405
+ """
406
+
407
+ B, T, _, _, _ = input.shape
408
+
409
+ # B T C H W -> B C T H W
410
+ input = input.permute(0, 2, 1, 3, 4)
411
+
412
+ # B R1 R2 3 -> B R1 R2 1 3
413
+ coords = coords.unsqueeze(3)
414
+
415
+ # B C R1 R2 1
416
+ feats = bilinear_sampler(input, coords)
417
+
418
+ return feats.permute(0, 2, 3, 1, 4).view(
419
+ B, feats.shape[2], feats.shape[3], feats.shape[1]
420
+ ) # B C R1 R2 1 -> B R1 R2 C
421
+
422
+ def vis_PCA(fmaps, save_dir):
423
+ """
424
+ visualize the PCA of the feature maps
425
+ args:
426
+ fmaps: feature maps 1 C H W
427
+ save_dir: the directory to save the PCA visualization
428
+ """
429
+
430
+ pca = PCA(n_components=3)
431
+ fmap_vis = fmaps[0,...]
432
+ fmap_vnorm = (
433
+ (fmap_vis-fmap_vis.min())/
434
+ (fmap_vis.max()-fmap_vis.min()))
435
+ H_vis, W_vis = fmap_vis.shape[1:]
436
+ fmap_vnorm = fmap_vnorm.reshape(fmap_vnorm.shape[0],
437
+ -1).permute(1,0)
438
+ fmap_pca = pca.fit_transform(fmap_vnorm.detach().cpu().numpy())
439
+ pca = fmap_pca.reshape(H_vis,W_vis,3)
440
+ plt.imsave(save_dir,
441
+ (
442
+ (pca-pca.min())/
443
+ (pca.max()-pca.min())
444
+ ))
445
+
446
+
447
+ # debug=False
448
+ # if debug==True:
449
+ # pcd_idx = 60
450
+ # vis_PCA(fmapYZ[0,:1], "./yz.png")
451
+ # vis_PCA(fmapXZ[0,:1], "./xz.png")
452
+ # vis_PCA(fmaps[0,:1], "./xy.png")
453
+ # vis_PCA(fmaps[0,-1:], "./xy_.png")
454
+ # fxy_q = fxy[0,0,pcd_idx:pcd_idx+1, :, None, None]
455
+ # fyz_q = fyz[0,0,pcd_idx:pcd_idx+1, :, None, None]
456
+ # fxz_q = fxz[0,0,pcd_idx:pcd_idx+1, :, None, None]
457
+ # corr_map = (fxy_q*fmaps[0,-1:]).sum(dim=1)
458
+ # corr_map_yz = (fyz_q*fmapYZ[0,-1:]).sum(dim=1)
459
+ # corr_map_xz = (fxz_q*fmapXZ[0,-1:]).sum(dim=1)
460
+ # coord_last = coords[0,-1,pcd_idx:pcd_idx+1]
461
+ # coord_last_neigh = coords[0,-1, self.neigh_indx[pcd_idx]]
462
+ # depth_last = depths_dnG[-1,0]
463
+ # abs_res = (depth_last-coord_last[-1,-1]).abs()
464
+ # abs_res = (abs_res - abs_res.min())/(abs_res.max()-abs_res.min())
465
+ # res_dp = torch.exp(-abs_res)
466
+ # enhance_corr = res_dp*corr_map
467
+ # plt.imsave("./res.png", res_dp.detach().cpu().numpy())
468
+ # plt.imsave("./enhance_corr.png", enhance_corr[0].detach().cpu().numpy())
469
+ # plt.imsave("./corr_map.png", corr_map[0].detach().cpu().numpy())
470
+ # plt.imsave("./corr_map_yz.png", corr_map_yz[0].detach().cpu().numpy())
471
+ # plt.imsave("./corr_map_xz.png", corr_map_xz[0].detach().cpu().numpy())
472
+ # img_feat = cv2.imread("./xy.png")
473
+ # cv2.circle(img_feat, (int(coord_last[0,0]), int(coord_last[0,1])), 2, (0, 0, 255), -1)
474
+ # for p_i in coord_last_neigh:
475
+ # cv2.circle(img_feat, (int(p_i[0]), int(p_i[1])), 1, (0, 255, 0), -1)
476
+ # cv2.imwrite("./xy_coord.png", img_feat)
477
+ # import ipdb; ipdb.set_trace()
models/spatracker/models/core/spatracker/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
models/spatracker/models/core/spatracker/blocks.py ADDED
@@ -0,0 +1,999 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F
10
+ from torch.cuda.amp import autocast
11
+ from einops import rearrange
12
+ import collections
13
+ from functools import partial
14
+ from itertools import repeat
15
+ import torchvision.models as tvm
16
+
17
+ from models.spatracker.models.core.spatracker.vit.encoder import ImageEncoderViT as vitEnc
18
+ from models.spatracker.models.core.spatracker.dpt.models import DPTEncoder
19
+ from models.spatracker.models.core.spatracker.loftr import LocalFeatureTransformer
20
+ # from models.monoD.depth_anything.dpt import DPTHeadEnc, DPTHead
21
+
22
+ # From PyTorch internals
23
+ def _ntuple(n):
24
+ def parse(x):
25
+ if isinstance(x, collections.abc.Iterable) and not isinstance(x, str):
26
+ return tuple(x)
27
+ return tuple(repeat(x, n))
28
+
29
+ return parse
30
+
31
+
32
+ def exists(val):
33
+ return val is not None
34
+
35
+
36
+ def default(val, d):
37
+ return val if exists(val) else d
38
+
39
+
40
+ to_2tuple = _ntuple(2)
41
+
42
+
43
+ class Mlp(nn.Module):
44
+ """MLP as used in Vision Transformer, MLP-Mixer and related networks"""
45
+
46
+ def __init__(
47
+ self,
48
+ in_features,
49
+ hidden_features=None,
50
+ out_features=None,
51
+ act_layer=nn.GELU,
52
+ norm_layer=None,
53
+ bias=True,
54
+ drop=0.0,
55
+ use_conv=False,
56
+ ):
57
+ super().__init__()
58
+ out_features = out_features or in_features
59
+ hidden_features = hidden_features or in_features
60
+ bias = to_2tuple(bias)
61
+ drop_probs = to_2tuple(drop)
62
+ linear_layer = partial(nn.Conv2d, kernel_size=1) if use_conv else nn.Linear
63
+
64
+ self.fc1 = linear_layer(in_features, hidden_features, bias=bias[0])
65
+ self.act = act_layer()
66
+ self.drop1 = nn.Dropout(drop_probs[0])
67
+ self.norm = norm_layer(hidden_features) if norm_layer is not None else nn.Identity()
68
+ self.fc2 = linear_layer(hidden_features, out_features, bias=bias[1])
69
+ self.drop2 = nn.Dropout(drop_probs[1])
70
+
71
+ def forward(self, x):
72
+ x = self.fc1(x)
73
+ x = self.act(x)
74
+ x = self.drop1(x)
75
+ x = self.fc2(x)
76
+ x = self.drop2(x)
77
+ return x
78
+
79
+ class Attention(nn.Module):
80
+ def __init__(self, query_dim, context_dim=None,
81
+ num_heads=8, dim_head=48, qkv_bias=False, flash=False):
82
+ super().__init__()
83
+ inner_dim = self.inner_dim = dim_head * num_heads
84
+ context_dim = default(context_dim, query_dim)
85
+ self.scale = dim_head**-0.5
86
+ self.heads = num_heads
87
+ self.flash = flash
88
+
89
+ self.qkv = nn.Linear(query_dim, inner_dim*3, bias=qkv_bias)
90
+ self.proj = nn.Linear(inner_dim, query_dim)
91
+
92
+ def forward(self, x, context=None, attn_bias=None):
93
+ B, N1, _ = x.shape
94
+ C = self.inner_dim
95
+ h = self.heads
96
+ # q = self.to_q(x).reshape(B, N1, h, C // h).permute(0, 2, 1, 3)
97
+ # k, v = self.to_kv(context).chunk(2, dim=-1)
98
+ # context = default(context, x)
99
+
100
+ qkv = self.qkv(x).reshape(B, N1, 3, h, C // h)
101
+ q, k, v = qkv[:,:, 0], qkv[:,:, 1], qkv[:,:, 2]
102
+ N2 = x.shape[1]
103
+
104
+ k = k.reshape(B, N2, h, C // h).permute(0, 2, 1, 3)
105
+ v = v.reshape(B, N2, h, C // h).permute(0, 2, 1, 3)
106
+ q = q.reshape(B, N1, h, C // h).permute(0, 2, 1, 3)
107
+ if self.flash==False:
108
+ sim = (q @ k.transpose(-2, -1)) * self.scale
109
+ if attn_bias is not None:
110
+ sim = sim + attn_bias
111
+ attn = sim.softmax(dim=-1)
112
+ x = (attn @ v).transpose(1, 2).reshape(B, N1, C)
113
+ else:
114
+ input_args = [x.half().contiguous() for x in [q, k, v]]
115
+ x = F.scaled_dot_product_attention(*input_args).permute(0,2,1,3).reshape(B,N1,-1) # type: ignore
116
+
117
+ # return self.to_out(x.float())
118
+ return self.proj(x.float())
119
+
120
+ class ResidualBlock(nn.Module):
121
+ def __init__(self, in_planes, planes, norm_fn="group", stride=1):
122
+ super(ResidualBlock, self).__init__()
123
+
124
+ self.conv1 = nn.Conv2d(
125
+ in_planes,
126
+ planes,
127
+ kernel_size=3,
128
+ padding=1,
129
+ stride=stride,
130
+ padding_mode="zeros",
131
+ )
132
+ self.conv2 = nn.Conv2d(
133
+ planes, planes, kernel_size=3, padding=1, padding_mode="zeros"
134
+ )
135
+ self.relu = nn.ReLU(inplace=True)
136
+
137
+ num_groups = planes // 8
138
+
139
+ if norm_fn == "group":
140
+ self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
141
+ self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
142
+ if not stride == 1:
143
+ self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
144
+
145
+ elif norm_fn == "batch":
146
+ self.norm1 = nn.BatchNorm2d(planes)
147
+ self.norm2 = nn.BatchNorm2d(planes)
148
+ if not stride == 1:
149
+ self.norm3 = nn.BatchNorm2d(planes)
150
+
151
+ elif norm_fn == "instance":
152
+ self.norm1 = nn.InstanceNorm2d(planes)
153
+ self.norm2 = nn.InstanceNorm2d(planes)
154
+ if not stride == 1:
155
+ self.norm3 = nn.InstanceNorm2d(planes)
156
+
157
+ elif norm_fn == "none":
158
+ self.norm1 = nn.Sequential()
159
+ self.norm2 = nn.Sequential()
160
+ if not stride == 1:
161
+ self.norm3 = nn.Sequential()
162
+
163
+ if stride == 1:
164
+ self.downsample = None
165
+
166
+ else:
167
+ self.downsample = nn.Sequential(
168
+ nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm3
169
+ )
170
+
171
+ def forward(self, x):
172
+ y = x
173
+ y = self.relu(self.norm1(self.conv1(y)))
174
+ y = self.relu(self.norm2(self.conv2(y)))
175
+
176
+ if self.downsample is not None:
177
+ x = self.downsample(x)
178
+
179
+ return self.relu(x + y)
180
+
181
+
182
+ class BasicEncoder(nn.Module):
183
+ def __init__(
184
+ self, input_dim=3, output_dim=128, stride=8, norm_fn="batch", dropout=0.0,
185
+ Embed3D=False
186
+ ):
187
+ super(BasicEncoder, self).__init__()
188
+ self.stride = stride
189
+ self.norm_fn = norm_fn
190
+ self.in_planes = 64
191
+
192
+ if self.norm_fn == "group":
193
+ self.norm1 = nn.GroupNorm(num_groups=8, num_channels=self.in_planes)
194
+ self.norm2 = nn.GroupNorm(num_groups=8, num_channels=output_dim * 2)
195
+
196
+ elif self.norm_fn == "batch":
197
+ self.norm1 = nn.BatchNorm2d(self.in_planes)
198
+ self.norm2 = nn.BatchNorm2d(output_dim * 2)
199
+
200
+ elif self.norm_fn == "instance":
201
+ self.norm1 = nn.InstanceNorm2d(self.in_planes)
202
+ self.norm2 = nn.InstanceNorm2d(output_dim * 2)
203
+
204
+ elif self.norm_fn == "none":
205
+ self.norm1 = nn.Sequential()
206
+
207
+ self.conv1 = nn.Conv2d(
208
+ input_dim,
209
+ self.in_planes,
210
+ kernel_size=7,
211
+ stride=2,
212
+ padding=3,
213
+ padding_mode="zeros",
214
+ )
215
+ self.relu1 = nn.ReLU(inplace=True)
216
+
217
+ self.shallow = False
218
+ if self.shallow:
219
+ self.layer1 = self._make_layer(64, stride=1)
220
+ self.layer2 = self._make_layer(96, stride=2)
221
+ self.layer3 = self._make_layer(128, stride=2)
222
+ self.conv2 = nn.Conv2d(128 + 96 + 64, output_dim, kernel_size=1)
223
+ else:
224
+ if Embed3D:
225
+ self.conv_fuse = nn.Conv2d(64+63,
226
+ self.in_planes, kernel_size=3, padding=1)
227
+ self.layer1 = self._make_layer(64, stride=1)
228
+ self.layer2 = self._make_layer(96, stride=2)
229
+ self.layer3 = self._make_layer(128, stride=2)
230
+ self.layer4 = self._make_layer(128, stride=2)
231
+ self.conv2 = nn.Conv2d(
232
+ 128 + 128 + 96 + 64,
233
+ output_dim * 2,
234
+ kernel_size=3,
235
+ padding=1,
236
+ padding_mode="zeros",
237
+ )
238
+ self.relu2 = nn.ReLU(inplace=True)
239
+ self.conv3 = nn.Conv2d(output_dim * 2, output_dim, kernel_size=1)
240
+
241
+ self.dropout = None
242
+ if dropout > 0:
243
+ self.dropout = nn.Dropout2d(p=dropout)
244
+
245
+ for m in self.modules():
246
+ if isinstance(m, nn.Conv2d):
247
+ nn.init.kaiming_normal_(m.weight, mode="fan_out",
248
+ nonlinearity="relu")
249
+ elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)):
250
+ if m.weight is not None:
251
+ nn.init.constant_(m.weight, 1)
252
+ if m.bias is not None:
253
+ nn.init.constant_(m.bias, 0)
254
+
255
+ def _make_layer(self, dim, stride=1):
256
+ layer1 = ResidualBlock(self.in_planes, dim, self.norm_fn, stride=stride)
257
+ layer2 = ResidualBlock(dim, dim, self.norm_fn, stride=1)
258
+ layers = (layer1, layer2)
259
+
260
+ self.in_planes = dim
261
+ return nn.Sequential(*layers)
262
+
263
+ def forward(self, x, feat_PE=None):
264
+ _, _, H, W = x.shape
265
+
266
+ x = self.conv1(x)
267
+ x = self.norm1(x)
268
+ x = self.relu1(x)
269
+
270
+ if self.shallow:
271
+ a = self.layer1(x)
272
+ b = self.layer2(a)
273
+ c = self.layer3(b)
274
+ a = F.interpolate(
275
+ a,
276
+ (H // self.stride, W // self.stride),
277
+ mode="bilinear",
278
+ align_corners=True,
279
+ )
280
+ b = F.interpolate(
281
+ b,
282
+ (H // self.stride, W // self.stride),
283
+ mode="bilinear",
284
+ align_corners=True,
285
+ )
286
+ c = F.interpolate(
287
+ c,
288
+ (H // self.stride, W // self.stride),
289
+ mode="bilinear",
290
+ align_corners=True,
291
+ )
292
+ x = self.conv2(torch.cat([a, b, c], dim=1))
293
+ else:
294
+ if feat_PE is not None:
295
+ x = self.conv_fuse(torch.cat([x, feat_PE], dim=1))
296
+ a = self.layer1(x)
297
+ else:
298
+ a = self.layer1(x)
299
+ b = self.layer2(a)
300
+ c = self.layer3(b)
301
+ d = self.layer4(c)
302
+ a = F.interpolate(
303
+ a,
304
+ (H // self.stride, W // self.stride),
305
+ mode="bilinear",
306
+ align_corners=True,
307
+ )
308
+ b = F.interpolate(
309
+ b,
310
+ (H // self.stride, W // self.stride),
311
+ mode="bilinear",
312
+ align_corners=True,
313
+ )
314
+ c = F.interpolate(
315
+ c,
316
+ (H // self.stride, W // self.stride),
317
+ mode="bilinear",
318
+ align_corners=True,
319
+ )
320
+ d = F.interpolate(
321
+ d,
322
+ (H // self.stride, W // self.stride),
323
+ mode="bilinear",
324
+ align_corners=True,
325
+ )
326
+ x = self.conv2(torch.cat([a, b, c, d], dim=1))
327
+ x = self.norm2(x)
328
+ x = self.relu2(x)
329
+ x = self.conv3(x)
330
+
331
+ if self.training and self.dropout is not None:
332
+ x = self.dropout(x)
333
+ return x
334
+
335
+ class VitEncoder(nn.Module):
336
+ def __init__(self, input_dim=4, output_dim=128, stride=4):
337
+ super(VitEncoder, self).__init__()
338
+ self.vit = vitEnc(img_size=512,
339
+ depth=6, num_heads=8, in_chans=input_dim,
340
+ out_chans=output_dim,embed_dim=384).cuda()
341
+ self.stride = stride
342
+ def forward(self, x):
343
+ T, C, H, W = x.shape
344
+ x_resize = F.interpolate(x.view(-1, C, H, W), size=(512, 512),
345
+ mode='bilinear', align_corners=False)
346
+ x_resize = self.vit(x_resize)
347
+ x = F.interpolate(x_resize, size=(H//self.stride, W//self.stride),
348
+ mode='bilinear', align_corners=False)
349
+ return x
350
+
351
+ class DPTEnc(nn.Module):
352
+ def __init__(self, input_dim=3, output_dim=128, stride=2):
353
+ super(DPTEnc, self).__init__()
354
+ self.dpt = DPTEncoder()
355
+ self.stride = stride
356
+ def forward(self, x):
357
+ T, C, H, W = x.shape
358
+ x = (x-0.5)/0.5
359
+ x_resize = F.interpolate(x.view(-1, C, H, W), size=(384, 384),
360
+ mode='bilinear', align_corners=False)
361
+ x_resize = self.dpt(x_resize)
362
+ x = F.interpolate(x_resize, size=(H//self.stride, W//self.stride),
363
+ mode='bilinear', align_corners=False)
364
+ return x
365
+
366
+ # class DPT_DINOv2(nn.Module):
367
+ # def __init__(self, encoder='vits', features=64, out_channels=[48, 96, 192, 384],
368
+ # use_bn=True, use_clstoken=False, localhub=True, stride=2, enc_only=True):
369
+ # super(DPT_DINOv2, self).__init__()
370
+ # self.stride = stride
371
+ # self.enc_only = enc_only
372
+ # assert encoder in ['vits', 'vitb', 'vitl']
373
+
374
+ # if localhub:
375
+ # self.pretrained = torch.hub.load('models/torchhub/facebookresearch_dinov2_main', 'dinov2_{:}14'.format(encoder), source='local', pretrained=False)
376
+ # else:
377
+ # self.pretrained = torch.hub.load('facebookresearch/dinov2', 'dinov2_{:}14'.format(encoder))
378
+
379
+ # state_dict = torch.load("models/monoD/zoeDepth/ckpts/dinov2_vits14_pretrain.pth")
380
+ # self.pretrained.load_state_dict(state_dict, strict=True)
381
+ # self.pretrained.requires_grad_(False)
382
+ # dim = self.pretrained.blocks[0].attn.qkv.in_features
383
+ # if enc_only == True:
384
+ # out_channels=[128, 128, 128, 128]
385
+
386
+ # self.DPThead = DPTHeadEnc(1, dim, features, use_bn, out_channels=out_channels, use_clstoken=use_clstoken)
387
+
388
+
389
+ # def forward(self, x):
390
+ # mean_ = torch.tensor([0.485, 0.456, 0.406],
391
+ # device=x.device).view(1, 3, 1, 1)
392
+ # std_ = torch.tensor([0.229, 0.224, 0.225],
393
+ # device=x.device).view(1, 3, 1, 1)
394
+ # x = (x+1)/2
395
+ # x = (x - mean_)/std_
396
+ # h, w = x.shape[-2:]
397
+ # h_re, w_re = 560, 560
398
+ # x_resize = F.interpolate(x, size=(h_re, w_re),
399
+ # mode='bilinear', align_corners=False)
400
+ # with torch.no_grad():
401
+ # features = self.pretrained.get_intermediate_layers(x_resize, 4, return_class_token=True)
402
+ # patch_h, patch_w = h_re // 14, w_re // 14
403
+ # feat = self.DPThead(features, patch_h, patch_w, self.enc_only)
404
+ # feat = F.interpolate(feat, size=(h//self.stride, w//self.stride), mode="bilinear", align_corners=True)
405
+
406
+ # return feat
407
+
408
+
409
+ class VGG19(nn.Module):
410
+ def __init__(self, pretrained=False, amp = False, amp_dtype = torch.float16) -> None:
411
+ super().__init__()
412
+ self.layers = nn.ModuleList(tvm.vgg19_bn(pretrained=pretrained).features[:40])
413
+ self.amp = amp
414
+ self.amp_dtype = amp_dtype
415
+
416
+ def forward(self, x, **kwargs):
417
+ with torch.autocast("cuda", enabled=self.amp, dtype = self.amp_dtype):
418
+ feats = {}
419
+ scale = 1
420
+ for layer in self.layers:
421
+ if isinstance(layer, nn.MaxPool2d):
422
+ feats[scale] = x
423
+ scale = scale*2
424
+ x = layer(x)
425
+ return feats
426
+
427
+ class CNNandDinov2(nn.Module):
428
+ def __init__(self, cnn_kwargs = None, amp = True, amp_dtype = torch.float16):
429
+ super().__init__()
430
+ # in case the Internet connection is not stable, please load the DINOv2 locally
431
+ self.dinov2_vitl14 = torch.hub.load('models/torchhub/facebookresearch_dinov2_main',
432
+ 'dinov2_{:}14'.format("vitl"), source='local', pretrained=False)
433
+
434
+ state_dict = torch.load("models/monoD/zoeDepth/ckpts/dinov2_vitl14_pretrain.pth")
435
+ self.dinov2_vitl14.load_state_dict(state_dict, strict=True)
436
+
437
+
438
+ cnn_kwargs = cnn_kwargs if cnn_kwargs is not None else {}
439
+ self.cnn = VGG19(**cnn_kwargs)
440
+ self.amp = amp
441
+ self.amp_dtype = amp_dtype
442
+ if self.amp:
443
+ dinov2_vitl14 = dinov2_vitl14.to(self.amp_dtype)
444
+ self.dinov2_vitl14 = [dinov2_vitl14] # ugly hack to not show parameters to DDP
445
+
446
+
447
+ def train(self, mode: bool = True):
448
+ return self.cnn.train(mode)
449
+
450
+ def forward(self, x, upsample = False):
451
+ B,C,H,W = x.shape
452
+ feature_pyramid = self.cnn(x)
453
+
454
+ if not upsample:
455
+ with torch.no_grad():
456
+ if self.dinov2_vitl14[0].device != x.device:
457
+ self.dinov2_vitl14[0] = self.dinov2_vitl14[0].to(x.device).to(self.amp_dtype)
458
+ dinov2_features_16 = self.dinov2_vitl14[0].forward_features(x.to(self.amp_dtype))
459
+ features_16 = dinov2_features_16['x_norm_patchtokens'].permute(0,2,1).reshape(B,1024,H//14, W//14)
460
+ del dinov2_features_16
461
+ feature_pyramid[16] = features_16
462
+ return feature_pyramid
463
+
464
+ class Dinov2(nn.Module):
465
+ def __init__(self, amp = True, amp_dtype = torch.float16):
466
+ super().__init__()
467
+ # in case the Internet connection is not stable, please load the DINOv2 locally
468
+ self.dinov2_vitl14 = torch.hub.load('models/torchhub/facebookresearch_dinov2_main',
469
+ 'dinov2_{:}14'.format("vitl"), source='local', pretrained=False)
470
+
471
+ state_dict = torch.load("models/monoD/zoeDepth/ckpts/dinov2_vitl14_pretrain.pth")
472
+ self.dinov2_vitl14.load_state_dict(state_dict, strict=True)
473
+
474
+ self.amp = amp
475
+ self.amp_dtype = amp_dtype
476
+ if self.amp:
477
+ self.dinov2_vitl14 = self.dinov2_vitl14.to(self.amp_dtype)
478
+
479
+ def forward(self, x, upsample = False):
480
+ B,C,H,W = x.shape
481
+ mean_ = torch.tensor([0.485, 0.456, 0.406],
482
+ device=x.device).view(1, 3, 1, 1)
483
+ std_ = torch.tensor([0.229, 0.224, 0.225],
484
+ device=x.device).view(1, 3, 1, 1)
485
+ x = (x+1)/2
486
+ x = (x - mean_)/std_
487
+ h_re, w_re = 560, 560
488
+ x_resize = F.interpolate(x, size=(h_re, w_re),
489
+ mode='bilinear', align_corners=True)
490
+ if not upsample:
491
+ with torch.no_grad():
492
+ dinov2_features_16 = self.dinov2_vitl14.forward_features(x_resize.to(self.amp_dtype))
493
+ features_16 = dinov2_features_16['x_norm_patchtokens'].permute(0,2,1).reshape(B,1024,h_re//14, w_re//14)
494
+ del dinov2_features_16
495
+ features_16 = F.interpolate(features_16, size=(H//8, W//8), mode="bilinear", align_corners=True)
496
+ return features_16
497
+
498
+ class AttnBlock(nn.Module):
499
+ """
500
+ A DiT block with adaptive layer norm zero (adaLN-Zero) conditioning.
501
+ """
502
+
503
+ def __init__(self, hidden_size, num_heads, mlp_ratio=4.0,
504
+ flash=False, **block_kwargs):
505
+ super().__init__()
506
+ self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
507
+ self.flash=flash
508
+
509
+ self.attn = Attention(
510
+ hidden_size, num_heads=num_heads, qkv_bias=True, flash=flash,
511
+ **block_kwargs
512
+ )
513
+
514
+ self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
515
+ mlp_hidden_dim = int(hidden_size * mlp_ratio)
516
+ approx_gelu = lambda: nn.GELU(approximate="tanh")
517
+ self.mlp = Mlp(
518
+ in_features=hidden_size,
519
+ hidden_features=mlp_hidden_dim,
520
+ act_layer=approx_gelu,
521
+ drop=0,
522
+ )
523
+ def forward(self, x):
524
+ x = x + self.attn(self.norm1(x))
525
+ x = x + self.mlp(self.norm2(x))
526
+ return x
527
+
528
+ class CrossAttnBlock(nn.Module):
529
+ def __init__(self, hidden_size, context_dim, num_heads=1, mlp_ratio=4.0,
530
+ flash=True, **block_kwargs):
531
+ super().__init__()
532
+ self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
533
+ self.norm_context = nn.LayerNorm(hidden_size)
534
+
535
+ self.cross_attn = Attention(
536
+ hidden_size, context_dim=context_dim,
537
+ num_heads=num_heads, qkv_bias=True, **block_kwargs, flash=flash
538
+
539
+ )
540
+
541
+ self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
542
+ mlp_hidden_dim = int(hidden_size * mlp_ratio)
543
+ approx_gelu = lambda: nn.GELU(approximate="tanh")
544
+ self.mlp = Mlp(
545
+ in_features=hidden_size,
546
+ hidden_features=mlp_hidden_dim,
547
+ act_layer=approx_gelu,
548
+ drop=0,
549
+ )
550
+
551
+ def forward(self, x, context):
552
+ with autocast():
553
+ x = x + self.cross_attn(
554
+ self.norm1(x), self.norm_context(context)
555
+ )
556
+ x = x + self.mlp(self.norm2(x))
557
+ return x
558
+
559
+
560
+ def bilinear_sampler(img, coords, mode="bilinear", mask=False):
561
+ """Wrapper for grid_sample, uses pixel coordinates"""
562
+ H, W = img.shape[-2:]
563
+ xgrid, ygrid = coords.split([1, 1], dim=-1)
564
+ # go to 0,1 then 0,2 then -1,1
565
+ xgrid = 2 * xgrid / (W - 1) - 1
566
+ ygrid = 2 * ygrid / (H - 1) - 1
567
+
568
+ grid = torch.cat([xgrid, ygrid], dim=-1)
569
+ img = F.grid_sample(img, grid, align_corners=True)
570
+
571
+ if mask:
572
+ mask = (xgrid > -1) & (ygrid > -1) & (xgrid < 1) & (ygrid < 1)
573
+ return img, mask.float()
574
+
575
+ return img
576
+
577
+
578
+ class CorrBlock:
579
+ def __init__(self, fmaps, num_levels=4, radius=4, depths_dnG=None):
580
+ B, S, C, H_prev, W_prev = fmaps.shape
581
+ self.S, self.C, self.H, self.W = S, C, H_prev, W_prev
582
+
583
+ self.num_levels = num_levels
584
+ self.radius = radius
585
+ self.fmaps_pyramid = []
586
+ self.depth_pyramid = []
587
+ self.fmaps_pyramid.append(fmaps)
588
+ if depths_dnG is not None:
589
+ self.depth_pyramid.append(depths_dnG)
590
+ for i in range(self.num_levels - 1):
591
+ if depths_dnG is not None:
592
+ depths_dnG_ = depths_dnG.reshape(B * S, 1, H_prev, W_prev)
593
+ depths_dnG_ = F.avg_pool2d(depths_dnG_, 2, stride=2)
594
+ _, _, H, W = depths_dnG_.shape
595
+ depths_dnG = depths_dnG_.reshape(B, S, 1, H, W)
596
+ self.depth_pyramid.append(depths_dnG)
597
+ fmaps_ = fmaps.reshape(B * S, C, H_prev, W_prev)
598
+ fmaps_ = F.avg_pool2d(fmaps_, 2, stride=2)
599
+ _, _, H, W = fmaps_.shape
600
+ fmaps = fmaps_.reshape(B, S, C, H, W)
601
+ H_prev = H
602
+ W_prev = W
603
+ self.fmaps_pyramid.append(fmaps)
604
+
605
+ def sample(self, coords):
606
+ r = self.radius
607
+ B, S, N, D = coords.shape
608
+ assert D == 2
609
+
610
+ H, W = self.H, self.W
611
+ out_pyramid = []
612
+ for i in range(self.num_levels):
613
+ corrs = self.corrs_pyramid[i] # B, S, N, H, W
614
+ _, _, _, H, W = corrs.shape
615
+
616
+ dx = torch.linspace(-r, r, 2 * r + 1)
617
+ dy = torch.linspace(-r, r, 2 * r + 1)
618
+ delta = torch.stack(torch.meshgrid(dy, dx, indexing="ij"), axis=-1).to(
619
+ coords.device
620
+ )
621
+ centroid_lvl = coords.reshape(B * S * N, 1, 1, 2) / 2 ** i
622
+ delta_lvl = delta.view(1, 2 * r + 1, 2 * r + 1, 2)
623
+ coords_lvl = centroid_lvl + delta_lvl
624
+ corrs = bilinear_sampler(corrs.reshape(B * S * N, 1, H, W), coords_lvl)
625
+ corrs = corrs.view(B, S, N, -1)
626
+ out_pyramid.append(corrs)
627
+
628
+ out = torch.cat(out_pyramid, dim=-1) # B, S, N, LRR*2
629
+ return out.contiguous().float()
630
+
631
+ def corr(self, targets):
632
+ B, S, N, C = targets.shape
633
+ assert C == self.C
634
+ assert S == self.S
635
+
636
+ fmap1 = targets
637
+
638
+ self.corrs_pyramid = []
639
+ for fmaps in self.fmaps_pyramid:
640
+ _, _, _, H, W = fmaps.shape
641
+ fmap2s = fmaps.view(B, S, C, H * W)
642
+ corrs = torch.matmul(fmap1, fmap2s)
643
+ corrs = corrs.view(B, S, N, H, W)
644
+ corrs = corrs / torch.sqrt(torch.tensor(C).float())
645
+ self.corrs_pyramid.append(corrs)
646
+
647
+ def corr_sample(self, targets, coords, coords_dp=None):
648
+ B, S, N, C = targets.shape
649
+ r = self.radius
650
+ Dim_c = (2*r+1)**2
651
+ assert C == self.C
652
+ assert S == self.S
653
+
654
+ out_pyramid = []
655
+ out_pyramid_dp = []
656
+ for i in range(self.num_levels):
657
+ dx = torch.linspace(-r, r, 2 * r + 1)
658
+ dy = torch.linspace(-r, r, 2 * r + 1)
659
+ delta = torch.stack(torch.meshgrid(dy, dx, indexing="ij"), axis=-1).to(
660
+ coords.device
661
+ )
662
+ centroid_lvl = coords.reshape(B * S * N, 1, 1, 2) / 2 ** i
663
+ delta_lvl = delta.view(1, 2 * r + 1, 2 * r + 1, 2)
664
+ coords_lvl = centroid_lvl + delta_lvl
665
+ fmaps = self.fmaps_pyramid[i]
666
+ _, _, _, H, W = fmaps.shape
667
+ fmap2s = fmaps.view(B*S, C, H, W)
668
+ if len(self.depth_pyramid)>0:
669
+ depths_dnG_i = self.depth_pyramid[i]
670
+ depths_dnG_i = depths_dnG_i.view(B*S, 1, H, W)
671
+ dnG_sample = bilinear_sampler(depths_dnG_i, coords_lvl.view(B*S,1,N*Dim_c,2))
672
+ dp_corrs = (dnG_sample.view(B*S,N,-1) - coords_dp[0]).abs()/coords_dp[0]
673
+ out_pyramid_dp.append(dp_corrs)
674
+ fmap2s_sample = bilinear_sampler(fmap2s, coords_lvl.view(B*S,1,N*Dim_c,2))
675
+ fmap2s_sample = fmap2s_sample.permute(0, 3, 1, 2) # B*S, N*Dim_c, C, -1
676
+ corrs = torch.matmul(targets.reshape(B*S*N, 1, -1), fmap2s_sample.reshape(B*S*N, Dim_c, -1).permute(0, 2, 1))
677
+ corrs = corrs / torch.sqrt(torch.tensor(C).float())
678
+ corrs = corrs.view(B, S, N, -1)
679
+ out_pyramid.append(corrs)
680
+
681
+ out = torch.cat(out_pyramid, dim=-1) # B, S, N, LRR*2
682
+ if len(self.depth_pyramid)>0:
683
+ out_dp = torch.cat(out_pyramid_dp, dim=-1)
684
+ self.fcorrD = out_dp.contiguous().float()
685
+ else:
686
+ self.fcorrD = torch.zeros_like(out).contiguous().float()
687
+ return out.contiguous().float()
688
+
689
+
690
+ class EUpdateFormer(nn.Module):
691
+ """
692
+ Transformer model that updates track estimates.
693
+ """
694
+
695
+ def __init__(
696
+ self,
697
+ space_depth=12,
698
+ time_depth=12,
699
+ input_dim=320,
700
+ hidden_size=384,
701
+ num_heads=8,
702
+ output_dim=130,
703
+ mlp_ratio=4.0,
704
+ vq_depth=3,
705
+ add_space_attn=True,
706
+ add_time_attn=True,
707
+ flash=True
708
+ ):
709
+ super().__init__()
710
+ self.out_channels = 2
711
+ self.num_heads = num_heads
712
+ self.hidden_size = hidden_size
713
+ self.add_space_attn = add_space_attn
714
+ self.input_transform = torch.nn.Linear(input_dim, hidden_size, bias=True)
715
+ self.flash = flash
716
+ self.flow_head = nn.Sequential(
717
+ nn.Linear(hidden_size, output_dim, bias=True),
718
+ nn.ReLU(inplace=True),
719
+ nn.Linear(output_dim, output_dim, bias=True),
720
+ nn.ReLU(inplace=True),
721
+ nn.Linear(output_dim, output_dim, bias=True)
722
+ )
723
+
724
+ cross_attn_kwargs = {
725
+ "d_model": 384,
726
+ "nhead": 4,
727
+ "layer_names": ['self', 'cross'] * 3,
728
+ }
729
+ self.gnn = LocalFeatureTransformer(cross_attn_kwargs)
730
+
731
+ # Attention Modules in the temporal dimension
732
+ self.time_blocks = nn.ModuleList(
733
+ [
734
+ AttnBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio, flash=flash) if add_time_attn else nn.Identity()
735
+ for _ in range(time_depth)
736
+ ]
737
+ )
738
+
739
+ if add_space_attn:
740
+ self.space_blocks = nn.ModuleList(
741
+ [
742
+ AttnBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio, flash=flash)
743
+ for _ in range(space_depth)
744
+ ]
745
+ )
746
+ assert len(self.time_blocks) >= len(self.space_blocks)
747
+
748
+ # Placeholder for the rigid transformation
749
+ self.RigidProj = nn.Linear(self.hidden_size, 128, bias=True)
750
+ self.Proj = nn.Linear(self.hidden_size, 128, bias=True)
751
+
752
+ self.se3_dec = nn.Linear(384, 3, bias=True)
753
+ self.initialize_weights()
754
+
755
+ def initialize_weights(self):
756
+ def _basic_init(module):
757
+ if isinstance(module, nn.Linear):
758
+ torch.nn.init.xavier_uniform_(module.weight)
759
+ if module.bias is not None:
760
+ nn.init.constant_(module.bias, 0)
761
+
762
+ self.apply(_basic_init)
763
+
764
+ def forward(self, input_tensor, se3_feature):
765
+ """ Updating with Transformer
766
+
767
+ Args:
768
+ input_tensor: B, N, T, C
769
+ arap_embed: B, N, T, C
770
+ """
771
+ B, N, T, C = input_tensor.shape
772
+ x = self.input_transform(input_tensor)
773
+ tokens = x
774
+ K = 0
775
+ j = 0
776
+ for i in range(len(self.time_blocks)):
777
+ tokens_time = rearrange(tokens, "b n t c -> (b n) t c", b=B, t=T, n=N+K)
778
+ tokens_time = self.time_blocks[i](tokens_time)
779
+ tokens = rearrange(tokens_time, "(b n) t c -> b n t c ", b=B, t=T, n=N+K)
780
+ if self.add_space_attn and (
781
+ i % (len(self.time_blocks) // len(self.space_blocks)) == 0
782
+ ):
783
+ tokens_space = rearrange(tokens, "b n t c -> (b t) n c ", b=B, t=T, n=N)
784
+ tokens_space = self.space_blocks[j](tokens_space)
785
+ tokens = rearrange(tokens_space, "(b t) n c -> b n t c ", b=B, t=T, n=N)
786
+ j += 1
787
+
788
+ B, N, S, _ = tokens.shape
789
+ feat0, feat1 = self.gnn(tokens.view(B*N*S, -1)[None,...], se3_feature[None, ...])
790
+
791
+ so3 = F.tanh(self.se3_dec(feat0.view(B*N*S, -1)[None,...].view(B, N, S, -1))/100)
792
+ flow = self.flow_head(feat0.view(B,N,S,-1))
793
+
794
+ return flow, _, _, feat1, so3
795
+
796
+
797
+ class FusionFormer(nn.Module):
798
+ """
799
+ Fuse the feature tracks info with the low rank motion tokens
800
+ """
801
+ def __init__(
802
+ self,
803
+ d_model=64,
804
+ nhead=8,
805
+ attn_iters=4,
806
+ mlp_ratio=4.0,
807
+ flash=False,
808
+ input_dim=35,
809
+ output_dim=384+3,
810
+ ):
811
+ super().__init__()
812
+ self.flash = flash
813
+ self.in_proj = nn.ModuleList(
814
+ [
815
+ nn.Linear(input_dim, d_model)
816
+ for _ in range(2)
817
+ ]
818
+ )
819
+ self.out_proj = nn.Linear(d_model, output_dim, bias=True)
820
+ self.time_blocks = nn.ModuleList(
821
+ [
822
+ CrossAttnBlock(d_model, d_model, nhead, mlp_ratio=mlp_ratio)
823
+ for _ in range(attn_iters)
824
+ ]
825
+ )
826
+ self.space_blocks = nn.ModuleList(
827
+ [
828
+ AttnBlock(d_model, nhead, mlp_ratio=mlp_ratio, flash=self.flash)
829
+ for _ in range(attn_iters)
830
+ ]
831
+ )
832
+
833
+ self.initialize_weights()
834
+
835
+ def initialize_weights(self):
836
+ def _basic_init(module):
837
+ if isinstance(module, nn.Linear):
838
+ torch.nn.init.xavier_uniform_(module.weight)
839
+ if module.bias is not None:
840
+ nn.init.constant_(module.bias, 0)
841
+ self.apply(_basic_init)
842
+ self.out_proj.weight.data.fill_(0)
843
+ self.out_proj.bias.data.fill_(0)
844
+
845
+ def forward(self, x, token_cls):
846
+ """ Fuse the feature tracks info with the low rank motion tokens
847
+
848
+ Args:
849
+ x: B, S, N, C
850
+ Traj_whole: B T N C
851
+
852
+ """
853
+ B, S, N, C = x.shape
854
+ _, T, _, _ = token_cls.shape
855
+ x = self.in_proj[0](x)
856
+ token_cls = self.in_proj[1](token_cls)
857
+ token_cls = rearrange(token_cls, 'b t n c -> (b n) t c')
858
+
859
+ for i in range(len(self.space_blocks)):
860
+ x = rearrange(x, 'b s n c -> (b n) s c')
861
+ x = self.time_blocks[i](x, token_cls)
862
+ x = self.space_blocks[i](x.permute(1,0,2))
863
+ x = rearrange(x, '(b s) n c -> b s n c', b=B, s=S, n=N)
864
+
865
+ x = self.out_proj(x)
866
+ delta_xyz = x[..., :3]
867
+ feat_traj = x[..., 3:]
868
+ return delta_xyz, feat_traj
869
+
870
+ class Lie():
871
+ """
872
+ Lie algebra for SO(3) and SE(3) operations in PyTorch
873
+ """
874
+
875
+ def so3_to_SO3(self,w): # [...,3]
876
+ wx = self.skew_symmetric(w)
877
+ theta = w.norm(dim=-1)[...,None,None]
878
+ I = torch.eye(3,device=w.device,dtype=torch.float32)
879
+ A = self.taylor_A(theta)
880
+ B = self.taylor_B(theta)
881
+ R = I+A*wx+B*wx@wx
882
+ return R
883
+
884
+ def SO3_to_so3(self,R,eps=1e-7): # [...,3,3]
885
+ trace = R[...,0,0]+R[...,1,1]+R[...,2,2]
886
+ theta = ((trace-1)/2).clamp(-1+eps,1-eps).acos_()[...,None,None]%np.pi # ln(R) will explode if theta==pi
887
+ lnR = 1/(2*self.taylor_A(theta)+1e-8)*(R-R.transpose(-2,-1)) # FIXME: wei-chiu finds it weird
888
+ w0,w1,w2 = lnR[...,2,1],lnR[...,0,2],lnR[...,1,0]
889
+ w = torch.stack([w0,w1,w2],dim=-1)
890
+ return w
891
+
892
+ def se3_to_SE3(self,wu): # [...,3]
893
+ w,u = wu.split([3,3],dim=-1)
894
+ wx = self.skew_symmetric(w)
895
+ theta = w.norm(dim=-1)[...,None,None]
896
+ I = torch.eye(3,device=w.device,dtype=torch.float32)
897
+ A = self.taylor_A(theta)
898
+ B = self.taylor_B(theta)
899
+ C = self.taylor_C(theta)
900
+ R = I+A*wx+B*wx@wx
901
+ V = I+B*wx+C*wx@wx
902
+ Rt = torch.cat([R,(V@u[...,None])],dim=-1)
903
+ return Rt
904
+
905
+ def SE3_to_se3(self,Rt,eps=1e-8): # [...,3,4]
906
+ R,t = Rt.split([3,1],dim=-1)
907
+ w = self.SO3_to_so3(R)
908
+ wx = self.skew_symmetric(w)
909
+ theta = w.norm(dim=-1)[...,None,None]
910
+ I = torch.eye(3,device=w.device,dtype=torch.float32)
911
+ A = self.taylor_A(theta)
912
+ B = self.taylor_B(theta)
913
+ invV = I-0.5*wx+(1-A/(2*B))/(theta**2+eps)*wx@wx
914
+ u = (invV@t)[...,0]
915
+ wu = torch.cat([w,u],dim=-1)
916
+ return wu
917
+
918
+ def skew_symmetric(self,w):
919
+ w0,w1,w2 = w.unbind(dim=-1)
920
+ O = torch.zeros_like(w0)
921
+ wx = torch.stack([torch.stack([O,-w2,w1],dim=-1),
922
+ torch.stack([w2,O,-w0],dim=-1),
923
+ torch.stack([-w1,w0,O],dim=-1)],dim=-2)
924
+ return wx
925
+
926
+ def taylor_A(self,x,nth=10):
927
+ # Taylor expansion of sin(x)/x
928
+ ans = torch.zeros_like(x)
929
+ denom = 1.
930
+ for i in range(nth+1):
931
+ if i>0: denom *= (2*i)*(2*i+1)
932
+ ans = ans+(-1)**i*x**(2*i)/denom
933
+ return ans
934
+ def taylor_B(self,x,nth=10):
935
+ # Taylor expansion of (1-cos(x))/x**2
936
+ ans = torch.zeros_like(x)
937
+ denom = 1.
938
+ for i in range(nth+1):
939
+ denom *= (2*i+1)*(2*i+2)
940
+ ans = ans+(-1)**i*x**(2*i)/denom
941
+ return ans
942
+ def taylor_C(self,x,nth=10):
943
+ # Taylor expansion of (x-sin(x))/x**3
944
+ ans = torch.zeros_like(x)
945
+ denom = 1.
946
+ for i in range(nth+1):
947
+ denom *= (2*i+2)*(2*i+3)
948
+ ans = ans+(-1)**i*x**(2*i)/denom
949
+ return ans
950
+
951
+
952
+
953
+ def pix2cam(coords,
954
+ intr):
955
+ """
956
+ Args:
957
+ coords: [B, T, N, 3]
958
+ intr: [B, T, 3, 3]
959
+ """
960
+ coords=coords.detach()
961
+ B, S, N, _, = coords.shape
962
+ xy_src = coords.reshape(B*S*N, 3)
963
+ intr = intr[:, :, None, ...].repeat(1, 1, N, 1, 1).reshape(B*S*N, 3, 3)
964
+ xy_src = torch.cat([xy_src[..., :2], torch.ones_like(xy_src[..., :1])], dim=-1)
965
+ xyz_src = (torch.inverse(intr)@xy_src[...,None])[...,0]
966
+ dp_pred = coords[..., 2]
967
+ xyz_src_ = (xyz_src*(dp_pred.reshape(S*N, 1)))
968
+ xyz_src_ = xyz_src_.reshape(B, S, N, 3)
969
+ return xyz_src_
970
+
971
+ def cam2pix(coords,
972
+ intr):
973
+ """
974
+ Args:
975
+ coords: [B, T, N, 3]
976
+ intr: [B, T, 3, 3]
977
+ """
978
+ coords=coords.detach()
979
+ B, S, N, _, = coords.shape
980
+ xy_src = coords.reshape(B*S*N, 3).clone()
981
+ intr = intr[:, :, None, ...].repeat(1, 1, N, 1, 1).reshape(B*S*N, 3, 3)
982
+ xy_src = xy_src / (xy_src[..., 2:]+1e-5)
983
+ xyz_src = (intr@xy_src[...,None])[...,0]
984
+ dp_pred = coords[..., 2]
985
+ xyz_src[...,2] *= dp_pred.reshape(S*N)
986
+ xyz_src = xyz_src.reshape(B, S, N, 3)
987
+ return xyz_src
988
+
989
+ def edgeMat(traj3d):
990
+ """
991
+ Args:
992
+ traj3d: [B, T, N, 3]
993
+ """
994
+ B, T, N, _ = traj3d.shape
995
+ traj3d = traj3d
996
+ traj3d = traj3d.view(B, T, N, 3)
997
+ traj3d = traj3d[..., None, :] - traj3d[..., None, :, :] # B, T, N, N, 3
998
+ edgeMat = traj3d.norm(dim=-1) # B, T, N, N
999
+ return edgeMat
models/spatracker/models/core/spatracker/dpt/__init__.py ADDED
File without changes
models/spatracker/models/core/spatracker/dpt/base_model.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+
4
+ class BaseModel(torch.nn.Module):
5
+ def load(self, path):
6
+ """Load model from file.
7
+
8
+ Args:
9
+ path (str): file path
10
+ """
11
+ parameters = torch.load(path, map_location=torch.device("cpu"))
12
+
13
+ if "optimizer" in parameters:
14
+ parameters = parameters["model"]
15
+
16
+ self.load_state_dict(parameters)
models/spatracker/models/core/spatracker/dpt/blocks.py ADDED
@@ -0,0 +1,394 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ from models.spatracker.models.core.spatracker.dpt.vit import (
5
+ _make_pretrained_vitb_rn50_384,
6
+ _make_pretrained_vitl16_384,
7
+ _make_pretrained_vitb16_384,
8
+ forward_vit,
9
+ _make_pretrained_vit_tiny
10
+ )
11
+
12
+
13
+ def _make_encoder(
14
+ backbone,
15
+ features,
16
+ use_pretrained,
17
+ groups=1,
18
+ expand=False,
19
+ exportable=True,
20
+ hooks=None,
21
+ use_vit_only=False,
22
+ use_readout="ignore",
23
+ enable_attention_hooks=False,
24
+ ):
25
+ if backbone == "vitl16_384":
26
+ pretrained = _make_pretrained_vitl16_384(
27
+ use_pretrained,
28
+ hooks=hooks,
29
+ use_readout=use_readout,
30
+ enable_attention_hooks=enable_attention_hooks,
31
+ )
32
+ scratch = _make_scratch(
33
+ [256, 512, 1024, 1024], features, groups=groups, expand=expand
34
+ ) # ViT-L/16 - 85.0% Top1 (backbone)
35
+ elif backbone == "vitb_rn50_384":
36
+ pretrained = _make_pretrained_vitb_rn50_384(
37
+ use_pretrained,
38
+ hooks=hooks,
39
+ use_vit_only=use_vit_only,
40
+ use_readout=use_readout,
41
+ enable_attention_hooks=enable_attention_hooks,
42
+ )
43
+ scratch = _make_scratch(
44
+ [256, 512, 768, 768], features, groups=groups, expand=expand
45
+ ) # ViT-H/16 - 85.0% Top1 (backbone)
46
+ elif backbone == "vitb16_384":
47
+ pretrained = _make_pretrained_vitb16_384(
48
+ use_pretrained,
49
+ hooks=hooks,
50
+ use_readout=use_readout,
51
+ enable_attention_hooks=enable_attention_hooks,
52
+ )
53
+ scratch = _make_scratch(
54
+ [96, 192, 384, 768], features, groups=groups, expand=expand
55
+ ) # ViT-B/16 - 84.6% Top1 (backbone)
56
+ elif backbone == "resnext101_wsl":
57
+ pretrained = _make_pretrained_resnext101_wsl(use_pretrained)
58
+ scratch = _make_scratch(
59
+ [256, 512, 1024, 2048], features, groups=groups, expand=expand
60
+ ) # efficientnet_lite3
61
+ elif backbone == "vit_tiny_r_s16_p8_384":
62
+ pretrained = _make_pretrained_vit_tiny(
63
+ use_pretrained,
64
+ hooks=hooks,
65
+ use_readout=use_readout,
66
+ enable_attention_hooks=enable_attention_hooks,
67
+ )
68
+ scratch = _make_scratch(
69
+ [96, 192, 384, 768], features, groups=groups, expand=expand
70
+ )
71
+ else:
72
+ print(f"Backbone '{backbone}' not implemented")
73
+ assert False
74
+
75
+ return pretrained, scratch
76
+
77
+
78
+ def _make_scratch(in_shape, out_shape, groups=1, expand=False):
79
+ scratch = nn.Module()
80
+
81
+ out_shape1 = out_shape
82
+ out_shape2 = out_shape
83
+ out_shape3 = out_shape
84
+ out_shape4 = out_shape
85
+ if expand == True:
86
+ out_shape1 = out_shape
87
+ out_shape2 = out_shape * 2
88
+ out_shape3 = out_shape * 4
89
+ out_shape4 = out_shape * 8
90
+
91
+ scratch.layer1_rn = nn.Conv2d(
92
+ in_shape[0],
93
+ out_shape1,
94
+ kernel_size=3,
95
+ stride=1,
96
+ padding=1,
97
+ bias=False,
98
+ groups=groups,
99
+ )
100
+ scratch.layer2_rn = nn.Conv2d(
101
+ in_shape[1],
102
+ out_shape2,
103
+ kernel_size=3,
104
+ stride=1,
105
+ padding=1,
106
+ bias=False,
107
+ groups=groups,
108
+ )
109
+ scratch.layer3_rn = nn.Conv2d(
110
+ in_shape[2],
111
+ out_shape3,
112
+ kernel_size=3,
113
+ stride=1,
114
+ padding=1,
115
+ bias=False,
116
+ groups=groups,
117
+ )
118
+ scratch.layer4_rn = nn.Conv2d(
119
+ in_shape[3],
120
+ out_shape4,
121
+ kernel_size=3,
122
+ stride=1,
123
+ padding=1,
124
+ bias=False,
125
+ groups=groups,
126
+ )
127
+
128
+ return scratch
129
+
130
+
131
+ def _make_resnet_backbone(resnet):
132
+ pretrained = nn.Module()
133
+ pretrained.layer1 = nn.Sequential(
134
+ resnet.conv1, resnet.bn1, resnet.relu, resnet.maxpool, resnet.layer1
135
+ )
136
+
137
+ pretrained.layer2 = resnet.layer2
138
+ pretrained.layer3 = resnet.layer3
139
+ pretrained.layer4 = resnet.layer4
140
+
141
+ return pretrained
142
+
143
+
144
+ def _make_pretrained_resnext101_wsl(use_pretrained):
145
+ resnet = torch.hub.load("facebookresearch/WSL-Images", "resnext101_32x8d_wsl")
146
+ return _make_resnet_backbone(resnet)
147
+
148
+
149
+ class Interpolate(nn.Module):
150
+ """Interpolation module."""
151
+
152
+ def __init__(self, scale_factor, mode, align_corners=False):
153
+ """Init.
154
+
155
+ Args:
156
+ scale_factor (float): scaling
157
+ mode (str): interpolation mode
158
+ """
159
+ super(Interpolate, self).__init__()
160
+
161
+ self.interp = nn.functional.interpolate
162
+ self.scale_factor = scale_factor
163
+ self.mode = mode
164
+ self.align_corners = align_corners
165
+
166
+ def forward(self, x):
167
+ """Forward pass.
168
+
169
+ Args:
170
+ x (tensor): input
171
+
172
+ Returns:
173
+ tensor: interpolated data
174
+ """
175
+
176
+ x = self.interp(
177
+ x,
178
+ scale_factor=self.scale_factor,
179
+ mode=self.mode,
180
+ align_corners=self.align_corners,
181
+ )
182
+
183
+ return x
184
+
185
+
186
+ class ResidualConvUnit(nn.Module):
187
+ """Residual convolution module."""
188
+
189
+ def __init__(self, features):
190
+ """Init.
191
+
192
+ Args:
193
+ features (int): number of features
194
+ """
195
+ super().__init__()
196
+
197
+ self.conv1 = nn.Conv2d(
198
+ features, features, kernel_size=3, stride=1, padding=1, bias=True
199
+ )
200
+
201
+ self.conv2 = nn.Conv2d(
202
+ features, features, kernel_size=3, stride=1, padding=1, bias=True
203
+ )
204
+
205
+ self.relu = nn.ReLU(inplace=True)
206
+
207
+ def forward(self, x):
208
+ """Forward pass.
209
+
210
+ Args:
211
+ x (tensor): input
212
+
213
+ Returns:
214
+ tensor: output
215
+ """
216
+ out = self.relu(x)
217
+ out = self.conv1(out)
218
+ out = self.relu(out)
219
+ out = self.conv2(out)
220
+
221
+ return out + x
222
+
223
+
224
+ class FeatureFusionBlock(nn.Module):
225
+ """Feature fusion block."""
226
+
227
+ def __init__(self, features):
228
+ """Init.
229
+
230
+ Args:
231
+ features (int): number of features
232
+ """
233
+ super(FeatureFusionBlock, self).__init__()
234
+
235
+ self.resConfUnit1 = ResidualConvUnit(features)
236
+ self.resConfUnit2 = ResidualConvUnit(features)
237
+
238
+ def forward(self, *xs):
239
+ """Forward pass.
240
+
241
+ Returns:
242
+ tensor: output
243
+ """
244
+ output = xs[0]
245
+
246
+ if len(xs) == 2:
247
+ output += self.resConfUnit1(xs[1])
248
+
249
+ output = self.resConfUnit2(output)
250
+
251
+ output = nn.functional.interpolate(
252
+ output, scale_factor=2, mode="bilinear", align_corners=True
253
+ )
254
+
255
+ return output
256
+
257
+
258
+ class ResidualConvUnit_custom(nn.Module):
259
+ """Residual convolution module."""
260
+
261
+ def __init__(self, features, activation, bn):
262
+ """Init.
263
+
264
+ Args:
265
+ features (int): number of features
266
+ """
267
+ super().__init__()
268
+
269
+ self.bn = bn
270
+
271
+ self.groups = 1
272
+
273
+ self.conv1 = nn.Conv2d(
274
+ features,
275
+ features,
276
+ kernel_size=3,
277
+ stride=1,
278
+ padding=1,
279
+ bias=not self.bn,
280
+ groups=self.groups,
281
+ )
282
+
283
+ self.conv2 = nn.Conv2d(
284
+ features,
285
+ features,
286
+ kernel_size=3,
287
+ stride=1,
288
+ padding=1,
289
+ bias=not self.bn,
290
+ groups=self.groups,
291
+ )
292
+
293
+ if self.bn == True:
294
+ self.bn1 = nn.BatchNorm2d(features)
295
+ self.bn2 = nn.BatchNorm2d(features)
296
+
297
+ self.activation = activation
298
+
299
+ self.skip_add = nn.quantized.FloatFunctional()
300
+
301
+ def forward(self, x):
302
+ """Forward pass.
303
+
304
+ Args:
305
+ x (tensor): input
306
+
307
+ Returns:
308
+ tensor: output
309
+ """
310
+
311
+ out = self.activation(x)
312
+ out = self.conv1(out)
313
+ if self.bn == True:
314
+ out = self.bn1(out)
315
+
316
+ out = self.activation(out)
317
+ out = self.conv2(out)
318
+ if self.bn == True:
319
+ out = self.bn2(out)
320
+
321
+ if self.groups > 1:
322
+ out = self.conv_merge(out)
323
+
324
+ return self.skip_add.add(out, x)
325
+
326
+ # return out + x
327
+
328
+
329
+ class FeatureFusionBlock_custom(nn.Module):
330
+ """Feature fusion block."""
331
+
332
+ def __init__(
333
+ self,
334
+ features,
335
+ activation,
336
+ deconv=False,
337
+ bn=False,
338
+ expand=False,
339
+ align_corners=True,
340
+ ):
341
+ """Init.
342
+
343
+ Args:
344
+ features (int): number of features
345
+ """
346
+ super(FeatureFusionBlock_custom, self).__init__()
347
+
348
+ self.deconv = deconv
349
+ self.align_corners = align_corners
350
+
351
+ self.groups = 1
352
+
353
+ self.expand = expand
354
+ out_features = features
355
+ if self.expand == True:
356
+ out_features = features // 2
357
+
358
+ self.out_conv = nn.Conv2d(
359
+ features,
360
+ out_features,
361
+ kernel_size=1,
362
+ stride=1,
363
+ padding=0,
364
+ bias=True,
365
+ groups=1,
366
+ )
367
+
368
+ self.resConfUnit1 = ResidualConvUnit_custom(features, activation, bn)
369
+ self.resConfUnit2 = ResidualConvUnit_custom(features, activation, bn)
370
+
371
+ self.skip_add = nn.quantized.FloatFunctional()
372
+
373
+ def forward(self, *xs):
374
+ """Forward pass.
375
+
376
+ Returns:
377
+ tensor: output
378
+ """
379
+ output = xs[0]
380
+
381
+ if len(xs) == 2:
382
+ res = self.resConfUnit1(xs[1])
383
+ output = self.skip_add.add(output, res)
384
+ # output += res
385
+
386
+ output = self.resConfUnit2(output)
387
+
388
+ output = nn.functional.interpolate(
389
+ output, scale_factor=2, mode="bilinear", align_corners=self.align_corners
390
+ )
391
+
392
+ output = self.out_conv(output)
393
+
394
+ return output
models/spatracker/models/core/spatracker/dpt/midas_net.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """MidashNet: Network for monocular depth estimation trained by mixing several datasets.
2
+ This file contains code that is adapted from
3
+ https://github.com/thomasjpfan/pytorch_refinenet/blob/master/pytorch_refinenet/refinenet/refinenet_4cascade.py
4
+ """
5
+ import torch
6
+ import torch.nn as nn
7
+
8
+ from models.spatracker.models.core.spatracker.dpt.base_model import BaseModel
9
+ from models.spatracker.models.core.spatracker.dpt.blocks import FeatureFusionBlock, Interpolate, _make_encoder
10
+
11
+
12
+ class MidasNet_large(BaseModel):
13
+ """Network for monocular depth estimation."""
14
+
15
+ def __init__(self, path=None, features=256, non_negative=True):
16
+ """Init.
17
+
18
+ Args:
19
+ path (str, optional): Path to saved model. Defaults to None.
20
+ features (int, optional): Number of features. Defaults to 256.
21
+ backbone (str, optional): Backbone network for encoder. Defaults to resnet50
22
+ """
23
+ print("Loading weights: ", path)
24
+
25
+ super(MidasNet_large, self).__init__()
26
+
27
+ use_pretrained = False if path is None else True
28
+
29
+ self.pretrained, self.scratch = _make_encoder(
30
+ backbone="resnext101_wsl", features=features, use_pretrained=use_pretrained
31
+ )
32
+
33
+ self.scratch.refinenet4 = FeatureFusionBlock(features)
34
+ self.scratch.refinenet3 = FeatureFusionBlock(features)
35
+ self.scratch.refinenet2 = FeatureFusionBlock(features)
36
+ self.scratch.refinenet1 = FeatureFusionBlock(features)
37
+
38
+ self.scratch.output_conv = nn.Sequential(
39
+ nn.Conv2d(features, 128, kernel_size=3, stride=1, padding=1),
40
+ Interpolate(scale_factor=2, mode="bilinear"),
41
+ nn.Conv2d(128, 32, kernel_size=3, stride=1, padding=1),
42
+ nn.ReLU(True),
43
+ nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0),
44
+ nn.ReLU(True) if non_negative else nn.Identity(),
45
+ )
46
+
47
+ if path:
48
+ self.load(path)
49
+
50
+ def forward(self, x):
51
+ """Forward pass.
52
+
53
+ Args:
54
+ x (tensor): input data (image)
55
+
56
+ Returns:
57
+ tensor: depth
58
+ """
59
+
60
+ layer_1 = self.pretrained.layer1(x)
61
+ layer_2 = self.pretrained.layer2(layer_1)
62
+ layer_3 = self.pretrained.layer3(layer_2)
63
+ layer_4 = self.pretrained.layer4(layer_3)
64
+
65
+ layer_1_rn = self.scratch.layer1_rn(layer_1)
66
+ layer_2_rn = self.scratch.layer2_rn(layer_2)
67
+ layer_3_rn = self.scratch.layer3_rn(layer_3)
68
+ layer_4_rn = self.scratch.layer4_rn(layer_4)
69
+
70
+ path_4 = self.scratch.refinenet4(layer_4_rn)
71
+ path_3 = self.scratch.refinenet3(path_4, layer_3_rn)
72
+ path_2 = self.scratch.refinenet2(path_3, layer_2_rn)
73
+ path_1 = self.scratch.refinenet1(path_2, layer_1_rn)
74
+
75
+ out = self.scratch.output_conv(path_1)
76
+
77
+ return torch.squeeze(out, dim=1)
models/spatracker/models/core/spatracker/dpt/models.py ADDED
@@ -0,0 +1,231 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ from models.spatracker.models.core.spatracker.dpt.base_model import BaseModel
6
+ from models.spatracker.models.core.spatracker.dpt.blocks import (
7
+ FeatureFusionBlock,
8
+ FeatureFusionBlock_custom,
9
+ Interpolate,
10
+ _make_encoder,
11
+ forward_vit,
12
+ )
13
+
14
+
15
+ def _make_fusion_block(features, use_bn):
16
+ return FeatureFusionBlock_custom(
17
+ features,
18
+ nn.ReLU(False),
19
+ deconv=False,
20
+ bn=use_bn,
21
+ expand=False,
22
+ align_corners=True,
23
+ )
24
+
25
+
26
+ class DPT(BaseModel):
27
+ def __init__(
28
+ self,
29
+ head,
30
+ features=256,
31
+ backbone="vitb_rn50_384",
32
+ readout="project",
33
+ channels_last=False,
34
+ use_bn=True,
35
+ enable_attention_hooks=False,
36
+ ):
37
+
38
+ super(DPT, self).__init__()
39
+
40
+ self.channels_last = channels_last
41
+
42
+ hooks = {
43
+ "vitb_rn50_384": [0, 1, 8, 11],
44
+ "vitb16_384": [2, 5, 8, 11],
45
+ "vitl16_384": [5, 11, 17, 23],
46
+ "vit_tiny_r_s16_p8_384": [0, 1, 2, 3],
47
+ }
48
+
49
+ # Instantiate backbone and reassemble blocks
50
+ self.pretrained, self.scratch = _make_encoder(
51
+ backbone,
52
+ features,
53
+ False, # Set to true of you want to train from scratch, uses ImageNet weights
54
+ groups=1,
55
+ expand=False,
56
+ exportable=False,
57
+ hooks=hooks[backbone],
58
+ use_readout=readout,
59
+ enable_attention_hooks=enable_attention_hooks,
60
+ )
61
+
62
+ self.scratch.refinenet1 = _make_fusion_block(features, use_bn)
63
+ self.scratch.refinenet2 = _make_fusion_block(features, use_bn)
64
+ self.scratch.refinenet3 = _make_fusion_block(features, use_bn)
65
+ self.scratch.refinenet4 = _make_fusion_block(features, use_bn)
66
+
67
+ self.scratch.output_conv = head
68
+
69
+ self.proj_out = nn.Sequential(
70
+ nn.Conv2d(
71
+ 256+512+384+384,
72
+ 256,
73
+ kernel_size=3,
74
+ padding=1,
75
+ padding_mode="zeros",
76
+ ),
77
+ nn.BatchNorm2d(128 * 2),
78
+ nn.ReLU(True),
79
+ nn.Conv2d(
80
+ 128 * 2,
81
+ 128,
82
+ kernel_size=3,
83
+ padding=1,
84
+ padding_mode="zeros",
85
+ )
86
+ )
87
+
88
+
89
+ def forward(self, x, only_enc=False):
90
+ if self.channels_last == True:
91
+ x.contiguous(memory_format=torch.channels_last)
92
+ if only_enc:
93
+ layer_1, layer_2, layer_3, layer_4 = forward_vit(self.pretrained, x)
94
+ a = (layer_1)
95
+ b = (
96
+ F.interpolate(
97
+ layer_2,
98
+ scale_factor=2,
99
+ mode="bilinear",
100
+ align_corners=True,
101
+ )
102
+ )
103
+ c = (
104
+ F.interpolate(
105
+ layer_3,
106
+ scale_factor=8,
107
+ mode="bilinear",
108
+ align_corners=True,
109
+ )
110
+ )
111
+ d = (
112
+ F.interpolate(
113
+ layer_4,
114
+ scale_factor=16,
115
+ mode="bilinear",
116
+ align_corners=True,
117
+ )
118
+ )
119
+ x = self.proj_out(torch.cat([a, b, c, d], dim=1))
120
+ return x
121
+ else:
122
+ layer_1, layer_2, layer_3, layer_4 = forward_vit(self.pretrained, x)
123
+
124
+ layer_1_rn = self.scratch.layer1_rn(layer_1)
125
+ layer_2_rn = self.scratch.layer2_rn(layer_2)
126
+ layer_3_rn = self.scratch.layer3_rn(layer_3)
127
+ layer_4_rn = self.scratch.layer4_rn(layer_4)
128
+
129
+ path_4 = self.scratch.refinenet4(layer_4_rn)
130
+ path_3 = self.scratch.refinenet3(path_4, layer_3_rn)
131
+ path_2 = self.scratch.refinenet2(path_3, layer_2_rn)
132
+ path_1 = self.scratch.refinenet1(path_2, layer_1_rn)
133
+
134
+ _,_,H_out,W_out = path_1.size()
135
+ path_2_up = F.interpolate(path_2, size=(H_out,W_out), mode="bilinear", align_corners=True)
136
+ path_3_up = F.interpolate(path_3, size=(H_out,W_out), mode="bilinear", align_corners=True)
137
+ path_4_up = F.interpolate(path_4, size=(H_out,W_out), mode="bilinear", align_corners=True)
138
+
139
+ out = self.scratch.output_conv(path_1+path_2_up+path_3_up+path_4_up)
140
+
141
+ return out
142
+
143
+
144
+ class DPTDepthModel(DPT):
145
+ def __init__(
146
+ self, path=None, non_negative=True, scale=1.0, shift=0.0, invert=False, **kwargs
147
+ ):
148
+ features = kwargs["features"] if "features" in kwargs else 256
149
+
150
+ self.scale = scale
151
+ self.shift = shift
152
+ self.invert = invert
153
+
154
+ head = nn.Sequential(
155
+ nn.Conv2d(features, features // 2, kernel_size=3, stride=1, padding=1),
156
+ Interpolate(scale_factor=2, mode="bilinear", align_corners=True),
157
+ nn.Conv2d(features // 2, 32, kernel_size=3, stride=1, padding=1),
158
+ nn.ReLU(True),
159
+ nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0),
160
+ nn.ReLU(True) if non_negative else nn.Identity(),
161
+ nn.Identity(),
162
+ )
163
+
164
+ super().__init__(head, **kwargs)
165
+
166
+ if path is not None:
167
+ self.load(path)
168
+
169
+ def forward(self, x):
170
+ inv_depth = super().forward(x).squeeze(dim=1)
171
+
172
+ if self.invert:
173
+ depth = self.scale * inv_depth + self.shift
174
+ depth[depth < 1e-8] = 1e-8
175
+ depth = 1.0 / depth
176
+ return depth
177
+ else:
178
+ return inv_depth
179
+
180
+ class DPTEncoder(DPT):
181
+ def __init__(
182
+ self, path=None, non_negative=True, scale=1.0, shift=0.0, invert=False, **kwargs
183
+ ):
184
+ features = kwargs["features"] if "features" in kwargs else 256
185
+
186
+ self.scale = scale
187
+ self.shift = shift
188
+
189
+ head = nn.Sequential(
190
+ nn.Conv2d(features, 128, kernel_size=3, stride=1, padding=1),
191
+ )
192
+
193
+ super().__init__(head, **kwargs)
194
+
195
+ if path is not None:
196
+ self.load(path)
197
+
198
+ def forward(self, x):
199
+ features = super().forward(x, only_enc=True).squeeze(dim=1)
200
+
201
+ return features
202
+
203
+
204
+ class DPTSegmentationModel(DPT):
205
+ def __init__(self, num_classes, path=None, **kwargs):
206
+
207
+ features = kwargs["features"] if "features" in kwargs else 256
208
+
209
+ kwargs["use_bn"] = True
210
+
211
+ head = nn.Sequential(
212
+ nn.Conv2d(features, features, kernel_size=3, padding=1, bias=False),
213
+ nn.BatchNorm2d(features),
214
+ nn.ReLU(True),
215
+ nn.Dropout(0.1, False),
216
+ nn.Conv2d(features, num_classes, kernel_size=1),
217
+ Interpolate(scale_factor=2, mode="bilinear", align_corners=True),
218
+ )
219
+
220
+ super().__init__(head, **kwargs)
221
+
222
+ self.auxlayer = nn.Sequential(
223
+ nn.Conv2d(features, features, kernel_size=3, padding=1, bias=False),
224
+ nn.BatchNorm2d(features),
225
+ nn.ReLU(True),
226
+ nn.Dropout(0.1, False),
227
+ nn.Conv2d(features, num_classes, kernel_size=1),
228
+ )
229
+
230
+ if path is not None:
231
+ self.load(path)
models/spatracker/models/core/spatracker/dpt/transforms.py ADDED
@@ -0,0 +1,231 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import cv2
3
+ import math
4
+
5
+
6
+ def apply_min_size(sample, size, image_interpolation_method=cv2.INTER_AREA):
7
+ """Rezise the sample to ensure the given size. Keeps aspect ratio.
8
+
9
+ Args:
10
+ sample (dict): sample
11
+ size (tuple): image size
12
+
13
+ Returns:
14
+ tuple: new size
15
+ """
16
+ shape = list(sample["disparity"].shape)
17
+
18
+ if shape[0] >= size[0] and shape[1] >= size[1]:
19
+ return sample
20
+
21
+ scale = [0, 0]
22
+ scale[0] = size[0] / shape[0]
23
+ scale[1] = size[1] / shape[1]
24
+
25
+ scale = max(scale)
26
+
27
+ shape[0] = math.ceil(scale * shape[0])
28
+ shape[1] = math.ceil(scale * shape[1])
29
+
30
+ # resize
31
+ sample["image"] = cv2.resize(
32
+ sample["image"], tuple(shape[::-1]), interpolation=image_interpolation_method
33
+ )
34
+
35
+ sample["disparity"] = cv2.resize(
36
+ sample["disparity"], tuple(shape[::-1]), interpolation=cv2.INTER_NEAREST
37
+ )
38
+ sample["mask"] = cv2.resize(
39
+ sample["mask"].astype(np.float32),
40
+ tuple(shape[::-1]),
41
+ interpolation=cv2.INTER_NEAREST,
42
+ )
43
+ sample["mask"] = sample["mask"].astype(bool)
44
+
45
+ return tuple(shape)
46
+
47
+
48
+ class Resize(object):
49
+ """Resize sample to given size (width, height)."""
50
+
51
+ def __init__(
52
+ self,
53
+ width,
54
+ height,
55
+ resize_target=True,
56
+ keep_aspect_ratio=False,
57
+ ensure_multiple_of=1,
58
+ resize_method="lower_bound",
59
+ image_interpolation_method=cv2.INTER_AREA,
60
+ ):
61
+ """Init.
62
+
63
+ Args:
64
+ width (int): desired output width
65
+ height (int): desired output height
66
+ resize_target (bool, optional):
67
+ True: Resize the full sample (image, mask, target).
68
+ False: Resize image only.
69
+ Defaults to True.
70
+ keep_aspect_ratio (bool, optional):
71
+ True: Keep the aspect ratio of the input sample.
72
+ Output sample might not have the given width and height, and
73
+ resize behaviour depends on the parameter 'resize_method'.
74
+ Defaults to False.
75
+ ensure_multiple_of (int, optional):
76
+ Output width and height is constrained to be multiple of this parameter.
77
+ Defaults to 1.
78
+ resize_method (str, optional):
79
+ "lower_bound": Output will be at least as large as the given size.
80
+ "upper_bound": Output will be at max as large as the given size. (Output size might be smaller than given size.)
81
+ "minimal": Scale as least as possible. (Output size might be smaller than given size.)
82
+ Defaults to "lower_bound".
83
+ """
84
+ self.__width = width
85
+ self.__height = height
86
+
87
+ self.__resize_target = resize_target
88
+ self.__keep_aspect_ratio = keep_aspect_ratio
89
+ self.__multiple_of = ensure_multiple_of
90
+ self.__resize_method = resize_method
91
+ self.__image_interpolation_method = image_interpolation_method
92
+
93
+ def constrain_to_multiple_of(self, x, min_val=0, max_val=None):
94
+ y = (np.round(x / self.__multiple_of) * self.__multiple_of).astype(int)
95
+
96
+ if max_val is not None and y > max_val:
97
+ y = (np.floor(x / self.__multiple_of) * self.__multiple_of).astype(int)
98
+
99
+ if y < min_val:
100
+ y = (np.ceil(x / self.__multiple_of) * self.__multiple_of).astype(int)
101
+
102
+ return y
103
+
104
+ def get_size(self, width, height):
105
+ # determine new height and width
106
+ scale_height = self.__height / height
107
+ scale_width = self.__width / width
108
+
109
+ if self.__keep_aspect_ratio:
110
+ if self.__resize_method == "lower_bound":
111
+ # scale such that output size is lower bound
112
+ if scale_width > scale_height:
113
+ # fit width
114
+ scale_height = scale_width
115
+ else:
116
+ # fit height
117
+ scale_width = scale_height
118
+ elif self.__resize_method == "upper_bound":
119
+ # scale such that output size is upper bound
120
+ if scale_width < scale_height:
121
+ # fit width
122
+ scale_height = scale_width
123
+ else:
124
+ # fit height
125
+ scale_width = scale_height
126
+ elif self.__resize_method == "minimal":
127
+ # scale as least as possbile
128
+ if abs(1 - scale_width) < abs(1 - scale_height):
129
+ # fit width
130
+ scale_height = scale_width
131
+ else:
132
+ # fit height
133
+ scale_width = scale_height
134
+ else:
135
+ raise ValueError(
136
+ f"resize_method {self.__resize_method} not implemented"
137
+ )
138
+
139
+ if self.__resize_method == "lower_bound":
140
+ new_height = self.constrain_to_multiple_of(
141
+ scale_height * height, min_val=self.__height
142
+ )
143
+ new_width = self.constrain_to_multiple_of(
144
+ scale_width * width, min_val=self.__width
145
+ )
146
+ elif self.__resize_method == "upper_bound":
147
+ new_height = self.constrain_to_multiple_of(
148
+ scale_height * height, max_val=self.__height
149
+ )
150
+ new_width = self.constrain_to_multiple_of(
151
+ scale_width * width, max_val=self.__width
152
+ )
153
+ elif self.__resize_method == "minimal":
154
+ new_height = self.constrain_to_multiple_of(scale_height * height)
155
+ new_width = self.constrain_to_multiple_of(scale_width * width)
156
+ else:
157
+ raise ValueError(f"resize_method {self.__resize_method} not implemented")
158
+
159
+ return (new_width, new_height)
160
+
161
+ def __call__(self, sample):
162
+ width, height = self.get_size(
163
+ sample["image"].shape[1], sample["image"].shape[0]
164
+ )
165
+
166
+ # resize sample
167
+ sample["image"] = cv2.resize(
168
+ sample["image"],
169
+ (width, height),
170
+ interpolation=self.__image_interpolation_method,
171
+ )
172
+
173
+ if self.__resize_target:
174
+ if "disparity" in sample:
175
+ sample["disparity"] = cv2.resize(
176
+ sample["disparity"],
177
+ (width, height),
178
+ interpolation=cv2.INTER_NEAREST,
179
+ )
180
+
181
+ if "depth" in sample:
182
+ sample["depth"] = cv2.resize(
183
+ sample["depth"], (width, height), interpolation=cv2.INTER_NEAREST
184
+ )
185
+
186
+ sample["mask"] = cv2.resize(
187
+ sample["mask"].astype(np.float32),
188
+ (width, height),
189
+ interpolation=cv2.INTER_NEAREST,
190
+ )
191
+ sample["mask"] = sample["mask"].astype(bool)
192
+
193
+ return sample
194
+
195
+
196
+ class NormalizeImage(object):
197
+ """Normlize image by given mean and std."""
198
+
199
+ def __init__(self, mean, std):
200
+ self.__mean = mean
201
+ self.__std = std
202
+
203
+ def __call__(self, sample):
204
+ sample["image"] = (sample["image"] - self.__mean) / self.__std
205
+
206
+ return sample
207
+
208
+
209
+ class PrepareForNet(object):
210
+ """Prepare sample for usage as network input."""
211
+
212
+ def __init__(self):
213
+ pass
214
+
215
+ def __call__(self, sample):
216
+ image = np.transpose(sample["image"], (2, 0, 1))
217
+ sample["image"] = np.ascontiguousarray(image).astype(np.float32)
218
+
219
+ if "mask" in sample:
220
+ sample["mask"] = sample["mask"].astype(np.float32)
221
+ sample["mask"] = np.ascontiguousarray(sample["mask"])
222
+
223
+ if "disparity" in sample:
224
+ disparity = sample["disparity"].astype(np.float32)
225
+ sample["disparity"] = np.ascontiguousarray(disparity)
226
+
227
+ if "depth" in sample:
228
+ depth = sample["depth"].astype(np.float32)
229
+ sample["depth"] = np.ascontiguousarray(depth)
230
+
231
+ return sample
models/spatracker/models/core/spatracker/dpt/vit.py ADDED
@@ -0,0 +1,596 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import timm
4
+ import types
5
+ import math
6
+ import torch.nn.functional as F
7
+
8
+
9
+ activations = {}
10
+
11
+
12
+ def get_activation(name):
13
+ def hook(model, input, output):
14
+ activations[name] = output
15
+
16
+ return hook
17
+
18
+
19
+ attention = {}
20
+
21
+
22
+ def get_attention(name):
23
+ def hook(module, input, output):
24
+ x = input[0]
25
+ B, N, C = x.shape
26
+ qkv = (
27
+ module.qkv(x)
28
+ .reshape(B, N, 3, module.num_heads, C // module.num_heads)
29
+ .permute(2, 0, 3, 1, 4)
30
+ )
31
+ q, k, v = (
32
+ qkv[0],
33
+ qkv[1],
34
+ qkv[2],
35
+ ) # make torchscript happy (cannot use tensor as tuple)
36
+
37
+ attn = (q @ k.transpose(-2, -1)) * module.scale
38
+
39
+ attn = attn.softmax(dim=-1) # [:,:,1,1:]
40
+ attention[name] = attn
41
+
42
+ return hook
43
+
44
+
45
+ def get_mean_attention_map(attn, token, shape):
46
+ attn = attn[:, :, token, 1:]
47
+ attn = attn.unflatten(2, torch.Size([shape[2] // 16, shape[3] // 16])).float()
48
+ attn = torch.nn.functional.interpolate(
49
+ attn, size=shape[2:], mode="bicubic", align_corners=False
50
+ ).squeeze(0)
51
+
52
+ all_attn = torch.mean(attn, 0)
53
+
54
+ return all_attn
55
+
56
+
57
+ class Slice(nn.Module):
58
+ def __init__(self, start_index=1):
59
+ super(Slice, self).__init__()
60
+ self.start_index = start_index
61
+
62
+ def forward(self, x):
63
+ return x[:, self.start_index :]
64
+
65
+
66
+ class AddReadout(nn.Module):
67
+ def __init__(self, start_index=1):
68
+ super(AddReadout, self).__init__()
69
+ self.start_index = start_index
70
+
71
+ def forward(self, x):
72
+ if self.start_index == 2:
73
+ readout = (x[:, 0] + x[:, 1]) / 2
74
+ else:
75
+ readout = x[:, 0]
76
+ return x[:, self.start_index :] + readout.unsqueeze(1)
77
+
78
+
79
+ class ProjectReadout(nn.Module):
80
+ def __init__(self, in_features, start_index=1):
81
+ super(ProjectReadout, self).__init__()
82
+ self.start_index = start_index
83
+
84
+ self.project = nn.Sequential(nn.Linear(2 * in_features, in_features), nn.GELU())
85
+
86
+ def forward(self, x):
87
+ readout = x[:, 0].unsqueeze(1).expand_as(x[:, self.start_index :])
88
+ features = torch.cat((x[:, self.start_index :], readout), -1)
89
+
90
+ return self.project(features)
91
+
92
+
93
+ class Transpose(nn.Module):
94
+ def __init__(self, dim0, dim1):
95
+ super(Transpose, self).__init__()
96
+ self.dim0 = dim0
97
+ self.dim1 = dim1
98
+
99
+ def forward(self, x):
100
+ x = x.transpose(self.dim0, self.dim1)
101
+ return x
102
+
103
+
104
+ def forward_vit(pretrained, x):
105
+ b, c, h, w = x.shape
106
+
107
+ glob = pretrained.model.forward_flex(x)
108
+
109
+ layer_1 = pretrained.activations["1"]
110
+ layer_2 = pretrained.activations["2"]
111
+ layer_3 = pretrained.activations["3"]
112
+ layer_4 = pretrained.activations["4"]
113
+
114
+ layer_1 = pretrained.act_postprocess1[0:2](layer_1)
115
+ layer_2 = pretrained.act_postprocess2[0:2](layer_2)
116
+ layer_3 = pretrained.act_postprocess3[0:2](layer_3)
117
+ layer_4 = pretrained.act_postprocess4[0:2](layer_4)
118
+
119
+ unflatten = nn.Sequential(
120
+ nn.Unflatten(
121
+ 2,
122
+ torch.Size(
123
+ [
124
+ h // pretrained.model.patch_size[1],
125
+ w // pretrained.model.patch_size[0],
126
+ ]
127
+ ),
128
+ )
129
+ )
130
+
131
+ if layer_1.ndim == 3:
132
+ layer_1 = unflatten(layer_1)
133
+ if layer_2.ndim == 3:
134
+ layer_2 = unflatten(layer_2)
135
+ if layer_3.ndim == 3:
136
+ layer_3 = unflatten(layer_3)
137
+ if layer_4.ndim == 3:
138
+ layer_4 = unflatten(layer_4)
139
+
140
+ layer_1 = pretrained.act_postprocess1[3 : len(pretrained.act_postprocess1)](layer_1)
141
+ layer_2 = pretrained.act_postprocess2[3 : len(pretrained.act_postprocess2)](layer_2)
142
+ layer_3 = pretrained.act_postprocess3[3 : len(pretrained.act_postprocess3)](layer_3)
143
+ layer_4 = pretrained.act_postprocess4[3 : len(pretrained.act_postprocess4)](layer_4)
144
+
145
+ return layer_1, layer_2, layer_3, layer_4
146
+
147
+
148
+ def _resize_pos_embed(self, posemb, gs_h, gs_w):
149
+ posemb_tok, posemb_grid = (
150
+ posemb[:, : self.start_index],
151
+ posemb[0, self.start_index :],
152
+ )
153
+
154
+ gs_old = int(math.sqrt(len(posemb_grid)))
155
+
156
+ posemb_grid = posemb_grid.reshape(1, gs_old, gs_old, -1).permute(0, 3, 1, 2)
157
+ posemb_grid = F.interpolate(posemb_grid, size=(gs_h, gs_w), mode="bilinear")
158
+ posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, gs_h * gs_w, -1)
159
+
160
+ posemb = torch.cat([posemb_tok, posemb_grid], dim=1)
161
+
162
+ return posemb
163
+
164
+
165
+ def forward_flex(self, x):
166
+ b, c, h, w = x.shape
167
+
168
+ pos_embed = self._resize_pos_embed(
169
+ self.pos_embed, h // self.patch_size[1], w // self.patch_size[0]
170
+ )
171
+
172
+ B = x.shape[0]
173
+
174
+ if hasattr(self.patch_embed, "backbone"):
175
+ x = self.patch_embed.backbone(x)
176
+ if isinstance(x, (list, tuple)):
177
+ x = x[-1] # last feature if backbone outputs list/tuple of features
178
+ x = self.patch_embed.proj(x).flatten(2).transpose(1, 2)
179
+
180
+ if getattr(self, "dist_token", None) is not None:
181
+ cls_tokens = self.cls_token.expand(
182
+ B, -1, -1
183
+ ) # stole cls_tokens impl from Phil Wang, thanks
184
+ dist_token = self.dist_token.expand(B, -1, -1)
185
+ x = torch.cat((cls_tokens, dist_token, x), dim=1)
186
+ else:
187
+ cls_tokens = self.cls_token.expand(
188
+ B, -1, -1
189
+ ) # stole cls_tokens impl from Phil Wang, thanks
190
+ x = torch.cat((cls_tokens, x), dim=1)
191
+
192
+ x = x + pos_embed
193
+ x = self.pos_drop(x)
194
+
195
+ for blk in self.blocks:
196
+ x = blk(x)
197
+
198
+ x = self.norm(x)
199
+
200
+ return x
201
+
202
+
203
+ def get_readout_oper(vit_features, features, use_readout, start_index=1):
204
+ if use_readout == "ignore":
205
+ readout_oper = [Slice(start_index)] * len(features)
206
+ elif use_readout == "add":
207
+ readout_oper = [AddReadout(start_index)] * len(features)
208
+ elif use_readout == "project":
209
+ readout_oper = [
210
+ ProjectReadout(vit_features, start_index) for out_feat in features
211
+ ]
212
+ else:
213
+ assert (
214
+ False
215
+ ), "wrong operation for readout token, use_readout can be 'ignore', 'add', or 'project'"
216
+
217
+ return readout_oper
218
+
219
+
220
+ def _make_vit_b16_backbone(
221
+ model,
222
+ features=[96, 192, 384, 768],
223
+ size=[384, 384],
224
+ hooks=[2, 5, 8, 11],
225
+ vit_features=768,
226
+ use_readout="ignore",
227
+ start_index=1,
228
+ enable_attention_hooks=False,
229
+ ):
230
+ pretrained = nn.Module()
231
+
232
+ pretrained.model = model
233
+ pretrained.model.blocks[hooks[0]].register_forward_hook(get_activation("1"))
234
+ pretrained.model.blocks[hooks[1]].register_forward_hook(get_activation("2"))
235
+ pretrained.model.blocks[hooks[2]].register_forward_hook(get_activation("3"))
236
+ pretrained.model.blocks[hooks[3]].register_forward_hook(get_activation("4"))
237
+
238
+ pretrained.activations = activations
239
+
240
+ if enable_attention_hooks:
241
+ pretrained.model.blocks[hooks[0]].attn.register_forward_hook(
242
+ get_attention("attn_1")
243
+ )
244
+ pretrained.model.blocks[hooks[1]].attn.register_forward_hook(
245
+ get_attention("attn_2")
246
+ )
247
+ pretrained.model.blocks[hooks[2]].attn.register_forward_hook(
248
+ get_attention("attn_3")
249
+ )
250
+ pretrained.model.blocks[hooks[3]].attn.register_forward_hook(
251
+ get_attention("attn_4")
252
+ )
253
+ pretrained.attention = attention
254
+
255
+ readout_oper = get_readout_oper(vit_features, features, use_readout, start_index)
256
+
257
+ # 32, 48, 136, 384
258
+ pretrained.act_postprocess1 = nn.Sequential(
259
+ readout_oper[0],
260
+ Transpose(1, 2),
261
+ nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
262
+ nn.Conv2d(
263
+ in_channels=vit_features,
264
+ out_channels=features[0],
265
+ kernel_size=1,
266
+ stride=1,
267
+ padding=0,
268
+ ),
269
+ nn.ConvTranspose2d(
270
+ in_channels=features[0],
271
+ out_channels=features[0],
272
+ kernel_size=4,
273
+ stride=4,
274
+ padding=0,
275
+ bias=True,
276
+ dilation=1,
277
+ groups=1,
278
+ ),
279
+ )
280
+
281
+ pretrained.act_postprocess2 = nn.Sequential(
282
+ readout_oper[1],
283
+ Transpose(1, 2),
284
+ nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
285
+ nn.Conv2d(
286
+ in_channels=vit_features,
287
+ out_channels=features[1],
288
+ kernel_size=1,
289
+ stride=1,
290
+ padding=0,
291
+ ),
292
+ nn.ConvTranspose2d(
293
+ in_channels=features[1],
294
+ out_channels=features[1],
295
+ kernel_size=2,
296
+ stride=2,
297
+ padding=0,
298
+ bias=True,
299
+ dilation=1,
300
+ groups=1,
301
+ ),
302
+ )
303
+
304
+ pretrained.act_postprocess3 = nn.Sequential(
305
+ readout_oper[2],
306
+ Transpose(1, 2),
307
+ nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
308
+ nn.Conv2d(
309
+ in_channels=vit_features,
310
+ out_channels=features[2],
311
+ kernel_size=1,
312
+ stride=1,
313
+ padding=0,
314
+ ),
315
+ )
316
+
317
+ pretrained.act_postprocess4 = nn.Sequential(
318
+ readout_oper[3],
319
+ Transpose(1, 2),
320
+ nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
321
+ nn.Conv2d(
322
+ in_channels=vit_features,
323
+ out_channels=features[3],
324
+ kernel_size=1,
325
+ stride=1,
326
+ padding=0,
327
+ ),
328
+ nn.Conv2d(
329
+ in_channels=features[3],
330
+ out_channels=features[3],
331
+ kernel_size=3,
332
+ stride=2,
333
+ padding=1,
334
+ ),
335
+ )
336
+
337
+ pretrained.model.start_index = start_index
338
+ pretrained.model.patch_size = [16, 16]
339
+
340
+ # We inject this function into the VisionTransformer instances so that
341
+ # we can use it with interpolated position embeddings without modifying the library source.
342
+ pretrained.model.forward_flex = types.MethodType(forward_flex, pretrained.model)
343
+ pretrained.model._resize_pos_embed = types.MethodType(
344
+ _resize_pos_embed, pretrained.model
345
+ )
346
+
347
+ return pretrained
348
+
349
+
350
+ def _make_vit_b_rn50_backbone(
351
+ model,
352
+ features=[256, 512, 768, 768],
353
+ size=[384, 384],
354
+ hooks=[0, 1, 8, 11],
355
+ vit_features=384,
356
+ use_vit_only=False,
357
+ use_readout="ignore",
358
+ start_index=1,
359
+ enable_attention_hooks=False,
360
+ ):
361
+ pretrained = nn.Module()
362
+ pretrained.model = model
363
+ pretrained.model.patch_size = [32, 32]
364
+ ps = pretrained.model.patch_size[0]
365
+ if use_vit_only == True:
366
+ pretrained.model.blocks[hooks[0]].register_forward_hook(get_activation("1"))
367
+ pretrained.model.blocks[hooks[1]].register_forward_hook(get_activation("2"))
368
+ else:
369
+ pretrained.model.patch_embed.backbone.stages[0].register_forward_hook(
370
+ get_activation("1")
371
+ )
372
+ pretrained.model.patch_embed.backbone.stages[1].register_forward_hook(
373
+ get_activation("2")
374
+ )
375
+
376
+ pretrained.model.blocks[hooks[2]].register_forward_hook(get_activation("3"))
377
+ pretrained.model.blocks[hooks[3]].register_forward_hook(get_activation("4"))
378
+
379
+ if enable_attention_hooks:
380
+ pretrained.model.blocks[2].attn.register_forward_hook(get_attention("attn_1"))
381
+ pretrained.model.blocks[5].attn.register_forward_hook(get_attention("attn_2"))
382
+ pretrained.model.blocks[8].attn.register_forward_hook(get_attention("attn_3"))
383
+ pretrained.model.blocks[11].attn.register_forward_hook(get_attention("attn_4"))
384
+ pretrained.attention = attention
385
+
386
+ pretrained.activations = activations
387
+
388
+ readout_oper = get_readout_oper(vit_features, features, use_readout, start_index)
389
+
390
+ if use_vit_only == True:
391
+ pretrained.act_postprocess1 = nn.Sequential(
392
+ readout_oper[0],
393
+ Transpose(1, 2),
394
+ nn.Unflatten(2, torch.Size([size[0] // ps, size[1] // ps])),
395
+ nn.Conv2d(
396
+ in_channels=vit_features,
397
+ out_channels=features[0],
398
+ kernel_size=1,
399
+ stride=1,
400
+ padding=0,
401
+ ),
402
+ nn.ConvTranspose2d(
403
+ in_channels=features[0],
404
+ out_channels=features[0],
405
+ kernel_size=4,
406
+ stride=4,
407
+ padding=0,
408
+ bias=True,
409
+ dilation=1,
410
+ groups=1,
411
+ ),
412
+ )
413
+
414
+ pretrained.act_postprocess2 = nn.Sequential(
415
+ readout_oper[1],
416
+ Transpose(1, 2),
417
+ nn.Unflatten(2, torch.Size([size[0] // ps, size[1] // ps])),
418
+ nn.Conv2d(
419
+ in_channels=vit_features,
420
+ out_channels=features[1],
421
+ kernel_size=1,
422
+ stride=1,
423
+ padding=0,
424
+ ),
425
+ nn.ConvTranspose2d(
426
+ in_channels=features[1],
427
+ out_channels=features[1],
428
+ kernel_size=2,
429
+ stride=2,
430
+ padding=0,
431
+ bias=True,
432
+ dilation=1,
433
+ groups=1,
434
+ ),
435
+ )
436
+ else:
437
+ pretrained.act_postprocess1 = nn.Sequential(
438
+ nn.Identity(), nn.Identity(), nn.Identity()
439
+ )
440
+ pretrained.act_postprocess2 = nn.Sequential(
441
+ nn.Identity(), nn.Identity(), nn.Identity()
442
+ )
443
+
444
+ pretrained.act_postprocess3 = nn.Sequential(
445
+ readout_oper[2],
446
+ Transpose(1, 2),
447
+ nn.Unflatten(2, torch.Size([size[0] // ps, size[1] // ps])),
448
+ nn.Conv2d(
449
+ in_channels=vit_features,
450
+ out_channels=features[2],
451
+ kernel_size=1,
452
+ stride=1,
453
+ padding=0,
454
+ ),
455
+ )
456
+
457
+ pretrained.act_postprocess4 = nn.Sequential(
458
+ readout_oper[3],
459
+ Transpose(1, 2),
460
+ nn.Unflatten(2, torch.Size([size[0] // ps, size[1] // ps])),
461
+ nn.Conv2d(
462
+ in_channels=vit_features,
463
+ out_channels=features[3],
464
+ kernel_size=1,
465
+ stride=1,
466
+ padding=0,
467
+ ),
468
+ nn.Conv2d(
469
+ in_channels=features[3],
470
+ out_channels=features[3],
471
+ kernel_size=3,
472
+ stride=2,
473
+ padding=1,
474
+ ),
475
+ )
476
+
477
+ pretrained.model.start_index = start_index
478
+ pretrained.model.patch_size = [32, 32]
479
+
480
+ # We inject this function into the VisionTransformer instances so that
481
+ # we can use it with interpolated position embeddings without modifying the library source.
482
+ pretrained.model.forward_flex = types.MethodType(forward_flex, pretrained.model)
483
+
484
+ # We inject this function into the VisionTransformer instances so that
485
+ # we can use it with interpolated position embeddings without modifying the library source.
486
+ pretrained.model._resize_pos_embed = types.MethodType(
487
+ _resize_pos_embed, pretrained.model
488
+ )
489
+
490
+ return pretrained
491
+
492
+
493
+ def _make_pretrained_vitb_rn50_384(
494
+ pretrained,
495
+ use_readout="ignore",
496
+ hooks=None,
497
+ use_vit_only=False,
498
+ enable_attention_hooks=False,
499
+ ):
500
+ # model = timm.create_model("vit_base_resnet50_384", pretrained=pretrained)
501
+ # model = timm.create_model("vit_tiny_r_s16_p8_384", pretrained=pretrained)
502
+ model = timm.create_model("vit_small_r26_s32_384", pretrained=pretrained)
503
+ hooks = [0, 1, 8, 11] if hooks == None else hooks
504
+ return _make_vit_b_rn50_backbone(
505
+ model,
506
+ features=[128, 256, 384, 384],
507
+ size=[384, 384],
508
+ hooks=hooks,
509
+ use_vit_only=use_vit_only,
510
+ use_readout=use_readout,
511
+ enable_attention_hooks=enable_attention_hooks,
512
+ )
513
+
514
+ def _make_pretrained_vit_tiny(
515
+ pretrained,
516
+ use_readout="ignore",
517
+ hooks=None,
518
+ use_vit_only=False,
519
+ enable_attention_hooks=False,
520
+ ):
521
+ # model = timm.create_model("vit_base_resnet50_384", pretrained=pretrained)
522
+ model = timm.create_model("vit_tiny_r_s16_p8_384", pretrained=pretrained)
523
+ import ipdb; ipdb.set_trace()
524
+ hooks = [0, 1, 8, 11] if hooks == None else hooks
525
+ return _make_vit_tiny_backbone(
526
+ model,
527
+ features=[256, 512, 768, 768],
528
+ size=[384, 384],
529
+ hooks=hooks,
530
+ use_vit_only=use_vit_only,
531
+ use_readout=use_readout,
532
+ enable_attention_hooks=enable_attention_hooks,
533
+ )
534
+
535
+ def _make_pretrained_vitl16_384(
536
+ pretrained, use_readout="ignore", hooks=None, enable_attention_hooks=False
537
+ ):
538
+ model = timm.create_model("vit_large_patch16_384", pretrained=pretrained)
539
+
540
+ hooks = [5, 11, 17, 23] if hooks == None else hooks
541
+ return _make_vit_b16_backbone(
542
+ model,
543
+ features=[256, 512, 1024, 1024],
544
+ hooks=hooks,
545
+ vit_features=1024,
546
+ use_readout=use_readout,
547
+ enable_attention_hooks=enable_attention_hooks,
548
+ )
549
+
550
+
551
+ def _make_pretrained_vitb16_384(
552
+ pretrained, use_readout="ignore", hooks=None, enable_attention_hooks=False
553
+ ):
554
+ model = timm.create_model("vit_base_patch16_384", pretrained=pretrained)
555
+
556
+ hooks = [2, 5, 8, 11] if hooks == None else hooks
557
+ return _make_vit_b16_backbone(
558
+ model,
559
+ features=[96, 192, 384, 768],
560
+ hooks=hooks,
561
+ use_readout=use_readout,
562
+ enable_attention_hooks=enable_attention_hooks,
563
+ )
564
+
565
+
566
+ def _make_pretrained_deitb16_384(
567
+ pretrained, use_readout="ignore", hooks=None, enable_attention_hooks=False
568
+ ):
569
+ model = timm.create_model("vit_deit_base_patch16_384", pretrained=pretrained)
570
+
571
+ hooks = [2, 5, 8, 11] if hooks == None else hooks
572
+ return _make_vit_b16_backbone(
573
+ model,
574
+ features=[96, 192, 384, 768],
575
+ hooks=hooks,
576
+ use_readout=use_readout,
577
+ enable_attention_hooks=enable_attention_hooks,
578
+ )
579
+
580
+
581
+ def _make_pretrained_deitb16_distil_384(
582
+ pretrained, use_readout="ignore", hooks=None, enable_attention_hooks=False
583
+ ):
584
+ model = timm.create_model(
585
+ "vit_deit_base_distilled_patch16_384", pretrained=pretrained
586
+ )
587
+
588
+ hooks = [2, 5, 8, 11] if hooks == None else hooks
589
+ return _make_vit_b16_backbone(
590
+ model,
591
+ features=[96, 192, 384, 768],
592
+ hooks=hooks,
593
+ use_readout=use_readout,
594
+ start_index=2,
595
+ enable_attention_hooks=enable_attention_hooks,
596
+ )
models/spatracker/models/core/spatracker/feature_net.py ADDED
@@ -0,0 +1,915 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Adapted from ConvONet
3
+ https://github.com/autonomousvision/convolutional_occupancy_networks/blob/838bea5b2f1314f2edbb68d05ebb0db49f1f3bd2/src/encoder/pointnet.py#L1
4
+ """
5
+
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F
10
+ # from torch_scatter import scatter_mean, scatter_max
11
+ from models.spatracker.models.core.spatracker.unet import UNet
12
+ from models.spatracker.models.core.model_utils import (
13
+ vis_PCA
14
+ )
15
+ from einops import rearrange
16
+
17
+ def compute_iou(occ1, occ2):
18
+ ''' Computes the Intersection over Union (IoU) value for two sets of
19
+ occupancy values.
20
+
21
+ Args:
22
+ occ1 (tensor): first set of occupancy values
23
+ occ2 (tensor): second set of occupancy values
24
+ '''
25
+ occ1 = np.asarray(occ1)
26
+ occ2 = np.asarray(occ2)
27
+
28
+ # Put all data in second dimension
29
+ # Also works for 1-dimensional data
30
+ if occ1.ndim >= 2:
31
+ occ1 = occ1.reshape(occ1.shape[0], -1)
32
+ if occ2.ndim >= 2:
33
+ occ2 = occ2.reshape(occ2.shape[0], -1)
34
+
35
+ # Convert to boolean values
36
+ occ1 = (occ1 >= 0.5)
37
+ occ2 = (occ2 >= 0.5)
38
+
39
+ # Compute IOU
40
+ area_union = (occ1 | occ2).astype(np.float32).sum(axis=-1)
41
+ area_intersect = (occ1 & occ2).astype(np.float32).sum(axis=-1)
42
+
43
+ iou = (area_intersect / area_union)
44
+
45
+ return iou
46
+
47
+
48
+ def chamfer_distance(points1, points2, use_kdtree=True, give_id=False):
49
+ ''' Returns the chamfer distance for the sets of points.
50
+
51
+ Args:
52
+ points1 (numpy array): first point set
53
+ points2 (numpy array): second point set
54
+ use_kdtree (bool): whether to use a kdtree
55
+ give_id (bool): whether to return the IDs of nearest points
56
+ '''
57
+ if use_kdtree:
58
+ return chamfer_distance_kdtree(points1, points2, give_id=give_id)
59
+ else:
60
+ return chamfer_distance_naive(points1, points2)
61
+
62
+
63
+ def chamfer_distance_naive(points1, points2):
64
+ ''' Naive implementation of the Chamfer distance.
65
+
66
+ Args:
67
+ points1 (numpy array): first point set
68
+ points2 (numpy array): second point set
69
+ '''
70
+ assert(points1.size() == points2.size())
71
+ batch_size, T, _ = points1.size()
72
+
73
+ points1 = points1.view(batch_size, T, 1, 3)
74
+ points2 = points2.view(batch_size, 1, T, 3)
75
+
76
+ distances = (points1 - points2).pow(2).sum(-1)
77
+
78
+ chamfer1 = distances.min(dim=1)[0].mean(dim=1)
79
+ chamfer2 = distances.min(dim=2)[0].mean(dim=1)
80
+
81
+ chamfer = chamfer1 + chamfer2
82
+ return chamfer
83
+
84
+
85
+ def chamfer_distance_kdtree(points1, points2, give_id=False):
86
+ ''' KD-tree based implementation of the Chamfer distance.
87
+
88
+ Args:
89
+ points1 (numpy array): first point set
90
+ points2 (numpy array): second point set
91
+ give_id (bool): whether to return the IDs of the nearest points
92
+ '''
93
+ # Points have size batch_size x T x 3
94
+ batch_size = points1.size(0)
95
+
96
+ # First convert points to numpy
97
+ points1_np = points1.detach().cpu().numpy()
98
+ points2_np = points2.detach().cpu().numpy()
99
+
100
+ # Get list of nearest neighbors indieces
101
+ idx_nn_12, _ = get_nearest_neighbors_indices_batch(points1_np, points2_np)
102
+ idx_nn_12 = torch.LongTensor(idx_nn_12).to(points1.device)
103
+ # Expands it as batch_size x 1 x 3
104
+ idx_nn_12_expand = idx_nn_12.view(batch_size, -1, 1).expand_as(points1)
105
+
106
+ # Get list of nearest neighbors indieces
107
+ idx_nn_21, _ = get_nearest_neighbors_indices_batch(points2_np, points1_np)
108
+ idx_nn_21 = torch.LongTensor(idx_nn_21).to(points1.device)
109
+ # Expands it as batch_size x T x 3
110
+ idx_nn_21_expand = idx_nn_21.view(batch_size, -1, 1).expand_as(points2)
111
+
112
+ # Compute nearest neighbors in points2 to points in points1
113
+ # points_12[i, j, k] = points2[i, idx_nn_12_expand[i, j, k], k]
114
+ points_12 = torch.gather(points2, dim=1, index=idx_nn_12_expand)
115
+
116
+ # Compute nearest neighbors in points1 to points in points2
117
+ # points_21[i, j, k] = points2[i, idx_nn_21_expand[i, j, k], k]
118
+ points_21 = torch.gather(points1, dim=1, index=idx_nn_21_expand)
119
+
120
+ # Compute chamfer distance
121
+ chamfer1 = (points1 - points_12).pow(2).sum(2).mean(1)
122
+ chamfer2 = (points2 - points_21).pow(2).sum(2).mean(1)
123
+
124
+ # Take sum
125
+ chamfer = chamfer1 + chamfer2
126
+
127
+ # If required, also return nearest neighbors
128
+ if give_id:
129
+ return chamfer1, chamfer2, idx_nn_12, idx_nn_21
130
+
131
+ return chamfer
132
+
133
+
134
+ def get_nearest_neighbors_indices_batch(points_src, points_tgt, k=1):
135
+ ''' Returns the nearest neighbors for point sets batchwise.
136
+
137
+ Args:
138
+ points_src (numpy array): source points
139
+ points_tgt (numpy array): target points
140
+ k (int): number of nearest neighbors to return
141
+ '''
142
+ indices = []
143
+ distances = []
144
+
145
+ for (p1, p2) in zip(points_src, points_tgt):
146
+ raise NotImplementedError()
147
+ # kdtree = KDTree(p2)
148
+ dist, idx = kdtree.query(p1, k=k)
149
+ indices.append(idx)
150
+ distances.append(dist)
151
+
152
+ return indices, distances
153
+
154
+
155
+ def make_3d_grid(bb_min, bb_max, shape):
156
+ ''' Makes a 3D grid.
157
+
158
+ Args:
159
+ bb_min (tuple): bounding box minimum
160
+ bb_max (tuple): bounding box maximum
161
+ shape (tuple): output shape
162
+ '''
163
+ size = shape[0] * shape[1] * shape[2]
164
+
165
+ pxs = torch.linspace(bb_min[0], bb_max[0], shape[0])
166
+ pys = torch.linspace(bb_min[1], bb_max[1], shape[1])
167
+ pzs = torch.linspace(bb_min[2], bb_max[2], shape[2])
168
+
169
+ pxs = pxs.view(-1, 1, 1).expand(*shape).contiguous().view(size)
170
+ pys = pys.view(1, -1, 1).expand(*shape).contiguous().view(size)
171
+ pzs = pzs.view(1, 1, -1).expand(*shape).contiguous().view(size)
172
+ p = torch.stack([pxs, pys, pzs], dim=1)
173
+
174
+ return p
175
+
176
+
177
+ def transform_points(points, transform):
178
+ ''' Transforms points with regard to passed camera information.
179
+
180
+ Args:
181
+ points (tensor): points tensor
182
+ transform (tensor): transformation matrices
183
+ '''
184
+ assert(points.size(2) == 3)
185
+ assert(transform.size(1) == 3)
186
+ assert(points.size(0) == transform.size(0))
187
+
188
+ if transform.size(2) == 4:
189
+ R = transform[:, :, :3]
190
+ t = transform[:, :, 3:]
191
+ points_out = points @ R.transpose(1, 2) + t.transpose(1, 2)
192
+ elif transform.size(2) == 3:
193
+ K = transform
194
+ points_out = points @ K.transpose(1, 2)
195
+
196
+ return points_out
197
+
198
+
199
+ def b_inv(b_mat):
200
+ ''' Performs batch matrix inversion.
201
+
202
+ Arguments:
203
+ b_mat: the batch of matrices that should be inverted
204
+ '''
205
+
206
+ eye = b_mat.new_ones(b_mat.size(-1)).diag().expand_as(b_mat)
207
+ b_inv, _ = torch.gesv(eye, b_mat)
208
+ return b_inv
209
+
210
+ def project_to_camera(points, transform):
211
+ ''' Projects points to the camera plane.
212
+
213
+ Args:
214
+ points (tensor): points tensor
215
+ transform (tensor): transformation matrices
216
+ '''
217
+ p_camera = transform_points(points, transform)
218
+ p_camera = p_camera[..., :2] / p_camera[..., 2:]
219
+ return p_camera
220
+
221
+
222
+ def fix_Rt_camera(Rt, loc, scale):
223
+ ''' Fixes Rt camera matrix.
224
+
225
+ Args:
226
+ Rt (tensor): Rt camera matrix
227
+ loc (tensor): location
228
+ scale (float): scale
229
+ '''
230
+ # Rt is B x 3 x 4
231
+ # loc is B x 3 and scale is B
232
+ batch_size = Rt.size(0)
233
+ R = Rt[:, :, :3]
234
+ t = Rt[:, :, 3:]
235
+
236
+ scale = scale.view(batch_size, 1, 1)
237
+ R_new = R * scale
238
+ t_new = t + R @ loc.unsqueeze(2)
239
+
240
+ Rt_new = torch.cat([R_new, t_new], dim=2)
241
+
242
+ assert(Rt_new.size() == (batch_size, 3, 4))
243
+ return Rt_new
244
+
245
+ def normalize_coordinate(p, padding=0.1, plane='xz'):
246
+ ''' Normalize coordinate to [0, 1] for unit cube experiments
247
+
248
+ Args:
249
+ p (tensor): point
250
+ padding (float): conventional padding paramter of ONet for unit cube, so [-0.5, 0.5] -> [-0.55, 0.55]
251
+ plane (str): plane feature type, ['xz', 'xy', 'yz']
252
+ '''
253
+ # breakpoint()
254
+ if plane == 'xz':
255
+ xy = p[:, :, [0, 2]]
256
+ elif plane =='xy':
257
+ xy = p[:, :, [0, 1]]
258
+ else:
259
+ xy = p[:, :, [1, 2]]
260
+
261
+ xy = torch.clamp(xy, min=1e-6, max=1. - 1e-6)
262
+
263
+ # xy_new = xy / (1 + padding + 10e-6) # (-0.5, 0.5)
264
+ # xy_new = xy_new + 0.5 # range (0, 1)
265
+
266
+ # # f there are outliers out of the range
267
+ # if xy_new.max() >= 1:
268
+ # xy_new[xy_new >= 1] = 1 - 10e-6
269
+ # if xy_new.min() < 0:
270
+ # xy_new[xy_new < 0] = 0.0
271
+ # xy_new = (xy + 1.) / 2.
272
+ return xy
273
+
274
+ def normalize_3d_coordinate(p, padding=0.1):
275
+ ''' Normalize coordinate to [0, 1] for unit cube experiments.
276
+ Corresponds to our 3D model
277
+
278
+ Args:
279
+ p (tensor): point
280
+ padding (float): conventional padding paramter of ONet for unit cube, so [-0.5, 0.5] -> [-0.55, 0.55]
281
+ '''
282
+
283
+ p_nor = p / (1 + padding + 10e-4) # (-0.5, 0.5)
284
+ p_nor = p_nor + 0.5 # range (0, 1)
285
+ # f there are outliers out of the range
286
+ if p_nor.max() >= 1:
287
+ p_nor[p_nor >= 1] = 1 - 10e-4
288
+ if p_nor.min() < 0:
289
+ p_nor[p_nor < 0] = 0.0
290
+ return p_nor
291
+
292
+ def normalize_coord(p, vol_range, plane='xz'):
293
+ ''' Normalize coordinate to [0, 1] for sliding-window experiments
294
+
295
+ Args:
296
+ p (tensor): point
297
+ vol_range (numpy array): volume boundary
298
+ plane (str): feature type, ['xz', 'xy', 'yz'] - canonical planes; ['grid'] - grid volume
299
+ '''
300
+ p[:, 0] = (p[:, 0] - vol_range[0][0]) / (vol_range[1][0] - vol_range[0][0])
301
+ p[:, 1] = (p[:, 1] - vol_range[0][1]) / (vol_range[1][1] - vol_range[0][1])
302
+ p[:, 2] = (p[:, 2] - vol_range[0][2]) / (vol_range[1][2] - vol_range[0][2])
303
+
304
+ if plane == 'xz':
305
+ x = p[:, [0, 2]]
306
+ elif plane =='xy':
307
+ x = p[:, [0, 1]]
308
+ elif plane =='yz':
309
+ x = p[:, [1, 2]]
310
+ else:
311
+ x = p
312
+ return x
313
+
314
+ def coordinate2index(x, reso, coord_type='2d'):
315
+ ''' Normalize coordinate to [0, 1] for unit cube experiments.
316
+ Corresponds to our 3D model
317
+
318
+ Args:
319
+ x (tensor): coordinate
320
+ reso (int): defined resolution
321
+ coord_type (str): coordinate type
322
+ '''
323
+ x = (x * reso).long()
324
+ if coord_type == '2d': # plane
325
+ index = x[:, :, 0] + reso * x[:, :, 1]
326
+ elif coord_type == '3d': # grid
327
+ index = x[:, :, 0] + reso * (x[:, :, 1] + reso * x[:, :, 2])
328
+ index = index[:, None, :]
329
+ return index
330
+
331
+ def coord2index(p, vol_range, reso=None, plane='xz'):
332
+ ''' Normalize coordinate to [0, 1] for sliding-window experiments.
333
+ Corresponds to our 3D model
334
+
335
+ Args:
336
+ p (tensor): points
337
+ vol_range (numpy array): volume boundary
338
+ reso (int): defined resolution
339
+ plane (str): feature type, ['xz', 'xy', 'yz'] - canonical planes; ['grid'] - grid volume
340
+ '''
341
+ # normalize to [0, 1]
342
+ x = normalize_coord(p, vol_range, plane=plane)
343
+
344
+ if isinstance(x, np.ndarray):
345
+ x = np.floor(x * reso).astype(int)
346
+ else: #* pytorch tensor
347
+ x = (x * reso).long()
348
+
349
+ if x.shape[1] == 2:
350
+ index = x[:, 0] + reso * x[:, 1]
351
+ index[index > reso**2] = reso**2
352
+ elif x.shape[1] == 3:
353
+ index = x[:, 0] + reso * (x[:, 1] + reso * x[:, 2])
354
+ index[index > reso**3] = reso**3
355
+
356
+ return index[None]
357
+
358
+ def update_reso(reso, depth):
359
+ ''' Update the defined resolution so that UNet can process.
360
+
361
+ Args:
362
+ reso (int): defined resolution
363
+ depth (int): U-Net number of layers
364
+ '''
365
+ base = 2**(int(depth) - 1)
366
+ if ~(reso / base).is_integer(): # when this is not integer, U-Net dimension error
367
+ for i in range(base):
368
+ if ((reso + i) / base).is_integer():
369
+ reso = reso + i
370
+ break
371
+ return reso
372
+
373
+ def decide_total_volume_range(query_vol_metric, recep_field, unit_size, unet_depth):
374
+ ''' Update the defined resolution so that UNet can process.
375
+
376
+ Args:
377
+ query_vol_metric (numpy array): query volume size
378
+ recep_field (int): defined the receptive field for U-Net
379
+ unit_size (float): the defined voxel size
380
+ unet_depth (int): U-Net number of layers
381
+ '''
382
+ reso = query_vol_metric / unit_size + recep_field - 1
383
+ reso = update_reso(int(reso), unet_depth) # make sure input reso can be processed by UNet
384
+ input_vol_metric = reso * unit_size
385
+ p_c = np.array([0.0, 0.0, 0.0]).astype(np.float32)
386
+ lb_input_vol, ub_input_vol = p_c - input_vol_metric/2, p_c + input_vol_metric/2
387
+ lb_query_vol, ub_query_vol = p_c - query_vol_metric/2, p_c + query_vol_metric/2
388
+ input_vol = [lb_input_vol, ub_input_vol]
389
+ query_vol = [lb_query_vol, ub_query_vol]
390
+
391
+ # handle the case when resolution is too large
392
+ if reso > 10000:
393
+ reso = 1
394
+
395
+ return input_vol, query_vol, reso
396
+
397
+ def add_key(base, new, base_name, new_name, device=None):
398
+ ''' Add new keys to the given input
399
+
400
+ Args:
401
+ base (tensor): inputs
402
+ new (tensor): new info for the inputs
403
+ base_name (str): name for the input
404
+ new_name (str): name for the new info
405
+ device (device): pytorch device
406
+ '''
407
+ if (new is not None) and (isinstance(new, dict)):
408
+ if device is not None:
409
+ for key in new.keys():
410
+ new[key] = new[key].to(device)
411
+ base = {base_name: base,
412
+ new_name: new}
413
+ return base
414
+
415
+ class map2local(object):
416
+ ''' Add new keys to the given input
417
+
418
+ Args:
419
+ s (float): the defined voxel size
420
+ pos_encoding (str): method for the positional encoding, linear|sin_cos
421
+ '''
422
+ def __init__(self, s, pos_encoding='linear'):
423
+ super().__init__()
424
+ self.s = s
425
+ self.pe = positional_encoding(basis_function=pos_encoding)
426
+
427
+ def __call__(self, p):
428
+ p = torch.remainder(p, self.s) / self.s # always possitive
429
+ # p = torch.fmod(p, self.s) / self.s # same sign as input p!
430
+ p = self.pe(p)
431
+ return p
432
+
433
+ class positional_encoding(object):
434
+ ''' Positional Encoding (presented in NeRF)
435
+
436
+ Args:
437
+ basis_function (str): basis function
438
+ '''
439
+ def __init__(self, basis_function='sin_cos'):
440
+ super().__init__()
441
+ self.func = basis_function
442
+
443
+ L = 10
444
+ freq_bands = 2.**(np.linspace(0, L-1, L))
445
+ self.freq_bands = freq_bands * math.pi
446
+
447
+ def __call__(self, p):
448
+ if self.func == 'sin_cos':
449
+ out = []
450
+ p = 2.0 * p - 1.0 # chagne to the range [-1, 1]
451
+ for freq in self.freq_bands:
452
+ out.append(torch.sin(freq * p))
453
+ out.append(torch.cos(freq * p))
454
+ p = torch.cat(out, dim=2)
455
+ return p
456
+
457
+ # Resnet Blocks
458
+ class ResnetBlockFC(nn.Module):
459
+ ''' Fully connected ResNet Block class.
460
+
461
+ Args:
462
+ size_in (int): input dimension
463
+ size_out (int): output dimension
464
+ size_h (int): hidden dimension
465
+ '''
466
+
467
+ def __init__(self, size_in, size_out=None, size_h=None):
468
+ super().__init__()
469
+ # Attributes
470
+ if size_out is None:
471
+ size_out = size_in
472
+
473
+ if size_h is None:
474
+ size_h = min(size_in, size_out)
475
+
476
+ self.size_in = size_in
477
+ self.size_h = size_h
478
+ self.size_out = size_out
479
+ # Submodules
480
+ self.fc_0 = nn.Linear(size_in, size_h)
481
+ self.fc_1 = nn.Linear(size_h, size_out)
482
+ self.actvn = nn.ReLU()
483
+
484
+ if size_in == size_out:
485
+ self.shortcut = None
486
+ else:
487
+ self.shortcut = nn.Linear(size_in, size_out, bias=False)
488
+ # Initialization
489
+ nn.init.zeros_(self.fc_1.weight)
490
+
491
+ def forward(self, x):
492
+ net = self.fc_0(self.actvn(x))
493
+ dx = self.fc_1(self.actvn(net))
494
+
495
+ if self.shortcut is not None:
496
+ x_s = self.shortcut(x)
497
+ else:
498
+ x_s = x
499
+
500
+ return x_s + dx
501
+
502
+
503
+
504
+ '''
505
+ ------------------ the key model for Pointnet ----------------------------
506
+ '''
507
+
508
+
509
+ class LocalSoftSplat(nn.Module):
510
+
511
+ def __init__(self, ch=128, dim=3, hidden_dim=128, scatter_type='max',
512
+ unet=True, unet_kwargs=None, unet3d=False, unet3d_kwargs=None,
513
+ hw=None, grid_resolution=None, plane_type='xz', padding=0.1,
514
+ n_blocks=4, splat_func=None):
515
+ super().__init__()
516
+ c_dim = ch
517
+
518
+ self.c_dim = c_dim
519
+
520
+ self.fc_pos = nn.Linear(dim, 2*hidden_dim)
521
+ self.blocks = nn.ModuleList([
522
+ ResnetBlockFC(2*hidden_dim, hidden_dim) for i in range(n_blocks)
523
+ ])
524
+ self.fc_c = nn.Linear(hidden_dim, c_dim)
525
+
526
+ self.actvn = nn.ReLU()
527
+ self.hidden_dim = hidden_dim
528
+
529
+ if unet:
530
+ self.unet = UNet(c_dim, in_channels=c_dim, **unet_kwargs)
531
+ else:
532
+ self.unet = None
533
+
534
+ # get splat func
535
+ self.splat_func = splat_func
536
+ def forward(self, img_feat,
537
+ Fxy2xz, Fxy2yz, Dz, gridxy=None):
538
+ """
539
+ Args:
540
+ img_feat (tensor): image features
541
+ Fxy2xz (tensor): transformation matrix from xy to xz
542
+ Fxy2yz (tensor): transformation matrix from xy to yz
543
+ """
544
+ B, T, _, H, W = img_feat.shape
545
+ fea_reshp = rearrange(img_feat, 'b t c h w -> (b h w) t c',
546
+ c=img_feat.shape[2], h=H, w=W)
547
+
548
+ gridyz = gridxy + Fxy2yz
549
+ gridxz = gridxy + Fxy2xz
550
+ # normalize
551
+ gridyz[:, 0, ...] = (gridyz[:, 0, ...] / (H - 1) - 0.5) * 2
552
+ gridyz[:, 1, ...] = (gridyz[:, 1, ...] / (Dz - 1) - 0.5) * 2
553
+ gridxz[:, 0, ...] = (gridxz[:, 0, ...] / (W - 1) - 0.5) * 2
554
+ gridxz[:, 1, ...] = (gridxz[:, 1, ...] / (Dz - 1) - 0.5) * 2
555
+ if len(self.blocks) > 0:
556
+ net = self.fc_pos(fea_reshp)
557
+ net = self.blocks[0](net)
558
+ for block in self.blocks[1:]:
559
+ # splat and fusion
560
+ net_plane = rearrange(net, '(b h w) t c -> (b t) c h w', b=B, h=H, w=W)
561
+
562
+ net_planeYZ = self.splat_func(net_plane, Fxy2yz, None,
563
+ strMode="avg", tenoutH=Dz, tenoutW=H)
564
+
565
+ net_planeXZ = self.splat_func(net_plane, Fxy2xz, None,
566
+ strMode="avg", tenoutH=Dz, tenoutW=W)
567
+
568
+ net_plane = net_plane + (
569
+ F.grid_sample(
570
+ net_planeYZ, gridyz.permute(0,2,3,1), mode='bilinear', padding_mode='border') +
571
+ F.grid_sample(
572
+ net_planeXZ, gridxz.permute(0,2,3,1), mode='bilinear', padding_mode='border')
573
+ )
574
+
575
+ pooled = rearrange(net_plane, 't c h w -> (h w) t c',
576
+ c=net_plane.shape[1], h=H, w=W)
577
+
578
+ net = torch.cat([net, pooled], dim=2)
579
+ net = block(net)
580
+
581
+ c = self.fc_c(net)
582
+ net_plane = rearrange(c, '(b h w) t c -> (b t) c h w', b=B, h=H, w=W)
583
+ else:
584
+ net_plane = rearrange(img_feat, 'b t c h w -> (b t) c h w',
585
+ c=img_feat.shape[2], h=H, w=W)
586
+ net_planeYZ = self.splat_func(net_plane, Fxy2yz, None,
587
+ strMode="avg", tenoutH=Dz, tenoutW=H)
588
+ net_planeXZ = self.splat_func(net_plane, Fxy2xz, None,
589
+ strMode="avg", tenoutH=Dz, tenoutW=W)
590
+
591
+ return net_plane[None], net_planeYZ[None], net_planeXZ[None]
592
+
593
+
594
+
595
+ class LocalPoolPointnet(nn.Module):
596
+ ''' PointNet-based encoder network with ResNet blocks for each point.
597
+ Number of input points are fixed.
598
+
599
+ Args:
600
+ c_dim (int): dimension of latent code c
601
+ dim (int): input points dimension
602
+ hidden_dim (int): hidden dimension of the network
603
+ scatter_type (str): feature aggregation when doing local pooling
604
+ unet (bool): weather to use U-Net
605
+ unet_kwargs (str): U-Net parameters
606
+ unet3d (bool): weather to use 3D U-Net
607
+ unet3d_kwargs (str): 3D U-Net parameters
608
+ plane_resolution (int): defined resolution for plane feature
609
+ grid_resolution (int): defined resolution for grid feature
610
+ plane_type (str): feature type, 'xz' - 1-plane, ['xz', 'xy', 'yz'] - 3-plane, ['grid'] - 3D grid volume
611
+ padding (float): conventional padding paramter of ONet for unit cube, so [-0.5, 0.5] -> [-0.55, 0.55]
612
+ n_blocks (int): number of blocks ResNetBlockFC layers
613
+ '''
614
+
615
+ def __init__(self, ch=128, dim=3, hidden_dim=128, scatter_type='max',
616
+ unet=True, unet_kwargs=None, unet3d=False, unet3d_kwargs=None,
617
+ hw=None, grid_resolution=None, plane_type='xz', padding=0.1, n_blocks=5):
618
+ super().__init__()
619
+ c_dim = ch
620
+ unet3d = False
621
+ plane_type = ['xy', 'xz', 'yz']
622
+ plane_resolution = hw
623
+
624
+ self.c_dim = c_dim
625
+
626
+ self.fc_pos = nn.Linear(dim, 2*hidden_dim)
627
+ self.blocks = nn.ModuleList([
628
+ ResnetBlockFC(2*hidden_dim, hidden_dim) for i in range(n_blocks)
629
+ ])
630
+ self.fc_c = nn.Linear(hidden_dim, c_dim)
631
+
632
+ self.actvn = nn.ReLU()
633
+ self.hidden_dim = hidden_dim
634
+
635
+ if unet:
636
+ self.unet = UNet(c_dim, in_channels=c_dim, **unet_kwargs)
637
+ else:
638
+ self.unet = None
639
+
640
+ if unet3d:
641
+ # self.unet3d = UNet3D(**unet3d_kwargs)
642
+ raise NotImplementedError()
643
+ else:
644
+ self.unet3d = None
645
+
646
+ self.reso_plane = plane_resolution
647
+ self.reso_grid = grid_resolution
648
+ self.plane_type = plane_type
649
+ self.padding = padding
650
+
651
+ if scatter_type == 'max':
652
+ self.scatter = scatter_max
653
+ elif scatter_type == 'mean':
654
+ self.scatter = scatter_mean
655
+ else:
656
+ raise ValueError('incorrect scatter type')
657
+
658
+ def generate_plane_features(self, p, c, plane='xz'):
659
+ # acquire indices of features in plane
660
+ xy = normalize_coordinate(p.clone(), plane=plane, padding=self.padding) # normalize to the range of (0, 1)
661
+ index = coordinate2index(xy, self.reso_plane)
662
+
663
+ # scatter plane features from points
664
+ fea_plane = c.new_zeros(p.size(0), self.c_dim, self.reso_plane**2)
665
+ c = c.permute(0, 2, 1) # B x 512 x T
666
+ fea_plane = scatter_mean(c, index, out=fea_plane) # B x 512 x reso^2
667
+ fea_plane = fea_plane.reshape(p.size(0), self.c_dim, self.reso_plane, self.reso_plane) # sparce matrix (B x 512 x reso x reso)
668
+
669
+ # process the plane features with UNet
670
+ if self.unet is not None:
671
+ fea_plane = self.unet(fea_plane)
672
+
673
+ return fea_plane
674
+
675
+ def generate_grid_features(self, p, c):
676
+ p_nor = normalize_3d_coordinate(p.clone(), padding=self.padding)
677
+ index = coordinate2index(p_nor, self.reso_grid, coord_type='3d')
678
+ # scatter grid features from points
679
+ fea_grid = c.new_zeros(p.size(0), self.c_dim, self.reso_grid**3)
680
+ c = c.permute(0, 2, 1)
681
+ fea_grid = scatter_mean(c, index, out=fea_grid) # B x C x reso^3
682
+ fea_grid = fea_grid.reshape(p.size(0), self.c_dim, self.reso_grid, self.reso_grid, self.reso_grid) # sparce matrix (B x 512 x reso x reso)
683
+
684
+ if self.unet3d is not None:
685
+ fea_grid = self.unet3d(fea_grid)
686
+
687
+ return fea_grid
688
+
689
+ def pool_local(self, xy, index, c):
690
+ bs, fea_dim = c.size(0), c.size(2)
691
+ keys = xy.keys()
692
+
693
+ c_out = 0
694
+ for key in keys:
695
+ # scatter plane features from points
696
+ if key == 'grid':
697
+ fea = self.scatter(c.permute(0, 2, 1), index[key], dim_size=self.reso_grid**3)
698
+ else:
699
+ c_permute = c.permute(0, 2, 1)
700
+ fea = self.scatter(c_permute, index[key], dim_size=self.reso_plane**2)
701
+ if self.scatter == scatter_max:
702
+ fea = fea[0]
703
+ # gather feature back to points
704
+ fea = fea.gather(dim=2, index=index[key].expand(-1, fea_dim, -1))
705
+ c_out = c_out + fea
706
+ return c_out.permute(0, 2, 1)
707
+
708
+
709
+ def forward(self, p_input, img_feats=None):
710
+ """
711
+ Args:
712
+ p_input (tensor): input points T 3 H W
713
+ img_feats (tensor): image features T C H W
714
+ """
715
+ T, _, H, W = img_feats.size()
716
+ p = rearrange(p_input, 't c h w -> (h w) t c', c=3, h=H, w=W)
717
+ fea_reshp = rearrange(img_feats, 't c h w -> (h w) t c',
718
+ c=img_feats.shape[1], h=H, w=W)
719
+
720
+ # acquire the index for each point
721
+ coord = {}
722
+ index = {}
723
+ if 'xz' in self.plane_type:
724
+ coord['xz'] = normalize_coordinate(p.clone(), plane='xz', padding=self.padding)
725
+ index['xz'] = coordinate2index(coord['xz'], self.reso_plane)
726
+ if 'xy' in self.plane_type:
727
+ coord['xy'] = normalize_coordinate(p.clone(), plane='xy', padding=self.padding)
728
+ index['xy'] = coordinate2index(coord['xy'], self.reso_plane)
729
+ if 'yz' in self.plane_type:
730
+ coord['yz'] = normalize_coordinate(p.clone(), plane='yz', padding=self.padding)
731
+ index['yz'] = coordinate2index(coord['yz'], self.reso_plane)
732
+ if 'grid' in self.plane_type:
733
+ coord['grid'] = normalize_3d_coordinate(p.clone(), padding=self.padding)
734
+ index['grid'] = coordinate2index(coord['grid'], self.reso_grid, coord_type='3d')
735
+
736
+ net = self.fc_pos(p) + fea_reshp
737
+ net = self.blocks[0](net)
738
+ for block in self.blocks[1:]:
739
+ pooled = self.pool_local(coord, index, net)
740
+ net = torch.cat([net, pooled], dim=2)
741
+ net = block(net)
742
+
743
+ c = self.fc_c(net)
744
+
745
+ fea = {}
746
+
747
+ if 'grid' in self.plane_type:
748
+ fea['grid'] = self.generate_grid_features(p, c)
749
+ if 'xz' in self.plane_type:
750
+ fea['xz'] = self.generate_plane_features(p, c, plane='xz')
751
+ if 'xy' in self.plane_type:
752
+ fea['xy'] = self.generate_plane_features(p, c, plane='xy')
753
+ if 'yz' in self.plane_type:
754
+ fea['yz'] = self.generate_plane_features(p, c, plane='yz')
755
+
756
+ ret = torch.stack([fea['xy'], fea['xz'], fea['yz']]).permute((1, 0, 2, 3, 4))
757
+ return ret
758
+
759
+ class PatchLocalPoolPointnet(nn.Module):
760
+ ''' PointNet-based encoder network with ResNet blocks.
761
+ First transform input points to local system based on the given voxel size.
762
+ Support non-fixed number of point cloud, but need to precompute the index
763
+
764
+ Args:
765
+ c_dim (int): dimension of latent code c
766
+ dim (int): input points dimension
767
+ hidden_dim (int): hidden dimension of the network
768
+ scatter_type (str): feature aggregation when doing local pooling
769
+ unet (bool): weather to use U-Net
770
+ unet_kwargs (str): U-Net parameters
771
+ unet3d (bool): weather to use 3D U-Net
772
+ unet3d_kwargs (str): 3D U-Net parameters
773
+ plane_resolution (int): defined resolution for plane feature
774
+ grid_resolution (int): defined resolution for grid feature
775
+ plane_type (str): feature type, 'xz' - 1-plane, ['xz', 'xy', 'yz'] - 3-plane, ['grid'] - 3D grid volume
776
+ padding (float): conventional padding paramter of ONet for unit cube, so [-0.5, 0.5] -> [-0.55, 0.55]
777
+ n_blocks (int): number of blocks ResNetBlockFC layers
778
+ local_coord (bool): whether to use local coordinate
779
+ pos_encoding (str): method for the positional encoding, linear|sin_cos
780
+ unit_size (float): defined voxel unit size for local system
781
+ '''
782
+
783
+ def __init__(self, c_dim=128, dim=3, hidden_dim=128, scatter_type='max',
784
+ unet=False, unet_kwargs=None, unet3d=False, unet3d_kwargs=None,
785
+ plane_resolution=None, grid_resolution=None, plane_type='xz', padding=0.1, n_blocks=5,
786
+ local_coord=False, pos_encoding='linear', unit_size=0.1):
787
+ super().__init__()
788
+ self.c_dim = c_dim
789
+
790
+ self.blocks = nn.ModuleList([
791
+ ResnetBlockFC(2*hidden_dim, hidden_dim) for i in range(n_blocks)
792
+ ])
793
+ self.fc_c = nn.Linear(hidden_dim, c_dim)
794
+
795
+ self.actvn = nn.ReLU()
796
+ self.hidden_dim = hidden_dim
797
+ self.reso_plane = plane_resolution
798
+ self.reso_grid = grid_resolution
799
+ self.plane_type = plane_type
800
+ self.padding = padding
801
+
802
+ if unet:
803
+ self.unet = UNet(c_dim, in_channels=c_dim, **unet_kwargs)
804
+ else:
805
+ self.unet = None
806
+
807
+ if unet3d:
808
+ # self.unet3d = UNet3D(**unet3d_kwargs)
809
+ raise NotImplementedError()
810
+ else:
811
+ self.unet3d = None
812
+
813
+ if scatter_type == 'max':
814
+ self.scatter = scatter_max
815
+ elif scatter_type == 'mean':
816
+ self.scatter = scatter_mean
817
+ else:
818
+ raise ValueError('incorrect scatter type')
819
+
820
+ if local_coord:
821
+ self.map2local = map2local(unit_size, pos_encoding=pos_encoding)
822
+ else:
823
+ self.map2local = None
824
+
825
+ if pos_encoding == 'sin_cos':
826
+ self.fc_pos = nn.Linear(60, 2*hidden_dim)
827
+ else:
828
+ self.fc_pos = nn.Linear(dim, 2*hidden_dim)
829
+
830
+ def generate_plane_features(self, index, c):
831
+ c = c.permute(0, 2, 1)
832
+ # scatter plane features from points
833
+ if index.max() < self.reso_plane**2:
834
+ fea_plane = c.new_zeros(c.size(0), self.c_dim, self.reso_plane**2)
835
+ fea_plane = scatter_mean(c, index, out=fea_plane) # B x c_dim x reso^2
836
+ else:
837
+ fea_plane = scatter_mean(c, index) # B x c_dim x reso^2
838
+ if fea_plane.shape[-1] > self.reso_plane**2: # deal with outliers
839
+ fea_plane = fea_plane[:, :, :-1]
840
+
841
+ fea_plane = fea_plane.reshape(c.size(0), self.c_dim, self.reso_plane, self.reso_plane)
842
+
843
+ # process the plane features with UNet
844
+ if self.unet is not None:
845
+ fea_plane = self.unet(fea_plane)
846
+
847
+ return fea_plane
848
+
849
+ def generate_grid_features(self, index, c):
850
+ # scatter grid features from points
851
+ c = c.permute(0, 2, 1)
852
+ if index.max() < self.reso_grid**3:
853
+ fea_grid = c.new_zeros(c.size(0), self.c_dim, self.reso_grid**3)
854
+ fea_grid = scatter_mean(c, index, out=fea_grid) # B x c_dim x reso^3
855
+ else:
856
+ fea_grid = scatter_mean(c, index) # B x c_dim x reso^3
857
+ if fea_grid.shape[-1] > self.reso_grid**3: # deal with outliers
858
+ fea_grid = fea_grid[:, :, :-1]
859
+ fea_grid = fea_grid.reshape(c.size(0), self.c_dim, self.reso_grid, self.reso_grid, self.reso_grid)
860
+
861
+ if self.unet3d is not None:
862
+ fea_grid = self.unet3d(fea_grid)
863
+
864
+ return fea_grid
865
+
866
+ def pool_local(self, index, c):
867
+ bs, fea_dim = c.size(0), c.size(2)
868
+ keys = index.keys()
869
+
870
+ c_out = 0
871
+ for key in keys:
872
+ # scatter plane features from points
873
+ if key == 'grid':
874
+ fea = self.scatter(c.permute(0, 2, 1), index[key])
875
+ else:
876
+ fea = self.scatter(c.permute(0, 2, 1), index[key])
877
+ if self.scatter == scatter_max:
878
+ fea = fea[0]
879
+ # gather feature back to points
880
+ fea = fea.gather(dim=2, index=index[key].expand(-1, fea_dim, -1))
881
+ c_out += fea
882
+ return c_out.permute(0, 2, 1)
883
+
884
+
885
+ def forward(self, inputs):
886
+ p = inputs['points']
887
+ index = inputs['index']
888
+
889
+ batch_size, T, D = p.size()
890
+
891
+ if self.map2local:
892
+ pp = self.map2local(p)
893
+ net = self.fc_pos(pp)
894
+ else:
895
+ net = self.fc_pos(p)
896
+
897
+ net = self.blocks[0](net)
898
+ for block in self.blocks[1:]:
899
+ pooled = self.pool_local(index, net)
900
+ net = torch.cat([net, pooled], dim=2)
901
+ net = block(net)
902
+
903
+ c = self.fc_c(net)
904
+
905
+ fea = {}
906
+ if 'grid' in self.plane_type:
907
+ fea['grid'] = self.generate_grid_features(index['grid'], c)
908
+ if 'xz' in self.plane_type:
909
+ fea['xz'] = self.generate_plane_features(index['xz'], c)
910
+ if 'xy' in self.plane_type:
911
+ fea['xy'] = self.generate_plane_features(index['xy'], c)
912
+ if 'yz' in self.plane_type:
913
+ fea['yz'] = self.generate_plane_features(index['yz'], c)
914
+
915
+ return fea
models/spatracker/models/core/spatracker/loftr/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .transformer import LocalFeatureTransformer
models/spatracker/models/core/spatracker/loftr/linear_attention.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Linear Transformer proposed in "Transformers are RNNs: Fast Autoregressive Transformers with Linear Attention"
3
+ Modified from: https://github.com/idiap/fast-transformers/blob/master/fast_transformers/attention/linear_attention.py
4
+ """
5
+
6
+ import torch
7
+ from torch.nn import Module, Dropout
8
+
9
+
10
+ def elu_feature_map(x):
11
+ return torch.nn.functional.elu(x) + 1
12
+
13
+
14
+ class LinearAttention(Module):
15
+ def __init__(self, eps=1e-6):
16
+ super().__init__()
17
+ self.feature_map = elu_feature_map
18
+ self.eps = eps
19
+
20
+ def forward(self, queries, keys, values, q_mask=None, kv_mask=None):
21
+ """ Multi-Head linear attention proposed in "Transformers are RNNs"
22
+ Args:
23
+ queries: [N, L, H, D]
24
+ keys: [N, S, H, D]
25
+ values: [N, S, H, D]
26
+ q_mask: [N, L]
27
+ kv_mask: [N, S]
28
+ Returns:
29
+ queried_values: (N, L, H, D)
30
+ """
31
+ Q = self.feature_map(queries)
32
+ K = self.feature_map(keys)
33
+
34
+ # set padded position to zero
35
+ if q_mask is not None:
36
+ Q = Q * q_mask[:, :, None, None]
37
+ if kv_mask is not None:
38
+ K = K * kv_mask[:, :, None, None]
39
+ values = values * kv_mask[:, :, None, None]
40
+
41
+ v_length = values.size(1)
42
+ values = values / v_length # prevent fp16 overflow
43
+ KV = torch.einsum("nshd,nshv->nhdv", K, values) # (S,D)' @ S,V
44
+ Z = 1 / (torch.einsum("nlhd,nhd->nlh", Q, K.sum(dim=1)) + self.eps)
45
+ queried_values = torch.einsum("nlhd,nhdv,nlh->nlhv", Q, KV, Z) * v_length
46
+
47
+ return queried_values.contiguous()
48
+
49
+
50
+ class FullAttention(Module):
51
+ def __init__(self, use_dropout=False, attention_dropout=0.1):
52
+ super().__init__()
53
+ self.use_dropout = use_dropout
54
+ self.dropout = Dropout(attention_dropout)
55
+
56
+ def forward(self, queries, keys, values, q_mask=None, kv_mask=None):
57
+ """ Multi-head scaled dot-product attention, a.k.a full attention.
58
+ Args:
59
+ queries: [N, L, H, D]
60
+ keys: [N, S, H, D]
61
+ values: [N, S, H, D]
62
+ q_mask: [N, L]
63
+ kv_mask: [N, S]
64
+ Returns:
65
+ queried_values: (N, L, H, D)
66
+ """
67
+
68
+ # Compute the unnormalized attention and apply the masks
69
+ QK = torch.einsum("nlhd,nshd->nlsh", queries, keys)
70
+ if kv_mask is not None:
71
+ QK.masked_fill_(~(q_mask[:, :, None, None] * kv_mask[:, None, :, None]), float('-inf'))
72
+
73
+ # Compute the attention and the weighted average
74
+ softmax_temp = 1. / queries.size(3)**.5 # sqrt(D)
75
+ A = torch.softmax(softmax_temp * QK, dim=2)
76
+ if self.use_dropout:
77
+ A = self.dropout(A)
78
+
79
+ queried_values = torch.einsum("nlsh,nshd->nlhd", A, values)
80
+
81
+ return queried_values.contiguous()
models/spatracker/models/core/spatracker/loftr/transformer.py ADDED
@@ -0,0 +1,142 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ modified from
3
+ https://github.com/zju3dv/LoFTR/blob/master/src/loftr/loftr_module/transformer.py
4
+ '''
5
+ import torch
6
+ from torch.nn import Module, Dropout
7
+ import copy
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F
10
+
11
+
12
+ def elu_feature_map(x):
13
+ return torch.nn.functional.elu(x) + 1
14
+
15
+ class FullAttention(Module):
16
+ def __init__(self, use_dropout=False, attention_dropout=0.1):
17
+ super().__init__()
18
+ self.use_dropout = use_dropout
19
+ self.dropout = Dropout(attention_dropout)
20
+
21
+ def forward(self, queries, keys, values, q_mask=None, kv_mask=None):
22
+ """ Multi-head scaled dot-product attention, a.k.a full attention.
23
+ Args:
24
+ queries: [N, L, H, D]
25
+ keys: [N, S, H, D]
26
+ values: [N, S, H, D]
27
+ q_mask: [N, L]
28
+ kv_mask: [N, S]
29
+ Returns:
30
+ queried_values: (N, L, H, D)
31
+ """
32
+
33
+ # Compute the unnormalized attention and apply the masks
34
+ # QK = torch.einsum("nlhd,nshd->nlsh", queries, keys)
35
+ # if kv_mask is not None:
36
+ # QK.masked_fill_(~(q_mask[:, :, None, None] * kv_mask[:, None, :, None]), float(-1e12))
37
+ # softmax_temp = 1. / queries.size(3)**.5 # sqrt(D)
38
+ # A = torch.softmax(softmax_temp * QK, dim=2)
39
+ # if self.use_dropout:
40
+ # A = self.dropout(A)
41
+ # queried_values_ = torch.einsum("nlsh,nshd->nlhd", A, values)
42
+
43
+ # Compute the attention and the weighted average
44
+ input_args = [x.half().contiguous() for x in [queries.permute(0,2,1,3), keys.permute(0,2,1,3), values.permute(0,2,1,3)]]
45
+ queried_values = F.scaled_dot_product_attention(*input_args).permute(0,2,1,3).float() # type: ignore
46
+
47
+
48
+ return queried_values.contiguous()
49
+
50
+ class TransformerEncoderLayer(nn.Module):
51
+ def __init__(self,
52
+ d_model,
53
+ nhead,):
54
+ super(TransformerEncoderLayer, self).__init__()
55
+
56
+ self.dim = d_model // nhead
57
+ self.nhead = nhead
58
+
59
+ # multi-head attention
60
+ self.q_proj = nn.Linear(d_model, d_model, bias=False)
61
+ self.k_proj = nn.Linear(d_model, d_model, bias=False)
62
+ self.v_proj = nn.Linear(d_model, d_model, bias=False)
63
+ self.attention = FullAttention()
64
+ self.merge = nn.Linear(d_model, d_model, bias=False)
65
+
66
+ # feed-forward network
67
+ self.mlp = nn.Sequential(
68
+ nn.Linear(d_model*2, d_model*2, bias=False),
69
+ nn.ReLU(True),
70
+ nn.Linear(d_model*2, d_model, bias=False),
71
+ )
72
+
73
+ # norm and dropout
74
+ self.norm1 = nn.LayerNorm(d_model)
75
+ self.norm2 = nn.LayerNorm(d_model)
76
+
77
+ def forward(self, x, source, x_mask=None, source_mask=None):
78
+ """
79
+ Args:
80
+ x (torch.Tensor): [N, L, C]
81
+ source (torch.Tensor): [N, S, C]
82
+ x_mask (torch.Tensor): [N, L] (optional)
83
+ source_mask (torch.Tensor): [N, S] (optional)
84
+ """
85
+ bs = x.size(0)
86
+ query, key, value = x, source, source
87
+
88
+ # multi-head attention
89
+ query = self.q_proj(query).view(bs, -1, self.nhead, self.dim) # [N, L, (H, D)]
90
+ key = self.k_proj(key).view(bs, -1, self.nhead, self.dim) # [N, S, (H, D)]
91
+ value = self.v_proj(value).view(bs, -1, self.nhead, self.dim)
92
+ message = self.attention(query, key, value, q_mask=x_mask, kv_mask=source_mask) # [N, L, (H, D)]
93
+ message = self.merge(message.view(bs, -1, self.nhead*self.dim)) # [N, L, C]
94
+ message = self.norm1(message)
95
+
96
+ # feed-forward network
97
+ message = self.mlp(torch.cat([x, message], dim=2))
98
+ message = self.norm2(message)
99
+
100
+ return x + message
101
+
102
+ class LocalFeatureTransformer(nn.Module):
103
+ """A Local Feature Transformer module."""
104
+
105
+ def __init__(self, config):
106
+ super(LocalFeatureTransformer, self).__init__()
107
+
108
+ self.config = config
109
+ self.d_model = config['d_model']
110
+ self.nhead = config['nhead']
111
+ self.layer_names = config['layer_names']
112
+ encoder_layer = TransformerEncoderLayer(config['d_model'], config['nhead'])
113
+ self.layers = nn.ModuleList([copy.deepcopy(encoder_layer) for _ in range(len(self.layer_names))])
114
+ self._reset_parameters()
115
+
116
+ def _reset_parameters(self):
117
+ for p in self.parameters():
118
+ if p.dim() > 1:
119
+ nn.init.xavier_uniform_(p)
120
+
121
+ def forward(self, feat0, feat1, mask0=None, mask1=None):
122
+ """
123
+ Args:
124
+ feat0 (torch.Tensor): [N, L, C]
125
+ feat1 (torch.Tensor): [N, S, C]
126
+ mask0 (torch.Tensor): [N, L] (optional)
127
+ mask1 (torch.Tensor): [N, S] (optional)
128
+ """
129
+
130
+ assert self.d_model == feat0.size(2), "the feature number of src and transformer must be equal"
131
+
132
+ for layer, name in zip(self.layers, self.layer_names):
133
+ if name == 'self':
134
+ feat0 = layer(feat0, feat0, mask0, mask0)
135
+ feat1 = layer(feat1, feat1, mask1, mask1)
136
+ elif name == 'cross':
137
+ feat0 = layer(feat0, feat1, mask0, mask1)
138
+ feat1 = layer(feat1, feat0, mask1, mask0)
139
+ else:
140
+ raise KeyError
141
+
142
+ return feat0, feat1
models/spatracker/models/core/spatracker/losses.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import torch
8
+ import torch.nn.functional as F
9
+ from models.spatracker.models.core.model_utils import reduce_masked_mean
10
+ from models.spatracker.models.core.spatracker.blocks import (
11
+ pix2cam
12
+ )
13
+ from models.spatracker.models.core.model_utils import (
14
+ bilinear_sample2d
15
+ )
16
+
17
+ EPS = 1e-6
18
+ import torchvision.transforms.functional as TF
19
+
20
+ sigma = 3
21
+ x_grid = torch.arange(-7,8,1)
22
+ y_grid = torch.arange(-7,8,1)
23
+ x_grid, y_grid = torch.meshgrid(x_grid, y_grid)
24
+ gridxy = torch.stack([x_grid, y_grid], dim=-1).float()
25
+ gs_kernel = torch.exp(-torch.sum(gridxy**2, dim=-1)/(2*sigma**2))
26
+
27
+
28
+ def balanced_ce_loss(pred, gt, valid=None):
29
+ total_balanced_loss = 0.0
30
+ for j in range(len(gt)):
31
+ B, S, N = gt[j].shape
32
+ # pred and gt are the same shape
33
+ for (a, b) in zip(pred[j].size(), gt[j].size()):
34
+ assert a == b # some shape mismatch!
35
+ # if valid is not None:
36
+ for (a, b) in zip(pred[j].size(), valid[j].size()):
37
+ assert a == b # some shape mismatch!
38
+
39
+ pos = (gt[j] > 0.95).float()
40
+ neg = (gt[j] < 0.05).float()
41
+
42
+ label = pos * 2.0 - 1.0
43
+ a = -label * pred[j]
44
+ b = F.relu(a)
45
+ loss = b + torch.log(torch.exp(-b) + torch.exp(a - b))
46
+
47
+ pos_loss = reduce_masked_mean(loss, pos * valid[j])
48
+ neg_loss = reduce_masked_mean(loss, neg * valid[j])
49
+ balanced_loss = pos_loss + neg_loss
50
+ total_balanced_loss += balanced_loss / float(N)
51
+ import ipdb; ipdb.set_trace()
52
+ return total_balanced_loss
53
+
54
+
55
+ def sequence_loss(flow_preds, flow_gt, vis, valids, gamma=0.8,
56
+ intr=None, trajs_g_all=None):
57
+ """Loss function defined over sequence of flow predictions"""
58
+ total_flow_loss = 0.0
59
+
60
+ for j in range(len(flow_gt)):
61
+ B, S, N, D = flow_gt[j].shape
62
+ # assert D == 3
63
+ B, S1, N = vis[j].shape
64
+ B, S2, N = valids[j].shape
65
+ assert S == S1
66
+ assert S == S2
67
+ n_predictions = len(flow_preds[j])
68
+ if intr is not None:
69
+ intr_i = intr[j]
70
+ flow_loss = 0.0
71
+ for i in range(n_predictions):
72
+ i_weight = gamma ** (n_predictions - i - 1)
73
+ flow_pred = flow_preds[j][i][..., -N:, :D]
74
+ flow_gt_j = flow_gt[j].clone()
75
+ if intr is not None:
76
+ xyz_j_gt = pix2cam(flow_gt_j, intr_i)
77
+ try:
78
+ i_loss = (flow_pred - flow_gt_j).abs() # B, S, N, 3
79
+ except:
80
+ import ipdb; ipdb.set_trace()
81
+ if D==3:
82
+ i_loss[...,2]*=30
83
+ i_loss = torch.mean(i_loss, dim=3) # B, S, N
84
+ flow_loss += i_weight * (reduce_masked_mean(i_loss, valids[j]))
85
+
86
+ flow_loss = flow_loss / n_predictions
87
+ total_flow_loss += flow_loss / float(N)
88
+
89
+
90
+ return total_flow_loss
models/spatracker/models/core/spatracker/softsplat.py ADDED
@@ -0,0 +1,539 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ """The code of softsplat function is modified from:
4
+ https://github.com/sniklaus/softmax-splatting/blob/master/softsplat.py
5
+
6
+ """
7
+
8
+
9
+ import collections
10
+ import cupy
11
+ import os
12
+ import re
13
+ import torch
14
+ import typing
15
+
16
+
17
+ ##########################################################
18
+
19
+
20
+ objCudacache = {}
21
+
22
+
23
+ def cuda_int32(intIn:int):
24
+ return cupy.int32(intIn)
25
+ # end
26
+
27
+
28
+ def cuda_float32(fltIn:float):
29
+ return cupy.float32(fltIn)
30
+ # end
31
+
32
+
33
+ def cuda_kernel(strFunction:str, strKernel:str, objVariables:typing.Dict):
34
+ if 'device' not in objCudacache:
35
+ objCudacache['device'] = torch.cuda.get_device_name()
36
+ # end
37
+
38
+ strKey = strFunction
39
+
40
+ for strVariable in objVariables:
41
+ objValue = objVariables[strVariable]
42
+
43
+ strKey += strVariable
44
+
45
+ if objValue is None:
46
+ continue
47
+
48
+ elif type(objValue) == int:
49
+ strKey += str(objValue)
50
+
51
+ elif type(objValue) == float:
52
+ strKey += str(objValue)
53
+
54
+ elif type(objValue) == bool:
55
+ strKey += str(objValue)
56
+
57
+ elif type(objValue) == str:
58
+ strKey += objValue
59
+
60
+ elif type(objValue) == torch.Tensor:
61
+ strKey += str(objValue.dtype)
62
+ strKey += str(objValue.shape)
63
+ strKey += str(objValue.stride())
64
+
65
+ elif True:
66
+ print(strVariable, type(objValue))
67
+ assert(False)
68
+
69
+ # end
70
+ # end
71
+
72
+ strKey += objCudacache['device']
73
+
74
+ if strKey not in objCudacache:
75
+ for strVariable in objVariables:
76
+ objValue = objVariables[strVariable]
77
+
78
+ if objValue is None:
79
+ continue
80
+
81
+ elif type(objValue) == int:
82
+ strKernel = strKernel.replace('{{' + strVariable + '}}', str(objValue))
83
+
84
+ elif type(objValue) == float:
85
+ strKernel = strKernel.replace('{{' + strVariable + '}}', str(objValue))
86
+
87
+ elif type(objValue) == bool:
88
+ strKernel = strKernel.replace('{{' + strVariable + '}}', str(objValue))
89
+
90
+ elif type(objValue) == str:
91
+ strKernel = strKernel.replace('{{' + strVariable + '}}', objValue)
92
+
93
+ elif type(objValue) == torch.Tensor and objValue.dtype == torch.uint8:
94
+ strKernel = strKernel.replace('{{type}}', 'unsigned char')
95
+
96
+ elif type(objValue) == torch.Tensor and objValue.dtype == torch.float16:
97
+ strKernel = strKernel.replace('{{type}}', 'half')
98
+
99
+ elif type(objValue) == torch.Tensor and objValue.dtype == torch.float32:
100
+ strKernel = strKernel.replace('{{type}}', 'float')
101
+
102
+ elif type(objValue) == torch.Tensor and objValue.dtype == torch.float64:
103
+ strKernel = strKernel.replace('{{type}}', 'double')
104
+
105
+ elif type(objValue) == torch.Tensor and objValue.dtype == torch.int32:
106
+ strKernel = strKernel.replace('{{type}}', 'int')
107
+
108
+ elif type(objValue) == torch.Tensor and objValue.dtype == torch.int64:
109
+ strKernel = strKernel.replace('{{type}}', 'long')
110
+
111
+ elif type(objValue) == torch.Tensor:
112
+ print(strVariable, objValue.dtype)
113
+ assert(False)
114
+
115
+ elif True:
116
+ print(strVariable, type(objValue))
117
+ assert(False)
118
+
119
+ # end
120
+ # end
121
+
122
+ while True:
123
+ objMatch = re.search('(SIZE_)([0-4])(\()([^\)]*)(\))', strKernel)
124
+
125
+ if objMatch is None:
126
+ break
127
+ # end
128
+
129
+ intArg = int(objMatch.group(2))
130
+
131
+ strTensor = objMatch.group(4)
132
+ intSizes = objVariables[strTensor].size()
133
+
134
+ strKernel = strKernel.replace(objMatch.group(), str(intSizes[intArg] if torch.is_tensor(intSizes[intArg]) == False else intSizes[intArg].item()))
135
+ # end
136
+
137
+ while True:
138
+ objMatch = re.search('(OFFSET_)([0-4])(\()', strKernel)
139
+
140
+ if objMatch is None:
141
+ break
142
+ # end
143
+
144
+ intStart = objMatch.span()[1]
145
+ intStop = objMatch.span()[1]
146
+ intParentheses = 1
147
+
148
+ while True:
149
+ intParentheses += 1 if strKernel[intStop] == '(' else 0
150
+ intParentheses -= 1 if strKernel[intStop] == ')' else 0
151
+
152
+ if intParentheses == 0:
153
+ break
154
+ # end
155
+
156
+ intStop += 1
157
+ # end
158
+
159
+ intArgs = int(objMatch.group(2))
160
+ strArgs = strKernel[intStart:intStop].split(',')
161
+
162
+ assert(intArgs == len(strArgs) - 1)
163
+
164
+ strTensor = strArgs[0]
165
+ intStrides = objVariables[strTensor].stride()
166
+
167
+ strIndex = []
168
+
169
+ for intArg in range(intArgs):
170
+ strIndex.append('((' + strArgs[intArg + 1].replace('{', '(').replace('}', ')').strip() + ')*' + str(intStrides[intArg] if torch.is_tensor(intStrides[intArg]) == False else intStrides[intArg].item()) + ')')
171
+ # end
172
+
173
+ strKernel = strKernel.replace('OFFSET_' + str(intArgs) + '(' + strKernel[intStart:intStop] + ')', '(' + str.join('+', strIndex) + ')')
174
+ # end
175
+
176
+ while True:
177
+ objMatch = re.search('(VALUE_)([0-4])(\()', strKernel)
178
+
179
+ if objMatch is None:
180
+ break
181
+ # end
182
+
183
+ intStart = objMatch.span()[1]
184
+ intStop = objMatch.span()[1]
185
+ intParentheses = 1
186
+
187
+ while True:
188
+ intParentheses += 1 if strKernel[intStop] == '(' else 0
189
+ intParentheses -= 1 if strKernel[intStop] == ')' else 0
190
+
191
+ if intParentheses == 0:
192
+ break
193
+ # end
194
+
195
+ intStop += 1
196
+ # end
197
+
198
+ intArgs = int(objMatch.group(2))
199
+ strArgs = strKernel[intStart:intStop].split(',')
200
+
201
+ assert(intArgs == len(strArgs) - 1)
202
+
203
+ strTensor = strArgs[0]
204
+ intStrides = objVariables[strTensor].stride()
205
+
206
+ strIndex = []
207
+
208
+ for intArg in range(intArgs):
209
+ strIndex.append('((' + strArgs[intArg + 1].replace('{', '(').replace('}', ')').strip() + ')*' + str(intStrides[intArg] if torch.is_tensor(intStrides[intArg]) == False else intStrides[intArg].item()) + ')')
210
+ # end
211
+
212
+ strKernel = strKernel.replace('VALUE_' + str(intArgs) + '(' + strKernel[intStart:intStop] + ')', strTensor + '[' + str.join('+', strIndex) + ']')
213
+ # end
214
+
215
+ objCudacache[strKey] = {
216
+ 'strFunction': strFunction,
217
+ 'strKernel': strKernel
218
+ }
219
+ # end
220
+
221
+ return strKey
222
+ # end
223
+
224
+
225
+ @cupy.memoize(for_each_device=True)
226
+ def cuda_launch(strKey:str):
227
+ if 'CUDA_HOME' not in os.environ:
228
+ os.environ['CUDA_HOME'] = cupy.cuda.get_cuda_path()
229
+ # end
230
+
231
+ return cupy.RawKernel(objCudacache[strKey]['strKernel'], objCudacache[strKey]['strFunction'])
232
+ # end
233
+
234
+
235
+ ##########################################################
236
+
237
+
238
+ def softsplat(tenIn:torch.Tensor, tenFlow:torch.Tensor,
239
+ tenMetric:torch.Tensor, strMode:str, tenoutH=None, tenoutW=None):
240
+ assert(strMode.split('-')[0] in ['sum', 'avg', 'linear', 'soft'])
241
+
242
+ if strMode == 'sum': assert(tenMetric is None)
243
+ if strMode == 'avg': assert(tenMetric is None)
244
+ if strMode.split('-')[0] == 'linear': assert(tenMetric is not None)
245
+ if strMode.split('-')[0] == 'soft': assert(tenMetric is not None)
246
+
247
+ if strMode == 'avg':
248
+ tenIn = torch.cat([tenIn, tenIn.new_ones([tenIn.shape[0], 1, tenIn.shape[2], tenIn.shape[3]])], 1)
249
+
250
+ elif strMode.split('-')[0] == 'linear':
251
+ tenIn = torch.cat([tenIn * tenMetric, tenMetric], 1)
252
+
253
+ elif strMode.split('-')[0] == 'soft':
254
+ tenIn = torch.cat([tenIn * tenMetric.exp(), tenMetric.exp()], 1)
255
+
256
+ # end
257
+
258
+ tenOut = softsplat_func.apply(tenIn, tenFlow, tenoutH, tenoutW)
259
+
260
+ if strMode.split('-')[0] in ['avg', 'linear', 'soft']:
261
+ tenNormalize = tenOut[:, -1:, :, :]
262
+
263
+ if len(strMode.split('-')) == 1:
264
+ tenNormalize = tenNormalize + 0.0000001
265
+
266
+ elif strMode.split('-')[1] == 'addeps':
267
+ tenNormalize = tenNormalize + 0.0000001
268
+
269
+ elif strMode.split('-')[1] == 'zeroeps':
270
+ tenNormalize[tenNormalize == 0.0] = 1.0
271
+
272
+ elif strMode.split('-')[1] == 'clipeps':
273
+ tenNormalize = tenNormalize.clip(0.0000001, None)
274
+
275
+ # end
276
+ tenOut = tenOut[:, :-1, :, :] / tenNormalize
277
+ # end
278
+
279
+ return tenOut
280
+ # end
281
+
282
+
283
+ class softsplat_func(torch.autograd.Function):
284
+ @staticmethod
285
+ @torch.cuda.amp.custom_fwd(cast_inputs=torch.float32)
286
+ def forward(self, tenIn, tenFlow, H=None, W=None):
287
+ if H is None:
288
+ tenOut = tenIn.new_zeros([tenIn.shape[0], tenIn.shape[1], tenIn.shape[2], tenIn.shape[3]])
289
+ else:
290
+ tenOut = tenIn.new_zeros([tenIn.shape[0], tenIn.shape[1], H, W])
291
+
292
+ if tenIn.is_cuda == True:
293
+ cuda_launch(cuda_kernel('softsplat_out', '''
294
+ extern "C" __global__ void __launch_bounds__(512) softsplat_out(
295
+ const int n,
296
+ const {{type}}* __restrict__ tenIn,
297
+ const {{type}}* __restrict__ tenFlow,
298
+ {{type}}* __restrict__ tenOut
299
+ ) { for (int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x; intIndex < n; intIndex += blockDim.x * gridDim.x) {
300
+ const int intN = ( intIndex / SIZE_3(tenIn) / SIZE_2(tenIn) / SIZE_1(tenIn) ) % SIZE_0(tenIn);
301
+ const int intC = ( intIndex / SIZE_3(tenIn) / SIZE_2(tenIn) ) % SIZE_1(tenIn);
302
+ const int intY = ( intIndex / SIZE_3(tenIn) ) % SIZE_2(tenIn);
303
+ const int intX = ( intIndex ) % SIZE_3(tenIn);
304
+
305
+ assert(SIZE_1(tenFlow) == 2);
306
+
307
+ {{type}} fltX = ({{type}}) (intX) + VALUE_4(tenFlow, intN, 0, intY, intX);
308
+ {{type}} fltY = ({{type}}) (intY) + VALUE_4(tenFlow, intN, 1, intY, intX);
309
+
310
+ if (isfinite(fltX) == false) { return; }
311
+ if (isfinite(fltY) == false) { return; }
312
+
313
+ {{type}} fltIn = VALUE_4(tenIn, intN, intC, intY, intX);
314
+
315
+ int intNorthwestX = (int) (floor(fltX));
316
+ int intNorthwestY = (int) (floor(fltY));
317
+ int intNortheastX = intNorthwestX + 1;
318
+ int intNortheastY = intNorthwestY;
319
+ int intSouthwestX = intNorthwestX;
320
+ int intSouthwestY = intNorthwestY + 1;
321
+ int intSoutheastX = intNorthwestX + 1;
322
+ int intSoutheastY = intNorthwestY + 1;
323
+
324
+ {{type}} fltNorthwest = (({{type}}) (intSoutheastX) - fltX) * (({{type}}) (intSoutheastY) - fltY);
325
+ {{type}} fltNortheast = (fltX - ({{type}}) (intSouthwestX)) * (({{type}}) (intSouthwestY) - fltY);
326
+ {{type}} fltSouthwest = (({{type}}) (intNortheastX) - fltX) * (fltY - ({{type}}) (intNortheastY));
327
+ {{type}} fltSoutheast = (fltX - ({{type}}) (intNorthwestX)) * (fltY - ({{type}}) (intNorthwestY));
328
+
329
+ if ((intNorthwestX >= 0) && (intNorthwestX < SIZE_3(tenOut)) && (intNorthwestY >= 0) && (intNorthwestY < SIZE_2(tenOut))) {
330
+ atomicAdd(&tenOut[OFFSET_4(tenOut, intN, intC, intNorthwestY, intNorthwestX)], fltIn * fltNorthwest);
331
+ }
332
+
333
+ if ((intNortheastX >= 0) && (intNortheastX < SIZE_3(tenOut)) && (intNortheastY >= 0) && (intNortheastY < SIZE_2(tenOut))) {
334
+ atomicAdd(&tenOut[OFFSET_4(tenOut, intN, intC, intNortheastY, intNortheastX)], fltIn * fltNortheast);
335
+ }
336
+
337
+ if ((intSouthwestX >= 0) && (intSouthwestX < SIZE_3(tenOut)) && (intSouthwestY >= 0) && (intSouthwestY < SIZE_2(tenOut))) {
338
+ atomicAdd(&tenOut[OFFSET_4(tenOut, intN, intC, intSouthwestY, intSouthwestX)], fltIn * fltSouthwest);
339
+ }
340
+
341
+ if ((intSoutheastX >= 0) && (intSoutheastX < SIZE_3(tenOut)) && (intSoutheastY >= 0) && (intSoutheastY < SIZE_2(tenOut))) {
342
+ atomicAdd(&tenOut[OFFSET_4(tenOut, intN, intC, intSoutheastY, intSoutheastX)], fltIn * fltSoutheast);
343
+ }
344
+ } }
345
+ ''', {
346
+ 'tenIn': tenIn,
347
+ 'tenFlow': tenFlow,
348
+ 'tenOut': tenOut
349
+ }))(
350
+ grid=tuple([int((tenIn.nelement() + 512 - 1) / 512), 1, 1]),
351
+ block=tuple([512, 1, 1]),
352
+ args=[cuda_int32(tenOut.nelement()), tenIn.data_ptr(), tenFlow.data_ptr(), tenOut.data_ptr()],
353
+ stream=collections.namedtuple('Stream', 'ptr')(torch.cuda.current_stream().cuda_stream)
354
+ )
355
+
356
+ elif tenIn.is_cuda != True:
357
+ assert(False)
358
+
359
+ # end
360
+
361
+ self.save_for_backward(tenIn, tenFlow)
362
+
363
+ return tenOut
364
+ # end
365
+
366
+ @staticmethod
367
+ @torch.cuda.amp.custom_bwd
368
+ def backward(self, tenOutgrad):
369
+ tenIn, tenFlow = self.saved_tensors
370
+
371
+ tenOutgrad = tenOutgrad.contiguous(); assert(tenOutgrad.is_cuda == True)
372
+
373
+ tenIngrad = tenIn.new_zeros([tenIn.shape[0], tenIn.shape[1], tenIn.shape[2], tenIn.shape[3]]) if self.needs_input_grad[0] == True else None
374
+ tenFlowgrad = tenFlow.new_zeros([tenFlow.shape[0], tenFlow.shape[1], tenFlow.shape[2], tenFlow.shape[3]]) if self.needs_input_grad[1] == True else None
375
+ Hgrad = None
376
+ Wgrad = None
377
+
378
+ if tenIngrad is not None:
379
+ cuda_launch(cuda_kernel('softsplat_ingrad', '''
380
+ extern "C" __global__ void __launch_bounds__(512) softsplat_ingrad(
381
+ const int n,
382
+ const {{type}}* __restrict__ tenIn,
383
+ const {{type}}* __restrict__ tenFlow,
384
+ const {{type}}* __restrict__ tenOutgrad,
385
+ {{type}}* __restrict__ tenIngrad,
386
+ {{type}}* __restrict__ tenFlowgrad
387
+ ) { for (int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x; intIndex < n; intIndex += blockDim.x * gridDim.x) {
388
+ const int intN = ( intIndex / SIZE_3(tenIngrad) / SIZE_2(tenIngrad) / SIZE_1(tenIngrad) ) % SIZE_0(tenIngrad);
389
+ const int intC = ( intIndex / SIZE_3(tenIngrad) / SIZE_2(tenIngrad) ) % SIZE_1(tenIngrad);
390
+ const int intY = ( intIndex / SIZE_3(tenIngrad) ) % SIZE_2(tenIngrad);
391
+ const int intX = ( intIndex ) % SIZE_3(tenIngrad);
392
+
393
+ assert(SIZE_1(tenFlow) == 2);
394
+
395
+ {{type}} fltIngrad = 0.0f;
396
+
397
+ {{type}} fltX = ({{type}}) (intX) + VALUE_4(tenFlow, intN, 0, intY, intX);
398
+ {{type}} fltY = ({{type}}) (intY) + VALUE_4(tenFlow, intN, 1, intY, intX);
399
+
400
+ if (isfinite(fltX) == false) { return; }
401
+ if (isfinite(fltY) == false) { return; }
402
+
403
+ int intNorthwestX = (int) (floor(fltX));
404
+ int intNorthwestY = (int) (floor(fltY));
405
+ int intNortheastX = intNorthwestX + 1;
406
+ int intNortheastY = intNorthwestY;
407
+ int intSouthwestX = intNorthwestX;
408
+ int intSouthwestY = intNorthwestY + 1;
409
+ int intSoutheastX = intNorthwestX + 1;
410
+ int intSoutheastY = intNorthwestY + 1;
411
+
412
+ {{type}} fltNorthwest = (({{type}}) (intSoutheastX) - fltX) * (({{type}}) (intSoutheastY) - fltY);
413
+ {{type}} fltNortheast = (fltX - ({{type}}) (intSouthwestX)) * (({{type}}) (intSouthwestY) - fltY);
414
+ {{type}} fltSouthwest = (({{type}}) (intNortheastX) - fltX) * (fltY - ({{type}}) (intNortheastY));
415
+ {{type}} fltSoutheast = (fltX - ({{type}}) (intNorthwestX)) * (fltY - ({{type}}) (intNorthwestY));
416
+
417
+ if ((intNorthwestX >= 0) && (intNorthwestX < SIZE_3(tenOutgrad)) && (intNorthwestY >= 0) && (intNorthwestY < SIZE_2(tenOutgrad))) {
418
+ fltIngrad += VALUE_4(tenOutgrad, intN, intC, intNorthwestY, intNorthwestX) * fltNorthwest;
419
+ }
420
+
421
+ if ((intNortheastX >= 0) && (intNortheastX < SIZE_3(tenOutgrad)) && (intNortheastY >= 0) && (intNortheastY < SIZE_2(tenOutgrad))) {
422
+ fltIngrad += VALUE_4(tenOutgrad, intN, intC, intNortheastY, intNortheastX) * fltNortheast;
423
+ }
424
+
425
+ if ((intSouthwestX >= 0) && (intSouthwestX < SIZE_3(tenOutgrad)) && (intSouthwestY >= 0) && (intSouthwestY < SIZE_2(tenOutgrad))) {
426
+ fltIngrad += VALUE_4(tenOutgrad, intN, intC, intSouthwestY, intSouthwestX) * fltSouthwest;
427
+ }
428
+
429
+ if ((intSoutheastX >= 0) && (intSoutheastX < SIZE_3(tenOutgrad)) && (intSoutheastY >= 0) && (intSoutheastY < SIZE_2(tenOutgrad))) {
430
+ fltIngrad += VALUE_4(tenOutgrad, intN, intC, intSoutheastY, intSoutheastX) * fltSoutheast;
431
+ }
432
+
433
+ tenIngrad[intIndex] = fltIngrad;
434
+ } }
435
+ ''', {
436
+ 'tenIn': tenIn,
437
+ 'tenFlow': tenFlow,
438
+ 'tenOutgrad': tenOutgrad,
439
+ 'tenIngrad': tenIngrad,
440
+ 'tenFlowgrad': tenFlowgrad
441
+ }))(
442
+ grid=tuple([int((tenIngrad.nelement() + 512 - 1) / 512), 1, 1]),
443
+ block=tuple([512, 1, 1]),
444
+ args=[cuda_int32(tenIngrad.nelement()), tenIn.data_ptr(), tenFlow.data_ptr(), tenOutgrad.data_ptr(), tenIngrad.data_ptr(), None],
445
+ stream=collections.namedtuple('Stream', 'ptr')(torch.cuda.current_stream().cuda_stream)
446
+ )
447
+ # end
448
+
449
+ if tenFlowgrad is not None:
450
+ cuda_launch(cuda_kernel('softsplat_flowgrad', '''
451
+ extern "C" __global__ void __launch_bounds__(512) softsplat_flowgrad(
452
+ const int n,
453
+ const {{type}}* __restrict__ tenIn,
454
+ const {{type}}* __restrict__ tenFlow,
455
+ const {{type}}* __restrict__ tenOutgrad,
456
+ {{type}}* __restrict__ tenIngrad,
457
+ {{type}}* __restrict__ tenFlowgrad
458
+ ) { for (int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x; intIndex < n; intIndex += blockDim.x * gridDim.x) {
459
+ const int intN = ( intIndex / SIZE_3(tenFlowgrad) / SIZE_2(tenFlowgrad) / SIZE_1(tenFlowgrad) ) % SIZE_0(tenFlowgrad);
460
+ const int intC = ( intIndex / SIZE_3(tenFlowgrad) / SIZE_2(tenFlowgrad) ) % SIZE_1(tenFlowgrad);
461
+ const int intY = ( intIndex / SIZE_3(tenFlowgrad) ) % SIZE_2(tenFlowgrad);
462
+ const int intX = ( intIndex ) % SIZE_3(tenFlowgrad);
463
+
464
+ assert(SIZE_1(tenFlow) == 2);
465
+
466
+ {{type}} fltFlowgrad = 0.0f;
467
+
468
+ {{type}} fltX = ({{type}}) (intX) + VALUE_4(tenFlow, intN, 0, intY, intX);
469
+ {{type}} fltY = ({{type}}) (intY) + VALUE_4(tenFlow, intN, 1, intY, intX);
470
+
471
+ if (isfinite(fltX) == false) { return; }
472
+ if (isfinite(fltY) == false) { return; }
473
+
474
+ int intNorthwestX = (int) (floor(fltX));
475
+ int intNorthwestY = (int) (floor(fltY));
476
+ int intNortheastX = intNorthwestX + 1;
477
+ int intNortheastY = intNorthwestY;
478
+ int intSouthwestX = intNorthwestX;
479
+ int intSouthwestY = intNorthwestY + 1;
480
+ int intSoutheastX = intNorthwestX + 1;
481
+ int intSoutheastY = intNorthwestY + 1;
482
+
483
+ {{type}} fltNorthwest = 0.0f;
484
+ {{type}} fltNortheast = 0.0f;
485
+ {{type}} fltSouthwest = 0.0f;
486
+ {{type}} fltSoutheast = 0.0f;
487
+
488
+ if (intC == 0) {
489
+ fltNorthwest = (({{type}}) (-1.0f)) * (({{type}}) (intSoutheastY) - fltY);
490
+ fltNortheast = (({{type}}) (+1.0f)) * (({{type}}) (intSouthwestY) - fltY);
491
+ fltSouthwest = (({{type}}) (-1.0f)) * (fltY - ({{type}}) (intNortheastY));
492
+ fltSoutheast = (({{type}}) (+1.0f)) * (fltY - ({{type}}) (intNorthwestY));
493
+
494
+ } else if (intC == 1) {
495
+ fltNorthwest = (({{type}}) (intSoutheastX) - fltX) * (({{type}}) (-1.0f));
496
+ fltNortheast = (fltX - ({{type}}) (intSouthwestX)) * (({{type}}) (-1.0f));
497
+ fltSouthwest = (({{type}}) (intNortheastX) - fltX) * (({{type}}) (+1.0f));
498
+ fltSoutheast = (fltX - ({{type}}) (intNorthwestX)) * (({{type}}) (+1.0f));
499
+
500
+ }
501
+
502
+ for (int intChannel = 0; intChannel < SIZE_1(tenOutgrad); intChannel += 1) {
503
+ {{type}} fltIn = VALUE_4(tenIn, intN, intChannel, intY, intX);
504
+
505
+ if ((intNorthwestX >= 0) && (intNorthwestX < SIZE_3(tenOutgrad)) && (intNorthwestY >= 0) && (intNorthwestY < SIZE_2(tenOutgrad))) {
506
+ fltFlowgrad += VALUE_4(tenOutgrad, intN, intChannel, intNorthwestY, intNorthwestX) * fltIn * fltNorthwest;
507
+ }
508
+
509
+ if ((intNortheastX >= 0) && (intNortheastX < SIZE_3(tenOutgrad)) && (intNortheastY >= 0) && (intNortheastY < SIZE_2(tenOutgrad))) {
510
+ fltFlowgrad += VALUE_4(tenOutgrad, intN, intChannel, intNortheastY, intNortheastX) * fltIn * fltNortheast;
511
+ }
512
+
513
+ if ((intSouthwestX >= 0) && (intSouthwestX < SIZE_3(tenOutgrad)) && (intSouthwestY >= 0) && (intSouthwestY < SIZE_2(tenOutgrad))) {
514
+ fltFlowgrad += VALUE_4(tenOutgrad, intN, intChannel, intSouthwestY, intSouthwestX) * fltIn * fltSouthwest;
515
+ }
516
+
517
+ if ((intSoutheastX >= 0) && (intSoutheastX < SIZE_3(tenOutgrad)) && (intSoutheastY >= 0) && (intSoutheastY < SIZE_2(tenOutgrad))) {
518
+ fltFlowgrad += VALUE_4(tenOutgrad, intN, intChannel, intSoutheastY, intSoutheastX) * fltIn * fltSoutheast;
519
+ }
520
+ }
521
+
522
+ tenFlowgrad[intIndex] = fltFlowgrad;
523
+ } }
524
+ ''', {
525
+ 'tenIn': tenIn,
526
+ 'tenFlow': tenFlow,
527
+ 'tenOutgrad': tenOutgrad,
528
+ 'tenIngrad': tenIngrad,
529
+ 'tenFlowgrad': tenFlowgrad
530
+ }))(
531
+ grid=tuple([int((tenFlowgrad.nelement() + 512 - 1) / 512), 1, 1]),
532
+ block=tuple([512, 1, 1]),
533
+ args=[cuda_int32(tenFlowgrad.nelement()), tenIn.data_ptr(), tenFlow.data_ptr(), tenOutgrad.data_ptr(), None, tenFlowgrad.data_ptr()],
534
+ stream=collections.namedtuple('Stream', 'ptr')(torch.cuda.current_stream().cuda_stream)
535
+ )
536
+ # end
537
+ return tenIngrad, tenFlowgrad, Hgrad, Wgrad
538
+ # end
539
+ # end
models/spatracker/models/core/spatracker/spatracker.py ADDED
@@ -0,0 +1,732 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+ from easydict import EasyDict as edict
10
+ from einops import rearrange
11
+ from sklearn.cluster import SpectralClustering
12
+ from models.spatracker.models.core.spatracker.blocks import Lie
13
+ import matplotlib.pyplot as plt
14
+ import cv2
15
+
16
+ import torch.nn.functional as F
17
+ from models.spatracker.models.core.spatracker.blocks import (
18
+ BasicEncoder,
19
+ CorrBlock,
20
+ EUpdateFormer,
21
+ FusionFormer,
22
+ pix2cam,
23
+ cam2pix,
24
+ edgeMat,
25
+ VitEncoder,
26
+ DPTEnc,
27
+ Dinov2
28
+ )
29
+
30
+ from models.spatracker.models.core.spatracker.feature_net import (
31
+ LocalSoftSplat
32
+ )
33
+
34
+ from models.spatracker.models.core.model_utils import (
35
+ meshgrid2d, bilinear_sample2d, smart_cat, sample_features5d, vis_PCA
36
+ )
37
+ from models.spatracker.models.core.embeddings import (
38
+ get_2d_embedding,
39
+ get_3d_embedding,
40
+ get_1d_sincos_pos_embed_from_grid,
41
+ get_2d_sincos_pos_embed,
42
+ get_3d_sincos_pos_embed_from_grid,
43
+ Embedder_Fourier,
44
+ )
45
+ import numpy as np
46
+ from models.spatracker.models.core.spatracker.softsplat import softsplat
47
+
48
+ torch.manual_seed(0)
49
+
50
+
51
+ def get_points_on_a_grid(grid_size, interp_shape,
52
+ grid_center=(0, 0), device="cuda"):
53
+ if grid_size == 1:
54
+ return torch.tensor([interp_shape[1] / 2,
55
+ interp_shape[0] / 2], device=device)[
56
+ None, None
57
+ ]
58
+
59
+ grid_y, grid_x = meshgrid2d(
60
+ 1, grid_size, grid_size, stack=False, norm=False, device=device
61
+ )
62
+ step = interp_shape[1] // 64
63
+ if grid_center[0] != 0 or grid_center[1] != 0:
64
+ grid_y = grid_y - grid_size / 2.0
65
+ grid_x = grid_x - grid_size / 2.0
66
+ grid_y = step + grid_y.reshape(1, -1) / float(grid_size - 1) * (
67
+ interp_shape[0] - step * 2
68
+ )
69
+ grid_x = step + grid_x.reshape(1, -1) / float(grid_size - 1) * (
70
+ interp_shape[1] - step * 2
71
+ )
72
+
73
+ grid_y = grid_y + grid_center[0]
74
+ grid_x = grid_x + grid_center[1]
75
+ xy = torch.stack([grid_x, grid_y], dim=-1).to(device)
76
+ return xy
77
+
78
+
79
+ def sample_pos_embed(grid_size, embed_dim, coords):
80
+ if coords.shape[-1] == 2:
81
+ pos_embed = get_2d_sincos_pos_embed(embed_dim=embed_dim,
82
+ grid_size=grid_size)
83
+ pos_embed = (
84
+ torch.from_numpy(pos_embed)
85
+ .reshape(grid_size[0], grid_size[1], embed_dim)
86
+ .float()
87
+ .unsqueeze(0)
88
+ .to(coords.device)
89
+ )
90
+ sampled_pos_embed = bilinear_sample2d(
91
+ pos_embed.permute(0, 3, 1, 2),
92
+ coords[:, 0, :, 0], coords[:, 0, :, 1]
93
+ )
94
+ elif coords.shape[-1] == 3:
95
+ sampled_pos_embed = get_3d_sincos_pos_embed_from_grid(
96
+ embed_dim, coords[:, :1, ...]
97
+ ).float()[:,0,...].permute(0, 2, 1)
98
+
99
+ return sampled_pos_embed
100
+
101
+
102
+ class SpaTracker(nn.Module):
103
+ def __init__(
104
+ self,
105
+ S=8,
106
+ stride=8,
107
+ add_space_attn=True,
108
+ num_heads=8,
109
+ hidden_size=384,
110
+ space_depth=12,
111
+ time_depth=12,
112
+ args=edict({})
113
+ ):
114
+ super(SpaTracker, self).__init__()
115
+
116
+ # step1: config the arch of the model
117
+ self.args=args
118
+ # step1.1: config the default value of the model
119
+ if getattr(args, "depth_color", None) == None:
120
+ self.args.depth_color = False
121
+ if getattr(args, "if_ARAP", None) == None:
122
+ self.args.if_ARAP = True
123
+ if getattr(args, "flash_attn", None) == None:
124
+ self.args.flash_attn = True
125
+ if getattr(args, "backbone", None) == None:
126
+ self.args.backbone = "CNN"
127
+ if getattr(args, "Nblock", None) == None:
128
+ self.args.Nblock = 0
129
+ if getattr(args, "Embed3D", None) == None:
130
+ self.args.Embed3D = True
131
+
132
+ # step1.2: config the model parameters
133
+ self.S = S
134
+ self.stride = stride
135
+ self.hidden_dim = 256
136
+ self.latent_dim = latent_dim = 128
137
+ self.b_latent_dim = self.latent_dim//3
138
+ self.corr_levels = 4
139
+ self.corr_radius = 3
140
+ self.add_space_attn = add_space_attn
141
+ self.lie = Lie()
142
+
143
+ # step2: config the model components
144
+ # @Encoder
145
+ self.fnet = BasicEncoder(input_dim=3,
146
+ output_dim=self.latent_dim, norm_fn="instance", dropout=0,
147
+ stride=stride, Embed3D=False
148
+ )
149
+
150
+ # conv head for the tri-plane features
151
+ self.headyz = nn.Sequential(
152
+ nn.Conv2d(self.latent_dim, self.latent_dim, 3, padding=1),
153
+ nn.ReLU(inplace=True),
154
+ nn.Conv2d(self.latent_dim, self.latent_dim, 3, padding=1))
155
+
156
+ self.headxz = nn.Sequential(
157
+ nn.Conv2d(self.latent_dim, self.latent_dim, 3, padding=1),
158
+ nn.ReLU(inplace=True),
159
+ nn.Conv2d(self.latent_dim, self.latent_dim, 3, padding=1))
160
+
161
+ # @UpdateFormer
162
+ self.updateformer = EUpdateFormer(
163
+ space_depth=space_depth,
164
+ time_depth=time_depth,
165
+ input_dim=456,
166
+ hidden_size=hidden_size,
167
+ num_heads=num_heads,
168
+ output_dim=latent_dim + 3,
169
+ mlp_ratio=4.0,
170
+ add_space_attn=add_space_attn,
171
+ flash=getattr(self.args, "flash_attn", True)
172
+ )
173
+ self.support_features = torch.zeros(100, 384).to("cuda") + 0.1
174
+
175
+ self.norm = nn.GroupNorm(1, self.latent_dim)
176
+
177
+ self.ffeat_updater = nn.Sequential(
178
+ nn.Linear(self.latent_dim, self.latent_dim),
179
+ nn.GELU(),
180
+ )
181
+ self.ffeatyz_updater = nn.Sequential(
182
+ nn.Linear(self.latent_dim, self.latent_dim),
183
+ nn.GELU(),
184
+ )
185
+ self.ffeatxz_updater = nn.Sequential(
186
+ nn.Linear(self.latent_dim, self.latent_dim),
187
+ nn.GELU(),
188
+ )
189
+
190
+ #TODO @NeuralArap: optimize the arap
191
+ self.embed_traj = Embedder_Fourier(
192
+ input_dim=5, max_freq_log2=5.0, N_freqs=3, include_input=True
193
+ )
194
+ self.embed3d = Embedder_Fourier(
195
+ input_dim=3, max_freq_log2=10.0, N_freqs=10, include_input=True
196
+ )
197
+ self.embedConv = nn.Conv2d(self.latent_dim+63,
198
+ self.latent_dim, 3, padding=1)
199
+
200
+ # @Vis_predictor
201
+ self.vis_predictor = nn.Sequential(
202
+ nn.Linear(128, 1),
203
+ )
204
+
205
+ self.embedProj = nn.Linear(63, 456)
206
+ self.zeroMLPflow = nn.Linear(195, 130)
207
+
208
+ def prepare_track(self, rgbds, queries):
209
+ """
210
+ NOTE:
211
+ Normalized the rgbs and sorted the queries via their first appeared time
212
+ Args:
213
+ rgbds: the input rgbd images (B T 4 H W)
214
+ queries: the input queries (B N 4)
215
+ Return:
216
+ rgbds: the normalized rgbds (B T 4 H W)
217
+ queries: the sorted queries (B N 4)
218
+ track_mask:
219
+ """
220
+ assert (rgbds.shape[2]==4) and (queries.shape[2]==4)
221
+ #Step1: normalize the rgbs input
222
+ device = rgbds.device
223
+ rgbds[:, :, :3, ...] = 2 * (rgbds[:, :, :3, ...] / 255.0) - 1.0
224
+ B, T, C, H, W = rgbds.shape
225
+ B, N, __ = queries.shape
226
+ self.traj_e = torch.zeros((B, T, N, 3), device=device)
227
+ self.vis_e = torch.zeros((B, T, N), device=device)
228
+
229
+ #Step2: sort the points via their first appeared time
230
+ first_positive_inds = queries[0, :, 0].long()
231
+ __, sort_inds = torch.sort(first_positive_inds, dim=0, descending=False)
232
+ inv_sort_inds = torch.argsort(sort_inds, dim=0)
233
+ first_positive_sorted_inds = first_positive_inds[sort_inds]
234
+ # check if can be inverse
235
+ assert torch.allclose(
236
+ first_positive_inds, first_positive_inds[sort_inds][inv_sort_inds]
237
+ )
238
+
239
+ # filter those points never appear points during 1 - T
240
+ ind_array = torch.arange(T, device=device)
241
+ ind_array = ind_array[None, :, None].repeat(B, 1, N)
242
+ track_mask = (ind_array >=
243
+ first_positive_inds[None, None, :]).unsqueeze(-1)
244
+
245
+ # scale the coords_init
246
+ coords_init = queries[:, :, 1:].reshape(B, 1, N, 3).repeat(
247
+ 1, self.S, 1, 1
248
+ )
249
+ coords_init[..., :2] /= float(self.stride)
250
+
251
+ #Step3: initial the regular grid
252
+ gridx = torch.linspace(0, W//self.stride - 1, W//self.stride)
253
+ gridy = torch.linspace(0, H//self.stride - 1, H//self.stride)
254
+ gridx, gridy = torch.meshgrid(gridx, gridy)
255
+ gridxy = torch.stack([gridx, gridy], dim=-1).to(rgbds.device).permute(
256
+ 2, 1, 0
257
+ )
258
+ vis_init = torch.ones((B, self.S, N, 1), device=device).float() * 10
259
+
260
+ # Step4: initial traj for neural arap
261
+ T_series = torch.linspace(0, 5, T).reshape(1, T, 1 , 1).cuda() # 1 T 1 1
262
+ T_series = T_series.repeat(B, 1, N, 1)
263
+ # get the 3d traj in the camera coordinates
264
+ intr_init = self.intrs[:,queries[0,:,0].long()]
265
+ Traj_series = pix2cam(queries[:,:,None,1:].double(), intr_init.double())
266
+ #torch.inverse(intr_init.double())@queries[:,:,1:,None].double() # B N 3 1
267
+ Traj_series = Traj_series.repeat(1, 1, T, 1).permute(0, 2, 1, 3).float()
268
+ Traj_series = torch.cat([T_series, Traj_series], dim=-1)
269
+ # get the indicator for the neural arap
270
+ Traj_mask = -1e2*torch.ones_like(T_series)
271
+ Traj_series = torch.cat([Traj_series, Traj_mask], dim=-1)
272
+
273
+ return (
274
+ rgbds,
275
+ first_positive_inds,
276
+ first_positive_sorted_inds,
277
+ sort_inds, inv_sort_inds,
278
+ track_mask, gridxy, coords_init[..., sort_inds, :].clone(),
279
+ vis_init, Traj_series[..., sort_inds, :].clone()
280
+ )
281
+
282
+ def sample_trifeat(self, t,
283
+ coords,
284
+ featMapxy,
285
+ featMapyz,
286
+ featMapxz):
287
+ """
288
+ Sample the features from the 5D triplane feature map 3*(B S C H W)
289
+ Args:
290
+ t: the time index
291
+ coords: the coordinates of the points B S N 3
292
+ featMapxy: the feature map B S C Hx Wy
293
+ featMapyz: the feature map B S C Hy Wz
294
+ featMapxz: the feature map B S C Hx Wz
295
+ """
296
+ # get xy_t yz_t xz_t
297
+ queried_t = t.reshape(1, 1, -1, 1)
298
+ xy_t = torch.cat(
299
+ [queried_t, coords[..., [0,1]]],
300
+ dim=-1
301
+ )
302
+ yz_t = torch.cat(
303
+ [queried_t, coords[..., [1, 2]]],
304
+ dim=-1
305
+ )
306
+ xz_t = torch.cat(
307
+ [queried_t, coords[..., [0, 2]]],
308
+ dim=-1
309
+ )
310
+ featxy_init = sample_features5d(featMapxy, xy_t)
311
+
312
+ featyz_init = sample_features5d(featMapyz, yz_t)
313
+ featxz_init = sample_features5d(featMapxz, xz_t)
314
+
315
+ featxy_init = featxy_init.repeat(1, self.S, 1, 1)
316
+ featyz_init = featyz_init.repeat(1, self.S, 1, 1)
317
+ featxz_init = featxz_init.repeat(1, self.S, 1, 1)
318
+
319
+ return featxy_init, featyz_init, featxz_init
320
+
321
+ def neural_arap(self, coords, Traj_arap, intrs_S, T_mark):
322
+ """ calculate the ARAP embedding and offset
323
+ Args:
324
+ coords: the coordinates of the current points 1 S N' 3
325
+ Traj_arap: the trajectory of the points 1 T N' 5
326
+ intrs_S: the camera intrinsics B S 3 3
327
+
328
+ """
329
+ coords_out = coords.clone()
330
+ coords_out[..., :2] *= float(self.stride)
331
+ coords_out[..., 2] = coords_out[..., 2]/self.Dz
332
+ coords_out[..., 2] = coords_out[..., 2]*(self.d_far-self.d_near) + self.d_near
333
+ intrs_S = intrs_S[:, :, None, ...].repeat(1, 1, coords_out.shape[2], 1, 1)
334
+ B, S, N, D = coords_out.shape
335
+ if S != intrs_S.shape[1]:
336
+ intrs_S = torch.cat(
337
+ [intrs_S, intrs_S[:, -1:].repeat(1, S - intrs_S.shape[1],1,1,1)], dim=1
338
+ )
339
+ T_mark = torch.cat(
340
+ [T_mark, T_mark[:, -1:].repeat(1, S - T_mark.shape[1],1)], dim=1
341
+ )
342
+ xyz_ = pix2cam(coords_out.double(), intrs_S.double()[:,:,0])
343
+ xyz_ = xyz_.float()
344
+ xyz_embed = torch.cat([T_mark[...,None], xyz_,
345
+ torch.zeros_like(T_mark[...,None])], dim=-1)
346
+
347
+ xyz_embed = self.embed_traj(xyz_embed)
348
+ Traj_arap_embed = self.embed_traj(Traj_arap)
349
+ d_xyz,traj_feat = self.arapFormer(xyz_embed, Traj_arap_embed)
350
+ # update in camera coordinate
351
+ xyz_ = xyz_ + d_xyz.clamp(-5, 5)
352
+ # project back to the image plane
353
+ coords_out = cam2pix(xyz_.double(), intrs_S[:,:,0].double()).float()
354
+ # resize back
355
+ coords_out[..., :2] /= float(self.stride)
356
+ coords_out[..., 2] = (coords_out[..., 2] - self.d_near)/(self.d_far-self.d_near)
357
+ coords_out[..., 2] *= self.Dz
358
+
359
+ return xyz_, coords_out, traj_feat
360
+
361
+ def gradient_arap(self, coords, aff_avg=None, aff_std=None, aff_f_sg=None,
362
+ iter=0, iter_num=4, neigh_idx=None, intr=None, msk_track=None):
363
+ with torch.enable_grad():
364
+ coords.requires_grad_(True)
365
+ y = self.ARAP_ln(coords, aff_f_sg=aff_f_sg, neigh_idx=neigh_idx,
366
+ iter=iter, iter_num=iter_num, intr=intr,msk_track=msk_track)
367
+ d_output = torch.ones_like(y, requires_grad=False, device=y.device)
368
+ gradients = torch.autograd.grad(
369
+ outputs=y,
370
+ inputs=coords,
371
+ grad_outputs=d_output,
372
+ create_graph=True,
373
+ retain_graph=True,
374
+ only_inputs=True, allow_unused=True)[0]
375
+
376
+ return gradients.detach()
377
+
378
+ def forward_iteration(
379
+ self,
380
+ fmapXY,
381
+ fmapYZ,
382
+ fmapXZ,
383
+ coords_init,
384
+ feat_init=None,
385
+ vis_init=None,
386
+ track_mask=None,
387
+ iters=4,
388
+ intrs_S=None,
389
+ ):
390
+ B, S_init, N, D = coords_init.shape
391
+ assert D == 3
392
+ assert B == 1
393
+ B, S, __, H8, W8 = fmapXY.shape
394
+ device = fmapXY.device
395
+
396
+ if S_init < S:
397
+ coords = torch.cat(
398
+ [coords_init, coords_init[:, -1].repeat(1, S - S_init, 1, 1)],
399
+ dim=1
400
+ )
401
+ vis_init = torch.cat(
402
+ [vis_init, vis_init[:, -1].repeat(1, S - S_init, 1, 1)], dim=1
403
+ )
404
+ intrs_S = torch.cat(
405
+ [intrs_S, intrs_S[:, -1].repeat(1, S - S_init, 1, 1)], dim=1
406
+ )
407
+ else:
408
+ coords = coords_init.clone()
409
+
410
+ fcorr_fnXY = CorrBlock(
411
+ fmapXY, num_levels=self.corr_levels, radius=self.corr_radius
412
+ )
413
+ fcorr_fnYZ = CorrBlock(
414
+ fmapYZ, num_levels=self.corr_levels, radius=self.corr_radius
415
+ )
416
+ fcorr_fnXZ = CorrBlock(
417
+ fmapXZ, num_levels=self.corr_levels, radius=self.corr_radius
418
+ )
419
+
420
+ ffeats = torch.split(feat_init.clone(), dim=-1, split_size_or_sections=1)
421
+ ffeats = [f.squeeze(-1) for f in ffeats]
422
+
423
+ times_ = torch.linspace(0, S - 1, S).reshape(1, S, 1)
424
+ pos_embed = sample_pos_embed(
425
+ grid_size=(H8, W8),
426
+ embed_dim=456,
427
+ coords=coords[..., :2],
428
+ )
429
+ pos_embed = rearrange(pos_embed, "b e n -> (b n) e").unsqueeze(1)
430
+
431
+ times_embed = (
432
+ torch.from_numpy(get_1d_sincos_pos_embed_from_grid(456, times_[0]))[None]
433
+ .repeat(B, 1, 1)
434
+ .float()
435
+ .to(device)
436
+ )
437
+ coord_predictions = []
438
+ attn_predictions = []
439
+ Rot_ln = 0
440
+ support_feat = self.support_features
441
+
442
+ for __ in range(iters):
443
+ coords = coords.detach()
444
+ # if self.args.if_ARAP == True:
445
+ # # refine the track with arap
446
+ # xyz_pred, coords, flows_cat0 = self.neural_arap(coords.detach(),
447
+ # Traj_arap.detach(),
448
+ # intrs_S, T_mark)
449
+ with torch.no_grad():
450
+ fcorrsXY = fcorr_fnXY.corr_sample(ffeats[0], coords[..., :2])
451
+ fcorrsYZ = fcorr_fnYZ.corr_sample(ffeats[1], coords[..., [1,2]])
452
+ fcorrsXZ = fcorr_fnXZ.corr_sample(ffeats[2], coords[..., [0,2]])
453
+ # fcorrs = fcorrsXY
454
+ fcorrs = fcorrsXY + fcorrsYZ + fcorrsXZ
455
+ LRR = fcorrs.shape[3]
456
+ fcorrs_ = fcorrs.permute(0, 2, 1, 3).reshape(B * N, S, LRR)
457
+
458
+ flows_ = (coords - coords[:, 0:1]).permute(0, 2, 1, 3).reshape(B * N, S, 3)
459
+ flows_cat = get_3d_embedding(flows_, 64, cat_coords=True)
460
+ flows_cat = self.zeroMLPflow(flows_cat)
461
+
462
+
463
+ ffeats_xy = ffeats[0].permute(0,
464
+ 2, 1, 3).reshape(B * N, S, self.latent_dim)
465
+ ffeats_yz = ffeats[1].permute(0,
466
+ 2, 1, 3).reshape(B * N, S, self.latent_dim)
467
+ ffeats_xz = ffeats[2].permute(0,
468
+ 2, 1, 3).reshape(B * N, S, self.latent_dim)
469
+ ffeats_ = ffeats_xy + ffeats_yz + ffeats_xz
470
+
471
+ if track_mask.shape[1] < vis_init.shape[1]:
472
+ track_mask = torch.cat(
473
+ [
474
+ track_mask,
475
+ torch.zeros_like(track_mask[:, 0]).repeat(
476
+ 1, vis_init.shape[1] - track_mask.shape[1], 1, 1
477
+ ),
478
+ ],
479
+ dim=1,
480
+ )
481
+ concat = (
482
+ torch.cat([track_mask, vis_init], dim=2)
483
+ .permute(0, 2, 1, 3)
484
+ .reshape(B * N, S, 2)
485
+ )
486
+
487
+ transformer_input = torch.cat([flows_cat, fcorrs_, ffeats_, concat], dim=2)
488
+
489
+ if transformer_input.shape[-1] < pos_embed.shape[-1]:
490
+ # padding the transformer_input to the same dimension as pos_embed
491
+ transformer_input = F.pad(
492
+ transformer_input, (0, pos_embed.shape[-1] - transformer_input.shape[-1]),
493
+ "constant", 0
494
+ )
495
+
496
+ x = transformer_input + pos_embed + times_embed
497
+ x = rearrange(x, "(b n) t d -> b n t d", b=B)
498
+
499
+ delta, AttnMap, so3_dist, delta_se3F, so3 = self.updateformer(x, support_feat)
500
+ support_feat = support_feat + delta_se3F[0]/100
501
+ delta = rearrange(delta, " b n t d -> (b n) t d")
502
+ d_coord = delta[:, :, :3]
503
+ d_feats = delta[:, :, 3:]
504
+
505
+ ffeats_xy = self.ffeat_updater(self.norm(d_feats.view(-1, self.latent_dim))) + ffeats_xy.reshape(-1, self.latent_dim)
506
+ ffeats_yz = self.ffeatyz_updater(self.norm(d_feats.view(-1, self.latent_dim))) + ffeats_yz.reshape(-1, self.latent_dim)
507
+ ffeats_xz = self.ffeatxz_updater(self.norm(d_feats.view(-1, self.latent_dim))) + ffeats_xz.reshape(-1, self.latent_dim)
508
+ ffeats[0] = ffeats_xy.reshape(B, N, S, self.latent_dim).permute(
509
+ 0, 2, 1, 3
510
+ ) # B,S,N,C
511
+ ffeats[1] = ffeats_yz.reshape(B, N, S, self.latent_dim).permute(
512
+ 0, 2, 1, 3
513
+ ) # B,S,N,C
514
+ ffeats[2] = ffeats_xz.reshape(B, N, S, self.latent_dim).permute(
515
+ 0, 2, 1, 3
516
+ ) # B,S,N,C
517
+ coords = coords + d_coord.reshape(B, N, S, 3).permute(0, 2, 1, 3)
518
+ if torch.isnan(coords).any():
519
+ import ipdb; ipdb.set_trace()
520
+
521
+ coords_out = coords.clone()
522
+ coords_out[..., :2] *= float(self.stride)
523
+
524
+ coords_out[..., 2] = coords_out[..., 2]/self.Dz
525
+ coords_out[..., 2] = coords_out[..., 2]*(self.d_far-self.d_near) + self.d_near
526
+
527
+ coord_predictions.append(coords_out)
528
+ attn_predictions.append(AttnMap)
529
+
530
+ ffeats_f = ffeats[0] + ffeats[1] + ffeats[2]
531
+ vis_e = self.vis_predictor(ffeats_f.reshape(B * S * N, self.latent_dim)).reshape(
532
+ B, S, N
533
+ )
534
+ self.support_features = support_feat.detach()
535
+ return coord_predictions, attn_predictions, vis_e, feat_init, Rot_ln
536
+
537
+
538
+ def forward(self, rgbds, queries, iters=4, feat_init=None,
539
+ is_train=False, intrs=None, wind_S=None):
540
+ self.support_features = torch.zeros(100, 384).to("cuda") + 0.1
541
+ self.is_train=is_train
542
+ B, T, C, H, W = rgbds.shape
543
+ # set the intrinsic or simply initialized
544
+ if intrs is None:
545
+ intrs = torch.from_numpy(np.array([[W, 0.0, W//2],
546
+ [0.0, W, H//2],
547
+ [0.0, 0.0, 1.0]]))
548
+ intrs = intrs[None,
549
+ None,...].repeat(B, T, 1, 1).float().to(rgbds.device)
550
+ self.intrs = intrs
551
+
552
+ # prepare the input for tracking
553
+ (
554
+ rgbds,
555
+ first_positive_inds,
556
+ first_positive_sorted_inds, sort_inds,
557
+ inv_sort_inds, track_mask, gridxy,
558
+ coords_init, vis_init, Traj_arap
559
+ ) = self.prepare_track(rgbds.clone(), queries)
560
+ coords_init_ = coords_init.clone()
561
+ vis_init_ = vis_init[:, :, sort_inds].clone()
562
+
563
+ depth_all = rgbds[:, :, 3,...]
564
+ d_near = self.d_near = depth_all[depth_all>0.01].min().item()
565
+ d_far = self.d_far = depth_all[depth_all>0.01].max().item()
566
+
567
+ if wind_S is not None:
568
+ self.S = wind_S
569
+
570
+ B, N, __ = queries.shape
571
+ self.Dz = Dz = W//self.stride
572
+ w_idx_start = 0
573
+ p_idx_end = 0
574
+ p_idx_start = 0
575
+ fmaps_ = None
576
+ vis_predictions = []
577
+ coord_predictions = []
578
+ attn_predictions = []
579
+ p_idx_end_list = []
580
+ Rigid_ln_total = 0
581
+ while w_idx_start < T - self.S // 2:
582
+ curr_wind_points = torch.nonzero(
583
+ first_positive_sorted_inds < w_idx_start + self.S)
584
+ if curr_wind_points.shape[0] == 0:
585
+ w_idx_start = w_idx_start + self.S // 2
586
+ continue
587
+ p_idx_end = curr_wind_points[-1] + 1
588
+ p_idx_end_list.append(p_idx_end)
589
+ # the T may not be divided by self.S
590
+ rgbds_seq = rgbds[:, w_idx_start:w_idx_start + self.S].clone()
591
+ S = S_local = rgbds_seq.shape[1]
592
+ if S < self.S:
593
+ rgbds_seq = torch.cat(
594
+ [rgbds_seq,
595
+ rgbds_seq[:, -1, None].repeat(1, self.S - S, 1, 1, 1)],
596
+ dim=1,
597
+ )
598
+ S = rgbds_seq.shape[1]
599
+
600
+ rgbs_ = rgbds_seq.reshape(B * S, C, H, W)[:, :3]
601
+ depths = rgbds_seq.reshape(B * S, C, H, W)[:, 3:].clone()
602
+ # open the mask
603
+ # Traj_arap[:, w_idx_start:w_idx_start + self.S, :p_idx_end, -1] = 0
604
+ #step1: normalize the depth map
605
+
606
+ depths = (depths - d_near)/(d_far-d_near)
607
+ depths_dn = nn.functional.interpolate(
608
+ depths, scale_factor=1.0 / self.stride, mode="nearest")
609
+ depths_dnG = depths_dn*Dz
610
+
611
+ #step2: normalize the coordinate
612
+ coords_init_[:, :, p_idx_start:p_idx_end, 2] = (
613
+ coords_init[:, :, p_idx_start:p_idx_end, 2] - d_near
614
+ )/(d_far-d_near)
615
+ coords_init_[:, :, p_idx_start:p_idx_end, 2] *= Dz
616
+
617
+ # efficient triplane splatting
618
+ gridxyz = torch.cat([gridxy[None,...].repeat(
619
+ depths_dn.shape[0],1,1,1), depths_dnG], dim=1)
620
+ Fxy2yz = gridxyz[:,[1, 2], ...] - gridxyz[:,:2]
621
+ Fxy2xz = gridxyz[:,[0, 2], ...] - gridxyz[:,:2]
622
+ if getattr(self.args, "Embed3D", None) == True:
623
+ gridxyz_nm = gridxyz.clone()
624
+ gridxyz_nm[:,0,...] = (gridxyz_nm[:,0,...]-gridxyz_nm[:,0,...].min())/(gridxyz_nm[:,0,...].max()-gridxyz_nm[:,0,...].min())
625
+ gridxyz_nm[:,1,...] = (gridxyz_nm[:,1,...]-gridxyz_nm[:,1,...].min())/(gridxyz_nm[:,1,...].max()-gridxyz_nm[:,1,...].min())
626
+ gridxyz_nm[:,2,...] = (gridxyz_nm[:,2,...]-gridxyz_nm[:,2,...].min())/(gridxyz_nm[:,2,...].max()-gridxyz_nm[:,2,...].min())
627
+ gridxyz_nm = 2*(gridxyz_nm-0.5)
628
+ _,_,h4,w4 = gridxyz_nm.shape
629
+ gridxyz_nm = gridxyz_nm.permute(0,2,3,1).reshape(S*h4*w4, 3)
630
+ featPE = self.embed3d(gridxyz_nm).view(S, h4, w4, -1).permute(0,3,1,2)
631
+ if fmaps_ is None:
632
+ fmaps_ = torch.cat([self.fnet(rgbs_),featPE], dim=1)
633
+ fmaps_ = self.embedConv(fmaps_)
634
+ else:
635
+ fmaps_new = torch.cat([self.fnet(rgbs_[self.S // 2 :]),featPE[self.S // 2 :]], dim=1)
636
+ fmaps_new = self.embedConv(fmaps_new)
637
+ fmaps_ = torch.cat(
638
+ [fmaps_[self.S // 2 :], fmaps_new], dim=0
639
+ )
640
+ else:
641
+ if fmaps_ is None:
642
+ fmaps_ = self.fnet(rgbs_)
643
+ else:
644
+ fmaps_ = torch.cat(
645
+ [fmaps_[self.S // 2 :], self.fnet(rgbs_[self.S // 2 :])], dim=0
646
+ )
647
+
648
+ fmapXY = fmaps_[:, :self.latent_dim].reshape(
649
+ B, S, self.latent_dim, H // self.stride, W // self.stride
650
+ )
651
+
652
+ fmapYZ = softsplat(fmapXY[0], Fxy2yz, None,
653
+ strMode="avg", tenoutH=self.Dz, tenoutW=H//self.stride)
654
+ fmapXZ = softsplat(fmapXY[0], Fxy2xz, None,
655
+ strMode="avg", tenoutH=self.Dz, tenoutW=W//self.stride)
656
+
657
+ fmapYZ = self.headyz(fmapYZ)[None, ...]
658
+ fmapXZ = self.headxz(fmapXZ)[None, ...]
659
+
660
+ if p_idx_end - p_idx_start > 0:
661
+ queried_t = (first_positive_sorted_inds[p_idx_start:p_idx_end]
662
+ - w_idx_start)
663
+ (featxy_init,
664
+ featyz_init,
665
+ featxz_init) = self.sample_trifeat(
666
+ t=queried_t,featMapxy=fmapXY,
667
+ featMapyz=fmapYZ,featMapxz=fmapXZ,
668
+ coords=coords_init_[:, :1, p_idx_start:p_idx_end]
669
+ )
670
+ # T, S, N, C, 3
671
+ feat_init_curr = torch.stack([featxy_init,
672
+ featyz_init, featxz_init], dim=-1)
673
+ feat_init = smart_cat(feat_init, feat_init_curr, dim=2)
674
+
675
+ if p_idx_start > 0:
676
+ # preprocess the coordinates of last windows
677
+ last_coords = coords[-1][:, self.S // 2 :].clone()
678
+ last_coords[..., :2] /= float(self.stride)
679
+ last_coords[..., 2:] = (last_coords[..., 2:]-d_near)/(d_far-d_near)
680
+ last_coords[..., 2:] = last_coords[..., 2:]*Dz
681
+
682
+ coords_init_[:, : self.S // 2, :p_idx_start] = last_coords
683
+ coords_init_[:, self.S // 2 :, :p_idx_start] = last_coords[
684
+ :, -1
685
+ ].repeat(1, self.S // 2, 1, 1)
686
+
687
+ last_vis = vis[:, self.S // 2 :].unsqueeze(-1)
688
+ vis_init_[:, : self.S // 2, :p_idx_start] = last_vis
689
+ vis_init_[:, self.S // 2 :, :p_idx_start] = last_vis[:, -1].repeat(
690
+ 1, self.S // 2, 1, 1
691
+ )
692
+
693
+ coords, attns, vis, __, Rigid_ln = self.forward_iteration(
694
+ fmapXY=fmapXY,
695
+ fmapYZ=fmapYZ,
696
+ fmapXZ=fmapXZ,
697
+ coords_init=coords_init_[:, :, :p_idx_end],
698
+ feat_init=feat_init[:, :, :p_idx_end],
699
+ vis_init=vis_init_[:, :, :p_idx_end],
700
+ track_mask=track_mask[:, w_idx_start : w_idx_start + self.S, :p_idx_end],
701
+ iters=iters,
702
+ intrs_S=self.intrs[:, w_idx_start : w_idx_start + self.S],
703
+ )
704
+
705
+ Rigid_ln_total+=Rigid_ln
706
+
707
+ if is_train:
708
+ vis_predictions.append(torch.sigmoid(vis[:, :S_local]))
709
+ coord_predictions.append([coord[:, :S_local] for coord in coords])
710
+ attn_predictions.append(attns)
711
+
712
+ self.traj_e[:, w_idx_start:w_idx_start+self.S, :p_idx_end] = coords[-1][:, :S_local]
713
+ self.vis_e[:, w_idx_start:w_idx_start+self.S, :p_idx_end] = vis[:, :S_local]
714
+
715
+ track_mask[:, : w_idx_start + self.S, :p_idx_end] = 0.0
716
+ w_idx_start = w_idx_start + self.S // 2
717
+
718
+ p_idx_start = p_idx_end
719
+
720
+ self.traj_e = self.traj_e[:, :, inv_sort_inds]
721
+ self.vis_e = self.vis_e[:, :, inv_sort_inds]
722
+
723
+ self.vis_e = torch.sigmoid(self.vis_e)
724
+ train_data = (
725
+ (vis_predictions, coord_predictions, attn_predictions,
726
+ p_idx_end_list, sort_inds, Rigid_ln_total)
727
+ )
728
+ if self.is_train:
729
+ return self.traj_e, feat_init, self.vis_e, train_data
730
+ else:
731
+ return self.traj_e, feat_init, self.vis_e
732
+
models/spatracker/models/core/spatracker/unet.py ADDED
@@ -0,0 +1,258 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ Codes are from:
3
+ https://github.com/jaxony/unet-pytorch/blob/master/model.py
4
+ '''
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+ from torch.autograd import Variable
10
+ from collections import OrderedDict
11
+ from torch.nn import init
12
+ import numpy as np
13
+
14
+ def conv3x3(in_channels, out_channels, stride=1,
15
+ padding=1, bias=True, groups=1):
16
+ return nn.Conv2d(
17
+ in_channels,
18
+ out_channels,
19
+ kernel_size=3,
20
+ stride=stride,
21
+ padding=padding,
22
+ bias=bias,
23
+ groups=groups)
24
+
25
+ def upconv2x2(in_channels, out_channels, mode='transpose'):
26
+ if mode == 'transpose':
27
+ return nn.ConvTranspose2d(
28
+ in_channels,
29
+ out_channels,
30
+ kernel_size=2,
31
+ stride=2)
32
+ else:
33
+ # out_channels is always going to be the same
34
+ # as in_channels
35
+ return nn.Sequential(
36
+ nn.Upsample(mode='bilinear', scale_factor=2),
37
+ conv1x1(in_channels, out_channels))
38
+
39
+ def conv1x1(in_channels, out_channels, groups=1):
40
+ return nn.Conv2d(
41
+ in_channels,
42
+ out_channels,
43
+ kernel_size=1,
44
+ groups=groups,
45
+ stride=1)
46
+
47
+
48
+ class DownConv(nn.Module):
49
+ """
50
+ A helper Module that performs 2 convolutions and 1 MaxPool.
51
+ A ReLU activation follows each convolution.
52
+ """
53
+ def __init__(self, in_channels, out_channels, pooling=True):
54
+ super(DownConv, self).__init__()
55
+
56
+ self.in_channels = in_channels
57
+ self.out_channels = out_channels
58
+ self.pooling = pooling
59
+
60
+ self.conv1 = conv3x3(self.in_channels, self.out_channels)
61
+ self.conv2 = conv3x3(self.out_channels, self.out_channels)
62
+
63
+ if self.pooling:
64
+ self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
65
+
66
+ def forward(self, x):
67
+ x = F.relu(self.conv1(x))
68
+ x = F.relu(self.conv2(x))
69
+ before_pool = x
70
+ if self.pooling:
71
+ x = self.pool(x)
72
+ return x, before_pool
73
+
74
+
75
+ class UpConv(nn.Module):
76
+ """
77
+ A helper Module that performs 2 convolutions and 1 UpConvolution.
78
+ A ReLU activation follows each convolution.
79
+ """
80
+ def __init__(self, in_channels, out_channels,
81
+ merge_mode='concat', up_mode='transpose'):
82
+ super(UpConv, self).__init__()
83
+
84
+ self.in_channels = in_channels
85
+ self.out_channels = out_channels
86
+ self.merge_mode = merge_mode
87
+ self.up_mode = up_mode
88
+
89
+ self.upconv = upconv2x2(self.in_channels, self.out_channels,
90
+ mode=self.up_mode)
91
+
92
+ if self.merge_mode == 'concat':
93
+ self.conv1 = conv3x3(
94
+ 2*self.out_channels, self.out_channels)
95
+ else:
96
+ # num of input channels to conv2 is same
97
+ self.conv1 = conv3x3(self.out_channels, self.out_channels)
98
+ self.conv2 = conv3x3(self.out_channels, self.out_channels)
99
+
100
+
101
+ def forward(self, from_down, from_up):
102
+ """ Forward pass
103
+ Arguments:
104
+ from_down: tensor from the encoder pathway
105
+ from_up: upconv'd tensor from the decoder pathway
106
+ """
107
+ from_up = self.upconv(from_up)
108
+ if self.merge_mode == 'concat':
109
+ x = torch.cat((from_up, from_down), 1)
110
+ else:
111
+ x = from_up + from_down
112
+ x = F.relu(self.conv1(x))
113
+ x = F.relu(self.conv2(x))
114
+ return x
115
+
116
+
117
+ class UNet(nn.Module):
118
+ """ `UNet` class is based on https://arxiv.org/abs/1505.04597
119
+
120
+ The U-Net is a convolutional encoder-decoder neural network.
121
+ Contextual spatial information (from the decoding,
122
+ expansive pathway) about an input tensor is merged with
123
+ information representing the localization of details
124
+ (from the encoding, compressive pathway).
125
+
126
+ Modifications to the original paper:
127
+ (1) padding is used in 3x3 convolutions to prevent loss
128
+ of border pixels
129
+ (2) merging outputs does not require cropping due to (1)
130
+ (3) residual connections can be used by specifying
131
+ UNet(merge_mode='add')
132
+ (4) if non-parametric upsampling is used in the decoder
133
+ pathway (specified by upmode='upsample'), then an
134
+ additional 1x1 2d convolution occurs after upsampling
135
+ to reduce channel dimensionality by a factor of 2.
136
+ This channel halving happens with the convolution in
137
+ the tranpose convolution (specified by upmode='transpose')
138
+ """
139
+
140
+ def __init__(self, num_classes, in_channels=3, depth=5,
141
+ start_filts=64, up_mode='transpose',
142
+ merge_mode='concat', **kwargs):
143
+ """
144
+ Arguments:
145
+ in_channels: int, number of channels in the input tensor.
146
+ Default is 3 for RGB images.
147
+ depth: int, number of MaxPools in the U-Net.
148
+ start_filts: int, number of convolutional filters for the
149
+ first conv.
150
+ up_mode: string, type of upconvolution. Choices: 'transpose'
151
+ for transpose convolution or 'upsample' for nearest neighbour
152
+ upsampling.
153
+ """
154
+ super(UNet, self).__init__()
155
+
156
+ if up_mode in ('transpose', 'upsample'):
157
+ self.up_mode = up_mode
158
+ else:
159
+ raise ValueError("\"{}\" is not a valid mode for "
160
+ "upsampling. Only \"transpose\" and "
161
+ "\"upsample\" are allowed.".format(up_mode))
162
+
163
+ if merge_mode in ('concat', 'add'):
164
+ self.merge_mode = merge_mode
165
+ else:
166
+ raise ValueError("\"{}\" is not a valid mode for"
167
+ "merging up and down paths. "
168
+ "Only \"concat\" and "
169
+ "\"add\" are allowed.".format(up_mode))
170
+
171
+ # NOTE: up_mode 'upsample' is incompatible with merge_mode 'add'
172
+ if self.up_mode == 'upsample' and self.merge_mode == 'add':
173
+ raise ValueError("up_mode \"upsample\" is incompatible "
174
+ "with merge_mode \"add\" at the moment "
175
+ "because it doesn't make sense to use "
176
+ "nearest neighbour to reduce "
177
+ "depth channels (by half).")
178
+
179
+ self.num_classes = num_classes
180
+ self.in_channels = in_channels
181
+ self.start_filts = start_filts
182
+ self.depth = depth
183
+
184
+ self.down_convs = []
185
+ self.up_convs = []
186
+
187
+ # create the encoder pathway and add to a list
188
+ for i in range(depth):
189
+ ins = self.in_channels if i == 0 else outs
190
+ outs = self.start_filts*(2**i)
191
+ pooling = True if i < depth-1 else False
192
+
193
+ down_conv = DownConv(ins, outs, pooling=pooling)
194
+ self.down_convs.append(down_conv)
195
+
196
+ # create the decoder pathway and add to a list
197
+ # - careful! decoding only requires depth-1 blocks
198
+ for i in range(depth-1):
199
+ ins = outs
200
+ outs = ins // 2
201
+ up_conv = UpConv(ins, outs, up_mode=up_mode,
202
+ merge_mode=merge_mode)
203
+ self.up_convs.append(up_conv)
204
+
205
+ # add the list of modules to current module
206
+ self.down_convs = nn.ModuleList(self.down_convs)
207
+ self.up_convs = nn.ModuleList(self.up_convs)
208
+
209
+ self.conv_final = conv1x1(outs, self.num_classes)
210
+
211
+ self.reset_params()
212
+
213
+ @staticmethod
214
+ def weight_init(m):
215
+ if isinstance(m, nn.Conv2d):
216
+ init.xavier_normal_(m.weight)
217
+ init.constant_(m.bias, 0)
218
+
219
+
220
+ def reset_params(self):
221
+ for i, m in enumerate(self.modules()):
222
+ self.weight_init(m)
223
+
224
+
225
+ def forward(self, x):
226
+ encoder_outs = []
227
+ # encoder pathway, save outputs for merging
228
+ for i, module in enumerate(self.down_convs):
229
+ x, before_pool = module(x)
230
+ encoder_outs.append(before_pool)
231
+ for i, module in enumerate(self.up_convs):
232
+ before_pool = encoder_outs[-(i+2)]
233
+ x = module(before_pool, x)
234
+
235
+ # No softmax is used. This means you need to use
236
+ # nn.CrossEntropyLoss is your training script,
237
+ # as this module includes a softmax already.
238
+ x = self.conv_final(x)
239
+ return x
240
+
241
+ if __name__ == "__main__":
242
+ """
243
+ testing
244
+ """
245
+ model = UNet(1, depth=5, merge_mode='concat', in_channels=1, start_filts=32)
246
+ print(model)
247
+ print(sum(p.numel() for p in model.parameters()))
248
+
249
+ reso = 176
250
+ x = np.zeros((1, 1, reso, reso))
251
+ x[:,:,int(reso/2-1), int(reso/2-1)] = np.nan
252
+ x = torch.FloatTensor(x)
253
+
254
+ out = model(x)
255
+ print('%f'%(torch.sum(torch.isnan(out)).detach().cpu().numpy()/(reso*reso)))
256
+
257
+ # loss = torch.sum(out)
258
+ # loss.backward()
models/spatracker/models/core/spatracker/vit/__init__.py ADDED
File without changes
models/spatracker/models/core/spatracker/vit/common.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+
10
+ from typing import Type
11
+
12
+
13
+ class MLPBlock(nn.Module):
14
+ def __init__(
15
+ self,
16
+ embedding_dim: int,
17
+ mlp_dim: int,
18
+ act: Type[nn.Module] = nn.GELU,
19
+ ) -> None:
20
+ super().__init__()
21
+ self.lin1 = nn.Linear(embedding_dim, mlp_dim)
22
+ self.lin2 = nn.Linear(mlp_dim, embedding_dim)
23
+ self.act = act()
24
+
25
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
26
+ return self.lin2(self.act(self.lin1(x)))
27
+
28
+
29
+ # From https://github.com/facebookresearch/detectron2/blob/main/detectron2/layers/batch_norm.py # noqa
30
+ # Itself from https://github.com/facebookresearch/ConvNeXt/blob/d1fa8f6fef0a165b27399986cc2bdacc92777e40/models/convnext.py#L119 # noqa
31
+ class LayerNorm2d(nn.Module):
32
+ def __init__(self, num_channels: int, eps: float = 1e-6) -> None:
33
+ super().__init__()
34
+ self.weight = nn.Parameter(torch.ones(num_channels))
35
+ self.bias = nn.Parameter(torch.zeros(num_channels))
36
+ self.eps = eps
37
+
38
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
39
+ u = x.mean(1, keepdim=True)
40
+ s = (x - u).pow(2).mean(1, keepdim=True)
41
+ x = (x - u) / torch.sqrt(s + self.eps)
42
+ x = self.weight[:, None, None] * x + self.bias[:, None, None]
43
+ return x
models/spatracker/models/core/spatracker/vit/encoder.py ADDED
@@ -0,0 +1,397 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F
10
+
11
+ from typing import Optional, Tuple, Type
12
+
13
+ from models.spatracker.models.core.spatracker.vit.common import (
14
+ LayerNorm2d, MLPBlock
15
+ )
16
+
17
+ # This class and its supporting functions below lightly adapted from the ViTDet backbone available at: https://github.com/facebookresearch/detectron2/blob/main/detectron2/modeling/backbone/vit.py # noqa
18
+ class ImageEncoderViT(nn.Module):
19
+ def __init__(
20
+ self,
21
+ img_size: int = 1024,
22
+ patch_size: int = 16,
23
+ in_chans: int = 3,
24
+ embed_dim: int = 768,
25
+ depth: int = 12,
26
+ num_heads: int = 12,
27
+ mlp_ratio: float = 4.0,
28
+ out_chans: int = 256,
29
+ qkv_bias: bool = True,
30
+ norm_layer: Type[nn.Module] = nn.LayerNorm,
31
+ act_layer: Type[nn.Module] = nn.GELU,
32
+ use_abs_pos: bool = True,
33
+ use_rel_pos: bool = False,
34
+ rel_pos_zero_init: bool = True,
35
+ window_size: int = 0,
36
+ global_attn_indexes: Tuple[int, ...] = (),
37
+ ) -> None:
38
+ """
39
+ Args:
40
+ img_size (int): Input image size.
41
+ patch_size (int): Patch size.
42
+ in_chans (int): Number of input image channels.
43
+ embed_dim (int): Patch embedding dimension.
44
+ depth (int): Depth of ViT.
45
+ num_heads (int): Number of attention heads in each ViT block.
46
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
47
+ qkv_bias (bool): If True, add a learnable bias to query, key, value.
48
+ norm_layer (nn.Module): Normalization layer.
49
+ act_layer (nn.Module): Activation layer.
50
+ use_abs_pos (bool): If True, use absolute positional embeddings.
51
+ use_rel_pos (bool): If True, add relative positional embeddings to the attention map.
52
+ rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
53
+ window_size (int): Window size for window attention blocks.
54
+ global_attn_indexes (list): Indexes for blocks using global attention.
55
+ """
56
+ super().__init__()
57
+ self.img_size = img_size
58
+
59
+ self.patch_embed = PatchEmbed(
60
+ kernel_size=(patch_size, patch_size),
61
+ stride=(patch_size, patch_size),
62
+ in_chans=in_chans,
63
+ embed_dim=embed_dim,
64
+ )
65
+
66
+ self.pos_embed: Optional[nn.Parameter] = None
67
+ if use_abs_pos:
68
+ # Initialize absolute positional embedding with pretrain image size.
69
+ self.pos_embed = nn.Parameter(
70
+ torch.zeros(1, img_size // patch_size, img_size // patch_size, embed_dim)
71
+ )
72
+
73
+ self.blocks = nn.ModuleList()
74
+ for i in range(depth):
75
+ block = Block(
76
+ dim=embed_dim,
77
+ num_heads=num_heads,
78
+ mlp_ratio=mlp_ratio,
79
+ qkv_bias=qkv_bias,
80
+ norm_layer=norm_layer,
81
+ act_layer=act_layer,
82
+ use_rel_pos=use_rel_pos,
83
+ rel_pos_zero_init=rel_pos_zero_init,
84
+ window_size=window_size if i not in global_attn_indexes else 0,
85
+ input_size=(img_size // patch_size, img_size // patch_size),
86
+ )
87
+ self.blocks.append(block)
88
+
89
+ self.neck = nn.Sequential(
90
+ nn.Conv2d(
91
+ embed_dim,
92
+ out_chans,
93
+ kernel_size=1,
94
+ bias=False,
95
+ ),
96
+ LayerNorm2d(out_chans),
97
+ nn.Conv2d(
98
+ out_chans,
99
+ out_chans,
100
+ kernel_size=3,
101
+ padding=1,
102
+ bias=False,
103
+ ),
104
+ LayerNorm2d(out_chans),
105
+ )
106
+
107
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
108
+
109
+ x = self.patch_embed(x)
110
+ if self.pos_embed is not None:
111
+ x = x + self.pos_embed
112
+
113
+ for blk in self.blocks:
114
+ x = blk(x)
115
+
116
+ x = self.neck(x.permute(0, 3, 1, 2))
117
+
118
+ return x
119
+
120
+
121
+ class Block(nn.Module):
122
+ """Transformer blocks with support of window attention and residual propagation blocks"""
123
+
124
+ def __init__(
125
+ self,
126
+ dim: int,
127
+ num_heads: int,
128
+ mlp_ratio: float = 4.0,
129
+ qkv_bias: bool = True,
130
+ norm_layer: Type[nn.Module] = nn.LayerNorm,
131
+ act_layer: Type[nn.Module] = nn.GELU,
132
+ use_rel_pos: bool = False,
133
+ rel_pos_zero_init: bool = True,
134
+ window_size: int = 0,
135
+ input_size: Optional[Tuple[int, int]] = None,
136
+ ) -> None:
137
+ """
138
+ Args:
139
+ dim (int): Number of input channels.
140
+ num_heads (int): Number of attention heads in each ViT block.
141
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
142
+ qkv_bias (bool): If True, add a learnable bias to query, key, value.
143
+ norm_layer (nn.Module): Normalization layer.
144
+ act_layer (nn.Module): Activation layer.
145
+ use_rel_pos (bool): If True, add relative positional embeddings to the attention map.
146
+ rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
147
+ window_size (int): Window size for window attention blocks. If it equals 0, then
148
+ use global attention.
149
+ input_size (tuple(int, int) or None): Input resolution for calculating the relative
150
+ positional parameter size.
151
+ """
152
+ super().__init__()
153
+ self.norm1 = norm_layer(dim)
154
+ self.attn = Attention(
155
+ dim,
156
+ num_heads=num_heads,
157
+ qkv_bias=qkv_bias,
158
+ use_rel_pos=use_rel_pos,
159
+ rel_pos_zero_init=rel_pos_zero_init,
160
+ input_size=input_size if window_size == 0 else (window_size, window_size),
161
+ )
162
+
163
+ self.norm2 = norm_layer(dim)
164
+ self.mlp = MLPBlock(embedding_dim=dim, mlp_dim=int(dim * mlp_ratio), act=act_layer)
165
+
166
+ self.window_size = window_size
167
+
168
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
169
+ shortcut = x
170
+ x = self.norm1(x)
171
+ # Window partition
172
+ if self.window_size > 0:
173
+ H, W = x.shape[1], x.shape[2]
174
+ x, pad_hw = window_partition(x, self.window_size)
175
+
176
+ x = self.attn(x)
177
+ # Reverse window partition
178
+ if self.window_size > 0:
179
+ x = window_unpartition(x, self.window_size, pad_hw, (H, W))
180
+
181
+ x = shortcut + x
182
+ x = x + self.mlp(self.norm2(x))
183
+
184
+ return x
185
+
186
+
187
+ class Attention(nn.Module):
188
+ """Multi-head Attention block with relative position embeddings."""
189
+
190
+ def __init__(
191
+ self,
192
+ dim: int,
193
+ num_heads: int = 8,
194
+ qkv_bias: bool = True,
195
+ use_rel_pos: bool = False,
196
+ rel_pos_zero_init: bool = True,
197
+ input_size: Optional[Tuple[int, int]] = None,
198
+ ) -> None:
199
+ """
200
+ Args:
201
+ dim (int): Number of input channels.
202
+ num_heads (int): Number of attention heads.
203
+ qkv_bias (bool): If True, add a learnable bias to query, key, value.
204
+ rel_pos (bool): If True, add relative positional embeddings to the attention map.
205
+ rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
206
+ input_size (tuple(int, int) or None): Input resolution for calculating the relative
207
+ positional parameter size.
208
+ """
209
+ super().__init__()
210
+ self.num_heads = num_heads
211
+ head_dim = dim // num_heads
212
+ self.scale = head_dim**-0.5
213
+
214
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
215
+ self.proj = nn.Linear(dim, dim)
216
+
217
+ self.use_rel_pos = use_rel_pos
218
+ if self.use_rel_pos:
219
+ assert (
220
+ input_size is not None
221
+ ), "Input size must be provided if using relative positional encoding."
222
+ # initialize relative positional embeddings
223
+ self.rel_pos_h = nn.Parameter(torch.zeros(2 * input_size[0] - 1, head_dim))
224
+ self.rel_pos_w = nn.Parameter(torch.zeros(2 * input_size[1] - 1, head_dim))
225
+
226
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
227
+ B, H, W, _ = x.shape
228
+ # qkv with shape (3, B, nHead, H * W, C)
229
+ qkv = self.qkv(x).reshape(B, H * W, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
230
+ # q, k, v with shape (B * nHead, H * W, C)
231
+ q, k, v = qkv.reshape(3, B * self.num_heads, H * W, -1).unbind(0)
232
+
233
+ attn = (q * self.scale) @ k.transpose(-2, -1)
234
+
235
+ if self.use_rel_pos:
236
+ attn = add_decomposed_rel_pos(attn, q, self.rel_pos_h, self.rel_pos_w, (H, W), (H, W))
237
+
238
+ attn = attn.softmax(dim=-1)
239
+ x = (attn @ v).view(B, self.num_heads, H, W, -1).permute(0, 2, 3, 1, 4).reshape(B, H, W, -1)
240
+ x = self.proj(x)
241
+
242
+ return x
243
+
244
+
245
+ def window_partition(x: torch.Tensor, window_size: int) -> Tuple[torch.Tensor, Tuple[int, int]]:
246
+ """
247
+ Partition into non-overlapping windows with padding if needed.
248
+ Args:
249
+ x (tensor): input tokens with [B, H, W, C].
250
+ window_size (int): window size.
251
+
252
+ Returns:
253
+ windows: windows after partition with [B * num_windows, window_size, window_size, C].
254
+ (Hp, Wp): padded height and width before partition
255
+ """
256
+ B, H, W, C = x.shape
257
+
258
+ pad_h = (window_size - H % window_size) % window_size
259
+ pad_w = (window_size - W % window_size) % window_size
260
+ if pad_h > 0 or pad_w > 0:
261
+ x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h))
262
+ Hp, Wp = H + pad_h, W + pad_w
263
+
264
+ x = x.view(B, Hp // window_size, window_size, Wp // window_size, window_size, C)
265
+ windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
266
+ return windows, (Hp, Wp)
267
+
268
+
269
+ def window_unpartition(
270
+ windows: torch.Tensor, window_size: int, pad_hw: Tuple[int, int], hw: Tuple[int, int]
271
+ ) -> torch.Tensor:
272
+ """
273
+ Window unpartition into original sequences and removing padding.
274
+ Args:
275
+ windows (tensor): input tokens with [B * num_windows, window_size, window_size, C].
276
+ window_size (int): window size.
277
+ pad_hw (Tuple): padded height and width (Hp, Wp).
278
+ hw (Tuple): original height and width (H, W) before padding.
279
+
280
+ Returns:
281
+ x: unpartitioned sequences with [B, H, W, C].
282
+ """
283
+ Hp, Wp = pad_hw
284
+ H, W = hw
285
+ B = windows.shape[0] // (Hp * Wp // window_size // window_size)
286
+ x = windows.view(B, Hp // window_size, Wp // window_size, window_size, window_size, -1)
287
+ x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, Hp, Wp, -1)
288
+
289
+ if Hp > H or Wp > W:
290
+ x = x[:, :H, :W, :].contiguous()
291
+ return x
292
+
293
+
294
+ def get_rel_pos(q_size: int, k_size: int, rel_pos: torch.Tensor) -> torch.Tensor:
295
+ """
296
+ Get relative positional embeddings according to the relative positions of
297
+ query and key sizes.
298
+ Args:
299
+ q_size (int): size of query q.
300
+ k_size (int): size of key k.
301
+ rel_pos (Tensor): relative position embeddings (L, C).
302
+
303
+ Returns:
304
+ Extracted positional embeddings according to relative positions.
305
+ """
306
+ max_rel_dist = int(2 * max(q_size, k_size) - 1)
307
+ # Interpolate rel pos if needed.
308
+ if rel_pos.shape[0] != max_rel_dist:
309
+ # Interpolate rel pos.
310
+ rel_pos_resized = F.interpolate(
311
+ rel_pos.reshape(1, rel_pos.shape[0], -1).permute(0, 2, 1),
312
+ size=max_rel_dist,
313
+ mode="linear",
314
+ )
315
+ rel_pos_resized = rel_pos_resized.reshape(-1, max_rel_dist).permute(1, 0)
316
+ else:
317
+ rel_pos_resized = rel_pos
318
+
319
+ # Scale the coords with short length if shapes for q and k are different.
320
+ q_coords = torch.arange(q_size)[:, None] * max(k_size / q_size, 1.0)
321
+ k_coords = torch.arange(k_size)[None, :] * max(q_size / k_size, 1.0)
322
+ relative_coords = (q_coords - k_coords) + (k_size - 1) * max(q_size / k_size, 1.0)
323
+
324
+ return rel_pos_resized[relative_coords.long()]
325
+
326
+
327
+ def add_decomposed_rel_pos(
328
+ attn: torch.Tensor,
329
+ q: torch.Tensor,
330
+ rel_pos_h: torch.Tensor,
331
+ rel_pos_w: torch.Tensor,
332
+ q_size: Tuple[int, int],
333
+ k_size: Tuple[int, int],
334
+ ) -> torch.Tensor:
335
+ """
336
+ Calculate decomposed Relative Positional Embeddings from :paper:`mvitv2`.
337
+ https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py # noqa B950
338
+ Args:
339
+ attn (Tensor): attention map.
340
+ q (Tensor): query q in the attention layer with shape (B, q_h * q_w, C).
341
+ rel_pos_h (Tensor): relative position embeddings (Lh, C) for height axis.
342
+ rel_pos_w (Tensor): relative position embeddings (Lw, C) for width axis.
343
+ q_size (Tuple): spatial sequence size of query q with (q_h, q_w).
344
+ k_size (Tuple): spatial sequence size of key k with (k_h, k_w).
345
+
346
+ Returns:
347
+ attn (Tensor): attention map with added relative positional embeddings.
348
+ """
349
+ q_h, q_w = q_size
350
+ k_h, k_w = k_size
351
+ Rh = get_rel_pos(q_h, k_h, rel_pos_h)
352
+ Rw = get_rel_pos(q_w, k_w, rel_pos_w)
353
+
354
+ B, _, dim = q.shape
355
+ r_q = q.reshape(B, q_h, q_w, dim)
356
+ rel_h = torch.einsum("bhwc,hkc->bhwk", r_q, Rh)
357
+ rel_w = torch.einsum("bhwc,wkc->bhwk", r_q, Rw)
358
+
359
+ attn = (
360
+ attn.view(B, q_h, q_w, k_h, k_w) + rel_h[:, :, :, :, None] + rel_w[:, :, :, None, :]
361
+ ).view(B, q_h * q_w, k_h * k_w)
362
+
363
+ return attn
364
+
365
+
366
+ class PatchEmbed(nn.Module):
367
+ """
368
+ Image to Patch Embedding.
369
+ """
370
+
371
+ def __init__(
372
+ self,
373
+ kernel_size: Tuple[int, int] = (16, 16),
374
+ stride: Tuple[int, int] = (16, 16),
375
+ padding: Tuple[int, int] = (0, 0),
376
+ in_chans: int = 3,
377
+ embed_dim: int = 768,
378
+ ) -> None:
379
+ """
380
+ Args:
381
+ kernel_size (Tuple): kernel size of the projection layer.
382
+ stride (Tuple): stride of the projection layer.
383
+ padding (Tuple): padding size of the projection layer.
384
+ in_chans (int): Number of input image channels.
385
+ embed_dim (int): Patch embedding dimension.
386
+ """
387
+ super().__init__()
388
+
389
+ self.proj = nn.Conv2d(
390
+ in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding
391
+ )
392
+
393
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
394
+ x = self.proj(x)
395
+ # B C H W -> B H W C
396
+ x = x.permute(0, 2, 3, 1)
397
+ return x
models/spatracker/predictor.py ADDED
@@ -0,0 +1,284 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import torch
8
+ import torch.nn.functional as F
9
+ import time
10
+
11
+ from tqdm import tqdm
12
+ from models.spatracker.models.core.spatracker.spatracker import get_points_on_a_grid
13
+ from models.spatracker.models.core.model_utils import smart_cat
14
+ from models.spatracker.models.build_spatracker import (
15
+ build_spatracker,
16
+ )
17
+ from models.spatracker.models.core.model_utils import (
18
+ meshgrid2d, bilinear_sample2d, smart_cat
19
+ )
20
+
21
+
22
+ class SpaTrackerPredictor(torch.nn.Module):
23
+ def __init__(
24
+ self, checkpoint="cotracker/checkpoints/cotracker_stride_4_wind_8.pth",
25
+ interp_shape=(384, 512),
26
+ seq_length=16
27
+ ):
28
+ super().__init__()
29
+ self.interp_shape = interp_shape
30
+ self.support_grid_size = 6
31
+ model = build_spatracker(checkpoint, seq_length=seq_length)
32
+
33
+ self.model = model
34
+ self.model.eval()
35
+
36
+ @torch.no_grad()
37
+ def forward(
38
+ self,
39
+ video, # (1, T, 3, H, W)
40
+ video_depth = None, # (T, 1, H, W)
41
+ # input prompt types:
42
+ # - None. Dense tracks are computed in this case. You can adjust *query_frame* to compute tracks starting from a specific frame.
43
+ # *backward_tracking=True* will compute tracks in both directions.
44
+ # - queries. Queried points of shape (1, N, 3) in format (t, x, y) for frame index and pixel coordinates.
45
+ # - grid_size. Grid of N*N points from the first frame. if segm_mask is provided, then computed only for the mask.
46
+ # You can adjust *query_frame* and *backward_tracking* for the regular grid in the same way as for dense tracks.
47
+ queries: torch.Tensor = None,
48
+ segm_mask: torch.Tensor = None, # Segmentation mask of shape (B, 1, H, W)
49
+ grid_size: int = 0,
50
+ grid_query_frame: int = 0, # only for dense and regular grid tracks
51
+ backward_tracking: bool = False,
52
+ depth_predictor=None,
53
+ wind_length: int = 8,
54
+ progressive_tracking: bool = False,
55
+ ):
56
+ if queries is None and grid_size == 0:
57
+ tracks, visibilities, T_Firsts = self._compute_dense_tracks(
58
+ video,
59
+ grid_query_frame=grid_query_frame,
60
+ backward_tracking=backward_tracking,
61
+ video_depth=video_depth,
62
+ depth_predictor=depth_predictor,
63
+ wind_length=wind_length,
64
+ )
65
+ else:
66
+ tracks, visibilities, T_Firsts = self._compute_sparse_tracks(
67
+ video,
68
+ queries,
69
+ segm_mask,
70
+ grid_size,
71
+ add_support_grid=False, #(grid_size == 0 or segm_mask is not None),
72
+ grid_query_frame=grid_query_frame,
73
+ backward_tracking=backward_tracking,
74
+ video_depth=video_depth,
75
+ depth_predictor=depth_predictor,
76
+ wind_length=wind_length,
77
+ )
78
+
79
+ return tracks, visibilities, T_Firsts
80
+
81
+ def _compute_dense_tracks(
82
+ self, video, grid_query_frame, grid_size=30, backward_tracking=False,
83
+ depth_predictor=None, video_depth=None, wind_length=8
84
+ ):
85
+ *_, H, W = video.shape
86
+ grid_step = W // grid_size
87
+ grid_width = W // grid_step
88
+ grid_height = H // grid_step
89
+ tracks = visibilities = T_Firsts = None
90
+ grid_pts = torch.zeros((1, grid_width * grid_height, 3)).to(video.device)
91
+ grid_pts[0, :, 0] = grid_query_frame
92
+ for offset in tqdm(range(grid_step * grid_step)):
93
+ ox = offset % grid_step
94
+ oy = offset // grid_step
95
+ grid_pts[0, :, 1] = (
96
+ torch.arange(grid_width).repeat(grid_height) * grid_step + ox
97
+ )
98
+ grid_pts[0, :, 2] = (
99
+ torch.arange(grid_height).repeat_interleave(grid_width) * grid_step + oy
100
+ )
101
+ tracks_step, visibilities_step, T_First_step = self._compute_sparse_tracks(
102
+ video=video,
103
+ queries=grid_pts,
104
+ backward_tracking=backward_tracking,
105
+ wind_length=wind_length,
106
+ video_depth=video_depth,
107
+ depth_predictor=depth_predictor,
108
+ )
109
+ tracks = smart_cat(tracks, tracks_step, dim=2)
110
+ visibilities = smart_cat(visibilities, visibilities_step, dim=2)
111
+ T_Firsts = smart_cat(T_Firsts, T_First_step, dim=1)
112
+
113
+
114
+ return tracks, visibilities, T_Firsts
115
+
116
+ def _compute_sparse_tracks(
117
+ self,
118
+ video,
119
+ queries,
120
+ segm_mask=None,
121
+ grid_size=0,
122
+ add_support_grid=False,
123
+ grid_query_frame=0,
124
+ backward_tracking=False,
125
+ depth_predictor=None,
126
+ video_depth=None,
127
+ wind_length=8,
128
+ ):
129
+ B, T, C, H, W = video.shape
130
+ assert B == 1
131
+
132
+ video = video.reshape(B * T, C, H, W)
133
+ video = F.interpolate(video, tuple(self.interp_shape), mode="bilinear")
134
+ video = video.reshape(B, T, 3, self.interp_shape[0], self.interp_shape[1])
135
+
136
+ if queries is not None:
137
+ queries = queries.clone()
138
+ B, N, D = queries.shape
139
+ assert D == 3
140
+ queries[:, :, 1] *= self.interp_shape[1] / W
141
+ queries[:, :, 2] *= self.interp_shape[0] / H
142
+ elif grid_size > 0:
143
+ grid_pts = get_points_on_a_grid(grid_size, self.interp_shape, device=video.device)
144
+ if segm_mask is not None:
145
+ segm_mask = F.interpolate(
146
+ segm_mask, tuple(self.interp_shape), mode="nearest"
147
+ )
148
+ point_mask = segm_mask[0, 0][
149
+ (grid_pts[0, :, 1]).round().long().cpu(),
150
+ (grid_pts[0, :, 0]).round().long().cpu(),
151
+ ].bool()
152
+ grid_pts_extra = grid_pts[:, point_mask]
153
+ else:
154
+ grid_pts_extra = None
155
+ if grid_pts_extra is not None:
156
+ total_num = int(grid_pts_extra.shape[1])
157
+ total_num = min(800, total_num)
158
+ pick_idx = torch.randperm(grid_pts_extra.shape[1])[:total_num]
159
+ grid_pts_extra = grid_pts_extra[:, pick_idx]
160
+ queries_extra = torch.cat(
161
+ [
162
+ torch.ones_like(grid_pts_extra[:, :, :1]) * grid_query_frame,
163
+ grid_pts_extra,
164
+ ],
165
+ dim=2,
166
+ )
167
+
168
+ queries = torch.cat(
169
+ [torch.zeros_like(grid_pts[:, :, :1]), grid_pts],
170
+ dim=2,
171
+ )
172
+
173
+ if add_support_grid:
174
+ grid_pts = get_points_on_a_grid(self.support_grid_size, self.interp_shape, device=video.device)
175
+ grid_pts = torch.cat(
176
+ [torch.zeros_like(grid_pts[:, :, :1]), grid_pts], dim=2
177
+ )
178
+ queries = torch.cat([queries, grid_pts], dim=1)
179
+
180
+ ## ----------- estimate the video depth -----------##
181
+ if video_depth is None:
182
+ with torch.no_grad():
183
+ if video[0].shape[0]>30:
184
+ vidDepths = []
185
+ for i in range(video[0].shape[0]//30+1):
186
+ if (i+1)*30 > video[0].shape[0]:
187
+ end_idx = video[0].shape[0]
188
+ else:
189
+ end_idx = (i+1)*30
190
+ if end_idx == i*30:
191
+ break
192
+ video_ = video[0][i*30:end_idx]
193
+ vidDepths.append(depth_predictor.infer(video_/255))
194
+
195
+ video_depth = torch.cat(vidDepths, dim=0)
196
+
197
+ else:
198
+ video_depth = depth_predictor.infer(video[0]/255)
199
+ video_depth = F.interpolate(video_depth,
200
+ tuple(self.interp_shape), mode="nearest")
201
+
202
+ # from PIL import Image
203
+ # import numpy
204
+ # depth_frame = video_depth[0].detach().cpu()
205
+ # depth_frame = depth_frame.squeeze(0)
206
+ # print(depth_frame)
207
+ # print(depth_frame.min(), depth_frame.max())
208
+ # depth_img = (depth_frame * 255).numpy().astype(numpy.uint8)
209
+ # depth_img = Image.fromarray(depth_img, mode='L')
210
+ # depth_img.save('outputs/depth_map.png')
211
+
212
+ # frame = video[0, 0].detach().cpu()
213
+ # frame = frame.permute(1, 2, 0)
214
+ # frame = (frame * 255).numpy().astype(numpy.uint8)
215
+ # frame = Image.fromarray(frame, mode='RGB')
216
+ # frame.save('outputs/frame.png')
217
+
218
+ depths = video_depth
219
+ rgbds = torch.cat([video, depths[None,...]], dim=2)
220
+ # get the 3D queries
221
+ depth_interp=[]
222
+ for i in range(queries.shape[1]):
223
+ depth_interp_i = bilinear_sample2d(video_depth[queries[:, i:i+1, 0].long()],
224
+ queries[:, i:i+1, 1], queries[:, i:i+1, 2])
225
+ depth_interp.append(depth_interp_i)
226
+
227
+ depth_interp = torch.cat(depth_interp, dim=1)
228
+ queries = smart_cat(queries, depth_interp,dim=-1)
229
+
230
+ #NOTE: free the memory of depth_predictor
231
+ del depth_predictor
232
+ torch.cuda.empty_cache()
233
+ t0 = time.time()
234
+ tracks, __, visibilities = self.model(rgbds=rgbds, queries=queries, iters=6, wind_S=wind_length)
235
+ print("Time taken for inference: ", time.time()-t0)
236
+
237
+ if backward_tracking:
238
+ tracks, visibilities = self._compute_backward_tracks(
239
+ rgbds, queries, tracks, visibilities
240
+ )
241
+ if add_support_grid:
242
+ queries[:, -self.support_grid_size ** 2 :, 0] = T - 1
243
+ if add_support_grid:
244
+ tracks = tracks[:, :, : -self.support_grid_size ** 2]
245
+ visibilities = visibilities[:, :, : -self.support_grid_size ** 2]
246
+ thr = 0.9
247
+ visibilities = visibilities > thr
248
+
249
+ # correct query-point predictions
250
+ # see https://github.com/facebookresearch/co-tracker/issues/28
251
+
252
+ # TODO: batchify
253
+ for i in range(len(queries)):
254
+ queries_t = queries[i, :tracks.size(2), 0].to(torch.int64)
255
+ arange = torch.arange(0, len(queries_t))
256
+
257
+ # overwrite the predictions with the query points
258
+ tracks[i, queries_t, arange] = queries[i, :tracks.size(2), 1:]
259
+
260
+ # correct visibilities, the query points should be visible
261
+ visibilities[i, queries_t, arange] = True
262
+
263
+ T_First = queries[..., :tracks.size(2), 0].to(torch.uint8)
264
+ tracks[:, :, :, 0] *= W / float(self.interp_shape[1])
265
+ tracks[:, :, :, 1] *= H / float(self.interp_shape[0])
266
+ return tracks, visibilities, T_First
267
+
268
+ def _compute_backward_tracks(self, video, queries, tracks, visibilities):
269
+ inv_video = video.flip(1).clone()
270
+ inv_queries = queries.clone()
271
+ inv_queries[:, :, 0] = inv_video.shape[1] - inv_queries[:, :, 0] - 1
272
+
273
+ inv_tracks, __, inv_visibilities = self.model(
274
+ rgbds=inv_video, queries=queries, iters=6
275
+ )
276
+
277
+ inv_tracks = inv_tracks.flip(1)
278
+ inv_visibilities = inv_visibilities.flip(1)
279
+
280
+ mask = tracks == 0
281
+
282
+ tracks[mask] = inv_tracks[mask]
283
+ visibilities[mask[:, :, :, 0]] = inv_visibilities[mask[:, :, :, 0]]
284
+ return tracks, visibilities
models/spatracker/utils/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
models/spatracker/utils/basic.py ADDED
@@ -0,0 +1,397 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numpy as np
3
+ from os.path import isfile
4
+ import torch
5
+ import torch.nn.functional as F
6
+ EPS = 1e-6
7
+ import copy
8
+
9
+ def sub2ind(height, width, y, x):
10
+ return y*width + x
11
+
12
+ def ind2sub(height, width, ind):
13
+ y = ind // width
14
+ x = ind % width
15
+ return y, x
16
+
17
+ def get_lr_str(lr):
18
+ lrn = "%.1e" % lr # e.g., 5.0e-04
19
+ lrn = lrn[0] + lrn[3:5] + lrn[-1] # e.g., 5e-4
20
+ return lrn
21
+
22
+ def strnum(x):
23
+ s = '%g' % x
24
+ if '.' in s:
25
+ if x < 1.0:
26
+ s = s[s.index('.'):]
27
+ s = s[:min(len(s),4)]
28
+ return s
29
+
30
+ def assert_same_shape(t1, t2):
31
+ for (x, y) in zip(list(t1.shape), list(t2.shape)):
32
+ assert(x==y)
33
+
34
+ def print_stats(name, tensor):
35
+ shape = tensor.shape
36
+ tensor = tensor.detach().cpu().numpy()
37
+ print('%s (%s) min = %.2f, mean = %.2f, max = %.2f' % (name, tensor.dtype, np.min(tensor), np.mean(tensor), np.max(tensor)), shape)
38
+
39
+ def print_stats_py(name, tensor):
40
+ shape = tensor.shape
41
+ print('%s (%s) min = %.2f, mean = %.2f, max = %.2f' % (name, tensor.dtype, np.min(tensor), np.mean(tensor), np.max(tensor)), shape)
42
+
43
+ def print_(name, tensor):
44
+ tensor = tensor.detach().cpu().numpy()
45
+ print(name, tensor, tensor.shape)
46
+
47
+ def mkdir(path):
48
+ if not os.path.exists(path):
49
+ os.makedirs(path)
50
+
51
+ def normalize_single(d):
52
+ # d is a whatever shape torch tensor
53
+ dmin = torch.min(d)
54
+ dmax = torch.max(d)
55
+ d = (d-dmin)/(EPS+(dmax-dmin))
56
+ return d
57
+
58
+ def normalize(d):
59
+ # d is B x whatever. normalize within each element of the batch
60
+ out = torch.zeros(d.size())
61
+ if d.is_cuda:
62
+ out = out.cuda()
63
+ B = list(d.size())[0]
64
+ for b in list(range(B)):
65
+ out[b] = normalize_single(d[b])
66
+ return out
67
+
68
+ def hard_argmax2d(tensor):
69
+ B, C, Y, X = list(tensor.shape)
70
+ assert(C==1)
71
+
72
+ # flatten the Tensor along the height and width axes
73
+ flat_tensor = tensor.reshape(B, -1)
74
+ # argmax of the flat tensor
75
+ argmax = torch.argmax(flat_tensor, dim=1)
76
+
77
+ # convert the indices into 2d coordinates
78
+ argmax_y = torch.floor(argmax / X) # row
79
+ argmax_x = argmax % X # col
80
+
81
+ argmax_y = argmax_y.reshape(B)
82
+ argmax_x = argmax_x.reshape(B)
83
+ return argmax_y, argmax_x
84
+
85
+ def argmax2d(heat, hard=True):
86
+ B, C, Y, X = list(heat.shape)
87
+ assert(C==1)
88
+
89
+ if hard:
90
+ # hard argmax
91
+ loc_y, loc_x = hard_argmax2d(heat)
92
+ loc_y = loc_y.float()
93
+ loc_x = loc_x.float()
94
+ else:
95
+ heat = heat.reshape(B, Y*X)
96
+ prob = torch.nn.functional.softmax(heat, dim=1)
97
+
98
+ grid_y, grid_x = meshgrid2d(B, Y, X)
99
+
100
+ grid_y = grid_y.reshape(B, -1)
101
+ grid_x = grid_x.reshape(B, -1)
102
+
103
+ loc_y = torch.sum(grid_y*prob, dim=1)
104
+ loc_x = torch.sum(grid_x*prob, dim=1)
105
+ # these are B
106
+
107
+ return loc_y, loc_x
108
+
109
+ def reduce_masked_mean(x, mask, dim=None, keepdim=False):
110
+ # x and mask are the same shape, or at least broadcastably so < actually it's safer if you disallow broadcasting
111
+ # returns shape-1
112
+ # axis can be a list of axes
113
+ for (a,b) in zip(x.size(), mask.size()):
114
+ # if not b==1:
115
+ assert(a==b) # some shape mismatch!
116
+ # assert(x.size() == mask.size())
117
+ prod = x*mask
118
+ if dim is None:
119
+ numer = torch.sum(prod)
120
+ denom = EPS+torch.sum(mask)
121
+ else:
122
+ numer = torch.sum(prod, dim=dim, keepdim=keepdim)
123
+ denom = EPS+torch.sum(mask, dim=dim, keepdim=keepdim)
124
+
125
+ mean = numer/denom
126
+ return mean
127
+
128
+ def reduce_masked_median(x, mask, keep_batch=False):
129
+ # x and mask are the same shape
130
+ assert(x.size() == mask.size())
131
+ device = x.device
132
+
133
+ B = list(x.shape)[0]
134
+ x = x.detach().cpu().numpy()
135
+ mask = mask.detach().cpu().numpy()
136
+
137
+ if keep_batch:
138
+ x = np.reshape(x, [B, -1])
139
+ mask = np.reshape(mask, [B, -1])
140
+ meds = np.zeros([B], np.float32)
141
+ for b in list(range(B)):
142
+ xb = x[b]
143
+ mb = mask[b]
144
+ if np.sum(mb) > 0:
145
+ xb = xb[mb > 0]
146
+ meds[b] = np.median(xb)
147
+ else:
148
+ meds[b] = np.nan
149
+ meds = torch.from_numpy(meds).to(device)
150
+ return meds.float()
151
+ else:
152
+ x = np.reshape(x, [-1])
153
+ mask = np.reshape(mask, [-1])
154
+ if np.sum(mask) > 0:
155
+ x = x[mask > 0]
156
+ med = np.median(x)
157
+ else:
158
+ med = np.nan
159
+ med = np.array([med], np.float32)
160
+ med = torch.from_numpy(med).to(device)
161
+ return med.float()
162
+
163
+ def pack_seqdim(tensor, B):
164
+ shapelist = list(tensor.shape)
165
+ B_, S = shapelist[:2]
166
+ assert(B==B_)
167
+ otherdims = shapelist[2:]
168
+ tensor = torch.reshape(tensor, [B*S]+otherdims)
169
+ return tensor
170
+
171
+ def unpack_seqdim(tensor, B):
172
+ shapelist = list(tensor.shape)
173
+ BS = shapelist[0]
174
+ assert(BS%B==0)
175
+ otherdims = shapelist[1:]
176
+ S = int(BS/B)
177
+ tensor = torch.reshape(tensor, [B,S]+otherdims)
178
+ return tensor
179
+
180
+ def meshgrid2d(B, Y, X, stack=False, norm=False, device='cuda', on_chans=False):
181
+ # returns a meshgrid sized B x Y x X
182
+
183
+ grid_y = torch.linspace(0.0, Y-1, Y, device=torch.device(device))
184
+ grid_y = torch.reshape(grid_y, [1, Y, 1])
185
+ grid_y = grid_y.repeat(B, 1, X)
186
+
187
+ grid_x = torch.linspace(0.0, X-1, X, device=torch.device(device))
188
+ grid_x = torch.reshape(grid_x, [1, 1, X])
189
+ grid_x = grid_x.repeat(B, Y, 1)
190
+
191
+ if norm:
192
+ grid_y, grid_x = normalize_grid2d(
193
+ grid_y, grid_x, Y, X)
194
+
195
+ if stack:
196
+ # note we stack in xy order
197
+ # (see https://pytorch.org/docs/stable/nn.functional.html#torch.nn.functional.grid_sample)
198
+ if on_chans:
199
+ grid = torch.stack([grid_x, grid_y], dim=1)
200
+ else:
201
+ grid = torch.stack([grid_x, grid_y], dim=-1)
202
+ return grid
203
+ else:
204
+ return grid_y, grid_x
205
+
206
+ def meshgrid3d(B, Z, Y, X, stack=False, norm=False, device='cuda'):
207
+ # returns a meshgrid sized B x Z x Y x X
208
+
209
+ grid_z = torch.linspace(0.0, Z-1, Z, device=device)
210
+ grid_z = torch.reshape(grid_z, [1, Z, 1, 1])
211
+ grid_z = grid_z.repeat(B, 1, Y, X)
212
+
213
+ grid_y = torch.linspace(0.0, Y-1, Y, device=device)
214
+ grid_y = torch.reshape(grid_y, [1, 1, Y, 1])
215
+ grid_y = grid_y.repeat(B, Z, 1, X)
216
+
217
+ grid_x = torch.linspace(0.0, X-1, X, device=device)
218
+ grid_x = torch.reshape(grid_x, [1, 1, 1, X])
219
+ grid_x = grid_x.repeat(B, Z, Y, 1)
220
+
221
+ # if cuda:
222
+ # grid_z = grid_z.cuda()
223
+ # grid_y = grid_y.cuda()
224
+ # grid_x = grid_x.cuda()
225
+
226
+ if norm:
227
+ grid_z, grid_y, grid_x = normalize_grid3d(
228
+ grid_z, grid_y, grid_x, Z, Y, X)
229
+
230
+ if stack:
231
+ # note we stack in xyz order
232
+ # (see https://pytorch.org/docs/stable/nn.functional.html#torch.nn.functional.grid_sample)
233
+ grid = torch.stack([grid_x, grid_y, grid_z], dim=-1)
234
+ return grid
235
+ else:
236
+ return grid_z, grid_y, grid_x
237
+
238
+ def normalize_grid2d(grid_y, grid_x, Y, X, clamp_extreme=True):
239
+ # make things in [-1,1]
240
+ grid_y = 2.0*(grid_y / float(Y-1)) - 1.0
241
+ grid_x = 2.0*(grid_x / float(X-1)) - 1.0
242
+
243
+ if clamp_extreme:
244
+ grid_y = torch.clamp(grid_y, min=-2.0, max=2.0)
245
+ grid_x = torch.clamp(grid_x, min=-2.0, max=2.0)
246
+
247
+ return grid_y, grid_x
248
+
249
+ def normalize_grid3d(grid_z, grid_y, grid_x, Z, Y, X, clamp_extreme=True):
250
+ # make things in [-1,1]
251
+ grid_z = 2.0*(grid_z / float(Z-1)) - 1.0
252
+ grid_y = 2.0*(grid_y / float(Y-1)) - 1.0
253
+ grid_x = 2.0*(grid_x / float(X-1)) - 1.0
254
+
255
+ if clamp_extreme:
256
+ grid_z = torch.clamp(grid_z, min=-2.0, max=2.0)
257
+ grid_y = torch.clamp(grid_y, min=-2.0, max=2.0)
258
+ grid_x = torch.clamp(grid_x, min=-2.0, max=2.0)
259
+
260
+ return grid_z, grid_y, grid_x
261
+
262
+ def gridcloud2d(B, Y, X, norm=False, device='cuda'):
263
+ # we want to sample for each location in the grid
264
+ grid_y, grid_x = meshgrid2d(B, Y, X, norm=norm, device=device)
265
+ x = torch.reshape(grid_x, [B, -1])
266
+ y = torch.reshape(grid_y, [B, -1])
267
+ # these are B x N
268
+ xy = torch.stack([x, y], dim=2)
269
+ # this is B x N x 2
270
+ return xy
271
+
272
+ def gridcloud3d(B, Z, Y, X, norm=False, device='cuda'):
273
+ # we want to sample for each location in the grid
274
+ grid_z, grid_y, grid_x = meshgrid3d(B, Z, Y, X, norm=norm, device=device)
275
+ x = torch.reshape(grid_x, [B, -1])
276
+ y = torch.reshape(grid_y, [B, -1])
277
+ z = torch.reshape(grid_z, [B, -1])
278
+ # these are B x N
279
+ xyz = torch.stack([x, y, z], dim=2)
280
+ # this is B x N x 3
281
+ return xyz
282
+
283
+ import re
284
+ def readPFM(file):
285
+ file = open(file, 'rb')
286
+
287
+ color = None
288
+ width = None
289
+ height = None
290
+ scale = None
291
+ endian = None
292
+
293
+ header = file.readline().rstrip()
294
+ if header == b'PF':
295
+ color = True
296
+ elif header == b'Pf':
297
+ color = False
298
+ else:
299
+ raise Exception('Not a PFM file.')
300
+
301
+ dim_match = re.match(rb'^(\d+)\s(\d+)\s$', file.readline())
302
+ if dim_match:
303
+ width, height = map(int, dim_match.groups())
304
+ else:
305
+ raise Exception('Malformed PFM header.')
306
+
307
+ scale = float(file.readline().rstrip())
308
+ if scale < 0: # little-endian
309
+ endian = '<'
310
+ scale = -scale
311
+ else:
312
+ endian = '>' # big-endian
313
+
314
+ data = np.fromfile(file, endian + 'f')
315
+ shape = (height, width, 3) if color else (height, width)
316
+
317
+ data = np.reshape(data, shape)
318
+ data = np.flipud(data)
319
+ return data
320
+
321
+ def normalize_boxlist2d(boxlist2d, H, W):
322
+ boxlist2d = boxlist2d.clone()
323
+ ymin, xmin, ymax, xmax = torch.unbind(boxlist2d, dim=2)
324
+ ymin = ymin / float(H)
325
+ ymax = ymax / float(H)
326
+ xmin = xmin / float(W)
327
+ xmax = xmax / float(W)
328
+ boxlist2d = torch.stack([ymin, xmin, ymax, xmax], dim=2)
329
+ return boxlist2d
330
+
331
+ def unnormalize_boxlist2d(boxlist2d, H, W):
332
+ boxlist2d = boxlist2d.clone()
333
+ ymin, xmin, ymax, xmax = torch.unbind(boxlist2d, dim=2)
334
+ ymin = ymin * float(H)
335
+ ymax = ymax * float(H)
336
+ xmin = xmin * float(W)
337
+ xmax = xmax * float(W)
338
+ boxlist2d = torch.stack([ymin, xmin, ymax, xmax], dim=2)
339
+ return boxlist2d
340
+
341
+ def unnormalize_box2d(box2d, H, W):
342
+ return unnormalize_boxlist2d(box2d.unsqueeze(1), H, W).squeeze(1)
343
+
344
+ def normalize_box2d(box2d, H, W):
345
+ return normalize_boxlist2d(box2d.unsqueeze(1), H, W).squeeze(1)
346
+
347
+ def get_gaussian_kernel_2d(channels, kernel_size=3, sigma=2.0, mid_one=False):
348
+ C = channels
349
+ xy_grid = gridcloud2d(C, kernel_size, kernel_size) # C x N x 2
350
+
351
+ mean = (kernel_size - 1)/2.0
352
+ variance = sigma**2.0
353
+
354
+ gaussian_kernel = (1.0/(2.0*np.pi*variance)**1.5) * torch.exp(-torch.sum((xy_grid - mean)**2.0, dim=-1) / (2.0*variance)) # C X N
355
+ gaussian_kernel = gaussian_kernel.view(C, 1, kernel_size, kernel_size) # C x 1 x 3 x 3
356
+ kernel_sum = torch.sum(gaussian_kernel, dim=(2,3), keepdim=True)
357
+
358
+ gaussian_kernel = gaussian_kernel / kernel_sum # normalize
359
+
360
+ if mid_one:
361
+ # normalize so that the middle element is 1
362
+ maxval = gaussian_kernel[:,:,(kernel_size//2),(kernel_size//2)].reshape(C, 1, 1, 1)
363
+ gaussian_kernel = gaussian_kernel / maxval
364
+
365
+ return gaussian_kernel
366
+
367
+ def gaussian_blur_2d(input, kernel_size=3, sigma=2.0, reflect_pad=False, mid_one=False):
368
+ B, C, Z, X = input.shape
369
+ kernel = get_gaussian_kernel_2d(C, kernel_size, sigma, mid_one=mid_one)
370
+ if reflect_pad:
371
+ pad = (kernel_size - 1)//2
372
+ out = F.pad(input, (pad, pad, pad, pad), mode='reflect')
373
+ out = F.conv2d(out, kernel, padding=0, groups=C)
374
+ else:
375
+ out = F.conv2d(input, kernel, padding=(kernel_size - 1)//2, groups=C)
376
+ return out
377
+
378
+ def gradient2d(x, absolute=False, square=False, return_sum=False):
379
+ # x should be B x C x H x W
380
+ dh = x[:, :, 1:, :] - x[:, :, :-1, :]
381
+ dw = x[:, :, :, 1:] - x[:, :, :, :-1]
382
+
383
+ zeros = torch.zeros_like(x)
384
+ zero_h = zeros[:, :, 0:1, :]
385
+ zero_w = zeros[:, :, :, 0:1]
386
+ dh = torch.cat([dh, zero_h], axis=2)
387
+ dw = torch.cat([dw, zero_w], axis=3)
388
+ if absolute:
389
+ dh = torch.abs(dh)
390
+ dw = torch.abs(dw)
391
+ if square:
392
+ dh = dh ** 2
393
+ dw = dw ** 2
394
+ if return_sum:
395
+ return dh+dw
396
+ else:
397
+ return dh, dw
models/spatracker/utils/geom.py ADDED
@@ -0,0 +1,547 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import models.spatracker.utils.basic
3
+ import numpy as np
4
+ import torchvision.ops as ops
5
+ from models.spatracker.utils.basic import print_
6
+
7
+ def matmul2(mat1, mat2):
8
+ return torch.matmul(mat1, mat2)
9
+
10
+ def matmul3(mat1, mat2, mat3):
11
+ return torch.matmul(mat1, torch.matmul(mat2, mat3))
12
+
13
+ def eye_3x3(B, device='cuda'):
14
+ rt = torch.eye(3, device=torch.device(device)).view(1,3,3).repeat([B, 1, 1])
15
+ return rt
16
+
17
+ def eye_4x4(B, device='cuda'):
18
+ rt = torch.eye(4, device=torch.device(device)).view(1,4,4).repeat([B, 1, 1])
19
+ return rt
20
+
21
+ def safe_inverse(a): #parallel version
22
+ B, _, _ = list(a.shape)
23
+ inv = a.clone()
24
+ r_transpose = a[:, :3, :3].transpose(1,2) #inverse of rotation matrix
25
+
26
+ inv[:, :3, :3] = r_transpose
27
+ inv[:, :3, 3:4] = -torch.matmul(r_transpose, a[:, :3, 3:4])
28
+
29
+ return inv
30
+
31
+ def safe_inverse_single(a):
32
+ r, t = split_rt_single(a)
33
+ t = t.view(3,1)
34
+ r_transpose = r.t()
35
+ inv = torch.cat([r_transpose, -torch.matmul(r_transpose, t)], 1)
36
+ bottom_row = a[3:4, :] # this is [0, 0, 0, 1]
37
+ # bottom_row = torch.tensor([0.,0.,0.,1.]).view(1,4)
38
+ inv = torch.cat([inv, bottom_row], 0)
39
+ return inv
40
+
41
+ def split_intrinsics(K):
42
+ # K is B x 3 x 3 or B x 4 x 4
43
+ fx = K[:,0,0]
44
+ fy = K[:,1,1]
45
+ x0 = K[:,0,2]
46
+ y0 = K[:,1,2]
47
+ return fx, fy, x0, y0
48
+
49
+ def apply_pix_T_cam(pix_T_cam, xyz):
50
+
51
+ fx, fy, x0, y0 = split_intrinsics(pix_T_cam)
52
+
53
+ # xyz is shaped B x H*W x 3
54
+ # returns xy, shaped B x H*W x 2
55
+
56
+ B, N, C = list(xyz.shape)
57
+ assert(C==3)
58
+
59
+ x, y, z = torch.unbind(xyz, axis=-1)
60
+
61
+ fx = torch.reshape(fx, [B, 1])
62
+ fy = torch.reshape(fy, [B, 1])
63
+ x0 = torch.reshape(x0, [B, 1])
64
+ y0 = torch.reshape(y0, [B, 1])
65
+
66
+ EPS = 1e-4
67
+ z = torch.clamp(z, min=EPS)
68
+ x = (x*fx)/(z)+x0
69
+ y = (y*fy)/(z)+y0
70
+ xy = torch.stack([x, y], axis=-1)
71
+ return xy
72
+
73
+ def apply_pix_T_cam_py(pix_T_cam, xyz):
74
+
75
+ fx, fy, x0, y0 = split_intrinsics(pix_T_cam)
76
+
77
+ # xyz is shaped B x H*W x 3
78
+ # returns xy, shaped B x H*W x 2
79
+
80
+ B, N, C = list(xyz.shape)
81
+ assert(C==3)
82
+
83
+ x, y, z = xyz[:,:,0], xyz[:,:,1], xyz[:,:,2]
84
+
85
+ fx = np.reshape(fx, [B, 1])
86
+ fy = np.reshape(fy, [B, 1])
87
+ x0 = np.reshape(x0, [B, 1])
88
+ y0 = np.reshape(y0, [B, 1])
89
+
90
+ EPS = 1e-4
91
+ z = np.clip(z, EPS, None)
92
+ x = (x*fx)/(z)+x0
93
+ y = (y*fy)/(z)+y0
94
+ xy = np.stack([x, y], axis=-1)
95
+ return xy
96
+
97
+ def get_camM_T_camXs(origin_T_camXs, ind=0):
98
+ B, S = list(origin_T_camXs.shape)[0:2]
99
+ camM_T_camXs = torch.zeros_like(origin_T_camXs)
100
+ for b in list(range(B)):
101
+ camM_T_origin = safe_inverse_single(origin_T_camXs[b,ind])
102
+ for s in list(range(S)):
103
+ camM_T_camXs[b,s] = torch.matmul(camM_T_origin, origin_T_camXs[b,s])
104
+ return camM_T_camXs
105
+
106
+ def apply_4x4(RT, xyz):
107
+ B, N, _ = list(xyz.shape)
108
+ ones = torch.ones_like(xyz[:,:,0:1])
109
+ xyz1 = torch.cat([xyz, ones], 2)
110
+ xyz1_t = torch.transpose(xyz1, 1, 2)
111
+ # this is B x 4 x N
112
+ xyz2_t = torch.matmul(RT, xyz1_t)
113
+ xyz2 = torch.transpose(xyz2_t, 1, 2)
114
+ xyz2 = xyz2[:,:,:3]
115
+ return xyz2
116
+
117
+ def apply_4x4_py(RT, xyz):
118
+ # print('RT', RT.shape)
119
+ B, N, _ = list(xyz.shape)
120
+ ones = np.ones_like(xyz[:,:,0:1])
121
+ xyz1 = np.concatenate([xyz, ones], 2)
122
+ # print('xyz1', xyz1.shape)
123
+ xyz1_t = xyz1.transpose(0,2,1)
124
+ # print('xyz1_t', xyz1_t.shape)
125
+ # this is B x 4 x N
126
+ xyz2_t = np.matmul(RT, xyz1_t)
127
+ # print('xyz2_t', xyz2_t.shape)
128
+ xyz2 = xyz2_t.transpose(0,2,1)
129
+ # print('xyz2', xyz2.shape)
130
+ xyz2 = xyz2[:,:,:3]
131
+ return xyz2
132
+
133
+ def apply_3x3(RT, xy):
134
+ B, N, _ = list(xy.shape)
135
+ ones = torch.ones_like(xy[:,:,0:1])
136
+ xy1 = torch.cat([xy, ones], 2)
137
+ xy1_t = torch.transpose(xy1, 1, 2)
138
+ # this is B x 4 x N
139
+ xy2_t = torch.matmul(RT, xy1_t)
140
+ xy2 = torch.transpose(xy2_t, 1, 2)
141
+ xy2 = xy2[:,:,:2]
142
+ return xy2
143
+
144
+ def generate_polygon(ctr_x, ctr_y, avg_r, irregularity, spikiness, num_verts):
145
+ '''
146
+ Start with the center of the polygon at ctr_x, ctr_y,
147
+ Then creates the polygon by sampling points on a circle around the center.
148
+ Random noise is added by varying the angular spacing between sequential points,
149
+ and by varying the radial distance of each point from the centre.
150
+
151
+ Params:
152
+ ctr_x, ctr_y - coordinates of the "centre" of the polygon
153
+ avg_r - in px, the average radius of this polygon, this roughly controls how large the polygon is, really only useful for order of magnitude.
154
+ irregularity - [0,1] indicating how much variance there is in the angular spacing of vertices. [0,1] will map to [0, 2pi/numberOfVerts]
155
+ spikiness - [0,1] indicating how much variance there is in each vertex from the circle of radius avg_r. [0,1] will map to [0, avg_r]
156
+ pp num_verts
157
+
158
+ Returns:
159
+ np.array [num_verts, 2] - CCW order.
160
+ '''
161
+ # spikiness
162
+ spikiness = np.clip(spikiness, 0, 1) * avg_r
163
+
164
+ # generate n angle steps
165
+ irregularity = np.clip(irregularity, 0, 1) * 2 * np.pi / num_verts
166
+ lower = (2*np.pi / num_verts) - irregularity
167
+ upper = (2*np.pi / num_verts) + irregularity
168
+
169
+ # angle steps
170
+ angle_steps = np.random.uniform(lower, upper, num_verts)
171
+ sc = (2 * np.pi) / angle_steps.sum()
172
+ angle_steps *= sc
173
+
174
+ # get all radii
175
+ angle = np.random.uniform(0, 2*np.pi)
176
+ radii = np.clip(np.random.normal(avg_r, spikiness, num_verts), 0, 2 * avg_r)
177
+
178
+ # compute all points
179
+ points = []
180
+ for i in range(num_verts):
181
+ x = ctr_x + radii[i] * np.cos(angle)
182
+ y = ctr_y + radii[i] * np.sin(angle)
183
+ points.append([x, y])
184
+ angle += angle_steps[i]
185
+
186
+ return np.array(points).astype(int)
187
+
188
+
189
+ def get_random_affine_2d(B, rot_min=-5.0, rot_max=5.0, tx_min=-0.1, tx_max=0.1, ty_min=-0.1, ty_max=0.1, sx_min=-0.05, sx_max=0.05, sy_min=-0.05, sy_max=0.05, shx_min=-0.05, shx_max=0.05, shy_min=-0.05, shy_max=0.05):
190
+ '''
191
+ Params:
192
+ rot_min: rotation amount min
193
+ rot_max: rotation amount max
194
+
195
+ tx_min: translation x min
196
+ tx_max: translation x max
197
+
198
+ ty_min: translation y min
199
+ ty_max: translation y max
200
+
201
+ sx_min: scaling x min
202
+ sx_max: scaling x max
203
+
204
+ sy_min: scaling y min
205
+ sy_max: scaling y max
206
+
207
+ shx_min: shear x min
208
+ shx_max: shear x max
209
+
210
+ shy_min: shear y min
211
+ shy_max: shear y max
212
+
213
+ Returns:
214
+ transformation matrix: (B, 3, 3)
215
+ '''
216
+ # rotation
217
+ if rot_max - rot_min != 0:
218
+ rot_amount = np.random.uniform(low=rot_min, high=rot_max, size=B)
219
+ rot_amount = np.pi/180.0*rot_amount
220
+ else:
221
+ rot_amount = rot_min
222
+ rotation = np.zeros((B, 3, 3)) # B, 3, 3
223
+ rotation[:, 2, 2] = 1
224
+ rotation[:, 0, 0] = np.cos(rot_amount)
225
+ rotation[:, 0, 1] = -np.sin(rot_amount)
226
+ rotation[:, 1, 0] = np.sin(rot_amount)
227
+ rotation[:, 1, 1] = np.cos(rot_amount)
228
+
229
+ # translation
230
+ translation = np.zeros((B, 3, 3)) # B, 3, 3
231
+ translation[:, [0,1,2], [0,1,2]] = 1
232
+ if (tx_max - tx_min) > 0:
233
+ trans_x = np.random.uniform(low=tx_min, high=tx_max, size=B)
234
+ translation[:, 0, 2] = trans_x
235
+ # else:
236
+ # translation[:, 0, 2] = tx_max
237
+ if ty_max - ty_min != 0:
238
+ trans_y = np.random.uniform(low=ty_min, high=ty_max, size=B)
239
+ translation[:, 1, 2] = trans_y
240
+ # else:
241
+ # translation[:, 1, 2] = ty_max
242
+
243
+ # scaling
244
+ scaling = np.zeros((B, 3, 3)) # B, 3, 3
245
+ scaling[:, [0,1,2], [0,1,2]] = 1
246
+ if (sx_max - sx_min) > 0:
247
+ scale_x = 1 + np.random.uniform(low=sx_min, high=sx_max, size=B)
248
+ scaling[:, 0, 0] = scale_x
249
+ # else:
250
+ # scaling[:, 0, 0] = sx_max
251
+ if (sy_max - sy_min) > 0:
252
+ scale_y = 1 + np.random.uniform(low=sy_min, high=sy_max, size=B)
253
+ scaling[:, 1, 1] = scale_y
254
+ # else:
255
+ # scaling[:, 1, 1] = sy_max
256
+
257
+ # shear
258
+ shear = np.zeros((B, 3, 3)) # B, 3, 3
259
+ shear[:, [0,1,2], [0,1,2]] = 1
260
+ if (shx_max - shx_min) > 0:
261
+ shear_x = np.random.uniform(low=shx_min, high=shx_max, size=B)
262
+ shear[:, 0, 1] = shear_x
263
+ # else:
264
+ # shear[:, 0, 1] = shx_max
265
+ if (shy_max - shy_min) > 0:
266
+ shear_y = np.random.uniform(low=shy_min, high=shy_max, size=B)
267
+ shear[:, 1, 0] = shear_y
268
+ # else:
269
+ # shear[:, 1, 0] = shy_max
270
+
271
+ # compose all those
272
+ rt = np.einsum("ijk,ikl->ijl", rotation, translation)
273
+ ss = np.einsum("ijk,ikl->ijl", scaling, shear)
274
+ trans = np.einsum("ijk,ikl->ijl", rt, ss)
275
+
276
+ return trans
277
+
278
+ def get_centroid_from_box2d(box2d):
279
+ ymin = box2d[:,0]
280
+ xmin = box2d[:,1]
281
+ ymax = box2d[:,2]
282
+ xmax = box2d[:,3]
283
+ x = (xmin+xmax)/2.0
284
+ y = (ymin+ymax)/2.0
285
+ return y, x
286
+
287
+ def normalize_boxlist2d(boxlist2d, H, W):
288
+ boxlist2d = boxlist2d.clone()
289
+ ymin, xmin, ymax, xmax = torch.unbind(boxlist2d, dim=2)
290
+ ymin = ymin / float(H)
291
+ ymax = ymax / float(H)
292
+ xmin = xmin / float(W)
293
+ xmax = xmax / float(W)
294
+ boxlist2d = torch.stack([ymin, xmin, ymax, xmax], dim=2)
295
+ return boxlist2d
296
+
297
+ def unnormalize_boxlist2d(boxlist2d, H, W):
298
+ boxlist2d = boxlist2d.clone()
299
+ ymin, xmin, ymax, xmax = torch.unbind(boxlist2d, dim=2)
300
+ ymin = ymin * float(H)
301
+ ymax = ymax * float(H)
302
+ xmin = xmin * float(W)
303
+ xmax = xmax * float(W)
304
+ boxlist2d = torch.stack([ymin, xmin, ymax, xmax], dim=2)
305
+ return boxlist2d
306
+
307
+ def unnormalize_box2d(box2d, H, W):
308
+ return unnormalize_boxlist2d(box2d.unsqueeze(1), H, W).squeeze(1)
309
+
310
+ def normalize_box2d(box2d, H, W):
311
+ return normalize_boxlist2d(box2d.unsqueeze(1), H, W).squeeze(1)
312
+
313
+ def get_size_from_box2d(box2d):
314
+ ymin = box2d[:,0]
315
+ xmin = box2d[:,1]
316
+ ymax = box2d[:,2]
317
+ xmax = box2d[:,3]
318
+ height = ymax-ymin
319
+ width = xmax-xmin
320
+ return height, width
321
+
322
+ def crop_and_resize(im, boxlist, PH, PW, boxlist_is_normalized=False):
323
+ B, C, H, W = im.shape
324
+ B2, N, D = boxlist.shape
325
+ assert(B==B2)
326
+ assert(D==4)
327
+ # PH, PW is the size to resize to
328
+
329
+ # output is B,N,C,PH,PW
330
+
331
+ # pt wants xy xy, unnormalized
332
+ if boxlist_is_normalized:
333
+ boxlist_unnorm = unnormalize_boxlist2d(boxlist, H, W)
334
+ else:
335
+ boxlist_unnorm = boxlist
336
+
337
+ ymin, xmin, ymax, xmax = boxlist_unnorm.unbind(2)
338
+ # boxlist_pt = torch.stack([boxlist_unnorm[:,1], boxlist_unnorm[:,0], boxlist_unnorm[:,3], boxlist_unnorm[:,2]], dim=1)
339
+ boxlist_pt = torch.stack([xmin, ymin, xmax, ymax], dim=2)
340
+ # we want a B-len list of K x 4 arrays
341
+
342
+ # print('im', im.shape)
343
+ # print('boxlist', boxlist.shape)
344
+ # print('boxlist_pt', boxlist_pt.shape)
345
+
346
+ # boxlist_pt = list(boxlist_pt.unbind(0))
347
+
348
+ crops = []
349
+ for b in range(B):
350
+ crops_b = ops.roi_align(im[b:b+1], [boxlist_pt[b]], output_size=(PH, PW))
351
+ crops.append(crops_b)
352
+ # # crops = im
353
+
354
+ # print('crops', crops.shape)
355
+ # crops = crops.reshape(B,N,C,PH,PW)
356
+
357
+
358
+ # crops = []
359
+ # for b in range(B):
360
+ # crop_b = ops.roi_align(im[b:b+1], [boxlist_pt[b]], output_size=(PH, PW))
361
+ # print('crop_b', crop_b.shape)
362
+ # crops.append(crop_b)
363
+ crops = torch.stack(crops, dim=0)
364
+
365
+ # print('crops', crops.shape)
366
+ # boxlist_list = boxlist_pt.unbind(0)
367
+ # print('rgb_crop', rgb_crop.shape)
368
+
369
+ return crops
370
+
371
+
372
+ # def get_boxlist_from_centroid_and_size(cy, cx, h, w, clip=True):
373
+ # # cy,cx are both B,N
374
+ # ymin = cy - h/2
375
+ # ymax = cy + h/2
376
+ # xmin = cx - w/2
377
+ # xmax = cx + w/2
378
+
379
+ # box = torch.stack([ymin, xmin, ymax, xmax], dim=-1)
380
+ # if clip:
381
+ # box = torch.clamp(box, 0, 1)
382
+ # return box
383
+
384
+
385
+ def get_boxlist_from_centroid_and_size(cy, cx, h, w):#, clip=False):
386
+ # cy,cx are the same shape
387
+ ymin = cy - h/2
388
+ ymax = cy + h/2
389
+ xmin = cx - w/2
390
+ xmax = cx + w/2
391
+
392
+ # if clip:
393
+ # ymin = torch.clamp(ymin, 0, H-1)
394
+ # ymax = torch.clamp(ymax, 0, H-1)
395
+ # xmin = torch.clamp(xmin, 0, W-1)
396
+ # xmax = torch.clamp(xmax, 0, W-1)
397
+
398
+ box = torch.stack([ymin, xmin, ymax, xmax], dim=-1)
399
+ return box
400
+
401
+
402
+ def get_box2d_from_mask(mask, normalize=False):
403
+ # mask is B, 1, H, W
404
+
405
+ B, C, H, W = mask.shape
406
+ assert(C==1)
407
+ xy = utils.basic.gridcloud2d(B, H, W, norm=False, device=mask.device) # B, H*W, 2
408
+
409
+ box = torch.zeros((B, 4), dtype=torch.float32, device=mask.device)
410
+ for b in range(B):
411
+ xy_b = xy[b] # H*W, 2
412
+ mask_b = mask[b].reshape(H*W)
413
+ xy_ = xy_b[mask_b > 0]
414
+ x_ = xy_[:,0]
415
+ y_ = xy_[:,1]
416
+ ymin = torch.min(y_)
417
+ ymax = torch.max(y_)
418
+ xmin = torch.min(x_)
419
+ xmax = torch.max(x_)
420
+ box[b] = torch.stack([ymin, xmin, ymax, xmax], dim=0)
421
+ if normalize:
422
+ box = normalize_boxlist2d(box.unsqueeze(1), H, W).squeeze(1)
423
+ return box
424
+
425
+ def convert_box2d_to_intrinsics(box2d, pix_T_cam, H, W, use_image_aspect_ratio=True, mult_padding=1.0):
426
+ # box2d is B x 4, with ymin, xmin, ymax, xmax in normalized coords
427
+ # ymin, xmin, ymax, xmax = torch.unbind(box2d, dim=1)
428
+ # H, W is the original size of the image
429
+ # mult_padding is relative to object size in pixels
430
+
431
+ # i assume we're rendering an image the same size as the original (H, W)
432
+
433
+ if not mult_padding==1.0:
434
+ y, x = get_centroid_from_box2d(box2d)
435
+ h, w = get_size_from_box2d(box2d)
436
+ box2d = get_box2d_from_centroid_and_size(
437
+ y, x, h*mult_padding, w*mult_padding, clip=False)
438
+
439
+ if use_image_aspect_ratio:
440
+ h, w = get_size_from_box2d(box2d)
441
+ y, x = get_centroid_from_box2d(box2d)
442
+
443
+ # note h,w are relative right now
444
+ # we need to undo this, to see the real ratio
445
+
446
+ h = h*float(H)
447
+ w = w*float(W)
448
+ box_ratio = h/w
449
+ im_ratio = H/float(W)
450
+
451
+ # print('box_ratio:', box_ratio)
452
+ # print('im_ratio:', im_ratio)
453
+
454
+ if box_ratio >= im_ratio:
455
+ w = h/im_ratio
456
+ # print('setting w:', h/im_ratio)
457
+ else:
458
+ h = w*im_ratio
459
+ # print('setting h:', w*im_ratio)
460
+
461
+ box2d = get_box2d_from_centroid_and_size(
462
+ y, x, h/float(H), w/float(W), clip=False)
463
+
464
+ assert(h > 1e-4)
465
+ assert(w > 1e-4)
466
+
467
+ ymin, xmin, ymax, xmax = torch.unbind(box2d, dim=1)
468
+
469
+ fx, fy, x0, y0 = split_intrinsics(pix_T_cam)
470
+
471
+ # the topleft of the new image will now have a different offset from the center of projection
472
+
473
+ new_x0 = x0 - xmin*W
474
+ new_y0 = y0 - ymin*H
475
+
476
+ pix_T_cam = pack_intrinsics(fx, fy, new_x0, new_y0)
477
+ # this alone will give me an image in original resolution,
478
+ # with its topleft at the box corner
479
+
480
+ box_h, box_w = get_size_from_box2d(box2d)
481
+ # these are normalized, and shaped B. (e.g., [0.4], [0.3])
482
+
483
+ # we are going to scale the image by the inverse of this,
484
+ # since we are zooming into this area
485
+
486
+ sy = 1./box_h
487
+ sx = 1./box_w
488
+
489
+ pix_T_cam = scale_intrinsics(pix_T_cam, sx, sy)
490
+ return pix_T_cam, box2d
491
+
492
+ def pixels2camera(x,y,z,fx,fy,x0,y0):
493
+ # x and y are locations in pixel coordinates, z is a depth in meters
494
+ # they can be images or pointclouds
495
+ # fx, fy, x0, y0 are camera intrinsics
496
+ # returns xyz, sized B x N x 3
497
+
498
+ B = x.shape[0]
499
+
500
+ fx = torch.reshape(fx, [B,1])
501
+ fy = torch.reshape(fy, [B,1])
502
+ x0 = torch.reshape(x0, [B,1])
503
+ y0 = torch.reshape(y0, [B,1])
504
+
505
+ x = torch.reshape(x, [B,-1])
506
+ y = torch.reshape(y, [B,-1])
507
+ z = torch.reshape(z, [B,-1])
508
+
509
+ # unproject
510
+ x = (z/fx)*(x-x0)
511
+ y = (z/fy)*(y-y0)
512
+
513
+ xyz = torch.stack([x,y,z], dim=2)
514
+ # B x N x 3
515
+ return xyz
516
+
517
+ def camera2pixels(xyz, pix_T_cam):
518
+ # xyz is shaped B x H*W x 3
519
+ # returns xy, shaped B x H*W x 2
520
+
521
+ fx, fy, x0, y0 = split_intrinsics(pix_T_cam)
522
+ x, y, z = torch.unbind(xyz, dim=-1)
523
+ B = list(z.shape)[0]
524
+
525
+ fx = torch.reshape(fx, [B,1])
526
+ fy = torch.reshape(fy, [B,1])
527
+ x0 = torch.reshape(x0, [B,1])
528
+ y0 = torch.reshape(y0, [B,1])
529
+ x = torch.reshape(x, [B,-1])
530
+ y = torch.reshape(y, [B,-1])
531
+ z = torch.reshape(z, [B,-1])
532
+
533
+ EPS = 1e-4
534
+ z = torch.clamp(z, min=EPS)
535
+ x = (x*fx)/z + x0
536
+ y = (y*fy)/z + y0
537
+ xy = torch.stack([x, y], dim=-1)
538
+ return xy
539
+
540
+ def depth2pointcloud(z, pix_T_cam):
541
+ B, C, H, W = list(z.shape)
542
+ device = z.device
543
+ y, x = utils.basic.meshgrid2d(B, H, W, device=device)
544
+ z = torch.reshape(z, [B, H, W])
545
+ fx, fy, x0, y0 = split_intrinsics(pix_T_cam)
546
+ xyz = pixels2camera(x, y, z, fx, fy, x0, y0)
547
+ return xyz
models/spatracker/utils/improc.py ADDED
@@ -0,0 +1,1447 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ import models.spatracker.utils.basic
4
+ from sklearn.decomposition import PCA
5
+ from matplotlib import cm
6
+ import matplotlib.pyplot as plt
7
+ import cv2
8
+ import torch.nn.functional as F
9
+ import torchvision
10
+ EPS = 1e-6
11
+
12
+ from skimage.color import (
13
+ rgb2lab, rgb2yuv, rgb2ycbcr, lab2rgb, yuv2rgb, ycbcr2rgb,
14
+ rgb2hsv, hsv2rgb, rgb2xyz, xyz2rgb, rgb2hed, hed2rgb)
15
+
16
+ def _convert(input_, type_):
17
+ return {
18
+ 'float': input_.float(),
19
+ 'double': input_.double(),
20
+ }.get(type_, input_)
21
+
22
+ def _generic_transform_sk_3d(transform, in_type='', out_type=''):
23
+ def apply_transform_individual(input_):
24
+ device = input_.device
25
+ input_ = input_.cpu()
26
+ input_ = _convert(input_, in_type)
27
+
28
+ input_ = input_.permute(1, 2, 0).detach().numpy()
29
+ transformed = transform(input_)
30
+ output = torch.from_numpy(transformed).float().permute(2, 0, 1)
31
+ output = _convert(output, out_type)
32
+ return output.to(device)
33
+
34
+ def apply_transform(input_):
35
+ to_stack = []
36
+ for image in input_:
37
+ to_stack.append(apply_transform_individual(image))
38
+ return torch.stack(to_stack)
39
+ return apply_transform
40
+
41
+ hsv_to_rgb = _generic_transform_sk_3d(hsv2rgb)
42
+
43
+ def preprocess_color_tf(x):
44
+ import tensorflow as tf
45
+ return tf.cast(x,tf.float32) * 1./255 - 0.5
46
+
47
+ def preprocess_color(x):
48
+ if isinstance(x, np.ndarray):
49
+ return x.astype(np.float32) * 1./255 - 0.5
50
+ else:
51
+ return x.float() * 1./255 - 0.5
52
+
53
+ def pca_embed(emb, keep, valid=None):
54
+ ## emb -- [S,H/2,W/2,C]
55
+ ## keep is the number of principal components to keep
56
+ ## Helper function for reduce_emb.
57
+ emb = emb + EPS
58
+ #emb is B x C x H x W
59
+ emb = emb.permute(0, 2, 3, 1).cpu().detach().numpy() #this is B x H x W x C
60
+
61
+ if valid:
62
+ valid = valid.cpu().detach().numpy().reshape((H*W))
63
+
64
+ emb_reduced = list()
65
+
66
+ B, H, W, C = np.shape(emb)
67
+ for img in emb:
68
+ if np.isnan(img).any():
69
+ emb_reduced.append(np.zeros([H, W, keep]))
70
+ continue
71
+
72
+ pixels_kd = np.reshape(img, (H*W, C))
73
+
74
+ if valid:
75
+ pixels_kd_pca = pixels_kd[valid]
76
+ else:
77
+ pixels_kd_pca = pixels_kd
78
+
79
+ P = PCA(keep)
80
+ P.fit(pixels_kd_pca)
81
+
82
+ if valid:
83
+ pixels3d = P.transform(pixels_kd)*valid
84
+ else:
85
+ pixels3d = P.transform(pixels_kd)
86
+
87
+ out_img = np.reshape(pixels3d, [H,W,keep]).astype(np.float32)
88
+ if np.isnan(out_img).any():
89
+ emb_reduced.append(np.zeros([H, W, keep]))
90
+ continue
91
+
92
+ emb_reduced.append(out_img)
93
+
94
+ emb_reduced = np.stack(emb_reduced, axis=0).astype(np.float32)
95
+
96
+ return torch.from_numpy(emb_reduced).permute(0, 3, 1, 2)
97
+
98
+ def pca_embed_together(emb, keep):
99
+ ## emb -- [S,H/2,W/2,C]
100
+ ## keep is the number of principal components to keep
101
+ ## Helper function for reduce_emb.
102
+ emb = emb + EPS
103
+ #emb is B x C x H x W
104
+ emb = emb.permute(0, 2, 3, 1).cpu().detach().numpy() #this is B x H x W x C
105
+
106
+ B, H, W, C = np.shape(emb)
107
+ if np.isnan(emb).any():
108
+ return torch.zeros(B, keep, H, W)
109
+
110
+ pixelskd = np.reshape(emb, (B*H*W, C))
111
+ P = PCA(keep)
112
+ P.fit(pixelskd)
113
+ pixels3d = P.transform(pixelskd)
114
+ out_img = np.reshape(pixels3d, [B,H,W,keep]).astype(np.float32)
115
+
116
+ if np.isnan(out_img).any():
117
+ return torch.zeros(B, keep, H, W)
118
+
119
+ return torch.from_numpy(out_img).permute(0, 3, 1, 2)
120
+
121
+ def reduce_emb(emb, valid=None, inbound=None, together=False):
122
+ ## emb -- [S,C,H/2,W/2], inbound -- [S,1,H/2,W/2]
123
+ ## Reduce number of chans to 3 with PCA. For vis.
124
+ # S,H,W,C = emb.shape.as_list()
125
+ S, C, H, W = list(emb.size())
126
+ keep = 3
127
+
128
+ if together:
129
+ reduced_emb = pca_embed_together(emb, keep)
130
+ else:
131
+ reduced_emb = pca_embed(emb, keep, valid) #not im
132
+
133
+ reduced_emb = utils.basic.normalize(reduced_emb) - 0.5
134
+ if inbound is not None:
135
+ emb_inbound = emb*inbound
136
+ else:
137
+ emb_inbound = None
138
+
139
+ return reduced_emb, emb_inbound
140
+
141
+ def get_feat_pca(feat, valid=None):
142
+ B, C, D, W = list(feat.size())
143
+ # feat is B x C x D x W. If 3D input, average it through Height dimension before passing into this function.
144
+
145
+ pca, _ = reduce_emb(feat, valid=valid,inbound=None, together=True)
146
+ # pca is B x 3 x W x D
147
+ return pca
148
+
149
+ def gif_and_tile(ims, just_gif=False):
150
+ S = len(ims)
151
+ # each im is B x H x W x C
152
+ # i want a gif in the left, and the tiled frames on the right
153
+ # for the gif tool, this means making a B x S x H x W tensor
154
+ # where the leftmost part is sequential and the rest is tiled
155
+ gif = torch.stack(ims, dim=1)
156
+ if just_gif:
157
+ return gif
158
+ til = torch.cat(ims, dim=2)
159
+ til = til.unsqueeze(dim=1).repeat(1, S, 1, 1, 1)
160
+ im = torch.cat([gif, til], dim=3)
161
+ return im
162
+
163
+ def back2color(i, blacken_zeros=False):
164
+ if blacken_zeros:
165
+ const = torch.tensor([-0.5])
166
+ i = torch.where(i==0.0, const.cuda() if i.is_cuda else const, i)
167
+ return back2color(i)
168
+ else:
169
+ return ((i+0.5)*255).type(torch.ByteTensor)
170
+
171
+ def convert_occ_to_height(occ, reduce_axis=3):
172
+ B, C, D, H, W = list(occ.shape)
173
+ assert(C==1)
174
+ # note that height increases DOWNWARD in the tensor
175
+ # (like pixel/camera coordinates)
176
+
177
+ G = list(occ.shape)[reduce_axis]
178
+ values = torch.linspace(float(G), 1.0, steps=G, dtype=torch.float32, device=occ.device)
179
+ if reduce_axis==2:
180
+ # fro view
181
+ values = values.view(1, 1, G, 1, 1)
182
+ elif reduce_axis==3:
183
+ # top view
184
+ values = values.view(1, 1, 1, G, 1)
185
+ elif reduce_axis==4:
186
+ # lateral view
187
+ values = values.view(1, 1, 1, 1, G)
188
+ else:
189
+ assert(False) # you have to reduce one of the spatial dims (2-4)
190
+ values = torch.max(occ*values, dim=reduce_axis)[0]/float(G)
191
+ # values = values.view([B, C, D, W])
192
+ return values
193
+
194
+ def xy2heatmap(xy, sigma, grid_xs, grid_ys, norm=False):
195
+ # xy is B x N x 2, containing float x and y coordinates of N things
196
+ # grid_xs and grid_ys are B x N x Y x X
197
+
198
+ B, N, Y, X = list(grid_xs.shape)
199
+
200
+ mu_x = xy[:,:,0].clone()
201
+ mu_y = xy[:,:,1].clone()
202
+
203
+ x_valid = (mu_x>-0.5) & (mu_x<float(X+0.5))
204
+ y_valid = (mu_y>-0.5) & (mu_y<float(Y+0.5))
205
+ not_valid = ~(x_valid & y_valid)
206
+
207
+ mu_x[not_valid] = -10000
208
+ mu_y[not_valid] = -10000
209
+
210
+ mu_x = mu_x.reshape(B, N, 1, 1).repeat(1, 1, Y, X)
211
+ mu_y = mu_y.reshape(B, N, 1, 1).repeat(1, 1, Y, X)
212
+
213
+ sigma_sq = sigma*sigma
214
+ # sigma_sq = (sigma*sigma).reshape(B, N, 1, 1)
215
+ sq_diff_x = (grid_xs - mu_x)**2
216
+ sq_diff_y = (grid_ys - mu_y)**2
217
+
218
+ term1 = 1./2.*np.pi*sigma_sq
219
+ term2 = torch.exp(-(sq_diff_x+sq_diff_y)/(2.*sigma_sq))
220
+ gauss = term1*term2
221
+
222
+ if norm:
223
+ # normalize so each gaussian peaks at 1
224
+ gauss_ = gauss.reshape(B*N, Y, X)
225
+ gauss_ = utils.basic.normalize(gauss_)
226
+ gauss = gauss_.reshape(B, N, Y, X)
227
+
228
+ return gauss
229
+
230
+ def xy2heatmaps(xy, Y, X, sigma=30.0, norm=True):
231
+ # xy is B x N x 2
232
+
233
+ B, N, D = list(xy.shape)
234
+ assert(D==2)
235
+
236
+ device = xy.device
237
+
238
+ grid_y, grid_x = utils.basic.meshgrid2d(B, Y, X, device=device)
239
+ # grid_x and grid_y are B x Y x X
240
+ grid_xs = grid_x.unsqueeze(1).repeat(1, N, 1, 1)
241
+ grid_ys = grid_y.unsqueeze(1).repeat(1, N, 1, 1)
242
+ heat = xy2heatmap(xy, sigma, grid_xs, grid_ys, norm=norm)
243
+ return heat
244
+
245
+ def draw_circles_at_xy(xy, Y, X, sigma=12.5, round=False):
246
+ B, N, D = list(xy.shape)
247
+ assert(D==2)
248
+ prior = xy2heatmaps(xy, Y, X, sigma=sigma)
249
+ # prior is B x N x Y x X
250
+ if round:
251
+ prior = (prior > 0.5).float()
252
+ return prior
253
+
254
+ def seq2color(im, norm=True, colormap='coolwarm'):
255
+ B, S, H, W = list(im.shape)
256
+ # S is sequential
257
+
258
+ # prep a mask of the valid pixels, so we can blacken the invalids later
259
+ mask = torch.max(im, dim=1, keepdim=True)[0]
260
+
261
+ # turn the S dim into an explicit sequence
262
+ coeffs = np.linspace(1.0, float(S), S).astype(np.float32)/float(S)
263
+
264
+ # # increase the spacing from the center
265
+ # coeffs[:int(S/2)] -= 2.0
266
+ # coeffs[int(S/2)+1:] += 2.0
267
+
268
+ coeffs = torch.from_numpy(coeffs).float().cuda()
269
+ coeffs = coeffs.reshape(1, S, 1, 1).repeat(B, 1, H, W)
270
+ # scale each channel by the right coeff
271
+ im = im * coeffs
272
+ # now im is in [1/S, 1], except for the invalid parts which are 0
273
+ # keep the highest valid coeff at each pixel
274
+ im = torch.max(im, dim=1, keepdim=True)[0]
275
+
276
+ out = []
277
+ for b in range(B):
278
+ im_ = im[b]
279
+ # move channels out to last dim_
280
+ im_ = im_.detach().cpu().numpy()
281
+ im_ = np.squeeze(im_)
282
+ # im_ is H x W
283
+ if colormap=='coolwarm':
284
+ im_ = cm.coolwarm(im_)[:, :, :3]
285
+ elif colormap=='PiYG':
286
+ im_ = cm.PiYG(im_)[:, :, :3]
287
+ elif colormap=='winter':
288
+ im_ = cm.winter(im_)[:, :, :3]
289
+ elif colormap=='spring':
290
+ im_ = cm.spring(im_)[:, :, :3]
291
+ elif colormap=='onediff':
292
+ im_ = np.reshape(im_, (-1))
293
+ im0_ = cm.spring(im_)[:, :3]
294
+ im1_ = cm.winter(im_)[:, :3]
295
+ im1_[im_==1/float(S)] = im0_[im_==1/float(S)]
296
+ im_ = np.reshape(im1_, (H, W, 3))
297
+ else:
298
+ assert(False) # invalid colormap
299
+ # move channels into dim 0
300
+ im_ = np.transpose(im_, [2, 0, 1])
301
+ im_ = torch.from_numpy(im_).float().cuda()
302
+ out.append(im_)
303
+ out = torch.stack(out, dim=0)
304
+
305
+ # blacken the invalid pixels, instead of using the 0-color
306
+ out = out*mask
307
+ # out = out*255.0
308
+
309
+ # put it in [-0.5, 0.5]
310
+ out = out - 0.5
311
+
312
+ return out
313
+
314
+ def colorize(d):
315
+ # this is actually just grayscale right now
316
+
317
+ if d.ndim==2:
318
+ d = d.unsqueeze(dim=0)
319
+ else:
320
+ assert(d.ndim==3)
321
+
322
+ # color_map = cm.get_cmap('plasma')
323
+ color_map = cm.get_cmap('inferno')
324
+ # S1, D = traj.shape
325
+
326
+ # print('d1', d.shape)
327
+ C,H,W = d.shape
328
+ assert(C==1)
329
+ d = d.reshape(-1)
330
+ d = d.detach().cpu().numpy()
331
+ # print('d2', d.shape)
332
+ color = np.array(color_map(d)) * 255 # rgba
333
+ # print('color1', color.shape)
334
+ color = np.reshape(color[:,:3], [H*W, 3])
335
+ # print('color2', color.shape)
336
+ color = torch.from_numpy(color).permute(1,0).reshape(3,H,W)
337
+ # # gather
338
+ # cm = matplotlib.cm.get_cmap(cmap if cmap is not None else 'gray')
339
+ # if cmap=='RdBu' or cmap=='RdYlGn':
340
+ # colors = cm(np.arange(256))[:, :3]
341
+ # else:
342
+ # colors = cm.colors
343
+ # colors = np.array(colors).astype(np.float32)
344
+ # colors = np.reshape(colors, [-1, 3])
345
+ # colors = tf.constant(colors, dtype=tf.float32)
346
+
347
+ # value = tf.gather(colors, indices)
348
+ # colorize(value, normalize=True, vmin=None, vmax=None, cmap=None, vals=255)
349
+
350
+ # copy to the three chans
351
+ # d = d.repeat(3, 1, 1)
352
+ return color
353
+
354
+
355
+ def oned2inferno(d, norm=True, do_colorize=False):
356
+ # convert a 1chan input to a 3chan image output
357
+
358
+ # if it's just B x H x W, add a C dim
359
+ if d.ndim==3:
360
+ d = d.unsqueeze(dim=1)
361
+ # d should be B x C x H x W, where C=1
362
+ B, C, H, W = list(d.shape)
363
+ assert(C==1)
364
+
365
+ if norm:
366
+ d = utils.basic.normalize(d)
367
+
368
+ if do_colorize:
369
+ rgb = torch.zeros(B, 3, H, W)
370
+ for b in list(range(B)):
371
+ rgb[b] = colorize(d[b])
372
+ else:
373
+ rgb = d.repeat(1, 3, 1, 1)*255.0
374
+ # rgb = (255.0*rgb).type(torch.ByteTensor)
375
+ rgb = rgb.type(torch.ByteTensor)
376
+
377
+ # rgb = tf.cast(255.0*rgb, tf.uint8)
378
+ # rgb = tf.reshape(rgb, [-1, hyp.H, hyp.W, 3])
379
+ # rgb = tf.expand_dims(rgb, axis=0)
380
+ return rgb
381
+
382
+ def oned2gray(d, norm=True):
383
+ # convert a 1chan input to a 3chan image output
384
+
385
+ # if it's just B x H x W, add a C dim
386
+ if d.ndim==3:
387
+ d = d.unsqueeze(dim=1)
388
+ # d should be B x C x H x W, where C=1
389
+ B, C, H, W = list(d.shape)
390
+ assert(C==1)
391
+
392
+ if norm:
393
+ d = utils.basic.normalize(d)
394
+
395
+ rgb = d.repeat(1,3,1,1)
396
+ rgb = (255.0*rgb).type(torch.ByteTensor)
397
+ return rgb
398
+
399
+
400
+ def draw_frame_id_on_vis(vis, frame_id, scale=0.5, left=5, top=20):
401
+
402
+ rgb = vis.detach().cpu().numpy()[0]
403
+ rgb = np.transpose(rgb, [1, 2, 0]) # put channels last
404
+ rgb = cv2.cvtColor(rgb, cv2.COLOR_RGB2BGR)
405
+ color = (255, 255, 255)
406
+ # print('putting frame id', frame_id)
407
+
408
+ frame_str = utils.basic.strnum(frame_id)
409
+
410
+ text_color_bg = (0,0,0)
411
+ font = cv2.FONT_HERSHEY_SIMPLEX
412
+ text_size, _ = cv2.getTextSize(frame_str, font, scale, 1)
413
+ text_w, text_h = text_size
414
+ cv2.rectangle(rgb, (left, top-text_h), (left + text_w, top+1), text_color_bg, -1)
415
+
416
+ cv2.putText(
417
+ rgb,
418
+ frame_str,
419
+ (left, top), # from left, from top
420
+ font,
421
+ scale, # font scale (float)
422
+ color,
423
+ 1) # font thickness (int)
424
+ rgb = cv2.cvtColor(rgb.astype(np.uint8), cv2.COLOR_BGR2RGB)
425
+ vis = torch.from_numpy(rgb).permute(2, 0, 1).unsqueeze(0)
426
+ return vis
427
+
428
+ COLORMAP_FILE = "./utils/bremm.png"
429
+ class ColorMap2d:
430
+ def __init__(self, filename=None):
431
+ self._colormap_file = filename or COLORMAP_FILE
432
+ self._img = plt.imread(self._colormap_file)
433
+
434
+ self._height = self._img.shape[0]
435
+ self._width = self._img.shape[1]
436
+
437
+ def __call__(self, X):
438
+ assert len(X.shape) == 2
439
+ output = np.zeros((X.shape[0], 3))
440
+ for i in range(X.shape[0]):
441
+ x, y = X[i, :]
442
+ xp = int((self._width-1) * x)
443
+ yp = int((self._height-1) * y)
444
+ xp = np.clip(xp, 0, self._width-1)
445
+ yp = np.clip(yp, 0, self._height-1)
446
+ output[i, :] = self._img[yp, xp]
447
+ return output
448
+
449
+ def get_n_colors(N, sequential=False):
450
+ label_colors = []
451
+ for ii in range(N):
452
+ if sequential:
453
+ rgb = cm.winter(ii/(N-1))
454
+ rgb = (np.array(rgb) * 255).astype(np.uint8)[:3]
455
+ else:
456
+ rgb = np.zeros(3)
457
+ while np.sum(rgb) < 128: # ensure min brightness
458
+ rgb = np.random.randint(0,256,3)
459
+ label_colors.append(rgb)
460
+ return label_colors
461
+
462
+ class Summ_writer(object):
463
+ def __init__(self, writer, global_step, log_freq=10, fps=8, scalar_freq=100, just_gif=False):
464
+ self.writer = writer
465
+ self.global_step = global_step
466
+ self.log_freq = log_freq
467
+ self.fps = fps
468
+ self.just_gif = just_gif
469
+ self.maxwidth = 10000
470
+ self.save_this = (self.global_step % self.log_freq == 0)
471
+ self.scalar_freq = max(scalar_freq,1)
472
+
473
+
474
+ def summ_gif(self, name, tensor, blacken_zeros=False):
475
+ # tensor should be in B x S x C x H x W
476
+
477
+ assert tensor.dtype in {torch.uint8,torch.float32}
478
+ shape = list(tensor.shape)
479
+
480
+ if tensor.dtype == torch.float32:
481
+ tensor = back2color(tensor, blacken_zeros=blacken_zeros)
482
+
483
+ video_to_write = tensor[0:1]
484
+
485
+ S = video_to_write.shape[1]
486
+ if S==1:
487
+ # video_to_write is 1 x 1 x C x H x W
488
+ self.writer.add_image(name, video_to_write[0,0], global_step=self.global_step)
489
+ else:
490
+ self.writer.add_video(name, video_to_write, fps=self.fps, global_step=self.global_step)
491
+
492
+ return video_to_write
493
+
494
+ def draw_boxlist2d_on_image(self, rgb, boxlist, scores=None, tids=None, linewidth=1):
495
+ B, C, H, W = list(rgb.shape)
496
+ assert(C==3)
497
+ B2, N, D = list(boxlist.shape)
498
+ assert(B2==B)
499
+ assert(D==4) # ymin, xmin, ymax, xmax
500
+
501
+ rgb = back2color(rgb)
502
+ if scores is None:
503
+ scores = torch.ones(B2, N).float()
504
+ if tids is None:
505
+ tids = torch.arange(N).reshape(1,N).repeat(B2,N).long()
506
+ # tids = torch.zeros(B2, N).long()
507
+ out = self.draw_boxlist2d_on_image_py(
508
+ rgb[0].cpu().detach().numpy(),
509
+ boxlist[0].cpu().detach().numpy(),
510
+ scores[0].cpu().detach().numpy(),
511
+ tids[0].cpu().detach().numpy(),
512
+ linewidth=linewidth)
513
+ out = torch.from_numpy(out).type(torch.ByteTensor).permute(2, 0, 1)
514
+ out = torch.unsqueeze(out, dim=0)
515
+ out = preprocess_color(out)
516
+ out = torch.reshape(out, [1, C, H, W])
517
+ return out
518
+
519
+ def draw_boxlist2d_on_image_py(self, rgb, boxlist, scores, tids, linewidth=1):
520
+ # all inputs are numpy tensors
521
+ # rgb is H x W x 3
522
+ # boxlist is N x 4
523
+ # scores is N
524
+ # tids is N
525
+
526
+ rgb = np.transpose(rgb, [1, 2, 0]) # put channels last
527
+ # rgb = cv2.cvtColor(rgb, cv2.COLOR_RGB2BGR)
528
+
529
+ rgb = rgb.astype(np.uint8).copy()
530
+
531
+
532
+ H, W, C = rgb.shape
533
+ assert(C==3)
534
+ N, D = boxlist.shape
535
+ assert(D==4)
536
+
537
+ # color_map = cm.get_cmap('tab20')
538
+ # color_map = cm.get_cmap('set1')
539
+ color_map = cm.get_cmap('Accent')
540
+ color_map = color_map.colors
541
+ # print('color_map', color_map)
542
+
543
+ # draw
544
+ for ind, box in enumerate(boxlist):
545
+ # box is 4
546
+ if not np.isclose(scores[ind], 0.0):
547
+ # box = utils.geom.scale_box2d(box, H, W)
548
+ ymin, xmin, ymax, xmax = box
549
+
550
+ # ymin, ymax = ymin*H, ymax*H
551
+ # xmin, xmax = xmin*W, xmax*W
552
+
553
+ # print 'score = %.2f' % scores[ind]
554
+ # color_id = tids[ind] % 20
555
+ color_id = tids[ind]
556
+ color = color_map[color_id]
557
+ color = np.array(color)*255.0
558
+ color = color.round()
559
+ # color = color.astype(np.uint8)
560
+ # color = color[::-1]
561
+ # print('color', color)
562
+
563
+ # print 'tid = %d; score = %.3f' % (tids[ind], scores[ind])
564
+
565
+ # if False:
566
+ if scores[ind] < 1.0: # not gt
567
+ cv2.putText(rgb,
568
+ # '%d (%.2f)' % (tids[ind], scores[ind]),
569
+ '%.2f' % (scores[ind]),
570
+ (int(xmin), int(ymin)),
571
+ cv2.FONT_HERSHEY_SIMPLEX,
572
+ 0.5, # font size
573
+ color),
574
+ #1) # font weight
575
+
576
+ xmin = np.clip(int(xmin), 0, W-1)
577
+ xmax = np.clip(int(xmax), 0, W-1)
578
+ ymin = np.clip(int(ymin), 0, H-1)
579
+ ymax = np.clip(int(ymax), 0, H-1)
580
+
581
+ cv2.line(rgb, (xmin, ymin), (xmin, ymax), color, linewidth, cv2.LINE_AA)
582
+ cv2.line(rgb, (xmin, ymin), (xmax, ymin), color, linewidth, cv2.LINE_AA)
583
+ cv2.line(rgb, (xmax, ymin), (xmax, ymax), color, linewidth, cv2.LINE_AA)
584
+ cv2.line(rgb, (xmax, ymax), (xmin, ymax), color, linewidth, cv2.LINE_AA)
585
+
586
+ # rgb = cv2.cvtColor(rgb.astype(np.uint8), cv2.COLOR_BGR2RGB)
587
+ return rgb
588
+
589
+ def summ_boxlist2d(self, name, rgb, boxlist, scores=None, tids=None, frame_id=None, only_return=False, linewidth=2):
590
+ B, C, H, W = list(rgb.shape)
591
+ boxlist_vis = self.draw_boxlist2d_on_image(rgb, boxlist, scores=scores, tids=tids, linewidth=linewidth)
592
+ return self.summ_rgb(name, boxlist_vis, frame_id=frame_id, only_return=only_return)
593
+
594
+ def summ_rgbs(self, name, ims, frame_ids=None, blacken_zeros=False, only_return=False):
595
+ if self.save_this:
596
+
597
+ ims = gif_and_tile(ims, just_gif=self.just_gif)
598
+ vis = ims
599
+
600
+ assert vis.dtype in {torch.uint8,torch.float32}
601
+
602
+ if vis.dtype == torch.float32:
603
+ vis = back2color(vis, blacken_zeros)
604
+
605
+ B, S, C, H, W = list(vis.shape)
606
+
607
+ if frame_ids is not None:
608
+ assert(len(frame_ids)==S)
609
+ for s in range(S):
610
+ vis[:,s] = draw_frame_id_on_vis(vis[:,s], frame_ids[s])
611
+
612
+ if int(W) > self.maxwidth:
613
+ vis = vis[:,:,:,:self.maxwidth]
614
+
615
+ if only_return:
616
+ return vis
617
+ else:
618
+ return self.summ_gif(name, vis, blacken_zeros)
619
+
620
+ def summ_rgb(self, name, ims, blacken_zeros=False, frame_id=None, only_return=False, halfres=False):
621
+ if self.save_this:
622
+ assert ims.dtype in {torch.uint8,torch.float32}
623
+
624
+ if ims.dtype == torch.float32:
625
+ ims = back2color(ims, blacken_zeros)
626
+
627
+ #ims is B x C x H x W
628
+ vis = ims[0:1] # just the first one
629
+ B, C, H, W = list(vis.shape)
630
+
631
+ if halfres:
632
+ vis = F.interpolate(vis, scale_factor=0.5)
633
+
634
+ if frame_id is not None:
635
+ vis = draw_frame_id_on_vis(vis, frame_id)
636
+
637
+ if int(W) > self.maxwidth:
638
+ vis = vis[:,:,:,:self.maxwidth]
639
+
640
+ if only_return:
641
+ return vis
642
+ else:
643
+ return self.summ_gif(name, vis.unsqueeze(1), blacken_zeros)
644
+
645
+ def flow2color(self, flow, clip=50.0):
646
+ """
647
+ :param flow: Optical flow tensor.
648
+ :return: RGB image normalized between 0 and 1.
649
+ """
650
+
651
+ # flow is B x C x H x W
652
+
653
+ B, C, H, W = list(flow.size())
654
+
655
+ flow = flow.clone().detach()
656
+
657
+ abs_image = torch.abs(flow)
658
+ flow_mean = abs_image.mean(dim=[1,2,3])
659
+ flow_std = abs_image.std(dim=[1,2,3])
660
+
661
+ if clip:
662
+ flow = torch.clamp(flow, -clip, clip)/clip
663
+ else:
664
+ # Apply some kind of normalization. Divide by the perceived maximum (mean + std*2)
665
+ flow_max = flow_mean + flow_std*2 + 1e-10
666
+ for b in range(B):
667
+ flow[b] = flow[b].clamp(-flow_max[b].item(), flow_max[b].item()) / flow_max[b].clamp(min=1)
668
+
669
+ radius = torch.sqrt(torch.sum(flow**2, dim=1, keepdim=True)) #B x 1 x H x W
670
+ radius_clipped = torch.clamp(radius, 0.0, 1.0)
671
+
672
+ angle = torch.atan2(flow[:, 1:], flow[:, 0:1]) / np.pi #B x 1 x H x W
673
+
674
+ hue = torch.clamp((angle + 1.0) / 2.0, 0.0, 1.0)
675
+ saturation = torch.ones_like(hue) * 0.75
676
+ value = radius_clipped
677
+ hsv = torch.cat([hue, saturation, value], dim=1) #B x 3 x H x W
678
+
679
+ #flow = tf.image.hsv_to_rgb(hsv)
680
+ flow = hsv_to_rgb(hsv)
681
+ flow = (flow*255.0).type(torch.ByteTensor)
682
+ return flow
683
+
684
+ def summ_flow(self, name, im, clip=0.0, only_return=False, frame_id=None):
685
+ # flow is B x C x D x W
686
+ if self.save_this:
687
+ return self.summ_rgb(name, self.flow2color(im, clip=clip), only_return=only_return, frame_id=frame_id)
688
+ else:
689
+ return None
690
+
691
+ def summ_oneds(self, name, ims, frame_ids=None, bev=False, fro=False, logvis=False, reduce_max=False, max_val=0.0, norm=True, only_return=False, do_colorize=False):
692
+ if self.save_this:
693
+ if bev:
694
+ B, C, H, _, W = list(ims[0].shape)
695
+ if reduce_max:
696
+ ims = [torch.max(im, dim=3)[0] for im in ims]
697
+ else:
698
+ ims = [torch.mean(im, dim=3) for im in ims]
699
+ elif fro:
700
+ B, C, _, H, W = list(ims[0].shape)
701
+ if reduce_max:
702
+ ims = [torch.max(im, dim=2)[0] for im in ims]
703
+ else:
704
+ ims = [torch.mean(im, dim=2) for im in ims]
705
+
706
+
707
+ if len(ims) != 1: # sequence
708
+ im = gif_and_tile(ims, just_gif=self.just_gif)
709
+ else:
710
+ im = torch.stack(ims, dim=1) # single frame
711
+
712
+ B, S, C, H, W = list(im.shape)
713
+
714
+ if logvis and max_val:
715
+ max_val = np.log(max_val)
716
+ im = torch.log(torch.clamp(im, 0)+1.0)
717
+ im = torch.clamp(im, 0, max_val)
718
+ im = im/max_val
719
+ norm = False
720
+ elif max_val:
721
+ im = torch.clamp(im, 0, max_val)
722
+ im = im/max_val
723
+ norm = False
724
+
725
+ if norm:
726
+ # normalize before oned2inferno,
727
+ # so that the ranges are similar within B across S
728
+ im = utils.basic.normalize(im)
729
+
730
+ im = im.view(B*S, C, H, W)
731
+ vis = oned2inferno(im, norm=norm, do_colorize=do_colorize)
732
+ vis = vis.view(B, S, 3, H, W)
733
+
734
+ if frame_ids is not None:
735
+ assert(len(frame_ids)==S)
736
+ for s in range(S):
737
+ vis[:,s] = draw_frame_id_on_vis(vis[:,s], frame_ids[s])
738
+
739
+ if W > self.maxwidth:
740
+ vis = vis[...,:self.maxwidth]
741
+
742
+ if only_return:
743
+ return vis
744
+ else:
745
+ self.summ_gif(name, vis)
746
+
747
+ def summ_oned(self, name, im, bev=False, fro=False, logvis=False, max_val=0, max_along_y=False, norm=True, frame_id=None, only_return=False):
748
+ if self.save_this:
749
+
750
+ if bev:
751
+ B, C, H, _, W = list(im.shape)
752
+ if max_along_y:
753
+ im = torch.max(im, dim=3)[0]
754
+ else:
755
+ im = torch.mean(im, dim=3)
756
+ elif fro:
757
+ B, C, _, H, W = list(im.shape)
758
+ if max_along_y:
759
+ im = torch.max(im, dim=2)[0]
760
+ else:
761
+ im = torch.mean(im, dim=2)
762
+ else:
763
+ B, C, H, W = list(im.shape)
764
+
765
+ im = im[0:1] # just the first one
766
+ assert(C==1)
767
+
768
+ if logvis and max_val:
769
+ max_val = np.log(max_val)
770
+ im = torch.log(im)
771
+ im = torch.clamp(im, 0, max_val)
772
+ im = im/max_val
773
+ norm = False
774
+ elif max_val:
775
+ im = torch.clamp(im, 0, max_val)/max_val
776
+ norm = False
777
+
778
+ vis = oned2inferno(im, norm=norm)
779
+ if W > self.maxwidth:
780
+ vis = vis[...,:self.maxwidth]
781
+ return self.summ_rgb(name, vis, blacken_zeros=False, frame_id=frame_id, only_return=only_return)
782
+
783
+ def summ_feats(self, name, feats, valids=None, pca=True, fro=False, only_return=False, frame_ids=None):
784
+ if self.save_this:
785
+ if valids is not None:
786
+ valids = torch.stack(valids, dim=1)
787
+
788
+ feats = torch.stack(feats, dim=1)
789
+ # feats leads with B x S x C
790
+
791
+ if feats.ndim==6:
792
+
793
+ # feats is B x S x C x D x H x W
794
+ if fro:
795
+ reduce_dim = 3
796
+ else:
797
+ reduce_dim = 4
798
+
799
+ if valids is None:
800
+ feats = torch.mean(feats, dim=reduce_dim)
801
+ else:
802
+ valids = valids.repeat(1, 1, feats.size()[2], 1, 1, 1)
803
+ feats = utils.basic.reduce_masked_mean(feats, valids, dim=reduce_dim)
804
+
805
+ B, S, C, D, W = list(feats.size())
806
+
807
+ if not pca:
808
+ # feats leads with B x S x C
809
+ feats = torch.mean(torch.abs(feats), dim=2, keepdims=True)
810
+ # feats leads with B x S x 1
811
+ feats = torch.unbind(feats, dim=1)
812
+ return self.summ_oneds(name=name, ims=feats, norm=True, only_return=only_return, frame_ids=frame_ids)
813
+
814
+ else:
815
+ __p = lambda x: utils.basic.pack_seqdim(x, B)
816
+ __u = lambda x: utils.basic.unpack_seqdim(x, B)
817
+
818
+ feats_ = __p(feats)
819
+
820
+ if valids is None:
821
+ feats_pca_ = get_feat_pca(feats_)
822
+ else:
823
+ valids_ = __p(valids)
824
+ feats_pca_ = get_feat_pca(feats_, valids)
825
+
826
+ feats_pca = __u(feats_pca_)
827
+
828
+ return self.summ_rgbs(name=name, ims=torch.unbind(feats_pca, dim=1), only_return=only_return, frame_ids=frame_ids)
829
+
830
+ def summ_feat(self, name, feat, valid=None, pca=True, only_return=False, bev=False, fro=False, frame_id=None):
831
+ if self.save_this:
832
+ if feat.ndim==5: # B x C x D x H x W
833
+
834
+ if bev:
835
+ reduce_axis = 3
836
+ elif fro:
837
+ reduce_axis = 2
838
+ else:
839
+ # default to bev
840
+ reduce_axis = 3
841
+
842
+ if valid is None:
843
+ feat = torch.mean(feat, dim=reduce_axis)
844
+ else:
845
+ valid = valid.repeat(1, feat.size()[1], 1, 1, 1)
846
+ feat = utils.basic.reduce_masked_mean(feat, valid, dim=reduce_axis)
847
+
848
+ B, C, D, W = list(feat.shape)
849
+
850
+ if not pca:
851
+ feat = torch.mean(torch.abs(feat), dim=1, keepdims=True)
852
+ # feat is B x 1 x D x W
853
+ return self.summ_oned(name=name, im=feat, norm=True, only_return=only_return, frame_id=frame_id)
854
+ else:
855
+ feat_pca = get_feat_pca(feat, valid)
856
+ return self.summ_rgb(name, feat_pca, only_return=only_return, frame_id=frame_id)
857
+
858
+ def summ_scalar(self, name, value):
859
+ if (not (isinstance(value, int) or isinstance(value, float) or isinstance(value, np.float32))) and ('torch' in value.type()):
860
+ value = value.detach().cpu().numpy()
861
+ if not np.isnan(value):
862
+ if (self.log_freq == 1):
863
+ self.writer.add_scalar(name, value, global_step=self.global_step)
864
+ elif self.save_this or np.mod(self.global_step, self.scalar_freq)==0:
865
+ self.writer.add_scalar(name, value, global_step=self.global_step)
866
+
867
+ def summ_seg(self, name, seg, only_return=False, frame_id=None, colormap='tab20', label_colors=None):
868
+ if not self.save_this:
869
+ return
870
+
871
+ B,H,W = seg.shape
872
+
873
+ if label_colors is None:
874
+ custom_label_colors = False
875
+ # label_colors = get_n_colors(int(torch.max(seg).item()), sequential=True)
876
+ label_colors = cm.get_cmap(colormap).colors
877
+ label_colors = [[int(i*255) for i in l] for l in label_colors]
878
+ else:
879
+ custom_label_colors = True
880
+ # label_colors = matplotlib.cm.get_cmap(colormap).colors
881
+ # label_colors = [[int(i*255) for i in l] for l in label_colors]
882
+ # print('label_colors', label_colors)
883
+
884
+ # label_colors = [
885
+ # (0, 0, 0), # None
886
+ # (70, 70, 70), # Buildings
887
+ # (190, 153, 153), # Fences
888
+ # (72, 0, 90), # Other
889
+ # (220, 20, 60), # Pedestrians
890
+ # (153, 153, 153), # Poles
891
+ # (157, 234, 50), # RoadLines
892
+ # (128, 64, 128), # Roads
893
+ # (244, 35, 232), # Sidewalks
894
+ # (107, 142, 35), # Vegetation
895
+ # (0, 0, 255), # Vehicles
896
+ # (102, 102, 156), # Walls
897
+ # (220, 220, 0) # TrafficSigns
898
+ # ]
899
+
900
+ r = torch.zeros_like(seg,dtype=torch.uint8)
901
+ g = torch.zeros_like(seg,dtype=torch.uint8)
902
+ b = torch.zeros_like(seg,dtype=torch.uint8)
903
+
904
+ for label in range(0,len(label_colors)):
905
+ if (not custom_label_colors):# and (N > 20):
906
+ label_ = label % 20
907
+ else:
908
+ label_ = label
909
+
910
+ idx = (seg == label+1)
911
+ r[idx] = label_colors[label_][0]
912
+ g[idx] = label_colors[label_][1]
913
+ b[idx] = label_colors[label_][2]
914
+
915
+ rgb = torch.stack([r,g,b],axis=1)
916
+ return self.summ_rgb(name,rgb,only_return=only_return, frame_id=frame_id)
917
+
918
+ def summ_pts_on_rgb(self, name, trajs, rgb, valids=None, frame_id=None, only_return=False, show_dots=True, cmap='coolwarm', linewidth=1):
919
+ # trajs is B, S, N, 2
920
+ # rgbs is B, S, C, H, W
921
+ B, C, H, W = rgb.shape
922
+ B, S, N, D = trajs.shape
923
+
924
+ rgb = rgb[0] # C, H, W
925
+ trajs = trajs[0] # S, N, 2
926
+ if valids is None:
927
+ valids = torch.ones_like(trajs[:,:,0]) # S, N
928
+ else:
929
+ valids = valids[0]
930
+ # print('trajs', trajs.shape)
931
+ # print('valids', valids.shape)
932
+
933
+ rgb = back2color(rgb).detach().cpu().numpy()
934
+ rgb = np.transpose(rgb, [1, 2, 0]) # put channels last
935
+
936
+ trajs = trajs.long().detach().cpu().numpy() # S, N, 2
937
+ valids = valids.long().detach().cpu().numpy() # S, N
938
+
939
+ rgb = rgb.astype(np.uint8).copy()
940
+
941
+ for i in range(N):
942
+ if cmap=='onediff' and i==0:
943
+ cmap_ = 'spring'
944
+ elif cmap=='onediff':
945
+ cmap_ = 'winter'
946
+ else:
947
+ cmap_ = cmap
948
+ traj = trajs[:,i] # S,2
949
+ valid = valids[:,i] # S
950
+
951
+ color_map = cm.get_cmap(cmap)
952
+ color = np.array(color_map(i)[:3]) * 255 # rgb
953
+ for s in range(S):
954
+ if valid[s]:
955
+ cv2.circle(rgb, (int(traj[s,0]), int(traj[s,1])), linewidth, color, -1)
956
+ rgb = torch.from_numpy(rgb).permute(2,0,1).unsqueeze(0)
957
+ rgb = preprocess_color(rgb)
958
+ return self.summ_rgb(name, rgb, only_return=only_return, frame_id=frame_id)
959
+
960
+ def summ_pts_on_rgbs(self, name, trajs, rgbs, valids=None, frame_ids=None, only_return=False, show_dots=True, cmap='coolwarm', linewidth=1):
961
+ # trajs is B, S, N, 2
962
+ # rgbs is B, S, C, H, W
963
+ B, S, C, H, W = rgbs.shape
964
+ B, S2, N, D = trajs.shape
965
+ assert(S==S2)
966
+
967
+ rgbs = rgbs[0] # S, C, H, W
968
+ trajs = trajs[0] # S, N, 2
969
+ if valids is None:
970
+ valids = torch.ones_like(trajs[:,:,0]) # S, N
971
+ else:
972
+ valids = valids[0]
973
+ # print('trajs', trajs.shape)
974
+ # print('valids', valids.shape)
975
+
976
+ rgbs_color = []
977
+ for rgb in rgbs:
978
+ rgb = back2color(rgb).detach().cpu().numpy()
979
+ rgb = np.transpose(rgb, [1, 2, 0]) # put channels last
980
+ rgbs_color.append(rgb) # each element 3 x H x W
981
+
982
+ trajs = trajs.long().detach().cpu().numpy() # S, N, 2
983
+ valids = valids.long().detach().cpu().numpy() # S, N
984
+
985
+ rgbs_color = [rgb.astype(np.uint8).copy() for rgb in rgbs_color]
986
+
987
+ for i in range(N):
988
+ traj = trajs[:,i] # S,2
989
+ valid = valids[:,i] # S
990
+
991
+ color_map = cm.get_cmap(cmap)
992
+ color = np.array(color_map(0)[:3]) * 255 # rgb
993
+ for s in range(S):
994
+ if valid[s]:
995
+ cv2.circle(rgbs_color[s], (traj[s,0], traj[s,1]), linewidth, color, -1)
996
+ rgbs = []
997
+ for rgb in rgbs_color:
998
+ rgb = torch.from_numpy(rgb).permute(2, 0, 1).unsqueeze(0)
999
+ rgbs.append(preprocess_color(rgb))
1000
+
1001
+ return self.summ_rgbs(name, rgbs, only_return=only_return, frame_ids=frame_ids)
1002
+
1003
+
1004
+ def summ_traj2ds_on_rgbs(self, name, trajs, rgbs, valids=None, frame_ids=None, only_return=False, show_dots=False, cmap='coolwarm', vals=None, linewidth=1):
1005
+ # trajs is B, S, N, 2
1006
+ # rgbs is B, S, C, H, W
1007
+ B, S, C, H, W = rgbs.shape
1008
+ B, S2, N, D = trajs.shape
1009
+ assert(S==S2)
1010
+
1011
+ rgbs = rgbs[0] # S, C, H, W
1012
+ trajs = trajs[0] # S, N, 2
1013
+ if valids is None:
1014
+ valids = torch.ones_like(trajs[:,:,0]) # S, N
1015
+ else:
1016
+ valids = valids[0]
1017
+
1018
+ # print('trajs', trajs.shape)
1019
+ # print('valids', valids.shape)
1020
+
1021
+ if vals is not None:
1022
+ vals = vals[0] # N
1023
+ # print('vals', vals.shape)
1024
+
1025
+ rgbs_color = []
1026
+ for rgb in rgbs:
1027
+ rgb = back2color(rgb).detach().cpu().numpy()
1028
+ rgb = np.transpose(rgb, [1, 2, 0]) # put channels last
1029
+ rgbs_color.append(rgb) # each element 3 x H x W
1030
+
1031
+ for i in range(N):
1032
+ if cmap=='onediff' and i==0:
1033
+ cmap_ = 'spring'
1034
+ elif cmap=='onediff':
1035
+ cmap_ = 'winter'
1036
+ else:
1037
+ cmap_ = cmap
1038
+ traj = trajs[:,i].long().detach().cpu().numpy() # S, 2
1039
+ valid = valids[:,i].long().detach().cpu().numpy() # S
1040
+
1041
+ # print('traj', traj.shape)
1042
+ # print('valid', valid.shape)
1043
+
1044
+ if vals is not None:
1045
+ # val = vals[:,i].float().detach().cpu().numpy() # []
1046
+ val = vals[i].float().detach().cpu().numpy() # []
1047
+ # print('val', val.shape)
1048
+ else:
1049
+ val = None
1050
+
1051
+ for t in range(S):
1052
+ # if valid[t]:
1053
+ # traj_seq = traj[max(t-16,0):t+1]
1054
+ traj_seq = traj[max(t-8,0):t+1]
1055
+ val_seq = np.linspace(0,1,len(traj_seq))
1056
+ # if t<2:
1057
+ # val_seq = np.zeros_like(val_seq)
1058
+ # print('val_seq', val_seq)
1059
+ # val_seq = 1.0
1060
+ # val_seq = np.arange(8)/8.0
1061
+ # val_seq = val_seq[-len(traj_seq):]
1062
+ # rgbs_color[t] = self.draw_traj_on_image_py(rgbs_color[t], traj_seq, S=S, show_dots=show_dots, cmap=cmap_, val=val_seq, linewidth=linewidth)
1063
+ rgbs_color[t] = self.draw_traj_on_image_py(rgbs_color[t], traj_seq, S=S, show_dots=show_dots, cmap=cmap_, val=val_seq, linewidth=linewidth)
1064
+ # input()
1065
+
1066
+ for i in range(N):
1067
+ if cmap=='onediff' and i==0:
1068
+ cmap_ = 'spring'
1069
+ elif cmap=='onediff':
1070
+ cmap_ = 'winter'
1071
+ else:
1072
+ cmap_ = cmap
1073
+ traj = trajs[:,i] # S,2
1074
+ # vis = visibles[:,i] # S
1075
+ vis = torch.ones_like(traj[:,0]) # S
1076
+ valid = valids[:,i] # S
1077
+ rgbs_color = self.draw_circ_on_images_py(rgbs_color, traj, vis, S=0, show_dots=show_dots, cmap=cmap_, linewidth=linewidth)
1078
+
1079
+ rgbs = []
1080
+ for rgb in rgbs_color:
1081
+ rgb = torch.from_numpy(rgb).permute(2, 0, 1).unsqueeze(0)
1082
+ rgbs.append(preprocess_color(rgb))
1083
+
1084
+ return self.summ_rgbs(name, rgbs, only_return=only_return, frame_ids=frame_ids)
1085
+
1086
+ def summ_traj2ds_on_rgbs2(self, name, trajs, visibles, rgbs, valids=None, frame_ids=None, only_return=False, show_dots=True, cmap=None, linewidth=1):
1087
+ # trajs is B, S, N, 2
1088
+ # rgbs is B, S, C, H, W
1089
+ B, S, C, H, W = rgbs.shape
1090
+ B, S2, N, D = trajs.shape
1091
+ assert(S==S2)
1092
+
1093
+ rgbs = rgbs[0] # S, C, H, W
1094
+ trajs = trajs[0] # S, N, 2
1095
+ visibles = visibles[0] # S, N
1096
+ if valids is None:
1097
+ valids = torch.ones_like(trajs[:,:,0]) # S, N
1098
+ else:
1099
+ valids = valids[0]
1100
+ # print('trajs', trajs.shape)
1101
+ # print('valids', valids.shape)
1102
+
1103
+ rgbs_color = []
1104
+ for rgb in rgbs:
1105
+ rgb = back2color(rgb).detach().cpu().numpy()
1106
+ rgb = np.transpose(rgb, [1, 2, 0]) # put channels last
1107
+ rgbs_color.append(rgb) # each element 3 x H x W
1108
+
1109
+ trajs = trajs.long().detach().cpu().numpy() # S, N, 2
1110
+ visibles = visibles.float().detach().cpu().numpy() # S, N
1111
+ valids = valids.long().detach().cpu().numpy() # S, N
1112
+
1113
+ for i in range(N):
1114
+ if cmap=='onediff' and i==0:
1115
+ cmap_ = 'spring'
1116
+ elif cmap=='onediff':
1117
+ cmap_ = 'winter'
1118
+ else:
1119
+ cmap_ = cmap
1120
+ traj = trajs[:,i] # S,2
1121
+ vis = visibles[:,i] # S
1122
+ valid = valids[:,i] # S
1123
+ rgbs_color = self.draw_traj_on_images_py(rgbs_color, traj, S=S, show_dots=show_dots, cmap=cmap_, linewidth=linewidth)
1124
+
1125
+ for i in range(N):
1126
+ if cmap=='onediff' and i==0:
1127
+ cmap_ = 'spring'
1128
+ elif cmap=='onediff':
1129
+ cmap_ = 'winter'
1130
+ else:
1131
+ cmap_ = cmap
1132
+ traj = trajs[:,i] # S,2
1133
+ vis = visibles[:,i] # S
1134
+ valid = valids[:,i] # S
1135
+ if valid[0]:
1136
+ rgbs_color = self.draw_circ_on_images_py(rgbs_color, traj, vis, S=S, show_dots=show_dots, cmap=None, linewidth=linewidth)
1137
+
1138
+ rgbs = []
1139
+ for rgb in rgbs_color:
1140
+ rgb = torch.from_numpy(rgb).permute(2, 0, 1).unsqueeze(0)
1141
+ rgbs.append(preprocess_color(rgb))
1142
+
1143
+ return self.summ_rgbs(name, rgbs, only_return=only_return, frame_ids=frame_ids)
1144
+
1145
+ def summ_traj2ds_on_rgb(self, name, trajs, rgb, valids=None, show_dots=False, show_lines=True, frame_id=None, only_return=False, cmap='coolwarm', linewidth=1):
1146
+ # trajs is B, S, N, 2
1147
+ # rgb is B, C, H, W
1148
+ B, C, H, W = rgb.shape
1149
+ B, S, N, D = trajs.shape
1150
+
1151
+ rgb = rgb[0] # S, C, H, W
1152
+ trajs = trajs[0] # S, N, 2
1153
+
1154
+ if valids is None:
1155
+ valids = torch.ones_like(trajs[:,:,0])
1156
+ else:
1157
+ valids = valids[0]
1158
+
1159
+ rgb_color = back2color(rgb).detach().cpu().numpy()
1160
+ rgb_color = np.transpose(rgb_color, [1, 2, 0]) # put channels last
1161
+
1162
+ # using maxdist will dampen the colors for short motions
1163
+ norms = torch.sqrt(1e-4 + torch.sum((trajs[-1] - trajs[0])**2, dim=1)) # N
1164
+ maxdist = torch.quantile(norms, 0.95).detach().cpu().numpy()
1165
+ maxdist = None
1166
+ trajs = trajs.long().detach().cpu().numpy() # S, N, 2
1167
+ valids = valids.long().detach().cpu().numpy() # S, N
1168
+
1169
+ for i in range(N):
1170
+ if cmap=='onediff' and i==0:
1171
+ cmap_ = 'spring'
1172
+ elif cmap=='onediff':
1173
+ cmap_ = 'winter'
1174
+ else:
1175
+ cmap_ = cmap
1176
+ traj = trajs[:,i] # S, 2
1177
+ valid = valids[:,i] # S
1178
+ if valid[0]==1:
1179
+ traj = traj[valid>0]
1180
+ rgb_color = self.draw_traj_on_image_py(
1181
+ rgb_color, traj, S=S, show_dots=show_dots, show_lines=show_lines, cmap=cmap_, maxdist=maxdist, linewidth=linewidth)
1182
+
1183
+ rgb_color = torch.from_numpy(rgb_color).permute(2, 0, 1).unsqueeze(0)
1184
+ rgb = preprocess_color(rgb_color)
1185
+ return self.summ_rgb(name, rgb, only_return=only_return, frame_id=frame_id)
1186
+
1187
+ def draw_traj_on_image_py(self, rgb, traj, S=50, linewidth=1, show_dots=False, show_lines=True, cmap='coolwarm', val=None, maxdist=None):
1188
+ # all inputs are numpy tensors
1189
+ # rgb is 3 x H x W
1190
+ # traj is S x 2
1191
+
1192
+ H, W, C = rgb.shape
1193
+ assert(C==3)
1194
+
1195
+ rgb = rgb.astype(np.uint8).copy()
1196
+
1197
+ S1, D = traj.shape
1198
+ assert(D==2)
1199
+
1200
+ color_map = cm.get_cmap(cmap)
1201
+ S1, D = traj.shape
1202
+
1203
+ for s in range(S1):
1204
+ if val is not None:
1205
+ # if len(val) == S1:
1206
+ color = np.array(color_map(val[s])[:3]) * 255 # rgb
1207
+ # else:
1208
+ # color = np.array(color_map(val)[:3]) * 255 # rgb
1209
+ else:
1210
+ if maxdist is not None:
1211
+ val = (np.sqrt(np.sum((traj[s]-traj[0])**2))/maxdist).clip(0,1)
1212
+ color = np.array(color_map(val)[:3]) * 255 # rgb
1213
+ else:
1214
+ color = np.array(color_map((s)/max(1,float(S-2)))[:3]) * 255 # rgb
1215
+
1216
+ if show_lines and s<(S1-1):
1217
+ cv2.line(rgb,
1218
+ (int(traj[s,0]), int(traj[s,1])),
1219
+ (int(traj[s+1,0]), int(traj[s+1,1])),
1220
+ color,
1221
+ linewidth,
1222
+ cv2.LINE_AA)
1223
+ if show_dots:
1224
+ cv2.circle(rgb, (int(traj[s,0]), int(traj[s,1])), linewidth, np.array(color_map(1)[:3])*255, -1)
1225
+
1226
+ # if maxdist is not None:
1227
+ # val = (np.sqrt(np.sum((traj[-1]-traj[0])**2))/maxdist).clip(0,1)
1228
+ # color = np.array(color_map(val)[:3]) * 255 # rgb
1229
+ # else:
1230
+ # # draw the endpoint of traj, using the next color (which may be the last color)
1231
+ # color = np.array(color_map((S1-1)/max(1,float(S-2)))[:3]) * 255 # rgb
1232
+
1233
+ # # emphasize endpoint
1234
+ # cv2.circle(rgb, (traj[-1,0], traj[-1,1]), linewidth*2, color, -1)
1235
+
1236
+ return rgb
1237
+
1238
+
1239
+
1240
+ def draw_traj_on_images_py(self, rgbs, traj, S=50, linewidth=1, show_dots=False, cmap='coolwarm', maxdist=None):
1241
+ # all inputs are numpy tensors
1242
+ # rgbs is a list of H,W,3
1243
+ # traj is S,2
1244
+ H, W, C = rgbs[0].shape
1245
+ assert(C==3)
1246
+
1247
+ rgbs = [rgb.astype(np.uint8).copy() for rgb in rgbs]
1248
+
1249
+ S1, D = traj.shape
1250
+ assert(D==2)
1251
+
1252
+ x = int(np.clip(traj[0,0], 0, W-1))
1253
+ y = int(np.clip(traj[0,1], 0, H-1))
1254
+ color = rgbs[0][y,x]
1255
+ color = (int(color[0]),int(color[1]),int(color[2]))
1256
+ for s in range(S):
1257
+ # bak_color = np.array(color_map(1.0)[:3]) * 255 # rgb
1258
+ # cv2.circle(rgbs[s], (traj[s,0], traj[s,1]), linewidth*4, bak_color, -1)
1259
+ cv2.polylines(rgbs[s],
1260
+ [traj[:s+1]],
1261
+ False,
1262
+ color,
1263
+ linewidth,
1264
+ cv2.LINE_AA)
1265
+ return rgbs
1266
+
1267
+ def draw_circs_on_image_py(self, rgb, xy, colors=None, linewidth=10, radius=3, show_dots=False, maxdist=None):
1268
+ # all inputs are numpy tensors
1269
+ # rgbs is a list of 3,H,W
1270
+ # xy is N,2
1271
+ H, W, C = rgb.shape
1272
+ assert(C==3)
1273
+
1274
+ rgb = rgb.astype(np.uint8).copy()
1275
+
1276
+ N, D = xy.shape
1277
+ assert(D==2)
1278
+
1279
+
1280
+ xy = xy.astype(np.float32)
1281
+ xy[:,0] = np.clip(xy[:,0], 0, W-1)
1282
+ xy[:,1] = np.clip(xy[:,1], 0, H-1)
1283
+ xy = xy.astype(np.int32)
1284
+
1285
+
1286
+
1287
+ if colors is None:
1288
+ colors = get_n_colors(N)
1289
+
1290
+ for n in range(N):
1291
+ color = colors[n]
1292
+ # print('color', color)
1293
+ # color = (color[0]*255).astype(np.uint8)
1294
+ color = (int(color[0]),int(color[1]),int(color[2]))
1295
+
1296
+ # x = int(np.clip(xy[0,0], 0, W-1))
1297
+ # y = int(np.clip(xy[0,1], 0, H-1))
1298
+ # color_ = rgbs[0][y,x]
1299
+ # color_ = (int(color_[0]),int(color_[1]),int(color_[2]))
1300
+ # color_ = (int(color_[0]),int(color_[1]),int(color_[2]))
1301
+
1302
+ cv2.circle(rgb, (xy[n,0], xy[n,1]), linewidth, color, 3)
1303
+ # vis_color = int(np.squeeze(vis[s])*255)
1304
+ # vis_color = (vis_color,vis_color,vis_color)
1305
+ # cv2.circle(rgbs[s], (traj[s,0], traj[s,1]), linewidth+1, vis_color, -1)
1306
+ return rgb
1307
+
1308
+ def draw_circ_on_images_py(self, rgbs, traj, vis, S=50, linewidth=1, show_dots=False, cmap=None, maxdist=None):
1309
+ # all inputs are numpy tensors
1310
+ # rgbs is a list of 3,H,W
1311
+ # traj is S,2
1312
+ H, W, C = rgbs[0].shape
1313
+ assert(C==3)
1314
+
1315
+ rgbs = [rgb.astype(np.uint8).copy() for rgb in rgbs]
1316
+
1317
+ S1, D = traj.shape
1318
+ assert(D==2)
1319
+
1320
+ if cmap is None:
1321
+ bremm = ColorMap2d()
1322
+ traj_ = traj[0:1].astype(np.float32)
1323
+ traj_[:,0] /= float(W)
1324
+ traj_[:,1] /= float(H)
1325
+ color = bremm(traj_)
1326
+ # print('color', color)
1327
+ color = (color[0]*255).astype(np.uint8)
1328
+ # color = (int(color[0]),int(color[1]),int(color[2]))
1329
+ color = (int(color[2]),int(color[1]),int(color[0]))
1330
+
1331
+ for s in range(S1):
1332
+ if cmap is not None:
1333
+ color_map = cm.get_cmap(cmap)
1334
+ # color = np.array(color_map(s/(S-1))[:3]) * 255 # rgb
1335
+ color = np.array(color_map((s+1)/max(1,float(S-1)))[:3]) * 255 # rgb
1336
+ # color = color.astype(np.uint8)
1337
+ # color = (color[0], color[1], color[2])
1338
+ # print('color', color)
1339
+ # import ipdb; ipdb.set_trace()
1340
+
1341
+ cv2.circle(rgbs[s], (int(traj[s,0]), int(traj[s,1])), linewidth+1, color, -1)
1342
+ # vis_color = int(np.squeeze(vis[s])*255)
1343
+ # vis_color = (vis_color,vis_color,vis_color)
1344
+ # cv2.circle(rgbs[s], (int(traj[s,0]), int(traj[s,1])), linewidth+1, vis_color, -1)
1345
+
1346
+ return rgbs
1347
+
1348
+ def summ_traj_as_crops(self, name, trajs_e, rgbs, frame_id=None, only_return=False, show_circ=False, trajs_g=None, is_g=False):
1349
+ B, S, N, D = trajs_e.shape
1350
+ assert(N==1)
1351
+ assert(D==2)
1352
+
1353
+ rgbs_vis = []
1354
+ n = 0
1355
+ pad_amount = 100
1356
+ trajs_e_py = trajs_e[0].detach().cpu().numpy()
1357
+ # trajs_e_py = np.clip(trajs_e_py, min=pad_amount/2, max=pad_amoun
1358
+ trajs_e_py = trajs_e_py + pad_amount
1359
+
1360
+ if trajs_g is not None:
1361
+ trajs_g_py = trajs_g[0].detach().cpu().numpy()
1362
+ trajs_g_py = trajs_g_py + pad_amount
1363
+
1364
+ for s in range(S):
1365
+ rgb = rgbs[0,s].detach().cpu().numpy()
1366
+ # print('orig rgb', rgb.shape)
1367
+ rgb = np.transpose(rgb,(1,2,0)) # H, W, 3
1368
+
1369
+ rgb = np.pad(rgb, ((pad_amount,pad_amount),(pad_amount,pad_amount),(0,0)))
1370
+ # print('pad rgb', rgb.shape)
1371
+ H, W, C = rgb.shape
1372
+
1373
+ if trajs_g is not None:
1374
+ xy_g = trajs_g_py[s,n]
1375
+ xy_g[0] = np.clip(xy_g[0], pad_amount, W-pad_amount)
1376
+ xy_g[1] = np.clip(xy_g[1], pad_amount, H-pad_amount)
1377
+ rgb = self.draw_circs_on_image_py(rgb, xy_g.reshape(1,2), colors=[(0,255,0)], linewidth=2, radius=3)
1378
+
1379
+ xy_e = trajs_e_py[s,n]
1380
+ xy_e[0] = np.clip(xy_e[0], pad_amount, W-pad_amount)
1381
+ xy_e[1] = np.clip(xy_e[1], pad_amount, H-pad_amount)
1382
+
1383
+ if show_circ:
1384
+ if is_g:
1385
+ rgb = self.draw_circs_on_image_py(rgb, xy_e.reshape(1,2), colors=[(0,255,0)], linewidth=2, radius=3)
1386
+ else:
1387
+ rgb = self.draw_circs_on_image_py(rgb, xy_e.reshape(1,2), colors=[(255,0,255)], linewidth=2, radius=3)
1388
+
1389
+
1390
+ xmin = int(xy_e[0])-pad_amount//2
1391
+ xmax = int(xy_e[0])+pad_amount//2
1392
+ ymin = int(xy_e[1])-pad_amount//2
1393
+ ymax = int(xy_e[1])+pad_amount//2
1394
+
1395
+ rgb_ = rgb[ymin:ymax, xmin:xmax]
1396
+
1397
+ H_, W_ = rgb_.shape[:2]
1398
+ # if np.any(rgb_.shape==0):
1399
+ # input()
1400
+ if H_==0 or W_==0:
1401
+ import ipdb; ipdb.set_trace()
1402
+
1403
+ rgb_ = rgb_.transpose(2,0,1)
1404
+ rgb_ = torch.from_numpy(rgb_)
1405
+
1406
+ rgbs_vis.append(rgb_)
1407
+
1408
+ # nrow = int(np.sqrt(S)*(16.0/9)/2.0)
1409
+ nrow = int(np.sqrt(S)*1.5)
1410
+ grid_img = torchvision.utils.make_grid(torch.stack(rgbs_vis, dim=0), nrow=nrow).unsqueeze(0)
1411
+ # print('grid_img', grid_img.shape)
1412
+ return self.summ_rgb(name, grid_img.byte(), frame_id=frame_id, only_return=only_return)
1413
+
1414
+ def summ_occ(self, name, occ, reduce_axes=[3], bev=False, fro=False, pro=False, frame_id=None, only_return=False):
1415
+ if self.save_this:
1416
+ B, C, D, H, W = list(occ.shape)
1417
+ if bev:
1418
+ reduce_axes = [3]
1419
+ elif fro:
1420
+ reduce_axes = [2]
1421
+ elif pro:
1422
+ reduce_axes = [4]
1423
+ for reduce_axis in reduce_axes:
1424
+ height = convert_occ_to_height(occ, reduce_axis=reduce_axis)
1425
+ if reduce_axis == reduce_axes[-1]:
1426
+ return self.summ_oned(name=('%s_ax%d' % (name, reduce_axis)), im=height, norm=False, frame_id=frame_id, only_return=only_return)
1427
+ else:
1428
+ self.summ_oned(name=('%s_ax%d' % (name, reduce_axis)), im=height, norm=False, frame_id=frame_id, only_return=only_return)
1429
+
1430
+ def erode2d(im, times=1, device='cuda'):
1431
+ weights2d = torch.ones(1, 1, 3, 3, device=device)
1432
+ for time in range(times):
1433
+ im = 1.0 - F.conv2d(1.0 - im, weights2d, padding=1).clamp(0, 1)
1434
+ return im
1435
+
1436
+ def dilate2d(im, times=1, device='cuda', mode='square'):
1437
+ weights2d = torch.ones(1, 1, 3, 3, device=device)
1438
+ if mode=='cross':
1439
+ weights2d[:,:,0,0] = 0.0
1440
+ weights2d[:,:,0,2] = 0.0
1441
+ weights2d[:,:,2,0] = 0.0
1442
+ weights2d[:,:,2,2] = 0.0
1443
+ for time in range(times):
1444
+ im = F.conv2d(im, weights2d, padding=1).clamp(0, 1)
1445
+ return im
1446
+
1447
+
models/spatracker/utils/misc.py ADDED
@@ -0,0 +1,166 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ import math
4
+ from prettytable import PrettyTable
5
+
6
+ def count_parameters(model):
7
+ table = PrettyTable(["Modules", "Parameters"])
8
+ total_params = 0
9
+ for name, parameter in model.named_parameters():
10
+ if not parameter.requires_grad:
11
+ continue
12
+ param = parameter.numel()
13
+ if param > 100000:
14
+ table.add_row([name, param])
15
+ total_params+=param
16
+ print(table)
17
+ print('total params: %.2f M' % (total_params/1000000.0))
18
+ return total_params
19
+
20
+ def posemb_sincos_2d_xy(xy, C, temperature=10000, dtype=torch.float32, cat_coords=False):
21
+ device = xy.device
22
+ dtype = xy.dtype
23
+ B, S, D = xy.shape
24
+ assert(D==2)
25
+ x = xy[:,:,0]
26
+ y = xy[:,:,1]
27
+ assert (C % 4) == 0, 'feature dimension must be multiple of 4 for sincos emb'
28
+ omega = torch.arange(C // 4, device=device) / (C // 4 - 1)
29
+ omega = 1. / (temperature ** omega)
30
+
31
+ y = y.flatten()[:, None] * omega[None, :]
32
+ x = x.flatten()[:, None] * omega[None, :]
33
+ pe = torch.cat((x.sin(), x.cos(), y.sin(), y.cos()), dim=1)
34
+ pe = pe.reshape(B,S,C).type(dtype)
35
+ if cat_coords:
36
+ pe = torch.cat([pe, xy], dim=2) # B,N,C+2
37
+ return pe
38
+
39
+ class SimplePool():
40
+ def __init__(self, pool_size, version='pt'):
41
+ self.pool_size = pool_size
42
+ self.version = version
43
+ self.items = []
44
+
45
+ if not (version=='pt' or version=='np'):
46
+ print('version = %s; please choose pt or np')
47
+ assert(False) # please choose pt or np
48
+
49
+ def __len__(self):
50
+ return len(self.items)
51
+
52
+ def mean(self, min_size=1):
53
+ if min_size=='half':
54
+ pool_size_thresh = self.pool_size/2
55
+ else:
56
+ pool_size_thresh = min_size
57
+
58
+ if self.version=='np':
59
+ if len(self.items) >= pool_size_thresh:
60
+ return np.sum(self.items)/float(len(self.items))
61
+ else:
62
+ return np.nan
63
+ if self.version=='pt':
64
+ if len(self.items) >= pool_size_thresh:
65
+ return torch.sum(self.items)/float(len(self.items))
66
+ else:
67
+ return torch.from_numpy(np.nan)
68
+
69
+ def sample(self, with_replacement=True):
70
+ idx = np.random.randint(len(self.items))
71
+ if with_replacement:
72
+ return self.items[idx]
73
+ else:
74
+ return self.items.pop(idx)
75
+
76
+ def fetch(self, num=None):
77
+ if self.version=='pt':
78
+ item_array = torch.stack(self.items)
79
+ elif self.version=='np':
80
+ item_array = np.stack(self.items)
81
+ if num is not None:
82
+ # there better be some items
83
+ assert(len(self.items) >= num)
84
+
85
+ # if there are not that many elements just return however many there are
86
+ if len(self.items) < num:
87
+ return item_array
88
+ else:
89
+ idxs = np.random.randint(len(self.items), size=num)
90
+ return item_array[idxs]
91
+ else:
92
+ return item_array
93
+
94
+ def is_full(self):
95
+ full = len(self.items)==self.pool_size
96
+ return full
97
+
98
+ def empty(self):
99
+ self.items = []
100
+
101
+ def update(self, items):
102
+ for item in items:
103
+ if len(self.items) < self.pool_size:
104
+ # the pool is not full, so let's add this in
105
+ self.items.append(item)
106
+ else:
107
+ # the pool is full
108
+ # pop from the front
109
+ self.items.pop(0)
110
+ # add to the back
111
+ self.items.append(item)
112
+ return self.items
113
+
114
+ def farthest_point_sample(xyz, npoint, include_ends=False, deterministic=False):
115
+ """
116
+ Input:
117
+ xyz: pointcloud data, [B, N, C], where C is probably 3
118
+ npoint: number of samples
119
+ Return:
120
+ inds: sampled pointcloud index, [B, npoint]
121
+ """
122
+ device = xyz.device
123
+ B, N, C = xyz.shape
124
+ xyz = xyz.float()
125
+ inds = torch.zeros(B, npoint, dtype=torch.long).to(device)
126
+ distance = torch.ones(B, N).to(device) * 1e10
127
+ if deterministic:
128
+ farthest = torch.randint(0, 1, (B,), dtype=torch.long).to(device)
129
+ else:
130
+ farthest = torch.randint(0, N, (B,), dtype=torch.long).to(device)
131
+ batch_indices = torch.arange(B, dtype=torch.long).to(device)
132
+ for i in range(npoint):
133
+ if include_ends:
134
+ if i==0:
135
+ farthest = 0
136
+ elif i==1:
137
+ farthest = N-1
138
+ inds[:, i] = farthest
139
+ centroid = xyz[batch_indices, farthest, :].view(B, 1, C)
140
+ dist = torch.sum((xyz - centroid) ** 2, -1)
141
+ mask = dist < distance
142
+ distance[mask] = dist[mask]
143
+ farthest = torch.max(distance, -1)[1]
144
+
145
+ if npoint > N:
146
+ # if we need more samples, make them random
147
+ distance += torch.randn_like(distance)
148
+ return inds
149
+
150
+ def farthest_point_sample_py(xyz, npoint):
151
+ N,C = xyz.shape
152
+ inds = np.zeros(npoint, dtype=np.int32)
153
+ distance = np.ones(N) * 1e10
154
+ farthest = np.random.randint(0, N, dtype=np.int32)
155
+ for i in range(npoint):
156
+ inds[i] = farthest
157
+ centroid = xyz[farthest, :].reshape(1,C)
158
+ dist = np.sum((xyz - centroid) ** 2, -1)
159
+ mask = dist < distance
160
+ distance[mask] = dist[mask]
161
+ farthest = np.argmax(distance, -1)
162
+ if npoint > N:
163
+ # if we need more samples, make them random
164
+ distance += np.random.randn(*distance.shape)
165
+ return inds
166
+
models/spatracker/utils/samp.py ADDED
@@ -0,0 +1,152 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import utils.basic
3
+ import torch.nn.functional as F
4
+
5
+ def bilinear_sample2d(im, x, y, return_inbounds=False):
6
+ # x and y are each B, N
7
+ # output is B, C, N
8
+ B, C, H, W = list(im.shape)
9
+ N = list(x.shape)[1]
10
+
11
+ x = x.float()
12
+ y = y.float()
13
+ H_f = torch.tensor(H, dtype=torch.float32)
14
+ W_f = torch.tensor(W, dtype=torch.float32)
15
+
16
+ # inbound_mask = (x>-0.5).float()*(y>-0.5).float()*(x<W_f+0.5).float()*(y<H_f+0.5).float()
17
+
18
+ max_y = (H_f - 1).int()
19
+ max_x = (W_f - 1).int()
20
+
21
+ x0 = torch.floor(x).int()
22
+ x1 = x0 + 1
23
+ y0 = torch.floor(y).int()
24
+ y1 = y0 + 1
25
+
26
+ x0_clip = torch.clamp(x0, 0, max_x)
27
+ x1_clip = torch.clamp(x1, 0, max_x)
28
+ y0_clip = torch.clamp(y0, 0, max_y)
29
+ y1_clip = torch.clamp(y1, 0, max_y)
30
+ dim2 = W
31
+ dim1 = W * H
32
+
33
+ base = torch.arange(0, B, dtype=torch.int64, device=x.device)*dim1
34
+ base = torch.reshape(base, [B, 1]).repeat([1, N])
35
+
36
+ base_y0 = base + y0_clip * dim2
37
+ base_y1 = base + y1_clip * dim2
38
+
39
+ idx_y0_x0 = base_y0 + x0_clip
40
+ idx_y0_x1 = base_y0 + x1_clip
41
+ idx_y1_x0 = base_y1 + x0_clip
42
+ idx_y1_x1 = base_y1 + x1_clip
43
+
44
+ # use the indices to lookup pixels in the flat image
45
+ # im is B x C x H x W
46
+ # move C out to last dim
47
+ im_flat = (im.permute(0, 2, 3, 1)).reshape(B*H*W, C)
48
+ i_y0_x0 = im_flat[idx_y0_x0.long()]
49
+ i_y0_x1 = im_flat[idx_y0_x1.long()]
50
+ i_y1_x0 = im_flat[idx_y1_x0.long()]
51
+ i_y1_x1 = im_flat[idx_y1_x1.long()]
52
+
53
+ # Finally calculate interpolated values.
54
+ x0_f = x0.float()
55
+ x1_f = x1.float()
56
+ y0_f = y0.float()
57
+ y1_f = y1.float()
58
+
59
+ w_y0_x0 = ((x1_f - x) * (y1_f - y)).unsqueeze(2)
60
+ w_y0_x1 = ((x - x0_f) * (y1_f - y)).unsqueeze(2)
61
+ w_y1_x0 = ((x1_f - x) * (y - y0_f)).unsqueeze(2)
62
+ w_y1_x1 = ((x - x0_f) * (y - y0_f)).unsqueeze(2)
63
+
64
+ output = w_y0_x0 * i_y0_x0 + w_y0_x1 * i_y0_x1 + \
65
+ w_y1_x0 * i_y1_x0 + w_y1_x1 * i_y1_x1
66
+ # output is B*N x C
67
+ output = output.view(B, -1, C)
68
+ output = output.permute(0, 2, 1)
69
+ # output is B x C x N
70
+
71
+ if return_inbounds:
72
+ x_valid = (x > -0.5).byte() & (x < float(W_f - 0.5)).byte()
73
+ y_valid = (y > -0.5).byte() & (y < float(H_f - 0.5)).byte()
74
+ inbounds = (x_valid & y_valid).float()
75
+ inbounds = inbounds.reshape(B, N) # something seems wrong here for B>1; i'm getting an error here (or downstream if i put -1)
76
+ return output, inbounds
77
+
78
+ return output # B, C, N
79
+
80
+ def paste_crop_on_canvas(crop, box2d_unnorm, H, W, fast=True, mask=None, canvas=None):
81
+ # this is the inverse of crop_and_resize_box2d
82
+ B, C, Y, X = list(crop.shape)
83
+ B2, D = list(box2d_unnorm.shape)
84
+ assert(B == B2)
85
+ assert(D == 4)
86
+
87
+ # here, we want to place the crop into a bigger image,
88
+ # at the location specified by the box2d.
89
+
90
+ if canvas is None:
91
+ canvas = torch.zeros((B, C, H, W), device=crop.device)
92
+ else:
93
+ B2, C2, H2, W2 = canvas.shape
94
+ assert(B==B2)
95
+ assert(C==C2)
96
+ assert(H==H2)
97
+ assert(W==W2)
98
+
99
+ # box2d_unnorm = utils.geom.unnormalize_box2d(box2d, H, W)
100
+
101
+ if fast:
102
+ ymin = box2d_unnorm[:, 0].long()
103
+ xmin = box2d_unnorm[:, 1].long()
104
+ ymax = box2d_unnorm[:, 2].long()
105
+ xmax = box2d_unnorm[:, 3].long()
106
+ w = (xmax - xmin).float()
107
+ h = (ymax - ymin).float()
108
+
109
+ grids = utils.basic.gridcloud2d(B, H, W)
110
+ grids_flat = grids.reshape(B, -1, 2)
111
+ # grids_flat[:, :, 0] = (grids_flat[:, :, 0] - xmin.float().unsqueeze(1)) / w.unsqueeze(1) * X
112
+ # grids_flat[:, :, 1] = (grids_flat[:, :, 1] - ymin.float().unsqueeze(1)) / h.unsqueeze(1) * Y
113
+
114
+ # for each pixel in the main image,
115
+ # grids_flat tells us where to sample in the crop image
116
+
117
+ # print('grids_flat', grids_flat.shape)
118
+ # print('crop', crop.shape)
119
+
120
+ grids_flat[:, :, 0] = (grids_flat[:, :, 0] - xmin.float().unsqueeze(1)) / w.unsqueeze(1) * 2.0 - 1.0
121
+ grids_flat[:, :, 1] = (grids_flat[:, :, 1] - ymin.float().unsqueeze(1)) / h.unsqueeze(1) * 2.0 - 1.0
122
+
123
+ grid = grids_flat.reshape(B,H,W,2)
124
+
125
+ canvas = F.grid_sample(crop, grid, align_corners=False)
126
+ # print('canvas', canvas.shape)
127
+
128
+ # if mask is None:
129
+ # crop_resamp, inb = bilinear_sample2d(crop, grids_flat[:, :, 0], grids_flat[:, :, 1], return_inbounds=True)
130
+ # crop_resamp = crop_resamp.reshape(B, C, H, W)
131
+ # inb = inb.reshape(B, 1, H, W)
132
+ # canvas = canvas * (1 - inb) + crop_resamp * inb
133
+ # else:
134
+ # full_resamp = bilinear_sample2d(torch.cat([crop, mask], dim=1), grids_flat[:, :, 0], grids_flat[:, :, 1])
135
+ # full_resamp = full_resamp.reshape(B, C+1, H, W)
136
+ # crop_resamp = full_resamp[:,:3]
137
+ # mask_resamp = full_resamp[:,3:4]
138
+ # canvas = canvas * (1 - mask_resamp) + crop_resamp * mask_resamp
139
+ else:
140
+ for b in range(B):
141
+ ymin = box2d_unnorm[b, 0].long()
142
+ xmin = box2d_unnorm[b, 1].long()
143
+ ymax = box2d_unnorm[b, 2].long()
144
+ xmax = box2d_unnorm[b, 3].long()
145
+
146
+ crop_b = F.interpolate(crop[b:b + 1], (ymax - ymin, xmax - xmin)).squeeze(0)
147
+
148
+ # print('canvas[b,:,...', canvas[b,:,ymin:ymax,xmin:xmax].shape)
149
+ # print('crop_b', crop_b.shape)
150
+
151
+ canvas[b, :, ymin:ymax, xmin:xmax] = crop_b
152
+ return canvas
models/spatracker/utils/visualizer.py ADDED
@@ -0,0 +1,409 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import os
8
+ import numpy as np
9
+ import cv2
10
+ import torch
11
+ import flow_vis
12
+
13
+ from matplotlib import cm
14
+ import torch.nn.functional as F
15
+ import torchvision.transforms as transforms
16
+ from moviepy.editor import ImageSequenceClip
17
+ import matplotlib.pyplot as plt
18
+ from tqdm import tqdm
19
+
20
+ def read_video_from_path(path):
21
+ cap = cv2.VideoCapture(path)
22
+ if not cap.isOpened():
23
+ print("Error opening video file")
24
+ else:
25
+ frames = []
26
+ while cap.isOpened():
27
+ ret, frame = cap.read()
28
+ if ret == True:
29
+ frames.append(np.array(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)))
30
+ else:
31
+ break
32
+ cap.release()
33
+ return np.stack(frames)
34
+
35
+
36
+ class Visualizer:
37
+ def __init__(
38
+ self,
39
+ save_dir: str = "./results",
40
+ grayscale: bool = False,
41
+ pad_value: int = 0,
42
+ fps: int = 10,
43
+ mode: str = "rainbow", # 'cool', 'optical_flow'
44
+ linewidth: int = 1,
45
+ show_first_frame: int = 10,
46
+ tracks_leave_trace: int = 0, # -1 for infinite
47
+ ):
48
+ self.mode = mode
49
+ self.save_dir = save_dir
50
+ self.vtxt_path = os.path.join(save_dir, "videos.txt")
51
+ self.ttxt_path = os.path.join(save_dir, "trackings.txt")
52
+ if mode == "rainbow":
53
+ self.color_map = cm.get_cmap("gist_rainbow")
54
+ elif mode == "cool":
55
+ self.color_map = cm.get_cmap(mode)
56
+ self.show_first_frame = show_first_frame
57
+ self.grayscale = grayscale
58
+ self.tracks_leave_trace = tracks_leave_trace
59
+ self.pad_value = pad_value
60
+ self.linewidth = linewidth
61
+ self.fps = fps
62
+
63
+ def visualize(
64
+ self,
65
+ video: torch.Tensor, # (B,T,C,H,W)
66
+ tracks: torch.Tensor, # (B,T,N,2)
67
+ visibility: torch.Tensor = None, # (B, T, N, 1) bool
68
+ gt_tracks: torch.Tensor = None, # (B,T,N,2)
69
+ segm_mask: torch.Tensor = None, # (B,1,H,W)
70
+ filename: str = "video",
71
+ writer=None, # tensorboard Summary Writer, used for visualization during training
72
+ step: int = 0,
73
+ query_frame: int = 0,
74
+ save_video: bool = True,
75
+ compensate_for_camera_motion: bool = False,
76
+ rigid_part = None,
77
+ video_depth = None # (B,T,C,H,W)
78
+ ):
79
+ if compensate_for_camera_motion:
80
+ assert segm_mask is not None
81
+ if segm_mask is not None:
82
+ coords = tracks[0, query_frame].round().long()
83
+ segm_mask = segm_mask[0, query_frame][coords[:, 1], coords[:, 0]].long()
84
+
85
+ video = F.pad(
86
+ video,
87
+ (self.pad_value, self.pad_value, self.pad_value, self.pad_value),
88
+ "constant",
89
+ 255,
90
+ )
91
+
92
+ if video_depth is not None:
93
+ video_depth = (video_depth*255).cpu().numpy().astype(np.uint8)
94
+ video_depth = ([cv2.applyColorMap(video_depth[0,i,0], cv2.COLORMAP_INFERNO)
95
+ for i in range(video_depth.shape[1])])
96
+ video_depth = np.stack(video_depth, axis=0)
97
+ video_depth = torch.from_numpy(video_depth).permute(0, 3, 1, 2)[None]
98
+
99
+ tracks = tracks + self.pad_value
100
+
101
+ if self.grayscale:
102
+ transform = transforms.Grayscale()
103
+ video = transform(video)
104
+ video = video.repeat(1, 1, 3, 1, 1)
105
+
106
+ tracking_video = self.draw_tracks_on_video(
107
+ video=video,
108
+ tracks=tracks,
109
+ visibility=visibility,
110
+ segm_mask=segm_mask,
111
+ gt_tracks=gt_tracks,
112
+ query_frame=query_frame,
113
+ compensate_for_camera_motion=compensate_for_camera_motion,
114
+ rigid_part=rigid_part
115
+ )
116
+
117
+ if save_video:
118
+ # import ipdb; ipdb.set_trace()
119
+ tracking_dir = os.path.join(self.save_dir, "tracking")
120
+ if not os.path.exists(tracking_dir):
121
+ os.makedirs(tracking_dir)
122
+ self.save_video(tracking_video, filename=filename+"_tracking",
123
+ savedir=tracking_dir, writer=writer, step=step)
124
+ # with open(self.ttxt_path, 'a') as file:
125
+ # file.write(f"tracking/{filename}_tracking.mp4\n")
126
+
127
+ videos_dir = os.path.join(self.save_dir, "videos")
128
+ if not os.path.exists(videos_dir):
129
+ os.makedirs(videos_dir)
130
+ self.save_video(video, filename=filename,
131
+ savedir=videos_dir, writer=writer, step=step)
132
+ # with open(self.vtxt_path, 'a') as file:
133
+ # file.write(f"videos/{filename}.mp4\n")
134
+ if video_depth is not None:
135
+ self.save_video(video_depth, filename=filename+"_depth",
136
+ savedir=os.path.join(self.save_dir, "depth"), writer=writer, step=step)
137
+ return tracking_video
138
+
139
+ def save_video(self, video, filename, savedir=None, writer=None, step=0):
140
+ if writer is not None:
141
+ writer.add_video(
142
+ f"{filename}",
143
+ video.to(torch.uint8),
144
+ global_step=step,
145
+ fps=self.fps,
146
+ )
147
+ else:
148
+ os.makedirs(self.save_dir, exist_ok=True)
149
+ wide_list = list(video.unbind(1))
150
+ wide_list = [wide[0].permute(1, 2, 0).cpu().numpy() for wide in wide_list]
151
+ # clip = ImageSequenceClip(wide_list[2:-1], fps=self.fps)
152
+ clip = ImageSequenceClip(wide_list, fps=self.fps)
153
+
154
+ # Write the video file
155
+ if savedir is None:
156
+ save_path = os.path.join(self.save_dir, f"{filename}.mp4")
157
+ else:
158
+ save_path = os.path.join(savedir, f"{filename}.mp4")
159
+ clip.write_videofile(save_path, codec="libx264", fps=self.fps, logger=None)
160
+
161
+ print(f"Video saved to {save_path}")
162
+
163
+ def draw_tracks_on_video(
164
+ self,
165
+ video: torch.Tensor,
166
+ tracks: torch.Tensor,
167
+ visibility: torch.Tensor = None,
168
+ segm_mask: torch.Tensor = None,
169
+ gt_tracks=None,
170
+ query_frame: int = 0,
171
+ compensate_for_camera_motion=False,
172
+ rigid_part=None,
173
+ ):
174
+ B, T, C, H, W = video.shape
175
+ _, _, N, D = tracks.shape
176
+
177
+ assert D == 3
178
+ assert C == 3
179
+ video = video[0].permute(0, 2, 3, 1).byte().detach().cpu().numpy() # S, H, W, C
180
+ tracks = tracks[0].detach().cpu().numpy() # S, N, 2
181
+ if gt_tracks is not None:
182
+ gt_tracks = gt_tracks[0].detach().cpu().numpy()
183
+
184
+ res_video = []
185
+
186
+ # process input video
187
+ # for rgb in video:
188
+ # res_video.append(rgb.copy())
189
+
190
+ # create a blank tensor with the same shape as the video
191
+ for rgb in video:
192
+ black_frame = np.zeros_like(rgb.copy(), dtype=rgb.dtype)
193
+ res_video.append(black_frame)
194
+
195
+ vector_colors = np.zeros((T, N, 3))
196
+
197
+ if self.mode == "optical_flow":
198
+
199
+ vector_colors = flow_vis.flow_to_color(tracks - tracks[query_frame][None])
200
+
201
+ elif segm_mask is None:
202
+ if self.mode == "rainbow":
203
+ x_min, x_max = tracks[0, :, 0].min(), tracks[0, :, 0].max()
204
+ y_min, y_max = tracks[0, :, 1].min(), tracks[0, :, 1].max()
205
+
206
+ z_inv = 1/tracks[0, :, 2]
207
+ z_min, z_max = np.percentile(z_inv, [2, 98])
208
+
209
+ norm_x = plt.Normalize(x_min, x_max)
210
+ norm_y = plt.Normalize(y_min, y_max)
211
+ norm_z = plt.Normalize(z_min, z_max)
212
+
213
+ for n in range(N):
214
+ r = norm_x(tracks[0, n, 0])
215
+ g = norm_y(tracks[0, n, 1])
216
+ # r = 0
217
+ # g = 0
218
+ b = norm_z(1/tracks[0, n, 2])
219
+ color = np.array([r, g, b])[None] * 255
220
+ vector_colors[:, n] = np.repeat(color, T, axis=0)
221
+ else:
222
+ # color changes with time
223
+ for t in range(T):
224
+ color = np.array(self.color_map(t / T)[:3])[None] * 255
225
+ vector_colors[t] = np.repeat(color, N, axis=0)
226
+ else:
227
+ if self.mode == "rainbow":
228
+ vector_colors[:, segm_mask <= 0, :] = 255
229
+
230
+ x_min, x_max = tracks[0, :, 0].min(), tracks[0, :, 0].max()
231
+ y_min, y_max = tracks[0, :, 1].min(), tracks[0, :, 1].max()
232
+ z_min, z_max = tracks[0, :, 2].min(), tracks[0, :, 2].max()
233
+
234
+ norm_x = plt.Normalize(x_min, x_max)
235
+ norm_y = plt.Normalize(y_min, y_max)
236
+ norm_z = plt.Normalize(z_min, z_max)
237
+
238
+ for n in range(N):
239
+ r = norm_x(tracks[0, n, 0])
240
+ g = norm_y(tracks[0, n, 1])
241
+ b = norm_z(tracks[0, n, 2])
242
+ color = np.array([r, g, b])[None] * 255
243
+ vector_colors[:, n] = np.repeat(color, T, axis=0)
244
+
245
+ else:
246
+ # color changes with segm class
247
+ segm_mask = segm_mask.cpu()
248
+ color = np.zeros((segm_mask.shape[0], 3), dtype=np.float32)
249
+ color[segm_mask > 0] = np.array(self.color_map(1.0)[:3]) * 255.0
250
+ color[segm_mask <= 0] = np.array(self.color_map(0.0)[:3]) * 255.0
251
+ vector_colors = np.repeat(color[None], T, axis=0)
252
+
253
+ # Draw tracks
254
+ if self.tracks_leave_trace != 0:
255
+ for t in range(1, T):
256
+ first_ind = (
257
+ max(0, t - self.tracks_leave_trace)
258
+ if self.tracks_leave_trace >= 0
259
+ else 0
260
+ )
261
+ curr_tracks = tracks[first_ind : t + 1]
262
+ curr_colors = vector_colors[first_ind : t + 1]
263
+ if compensate_for_camera_motion:
264
+ diff = (
265
+ tracks[first_ind : t + 1, segm_mask <= 0]
266
+ - tracks[t : t + 1, segm_mask <= 0]
267
+ ).mean(1)[:, None]
268
+
269
+ curr_tracks = curr_tracks - diff
270
+ curr_tracks = curr_tracks[:, segm_mask > 0]
271
+ curr_colors = curr_colors[:, segm_mask > 0]
272
+
273
+ res_video[t] = self._draw_pred_tracks(
274
+ res_video[t],
275
+ curr_tracks,
276
+ curr_colors,
277
+ )
278
+ if gt_tracks is not None:
279
+ res_video[t] = self._draw_gt_tracks(
280
+ res_video[t], gt_tracks[first_ind : t + 1]
281
+ )
282
+
283
+ if rigid_part is not None:
284
+ cls_label = torch.unique(rigid_part)
285
+ cls_num = len(torch.unique(rigid_part))
286
+ # visualize the clustering results
287
+ cmap = plt.get_cmap('jet') # get the color mapping
288
+ colors = cmap(np.linspace(0, 1, cls_num))
289
+ colors = (colors[:, :3] * 255)
290
+ color_map = {lable.item(): color for lable, color in zip(cls_label, colors)}
291
+
292
+ # Draw points
293
+ for t in tqdm(range(T)):
294
+ # Create a list to store information for each point
295
+ points_info = []
296
+ for i in range(N):
297
+ coord = (tracks[t, i, 0], tracks[t, i, 1])
298
+ depth = tracks[t, i, 2] # assume the third dimension is depth
299
+ visibile = True
300
+ if visibility is not None:
301
+ visibile = visibility[0, t, i]
302
+ if coord[0] != 0 and coord[1] != 0:
303
+ if not compensate_for_camera_motion or (
304
+ compensate_for_camera_motion and segm_mask[i] > 0
305
+ ):
306
+ points_info.append((i, coord, depth, visibile))
307
+
308
+ # Sort points by depth, points with smaller depth (closer) will be drawn later
309
+ points_info.sort(key=lambda x: x[2], reverse=True)
310
+
311
+ for i, coord, _, visibile in points_info:
312
+ if rigid_part is not None:
313
+ color = color_map[rigid_part.squeeze()[i].item()]
314
+ cv2.circle(
315
+ res_video[t],
316
+ coord,
317
+ int(self.linewidth * 2),
318
+ color.tolist(),
319
+ thickness=-1 if visibile else 2
320
+ -1,
321
+ )
322
+ else:
323
+ # Determine rectangle width based on the distance between adjacent tracks in the first frame
324
+ if t == 0:
325
+ distances = np.linalg.norm(tracks[0] - tracks[0, i], axis=1)
326
+ distances = distances[distances > 0]
327
+ rect_size = int(np.min(distances))/2
328
+
329
+ # Define coordinates for top-left and bottom-right corners of the rectangle
330
+ top_left = (int(coord[0] - rect_size), int(coord[1] - rect_size/1.5)) # Rectangle width is 1.5x (video aspect ratio is 1.5:1)
331
+ bottom_right = (int(coord[0] + rect_size), int(coord[1] + rect_size/1.5))
332
+
333
+ # Draw rectangle
334
+ cv2.rectangle(
335
+ res_video[t],
336
+ top_left,
337
+ bottom_right,
338
+ vector_colors[t, i].tolist(),
339
+ thickness=-1 if visibile else 0
340
+ -1,
341
+ )
342
+
343
+ # Construct the final rgb sequence
344
+ return torch.from_numpy(np.stack(res_video)).permute(0, 3, 1, 2)[None].byte()
345
+
346
+ def _draw_pred_tracks(
347
+ self,
348
+ rgb: np.ndarray, # H x W x 3
349
+ tracks: np.ndarray, # T x 2
350
+ vector_colors: np.ndarray,
351
+ alpha: float = 0.5,
352
+ ):
353
+ T, N, _ = tracks.shape
354
+
355
+ for s in range(T - 1):
356
+ vector_color = vector_colors[s]
357
+ original = rgb.copy()
358
+ alpha = (s / T) ** 2
359
+ for i in range(N):
360
+ coord_y = (int(tracks[s, i, 0]), int(tracks[s, i, 1]))
361
+ coord_x = (int(tracks[s + 1, i, 0]), int(tracks[s + 1, i, 1]))
362
+ if coord_y[0] != 0 and coord_y[1] != 0:
363
+ cv2.line(
364
+ rgb,
365
+ coord_y,
366
+ coord_x,
367
+ vector_color[i].tolist(),
368
+ self.linewidth,
369
+ cv2.LINE_AA,
370
+ )
371
+ if self.tracks_leave_trace > 0:
372
+ rgb = cv2.addWeighted(rgb, alpha, original, 1 - alpha, 0)
373
+ return rgb
374
+
375
+ def _draw_gt_tracks(
376
+ self,
377
+ rgb: np.ndarray, # H x W x 3,
378
+ gt_tracks: np.ndarray, # T x 2
379
+ ):
380
+ T, N, _ = gt_tracks.shape
381
+ color = np.array((211.0, 0.0, 0.0))
382
+
383
+ for t in range(T):
384
+ for i in range(N):
385
+ gt_tracks = gt_tracks[t][i]
386
+ # draw a red cross
387
+ if gt_tracks[0] > 0 and gt_tracks[1] > 0:
388
+ length = self.linewidth * 3
389
+ coord_y = (int(gt_tracks[0]) + length, int(gt_tracks[1]) + length)
390
+ coord_x = (int(gt_tracks[0]) - length, int(gt_tracks[1]) - length)
391
+ cv2.line(
392
+ rgb,
393
+ coord_y,
394
+ coord_x,
395
+ color,
396
+ self.linewidth,
397
+ cv2.LINE_AA,
398
+ )
399
+ coord_y = (int(gt_tracks[0]) - length, int(gt_tracks[1]) + length)
400
+ coord_x = (int(gt_tracks[0]) + length, int(gt_tracks[1]) - length)
401
+ cv2.line(
402
+ rgb,
403
+ coord_y,
404
+ coord_x,
405
+ color,
406
+ self.linewidth,
407
+ cv2.LINE_AA,
408
+ )
409
+ return rgb
models/spatracker/utils/vox.py ADDED
@@ -0,0 +1,500 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import torch.nn.functional as F
4
+
5
+ import utils.geom
6
+
7
+ class Vox_util(object):
8
+ def __init__(self, Z, Y, X, scene_centroid, bounds, pad=None, assert_cube=False):
9
+ self.XMIN, self.XMAX, self.YMIN, self.YMAX, self.ZMIN, self.ZMAX = bounds
10
+ B, D = list(scene_centroid.shape)
11
+ self.Z, self.Y, self.X = Z, Y, X
12
+
13
+ scene_centroid = scene_centroid.detach().cpu().numpy()
14
+ x_centroid, y_centroid, z_centroid = scene_centroid[0]
15
+ self.XMIN += x_centroid
16
+ self.XMAX += x_centroid
17
+ self.YMIN += y_centroid
18
+ self.YMAX += y_centroid
19
+ self.ZMIN += z_centroid
20
+ self.ZMAX += z_centroid
21
+
22
+ self.default_vox_size_X = (self.XMAX-self.XMIN)/float(X)
23
+ self.default_vox_size_Y = (self.YMAX-self.YMIN)/float(Y)
24
+ self.default_vox_size_Z = (self.ZMAX-self.ZMIN)/float(Z)
25
+
26
+ if pad:
27
+ Z_pad, Y_pad, X_pad = pad
28
+ self.ZMIN -= self.default_vox_size_Z * Z_pad
29
+ self.ZMAX += self.default_vox_size_Z * Z_pad
30
+ self.YMIN -= self.default_vox_size_Y * Y_pad
31
+ self.YMAX += self.default_vox_size_Y * Y_pad
32
+ self.XMIN -= self.default_vox_size_X * X_pad
33
+ self.XMAX += self.default_vox_size_X * X_pad
34
+
35
+ if assert_cube:
36
+ # we assume cube voxels
37
+ if (not np.isclose(self.default_vox_size_X, self.default_vox_size_Y)) or (not np.isclose(self.default_vox_size_X, self.default_vox_size_Z)):
38
+ print('Z, Y, X', Z, Y, X)
39
+ print('bounds for this iter:',
40
+ 'X = %.2f to %.2f' % (self.XMIN, self.XMAX),
41
+ 'Y = %.2f to %.2f' % (self.YMIN, self.YMAX),
42
+ 'Z = %.2f to %.2f' % (self.ZMIN, self.ZMAX),
43
+ )
44
+ print('self.default_vox_size_X', self.default_vox_size_X)
45
+ print('self.default_vox_size_Y', self.default_vox_size_Y)
46
+ print('self.default_vox_size_Z', self.default_vox_size_Z)
47
+ assert(np.isclose(self.default_vox_size_X, self.default_vox_size_Y))
48
+ assert(np.isclose(self.default_vox_size_X, self.default_vox_size_Z))
49
+
50
+ def Ref2Mem(self, xyz, Z, Y, X, assert_cube=False):
51
+ # xyz is B x N x 3, in ref coordinates
52
+ # transforms ref coordinates into mem coordinates
53
+ B, N, C = list(xyz.shape)
54
+ device = xyz.device
55
+ assert(C==3)
56
+ mem_T_ref = self.get_mem_T_ref(B, Z, Y, X, assert_cube=assert_cube, device=device)
57
+ xyz = utils.geom.apply_4x4(mem_T_ref, xyz)
58
+ return xyz
59
+
60
+ def Mem2Ref(self, xyz_mem, Z, Y, X, assert_cube=False):
61
+ # xyz is B x N x 3, in mem coordinates
62
+ # transforms mem coordinates into ref coordinates
63
+ B, N, C = list(xyz_mem.shape)
64
+ ref_T_mem = self.get_ref_T_mem(B, Z, Y, X, assert_cube=assert_cube, device=xyz_mem.device)
65
+ xyz_ref = utils.geom.apply_4x4(ref_T_mem, xyz_mem)
66
+ return xyz_ref
67
+
68
+ def get_mem_T_ref(self, B, Z, Y, X, assert_cube=False, device='cuda'):
69
+ vox_size_X = (self.XMAX-self.XMIN)/float(X)
70
+ vox_size_Y = (self.YMAX-self.YMIN)/float(Y)
71
+ vox_size_Z = (self.ZMAX-self.ZMIN)/float(Z)
72
+
73
+ if assert_cube:
74
+ if (not np.isclose(vox_size_X, vox_size_Y)) or (not np.isclose(vox_size_X, vox_size_Z)):
75
+ print('Z, Y, X', Z, Y, X)
76
+ print('bounds for this iter:',
77
+ 'X = %.2f to %.2f' % (self.XMIN, self.XMAX),
78
+ 'Y = %.2f to %.2f' % (self.YMIN, self.YMAX),
79
+ 'Z = %.2f to %.2f' % (self.ZMIN, self.ZMAX),
80
+ )
81
+ print('vox_size_X', vox_size_X)
82
+ print('vox_size_Y', vox_size_Y)
83
+ print('vox_size_Z', vox_size_Z)
84
+ assert(np.isclose(vox_size_X, vox_size_Y))
85
+ assert(np.isclose(vox_size_X, vox_size_Z))
86
+
87
+ # translation
88
+ # (this makes the left edge of the leftmost voxel correspond to XMIN)
89
+ center_T_ref = utils.geom.eye_4x4(B, device=device)
90
+ center_T_ref[:,0,3] = -self.XMIN-vox_size_X/2.0
91
+ center_T_ref[:,1,3] = -self.YMIN-vox_size_Y/2.0
92
+ center_T_ref[:,2,3] = -self.ZMIN-vox_size_Z/2.0
93
+
94
+ # scaling
95
+ # (this makes the right edge of the rightmost voxel correspond to XMAX)
96
+ mem_T_center = utils.geom.eye_4x4(B, device=device)
97
+ mem_T_center[:,0,0] = 1./vox_size_X
98
+ mem_T_center[:,1,1] = 1./vox_size_Y
99
+ mem_T_center[:,2,2] = 1./vox_size_Z
100
+ mem_T_ref = utils.geom.matmul2(mem_T_center, center_T_ref)
101
+
102
+ return mem_T_ref
103
+
104
+ def get_ref_T_mem(self, B, Z, Y, X, assert_cube=False, device='cuda'):
105
+ mem_T_ref = self.get_mem_T_ref(B, Z, Y, X, assert_cube=assert_cube, device=device)
106
+ # note safe_inverse is inapplicable here,
107
+ # since the transform is nonrigid
108
+ ref_T_mem = mem_T_ref.inverse()
109
+ return ref_T_mem
110
+
111
+ def get_inbounds(self, xyz, Z, Y, X, already_mem=False, padding=0.0, assert_cube=False):
112
+ # xyz is B x N x 3
113
+ # padding should be 0 unless you are trying to account for some later cropping
114
+ if not already_mem:
115
+ xyz = self.Ref2Mem(xyz, Z, Y, X, assert_cube=assert_cube)
116
+
117
+ x = xyz[:,:,0]
118
+ y = xyz[:,:,1]
119
+ z = xyz[:,:,2]
120
+
121
+ x_valid = ((x-padding)>-0.5).byte() & ((x+padding)<float(X-0.5)).byte()
122
+ y_valid = ((y-padding)>-0.5).byte() & ((y+padding)<float(Y-0.5)).byte()
123
+ z_valid = ((z-padding)>-0.5).byte() & ((z+padding)<float(Z-0.5)).byte()
124
+ nonzero = (~(z==0.0)).byte()
125
+
126
+ inbounds = x_valid & y_valid & z_valid & nonzero
127
+ return inbounds.bool()
128
+
129
+ def voxelize_xyz(self, xyz_ref, Z, Y, X, already_mem=False, assert_cube=False, clean_eps=0):
130
+ B, N, D = list(xyz_ref.shape)
131
+ assert(D==3)
132
+ if already_mem:
133
+ xyz_mem = xyz_ref
134
+ else:
135
+ xyz_mem = self.Ref2Mem(xyz_ref, Z, Y, X, assert_cube=assert_cube)
136
+ xyz_zero = self.Ref2Mem(xyz_ref[:,0:1]*0, Z, Y, X, assert_cube=assert_cube)
137
+ vox = self.get_occupancy(xyz_mem, Z, Y, X, clean_eps=clean_eps, xyz_zero=xyz_zero)
138
+ return vox
139
+
140
+ def voxelize_xyz_and_feats(self, xyz_ref, feats, Z, Y, X, already_mem=False, assert_cube=False, clean_eps=0):
141
+ B, N, D = list(xyz_ref.shape)
142
+ B2, N2, D2 = list(feats.shape)
143
+ assert(D==3)
144
+ assert(B==B2)
145
+ assert(N==N2)
146
+ if already_mem:
147
+ xyz_mem = xyz_ref
148
+ else:
149
+ xyz_mem = self.Ref2Mem(xyz_ref, Z, Y, X, assert_cube=assert_cube)
150
+ xyz_zero = self.Ref2Mem(xyz_ref[:,0:1]*0, Z, Y, X, assert_cube=assert_cube)
151
+ feats = self.get_feat_occupancy(xyz_mem, feats, Z, Y, X, clean_eps=clean_eps, xyz_zero=xyz_zero)
152
+ return feats
153
+
154
+ def get_occupancy(self, xyz, Z, Y, X, clean_eps=0, xyz_zero=None):
155
+ # xyz is B x N x 3 and in mem coords
156
+ # we want to fill a voxel tensor with 1's at these inds
157
+ B, N, C = list(xyz.shape)
158
+ assert(C==3)
159
+
160
+ # these papers say simple 1/0 occupancy is ok:
161
+ # http://openaccess.thecvf.com/content_cvpr_2018/papers/Yang_PIXOR_Real-Time_3d_CVPR_2018_paper.pdf
162
+ # http://openaccess.thecvf.com/content_cvpr_2018/papers/Luo_Fast_and_Furious_CVPR_2018_paper.pdf
163
+ # cont fusion says they do 8-neighbor interp
164
+ # voxelnet does occupancy but with a bit of randomness in terms of the reflectance value i think
165
+
166
+ inbounds = self.get_inbounds(xyz, Z, Y, X, already_mem=True)
167
+ x, y, z = xyz[:,:,0], xyz[:,:,1], xyz[:,:,2]
168
+ mask = torch.zeros_like(x)
169
+ mask[inbounds] = 1.0
170
+
171
+ if xyz_zero is not None:
172
+ # only take points that are beyond a thresh of zero
173
+ dist = torch.norm(xyz_zero-xyz, dim=2)
174
+ mask[dist < 0.1] = 0
175
+
176
+ if clean_eps > 0:
177
+ # only take points that are already near centers
178
+ xyz_round = torch.round(xyz) # B, N, 3
179
+ dist = torch.norm(xyz_round - xyz, dim=2)
180
+ mask[dist > clean_eps] = 0
181
+
182
+ # set the invalid guys to zero
183
+ # we then need to zero out 0,0,0
184
+ # (this method seems a bit clumsy)
185
+ x = x*mask
186
+ y = y*mask
187
+ z = z*mask
188
+
189
+ x = torch.round(x)
190
+ y = torch.round(y)
191
+ z = torch.round(z)
192
+ x = torch.clamp(x, 0, X-1).int()
193
+ y = torch.clamp(y, 0, Y-1).int()
194
+ z = torch.clamp(z, 0, Z-1).int()
195
+
196
+ x = x.view(B*N)
197
+ y = y.view(B*N)
198
+ z = z.view(B*N)
199
+
200
+ dim3 = X
201
+ dim2 = X * Y
202
+ dim1 = X * Y * Z
203
+
204
+ base = torch.arange(0, B, dtype=torch.int32, device=xyz.device)*dim1
205
+ base = torch.reshape(base, [B, 1]).repeat([1, N]).view(B*N)
206
+
207
+ vox_inds = base + z * dim2 + y * dim3 + x
208
+ voxels = torch.zeros(B*Z*Y*X, device=xyz.device).float()
209
+ voxels[vox_inds.long()] = 1.0
210
+ # zero out the singularity
211
+ voxels[base.long()] = 0.0
212
+ voxels = voxels.reshape(B, 1, Z, Y, X)
213
+ # B x 1 x Z x Y x X
214
+ return voxels
215
+
216
+ def get_feat_occupancy(self, xyz, feat, Z, Y, X, clean_eps=0, xyz_zero=None):
217
+ # xyz is B x N x 3 and in mem coords
218
+ # feat is B x N x D
219
+ # we want to fill a voxel tensor with 1's at these inds
220
+ B, N, C = list(xyz.shape)
221
+ B2, N2, D2 = list(feat.shape)
222
+ assert(C==3)
223
+ assert(B==B2)
224
+ assert(N==N2)
225
+
226
+ # these papers say simple 1/0 occupancy is ok:
227
+ # http://openaccess.thecvf.com/content_cvpr_2018/papers/Yang_PIXOR_Real-Time_3d_CVPR_2018_paper.pdf
228
+ # http://openaccess.thecvf.com/content_cvpr_2018/papers/Luo_Fast_and_Furious_CVPR_2018_paper.pdf
229
+ # cont fusion says they do 8-neighbor interp
230
+ # voxelnet does occupancy but with a bit of randomness in terms of the reflectance value i think
231
+
232
+ inbounds = self.get_inbounds(xyz, Z, Y, X, already_mem=True)
233
+ x, y, z = xyz[:,:,0], xyz[:,:,1], xyz[:,:,2]
234
+ mask = torch.zeros_like(x)
235
+ mask[inbounds] = 1.0
236
+
237
+ if xyz_zero is not None:
238
+ # only take points that are beyond a thresh of zero
239
+ dist = torch.norm(xyz_zero-xyz, dim=2)
240
+ mask[dist < 0.1] = 0
241
+
242
+ if clean_eps > 0:
243
+ # only take points that are already near centers
244
+ xyz_round = torch.round(xyz) # B, N, 3
245
+ dist = torch.norm(xyz_round - xyz, dim=2)
246
+ mask[dist > clean_eps] = 0
247
+
248
+ # set the invalid guys to zero
249
+ # we then need to zero out 0,0,0
250
+ # (this method seems a bit clumsy)
251
+ x = x*mask # B, N
252
+ y = y*mask
253
+ z = z*mask
254
+ feat = feat*mask.unsqueeze(-1) # B, N, D
255
+
256
+ x = torch.round(x)
257
+ y = torch.round(y)
258
+ z = torch.round(z)
259
+ x = torch.clamp(x, 0, X-1).int()
260
+ y = torch.clamp(y, 0, Y-1).int()
261
+ z = torch.clamp(z, 0, Z-1).int()
262
+
263
+ # permute point orders
264
+ perm = torch.randperm(N)
265
+ x = x[:, perm]
266
+ y = y[:, perm]
267
+ z = z[:, perm]
268
+ feat = feat[:, perm]
269
+
270
+ x = x.view(B*N)
271
+ y = y.view(B*N)
272
+ z = z.view(B*N)
273
+ feat = feat.view(B*N, -1)
274
+
275
+ dim3 = X
276
+ dim2 = X * Y
277
+ dim1 = X * Y * Z
278
+
279
+ base = torch.arange(0, B, dtype=torch.int32, device=xyz.device)*dim1
280
+ base = torch.reshape(base, [B, 1]).repeat([1, N]).view(B*N)
281
+
282
+ vox_inds = base + z * dim2 + y * dim3 + x
283
+ feat_voxels = torch.zeros((B*Z*Y*X, D2), device=xyz.device).float()
284
+ feat_voxels[vox_inds.long()] = feat
285
+ # zero out the singularity
286
+ feat_voxels[base.long()] = 0.0
287
+ feat_voxels = feat_voxels.reshape(B, Z, Y, X, D2).permute(0, 4, 1, 2, 3)
288
+ # B x C x Z x Y x X
289
+ return feat_voxels
290
+
291
+ def unproject_image_to_mem(self, rgb_camB, pixB_T_camA, camB_T_camA, Z, Y, X, assert_cube=False, xyz_camA=None):
292
+ # rgb_camB is B x C x H x W
293
+ # pixB_T_camA is B x 4 x 4
294
+
295
+ # rgb lives in B pixel coords
296
+ # we want everything in A memory coords
297
+
298
+ # this puts each C-dim pixel in the rgb_camB
299
+ # along a ray in the voxelgrid
300
+ B, C, H, W = list(rgb_camB.shape)
301
+
302
+ if xyz_camA is None:
303
+ xyz_memA = utils.basic.gridcloud3d(B, Z, Y, X, norm=False, device=pixB_T_camA.device)
304
+ xyz_camA = self.Mem2Ref(xyz_memA, Z, Y, X, assert_cube=assert_cube)
305
+
306
+ xyz_camB = utils.geom.apply_4x4(camB_T_camA, xyz_camA)
307
+ z = xyz_camB[:,:,2]
308
+
309
+ xyz_pixB = utils.geom.apply_4x4(pixB_T_camA, xyz_camA)
310
+ normalizer = torch.unsqueeze(xyz_pixB[:,:,2], 2)
311
+ EPS=1e-6
312
+ # z = xyz_pixB[:,:,2]
313
+ xy_pixB = xyz_pixB[:,:,:2]/torch.clamp(normalizer, min=EPS)
314
+ # this is B x N x 2
315
+ # this is the (floating point) pixel coordinate of each voxel
316
+ x, y = xy_pixB[:,:,0], xy_pixB[:,:,1]
317
+ # these are B x N
318
+
319
+ x_valid = (x>-0.5).bool() & (x<float(W-0.5)).bool()
320
+ y_valid = (y>-0.5).bool() & (y<float(H-0.5)).bool()
321
+ z_valid = (z>0.0).bool()
322
+ valid_mem = (x_valid & y_valid & z_valid).reshape(B, 1, Z, Y, X).float()
323
+
324
+ if (0):
325
+ # handwritten version
326
+ values = torch.zeros([B, C, Z*Y*X], dtype=torch.float32)
327
+ for b in list(range(B)):
328
+ values[b] = utils.samp.bilinear_sample_single(rgb_camB[b], x_pixB[b], y_pixB[b])
329
+ else:
330
+ # native pytorch version
331
+ y_pixB, x_pixB = utils.basic.normalize_grid2d(y, x, H, W)
332
+ # since we want a 3d output, we need 5d tensors
333
+ z_pixB = torch.zeros_like(x)
334
+ xyz_pixB = torch.stack([x_pixB, y_pixB, z_pixB], axis=2)
335
+ rgb_camB = rgb_camB.unsqueeze(2)
336
+ xyz_pixB = torch.reshape(xyz_pixB, [B, Z, Y, X, 3])
337
+ values = F.grid_sample(rgb_camB, xyz_pixB, align_corners=False)
338
+
339
+ values = torch.reshape(values, (B, C, Z, Y, X))
340
+ values = values * valid_mem
341
+ return values
342
+
343
+ def warp_tiled_to_mem(self, rgb_tileB, pixB_T_camA, camB_T_camA, Z, Y, X, DMIN, DMAX, assert_cube=False):
344
+ # rgb_tileB is B,C,D,H,W
345
+ # pixB_T_camA is B,4,4
346
+ # camB_T_camA is B,4,4
347
+
348
+ # rgb_tileB lives in B pixel coords but it has been tiled across the Z dimension
349
+ # we want everything in A memory coords
350
+
351
+ # this resamples the so that each C-dim pixel in rgb_tilB
352
+ # is put into its correct place in the voxelgrid
353
+ # (using the pinhole camera model)
354
+
355
+ B, C, D, H, W = list(rgb_tileB.shape)
356
+
357
+ xyz_memA = utils.basic.gridcloud3d(B, Z, Y, X, norm=False, device=pixB_T_camA.device)
358
+
359
+ xyz_camA = self.Mem2Ref(xyz_memA, Z, Y, X, assert_cube=assert_cube)
360
+
361
+ xyz_camB = utils.geom.apply_4x4(camB_T_camA, xyz_camA)
362
+ z_camB = xyz_camB[:,:,2]
363
+
364
+ # rgb_tileB has depth=DMIN in tile 0, and depth=DMAX in tile D-1
365
+ z_tileB = (D-1.0) * (z_camB-float(DMIN)) / float(DMAX-DMIN)
366
+
367
+ xyz_pixB = utils.geom.apply_4x4(pixB_T_camA, xyz_camA)
368
+ normalizer = torch.unsqueeze(xyz_pixB[:,:,2], 2)
369
+ EPS=1e-6
370
+ # z = xyz_pixB[:,:,2]
371
+ xy_pixB = xyz_pixB[:,:,:2]/torch.clamp(normalizer, min=EPS)
372
+ # this is B x N x 2
373
+ # this is the (floating point) pixel coordinate of each voxel
374
+ x, y = xy_pixB[:,:,0], xy_pixB[:,:,1]
375
+ # these are B x N
376
+
377
+ x_valid = (x>-0.5).bool() & (x<float(W-0.5)).bool()
378
+ y_valid = (y>-0.5).bool() & (y<float(H-0.5)).bool()
379
+ z_valid = (z_camB>0.0).bool()
380
+ valid_mem = (x_valid & y_valid & z_valid).reshape(B, 1, Z, Y, X).float()
381
+
382
+ z_tileB, y_pixB, x_pixB = utils.basic.normalize_grid3d(z_tileB, y, x, D, H, W)
383
+ xyz_pixB = torch.stack([x_pixB, y_pixB, z_tileB], axis=2)
384
+ xyz_pixB = torch.reshape(xyz_pixB, [B, Z, Y, X, 3])
385
+ values = F.grid_sample(rgb_tileB, xyz_pixB, align_corners=False)
386
+
387
+ values = torch.reshape(values, (B, C, Z, Y, X))
388
+ values = values * valid_mem
389
+ return values
390
+
391
+
392
+ def apply_mem_T_ref_to_lrtlist(self, lrtlist_cam, Z, Y, X, assert_cube=False):
393
+ # lrtlist is B x N x 19, in cam coordinates
394
+ # transforms them into mem coordinates, including a scale change for the lengths
395
+ B, N, C = list(lrtlist_cam.shape)
396
+ assert(C==19)
397
+ mem_T_cam = self.get_mem_T_ref(B, Z, Y, X, assert_cube=assert_cube, device=lrtlist_cam.device)
398
+
399
+ def xyz2circles(self, xyz, radius, Z, Y, X, soft=True, already_mem=True, also_offset=False, grid=None):
400
+ # xyz is B x N x 3
401
+ # radius is B x N or broadcastably so
402
+ # output is B x N x Z x Y x X
403
+ B, N, D = list(xyz.shape)
404
+ assert(D==3)
405
+ if not already_mem:
406
+ xyz = self.Ref2Mem(xyz, Z, Y, X)
407
+
408
+ if grid is None:
409
+ grid_z, grid_y, grid_x = utils.basic.meshgrid3d(B, Z, Y, X, stack=False, norm=False, device=xyz.device)
410
+ # note the default stack is on -1
411
+ grid = torch.stack([grid_x, grid_y, grid_z], dim=1)
412
+ # this is B x 3 x Z x Y x X
413
+
414
+ xyz = xyz.reshape(B, N, 3, 1, 1, 1)
415
+ grid = grid.reshape(B, 1, 3, Z, Y, X)
416
+ # this is B x N x Z x Y x X
417
+
418
+ # round the xyzs, so that at least one value matches the grid perfectly,
419
+ # and we get a value of 1 there (since exp(0)==1)
420
+ xyz = xyz.round()
421
+
422
+ if torch.is_tensor(radius):
423
+ radius = radius.clamp(min=0.01)
424
+
425
+ if soft:
426
+ off = grid - xyz # B,N,3,Z,Y,X
427
+ # interpret radius as sigma
428
+ dist_grid = torch.sum(off**2, dim=2, keepdim=False)
429
+ # this is B x N x Z x Y x X
430
+ if torch.is_tensor(radius):
431
+ radius = radius.reshape(B, N, 1, 1, 1)
432
+ mask = torch.exp(-dist_grid/(2*radius*radius))
433
+ # zero out near zero
434
+ mask[mask < 0.001] = 0.0
435
+ # h = np.exp(-(x * x + y * y) / (2 * sigma * sigma))
436
+ # h[h < np.finfo(h.dtype).eps * h.max()] = 0
437
+ # return h
438
+ if also_offset:
439
+ return mask, off
440
+ else:
441
+ return mask
442
+ else:
443
+ assert(False) # something is wrong with this. come back later to debug
444
+
445
+ dist_grid = torch.norm(grid - xyz, dim=2, keepdim=False)
446
+ # this is 0 at/near the xyz, and increases by 1 for each voxel away
447
+
448
+ radius = radius.reshape(B, N, 1, 1, 1)
449
+
450
+ within_radius_mask = (dist_grid < radius).float()
451
+ within_radius_mask = torch.sum(within_radius_mask, dim=1, keepdim=True).clamp(0, 1)
452
+ return within_radius_mask
453
+
454
+ def xyz2circles_bev(self, xyz, radius, Z, Y, X, already_mem=True, also_offset=False):
455
+ # xyz is B x N x 3
456
+ # radius is B x N or broadcastably so
457
+ # output is B x N x Z x Y x X
458
+ B, N, D = list(xyz.shape)
459
+ assert(D==3)
460
+ if not already_mem:
461
+ xyz = self.Ref2Mem(xyz, Z, Y, X)
462
+
463
+ xz = torch.stack([xyz[:,:,0], xyz[:,:,2]], dim=2)
464
+
465
+ grid_z, grid_x = utils.basic.meshgrid2d(B, Z, X, stack=False, norm=False, device=xyz.device)
466
+ # note the default stack is on -1
467
+ grid = torch.stack([grid_x, grid_z], dim=1)
468
+ # this is B x 2 x Z x X
469
+
470
+ xz = xz.reshape(B, N, 2, 1, 1)
471
+ grid = grid.reshape(B, 1, 2, Z, X)
472
+ # these are ready to broadcast to B x N x Z x X
473
+
474
+ # round the points, so that at least one value matches the grid perfectly,
475
+ # and we get a value of 1 there (since exp(0)==1)
476
+ xz = xz.round()
477
+
478
+ if torch.is_tensor(radius):
479
+ radius = radius.clamp(min=0.01)
480
+
481
+ off = grid - xz # B,N,2,Z,X
482
+ # interpret radius as sigma
483
+ dist_grid = torch.sum(off**2, dim=2, keepdim=False)
484
+ # this is B x N x Z x X
485
+ if torch.is_tensor(radius):
486
+ radius = radius.reshape(B, N, 1, 1, 1)
487
+ mask = torch.exp(-dist_grid/(2*radius*radius))
488
+ # zero out near zero
489
+ mask[mask < 0.001] = 0.0
490
+
491
+ # add a Y dim
492
+ mask = mask.unsqueeze(-2)
493
+ off = off.unsqueeze(-2)
494
+ # # B,N,2,Z,1,X
495
+
496
+ if also_offset:
497
+ return mask, off
498
+ else:
499
+ return mask
500
+
requirements.txt ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # spatrack
2
+ easydict==1.13
3
+ opencv-python==4.9.0.80
4
+ moviepy==1.0.3
5
+ flow-vis==0.1
6
+ matplotlib==3.8.3
7
+ einops==0.7.0
8
+ timm==0.6.7
9
+ scikit-image==0.22.0
10
+ scikit-learn==1.4.1.post1
11
+ cupy-cuda11x
12
+ accelerate
13
+ yt-dlp
14
+ pandas
15
+
16
+ # cogvideox
17
+ bitsandbytes
18
+ diffusers>=0.31.2
19
+ transformers>=4.45.2
20
+ hf_transfer>=0.1.8
21
+ peft>=0.12.0
22
+ decord>=0.6.0
23
+ wandb
24
+ torchao>=0.5.0
25
+ sentencepiece>=0.2.0
26
+ imageio-ffmpeg>=0.5.1
27
+ numpy>=1.26.4
28
+ git+https://github.com/asomoza/image_gen_aux.git
29
+ deepspeed
30
+
31
+ # submodules
32
+ -r submodules/MoGe/requirements.txt
submodules/MoGe/.gitignore ADDED
@@ -0,0 +1,425 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## Ignore Visual Studio temporary files, build results, and
2
+ ## files generated by popular Visual Studio add-ons.
3
+ ##
4
+ ## Get latest from https://github.com/github/gitignore/blob/main/VisualStudio.gitignore
5
+
6
+ # User-specific files
7
+ *.rsuser
8
+ *.suo
9
+ *.user
10
+ *.userosscache
11
+ *.sln.docstates
12
+
13
+ # User-specific files (MonoDevelop/Xamarin Studio)
14
+ *.userprefs
15
+
16
+ # Mono auto generated files
17
+ mono_crash.*
18
+
19
+ # Build results
20
+ [Dd]ebug/
21
+ [Dd]ebugPublic/
22
+ [Rr]elease/
23
+ [Rr]eleases/
24
+ x64/
25
+ x86/
26
+ [Ww][Ii][Nn]32/
27
+ [Aa][Rr][Mm]/
28
+ [Aa][Rr][Mm]64/
29
+ bld/
30
+ [Bb]in/
31
+ [Oo]bj/
32
+ [Ll]og/
33
+ [Ll]ogs/
34
+
35
+ # Visual Studio 2015/2017 cache/options directory
36
+ .vs/
37
+ # Uncomment if you have tasks that create the project's static files in wwwroot
38
+ #wwwroot/
39
+
40
+ # Visual Studio 2017 auto generated files
41
+ Generated\ Files/
42
+
43
+ # MSTest test Results
44
+ [Tt]est[Rr]esult*/
45
+ [Bb]uild[Ll]og.*
46
+
47
+ # NUnit
48
+ *.VisualState.xml
49
+ TestResult.xml
50
+ nunit-*.xml
51
+
52
+ # Build Results of an ATL Project
53
+ [Dd]ebugPS/
54
+ [Rr]eleasePS/
55
+ dlldata.c
56
+
57
+ # Benchmark Results
58
+ BenchmarkDotNet.Artifacts/
59
+
60
+ # .NET Core
61
+ project.lock.json
62
+ project.fragment.lock.json
63
+ artifacts/
64
+
65
+ # ASP.NET Scaffolding
66
+ ScaffoldingReadMe.txt
67
+
68
+ # StyleCop
69
+ StyleCopReport.xml
70
+
71
+ # Files built by Visual Studio
72
+ *_i.c
73
+ *_p.c
74
+ *_h.h
75
+ *.ilk
76
+ *.meta
77
+ *.obj
78
+ *.iobj
79
+ *.pch
80
+ *.pdb
81
+ *.ipdb
82
+ *.pgc
83
+ *.pgd
84
+ *.rsp
85
+ *.sbr
86
+ *.tlb
87
+ *.tli
88
+ *.tlh
89
+ *.tmp
90
+ *.tmp_proj
91
+ *_wpftmp.csproj
92
+ *.log
93
+ *.tlog
94
+ *.vspscc
95
+ *.vssscc
96
+ .builds
97
+ *.pidb
98
+ *.svclog
99
+ *.scc
100
+
101
+ # Chutzpah Test files
102
+ _Chutzpah*
103
+
104
+ # Visual C++ cache files
105
+ ipch/
106
+ *.aps
107
+ *.ncb
108
+ *.opendb
109
+ *.opensdf
110
+ *.sdf
111
+ *.cachefile
112
+ *.VC.db
113
+ *.VC.VC.opendb
114
+
115
+ # Visual Studio profiler
116
+ *.psess
117
+ *.vsp
118
+ *.vspx
119
+ *.sap
120
+
121
+ # Visual Studio Trace Files
122
+ *.e2e
123
+
124
+ # TFS 2012 Local Workspace
125
+ $tf/
126
+
127
+ # Guidance Automation Toolkit
128
+ *.gpState
129
+
130
+ # ReSharper is a .NET coding add-in
131
+ _ReSharper*/
132
+ *.[Rr]e[Ss]harper
133
+ *.DotSettings.user
134
+
135
+ # TeamCity is a build add-in
136
+ _TeamCity*
137
+
138
+ # DotCover is a Code Coverage Tool
139
+ *.dotCover
140
+
141
+ # AxoCover is a Code Coverage Tool
142
+ .axoCover/*
143
+ !.axoCover/settings.json
144
+
145
+ # Coverlet is a free, cross platform Code Coverage Tool
146
+ coverage*.json
147
+ coverage*.xml
148
+ coverage*.info
149
+
150
+ # Visual Studio code coverage results
151
+ *.coverage
152
+ *.coveragexml
153
+
154
+ # NCrunch
155
+ _NCrunch_*
156
+ .*crunch*.local.xml
157
+ nCrunchTemp_*
158
+
159
+ # MightyMoose
160
+ *.mm.*
161
+ AutoTest.Net/
162
+
163
+ # Web workbench (sass)
164
+ .sass-cache/
165
+
166
+ # Installshield output folder
167
+ [Ee]xpress/
168
+
169
+ # DocProject is a documentation generator add-in
170
+ DocProject/buildhelp/
171
+ DocProject/Help/*.HxT
172
+ DocProject/Help/*.HxC
173
+ DocProject/Help/*.hhc
174
+ DocProject/Help/*.hhk
175
+ DocProject/Help/*.hhp
176
+ DocProject/Help/Html2
177
+ DocProject/Help/html
178
+
179
+ # Click-Once directory
180
+ publish/
181
+
182
+ # Publish Web Output
183
+ *.[Pp]ublish.xml
184
+ *.azurePubxml
185
+ # Note: Comment the next line if you want to checkin your web deploy settings,
186
+ # but database connection strings (with potential passwords) will be unencrypted
187
+ *.pubxml
188
+ *.publishproj
189
+
190
+ # Microsoft Azure Web App publish settings. Comment the next line if you want to
191
+ # checkin your Azure Web App publish settings, but sensitive information contained
192
+ # in these scripts will be unencrypted
193
+ PublishScripts/
194
+
195
+ # NuGet Packages
196
+ *.nupkg
197
+ # NuGet Symbol Packages
198
+ *.snupkg
199
+ # The packages folder can be ignored because of Package Restore
200
+ **/[Pp]ackages/*
201
+ # except build/, which is used as an MSBuild target.
202
+ !**/[Pp]ackages/build/
203
+ # Uncomment if necessary however generally it will be regenerated when needed
204
+ #!**/[Pp]ackages/repositories.config
205
+ # NuGet v3's project.json files produces more ignorable files
206
+ *.nuget.props
207
+ *.nuget.targets
208
+
209
+ # Microsoft Azure Build Output
210
+ csx/
211
+ *.build.csdef
212
+
213
+ # Microsoft Azure Emulator
214
+ ecf/
215
+ rcf/
216
+
217
+ # Windows Store app package directories and files
218
+ AppPackages/
219
+ BundleArtifacts/
220
+ Package.StoreAssociation.xml
221
+ _pkginfo.txt
222
+ *.appx
223
+ *.appxbundle
224
+ *.appxupload
225
+
226
+ # Visual Studio cache files
227
+ # files ending in .cache can be ignored
228
+ *.[Cc]ache
229
+ # but keep track of directories ending in .cache
230
+ !?*.[Cc]ache/
231
+
232
+ # Others
233
+ ClientBin/
234
+ ~$*
235
+ *~
236
+ *.dbmdl
237
+ *.dbproj.schemaview
238
+ *.jfm
239
+ *.pfx
240
+ *.publishsettings
241
+ orleans.codegen.cs
242
+
243
+ # Including strong name files can present a security risk
244
+ # (https://github.com/github/gitignore/pull/2483#issue-259490424)
245
+ #*.snk
246
+
247
+ # Since there are multiple workflows, uncomment next line to ignore bower_components
248
+ # (https://github.com/github/gitignore/pull/1529#issuecomment-104372622)
249
+ #bower_components/
250
+
251
+ # RIA/Silverlight projects
252
+ Generated_Code/
253
+
254
+ # Backup & report files from converting an old project file
255
+ # to a newer Visual Studio version. Backup files are not needed,
256
+ # because we have git ;-)
257
+ _UpgradeReport_Files/
258
+ Backup*/
259
+ UpgradeLog*.XML
260
+ UpgradeLog*.htm
261
+ ServiceFabricBackup/
262
+ *.rptproj.bak
263
+
264
+ # SQL Server files
265
+ *.mdf
266
+ *.ldf
267
+ *.ndf
268
+
269
+ # Business Intelligence projects
270
+ *.rdl.data
271
+ *.bim.layout
272
+ *.bim_*.settings
273
+ *.rptproj.rsuser
274
+ *- [Bb]ackup.rdl
275
+ *- [Bb]ackup ([0-9]).rdl
276
+ *- [Bb]ackup ([0-9][0-9]).rdl
277
+
278
+ # Microsoft Fakes
279
+ FakesAssemblies/
280
+
281
+ # GhostDoc plugin setting file
282
+ *.GhostDoc.xml
283
+
284
+ # Node.js Tools for Visual Studio
285
+ .ntvs_analysis.dat
286
+ node_modules/
287
+
288
+ # Visual Studio 6 build log
289
+ *.plg
290
+
291
+ # Visual Studio 6 workspace options file
292
+ *.opt
293
+
294
+ # Visual Studio 6 auto-generated workspace file (contains which files were open etc.)
295
+ *.vbw
296
+
297
+ # Visual Studio 6 auto-generated project file (contains which files were open etc.)
298
+ *.vbp
299
+
300
+ # Visual Studio 6 workspace and project file (working project files containing files to include in project)
301
+ *.dsw
302
+ *.dsp
303
+
304
+ # Visual Studio 6 technical files
305
+ *.ncb
306
+ *.aps
307
+
308
+ # Visual Studio LightSwitch build output
309
+ **/*.HTMLClient/GeneratedArtifacts
310
+ **/*.DesktopClient/GeneratedArtifacts
311
+ **/*.DesktopClient/ModelManifest.xml
312
+ **/*.Server/GeneratedArtifacts
313
+ **/*.Server/ModelManifest.xml
314
+ _Pvt_Extensions
315
+
316
+ # Paket dependency manager
317
+ .paket/paket.exe
318
+ paket-files/
319
+
320
+ # FAKE - F# Make
321
+ .fake/
322
+
323
+ # CodeRush personal settings
324
+ .cr/personal
325
+
326
+ # Python Tools for Visual Studio (PTVS)
327
+ __pycache__/
328
+ *.pyc
329
+
330
+ # Cake - Uncomment if you are using it
331
+ # tools/**
332
+ # !tools/packages.config
333
+
334
+ # Tabs Studio
335
+ *.tss
336
+
337
+ # Telerik's JustMock configuration file
338
+ *.jmconfig
339
+
340
+ # BizTalk build output
341
+ *.btp.cs
342
+ *.btm.cs
343
+ *.odx.cs
344
+ *.xsd.cs
345
+
346
+ # OpenCover UI analysis results
347
+ OpenCover/
348
+
349
+ # Azure Stream Analytics local run output
350
+ ASALocalRun/
351
+
352
+ # MSBuild Binary and Structured Log
353
+ *.binlog
354
+
355
+ # NVidia Nsight GPU debugger configuration file
356
+ *.nvuser
357
+
358
+ # MFractors (Xamarin productivity tool) working folder
359
+ .mfractor/
360
+
361
+ # Local History for Visual Studio
362
+ .localhistory/
363
+
364
+ # Visual Studio History (VSHistory) files
365
+ .vshistory/
366
+
367
+ # BeatPulse healthcheck temp database
368
+ healthchecksdb
369
+
370
+ # Backup folder for Package Reference Convert tool in Visual Studio 2017
371
+ MigrationBackup/
372
+
373
+ # Ionide (cross platform F# VS Code tools) working folder
374
+ .ionide/
375
+
376
+ # Fody - auto-generated XML schema
377
+ FodyWeavers.xsd
378
+
379
+ # VS Code files for those working on multiple tools
380
+ .vscode/*
381
+ !.vscode/settings.json
382
+ !.vscode/tasks.json
383
+ !.vscode/launch.json
384
+ !.vscode/extensions.json
385
+ *.code-workspace
386
+
387
+ # Local History for Visual Studio Code
388
+ .history/
389
+
390
+ # Windows Installer files from build outputs
391
+ *.cab
392
+ *.msi
393
+ *.msix
394
+ *.msm
395
+ *.msp
396
+
397
+ # JetBrains Rider
398
+ *.sln.iml
399
+
400
+ # MoGe
401
+ /data
402
+ /download
403
+ /extract
404
+ /view_point_cloud
405
+ /view_depth_map
406
+ /blobcache
407
+ /snapshot
408
+ /reference_embeddings
409
+ /.msra_intern_s_toolkit
410
+ /debug
411
+ /workspace
412
+ /mlruns
413
+ /infer_output
414
+ /video_output
415
+ /eval_output
416
+ /.blobcache
417
+ /test_images
418
+ /test_videos
419
+ /vis
420
+ /videos
421
+ /raid
422
+ /blobmnt
423
+ /eval_dump
424
+ /pretrained
425
+ /.gradio
submodules/MoGe/CHANGELOG.md ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## 2024-11-28
2
+ ### Added
3
+ - Supported user-provided camera FOV. See [scripts/infer.py](scripts/infer.py) --fov_x.
4
+ - Related issues: [#25](https://github.com/microsoft/MoGe/issues/25) and [#24](https://github.com/microsoft/MoGe/issues/24).
5
+ - Added inference scripts for panorama images. See [scripts/infer_panorama.py](scripts/infer_panorama.py).
6
+ - Related issue: [#19](https://github.com/microsoft/MoGe/issues/19).
7
+
8
+ ### Fixed
9
+ - Suppressed unnecessary numpy runtime warnings.
10
+ - Specified recommended versions of requirements.
11
+ - Related issue: [#21](https://github.com/microsoft/MoGe/issues/21).
12
+
13
+ ### Changed
14
+ - Moved `app.py` and `infer.py` to [scripts/](scripts/)
15
+ - Improved edge removal.
submodules/MoGe/CODE_OF_CONDUCT.md ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ # Microsoft Open Source Code of Conduct
2
+
3
+ This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/).
4
+
5
+ Resources:
6
+
7
+ - [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/)
8
+ - [Microsoft Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/)
9
+ - Contact [[email protected]](mailto:[email protected]) with questions or concerns
submodules/MoGe/LICENSE ADDED
@@ -0,0 +1,224 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) Microsoft Corporation.
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE
22
+
23
+
24
+ Apache License
25
+ Version 2.0, January 2004
26
+ http://www.apache.org/licenses/
27
+
28
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
29
+
30
+ 1. Definitions.
31
+
32
+ "License" shall mean the terms and conditions for use, reproduction,
33
+ and distribution as defined by Sections 1 through 9 of this document.
34
+
35
+ "Licensor" shall mean the copyright owner or entity authorized by
36
+ the copyright owner that is granting the License.
37
+
38
+ "Legal Entity" shall mean the union of the acting entity and all
39
+ other entities that control, are controlled by, or are under common
40
+ control with that entity. For the purposes of this definition,
41
+ "control" means (i) the power, direct or indirect, to cause the
42
+ direction or management of such entity, whether by contract or
43
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
44
+ outstanding shares, or (iii) beneficial ownership of such entity.
45
+
46
+ "You" (or "Your") shall mean an individual or Legal Entity
47
+ exercising permissions granted by this License.
48
+
49
+ "Source" form shall mean the preferred form for making modifications,
50
+ including but not limited to software source code, documentation
51
+ source, and configuration files.
52
+
53
+ "Object" form shall mean any form resulting from mechanical
54
+ transformation or translation of a Source form, including but
55
+ not limited to compiled object code, generated documentation,
56
+ and conversions to other media types.
57
+
58
+ "Work" shall mean the work of authorship, whether in Source or
59
+ Object form, made available under the License, as indicated by a
60
+ copyright notice that is included in or attached to the work
61
+ (an example is provided in the Appendix below).
62
+
63
+ "Derivative Works" shall mean any work, whether in Source or Object
64
+ form, that is based on (or derived from) the Work and for which the
65
+ editorial revisions, annotations, elaborations, or other modifications
66
+ represent, as a whole, an original work of authorship. For the purposes
67
+ of this License, Derivative Works shall not include works that remain
68
+ separable from, or merely link (or bind by name) to the interfaces of,
69
+ the Work and Derivative Works thereof.
70
+
71
+ "Contribution" shall mean any work of authorship, including
72
+ the original version of the Work and any modifications or additions
73
+ to that Work or Derivative Works thereof, that is intentionally
74
+ submitted to Licensor for inclusion in the Work by the copyright owner
75
+ or by an individual or Legal Entity authorized to submit on behalf of
76
+ the copyright owner. For the purposes of this definition, "submitted"
77
+ means any form of electronic, verbal, or written communication sent
78
+ to the Licensor or its representatives, including but not limited to
79
+ communication on electronic mailing lists, source code control systems,
80
+ and issue tracking systems that are managed by, or on behalf of, the
81
+ Licensor for the purpose of discussing and improving the Work, but
82
+ excluding communication that is conspicuously marked or otherwise
83
+ designated in writing by the copyright owner as "Not a Contribution."
84
+
85
+ "Contributor" shall mean Licensor and any individual or Legal Entity
86
+ on behalf of whom a Contribution has been received by Licensor and
87
+ subsequently incorporated within the Work.
88
+
89
+ 2. Grant of Copyright License. Subject to the terms and conditions of
90
+ this License, each Contributor hereby grants to You a perpetual,
91
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
92
+ copyright license to reproduce, prepare Derivative Works of,
93
+ publicly display, publicly perform, sublicense, and distribute the
94
+ Work and such Derivative Works in Source or Object form.
95
+
96
+ 3. Grant of Patent License. Subject to the terms and conditions of
97
+ this License, each Contributor hereby grants to You a perpetual,
98
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
99
+ (except as stated in this section) patent license to make, have made,
100
+ use, offer to sell, sell, import, and otherwise transfer the Work,
101
+ where such license applies only to those patent claims licensable
102
+ by such Contributor that are necessarily infringed by their
103
+ Contribution(s) alone or by combination of their Contribution(s)
104
+ with the Work to which such Contribution(s) was submitted. If You
105
+ institute patent litigation against any entity (including a
106
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
107
+ or a Contribution incorporated within the Work constitutes direct
108
+ or contributory patent infringement, then any patent licenses
109
+ granted to You under this License for that Work shall terminate
110
+ as of the date such litigation is filed.
111
+
112
+ 4. Redistribution. You may reproduce and distribute copies of the
113
+ Work or Derivative Works thereof in any medium, with or without
114
+ modifications, and in Source or Object form, provided that You
115
+ meet the following conditions:
116
+
117
+ (a) You must give any other recipients of the Work or
118
+ Derivative Works a copy of this License; and
119
+
120
+ (b) You must cause any modified files to carry prominent notices
121
+ stating that You changed the files; and
122
+
123
+ (c) You must retain, in the Source form of any Derivative Works
124
+ that You distribute, all copyright, patent, trademark, and
125
+ attribution notices from the Source form of the Work,
126
+ excluding those notices that do not pertain to any part of
127
+ the Derivative Works; and
128
+
129
+ (d) If the Work includes a "NOTICE" text file as part of its
130
+ distribution, then any Derivative Works that You distribute must
131
+ include a readable copy of the attribution notices contained
132
+ within such NOTICE file, excluding those notices that do not
133
+ pertain to any part of the Derivative Works, in at least one
134
+ of the following places: within a NOTICE text file distributed
135
+ as part of the Derivative Works; within the Source form or
136
+ documentation, if provided along with the Derivative Works; or,
137
+ within a display generated by the Derivative Works, if and
138
+ wherever such third-party notices normally appear. The contents
139
+ of the NOTICE file are for informational purposes only and
140
+ do not modify the License. You may add Your own attribution
141
+ notices within Derivative Works that You distribute, alongside
142
+ or as an addendum to the NOTICE text from the Work, provided
143
+ that such additional attribution notices cannot be construed
144
+ as modifying the License.
145
+
146
+ You may add Your own copyright statement to Your modifications and
147
+ may provide additional or different license terms and conditions
148
+ for use, reproduction, or distribution of Your modifications, or
149
+ for any such Derivative Works as a whole, provided Your use,
150
+ reproduction, and distribution of the Work otherwise complies with
151
+ the conditions stated in this License.
152
+
153
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
154
+ any Contribution intentionally submitted for inclusion in the Work
155
+ by You to the Licensor shall be under the terms and conditions of
156
+ this License, without any additional terms or conditions.
157
+ Notwithstanding the above, nothing herein shall supersede or modify
158
+ the terms of any separate license agreement you may have executed
159
+ with Licensor regarding such Contributions.
160
+
161
+ 6. Trademarks. This License does not grant permission to use the trade
162
+ names, trademarks, service marks, or product names of the Licensor,
163
+ except as required for reasonable and customary use in describing the
164
+ origin of the Work and reproducing the content of the NOTICE file.
165
+
166
+ 7. Disclaimer of Warranty. Unless required by applicable law or
167
+ agreed to in writing, Licensor provides the Work (and each
168
+ Contributor provides its Contributions) on an "AS IS" BASIS,
169
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
170
+ implied, including, without limitation, any warranties or conditions
171
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
172
+ PARTICULAR PURPOSE. You are solely responsible for determining the
173
+ appropriateness of using or redistributing the Work and assume any
174
+ risks associated with Your exercise of permissions under this License.
175
+
176
+ 8. Limitation of Liability. In no event and under no legal theory,
177
+ whether in tort (including negligence), contract, or otherwise,
178
+ unless required by applicable law (such as deliberate and grossly
179
+ negligent acts) or agreed to in writing, shall any Contributor be
180
+ liable to You for damages, including any direct, indirect, special,
181
+ incidental, or consequential damages of any character arising as a
182
+ result of this License or out of the use or inability to use the
183
+ Work (including but not limited to damages for loss of goodwill,
184
+ work stoppage, computer failure or malfunction, or any and all
185
+ other commercial damages or losses), even if such Contributor
186
+ has been advised of the possibility of such damages.
187
+
188
+ 9. Accepting Warranty or Additional Liability. While redistributing
189
+ the Work or Derivative Works thereof, You may choose to offer,
190
+ and charge a fee for, acceptance of support, warranty, indemnity,
191
+ or other liability obligations and/or rights consistent with this
192
+ License. However, in accepting such obligations, You may act only
193
+ on Your own behalf and on Your sole responsibility, not on behalf
194
+ of any other Contributor, and only if You agree to indemnify,
195
+ defend, and hold each Contributor harmless for any liability
196
+ incurred by, or claims asserted against, such Contributor by reason
197
+ of your accepting any such warranty or additional liability.
198
+
199
+ END OF TERMS AND CONDITIONS
200
+
201
+ APPENDIX: How to apply the Apache License to your work.
202
+
203
+ To apply the Apache License to your work, attach the following
204
+ boilerplate notice, with the fields enclosed by brackets "[]"
205
+ replaced with your own identifying information. (Don't include
206
+ the brackets!) The text should be enclosed in the appropriate
207
+ comment syntax for the file format. We also recommend that a
208
+ file or class name and description of purpose be included on the
209
+ same "printed page" as the copyright notice for easier
210
+ identification within third-party archives.
211
+
212
+ Copyright [yyyy] [name of copyright owner]
213
+
214
+ Licensed under the Apache License, Version 2.0 (the "License");
215
+ you may not use this file except in compliance with the License.
216
+ You may obtain a copy of the License at
217
+
218
+ http://www.apache.org/licenses/LICENSE-2.0
219
+
220
+ Unless required by applicable law or agreed to in writing, software
221
+ distributed under the License is distributed on an "AS IS" BASIS,
222
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
223
+ See the License for the specific language governing permissions and
224
+ limitations under the License.