diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..7026aea55168872d1261add521740ffa888913f2 --- /dev/null +++ b/.gitignore @@ -0,0 +1,200 @@ +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# JetBrains +.idea + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ +cover/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +.pybuilder/ +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +# For a library or package, you might want to ignore these files since the code is +# intended to run in multiple environments; otherwise, check them in: +# .python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# poetry +# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. +# This is especially recommended for binary packages to ensure reproducibility, and is more +# commonly ignored for libraries. +# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control +#poetry.lock + +# pdm +# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. +#pdm.lock +# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it +# in version control. +# https://pdm.fming.dev/latest/usage/project/#working-with-version-control +.pdm.toml +.pdm-python +.pdm-build/ + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# pytype static type analyzer +.pytype/ + +# Cython debug symbols +cython_debug/ + +# PyCharm +# JetBrains specific template is maintained in a separate JetBrains.gitignore that can +# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore +# and can be added to the global gitignore or merged into this file. For a more nuclear +# option (not recommended) you can uncomment the following to ignore the entire idea folder. +#.idea/ + +# manually added +wandb/ +dump* + +!requirements.txt +env/ +datasets/ +validation/ +ckpts/ +.vscode/ +output.mp4 +outputs/ +camctrl_output +*.code-workspace + +**/*/.DS_Store +**/*/__pycache__/* +.DS_Store +__pycache__ +vis_results +checkpoints +**/*/.pth +**/*/.pt +**/*/.mp4 +**/*/.npy + +/assets/** +./vis_results/** */ +models/monoD/zoeDepth/ckpts/* +slurm-*.out +.vscode + +data/ +tmp/ \ No newline at end of file diff --git a/.gitmodules b/.gitmodules new file mode 100644 index 0000000000000000000000000000000000000000..a30db120944e4b70b9d02f689b196263cc72125c --- /dev/null +++ b/.gitmodules @@ -0,0 +1,3 @@ +[submodule "submodules/MoGe"] + path = submodules/MoGe + url = https://github.com/microsoft/MoGe.git diff --git a/app.py b/app.py new file mode 100644 index 0000000000000000000000000000000000000000..186bcfdcef150e21f3d7b4c45c508defae3121ed --- /dev/null +++ b/app.py @@ -0,0 +1,577 @@ +import os +import sys +import gradio as gr +import torch +import subprocess +import argparse +import glob + +project_root = os.path.dirname(os.path.abspath(__file__)) +os.environ["GRADIO_TEMP_DIR"] = os.path.join(project_root, "tmp", "gradio") +sys.path.append(project_root) + +HERE_PATH = os.path.normpath(os.path.dirname(__file__)) +sys.path.insert(0, HERE_PATH) +from huggingface_hub import hf_hub_download +hf_hub_download(repo_id="EXCAI/Diffusion-As-Shader", filename='spatracker/spaT_final.pth', local_dir=f'{HERE_PATH}/checkpoints/') + + +# Parse command line arguments +parser = argparse.ArgumentParser(description="Diffusion as Shader Web UI") +parser.add_argument("--port", type=int, default=7860, help="Port to run the web UI on") +parser.add_argument("--share", action="store_true", help="Share the web UI") +parser.add_argument("--gpu", type=int, default=0, help="GPU device ID") +parser.add_argument("--model_path", type=str, default="EXCAI/Diffusion-As-Shader", help="Path to model checkpoint") +parser.add_argument("--output_dir", type=str, default="tmp", help="Output directory") +args = parser.parse_args() + +# Use the original GPU ID throughout the entire code for consistency +GPU_ID = args.gpu + +# Set environment variables - this used to remap the GPU, but we're removing this for consistency +# Instead, we'll pass the original GPU ID to all commands +# os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu) # Commented out to ensure consistent GPU ID usage + +# Check if CUDA is available +CUDA_AVAILABLE = torch.cuda.is_available() +if CUDA_AVAILABLE: + GPU_COUNT = torch.cuda.device_count() + GPU_NAMES = [f"{i}: {torch.cuda.get_device_name(i)}" for i in range(GPU_COUNT)] +else: + GPU_COUNT = 0 + GPU_NAMES = ["CPU (CUDA not available)"] + GPU_ID = "CPU" + +DEFAULT_MODEL_PATH = args.model_path +OUTPUT_DIR = args.output_dir + +# Create necessary directories +os.makedirs("outputs", exist_ok=True) +# Create project tmp directory instead of using system temp +os.makedirs(os.path.join(project_root, "tmp"), exist_ok=True) +os.makedirs(os.path.join(project_root, "tmp", "gradio"), exist_ok=True) + +def save_uploaded_file(file): + if file is None: + return None + + # Use project tmp directory instead of system temp + temp_dir = os.path.join(project_root, "tmp") + + if hasattr(file, 'name'): + filename = file.name + else: + # Generate a unique filename if name attribute is missing + import uuid + ext = ".tmp" + if hasattr(file, 'content_type'): + if "image" in file.content_type: + ext = ".png" + elif "video" in file.content_type: + ext = ".mp4" + filename = f"{uuid.uuid4()}{ext}" + + temp_path = os.path.join(temp_dir, filename) + + try: + # Check if file is a FileStorage object or already a path + if hasattr(file, 'save'): + file.save(temp_path) + elif isinstance(file, str): + # It's already a path + return file + else: + # Try to read and save the file + with open(temp_path, 'wb') as f: + f.write(file.read() if hasattr(file, 'read') else file) + except Exception as e: + print(f"Error saving file: {e}") + return None + + return temp_path + +def create_run_command(args): + """Create command based on input parameters""" + cmd = ["python", "demo.py"] + + if "prompt" not in args or args["prompt"] is None or args["prompt"] == "": + args["prompt"] = "" + if "checkpoint_path" not in args or args["checkpoint_path"] is None or args["checkpoint_path"] == "": + args["checkpoint_path"] = DEFAULT_MODEL_PATH + + # 添加调试输出 + print(f"DEBUG: Command args: {args}") + + for key, value in args.items(): + if value is not None: + # Handle boolean values correctly - for repaint, we need to pass true/false + if isinstance(value, bool): + cmd.append(f"--{key}") + cmd.append(str(value).lower()) # Convert True/False to true/false + else: + cmd.append(f"--{key}") + cmd.append(str(value)) + + return cmd + +def run_process(cmd): + """Run command and return output""" + print(f"Running command: {' '.join(cmd)}") + process = subprocess.Popen( + cmd, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + universal_newlines=True + ) + + output = [] + for line in iter(process.stdout.readline, ""): + print(line, end="") + output.append(line) + if not line: + break + + process.stdout.close() + return_code = process.wait() + + if return_code: + stderr = process.stderr.read() + print(f"Error: {stderr}") + raise subprocess.CalledProcessError(return_code, cmd, output="\n".join(output), stderr=stderr) + + return "\n".join(output) + +# Process functions for each tab +def process_motion_transfer(source, prompt, mt_repaint_option, mt_repaint_image): + """Process video motion transfer task""" + try: + # Save uploaded files + input_video_path = save_uploaded_file(source) + if input_video_path is None: + return None + + print(f"DEBUG: Repaint option: {mt_repaint_option}") + print(f"DEBUG: Repaint image: {mt_repaint_image}") + + args = { + "input_path": input_video_path, + "prompt": f"\"{prompt}\"", + "checkpoint_path": DEFAULT_MODEL_PATH, + "output_dir": OUTPUT_DIR, + "gpu": GPU_ID + } + + # Priority: Custom Image > Yes > No + if mt_repaint_image is not None: + # Custom image takes precedence if provided + repaint_path = save_uploaded_file(mt_repaint_image) + print(f"DEBUG: Repaint path: {repaint_path}") + args["repaint"] = repaint_path + elif mt_repaint_option == "Yes": + # Otherwise use Yes/No selection + args["repaint"] = "true" + + # Create and run command + cmd = create_run_command(args) + output = run_process(cmd) + + # Find generated video files + output_files = glob.glob(os.path.join(OUTPUT_DIR, "*.mp4")) + if output_files: + # Sort by modification time, return the latest file + latest_file = max(output_files, key=os.path.getmtime) + return latest_file + else: + return None + except Exception as e: + import traceback + print(f"Processing failed: {str(e)}\n{traceback.format_exc()}") + return None + +def process_camera_control(source, prompt, camera_motion, tracking_method): + """Process camera control task""" + try: + # Save uploaded files + input_media_path = save_uploaded_file(source) + if input_media_path is None: + return None + + print(f"DEBUG: Camera motion: '{camera_motion}'") + print(f"DEBUG: Tracking method: '{tracking_method}'") + + args = { + "input_path": input_media_path, + "prompt": prompt, + "checkpoint_path": DEFAULT_MODEL_PATH, + "output_dir": OUTPUT_DIR, + "gpu": GPU_ID, + "tracking_method": tracking_method + } + + if camera_motion and camera_motion.strip(): + args["camera_motion"] = camera_motion + + # Create and run command + cmd = create_run_command(args) + output = run_process(cmd) + + # Find generated video files + output_files = glob.glob(os.path.join(OUTPUT_DIR, "*.mp4")) + if output_files: + # Sort by modification time, return the latest file + latest_file = max(output_files, key=os.path.getmtime) + return latest_file + else: + return None + except Exception as e: + import traceback + print(f"Processing failed: {str(e)}\n{traceback.format_exc()}") + return None + +def process_object_manipulation(source, prompt, object_motion, object_mask, tracking_method): + """Process object manipulation task""" + try: + # Save uploaded files + input_image_path = save_uploaded_file(source) + if input_image_path is None: + return None + + object_mask_path = save_uploaded_file(object_mask) + + args = { + "input_path": input_image_path, + "prompt": prompt, + "checkpoint_path": DEFAULT_MODEL_PATH, + "output_dir": OUTPUT_DIR, + "gpu": GPU_ID, + "object_motion": object_motion, + "object_mask": object_mask_path, + "tracking_method": tracking_method + } + + # Create and run command + cmd = create_run_command(args) + output = run_process(cmd) + + # Find generated video files + output_files = glob.glob(os.path.join(OUTPUT_DIR, "*.mp4")) + if output_files: + # Sort by modification time, return the latest file + latest_file = max(output_files, key=os.path.getmtime) + return latest_file + else: + return None + except Exception as e: + import traceback + print(f"Processing failed: {str(e)}\n{traceback.format_exc()}") + return None + +def process_mesh_animation(source, prompt, tracking_video, ma_repaint_option, ma_repaint_image): + """Process mesh animation task""" + try: + # Save uploaded files + input_video_path = save_uploaded_file(source) + if input_video_path is None: + return None + + tracking_video_path = save_uploaded_file(tracking_video) + if tracking_video_path is None: + return None + + args = { + "input_path": input_video_path, + "prompt": prompt, + "checkpoint_path": DEFAULT_MODEL_PATH, + "output_dir": OUTPUT_DIR, + "gpu": GPU_ID, + "tracking_path": tracking_video_path + } + + # Priority: Custom Image > Yes > No + if ma_repaint_image is not None: + # Custom image takes precedence if provided + repaint_path = save_uploaded_file(ma_repaint_image) + args["repaint"] = repaint_path + elif ma_repaint_option == "Yes": + # Otherwise use Yes/No selection + args["repaint"] = "true" + + # Create and run command + cmd = create_run_command(args) + output = run_process(cmd) + + # Find generated video files + output_files = glob.glob(os.path.join(OUTPUT_DIR, "*.mp4")) + if output_files: + # Sort by modification time, return the latest file + latest_file = max(output_files, key=os.path.getmtime) + return latest_file + else: + return None + except Exception as e: + import traceback + print(f"Processing failed: {str(e)}\n{traceback.format_exc()}") + return None + +# Create Gradio interface with updated layout +with gr.Blocks(title="Diffusion as Shader") as demo: + gr.Markdown("# Diffusion as Shader Web UI") + gr.Markdown("### [Project Page](https://igl-hkust.github.io/das/) | [GitHub](https://github.com/IGL-HKUST/DiffusionAsShader)") + + with gr.Row(): + left_column = gr.Column(scale=1) + right_column = gr.Column(scale=1) + + with right_column: + output_video = gr.Video(label="Generated Video") + + with left_column: + source = gr.File(label="Source", file_types=["image", "video"]) + common_prompt = gr.Textbox(label="Prompt", lines=2) + gr.Markdown(f"**Using GPU: {GPU_ID}**") + + with gr.Tabs() as task_tabs: + # Motion Transfer tab + with gr.TabItem("Motion Transfer"): + gr.Markdown("## Motion Transfer") + + # Simplified controls - Radio buttons for Yes/No and separate file upload + with gr.Row(): + mt_repaint_option = gr.Radio( + label="Repaint First Frame", + choices=["No", "Yes"], + value="No" + ) + gr.Markdown("### Note: If you want to use your own image as repainted first frame, please upload the image in below.") + # Custom image uploader (always visible) + mt_repaint_image = gr.File( + label="Custom Repaint Image", + file_types=["image"] + ) + + # Add run button for Motion Transfer tab + mt_run_btn = gr.Button("Run Motion Transfer", variant="primary", size="lg") + + # Connect to process function + mt_run_btn.click( + fn=process_motion_transfer, + inputs=[ + source, common_prompt, + mt_repaint_option, mt_repaint_image + ], + outputs=[output_video] + ) + + # Camera Control tab + with gr.TabItem("Camera Control"): + gr.Markdown("## Camera Control") + + cc_camera_motion = gr.Textbox( + label="Current Camera Motion Sequence", + placeholder="Your camera motion sequence will appear here...", + interactive=False + ) + + # Use tabs for different motion types + with gr.Tabs() as cc_motion_tabs: + # Translation tab + with gr.TabItem("Translation (trans)"): + with gr.Row(): + cc_trans_x = gr.Slider(minimum=-1.0, maximum=1.0, value=0.0, step=0.05, label="X-axis Movement") + cc_trans_y = gr.Slider(minimum=-1.0, maximum=1.0, value=0.0, step=0.05, label="Y-axis Movement") + cc_trans_z = gr.Slider(minimum=-1.0, maximum=1.0, value=0.0, step=0.05, label="Z-axis Movement (depth)") + + with gr.Row(): + cc_trans_start = gr.Number(minimum=0, maximum=48, value=0, step=1, label="Start Frame", precision=0) + cc_trans_end = gr.Number(minimum=0, maximum=48, value=48, step=1, label="End Frame", precision=0) + + cc_trans_note = gr.Markdown(""" + **Translation Notes:** + - Positive X: Move right, Negative X: Move left + - Positive Y: Move down, Negative Y: Move up + - Positive Z: Zoom in, Negative Z: Zoom out + """) + + # Add translation button in the Translation tab + cc_add_trans = gr.Button("Add Camera Translation", variant="secondary") + + # Function to add translation motion + def add_translation_motion(current_motion, trans_x, trans_y, trans_z, trans_start, trans_end): + # Format: trans dx dy dz [start_frame end_frame] + frame_range = f" {int(trans_start)} {int(trans_end)}" if trans_start != 0 or trans_end != 48 else "" + new_motion = f"trans {trans_x:.2f} {trans_y:.2f} {trans_z:.2f}{frame_range}" + + # Append to existing motion string with semicolon separator if needed + if current_motion and current_motion.strip(): + updated_motion = f"{current_motion}; {new_motion}" + else: + updated_motion = new_motion + + return updated_motion + + # Connect translation button + cc_add_trans.click( + fn=add_translation_motion, + inputs=[ + cc_camera_motion, + cc_trans_x, cc_trans_y, cc_trans_z, cc_trans_start, cc_trans_end + ], + outputs=[cc_camera_motion] + ) + + # Rotation tab + with gr.TabItem("Rotation (rot)"): + with gr.Row(): + cc_rot_axis = gr.Dropdown(choices=["x", "y", "z"], value="y", label="Rotation Axis") + cc_rot_angle = gr.Slider(minimum=-30, maximum=30, value=5, step=1, label="Rotation Angle (degrees)") + + with gr.Row(): + cc_rot_start = gr.Number(minimum=0, maximum=48, value=0, step=1, label="Start Frame", precision=0) + cc_rot_end = gr.Number(minimum=0, maximum=48, value=48, step=1, label="End Frame", precision=0) + + cc_rot_note = gr.Markdown(""" + **Rotation Notes:** + - X-axis rotation: Tilt camera up/down + - Y-axis rotation: Pan camera left/right + - Z-axis rotation: Roll camera + """) + + # Add rotation button in the Rotation tab + cc_add_rot = gr.Button("Add Camera Rotation", variant="secondary") + + # Function to add rotation motion + def add_rotation_motion(current_motion, rot_axis, rot_angle, rot_start, rot_end): + # Format: rot axis angle [start_frame end_frame] + frame_range = f" {int(rot_start)} {int(rot_end)}" if rot_start != 0 or rot_end != 48 else "" + new_motion = f"rot {rot_axis} {rot_angle}{frame_range}" + + # Append to existing motion string with semicolon separator if needed + if current_motion and current_motion.strip(): + updated_motion = f"{current_motion}; {new_motion}" + else: + updated_motion = new_motion + + return updated_motion + + # Connect rotation button + cc_add_rot.click( + fn=add_rotation_motion, + inputs=[ + cc_camera_motion, + cc_rot_axis, cc_rot_angle, cc_rot_start, cc_rot_end + ], + outputs=[cc_camera_motion] + ) + + # Add a clear button to reset the motion sequence + cc_clear_motion = gr.Button("Clear All Motions", variant="stop") + + def clear_camera_motion(): + return "" + + cc_clear_motion.click( + fn=clear_camera_motion, + inputs=[], + outputs=[cc_camera_motion] + ) + + cc_tracking_method = gr.Radio( + label="Tracking Method", + choices=["spatracker", "moge"], + value="moge" + ) + + # Add run button for Camera Control tab + cc_run_btn = gr.Button("Run Camera Control", variant="primary", size="lg") + + # Connect to process function + cc_run_btn.click( + fn=process_camera_control, + inputs=[ + source, common_prompt, + cc_camera_motion, cc_tracking_method + ], + outputs=[output_video] + ) + + # Object Manipulation tab + with gr.TabItem("Object Manipulation"): + gr.Markdown("## Object Manipulation") + om_object_mask = gr.File( + label="Object Mask Image", + file_types=["image"] + ) + gr.Markdown("Upload a binary mask image, white areas indicate the object to manipulate") + om_object_motion = gr.Dropdown( + label="Object Motion Type", + choices=["up", "down", "left", "right", "front", "back", "rot"], + value="up" + ) + om_tracking_method = gr.Radio( + label="Tracking Method", + choices=["spatracker", "moge"], + value="moge" + ) + + # Add run button for Object Manipulation tab + om_run_btn = gr.Button("Run Object Manipulation", variant="primary", size="lg") + + # Connect to process function + om_run_btn.click( + fn=process_object_manipulation, + inputs=[ + source, common_prompt, + om_object_motion, om_object_mask, om_tracking_method + ], + outputs=[output_video] + ) + + # Animating meshes to video tab + with gr.TabItem("Animating meshes to video"): + gr.Markdown("## Mesh Animation to Video") + gr.Markdown(""" + Note: Currently only supports tracking videos generated with Blender (version > 4.0). + Please run the script `scripts/blender.py` in your Blender project to generate tracking videos. + """) + ma_tracking_video = gr.File( + label="Tracking Video", + file_types=["video"] + ) + gr.Markdown("Tracking video needs to be generated from Blender") + + # Simplified controls - Radio buttons for Yes/No and separate file upload + with gr.Row(): + ma_repaint_option = gr.Radio( + label="Repaint First Frame", + choices=["No", "Yes"], + value="No" + ) + gr.Markdown("### Note: If you want to use your own image as repainted first frame, please upload the image in below.") + # Custom image uploader (always visible) + ma_repaint_image = gr.File( + label="Custom Repaint Image", + file_types=["image"] + ) + + # Add run button for Mesh Animation tab + ma_run_btn = gr.Button("Run Mesh Animation", variant="primary", size="lg") + + # Connect to process function + ma_run_btn.click( + fn=process_mesh_animation, + inputs=[ + source, common_prompt, + ma_tracking_video, ma_repaint_option, ma_repaint_image + ], + outputs=[output_video] + ) + +# Launch interface +if __name__ == "__main__": + print(f"Using GPU: {GPU_ID}") + print(f"Web UI will start on port {args.port}") + if args.share: + print("Creating public link for remote access") + + # Launch interface + demo.launch(share=args.share, server_port=args.port) \ No newline at end of file diff --git a/config/__init__.py b/config/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/config/base_cfg.py b/config/base_cfg.py new file mode 100644 index 0000000000000000000000000000000000000000..6d5fc1c749a22846654e53d52cc5f0d190a4c043 --- /dev/null +++ b/config/base_cfg.py @@ -0,0 +1,410 @@ +#python3.10 +"""Hierachical configuration for different pipelines, using `yacs` +(refered to https://github.com/rbgirshick/yacs) + +This projects contain the configuration for three aspects: + the regular config for experiment setting + + NOTE: Each experiment will be assigned a seperate working space, and the + intermediate results will be saved in the working space. The experimentes + folder structure is as follows: + { + /${ROOT_WORK_DIR}/ + └── ${PIPELINES_NAME}/ + ├── ${EXP_NAME}/ + ├── ${CHECKPOINT_DIR}/ + ├── ${RESULT_DIR}/ + ├── meta.json/ + └── ${LOG_DIR} + } + +""" + +import os, sys +from .yacs import CfgNode as CN +import argparse +import numpy as np + +# the parser for boolean +def bool_parser(arg): + """Parses an argument to boolean.""" + if isinstance(arg, bool): + return arg + if arg is None: + return False + if arg.lower() in ['1', 'true', 't', 'yes', 'y']: + return True + if arg.lower() in ['0', 'false', 'f', 'no', 'n']: + return False + raise ValueError(f'`{arg}` cannot be converted to boolean!') + +# ----------------------------------------------------------------------------- +# base cfg +# ----------------------------------------------------------------------------- +cfg = CN() + +# configuration for basic experiments +cfg.save_dir = "./checkpoints" +cfg.restore_ckpt = "" +cfg.model_name = "cotracker" +cfg.exp_name = "" + +# NOTE: configuration for datasets and augmentation +cfg.dataset_root = "" +cfg.eval_datasets = [""] +cfg.dont_use_augs = False +cfg.crop_size = [384, 512] +cfg.traj_per_sample = 384 +cfg.sample_vis_1st_frame = False +cfg.depth_near = 0.01 # meter +cfg.depth_far = 65.0 # meter +cfg.sequence_len = 24 + +# NOTE: configuration for network arch +cfg.sliding_window_len = 8 +cfg.remove_space_attn = False +cfg.updateformer_hidden_size = 384 +cfg.updateformer_num_heads = 8 +cfg.updateformer_space_depth = 6 +cfg.updateformer_time_depth = 6 +cfg.model_stride = 4 +cfg.train_iters = 4 +cfg.if_ARAP = False +cfg.Embed3D = False +cfg.Loss_W_feat = 5e-1 +cfg.Loss_W_cls = 1e-4 +cfg.depth_color = False +cfg.flash_attn = False +cfg.corr_dp = True +cfg.support_grid = 0 +cfg.backbone = "CNN" +cfg.enc_only = False +cfg.init_match = False +cfg.Nblock = 4 + +# NOTE: configuration for training and saving +cfg.nodes_num = 1 +cfg.batch_size = 1 +cfg.num_workers = 6 +cfg.mixed_precision = False +cfg.lr = 0.0005 +cfg.wdecay = 0.00001 +cfg.num_steps = 200000 +cfg.evaluate_every_n_epoch = 1 +cfg.save_every_n_epoch = 1 +cfg.validate_at_start = False +cfg.save_freq = 100 +cfg.eval_max_seq_len = 1000 +cfg.debug = False +cfg.fine_tune = False +cfg.aug_wind_sample = False +cfg.use_video_flip = False +cfg.fix_backbone = False +cfg.tune_backbone = False +cfg.tune_arap = False +cfg.tune_per_scene = False +cfg.use_hier_encoder = False +cfg.scales = [4, 2] + + +# NOTE: configuration for monocular depth estimator +cfg.mde_name = "zoedepth_nk" + +# ----------------------------------------------------------------------------- + +# configurations for the command line +parser = argparse.ArgumentParser() + +# config for the basic experiment +parser.add_argument("--save_dir", default="./checkpoints", type=str ,help="path to save checkpoints") +parser.add_argument("--restore_ckpt", default="", help="path to restore a checkpoint") +parser.add_argument("--model_name", default="cotracker", help="model name") +parser.add_argument("--exp_name", type=str, default="base", + help="the name for experiment", + ) +# config for dataset and augmentation +parser.add_argument( + "--dataset_root", type=str, help="path lo all the datasets (train and eval)" +) +parser.add_argument( + "--eval_datasets", nargs="+", default=["things", "badja"], + help="what datasets to use for evaluation", +) +parser.add_argument( + "--dont_use_augs", action="store_true", default=False, + help="don't apply augmentations during training", +) +parser.add_argument( + "--crop_size", type=int, nargs="+", default=[384, 512], + help="crop videos to this resolution during training", +) +parser.add_argument( + "--traj_per_sample", type=int, default=768, + help="the number of trajectories to sample for training", +) +parser.add_argument( + "--depth_near", type=float, default=0.01, help="near plane depth" +) +parser.add_argument( + "--depth_far", type=float, default=65.0, help="far plane depth" +) +parser.add_argument( + "--sample_vis_1st_frame", + action="store_true", + default=False, + help="only sample trajectories with points visible on the first frame", +) +parser.add_argument( + "--sequence_len", type=int, default=24, help="train sequence length" +) +# configuration for network arch +parser.add_argument( + "--sliding_window_len", + type=int, + default=8, + help="length of the CoTracker sliding window", +) +parser.add_argument( + "--remove_space_attn", + action="store_true", + default=False, + help="remove space attention from CoTracker", +) +parser.add_argument( + "--updateformer_hidden_size", + type=int, + default=384, + help="hidden dimension of the CoTracker transformer model", +) +parser.add_argument( + "--updateformer_num_heads", + type=int, + default=8, + help="number of heads of the CoTracker transformer model", +) +parser.add_argument( + "--updateformer_space_depth", + type=int, + default=6, + help="number of group attention layers in the CoTracker transformer model", +) +parser.add_argument( + "--updateformer_time_depth", + type=int, + default=6, + help="number of time attention layers in the CoTracker transformer model", +) +parser.add_argument( + "--model_stride", + type=int, + default=4, + help="stride of the CoTracker feature network", +) +parser.add_argument( + "--train_iters", + type=int, + default=4, + help="number of updates to the disparity field in each forward pass.", +) +parser.add_argument( + "--if_ARAP", + action="store_true", + default=False, + help="if using ARAP loss in the optimization", +) +parser.add_argument( + "--Embed3D", + action="store_true", + default=False, + help="if using the 3D embedding for image", +) +parser.add_argument( + "--Loss_W_feat", + type=float, + default=5e-1, + help="weight for the feature loss", +) +parser.add_argument( + "--Loss_W_cls", + type=float, + default=1e-4, + help="weight for the classification loss", +) +parser.add_argument( + "--depth_color", + action="store_true", + default=False, + help="if using the color for depth", +) +parser.add_argument( + "--flash_attn", + action="store_true", + default=False, + help="if using the flash attention", +) +parser.add_argument( + "--corr_dp", + action="store_true", + default=False, + help="if using the correlation of depth", +) +parser.add_argument( + "--support_grid", + type=int, + default=0, + help="if using the support grid", +) +parser.add_argument( + "--backbone", + type=str, + default="CNN", + help="backbone for the CoTracker feature network", +) +parser.add_argument( + "--enc_only", + action="store_true", + default=False, + help="if using the encoder only", +) +parser.add_argument( + "--init_match", + action="store_true", + default=False, + help="if using the initial matching", +) +parser.add_argument( + "--Nblock", + type=int, + default=4, + help="number of blocks in the CoTracker feature network", +) + +# configuration for training and saving +parser.add_argument( + "--nodes_num", type=int, default=1, help="number of nodes used for training." +) +parser.add_argument( + "--batch_size", type=int, default=1, help="batch size used during training." +) +parser.add_argument( + "--num_workers", type=int, default=6, help="number of dataloader workers" +) + +parser.add_argument( + "--mixed_precision", + action="store_true", default=False, + help="use mixed precision" +) +parser.add_argument("--lr", type=float, default=0.0005, help="max learning rate.") +parser.add_argument( + "--wdecay", type=float, default=0.00001, help="Weight decay in optimizer." +) +parser.add_argument( + "--num_steps", type=int, default=200000, help="length of training schedule." +) +parser.add_argument( + "--evaluate_every_n_epoch", + type=int, + default=1, + help="evaluate during training after every n epochs, after every epoch by default", +) +parser.add_argument( + "--save_every_n_epoch", + type=int, + default=1, + help="save checkpoints during training after every n epochs, after every epoch by default", +) +parser.add_argument( + "--validate_at_start", + action="store_true", + default=False, + help="whether to run evaluation before training starts", +) +parser.add_argument( + "--save_freq", + type=int, + default=100, + help="frequency of trajectory visualization during training", +) +parser.add_argument( + "--eval_max_seq_len", + type=int, + default=1000, + help="maximum length of evaluation videos", +) +parser.add_argument( + "--debug", + action="store_true", + default=False, + help="if using the visibility mask", +) +parser.add_argument( + "--fine_tune", + action="store_true", + default=False, + help="if fine tune the model", +) +parser.add_argument( + "--aug_wind_sample", + action="store_true", + default=False, + help="if using the window sampling", +) +parser.add_argument( + "--use_video_flip", + action="store_true", + default=False, + help="if using the video flip", +) +parser.add_argument( + "--fix_backbone", + action="store_true", + default=False, + help="if fix the backbone", +) +parser.add_argument( + "--tune_backbone", + action="store_true", + default=False, + help="if tune the backbone", +) +parser.add_argument( + "--tune_arap", + action="store_true", + default=False, + help="if fix the backbone", +) +parser.add_argument( + "--tune_per_scene", + action="store_true", + default=False, + help="if tune one scene", +) +parser.add_argument( + "--use_hier_encoder", + action="store_true", + default=False, + help="if using the hierarchical encoder", +) +parser.add_argument( + "--scales", + type=int, + nargs="+", + default=[4, 2], + help="scales for the CoTracker feature network", +) + +# config for monocular depth estimator +parser.add_argument( + "--mde_name", type=str, default="zoedepth_nk", help="name of the MDE model" +) +args = parser.parse_args() +args_dict = vars(args) + +# ----------------------------------------------------------------------------- + +# merge the `args` to the `cfg` +cfg.merge_from_dict(args_dict) + +cfg.ckpt_path=os.path.join(args.save_dir, args.model_name ,args.exp_name) + diff --git a/config/ssm_cfg.py b/config/ssm_cfg.py new file mode 100644 index 0000000000000000000000000000000000000000..b4595948e1e107ef8f659cef217d3fb7b3d1ad55 --- /dev/null +++ b/config/ssm_cfg.py @@ -0,0 +1,347 @@ +#python3.10 +"""Hierachical configuration for different pipelines, using `yacs` +(refered to https://github.com/rbgirshick/yacs) + +This projects contain the configuration for three aspects: + the regular config for experiment setting + + NOTE: Each experiment will be assigned a seperate working space, and the + intermediate results will be saved in the working space. The experimentes + folder structure is as follows: + { + /${ROOT_WORK_DIR}/ + └── ${PIPELINES_NAME}/ + ├── ${EXP_NAME}/ + ├── ${CHECKPOINT_DIR}/ + ├── ${RESULT_DIR}/ + ├── meta.json/ + └── ${LOG_DIR} + } + +""" + +import os, sys +from .yacs import CfgNode as CN +import argparse +import numpy as np + +# the parser for boolean +def bool_parser(arg): + """Parses an argument to boolean.""" + if isinstance(arg, bool): + return arg + if arg is None: + return False + if arg.lower() in ['1', 'true', 't', 'yes', 'y']: + return True + if arg.lower() in ['0', 'false', 'f', 'no', 'n']: + return False + raise ValueError(f'`{arg}` cannot be converted to boolean!') + +# ----------------------------------------------------------------------------- +# base cfg +# ----------------------------------------------------------------------------- +cfg = CN() + +# configuration for basic experiments +cfg.save_dir = "./checkpoints" +cfg.restore_ckpt = "" +cfg.model_name = "cotracker" +cfg.exp_name = "" + +# NOTE: configuration for datasets and augmentation +cfg.dataset_root = "" +cfg.eval_datasets = [""] +cfg.dont_use_augs = False +cfg.crop_size = [384, 512] +cfg.traj_per_sample = 384 +cfg.sample_vis_1st_frame = False +cfg.depth_near = 0.01 # meter +cfg.depth_far = 65.0 # meter +cfg.sequence_len = 24 + +# NOTE: configuration for network arch +cfg.hidden_size = 384 +cfg.mamba_depth = 8 +cfg.model_stride = 4 +cfg.train_iters = 4 +cfg.updateformer_num_heads = 8 +cfg.updateformer_hidden_size = 384 +cfg.if_ARAP = False +cfg.Embed3D = False +cfg.Loss_W_feat = 5e-1 +cfg.Loss_W_cls = 1e-4 +cfg.depth_color = False +cfg.flash_attn = False +cfg.corr_dp = True +cfg.support_grid = 0 +cfg.backbone = "CNN" +cfg.enc_only = False + +# NOTE: configuration for training and saving +cfg.nodes_num = 1 +cfg.batch_size = 1 +cfg.num_workers = 6 +cfg.mixed_precision = False +cfg.lr = 0.0005 +cfg.wdecay = 0.00001 +cfg.num_steps = 200000 +cfg.evaluate_every_n_epoch = 1 +cfg.save_every_n_epoch = 1 +cfg.validate_at_start = False +cfg.save_freq = 100 +cfg.eval_max_seq_len = 1000 +cfg.debug = False +cfg.fine_tune = False +cfg.aug_wind_sample = False +cfg.use_video_flip = False +cfg.fix_backbone = False +cfg.tune_backbone = False + + +# NOTE: configuration for monocular depth estimator +cfg.mde_name = "zoedepth_nk" + +# ----------------------------------------------------------------------------- + +# configurations for the command line +parser = argparse.ArgumentParser() + +# config for the basic experiment +parser.add_argument("--save_dir", default="./checkpoints", type=str ,help="path to save checkpoints") +parser.add_argument("--restore_ckpt", default="", help="path to restore a checkpoint") +parser.add_argument("--model_name", default="cotracker", help="model name") +parser.add_argument("--exp_name", type=str, default="base", + help="the name for experiment", + ) +# config for dataset and augmentation +parser.add_argument( + "--dataset_root", type=str, help="path lo all the datasets (train and eval)" +) +parser.add_argument( + "--eval_datasets", nargs="+", default=["things", "badja"], + help="what datasets to use for evaluation", +) +parser.add_argument( + "--dont_use_augs", action="store_true", default=False, + help="don't apply augmentations during training", +) +parser.add_argument( + "--crop_size", type=int, nargs="+", default=[384, 512], + help="crop videos to this resolution during training", +) +parser.add_argument( + "--traj_per_sample", type=int, default=768, + help="the number of trajectories to sample for training", +) +parser.add_argument( + "--depth_near", type=float, default=0.01, help="near plane depth" +) +parser.add_argument( + "--depth_far", type=float, default=65.0, help="far plane depth" +) +parser.add_argument( + "--sample_vis_1st_frame", + action="store_true", + default=False, + help="only sample trajectories with points visible on the first frame", +) +parser.add_argument( + "--sequence_len", type=int, default=24, help="train sequence length" +) +# configuration for network arch +parser.add_argument( + "--hidden_size", + type=int, + default=384, + help="hidden dimension of the CoTracker transformer model", +) +parser.add_argument( + "--mamba_depth", + type=int, + default=6, + help="number of group attention layers in the CoTracker transformer model", +) +parser.add_argument( + "--updateformer_num_heads", + type=int, + default=8, + help="number of heads of the CoTracker transformer model", +) +parser.add_argument( + "--updateformer_hidden_size", + type=int, + default=384, + help="hidden dimension of the CoTracker transformer model", +) +parser.add_argument( + "--model_stride", + type=int, + default=4, + help="stride of the CoTracker feature network", +) +parser.add_argument( + "--train_iters", + type=int, + default=4, + help="number of updates to the disparity field in each forward pass.", +) +parser.add_argument( + "--if_ARAP", + action="store_true", + default=False, + help="if using ARAP loss in the optimization", +) +parser.add_argument( + "--Embed3D", + action="store_true", + default=False, + help="if using the 3D embedding for image", +) +parser.add_argument( + "--Loss_W_feat", + type=float, + default=5e-1, + help="weight for the feature loss", +) +parser.add_argument( + "--Loss_W_cls", + type=float, + default=1e-4, + help="weight for the classification loss", +) +parser.add_argument( + "--depth_color", + action="store_true", + default=False, + help="if using the color for depth", +) +parser.add_argument( + "--flash_attn", + action="store_true", + default=False, + help="if using the flash attention", +) +parser.add_argument( + "--corr_dp", + action="store_true", + default=False, + help="if using the correlation of depth", +) +parser.add_argument( + "--support_grid", + type=int, + default=0, + help="if using the support grid", +) +parser.add_argument( + "--backbone", + type=str, + default="CNN", + help="backbone for the CoTracker feature network", +) +parser.add_argument( + "--enc_only", + action="store_true", + default=False, + help="if using the encoder only", +) + +# configuration for training and saving +parser.add_argument( + "--nodes_num", type=int, default=1, help="number of nodes used for training." +) +parser.add_argument( + "--batch_size", type=int, default=1, help="batch size used during training." +) +parser.add_argument( + "--num_workers", type=int, default=6, help="number of dataloader workers" +) + +parser.add_argument( + "--mixed_precision", + action="store_true", default=False, + help="use mixed precision" +) +parser.add_argument("--lr", type=float, default=0.0005, help="max learning rate.") +parser.add_argument( + "--wdecay", type=float, default=0.00001, help="Weight decay in optimizer." +) +parser.add_argument( + "--num_steps", type=int, default=200000, help="length of training schedule." +) +parser.add_argument( + "--evaluate_every_n_epoch", + type=int, + default=1, + help="evaluate during training after every n epochs, after every epoch by default", +) +parser.add_argument( + "--save_every_n_epoch", + type=int, + default=1, + help="save checkpoints during training after every n epochs, after every epoch by default", +) +parser.add_argument( + "--validate_at_start", + action="store_true", + default=False, + help="whether to run evaluation before training starts", +) +parser.add_argument( + "--save_freq", + type=int, + default=100, + help="frequency of trajectory visualization during training", +) +parser.add_argument( + "--eval_max_seq_len", + type=int, + default=1000, + help="maximum length of evaluation videos", +) +parser.add_argument( + "--debug", + action="store_true", + default=False, + help="if using the visibility mask", +) +parser.add_argument( + "--fine_tune", + action="store_true", + default=False, + help="if fine tune the model", +) +parser.add_argument( + "--aug_wind_sample", + action="store_true", + default=False, + help="if using the window sampling", +) +parser.add_argument( + "--use_video_flip", + action="store_true", + default=False, + help="if using the video flip", +) +parser.add_argument( + "--fix_backbone", + action="store_true", + default=False, + help="if fix the backbone", +) + +# config for monocular depth estimator +parser.add_argument( + "--mde_name", type=str, default="zoedepth_nk", help="name of the MDE model" +) +args = parser.parse_args() +args_dict = vars(args) + +# ----------------------------------------------------------------------------- + +# merge the `args` to the `cfg` +cfg.merge_from_dict(args_dict) + +cfg.ckpt_path=os.path.join(args.save_dir, args.model_name ,args.exp_name) + diff --git a/config/yacs.py b/config/yacs.py new file mode 100644 index 0000000000000000000000000000000000000000..4c0632d20a89f1caec302570479269eb4078773e --- /dev/null +++ b/config/yacs.py @@ -0,0 +1,506 @@ +# Copyright (c) 2018-present, Facebook, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +############################################################################## + +"""YACS -- Yet Another Configuration System is designed to be a simple +configuration management system for academic and industrial research +projects. + +See README.md for usage and examples. +""" + +import copy +import io +import logging +import os +from ast import literal_eval + +import yaml + + +# Flag for py2 and py3 compatibility to use when separate code paths are necessary +# When _PY2 is False, we assume Python 3 is in use +_PY2 = False + +# Filename extensions for loading configs from files +_YAML_EXTS = {"", ".yaml", ".yml"} +_PY_EXTS = {".py"} + +# py2 and py3 compatibility for checking file object type +# We simply use this to infer py2 vs py3 +try: + _FILE_TYPES = (file, io.IOBase) + _PY2 = True +except NameError: + _FILE_TYPES = (io.IOBase,) + +# CfgNodes can only contain a limited set of valid types +_VALID_TYPES = {tuple, list, str, int, float, bool} +# py2 allow for str and unicode +if _PY2: + _VALID_TYPES = _VALID_TYPES.union({unicode}) # noqa: F821 + +# Utilities for importing modules from file paths +if _PY2: + # imp is available in both py2 and py3 for now, but is deprecated in py3 + import imp +else: + import importlib.util + +logger = logging.getLogger(__name__) + + +class CfgNode(dict): + """ + CfgNode represents an internal node in the configuration tree. It's a simple + dict-like container that allows for attribute-based access to keys. + """ + + IMMUTABLE = "__immutable__" + DEPRECATED_KEYS = "__deprecated_keys__" + RENAMED_KEYS = "__renamed_keys__" + + def __init__(self, init_dict=None, key_list=None): + # Recursively convert nested dictionaries in init_dict into CfgNodes + init_dict = {} if init_dict is None else init_dict + key_list = [] if key_list is None else key_list + for k, v in init_dict.items(): + if type(v) is dict: + # Convert dict to CfgNode + init_dict[k] = CfgNode(v, key_list=key_list + [k]) + else: + # Check for valid leaf type or nested CfgNode + _assert_with_logging( + _valid_type(v, allow_cfg_node=True), + "Key {} with value {} is not a valid type; valid types: {}".format( + ".".join(key_list + [k]), type(v), _VALID_TYPES + ), + ) + super(CfgNode, self).__init__(init_dict) + # Manage if the CfgNode is frozen or not + self.__dict__[CfgNode.IMMUTABLE] = False + # Deprecated options + # If an option is removed from the code and you don't want to break existing + # yaml configs, you can add the full config key as a string to the set below. + self.__dict__[CfgNode.DEPRECATED_KEYS] = set() + # Renamed options + # If you rename a config option, record the mapping from the old name to the new + # name in the dictionary below. Optionally, if the type also changed, you can + # make the value a tuple that specifies first the renamed key and then + # instructions for how to edit the config file. + self.__dict__[CfgNode.RENAMED_KEYS] = { + # 'EXAMPLE.OLD.KEY': 'EXAMPLE.NEW.KEY', # Dummy example to follow + # 'EXAMPLE.OLD.KEY': ( # A more complex example to follow + # 'EXAMPLE.NEW.KEY', + # "Also convert to a tuple, e.g., 'foo' -> ('foo',) or " + # + "'foo:bar' -> ('foo', 'bar')" + # ), + } + + def __getattr__(self, name): + if name in self: + return self[name] + else: + raise AttributeError(name) + + def __setattr__(self, name, value): + if self.is_frozen(): + raise AttributeError( + "Attempted to set {} to {}, but CfgNode is immutable".format( + name, value + ) + ) + + _assert_with_logging( + name not in self.__dict__, + "Invalid attempt to modify internal CfgNode state: {}".format(name), + ) + _assert_with_logging( + _valid_type(value, allow_cfg_node=True), + "Invalid type {} for key {}; valid types = {}".format( + type(value), name, _VALID_TYPES + ), + ) + + self[name] = value + + def __str__(self): + def _indent(s_, num_spaces): + s = s_.split("\n") + if len(s) == 1: + return s_ + first = s.pop(0) + s = [(num_spaces * " ") + line for line in s] + s = "\n".join(s) + s = first + "\n" + s + return s + + r = "" + s = [] + for k, v in sorted(self.items()): + seperator = "\n" if isinstance(v, CfgNode) else " " + attr_str = "{}:{}{}".format(str(k), seperator, str(v)) + attr_str = _indent(attr_str, 2) + s.append(attr_str) + r += "\n".join(s) + return r + + def __repr__(self): + return "{}({})".format(self.__class__.__name__, super(CfgNode, self).__repr__()) + + def dump(self): + """Dump to a string.""" + self_as_dict = _to_dict(self) + return yaml.safe_dump(self_as_dict) + + def merge_from_file(self, cfg_filename): + """Load a yaml config file and merge it this CfgNode.""" + with open(cfg_filename, "r") as f: + cfg = load_cfg(f) + self.merge_from_other_cfg(cfg) + + def merge_from_other_cfg(self, cfg_other): + """Merge `cfg_other` into this CfgNode.""" + _merge_a_into_b(cfg_other, self, self, []) + + def merge_from_list(self, cfg_list): + """Merge config (keys, values) in a list (e.g., from command line) into + this CfgNode. For example, `cfg_list = ['FOO.BAR', 0.5]`. + """ + _assert_with_logging( + len(cfg_list) % 2 == 0, + "Override list has odd length: {}; it must be a list of pairs".format( + cfg_list + ), + ) + root = self + for full_key, v in zip(cfg_list[0::2], cfg_list[1::2]): + if root.key_is_deprecated(full_key): + continue + if root.key_is_renamed(full_key): + root.raise_key_rename_error(full_key) + key_list = full_key.split(".") + d = self + for subkey in key_list[:-1]: + _assert_with_logging( + subkey in d, "Non-existent key: {}".format(full_key) + ) + d = d[subkey] + subkey = key_list[-1] + _assert_with_logging(subkey in d, "Non-existent key: {}".format(full_key)) + value = _decode_cfg_value(v) + value = _check_and_coerce_cfg_value_type(value, d[subkey], subkey, full_key) + d[subkey] = value + def merge_from_dict(self, cfg_dict): + """Merge config (keys, values) in a dict into this CfgNode.""" + cfg_dict = cfg_dict.items() + cfg_list = [] + for pair in cfg_dict: + cfg_list.append(pair[0]) + cfg_list.append(pair[1]) + self.merge_from_list(cfg_list) + + def freeze(self): + """Make this CfgNode and all of its children immutable.""" + self._immutable(True) + + def defrost(self): + """Make this CfgNode and all of its children mutable.""" + self._immutable(False) + + def is_frozen(self): + """Return mutability.""" + return self.__dict__[CfgNode.IMMUTABLE] + + def _immutable(self, is_immutable): + """Set immutability to is_immutable and recursively apply the setting + to all nested CfgNodes. + """ + self.__dict__[CfgNode.IMMUTABLE] = is_immutable + # Recursively set immutable state + for v in self.__dict__.values(): + if isinstance(v, CfgNode): + v._immutable(is_immutable) + for v in self.values(): + if isinstance(v, CfgNode): + v._immutable(is_immutable) + + def clone(self): + """Recursively copy this CfgNode.""" + return copy.deepcopy(self) + + def register_deprecated_key(self, key): + """Register key (e.g. `FOO.BAR`) a deprecated option. When merging deprecated + keys a warning is generated and the key is ignored. + """ + _assert_with_logging( + key not in self.__dict__[CfgNode.DEPRECATED_KEYS], + "key {} is already registered as a deprecated key".format(key), + ) + self.__dict__[CfgNode.DEPRECATED_KEYS].add(key) + + def register_renamed_key(self, old_name, new_name, message=None): + """Register a key as having been renamed from `old_name` to `new_name`. + When merging a renamed key, an exception is thrown alerting to user to + the fact that the key has been renamed. + """ + _assert_with_logging( + old_name not in self.__dict__[CfgNode.RENAMED_KEYS], + "key {} is already registered as a renamed cfg key".format(old_name), + ) + value = new_name + if message: + value = (new_name, message) + self.__dict__[CfgNode.RENAMED_KEYS][old_name] = value + + def key_is_deprecated(self, full_key): + """Test if a key is deprecated.""" + if full_key in self.__dict__[CfgNode.DEPRECATED_KEYS]: + logger.warning("Deprecated config key (ignoring): {}".format(full_key)) + return True + return False + + def key_is_renamed(self, full_key): + """Test if a key is renamed.""" + return full_key in self.__dict__[CfgNode.RENAMED_KEYS] + + def raise_key_rename_error(self, full_key): + new_key = self.__dict__[CfgNode.RENAMED_KEYS][full_key] + if isinstance(new_key, tuple): + msg = " Note: " + new_key[1] + new_key = new_key[0] + else: + msg = "" + raise KeyError( + "Key {} was renamed to {}; please update your config.{}".format( + full_key, new_key, msg + ) + ) + + +def load_cfg(cfg_file_obj_or_str): + """Load a cfg. Supports loading from: + - A file object backed by a YAML file + - A file object backed by a Python source file that exports an attribute + "cfg" that is either a dict or a CfgNode + - A string that can be parsed as valid YAML + """ + _assert_with_logging( + isinstance(cfg_file_obj_or_str, _FILE_TYPES + (str,)), + "Expected first argument to be of type {} or {}, but it was {}".format( + _FILE_TYPES, str, type(cfg_file_obj_or_str) + ), + ) + if isinstance(cfg_file_obj_or_str, str): + return _load_cfg_from_yaml_str(cfg_file_obj_or_str) + elif isinstance(cfg_file_obj_or_str, _FILE_TYPES): + return _load_cfg_from_file(cfg_file_obj_or_str) + else: + raise NotImplementedError("Impossible to reach here (unless there's a bug)") + + +def _load_cfg_from_file(file_obj): + """Load a config from a YAML file or a Python source file.""" + _, file_extension = os.path.splitext(file_obj.name) + if file_extension in _YAML_EXTS: + return _load_cfg_from_yaml_str(file_obj.read()) + elif file_extension in _PY_EXTS: + return _load_cfg_py_source(file_obj.name) + else: + raise Exception( + "Attempt to load from an unsupported file type {}; " + "only {} are supported".format(file_obj, _YAML_EXTS.union(_PY_EXTS)) + ) + + +def _load_cfg_from_yaml_str(str_obj): + """Load a config from a YAML string encoding.""" + cfg_as_dict = yaml.safe_load(str_obj) + return CfgNode(cfg_as_dict) + + +def _load_cfg_py_source(filename): + """Load a config from a Python source file.""" + module = _load_module_from_file("yacs.config.override", filename) + _assert_with_logging( + hasattr(module, "cfg"), + "Python module from file {} must have 'cfg' attr".format(filename), + ) + VALID_ATTR_TYPES = {dict, CfgNode} + _assert_with_logging( + type(module.cfg) in VALID_ATTR_TYPES, + "Imported module 'cfg' attr must be in {} but is {} instead".format( + VALID_ATTR_TYPES, type(module.cfg) + ), + ) + if type(module.cfg) is dict: + return CfgNode(module.cfg) + else: + return module.cfg + + +def _to_dict(cfg_node): + """Recursively convert all CfgNode objects to dict objects.""" + + def convert_to_dict(cfg_node, key_list): + if not isinstance(cfg_node, CfgNode): + _assert_with_logging( + _valid_type(cfg_node), + "Key {} with value {} is not a valid type; valid types: {}".format( + ".".join(key_list), type(cfg_node), _VALID_TYPES + ), + ) + return cfg_node + else: + cfg_dict = dict(cfg_node) + for k, v in cfg_dict.items(): + cfg_dict[k] = convert_to_dict(v, key_list + [k]) + return cfg_dict + + return convert_to_dict(cfg_node, []) + + +def _valid_type(value, allow_cfg_node=False): + return (type(value) in _VALID_TYPES) or (allow_cfg_node and type(value) == CfgNode) + + +def _merge_a_into_b(a, b, root, key_list): + """Merge config dictionary a into config dictionary b, clobbering the + options in b whenever they are also specified in a. + """ + _assert_with_logging( + isinstance(a, CfgNode), + "`a` (cur type {}) must be an instance of {}".format(type(a), CfgNode), + ) + _assert_with_logging( + isinstance(b, CfgNode), + "`b` (cur type {}) must be an instance of {}".format(type(b), CfgNode), + ) + + for k, v_ in a.items(): + full_key = ".".join(key_list + [k]) + # a must specify keys that are in b + if k not in b: + if root.key_is_deprecated(full_key): + continue + elif root.key_is_renamed(full_key): + root.raise_key_rename_error(full_key) + else: + v = copy.deepcopy(v_) + v = _decode_cfg_value(v) + b.update({k: v}) + else: + v = copy.deepcopy(v_) + v = _decode_cfg_value(v) + v = _check_and_coerce_cfg_value_type(v, b[k], k, full_key) + + # Recursively merge dicts + if isinstance(v, CfgNode): + try: + _merge_a_into_b(v, b[k], root, key_list + [k]) + except BaseException: + raise + else: + b[k] = v + + +def _decode_cfg_value(v): + """Decodes a raw config value (e.g., from a yaml config files or command + line argument) into a Python object. + """ + # Configs parsed from raw yaml will contain dictionary keys that need to be + # converted to CfgNode objects + if isinstance(v, dict): + return CfgNode(v) + # All remaining processing is only applied to strings + if not isinstance(v, str): + return v + # Try to interpret `v` as a: + # string, number, tuple, list, dict, boolean, or None + try: + v = literal_eval(v) + # The following two excepts allow v to pass through when it represents a + # string. + # + # Longer explanation: + # The type of v is always a string (before calling literal_eval), but + # sometimes it *represents* a string and other times a data structure, like + # a list. In the case that v represents a string, what we got back from the + # yaml parser is 'foo' *without quotes* (so, not '"foo"'). literal_eval is + # ok with '"foo"', but will raise a ValueError if given 'foo'. In other + # cases, like paths (v = 'foo/bar' and not v = '"foo/bar"'), literal_eval + # will raise a SyntaxError. + except ValueError: + pass + except SyntaxError: + pass + return v + + +def _check_and_coerce_cfg_value_type(replacement, original, key, full_key): + """Checks that `replacement`, which is intended to replace `original` is of + the right type. The type is correct if it matches exactly or is one of a few + cases in which the type can be easily coerced. + """ + original_type = type(original) + replacement_type = type(replacement) + + # The types must match (with some exceptions) + if replacement_type == original_type: + return replacement + + # Cast replacement from from_type to to_type if the replacement and original + # types match from_type and to_type + def conditional_cast(from_type, to_type): + if replacement_type == from_type and original_type == to_type: + return True, to_type(replacement) + else: + return False, None + + # Conditionally casts + # list <-> tuple + casts = [(tuple, list), (list, tuple)] + # For py2: allow converting from str (bytes) to a unicode string + try: + casts.append((str, unicode)) # noqa: F821 + except Exception: + pass + + for (from_type, to_type) in casts: + converted, converted_value = conditional_cast(from_type, to_type) + if converted: + return converted_value + + raise ValueError( + "Type mismatch ({} vs. {}) with values ({} vs. {}) for config " + "key: {}".format( + original_type, replacement_type, original, replacement, full_key + ) + ) + + +def _assert_with_logging(cond, msg): + if not cond: + logger.debug(msg) + assert cond, msg + + +def _load_module_from_file(name, filename): + if _PY2: + module = imp.load_source(name, filename) + else: + spec = importlib.util.spec_from_file_location(name, filename) + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + return module \ No newline at end of file diff --git a/demo.py b/demo.py new file mode 100644 index 0000000000000000000000000000000000000000..3541b61f94d6d8a3b97699c83780208ba4117f21 --- /dev/null +++ b/demo.py @@ -0,0 +1,206 @@ +import os +import sys +import argparse +from PIL import Image +project_root = os.path.dirname(os.path.abspath(__file__)) +try: + sys.path.append(os.path.join(project_root, "submodules/MoGe")) + os.environ["TOKENIZERS_PARALLELISM"] = "false" +except: + print("Warning: MoGe not found, motion transfer will not be applied") + +import torch +import numpy as np +from PIL import Image +import torchvision.transforms as transforms +from moviepy.editor import VideoFileClip +from diffusers.utils import load_image, load_video + +from models.pipelines import DiffusionAsShaderPipeline, FirstFrameRepainter, CameraMotionGenerator, ObjectMotionGenerator +from submodules.MoGe.moge.model import MoGeModel + +def load_media(media_path, max_frames=49, transform=None): + """Load video or image frames and convert to tensor + + Args: + media_path (str): Path to video or image file + max_frames (int): Maximum number of frames to load + transform (callable): Transform to apply to frames + + Returns: + Tuple[torch.Tensor, float]: Video tensor [T,C,H,W] and FPS + """ + if transform is None: + transform = transforms.Compose([ + transforms.Resize((480, 720)), + transforms.ToTensor() + ]) + + # Determine if input is video or image based on extension + ext = os.path.splitext(media_path)[1].lower() + is_video = ext in ['.mp4', '.avi', '.mov'] + + if is_video: + frames = load_video(media_path) + fps = len(frames) / VideoFileClip(media_path).duration + else: + # Handle image as single frame + image = load_image(media_path) + frames = [image] + fps = 8 # Default fps for images + + # Ensure we have exactly max_frames + if len(frames) > max_frames: + frames = frames[:max_frames] + elif len(frames) < max_frames: + last_frame = frames[-1] + while len(frames) < max_frames: + frames.append(last_frame.copy()) + + # Convert frames to tensor + video_tensor = torch.stack([transform(frame) for frame in frames]) + + return video_tensor, fps, is_video + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument('--input_path', type=str, default=None, help='Path to input video/image') + parser.add_argument('--prompt', type=str, required=True, help='Repaint prompt') + parser.add_argument('--output_dir', type=str, default='outputs', help='Output directory') + parser.add_argument('--gpu', type=int, default=0, help='GPU device ID') + parser.add_argument('--checkpoint_path', type=str, default="EXCAI/Diffusion-As-Shader", help='Path to model checkpoint') + parser.add_argument('--depth_path', type=str, default=None, help='Path to depth image') + parser.add_argument('--tracking_path', type=str, default=None, help='Path to tracking video, if provided, camera motion and object manipulation will not be applied') + parser.add_argument('--repaint', type=str, default=None, + help='Path to repainted image, or "true" to perform repainting, if not provided use original frame') + parser.add_argument('--camera_motion', type=str, default=None, + help='Camera motion mode: "trans " or "rot " or "spiral "') + parser.add_argument('--object_motion', type=str, default=None, help='Object motion mode: up/down/left/right') + parser.add_argument('--object_mask', type=str, default=None, help='Path to object mask image (binary image)') + parser.add_argument('--tracking_method', type=str, default='spatracker', choices=['spatracker', 'moge'], + help='Tracking method to use (spatracker or moge)') + args = parser.parse_args() + + # Load input video/image + video_tensor, fps, is_video = load_media(args.input_path) + if not is_video: + args.tracking_method = "moge" + print("Image input detected, using MoGe for tracking video generation.") + + # Initialize pipeline + das = DiffusionAsShaderPipeline(gpu_id=args.gpu, output_dir=args.output_dir) + if args.tracking_method == "moge" and args.tracking_path is None: + moge = MoGeModel.from_pretrained("Ruicheng/moge-vitl").to(das.device) + + # Repaint first frame if requested + repaint_img_tensor = None + if args.repaint: + if args.repaint.lower() == "true": + repainter = FirstFrameRepainter(gpu_id=args.gpu, output_dir=args.output_dir) + repaint_img_tensor = repainter.repaint( + video_tensor[0], + prompt=args.prompt, + depth_path=args.depth_path + ) + else: + repaint_img_tensor, _, _ = load_media(args.repaint) + repaint_img_tensor = repaint_img_tensor[0] # Take first frame + + # Generate tracking if not provided + tracking_tensor = None + pred_tracks = None + cam_motion = CameraMotionGenerator(args.camera_motion) + + if args.tracking_path: + tracking_tensor, _, _ = load_media(args.tracking_path) + + elif args.tracking_method == "moge": + # Use the first frame from previously loaded video_tensor + infer_result = moge.infer(video_tensor[0].to(das.device)) # [C, H, W] in range [0,1] + H, W = infer_result["points"].shape[0:2] + pred_tracks = infer_result["points"].unsqueeze(0).repeat(49, 1, 1, 1) #[T, H, W, 3] + cam_motion.set_intr(infer_result["intrinsics"]) + + # Apply object motion if specified + if args.object_motion: + if args.object_mask is None: + raise ValueError("Object motion specified but no mask provided. Please provide a mask image with --object_mask") + + # Load mask image + mask_image = Image.open(args.object_mask).convert('L') # Convert to grayscale + mask_image = transforms.Resize((480, 720))(mask_image) # Resize to match video size + # Convert to binary mask + mask = torch.from_numpy(np.array(mask_image) > 127) # Threshold at 127 + + motion_generator = ObjectMotionGenerator(device=das.device) + + pred_tracks = motion_generator.apply_motion( + pred_tracks=pred_tracks, + mask=mask, + motion_type=args.object_motion, + distance=50, + num_frames=49, + tracking_method="moge" + ) + print("Object motion applied") + + # Apply camera motion if specified + if args.camera_motion: + poses = cam_motion.get_default_motion() # shape: [49, 4, 4] + print("Camera motion applied") + else: + # no poses + poses = torch.eye(4).unsqueeze(0).repeat(49, 1, 1) + # change pred_tracks into screen coordinate + pred_tracks_flatten = pred_tracks.reshape(video_tensor.shape[0], H*W, 3) + pred_tracks = cam_motion.w2s(pred_tracks_flatten, poses).reshape([video_tensor.shape[0], H, W, 3]) # [T, H, W, 3] + _, tracking_tensor = das.visualize_tracking_moge( + pred_tracks.cpu().numpy(), + infer_result["mask"].cpu().numpy() + ) + print('export tracking video via MoGe.') + + else: + # Generate tracking points + pred_tracks, pred_visibility, T_Firsts = das.generate_tracking_spatracker(video_tensor) + + # Apply camera motion if specified + if args.camera_motion: + poses = cam_motion.get_default_motion() # shape: [49, 4, 4] + pred_tracks = cam_motion.apply_motion_on_pts(pred_tracks, poses) + print("Camera motion applied") + + # Apply object motion if specified + if args.object_motion: + if args.object_mask is None: + raise ValueError("Object motion specified but no mask provided. Please provide a mask image with --object_mask") + + # Load mask image + mask_image = Image.open(args.object_mask).convert('L') # Convert to grayscale + mask_image = transforms.Resize((480, 720))(mask_image) # Resize to match video size + # Convert to binary mask + mask = torch.from_numpy(np.array(mask_image) > 127) # Threshold at 127 + + motion_generator = ObjectMotionGenerator(device=das.device) + + pred_tracks = motion_generator.apply_motion( + pred_tracks=pred_tracks.squeeze(), + mask=mask, + motion_type=args.object_motion, + distance=50, + num_frames=49, + tracking_method="spatracker" + ).unsqueeze(0) + print(f"Object motion '{args.object_motion}' applied using mask from {args.object_mask}") + + # Generate tracking tensor from modified tracks + _, tracking_tensor = das.visualize_tracking_spatracker(video_tensor, pred_tracks, pred_visibility, T_Firsts) + + das.apply_tracking( + video_tensor=video_tensor, + fps=8, + tracking_tensor=tracking_tensor, + img_cond_tensor=repaint_img_tensor, + prompt=args.prompt, + checkpoint_path=args.checkpoint_path + ) diff --git a/models/cogvideox_tracking.py b/models/cogvideox_tracking.py new file mode 100644 index 0000000000000000000000000000000000000000..be6adce948b06e15f6f589dd906ac61d6123aa19 --- /dev/null +++ b/models/cogvideox_tracking.py @@ -0,0 +1,1020 @@ +from typing import Any, Dict, Optional, Tuple, Union, List, Callable + +import torch, os, math +from torch import nn +from PIL import Image +from tqdm import tqdm + +from diffusers.utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers +from diffusers.models.modeling_outputs import Transformer2DModelOutput +from diffusers.models.transformers.cogvideox_transformer_3d import CogVideoXBlock, CogVideoXTransformer3DModel + +from diffusers.pipelines.cogvideo.pipeline_cogvideox import CogVideoXPipeline, CogVideoXPipelineOutput +from diffusers.pipelines.cogvideo.pipeline_cogvideox_image2video import CogVideoXImageToVideoPipeline +from diffusers.pipelines.cogvideo.pipeline_cogvideox_video2video import CogVideoXVideoToVideoPipeline +from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback +from diffusers.pipelines.cogvideo.pipeline_cogvideox import retrieve_timesteps +from transformers import T5EncoderModel, T5Tokenizer +from diffusers.models import AutoencoderKLCogVideoX, CogVideoXTransformer3DModel +from diffusers.schedulers import CogVideoXDDIMScheduler, CogVideoXDPMScheduler +from diffusers.pipelines import DiffusionPipeline +from diffusers.models.modeling_utils import ModelMixin + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +class CogVideoXTransformer3DModelTracking(CogVideoXTransformer3DModel, ModelMixin): + """ + Add tracking maps to the CogVideoX transformer model. + + Parameters: + num_tracking_blocks (`int`, defaults to `18`): + The number of tracking blocks to use. Must be less than or equal to num_layers. + """ + + def __init__( + self, + num_tracking_blocks: Optional[int] = 18, + num_attention_heads: int = 30, + attention_head_dim: int = 64, + in_channels: int = 16, + out_channels: Optional[int] = 16, + flip_sin_to_cos: bool = True, + freq_shift: int = 0, + time_embed_dim: int = 512, + text_embed_dim: int = 4096, + num_layers: int = 30, + dropout: float = 0.0, + attention_bias: bool = True, + sample_width: int = 90, + sample_height: int = 60, + sample_frames: int = 49, + patch_size: int = 2, + temporal_compression_ratio: int = 4, + max_text_seq_length: int = 226, + activation_fn: str = "gelu-approximate", + timestep_activation_fn: str = "silu", + norm_elementwise_affine: bool = True, + norm_eps: float = 1e-5, + spatial_interpolation_scale: float = 1.875, + temporal_interpolation_scale: float = 1.0, + use_rotary_positional_embeddings: bool = False, + use_learned_positional_embeddings: bool = False, + **kwargs + ): + super().__init__( + num_attention_heads=num_attention_heads, + attention_head_dim=attention_head_dim, + in_channels=in_channels, + out_channels=out_channels, + flip_sin_to_cos=flip_sin_to_cos, + freq_shift=freq_shift, + time_embed_dim=time_embed_dim, + text_embed_dim=text_embed_dim, + num_layers=num_layers, + dropout=dropout, + attention_bias=attention_bias, + sample_width=sample_width, + sample_height=sample_height, + sample_frames=sample_frames, + patch_size=patch_size, + temporal_compression_ratio=temporal_compression_ratio, + max_text_seq_length=max_text_seq_length, + activation_fn=activation_fn, + timestep_activation_fn=timestep_activation_fn, + norm_elementwise_affine=norm_elementwise_affine, + norm_eps=norm_eps, + spatial_interpolation_scale=spatial_interpolation_scale, + temporal_interpolation_scale=temporal_interpolation_scale, + use_rotary_positional_embeddings=use_rotary_positional_embeddings, + use_learned_positional_embeddings=use_learned_positional_embeddings, + **kwargs + ) + + inner_dim = num_attention_heads * attention_head_dim + self.num_tracking_blocks = num_tracking_blocks + + # Ensure num_tracking_blocks is not greater than num_layers + if num_tracking_blocks > num_layers: + raise ValueError("num_tracking_blocks must be less than or equal to num_layers") + + # Create linear layers for combining hidden states and tracking maps + self.combine_linears = nn.ModuleList( + [nn.Linear(inner_dim, inner_dim) for _ in range(num_tracking_blocks)] + ) + + # Initialize weights of combine_linears to zero + for linear in self.combine_linears: + linear.weight.data.zero_() + linear.bias.data.zero_() + + # Create transformer blocks for processing tracking maps + self.transformer_blocks_copy = nn.ModuleList( + [ + CogVideoXBlock( + dim=inner_dim, + num_attention_heads=self.config.num_attention_heads, + attention_head_dim=self.config.attention_head_dim, + time_embed_dim=self.config.time_embed_dim, + dropout=self.config.dropout, + activation_fn=self.config.activation_fn, + attention_bias=self.config.attention_bias, + norm_elementwise_affine=self.config.norm_elementwise_affine, + norm_eps=self.config.norm_eps, + ) + for _ in range(num_tracking_blocks) + ] + ) + + # For initial combination of hidden states and tracking maps + self.initial_combine_linear = nn.Linear(inner_dim, inner_dim) + self.initial_combine_linear.weight.data.zero_() + self.initial_combine_linear.bias.data.zero_() + + # Freeze all parameters + for param in self.parameters(): + param.requires_grad = False + + # Unfreeze parameters that need to be trained + for linear in self.combine_linears: + for param in linear.parameters(): + param.requires_grad = True + + for block in self.transformer_blocks_copy: + for param in block.parameters(): + param.requires_grad = True + + for param in self.initial_combine_linear.parameters(): + param.requires_grad = True + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + tracking_maps: torch.Tensor, + timestep: Union[int, float, torch.LongTensor], + timestep_cond: Optional[torch.Tensor] = None, + image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + attention_kwargs: Optional[Dict[str, Any]] = None, + return_dict: bool = True, + ): + if attention_kwargs is not None: + attention_kwargs = attention_kwargs.copy() + lora_scale = attention_kwargs.pop("scale", 1.0) + else: + lora_scale = 1.0 + + if USE_PEFT_BACKEND: + # weight the lora layers by setting `lora_scale` for each PEFT layer + scale_lora_layers(self, lora_scale) + else: + if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None: + logger.warning( + "Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective." + ) + + batch_size, num_frames, channels, height, width = hidden_states.shape + + # 1. Time embedding + timesteps = timestep + t_emb = self.time_proj(timesteps) + + # timesteps does not contain any weights and will always return f32 tensors + # but time_embedding might actually be running in fp16. so we need to cast here. + # there might be better ways to encapsulate this. + t_emb = t_emb.to(dtype=hidden_states.dtype) + emb = self.time_embedding(t_emb, timestep_cond) + + # 2. Patch embedding + hidden_states = self.patch_embed(encoder_hidden_states, hidden_states) + hidden_states = self.embedding_dropout(hidden_states) + + # Process tracking maps + prompt_embed = encoder_hidden_states.clone() + tracking_maps_hidden_states = self.patch_embed(prompt_embed, tracking_maps) + tracking_maps_hidden_states = self.embedding_dropout(tracking_maps_hidden_states) + del prompt_embed + + text_seq_length = encoder_hidden_states.shape[1] + encoder_hidden_states = hidden_states[:, :text_seq_length] + hidden_states = hidden_states[:, text_seq_length:] + tracking_maps = tracking_maps_hidden_states[:, text_seq_length:] + + # Combine hidden states and tracking maps initially + combined = hidden_states + tracking_maps + tracking_maps = self.initial_combine_linear(combined) + + # Process transformer blocks + for i in range(len(self.transformer_blocks)): + if self.training and self.gradient_checkpointing: + # Gradient checkpointing logic for hidden states + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + return custom_forward + + ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + hidden_states, encoder_hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(self.transformer_blocks[i]), + hidden_states, + encoder_hidden_states, + emb, + image_rotary_emb, + **ckpt_kwargs, + ) + else: + hidden_states, encoder_hidden_states = self.transformer_blocks[i]( + hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + temb=emb, + image_rotary_emb=image_rotary_emb, + ) + + if i < len(self.transformer_blocks_copy): + if self.training and self.gradient_checkpointing: + # Gradient checkpointing logic for tracking maps + tracking_maps, _ = torch.utils.checkpoint.checkpoint( + create_custom_forward(self.transformer_blocks_copy[i]), + tracking_maps, + encoder_hidden_states, + emb, + image_rotary_emb, + **ckpt_kwargs, + ) + else: + tracking_maps, _ = self.transformer_blocks_copy[i]( + hidden_states=tracking_maps, + encoder_hidden_states=encoder_hidden_states, + temb=emb, + image_rotary_emb=image_rotary_emb, + ) + + # Combine hidden states and tracking maps + tracking_maps = self.combine_linears[i](tracking_maps) + hidden_states = hidden_states + tracking_maps + + + if not self.config.use_rotary_positional_embeddings: + # CogVideoX-2B + hidden_states = self.norm_final(hidden_states) + else: + # CogVideoX-5B + hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1) + hidden_states = self.norm_final(hidden_states) + hidden_states = hidden_states[:, text_seq_length:] + + # 4. Final block + hidden_states = self.norm_out(hidden_states, temb=emb) + hidden_states = self.proj_out(hidden_states) + + # 5. Unpatchify + # Note: we use `-1` instead of `channels`: + # - It is okay to `channels` use for CogVideoX-2b and CogVideoX-5b (number of input channels is equal to output channels) + # - However, for CogVideoX-5b-I2V also takes concatenated input image latents (number of input channels is twice the output channels) + p = self.config.patch_size + output = hidden_states.reshape(batch_size, num_frames, height // p, width // p, -1, p, p) + output = output.permute(0, 1, 4, 2, 5, 3, 6).flatten(5, 6).flatten(3, 4) + + if USE_PEFT_BACKEND: + # remove `lora_scale` from each PEFT layer + unscale_lora_layers(self, lora_scale) + + if not return_dict: + return (output,) + return Transformer2DModelOutput(sample=output) + + @classmethod + def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs): + try: + model = super().from_pretrained(pretrained_model_name_or_path, **kwargs) + print("Loaded DiffusionAsShader checkpoint directly.") + + for param in model.parameters(): + param.requires_grad = False + + for linear in model.combine_linears: + for param in linear.parameters(): + param.requires_grad = True + + for block in model.transformer_blocks_copy: + for param in block.parameters(): + param.requires_grad = True + + for param in model.initial_combine_linear.parameters(): + param.requires_grad = True + + return model + + except Exception as e: + print(f"Failed to load as DiffusionAsShader: {e}") + print("Attempting to load as CogVideoXTransformer3DModel and convert...") + + base_model = CogVideoXTransformer3DModel.from_pretrained(pretrained_model_name_or_path, **kwargs) + + config = dict(base_model.config) + config["num_tracking_blocks"] = kwargs.pop("num_tracking_blocks", 18) + + model = cls(**config) + model.load_state_dict(base_model.state_dict(), strict=False) + + model.initial_combine_linear.weight.data.zero_() + model.initial_combine_linear.bias.data.zero_() + + for linear in model.combine_linears: + linear.weight.data.zero_() + linear.bias.data.zero_() + + for i in range(model.num_tracking_blocks): + model.transformer_blocks_copy[i].load_state_dict(model.transformer_blocks[i].state_dict()) + + + for param in model.parameters(): + param.requires_grad = False + + for linear in model.combine_linears: + for param in linear.parameters(): + param.requires_grad = True + + for block in model.transformer_blocks_copy: + for param in block.parameters(): + param.requires_grad = True + + for param in model.initial_combine_linear.parameters(): + param.requires_grad = True + + return model + + def save_pretrained( + self, + save_directory: Union[str, os.PathLike], + is_main_process: bool = True, + save_function: Optional[Callable] = None, + safe_serialization: bool = True, + variant: Optional[str] = None, + max_shard_size: Union[int, str] = "5GB", + push_to_hub: bool = False, + **kwargs, + ): + super().save_pretrained( + save_directory, + is_main_process=is_main_process, + save_function=save_function, + safe_serialization=safe_serialization, + variant=variant, + max_shard_size=max_shard_size, + push_to_hub=push_to_hub, + **kwargs, + ) + + if is_main_process: + config_dict = dict(self.config) + config_dict.pop("_name_or_path", None) + config_dict.pop("_use_default_values", None) + config_dict["_class_name"] = "CogVideoXTransformer3DModelTracking" + config_dict["num_tracking_blocks"] = self.num_tracking_blocks + + os.makedirs(save_directory, exist_ok=True) + with open(os.path.join(save_directory, "config.json"), "w", encoding="utf-8") as f: + import json + json.dump(config_dict, f, indent=2) + +class CogVideoXPipelineTracking(CogVideoXPipeline, DiffusionPipeline): + + def __init__( + self, + tokenizer: T5Tokenizer, + text_encoder: T5EncoderModel, + vae: AutoencoderKLCogVideoX, + transformer: CogVideoXTransformer3DModelTracking, + scheduler: Union[CogVideoXDDIMScheduler, CogVideoXDPMScheduler], + ): + super().__init__(tokenizer, text_encoder, vae, transformer, scheduler) + + if not isinstance(self.transformer, CogVideoXTransformer3DModelTracking): + raise ValueError("The transformer in this pipeline must be of type CogVideoXTransformer3DModelTracking") + + @torch.no_grad() + def __call__( + self, + prompt: Optional[Union[str, List[str]]] = None, + negative_prompt: Optional[Union[str, List[str]]] = None, + height: int = 480, + width: int = 720, + num_frames: int = 49, + num_inference_steps: int = 50, + timesteps: Optional[List[int]] = None, + guidance_scale: float = 6, + use_dynamic_cfg: bool = False, + num_videos_per_prompt: int = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + output_type: str = "pil", + return_dict: bool = True, + attention_kwargs: Optional[Dict[str, Any]] = None, + callback_on_step_end: Optional[ + Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] + ] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + max_sequence_length: int = 226, + tracking_maps: Optional[torch.Tensor] = None, + ) -> Union[CogVideoXPipelineOutput, Tuple]: + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + + num_videos_per_prompt = 1 + + self.check_inputs( + prompt, + height, + width, + negative_prompt, + callback_on_step_end_tensor_inputs, + prompt_embeds, + negative_prompt_embeds, + ) + self._guidance_scale = guidance_scale + self._attention_kwargs = attention_kwargs + self._interrupt = False + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + do_classifier_free_guidance = guidance_scale > 1.0 + + prompt_embeds, negative_prompt_embeds = self.encode_prompt( + prompt, + negative_prompt, + do_classifier_free_guidance, + num_videos_per_prompt=num_videos_per_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + max_sequence_length=max_sequence_length, + device=device, + ) + if do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) + + timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps) + self._num_timesteps = len(timesteps) + + latent_channels = self.transformer.config.in_channels + latents = self.prepare_latents( + batch_size * num_videos_per_prompt, + latent_channels, + num_frames, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + image_rotary_emb = ( + self._prepare_rotary_positional_embeddings(height, width, latents.size(1), device) + if self.transformer.config.use_rotary_positional_embeddings + else None + ) + + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + + with self.progress_bar(total=num_inference_steps) as progress_bar: + old_pred_original_sample = None + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + tracking_maps_latent = torch.cat([tracking_maps] * 2) if do_classifier_free_guidance else tracking_maps + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + timestep = t.expand(latent_model_input.shape[0]) + + noise_pred = self.transformer( + hidden_states=latent_model_input, + encoder_hidden_states=prompt_embeds, + timestep=timestep, + image_rotary_emb=image_rotary_emb, + attention_kwargs=attention_kwargs, + tracking_maps=tracking_maps_latent, + return_dict=False, + )[0] + noise_pred = noise_pred.float() + + if use_dynamic_cfg: + self._guidance_scale = 1 + guidance_scale * ( + (1 - math.cos(math.pi * ((num_inference_steps - t.item()) / num_inference_steps) ** 5.0)) / 2 + ) + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) + + if not isinstance(self.scheduler, CogVideoXDPMScheduler): + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + else: + latents, old_pred_original_sample = self.scheduler.step( + noise_pred, + old_pred_original_sample, + t, + timesteps[i - 1] if i > 0 else None, + latents, + **extra_step_kwargs, + return_dict=False, + ) + latents = latents.to(prompt_embeds.dtype) + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if not output_type == "latent": + video = self.decode_latents(latents) + video = self.video_processor.postprocess_video(video=video, output_type=output_type) + else: + video = latents + + self.maybe_free_model_hooks() + + if not return_dict: + return (video,) + return CogVideoXPipelineOutput(frames=video) + +class CogVideoXImageToVideoPipelineTracking(CogVideoXImageToVideoPipeline, DiffusionPipeline): + + def __init__( + self, + tokenizer: T5Tokenizer, + text_encoder: T5EncoderModel, + vae: AutoencoderKLCogVideoX, + transformer: CogVideoXTransformer3DModelTracking, + scheduler: Union[CogVideoXDDIMScheduler, CogVideoXDPMScheduler], + ): + super().__init__(tokenizer, text_encoder, vae, transformer, scheduler) + + if not isinstance(self.transformer, CogVideoXTransformer3DModelTracking): + raise ValueError("The transformer in this pipeline must be of type CogVideoXTransformer3DModelTracking") + + # 打印transformer blocks的数量 + print(f"Number of transformer blocks: {len(self.transformer.transformer_blocks)}") + print(f"Number of tracking transformer blocks: {len(self.transformer.transformer_blocks_copy)}") + self.transformer = torch.compile(self.transformer) + + @torch.no_grad() + def __call__( + self, + image: Union[torch.Tensor, Image.Image], + prompt: Optional[Union[str, List[str]]] = None, + negative_prompt: Optional[Union[str, List[str]]] = None, + height: Optional[int] = None, + width: Optional[int] = None, + num_frames: int = 49, + num_inference_steps: int = 50, + timesteps: Optional[List[int]] = None, + guidance_scale: float = 6, + use_dynamic_cfg: bool = False, + num_videos_per_prompt: int = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + output_type: str = "pil", + return_dict: bool = True, + attention_kwargs: Optional[Dict[str, Any]] = None, + callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + max_sequence_length: int = 226, + tracking_maps: Optional[torch.Tensor] = None, + tracking_image: Optional[torch.Tensor] = None, + ) -> Union[CogVideoXPipelineOutput, Tuple]: + # Most of the implementation remains the same as the parent class + # We will modify the parts that need to handle tracking_maps + + # 1. Check inputs and set default values + self.check_inputs( + image, + prompt, + height, + width, + negative_prompt, + callback_on_step_end_tensor_inputs, + prompt_embeds, + negative_prompt_embeds, + ) + self._guidance_scale = guidance_scale + self._attention_kwargs = attention_kwargs + self._interrupt = False + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input prompt + prompt_embeds, negative_prompt_embeds = self.encode_prompt( + prompt=prompt, + negative_prompt=negative_prompt, + do_classifier_free_guidance=do_classifier_free_guidance, + num_videos_per_prompt=num_videos_per_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + max_sequence_length=max_sequence_length, + device=device, + ) + if do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) + del negative_prompt_embeds + + # 4. Prepare timesteps + timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps) + self._num_timesteps = len(timesteps) + + # 5. Prepare latents + image = self.video_processor.preprocess(image, height=height, width=width).to( + device, dtype=prompt_embeds.dtype + ) + + tracking_image = self.video_processor.preprocess(tracking_image, height=height, width=width).to( + device, dtype=prompt_embeds.dtype + ) + if self.transformer.config.in_channels != 16: + latent_channels = self.transformer.config.in_channels // 2 + else: + latent_channels = self.transformer.config.in_channels + latents, image_latents = self.prepare_latents( + image, + batch_size * num_videos_per_prompt, + latent_channels, + num_frames, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + del image + + _, tracking_image_latents = self.prepare_latents( + tracking_image, + batch_size * num_videos_per_prompt, + latent_channels, + num_frames, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents=None, + ) + del tracking_image + + # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 7. Create rotary embeds if required + image_rotary_emb = ( + self._prepare_rotary_positional_embeddings(height, width, latents.size(1), device) + if self.transformer.config.use_rotary_positional_embeddings + else None + ) + + # 8. Denoising loop + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + + with self.progress_bar(total=num_inference_steps) as progress_bar: + old_pred_original_sample = None + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + latent_image_input = torch.cat([image_latents] * 2) if do_classifier_free_guidance else image_latents + latent_model_input = torch.cat([latent_model_input, latent_image_input], dim=2) + del latent_image_input + + # Handle tracking maps + if tracking_maps is not None: + latents_tracking_image = torch.cat([tracking_image_latents] * 2) if do_classifier_free_guidance else tracking_image_latents + tracking_maps_input = torch.cat([tracking_maps] * 2) if do_classifier_free_guidance else tracking_maps + tracking_maps_input = torch.cat([tracking_maps_input, latents_tracking_image], dim=2) + del latents_tracking_image + else: + tracking_maps_input = None + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timestep = t.expand(latent_model_input.shape[0]) + + # Predict noise + self.transformer.to(dtype=latent_model_input.dtype) + noise_pred = self.transformer( + hidden_states=latent_model_input, + encoder_hidden_states=prompt_embeds, + timestep=timestep, + image_rotary_emb=image_rotary_emb, + attention_kwargs=attention_kwargs, + tracking_maps=tracking_maps_input, + return_dict=False, + )[0] + del latent_model_input + if tracking_maps_input is not None: + del tracking_maps_input + noise_pred = noise_pred.float() + + # perform guidance + if use_dynamic_cfg: + self._guidance_scale = 1 + guidance_scale * ( + (1 - math.cos(math.pi * ((num_inference_steps - t.item()) / num_inference_steps) ** 5.0)) / 2 + ) + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) + del noise_pred_uncond, noise_pred_text + + # compute the previous noisy sample x_t -> x_t-1 + if not isinstance(self.scheduler, CogVideoXDPMScheduler): + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + else: + latents, old_pred_original_sample = self.scheduler.step( + noise_pred, + old_pred_original_sample, + t, + timesteps[i - 1] if i > 0 else None, + latents, + **extra_step_kwargs, + return_dict=False, + ) + del noise_pred + latents = latents.to(prompt_embeds.dtype) + + # call the callback, if provided + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + # 9. Post-processing + if not output_type == "latent": + video = self.decode_latents(latents) + video = self.video_processor.postprocess_video(video=video, output_type=output_type) + else: + video = latents + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (video,) + + return CogVideoXPipelineOutput(frames=video) + +class CogVideoXVideoToVideoPipelineTracking(CogVideoXVideoToVideoPipeline, DiffusionPipeline): + + def __init__( + self, + tokenizer: T5Tokenizer, + text_encoder: T5EncoderModel, + vae: AutoencoderKLCogVideoX, + transformer: CogVideoXTransformer3DModelTracking, + scheduler: Union[CogVideoXDDIMScheduler, CogVideoXDPMScheduler], + ): + super().__init__(tokenizer, text_encoder, vae, transformer, scheduler) + + if not isinstance(self.transformer, CogVideoXTransformer3DModelTracking): + raise ValueError("The transformer in this pipeline must be of type CogVideoXTransformer3DModelTracking") + + @torch.no_grad() + def __call__( + self, + video: List[Image.Image] = None, + prompt: Optional[Union[str, List[str]]] = None, + negative_prompt: Optional[Union[str, List[str]]] = None, + height: int = 480, + width: int = 720, + num_inference_steps: int = 50, + timesteps: Optional[List[int]] = None, + strength: float = 0.8, + guidance_scale: float = 6, + use_dynamic_cfg: bool = False, + num_videos_per_prompt: int = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + output_type: str = "pil", + return_dict: bool = True, + attention_kwargs: Optional[Dict[str, Any]] = None, + callback_on_step_end: Optional[ + Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] + ] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + max_sequence_length: int = 226, + tracking_maps: Optional[torch.Tensor] = None, + ) -> Union[CogVideoXPipelineOutput, Tuple]: + + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + + num_videos_per_prompt = 1 + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt=prompt, + height=height, + width=width, + strength=strength, + negative_prompt=negative_prompt, + callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, + video=video, + latents=latents, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + ) + self._guidance_scale = guidance_scale + self._attention_kwargs = attention_kwargs + self._interrupt = False + + # 2. Default call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input prompt + prompt_embeds, negative_prompt_embeds = self.encode_prompt( + prompt, + negative_prompt, + do_classifier_free_guidance, + num_videos_per_prompt=num_videos_per_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + max_sequence_length=max_sequence_length, + device=device, + ) + if do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) + + # 4. Prepare timesteps + timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps) + timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, timesteps, strength, device) + latent_timestep = timesteps[:1].repeat(batch_size * num_videos_per_prompt) + self._num_timesteps = len(timesteps) + + # 5. Prepare latents + if latents is None: + video = self.video_processor.preprocess_video(video, height=height, width=width) + video = video.to(device=device, dtype=prompt_embeds.dtype) + + latent_channels = self.transformer.config.in_channels + latents = self.prepare_latents( + video, + batch_size * num_videos_per_prompt, + latent_channels, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + latent_timestep, + ) + + # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 7. Create rotary embeds if required + image_rotary_emb = ( + self._prepare_rotary_positional_embeddings(height, width, latents.size(1), device) + if self.transformer.config.use_rotary_positional_embeddings + else None + ) + + # 8. Denoising loop + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + + with self.progress_bar(total=num_inference_steps) as progress_bar: + # for DPM-solver++ + old_pred_original_sample = None + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + tracking_maps_input = torch.cat([tracking_maps] * 2) if do_classifier_free_guidance else tracking_maps + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timestep = t.expand(latent_model_input.shape[0]) + + # predict noise model_output + noise_pred = self.transformer( + hidden_states=latent_model_input, + encoder_hidden_states=prompt_embeds, + timestep=timestep, + image_rotary_emb=image_rotary_emb, + attention_kwargs=attention_kwargs, + tracking_maps=tracking_maps_input, + return_dict=False, + )[0] + noise_pred = noise_pred.float() + + # perform guidance + if use_dynamic_cfg: + self._guidance_scale = 1 + guidance_scale * ( + (1 - math.cos(math.pi * ((num_inference_steps - t.item()) / num_inference_steps) ** 5.0)) / 2 + ) + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + if not isinstance(self.scheduler, CogVideoXDPMScheduler): + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + else: + latents, old_pred_original_sample = self.scheduler.step( + noise_pred, + old_pred_original_sample, + t, + timesteps[i - 1] if i > 0 else None, + latents, + **extra_step_kwargs, + return_dict=False, + ) + latents = latents.to(prompt_embeds.dtype) + + # call the callback, if provided + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if not output_type == "latent": + video = self.decode_latents(latents) + video = self.video_processor.postprocess_video(video=video, output_type=output_type) + else: + video = latents + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (video,) + + return CogVideoXPipelineOutput(frames=video) + diff --git a/models/pipelines.py b/models/pipelines.py new file mode 100644 index 0000000000000000000000000000000000000000..1c42826b4c675caa059e0faa7f5a63a31cb02466 --- /dev/null +++ b/models/pipelines.py @@ -0,0 +1,1040 @@ +import os +import sys +import math +from tqdm import tqdm +from PIL import Image, ImageDraw +project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) +try: + sys.path.append(os.path.join(project_root, "submodules/MoGe")) + os.environ["TOKENIZERS_PARALLELISM"] = "false" +except: + print("Warning: MoGe not found, motion transfer will not be applied") + +import torch +import numpy as np +from PIL import Image +import torchvision.transforms as transforms +from diffusers import FluxControlPipeline, CogVideoXDPMScheduler +from diffusers.utils import export_to_video, load_image, load_video + +from models.spatracker.predictor import SpaTrackerPredictor +from models.spatracker.utils.visualizer import Visualizer +from models.cogvideox_tracking import CogVideoXImageToVideoPipelineTracking + +from submodules.MoGe.moge.model import MoGeModel +from image_gen_aux import DepthPreprocessor +from moviepy.editor import ImageSequenceClip + +class DiffusionAsShaderPipeline: + def __init__(self, gpu_id=0, output_dir='outputs'): + """Initialize MotionTransfer class + + Args: + gpu_id (int): GPU device ID + output_dir (str): Output directory path + """ + # video parameters + self.max_depth = 65.0 + self.fps = 8 + + # camera parameters + self.camera_motion=None + self.fov=55 + + # device + self.device = f"cuda:{gpu_id}" + torch.cuda.set_device(gpu_id) + + # files + self.output_dir = output_dir + os.makedirs(output_dir, exist_ok=True) + + # Initialize transform + self.transform = transforms.Compose([ + transforms.Resize((480, 720)), + transforms.ToTensor() + ]) + + @torch.no_grad() + def _infer( + self, + prompt: str, + model_path: str, + tracking_tensor: torch.Tensor = None, + image_tensor: torch.Tensor = None, # [C,H,W] in range [0,1] + output_path: str = "./output.mp4", + num_inference_steps: int = 50, + guidance_scale: float = 6.0, + num_videos_per_prompt: int = 1, + dtype: torch.dtype = torch.bfloat16, + fps: int = 24, + seed: int = 42, + ): + """ + Generates a video based on the given prompt and saves it to the specified path. + + Parameters: + - prompt (str): The description of the video to be generated. + - model_path (str): The path of the pre-trained model to be used. + - tracking_tensor (torch.Tensor): Tracking video tensor [T, C, H, W] in range [0,1] + - image_tensor (torch.Tensor): Input image tensor [C, H, W] in range [0,1] + - output_path (str): The path where the generated video will be saved. + - num_inference_steps (int): Number of steps for the inference process. + - guidance_scale (float): The scale for classifier-free guidance. + - num_videos_per_prompt (int): Number of videos to generate per prompt. + - dtype (torch.dtype): The data type for computation. + - seed (int): The seed for reproducibility. + """ + from transformers import T5EncoderModel, T5Tokenizer + from diffusers import AutoencoderKLCogVideoX, CogVideoXDDIMScheduler + from models.cogvideox_tracking import CogVideoXTransformer3DModelTracking + + vae = AutoencoderKLCogVideoX.from_pretrained(model_path, subfolder="vae") + text_encoder = T5EncoderModel.from_pretrained(model_path, subfolder="text_encoder") + tokenizer = T5Tokenizer.from_pretrained(model_path, subfolder="tokenizer") + transformer = CogVideoXTransformer3DModelTracking.from_pretrained(model_path, subfolder="transformer") + scheduler = CogVideoXDDIMScheduler.from_pretrained(model_path, subfolder="scheduler") + + pipe = CogVideoXImageToVideoPipelineTracking( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + transformer=transformer, + scheduler=scheduler + ) + + # Convert tensor to PIL Image + image_np = (image_tensor.permute(1, 2, 0).numpy() * 255).astype(np.uint8) + image = Image.fromarray(image_np) + height, width = image.height, image.width + + pipe.transformer.eval() + pipe.text_encoder.eval() + pipe.vae.eval() + + # Process tracking tensor + tracking_maps = tracking_tensor.float() # [T, C, H, W] + tracking_maps = tracking_maps.to(device=self.device, dtype=dtype) + tracking_first_frame = tracking_maps[0:1] # Get first frame as [1, C, H, W] + height, width = tracking_first_frame.shape[2], tracking_first_frame.shape[3] + + # 2. Set Scheduler. + pipe.scheduler = CogVideoXDPMScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing") + + pipe.to(self.device, dtype=dtype) + # pipe.enable_sequential_cpu_offload() + + pipe.vae.enable_slicing() + pipe.vae.enable_tiling() + pipe.transformer.eval() + pipe.text_encoder.eval() + pipe.vae.eval() + + pipe.transformer.gradient_checkpointing = False + + print("Encoding tracking maps") + tracking_maps = tracking_maps.unsqueeze(0) # [B, T, C, H, W] + tracking_maps = tracking_maps.permute(0, 2, 1, 3, 4) # [B, C, T, H, W] + tracking_latent_dist = pipe.vae.encode(tracking_maps).latent_dist + tracking_maps = tracking_latent_dist.sample() * pipe.vae.config.scaling_factor + tracking_maps = tracking_maps.permute(0, 2, 1, 3, 4) # [B, F, C, H, W] + + # 4. Generate the video frames based on the prompt. + video_generate = pipe( + prompt=prompt, + negative_prompt="The video is not of a high quality, it has a low resolution. Watermark present in each frame. The background is solid. Strange body and strange trajectory. Distortion.", + image=image, + num_videos_per_prompt=num_videos_per_prompt, + num_inference_steps=num_inference_steps, + num_frames=49, + use_dynamic_cfg=True, + guidance_scale=guidance_scale, + generator=torch.Generator().manual_seed(seed), + tracking_maps=tracking_maps, + tracking_image=tracking_first_frame, + height=height, + width=width, + ).frames[0] + + # 5. Export the generated frames to a video file. fps must be 8 for original video. + output_path = output_path if output_path else f"result.mp4" + os.makedirs(os.path.dirname(output_path), exist_ok=True) + export_to_video(video_generate, output_path, fps=fps) + + #========== camera parameters ==========# + + def _set_camera_motion(self, camera_motion): + self.camera_motion = camera_motion + + def _get_intr(self, fov, H=480, W=720): + fov_rad = math.radians(fov) + focal_length = (W / 2) / math.tan(fov_rad / 2) + + cx = W / 2 + cy = H / 2 + + intr = torch.tensor([ + [focal_length, 0, cx], + [0, focal_length, cy], + [0, 0, 1] + ], dtype=torch.float32) + + return intr + + def _apply_poses(self, pts, intr, poses): + """ + Args: + pts (torch.Tensor): pointclouds coordinates [T, N, 3] + intr (torch.Tensor): camera intrinsics [T, 3, 3] + poses (numpy.ndarray): camera poses [T, 4, 4] + """ + poses = torch.from_numpy(poses).float().to(self.device) + + T, N, _ = pts.shape + ones = torch.ones(T, N, 1, device=self.device, dtype=torch.float) + pts_hom = torch.cat([pts[:, :, :2], ones], dim=-1) # (T, N, 3) + pts_cam = torch.bmm(pts_hom, torch.linalg.inv(intr).transpose(1, 2)) # (T, N, 3) + pts_cam[:,:, :3] /= pts[:, :, 2:3] + + # to homogeneous + pts_cam = torch.cat([pts_cam, ones], dim=-1) # (T, N, 4) + + if poses.shape[0] == 1: + poses = poses.repeat(T, 1, 1) + elif poses.shape[0] != T: + raise ValueError(f"Poses length ({poses.shape[0]}) must match sequence length ({T})") + + pts_world = torch.bmm(pts_cam, poses.transpose(1, 2))[:, :, :3] # (T, N, 3) + + pts_proj = torch.bmm(pts_world, intr.transpose(1, 2)) # (T, N, 3) + pts_proj[:, :, :2] /= pts_proj[:, :, 2:3] + + return pts_proj + + def apply_traj_on_tracking(self, pred_tracks, camera_motion=None, fov=55, frame_num=49): + intr = self._get_intr(fov).unsqueeze(0).repeat(frame_num, 1, 1).to(self.device) + tracking_pts = self._apply_poses(pred_tracks.squeeze(), intr, camera_motion).unsqueeze(0) + return tracking_pts + + ##============= SpatialTracker =============## + + def generate_tracking_spatracker(self, video_tensor, density=70): + """Generate tracking video + + Args: + video_tensor (torch.Tensor): Input video tensor + + Returns: + str: Path to tracking video + """ + print("Loading tracking models...") + # Load tracking model + tracker = SpaTrackerPredictor( + checkpoint=os.path.join(project_root, 'checkpoints/spatracker/spaT_final.pth'), + interp_shape=(384, 576), + seq_length=12 + ).to(self.device) + + # Load depth model + self.depth_preprocessor = DepthPreprocessor.from_pretrained("Intel/zoedepth-nyu-kitti") + self.depth_preprocessor.to(self.device) + + try: + video = video_tensor.unsqueeze(0).to(self.device) + + video_depths = [] + for i in range(video_tensor.shape[0]): + frame = (video_tensor[i].permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8) + depth = self.depth_preprocessor(Image.fromarray(frame))[0] + depth_tensor = transforms.ToTensor()(depth) # [1, H, W] + video_depths.append(depth_tensor) + video_depth = torch.stack(video_depths, dim=0).to(self.device) + # print("Video depth shape:", video_depth.shape) + + segm_mask = np.ones((480, 720), dtype=np.uint8) + + pred_tracks, pred_visibility, T_Firsts = tracker( + video * 255, + video_depth=video_depth, + grid_size=density, + backward_tracking=False, + depth_predictor=None, + grid_query_frame=0, + segm_mask=torch.from_numpy(segm_mask)[None, None].to(self.device), + wind_length=12, + progressive_tracking=False + ) + + return pred_tracks, pred_visibility, T_Firsts + + finally: + # Clean up GPU memory + del tracker, self.depth_preprocessor + torch.cuda.empty_cache() + + def visualize_tracking_spatracker(self, video, pred_tracks, pred_visibility, T_Firsts, save_tracking=True): + video = video.unsqueeze(0).to(self.device) + vis = Visualizer(save_dir=self.output_dir, grayscale=False, fps=24, pad_value=0) + msk_query = (T_Firsts == 0) + pred_tracks = pred_tracks[:,:,msk_query.squeeze()] + pred_visibility = pred_visibility[:,:,msk_query.squeeze()] + + tracking_video = vis.visualize(video=video, tracks=pred_tracks, + visibility=pred_visibility, save_video=False, + filename="temp") + + tracking_video = tracking_video.squeeze(0) # [T, C, H, W] + wide_list = list(tracking_video.unbind(0)) + wide_list = [wide.permute(1, 2, 0).cpu().numpy() for wide in wide_list] + clip = ImageSequenceClip(wide_list, fps=self.fps) + + tracking_path = None + if save_tracking: + try: + tracking_path = os.path.join(self.output_dir, "tracking_video.mp4") + clip.write_videofile(tracking_path, codec="libx264", fps=self.fps, logger=None) + print(f"Video saved to {tracking_path}") + except Exception as e: + print(f"Warning: Failed to save tracking video: {e}") + tracking_path = None + + # Convert tracking_video back to tensor in range [0,1] + tracking_frames = np.array(list(clip.iter_frames())) / 255.0 + tracking_video = torch.from_numpy(tracking_frames).permute(0, 3, 1, 2).float() + + return tracking_path, tracking_video + + ##============= MoGe =============## + + def valid_mask(self, pixels, W, H): + """Check if pixels are within valid image bounds + + Args: + pixels (numpy.ndarray): Pixel coordinates of shape [N, 2] + W (int): Image width + H (int): Image height + + Returns: + numpy.ndarray: Boolean mask of valid pixels + """ + return ((pixels[:, 0] >= 0) & (pixels[:, 0] < W) & (pixels[:, 1] > 0) & \ + (pixels[:, 1] < H)) + + def sort_points_by_depth(self, points, depths): + """Sort points by depth values + + Args: + points (numpy.ndarray): Points array of shape [N, 2] + depths (numpy.ndarray): Depth values of shape [N] + + Returns: + tuple: (sorted_points, sorted_depths, sort_index) + """ + # Combine points and depths into a single array for sorting + combined = np.hstack((points, depths[:, None])) # Nx3 (points + depth) + # Sort by depth (last column) in descending order + sort_index = combined[:, -1].argsort()[::-1] + sorted_combined = combined[sort_index] + # Split back into points and depths + sorted_points = sorted_combined[:, :-1] + sorted_depths = sorted_combined[:, -1] + return sorted_points, sorted_depths, sort_index + + def draw_rectangle(self, rgb, coord, side_length, color=(255, 0, 0)): + """Draw a rectangle on the image + + Args: + rgb (PIL.Image): Image to draw on + coord (tuple): Center coordinates (x, y) + side_length (int): Length of rectangle sides + color (tuple): RGB color tuple + """ + draw = ImageDraw.Draw(rgb) + # Calculate the bounding box of the rectangle + left_up_point = (coord[0] - side_length//2, coord[1] - side_length//2) + right_down_point = (coord[0] + side_length//2, coord[1] + side_length//2) + color = tuple(list(color)) + + draw.rectangle( + [left_up_point, right_down_point], + fill=tuple(color), + outline=tuple(color), + ) + + def visualize_tracking_moge(self, points, mask, save_tracking=True): + """Visualize tracking results from MoGe model + + Args: + points (numpy.ndarray): Points array of shape [T, H, W, 3] + mask (numpy.ndarray): Binary mask of shape [H, W] + save_tracking (bool): Whether to save tracking video + + Returns: + tuple: (tracking_path, tracking_video) + - tracking_path (str): Path to saved tracking video, None if save_tracking is False + - tracking_video (torch.Tensor): Tracking visualization tensor of shape [T, C, H, W] in range [0,1] + """ + # Create color array + T, H, W, _ = points.shape + colors = np.zeros((H, W, 3), dtype=np.uint8) + + # Set R channel - based on x coordinates (smaller on the left) + colors[:, :, 0] = np.tile(np.linspace(0, 255, W), (H, 1)) + + # Set G channel - based on y coordinates (smaller on the top) + colors[:, :, 1] = np.tile(np.linspace(0, 255, H), (W, 1)).T + + # Set B channel - based on depth + z_values = points[0, :, :, 2] # get z values + inv_z = 1 / z_values # calculate 1/z + # Calculate 2% and 98% percentiles + p2 = np.percentile(inv_z, 2) + p98 = np.percentile(inv_z, 98) + # Normalize to [0,1] range + normalized_z = np.clip((inv_z - p2) / (p98 - p2), 0, 1) + colors[:, :, 2] = (normalized_z * 255).astype(np.uint8) + colors = colors.astype(np.uint8) + # colors = colors * mask[..., None] + # points = points * mask[None, :, :, None] + + points = points.reshape(T, -1, 3) + colors = colors.reshape(-1, 3) + + # Initialize list to store frames + frames = [] + + for i, pts_i in enumerate(tqdm(points)): + pixels, depths = pts_i[..., :2], pts_i[..., 2] + pixels[..., 0] = pixels[..., 0] * W + pixels[..., 1] = pixels[..., 1] * H + pixels = pixels.astype(int) + + valid = self.valid_mask(pixels, W, H) + frame_rgb = colors[valid] + pixels = pixels[valid] + depths = depths[valid] + + img = Image.fromarray(np.uint8(np.zeros([H, W, 3])), mode="RGB") + sorted_pixels, _, sort_index = self.sort_points_by_depth(pixels, depths) + step = 1 + sorted_pixels = sorted_pixels[::step] + sorted_rgb = frame_rgb[sort_index][::step] + + for j in range(sorted_pixels.shape[0]): + self.draw_rectangle( + img, + coord=(sorted_pixels[j, 0], sorted_pixels[j, 1]), + side_length=2, + color=sorted_rgb[j], + ) + frames.append(np.array(img)) + + # Convert frames to video tensor in range [0,1] + tracking_video = torch.from_numpy(np.stack(frames)).permute(0, 3, 1, 2).float() / 255.0 + + tracking_path = None + if save_tracking: + try: + tracking_path = os.path.join(self.output_dir, "tracking_video_moge.mp4") + # Convert back to uint8 for saving + uint8_frames = [frame.astype(np.uint8) for frame in frames] + clip = ImageSequenceClip(uint8_frames, fps=self.fps) + clip.write_videofile(tracking_path, codec="libx264", fps=self.fps, logger=None) + print(f"Video saved to {tracking_path}") + except Exception as e: + print(f"Warning: Failed to save tracking video: {e}") + tracking_path = None + + return tracking_path, tracking_video + + def apply_tracking(self, video_tensor, fps=8, tracking_tensor=None, img_cond_tensor=None, prompt=None, checkpoint_path=None): + """Generate final video with motion transfer + + Args: + video_tensor (torch.Tensor): Input video tensor [T,C,H,W] + fps (float): Input video FPS + tracking_tensor (torch.Tensor): Tracking video tensor [T,C,H,W] + image_tensor (torch.Tensor): First frame tensor [C,H,W] to use for generation + prompt (str): Generation prompt + checkpoint_path (str): Path to model checkpoint + """ + self.fps = fps + + # Use first frame if no image provided + if img_cond_tensor is None: + img_cond_tensor = video_tensor[0] + + # Generate final video + final_output = os.path.join(os.path.abspath(self.output_dir), "result.mp4") + self._infer( + prompt=prompt, + model_path=checkpoint_path, + tracking_tensor=tracking_tensor, + image_tensor=img_cond_tensor, + output_path=final_output, + num_inference_steps=50, + guidance_scale=6.0, + dtype=torch.bfloat16, + fps=self.fps + ) + print(f"Final video generated successfully at: {final_output}") + + def _set_object_motion(self, motion_type): + """Set object motion type + + Args: + motion_type (str): Motion direction ('up', 'down', 'left', 'right') + """ + self.object_motion = motion_type + +class FirstFrameRepainter: + def __init__(self, gpu_id=0, output_dir='outputs'): + """Initialize FirstFrameRepainter + + Args: + gpu_id (int): GPU device ID + output_dir (str): Output directory path + """ + self.device = f"cuda:{gpu_id}" + self.output_dir = output_dir + self.max_depth = 65.0 + os.makedirs(output_dir, exist_ok=True) + + def repaint(self, image_tensor, prompt, depth_path=None, method="dav"): + """Repaint first frame using Flux + + Args: + image_tensor (torch.Tensor): Input image tensor [C,H,W] + prompt (str): Repaint prompt + depth_path (str): Path to depth image + method (str): depth estimator, "moge" or "dav" or "zoedepth" + + Returns: + torch.Tensor: Repainted image tensor [C,H,W] + """ + print("Loading Flux model...") + # Load Flux model + flux_pipe = FluxControlPipeline.from_pretrained( + "black-forest-labs/FLUX.1-Depth-dev", + torch_dtype=torch.bfloat16 + ).to(self.device) + + # Get depth map + if depth_path is None: + if method == "moge": + self.moge_model = MoGeModel.from_pretrained("Ruicheng/moge-vitl").to(self.device) + depth_map = self.moge_model.infer(image_tensor.to(self.device))["depth"] + depth_map = torch.clamp(depth_map, max=self.max_depth) + depth_normalized = 1.0 - (depth_map / self.max_depth) + depth_rgb = (depth_normalized * 255).cpu().numpy().astype(np.uint8) + control_image = Image.fromarray(depth_rgb).convert("RGB") + elif method == "zoedepth": + self.depth_preprocessor = DepthPreprocessor.from_pretrained("Intel/zoedepth-nyu-kitti") + self.depth_preprocessor.to(self.device) + image_np = (image_tensor.permute(1, 2, 0).numpy() * 255).astype(np.uint8) + control_image = self.depth_preprocessor(Image.fromarray(image_np))[0].convert("RGB") + control_image = control_image.point(lambda x: 255 - x) # the zoedepth depth is inverted + else: + self.depth_preprocessor = DepthPreprocessor.from_pretrained("depth-anything/Depth-Anything-V2-Large-hf") + self.depth_preprocessor.to(self.device) + image_np = (image_tensor.permute(1, 2, 0).numpy() * 255).astype(np.uint8) + control_image = self.depth_preprocessor(Image.fromarray(image_np))[0].convert("RGB") + else: + control_image = Image.open(depth_path).convert("RGB") + + try: + repainted_image = flux_pipe( + prompt=prompt, + control_image=control_image, + height=480, + width=720, + num_inference_steps=30, + guidance_scale=7.5, + ).images[0] + + # Save repainted image + repainted_image.save(os.path.join(self.output_dir, "temp_repainted.png")) + + # Convert PIL Image to tensor + transform = transforms.Compose([ + transforms.ToTensor() + ]) + repainted_tensor = transform(repainted_image) + + return repainted_tensor + + finally: + # Clean up GPU memory + del flux_pipe + if method == "moge": + del self.moge_model + else: + del self.depth_preprocessor + torch.cuda.empty_cache() + +class CameraMotionGenerator: + def __init__(self, motion_type, frame_num=49, H=480, W=720, fx=None, fy=None, fov=55, device='cuda'): + self.motion_type = motion_type + self.frame_num = frame_num + self.fov = fov + self.device = device + self.W = W + self.H = H + self.intr = torch.tensor([ + [0, 0, W / 2], + [0, 0, H / 2], + [0, 0, 1] + ], dtype=torch.float32, device=device) + # if fx, fy not provided + if not fx or not fy: + fov_rad = math.radians(fov) + fx = fy = (W / 2) / math.tan(fov_rad / 2) + + self.intr[0, 0] = fx + self.intr[1, 1] = fy + + def _apply_poses(self, pts, poses): + """ + Args: + pts (torch.Tensor): pointclouds coordinates [T, N, 3] + intr (torch.Tensor): camera intrinsics [T, 3, 3] + poses (numpy.ndarray): camera poses [T, 4, 4] + """ + if isinstance(poses, np.ndarray): + poses = torch.from_numpy(poses) + + intr = self.intr.unsqueeze(0).repeat(self.frame_num, 1, 1).to(torch.float) + T, N, _ = pts.shape + ones = torch.ones(T, N, 1, device=self.device, dtype=torch.float) + pts_hom = torch.cat([pts[:, :, :2], ones], dim=-1) # (T, N, 3) + pts_cam = torch.bmm(pts_hom, torch.linalg.inv(intr).transpose(1, 2)) # (T, N, 3) + pts_cam[:,:, :3] *= pts[:, :, 2:3] + + # to homogeneous + pts_cam = torch.cat([pts_cam, ones], dim=-1) # (T, N, 4) + + if poses.shape[0] == 1: + poses = poses.repeat(T, 1, 1) + elif poses.shape[0] != T: + raise ValueError(f"Poses length ({poses.shape[0]}) must match sequence length ({T})") + + poses = poses.to(torch.float).to(self.device) + pts_world = torch.bmm(pts_cam, poses.transpose(1, 2))[:, :, :3] # (T, N, 3) + pts_proj = torch.bmm(pts_world, intr.transpose(1, 2)) # (T, N, 3) + pts_proj[:, :, :2] /= pts_proj[:, :, 2:3] + + return pts_proj + + def w2s(self, pts, poses): + if isinstance(poses, np.ndarray): + poses = torch.from_numpy(poses) + assert poses.shape[0] == self.frame_num + poses = poses.to(torch.float32).to(self.device) + T, N, _ = pts.shape # (T, N, 3) + intr = self.intr.unsqueeze(0).repeat(self.frame_num, 1, 1) + # Step 1: 扩展点的维度,使其变成 (T, N, 4),最后一维填充1 (齐次坐标) + ones = torch.ones((T, N, 1), device=self.device, dtype=pts.dtype) + points_world_h = torch.cat([pts, ones], dim=-1) + points_camera_h = torch.bmm(poses, points_world_h.permute(0, 2, 1)) + points_camera = points_camera_h[:, :3, :].permute(0, 2, 1) + + points_image_h = torch.bmm(points_camera, intr.permute(0, 2, 1)) + + uv = points_image_h[:, :, :2] / points_image_h[:, :, 2:3] + + # Step 5: 提取深度 (Z) 并拼接 + depth = points_camera[:, :, 2:3] # (T, N, 1) + uvd = torch.cat([uv, depth], dim=-1) # (T, N, 3) + + return uvd # 屏幕坐标 + 深度 (T, N, 3) + + def apply_motion_on_pts(self, pts, camera_motion): + tracking_pts = self._apply_poses(pts.squeeze(), camera_motion).unsqueeze(0) + return tracking_pts + + def set_intr(self, K): + if isinstance(K, np.ndarray): + K = torch.from_numpy(K) + self.intr = K.to(self.device) + + def rot_poses(self, angle, axis='y'): + """Generate a single rotation matrix + + Args: + angle (float): Rotation angle in degrees + axis (str): Rotation axis ('x', 'y', or 'z') + + Returns: + torch.Tensor: Single rotation matrix [4, 4] + """ + angle_rad = math.radians(angle) + cos_theta = torch.cos(torch.tensor(angle_rad)) + sin_theta = torch.sin(torch.tensor(angle_rad)) + + if axis == 'x': + rot_mat = torch.tensor([ + [1, 0, 0, 0], + [0, cos_theta, -sin_theta, 0], + [0, sin_theta, cos_theta, 0], + [0, 0, 0, 1] + ], dtype=torch.float32) + elif axis == 'y': + rot_mat = torch.tensor([ + [cos_theta, 0, sin_theta, 0], + [0, 1, 0, 0], + [-sin_theta, 0, cos_theta, 0], + [0, 0, 0, 1] + ], dtype=torch.float32) + elif axis == 'z': + rot_mat = torch.tensor([ + [cos_theta, -sin_theta, 0, 0], + [sin_theta, cos_theta, 0, 0], + [0, 0, 1, 0], + [0, 0, 0, 1] + ], dtype=torch.float32) + else: + raise ValueError("Invalid axis value. Choose 'x', 'y', or 'z'.") + + return rot_mat.to(self.device) + + def trans_poses(self, dx, dy, dz): + """ + params: + - dx: float, displacement along x axis。 + - dy: float, displacement along y axis。 + - dz: float, displacement along z axis。 + + ret: + - matrices: torch.Tensor + """ + trans_mats = torch.eye(4).unsqueeze(0).repeat(self.frame_num, 1, 1) # (n, 4, 4) + + delta_x = dx / (self.frame_num - 1) + delta_y = dy / (self.frame_num - 1) + delta_z = dz / (self.frame_num - 1) + + for i in range(self.frame_num): + trans_mats[i, 0, 3] = i * delta_x + trans_mats[i, 1, 3] = i * delta_y + trans_mats[i, 2, 3] = i * delta_z + + return trans_mats.to(self.device) + + + def _look_at(self, camera_position, target_position): + # look at direction + direction = target_position - camera_position + direction /= np.linalg.norm(direction) + # calculate rotation matrix + up = np.array([0, 1, 0]) + right = np.cross(up, direction) + right /= np.linalg.norm(right) + up = np.cross(direction, right) + rotation_matrix = np.vstack([right, up, direction]) + rotation_matrix = np.linalg.inv(rotation_matrix) + return rotation_matrix + + def spiral_poses(self, radius, forward_ratio = 0.5, backward_ratio = 0.5, rotation_times = 0.1, look_at_times = 0.5): + """Generate spiral camera poses + + Args: + radius (float): Base radius of the spiral + forward_ratio (float): Scale factor for forward motion + backward_ratio (float): Scale factor for backward motion + rotation_times (float): Number of rotations to complete + look_at_times (float): Scale factor for look-at point distance + + Returns: + torch.Tensor: Camera poses of shape [num_frames, 4, 4] + """ + # Generate spiral trajectory + t = np.linspace(0, 1, self.frame_num) + r = np.sin(np.pi * t) * radius * rotation_times + theta = 2 * np.pi * t + + # Calculate camera positions + # Limit y motion for better floor/sky view + y = r * np.cos(theta) * 0.3 + x = r * np.sin(theta) + z = -r + z[z < 0] *= forward_ratio + z[z > 0] *= backward_ratio + + # Set look-at target + target_pos = np.array([0, 0, radius * look_at_times]) + cam_pos = np.vstack([x, y, z]).T + cam_poses = [] + + for pos in cam_pos: + rot_mat = self._look_at(pos, target_pos) + trans_mat = np.eye(4) + trans_mat[:3, :3] = rot_mat + trans_mat[:3, 3] = pos + cam_poses.append(trans_mat[None]) + + camera_poses = np.concatenate(cam_poses, axis=0) + return torch.from_numpy(camera_poses).to(self.device) + + def rot(self, pts, angle, axis): + """ + pts: torch.Tensor, (T, N, 2) + """ + rot_mats = self.rot_poses(angle, axis) + pts = self.apply_motion_on_pts(pts, rot_mats) + return pts + + def trans(self, pts, dx, dy, dz): + if pts.shape[-1] != 3: + raise ValueError("points should be in the 3d coordinate.") + trans_mats = self.trans_poses(dx, dy, dz) + pts = self.apply_motion_on_pts(pts, trans_mats) + return pts + + def spiral(self, pts, radius): + spiral_poses = self.spiral_poses(radius) + pts = self.apply_motion_on_pts(pts, spiral_poses) + return pts + + def get_default_motion(self): + """Parse motion parameters and generate corresponding motion matrices + + Supported formats: + - trans [start_frame] [end_frame]: Translation motion + - rot [start_frame] [end_frame]: Rotation motion + - spiral [start_frame] [end_frame]: Spiral motion + + Multiple transformations can be combined using semicolon (;) as separator: + e.g., "trans 0 0 0.5 0 30; rot x 25 0 30; trans 0.1 0 0 30 48" + + Note: + - start_frame and end_frame are optional + - frame range: 0-49 (will be clamped to this range) + - if not specified, defaults to 0-49 + - frames after end_frame will maintain the final transformation + - for combined transformations, they are applied in sequence + + Returns: + torch.Tensor: Motion matrices [num_frames, 4, 4] + """ + if not isinstance(self.motion_type, str): + raise ValueError(f'camera_motion must be a string, but got {type(self.motion_type)}') + + # Split combined transformations + transform_sequences = [s.strip() for s in self.motion_type.split(';')] + + # Initialize the final motion matrices + final_motion = torch.eye(4, device=self.device).unsqueeze(0).repeat(49, 1, 1) + + # Process each transformation in sequence + for transform in transform_sequences: + params = transform.lower().split() + if not params: + continue + + motion_type = params[0] + + # Default frame range + start_frame = 0 + end_frame = 48 # 49 frames in total (0-48) + + if motion_type == 'trans': + # Parse translation parameters + if len(params) not in [4, 6]: + raise ValueError(f"trans motion requires 3 or 5 parameters: 'trans ' or 'trans ', got: {transform}") + + dx, dy, dz = map(float, params[1:4]) + + if len(params) == 6: + start_frame = max(0, min(48, int(params[4]))) + end_frame = max(0, min(48, int(params[5]))) + if start_frame > end_frame: + start_frame, end_frame = end_frame, start_frame + + # Generate current transformation + current_motion = torch.eye(4, device=self.device).unsqueeze(0).repeat(49, 1, 1) + for frame_idx in range(49): + if frame_idx < start_frame: + continue + elif frame_idx <= end_frame: + t = (frame_idx - start_frame) / (end_frame - start_frame) + current_motion[frame_idx, :3, 3] = torch.tensor([dx, dy, dz], device=self.device) * t + else: + current_motion[frame_idx] = current_motion[end_frame] + + # Combine with previous transformations + final_motion = torch.matmul(final_motion, current_motion) + + elif motion_type == 'rot': + # Parse rotation parameters + if len(params) not in [3, 5]: + raise ValueError(f"rot motion requires 2 or 4 parameters: 'rot ' or 'rot ', got: {transform}") + + axis = params[1] + if axis not in ['x', 'y', 'z']: + raise ValueError(f"Invalid rotation axis '{axis}', must be 'x', 'y' or 'z'") + angle = float(params[2]) + + if len(params) == 5: + start_frame = max(0, min(48, int(params[3]))) + end_frame = max(0, min(48, int(params[4]))) + if start_frame > end_frame: + start_frame, end_frame = end_frame, start_frame + + current_motion = torch.eye(4, device=self.device).unsqueeze(0).repeat(49, 1, 1) + for frame_idx in range(49): + if frame_idx < start_frame: + continue + elif frame_idx <= end_frame: + t = (frame_idx - start_frame) / (end_frame - start_frame) + current_angle = angle * t + current_motion[frame_idx] = self.rot_poses(current_angle, axis) + else: + current_motion[frame_idx] = current_motion[end_frame] + + # Combine with previous transformations + final_motion = torch.matmul(final_motion, current_motion) + + elif motion_type == 'spiral': + # Parse spiral motion parameters + if len(params) not in [2, 4]: + raise ValueError(f"spiral motion requires 1 or 3 parameters: 'spiral ' or 'spiral ', got: {transform}") + + radius = float(params[1]) + + if len(params) == 4: + start_frame = max(0, min(48, int(params[2]))) + end_frame = max(0, min(48, int(params[3]))) + if start_frame > end_frame: + start_frame, end_frame = end_frame, start_frame + + current_motion = torch.eye(4, device=self.device).unsqueeze(0).repeat(49, 1, 1) + spiral_motion = self.spiral_poses(radius) + for frame_idx in range(49): + if frame_idx < start_frame: + continue + elif frame_idx <= end_frame: + t = (frame_idx - start_frame) / (end_frame - start_frame) + idx = int(t * (len(spiral_motion) - 1)) + current_motion[frame_idx] = spiral_motion[idx] + else: + current_motion[frame_idx] = current_motion[end_frame] + + # Combine with previous transformations + final_motion = torch.matmul(final_motion, current_motion) + + else: + raise ValueError(f'camera_motion type must be in [trans, spiral, rot], but got {motion_type}') + + return final_motion + +class ObjectMotionGenerator: + def __init__(self, device="cuda:0"): + self.device = device + self.num_frames = 49 + + def _get_points_in_mask(self, pred_tracks, mask): + """Get points that lie within the mask + + Args: + pred_tracks (torch.Tensor): Point trajectories [num_frames, num_points, 3] + mask (torch.Tensor): Binary mask [H, W] + + Returns: + torch.Tensor: Boolean mask for selected points [num_points] + """ + first_frame_points = pred_tracks[0] # [num_points, 3] + xy_points = first_frame_points[:, :2] # [num_points, 2] + + xy_pixels = xy_points.round().long() + xy_pixels[:, 0].clamp_(0, mask.shape[1] - 1) + xy_pixels[:, 1].clamp_(0, mask.shape[0] - 1) + + points_in_mask = mask[xy_pixels[:, 1], xy_pixels[:, 0]] + + return points_in_mask + + def apply_motion(self, pred_tracks, mask, motion_type, distance, num_frames=49, tracking_method="spatracker"): + + self.num_frames = num_frames + pred_tracks = pred_tracks.to(self.device).float() + mask = mask.to(self.device) + + template = { + 'up': ('trans', torch.tensor([0, -1, 0])), + 'down': ('trans', torch.tensor([0, 1, 0])), + 'left': ('trans', torch.tensor([-1, 0, 0])), + 'right': ('trans', torch.tensor([1, 0, 0])), + 'front': ('trans', torch.tensor([0, 0, 1])), + 'back': ('trans', torch.tensor([0, 0, -1])), + 'rot': ('rot', None) # rotate around y axis + } + + if motion_type not in template: + raise ValueError(f"unknown motion type: {motion_type}") + + motion_type, base_vec = template[motion_type] + if base_vec is not None: + base_vec = base_vec.to(self.device) * distance + + if tracking_method == "moge": + T, H, W, _ = pred_tracks.shape + valid_selected = ~torch.any(torch.isnan(pred_tracks[0]), dim=2) & mask + points = pred_tracks[0][valid_selected].reshape(-1, 3) + else: + points_in_mask = self._get_points_in_mask(pred_tracks, mask) + points = pred_tracks[0, points_in_mask] + + center = points.mean(dim=0) + + motions = [] + for frame_idx in range(num_frames): + t = frame_idx / (num_frames - 1) + current_motion = torch.eye(4, device=self.device) + current_motion[:3, 3] = -center + motion_mat = torch.eye(4, device=self.device) + if motion_type == 'trans': + motion_mat[:3, 3] = base_vec * t + else: # 'rot' + angle_rad = torch.deg2rad(torch.tensor(distance * t, device=self.device)) + cos_t = torch.cos(angle_rad) + sin_t = torch.sin(angle_rad) + motion_mat[0, 0] = cos_t + motion_mat[0, 2] = sin_t + motion_mat[2, 0] = -sin_t + motion_mat[2, 2] = cos_t + + current_motion = motion_mat @ current_motion + current_motion[:3, 3] += center + motions.append(current_motion) + + motions = torch.stack(motions) # [num_frames, 4, 4] + + if tracking_method == "moge": + modified_tracks = pred_tracks.clone().reshape(T, -1, 3) + valid_selected = valid_selected.reshape([-1]) + + for frame_idx in range(self.num_frames): + motion_mat = motions[frame_idx] + if W > 1: + motion_mat = motion_mat.clone() + motion_mat[0, 3] /= W + motion_mat[1, 3] /= H + points = modified_tracks[frame_idx, valid_selected] + points_homo = torch.cat([points, torch.ones_like(points[:, :1])], dim=1) + transformed_points = torch.matmul(points_homo, motion_mat.T) + modified_tracks[frame_idx, valid_selected] = transformed_points[:, :3] + + return modified_tracks.reshape(T, H, W, 3) + + else: + points_in_mask = self._get_points_in_mask(pred_tracks, mask) + modified_tracks = pred_tracks.clone() + + for frame_idx in range(pred_tracks.shape[0]): + motion_mat = motions[frame_idx] + points = modified_tracks[frame_idx, points_in_mask] + points_homo = torch.cat([points, torch.ones_like(points[:, :1])], dim=1) + transformed_points = torch.matmul(points_homo, motion_mat.T) + modified_tracks[frame_idx, points_in_mask] = transformed_points[:, :3] + + return modified_tracks \ No newline at end of file diff --git a/models/spatracker/__init__.py b/models/spatracker/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..5277f46157403e47fd830fc519144b97ef69d4ae --- /dev/null +++ b/models/spatracker/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. diff --git a/models/spatracker/models/__init__.py b/models/spatracker/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..5277f46157403e47fd830fc519144b97ef69d4ae --- /dev/null +++ b/models/spatracker/models/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. diff --git a/models/spatracker/models/build_spatracker.py b/models/spatracker/models/build_spatracker.py new file mode 100644 index 0000000000000000000000000000000000000000..7cb0dcb9dd6ce198aeb25d1d27c976fee989c45d --- /dev/null +++ b/models/spatracker/models/build_spatracker.py @@ -0,0 +1,51 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import torch + +from models.spatracker.models.core.spatracker.spatracker import SpaTracker + + +def build_spatracker( + checkpoint: str, + seq_length: int = 8, +): + model_name = checkpoint.split("/")[-1].split(".")[0] + return build_spatracker_from_cfg(checkpoint=checkpoint, seq_length=seq_length) + + + +# model used to produce the results in the paper +def build_spatracker_from_cfg(checkpoint=None, seq_length=8): + return _build_spatracker( + stride=4, + sequence_len=seq_length, + checkpoint=checkpoint, + ) + + +def _build_spatracker( + stride, + sequence_len, + checkpoint=None, +): + spatracker = SpaTracker( + stride=stride, + S=sequence_len, + add_space_attn=True, + space_depth=6, + time_depth=6, + ) + if checkpoint is not None: + with open(checkpoint, "rb") as f: + state_dict = torch.load(f, map_location="cpu", weights_only=True) + if "model" in state_dict: + model_paras = spatracker.state_dict() + paras_dict = {k: v for k,v in state_dict["model"].items() if k in spatracker.state_dict()} + model_paras.update(paras_dict) + state_dict = model_paras + spatracker.load_state_dict(state_dict) + return spatracker diff --git a/models/spatracker/models/core/__init__.py b/models/spatracker/models/core/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..5277f46157403e47fd830fc519144b97ef69d4ae --- /dev/null +++ b/models/spatracker/models/core/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. diff --git a/models/spatracker/models/core/embeddings.py b/models/spatracker/models/core/embeddings.py new file mode 100644 index 0000000000000000000000000000000000000000..1b84c0db9b4351623d16660bde0b98b3252fce6d --- /dev/null +++ b/models/spatracker/models/core/embeddings.py @@ -0,0 +1,250 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import torch +import numpy as np + +def get_3d_sincos_pos_embed(embed_dim, grid_size, cls_token=False, extra_tokens=0): + """ + grid_size: int of the grid height and width + return: + pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) + """ + if isinstance(grid_size, tuple): + grid_size_h, grid_size_w = grid_size + else: + grid_size_h = grid_size_w = grid_size + grid_h = np.arange(grid_size_h, dtype=np.float32) + grid_w = np.arange(grid_size_w, dtype=np.float32) + grid = np.meshgrid(grid_w, grid_h) # here w goes first + grid = np.stack(grid, axis=0) + + grid = grid.reshape([2, 1, grid_size_h, grid_size_w]) + pos_embed = get_3d_sincos_pos_embed_from_grid(embed_dim, grid) + if cls_token and extra_tokens > 0: + pos_embed = np.concatenate( + [np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0 + ) + return pos_embed + + +def get_3d_sincos_pos_embed_from_grid(embed_dim, grid): + assert embed_dim % 3 == 0 + + # use half of dimensions to encode grid_h + B, S, N, _ = grid.shape + gridx = grid[..., 0].view(B*S*N).detach().cpu().numpy() + gridy = grid[..., 1].view(B*S*N).detach().cpu().numpy() + gridz = grid[..., 2].view(B*S*N).detach().cpu().numpy() + + emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 3, gridx) # (N, D/3) + emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 3, gridy) # (N, D/3) + emb_z = get_1d_sincos_pos_embed_from_grid(embed_dim // 3, gridz) # (N, D/3) + + + emb = np.concatenate([emb_h, emb_w, emb_z], axis=1) # (N, D) + emb = torch.from_numpy(emb).to(grid.device) + return emb.view(B, S, N, embed_dim) + + +def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False, extra_tokens=0): + """ + grid_size: int of the grid height and width + return: + pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) + """ + if isinstance(grid_size, tuple): + grid_size_h, grid_size_w = grid_size + else: + grid_size_h = grid_size_w = grid_size + grid_h = np.arange(grid_size_h, dtype=np.float32) + grid_w = np.arange(grid_size_w, dtype=np.float32) + grid = np.meshgrid(grid_w, grid_h) # here w goes first + grid = np.stack(grid, axis=0) + + grid = grid.reshape([2, 1, grid_size_h, grid_size_w]) + pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) + if cls_token and extra_tokens > 0: + pos_embed = np.concatenate( + [np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0 + ) + return pos_embed + + +def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): + assert embed_dim % 2 == 0 + + # use half of dimensions to encode grid_h + emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) + emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) + + emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) + return emb + + +def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): + """ + embed_dim: output dimension for each position + pos: a list of positions to be encoded: size (M,) + out: (M, D) + """ + assert embed_dim % 2 == 0 + omega = np.arange(embed_dim // 2, dtype=np.float64) + omega /= embed_dim / 2.0 + omega = 1.0 / 10000 ** omega # (D/2,) + + pos = pos.reshape(-1) # (M,) + out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product + + emb_sin = np.sin(out) # (M, D/2) + emb_cos = np.cos(out) # (M, D/2) + + emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) + return emb + + +def get_2d_embedding(xy, C, cat_coords=True): + B, N, D = xy.shape + assert D == 2 + + x = xy[:, :, 0:1] + y = xy[:, :, 1:2] + div_term = ( + torch.arange(0, C, 2, device=xy.device, dtype=torch.float32) * (1000.0 / C) + ).reshape(1, 1, int(C / 2)) + + pe_x = torch.zeros(B, N, C, device=xy.device, dtype=torch.float32) + pe_y = torch.zeros(B, N, C, device=xy.device, dtype=torch.float32) + + pe_x[:, :, 0::2] = torch.sin(x * div_term) + pe_x[:, :, 1::2] = torch.cos(x * div_term) + + pe_y[:, :, 0::2] = torch.sin(y * div_term) + pe_y[:, :, 1::2] = torch.cos(y * div_term) + + pe = torch.cat([pe_x, pe_y], dim=2) # B, N, C*3 + if cat_coords: + pe = torch.cat([xy, pe], dim=2) # B, N, C*3+3 + return pe + + +def get_3d_embedding(xyz, C, cat_coords=True): + B, N, D = xyz.shape + assert D == 3 + + x = xyz[:, :, 0:1] + y = xyz[:, :, 1:2] + z = xyz[:, :, 2:3] + div_term = ( + torch.arange(0, C, 2, device=xyz.device, dtype=torch.float32) * (1000.0 / C) + ).reshape(1, 1, int(C / 2)) + + pe_x = torch.zeros(B, N, C, device=xyz.device, dtype=torch.float32) + pe_y = torch.zeros(B, N, C, device=xyz.device, dtype=torch.float32) + pe_z = torch.zeros(B, N, C, device=xyz.device, dtype=torch.float32) + + pe_x[:, :, 0::2] = torch.sin(x * div_term) + pe_x[:, :, 1::2] = torch.cos(x * div_term) + + pe_y[:, :, 0::2] = torch.sin(y * div_term) + pe_y[:, :, 1::2] = torch.cos(y * div_term) + + pe_z[:, :, 0::2] = torch.sin(z * div_term) + pe_z[:, :, 1::2] = torch.cos(z * div_term) + + pe = torch.cat([pe_x, pe_y, pe_z], dim=2) # B, N, C*3 + if cat_coords: + pe = torch.cat([pe, xyz], dim=2) # B, N, C*3+3 + return pe + + +def get_4d_embedding(xyzw, C, cat_coords=True): + B, N, D = xyzw.shape + assert D == 4 + + x = xyzw[:, :, 0:1] + y = xyzw[:, :, 1:2] + z = xyzw[:, :, 2:3] + w = xyzw[:, :, 3:4] + div_term = ( + torch.arange(0, C, 2, device=xyzw.device, dtype=torch.float32) * (1000.0 / C) + ).reshape(1, 1, int(C / 2)) + + pe_x = torch.zeros(B, N, C, device=xyzw.device, dtype=torch.float32) + pe_y = torch.zeros(B, N, C, device=xyzw.device, dtype=torch.float32) + pe_z = torch.zeros(B, N, C, device=xyzw.device, dtype=torch.float32) + pe_w = torch.zeros(B, N, C, device=xyzw.device, dtype=torch.float32) + + pe_x[:, :, 0::2] = torch.sin(x * div_term) + pe_x[:, :, 1::2] = torch.cos(x * div_term) + + pe_y[:, :, 0::2] = torch.sin(y * div_term) + pe_y[:, :, 1::2] = torch.cos(y * div_term) + + pe_z[:, :, 0::2] = torch.sin(z * div_term) + pe_z[:, :, 1::2] = torch.cos(z * div_term) + + pe_w[:, :, 0::2] = torch.sin(w * div_term) + pe_w[:, :, 1::2] = torch.cos(w * div_term) + + pe = torch.cat([pe_x, pe_y, pe_z, pe_w], dim=2) # B, N, C*3 + if cat_coords: + pe = torch.cat([pe, xyzw], dim=2) # B, N, C*3+3 + return pe + +import torch.nn as nn +class Embedder_Fourier(nn.Module): + def __init__(self, input_dim, max_freq_log2, N_freqs, + log_sampling=True, include_input=True, + periodic_fns=(torch.sin, torch.cos)): + ''' + :param input_dim: dimension of input to be embedded + :param max_freq_log2: log2 of max freq; min freq is 1 by default + :param N_freqs: number of frequency bands + :param log_sampling: if True, frequency bands are linerly sampled in log-space + :param include_input: if True, raw input is included in the embedding + :param periodic_fns: periodic functions used to embed input + ''' + super(Embedder_Fourier, self).__init__() + + self.input_dim = input_dim + self.include_input = include_input + self.periodic_fns = periodic_fns + + self.out_dim = 0 + if self.include_input: + self.out_dim += self.input_dim + + self.out_dim += self.input_dim * N_freqs * len(self.periodic_fns) + + if log_sampling: + self.freq_bands = 2. ** torch.linspace(0., max_freq_log2, N_freqs) + else: + self.freq_bands = torch.linspace( + 2. ** 0., 2. ** max_freq_log2, N_freqs) + + self.freq_bands = self.freq_bands.numpy().tolist() + + def forward(self, + input: torch.Tensor, + rescale: float = 1.0): + ''' + :param input: tensor of shape [..., self.input_dim] + :return: tensor of shape [..., self.out_dim] + ''' + assert (input.shape[-1] == self.input_dim) + out = [] + if self.include_input: + out.append(input/rescale) + + for i in range(len(self.freq_bands)): + freq = self.freq_bands[i] + for p_fn in self.periodic_fns: + out.append(p_fn(input * freq)) + out = torch.cat(out, dim=-1) + + assert (out.shape[-1] == self.out_dim) + return out \ No newline at end of file diff --git a/models/spatracker/models/core/model_utils.py b/models/spatracker/models/core/model_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..3eda98a4555fbeb8da192958485c5881c7f40461 --- /dev/null +++ b/models/spatracker/models/core/model_utils.py @@ -0,0 +1,477 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import torch +import torch.nn.functional as F +from easydict import EasyDict as edict +from sklearn.decomposition import PCA +import matplotlib.pyplot as plt + +EPS = 1e-6 + +def nearest_sample2d(im, x, y, return_inbounds=False): + # x and y are each B, N + # output is B, C, N + if len(im.shape) == 5: + B, N, C, H, W = list(im.shape) + else: + B, C, H, W = list(im.shape) + N = list(x.shape)[1] + + x = x.float() + y = y.float() + H_f = torch.tensor(H, dtype=torch.float32) + W_f = torch.tensor(W, dtype=torch.float32) + + # inbound_mask = (x>-0.5).float()*(y>-0.5).float()*(x -0.5).byte() & (x < float(W_f - 0.5)).byte() + y_valid = (y > -0.5).byte() & (y < float(H_f - 0.5)).byte() + inbounds = (x_valid & y_valid).float() + inbounds = inbounds.reshape( + B, N + ) # something seems wrong here for B>1; i'm getting an error here (or downstream if i put -1) + return output, inbounds + + return output # B, C, N + +def smart_cat(tensor1, tensor2, dim): + if tensor1 is None: + return tensor2 + return torch.cat([tensor1, tensor2], dim=dim) + + +def normalize_single(d): + # d is a whatever shape torch tensor + dmin = torch.min(d) + dmax = torch.max(d) + d = (d - dmin) / (EPS + (dmax - dmin)) + return d + + +def normalize(d): + # d is B x whatever. normalize within each element of the batch + out = torch.zeros(d.size()) + if d.is_cuda: + out = out.cuda() + B = list(d.size())[0] + for b in list(range(B)): + out[b] = normalize_single(d[b]) + return out + + +def meshgrid2d(B, Y, X, stack=False, norm=False, device="cuda"): + # returns a meshgrid sized B x Y x X + + grid_y = torch.linspace(0.0, Y - 1, Y, device=torch.device(device)) + grid_y = torch.reshape(grid_y, [1, Y, 1]) + grid_y = grid_y.repeat(B, 1, X) + + grid_x = torch.linspace(0.0, X - 1, X, device=torch.device(device)) + grid_x = torch.reshape(grid_x, [1, 1, X]) + grid_x = grid_x.repeat(B, Y, 1) + + if stack: + # note we stack in xy order + # (see https://pytorch.org/docs/stable/nn.functional.html#torch.nn.functional.grid_sample) + grid = torch.stack([grid_x, grid_y], dim=-1) + return grid + else: + return grid_y, grid_x + + +def reduce_masked_mean(x, mask, dim=None, keepdim=False): + # x and mask are the same shape, or at least broadcastably so < actually it's safer if you disallow broadcasting + # returns shape-1 + # axis can be a list of axes + for (a, b) in zip(x.size(), mask.size()): + assert a == b # some shape mismatch! + prod = x * mask + if dim is None: + numer = torch.sum(prod) + denom = EPS + torch.sum(mask) + else: + numer = torch.sum(prod, dim=dim, keepdim=keepdim) + denom = EPS + torch.sum(mask, dim=dim, keepdim=keepdim) + + mean = numer / denom + return mean + + +def bilinear_sample2d(im, x, y, return_inbounds=False): + # x and y are each B, N + # output is B, C, N + if len(im.shape) == 5: + B, N, C, H, W = list(im.shape) + else: + B, C, H, W = list(im.shape) + N = list(x.shape)[1] + + x = x.float() + y = y.float() + H_f = torch.tensor(H, dtype=torch.float32) + W_f = torch.tensor(W, dtype=torch.float32) + + # inbound_mask = (x>-0.5).float()*(y>-0.5).float()*(x -0.5).byte() & (x < float(W_f - 0.5)).byte() + y_valid = (y > -0.5).byte() & (y < float(H_f - 0.5)).byte() + inbounds = (x_valid & y_valid).float() + inbounds = inbounds.reshape( + B, N + ) # something seems wrong here for B>1; i'm getting an error here (or downstream if i put -1) + return output, inbounds + + return output # B, C, N + + +def procrustes_analysis(X0,X1,Weight): # [B,N,3] + # translation + t0 = X0.mean(dim=1,keepdim=True) + t1 = X1.mean(dim=1,keepdim=True) + X0c = X0-t0 + X1c = X1-t1 + # scale + # s0 = (X0c**2).sum(dim=-1).mean().sqrt() + # s1 = (X1c**2).sum(dim=-1).mean().sqrt() + # X0cs = X0c/s0 + # X1cs = X1c/s1 + # rotation (use double for SVD, float loses precision) + U,_,V = (X0c.t()@X1c).double().svd(some=True) + R = (U@V.t()).float() + if R.det()<0: R[2] *= -1 + # align X1 to X0: X1to0 = (X1-t1)/@R.t()+t0 + se3 = edict(t0=t0[0],t1=t1[0],R=R) + + return se3 + +def bilinear_sampler(input, coords, align_corners=True, padding_mode="border"): + r"""Sample a tensor using bilinear interpolation + + `bilinear_sampler(input, coords)` samples a tensor :attr:`input` at + coordinates :attr:`coords` using bilinear interpolation. It is the same + as `torch.nn.functional.grid_sample()` but with a different coordinate + convention. + + The input tensor is assumed to be of shape :math:`(B, C, H, W)`, where + :math:`B` is the batch size, :math:`C` is the number of channels, + :math:`H` is the height of the image, and :math:`W` is the width of the + image. The tensor :attr:`coords` of shape :math:`(B, H_o, W_o, 2)` is + interpreted as an array of 2D point coordinates :math:`(x_i,y_i)`. + + Alternatively, the input tensor can be of size :math:`(B, C, T, H, W)`, + in which case sample points are triplets :math:`(t_i,x_i,y_i)`. Note + that in this case the order of the components is slightly different + from `grid_sample()`, which would expect :math:`(x_i,y_i,t_i)`. + + If `align_corners` is `True`, the coordinate :math:`x` is assumed to be + in the range :math:`[0,W-1]`, with 0 corresponding to the center of the + left-most image pixel :math:`W-1` to the center of the right-most + pixel. + + If `align_corners` is `False`, the coordinate :math:`x` is assumed to + be in the range :math:`[0,W]`, with 0 corresponding to the left edge of + the left-most pixel :math:`W` to the right edge of the right-most + pixel. + + Similar conventions apply to the :math:`y` for the range + :math:`[0,H-1]` and :math:`[0,H]` and to :math:`t` for the range + :math:`[0,T-1]` and :math:`[0,T]`. + + Args: + input (Tensor): batch of input images. + coords (Tensor): batch of coordinates. + align_corners (bool, optional): Coordinate convention. Defaults to `True`. + padding_mode (str, optional): Padding mode. Defaults to `"border"`. + + Returns: + Tensor: sampled points. + """ + + sizes = input.shape[2:] + + assert len(sizes) in [2, 3] + + if len(sizes) == 3: + # t x y -> x y t to match dimensions T H W in grid_sample + coords = coords[..., [1, 2, 0]] + + if align_corners: + coords = coords * torch.tensor( + [2 / max(size - 1, 1) for size in reversed(sizes)], device=coords.device + ) + else: + coords = coords * torch.tensor([2 / size for size in reversed(sizes)], device=coords.device) + + coords -= 1 + + return F.grid_sample(input, coords, align_corners=align_corners, padding_mode=padding_mode) + + +def sample_features4d(input, coords): + r"""Sample spatial features + + `sample_features4d(input, coords)` samples the spatial features + :attr:`input` represented by a 4D tensor :math:`(B, C, H, W)`. + + The field is sampled at coordinates :attr:`coords` using bilinear + interpolation. :attr:`coords` is assumed to be of shape :math:`(B, R, + 3)`, where each sample has the format :math:`(x_i, y_i)`. This uses the + same convention as :func:`bilinear_sampler` with `align_corners=True`. + + The output tensor has one feature per point, and has shape :math:`(B, + R, C)`. + + Args: + input (Tensor): spatial features. + coords (Tensor): points. + + Returns: + Tensor: sampled features. + """ + + B, _, _, _ = input.shape + + # B R 2 -> B R 1 2 + coords = coords.unsqueeze(2) + + # B C R 1 + feats = bilinear_sampler(input, coords) + + return feats.permute(0, 2, 1, 3).view( + B, -1, feats.shape[1] * feats.shape[3] + ) # B C R 1 -> B R C + + +def sample_features5d(input, coords): + r"""Sample spatio-temporal features + + `sample_features5d(input, coords)` works in the same way as + :func:`sample_features4d` but for spatio-temporal features and points: + :attr:`input` is a 5D tensor :math:`(B, T, C, H, W)`, :attr:`coords` is + a :math:`(B, R1, R2, 3)` tensor of spatio-temporal point :math:`(t_i, + x_i, y_i)`. The output tensor has shape :math:`(B, R1, R2, C)`. + + Args: + input (Tensor): spatio-temporal features. + coords (Tensor): spatio-temporal points. + + Returns: + Tensor: sampled features. + """ + + B, T, _, _, _ = input.shape + + # B T C H W -> B C T H W + input = input.permute(0, 2, 1, 3, 4) + + # B R1 R2 3 -> B R1 R2 1 3 + coords = coords.unsqueeze(3) + + # B C R1 R2 1 + feats = bilinear_sampler(input, coords) + + return feats.permute(0, 2, 3, 1, 4).view( + B, feats.shape[2], feats.shape[3], feats.shape[1] + ) # B C R1 R2 1 -> B R1 R2 C + +def vis_PCA(fmaps, save_dir): + """ + visualize the PCA of the feature maps + args: + fmaps: feature maps 1 C H W + save_dir: the directory to save the PCA visualization + """ + + pca = PCA(n_components=3) + fmap_vis = fmaps[0,...] + fmap_vnorm = ( + (fmap_vis-fmap_vis.min())/ + (fmap_vis.max()-fmap_vis.min())) + H_vis, W_vis = fmap_vis.shape[1:] + fmap_vnorm = fmap_vnorm.reshape(fmap_vnorm.shape[0], + -1).permute(1,0) + fmap_pca = pca.fit_transform(fmap_vnorm.detach().cpu().numpy()) + pca = fmap_pca.reshape(H_vis,W_vis,3) + plt.imsave(save_dir, + ( + (pca-pca.min())/ + (pca.max()-pca.min()) + )) + + + # debug=False + # if debug==True: + # pcd_idx = 60 + # vis_PCA(fmapYZ[0,:1], "./yz.png") + # vis_PCA(fmapXZ[0,:1], "./xz.png") + # vis_PCA(fmaps[0,:1], "./xy.png") + # vis_PCA(fmaps[0,-1:], "./xy_.png") + # fxy_q = fxy[0,0,pcd_idx:pcd_idx+1, :, None, None] + # fyz_q = fyz[0,0,pcd_idx:pcd_idx+1, :, None, None] + # fxz_q = fxz[0,0,pcd_idx:pcd_idx+1, :, None, None] + # corr_map = (fxy_q*fmaps[0,-1:]).sum(dim=1) + # corr_map_yz = (fyz_q*fmapYZ[0,-1:]).sum(dim=1) + # corr_map_xz = (fxz_q*fmapXZ[0,-1:]).sum(dim=1) + # coord_last = coords[0,-1,pcd_idx:pcd_idx+1] + # coord_last_neigh = coords[0,-1, self.neigh_indx[pcd_idx]] + # depth_last = depths_dnG[-1,0] + # abs_res = (depth_last-coord_last[-1,-1]).abs() + # abs_res = (abs_res - abs_res.min())/(abs_res.max()-abs_res.min()) + # res_dp = torch.exp(-abs_res) + # enhance_corr = res_dp*corr_map + # plt.imsave("./res.png", res_dp.detach().cpu().numpy()) + # plt.imsave("./enhance_corr.png", enhance_corr[0].detach().cpu().numpy()) + # plt.imsave("./corr_map.png", corr_map[0].detach().cpu().numpy()) + # plt.imsave("./corr_map_yz.png", corr_map_yz[0].detach().cpu().numpy()) + # plt.imsave("./corr_map_xz.png", corr_map_xz[0].detach().cpu().numpy()) + # img_feat = cv2.imread("./xy.png") + # cv2.circle(img_feat, (int(coord_last[0,0]), int(coord_last[0,1])), 2, (0, 0, 255), -1) + # for p_i in coord_last_neigh: + # cv2.circle(img_feat, (int(p_i[0]), int(p_i[1])), 1, (0, 255, 0), -1) + # cv2.imwrite("./xy_coord.png", img_feat) + # import ipdb; ipdb.set_trace() \ No newline at end of file diff --git a/models/spatracker/models/core/spatracker/__init__.py b/models/spatracker/models/core/spatracker/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..5277f46157403e47fd830fc519144b97ef69d4ae --- /dev/null +++ b/models/spatracker/models/core/spatracker/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. diff --git a/models/spatracker/models/core/spatracker/blocks.py b/models/spatracker/models/core/spatracker/blocks.py new file mode 100644 index 0000000000000000000000000000000000000000..2c88f3da86443771878d9a4d31c1b57da54f54d5 --- /dev/null +++ b/models/spatracker/models/core/spatracker/blocks.py @@ -0,0 +1,999 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.cuda.amp import autocast +from einops import rearrange +import collections +from functools import partial +from itertools import repeat +import torchvision.models as tvm + +from models.spatracker.models.core.spatracker.vit.encoder import ImageEncoderViT as vitEnc +from models.spatracker.models.core.spatracker.dpt.models import DPTEncoder +from models.spatracker.models.core.spatracker.loftr import LocalFeatureTransformer +# from models.monoD.depth_anything.dpt import DPTHeadEnc, DPTHead + +# From PyTorch internals +def _ntuple(n): + def parse(x): + if isinstance(x, collections.abc.Iterable) and not isinstance(x, str): + return tuple(x) + return tuple(repeat(x, n)) + + return parse + + +def exists(val): + return val is not None + + +def default(val, d): + return val if exists(val) else d + + +to_2tuple = _ntuple(2) + + +class Mlp(nn.Module): + """MLP as used in Vision Transformer, MLP-Mixer and related networks""" + + def __init__( + self, + in_features, + hidden_features=None, + out_features=None, + act_layer=nn.GELU, + norm_layer=None, + bias=True, + drop=0.0, + use_conv=False, + ): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + bias = to_2tuple(bias) + drop_probs = to_2tuple(drop) + linear_layer = partial(nn.Conv2d, kernel_size=1) if use_conv else nn.Linear + + self.fc1 = linear_layer(in_features, hidden_features, bias=bias[0]) + self.act = act_layer() + self.drop1 = nn.Dropout(drop_probs[0]) + self.norm = norm_layer(hidden_features) if norm_layer is not None else nn.Identity() + self.fc2 = linear_layer(hidden_features, out_features, bias=bias[1]) + self.drop2 = nn.Dropout(drop_probs[1]) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop1(x) + x = self.fc2(x) + x = self.drop2(x) + return x + +class Attention(nn.Module): + def __init__(self, query_dim, context_dim=None, + num_heads=8, dim_head=48, qkv_bias=False, flash=False): + super().__init__() + inner_dim = self.inner_dim = dim_head * num_heads + context_dim = default(context_dim, query_dim) + self.scale = dim_head**-0.5 + self.heads = num_heads + self.flash = flash + + self.qkv = nn.Linear(query_dim, inner_dim*3, bias=qkv_bias) + self.proj = nn.Linear(inner_dim, query_dim) + + def forward(self, x, context=None, attn_bias=None): + B, N1, _ = x.shape + C = self.inner_dim + h = self.heads + # q = self.to_q(x).reshape(B, N1, h, C // h).permute(0, 2, 1, 3) + # k, v = self.to_kv(context).chunk(2, dim=-1) + # context = default(context, x) + + qkv = self.qkv(x).reshape(B, N1, 3, h, C // h) + q, k, v = qkv[:,:, 0], qkv[:,:, 1], qkv[:,:, 2] + N2 = x.shape[1] + + k = k.reshape(B, N2, h, C // h).permute(0, 2, 1, 3) + v = v.reshape(B, N2, h, C // h).permute(0, 2, 1, 3) + q = q.reshape(B, N1, h, C // h).permute(0, 2, 1, 3) + if self.flash==False: + sim = (q @ k.transpose(-2, -1)) * self.scale + if attn_bias is not None: + sim = sim + attn_bias + attn = sim.softmax(dim=-1) + x = (attn @ v).transpose(1, 2).reshape(B, N1, C) + else: + input_args = [x.half().contiguous() for x in [q, k, v]] + x = F.scaled_dot_product_attention(*input_args).permute(0,2,1,3).reshape(B,N1,-1) # type: ignore + + # return self.to_out(x.float()) + return self.proj(x.float()) + +class ResidualBlock(nn.Module): + def __init__(self, in_planes, planes, norm_fn="group", stride=1): + super(ResidualBlock, self).__init__() + + self.conv1 = nn.Conv2d( + in_planes, + planes, + kernel_size=3, + padding=1, + stride=stride, + padding_mode="zeros", + ) + self.conv2 = nn.Conv2d( + planes, planes, kernel_size=3, padding=1, padding_mode="zeros" + ) + self.relu = nn.ReLU(inplace=True) + + num_groups = planes // 8 + + if norm_fn == "group": + self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) + self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) + if not stride == 1: + self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) + + elif norm_fn == "batch": + self.norm1 = nn.BatchNorm2d(planes) + self.norm2 = nn.BatchNorm2d(planes) + if not stride == 1: + self.norm3 = nn.BatchNorm2d(planes) + + elif norm_fn == "instance": + self.norm1 = nn.InstanceNorm2d(planes) + self.norm2 = nn.InstanceNorm2d(planes) + if not stride == 1: + self.norm3 = nn.InstanceNorm2d(planes) + + elif norm_fn == "none": + self.norm1 = nn.Sequential() + self.norm2 = nn.Sequential() + if not stride == 1: + self.norm3 = nn.Sequential() + + if stride == 1: + self.downsample = None + + else: + self.downsample = nn.Sequential( + nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm3 + ) + + def forward(self, x): + y = x + y = self.relu(self.norm1(self.conv1(y))) + y = self.relu(self.norm2(self.conv2(y))) + + if self.downsample is not None: + x = self.downsample(x) + + return self.relu(x + y) + + +class BasicEncoder(nn.Module): + def __init__( + self, input_dim=3, output_dim=128, stride=8, norm_fn="batch", dropout=0.0, + Embed3D=False + ): + super(BasicEncoder, self).__init__() + self.stride = stride + self.norm_fn = norm_fn + self.in_planes = 64 + + if self.norm_fn == "group": + self.norm1 = nn.GroupNorm(num_groups=8, num_channels=self.in_planes) + self.norm2 = nn.GroupNorm(num_groups=8, num_channels=output_dim * 2) + + elif self.norm_fn == "batch": + self.norm1 = nn.BatchNorm2d(self.in_planes) + self.norm2 = nn.BatchNorm2d(output_dim * 2) + + elif self.norm_fn == "instance": + self.norm1 = nn.InstanceNorm2d(self.in_planes) + self.norm2 = nn.InstanceNorm2d(output_dim * 2) + + elif self.norm_fn == "none": + self.norm1 = nn.Sequential() + + self.conv1 = nn.Conv2d( + input_dim, + self.in_planes, + kernel_size=7, + stride=2, + padding=3, + padding_mode="zeros", + ) + self.relu1 = nn.ReLU(inplace=True) + + self.shallow = False + if self.shallow: + self.layer1 = self._make_layer(64, stride=1) + self.layer2 = self._make_layer(96, stride=2) + self.layer3 = self._make_layer(128, stride=2) + self.conv2 = nn.Conv2d(128 + 96 + 64, output_dim, kernel_size=1) + else: + if Embed3D: + self.conv_fuse = nn.Conv2d(64+63, + self.in_planes, kernel_size=3, padding=1) + self.layer1 = self._make_layer(64, stride=1) + self.layer2 = self._make_layer(96, stride=2) + self.layer3 = self._make_layer(128, stride=2) + self.layer4 = self._make_layer(128, stride=2) + self.conv2 = nn.Conv2d( + 128 + 128 + 96 + 64, + output_dim * 2, + kernel_size=3, + padding=1, + padding_mode="zeros", + ) + self.relu2 = nn.ReLU(inplace=True) + self.conv3 = nn.Conv2d(output_dim * 2, output_dim, kernel_size=1) + + self.dropout = None + if dropout > 0: + self.dropout = nn.Dropout2d(p=dropout) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode="fan_out", + nonlinearity="relu") + elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)): + if m.weight is not None: + nn.init.constant_(m.weight, 1) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + + def _make_layer(self, dim, stride=1): + layer1 = ResidualBlock(self.in_planes, dim, self.norm_fn, stride=stride) + layer2 = ResidualBlock(dim, dim, self.norm_fn, stride=1) + layers = (layer1, layer2) + + self.in_planes = dim + return nn.Sequential(*layers) + + def forward(self, x, feat_PE=None): + _, _, H, W = x.shape + + x = self.conv1(x) + x = self.norm1(x) + x = self.relu1(x) + + if self.shallow: + a = self.layer1(x) + b = self.layer2(a) + c = self.layer3(b) + a = F.interpolate( + a, + (H // self.stride, W // self.stride), + mode="bilinear", + align_corners=True, + ) + b = F.interpolate( + b, + (H // self.stride, W // self.stride), + mode="bilinear", + align_corners=True, + ) + c = F.interpolate( + c, + (H // self.stride, W // self.stride), + mode="bilinear", + align_corners=True, + ) + x = self.conv2(torch.cat([a, b, c], dim=1)) + else: + if feat_PE is not None: + x = self.conv_fuse(torch.cat([x, feat_PE], dim=1)) + a = self.layer1(x) + else: + a = self.layer1(x) + b = self.layer2(a) + c = self.layer3(b) + d = self.layer4(c) + a = F.interpolate( + a, + (H // self.stride, W // self.stride), + mode="bilinear", + align_corners=True, + ) + b = F.interpolate( + b, + (H // self.stride, W // self.stride), + mode="bilinear", + align_corners=True, + ) + c = F.interpolate( + c, + (H // self.stride, W // self.stride), + mode="bilinear", + align_corners=True, + ) + d = F.interpolate( + d, + (H // self.stride, W // self.stride), + mode="bilinear", + align_corners=True, + ) + x = self.conv2(torch.cat([a, b, c, d], dim=1)) + x = self.norm2(x) + x = self.relu2(x) + x = self.conv3(x) + + if self.training and self.dropout is not None: + x = self.dropout(x) + return x + +class VitEncoder(nn.Module): + def __init__(self, input_dim=4, output_dim=128, stride=4): + super(VitEncoder, self).__init__() + self.vit = vitEnc(img_size=512, + depth=6, num_heads=8, in_chans=input_dim, + out_chans=output_dim,embed_dim=384).cuda() + self.stride = stride + def forward(self, x): + T, C, H, W = x.shape + x_resize = F.interpolate(x.view(-1, C, H, W), size=(512, 512), + mode='bilinear', align_corners=False) + x_resize = self.vit(x_resize) + x = F.interpolate(x_resize, size=(H//self.stride, W//self.stride), + mode='bilinear', align_corners=False) + return x + +class DPTEnc(nn.Module): + def __init__(self, input_dim=3, output_dim=128, stride=2): + super(DPTEnc, self).__init__() + self.dpt = DPTEncoder() + self.stride = stride + def forward(self, x): + T, C, H, W = x.shape + x = (x-0.5)/0.5 + x_resize = F.interpolate(x.view(-1, C, H, W), size=(384, 384), + mode='bilinear', align_corners=False) + x_resize = self.dpt(x_resize) + x = F.interpolate(x_resize, size=(H//self.stride, W//self.stride), + mode='bilinear', align_corners=False) + return x + +# class DPT_DINOv2(nn.Module): +# def __init__(self, encoder='vits', features=64, out_channels=[48, 96, 192, 384], +# use_bn=True, use_clstoken=False, localhub=True, stride=2, enc_only=True): +# super(DPT_DINOv2, self).__init__() +# self.stride = stride +# self.enc_only = enc_only +# assert encoder in ['vits', 'vitb', 'vitl'] + +# if localhub: +# self.pretrained = torch.hub.load('models/torchhub/facebookresearch_dinov2_main', 'dinov2_{:}14'.format(encoder), source='local', pretrained=False) +# else: +# self.pretrained = torch.hub.load('facebookresearch/dinov2', 'dinov2_{:}14'.format(encoder)) + +# state_dict = torch.load("models/monoD/zoeDepth/ckpts/dinov2_vits14_pretrain.pth") +# self.pretrained.load_state_dict(state_dict, strict=True) +# self.pretrained.requires_grad_(False) +# dim = self.pretrained.blocks[0].attn.qkv.in_features +# if enc_only == True: +# out_channels=[128, 128, 128, 128] + +# self.DPThead = DPTHeadEnc(1, dim, features, use_bn, out_channels=out_channels, use_clstoken=use_clstoken) + + +# def forward(self, x): +# mean_ = torch.tensor([0.485, 0.456, 0.406], +# device=x.device).view(1, 3, 1, 1) +# std_ = torch.tensor([0.229, 0.224, 0.225], +# device=x.device).view(1, 3, 1, 1) +# x = (x+1)/2 +# x = (x - mean_)/std_ +# h, w = x.shape[-2:] +# h_re, w_re = 560, 560 +# x_resize = F.interpolate(x, size=(h_re, w_re), +# mode='bilinear', align_corners=False) +# with torch.no_grad(): +# features = self.pretrained.get_intermediate_layers(x_resize, 4, return_class_token=True) +# patch_h, patch_w = h_re // 14, w_re // 14 +# feat = self.DPThead(features, patch_h, patch_w, self.enc_only) +# feat = F.interpolate(feat, size=(h//self.stride, w//self.stride), mode="bilinear", align_corners=True) + +# return feat + + +class VGG19(nn.Module): + def __init__(self, pretrained=False, amp = False, amp_dtype = torch.float16) -> None: + super().__init__() + self.layers = nn.ModuleList(tvm.vgg19_bn(pretrained=pretrained).features[:40]) + self.amp = amp + self.amp_dtype = amp_dtype + + def forward(self, x, **kwargs): + with torch.autocast("cuda", enabled=self.amp, dtype = self.amp_dtype): + feats = {} + scale = 1 + for layer in self.layers: + if isinstance(layer, nn.MaxPool2d): + feats[scale] = x + scale = scale*2 + x = layer(x) + return feats + +class CNNandDinov2(nn.Module): + def __init__(self, cnn_kwargs = None, amp = True, amp_dtype = torch.float16): + super().__init__() + # in case the Internet connection is not stable, please load the DINOv2 locally + self.dinov2_vitl14 = torch.hub.load('models/torchhub/facebookresearch_dinov2_main', + 'dinov2_{:}14'.format("vitl"), source='local', pretrained=False) + + state_dict = torch.load("models/monoD/zoeDepth/ckpts/dinov2_vitl14_pretrain.pth") + self.dinov2_vitl14.load_state_dict(state_dict, strict=True) + + + cnn_kwargs = cnn_kwargs if cnn_kwargs is not None else {} + self.cnn = VGG19(**cnn_kwargs) + self.amp = amp + self.amp_dtype = amp_dtype + if self.amp: + dinov2_vitl14 = dinov2_vitl14.to(self.amp_dtype) + self.dinov2_vitl14 = [dinov2_vitl14] # ugly hack to not show parameters to DDP + + + def train(self, mode: bool = True): + return self.cnn.train(mode) + + def forward(self, x, upsample = False): + B,C,H,W = x.shape + feature_pyramid = self.cnn(x) + + if not upsample: + with torch.no_grad(): + if self.dinov2_vitl14[0].device != x.device: + self.dinov2_vitl14[0] = self.dinov2_vitl14[0].to(x.device).to(self.amp_dtype) + dinov2_features_16 = self.dinov2_vitl14[0].forward_features(x.to(self.amp_dtype)) + features_16 = dinov2_features_16['x_norm_patchtokens'].permute(0,2,1).reshape(B,1024,H//14, W//14) + del dinov2_features_16 + feature_pyramid[16] = features_16 + return feature_pyramid + +class Dinov2(nn.Module): + def __init__(self, amp = True, amp_dtype = torch.float16): + super().__init__() + # in case the Internet connection is not stable, please load the DINOv2 locally + self.dinov2_vitl14 = torch.hub.load('models/torchhub/facebookresearch_dinov2_main', + 'dinov2_{:}14'.format("vitl"), source='local', pretrained=False) + + state_dict = torch.load("models/monoD/zoeDepth/ckpts/dinov2_vitl14_pretrain.pth") + self.dinov2_vitl14.load_state_dict(state_dict, strict=True) + + self.amp = amp + self.amp_dtype = amp_dtype + if self.amp: + self.dinov2_vitl14 = self.dinov2_vitl14.to(self.amp_dtype) + + def forward(self, x, upsample = False): + B,C,H,W = x.shape + mean_ = torch.tensor([0.485, 0.456, 0.406], + device=x.device).view(1, 3, 1, 1) + std_ = torch.tensor([0.229, 0.224, 0.225], + device=x.device).view(1, 3, 1, 1) + x = (x+1)/2 + x = (x - mean_)/std_ + h_re, w_re = 560, 560 + x_resize = F.interpolate(x, size=(h_re, w_re), + mode='bilinear', align_corners=True) + if not upsample: + with torch.no_grad(): + dinov2_features_16 = self.dinov2_vitl14.forward_features(x_resize.to(self.amp_dtype)) + features_16 = dinov2_features_16['x_norm_patchtokens'].permute(0,2,1).reshape(B,1024,h_re//14, w_re//14) + del dinov2_features_16 + features_16 = F.interpolate(features_16, size=(H//8, W//8), mode="bilinear", align_corners=True) + return features_16 + +class AttnBlock(nn.Module): + """ + A DiT block with adaptive layer norm zero (adaLN-Zero) conditioning. + """ + + def __init__(self, hidden_size, num_heads, mlp_ratio=4.0, + flash=False, **block_kwargs): + super().__init__() + self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.flash=flash + + self.attn = Attention( + hidden_size, num_heads=num_heads, qkv_bias=True, flash=flash, + **block_kwargs + ) + + self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + mlp_hidden_dim = int(hidden_size * mlp_ratio) + approx_gelu = lambda: nn.GELU(approximate="tanh") + self.mlp = Mlp( + in_features=hidden_size, + hidden_features=mlp_hidden_dim, + act_layer=approx_gelu, + drop=0, + ) + def forward(self, x): + x = x + self.attn(self.norm1(x)) + x = x + self.mlp(self.norm2(x)) + return x + +class CrossAttnBlock(nn.Module): + def __init__(self, hidden_size, context_dim, num_heads=1, mlp_ratio=4.0, + flash=True, **block_kwargs): + super().__init__() + self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.norm_context = nn.LayerNorm(hidden_size) + + self.cross_attn = Attention( + hidden_size, context_dim=context_dim, + num_heads=num_heads, qkv_bias=True, **block_kwargs, flash=flash + + ) + + self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + mlp_hidden_dim = int(hidden_size * mlp_ratio) + approx_gelu = lambda: nn.GELU(approximate="tanh") + self.mlp = Mlp( + in_features=hidden_size, + hidden_features=mlp_hidden_dim, + act_layer=approx_gelu, + drop=0, + ) + + def forward(self, x, context): + with autocast(): + x = x + self.cross_attn( + self.norm1(x), self.norm_context(context) + ) + x = x + self.mlp(self.norm2(x)) + return x + + +def bilinear_sampler(img, coords, mode="bilinear", mask=False): + """Wrapper for grid_sample, uses pixel coordinates""" + H, W = img.shape[-2:] + xgrid, ygrid = coords.split([1, 1], dim=-1) + # go to 0,1 then 0,2 then -1,1 + xgrid = 2 * xgrid / (W - 1) - 1 + ygrid = 2 * ygrid / (H - 1) - 1 + + grid = torch.cat([xgrid, ygrid], dim=-1) + img = F.grid_sample(img, grid, align_corners=True) + + if mask: + mask = (xgrid > -1) & (ygrid > -1) & (xgrid < 1) & (ygrid < 1) + return img, mask.float() + + return img + + +class CorrBlock: + def __init__(self, fmaps, num_levels=4, radius=4, depths_dnG=None): + B, S, C, H_prev, W_prev = fmaps.shape + self.S, self.C, self.H, self.W = S, C, H_prev, W_prev + + self.num_levels = num_levels + self.radius = radius + self.fmaps_pyramid = [] + self.depth_pyramid = [] + self.fmaps_pyramid.append(fmaps) + if depths_dnG is not None: + self.depth_pyramid.append(depths_dnG) + for i in range(self.num_levels - 1): + if depths_dnG is not None: + depths_dnG_ = depths_dnG.reshape(B * S, 1, H_prev, W_prev) + depths_dnG_ = F.avg_pool2d(depths_dnG_, 2, stride=2) + _, _, H, W = depths_dnG_.shape + depths_dnG = depths_dnG_.reshape(B, S, 1, H, W) + self.depth_pyramid.append(depths_dnG) + fmaps_ = fmaps.reshape(B * S, C, H_prev, W_prev) + fmaps_ = F.avg_pool2d(fmaps_, 2, stride=2) + _, _, H, W = fmaps_.shape + fmaps = fmaps_.reshape(B, S, C, H, W) + H_prev = H + W_prev = W + self.fmaps_pyramid.append(fmaps) + + def sample(self, coords): + r = self.radius + B, S, N, D = coords.shape + assert D == 2 + + H, W = self.H, self.W + out_pyramid = [] + for i in range(self.num_levels): + corrs = self.corrs_pyramid[i] # B, S, N, H, W + _, _, _, H, W = corrs.shape + + dx = torch.linspace(-r, r, 2 * r + 1) + dy = torch.linspace(-r, r, 2 * r + 1) + delta = torch.stack(torch.meshgrid(dy, dx, indexing="ij"), axis=-1).to( + coords.device + ) + centroid_lvl = coords.reshape(B * S * N, 1, 1, 2) / 2 ** i + delta_lvl = delta.view(1, 2 * r + 1, 2 * r + 1, 2) + coords_lvl = centroid_lvl + delta_lvl + corrs = bilinear_sampler(corrs.reshape(B * S * N, 1, H, W), coords_lvl) + corrs = corrs.view(B, S, N, -1) + out_pyramid.append(corrs) + + out = torch.cat(out_pyramid, dim=-1) # B, S, N, LRR*2 + return out.contiguous().float() + + def corr(self, targets): + B, S, N, C = targets.shape + assert C == self.C + assert S == self.S + + fmap1 = targets + + self.corrs_pyramid = [] + for fmaps in self.fmaps_pyramid: + _, _, _, H, W = fmaps.shape + fmap2s = fmaps.view(B, S, C, H * W) + corrs = torch.matmul(fmap1, fmap2s) + corrs = corrs.view(B, S, N, H, W) + corrs = corrs / torch.sqrt(torch.tensor(C).float()) + self.corrs_pyramid.append(corrs) + + def corr_sample(self, targets, coords, coords_dp=None): + B, S, N, C = targets.shape + r = self.radius + Dim_c = (2*r+1)**2 + assert C == self.C + assert S == self.S + + out_pyramid = [] + out_pyramid_dp = [] + for i in range(self.num_levels): + dx = torch.linspace(-r, r, 2 * r + 1) + dy = torch.linspace(-r, r, 2 * r + 1) + delta = torch.stack(torch.meshgrid(dy, dx, indexing="ij"), axis=-1).to( + coords.device + ) + centroid_lvl = coords.reshape(B * S * N, 1, 1, 2) / 2 ** i + delta_lvl = delta.view(1, 2 * r + 1, 2 * r + 1, 2) + coords_lvl = centroid_lvl + delta_lvl + fmaps = self.fmaps_pyramid[i] + _, _, _, H, W = fmaps.shape + fmap2s = fmaps.view(B*S, C, H, W) + if len(self.depth_pyramid)>0: + depths_dnG_i = self.depth_pyramid[i] + depths_dnG_i = depths_dnG_i.view(B*S, 1, H, W) + dnG_sample = bilinear_sampler(depths_dnG_i, coords_lvl.view(B*S,1,N*Dim_c,2)) + dp_corrs = (dnG_sample.view(B*S,N,-1) - coords_dp[0]).abs()/coords_dp[0] + out_pyramid_dp.append(dp_corrs) + fmap2s_sample = bilinear_sampler(fmap2s, coords_lvl.view(B*S,1,N*Dim_c,2)) + fmap2s_sample = fmap2s_sample.permute(0, 3, 1, 2) # B*S, N*Dim_c, C, -1 + corrs = torch.matmul(targets.reshape(B*S*N, 1, -1), fmap2s_sample.reshape(B*S*N, Dim_c, -1).permute(0, 2, 1)) + corrs = corrs / torch.sqrt(torch.tensor(C).float()) + corrs = corrs.view(B, S, N, -1) + out_pyramid.append(corrs) + + out = torch.cat(out_pyramid, dim=-1) # B, S, N, LRR*2 + if len(self.depth_pyramid)>0: + out_dp = torch.cat(out_pyramid_dp, dim=-1) + self.fcorrD = out_dp.contiguous().float() + else: + self.fcorrD = torch.zeros_like(out).contiguous().float() + return out.contiguous().float() + + +class EUpdateFormer(nn.Module): + """ + Transformer model that updates track estimates. + """ + + def __init__( + self, + space_depth=12, + time_depth=12, + input_dim=320, + hidden_size=384, + num_heads=8, + output_dim=130, + mlp_ratio=4.0, + vq_depth=3, + add_space_attn=True, + add_time_attn=True, + flash=True + ): + super().__init__() + self.out_channels = 2 + self.num_heads = num_heads + self.hidden_size = hidden_size + self.add_space_attn = add_space_attn + self.input_transform = torch.nn.Linear(input_dim, hidden_size, bias=True) + self.flash = flash + self.flow_head = nn.Sequential( + nn.Linear(hidden_size, output_dim, bias=True), + nn.ReLU(inplace=True), + nn.Linear(output_dim, output_dim, bias=True), + nn.ReLU(inplace=True), + nn.Linear(output_dim, output_dim, bias=True) + ) + + cross_attn_kwargs = { + "d_model": 384, + "nhead": 4, + "layer_names": ['self', 'cross'] * 3, + } + self.gnn = LocalFeatureTransformer(cross_attn_kwargs) + + # Attention Modules in the temporal dimension + self.time_blocks = nn.ModuleList( + [ + AttnBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio, flash=flash) if add_time_attn else nn.Identity() + for _ in range(time_depth) + ] + ) + + if add_space_attn: + self.space_blocks = nn.ModuleList( + [ + AttnBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio, flash=flash) + for _ in range(space_depth) + ] + ) + assert len(self.time_blocks) >= len(self.space_blocks) + + # Placeholder for the rigid transformation + self.RigidProj = nn.Linear(self.hidden_size, 128, bias=True) + self.Proj = nn.Linear(self.hidden_size, 128, bias=True) + + self.se3_dec = nn.Linear(384, 3, bias=True) + self.initialize_weights() + + def initialize_weights(self): + def _basic_init(module): + if isinstance(module, nn.Linear): + torch.nn.init.xavier_uniform_(module.weight) + if module.bias is not None: + nn.init.constant_(module.bias, 0) + + self.apply(_basic_init) + + def forward(self, input_tensor, se3_feature): + """ Updating with Transformer + + Args: + input_tensor: B, N, T, C + arap_embed: B, N, T, C + """ + B, N, T, C = input_tensor.shape + x = self.input_transform(input_tensor) + tokens = x + K = 0 + j = 0 + for i in range(len(self.time_blocks)): + tokens_time = rearrange(tokens, "b n t c -> (b n) t c", b=B, t=T, n=N+K) + tokens_time = self.time_blocks[i](tokens_time) + tokens = rearrange(tokens_time, "(b n) t c -> b n t c ", b=B, t=T, n=N+K) + if self.add_space_attn and ( + i % (len(self.time_blocks) // len(self.space_blocks)) == 0 + ): + tokens_space = rearrange(tokens, "b n t c -> (b t) n c ", b=B, t=T, n=N) + tokens_space = self.space_blocks[j](tokens_space) + tokens = rearrange(tokens_space, "(b t) n c -> b n t c ", b=B, t=T, n=N) + j += 1 + + B, N, S, _ = tokens.shape + feat0, feat1 = self.gnn(tokens.view(B*N*S, -1)[None,...], se3_feature[None, ...]) + + so3 = F.tanh(self.se3_dec(feat0.view(B*N*S, -1)[None,...].view(B, N, S, -1))/100) + flow = self.flow_head(feat0.view(B,N,S,-1)) + + return flow, _, _, feat1, so3 + + +class FusionFormer(nn.Module): + """ + Fuse the feature tracks info with the low rank motion tokens + """ + def __init__( + self, + d_model=64, + nhead=8, + attn_iters=4, + mlp_ratio=4.0, + flash=False, + input_dim=35, + output_dim=384+3, + ): + super().__init__() + self.flash = flash + self.in_proj = nn.ModuleList( + [ + nn.Linear(input_dim, d_model) + for _ in range(2) + ] + ) + self.out_proj = nn.Linear(d_model, output_dim, bias=True) + self.time_blocks = nn.ModuleList( + [ + CrossAttnBlock(d_model, d_model, nhead, mlp_ratio=mlp_ratio) + for _ in range(attn_iters) + ] + ) + self.space_blocks = nn.ModuleList( + [ + AttnBlock(d_model, nhead, mlp_ratio=mlp_ratio, flash=self.flash) + for _ in range(attn_iters) + ] + ) + + self.initialize_weights() + + def initialize_weights(self): + def _basic_init(module): + if isinstance(module, nn.Linear): + torch.nn.init.xavier_uniform_(module.weight) + if module.bias is not None: + nn.init.constant_(module.bias, 0) + self.apply(_basic_init) + self.out_proj.weight.data.fill_(0) + self.out_proj.bias.data.fill_(0) + + def forward(self, x, token_cls): + """ Fuse the feature tracks info with the low rank motion tokens + + Args: + x: B, S, N, C + Traj_whole: B T N C + + """ + B, S, N, C = x.shape + _, T, _, _ = token_cls.shape + x = self.in_proj[0](x) + token_cls = self.in_proj[1](token_cls) + token_cls = rearrange(token_cls, 'b t n c -> (b n) t c') + + for i in range(len(self.space_blocks)): + x = rearrange(x, 'b s n c -> (b n) s c') + x = self.time_blocks[i](x, token_cls) + x = self.space_blocks[i](x.permute(1,0,2)) + x = rearrange(x, '(b s) n c -> b s n c', b=B, s=S, n=N) + + x = self.out_proj(x) + delta_xyz = x[..., :3] + feat_traj = x[..., 3:] + return delta_xyz, feat_traj + +class Lie(): + """ + Lie algebra for SO(3) and SE(3) operations in PyTorch + """ + + def so3_to_SO3(self,w): # [...,3] + wx = self.skew_symmetric(w) + theta = w.norm(dim=-1)[...,None,None] + I = torch.eye(3,device=w.device,dtype=torch.float32) + A = self.taylor_A(theta) + B = self.taylor_B(theta) + R = I+A*wx+B*wx@wx + return R + + def SO3_to_so3(self,R,eps=1e-7): # [...,3,3] + trace = R[...,0,0]+R[...,1,1]+R[...,2,2] + theta = ((trace-1)/2).clamp(-1+eps,1-eps).acos_()[...,None,None]%np.pi # ln(R) will explode if theta==pi + lnR = 1/(2*self.taylor_A(theta)+1e-8)*(R-R.transpose(-2,-1)) # FIXME: wei-chiu finds it weird + w0,w1,w2 = lnR[...,2,1],lnR[...,0,2],lnR[...,1,0] + w = torch.stack([w0,w1,w2],dim=-1) + return w + + def se3_to_SE3(self,wu): # [...,3] + w,u = wu.split([3,3],dim=-1) + wx = self.skew_symmetric(w) + theta = w.norm(dim=-1)[...,None,None] + I = torch.eye(3,device=w.device,dtype=torch.float32) + A = self.taylor_A(theta) + B = self.taylor_B(theta) + C = self.taylor_C(theta) + R = I+A*wx+B*wx@wx + V = I+B*wx+C*wx@wx + Rt = torch.cat([R,(V@u[...,None])],dim=-1) + return Rt + + def SE3_to_se3(self,Rt,eps=1e-8): # [...,3,4] + R,t = Rt.split([3,1],dim=-1) + w = self.SO3_to_so3(R) + wx = self.skew_symmetric(w) + theta = w.norm(dim=-1)[...,None,None] + I = torch.eye(3,device=w.device,dtype=torch.float32) + A = self.taylor_A(theta) + B = self.taylor_B(theta) + invV = I-0.5*wx+(1-A/(2*B))/(theta**2+eps)*wx@wx + u = (invV@t)[...,0] + wu = torch.cat([w,u],dim=-1) + return wu + + def skew_symmetric(self,w): + w0,w1,w2 = w.unbind(dim=-1) + O = torch.zeros_like(w0) + wx = torch.stack([torch.stack([O,-w2,w1],dim=-1), + torch.stack([w2,O,-w0],dim=-1), + torch.stack([-w1,w0,O],dim=-1)],dim=-2) + return wx + + def taylor_A(self,x,nth=10): + # Taylor expansion of sin(x)/x + ans = torch.zeros_like(x) + denom = 1. + for i in range(nth+1): + if i>0: denom *= (2*i)*(2*i+1) + ans = ans+(-1)**i*x**(2*i)/denom + return ans + def taylor_B(self,x,nth=10): + # Taylor expansion of (1-cos(x))/x**2 + ans = torch.zeros_like(x) + denom = 1. + for i in range(nth+1): + denom *= (2*i+1)*(2*i+2) + ans = ans+(-1)**i*x**(2*i)/denom + return ans + def taylor_C(self,x,nth=10): + # Taylor expansion of (x-sin(x))/x**3 + ans = torch.zeros_like(x) + denom = 1. + for i in range(nth+1): + denom *= (2*i+2)*(2*i+3) + ans = ans+(-1)**i*x**(2*i)/denom + return ans + + + +def pix2cam(coords, + intr): + """ + Args: + coords: [B, T, N, 3] + intr: [B, T, 3, 3] + """ + coords=coords.detach() + B, S, N, _, = coords.shape + xy_src = coords.reshape(B*S*N, 3) + intr = intr[:, :, None, ...].repeat(1, 1, N, 1, 1).reshape(B*S*N, 3, 3) + xy_src = torch.cat([xy_src[..., :2], torch.ones_like(xy_src[..., :1])], dim=-1) + xyz_src = (torch.inverse(intr)@xy_src[...,None])[...,0] + dp_pred = coords[..., 2] + xyz_src_ = (xyz_src*(dp_pred.reshape(S*N, 1))) + xyz_src_ = xyz_src_.reshape(B, S, N, 3) + return xyz_src_ + +def cam2pix(coords, + intr): + """ + Args: + coords: [B, T, N, 3] + intr: [B, T, 3, 3] + """ + coords=coords.detach() + B, S, N, _, = coords.shape + xy_src = coords.reshape(B*S*N, 3).clone() + intr = intr[:, :, None, ...].repeat(1, 1, N, 1, 1).reshape(B*S*N, 3, 3) + xy_src = xy_src / (xy_src[..., 2:]+1e-5) + xyz_src = (intr@xy_src[...,None])[...,0] + dp_pred = coords[..., 2] + xyz_src[...,2] *= dp_pred.reshape(S*N) + xyz_src = xyz_src.reshape(B, S, N, 3) + return xyz_src + +def edgeMat(traj3d): + """ + Args: + traj3d: [B, T, N, 3] + """ + B, T, N, _ = traj3d.shape + traj3d = traj3d + traj3d = traj3d.view(B, T, N, 3) + traj3d = traj3d[..., None, :] - traj3d[..., None, :, :] # B, T, N, N, 3 + edgeMat = traj3d.norm(dim=-1) # B, T, N, N + return edgeMat \ No newline at end of file diff --git a/models/spatracker/models/core/spatracker/dpt/__init__.py b/models/spatracker/models/core/spatracker/dpt/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/models/spatracker/models/core/spatracker/dpt/base_model.py b/models/spatracker/models/core/spatracker/dpt/base_model.py new file mode 100644 index 0000000000000000000000000000000000000000..5c2e0e93b0495f48a3405546b6fe1969be3480a2 --- /dev/null +++ b/models/spatracker/models/core/spatracker/dpt/base_model.py @@ -0,0 +1,16 @@ +import torch + + +class BaseModel(torch.nn.Module): + def load(self, path): + """Load model from file. + + Args: + path (str): file path + """ + parameters = torch.load(path, map_location=torch.device("cpu")) + + if "optimizer" in parameters: + parameters = parameters["model"] + + self.load_state_dict(parameters) diff --git a/models/spatracker/models/core/spatracker/dpt/blocks.py b/models/spatracker/models/core/spatracker/dpt/blocks.py new file mode 100644 index 0000000000000000000000000000000000000000..121816f548644999ace947daaf31858e3c596cf1 --- /dev/null +++ b/models/spatracker/models/core/spatracker/dpt/blocks.py @@ -0,0 +1,394 @@ +import torch +import torch.nn as nn + +from models.spatracker.models.core.spatracker.dpt.vit import ( + _make_pretrained_vitb_rn50_384, + _make_pretrained_vitl16_384, + _make_pretrained_vitb16_384, + forward_vit, + _make_pretrained_vit_tiny +) + + +def _make_encoder( + backbone, + features, + use_pretrained, + groups=1, + expand=False, + exportable=True, + hooks=None, + use_vit_only=False, + use_readout="ignore", + enable_attention_hooks=False, +): + if backbone == "vitl16_384": + pretrained = _make_pretrained_vitl16_384( + use_pretrained, + hooks=hooks, + use_readout=use_readout, + enable_attention_hooks=enable_attention_hooks, + ) + scratch = _make_scratch( + [256, 512, 1024, 1024], features, groups=groups, expand=expand + ) # ViT-L/16 - 85.0% Top1 (backbone) + elif backbone == "vitb_rn50_384": + pretrained = _make_pretrained_vitb_rn50_384( + use_pretrained, + hooks=hooks, + use_vit_only=use_vit_only, + use_readout=use_readout, + enable_attention_hooks=enable_attention_hooks, + ) + scratch = _make_scratch( + [256, 512, 768, 768], features, groups=groups, expand=expand + ) # ViT-H/16 - 85.0% Top1 (backbone) + elif backbone == "vitb16_384": + pretrained = _make_pretrained_vitb16_384( + use_pretrained, + hooks=hooks, + use_readout=use_readout, + enable_attention_hooks=enable_attention_hooks, + ) + scratch = _make_scratch( + [96, 192, 384, 768], features, groups=groups, expand=expand + ) # ViT-B/16 - 84.6% Top1 (backbone) + elif backbone == "resnext101_wsl": + pretrained = _make_pretrained_resnext101_wsl(use_pretrained) + scratch = _make_scratch( + [256, 512, 1024, 2048], features, groups=groups, expand=expand + ) # efficientnet_lite3 + elif backbone == "vit_tiny_r_s16_p8_384": + pretrained = _make_pretrained_vit_tiny( + use_pretrained, + hooks=hooks, + use_readout=use_readout, + enable_attention_hooks=enable_attention_hooks, + ) + scratch = _make_scratch( + [96, 192, 384, 768], features, groups=groups, expand=expand + ) + else: + print(f"Backbone '{backbone}' not implemented") + assert False + + return pretrained, scratch + + +def _make_scratch(in_shape, out_shape, groups=1, expand=False): + scratch = nn.Module() + + out_shape1 = out_shape + out_shape2 = out_shape + out_shape3 = out_shape + out_shape4 = out_shape + if expand == True: + out_shape1 = out_shape + out_shape2 = out_shape * 2 + out_shape3 = out_shape * 4 + out_shape4 = out_shape * 8 + + scratch.layer1_rn = nn.Conv2d( + in_shape[0], + out_shape1, + kernel_size=3, + stride=1, + padding=1, + bias=False, + groups=groups, + ) + scratch.layer2_rn = nn.Conv2d( + in_shape[1], + out_shape2, + kernel_size=3, + stride=1, + padding=1, + bias=False, + groups=groups, + ) + scratch.layer3_rn = nn.Conv2d( + in_shape[2], + out_shape3, + kernel_size=3, + stride=1, + padding=1, + bias=False, + groups=groups, + ) + scratch.layer4_rn = nn.Conv2d( + in_shape[3], + out_shape4, + kernel_size=3, + stride=1, + padding=1, + bias=False, + groups=groups, + ) + + return scratch + + +def _make_resnet_backbone(resnet): + pretrained = nn.Module() + pretrained.layer1 = nn.Sequential( + resnet.conv1, resnet.bn1, resnet.relu, resnet.maxpool, resnet.layer1 + ) + + pretrained.layer2 = resnet.layer2 + pretrained.layer3 = resnet.layer3 + pretrained.layer4 = resnet.layer4 + + return pretrained + + +def _make_pretrained_resnext101_wsl(use_pretrained): + resnet = torch.hub.load("facebookresearch/WSL-Images", "resnext101_32x8d_wsl") + return _make_resnet_backbone(resnet) + + +class Interpolate(nn.Module): + """Interpolation module.""" + + def __init__(self, scale_factor, mode, align_corners=False): + """Init. + + Args: + scale_factor (float): scaling + mode (str): interpolation mode + """ + super(Interpolate, self).__init__() + + self.interp = nn.functional.interpolate + self.scale_factor = scale_factor + self.mode = mode + self.align_corners = align_corners + + def forward(self, x): + """Forward pass. + + Args: + x (tensor): input + + Returns: + tensor: interpolated data + """ + + x = self.interp( + x, + scale_factor=self.scale_factor, + mode=self.mode, + align_corners=self.align_corners, + ) + + return x + + +class ResidualConvUnit(nn.Module): + """Residual convolution module.""" + + def __init__(self, features): + """Init. + + Args: + features (int): number of features + """ + super().__init__() + + self.conv1 = nn.Conv2d( + features, features, kernel_size=3, stride=1, padding=1, bias=True + ) + + self.conv2 = nn.Conv2d( + features, features, kernel_size=3, stride=1, padding=1, bias=True + ) + + self.relu = nn.ReLU(inplace=True) + + def forward(self, x): + """Forward pass. + + Args: + x (tensor): input + + Returns: + tensor: output + """ + out = self.relu(x) + out = self.conv1(out) + out = self.relu(out) + out = self.conv2(out) + + return out + x + + +class FeatureFusionBlock(nn.Module): + """Feature fusion block.""" + + def __init__(self, features): + """Init. + + Args: + features (int): number of features + """ + super(FeatureFusionBlock, self).__init__() + + self.resConfUnit1 = ResidualConvUnit(features) + self.resConfUnit2 = ResidualConvUnit(features) + + def forward(self, *xs): + """Forward pass. + + Returns: + tensor: output + """ + output = xs[0] + + if len(xs) == 2: + output += self.resConfUnit1(xs[1]) + + output = self.resConfUnit2(output) + + output = nn.functional.interpolate( + output, scale_factor=2, mode="bilinear", align_corners=True + ) + + return output + + +class ResidualConvUnit_custom(nn.Module): + """Residual convolution module.""" + + def __init__(self, features, activation, bn): + """Init. + + Args: + features (int): number of features + """ + super().__init__() + + self.bn = bn + + self.groups = 1 + + self.conv1 = nn.Conv2d( + features, + features, + kernel_size=3, + stride=1, + padding=1, + bias=not self.bn, + groups=self.groups, + ) + + self.conv2 = nn.Conv2d( + features, + features, + kernel_size=3, + stride=1, + padding=1, + bias=not self.bn, + groups=self.groups, + ) + + if self.bn == True: + self.bn1 = nn.BatchNorm2d(features) + self.bn2 = nn.BatchNorm2d(features) + + self.activation = activation + + self.skip_add = nn.quantized.FloatFunctional() + + def forward(self, x): + """Forward pass. + + Args: + x (tensor): input + + Returns: + tensor: output + """ + + out = self.activation(x) + out = self.conv1(out) + if self.bn == True: + out = self.bn1(out) + + out = self.activation(out) + out = self.conv2(out) + if self.bn == True: + out = self.bn2(out) + + if self.groups > 1: + out = self.conv_merge(out) + + return self.skip_add.add(out, x) + + # return out + x + + +class FeatureFusionBlock_custom(nn.Module): + """Feature fusion block.""" + + def __init__( + self, + features, + activation, + deconv=False, + bn=False, + expand=False, + align_corners=True, + ): + """Init. + + Args: + features (int): number of features + """ + super(FeatureFusionBlock_custom, self).__init__() + + self.deconv = deconv + self.align_corners = align_corners + + self.groups = 1 + + self.expand = expand + out_features = features + if self.expand == True: + out_features = features // 2 + + self.out_conv = nn.Conv2d( + features, + out_features, + kernel_size=1, + stride=1, + padding=0, + bias=True, + groups=1, + ) + + self.resConfUnit1 = ResidualConvUnit_custom(features, activation, bn) + self.resConfUnit2 = ResidualConvUnit_custom(features, activation, bn) + + self.skip_add = nn.quantized.FloatFunctional() + + def forward(self, *xs): + """Forward pass. + + Returns: + tensor: output + """ + output = xs[0] + + if len(xs) == 2: + res = self.resConfUnit1(xs[1]) + output = self.skip_add.add(output, res) + # output += res + + output = self.resConfUnit2(output) + + output = nn.functional.interpolate( + output, scale_factor=2, mode="bilinear", align_corners=self.align_corners + ) + + output = self.out_conv(output) + + return output diff --git a/models/spatracker/models/core/spatracker/dpt/midas_net.py b/models/spatracker/models/core/spatracker/dpt/midas_net.py new file mode 100644 index 0000000000000000000000000000000000000000..182d36eb7d386cf6b6cfe60ec5754228ee7b0859 --- /dev/null +++ b/models/spatracker/models/core/spatracker/dpt/midas_net.py @@ -0,0 +1,77 @@ +"""MidashNet: Network for monocular depth estimation trained by mixing several datasets. +This file contains code that is adapted from +https://github.com/thomasjpfan/pytorch_refinenet/blob/master/pytorch_refinenet/refinenet/refinenet_4cascade.py +""" +import torch +import torch.nn as nn + +from models.spatracker.models.core.spatracker.dpt.base_model import BaseModel +from models.spatracker.models.core.spatracker.dpt.blocks import FeatureFusionBlock, Interpolate, _make_encoder + + +class MidasNet_large(BaseModel): + """Network for monocular depth estimation.""" + + def __init__(self, path=None, features=256, non_negative=True): + """Init. + + Args: + path (str, optional): Path to saved model. Defaults to None. + features (int, optional): Number of features. Defaults to 256. + backbone (str, optional): Backbone network for encoder. Defaults to resnet50 + """ + print("Loading weights: ", path) + + super(MidasNet_large, self).__init__() + + use_pretrained = False if path is None else True + + self.pretrained, self.scratch = _make_encoder( + backbone="resnext101_wsl", features=features, use_pretrained=use_pretrained + ) + + self.scratch.refinenet4 = FeatureFusionBlock(features) + self.scratch.refinenet3 = FeatureFusionBlock(features) + self.scratch.refinenet2 = FeatureFusionBlock(features) + self.scratch.refinenet1 = FeatureFusionBlock(features) + + self.scratch.output_conv = nn.Sequential( + nn.Conv2d(features, 128, kernel_size=3, stride=1, padding=1), + Interpolate(scale_factor=2, mode="bilinear"), + nn.Conv2d(128, 32, kernel_size=3, stride=1, padding=1), + nn.ReLU(True), + nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0), + nn.ReLU(True) if non_negative else nn.Identity(), + ) + + if path: + self.load(path) + + def forward(self, x): + """Forward pass. + + Args: + x (tensor): input data (image) + + Returns: + tensor: depth + """ + + layer_1 = self.pretrained.layer1(x) + layer_2 = self.pretrained.layer2(layer_1) + layer_3 = self.pretrained.layer3(layer_2) + layer_4 = self.pretrained.layer4(layer_3) + + layer_1_rn = self.scratch.layer1_rn(layer_1) + layer_2_rn = self.scratch.layer2_rn(layer_2) + layer_3_rn = self.scratch.layer3_rn(layer_3) + layer_4_rn = self.scratch.layer4_rn(layer_4) + + path_4 = self.scratch.refinenet4(layer_4_rn) + path_3 = self.scratch.refinenet3(path_4, layer_3_rn) + path_2 = self.scratch.refinenet2(path_3, layer_2_rn) + path_1 = self.scratch.refinenet1(path_2, layer_1_rn) + + out = self.scratch.output_conv(path_1) + + return torch.squeeze(out, dim=1) diff --git a/models/spatracker/models/core/spatracker/dpt/models.py b/models/spatracker/models/core/spatracker/dpt/models.py new file mode 100644 index 0000000000000000000000000000000000000000..3177d3fb76f42d0f25de13ea0c8f76c4b9273364 --- /dev/null +++ b/models/spatracker/models/core/spatracker/dpt/models.py @@ -0,0 +1,231 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +from models.spatracker.models.core.spatracker.dpt.base_model import BaseModel +from models.spatracker.models.core.spatracker.dpt.blocks import ( + FeatureFusionBlock, + FeatureFusionBlock_custom, + Interpolate, + _make_encoder, + forward_vit, +) + + +def _make_fusion_block(features, use_bn): + return FeatureFusionBlock_custom( + features, + nn.ReLU(False), + deconv=False, + bn=use_bn, + expand=False, + align_corners=True, + ) + + +class DPT(BaseModel): + def __init__( + self, + head, + features=256, + backbone="vitb_rn50_384", + readout="project", + channels_last=False, + use_bn=True, + enable_attention_hooks=False, + ): + + super(DPT, self).__init__() + + self.channels_last = channels_last + + hooks = { + "vitb_rn50_384": [0, 1, 8, 11], + "vitb16_384": [2, 5, 8, 11], + "vitl16_384": [5, 11, 17, 23], + "vit_tiny_r_s16_p8_384": [0, 1, 2, 3], + } + + # Instantiate backbone and reassemble blocks + self.pretrained, self.scratch = _make_encoder( + backbone, + features, + False, # Set to true of you want to train from scratch, uses ImageNet weights + groups=1, + expand=False, + exportable=False, + hooks=hooks[backbone], + use_readout=readout, + enable_attention_hooks=enable_attention_hooks, + ) + + self.scratch.refinenet1 = _make_fusion_block(features, use_bn) + self.scratch.refinenet2 = _make_fusion_block(features, use_bn) + self.scratch.refinenet3 = _make_fusion_block(features, use_bn) + self.scratch.refinenet4 = _make_fusion_block(features, use_bn) + + self.scratch.output_conv = head + + self.proj_out = nn.Sequential( + nn.Conv2d( + 256+512+384+384, + 256, + kernel_size=3, + padding=1, + padding_mode="zeros", + ), + nn.BatchNorm2d(128 * 2), + nn.ReLU(True), + nn.Conv2d( + 128 * 2, + 128, + kernel_size=3, + padding=1, + padding_mode="zeros", + ) + ) + + + def forward(self, x, only_enc=False): + if self.channels_last == True: + x.contiguous(memory_format=torch.channels_last) + if only_enc: + layer_1, layer_2, layer_3, layer_4 = forward_vit(self.pretrained, x) + a = (layer_1) + b = ( + F.interpolate( + layer_2, + scale_factor=2, + mode="bilinear", + align_corners=True, + ) + ) + c = ( + F.interpolate( + layer_3, + scale_factor=8, + mode="bilinear", + align_corners=True, + ) + ) + d = ( + F.interpolate( + layer_4, + scale_factor=16, + mode="bilinear", + align_corners=True, + ) + ) + x = self.proj_out(torch.cat([a, b, c, d], dim=1)) + return x + else: + layer_1, layer_2, layer_3, layer_4 = forward_vit(self.pretrained, x) + + layer_1_rn = self.scratch.layer1_rn(layer_1) + layer_2_rn = self.scratch.layer2_rn(layer_2) + layer_3_rn = self.scratch.layer3_rn(layer_3) + layer_4_rn = self.scratch.layer4_rn(layer_4) + + path_4 = self.scratch.refinenet4(layer_4_rn) + path_3 = self.scratch.refinenet3(path_4, layer_3_rn) + path_2 = self.scratch.refinenet2(path_3, layer_2_rn) + path_1 = self.scratch.refinenet1(path_2, layer_1_rn) + + _,_,H_out,W_out = path_1.size() + path_2_up = F.interpolate(path_2, size=(H_out,W_out), mode="bilinear", align_corners=True) + path_3_up = F.interpolate(path_3, size=(H_out,W_out), mode="bilinear", align_corners=True) + path_4_up = F.interpolate(path_4, size=(H_out,W_out), mode="bilinear", align_corners=True) + + out = self.scratch.output_conv(path_1+path_2_up+path_3_up+path_4_up) + + return out + + +class DPTDepthModel(DPT): + def __init__( + self, path=None, non_negative=True, scale=1.0, shift=0.0, invert=False, **kwargs + ): + features = kwargs["features"] if "features" in kwargs else 256 + + self.scale = scale + self.shift = shift + self.invert = invert + + head = nn.Sequential( + nn.Conv2d(features, features // 2, kernel_size=3, stride=1, padding=1), + Interpolate(scale_factor=2, mode="bilinear", align_corners=True), + nn.Conv2d(features // 2, 32, kernel_size=3, stride=1, padding=1), + nn.ReLU(True), + nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0), + nn.ReLU(True) if non_negative else nn.Identity(), + nn.Identity(), + ) + + super().__init__(head, **kwargs) + + if path is not None: + self.load(path) + + def forward(self, x): + inv_depth = super().forward(x).squeeze(dim=1) + + if self.invert: + depth = self.scale * inv_depth + self.shift + depth[depth < 1e-8] = 1e-8 + depth = 1.0 / depth + return depth + else: + return inv_depth + +class DPTEncoder(DPT): + def __init__( + self, path=None, non_negative=True, scale=1.0, shift=0.0, invert=False, **kwargs + ): + features = kwargs["features"] if "features" in kwargs else 256 + + self.scale = scale + self.shift = shift + + head = nn.Sequential( + nn.Conv2d(features, 128, kernel_size=3, stride=1, padding=1), + ) + + super().__init__(head, **kwargs) + + if path is not None: + self.load(path) + + def forward(self, x): + features = super().forward(x, only_enc=True).squeeze(dim=1) + + return features + + +class DPTSegmentationModel(DPT): + def __init__(self, num_classes, path=None, **kwargs): + + features = kwargs["features"] if "features" in kwargs else 256 + + kwargs["use_bn"] = True + + head = nn.Sequential( + nn.Conv2d(features, features, kernel_size=3, padding=1, bias=False), + nn.BatchNorm2d(features), + nn.ReLU(True), + nn.Dropout(0.1, False), + nn.Conv2d(features, num_classes, kernel_size=1), + Interpolate(scale_factor=2, mode="bilinear", align_corners=True), + ) + + super().__init__(head, **kwargs) + + self.auxlayer = nn.Sequential( + nn.Conv2d(features, features, kernel_size=3, padding=1, bias=False), + nn.BatchNorm2d(features), + nn.ReLU(True), + nn.Dropout(0.1, False), + nn.Conv2d(features, num_classes, kernel_size=1), + ) + + if path is not None: + self.load(path) diff --git a/models/spatracker/models/core/spatracker/dpt/transforms.py b/models/spatracker/models/core/spatracker/dpt/transforms.py new file mode 100644 index 0000000000000000000000000000000000000000..399adbcdad096ae3fb8a190ecd3ec5483a897251 --- /dev/null +++ b/models/spatracker/models/core/spatracker/dpt/transforms.py @@ -0,0 +1,231 @@ +import numpy as np +import cv2 +import math + + +def apply_min_size(sample, size, image_interpolation_method=cv2.INTER_AREA): + """Rezise the sample to ensure the given size. Keeps aspect ratio. + + Args: + sample (dict): sample + size (tuple): image size + + Returns: + tuple: new size + """ + shape = list(sample["disparity"].shape) + + if shape[0] >= size[0] and shape[1] >= size[1]: + return sample + + scale = [0, 0] + scale[0] = size[0] / shape[0] + scale[1] = size[1] / shape[1] + + scale = max(scale) + + shape[0] = math.ceil(scale * shape[0]) + shape[1] = math.ceil(scale * shape[1]) + + # resize + sample["image"] = cv2.resize( + sample["image"], tuple(shape[::-1]), interpolation=image_interpolation_method + ) + + sample["disparity"] = cv2.resize( + sample["disparity"], tuple(shape[::-1]), interpolation=cv2.INTER_NEAREST + ) + sample["mask"] = cv2.resize( + sample["mask"].astype(np.float32), + tuple(shape[::-1]), + interpolation=cv2.INTER_NEAREST, + ) + sample["mask"] = sample["mask"].astype(bool) + + return tuple(shape) + + +class Resize(object): + """Resize sample to given size (width, height).""" + + def __init__( + self, + width, + height, + resize_target=True, + keep_aspect_ratio=False, + ensure_multiple_of=1, + resize_method="lower_bound", + image_interpolation_method=cv2.INTER_AREA, + ): + """Init. + + Args: + width (int): desired output width + height (int): desired output height + resize_target (bool, optional): + True: Resize the full sample (image, mask, target). + False: Resize image only. + Defaults to True. + keep_aspect_ratio (bool, optional): + True: Keep the aspect ratio of the input sample. + Output sample might not have the given width and height, and + resize behaviour depends on the parameter 'resize_method'. + Defaults to False. + ensure_multiple_of (int, optional): + Output width and height is constrained to be multiple of this parameter. + Defaults to 1. + resize_method (str, optional): + "lower_bound": Output will be at least as large as the given size. + "upper_bound": Output will be at max as large as the given size. (Output size might be smaller than given size.) + "minimal": Scale as least as possible. (Output size might be smaller than given size.) + Defaults to "lower_bound". + """ + self.__width = width + self.__height = height + + self.__resize_target = resize_target + self.__keep_aspect_ratio = keep_aspect_ratio + self.__multiple_of = ensure_multiple_of + self.__resize_method = resize_method + self.__image_interpolation_method = image_interpolation_method + + def constrain_to_multiple_of(self, x, min_val=0, max_val=None): + y = (np.round(x / self.__multiple_of) * self.__multiple_of).astype(int) + + if max_val is not None and y > max_val: + y = (np.floor(x / self.__multiple_of) * self.__multiple_of).astype(int) + + if y < min_val: + y = (np.ceil(x / self.__multiple_of) * self.__multiple_of).astype(int) + + return y + + def get_size(self, width, height): + # determine new height and width + scale_height = self.__height / height + scale_width = self.__width / width + + if self.__keep_aspect_ratio: + if self.__resize_method == "lower_bound": + # scale such that output size is lower bound + if scale_width > scale_height: + # fit width + scale_height = scale_width + else: + # fit height + scale_width = scale_height + elif self.__resize_method == "upper_bound": + # scale such that output size is upper bound + if scale_width < scale_height: + # fit width + scale_height = scale_width + else: + # fit height + scale_width = scale_height + elif self.__resize_method == "minimal": + # scale as least as possbile + if abs(1 - scale_width) < abs(1 - scale_height): + # fit width + scale_height = scale_width + else: + # fit height + scale_width = scale_height + else: + raise ValueError( + f"resize_method {self.__resize_method} not implemented" + ) + + if self.__resize_method == "lower_bound": + new_height = self.constrain_to_multiple_of( + scale_height * height, min_val=self.__height + ) + new_width = self.constrain_to_multiple_of( + scale_width * width, min_val=self.__width + ) + elif self.__resize_method == "upper_bound": + new_height = self.constrain_to_multiple_of( + scale_height * height, max_val=self.__height + ) + new_width = self.constrain_to_multiple_of( + scale_width * width, max_val=self.__width + ) + elif self.__resize_method == "minimal": + new_height = self.constrain_to_multiple_of(scale_height * height) + new_width = self.constrain_to_multiple_of(scale_width * width) + else: + raise ValueError(f"resize_method {self.__resize_method} not implemented") + + return (new_width, new_height) + + def __call__(self, sample): + width, height = self.get_size( + sample["image"].shape[1], sample["image"].shape[0] + ) + + # resize sample + sample["image"] = cv2.resize( + sample["image"], + (width, height), + interpolation=self.__image_interpolation_method, + ) + + if self.__resize_target: + if "disparity" in sample: + sample["disparity"] = cv2.resize( + sample["disparity"], + (width, height), + interpolation=cv2.INTER_NEAREST, + ) + + if "depth" in sample: + sample["depth"] = cv2.resize( + sample["depth"], (width, height), interpolation=cv2.INTER_NEAREST + ) + + sample["mask"] = cv2.resize( + sample["mask"].astype(np.float32), + (width, height), + interpolation=cv2.INTER_NEAREST, + ) + sample["mask"] = sample["mask"].astype(bool) + + return sample + + +class NormalizeImage(object): + """Normlize image by given mean and std.""" + + def __init__(self, mean, std): + self.__mean = mean + self.__std = std + + def __call__(self, sample): + sample["image"] = (sample["image"] - self.__mean) / self.__std + + return sample + + +class PrepareForNet(object): + """Prepare sample for usage as network input.""" + + def __init__(self): + pass + + def __call__(self, sample): + image = np.transpose(sample["image"], (2, 0, 1)) + sample["image"] = np.ascontiguousarray(image).astype(np.float32) + + if "mask" in sample: + sample["mask"] = sample["mask"].astype(np.float32) + sample["mask"] = np.ascontiguousarray(sample["mask"]) + + if "disparity" in sample: + disparity = sample["disparity"].astype(np.float32) + sample["disparity"] = np.ascontiguousarray(disparity) + + if "depth" in sample: + depth = sample["depth"].astype(np.float32) + sample["depth"] = np.ascontiguousarray(depth) + + return sample diff --git a/models/spatracker/models/core/spatracker/dpt/vit.py b/models/spatracker/models/core/spatracker/dpt/vit.py new file mode 100644 index 0000000000000000000000000000000000000000..0a0b10806f7411e464c7cd3ae21554825d62f80a --- /dev/null +++ b/models/spatracker/models/core/spatracker/dpt/vit.py @@ -0,0 +1,596 @@ +import torch +import torch.nn as nn +import timm +import types +import math +import torch.nn.functional as F + + +activations = {} + + +def get_activation(name): + def hook(model, input, output): + activations[name] = output + + return hook + + +attention = {} + + +def get_attention(name): + def hook(module, input, output): + x = input[0] + B, N, C = x.shape + qkv = ( + module.qkv(x) + .reshape(B, N, 3, module.num_heads, C // module.num_heads) + .permute(2, 0, 3, 1, 4) + ) + q, k, v = ( + qkv[0], + qkv[1], + qkv[2], + ) # make torchscript happy (cannot use tensor as tuple) + + attn = (q @ k.transpose(-2, -1)) * module.scale + + attn = attn.softmax(dim=-1) # [:,:,1,1:] + attention[name] = attn + + return hook + + +def get_mean_attention_map(attn, token, shape): + attn = attn[:, :, token, 1:] + attn = attn.unflatten(2, torch.Size([shape[2] // 16, shape[3] // 16])).float() + attn = torch.nn.functional.interpolate( + attn, size=shape[2:], mode="bicubic", align_corners=False + ).squeeze(0) + + all_attn = torch.mean(attn, 0) + + return all_attn + + +class Slice(nn.Module): + def __init__(self, start_index=1): + super(Slice, self).__init__() + self.start_index = start_index + + def forward(self, x): + return x[:, self.start_index :] + + +class AddReadout(nn.Module): + def __init__(self, start_index=1): + super(AddReadout, self).__init__() + self.start_index = start_index + + def forward(self, x): + if self.start_index == 2: + readout = (x[:, 0] + x[:, 1]) / 2 + else: + readout = x[:, 0] + return x[:, self.start_index :] + readout.unsqueeze(1) + + +class ProjectReadout(nn.Module): + def __init__(self, in_features, start_index=1): + super(ProjectReadout, self).__init__() + self.start_index = start_index + + self.project = nn.Sequential(nn.Linear(2 * in_features, in_features), nn.GELU()) + + def forward(self, x): + readout = x[:, 0].unsqueeze(1).expand_as(x[:, self.start_index :]) + features = torch.cat((x[:, self.start_index :], readout), -1) + + return self.project(features) + + +class Transpose(nn.Module): + def __init__(self, dim0, dim1): + super(Transpose, self).__init__() + self.dim0 = dim0 + self.dim1 = dim1 + + def forward(self, x): + x = x.transpose(self.dim0, self.dim1) + return x + + +def forward_vit(pretrained, x): + b, c, h, w = x.shape + + glob = pretrained.model.forward_flex(x) + + layer_1 = pretrained.activations["1"] + layer_2 = pretrained.activations["2"] + layer_3 = pretrained.activations["3"] + layer_4 = pretrained.activations["4"] + + layer_1 = pretrained.act_postprocess1[0:2](layer_1) + layer_2 = pretrained.act_postprocess2[0:2](layer_2) + layer_3 = pretrained.act_postprocess3[0:2](layer_3) + layer_4 = pretrained.act_postprocess4[0:2](layer_4) + + unflatten = nn.Sequential( + nn.Unflatten( + 2, + torch.Size( + [ + h // pretrained.model.patch_size[1], + w // pretrained.model.patch_size[0], + ] + ), + ) + ) + + if layer_1.ndim == 3: + layer_1 = unflatten(layer_1) + if layer_2.ndim == 3: + layer_2 = unflatten(layer_2) + if layer_3.ndim == 3: + layer_3 = unflatten(layer_3) + if layer_4.ndim == 3: + layer_4 = unflatten(layer_4) + + layer_1 = pretrained.act_postprocess1[3 : len(pretrained.act_postprocess1)](layer_1) + layer_2 = pretrained.act_postprocess2[3 : len(pretrained.act_postprocess2)](layer_2) + layer_3 = pretrained.act_postprocess3[3 : len(pretrained.act_postprocess3)](layer_3) + layer_4 = pretrained.act_postprocess4[3 : len(pretrained.act_postprocess4)](layer_4) + + return layer_1, layer_2, layer_3, layer_4 + + +def _resize_pos_embed(self, posemb, gs_h, gs_w): + posemb_tok, posemb_grid = ( + posemb[:, : self.start_index], + posemb[0, self.start_index :], + ) + + gs_old = int(math.sqrt(len(posemb_grid))) + + posemb_grid = posemb_grid.reshape(1, gs_old, gs_old, -1).permute(0, 3, 1, 2) + posemb_grid = F.interpolate(posemb_grid, size=(gs_h, gs_w), mode="bilinear") + posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, gs_h * gs_w, -1) + + posemb = torch.cat([posemb_tok, posemb_grid], dim=1) + + return posemb + + +def forward_flex(self, x): + b, c, h, w = x.shape + + pos_embed = self._resize_pos_embed( + self.pos_embed, h // self.patch_size[1], w // self.patch_size[0] + ) + + B = x.shape[0] + + if hasattr(self.patch_embed, "backbone"): + x = self.patch_embed.backbone(x) + if isinstance(x, (list, tuple)): + x = x[-1] # last feature if backbone outputs list/tuple of features + x = self.patch_embed.proj(x).flatten(2).transpose(1, 2) + + if getattr(self, "dist_token", None) is not None: + cls_tokens = self.cls_token.expand( + B, -1, -1 + ) # stole cls_tokens impl from Phil Wang, thanks + dist_token = self.dist_token.expand(B, -1, -1) + x = torch.cat((cls_tokens, dist_token, x), dim=1) + else: + cls_tokens = self.cls_token.expand( + B, -1, -1 + ) # stole cls_tokens impl from Phil Wang, thanks + x = torch.cat((cls_tokens, x), dim=1) + + x = x + pos_embed + x = self.pos_drop(x) + + for blk in self.blocks: + x = blk(x) + + x = self.norm(x) + + return x + + +def get_readout_oper(vit_features, features, use_readout, start_index=1): + if use_readout == "ignore": + readout_oper = [Slice(start_index)] * len(features) + elif use_readout == "add": + readout_oper = [AddReadout(start_index)] * len(features) + elif use_readout == "project": + readout_oper = [ + ProjectReadout(vit_features, start_index) for out_feat in features + ] + else: + assert ( + False + ), "wrong operation for readout token, use_readout can be 'ignore', 'add', or 'project'" + + return readout_oper + + +def _make_vit_b16_backbone( + model, + features=[96, 192, 384, 768], + size=[384, 384], + hooks=[2, 5, 8, 11], + vit_features=768, + use_readout="ignore", + start_index=1, + enable_attention_hooks=False, +): + pretrained = nn.Module() + + pretrained.model = model + pretrained.model.blocks[hooks[0]].register_forward_hook(get_activation("1")) + pretrained.model.blocks[hooks[1]].register_forward_hook(get_activation("2")) + pretrained.model.blocks[hooks[2]].register_forward_hook(get_activation("3")) + pretrained.model.blocks[hooks[3]].register_forward_hook(get_activation("4")) + + pretrained.activations = activations + + if enable_attention_hooks: + pretrained.model.blocks[hooks[0]].attn.register_forward_hook( + get_attention("attn_1") + ) + pretrained.model.blocks[hooks[1]].attn.register_forward_hook( + get_attention("attn_2") + ) + pretrained.model.blocks[hooks[2]].attn.register_forward_hook( + get_attention("attn_3") + ) + pretrained.model.blocks[hooks[3]].attn.register_forward_hook( + get_attention("attn_4") + ) + pretrained.attention = attention + + readout_oper = get_readout_oper(vit_features, features, use_readout, start_index) + + # 32, 48, 136, 384 + pretrained.act_postprocess1 = nn.Sequential( + readout_oper[0], + Transpose(1, 2), + nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), + nn.Conv2d( + in_channels=vit_features, + out_channels=features[0], + kernel_size=1, + stride=1, + padding=0, + ), + nn.ConvTranspose2d( + in_channels=features[0], + out_channels=features[0], + kernel_size=4, + stride=4, + padding=0, + bias=True, + dilation=1, + groups=1, + ), + ) + + pretrained.act_postprocess2 = nn.Sequential( + readout_oper[1], + Transpose(1, 2), + nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), + nn.Conv2d( + in_channels=vit_features, + out_channels=features[1], + kernel_size=1, + stride=1, + padding=0, + ), + nn.ConvTranspose2d( + in_channels=features[1], + out_channels=features[1], + kernel_size=2, + stride=2, + padding=0, + bias=True, + dilation=1, + groups=1, + ), + ) + + pretrained.act_postprocess3 = nn.Sequential( + readout_oper[2], + Transpose(1, 2), + nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), + nn.Conv2d( + in_channels=vit_features, + out_channels=features[2], + kernel_size=1, + stride=1, + padding=0, + ), + ) + + pretrained.act_postprocess4 = nn.Sequential( + readout_oper[3], + Transpose(1, 2), + nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), + nn.Conv2d( + in_channels=vit_features, + out_channels=features[3], + kernel_size=1, + stride=1, + padding=0, + ), + nn.Conv2d( + in_channels=features[3], + out_channels=features[3], + kernel_size=3, + stride=2, + padding=1, + ), + ) + + pretrained.model.start_index = start_index + pretrained.model.patch_size = [16, 16] + + # We inject this function into the VisionTransformer instances so that + # we can use it with interpolated position embeddings without modifying the library source. + pretrained.model.forward_flex = types.MethodType(forward_flex, pretrained.model) + pretrained.model._resize_pos_embed = types.MethodType( + _resize_pos_embed, pretrained.model + ) + + return pretrained + + +def _make_vit_b_rn50_backbone( + model, + features=[256, 512, 768, 768], + size=[384, 384], + hooks=[0, 1, 8, 11], + vit_features=384, + use_vit_only=False, + use_readout="ignore", + start_index=1, + enable_attention_hooks=False, +): + pretrained = nn.Module() + pretrained.model = model + pretrained.model.patch_size = [32, 32] + ps = pretrained.model.patch_size[0] + if use_vit_only == True: + pretrained.model.blocks[hooks[0]].register_forward_hook(get_activation("1")) + pretrained.model.blocks[hooks[1]].register_forward_hook(get_activation("2")) + else: + pretrained.model.patch_embed.backbone.stages[0].register_forward_hook( + get_activation("1") + ) + pretrained.model.patch_embed.backbone.stages[1].register_forward_hook( + get_activation("2") + ) + + pretrained.model.blocks[hooks[2]].register_forward_hook(get_activation("3")) + pretrained.model.blocks[hooks[3]].register_forward_hook(get_activation("4")) + + if enable_attention_hooks: + pretrained.model.blocks[2].attn.register_forward_hook(get_attention("attn_1")) + pretrained.model.blocks[5].attn.register_forward_hook(get_attention("attn_2")) + pretrained.model.blocks[8].attn.register_forward_hook(get_attention("attn_3")) + pretrained.model.blocks[11].attn.register_forward_hook(get_attention("attn_4")) + pretrained.attention = attention + + pretrained.activations = activations + + readout_oper = get_readout_oper(vit_features, features, use_readout, start_index) + + if use_vit_only == True: + pretrained.act_postprocess1 = nn.Sequential( + readout_oper[0], + Transpose(1, 2), + nn.Unflatten(2, torch.Size([size[0] // ps, size[1] // ps])), + nn.Conv2d( + in_channels=vit_features, + out_channels=features[0], + kernel_size=1, + stride=1, + padding=0, + ), + nn.ConvTranspose2d( + in_channels=features[0], + out_channels=features[0], + kernel_size=4, + stride=4, + padding=0, + bias=True, + dilation=1, + groups=1, + ), + ) + + pretrained.act_postprocess2 = nn.Sequential( + readout_oper[1], + Transpose(1, 2), + nn.Unflatten(2, torch.Size([size[0] // ps, size[1] // ps])), + nn.Conv2d( + in_channels=vit_features, + out_channels=features[1], + kernel_size=1, + stride=1, + padding=0, + ), + nn.ConvTranspose2d( + in_channels=features[1], + out_channels=features[1], + kernel_size=2, + stride=2, + padding=0, + bias=True, + dilation=1, + groups=1, + ), + ) + else: + pretrained.act_postprocess1 = nn.Sequential( + nn.Identity(), nn.Identity(), nn.Identity() + ) + pretrained.act_postprocess2 = nn.Sequential( + nn.Identity(), nn.Identity(), nn.Identity() + ) + + pretrained.act_postprocess3 = nn.Sequential( + readout_oper[2], + Transpose(1, 2), + nn.Unflatten(2, torch.Size([size[0] // ps, size[1] // ps])), + nn.Conv2d( + in_channels=vit_features, + out_channels=features[2], + kernel_size=1, + stride=1, + padding=0, + ), + ) + + pretrained.act_postprocess4 = nn.Sequential( + readout_oper[3], + Transpose(1, 2), + nn.Unflatten(2, torch.Size([size[0] // ps, size[1] // ps])), + nn.Conv2d( + in_channels=vit_features, + out_channels=features[3], + kernel_size=1, + stride=1, + padding=0, + ), + nn.Conv2d( + in_channels=features[3], + out_channels=features[3], + kernel_size=3, + stride=2, + padding=1, + ), + ) + + pretrained.model.start_index = start_index + pretrained.model.patch_size = [32, 32] + + # We inject this function into the VisionTransformer instances so that + # we can use it with interpolated position embeddings without modifying the library source. + pretrained.model.forward_flex = types.MethodType(forward_flex, pretrained.model) + + # We inject this function into the VisionTransformer instances so that + # we can use it with interpolated position embeddings without modifying the library source. + pretrained.model._resize_pos_embed = types.MethodType( + _resize_pos_embed, pretrained.model + ) + + return pretrained + + +def _make_pretrained_vitb_rn50_384( + pretrained, + use_readout="ignore", + hooks=None, + use_vit_only=False, + enable_attention_hooks=False, +): + # model = timm.create_model("vit_base_resnet50_384", pretrained=pretrained) + # model = timm.create_model("vit_tiny_r_s16_p8_384", pretrained=pretrained) + model = timm.create_model("vit_small_r26_s32_384", pretrained=pretrained) + hooks = [0, 1, 8, 11] if hooks == None else hooks + return _make_vit_b_rn50_backbone( + model, + features=[128, 256, 384, 384], + size=[384, 384], + hooks=hooks, + use_vit_only=use_vit_only, + use_readout=use_readout, + enable_attention_hooks=enable_attention_hooks, + ) + +def _make_pretrained_vit_tiny( + pretrained, + use_readout="ignore", + hooks=None, + use_vit_only=False, + enable_attention_hooks=False, +): + # model = timm.create_model("vit_base_resnet50_384", pretrained=pretrained) + model = timm.create_model("vit_tiny_r_s16_p8_384", pretrained=pretrained) + import ipdb; ipdb.set_trace() + hooks = [0, 1, 8, 11] if hooks == None else hooks + return _make_vit_tiny_backbone( + model, + features=[256, 512, 768, 768], + size=[384, 384], + hooks=hooks, + use_vit_only=use_vit_only, + use_readout=use_readout, + enable_attention_hooks=enable_attention_hooks, + ) + +def _make_pretrained_vitl16_384( + pretrained, use_readout="ignore", hooks=None, enable_attention_hooks=False +): + model = timm.create_model("vit_large_patch16_384", pretrained=pretrained) + + hooks = [5, 11, 17, 23] if hooks == None else hooks + return _make_vit_b16_backbone( + model, + features=[256, 512, 1024, 1024], + hooks=hooks, + vit_features=1024, + use_readout=use_readout, + enable_attention_hooks=enable_attention_hooks, + ) + + +def _make_pretrained_vitb16_384( + pretrained, use_readout="ignore", hooks=None, enable_attention_hooks=False +): + model = timm.create_model("vit_base_patch16_384", pretrained=pretrained) + + hooks = [2, 5, 8, 11] if hooks == None else hooks + return _make_vit_b16_backbone( + model, + features=[96, 192, 384, 768], + hooks=hooks, + use_readout=use_readout, + enable_attention_hooks=enable_attention_hooks, + ) + + +def _make_pretrained_deitb16_384( + pretrained, use_readout="ignore", hooks=None, enable_attention_hooks=False +): + model = timm.create_model("vit_deit_base_patch16_384", pretrained=pretrained) + + hooks = [2, 5, 8, 11] if hooks == None else hooks + return _make_vit_b16_backbone( + model, + features=[96, 192, 384, 768], + hooks=hooks, + use_readout=use_readout, + enable_attention_hooks=enable_attention_hooks, + ) + + +def _make_pretrained_deitb16_distil_384( + pretrained, use_readout="ignore", hooks=None, enable_attention_hooks=False +): + model = timm.create_model( + "vit_deit_base_distilled_patch16_384", pretrained=pretrained + ) + + hooks = [2, 5, 8, 11] if hooks == None else hooks + return _make_vit_b16_backbone( + model, + features=[96, 192, 384, 768], + hooks=hooks, + use_readout=use_readout, + start_index=2, + enable_attention_hooks=enable_attention_hooks, + ) diff --git a/models/spatracker/models/core/spatracker/feature_net.py b/models/spatracker/models/core/spatracker/feature_net.py new file mode 100644 index 0000000000000000000000000000000000000000..48c3b3ab6728b13faf6a54d79c6758f647feb19c --- /dev/null +++ b/models/spatracker/models/core/spatracker/feature_net.py @@ -0,0 +1,915 @@ +""" + Adapted from ConvONet + https://github.com/autonomousvision/convolutional_occupancy_networks/blob/838bea5b2f1314f2edbb68d05ebb0db49f1f3bd2/src/encoder/pointnet.py#L1 +""" + + +import torch +import torch.nn as nn +import torch.nn.functional as F +# from torch_scatter import scatter_mean, scatter_max +from models.spatracker.models.core.spatracker.unet import UNet +from models.spatracker.models.core.model_utils import ( + vis_PCA +) +from einops import rearrange + +def compute_iou(occ1, occ2): + ''' Computes the Intersection over Union (IoU) value for two sets of + occupancy values. + + Args: + occ1 (tensor): first set of occupancy values + occ2 (tensor): second set of occupancy values + ''' + occ1 = np.asarray(occ1) + occ2 = np.asarray(occ2) + + # Put all data in second dimension + # Also works for 1-dimensional data + if occ1.ndim >= 2: + occ1 = occ1.reshape(occ1.shape[0], -1) + if occ2.ndim >= 2: + occ2 = occ2.reshape(occ2.shape[0], -1) + + # Convert to boolean values + occ1 = (occ1 >= 0.5) + occ2 = (occ2 >= 0.5) + + # Compute IOU + area_union = (occ1 | occ2).astype(np.float32).sum(axis=-1) + area_intersect = (occ1 & occ2).astype(np.float32).sum(axis=-1) + + iou = (area_intersect / area_union) + + return iou + + +def chamfer_distance(points1, points2, use_kdtree=True, give_id=False): + ''' Returns the chamfer distance for the sets of points. + + Args: + points1 (numpy array): first point set + points2 (numpy array): second point set + use_kdtree (bool): whether to use a kdtree + give_id (bool): whether to return the IDs of nearest points + ''' + if use_kdtree: + return chamfer_distance_kdtree(points1, points2, give_id=give_id) + else: + return chamfer_distance_naive(points1, points2) + + +def chamfer_distance_naive(points1, points2): + ''' Naive implementation of the Chamfer distance. + + Args: + points1 (numpy array): first point set + points2 (numpy array): second point set + ''' + assert(points1.size() == points2.size()) + batch_size, T, _ = points1.size() + + points1 = points1.view(batch_size, T, 1, 3) + points2 = points2.view(batch_size, 1, T, 3) + + distances = (points1 - points2).pow(2).sum(-1) + + chamfer1 = distances.min(dim=1)[0].mean(dim=1) + chamfer2 = distances.min(dim=2)[0].mean(dim=1) + + chamfer = chamfer1 + chamfer2 + return chamfer + + +def chamfer_distance_kdtree(points1, points2, give_id=False): + ''' KD-tree based implementation of the Chamfer distance. + + Args: + points1 (numpy array): first point set + points2 (numpy array): second point set + give_id (bool): whether to return the IDs of the nearest points + ''' + # Points have size batch_size x T x 3 + batch_size = points1.size(0) + + # First convert points to numpy + points1_np = points1.detach().cpu().numpy() + points2_np = points2.detach().cpu().numpy() + + # Get list of nearest neighbors indieces + idx_nn_12, _ = get_nearest_neighbors_indices_batch(points1_np, points2_np) + idx_nn_12 = torch.LongTensor(idx_nn_12).to(points1.device) + # Expands it as batch_size x 1 x 3 + idx_nn_12_expand = idx_nn_12.view(batch_size, -1, 1).expand_as(points1) + + # Get list of nearest neighbors indieces + idx_nn_21, _ = get_nearest_neighbors_indices_batch(points2_np, points1_np) + idx_nn_21 = torch.LongTensor(idx_nn_21).to(points1.device) + # Expands it as batch_size x T x 3 + idx_nn_21_expand = idx_nn_21.view(batch_size, -1, 1).expand_as(points2) + + # Compute nearest neighbors in points2 to points in points1 + # points_12[i, j, k] = points2[i, idx_nn_12_expand[i, j, k], k] + points_12 = torch.gather(points2, dim=1, index=idx_nn_12_expand) + + # Compute nearest neighbors in points1 to points in points2 + # points_21[i, j, k] = points2[i, idx_nn_21_expand[i, j, k], k] + points_21 = torch.gather(points1, dim=1, index=idx_nn_21_expand) + + # Compute chamfer distance + chamfer1 = (points1 - points_12).pow(2).sum(2).mean(1) + chamfer2 = (points2 - points_21).pow(2).sum(2).mean(1) + + # Take sum + chamfer = chamfer1 + chamfer2 + + # If required, also return nearest neighbors + if give_id: + return chamfer1, chamfer2, idx_nn_12, idx_nn_21 + + return chamfer + + +def get_nearest_neighbors_indices_batch(points_src, points_tgt, k=1): + ''' Returns the nearest neighbors for point sets batchwise. + + Args: + points_src (numpy array): source points + points_tgt (numpy array): target points + k (int): number of nearest neighbors to return + ''' + indices = [] + distances = [] + + for (p1, p2) in zip(points_src, points_tgt): + raise NotImplementedError() + # kdtree = KDTree(p2) + dist, idx = kdtree.query(p1, k=k) + indices.append(idx) + distances.append(dist) + + return indices, distances + + +def make_3d_grid(bb_min, bb_max, shape): + ''' Makes a 3D grid. + + Args: + bb_min (tuple): bounding box minimum + bb_max (tuple): bounding box maximum + shape (tuple): output shape + ''' + size = shape[0] * shape[1] * shape[2] + + pxs = torch.linspace(bb_min[0], bb_max[0], shape[0]) + pys = torch.linspace(bb_min[1], bb_max[1], shape[1]) + pzs = torch.linspace(bb_min[2], bb_max[2], shape[2]) + + pxs = pxs.view(-1, 1, 1).expand(*shape).contiguous().view(size) + pys = pys.view(1, -1, 1).expand(*shape).contiguous().view(size) + pzs = pzs.view(1, 1, -1).expand(*shape).contiguous().view(size) + p = torch.stack([pxs, pys, pzs], dim=1) + + return p + + +def transform_points(points, transform): + ''' Transforms points with regard to passed camera information. + + Args: + points (tensor): points tensor + transform (tensor): transformation matrices + ''' + assert(points.size(2) == 3) + assert(transform.size(1) == 3) + assert(points.size(0) == transform.size(0)) + + if transform.size(2) == 4: + R = transform[:, :, :3] + t = transform[:, :, 3:] + points_out = points @ R.transpose(1, 2) + t.transpose(1, 2) + elif transform.size(2) == 3: + K = transform + points_out = points @ K.transpose(1, 2) + + return points_out + + +def b_inv(b_mat): + ''' Performs batch matrix inversion. + + Arguments: + b_mat: the batch of matrices that should be inverted + ''' + + eye = b_mat.new_ones(b_mat.size(-1)).diag().expand_as(b_mat) + b_inv, _ = torch.gesv(eye, b_mat) + return b_inv + +def project_to_camera(points, transform): + ''' Projects points to the camera plane. + + Args: + points (tensor): points tensor + transform (tensor): transformation matrices + ''' + p_camera = transform_points(points, transform) + p_camera = p_camera[..., :2] / p_camera[..., 2:] + return p_camera + + +def fix_Rt_camera(Rt, loc, scale): + ''' Fixes Rt camera matrix. + + Args: + Rt (tensor): Rt camera matrix + loc (tensor): location + scale (float): scale + ''' + # Rt is B x 3 x 4 + # loc is B x 3 and scale is B + batch_size = Rt.size(0) + R = Rt[:, :, :3] + t = Rt[:, :, 3:] + + scale = scale.view(batch_size, 1, 1) + R_new = R * scale + t_new = t + R @ loc.unsqueeze(2) + + Rt_new = torch.cat([R_new, t_new], dim=2) + + assert(Rt_new.size() == (batch_size, 3, 4)) + return Rt_new + +def normalize_coordinate(p, padding=0.1, plane='xz'): + ''' Normalize coordinate to [0, 1] for unit cube experiments + + Args: + p (tensor): point + padding (float): conventional padding paramter of ONet for unit cube, so [-0.5, 0.5] -> [-0.55, 0.55] + plane (str): plane feature type, ['xz', 'xy', 'yz'] + ''' + # breakpoint() + if plane == 'xz': + xy = p[:, :, [0, 2]] + elif plane =='xy': + xy = p[:, :, [0, 1]] + else: + xy = p[:, :, [1, 2]] + + xy = torch.clamp(xy, min=1e-6, max=1. - 1e-6) + + # xy_new = xy / (1 + padding + 10e-6) # (-0.5, 0.5) + # xy_new = xy_new + 0.5 # range (0, 1) + + # # f there are outliers out of the range + # if xy_new.max() >= 1: + # xy_new[xy_new >= 1] = 1 - 10e-6 + # if xy_new.min() < 0: + # xy_new[xy_new < 0] = 0.0 + # xy_new = (xy + 1.) / 2. + return xy + +def normalize_3d_coordinate(p, padding=0.1): + ''' Normalize coordinate to [0, 1] for unit cube experiments. + Corresponds to our 3D model + + Args: + p (tensor): point + padding (float): conventional padding paramter of ONet for unit cube, so [-0.5, 0.5] -> [-0.55, 0.55] + ''' + + p_nor = p / (1 + padding + 10e-4) # (-0.5, 0.5) + p_nor = p_nor + 0.5 # range (0, 1) + # f there are outliers out of the range + if p_nor.max() >= 1: + p_nor[p_nor >= 1] = 1 - 10e-4 + if p_nor.min() < 0: + p_nor[p_nor < 0] = 0.0 + return p_nor + +def normalize_coord(p, vol_range, plane='xz'): + ''' Normalize coordinate to [0, 1] for sliding-window experiments + + Args: + p (tensor): point + vol_range (numpy array): volume boundary + plane (str): feature type, ['xz', 'xy', 'yz'] - canonical planes; ['grid'] - grid volume + ''' + p[:, 0] = (p[:, 0] - vol_range[0][0]) / (vol_range[1][0] - vol_range[0][0]) + p[:, 1] = (p[:, 1] - vol_range[0][1]) / (vol_range[1][1] - vol_range[0][1]) + p[:, 2] = (p[:, 2] - vol_range[0][2]) / (vol_range[1][2] - vol_range[0][2]) + + if plane == 'xz': + x = p[:, [0, 2]] + elif plane =='xy': + x = p[:, [0, 1]] + elif plane =='yz': + x = p[:, [1, 2]] + else: + x = p + return x + +def coordinate2index(x, reso, coord_type='2d'): + ''' Normalize coordinate to [0, 1] for unit cube experiments. + Corresponds to our 3D model + + Args: + x (tensor): coordinate + reso (int): defined resolution + coord_type (str): coordinate type + ''' + x = (x * reso).long() + if coord_type == '2d': # plane + index = x[:, :, 0] + reso * x[:, :, 1] + elif coord_type == '3d': # grid + index = x[:, :, 0] + reso * (x[:, :, 1] + reso * x[:, :, 2]) + index = index[:, None, :] + return index + +def coord2index(p, vol_range, reso=None, plane='xz'): + ''' Normalize coordinate to [0, 1] for sliding-window experiments. + Corresponds to our 3D model + + Args: + p (tensor): points + vol_range (numpy array): volume boundary + reso (int): defined resolution + plane (str): feature type, ['xz', 'xy', 'yz'] - canonical planes; ['grid'] - grid volume + ''' + # normalize to [0, 1] + x = normalize_coord(p, vol_range, plane=plane) + + if isinstance(x, np.ndarray): + x = np.floor(x * reso).astype(int) + else: #* pytorch tensor + x = (x * reso).long() + + if x.shape[1] == 2: + index = x[:, 0] + reso * x[:, 1] + index[index > reso**2] = reso**2 + elif x.shape[1] == 3: + index = x[:, 0] + reso * (x[:, 1] + reso * x[:, 2]) + index[index > reso**3] = reso**3 + + return index[None] + +def update_reso(reso, depth): + ''' Update the defined resolution so that UNet can process. + + Args: + reso (int): defined resolution + depth (int): U-Net number of layers + ''' + base = 2**(int(depth) - 1) + if ~(reso / base).is_integer(): # when this is not integer, U-Net dimension error + for i in range(base): + if ((reso + i) / base).is_integer(): + reso = reso + i + break + return reso + +def decide_total_volume_range(query_vol_metric, recep_field, unit_size, unet_depth): + ''' Update the defined resolution so that UNet can process. + + Args: + query_vol_metric (numpy array): query volume size + recep_field (int): defined the receptive field for U-Net + unit_size (float): the defined voxel size + unet_depth (int): U-Net number of layers + ''' + reso = query_vol_metric / unit_size + recep_field - 1 + reso = update_reso(int(reso), unet_depth) # make sure input reso can be processed by UNet + input_vol_metric = reso * unit_size + p_c = np.array([0.0, 0.0, 0.0]).astype(np.float32) + lb_input_vol, ub_input_vol = p_c - input_vol_metric/2, p_c + input_vol_metric/2 + lb_query_vol, ub_query_vol = p_c - query_vol_metric/2, p_c + query_vol_metric/2 + input_vol = [lb_input_vol, ub_input_vol] + query_vol = [lb_query_vol, ub_query_vol] + + # handle the case when resolution is too large + if reso > 10000: + reso = 1 + + return input_vol, query_vol, reso + +def add_key(base, new, base_name, new_name, device=None): + ''' Add new keys to the given input + + Args: + base (tensor): inputs + new (tensor): new info for the inputs + base_name (str): name for the input + new_name (str): name for the new info + device (device): pytorch device + ''' + if (new is not None) and (isinstance(new, dict)): + if device is not None: + for key in new.keys(): + new[key] = new[key].to(device) + base = {base_name: base, + new_name: new} + return base + +class map2local(object): + ''' Add new keys to the given input + + Args: + s (float): the defined voxel size + pos_encoding (str): method for the positional encoding, linear|sin_cos + ''' + def __init__(self, s, pos_encoding='linear'): + super().__init__() + self.s = s + self.pe = positional_encoding(basis_function=pos_encoding) + + def __call__(self, p): + p = torch.remainder(p, self.s) / self.s # always possitive + # p = torch.fmod(p, self.s) / self.s # same sign as input p! + p = self.pe(p) + return p + +class positional_encoding(object): + ''' Positional Encoding (presented in NeRF) + + Args: + basis_function (str): basis function + ''' + def __init__(self, basis_function='sin_cos'): + super().__init__() + self.func = basis_function + + L = 10 + freq_bands = 2.**(np.linspace(0, L-1, L)) + self.freq_bands = freq_bands * math.pi + + def __call__(self, p): + if self.func == 'sin_cos': + out = [] + p = 2.0 * p - 1.0 # chagne to the range [-1, 1] + for freq in self.freq_bands: + out.append(torch.sin(freq * p)) + out.append(torch.cos(freq * p)) + p = torch.cat(out, dim=2) + return p + +# Resnet Blocks +class ResnetBlockFC(nn.Module): + ''' Fully connected ResNet Block class. + + Args: + size_in (int): input dimension + size_out (int): output dimension + size_h (int): hidden dimension + ''' + + def __init__(self, size_in, size_out=None, size_h=None): + super().__init__() + # Attributes + if size_out is None: + size_out = size_in + + if size_h is None: + size_h = min(size_in, size_out) + + self.size_in = size_in + self.size_h = size_h + self.size_out = size_out + # Submodules + self.fc_0 = nn.Linear(size_in, size_h) + self.fc_1 = nn.Linear(size_h, size_out) + self.actvn = nn.ReLU() + + if size_in == size_out: + self.shortcut = None + else: + self.shortcut = nn.Linear(size_in, size_out, bias=False) + # Initialization + nn.init.zeros_(self.fc_1.weight) + + def forward(self, x): + net = self.fc_0(self.actvn(x)) + dx = self.fc_1(self.actvn(net)) + + if self.shortcut is not None: + x_s = self.shortcut(x) + else: + x_s = x + + return x_s + dx + + + +''' +------------------ the key model for Pointnet ---------------------------- +''' + + +class LocalSoftSplat(nn.Module): + + def __init__(self, ch=128, dim=3, hidden_dim=128, scatter_type='max', + unet=True, unet_kwargs=None, unet3d=False, unet3d_kwargs=None, + hw=None, grid_resolution=None, plane_type='xz', padding=0.1, + n_blocks=4, splat_func=None): + super().__init__() + c_dim = ch + + self.c_dim = c_dim + + self.fc_pos = nn.Linear(dim, 2*hidden_dim) + self.blocks = nn.ModuleList([ + ResnetBlockFC(2*hidden_dim, hidden_dim) for i in range(n_blocks) + ]) + self.fc_c = nn.Linear(hidden_dim, c_dim) + + self.actvn = nn.ReLU() + self.hidden_dim = hidden_dim + + if unet: + self.unet = UNet(c_dim, in_channels=c_dim, **unet_kwargs) + else: + self.unet = None + + # get splat func + self.splat_func = splat_func + def forward(self, img_feat, + Fxy2xz, Fxy2yz, Dz, gridxy=None): + """ + Args: + img_feat (tensor): image features + Fxy2xz (tensor): transformation matrix from xy to xz + Fxy2yz (tensor): transformation matrix from xy to yz + """ + B, T, _, H, W = img_feat.shape + fea_reshp = rearrange(img_feat, 'b t c h w -> (b h w) t c', + c=img_feat.shape[2], h=H, w=W) + + gridyz = gridxy + Fxy2yz + gridxz = gridxy + Fxy2xz + # normalize + gridyz[:, 0, ...] = (gridyz[:, 0, ...] / (H - 1) - 0.5) * 2 + gridyz[:, 1, ...] = (gridyz[:, 1, ...] / (Dz - 1) - 0.5) * 2 + gridxz[:, 0, ...] = (gridxz[:, 0, ...] / (W - 1) - 0.5) * 2 + gridxz[:, 1, ...] = (gridxz[:, 1, ...] / (Dz - 1) - 0.5) * 2 + if len(self.blocks) > 0: + net = self.fc_pos(fea_reshp) + net = self.blocks[0](net) + for block in self.blocks[1:]: + # splat and fusion + net_plane = rearrange(net, '(b h w) t c -> (b t) c h w', b=B, h=H, w=W) + + net_planeYZ = self.splat_func(net_plane, Fxy2yz, None, + strMode="avg", tenoutH=Dz, tenoutW=H) + + net_planeXZ = self.splat_func(net_plane, Fxy2xz, None, + strMode="avg", tenoutH=Dz, tenoutW=W) + + net_plane = net_plane + ( + F.grid_sample( + net_planeYZ, gridyz.permute(0,2,3,1), mode='bilinear', padding_mode='border') + + F.grid_sample( + net_planeXZ, gridxz.permute(0,2,3,1), mode='bilinear', padding_mode='border') + ) + + pooled = rearrange(net_plane, 't c h w -> (h w) t c', + c=net_plane.shape[1], h=H, w=W) + + net = torch.cat([net, pooled], dim=2) + net = block(net) + + c = self.fc_c(net) + net_plane = rearrange(c, '(b h w) t c -> (b t) c h w', b=B, h=H, w=W) + else: + net_plane = rearrange(img_feat, 'b t c h w -> (b t) c h w', + c=img_feat.shape[2], h=H, w=W) + net_planeYZ = self.splat_func(net_plane, Fxy2yz, None, + strMode="avg", tenoutH=Dz, tenoutW=H) + net_planeXZ = self.splat_func(net_plane, Fxy2xz, None, + strMode="avg", tenoutH=Dz, tenoutW=W) + + return net_plane[None], net_planeYZ[None], net_planeXZ[None] + + + +class LocalPoolPointnet(nn.Module): + ''' PointNet-based encoder network with ResNet blocks for each point. + Number of input points are fixed. + + Args: + c_dim (int): dimension of latent code c + dim (int): input points dimension + hidden_dim (int): hidden dimension of the network + scatter_type (str): feature aggregation when doing local pooling + unet (bool): weather to use U-Net + unet_kwargs (str): U-Net parameters + unet3d (bool): weather to use 3D U-Net + unet3d_kwargs (str): 3D U-Net parameters + plane_resolution (int): defined resolution for plane feature + grid_resolution (int): defined resolution for grid feature + plane_type (str): feature type, 'xz' - 1-plane, ['xz', 'xy', 'yz'] - 3-plane, ['grid'] - 3D grid volume + padding (float): conventional padding paramter of ONet for unit cube, so [-0.5, 0.5] -> [-0.55, 0.55] + n_blocks (int): number of blocks ResNetBlockFC layers + ''' + + def __init__(self, ch=128, dim=3, hidden_dim=128, scatter_type='max', + unet=True, unet_kwargs=None, unet3d=False, unet3d_kwargs=None, + hw=None, grid_resolution=None, plane_type='xz', padding=0.1, n_blocks=5): + super().__init__() + c_dim = ch + unet3d = False + plane_type = ['xy', 'xz', 'yz'] + plane_resolution = hw + + self.c_dim = c_dim + + self.fc_pos = nn.Linear(dim, 2*hidden_dim) + self.blocks = nn.ModuleList([ + ResnetBlockFC(2*hidden_dim, hidden_dim) for i in range(n_blocks) + ]) + self.fc_c = nn.Linear(hidden_dim, c_dim) + + self.actvn = nn.ReLU() + self.hidden_dim = hidden_dim + + if unet: + self.unet = UNet(c_dim, in_channels=c_dim, **unet_kwargs) + else: + self.unet = None + + if unet3d: + # self.unet3d = UNet3D(**unet3d_kwargs) + raise NotImplementedError() + else: + self.unet3d = None + + self.reso_plane = plane_resolution + self.reso_grid = grid_resolution + self.plane_type = plane_type + self.padding = padding + + if scatter_type == 'max': + self.scatter = scatter_max + elif scatter_type == 'mean': + self.scatter = scatter_mean + else: + raise ValueError('incorrect scatter type') + + def generate_plane_features(self, p, c, plane='xz'): + # acquire indices of features in plane + xy = normalize_coordinate(p.clone(), plane=plane, padding=self.padding) # normalize to the range of (0, 1) + index = coordinate2index(xy, self.reso_plane) + + # scatter plane features from points + fea_plane = c.new_zeros(p.size(0), self.c_dim, self.reso_plane**2) + c = c.permute(0, 2, 1) # B x 512 x T + fea_plane = scatter_mean(c, index, out=fea_plane) # B x 512 x reso^2 + fea_plane = fea_plane.reshape(p.size(0), self.c_dim, self.reso_plane, self.reso_plane) # sparce matrix (B x 512 x reso x reso) + + # process the plane features with UNet + if self.unet is not None: + fea_plane = self.unet(fea_plane) + + return fea_plane + + def generate_grid_features(self, p, c): + p_nor = normalize_3d_coordinate(p.clone(), padding=self.padding) + index = coordinate2index(p_nor, self.reso_grid, coord_type='3d') + # scatter grid features from points + fea_grid = c.new_zeros(p.size(0), self.c_dim, self.reso_grid**3) + c = c.permute(0, 2, 1) + fea_grid = scatter_mean(c, index, out=fea_grid) # B x C x reso^3 + fea_grid = fea_grid.reshape(p.size(0), self.c_dim, self.reso_grid, self.reso_grid, self.reso_grid) # sparce matrix (B x 512 x reso x reso) + + if self.unet3d is not None: + fea_grid = self.unet3d(fea_grid) + + return fea_grid + + def pool_local(self, xy, index, c): + bs, fea_dim = c.size(0), c.size(2) + keys = xy.keys() + + c_out = 0 + for key in keys: + # scatter plane features from points + if key == 'grid': + fea = self.scatter(c.permute(0, 2, 1), index[key], dim_size=self.reso_grid**3) + else: + c_permute = c.permute(0, 2, 1) + fea = self.scatter(c_permute, index[key], dim_size=self.reso_plane**2) + if self.scatter == scatter_max: + fea = fea[0] + # gather feature back to points + fea = fea.gather(dim=2, index=index[key].expand(-1, fea_dim, -1)) + c_out = c_out + fea + return c_out.permute(0, 2, 1) + + + def forward(self, p_input, img_feats=None): + """ + Args: + p_input (tensor): input points T 3 H W + img_feats (tensor): image features T C H W + """ + T, _, H, W = img_feats.size() + p = rearrange(p_input, 't c h w -> (h w) t c', c=3, h=H, w=W) + fea_reshp = rearrange(img_feats, 't c h w -> (h w) t c', + c=img_feats.shape[1], h=H, w=W) + + # acquire the index for each point + coord = {} + index = {} + if 'xz' in self.plane_type: + coord['xz'] = normalize_coordinate(p.clone(), plane='xz', padding=self.padding) + index['xz'] = coordinate2index(coord['xz'], self.reso_plane) + if 'xy' in self.plane_type: + coord['xy'] = normalize_coordinate(p.clone(), plane='xy', padding=self.padding) + index['xy'] = coordinate2index(coord['xy'], self.reso_plane) + if 'yz' in self.plane_type: + coord['yz'] = normalize_coordinate(p.clone(), plane='yz', padding=self.padding) + index['yz'] = coordinate2index(coord['yz'], self.reso_plane) + if 'grid' in self.plane_type: + coord['grid'] = normalize_3d_coordinate(p.clone(), padding=self.padding) + index['grid'] = coordinate2index(coord['grid'], self.reso_grid, coord_type='3d') + + net = self.fc_pos(p) + fea_reshp + net = self.blocks[0](net) + for block in self.blocks[1:]: + pooled = self.pool_local(coord, index, net) + net = torch.cat([net, pooled], dim=2) + net = block(net) + + c = self.fc_c(net) + + fea = {} + + if 'grid' in self.plane_type: + fea['grid'] = self.generate_grid_features(p, c) + if 'xz' in self.plane_type: + fea['xz'] = self.generate_plane_features(p, c, plane='xz') + if 'xy' in self.plane_type: + fea['xy'] = self.generate_plane_features(p, c, plane='xy') + if 'yz' in self.plane_type: + fea['yz'] = self.generate_plane_features(p, c, plane='yz') + + ret = torch.stack([fea['xy'], fea['xz'], fea['yz']]).permute((1, 0, 2, 3, 4)) + return ret + +class PatchLocalPoolPointnet(nn.Module): + ''' PointNet-based encoder network with ResNet blocks. + First transform input points to local system based on the given voxel size. + Support non-fixed number of point cloud, but need to precompute the index + + Args: + c_dim (int): dimension of latent code c + dim (int): input points dimension + hidden_dim (int): hidden dimension of the network + scatter_type (str): feature aggregation when doing local pooling + unet (bool): weather to use U-Net + unet_kwargs (str): U-Net parameters + unet3d (bool): weather to use 3D U-Net + unet3d_kwargs (str): 3D U-Net parameters + plane_resolution (int): defined resolution for plane feature + grid_resolution (int): defined resolution for grid feature + plane_type (str): feature type, 'xz' - 1-plane, ['xz', 'xy', 'yz'] - 3-plane, ['grid'] - 3D grid volume + padding (float): conventional padding paramter of ONet for unit cube, so [-0.5, 0.5] -> [-0.55, 0.55] + n_blocks (int): number of blocks ResNetBlockFC layers + local_coord (bool): whether to use local coordinate + pos_encoding (str): method for the positional encoding, linear|sin_cos + unit_size (float): defined voxel unit size for local system + ''' + + def __init__(self, c_dim=128, dim=3, hidden_dim=128, scatter_type='max', + unet=False, unet_kwargs=None, unet3d=False, unet3d_kwargs=None, + plane_resolution=None, grid_resolution=None, plane_type='xz', padding=0.1, n_blocks=5, + local_coord=False, pos_encoding='linear', unit_size=0.1): + super().__init__() + self.c_dim = c_dim + + self.blocks = nn.ModuleList([ + ResnetBlockFC(2*hidden_dim, hidden_dim) for i in range(n_blocks) + ]) + self.fc_c = nn.Linear(hidden_dim, c_dim) + + self.actvn = nn.ReLU() + self.hidden_dim = hidden_dim + self.reso_plane = plane_resolution + self.reso_grid = grid_resolution + self.plane_type = plane_type + self.padding = padding + + if unet: + self.unet = UNet(c_dim, in_channels=c_dim, **unet_kwargs) + else: + self.unet = None + + if unet3d: + # self.unet3d = UNet3D(**unet3d_kwargs) + raise NotImplementedError() + else: + self.unet3d = None + + if scatter_type == 'max': + self.scatter = scatter_max + elif scatter_type == 'mean': + self.scatter = scatter_mean + else: + raise ValueError('incorrect scatter type') + + if local_coord: + self.map2local = map2local(unit_size, pos_encoding=pos_encoding) + else: + self.map2local = None + + if pos_encoding == 'sin_cos': + self.fc_pos = nn.Linear(60, 2*hidden_dim) + else: + self.fc_pos = nn.Linear(dim, 2*hidden_dim) + + def generate_plane_features(self, index, c): + c = c.permute(0, 2, 1) + # scatter plane features from points + if index.max() < self.reso_plane**2: + fea_plane = c.new_zeros(c.size(0), self.c_dim, self.reso_plane**2) + fea_plane = scatter_mean(c, index, out=fea_plane) # B x c_dim x reso^2 + else: + fea_plane = scatter_mean(c, index) # B x c_dim x reso^2 + if fea_plane.shape[-1] > self.reso_plane**2: # deal with outliers + fea_plane = fea_plane[:, :, :-1] + + fea_plane = fea_plane.reshape(c.size(0), self.c_dim, self.reso_plane, self.reso_plane) + + # process the plane features with UNet + if self.unet is not None: + fea_plane = self.unet(fea_plane) + + return fea_plane + + def generate_grid_features(self, index, c): + # scatter grid features from points + c = c.permute(0, 2, 1) + if index.max() < self.reso_grid**3: + fea_grid = c.new_zeros(c.size(0), self.c_dim, self.reso_grid**3) + fea_grid = scatter_mean(c, index, out=fea_grid) # B x c_dim x reso^3 + else: + fea_grid = scatter_mean(c, index) # B x c_dim x reso^3 + if fea_grid.shape[-1] > self.reso_grid**3: # deal with outliers + fea_grid = fea_grid[:, :, :-1] + fea_grid = fea_grid.reshape(c.size(0), self.c_dim, self.reso_grid, self.reso_grid, self.reso_grid) + + if self.unet3d is not None: + fea_grid = self.unet3d(fea_grid) + + return fea_grid + + def pool_local(self, index, c): + bs, fea_dim = c.size(0), c.size(2) + keys = index.keys() + + c_out = 0 + for key in keys: + # scatter plane features from points + if key == 'grid': + fea = self.scatter(c.permute(0, 2, 1), index[key]) + else: + fea = self.scatter(c.permute(0, 2, 1), index[key]) + if self.scatter == scatter_max: + fea = fea[0] + # gather feature back to points + fea = fea.gather(dim=2, index=index[key].expand(-1, fea_dim, -1)) + c_out += fea + return c_out.permute(0, 2, 1) + + + def forward(self, inputs): + p = inputs['points'] + index = inputs['index'] + + batch_size, T, D = p.size() + + if self.map2local: + pp = self.map2local(p) + net = self.fc_pos(pp) + else: + net = self.fc_pos(p) + + net = self.blocks[0](net) + for block in self.blocks[1:]: + pooled = self.pool_local(index, net) + net = torch.cat([net, pooled], dim=2) + net = block(net) + + c = self.fc_c(net) + + fea = {} + if 'grid' in self.plane_type: + fea['grid'] = self.generate_grid_features(index['grid'], c) + if 'xz' in self.plane_type: + fea['xz'] = self.generate_plane_features(index['xz'], c) + if 'xy' in self.plane_type: + fea['xy'] = self.generate_plane_features(index['xy'], c) + if 'yz' in self.plane_type: + fea['yz'] = self.generate_plane_features(index['yz'], c) + + return fea \ No newline at end of file diff --git a/models/spatracker/models/core/spatracker/loftr/__init__.py b/models/spatracker/models/core/spatracker/loftr/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a343f89f0facafd70637406897f749b9a60732b5 --- /dev/null +++ b/models/spatracker/models/core/spatracker/loftr/__init__.py @@ -0,0 +1 @@ +from .transformer import LocalFeatureTransformer diff --git a/models/spatracker/models/core/spatracker/loftr/linear_attention.py b/models/spatracker/models/core/spatracker/loftr/linear_attention.py new file mode 100644 index 0000000000000000000000000000000000000000..61b1b8573e6454b6d340c20381ad5f945d479791 --- /dev/null +++ b/models/spatracker/models/core/spatracker/loftr/linear_attention.py @@ -0,0 +1,81 @@ +""" +Linear Transformer proposed in "Transformers are RNNs: Fast Autoregressive Transformers with Linear Attention" +Modified from: https://github.com/idiap/fast-transformers/blob/master/fast_transformers/attention/linear_attention.py +""" + +import torch +from torch.nn import Module, Dropout + + +def elu_feature_map(x): + return torch.nn.functional.elu(x) + 1 + + +class LinearAttention(Module): + def __init__(self, eps=1e-6): + super().__init__() + self.feature_map = elu_feature_map + self.eps = eps + + def forward(self, queries, keys, values, q_mask=None, kv_mask=None): + """ Multi-Head linear attention proposed in "Transformers are RNNs" + Args: + queries: [N, L, H, D] + keys: [N, S, H, D] + values: [N, S, H, D] + q_mask: [N, L] + kv_mask: [N, S] + Returns: + queried_values: (N, L, H, D) + """ + Q = self.feature_map(queries) + K = self.feature_map(keys) + + # set padded position to zero + if q_mask is not None: + Q = Q * q_mask[:, :, None, None] + if kv_mask is not None: + K = K * kv_mask[:, :, None, None] + values = values * kv_mask[:, :, None, None] + + v_length = values.size(1) + values = values / v_length # prevent fp16 overflow + KV = torch.einsum("nshd,nshv->nhdv", K, values) # (S,D)' @ S,V + Z = 1 / (torch.einsum("nlhd,nhd->nlh", Q, K.sum(dim=1)) + self.eps) + queried_values = torch.einsum("nlhd,nhdv,nlh->nlhv", Q, KV, Z) * v_length + + return queried_values.contiguous() + + +class FullAttention(Module): + def __init__(self, use_dropout=False, attention_dropout=0.1): + super().__init__() + self.use_dropout = use_dropout + self.dropout = Dropout(attention_dropout) + + def forward(self, queries, keys, values, q_mask=None, kv_mask=None): + """ Multi-head scaled dot-product attention, a.k.a full attention. + Args: + queries: [N, L, H, D] + keys: [N, S, H, D] + values: [N, S, H, D] + q_mask: [N, L] + kv_mask: [N, S] + Returns: + queried_values: (N, L, H, D) + """ + + # Compute the unnormalized attention and apply the masks + QK = torch.einsum("nlhd,nshd->nlsh", queries, keys) + if kv_mask is not None: + QK.masked_fill_(~(q_mask[:, :, None, None] * kv_mask[:, None, :, None]), float('-inf')) + + # Compute the attention and the weighted average + softmax_temp = 1. / queries.size(3)**.5 # sqrt(D) + A = torch.softmax(softmax_temp * QK, dim=2) + if self.use_dropout: + A = self.dropout(A) + + queried_values = torch.einsum("nlsh,nshd->nlhd", A, values) + + return queried_values.contiguous() \ No newline at end of file diff --git a/models/spatracker/models/core/spatracker/loftr/transformer.py b/models/spatracker/models/core/spatracker/loftr/transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..2f6abe749e74062a37365beed5982f16e0f38b20 --- /dev/null +++ b/models/spatracker/models/core/spatracker/loftr/transformer.py @@ -0,0 +1,142 @@ +''' +modified from +https://github.com/zju3dv/LoFTR/blob/master/src/loftr/loftr_module/transformer.py +''' +import torch +from torch.nn import Module, Dropout +import copy +import torch.nn as nn +import torch.nn.functional as F + + +def elu_feature_map(x): + return torch.nn.functional.elu(x) + 1 + +class FullAttention(Module): + def __init__(self, use_dropout=False, attention_dropout=0.1): + super().__init__() + self.use_dropout = use_dropout + self.dropout = Dropout(attention_dropout) + + def forward(self, queries, keys, values, q_mask=None, kv_mask=None): + """ Multi-head scaled dot-product attention, a.k.a full attention. + Args: + queries: [N, L, H, D] + keys: [N, S, H, D] + values: [N, S, H, D] + q_mask: [N, L] + kv_mask: [N, S] + Returns: + queried_values: (N, L, H, D) + """ + + # Compute the unnormalized attention and apply the masks + # QK = torch.einsum("nlhd,nshd->nlsh", queries, keys) + # if kv_mask is not None: + # QK.masked_fill_(~(q_mask[:, :, None, None] * kv_mask[:, None, :, None]), float(-1e12)) + # softmax_temp = 1. / queries.size(3)**.5 # sqrt(D) + # A = torch.softmax(softmax_temp * QK, dim=2) + # if self.use_dropout: + # A = self.dropout(A) + # queried_values_ = torch.einsum("nlsh,nshd->nlhd", A, values) + + # Compute the attention and the weighted average + input_args = [x.half().contiguous() for x in [queries.permute(0,2,1,3), keys.permute(0,2,1,3), values.permute(0,2,1,3)]] + queried_values = F.scaled_dot_product_attention(*input_args).permute(0,2,1,3).float() # type: ignore + + + return queried_values.contiguous() + +class TransformerEncoderLayer(nn.Module): + def __init__(self, + d_model, + nhead,): + super(TransformerEncoderLayer, self).__init__() + + self.dim = d_model // nhead + self.nhead = nhead + + # multi-head attention + self.q_proj = nn.Linear(d_model, d_model, bias=False) + self.k_proj = nn.Linear(d_model, d_model, bias=False) + self.v_proj = nn.Linear(d_model, d_model, bias=False) + self.attention = FullAttention() + self.merge = nn.Linear(d_model, d_model, bias=False) + + # feed-forward network + self.mlp = nn.Sequential( + nn.Linear(d_model*2, d_model*2, bias=False), + nn.ReLU(True), + nn.Linear(d_model*2, d_model, bias=False), + ) + + # norm and dropout + self.norm1 = nn.LayerNorm(d_model) + self.norm2 = nn.LayerNorm(d_model) + + def forward(self, x, source, x_mask=None, source_mask=None): + """ + Args: + x (torch.Tensor): [N, L, C] + source (torch.Tensor): [N, S, C] + x_mask (torch.Tensor): [N, L] (optional) + source_mask (torch.Tensor): [N, S] (optional) + """ + bs = x.size(0) + query, key, value = x, source, source + + # multi-head attention + query = self.q_proj(query).view(bs, -1, self.nhead, self.dim) # [N, L, (H, D)] + key = self.k_proj(key).view(bs, -1, self.nhead, self.dim) # [N, S, (H, D)] + value = self.v_proj(value).view(bs, -1, self.nhead, self.dim) + message = self.attention(query, key, value, q_mask=x_mask, kv_mask=source_mask) # [N, L, (H, D)] + message = self.merge(message.view(bs, -1, self.nhead*self.dim)) # [N, L, C] + message = self.norm1(message) + + # feed-forward network + message = self.mlp(torch.cat([x, message], dim=2)) + message = self.norm2(message) + + return x + message + +class LocalFeatureTransformer(nn.Module): + """A Local Feature Transformer module.""" + + def __init__(self, config): + super(LocalFeatureTransformer, self).__init__() + + self.config = config + self.d_model = config['d_model'] + self.nhead = config['nhead'] + self.layer_names = config['layer_names'] + encoder_layer = TransformerEncoderLayer(config['d_model'], config['nhead']) + self.layers = nn.ModuleList([copy.deepcopy(encoder_layer) for _ in range(len(self.layer_names))]) + self._reset_parameters() + + def _reset_parameters(self): + for p in self.parameters(): + if p.dim() > 1: + nn.init.xavier_uniform_(p) + + def forward(self, feat0, feat1, mask0=None, mask1=None): + """ + Args: + feat0 (torch.Tensor): [N, L, C] + feat1 (torch.Tensor): [N, S, C] + mask0 (torch.Tensor): [N, L] (optional) + mask1 (torch.Tensor): [N, S] (optional) + """ + + assert self.d_model == feat0.size(2), "the feature number of src and transformer must be equal" + + for layer, name in zip(self.layers, self.layer_names): + if name == 'self': + feat0 = layer(feat0, feat0, mask0, mask0) + feat1 = layer(feat1, feat1, mask1, mask1) + elif name == 'cross': + feat0 = layer(feat0, feat1, mask0, mask1) + feat1 = layer(feat1, feat0, mask1, mask0) + else: + raise KeyError + + return feat0, feat1 \ No newline at end of file diff --git a/models/spatracker/models/core/spatracker/losses.py b/models/spatracker/models/core/spatracker/losses.py new file mode 100644 index 0000000000000000000000000000000000000000..42603febcb33ce7d46901b5c1674f1bac5afb99d --- /dev/null +++ b/models/spatracker/models/core/spatracker/losses.py @@ -0,0 +1,90 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import torch +import torch.nn.functional as F +from models.spatracker.models.core.model_utils import reduce_masked_mean +from models.spatracker.models.core.spatracker.blocks import ( + pix2cam +) +from models.spatracker.models.core.model_utils import ( + bilinear_sample2d +) + +EPS = 1e-6 +import torchvision.transforms.functional as TF + +sigma = 3 +x_grid = torch.arange(-7,8,1) +y_grid = torch.arange(-7,8,1) +x_grid, y_grid = torch.meshgrid(x_grid, y_grid) +gridxy = torch.stack([x_grid, y_grid], dim=-1).float() +gs_kernel = torch.exp(-torch.sum(gridxy**2, dim=-1)/(2*sigma**2)) + + +def balanced_ce_loss(pred, gt, valid=None): + total_balanced_loss = 0.0 + for j in range(len(gt)): + B, S, N = gt[j].shape + # pred and gt are the same shape + for (a, b) in zip(pred[j].size(), gt[j].size()): + assert a == b # some shape mismatch! + # if valid is not None: + for (a, b) in zip(pred[j].size(), valid[j].size()): + assert a == b # some shape mismatch! + + pos = (gt[j] > 0.95).float() + neg = (gt[j] < 0.05).float() + + label = pos * 2.0 - 1.0 + a = -label * pred[j] + b = F.relu(a) + loss = b + torch.log(torch.exp(-b) + torch.exp(a - b)) + + pos_loss = reduce_masked_mean(loss, pos * valid[j]) + neg_loss = reduce_masked_mean(loss, neg * valid[j]) + balanced_loss = pos_loss + neg_loss + total_balanced_loss += balanced_loss / float(N) + import ipdb; ipdb.set_trace() + return total_balanced_loss + + +def sequence_loss(flow_preds, flow_gt, vis, valids, gamma=0.8, + intr=None, trajs_g_all=None): + """Loss function defined over sequence of flow predictions""" + total_flow_loss = 0.0 + + for j in range(len(flow_gt)): + B, S, N, D = flow_gt[j].shape + # assert D == 3 + B, S1, N = vis[j].shape + B, S2, N = valids[j].shape + assert S == S1 + assert S == S2 + n_predictions = len(flow_preds[j]) + if intr is not None: + intr_i = intr[j] + flow_loss = 0.0 + for i in range(n_predictions): + i_weight = gamma ** (n_predictions - i - 1) + flow_pred = flow_preds[j][i][..., -N:, :D] + flow_gt_j = flow_gt[j].clone() + if intr is not None: + xyz_j_gt = pix2cam(flow_gt_j, intr_i) + try: + i_loss = (flow_pred - flow_gt_j).abs() # B, S, N, 3 + except: + import ipdb; ipdb.set_trace() + if D==3: + i_loss[...,2]*=30 + i_loss = torch.mean(i_loss, dim=3) # B, S, N + flow_loss += i_weight * (reduce_masked_mean(i_loss, valids[j])) + + flow_loss = flow_loss / n_predictions + total_flow_loss += flow_loss / float(N) + + + return total_flow_loss diff --git a/models/spatracker/models/core/spatracker/softsplat.py b/models/spatracker/models/core/spatracker/softsplat.py new file mode 100644 index 0000000000000000000000000000000000000000..fc4aad51ac0d11cd8e435aeb0457dc7eb5301643 --- /dev/null +++ b/models/spatracker/models/core/spatracker/softsplat.py @@ -0,0 +1,539 @@ +#!/usr/bin/env python + +"""The code of softsplat function is modified from: +https://github.com/sniklaus/softmax-splatting/blob/master/softsplat.py + +""" + + +import collections +import cupy +import os +import re +import torch +import typing + + +########################################################## + + +objCudacache = {} + + +def cuda_int32(intIn:int): + return cupy.int32(intIn) +# end + + +def cuda_float32(fltIn:float): + return cupy.float32(fltIn) +# end + + +def cuda_kernel(strFunction:str, strKernel:str, objVariables:typing.Dict): + if 'device' not in objCudacache: + objCudacache['device'] = torch.cuda.get_device_name() + # end + + strKey = strFunction + + for strVariable in objVariables: + objValue = objVariables[strVariable] + + strKey += strVariable + + if objValue is None: + continue + + elif type(objValue) == int: + strKey += str(objValue) + + elif type(objValue) == float: + strKey += str(objValue) + + elif type(objValue) == bool: + strKey += str(objValue) + + elif type(objValue) == str: + strKey += objValue + + elif type(objValue) == torch.Tensor: + strKey += str(objValue.dtype) + strKey += str(objValue.shape) + strKey += str(objValue.stride()) + + elif True: + print(strVariable, type(objValue)) + assert(False) + + # end + # end + + strKey += objCudacache['device'] + + if strKey not in objCudacache: + for strVariable in objVariables: + objValue = objVariables[strVariable] + + if objValue is None: + continue + + elif type(objValue) == int: + strKernel = strKernel.replace('{{' + strVariable + '}}', str(objValue)) + + elif type(objValue) == float: + strKernel = strKernel.replace('{{' + strVariable + '}}', str(objValue)) + + elif type(objValue) == bool: + strKernel = strKernel.replace('{{' + strVariable + '}}', str(objValue)) + + elif type(objValue) == str: + strKernel = strKernel.replace('{{' + strVariable + '}}', objValue) + + elif type(objValue) == torch.Tensor and objValue.dtype == torch.uint8: + strKernel = strKernel.replace('{{type}}', 'unsigned char') + + elif type(objValue) == torch.Tensor and objValue.dtype == torch.float16: + strKernel = strKernel.replace('{{type}}', 'half') + + elif type(objValue) == torch.Tensor and objValue.dtype == torch.float32: + strKernel = strKernel.replace('{{type}}', 'float') + + elif type(objValue) == torch.Tensor and objValue.dtype == torch.float64: + strKernel = strKernel.replace('{{type}}', 'double') + + elif type(objValue) == torch.Tensor and objValue.dtype == torch.int32: + strKernel = strKernel.replace('{{type}}', 'int') + + elif type(objValue) == torch.Tensor and objValue.dtype == torch.int64: + strKernel = strKernel.replace('{{type}}', 'long') + + elif type(objValue) == torch.Tensor: + print(strVariable, objValue.dtype) + assert(False) + + elif True: + print(strVariable, type(objValue)) + assert(False) + + # end + # end + + while True: + objMatch = re.search('(SIZE_)([0-4])(\()([^\)]*)(\))', strKernel) + + if objMatch is None: + break + # end + + intArg = int(objMatch.group(2)) + + strTensor = objMatch.group(4) + intSizes = objVariables[strTensor].size() + + strKernel = strKernel.replace(objMatch.group(), str(intSizes[intArg] if torch.is_tensor(intSizes[intArg]) == False else intSizes[intArg].item())) + # end + + while True: + objMatch = re.search('(OFFSET_)([0-4])(\()', strKernel) + + if objMatch is None: + break + # end + + intStart = objMatch.span()[1] + intStop = objMatch.span()[1] + intParentheses = 1 + + while True: + intParentheses += 1 if strKernel[intStop] == '(' else 0 + intParentheses -= 1 if strKernel[intStop] == ')' else 0 + + if intParentheses == 0: + break + # end + + intStop += 1 + # end + + intArgs = int(objMatch.group(2)) + strArgs = strKernel[intStart:intStop].split(',') + + assert(intArgs == len(strArgs) - 1) + + strTensor = strArgs[0] + intStrides = objVariables[strTensor].stride() + + strIndex = [] + + for intArg in range(intArgs): + strIndex.append('((' + strArgs[intArg + 1].replace('{', '(').replace('}', ')').strip() + ')*' + str(intStrides[intArg] if torch.is_tensor(intStrides[intArg]) == False else intStrides[intArg].item()) + ')') + # end + + strKernel = strKernel.replace('OFFSET_' + str(intArgs) + '(' + strKernel[intStart:intStop] + ')', '(' + str.join('+', strIndex) + ')') + # end + + while True: + objMatch = re.search('(VALUE_)([0-4])(\()', strKernel) + + if objMatch is None: + break + # end + + intStart = objMatch.span()[1] + intStop = objMatch.span()[1] + intParentheses = 1 + + while True: + intParentheses += 1 if strKernel[intStop] == '(' else 0 + intParentheses -= 1 if strKernel[intStop] == ')' else 0 + + if intParentheses == 0: + break + # end + + intStop += 1 + # end + + intArgs = int(objMatch.group(2)) + strArgs = strKernel[intStart:intStop].split(',') + + assert(intArgs == len(strArgs) - 1) + + strTensor = strArgs[0] + intStrides = objVariables[strTensor].stride() + + strIndex = [] + + for intArg in range(intArgs): + strIndex.append('((' + strArgs[intArg + 1].replace('{', '(').replace('}', ')').strip() + ')*' + str(intStrides[intArg] if torch.is_tensor(intStrides[intArg]) == False else intStrides[intArg].item()) + ')') + # end + + strKernel = strKernel.replace('VALUE_' + str(intArgs) + '(' + strKernel[intStart:intStop] + ')', strTensor + '[' + str.join('+', strIndex) + ']') + # end + + objCudacache[strKey] = { + 'strFunction': strFunction, + 'strKernel': strKernel + } + # end + + return strKey +# end + + +@cupy.memoize(for_each_device=True) +def cuda_launch(strKey:str): + if 'CUDA_HOME' not in os.environ: + os.environ['CUDA_HOME'] = cupy.cuda.get_cuda_path() + # end + + return cupy.RawKernel(objCudacache[strKey]['strKernel'], objCudacache[strKey]['strFunction']) +# end + + +########################################################## + + +def softsplat(tenIn:torch.Tensor, tenFlow:torch.Tensor, + tenMetric:torch.Tensor, strMode:str, tenoutH=None, tenoutW=None): + assert(strMode.split('-')[0] in ['sum', 'avg', 'linear', 'soft']) + + if strMode == 'sum': assert(tenMetric is None) + if strMode == 'avg': assert(tenMetric is None) + if strMode.split('-')[0] == 'linear': assert(tenMetric is not None) + if strMode.split('-')[0] == 'soft': assert(tenMetric is not None) + + if strMode == 'avg': + tenIn = torch.cat([tenIn, tenIn.new_ones([tenIn.shape[0], 1, tenIn.shape[2], tenIn.shape[3]])], 1) + + elif strMode.split('-')[0] == 'linear': + tenIn = torch.cat([tenIn * tenMetric, tenMetric], 1) + + elif strMode.split('-')[0] == 'soft': + tenIn = torch.cat([tenIn * tenMetric.exp(), tenMetric.exp()], 1) + + # end + + tenOut = softsplat_func.apply(tenIn, tenFlow, tenoutH, tenoutW) + + if strMode.split('-')[0] in ['avg', 'linear', 'soft']: + tenNormalize = tenOut[:, -1:, :, :] + + if len(strMode.split('-')) == 1: + tenNormalize = tenNormalize + 0.0000001 + + elif strMode.split('-')[1] == 'addeps': + tenNormalize = tenNormalize + 0.0000001 + + elif strMode.split('-')[1] == 'zeroeps': + tenNormalize[tenNormalize == 0.0] = 1.0 + + elif strMode.split('-')[1] == 'clipeps': + tenNormalize = tenNormalize.clip(0.0000001, None) + + # end + tenOut = tenOut[:, :-1, :, :] / tenNormalize + # end + + return tenOut +# end + + +class softsplat_func(torch.autograd.Function): + @staticmethod + @torch.cuda.amp.custom_fwd(cast_inputs=torch.float32) + def forward(self, tenIn, tenFlow, H=None, W=None): + if H is None: + tenOut = tenIn.new_zeros([tenIn.shape[0], tenIn.shape[1], tenIn.shape[2], tenIn.shape[3]]) + else: + tenOut = tenIn.new_zeros([tenIn.shape[0], tenIn.shape[1], H, W]) + + if tenIn.is_cuda == True: + cuda_launch(cuda_kernel('softsplat_out', ''' + extern "C" __global__ void __launch_bounds__(512) softsplat_out( + const int n, + const {{type}}* __restrict__ tenIn, + const {{type}}* __restrict__ tenFlow, + {{type}}* __restrict__ tenOut + ) { for (int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x; intIndex < n; intIndex += blockDim.x * gridDim.x) { + const int intN = ( intIndex / SIZE_3(tenIn) / SIZE_2(tenIn) / SIZE_1(tenIn) ) % SIZE_0(tenIn); + const int intC = ( intIndex / SIZE_3(tenIn) / SIZE_2(tenIn) ) % SIZE_1(tenIn); + const int intY = ( intIndex / SIZE_3(tenIn) ) % SIZE_2(tenIn); + const int intX = ( intIndex ) % SIZE_3(tenIn); + + assert(SIZE_1(tenFlow) == 2); + + {{type}} fltX = ({{type}}) (intX) + VALUE_4(tenFlow, intN, 0, intY, intX); + {{type}} fltY = ({{type}}) (intY) + VALUE_4(tenFlow, intN, 1, intY, intX); + + if (isfinite(fltX) == false) { return; } + if (isfinite(fltY) == false) { return; } + + {{type}} fltIn = VALUE_4(tenIn, intN, intC, intY, intX); + + int intNorthwestX = (int) (floor(fltX)); + int intNorthwestY = (int) (floor(fltY)); + int intNortheastX = intNorthwestX + 1; + int intNortheastY = intNorthwestY; + int intSouthwestX = intNorthwestX; + int intSouthwestY = intNorthwestY + 1; + int intSoutheastX = intNorthwestX + 1; + int intSoutheastY = intNorthwestY + 1; + + {{type}} fltNorthwest = (({{type}}) (intSoutheastX) - fltX) * (({{type}}) (intSoutheastY) - fltY); + {{type}} fltNortheast = (fltX - ({{type}}) (intSouthwestX)) * (({{type}}) (intSouthwestY) - fltY); + {{type}} fltSouthwest = (({{type}}) (intNortheastX) - fltX) * (fltY - ({{type}}) (intNortheastY)); + {{type}} fltSoutheast = (fltX - ({{type}}) (intNorthwestX)) * (fltY - ({{type}}) (intNorthwestY)); + + if ((intNorthwestX >= 0) && (intNorthwestX < SIZE_3(tenOut)) && (intNorthwestY >= 0) && (intNorthwestY < SIZE_2(tenOut))) { + atomicAdd(&tenOut[OFFSET_4(tenOut, intN, intC, intNorthwestY, intNorthwestX)], fltIn * fltNorthwest); + } + + if ((intNortheastX >= 0) && (intNortheastX < SIZE_3(tenOut)) && (intNortheastY >= 0) && (intNortheastY < SIZE_2(tenOut))) { + atomicAdd(&tenOut[OFFSET_4(tenOut, intN, intC, intNortheastY, intNortheastX)], fltIn * fltNortheast); + } + + if ((intSouthwestX >= 0) && (intSouthwestX < SIZE_3(tenOut)) && (intSouthwestY >= 0) && (intSouthwestY < SIZE_2(tenOut))) { + atomicAdd(&tenOut[OFFSET_4(tenOut, intN, intC, intSouthwestY, intSouthwestX)], fltIn * fltSouthwest); + } + + if ((intSoutheastX >= 0) && (intSoutheastX < SIZE_3(tenOut)) && (intSoutheastY >= 0) && (intSoutheastY < SIZE_2(tenOut))) { + atomicAdd(&tenOut[OFFSET_4(tenOut, intN, intC, intSoutheastY, intSoutheastX)], fltIn * fltSoutheast); + } + } } + ''', { + 'tenIn': tenIn, + 'tenFlow': tenFlow, + 'tenOut': tenOut + }))( + grid=tuple([int((tenIn.nelement() + 512 - 1) / 512), 1, 1]), + block=tuple([512, 1, 1]), + args=[cuda_int32(tenOut.nelement()), tenIn.data_ptr(), tenFlow.data_ptr(), tenOut.data_ptr()], + stream=collections.namedtuple('Stream', 'ptr')(torch.cuda.current_stream().cuda_stream) + ) + + elif tenIn.is_cuda != True: + assert(False) + + # end + + self.save_for_backward(tenIn, tenFlow) + + return tenOut + # end + + @staticmethod + @torch.cuda.amp.custom_bwd + def backward(self, tenOutgrad): + tenIn, tenFlow = self.saved_tensors + + tenOutgrad = tenOutgrad.contiguous(); assert(tenOutgrad.is_cuda == True) + + tenIngrad = tenIn.new_zeros([tenIn.shape[0], tenIn.shape[1], tenIn.shape[2], tenIn.shape[3]]) if self.needs_input_grad[0] == True else None + tenFlowgrad = tenFlow.new_zeros([tenFlow.shape[0], tenFlow.shape[1], tenFlow.shape[2], tenFlow.shape[3]]) if self.needs_input_grad[1] == True else None + Hgrad = None + Wgrad = None + + if tenIngrad is not None: + cuda_launch(cuda_kernel('softsplat_ingrad', ''' + extern "C" __global__ void __launch_bounds__(512) softsplat_ingrad( + const int n, + const {{type}}* __restrict__ tenIn, + const {{type}}* __restrict__ tenFlow, + const {{type}}* __restrict__ tenOutgrad, + {{type}}* __restrict__ tenIngrad, + {{type}}* __restrict__ tenFlowgrad + ) { for (int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x; intIndex < n; intIndex += blockDim.x * gridDim.x) { + const int intN = ( intIndex / SIZE_3(tenIngrad) / SIZE_2(tenIngrad) / SIZE_1(tenIngrad) ) % SIZE_0(tenIngrad); + const int intC = ( intIndex / SIZE_3(tenIngrad) / SIZE_2(tenIngrad) ) % SIZE_1(tenIngrad); + const int intY = ( intIndex / SIZE_3(tenIngrad) ) % SIZE_2(tenIngrad); + const int intX = ( intIndex ) % SIZE_3(tenIngrad); + + assert(SIZE_1(tenFlow) == 2); + + {{type}} fltIngrad = 0.0f; + + {{type}} fltX = ({{type}}) (intX) + VALUE_4(tenFlow, intN, 0, intY, intX); + {{type}} fltY = ({{type}}) (intY) + VALUE_4(tenFlow, intN, 1, intY, intX); + + if (isfinite(fltX) == false) { return; } + if (isfinite(fltY) == false) { return; } + + int intNorthwestX = (int) (floor(fltX)); + int intNorthwestY = (int) (floor(fltY)); + int intNortheastX = intNorthwestX + 1; + int intNortheastY = intNorthwestY; + int intSouthwestX = intNorthwestX; + int intSouthwestY = intNorthwestY + 1; + int intSoutheastX = intNorthwestX + 1; + int intSoutheastY = intNorthwestY + 1; + + {{type}} fltNorthwest = (({{type}}) (intSoutheastX) - fltX) * (({{type}}) (intSoutheastY) - fltY); + {{type}} fltNortheast = (fltX - ({{type}}) (intSouthwestX)) * (({{type}}) (intSouthwestY) - fltY); + {{type}} fltSouthwest = (({{type}}) (intNortheastX) - fltX) * (fltY - ({{type}}) (intNortheastY)); + {{type}} fltSoutheast = (fltX - ({{type}}) (intNorthwestX)) * (fltY - ({{type}}) (intNorthwestY)); + + if ((intNorthwestX >= 0) && (intNorthwestX < SIZE_3(tenOutgrad)) && (intNorthwestY >= 0) && (intNorthwestY < SIZE_2(tenOutgrad))) { + fltIngrad += VALUE_4(tenOutgrad, intN, intC, intNorthwestY, intNorthwestX) * fltNorthwest; + } + + if ((intNortheastX >= 0) && (intNortheastX < SIZE_3(tenOutgrad)) && (intNortheastY >= 0) && (intNortheastY < SIZE_2(tenOutgrad))) { + fltIngrad += VALUE_4(tenOutgrad, intN, intC, intNortheastY, intNortheastX) * fltNortheast; + } + + if ((intSouthwestX >= 0) && (intSouthwestX < SIZE_3(tenOutgrad)) && (intSouthwestY >= 0) && (intSouthwestY < SIZE_2(tenOutgrad))) { + fltIngrad += VALUE_4(tenOutgrad, intN, intC, intSouthwestY, intSouthwestX) * fltSouthwest; + } + + if ((intSoutheastX >= 0) && (intSoutheastX < SIZE_3(tenOutgrad)) && (intSoutheastY >= 0) && (intSoutheastY < SIZE_2(tenOutgrad))) { + fltIngrad += VALUE_4(tenOutgrad, intN, intC, intSoutheastY, intSoutheastX) * fltSoutheast; + } + + tenIngrad[intIndex] = fltIngrad; + } } + ''', { + 'tenIn': tenIn, + 'tenFlow': tenFlow, + 'tenOutgrad': tenOutgrad, + 'tenIngrad': tenIngrad, + 'tenFlowgrad': tenFlowgrad + }))( + grid=tuple([int((tenIngrad.nelement() + 512 - 1) / 512), 1, 1]), + block=tuple([512, 1, 1]), + args=[cuda_int32(tenIngrad.nelement()), tenIn.data_ptr(), tenFlow.data_ptr(), tenOutgrad.data_ptr(), tenIngrad.data_ptr(), None], + stream=collections.namedtuple('Stream', 'ptr')(torch.cuda.current_stream().cuda_stream) + ) + # end + + if tenFlowgrad is not None: + cuda_launch(cuda_kernel('softsplat_flowgrad', ''' + extern "C" __global__ void __launch_bounds__(512) softsplat_flowgrad( + const int n, + const {{type}}* __restrict__ tenIn, + const {{type}}* __restrict__ tenFlow, + const {{type}}* __restrict__ tenOutgrad, + {{type}}* __restrict__ tenIngrad, + {{type}}* __restrict__ tenFlowgrad + ) { for (int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x; intIndex < n; intIndex += blockDim.x * gridDim.x) { + const int intN = ( intIndex / SIZE_3(tenFlowgrad) / SIZE_2(tenFlowgrad) / SIZE_1(tenFlowgrad) ) % SIZE_0(tenFlowgrad); + const int intC = ( intIndex / SIZE_3(tenFlowgrad) / SIZE_2(tenFlowgrad) ) % SIZE_1(tenFlowgrad); + const int intY = ( intIndex / SIZE_3(tenFlowgrad) ) % SIZE_2(tenFlowgrad); + const int intX = ( intIndex ) % SIZE_3(tenFlowgrad); + + assert(SIZE_1(tenFlow) == 2); + + {{type}} fltFlowgrad = 0.0f; + + {{type}} fltX = ({{type}}) (intX) + VALUE_4(tenFlow, intN, 0, intY, intX); + {{type}} fltY = ({{type}}) (intY) + VALUE_4(tenFlow, intN, 1, intY, intX); + + if (isfinite(fltX) == false) { return; } + if (isfinite(fltY) == false) { return; } + + int intNorthwestX = (int) (floor(fltX)); + int intNorthwestY = (int) (floor(fltY)); + int intNortheastX = intNorthwestX + 1; + int intNortheastY = intNorthwestY; + int intSouthwestX = intNorthwestX; + int intSouthwestY = intNorthwestY + 1; + int intSoutheastX = intNorthwestX + 1; + int intSoutheastY = intNorthwestY + 1; + + {{type}} fltNorthwest = 0.0f; + {{type}} fltNortheast = 0.0f; + {{type}} fltSouthwest = 0.0f; + {{type}} fltSoutheast = 0.0f; + + if (intC == 0) { + fltNorthwest = (({{type}}) (-1.0f)) * (({{type}}) (intSoutheastY) - fltY); + fltNortheast = (({{type}}) (+1.0f)) * (({{type}}) (intSouthwestY) - fltY); + fltSouthwest = (({{type}}) (-1.0f)) * (fltY - ({{type}}) (intNortheastY)); + fltSoutheast = (({{type}}) (+1.0f)) * (fltY - ({{type}}) (intNorthwestY)); + + } else if (intC == 1) { + fltNorthwest = (({{type}}) (intSoutheastX) - fltX) * (({{type}}) (-1.0f)); + fltNortheast = (fltX - ({{type}}) (intSouthwestX)) * (({{type}}) (-1.0f)); + fltSouthwest = (({{type}}) (intNortheastX) - fltX) * (({{type}}) (+1.0f)); + fltSoutheast = (fltX - ({{type}}) (intNorthwestX)) * (({{type}}) (+1.0f)); + + } + + for (int intChannel = 0; intChannel < SIZE_1(tenOutgrad); intChannel += 1) { + {{type}} fltIn = VALUE_4(tenIn, intN, intChannel, intY, intX); + + if ((intNorthwestX >= 0) && (intNorthwestX < SIZE_3(tenOutgrad)) && (intNorthwestY >= 0) && (intNorthwestY < SIZE_2(tenOutgrad))) { + fltFlowgrad += VALUE_4(tenOutgrad, intN, intChannel, intNorthwestY, intNorthwestX) * fltIn * fltNorthwest; + } + + if ((intNortheastX >= 0) && (intNortheastX < SIZE_3(tenOutgrad)) && (intNortheastY >= 0) && (intNortheastY < SIZE_2(tenOutgrad))) { + fltFlowgrad += VALUE_4(tenOutgrad, intN, intChannel, intNortheastY, intNortheastX) * fltIn * fltNortheast; + } + + if ((intSouthwestX >= 0) && (intSouthwestX < SIZE_3(tenOutgrad)) && (intSouthwestY >= 0) && (intSouthwestY < SIZE_2(tenOutgrad))) { + fltFlowgrad += VALUE_4(tenOutgrad, intN, intChannel, intSouthwestY, intSouthwestX) * fltIn * fltSouthwest; + } + + if ((intSoutheastX >= 0) && (intSoutheastX < SIZE_3(tenOutgrad)) && (intSoutheastY >= 0) && (intSoutheastY < SIZE_2(tenOutgrad))) { + fltFlowgrad += VALUE_4(tenOutgrad, intN, intChannel, intSoutheastY, intSoutheastX) * fltIn * fltSoutheast; + } + } + + tenFlowgrad[intIndex] = fltFlowgrad; + } } + ''', { + 'tenIn': tenIn, + 'tenFlow': tenFlow, + 'tenOutgrad': tenOutgrad, + 'tenIngrad': tenIngrad, + 'tenFlowgrad': tenFlowgrad + }))( + grid=tuple([int((tenFlowgrad.nelement() + 512 - 1) / 512), 1, 1]), + block=tuple([512, 1, 1]), + args=[cuda_int32(tenFlowgrad.nelement()), tenIn.data_ptr(), tenFlow.data_ptr(), tenOutgrad.data_ptr(), None, tenFlowgrad.data_ptr()], + stream=collections.namedtuple('Stream', 'ptr')(torch.cuda.current_stream().cuda_stream) + ) + # end + return tenIngrad, tenFlowgrad, Hgrad, Wgrad + # end +# end \ No newline at end of file diff --git a/models/spatracker/models/core/spatracker/spatracker.py b/models/spatracker/models/core/spatracker/spatracker.py new file mode 100644 index 0000000000000000000000000000000000000000..577cac62a6c861b688af67edb3ecbc9f621ba397 --- /dev/null +++ b/models/spatracker/models/core/spatracker/spatracker.py @@ -0,0 +1,732 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import torch +import torch.nn as nn +from easydict import EasyDict as edict +from einops import rearrange +from sklearn.cluster import SpectralClustering +from models.spatracker.models.core.spatracker.blocks import Lie +import matplotlib.pyplot as plt +import cv2 + +import torch.nn.functional as F +from models.spatracker.models.core.spatracker.blocks import ( + BasicEncoder, + CorrBlock, + EUpdateFormer, + FusionFormer, + pix2cam, + cam2pix, + edgeMat, + VitEncoder, + DPTEnc, + Dinov2 +) + +from models.spatracker.models.core.spatracker.feature_net import ( + LocalSoftSplat +) + +from models.spatracker.models.core.model_utils import ( + meshgrid2d, bilinear_sample2d, smart_cat, sample_features5d, vis_PCA +) +from models.spatracker.models.core.embeddings import ( + get_2d_embedding, + get_3d_embedding, + get_1d_sincos_pos_embed_from_grid, + get_2d_sincos_pos_embed, + get_3d_sincos_pos_embed_from_grid, + Embedder_Fourier, +) +import numpy as np +from models.spatracker.models.core.spatracker.softsplat import softsplat + +torch.manual_seed(0) + + +def get_points_on_a_grid(grid_size, interp_shape, + grid_center=(0, 0), device="cuda"): + if grid_size == 1: + return torch.tensor([interp_shape[1] / 2, + interp_shape[0] / 2], device=device)[ + None, None + ] + + grid_y, grid_x = meshgrid2d( + 1, grid_size, grid_size, stack=False, norm=False, device=device + ) + step = interp_shape[1] // 64 + if grid_center[0] != 0 or grid_center[1] != 0: + grid_y = grid_y - grid_size / 2.0 + grid_x = grid_x - grid_size / 2.0 + grid_y = step + grid_y.reshape(1, -1) / float(grid_size - 1) * ( + interp_shape[0] - step * 2 + ) + grid_x = step + grid_x.reshape(1, -1) / float(grid_size - 1) * ( + interp_shape[1] - step * 2 + ) + + grid_y = grid_y + grid_center[0] + grid_x = grid_x + grid_center[1] + xy = torch.stack([grid_x, grid_y], dim=-1).to(device) + return xy + + +def sample_pos_embed(grid_size, embed_dim, coords): + if coords.shape[-1] == 2: + pos_embed = get_2d_sincos_pos_embed(embed_dim=embed_dim, + grid_size=grid_size) + pos_embed = ( + torch.from_numpy(pos_embed) + .reshape(grid_size[0], grid_size[1], embed_dim) + .float() + .unsqueeze(0) + .to(coords.device) + ) + sampled_pos_embed = bilinear_sample2d( + pos_embed.permute(0, 3, 1, 2), + coords[:, 0, :, 0], coords[:, 0, :, 1] + ) + elif coords.shape[-1] == 3: + sampled_pos_embed = get_3d_sincos_pos_embed_from_grid( + embed_dim, coords[:, :1, ...] + ).float()[:,0,...].permute(0, 2, 1) + + return sampled_pos_embed + + +class SpaTracker(nn.Module): + def __init__( + self, + S=8, + stride=8, + add_space_attn=True, + num_heads=8, + hidden_size=384, + space_depth=12, + time_depth=12, + args=edict({}) + ): + super(SpaTracker, self).__init__() + + # step1: config the arch of the model + self.args=args + # step1.1: config the default value of the model + if getattr(args, "depth_color", None) == None: + self.args.depth_color = False + if getattr(args, "if_ARAP", None) == None: + self.args.if_ARAP = True + if getattr(args, "flash_attn", None) == None: + self.args.flash_attn = True + if getattr(args, "backbone", None) == None: + self.args.backbone = "CNN" + if getattr(args, "Nblock", None) == None: + self.args.Nblock = 0 + if getattr(args, "Embed3D", None) == None: + self.args.Embed3D = True + + # step1.2: config the model parameters + self.S = S + self.stride = stride + self.hidden_dim = 256 + self.latent_dim = latent_dim = 128 + self.b_latent_dim = self.latent_dim//3 + self.corr_levels = 4 + self.corr_radius = 3 + self.add_space_attn = add_space_attn + self.lie = Lie() + + # step2: config the model components + # @Encoder + self.fnet = BasicEncoder(input_dim=3, + output_dim=self.latent_dim, norm_fn="instance", dropout=0, + stride=stride, Embed3D=False + ) + + # conv head for the tri-plane features + self.headyz = nn.Sequential( + nn.Conv2d(self.latent_dim, self.latent_dim, 3, padding=1), + nn.ReLU(inplace=True), + nn.Conv2d(self.latent_dim, self.latent_dim, 3, padding=1)) + + self.headxz = nn.Sequential( + nn.Conv2d(self.latent_dim, self.latent_dim, 3, padding=1), + nn.ReLU(inplace=True), + nn.Conv2d(self.latent_dim, self.latent_dim, 3, padding=1)) + + # @UpdateFormer + self.updateformer = EUpdateFormer( + space_depth=space_depth, + time_depth=time_depth, + input_dim=456, + hidden_size=hidden_size, + num_heads=num_heads, + output_dim=latent_dim + 3, + mlp_ratio=4.0, + add_space_attn=add_space_attn, + flash=getattr(self.args, "flash_attn", True) + ) + self.support_features = torch.zeros(100, 384).to("cuda") + 0.1 + + self.norm = nn.GroupNorm(1, self.latent_dim) + + self.ffeat_updater = nn.Sequential( + nn.Linear(self.latent_dim, self.latent_dim), + nn.GELU(), + ) + self.ffeatyz_updater = nn.Sequential( + nn.Linear(self.latent_dim, self.latent_dim), + nn.GELU(), + ) + self.ffeatxz_updater = nn.Sequential( + nn.Linear(self.latent_dim, self.latent_dim), + nn.GELU(), + ) + + #TODO @NeuralArap: optimize the arap + self.embed_traj = Embedder_Fourier( + input_dim=5, max_freq_log2=5.0, N_freqs=3, include_input=True + ) + self.embed3d = Embedder_Fourier( + input_dim=3, max_freq_log2=10.0, N_freqs=10, include_input=True + ) + self.embedConv = nn.Conv2d(self.latent_dim+63, + self.latent_dim, 3, padding=1) + + # @Vis_predictor + self.vis_predictor = nn.Sequential( + nn.Linear(128, 1), + ) + + self.embedProj = nn.Linear(63, 456) + self.zeroMLPflow = nn.Linear(195, 130) + + def prepare_track(self, rgbds, queries): + """ + NOTE: + Normalized the rgbs and sorted the queries via their first appeared time + Args: + rgbds: the input rgbd images (B T 4 H W) + queries: the input queries (B N 4) + Return: + rgbds: the normalized rgbds (B T 4 H W) + queries: the sorted queries (B N 4) + track_mask: + """ + assert (rgbds.shape[2]==4) and (queries.shape[2]==4) + #Step1: normalize the rgbs input + device = rgbds.device + rgbds[:, :, :3, ...] = 2 * (rgbds[:, :, :3, ...] / 255.0) - 1.0 + B, T, C, H, W = rgbds.shape + B, N, __ = queries.shape + self.traj_e = torch.zeros((B, T, N, 3), device=device) + self.vis_e = torch.zeros((B, T, N), device=device) + + #Step2: sort the points via their first appeared time + first_positive_inds = queries[0, :, 0].long() + __, sort_inds = torch.sort(first_positive_inds, dim=0, descending=False) + inv_sort_inds = torch.argsort(sort_inds, dim=0) + first_positive_sorted_inds = first_positive_inds[sort_inds] + # check if can be inverse + assert torch.allclose( + first_positive_inds, first_positive_inds[sort_inds][inv_sort_inds] + ) + + # filter those points never appear points during 1 - T + ind_array = torch.arange(T, device=device) + ind_array = ind_array[None, :, None].repeat(B, 1, N) + track_mask = (ind_array >= + first_positive_inds[None, None, :]).unsqueeze(-1) + + # scale the coords_init + coords_init = queries[:, :, 1:].reshape(B, 1, N, 3).repeat( + 1, self.S, 1, 1 + ) + coords_init[..., :2] /= float(self.stride) + + #Step3: initial the regular grid + gridx = torch.linspace(0, W//self.stride - 1, W//self.stride) + gridy = torch.linspace(0, H//self.stride - 1, H//self.stride) + gridx, gridy = torch.meshgrid(gridx, gridy) + gridxy = torch.stack([gridx, gridy], dim=-1).to(rgbds.device).permute( + 2, 1, 0 + ) + vis_init = torch.ones((B, self.S, N, 1), device=device).float() * 10 + + # Step4: initial traj for neural arap + T_series = torch.linspace(0, 5, T).reshape(1, T, 1 , 1).cuda() # 1 T 1 1 + T_series = T_series.repeat(B, 1, N, 1) + # get the 3d traj in the camera coordinates + intr_init = self.intrs[:,queries[0,:,0].long()] + Traj_series = pix2cam(queries[:,:,None,1:].double(), intr_init.double()) + #torch.inverse(intr_init.double())@queries[:,:,1:,None].double() # B N 3 1 + Traj_series = Traj_series.repeat(1, 1, T, 1).permute(0, 2, 1, 3).float() + Traj_series = torch.cat([T_series, Traj_series], dim=-1) + # get the indicator for the neural arap + Traj_mask = -1e2*torch.ones_like(T_series) + Traj_series = torch.cat([Traj_series, Traj_mask], dim=-1) + + return ( + rgbds, + first_positive_inds, + first_positive_sorted_inds, + sort_inds, inv_sort_inds, + track_mask, gridxy, coords_init[..., sort_inds, :].clone(), + vis_init, Traj_series[..., sort_inds, :].clone() + ) + + def sample_trifeat(self, t, + coords, + featMapxy, + featMapyz, + featMapxz): + """ + Sample the features from the 5D triplane feature map 3*(B S C H W) + Args: + t: the time index + coords: the coordinates of the points B S N 3 + featMapxy: the feature map B S C Hx Wy + featMapyz: the feature map B S C Hy Wz + featMapxz: the feature map B S C Hx Wz + """ + # get xy_t yz_t xz_t + queried_t = t.reshape(1, 1, -1, 1) + xy_t = torch.cat( + [queried_t, coords[..., [0,1]]], + dim=-1 + ) + yz_t = torch.cat( + [queried_t, coords[..., [1, 2]]], + dim=-1 + ) + xz_t = torch.cat( + [queried_t, coords[..., [0, 2]]], + dim=-1 + ) + featxy_init = sample_features5d(featMapxy, xy_t) + + featyz_init = sample_features5d(featMapyz, yz_t) + featxz_init = sample_features5d(featMapxz, xz_t) + + featxy_init = featxy_init.repeat(1, self.S, 1, 1) + featyz_init = featyz_init.repeat(1, self.S, 1, 1) + featxz_init = featxz_init.repeat(1, self.S, 1, 1) + + return featxy_init, featyz_init, featxz_init + + def neural_arap(self, coords, Traj_arap, intrs_S, T_mark): + """ calculate the ARAP embedding and offset + Args: + coords: the coordinates of the current points 1 S N' 3 + Traj_arap: the trajectory of the points 1 T N' 5 + intrs_S: the camera intrinsics B S 3 3 + + """ + coords_out = coords.clone() + coords_out[..., :2] *= float(self.stride) + coords_out[..., 2] = coords_out[..., 2]/self.Dz + coords_out[..., 2] = coords_out[..., 2]*(self.d_far-self.d_near) + self.d_near + intrs_S = intrs_S[:, :, None, ...].repeat(1, 1, coords_out.shape[2], 1, 1) + B, S, N, D = coords_out.shape + if S != intrs_S.shape[1]: + intrs_S = torch.cat( + [intrs_S, intrs_S[:, -1:].repeat(1, S - intrs_S.shape[1],1,1,1)], dim=1 + ) + T_mark = torch.cat( + [T_mark, T_mark[:, -1:].repeat(1, S - T_mark.shape[1],1)], dim=1 + ) + xyz_ = pix2cam(coords_out.double(), intrs_S.double()[:,:,0]) + xyz_ = xyz_.float() + xyz_embed = torch.cat([T_mark[...,None], xyz_, + torch.zeros_like(T_mark[...,None])], dim=-1) + + xyz_embed = self.embed_traj(xyz_embed) + Traj_arap_embed = self.embed_traj(Traj_arap) + d_xyz,traj_feat = self.arapFormer(xyz_embed, Traj_arap_embed) + # update in camera coordinate + xyz_ = xyz_ + d_xyz.clamp(-5, 5) + # project back to the image plane + coords_out = cam2pix(xyz_.double(), intrs_S[:,:,0].double()).float() + # resize back + coords_out[..., :2] /= float(self.stride) + coords_out[..., 2] = (coords_out[..., 2] - self.d_near)/(self.d_far-self.d_near) + coords_out[..., 2] *= self.Dz + + return xyz_, coords_out, traj_feat + + def gradient_arap(self, coords, aff_avg=None, aff_std=None, aff_f_sg=None, + iter=0, iter_num=4, neigh_idx=None, intr=None, msk_track=None): + with torch.enable_grad(): + coords.requires_grad_(True) + y = self.ARAP_ln(coords, aff_f_sg=aff_f_sg, neigh_idx=neigh_idx, + iter=iter, iter_num=iter_num, intr=intr,msk_track=msk_track) + d_output = torch.ones_like(y, requires_grad=False, device=y.device) + gradients = torch.autograd.grad( + outputs=y, + inputs=coords, + grad_outputs=d_output, + create_graph=True, + retain_graph=True, + only_inputs=True, allow_unused=True)[0] + + return gradients.detach() + + def forward_iteration( + self, + fmapXY, + fmapYZ, + fmapXZ, + coords_init, + feat_init=None, + vis_init=None, + track_mask=None, + iters=4, + intrs_S=None, + ): + B, S_init, N, D = coords_init.shape + assert D == 3 + assert B == 1 + B, S, __, H8, W8 = fmapXY.shape + device = fmapXY.device + + if S_init < S: + coords = torch.cat( + [coords_init, coords_init[:, -1].repeat(1, S - S_init, 1, 1)], + dim=1 + ) + vis_init = torch.cat( + [vis_init, vis_init[:, -1].repeat(1, S - S_init, 1, 1)], dim=1 + ) + intrs_S = torch.cat( + [intrs_S, intrs_S[:, -1].repeat(1, S - S_init, 1, 1)], dim=1 + ) + else: + coords = coords_init.clone() + + fcorr_fnXY = CorrBlock( + fmapXY, num_levels=self.corr_levels, radius=self.corr_radius + ) + fcorr_fnYZ = CorrBlock( + fmapYZ, num_levels=self.corr_levels, radius=self.corr_radius + ) + fcorr_fnXZ = CorrBlock( + fmapXZ, num_levels=self.corr_levels, radius=self.corr_radius + ) + + ffeats = torch.split(feat_init.clone(), dim=-1, split_size_or_sections=1) + ffeats = [f.squeeze(-1) for f in ffeats] + + times_ = torch.linspace(0, S - 1, S).reshape(1, S, 1) + pos_embed = sample_pos_embed( + grid_size=(H8, W8), + embed_dim=456, + coords=coords[..., :2], + ) + pos_embed = rearrange(pos_embed, "b e n -> (b n) e").unsqueeze(1) + + times_embed = ( + torch.from_numpy(get_1d_sincos_pos_embed_from_grid(456, times_[0]))[None] + .repeat(B, 1, 1) + .float() + .to(device) + ) + coord_predictions = [] + attn_predictions = [] + Rot_ln = 0 + support_feat = self.support_features + + for __ in range(iters): + coords = coords.detach() + # if self.args.if_ARAP == True: + # # refine the track with arap + # xyz_pred, coords, flows_cat0 = self.neural_arap(coords.detach(), + # Traj_arap.detach(), + # intrs_S, T_mark) + with torch.no_grad(): + fcorrsXY = fcorr_fnXY.corr_sample(ffeats[0], coords[..., :2]) + fcorrsYZ = fcorr_fnYZ.corr_sample(ffeats[1], coords[..., [1,2]]) + fcorrsXZ = fcorr_fnXZ.corr_sample(ffeats[2], coords[..., [0,2]]) + # fcorrs = fcorrsXY + fcorrs = fcorrsXY + fcorrsYZ + fcorrsXZ + LRR = fcorrs.shape[3] + fcorrs_ = fcorrs.permute(0, 2, 1, 3).reshape(B * N, S, LRR) + + flows_ = (coords - coords[:, 0:1]).permute(0, 2, 1, 3).reshape(B * N, S, 3) + flows_cat = get_3d_embedding(flows_, 64, cat_coords=True) + flows_cat = self.zeroMLPflow(flows_cat) + + + ffeats_xy = ffeats[0].permute(0, + 2, 1, 3).reshape(B * N, S, self.latent_dim) + ffeats_yz = ffeats[1].permute(0, + 2, 1, 3).reshape(B * N, S, self.latent_dim) + ffeats_xz = ffeats[2].permute(0, + 2, 1, 3).reshape(B * N, S, self.latent_dim) + ffeats_ = ffeats_xy + ffeats_yz + ffeats_xz + + if track_mask.shape[1] < vis_init.shape[1]: + track_mask = torch.cat( + [ + track_mask, + torch.zeros_like(track_mask[:, 0]).repeat( + 1, vis_init.shape[1] - track_mask.shape[1], 1, 1 + ), + ], + dim=1, + ) + concat = ( + torch.cat([track_mask, vis_init], dim=2) + .permute(0, 2, 1, 3) + .reshape(B * N, S, 2) + ) + + transformer_input = torch.cat([flows_cat, fcorrs_, ffeats_, concat], dim=2) + + if transformer_input.shape[-1] < pos_embed.shape[-1]: + # padding the transformer_input to the same dimension as pos_embed + transformer_input = F.pad( + transformer_input, (0, pos_embed.shape[-1] - transformer_input.shape[-1]), + "constant", 0 + ) + + x = transformer_input + pos_embed + times_embed + x = rearrange(x, "(b n) t d -> b n t d", b=B) + + delta, AttnMap, so3_dist, delta_se3F, so3 = self.updateformer(x, support_feat) + support_feat = support_feat + delta_se3F[0]/100 + delta = rearrange(delta, " b n t d -> (b n) t d") + d_coord = delta[:, :, :3] + d_feats = delta[:, :, 3:] + + ffeats_xy = self.ffeat_updater(self.norm(d_feats.view(-1, self.latent_dim))) + ffeats_xy.reshape(-1, self.latent_dim) + ffeats_yz = self.ffeatyz_updater(self.norm(d_feats.view(-1, self.latent_dim))) + ffeats_yz.reshape(-1, self.latent_dim) + ffeats_xz = self.ffeatxz_updater(self.norm(d_feats.view(-1, self.latent_dim))) + ffeats_xz.reshape(-1, self.latent_dim) + ffeats[0] = ffeats_xy.reshape(B, N, S, self.latent_dim).permute( + 0, 2, 1, 3 + ) # B,S,N,C + ffeats[1] = ffeats_yz.reshape(B, N, S, self.latent_dim).permute( + 0, 2, 1, 3 + ) # B,S,N,C + ffeats[2] = ffeats_xz.reshape(B, N, S, self.latent_dim).permute( + 0, 2, 1, 3 + ) # B,S,N,C + coords = coords + d_coord.reshape(B, N, S, 3).permute(0, 2, 1, 3) + if torch.isnan(coords).any(): + import ipdb; ipdb.set_trace() + + coords_out = coords.clone() + coords_out[..., :2] *= float(self.stride) + + coords_out[..., 2] = coords_out[..., 2]/self.Dz + coords_out[..., 2] = coords_out[..., 2]*(self.d_far-self.d_near) + self.d_near + + coord_predictions.append(coords_out) + attn_predictions.append(AttnMap) + + ffeats_f = ffeats[0] + ffeats[1] + ffeats[2] + vis_e = self.vis_predictor(ffeats_f.reshape(B * S * N, self.latent_dim)).reshape( + B, S, N + ) + self.support_features = support_feat.detach() + return coord_predictions, attn_predictions, vis_e, feat_init, Rot_ln + + + def forward(self, rgbds, queries, iters=4, feat_init=None, + is_train=False, intrs=None, wind_S=None): + self.support_features = torch.zeros(100, 384).to("cuda") + 0.1 + self.is_train=is_train + B, T, C, H, W = rgbds.shape + # set the intrinsic or simply initialized + if intrs is None: + intrs = torch.from_numpy(np.array([[W, 0.0, W//2], + [0.0, W, H//2], + [0.0, 0.0, 1.0]])) + intrs = intrs[None, + None,...].repeat(B, T, 1, 1).float().to(rgbds.device) + self.intrs = intrs + + # prepare the input for tracking + ( + rgbds, + first_positive_inds, + first_positive_sorted_inds, sort_inds, + inv_sort_inds, track_mask, gridxy, + coords_init, vis_init, Traj_arap + ) = self.prepare_track(rgbds.clone(), queries) + coords_init_ = coords_init.clone() + vis_init_ = vis_init[:, :, sort_inds].clone() + + depth_all = rgbds[:, :, 3,...] + d_near = self.d_near = depth_all[depth_all>0.01].min().item() + d_far = self.d_far = depth_all[depth_all>0.01].max().item() + + if wind_S is not None: + self.S = wind_S + + B, N, __ = queries.shape + self.Dz = Dz = W//self.stride + w_idx_start = 0 + p_idx_end = 0 + p_idx_start = 0 + fmaps_ = None + vis_predictions = [] + coord_predictions = [] + attn_predictions = [] + p_idx_end_list = [] + Rigid_ln_total = 0 + while w_idx_start < T - self.S // 2: + curr_wind_points = torch.nonzero( + first_positive_sorted_inds < w_idx_start + self.S) + if curr_wind_points.shape[0] == 0: + w_idx_start = w_idx_start + self.S // 2 + continue + p_idx_end = curr_wind_points[-1] + 1 + p_idx_end_list.append(p_idx_end) + # the T may not be divided by self.S + rgbds_seq = rgbds[:, w_idx_start:w_idx_start + self.S].clone() + S = S_local = rgbds_seq.shape[1] + if S < self.S: + rgbds_seq = torch.cat( + [rgbds_seq, + rgbds_seq[:, -1, None].repeat(1, self.S - S, 1, 1, 1)], + dim=1, + ) + S = rgbds_seq.shape[1] + + rgbs_ = rgbds_seq.reshape(B * S, C, H, W)[:, :3] + depths = rgbds_seq.reshape(B * S, C, H, W)[:, 3:].clone() + # open the mask + # Traj_arap[:, w_idx_start:w_idx_start + self.S, :p_idx_end, -1] = 0 + #step1: normalize the depth map + + depths = (depths - d_near)/(d_far-d_near) + depths_dn = nn.functional.interpolate( + depths, scale_factor=1.0 / self.stride, mode="nearest") + depths_dnG = depths_dn*Dz + + #step2: normalize the coordinate + coords_init_[:, :, p_idx_start:p_idx_end, 2] = ( + coords_init[:, :, p_idx_start:p_idx_end, 2] - d_near + )/(d_far-d_near) + coords_init_[:, :, p_idx_start:p_idx_end, 2] *= Dz + + # efficient triplane splatting + gridxyz = torch.cat([gridxy[None,...].repeat( + depths_dn.shape[0],1,1,1), depths_dnG], dim=1) + Fxy2yz = gridxyz[:,[1, 2], ...] - gridxyz[:,:2] + Fxy2xz = gridxyz[:,[0, 2], ...] - gridxyz[:,:2] + if getattr(self.args, "Embed3D", None) == True: + gridxyz_nm = gridxyz.clone() + gridxyz_nm[:,0,...] = (gridxyz_nm[:,0,...]-gridxyz_nm[:,0,...].min())/(gridxyz_nm[:,0,...].max()-gridxyz_nm[:,0,...].min()) + gridxyz_nm[:,1,...] = (gridxyz_nm[:,1,...]-gridxyz_nm[:,1,...].min())/(gridxyz_nm[:,1,...].max()-gridxyz_nm[:,1,...].min()) + gridxyz_nm[:,2,...] = (gridxyz_nm[:,2,...]-gridxyz_nm[:,2,...].min())/(gridxyz_nm[:,2,...].max()-gridxyz_nm[:,2,...].min()) + gridxyz_nm = 2*(gridxyz_nm-0.5) + _,_,h4,w4 = gridxyz_nm.shape + gridxyz_nm = gridxyz_nm.permute(0,2,3,1).reshape(S*h4*w4, 3) + featPE = self.embed3d(gridxyz_nm).view(S, h4, w4, -1).permute(0,3,1,2) + if fmaps_ is None: + fmaps_ = torch.cat([self.fnet(rgbs_),featPE], dim=1) + fmaps_ = self.embedConv(fmaps_) + else: + fmaps_new = torch.cat([self.fnet(rgbs_[self.S // 2 :]),featPE[self.S // 2 :]], dim=1) + fmaps_new = self.embedConv(fmaps_new) + fmaps_ = torch.cat( + [fmaps_[self.S // 2 :], fmaps_new], dim=0 + ) + else: + if fmaps_ is None: + fmaps_ = self.fnet(rgbs_) + else: + fmaps_ = torch.cat( + [fmaps_[self.S // 2 :], self.fnet(rgbs_[self.S // 2 :])], dim=0 + ) + + fmapXY = fmaps_[:, :self.latent_dim].reshape( + B, S, self.latent_dim, H // self.stride, W // self.stride + ) + + fmapYZ = softsplat(fmapXY[0], Fxy2yz, None, + strMode="avg", tenoutH=self.Dz, tenoutW=H//self.stride) + fmapXZ = softsplat(fmapXY[0], Fxy2xz, None, + strMode="avg", tenoutH=self.Dz, tenoutW=W//self.stride) + + fmapYZ = self.headyz(fmapYZ)[None, ...] + fmapXZ = self.headxz(fmapXZ)[None, ...] + + if p_idx_end - p_idx_start > 0: + queried_t = (first_positive_sorted_inds[p_idx_start:p_idx_end] + - w_idx_start) + (featxy_init, + featyz_init, + featxz_init) = self.sample_trifeat( + t=queried_t,featMapxy=fmapXY, + featMapyz=fmapYZ,featMapxz=fmapXZ, + coords=coords_init_[:, :1, p_idx_start:p_idx_end] + ) + # T, S, N, C, 3 + feat_init_curr = torch.stack([featxy_init, + featyz_init, featxz_init], dim=-1) + feat_init = smart_cat(feat_init, feat_init_curr, dim=2) + + if p_idx_start > 0: + # preprocess the coordinates of last windows + last_coords = coords[-1][:, self.S // 2 :].clone() + last_coords[..., :2] /= float(self.stride) + last_coords[..., 2:] = (last_coords[..., 2:]-d_near)/(d_far-d_near) + last_coords[..., 2:] = last_coords[..., 2:]*Dz + + coords_init_[:, : self.S // 2, :p_idx_start] = last_coords + coords_init_[:, self.S // 2 :, :p_idx_start] = last_coords[ + :, -1 + ].repeat(1, self.S // 2, 1, 1) + + last_vis = vis[:, self.S // 2 :].unsqueeze(-1) + vis_init_[:, : self.S // 2, :p_idx_start] = last_vis + vis_init_[:, self.S // 2 :, :p_idx_start] = last_vis[:, -1].repeat( + 1, self.S // 2, 1, 1 + ) + + coords, attns, vis, __, Rigid_ln = self.forward_iteration( + fmapXY=fmapXY, + fmapYZ=fmapYZ, + fmapXZ=fmapXZ, + coords_init=coords_init_[:, :, :p_idx_end], + feat_init=feat_init[:, :, :p_idx_end], + vis_init=vis_init_[:, :, :p_idx_end], + track_mask=track_mask[:, w_idx_start : w_idx_start + self.S, :p_idx_end], + iters=iters, + intrs_S=self.intrs[:, w_idx_start : w_idx_start + self.S], + ) + + Rigid_ln_total+=Rigid_ln + + if is_train: + vis_predictions.append(torch.sigmoid(vis[:, :S_local])) + coord_predictions.append([coord[:, :S_local] for coord in coords]) + attn_predictions.append(attns) + + self.traj_e[:, w_idx_start:w_idx_start+self.S, :p_idx_end] = coords[-1][:, :S_local] + self.vis_e[:, w_idx_start:w_idx_start+self.S, :p_idx_end] = vis[:, :S_local] + + track_mask[:, : w_idx_start + self.S, :p_idx_end] = 0.0 + w_idx_start = w_idx_start + self.S // 2 + + p_idx_start = p_idx_end + + self.traj_e = self.traj_e[:, :, inv_sort_inds] + self.vis_e = self.vis_e[:, :, inv_sort_inds] + + self.vis_e = torch.sigmoid(self.vis_e) + train_data = ( + (vis_predictions, coord_predictions, attn_predictions, + p_idx_end_list, sort_inds, Rigid_ln_total) + ) + if self.is_train: + return self.traj_e, feat_init, self.vis_e, train_data + else: + return self.traj_e, feat_init, self.vis_e + diff --git a/models/spatracker/models/core/spatracker/unet.py b/models/spatracker/models/core/spatracker/unet.py new file mode 100644 index 0000000000000000000000000000000000000000..715ae32b5cc1de01dfb749906a78c3c87bce821e --- /dev/null +++ b/models/spatracker/models/core/spatracker/unet.py @@ -0,0 +1,258 @@ +''' +Codes are from: +https://github.com/jaxony/unet-pytorch/blob/master/model.py +''' + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.autograd import Variable +from collections import OrderedDict +from torch.nn import init +import numpy as np + +def conv3x3(in_channels, out_channels, stride=1, + padding=1, bias=True, groups=1): + return nn.Conv2d( + in_channels, + out_channels, + kernel_size=3, + stride=stride, + padding=padding, + bias=bias, + groups=groups) + +def upconv2x2(in_channels, out_channels, mode='transpose'): + if mode == 'transpose': + return nn.ConvTranspose2d( + in_channels, + out_channels, + kernel_size=2, + stride=2) + else: + # out_channels is always going to be the same + # as in_channels + return nn.Sequential( + nn.Upsample(mode='bilinear', scale_factor=2), + conv1x1(in_channels, out_channels)) + +def conv1x1(in_channels, out_channels, groups=1): + return nn.Conv2d( + in_channels, + out_channels, + kernel_size=1, + groups=groups, + stride=1) + + +class DownConv(nn.Module): + """ + A helper Module that performs 2 convolutions and 1 MaxPool. + A ReLU activation follows each convolution. + """ + def __init__(self, in_channels, out_channels, pooling=True): + super(DownConv, self).__init__() + + self.in_channels = in_channels + self.out_channels = out_channels + self.pooling = pooling + + self.conv1 = conv3x3(self.in_channels, self.out_channels) + self.conv2 = conv3x3(self.out_channels, self.out_channels) + + if self.pooling: + self.pool = nn.MaxPool2d(kernel_size=2, stride=2) + + def forward(self, x): + x = F.relu(self.conv1(x)) + x = F.relu(self.conv2(x)) + before_pool = x + if self.pooling: + x = self.pool(x) + return x, before_pool + + +class UpConv(nn.Module): + """ + A helper Module that performs 2 convolutions and 1 UpConvolution. + A ReLU activation follows each convolution. + """ + def __init__(self, in_channels, out_channels, + merge_mode='concat', up_mode='transpose'): + super(UpConv, self).__init__() + + self.in_channels = in_channels + self.out_channels = out_channels + self.merge_mode = merge_mode + self.up_mode = up_mode + + self.upconv = upconv2x2(self.in_channels, self.out_channels, + mode=self.up_mode) + + if self.merge_mode == 'concat': + self.conv1 = conv3x3( + 2*self.out_channels, self.out_channels) + else: + # num of input channels to conv2 is same + self.conv1 = conv3x3(self.out_channels, self.out_channels) + self.conv2 = conv3x3(self.out_channels, self.out_channels) + + + def forward(self, from_down, from_up): + """ Forward pass + Arguments: + from_down: tensor from the encoder pathway + from_up: upconv'd tensor from the decoder pathway + """ + from_up = self.upconv(from_up) + if self.merge_mode == 'concat': + x = torch.cat((from_up, from_down), 1) + else: + x = from_up + from_down + x = F.relu(self.conv1(x)) + x = F.relu(self.conv2(x)) + return x + + +class UNet(nn.Module): + """ `UNet` class is based on https://arxiv.org/abs/1505.04597 + + The U-Net is a convolutional encoder-decoder neural network. + Contextual spatial information (from the decoding, + expansive pathway) about an input tensor is merged with + information representing the localization of details + (from the encoding, compressive pathway). + + Modifications to the original paper: + (1) padding is used in 3x3 convolutions to prevent loss + of border pixels + (2) merging outputs does not require cropping due to (1) + (3) residual connections can be used by specifying + UNet(merge_mode='add') + (4) if non-parametric upsampling is used in the decoder + pathway (specified by upmode='upsample'), then an + additional 1x1 2d convolution occurs after upsampling + to reduce channel dimensionality by a factor of 2. + This channel halving happens with the convolution in + the tranpose convolution (specified by upmode='transpose') + """ + + def __init__(self, num_classes, in_channels=3, depth=5, + start_filts=64, up_mode='transpose', + merge_mode='concat', **kwargs): + """ + Arguments: + in_channels: int, number of channels in the input tensor. + Default is 3 for RGB images. + depth: int, number of MaxPools in the U-Net. + start_filts: int, number of convolutional filters for the + first conv. + up_mode: string, type of upconvolution. Choices: 'transpose' + for transpose convolution or 'upsample' for nearest neighbour + upsampling. + """ + super(UNet, self).__init__() + + if up_mode in ('transpose', 'upsample'): + self.up_mode = up_mode + else: + raise ValueError("\"{}\" is not a valid mode for " + "upsampling. Only \"transpose\" and " + "\"upsample\" are allowed.".format(up_mode)) + + if merge_mode in ('concat', 'add'): + self.merge_mode = merge_mode + else: + raise ValueError("\"{}\" is not a valid mode for" + "merging up and down paths. " + "Only \"concat\" and " + "\"add\" are allowed.".format(up_mode)) + + # NOTE: up_mode 'upsample' is incompatible with merge_mode 'add' + if self.up_mode == 'upsample' and self.merge_mode == 'add': + raise ValueError("up_mode \"upsample\" is incompatible " + "with merge_mode \"add\" at the moment " + "because it doesn't make sense to use " + "nearest neighbour to reduce " + "depth channels (by half).") + + self.num_classes = num_classes + self.in_channels = in_channels + self.start_filts = start_filts + self.depth = depth + + self.down_convs = [] + self.up_convs = [] + + # create the encoder pathway and add to a list + for i in range(depth): + ins = self.in_channels if i == 0 else outs + outs = self.start_filts*(2**i) + pooling = True if i < depth-1 else False + + down_conv = DownConv(ins, outs, pooling=pooling) + self.down_convs.append(down_conv) + + # create the decoder pathway and add to a list + # - careful! decoding only requires depth-1 blocks + for i in range(depth-1): + ins = outs + outs = ins // 2 + up_conv = UpConv(ins, outs, up_mode=up_mode, + merge_mode=merge_mode) + self.up_convs.append(up_conv) + + # add the list of modules to current module + self.down_convs = nn.ModuleList(self.down_convs) + self.up_convs = nn.ModuleList(self.up_convs) + + self.conv_final = conv1x1(outs, self.num_classes) + + self.reset_params() + + @staticmethod + def weight_init(m): + if isinstance(m, nn.Conv2d): + init.xavier_normal_(m.weight) + init.constant_(m.bias, 0) + + + def reset_params(self): + for i, m in enumerate(self.modules()): + self.weight_init(m) + + + def forward(self, x): + encoder_outs = [] + # encoder pathway, save outputs for merging + for i, module in enumerate(self.down_convs): + x, before_pool = module(x) + encoder_outs.append(before_pool) + for i, module in enumerate(self.up_convs): + before_pool = encoder_outs[-(i+2)] + x = module(before_pool, x) + + # No softmax is used. This means you need to use + # nn.CrossEntropyLoss is your training script, + # as this module includes a softmax already. + x = self.conv_final(x) + return x + +if __name__ == "__main__": + """ + testing + """ + model = UNet(1, depth=5, merge_mode='concat', in_channels=1, start_filts=32) + print(model) + print(sum(p.numel() for p in model.parameters())) + + reso = 176 + x = np.zeros((1, 1, reso, reso)) + x[:,:,int(reso/2-1), int(reso/2-1)] = np.nan + x = torch.FloatTensor(x) + + out = model(x) + print('%f'%(torch.sum(torch.isnan(out)).detach().cpu().numpy()/(reso*reso))) + + # loss = torch.sum(out) + # loss.backward() \ No newline at end of file diff --git a/models/spatracker/models/core/spatracker/vit/__init__.py b/models/spatracker/models/core/spatracker/vit/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/models/spatracker/models/core/spatracker/vit/common.py b/models/spatracker/models/core/spatracker/vit/common.py new file mode 100644 index 0000000000000000000000000000000000000000..d67662c6a517be28bf3b8d037056a6e376cf7a7e --- /dev/null +++ b/models/spatracker/models/core/spatracker/vit/common.py @@ -0,0 +1,43 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import torch +import torch.nn as nn + +from typing import Type + + +class MLPBlock(nn.Module): + def __init__( + self, + embedding_dim: int, + mlp_dim: int, + act: Type[nn.Module] = nn.GELU, + ) -> None: + super().__init__() + self.lin1 = nn.Linear(embedding_dim, mlp_dim) + self.lin2 = nn.Linear(mlp_dim, embedding_dim) + self.act = act() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.lin2(self.act(self.lin1(x))) + + +# From https://github.com/facebookresearch/detectron2/blob/main/detectron2/layers/batch_norm.py # noqa +# Itself from https://github.com/facebookresearch/ConvNeXt/blob/d1fa8f6fef0a165b27399986cc2bdacc92777e40/models/convnext.py#L119 # noqa +class LayerNorm2d(nn.Module): + def __init__(self, num_channels: int, eps: float = 1e-6) -> None: + super().__init__() + self.weight = nn.Parameter(torch.ones(num_channels)) + self.bias = nn.Parameter(torch.zeros(num_channels)) + self.eps = eps + + def forward(self, x: torch.Tensor) -> torch.Tensor: + u = x.mean(1, keepdim=True) + s = (x - u).pow(2).mean(1, keepdim=True) + x = (x - u) / torch.sqrt(s + self.eps) + x = self.weight[:, None, None] * x + self.bias[:, None, None] + return x \ No newline at end of file diff --git a/models/spatracker/models/core/spatracker/vit/encoder.py b/models/spatracker/models/core/spatracker/vit/encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..c957ea1f7df5c768f47d5a8f4b46f154217913ca --- /dev/null +++ b/models/spatracker/models/core/spatracker/vit/encoder.py @@ -0,0 +1,397 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from typing import Optional, Tuple, Type + +from models.spatracker.models.core.spatracker.vit.common import ( + LayerNorm2d, MLPBlock +) + +# This class and its supporting functions below lightly adapted from the ViTDet backbone available at: https://github.com/facebookresearch/detectron2/blob/main/detectron2/modeling/backbone/vit.py # noqa +class ImageEncoderViT(nn.Module): + def __init__( + self, + img_size: int = 1024, + patch_size: int = 16, + in_chans: int = 3, + embed_dim: int = 768, + depth: int = 12, + num_heads: int = 12, + mlp_ratio: float = 4.0, + out_chans: int = 256, + qkv_bias: bool = True, + norm_layer: Type[nn.Module] = nn.LayerNorm, + act_layer: Type[nn.Module] = nn.GELU, + use_abs_pos: bool = True, + use_rel_pos: bool = False, + rel_pos_zero_init: bool = True, + window_size: int = 0, + global_attn_indexes: Tuple[int, ...] = (), + ) -> None: + """ + Args: + img_size (int): Input image size. + patch_size (int): Patch size. + in_chans (int): Number of input image channels. + embed_dim (int): Patch embedding dimension. + depth (int): Depth of ViT. + num_heads (int): Number of attention heads in each ViT block. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool): If True, add a learnable bias to query, key, value. + norm_layer (nn.Module): Normalization layer. + act_layer (nn.Module): Activation layer. + use_abs_pos (bool): If True, use absolute positional embeddings. + use_rel_pos (bool): If True, add relative positional embeddings to the attention map. + rel_pos_zero_init (bool): If True, zero initialize relative positional parameters. + window_size (int): Window size for window attention blocks. + global_attn_indexes (list): Indexes for blocks using global attention. + """ + super().__init__() + self.img_size = img_size + + self.patch_embed = PatchEmbed( + kernel_size=(patch_size, patch_size), + stride=(patch_size, patch_size), + in_chans=in_chans, + embed_dim=embed_dim, + ) + + self.pos_embed: Optional[nn.Parameter] = None + if use_abs_pos: + # Initialize absolute positional embedding with pretrain image size. + self.pos_embed = nn.Parameter( + torch.zeros(1, img_size // patch_size, img_size // patch_size, embed_dim) + ) + + self.blocks = nn.ModuleList() + for i in range(depth): + block = Block( + dim=embed_dim, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + norm_layer=norm_layer, + act_layer=act_layer, + use_rel_pos=use_rel_pos, + rel_pos_zero_init=rel_pos_zero_init, + window_size=window_size if i not in global_attn_indexes else 0, + input_size=(img_size // patch_size, img_size // patch_size), + ) + self.blocks.append(block) + + self.neck = nn.Sequential( + nn.Conv2d( + embed_dim, + out_chans, + kernel_size=1, + bias=False, + ), + LayerNorm2d(out_chans), + nn.Conv2d( + out_chans, + out_chans, + kernel_size=3, + padding=1, + bias=False, + ), + LayerNorm2d(out_chans), + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + + x = self.patch_embed(x) + if self.pos_embed is not None: + x = x + self.pos_embed + + for blk in self.blocks: + x = blk(x) + + x = self.neck(x.permute(0, 3, 1, 2)) + + return x + + +class Block(nn.Module): + """Transformer blocks with support of window attention and residual propagation blocks""" + + def __init__( + self, + dim: int, + num_heads: int, + mlp_ratio: float = 4.0, + qkv_bias: bool = True, + norm_layer: Type[nn.Module] = nn.LayerNorm, + act_layer: Type[nn.Module] = nn.GELU, + use_rel_pos: bool = False, + rel_pos_zero_init: bool = True, + window_size: int = 0, + input_size: Optional[Tuple[int, int]] = None, + ) -> None: + """ + Args: + dim (int): Number of input channels. + num_heads (int): Number of attention heads in each ViT block. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool): If True, add a learnable bias to query, key, value. + norm_layer (nn.Module): Normalization layer. + act_layer (nn.Module): Activation layer. + use_rel_pos (bool): If True, add relative positional embeddings to the attention map. + rel_pos_zero_init (bool): If True, zero initialize relative positional parameters. + window_size (int): Window size for window attention blocks. If it equals 0, then + use global attention. + input_size (tuple(int, int) or None): Input resolution for calculating the relative + positional parameter size. + """ + super().__init__() + self.norm1 = norm_layer(dim) + self.attn = Attention( + dim, + num_heads=num_heads, + qkv_bias=qkv_bias, + use_rel_pos=use_rel_pos, + rel_pos_zero_init=rel_pos_zero_init, + input_size=input_size if window_size == 0 else (window_size, window_size), + ) + + self.norm2 = norm_layer(dim) + self.mlp = MLPBlock(embedding_dim=dim, mlp_dim=int(dim * mlp_ratio), act=act_layer) + + self.window_size = window_size + + def forward(self, x: torch.Tensor) -> torch.Tensor: + shortcut = x + x = self.norm1(x) + # Window partition + if self.window_size > 0: + H, W = x.shape[1], x.shape[2] + x, pad_hw = window_partition(x, self.window_size) + + x = self.attn(x) + # Reverse window partition + if self.window_size > 0: + x = window_unpartition(x, self.window_size, pad_hw, (H, W)) + + x = shortcut + x + x = x + self.mlp(self.norm2(x)) + + return x + + +class Attention(nn.Module): + """Multi-head Attention block with relative position embeddings.""" + + def __init__( + self, + dim: int, + num_heads: int = 8, + qkv_bias: bool = True, + use_rel_pos: bool = False, + rel_pos_zero_init: bool = True, + input_size: Optional[Tuple[int, int]] = None, + ) -> None: + """ + Args: + dim (int): Number of input channels. + num_heads (int): Number of attention heads. + qkv_bias (bool): If True, add a learnable bias to query, key, value. + rel_pos (bool): If True, add relative positional embeddings to the attention map. + rel_pos_zero_init (bool): If True, zero initialize relative positional parameters. + input_size (tuple(int, int) or None): Input resolution for calculating the relative + positional parameter size. + """ + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = head_dim**-0.5 + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.proj = nn.Linear(dim, dim) + + self.use_rel_pos = use_rel_pos + if self.use_rel_pos: + assert ( + input_size is not None + ), "Input size must be provided if using relative positional encoding." + # initialize relative positional embeddings + self.rel_pos_h = nn.Parameter(torch.zeros(2 * input_size[0] - 1, head_dim)) + self.rel_pos_w = nn.Parameter(torch.zeros(2 * input_size[1] - 1, head_dim)) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + B, H, W, _ = x.shape + # qkv with shape (3, B, nHead, H * W, C) + qkv = self.qkv(x).reshape(B, H * W, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) + # q, k, v with shape (B * nHead, H * W, C) + q, k, v = qkv.reshape(3, B * self.num_heads, H * W, -1).unbind(0) + + attn = (q * self.scale) @ k.transpose(-2, -1) + + if self.use_rel_pos: + attn = add_decomposed_rel_pos(attn, q, self.rel_pos_h, self.rel_pos_w, (H, W), (H, W)) + + attn = attn.softmax(dim=-1) + x = (attn @ v).view(B, self.num_heads, H, W, -1).permute(0, 2, 3, 1, 4).reshape(B, H, W, -1) + x = self.proj(x) + + return x + + +def window_partition(x: torch.Tensor, window_size: int) -> Tuple[torch.Tensor, Tuple[int, int]]: + """ + Partition into non-overlapping windows with padding if needed. + Args: + x (tensor): input tokens with [B, H, W, C]. + window_size (int): window size. + + Returns: + windows: windows after partition with [B * num_windows, window_size, window_size, C]. + (Hp, Wp): padded height and width before partition + """ + B, H, W, C = x.shape + + pad_h = (window_size - H % window_size) % window_size + pad_w = (window_size - W % window_size) % window_size + if pad_h > 0 or pad_w > 0: + x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h)) + Hp, Wp = H + pad_h, W + pad_w + + x = x.view(B, Hp // window_size, window_size, Wp // window_size, window_size, C) + windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) + return windows, (Hp, Wp) + + +def window_unpartition( + windows: torch.Tensor, window_size: int, pad_hw: Tuple[int, int], hw: Tuple[int, int] +) -> torch.Tensor: + """ + Window unpartition into original sequences and removing padding. + Args: + windows (tensor): input tokens with [B * num_windows, window_size, window_size, C]. + window_size (int): window size. + pad_hw (Tuple): padded height and width (Hp, Wp). + hw (Tuple): original height and width (H, W) before padding. + + Returns: + x: unpartitioned sequences with [B, H, W, C]. + """ + Hp, Wp = pad_hw + H, W = hw + B = windows.shape[0] // (Hp * Wp // window_size // window_size) + x = windows.view(B, Hp // window_size, Wp // window_size, window_size, window_size, -1) + x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, Hp, Wp, -1) + + if Hp > H or Wp > W: + x = x[:, :H, :W, :].contiguous() + return x + + +def get_rel_pos(q_size: int, k_size: int, rel_pos: torch.Tensor) -> torch.Tensor: + """ + Get relative positional embeddings according to the relative positions of + query and key sizes. + Args: + q_size (int): size of query q. + k_size (int): size of key k. + rel_pos (Tensor): relative position embeddings (L, C). + + Returns: + Extracted positional embeddings according to relative positions. + """ + max_rel_dist = int(2 * max(q_size, k_size) - 1) + # Interpolate rel pos if needed. + if rel_pos.shape[0] != max_rel_dist: + # Interpolate rel pos. + rel_pos_resized = F.interpolate( + rel_pos.reshape(1, rel_pos.shape[0], -1).permute(0, 2, 1), + size=max_rel_dist, + mode="linear", + ) + rel_pos_resized = rel_pos_resized.reshape(-1, max_rel_dist).permute(1, 0) + else: + rel_pos_resized = rel_pos + + # Scale the coords with short length if shapes for q and k are different. + q_coords = torch.arange(q_size)[:, None] * max(k_size / q_size, 1.0) + k_coords = torch.arange(k_size)[None, :] * max(q_size / k_size, 1.0) + relative_coords = (q_coords - k_coords) + (k_size - 1) * max(q_size / k_size, 1.0) + + return rel_pos_resized[relative_coords.long()] + + +def add_decomposed_rel_pos( + attn: torch.Tensor, + q: torch.Tensor, + rel_pos_h: torch.Tensor, + rel_pos_w: torch.Tensor, + q_size: Tuple[int, int], + k_size: Tuple[int, int], +) -> torch.Tensor: + """ + Calculate decomposed Relative Positional Embeddings from :paper:`mvitv2`. + https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py # noqa B950 + Args: + attn (Tensor): attention map. + q (Tensor): query q in the attention layer with shape (B, q_h * q_w, C). + rel_pos_h (Tensor): relative position embeddings (Lh, C) for height axis. + rel_pos_w (Tensor): relative position embeddings (Lw, C) for width axis. + q_size (Tuple): spatial sequence size of query q with (q_h, q_w). + k_size (Tuple): spatial sequence size of key k with (k_h, k_w). + + Returns: + attn (Tensor): attention map with added relative positional embeddings. + """ + q_h, q_w = q_size + k_h, k_w = k_size + Rh = get_rel_pos(q_h, k_h, rel_pos_h) + Rw = get_rel_pos(q_w, k_w, rel_pos_w) + + B, _, dim = q.shape + r_q = q.reshape(B, q_h, q_w, dim) + rel_h = torch.einsum("bhwc,hkc->bhwk", r_q, Rh) + rel_w = torch.einsum("bhwc,wkc->bhwk", r_q, Rw) + + attn = ( + attn.view(B, q_h, q_w, k_h, k_w) + rel_h[:, :, :, :, None] + rel_w[:, :, :, None, :] + ).view(B, q_h * q_w, k_h * k_w) + + return attn + + +class PatchEmbed(nn.Module): + """ + Image to Patch Embedding. + """ + + def __init__( + self, + kernel_size: Tuple[int, int] = (16, 16), + stride: Tuple[int, int] = (16, 16), + padding: Tuple[int, int] = (0, 0), + in_chans: int = 3, + embed_dim: int = 768, + ) -> None: + """ + Args: + kernel_size (Tuple): kernel size of the projection layer. + stride (Tuple): stride of the projection layer. + padding (Tuple): padding size of the projection layer. + in_chans (int): Number of input image channels. + embed_dim (int): Patch embedding dimension. + """ + super().__init__() + + self.proj = nn.Conv2d( + in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.proj(x) + # B C H W -> B H W C + x = x.permute(0, 2, 3, 1) + return x \ No newline at end of file diff --git a/models/spatracker/predictor.py b/models/spatracker/predictor.py new file mode 100644 index 0000000000000000000000000000000000000000..4dacbb36c205ac9aebf578f9294099e1cca35860 --- /dev/null +++ b/models/spatracker/predictor.py @@ -0,0 +1,284 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import torch +import torch.nn.functional as F +import time + +from tqdm import tqdm +from models.spatracker.models.core.spatracker.spatracker import get_points_on_a_grid +from models.spatracker.models.core.model_utils import smart_cat +from models.spatracker.models.build_spatracker import ( + build_spatracker, +) +from models.spatracker.models.core.model_utils import ( + meshgrid2d, bilinear_sample2d, smart_cat +) + + +class SpaTrackerPredictor(torch.nn.Module): + def __init__( + self, checkpoint="cotracker/checkpoints/cotracker_stride_4_wind_8.pth", + interp_shape=(384, 512), + seq_length=16 + ): + super().__init__() + self.interp_shape = interp_shape + self.support_grid_size = 6 + model = build_spatracker(checkpoint, seq_length=seq_length) + + self.model = model + self.model.eval() + + @torch.no_grad() + def forward( + self, + video, # (1, T, 3, H, W) + video_depth = None, # (T, 1, H, W) + # input prompt types: + # - None. Dense tracks are computed in this case. You can adjust *query_frame* to compute tracks starting from a specific frame. + # *backward_tracking=True* will compute tracks in both directions. + # - queries. Queried points of shape (1, N, 3) in format (t, x, y) for frame index and pixel coordinates. + # - grid_size. Grid of N*N points from the first frame. if segm_mask is provided, then computed only for the mask. + # You can adjust *query_frame* and *backward_tracking* for the regular grid in the same way as for dense tracks. + queries: torch.Tensor = None, + segm_mask: torch.Tensor = None, # Segmentation mask of shape (B, 1, H, W) + grid_size: int = 0, + grid_query_frame: int = 0, # only for dense and regular grid tracks + backward_tracking: bool = False, + depth_predictor=None, + wind_length: int = 8, + progressive_tracking: bool = False, + ): + if queries is None and grid_size == 0: + tracks, visibilities, T_Firsts = self._compute_dense_tracks( + video, + grid_query_frame=grid_query_frame, + backward_tracking=backward_tracking, + video_depth=video_depth, + depth_predictor=depth_predictor, + wind_length=wind_length, + ) + else: + tracks, visibilities, T_Firsts = self._compute_sparse_tracks( + video, + queries, + segm_mask, + grid_size, + add_support_grid=False, #(grid_size == 0 or segm_mask is not None), + grid_query_frame=grid_query_frame, + backward_tracking=backward_tracking, + video_depth=video_depth, + depth_predictor=depth_predictor, + wind_length=wind_length, + ) + + return tracks, visibilities, T_Firsts + + def _compute_dense_tracks( + self, video, grid_query_frame, grid_size=30, backward_tracking=False, + depth_predictor=None, video_depth=None, wind_length=8 + ): + *_, H, W = video.shape + grid_step = W // grid_size + grid_width = W // grid_step + grid_height = H // grid_step + tracks = visibilities = T_Firsts = None + grid_pts = torch.zeros((1, grid_width * grid_height, 3)).to(video.device) + grid_pts[0, :, 0] = grid_query_frame + for offset in tqdm(range(grid_step * grid_step)): + ox = offset % grid_step + oy = offset // grid_step + grid_pts[0, :, 1] = ( + torch.arange(grid_width).repeat(grid_height) * grid_step + ox + ) + grid_pts[0, :, 2] = ( + torch.arange(grid_height).repeat_interleave(grid_width) * grid_step + oy + ) + tracks_step, visibilities_step, T_First_step = self._compute_sparse_tracks( + video=video, + queries=grid_pts, + backward_tracking=backward_tracking, + wind_length=wind_length, + video_depth=video_depth, + depth_predictor=depth_predictor, + ) + tracks = smart_cat(tracks, tracks_step, dim=2) + visibilities = smart_cat(visibilities, visibilities_step, dim=2) + T_Firsts = smart_cat(T_Firsts, T_First_step, dim=1) + + + return tracks, visibilities, T_Firsts + + def _compute_sparse_tracks( + self, + video, + queries, + segm_mask=None, + grid_size=0, + add_support_grid=False, + grid_query_frame=0, + backward_tracking=False, + depth_predictor=None, + video_depth=None, + wind_length=8, + ): + B, T, C, H, W = video.shape + assert B == 1 + + video = video.reshape(B * T, C, H, W) + video = F.interpolate(video, tuple(self.interp_shape), mode="bilinear") + video = video.reshape(B, T, 3, self.interp_shape[0], self.interp_shape[1]) + + if queries is not None: + queries = queries.clone() + B, N, D = queries.shape + assert D == 3 + queries[:, :, 1] *= self.interp_shape[1] / W + queries[:, :, 2] *= self.interp_shape[0] / H + elif grid_size > 0: + grid_pts = get_points_on_a_grid(grid_size, self.interp_shape, device=video.device) + if segm_mask is not None: + segm_mask = F.interpolate( + segm_mask, tuple(self.interp_shape), mode="nearest" + ) + point_mask = segm_mask[0, 0][ + (grid_pts[0, :, 1]).round().long().cpu(), + (grid_pts[0, :, 0]).round().long().cpu(), + ].bool() + grid_pts_extra = grid_pts[:, point_mask] + else: + grid_pts_extra = None + if grid_pts_extra is not None: + total_num = int(grid_pts_extra.shape[1]) + total_num = min(800, total_num) + pick_idx = torch.randperm(grid_pts_extra.shape[1])[:total_num] + grid_pts_extra = grid_pts_extra[:, pick_idx] + queries_extra = torch.cat( + [ + torch.ones_like(grid_pts_extra[:, :, :1]) * grid_query_frame, + grid_pts_extra, + ], + dim=2, + ) + + queries = torch.cat( + [torch.zeros_like(grid_pts[:, :, :1]), grid_pts], + dim=2, + ) + + if add_support_grid: + grid_pts = get_points_on_a_grid(self.support_grid_size, self.interp_shape, device=video.device) + grid_pts = torch.cat( + [torch.zeros_like(grid_pts[:, :, :1]), grid_pts], dim=2 + ) + queries = torch.cat([queries, grid_pts], dim=1) + + ## ----------- estimate the video depth -----------## + if video_depth is None: + with torch.no_grad(): + if video[0].shape[0]>30: + vidDepths = [] + for i in range(video[0].shape[0]//30+1): + if (i+1)*30 > video[0].shape[0]: + end_idx = video[0].shape[0] + else: + end_idx = (i+1)*30 + if end_idx == i*30: + break + video_ = video[0][i*30:end_idx] + vidDepths.append(depth_predictor.infer(video_/255)) + + video_depth = torch.cat(vidDepths, dim=0) + + else: + video_depth = depth_predictor.infer(video[0]/255) + video_depth = F.interpolate(video_depth, + tuple(self.interp_shape), mode="nearest") + + # from PIL import Image + # import numpy + # depth_frame = video_depth[0].detach().cpu() + # depth_frame = depth_frame.squeeze(0) + # print(depth_frame) + # print(depth_frame.min(), depth_frame.max()) + # depth_img = (depth_frame * 255).numpy().astype(numpy.uint8) + # depth_img = Image.fromarray(depth_img, mode='L') + # depth_img.save('outputs/depth_map.png') + + # frame = video[0, 0].detach().cpu() + # frame = frame.permute(1, 2, 0) + # frame = (frame * 255).numpy().astype(numpy.uint8) + # frame = Image.fromarray(frame, mode='RGB') + # frame.save('outputs/frame.png') + + depths = video_depth + rgbds = torch.cat([video, depths[None,...]], dim=2) + # get the 3D queries + depth_interp=[] + for i in range(queries.shape[1]): + depth_interp_i = bilinear_sample2d(video_depth[queries[:, i:i+1, 0].long()], + queries[:, i:i+1, 1], queries[:, i:i+1, 2]) + depth_interp.append(depth_interp_i) + + depth_interp = torch.cat(depth_interp, dim=1) + queries = smart_cat(queries, depth_interp,dim=-1) + + #NOTE: free the memory of depth_predictor + del depth_predictor + torch.cuda.empty_cache() + t0 = time.time() + tracks, __, visibilities = self.model(rgbds=rgbds, queries=queries, iters=6, wind_S=wind_length) + print("Time taken for inference: ", time.time()-t0) + + if backward_tracking: + tracks, visibilities = self._compute_backward_tracks( + rgbds, queries, tracks, visibilities + ) + if add_support_grid: + queries[:, -self.support_grid_size ** 2 :, 0] = T - 1 + if add_support_grid: + tracks = tracks[:, :, : -self.support_grid_size ** 2] + visibilities = visibilities[:, :, : -self.support_grid_size ** 2] + thr = 0.9 + visibilities = visibilities > thr + + # correct query-point predictions + # see https://github.com/facebookresearch/co-tracker/issues/28 + + # TODO: batchify + for i in range(len(queries)): + queries_t = queries[i, :tracks.size(2), 0].to(torch.int64) + arange = torch.arange(0, len(queries_t)) + + # overwrite the predictions with the query points + tracks[i, queries_t, arange] = queries[i, :tracks.size(2), 1:] + + # correct visibilities, the query points should be visible + visibilities[i, queries_t, arange] = True + + T_First = queries[..., :tracks.size(2), 0].to(torch.uint8) + tracks[:, :, :, 0] *= W / float(self.interp_shape[1]) + tracks[:, :, :, 1] *= H / float(self.interp_shape[0]) + return tracks, visibilities, T_First + + def _compute_backward_tracks(self, video, queries, tracks, visibilities): + inv_video = video.flip(1).clone() + inv_queries = queries.clone() + inv_queries[:, :, 0] = inv_video.shape[1] - inv_queries[:, :, 0] - 1 + + inv_tracks, __, inv_visibilities = self.model( + rgbds=inv_video, queries=queries, iters=6 + ) + + inv_tracks = inv_tracks.flip(1) + inv_visibilities = inv_visibilities.flip(1) + + mask = tracks == 0 + + tracks[mask] = inv_tracks[mask] + visibilities[mask[:, :, :, 0]] = inv_visibilities[mask[:, :, :, 0]] + return tracks, visibilities \ No newline at end of file diff --git a/models/spatracker/utils/__init__.py b/models/spatracker/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..5277f46157403e47fd830fc519144b97ef69d4ae --- /dev/null +++ b/models/spatracker/utils/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. diff --git a/models/spatracker/utils/basic.py b/models/spatracker/utils/basic.py new file mode 100644 index 0000000000000000000000000000000000000000..4a4a15ecf2c566fe9216f2622ff21c576f0d43f7 --- /dev/null +++ b/models/spatracker/utils/basic.py @@ -0,0 +1,397 @@ +import os +import numpy as np +from os.path import isfile +import torch +import torch.nn.functional as F +EPS = 1e-6 +import copy + +def sub2ind(height, width, y, x): + return y*width + x + +def ind2sub(height, width, ind): + y = ind // width + x = ind % width + return y, x + +def get_lr_str(lr): + lrn = "%.1e" % lr # e.g., 5.0e-04 + lrn = lrn[0] + lrn[3:5] + lrn[-1] # e.g., 5e-4 + return lrn + +def strnum(x): + s = '%g' % x + if '.' in s: + if x < 1.0: + s = s[s.index('.'):] + s = s[:min(len(s),4)] + return s + +def assert_same_shape(t1, t2): + for (x, y) in zip(list(t1.shape), list(t2.shape)): + assert(x==y) + +def print_stats(name, tensor): + shape = tensor.shape + tensor = tensor.detach().cpu().numpy() + print('%s (%s) min = %.2f, mean = %.2f, max = %.2f' % (name, tensor.dtype, np.min(tensor), np.mean(tensor), np.max(tensor)), shape) + +def print_stats_py(name, tensor): + shape = tensor.shape + print('%s (%s) min = %.2f, mean = %.2f, max = %.2f' % (name, tensor.dtype, np.min(tensor), np.mean(tensor), np.max(tensor)), shape) + +def print_(name, tensor): + tensor = tensor.detach().cpu().numpy() + print(name, tensor, tensor.shape) + +def mkdir(path): + if not os.path.exists(path): + os.makedirs(path) + +def normalize_single(d): + # d is a whatever shape torch tensor + dmin = torch.min(d) + dmax = torch.max(d) + d = (d-dmin)/(EPS+(dmax-dmin)) + return d + +def normalize(d): + # d is B x whatever. normalize within each element of the batch + out = torch.zeros(d.size()) + if d.is_cuda: + out = out.cuda() + B = list(d.size())[0] + for b in list(range(B)): + out[b] = normalize_single(d[b]) + return out + +def hard_argmax2d(tensor): + B, C, Y, X = list(tensor.shape) + assert(C==1) + + # flatten the Tensor along the height and width axes + flat_tensor = tensor.reshape(B, -1) + # argmax of the flat tensor + argmax = torch.argmax(flat_tensor, dim=1) + + # convert the indices into 2d coordinates + argmax_y = torch.floor(argmax / X) # row + argmax_x = argmax % X # col + + argmax_y = argmax_y.reshape(B) + argmax_x = argmax_x.reshape(B) + return argmax_y, argmax_x + +def argmax2d(heat, hard=True): + B, C, Y, X = list(heat.shape) + assert(C==1) + + if hard: + # hard argmax + loc_y, loc_x = hard_argmax2d(heat) + loc_y = loc_y.float() + loc_x = loc_x.float() + else: + heat = heat.reshape(B, Y*X) + prob = torch.nn.functional.softmax(heat, dim=1) + + grid_y, grid_x = meshgrid2d(B, Y, X) + + grid_y = grid_y.reshape(B, -1) + grid_x = grid_x.reshape(B, -1) + + loc_y = torch.sum(grid_y*prob, dim=1) + loc_x = torch.sum(grid_x*prob, dim=1) + # these are B + + return loc_y, loc_x + +def reduce_masked_mean(x, mask, dim=None, keepdim=False): + # x and mask are the same shape, or at least broadcastably so < actually it's safer if you disallow broadcasting + # returns shape-1 + # axis can be a list of axes + for (a,b) in zip(x.size(), mask.size()): + # if not b==1: + assert(a==b) # some shape mismatch! + # assert(x.size() == mask.size()) + prod = x*mask + if dim is None: + numer = torch.sum(prod) + denom = EPS+torch.sum(mask) + else: + numer = torch.sum(prod, dim=dim, keepdim=keepdim) + denom = EPS+torch.sum(mask, dim=dim, keepdim=keepdim) + + mean = numer/denom + return mean + +def reduce_masked_median(x, mask, keep_batch=False): + # x and mask are the same shape + assert(x.size() == mask.size()) + device = x.device + + B = list(x.shape)[0] + x = x.detach().cpu().numpy() + mask = mask.detach().cpu().numpy() + + if keep_batch: + x = np.reshape(x, [B, -1]) + mask = np.reshape(mask, [B, -1]) + meds = np.zeros([B], np.float32) + for b in list(range(B)): + xb = x[b] + mb = mask[b] + if np.sum(mb) > 0: + xb = xb[mb > 0] + meds[b] = np.median(xb) + else: + meds[b] = np.nan + meds = torch.from_numpy(meds).to(device) + return meds.float() + else: + x = np.reshape(x, [-1]) + mask = np.reshape(mask, [-1]) + if np.sum(mask) > 0: + x = x[mask > 0] + med = np.median(x) + else: + med = np.nan + med = np.array([med], np.float32) + med = torch.from_numpy(med).to(device) + return med.float() + +def pack_seqdim(tensor, B): + shapelist = list(tensor.shape) + B_, S = shapelist[:2] + assert(B==B_) + otherdims = shapelist[2:] + tensor = torch.reshape(tensor, [B*S]+otherdims) + return tensor + +def unpack_seqdim(tensor, B): + shapelist = list(tensor.shape) + BS = shapelist[0] + assert(BS%B==0) + otherdims = shapelist[1:] + S = int(BS/B) + tensor = torch.reshape(tensor, [B,S]+otherdims) + return tensor + +def meshgrid2d(B, Y, X, stack=False, norm=False, device='cuda', on_chans=False): + # returns a meshgrid sized B x Y x X + + grid_y = torch.linspace(0.0, Y-1, Y, device=torch.device(device)) + grid_y = torch.reshape(grid_y, [1, Y, 1]) + grid_y = grid_y.repeat(B, 1, X) + + grid_x = torch.linspace(0.0, X-1, X, device=torch.device(device)) + grid_x = torch.reshape(grid_x, [1, 1, X]) + grid_x = grid_x.repeat(B, Y, 1) + + if norm: + grid_y, grid_x = normalize_grid2d( + grid_y, grid_x, Y, X) + + if stack: + # note we stack in xy order + # (see https://pytorch.org/docs/stable/nn.functional.html#torch.nn.functional.grid_sample) + if on_chans: + grid = torch.stack([grid_x, grid_y], dim=1) + else: + grid = torch.stack([grid_x, grid_y], dim=-1) + return grid + else: + return grid_y, grid_x + +def meshgrid3d(B, Z, Y, X, stack=False, norm=False, device='cuda'): + # returns a meshgrid sized B x Z x Y x X + + grid_z = torch.linspace(0.0, Z-1, Z, device=device) + grid_z = torch.reshape(grid_z, [1, Z, 1, 1]) + grid_z = grid_z.repeat(B, 1, Y, X) + + grid_y = torch.linspace(0.0, Y-1, Y, device=device) + grid_y = torch.reshape(grid_y, [1, 1, Y, 1]) + grid_y = grid_y.repeat(B, Z, 1, X) + + grid_x = torch.linspace(0.0, X-1, X, device=device) + grid_x = torch.reshape(grid_x, [1, 1, 1, X]) + grid_x = grid_x.repeat(B, Z, Y, 1) + + # if cuda: + # grid_z = grid_z.cuda() + # grid_y = grid_y.cuda() + # grid_x = grid_x.cuda() + + if norm: + grid_z, grid_y, grid_x = normalize_grid3d( + grid_z, grid_y, grid_x, Z, Y, X) + + if stack: + # note we stack in xyz order + # (see https://pytorch.org/docs/stable/nn.functional.html#torch.nn.functional.grid_sample) + grid = torch.stack([grid_x, grid_y, grid_z], dim=-1) + return grid + else: + return grid_z, grid_y, grid_x + +def normalize_grid2d(grid_y, grid_x, Y, X, clamp_extreme=True): + # make things in [-1,1] + grid_y = 2.0*(grid_y / float(Y-1)) - 1.0 + grid_x = 2.0*(grid_x / float(X-1)) - 1.0 + + if clamp_extreme: + grid_y = torch.clamp(grid_y, min=-2.0, max=2.0) + grid_x = torch.clamp(grid_x, min=-2.0, max=2.0) + + return grid_y, grid_x + +def normalize_grid3d(grid_z, grid_y, grid_x, Z, Y, X, clamp_extreme=True): + # make things in [-1,1] + grid_z = 2.0*(grid_z / float(Z-1)) - 1.0 + grid_y = 2.0*(grid_y / float(Y-1)) - 1.0 + grid_x = 2.0*(grid_x / float(X-1)) - 1.0 + + if clamp_extreme: + grid_z = torch.clamp(grid_z, min=-2.0, max=2.0) + grid_y = torch.clamp(grid_y, min=-2.0, max=2.0) + grid_x = torch.clamp(grid_x, min=-2.0, max=2.0) + + return grid_z, grid_y, grid_x + +def gridcloud2d(B, Y, X, norm=False, device='cuda'): + # we want to sample for each location in the grid + grid_y, grid_x = meshgrid2d(B, Y, X, norm=norm, device=device) + x = torch.reshape(grid_x, [B, -1]) + y = torch.reshape(grid_y, [B, -1]) + # these are B x N + xy = torch.stack([x, y], dim=2) + # this is B x N x 2 + return xy + +def gridcloud3d(B, Z, Y, X, norm=False, device='cuda'): + # we want to sample for each location in the grid + grid_z, grid_y, grid_x = meshgrid3d(B, Z, Y, X, norm=norm, device=device) + x = torch.reshape(grid_x, [B, -1]) + y = torch.reshape(grid_y, [B, -1]) + z = torch.reshape(grid_z, [B, -1]) + # these are B x N + xyz = torch.stack([x, y, z], dim=2) + # this is B x N x 3 + return xyz + +import re +def readPFM(file): + file = open(file, 'rb') + + color = None + width = None + height = None + scale = None + endian = None + + header = file.readline().rstrip() + if header == b'PF': + color = True + elif header == b'Pf': + color = False + else: + raise Exception('Not a PFM file.') + + dim_match = re.match(rb'^(\d+)\s(\d+)\s$', file.readline()) + if dim_match: + width, height = map(int, dim_match.groups()) + else: + raise Exception('Malformed PFM header.') + + scale = float(file.readline().rstrip()) + if scale < 0: # little-endian + endian = '<' + scale = -scale + else: + endian = '>' # big-endian + + data = np.fromfile(file, endian + 'f') + shape = (height, width, 3) if color else (height, width) + + data = np.reshape(data, shape) + data = np.flipud(data) + return data + +def normalize_boxlist2d(boxlist2d, H, W): + boxlist2d = boxlist2d.clone() + ymin, xmin, ymax, xmax = torch.unbind(boxlist2d, dim=2) + ymin = ymin / float(H) + ymax = ymax / float(H) + xmin = xmin / float(W) + xmax = xmax / float(W) + boxlist2d = torch.stack([ymin, xmin, ymax, xmax], dim=2) + return boxlist2d + +def unnormalize_boxlist2d(boxlist2d, H, W): + boxlist2d = boxlist2d.clone() + ymin, xmin, ymax, xmax = torch.unbind(boxlist2d, dim=2) + ymin = ymin * float(H) + ymax = ymax * float(H) + xmin = xmin * float(W) + xmax = xmax * float(W) + boxlist2d = torch.stack([ymin, xmin, ymax, xmax], dim=2) + return boxlist2d + +def unnormalize_box2d(box2d, H, W): + return unnormalize_boxlist2d(box2d.unsqueeze(1), H, W).squeeze(1) + +def normalize_box2d(box2d, H, W): + return normalize_boxlist2d(box2d.unsqueeze(1), H, W).squeeze(1) + +def get_gaussian_kernel_2d(channels, kernel_size=3, sigma=2.0, mid_one=False): + C = channels + xy_grid = gridcloud2d(C, kernel_size, kernel_size) # C x N x 2 + + mean = (kernel_size - 1)/2.0 + variance = sigma**2.0 + + gaussian_kernel = (1.0/(2.0*np.pi*variance)**1.5) * torch.exp(-torch.sum((xy_grid - mean)**2.0, dim=-1) / (2.0*variance)) # C X N + gaussian_kernel = gaussian_kernel.view(C, 1, kernel_size, kernel_size) # C x 1 x 3 x 3 + kernel_sum = torch.sum(gaussian_kernel, dim=(2,3), keepdim=True) + + gaussian_kernel = gaussian_kernel / kernel_sum # normalize + + if mid_one: + # normalize so that the middle element is 1 + maxval = gaussian_kernel[:,:,(kernel_size//2),(kernel_size//2)].reshape(C, 1, 1, 1) + gaussian_kernel = gaussian_kernel / maxval + + return gaussian_kernel + +def gaussian_blur_2d(input, kernel_size=3, sigma=2.0, reflect_pad=False, mid_one=False): + B, C, Z, X = input.shape + kernel = get_gaussian_kernel_2d(C, kernel_size, sigma, mid_one=mid_one) + if reflect_pad: + pad = (kernel_size - 1)//2 + out = F.pad(input, (pad, pad, pad, pad), mode='reflect') + out = F.conv2d(out, kernel, padding=0, groups=C) + else: + out = F.conv2d(input, kernel, padding=(kernel_size - 1)//2, groups=C) + return out + +def gradient2d(x, absolute=False, square=False, return_sum=False): + # x should be B x C x H x W + dh = x[:, :, 1:, :] - x[:, :, :-1, :] + dw = x[:, :, :, 1:] - x[:, :, :, :-1] + + zeros = torch.zeros_like(x) + zero_h = zeros[:, :, 0:1, :] + zero_w = zeros[:, :, :, 0:1] + dh = torch.cat([dh, zero_h], axis=2) + dw = torch.cat([dw, zero_w], axis=3) + if absolute: + dh = torch.abs(dh) + dw = torch.abs(dw) + if square: + dh = dh ** 2 + dw = dw ** 2 + if return_sum: + return dh+dw + else: + return dh, dw diff --git a/models/spatracker/utils/geom.py b/models/spatracker/utils/geom.py new file mode 100644 index 0000000000000000000000000000000000000000..486bfff4fa0d2ab677d64666899755bd6c7780eb --- /dev/null +++ b/models/spatracker/utils/geom.py @@ -0,0 +1,547 @@ +import torch +import models.spatracker.utils.basic +import numpy as np +import torchvision.ops as ops +from models.spatracker.utils.basic import print_ + +def matmul2(mat1, mat2): + return torch.matmul(mat1, mat2) + +def matmul3(mat1, mat2, mat3): + return torch.matmul(mat1, torch.matmul(mat2, mat3)) + +def eye_3x3(B, device='cuda'): + rt = torch.eye(3, device=torch.device(device)).view(1,3,3).repeat([B, 1, 1]) + return rt + +def eye_4x4(B, device='cuda'): + rt = torch.eye(4, device=torch.device(device)).view(1,4,4).repeat([B, 1, 1]) + return rt + +def safe_inverse(a): #parallel version + B, _, _ = list(a.shape) + inv = a.clone() + r_transpose = a[:, :3, :3].transpose(1,2) #inverse of rotation matrix + + inv[:, :3, :3] = r_transpose + inv[:, :3, 3:4] = -torch.matmul(r_transpose, a[:, :3, 3:4]) + + return inv + +def safe_inverse_single(a): + r, t = split_rt_single(a) + t = t.view(3,1) + r_transpose = r.t() + inv = torch.cat([r_transpose, -torch.matmul(r_transpose, t)], 1) + bottom_row = a[3:4, :] # this is [0, 0, 0, 1] + # bottom_row = torch.tensor([0.,0.,0.,1.]).view(1,4) + inv = torch.cat([inv, bottom_row], 0) + return inv + +def split_intrinsics(K): + # K is B x 3 x 3 or B x 4 x 4 + fx = K[:,0,0] + fy = K[:,1,1] + x0 = K[:,0,2] + y0 = K[:,1,2] + return fx, fy, x0, y0 + +def apply_pix_T_cam(pix_T_cam, xyz): + + fx, fy, x0, y0 = split_intrinsics(pix_T_cam) + + # xyz is shaped B x H*W x 3 + # returns xy, shaped B x H*W x 2 + + B, N, C = list(xyz.shape) + assert(C==3) + + x, y, z = torch.unbind(xyz, axis=-1) + + fx = torch.reshape(fx, [B, 1]) + fy = torch.reshape(fy, [B, 1]) + x0 = torch.reshape(x0, [B, 1]) + y0 = torch.reshape(y0, [B, 1]) + + EPS = 1e-4 + z = torch.clamp(z, min=EPS) + x = (x*fx)/(z)+x0 + y = (y*fy)/(z)+y0 + xy = torch.stack([x, y], axis=-1) + return xy + +def apply_pix_T_cam_py(pix_T_cam, xyz): + + fx, fy, x0, y0 = split_intrinsics(pix_T_cam) + + # xyz is shaped B x H*W x 3 + # returns xy, shaped B x H*W x 2 + + B, N, C = list(xyz.shape) + assert(C==3) + + x, y, z = xyz[:,:,0], xyz[:,:,1], xyz[:,:,2] + + fx = np.reshape(fx, [B, 1]) + fy = np.reshape(fy, [B, 1]) + x0 = np.reshape(x0, [B, 1]) + y0 = np.reshape(y0, [B, 1]) + + EPS = 1e-4 + z = np.clip(z, EPS, None) + x = (x*fx)/(z)+x0 + y = (y*fy)/(z)+y0 + xy = np.stack([x, y], axis=-1) + return xy + +def get_camM_T_camXs(origin_T_camXs, ind=0): + B, S = list(origin_T_camXs.shape)[0:2] + camM_T_camXs = torch.zeros_like(origin_T_camXs) + for b in list(range(B)): + camM_T_origin = safe_inverse_single(origin_T_camXs[b,ind]) + for s in list(range(S)): + camM_T_camXs[b,s] = torch.matmul(camM_T_origin, origin_T_camXs[b,s]) + return camM_T_camXs + +def apply_4x4(RT, xyz): + B, N, _ = list(xyz.shape) + ones = torch.ones_like(xyz[:,:,0:1]) + xyz1 = torch.cat([xyz, ones], 2) + xyz1_t = torch.transpose(xyz1, 1, 2) + # this is B x 4 x N + xyz2_t = torch.matmul(RT, xyz1_t) + xyz2 = torch.transpose(xyz2_t, 1, 2) + xyz2 = xyz2[:,:,:3] + return xyz2 + +def apply_4x4_py(RT, xyz): + # print('RT', RT.shape) + B, N, _ = list(xyz.shape) + ones = np.ones_like(xyz[:,:,0:1]) + xyz1 = np.concatenate([xyz, ones], 2) + # print('xyz1', xyz1.shape) + xyz1_t = xyz1.transpose(0,2,1) + # print('xyz1_t', xyz1_t.shape) + # this is B x 4 x N + xyz2_t = np.matmul(RT, xyz1_t) + # print('xyz2_t', xyz2_t.shape) + xyz2 = xyz2_t.transpose(0,2,1) + # print('xyz2', xyz2.shape) + xyz2 = xyz2[:,:,:3] + return xyz2 + +def apply_3x3(RT, xy): + B, N, _ = list(xy.shape) + ones = torch.ones_like(xy[:,:,0:1]) + xy1 = torch.cat([xy, ones], 2) + xy1_t = torch.transpose(xy1, 1, 2) + # this is B x 4 x N + xy2_t = torch.matmul(RT, xy1_t) + xy2 = torch.transpose(xy2_t, 1, 2) + xy2 = xy2[:,:,:2] + return xy2 + +def generate_polygon(ctr_x, ctr_y, avg_r, irregularity, spikiness, num_verts): + ''' + Start with the center of the polygon at ctr_x, ctr_y, + Then creates the polygon by sampling points on a circle around the center. + Random noise is added by varying the angular spacing between sequential points, + and by varying the radial distance of each point from the centre. + + Params: + ctr_x, ctr_y - coordinates of the "centre" of the polygon + avg_r - in px, the average radius of this polygon, this roughly controls how large the polygon is, really only useful for order of magnitude. + irregularity - [0,1] indicating how much variance there is in the angular spacing of vertices. [0,1] will map to [0, 2pi/numberOfVerts] + spikiness - [0,1] indicating how much variance there is in each vertex from the circle of radius avg_r. [0,1] will map to [0, avg_r] +pp num_verts + + Returns: + np.array [num_verts, 2] - CCW order. + ''' + # spikiness + spikiness = np.clip(spikiness, 0, 1) * avg_r + + # generate n angle steps + irregularity = np.clip(irregularity, 0, 1) * 2 * np.pi / num_verts + lower = (2*np.pi / num_verts) - irregularity + upper = (2*np.pi / num_verts) + irregularity + + # angle steps + angle_steps = np.random.uniform(lower, upper, num_verts) + sc = (2 * np.pi) / angle_steps.sum() + angle_steps *= sc + + # get all radii + angle = np.random.uniform(0, 2*np.pi) + radii = np.clip(np.random.normal(avg_r, spikiness, num_verts), 0, 2 * avg_r) + + # compute all points + points = [] + for i in range(num_verts): + x = ctr_x + radii[i] * np.cos(angle) + y = ctr_y + radii[i] * np.sin(angle) + points.append([x, y]) + angle += angle_steps[i] + + return np.array(points).astype(int) + + +def get_random_affine_2d(B, rot_min=-5.0, rot_max=5.0, tx_min=-0.1, tx_max=0.1, ty_min=-0.1, ty_max=0.1, sx_min=-0.05, sx_max=0.05, sy_min=-0.05, sy_max=0.05, shx_min=-0.05, shx_max=0.05, shy_min=-0.05, shy_max=0.05): + ''' + Params: + rot_min: rotation amount min + rot_max: rotation amount max + + tx_min: translation x min + tx_max: translation x max + + ty_min: translation y min + ty_max: translation y max + + sx_min: scaling x min + sx_max: scaling x max + + sy_min: scaling y min + sy_max: scaling y max + + shx_min: shear x min + shx_max: shear x max + + shy_min: shear y min + shy_max: shear y max + + Returns: + transformation matrix: (B, 3, 3) + ''' + # rotation + if rot_max - rot_min != 0: + rot_amount = np.random.uniform(low=rot_min, high=rot_max, size=B) + rot_amount = np.pi/180.0*rot_amount + else: + rot_amount = rot_min + rotation = np.zeros((B, 3, 3)) # B, 3, 3 + rotation[:, 2, 2] = 1 + rotation[:, 0, 0] = np.cos(rot_amount) + rotation[:, 0, 1] = -np.sin(rot_amount) + rotation[:, 1, 0] = np.sin(rot_amount) + rotation[:, 1, 1] = np.cos(rot_amount) + + # translation + translation = np.zeros((B, 3, 3)) # B, 3, 3 + translation[:, [0,1,2], [0,1,2]] = 1 + if (tx_max - tx_min) > 0: + trans_x = np.random.uniform(low=tx_min, high=tx_max, size=B) + translation[:, 0, 2] = trans_x + # else: + # translation[:, 0, 2] = tx_max + if ty_max - ty_min != 0: + trans_y = np.random.uniform(low=ty_min, high=ty_max, size=B) + translation[:, 1, 2] = trans_y + # else: + # translation[:, 1, 2] = ty_max + + # scaling + scaling = np.zeros((B, 3, 3)) # B, 3, 3 + scaling[:, [0,1,2], [0,1,2]] = 1 + if (sx_max - sx_min) > 0: + scale_x = 1 + np.random.uniform(low=sx_min, high=sx_max, size=B) + scaling[:, 0, 0] = scale_x + # else: + # scaling[:, 0, 0] = sx_max + if (sy_max - sy_min) > 0: + scale_y = 1 + np.random.uniform(low=sy_min, high=sy_max, size=B) + scaling[:, 1, 1] = scale_y + # else: + # scaling[:, 1, 1] = sy_max + + # shear + shear = np.zeros((B, 3, 3)) # B, 3, 3 + shear[:, [0,1,2], [0,1,2]] = 1 + if (shx_max - shx_min) > 0: + shear_x = np.random.uniform(low=shx_min, high=shx_max, size=B) + shear[:, 0, 1] = shear_x + # else: + # shear[:, 0, 1] = shx_max + if (shy_max - shy_min) > 0: + shear_y = np.random.uniform(low=shy_min, high=shy_max, size=B) + shear[:, 1, 0] = shear_y + # else: + # shear[:, 1, 0] = shy_max + + # compose all those + rt = np.einsum("ijk,ikl->ijl", rotation, translation) + ss = np.einsum("ijk,ikl->ijl", scaling, shear) + trans = np.einsum("ijk,ikl->ijl", rt, ss) + + return trans + +def get_centroid_from_box2d(box2d): + ymin = box2d[:,0] + xmin = box2d[:,1] + ymax = box2d[:,2] + xmax = box2d[:,3] + x = (xmin+xmax)/2.0 + y = (ymin+ymax)/2.0 + return y, x + +def normalize_boxlist2d(boxlist2d, H, W): + boxlist2d = boxlist2d.clone() + ymin, xmin, ymax, xmax = torch.unbind(boxlist2d, dim=2) + ymin = ymin / float(H) + ymax = ymax / float(H) + xmin = xmin / float(W) + xmax = xmax / float(W) + boxlist2d = torch.stack([ymin, xmin, ymax, xmax], dim=2) + return boxlist2d + +def unnormalize_boxlist2d(boxlist2d, H, W): + boxlist2d = boxlist2d.clone() + ymin, xmin, ymax, xmax = torch.unbind(boxlist2d, dim=2) + ymin = ymin * float(H) + ymax = ymax * float(H) + xmin = xmin * float(W) + xmax = xmax * float(W) + boxlist2d = torch.stack([ymin, xmin, ymax, xmax], dim=2) + return boxlist2d + +def unnormalize_box2d(box2d, H, W): + return unnormalize_boxlist2d(box2d.unsqueeze(1), H, W).squeeze(1) + +def normalize_box2d(box2d, H, W): + return normalize_boxlist2d(box2d.unsqueeze(1), H, W).squeeze(1) + +def get_size_from_box2d(box2d): + ymin = box2d[:,0] + xmin = box2d[:,1] + ymax = box2d[:,2] + xmax = box2d[:,3] + height = ymax-ymin + width = xmax-xmin + return height, width + +def crop_and_resize(im, boxlist, PH, PW, boxlist_is_normalized=False): + B, C, H, W = im.shape + B2, N, D = boxlist.shape + assert(B==B2) + assert(D==4) + # PH, PW is the size to resize to + + # output is B,N,C,PH,PW + + # pt wants xy xy, unnormalized + if boxlist_is_normalized: + boxlist_unnorm = unnormalize_boxlist2d(boxlist, H, W) + else: + boxlist_unnorm = boxlist + + ymin, xmin, ymax, xmax = boxlist_unnorm.unbind(2) + # boxlist_pt = torch.stack([boxlist_unnorm[:,1], boxlist_unnorm[:,0], boxlist_unnorm[:,3], boxlist_unnorm[:,2]], dim=1) + boxlist_pt = torch.stack([xmin, ymin, xmax, ymax], dim=2) + # we want a B-len list of K x 4 arrays + + # print('im', im.shape) + # print('boxlist', boxlist.shape) + # print('boxlist_pt', boxlist_pt.shape) + + # boxlist_pt = list(boxlist_pt.unbind(0)) + + crops = [] + for b in range(B): + crops_b = ops.roi_align(im[b:b+1], [boxlist_pt[b]], output_size=(PH, PW)) + crops.append(crops_b) + # # crops = im + + # print('crops', crops.shape) + # crops = crops.reshape(B,N,C,PH,PW) + + + # crops = [] + # for b in range(B): + # crop_b = ops.roi_align(im[b:b+1], [boxlist_pt[b]], output_size=(PH, PW)) + # print('crop_b', crop_b.shape) + # crops.append(crop_b) + crops = torch.stack(crops, dim=0) + + # print('crops', crops.shape) + # boxlist_list = boxlist_pt.unbind(0) + # print('rgb_crop', rgb_crop.shape) + + return crops + + +# def get_boxlist_from_centroid_and_size(cy, cx, h, w, clip=True): +# # cy,cx are both B,N +# ymin = cy - h/2 +# ymax = cy + h/2 +# xmin = cx - w/2 +# xmax = cx + w/2 + +# box = torch.stack([ymin, xmin, ymax, xmax], dim=-1) +# if clip: +# box = torch.clamp(box, 0, 1) +# return box + + +def get_boxlist_from_centroid_and_size(cy, cx, h, w):#, clip=False): + # cy,cx are the same shape + ymin = cy - h/2 + ymax = cy + h/2 + xmin = cx - w/2 + xmax = cx + w/2 + + # if clip: + # ymin = torch.clamp(ymin, 0, H-1) + # ymax = torch.clamp(ymax, 0, H-1) + # xmin = torch.clamp(xmin, 0, W-1) + # xmax = torch.clamp(xmax, 0, W-1) + + box = torch.stack([ymin, xmin, ymax, xmax], dim=-1) + return box + + +def get_box2d_from_mask(mask, normalize=False): + # mask is B, 1, H, W + + B, C, H, W = mask.shape + assert(C==1) + xy = utils.basic.gridcloud2d(B, H, W, norm=False, device=mask.device) # B, H*W, 2 + + box = torch.zeros((B, 4), dtype=torch.float32, device=mask.device) + for b in range(B): + xy_b = xy[b] # H*W, 2 + mask_b = mask[b].reshape(H*W) + xy_ = xy_b[mask_b > 0] + x_ = xy_[:,0] + y_ = xy_[:,1] + ymin = torch.min(y_) + ymax = torch.max(y_) + xmin = torch.min(x_) + xmax = torch.max(x_) + box[b] = torch.stack([ymin, xmin, ymax, xmax], dim=0) + if normalize: + box = normalize_boxlist2d(box.unsqueeze(1), H, W).squeeze(1) + return box + +def convert_box2d_to_intrinsics(box2d, pix_T_cam, H, W, use_image_aspect_ratio=True, mult_padding=1.0): + # box2d is B x 4, with ymin, xmin, ymax, xmax in normalized coords + # ymin, xmin, ymax, xmax = torch.unbind(box2d, dim=1) + # H, W is the original size of the image + # mult_padding is relative to object size in pixels + + # i assume we're rendering an image the same size as the original (H, W) + + if not mult_padding==1.0: + y, x = get_centroid_from_box2d(box2d) + h, w = get_size_from_box2d(box2d) + box2d = get_box2d_from_centroid_and_size( + y, x, h*mult_padding, w*mult_padding, clip=False) + + if use_image_aspect_ratio: + h, w = get_size_from_box2d(box2d) + y, x = get_centroid_from_box2d(box2d) + + # note h,w are relative right now + # we need to undo this, to see the real ratio + + h = h*float(H) + w = w*float(W) + box_ratio = h/w + im_ratio = H/float(W) + + # print('box_ratio:', box_ratio) + # print('im_ratio:', im_ratio) + + if box_ratio >= im_ratio: + w = h/im_ratio + # print('setting w:', h/im_ratio) + else: + h = w*im_ratio + # print('setting h:', w*im_ratio) + + box2d = get_box2d_from_centroid_and_size( + y, x, h/float(H), w/float(W), clip=False) + + assert(h > 1e-4) + assert(w > 1e-4) + + ymin, xmin, ymax, xmax = torch.unbind(box2d, dim=1) + + fx, fy, x0, y0 = split_intrinsics(pix_T_cam) + + # the topleft of the new image will now have a different offset from the center of projection + + new_x0 = x0 - xmin*W + new_y0 = y0 - ymin*H + + pix_T_cam = pack_intrinsics(fx, fy, new_x0, new_y0) + # this alone will give me an image in original resolution, + # with its topleft at the box corner + + box_h, box_w = get_size_from_box2d(box2d) + # these are normalized, and shaped B. (e.g., [0.4], [0.3]) + + # we are going to scale the image by the inverse of this, + # since we are zooming into this area + + sy = 1./box_h + sx = 1./box_w + + pix_T_cam = scale_intrinsics(pix_T_cam, sx, sy) + return pix_T_cam, box2d + +def pixels2camera(x,y,z,fx,fy,x0,y0): + # x and y are locations in pixel coordinates, z is a depth in meters + # they can be images or pointclouds + # fx, fy, x0, y0 are camera intrinsics + # returns xyz, sized B x N x 3 + + B = x.shape[0] + + fx = torch.reshape(fx, [B,1]) + fy = torch.reshape(fy, [B,1]) + x0 = torch.reshape(x0, [B,1]) + y0 = torch.reshape(y0, [B,1]) + + x = torch.reshape(x, [B,-1]) + y = torch.reshape(y, [B,-1]) + z = torch.reshape(z, [B,-1]) + + # unproject + x = (z/fx)*(x-x0) + y = (z/fy)*(y-y0) + + xyz = torch.stack([x,y,z], dim=2) + # B x N x 3 + return xyz + +def camera2pixels(xyz, pix_T_cam): + # xyz is shaped B x H*W x 3 + # returns xy, shaped B x H*W x 2 + + fx, fy, x0, y0 = split_intrinsics(pix_T_cam) + x, y, z = torch.unbind(xyz, dim=-1) + B = list(z.shape)[0] + + fx = torch.reshape(fx, [B,1]) + fy = torch.reshape(fy, [B,1]) + x0 = torch.reshape(x0, [B,1]) + y0 = torch.reshape(y0, [B,1]) + x = torch.reshape(x, [B,-1]) + y = torch.reshape(y, [B,-1]) + z = torch.reshape(z, [B,-1]) + + EPS = 1e-4 + z = torch.clamp(z, min=EPS) + x = (x*fx)/z + x0 + y = (y*fy)/z + y0 + xy = torch.stack([x, y], dim=-1) + return xy + +def depth2pointcloud(z, pix_T_cam): + B, C, H, W = list(z.shape) + device = z.device + y, x = utils.basic.meshgrid2d(B, H, W, device=device) + z = torch.reshape(z, [B, H, W]) + fx, fy, x0, y0 = split_intrinsics(pix_T_cam) + xyz = pixels2camera(x, y, z, fx, fy, x0, y0) + return xyz diff --git a/models/spatracker/utils/improc.py b/models/spatracker/utils/improc.py new file mode 100644 index 0000000000000000000000000000000000000000..364daefa705df0b56e7df13e338bc2fabadd1fed --- /dev/null +++ b/models/spatracker/utils/improc.py @@ -0,0 +1,1447 @@ +import torch +import numpy as np +import models.spatracker.utils.basic +from sklearn.decomposition import PCA +from matplotlib import cm +import matplotlib.pyplot as plt +import cv2 +import torch.nn.functional as F +import torchvision +EPS = 1e-6 + +from skimage.color import ( + rgb2lab, rgb2yuv, rgb2ycbcr, lab2rgb, yuv2rgb, ycbcr2rgb, + rgb2hsv, hsv2rgb, rgb2xyz, xyz2rgb, rgb2hed, hed2rgb) + +def _convert(input_, type_): + return { + 'float': input_.float(), + 'double': input_.double(), + }.get(type_, input_) + +def _generic_transform_sk_3d(transform, in_type='', out_type=''): + def apply_transform_individual(input_): + device = input_.device + input_ = input_.cpu() + input_ = _convert(input_, in_type) + + input_ = input_.permute(1, 2, 0).detach().numpy() + transformed = transform(input_) + output = torch.from_numpy(transformed).float().permute(2, 0, 1) + output = _convert(output, out_type) + return output.to(device) + + def apply_transform(input_): + to_stack = [] + for image in input_: + to_stack.append(apply_transform_individual(image)) + return torch.stack(to_stack) + return apply_transform + +hsv_to_rgb = _generic_transform_sk_3d(hsv2rgb) + +def preprocess_color_tf(x): + import tensorflow as tf + return tf.cast(x,tf.float32) * 1./255 - 0.5 + +def preprocess_color(x): + if isinstance(x, np.ndarray): + return x.astype(np.float32) * 1./255 - 0.5 + else: + return x.float() * 1./255 - 0.5 + +def pca_embed(emb, keep, valid=None): + ## emb -- [S,H/2,W/2,C] + ## keep is the number of principal components to keep + ## Helper function for reduce_emb. + emb = emb + EPS + #emb is B x C x H x W + emb = emb.permute(0, 2, 3, 1).cpu().detach().numpy() #this is B x H x W x C + + if valid: + valid = valid.cpu().detach().numpy().reshape((H*W)) + + emb_reduced = list() + + B, H, W, C = np.shape(emb) + for img in emb: + if np.isnan(img).any(): + emb_reduced.append(np.zeros([H, W, keep])) + continue + + pixels_kd = np.reshape(img, (H*W, C)) + + if valid: + pixels_kd_pca = pixels_kd[valid] + else: + pixels_kd_pca = pixels_kd + + P = PCA(keep) + P.fit(pixels_kd_pca) + + if valid: + pixels3d = P.transform(pixels_kd)*valid + else: + pixels3d = P.transform(pixels_kd) + + out_img = np.reshape(pixels3d, [H,W,keep]).astype(np.float32) + if np.isnan(out_img).any(): + emb_reduced.append(np.zeros([H, W, keep])) + continue + + emb_reduced.append(out_img) + + emb_reduced = np.stack(emb_reduced, axis=0).astype(np.float32) + + return torch.from_numpy(emb_reduced).permute(0, 3, 1, 2) + +def pca_embed_together(emb, keep): + ## emb -- [S,H/2,W/2,C] + ## keep is the number of principal components to keep + ## Helper function for reduce_emb. + emb = emb + EPS + #emb is B x C x H x W + emb = emb.permute(0, 2, 3, 1).cpu().detach().numpy() #this is B x H x W x C + + B, H, W, C = np.shape(emb) + if np.isnan(emb).any(): + return torch.zeros(B, keep, H, W) + + pixelskd = np.reshape(emb, (B*H*W, C)) + P = PCA(keep) + P.fit(pixelskd) + pixels3d = P.transform(pixelskd) + out_img = np.reshape(pixels3d, [B,H,W,keep]).astype(np.float32) + + if np.isnan(out_img).any(): + return torch.zeros(B, keep, H, W) + + return torch.from_numpy(out_img).permute(0, 3, 1, 2) + +def reduce_emb(emb, valid=None, inbound=None, together=False): + ## emb -- [S,C,H/2,W/2], inbound -- [S,1,H/2,W/2] + ## Reduce number of chans to 3 with PCA. For vis. + # S,H,W,C = emb.shape.as_list() + S, C, H, W = list(emb.size()) + keep = 3 + + if together: + reduced_emb = pca_embed_together(emb, keep) + else: + reduced_emb = pca_embed(emb, keep, valid) #not im + + reduced_emb = utils.basic.normalize(reduced_emb) - 0.5 + if inbound is not None: + emb_inbound = emb*inbound + else: + emb_inbound = None + + return reduced_emb, emb_inbound + +def get_feat_pca(feat, valid=None): + B, C, D, W = list(feat.size()) + # feat is B x C x D x W. If 3D input, average it through Height dimension before passing into this function. + + pca, _ = reduce_emb(feat, valid=valid,inbound=None, together=True) + # pca is B x 3 x W x D + return pca + +def gif_and_tile(ims, just_gif=False): + S = len(ims) + # each im is B x H x W x C + # i want a gif in the left, and the tiled frames on the right + # for the gif tool, this means making a B x S x H x W tensor + # where the leftmost part is sequential and the rest is tiled + gif = torch.stack(ims, dim=1) + if just_gif: + return gif + til = torch.cat(ims, dim=2) + til = til.unsqueeze(dim=1).repeat(1, S, 1, 1, 1) + im = torch.cat([gif, til], dim=3) + return im + +def back2color(i, blacken_zeros=False): + if blacken_zeros: + const = torch.tensor([-0.5]) + i = torch.where(i==0.0, const.cuda() if i.is_cuda else const, i) + return back2color(i) + else: + return ((i+0.5)*255).type(torch.ByteTensor) + +def convert_occ_to_height(occ, reduce_axis=3): + B, C, D, H, W = list(occ.shape) + assert(C==1) + # note that height increases DOWNWARD in the tensor + # (like pixel/camera coordinates) + + G = list(occ.shape)[reduce_axis] + values = torch.linspace(float(G), 1.0, steps=G, dtype=torch.float32, device=occ.device) + if reduce_axis==2: + # fro view + values = values.view(1, 1, G, 1, 1) + elif reduce_axis==3: + # top view + values = values.view(1, 1, 1, G, 1) + elif reduce_axis==4: + # lateral view + values = values.view(1, 1, 1, 1, G) + else: + assert(False) # you have to reduce one of the spatial dims (2-4) + values = torch.max(occ*values, dim=reduce_axis)[0]/float(G) + # values = values.view([B, C, D, W]) + return values + +def xy2heatmap(xy, sigma, grid_xs, grid_ys, norm=False): + # xy is B x N x 2, containing float x and y coordinates of N things + # grid_xs and grid_ys are B x N x Y x X + + B, N, Y, X = list(grid_xs.shape) + + mu_x = xy[:,:,0].clone() + mu_y = xy[:,:,1].clone() + + x_valid = (mu_x>-0.5) & (mu_x-0.5) & (mu_y 0.5).float() + return prior + +def seq2color(im, norm=True, colormap='coolwarm'): + B, S, H, W = list(im.shape) + # S is sequential + + # prep a mask of the valid pixels, so we can blacken the invalids later + mask = torch.max(im, dim=1, keepdim=True)[0] + + # turn the S dim into an explicit sequence + coeffs = np.linspace(1.0, float(S), S).astype(np.float32)/float(S) + + # # increase the spacing from the center + # coeffs[:int(S/2)] -= 2.0 + # coeffs[int(S/2)+1:] += 2.0 + + coeffs = torch.from_numpy(coeffs).float().cuda() + coeffs = coeffs.reshape(1, S, 1, 1).repeat(B, 1, H, W) + # scale each channel by the right coeff + im = im * coeffs + # now im is in [1/S, 1], except for the invalid parts which are 0 + # keep the highest valid coeff at each pixel + im = torch.max(im, dim=1, keepdim=True)[0] + + out = [] + for b in range(B): + im_ = im[b] + # move channels out to last dim_ + im_ = im_.detach().cpu().numpy() + im_ = np.squeeze(im_) + # im_ is H x W + if colormap=='coolwarm': + im_ = cm.coolwarm(im_)[:, :, :3] + elif colormap=='PiYG': + im_ = cm.PiYG(im_)[:, :, :3] + elif colormap=='winter': + im_ = cm.winter(im_)[:, :, :3] + elif colormap=='spring': + im_ = cm.spring(im_)[:, :, :3] + elif colormap=='onediff': + im_ = np.reshape(im_, (-1)) + im0_ = cm.spring(im_)[:, :3] + im1_ = cm.winter(im_)[:, :3] + im1_[im_==1/float(S)] = im0_[im_==1/float(S)] + im_ = np.reshape(im1_, (H, W, 3)) + else: + assert(False) # invalid colormap + # move channels into dim 0 + im_ = np.transpose(im_, [2, 0, 1]) + im_ = torch.from_numpy(im_).float().cuda() + out.append(im_) + out = torch.stack(out, dim=0) + + # blacken the invalid pixels, instead of using the 0-color + out = out*mask + # out = out*255.0 + + # put it in [-0.5, 0.5] + out = out - 0.5 + + return out + +def colorize(d): + # this is actually just grayscale right now + + if d.ndim==2: + d = d.unsqueeze(dim=0) + else: + assert(d.ndim==3) + + # color_map = cm.get_cmap('plasma') + color_map = cm.get_cmap('inferno') + # S1, D = traj.shape + + # print('d1', d.shape) + C,H,W = d.shape + assert(C==1) + d = d.reshape(-1) + d = d.detach().cpu().numpy() + # print('d2', d.shape) + color = np.array(color_map(d)) * 255 # rgba + # print('color1', color.shape) + color = np.reshape(color[:,:3], [H*W, 3]) + # print('color2', color.shape) + color = torch.from_numpy(color).permute(1,0).reshape(3,H,W) + # # gather + # cm = matplotlib.cm.get_cmap(cmap if cmap is not None else 'gray') + # if cmap=='RdBu' or cmap=='RdYlGn': + # colors = cm(np.arange(256))[:, :3] + # else: + # colors = cm.colors + # colors = np.array(colors).astype(np.float32) + # colors = np.reshape(colors, [-1, 3]) + # colors = tf.constant(colors, dtype=tf.float32) + + # value = tf.gather(colors, indices) + # colorize(value, normalize=True, vmin=None, vmax=None, cmap=None, vals=255) + + # copy to the three chans + # d = d.repeat(3, 1, 1) + return color + + +def oned2inferno(d, norm=True, do_colorize=False): + # convert a 1chan input to a 3chan image output + + # if it's just B x H x W, add a C dim + if d.ndim==3: + d = d.unsqueeze(dim=1) + # d should be B x C x H x W, where C=1 + B, C, H, W = list(d.shape) + assert(C==1) + + if norm: + d = utils.basic.normalize(d) + + if do_colorize: + rgb = torch.zeros(B, 3, H, W) + for b in list(range(B)): + rgb[b] = colorize(d[b]) + else: + rgb = d.repeat(1, 3, 1, 1)*255.0 + # rgb = (255.0*rgb).type(torch.ByteTensor) + rgb = rgb.type(torch.ByteTensor) + + # rgb = tf.cast(255.0*rgb, tf.uint8) + # rgb = tf.reshape(rgb, [-1, hyp.H, hyp.W, 3]) + # rgb = tf.expand_dims(rgb, axis=0) + return rgb + +def oned2gray(d, norm=True): + # convert a 1chan input to a 3chan image output + + # if it's just B x H x W, add a C dim + if d.ndim==3: + d = d.unsqueeze(dim=1) + # d should be B x C x H x W, where C=1 + B, C, H, W = list(d.shape) + assert(C==1) + + if norm: + d = utils.basic.normalize(d) + + rgb = d.repeat(1,3,1,1) + rgb = (255.0*rgb).type(torch.ByteTensor) + return rgb + + +def draw_frame_id_on_vis(vis, frame_id, scale=0.5, left=5, top=20): + + rgb = vis.detach().cpu().numpy()[0] + rgb = np.transpose(rgb, [1, 2, 0]) # put channels last + rgb = cv2.cvtColor(rgb, cv2.COLOR_RGB2BGR) + color = (255, 255, 255) + # print('putting frame id', frame_id) + + frame_str = utils.basic.strnum(frame_id) + + text_color_bg = (0,0,0) + font = cv2.FONT_HERSHEY_SIMPLEX + text_size, _ = cv2.getTextSize(frame_str, font, scale, 1) + text_w, text_h = text_size + cv2.rectangle(rgb, (left, top-text_h), (left + text_w, top+1), text_color_bg, -1) + + cv2.putText( + rgb, + frame_str, + (left, top), # from left, from top + font, + scale, # font scale (float) + color, + 1) # font thickness (int) + rgb = cv2.cvtColor(rgb.astype(np.uint8), cv2.COLOR_BGR2RGB) + vis = torch.from_numpy(rgb).permute(2, 0, 1).unsqueeze(0) + return vis + +COLORMAP_FILE = "./utils/bremm.png" +class ColorMap2d: + def __init__(self, filename=None): + self._colormap_file = filename or COLORMAP_FILE + self._img = plt.imread(self._colormap_file) + + self._height = self._img.shape[0] + self._width = self._img.shape[1] + + def __call__(self, X): + assert len(X.shape) == 2 + output = np.zeros((X.shape[0], 3)) + for i in range(X.shape[0]): + x, y = X[i, :] + xp = int((self._width-1) * x) + yp = int((self._height-1) * y) + xp = np.clip(xp, 0, self._width-1) + yp = np.clip(yp, 0, self._height-1) + output[i, :] = self._img[yp, xp] + return output + +def get_n_colors(N, sequential=False): + label_colors = [] + for ii in range(N): + if sequential: + rgb = cm.winter(ii/(N-1)) + rgb = (np.array(rgb) * 255).astype(np.uint8)[:3] + else: + rgb = np.zeros(3) + while np.sum(rgb) < 128: # ensure min brightness + rgb = np.random.randint(0,256,3) + label_colors.append(rgb) + return label_colors + +class Summ_writer(object): + def __init__(self, writer, global_step, log_freq=10, fps=8, scalar_freq=100, just_gif=False): + self.writer = writer + self.global_step = global_step + self.log_freq = log_freq + self.fps = fps + self.just_gif = just_gif + self.maxwidth = 10000 + self.save_this = (self.global_step % self.log_freq == 0) + self.scalar_freq = max(scalar_freq,1) + + + def summ_gif(self, name, tensor, blacken_zeros=False): + # tensor should be in B x S x C x H x W + + assert tensor.dtype in {torch.uint8,torch.float32} + shape = list(tensor.shape) + + if tensor.dtype == torch.float32: + tensor = back2color(tensor, blacken_zeros=blacken_zeros) + + video_to_write = tensor[0:1] + + S = video_to_write.shape[1] + if S==1: + # video_to_write is 1 x 1 x C x H x W + self.writer.add_image(name, video_to_write[0,0], global_step=self.global_step) + else: + self.writer.add_video(name, video_to_write, fps=self.fps, global_step=self.global_step) + + return video_to_write + + def draw_boxlist2d_on_image(self, rgb, boxlist, scores=None, tids=None, linewidth=1): + B, C, H, W = list(rgb.shape) + assert(C==3) + B2, N, D = list(boxlist.shape) + assert(B2==B) + assert(D==4) # ymin, xmin, ymax, xmax + + rgb = back2color(rgb) + if scores is None: + scores = torch.ones(B2, N).float() + if tids is None: + tids = torch.arange(N).reshape(1,N).repeat(B2,N).long() + # tids = torch.zeros(B2, N).long() + out = self.draw_boxlist2d_on_image_py( + rgb[0].cpu().detach().numpy(), + boxlist[0].cpu().detach().numpy(), + scores[0].cpu().detach().numpy(), + tids[0].cpu().detach().numpy(), + linewidth=linewidth) + out = torch.from_numpy(out).type(torch.ByteTensor).permute(2, 0, 1) + out = torch.unsqueeze(out, dim=0) + out = preprocess_color(out) + out = torch.reshape(out, [1, C, H, W]) + return out + + def draw_boxlist2d_on_image_py(self, rgb, boxlist, scores, tids, linewidth=1): + # all inputs are numpy tensors + # rgb is H x W x 3 + # boxlist is N x 4 + # scores is N + # tids is N + + rgb = np.transpose(rgb, [1, 2, 0]) # put channels last + # rgb = cv2.cvtColor(rgb, cv2.COLOR_RGB2BGR) + + rgb = rgb.astype(np.uint8).copy() + + + H, W, C = rgb.shape + assert(C==3) + N, D = boxlist.shape + assert(D==4) + + # color_map = cm.get_cmap('tab20') + # color_map = cm.get_cmap('set1') + color_map = cm.get_cmap('Accent') + color_map = color_map.colors + # print('color_map', color_map) + + # draw + for ind, box in enumerate(boxlist): + # box is 4 + if not np.isclose(scores[ind], 0.0): + # box = utils.geom.scale_box2d(box, H, W) + ymin, xmin, ymax, xmax = box + + # ymin, ymax = ymin*H, ymax*H + # xmin, xmax = xmin*W, xmax*W + + # print 'score = %.2f' % scores[ind] + # color_id = tids[ind] % 20 + color_id = tids[ind] + color = color_map[color_id] + color = np.array(color)*255.0 + color = color.round() + # color = color.astype(np.uint8) + # color = color[::-1] + # print('color', color) + + # print 'tid = %d; score = %.3f' % (tids[ind], scores[ind]) + + # if False: + if scores[ind] < 1.0: # not gt + cv2.putText(rgb, + # '%d (%.2f)' % (tids[ind], scores[ind]), + '%.2f' % (scores[ind]), + (int(xmin), int(ymin)), + cv2.FONT_HERSHEY_SIMPLEX, + 0.5, # font size + color), + #1) # font weight + + xmin = np.clip(int(xmin), 0, W-1) + xmax = np.clip(int(xmax), 0, W-1) + ymin = np.clip(int(ymin), 0, H-1) + ymax = np.clip(int(ymax), 0, H-1) + + cv2.line(rgb, (xmin, ymin), (xmin, ymax), color, linewidth, cv2.LINE_AA) + cv2.line(rgb, (xmin, ymin), (xmax, ymin), color, linewidth, cv2.LINE_AA) + cv2.line(rgb, (xmax, ymin), (xmax, ymax), color, linewidth, cv2.LINE_AA) + cv2.line(rgb, (xmax, ymax), (xmin, ymax), color, linewidth, cv2.LINE_AA) + + # rgb = cv2.cvtColor(rgb.astype(np.uint8), cv2.COLOR_BGR2RGB) + return rgb + + def summ_boxlist2d(self, name, rgb, boxlist, scores=None, tids=None, frame_id=None, only_return=False, linewidth=2): + B, C, H, W = list(rgb.shape) + boxlist_vis = self.draw_boxlist2d_on_image(rgb, boxlist, scores=scores, tids=tids, linewidth=linewidth) + return self.summ_rgb(name, boxlist_vis, frame_id=frame_id, only_return=only_return) + + def summ_rgbs(self, name, ims, frame_ids=None, blacken_zeros=False, only_return=False): + if self.save_this: + + ims = gif_and_tile(ims, just_gif=self.just_gif) + vis = ims + + assert vis.dtype in {torch.uint8,torch.float32} + + if vis.dtype == torch.float32: + vis = back2color(vis, blacken_zeros) + + B, S, C, H, W = list(vis.shape) + + if frame_ids is not None: + assert(len(frame_ids)==S) + for s in range(S): + vis[:,s] = draw_frame_id_on_vis(vis[:,s], frame_ids[s]) + + if int(W) > self.maxwidth: + vis = vis[:,:,:,:self.maxwidth] + + if only_return: + return vis + else: + return self.summ_gif(name, vis, blacken_zeros) + + def summ_rgb(self, name, ims, blacken_zeros=False, frame_id=None, only_return=False, halfres=False): + if self.save_this: + assert ims.dtype in {torch.uint8,torch.float32} + + if ims.dtype == torch.float32: + ims = back2color(ims, blacken_zeros) + + #ims is B x C x H x W + vis = ims[0:1] # just the first one + B, C, H, W = list(vis.shape) + + if halfres: + vis = F.interpolate(vis, scale_factor=0.5) + + if frame_id is not None: + vis = draw_frame_id_on_vis(vis, frame_id) + + if int(W) > self.maxwidth: + vis = vis[:,:,:,:self.maxwidth] + + if only_return: + return vis + else: + return self.summ_gif(name, vis.unsqueeze(1), blacken_zeros) + + def flow2color(self, flow, clip=50.0): + """ + :param flow: Optical flow tensor. + :return: RGB image normalized between 0 and 1. + """ + + # flow is B x C x H x W + + B, C, H, W = list(flow.size()) + + flow = flow.clone().detach() + + abs_image = torch.abs(flow) + flow_mean = abs_image.mean(dim=[1,2,3]) + flow_std = abs_image.std(dim=[1,2,3]) + + if clip: + flow = torch.clamp(flow, -clip, clip)/clip + else: + # Apply some kind of normalization. Divide by the perceived maximum (mean + std*2) + flow_max = flow_mean + flow_std*2 + 1e-10 + for b in range(B): + flow[b] = flow[b].clamp(-flow_max[b].item(), flow_max[b].item()) / flow_max[b].clamp(min=1) + + radius = torch.sqrt(torch.sum(flow**2, dim=1, keepdim=True)) #B x 1 x H x W + radius_clipped = torch.clamp(radius, 0.0, 1.0) + + angle = torch.atan2(flow[:, 1:], flow[:, 0:1]) / np.pi #B x 1 x H x W + + hue = torch.clamp((angle + 1.0) / 2.0, 0.0, 1.0) + saturation = torch.ones_like(hue) * 0.75 + value = radius_clipped + hsv = torch.cat([hue, saturation, value], dim=1) #B x 3 x H x W + + #flow = tf.image.hsv_to_rgb(hsv) + flow = hsv_to_rgb(hsv) + flow = (flow*255.0).type(torch.ByteTensor) + return flow + + def summ_flow(self, name, im, clip=0.0, only_return=False, frame_id=None): + # flow is B x C x D x W + if self.save_this: + return self.summ_rgb(name, self.flow2color(im, clip=clip), only_return=only_return, frame_id=frame_id) + else: + return None + + def summ_oneds(self, name, ims, frame_ids=None, bev=False, fro=False, logvis=False, reduce_max=False, max_val=0.0, norm=True, only_return=False, do_colorize=False): + if self.save_this: + if bev: + B, C, H, _, W = list(ims[0].shape) + if reduce_max: + ims = [torch.max(im, dim=3)[0] for im in ims] + else: + ims = [torch.mean(im, dim=3) for im in ims] + elif fro: + B, C, _, H, W = list(ims[0].shape) + if reduce_max: + ims = [torch.max(im, dim=2)[0] for im in ims] + else: + ims = [torch.mean(im, dim=2) for im in ims] + + + if len(ims) != 1: # sequence + im = gif_and_tile(ims, just_gif=self.just_gif) + else: + im = torch.stack(ims, dim=1) # single frame + + B, S, C, H, W = list(im.shape) + + if logvis and max_val: + max_val = np.log(max_val) + im = torch.log(torch.clamp(im, 0)+1.0) + im = torch.clamp(im, 0, max_val) + im = im/max_val + norm = False + elif max_val: + im = torch.clamp(im, 0, max_val) + im = im/max_val + norm = False + + if norm: + # normalize before oned2inferno, + # so that the ranges are similar within B across S + im = utils.basic.normalize(im) + + im = im.view(B*S, C, H, W) + vis = oned2inferno(im, norm=norm, do_colorize=do_colorize) + vis = vis.view(B, S, 3, H, W) + + if frame_ids is not None: + assert(len(frame_ids)==S) + for s in range(S): + vis[:,s] = draw_frame_id_on_vis(vis[:,s], frame_ids[s]) + + if W > self.maxwidth: + vis = vis[...,:self.maxwidth] + + if only_return: + return vis + else: + self.summ_gif(name, vis) + + def summ_oned(self, name, im, bev=False, fro=False, logvis=False, max_val=0, max_along_y=False, norm=True, frame_id=None, only_return=False): + if self.save_this: + + if bev: + B, C, H, _, W = list(im.shape) + if max_along_y: + im = torch.max(im, dim=3)[0] + else: + im = torch.mean(im, dim=3) + elif fro: + B, C, _, H, W = list(im.shape) + if max_along_y: + im = torch.max(im, dim=2)[0] + else: + im = torch.mean(im, dim=2) + else: + B, C, H, W = list(im.shape) + + im = im[0:1] # just the first one + assert(C==1) + + if logvis and max_val: + max_val = np.log(max_val) + im = torch.log(im) + im = torch.clamp(im, 0, max_val) + im = im/max_val + norm = False + elif max_val: + im = torch.clamp(im, 0, max_val)/max_val + norm = False + + vis = oned2inferno(im, norm=norm) + if W > self.maxwidth: + vis = vis[...,:self.maxwidth] + return self.summ_rgb(name, vis, blacken_zeros=False, frame_id=frame_id, only_return=only_return) + + def summ_feats(self, name, feats, valids=None, pca=True, fro=False, only_return=False, frame_ids=None): + if self.save_this: + if valids is not None: + valids = torch.stack(valids, dim=1) + + feats = torch.stack(feats, dim=1) + # feats leads with B x S x C + + if feats.ndim==6: + + # feats is B x S x C x D x H x W + if fro: + reduce_dim = 3 + else: + reduce_dim = 4 + + if valids is None: + feats = torch.mean(feats, dim=reduce_dim) + else: + valids = valids.repeat(1, 1, feats.size()[2], 1, 1, 1) + feats = utils.basic.reduce_masked_mean(feats, valids, dim=reduce_dim) + + B, S, C, D, W = list(feats.size()) + + if not pca: + # feats leads with B x S x C + feats = torch.mean(torch.abs(feats), dim=2, keepdims=True) + # feats leads with B x S x 1 + feats = torch.unbind(feats, dim=1) + return self.summ_oneds(name=name, ims=feats, norm=True, only_return=only_return, frame_ids=frame_ids) + + else: + __p = lambda x: utils.basic.pack_seqdim(x, B) + __u = lambda x: utils.basic.unpack_seqdim(x, B) + + feats_ = __p(feats) + + if valids is None: + feats_pca_ = get_feat_pca(feats_) + else: + valids_ = __p(valids) + feats_pca_ = get_feat_pca(feats_, valids) + + feats_pca = __u(feats_pca_) + + return self.summ_rgbs(name=name, ims=torch.unbind(feats_pca, dim=1), only_return=only_return, frame_ids=frame_ids) + + def summ_feat(self, name, feat, valid=None, pca=True, only_return=False, bev=False, fro=False, frame_id=None): + if self.save_this: + if feat.ndim==5: # B x C x D x H x W + + if bev: + reduce_axis = 3 + elif fro: + reduce_axis = 2 + else: + # default to bev + reduce_axis = 3 + + if valid is None: + feat = torch.mean(feat, dim=reduce_axis) + else: + valid = valid.repeat(1, feat.size()[1], 1, 1, 1) + feat = utils.basic.reduce_masked_mean(feat, valid, dim=reduce_axis) + + B, C, D, W = list(feat.shape) + + if not pca: + feat = torch.mean(torch.abs(feat), dim=1, keepdims=True) + # feat is B x 1 x D x W + return self.summ_oned(name=name, im=feat, norm=True, only_return=only_return, frame_id=frame_id) + else: + feat_pca = get_feat_pca(feat, valid) + return self.summ_rgb(name, feat_pca, only_return=only_return, frame_id=frame_id) + + def summ_scalar(self, name, value): + if (not (isinstance(value, int) or isinstance(value, float) or isinstance(value, np.float32))) and ('torch' in value.type()): + value = value.detach().cpu().numpy() + if not np.isnan(value): + if (self.log_freq == 1): + self.writer.add_scalar(name, value, global_step=self.global_step) + elif self.save_this or np.mod(self.global_step, self.scalar_freq)==0: + self.writer.add_scalar(name, value, global_step=self.global_step) + + def summ_seg(self, name, seg, only_return=False, frame_id=None, colormap='tab20', label_colors=None): + if not self.save_this: + return + + B,H,W = seg.shape + + if label_colors is None: + custom_label_colors = False + # label_colors = get_n_colors(int(torch.max(seg).item()), sequential=True) + label_colors = cm.get_cmap(colormap).colors + label_colors = [[int(i*255) for i in l] for l in label_colors] + else: + custom_label_colors = True + # label_colors = matplotlib.cm.get_cmap(colormap).colors + # label_colors = [[int(i*255) for i in l] for l in label_colors] + # print('label_colors', label_colors) + + # label_colors = [ + # (0, 0, 0), # None + # (70, 70, 70), # Buildings + # (190, 153, 153), # Fences + # (72, 0, 90), # Other + # (220, 20, 60), # Pedestrians + # (153, 153, 153), # Poles + # (157, 234, 50), # RoadLines + # (128, 64, 128), # Roads + # (244, 35, 232), # Sidewalks + # (107, 142, 35), # Vegetation + # (0, 0, 255), # Vehicles + # (102, 102, 156), # Walls + # (220, 220, 0) # TrafficSigns + # ] + + r = torch.zeros_like(seg,dtype=torch.uint8) + g = torch.zeros_like(seg,dtype=torch.uint8) + b = torch.zeros_like(seg,dtype=torch.uint8) + + for label in range(0,len(label_colors)): + if (not custom_label_colors):# and (N > 20): + label_ = label % 20 + else: + label_ = label + + idx = (seg == label+1) + r[idx] = label_colors[label_][0] + g[idx] = label_colors[label_][1] + b[idx] = label_colors[label_][2] + + rgb = torch.stack([r,g,b],axis=1) + return self.summ_rgb(name,rgb,only_return=only_return, frame_id=frame_id) + + def summ_pts_on_rgb(self, name, trajs, rgb, valids=None, frame_id=None, only_return=False, show_dots=True, cmap='coolwarm', linewidth=1): + # trajs is B, S, N, 2 + # rgbs is B, S, C, H, W + B, C, H, W = rgb.shape + B, S, N, D = trajs.shape + + rgb = rgb[0] # C, H, W + trajs = trajs[0] # S, N, 2 + if valids is None: + valids = torch.ones_like(trajs[:,:,0]) # S, N + else: + valids = valids[0] + # print('trajs', trajs.shape) + # print('valids', valids.shape) + + rgb = back2color(rgb).detach().cpu().numpy() + rgb = np.transpose(rgb, [1, 2, 0]) # put channels last + + trajs = trajs.long().detach().cpu().numpy() # S, N, 2 + valids = valids.long().detach().cpu().numpy() # S, N + + rgb = rgb.astype(np.uint8).copy() + + for i in range(N): + if cmap=='onediff' and i==0: + cmap_ = 'spring' + elif cmap=='onediff': + cmap_ = 'winter' + else: + cmap_ = cmap + traj = trajs[:,i] # S,2 + valid = valids[:,i] # S + + color_map = cm.get_cmap(cmap) + color = np.array(color_map(i)[:3]) * 255 # rgb + for s in range(S): + if valid[s]: + cv2.circle(rgb, (int(traj[s,0]), int(traj[s,1])), linewidth, color, -1) + rgb = torch.from_numpy(rgb).permute(2,0,1).unsqueeze(0) + rgb = preprocess_color(rgb) + return self.summ_rgb(name, rgb, only_return=only_return, frame_id=frame_id) + + def summ_pts_on_rgbs(self, name, trajs, rgbs, valids=None, frame_ids=None, only_return=False, show_dots=True, cmap='coolwarm', linewidth=1): + # trajs is B, S, N, 2 + # rgbs is B, S, C, H, W + B, S, C, H, W = rgbs.shape + B, S2, N, D = trajs.shape + assert(S==S2) + + rgbs = rgbs[0] # S, C, H, W + trajs = trajs[0] # S, N, 2 + if valids is None: + valids = torch.ones_like(trajs[:,:,0]) # S, N + else: + valids = valids[0] + # print('trajs', trajs.shape) + # print('valids', valids.shape) + + rgbs_color = [] + for rgb in rgbs: + rgb = back2color(rgb).detach().cpu().numpy() + rgb = np.transpose(rgb, [1, 2, 0]) # put channels last + rgbs_color.append(rgb) # each element 3 x H x W + + trajs = trajs.long().detach().cpu().numpy() # S, N, 2 + valids = valids.long().detach().cpu().numpy() # S, N + + rgbs_color = [rgb.astype(np.uint8).copy() for rgb in rgbs_color] + + for i in range(N): + traj = trajs[:,i] # S,2 + valid = valids[:,i] # S + + color_map = cm.get_cmap(cmap) + color = np.array(color_map(0)[:3]) * 255 # rgb + for s in range(S): + if valid[s]: + cv2.circle(rgbs_color[s], (traj[s,0], traj[s,1]), linewidth, color, -1) + rgbs = [] + for rgb in rgbs_color: + rgb = torch.from_numpy(rgb).permute(2, 0, 1).unsqueeze(0) + rgbs.append(preprocess_color(rgb)) + + return self.summ_rgbs(name, rgbs, only_return=only_return, frame_ids=frame_ids) + + + def summ_traj2ds_on_rgbs(self, name, trajs, rgbs, valids=None, frame_ids=None, only_return=False, show_dots=False, cmap='coolwarm', vals=None, linewidth=1): + # trajs is B, S, N, 2 + # rgbs is B, S, C, H, W + B, S, C, H, W = rgbs.shape + B, S2, N, D = trajs.shape + assert(S==S2) + + rgbs = rgbs[0] # S, C, H, W + trajs = trajs[0] # S, N, 2 + if valids is None: + valids = torch.ones_like(trajs[:,:,0]) # S, N + else: + valids = valids[0] + + # print('trajs', trajs.shape) + # print('valids', valids.shape) + + if vals is not None: + vals = vals[0] # N + # print('vals', vals.shape) + + rgbs_color = [] + for rgb in rgbs: + rgb = back2color(rgb).detach().cpu().numpy() + rgb = np.transpose(rgb, [1, 2, 0]) # put channels last + rgbs_color.append(rgb) # each element 3 x H x W + + for i in range(N): + if cmap=='onediff' and i==0: + cmap_ = 'spring' + elif cmap=='onediff': + cmap_ = 'winter' + else: + cmap_ = cmap + traj = trajs[:,i].long().detach().cpu().numpy() # S, 2 + valid = valids[:,i].long().detach().cpu().numpy() # S + + # print('traj', traj.shape) + # print('valid', valid.shape) + + if vals is not None: + # val = vals[:,i].float().detach().cpu().numpy() # [] + val = vals[i].float().detach().cpu().numpy() # [] + # print('val', val.shape) + else: + val = None + + for t in range(S): + # if valid[t]: + # traj_seq = traj[max(t-16,0):t+1] + traj_seq = traj[max(t-8,0):t+1] + val_seq = np.linspace(0,1,len(traj_seq)) + # if t<2: + # val_seq = np.zeros_like(val_seq) + # print('val_seq', val_seq) + # val_seq = 1.0 + # val_seq = np.arange(8)/8.0 + # val_seq = val_seq[-len(traj_seq):] + # rgbs_color[t] = self.draw_traj_on_image_py(rgbs_color[t], traj_seq, S=S, show_dots=show_dots, cmap=cmap_, val=val_seq, linewidth=linewidth) + rgbs_color[t] = self.draw_traj_on_image_py(rgbs_color[t], traj_seq, S=S, show_dots=show_dots, cmap=cmap_, val=val_seq, linewidth=linewidth) + # input() + + for i in range(N): + if cmap=='onediff' and i==0: + cmap_ = 'spring' + elif cmap=='onediff': + cmap_ = 'winter' + else: + cmap_ = cmap + traj = trajs[:,i] # S,2 + # vis = visibles[:,i] # S + vis = torch.ones_like(traj[:,0]) # S + valid = valids[:,i] # S + rgbs_color = self.draw_circ_on_images_py(rgbs_color, traj, vis, S=0, show_dots=show_dots, cmap=cmap_, linewidth=linewidth) + + rgbs = [] + for rgb in rgbs_color: + rgb = torch.from_numpy(rgb).permute(2, 0, 1).unsqueeze(0) + rgbs.append(preprocess_color(rgb)) + + return self.summ_rgbs(name, rgbs, only_return=only_return, frame_ids=frame_ids) + + def summ_traj2ds_on_rgbs2(self, name, trajs, visibles, rgbs, valids=None, frame_ids=None, only_return=False, show_dots=True, cmap=None, linewidth=1): + # trajs is B, S, N, 2 + # rgbs is B, S, C, H, W + B, S, C, H, W = rgbs.shape + B, S2, N, D = trajs.shape + assert(S==S2) + + rgbs = rgbs[0] # S, C, H, W + trajs = trajs[0] # S, N, 2 + visibles = visibles[0] # S, N + if valids is None: + valids = torch.ones_like(trajs[:,:,0]) # S, N + else: + valids = valids[0] + # print('trajs', trajs.shape) + # print('valids', valids.shape) + + rgbs_color = [] + for rgb in rgbs: + rgb = back2color(rgb).detach().cpu().numpy() + rgb = np.transpose(rgb, [1, 2, 0]) # put channels last + rgbs_color.append(rgb) # each element 3 x H x W + + trajs = trajs.long().detach().cpu().numpy() # S, N, 2 + visibles = visibles.float().detach().cpu().numpy() # S, N + valids = valids.long().detach().cpu().numpy() # S, N + + for i in range(N): + if cmap=='onediff' and i==0: + cmap_ = 'spring' + elif cmap=='onediff': + cmap_ = 'winter' + else: + cmap_ = cmap + traj = trajs[:,i] # S,2 + vis = visibles[:,i] # S + valid = valids[:,i] # S + rgbs_color = self.draw_traj_on_images_py(rgbs_color, traj, S=S, show_dots=show_dots, cmap=cmap_, linewidth=linewidth) + + for i in range(N): + if cmap=='onediff' and i==0: + cmap_ = 'spring' + elif cmap=='onediff': + cmap_ = 'winter' + else: + cmap_ = cmap + traj = trajs[:,i] # S,2 + vis = visibles[:,i] # S + valid = valids[:,i] # S + if valid[0]: + rgbs_color = self.draw_circ_on_images_py(rgbs_color, traj, vis, S=S, show_dots=show_dots, cmap=None, linewidth=linewidth) + + rgbs = [] + for rgb in rgbs_color: + rgb = torch.from_numpy(rgb).permute(2, 0, 1).unsqueeze(0) + rgbs.append(preprocess_color(rgb)) + + return self.summ_rgbs(name, rgbs, only_return=only_return, frame_ids=frame_ids) + + def summ_traj2ds_on_rgb(self, name, trajs, rgb, valids=None, show_dots=False, show_lines=True, frame_id=None, only_return=False, cmap='coolwarm', linewidth=1): + # trajs is B, S, N, 2 + # rgb is B, C, H, W + B, C, H, W = rgb.shape + B, S, N, D = trajs.shape + + rgb = rgb[0] # S, C, H, W + trajs = trajs[0] # S, N, 2 + + if valids is None: + valids = torch.ones_like(trajs[:,:,0]) + else: + valids = valids[0] + + rgb_color = back2color(rgb).detach().cpu().numpy() + rgb_color = np.transpose(rgb_color, [1, 2, 0]) # put channels last + + # using maxdist will dampen the colors for short motions + norms = torch.sqrt(1e-4 + torch.sum((trajs[-1] - trajs[0])**2, dim=1)) # N + maxdist = torch.quantile(norms, 0.95).detach().cpu().numpy() + maxdist = None + trajs = trajs.long().detach().cpu().numpy() # S, N, 2 + valids = valids.long().detach().cpu().numpy() # S, N + + for i in range(N): + if cmap=='onediff' and i==0: + cmap_ = 'spring' + elif cmap=='onediff': + cmap_ = 'winter' + else: + cmap_ = cmap + traj = trajs[:,i] # S, 2 + valid = valids[:,i] # S + if valid[0]==1: + traj = traj[valid>0] + rgb_color = self.draw_traj_on_image_py( + rgb_color, traj, S=S, show_dots=show_dots, show_lines=show_lines, cmap=cmap_, maxdist=maxdist, linewidth=linewidth) + + rgb_color = torch.from_numpy(rgb_color).permute(2, 0, 1).unsqueeze(0) + rgb = preprocess_color(rgb_color) + return self.summ_rgb(name, rgb, only_return=only_return, frame_id=frame_id) + + def draw_traj_on_image_py(self, rgb, traj, S=50, linewidth=1, show_dots=False, show_lines=True, cmap='coolwarm', val=None, maxdist=None): + # all inputs are numpy tensors + # rgb is 3 x H x W + # traj is S x 2 + + H, W, C = rgb.shape + assert(C==3) + + rgb = rgb.astype(np.uint8).copy() + + S1, D = traj.shape + assert(D==2) + + color_map = cm.get_cmap(cmap) + S1, D = traj.shape + + for s in range(S1): + if val is not None: + # if len(val) == S1: + color = np.array(color_map(val[s])[:3]) * 255 # rgb + # else: + # color = np.array(color_map(val)[:3]) * 255 # rgb + else: + if maxdist is not None: + val = (np.sqrt(np.sum((traj[s]-traj[0])**2))/maxdist).clip(0,1) + color = np.array(color_map(val)[:3]) * 255 # rgb + else: + color = np.array(color_map((s)/max(1,float(S-2)))[:3]) * 255 # rgb + + if show_lines and s<(S1-1): + cv2.line(rgb, + (int(traj[s,0]), int(traj[s,1])), + (int(traj[s+1,0]), int(traj[s+1,1])), + color, + linewidth, + cv2.LINE_AA) + if show_dots: + cv2.circle(rgb, (int(traj[s,0]), int(traj[s,1])), linewidth, np.array(color_map(1)[:3])*255, -1) + + # if maxdist is not None: + # val = (np.sqrt(np.sum((traj[-1]-traj[0])**2))/maxdist).clip(0,1) + # color = np.array(color_map(val)[:3]) * 255 # rgb + # else: + # # draw the endpoint of traj, using the next color (which may be the last color) + # color = np.array(color_map((S1-1)/max(1,float(S-2)))[:3]) * 255 # rgb + + # # emphasize endpoint + # cv2.circle(rgb, (traj[-1,0], traj[-1,1]), linewidth*2, color, -1) + + return rgb + + + + def draw_traj_on_images_py(self, rgbs, traj, S=50, linewidth=1, show_dots=False, cmap='coolwarm', maxdist=None): + # all inputs are numpy tensors + # rgbs is a list of H,W,3 + # traj is S,2 + H, W, C = rgbs[0].shape + assert(C==3) + + rgbs = [rgb.astype(np.uint8).copy() for rgb in rgbs] + + S1, D = traj.shape + assert(D==2) + + x = int(np.clip(traj[0,0], 0, W-1)) + y = int(np.clip(traj[0,1], 0, H-1)) + color = rgbs[0][y,x] + color = (int(color[0]),int(color[1]),int(color[2])) + for s in range(S): + # bak_color = np.array(color_map(1.0)[:3]) * 255 # rgb + # cv2.circle(rgbs[s], (traj[s,0], traj[s,1]), linewidth*4, bak_color, -1) + cv2.polylines(rgbs[s], + [traj[:s+1]], + False, + color, + linewidth, + cv2.LINE_AA) + return rgbs + + def draw_circs_on_image_py(self, rgb, xy, colors=None, linewidth=10, radius=3, show_dots=False, maxdist=None): + # all inputs are numpy tensors + # rgbs is a list of 3,H,W + # xy is N,2 + H, W, C = rgb.shape + assert(C==3) + + rgb = rgb.astype(np.uint8).copy() + + N, D = xy.shape + assert(D==2) + + + xy = xy.astype(np.float32) + xy[:,0] = np.clip(xy[:,0], 0, W-1) + xy[:,1] = np.clip(xy[:,1], 0, H-1) + xy = xy.astype(np.int32) + + + + if colors is None: + colors = get_n_colors(N) + + for n in range(N): + color = colors[n] + # print('color', color) + # color = (color[0]*255).astype(np.uint8) + color = (int(color[0]),int(color[1]),int(color[2])) + + # x = int(np.clip(xy[0,0], 0, W-1)) + # y = int(np.clip(xy[0,1], 0, H-1)) + # color_ = rgbs[0][y,x] + # color_ = (int(color_[0]),int(color_[1]),int(color_[2])) + # color_ = (int(color_[0]),int(color_[1]),int(color_[2])) + + cv2.circle(rgb, (xy[n,0], xy[n,1]), linewidth, color, 3) + # vis_color = int(np.squeeze(vis[s])*255) + # vis_color = (vis_color,vis_color,vis_color) + # cv2.circle(rgbs[s], (traj[s,0], traj[s,1]), linewidth+1, vis_color, -1) + return rgb + + def draw_circ_on_images_py(self, rgbs, traj, vis, S=50, linewidth=1, show_dots=False, cmap=None, maxdist=None): + # all inputs are numpy tensors + # rgbs is a list of 3,H,W + # traj is S,2 + H, W, C = rgbs[0].shape + assert(C==3) + + rgbs = [rgb.astype(np.uint8).copy() for rgb in rgbs] + + S1, D = traj.shape + assert(D==2) + + if cmap is None: + bremm = ColorMap2d() + traj_ = traj[0:1].astype(np.float32) + traj_[:,0] /= float(W) + traj_[:,1] /= float(H) + color = bremm(traj_) + # print('color', color) + color = (color[0]*255).astype(np.uint8) + # color = (int(color[0]),int(color[1]),int(color[2])) + color = (int(color[2]),int(color[1]),int(color[0])) + + for s in range(S1): + if cmap is not None: + color_map = cm.get_cmap(cmap) + # color = np.array(color_map(s/(S-1))[:3]) * 255 # rgb + color = np.array(color_map((s+1)/max(1,float(S-1)))[:3]) * 255 # rgb + # color = color.astype(np.uint8) + # color = (color[0], color[1], color[2]) + # print('color', color) + # import ipdb; ipdb.set_trace() + + cv2.circle(rgbs[s], (int(traj[s,0]), int(traj[s,1])), linewidth+1, color, -1) + # vis_color = int(np.squeeze(vis[s])*255) + # vis_color = (vis_color,vis_color,vis_color) + # cv2.circle(rgbs[s], (int(traj[s,0]), int(traj[s,1])), linewidth+1, vis_color, -1) + + return rgbs + + def summ_traj_as_crops(self, name, trajs_e, rgbs, frame_id=None, only_return=False, show_circ=False, trajs_g=None, is_g=False): + B, S, N, D = trajs_e.shape + assert(N==1) + assert(D==2) + + rgbs_vis = [] + n = 0 + pad_amount = 100 + trajs_e_py = trajs_e[0].detach().cpu().numpy() + # trajs_e_py = np.clip(trajs_e_py, min=pad_amount/2, max=pad_amoun + trajs_e_py = trajs_e_py + pad_amount + + if trajs_g is not None: + trajs_g_py = trajs_g[0].detach().cpu().numpy() + trajs_g_py = trajs_g_py + pad_amount + + for s in range(S): + rgb = rgbs[0,s].detach().cpu().numpy() + # print('orig rgb', rgb.shape) + rgb = np.transpose(rgb,(1,2,0)) # H, W, 3 + + rgb = np.pad(rgb, ((pad_amount,pad_amount),(pad_amount,pad_amount),(0,0))) + # print('pad rgb', rgb.shape) + H, W, C = rgb.shape + + if trajs_g is not None: + xy_g = trajs_g_py[s,n] + xy_g[0] = np.clip(xy_g[0], pad_amount, W-pad_amount) + xy_g[1] = np.clip(xy_g[1], pad_amount, H-pad_amount) + rgb = self.draw_circs_on_image_py(rgb, xy_g.reshape(1,2), colors=[(0,255,0)], linewidth=2, radius=3) + + xy_e = trajs_e_py[s,n] + xy_e[0] = np.clip(xy_e[0], pad_amount, W-pad_amount) + xy_e[1] = np.clip(xy_e[1], pad_amount, H-pad_amount) + + if show_circ: + if is_g: + rgb = self.draw_circs_on_image_py(rgb, xy_e.reshape(1,2), colors=[(0,255,0)], linewidth=2, radius=3) + else: + rgb = self.draw_circs_on_image_py(rgb, xy_e.reshape(1,2), colors=[(255,0,255)], linewidth=2, radius=3) + + + xmin = int(xy_e[0])-pad_amount//2 + xmax = int(xy_e[0])+pad_amount//2 + ymin = int(xy_e[1])-pad_amount//2 + ymax = int(xy_e[1])+pad_amount//2 + + rgb_ = rgb[ymin:ymax, xmin:xmax] + + H_, W_ = rgb_.shape[:2] + # if np.any(rgb_.shape==0): + # input() + if H_==0 or W_==0: + import ipdb; ipdb.set_trace() + + rgb_ = rgb_.transpose(2,0,1) + rgb_ = torch.from_numpy(rgb_) + + rgbs_vis.append(rgb_) + + # nrow = int(np.sqrt(S)*(16.0/9)/2.0) + nrow = int(np.sqrt(S)*1.5) + grid_img = torchvision.utils.make_grid(torch.stack(rgbs_vis, dim=0), nrow=nrow).unsqueeze(0) + # print('grid_img', grid_img.shape) + return self.summ_rgb(name, grid_img.byte(), frame_id=frame_id, only_return=only_return) + + def summ_occ(self, name, occ, reduce_axes=[3], bev=False, fro=False, pro=False, frame_id=None, only_return=False): + if self.save_this: + B, C, D, H, W = list(occ.shape) + if bev: + reduce_axes = [3] + elif fro: + reduce_axes = [2] + elif pro: + reduce_axes = [4] + for reduce_axis in reduce_axes: + height = convert_occ_to_height(occ, reduce_axis=reduce_axis) + if reduce_axis == reduce_axes[-1]: + return self.summ_oned(name=('%s_ax%d' % (name, reduce_axis)), im=height, norm=False, frame_id=frame_id, only_return=only_return) + else: + self.summ_oned(name=('%s_ax%d' % (name, reduce_axis)), im=height, norm=False, frame_id=frame_id, only_return=only_return) + +def erode2d(im, times=1, device='cuda'): + weights2d = torch.ones(1, 1, 3, 3, device=device) + for time in range(times): + im = 1.0 - F.conv2d(1.0 - im, weights2d, padding=1).clamp(0, 1) + return im + +def dilate2d(im, times=1, device='cuda', mode='square'): + weights2d = torch.ones(1, 1, 3, 3, device=device) + if mode=='cross': + weights2d[:,:,0,0] = 0.0 + weights2d[:,:,0,2] = 0.0 + weights2d[:,:,2,0] = 0.0 + weights2d[:,:,2,2] = 0.0 + for time in range(times): + im = F.conv2d(im, weights2d, padding=1).clamp(0, 1) + return im + + diff --git a/models/spatracker/utils/misc.py b/models/spatracker/utils/misc.py new file mode 100644 index 0000000000000000000000000000000000000000..adc31966644a4639152770f8067d74a24f36b8a2 --- /dev/null +++ b/models/spatracker/utils/misc.py @@ -0,0 +1,166 @@ +import torch +import numpy as np +import math +from prettytable import PrettyTable + +def count_parameters(model): + table = PrettyTable(["Modules", "Parameters"]) + total_params = 0 + for name, parameter in model.named_parameters(): + if not parameter.requires_grad: + continue + param = parameter.numel() + if param > 100000: + table.add_row([name, param]) + total_params+=param + print(table) + print('total params: %.2f M' % (total_params/1000000.0)) + return total_params + +def posemb_sincos_2d_xy(xy, C, temperature=10000, dtype=torch.float32, cat_coords=False): + device = xy.device + dtype = xy.dtype + B, S, D = xy.shape + assert(D==2) + x = xy[:,:,0] + y = xy[:,:,1] + assert (C % 4) == 0, 'feature dimension must be multiple of 4 for sincos emb' + omega = torch.arange(C // 4, device=device) / (C // 4 - 1) + omega = 1. / (temperature ** omega) + + y = y.flatten()[:, None] * omega[None, :] + x = x.flatten()[:, None] * omega[None, :] + pe = torch.cat((x.sin(), x.cos(), y.sin(), y.cos()), dim=1) + pe = pe.reshape(B,S,C).type(dtype) + if cat_coords: + pe = torch.cat([pe, xy], dim=2) # B,N,C+2 + return pe + +class SimplePool(): + def __init__(self, pool_size, version='pt'): + self.pool_size = pool_size + self.version = version + self.items = [] + + if not (version=='pt' or version=='np'): + print('version = %s; please choose pt or np') + assert(False) # please choose pt or np + + def __len__(self): + return len(self.items) + + def mean(self, min_size=1): + if min_size=='half': + pool_size_thresh = self.pool_size/2 + else: + pool_size_thresh = min_size + + if self.version=='np': + if len(self.items) >= pool_size_thresh: + return np.sum(self.items)/float(len(self.items)) + else: + return np.nan + if self.version=='pt': + if len(self.items) >= pool_size_thresh: + return torch.sum(self.items)/float(len(self.items)) + else: + return torch.from_numpy(np.nan) + + def sample(self, with_replacement=True): + idx = np.random.randint(len(self.items)) + if with_replacement: + return self.items[idx] + else: + return self.items.pop(idx) + + def fetch(self, num=None): + if self.version=='pt': + item_array = torch.stack(self.items) + elif self.version=='np': + item_array = np.stack(self.items) + if num is not None: + # there better be some items + assert(len(self.items) >= num) + + # if there are not that many elements just return however many there are + if len(self.items) < num: + return item_array + else: + idxs = np.random.randint(len(self.items), size=num) + return item_array[idxs] + else: + return item_array + + def is_full(self): + full = len(self.items)==self.pool_size + return full + + def empty(self): + self.items = [] + + def update(self, items): + for item in items: + if len(self.items) < self.pool_size: + # the pool is not full, so let's add this in + self.items.append(item) + else: + # the pool is full + # pop from the front + self.items.pop(0) + # add to the back + self.items.append(item) + return self.items + +def farthest_point_sample(xyz, npoint, include_ends=False, deterministic=False): + """ + Input: + xyz: pointcloud data, [B, N, C], where C is probably 3 + npoint: number of samples + Return: + inds: sampled pointcloud index, [B, npoint] + """ + device = xyz.device + B, N, C = xyz.shape + xyz = xyz.float() + inds = torch.zeros(B, npoint, dtype=torch.long).to(device) + distance = torch.ones(B, N).to(device) * 1e10 + if deterministic: + farthest = torch.randint(0, 1, (B,), dtype=torch.long).to(device) + else: + farthest = torch.randint(0, N, (B,), dtype=torch.long).to(device) + batch_indices = torch.arange(B, dtype=torch.long).to(device) + for i in range(npoint): + if include_ends: + if i==0: + farthest = 0 + elif i==1: + farthest = N-1 + inds[:, i] = farthest + centroid = xyz[batch_indices, farthest, :].view(B, 1, C) + dist = torch.sum((xyz - centroid) ** 2, -1) + mask = dist < distance + distance[mask] = dist[mask] + farthest = torch.max(distance, -1)[1] + + if npoint > N: + # if we need more samples, make them random + distance += torch.randn_like(distance) + return inds + +def farthest_point_sample_py(xyz, npoint): + N,C = xyz.shape + inds = np.zeros(npoint, dtype=np.int32) + distance = np.ones(N) * 1e10 + farthest = np.random.randint(0, N, dtype=np.int32) + for i in range(npoint): + inds[i] = farthest + centroid = xyz[farthest, :].reshape(1,C) + dist = np.sum((xyz - centroid) ** 2, -1) + mask = dist < distance + distance[mask] = dist[mask] + farthest = np.argmax(distance, -1) + if npoint > N: + # if we need more samples, make them random + distance += np.random.randn(*distance.shape) + return inds + diff --git a/models/spatracker/utils/samp.py b/models/spatracker/utils/samp.py new file mode 100644 index 0000000000000000000000000000000000000000..3632c9c1164638aec4c1caf3de2bfdbcb4ee6126 --- /dev/null +++ b/models/spatracker/utils/samp.py @@ -0,0 +1,152 @@ +import torch +import utils.basic +import torch.nn.functional as F + +def bilinear_sample2d(im, x, y, return_inbounds=False): + # x and y are each B, N + # output is B, C, N + B, C, H, W = list(im.shape) + N = list(x.shape)[1] + + x = x.float() + y = y.float() + H_f = torch.tensor(H, dtype=torch.float32) + W_f = torch.tensor(W, dtype=torch.float32) + + # inbound_mask = (x>-0.5).float()*(y>-0.5).float()*(x -0.5).byte() & (x < float(W_f - 0.5)).byte() + y_valid = (y > -0.5).byte() & (y < float(H_f - 0.5)).byte() + inbounds = (x_valid & y_valid).float() + inbounds = inbounds.reshape(B, N) # something seems wrong here for B>1; i'm getting an error here (or downstream if i put -1) + return output, inbounds + + return output # B, C, N + +def paste_crop_on_canvas(crop, box2d_unnorm, H, W, fast=True, mask=None, canvas=None): + # this is the inverse of crop_and_resize_box2d + B, C, Y, X = list(crop.shape) + B2, D = list(box2d_unnorm.shape) + assert(B == B2) + assert(D == 4) + + # here, we want to place the crop into a bigger image, + # at the location specified by the box2d. + + if canvas is None: + canvas = torch.zeros((B, C, H, W), device=crop.device) + else: + B2, C2, H2, W2 = canvas.shape + assert(B==B2) + assert(C==C2) + assert(H==H2) + assert(W==W2) + + # box2d_unnorm = utils.geom.unnormalize_box2d(box2d, H, W) + + if fast: + ymin = box2d_unnorm[:, 0].long() + xmin = box2d_unnorm[:, 1].long() + ymax = box2d_unnorm[:, 2].long() + xmax = box2d_unnorm[:, 3].long() + w = (xmax - xmin).float() + h = (ymax - ymin).float() + + grids = utils.basic.gridcloud2d(B, H, W) + grids_flat = grids.reshape(B, -1, 2) + # grids_flat[:, :, 0] = (grids_flat[:, :, 0] - xmin.float().unsqueeze(1)) / w.unsqueeze(1) * X + # grids_flat[:, :, 1] = (grids_flat[:, :, 1] - ymin.float().unsqueeze(1)) / h.unsqueeze(1) * Y + + # for each pixel in the main image, + # grids_flat tells us where to sample in the crop image + + # print('grids_flat', grids_flat.shape) + # print('crop', crop.shape) + + grids_flat[:, :, 0] = (grids_flat[:, :, 0] - xmin.float().unsqueeze(1)) / w.unsqueeze(1) * 2.0 - 1.0 + grids_flat[:, :, 1] = (grids_flat[:, :, 1] - ymin.float().unsqueeze(1)) / h.unsqueeze(1) * 2.0 - 1.0 + + grid = grids_flat.reshape(B,H,W,2) + + canvas = F.grid_sample(crop, grid, align_corners=False) + # print('canvas', canvas.shape) + + # if mask is None: + # crop_resamp, inb = bilinear_sample2d(crop, grids_flat[:, :, 0], grids_flat[:, :, 1], return_inbounds=True) + # crop_resamp = crop_resamp.reshape(B, C, H, W) + # inb = inb.reshape(B, 1, H, W) + # canvas = canvas * (1 - inb) + crop_resamp * inb + # else: + # full_resamp = bilinear_sample2d(torch.cat([crop, mask], dim=1), grids_flat[:, :, 0], grids_flat[:, :, 1]) + # full_resamp = full_resamp.reshape(B, C+1, H, W) + # crop_resamp = full_resamp[:,:3] + # mask_resamp = full_resamp[:,3:4] + # canvas = canvas * (1 - mask_resamp) + crop_resamp * mask_resamp + else: + for b in range(B): + ymin = box2d_unnorm[b, 0].long() + xmin = box2d_unnorm[b, 1].long() + ymax = box2d_unnorm[b, 2].long() + xmax = box2d_unnorm[b, 3].long() + + crop_b = F.interpolate(crop[b:b + 1], (ymax - ymin, xmax - xmin)).squeeze(0) + + # print('canvas[b,:,...', canvas[b,:,ymin:ymax,xmin:xmax].shape) + # print('crop_b', crop_b.shape) + + canvas[b, :, ymin:ymax, xmin:xmax] = crop_b + return canvas diff --git a/models/spatracker/utils/visualizer.py b/models/spatracker/utils/visualizer.py new file mode 100644 index 0000000000000000000000000000000000000000..5202507bf6643f19dec5ede73c9edfe7f0a42770 --- /dev/null +++ b/models/spatracker/utils/visualizer.py @@ -0,0 +1,409 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import os +import numpy as np +import cv2 +import torch +import flow_vis + +from matplotlib import cm +import torch.nn.functional as F +import torchvision.transforms as transforms +from moviepy.editor import ImageSequenceClip +import matplotlib.pyplot as plt +from tqdm import tqdm + +def read_video_from_path(path): + cap = cv2.VideoCapture(path) + if not cap.isOpened(): + print("Error opening video file") + else: + frames = [] + while cap.isOpened(): + ret, frame = cap.read() + if ret == True: + frames.append(np.array(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))) + else: + break + cap.release() + return np.stack(frames) + + +class Visualizer: + def __init__( + self, + save_dir: str = "./results", + grayscale: bool = False, + pad_value: int = 0, + fps: int = 10, + mode: str = "rainbow", # 'cool', 'optical_flow' + linewidth: int = 1, + show_first_frame: int = 10, + tracks_leave_trace: int = 0, # -1 for infinite + ): + self.mode = mode + self.save_dir = save_dir + self.vtxt_path = os.path.join(save_dir, "videos.txt") + self.ttxt_path = os.path.join(save_dir, "trackings.txt") + if mode == "rainbow": + self.color_map = cm.get_cmap("gist_rainbow") + elif mode == "cool": + self.color_map = cm.get_cmap(mode) + self.show_first_frame = show_first_frame + self.grayscale = grayscale + self.tracks_leave_trace = tracks_leave_trace + self.pad_value = pad_value + self.linewidth = linewidth + self.fps = fps + + def visualize( + self, + video: torch.Tensor, # (B,T,C,H,W) + tracks: torch.Tensor, # (B,T,N,2) + visibility: torch.Tensor = None, # (B, T, N, 1) bool + gt_tracks: torch.Tensor = None, # (B,T,N,2) + segm_mask: torch.Tensor = None, # (B,1,H,W) + filename: str = "video", + writer=None, # tensorboard Summary Writer, used for visualization during training + step: int = 0, + query_frame: int = 0, + save_video: bool = True, + compensate_for_camera_motion: bool = False, + rigid_part = None, + video_depth = None # (B,T,C,H,W) + ): + if compensate_for_camera_motion: + assert segm_mask is not None + if segm_mask is not None: + coords = tracks[0, query_frame].round().long() + segm_mask = segm_mask[0, query_frame][coords[:, 1], coords[:, 0]].long() + + video = F.pad( + video, + (self.pad_value, self.pad_value, self.pad_value, self.pad_value), + "constant", + 255, + ) + + if video_depth is not None: + video_depth = (video_depth*255).cpu().numpy().astype(np.uint8) + video_depth = ([cv2.applyColorMap(video_depth[0,i,0], cv2.COLORMAP_INFERNO) + for i in range(video_depth.shape[1])]) + video_depth = np.stack(video_depth, axis=0) + video_depth = torch.from_numpy(video_depth).permute(0, 3, 1, 2)[None] + + tracks = tracks + self.pad_value + + if self.grayscale: + transform = transforms.Grayscale() + video = transform(video) + video = video.repeat(1, 1, 3, 1, 1) + + tracking_video = self.draw_tracks_on_video( + video=video, + tracks=tracks, + visibility=visibility, + segm_mask=segm_mask, + gt_tracks=gt_tracks, + query_frame=query_frame, + compensate_for_camera_motion=compensate_for_camera_motion, + rigid_part=rigid_part + ) + + if save_video: + # import ipdb; ipdb.set_trace() + tracking_dir = os.path.join(self.save_dir, "tracking") + if not os.path.exists(tracking_dir): + os.makedirs(tracking_dir) + self.save_video(tracking_video, filename=filename+"_tracking", + savedir=tracking_dir, writer=writer, step=step) + # with open(self.ttxt_path, 'a') as file: + # file.write(f"tracking/{filename}_tracking.mp4\n") + + videos_dir = os.path.join(self.save_dir, "videos") + if not os.path.exists(videos_dir): + os.makedirs(videos_dir) + self.save_video(video, filename=filename, + savedir=videos_dir, writer=writer, step=step) + # with open(self.vtxt_path, 'a') as file: + # file.write(f"videos/{filename}.mp4\n") + if video_depth is not None: + self.save_video(video_depth, filename=filename+"_depth", + savedir=os.path.join(self.save_dir, "depth"), writer=writer, step=step) + return tracking_video + + def save_video(self, video, filename, savedir=None, writer=None, step=0): + if writer is not None: + writer.add_video( + f"{filename}", + video.to(torch.uint8), + global_step=step, + fps=self.fps, + ) + else: + os.makedirs(self.save_dir, exist_ok=True) + wide_list = list(video.unbind(1)) + wide_list = [wide[0].permute(1, 2, 0).cpu().numpy() for wide in wide_list] + # clip = ImageSequenceClip(wide_list[2:-1], fps=self.fps) + clip = ImageSequenceClip(wide_list, fps=self.fps) + + # Write the video file + if savedir is None: + save_path = os.path.join(self.save_dir, f"{filename}.mp4") + else: + save_path = os.path.join(savedir, f"{filename}.mp4") + clip.write_videofile(save_path, codec="libx264", fps=self.fps, logger=None) + + print(f"Video saved to {save_path}") + + def draw_tracks_on_video( + self, + video: torch.Tensor, + tracks: torch.Tensor, + visibility: torch.Tensor = None, + segm_mask: torch.Tensor = None, + gt_tracks=None, + query_frame: int = 0, + compensate_for_camera_motion=False, + rigid_part=None, + ): + B, T, C, H, W = video.shape + _, _, N, D = tracks.shape + + assert D == 3 + assert C == 3 + video = video[0].permute(0, 2, 3, 1).byte().detach().cpu().numpy() # S, H, W, C + tracks = tracks[0].detach().cpu().numpy() # S, N, 2 + if gt_tracks is not None: + gt_tracks = gt_tracks[0].detach().cpu().numpy() + + res_video = [] + + # process input video + # for rgb in video: + # res_video.append(rgb.copy()) + + # create a blank tensor with the same shape as the video + for rgb in video: + black_frame = np.zeros_like(rgb.copy(), dtype=rgb.dtype) + res_video.append(black_frame) + + vector_colors = np.zeros((T, N, 3)) + + if self.mode == "optical_flow": + + vector_colors = flow_vis.flow_to_color(tracks - tracks[query_frame][None]) + + elif segm_mask is None: + if self.mode == "rainbow": + x_min, x_max = tracks[0, :, 0].min(), tracks[0, :, 0].max() + y_min, y_max = tracks[0, :, 1].min(), tracks[0, :, 1].max() + + z_inv = 1/tracks[0, :, 2] + z_min, z_max = np.percentile(z_inv, [2, 98]) + + norm_x = plt.Normalize(x_min, x_max) + norm_y = plt.Normalize(y_min, y_max) + norm_z = plt.Normalize(z_min, z_max) + + for n in range(N): + r = norm_x(tracks[0, n, 0]) + g = norm_y(tracks[0, n, 1]) + # r = 0 + # g = 0 + b = norm_z(1/tracks[0, n, 2]) + color = np.array([r, g, b])[None] * 255 + vector_colors[:, n] = np.repeat(color, T, axis=0) + else: + # color changes with time + for t in range(T): + color = np.array(self.color_map(t / T)[:3])[None] * 255 + vector_colors[t] = np.repeat(color, N, axis=0) + else: + if self.mode == "rainbow": + vector_colors[:, segm_mask <= 0, :] = 255 + + x_min, x_max = tracks[0, :, 0].min(), tracks[0, :, 0].max() + y_min, y_max = tracks[0, :, 1].min(), tracks[0, :, 1].max() + z_min, z_max = tracks[0, :, 2].min(), tracks[0, :, 2].max() + + norm_x = plt.Normalize(x_min, x_max) + norm_y = plt.Normalize(y_min, y_max) + norm_z = plt.Normalize(z_min, z_max) + + for n in range(N): + r = norm_x(tracks[0, n, 0]) + g = norm_y(tracks[0, n, 1]) + b = norm_z(tracks[0, n, 2]) + color = np.array([r, g, b])[None] * 255 + vector_colors[:, n] = np.repeat(color, T, axis=0) + + else: + # color changes with segm class + segm_mask = segm_mask.cpu() + color = np.zeros((segm_mask.shape[0], 3), dtype=np.float32) + color[segm_mask > 0] = np.array(self.color_map(1.0)[:3]) * 255.0 + color[segm_mask <= 0] = np.array(self.color_map(0.0)[:3]) * 255.0 + vector_colors = np.repeat(color[None], T, axis=0) + + # Draw tracks + if self.tracks_leave_trace != 0: + for t in range(1, T): + first_ind = ( + max(0, t - self.tracks_leave_trace) + if self.tracks_leave_trace >= 0 + else 0 + ) + curr_tracks = tracks[first_ind : t + 1] + curr_colors = vector_colors[first_ind : t + 1] + if compensate_for_camera_motion: + diff = ( + tracks[first_ind : t + 1, segm_mask <= 0] + - tracks[t : t + 1, segm_mask <= 0] + ).mean(1)[:, None] + + curr_tracks = curr_tracks - diff + curr_tracks = curr_tracks[:, segm_mask > 0] + curr_colors = curr_colors[:, segm_mask > 0] + + res_video[t] = self._draw_pred_tracks( + res_video[t], + curr_tracks, + curr_colors, + ) + if gt_tracks is not None: + res_video[t] = self._draw_gt_tracks( + res_video[t], gt_tracks[first_ind : t + 1] + ) + + if rigid_part is not None: + cls_label = torch.unique(rigid_part) + cls_num = len(torch.unique(rigid_part)) + # visualize the clustering results + cmap = plt.get_cmap('jet') # get the color mapping + colors = cmap(np.linspace(0, 1, cls_num)) + colors = (colors[:, :3] * 255) + color_map = {lable.item(): color for lable, color in zip(cls_label, colors)} + + # Draw points + for t in tqdm(range(T)): + # Create a list to store information for each point + points_info = [] + for i in range(N): + coord = (tracks[t, i, 0], tracks[t, i, 1]) + depth = tracks[t, i, 2] # assume the third dimension is depth + visibile = True + if visibility is not None: + visibile = visibility[0, t, i] + if coord[0] != 0 and coord[1] != 0: + if not compensate_for_camera_motion or ( + compensate_for_camera_motion and segm_mask[i] > 0 + ): + points_info.append((i, coord, depth, visibile)) + + # Sort points by depth, points with smaller depth (closer) will be drawn later + points_info.sort(key=lambda x: x[2], reverse=True) + + for i, coord, _, visibile in points_info: + if rigid_part is not None: + color = color_map[rigid_part.squeeze()[i].item()] + cv2.circle( + res_video[t], + coord, + int(self.linewidth * 2), + color.tolist(), + thickness=-1 if visibile else 2 + -1, + ) + else: + # Determine rectangle width based on the distance between adjacent tracks in the first frame + if t == 0: + distances = np.linalg.norm(tracks[0] - tracks[0, i], axis=1) + distances = distances[distances > 0] + rect_size = int(np.min(distances))/2 + + # Define coordinates for top-left and bottom-right corners of the rectangle + top_left = (int(coord[0] - rect_size), int(coord[1] - rect_size/1.5)) # Rectangle width is 1.5x (video aspect ratio is 1.5:1) + bottom_right = (int(coord[0] + rect_size), int(coord[1] + rect_size/1.5)) + + # Draw rectangle + cv2.rectangle( + res_video[t], + top_left, + bottom_right, + vector_colors[t, i].tolist(), + thickness=-1 if visibile else 0 + -1, + ) + + # Construct the final rgb sequence + return torch.from_numpy(np.stack(res_video)).permute(0, 3, 1, 2)[None].byte() + + def _draw_pred_tracks( + self, + rgb: np.ndarray, # H x W x 3 + tracks: np.ndarray, # T x 2 + vector_colors: np.ndarray, + alpha: float = 0.5, + ): + T, N, _ = tracks.shape + + for s in range(T - 1): + vector_color = vector_colors[s] + original = rgb.copy() + alpha = (s / T) ** 2 + for i in range(N): + coord_y = (int(tracks[s, i, 0]), int(tracks[s, i, 1])) + coord_x = (int(tracks[s + 1, i, 0]), int(tracks[s + 1, i, 1])) + if coord_y[0] != 0 and coord_y[1] != 0: + cv2.line( + rgb, + coord_y, + coord_x, + vector_color[i].tolist(), + self.linewidth, + cv2.LINE_AA, + ) + if self.tracks_leave_trace > 0: + rgb = cv2.addWeighted(rgb, alpha, original, 1 - alpha, 0) + return rgb + + def _draw_gt_tracks( + self, + rgb: np.ndarray, # H x W x 3, + gt_tracks: np.ndarray, # T x 2 + ): + T, N, _ = gt_tracks.shape + color = np.array((211.0, 0.0, 0.0)) + + for t in range(T): + for i in range(N): + gt_tracks = gt_tracks[t][i] + # draw a red cross + if gt_tracks[0] > 0 and gt_tracks[1] > 0: + length = self.linewidth * 3 + coord_y = (int(gt_tracks[0]) + length, int(gt_tracks[1]) + length) + coord_x = (int(gt_tracks[0]) - length, int(gt_tracks[1]) - length) + cv2.line( + rgb, + coord_y, + coord_x, + color, + self.linewidth, + cv2.LINE_AA, + ) + coord_y = (int(gt_tracks[0]) - length, int(gt_tracks[1]) + length) + coord_x = (int(gt_tracks[0]) + length, int(gt_tracks[1]) - length) + cv2.line( + rgb, + coord_y, + coord_x, + color, + self.linewidth, + cv2.LINE_AA, + ) + return rgb diff --git a/models/spatracker/utils/vox.py b/models/spatracker/utils/vox.py new file mode 100644 index 0000000000000000000000000000000000000000..203097b8736eabc2158950a11f4600b7848f119e --- /dev/null +++ b/models/spatracker/utils/vox.py @@ -0,0 +1,500 @@ +import numpy as np +import torch +import torch.nn.functional as F + +import utils.geom + +class Vox_util(object): + def __init__(self, Z, Y, X, scene_centroid, bounds, pad=None, assert_cube=False): + self.XMIN, self.XMAX, self.YMIN, self.YMAX, self.ZMIN, self.ZMAX = bounds + B, D = list(scene_centroid.shape) + self.Z, self.Y, self.X = Z, Y, X + + scene_centroid = scene_centroid.detach().cpu().numpy() + x_centroid, y_centroid, z_centroid = scene_centroid[0] + self.XMIN += x_centroid + self.XMAX += x_centroid + self.YMIN += y_centroid + self.YMAX += y_centroid + self.ZMIN += z_centroid + self.ZMAX += z_centroid + + self.default_vox_size_X = (self.XMAX-self.XMIN)/float(X) + self.default_vox_size_Y = (self.YMAX-self.YMIN)/float(Y) + self.default_vox_size_Z = (self.ZMAX-self.ZMIN)/float(Z) + + if pad: + Z_pad, Y_pad, X_pad = pad + self.ZMIN -= self.default_vox_size_Z * Z_pad + self.ZMAX += self.default_vox_size_Z * Z_pad + self.YMIN -= self.default_vox_size_Y * Y_pad + self.YMAX += self.default_vox_size_Y * Y_pad + self.XMIN -= self.default_vox_size_X * X_pad + self.XMAX += self.default_vox_size_X * X_pad + + if assert_cube: + # we assume cube voxels + if (not np.isclose(self.default_vox_size_X, self.default_vox_size_Y)) or (not np.isclose(self.default_vox_size_X, self.default_vox_size_Z)): + print('Z, Y, X', Z, Y, X) + print('bounds for this iter:', + 'X = %.2f to %.2f' % (self.XMIN, self.XMAX), + 'Y = %.2f to %.2f' % (self.YMIN, self.YMAX), + 'Z = %.2f to %.2f' % (self.ZMIN, self.ZMAX), + ) + print('self.default_vox_size_X', self.default_vox_size_X) + print('self.default_vox_size_Y', self.default_vox_size_Y) + print('self.default_vox_size_Z', self.default_vox_size_Z) + assert(np.isclose(self.default_vox_size_X, self.default_vox_size_Y)) + assert(np.isclose(self.default_vox_size_X, self.default_vox_size_Z)) + + def Ref2Mem(self, xyz, Z, Y, X, assert_cube=False): + # xyz is B x N x 3, in ref coordinates + # transforms ref coordinates into mem coordinates + B, N, C = list(xyz.shape) + device = xyz.device + assert(C==3) + mem_T_ref = self.get_mem_T_ref(B, Z, Y, X, assert_cube=assert_cube, device=device) + xyz = utils.geom.apply_4x4(mem_T_ref, xyz) + return xyz + + def Mem2Ref(self, xyz_mem, Z, Y, X, assert_cube=False): + # xyz is B x N x 3, in mem coordinates + # transforms mem coordinates into ref coordinates + B, N, C = list(xyz_mem.shape) + ref_T_mem = self.get_ref_T_mem(B, Z, Y, X, assert_cube=assert_cube, device=xyz_mem.device) + xyz_ref = utils.geom.apply_4x4(ref_T_mem, xyz_mem) + return xyz_ref + + def get_mem_T_ref(self, B, Z, Y, X, assert_cube=False, device='cuda'): + vox_size_X = (self.XMAX-self.XMIN)/float(X) + vox_size_Y = (self.YMAX-self.YMIN)/float(Y) + vox_size_Z = (self.ZMAX-self.ZMIN)/float(Z) + + if assert_cube: + if (not np.isclose(vox_size_X, vox_size_Y)) or (not np.isclose(vox_size_X, vox_size_Z)): + print('Z, Y, X', Z, Y, X) + print('bounds for this iter:', + 'X = %.2f to %.2f' % (self.XMIN, self.XMAX), + 'Y = %.2f to %.2f' % (self.YMIN, self.YMAX), + 'Z = %.2f to %.2f' % (self.ZMIN, self.ZMAX), + ) + print('vox_size_X', vox_size_X) + print('vox_size_Y', vox_size_Y) + print('vox_size_Z', vox_size_Z) + assert(np.isclose(vox_size_X, vox_size_Y)) + assert(np.isclose(vox_size_X, vox_size_Z)) + + # translation + # (this makes the left edge of the leftmost voxel correspond to XMIN) + center_T_ref = utils.geom.eye_4x4(B, device=device) + center_T_ref[:,0,3] = -self.XMIN-vox_size_X/2.0 + center_T_ref[:,1,3] = -self.YMIN-vox_size_Y/2.0 + center_T_ref[:,2,3] = -self.ZMIN-vox_size_Z/2.0 + + # scaling + # (this makes the right edge of the rightmost voxel correspond to XMAX) + mem_T_center = utils.geom.eye_4x4(B, device=device) + mem_T_center[:,0,0] = 1./vox_size_X + mem_T_center[:,1,1] = 1./vox_size_Y + mem_T_center[:,2,2] = 1./vox_size_Z + mem_T_ref = utils.geom.matmul2(mem_T_center, center_T_ref) + + return mem_T_ref + + def get_ref_T_mem(self, B, Z, Y, X, assert_cube=False, device='cuda'): + mem_T_ref = self.get_mem_T_ref(B, Z, Y, X, assert_cube=assert_cube, device=device) + # note safe_inverse is inapplicable here, + # since the transform is nonrigid + ref_T_mem = mem_T_ref.inverse() + return ref_T_mem + + def get_inbounds(self, xyz, Z, Y, X, already_mem=False, padding=0.0, assert_cube=False): + # xyz is B x N x 3 + # padding should be 0 unless you are trying to account for some later cropping + if not already_mem: + xyz = self.Ref2Mem(xyz, Z, Y, X, assert_cube=assert_cube) + + x = xyz[:,:,0] + y = xyz[:,:,1] + z = xyz[:,:,2] + + x_valid = ((x-padding)>-0.5).byte() & ((x+padding)-0.5).byte() & ((y+padding)-0.5).byte() & ((z+padding) 0: + # only take points that are already near centers + xyz_round = torch.round(xyz) # B, N, 3 + dist = torch.norm(xyz_round - xyz, dim=2) + mask[dist > clean_eps] = 0 + + # set the invalid guys to zero + # we then need to zero out 0,0,0 + # (this method seems a bit clumsy) + x = x*mask + y = y*mask + z = z*mask + + x = torch.round(x) + y = torch.round(y) + z = torch.round(z) + x = torch.clamp(x, 0, X-1).int() + y = torch.clamp(y, 0, Y-1).int() + z = torch.clamp(z, 0, Z-1).int() + + x = x.view(B*N) + y = y.view(B*N) + z = z.view(B*N) + + dim3 = X + dim2 = X * Y + dim1 = X * Y * Z + + base = torch.arange(0, B, dtype=torch.int32, device=xyz.device)*dim1 + base = torch.reshape(base, [B, 1]).repeat([1, N]).view(B*N) + + vox_inds = base + z * dim2 + y * dim3 + x + voxels = torch.zeros(B*Z*Y*X, device=xyz.device).float() + voxels[vox_inds.long()] = 1.0 + # zero out the singularity + voxels[base.long()] = 0.0 + voxels = voxels.reshape(B, 1, Z, Y, X) + # B x 1 x Z x Y x X + return voxels + + def get_feat_occupancy(self, xyz, feat, Z, Y, X, clean_eps=0, xyz_zero=None): + # xyz is B x N x 3 and in mem coords + # feat is B x N x D + # we want to fill a voxel tensor with 1's at these inds + B, N, C = list(xyz.shape) + B2, N2, D2 = list(feat.shape) + assert(C==3) + assert(B==B2) + assert(N==N2) + + # these papers say simple 1/0 occupancy is ok: + # http://openaccess.thecvf.com/content_cvpr_2018/papers/Yang_PIXOR_Real-Time_3d_CVPR_2018_paper.pdf + # http://openaccess.thecvf.com/content_cvpr_2018/papers/Luo_Fast_and_Furious_CVPR_2018_paper.pdf + # cont fusion says they do 8-neighbor interp + # voxelnet does occupancy but with a bit of randomness in terms of the reflectance value i think + + inbounds = self.get_inbounds(xyz, Z, Y, X, already_mem=True) + x, y, z = xyz[:,:,0], xyz[:,:,1], xyz[:,:,2] + mask = torch.zeros_like(x) + mask[inbounds] = 1.0 + + if xyz_zero is not None: + # only take points that are beyond a thresh of zero + dist = torch.norm(xyz_zero-xyz, dim=2) + mask[dist < 0.1] = 0 + + if clean_eps > 0: + # only take points that are already near centers + xyz_round = torch.round(xyz) # B, N, 3 + dist = torch.norm(xyz_round - xyz, dim=2) + mask[dist > clean_eps] = 0 + + # set the invalid guys to zero + # we then need to zero out 0,0,0 + # (this method seems a bit clumsy) + x = x*mask # B, N + y = y*mask + z = z*mask + feat = feat*mask.unsqueeze(-1) # B, N, D + + x = torch.round(x) + y = torch.round(y) + z = torch.round(z) + x = torch.clamp(x, 0, X-1).int() + y = torch.clamp(y, 0, Y-1).int() + z = torch.clamp(z, 0, Z-1).int() + + # permute point orders + perm = torch.randperm(N) + x = x[:, perm] + y = y[:, perm] + z = z[:, perm] + feat = feat[:, perm] + + x = x.view(B*N) + y = y.view(B*N) + z = z.view(B*N) + feat = feat.view(B*N, -1) + + dim3 = X + dim2 = X * Y + dim1 = X * Y * Z + + base = torch.arange(0, B, dtype=torch.int32, device=xyz.device)*dim1 + base = torch.reshape(base, [B, 1]).repeat([1, N]).view(B*N) + + vox_inds = base + z * dim2 + y * dim3 + x + feat_voxels = torch.zeros((B*Z*Y*X, D2), device=xyz.device).float() + feat_voxels[vox_inds.long()] = feat + # zero out the singularity + feat_voxels[base.long()] = 0.0 + feat_voxels = feat_voxels.reshape(B, Z, Y, X, D2).permute(0, 4, 1, 2, 3) + # B x C x Z x Y x X + return feat_voxels + + def unproject_image_to_mem(self, rgb_camB, pixB_T_camA, camB_T_camA, Z, Y, X, assert_cube=False, xyz_camA=None): + # rgb_camB is B x C x H x W + # pixB_T_camA is B x 4 x 4 + + # rgb lives in B pixel coords + # we want everything in A memory coords + + # this puts each C-dim pixel in the rgb_camB + # along a ray in the voxelgrid + B, C, H, W = list(rgb_camB.shape) + + if xyz_camA is None: + xyz_memA = utils.basic.gridcloud3d(B, Z, Y, X, norm=False, device=pixB_T_camA.device) + xyz_camA = self.Mem2Ref(xyz_memA, Z, Y, X, assert_cube=assert_cube) + + xyz_camB = utils.geom.apply_4x4(camB_T_camA, xyz_camA) + z = xyz_camB[:,:,2] + + xyz_pixB = utils.geom.apply_4x4(pixB_T_camA, xyz_camA) + normalizer = torch.unsqueeze(xyz_pixB[:,:,2], 2) + EPS=1e-6 + # z = xyz_pixB[:,:,2] + xy_pixB = xyz_pixB[:,:,:2]/torch.clamp(normalizer, min=EPS) + # this is B x N x 2 + # this is the (floating point) pixel coordinate of each voxel + x, y = xy_pixB[:,:,0], xy_pixB[:,:,1] + # these are B x N + + x_valid = (x>-0.5).bool() & (x-0.5).bool() & (y0.0).bool() + valid_mem = (x_valid & y_valid & z_valid).reshape(B, 1, Z, Y, X).float() + + if (0): + # handwritten version + values = torch.zeros([B, C, Z*Y*X], dtype=torch.float32) + for b in list(range(B)): + values[b] = utils.samp.bilinear_sample_single(rgb_camB[b], x_pixB[b], y_pixB[b]) + else: + # native pytorch version + y_pixB, x_pixB = utils.basic.normalize_grid2d(y, x, H, W) + # since we want a 3d output, we need 5d tensors + z_pixB = torch.zeros_like(x) + xyz_pixB = torch.stack([x_pixB, y_pixB, z_pixB], axis=2) + rgb_camB = rgb_camB.unsqueeze(2) + xyz_pixB = torch.reshape(xyz_pixB, [B, Z, Y, X, 3]) + values = F.grid_sample(rgb_camB, xyz_pixB, align_corners=False) + + values = torch.reshape(values, (B, C, Z, Y, X)) + values = values * valid_mem + return values + + def warp_tiled_to_mem(self, rgb_tileB, pixB_T_camA, camB_T_camA, Z, Y, X, DMIN, DMAX, assert_cube=False): + # rgb_tileB is B,C,D,H,W + # pixB_T_camA is B,4,4 + # camB_T_camA is B,4,4 + + # rgb_tileB lives in B pixel coords but it has been tiled across the Z dimension + # we want everything in A memory coords + + # this resamples the so that each C-dim pixel in rgb_tilB + # is put into its correct place in the voxelgrid + # (using the pinhole camera model) + + B, C, D, H, W = list(rgb_tileB.shape) + + xyz_memA = utils.basic.gridcloud3d(B, Z, Y, X, norm=False, device=pixB_T_camA.device) + + xyz_camA = self.Mem2Ref(xyz_memA, Z, Y, X, assert_cube=assert_cube) + + xyz_camB = utils.geom.apply_4x4(camB_T_camA, xyz_camA) + z_camB = xyz_camB[:,:,2] + + # rgb_tileB has depth=DMIN in tile 0, and depth=DMAX in tile D-1 + z_tileB = (D-1.0) * (z_camB-float(DMIN)) / float(DMAX-DMIN) + + xyz_pixB = utils.geom.apply_4x4(pixB_T_camA, xyz_camA) + normalizer = torch.unsqueeze(xyz_pixB[:,:,2], 2) + EPS=1e-6 + # z = xyz_pixB[:,:,2] + xy_pixB = xyz_pixB[:,:,:2]/torch.clamp(normalizer, min=EPS) + # this is B x N x 2 + # this is the (floating point) pixel coordinate of each voxel + x, y = xy_pixB[:,:,0], xy_pixB[:,:,1] + # these are B x N + + x_valid = (x>-0.5).bool() & (x-0.5).bool() & (y0.0).bool() + valid_mem = (x_valid & y_valid & z_valid).reshape(B, 1, Z, Y, X).float() + + z_tileB, y_pixB, x_pixB = utils.basic.normalize_grid3d(z_tileB, y, x, D, H, W) + xyz_pixB = torch.stack([x_pixB, y_pixB, z_tileB], axis=2) + xyz_pixB = torch.reshape(xyz_pixB, [B, Z, Y, X, 3]) + values = F.grid_sample(rgb_tileB, xyz_pixB, align_corners=False) + + values = torch.reshape(values, (B, C, Z, Y, X)) + values = values * valid_mem + return values + + + def apply_mem_T_ref_to_lrtlist(self, lrtlist_cam, Z, Y, X, assert_cube=False): + # lrtlist is B x N x 19, in cam coordinates + # transforms them into mem coordinates, including a scale change for the lengths + B, N, C = list(lrtlist_cam.shape) + assert(C==19) + mem_T_cam = self.get_mem_T_ref(B, Z, Y, X, assert_cube=assert_cube, device=lrtlist_cam.device) + + def xyz2circles(self, xyz, radius, Z, Y, X, soft=True, already_mem=True, also_offset=False, grid=None): + # xyz is B x N x 3 + # radius is B x N or broadcastably so + # output is B x N x Z x Y x X + B, N, D = list(xyz.shape) + assert(D==3) + if not already_mem: + xyz = self.Ref2Mem(xyz, Z, Y, X) + + if grid is None: + grid_z, grid_y, grid_x = utils.basic.meshgrid3d(B, Z, Y, X, stack=False, norm=False, device=xyz.device) + # note the default stack is on -1 + grid = torch.stack([grid_x, grid_y, grid_z], dim=1) + # this is B x 3 x Z x Y x X + + xyz = xyz.reshape(B, N, 3, 1, 1, 1) + grid = grid.reshape(B, 1, 3, Z, Y, X) + # this is B x N x Z x Y x X + + # round the xyzs, so that at least one value matches the grid perfectly, + # and we get a value of 1 there (since exp(0)==1) + xyz = xyz.round() + + if torch.is_tensor(radius): + radius = radius.clamp(min=0.01) + + if soft: + off = grid - xyz # B,N,3,Z,Y,X + # interpret radius as sigma + dist_grid = torch.sum(off**2, dim=2, keepdim=False) + # this is B x N x Z x Y x X + if torch.is_tensor(radius): + radius = radius.reshape(B, N, 1, 1, 1) + mask = torch.exp(-dist_grid/(2*radius*radius)) + # zero out near zero + mask[mask < 0.001] = 0.0 + # h = np.exp(-(x * x + y * y) / (2 * sigma * sigma)) + # h[h < np.finfo(h.dtype).eps * h.max()] = 0 + # return h + if also_offset: + return mask, off + else: + return mask + else: + assert(False) # something is wrong with this. come back later to debug + + dist_grid = torch.norm(grid - xyz, dim=2, keepdim=False) + # this is 0 at/near the xyz, and increases by 1 for each voxel away + + radius = radius.reshape(B, N, 1, 1, 1) + + within_radius_mask = (dist_grid < radius).float() + within_radius_mask = torch.sum(within_radius_mask, dim=1, keepdim=True).clamp(0, 1) + return within_radius_mask + + def xyz2circles_bev(self, xyz, radius, Z, Y, X, already_mem=True, also_offset=False): + # xyz is B x N x 3 + # radius is B x N or broadcastably so + # output is B x N x Z x Y x X + B, N, D = list(xyz.shape) + assert(D==3) + if not already_mem: + xyz = self.Ref2Mem(xyz, Z, Y, X) + + xz = torch.stack([xyz[:,:,0], xyz[:,:,2]], dim=2) + + grid_z, grid_x = utils.basic.meshgrid2d(B, Z, X, stack=False, norm=False, device=xyz.device) + # note the default stack is on -1 + grid = torch.stack([grid_x, grid_z], dim=1) + # this is B x 2 x Z x X + + xz = xz.reshape(B, N, 2, 1, 1) + grid = grid.reshape(B, 1, 2, Z, X) + # these are ready to broadcast to B x N x Z x X + + # round the points, so that at least one value matches the grid perfectly, + # and we get a value of 1 there (since exp(0)==1) + xz = xz.round() + + if torch.is_tensor(radius): + radius = radius.clamp(min=0.01) + + off = grid - xz # B,N,2,Z,X + # interpret radius as sigma + dist_grid = torch.sum(off**2, dim=2, keepdim=False) + # this is B x N x Z x X + if torch.is_tensor(radius): + radius = radius.reshape(B, N, 1, 1, 1) + mask = torch.exp(-dist_grid/(2*radius*radius)) + # zero out near zero + mask[mask < 0.001] = 0.0 + + # add a Y dim + mask = mask.unsqueeze(-2) + off = off.unsqueeze(-2) + # # B,N,2,Z,1,X + + if also_offset: + return mask, off + else: + return mask + diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..13ad89f9c24fbcd8abaa5fd52995f2bfcbd67194 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,32 @@ +# spatrack +easydict==1.13 +opencv-python==4.9.0.80 +moviepy==1.0.3 +flow-vis==0.1 +matplotlib==3.8.3 +einops==0.7.0 +timm==0.6.7 +scikit-image==0.22.0 +scikit-learn==1.4.1.post1 +cupy-cuda11x +accelerate +yt-dlp +pandas + +# cogvideox +bitsandbytes +diffusers>=0.31.2 +transformers>=4.45.2 +hf_transfer>=0.1.8 +peft>=0.12.0 +decord>=0.6.0 +wandb +torchao>=0.5.0 +sentencepiece>=0.2.0 +imageio-ffmpeg>=0.5.1 +numpy>=1.26.4 +git+https://github.com/asomoza/image_gen_aux.git +deepspeed + +# submodules +-r submodules/MoGe/requirements.txt \ No newline at end of file diff --git a/submodules/MoGe/.gitignore b/submodules/MoGe/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..8964a27da39b31a8a47d3f584197f6335c854dae --- /dev/null +++ b/submodules/MoGe/.gitignore @@ -0,0 +1,425 @@ +## Ignore Visual Studio temporary files, build results, and +## files generated by popular Visual Studio add-ons. +## +## Get latest from https://github.com/github/gitignore/blob/main/VisualStudio.gitignore + +# User-specific files +*.rsuser +*.suo +*.user +*.userosscache +*.sln.docstates + +# User-specific files (MonoDevelop/Xamarin Studio) +*.userprefs + +# Mono auto generated files +mono_crash.* + +# Build results +[Dd]ebug/ +[Dd]ebugPublic/ +[Rr]elease/ +[Rr]eleases/ +x64/ +x86/ +[Ww][Ii][Nn]32/ +[Aa][Rr][Mm]/ +[Aa][Rr][Mm]64/ +bld/ +[Bb]in/ +[Oo]bj/ +[Ll]og/ +[Ll]ogs/ + +# Visual Studio 2015/2017 cache/options directory +.vs/ +# Uncomment if you have tasks that create the project's static files in wwwroot +#wwwroot/ + +# Visual Studio 2017 auto generated files +Generated\ Files/ + +# MSTest test Results +[Tt]est[Rr]esult*/ +[Bb]uild[Ll]og.* + +# NUnit +*.VisualState.xml +TestResult.xml +nunit-*.xml + +# Build Results of an ATL Project +[Dd]ebugPS/ +[Rr]eleasePS/ +dlldata.c + +# Benchmark Results +BenchmarkDotNet.Artifacts/ + +# .NET Core +project.lock.json +project.fragment.lock.json +artifacts/ + +# ASP.NET Scaffolding +ScaffoldingReadMe.txt + +# StyleCop +StyleCopReport.xml + +# Files built by Visual Studio +*_i.c +*_p.c +*_h.h +*.ilk +*.meta +*.obj +*.iobj +*.pch +*.pdb +*.ipdb +*.pgc +*.pgd +*.rsp +*.sbr +*.tlb +*.tli +*.tlh +*.tmp +*.tmp_proj +*_wpftmp.csproj +*.log +*.tlog +*.vspscc +*.vssscc +.builds +*.pidb +*.svclog +*.scc + +# Chutzpah Test files +_Chutzpah* + +# Visual C++ cache files +ipch/ +*.aps +*.ncb +*.opendb +*.opensdf +*.sdf +*.cachefile +*.VC.db +*.VC.VC.opendb + +# Visual Studio profiler +*.psess +*.vsp +*.vspx +*.sap + +# Visual Studio Trace Files +*.e2e + +# TFS 2012 Local Workspace +$tf/ + +# Guidance Automation Toolkit +*.gpState + +# ReSharper is a .NET coding add-in +_ReSharper*/ +*.[Rr]e[Ss]harper +*.DotSettings.user + +# TeamCity is a build add-in +_TeamCity* + +# DotCover is a Code Coverage Tool +*.dotCover + +# AxoCover is a Code Coverage Tool +.axoCover/* +!.axoCover/settings.json + +# Coverlet is a free, cross platform Code Coverage Tool +coverage*.json +coverage*.xml +coverage*.info + +# Visual Studio code coverage results +*.coverage +*.coveragexml + +# NCrunch +_NCrunch_* +.*crunch*.local.xml +nCrunchTemp_* + +# MightyMoose +*.mm.* +AutoTest.Net/ + +# Web workbench (sass) +.sass-cache/ + +# Installshield output folder +[Ee]xpress/ + +# DocProject is a documentation generator add-in +DocProject/buildhelp/ +DocProject/Help/*.HxT +DocProject/Help/*.HxC +DocProject/Help/*.hhc +DocProject/Help/*.hhk +DocProject/Help/*.hhp +DocProject/Help/Html2 +DocProject/Help/html + +# Click-Once directory +publish/ + +# Publish Web Output +*.[Pp]ublish.xml +*.azurePubxml +# Note: Comment the next line if you want to checkin your web deploy settings, +# but database connection strings (with potential passwords) will be unencrypted +*.pubxml +*.publishproj + +# Microsoft Azure Web App publish settings. Comment the next line if you want to +# checkin your Azure Web App publish settings, but sensitive information contained +# in these scripts will be unencrypted +PublishScripts/ + +# NuGet Packages +*.nupkg +# NuGet Symbol Packages +*.snupkg +# The packages folder can be ignored because of Package Restore +**/[Pp]ackages/* +# except build/, which is used as an MSBuild target. +!**/[Pp]ackages/build/ +# Uncomment if necessary however generally it will be regenerated when needed +#!**/[Pp]ackages/repositories.config +# NuGet v3's project.json files produces more ignorable files +*.nuget.props +*.nuget.targets + +# Microsoft Azure Build Output +csx/ +*.build.csdef + +# Microsoft Azure Emulator +ecf/ +rcf/ + +# Windows Store app package directories and files +AppPackages/ +BundleArtifacts/ +Package.StoreAssociation.xml +_pkginfo.txt +*.appx +*.appxbundle +*.appxupload + +# Visual Studio cache files +# files ending in .cache can be ignored +*.[Cc]ache +# but keep track of directories ending in .cache +!?*.[Cc]ache/ + +# Others +ClientBin/ +~$* +*~ +*.dbmdl +*.dbproj.schemaview +*.jfm +*.pfx +*.publishsettings +orleans.codegen.cs + +# Including strong name files can present a security risk +# (https://github.com/github/gitignore/pull/2483#issue-259490424) +#*.snk + +# Since there are multiple workflows, uncomment next line to ignore bower_components +# (https://github.com/github/gitignore/pull/1529#issuecomment-104372622) +#bower_components/ + +# RIA/Silverlight projects +Generated_Code/ + +# Backup & report files from converting an old project file +# to a newer Visual Studio version. Backup files are not needed, +# because we have git ;-) +_UpgradeReport_Files/ +Backup*/ +UpgradeLog*.XML +UpgradeLog*.htm +ServiceFabricBackup/ +*.rptproj.bak + +# SQL Server files +*.mdf +*.ldf +*.ndf + +# Business Intelligence projects +*.rdl.data +*.bim.layout +*.bim_*.settings +*.rptproj.rsuser +*- [Bb]ackup.rdl +*- [Bb]ackup ([0-9]).rdl +*- [Bb]ackup ([0-9][0-9]).rdl + +# Microsoft Fakes +FakesAssemblies/ + +# GhostDoc plugin setting file +*.GhostDoc.xml + +# Node.js Tools for Visual Studio +.ntvs_analysis.dat +node_modules/ + +# Visual Studio 6 build log +*.plg + +# Visual Studio 6 workspace options file +*.opt + +# Visual Studio 6 auto-generated workspace file (contains which files were open etc.) +*.vbw + +# Visual Studio 6 auto-generated project file (contains which files were open etc.) +*.vbp + +# Visual Studio 6 workspace and project file (working project files containing files to include in project) +*.dsw +*.dsp + +# Visual Studio 6 technical files +*.ncb +*.aps + +# Visual Studio LightSwitch build output +**/*.HTMLClient/GeneratedArtifacts +**/*.DesktopClient/GeneratedArtifacts +**/*.DesktopClient/ModelManifest.xml +**/*.Server/GeneratedArtifacts +**/*.Server/ModelManifest.xml +_Pvt_Extensions + +# Paket dependency manager +.paket/paket.exe +paket-files/ + +# FAKE - F# Make +.fake/ + +# CodeRush personal settings +.cr/personal + +# Python Tools for Visual Studio (PTVS) +__pycache__/ +*.pyc + +# Cake - Uncomment if you are using it +# tools/** +# !tools/packages.config + +# Tabs Studio +*.tss + +# Telerik's JustMock configuration file +*.jmconfig + +# BizTalk build output +*.btp.cs +*.btm.cs +*.odx.cs +*.xsd.cs + +# OpenCover UI analysis results +OpenCover/ + +# Azure Stream Analytics local run output +ASALocalRun/ + +# MSBuild Binary and Structured Log +*.binlog + +# NVidia Nsight GPU debugger configuration file +*.nvuser + +# MFractors (Xamarin productivity tool) working folder +.mfractor/ + +# Local History for Visual Studio +.localhistory/ + +# Visual Studio History (VSHistory) files +.vshistory/ + +# BeatPulse healthcheck temp database +healthchecksdb + +# Backup folder for Package Reference Convert tool in Visual Studio 2017 +MigrationBackup/ + +# Ionide (cross platform F# VS Code tools) working folder +.ionide/ + +# Fody - auto-generated XML schema +FodyWeavers.xsd + +# VS Code files for those working on multiple tools +.vscode/* +!.vscode/settings.json +!.vscode/tasks.json +!.vscode/launch.json +!.vscode/extensions.json +*.code-workspace + +# Local History for Visual Studio Code +.history/ + +# Windows Installer files from build outputs +*.cab +*.msi +*.msix +*.msm +*.msp + +# JetBrains Rider +*.sln.iml + +# MoGe +/data +/download +/extract +/view_point_cloud +/view_depth_map +/blobcache +/snapshot +/reference_embeddings +/.msra_intern_s_toolkit +/debug +/workspace +/mlruns +/infer_output +/video_output +/eval_output +/.blobcache +/test_images +/test_videos +/vis +/videos +/raid +/blobmnt +/eval_dump +/pretrained +/.gradio diff --git a/submodules/MoGe/CHANGELOG.md b/submodules/MoGe/CHANGELOG.md new file mode 100644 index 0000000000000000000000000000000000000000..36d768b237508e7ffb0b1fa677b6fa252baebcfa --- /dev/null +++ b/submodules/MoGe/CHANGELOG.md @@ -0,0 +1,15 @@ +## 2024-11-28 +### Added +- Supported user-provided camera FOV. See [scripts/infer.py](scripts/infer.py) --fov_x. + - Related issues: [#25](https://github.com/microsoft/MoGe/issues/25) and [#24](https://github.com/microsoft/MoGe/issues/24). +- Added inference scripts for panorama images. See [scripts/infer_panorama.py](scripts/infer_panorama.py). + - Related issue: [#19](https://github.com/microsoft/MoGe/issues/19). + +### Fixed +- Suppressed unnecessary numpy runtime warnings. +- Specified recommended versions of requirements. + - Related issue: [#21](https://github.com/microsoft/MoGe/issues/21). + +### Changed +- Moved `app.py` and `infer.py` to [scripts/](scripts/) +- Improved edge removal. diff --git a/submodules/MoGe/CODE_OF_CONDUCT.md b/submodules/MoGe/CODE_OF_CONDUCT.md new file mode 100644 index 0000000000000000000000000000000000000000..f9ba8cf65f3e3104dd061c178066ec8247811f33 --- /dev/null +++ b/submodules/MoGe/CODE_OF_CONDUCT.md @@ -0,0 +1,9 @@ +# Microsoft Open Source Code of Conduct + +This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/). + +Resources: + +- [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/) +- [Microsoft Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/) +- Contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with questions or concerns diff --git a/submodules/MoGe/LICENSE b/submodules/MoGe/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..3458b5ccd398afed340e17a4d0615c9a8666bb5d --- /dev/null +++ b/submodules/MoGe/LICENSE @@ -0,0 +1,224 @@ + MIT License + + Copyright (c) Microsoft Corporation. + + Permission is hereby granted, free of charge, to any person obtaining a copy + of this software and associated documentation files (the "Software"), to deal + in the Software without restriction, including without limitation the rights + to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + copies of the Software, and to permit persons to whom the Software is + furnished to do so, subject to the following conditions: + + The above copyright notice and this permission notice shall be included in all + copies or substantial portions of the Software. + + THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + SOFTWARE + + + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/submodules/MoGe/README.md b/submodules/MoGe/README.md new file mode 100644 index 0000000000000000000000000000000000000000..a7505cc96abeb2e5bbb0f3ee718878619f90cfb7 --- /dev/null +++ b/submodules/MoGe/README.md @@ -0,0 +1,189 @@ +
+ +# MoGe: Unlocking Accurate Monocular Geometry Estimation for Open-Domain Images with Optimal Training Supervision + +arXiv +Project Page + + +
+ +Method overview + +MoGe is a powerful model for recovering 3D geometry from monocular open-domain images. The model consists of a ViT encoder and a convolutional decoder. It directly predicts an affine-invariant point map as well as a mask that excludes regions with undefined geometry (e.g., sky), from which the camera shift, camera focal length and depth map can be further derived. + +***Check our [website](https://wangrc.site/MoGePage) for videos and interactive results!*** + +## Features + +* **Accurately** estimate 3D geometry in point map or mesh format from a **single** image. +* Support various image resolutions and aspect ratios, ranging from **2:1** to **1:2**. +* Capable of producing an extensive depth range, with distances from nearest to farthest reaching up to **1000x**. +* **Fast** inference, typically **0.2s** for a single image on an A100 or RTX 3090 GPU. + +## TODO List + +- [x] Release inference code & ViT-Large model. +- [ ] Release ViT-Base and ViT-Giant models. +- [ ] Release evaluation and training code. + +🌟*Updated on 2024/11/28* - [CHANGELOG](CHANGELOG.md): + * Supported user-provided camera FOV. + * Added the script for panorama images [scripts/infer_panorama.py](scripts/infer_panorama.py). + +## Usage + +### Prerequisite + +- Clone this repository. + + ```bash + git clone https://github.com/microsoft/MoGe.git + cd MoGe + ``` + +- Python (>= 3.10) environment: + - torch (>= 2.0) and torchvision (compatible with the torch version). + - other requirements + ```bash + pip install -r requirements.txt + ``` + MoGe should be compatible with most requirements versions. Please check the `requirements.txt` for more details if you have concerns. + +### Pretrained model + +The ViT-Large model has been uploaded to Hugging Face hub at [Ruicheng/moge-vitl](https://huggingface.co/Ruicheng/moge-vitl). +You may load the model via `MoGeModel.from_pretrained("Ruicheng/moge-vitl")` without manually downloading. + +If loading the model from a local file is preferred, you may manually download the model from the huggingface hub and load it via `MoGeModel.from_pretrained("PATH_TO_LOCAL_MODEL.pt")`. + +### Minimal example + +Here is a minimal example for loading the model and inferring on a single image. + +```python +import cv2 +import torch +from moge.model import MoGeModel + +device = torch.device("cuda") + +# Load the model from huggingface hub (or load from local). +model = MoGeModel.from_pretrained("Ruicheng/moge-vitl").to(device) + +# Read the input image and convert to tensor (3, H, W) and normalize to [0, 1] +input_image = cv2.cvtColor(cv2.imread("PATH_TO_IMAGE.jpg"), cv2.COLOR_BGR2RGB) +input_image = torch.tensor(input_image / 255, dtype=torch.float32, device=device).permute(2, 0, 1) + +# Infer +output = model.infer(input_image) +# `output` has keys "points", "depth", "mask" and "intrinsics", +# The maps are in the same size as the input image. +# { +# "points": (H, W, 3), # scale-invariant point map in OpenCV camera coordinate system (x right, y down, z forward) +# "depth": (H, W), # scale-invariant depth map +# "mask": (H, W), # a binary mask for valid pixels. +# "intrinsics": (3, 3), # normalized camera intrinsics +# } +# For more usage details, see the `MoGeModel.infer` docstring. +``` + +### Using [scripts/app.py](scripts/app.py) for a web demo + +Make sure that `gradio` is installed and then run the following command to start the web demo: + +```bash +python scripts/app.py # --share for Gradio public sharing +``` + +The web demo is also available at our [Hugging Face space](https://huggingface.co/spaces/Ruicheng/MoGe). + + +### Using [scripts/infer.py](scripts/infer.py) + +Run the script `scripts/infer.py` via the following command: + +```bash +# Save the output [maps], [glb] and [ply] files +python scripts/infer.py --input IMAGES_FOLDER_OR_IMAGE_PATH --output OUTPUT_FOLDER --maps --glb --ply + +# Show the result in a window (requires pyglet < 2.0, e.g. pip install pyglet==1.5.29) +python scripts/infer.py --input IMAGES_FOLDER_OR_IMAGE_PATH --output OUTPUT_FOLDER --show +``` + +For detailed options, run `python scripts/infer.py --help`: + +``` +Usage: infer.py [OPTIONS] + + Inference script for the MoGe model. + +Options: + --input PATH Input image or folder path. "jpg" and "png" are + supported. + --fov_x FLOAT If camera parameters are known, set the + horizontal field of view in degrees. Otherwise, + MoGe will estimate it. + --output PATH Output folder path + --pretrained TEXT Pretrained model name or path. Default is + "Ruicheng/moge-vitl" + --device TEXT Device name (e.g. "cuda", "cuda:0", "cpu"). + Default is "cuda" + --resize INTEGER Resize the image(s) & output maps to a specific + size. Default is None (no resizing). + --resolution_level INTEGER An integer [0-9] for the resolution level of + inference. The higher, the better but slower. + Default is 9. Note that it is irrelevant to the + output resolution. + --threshold FLOAT Threshold for removing edges. Default is 0.03. + Smaller value removes more edges. "inf" means no + thresholding. + --maps Whether to save the output maps and fov(image, + depth, mask, points, fov). + --glb Whether to save the output as a.glb file. The + color will be saved as a texture. + --ply Whether to save the output as a.ply file. The + color will be saved as vertex colors. + --show Whether show the output in a window. Note that + this requires pyglet<2 installed as required by + trimesh. + --help Show this message and exit. +``` + +### Using [scripts/infer_panorama.py](scripts/infer_panorama.py) for 360° panorama images + +> *NOTE: This is an experimental extension of MoGe.* + +The script will split the 360-degree panorama image into multiple perspective views and infer on each view separately. +The output maps will be combined to produce a panorama depth map and point map. + +Note that the panorama image must have spherical parameterization (e.g., environment maps or equirectangular images). Other formats must be converted to spherical format before using this script. Run `python scripts/infer_panorama.py --help` for detailed options. + + +
+ + +The photo is from [this URL](https://commons.wikimedia.org/wiki/Category:360%C2%B0_panoramas_with_equirectangular_projection#/media/File:Braunschweig_Sankt-%C3%84gidien_Panorama_02.jpg) +
+ +## License + +MoGe code is released under the MIT license, except for DINOv2 code in `moge/model/dinov2` which is released by Meta AI under the Apache 2.0 license. +See [LICENSE](LICENSE) for more details. + + +## Citation + +If you find our work useful in your research, we gratefully request that you consider citing our paper: + +``` +@misc{wang2024moge, + title={MoGe: Unlocking Accurate Monocular Geometry Estimation for Open-Domain Images with Optimal Training Supervision}, + author={Wang, Ruicheng and Xu, Sicheng and Dai, Cassie and Xiang, Jianfeng and Deng, Yu and Tong, Xin and Yang, Jiaolong}, + year={2024}, + eprint={2410.19115}, + archivePrefix={arXiv}, + primaryClass={cs.CV}, + url={https://arxiv.org/abs/2410.19115}, +} +``` diff --git a/submodules/MoGe/SECURITY.md b/submodules/MoGe/SECURITY.md new file mode 100644 index 0000000000000000000000000000000000000000..b3c89efc852e22f71eabf5dfbc6ac62493425eb6 --- /dev/null +++ b/submodules/MoGe/SECURITY.md @@ -0,0 +1,41 @@ + + +## Security + +Microsoft takes the security of our software products and services seriously, which includes all source code repositories managed through our GitHub organizations, which include [Microsoft](https://github.com/Microsoft), [Azure](https://github.com/Azure), [DotNet](https://github.com/dotnet), [AspNet](https://github.com/aspnet) and [Xamarin](https://github.com/xamarin). + +If you believe you have found a security vulnerability in any Microsoft-owned repository that meets [Microsoft's definition of a security vulnerability](https://aka.ms/security.md/definition), please report it to us as described below. + +## Reporting Security Issues + +**Please do not report security vulnerabilities through public GitHub issues.** + +Instead, please report them to the Microsoft Security Response Center (MSRC) at [https://msrc.microsoft.com/create-report](https://aka.ms/security.md/msrc/create-report). + +If you prefer to submit without logging in, send email to [secure@microsoft.com](mailto:secure@microsoft.com). If possible, encrypt your message with our PGP key; please download it from the [Microsoft Security Response Center PGP Key page](https://aka.ms/security.md/msrc/pgp). + +You should receive a response within 24 hours. If for some reason you do not, please follow up via email to ensure we received your original message. Additional information can be found at [microsoft.com/msrc](https://www.microsoft.com/msrc). + +Please include the requested information listed below (as much as you can provide) to help us better understand the nature and scope of the possible issue: + + * Type of issue (e.g. buffer overflow, SQL injection, cross-site scripting, etc.) + * Full paths of source file(s) related to the manifestation of the issue + * The location of the affected source code (tag/branch/commit or direct URL) + * Any special configuration required to reproduce the issue + * Step-by-step instructions to reproduce the issue + * Proof-of-concept or exploit code (if possible) + * Impact of the issue, including how an attacker might exploit the issue + +This information will help us triage your report more quickly. + +If you are reporting for a bug bounty, more complete reports can contribute to a higher bounty award. Please visit our [Microsoft Bug Bounty Program](https://aka.ms/security.md/msrc/bounty) page for more details about our active programs. + +## Preferred Languages + +We prefer all communications to be in English. + +## Policy + +Microsoft follows the principle of [Coordinated Vulnerability Disclosure](https://aka.ms/security.md/cvd). + + diff --git a/submodules/MoGe/SUPPORT.md b/submodules/MoGe/SUPPORT.md new file mode 100644 index 0000000000000000000000000000000000000000..291d4d43733f4c15a81ff598ec1c99fd6c18f64c --- /dev/null +++ b/submodules/MoGe/SUPPORT.md @@ -0,0 +1,25 @@ +# TODO: The maintainer of this repo has not yet edited this file + +**REPO OWNER**: Do you want Customer Service & Support (CSS) support for this product/project? + +- **No CSS support:** Fill out this template with information about how to file issues and get help. +- **Yes CSS support:** Fill out an intake form at [aka.ms/onboardsupport](https://aka.ms/onboardsupport). CSS will work with/help you to determine next steps. +- **Not sure?** Fill out an intake as though the answer were "Yes". CSS will help you decide. + +*Then remove this first heading from this SUPPORT.MD file before publishing your repo.* + +# Support + +## How to file issues and get help + +This project uses GitHub Issues to track bugs and feature requests. Please search the existing +issues before filing new issues to avoid duplicates. For new issues, file your bug or +feature request as a new Issue. + +For help and questions about using this project, please **REPO MAINTAINER: INSERT INSTRUCTIONS HERE +FOR HOW TO ENGAGE REPO OWNERS OR COMMUNITY FOR HELP. COULD BE A STACK OVERFLOW TAG OR OTHER +CHANNEL. WHERE WILL YOU HELP PEOPLE?**. + +## Microsoft Support Policy + +Support for this **PROJECT or PRODUCT** is limited to the resources listed above. diff --git a/submodules/MoGe/moge/model/__init__.py b/submodules/MoGe/moge/model/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..52b16d09fc22a91d1c8f947909971ee9830d5db3 --- /dev/null +++ b/submodules/MoGe/moge/model/__init__.py @@ -0,0 +1 @@ +from .moge_model import MoGeModel \ No newline at end of file diff --git a/submodules/MoGe/moge/model/dinov2/__init__.py b/submodules/MoGe/moge/model/dinov2/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..ae847e46898077fe3d8701b8a181d7b4e3d41cd9 --- /dev/null +++ b/submodules/MoGe/moge/model/dinov2/__init__.py @@ -0,0 +1,6 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +__version__ = "0.0.1" diff --git a/submodules/MoGe/moge/model/dinov2/hub/__init__.py b/submodules/MoGe/moge/model/dinov2/hub/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b88da6bf80be92af00b72dfdb0a806fa64a7a2d9 --- /dev/null +++ b/submodules/MoGe/moge/model/dinov2/hub/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. diff --git a/submodules/MoGe/moge/model/dinov2/hub/backbones.py b/submodules/MoGe/moge/model/dinov2/hub/backbones.py new file mode 100644 index 0000000000000000000000000000000000000000..53fe83719d5107eb77a8f25ef1814c3d73446002 --- /dev/null +++ b/submodules/MoGe/moge/model/dinov2/hub/backbones.py @@ -0,0 +1,156 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +from enum import Enum +from typing import Union + +import torch + +from .utils import _DINOV2_BASE_URL, _make_dinov2_model_name + + +class Weights(Enum): + LVD142M = "LVD142M" + + +def _make_dinov2_model( + *, + arch_name: str = "vit_large", + img_size: int = 518, + patch_size: int = 14, + init_values: float = 1.0, + ffn_layer: str = "mlp", + block_chunks: int = 0, + num_register_tokens: int = 0, + interpolate_antialias: bool = False, + interpolate_offset: float = 0.1, + pretrained: bool = True, + weights: Union[Weights, str] = Weights.LVD142M, + **kwargs, +): + from ..models import vision_transformer as vits + + if isinstance(weights, str): + try: + weights = Weights[weights] + except KeyError: + raise AssertionError(f"Unsupported weights: {weights}") + + model_base_name = _make_dinov2_model_name(arch_name, patch_size) + vit_kwargs = dict( + img_size=img_size, + patch_size=patch_size, + init_values=init_values, + ffn_layer=ffn_layer, + block_chunks=block_chunks, + num_register_tokens=num_register_tokens, + interpolate_antialias=interpolate_antialias, + interpolate_offset=interpolate_offset, + ) + vit_kwargs.update(**kwargs) + model = vits.__dict__[arch_name](**vit_kwargs) + + if pretrained: + model_full_name = _make_dinov2_model_name(arch_name, patch_size, num_register_tokens) + url = _DINOV2_BASE_URL + f"/{model_base_name}/{model_full_name}_pretrain.pth" + state_dict = torch.hub.load_state_dict_from_url(url, map_location="cpu") + model.load_state_dict(state_dict, strict=True) + + return model + + +def dinov2_vits14(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs): + """ + DINOv2 ViT-S/14 model (optionally) pretrained on the LVD-142M dataset. + """ + return _make_dinov2_model(arch_name="vit_small", pretrained=pretrained, weights=weights, **kwargs) + + +def dinov2_vitb14(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs): + """ + DINOv2 ViT-B/14 model (optionally) pretrained on the LVD-142M dataset. + """ + return _make_dinov2_model(arch_name="vit_base", pretrained=pretrained, weights=weights, **kwargs) + + +def dinov2_vitl14(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs): + """ + DINOv2 ViT-L/14 model (optionally) pretrained on the LVD-142M dataset. + """ + return _make_dinov2_model(arch_name="vit_large", pretrained=pretrained, weights=weights, **kwargs) + + +def dinov2_vitg14(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs): + """ + DINOv2 ViT-g/14 model (optionally) pretrained on the LVD-142M dataset. + """ + return _make_dinov2_model( + arch_name="vit_giant2", + ffn_layer="swiglufused", + weights=weights, + pretrained=pretrained, + **kwargs, + ) + + +def dinov2_vits14_reg(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs): + """ + DINOv2 ViT-S/14 model with registers (optionally) pretrained on the LVD-142M dataset. + """ + return _make_dinov2_model( + arch_name="vit_small", + pretrained=pretrained, + weights=weights, + num_register_tokens=4, + interpolate_antialias=True, + interpolate_offset=0.0, + **kwargs, + ) + + +def dinov2_vitb14_reg(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs): + """ + DINOv2 ViT-B/14 model with registers (optionally) pretrained on the LVD-142M dataset. + """ + return _make_dinov2_model( + arch_name="vit_base", + pretrained=pretrained, + weights=weights, + num_register_tokens=4, + interpolate_antialias=True, + interpolate_offset=0.0, + **kwargs, + ) + + +def dinov2_vitl14_reg(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs): + """ + DINOv2 ViT-L/14 model with registers (optionally) pretrained on the LVD-142M dataset. + """ + return _make_dinov2_model( + arch_name="vit_large", + pretrained=pretrained, + weights=weights, + num_register_tokens=4, + interpolate_antialias=True, + interpolate_offset=0.0, + **kwargs, + ) + + +def dinov2_vitg14_reg(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs): + """ + DINOv2 ViT-g/14 model with registers (optionally) pretrained on the LVD-142M dataset. + """ + return _make_dinov2_model( + arch_name="vit_giant2", + ffn_layer="swiglufused", + weights=weights, + pretrained=pretrained, + num_register_tokens=4, + interpolate_antialias=True, + interpolate_offset=0.0, + **kwargs, + ) diff --git a/submodules/MoGe/moge/model/dinov2/hub/utils.py b/submodules/MoGe/moge/model/dinov2/hub/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..9c6641404093652d5a2f19b4cf283d976ec39e64 --- /dev/null +++ b/submodules/MoGe/moge/model/dinov2/hub/utils.py @@ -0,0 +1,39 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +import itertools +import math + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +_DINOV2_BASE_URL = "https://dl.fbaipublicfiles.com/dinov2" + + +def _make_dinov2_model_name(arch_name: str, patch_size: int, num_register_tokens: int = 0) -> str: + compact_arch_name = arch_name.replace("_", "")[:4] + registers_suffix = f"_reg{num_register_tokens}" if num_register_tokens else "" + return f"dinov2_{compact_arch_name}{patch_size}{registers_suffix}" + + +class CenterPadding(nn.Module): + def __init__(self, multiple): + super().__init__() + self.multiple = multiple + + def _get_pad(self, size): + new_size = math.ceil(size / self.multiple) * self.multiple + pad_size = new_size - size + pad_size_left = pad_size // 2 + pad_size_right = pad_size - pad_size_left + return pad_size_left, pad_size_right + + @torch.inference_mode() + def forward(self, x): + pads = list(itertools.chain.from_iterable(self._get_pad(m) for m in x.shape[:1:-1])) + output = F.pad(x, pads) + return output diff --git a/submodules/MoGe/moge/model/dinov2/layers/__init__.py b/submodules/MoGe/moge/model/dinov2/layers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..05a0b61868e43abb821ca05a813bab2b8b43629e --- /dev/null +++ b/submodules/MoGe/moge/model/dinov2/layers/__init__.py @@ -0,0 +1,11 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +from .dino_head import DINOHead +from .mlp import Mlp +from .patch_embed import PatchEmbed +from .swiglu_ffn import SwiGLUFFN, SwiGLUFFNFused +from .block import NestedTensorBlock +from .attention import MemEffAttention diff --git a/submodules/MoGe/moge/model/dinov2/layers/attention.py b/submodules/MoGe/moge/model/dinov2/layers/attention.py new file mode 100644 index 0000000000000000000000000000000000000000..3fed573116d5c837be46a7525d8acf77422c2400 --- /dev/null +++ b/submodules/MoGe/moge/model/dinov2/layers/attention.py @@ -0,0 +1,89 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +# References: +# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py +# https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py + +import logging +import os +import warnings + +from torch import Tensor +from torch import nn + + +logger = logging.getLogger("dinov2") + + +XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None +try: + if XFORMERS_ENABLED: + from xformers.ops import memory_efficient_attention, unbind + + XFORMERS_AVAILABLE = True + # warnings.warn("xFormers is available (Attention)") + else: + # warnings.warn("xFormers is disabled (Attention)") + raise ImportError +except ImportError: + XFORMERS_AVAILABLE = False + # warnings.warn("xFormers is not available (Attention)") + + +class Attention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int = 8, + qkv_bias: bool = False, + proj_bias: bool = True, + attn_drop: float = 0.0, + proj_drop: float = 0.0, + ) -> None: + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = head_dim**-0.5 + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim, bias=proj_bias) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, x: Tensor, attn_bias=None) -> Tensor: + B, N, C = x.shape + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + + q, k, v = qkv[0] * self.scale, qkv[1], qkv[2] + attn = q @ k.transpose(-2, -1) + + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class MemEffAttention(Attention): + def forward(self, x: Tensor, attn_bias=None) -> Tensor: + if not XFORMERS_AVAILABLE: + if attn_bias is not None: + raise AssertionError("xFormers is required for using nested tensors") + return super().forward(x) + + B, N, C = x.shape + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads) + + q, k, v = unbind(qkv, 2) + + x = memory_efficient_attention(q, k, v, attn_bias=attn_bias) + x = x.reshape([B, N, C]) + + x = self.proj(x) + x = self.proj_drop(x) + return x diff --git a/submodules/MoGe/moge/model/dinov2/layers/block.py b/submodules/MoGe/moge/model/dinov2/layers/block.py new file mode 100644 index 0000000000000000000000000000000000000000..fd5b8a7bb8527b74186af7c1e060e37bdb52c73d --- /dev/null +++ b/submodules/MoGe/moge/model/dinov2/layers/block.py @@ -0,0 +1,259 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +# References: +# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py +# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py + +import logging +import os +from typing import Callable, List, Any, Tuple, Dict +import warnings + +import torch +from torch import nn, Tensor + +from .attention import Attention, MemEffAttention +from .drop_path import DropPath +from .layer_scale import LayerScale +from .mlp import Mlp + + +logger = logging.getLogger("dinov2") + + +XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None +try: + if XFORMERS_ENABLED: + from xformers.ops import fmha, scaled_index_add, index_select_cat + + XFORMERS_AVAILABLE = True + # warnings.warn("xFormers is available (Block)") + else: + # warnings.warn("xFormers is disabled (Block)") + raise ImportError +except ImportError: + XFORMERS_AVAILABLE = False + # warnings.warn("xFormers is not available (Block)") + + +class Block(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + mlp_ratio: float = 4.0, + qkv_bias: bool = False, + proj_bias: bool = True, + ffn_bias: bool = True, + drop: float = 0.0, + attn_drop: float = 0.0, + init_values=None, + drop_path: float = 0.0, + act_layer: Callable[..., nn.Module] = nn.GELU, + norm_layer: Callable[..., nn.Module] = nn.LayerNorm, + attn_class: Callable[..., nn.Module] = Attention, + ffn_layer: Callable[..., nn.Module] = Mlp, + ) -> None: + super().__init__() + # print(f"biases: qkv: {qkv_bias}, proj: {proj_bias}, ffn: {ffn_bias}") + self.norm1 = norm_layer(dim) + self.attn = attn_class( + dim, + num_heads=num_heads, + qkv_bias=qkv_bias, + proj_bias=proj_bias, + attn_drop=attn_drop, + proj_drop=drop, + ) + self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity() + self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = ffn_layer( + in_features=dim, + hidden_features=mlp_hidden_dim, + act_layer=act_layer, + drop=drop, + bias=ffn_bias, + ) + self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity() + self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + + self.sample_drop_ratio = drop_path + + def forward(self, x: Tensor) -> Tensor: + def attn_residual_func(x: Tensor) -> Tensor: + return self.ls1(self.attn(self.norm1(x))) + + def ffn_residual_func(x: Tensor) -> Tensor: + return self.ls2(self.mlp(self.norm2(x))) + + if self.training and self.sample_drop_ratio > 0.1: + # the overhead is compensated only for a drop path rate larger than 0.1 + x = drop_add_residual_stochastic_depth( + x, + residual_func=attn_residual_func, + sample_drop_ratio=self.sample_drop_ratio, + ) + x = drop_add_residual_stochastic_depth( + x, + residual_func=ffn_residual_func, + sample_drop_ratio=self.sample_drop_ratio, + ) + elif self.training and self.sample_drop_ratio > 0.0: + x = x + self.drop_path1(attn_residual_func(x)) + x = x + self.drop_path1(ffn_residual_func(x)) # FIXME: drop_path2 + else: + x = x + attn_residual_func(x) + x = x + ffn_residual_func(x) + return x + + +def drop_add_residual_stochastic_depth( + x: Tensor, + residual_func: Callable[[Tensor], Tensor], + sample_drop_ratio: float = 0.0, +) -> Tensor: + # 1) extract subset using permutation + b, n, d = x.shape + sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1) + brange = (torch.randperm(b, device=x.device))[:sample_subset_size] + x_subset = x[brange] + + # 2) apply residual_func to get residual + residual = residual_func(x_subset) + + x_flat = x.flatten(1) + residual = residual.flatten(1) + + residual_scale_factor = b / sample_subset_size + + # 3) add the residual + x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor) + return x_plus_residual.view_as(x) + + +def get_branges_scales(x, sample_drop_ratio=0.0): + b, n, d = x.shape + sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1) + brange = (torch.randperm(b, device=x.device))[:sample_subset_size] + residual_scale_factor = b / sample_subset_size + return brange, residual_scale_factor + + +def add_residual(x, brange, residual, residual_scale_factor, scaling_vector=None): + if scaling_vector is None: + x_flat = x.flatten(1) + residual = residual.flatten(1) + x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor) + else: + x_plus_residual = scaled_index_add( + x, brange, residual.to(dtype=x.dtype), scaling=scaling_vector, alpha=residual_scale_factor + ) + return x_plus_residual + + +attn_bias_cache: Dict[Tuple, Any] = {} + + +def get_attn_bias_and_cat(x_list, branges=None): + """ + this will perform the index select, cat the tensors, and provide the attn_bias from cache + """ + batch_sizes = [b.shape[0] for b in branges] if branges is not None else [x.shape[0] for x in x_list] + all_shapes = tuple((b, x.shape[1]) for b, x in zip(batch_sizes, x_list)) + if all_shapes not in attn_bias_cache.keys(): + seqlens = [] + for b, x in zip(batch_sizes, x_list): + for _ in range(b): + seqlens.append(x.shape[1]) + attn_bias = fmha.BlockDiagonalMask.from_seqlens(seqlens) + attn_bias._batch_sizes = batch_sizes + attn_bias_cache[all_shapes] = attn_bias + + if branges is not None: + cat_tensors = index_select_cat([x.flatten(1) for x in x_list], branges).view(1, -1, x_list[0].shape[-1]) + else: + tensors_bs1 = tuple(x.reshape([1, -1, *x.shape[2:]]) for x in x_list) + cat_tensors = torch.cat(tensors_bs1, dim=1) + + return attn_bias_cache[all_shapes], cat_tensors + + +def drop_add_residual_stochastic_depth_list( + x_list: List[Tensor], + residual_func: Callable[[Tensor, Any], Tensor], + sample_drop_ratio: float = 0.0, + scaling_vector=None, +) -> Tensor: + # 1) generate random set of indices for dropping samples in the batch + branges_scales = [get_branges_scales(x, sample_drop_ratio=sample_drop_ratio) for x in x_list] + branges = [s[0] for s in branges_scales] + residual_scale_factors = [s[1] for s in branges_scales] + + # 2) get attention bias and index+concat the tensors + attn_bias, x_cat = get_attn_bias_and_cat(x_list, branges) + + # 3) apply residual_func to get residual, and split the result + residual_list = attn_bias.split(residual_func(x_cat, attn_bias=attn_bias)) # type: ignore + + outputs = [] + for x, brange, residual, residual_scale_factor in zip(x_list, branges, residual_list, residual_scale_factors): + outputs.append(add_residual(x, brange, residual, residual_scale_factor, scaling_vector).view_as(x)) + return outputs + + +class NestedTensorBlock(Block): + def forward_nested(self, x_list: List[Tensor]) -> List[Tensor]: + """ + x_list contains a list of tensors to nest together and run + """ + assert isinstance(self.attn, MemEffAttention) + + if self.training and self.sample_drop_ratio > 0.0: + + def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor: + return self.attn(self.norm1(x), attn_bias=attn_bias) + + def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor: + return self.mlp(self.norm2(x)) + + x_list = drop_add_residual_stochastic_depth_list( + x_list, + residual_func=attn_residual_func, + sample_drop_ratio=self.sample_drop_ratio, + scaling_vector=self.ls1.gamma if isinstance(self.ls1, LayerScale) else None, + ) + x_list = drop_add_residual_stochastic_depth_list( + x_list, + residual_func=ffn_residual_func, + sample_drop_ratio=self.sample_drop_ratio, + scaling_vector=self.ls2.gamma if isinstance(self.ls1, LayerScale) else None, + ) + return x_list + else: + + def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor: + return self.ls1(self.attn(self.norm1(x), attn_bias=attn_bias)) + + def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor: + return self.ls2(self.mlp(self.norm2(x))) + + attn_bias, x = get_attn_bias_and_cat(x_list) + x = x + attn_residual_func(x, attn_bias=attn_bias) + x = x + ffn_residual_func(x) + return attn_bias.split(x) + + def forward(self, x_or_x_list): + if isinstance(x_or_x_list, Tensor): + return super().forward(x_or_x_list) + elif isinstance(x_or_x_list, list): + if not XFORMERS_AVAILABLE: + raise AssertionError("xFormers is required for using nested tensors") + return self.forward_nested(x_or_x_list) + else: + raise AssertionError diff --git a/submodules/MoGe/moge/model/dinov2/layers/dino_head.py b/submodules/MoGe/moge/model/dinov2/layers/dino_head.py new file mode 100644 index 0000000000000000000000000000000000000000..0ace8ffd6297a1dd480b19db407b662a6ea0f565 --- /dev/null +++ b/submodules/MoGe/moge/model/dinov2/layers/dino_head.py @@ -0,0 +1,58 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +import torch +import torch.nn as nn +from torch.nn.init import trunc_normal_ +from torch.nn.utils import weight_norm + + +class DINOHead(nn.Module): + def __init__( + self, + in_dim, + out_dim, + use_bn=False, + nlayers=3, + hidden_dim=2048, + bottleneck_dim=256, + mlp_bias=True, + ): + super().__init__() + nlayers = max(nlayers, 1) + self.mlp = _build_mlp(nlayers, in_dim, bottleneck_dim, hidden_dim=hidden_dim, use_bn=use_bn, bias=mlp_bias) + self.apply(self._init_weights) + self.last_layer = weight_norm(nn.Linear(bottleneck_dim, out_dim, bias=False)) + self.last_layer.weight_g.data.fill_(1) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=0.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + + def forward(self, x): + x = self.mlp(x) + eps = 1e-6 if x.dtype == torch.float16 else 1e-12 + x = nn.functional.normalize(x, dim=-1, p=2, eps=eps) + x = self.last_layer(x) + return x + + +def _build_mlp(nlayers, in_dim, bottleneck_dim, hidden_dim=None, use_bn=False, bias=True): + if nlayers == 1: + return nn.Linear(in_dim, bottleneck_dim, bias=bias) + else: + layers = [nn.Linear(in_dim, hidden_dim, bias=bias)] + if use_bn: + layers.append(nn.BatchNorm1d(hidden_dim)) + layers.append(nn.GELU()) + for _ in range(nlayers - 2): + layers.append(nn.Linear(hidden_dim, hidden_dim, bias=bias)) + if use_bn: + layers.append(nn.BatchNorm1d(hidden_dim)) + layers.append(nn.GELU()) + layers.append(nn.Linear(hidden_dim, bottleneck_dim, bias=bias)) + return nn.Sequential(*layers) diff --git a/submodules/MoGe/moge/model/dinov2/layers/drop_path.py b/submodules/MoGe/moge/model/dinov2/layers/drop_path.py new file mode 100644 index 0000000000000000000000000000000000000000..1d640e0b969b8dcba96260243473700b4e5b24b5 --- /dev/null +++ b/submodules/MoGe/moge/model/dinov2/layers/drop_path.py @@ -0,0 +1,34 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +# References: +# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py +# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/drop.py + + +from torch import nn + + +def drop_path(x, drop_prob: float = 0.0, training: bool = False): + if drop_prob == 0.0 or not training: + return x + keep_prob = 1 - drop_prob + shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets + random_tensor = x.new_empty(shape).bernoulli_(keep_prob) + if keep_prob > 0.0: + random_tensor.div_(keep_prob) + output = x * random_tensor + return output + + +class DropPath(nn.Module): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" + + def __init__(self, drop_prob=None): + super(DropPath, self).__init__() + self.drop_prob = drop_prob + + def forward(self, x): + return drop_path(x, self.drop_prob, self.training) diff --git a/submodules/MoGe/moge/model/dinov2/layers/layer_scale.py b/submodules/MoGe/moge/model/dinov2/layers/layer_scale.py new file mode 100644 index 0000000000000000000000000000000000000000..51df0d7ce61f2b41fa9e6369f52391dd7fe7d386 --- /dev/null +++ b/submodules/MoGe/moge/model/dinov2/layers/layer_scale.py @@ -0,0 +1,27 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +# Modified from: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py#L103-L110 + +from typing import Union + +import torch +from torch import Tensor +from torch import nn + + +class LayerScale(nn.Module): + def __init__( + self, + dim: int, + init_values: Union[float, Tensor] = 1e-5, + inplace: bool = False, + ) -> None: + super().__init__() + self.inplace = inplace + self.gamma = nn.Parameter(init_values * torch.ones(dim)) + + def forward(self, x: Tensor) -> Tensor: + return x.mul_(self.gamma) if self.inplace else x * self.gamma diff --git a/submodules/MoGe/moge/model/dinov2/layers/mlp.py b/submodules/MoGe/moge/model/dinov2/layers/mlp.py new file mode 100644 index 0000000000000000000000000000000000000000..bbf9432aae9258612caeae910a7bde17999e328e --- /dev/null +++ b/submodules/MoGe/moge/model/dinov2/layers/mlp.py @@ -0,0 +1,40 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +# References: +# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py +# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/mlp.py + + +from typing import Callable, Optional + +from torch import Tensor, nn + + +class Mlp(nn.Module): + def __init__( + self, + in_features: int, + hidden_features: Optional[int] = None, + out_features: Optional[int] = None, + act_layer: Callable[..., nn.Module] = nn.GELU, + drop: float = 0.0, + bias: bool = True, + ) -> None: + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features, bias=bias) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features, bias=bias) + self.drop = nn.Dropout(drop) + + def forward(self, x: Tensor) -> Tensor: + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x diff --git a/submodules/MoGe/moge/model/dinov2/layers/patch_embed.py b/submodules/MoGe/moge/model/dinov2/layers/patch_embed.py new file mode 100644 index 0000000000000000000000000000000000000000..8b7c0804784a42cf80c0297d110dcc68cc85b339 --- /dev/null +++ b/submodules/MoGe/moge/model/dinov2/layers/patch_embed.py @@ -0,0 +1,88 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +# References: +# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py +# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py + +from typing import Callable, Optional, Tuple, Union + +from torch import Tensor +import torch.nn as nn + + +def make_2tuple(x): + if isinstance(x, tuple): + assert len(x) == 2 + return x + + assert isinstance(x, int) + return (x, x) + + +class PatchEmbed(nn.Module): + """ + 2D image to patch embedding: (B,C,H,W) -> (B,N,D) + + Args: + img_size: Image size. + patch_size: Patch token size. + in_chans: Number of input image channels. + embed_dim: Number of linear projection output channels. + norm_layer: Normalization layer. + """ + + def __init__( + self, + img_size: Union[int, Tuple[int, int]] = 224, + patch_size: Union[int, Tuple[int, int]] = 16, + in_chans: int = 3, + embed_dim: int = 768, + norm_layer: Optional[Callable] = None, + flatten_embedding: bool = True, + ) -> None: + super().__init__() + + image_HW = make_2tuple(img_size) + patch_HW = make_2tuple(patch_size) + patch_grid_size = ( + image_HW[0] // patch_HW[0], + image_HW[1] // patch_HW[1], + ) + + self.img_size = image_HW + self.patch_size = patch_HW + self.patches_resolution = patch_grid_size + self.num_patches = patch_grid_size[0] * patch_grid_size[1] + + self.in_chans = in_chans + self.embed_dim = embed_dim + + self.flatten_embedding = flatten_embedding + + self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_HW, stride=patch_HW) + self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() + + def forward(self, x: Tensor) -> Tensor: + _, _, H, W = x.shape + patch_H, patch_W = self.patch_size + + assert H % patch_H == 0, f"Input image height {H} is not a multiple of patch height {patch_H}" + assert W % patch_W == 0, f"Input image width {W} is not a multiple of patch width: {patch_W}" + + x = self.proj(x) # B C H W + H, W = x.size(2), x.size(3) + x = x.flatten(2).transpose(1, 2) # B HW C + x = self.norm(x) + if not self.flatten_embedding: + x = x.reshape(-1, H, W, self.embed_dim) # B H W C + return x + + def flops(self) -> float: + Ho, Wo = self.patches_resolution + flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1]) + if self.norm is not None: + flops += Ho * Wo * self.embed_dim + return flops diff --git a/submodules/MoGe/moge/model/dinov2/layers/swiglu_ffn.py b/submodules/MoGe/moge/model/dinov2/layers/swiglu_ffn.py new file mode 100644 index 0000000000000000000000000000000000000000..5ce211515774d42e04c8b51003bae53b88f14b35 --- /dev/null +++ b/submodules/MoGe/moge/model/dinov2/layers/swiglu_ffn.py @@ -0,0 +1,72 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +import os +from typing import Callable, Optional +import warnings + +from torch import Tensor, nn +import torch.nn.functional as F + + +class SwiGLUFFN(nn.Module): + def __init__( + self, + in_features: int, + hidden_features: Optional[int] = None, + out_features: Optional[int] = None, + act_layer: Callable[..., nn.Module] = None, + drop: float = 0.0, + bias: bool = True, + ) -> None: + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.w12 = nn.Linear(in_features, 2 * hidden_features, bias=bias) + self.w3 = nn.Linear(hidden_features, out_features, bias=bias) + + def forward(self, x: Tensor) -> Tensor: + x12 = self.w12(x) + x1, x2 = x12.chunk(2, dim=-1) + hidden = F.silu(x1) * x2 + return self.w3(hidden) + + +XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None +try: + if XFORMERS_ENABLED: + from xformers.ops import SwiGLU + + XFORMERS_AVAILABLE = True + # warnings.warn("xFormers is available (SwiGLU)") + else: + # warnings.warn("xFormers is disabled (SwiGLU)") + raise ImportError +except ImportError: + SwiGLU = SwiGLUFFN + XFORMERS_AVAILABLE = False + + # warnings.warn("xFormers is not available (SwiGLU)") + + +class SwiGLUFFNFused(SwiGLU): + def __init__( + self, + in_features: int, + hidden_features: Optional[int] = None, + out_features: Optional[int] = None, + act_layer: Callable[..., nn.Module] = None, + drop: float = 0.0, + bias: bool = True, + ) -> None: + out_features = out_features or in_features + hidden_features = hidden_features or in_features + hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8 + super().__init__( + in_features=in_features, + hidden_features=hidden_features, + out_features=out_features, + bias=bias, + ) diff --git a/submodules/MoGe/moge/model/dinov2/models/__init__.py b/submodules/MoGe/moge/model/dinov2/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..3fdff20badbd5244bf79f16bf18dd2cb73982265 --- /dev/null +++ b/submodules/MoGe/moge/model/dinov2/models/__init__.py @@ -0,0 +1,43 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +import logging + +from . import vision_transformer as vits + + +logger = logging.getLogger("dinov2") + + +def build_model(args, only_teacher=False, img_size=224): + args.arch = args.arch.removesuffix("_memeff") + if "vit" in args.arch: + vit_kwargs = dict( + img_size=img_size, + patch_size=args.patch_size, + init_values=args.layerscale, + ffn_layer=args.ffn_layer, + block_chunks=args.block_chunks, + qkv_bias=args.qkv_bias, + proj_bias=args.proj_bias, + ffn_bias=args.ffn_bias, + num_register_tokens=args.num_register_tokens, + interpolate_offset=args.interpolate_offset, + interpolate_antialias=args.interpolate_antialias, + ) + teacher = vits.__dict__[args.arch](**vit_kwargs) + if only_teacher: + return teacher, teacher.embed_dim + student = vits.__dict__[args.arch]( + **vit_kwargs, + drop_path_rate=args.drop_path_rate, + drop_path_uniform=args.drop_path_uniform, + ) + embed_dim = student.embed_dim + return student, teacher, embed_dim + + +def build_model_from_cfg(cfg, only_teacher=False): + return build_model(cfg.student, only_teacher=only_teacher, img_size=cfg.crops.global_crops_size) diff --git a/submodules/MoGe/moge/model/dinov2/models/vision_transformer.py b/submodules/MoGe/moge/model/dinov2/models/vision_transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..1007ba57ddb35109c91716f1f5bf203db346e7be --- /dev/null +++ b/submodules/MoGe/moge/model/dinov2/models/vision_transformer.py @@ -0,0 +1,396 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +# References: +# https://github.com/facebookresearch/dino/blob/main/vision_transformer.py +# https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py + +from functools import partial +import math +import logging +from typing import Sequence, Tuple, Union, Callable + +import torch +import torch.nn as nn +import torch.utils.checkpoint +from torch.nn.init import trunc_normal_ + +from ..layers import Mlp, PatchEmbed, SwiGLUFFNFused, MemEffAttention, NestedTensorBlock as Block + + +logger = logging.getLogger("dinov2") + + +def named_apply(fn: Callable, module: nn.Module, name="", depth_first=True, include_root=False) -> nn.Module: + if not depth_first and include_root: + fn(module=module, name=name) + for child_name, child_module in module.named_children(): + child_name = ".".join((name, child_name)) if name else child_name + named_apply(fn=fn, module=child_module, name=child_name, depth_first=depth_first, include_root=True) + if depth_first and include_root: + fn(module=module, name=name) + return module + + +class BlockChunk(nn.ModuleList): + def forward(self, x): + for b in self: + x = b(x) + return x + + +class DinoVisionTransformer(nn.Module): + def __init__( + self, + img_size=224, + patch_size=16, + in_chans=3, + embed_dim=768, + depth=12, + num_heads=12, + mlp_ratio=4.0, + qkv_bias=True, + ffn_bias=True, + proj_bias=True, + drop_path_rate=0.0, + drop_path_uniform=False, + init_values=None, # for layerscale: None or 0 => no layerscale + embed_layer=PatchEmbed, + act_layer=nn.GELU, + block_fn=Block, + ffn_layer="mlp", + block_chunks=1, + num_register_tokens=0, + interpolate_antialias=False, + interpolate_offset=0.1, + ): + """ + Args: + img_size (int, tuple): input image size + patch_size (int, tuple): patch size + in_chans (int): number of input channels + embed_dim (int): embedding dimension + depth (int): depth of transformer + num_heads (int): number of attention heads + mlp_ratio (int): ratio of mlp hidden dim to embedding dim + qkv_bias (bool): enable bias for qkv if True + proj_bias (bool): enable bias for proj in attn if True + ffn_bias (bool): enable bias for ffn if True + drop_path_rate (float): stochastic depth rate + drop_path_uniform (bool): apply uniform drop rate across blocks + weight_init (str): weight init scheme + init_values (float): layer-scale init values + embed_layer (nn.Module): patch embedding layer + act_layer (nn.Module): MLP activation layer + block_fn (nn.Module): transformer block class + ffn_layer (str): "mlp", "swiglu", "swiglufused" or "identity" + block_chunks: (int) split block sequence into block_chunks units for FSDP wrap + num_register_tokens: (int) number of extra cls tokens (so-called "registers") + interpolate_antialias: (str) flag to apply anti-aliasing when interpolating positional embeddings + interpolate_offset: (float) work-around offset to apply when interpolating positional embeddings + """ + super().__init__() + norm_layer = partial(nn.LayerNorm, eps=1e-6) + + self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models + self.num_tokens = 1 + self.n_blocks = depth + self.num_heads = num_heads + self.patch_size = patch_size + self.num_register_tokens = num_register_tokens + self.interpolate_antialias = interpolate_antialias + self.interpolate_offset = interpolate_offset + + self.patch_embed = embed_layer(img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim) + num_patches = self.patch_embed.num_patches + + self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) + self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim)) + assert num_register_tokens >= 0 + self.register_tokens = ( + nn.Parameter(torch.zeros(1, num_register_tokens, embed_dim)) if num_register_tokens else None + ) + + if drop_path_uniform is True: + dpr = [drop_path_rate] * depth + else: + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule + + if ffn_layer == "mlp": + logger.info("using MLP layer as FFN") + ffn_layer = Mlp + elif ffn_layer == "swiglufused" or ffn_layer == "swiglu": + logger.info("using SwiGLU layer as FFN") + ffn_layer = SwiGLUFFNFused + elif ffn_layer == "identity": + logger.info("using Identity layer as FFN") + + def f(*args, **kwargs): + return nn.Identity() + + ffn_layer = f + else: + raise NotImplementedError + + blocks_list = [ + block_fn( + dim=embed_dim, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + proj_bias=proj_bias, + ffn_bias=ffn_bias, + drop_path=dpr[i], + norm_layer=norm_layer, + act_layer=act_layer, + ffn_layer=ffn_layer, + init_values=init_values, + ) + for i in range(depth) + ] + if block_chunks > 0: + self.chunked_blocks = True + chunked_blocks = [] + chunksize = depth // block_chunks + for i in range(0, depth, chunksize): + # this is to keep the block index consistent if we chunk the block list + chunked_blocks.append([nn.Identity()] * i + blocks_list[i : i + chunksize]) + self.blocks = nn.ModuleList([BlockChunk(p) for p in chunked_blocks]) + else: + self.chunked_blocks = False + self.blocks = nn.ModuleList(blocks_list) + + self.norm = norm_layer(embed_dim) + self.head = nn.Identity() + + self.mask_token = nn.Parameter(torch.zeros(1, embed_dim)) + + self.init_weights() + + def init_weights(self): + trunc_normal_(self.pos_embed, std=0.02) + nn.init.normal_(self.cls_token, std=1e-6) + if self.register_tokens is not None: + nn.init.normal_(self.register_tokens, std=1e-6) + named_apply(init_weights_vit_timm, self) + + def interpolate_pos_encoding(self, x, w, h): + previous_dtype = x.dtype + npatch = x.shape[1] - 1 + N = self.pos_embed.shape[1] - 1 + if npatch == N and w == h: + return self.pos_embed + pos_embed = self.pos_embed.float() + class_pos_embed = pos_embed[:, 0] + patch_pos_embed = pos_embed[:, 1:] + dim = x.shape[-1] + w0 = w // self.patch_size + h0 = h // self.patch_size + M = int(math.sqrt(N)) # Recover the number of patches in each dimension + assert N == M * M + kwargs = {} + if self.interpolate_offset: + # Historical kludge: add a small number to avoid floating point error in the interpolation, see https://github.com/facebookresearch/dino/issues/8 + # Note: still needed for backward-compatibility, the underlying operators are using both output size and scale factors + sx = float(w0 + self.interpolate_offset) / M + sy = float(h0 + self.interpolate_offset) / M + kwargs["scale_factor"] = (sx, sy) + else: + # Simply specify an output size instead of a scale factor + kwargs["size"] = (w0, h0) + patch_pos_embed = nn.functional.interpolate( + patch_pos_embed.reshape(1, M, M, dim).permute(0, 3, 1, 2), + mode="bicubic", + antialias=self.interpolate_antialias, + **kwargs, + ) + assert (w0, h0) == patch_pos_embed.shape[-2:] + patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) + return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1).to(previous_dtype) + + def prepare_tokens_with_masks(self, x, masks=None): + B, nc, w, h = x.shape + x = self.patch_embed(x) + if masks is not None: + x = torch.where(masks.unsqueeze(-1), self.mask_token.to(x.dtype).unsqueeze(0), x) + + x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1) + x = x + self.interpolate_pos_encoding(x, w, h) + + if self.register_tokens is not None: + x = torch.cat( + ( + x[:, :1], + self.register_tokens.expand(x.shape[0], -1, -1), + x[:, 1:], + ), + dim=1, + ) + + return x + + def forward_features_list(self, x_list, masks_list): + x = [self.prepare_tokens_with_masks(x, masks) for x, masks in zip(x_list, masks_list)] + for blk in self.blocks: + x = blk(x) + + all_x = x + output = [] + for x, masks in zip(all_x, masks_list): + x_norm = self.norm(x) + output.append( + { + "x_norm_clstoken": x_norm[:, 0], + "x_norm_regtokens": x_norm[:, 1 : self.num_register_tokens + 1], + "x_norm_patchtokens": x_norm[:, self.num_register_tokens + 1 :], + "x_prenorm": x, + "masks": masks, + } + ) + return output + + def forward_features(self, x, masks=None): + if isinstance(x, list): + return self.forward_features_list(x, masks) + + x = self.prepare_tokens_with_masks(x, masks) + + for blk in self.blocks: + x = blk(x) + + x_norm = self.norm(x) + return { + "x_norm_clstoken": x_norm[:, 0], + "x_norm_regtokens": x_norm[:, 1 : self.num_register_tokens + 1], + "x_norm_patchtokens": x_norm[:, self.num_register_tokens + 1 :], + "x_prenorm": x, + "masks": masks, + } + + def _get_intermediate_layers_not_chunked(self, x, n=1): + x = self.prepare_tokens_with_masks(x) + # If n is an int, take the n last blocks. If it's a list, take them + output, total_block_len = [], len(self.blocks) + blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n + for i, blk in enumerate(self.blocks): + x = blk(x) + if i in blocks_to_take: + output.append(x) + assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found" + return output + + def _get_intermediate_layers_chunked(self, x, n=1): + x = self.prepare_tokens_with_masks(x) + output, i, total_block_len = [], 0, len(self.blocks[-1]) + # If n is an int, take the n last blocks. If it's a list, take them + blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n + for block_chunk in self.blocks: + for blk in block_chunk[i:]: # Passing the nn.Identity() + x = blk(x) + if i in blocks_to_take: + output.append(x) + i += 1 + assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found" + return output + + def get_intermediate_layers( + self, + x: torch.Tensor, + n: Union[int, Sequence] = 1, # Layers or n last layers to take + reshape: bool = False, + return_class_token: bool = False, + norm=True, + ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]]]: + if self.chunked_blocks: + outputs = self._get_intermediate_layers_chunked(x, n) + else: + outputs = self._get_intermediate_layers_not_chunked(x, n) + if norm: + outputs = [self.norm(out) for out in outputs] + class_tokens = [out[:, 0] for out in outputs] + outputs = [out[:, 1 + self.num_register_tokens :] for out in outputs] + if reshape: + B, _, w, h = x.shape + outputs = [ + out.reshape(B, w // self.patch_size, h // self.patch_size, -1).permute(0, 3, 1, 2).contiguous() + for out in outputs + ] + if return_class_token: + return tuple(zip(outputs, class_tokens)) + return tuple(outputs) + + def forward(self, *args, is_training=False, **kwargs): + ret = self.forward_features(*args, **kwargs) + if is_training: + return ret + else: + return self.head(ret["x_norm_clstoken"]) + + +def init_weights_vit_timm(module: nn.Module, name: str = ""): + """ViT weight initialization, original timm impl (for reproducibility)""" + if isinstance(module, nn.Linear): + trunc_normal_(module.weight, std=0.02) + if module.bias is not None: + nn.init.zeros_(module.bias) + + +def vit_small(patch_size=16, num_register_tokens=0, **kwargs): + model = DinoVisionTransformer( + patch_size=patch_size, + embed_dim=384, + depth=12, + num_heads=6, + mlp_ratio=4, + block_fn=partial(Block, attn_class=MemEffAttention), + num_register_tokens=num_register_tokens, + **kwargs, + ) + return model + + +def vit_base(patch_size=16, num_register_tokens=0, **kwargs): + model = DinoVisionTransformer( + patch_size=patch_size, + embed_dim=768, + depth=12, + num_heads=12, + mlp_ratio=4, + block_fn=partial(Block, attn_class=MemEffAttention), + num_register_tokens=num_register_tokens, + **kwargs, + ) + return model + + +def vit_large(patch_size=16, num_register_tokens=0, **kwargs): + model = DinoVisionTransformer( + patch_size=patch_size, + embed_dim=1024, + depth=24, + num_heads=16, + mlp_ratio=4, + block_fn=partial(Block, attn_class=MemEffAttention), + num_register_tokens=num_register_tokens, + **kwargs, + ) + return model + + +def vit_giant2(patch_size=16, num_register_tokens=0, **kwargs): + """ + Close to ViT-giant, with embed-dim 1536 and 24 heads => embed-dim per head 64 + """ + model = DinoVisionTransformer( + patch_size=patch_size, + embed_dim=1536, + depth=40, + num_heads=24, + mlp_ratio=4, + block_fn=partial(Block, attn_class=MemEffAttention), + num_register_tokens=num_register_tokens, + **kwargs, + ) + return model diff --git a/submodules/MoGe/moge/model/dinov2/utils/__init__.py b/submodules/MoGe/moge/model/dinov2/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b88da6bf80be92af00b72dfdb0a806fa64a7a2d9 --- /dev/null +++ b/submodules/MoGe/moge/model/dinov2/utils/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. diff --git a/submodules/MoGe/moge/model/dinov2/utils/cluster.py b/submodules/MoGe/moge/model/dinov2/utils/cluster.py new file mode 100644 index 0000000000000000000000000000000000000000..3df87dc3e1eb4f0f8a280dc3137cfef031886314 --- /dev/null +++ b/submodules/MoGe/moge/model/dinov2/utils/cluster.py @@ -0,0 +1,95 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +from enum import Enum +import os +from pathlib import Path +from typing import Any, Dict, Optional + + +class ClusterType(Enum): + AWS = "aws" + FAIR = "fair" + RSC = "rsc" + + +def _guess_cluster_type() -> ClusterType: + uname = os.uname() + if uname.sysname == "Linux": + if uname.release.endswith("-aws"): + # Linux kernel versions on AWS instances are of the form "5.4.0-1051-aws" + return ClusterType.AWS + elif uname.nodename.startswith("rsc"): + # Linux kernel versions on RSC instances are standard ones but hostnames start with "rsc" + return ClusterType.RSC + + return ClusterType.FAIR + + +def get_cluster_type(cluster_type: Optional[ClusterType] = None) -> Optional[ClusterType]: + if cluster_type is None: + return _guess_cluster_type() + + return cluster_type + + +def get_checkpoint_path(cluster_type: Optional[ClusterType] = None) -> Optional[Path]: + cluster_type = get_cluster_type(cluster_type) + if cluster_type is None: + return None + + CHECKPOINT_DIRNAMES = { + ClusterType.AWS: "checkpoints", + ClusterType.FAIR: "checkpoint", + ClusterType.RSC: "checkpoint/dino", + } + return Path("/") / CHECKPOINT_DIRNAMES[cluster_type] + + +def get_user_checkpoint_path(cluster_type: Optional[ClusterType] = None) -> Optional[Path]: + checkpoint_path = get_checkpoint_path(cluster_type) + if checkpoint_path is None: + return None + + username = os.environ.get("USER") + assert username is not None + return checkpoint_path / username + + +def get_slurm_partition(cluster_type: Optional[ClusterType] = None) -> Optional[str]: + cluster_type = get_cluster_type(cluster_type) + if cluster_type is None: + return None + + SLURM_PARTITIONS = { + ClusterType.AWS: "learnlab", + ClusterType.FAIR: "learnlab", + ClusterType.RSC: "learn", + } + return SLURM_PARTITIONS[cluster_type] + + +def get_slurm_executor_parameters( + nodes: int, num_gpus_per_node: int, cluster_type: Optional[ClusterType] = None, **kwargs +) -> Dict[str, Any]: + # create default parameters + params = { + "mem_gb": 0, # Requests all memory on a node, see https://slurm.schedmd.com/sbatch.html + "gpus_per_node": num_gpus_per_node, + "tasks_per_node": num_gpus_per_node, # one task per GPU + "cpus_per_task": 10, + "nodes": nodes, + "slurm_partition": get_slurm_partition(cluster_type), + } + # apply cluster-specific adjustments + cluster_type = get_cluster_type(cluster_type) + if cluster_type == ClusterType.AWS: + params["cpus_per_task"] = 12 + del params["mem_gb"] + elif cluster_type == ClusterType.RSC: + params["cpus_per_task"] = 12 + # set additional parameters / apply overrides + params.update(kwargs) + return params diff --git a/submodules/MoGe/moge/model/dinov2/utils/config.py b/submodules/MoGe/moge/model/dinov2/utils/config.py new file mode 100644 index 0000000000000000000000000000000000000000..c9de578787bbcb376f8bd5a782206d0eb7ec1f52 --- /dev/null +++ b/submodules/MoGe/moge/model/dinov2/utils/config.py @@ -0,0 +1,72 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +import math +import logging +import os + +from omegaconf import OmegaConf + +import dinov2.distributed as distributed +from dinov2.logging import setup_logging +from dinov2.utils import utils +from dinov2.configs import dinov2_default_config + + +logger = logging.getLogger("dinov2") + + +def apply_scaling_rules_to_cfg(cfg): # to fix + if cfg.optim.scaling_rule == "sqrt_wrt_1024": + base_lr = cfg.optim.base_lr + cfg.optim.lr = base_lr + cfg.optim.lr *= math.sqrt(cfg.train.batch_size_per_gpu * distributed.get_global_size() / 1024.0) + logger.info(f"sqrt scaling learning rate; base: {base_lr}, new: {cfg.optim.lr}") + else: + raise NotImplementedError + return cfg + + +def write_config(cfg, output_dir, name="config.yaml"): + logger.info(OmegaConf.to_yaml(cfg)) + saved_cfg_path = os.path.join(output_dir, name) + with open(saved_cfg_path, "w") as f: + OmegaConf.save(config=cfg, f=f) + return saved_cfg_path + + +def get_cfg_from_args(args): + args.output_dir = os.path.abspath(args.output_dir) + args.opts += [f"train.output_dir={args.output_dir}"] + default_cfg = OmegaConf.create(dinov2_default_config) + cfg = OmegaConf.load(args.config_file) + cfg = OmegaConf.merge(default_cfg, cfg, OmegaConf.from_cli(args.opts)) + return cfg + + +def default_setup(args): + distributed.enable(overwrite=True) + seed = getattr(args, "seed", 0) + rank = distributed.get_global_rank() + + global logger + setup_logging(output=args.output_dir, level=logging.INFO) + logger = logging.getLogger("dinov2") + + utils.fix_random_seeds(seed + rank) + logger.info("git:\n {}\n".format(utils.get_sha())) + logger.info("\n".join("%s: %s" % (k, str(v)) for k, v in sorted(dict(vars(args)).items()))) + + +def setup(args): + """ + Create configs and perform basic setups. + """ + cfg = get_cfg_from_args(args) + os.makedirs(args.output_dir, exist_ok=True) + default_setup(args) + apply_scaling_rules_to_cfg(cfg) + write_config(cfg, args.output_dir) + return cfg diff --git a/submodules/MoGe/moge/model/dinov2/utils/dtype.py b/submodules/MoGe/moge/model/dinov2/utils/dtype.py new file mode 100644 index 0000000000000000000000000000000000000000..80f4cd74d99faa2731dbe9f8d3a13d71b3f8e3a8 --- /dev/null +++ b/submodules/MoGe/moge/model/dinov2/utils/dtype.py @@ -0,0 +1,37 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + + +from typing import Dict, Union + +import numpy as np +import torch + + +TypeSpec = Union[str, np.dtype, torch.dtype] + + +_NUMPY_TO_TORCH_DTYPE: Dict[np.dtype, torch.dtype] = { + np.dtype("bool"): torch.bool, + np.dtype("uint8"): torch.uint8, + np.dtype("int8"): torch.int8, + np.dtype("int16"): torch.int16, + np.dtype("int32"): torch.int32, + np.dtype("int64"): torch.int64, + np.dtype("float16"): torch.float16, + np.dtype("float32"): torch.float32, + np.dtype("float64"): torch.float64, + np.dtype("complex64"): torch.complex64, + np.dtype("complex128"): torch.complex128, +} + + +def as_torch_dtype(dtype: TypeSpec) -> torch.dtype: + if isinstance(dtype, torch.dtype): + return dtype + if isinstance(dtype, str): + dtype = np.dtype(dtype) + assert isinstance(dtype, np.dtype), f"Expected an instance of nunpy dtype, got {type(dtype)}" + return _NUMPY_TO_TORCH_DTYPE[dtype] diff --git a/submodules/MoGe/moge/model/dinov2/utils/param_groups.py b/submodules/MoGe/moge/model/dinov2/utils/param_groups.py new file mode 100644 index 0000000000000000000000000000000000000000..9a5d2ff627cddadc222e5f836864ee39c865208f --- /dev/null +++ b/submodules/MoGe/moge/model/dinov2/utils/param_groups.py @@ -0,0 +1,103 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +from collections import defaultdict +import logging + + +logger = logging.getLogger("dinov2") + + +def get_vit_lr_decay_rate(name, lr_decay_rate=1.0, num_layers=12, force_is_backbone=False, chunked_blocks=False): + """ + Calculate lr decay rate for different ViT blocks. + Args: + name (string): parameter name. + lr_decay_rate (float): base lr decay rate. + num_layers (int): number of ViT blocks. + Returns: + lr decay rate for the given parameter. + """ + layer_id = num_layers + 1 + if name.startswith("backbone") or force_is_backbone: + if ( + ".pos_embed" in name + or ".patch_embed" in name + or ".mask_token" in name + or ".cls_token" in name + or ".register_tokens" in name + ): + layer_id = 0 + elif force_is_backbone and ( + "pos_embed" in name + or "patch_embed" in name + or "mask_token" in name + or "cls_token" in name + or "register_tokens" in name + ): + layer_id = 0 + elif ".blocks." in name and ".residual." not in name: + layer_id = int(name[name.find(".blocks.") :].split(".")[2]) + 1 + elif chunked_blocks and "blocks." in name and "residual." not in name: + layer_id = int(name[name.find("blocks.") :].split(".")[2]) + 1 + elif "blocks." in name and "residual." not in name: + layer_id = int(name[name.find("blocks.") :].split(".")[1]) + 1 + + return lr_decay_rate ** (num_layers + 1 - layer_id) + + +def get_params_groups_with_decay(model, lr_decay_rate=1.0, patch_embed_lr_mult=1.0): + chunked_blocks = False + if hasattr(model, "n_blocks"): + logger.info("chunked fsdp") + n_blocks = model.n_blocks + chunked_blocks = model.chunked_blocks + elif hasattr(model, "blocks"): + logger.info("first code branch") + n_blocks = len(model.blocks) + elif hasattr(model, "backbone"): + logger.info("second code branch") + n_blocks = len(model.backbone.blocks) + else: + logger.info("else code branch") + n_blocks = 0 + all_param_groups = [] + + for name, param in model.named_parameters(): + name = name.replace("_fsdp_wrapped_module.", "") + if not param.requires_grad: + continue + decay_rate = get_vit_lr_decay_rate( + name, lr_decay_rate, num_layers=n_blocks, force_is_backbone=n_blocks > 0, chunked_blocks=chunked_blocks + ) + d = {"params": param, "is_last_layer": False, "lr_multiplier": decay_rate, "wd_multiplier": 1.0, "name": name} + + if "last_layer" in name: + d.update({"is_last_layer": True}) + + if name.endswith(".bias") or "norm" in name or "gamma" in name: + d.update({"wd_multiplier": 0.0}) + + if "patch_embed" in name: + d.update({"lr_multiplier": d["lr_multiplier"] * patch_embed_lr_mult}) + + all_param_groups.append(d) + logger.info(f"""{name}: lr_multiplier: {d["lr_multiplier"]}, wd_multiplier: {d["wd_multiplier"]}""") + + return all_param_groups + + +def fuse_params_groups(all_params_groups, keys=("lr_multiplier", "wd_multiplier", "is_last_layer")): + fused_params_groups = defaultdict(lambda: {"params": []}) + for d in all_params_groups: + identifier = "" + for k in keys: + identifier += k + str(d[k]) + "_" + + for k in keys: + fused_params_groups[identifier][k] = d[k] + fused_params_groups[identifier]["params"].append(d["params"]) + + return fused_params_groups.values() diff --git a/submodules/MoGe/moge/model/dinov2/utils/utils.py b/submodules/MoGe/moge/model/dinov2/utils/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..68f8e2c3be5f780bbb7e00359b5ac4fd0ba0785f --- /dev/null +++ b/submodules/MoGe/moge/model/dinov2/utils/utils.py @@ -0,0 +1,95 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +import logging +import os +import random +import subprocess +from urllib.parse import urlparse + +import numpy as np +import torch +from torch import nn + + +logger = logging.getLogger("dinov2") + + +def load_pretrained_weights(model, pretrained_weights, checkpoint_key): + if urlparse(pretrained_weights).scheme: # If it looks like an URL + state_dict = torch.hub.load_state_dict_from_url(pretrained_weights, map_location="cpu") + else: + state_dict = torch.load(pretrained_weights, map_location="cpu") + if checkpoint_key is not None and checkpoint_key in state_dict: + logger.info(f"Take key {checkpoint_key} in provided checkpoint dict") + state_dict = state_dict[checkpoint_key] + # remove `module.` prefix + state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()} + # remove `backbone.` prefix induced by multicrop wrapper + state_dict = {k.replace("backbone.", ""): v for k, v in state_dict.items()} + msg = model.load_state_dict(state_dict, strict=False) + logger.info("Pretrained weights found at {} and loaded with msg: {}".format(pretrained_weights, msg)) + + +def fix_random_seeds(seed=31): + """ + Fix random seeds. + """ + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + np.random.seed(seed) + random.seed(seed) + + +def get_sha(): + cwd = os.path.dirname(os.path.abspath(__file__)) + + def _run(command): + return subprocess.check_output(command, cwd=cwd).decode("ascii").strip() + + sha = "N/A" + diff = "clean" + branch = "N/A" + try: + sha = _run(["git", "rev-parse", "HEAD"]) + subprocess.check_output(["git", "diff"], cwd=cwd) + diff = _run(["git", "diff-index", "HEAD"]) + diff = "has uncommitted changes" if diff else "clean" + branch = _run(["git", "rev-parse", "--abbrev-ref", "HEAD"]) + except Exception: + pass + message = f"sha: {sha}, status: {diff}, branch: {branch}" + return message + + +class CosineScheduler(object): + def __init__(self, base_value, final_value, total_iters, warmup_iters=0, start_warmup_value=0, freeze_iters=0): + super().__init__() + self.final_value = final_value + self.total_iters = total_iters + + freeze_schedule = np.zeros((freeze_iters)) + + warmup_schedule = np.linspace(start_warmup_value, base_value, warmup_iters) + + iters = np.arange(total_iters - warmup_iters - freeze_iters) + schedule = final_value + 0.5 * (base_value - final_value) * (1 + np.cos(np.pi * iters / len(iters))) + self.schedule = np.concatenate((freeze_schedule, warmup_schedule, schedule)) + + assert len(self.schedule) == self.total_iters + + def __getitem__(self, it): + if it >= self.total_iters: + return self.final_value + else: + return self.schedule[it] + + +def has_batchnorms(model): + bn_types = (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d, nn.SyncBatchNorm) + for name, module in model.named_modules(): + if isinstance(module, bn_types): + return True + return False diff --git a/submodules/MoGe/moge/model/moge_model.py b/submodules/MoGe/moge/model/moge_model.py new file mode 100644 index 0000000000000000000000000000000000000000..014333763c0f2df9f34cd82cea6efaa61bf5f31e --- /dev/null +++ b/submodules/MoGe/moge/model/moge_model.py @@ -0,0 +1,389 @@ +from typing import * +from numbers import Number +from functools import partial +from pathlib import Path +import importlib +import warnings +import json + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.utils +import torch.utils.checkpoint +import torch.version +import utils3d +from huggingface_hub import hf_hub_download + +from ..utils.geometry_torch import normalized_view_plane_uv, recover_focal_shift, gaussian_blur_2d +from .utils import wrap_dinov2_attention_with_sdpa, wrap_module_with_gradient_checkpointing, unwrap_module_with_gradient_checkpointing +from ..utils.tools import timeit + + +class ResidualConvBlock(nn.Module): + def __init__(self, in_channels: int, out_channels: int = None, hidden_channels: int = None, padding_mode: str = 'replicate', activation: Literal['relu', 'leaky_relu', 'silu', 'elu'] = 'relu', norm: Literal['group_norm', 'layer_norm'] = 'group_norm'): + super(ResidualConvBlock, self).__init__() + if out_channels is None: + out_channels = in_channels + if hidden_channels is None: + hidden_channels = in_channels + + if activation =='relu': + activation_cls = lambda: nn.ReLU(inplace=True) + elif activation == 'leaky_relu': + activation_cls = lambda: nn.LeakyReLU(negative_slope=0.2, inplace=True) + elif activation =='silu': + activation_cls = lambda: nn.SiLU(inplace=True) + elif activation == 'elu': + activation_cls = lambda: nn.ELU(inplace=True) + else: + raise ValueError(f'Unsupported activation function: {activation}') + + self.layers = nn.Sequential( + nn.GroupNorm(1, in_channels), + activation_cls(), + nn.Conv2d(in_channels, hidden_channels, kernel_size=3, padding=1, padding_mode=padding_mode), + nn.GroupNorm(hidden_channels // 32 if norm == 'group_norm' else 1, hidden_channels), + activation_cls(), + nn.Conv2d(hidden_channels, out_channels, kernel_size=3, padding=1, padding_mode=padding_mode) + ) + + self.skip_connection = nn.Conv2d(in_channels, out_channels, kernel_size=1, padding=0) if in_channels != out_channels else nn.Identity() + + def forward(self, x): + skip = self.skip_connection(x) + x = self.layers(x) + x = x + skip + return x + + +class Head(nn.Module): + def __init__( + self, + num_features: int, + dim_in: int, + dim_out: List[int], + dim_proj: int = 512, + dim_upsample: List[int] = [256, 128, 128], + dim_times_res_block_hidden: int = 1, + num_res_blocks: int = 1, + res_block_norm: Literal['group_norm', 'layer_norm'] = 'group_norm', + last_res_blocks: int = 0, + last_conv_channels: int = 32, + last_conv_size: int = 1 + ): + super().__init__() + + self.projects = nn.ModuleList([ + nn.Conv2d(in_channels=dim_in, out_channels=dim_proj, kernel_size=1, stride=1, padding=0,) for _ in range(num_features) + ]) + + self.upsample_blocks = nn.ModuleList([ + nn.Sequential( + self._make_upsampler(in_ch + 2, out_ch), + *(ResidualConvBlock(out_ch, out_ch, dim_times_res_block_hidden * out_ch, activation="relu", norm=res_block_norm) for _ in range(num_res_blocks)) + ) for in_ch, out_ch in zip([dim_proj] + dim_upsample[:-1], dim_upsample) + ]) + + self.output_block = nn.ModuleList([ + self._make_output_block( + dim_upsample[-1] + 2, dim_out_, dim_times_res_block_hidden, last_res_blocks, last_conv_channels, last_conv_size, res_block_norm, + ) for dim_out_ in dim_out + ]) + + def _make_upsampler(self, in_channels: int, out_channels: int): + upsampler = nn.Sequential( + nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2), + nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, padding_mode='replicate') + ) + upsampler[0].weight.data[:] = upsampler[0].weight.data[:, :, :1, :1] + return upsampler + + def _make_output_block(self, dim_in: int, dim_out: int, dim_times_res_block_hidden: int, last_res_blocks: int, last_conv_channels: int, last_conv_size: int, res_block_norm: Literal['group_norm', 'layer_norm']): + return nn.Sequential( + nn.Conv2d(dim_in, last_conv_channels, kernel_size=3, stride=1, padding=1, padding_mode='replicate'), + *(ResidualConvBlock(last_conv_channels, last_conv_channels, dim_times_res_block_hidden * last_conv_channels, activation='relu', norm=res_block_norm) for _ in range(last_res_blocks)), + nn.ReLU(inplace=True), + nn.Conv2d(last_conv_channels, dim_out, kernel_size=last_conv_size, stride=1, padding=last_conv_size // 2, padding_mode='replicate'), + ) + + def forward(self, hidden_states: torch.Tensor, image: torch.Tensor): + img_h, img_w = image.shape[-2:] + patch_h, patch_w = img_h // 14, img_w // 14 + + # Process the hidden states + x = torch.stack([ + proj(feat.permute(0, 2, 1).unflatten(2, (patch_h, patch_w)).contiguous()) + for proj, (feat, clstoken) in zip(self.projects, hidden_states) + ], dim=1).sum(dim=1) + + # Upsample stage + # (patch_h, patch_w) -> (patch_h * 2, patch_w * 2) -> (patch_h * 4, patch_w * 4) -> (patch_h * 8, patch_w * 8) + for i, block in enumerate(self.upsample_blocks): + # UV coordinates is for awareness of image aspect ratio + uv = normalized_view_plane_uv(width=x.shape[-1], height=x.shape[-2], aspect_ratio=img_w / img_h, dtype=x.dtype, device=x.device) + uv = uv.permute(2, 0, 1).unsqueeze(0).expand(x.shape[0], -1, -1, -1) + x = torch.cat([x, uv], dim=1) + for layer in block: + x = torch.utils.checkpoint.checkpoint(layer, x, use_reentrant=False) + + # (patch_h * 8, patch_w * 8) -> (img_h, img_w) + x = F.interpolate(x, (img_h, img_w), mode="bilinear", align_corners=False) + uv = normalized_view_plane_uv(width=x.shape[-1], height=x.shape[-2], aspect_ratio=img_w / img_h, dtype=x.dtype, device=x.device) + uv = uv.permute(2, 0, 1).unsqueeze(0).expand(x.shape[0], -1, -1, -1) + x = torch.cat([x, uv], dim=1) + + if isinstance(self.output_block, nn.ModuleList): + output = [torch.utils.checkpoint.checkpoint(block, x, use_reentrant=False) for block in self.output_block] + else: + output = torch.utils.checkpoint.checkpoint(self.output_block, x, use_reentrant=False) + + return output + + +class MoGeModel(nn.Module): + image_mean: torch.Tensor + image_std: torch.Tensor + + def __init__(self, + encoder: str = 'dinov2_vitb14', + intermediate_layers: Union[int, List[int]] = 4, + dim_proj: int = 512, + dim_upsample: List[int] = [256, 128, 128], + dim_times_res_block_hidden: int = 1, + num_res_blocks: int = 1, + output_mask: bool = False, + split_head: bool = False, + remap_output: Literal[False, True, 'linear', 'sinh', 'exp', 'sinh_exp'] = 'linear', + res_block_norm: Literal['group_norm', 'layer_norm'] = 'group_norm', + trained_diagonal_size_range: Tuple[Number, Number] = (600, 900), + trained_area_range: Tuple[Number, Number] = (500 * 500, 700 * 700), + last_res_blocks: int = 0, + last_conv_channels: int = 32, + last_conv_size: int = 1, + **deprecated_kwargs + ): + super(MoGeModel, self).__init__() + if deprecated_kwargs: + warnings.warn(f"The following deprecated/invalid arguments are ignored: {deprecated_kwargs}") + + self.encoder = encoder + self.remap_output = remap_output + self.intermediate_layers = intermediate_layers + self.trained_diagonal_size_range = trained_diagonal_size_range + self.trained_area_range = trained_area_range + self.output_mask = output_mask + self.split_head = split_head + + # NOTE: We have copied the DINOv2 code in torchhub to this repository. + # Minimal modifications have been made: removing irrelevant code, unnecessary warnings and fixing importing issues. + hub_loader = getattr(importlib.import_module(".dinov2.hub.backbones", __package__), encoder) + self.backbone = hub_loader(pretrained=False) + dim_feature = self.backbone.blocks[0].attn.qkv.in_features + + self.head = Head( + num_features=intermediate_layers if isinstance(intermediate_layers, int) else len(intermediate_layers), + dim_in=dim_feature, + dim_out=3 if not output_mask else 4 if output_mask and not split_head else [3, 1], + dim_proj=dim_proj, + dim_upsample=dim_upsample, + dim_times_res_block_hidden=dim_times_res_block_hidden, + num_res_blocks=num_res_blocks, + res_block_norm=res_block_norm, + last_res_blocks=last_res_blocks, + last_conv_channels=last_conv_channels, + last_conv_size=last_conv_size + ) + + image_mean = torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1) + image_std = torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1) + + self.register_buffer("image_mean", image_mean) + self.register_buffer("image_std", image_std) + + if torch.__version__ >= '2.0': + self.enable_pytorch_native_sdpa() + + @classmethod + def from_pretrained(cls, pretrained_model_name_or_path: Union[str, Path, IO[bytes]], model_kwargs: Optional[Dict[str, Any]] = None, **hf_kwargs) -> 'MoGeModel': + """ + Load a model from a checkpoint file. + + ### Parameters: + - `pretrained_model_name_or_path`: path to the checkpoint file or repo id. + - `model_kwargs`: additional keyword arguments to override the parameters in the checkpoint. + - `hf_kwargs`: additional keyword arguments to pass to the `hf_hub_download` function. Ignored if `pretrained_model_name_or_path` is a local path. + + ### Returns: + - A new instance of `MoGe` with the parameters loaded from the checkpoint. + """ + if Path(pretrained_model_name_or_path).exists(): + checkpoint = torch.load(pretrained_model_name_or_path, map_location='cpu', weights_only=True) + else: + cached_checkpoint_path = hf_hub_download( + repo_id=pretrained_model_name_or_path, + repo_type="model", + filename="model.pt", + **hf_kwargs + ) + checkpoint = torch.load(cached_checkpoint_path, map_location='cpu', weights_only=True) + model_config = checkpoint['model_config'] + if model_kwargs is not None: + model_config.update(model_kwargs) + model = cls(**model_config) + model.load_state_dict(checkpoint['model']) + return model + + @staticmethod + def cache_pretrained_backbone(encoder: str, pretrained: bool): + _ = torch.hub.load('facebookresearch/dinov2', encoder, pretrained=pretrained) + + def load_pretrained_backbone(self): + "Load the backbone with pretrained dinov2 weights from torch hub" + state_dict = torch.hub.load('facebookresearch/dinov2', self.encoder, pretrained=True).state_dict() + self.backbone.load_state_dict(state_dict) + + def enable_backbone_gradient_checkpointing(self): + for i in range(len(self.backbone.blocks)): + self.backbone.blocks[i] = wrap_module_with_gradient_checkpointing(self.backbone.blocks[i]) + + def enable_pytorch_native_sdpa(self): + for i in range(len(self.backbone.blocks)): + self.backbone.blocks[i].attn = wrap_dinov2_attention_with_sdpa(self.backbone.blocks[i].attn) + + def forward(self, image: torch.Tensor, mixed_precision: bool = False) -> Dict[str, torch.Tensor]: + raw_img_h, raw_img_w = image.shape[-2:] + patch_h, patch_w = raw_img_h // 14, raw_img_w // 14 + + image = (image - self.image_mean) / self.image_std + + # Apply image transformation for DINOv2 + image_14 = F.interpolate(image, (patch_h * 14, patch_w * 14), mode="bilinear", align_corners=False, antialias=True) + + # Get intermediate layers from the backbone + with torch.autocast(device_type='cuda', dtype=torch.float16, enabled=mixed_precision): + features = self.backbone.get_intermediate_layers(image_14, self.intermediate_layers, return_class_token=True) + + # Predict points (and mask) + output = self.head(features, image) + if self.output_mask: + if self.split_head: + points, mask = output + else: + points, mask = output.split([3, 1], dim=1) + points, mask = points.permute(0, 2, 3, 1), mask.squeeze(1) + else: + points = output.permute(0, 2, 3, 1) + + if self.remap_output == 'linear' or self.remap_output == False: + pass + elif self.remap_output =='sinh' or self.remap_output == True: + points = torch.sinh(points) + elif self.remap_output == 'exp': + xy, z = points.split([2, 1], dim=-1) + z = torch.exp(z) + points = torch.cat([xy * z, z], dim=-1) + elif self.remap_output =='sinh_exp': + xy, z = points.split([2, 1], dim=-1) + points = torch.cat([torch.sinh(xy), torch.exp(z)], dim=-1) + else: + raise ValueError(f"Invalid remap output type: {self.remap_output}") + + return_dict = {'points': points} + if self.output_mask: + return_dict['mask'] = mask + return return_dict + + @torch.inference_mode() + def infer( + self, + image: torch.Tensor, + force_projection: bool = True, + resolution_level: int = 9, + apply_mask: bool = True, + fov_x: Union[Number, torch.Tensor] = None + ) -> Dict[str, torch.Tensor]: + """ + User-friendly inference function + + ### Parameters + - `image`: input image tensor of shape (B, 3, H, W) or (3, H, W) + - `resolution_level`: the resolution level to use for the output point map in 0-9. Default: 9 (highest) + - `force_projection`: if True, the output point map will be computed using the actual depth map. Default: True + - `apply_mask`: if True, the output point map will be masked using the predicted mask. Default: True + - `fov_x`: the horizontal camera FoV in degrees. If None, it will be inferred from the predicted point map. Default: None + + ### Returns + + A dictionary containing the following keys: + - `points`: output tensor of shape (B, H, W, 3) or (H, W, 3). + - `depth`: tensor of shape (B, H, W) or (H, W) containing the depth map. + - `intrinsics`: tensor of shape (B, 3, 3) or (3, 3) containing the camera intrinsics. + """ + if image.dim() == 3: + omit_batch_dim = True + image = image.unsqueeze(0) + else: + omit_batch_dim = False + + original_height, original_width = image.shape[-2:] + area = original_height * original_width + aspect_ratio = original_width / original_height + + min_area, max_area = self.trained_area_range + expected_area = min_area + (max_area - min_area) * (resolution_level / 9) + + if expected_area != area: + expected_width, expected_height = int(original_width * (expected_area / area) ** 0.5), int(original_height * (expected_area / area) ** 0.5) + image = F.interpolate(image, (expected_height, expected_width), mode="bicubic", align_corners=False, antialias=True) + + output = self.forward(image) + points, mask = output['points'], output.get('mask', None) + + # Get camera-space point map. (Focal here is the focal length relative to half the image diagonal) + if fov_x is None: + focal, shift = recover_focal_shift(points, None if mask is None else mask > 0.5) + else: + focal = aspect_ratio / (1 + aspect_ratio ** 2) ** 0.5 / torch.tan(torch.deg2rad(torch.as_tensor(fov_x, device=points.device, dtype=points.dtype) / 2)) + if focal.ndim == 0: + focal = focal[None].expand(points.shape[0]) + _, shift = recover_focal_shift(points, None if mask is None else mask > 0.5, focal=focal) + fx = focal / 2 * (1 + aspect_ratio ** 2) ** 0.5 / aspect_ratio + fy = focal / 2 * (1 + aspect_ratio ** 2) ** 0.5 + intrinsics = utils3d.torch.intrinsics_from_focal_center(fx, fy, 0.5, 0.5) + depth = points[..., 2] + shift[..., None, None] + + # If projection constraint is forced, recompute the point map using the actual depth map + if force_projection: + points = utils3d.torch.unproject_cv(utils3d.torch.image_uv(width=expected_width, height=expected_height, dtype=points.dtype, device=points.device), depth, extrinsics=None, intrinsics=intrinsics[..., None, :, :]) + else: + points = points + torch.stack([torch.zeros_like(shift), torch.zeros_like(shift), shift], dim=-1)[..., None, None, :] + + # Resize the output to the original resolution + if expected_area != area: + points = F.interpolate(points.permute(0, 3, 1, 2), (original_height, original_width), mode='bilinear', align_corners=False, antialias=False).permute(0, 2, 3, 1) + depth = F.interpolate(depth.unsqueeze(1), (original_height, original_width), mode='bilinear', align_corners=False, antialias=False).squeeze(1) + mask = None if mask is None else F.interpolate(mask.unsqueeze(1), (original_height, original_width), mode='bilinear', align_corners=False, antialias=False).squeeze(1) + + # Apply mask if needed + if self.output_mask and apply_mask: + mask_binary = (depth > 0) & (mask > 0.5) + points = torch.where(mask_binary[..., None], points, torch.inf) + depth = torch.where(mask_binary, depth, torch.inf) + + if omit_batch_dim: + points = points.squeeze(0) + intrinsics = intrinsics.squeeze(0) + depth = depth.squeeze(0) + if self.output_mask: + mask = mask.squeeze(0) + + return_dict = { + 'points': points, + 'intrinsics': intrinsics, + 'depth': depth, + } + if self.output_mask: + return_dict['mask'] = mask > 0.5 + + return return_dict \ No newline at end of file diff --git a/submodules/MoGe/moge/model/utils.py b/submodules/MoGe/moge/model/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..af0d0a042209ed87cb60f340529940359fdfa900 --- /dev/null +++ b/submodules/MoGe/moge/model/utils.py @@ -0,0 +1,38 @@ +from typing import * + +import torch +import torch.nn as nn +import torch.nn.functional as F + +def wrap_module_with_gradient_checkpointing(module: nn.Module): + from torch.utils.checkpoint import checkpoint + class _CheckpointingWrapper(module.__class__): + _restore_cls = module.__class__ + def forward(self, *args, **kwargs): + return checkpoint(super().forward, *args, use_reentrant=False, **kwargs) + + module.__class__ = _CheckpointingWrapper + return module + + +def unwrap_module_with_gradient_checkpointing(module: nn.Module): + module.__class__ = module.__class__._restore_cls + + +def wrap_dinov2_attention_with_sdpa(module: nn.Module): + assert torch.__version__ >= '2.0', "SDPA requires PyTorch 2.0 or later" + class _AttentionWrapper(module.__class__): + def forward(self, x: torch.Tensor, attn_bias=None) -> torch.Tensor: + B, N, C = x.shape + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) # (3, B, H, N, C // H) + + q, k, v = torch.unbind(qkv, 0) # (B, H, N, C // H) + + x = F.scaled_dot_product_attention(q, k, v, attn_bias) + x = x.permute(0, 2, 1, 3).reshape(B, N, C) + + x = self.proj(x) + x = self.proj_drop(x) + return x + module.__class__ = _AttentionWrapper + return module \ No newline at end of file diff --git a/submodules/MoGe/moge/utils/__init__.py b/submodules/MoGe/moge/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/submodules/MoGe/moge/utils/download.py b/submodules/MoGe/moge/utils/download.py new file mode 100644 index 0000000000000000000000000000000000000000..886edbccc81cc0c3daed4d858f641097bdfceee2 --- /dev/null +++ b/submodules/MoGe/moge/utils/download.py @@ -0,0 +1,55 @@ +from pathlib import Path +from typing import * +import requests + +from tqdm import tqdm + + +__all__ = ["download_file", "download_bytes"] + + +def download_file(url: str, filepath: Union[str, Path], headers: dict = None, resume: bool = True) -> None: + # Ensure headers is a dict if not provided + headers = headers or {} + + # Initialize local variables + file_path = Path(filepath) + downloaded_bytes = 0 + + # Check if we should resume the download + if resume and file_path.exists(): + downloaded_bytes = file_path.stat().st_size + headers['Range'] = f"bytes={downloaded_bytes}-" + + # Make a GET request to fetch the file + with requests.get(url, stream=True, headers=headers) as response: + response.raise_for_status() # This will raise an HTTPError if the status is 4xx/5xx + + # Calculate the total size to download + total_size = downloaded_bytes + int(response.headers.get('content-length', 0)) + + # Display a progress bar while downloading + with ( + tqdm(desc=f"Downloading {file_path.name}", total=total_size, unit='B', unit_scale=True, leave=False) as pbar, + open(file_path, 'ab') as file, + ): + # Set the initial position of the progress bar + pbar.update(downloaded_bytes) + + # Write the content to the file in chunks + for chunk in response.iter_content(chunk_size=4096): + file.write(chunk) + pbar.update(len(chunk)) + + +def download_bytes(url: str, headers: dict = None) -> bytes: + # Ensure headers is a dict if not provided + headers = headers or {} + + # Make a GET request to fetch the file + with requests.get(url, stream=True, headers=headers) as response: + response.raise_for_status() # This will raise an HTTPError if the status is 4xx/5xx + + # Read the content of the response + return response.content + \ No newline at end of file diff --git a/submodules/MoGe/moge/utils/geometry_numpy.py b/submodules/MoGe/moge/utils/geometry_numpy.py new file mode 100644 index 0000000000000000000000000000000000000000..299120caf608be22b4cf0c57ab38554fbc8dcfd7 --- /dev/null +++ b/submodules/MoGe/moge/utils/geometry_numpy.py @@ -0,0 +1,189 @@ +from typing import * +from functools import partial +import math + +import numpy as np +import utils3d + +from .tools import timeit + +def weighted_mean_numpy(x: np.ndarray, w: np.ndarray = None, axis: Union[int, Tuple[int,...]] = None, keepdims: bool = False, eps: float = 1e-7) -> np.ndarray: + if w is None: + return np.mean(x, axis=axis) + else: + w = w.astype(x.dtype) + return (x * w).mean(axis=axis) / np.clip(w.mean(axis=axis), eps, None) + + +def harmonic_mean_numpy(x: np.ndarray, w: np.ndarray = None, axis: Union[int, Tuple[int,...]] = None, keepdims: bool = False, eps: float = 1e-7) -> np.ndarray: + if w is None: + return 1 / (1 / np.clip(x, eps, None)).mean(axis=axis) + else: + w = w.astype(x.dtype) + return 1 / (weighted_mean_numpy(1 / (x + eps), w, axis=axis, keepdims=keepdims, eps=eps) + eps) + + +def normalized_view_plane_uv_numpy(width: int, height: int, aspect_ratio: float = None, dtype: np.dtype = np.float32) -> np.ndarray: + "UV with left-top corner as (-width / diagonal, -height / diagonal) and right-bottom corner as (width / diagonal, height / diagonal)" + if aspect_ratio is None: + aspect_ratio = width / height + + span_x = aspect_ratio / (1 + aspect_ratio ** 2) ** 0.5 + span_y = 1 / (1 + aspect_ratio ** 2) ** 0.5 + + u = np.linspace(-span_x * (width - 1) / width, span_x * (width - 1) / width, width, dtype=dtype) + v = np.linspace(-span_y * (height - 1) / height, span_y * (height - 1) / height, height, dtype=dtype) + u, v = np.meshgrid(u, v, indexing='xy') + uv = np.stack([u, v], axis=-1) + return uv + + +def focal_to_fov_numpy(focal: np.ndarray): + return 2 * np.arctan(0.5 / focal) + + +def fov_to_focal_numpy(fov: np.ndarray): + return 0.5 / np.tan(fov / 2) + + +def intrinsics_to_fov_numpy(intrinsics: np.ndarray) -> Tuple[np.ndarray, np.ndarray]: + fov_x = focal_to_fov_numpy(intrinsics[..., 0, 0]) + fov_y = focal_to_fov_numpy(intrinsics[..., 1, 1]) + return fov_x, fov_y + + +def point_map_to_depth_legacy_numpy(points: np.ndarray): + height, width = points.shape[-3:-1] + diagonal = (height ** 2 + width ** 2) ** 0.5 + uv = normalized_view_plane_uv_numpy(width, height, dtype=points.dtype) # (H, W, 2) + _, uv = np.broadcast_arrays(points[..., :2], uv) + + # Solve least squares problem + b = (uv * points[..., 2:]).reshape(*points.shape[:-3], -1) # (..., H * W * 2) + A = np.stack([points[..., :2], -uv], axis=-1).reshape(*points.shape[:-3], -1, 2) # (..., H * W * 2, 2) + + M = A.swapaxes(-2, -1) @ A + solution = (np.linalg.inv(M + 1e-6 * np.eye(2)) @ (A.swapaxes(-2, -1) @ b[..., None])).squeeze(-1) + focal, shift = solution + + depth = points[..., 2] + shift[..., None, None] + fov_x = np.arctan(width / diagonal / focal) * 2 + fov_y = np.arctan(height / diagonal / focal) * 2 + return depth, fov_x, fov_y, shift + + +def solve_optimal_focal_shift(uv: np.ndarray, xyz: np.ndarray): + "Solve `min |focal * xy / (z + shift) - uv|` with respect to shift and focal" + from scipy.optimize import least_squares + uv, xy, z = uv.reshape(-1, 2), xyz[..., :2].reshape(-1, 2), xyz[..., 2].reshape(-1) + + def fn(uv: np.ndarray, xy: np.ndarray, z: np.ndarray, shift: np.ndarray): + xy_proj = xy / (z + shift)[: , None] + f = (xy_proj * uv).sum() / np.square(xy_proj).sum() + err = (f * xy_proj - uv).ravel() + return err + + solution = least_squares(partial(fn, uv, xy, z), x0=0, ftol=1e-3, method='lm') + optim_shift = solution['x'].squeeze().astype(np.float32) + + xy_proj = xy / (z + optim_shift)[: , None] + optim_focal = (xy_proj * uv).sum() / np.square(xy_proj).sum() + + return optim_shift, optim_focal + + +def solve_optimal_shift(uv: np.ndarray, xyz: np.ndarray, focal: float): + "Solve `min |focal * xy / (z + shift) - uv|` with respect to shift" + from scipy.optimize import least_squares + uv, xy, z = uv.reshape(-1, 2), xyz[..., :2].reshape(-1, 2), xyz[..., 2].reshape(-1) + + def fn(uv: np.ndarray, xy: np.ndarray, z: np.ndarray, shift: np.ndarray): + xy_proj = xy/ (z + shift)[: , None] + err = (focal * xy_proj - uv).ravel() + return err + + solution = least_squares(partial(fn, uv, xy, z), x0=0, ftol=1e-3, method='lm') + optim_shift = solution['x'].squeeze().astype(np.float32) + + return optim_shift + + +def recover_focal_shift_numpy(points: np.ndarray, mask: np.ndarray = None, focal: float = None, downsample_size: Tuple[int, int] = (64, 64)): + import cv2 + assert points.shape[-1] == 3, "Points should (H, W, 3)" + + height, width = points.shape[-3], points.shape[-2] + diagonal = (height ** 2 + width ** 2) ** 0.5 + + uv = normalized_view_plane_uv_numpy(width=width, height=height) + + if mask is None: + points_lr = cv2.resize(points, downsample_size, interpolation=cv2.INTER_LINEAR).reshape(-1, 3) + uv_lr = cv2.resize(uv, downsample_size, interpolation=cv2.INTER_LINEAR).reshape(-1, 2) + else: + index, mask_lr = mask_aware_nearest_resize_numpy(mask, *downsample_size) + points_lr, uv_lr = points[index][mask_lr], uv[index][mask_lr] + + if points_lr.size == 0: + return np.zeros((height, width)), 0, 0, 0 + + if focal is None: + focal, shift = solve_optimal_focal_shift(uv_lr, points_lr) + else: + shift = solve_optimal_shift(uv_lr, points_lr, focal) + + return focal, shift + + +def mask_aware_nearest_resize_numpy(mask: np.ndarray, target_width: int, target_height: int) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: + """ + Resize 2D map by nearest interpolation. Return the nearest neighbor index and mask of the resized map. + + ### Parameters + - `mask`: Input 2D mask of shape (..., H, W) + - `target_width`: target width of the resized map + - `target_height`: target height of the resized map + + ### Returns + - `nearest_idx`: Nearest neighbor index of the resized map of shape (..., target_height, target_width). Indices are like j + i * W, where j is the row index and i is the column index. + - `target_mask`: Mask of the resized map of shape (..., target_height, target_width) + """ + height, width = mask.shape[-2:] + filter_h_f, filter_w_f = max(1, height / target_height), max(1, width / target_width) + filter_h_i, filter_w_i = math.ceil(filter_h_f), math.ceil(filter_w_f) + filter_size = filter_h_i * filter_w_i + padding_h, padding_w = round(filter_h_f / 2), round(filter_w_f / 2) + + # Window the original mask and uv + uv = utils3d.numpy.image_pixel_center(width=width, height=height, dtype=np.float32) + indices = np.arange(height * width, dtype=np.int32).reshape(height, width) + padded_uv = np.full((height + 2 * padding_h, width + 2 * padding_w, 2), 0, dtype=np.float32) + padded_uv[padding_h:padding_h + height, padding_w:padding_w + width] = uv + padded_mask = np.full((*mask.shape[:-2], height + 2 * padding_h, width + 2 * padding_w), False, dtype=bool) + padded_mask[..., padding_h:padding_h + height, padding_w:padding_w + width] = mask + padded_indices = np.full((height + 2 * padding_h, width + 2 * padding_w), 0, dtype=np.int32) + padded_indices[padding_h:padding_h + height, padding_w:padding_w + width] = indices + windowed_uv = utils3d.numpy.sliding_window_2d(padded_uv, (filter_h_i, filter_w_i), 1, axis=(0, 1)) + windowed_mask = utils3d.numpy.sliding_window_2d(padded_mask, (filter_h_i, filter_w_i), 1, axis=(-2, -1)) + windowed_indices = utils3d.numpy.sliding_window_2d(padded_indices, (filter_h_i, filter_w_i), 1, axis=(0, 1)) + + # Gather the target pixels's local window + target_uv = utils3d.numpy.image_uv(width=target_width, height=target_height, dtype=np.float32) * np.array([width, height], dtype=np.float32) + target_corner = target_uv - np.array((filter_w_f / 2, filter_h_f / 2), dtype=np.float32) + target_corner = np.round(target_corner - 0.5).astype(np.int32) + np.array((padding_w, padding_h), dtype=np.int32) + + target_window_uv = windowed_uv[target_corner[..., 1], target_corner[..., 0], :, :, :].reshape(target_height, target_width, 2, filter_size) # (target_height, tgt_width, 2, filter_size) + target_window_mask = windowed_mask[..., target_corner[..., 1], target_corner[..., 0], :, :].reshape(*mask.shape[:-2], target_height, target_width, filter_size) # (..., target_height, tgt_width, filter_size) + target_window_indices = windowed_indices[target_corner[..., 1], target_corner[..., 0], :, :].reshape(target_height, target_width, filter_size) # (target_height, tgt_width, filter_size) + + # Compute nearest neighbor in the local window for each pixel + dist = np.square(target_window_uv - target_uv[..., None]) + dist = dist[..., 0, :] + dist[..., 1, :] + dist = np.where(target_window_mask, dist, np.inf) # (..., target_height, tgt_width, filter_size) + nearest_in_window = np.argmin(dist, axis=-1, keepdims=True) # (..., target_height, tgt_width, 1) + nearest_idx = np.take_along_axis(target_window_indices, nearest_in_window, axis=-1).squeeze(-1) # (..., target_height, tgt_width) + nearest_i, nearest_j = nearest_idx // width, nearest_idx % width + target_mask = np.any(target_window_mask, axis=-1) + batch_indices = [np.arange(n).reshape([1] * i + [n] + [1] * (mask.ndim - i - 1)) for i, n in enumerate(mask.shape[:-2])] + + return (*batch_indices, nearest_i, nearest_j), target_mask \ No newline at end of file diff --git a/submodules/MoGe/moge/utils/geometry_torch.py b/submodules/MoGe/moge/utils/geometry_torch.py new file mode 100644 index 0000000000000000000000000000000000000000..d2a5a303d79329bf3e9feca1f4ada09e6a0ad51a --- /dev/null +++ b/submodules/MoGe/moge/utils/geometry_torch.py @@ -0,0 +1,219 @@ +from typing import * +import math +from collections import namedtuple + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.types +import utils3d + +from .tools import timeit +from .geometry_numpy import solve_optimal_focal_shift, solve_optimal_shift + + +def weighted_mean(x: torch.Tensor, w: torch.Tensor = None, dim: Union[int, torch.Size] = None, keepdim: bool = False, eps: float = 1e-7) -> torch.Tensor: + if w is None: + return x.mean(dim=dim, keepdim=keepdim) + else: + w = w.to(x.dtype) + return (x * w).mean(dim=dim, keepdim=keepdim) / w.mean(dim=dim, keepdim=keepdim).add(eps) + + +def harmonic_mean(x: torch.Tensor, w: torch.Tensor = None, dim: Union[int, torch.Size] = None, keepdim: bool = False, eps: float = 1e-7) -> torch.Tensor: + if w is None: + return x.add(eps).reciprocal().mean(dim=dim, keepdim=keepdim).reciprocal() + else: + w = w.to(x.dtype) + return weighted_mean(x.add(eps).reciprocal(), w, dim=dim, keepdim=keepdim, eps=eps).add(eps).reciprocal() + + +def geometric_mean(x: torch.Tensor, w: torch.Tensor = None, dim: Union[int, torch.Size] = None, keepdim: bool = False, eps: float = 1e-7) -> torch.Tensor: + if w is None: + return x.add(eps).log().mean(dim=dim).exp() + else: + w = w.to(x.dtype) + return weighted_mean(x.add(eps).log(), w, dim=dim, keepdim=keepdim, eps=eps).exp() + + +def normalized_view_plane_uv(width: int, height: int, aspect_ratio: float = None, dtype: torch.dtype = None, device: torch.device = None) -> torch.Tensor: + "UV with left-top corner as (-width / diagonal, -height / diagonal) and right-bottom corner as (width / diagonal, height / diagonal)" + if aspect_ratio is None: + aspect_ratio = width / height + + span_x = aspect_ratio / (1 + aspect_ratio ** 2) ** 0.5 + span_y = 1 / (1 + aspect_ratio ** 2) ** 0.5 + + u = torch.linspace(-span_x * (width - 1) / width, span_x * (width - 1) / width, width, dtype=dtype, device=device) + v = torch.linspace(-span_y * (height - 1) / height, span_y * (height - 1) / height, height, dtype=dtype, device=device) + u, v = torch.meshgrid(u, v, indexing='xy') + uv = torch.stack([u, v], dim=-1) + return uv + + +def gaussian_blur_2d(input: torch.Tensor, kernel_size: int, sigma: float) -> torch.Tensor: + kernel = torch.exp(-(torch.arange(-kernel_size // 2 + 1, kernel_size // 2 + 1, dtype=input.dtype, device=input.device) ** 2) / (2 * sigma ** 2)) + kernel = kernel / kernel.sum() + kernel = (kernel[:, None] * kernel[None, :]).reshape(1, 1, kernel_size, kernel_size) + input = F.pad(input, (kernel_size // 2, kernel_size // 2, kernel_size // 2, kernel_size // 2), mode='replicate') + input = F.conv2d(input, kernel, groups=input.shape[1]) + return input + + +def focal_to_fov(focal: torch.Tensor): + return 2 * torch.atan(0.5 / focal) + + +def fov_to_focal(fov: torch.Tensor): + return 0.5 / torch.tan(fov / 2) + + +def intrinsics_to_fov(intrinsics: torch.Tensor): + """ + Returns field of view in radians from normalized intrinsics matrix. + ### Parameters: + - intrinsics: torch.Tensor of shape (..., 3, 3) + + ### Returns: + - fov_x: torch.Tensor of shape (...) + - fov_y: torch.Tensor of shape (...) + """ + focal_x = intrinsics[..., 0, 0] + focal_y = intrinsics[..., 1, 1] + return 2 * torch.atan(0.5 / focal_x), 2 * torch.atan(0.5 / focal_y) + + +def point_map_to_depth_legacy(points: torch.Tensor): + height, width = points.shape[-3:-1] + diagonal = (height ** 2 + width ** 2) ** 0.5 + uv = normalized_view_plane_uv(width, height, dtype=points.dtype, device=points.device) # (H, W, 2) + + # Solve least squares problem + b = (uv * points[..., 2:]).flatten(-3, -1) # (..., H * W * 2) + A = torch.stack([points[..., :2], -uv.expand_as(points[..., :2])], dim=-1).flatten(-4, -2) # (..., H * W * 2, 2) + + M = A.transpose(-2, -1) @ A + solution = (torch.inverse(M + 1e-6 * torch.eye(2).to(A)) @ (A.transpose(-2, -1) @ b[..., None])).squeeze(-1) + focal, shift = solution.unbind(-1) + + depth = points[..., 2] + shift[..., None, None] + fov_x = torch.atan(width / diagonal / focal) * 2 + fov_y = torch.atan(height / diagonal / focal) * 2 + return depth, fov_x, fov_y, shift + + +def view_plane_uv_to_focal(uv: torch.Tensor): + normed_uv = normalized_view_plane_uv(width=uv.shape[-2], height=uv.shape[-3], device=uv.device, dtype=uv.dtype) + focal = (uv * normed_uv).sum() / uv.square().sum().add(1e-12) + return focal + + +def recover_focal_shift(points: torch.Tensor, mask: torch.Tensor = None, focal: torch.Tensor = None, downsample_size: Tuple[int, int] = (64, 64)): + """ + Recover the depth map and FoV from a point map with unknown z shift and focal. + + Note that it assumes: + - the optical center is at the center of the map + - the map is undistorted + - the map is isometric in the x and y directions + + ### Parameters: + - `points: torch.Tensor` of shape (..., H, W, 3) + - `mask: torch.Tensor` of shape (..., H, W). Optional. + - `focal: torch.Tensor` of shape (...). Optional. + - `downsample_size: Tuple[int, int]` in (height, width), the size of the downsampled map. Downsampling produces approximate solution and is efficient for large maps. + + ### Returns: + - `focal`: torch.Tensor of shape (...) the estimated focal length, relative to the half diagonal of the map + - `shift`: torch.Tensor of shape (...) Z-axis shift to translate the point map to camera space + """ + shape = points.shape + height, width = points.shape[-3], points.shape[-2] + diagonal = (height ** 2 + width ** 2) ** 0.5 + + points = points.reshape(-1, *shape[-3:]) + mask = None if mask is None else mask.reshape(-1, *shape[-3:-1]) + focal = focal.reshape(-1) if focal is not None else None + uv = normalized_view_plane_uv(width, height, dtype=points.dtype, device=points.device) # (H, W, 2) + + points_lr = F.interpolate(points.permute(0, 3, 1, 2), downsample_size, mode='nearest').permute(0, 2, 3, 1) + uv_lr = F.interpolate(uv.unsqueeze(0).permute(0, 3, 1, 2), downsample_size, mode='nearest').squeeze(0).permute(1, 2, 0) + mask_lr = None if mask is None else F.interpolate(mask.to(torch.float32).unsqueeze(1), downsample_size, mode='nearest').squeeze(1) > 0 + + uv_lr_np = uv_lr.cpu().numpy() + points_lr_np = points_lr.detach().cpu().numpy() + focal_np = focal.cpu().numpy() if focal is not None else None + mask_lr_np = None if mask is None else mask_lr.cpu().numpy() + optim_shift, optim_focal = [], [] + for i in range(points.shape[0]): + points_lr_i_np = points_lr_np[i] if mask is None else points_lr_np[i][mask_lr_np[i]] + uv_lr_i_np = uv_lr_np if mask is None else uv_lr_np[mask_lr_np[i]] + if focal is None: + optim_shift_i, optim_focal_i = solve_optimal_focal_shift(uv_lr_i_np, points_lr_i_np) + optim_focal.append(float(optim_focal_i)) + else: + optim_shift_i = solve_optimal_shift(uv_lr_i_np, points_lr_i_np, focal_np[i]) + optim_shift.append(float(optim_shift_i)) + optim_shift = torch.tensor(optim_shift, device=points.device, dtype=points.dtype).reshape(shape[:-3]) + + if focal is None: + optim_focal = torch.tensor(optim_focal, device=points.device, dtype=points.dtype).reshape(shape[:-3]) + else: + optim_focal = focal.reshape(shape[:-3]) + + return optim_focal, optim_shift + + +def mask_aware_nearest_resize(mask: torch.BoolTensor, target_width: int, target_height: int) -> Tuple[torch.LongTensor, torch.LongTensor, torch.BoolTensor]: + """ + Resize 2D map by nearest interpolation. Return the nearest neighbor index and mask of the resized map. + + ### Parameters + - `mask`: Input 2D mask of shape (..., H, W) + - `target_width`: target width of the resized map + - `target_height`: target height of the resized map + + ### Returns + - `nearest_idx`: Nearest neighbor index of the resized map of shape (..., target_height, target_width) for each dimension + - `target_mask`: Mask of the resized map of shape (..., target_height, target_width) + """ + height, width = mask.shape[-2:] + device = mask.device + filter_h_f, filter_w_f = max(1, height / target_height), max(1, width / target_width) + filter_h_i, filter_w_i = math.ceil(filter_h_f), math.ceil(filter_w_f) + filter_size = filter_h_i * filter_w_i + padding_h, padding_w = round(filter_h_f / 2), round(filter_w_f / 2) + + # Window the original mask and uv + uv = utils3d.torch.image_pixel_center(width=width, height=height, dtype=torch.float32, device=device) + indices = torch.arange(height * width, dtype=torch.long, device=device).reshape(height, width) + padded_uv = torch.full((height + 2 * padding_h, width + 2 * padding_w, 2), 0, dtype=torch.float32, device=device) + padded_uv[padding_h:padding_h + height, padding_w:padding_w + width] = uv + padded_mask = torch.full((*mask.shape[:-2], height + 2 * padding_h, width + 2 * padding_w), False, dtype=torch.bool, device=device) + padded_mask[..., padding_h:padding_h + height, padding_w:padding_w + width] = mask + padded_indices = torch.full((height + 2 * padding_h, width + 2 * padding_w), 0, dtype=torch.long, device=device) + padded_indices[padding_h:padding_h + height, padding_w:padding_w + width] = indices + windowed_uv = utils3d.torch.sliding_window_2d(padded_uv, (filter_h_i, filter_w_i), 1, dim=(0, 1)) + windowed_mask = utils3d.torch.sliding_window_2d(padded_mask, (filter_h_i, filter_w_i), 1, dim=(-2, -1)) + windowed_indices = utils3d.torch.sliding_window_2d(padded_indices, (filter_h_i, filter_w_i), 1, dim=(0, 1)) + + # Gather the target pixels's local window + target_uv = utils3d.torch.image_uv(width=target_width, height=target_height, dtype=torch.float32, device=device) * torch.tensor([width, height], dtype=torch.float32, device=device) + target_corner = target_uv - torch.tensor((filter_w_f / 2, filter_h_f / 2), dtype=torch.float32, device=device) + target_corner = torch.round(target_corner - 0.5).long() + torch.tensor((padding_w, padding_h), dtype=torch.long, device=device) + + target_window_uv = windowed_uv[target_corner[..., 1], target_corner[..., 0], :, :, :].reshape(target_height, target_width, 2, filter_size) # (target_height, tgt_width, 2, filter_size) + target_window_mask = windowed_mask[..., target_corner[..., 1], target_corner[..., 0], :, :].reshape(*mask.shape[:-2], target_height, target_width, filter_size) # (..., target_height, tgt_width, filter_size) + target_window_indices = windowed_indices[target_corner[..., 1], target_corner[..., 0], :, :].reshape(target_height, target_width, filter_size) # (target_height, tgt_width, filter_size) + target_window_indices = target_window_indices.expand_as(target_window_mask) + + # Compute nearest neighbor in the local window for each pixel + dist = torch.where(target_window_mask, torch.norm(target_window_uv - target_uv[..., None], dim=-2), torch.inf) # (..., target_height, tgt_width, filter_size) + nearest = torch.argmin(dist, dim=-1, keepdim=True) # (..., target_height, tgt_width, 1) + nearest_idx = torch.gather(target_window_indices, index=nearest, dim=-1).squeeze(-1) # (..., target_height, tgt_width) + target_mask = torch.any(target_window_mask, dim=-1) + nearest_i, nearest_j = nearest_idx // width, nearest_idx % width + batch_indices = [torch.arange(n, device=device).reshape([1] * i + [n] + [1] * (mask.dim() - i - 1)) for i, n in enumerate(mask.shape[:-2])] + + return (*batch_indices, nearest_i, nearest_j), target_mask diff --git a/submodules/MoGe/moge/utils/io.py b/submodules/MoGe/moge/utils/io.py new file mode 100644 index 0000000000000000000000000000000000000000..7166682ca6a1faedf8e5d54aacd79a81d8d5208f --- /dev/null +++ b/submodules/MoGe/moge/utils/io.py @@ -0,0 +1,391 @@ +import os +os.environ['OPENCV_IO_ENABLE_OPENEXR'] = '1' +from typing import IO +import zipfile +import json +import io +from typing import * +from pathlib import Path +import re + +import numpy as np +import cv2 + +from .tools import timeit + + +LEGACY_SEGFORMER_CLASSES = [ + 'wall', 'building', 'sky', 'floor', 'tree', 'ceiling', 'road', 'bed ', + 'windowpane', 'grass', 'cabinet', 'sidewalk', 'person', 'earth', + 'door', 'table', 'mountain', 'plant', 'curtain', 'chair', 'car', + 'water', 'painting', 'sofa', 'shelf', 'house', 'sea', 'mirror', 'rug', + 'field', 'armchair', 'seat', 'fence', 'desk', 'rock', 'wardrobe', + 'lamp', 'bathtub', 'railing', 'cushion', 'base', 'box', 'column', + 'signboard', 'chest of drawers', 'counter', 'sand', 'sink', + 'skyscraper', 'fireplace', 'refrigerator', 'grandstand', 'path', + 'stairs', 'runway', 'case', 'pool table', 'pillow', 'screen door', + 'stairway', 'river', 'bridge', 'bookcase', 'blind', 'coffee table', + 'toilet', 'flower', 'book', 'hill', 'bench', 'countertop', 'stove', + 'palm', 'kitchen island', 'computer', 'swivel chair', 'boat', 'bar', + 'arcade machine', 'hovel', 'bus', 'towel', 'light', 'truck', 'tower', + 'chandelier', 'awning', 'streetlight', 'booth', 'television receiver', + 'airplane', 'dirt track', 'apparel', 'pole', 'land', 'bannister', + 'escalator', 'ottoman', 'bottle', 'buffet', 'poster', 'stage', 'van', + 'ship', 'fountain', 'conveyer belt', 'canopy', 'washer', 'plaything', + 'swimming pool', 'stool', 'barrel', 'basket', 'waterfall', 'tent', + 'bag', 'minibike', 'cradle', 'oven', 'ball', 'food', 'step', 'tank', + 'trade name', 'microwave', 'pot', 'animal', 'bicycle', 'lake', + 'dishwasher', 'screen', 'blanket', 'sculpture', 'hood', 'sconce', + 'vase', 'traffic light', 'tray', 'ashcan', 'fan', 'pier', 'crt screen', + 'plate', 'monitor', 'bulletin board', 'shower', 'radiator', 'glass', + 'clock', 'flag' +] +LEGACY_SEGFORMER_LABELS = {k: i for i, k in enumerate(LEGACY_SEGFORMER_CLASSES)} + + +def write_rgbd_zip( + file: Union[IO, os.PathLike], + image: Union[np.ndarray, bytes], + depth: Union[np.ndarray, bytes], mask: Union[np.ndarray, bytes], + segmentation_mask: Union[np.ndarray, bytes] = None, segmentation_labels: Union[Dict[str, int], bytes] = None, + intrinsics: np.ndarray = None, + normal: np.ndarray = None, normal_mask: np.ndarray = None, + meta: Union[Dict[str, Any], bytes] = None, + *, image_quality: int = 95, depth_type: Literal['linear', 'log', 'disparity'] = 'linear', depth_format: Literal['png', 'exr'] = 'png', depth_max_dynamic_range: float = 1e4, png_compression: int = 7 +): + """ + Write RGBD data as zip archive containing the image, depth, mask, segmentation_mask, and meta data. + In the zip file there will be: + - `meta.json`: The meta data as a JSON file. + - `image.jpg`: The RGB image as a JPEG file. + - `depth.png/exr`: The depth map as a PNG or EXR file, depending on the `depth_type`. + - `mask.png` (optional): The mask as a uint8 PNG file. + - `segmentation_mask.png` (optional): The segformer mask as a uint8/uint16 PNG file. + + You can provided those data as np.ndarray or bytes. If you provide them as np.ndarray, they will be properly processed and encoded. + If you provide them as bytes, they will be written as is, assuming they are already encoded. + """ + if meta is None: + meta = {} + elif isinstance(meta, bytes): + meta = json.loads(meta.decode()) + + if isinstance(image, bytes): + image_bytes = image + elif isinstance(image, np.ndarray): + image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR) + image_bytes = cv2.imencode('.jpg', image, [cv2.IMWRITE_JPEG_QUALITY, image_quality])[1].tobytes() + + if isinstance(depth, bytes): + depth_bytes = depth + elif isinstance(depth, np.ndarray): + meta['depth_type'] = depth_type + if depth_type == 'linear': + if depth.dtype == np.float16: + depth_format = 'exr' + depth_bytes = cv2.imencode('.exr', depth.astype(np.float32), [cv2.IMWRITE_EXR_TYPE, cv2.IMWRITE_EXR_TYPE_HALF])[1].tobytes() + elif np.issubdtype(depth.dtype, np.floating): + depth_format = 'exr' + depth_bytes = cv2.imencode('.exr', depth.astype(np.float32), [cv2.IMWRITE_EXR_TYPE, cv2.IMWRITE_EXR_TYPE_FLOAT])[1].tobytes() + elif depth.dtype in [np.uint8, np.uint16]: + depth_format = 'png' + depth_bytes = cv2.imencode('.png', depth, [cv2.IMWRITE_PNG_COMPRESSION, png_compression])[1].tobytes() + elif depth_type == 'log': + depth_format = 'png' + depth = depth.astype(np.float32) + near = max(depth[mask].min(), 1e-3) + far = min(depth[mask].max(), near * depth_max_dynamic_range) + depth = ((np.log(depth.clip(near, far) / near) / np.log(far / near)).clip(0, 1) * 65535).astype(np.uint16) + depth_bytes = cv2.imencode('.png', depth, [cv2.IMWRITE_PNG_COMPRESSION, png_compression])[1].tobytes() + meta['depth_near'] = float(near) + meta['depth_far'] = float(far) + elif depth_type == 'disparity': + depth_format = 'png' + depth = depth.astype(np.float32) + depth = 1 / (depth + 1e-12) + depth = (depth / depth[mask].max()).clip(0, 1) + if np.unique(depth) < 200: + depth = (depth * 255).astype(np.uint8) + else: + depth = (depth * 65535).astype(np.uint16) + depth_bytes = cv2.imencode('.png', depth, [cv2.IMWRITE_PNG_COMPRESSION, png_compression])[1].tobytes() + + if isinstance(mask, bytes): + mask_bytes = mask + elif isinstance(mask, np.ndarray): + mask_bytes = cv2.imencode('.png', mask.astype(np.uint8) * 255)[1].tobytes() + + if segmentation_mask is not None: + if isinstance(segmentation_mask, bytes): + segmentation_mask_bytes = segmentation_mask + else: + segmentation_mask_bytes = cv2.imencode('.png', segmentation_mask)[1].tobytes() + assert segmentation_labels is not None, "You provided a segmentation mask, but not the corresponding labels." + if isinstance(segmentation_labels, bytes): + segmentation_labels = json.loads(segmentation_labels) + meta['segmentation_labels'] = segmentation_labels + + if intrinsics is not None: + meta['intrinsics'] = intrinsics.tolist() + + if normal is not None: + if isinstance(normal, bytes): + normal_bytes = normal + elif isinstance(normal, np.ndarray): + normal = ((normal * [0.5, -0.5, -0.5] + 0.5).clip(0, 1) * 65535).astype(np.uint16) + normal = cv2.cvtColor(normal, cv2.COLOR_RGB2BGR) + normal_bytes = cv2.imencode('.png', normal, [cv2.IMWRITE_PNG_COMPRESSION, png_compression])[1].tobytes() + if normal_mask is None: + normal_mask = np.ones(image.shape[:2], dtype=bool) + normal_mask_bytes = cv2.imencode('.png', normal_mask.astype(np.uint8) * 255)[1].tobytes() + + meta_bytes = meta if isinstance(meta, bytes) else json.dumps(meta).encode() + + with zipfile.ZipFile(file, 'w') as z: + z.writestr('meta.json', meta_bytes) + z.writestr('image.jpg', image_bytes) + z.writestr(f'depth.{depth_format}', depth_bytes) + z.writestr('mask.png', mask_bytes) + if segmentation_mask is not None: + z.writestr('segmentation_mask.png', segmentation_mask_bytes) + if normal is not None: + z.writestr('normal.png', normal_bytes) + z.writestr('normal_mask.png', normal_mask_bytes) + + +def read_rgbd_zip(file: Union[str, Path, IO], return_bytes: bool = False) -> Dict[str, Union[np.ndarray, Dict[str, Any], bytes]]: + """ + Read an RGBD zip file and return the image, depth, mask, segmentation_mask, intrinsics, and meta data. + + ### Parameters: + - `file: Union[str, Path, IO]` + The file path or file object to read from. + - `return_bytes: bool = False` + If True, return the image, depth, mask, and segmentation_mask as raw bytes. + + ### Returns: + - `Tuple[Dict[str, Union[np.ndarray, Dict[str, Any]]], Dict[str, bytes]]` + A dictionary containing: (If missing, the value will be None; if return_bytes is True, the value will be bytes) + - `image`: RGB numpy.ndarray of shape (H, W, 3). + - `depth`: float32 numpy.ndarray of shape (H, W). + - `mask`: bool numpy.ndarray of shape (H, W). + - `segformer_mask`: uint8 numpy.ndarray of shape (H, W). + - `intrinsics`: float32 numpy.ndarray of shape (3, 3). + - `meta`: Dict[str, Any]. + """ + # Load & extract archive + with zipfile.ZipFile(file, 'r') as z: + meta = z.read('meta.json') + if not return_bytes: + meta = json.loads(z.read('meta.json')) + + image = z.read('image.jpg') + if not return_bytes: + image = cv2.imdecode(np.frombuffer(z.read('image.jpg'), np.uint8), cv2.IMREAD_COLOR) + image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) + + depth_name = next(s for s in z.namelist() if s.startswith('depth')) + depth = z.read(depth_name) + if not return_bytes: + depth = cv2.imdecode(np.frombuffer(z.read(depth_name), np.uint8), cv2.IMREAD_UNCHANGED) + + if 'mask.png' in z.namelist(): + mask = z.read('mask.png') + if not return_bytes: + mask = cv2.imdecode(np.frombuffer(z.read('mask.png'), np.uint8), cv2.IMREAD_UNCHANGED) > 0 + else: + mask = None + + if 'segformer_mask.png' in z.namelist(): + # NOTE: Legacy support for segformer_mask.png + segmentation_mask = z.read('segformer_mask.png') + segmentation_labels = None + if not return_bytes: + segmentation_mask = cv2.imdecode(np.frombuffer(segmentation_mask, np.uint8), cv2.IMREAD_UNCHANGED) + segmentation_labels = LEGACY_SEGFORMER_LABELS + elif 'segmentation_mask.png' in z.namelist(): + segmentation_mask = z.read('segmentation_mask.png') + segmentation_labels = None + if not return_bytes: + segmentation_mask = cv2.imdecode(np.frombuffer(segmentation_mask, np.uint8), cv2.IMREAD_UNCHANGED) + segmentation_labels = meta['segmentation_labels'] + else: + segmentation_mask = None + segmentation_labels = None + + if 'normal.png' in z.namelist(): + normal = z.read('normal.png') + if not return_bytes: + normal = cv2.imdecode(np.frombuffer(z.read('normal.png'), np.uint8), cv2.IMREAD_UNCHANGED) + normal = cv2.cvtColor(normal, cv2.COLOR_BGR2RGB) + normal = (normal.astype(np.float32) / 65535 - 0.5) * [2.0, -2.0, -2.0] + normal = normal / np.linalg.norm(normal, axis=-1, keepdims=True) + + if 'normal_mask.png' in z.namelist(): + normal_mask = z.read('normal_mask.png') + normal_mask = cv2.imdecode(np.frombuffer(normal_mask, np.uint8), cv2.IMREAD_UNCHANGED) > 0 + else: + normal_mask = np.ones(image.shape[:2], dtype=bool) + else: + normal, normal_mask = None, None + + # recover linear depth + if not return_bytes: + if mask is None: + mask = np.ones(image.shape[:2], dtype=bool) + if meta['depth_type'] == 'linear': + depth = depth.astype(np.float32) + mask = mask & (depth > 0) + elif meta['depth_type'] == 'log': + near, far = meta['depth_near'], meta['depth_far'] + if depth.dtype == np.uint16: + depth = depth.astype(np.float32) / 65535 + elif depth.dtype == np.uint8: + depth = depth.astype(np.float32) / 255 + depth = near ** (1 - depth) * far ** depth + mask = mask & ~np.isnan(depth) + elif meta['depth_type'] == 'disparity': + mask = mask & (depth > 0) + if depth.dtype == np.uint16: + depth = depth.astype(np.float32) / 65535 + elif depth.dtype == np.uint8: + depth = depth.astype(np.float32) / 255 + depth = 1 / (depth + 1e-12) + + # intrinsics + if not return_bytes and 'intrinsics' in meta: + intrinsics = np.array(meta['intrinsics'], dtype=np.float32) + else: + intrinsics = None + + # depth unit + if not return_bytes and 'depth_unit' in meta: + depth_unit_str = meta['depth_unit'] + if r := re.match(r'([\d.]*)(\w*)', depth_unit_str): + digits, unit = r.groups() + depth_unit = float(digits or 1) * {'m': 1, 'cm': 0.01, 'mm': 0.001}[unit] + else: + depth_unit = None + else: + depth_unit = None + + return_dict = { + 'image': image, + 'depth': depth, + 'mask': mask, + 'segmentation_mask': segmentation_mask, + 'segmentation_labels': segmentation_labels, + 'normal': normal, + 'normal_mask': normal_mask, + 'intrinsics': intrinsics, + 'depth_unit': depth_unit, + 'meta': meta, + } + return_dict = {k: v for k, v in return_dict.items() if v is not None} + + return return_dict + +def write_rgbxyz(file: Union[IO, Path], image: np.ndarray, points: np.ndarray, mask: np.ndarray = None, image_quality: int = 95): + if isinstance(image, bytes): + image_bytes = image + elif isinstance(image, np.ndarray): + image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR) + image_bytes = cv2.imencode('.jpg', image, [cv2.IMWRITE_JPEG_QUALITY, image_quality])[1].tobytes() + + if isinstance(points, bytes): + points_bytes = points + elif isinstance(points, np.ndarray): + points_bytes = cv2.imencode('.exr', points.astype(np.float32), [cv2.IMWRITE_EXR_TYPE, cv2.IMWRITE_EXR_TYPE_FLOAT])[1].tobytes() + + if mask is None: + mask = np.ones(image.shape[:2], dtype=bool) + if isinstance(mask, bytes): + mask_bytes = mask + elif isinstance(mask, np.ndarray): + mask_bytes = cv2.imencode('.png', mask.astype(np.uint8) * 255)[1].tobytes() + + is_archive = hasattr(file, 'write') or Path(file).suffix == '.zip' + if is_archive: + with zipfile.ZipFile(file, 'w') as z: + z.writestr('image.jpg', image_bytes) + z.writestr('points.exr', points_bytes) + if mask is not None: + z.writestr('mask.png', mask_bytes) + else: + file = Path(file) + file.mkdir(parents=True, exist_ok=True) + with open(file / 'image.jpg', 'wb') as f: + f.write(image_bytes) + with open(file / 'points.exr', 'wb') as f: + f.write(points_bytes) + if mask is not None: + with open(file / 'mask.png', 'wb') as f: + f.write(mask_bytes) + + +def read_rgbxyz(file: Union[IO, str, Path]) -> Tuple[np.ndarray, np.ndarray, np.ndarray, Dict[str, Any]]: + is_archive = hasattr(file, 'read') or Path(file).suffix == '.zip' + if is_archive: + with zipfile.ZipFile(file, 'r') as z: + image = cv2.imdecode(np.frombuffer(z.read('image.jpg'), np.uint8), cv2.IMREAD_COLOR) + points = cv2.imdecode(np.frombuffer(z.read('points.exr'), np.uint8), cv2.IMREAD_UNCHANGED) + if 'mask.png' in z.namelist(): + mask = cv2.imdecode(np.frombuffer(z.read('mask.png'), np.uint8), cv2.IMREAD_UNCHANGED) > 0 + else: + mask = np.ones(image.shape[:2], dtype=bool) + else: + file = Path(file) + file.mkdir(parents=True, exist_ok=True) + image = cv2.imread(str(file / 'image.jpg'), cv2.IMREAD_COLOR) + image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) + points = cv2.imread(str(file / 'points.exr'), cv2.IMREAD_UNCHANGED) + if (file /'mask.png').exists(): + mask = cv2.imread(str(file / 'mask.png'), cv2.IMREAD_UNCHANGED) > 0 + else: + mask = np.ones(image.shape[:2], dtype=bool) + + return image, points, mask + + +def save_glb( + save_path: Union[str, os.PathLike], + vertices: np.ndarray, + faces: np.ndarray, + vertex_uvs: np.ndarray, + texture: np.ndarray, +): + import trimesh + import trimesh.visual + from PIL import Image + + trimesh.Trimesh( + vertices=vertices, + faces=faces, + visual = trimesh.visual.texture.TextureVisuals( + uv=vertex_uvs, + material=trimesh.visual.material.PBRMaterial( + baseColorTexture=Image.fromarray(texture), + metallicFactor=0.5, + roughnessFactor=1.0 + ) + ), + process=False + ).export(save_path) + + +def save_ply( + save_path: Union[str, os.PathLike], + vertices: np.ndarray, + faces: np.ndarray, + vertex_colors: np.ndarray, +): + import trimesh + import trimesh.visual + from PIL import Image + + trimesh.Trimesh( + vertices=vertices, + faces=faces, + vertex_colors=vertex_colors, + process=False + ).export(save_path) \ No newline at end of file diff --git a/submodules/MoGe/moge/utils/pipeline.py b/submodules/MoGe/moge/utils/pipeline.py new file mode 100644 index 0000000000000000000000000000000000000000..f652595c8b840bc672b8a3033b0694aa0986c4bf --- /dev/null +++ b/submodules/MoGe/moge/utils/pipeline.py @@ -0,0 +1,503 @@ +from typing import * +from abc import abstractmethod +from queue import Empty, Full +from threading import Thread +from queue import Queue +from multiprocessing import Process +from threading import Thread, Event +import multiprocessing +import threading +import inspect +import time +import uuid +from copy import deepcopy +import itertools +import functools + +__all__ = [ + 'Node', + 'Link', + 'ConcurrentNode', + 'Worker', + 'WorkerFunction', + 'Provider', + 'ProviderFunction', + 'Sequential', + 'Batch', + 'Unbatch', + 'Parallel', + 'Graph', + 'Buffer', +] + +TERMINATE_CHECK_INTERVAL = 0.5 + + +class _ItemWrapper: + def __init__(self, data: Any, id: Union[int, List[int]] = None): + self.data = data + self.id = id + + +class Terminate(Exception): + pass + + +def _get_queue_item(queue: Queue, terminate_flag: Event, timeout: float = None) -> _ItemWrapper: + while True: + try: + item: _ItemWrapper = queue.get(block=True, timeout=TERMINATE_CHECK_INTERVAL if timeout is None else min(timeout, TERMINATE_CHECK_INTERVAL)) + if terminate_flag.is_set(): + raise Terminate() + return item + except Empty: + if terminate_flag.is_set(): + raise Terminate() + + if timeout is not None: + timeout -= TERMINATE_CHECK_INTERVAL + if timeout <= 0: + raise Empty() + + +def _put_queue_item(queue: Queue, item: _ItemWrapper, terminate_flag: Event): + while True: + try: + queue.put(item, block=True, timeout=TERMINATE_CHECK_INTERVAL) + if terminate_flag.is_set(): + raise Terminate() + return + except Full: + if terminate_flag.is_set(): + raise Terminate() + +class Node: + def __init__(self, in_buffer_size: int = 1, out_buffer_size: int = 1) -> None: + self.input: Queue = Queue(maxsize=in_buffer_size) + self.output: Queue = Queue(maxsize=out_buffer_size) + self.in_buffer_size = in_buffer_size + self.out_buffer_size = out_buffer_size + + @abstractmethod + def start(self): + pass + + @abstractmethod + def terminate(self): + pass + + def stop(self): + self.terminate() + self.join() + + @abstractmethod + def join(self): + pass + + def put(self, data: Any, key: str = None, block: bool = True) -> None: + item = _ItemWrapper(data) + self.input.put(item, block=block) + + def get(self, key: str = None, block: bool = True) -> Any: + item: _ItemWrapper = self.output.get(block=block) + return item.data + + def __enter__(self): + self.start() + return self + + def __exit__(self, exc_type, exc_value, traceback): + self.terminate() + self.join() + + +class ConcurrentNode(Node): + job: Union[Thread, Process] + + def __init__(self, running_as: Literal['thread', 'process'] = 'thread', in_buffer_size: int = 1, out_buffer_size: int = 1) -> None: + super().__init__(in_buffer_size, out_buffer_size) + self.running_as = running_as + + @abstractmethod + def _loop_fn(self, input: Queue, output: Queue, terminate_flag: Event): + pass + + def start(self): + if self.running_as == 'thread': + terminate_flag = threading.Event() + job = Thread(target=self._loop_fn, args=(self.input, self.output, terminate_flag)) + elif self.running_as == 'process': + terminate_flag = multiprocessing.Event() + job = Process(target=self._loop_fn, args=(self.input, self.output, terminate_flag)) + job.start() + self.job = job + self.terminate_flag = terminate_flag + + def terminate(self): + self.terminate_flag.set() + + def join(self): + self.job.join() + + +class Worker(ConcurrentNode): + def __init__(self, running_as: Literal['thread', 'process'] = 'thread', in_buffer_size: int = 0, out_buffer_size: int = 0) -> None: + super().__init__(running_as, in_buffer_size, out_buffer_size) + + def init(self) -> None: + """ + This method is called the the thread is started, to initialize any resources that is only held in the thread. + """ + pass + + @abstractmethod + def work(self, *args, **kwargs) -> Union[Any, Dict[str, Any]]: + """ + This method defines the job that the node should do for each input item. + A item obtained from the input queue is passed as arguments to this method, and the result is placed in the output queue. + The method is executed concurrently with other nodes. + """ + pass + + def _loop_fn(self, input: Queue, output: Queue, terminate_flag: Event): + self.init() + try: + while True: + item = _get_queue_item(input, terminate_flag) + result = self.work(item.data) + _put_queue_item(output, _ItemWrapper(result, item.id), terminate_flag) + + except Terminate: + return + + +class Provider(ConcurrentNode): + """ + A node that provides data to successive nodes. It takes no input and provides data to the output queue. + """ + def __init__(self, running_as: Literal['thread', 'process'], out_buffer_size: int = 1) -> None: + super().__init__(running_as, 0, out_buffer_size) + + def init(self) -> None: + """ + This method is called the the thread or process is started, to initialize any resources that is only held in the thread or process. + """ + pass + + @abstractmethod + def provide(self) -> Generator[Any, None, None]: + pass + + def _loop_fn(self, input: Queue, output: Queue, terminate_flag: Event): + self.init() + try: + for data in self.provide(): + _put_queue_item(output, _ItemWrapper(data), terminate_flag) + except Terminate: + return + + +class WorkerFunction(Worker): + def __init__(self, fn: Callable, running_as: 'thread', in_buffer_size: int = 1, out_buffer_size: int = 1) -> None: + super().__init__(running_as, in_buffer_size, out_buffer_size) + self.fn = fn + + def work(self, *args, **kwargs): + return self.fn(*args, **kwargs) + + +class ProviderFunction(Provider): + def __init__(self, fn: Callable, running_as: 'thread', out_buffer_size: int = 1) -> None: + super().__init__(running_as, out_buffer_size) + self.fn = fn + + def provide(self): + for item in self.fn(): + yield item + + +class Link: + def __init__(self, src: Queue, dst: Queue): + self.src = src + self.dst = dst + + def _thread_fn(self): + try: + while True: + item = _get_queue_item(self.src, self.terminate_flag) + _put_queue_item(self.dst, item, self.terminate_flag) + except Terminate: + return + + def start(self): + self.terminate_flag = threading.Event() + self.thread = Thread(target=self._thread_fn) + self.thread.start() + + def terminate(self): + self.terminate_flag.set() + + def join(self): + self.thread.join() + + +class Graph(Node): + """ + Graph pipeline of nodes and links + """ + nodes: List[Node] + links: List[Link] + + def __init__(self, in_buffer_size: int = 1, out_buffer_size: int = 1): + super().__init__(in_buffer_size, out_buffer_size) + self.nodes = [] + self.links = [] + + def add(self, node: Node): + self.nodes.append(node) + + def link(self, src: Union[Node, Tuple[Node, str]], dst: Union[Node, Tuple[Node, str]]): + """ + Links the output of the source node to the input of the destination node. + If the source or destination node is None, the pipeline's input or output is used. + """ + src_queue = self.input if src is None else src.output + dst_queue = self.output if dst is None else dst.input + self.links.append(Link(src_queue, dst_queue)) + + def chain(self, nodes: Iterable[Node]): + """ + Link the output of each node to the input of the next node. + """ + nodes = list(nodes) + for i in range(len(nodes) - 1): + self.link(nodes[i], nodes[i + 1]) + + def start(self): + for node in self.nodes: + node.start() + for link in self.links: + link.start() + + def terminate(self): + for node in self.nodes: + node.terminate() + for link in self.links: + link.terminate() + + def join(self): + for node in self.nodes: + node.join() + for link in self.links: + link.join() + + def __iter__(self): + providers = [node for node in self.nodes if isinstance(node, Provider)] + if len(providers) == 0: + raise ValueError("No provider node found in the pipeline. If you want to iterate over the pipeline, the pipeline must be driven by a provider node.") + with self: + # while all(provider.job.is_alive() for provider in providers): + while True: + yield self.get() + + def __call__(self, data: Any) -> Any: + """ + Submit data to the pipeline's input queue, and return the output data asynchronously. + NOTE: The pipeline must be streamed (i.e., every output item is uniquely associated with an input item) for this to work. + """ + # TODO + + +class Sequential(Graph): + """ + Pipeline of nodes in sequential order, where each node takes the output of the previous node as input. + The order of input and output items is preserved (FIFO) + """ + def __init__(self, nodes: List[Union[Node, Callable]], function_running_as: Literal['thread', 'process'] = 'thread', in_buffer_size: int = 1, out_buffer_size: int = 1): + """ + Initialize the pipeline with a list of nodes to execute sequentially. + ### Parameters: + - nodes: List of nodes or functions to execute sequentially. Generator functions are wrapped in provider nodes, and other functions are wrapped in worker nodes. + - function_running_as: Whether to wrap the function as a thread or process worker. Default is 'thread'. + - in_buffer_size: Maximum size of the input queue of the pipeline. Default is 0 (unlimited). + - out_buffer_size: Maximum size of the output queue of the pipeline. Default is 0 (unlimited). + """ + super().__init__(in_buffer_size, out_buffer_size) + for node in nodes: + if isinstance(node, Node): + pass + elif isinstance(node, Callable): + if inspect.isgeneratorfunction(node): + node = ProviderFunction(node, function_running_as) + else: + node = WorkerFunction(node, function_running_as) + else: + raise ValueError(f"Invalid node type: {type(node)}") + self.add(node) + self.chain([None, *self.nodes, None]) + + +class Parallel(Node): + """ + A FIFO node that runs multiple nodes in parallel to process the input items. Each input item is handed to one of the nodes whoever is available. + NOTE: It is FIFO if and only if all the nested nodes are FIFO. + """ + nodes: List[Node] + + def __init__(self, nodes: Iterable[Node], in_buffer_size: int = 1, out_buffer_size: int = 1, function_running_as: Literal['thread', 'process'] = 'thread'): + super().__init__(in_buffer_size, out_buffer_size) + self.nodes = [] + for node in nodes: + if isinstance(node, Node): + pass + elif isinstance(node, Callable): + if inspect.isgeneratorfunction(node): + node = ProviderFunction(node, function_running_as) + else: + node = WorkerFunction(node, function_running_as) + else: + raise ValueError(f"Invalid node type: {type(node)}") + self.nodes.append(node) + self.output_order = Queue() + self.lock = threading.Lock() + + def _in_thread_fn(self, node: Node): + try: + while True: + with self.lock: + # A better idea: first make sure its node is vacant, then get it a new item. + # Currently we will not be able to know which node is busy util there is at least one item already waiting in the queue of the node. + # This could lead to suboptimal scheduling. + item = _get_queue_item(self.input, self.terminate_flag) + self.output_order.put(node.output) + _put_queue_item(node.input, item, self.terminate_flag) + except Terminate: + return + + def _out_thread_fn(self): + try: + while True: + queue = _get_queue_item(self.output_order, self.terminate_flag) + item = _get_queue_item(queue, self.terminate_flag) + _put_queue_item(self.output, item, self.terminate_flag) + except Terminate: + return + + def start(self): + self.terminate_flag = threading.Event() + self.in_threads = [] + for node in self.nodes: + thread = Thread(target=self._in_thread_fn, args=(node,)) + thread.start() + self.in_threads.append(thread) + thread = Thread(target=self._out_thread_fn) + thread.start() + self.out_thread = thread + for node in self.nodes: + node.start() + + def terminate(self): + self.terminate_flag.set() + for node in self.nodes: + node.terminate() + + def join(self): + for thread in self.in_threads: + thread.join() + self.out_thread.join() + + +class UnorderedParallel(Graph): + """ + Pipeline of nodes in parallel, where each input item is handed to one of the nodes whoever is available. + NOTE: The order of the output items is NOT guaranteed to be the same as the input items, depending on how fast the nodes handle their input. + """ + def __init__(self, nodes: List[Union[Node, Callable]], function_running_as: Literal['thread', 'process'] = 'thread', in_buffer_size: int = 1, out_buffer_size: int = 1): + """ + Initialize the pipeline with a list of nodes to execute in parallel. If a function is given, it is wrapped in a worker node. + ### Parameters: + - nodes: List of nodes or functions to execute in parallel. Generator functions are wrapped in provider nodes, and other functions are wrapped in worker nodes. + - function_running_as: Whether to wrap the function as a thread or process worker. Default is 'thread'. + - in_buffer_size: Maximum size of the input queue of the pipeline. Default is 0 (unlimited). + - out_buffer_size: Maximum size of the output queue of the pipeline. Default is 0 (unlimited). + """ + super().__init__(in_buffer_size, out_buffer_size) + for node in nodes: + if isinstance(node, Node): + pass + elif isinstance(node, Callable): + if inspect.isgeneratorfunction(node): + node = ProviderFunction(node, function_running_as) + else: + node = WorkerFunction(node, function_running_as) + else: + raise ValueError(f"Invalid node type: {type(node)}") + self.add(node) + for i in range(len(nodes)): + self.chain([None, self.nodes[i], None]) + + +class Batch(ConcurrentNode): + """ + Groups every `batch_size` items into a batch (a list of items) and passes the batch to successive nodes. + The `patience` parameter specifies the maximum time to wait for a batch to be filled before sending it to the next node, + i.e., when the earliest item in the batch is out of `patience` seconds, the batch is sent regardless of its size. + """ + def __init__(self, batch_size: int, patience: float = None, in_buffer_size: int = 1, out_buffer_size: int = 1): + assert batch_size > 0, "Batch size must be greater than 0." + super().__init__('thread', in_buffer_size, out_buffer_size) + self.batch_size = batch_size + self.patience = patience + + def _loop_fn(self, input: Queue, output: Queue, terminate_flag: Event): + try: + while True: + batch_id, batch_data = [], [] + # Try to fill the batch + for i in range(self.batch_size): + if i == 0 or self.patience is None: + timeout = None + else: + timeout = self.patience - (time.time() - earliest_time) + if timeout < 0: + break + try: + item = _get_queue_item(input, terminate_flag, timeout) + except Empty: + break + + if i == 0: + earliest_time = time.time() + batch_data.append(item.data) + batch_id.append(item.id) + + batch = _ItemWrapper(batch_data, batch_id) + _put_queue_item(output, batch, terminate_flag) + except Terminate: + return + + +class Unbatch(ConcurrentNode): + """ + Ungroups every batch (a list of items) into individual items and passes them to successive nodes. + """ + def __init__(self, in_buffer_size: int = 1, out_buffer_size: int = 1): + super().__init__('thread', in_buffer_size, out_buffer_size) + + def _loop_fn(self, input: Queue, output: Queue, terminate_flag: Event): + try: + while True: + batch = _get_queue_item(input, terminate_flag) + for id, data in zip(batch.id or itertools.repeat(None), batch.data): + item = _ItemWrapper(data, id) + _put_queue_item(output, item, terminate_flag) + except Terminate: + return + + +class Buffer(Node): + "A FIFO node that buffers items in a queue. Usefull achieve better temporal balance when its successor node has a variable processing time." + def __init__(self, size: int): + super().__init__(size, size) + self.size = size + self.input = self.output = Queue(maxsize=size) \ No newline at end of file diff --git a/submodules/MoGe/moge/utils/tools.py b/submodules/MoGe/moge/utils/tools.py new file mode 100644 index 0000000000000000000000000000000000000000..4659462450dd946388cfd9638b3b6b17cc03ba0d --- /dev/null +++ b/submodules/MoGe/moge/utils/tools.py @@ -0,0 +1,240 @@ +from typing import * +import time +from pathlib import Path +from numbers import Number + + +def catch_exception(fn): + def wrapper(*args, **kwargs): + try: + return fn(*args, **kwargs) + except Exception as e: + import traceback + print(f"Exception in {fn.__name__}({', '.join(repr(arg) for arg in args)}, {', '.join(f'{k}={v!r}' for k, v in kwargs.items())})") + traceback.print_exc(chain=False) + time.sleep(0.1) + return None + return wrapper + + +class CallbackOnException: + def __init__(self, callback: Callable, exception: type): + self.exception = exception + self.callback = callback + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + if isinstance(exc_val, self.exception): + self.callback() + return True + return False + +def traverse_nested_dict_keys(d: Dict[str, Dict]) -> Generator[Tuple[str, ...], None, None]: + for k, v in d.items(): + if isinstance(v, dict): + for sub_key in traverse_nested_dict_keys(v): + yield (k, ) + sub_key + else: + yield (k, ) + + +def get_nested_dict(d: Dict[str, Dict], keys: Tuple[str, ...], default: Any = None): + for k in keys: + d = d.get(k, default) + if d is None: + break + return d + +def set_nested_dict(d: Dict[str, Dict], keys: Tuple[str, ...], value: Any): + for k in keys[:-1]: + d = d.setdefault(k, {}) + d[keys[-1]] = value + + +def key_average(list_of_dicts: list) -> Dict[str, Any]: + """ + Returns a dictionary with the average value of each key in the input list of dictionaries. + """ + _nested_dict_keys = set() + for d in list_of_dicts: + _nested_dict_keys.update(traverse_nested_dict_keys(d)) + _nested_dict_keys = sorted(_nested_dict_keys) + result = {} + for k in _nested_dict_keys: + values = [ + get_nested_dict(d, k) for d in list_of_dicts + if get_nested_dict(d, k) is not None + ] + avg = sum(values) / len(values) if values else float('nan') + set_nested_dict(result, k, avg) + return result + + +def flatten_nested_dict(d: Dict[str, Any], parent_key: Tuple[str, ...] = None) -> Dict[Tuple[str, ...], Any]: + """ + Flattens a nested dictionary into a single-level dictionary, with keys as tuples. + """ + items = [] + if parent_key is None: + parent_key = () + for k, v in d.items(): + new_key = parent_key + (k, ) + if isinstance(v, MutableMapping): + items.extend(flatten_nested_dict(v, new_key).items()) + else: + items.append((new_key, v)) + return dict(items) + + +def unflatten_nested_dict(d: Dict[str, Any]) -> Dict[str, Any]: + """ + Unflattens a single-level dictionary into a nested dictionary, with keys as tuples. + """ + result = {} + for k, v in d.items(): + sub_dict = result + for k_ in k[:-1]: + if k_ not in sub_dict: + sub_dict[k_] = {} + sub_dict = sub_dict[k_] + sub_dict[k[-1]] = v + return result + + +def read_jsonl(file): + import json + with open(file, 'r') as f: + data = f.readlines() + return [json.loads(line) for line in data] + + +def write_jsonl(data: List[dict], file): + import json + with open(file, 'w') as f: + for item in data: + f.write(json.dumps(item) + '\n') + + +def save_metrics(save_path: Union[str, Path], all_metrics: Dict[str, List[Dict]]): + import pandas as pd + import json + + with open(save_path, 'w') as f: + json.dump(all_metrics, f, indent=4) + + +def to_hierachical_dataframe(data: List[Dict[Tuple[str, ...], Any]]): + import pandas as pd + data = [flatten_nested_dict(d) for d in data] + df = pd.DataFrame(data) + df = df.sort_index(axis=1) + df.columns = pd.MultiIndex.from_tuples(df.columns) + return df + + +def recursive_replace(d: Union[List, Dict, str], mapping: Dict[str, str]): + if isinstance(d, str): + for old, new in mapping.items(): + d = d.replace(old, new) + elif isinstance(d, list): + for i, item in enumerate(d): + d[i] = recursive_replace(item, mapping) + elif isinstance(d, dict): + for k, v in d.items(): + d[k] = recursive_replace(v, mapping) + return d + + +class timeit: + _history: Dict[str, List['timeit']] = {} + + def __init__(self, name: str = None, verbose: bool = True, multiple: bool = False): + self.name = name + self.verbose = verbose + self.start = None + self.end = None + self.multiple = multiple + if multiple and name not in timeit._history: + timeit._history[name] = [] + + def __call__(self, func: Callable): + import inspect + if inspect.iscoroutinefunction(func): + async def wrapper(*args, **kwargs): + with timeit(self.name or func.__qualname__): + ret = await func(*args, **kwargs) + return ret + return wrapper + else: + def wrapper(*args, **kwargs): + with timeit(self.name or func.__qualname__): + ret = func(*args, **kwargs) + return ret + return wrapper + + def __enter__(self): + self.start = time.time() + + @property + def time(self) -> float: + assert self.start is not None, "Time not yet started." + assert self.end is not None, "Time not yet ended." + return self.end - self.start + + @property + def history(self) -> List['timeit']: + return timeit._history.get(self.name, []) + + def __exit__(self, exc_type, exc_val, exc_tb): + self.end = time.time() + if self.multiple: + timeit._history[self.name].append(self) + if self.verbose: + if self.multiple: + avg = sum(t.time for t in timeit._history[self.name]) / len(timeit._history[self.name]) + print(f"{self.name or 'It'} took {avg} seconds in average.") + else: + print(f"{self.name or 'It'} took {self.time} seconds.") + + +def strip_common_prefix_suffix(strings: List[str]) -> List[str]: + first = strings[0] + + for start in range(len(first)): + if any(s[start] != strings[0][start] for s in strings): + break + + for end in range(1, min(len(s) for s in strings)): + if any(s[-end] != first[-end] for s in strings): + break + + return [s[start:len(s) - end + 1] for s in strings] + + +def multithead_execute(inputs: List[Any], num_workers: int, pbar = None): + from concurrent.futures import ThreadPoolExecutor + from contextlib import nullcontext + from tqdm import tqdm + + if pbar is not None: + pbar.total = len(inputs) if hasattr(inputs, '__len__') else None + else: + pbar = tqdm(total=len(inputs) if hasattr(inputs, '__len__') else None) + + def decorator(fn: Callable): + with ( + ThreadPoolExecutor(max_workers=num_workers) as executor, + pbar + ): + pbar.refresh() + @catch_exception + def _fn(input): + ret = fn(input) + pbar.update() + return ret + executor.map(_fn, inputs) + executor.shutdown(wait=True) + + return decorator \ No newline at end of file diff --git a/submodules/MoGe/moge/utils/vis.py b/submodules/MoGe/moge/utils/vis.py new file mode 100644 index 0000000000000000000000000000000000000000..a85945ce10adfbf29bdbd95fb9ad765082b3e4df --- /dev/null +++ b/submodules/MoGe/moge/utils/vis.py @@ -0,0 +1,51 @@ +import numpy as np +import matplotlib + + +def colorize_depth(depth: np.ndarray, mask: np.ndarray = None, normalize: bool = True, cmap: str = 'Spectral') -> np.ndarray: + if mask is None: + depth = np.where(depth > 0, depth, np.nan) + else: + depth = np.where((depth > 0) & mask, depth, np.nan) + disp = 1 / depth + if normalize: + min_disp, max_disp = np.nanquantile(disp, 0.001), np.nanquantile(disp, 0.999) + disp = (disp - min_disp) / (max_disp - min_disp) + colored = np.nan_to_num(matplotlib.colormaps[cmap](1.0 - disp), 0) + colored = (colored.clip(0, 1) * 255).astype(np.uint8)[:, :, :3] + return colored + + +def colorize_depth_affine(depth: np.ndarray, mask: np.ndarray = None, cmap: str = 'Spectral') -> np.ndarray: + if mask is not None: + depth = np.where(mask, depth, np.nan) + + min_depth, max_depth = np.nanquantile(depth, 0.001), np.nanquantile(depth, 0.999) + depth = (depth - min_depth) / (max_depth - min_depth) + colored = np.nan_to_num(matplotlib.colormaps[cmap](depth), 0) + colored = (colored.clip(0, 1) * 255).astype(np.uint8)[:, :, :3] + return colored + + +def colorize_disparity(disparity: np.ndarray, mask: np.ndarray = None, normalize: bool = True, cmap: str = 'Spectral') -> np.ndarray: + if mask is not None: + disparity = np.where(mask, disparity, np.nan) + + if normalize: + min_disp, max_disp = np.nanquantile(disparity, 0.001), np.nanquantile(disparity, 0.999) + disparity = (disparity - min_disp) / (max_disp - min_disp) + colored = np.nan_to_num(matplotlib.colormaps[cmap](1.0 - disparity), 0) + colored = (colored.clip(0, 1) * 255).astype(np.uint8)[:, :, :3] + return colored + + +def colorize_segmentation(segmentation: np.ndarray, cmap: str = 'Set1') -> np.ndarray: + colored = matplotlib.colormaps[cmap]((segmentation % 20) / 20) + colored = (colored.clip(0, 1) * 255).astype(np.uint8)[:, :, :3] + return colored + + +def colorize_normal(normal: np.ndarray) -> np.ndarray: + normal = normal * [0.5, -0.5, -0.5] + 0.5 + normal = (normal.clip(0, 1) * 255).astype(np.uint8) + return normal \ No newline at end of file diff --git a/submodules/MoGe/moge/utils/webfile.py b/submodules/MoGe/moge/utils/webfile.py new file mode 100644 index 0000000000000000000000000000000000000000..1e98abf8413e1c9f408849b74f4d2025d25511b6 --- /dev/null +++ b/submodules/MoGe/moge/utils/webfile.py @@ -0,0 +1,73 @@ +import requests +from typing import * + +__all__ = ["WebFile"] + + +class WebFile: + def __init__(self, url: str, session: Optional[requests.Session] = None, headers: Optional[Dict[str, str]] = None, size: Optional[int] = None): + self.url = url + self.session = session or requests.Session() + self.session.headers.update(headers or {}) + self._offset = 0 + self.size = size if size is not None else self._fetch_size() + + def _fetch_size(self): + with self.session.get(self.url, stream=True) as response: + response.raise_for_status() + content_length = response.headers.get("Content-Length") + if content_length is None: + raise ValueError("Missing Content-Length in header") + return int(content_length) + + def _fetch_data(self, offset: int, n: int) -> bytes: + headers = {"Range": f"bytes={offset}-{min(offset + n - 1, self.size)}"} + response = self.session.get(self.url, headers=headers) + response.raise_for_status() + return response.content + + def seekable(self) -> bool: + return True + + def tell(self) -> int: + return self._offset + + def available(self) -> int: + return self.size - self._offset + + def seek(self, offset: int, whence: int = 0) -> None: + if whence == 0: + new_offset = offset + elif whence == 1: + new_offset = self._offset + offset + elif whence == 2: + new_offset = self.size + offset + else: + raise ValueError("Invalid value for whence") + + self._offset = max(0, min(new_offset, self.size)) + + def read(self, n: Optional[int] = None) -> bytes: + if n is None or n < 0: + n = self.available() + else: + n = min(n, self.available()) + + if n == 0: + return b'' + + data = self._fetch_data(self._offset, n) + self._offset += len(data) + + return data + + def close(self) -> None: + pass + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_value, traceback): + pass + + \ No newline at end of file diff --git a/submodules/MoGe/moge/utils/webzipfile.py b/submodules/MoGe/moge/utils/webzipfile.py new file mode 100644 index 0000000000000000000000000000000000000000..25ed1d3cd34720335eb001d77a278539ffef569b --- /dev/null +++ b/submodules/MoGe/moge/utils/webzipfile.py @@ -0,0 +1,128 @@ +from typing import * +import io +import os +from zipfile import ( + ZipInfo, BadZipFile, ZipFile, ZipExtFile, + sizeFileHeader, structFileHeader, stringFileHeader, + _FH_SIGNATURE, _FH_FILENAME_LENGTH, _FH_EXTRA_FIELD_LENGTH, _FH_GENERAL_PURPOSE_FLAG_BITS, + _MASK_COMPRESSED_PATCH, _MASK_STRONG_ENCRYPTION, _MASK_UTF_FILENAME, _MASK_ENCRYPTED +) +import struct +from requests import Session + +from .webfile import WebFile + + +class _SharedWebFile(WebFile): + def __init__(self, webfile: WebFile, pos: int): + super().__init__(webfile.url, webfile.session, size=webfile.size) + self.seek(pos) + + +class WebZipFile(ZipFile): + "Lock-free version of ZipFile that reads from a WebFile, allowing for concurrent reads." + def __init__(self, url: str, session: Optional[Session] = None, headers: Optional[Dict[str, str]] = None): + """Open the ZIP file with mode read 'r', write 'w', exclusive create 'x', + or append 'a'.""" + webf = WebFile(url, session=session, headers=headers) + super().__init__(webf, mode='r') + + def open(self, name, mode="r", pwd=None, *, force_zip64=False): + """Return file-like object for 'name'. + + name is a string for the file name within the ZIP file, or a ZipInfo + object. + + mode should be 'r' to read a file already in the ZIP file, or 'w' to + write to a file newly added to the archive. + + pwd is the password to decrypt files (only used for reading). + + When writing, if the file size is not known in advance but may exceed + 2 GiB, pass force_zip64 to use the ZIP64 format, which can handle large + files. If the size is known in advance, it is best to pass a ZipInfo + instance for name, with zinfo.file_size set. + """ + if mode not in {"r", "w"}: + raise ValueError('open() requires mode "r" or "w"') + if pwd and (mode == "w"): + raise ValueError("pwd is only supported for reading files") + if not self.fp: + raise ValueError( + "Attempt to use ZIP archive that was already closed") + + assert mode == "r", "Only read mode is supported for now" + + # Make sure we have an info object + if isinstance(name, ZipInfo): + # 'name' is already an info object + zinfo = name + elif mode == 'w': + zinfo = ZipInfo(name) + zinfo.compress_type = self.compression + zinfo._compresslevel = self.compresslevel + else: + # Get info object for name + zinfo = self.getinfo(name) + + if mode == 'w': + return self._open_to_write(zinfo, force_zip64=force_zip64) + + if self._writing: + raise ValueError("Can't read from the ZIP file while there " + "is an open writing handle on it. " + "Close the writing handle before trying to read.") + + # Open for reading: + self._fileRefCnt += 1 + zef_file = _SharedWebFile(self.fp, zinfo.header_offset) + + try: + # Skip the file header: + fheader = zef_file.read(sizeFileHeader) + if len(fheader) != sizeFileHeader: + raise BadZipFile("Truncated file header") + fheader = struct.unpack(structFileHeader, fheader) + if fheader[_FH_SIGNATURE] != stringFileHeader: + raise BadZipFile("Bad magic number for file header") + + fname = zef_file.read(fheader[_FH_FILENAME_LENGTH]) + if fheader[_FH_EXTRA_FIELD_LENGTH]: + zef_file.seek(fheader[_FH_EXTRA_FIELD_LENGTH], whence=1) + + if zinfo.flag_bits & _MASK_COMPRESSED_PATCH: + # Zip 2.7: compressed patched data + raise NotImplementedError("compressed patched data (flag bit 5)") + + if zinfo.flag_bits & _MASK_STRONG_ENCRYPTION: + # strong encryption + raise NotImplementedError("strong encryption (flag bit 6)") + + if fheader[_FH_GENERAL_PURPOSE_FLAG_BITS] & _MASK_UTF_FILENAME: + # UTF-8 filename + fname_str = fname.decode("utf-8") + else: + fname_str = fname.decode(self.metadata_encoding or "cp437") + + if fname_str != zinfo.orig_filename: + raise BadZipFile( + 'File name in directory %r and header %r differ.' + % (zinfo.orig_filename, fname)) + + # check for encrypted flag & handle password + is_encrypted = zinfo.flag_bits & _MASK_ENCRYPTED + if is_encrypted: + if not pwd: + pwd = self.pwd + if pwd and not isinstance(pwd, bytes): + raise TypeError("pwd: expected bytes, got %s" % type(pwd).__name__) + if not pwd: + raise RuntimeError("File %r is encrypted, password " + "required for extraction" % name) + else: + pwd = None + + return ZipExtFile(zef_file, mode, zinfo, pwd, True) + except: + zef_file.close() + raise \ No newline at end of file diff --git a/submodules/MoGe/pyrightconfig.json b/submodules/MoGe/pyrightconfig.json new file mode 100644 index 0000000000000000000000000000000000000000..73fadfe308984b5bfebfd5b1ce1fbd8892667e6a --- /dev/null +++ b/submodules/MoGe/pyrightconfig.json @@ -0,0 +1,11 @@ +{ + "include": [ + "moge", + "utils3d", + "infer.py", + "app.py" + ], + "ignore": [ + "**" + ] +} \ No newline at end of file diff --git a/submodules/MoGe/requirements.txt b/submodules/MoGe/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..94782c8a0fa5555f77c2fea251df4f99da02b6e9 --- /dev/null +++ b/submodules/MoGe/requirements.txt @@ -0,0 +1,9 @@ +# The versions are not specified since MoGe should be compatible with most versions of the packages. +# If incompatibilities are found, consider upgrading to latest versions or installing the following recommended version of the package. +click # ==8.1.7 +opencv-python # ==4.10.0.84 +scipy # ==1.14.1 +matplotlib # ==3.9.2 +trimesh # ==4.5.1 +pillow # ==10.4.0 +huggingface_hub # ==0.25.2 \ No newline at end of file diff --git a/submodules/MoGe/scripts/app.py b/submodules/MoGe/scripts/app.py new file mode 100644 index 0000000000000000000000000000000000000000..f5c458be8a893070342597c082b1b0164e8d0e73 --- /dev/null +++ b/submodules/MoGe/scripts/app.py @@ -0,0 +1,133 @@ +import os +import sys +from pathlib import Path +sys.path.append(str(Path(__file__).absolute().parents[1])) +import time +import uuid +import tempfile +from typing import Union +import atexit +from concurrent.futures import ThreadPoolExecutor + +import gradio as gr +import cv2 +import torch +import numpy as np +import click +import trimesh +import trimesh.visual +from PIL import Image + +from moge.model import MoGeModel +from moge.utils.vis import colorize_depth +import utils3d + +model = MoGeModel.from_pretrained('Ruicheng/moge-vitl').cuda().eval() +thread_pool_executor = ThreadPoolExecutor(max_workers=1) + + +def delete_later(path: Union[str, os.PathLike], delay: int = 300): + def _delete(): + try: + os.remove(path) + except: + pass + def _wait_and_delete(): + time.sleep(delay) + _delete(path) + thread_pool_executor.submit(_wait_and_delete) + atexit.register(_delete) + + +def run(image: np.ndarray, remove_edge: bool = True, max_size: int = 800): + run_id = str(uuid.uuid4()) + + larger_size = max(image.shape[:2]) + if larger_size > max_size: + scale = max_size / larger_size + image = cv2.resize(image, (0, 0), fx=scale, fy=scale, interpolation=cv2.INTER_AREA) + + height, width = image.shape[:2] + + image_tensor = torch.tensor(image, dtype=torch.float32, device=torch.device('cuda')).permute(2, 0, 1) / 255 + output = model.infer(image_tensor, resolution_level=9, apply_mask=True) + points, depth, mask = output['points'].cpu().numpy(), output['depth'].cpu().numpy(), output['mask'].cpu().numpy() + normals, normals_mask = utils3d.numpy.points_to_normals(points, mask=mask) + + faces, vertices, vertex_colors, vertex_uvs = utils3d.numpy.image_mesh( + points, + image.astype(np.float32) / 255, + utils3d.numpy.image_uv(width=width, height=height), + mask=mask & ~(utils3d.numpy.depth_edge(depth, rtol=0.03, mask=mask) & utils3d.numpy.normals_edge(normals, tol=5, mask=normals_mask)), + tri=True + ) + vertices, vertex_uvs = vertices * [1, -1, -1], vertex_uvs * [1, -1] + [0, 1] + + tempdir = Path(tempfile.gettempdir(), 'moge') + tempdir.mkdir(exist_ok=True) + + output_glb_path = Path(tempdir, f'{run_id}.glb') + output_glb_path.parent.mkdir(exist_ok=True) + trimesh.Trimesh( + vertices=vertices * [-1, 1, -1], # No idea why Gradio 3D Viewer' default camera is flipped + faces=faces, + visual = trimesh.visual.texture.TextureVisuals( + uv=vertex_uvs, + material=trimesh.visual.material.PBRMaterial( + baseColorTexture=Image.fromarray(image), + metallicFactor=0.5, + roughnessFactor=1.0 + ) + ), + process=False + ).export(output_glb_path) + + output_ply_path = Path(tempdir, f'{run_id}.ply') + output_ply_path.parent.mkdir(exist_ok=True) + trimesh.Trimesh( + vertices=vertices, + faces=faces, + vertex_colors=vertex_colors, + process=False + ).export(output_ply_path) + + colorized_depth = colorize_depth(depth) + + delete_later(output_glb_path, delay=300) + delete_later(output_ply_path, delay=300) + + return colorized_depth, output_glb_path, output_ply_path.as_posix() + + +DESCRIPTION = """ +## Turn a 2D image into a 3D point map with [MoGe](https://wangrc.site/MoGePage/) + +NOTE: +* The maximum size is set to 800px for efficiency purpose. Oversized images will be downsampled. +* The color in the 3D viewer may look dark due to rendering of 3D viewer. You may download the 3D model as .glb or .ply file to view it in other 3D viewers. +""" + +@click.command() +@click.option('--share', is_flag=True, help='Whether to run the app in shared mode.') +def main(share: bool): + gr.Interface( + fn=run, + inputs=[ + gr.Image(type="numpy", image_mode="RGB"), + gr.Checkbox(True, label="Remove edges"), + ], + outputs=[ + gr.Image(type="numpy", label="Depth map (colorized)"), + gr.Model3D(display_mode="solid", clear_color=[1.0, 1.0, 1.0, 1.0], label="3D Viewer"), + gr.File(type="filepath", label="Download the model as .ply file"), + ], + title=None, + description=DESCRIPTION, + clear_btn=None, + allow_flagging="never", + theme=gr.themes.Soft() + ).launch(share=share) + + +if __name__ == '__main__': + main() \ No newline at end of file diff --git a/submodules/MoGe/scripts/infer.py b/submodules/MoGe/scripts/infer.py new file mode 100644 index 0000000000000000000000000000000000000000..b6420fda71a5df40766489530d217fe9849cd27d --- /dev/null +++ b/submodules/MoGe/scripts/infer.py @@ -0,0 +1,130 @@ +import os +os.environ['OPENCV_IO_ENABLE_OPENEXR'] = '1' +from pathlib import Path +import sys +sys.path.append(str(Path(__file__).absolute().parents[1])) + +from typing import * +import itertools +import json +import warnings + +import cv2 +import numpy as np +import torch +from PIL import Image +from tqdm import tqdm +import trimesh +import trimesh.visual +import click + +from moge.model import MoGeModel +from moge.utils.io import save_glb, save_ply +from moge.utils.vis import colorize_depth, colorize_normal +import utils3d + + +@click.command(help='Inference script for the MoGe model.') +@click.option('--input', 'input_path', type=click.Path(exists=True), help='Input image or folder path. "jpg" and "png" are supported.') +@click.option('--fov_x', 'fov_x_', type=float, default=None, help='If camera parameters are known, set the horizontal field of view in degrees. Otherwise, MoGe will estimate it.') +@click.option('--output', 'output_path', type=click.Path(), help='Output folder path') +@click.option('--pretrained', 'pretrained_model_name_or_path', type=str, default='Ruicheng/moge-vitl', help='Pretrained model name or path. Default is "Ruicheng/moge-vitl"') +@click.option('--device', 'device_name', type=str, default='cuda', help='Device name (e.g. "cuda", "cuda:0", "cpu"). Default is "cuda"') +@click.option('--resize', 'resize_to', type=int, default=None, help='Resize the image(s) & output maps to a specific size. Default is None (no resizing).') +@click.option('--resolution_level', type=int, default=9, help='An integer [0-9] for the resolution level of inference. The higher, the better but slower. Default is 9. Note that it is irrelevant to the output resolution.') +@click.option('--threshold', type=float, default=0.03, help='Threshold for removing edges. Default is 0.03. Smaller value removes more edges. "inf" means no thresholding.') +@click.option('--maps', 'save_maps_', is_flag=True, help='Whether to save the output maps and fov(image, depth, mask, points, fov).') +@click.option('--glb', 'save_glb_', is_flag=True, help='Whether to save the output as a.glb file. The color will be saved as a texture.') +@click.option('--ply', 'save_ply_', is_flag=True, help='Whether to save the output as a.ply file. The color will be saved as vertex colors.') +@click.option('--show', 'show', is_flag=True, help='Whether show the output in a window. Note that this requires pyglet<2 installed as required by trimesh.') +def main( + input_path: str, + fov_x_: float, + output_path: str, + pretrained_model_name_or_path: str, + device_name: str, + resize_to: int, + resolution_level: int, + threshold: float, + save_maps_: bool, + save_glb_: bool, + save_ply_: bool, + show: bool, +): + device = torch.device(device_name) + + include_suffices = ['jpg', 'png', 'jpeg', 'JPG', 'PNG', 'JPEG'] + if Path(input_path).is_dir(): + image_paths = sorted(itertools.chain(*(Path(input_path).rglob(f'*.{suffix}') for suffix in include_suffices))) + else: + image_paths = [Path(input_path)] + + if len(image_paths) == 0: + raise FileNotFoundError(f'No image files found in {input_path}') + + model = MoGeModel.from_pretrained(pretrained_model_name_or_path).to(device).eval() + + for image_path in (pbar := tqdm(image_paths, desc='Inference', disable=len(image_paths) <= 1)): + image = cv2.cvtColor(cv2.imread(str(image_path)), cv2.COLOR_BGR2RGB) + height, width = image.shape[:2] + if resize_to is not None: + height, width = min(resize_to, int(resize_to * height / width)), min(resize_to, int(resize_to * width / height)) + image = cv2.resize(image, (width, height), cv2.INTER_AREA) + image_tensor = torch.tensor(image / 255, dtype=torch.float32, device=device).permute(2, 0, 1) + + # Inference + output = model.infer(image_tensor, fov_x=fov_x_) + points, depth, mask, intrinsics = output['points'].cpu().numpy(), output['depth'].cpu().numpy(), output['mask'].cpu().numpy(), output['intrinsics'].cpu().numpy() + normals, normals_mask = utils3d.numpy.points_to_normals(points, mask=mask) + + # Write outputs + if not any([save_maps_, save_glb_, save_ply_]): + warnings.warn('No output format specified. Please use "--maps", "--glb", or "--ply" to specify the output.') + + save_path = Path(output_path, image_path.relative_to(input_path).parent, image_path.stem) + save_path.mkdir(exist_ok=True, parents=True) + + if save_maps_: + cv2.imwrite(str(save_path / 'image.jpg'), cv2.cvtColor(image, cv2.COLOR_RGB2BGR)) + cv2.imwrite(str(save_path / 'depth_vis.png'), cv2.cvtColor(colorize_depth(depth), cv2.COLOR_RGB2BGR)) + cv2.imwrite(str(save_path / 'depth.exr'), depth, [cv2.IMWRITE_EXR_TYPE, cv2.IMWRITE_EXR_TYPE_FLOAT]) + cv2.imwrite(str(save_path / 'mask.png'), (mask * 255).astype(np.uint8)) + cv2.imwrite(str(save_path / 'points.exr'), cv2.cvtColor(points, cv2.COLOR_RGB2BGR), [cv2.IMWRITE_EXR_TYPE, cv2.IMWRITE_EXR_TYPE_FLOAT]) + fov_x, fov_y = utils3d.numpy.intrinsics_to_fov(intrinsics) + with open(save_path / 'fov.json', 'w') as f: + json.dump({ + 'fov_x': round(float(np.rad2deg(fov_x)), 2), + 'fov_y': round(float(np.rad2deg(fov_y)), 2), + }, f) + + # Export mesh & visulization + if save_glb_ or save_ply_ or show: + faces, vertices, vertex_colors, vertex_uvs = utils3d.numpy.image_mesh( + points, + image.astype(np.float32) / 255, + utils3d.numpy.image_uv(width=width, height=height), + mask=mask & ~(utils3d.numpy.depth_edge(depth, rtol=threshold, mask=mask) & utils3d.numpy.normals_edge(normals, tol=5, mask=normals_mask)), + tri=True + ) + # When exporting the model, follow the OpenGL coordinate conventions: + # - world coordinate system: x right, y up, z backward. + # - texture coordinate system: (0, 0) for left-bottom, (1, 1) for right-top. + vertices, vertex_uvs = vertices * [1, -1, -1], vertex_uvs * [1, -1] + [0, 1] + + if save_glb_: + save_glb(save_path / 'mesh.glb', vertices, faces, vertex_uvs, image) + + if save_ply_: + save_ply(save_path / 'mesh.ply', vertices, faces, vertex_colors) + + if show: + trimesh.Trimesh( + vertices=vertices, + vertex_colors=vertex_colors, + faces=faces, + process=False + ).show() + + +if __name__ == '__main__': + main() diff --git a/submodules/MoGe/scripts/infer_panorama.py b/submodules/MoGe/scripts/infer_panorama.py new file mode 100644 index 0000000000000000000000000000000000000000..a4e61001cfbdc4e091cebaa90e1432ad739e10b8 --- /dev/null +++ b/submodules/MoGe/scripts/infer_panorama.py @@ -0,0 +1,329 @@ +import os +os.environ['OPENCV_IO_ENABLE_OPENEXR'] = '1' +from pathlib import Path +import sys +sys.path.append(str(Path(__file__).absolute().parents[1])) + +from typing import * +import itertools +import json +import warnings + +import cv2 +import numpy as np +from numpy import ndarray +import torch +from PIL import Image +from tqdm import tqdm, trange +import trimesh +import trimesh.visual +import click +from scipy.sparse import csr_array, hstack, vstack +from scipy.ndimage import convolve +from scipy.sparse.linalg import lsmr + +from moge.model import MoGeModel +from moge.utils.io import save_glb, save_ply +from moge.utils.vis import colorize_depth +import utils3d + + +def get_panorama_cameras(): + vertices, _ = utils3d.numpy.icosahedron() + intrinsics = utils3d.numpy.intrinsics_from_fov(fov_x=np.deg2rad(90), fov_y=np.deg2rad(90)) + extrinsics = utils3d.numpy.extrinsics_look_at([0, 0, 0], vertices, [0, 0, 1]).astype(np.float32) + return extrinsics, [intrinsics] * len(vertices) + + +def spherical_uv_to_directions(uv: np.ndarray): + theta, phi = (1 - uv[..., 0]) * (2 * np.pi), uv[..., 1] * np.pi + directions = np.stack([np.sin(phi) * np.cos(theta), np.sin(phi) * np.sin(theta), np.cos(phi)], axis=-1) + return directions + + +def directions_to_spherical_uv(directions: np.ndarray): + directions = directions / np.linalg.norm(directions, axis=-1, keepdims=True) + u = 1 - np.arctan2(directions[..., 1], directions[..., 0]) / (2 * np.pi) % 1.0 + v = np.arccos(directions[..., 2]) / np.pi + return np.stack([u, v], axis=-1) + + +def split_panorama_image(image: np.ndarray, extrinsics: np.ndarray, intrinsics: np.ndarray, resolution: int): + height, width = image.shape[:2] + uv = utils3d.numpy.image_uv(width=resolution, height=resolution) + splitted_images = [] + for i in range(len(extrinsics)): + spherical_uv = directions_to_spherical_uv(utils3d.numpy.unproject_cv(uv, extrinsics=extrinsics[i], intrinsics=intrinsics[i])) + pixels = utils3d.numpy.uv_to_pixel(spherical_uv, width=width, height=height).astype(np.float32) + + splitted_image = cv2.remap(image, pixels[..., 0], pixels[..., 1], interpolation=cv2.INTER_LINEAR) + splitted_images.append(splitted_image) + return splitted_images + + +def poisson_equation(width: int, height: int, wrap_x: bool = False, wrap_y: bool = False) -> Tuple[csr_array, ndarray]: + grid_index = np.arange(height * width).reshape(height, width) + grid_index = np.pad(grid_index, ((0, 0), (1, 1)), mode='wrap' if wrap_x else 'edge') + grid_index = np.pad(grid_index, ((1, 1), (0, 0)), mode='wrap' if wrap_y else 'edge') + + data = np.array([[-4, 1, 1, 1, 1]], dtype=np.float32).repeat(height * width, axis=0).reshape(-1) + indices = np.stack([ + grid_index[1:-1, 1:-1], + grid_index[:-2, 1:-1], # up + grid_index[2:, 1:-1], # down + grid_index[1:-1, :-2], # left + grid_index[1:-1, 2:] # right + ], axis=-1).reshape(-1) + indptr = np.arange(0, height * width * 5 + 1, 5) + A = csr_array((data, indices, indptr), shape=(height * width, height * width)) + + return A + + +def grad_equation(width: int, height: int, wrap_x: bool = False, wrap_y: bool = False) -> Tuple[csr_array, np.ndarray]: + grid_index = np.arange(width * height).reshape(height, width) + if wrap_x: + grid_index = np.pad(grid_index, ((0, 0), (0, 1)), mode='wrap') + if wrap_y: + grid_index = np.pad(grid_index, ((0, 1), (0, 0)), mode='wrap') + + data = np.concatenate([ + np.concatenate([ + np.ones((grid_index.shape[0], grid_index.shape[1] - 1), dtype=np.float32).reshape(-1, 1), # x[i,j] + -np.ones((grid_index.shape[0], grid_index.shape[1] - 1), dtype=np.float32).reshape(-1, 1), # x[i,j-1] + ], axis=1).reshape(-1), + np.concatenate([ + np.ones((grid_index.shape[0] - 1, grid_index.shape[1]), dtype=np.float32).reshape(-1, 1), # x[i,j] + -np.ones((grid_index.shape[0] - 1, grid_index.shape[1]), dtype=np.float32).reshape(-1, 1), # x[i-1,j] + ], axis=1).reshape(-1), + ]) + indices = np.concatenate([ + np.concatenate([ + grid_index[:, :-1].reshape(-1, 1), + grid_index[:, 1:].reshape(-1, 1), + ], axis=1).reshape(-1), + np.concatenate([ + grid_index[:-1, :].reshape(-1, 1), + grid_index[1:, :].reshape(-1, 1), + ], axis=1).reshape(-1), + ]) + indptr = np.arange(0, grid_index.shape[0] * (grid_index.shape[1] - 1) * 2 + (grid_index.shape[0] - 1) * grid_index.shape[1] * 2 + 1, 2) + A = csr_array((data, indices, indptr), shape=(grid_index.shape[0] * (grid_index.shape[1] - 1) + (grid_index.shape[0] - 1) * grid_index.shape[1], height * width)) + + return A + + +def merge_panorama_depth(width: int, height: int, distance_maps: List[np.ndarray], pred_masks: List[np.ndarray], extrinsics: List[np.ndarray], intrinsics: List[np.ndarray]): + if max(width, height) > 256: + panorama_depth_init, _ = merge_panorama_depth(width // 2, height // 2, distance_maps, pred_masks, extrinsics, intrinsics) + panorama_depth_init = cv2.resize(panorama_depth_init, (width, height), cv2.INTER_LINEAR) + else: + panorama_depth_init = None + + uv = utils3d.numpy.image_uv(width=width, height=height) + spherical_directions = spherical_uv_to_directions(uv) + + # Warp each view to the panorama + panorama_log_distance_grad_maps, panorama_grad_masks = [], [] + panorama_log_distance_laplacian_maps, panorama_laplacian_masks = [], [] + panorama_pred_masks = [] + for i in range(len(distance_maps)): + projected_uv, projected_depth = utils3d.numpy.project_cv(spherical_directions, extrinsics=extrinsics[i], intrinsics=intrinsics[i]) + projection_valid_mask = (projected_depth > 0) & (projected_uv > 0).all(axis=-1) & (projected_uv < 1).all(axis=-1) + + projected_pixels = utils3d.numpy.uv_to_pixel(np.clip(projected_uv, 0, 1), width=distance_maps[i].shape[1], height=distance_maps[i].shape[0]).astype(np.float32) + + log_splitted_distance = np.log(distance_maps[i]) + panorama_log_distance_map = np.where(projection_valid_mask, cv2.remap(log_splitted_distance, projected_pixels[..., 0], projected_pixels[..., 1], cv2.INTER_LINEAR, borderMode=cv2.BORDER_REPLICATE), 0) + panorama_pred_mask = projection_valid_mask & (cv2.remap(pred_masks[i].astype(np.uint8), projected_pixels[..., 0], projected_pixels[..., 1], cv2.INTER_NEAREST, borderMode=cv2.BORDER_REPLICATE) > 0) + + # calculate gradient map + padded = np.pad(panorama_log_distance_map, ((0, 0), (0, 1)), mode='wrap') + grad_x, grad_y = padded[:, :-1] - padded[:, 1:], padded[:-1, :] - padded[1:, :] + + padded = np.pad(panorama_pred_mask, ((0, 0), (0, 1)), mode='wrap') + mask_x, mask_y = padded[:, :-1] & padded[:, 1:], padded[:-1, :] & padded[1:, :] + + panorama_log_distance_grad_maps.append((grad_x, grad_y)) + panorama_grad_masks.append((mask_x, mask_y)) + + # calculate laplacian map + padded = np.pad(panorama_log_distance_map, ((1, 1), (0, 0)), mode='edge') + padded = np.pad(padded, ((0, 0), (1, 1)), mode='wrap') + laplacian = convolve(padded, np.array([[0, 1, 0], [1, -4, 1], [0, 1, 0]], dtype=np.float32))[1:-1, 1:-1] + + padded = np.pad(panorama_pred_mask, ((1, 1), (0, 0)), mode='edge') + padded = np.pad(padded, ((0, 0), (1, 1)), mode='wrap') + mask = convolve(padded.astype(np.uint8), np.array([[0, 1, 0], [1, 1, 1], [0, 1, 0]], dtype=np.uint8))[1:-1, 1:-1] == 5 + + panorama_log_distance_laplacian_maps.append(laplacian) + panorama_laplacian_masks.append(mask) + + panorama_pred_masks.append(panorama_pred_mask) + + panorama_log_distance_grad_x = np.stack([grad_map[0] for grad_map in panorama_log_distance_grad_maps], axis=0) + panorama_log_distance_grad_y = np.stack([grad_map[1] for grad_map in panorama_log_distance_grad_maps], axis=0) + panorama_grad_mask_x = np.stack([mask_map[0] for mask_map in panorama_grad_masks], axis=0) + panorama_grad_mask_y = np.stack([mask_map[1] for mask_map in panorama_grad_masks], axis=0) + + panorama_log_distance_grad_x = np.sum(panorama_log_distance_grad_x * panorama_grad_mask_x, axis=0) / np.sum(panorama_grad_mask_x, axis=0).clip(1e-3) + panorama_log_distance_grad_y = np.sum(panorama_log_distance_grad_y * panorama_grad_mask_y, axis=0) / np.sum(panorama_grad_mask_y, axis=0).clip(1e-3) + + panorama_laplacian_maps = np.stack(panorama_log_distance_laplacian_maps, axis=0) + panorama_laplacian_masks = np.stack(panorama_laplacian_masks, axis=0) + panorama_laplacian_map = np.sum(panorama_laplacian_maps * panorama_laplacian_masks, axis=0) / np.sum(panorama_laplacian_masks, axis=0).clip(1e-3) + + grad_x_mask = np.any(panorama_grad_mask_x, axis=0).reshape(-1) + grad_y_mask = np.any(panorama_grad_mask_y, axis=0).reshape(-1) + grad_mask = np.concatenate([grad_x_mask, grad_y_mask]) + laplacian_mask = np.any(panorama_laplacian_masks, axis=0).reshape(-1) + + # Solve overdetermined system + A = vstack([ + grad_equation(width, height, wrap_x=True, wrap_y=False)[grad_mask], + poisson_equation(width, height, wrap_x=True, wrap_y=False)[laplacian_mask], + ]) + b = np.concatenate([ + panorama_log_distance_grad_x.reshape(-1)[grad_x_mask], + panorama_log_distance_grad_y.reshape(-1)[grad_y_mask], + panorama_laplacian_map.reshape(-1)[laplacian_mask] + ]) + x, *_ = lsmr( + A, b, + atol=1e-5, btol=1e-5, + x0=np.log(panorama_depth_init).reshape(-1) if panorama_depth_init is not None else None, + show=False, + ) + + panorama_depth = np.exp(x).reshape(height, width).astype(np.float32) + panorama_mask = np.any(panorama_pred_masks, axis=0) + + return panorama_depth, panorama_mask + + +@click.command(help='Inference script for the MoGe model.') +@click.option('--input', 'input_path', type=click.Path(exists=True), help='Input image or folder path. "jpg" and "png" are supported.') +@click.option('--output', 'output_path', type=click.Path(), help='Output folder path') +@click.option('--pretrained', 'pretrained_model_name_or_path', type=str, default='Ruicheng/moge-vitl', help='Pretrained model name or path. Default is "Ruicheng/moge-vitl"') +@click.option('--device', 'device_name', type=str, default='cuda', help='Device name (e.g. "cuda", "cuda:0", "cpu"). Default is "cuda"') +@click.option('--resize', 'resize_to', type=int, default=None, help='Resize the image(s) & output maps to a specific size. Default is None (no resizing).') +@click.option('--resolution_level', type=int, default=9, help='An integer [0-9] for the resolution level of inference. The higher, the better but slower. Default is 9. Note that it is irrelevant to the output resolution.') +@click.option('--threshold', type=float, default=0.03, help='Threshold for removing edges. Default is 0.03. Smaller value removes more edges. "inf" means no thresholding.') +@click.option('--batch_size', type=int, default=4, help='Batch size for inference. Default is 4.') +@click.option('--splitted', 'save_splitted', is_flag=True, help='Whether to save the splitted images. Default is False.') +@click.option('--maps', 'save_maps_', is_flag=True, help='Whether to save the output maps and fov(image, depth, mask, points, fov).') +@click.option('--glb', 'save_glb_', is_flag=True, help='Whether to save the output as a.glb file. The color will be saved as a texture.') +@click.option('--ply', 'save_ply_', is_flag=True, help='Whether to save the output as a.ply file. The color will be saved as vertex colors.') +@click.option('--show', 'show', is_flag=True, help='Whether show the output in a window. Note that this requires pyglet<2 installed as required by trimesh.') +def main( + input_path: str, + output_path: str, + pretrained_model_name_or_path: str, + device_name: str, + resize_to: int, + resolution_level: int, + threshold: float, + batch_size: int, + save_splitted: bool, + save_maps_: bool, + save_glb_: bool, + save_ply_: bool, + show: bool, +): + device = torch.device(device_name) + + include_suffices = ['jpg', 'png', 'jpeg', 'JPG', 'PNG', 'JPEG'] + if Path(input_path).is_dir(): + image_paths = sorted(itertools.chain(*(Path(input_path).rglob(f'*.{suffix}') for suffix in include_suffices))) + else: + image_paths = [Path(input_path)] + + if len(image_paths) == 0: + raise FileNotFoundError(f'No image files found in {input_path}') + + if not any([save_maps_, save_glb_, save_ply_]): + warnings.warn('No output format specified. Please use "--maps", "--glb", or "--ply" to specify the output.') + + model = MoGeModel.from_pretrained(pretrained_model_name_or_path).to(device).eval() + + for image_path in (pbar := tqdm(image_paths, desc='Total images', disable=len(image_paths) <= 1)): + image = cv2.cvtColor(cv2.imread(str(image_path)), cv2.COLOR_BGR2RGB) + height, width = image.shape[:2] + if resize_to is not None: + height, width = min(resize_to, int(resize_to * height / width)), min(resize_to, int(resize_to * width / height)) + image = cv2.resize(image, (width, height), cv2.INTER_AREA) + + splitted_extrinsics, splitted_intriniscs = get_panorama_cameras() + splitted_resolution = 512 + splitted_images = split_panorama_image(image, splitted_extrinsics, splitted_intriniscs, splitted_resolution) + + # Infer each view + print('Inferring...') if pbar.disable else pbar.set_postfix_str(f'Inferring') + + splitted_distance_maps, splitted_masks = [], [] + for i in trange(0, len(splitted_images), batch_size, desc='Inferring splitted views', disable=len(splitted_images) <= batch_size, leave=False): + image_tensor = torch.tensor(np.stack(splitted_images[i:i + batch_size]) / 255, dtype=torch.float32, device=device).permute(0, 3, 1, 2) + fov_x, fov_y = np.rad2deg(utils3d.numpy.intrinsics_to_fov(np.array(splitted_intriniscs[i:i + batch_size]))) + fov_x = torch.tensor(fov_x, dtype=torch.float32, device=device) + output = model.infer(image_tensor, fov_x=fov_x, apply_mask=False) + distance_map, mask = output['points'].norm(dim=-1).cpu().numpy(), output['mask'].cpu().numpy() + splitted_distance_maps.extend(list(distance_map)) + splitted_masks.extend(list(mask)) + + # Save splitted + if save_splitted: + splitted_save_path = Path(output_path, image_path.stem, 'splitted') + splitted_save_path.mkdir(exist_ok=True, parents=True) + for i in range(len(splitted_images)): + cv2.imwrite(str(splitted_save_path / f'{i:02d}.jpg'), cv2.cvtColor(splitted_images[i], cv2.COLOR_RGB2BGR)) + cv2.imwrite(str(splitted_save_path / f'{i:02d}_distance_vis.png'), cv2.cvtColor(colorize_depth(splitted_distance_maps[i], splitted_masks[i]), cv2.COLOR_RGB2BGR)) + + # Merge + print('Merging...') if pbar.disable else pbar.set_postfix_str(f'Merging') + + merging_width, merging_height = min(1920, width), min(960, height) + panorama_depth, panorama_mask = merge_panorama_depth(merging_width, merging_height, splitted_distance_maps, splitted_masks, splitted_extrinsics, splitted_intriniscs) + panorama_depth = panorama_depth.astype(np.float32) + panorama_depth = cv2.resize(panorama_depth, (width, height), cv2.INTER_LINEAR) + panorama_mask = cv2.resize(panorama_mask.astype(np.uint8), (width, height), cv2.INTER_NEAREST) > 0 + points = panorama_depth[:, :, None] * spherical_uv_to_directions(utils3d.numpy.image_uv(width=width, height=height)) + + # Write outputs + print('Writing outputs...') if pbar.disable else pbar.set_postfix_str(f'Inferring') + save_path = Path(output_path, image_path.relative_to(input_path).parent, image_path.stem) + save_path.mkdir(exist_ok=True, parents=True) + if save_maps_: + cv2.imwrite(str(save_path / 'image.jpg'), cv2.cvtColor(image, cv2.COLOR_RGB2BGR)) + cv2.imwrite(str(save_path / 'depth_vis.png'), cv2.cvtColor(colorize_depth(panorama_depth, mask=panorama_mask), cv2.COLOR_RGB2BGR)) + cv2.imwrite(str(save_path / 'depth.exr'), panorama_depth, [cv2.IMWRITE_EXR_TYPE, cv2.IMWRITE_EXR_TYPE_FLOAT]) + cv2.imwrite(str(save_path / 'points.exr'), points, [cv2.IMWRITE_EXR_TYPE, cv2.IMWRITE_EXR_TYPE_FLOAT]) + cv2.imwrite(str(save_path /'mask.png'), (panorama_mask * 255).astype(np.uint8)) + + # Export mesh & visulization + if save_glb_ or save_ply_ or show: + normals, normals_mask = utils3d.numpy.points_to_normals(points, panorama_mask) + faces, vertices, vertex_colors, vertex_uvs = utils3d.numpy.image_mesh( + points, + image.astype(np.float32) / 255, + utils3d.numpy.image_uv(width=width, height=height), + mask=panorama_mask & ~(utils3d.numpy.depth_edge(panorama_depth, rtol=threshold) & utils3d.numpy.normals_edge(normals, tol=5, mask=normals_mask)), + tri=True + ) + + if save_glb_: + save_glb(save_path / 'mesh.glb', vertices, faces, vertex_uvs, image) + + if save_ply_: + save_ply(save_path / 'mesh.ply', vertices, faces, vertex_colors) + + if show: + trimesh.Trimesh( + vertices=vertices, + vertex_colors=vertex_colors, + faces=faces, + process=False + ).show() + + +if __name__ == '__main__': + main() \ No newline at end of file diff --git a/submodules/MoGe/utils3d/README.md b/submodules/MoGe/utils3d/README.md new file mode 100644 index 0000000000000000000000000000000000000000..25b9c737af1398819f077988b4dfca878f839205 --- /dev/null +++ b/submodules/MoGe/utils3d/README.md @@ -0,0 +1,3 @@ +# utils3d + +This is a collection of utility functions for 3D computer vision tasks copied from https://github.com/EasternJournalist/utils3d. diff --git a/submodules/MoGe/utils3d/__init__.py b/submodules/MoGe/utils3d/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..3291ba3a79a7f263c6208a546c20ca458a5058d9 --- /dev/null +++ b/submodules/MoGe/utils3d/__init__.py @@ -0,0 +1,20 @@ +""" +A package for common utility functions in 3D computer graphics and vision. Providing NumPy utilities in `utils3d.numpy`, PyTorch utilities in `utils3d.torch`, and IO utilities in `utils3d.io`. +""" +import importlib +from typing import TYPE_CHECKING + +try: + from ._unified import * +except ImportError: + pass + +__all__ = ['numpy', 'torch', 'io'] + +def __getattr__(name: str): + return globals().get(name, importlib.import_module(f'.{name}', __package__)) + +if TYPE_CHECKING: + from . import torch + from . import numpy + from . import io \ No newline at end of file diff --git a/submodules/MoGe/utils3d/_helpers.py b/submodules/MoGe/utils3d/_helpers.py new file mode 100644 index 0000000000000000000000000000000000000000..9d1d5b086dec88b8e2283e2546520d6f2a3d8505 --- /dev/null +++ b/submodules/MoGe/utils3d/_helpers.py @@ -0,0 +1,35 @@ +from functools import wraps +import warnings + + +def suppress_traceback(fn): + @wraps(fn) + def wrapper(*args, **kwargs): + try: + return fn(*args, **kwargs) + except Exception as e: + e.__traceback__ = e.__traceback__.tb_next.tb_next + raise + return wrapper + + +class no_warnings: + def __init__(self, action: str = 'ignore', **kwargs): + self.action = action + self.filter_kwargs = kwargs + + def __call__(self, fn): + @wraps(fn) + def wrapper(*args, **kwargs): + with warnings.catch_warnings(): + warnings.simplefilter(self.action, **self.filter_kwargs) + return fn(*args, **kwargs) + return wrapper + + def __enter__(self): + self.warnings_manager = warnings.catch_warnings() + self.warnings_manager.__enter__() + warnings.simplefilter(self.action, **self.filter_kwargs) + + def __exit__(self, exc_type, exc_val, exc_tb): + self.warnings_manager.__exit__(exc_type, exc_val, exc_tb) diff --git a/submodules/MoGe/utils3d/_unified/__init__.py b/submodules/MoGe/utils3d/_unified/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..84a675766935a51cbae7f3bcbb7378ca25227bd5 --- /dev/null +++ b/submodules/MoGe/utils3d/_unified/__init__.py @@ -0,0 +1,934 @@ +# Auto-generated implementation redirecting to numpy/torch implementations +import sys +from typing import TYPE_CHECKING +import utils3d +from .._helpers import suppress_traceback + +__all__ = ["triangulate", +"compute_face_normal", +"compute_face_angle", +"compute_vertex_normal", +"compute_vertex_normal_weighted", +"remove_corrupted_faces", +"merge_duplicate_vertices", +"remove_unreferenced_vertices", +"subdivide_mesh_simple", +"mesh_relations", +"flatten_mesh_indices", +"calc_quad_candidates", +"calc_quad_distortion", +"calc_quad_direction", +"calc_quad_smoothness", +"sovle_quad", +"sovle_quad_qp", +"tri_to_quad", +"sliding_window_1d", +"sliding_window_nd", +"sliding_window_2d", +"max_pool_1d", +"max_pool_2d", +"max_pool_nd", +"depth_edge", +"normals_edge", +"depth_aliasing", +"interpolate", +"image_scrcoord", +"image_uv", +"image_pixel_center", +"image_pixel", +"image_mesh", +"image_mesh_from_depth", +"depth_to_normals", +"points_to_normals", +"chessboard", +"cube", +"icosahedron", +"square", +"camera_frustum", +"perspective", +"perspective_from_fov", +"perspective_from_fov_xy", +"intrinsics_from_focal_center", +"intrinsics_from_fov", +"fov_to_focal", +"focal_to_fov", +"intrinsics_to_fov", +"view_look_at", +"extrinsics_look_at", +"perspective_to_intrinsics", +"perspective_to_near_far", +"intrinsics_to_perspective", +"extrinsics_to_view", +"view_to_extrinsics", +"normalize_intrinsics", +"crop_intrinsics", +"pixel_to_uv", +"pixel_to_ndc", +"uv_to_pixel", +"project_depth", +"depth_buffer_to_linear", +"unproject_cv", +"unproject_gl", +"project_cv", +"project_gl", +"quaternion_to_matrix", +"axis_angle_to_matrix", +"matrix_to_quaternion", +"extrinsics_to_essential", +"euler_axis_angle_rotation", +"euler_angles_to_matrix", +"skew_symmetric", +"rotation_matrix_from_vectors", +"ray_intersection", +"se3_matrix", +"slerp_quaternion", +"slerp_vector", +"lerp", +"lerp_se3_matrix", +"piecewise_lerp", +"piecewise_lerp_se3_matrix", +"apply_transform", +"linear_spline_interpolate", +"RastContext", +"rasterize_triangle_faces", +"rasterize_edges", +"texture", +"warp_image_by_depth", +"test_rasterization", +"compute_face_angles", +"compute_face_tbn", +"compute_vertex_tbn", +"laplacian", +"laplacian_smooth_mesh", +"taubin_smooth_mesh", +"laplacian_hc_smooth_mesh", +"get_rays", +"get_image_rays", +"get_mipnerf_cones", +"volume_rendering", +"bin_sample", +"importance_sample", +"nerf_render_rays", +"mipnerf_render_rays", +"nerf_render_view", +"mipnerf_render_view", +"InstantNGP", +"point_to_normal", +"depth_to_normal", +"masked_min", +"masked_max", +"bounding_rect", +"intrinsics_from_fov_xy", +"matrix_to_euler_angles", +"matrix_to_axis_angle", +"axis_angle_to_quaternion", +"quaternion_to_axis_angle", +"slerp", +"interpolate_extrinsics", +"interpolate_view", +"to4x4", +"rotation_matrix_2d", +"rotate_2d", +"translate_2d", +"scale_2d", +"apply_2d", +"warp_image_by_forward_flow"] + +def _contains_tensor(obj): + if isinstance(obj, (list, tuple)): + return any(_contains_tensor(item) for item in obj) + elif isinstance(obj, dict): + return any(_contains_tensor(value) for value in obj.values()) + else: + import torch + return isinstance(obj, torch.Tensor) + + +@suppress_traceback +def _call_based_on_args(fname, args, kwargs): + if 'torch' in sys.modules: + if any(_contains_tensor(arg) for arg in args) or any(_contains_tensor(v) for v in kwargs.values()): + fn = getattr(utils3d.torch, fname, None) + if fn is None: + raise NotImplementedError(f"Function {fname} has no torch implementation.") + return fn(*args, **kwargs) + fn = getattr(utils3d.numpy, fname, None) + if fn is None: + raise NotImplementedError(f"Function {fname} has no numpy implementation.") + return fn(*args, **kwargs) + + +@suppress_traceback +def triangulate(*args, **kwargs): + if TYPE_CHECKING: # redirected to: + utils3d.numpy.triangulate, utils3d.torch.triangulate + return _call_based_on_args('triangulate', args, kwargs) + +@suppress_traceback +def compute_face_normal(*args, **kwargs): + if TYPE_CHECKING: # redirected to: + utils3d.numpy.compute_face_normal, utils3d.torch.compute_face_normal + return _call_based_on_args('compute_face_normal', args, kwargs) + +@suppress_traceback +def compute_face_angle(*args, **kwargs): + if TYPE_CHECKING: # redirected to: + utils3d.numpy.compute_face_angle, None + return _call_based_on_args('compute_face_angle', args, kwargs) + +@suppress_traceback +def compute_vertex_normal(*args, **kwargs): + if TYPE_CHECKING: # redirected to: + utils3d.numpy.compute_vertex_normal, utils3d.torch.compute_vertex_normal + return _call_based_on_args('compute_vertex_normal', args, kwargs) + +@suppress_traceback +def compute_vertex_normal_weighted(*args, **kwargs): + if TYPE_CHECKING: # redirected to: + utils3d.numpy.compute_vertex_normal_weighted, utils3d.torch.compute_vertex_normal_weighted + return _call_based_on_args('compute_vertex_normal_weighted', args, kwargs) + +@suppress_traceback +def remove_corrupted_faces(*args, **kwargs): + if TYPE_CHECKING: # redirected to: + utils3d.numpy.remove_corrupted_faces, utils3d.torch.remove_corrupted_faces + return _call_based_on_args('remove_corrupted_faces', args, kwargs) + +@suppress_traceback +def merge_duplicate_vertices(*args, **kwargs): + if TYPE_CHECKING: # redirected to: + utils3d.numpy.merge_duplicate_vertices, utils3d.torch.merge_duplicate_vertices + return _call_based_on_args('merge_duplicate_vertices', args, kwargs) + +@suppress_traceback +def remove_unreferenced_vertices(*args, **kwargs): + if TYPE_CHECKING: # redirected to: + utils3d.numpy.remove_unreferenced_vertices, utils3d.torch.remove_unreferenced_vertices + return _call_based_on_args('remove_unreferenced_vertices', args, kwargs) + +@suppress_traceback +def subdivide_mesh_simple(*args, **kwargs): + if TYPE_CHECKING: # redirected to: + utils3d.numpy.subdivide_mesh_simple, utils3d.torch.subdivide_mesh_simple + return _call_based_on_args('subdivide_mesh_simple', args, kwargs) + +@suppress_traceback +def mesh_relations(*args, **kwargs): + if TYPE_CHECKING: # redirected to: + utils3d.numpy.mesh_relations, None + return _call_based_on_args('mesh_relations', args, kwargs) + +@suppress_traceback +def flatten_mesh_indices(*args, **kwargs): + if TYPE_CHECKING: # redirected to: + utils3d.numpy.flatten_mesh_indices, None + return _call_based_on_args('flatten_mesh_indices', args, kwargs) + +@suppress_traceback +def calc_quad_candidates(*args, **kwargs): + if TYPE_CHECKING: # redirected to: + utils3d.numpy.calc_quad_candidates, None + return _call_based_on_args('calc_quad_candidates', args, kwargs) + +@suppress_traceback +def calc_quad_distortion(*args, **kwargs): + if TYPE_CHECKING: # redirected to: + utils3d.numpy.calc_quad_distortion, None + return _call_based_on_args('calc_quad_distortion', args, kwargs) + +@suppress_traceback +def calc_quad_direction(*args, **kwargs): + if TYPE_CHECKING: # redirected to: + utils3d.numpy.calc_quad_direction, None + return _call_based_on_args('calc_quad_direction', args, kwargs) + +@suppress_traceback +def calc_quad_smoothness(*args, **kwargs): + if TYPE_CHECKING: # redirected to: + utils3d.numpy.calc_quad_smoothness, None + return _call_based_on_args('calc_quad_smoothness', args, kwargs) + +@suppress_traceback +def sovle_quad(*args, **kwargs): + if TYPE_CHECKING: # redirected to: + utils3d.numpy.sovle_quad, None + return _call_based_on_args('sovle_quad', args, kwargs) + +@suppress_traceback +def sovle_quad_qp(*args, **kwargs): + if TYPE_CHECKING: # redirected to: + utils3d.numpy.sovle_quad_qp, None + return _call_based_on_args('sovle_quad_qp', args, kwargs) + +@suppress_traceback +def tri_to_quad(*args, **kwargs): + if TYPE_CHECKING: # redirected to: + utils3d.numpy.tri_to_quad, None + return _call_based_on_args('tri_to_quad', args, kwargs) + +@suppress_traceback +def sliding_window_1d(*args, **kwargs): + if TYPE_CHECKING: # redirected to: + utils3d.numpy.sliding_window_1d, utils3d.torch.sliding_window_1d + return _call_based_on_args('sliding_window_1d', args, kwargs) + +@suppress_traceback +def sliding_window_nd(*args, **kwargs): + if TYPE_CHECKING: # redirected to: + utils3d.numpy.sliding_window_nd, utils3d.torch.sliding_window_nd + return _call_based_on_args('sliding_window_nd', args, kwargs) + +@suppress_traceback +def sliding_window_2d(*args, **kwargs): + if TYPE_CHECKING: # redirected to: + utils3d.numpy.sliding_window_2d, utils3d.torch.sliding_window_2d + return _call_based_on_args('sliding_window_2d', args, kwargs) + +@suppress_traceback +def max_pool_1d(*args, **kwargs): + if TYPE_CHECKING: # redirected to: + utils3d.numpy.max_pool_1d, None + return _call_based_on_args('max_pool_1d', args, kwargs) + +@suppress_traceback +def max_pool_2d(*args, **kwargs): + if TYPE_CHECKING: # redirected to: + utils3d.numpy.max_pool_2d, None + return _call_based_on_args('max_pool_2d', args, kwargs) + +@suppress_traceback +def max_pool_nd(*args, **kwargs): + if TYPE_CHECKING: # redirected to: + utils3d.numpy.max_pool_nd, None + return _call_based_on_args('max_pool_nd', args, kwargs) + +@suppress_traceback +def depth_edge(*args, **kwargs): + if TYPE_CHECKING: # redirected to: + utils3d.numpy.depth_edge, utils3d.torch.depth_edge + return _call_based_on_args('depth_edge', args, kwargs) + +@suppress_traceback +def normals_edge(*args, **kwargs): + if TYPE_CHECKING: # redirected to: + utils3d.numpy.normals_edge, None + return _call_based_on_args('normals_edge', args, kwargs) + +@suppress_traceback +def depth_aliasing(*args, **kwargs): + if TYPE_CHECKING: # redirected to: + utils3d.numpy.depth_aliasing, utils3d.torch.depth_aliasing + return _call_based_on_args('depth_aliasing', args, kwargs) + +@suppress_traceback +def interpolate(*args, **kwargs): + if TYPE_CHECKING: # redirected to: + utils3d.numpy.interpolate, None + return _call_based_on_args('interpolate', args, kwargs) + +@suppress_traceback +def image_scrcoord(*args, **kwargs): + if TYPE_CHECKING: # redirected to: + utils3d.numpy.image_scrcoord, None + return _call_based_on_args('image_scrcoord', args, kwargs) + +@suppress_traceback +def image_uv(*args, **kwargs): + if TYPE_CHECKING: # redirected to: + utils3d.numpy.image_uv, utils3d.torch.image_uv + return _call_based_on_args('image_uv', args, kwargs) + +@suppress_traceback +def image_pixel_center(*args, **kwargs): + if TYPE_CHECKING: # redirected to: + utils3d.numpy.image_pixel_center, utils3d.torch.image_pixel_center + return _call_based_on_args('image_pixel_center', args, kwargs) + +@suppress_traceback +def image_pixel(*args, **kwargs): + if TYPE_CHECKING: # redirected to: + utils3d.numpy.image_pixel, None + return _call_based_on_args('image_pixel', args, kwargs) + +@suppress_traceback +def image_mesh(*args, **kwargs): + if TYPE_CHECKING: # redirected to: + utils3d.numpy.image_mesh, utils3d.torch.image_mesh + return _call_based_on_args('image_mesh', args, kwargs) + +@suppress_traceback +def image_mesh_from_depth(*args, **kwargs): + if TYPE_CHECKING: # redirected to: + utils3d.numpy.image_mesh_from_depth, utils3d.torch.image_mesh_from_depth + return _call_based_on_args('image_mesh_from_depth', args, kwargs) + +@suppress_traceback +def depth_to_normals(*args, **kwargs): + if TYPE_CHECKING: # redirected to: + utils3d.numpy.depth_to_normals, None + return _call_based_on_args('depth_to_normals', args, kwargs) + +@suppress_traceback +def points_to_normals(*args, **kwargs): + if TYPE_CHECKING: # redirected to: + utils3d.numpy.points_to_normals, None + return _call_based_on_args('points_to_normals', args, kwargs) + +@suppress_traceback +def chessboard(*args, **kwargs): + if TYPE_CHECKING: # redirected to: + utils3d.numpy.chessboard, utils3d.torch.chessboard + return _call_based_on_args('chessboard', args, kwargs) + +@suppress_traceback +def cube(*args, **kwargs): + if TYPE_CHECKING: # redirected to: + utils3d.numpy.cube, None + return _call_based_on_args('cube', args, kwargs) + +@suppress_traceback +def icosahedron(*args, **kwargs): + if TYPE_CHECKING: # redirected to: + utils3d.numpy.icosahedron, None + return _call_based_on_args('icosahedron', args, kwargs) + +@suppress_traceback +def square(*args, **kwargs): + if TYPE_CHECKING: # redirected to: + utils3d.numpy.square, None + return _call_based_on_args('square', args, kwargs) + +@suppress_traceback +def camera_frustum(*args, **kwargs): + if TYPE_CHECKING: # redirected to: + utils3d.numpy.camera_frustum, None + return _call_based_on_args('camera_frustum', args, kwargs) + +@suppress_traceback +def perspective(*args, **kwargs): + if TYPE_CHECKING: # redirected to: + utils3d.numpy.perspective, utils3d.torch.perspective + return _call_based_on_args('perspective', args, kwargs) + +@suppress_traceback +def perspective_from_fov(*args, **kwargs): + if TYPE_CHECKING: # redirected to: + utils3d.numpy.perspective_from_fov, utils3d.torch.perspective_from_fov + return _call_based_on_args('perspective_from_fov', args, kwargs) + +@suppress_traceback +def perspective_from_fov_xy(*args, **kwargs): + if TYPE_CHECKING: # redirected to: + utils3d.numpy.perspective_from_fov_xy, utils3d.torch.perspective_from_fov_xy + return _call_based_on_args('perspective_from_fov_xy', args, kwargs) + +@suppress_traceback +def intrinsics_from_focal_center(*args, **kwargs): + if TYPE_CHECKING: # redirected to: + utils3d.numpy.intrinsics_from_focal_center, utils3d.torch.intrinsics_from_focal_center + return _call_based_on_args('intrinsics_from_focal_center', args, kwargs) + +@suppress_traceback +def intrinsics_from_fov(*args, **kwargs): + if TYPE_CHECKING: # redirected to: + utils3d.numpy.intrinsics_from_fov, utils3d.torch.intrinsics_from_fov + return _call_based_on_args('intrinsics_from_fov', args, kwargs) + +@suppress_traceback +def fov_to_focal(*args, **kwargs): + if TYPE_CHECKING: # redirected to: + utils3d.numpy.fov_to_focal, None + return _call_based_on_args('fov_to_focal', args, kwargs) + +@suppress_traceback +def focal_to_fov(*args, **kwargs): + if TYPE_CHECKING: # redirected to: + utils3d.numpy.focal_to_fov, None + return _call_based_on_args('focal_to_fov', args, kwargs) + +@suppress_traceback +def intrinsics_to_fov(*args, **kwargs): + if TYPE_CHECKING: # redirected to: + utils3d.numpy.intrinsics_to_fov, None + return _call_based_on_args('intrinsics_to_fov', args, kwargs) + +@suppress_traceback +def view_look_at(*args, **kwargs): + if TYPE_CHECKING: # redirected to: + utils3d.numpy.view_look_at, utils3d.torch.view_look_at + return _call_based_on_args('view_look_at', args, kwargs) + +@suppress_traceback +def extrinsics_look_at(*args, **kwargs): + if TYPE_CHECKING: # redirected to: + utils3d.numpy.extrinsics_look_at, utils3d.torch.extrinsics_look_at + return _call_based_on_args('extrinsics_look_at', args, kwargs) + +@suppress_traceback +def perspective_to_intrinsics(*args, **kwargs): + if TYPE_CHECKING: # redirected to: + utils3d.numpy.perspective_to_intrinsics, utils3d.torch.perspective_to_intrinsics + return _call_based_on_args('perspective_to_intrinsics', args, kwargs) + +@suppress_traceback +def perspective_to_near_far(*args, **kwargs): + if TYPE_CHECKING: # redirected to: + utils3d.numpy.perspective_to_near_far, None + return _call_based_on_args('perspective_to_near_far', args, kwargs) + +@suppress_traceback +def intrinsics_to_perspective(*args, **kwargs): + if TYPE_CHECKING: # redirected to: + utils3d.numpy.intrinsics_to_perspective, utils3d.torch.intrinsics_to_perspective + return _call_based_on_args('intrinsics_to_perspective', args, kwargs) + +@suppress_traceback +def extrinsics_to_view(*args, **kwargs): + if TYPE_CHECKING: # redirected to: + utils3d.numpy.extrinsics_to_view, utils3d.torch.extrinsics_to_view + return _call_based_on_args('extrinsics_to_view', args, kwargs) + +@suppress_traceback +def view_to_extrinsics(*args, **kwargs): + if TYPE_CHECKING: # redirected to: + utils3d.numpy.view_to_extrinsics, utils3d.torch.view_to_extrinsics + return _call_based_on_args('view_to_extrinsics', args, kwargs) + +@suppress_traceback +def normalize_intrinsics(*args, **kwargs): + if TYPE_CHECKING: # redirected to: + utils3d.numpy.normalize_intrinsics, utils3d.torch.normalize_intrinsics + return _call_based_on_args('normalize_intrinsics', args, kwargs) + +@suppress_traceback +def crop_intrinsics(*args, **kwargs): + if TYPE_CHECKING: # redirected to: + utils3d.numpy.crop_intrinsics, utils3d.torch.crop_intrinsics + return _call_based_on_args('crop_intrinsics', args, kwargs) + +@suppress_traceback +def pixel_to_uv(*args, **kwargs): + if TYPE_CHECKING: # redirected to: + utils3d.numpy.pixel_to_uv, utils3d.torch.pixel_to_uv + return _call_based_on_args('pixel_to_uv', args, kwargs) + +@suppress_traceback +def pixel_to_ndc(*args, **kwargs): + if TYPE_CHECKING: # redirected to: + utils3d.numpy.pixel_to_ndc, utils3d.torch.pixel_to_ndc + return _call_based_on_args('pixel_to_ndc', args, kwargs) + +@suppress_traceback +def uv_to_pixel(*args, **kwargs): + if TYPE_CHECKING: # redirected to: + utils3d.numpy.uv_to_pixel, utils3d.torch.uv_to_pixel + return _call_based_on_args('uv_to_pixel', args, kwargs) + +@suppress_traceback +def project_depth(*args, **kwargs): + if TYPE_CHECKING: # redirected to: + utils3d.numpy.project_depth, utils3d.torch.project_depth + return _call_based_on_args('project_depth', args, kwargs) + +@suppress_traceback +def depth_buffer_to_linear(*args, **kwargs): + if TYPE_CHECKING: # redirected to: + utils3d.numpy.depth_buffer_to_linear, utils3d.torch.depth_buffer_to_linear + return _call_based_on_args('depth_buffer_to_linear', args, kwargs) + +@suppress_traceback +def unproject_cv(*args, **kwargs): + if TYPE_CHECKING: # redirected to: + utils3d.numpy.unproject_cv, utils3d.torch.unproject_cv + return _call_based_on_args('unproject_cv', args, kwargs) + +@suppress_traceback +def unproject_gl(*args, **kwargs): + if TYPE_CHECKING: # redirected to: + utils3d.numpy.unproject_gl, utils3d.torch.unproject_gl + return _call_based_on_args('unproject_gl', args, kwargs) + +@suppress_traceback +def project_cv(*args, **kwargs): + if TYPE_CHECKING: # redirected to: + utils3d.numpy.project_cv, utils3d.torch.project_cv + return _call_based_on_args('project_cv', args, kwargs) + +@suppress_traceback +def project_gl(*args, **kwargs): + if TYPE_CHECKING: # redirected to: + utils3d.numpy.project_gl, utils3d.torch.project_gl + return _call_based_on_args('project_gl', args, kwargs) + +@suppress_traceback +def quaternion_to_matrix(*args, **kwargs): + if TYPE_CHECKING: # redirected to: + utils3d.numpy.quaternion_to_matrix, utils3d.torch.quaternion_to_matrix + return _call_based_on_args('quaternion_to_matrix', args, kwargs) + +@suppress_traceback +def axis_angle_to_matrix(*args, **kwargs): + if TYPE_CHECKING: # redirected to: + utils3d.numpy.axis_angle_to_matrix, utils3d.torch.axis_angle_to_matrix + return _call_based_on_args('axis_angle_to_matrix', args, kwargs) + +@suppress_traceback +def matrix_to_quaternion(*args, **kwargs): + if TYPE_CHECKING: # redirected to: + utils3d.numpy.matrix_to_quaternion, utils3d.torch.matrix_to_quaternion + return _call_based_on_args('matrix_to_quaternion', args, kwargs) + +@suppress_traceback +def extrinsics_to_essential(*args, **kwargs): + if TYPE_CHECKING: # redirected to: + utils3d.numpy.extrinsics_to_essential, utils3d.torch.extrinsics_to_essential + return _call_based_on_args('extrinsics_to_essential', args, kwargs) + +@suppress_traceback +def euler_axis_angle_rotation(*args, **kwargs): + if TYPE_CHECKING: # redirected to: + utils3d.numpy.euler_axis_angle_rotation, utils3d.torch.euler_axis_angle_rotation + return _call_based_on_args('euler_axis_angle_rotation', args, kwargs) + +@suppress_traceback +def euler_angles_to_matrix(*args, **kwargs): + if TYPE_CHECKING: # redirected to: + utils3d.numpy.euler_angles_to_matrix, utils3d.torch.euler_angles_to_matrix + return _call_based_on_args('euler_angles_to_matrix', args, kwargs) + +@suppress_traceback +def skew_symmetric(*args, **kwargs): + if TYPE_CHECKING: # redirected to: + utils3d.numpy.skew_symmetric, utils3d.torch.skew_symmetric + return _call_based_on_args('skew_symmetric', args, kwargs) + +@suppress_traceback +def rotation_matrix_from_vectors(*args, **kwargs): + if TYPE_CHECKING: # redirected to: + utils3d.numpy.rotation_matrix_from_vectors, utils3d.torch.rotation_matrix_from_vectors + return _call_based_on_args('rotation_matrix_from_vectors', args, kwargs) + +@suppress_traceback +def ray_intersection(*args, **kwargs): + if TYPE_CHECKING: # redirected to: + utils3d.numpy.ray_intersection, None + return _call_based_on_args('ray_intersection', args, kwargs) + +@suppress_traceback +def se3_matrix(*args, **kwargs): + if TYPE_CHECKING: # redirected to: + utils3d.numpy.se3_matrix, None + return _call_based_on_args('se3_matrix', args, kwargs) + +@suppress_traceback +def slerp_quaternion(*args, **kwargs): + if TYPE_CHECKING: # redirected to: + utils3d.numpy.slerp_quaternion, None + return _call_based_on_args('slerp_quaternion', args, kwargs) + +@suppress_traceback +def slerp_vector(*args, **kwargs): + if TYPE_CHECKING: # redirected to: + utils3d.numpy.slerp_vector, None + return _call_based_on_args('slerp_vector', args, kwargs) + +@suppress_traceback +def lerp(*args, **kwargs): + if TYPE_CHECKING: # redirected to: + utils3d.numpy.lerp, None + return _call_based_on_args('lerp', args, kwargs) + +@suppress_traceback +def lerp_se3_matrix(*args, **kwargs): + if TYPE_CHECKING: # redirected to: + utils3d.numpy.lerp_se3_matrix, None + return _call_based_on_args('lerp_se3_matrix', args, kwargs) + +@suppress_traceback +def piecewise_lerp(*args, **kwargs): + if TYPE_CHECKING: # redirected to: + utils3d.numpy.piecewise_lerp, None + return _call_based_on_args('piecewise_lerp', args, kwargs) + +@suppress_traceback +def piecewise_lerp_se3_matrix(*args, **kwargs): + if TYPE_CHECKING: # redirected to: + utils3d.numpy.piecewise_lerp_se3_matrix, None + return _call_based_on_args('piecewise_lerp_se3_matrix', args, kwargs) + +@suppress_traceback +def apply_transform(*args, **kwargs): + if TYPE_CHECKING: # redirected to: + utils3d.numpy.apply_transform, None + return _call_based_on_args('apply_transform', args, kwargs) + +@suppress_traceback +def linear_spline_interpolate(*args, **kwargs): + if TYPE_CHECKING: # redirected to: + utils3d.numpy.linear_spline_interpolate, None + return _call_based_on_args('linear_spline_interpolate', args, kwargs) + +@suppress_traceback +def RastContext(*args, **kwargs): + if TYPE_CHECKING: # redirected to: + utils3d.numpy.RastContext, utils3d.torch.RastContext + return _call_based_on_args('RastContext', args, kwargs) + +@suppress_traceback +def rasterize_triangle_faces(*args, **kwargs): + if TYPE_CHECKING: # redirected to: + utils3d.numpy.rasterize_triangle_faces, utils3d.torch.rasterize_triangle_faces + return _call_based_on_args('rasterize_triangle_faces', args, kwargs) + +@suppress_traceback +def rasterize_edges(*args, **kwargs): + if TYPE_CHECKING: # redirected to: + utils3d.numpy.rasterize_edges, None + return _call_based_on_args('rasterize_edges', args, kwargs) + +@suppress_traceback +def texture(*args, **kwargs): + if TYPE_CHECKING: # redirected to: + utils3d.numpy.texture, None + return _call_based_on_args('texture', args, kwargs) + +@suppress_traceback +def warp_image_by_depth(*args, **kwargs): + if TYPE_CHECKING: # redirected to: + utils3d.numpy.warp_image_by_depth, utils3d.torch.warp_image_by_depth + return _call_based_on_args('warp_image_by_depth', args, kwargs) + +@suppress_traceback +def test_rasterization(*args, **kwargs): + if TYPE_CHECKING: # redirected to: + utils3d.numpy.test_rasterization, None + return _call_based_on_args('test_rasterization', args, kwargs) + +@suppress_traceback +def compute_face_angles(*args, **kwargs): + if TYPE_CHECKING: # redirected to: + None, utils3d.torch.compute_face_angles + return _call_based_on_args('compute_face_angles', args, kwargs) + +@suppress_traceback +def compute_face_tbn(*args, **kwargs): + if TYPE_CHECKING: # redirected to: + None, utils3d.torch.compute_face_tbn + return _call_based_on_args('compute_face_tbn', args, kwargs) + +@suppress_traceback +def compute_vertex_tbn(*args, **kwargs): + if TYPE_CHECKING: # redirected to: + None, utils3d.torch.compute_vertex_tbn + return _call_based_on_args('compute_vertex_tbn', args, kwargs) + +@suppress_traceback +def laplacian(*args, **kwargs): + if TYPE_CHECKING: # redirected to: + None, utils3d.torch.laplacian + return _call_based_on_args('laplacian', args, kwargs) + +@suppress_traceback +def laplacian_smooth_mesh(*args, **kwargs): + if TYPE_CHECKING: # redirected to: + None, utils3d.torch.laplacian_smooth_mesh + return _call_based_on_args('laplacian_smooth_mesh', args, kwargs) + +@suppress_traceback +def taubin_smooth_mesh(*args, **kwargs): + if TYPE_CHECKING: # redirected to: + None, utils3d.torch.taubin_smooth_mesh + return _call_based_on_args('taubin_smooth_mesh', args, kwargs) + +@suppress_traceback +def laplacian_hc_smooth_mesh(*args, **kwargs): + if TYPE_CHECKING: # redirected to: + None, utils3d.torch.laplacian_hc_smooth_mesh + return _call_based_on_args('laplacian_hc_smooth_mesh', args, kwargs) + +@suppress_traceback +def get_rays(*args, **kwargs): + if TYPE_CHECKING: # redirected to: + None, utils3d.torch.get_rays + return _call_based_on_args('get_rays', args, kwargs) + +@suppress_traceback +def get_image_rays(*args, **kwargs): + if TYPE_CHECKING: # redirected to: + None, utils3d.torch.get_image_rays + return _call_based_on_args('get_image_rays', args, kwargs) + +@suppress_traceback +def get_mipnerf_cones(*args, **kwargs): + if TYPE_CHECKING: # redirected to: + None, utils3d.torch.get_mipnerf_cones + return _call_based_on_args('get_mipnerf_cones', args, kwargs) + +@suppress_traceback +def volume_rendering(*args, **kwargs): + if TYPE_CHECKING: # redirected to: + None, utils3d.torch.volume_rendering + return _call_based_on_args('volume_rendering', args, kwargs) + +@suppress_traceback +def bin_sample(*args, **kwargs): + if TYPE_CHECKING: # redirected to: + None, utils3d.torch.bin_sample + return _call_based_on_args('bin_sample', args, kwargs) + +@suppress_traceback +def importance_sample(*args, **kwargs): + if TYPE_CHECKING: # redirected to: + None, utils3d.torch.importance_sample + return _call_based_on_args('importance_sample', args, kwargs) + +@suppress_traceback +def nerf_render_rays(*args, **kwargs): + if TYPE_CHECKING: # redirected to: + None, utils3d.torch.nerf_render_rays + return _call_based_on_args('nerf_render_rays', args, kwargs) + +@suppress_traceback +def mipnerf_render_rays(*args, **kwargs): + if TYPE_CHECKING: # redirected to: + None, utils3d.torch.mipnerf_render_rays + return _call_based_on_args('mipnerf_render_rays', args, kwargs) + +@suppress_traceback +def nerf_render_view(*args, **kwargs): + if TYPE_CHECKING: # redirected to: + None, utils3d.torch.nerf_render_view + return _call_based_on_args('nerf_render_view', args, kwargs) + +@suppress_traceback +def mipnerf_render_view(*args, **kwargs): + if TYPE_CHECKING: # redirected to: + None, utils3d.torch.mipnerf_render_view + return _call_based_on_args('mipnerf_render_view', args, kwargs) + +@suppress_traceback +def InstantNGP(*args, **kwargs): + if TYPE_CHECKING: # redirected to: + None, utils3d.torch.InstantNGP + return _call_based_on_args('InstantNGP', args, kwargs) + +@suppress_traceback +def point_to_normal(*args, **kwargs): + if TYPE_CHECKING: # redirected to: + None, utils3d.torch.point_to_normal + return _call_based_on_args('point_to_normal', args, kwargs) + +@suppress_traceback +def depth_to_normal(*args, **kwargs): + if TYPE_CHECKING: # redirected to: + None, utils3d.torch.depth_to_normal + return _call_based_on_args('depth_to_normal', args, kwargs) + +@suppress_traceback +def masked_min(*args, **kwargs): + if TYPE_CHECKING: # redirected to: + None, utils3d.torch.masked_min + return _call_based_on_args('masked_min', args, kwargs) + +@suppress_traceback +def masked_max(*args, **kwargs): + if TYPE_CHECKING: # redirected to: + None, utils3d.torch.masked_max + return _call_based_on_args('masked_max', args, kwargs) + +@suppress_traceback +def bounding_rect(*args, **kwargs): + if TYPE_CHECKING: # redirected to: + None, utils3d.torch.bounding_rect + return _call_based_on_args('bounding_rect', args, kwargs) + +@suppress_traceback +def intrinsics_from_fov_xy(*args, **kwargs): + if TYPE_CHECKING: # redirected to: + None, utils3d.torch.intrinsics_from_fov_xy + return _call_based_on_args('intrinsics_from_fov_xy', args, kwargs) + +@suppress_traceback +def matrix_to_euler_angles(*args, **kwargs): + if TYPE_CHECKING: # redirected to: + None, utils3d.torch.matrix_to_euler_angles + return _call_based_on_args('matrix_to_euler_angles', args, kwargs) + +@suppress_traceback +def matrix_to_axis_angle(*args, **kwargs): + if TYPE_CHECKING: # redirected to: + None, utils3d.torch.matrix_to_axis_angle + return _call_based_on_args('matrix_to_axis_angle', args, kwargs) + +@suppress_traceback +def axis_angle_to_quaternion(*args, **kwargs): + if TYPE_CHECKING: # redirected to: + None, utils3d.torch.axis_angle_to_quaternion + return _call_based_on_args('axis_angle_to_quaternion', args, kwargs) + +@suppress_traceback +def quaternion_to_axis_angle(*args, **kwargs): + if TYPE_CHECKING: # redirected to: + None, utils3d.torch.quaternion_to_axis_angle + return _call_based_on_args('quaternion_to_axis_angle', args, kwargs) + +@suppress_traceback +def slerp(*args, **kwargs): + if TYPE_CHECKING: # redirected to: + None, utils3d.torch.slerp + return _call_based_on_args('slerp', args, kwargs) + +@suppress_traceback +def interpolate_extrinsics(*args, **kwargs): + if TYPE_CHECKING: # redirected to: + None, utils3d.torch.interpolate_extrinsics + return _call_based_on_args('interpolate_extrinsics', args, kwargs) + +@suppress_traceback +def interpolate_view(*args, **kwargs): + if TYPE_CHECKING: # redirected to: + None, utils3d.torch.interpolate_view + return _call_based_on_args('interpolate_view', args, kwargs) + +@suppress_traceback +def to4x4(*args, **kwargs): + if TYPE_CHECKING: # redirected to: + None, utils3d.torch.to4x4 + return _call_based_on_args('to4x4', args, kwargs) + +@suppress_traceback +def rotation_matrix_2d(*args, **kwargs): + if TYPE_CHECKING: # redirected to: + None, utils3d.torch.rotation_matrix_2d + return _call_based_on_args('rotation_matrix_2d', args, kwargs) + +@suppress_traceback +def rotate_2d(*args, **kwargs): + if TYPE_CHECKING: # redirected to: + None, utils3d.torch.rotate_2d + return _call_based_on_args('rotate_2d', args, kwargs) + +@suppress_traceback +def translate_2d(*args, **kwargs): + if TYPE_CHECKING: # redirected to: + None, utils3d.torch.translate_2d + return _call_based_on_args('translate_2d', args, kwargs) + +@suppress_traceback +def scale_2d(*args, **kwargs): + if TYPE_CHECKING: # redirected to: + None, utils3d.torch.scale_2d + return _call_based_on_args('scale_2d', args, kwargs) + +@suppress_traceback +def apply_2d(*args, **kwargs): + if TYPE_CHECKING: # redirected to: + None, utils3d.torch.apply_2d + return _call_based_on_args('apply_2d', args, kwargs) + +@suppress_traceback +def warp_image_by_forward_flow(*args, **kwargs): + if TYPE_CHECKING: # redirected to: + None, utils3d.torch.warp_image_by_forward_flow + return _call_based_on_args('warp_image_by_forward_flow', args, kwargs) + diff --git a/submodules/MoGe/utils3d/_unified/__init__.pyi b/submodules/MoGe/utils3d/_unified/__init__.pyi new file mode 100644 index 0000000000000000000000000000000000000000..28f662efdf95ed2894c9af693674e74b307109dc --- /dev/null +++ b/submodules/MoGe/utils3d/_unified/__init__.pyi @@ -0,0 +1,2431 @@ +# Auto-generated interface file +from typing import List, Tuple, Dict, Union, Optional, Any, overload, Literal, Callable +import numpy as numpy_ +import torch as torch_ +import nvdiffrast.torch +import numbers +from . import numpy, torch +import utils3d.numpy, utils3d.torch + +__all__ = ["triangulate", +"compute_face_normal", +"compute_face_angle", +"compute_vertex_normal", +"compute_vertex_normal_weighted", +"remove_corrupted_faces", +"merge_duplicate_vertices", +"remove_unreferenced_vertices", +"subdivide_mesh_simple", +"mesh_relations", +"flatten_mesh_indices", +"calc_quad_candidates", +"calc_quad_distortion", +"calc_quad_direction", +"calc_quad_smoothness", +"sovle_quad", +"sovle_quad_qp", +"tri_to_quad", +"sliding_window_1d", +"sliding_window_nd", +"sliding_window_2d", +"max_pool_1d", +"max_pool_2d", +"max_pool_nd", +"depth_edge", +"normals_edge", +"depth_aliasing", +"interpolate", +"image_scrcoord", +"image_uv", +"image_pixel_center", +"image_pixel", +"image_mesh", +"image_mesh_from_depth", +"depth_to_normals", +"points_to_normals", +"chessboard", +"cube", +"icosahedron", +"square", +"camera_frustum", +"perspective", +"perspective_from_fov", +"perspective_from_fov_xy", +"intrinsics_from_focal_center", +"intrinsics_from_fov", +"fov_to_focal", +"focal_to_fov", +"intrinsics_to_fov", +"view_look_at", +"extrinsics_look_at", +"perspective_to_intrinsics", +"perspective_to_near_far", +"intrinsics_to_perspective", +"extrinsics_to_view", +"view_to_extrinsics", +"normalize_intrinsics", +"crop_intrinsics", +"pixel_to_uv", +"pixel_to_ndc", +"uv_to_pixel", +"project_depth", +"depth_buffer_to_linear", +"unproject_cv", +"unproject_gl", +"project_cv", +"project_gl", +"quaternion_to_matrix", +"axis_angle_to_matrix", +"matrix_to_quaternion", +"extrinsics_to_essential", +"euler_axis_angle_rotation", +"euler_angles_to_matrix", +"skew_symmetric", +"rotation_matrix_from_vectors", +"ray_intersection", +"se3_matrix", +"slerp_quaternion", +"slerp_vector", +"lerp", +"lerp_se3_matrix", +"piecewise_lerp", +"piecewise_lerp_se3_matrix", +"apply_transform", +"linear_spline_interpolate", +"RastContext", +"rasterize_triangle_faces", +"rasterize_edges", +"texture", +"warp_image_by_depth", +"test_rasterization", +"compute_face_angles", +"compute_face_tbn", +"compute_vertex_tbn", +"laplacian", +"laplacian_smooth_mesh", +"taubin_smooth_mesh", +"laplacian_hc_smooth_mesh", +"get_rays", +"get_image_rays", +"get_mipnerf_cones", +"volume_rendering", +"bin_sample", +"importance_sample", +"nerf_render_rays", +"mipnerf_render_rays", +"nerf_render_view", +"mipnerf_render_view", +"InstantNGP", +"point_to_normal", +"depth_to_normal", +"masked_min", +"masked_max", +"bounding_rect", +"intrinsics_from_fov_xy", +"matrix_to_euler_angles", +"matrix_to_axis_angle", +"axis_angle_to_quaternion", +"quaternion_to_axis_angle", +"slerp", +"interpolate_extrinsics", +"interpolate_view", +"to4x4", +"rotation_matrix_2d", +"rotate_2d", +"translate_2d", +"scale_2d", +"apply_2d", +"warp_image_by_forward_flow"] + +@overload +def triangulate(faces: numpy_.ndarray, vertices: numpy_.ndarray = None, backslash: numpy_.ndarray = None) -> numpy_.ndarray: + """Triangulate a polygonal mesh. + +Args: + faces (np.ndarray): [L, P] polygonal faces + vertices (np.ndarray, optional): [N, 3] 3-dimensional vertices. + If given, the triangulation is performed according to the distance + between vertices. Defaults to None. + backslash (np.ndarray, optional): [L] boolean array indicating + how to triangulate the quad faces. Defaults to None. + +Returns: + (np.ndarray): [L * (P - 2), 3] triangular faces""" + utils3d.numpy.mesh.triangulate + +@overload +def compute_face_normal(vertices: numpy_.ndarray, faces: numpy_.ndarray) -> numpy_.ndarray: + """Compute face normals of a triangular mesh + +Args: + vertices (np.ndarray): [..., N, 3] 3-dimensional vertices + faces (np.ndarray): [T, 3] triangular face indices + +Returns: + normals (np.ndarray): [..., T, 3] face normals""" + utils3d.numpy.mesh.compute_face_normal + +@overload +def compute_face_angle(vertices: numpy_.ndarray, faces: numpy_.ndarray, eps: float = 1e-12) -> numpy_.ndarray: + """Compute face angles of a triangular mesh + +Args: + vertices (np.ndarray): [..., N, 3] 3-dimensional vertices + faces (np.ndarray): [T, 3] triangular face indices + +Returns: + angles (np.ndarray): [..., T, 3] face angles""" + utils3d.numpy.mesh.compute_face_angle + +@overload +def compute_vertex_normal(vertices: numpy_.ndarray, faces: numpy_.ndarray, face_normal: numpy_.ndarray = None) -> numpy_.ndarray: + """Compute vertex normals of a triangular mesh by averaging neightboring face normals +TODO: can be improved. + +Args: + vertices (np.ndarray): [..., N, 3] 3-dimensional vertices + faces (np.ndarray): [T, 3] triangular face indices + face_normal (np.ndarray, optional): [..., T, 3] face normals. + None to compute face normals from vertices and faces. Defaults to None. + +Returns: + normals (np.ndarray): [..., N, 3] vertex normals""" + utils3d.numpy.mesh.compute_vertex_normal + +@overload +def compute_vertex_normal_weighted(vertices: numpy_.ndarray, faces: numpy_.ndarray, face_normal: numpy_.ndarray = None) -> numpy_.ndarray: + """Compute vertex normals of a triangular mesh by weighted sum of neightboring face normals +according to the angles + +Args: + vertices (np.ndarray): [..., N, 3] 3-dimensional vertices + faces (np.ndarray): [..., T, 3] triangular face indices + face_normal (np.ndarray, optional): [..., T, 3] face normals. + None to compute face normals from vertices and faces. Defaults to None. + +Returns: + normals (np.ndarray): [..., N, 3] vertex normals""" + utils3d.numpy.mesh.compute_vertex_normal_weighted + +@overload +def remove_corrupted_faces(faces: numpy_.ndarray) -> numpy_.ndarray: + """Remove corrupted faces (faces with duplicated vertices) + +Args: + faces (np.ndarray): [T, 3] triangular face indices + +Returns: + np.ndarray: [T_, 3] triangular face indices""" + utils3d.numpy.mesh.remove_corrupted_faces + +@overload +def merge_duplicate_vertices(vertices: numpy_.ndarray, faces: numpy_.ndarray, tol: float = 1e-06) -> Tuple[numpy_.ndarray, numpy_.ndarray]: + """Merge duplicate vertices of a triangular mesh. +Duplicate vertices are merged by selecte one of them, and the face indices are updated accordingly. + +Args: + vertices (np.ndarray): [N, 3] 3-dimensional vertices + faces (np.ndarray): [T, 3] triangular face indices + tol (float, optional): tolerance for merging. Defaults to 1e-6. + +Returns: + vertices (np.ndarray): [N_, 3] 3-dimensional vertices + faces (np.ndarray): [T, 3] triangular face indices""" + utils3d.numpy.mesh.merge_duplicate_vertices + +@overload +def remove_unreferenced_vertices(faces: numpy_.ndarray, *vertice_attrs, return_indices: bool = False) -> Tuple[numpy_.ndarray, ...]: + """Remove unreferenced vertices of a mesh. +Unreferenced vertices are removed, and the face indices are updated accordingly. + +Args: + faces (np.ndarray): [T, P] face indices + *vertice_attrs: vertex attributes + +Returns: + faces (np.ndarray): [T, P] face indices + *vertice_attrs: vertex attributes + indices (np.ndarray, optional): [N] indices of vertices that are kept. Defaults to None.""" + utils3d.numpy.mesh.remove_unreferenced_vertices + +@overload +def subdivide_mesh_simple(vertices: numpy_.ndarray, faces: numpy_.ndarray, n: int = 1) -> Tuple[numpy_.ndarray, numpy_.ndarray]: + """Subdivide a triangular mesh by splitting each triangle into 4 smaller triangles. +NOTE: All original vertices are kept, and new vertices are appended to the end of the vertex list. + +Args: + vertices (np.ndarray): [N, 3] 3-dimensional vertices + faces (np.ndarray): [T, 3] triangular face indices + n (int, optional): number of subdivisions. Defaults to 1. + +Returns: + vertices (np.ndarray): [N_, 3] subdivided 3-dimensional vertices + faces (np.ndarray): [4 * T, 3] subdivided triangular face indices""" + utils3d.numpy.mesh.subdivide_mesh_simple + +@overload +def mesh_relations(faces: numpy_.ndarray) -> Tuple[numpy_.ndarray, numpy_.ndarray]: + """Calculate the relation between vertices and faces. +NOTE: The input mesh must be a manifold triangle mesh. + +Args: + faces (np.ndarray): [T, 3] triangular face indices + +Returns: + edges (np.ndarray): [E, 2] edge indices + edge2face (np.ndarray): [E, 2] edge to face relation. The second column is -1 if the edge is boundary. + face2edge (np.ndarray): [T, 3] face to edge relation + face2face (np.ndarray): [T, 3] face to face relation""" + utils3d.numpy.mesh.mesh_relations + +@overload +def flatten_mesh_indices(*args: numpy_.ndarray) -> Tuple[numpy_.ndarray, ...]: + utils3d.numpy.mesh.flatten_mesh_indices + +@overload +def calc_quad_candidates(edges: numpy_.ndarray, face2edge: numpy_.ndarray, edge2face: numpy_.ndarray): + """Calculate the candidate quad faces. + +Args: + edges (np.ndarray): [E, 2] edge indices + face2edge (np.ndarray): [T, 3] face to edge relation + edge2face (np.ndarray): [E, 2] edge to face relation + +Returns: + quads (np.ndarray): [Q, 4] quad candidate indices + quad2edge (np.ndarray): [Q, 4] edge to quad candidate relation + quad2adj (np.ndarray): [Q, 8] adjacent quad candidates of each quad candidate + quads_valid (np.ndarray): [E] whether the quad corresponding to the edge is valid""" + utils3d.numpy.quadmesh.calc_quad_candidates + +@overload +def calc_quad_distortion(vertices: numpy_.ndarray, quads: numpy_.ndarray): + """Calculate the distortion of each candidate quad face. + +Args: + vertices (np.ndarray): [N, 3] 3-dimensional vertices + quads (np.ndarray): [Q, 4] quad face indices + +Returns: + distortion (np.ndarray): [Q] distortion of each quad face""" + utils3d.numpy.quadmesh.calc_quad_distortion + +@overload +def calc_quad_direction(vertices: numpy_.ndarray, quads: numpy_.ndarray): + """Calculate the direction of each candidate quad face. + +Args: + vertices (np.ndarray): [N, 3] 3-dimensional vertices + quads (np.ndarray): [Q, 4] quad face indices + +Returns: + direction (np.ndarray): [Q, 4] direction of each quad face. + Represented by the angle between the crossing and each edge.""" + utils3d.numpy.quadmesh.calc_quad_direction + +@overload +def calc_quad_smoothness(quad2edge: numpy_.ndarray, quad2adj: numpy_.ndarray, quads_direction: numpy_.ndarray): + """Calculate the smoothness of each candidate quad face connection. + +Args: + quad2adj (np.ndarray): [Q, 8] adjacent quad faces of each quad face + quads_direction (np.ndarray): [Q, 4] direction of each quad face + +Returns: + smoothness (np.ndarray): [Q, 8] smoothness of each quad face connection""" + utils3d.numpy.quadmesh.calc_quad_smoothness + +@overload +def sovle_quad(face2edge: numpy_.ndarray, edge2face: numpy_.ndarray, quad2adj: numpy_.ndarray, quads_distortion: numpy_.ndarray, quads_smoothness: numpy_.ndarray, quads_valid: numpy_.ndarray): + """Solve the quad mesh from the candidate quad faces. + +Args: + face2edge (np.ndarray): [T, 3] face to edge relation + edge2face (np.ndarray): [E, 2] edge to face relation + quad2adj (np.ndarray): [Q, 8] adjacent quad faces of each quad face + quads_distortion (np.ndarray): [Q] distortion of each quad face + quads_smoothness (np.ndarray): [Q, 8] smoothness of each quad face connection + quads_valid (np.ndarray): [E] whether the quad corresponding to the edge is valid + +Returns: + weights (np.ndarray): [Q] weight of each valid quad face""" + utils3d.numpy.quadmesh.sovle_quad + +@overload +def sovle_quad_qp(face2edge: numpy_.ndarray, edge2face: numpy_.ndarray, quad2adj: numpy_.ndarray, quads_distortion: numpy_.ndarray, quads_smoothness: numpy_.ndarray, quads_valid: numpy_.ndarray): + """Solve the quad mesh from the candidate quad faces. + +Args: + face2edge (np.ndarray): [T, 3] face to edge relation + edge2face (np.ndarray): [E, 2] edge to face relation + quad2adj (np.ndarray): [Q, 8] adjacent quad faces of each quad face + quads_distortion (np.ndarray): [Q] distortion of each quad face + quads_smoothness (np.ndarray): [Q, 8] smoothness of each quad face connection + quads_valid (np.ndarray): [E] whether the quad corresponding to the edge is valid + +Returns: + weights (np.ndarray): [Q] weight of each valid quad face""" + utils3d.numpy.quadmesh.sovle_quad_qp + +@overload +def tri_to_quad(vertices: numpy_.ndarray, faces: numpy_.ndarray) -> Tuple[numpy_.ndarray, numpy_.ndarray]: + """Convert a triangle mesh to a quad mesh. +NOTE: The input mesh must be a manifold mesh. + +Args: + vertices (np.ndarray): [N, 3] 3-dimensional vertices + faces (np.ndarray): [T, 3] triangular face indices + +Returns: + vertices (np.ndarray): [N_, 3] 3-dimensional vertices + faces (np.ndarray): [Q, 4] quad face indices""" + utils3d.numpy.quadmesh.tri_to_quad + +@overload +def sliding_window_1d(x: numpy_.ndarray, window_size: int, stride: int, axis: int = -1): + """Return x view of the input array with x sliding window of the given kernel size and stride. +The sliding window is performed over the given axis, and the window dimension is append to the end of the output array's shape. + +Args: + x (np.ndarray): input array with shape (..., axis_size, ...) + kernel_size (int): size of the sliding window + stride (int): stride of the sliding window + axis (int): axis to perform sliding window over + +Returns: + a_sliding (np.ndarray): view of the input array with shape (..., n_windows, ..., kernel_size), where n_windows = (axis_size - kernel_size + 1) // stride""" + utils3d.numpy.utils.sliding_window_1d + +@overload +def sliding_window_nd(x: numpy_.ndarray, window_size: Tuple[int, ...], stride: Tuple[int, ...], axis: Tuple[int, ...]) -> numpy_.ndarray: + utils3d.numpy.utils.sliding_window_nd + +@overload +def sliding_window_2d(x: numpy_.ndarray, window_size: Union[int, Tuple[int, int]], stride: Union[int, Tuple[int, int]], axis: Tuple[int, int] = (-2, -1)) -> numpy_.ndarray: + utils3d.numpy.utils.sliding_window_2d + +@overload +def max_pool_1d(x: numpy_.ndarray, kernel_size: int, stride: int, padding: int = 0, axis: int = -1): + utils3d.numpy.utils.max_pool_1d + +@overload +def max_pool_2d(x: numpy_.ndarray, kernel_size: Union[int, Tuple[int, int]], stride: Union[int, Tuple[int, int]], padding: Union[int, Tuple[int, int]], axis: Tuple[int, int] = (-2, -1)): + utils3d.numpy.utils.max_pool_2d + +@overload +def max_pool_nd(x: numpy_.ndarray, kernel_size: Tuple[int, ...], stride: Tuple[int, ...], padding: Tuple[int, ...], axis: Tuple[int, ...]) -> numpy_.ndarray: + utils3d.numpy.utils.max_pool_nd + +@overload +def depth_edge(depth: numpy_.ndarray, atol: float = None, rtol: float = None, kernel_size: int = 3, mask: numpy_.ndarray = None) -> numpy_.ndarray: + """Compute the edge mask from depth map. The edge is defined as the pixels whose neighbors have large difference in depth. + +Args: + depth (np.ndarray): shape (..., height, width), linear depth map + atol (float): absolute tolerance + rtol (float): relative tolerance + +Returns: + edge (np.ndarray): shape (..., height, width) of dtype torch.bool""" + utils3d.numpy.utils.depth_edge + +@overload +def normals_edge(normals: numpy_.ndarray, tol: float, kernel_size: int = 3, mask: numpy_.ndarray = None) -> numpy_.ndarray: + """Compute the edge mask from normal map. + +Args: + normal (np.ndarray): shape (..., height, width, 3), normal map + tol (float): tolerance in degrees + +Returns: + edge (np.ndarray): shape (..., height, width) of dtype torch.bool""" + utils3d.numpy.utils.normals_edge + +@overload +def depth_aliasing(depth: numpy_.ndarray, atol: float = None, rtol: float = None, kernel_size: int = 3, mask: numpy_.ndarray = None) -> numpy_.ndarray: + """Compute the map that indicates the aliasing of x depth map. The aliasing is defined as the pixels which neither close to the maximum nor the minimum of its neighbors. +Args: + depth (np.ndarray): shape (..., height, width), linear depth map + atol (float): absolute tolerance + rtol (float): relative tolerance + +Returns: + edge (np.ndarray): shape (..., height, width) of dtype torch.bool""" + utils3d.numpy.utils.depth_aliasing + +@overload +def interpolate(bary: numpy_.ndarray, tri_id: numpy_.ndarray, attr: numpy_.ndarray, faces: numpy_.ndarray) -> numpy_.ndarray: + """Interpolate with given barycentric coordinates and triangle indices + +Args: + bary (np.ndarray): shape (..., 3), barycentric coordinates + tri_id (np.ndarray): int array of shape (...), triangle indices + attr (np.ndarray): shape (N, M), vertices attributes + faces (np.ndarray): int array of shape (T, 3), face vertex indices + +Returns: + np.ndarray: shape (..., M) interpolated result""" + utils3d.numpy.utils.interpolate + +@overload +def image_scrcoord(width: int, height: int) -> numpy_.ndarray: + """Get OpenGL's screen space coordinates, ranging in [0, 1]. +[0, 0] is the bottom-left corner of the image. + +Args: + width (int): image width + height (int): image height + +Returns: + (np.ndarray): shape (height, width, 2)""" + utils3d.numpy.utils.image_scrcoord + +@overload +def image_uv(height: int, width: int, left: int = None, top: int = None, right: int = None, bottom: int = None, dtype: numpy_.dtype = numpy_.float32) -> numpy_.ndarray: + """Get image space UV grid, ranging in [0, 1]. + +>>> image_uv(10, 10): +[[[0.05, 0.05], [0.15, 0.05], ..., [0.95, 0.05]], + [[0.05, 0.15], [0.15, 0.15], ..., [0.95, 0.15]], + ... ... ... + [[0.05, 0.95], [0.15, 0.95], ..., [0.95, 0.95]]] + +Args: + width (int): image width + height (int): image height + +Returns: + np.ndarray: shape (height, width, 2)""" + utils3d.numpy.utils.image_uv + +@overload +def image_pixel_center(height: int, width: int, left: int = None, top: int = None, right: int = None, bottom: int = None, dtype: numpy_.dtype = numpy_.float32) -> numpy_.ndarray: + """Get image pixel center coordinates, ranging in [0, width] and [0, height]. +`image[i, j]` has pixel center coordinates `(j + 0.5, i + 0.5)`. + +>>> image_pixel_center(10, 10): +[[[0.5, 0.5], [1.5, 0.5], ..., [9.5, 0.5]], + [[0.5, 1.5], [1.5, 1.5], ..., [9.5, 1.5]], + ... ... ... +[[0.5, 9.5], [1.5, 9.5], ..., [9.5, 9.5]]] + +Args: + width (int): image width + height (int): image height + +Returns: + np.ndarray: shape (height, width, 2)""" + utils3d.numpy.utils.image_pixel_center + +@overload +def image_pixel(height: int, width: int, left: int = None, top: int = None, right: int = None, bottom: int = None, dtype: numpy_.dtype = numpy_.int32) -> numpy_.ndarray: + """Get image pixel coordinates grid, ranging in [0, width - 1] and [0, height - 1]. +`image[i, j]` has pixel center coordinates `(j, i)`. + +>>> image_pixel_center(10, 10): +[[[0, 0], [1, 0], ..., [9, 0]], + [[0, 1.5], [1, 1], ..., [9, 1]], + ... ... ... +[[0, 9.5], [1, 9], ..., [9, 9 ]]] + +Args: + width (int): image width + height (int): image height + +Returns: + np.ndarray: shape (height, width, 2)""" + utils3d.numpy.utils.image_pixel + +@overload +def image_mesh(*image_attrs: numpy_.ndarray, mask: numpy_.ndarray = None, tri: bool = False, return_indices: bool = False) -> Tuple[numpy_.ndarray, ...]: + """Get a mesh regarding image pixel uv coordinates as vertices and image grid as faces. + +Args: + *image_attrs (np.ndarray): image attributes in shape (height, width, [channels]) + mask (np.ndarray, optional): binary mask of shape (height, width), dtype=bool. Defaults to None. + +Returns: + faces (np.ndarray): faces connecting neighboring pixels. shape (T, 4) if tri is False, else (T, 3) + *vertex_attrs (np.ndarray): vertex attributes in corresponding order with input image_attrs + indices (np.ndarray, optional): indices of vertices in the original mesh""" + utils3d.numpy.utils.image_mesh + +@overload +def image_mesh_from_depth(depth: numpy_.ndarray, extrinsics: numpy_.ndarray = None, intrinsics: numpy_.ndarray = None, *vertice_attrs: numpy_.ndarray, atol: float = None, rtol: float = None, remove_by_depth: bool = False, return_uv: bool = False, return_indices: bool = False) -> Tuple[numpy_.ndarray, ...]: + """Get x triangle mesh by lifting depth map to 3D. + +Args: + depth (np.ndarray): [H, W] depth map + extrinsics (np.ndarray, optional): [4, 4] extrinsics matrix. Defaults to None. + intrinsics (np.ndarray, optional): [3, 3] intrinsics matrix. Defaults to None. + *vertice_attrs (np.ndarray): [H, W, C] vertex attributes. Defaults to None. + atol (float, optional): absolute tolerance. Defaults to None. + rtol (float, optional): relative tolerance. Defaults to None. + triangles with vertices having depth difference larger than atol + rtol * depth will be marked. + remove_by_depth (bool, optional): whether to remove triangles with large depth difference. Defaults to True. + return_uv (bool, optional): whether to return uv coordinates. Defaults to False. + return_indices (bool, optional): whether to return indices of vertices in the original mesh. Defaults to False. + +Returns: + vertices (np.ndarray): [N, 3] vertices + faces (np.ndarray): [T, 3] faces + *vertice_attrs (np.ndarray): [N, C] vertex attributes + image_uv (np.ndarray, optional): [N, 2] uv coordinates + ref_indices (np.ndarray, optional): [N] indices of vertices in the original mesh""" + utils3d.numpy.utils.image_mesh_from_depth + +@overload +def depth_to_normals(depth: numpy_.ndarray, intrinsics: numpy_.ndarray, mask: numpy_.ndarray = None) -> numpy_.ndarray: + """Calculate normal map from depth map. Value range is [-1, 1]. Normal direction in OpenGL identity camera's coordinate system. + +Args: + depth (np.ndarray): shape (height, width), linear depth map + intrinsics (np.ndarray): shape (3, 3), intrinsics matrix +Returns: + normal (np.ndarray): shape (height, width, 3), normal map. """ + utils3d.numpy.utils.depth_to_normals + +@overload +def points_to_normals(point: numpy_.ndarray, mask: numpy_.ndarray = None) -> numpy_.ndarray: + """Calculate normal map from point map. Value range is [-1, 1]. Normal direction in OpenGL identity camera's coordinate system. + +Args: + point (np.ndarray): shape (height, width, 3), point map +Returns: + normal (np.ndarray): shape (height, width, 3), normal map. """ + utils3d.numpy.utils.points_to_normals + +@overload +def chessboard(width: int, height: int, grid_size: int, color_a: numpy_.ndarray, color_b: numpy_.ndarray) -> numpy_.ndarray: + """get x chessboard image + +Args: + width (int): image width + height (int): image height + grid_size (int): size of chessboard grid + color_a (np.ndarray): color of the grid at the top-left corner + color_b (np.ndarray): color in complementary grid cells + +Returns: + image (np.ndarray): shape (height, width, channels), chessboard image""" + utils3d.numpy.utils.chessboard + +@overload +def cube(tri: bool = False) -> Tuple[numpy_.ndarray, numpy_.ndarray]: + """Get x cube mesh of size 1 centered at origin. + +### Parameters + tri (bool, optional): return triangulated mesh. Defaults to False, which returns quad mesh. + +### Returns + vertices (np.ndarray): shape (8, 3) + faces (np.ndarray): shape (12, 3)""" + utils3d.numpy.utils.cube + +@overload +def icosahedron(): + utils3d.numpy.utils.icosahedron + +@overload +def square(tri: bool = False) -> Tuple[numpy_.ndarray, numpy_.ndarray]: + """Get a square mesh of area 1 centered at origin in the xy-plane. + +### Returns + vertices (np.ndarray): shape (4, 3) + faces (np.ndarray): shape (1, 4)""" + utils3d.numpy.utils.square + +@overload +def camera_frustum(extrinsics: numpy_.ndarray, intrinsics: numpy_.ndarray, depth: float = 1.0) -> Tuple[numpy_.ndarray, numpy_.ndarray, numpy_.ndarray]: + """Get x triangle mesh of camera frustum.""" + utils3d.numpy.utils.camera_frustum + +@overload +def perspective(fov_y: Union[float, numpy_.ndarray], aspect: Union[float, numpy_.ndarray], near: Union[float, numpy_.ndarray], far: Union[float, numpy_.ndarray]) -> numpy_.ndarray: + """Get OpenGL perspective matrix + +Args: + fov_y (float | np.ndarray): field of view in y axis + aspect (float | np.ndarray): aspect ratio + near (float | np.ndarray): near plane to clip + far (float | np.ndarray): far plane to clip + +Returns: + (np.ndarray): [..., 4, 4] perspective matrix""" + utils3d.numpy.transforms.perspective + +@overload +def perspective_from_fov(fov: Union[float, numpy_.ndarray], width: Union[int, numpy_.ndarray], height: Union[int, numpy_.ndarray], near: Union[float, numpy_.ndarray], far: Union[float, numpy_.ndarray]) -> numpy_.ndarray: + """Get OpenGL perspective matrix from field of view in largest dimension + +Args: + fov (float | np.ndarray): field of view in largest dimension + width (int | np.ndarray): image width + height (int | np.ndarray): image height + near (float | np.ndarray): near plane to clip + far (float | np.ndarray): far plane to clip + +Returns: + (np.ndarray): [..., 4, 4] perspective matrix""" + utils3d.numpy.transforms.perspective_from_fov + +@overload +def perspective_from_fov_xy(fov_x: Union[float, numpy_.ndarray], fov_y: Union[float, numpy_.ndarray], near: Union[float, numpy_.ndarray], far: Union[float, numpy_.ndarray]) -> numpy_.ndarray: + """Get OpenGL perspective matrix from field of view in x and y axis + +Args: + fov_x (float | np.ndarray): field of view in x axis + fov_y (float | np.ndarray): field of view in y axis + near (float | np.ndarray): near plane to clip + far (float | np.ndarray): far plane to clip + +Returns: + (np.ndarray): [..., 4, 4] perspective matrix""" + utils3d.numpy.transforms.perspective_from_fov_xy + +@overload +def intrinsics_from_focal_center(fx: Union[float, numpy_.ndarray], fy: Union[float, numpy_.ndarray], cx: Union[float, numpy_.ndarray], cy: Union[float, numpy_.ndarray], dtype: Optional[numpy_.dtype] = numpy_.float32) -> numpy_.ndarray: + """Get OpenCV intrinsics matrix + +Returns: + (np.ndarray): [..., 3, 3] OpenCV intrinsics matrix""" + utils3d.numpy.transforms.intrinsics_from_focal_center + +@overload +def intrinsics_from_fov(fov_max: Union[float, numpy_.ndarray] = None, fov_min: Union[float, numpy_.ndarray] = None, fov_x: Union[float, numpy_.ndarray] = None, fov_y: Union[float, numpy_.ndarray] = None, width: Union[int, numpy_.ndarray] = None, height: Union[int, numpy_.ndarray] = None) -> numpy_.ndarray: + """Get normalized OpenCV intrinsics matrix from given field of view. +You can provide either fov_max, fov_min, fov_x or fov_y + +Args: + width (int | np.ndarray): image width + height (int | np.ndarray): image height + fov_max (float | np.ndarray): field of view in largest dimension + fov_min (float | np.ndarray): field of view in smallest dimension + fov_x (float | np.ndarray): field of view in x axis + fov_y (float | np.ndarray): field of view in y axis + +Returns: + (np.ndarray): [..., 3, 3] OpenCV intrinsics matrix""" + utils3d.numpy.transforms.intrinsics_from_fov + +@overload +def fov_to_focal(fov: numpy_.ndarray): + utils3d.numpy.transforms.fov_to_focal + +@overload +def focal_to_fov(focal: numpy_.ndarray): + utils3d.numpy.transforms.focal_to_fov + +@overload +def intrinsics_to_fov(intrinsics: numpy_.ndarray) -> Tuple[numpy_.ndarray, numpy_.ndarray]: + utils3d.numpy.transforms.intrinsics_to_fov + +@overload +def view_look_at(eye: numpy_.ndarray, look_at: numpy_.ndarray, up: numpy_.ndarray) -> numpy_.ndarray: + """Get OpenGL view matrix looking at something + +Args: + eye (np.ndarray): [..., 3] the eye position + look_at (np.ndarray): [..., 3] the position to look at + up (np.ndarray): [..., 3] head up direction (y axis in screen space). Not necessarily othogonal to view direction + +Returns: + (np.ndarray): [..., 4, 4], view matrix""" + utils3d.numpy.transforms.view_look_at + +@overload +def extrinsics_look_at(eye: numpy_.ndarray, look_at: numpy_.ndarray, up: numpy_.ndarray) -> numpy_.ndarray: + """Get OpenCV extrinsics matrix looking at something + +Args: + eye (np.ndarray): [..., 3] the eye position + look_at (np.ndarray): [..., 3] the position to look at + up (np.ndarray): [..., 3] head up direction (-y axis in screen space). Not necessarily othogonal to view direction + +Returns: + (np.ndarray): [..., 4, 4], extrinsics matrix""" + utils3d.numpy.transforms.extrinsics_look_at + +@overload +def perspective_to_intrinsics(perspective: numpy_.ndarray) -> numpy_.ndarray: + """OpenGL perspective matrix to OpenCV intrinsics + +Args: + perspective (np.ndarray): [..., 4, 4] OpenGL perspective matrix + +Returns: + (np.ndarray): shape [..., 3, 3] OpenCV intrinsics""" + utils3d.numpy.transforms.perspective_to_intrinsics + +@overload +def perspective_to_near_far(perspective: numpy_.ndarray) -> Tuple[numpy_.ndarray, numpy_.ndarray]: + """Get near and far planes from OpenGL perspective matrix + +Args:""" + utils3d.numpy.transforms.perspective_to_near_far + +@overload +def intrinsics_to_perspective(intrinsics: numpy_.ndarray, near: Union[float, numpy_.ndarray], far: Union[float, numpy_.ndarray]) -> numpy_.ndarray: + """OpenCV intrinsics to OpenGL perspective matrix +NOTE: not work for tile-shifting intrinsics currently + +Args: + intrinsics (np.ndarray): [..., 3, 3] OpenCV intrinsics matrix + near (float | np.ndarray): [...] near plane to clip + far (float | np.ndarray): [...] far plane to clip +Returns: + (np.ndarray): [..., 4, 4] OpenGL perspective matrix""" + utils3d.numpy.transforms.intrinsics_to_perspective + +@overload +def extrinsics_to_view(extrinsics: numpy_.ndarray) -> numpy_.ndarray: + """OpenCV camera extrinsics to OpenGL view matrix + +Args: + extrinsics (np.ndarray): [..., 4, 4] OpenCV camera extrinsics matrix + +Returns: + (np.ndarray): [..., 4, 4] OpenGL view matrix""" + utils3d.numpy.transforms.extrinsics_to_view + +@overload +def view_to_extrinsics(view: numpy_.ndarray) -> numpy_.ndarray: + """OpenGL view matrix to OpenCV camera extrinsics + +Args: + view (np.ndarray): [..., 4, 4] OpenGL view matrix + +Returns: + (np.ndarray): [..., 4, 4] OpenCV camera extrinsics matrix""" + utils3d.numpy.transforms.view_to_extrinsics + +@overload +def normalize_intrinsics(intrinsics: numpy_.ndarray, width: Union[int, numpy_.ndarray], height: Union[int, numpy_.ndarray], integer_pixel_centers: bool = True) -> numpy_.ndarray: + """Normalize intrinsics from pixel cooridnates to uv coordinates + +Args: + intrinsics (np.ndarray): [..., 3, 3] camera intrinsics(s) to normalize + width (int | np.ndarray): [...] image width(s) + height (int | np.ndarray): [...] image height(s) + integer_pixel_centers (bool): whether the integer pixel coordinates are at the center of the pixel. If False, the integer coordinates are at the left-top corner of the pixel. + +Returns: + (np.ndarray): [..., 3, 3] normalized camera intrinsics(s)""" + utils3d.numpy.transforms.normalize_intrinsics + +@overload +def crop_intrinsics(intrinsics: numpy_.ndarray, width: Union[int, numpy_.ndarray], height: Union[int, numpy_.ndarray], left: Union[int, numpy_.ndarray], top: Union[int, numpy_.ndarray], crop_width: Union[int, numpy_.ndarray], crop_height: Union[int, numpy_.ndarray]) -> numpy_.ndarray: + """Evaluate the new intrinsics(s) after crop the image: cropped_img = img[top:top+crop_height, left:left+crop_width] + +Args: + intrinsics (np.ndarray): [..., 3, 3] camera intrinsics(s) to crop + width (int | np.ndarray): [...] image width(s) + height (int | np.ndarray): [...] image height(s) + left (int | np.ndarray): [...] left crop boundary + top (int | np.ndarray): [...] top crop boundary + crop_width (int | np.ndarray): [...] crop width + crop_height (int | np.ndarray): [...] crop height + +Returns: + (np.ndarray): [..., 3, 3] cropped camera intrinsics(s)""" + utils3d.numpy.transforms.crop_intrinsics + +@overload +def pixel_to_uv(pixel: numpy_.ndarray, width: Union[int, numpy_.ndarray], height: Union[int, numpy_.ndarray]) -> numpy_.ndarray: + """Args: + pixel (np.ndarray): [..., 2] pixel coordinrates defined in image space, x range is (0, W - 1), y range is (0, H - 1) + width (int | np.ndarray): [...] image width(s) + height (int | np.ndarray): [...] image height(s) + +Returns: + (np.ndarray): [..., 2] pixel coordinrates defined in uv space, the range is (0, 1)""" + utils3d.numpy.transforms.pixel_to_uv + +@overload +def pixel_to_ndc(pixel: numpy_.ndarray, width: Union[int, numpy_.ndarray], height: Union[int, numpy_.ndarray]) -> numpy_.ndarray: + """Args: + pixel (np.ndarray): [..., 2] pixel coordinrates defined in image space, x range is (0, W - 1), y range is (0, H - 1) + width (int | np.ndarray): [...] image width(s) + height (int | np.ndarray): [...] image height(s) + +Returns: + (np.ndarray): [..., 2] pixel coordinrates defined in ndc space, the range is (-1, 1)""" + utils3d.numpy.transforms.pixel_to_ndc + +@overload +def uv_to_pixel(uv: numpy_.ndarray, width: Union[int, numpy_.ndarray], height: Union[int, numpy_.ndarray]) -> numpy_.ndarray: + """Args: + pixel (np.ndarray): [..., 2] pixel coordinrates defined in image space, x range is (0, W - 1), y range is (0, H - 1) + width (int | np.ndarray): [...] image width(s) + height (int | np.ndarray): [...] image height(s) + +Returns: + (np.ndarray): [..., 2] pixel coordinrates defined in uv space, the range is (0, 1)""" + utils3d.numpy.transforms.uv_to_pixel + +@overload +def project_depth(depth: numpy_.ndarray, near: Union[float, numpy_.ndarray], far: Union[float, numpy_.ndarray]) -> numpy_.ndarray: + """Project linear depth to depth value in screen space + +Args: + depth (np.ndarray): [...] depth value + near (float | np.ndarray): [...] near plane to clip + far (float | np.ndarray): [...] far plane to clip + +Returns: + (np.ndarray): [..., 1] depth value in screen space, value ranging in [0, 1]""" + utils3d.numpy.transforms.project_depth + +@overload +def depth_buffer_to_linear(depth_buffer: numpy_.ndarray, near: Union[float, numpy_.ndarray], far: Union[float, numpy_.ndarray]) -> numpy_.ndarray: + """OpenGL depth buffer to linear depth + +Args: + depth_buffer (np.ndarray): [...] depth value + near (float | np.ndarray): [...] near plane to clip + far (float | np.ndarray): [...] far plane to clip + +Returns: + (np.ndarray): [..., 1] linear depth""" + utils3d.numpy.transforms.depth_buffer_to_linear + +@overload +def unproject_cv(uv_coord: numpy_.ndarray, depth: numpy_.ndarray = None, extrinsics: numpy_.ndarray = None, intrinsics: numpy_.ndarray = None) -> numpy_.ndarray: + """Unproject uv coordinates to 3D view space following the OpenCV convention + +Args: + uv_coord (np.ndarray): [..., N, 2] uv coordinates, value ranging in [0, 1]. + The origin (0., 0.) is corresponding to the left & top + depth (np.ndarray): [..., N] depth value + extrinsics (np.ndarray): [..., 4, 4] extrinsics matrix + intrinsics (np.ndarray): [..., 3, 3] intrinsics matrix + +Returns: + points (np.ndarray): [..., N, 3] 3d points""" + utils3d.numpy.transforms.unproject_cv + +@overload +def unproject_gl(screen_coord: numpy_.ndarray, model: numpy_.ndarray = None, view: numpy_.ndarray = None, perspective: numpy_.ndarray = None) -> numpy_.ndarray: + """Unproject screen space coordinates to 3D view space following the OpenGL convention (except for row major matrice) + +Args: + screen_coord (np.ndarray): [..., N, 3] screen space coordinates, value ranging in [0, 1]. + The origin (0., 0., 0.) is corresponding to the left & bottom & nearest + model (np.ndarray): [..., 4, 4] model matrix + view (np.ndarray): [..., 4, 4] view matrix + perspective (np.ndarray): [..., 4, 4] perspective matrix + +Returns: + points (np.ndarray): [..., N, 3] 3d points""" + utils3d.numpy.transforms.unproject_gl + +@overload +def project_cv(points: numpy_.ndarray, extrinsics: numpy_.ndarray = None, intrinsics: numpy_.ndarray = None) -> Tuple[numpy_.ndarray, numpy_.ndarray]: + """Project 3D points to 2D following the OpenCV convention + +Args: + points (np.ndarray): [..., N, 3] or [..., N, 4] 3D points to project, if the last + dimension is 4, the points are assumed to be in homogeneous coordinates + extrinsics (np.ndarray): [..., 4, 4] extrinsics matrix + intrinsics (np.ndarray): [..., 3, 3] intrinsics matrix + +Returns: + uv_coord (np.ndarray): [..., N, 2] uv coordinates, value ranging in [0, 1]. + The origin (0., 0.) is corresponding to the left & top + linear_depth (np.ndarray): [..., N] linear depth""" + utils3d.numpy.transforms.project_cv + +@overload +def project_gl(points: numpy_.ndarray, model: numpy_.ndarray = None, view: numpy_.ndarray = None, perspective: numpy_.ndarray = None) -> Tuple[numpy_.ndarray, numpy_.ndarray]: + """Project 3D points to 2D following the OpenGL convention (except for row major matrice) + +Args: + points (np.ndarray): [..., N, 3] or [..., N, 4] 3D points to project, if the last + dimension is 4, the points are assumed to be in homogeneous coordinates + model (np.ndarray): [..., 4, 4] model matrix + view (np.ndarray): [..., 4, 4] view matrix + perspective (np.ndarray): [..., 4, 4] perspective matrix + +Returns: + scr_coord (np.ndarray): [..., N, 3] screen space coordinates, value ranging in [0, 1]. + The origin (0., 0., 0.) is corresponding to the left & bottom & nearest + linear_depth (np.ndarray): [..., N] linear depth""" + utils3d.numpy.transforms.project_gl + +@overload +def quaternion_to_matrix(quaternion: numpy_.ndarray, eps: float = 1e-12) -> numpy_.ndarray: + """Converts a batch of quaternions (w, x, y, z) to rotation matrices + +Args: + quaternion (np.ndarray): shape (..., 4), the quaternions to convert + +Returns: + np.ndarray: shape (..., 3, 3), the rotation matrices corresponding to the given quaternions""" + utils3d.numpy.transforms.quaternion_to_matrix + +@overload +def axis_angle_to_matrix(axis_angle: numpy_.ndarray, eps: float = 1e-12) -> numpy_.ndarray: + """Convert axis-angle representation (rotation vector) to rotation matrix, whose direction is the axis of rotation and length is the angle of rotation + +Args: + axis_angle (np.ndarray): shape (..., 3), axis-angle vcetors + +Returns: + np.ndarray: shape (..., 3, 3) The rotation matrices for the given axis-angle parameters""" + utils3d.numpy.transforms.axis_angle_to_matrix + +@overload +def matrix_to_quaternion(rot_mat: numpy_.ndarray, eps: float = 1e-12) -> numpy_.ndarray: + """Convert 3x3 rotation matrix to quaternion (w, x, y, z) + +Args: + rot_mat (np.ndarray): shape (..., 3, 3), the rotation matrices to convert + +Returns: + np.ndarray: shape (..., 4), the quaternions corresponding to the given rotation matrices""" + utils3d.numpy.transforms.matrix_to_quaternion + +@overload +def extrinsics_to_essential(extrinsics: numpy_.ndarray): + """extrinsics matrix `[[R, t] [0, 0, 0, 1]]` such that `x' = R (x - t)` to essential matrix such that `x' E x = 0` + +Args: + extrinsics (np.ndaray): [..., 4, 4] extrinsics matrix + +Returns: + (np.ndaray): [..., 3, 3] essential matrix""" + utils3d.numpy.transforms.extrinsics_to_essential + +@overload +def euler_axis_angle_rotation(axis: str, angle: numpy_.ndarray) -> numpy_.ndarray: + """Return the rotation matrices for one of the rotations about an axis +of which Euler angles describe, for each value of the angle given. + +Args: + axis: Axis label "X" or "Y or "Z". + angle: any shape tensor of Euler angles in radians + +Returns: + Rotation matrices as tensor of shape (..., 3, 3).""" + utils3d.numpy.transforms.euler_axis_angle_rotation + +@overload +def euler_angles_to_matrix(euler_angles: numpy_.ndarray, convention: str = 'XYZ') -> numpy_.ndarray: + """Convert rotations given as Euler angles in radians to rotation matrices. + +Args: + euler_angles: Euler angles in radians as ndarray of shape (..., 3), XYZ + convention: permutation of "X", "Y" or "Z", representing the order of Euler rotations to apply. + +Returns: + Rotation matrices as ndarray of shape (..., 3, 3).""" + utils3d.numpy.transforms.euler_angles_to_matrix + +@overload +def skew_symmetric(v: numpy_.ndarray): + """Skew symmetric matrix from a 3D vector""" + utils3d.numpy.transforms.skew_symmetric + +@overload +def rotation_matrix_from_vectors(v1: numpy_.ndarray, v2: numpy_.ndarray): + """Rotation matrix that rotates v1 to v2""" + utils3d.numpy.transforms.rotation_matrix_from_vectors + +@overload +def ray_intersection(p1: numpy_.ndarray, d1: numpy_.ndarray, p2: numpy_.ndarray, d2: numpy_.ndarray): + """Compute the intersection/closest point of two D-dimensional rays +If the rays are intersecting, the closest point is the intersection point. + +Args: + p1 (np.ndarray): (..., D) origin of ray 1 + d1 (np.ndarray): (..., D) direction of ray 1 + p2 (np.ndarray): (..., D) origin of ray 2 + d2 (np.ndarray): (..., D) direction of ray 2 + +Returns: + (np.ndarray): (..., N) intersection point""" + utils3d.numpy.transforms.ray_intersection + +@overload +def se3_matrix(R: numpy_.ndarray, t: numpy_.ndarray) -> numpy_.ndarray: + """Convert rotation matrix and translation vector to 4x4 transformation matrix. + +Args: + R (np.ndarray): [..., 3, 3] rotation matrix + t (np.ndarray): [..., 3] translation vector + +Returns: + np.ndarray: [..., 4, 4] transformation matrix""" + utils3d.numpy.transforms.se3_matrix + +@overload +def slerp_quaternion(q1: numpy_.ndarray, q2: numpy_.ndarray, t: numpy_.ndarray) -> numpy_.ndarray: + """Spherical linear interpolation between two unit quaternions. + +Args: + q1 (np.ndarray): [..., d] unit vector 1 + q2 (np.ndarray): [..., d] unit vector 2 + t (np.ndarray): [...] interpolation parameter in [0, 1] + +Returns: + np.ndarray: [..., 3] interpolated unit vector""" + utils3d.numpy.transforms.slerp_quaternion + +@overload +def slerp_vector(v1: numpy_.ndarray, v2: numpy_.ndarray, t: numpy_.ndarray) -> numpy_.ndarray: + """Spherical linear interpolation between two unit vectors. The vectors are assumed to be normalized. + +Args: + v1 (np.ndarray): [..., d] unit vector 1 + v2 (np.ndarray): [..., d] unit vector 2 + t (np.ndarray): [...] interpolation parameter in [0, 1] + +Returns: + np.ndarray: [..., d] interpolated unit vector""" + utils3d.numpy.transforms.slerp_vector + +@overload +def lerp(x1: numpy_.ndarray, x2: numpy_.ndarray, t: numpy_.ndarray) -> numpy_.ndarray: + """Linear interpolation between two vectors. + +Args: + x1 (np.ndarray): [..., d] vector 1 + x2 (np.ndarray): [..., d] vector 2 + t (np.ndarray): [...] interpolation parameter. [0, 1] for interpolation between x1 and x2, otherwise for extrapolation. + +Returns: + np.ndarray: [..., d] interpolated vector""" + utils3d.numpy.transforms.lerp + +@overload +def lerp_se3_matrix(T1: numpy_.ndarray, T2: numpy_.ndarray, t: numpy_.ndarray) -> numpy_.ndarray: + """Linear interpolation between two SE(3) matrices. + +Args: + T1 (np.ndarray): [..., 4, 4] SE(3) matrix 1 + T2 (np.ndarray): [..., 4, 4] SE(3) matrix 2 + t (np.ndarray): [...] interpolation parameter in [0, 1] + +Returns: + np.ndarray: [..., 4, 4] interpolated SE(3) matrix""" + utils3d.numpy.transforms.lerp_se3_matrix + +@overload +def piecewise_lerp(x: numpy_.ndarray, t: numpy_.ndarray, s: numpy_.ndarray, extrapolation_mode: Literal['constant', 'linear'] = 'constant') -> numpy_.ndarray: + """Linear spline interpolation. + +### Parameters: +- `x`: np.ndarray, shape (n, d): the values of data points. +- `t`: np.ndarray, shape (n,): the times of the data points. +- `s`: np.ndarray, shape (m,): the times to be interpolated. +- `extrapolation_mode`: str, the mode of extrapolation. 'constant' means extrapolate the boundary values, 'linear' means extrapolate linearly. + +### Returns: +- `y`: np.ndarray, shape (..., m, d): the interpolated values.""" + utils3d.numpy.transforms.piecewise_lerp + +@overload +def piecewise_lerp_se3_matrix(T: numpy_.ndarray, t: numpy_.ndarray, s: numpy_.ndarray, extrapolation_mode: Literal['constant', 'linear'] = 'constant') -> numpy_.ndarray: + """Linear spline interpolation for SE(3) matrices. + +### Parameters: +- `T`: np.ndarray, shape (n, 4, 4): the SE(3) matrices. +- `t`: np.ndarray, shape (n,): the times of the data points. +- `s`: np.ndarray, shape (m,): the times to be interpolated. +- `extrapolation_mode`: str, the mode of extrapolation. 'constant' means extrapolate the boundary values, 'linear' means extrapolate linearly. + +### Returns: +- `T_interp`: np.ndarray, shape (..., m, 4, 4): the interpolated SE(3) matrices.""" + utils3d.numpy.transforms.piecewise_lerp_se3_matrix + +@overload +def apply_transform(T: numpy_.ndarray, x: numpy_.ndarray) -> numpy_.ndarray: + """Apply SE(3) transformation to a point or a set of points. + +### Parameters: +- `T`: np.ndarray, shape (..., 4, 4): the SE(3) matrix. +- `x`: np.ndarray, shape (..., 3): the point or a set of points to be transformed. + +### Returns: +- `x_transformed`: np.ndarray, shape (..., 3): the transformed point or a set of points.""" + utils3d.numpy.transforms.apply_transform + +@overload +def linear_spline_interpolate(x: numpy_.ndarray, t: numpy_.ndarray, s: numpy_.ndarray, extrapolation_mode: Literal['constant', 'linear'] = 'constant') -> numpy_.ndarray: + """Linear spline interpolation. + +### Parameters: +- `x`: np.ndarray, shape (n, d): the values of data points. +- `t`: np.ndarray, shape (n,): the times of the data points. +- `s`: np.ndarray, shape (m,): the times to be interpolated. +- `extrapolation_mode`: str, the mode of extrapolation. 'constant' means extrapolate the boundary values, 'linear' means extrapolate linearly. + +### Returns: +- `y`: np.ndarray, shape (..., m, d): the interpolated values.""" + utils3d.numpy.spline.linear_spline_interpolate + +@overload +def RastContext(*args, **kwargs): + utils3d.numpy.rasterization.RastContext + +@overload +def rasterize_triangle_faces(ctx: utils3d.numpy.rasterization.RastContext, vertices: numpy_.ndarray, faces: numpy_.ndarray, attr: numpy_.ndarray, width: int, height: int, transform: numpy_.ndarray = None, cull_backface: bool = True, return_depth: bool = False, image: numpy_.ndarray = None, depth: numpy_.ndarray = None) -> Tuple[numpy_.ndarray, numpy_.ndarray]: + """Rasterize vertex attribute. + +Args: + vertices (np.ndarray): [N, 3] + faces (np.ndarray): [T, 3] + attr (np.ndarray): [N, C] + width (int): width of rendered image + height (int): height of rendered image + transform (np.ndarray): [4, 4] model-view-projection transformation matrix. + cull_backface (bool): whether to cull backface + image: (np.ndarray): [H, W, C] background image + depth: (np.ndarray): [H, W] background depth + +Returns: + image (np.ndarray): [H, W, C] rendered image + depth (np.ndarray): [H, W] screen space depth, ranging from 0 to 1. If return_depth is False, it is None.""" + utils3d.numpy.rasterization.rasterize_triangle_faces + +@overload +def rasterize_edges(ctx: utils3d.numpy.rasterization.RastContext, vertices: numpy_.ndarray, edges: numpy_.ndarray, attr: numpy_.ndarray, width: int, height: int, transform: numpy_.ndarray = None, line_width: float = 1.0, return_depth: bool = False, image: numpy_.ndarray = None, depth: numpy_.ndarray = None) -> Tuple[numpy_.ndarray, ...]: + """Rasterize vertex attribute. + +Args: + vertices (np.ndarray): [N, 3] + faces (np.ndarray): [T, 3] + attr (np.ndarray): [N, C] + width (int): width of rendered image + height (int): height of rendered image + transform (np.ndarray): [4, 4] model-view-projection matrix + line_width (float): width of line. Defaults to 1.0. NOTE: Values other than 1.0 may not work across all platforms. + cull_backface (bool): whether to cull backface + +Returns: + image (np.ndarray): [H, W, C] rendered image + depth (np.ndarray): [H, W] screen space depth, ranging from 0 to 1. If return_depth is False, it is None.""" + utils3d.numpy.rasterization.rasterize_edges + +@overload +def texture(ctx: utils3d.numpy.rasterization.RastContext, uv: numpy_.ndarray, texture: numpy_.ndarray, interpolation: str = 'linear', wrap: str = 'clamp') -> numpy_.ndarray: + """Given an UV image, texturing from the texture map""" + utils3d.numpy.rasterization.texture + +@overload +def warp_image_by_depth(ctx: utils3d.numpy.rasterization.RastContext, src_depth: numpy_.ndarray, src_image: numpy_.ndarray = None, width: int = None, height: int = None, *, extrinsics_src: numpy_.ndarray = None, extrinsics_tgt: numpy_.ndarray = None, intrinsics_src: numpy_.ndarray = None, intrinsics_tgt: numpy_.ndarray = None, near: float = 0.1, far: float = 100.0, cull_backface: bool = True, ssaa: int = 1, return_depth: bool = False) -> Tuple[numpy_.ndarray, ...]: + """Warp image by depth map. + +Args: + ctx (RastContext): rasterizer context + src_depth (np.ndarray): [H, W] + src_image (np.ndarray, optional): [H, W, C]. The image to warp. Defaults to None (use uv coordinates). + width (int, optional): width of the output image. None to use depth map width. Defaults to None. + height (int, optional): height of the output image. None to use depth map height. Defaults to None. + extrinsics_src (np.ndarray, optional): extrinsics matrix of the source camera. Defaults to None (identity). + extrinsics_tgt (np.ndarray, optional): extrinsics matrix of the target camera. Defaults to None (identity). + intrinsics_src (np.ndarray, optional): intrinsics matrix of the source camera. Defaults to None (use the same as intrinsics_tgt). + intrinsics_tgt (np.ndarray, optional): intrinsics matrix of the target camera. Defaults to None (use the same as intrinsics_src). + cull_backface (bool, optional): whether to cull backface. Defaults to True. + ssaa (int, optional): super sampling anti-aliasing. Defaults to 1. + +Returns: + tgt_image (np.ndarray): [H, W, C] warped image (or uv coordinates if image is None). + tgt_depth (np.ndarray): [H, W] screen space depth, ranging from 0 to 1. If return_depth is False, it is None.""" + utils3d.numpy.rasterization.warp_image_by_depth + +@overload +def test_rasterization(ctx: utils3d.numpy.rasterization.RastContext): + """Test if rasterization works. It will render a cube with random colors and save it as a CHECKME.png file.""" + utils3d.numpy.rasterization.test_rasterization + +@overload +def triangulate(faces: torch_.Tensor, vertices: torch_.Tensor = None, backslash: bool = None) -> torch_.Tensor: + """Triangulate a polygonal mesh. + +Args: + faces (torch.Tensor): [..., L, P] polygonal faces + vertices (torch.Tensor, optional): [..., N, 3] 3-dimensional vertices. + If given, the triangulation is performed according to the distance + between vertices. Defaults to None. + backslash (torch.Tensor, optional): [..., L] boolean array indicating + how to triangulate the quad faces. Defaults to None. + + +Returns: + (torch.Tensor): [L * (P - 2), 3] triangular faces""" + utils3d.torch.mesh.triangulate + +@overload +def compute_face_normal(vertices: torch_.Tensor, faces: torch_.Tensor) -> torch_.Tensor: + """Compute face normals of a triangular mesh + +Args: + vertices (torch.Tensor): [..., N, 3] 3-dimensional vertices + faces (torch.Tensor): [..., T, 3] triangular face indices + +Returns: + normals (torch.Tensor): [..., T, 3] face normals""" + utils3d.torch.mesh.compute_face_normal + +@overload +def compute_face_angles(vertices: torch_.Tensor, faces: torch_.Tensor) -> torch_.Tensor: + """Compute face angles of a triangular mesh + +Args: + vertices (torch.Tensor): [..., N, 3] 3-dimensional vertices + faces (torch.Tensor): [T, 3] triangular face indices + +Returns: + angles (torch.Tensor): [..., T, 3] face angles""" + utils3d.torch.mesh.compute_face_angles + +@overload +def compute_vertex_normal(vertices: torch_.Tensor, faces: torch_.Tensor, face_normal: torch_.Tensor = None) -> torch_.Tensor: + """Compute vertex normals of a triangular mesh by averaging neightboring face normals + +Args: + vertices (torch.Tensor): [..., N, 3] 3-dimensional vertices + faces (torch.Tensor): [T, 3] triangular face indices + face_normal (torch.Tensor, optional): [..., T, 3] face normals. + None to compute face normals from vertices and faces. Defaults to None. + +Returns: + normals (torch.Tensor): [..., N, 3] vertex normals""" + utils3d.torch.mesh.compute_vertex_normal + +@overload +def compute_vertex_normal_weighted(vertices: torch_.Tensor, faces: torch_.Tensor, face_normal: torch_.Tensor = None) -> torch_.Tensor: + """Compute vertex normals of a triangular mesh by weighted sum of neightboring face normals +according to the angles + +Args: + vertices (torch.Tensor): [..., N, 3] 3-dimensional vertices + faces (torch.Tensor): [T, 3] triangular face indices + face_normal (torch.Tensor, optional): [..., T, 3] face normals. + None to compute face normals from vertices and faces. Defaults to None. + +Returns: + normals (torch.Tensor): [..., N, 3] vertex normals""" + utils3d.torch.mesh.compute_vertex_normal_weighted + +@overload +def remove_unreferenced_vertices(faces: torch_.Tensor, *vertice_attrs, return_indices: bool = False) -> Tuple[torch_.Tensor, ...]: + """Remove unreferenced vertices of a mesh. +Unreferenced vertices are removed, and the face indices are updated accordingly. + +Args: + faces (torch.Tensor): [T, P] face indices + *vertice_attrs: vertex attributes + +Returns: + faces (torch.Tensor): [T, P] face indices + *vertice_attrs: vertex attributes + indices (torch.Tensor, optional): [N] indices of vertices that are kept. Defaults to None.""" + utils3d.torch.mesh.remove_unreferenced_vertices + +@overload +def remove_corrupted_faces(faces: torch_.Tensor) -> torch_.Tensor: + """Remove corrupted faces (faces with duplicated vertices) + +Args: + faces (torch.Tensor): [T, 3] triangular face indices + +Returns: + torch.Tensor: [T_, 3] triangular face indices""" + utils3d.torch.mesh.remove_corrupted_faces + +@overload +def merge_duplicate_vertices(vertices: torch_.Tensor, faces: torch_.Tensor, tol: float = 1e-06) -> Tuple[torch_.Tensor, torch_.Tensor]: + """Merge duplicate vertices of a triangular mesh. +Duplicate vertices are merged by selecte one of them, and the face indices are updated accordingly. + +Args: + vertices (torch.Tensor): [N, 3] 3-dimensional vertices + faces (torch.Tensor): [T, 3] triangular face indices + tol (float, optional): tolerance for merging. Defaults to 1e-6. + +Returns: + vertices (torch.Tensor): [N_, 3] 3-dimensional vertices + faces (torch.Tensor): [T, 3] triangular face indices""" + utils3d.torch.mesh.merge_duplicate_vertices + +@overload +def subdivide_mesh_simple(vertices: torch_.Tensor, faces: torch_.Tensor, n: int = 1) -> Tuple[torch_.Tensor, torch_.Tensor]: + """Subdivide a triangular mesh by splitting each triangle into 4 smaller triangles. +NOTE: All original vertices are kept, and new vertices are appended to the end of the vertex list. + +Args: + vertices (torch.Tensor): [N, 3] 3-dimensional vertices + faces (torch.Tensor): [T, 3] triangular face indices + n (int, optional): number of subdivisions. Defaults to 1. + +Returns: + vertices (torch.Tensor): [N_, 3] subdivided 3-dimensional vertices + faces (torch.Tensor): [4 * T, 3] subdivided triangular face indices""" + utils3d.torch.mesh.subdivide_mesh_simple + +@overload +def compute_face_tbn(pos: torch_.Tensor, faces_pos: torch_.Tensor, uv: torch_.Tensor, faces_uv: torch_.Tensor, eps: float = 1e-07) -> torch_.Tensor: + """compute TBN matrix for each face + +Args: + pos (torch.Tensor): shape (..., N_pos, 3), positions + faces_pos (torch.Tensor): shape(T, 3) + uv (torch.Tensor): shape (..., N_uv, 3) uv coordinates, + faces_uv (torch.Tensor): shape(T, 3) + +Returns: + torch.Tensor: (..., T, 3, 3) TBN matrix for each face. Note TBN vectors are normalized but not necessarily orthognal""" + utils3d.torch.mesh.compute_face_tbn + +@overload +def compute_vertex_tbn(faces_topo: torch_.Tensor, pos: torch_.Tensor, faces_pos: torch_.Tensor, uv: torch_.Tensor, faces_uv: torch_.Tensor) -> torch_.Tensor: + """compute TBN matrix for each face + +Args: + faces_topo (torch.Tensor): (T, 3), face indice of topology + pos (torch.Tensor): shape (..., N_pos, 3), positions + faces_pos (torch.Tensor): shape(T, 3) + uv (torch.Tensor): shape (..., N_uv, 3) uv coordinates, + faces_uv (torch.Tensor): shape(T, 3) + +Returns: + torch.Tensor: (..., V, 3, 3) TBN matrix for each face. Note TBN vectors are normalized but not necessarily orthognal""" + utils3d.torch.mesh.compute_vertex_tbn + +@overload +def laplacian(vertices: torch_.Tensor, faces: torch_.Tensor, weight: str = 'uniform') -> torch_.Tensor: + """Laplacian smooth with cotangent weights + +Args: + vertices (torch.Tensor): shape (..., N, 3) + faces (torch.Tensor): shape (T, 3) + weight (str): 'uniform' or 'cotangent'""" + utils3d.torch.mesh.laplacian + +@overload +def laplacian_smooth_mesh(vertices: torch_.Tensor, faces: torch_.Tensor, weight: str = 'uniform', times: int = 5) -> torch_.Tensor: + """Laplacian smooth with cotangent weights + +Args: + vertices (torch.Tensor): shape (..., N, 3) + faces (torch.Tensor): shape (T, 3) + weight (str): 'uniform' or 'cotangent'""" + utils3d.torch.mesh.laplacian_smooth_mesh + +@overload +def taubin_smooth_mesh(vertices: torch_.Tensor, faces: torch_.Tensor, lambda_: float = 0.5, mu_: float = -0.51) -> torch_.Tensor: + """Taubin smooth mesh + +Args: + vertices (torch.Tensor): _description_ + faces (torch.Tensor): _description_ + lambda_ (float, optional): _description_. Defaults to 0.5. + mu_ (float, optional): _description_. Defaults to -0.51. + +Returns: + torch.Tensor: _description_""" + utils3d.torch.mesh.taubin_smooth_mesh + +@overload +def laplacian_hc_smooth_mesh(vertices: torch_.Tensor, faces: torch_.Tensor, times: int = 5, alpha: float = 0.5, beta: float = 0.5, weight: str = 'uniform'): + """HC algorithm from Improved Laplacian Smoothing of Noisy Surface Meshes by J.Vollmer et al. + """ + utils3d.torch.mesh.laplacian_hc_smooth_mesh + +@overload +def get_rays(extrinsics: torch_.Tensor, intrinsics: torch_.Tensor, uv: torch_.Tensor) -> Tuple[torch_.Tensor, torch_.Tensor]: + """Args: + extrinsics: (..., 4, 4) extrinsics matrices. + intrinsics: (..., 3, 3) intrinsics matrices. + uv: (..., n_rays, 2) uv coordinates of the rays. + +Returns: + rays_o: (..., 1, 3) ray origins + rays_d: (..., n_rays, 3) ray directions. + NOTE: ray directions are NOT normalized. They actuallys makes rays_o + rays_d * z = world coordinates, where z is the depth.""" + utils3d.torch.nerf.get_rays + +@overload +def get_image_rays(extrinsics: torch_.Tensor, intrinsics: torch_.Tensor, width: int, height: int) -> Tuple[torch_.Tensor, torch_.Tensor]: + """Args: + extrinsics: (..., 4, 4) extrinsics matrices. + intrinsics: (..., 3, 3) intrinsics matrices. + width: width of the image. + height: height of the image. + +Returns: + rays_o: (..., 1, 1, 3) ray origins + rays_d: (..., height, width, 3) ray directions. + NOTE: ray directions are NOT normalized. They actuallys makes rays_o + rays_d * z = world coordinates, where z is the depth.""" + utils3d.torch.nerf.get_image_rays + +@overload +def get_mipnerf_cones(rays_o: torch_.Tensor, rays_d: torch_.Tensor, z_vals: torch_.Tensor, pixel_width: torch_.Tensor) -> Tuple[torch_.Tensor, torch_.Tensor]: + """Args: + rays_o: (..., n_rays, 3) ray origins + rays_d: (..., n_rays, 3) ray directions. + z_vals: (..., n_rays, n_samples) z values. + pixel_width: (...) pixel width. = 1 / (normalized focal length * width) + +Returns: + mu: (..., n_rays, n_samples, 3) cone mu. + sigma: (..., n_rays, n_samples, 3, 3) cone sigma.""" + utils3d.torch.nerf.get_mipnerf_cones + +@overload +def volume_rendering(color: torch_.Tensor, sigma: torch_.Tensor, z_vals: torch_.Tensor, ray_length: torch_.Tensor, rgb: bool = True, depth: bool = True) -> Tuple[torch_.Tensor, torch_.Tensor, torch_.Tensor]: + """Given color, sigma and z_vals (linear depth of the sampling points), render the volume. + +NOTE: By default, color and sigma should have one less sample than z_vals, in correspondence with the average value in intervals. +If queried color are aligned with z_vals, we use trapezoidal rule to calculate the average values in intervals. + +Args: + color: (..., n_samples or n_samples - 1, 3) color values. + sigma: (..., n_samples or n_samples - 1) density values. + z_vals: (..., n_samples) z values. + ray_length: (...) length of the ray + +Returns: + rgb: (..., 3) rendered color values. + depth: (...) rendered depth values. + weights (..., n_samples) weights.""" + utils3d.torch.nerf.volume_rendering + +@overload +def bin_sample(size: Union[torch_.Size, Tuple[int, ...]], n_samples: int, min_value: numbers.Number, max_value: numbers.Number, spacing: Literal['linear', 'inverse_linear'], dtype: torch_.dtype = None, device: torch_.device = None) -> torch_.Tensor: + """Uniformly (or uniformly in inverse space) sample z values in `n_samples` bins in range [min_value, max_value]. +Args: + size: size of the rays + n_samples: number of samples to be sampled, also the number of bins + min_value: minimum value of the range + max_value: maximum value of the range + space: 'linear' or 'inverse_linear'. If 'inverse_linear', the sampling is uniform in inverse space. + +Returns: + z_rand: (*size, n_samples) sampled z values, sorted in ascending order.""" + utils3d.torch.nerf.bin_sample + +@overload +def importance_sample(z_vals: torch_.Tensor, weights: torch_.Tensor, n_samples: int) -> Tuple[torch_.Tensor, torch_.Tensor]: + """Importance sample z values. + +NOTE: By default, weights should have one less sample than z_vals, in correspondence with the intervals. +If weights has the same number of samples as z_vals, we use trapezoidal rule to calculate the average weights in intervals. + +Args: + z_vals: (..., n_rays, n_input_samples) z values, sorted in ascending order. + weights: (..., n_rays, n_input_samples or n_input_samples - 1) weights. + n_samples: number of output samples for importance sampling. + +Returns: + z_importance: (..., n_rays, n_samples) importance sampled z values, unsorted.""" + utils3d.torch.nerf.importance_sample + +@overload +def nerf_render_rays(nerf: Union[Callable[[torch_.Tensor, torch_.Tensor], Tuple[torch_.Tensor, torch_.Tensor]], Tuple[Callable[[torch_.Tensor], Tuple[torch_.Tensor, torch_.Tensor]], Callable[[torch_.Tensor], Tuple[torch_.Tensor, torch_.Tensor]]]], rays_o: torch_.Tensor, rays_d: torch_.Tensor, *, return_dict: bool = False, n_coarse: int = 64, n_fine: int = 64, near: float = 0.1, far: float = 100.0, z_spacing: Literal['linear', 'inverse_linear'] = 'linear'): + """NeRF rendering of rays. Note that it supports arbitrary batch dimensions (denoted as `...`) + +Args: + nerf: nerf model, which takes (points, directions) as input and returns (color, density) as output. + If nerf is a tuple, it should be (nerf_coarse, nerf_fine), where nerf_coarse and nerf_fine are two nerf models for coarse and fine stages respectively. + + nerf args: + points: (..., n_rays, n_samples, 3) + directions: (..., n_rays, n_samples, 3) + nerf returns: + color: (..., n_rays, n_samples, 3) color values. + density: (..., n_rays, n_samples) density values. + + rays_o: (..., n_rays, 3) ray origins + rays_d: (..., n_rays, 3) ray directions. + pixel_width: (..., n_rays) pixel width. How to compute? pixel_width = 1 / (normalized focal length * width) + +Returns + if return_dict is False, return rendered rgb and depth for short cut. (If there are separate coarse and fine results, return fine results) + rgb: (..., n_rays, 3) rendered color values. + depth: (..., n_rays) rendered depth values. + else, return a dict. If `n_fine == 0` or `nerf` is a single model, the dict only contains coarse results: + ``` + {'rgb': .., 'depth': .., 'weights': .., 'z_vals': .., 'color': .., 'density': ..} + ``` + If there are two models for coarse and fine stages, the dict contains both coarse and fine results: + ``` + { + "coarse": {'rgb': .., 'depth': .., 'weights': .., 'z_vals': .., 'color': .., 'density': ..}, + "fine": {'rgb': .., 'depth': .., 'weights': .., 'z_vals': .., 'color': .., 'density': ..} + } + ```""" + utils3d.torch.nerf.nerf_render_rays + +@overload +def mipnerf_render_rays(mipnerf: Callable[[torch_.Tensor, torch_.Tensor, torch_.Tensor], Tuple[torch_.Tensor, torch_.Tensor]], rays_o: torch_.Tensor, rays_d: torch_.Tensor, pixel_width: torch_.Tensor, *, return_dict: bool = False, n_coarse: int = 64, n_fine: int = 64, uniform_ratio: float = 0.4, near: float = 0.1, far: float = 100.0, z_spacing: Literal['linear', 'inverse_linear'] = 'linear') -> Union[Tuple[torch_.Tensor, torch_.Tensor], Dict[str, torch_.Tensor]]: + """MipNeRF rendering. + +Args: + mipnerf: mipnerf model, which takes (points_mu, points_sigma) as input and returns (color, density) as output. + + mipnerf args: + points_mu: (..., n_rays, n_samples, 3) cone mu. + points_sigma: (..., n_rays, n_samples, 3, 3) cone sigma. + directions: (..., n_rays, n_samples, 3) + mipnerf returns: + color: (..., n_rays, n_samples, 3) color values. + density: (..., n_rays, n_samples) density values. + + rays_o: (..., n_rays, 3) ray origins + rays_d: (..., n_rays, 3) ray directions. + pixel_width: (..., n_rays) pixel width. How to compute? pixel_width = 1 / (normalized focal length * width) + +Returns + if return_dict is False, return rendered results only: (If `n_fine == 0`, return coarse results, otherwise return fine results) + rgb: (..., n_rays, 3) rendered color values. + depth: (..., n_rays) rendered depth values. + else, return a dict. If `n_fine == 0`, the dict only contains coarse results: + ``` + {'rgb': .., 'depth': .., 'weights': .., 'z_vals': .., 'color': .., 'density': ..} + ``` + If n_fine > 0, the dict contains both coarse and fine results : + ``` + { + "coarse": {'rgb': .., 'depth': .., 'weights': .., 'z_vals': .., 'color': .., 'density': ..}, + "fine": {'rgb': .., 'depth': .., 'weights': .., 'z_vals': .., 'color': .., 'density': ..} + } + ```""" + utils3d.torch.nerf.mipnerf_render_rays + +@overload +def nerf_render_view(nerf: torch_.Tensor, extrinsics: torch_.Tensor, intrinsics: torch_.Tensor, width: int, height: int, *, patchify: bool = False, patch_size: Tuple[int, int] = (64, 64), **options: Dict[str, Any]) -> Tuple[torch_.Tensor, torch_.Tensor]: + """NeRF rendering of views. Note that it supports arbitrary batch dimensions (denoted as `...`) + +Args: + extrinsics: (..., 4, 4) extrinsics matrice of the rendered views + intrinsics (optional): (..., 3, 3) intrinsics matrice of the rendered views. + width (optional): image width of the rendered views. + height (optional): image height of the rendered views. + patchify (optional): If the image is too large, render it patch by patch + **options: rendering options. + +Returns: + rgb: (..., channels, height, width) rendered color values. + depth: (..., height, width) rendered depth values.""" + utils3d.torch.nerf.nerf_render_view + +@overload +def mipnerf_render_view(mipnerf: torch_.Tensor, extrinsics: torch_.Tensor, intrinsics: torch_.Tensor, width: int, height: int, *, patchify: bool = False, patch_size: Tuple[int, int] = (64, 64), **options: Dict[str, Any]) -> Tuple[torch_.Tensor, torch_.Tensor]: + """MipNeRF rendering of views. Note that it supports arbitrary batch dimensions (denoted as `...`) + +Args: + extrinsics: (..., 4, 4) extrinsics matrice of the rendered views + intrinsics (optional): (..., 3, 3) intrinsics matrice of the rendered views. + width (optional): image width of the rendered views. + height (optional): image height of the rendered views. + patchify (optional): If the image is too large, render it patch by patch + **options: rendering options. + +Returns: + rgb: (..., 3, height, width) rendered color values. + depth: (..., height, width) rendered depth values.""" + utils3d.torch.nerf.mipnerf_render_view + +@overload +def InstantNGP(view_dependent: bool = True, base_resolution: int = 16, finest_resolution: int = 2048, n_levels: int = 16, num_layers_density: int = 2, hidden_dim_density: int = 64, num_layers_color: int = 3, hidden_dim_color: int = 64, log2_hashmap_size: int = 19, bound: float = 1.0, color_channels: int = 3): + """An implementation of InstantNGP, Müller et. al., https://nvlabs.github.io/instant-ngp/. +Requires `tinycudann` package. +Install it by: +``` +pip install git+https://github.com/NVlabs/tiny-cuda-nn/#subdirectory=bindings/torch +```""" + utils3d.torch.nerf.InstantNGP + +@overload +def sliding_window_1d(x: torch_.Tensor, window_size: int, stride: int = 1, dim: int = -1) -> torch_.Tensor: + """Sliding window view of the input tensor. The dimension of the sliding window is appended to the end of the input tensor's shape. +NOTE: Since Pytorch has `unfold` function, 1D sliding window view is just a wrapper of it.""" + utils3d.torch.utils.sliding_window_1d + +@overload +def sliding_window_2d(x: torch_.Tensor, window_size: Union[int, Tuple[int, int]], stride: Union[int, Tuple[int, int]], dim: Union[int, Tuple[int, int]] = (-2, -1)) -> torch_.Tensor: + utils3d.torch.utils.sliding_window_2d + +@overload +def sliding_window_nd(x: torch_.Tensor, window_size: Tuple[int, ...], stride: Tuple[int, ...], dim: Tuple[int, ...]) -> torch_.Tensor: + utils3d.torch.utils.sliding_window_nd + +@overload +def image_uv(height: int, width: int, left: int = None, top: int = None, right: int = None, bottom: int = None, device: torch_.device = None, dtype: torch_.dtype = None) -> torch_.Tensor: + """Get image space UV grid, ranging in [0, 1]. + +>>> image_uv(10, 10): +[[[0.05, 0.05], [0.15, 0.05], ..., [0.95, 0.05]], + [[0.05, 0.15], [0.15, 0.15], ..., [0.95, 0.15]], + ... ... ... + [[0.05, 0.95], [0.15, 0.95], ..., [0.95, 0.95]]] + +Args: + width (int): image width + height (int): image height + +Returns: + np.ndarray: shape (height, width, 2)""" + utils3d.torch.utils.image_uv + +@overload +def image_pixel_center(height: int, width: int, left: int = None, top: int = None, right: int = None, bottom: int = None, dtype: torch_.dtype = None, device: torch_.device = None) -> torch_.Tensor: + """Get image pixel center coordinates, ranging in [0, width] and [0, height]. +`image[i, j]` has pixel center coordinates `(j + 0.5, i + 0.5)`. + +>>> image_pixel_center(10, 10): +[[[0.5, 0.5], [1.5, 0.5], ..., [9.5, 0.5]], + [[0.5, 1.5], [1.5, 1.5], ..., [9.5, 1.5]], + ... ... ... +[[0.5, 9.5], [1.5, 9.5], ..., [9.5, 9.5]]] + +Args: + width (int): image width + height (int): image height + +Returns: + np.ndarray: shape (height, width, 2)""" + utils3d.torch.utils.image_pixel_center + +@overload +def image_mesh(height: int, width: int, mask: torch_.Tensor = None, device: torch_.device = None, dtype: torch_.dtype = None) -> Tuple[torch_.Tensor, torch_.Tensor]: + """Get a quad mesh regarding image pixel uv coordinates as vertices and image grid as faces. + +Args: + width (int): image width + height (int): image height + mask (np.ndarray, optional): binary mask of shape (height, width), dtype=bool. Defaults to None. + +Returns: + uv (np.ndarray): uv corresponding to pixels as described in image_uv() + faces (np.ndarray): quad faces connecting neighboring pixels + indices (np.ndarray, optional): indices of vertices in the original mesh""" + utils3d.torch.utils.image_mesh + +@overload +def chessboard(width: int, height: int, grid_size: int, color_a: torch_.Tensor, color_b: torch_.Tensor) -> torch_.Tensor: + """get a chessboard image + +Args: + width (int): image width + height (int): image height + grid_size (int): size of chessboard grid + color_a (torch.Tensor): shape (chanenls,), color of the grid at the top-left corner + color_b (torch.Tensor): shape (chanenls,), color in complementary grids + +Returns: + image (torch.Tensor): shape (height, width, channels), chessboard image""" + utils3d.torch.utils.chessboard + +@overload +def depth_edge(depth: torch_.Tensor, atol: float = None, rtol: float = None, kernel_size: int = 3, mask: torch_.Tensor = None) -> torch_.BoolTensor: + """Compute the edge mask of a depth map. The edge is defined as the pixels whose neighbors have a large difference in depth. + +Args: + depth (torch.Tensor): shape (..., height, width), linear depth map + atol (float): absolute tolerance + rtol (float): relative tolerance + +Returns: + edge (torch.Tensor): shape (..., height, width) of dtype torch.bool""" + utils3d.torch.utils.depth_edge + +@overload +def depth_aliasing(depth: torch_.Tensor, atol: float = None, rtol: float = None, kernel_size: int = 3, mask: torch_.Tensor = None) -> torch_.BoolTensor: + """Compute the map that indicates the aliasing of a depth map. The aliasing is defined as the pixels which neither close to the maximum nor the minimum of its neighbors. +Args: + depth (torch.Tensor): shape (..., height, width), linear depth map + atol (float): absolute tolerance + rtol (float): relative tolerance + +Returns: + edge (torch.Tensor): shape (..., height, width) of dtype torch.bool""" + utils3d.torch.utils.depth_aliasing + +@overload +def image_mesh_from_depth(depth: torch_.Tensor, extrinsics: torch_.Tensor = None, intrinsics: torch_.Tensor = None) -> Tuple[torch_.Tensor, torch_.Tensor]: + utils3d.torch.utils.image_mesh_from_depth + +@overload +def point_to_normal(point: torch_.Tensor, mask: torch_.Tensor = None) -> torch_.Tensor: + """Calculate normal map from point map. Value range is [-1, 1]. Normal direction in OpenGL identity camera's coordinate system. + +Args: + point (torch.Tensor): shape (..., height, width, 3), point map +Returns: + normal (torch.Tensor): shape (..., height, width, 3), normal map. """ + utils3d.torch.utils.point_to_normal + +@overload +def depth_to_normal(depth: torch_.Tensor, intrinsics: torch_.Tensor, mask: torch_.Tensor = None) -> torch_.Tensor: + """Calculate normal map from depth map. Value range is [-1, 1]. Normal direction in OpenGL identity camera's coordinate system. + +Args: + depth (torch.Tensor): shape (..., height, width), linear depth map + intrinsics (torch.Tensor): shape (..., 3, 3), intrinsics matrix +Returns: + normal (torch.Tensor): shape (..., 3, height, width), normal map. """ + utils3d.torch.utils.depth_to_normal + +@overload +def masked_min(input: torch_.Tensor, mask: torch_.BoolTensor, dim: int = None, keepdim: bool = False) -> Union[torch_.Tensor, Tuple[torch_.Tensor, torch_.Tensor]]: + """Similar to torch.min, but with mask + """ + utils3d.torch.utils.masked_min + +@overload +def masked_max(input: torch_.Tensor, mask: torch_.BoolTensor, dim: int = None, keepdim: bool = False) -> Union[torch_.Tensor, Tuple[torch_.Tensor, torch_.Tensor]]: + """Similar to torch.max, but with mask + """ + utils3d.torch.utils.masked_max + +@overload +def bounding_rect(mask: torch_.BoolTensor): + """get bounding rectangle of a mask + +Args: + mask (torch.Tensor): shape (..., height, width), mask + +Returns: + rect (torch.Tensor): shape (..., 4), bounding rectangle (left, top, right, bottom)""" + utils3d.torch.utils.bounding_rect + +@overload +def perspective(fov_y: Union[float, torch_.Tensor], aspect: Union[float, torch_.Tensor], near: Union[float, torch_.Tensor], far: Union[float, torch_.Tensor]) -> torch_.Tensor: + """Get OpenGL perspective matrix + +Args: + fov_y (float | torch.Tensor): field of view in y axis + aspect (float | torch.Tensor): aspect ratio + near (float | torch.Tensor): near plane to clip + far (float | torch.Tensor): far plane to clip + +Returns: + (torch.Tensor): [..., 4, 4] perspective matrix""" + utils3d.torch.transforms.perspective + +@overload +def perspective_from_fov(fov: Union[float, torch_.Tensor], width: Union[int, torch_.Tensor], height: Union[int, torch_.Tensor], near: Union[float, torch_.Tensor], far: Union[float, torch_.Tensor]) -> torch_.Tensor: + """Get OpenGL perspective matrix from field of view in largest dimension + +Args: + fov (float | torch.Tensor): field of view in largest dimension + width (int | torch.Tensor): image width + height (int | torch.Tensor): image height + near (float | torch.Tensor): near plane to clip + far (float | torch.Tensor): far plane to clip + +Returns: + (torch.Tensor): [..., 4, 4] perspective matrix""" + utils3d.torch.transforms.perspective_from_fov + +@overload +def perspective_from_fov_xy(fov_x: Union[float, torch_.Tensor], fov_y: Union[float, torch_.Tensor], near: Union[float, torch_.Tensor], far: Union[float, torch_.Tensor]) -> torch_.Tensor: + """Get OpenGL perspective matrix from field of view in x and y axis + +Args: + fov_x (float | torch.Tensor): field of view in x axis + fov_y (float | torch.Tensor): field of view in y axis + near (float | torch.Tensor): near plane to clip + far (float | torch.Tensor): far plane to clip + +Returns: + (torch.Tensor): [..., 4, 4] perspective matrix""" + utils3d.torch.transforms.perspective_from_fov_xy + +@overload +def intrinsics_from_focal_center(fx: Union[float, torch_.Tensor], fy: Union[float, torch_.Tensor], cx: Union[float, torch_.Tensor], cy: Union[float, torch_.Tensor]) -> torch_.Tensor: + """Get OpenCV intrinsics matrix + +Args: + focal_x (float | torch.Tensor): focal length in x axis + focal_y (float | torch.Tensor): focal length in y axis + cx (float | torch.Tensor): principal point in x axis + cy (float | torch.Tensor): principal point in y axis + +Returns: + (torch.Tensor): [..., 3, 3] OpenCV intrinsics matrix""" + utils3d.torch.transforms.intrinsics_from_focal_center + +@overload +def intrinsics_from_fov(fov_max: Union[float, torch_.Tensor] = None, fov_min: Union[float, torch_.Tensor] = None, fov_x: Union[float, torch_.Tensor] = None, fov_y: Union[float, torch_.Tensor] = None, width: Union[int, torch_.Tensor] = None, height: Union[int, torch_.Tensor] = None) -> torch_.Tensor: + """Get normalized OpenCV intrinsics matrix from given field of view. +You can provide either fov_max, fov_min, fov_x or fov_y + +Args: + width (int | torch.Tensor): image width + height (int | torch.Tensor): image height + fov_max (float | torch.Tensor): field of view in largest dimension + fov_min (float | torch.Tensor): field of view in smallest dimension + fov_x (float | torch.Tensor): field of view in x axis + fov_y (float | torch.Tensor): field of view in y axis + +Returns: + (torch.Tensor): [..., 3, 3] OpenCV intrinsics matrix""" + utils3d.torch.transforms.intrinsics_from_fov + +@overload +def intrinsics_from_fov_xy(fov_x: Union[float, torch_.Tensor], fov_y: Union[float, torch_.Tensor]) -> torch_.Tensor: + """Get OpenCV intrinsics matrix from field of view in x and y axis + +Args: + fov_x (float | torch.Tensor): field of view in x axis + fov_y (float | torch.Tensor): field of view in y axis + +Returns: + (torch.Tensor): [..., 3, 3] OpenCV intrinsics matrix""" + utils3d.torch.transforms.intrinsics_from_fov_xy + +@overload +def view_look_at(eye: torch_.Tensor, look_at: torch_.Tensor, up: torch_.Tensor) -> torch_.Tensor: + """Get OpenGL view matrix looking at something + +Args: + eye (torch.Tensor): [..., 3] the eye position + look_at (torch.Tensor): [..., 3] the position to look at + up (torch.Tensor): [..., 3] head up direction (y axis in screen space). Not necessarily othogonal to view direction + +Returns: + (torch.Tensor): [..., 4, 4], view matrix""" + utils3d.torch.transforms.view_look_at + +@overload +def extrinsics_look_at(eye: torch_.Tensor, look_at: torch_.Tensor, up: torch_.Tensor) -> torch_.Tensor: + """Get OpenCV extrinsics matrix looking at something + +Args: + eye (torch.Tensor): [..., 3] the eye position + look_at (torch.Tensor): [..., 3] the position to look at + up (torch.Tensor): [..., 3] head up direction (-y axis in screen space). Not necessarily othogonal to view direction + +Returns: + (torch.Tensor): [..., 4, 4], extrinsics matrix""" + utils3d.torch.transforms.extrinsics_look_at + +@overload +def perspective_to_intrinsics(perspective: torch_.Tensor) -> torch_.Tensor: + """OpenGL perspective matrix to OpenCV intrinsics + +Args: + perspective (torch.Tensor): [..., 4, 4] OpenGL perspective matrix + +Returns: + (torch.Tensor): shape [..., 3, 3] OpenCV intrinsics""" + utils3d.torch.transforms.perspective_to_intrinsics + +@overload +def intrinsics_to_perspective(intrinsics: torch_.Tensor, near: Union[float, torch_.Tensor], far: Union[float, torch_.Tensor]) -> torch_.Tensor: + """OpenCV intrinsics to OpenGL perspective matrix + +Args: + intrinsics (torch.Tensor): [..., 3, 3] OpenCV intrinsics matrix + near (float | torch.Tensor): [...] near plane to clip + far (float | torch.Tensor): [...] far plane to clip +Returns: + (torch.Tensor): [..., 4, 4] OpenGL perspective matrix""" + utils3d.torch.transforms.intrinsics_to_perspective + +@overload +def extrinsics_to_view(extrinsics: torch_.Tensor) -> torch_.Tensor: + """OpenCV camera extrinsics to OpenGL view matrix + +Args: + extrinsics (torch.Tensor): [..., 4, 4] OpenCV camera extrinsics matrix + +Returns: + (torch.Tensor): [..., 4, 4] OpenGL view matrix""" + utils3d.torch.transforms.extrinsics_to_view + +@overload +def view_to_extrinsics(view: torch_.Tensor) -> torch_.Tensor: + """OpenGL view matrix to OpenCV camera extrinsics + +Args: + view (torch.Tensor): [..., 4, 4] OpenGL view matrix + +Returns: + (torch.Tensor): [..., 4, 4] OpenCV camera extrinsics matrix""" + utils3d.torch.transforms.view_to_extrinsics + +@overload +def normalize_intrinsics(intrinsics: torch_.Tensor, width: Union[int, torch_.Tensor], height: Union[int, torch_.Tensor]) -> torch_.Tensor: + """Normalize camera intrinsics(s) to uv space + +Args: + intrinsics (torch.Tensor): [..., 3, 3] camera intrinsics(s) to normalize + width (int | torch.Tensor): [...] image width(s) + height (int | torch.Tensor): [...] image height(s) + +Returns: + (torch.Tensor): [..., 3, 3] normalized camera intrinsics(s)""" + utils3d.torch.transforms.normalize_intrinsics + +@overload +def crop_intrinsics(intrinsics: torch_.Tensor, width: Union[int, torch_.Tensor], height: Union[int, torch_.Tensor], left: Union[int, torch_.Tensor], top: Union[int, torch_.Tensor], crop_width: Union[int, torch_.Tensor], crop_height: Union[int, torch_.Tensor]) -> torch_.Tensor: + """Evaluate the new intrinsics(s) after crop the image: cropped_img = img[top:top+crop_height, left:left+crop_width] + +Args: + intrinsics (torch.Tensor): [..., 3, 3] camera intrinsics(s) to crop + width (int | torch.Tensor): [...] image width(s) + height (int | torch.Tensor): [...] image height(s) + left (int | torch.Tensor): [...] left crop boundary + top (int | torch.Tensor): [...] top crop boundary + crop_width (int | torch.Tensor): [...] crop width + crop_height (int | torch.Tensor): [...] crop height + +Returns: + (torch.Tensor): [..., 3, 3] cropped camera intrinsics(s)""" + utils3d.torch.transforms.crop_intrinsics + +@overload +def pixel_to_uv(pixel: torch_.Tensor, width: Union[int, torch_.Tensor], height: Union[int, torch_.Tensor]) -> torch_.Tensor: + """Args: + pixel (torch.Tensor): [..., 2] pixel coordinrates defined in image space, x range is (0, W - 1), y range is (0, H - 1) + width (int | torch.Tensor): [...] image width(s) + height (int | torch.Tensor): [...] image height(s) + +Returns: + (torch.Tensor): [..., 2] pixel coordinrates defined in uv space, the range is (0, 1)""" + utils3d.torch.transforms.pixel_to_uv + +@overload +def pixel_to_ndc(pixel: torch_.Tensor, width: Union[int, torch_.Tensor], height: Union[int, torch_.Tensor]) -> torch_.Tensor: + """Args: + pixel (torch.Tensor): [..., 2] pixel coordinrates defined in image space, x range is (0, W - 1), y range is (0, H - 1) + width (int | torch.Tensor): [...] image width(s) + height (int | torch.Tensor): [...] image height(s) + +Returns: + (torch.Tensor): [..., 2] pixel coordinrates defined in ndc space, the range is (-1, 1)""" + utils3d.torch.transforms.pixel_to_ndc + +@overload +def uv_to_pixel(uv: torch_.Tensor, width: Union[int, torch_.Tensor], height: Union[int, torch_.Tensor]) -> torch_.Tensor: + """Args: + uv (torch.Tensor): [..., 2] pixel coordinrates defined in uv space, the range is (0, 1) + width (int | torch.Tensor): [...] image width(s) + height (int | torch.Tensor): [...] image height(s) + +Returns: + (torch.Tensor): [..., 2] pixel coordinrates defined in uv space, the range is (0, 1)""" + utils3d.torch.transforms.uv_to_pixel + +@overload +def project_depth(depth: torch_.Tensor, near: Union[float, torch_.Tensor], far: Union[float, torch_.Tensor]) -> torch_.Tensor: + """Project linear depth to depth value in screen space + +Args: + depth (torch.Tensor): [...] depth value + near (float | torch.Tensor): [...] near plane to clip + far (float | torch.Tensor): [...] far plane to clip + +Returns: + (torch.Tensor): [..., 1] depth value in screen space, value ranging in [0, 1]""" + utils3d.torch.transforms.project_depth + +@overload +def depth_buffer_to_linear(depth: torch_.Tensor, near: Union[float, torch_.Tensor], far: Union[float, torch_.Tensor]) -> torch_.Tensor: + """Linearize depth value to linear depth + +Args: + depth (torch.Tensor): [...] screen depth value, ranging in [0, 1] + near (float | torch.Tensor): [...] near plane to clip + far (float | torch.Tensor): [...] far plane to clip + +Returns: + (torch.Tensor): [...] linear depth""" + utils3d.torch.transforms.depth_buffer_to_linear + +@overload +def project_gl(points: torch_.Tensor, model: torch_.Tensor = None, view: torch_.Tensor = None, perspective: torch_.Tensor = None) -> Tuple[torch_.Tensor, torch_.Tensor]: + """Project 3D points to 2D following the OpenGL convention (except for row major matrice) + +Args: + points (torch.Tensor): [..., N, 3 or 4] 3D points to project, if the last + dimension is 4, the points are assumed to be in homogeneous coordinates + model (torch.Tensor): [..., 4, 4] model matrix + view (torch.Tensor): [..., 4, 4] view matrix + perspective (torch.Tensor): [..., 4, 4] perspective matrix + +Returns: + scr_coord (torch.Tensor): [..., N, 3] screen space coordinates, value ranging in [0, 1]. + The origin (0., 0., 0.) is corresponding to the left & bottom & nearest + linear_depth (torch.Tensor): [..., N] linear depth""" + utils3d.torch.transforms.project_gl + +@overload +def project_cv(points: torch_.Tensor, extrinsics: torch_.Tensor = None, intrinsics: torch_.Tensor = None) -> Tuple[torch_.Tensor, torch_.Tensor]: + """Project 3D points to 2D following the OpenCV convention + +Args: + points (torch.Tensor): [..., N, 3] or [..., N, 4] 3D points to project, if the last + dimension is 4, the points are assumed to be in homogeneous coordinates + extrinsics (torch.Tensor): [..., 4, 4] extrinsics matrix + intrinsics (torch.Tensor): [..., 3, 3] intrinsics matrix + +Returns: + uv_coord (torch.Tensor): [..., N, 2] uv coordinates, value ranging in [0, 1]. + The origin (0., 0.) is corresponding to the left & top + linear_depth (torch.Tensor): [..., N] linear depth""" + utils3d.torch.transforms.project_cv + +@overload +def unproject_gl(screen_coord: torch_.Tensor, model: torch_.Tensor = None, view: torch_.Tensor = None, perspective: torch_.Tensor = None) -> torch_.Tensor: + """Unproject screen space coordinates to 3D view space following the OpenGL convention (except for row major matrice) + +Args: + screen_coord (torch.Tensor): [... N, 3] screen space coordinates, value ranging in [0, 1]. + The origin (0., 0., 0.) is corresponding to the left & bottom & nearest + model (torch.Tensor): [..., 4, 4] model matrix + view (torch.Tensor): [..., 4, 4] view matrix + perspective (torch.Tensor): [..., 4, 4] perspective matrix + +Returns: + points (torch.Tensor): [..., N, 3] 3d points""" + utils3d.torch.transforms.unproject_gl + +@overload +def unproject_cv(uv_coord: torch_.Tensor, depth: torch_.Tensor, extrinsics: torch_.Tensor = None, intrinsics: torch_.Tensor = None) -> torch_.Tensor: + """Unproject uv coordinates to 3D view space following the OpenCV convention + +Args: + uv_coord (torch.Tensor): [..., N, 2] uv coordinates, value ranging in [0, 1]. + The origin (0., 0.) is corresponding to the left & top + depth (torch.Tensor): [..., N] depth value + extrinsics (torch.Tensor): [..., 4, 4] extrinsics matrix + intrinsics (torch.Tensor): [..., 3, 3] intrinsics matrix + +Returns: + points (torch.Tensor): [..., N, 3] 3d points""" + utils3d.torch.transforms.unproject_cv + +@overload +def skew_symmetric(v: torch_.Tensor): + """Skew symmetric matrix from a 3D vector""" + utils3d.torch.transforms.skew_symmetric + +@overload +def rotation_matrix_from_vectors(v1: torch_.Tensor, v2: torch_.Tensor): + """Rotation matrix that rotates v1 to v2""" + utils3d.torch.transforms.rotation_matrix_from_vectors + +@overload +def euler_axis_angle_rotation(axis: str, angle: torch_.Tensor) -> torch_.Tensor: + """Return the rotation matrices for one of the rotations about an axis +of which Euler angles describe, for each value of the angle given. + +Args: + axis: Axis label "X" or "Y or "Z". + angle: any shape tensor of Euler angles in radians + +Returns: + Rotation matrices as tensor of shape (..., 3, 3).""" + utils3d.torch.transforms.euler_axis_angle_rotation + +@overload +def euler_angles_to_matrix(euler_angles: torch_.Tensor, convention: str = 'XYZ') -> torch_.Tensor: + """Convert rotations given as Euler angles in radians to rotation matrices. + +Args: + euler_angles: Euler angles in radians as tensor of shape (..., 3), XYZ + convention: permutation of "X", "Y" or "Z", representing the order of Euler rotations to apply. + +Returns: + Rotation matrices as tensor of shape (..., 3, 3).""" + utils3d.torch.transforms.euler_angles_to_matrix + +@overload +def matrix_to_euler_angles(matrix: torch_.Tensor, convention: str) -> torch_.Tensor: + """Convert rotations given as rotation matrices to Euler angles in radians. +NOTE: The composition order eg. `XYZ` means `Rz * Ry * Rx` (like blender), instead of `Rx * Ry * Rz` (like pytorch3d) + +Args: + matrix: Rotation matrices as tensor of shape (..., 3, 3). + convention: Convention string of three uppercase letters. + +Returns: + Euler angles in radians as tensor of shape (..., 3), in the order of XYZ (like blender), instead of convention (like pytorch3d)""" + utils3d.torch.transforms.matrix_to_euler_angles + +@overload +def matrix_to_quaternion(rot_mat: torch_.Tensor, eps: float = 1e-12) -> torch_.Tensor: + """Convert 3x3 rotation matrix to quaternion (w, x, y, z) + +Args: + rot_mat (torch.Tensor): shape (..., 3, 3), the rotation matrices to convert + +Returns: + torch.Tensor: shape (..., 4), the quaternions corresponding to the given rotation matrices""" + utils3d.torch.transforms.matrix_to_quaternion + +@overload +def quaternion_to_matrix(quaternion: torch_.Tensor, eps: float = 1e-12) -> torch_.Tensor: + """Converts a batch of quaternions (w, x, y, z) to rotation matrices + +Args: + quaternion (torch.Tensor): shape (..., 4), the quaternions to convert + +Returns: + torch.Tensor: shape (..., 3, 3), the rotation matrices corresponding to the given quaternions""" + utils3d.torch.transforms.quaternion_to_matrix + +@overload +def matrix_to_axis_angle(rot_mat: torch_.Tensor, eps: float = 1e-12) -> torch_.Tensor: + """Convert a batch of 3x3 rotation matrices to axis-angle representation (rotation vector) + +Args: + rot_mat (torch.Tensor): shape (..., 3, 3), the rotation matrices to convert + +Returns: + torch.Tensor: shape (..., 3), the axis-angle vectors corresponding to the given rotation matrices""" + utils3d.torch.transforms.matrix_to_axis_angle + +@overload +def axis_angle_to_matrix(axis_angle: torch_.Tensor, eps: float = 1e-12) -> torch_.Tensor: + """Convert axis-angle representation (rotation vector) to rotation matrix, whose direction is the axis of rotation and length is the angle of rotation + +Args: + axis_angle (torch.Tensor): shape (..., 3), axis-angle vcetors + +Returns: + torch.Tensor: shape (..., 3, 3) The rotation matrices for the given axis-angle parameters""" + utils3d.torch.transforms.axis_angle_to_matrix + +@overload +def axis_angle_to_quaternion(axis_angle: torch_.Tensor, eps: float = 1e-12) -> torch_.Tensor: + """Convert axis-angle representation (rotation vector) to quaternion (w, x, y, z) + +Args: + axis_angle (torch.Tensor): shape (..., 3), axis-angle vcetors + +Returns: + torch.Tensor: shape (..., 4) The quaternions for the given axis-angle parameters""" + utils3d.torch.transforms.axis_angle_to_quaternion + +@overload +def quaternion_to_axis_angle(quaternion: torch_.Tensor, eps: float = 1e-12) -> torch_.Tensor: + """Convert a batch of quaternions (w, x, y, z) to axis-angle representation (rotation vector) + +Args: + quaternion (torch.Tensor): shape (..., 4), the quaternions to convert + +Returns: + torch.Tensor: shape (..., 3), the axis-angle vectors corresponding to the given quaternions""" + utils3d.torch.transforms.quaternion_to_axis_angle + +@overload +def slerp(rot_mat_1: torch_.Tensor, rot_mat_2: torch_.Tensor, t: Union[numbers.Number, torch_.Tensor]) -> torch_.Tensor: + """Spherical linear interpolation between two rotation matrices + +Args: + rot_mat_1 (torch.Tensor): shape (..., 3, 3), the first rotation matrix + rot_mat_2 (torch.Tensor): shape (..., 3, 3), the second rotation matrix + t (torch.Tensor): scalar or shape (...,), the interpolation factor + +Returns: + torch.Tensor: shape (..., 3, 3), the interpolated rotation matrix""" + utils3d.torch.transforms.slerp + +@overload +def interpolate_extrinsics(ext1: torch_.Tensor, ext2: torch_.Tensor, t: Union[numbers.Number, torch_.Tensor]) -> torch_.Tensor: + """Interpolate extrinsics between two camera poses. Linear interpolation for translation, spherical linear interpolation for rotation. + +Args: + ext1 (torch.Tensor): shape (..., 4, 4), the first camera pose + ext2 (torch.Tensor): shape (..., 4, 4), the second camera pose + t (torch.Tensor): scalar or shape (...,), the interpolation factor + +Returns: + torch.Tensor: shape (..., 4, 4), the interpolated camera pose""" + utils3d.torch.transforms.interpolate_extrinsics + +@overload +def interpolate_view(view1: torch_.Tensor, view2: torch_.Tensor, t: Union[numbers.Number, torch_.Tensor]): + """Interpolate view matrices between two camera poses. Linear interpolation for translation, spherical linear interpolation for rotation. + +Args: + ext1 (torch.Tensor): shape (..., 4, 4), the first camera pose + ext2 (torch.Tensor): shape (..., 4, 4), the second camera pose + t (torch.Tensor): scalar or shape (...,), the interpolation factor + +Returns: + torch.Tensor: shape (..., 4, 4), the interpolated camera pose""" + utils3d.torch.transforms.interpolate_view + +@overload +def extrinsics_to_essential(extrinsics: torch_.Tensor): + """extrinsics matrix `[[R, t] [0, 0, 0, 1]]` such that `x' = R (x - t)` to essential matrix such that `x' E x = 0` + +Args: + extrinsics (torch.Tensor): [..., 4, 4] extrinsics matrix + +Returns: + (torch.Tensor): [..., 3, 3] essential matrix""" + utils3d.torch.transforms.extrinsics_to_essential + +@overload +def to4x4(R: torch_.Tensor, t: torch_.Tensor): + """Compose rotation matrix and translation vector to 4x4 transformation matrix + +Args: + R (torch.Tensor): [..., 3, 3] rotation matrix + t (torch.Tensor): [..., 3] translation vector + +Returns: + (torch.Tensor): [..., 4, 4] transformation matrix""" + utils3d.torch.transforms.to4x4 + +@overload +def rotation_matrix_2d(theta: Union[float, torch_.Tensor]): + """2x2 matrix for 2D rotation + +Args: + theta (float | torch.Tensor): rotation angle in radians, arbitrary shape (...,) + +Returns: + (torch.Tensor): (..., 2, 2) rotation matrix""" + utils3d.torch.transforms.rotation_matrix_2d + +@overload +def rotate_2d(theta: Union[float, torch_.Tensor], center: torch_.Tensor = None): + """3x3 matrix for 2D rotation around a center +``` + [[Rxx, Rxy, tx], + [Ryx, Ryy, ty], + [0, 0, 1]] +``` +Args: + theta (float | torch.Tensor): rotation angle in radians, arbitrary shape (...,) + center (torch.Tensor): rotation center, arbitrary shape (..., 2). Default to (0, 0) + +Returns: + (torch.Tensor): (..., 3, 3) transformation matrix""" + utils3d.torch.transforms.rotate_2d + +@overload +def translate_2d(translation: torch_.Tensor): + """Translation matrix for 2D translation +``` + [[1, 0, tx], + [0, 1, ty], + [0, 0, 1]] +``` +Args: + translation (torch.Tensor): translation vector, arbitrary shape (..., 2) + +Returns: + (torch.Tensor): (..., 3, 3) transformation matrix""" + utils3d.torch.transforms.translate_2d + +@overload +def scale_2d(scale: Union[float, torch_.Tensor], center: torch_.Tensor = None): + """Scale matrix for 2D scaling +``` + [[s, 0, tx], + [0, s, ty], + [0, 0, 1]] +``` +Args: + scale (float | torch.Tensor): scale factor, arbitrary shape (...,) + center (torch.Tensor): scale center, arbitrary shape (..., 2). Default to (0, 0) + +Returns: + (torch.Tensor): (..., 3, 3) transformation matrix""" + utils3d.torch.transforms.scale_2d + +@overload +def apply_2d(transform: torch_.Tensor, points: torch_.Tensor): + """Apply (3x3 or 2x3) 2D affine transformation to points +``` + p = R @ p + t +``` +Args: + transform (torch.Tensor): (..., 2 or 3, 3) transformation matrix + points (torch.Tensor): (..., N, 2) points to transform + +Returns: + (torch.Tensor): (..., N, 2) transformed points""" + utils3d.torch.transforms.apply_2d + +@overload +def RastContext(nvd_ctx: Union[nvdiffrast.torch.ops.RasterizeCudaContext, nvdiffrast.torch.ops.RasterizeGLContext] = None, *, backend: Literal['cuda', 'gl'] = 'gl', device: Union[str, torch_.device] = None): + """Create a rasterization context. Nothing but a wrapper of nvdiffrast.torch.RasterizeCudaContext or nvdiffrast.torch.RasterizeGLContext.""" + utils3d.torch.rasterization.RastContext + +@overload +def rasterize_triangle_faces(ctx: utils3d.torch.rasterization.RastContext, vertices: torch_.Tensor, faces: torch_.Tensor, attr: torch_.Tensor, width: int, height: int, model: torch_.Tensor = None, view: torch_.Tensor = None, projection: torch_.Tensor = None, antialiasing: Union[bool, List[int]] = True, diff_attrs: Optional[List[int]] = None) -> Tuple[torch_.Tensor, torch_.Tensor, Optional[torch_.Tensor]]: + """Rasterize a mesh with vertex attributes. + +Args: + ctx (GLContext): rasterizer context + vertices (np.ndarray): (B, N, 2 or 3 or 4) + faces (torch.Tensor): (T, 3) + attr (torch.Tensor): (B, N, C) + width (int): width of the output image + height (int): height of the output image + model (torch.Tensor, optional): ([B,] 4, 4) model matrix. Defaults to None (identity). + view (torch.Tensor, optional): ([B,] 4, 4) view matrix. Defaults to None (identity). + projection (torch.Tensor, optional): ([B,] 4, 4) projection matrix. Defaults to None (identity). + antialiasing (Union[bool, List[int]], optional): whether to perform antialiasing. Defaults to True. If a list of indices is provided, only those channels will be antialiased. + diff_attrs (Union[None, List[int]], optional): indices of attributes to compute screen-space derivatives. Defaults to None. + +Returns: + image: (torch.Tensor): (B, C, H, W) + depth: (torch.Tensor): (B, H, W) screen space depth, ranging from 0 (near) to 1. (far) + NOTE: Empty pixels will have depth 1., i.e. far plane.""" + utils3d.torch.rasterization.rasterize_triangle_faces + +@overload +def warp_image_by_depth(ctx: utils3d.torch.rasterization.RastContext, depth: torch_.FloatTensor, image: torch_.FloatTensor = None, mask: torch_.BoolTensor = None, width: int = None, height: int = None, *, extrinsics_src: torch_.FloatTensor = None, extrinsics_tgt: torch_.FloatTensor = None, intrinsics_src: torch_.FloatTensor = None, intrinsics_tgt: torch_.FloatTensor = None, near: float = 0.1, far: float = 100.0, antialiasing: bool = True, backslash: bool = False, padding: int = 0, return_uv: bool = False, return_dr: bool = False) -> Tuple[torch_.FloatTensor, torch_.FloatTensor, torch_.BoolTensor, Optional[torch_.FloatTensor], Optional[torch_.FloatTensor]]: + """Warp image by depth. +NOTE: if batch size is 1, image mesh will be triangulated aware of the depth, yielding less distorted results. +Otherwise, image mesh will be triangulated simply for batch rendering. + +Args: + ctx (Union[dr.RasterizeCudaContext, dr.RasterizeGLContext]): rasterization context + depth (torch.Tensor): (B, H, W) linear depth + image (torch.Tensor): (B, C, H, W). None to use image space uv. Defaults to None. + width (int, optional): width of the output image. None to use the same as depth. Defaults to None. + height (int, optional): height of the output image. Defaults the same as depth.. + extrinsics_src (torch.Tensor, optional): (B, 4, 4) extrinsics matrix for source. None to use identity. Defaults to None. + extrinsics_tgt (torch.Tensor, optional): (B, 4, 4) extrinsics matrix for target. None to use identity. Defaults to None. + intrinsics_src (torch.Tensor, optional): (B, 3, 3) intrinsics matrix for source. None to use the same as target. Defaults to None. + intrinsics_tgt (torch.Tensor, optional): (B, 3, 3) intrinsics matrix for target. None to use the same as source. Defaults to None. + near (float, optional): near plane. Defaults to 0.1. + far (float, optional): far plane. Defaults to 100.0. + antialiasing (bool, optional): whether to perform antialiasing. Defaults to True. + backslash (bool, optional): whether to use backslash triangulation. Defaults to False. + padding (int, optional): padding of the image. Defaults to 0. + return_uv (bool, optional): whether to return the uv. Defaults to False. + return_dr (bool, optional): whether to return the image-space derivatives of uv. Defaults to False. + +Returns: + image: (torch.FloatTensor): (B, C, H, W) rendered image + depth: (torch.FloatTensor): (B, H, W) linear depth, ranging from 0 to inf + mask: (torch.BoolTensor): (B, H, W) mask of valid pixels + uv: (torch.FloatTensor): (B, 2, H, W) image-space uv + dr: (torch.FloatTensor): (B, 4, H, W) image-space derivatives of uv""" + utils3d.torch.rasterization.warp_image_by_depth + +@overload +def warp_image_by_forward_flow(ctx: utils3d.torch.rasterization.RastContext, image: torch_.FloatTensor, flow: torch_.FloatTensor, depth: torch_.FloatTensor = None, *, antialiasing: bool = True, backslash: bool = False) -> Tuple[torch_.FloatTensor, torch_.BoolTensor]: + """Warp image by forward flow. +NOTE: if batch size is 1, image mesh will be triangulated aware of the depth, yielding less distorted results. +Otherwise, image mesh will be triangulated simply for batch rendering. + +Args: + ctx (Union[dr.RasterizeCudaContext, dr.RasterizeGLContext]): rasterization context + image (torch.Tensor): (B, C, H, W) image + flow (torch.Tensor): (B, 2, H, W) forward flow + depth (torch.Tensor, optional): (B, H, W) linear depth. If None, will use the same for all pixels. Defaults to None. + antialiasing (bool, optional): whether to perform antialiasing. Defaults to True. + backslash (bool, optional): whether to use backslash triangulation. Defaults to False. + +Returns: + image: (torch.FloatTensor): (B, C, H, W) rendered image + mask: (torch.BoolTensor): (B, H, W) mask of valid pixels""" + utils3d.torch.rasterization.warp_image_by_forward_flow + diff --git a/submodules/MoGe/utils3d/io/__init__.py b/submodules/MoGe/utils3d/io/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e67b88d2d06b04520fec1cb21b70bdda521eafed --- /dev/null +++ b/submodules/MoGe/utils3d/io/__init__.py @@ -0,0 +1,3 @@ +from .obj import * +from .colmap import * +from .ply import * diff --git a/submodules/MoGe/utils3d/io/colmap.py b/submodules/MoGe/utils3d/io/colmap.py new file mode 100644 index 0000000000000000000000000000000000000000..d00ccbea8b3974e45fc91678000e9083e4ce378b --- /dev/null +++ b/submodules/MoGe/utils3d/io/colmap.py @@ -0,0 +1,139 @@ +from typing import * +from pathlib import Path + +import numpy as np +from scipy.spatial.transform import Rotation + + +__all__ = ['read_extrinsics_from_colmap', 'read_intrinsics_from_colmap', 'write_extrinsics_as_colmap', 'write_intrinsics_as_colmap'] + + +def write_extrinsics_as_colmap(file: Union[str, Path], extrinsics: np.ndarray, image_names: Union[str, List[str]] = 'image_{i:04d}.png', camera_ids: List[int] = None): + """ + Write extrinsics to colmap `images.txt` file. + Args: + file: Path to `images.txt` file. + extrinsics: (N, 4, 4) array of extrinsics. + image_names: str or List of str, image names. Length is N. + If str, it should be a format string with `i` as the index. (i starts from 1, in correspondence with IMAGE_ID in colmap) + camera_ids: List of int, camera ids. Length is N. + If None, it will be set to [1, 2, ..., N]. + """ + assert extrinsics.shape[1:] == (4, 4) and extrinsics.ndim == 3 or extrinsics.shape == (4, 4) + if extrinsics.ndim == 2: + extrinsics = extrinsics[np.newaxis, ...] + quats = Rotation.from_matrix(extrinsics[:, :3, :3]).as_quat() + trans = extrinsics[:, :3, 3] + if camera_ids is None: + camera_ids = list(range(1, len(extrinsics) + 1)) + if isinstance(image_names, str): + image_names = [image_names.format(i=i) for i in range(1, len(extrinsics) + 1)] + assert len(extrinsics) == len(image_names) == len(camera_ids), \ + f'Number of extrinsics ({len(extrinsics)}), image_names ({len(image_names)}), and camera_ids ({len(camera_ids)}) must be the same' + with open(file, 'w') as fp: + print("# IMAGE_ID, QW, QX, QY, QZ, TX, TY, TZ, CAMERA_ID, NAME", file=fp) + for i, (quat, t, name, camera_id) in enumerate(zip(quats.tolist(), trans.tolist(), image_names, camera_ids)): + # Colmap has wxyz order while scipy.spatial.transform.Rotation has xyzw order. + qx, qy, qz, qw = quat + tx, ty, tz = t + print(f'{i + 1} {qw:f} {qx:f} {qy:f} {qz:f} {tx:f} {ty:f} {tz:f} {camera_id:d} {name}', file=fp) + print() + + +def write_intrinsics_as_colmap(file: Union[str, Path], intrinsics: np.ndarray, width: int, height: int, normalized: bool = False): + """ + Write intrinsics to colmap `cameras.txt` file. Currently only support PINHOLE model (no distortion) + Args: + file: Path to `cameras.txt` file. + intrinsics: (N, 3, 3) array of intrinsics. + width: Image width. + height: Image height. + normalized: Whether the intrinsics are normalized. If True, the intrinsics will unnormalized for writing. + """ + assert intrinsics.shape[1:] == (3, 3) and intrinsics.ndim == 3 or intrinsics.shape == (3, 3) + if intrinsics.ndim == 2: + intrinsics = intrinsics[np.newaxis, ...] + if normalized: + intrinsics = intrinsics * np.array([width, height, 1])[:, None] + with open(file, 'w') as fp: + print("# CAMERA_ID, MODEL, WIDTH, HEIGHT, PARAMS[]", file=fp) + for i, intr in enumerate(intrinsics): + fx, fy, cx, cy = intr[0, 0], intr[1, 1], intr[0, 2], intr[1, 2] + print(f'{i + 1} PINHOLE {width:d} {height:d} {fx:f} {fy:f} {cx:f} {cy:f}', file=fp) + + +def read_extrinsics_from_colmap(file: Union[str, Path]) -> Union[np.ndarray, List[int], List[str]]: + """ + Read extrinsics from colmap `images.txt` file. + Args: + file: Path to `images.txt` file. + Returns: + extrinsics: (N, 4, 4) array of extrinsics. + camera_ids: List of int, camera ids. Length is N. Note that camera ids in colmap typically starts from 1. + image_names: List of str, image names. Length is N. + """ + with open(file) as fp: + lines = fp.readlines() + image_names, quats, trans, camera_ids = [], [], [], [] + i_line = 0 + for line in lines: + line = line.strip() + if line.startswith('#'): + continue + i_line += 1 + if i_line % 2 == 0: + continue + image_id, qw, qx, qy, qz, tx, ty, tz, camera_id, name = line.split() + quats.append([float(qx), float(qy), float(qz), float(qw)]) + trans.append([float(tx), float(ty), float(tz)]) + camera_ids.append(int(camera_id)) + image_names.append(name) + + quats = np.array(quats, dtype=np.float32) + trans = np.array(trans, dtype=np.float32) + rotation = Rotation.from_quat(quats).as_matrix() + extrinsics = np.concatenate([ + np.concatenate([rotation, trans[..., None]], axis=-1), + np.array([0, 0, 0, 1], dtype=np.float32)[None, None, :].repeat(len(quats), axis=0) + ], axis=-2) + + return extrinsics, camera_ids, image_names + + +def read_intrinsics_from_colmap(file: Union[str, Path], normalize: bool = False) -> Tuple[List[int], np.ndarray, np.ndarray]: + """ + Read intrinsics from colmap `cameras.txt` file. + Args: + file: Path to `cameras.txt` file. + normalize: Whether to normalize the intrinsics. If True, the intrinsics will be normalized. (mapping coordinates to [0, 1] range) + Returns: + camera_ids: List of int, camera ids. Length is N. Note that camera ids in colmap typically starts from 1. + intrinsics: (N, 3, 3) array of intrinsics. + distortions: (N, 5) array of distortions. + """ + with open(file) as fp: + lines = fp.readlines() + intrinsics, distortions, camera_ids = [], [], [] + for line in lines: + line = line.strip() + if not line or line.startswith('#'): + continue + camera_id, model, width, height, *params = line.split() + camera_id, width, height = int(camera_id), int(width), int(height) + if model == 'PINHOLE': + fx, fy, cx, cy = map(float, params[:4]) + k1 = k2 = k3 = p1 = p2 = 0.0 + elif model == 'OPENCV': + fx, fy, cx, cy, k1, k2, p1, p2, k3 = *map(float, params[:8]), 0.0 + elif model == 'SIMPLE_RADIAL': + f, cx, cy, k = map(float, params[:4]) + fx = fy = f + k1, k2, p1, p2, k3 = k, 0.0, 0.0, 0.0, 0.0 + camera_ids.append(camera_id) + if normalize: + fx, fy, cx, cy = fx / width, fy / height, cx / width, cy / height + intrinsics.append([[fx, 0, cx], [0, fy, cy], [0, 0, 1]]) + distortions.append([k1, k2, p1, p2, k3]) + intrinsics = np.array(intrinsics, dtype=np.float32) + distortions = np.array(distortions, dtype=np.float32) + return camera_ids, intrinsics, distortions diff --git a/submodules/MoGe/utils3d/io/obj.py b/submodules/MoGe/utils3d/io/obj.py new file mode 100644 index 0000000000000000000000000000000000000000..3471e490bd758fbf58173cfb7297ec747f46f173 --- /dev/null +++ b/submodules/MoGe/utils3d/io/obj.py @@ -0,0 +1,146 @@ +from io import TextIOWrapper +from typing import Dict, Any, Union, Iterable +import numpy as np +from pathlib import Path + +__all__ = [ + 'read_obj', + 'write_obj', + 'simple_write_obj' +] + +def read_obj( + file : Union[str, Path, TextIOWrapper], + encoding: Union[str, None] = None, + ignore_unknown: bool = False +): + """ + Read wavefront .obj file, without preprocessing. + + Why bothering having this read_obj() while we already have other libraries like `trimesh`? + This function read the raw format from .obj file and keeps the order of vertices and faces, + while trimesh which involves modification like merge/split vertices, which could break the orders of vertices and faces, + Those libraries are commonly aiming at geometry processing and rendering supporting various formats. + If you want mesh geometry processing, you may turn to `trimesh` for more features. + + ### Parameters + `file` (str, Path, TextIOWrapper): filepath or file object + encoding (str, optional): + + ### Returns + obj (dict): A dict containing .obj components + { + 'mtllib': [], + 'v': [[0,1, 0.2, 1.0], [1.2, 0.0, 0.0], ...], + 'vt': [[0.5, 0.5], ...], + 'vn': [[0., 0.7, 0.7], [0., -0.7, 0.7], ...], + 'f': [[0, 1, 2], [2, 3, 4],...], + 'usemtl': [{'name': 'mtl1', 'f': 7}] + } + """ + if hasattr(file,'read'): + lines = file.read().splitlines() + else: + with open(file, 'r', encoding=encoding) as fp: + lines = fp.read().splitlines() + mtllib = [] + v, vt, vn, vp = [], [], [], [] # Vertex coordinates, Vertex texture coordinate, Vertex normal, Vertex parameter + f, ft, fn = [], [], [] # Face indices, Face texture indices, Face normal indices + o = [] + s = [] + usemtl = [] + + def pad(l: list, n: Any): + return l + [n] * (3 - len(l)) + + for i, line in enumerate(lines): + sq = line.strip().split() + if len(sq) == 0: + continue + if sq[0] == 'v': + assert 4 <= len(sq) <= 5, f'Invalid format of line {i}: {line}' + v.append([float(e) for e in sq[1:]][:3]) + elif sq[0] == 'vt': + assert 3 <= len(sq) <= 4, f'Invalid format of line {i}: {line}' + vt.append([float(e) for e in sq[1:]][:2]) + elif sq[0] == 'vn': + assert len(sq) == 4, f'Invalid format of line {i}: {line}' + vn.append([float(e) for e in sq[1:]]) + elif sq[0] == 'vp': + assert 2 <= len(sq) <= 4, f'Invalid format of line {i}: {line}' + vp.append(pad([float(e) for e in sq[1:]], 0)) + elif sq[0] == 'f': + spliting = [pad([int(j) - 1 for j in e.split('/')], -1) for e in sq[1:]] + f.append([e[0] for e in spliting]) + ft.append([e[1] for e in spliting]) + fn.append([e[2] for e in spliting]) + elif sq[0] == 'usemtl': + assert len(sq) == 2 + usemtl.append((sq[1], len(f))) + elif sq[0] == 'o': + assert len(sq) == 2 + o.append((sq[1], len(f))) + elif sq[0] == 's': + s.append((sq[1], len(f))) + elif sq[0] == 'mtllib': + assert len(sq) == 2 + mtllib.append(sq[1]) + elif sq[0][0] == '#': + continue + else: + if not ignore_unknown: + raise Exception(f'Unknown keyword {sq[0]}') + + min_poly_vertices = min(len(f) for f in f) + max_poly_vertices = max(len(f) for f in f) + + return { + 'mtllib': mtllib, + 'v': np.array(v, dtype=np.float32), + 'vt': np.array(vt, dtype=np.float32), + 'vn': np.array(vn, dtype=np.float32), + 'vp': np.array(vp, dtype=np.float32), + 'f': np.array(f, dtype=np.int32) if min_poly_vertices == max_poly_vertices else f, + 'ft': np.array(ft, dtype=np.int32) if min_poly_vertices == max_poly_vertices else ft, + 'fn': np.array(fn, dtype=np.int32) if min_poly_vertices == max_poly_vertices else fn, + 'o': o, + 's': s, + 'usemtl': usemtl, + } + + +def write_obj( + file: Union[str, Path], + obj: Dict[str, Any], + encoding: Union[str, None] = None + ): + with open(file, 'w', encoding=encoding) as fp: + for k in ['v', 'vt', 'vn', 'vp']: + if k not in obj: + continue + for v in obj[k]: + print(k, *map(float, v), file=fp) + for f in obj['f']: + print('f', *((str('/').join(map(int, i)) if isinstance(int(i), Iterable) else i) for i in f), file=fp) + + +def simple_write_obj( + file: Union[str, Path], + vertices: np.ndarray, + faces: np.ndarray, + encoding: Union[str, None] = None + ): + """ + Write wavefront .obj file, without preprocessing. + + Args: + vertices (np.ndarray): [N, 3] + faces (np.ndarray): [T, 3] + file (Any): filepath + encoding (str, optional): + """ + with open(file, 'w', encoding=encoding) as fp: + for v in vertices: + print('v', *map(float, v), file=fp) + for f in faces: + print('f', *map(int, f + 1), file=fp) diff --git a/submodules/MoGe/utils3d/io/ply.py b/submodules/MoGe/utils3d/io/ply.py new file mode 100644 index 0000000000000000000000000000000000000000..39fa41728a7be76d25743788c85dacb384d6d83e --- /dev/null +++ b/submodules/MoGe/utils3d/io/ply.py @@ -0,0 +1,104 @@ +import numpy as np + +from typing import * +from pathlib import Path + + +def read_ply( + file: Union[str, Path], + encoding: Union[str, None] = None, + ignore_unknown: bool = False +) -> Tuple[np.ndarray, np.ndarray]: + """ + Read .ply file, without preprocessing. + + Args: + file (Any): filepath + encoding (str, optional): + + Returns: + Tuple[np.ndarray, np.ndarray]: vertices, faces + """ + import plyfile + plydata = plyfile.PlyData.read(file) + vertices = np.stack([plydata['vertex'][k] for k in ['x', 'y', 'z']], axis=-1) + if 'face' in plydata: + faces = np.array(plydata['face']['vertex_indices'].tolist()) + else: + faces = None + return vertices, faces + + +def write_ply( + file: Union[str, Path], + vertices: np.ndarray, + faces: np.ndarray = None, + edges: np.ndarray = None, + vertex_colors: np.ndarray = None, + edge_colors: np.ndarray = None, + text: bool = False +): + """ + Write .ply file, without preprocessing. + + Args: + file (Any): filepath + vertices (np.ndarray): [N, 3] + faces (np.ndarray): [T, E] + edges (np.ndarray): [E, 2] + vertex_colors (np.ndarray, optional): [N, 3]. Defaults to None. + edge_colors (np.ndarray, optional): [E, 3]. Defaults to None. + text (bool, optional): save data in text format. Defaults to False. + """ + import plyfile + assert vertices.ndim == 2 and vertices.shape[1] == 3 + vertices = vertices.astype(np.float32) + if faces is not None: + assert faces.ndim == 2 + faces = faces.astype(np.int32) + if edges is not None: + assert edges.ndim == 2 and edges.shape[1] == 2 + edges = edges.astype(np.int32) + + if vertex_colors is not None: + assert vertex_colors.ndim == 2 and vertex_colors.shape[1] == 3 + if vertex_colors.dtype in [np.float32, np.float64]: + vertex_colors = vertex_colors * 255 + vertex_colors = np.clip(vertex_colors, 0, 255).astype(np.uint8) + vertices_data = np.zeros(len(vertices), dtype=[('x', 'f4'), ('y', 'f4'), ('z', 'f4'), ('red', 'u1'), ('green', 'u1'), ('blue', 'u1')]) + vertices_data['x'] = vertices[:, 0] + vertices_data['y'] = vertices[:, 1] + vertices_data['z'] = vertices[:, 2] + vertices_data['red'] = vertex_colors[:, 0] + vertices_data['green'] = vertex_colors[:, 1] + vertices_data['blue'] = vertex_colors[:, 2] + else: + vertices_data = np.array([tuple(v) for v in vertices], dtype=[('x', 'f4'), ('y', 'f4'), ('z', 'f4')]) + + if faces is not None: + faces_data = np.zeros(len(faces), dtype=[('vertex_indices', 'i4', (faces.shape[1],))]) + faces_data['vertex_indices'] = faces + + if edges is not None: + if edge_colors is not None: + assert edge_colors.ndim == 2 and edge_colors.shape[1] == 3 + if edge_colors.dtype in [np.float32, np.float64]: + edge_colors = edge_colors * 255 + edge_colors = np.clip(edge_colors, 0, 255).astype(np.uint8) + edges_data = np.zeros(len(edges), dtype=[('vertex1', 'i4'), ('vertex2', 'i4'), ('red', 'u1'), ('green', 'u1'), ('blue', 'u1')]) + edges_data['vertex1'] = edges[:, 0] + edges_data['vertex2'] = edges[:, 1] + edges_data['red'] = edge_colors[:, 0] + edges_data['green'] = edge_colors[:, 1] + edges_data['blue'] = edge_colors[:, 2] + else: + edges_data = np.array([tuple(e) for e in edges], dtype=[('vertex1', 'i4'), ('vertex2', 'i4')]) + + ply_data = [plyfile.PlyElement.describe(vertices_data, 'vertex')] + if faces is not None: + ply_data.append(plyfile.PlyElement.describe(faces_data, 'face')) + if edges is not None: + ply_data.append(plyfile.PlyElement.describe(edges_data, 'edge')) + + plyfile.PlyData(ply_data, text=text).write(file) + \ No newline at end of file diff --git a/submodules/MoGe/utils3d/numpy/__init__.py b/submodules/MoGe/utils3d/numpy/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..aa06c53abe3b4abd39f1d7f8372851d7cdc58260 --- /dev/null +++ b/submodules/MoGe/utils3d/numpy/__init__.py @@ -0,0 +1,142 @@ +""" +3D utility functions workings with NumPy. +""" +import importlib +import itertools +import numpy +from typing import TYPE_CHECKING + + +__modules_all__ = { + 'mesh':[ + 'triangulate', + 'compute_face_normal', + 'compute_face_angle', + 'compute_vertex_normal', + 'compute_vertex_normal_weighted', + 'remove_corrupted_faces', + 'merge_duplicate_vertices', + 'remove_unreferenced_vertices', + 'subdivide_mesh_simple', + 'mesh_relations', + 'flatten_mesh_indices' + ], + 'quadmesh': [ + 'calc_quad_candidates', + 'calc_quad_distortion', + 'calc_quad_direction', + 'calc_quad_smoothness', + 'sovle_quad', + 'sovle_quad_qp', + 'tri_to_quad' + ], + 'utils': [ + 'sliding_window_1d', + 'sliding_window_nd', + 'sliding_window_2d', + 'max_pool_1d', + 'max_pool_2d', + 'max_pool_nd', + 'depth_edge', + 'normals_edge', + 'depth_aliasing', + 'interpolate', + 'image_scrcoord', + 'image_uv', + 'image_pixel_center', + 'image_pixel', + 'image_mesh', + 'image_mesh_from_depth', + 'depth_to_normals', + 'points_to_normals', + 'chessboard', + 'cube', + 'icosahedron', + 'square', + 'camera_frustum', + ], + 'transforms': [ + 'perspective', + 'perspective_from_fov', + 'perspective_from_fov_xy', + 'intrinsics_from_focal_center', + 'intrinsics_from_fov', + 'fov_to_focal', + 'focal_to_fov', + 'intrinsics_to_fov', + 'view_look_at', + 'extrinsics_look_at', + 'perspective_to_intrinsics', + 'perspective_to_near_far', + 'intrinsics_to_perspective', + 'extrinsics_to_view', + 'view_to_extrinsics', + 'normalize_intrinsics', + 'crop_intrinsics', + 'pixel_to_uv', + 'pixel_to_ndc', + 'uv_to_pixel', + 'project_depth', + 'depth_buffer_to_linear', + 'unproject_cv', + 'unproject_gl', + 'project_cv', + 'project_gl', + 'quaternion_to_matrix', + 'axis_angle_to_matrix', + 'matrix_to_quaternion', + 'extrinsics_to_essential', + 'euler_axis_angle_rotation', + 'euler_angles_to_matrix', + 'skew_symmetric', + 'rotation_matrix_from_vectors', + 'ray_intersection', + 'se3_matrix', + 'slerp_quaternion', + 'slerp_vector', + 'lerp', + 'lerp_se3_matrix', + 'piecewise_lerp', + 'piecewise_lerp_se3_matrix', + 'apply_transform' + ], + 'spline': [ + 'linear_spline_interpolate', + ], + 'rasterization': [ + 'RastContext', + 'rasterize_triangle_faces', + 'rasterize_edges', + 'texture', + 'warp_image_by_depth', + 'test_rasterization' + ], +} + + +__all__ = list(itertools.chain(*__modules_all__.values())) + +def __getattr__(name): + try: + return globals()[name] + except KeyError: + pass + + try: + module_name = next(m for m in __modules_all__ if name in __modules_all__[m]) + except StopIteration: + raise AttributeError(f"module '{__name__}' has no attribute '{name}'") + module = importlib.import_module(f'.{module_name}', __name__) + for key in __modules_all__[module_name]: + globals()[key] = getattr(module, key) + + return globals()[name] + + +if TYPE_CHECKING: + from .quadmesh import * + from .transforms import * + from .mesh import * + from .utils import * + from .rasterization import * + from .spline import * \ No newline at end of file diff --git a/submodules/MoGe/utils3d/numpy/_helpers.py b/submodules/MoGe/utils3d/numpy/_helpers.py new file mode 100644 index 0000000000000000000000000000000000000000..7c397df338e3e04e0f228341f68171d8e067eb4e --- /dev/null +++ b/submodules/MoGe/utils3d/numpy/_helpers.py @@ -0,0 +1,93 @@ +# decorator +import numpy as np +from numbers import Number +import inspect +from functools import wraps +from typing import * +from .._helpers import suppress_traceback + + +def get_args_order(func, args, kwargs): + """ + Get the order of the arguments of a function. + """ + names = inspect.getfullargspec(func).args + names_idx = {name: i for i, name in enumerate(names)} + args_order = [] + kwargs_order = {} + for name, arg in kwargs.items(): + if name in names: + kwargs_order[name] = names_idx[name] + names.remove(name) + for i, arg in enumerate(args): + if i < len(names): + args_order.append(names_idx[names[i]]) + return args_order, kwargs_order + + +def broadcast_args(args, kwargs, args_dim, kwargs_dim): + spatial = [] + for arg, arg_dim in zip(args + list(kwargs.values()), args_dim + list(kwargs_dim.values())): + if isinstance(arg, np.ndarray) and arg_dim is not None: + arg_spatial = arg.shape[:arg.ndim-arg_dim] + if len(arg_spatial) > len(spatial): + spatial = [1] * (len(arg_spatial) - len(spatial)) + spatial + for j in range(len(arg_spatial)): + if spatial[-j] < arg_spatial[-j]: + if spatial[-j] == 1: + spatial[-j] = arg_spatial[-j] + else: + raise ValueError("Cannot broadcast arguments.") + for i, arg in enumerate(args): + if isinstance(arg, np.ndarray) and args_dim[i] is not None: + args[i] = np.broadcast_to(arg, [*spatial, *arg.shape[arg.ndim-args_dim[i]:]]) + for key, arg in kwargs.items(): + if isinstance(arg, np.ndarray) and kwargs_dim[key] is not None: + kwargs[key] = np.broadcast_to(arg, [*spatial, *arg.shape[arg.ndim-kwargs_dim[key]:]]) + return args, kwargs, spatial + + +def batched(*dims): + """ + Decorator that allows a function to be called with batched arguments. + """ + def decorator(func): + @wraps(func) + @suppress_traceback + def wrapper(*args, **kwargs): + args = list(args) + # get arguments dimensions + args_order, kwargs_order = get_args_order(func, args, kwargs) + args_dim = [dims[i] for i in args_order] + kwargs_dim = {key: dims[i] for key, i in kwargs_order.items()} + # convert to numpy array + for i, arg in enumerate(args): + if isinstance(arg, (Number, list, tuple)) and args_dim[i] is not None: + args[i] = np.array(arg) + for key, arg in kwargs.items(): + if isinstance(arg, (Number, list, tuple)) and kwargs_dim[key] is not None: + kwargs[key] = np.array(arg) + # broadcast arguments + args, kwargs, spatial = broadcast_args(args, kwargs, args_dim, kwargs_dim) + for i, (arg, arg_dim) in enumerate(zip(args, args_dim)): + if isinstance(arg, np.ndarray) and arg_dim is not None: + args[i] = arg.reshape([-1, *arg.shape[arg.ndim-arg_dim:]]) + for key, arg in kwargs.items(): + if isinstance(arg, np.ndarray) and kwargs_dim[key] is not None: + kwargs[key] = arg.reshape([-1, *arg.shape[arg.ndim-kwargs_dim[key]:]]) + # call function + results = func(*args, **kwargs) + type_results = type(results) + results = list(results) if isinstance(results, (tuple, list)) else [results] + # restore spatial dimensions + for i, result in enumerate(results): + results[i] = result.reshape([*spatial, *result.shape[1:]]) + if type_results == tuple: + results = tuple(results) + elif type_results == list: + results = list(results) + else: + results = results[0] + return results + return wrapper + return decorator diff --git a/submodules/MoGe/utils3d/numpy/mesh.py b/submodules/MoGe/utils3d/numpy/mesh.py new file mode 100644 index 0000000000000000000000000000000000000000..afadb5f2510b58a1c5acbabff2ff798c041744d6 --- /dev/null +++ b/submodules/MoGe/utils3d/numpy/mesh.py @@ -0,0 +1,355 @@ +import numpy as np +from typing import * +from ._helpers import batched + + +__all__ = [ + 'triangulate', + 'compute_face_normal', + 'compute_face_angle', + 'compute_vertex_normal', + 'compute_vertex_normal_weighted', + 'remove_corrupted_faces', + 'merge_duplicate_vertices', + 'remove_unreferenced_vertices', + 'subdivide_mesh_simple', + 'mesh_relations', + 'flatten_mesh_indices' +] + + +def triangulate( + faces: np.ndarray, + vertices: np.ndarray = None, + backslash: np.ndarray = None +) -> np.ndarray: + """ + Triangulate a polygonal mesh. + + Args: + faces (np.ndarray): [L, P] polygonal faces + vertices (np.ndarray, optional): [N, 3] 3-dimensional vertices. + If given, the triangulation is performed according to the distance + between vertices. Defaults to None. + backslash (np.ndarray, optional): [L] boolean array indicating + how to triangulate the quad faces. Defaults to None. + + Returns: + (np.ndarray): [L * (P - 2), 3] triangular faces + """ + if faces.shape[-1] == 3: + return faces + P = faces.shape[-1] + if vertices is not None: + assert faces.shape[-1] == 4, "now only support quad mesh" + if backslash is None: + backslash = np.linalg.norm(vertices[faces[:, 0]] - vertices[faces[:, 2]], axis=-1) < \ + np.linalg.norm(vertices[faces[:, 1]] - vertices[faces[:, 3]], axis=-1) + if backslash is None: + loop_indice = np.stack([ + np.zeros(P - 2, dtype=int), + np.arange(1, P - 1, 1, dtype=int), + np.arange(2, P, 1, dtype=int) + ], axis=1) + return faces[:, loop_indice].reshape((-1, 3)) + else: + assert faces.shape[-1] == 4, "now only support quad mesh" + faces = np.where( + backslash[:, None], + faces[:, [0, 1, 2, 0, 2, 3]], + faces[:, [0, 1, 3, 3, 1, 2]] + ).reshape((-1, 3)) + return faces + + +@batched(2, None) +def compute_face_normal( + vertices: np.ndarray, + faces: np.ndarray +) -> np.ndarray: + """ + Compute face normals of a triangular mesh + + Args: + vertices (np.ndarray): [..., N, 3] 3-dimensional vertices + faces (np.ndarray): [T, 3] triangular face indices + + Returns: + normals (np.ndarray): [..., T, 3] face normals + """ + normal = np.cross( + vertices[..., faces[:, 1], :] - vertices[..., faces[:, 0], :], + vertices[..., faces[:, 2], :] - vertices[..., faces[:, 0], :] + ) + normal_norm = np.linalg.norm(normal, axis=-1, keepdims=True) + normal_norm[normal_norm == 0] = 1 + normal /= normal_norm + return normal + + +@batched(2, None) +def compute_face_angle( + vertices: np.ndarray, + faces: np.ndarray, + eps: float = 1e-12 + ) -> np.ndarray: + """ + Compute face angles of a triangular mesh + + Args: + vertices (np.ndarray): [..., N, 3] 3-dimensional vertices + faces (np.ndarray): [T, 3] triangular face indices + + Returns: + angles (np.ndarray): [..., T, 3] face angles + """ + face_angle = np.zeros_like(faces, dtype=vertices.dtype) + for i in range(3): + edge1 = vertices[..., faces[:, (i + 1) % 3], :] - vertices[..., faces[:, i], :] + edge2 = vertices[..., faces[:, (i + 2) % 3], :] - vertices[..., faces[:, i], :] + face_angle[..., i] = np.arccos(np.sum( + edge1 / np.clip(np.linalg.norm(edge1, axis=-1, keepdims=True), eps, None) * + edge2 / np.clip(np.linalg.norm(edge2, axis=-1, keepdims=True), eps, None), + axis=-1 + )) + return face_angle + + +@batched(2, None, 2) +def compute_vertex_normal( + vertices: np.ndarray, + faces: np.ndarray, + face_normal: np.ndarray = None +) -> np.ndarray: + """ + Compute vertex normals of a triangular mesh by averaging neightboring face normals + TODO: can be improved. + + Args: + vertices (np.ndarray): [..., N, 3] 3-dimensional vertices + faces (np.ndarray): [T, 3] triangular face indices + face_normal (np.ndarray, optional): [..., T, 3] face normals. + None to compute face normals from vertices and faces. Defaults to None. + + Returns: + normals (np.ndarray): [..., N, 3] vertex normals + """ + if face_normal is None: + face_normal = compute_face_normal(vertices, faces) + vertex_normal = np.zeros_like(vertices, dtype=vertices.dtype) + for n in range(vertices.shape[0]): + for i in range(3): + vertex_normal[n, :, 0] += np.bincount(faces[:, i], weights=face_normal[n, :, 0], minlength=vertices.shape[1]) + vertex_normal[n, :, 1] += np.bincount(faces[:, i], weights=face_normal[n, :, 1], minlength=vertices.shape[1]) + vertex_normal[n, :, 2] += np.bincount(faces[:, i], weights=face_normal[n, :, 2], minlength=vertices.shape[1]) + vertex_normal_norm = np.linalg.norm(vertex_normal, axis=-1, keepdims=True) + vertex_normal_norm[vertex_normal_norm == 0] = 1 + vertex_normal /= vertex_normal_norm + return vertex_normal + + +@batched(2, None, 2) +def compute_vertex_normal_weighted( + vertices: np.ndarray, + faces: np.ndarray, + face_normal: np.ndarray = None +) -> np.ndarray: + """ + Compute vertex normals of a triangular mesh by weighted sum of neightboring face normals + according to the angles + + Args: + vertices (np.ndarray): [..., N, 3] 3-dimensional vertices + faces (np.ndarray): [..., T, 3] triangular face indices + face_normal (np.ndarray, optional): [..., T, 3] face normals. + None to compute face normals from vertices and faces. Defaults to None. + + Returns: + normals (np.ndarray): [..., N, 3] vertex normals + """ + if face_normal is None: + face_normal = compute_face_normal(vertices, faces) + face_angle = compute_face_angle(vertices, faces) + vertex_normal = np.zeros_like(vertices) + for n in range(vertices.shape[0]): + for i in range(3): + vertex_normal[n, :, 0] += np.bincount(faces[n, :, i], weights=face_normal[n, :, 0] * face_angle[n, :, i], minlength=vertices.shape[1]) + vertex_normal[n, :, 1] += np.bincount(faces[n, :, i], weights=face_normal[n, :, 1] * face_angle[n, :, i], minlength=vertices.shape[1]) + vertex_normal[n, :, 2] += np.bincount(faces[n, :, i], weights=face_normal[n, :, 2] * face_angle[n, :, i], minlength=vertices.shape[1]) + vertex_normal_norm = np.linalg.norm(vertex_normal, axis=-1, keepdims=True) + vertex_normal_norm[vertex_normal_norm == 0] = 1 + vertex_normal /= vertex_normal_norm + return vertex_normal + + +def remove_corrupted_faces( + faces: np.ndarray + ) -> np.ndarray: + """ + Remove corrupted faces (faces with duplicated vertices) + + Args: + faces (np.ndarray): [T, 3] triangular face indices + + Returns: + np.ndarray: [T_, 3] triangular face indices + """ + corrupted = (faces[:, 0] == faces[:, 1]) | (faces[:, 1] == faces[:, 2]) | (faces[:, 2] == faces[:, 0]) + return faces[~corrupted] + + +def merge_duplicate_vertices( + vertices: np.ndarray, + faces: np.ndarray, + tol: float = 1e-6 + ) -> Tuple[np.ndarray, np.ndarray]: + """ + Merge duplicate vertices of a triangular mesh. + Duplicate vertices are merged by selecte one of them, and the face indices are updated accordingly. + + Args: + vertices (np.ndarray): [N, 3] 3-dimensional vertices + faces (np.ndarray): [T, 3] triangular face indices + tol (float, optional): tolerance for merging. Defaults to 1e-6. + + Returns: + vertices (np.ndarray): [N_, 3] 3-dimensional vertices + faces (np.ndarray): [T, 3] triangular face indices + """ + vertices_round = np.round(vertices / tol) + _, uni_i, uni_inv = np.unique(vertices_round, return_index=True, return_inverse=True, axis=0) + vertices = vertices[uni_i] + faces = uni_inv[faces] + return vertices, faces + + +def remove_unreferenced_vertices( + faces: np.ndarray, + *vertice_attrs, + return_indices: bool = False +) -> Tuple[np.ndarray, ...]: + """ + Remove unreferenced vertices of a mesh. + Unreferenced vertices are removed, and the face indices are updated accordingly. + + Args: + faces (np.ndarray): [T, P] face indices + *vertice_attrs: vertex attributes + + Returns: + faces (np.ndarray): [T, P] face indices + *vertice_attrs: vertex attributes + indices (np.ndarray, optional): [N] indices of vertices that are kept. Defaults to None. + """ + P = faces.shape[-1] + fewer_indices, inv_map = np.unique(faces, return_inverse=True) + faces = inv_map.astype(np.int32).reshape(-1, P) + ret = [faces] + for attr in vertice_attrs: + ret.append(attr[fewer_indices]) + if return_indices: + ret.append(fewer_indices) + return tuple(ret) + + +def subdivide_mesh_simple( + vertices: np.ndarray, + faces: np.ndarray, + n: int = 1 +) -> Tuple[np.ndarray, np.ndarray]: + """ + Subdivide a triangular mesh by splitting each triangle into 4 smaller triangles. + NOTE: All original vertices are kept, and new vertices are appended to the end of the vertex list. + + Args: + vertices (np.ndarray): [N, 3] 3-dimensional vertices + faces (np.ndarray): [T, 3] triangular face indices + n (int, optional): number of subdivisions. Defaults to 1. + + Returns: + vertices (np.ndarray): [N_, 3] subdivided 3-dimensional vertices + faces (np.ndarray): [4 * T, 3] subdivided triangular face indices + """ + for _ in range(n): + edges = np.stack([faces[:, [0, 1]], faces[:, [1, 2]], faces[:, [2, 0]]], axis=0) + edges = np.sort(edges, axis=2) + uni_edges, uni_inv = np.unique(edges.reshape(-1, 2), return_inverse=True, axis=0) + uni_inv = uni_inv.reshape(3, -1) + midpoints = (vertices[uni_edges[:, 0]] + vertices[uni_edges[:, 1]]) / 2 + + n_vertices = vertices.shape[0] + vertices = np.concatenate([vertices, midpoints], axis=0) + faces = np.concatenate([ + np.stack([faces[:, 0], n_vertices + uni_inv[0], n_vertices + uni_inv[2]], axis=1), + np.stack([faces[:, 1], n_vertices + uni_inv[1], n_vertices + uni_inv[0]], axis=1), + np.stack([faces[:, 2], n_vertices + uni_inv[2], n_vertices + uni_inv[1]], axis=1), + np.stack([n_vertices + uni_inv[0], n_vertices + uni_inv[1], n_vertices + uni_inv[2]], axis=1), + ], axis=0) + return vertices, faces + + +def mesh_relations( + faces: np.ndarray, +) -> Tuple[np.ndarray, np.ndarray]: + """ + Calculate the relation between vertices and faces. + NOTE: The input mesh must be a manifold triangle mesh. + + Args: + faces (np.ndarray): [T, 3] triangular face indices + + Returns: + edges (np.ndarray): [E, 2] edge indices + edge2face (np.ndarray): [E, 2] edge to face relation. The second column is -1 if the edge is boundary. + face2edge (np.ndarray): [T, 3] face to edge relation + face2face (np.ndarray): [T, 3] face to face relation + """ + T = faces.shape[0] + edges = np.stack([faces[:, [0, 1]], faces[:, [1, 2]], faces[:, [2, 0]]], axis=1).reshape(-1, 2) # [3T, 2] + edges = np.sort(edges, axis=1) # [3T, 2] + edges, face2edge, occurence = np.unique(edges, axis=0, return_inverse=True, return_counts=True) # [E, 2], [3T], [E] + E = edges.shape[0] + assert np.all(occurence <= 2), "The input mesh is not a manifold mesh." + + # Edge to face relation + padding = np.arange(E, dtype=np.int32)[occurence == 1] + padded_face2edge = np.concatenate([face2edge, padding], axis=0) # [2E] + edge2face = np.argsort(padded_face2edge, kind='stable').reshape(-1, 2) // 3 # [E, 2] + edge2face_valid = edge2face[:, 1] < T # [E] + edge2face[~edge2face_valid, 1] = -1 + + # Face to edge relation + face2edge = face2edge.reshape(-1, 3) # [T, 3] + + # Face to face relation + face2face = edge2face[face2edge] # [T, 3, 2] + face2face = face2face[face2face != np.arange(T)[:, None, None]].reshape(T, 3) # [T, 3] + + return edges, edge2face, face2edge, face2face + + +@overload +def flatten_mesh_indices(faces1: np.ndarray, attr1: np.ndarray, *other_faces_attrs_pairs: np.ndarray) -> Tuple[np.ndarray, ...]: + """ + Rearrange the indices of a mesh to a flattened version. Vertices will be no longer shared. + + ### Parameters: + - `faces1`: [T, P] face indices of the first attribute + - `attr1`: [N1, ...] attributes of the first mesh + - ... + + ### Returns: + - `faces`: [T, P] flattened face indices, contigous from 0 to T * P - 1 + - `attr1`: [T * P, ...] attributes of the first mesh, where every P values correspond to a face + _ ... + """ +def flatten_mesh_indices(*args: np.ndarray) -> Tuple[np.ndarray, ...]: + assert len(args) % 2 == 0, "The number of arguments must be even." + T, P = args[0].shape + assert all(arg.shape[0] == T and arg.shape[1] == P for arg in args[::2]), "The faces must have the same shape." + attr_flat = [] + for faces_, attr_ in zip(args[::2], args[1::2]): + attr_flat_ = attr_[faces_].reshape(-1, *attr_.shape[1:]) + attr_flat.append(attr_flat_) + faces_flat = np.arange(T * P, dtype=np.int32).reshape(T, P) + return faces_flat, *attr_flat \ No newline at end of file diff --git a/submodules/MoGe/utils3d/numpy/quadmesh.py b/submodules/MoGe/utils3d/numpy/quadmesh.py new file mode 100644 index 0000000000000000000000000000000000000000..6728d91124020767cc9b3c1fdd6b21d50dc55828 --- /dev/null +++ b/submodules/MoGe/utils3d/numpy/quadmesh.py @@ -0,0 +1,472 @@ +import numpy as np +import scipy as sp +import scipy.optimize as spopt +from typing import * + + +__all__ = [ + 'calc_quad_candidates', + 'calc_quad_distortion', + 'calc_quad_direction', + 'calc_quad_smoothness', + 'sovle_quad', + 'sovle_quad_qp', + 'tri_to_quad' +] + + +def calc_quad_candidates( + edges: np.ndarray, + face2edge: np.ndarray, + edge2face: np.ndarray, +): + """ + Calculate the candidate quad faces. + + Args: + edges (np.ndarray): [E, 2] edge indices + face2edge (np.ndarray): [T, 3] face to edge relation + edge2face (np.ndarray): [E, 2] edge to face relation + + Returns: + quads (np.ndarray): [Q, 4] quad candidate indices + quad2edge (np.ndarray): [Q, 4] edge to quad candidate relation + quad2adj (np.ndarray): [Q, 8] adjacent quad candidates of each quad candidate + quads_valid (np.ndarray): [E] whether the quad corresponding to the edge is valid + """ + E = edges.shape[0] + T = face2edge.shape[0] + + quads_valid = edge2face[:, 1] != -1 + Q = quads_valid.sum() + quad2face = edge2face[quads_valid] # [Q, 2] + quad2edge = face2edge[quad2face] # [Q, 2, 3] + flag = quad2edge == np.arange(E)[quads_valid][:, None, None] # [Q, 2, 3] + flag = flag.argmax(axis=-1) # [Q, 2] + quad2edge = np.stack([ + quad2edge[np.arange(Q)[:, None], np.arange(2)[None, :], (flag + 1) % 3], + quad2edge[np.arange(Q)[:, None], np.arange(2)[None, :], (flag + 2) % 3], + ], axis=-1).reshape(Q, 4) # [Q, 4] + + quads = np.concatenate([ + np.where( + (edges[quad2edge[:, 0:1], 1:] == edges[quad2edge[:, 1:2], :]).any(axis=-1), + edges[quad2edge[:, 0:1], [[0, 1]]], + edges[quad2edge[:, 0:1], [[1, 0]]], + ), + np.where( + (edges[quad2edge[:, 2:3], 1:] == edges[quad2edge[:, 3:4], :]).any(axis=-1), + edges[quad2edge[:, 2:3], [[0, 1]]], + edges[quad2edge[:, 2:3], [[1, 0]]], + ), + ], axis=1) # [Q, 4] + + quad2adj = edge2face[quad2edge] # [Q, 4, 2] + quad2adj = quad2adj[quad2adj != quad2face[:, [0,0,1,1], None]].reshape(Q, 4) # [Q, 4] + quad2adj_valid = quad2adj != -1 + quad2adj = face2edge[quad2adj] # [Q, 4, 3] + quad2adj[~quad2adj_valid, 0] = quad2edge[~quad2adj_valid] + quad2adj[~quad2adj_valid, 1:] = -1 + quad2adj = quad2adj[quad2adj != quad2edge[..., None]].reshape(Q, 8) # [Q, 8] + edge_valid = -np.ones(E, dtype=np.int32) + edge_valid[quads_valid] = np.arange(Q) + quad2adj_valid = quad2adj != -1 + quad2adj[quad2adj_valid] = edge_valid[quad2adj[quad2adj_valid]] # [Q, 8] + + return quads, quad2edge, quad2adj, quads_valid + + +def calc_quad_distortion( + vertices: np.ndarray, + quads: np.ndarray, +): + """ + Calculate the distortion of each candidate quad face. + + Args: + vertices (np.ndarray): [N, 3] 3-dimensional vertices + quads (np.ndarray): [Q, 4] quad face indices + + Returns: + distortion (np.ndarray): [Q] distortion of each quad face + """ + edge0 = vertices[quads[:, 1]] - vertices[quads[:, 0]] # [Q, 3] + edge1 = vertices[quads[:, 2]] - vertices[quads[:, 1]] # [Q, 3] + edge2 = vertices[quads[:, 3]] - vertices[quads[:, 2]] # [Q, 3] + edge3 = vertices[quads[:, 0]] - vertices[quads[:, 3]] # [Q, 3] + cross = vertices[quads[:, 0]] - vertices[quads[:, 2]] # [Q, 3] + + len0 = np.maximum(np.linalg.norm(edge0, axis=-1), 1e-10) # [Q] + len1 = np.maximum(np.linalg.norm(edge1, axis=-1), 1e-10) # [Q] + len2 = np.maximum(np.linalg.norm(edge2, axis=-1), 1e-10) # [Q] + len3 = np.maximum(np.linalg.norm(edge3, axis=-1), 1e-10) # [Q] + len_cross = np.maximum(np.linalg.norm(cross, axis=-1), 1e-10) # [Q] + + angle0 = np.arccos(np.clip(np.sum(-edge0 * edge1, axis=-1) / (len0 * len1), -1, 1)) # [Q] + angle1 = np.arccos(np.clip(np.sum(-edge1 * cross, axis=-1) / (len1 * len_cross), -1, 1)) \ + + np.arccos(np.clip(np.sum(cross * edge2, axis=-1) / (len_cross * len2), -1, 1)) # [Q] + angle2 = np.arccos(np.clip(np.sum(-edge2 * edge3, axis=-1) / (len2 * len3), -1, 1)) # [Q] + angle3 = np.arccos(np.clip(np.sum(-edge3 * -cross, axis=-1) / (len3 * len_cross), -1, 1)) \ + + np.arccos(np.clip(np.sum(-cross * edge0, axis=-1) / (len_cross * len0), -1, 1)) # [Q] + + normal0 = np.cross(edge0, edge1) # [Q, 3] + normal1 = np.cross(edge2, edge3) # [Q, 3] + normal0 = normal0 / np.maximum(np.linalg.norm(normal0, axis=-1, keepdims=True), 1e-10) # [Q, 3] + normal1 = normal1 / np.maximum(np.linalg.norm(normal1, axis=-1, keepdims=True), 1e-10) # [Q, 3] + angle_normal = np.arccos(np.clip(np.sum(normal0 * normal1, axis=-1), -1, 1)) # [Q] + + D90 = np.pi / 2 + D180 = np.pi + D360 = np.pi * 2 + ang_eng = (np.abs(angle0 - D90)**2 + np.abs(angle1 - D90)**2 + np.abs(angle2 - D90)**2 + np.abs(angle3 - D90)**2) / 4 # [Q] + dist_eng = np.abs(angle0 - angle2)**2 / np.minimum(np.maximum(np.minimum(angle0, angle2), 1e-10), np.maximum(D180 - np.maximum(angle0, angle2), 1e-10)) \ + + np.abs(angle1 - angle3)**2 / np.minimum(np.maximum(np.minimum(angle1, angle3), 1e-10), np.maximum(D180 - np.maximum(angle1, angle3), 1e-10)) # [Q] + plane_eng = np.where(angle_normal < D90/2, np.abs(angle_normal)**2, 1e10) # [Q] + eng = ang_eng + 2 * dist_eng + 2 * plane_eng # [Q] + + return eng + + +def calc_quad_direction( + vertices: np.ndarray, + quads: np.ndarray, + ): + """ + Calculate the direction of each candidate quad face. + + Args: + vertices (np.ndarray): [N, 3] 3-dimensional vertices + quads (np.ndarray): [Q, 4] quad face indices + + Returns: + direction (np.ndarray): [Q, 4] direction of each quad face. + Represented by the angle between the crossing and each edge. + """ + mid0 = (vertices[quads[:, 0]] + vertices[quads[:, 1]]) / 2 # [Q, 3] + mid1 = (vertices[quads[:, 1]] + vertices[quads[:, 2]]) / 2 # [Q, 3] + mid2 = (vertices[quads[:, 2]] + vertices[quads[:, 3]]) / 2 # [Q, 3] + mid3 = (vertices[quads[:, 3]] + vertices[quads[:, 0]]) / 2 # [Q, 3] + + cross0 = mid2 - mid0 # [Q, 3] + cross1 = mid3 - mid1 # [Q, 3] + cross0 = cross0 / np.maximum(np.linalg.norm(cross0, axis=-1, keepdims=True), 1e-10) # [Q, 3] + cross1 = cross1 / np.maximum(np.linalg.norm(cross1, axis=-1, keepdims=True), 1e-10) # [Q, 3] + + edge0 = vertices[quads[:, 1]] - vertices[quads[:, 0]] # [Q, 3] + edge1 = vertices[quads[:, 2]] - vertices[quads[:, 1]] # [Q, 3] + edge2 = vertices[quads[:, 3]] - vertices[quads[:, 2]] # [Q, 3] + edge3 = vertices[quads[:, 0]] - vertices[quads[:, 3]] # [Q, 3] + edge0 = edge0 / np.maximum(np.linalg.norm(edge0, axis=-1, keepdims=True), 1e-10) # [Q, 3] + edge1 = edge1 / np.maximum(np.linalg.norm(edge1, axis=-1, keepdims=True), 1e-10) # [Q, 3] + edge2 = edge2 / np.maximum(np.linalg.norm(edge2, axis=-1, keepdims=True), 1e-10) # [Q, 3] + edge3 = edge3 / np.maximum(np.linalg.norm(edge3, axis=-1, keepdims=True), 1e-10) # [Q, 3] + + direction = np.stack([ + np.arccos(np.clip(np.sum(cross0 * edge0, axis=-1), -1, 1)), + np.arccos(np.clip(np.sum(cross1 * edge1, axis=-1), -1, 1)), + np.arccos(np.clip(np.sum(-cross0 * edge2, axis=-1), -1, 1)), + np.arccos(np.clip(np.sum(-cross1 * edge3, axis=-1), -1, 1)), + ], axis=-1) # [Q, 4] + + return direction + + +def calc_quad_smoothness( + quad2edge: np.ndarray, + quad2adj: np.ndarray, + quads_direction: np.ndarray, + ): + """ + Calculate the smoothness of each candidate quad face connection. + + Args: + quad2adj (np.ndarray): [Q, 8] adjacent quad faces of each quad face + quads_direction (np.ndarray): [Q, 4] direction of each quad face + + Returns: + smoothness (np.ndarray): [Q, 8] smoothness of each quad face connection + """ + Q = quad2adj.shape[0] + quad2adj_valid = quad2adj != -1 + connections = np.stack([ + np.arange(Q)[:, None].repeat(8, axis=1), + quad2adj, + ], axis=-1)[quad2adj_valid] # [C, 2] + shared_edge_idx_0 = np.array([[0, 0, 1, 1, 2, 2, 3, 3]]).repeat(Q, axis=0)[quad2adj_valid] # [C] + shared_edge_idx_1 = np.argmax(quad2edge[quad2adj][quad2adj_valid] == quad2edge[connections[:, 0], shared_edge_idx_0][:, None], axis=-1) # [C] + valid_smoothness = np.abs(quads_direction[connections[:, 0], shared_edge_idx_0] - quads_direction[connections[:, 1], shared_edge_idx_1])**2 # [C] + smoothness = np.zeros([Q, 8], dtype=np.float32) + smoothness[quad2adj_valid] = valid_smoothness + return smoothness + + +def sovle_quad( + face2edge: np.ndarray, + edge2face: np.ndarray, + quad2adj: np.ndarray, + quads_distortion: np.ndarray, + quads_smoothness: np.ndarray, + quads_valid: np.ndarray, + ): + """ + Solve the quad mesh from the candidate quad faces. + + Args: + face2edge (np.ndarray): [T, 3] face to edge relation + edge2face (np.ndarray): [E, 2] edge to face relation + quad2adj (np.ndarray): [Q, 8] adjacent quad faces of each quad face + quads_distortion (np.ndarray): [Q] distortion of each quad face + quads_smoothness (np.ndarray): [Q, 8] smoothness of each quad face connection + quads_valid (np.ndarray): [E] whether the quad corresponding to the edge is valid + + Returns: + weights (np.ndarray): [Q] weight of each valid quad face + """ + T = face2edge.shape[0] + E = edge2face.shape[0] + Q = quads_distortion.shape[0] + edge_valid = -np.ones(E, dtype=np.int32) + edge_valid[quads_valid] = np.arange(Q) + + quads_connection = np.stack([ + np.arange(Q)[:, None].repeat(8, axis=1), + quad2adj, + ], axis=-1)[quad2adj != -1] # [C, 2] + quads_connection = np.sort(quads_connection, axis=-1) # [C, 2] + quads_connection, quads_connection_idx = np.unique(quads_connection, axis=0, return_index=True) # [C, 2], [C] + quads_smoothness = quads_smoothness[quad2adj != -1] # [C] + quads_smoothness = quads_smoothness[quads_connection_idx] # [C] + C = quads_connection.shape[0] + + # Construct the linear programming problem + + # Variables: + # quads_weight: [Q] weight of each quad face + # tri_min_weight: [T] minimum weight of each triangle face + # conn_min_weight: [C] minimum weight of each quad face connection + # conn_max_weight: [C] maximum weight of each quad face connection + # Objective: + # mimi + + c = np.concatenate([ + quads_distortion - 3, + quads_smoothness*4 - 2, + quads_smoothness*4, + ], axis=0) # [Q+C] + + A_ub_triplet = np.concatenate([ + np.stack([np.arange(T), edge_valid[face2edge[:, 0]], np.ones(T)], axis=1), # [T, 3] + np.stack([np.arange(T), edge_valid[face2edge[:, 1]], np.ones(T)], axis=1), # [T, 3] + np.stack([np.arange(T), edge_valid[face2edge[:, 2]], np.ones(T)], axis=1), # [T, 3] + np.stack([np.arange(T, T+C), np.arange(Q, Q+C), np.ones(C)], axis=1), # [C, 3] + np.stack([np.arange(T, T+C), quads_connection[:, 0], -np.ones(C)], axis=1), # [C, 3] + np.stack([np.arange(T, T+C), quads_connection[:, 1], -np.ones(C)], axis=1), # [C, 3] + np.stack([np.arange(T+C, T+2*C), np.arange(Q+C, Q+2*C), -np.ones(C)], axis=1), # [C, 3] + np.stack([np.arange(T+C, T+2*C), quads_connection[:, 0], np.ones(C)], axis=1), # [C, 3] + np.stack([np.arange(T+C, T+2*C), quads_connection[:, 1], np.ones(C)], axis=1), # [C, 3] + ], axis=0) # [3T+6C, 3] + A_ub_triplet = A_ub_triplet[A_ub_triplet[:, 1] != -1] # [3T', 3] + A_ub = sp.sparse.coo_matrix((A_ub_triplet[:, 2], (A_ub_triplet[:, 0], A_ub_triplet[:, 1])), shape=[T+2*C, Q+2*C]) # [T, + b_ub = np.concatenate([np.ones(T), -np.ones(C), np.ones(C)], axis=0) # [T+2C] + bound = np.stack([ + np.concatenate([np.zeros(Q), -np.ones(C), np.zeros(C)], axis=0), + np.concatenate([np.ones(Q), np.ones(C), np.ones(C)], axis=0), + ], axis=1) # [Q+2C, 2] + A_eq = None + b_eq = None + + print('Solver statistics:') + print(f' #T = {T}') + print(f' #Q = {Q}') + print(f' #C = {C}') + + # Solve the linear programming problem + last_num_valid = 0 + for i in range(100): + res_ = spopt.linprog(c, A_ub=A_ub, b_ub=b_ub, A_eq=A_eq, b_eq=b_eq, bounds=bound) + if not res_.success: + print(f' Iter {i} | Failed with {res_.message}') + break + res = res_ + weights = res.x[:Q] + valid = (weights > 0.5) + num_valid = valid.sum() + print(f' Iter {i} | #Q_valid = {num_valid}') + if num_valid == last_num_valid: + break + last_num_valid = num_valid + A_eq_triplet = np.stack([ + np.arange(num_valid), + np.arange(Q)[valid], + np.ones(num_valid), + ], axis=1) # [num_valid, 3] + A_eq = sp.sparse.coo_matrix((A_eq_triplet[:, 2], (A_eq_triplet[:, 0], A_eq_triplet[:, 1])), shape=[num_valid, Q+2*C]) # [num_valid, Q+C] + b_eq = np.where(weights[valid] > 0.5, 1, 0) # [num_valid] + + # Return the result + quads_weight = res.x[:Q] + conn_min_weight = res.x[Q:Q+C] + conn_max_weight = res.x[Q+C:Q+2*C] + return quads_weight, conn_min_weight, conn_max_weight + + +def sovle_quad_qp( + face2edge: np.ndarray, + edge2face: np.ndarray, + quad2adj: np.ndarray, + quads_distortion: np.ndarray, + quads_smoothness: np.ndarray, + quads_valid: np.ndarray, + ): + """ + Solve the quad mesh from the candidate quad faces. + + Args: + face2edge (np.ndarray): [T, 3] face to edge relation + edge2face (np.ndarray): [E, 2] edge to face relation + quad2adj (np.ndarray): [Q, 8] adjacent quad faces of each quad face + quads_distortion (np.ndarray): [Q] distortion of each quad face + quads_smoothness (np.ndarray): [Q, 8] smoothness of each quad face connection + quads_valid (np.ndarray): [E] whether the quad corresponding to the edge is valid + + Returns: + weights (np.ndarray): [Q] weight of each valid quad face + """ + T = face2edge.shape[0] + E = edge2face.shape[0] + Q = quads_distortion.shape[0] + edge_valid = -np.ones(E, dtype=np.int32) + edge_valid[quads_valid] = np.arange(Q) + + # Construct the quadratic programming problem + C_smoothness_triplet = np.stack([ + np.arange(Q)[:, None].repeat(8, axis=1)[quad2adj != -1], + quad2adj[quad2adj != -1], + 5 * quads_smoothness[quad2adj != -1], + ], axis=-1) # [C, 3] + # C_smoothness_triplet = np.concatenate([ + # C_smoothness_triplet, + # np.stack([np.arange(Q), np.arange(Q), 20*np.ones(Q)], axis=1), + # ], axis=0) # [C+Q, 3] + C_smoothness = sp.sparse.coo_matrix((C_smoothness_triplet[:, 2], (C_smoothness_triplet[:, 0], C_smoothness_triplet[:, 1])), shape=[Q, Q]) # [Q, Q] + C_smoothness = C_smoothness.tocsc() + C_dist = quads_distortion - 20 # [Q] + + A_eq = sp.sparse.coo_matrix((np.zeros(Q), (np.zeros(Q), np.arange(Q))), shape=[1, Q]) # [1, Q]\ + A_eq = A_eq.tocsc() + b_eq = np.array([0]) + + A_ub_triplet = np.concatenate([ + np.stack([np.arange(T), edge_valid[face2edge[:, 0]], np.ones(T)], axis=1), # [T, 3] + np.stack([np.arange(T), edge_valid[face2edge[:, 1]], np.ones(T)], axis=1), # [T, 3] + np.stack([np.arange(T), edge_valid[face2edge[:, 2]], np.ones(T)], axis=1), # [T, 3] + ], axis=0) # [3T, 3] + A_ub_triplet = A_ub_triplet[A_ub_triplet[:, 1] != -1] # [3T', 3] + A_ub = sp.sparse.coo_matrix((A_ub_triplet[:, 2], (A_ub_triplet[:, 0], A_ub_triplet[:, 1])), shape=[T, Q]) # [T, Q] + A_ub = A_ub.tocsc() + b_ub = np.ones(T) + + lb = np.zeros(Q) + ub = np.ones(Q) + + import piqp + solver = piqp.SparseSolver() + solver.settings.verbose = True + solver.settings.compute_timings = True + solver.setup(C_smoothness, C_dist, A_eq, b_eq, A_ub, b_ub, lb, ub) + + status = solver.solve() + + # x = cp.Variable(Q) + # prob = cp.Problem( + # cp.Minimize(cp.quad_form(x, C_smoothness) + C_dist.T @ x), + # [ + # A_ub @ x <= b_ub, + # x >= 0, x <= 1, + # ] + # ) + + # # Solve the quadratic programming problem + # prob.solve(solver=cp.PIQP, verbose=True) + + # Return the result + weights = solver.result.x + return weights + + +def tri_to_quad( + vertices: np.ndarray, + faces: np.ndarray, + ) -> Tuple[np.ndarray, np.ndarray]: + """ + Convert a triangle mesh to a quad mesh. + NOTE: The input mesh must be a manifold mesh. + + Args: + vertices (np.ndarray): [N, 3] 3-dimensional vertices + faces (np.ndarray): [T, 3] triangular face indices + + Returns: + vertices (np.ndarray): [N_, 3] 3-dimensional vertices + faces (np.ndarray): [Q, 4] quad face indices + """ + raise NotImplementedError + + +if __name__ == '__main__': + import os + import sys + sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..', '..'))) + import utils3d + import numpy as np + import cv2 + from vis import vis_edge_color + + file = 'miku' + + vertices, faces = utils3d.io.read_ply(f'test/assets/{file}.ply') + edges, edge2face, face2edge, face2face = calc_relations(faces) + quad_cands, quad2edge, quad2adj, quad_valid = calc_quad_candidates(edges, face2edge, edge2face) + distortion = calc_quad_distortion(vertices, quad_cands) + direction = calc_quad_direction(vertices, quad_cands) + smoothness = calc_quad_smoothness(quad2edge, quad2adj, direction) + boundary_edges = edges[edge2face[:, 1] == -1] + quads_weight, conn_min_weight, conn_max_weight = sovle_quad(face2edge, edge2face, quad2adj, distortion, smoothness, quad_valid) + quads = quad_cands[quads_weight > 0.5] + print('Mesh statistics') + print(f' #V = {vertices.shape[0]}') + print(f' #F = {faces.shape[0]}') + print(f' #E = {edges.shape[0]}') + print(f' #B = {boundary_edges.shape[0]}') + print(f' #Q_cand = {quad_cands.shape[0]}') + print(f' #Q = {quads.shape[0]}') + + utils3d.io.write_ply(f'test/assets/{file}_boundary_edges.ply', vertices=vertices, edges=boundary_edges) + utils3d.io.write_ply(f'test/assets/{file}_quad_candidates.ply', vertices=vertices, faces=quads) + + edge_colors = np.zeros([edges.shape[0], 3], dtype=np.uint8) + distortion = (distortion - distortion.min()) / (distortion.max() - distortion.min()) + distortion = (distortion * 255).astype(np.uint8) + edge_colors[quad_valid] = cv2.cvtColor(cv2.applyColorMap(distortion, cv2.COLORMAP_JET), cv2.COLOR_BGR2RGB).reshape(-1, 3) + utils3d.io.write_ply(f'test/assets/{file}_quad_candidates_distortion.ply', **vis_edge_color(vertices, edges, edge_colors)) + + edge_colors = np.zeros([edges.shape[0], 3], dtype=np.uint8) + edge_colors[quad_valid] = cv2.cvtColor(cv2.applyColorMap((quads_weight * 255).astype(np.uint8), cv2.COLORMAP_JET), cv2.COLOR_BGR2RGB).reshape(-1, 3) + utils3d.io.write_ply(f'test/assets/{file}_quad_candidates_weights.ply', **vis_edge_color(vertices, edges, edge_colors)) + utils3d.io.write_ply(f'test/assets/{file}_quad.ply', vertices=vertices, faces=quads) + + quad_centers = vertices[quad_cands].mean(axis=1) + conns = np.stack([ + np.arange(quad_cands.shape[0])[:, None].repeat(8, axis=1), + quad2adj, + ], axis=-1)[quad2adj != -1] # [C, 2] + conns, conns_idx = np.unique(np.sort(conns, axis=-1), axis=0, return_index=True) # [C, 2], [C] + smoothness = smoothness[quad2adj != -1][conns_idx] # [C] + conns_color = cv2.cvtColor(cv2.applyColorMap((smoothness * 255).astype(np.uint8), cv2.COLORMAP_JET), cv2.COLOR_BGR2RGB).reshape(-1, 3) + utils3d.io.write_ply(f'test/assets/{file}_quad_conn_smoothness.ply', **vis_edge_color(quad_centers, conns, conns_color)) + conns_color = cv2.cvtColor(cv2.applyColorMap((conn_min_weight * 255).astype(np.uint8), cv2.COLORMAP_JET), cv2.COLOR_BGR2RGB).reshape(-1, 3) + utils3d.io.write_ply(f'test/assets/{file}_quad_conn_min.ply', **vis_edge_color(quad_centers, conns, conns_color)) + conns_color = cv2.cvtColor(cv2.applyColorMap((conn_max_weight * 255).astype(np.uint8), cv2.COLORMAP_JET), cv2.COLOR_BGR2RGB).reshape(-1, 3) + utils3d.io.write_ply(f'test/assets/{file}_quad_conn_max.ply', **vis_edge_color(quad_centers, conns, conns_color)) + + \ No newline at end of file diff --git a/submodules/MoGe/utils3d/numpy/rasterization.py b/submodules/MoGe/utils3d/numpy/rasterization.py new file mode 100644 index 0000000000000000000000000000000000000000..8f0f0db55d87f37f108a778dac29ae6320418f3a --- /dev/null +++ b/submodules/MoGe/utils3d/numpy/rasterization.py @@ -0,0 +1,469 @@ +import os +from typing import * + +import numpy as np +import moderngl + +from . import transforms, utils, mesh + + +__all__ = [ + 'RastContext', + 'rasterize_triangle_faces', + 'rasterize_edges', + 'texture', + 'test_rasterization', + 'warp_image_by_depth', +] + + +def map_np_dtype(dtype) -> str: + if dtype == int: + return 'i4' + elif dtype == np.uint8: + return 'u1' + elif dtype == np.uint32: + return 'u2' + elif dtype == np.float16: + return 'f2' + elif dtype == np.float32: + return 'f4' + + +def one_value(dtype): + if dtype == 'u1': + return 255 + elif dtype == 'u2': + return 65535 + else: + return 1 + + +class RastContext: + def __init__(self, *args, **kwargs): + """ + Create a moderngl context. + + Args: + See moderngl.create_context + """ + if len(args) == 1 and isinstance(args[0], moderngl.Context): + self.mgl_ctx = args[0] + else: + self.mgl_ctx = moderngl.create_context(*args, **kwargs) + self.__prog_src = {} + self.__prog = {} + + def program_vertex_attribute(self, n: int) -> moderngl.Program: + assert n in [1, 2, 3, 4], 'vertex attribute only supports channels 1, 2, 3, 4' + + if 'vertex_attribute_vsh' not in self.__prog_src: + with open(os.path.join(os.path.dirname(__file__), 'shaders', 'vertex_attribute.vsh'), 'r') as f: + self.__prog_src['vertex_attribute_vsh'] = f.read() + if 'vertex_attribute_fsh' not in self.__prog_src: + with open(os.path.join(os.path.dirname(__file__), 'shaders', 'vertex_attribute.fsh'), 'r') as f: + self.__prog_src['vertex_attribute_fsh'] = f.read() + + if f'vertex_attribute_{n}' not in self.__prog: + vsh = self.__prog_src['vertex_attribute_vsh'].replace('vecN', f'vec{n}') + fsh = self.__prog_src['vertex_attribute_fsh'].replace('vecN', f'vec{n}') + self.__prog[f'vertex_attribute_{n}'] = self.mgl_ctx.program(vertex_shader=vsh, fragment_shader=fsh) + + return self.__prog[f'vertex_attribute_{n}'] + + def program_texture(self, n: int) -> moderngl.Program: + assert n in [1, 2, 3, 4], 'texture only supports channels 1, 2, 3, 4' + + if 'texture_vsh' not in self.__prog_src: + with open(os.path.join(os.path.dirname(__file__), 'shaders', 'texture.vsh'), 'r') as f: + self.__prog_src['texture_vsh'] = f.read() + if 'texture_fsh' not in self.__prog_src: + with open(os.path.join(os.path.dirname(__file__), 'shaders', 'texture.fsh'), 'r') as f: + self.__prog_src['texture_fsh'] = f.read() + + if f'texture_{n}' not in self.__prog: + vsh = self.__prog_src['texture_vsh'].replace('vecN', f'vec{n}') + fsh = self.__prog_src['texture_fsh'].replace('vecN', f'vec{n}') + self.__prog[f'texture_{n}'] = self.mgl_ctx.program(vertex_shader=vsh, fragment_shader=fsh) + self.__prog[f'texture_{n}']['tex'] = 0 + self.__prog[f'texture_{n}']['uv'] = 1 + + return self.__prog[f'texture_{n}'] + + +def rasterize_triangle_faces( + ctx: RastContext, + vertices: np.ndarray, + faces: np.ndarray, + attr: np.ndarray, + width: int, + height: int, + transform: np.ndarray = None, + cull_backface: bool = True, + return_depth: bool = False, + image: np.ndarray = None, + depth: np.ndarray = None +) -> Tuple[np.ndarray, np.ndarray]: + """ + Rasterize vertex attribute. + + Args: + vertices (np.ndarray): [N, 3] + faces (np.ndarray): [T, 3] + attr (np.ndarray): [N, C] + width (int): width of rendered image + height (int): height of rendered image + transform (np.ndarray): [4, 4] model-view-projection transformation matrix. + cull_backface (bool): whether to cull backface + image: (np.ndarray): [H, W, C] background image + depth: (np.ndarray): [H, W] background depth + + Returns: + image (np.ndarray): [H, W, C] rendered image + depth (np.ndarray): [H, W] screen space depth, ranging from 0 to 1. If return_depth is False, it is None. + """ + assert vertices.ndim == 2 and vertices.shape[1] == 3 + assert faces.ndim == 2 and faces.shape[1] == 3, f"Faces should be a 2D array with shape (T, 3), but got {faces.shape}" + assert attr.ndim == 2 and attr.shape[1] in [1, 2, 3, 4], f'Vertex attribute only supports channels 1, 2, 3, 4, but got {attr.shape}' + assert vertices.shape[0] == attr.shape[0] + assert vertices.dtype == np.float32 + assert faces.dtype == np.uint32 or faces.dtype == np.int32 + assert attr.dtype == np.float32, "Attribute should be float32" + assert transform is None or transform.shape == (4, 4), f"Transform should be a 4x4 matrix, but got {transform.shape}" + assert transform is None or transform.dtype == np.float32, f"Transform should be float32, but got {transform.dtype}" + if image is not None: + assert image.ndim == 3 and image.shape == (height, width, attr.shape[1]), f"Image should be a 3D array with shape (H, W, {attr.shape[1]}), but got {image.shape}" + assert image.dtype == np.float32, f"Image should be float32, but got {image.dtype}" + if depth is not None: + assert depth.ndim == 2 and depth.shape == (height, width), f"Depth should be a 2D array with shape (H, W), but got {depth.shape}" + assert depth.dtype == np.float32, f"Depth should be float32, but got {depth.dtype}" + + C = attr.shape[1] + prog = ctx.program_vertex_attribute(C) + + transform = np.eye(4, np.float32) if transform is None else transform + + # Create buffers + ibo = ctx.mgl_ctx.buffer(np.ascontiguousarray(faces, dtype='i4')) + vbo_vertices = ctx.mgl_ctx.buffer(np.ascontiguousarray(vertices, dtype='f4')) + vbo_attr = ctx.mgl_ctx.buffer(np.ascontiguousarray(attr, dtype='f4')) + vao = ctx.mgl_ctx.vertex_array( + prog, + [ + (vbo_vertices, '3f', 'i_position'), + (vbo_attr, f'{C}f', 'i_attr'), + ], + ibo, + mode=moderngl.TRIANGLES, + ) + + # Create framebuffer + image_tex = ctx.mgl_ctx.texture((width, height), C, dtype='f4', data=np.ascontiguousarray(image[::-1, :, :]) if image is not None else None) + depth_tex = ctx.mgl_ctx.depth_texture((width, height), data=np.ascontiguousarray(depth[::-1, :]) if depth is not None else None) + fbo = ctx.mgl_ctx.framebuffer( + color_attachments=[image_tex], + depth_attachment=depth_tex, + ) + + # Render + prog['u_mvp'].write(transform.transpose().copy().astype('f4')) + fbo.use() + fbo.viewport = (0, 0, width, height) + ctx.mgl_ctx.depth_func = '<' + if depth is None: + ctx.mgl_ctx.clear(depth=1.0) + ctx.mgl_ctx.enable(ctx.mgl_ctx.DEPTH_TEST) + if cull_backface: + ctx.mgl_ctx.enable(ctx.mgl_ctx.CULL_FACE) + else: + ctx.mgl_ctx.disable(ctx.mgl_ctx.CULL_FACE) + vao.render() + ctx.mgl_ctx.disable(ctx.mgl_ctx.DEPTH_TEST) + + # Read + image = np.zeros((height, width, C), dtype='f4') + image_tex.read_into(image) + image = image[::-1, :, :] + if return_depth: + depth = np.zeros((height, width), dtype='f4') + depth_tex.read_into(depth) + depth = depth[::-1, :] + else: + depth = None + + # Release + vao.release() + ibo.release() + vbo_vertices.release() + vbo_attr.release() + fbo.release() + image_tex.release() + depth_tex.release() + + return image, depth + + +def rasterize_edges( + ctx: RastContext, + vertices: np.ndarray, + edges: np.ndarray, + attr: np.ndarray, + width: int, + height: int, + transform: np.ndarray = None, + line_width: float = 1.0, + return_depth: bool = False, + image: np.ndarray = None, + depth: np.ndarray = None +) -> Tuple[np.ndarray, ...]: + """ + Rasterize vertex attribute. + + Args: + vertices (np.ndarray): [N, 3] + faces (np.ndarray): [T, 3] + attr (np.ndarray): [N, C] + width (int): width of rendered image + height (int): height of rendered image + transform (np.ndarray): [4, 4] model-view-projection matrix + line_width (float): width of line. Defaults to 1.0. NOTE: Values other than 1.0 may not work across all platforms. + cull_backface (bool): whether to cull backface + + Returns: + image (np.ndarray): [H, W, C] rendered image + depth (np.ndarray): [H, W] screen space depth, ranging from 0 to 1. If return_depth is False, it is None. + """ + assert vertices.ndim == 2 and vertices.shape[1] == 3 + assert edges.ndim == 2 and edges.shape[1] == 2, f"Edges should be a 2D array with shape (T, 2), but got {edges.shape}" + assert attr.ndim == 2 and attr.shape[1] in [1, 2, 3, 4], f'Vertex attribute only supports channels 1, 2, 3, 4, but got {attr.shape}' + assert vertices.shape[0] == attr.shape[0] + assert vertices.dtype == np.float32 + assert edges.dtype == np.uint32 or edges.dtype == np.int32 + assert attr.dtype == np.float32, "Attribute should be float32" + + C = attr.shape[1] + prog = ctx.program_vertex_attribute(C) + + transform = transform if transform is not None else np.eye(4, np.float32) + + # Create buffers + ibo = ctx.mgl_ctx.buffer(np.ascontiguousarray(edges, dtype='i4')) + vbo_vertices = ctx.mgl_ctx.buffer(np.ascontiguousarray(vertices, dtype='f4')) + vbo_attr = ctx.mgl_ctx.buffer(np.ascontiguousarray(attr, dtype='f4')) + vao = ctx.mgl_ctx.vertex_array( + prog, + [ + (vbo_vertices, '3f', 'i_position'), + (vbo_attr, f'{C}f', 'i_attr'), + ], + ibo, + mode=moderngl.LINES, + ) + + # Create framebuffer + image_tex = ctx.mgl_ctx.texture((width, height), C, dtype='f4', data=np.ascontiguousarray(image[::-1, :, :]) if image is not None else None) + depth_tex = ctx.mgl_ctx.depth_texture((width, height), data=np.ascontiguousarray(depth[::-1, :]) if depth is not None else None) + fbo = ctx.mgl_ctx.framebuffer( + color_attachments=[image_tex], + depth_attachment=depth_tex, + ) + + # Render + prog['u_mvp'].write(transform.transpose().copy().astype('f4')) + fbo.use() + fbo.viewport = (0, 0, width, height) + if depth is None: + ctx.mgl_ctx.clear(depth=1.0) + ctx.mgl_ctx.depth_func = '<' + ctx.mgl_ctx.enable(ctx.mgl_ctx.DEPTH_TEST) + ctx.mgl_ctx.line_width = line_width + vao.render() + ctx.mgl_ctx.disable(ctx.mgl_ctx.DEPTH_TEST) + + # Read + image = np.zeros((height, width, C), dtype='f4') + image_tex.read_into(image) + image = image[::-1, :, :] + if return_depth: + depth = np.zeros((height, width), dtype='f4') + depth_tex.read_into(depth) + depth = depth[::-1, :] + else: + depth = None + + # Release + vao.release() + ibo.release() + vbo_vertices.release() + vbo_attr.release() + fbo.release() + image_tex.release() + depth_tex.release() + + return image, depth + + +def texture( + ctx: RastContext, + uv: np.ndarray, + texture: np.ndarray, + interpolation: str= 'linear', + wrap: str = 'clamp' +) -> np.ndarray: + """ + Given an UV image, texturing from the texture map + """ + assert len(texture.shape) == 3 and 1 <= texture.shape[2] <= 4 + assert uv.shape[2] == 2 + height, width = uv.shape[:2] + texture_dtype = map_np_dtype(texture.dtype) + + # Create VAO + screen_quad_vbo = ctx.mgl_ctx.buffer(np.array([[-1, -1], [1, -1], [1, 1], [-1, 1]], dtype='f4')) + screen_quad_ibo = ctx.mgl_ctx.buffer(np.array([0, 1, 2, 0, 2, 3], dtype=np.int32)) + screen_quad_vao = ctx.mgl_ctx.vertex_array(ctx.program_texture(texture.shape[2]), [(screen_quad_vbo, '2f4', 'in_vert')], index_buffer=screen_quad_ibo, index_element_size=4) + + # Create texture, set filter and bind. TODO: min mag filter, mipmap + texture_tex = ctx.mgl_ctx.texture((texture.shape[1], texture.shape[0]), texture.shape[2], dtype=texture_dtype, data=np.ascontiguousarray(texture)) + if interpolation == 'linear': + texture_tex.filter = (moderngl.LINEAR, moderngl.LINEAR) + elif interpolation == 'nearest': + texture_tex.filter = (moderngl.NEAREST, moderngl.NEAREST) + texture_tex.use(location=0) + texture_uv = ctx.mgl_ctx.texture((width, height), 2, dtype='f4', data=np.ascontiguousarray(uv.astype('f4', copy=False))) + texture_uv.filter = (moderngl.NEAREST, moderngl.NEAREST) + texture_uv.use(location=1) + + # Create render buffer and frame buffer + rb = ctx.mgl_ctx.renderbuffer((uv.shape[1], uv.shape[0]), texture.shape[2], dtype=texture_dtype) + fbo = ctx.mgl_ctx.framebuffer(color_attachments=[rb]) + + # Render + fbo.use() + fbo.viewport = (0, 0, width, height) + ctx.mgl_ctx.disable(ctx.mgl_ctx.BLEND) + screen_quad_vao.render() + + # Read buffer + image_buffer = np.frombuffer(fbo.read(components=texture.shape[2], attachment=0, dtype=texture_dtype), dtype=texture_dtype).reshape((height, width, texture.shape[2])) + + # Release + texture_tex.release() + rb.release() + fbo.release() + + return image_buffer + + +def warp_image_by_depth( + ctx: RastContext, + src_depth: np.ndarray, + src_image: np.ndarray = None, + width: int = None, + height: int = None, + *, + extrinsics_src: np.ndarray = None, + extrinsics_tgt: np.ndarray = None, + intrinsics_src: np.ndarray = None, + intrinsics_tgt: np.ndarray = None, + near: float = 0.1, + far: float = 100.0, + cull_backface: bool = True, + ssaa: int = 1, + return_depth: bool = False, +) -> Tuple[np.ndarray, ...]: + """ + Warp image by depth map. + + Args: + ctx (RastContext): rasterizer context + src_depth (np.ndarray): [H, W] + src_image (np.ndarray, optional): [H, W, C]. The image to warp. Defaults to None (use uv coordinates). + width (int, optional): width of the output image. None to use depth map width. Defaults to None. + height (int, optional): height of the output image. None to use depth map height. Defaults to None. + extrinsics_src (np.ndarray, optional): extrinsics matrix of the source camera. Defaults to None (identity). + extrinsics_tgt (np.ndarray, optional): extrinsics matrix of the target camera. Defaults to None (identity). + intrinsics_src (np.ndarray, optional): intrinsics matrix of the source camera. Defaults to None (use the same as intrinsics_tgt). + intrinsics_tgt (np.ndarray, optional): intrinsics matrix of the target camera. Defaults to None (use the same as intrinsics_src). + cull_backface (bool, optional): whether to cull backface. Defaults to True. + ssaa (int, optional): super sampling anti-aliasing. Defaults to 1. + + Returns: + tgt_image (np.ndarray): [H, W, C] warped image (or uv coordinates if image is None). + tgt_depth (np.ndarray): [H, W] screen space depth, ranging from 0 to 1. If return_depth is False, it is None. + """ + assert src_depth.ndim == 2 + + if width is None: + width = src_depth.shape[1] + if height is None: + height = src_depth.shape[0] + if src_image is not None: + assert src_image.shape[-2:] == src_depth.shape[-2:], f'Shape of source image {src_image.shape} does not match shape of source depth {src_depth.shape}' + + # set up default camera parameters + extrinsics_src = np.eye(4) if extrinsics_src is None else extrinsics_src + extrinsics_tgt = np.eye(4) if extrinsics_tgt is None else extrinsics_tgt + intrinsics_src = intrinsics_tgt if intrinsics_src is None else intrinsics_src + intrinsics_tgt = intrinsics_src if intrinsics_tgt is None else intrinsics_tgt + + assert all(x is not None for x in [extrinsics_src, extrinsics_tgt, intrinsics_src, intrinsics_tgt]), "Make sure you have provided all the necessary camera parameters." + + # check shapes + assert extrinsics_src.shape == (4, 4) and extrinsics_tgt.shape == (4, 4) + assert intrinsics_src.shape == (3, 3) and intrinsics_tgt.shape == (3, 3) + + # convert to view and perspective matrices + view_tgt = transforms.extrinsics_to_view(extrinsics_tgt) + perspective_tgt = transforms.intrinsics_to_perspective(intrinsics_tgt, near=near, far=far) + + # unproject depth map + uv, faces = utils.image_mesh(*src_depth.shape[-2:]) + pts = transforms.unproject_cv(uv, src_depth.reshape(-1), extrinsics_src, intrinsics_src) + faces = mesh.triangulate(faces, vertices=pts) + + # rasterize attributes + if src_image is not None: + attr = src_image.reshape(-1, src_image.shape[-1]) + else: + attr = uv + + tgt_image, tgt_depth = rasterize_triangle_faces( + ctx, + pts, + faces, + attr, + width * ssaa, + height * ssaa, + transform=perspective_tgt @ view_tgt, + cull_backface=cull_backface, + return_depth=return_depth, + ) + + if ssaa > 1: + tgt_image = tgt_image.reshape(height, ssaa, width, ssaa, -1).mean(axis=(1, 3)) + tgt_depth = tgt_depth.reshape(height, ssaa, width, ssaa, -1).mean(axis=(1, 3)) if return_depth else None + + return tgt_image, tgt_depth + +def test_rasterization(ctx: RastContext): + """ + Test if rasterization works. It will render a cube with random colors and save it as a CHECKME.png file. + """ + vertices, faces = utils.cube(tri=True) + attr = np.random.rand(len(vertices), 3).astype(np.float32) + perspective = transforms.perspective(np.deg2rad(60), 1, 0.01, 100) + view = transforms.view_look_at(np.array([2, 2, 2]), np.array([0, 0, 0]), np.array([0, 1, 0])) + image, depth = rasterize_triangle_faces( + ctx, + vertices, + faces, + attr, + 512, 512, + transform=(perspective @ view).astype(np.float32), + cull_backface=False, + return_depth=True, + ) + import cv2 + cv2.imwrite('CHECKME.png', cv2.cvtColor((image.clip(0, 1) * 255).astype(np.uint8), cv2.COLOR_RGB2BGR)) + \ No newline at end of file diff --git a/submodules/MoGe/utils3d/numpy/shaders/texture.fsh b/submodules/MoGe/utils3d/numpy/shaders/texture.fsh new file mode 100644 index 0000000000000000000000000000000000000000..c8be72f94cbf38fb0b2a9609e8db4d50ac7753d6 --- /dev/null +++ b/submodules/MoGe/utils3d/numpy/shaders/texture.fsh @@ -0,0 +1,11 @@ +#version 330 + +uniform sampler2D tex; +uniform sampler2D uv; + +in vec2 scr_coord; +out vecN tex_color; + +void main() { + tex_color = vecN(texture(tex, texture(uv, scr_coord).xy)); +} \ No newline at end of file diff --git a/submodules/MoGe/utils3d/numpy/shaders/texture.vsh b/submodules/MoGe/utils3d/numpy/shaders/texture.vsh new file mode 100644 index 0000000000000000000000000000000000000000..f96c6b14a8931fbcd5f4ca22ea917b9c8f80f195 --- /dev/null +++ b/submodules/MoGe/utils3d/numpy/shaders/texture.vsh @@ -0,0 +1,9 @@ + #version 330 core + +in vec2 in_vert; +out vec2 scr_coord; + +void main() { + scr_coord = in_vert * 0.5 + 0.5; + gl_Position = vec4(in_vert, 0., 1.); +} \ No newline at end of file diff --git a/submodules/MoGe/utils3d/numpy/shaders/vertex_attribute.fsh b/submodules/MoGe/utils3d/numpy/shaders/vertex_attribute.fsh new file mode 100644 index 0000000000000000000000000000000000000000..54409764c5600ee190db89313b07dd91b940d6eb --- /dev/null +++ b/submodules/MoGe/utils3d/numpy/shaders/vertex_attribute.fsh @@ -0,0 +1,9 @@ +#version 330 + +in vecN v_attr; + +out vecN f_attr; + +void main() { + f_attr = v_attr; +} diff --git a/submodules/MoGe/utils3d/numpy/shaders/vertex_attribute.vsh b/submodules/MoGe/utils3d/numpy/shaders/vertex_attribute.vsh new file mode 100644 index 0000000000000000000000000000000000000000..7c94f91aaabfd714a47a194b93f8e53bf63577f5 --- /dev/null +++ b/submodules/MoGe/utils3d/numpy/shaders/vertex_attribute.vsh @@ -0,0 +1,13 @@ +#version 330 + +uniform mat4 u_mvp; + +in vec3 i_position; +in vecN i_attr; + +out vecN v_attr; + +void main() { + gl_Position = u_mvp * vec4(i_position, 1.0); + v_attr = i_attr; +} diff --git a/submodules/MoGe/utils3d/numpy/spline.py b/submodules/MoGe/utils3d/numpy/spline.py new file mode 100644 index 0000000000000000000000000000000000000000..03c664136bc3734215d37669a3446c248dffe097 --- /dev/null +++ b/submodules/MoGe/utils3d/numpy/spline.py @@ -0,0 +1,82 @@ +from typing import * + +import numpy as np + + +__all__ = ['linear_spline_interpolate'] + + +def linear_spline_interpolate(x: np.ndarray, t: np.ndarray, s: np.ndarray, extrapolation_mode: Literal['constant', 'linear'] = 'constant') -> np.ndarray: + """ + Linear spline interpolation. + + ### Parameters: + - `x`: np.ndarray, shape (n, d): the values of data points. + - `t`: np.ndarray, shape (n,): the times of the data points. + - `s`: np.ndarray, shape (m,): the times to be interpolated. + - `extrapolation_mode`: str, the mode of extrapolation. 'constant' means extrapolate the boundary values, 'linear' means extrapolate linearly. + + ### Returns: + - `y`: np.ndarray, shape (..., m, d): the interpolated values. + """ + i = np.searchsorted(t, s, side='left') + if extrapolation_mode == 'constant': + prev = np.clip(i - 1, 0, len(t) - 1) + suc = np.clip(i, 0, len(t) - 1) + elif extrapolation_mode == 'linear': + prev = np.clip(i - 1, 0, len(t) - 2) + suc = np.clip(i, 1, len(t) - 1) + else: + raise ValueError(f'Invalid extrapolation_mode: {extrapolation_mode}') + + u = (s - t[prev]) / np.maximum(t[suc] - t[prev], 1e-12) + y = u * x[suc] + (1 - u) * x[prev] + + return y + + + +def _solve_tridiagonal(a: np.ndarray, b: np.ndarray, c: np.ndarray, d: np.ndarray) -> np.ndarray: + n = b.shape[-1] + cc = np.zeros_like(b) + dd = np.zeros_like(b) + cc[..., 0] = c[..., 0] / b[..., 0] + dd[..., 0] = d[..., 0] / b[..., 0] + for i in range(1, n): + cc[..., i] = c[..., i] / (b[..., i] - a[..., i - 1] * cc[..., i - 1]) + dd[..., i] = (d[..., i] - a[..., i - 1] * dd[..., i - 1]) / (b[..., i] - a[..., i - 1] * cc[..., i - 1]) + x = np.zeros_like(b) + x[..., -1] = dd[..., -1] + for i in range(n - 2, -1, -1): + x[..., i] = dd[..., i] - cc[..., i] * x[..., i + 1] + return x + + +def cubic_spline_interpolate(x: np.ndarray, t: np.ndarray, s: np.ndarray, v0: np.ndarray = None, vn: np.ndarray = None) -> np.ndarray: + """ + Cubic spline interpolation. + + ### Parameters: + - `x`: np.ndarray, shape (..., n,): the x-coordinates of the data points. + - `t`: np.ndarray, shape (n,): the knot vector. NOTE: t must be sorted in ascending order. + - `s`: np.ndarray, shape (..., m,): the y-coordinates of the data points. + - `v0`: np.ndarray, shape (...,): the value of the derivative at the first knot, as the boundary condition. If None, it is set to zero. + - `vn`: np.ndarray, shape (...,): the value of the derivative at the last knot, as the boundary condition. If None, it is set to zero. + + ### Returns: + - `y`: np.ndarray, shape (..., m): the interpolated values. + """ + h = t[..., 1:] - t[..., :-1] + mu = h[..., :-1] / (h[..., :-1] + h[..., 1:]) + la = 1 - mu + d = (x[..., 1:] - x[..., :-1]) / h + d = 6 * (d[..., 1:] - d[..., :-1]) / (t[..., 2:] - t[..., :-2]) + + mu = np.concatenate([mu, np.ones_like(mu[..., :1])], axis=-1) + la = np.concatenate([np.ones_like(la[..., :1]), la], axis=-1) + d = np.concatenate([(((x[..., 1] - x[..., 0]) / h[0] - v0) / h[0])[..., None], d, ((vn - (x[..., -1] - x[..., -2]) / h[-1]) / h[-1])[..., None]], axis=-1) + + M = _solve_tridiagonal(mu, np.full_like(d, fill_value=2), la, d) + + i = np.searchsorted(t, s, side='left') + diff --git a/submodules/MoGe/utils3d/numpy/transforms.py b/submodules/MoGe/utils3d/numpy/transforms.py new file mode 100644 index 0000000000000000000000000000000000000000..2e2418540b4a1d152a5e072900079f59f474a139 --- /dev/null +++ b/submodules/MoGe/utils3d/numpy/transforms.py @@ -0,0 +1,1104 @@ +import numpy as np +from typing import * +from numbers import Number +from ._helpers import batched +from .._helpers import no_warnings + + +__all__ = [ + 'perspective', + 'perspective_from_fov', + 'perspective_from_fov_xy', + 'intrinsics_from_focal_center', + 'intrinsics_from_fov', + 'fov_to_focal', + 'focal_to_fov', + 'intrinsics_to_fov', + 'view_look_at', + 'extrinsics_look_at', + 'perspective_to_intrinsics', + 'perspective_to_near_far', + 'intrinsics_to_perspective', + 'extrinsics_to_view', + 'view_to_extrinsics', + 'normalize_intrinsics', + 'crop_intrinsics', + 'pixel_to_uv', + 'pixel_to_ndc', + 'uv_to_pixel', + 'project_depth', + 'depth_buffer_to_linear', + 'unproject_cv', + 'unproject_gl', + 'project_cv', + 'project_gl', + 'quaternion_to_matrix', + 'axis_angle_to_matrix', + 'matrix_to_quaternion', + 'extrinsics_to_essential', + 'euler_axis_angle_rotation', + 'euler_angles_to_matrix', + 'skew_symmetric', + 'rotation_matrix_from_vectors', + 'ray_intersection', + 'se3_matrix', + 'slerp_quaternion', + 'slerp_vector', + 'lerp', + 'lerp_se3_matrix', + 'piecewise_lerp', + 'piecewise_lerp_se3_matrix', + 'apply_transform' +] + + +@batched(0,0,0,0) +def perspective( + fov_y: Union[float, np.ndarray], + aspect: Union[float, np.ndarray], + near: Union[float, np.ndarray], + far: Union[float, np.ndarray] +) -> np.ndarray: + """ + Get OpenGL perspective matrix + + Args: + fov_y (float | np.ndarray): field of view in y axis + aspect (float | np.ndarray): aspect ratio + near (float | np.ndarray): near plane to clip + far (float | np.ndarray): far plane to clip + + Returns: + (np.ndarray): [..., 4, 4] perspective matrix + """ + N = fov_y.shape[0] + ret = np.zeros((N, 4, 4), dtype=fov_y.dtype) + ret[:, 0, 0] = 1. / (np.tan(fov_y / 2) * aspect) + ret[:, 1, 1] = 1. / (np.tan(fov_y / 2)) + ret[:, 2, 2] = (near + far) / (near - far) + ret[:, 2, 3] = 2. * near * far / (near - far) + ret[:, 3, 2] = -1. + return ret + + +def perspective_from_fov( + fov: Union[float, np.ndarray], + width: Union[int, np.ndarray], + height: Union[int, np.ndarray], + near: Union[float, np.ndarray], + far: Union[float, np.ndarray] +) -> np.ndarray: + """ + Get OpenGL perspective matrix from field of view in largest dimension + + Args: + fov (float | np.ndarray): field of view in largest dimension + width (int | np.ndarray): image width + height (int | np.ndarray): image height + near (float | np.ndarray): near plane to clip + far (float | np.ndarray): far plane to clip + + Returns: + (np.ndarray): [..., 4, 4] perspective matrix + """ + fov_y = 2 * np.arctan(np.tan(fov / 2) * height / np.maximum(width, height)) + aspect = width / height + return perspective(fov_y, aspect, near, far) + + +def perspective_from_fov_xy( + fov_x: Union[float, np.ndarray], + fov_y: Union[float, np.ndarray], + near: Union[float, np.ndarray], + far: Union[float, np.ndarray] +) -> np.ndarray: + """ + Get OpenGL perspective matrix from field of view in x and y axis + + Args: + fov_x (float | np.ndarray): field of view in x axis + fov_y (float | np.ndarray): field of view in y axis + near (float | np.ndarray): near plane to clip + far (float | np.ndarray): far plane to clip + + Returns: + (np.ndarray): [..., 4, 4] perspective matrix + """ + aspect = np.tan(fov_x / 2) / np.tan(fov_y / 2) + return perspective(fov_y, aspect, near, far) + + +def intrinsics_from_focal_center( + fx: Union[float, np.ndarray], + fy: Union[float, np.ndarray], + cx: Union[float, np.ndarray], + cy: Union[float, np.ndarray], + dtype: Optional[np.dtype] = np.float32 +) -> np.ndarray: + """ + Get OpenCV intrinsics matrix + + Returns: + (np.ndarray): [..., 3, 3] OpenCV intrinsics matrix + """ + if any(isinstance(x, np.ndarray) for x in (fx, fy, cx, cy)): + dtype = np.result_type(fx, fy, cx, cy) + fx, fy, cx, cy = np.broadcast_arrays(fx, fy, cx, cy) + ret = np.zeros((*fx.shape, 3, 3), dtype=dtype) + ret[..., 0, 0] = fx + ret[..., 1, 1] = fy + ret[..., 0, 2] = cx + ret[..., 1, 2] = cy + ret[..., 2, 2] = 1. + return ret + + +def intrinsics_from_fov( + fov_max: Union[float, np.ndarray] = None, + fov_min: Union[float, np.ndarray] = None, + fov_x: Union[float, np.ndarray] = None, + fov_y: Union[float, np.ndarray] = None, + width: Union[int, np.ndarray] = None, + height: Union[int, np.ndarray] = None, +) -> np.ndarray: + """ + Get normalized OpenCV intrinsics matrix from given field of view. + You can provide either fov_max, fov_min, fov_x or fov_y + + Args: + width (int | np.ndarray): image width + height (int | np.ndarray): image height + fov_max (float | np.ndarray): field of view in largest dimension + fov_min (float | np.ndarray): field of view in smallest dimension + fov_x (float | np.ndarray): field of view in x axis + fov_y (float | np.ndarray): field of view in y axis + + Returns: + (np.ndarray): [..., 3, 3] OpenCV intrinsics matrix + """ + if fov_max is not None: + fx = np.maximum(width, height) / width / (2 * np.tan(fov_max / 2)) + fy = np.maximum(width, height) / height / (2 * np.tan(fov_max / 2)) + elif fov_min is not None: + fx = np.minimum(width, height) / width / (2 * np.tan(fov_min / 2)) + fy = np.minimum(width, height) / height / (2 * np.tan(fov_min / 2)) + elif fov_x is not None and fov_y is not None: + fx = 1 / (2 * np.tan(fov_x / 2)) + fy = 1 / (2 * np.tan(fov_y / 2)) + elif fov_x is not None: + fx = 1 / (2 * np.tan(fov_x / 2)) + fy = fx * width / height + elif fov_y is not None: + fy = 1 / (2 * np.tan(fov_y / 2)) + fx = fy * height / width + cx = 0.5 + cy = 0.5 + ret = intrinsics_from_focal_center(fx, fy, cx, cy) + return ret + + +def focal_to_fov(focal: np.ndarray): + return 2 * np.arctan(0.5 / focal) + + +def fov_to_focal(fov: np.ndarray): + return 0.5 / np.tan(fov / 2) + + +def intrinsics_to_fov(intrinsics: np.ndarray) -> Tuple[np.ndarray, np.ndarray]: + fov_x = focal_to_fov(intrinsics[..., 0, 0]) + fov_y = focal_to_fov(intrinsics[..., 1, 1]) + return fov_x, fov_y + + +@batched(1,1,1) +def view_look_at( + eye: np.ndarray, + look_at: np.ndarray, + up: np.ndarray + ) -> np.ndarray: + """ + Get OpenGL view matrix looking at something + + Args: + eye (np.ndarray): [..., 3] the eye position + look_at (np.ndarray): [..., 3] the position to look at + up (np.ndarray): [..., 3] head up direction (y axis in screen space). Not necessarily othogonal to view direction + + Returns: + (np.ndarray): [..., 4, 4], view matrix + """ + z = eye - look_at + x = np.cross(up, z) + y = np.cross(z, x) + # x = np.cross(y, z) + x = x / np.linalg.norm(x, axis=-1, keepdims=True) + y = y / np.linalg.norm(y, axis=-1, keepdims=True) + z = z / np.linalg.norm(z, axis=-1, keepdims=True) + R = np.stack([x, y, z], axis=-2) + t = -np.matmul(R, eye[..., None]) + return np.concatenate([ + np.concatenate([R, t], axis=-1), + np.array([[[0., 0., 0., 1.]]]).repeat(eye.shape[0], axis=0) + ], axis=-2) + + +@batched(1,1,1) +def extrinsics_look_at( + eye: np.ndarray, + look_at: np.ndarray, + up: np.ndarray +) -> np.ndarray: + """ + Get OpenCV extrinsics matrix looking at something + + Args: + eye (np.ndarray): [..., 3] the eye position + look_at (np.ndarray): [..., 3] the position to look at + up (np.ndarray): [..., 3] head up direction (-y axis in screen space). Not necessarily othogonal to view direction + + Returns: + (np.ndarray): [..., 4, 4], extrinsics matrix + """ + z = look_at - eye + x = np.cross(-up, z) + y = np.cross(z, x) + # x = np.cross(y, z) + x = x / np.linalg.norm(x, axis=-1, keepdims=True) + y = y / np.linalg.norm(y, axis=-1, keepdims=True) + z = z / np.linalg.norm(z, axis=-1, keepdims=True) + R = np.stack([x, y, z], axis=-2) + t = -np.matmul(R, eye[..., None]) + return np.concatenate([ + np.concatenate([R, t], axis=-1), + np.array([[[0., 0., 0., 1.]]], dtype=eye.dtype).repeat(eye.shape[0], axis=0) + ], axis=-2) + + +def perspective_to_intrinsics( + perspective: np.ndarray +) -> np.ndarray: + """ + OpenGL perspective matrix to OpenCV intrinsics + + Args: + perspective (np.ndarray): [..., 4, 4] OpenGL perspective matrix + + Returns: + (np.ndarray): shape [..., 3, 3] OpenCV intrinsics + """ + ret = np.array([[0.5, 0., 0.5], [0., -0.5, 0.5], [0., 0., 1.]], dtype=perspective.dtype) \ + @ perspective[..., [0, 1, 3], :3] \ + @ np.diag(np.array([1, -1, -1], dtype=perspective.dtype)) + return ret + + +def perspective_to_near_far(perspective: np.ndarray) -> Tuple[np.ndarray, np.ndarray]: + """ + Get near and far planes from OpenGL perspective matrix + + Args: + """ + a, b = perspective[..., 2, 2], perspective[..., 2, 3] + near, far = b / (a - 1), b / (a + 1) + return near, far + + +@batched(2,0,0) +def intrinsics_to_perspective( + intrinsics: np.ndarray, + near: Union[float, np.ndarray], + far: Union[float, np.ndarray], +) -> np.ndarray: + """ + OpenCV intrinsics to OpenGL perspective matrix + NOTE: not work for tile-shifting intrinsics currently + + Args: + intrinsics (np.ndarray): [..., 3, 3] OpenCV intrinsics matrix + near (float | np.ndarray): [...] near plane to clip + far (float | np.ndarray): [...] far plane to clip + Returns: + (np.ndarray): [..., 4, 4] OpenGL perspective matrix + """ + N = intrinsics.shape[0] + fx, fy = intrinsics[:, 0, 0], intrinsics[:, 1, 1] + cx, cy = intrinsics[:, 0, 2], intrinsics[:, 1, 2] + ret = np.zeros((N, 4, 4), dtype=intrinsics.dtype) + ret[:, 0, 0] = 2 * fx + ret[:, 1, 1] = 2 * fy + ret[:, 0, 2] = -2 * cx + 1 + ret[:, 1, 2] = 2 * cy - 1 + ret[:, 2, 2] = (near + far) / (near - far) + ret[:, 2, 3] = 2. * near * far / (near - far) + ret[:, 3, 2] = -1. + return ret + + +@batched(2) +def extrinsics_to_view( + extrinsics: np.ndarray + ) -> np.ndarray: + """ + OpenCV camera extrinsics to OpenGL view matrix + + Args: + extrinsics (np.ndarray): [..., 4, 4] OpenCV camera extrinsics matrix + + Returns: + (np.ndarray): [..., 4, 4] OpenGL view matrix + """ + return extrinsics * np.array([1, -1, -1, 1], dtype=extrinsics.dtype)[:, None] + + +@batched(2) +def view_to_extrinsics( + view: np.ndarray + ) -> np.ndarray: + """ + OpenGL view matrix to OpenCV camera extrinsics + + Args: + view (np.ndarray): [..., 4, 4] OpenGL view matrix + + Returns: + (np.ndarray): [..., 4, 4] OpenCV camera extrinsics matrix + """ + return view * np.array([1, -1, -1, 1], dtype=view.dtype)[:, None] + + +@batched(2, 0, 0, None) +def normalize_intrinsics( + intrinsics: np.ndarray, + width: Union[int, np.ndarray], + height: Union[int, np.ndarray], + integer_pixel_centers: bool = True +) -> np.ndarray: + """ + Normalize intrinsics from pixel cooridnates to uv coordinates + + Args: + intrinsics (np.ndarray): [..., 3, 3] camera intrinsics(s) to normalize + width (int | np.ndarray): [...] image width(s) + height (int | np.ndarray): [...] image height(s) + integer_pixel_centers (bool): whether the integer pixel coordinates are at the center of the pixel. If False, the integer coordinates are at the left-top corner of the pixel. + + Returns: + (np.ndarray): [..., 3, 3] normalized camera intrinsics(s) + """ + zeros = np.zeros_like(width) + ones = np.ones_like(width) + if integer_pixel_centers: + transform = np.stack([ + 1 / width, zeros, 0.5 / width, + zeros, 1 / height, 0.5 / height, + zeros, zeros, ones + ]).reshape(*zeros.shape, 3, 3) + else: + transform = np.stack([ + 1 / width, zeros, zeros, + zeros, 1 / height, zeros, + zeros, zeros, ones + ]).reshape(*zeros.shape, 3, 3) + return transform @ intrinsics + + +@batched(2,0,0,0,0,0,0) +def crop_intrinsics( + intrinsics: np.ndarray, + width: Union[int, np.ndarray], + height: Union[int, np.ndarray], + left: Union[int, np.ndarray], + top: Union[int, np.ndarray], + crop_width: Union[int, np.ndarray], + crop_height: Union[int, np.ndarray] +) -> np.ndarray: + """ + Evaluate the new intrinsics(s) after crop the image: cropped_img = img[top:top+crop_height, left:left+crop_width] + + Args: + intrinsics (np.ndarray): [..., 3, 3] camera intrinsics(s) to crop + width (int | np.ndarray): [...] image width(s) + height (int | np.ndarray): [...] image height(s) + left (int | np.ndarray): [...] left crop boundary + top (int | np.ndarray): [...] top crop boundary + crop_width (int | np.ndarray): [...] crop width + crop_height (int | np.ndarray): [...] crop height + + Returns: + (np.ndarray): [..., 3, 3] cropped camera intrinsics(s) + """ + zeros = np.zeros_like(width) + ones = np.ones_like(width) + transform = np.stack([ + width / crop_width, zeros, -left / crop_width, + zeros, height / crop_height, -top / crop_height, + zeros, zeros, ones + ]).reshape(*zeros.shape, 3, 3) + return transform @ intrinsics + + +@batched(1,0,0) +def pixel_to_uv( + pixel: np.ndarray, + width: Union[int, np.ndarray], + height: Union[int, np.ndarray] +) -> np.ndarray: + """ + Args: + pixel (np.ndarray): [..., 2] pixel coordinrates defined in image space, x range is (0, W - 1), y range is (0, H - 1) + width (int | np.ndarray): [...] image width(s) + height (int | np.ndarray): [...] image height(s) + + Returns: + (np.ndarray): [..., 2] pixel coordinrates defined in uv space, the range is (0, 1) + """ + if not np.issubdtype(pixel.dtype, np.floating): + pixel = pixel.astype(np.float32) + dtype = pixel.dtype + uv = (pixel + np.array(0.5, dtype=dtype)) / np.stack([width, height], axis=-1) + return uv + + +@batched(1,0,0) +def uv_to_pixel( + uv: np.ndarray, + width: Union[int, np.ndarray], + height: Union[int, np.ndarray] +) -> np.ndarray: + """ + Args: + pixel (np.ndarray): [..., 2] pixel coordinrates defined in image space, x range is (0, W - 1), y range is (0, H - 1) + width (int | np.ndarray): [...] image width(s) + height (int | np.ndarray): [...] image height(s) + + Returns: + (np.ndarray): [..., 2] pixel coordinrates defined in uv space, the range is (0, 1) + """ + pixel = uv * np.stack([width, height], axis=-1).astype(uv.dtype) - 0.5 + return pixel + + +@batched(1,0,0) +def pixel_to_ndc( + pixel: np.ndarray, + width: Union[int, np.ndarray], + height: Union[int, np.ndarray] +) -> np.ndarray: + """ + Args: + pixel (np.ndarray): [..., 2] pixel coordinrates defined in image space, x range is (0, W - 1), y range is (0, H - 1) + width (int | np.ndarray): [...] image width(s) + height (int | np.ndarray): [...] image height(s) + + Returns: + (np.ndarray): [..., 2] pixel coordinrates defined in ndc space, the range is (-1, 1) + """ + if not np.issubdtype(pixel.dtype, np.floating): + pixel = pixel.astype(np.float32) + dtype = pixel.dtype + ndc = (pixel + np.array(0.5, dtype=dtype)) / (np.stack([width, height], dim=-1) * np.array([2, -2], dtype=dtype)) \ + + np.array([-1, 1], dtype=dtype) + return ndc + + +@batched(0,0,0) +def project_depth( + depth: np.ndarray, + near: Union[float, np.ndarray], + far: Union[float, np.ndarray] +) -> np.ndarray: + """ + Project linear depth to depth value in screen space + + Args: + depth (np.ndarray): [...] depth value + near (float | np.ndarray): [...] near plane to clip + far (float | np.ndarray): [...] far plane to clip + + Returns: + (np.ndarray): [..., 1] depth value in screen space, value ranging in [0, 1] + """ + return (far - near * far / depth) / (far - near) + + +@batched(0,0,0) +def depth_buffer_to_linear( + depth_buffer: np.ndarray, + near: Union[float, np.ndarray], + far: Union[float, np.ndarray] +) -> np.ndarray: + """ + OpenGL depth buffer to linear depth + + Args: + depth_buffer (np.ndarray): [...] depth value + near (float | np.ndarray): [...] near plane to clip + far (float | np.ndarray): [...] far plane to clip + + Returns: + (np.ndarray): [..., 1] linear depth + """ + return near * far / (far - (far - near) * depth_buffer) + + +@batched(2,2,2,2) +def project_gl( + points: np.ndarray, + model: np.ndarray = None, + view: np.ndarray = None, + perspective: np.ndarray = None + ) -> Tuple[np.ndarray, np.ndarray]: + """ + Project 3D points to 2D following the OpenGL convention (except for row major matrice) + + Args: + points (np.ndarray): [..., N, 3] or [..., N, 4] 3D points to project, if the last + dimension is 4, the points are assumed to be in homogeneous coordinates + model (np.ndarray): [..., 4, 4] model matrix + view (np.ndarray): [..., 4, 4] view matrix + perspective (np.ndarray): [..., 4, 4] perspective matrix + + Returns: + scr_coord (np.ndarray): [..., N, 3] screen space coordinates, value ranging in [0, 1]. + The origin (0., 0., 0.) is corresponding to the left & bottom & nearest + linear_depth (np.ndarray): [..., N] linear depth + """ + assert perspective is not None, "perspective matrix is required" + if points.shape[-1] == 3: + points = np.concatenate([points, np.ones_like(points[..., :1])], axis=-1) + if model is not None: + points = points @ model.swapaxes(-1, -2) + if view is not None: + points = points @ view.swapaxes(-1, -2) + clip_coord = points @ perspective.swapaxes(-1, -2) + ndc_coord = clip_coord[..., :3] / clip_coord[..., 3:] + scr_coord = ndc_coord * 0.5 + 0.5 + linear_depth = clip_coord[..., 3] + return scr_coord, linear_depth + + +@batched(2,2,2) +def project_cv( + points: np.ndarray, + extrinsics: np.ndarray = None, + intrinsics: np.ndarray = None + ) -> Tuple[np.ndarray, np.ndarray]: + """ + Project 3D points to 2D following the OpenCV convention + + Args: + points (np.ndarray): [..., N, 3] or [..., N, 4] 3D points to project, if the last + dimension is 4, the points are assumed to be in homogeneous coordinates + extrinsics (np.ndarray): [..., 4, 4] extrinsics matrix + intrinsics (np.ndarray): [..., 3, 3] intrinsics matrix + + Returns: + uv_coord (np.ndarray): [..., N, 2] uv coordinates, value ranging in [0, 1]. + The origin (0., 0.) is corresponding to the left & top + linear_depth (np.ndarray): [..., N] linear depth + """ + assert intrinsics is not None, "intrinsics matrix is required" + if points.shape[-1] == 3: + points = np.concatenate([points, np.ones_like(points[..., :1])], axis=-1) + if extrinsics is not None: + points = points @ extrinsics.swapaxes(-1, -2) + points = points[..., :3] @ intrinsics.swapaxes(-1, -2) + with no_warnings(): + uv_coord = points[..., :2] / points[..., 2:] + linear_depth = points[..., 2] + return uv_coord, linear_depth + + +@batched(2,2,2,2) +def unproject_gl( + screen_coord: np.ndarray, + model: np.ndarray = None, + view: np.ndarray = None, + perspective: np.ndarray = None + ) -> np.ndarray: + """ + Unproject screen space coordinates to 3D view space following the OpenGL convention (except for row major matrice) + + Args: + screen_coord (np.ndarray): [..., N, 3] screen space coordinates, value ranging in [0, 1]. + The origin (0., 0., 0.) is corresponding to the left & bottom & nearest + model (np.ndarray): [..., 4, 4] model matrix + view (np.ndarray): [..., 4, 4] view matrix + perspective (np.ndarray): [..., 4, 4] perspective matrix + + Returns: + points (np.ndarray): [..., N, 3] 3d points + """ + assert perspective is not None, "perspective matrix is required" + ndc_xy = screen_coord * 2 - 1 + clip_coord = np.concatenate([ndc_xy, np.ones_like(ndc_xy[..., :1])], axis=-1) + transform = perspective + if view is not None: + transform = transform @ view + if model is not None: + transform = transform @ model + transform = np.linalg.inv(transform) + points = clip_coord @ transform.swapaxes(-1, -2) + points = points[..., :3] / points[..., 3:] + return points + + +@batched(2,1,2,2) +def unproject_cv( + uv_coord: np.ndarray, + depth: np.ndarray = None, + extrinsics: np.ndarray = None, + intrinsics: np.ndarray = None +) -> np.ndarray: + """ + Unproject uv coordinates to 3D view space following the OpenCV convention + + Args: + uv_coord (np.ndarray): [..., N, 2] uv coordinates, value ranging in [0, 1]. + The origin (0., 0.) is corresponding to the left & top + depth (np.ndarray): [..., N] depth value + extrinsics (np.ndarray): [..., 4, 4] extrinsics matrix + intrinsics (np.ndarray): [..., 3, 3] intrinsics matrix + + Returns: + points (np.ndarray): [..., N, 3] 3d points + """ + assert intrinsics is not None, "intrinsics matrix is required" + points = np.concatenate([uv_coord, np.ones_like(uv_coord[..., :1])], axis=-1) + points = points @ np.linalg.inv(intrinsics).swapaxes(-1, -2) + if depth is not None: + points = points * depth[..., None] + if extrinsics is not None: + points = np.concatenate([points, np.ones_like(points[..., :1])], axis=-1) + points = (points @ np.linalg.inv(extrinsics).swapaxes(-1, -2))[..., :3] + return points + + +def quaternion_to_matrix(quaternion: np.ndarray, eps: float = 1e-12) -> np.ndarray: + """Converts a batch of quaternions (w, x, y, z) to rotation matrices + + Args: + quaternion (np.ndarray): shape (..., 4), the quaternions to convert + + Returns: + np.ndarray: shape (..., 3, 3), the rotation matrices corresponding to the given quaternions + """ + assert quaternion.shape[-1] == 4 + quaternion = quaternion / np.linalg.norm(quaternion, axis=-1, keepdims=True).clip(min=eps) + w, x, y, z = quaternion[..., 0], quaternion[..., 1], quaternion[..., 2], quaternion[..., 3] + zeros = np.zeros_like(w) + I = np.eye(3, dtype=quaternion.dtype) + xyz = quaternion[..., 1:] + A = xyz[..., :, None] * xyz[..., None, :] - I * (xyz ** 2).sum(axis=-1)[..., None, None] + B = np.stack([ + zeros, -z, y, + z, zeros, -x, + -y, x, zeros + ], axis=-1).reshape(*quaternion.shape[:-1], 3, 3) + rot_mat = I + 2 * (A + w[..., None, None] * B) + return rot_mat + + +def matrix_to_quaternion(rot_mat: np.ndarray, eps: float = 1e-12) -> np.ndarray: + """Convert 3x3 rotation matrix to quaternion (w, x, y, z) + + Args: + rot_mat (np.ndarray): shape (..., 3, 3), the rotation matrices to convert + + Returns: + np.ndarray: shape (..., 4), the quaternions corresponding to the given rotation matrices + """ + # Extract the diagonal and off-diagonal elements of the rotation matrix + m00, m01, m02, m10, m11, m12, m20, m21, m22 = [rot_mat[..., i, j] for i in range(3) for j in range(3)] + + diag = np.diagonal(rot_mat, axis1=-2, axis2=-1) + M = np.array([ + [1, 1, 1], + [1, -1, -1], + [-1, 1, -1], + [-1, -1, 1] + ], dtype=rot_mat.dtype) + wxyz = 0.5 * np.clip(1 + diag @ M.T, 0.0, None) ** 0.5 + max_idx = np.argmax(wxyz, axis=-1) + xw = np.sign(m21 - m12) + yw = np.sign(m02 - m20) + zw = np.sign(m10 - m01) + yz = np.sign(m21 + m12) + xz = np.sign(m02 + m20) + xy = np.sign(m01 + m10) + ones = np.ones_like(xw) + sign = np.where( + max_idx[..., None] == 0, + np.stack([ones, xw, yw, zw], axis=-1), + np.where( + max_idx[..., None] == 1, + np.stack([xw, ones, xy, xz], axis=-1), + np.where( + max_idx[..., None] == 2, + np.stack([yw, xy, ones, yz], axis=-1), + np.stack([zw, xz, yz, ones], axis=-1) + ) + ) + ) + quat = sign * wxyz + quat = quat / np.linalg.norm(quat, axis=-1, keepdims=True).clip(min=eps) + return quat + + +def extrinsics_to_essential(extrinsics: np.ndarray): + """ + extrinsics matrix `[[R, t] [0, 0, 0, 1]]` such that `x' = R (x - t)` to essential matrix such that `x' E x = 0` + + Args: + extrinsics (np.ndaray): [..., 4, 4] extrinsics matrix + + Returns: + (np.ndaray): [..., 3, 3] essential matrix + """ + assert extrinsics.shape[-2:] == (4, 4) + R = extrinsics[..., :3, :3] + t = extrinsics[..., :3, 3] + zeros = np.zeros_like(t[..., 0]) + t_x = np.stack([ + zeros, -t[..., 2], t[..., 1], + t[..., 2], zeros, -t[..., 0], + -t[..., 1], t[..., 0], zeros + ]).reshape(*t.shape[:-1], 3, 3) + return t_x @ R + + +def euler_axis_angle_rotation(axis: str, angle: np.ndarray) -> np.ndarray: + """ + Return the rotation matrices for one of the rotations about an axis + of which Euler angles describe, for each value of the angle given. + + Args: + axis: Axis label "X" or "Y or "Z". + angle: any shape tensor of Euler angles in radians + + Returns: + Rotation matrices as tensor of shape (..., 3, 3). + """ + + cos = np.cos(angle) + sin = np.sin(angle) + one = np.ones_like(angle) + zero = np.zeros_like(angle) + + if axis == "X": + R_flat = (one, zero, zero, zero, cos, -sin, zero, sin, cos) + elif axis == "Y": + R_flat = (cos, zero, sin, zero, one, zero, -sin, zero, cos) + elif axis == "Z": + R_flat = (cos, -sin, zero, sin, cos, zero, zero, zero, one) + else: + raise ValueError("letter must be either X, Y or Z.") + + return np.stack(R_flat, -1).reshape(angle.shape + (3, 3)) + + +def euler_angles_to_matrix(euler_angles: np.ndarray, convention: str = 'XYZ') -> np.ndarray: + """ + Convert rotations given as Euler angles in radians to rotation matrices. + + Args: + euler_angles: Euler angles in radians as ndarray of shape (..., 3), XYZ + convention: permutation of "X", "Y" or "Z", representing the order of Euler rotations to apply. + + Returns: + Rotation matrices as ndarray of shape (..., 3, 3). + """ + if euler_angles.shape[-1] != 3: + raise ValueError("Invalid input euler angles.") + if len(convention) != 3: + raise ValueError("Convention must have 3 letters.") + if convention[1] in (convention[0], convention[2]): + raise ValueError(f"Invalid convention {convention}.") + for letter in convention: + if letter not in ("X", "Y", "Z"): + raise ValueError(f"Invalid letter {letter} in convention string.") + matrices = [ + euler_axis_angle_rotation(c, euler_angles[..., 'XYZ'.index(c)]) + for c in convention + ] + return matrices[2] @ matrices[1] @ matrices[0] + + +def skew_symmetric(v: np.ndarray): + "Skew symmetric matrix from a 3D vector" + assert v.shape[-1] == 3, "v must be 3D" + x, y, z = v[..., 0], v[..., 1], v[..., 2] + zeros = np.zeros_like(x) + return np.stack([ + zeros, -z, y, + z, zeros, -x, + -y, x, zeros, + ], axis=-1).reshape(*v.shape[:-1], 3, 3) + + +def rotation_matrix_from_vectors(v1: np.ndarray, v2: np.ndarray): + "Rotation matrix that rotates v1 to v2" + I = np.eye(3, dtype=v1.dtype) + v1 = v1 / np.linalg.norm(v1, axis=-1) + v2 = v2 / np.linalg.norm(v2, axis=-1) + v = np.cross(v1, v2, axis=-1) + c = np.sum(v1 * v2, axis=-1) + K = skew_symmetric(v) + R = I + K + (1 / (1 + c)).astype(v1.dtype)[None, None] * (K @ K) # Avoid numpy's default type casting for scalars + return R + + +def axis_angle_to_matrix(axis_angle: np.ndarray, eps: float = 1e-12) -> np.ndarray: + """Convert axis-angle representation (rotation vector) to rotation matrix, whose direction is the axis of rotation and length is the angle of rotation + + Args: + axis_angle (np.ndarray): shape (..., 3), axis-angle vcetors + + Returns: + np.ndarray: shape (..., 3, 3) The rotation matrices for the given axis-angle parameters + """ + batch_shape = axis_angle.shape[:-1] + dtype = axis_angle.dtype + + angle = np.linalg.norm(axis_angle, axis=-1, keepdims=True) + axis = axis_angle / (angle + eps) + + cos = np.cos(angle)[..., None, :] + sin = np.sin(angle)[..., None, :] + + rx, ry, rz = np.split(axis, 3, axis=-1) + zeros = np.zeros((*batch_shape, 1), dtype=dtype) + K = np.concatenate([zeros, -rz, ry, rz, zeros, -rx, -ry, rx, zeros], axis=-1).reshape((*batch_shape, 3, 3)) + + ident = np.eye(3, dtype=dtype) + rot_mat = ident + sin * K + (1 - cos) * (K @ K) + return rot_mat + + +def ray_intersection(p1: np.ndarray, d1: np.ndarray, p2: np.ndarray, d2: np.ndarray): + """ + Compute the intersection/closest point of two D-dimensional rays + If the rays are intersecting, the closest point is the intersection point. + + Args: + p1 (np.ndarray): (..., D) origin of ray 1 + d1 (np.ndarray): (..., D) direction of ray 1 + p2 (np.ndarray): (..., D) origin of ray 2 + d2 (np.ndarray): (..., D) direction of ray 2 + + Returns: + (np.ndarray): (..., N) intersection point + """ + p1, d1, p2, d2 = np.broadcast_arrays(p1, d1, p2, d2) + dtype = p1.dtype + dim = p1.shape[-1] + d = np.stack([d1, d2], axis=-2) # (..., 2, D) + p = np.stack([p1, p2], axis=-2) # (..., 2, D) + A = np.concatenate([ + (np.eye(dim, dtype=dtype) * np.ones((*p.shape[:-2], 2, 1, 1))).reshape(*d.shape[:-2], 2 * dim, dim), # (..., 2 * D, D) + -(np.eye(2, dtype=dtype)[..., None] * d[..., None, :]).swapaxes(-2, -1).reshape(*d.shape[:-2], 2 * dim, 2) # (..., 2 * D, 2) + ], axis=-1) # (..., 2 * D, D + 2) + b = p.reshape(*p.shape[:-2], 2 * dim) # (..., 2 * D) + x = np.linalg.solve(A.swapaxes(-1, -2) @ A + 1e-12 * np.eye(dim + 2, dtype=dtype), (A.swapaxes(-1, -2) @ b[..., :, None])[..., 0]) + return x[..., :dim], (x[..., dim], x[..., dim + 1]) + + +def se3_matrix(R: np.ndarray, t: np.ndarray) -> np.ndarray: + """ + Convert rotation matrix and translation vector to 4x4 transformation matrix. + + Args: + R (np.ndarray): [..., 3, 3] rotation matrix + t (np.ndarray): [..., 3] translation vector + + Returns: + np.ndarray: [..., 4, 4] transformation matrix + """ + assert R.shape[:-2] == t.shape[:-1] + assert R.shape[-1] == 3 and R.shape[-2] == 3 + return np.concatenate([ + np.concatenate([R, t[..., None]], axis=-1), + np.concatenate([np.zeros_like(t), np.ones_like(t[..., :1])], axis=-1)[..., None, :] + ], axis=-2) + + +def slerp_quaternion(q1: np.ndarray, q2: np.ndarray, t: np.ndarray) -> np.ndarray: + """ + Spherical linear interpolation between two unit quaternions. + + Args: + q1 (np.ndarray): [..., d] unit vector 1 + q2 (np.ndarray): [..., d] unit vector 2 + t (np.ndarray): [...] interpolation parameter in [0, 1] + + Returns: + np.ndarray: [..., 3] interpolated unit vector + """ + q1 = q1 / np.linalg.norm(q1, axis=-1, keepdims=True) + q2 = q2 / np.linalg.norm(q2, axis=-1, keepdims=True) + dot = np.sum(q1 * q2, axis=-1, keepdims=True) + + dot = np.where(dot < 0, -dot, dot) # handle negative dot product + + dot = np.minimum(dot, 1.) + theta = np.arccos(dot) * t + + q_ortho = q2 - q1 * dot + q_ortho = q_ortho / np.maximum(np.linalg.norm(q_ortho, axis=-1, keepdims=True), 1e-12) + q = q1 * np.cos(theta) + q_ortho * np.sin(theta) + return q + + +def slerp_rotation_matrix(R1: np.ndarray, R2: np.ndarray, t: np.ndarray) -> np.ndarray: + """ + Spherical linear interpolation between two rotation matrices. + + Args: + R1 (np.ndarray): [..., 3, 3] rotation matrix 1 + R2 (np.ndarray): [..., 3, 3] rotation matrix 2 + t (np.ndarray): [...] interpolation parameter in [0, 1] + + Returns: + np.ndarray: [..., 3, 3] interpolated rotation matrix + """ + quat1 = matrix_to_quaternion(R1) + quat2 = matrix_to_quaternion(R2) + quat = slerp_quaternion(quat1, quat2, t) + return quaternion_to_matrix(quat) + + +def slerp_vector(v1: np.ndarray, v2: np.ndarray, t: np.ndarray) -> np.ndarray: + """ + Spherical linear interpolation between two unit vectors. The vectors are assumed to be normalized. + + Args: + v1 (np.ndarray): [..., d] unit vector 1 + v2 (np.ndarray): [..., d] unit vector 2 + t (np.ndarray): [...] interpolation parameter in [0, 1] + + Returns: + np.ndarray: [..., d] interpolated unit vector + """ + dot = np.sum(v1 * v2, axis=-1, keepdims=True) + + dot = np.minimum(dot, 1.) + theta = np.arccos(dot) * t + + v_ortho = v2 - v1 * dot + v_ortho = v_ortho / np.maximum(np.linalg.norm(v_ortho, axis=-1, keepdims=True), 1e-12) + v = v1 * np.cos(theta) + v_ortho * np.sin(theta) + return v + + +def lerp(x1: np.ndarray, x2: np.ndarray, t: np.ndarray) -> np.ndarray: + """ + Linear interpolation between two vectors. + + Args: + x1 (np.ndarray): [..., d] vector 1 + x2 (np.ndarray): [..., d] vector 2 + t (np.ndarray): [...] interpolation parameter. [0, 1] for interpolation between x1 and x2, otherwise for extrapolation. + + Returns: + np.ndarray: [..., d] interpolated vector + """ + return x1 + np.asarray(t)[..., None] * (x2 - x1) + + +def lerp_se3_matrix(T1: np.ndarray, T2: np.ndarray, t: np.ndarray) -> np.ndarray: + """ + Linear interpolation between two SE(3) matrices. + + Args: + T1 (np.ndarray): [..., 4, 4] SE(3) matrix 1 + T2 (np.ndarray): [..., 4, 4] SE(3) matrix 2 + t (np.ndarray): [...] interpolation parameter in [0, 1] + + Returns: + np.ndarray: [..., 4, 4] interpolated SE(3) matrix + """ + R1 = T1[..., :3, :3] + R2 = T2[..., :3, :3] + trans1 = T1[..., :3, 3] + trans2 = T2[..., :3, 3] + R = slerp_rotation_matrix(R1, R2, t) + trans = lerp(trans1, trans2, t) + return se3_matrix(R, trans) + + +def piecewise_lerp(x: np.ndarray, t: np.ndarray, s: np.ndarray, extrapolation_mode: Literal['constant', 'linear'] = 'constant') -> np.ndarray: + """ + Linear spline interpolation. + + ### Parameters: + - `x`: np.ndarray, shape (n, d): the values of data points. + - `t`: np.ndarray, shape (n,): the times of the data points. + - `s`: np.ndarray, shape (m,): the times to be interpolated. + - `extrapolation_mode`: str, the mode of extrapolation. 'constant' means extrapolate the boundary values, 'linear' means extrapolate linearly. + + ### Returns: + - `y`: np.ndarray, shape (..., m, d): the interpolated values. + """ + i = np.searchsorted(t, s, side='left') + if extrapolation_mode == 'constant': + prev = np.clip(i - 1, 0, len(t) - 1) + suc = np.clip(i, 0, len(t) - 1) + elif extrapolation_mode == 'linear': + prev = np.clip(i - 1, 0, len(t) - 2) + suc = np.clip(i, 1, len(t) - 1) + else: + raise ValueError(f'Invalid extrapolation_mode: {extrapolation_mode}') + + u = (s - t[prev]) / np.maximum(t[suc] - t[prev], 1e-12) + y = lerp(x[prev], x[suc], u) + + return y + + +def piecewise_lerp_se3_matrix(T: np.ndarray, t: np.ndarray, s: np.ndarray, extrapolation_mode: Literal['constant', 'linear'] = 'constant') -> np.ndarray: + """ + Linear spline interpolation for SE(3) matrices. + + ### Parameters: + - `T`: np.ndarray, shape (n, 4, 4): the SE(3) matrices. + - `t`: np.ndarray, shape (n,): the times of the data points. + - `s`: np.ndarray, shape (m,): the times to be interpolated. + - `extrapolation_mode`: str, the mode of extrapolation. 'constant' means extrapolate the boundary values, 'linear' means extrapolate linearly. + + ### Returns: + - `T_interp`: np.ndarray, shape (..., m, 4, 4): the interpolated SE(3) matrices. + """ + i = np.searchsorted(t, s, side='left') + if extrapolation_mode == 'constant': + prev = np.clip(i - 1, 0, len(t) - 1) + suc = np.clip(i, 0, len(t) - 1) + elif extrapolation_mode == 'linear': + prev = np.clip(i - 1, 0, len(t) - 2) + suc = np.clip(i, 1, len(t) - 1) + else: + raise ValueError(f'Invalid extrapolation_mode: {extrapolation_mode}') + + u = (s - t[prev]) / np.maximum(t[suc] - t[prev], 1e-12) + T = lerp_se3_matrix(T[prev], T[suc], u) + + return T + + +def apply_transform(T: np.ndarray, x: np.ndarray) -> np.ndarray: + """ + Apply SE(3) transformation to a point or a set of points. + + ### Parameters: + - `T`: np.ndarray, shape (..., 4, 4): the SE(3) matrix. + - `x`: np.ndarray, shape (..., 3): the point or a set of points to be transformed. + + ### Returns: + - `x_transformed`: np.ndarray, shape (..., 3): the transformed point or a set of points. + """ + x = np.asarray(x) + assert x.shape[-1] == 3 + T = np.asarray(T) + assert T.shape[-2:] == (4, 4) + x_transformed = (T[..., :3, :3] @ x[..., :, None]) + T[..., :3, 3][..., None] + return x_transformed[..., 0] \ No newline at end of file diff --git a/submodules/MoGe/utils3d/numpy/utils.py b/submodules/MoGe/utils3d/numpy/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..8f3fbfccef9eb8224f658307c425125095f4c2a7 --- /dev/null +++ b/submodules/MoGe/utils3d/numpy/utils.py @@ -0,0 +1,625 @@ +import numpy as np +from typing import * +from numbers import Number +import warnings +import functools + +from ._helpers import batched +from .._helpers import no_warnings +from . import transforms +from . import mesh + +__all__ = [ + 'sliding_window_1d', + 'sliding_window_nd', + 'sliding_window_2d', + 'max_pool_1d', + 'max_pool_2d', + 'max_pool_nd', + 'depth_edge', + 'normals_edge', + 'depth_aliasing', + 'interpolate', + 'image_scrcoord', + 'image_uv', + 'image_pixel_center', + 'image_pixel', + 'image_mesh', + 'image_mesh_from_depth', + 'points_to_normals', + 'points_to_normals', + 'chessboard', + 'cube', + 'icosahedron', + 'square', + 'camera_frustum', + 'to4x4' +] + + + +def sliding_window_1d(x: np.ndarray, window_size: int, stride: int, axis: int = -1): + """ + Return x view of the input array with x sliding window of the given kernel size and stride. + The sliding window is performed over the given axis, and the window dimension is append to the end of the output array's shape. + + Args: + x (np.ndarray): input array with shape (..., axis_size, ...) + kernel_size (int): size of the sliding window + stride (int): stride of the sliding window + axis (int): axis to perform sliding window over + + Returns: + a_sliding (np.ndarray): view of the input array with shape (..., n_windows, ..., kernel_size), where n_windows = (axis_size - kernel_size + 1) // stride + """ + assert x.shape[axis] >= window_size, f"kernel_size ({window_size}) is larger than axis_size ({x.shape[axis]})" + axis = axis % x.ndim + shape = (*x.shape[:axis], (x.shape[axis] - window_size + 1) // stride, *x.shape[axis + 1:], window_size) + strides = (*x.strides[:axis], stride * x.strides[axis], *x.strides[axis + 1:], x.strides[axis]) + x_sliding = np.lib.stride_tricks.as_strided(x, shape=shape, strides=strides) + return x_sliding + + +def sliding_window_nd(x: np.ndarray, window_size: Tuple[int,...], stride: Tuple[int,...], axis: Tuple[int,...]) -> np.ndarray: + axis = [axis[i] % x.ndim for i in range(len(axis))] + for i in range(len(axis)): + x = sliding_window_1d(x, window_size[i], stride[i], axis[i]) + return x + + +def sliding_window_2d(x: np.ndarray, window_size: Union[int, Tuple[int, int]], stride: Union[int, Tuple[int, int]], axis: Tuple[int, int] = (-2, -1)) -> np.ndarray: + if isinstance(window_size, int): + window_size = (window_size, window_size) + if isinstance(stride, int): + stride = (stride, stride) + return sliding_window_nd(x, window_size, stride, axis) + + +def max_pool_1d(x: np.ndarray, kernel_size: int, stride: int, padding: int = 0, axis: int = -1): + axis = axis % x.ndim + if padding > 0: + fill_value = np.nan if x.dtype.kind == 'f' else np.iinfo(x.dtype).min + padding_arr = np.full((*x.shape[:axis], padding, *x.shape[axis + 1:]), fill_value=fill_value, dtype=x.dtype) + x = np.concatenate([padding_arr, x, padding_arr], axis=axis) + a_sliding = sliding_window_1d(x, kernel_size, stride, axis) + max_pool = np.nanmax(a_sliding, axis=-1) + return max_pool + + +def max_pool_nd(x: np.ndarray, kernel_size: Tuple[int,...], stride: Tuple[int,...], padding: Tuple[int,...], axis: Tuple[int,...]) -> np.ndarray: + for i in range(len(axis)): + x = max_pool_1d(x, kernel_size[i], stride[i], padding[i], axis[i]) + return x + + +def max_pool_2d(x: np.ndarray, kernel_size: Union[int, Tuple[int, int]], stride: Union[int, Tuple[int, int]], padding: Union[int, Tuple[int, int]], axis: Tuple[int, int] = (-2, -1)): + if isinstance(kernel_size, Number): + kernel_size = (kernel_size, kernel_size) + if isinstance(stride, Number): + stride = (stride, stride) + if isinstance(padding, Number): + padding = (padding, padding) + axis = tuple(axis) + return max_pool_nd(x, kernel_size, stride, padding, axis) + +@no_warnings(category=RuntimeWarning) +def depth_edge(depth: np.ndarray, atol: float = None, rtol: float = None, kernel_size: int = 3, mask: np.ndarray = None) -> np.ndarray: + """ + Compute the edge mask from depth map. The edge is defined as the pixels whose neighbors have large difference in depth. + + Args: + depth (np.ndarray): shape (..., height, width), linear depth map + atol (float): absolute tolerance + rtol (float): relative tolerance + + Returns: + edge (np.ndarray): shape (..., height, width) of dtype torch.bool + """ + if mask is None: + diff = (max_pool_2d(depth, kernel_size, stride=1, padding=kernel_size // 2) + max_pool_2d(-depth, kernel_size, stride=1, padding=kernel_size // 2)) + else: + diff = (max_pool_2d(np.where(mask, depth, -np.inf), kernel_size, stride=1, padding=kernel_size // 2) + max_pool_2d(np.where(mask, -depth, -np.inf), kernel_size, stride=1, padding=kernel_size // 2)) + + edge = np.zeros_like(depth, dtype=bool) + if atol is not None: + edge |= diff > atol + + if rtol is not None: + edge |= diff / depth > rtol + return edge + + +def depth_aliasing(depth: np.ndarray, atol: float = None, rtol: float = None, kernel_size: int = 3, mask: np.ndarray = None) -> np.ndarray: + """ + Compute the map that indicates the aliasing of x depth map. The aliasing is defined as the pixels which neither close to the maximum nor the minimum of its neighbors. + Args: + depth (np.ndarray): shape (..., height, width), linear depth map + atol (float): absolute tolerance + rtol (float): relative tolerance + + Returns: + edge (np.ndarray): shape (..., height, width) of dtype torch.bool + """ + if mask is None: + diff_max = max_pool_2d(depth, kernel_size, stride=1, padding=kernel_size // 2) - depth + diff_min = max_pool_2d(-depth, kernel_size, stride=1, padding=kernel_size // 2) + depth + else: + diff_max = max_pool_2d(np.where(mask, depth, -np.inf), kernel_size, stride=1, padding=kernel_size // 2) - depth + diff_min = max_pool_2d(np.where(mask, -depth, -np.inf), kernel_size, stride=1, padding=kernel_size // 2) + depth + diff = np.minimum(diff_max, diff_min) + + edge = np.zeros_like(depth, dtype=bool) + if atol is not None: + edge |= diff > atol + if rtol is not None: + edge |= diff / depth > rtol + return edge + +@no_warnings(category=RuntimeWarning) +def normals_edge(normals: np.ndarray, tol: float, kernel_size: int = 3, mask: np.ndarray = None) -> np.ndarray: + """ + Compute the edge mask from normal map. + + Args: + normal (np.ndarray): shape (..., height, width, 3), normal map + tol (float): tolerance in degrees + + Returns: + edge (np.ndarray): shape (..., height, width) of dtype torch.bool + """ + assert normals.ndim >= 3 and normals.shape[-1] == 3, "normal should be of shape (..., height, width, 3)" + normals = normals / (np.linalg.norm(normals, axis=-1, keepdims=True) + 1e-12) + + padding = kernel_size // 2 + normals_window = sliding_window_2d( + np.pad(normals, (*([(0, 0)] * (normals.ndim - 3)), (padding, padding), (padding, padding), (0, 0)), mode='edge'), + window_size=kernel_size, + stride=1, + axis=(-3, -2) + ) + if mask is None: + angle_diff = np.arccos((normals[..., None, None] * normals_window).sum(axis=-3)).max(axis=(-2, -1)) + else: + mask_window = sliding_window_2d( + np.pad(mask, (*([(0, 0)] * (mask.ndim - 3)), (padding, padding), (padding, padding)), mode='edge'), + window_size=kernel_size, + stride=1, + axis=(-3, -2) + ) + angle_diff = np.where(mask_window, np.arccos((normals[..., None, None] * normals_window).sum(axis=-3)), 0).max(axis=(-2, -1)) + + angle_diff = max_pool_2d(angle_diff, kernel_size, stride=1, padding=kernel_size // 2) + edge = angle_diff > np.deg2rad(tol) + return edge + + +@no_warnings(category=RuntimeWarning) +def points_to_normals(point: np.ndarray, mask: np.ndarray = None) -> np.ndarray: + """ + Calculate normal map from point map. Value range is [-1, 1]. Normal direction in OpenGL identity camera's coordinate system. + + Args: + point (np.ndarray): shape (height, width, 3), point map + Returns: + normal (np.ndarray): shape (height, width, 3), normal map. + """ + height, width = point.shape[-3:-1] + has_mask = mask is not None + + if mask is None: + mask = np.ones_like(point[..., 0], dtype=bool) + mask_pad = np.zeros((height + 2, width + 2), dtype=bool) + mask_pad[1:-1, 1:-1] = mask + mask = mask_pad + + pts = np.zeros((height + 2, width + 2, 3), dtype=point.dtype) + pts[1:-1, 1:-1, :] = point + up = pts[:-2, 1:-1, :] - pts[1:-1, 1:-1, :] + left = pts[1:-1, :-2, :] - pts[1:-1, 1:-1, :] + down = pts[2:, 1:-1, :] - pts[1:-1, 1:-1, :] + right = pts[1:-1, 2:, :] - pts[1:-1, 1:-1, :] + normal = np.stack([ + np.cross(up, left, axis=-1), + np.cross(left, down, axis=-1), + np.cross(down, right, axis=-1), + np.cross(right, up, axis=-1), + ]) + normal = normal / (np.linalg.norm(normal, axis=-1, keepdims=True) + 1e-12) + valid = np.stack([ + mask[:-2, 1:-1] & mask[1:-1, :-2], + mask[1:-1, :-2] & mask[2:, 1:-1], + mask[2:, 1:-1] & mask[1:-1, 2:], + mask[1:-1, 2:] & mask[:-2, 1:-1], + ]) & mask[None, 1:-1, 1:-1] + normal = (normal * valid[..., None]).sum(axis=0) + normal = normal / (np.linalg.norm(normal, axis=-1, keepdims=True) + 1e-12) + + if has_mask: + normal_mask = valid.any(axis=0) + normal = np.where(normal_mask[..., None], normal, 0) + return normal, normal_mask + else: + return normal + + +def depth_to_normals(depth: np.ndarray, intrinsics: np.ndarray, mask: np.ndarray = None) -> np.ndarray: + """ + Calculate normal map from depth map. Value range is [-1, 1]. Normal direction in OpenGL identity camera's coordinate system. + + Args: + depth (np.ndarray): shape (height, width), linear depth map + intrinsics (np.ndarray): shape (3, 3), intrinsics matrix + Returns: + normal (np.ndarray): shape (height, width, 3), normal map. + """ + has_mask = mask is not None + + height, width = depth.shape[-2:] + if mask is None: + mask = np.ones_like(depth, dtype=bool) + + uv = image_uv(width=width, height=height, dtype=np.float32) + pts = transforms.unproject_cv(uv, depth, intrinsics=intrinsics, extrinsics=None) + + return points_to_normals(pts, mask) + +def interpolate(bary: np.ndarray, tri_id: np.ndarray, attr: np.ndarray, faces: np.ndarray) -> np.ndarray: + """Interpolate with given barycentric coordinates and triangle indices + + Args: + bary (np.ndarray): shape (..., 3), barycentric coordinates + tri_id (np.ndarray): int array of shape (...), triangle indices + attr (np.ndarray): shape (N, M), vertices attributes + faces (np.ndarray): int array of shape (T, 3), face vertex indices + + Returns: + np.ndarray: shape (..., M) interpolated result + """ + faces_ = np.concatenate([np.zeros((1, 3), dtype=faces.dtype), faces + 1], axis=0) + attr_ = np.concatenate([np.zeros((1, attr.shape[1]), dtype=attr.dtype), attr], axis=0) + return np.sum(bary[..., None] * attr_[faces_[tri_id + 1]], axis=-2) + + +def image_scrcoord( + width: int, + height: int, +) -> np.ndarray: + """ + Get OpenGL's screen space coordinates, ranging in [0, 1]. + [0, 0] is the bottom-left corner of the image. + + Args: + width (int): image width + height (int): image height + + Returns: + (np.ndarray): shape (height, width, 2) + """ + x, y = np.meshgrid( + np.linspace(0.5 / width, 1 - 0.5 / width, width, dtype=np.float32), + np.linspace(1 - 0.5 / height, 0.5 / height, height, dtype=np.float32), + indexing='xy' + ) + return np.stack([x, y], axis=2) + + +def image_uv( + height: int, + width: int, + left: int = None, + top: int = None, + right: int = None, + bottom: int = None, + dtype: np.dtype = np.float32 +) -> np.ndarray: + """ + Get image space UV grid, ranging in [0, 1]. + + >>> image_uv(10, 10): + [[[0.05, 0.05], [0.15, 0.05], ..., [0.95, 0.05]], + [[0.05, 0.15], [0.15, 0.15], ..., [0.95, 0.15]], + ... ... ... + [[0.05, 0.95], [0.15, 0.95], ..., [0.95, 0.95]]] + + Args: + width (int): image width + height (int): image height + + Returns: + np.ndarray: shape (height, width, 2) + """ + if left is None: left = 0 + if top is None: top = 0 + if right is None: right = width + if bottom is None: bottom = height + u = np.linspace((left + 0.5) / width, (right - 0.5) / width, right - left, dtype=dtype) + v = np.linspace((top + 0.5) / height, (bottom - 0.5) / height, bottom - top, dtype=dtype) + u, v = np.meshgrid(u, v, indexing='xy') + return np.stack([u, v], axis=2) + + +def image_pixel_center( + height: int, + width: int, + left: int = None, + top: int = None, + right: int = None, + bottom: int = None, + dtype: np.dtype = np.float32 +) -> np.ndarray: + """ + Get image pixel center coordinates, ranging in [0, width] and [0, height]. + `image[i, j]` has pixel center coordinates `(j + 0.5, i + 0.5)`. + + >>> image_pixel_center(10, 10): + [[[0.5, 0.5], [1.5, 0.5], ..., [9.5, 0.5]], + [[0.5, 1.5], [1.5, 1.5], ..., [9.5, 1.5]], + ... ... ... + [[0.5, 9.5], [1.5, 9.5], ..., [9.5, 9.5]]] + + Args: + width (int): image width + height (int): image height + + Returns: + np.ndarray: shape (height, width, 2) + """ + if left is None: left = 0 + if top is None: top = 0 + if right is None: right = width + if bottom is None: bottom = height + u = np.linspace(left + 0.5, right - 0.5, right - left, dtype=dtype) + v = np.linspace(top + 0.5, bottom - 0.5, bottom - top, dtype=dtype) + u, v = np.meshgrid(u, v, indexing='xy') + return np.stack([u, v], axis=2) + +def image_pixel( + height: int, + width: int, + left: int = None, + top: int = None, + right: int = None, + bottom: int = None, + dtype: np.dtype = np.int32 +) -> np.ndarray: + """ + Get image pixel coordinates grid, ranging in [0, width - 1] and [0, height - 1]. + `image[i, j]` has pixel center coordinates `(j, i)`. + + >>> image_pixel_center(10, 10): + [[[0, 0], [1, 0], ..., [9, 0]], + [[0, 1.5], [1, 1], ..., [9, 1]], + ... ... ... + [[0, 9.5], [1, 9], ..., [9, 9 ]]] + + Args: + width (int): image width + height (int): image height + + Returns: + np.ndarray: shape (height, width, 2) + """ + if left is None: left = 0 + if top is None: top = 0 + if right is None: right = width + if bottom is None: bottom = height + u = np.arange(left, right, dtype=dtype) + v = np.arange(top, bottom, dtype=dtype) + u, v = np.meshgrid(u, v, indexing='xy') + return np.stack([u, v], axis=2) + + +def image_mesh( + *image_attrs: np.ndarray, + mask: np.ndarray = None, + tri: bool = False, + return_indices: bool = False +) -> Tuple[np.ndarray, ...]: + """ + Get a mesh regarding image pixel uv coordinates as vertices and image grid as faces. + + Args: + *image_attrs (np.ndarray): image attributes in shape (height, width, [channels]) + mask (np.ndarray, optional): binary mask of shape (height, width), dtype=bool. Defaults to None. + + Returns: + faces (np.ndarray): faces connecting neighboring pixels. shape (T, 4) if tri is False, else (T, 3) + *vertex_attrs (np.ndarray): vertex attributes in corresponding order with input image_attrs + indices (np.ndarray, optional): indices of vertices in the original mesh + """ + assert (len(image_attrs) > 0) or (mask is not None), "At least one of image_attrs or mask should be provided" + height, width = next(image_attrs).shape[:2] if mask is None else mask.shape + assert all(img.shape[:2] == (height, width) for img in image_attrs), "All image_attrs should have the same shape" + + row_faces = np.stack([np.arange(0, width - 1, dtype=np.int32), np.arange(width, 2 * width - 1, dtype=np.int32), np.arange(1 + width, 2 * width, dtype=np.int32), np.arange(1, width, dtype=np.int32)], axis=1) + faces = (np.arange(0, (height - 1) * width, width, dtype=np.int32)[:, None, None] + row_faces[None, :, :]).reshape((-1, 4)) + if mask is None: + if tri: + faces = mesh.triangulate(faces) + ret = [faces, *(img.reshape(-1, *img.shape[2:]) for img in image_attrs)] + if return_indices: + ret.append(np.arange(height * width, dtype=np.int32)) + return tuple(ret) + else: + quad_mask = (mask[:-1, :-1] & mask[1:, :-1] & mask[1:, 1:] & mask[:-1, 1:]).ravel() + faces = faces[quad_mask] + if tri: + faces = mesh.triangulate(faces) + return mesh.remove_unreferenced_vertices( + faces, + *(x.reshape(-1, *x.shape[2:]) for x in image_attrs), + return_indices=return_indices + ) + +def image_mesh_from_depth( + depth: np.ndarray, + extrinsics: np.ndarray = None, + intrinsics: np.ndarray = None, + *vertice_attrs: np.ndarray, + atol: float = None, + rtol: float = None, + remove_by_depth: bool = False, + return_uv: bool = False, + return_indices: bool = False +) -> Tuple[np.ndarray, ...]: + """ + Get x triangle mesh by lifting depth map to 3D. + + Args: + depth (np.ndarray): [H, W] depth map + extrinsics (np.ndarray, optional): [4, 4] extrinsics matrix. Defaults to None. + intrinsics (np.ndarray, optional): [3, 3] intrinsics matrix. Defaults to None. + *vertice_attrs (np.ndarray): [H, W, C] vertex attributes. Defaults to None. + atol (float, optional): absolute tolerance. Defaults to None. + rtol (float, optional): relative tolerance. Defaults to None. + triangles with vertices having depth difference larger than atol + rtol * depth will be marked. + remove_by_depth (bool, optional): whether to remove triangles with large depth difference. Defaults to True. + return_uv (bool, optional): whether to return uv coordinates. Defaults to False. + return_indices (bool, optional): whether to return indices of vertices in the original mesh. Defaults to False. + + Returns: + vertices (np.ndarray): [N, 3] vertices + faces (np.ndarray): [T, 3] faces + *vertice_attrs (np.ndarray): [N, C] vertex attributes + image_uv (np.ndarray, optional): [N, 2] uv coordinates + ref_indices (np.ndarray, optional): [N] indices of vertices in the original mesh + """ + height, width = depth.shape + image_uv, image_face = image_mesh(height, width) + depth = depth.reshape(-1) + pts = transforms.unproject_cv(image_uv, depth, extrinsics, intrinsics) + image_face = mesh.triangulate(image_face, vertices=pts) + ref_indices = None + ret = [] + if atol is not None or rtol is not None: + atol = 0 if atol is None else atol + rtol = 0 if rtol is None else rtol + mean = depth[image_face].mean(axis=1) + diff = np.max(np.abs(depth[image_face] - depth[image_face[:, [1, 2, 0]]]), axis=1) + mask = (diff <= atol + rtol * mean) + image_face_ = image_face[mask] + image_face_, ref_indices = mesh.remove_unreferenced_vertices(image_face_, return_indices=True) + + remove = remove_by_depth and ref_indices is not None + if remove: + pts = pts[ref_indices] + image_face = image_face_ + ret += [pts, image_face] + for attr in vertice_attrs: + ret.append(attr.reshape(-1, attr.shape[-1]) if not remove else attr.reshape(-1, attr.shape[-1])[ref_indices]) + if return_uv: + ret.append(image_uv if not remove else image_uv[ref_indices]) + if return_indices and ref_indices is not None: + ret.append(ref_indices) + return tuple(ret) + + +def chessboard(width: int, height: int, grid_size: int, color_a: np.ndarray, color_b: np.ndarray) -> np.ndarray: + """get x chessboard image + + Args: + width (int): image width + height (int): image height + grid_size (int): size of chessboard grid + color_a (np.ndarray): color of the grid at the top-left corner + color_b (np.ndarray): color in complementary grid cells + + Returns: + image (np.ndarray): shape (height, width, channels), chessboard image + """ + x = np.arange(width) // grid_size + y = np.arange(height) // grid_size + mask = (x[None, :] + y[:, None]) % 2 + image = (1 - mask[..., None]) * color_a + mask[..., None] * color_b + return image + + +def square(tri: bool = False) -> Tuple[np.ndarray, np.ndarray]: + """ + Get a square mesh of area 1 centered at origin in the xy-plane. + + ### Returns + vertices (np.ndarray): shape (4, 3) + faces (np.ndarray): shape (1, 4) + """ + vertices = np.array([ + [-0.5, 0.5, 0], [0.5, 0.5, 0], [0.5, -0.5, 0], [-0.5, -0.5, 0] # v0-v1-v2-v3 + ], dtype=np.float32) + if tri: + faces = np.array([[0, 1, 2], [0, 2, 3]], dtype=np.int32) + else: + faces = np.array([[0, 1, 2, 3]], dtype=np.int32) + return vertices, faces + + +def cube(tri: bool = False) -> Tuple[np.ndarray, np.ndarray]: + """ + Get x cube mesh of size 1 centered at origin. + + ### Parameters + tri (bool, optional): return triangulated mesh. Defaults to False, which returns quad mesh. + + ### Returns + vertices (np.ndarray): shape (8, 3) + faces (np.ndarray): shape (12, 3) + """ + vertices = np.array([ + [-0.5, 0.5, 0.5], [0.5, 0.5, 0.5], [0.5, -0.5, 0.5], [-0.5, -0.5, 0.5], # v0-v1-v2-v3 + [-0.5, 0.5, -0.5], [0.5, 0.5, -0.5], [0.5, -0.5, -0.5], [-0.5, -0.5, -0.5] # v4-v5-v6-v7 + ], dtype=np.float32).reshape((-1, 3)) + + faces = np.array([ + [0, 1, 2, 3], # v0-v1-v2-v3 (front) + [4, 5, 1, 0], # v4-v5-v1-v0 (top) + [3, 2, 6, 7], # v3-v2-v6-v7 (bottom) + [5, 4, 7, 6], # v5-v4-v7-v6 (back) + [1, 5, 6, 2], # v1-v5-v6-v2 (right) + [4, 0, 3, 7] # v4-v0-v3-v7 (left) + ], dtype=np.int32) + + if tri: + faces = mesh.triangulate(faces, vertices=vertices) + + return vertices, faces + + +def camera_frustum(extrinsics: np.ndarray, intrinsics: np.ndarray, depth: float = 1.0) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: + """ + Get x triangle mesh of camera frustum. + """ + assert extrinsics.shape == (4, 4) and intrinsics.shape == (3, 3) + vertices = transforms.unproject_cv( + np.array([[0, 0], [0, 0], [0, 1], [1, 1], [1, 0]], dtype=np.float32), + np.array([0] + [depth] * 4, dtype=np.float32), + extrinsics, + intrinsics + ).astype(np.float32) + edges = np.array([ + [0, 1], [0, 2], [0, 3], [0, 4], + [1, 2], [2, 3], [3, 4], [4, 1] + ], dtype=np.int32) + faces = np.array([ + [0, 1, 2], + [0, 2, 3], + [0, 3, 4], + [0, 4, 1], + [1, 2, 3], + [1, 3, 4] + ], dtype=np.int32) + return vertices, edges, faces + + +def icosahedron(): + A = (1 + 5 ** 0.5) / 2 + vertices = np.array([ + [0, 1, A], [0, -1, A], [0, 1, -A], [0, -1, -A], + [1, A, 0], [-1, A, 0], [1, -A, 0], [-1, -A, 0], + [A, 0, 1], [A, 0, -1], [-A, 0, 1], [-A, 0, -1] + ], dtype=np.float32) + faces = np.array([ + [0, 1, 8], [0, 8, 4], [0, 4, 5], [0, 5, 10], [0, 10, 1], + [3, 2, 9], [3, 9, 6], [3, 6, 7], [3, 7, 11], [3, 11, 2], + [1, 6, 8], [8, 9, 4], [4, 2, 5], [5, 11, 10], [10, 7, 1], + [2, 4, 9], [9, 8, 6], [6, 1, 7], [7, 10, 11], [11, 5, 2] + ], dtype=np.int32) + return vertices, faces diff --git a/submodules/MoGe/utils3d/torch/__init__.py b/submodules/MoGe/utils3d/torch/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..bffcf41b5e906c25f8c8f01fb0a1b557151103c1 --- /dev/null +++ b/submodules/MoGe/utils3d/torch/__init__.py @@ -0,0 +1,139 @@ +import importlib +import itertools +import torch +from typing import TYPE_CHECKING + +__modules_all__ = { + 'mesh': [ + 'triangulate', + 'compute_face_normal', + 'compute_face_angles', + 'compute_vertex_normal', + 'compute_vertex_normal_weighted', + 'compute_edges', + 'compute_connected_components', + 'compute_edge_connected_components', + 'compute_boundarys', + 'compute_dual_graph', + 'remove_unreferenced_vertices', + 'remove_corrupted_faces', + 'remove_isolated_pieces', + 'merge_duplicate_vertices', + 'subdivide_mesh_simple', + 'compute_face_tbn', + 'compute_vertex_tbn', + 'laplacian', + 'laplacian_smooth_mesh', + 'taubin_smooth_mesh', + 'laplacian_hc_smooth_mesh', + ], + 'nerf': [ + 'get_rays', + 'get_image_rays', + 'get_mipnerf_cones', + 'volume_rendering', + 'bin_sample', + 'importance_sample', + 'nerf_render_rays', + 'mipnerf_render_rays', + 'nerf_render_view', + 'mipnerf_render_view', + 'InstantNGP', + ], + 'utils': [ + 'sliding_window_1d', + 'sliding_window_2d', + 'sliding_window_nd', + 'image_uv', + 'image_pixel_center', + 'image_mesh', + 'chessboard', + 'depth_edge', + 'depth_aliasing', + 'image_mesh_from_depth', + 'point_to_normal', + 'depth_to_normal', + 'masked_min', + 'masked_max', + 'bounding_rect' + ], + 'transforms': [ + 'perspective', + 'perspective_from_fov', + 'perspective_from_fov_xy', + 'intrinsics_from_focal_center', + 'intrinsics_from_fov', + 'intrinsics_from_fov_xy', + 'view_look_at', + 'extrinsics_look_at', + 'perspective_to_intrinsics', + 'intrinsics_to_perspective', + 'extrinsics_to_view', + 'view_to_extrinsics', + 'normalize_intrinsics', + 'crop_intrinsics', + 'pixel_to_uv', + 'pixel_to_ndc', + 'uv_to_pixel', + 'project_depth', + 'depth_buffer_to_linear', + 'project_gl', + 'project_cv', + 'unproject_gl', + 'unproject_cv', + 'skew_symmetric', + 'rotation_matrix_from_vectors', + 'euler_axis_angle_rotation', + 'euler_angles_to_matrix', + 'matrix_to_euler_angles', + 'matrix_to_quaternion', + 'quaternion_to_matrix', + 'matrix_to_axis_angle', + 'axis_angle_to_matrix', + 'axis_angle_to_quaternion', + 'quaternion_to_axis_angle', + 'slerp', + 'interpolate_extrinsics', + 'interpolate_view', + 'extrinsics_to_essential', + 'to4x4', + 'rotation_matrix_2d', + 'rotate_2d', + 'translate_2d', + 'scale_2d', + 'apply_2d', + ], + 'rasterization': [ + 'RastContext', + 'rasterize_triangle_faces', + 'warp_image_by_depth', + 'warp_image_by_forward_flow', + ], +} + + +__all__ = list(itertools.chain(*__modules_all__.values())) + +def __getattr__(name): + try: + return globals()[name] + except KeyError: + pass + + try: + module_name = next(m for m in __modules_all__ if name in __modules_all__[m]) + except StopIteration: + raise AttributeError(f"module '{__name__}' has no attribute '{name}'") + module = importlib.import_module(f'.{module_name}', __name__) + for key in __modules_all__[module_name]: + globals()[key] = getattr(module, key) + + return globals()[name] + + +if TYPE_CHECKING: + from .transforms import * + from .mesh import * + from .utils import * + from .nerf import * + from .rasterization import * \ No newline at end of file diff --git a/submodules/MoGe/utils3d/torch/_helpers.py b/submodules/MoGe/utils3d/torch/_helpers.py new file mode 100644 index 0000000000000000000000000000000000000000..442e2cb6358ba4f105d2664f9b9f44b6ec6561ca --- /dev/null +++ b/submodules/MoGe/utils3d/torch/_helpers.py @@ -0,0 +1,103 @@ +# decorator +import torch +from numbers import Number +import inspect +from functools import wraps +from .._helpers import suppress_traceback + + +def get_device(args, kwargs): + device = None + for arg in (list(args) + list(kwargs.values())): + if isinstance(arg, torch.Tensor): + if device is None: + device = arg.device + elif device != arg.device: + raise ValueError("All tensors must be on the same device.") + return device + + +def get_args_order(func, args, kwargs): + """ + Get the order of the arguments of a function. + """ + names = inspect.getfullargspec(func).args + names_idx = {name: i for i, name in enumerate(names)} + args_order = [] + kwargs_order = {} + for name, arg in kwargs.items(): + if name in names: + kwargs_order[name] = names_idx[name] + names.remove(name) + for i, arg in enumerate(args): + if i < len(names): + args_order.append(names_idx[names[i]]) + return args_order, kwargs_order + + +def broadcast_args(args, kwargs, args_dim, kwargs_dim): + spatial = [] + for arg, arg_dim in zip(args + list(kwargs.values()), args_dim + list(kwargs_dim.values())): + if isinstance(arg, torch.Tensor) and arg_dim is not None: + arg_spatial = arg.shape[:arg.ndim-arg_dim] + if len(arg_spatial) > len(spatial): + spatial = [1] * (len(arg_spatial) - len(spatial)) + spatial + for j in range(len(arg_spatial)): + if spatial[-j] < arg_spatial[-j]: + if spatial[-j] == 1: + spatial[-j] = arg_spatial[-j] + else: + raise ValueError("Cannot broadcast arguments.") + for i, arg in enumerate(args): + if isinstance(arg, torch.Tensor) and args_dim[i] is not None: + args[i] = torch.broadcast_to(arg, [*spatial, *arg.shape[arg.ndim-args_dim[i]:]]) + for key, arg in kwargs.items(): + if isinstance(arg, torch.Tensor) and kwargs_dim[key] is not None: + kwargs[key] = torch.broadcast_to(arg, [*spatial, *arg.shape[arg.ndim-kwargs_dim[key]:]]) + return args, kwargs, spatial + +@suppress_traceback +def batched(*dims): + """ + Decorator that allows a function to be called with batched arguments. + """ + def decorator(func): + @wraps(func) + def wrapper(*args, device=torch.device('cpu'), **kwargs): + args = list(args) + # get arguments dimensions + args_order, kwargs_order = get_args_order(func, args, kwargs) + args_dim = [dims[i] for i in args_order] + kwargs_dim = {key: dims[i] for key, i in kwargs_order.items()} + # convert to torch tensor + device = get_device(args, kwargs) or device + for i, arg in enumerate(args): + if isinstance(arg, (Number, list, tuple)) and args_dim[i] is not None: + args[i] = torch.tensor(arg, device=device) + for key, arg in kwargs.items(): + if isinstance(arg, (Number, list, tuple)) and kwargs_dim[key] is not None: + kwargs[key] = torch.tensor(arg, device=device) + # broadcast arguments + args, kwargs, spatial = broadcast_args(args, kwargs, args_dim, kwargs_dim) + for i, (arg, arg_dim) in enumerate(zip(args, args_dim)): + if isinstance(arg, torch.Tensor) and arg_dim is not None: + args[i] = arg.reshape([-1, *arg.shape[arg.ndim-arg_dim:]]) + for key, arg in kwargs.items(): + if isinstance(arg, torch.Tensor) and kwargs_dim[key] is not None: + kwargs[key] = arg.reshape([-1, *arg.shape[arg.ndim-kwargs_dim[key]:]]) + # call function + results = func(*args, **kwargs) + type_results = type(results) + results = list(results) if isinstance(results, (tuple, list)) else [results] + # restore spatial dimensions + for i, result in enumerate(results): + results[i] = result.reshape([*spatial, *result.shape[1:]]) + if type_results == tuple: + results = tuple(results) + elif type_results == list: + results = list(results) + else: + results = results[0] + return results + return wrapper + return decorator \ No newline at end of file diff --git a/submodules/MoGe/utils3d/torch/mesh.py b/submodules/MoGe/utils3d/torch/mesh.py new file mode 100644 index 0000000000000000000000000000000000000000..5b874d163e5edad3ef871a276b4edccf2e593265 --- /dev/null +++ b/submodules/MoGe/utils3d/torch/mesh.py @@ -0,0 +1,688 @@ +import torch +import torch.nn.functional as F +from typing import * +from ._helpers import batched + + +__all__ = [ + 'triangulate', + 'compute_face_normal', + 'compute_face_angles', + 'compute_vertex_normal', + 'compute_vertex_normal_weighted', + 'compute_edges', + 'compute_connected_components', + 'compute_edge_connected_components', + 'compute_boundarys', + 'compute_dual_graph', + 'remove_unreferenced_vertices', + 'remove_corrupted_faces', + 'remove_isolated_pieces', + 'merge_duplicate_vertices', + 'subdivide_mesh_simple', + 'compute_face_tbn', + 'compute_vertex_tbn', + 'laplacian', + 'laplacian_smooth_mesh', + 'taubin_smooth_mesh', + 'laplacian_hc_smooth_mesh', +] + + +def _group( + values: torch.Tensor, + required_group_size: Optional[int] = None, + return_values: bool = False +) -> Tuple[Union[List[torch.Tensor], torch.Tensor], Optional[torch.Tensor]]: + """ + Group values into groups with identical values. + + Args: + values (torch.Tensor): [N] values to group + required_group_size (int, optional): required group size. Defaults to None. + return_values (bool, optional): return values of groups. Defaults to False. + + Returns: + group (Union[List[torch.Tensor], torch.Tensor]): list of groups or group indices. It will be a list of groups if required_group_size is None, otherwise a tensor of group indices. + group_values (Optional[torch.Tensor]): values of groups. Only returned if return_values is True. + """ + sorted_values, indices = torch.sort(values) + nondupe = torch.cat([torch.tensor([True], dtype=torch.bool, device=values.device), sorted_values[1:] != sorted_values[:-1]]) + nondupe_indices = torch.cumsum(nondupe, dim=0) - 1 + counts = torch.bincount(nondupe_indices) + if required_group_size is None: + groups = torch.split(indices, counts.tolist()) + if return_values: + group_values = sorted_values[nondupe] + return groups, group_values + else: + return groups + else: + counts = counts[nondupe_indices] + groups = indices[counts == required_group_size].reshape(-1, required_group_size) + if return_values: + group_values = sorted_values[nondupe][counts[nondupe] == required_group_size] + return groups, group_values + else: + return groups + +def triangulate( + faces: torch.Tensor, + vertices: torch.Tensor = None, + backslash: bool = None +) -> torch.Tensor: + """ + Triangulate a polygonal mesh. + + Args: + faces (torch.Tensor): [..., L, P] polygonal faces + vertices (torch.Tensor, optional): [..., N, 3] 3-dimensional vertices. + If given, the triangulation is performed according to the distance + between vertices. Defaults to None. + backslash (torch.Tensor, optional): [..., L] boolean array indicating + how to triangulate the quad faces. Defaults to None. + + + Returns: + (torch.Tensor): [L * (P - 2), 3] triangular faces + """ + if faces.shape[-1] == 3: + return faces + P = faces.shape[-1] + if vertices is not None: + assert faces.shape[-1] == 4, "now only support quad mesh" + if backslash is None: + faces_idx = faces.long() + backslash = torch.norm(vertices[faces_idx[..., 0]] - vertices[faces_idx[..., 2]], p=2, dim=-1) < \ + torch.norm(vertices[faces_idx[..., 1]] - vertices[faces_idx[..., 3]], p=2, dim=-1) + if backslash is None: + loop_indice = torch.stack([ + torch.zeros(P - 2, dtype=int), + torch.arange(1, P - 1, 1, dtype=int), + torch.arange(2, P, 1, dtype=int) + ], axis=1) + return faces[:, loop_indice].reshape(-1, 3) + else: + assert faces.shape[-1] == 4, "now only support quad mesh" + if isinstance(backslash, bool): + if backslash: + faces = faces[:, [0, 1, 2, 0, 2, 3]].reshape(-1, 3) + else: + faces = faces[:, [0, 1, 3, 3, 1, 2]].reshape(-1, 3) + else: + faces = torch.where( + backslash[:, None], + faces[:, [0, 1, 2, 0, 2, 3]], + faces[:, [0, 1, 3, 3, 1, 2]] + ).reshape(-1, 3) + return faces + + +@batched(2, None) +def compute_face_normal( + vertices: torch.Tensor, + faces: torch.Tensor +) -> torch.Tensor: + """ + Compute face normals of a triangular mesh + + Args: + vertices (torch.Tensor): [..., N, 3] 3-dimensional vertices + faces (torch.Tensor): [..., T, 3] triangular face indices + + Returns: + normals (torch.Tensor): [..., T, 3] face normals + """ + N = vertices.shape[0] + index = torch.arange(N)[:, None] + normal = torch.cross( + vertices[index, faces[..., 1].long()] - vertices[index, faces[..., 0].long()], + vertices[index, faces[..., 2].long()] - vertices[index, faces[..., 0].long()], + dim=-1 + ) + return F.normalize(normal, p=2, dim=-1) + + +@batched(2, None) +def compute_face_angles( + vertices: torch.Tensor, + faces: torch.Tensor +) -> torch.Tensor: + """ + Compute face angles of a triangular mesh + + Args: + vertices (torch.Tensor): [..., N, 3] 3-dimensional vertices + faces (torch.Tensor): [T, 3] triangular face indices + + Returns: + angles (torch.Tensor): [..., T, 3] face angles + """ + face_angles = [] + for i in range(3): + edge1 = torch.index_select(vertices, dim=-2, index=faces[:, (i + 1) % 3]) - torch.index_select(vertices, dim=-2, index=faces[:, i]) + edge2 = torch.index_select(vertices, dim=-2, index=faces[:, (i + 2) % 3]) - torch.index_select(vertices, dim=-2, index=faces[:, i]) + face_angle = torch.arccos(torch.sum(F.normalize(edge1, p=2, dim=-1) * F.normalize(edge2, p=2, dim=-1), dim=-1)) + face_angles.append(face_angle) + face_angles = torch.stack(face_angles, dim=-1) + return face_angles + + +@batched(2, None, 2) +def compute_vertex_normal( + vertices: torch.Tensor, + faces: torch.Tensor, + face_normal: torch.Tensor = None +) -> torch.Tensor: + """ + Compute vertex normals of a triangular mesh by averaging neightboring face normals + + Args: + vertices (torch.Tensor): [..., N, 3] 3-dimensional vertices + faces (torch.Tensor): [T, 3] triangular face indices + face_normal (torch.Tensor, optional): [..., T, 3] face normals. + None to compute face normals from vertices and faces. Defaults to None. + + Returns: + normals (torch.Tensor): [..., N, 3] vertex normals + """ + N = vertices.shape[0] + assert faces.shape[-1] == 3, "Only support triangular mesh" + if face_normal is None: + face_normal = compute_face_normal(vertices, faces) + face_normal = face_normal[:, :, None, :].expand(-1, -1, 3, -1).flatten(-3, -2) + faces = faces.flatten() + vertex_normal = torch.index_put(torch.zeros_like(vertices), (torch.arange(N)[:, None], faces[None, :]), face_normal, accumulate=True) + vertex_normal = F.normalize(vertex_normal, p=2, dim=-1) + return vertex_normal + + +@batched(2, None, 2) +def compute_vertex_normal_weighted( + vertices: torch.Tensor, + faces: torch.Tensor, + face_normal: torch.Tensor = None +) -> torch.Tensor: + """ + Compute vertex normals of a triangular mesh by weighted sum of neightboring face normals + according to the angles + + Args: + vertices (torch.Tensor): [..., N, 3] 3-dimensional vertices + faces (torch.Tensor): [T, 3] triangular face indices + face_normal (torch.Tensor, optional): [..., T, 3] face normals. + None to compute face normals from vertices and faces. Defaults to None. + + Returns: + normals (torch.Tensor): [..., N, 3] vertex normals + """ + N = vertices.shape[0] + if face_normal is None: + face_normal = compute_face_normal(vertices, faces) + face_angle = compute_face_angles(vertices, faces) + face_normal = face_normal[:, :, None, :].expand(-1, -1, 3, -1) * face_angle[..., None] + vertex_normal = torch.index_put(torch.zeros_like(vertices), (torch.arange(N)[:, None], faces.view(N, -1)), face_normal.view(N, -1, 3), accumulate=True) + vertex_normal = F.normalize(vertex_normal, p=2, dim=-1) + return vertex_normal + + +def compute_edges( + faces: torch.Tensor +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Compute edges of a mesh. + + Args: + faces (torch.Tensor): [T, 3] triangular face indices + + Returns: + edges (torch.Tensor): [E, 2] edge indices + face2edge (torch.Tensor): [T, 3] mapping from face to edge + counts (torch.Tensor): [E] degree of each edge + """ + T = faces.shape[0] + edges = torch.cat([faces[:, [0, 1]], faces[:, [1, 2]], faces[:, [2, 0]]], dim=0) # [3T, 2] + edges = torch.sort(edges, dim=1).values + edges, inv_map, counts = torch.unique(edges, return_inverse=True, return_counts=True, dim=0) + face2edge = inv_map.view(3, T).T + return edges, face2edge, counts + + +def compute_connected_components( + faces: torch.Tensor, + edges: torch.Tensor=None, + face2edge: torch.Tensor=None +) -> List[torch.Tensor]: + """ + Compute connected faces of a mesh. + + Args: + faces (torch.Tensor): [T, 3] triangular face indices + edges (torch.Tensor, optional): [E, 2] edge indices. Defaults to None. + face2edge (torch.Tensor, optional): [T, 3] mapping from face to edge. Defaults to None. + NOTE: If edges and face2edge are not provided, they will be computed. + + Returns: + components (List[torch.Tensor]): list of connected faces + """ + T = faces.shape[0] + if edges is None or face2edge is None: + edges, face2edge, _ = compute_edges(faces) + E = edges.shape[0] + + labels = torch.arange(T, dtype=torch.int32, device=faces.device) + while True: + edge_labels = torch.scatter_reduce( + torch.zeros(E, dtype=torch.int32, device=faces.device), + 0, + face2edge.flatten().long(), + labels.view(-1, 1).expand(-1, 3).flatten(), + reduce='amin', + include_self=False + ) + new_labels = torch.min(edge_labels[face2edge], dim=-1).values + if torch.equal(labels, new_labels): + break + labels = new_labels + + components = _group(labels) + + return components + + +def compute_edge_connected_components( + edges: torch.Tensor, +) -> List[torch.Tensor]: + """ + Compute connected edges of a mesh. + + Args: + edges (torch.Tensor): [E, 2] edge indices + + Returns: + components (List[torch.Tensor]): list of connected edges + """ + E = edges.shape[0] + + # Re-index edges + verts, edges = torch.unique(edges.flatten(), return_inverse=True) + edges = edges.view(-1, 2) + V = verts.shape[0] + + labels = torch.arange(E, dtype=torch.int32, device=edges.device) + while True: + vertex_labels = torch.scatter_reduce( + torch.zeros(V, dtype=torch.int32, device=edges.device), + 0, + edges.flatten().long(), + labels.view(-1, 1).expand(-1, 2).flatten(), + reduce='amin', + include_self=False + ) + new_labels = torch.min(vertex_labels[edges], dim=-1).values + if torch.equal(labels, new_labels): + break + labels = new_labels + + components = _group(labels) + + return components + + +def compute_boundarys( + faces: torch.Tensor, + edges: torch.Tensor=None, + face2edge: torch.Tensor=None, + edge_degrees: torch.Tensor=None +) -> Tuple[List[torch.Tensor], List[torch.Tensor]]: + """ + Compute boundary edges of a mesh. + + Args: + faces (torch.Tensor): [T, 3] triangular face indices + edges (torch.Tensor): [E, 2] edge indices. + face2edge (torch.Tensor): [T, 3] mapping from face to edge. + edge_degrees (torch.Tensor): [E] degree of each edge. + + Returns: + boundary_edge_indices (List[torch.Tensor]): list of boundary edge indices + boundary_face_indices (List[torch.Tensor]): list of boundary face indices + """ + # Map each edge to boundary edge index + boundary_edges = edges[edge_degrees == 1] # [BE, 2] + boundary_edges_idx = torch.nonzero(edge_degrees == 1, as_tuple=False).flatten() # [BE] + E = edges.shape[0] # Edge count + BE = boundary_edges.shape[0] # Boundary edge count + map_to_boundary_edges = torch.full((E,), -1, dtype=torch.int32, device=faces.device) # [E] + map_to_boundary_edges[boundary_edges_idx] = torch.arange(BE, dtype=torch.int32, device=faces.device) + + # Re-index boundary vertices + boundary_vertices, boundary_edges = torch.unique(boundary_edges.flatten(), return_inverse=True) + boundary_edges = boundary_edges.view(-1, 2) + BV = boundary_vertices.shape[0] + + boundary_edge_labels = torch.arange(BE, dtype=torch.int32, device=faces.device) + while True: + boundary_vertex_labels = torch.scatter_reduce( + torch.zeros(BV, dtype=torch.int32, device=faces.device), + 0, + boundary_edges.flatten().long(), + boundary_edge_labels.view(-1, 1).expand(-1, 2).flatten(), + reduce='amin', + include_self=False + ) + new_boundary_edge_labels = torch.min(boundary_vertex_labels[boundary_edges], dim=-1).values + if torch.equal(boundary_edge_labels, new_boundary_edge_labels): + break + boundary_edge_labels = new_boundary_edge_labels + + labels = torch.unique(boundary_edge_labels) + boundary_edge_indices = [boundary_edges_idx[boundary_edge_labels == label] for label in labels] + edge_labels = torch.full((E,), -1, dtype=torch.int32, device=faces.device) + edge_labels[boundary_edges_idx] = boundary_edge_labels + boundary_face_indices = [torch.nonzero((edge_labels[face2edge] == label).any(dim=-1), as_tuple=False).flatten() for label in labels] + + return boundary_edge_indices, boundary_face_indices + + +def compute_dual_graph( + face2edge: torch.Tensor, +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Compute dual graph of a mesh. + + Args: + face2edge (torch.Tensor): [T, 3] mapping from face to edge. + + Returns: + dual_edges (torch.Tensor): [DE, 2] face indices of dual edges + dual_edge2edge (torch.Tensor): [DE] mapping from dual edge to edge + """ + all_edge_indices = face2edge.flatten() # [3T] + dual_edges, dual_edge2edge = _group(all_edge_indices, required_group_size=2, return_values=True) + dual_edges = dual_edges // face2edge.shape[1] + return dual_edges, dual_edge2edge + + +def remove_unreferenced_vertices( + faces: torch.Tensor, + *vertice_attrs, + return_indices: bool = False +) -> Tuple[torch.Tensor, ...]: + """ + Remove unreferenced vertices of a mesh. + Unreferenced vertices are removed, and the face indices are updated accordingly. + + Args: + faces (torch.Tensor): [T, P] face indices + *vertice_attrs: vertex attributes + + Returns: + faces (torch.Tensor): [T, P] face indices + *vertice_attrs: vertex attributes + indices (torch.Tensor, optional): [N] indices of vertices that are kept. Defaults to None. + """ + P = faces.shape[-1] + fewer_indices, inv_map = torch.unique(faces, return_inverse=True) + faces = inv_map.to(torch.int32).reshape(-1, P) + ret = [faces] + for attr in vertice_attrs: + ret.append(attr[fewer_indices]) + if return_indices: + ret.append(fewer_indices) + return tuple(ret) + + +def remove_corrupted_faces( + faces: torch.Tensor +) -> torch.Tensor: + """ + Remove corrupted faces (faces with duplicated vertices) + + Args: + faces (torch.Tensor): [T, 3] triangular face indices + + Returns: + torch.Tensor: [T_, 3] triangular face indices + """ + corrupted = (faces[:, 0] == faces[:, 1]) | (faces[:, 1] == faces[:, 2]) | (faces[:, 2] == faces[:, 0]) + return faces[~corrupted] + + +def merge_duplicate_vertices( + vertices: torch.Tensor, + faces: torch.Tensor, + tol: float = 1e-6 +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Merge duplicate vertices of a triangular mesh. + Duplicate vertices are merged by selecte one of them, and the face indices are updated accordingly. + + Args: + vertices (torch.Tensor): [N, 3] 3-dimensional vertices + faces (torch.Tensor): [T, 3] triangular face indices + tol (float, optional): tolerance for merging. Defaults to 1e-6. + + Returns: + vertices (torch.Tensor): [N_, 3] 3-dimensional vertices + faces (torch.Tensor): [T, 3] triangular face indices + """ + vertices_round = torch.round(vertices / tol) + uni, uni_inv = torch.unique(vertices_round, dim=0, return_inverse=True) + uni[uni_inv] = vertices + faces = uni_inv[faces] + return uni, faces + + +def remove_isolated_pieces( + vertices: torch.Tensor, + faces: torch.Tensor, + connected_components: List[torch.Tensor] = None, + thresh_num_faces: int = None, + thresh_radius: float = None, + thresh_boundary_ratio: float = None, + remove_unreferenced: bool = True, +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Remove isolated pieces of a mesh. + Isolated pieces are removed, and the face indices are updated accordingly. + If no face is left, will return the largest connected component. + + Args: + vertices (torch.Tensor): [N, 3] 3-dimensional vertices + faces (torch.Tensor): [T, 3] triangular face indices + connected_components (List[torch.Tensor], optional): connected components of the mesh. If None, it will be computed. Defaults to None. + thresh_num_faces (int, optional): threshold of number of faces for isolated pieces. Defaults to None. + thresh_radius (float, optional): threshold of radius for isolated pieces. Defaults to None. + remove_unreferenced (bool, optional): remove unreferenced vertices after removing isolated pieces. Defaults to True. + + Returns: + vertices (torch.Tensor): [N_, 3] 3-dimensional vertices + faces (torch.Tensor): [T, 3] triangular face indices + """ + if connected_components is None: + connected_components = compute_connected_components(faces) + connected_components = sorted(connected_components, key=lambda x: len(x), reverse=True) + if thresh_num_faces is not None: + removed = [] + for i in range(1, len(connected_components)): + if len(connected_components[i]) < thresh_num_faces: + removed.append(i) + for i in removed[::-1]: + connected_components.pop(i) + if thresh_radius is not None: + removed = [] + for i in range(1, len(connected_components)): + comp_vertices = vertices[faces[connected_components[i]].flatten().unique()] + comp_center = comp_vertices.mean(dim=0) + comp_radius = (comp_vertices - comp_center).norm(p=2, dim=-1).max() + if comp_radius < thresh_radius: + removed.append(i) + for i in removed[::-1]: + connected_components.pop(i) + if thresh_boundary_ratio is not None: + removed = [] + for i in range(1, len(connected_components)): + edges = torch.cat([faces[connected_components[i]][:, [0, 1]], faces[connected_components[i]][:, [1, 2]], faces[connected_components[i]][:, [2, 0]]], dim=0) + edges = torch.sort(edges, dim=1).values + edges, counts = torch.unique(edges, return_counts=True, dim=0) + num_boundary_edges = (counts == 1).sum().item() + num_faces = len(connected_components[i]) + if num_boundary_edges / num_faces > thresh_boundary_ratio: + removed.append(i) + for i in removed[::-1]: + connected_components.pop(i) + + # post-process + faces = torch.cat([faces[connected_components[i]] for i in range(len(connected_components))], dim=0) + if remove_unreferenced: + faces, vertices = remove_unreferenced_vertices(faces, vertices) + return vertices, faces + + +def subdivide_mesh_simple(vertices: torch.Tensor, faces: torch.Tensor, n: int = 1) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Subdivide a triangular mesh by splitting each triangle into 4 smaller triangles. + NOTE: All original vertices are kept, and new vertices are appended to the end of the vertex list. + + Args: + vertices (torch.Tensor): [N, 3] 3-dimensional vertices + faces (torch.Tensor): [T, 3] triangular face indices + n (int, optional): number of subdivisions. Defaults to 1. + + Returns: + vertices (torch.Tensor): [N_, 3] subdivided 3-dimensional vertices + faces (torch.Tensor): [4 * T, 3] subdivided triangular face indices + """ + for _ in range(n): + edges = torch.stack([faces[:, [0, 1]], faces[:, [1, 2]], faces[:, [2, 0]]], dim=0) + edges = torch.sort(edges, dim=2) + uni_edges, uni_inv = torch.unique(edges, return_inverse=True, dim=0) + midpoints = (vertices[uni_edges[:, 0]] + vertices[uni_edges[:, 1]]) / 2 + + n_vertices = vertices.shape[0] + vertices = torch.cat([vertices, midpoints], dim=0) + faces = torch.cat([ + torch.stack([faces[:, 0], n_vertices + uni_inv[0], n_vertices + uni_inv[2]], axis=1), + torch.stack([faces[:, 1], n_vertices + uni_inv[1], n_vertices + uni_inv[0]], axis=1), + torch.stack([faces[:, 2], n_vertices + uni_inv[2], n_vertices + uni_inv[1]], axis=1), + torch.stack([n_vertices + uni_inv[0], n_vertices + uni_inv[1], n_vertices + uni_inv[2]], axis=1), + ], dim=0) + return vertices, faces + + +def compute_face_tbn(pos: torch.Tensor, faces_pos: torch.Tensor, uv: torch.Tensor, faces_uv: torch.Tensor, eps: float = 1e-7) -> torch.Tensor: + """compute TBN matrix for each face + + Args: + pos (torch.Tensor): shape (..., N_pos, 3), positions + faces_pos (torch.Tensor): shape(T, 3) + uv (torch.Tensor): shape (..., N_uv, 3) uv coordinates, + faces_uv (torch.Tensor): shape(T, 3) + + Returns: + torch.Tensor: (..., T, 3, 3) TBN matrix for each face. Note TBN vectors are normalized but not necessarily orthognal + """ + e01 = torch.index_select(pos, dim=-2, index=faces_pos[:, 1]) - torch.index_select(pos, dim=-2, index=faces_pos[:, 0]) + e02 = torch.index_select(pos, dim=-2, index=faces_pos[:, 2]) - torch.index_select(pos, dim=-2, index=faces_pos[:, 0]) + uv01 = torch.index_select(uv, dim=-2, index=faces_uv[:, 1]) - torch.index_select(uv, dim=-2, index=faces_uv[:, 0]) + uv02 = torch.index_select(uv, dim=-2, index=faces_uv[:, 2]) - torch.index_select(uv, dim=-2, index=faces_uv[:, 0]) + normal = torch.cross(e01, e02) + tangent_bitangent = torch.stack([e01, e02], dim=-1) @ torch.inverse(torch.stack([uv01, uv02], dim=-1)) + tbn = torch.cat([tangent_bitangent, normal.unsqueeze(-1)], dim=-1) + tbn = tbn / (torch.norm(tbn, p=2, dim=-2, keepdim=True) + eps) + return tbn + + +def compute_vertex_tbn(faces_topo: torch.Tensor, pos: torch.Tensor, faces_pos: torch.Tensor, uv: torch.Tensor, faces_uv: torch.Tensor) -> torch.Tensor: + """compute TBN matrix for each face + + Args: + faces_topo (torch.Tensor): (T, 3), face indice of topology + pos (torch.Tensor): shape (..., N_pos, 3), positions + faces_pos (torch.Tensor): shape(T, 3) + uv (torch.Tensor): shape (..., N_uv, 3) uv coordinates, + faces_uv (torch.Tensor): shape(T, 3) + + Returns: + torch.Tensor: (..., V, 3, 3) TBN matrix for each face. Note TBN vectors are normalized but not necessarily orthognal + """ + n_vertices = faces_topo.max().item() + 1 + n_tri = faces_topo.shape[-2] + batch_shape = pos.shape[:-2] + face_tbn = compute_face_tbn(pos, faces_pos, uv, faces_uv) # (..., T, 3, 3) + face_tbn = face_tbn[..., :, None, :, :].repeat(*[1] * len(batch_shape), 1, 3, 1, 1).view(*batch_shape, n_tri * 3, 3, 3) # (..., T * 3, 3, 3) + vertex_tbn = torch.index_add(torch.zeros(*batch_shape, n_vertices, 3, 3).to(face_tbn), dim=-3, index=faces_topo.view(-1), source=face_tbn) + vertex_tbn = vertex_tbn / (torch.norm(vertex_tbn, p=2, dim=-2, keepdim=True) + 1e-7) + return vertex_tbn + + +def laplacian(vertices: torch.Tensor, faces: torch.Tensor, weight: str = 'uniform') -> torch.Tensor: + """Laplacian smooth with cotangent weights + + Args: + vertices (torch.Tensor): shape (..., N, 3) + faces (torch.Tensor): shape (T, 3) + weight (str): 'uniform' or 'cotangent' + """ + sum_verts = torch.zeros_like(vertices) # (..., N, 3) + sum_weights = torch.zeros(*vertices.shape[:-1]).to(vertices) # (..., N) + face_verts = torch.index_select(vertices, -2, faces.view(-1)).view(*vertices.shape[:-2], *faces.shape, vertices.shape[-1]) # (..., T, 3) + if weight == 'cotangent': + for i in range(3): + e1 = face_verts[..., (i + 1) % 3, :] - face_verts[..., i, :] + e2 = face_verts[..., (i + 2) % 3, :] - face_verts[..., i, :] + cot_angle = (e1 * e2).sum(dim=-1) / torch.cross(e1, e2, dim=-1).norm(p=2, dim=-1) # (..., T, 3) + sum_verts = torch.index_add(sum_verts, -2, faces[:, (i + 1) % 3], face_verts[..., (i + 2) % 3, :] * cot_angle[..., None]) + sum_weights = torch.index_add(sum_weights, -1, faces[:, (i + 1) % 3], cot_angle) + sum_verts = torch.index_add(sum_verts, -2, faces[:, (i + 2) % 3], face_verts[..., (i + 1) % 3, :] * cot_angle[..., None]) + sum_weights = torch.index_add(sum_weights, -1, faces[:, (i + 2) % 3], cot_angle) + elif weight == 'uniform': + for i in range(3): + sum_verts = torch.index_add(sum_verts, -2, faces[:, i], face_verts[..., (i + 1) % 3, :]) + sum_weights = torch.index_add(sum_weights, -1, faces[:, i], torch.ones_like(face_verts[..., i, 0])) + else: + raise NotImplementedError + return sum_verts / (sum_weights[..., None] + 1e-7) + + +def laplacian_smooth_mesh(vertices: torch.Tensor, faces: torch.Tensor, weight: str = 'uniform', times: int = 5) -> torch.Tensor: + """Laplacian smooth with cotangent weights + + Args: + vertices (torch.Tensor): shape (..., N, 3) + faces (torch.Tensor): shape (T, 3) + weight (str): 'uniform' or 'cotangent' + """ + for _ in range(times): + vertices = laplacian(vertices, faces, weight) + return vertices + + +def taubin_smooth_mesh(vertices: torch.Tensor, faces: torch.Tensor, lambda_: float = 0.5, mu_: float = -0.51) -> torch.Tensor: + """Taubin smooth mesh + + Args: + vertices (torch.Tensor): _description_ + faces (torch.Tensor): _description_ + lambda_ (float, optional): _description_. Defaults to 0.5. + mu_ (float, optional): _description_. Defaults to -0.51. + + Returns: + torch.Tensor: _description_ + """ + pt = vertices + lambda_ * laplacian_smooth_mesh(vertices, faces) + p = pt + mu_ * laplacian_smooth_mesh(pt, faces) + return p + + +def laplacian_hc_smooth_mesh(vertices: torch.Tensor, faces: torch.Tensor, times: int = 5, alpha: float = 0.5, beta: float = 0.5, weight: str = 'uniform'): + """HC algorithm from Improved Laplacian Smoothing of Noisy Surface Meshes by J.Vollmer et al. + """ + p = vertices + for i in range(times): + q = p + p = laplacian_smooth_mesh(vertices, faces, weight) + b = p - (alpha * vertices + (1 - alpha) * q) + p = p - (beta * b + (1 - beta) * laplacian_smooth_mesh(b, faces, weight)) * 0.8 + return p diff --git a/submodules/MoGe/utils3d/torch/nerf.py b/submodules/MoGe/utils3d/torch/nerf.py new file mode 100644 index 0000000000000000000000000000000000000000..7d20bc747255dbb1a68191f93a395a824d76e108 --- /dev/null +++ b/submodules/MoGe/utils3d/torch/nerf.py @@ -0,0 +1,749 @@ +from typing import * +from numbers import Number +import math + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch import Tensor +from .utils import image_uv + + +__all__ = [ + 'get_rays', + 'get_image_rays', + 'get_mipnerf_cones', + 'volume_rendering', + 'bin_sample', + 'importance_sample', + 'nerf_render_rays', + 'mipnerf_render_rays', + 'nerf_render_view', + 'mipnerf_render_view', + 'InstantNGP', +] + + +def get_rays(extrinsics: Tensor, intrinsics: Tensor, uv: Tensor) -> Tuple[Tensor, Tensor]: + """ + Args: + extrinsics: (..., 4, 4) extrinsics matrices. + intrinsics: (..., 3, 3) intrinsics matrices. + uv: (..., n_rays, 2) uv coordinates of the rays. + + Returns: + rays_o: (..., 1, 3) ray origins + rays_d: (..., n_rays, 3) ray directions. + NOTE: ray directions are NOT normalized. They actuallys makes rays_o + rays_d * z = world coordinates, where z is the depth. + """ + uvz = torch.cat([uv, torch.ones_like(uv[..., :1])], dim=-1).to(extrinsics) # (n_batch, n_views, n_rays, 3) + + with torch.cuda.amp.autocast(enabled=False): + inv_transformation = (intrinsics @ extrinsics[..., :3, :3]).inverse() + inv_extrinsics = extrinsics.inverse() + rays_d = uvz @ inv_transformation.transpose(-1, -2) + rays_o = inv_extrinsics[..., None, :3, 3] # (n_batch, n_views, 1, 3) + return rays_o, rays_d + + +def get_image_rays(extrinsics: Tensor, intrinsics: Tensor, width: int, height: int) -> Tuple[Tensor, Tensor]: + """ + Args: + extrinsics: (..., 4, 4) extrinsics matrices. + intrinsics: (..., 3, 3) intrinsics matrices. + width: width of the image. + height: height of the image. + + Returns: + rays_o: (..., 1, 1, 3) ray origins + rays_d: (..., height, width, 3) ray directions. + NOTE: ray directions are NOT normalized. They actuallys makes rays_o + rays_d * z = world coordinates, where z is the depth. + """ + uv = image_uv(height, width).to(extrinsics).flatten(0, 1) + rays_o, rays_d = get_rays(extrinsics, intrinsics, uv) + rays_o = rays_o.unflatten(-2, (1, 1)) + rays_d = rays_d.unflatten(-2, (height, width)) + return rays_o, rays_d + + +def get_mipnerf_cones(rays_o: Tensor, rays_d: Tensor, z_vals: Tensor, pixel_width: Tensor) -> Tuple[Tensor, Tensor]: + """ + Args: + rays_o: (..., n_rays, 3) ray origins + rays_d: (..., n_rays, 3) ray directions. + z_vals: (..., n_rays, n_samples) z values. + pixel_width: (...) pixel width. = 1 / (normalized focal length * width) + + Returns: + mu: (..., n_rays, n_samples, 3) cone mu. + sigma: (..., n_rays, n_samples, 3, 3) cone sigma. + """ + t_mu = (z_vals[..., 1:] + z_vals[..., :-1]).mul_(0.5) + t_delta = (z_vals[..., 1:] - z_vals[..., :-1]).mul_(0.5) + t_mu_square = t_mu.square() + t_delta_square = t_delta.square() + t_delta_quad = t_delta_square.square() + mu_t = t_mu + 2.0 * t_mu * t_delta_square / (3.0 * t_mu_square + t_delta_square) + sigma_t = t_delta_square / 3.0 - (4.0 / 15.0) * t_delta_quad / (3.0 * t_mu_square + t_delta_square).square() * (12.0 * t_mu_square - t_delta_square) + sigma_r = (pixel_width[..., None, None].square() / 3.0) * (t_mu_square / 4.0 + (5.0 / 12.0) * t_delta_square - (4.0 / 15.0) * t_delta_quad / (3.0 * t_mu_square + t_delta_square)) + points_mu = rays_o[:, :, :, None, :] + rays_d[:, :, :, None, :] * mu_t[..., None] + d_dt = rays_d[..., :, None] * rays_d[..., None, :] # (..., n_rays, 3, 3) + points_sigma = sigma_t[..., None, None] * d_dt[..., None, :, :] + sigma_r[..., None, None] * (torch.eye(3).to(rays_o) - d_dt[..., None, :, :]) + return points_mu, points_sigma + + +def get_pixel_width(intrinsics: Tensor, width: int, height: int) -> Tensor: + """ + Args: + intrinsics: (..., 3, 3) intrinsics matrices. + width: width of the image. + height: height of the image. + + Returns: + pixel_width: (...) pixel width. = 1 / (normalized focal length * width) + """ + assert width == height, "Currently, only square images are supported." + pixel_width = torch.reciprocal((intrinsics[..., 0, 0] * intrinsics[..., 1, 1]).sqrt() * width) + return pixel_width + + +def volume_rendering(color: Tensor, sigma: Tensor, z_vals: Tensor, ray_length: Tensor, rgb: bool = True, depth: bool = True) -> Tuple[Tensor, Tensor, Tensor]: + """ + Given color, sigma and z_vals (linear depth of the sampling points), render the volume. + + NOTE: By default, color and sigma should have one less sample than z_vals, in correspondence with the average value in intervals. + If queried color are aligned with z_vals, we use trapezoidal rule to calculate the average values in intervals. + + Args: + color: (..., n_samples or n_samples - 1, 3) color values. + sigma: (..., n_samples or n_samples - 1) density values. + z_vals: (..., n_samples) z values. + ray_length: (...) length of the ray + + Returns: + rgb: (..., 3) rendered color values. + depth: (...) rendered depth values. + weights (..., n_samples) weights. + """ + dists = (z_vals[..., 1:] - z_vals[..., :-1]) * ray_length[..., None] + if color.shape[-2] == z_vals.shape[-1]: + color = (color[..., 1:, :] + color[..., :-1, :]).mul_(0.5) + sigma = (sigma[..., 1:] + sigma[..., :-1]).mul_(0.5) + sigma_delta = sigma * dists + transparancy = (-torch.cat([torch.zeros_like(sigma_delta[..., :1]), sigma_delta[..., :-1]], dim=-1).cumsum(dim=-1)).exp_() # First cumsum then exp for numerical stability + alpha = 1.0 - (-sigma_delta).exp_() + weights = alpha * transparancy + if rgb: + rgb = torch.sum(weights[..., None] * color, dim=-2) if rgb else None + if depth: + z_vals = (z_vals[..., 1:] + z_vals[..., :-1]).mul_(0.5) + depth = torch.sum(weights * z_vals, dim=-1) / weights.sum(dim=-1).clamp_min_(1e-8) if depth else None + return rgb, depth, weights + + +def neus_volume_rendering(color: Tensor, sdf: Tensor, s: torch.Tensor, z_vals: Tensor = None, rgb: bool = True, depth: bool = True) -> Tuple[Tensor, Tensor, Tensor]: + """ + Given color, sdf values and z_vals (linear depth of the sampling points), do volume rendering. (NeuS) + + Args: + color: (..., n_samples or n_samples - 1, 3) color values. + sdf: (..., n_samples) sdf values. + s: (..., n_samples) S values of S-density function in NeuS. The standard deviation of such S-density distribution is 1 / s. + z_vals: (..., n_samples) z values. + ray_length: (...) length of the ray + + Returns: + rgb: (..., 3) rendered color values. + depth: (...) rendered depth values. + weights (..., n_samples) weights. + """ + + if color.shape[-2] == z_vals.shape[-1]: + color = (color[..., 1:, :] + color[..., :-1, :]).mul_(0.5) + + sigmoid_sdf = torch.sigmoid(s * sdf) + alpha = F.relu(1 - sigmoid_sdf[..., :-1] / sigmoid_sdf[..., :-1]) + transparancy = torch.cumprod(torch.cat([torch.ones_like(alpha[..., :1]), alpha], dim=-1), dim=-1) + weights = alpha * transparancy + + if rgb: + rgb = torch.sum(weights[..., None] * color, dim=-2) if rgb else None + if depth: + z_vals = (z_vals[..., 1:] + z_vals[..., :-1]).mul_(0.5) + depth = torch.sum(weights * z_vals, dim=-1) / weights.sum(dim=-1).clamp_min_(1e-8) if depth else None + return rgb, depth, weights + + +def bin_sample(size: Union[torch.Size, Tuple[int, ...]], n_samples: int, min_value: Number, max_value: Number, spacing: Literal['linear', 'inverse_linear'], dtype: torch.dtype = None, device: torch.device = None) -> Tensor: + """ + Uniformly (or uniformly in inverse space) sample z values in `n_samples` bins in range [min_value, max_value]. + Args: + size: size of the rays + n_samples: number of samples to be sampled, also the number of bins + min_value: minimum value of the range + max_value: maximum value of the range + space: 'linear' or 'inverse_linear'. If 'inverse_linear', the sampling is uniform in inverse space. + + Returns: + z_rand: (*size, n_samples) sampled z values, sorted in ascending order. + """ + if spacing == 'linear': + pass + elif spacing == 'inverse_linear': + min_value = 1.0 / min_value + max_value = 1.0 / max_value + bin_length = (max_value - min_value) / n_samples + z_rand = (torch.rand(*size, n_samples, device=device, dtype=dtype) - 0.5) * bin_length + torch.linspace(min_value + bin_length * 0.5, max_value - bin_length * 0.5, n_samples, device=device, dtype=dtype) + if spacing == 'inverse_linear': + z_rand = 1.0 / z_rand + return z_rand + + +def importance_sample(z_vals: Tensor, weights: Tensor, n_samples: int) -> Tuple[Tensor, Tensor]: + """ + Importance sample z values. + + NOTE: By default, weights should have one less sample than z_vals, in correspondence with the intervals. + If weights has the same number of samples as z_vals, we use trapezoidal rule to calculate the average weights in intervals. + + Args: + z_vals: (..., n_rays, n_input_samples) z values, sorted in ascending order. + weights: (..., n_rays, n_input_samples or n_input_samples - 1) weights. + n_samples: number of output samples for importance sampling. + + Returns: + z_importance: (..., n_rays, n_samples) importance sampled z values, unsorted. + """ + if weights.shape[-1] == z_vals.shape[-1]: + weights = (weights[..., 1:] + weights[..., :-1]).mul_(0.5) + weights = weights / torch.sum(weights, dim=-1, keepdim=True) # (..., n_rays, n_input_samples - 1) + bins_a, bins_b = z_vals[..., :-1], z_vals[..., 1:] + + pdf = weights / torch.sum(weights, dim=-1, keepdim=True) # (..., n_rays, n_input_samples - 1) + cdf = torch.cumsum(pdf, dim=-1) + u = torch.rand(*z_vals.shape[:-1], n_samples, device=z_vals.device, dtype=z_vals.dtype) + + inds = torch.searchsorted(cdf, u, right=True).clamp(0, cdf.shape[-1] - 1) # (..., n_rays, n_samples) + + bins_a = torch.gather(bins_a, dim=-1, index=inds) + bins_b = torch.gather(bins_b, dim=-1, index=inds) + z_importance = bins_a + (bins_b - bins_a) * torch.rand_like(u) + return z_importance + + +def nerf_render_rays( + nerf: Union[Callable[[Tensor, Tensor], Tuple[Tensor, Tensor]], Tuple[Callable[[Tensor], Tuple[Tensor, Tensor]], Callable[[Tensor], Tuple[Tensor, Tensor]]]], + rays_o: Tensor, rays_d: Tensor, + *, + return_dict: bool = False, + n_coarse: int = 64, n_fine: int = 64, + near: float = 0.1, far: float = 100.0, + z_spacing: Literal['linear', 'inverse_linear'] = 'linear', +): + """ + NeRF rendering of rays. Note that it supports arbitrary batch dimensions (denoted as `...`) + + Args: + nerf: nerf model, which takes (points, directions) as input and returns (color, density) as output. + If nerf is a tuple, it should be (nerf_coarse, nerf_fine), where nerf_coarse and nerf_fine are two nerf models for coarse and fine stages respectively. + + nerf args: + points: (..., n_rays, n_samples, 3) + directions: (..., n_rays, n_samples, 3) + nerf returns: + color: (..., n_rays, n_samples, 3) color values. + density: (..., n_rays, n_samples) density values. + + rays_o: (..., n_rays, 3) ray origins + rays_d: (..., n_rays, 3) ray directions. + pixel_width: (..., n_rays) pixel width. How to compute? pixel_width = 1 / (normalized focal length * width) + + Returns + if return_dict is False, return rendered rgb and depth for short cut. (If there are separate coarse and fine results, return fine results) + rgb: (..., n_rays, 3) rendered color values. + depth: (..., n_rays) rendered depth values. + else, return a dict. If `n_fine == 0` or `nerf` is a single model, the dict only contains coarse results: + ``` + {'rgb': .., 'depth': .., 'weights': .., 'z_vals': .., 'color': .., 'density': ..} + ``` + If there are two models for coarse and fine stages, the dict contains both coarse and fine results: + ``` + { + "coarse": {'rgb': .., 'depth': .., 'weights': .., 'z_vals': .., 'color': .., 'density': ..}, + "fine": {'rgb': .., 'depth': .., 'weights': .., 'z_vals': .., 'color': .., 'density': ..} + } + ``` + """ + if isinstance(nerf, tuple): + nerf_coarse, nerf_fine = nerf + else: + nerf_coarse = nerf_fine = nerf + # 1. Coarse: bin sampling + z_coarse = bin_sample(rays_d.shape[:-1], n_coarse, near, far, device=rays_o.device, dtype=rays_o.dtype, spacing=z_spacing) # (n_batch, n_views, n_rays, n_samples) + points_coarse = rays_o[..., None, :] + rays_d[..., None, :] * z_coarse[..., None] # (n_batch, n_views, n_rays, n_samples, 3) + ray_length = rays_d.norm(dim=-1) + + # Query color and density + color_coarse, density_coarse = nerf_coarse(points_coarse, rays_d[..., None, :].expand_as(points_coarse)) # (n_batch, n_views, n_rays, n_samples, 3), (n_batch, n_views, n_rays, n_samples) + + # Volume rendering + with torch.no_grad(): + rgb_coarse, depth_coarse, weights = volume_rendering(color_coarse, density_coarse, z_coarse, ray_length) # (n_batch, n_views, n_rays, 3), (n_batch, n_views, n_rays, 1), (n_batch, n_views, n_rays, n_samples) + + if n_fine == 0: + if return_dict: + return {'rgb': rgb_coarse, 'depth': depth_coarse, 'weights': weights, 'z_vals': z_coarse, 'color': color_coarse, 'density': density_coarse} + else: + return rgb_coarse, depth_coarse + + # 2. Fine: Importance sampling + if nerf_coarse is nerf_fine: + # If coarse and fine stages share the same model, the points of coarse stage can be reused, + # and we only need to query the importance samples of fine stage. + z_fine = importance_sample(z_coarse, weights, n_fine) + points_fine = rays_o[..., None, :] + rays_d[..., None, :] * z_fine[..., None] + color_fine, density_fine = nerf_fine(points_fine, rays_d[..., None, :].expand_as(points_fine)) + + # Merge & volume rendering + z_vals = torch.cat([z_coarse, z_fine], dim=-1) + color = torch.cat([color_coarse, color_fine], dim=-2) + density = torch.cat([density_coarse, density_fine], dim=-1) + z_vals, sort_inds = torch.sort(z_vals, dim=-1) + color = torch.gather(color, dim=-2, index=sort_inds[..., None].expand_as(color)) + density = torch.gather(density, dim=-1, index=sort_inds) + rgb, depth, weights = volume_rendering(color, density, z_vals, ray_length) + + if return_dict: + return {'rgb': rgb, 'depth': depth, 'weights': weights, 'z_vals': z_vals, 'color': color, 'density': density} + else: + return rgb, depth + else: + # If coarse and fine stages use different models, we need to query the importance samples of both stages. + z_fine = importance_sample(z_coarse, weights, n_fine) + z_vals = torch.cat([z_coarse, z_fine], dim=-1) + points = rays_o[..., None, :] + rays_d[..., None, :] * z_vals[..., None] + color, density = nerf_fine(points) + rgb, depth, weights = volume_rendering(color, density, z_vals, ray_length) + + if return_dict: + return { + 'coarse': {'rgb': rgb_coarse, 'depth': depth_coarse, 'weights': weights, 'z_vals': z_coarse, 'color': color_coarse, 'density': density_coarse}, + 'fine': {'rgb': rgb, 'depth': depth, 'weights': weights, 'z_vals': z_vals, 'color': color, 'density': density} + } + else: + return rgb, depth + + +def mipnerf_render_rays( + mipnerf: Callable[[Tensor, Tensor, Tensor], Tuple[Tensor, Tensor]], + rays_o: Tensor, rays_d: Tensor, pixel_width: Tensor, + *, + return_dict: bool = False, + n_coarse: int = 64, n_fine: int = 64, uniform_ratio: float = 0.4, + near: float = 0.1, far: float = 100.0, + z_spacing: Literal['linear', 'inverse_linear'] = 'linear', +) -> Union[Tuple[Tensor, Tensor], Dict[str, Tensor]]: + """ + MipNeRF rendering. + + Args: + mipnerf: mipnerf model, which takes (points_mu, points_sigma) as input and returns (color, density) as output. + + mipnerf args: + points_mu: (..., n_rays, n_samples, 3) cone mu. + points_sigma: (..., n_rays, n_samples, 3, 3) cone sigma. + directions: (..., n_rays, n_samples, 3) + mipnerf returns: + color: (..., n_rays, n_samples, 3) color values. + density: (..., n_rays, n_samples) density values. + + rays_o: (..., n_rays, 3) ray origins + rays_d: (..., n_rays, 3) ray directions. + pixel_width: (..., n_rays) pixel width. How to compute? pixel_width = 1 / (normalized focal length * width) + + Returns + if return_dict is False, return rendered results only: (If `n_fine == 0`, return coarse results, otherwise return fine results) + rgb: (..., n_rays, 3) rendered color values. + depth: (..., n_rays) rendered depth values. + else, return a dict. If `n_fine == 0`, the dict only contains coarse results: + ``` + {'rgb': .., 'depth': .., 'weights': .., 'z_vals': .., 'color': .., 'density': ..} + ``` + If n_fine > 0, the dict contains both coarse and fine results : + ``` + { + "coarse": {'rgb': .., 'depth': .., 'weights': .., 'z_vals': .., 'color': .., 'density': ..}, + "fine": {'rgb': .., 'depth': .., 'weights': .., 'z_vals': .., 'color': .., 'density': ..} + } + ``` + """ + # 1. Coarse: bin sampling + z_coarse = bin_sample(rays_d.shape[:-1], n_coarse, near, far, spacing=z_spacing, device=rays_o.device, dtype=rays_o.dtype) + points_mu_coarse, points_sigma_coarse = get_mipnerf_cones(rays_o, rays_d, z_coarse, pixel_width) + ray_length = rays_d.norm(dim=-1) + + # Query color and density + color_coarse, density_coarse = mipnerf(points_mu_coarse, points_sigma_coarse, rays_d[..., None, :].expand_as(points_mu_coarse)) # (n_batch, n_views, n_rays, n_samples, 3), (n_batch, n_views, n_rays, n_samples) + + # Volume rendering + rgb_coarse, depth_coarse, weights_coarse = volume_rendering(color_coarse, density_coarse, z_coarse, ray_length) # (n_batch, n_views, n_rays, 3), (n_batch, n_views, n_rays, 1), (n_batch, n_views, n_rays, n_samples) + + if n_fine == 0: + if return_dict: + return {'rgb': rgb_coarse, 'depth': depth_coarse, 'weights': weights_coarse, 'z_vals': z_coarse, 'color': color_coarse, 'density': density_coarse} + else: + return rgb_coarse, depth_coarse + + # 2. Fine: Importance sampling. (NOTE: coarse stages and fine stages always share the same model, but coarse stage points can not be reused) + with torch.no_grad(): + weights_coarse = (1.0 - uniform_ratio) * weights_coarse + uniform_ratio / weights_coarse.shape[-1] + z_fine = importance_sample(z_coarse, weights_coarse, n_fine) + z_fine, _ = torch.sort(z_fine, dim=-2) + points_mu_fine, points_sigma_fine = get_mipnerf_cones(rays_o, rays_d, z_fine, pixel_width) + color_fine, density_fine = mipnerf(points_mu_fine, points_sigma_fine, rays_d[..., None, :].expand_as(points_mu_fine)) + + # Volume rendering + rgb_fine, depth_fine, weights_fine = volume_rendering(color_fine, density_fine, z_fine, ray_length) + + if return_dict: + return { + 'coarse': {'rgb': rgb_coarse, 'depth': depth_coarse, 'weights': weights_coarse, 'z_vals': z_coarse, 'color': color_coarse, 'density': density_coarse}, + 'fine': {'rgb': rgb_fine, 'depth': depth_fine, 'weights': weights_fine, 'z_vals': z_fine, 'color': color_fine, 'density': density_fine} + } + else: + return rgb_fine, depth_fine + + +def neus_render_rays( + neus: Callable[[Tensor, Tensor], Tuple[Tensor, Tensor]], + s: Union[Number, Tensor], + rays_o: Tensor, rays_d: Tensor, + *, + compute_normal: bool = True, + return_dict: bool = False, + n_coarse: int = 64, n_fine: int = 64, + near: float = 0.1, far: float = 100.0, + z_spacing: Literal['linear', 'inverse_linear'] = 'linear', +): + """ + TODO + NeuS rendering of rays. Note that it supports arbitrary batch dimensions (denoted as `...`) + + Args: + neus: neus model, which takes (points, directions) as input and returns (color, density) as output. + + nerf args: + points: (..., n_rays, n_samples, 3) + directions: (..., n_rays, n_samples, 3) + nerf returns: + color: (..., n_rays, n_samples, 3) color values. + density: (..., n_rays, n_samples) density values. + + rays_o: (..., n_rays, 3) ray origins + rays_d: (..., n_rays, 3) ray directions. + pixel_width: (..., n_rays) pixel width. How to compute? pixel_width = 1 / (normalized focal length * width) + + Returns + if return_dict is False, return rendered results only: (If `n_fine == 0`, return coarse results, otherwise return fine results) + rgb: (..., n_rays, 3) rendered color values. + depth: (..., n_rays) rendered depth values. + else, return a dict. If `n_fine == 0`, the dict only contains coarse results: + ``` + {'rgb': .., 'depth': .., 'weights': .., 'z_vals': .., 'color': .., 'sdf': ..., 'normal': ...} + ``` + If n_fine > 0, the dict contains both coarse and fine results: + ``` + { + "coarse": {'rgb': .., 'depth': .., 'weights': .., 'z_vals': .., 'color': .., 'density': ..}, + "fine": {'rgb': .., 'depth': .., 'weights': .., 'z_vals': .., 'color': .., 'density': ..} + } + ``` + """ + + # 1. Coarse: bin sampling + z_coarse = bin_sample(rays_d.shape[:-1], n_coarse, near, far, device=rays_o.device, dtype=rays_o.dtype, spacing=z_spacing) # (n_batch, n_views, n_rays, n_samples) + points_coarse = rays_o[..., None, :] + rays_d[..., None, :] * z_coarse[..., None] # (n_batch, n_views, n_rays, n_samples, 3) + + # Query color and density + color_coarse, sdf_coarse = neus(points_coarse, rays_d[..., None, :].expand_as(points_coarse)) # (n_batch, n_views, n_rays, n_samples, 3), (n_batch, n_views, n_rays, n_samples) + + # Volume rendering + with torch.no_grad(): + rgb_coarse, depth_coarse, weights = neus_volume_rendering(color_coarse, sdf_coarse, s, z_coarse) # (n_batch, n_views, n_rays, 3), (n_batch, n_views, n_rays, 1), (n_batch, n_views, n_rays, n_samples) + + if n_fine == 0: + if return_dict: + return {'rgb': rgb_coarse, 'depth': depth_coarse, 'weights': weights, 'z_vals': z_coarse, 'color': color_coarse, 'sdf': sdf_coarse} + else: + return rgb_coarse, depth_coarse + + # If coarse and fine stages share the same model, the points of coarse stage can be reused, + # and we only need to query the importance samples of fine stage. + z_fine = importance_sample(z_coarse, weights, n_fine) + points_fine = rays_o[..., None, :] + rays_d[..., None, :] * z_fine[..., None] + color_fine, sdf_fine = neus(points_fine, rays_d[..., None, :].expand_as(points_fine)) + + # Merge & volume rendering + z_vals = torch.cat([z_coarse, z_fine], dim=-1) + color = torch.cat([color_coarse, color_fine], dim=-2) + sdf = torch.cat([sdf_coarse, sdf_fine], dim=-1) + z_vals, sort_inds = torch.sort(z_vals, dim=-1) + color = torch.gather(color, dim=-2, index=sort_inds[..., None].expand_as(color)) + sdf = torch.gather(sdf, dim=-1, index=sort_inds) + rgb, depth, weights = neus_volume_rendering(color, sdf, s, z_vals) + + if return_dict: + return { + 'coarse': {'rgb': rgb_coarse, 'depth': depth_coarse, 'weights': weights, 'z_vals': z_coarse, 'color': color_coarse, 'sdf': sdf_coarse}, + 'fine': {'rgb': rgb, 'depth': depth, 'weights': weights, 'z_vals': z_vals, 'color': color, 'sdf': sdf} + } + else: + return rgb, depth + + +def nerf_render_view( + nerf: Tensor, + extrinsics: Tensor, + intrinsics: Tensor, + width: int, + height: int, + *, + patchify: bool = False, + patch_size: Tuple[int, int] = (64, 64), + **options: Dict[str, Any] +) -> Tuple[Tensor, Tensor]: + """ + NeRF rendering of views. Note that it supports arbitrary batch dimensions (denoted as `...`) + + Args: + extrinsics: (..., 4, 4) extrinsics matrice of the rendered views + intrinsics (optional): (..., 3, 3) intrinsics matrice of the rendered views. + width (optional): image width of the rendered views. + height (optional): image height of the rendered views. + patchify (optional): If the image is too large, render it patch by patch + **options: rendering options. + + Returns: + rgb: (..., channels, height, width) rendered color values. + depth: (..., height, width) rendered depth values. + """ + if patchify: + # Patchified rendering + max_patch_width, max_patch_height = patch_size + n_rows, n_columns = math.ceil(height / max_patch_height), math.ceil(width / max_patch_width) + + rgb_rows, depth_rows = [], [] + for i_row in range(n_rows): + rgb_row, depth_row = [], [] + for i_column in range(n_columns): + patch_shape = patch_height, patch_width = min(max_patch_height, height - i_row * max_patch_height), min(max_patch_width, width - i_column * max_patch_width) + uv = image_uv(height, width, i_column * max_patch_width, i_row * max_patch_height, i_column * max_patch_width + patch_width, i_row * max_patch_height + patch_height).to(extrinsics) + uv = uv.flatten(0, 1) # (patch_height * patch_width, 2) + ray_o_, ray_d_ = get_rays(extrinsics, intrinsics, uv) + rgb_, depth_ = nerf_render_rays(nerf, ray_o_, ray_d_, **options, return_dict=False) + rgb_ = rgb_.transpose(-1, -2).unflatten(-1, patch_shape) # (..., 3, patch_height, patch_width) + depth_ = depth_.unflatten(-1, patch_shape) # (..., patch_height, patch_width) + + rgb_row.append(rgb_) + depth_row.append(depth_) + rgb_rows.append(torch.cat(rgb_row, dim=-1)) + depth_rows.append(torch.cat(depth_row, dim=-1)) + rgb = torch.cat(rgb_rows, dim=-2) + depth = torch.cat(depth_rows, dim=-2) + + return rgb, depth + else: + # Full rendering + uv = image_uv(height, width).to(extrinsics) + uv = uv.flatten(0, 1) # (height * width, 2) + ray_o_, ray_d_ = get_rays(extrinsics, intrinsics, uv) + rgb, depth = nerf_render_rays(nerf, ray_o_, ray_d_, **options, return_dict=False) + rgb = rgb.transpose(-1, -2).unflatten(-1, (height, width)) # (..., 3, height, width) + depth = depth.unflatten(-1, (height, width)) # (..., height, width) + + return rgb, depth + + +def mipnerf_render_view( + mipnerf: Tensor, + extrinsics: Tensor, + intrinsics: Tensor, + width: int, + height: int, + *, + patchify: bool = False, + patch_size: Tuple[int, int] = (64, 64), + **options: Dict[str, Any] +) -> Tuple[Tensor, Tensor]: + """ + MipNeRF rendering of views. Note that it supports arbitrary batch dimensions (denoted as `...`) + + Args: + extrinsics: (..., 4, 4) extrinsics matrice of the rendered views + intrinsics (optional): (..., 3, 3) intrinsics matrice of the rendered views. + width (optional): image width of the rendered views. + height (optional): image height of the rendered views. + patchify (optional): If the image is too large, render it patch by patch + **options: rendering options. + + Returns: + rgb: (..., 3, height, width) rendered color values. + depth: (..., height, width) rendered depth values. + """ + pixel_width = get_pixel_width(intrinsics, width, height) + + if patchify: + # Patchified rendering + max_patch_width, max_patch_height = patch_size + n_rows, n_columns = math.ceil(height / max_patch_height), math.ceil(width / max_patch_width) + + rgb_rows, depth_rows = [], [] + for i_row in range(n_rows): + rgb_row, depth_row = [], [] + for i_column in range(n_columns): + patch_shape = patch_height, patch_width = min(max_patch_height, height - i_row * max_patch_height), min(max_patch_width, width - i_column * max_patch_width) + uv = image_uv(height, width, i_column * max_patch_width, i_row * max_patch_height, i_column * max_patch_width + patch_width, i_row * max_patch_height + patch_height).to(extrinsics) + uv = uv.flatten(0, 1) # (patch_height * patch_width, 2) + ray_o_, ray_d_ = get_rays(extrinsics, intrinsics, uv) + rgb_, depth_ = mipnerf_render_rays(mipnerf, ray_o_, ray_d_, pixel_width, **options) + rgb_ = rgb_.transpose(-1, -2).unflatten(-1, patch_shape) # (..., 3, patch_height, patch_width) + depth_ = depth_.unflatten(-1, patch_shape) # (..., patch_height, patch_width) + + rgb_row.append(rgb_) + depth_row.append(depth_) + rgb_rows.append(torch.cat(rgb_row, dim=-1)) + depth_rows.append(torch.cat(depth_row, dim=-1)) + rgb = torch.cat(rgb_rows, dim=-2) + depth = torch.cat(depth_rows, dim=-2) + + return rgb, depth + else: + # Full rendering + uv = image_uv(height, width).to(extrinsics) + uv = uv.flatten(0, 1) # (height * width, 2) + ray_o_, ray_d_ = get_rays(extrinsics, intrinsics, uv) + rgb, depth = mipnerf_render_rays(mipnerf, ray_o_, ray_d_, pixel_width, **options) + rgb = rgb.transpose(-1, -2).unflatten(-1, (height, width)) # (..., 3, height, width) + depth = depth.unflatten(-1, (height, width)) # (..., height, width) + + return rgb, depth + + +class InstantNGP(nn.Module): + """ + An implementation of InstantNGP, Müller et. al., https://nvlabs.github.io/instant-ngp/. + Requires `tinycudann` package. + Install it by: + ``` + pip install git+https://github.com/NVlabs/tiny-cuda-nn/#subdirectory=bindings/torch + ``` + """ + def __init__(self, + view_dependent: bool = True, + base_resolution: int = 16, + finest_resolution: int = 2048, + n_levels: int = 16, + num_layers_density: int = 2, + hidden_dim_density: int = 64, + num_layers_color: int = 3, + hidden_dim_color: int = 64, + log2_hashmap_size: int = 19, + bound: float = 1.0, + color_channels: int = 3, + ): + super().__init__() + import tinycudann + N_FEATURES_PER_LEVEL = 2 + GEO_FEAT_DIM = 15 + + self.bound = bound + self.color_channels = color_channels + + # density network + self.num_layers_density = num_layers_density + self.hidden_dim_density = hidden_dim_density + + per_level_scale = (finest_resolution / base_resolution) ** (1 / (n_levels - 1)) + + self.encoder = tinycudann.Encoding( + n_input_dims=3, + encoding_config={ + "otype": "HashGrid", + "n_levels": n_levels, + "n_features_per_level": N_FEATURES_PER_LEVEL, + "log2_hashmap_size": log2_hashmap_size, + "base_resolution": base_resolution, + "per_level_scale": per_level_scale, + }, + ) + + self.density_net = tinycudann.Network( + n_input_dims=N_FEATURES_PER_LEVEL * n_levels, + n_output_dims=1 + GEO_FEAT_DIM, + network_config={ + "otype": "FullyFusedMLP", + "activation": "ReLU", + "output_activation": "None", + "n_neurons": hidden_dim_density, + "n_hidden_layers": num_layers_density - 1, + }, + ) + + # color network + self.num_layers_color = num_layers_color + self.hidden_dim_color = hidden_dim_color + + self.view_dependent = view_dependent + if view_dependent: + self.encoder_dir = tinycudann.Encoding( + n_input_dims=3, + encoding_config={ + "otype": "SphericalHarmonics", + "degree": 4, + }, + ) + self.in_dim_color = self.encoder_dir.n_output_dims + GEO_FEAT_DIM + else: + self.in_dim_color = GEO_FEAT_DIM + + self.color_net = tinycudann.Network( + n_input_dims=self.in_dim_color, + n_output_dims=color_channels, + network_config={ + "otype": "FullyFusedMLP", + "activation": "ReLU", + "output_activation": "None", + "n_neurons": hidden_dim_color, + "n_hidden_layers": num_layers_color - 1, + }, + ) + + def forward(self, x: torch.Tensor, d: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Args: + x: (..., 3) points + d: (..., 3) directions + Returns: + color: (..., 3) color values. + density: (..., 1) density values. + """ + batch_shape = x.shape[:-1] + x, d = x.reshape(-1, 3), d.reshape(-1, 3) + + # density + x = (x + self.bound) / (2 * self.bound) # to [0, 1] + x = self.encoder(x) + density, geo_feat = self.density_net(x).split([1, 15], dim=-1) + density = F.softplus(density).squeeze(-1) + + # color + if self.view_dependent: + d = (F.normalize(d, dim=-1) + 1) / 2 # tcnn SH encoding requires inputs to be in [0, 1] + d = self.encoder_dir(d) + h = torch.cat([d, geo_feat], dim=-1) + else: + h = geo_feat + color = self.color_net(h) + + return color.reshape(*batch_shape, self.color_channels), density.reshape(*batch_shape) + diff --git a/submodules/MoGe/utils3d/torch/rasterization.py b/submodules/MoGe/utils3d/torch/rasterization.py new file mode 100644 index 0000000000000000000000000000000000000000..11802737ebeae9be2e6b7bda7ee0d933b01ab909 --- /dev/null +++ b/submodules/MoGe/utils3d/torch/rasterization.py @@ -0,0 +1,392 @@ +from typing import * + +import torch +import nvdiffrast.torch as dr + +from . import utils, transforms, mesh +from ._helpers import batched + + +__all__ = [ + 'RastContext', + 'rasterize_triangle_faces', + 'warp_image_by_depth', + 'warp_image_by_forward_flow', +] + + +class RastContext: + """ + Create a rasterization context. Nothing but a wrapper of nvdiffrast.torch.RasterizeCudaContext or nvdiffrast.torch.RasterizeGLContext. + """ + def __init__(self, nvd_ctx: Union[dr.RasterizeCudaContext, dr.RasterizeGLContext] = None, *, backend: Literal['cuda', 'gl'] = 'gl', device: Union[str, torch.device] = None): + if nvd_ctx is not None: + self.nvd_ctx = nvd_ctx + return + + if backend == 'gl': + self.nvd_ctx = dr.RasterizeGLContext(device=device) + elif backend == 'cuda': + self.nvd_ctx = dr.RasterizeCudaContext(device=device) + else: + raise ValueError(f'Unknown backend: {backend}') + + +def rasterize_triangle_faces( + ctx: RastContext, + vertices: torch.Tensor, + faces: torch.Tensor, + width: int, + height: int, + attr: torch.Tensor = None, + uv: torch.Tensor = None, + texture: torch.Tensor = None, + model: torch.Tensor = None, + view: torch.Tensor = None, + projection: torch.Tensor = None, + antialiasing: Union[bool, List[int]] = True, + diff_attrs: Union[None, List[int]] = None, +) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: + """ + Rasterize a mesh with vertex attributes. + + Args: + ctx (GLContext): rasterizer context + vertices (np.ndarray): (B, N, 2 or 3 or 4) + faces (torch.Tensor): (T, 3) + width (int): width of the output image + height (int): height of the output image + attr (torch.Tensor, optional): (B, N, C) vertex attributes. Defaults to None. + uv (torch.Tensor, optional): (B, N, 2) uv coordinates. Defaults to None. + texture (torch.Tensor, optional): (B, H, W, C) texture. Defaults to None. + model (torch.Tensor, optional): ([B,] 4, 4) model matrix. Defaults to None (identity). + view (torch.Tensor, optional): ([B,] 4, 4) view matrix. Defaults to None (identity). + projection (torch.Tensor, optional): ([B,] 4, 4) projection matrix. Defaults to None (identity). + antialiasing (Union[bool, List[int]], optional): whether to perform antialiasing. Defaults to True. If a list of indices is provided, only those channels will be antialiased. + diff_attrs (Union[None, List[int]], optional): indices of attributes to compute screen-space derivatives. Defaults to None. + + Returns: + Dictionary containing: + - image: (torch.Tensor): (B, C, H, W) + - depth: (torch.Tensor): (B, H, W) screen space depth, ranging from 0 (near) to 1. (far) + NOTE: Empty pixels will have depth 1., i.e. far plane. + - mask: (torch.BoolTensor): (B, H, W) mask of valid pixels + - image_dr: (torch.Tensor): (B, 4, H, W) screen space derivatives of the attributes + - face_id: (torch.Tensor): (B, H, W) face ids + - uv: (torch.Tensor): (B, N, 2) uv coordinates (if uv is not None) + - uv_dr: (torch.Tensor): (B, N, 4) uv derivatives (if uv is not None) + - texture: (torch.Tensor): (B, H, W, C) texture (if uv and texture are not None) + """ + assert vertices.ndim == 3 + assert faces.ndim == 2 + + if vertices.shape[-1] == 2: + vertices = torch.cat([vertices, torch.zeros_like(vertices[..., :1]), torch.ones_like(vertices[..., :1])], dim=-1) + elif vertices.shape[-1] == 3: + vertices = torch.cat([vertices, torch.ones_like(vertices[..., :1])], dim=-1) + elif vertices.shape[-1] == 4: + pass + else: + raise ValueError(f'Wrong shape of vertices: {vertices.shape}') + + mvp = projection if projection is not None else torch.eye(4).to(vertices) + if view is not None: + mvp = mvp @ view + if model is not None: + mvp = mvp @ model + + pos_clip = vertices @ mvp.transpose(-1, -2) + faces = faces.contiguous() + if attr is not None: + attr = attr.contiguous() + + rast_out, rast_db = dr.rasterize(ctx.nvd_ctx, pos_clip, faces, resolution=[height, width], grad_db=True) + face_id = rast_out[..., 3].flip(1) + depth = rast_out[..., 2].flip(1) + mask = (face_id > 0).float() + depth = (depth * 0.5 + 0.5) * mask + (1.0 - mask) + + ret = { + 'depth': depth, + 'mask': mask, + 'face_id': face_id, + } + + if attr is not None: + image, image_dr = dr.interpolate(attr, rast_out, faces, rast_db, diff_attrs=diff_attrs) + if antialiasing == True: + image = dr.antialias(image, rast_out, pos_clip, faces) + elif isinstance(antialiasing, list): + aa_image = dr.antialias(image[..., antialiasing], rast_out, pos_clip, faces) + image[..., antialiasing] = aa_image + image = image.flip(1).permute(0, 3, 1, 2) + ret['image'] = image + + if uv is not None: + uv_map, uv_map_dr = dr.interpolate(uv, rast_out, faces, rast_db, diff_attrs='all') + ret['uv'] = uv_map + ret['uv_dr'] = uv_map_dr + if texture is not None: + texture_map = dr.texture(ctx.nvd_ctx, uv_map, uv_map_dr) + ret['texture'] = texture_map.flip(1).permute(0, 3, 1, 2) + + if diff_attrs is not None: + image_dr = image_dr.flip(1).permute(0, 3, 1, 2) + ret['image_dr'] = image_dr + + return ret + + +def texture( + ctx: RastContext, + uv: torch.Tensor, + uv_da: torch.Tensor, + texture: torch.Tensor, +) -> torch.Tensor: + dr.texture(ctx.nvd_ctx, uv, texture) + + +def warp_image_by_depth( + ctx: RastContext, + depth: torch.FloatTensor, + image: torch.FloatTensor = None, + mask: torch.BoolTensor = None, + width: int = None, + height: int = None, + *, + extrinsics_src: torch.FloatTensor = None, + extrinsics_tgt: torch.FloatTensor = None, + intrinsics_src: torch.FloatTensor = None, + intrinsics_tgt: torch.FloatTensor = None, + near: float = 0.1, + far: float = 100.0, + antialiasing: bool = True, + backslash: bool = False, + padding: int = 0, + return_uv: bool = False, + return_dr: bool = False, +) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.BoolTensor, Optional[torch.FloatTensor], Optional[torch.FloatTensor]]: + """ + Warp image by depth. + NOTE: if batch size is 1, image mesh will be triangulated aware of the depth, yielding less distorted results. + Otherwise, image mesh will be triangulated simply for batch rendering. + + Args: + ctx (Union[dr.RasterizeCudaContext, dr.RasterizeGLContext]): rasterization context + depth (torch.Tensor): (B, H, W) linear depth + image (torch.Tensor): (B, C, H, W). None to use image space uv. Defaults to None. + width (int, optional): width of the output image. None to use the same as depth. Defaults to None. + height (int, optional): height of the output image. Defaults the same as depth.. + extrinsics_src (torch.Tensor, optional): (B, 4, 4) extrinsics matrix for source. None to use identity. Defaults to None. + extrinsics_tgt (torch.Tensor, optional): (B, 4, 4) extrinsics matrix for target. None to use identity. Defaults to None. + intrinsics_src (torch.Tensor, optional): (B, 3, 3) intrinsics matrix for source. None to use the same as target. Defaults to None. + intrinsics_tgt (torch.Tensor, optional): (B, 3, 3) intrinsics matrix for target. None to use the same as source. Defaults to None. + near (float, optional): near plane. Defaults to 0.1. + far (float, optional): far plane. Defaults to 100.0. + antialiasing (bool, optional): whether to perform antialiasing. Defaults to True. + backslash (bool, optional): whether to use backslash triangulation. Defaults to False. + padding (int, optional): padding of the image. Defaults to 0. + return_uv (bool, optional): whether to return the uv. Defaults to False. + return_dr (bool, optional): whether to return the image-space derivatives of uv. Defaults to False. + + Returns: + image: (torch.FloatTensor): (B, C, H, W) rendered image + depth: (torch.FloatTensor): (B, H, W) linear depth, ranging from 0 to inf + mask: (torch.BoolTensor): (B, H, W) mask of valid pixels + uv: (torch.FloatTensor): (B, 2, H, W) image-space uv + dr: (torch.FloatTensor): (B, 4, H, W) image-space derivatives of uv + """ + assert depth.ndim == 3 + batch_size = depth.shape[0] + + if width is None: + width = depth.shape[-1] + if height is None: + height = depth.shape[-2] + if image is not None: + assert image.shape[-2:] == depth.shape[-2:], f'Shape of image {image.shape} does not match shape of depth {depth.shape}' + + if extrinsics_src is None: + extrinsics_src = torch.eye(4).to(depth) + if extrinsics_tgt is None: + extrinsics_tgt = torch.eye(4).to(depth) + if intrinsics_src is None: + intrinsics_src = intrinsics_tgt + if intrinsics_tgt is None: + intrinsics_tgt = intrinsics_src + + assert all(x is not None for x in [extrinsics_src, extrinsics_tgt, intrinsics_src, intrinsics_tgt]), "Make sure you have provided all the necessary camera parameters." + + view_tgt = transforms.extrinsics_to_view(extrinsics_tgt) + perspective_tgt = transforms.intrinsics_to_perspective(intrinsics_tgt, near=near, far=far) + + if padding > 0: + uv, faces = utils.image_mesh(width=width+2, height=height+2) + uv = (uv - 1 / (width + 2)) * ((width + 2) / width) + uv_ = uv.clone().reshape(height+2, width+2, 2) + uv_[0, :, 1] -= padding / height + uv_[-1, :, 1] += padding / height + uv_[:, 0, 0] -= padding / width + uv_[:, -1, 0] += padding / width + uv_ = uv_.reshape(-1, 2) + depth = torch.nn.functional.pad(depth, [1, 1, 1, 1], mode='replicate') + if image is not None: + image = torch.nn.functional.pad(image, [1, 1, 1, 1], mode='replicate') + uv, uv_, faces = uv.to(depth.device), uv_.to(depth.device), faces.to(depth.device) + pts = transforms.unproject_cv( + uv_, + depth.flatten(-2, -1), + extrinsics_src, + intrinsics_src, + ) + else: + uv, faces = utils.image_mesh(width=depth.shape[-1], height=depth.shape[-2]) + if mask is not None: + depth = torch.where(mask, depth, torch.tensor(far, dtype=depth.dtype, device=depth.device)) + uv, faces = uv.to(depth.device), faces.to(depth.device) + pts = transforms.unproject_cv( + uv, + depth.flatten(-2, -1), + extrinsics_src, + intrinsics_src, + ) + + # triangulate + if batch_size == 1: + faces = mesh.triangulate(faces, vertices=pts[0]) + else: + faces = mesh.triangulate(faces, backslash=backslash) + + # rasterize attributes + diff_attrs = None + if image is not None: + attr = image.permute(0, 2, 3, 1).flatten(1, 2) + if return_dr or return_uv: + if return_dr: + diff_attrs = [image.shape[1], image.shape[1]+1] + if return_uv and antialiasing: + antialiasing = list(range(image.shape[1])) + attr = torch.cat([attr, uv.expand(batch_size, -1, -1)], dim=-1) + else: + attr = uv.expand(batch_size, -1, -1) + if antialiasing: + print("\033[93mWarning: you are performing antialiasing on uv. This may cause artifacts.\033[0m") + if return_uv: + return_uv = False + print("\033[93mWarning: image is None, return_uv is ignored.\033[0m") + if return_dr: + diff_attrs = [0, 1] + + if mask is not None: + attr = torch.cat([attr, mask.float().flatten(1, 2).unsqueeze(-1)], dim=-1) + + rast = rasterize_triangle_faces( + ctx, + pts, + faces, + width, + height, + attr=attr, + view=view_tgt, + perspective=perspective_tgt, + antialiasing=antialiasing, + diff_attrs=diff_attrs, + ) + if return_dr: + output_image, screen_depth, output_dr = rast['image'], rast['depth'], rast['image_dr'] + else: + output_image, screen_depth = rast['image'], rast['depth'] + output_mask = screen_depth < 1.0 + + if mask is not None: + output_image, rast_mask = output_image[..., :-1, :, :], output_image[..., -1, :, :] + output_mask &= (rast_mask > 0.9999).reshape(-1, height, width) + + if (return_dr or return_uv) and image is not None: + output_image, output_uv = output_image[..., :-2, :, :], output_image[..., -2:, :, :] + + output_depth = transforms.depth_buffer_to_linear(screen_depth, near=near, far=far) * output_mask + output_image = output_image * output_mask.unsqueeze(1) + + outs = [output_image, output_depth, output_mask] + if return_uv: + outs.append(output_uv) + if return_dr: + outs.append(output_dr) + return tuple(outs) + + +def warp_image_by_forward_flow( + ctx: RastContext, + image: torch.FloatTensor, + flow: torch.FloatTensor, + depth: torch.FloatTensor = None, + *, + antialiasing: bool = True, + backslash: bool = False, +) -> Tuple[torch.FloatTensor, torch.BoolTensor]: + """ + Warp image by forward flow. + NOTE: if batch size is 1, image mesh will be triangulated aware of the depth, yielding less distorted results. + Otherwise, image mesh will be triangulated simply for batch rendering. + + Args: + ctx (Union[dr.RasterizeCudaContext, dr.RasterizeGLContext]): rasterization context + image (torch.Tensor): (B, C, H, W) image + flow (torch.Tensor): (B, 2, H, W) forward flow + depth (torch.Tensor, optional): (B, H, W) linear depth. If None, will use the same for all pixels. Defaults to None. + antialiasing (bool, optional): whether to perform antialiasing. Defaults to True. + backslash (bool, optional): whether to use backslash triangulation. Defaults to False. + + Returns: + image: (torch.FloatTensor): (B, C, H, W) rendered image + mask: (torch.BoolTensor): (B, H, W) mask of valid pixels + """ + assert image.ndim == 4, f'Wrong shape of image: {image.shape}' + batch_size, _, height, width = image.shape + + if depth is None: + depth = torch.ones_like(flow[:, 0]) + + extrinsics = torch.eye(4).to(image) + fov = torch.deg2rad(torch.tensor([45.0], device=image.device)) + intrinsics = transforms.intrinsics_from_fov(fov, width, height, normalize=True)[0] + + view = transforms.extrinsics_to_view(extrinsics) + perspective = transforms.intrinsics_to_perspective(intrinsics, near=0.1, far=100) + + uv, faces = utils.image_mesh(width=width, height=height) + uv, faces = uv.to(image.device), faces.to(image.device) + uv = uv + flow.permute(0, 2, 3, 1).flatten(1, 2) + pts = transforms.unproject_cv( + uv, + depth.flatten(-2, -1), + extrinsics, + intrinsics, + ) + + # triangulate + if batch_size == 1: + faces = mesh.triangulate(faces, vertices=pts[0]) + else: + faces = mesh.triangulate(faces, backslash=backslash) + + # rasterize attributes + attr = image.permute(0, 2, 3, 1).flatten(1, 2) + rast = rasterize_triangle_faces( + ctx, + pts, + faces, + width, + height, + attr=attr, + view=view, + perspective=perspective, + antialiasing=antialiasing, + ) + output_image, screen_depth = rast['image'], rast['depth'] + output_mask = screen_depth < 1.0 + output_image = output_image * output_mask.unsqueeze(1) + + outs = [output_image, output_mask] + return tuple(outs) diff --git a/submodules/MoGe/utils3d/torch/transforms.py b/submodules/MoGe/utils3d/torch/transforms.py new file mode 100644 index 0000000000000000000000000000000000000000..46e61d741c6ef000c80aa65201b55e31ed4246c6 --- /dev/null +++ b/submodules/MoGe/utils3d/torch/transforms.py @@ -0,0 +1,1189 @@ +from typing import * +from numbers import Number + +import torch +import torch.nn.functional as F + +from ._helpers import batched + + +__all__ = [ + 'perspective', + 'perspective_from_fov', + 'perspective_from_fov_xy', + 'intrinsics_from_focal_center', + 'intrinsics_from_fov', + 'intrinsics_from_fov_xy', + 'view_look_at', + 'extrinsics_look_at', + 'perspective_to_intrinsics', + 'intrinsics_to_perspective', + 'extrinsics_to_view', + 'view_to_extrinsics', + 'normalize_intrinsics', + 'crop_intrinsics', + 'pixel_to_uv', + 'pixel_to_ndc', + 'uv_to_pixel', + 'project_depth', + 'depth_buffer_to_linear', + 'project_gl', + 'project_cv', + 'unproject_gl', + 'unproject_cv', + 'skew_symmetric', + 'rotation_matrix_from_vectors', + 'euler_axis_angle_rotation', + 'euler_angles_to_matrix', + 'matrix_to_euler_angles', + 'matrix_to_quaternion', + 'quaternion_to_matrix', + 'matrix_to_axis_angle', + 'axis_angle_to_matrix', + 'axis_angle_to_quaternion', + 'quaternion_to_axis_angle', + 'slerp', + 'interpolate_extrinsics', + 'interpolate_view', + 'extrinsics_to_essential', + 'to4x4', + 'rotation_matrix_2d', + 'rotate_2d', + 'translate_2d', + 'scale_2d', + 'apply_2d', +] + + +@batched(0,0,0,0) +def perspective( + fov_y: Union[float, torch.Tensor], + aspect: Union[float, torch.Tensor], + near: Union[float, torch.Tensor], + far: Union[float, torch.Tensor] + ) -> torch.Tensor: + """ + Get OpenGL perspective matrix + + Args: + fov_y (float | torch.Tensor): field of view in y axis + aspect (float | torch.Tensor): aspect ratio + near (float | torch.Tensor): near plane to clip + far (float | torch.Tensor): far plane to clip + + Returns: + (torch.Tensor): [..., 4, 4] perspective matrix + """ + N = fov_y.shape[0] + ret = torch.zeros((N, 4, 4), dtype=fov_y.dtype, device=fov_y.device) + ret[:, 0, 0] = 1. / (torch.tan(fov_y / 2) * aspect) + ret[:, 1, 1] = 1. / (torch.tan(fov_y / 2)) + ret[:, 2, 2] = (near + far) / (near - far) + ret[:, 2, 3] = 2. * near * far / (near - far) + ret[:, 3, 2] = -1. + return ret + + +def perspective_from_fov( + fov: Union[float, torch.Tensor], + width: Union[int, torch.Tensor], + height: Union[int, torch.Tensor], + near: Union[float, torch.Tensor], + far: Union[float, torch.Tensor] + ) -> torch.Tensor: + """ + Get OpenGL perspective matrix from field of view in largest dimension + + Args: + fov (float | torch.Tensor): field of view in largest dimension + width (int | torch.Tensor): image width + height (int | torch.Tensor): image height + near (float | torch.Tensor): near plane to clip + far (float | torch.Tensor): far plane to clip + + Returns: + (torch.Tensor): [..., 4, 4] perspective matrix + """ + fov_y = 2 * torch.atan(torch.tan(fov / 2) * height / torch.maximum(width, height)) + aspect = width / height + return perspective(fov_y, aspect, near, far) + + +def perspective_from_fov_xy( + fov_x: Union[float, torch.Tensor], + fov_y: Union[float, torch.Tensor], + near: Union[float, torch.Tensor], + far: Union[float, torch.Tensor] + ) -> torch.Tensor: + """ + Get OpenGL perspective matrix from field of view in x and y axis + + Args: + fov_x (float | torch.Tensor): field of view in x axis + fov_y (float | torch.Tensor): field of view in y axis + near (float | torch.Tensor): near plane to clip + far (float | torch.Tensor): far plane to clip + + Returns: + (torch.Tensor): [..., 4, 4] perspective matrix + """ + aspect = torch.tan(fov_x / 2) / torch.tan(fov_y / 2) + return perspective(fov_y, aspect, near, far) + + +@batched(0,0,0,0) +def intrinsics_from_focal_center( + fx: Union[float, torch.Tensor], + fy: Union[float, torch.Tensor], + cx: Union[float, torch.Tensor], + cy: Union[float, torch.Tensor] +) -> torch.Tensor: + """ + Get OpenCV intrinsics matrix + + Args: + focal_x (float | torch.Tensor): focal length in x axis + focal_y (float | torch.Tensor): focal length in y axis + cx (float | torch.Tensor): principal point in x axis + cy (float | torch.Tensor): principal point in y axis + + Returns: + (torch.Tensor): [..., 3, 3] OpenCV intrinsics matrix + """ + N = fx.shape[0] + ret = torch.zeros((N, 3, 3), dtype=fx.dtype, device=fx.device) + zeros, ones = torch.zeros(N, dtype=fx.dtype, device=fx.device), torch.ones(N, dtype=fx.dtype, device=fx.device) + ret = torch.stack([fx, zeros, cx, zeros, fy, cy, zeros, zeros, ones], dim=-1).unflatten(-1, (3, 3)) + return ret + + +@batched(0, 0, 0, 0, 0, 0) +def intrinsics_from_fov( + fov_max: Union[float, torch.Tensor] = None, + fov_min: Union[float, torch.Tensor] = None, + fov_x: Union[float, torch.Tensor] = None, + fov_y: Union[float, torch.Tensor] = None, + width: Union[int, torch.Tensor] = None, + height: Union[int, torch.Tensor] = None, +) -> torch.Tensor: + """ + Get normalized OpenCV intrinsics matrix from given field of view. + You can provide either fov_max, fov_min, fov_x or fov_y + + Args: + width (int | torch.Tensor): image width + height (int | torch.Tensor): image height + fov_max (float | torch.Tensor): field of view in largest dimension + fov_min (float | torch.Tensor): field of view in smallest dimension + fov_x (float | torch.Tensor): field of view in x axis + fov_y (float | torch.Tensor): field of view in y axis + + Returns: + (torch.Tensor): [..., 3, 3] OpenCV intrinsics matrix + """ + if fov_max is not None: + fx = torch.maximum(width, height) / width / (2 * torch.tan(fov_max / 2)) + fy = torch.maximum(width, height) / height / (2 * torch.tan(fov_max / 2)) + elif fov_min is not None: + fx = torch.minimum(width, height) / width / (2 * torch.tan(fov_min / 2)) + fy = torch.minimum(width, height) / height / (2 * torch.tan(fov_min / 2)) + elif fov_x is not None and fov_y is not None: + fx = 1 / (2 * torch.tan(fov_x / 2)) + fy = 1 / (2 * torch.tan(fov_y / 2)) + elif fov_x is not None: + fx = 1 / (2 * torch.tan(fov_x / 2)) + fy = fx * width / height + elif fov_y is not None: + fy = 1 / (2 * torch.tan(fov_y / 2)) + fx = fy * height / width + cx = 0.5 + cy = 0.5 + ret = intrinsics_from_focal_center(fx, fy, cx, cy) + return ret + + + +def intrinsics_from_fov_xy( + fov_x: Union[float, torch.Tensor], + fov_y: Union[float, torch.Tensor] + ) -> torch.Tensor: + """ + Get OpenCV intrinsics matrix from field of view in x and y axis + + Args: + fov_x (float | torch.Tensor): field of view in x axis + fov_y (float | torch.Tensor): field of view in y axis + + Returns: + (torch.Tensor): [..., 3, 3] OpenCV intrinsics matrix + """ + focal_x = 0.5 / torch.tan(fov_x / 2) + focal_y = 0.5 / torch.tan(fov_y / 2) + cx = cy = 0.5 + return intrinsics_from_focal_center(focal_x, focal_y, cx, cy) + + +@batched(1,1,1) +def view_look_at( + eye: torch.Tensor, + look_at: torch.Tensor, + up: torch.Tensor + ) -> torch.Tensor: + """ + Get OpenGL view matrix looking at something + + Args: + eye (torch.Tensor): [..., 3] the eye position + look_at (torch.Tensor): [..., 3] the position to look at + up (torch.Tensor): [..., 3] head up direction (y axis in screen space). Not necessarily othogonal to view direction + + Returns: + (torch.Tensor): [..., 4, 4], view matrix + """ + N = eye.shape[0] + z = eye - look_at + x = torch.cross(up, z, dim=-1) + y = torch.cross(z, x, dim=-1) + # x = torch.cross(y, z, dim=-1) + x = x / x.norm(dim=-1, keepdim=True) + y = y / y.norm(dim=-1, keepdim=True) + z = z / z.norm(dim=-1, keepdim=True) + R = torch.stack([x, y, z], dim=-2) + t = -torch.matmul(R, eye[..., None]) + ret = torch.zeros((N, 4, 4), dtype=eye.dtype, device=eye.device) + ret[:, :3, :3] = R + ret[:, :3, 3] = t[:, :, 0] + ret[:, 3, 3] = 1. + return ret + + +@batched(1, 1, 1) +def extrinsics_look_at( + eye: torch.Tensor, + look_at: torch.Tensor, + up: torch.Tensor +) -> torch.Tensor: + """ + Get OpenCV extrinsics matrix looking at something + + Args: + eye (torch.Tensor): [..., 3] the eye position + look_at (torch.Tensor): [..., 3] the position to look at + up (torch.Tensor): [..., 3] head up direction (-y axis in screen space). Not necessarily othogonal to view direction + + Returns: + (torch.Tensor): [..., 4, 4], extrinsics matrix + """ + N = eye.shape[0] + z = look_at - eye + x = torch.cross(-up, z, dim=-1) + y = torch.cross(z, x, dim=-1) + # x = torch.cross(y, z, dim=-1) + x = x / x.norm(dim=-1, keepdim=True) + y = y / y.norm(dim=-1, keepdim=True) + z = z / z.norm(dim=-1, keepdim=True) + R = torch.stack([x, y, z], dim=-2) + t = -torch.matmul(R, eye[..., None]) + ret = torch.zeros((N, 4, 4), dtype=eye.dtype, device=eye.device) + ret[:, :3, :3] = R + ret[:, :3, 3] = t[:, :, 0] + ret[:, 3, 3] = 1. + return ret + + +@batched(2) +def perspective_to_intrinsics( + perspective: torch.Tensor +) -> torch.Tensor: + """ + OpenGL perspective matrix to OpenCV intrinsics + + Args: + perspective (torch.Tensor): [..., 4, 4] OpenGL perspective matrix + + Returns: + (torch.Tensor): shape [..., 3, 3] OpenCV intrinsics + """ + assert torch.allclose(perspective[:, [0, 1, 3], 3], 0), "The perspective matrix is not a projection matrix" + ret = torch.tensor([[0.5, 0., 0.5], [0., -0.5, 0.5], [0., 0., 1.]], dtype=perspective.dtype, device=perspective.device) \ + @ perspective[:, [0, 1, 3], :3] \ + @ torch.diag(torch.tensor([1, -1, -1], dtype=perspective.dtype, device=perspective.device)) + return ret / ret[:, 2, 2, None, None] + + +@batched(2,0,0) +def intrinsics_to_perspective( + intrinsics: torch.Tensor, + near: Union[float, torch.Tensor], + far: Union[float, torch.Tensor], + ) -> torch.Tensor: + """ + OpenCV intrinsics to OpenGL perspective matrix + + Args: + intrinsics (torch.Tensor): [..., 3, 3] OpenCV intrinsics matrix + near (float | torch.Tensor): [...] near plane to clip + far (float | torch.Tensor): [...] far plane to clip + Returns: + (torch.Tensor): [..., 4, 4] OpenGL perspective matrix + """ + N = intrinsics.shape[0] + fx, fy = intrinsics[:, 0, 0], intrinsics[:, 1, 1] + cx, cy = intrinsics[:, 0, 2], intrinsics[:, 1, 2] + ret = torch.zeros((N, 4, 4), dtype=intrinsics.dtype, device=intrinsics.device) + ret[:, 0, 0] = 2 * fx + ret[:, 1, 1] = 2 * fy + ret[:, 0, 2] = -2 * cx + 1 + ret[:, 1, 2] = 2 * cy - 1 + ret[:, 2, 2] = (near + far) / (near - far) + ret[:, 2, 3] = 2. * near * far / (near - far) + ret[:, 3, 2] = -1. + return ret + + +@batched(2) +def extrinsics_to_view( + extrinsics: torch.Tensor + ) -> torch.Tensor: + """ + OpenCV camera extrinsics to OpenGL view matrix + + Args: + extrinsics (torch.Tensor): [..., 4, 4] OpenCV camera extrinsics matrix + + Returns: + (torch.Tensor): [..., 4, 4] OpenGL view matrix + """ + return extrinsics * torch.tensor([1, -1, -1, 1], dtype=extrinsics.dtype, device=extrinsics.device)[:, None] + + +@batched(2) +def view_to_extrinsics( + view: torch.Tensor + ) -> torch.Tensor: + """ + OpenGL view matrix to OpenCV camera extrinsics + + Args: + view (torch.Tensor): [..., 4, 4] OpenGL view matrix + + Returns: + (torch.Tensor): [..., 4, 4] OpenCV camera extrinsics matrix + """ + return view * torch.tensor([1, -1, -1, 1], dtype=view.dtype, device=view.device)[:, None] + + +@batched(2,0,0) +def normalize_intrinsics( + intrinsics: torch.Tensor, + width: Union[int, torch.Tensor], + height: Union[int, torch.Tensor] + ) -> torch.Tensor: + """ + Normalize camera intrinsics(s) to uv space + + Args: + intrinsics (torch.Tensor): [..., 3, 3] camera intrinsics(s) to normalize + width (int | torch.Tensor): [...] image width(s) + height (int | torch.Tensor): [...] image height(s) + + Returns: + (torch.Tensor): [..., 3, 3] normalized camera intrinsics(s) + """ + zeros = torch.zeros_like(width) + ones = torch.ones_like(width) + transform = torch.stack([ + 1 / width, zeros, 0.5 / width, + zeros, 1 / height, 0.5 / height, + zeros, zeros, ones + ]).reshape(*zeros.shape, 3, 3).to(intrinsics) + return transform @ intrinsics + + + +@batched(2,0,0,0,0,0,0) +def crop_intrinsics( + intrinsics: torch.Tensor, + width: Union[int, torch.Tensor], + height: Union[int, torch.Tensor], + left: Union[int, torch.Tensor], + top: Union[int, torch.Tensor], + crop_width: Union[int, torch.Tensor], + crop_height: Union[int, torch.Tensor] +) -> torch.Tensor: + """ + Evaluate the new intrinsics(s) after crop the image: cropped_img = img[top:top+crop_height, left:left+crop_width] + + Args: + intrinsics (torch.Tensor): [..., 3, 3] camera intrinsics(s) to crop + width (int | torch.Tensor): [...] image width(s) + height (int | torch.Tensor): [...] image height(s) + left (int | torch.Tensor): [...] left crop boundary + top (int | torch.Tensor): [...] top crop boundary + crop_width (int | torch.Tensor): [...] crop width + crop_height (int | torch.Tensor): [...] crop height + + Returns: + (torch.Tensor): [..., 3, 3] cropped camera intrinsics(s) + """ + zeros = torch.zeros_like(width) + ones = torch.ones_like(width) + transform = torch.stack([ + width / crop_width, zeros, -left / crop_width, + zeros, height / crop_height, -top / crop_height, + zeros, zeros, ones + ]).reshape(*zeros.shape, 3, 3).to(intrinsics) + return transform @ intrinsics + + +@batched(1,0,0) +def pixel_to_uv( + pixel: torch.Tensor, + width: Union[int, torch.Tensor], + height: Union[int, torch.Tensor] +) -> torch.Tensor: + """ + Args: + pixel (torch.Tensor): [..., 2] pixel coordinrates defined in image space, x range is (0, W - 1), y range is (0, H - 1) + width (int | torch.Tensor): [...] image width(s) + height (int | torch.Tensor): [...] image height(s) + + Returns: + (torch.Tensor): [..., 2] pixel coordinrates defined in uv space, the range is (0, 1) + """ + if not torch.is_floating_point(pixel): + pixel = pixel.float() + uv = (pixel + 0.5) / torch.stack([width, height], dim=-1).to(pixel) + return uv + + +@batched(1,0,0) +def uv_to_pixel( + uv: torch.Tensor, + width: Union[int, torch.Tensor], + height: Union[int, torch.Tensor] +) -> torch.Tensor: + """ + Args: + uv (torch.Tensor): [..., 2] pixel coordinrates defined in uv space, the range is (0, 1) + width (int | torch.Tensor): [...] image width(s) + height (int | torch.Tensor): [...] image height(s) + + Returns: + (torch.Tensor): [..., 2] pixel coordinrates defined in uv space, the range is (0, 1) + """ + pixel = uv * torch.stack([width, height], dim=-1).to(uv) - 0.5 + return pixel + + +@batched(1,0,0) +def pixel_to_ndc( + pixel: torch.Tensor, + width: Union[int, torch.Tensor], + height: Union[int, torch.Tensor] +) -> torch.Tensor: + """ + Args: + pixel (torch.Tensor): [..., 2] pixel coordinrates defined in image space, x range is (0, W - 1), y range is (0, H - 1) + width (int | torch.Tensor): [...] image width(s) + height (int | torch.Tensor): [...] image height(s) + + Returns: + (torch.Tensor): [..., 2] pixel coordinrates defined in ndc space, the range is (-1, 1) + """ + if not torch.is_floating_point(pixel): + pixel = pixel.float() + ndc = (pixel + 0.5) / (torch.stack([width, height], dim=-1).to(pixel) * torch.tensor([2, -2], dtype=pixel.dtype, device=pixel.device)) \ + + torch.tensor([-1, 1], dtype=pixel.dtype, device=pixel.device) + return ndc + + +@batched(0,0,0) +def project_depth( + depth: torch.Tensor, + near: Union[float, torch.Tensor], + far: Union[float, torch.Tensor] + ) -> torch.Tensor: + """ + Project linear depth to depth value in screen space + + Args: + depth (torch.Tensor): [...] depth value + near (float | torch.Tensor): [...] near plane to clip + far (float | torch.Tensor): [...] far plane to clip + + Returns: + (torch.Tensor): [..., 1] depth value in screen space, value ranging in [0, 1] + """ + return (far - near * far / depth) / (far - near) + + +@batched(0,0,0) +def depth_buffer_to_linear( + depth: torch.Tensor, + near: Union[float, torch.Tensor], + far: Union[float, torch.Tensor] + ) -> torch.Tensor: + """ + Linearize depth value to linear depth + + Args: + depth (torch.Tensor): [...] screen depth value, ranging in [0, 1] + near (float | torch.Tensor): [...] near plane to clip + far (float | torch.Tensor): [...] far plane to clip + + Returns: + (torch.Tensor): [...] linear depth + """ + return near * far / (far - (far - near) * depth) + + +@batched(2, 2, 2, 2) +def project_gl( + points: torch.Tensor, + model: torch.Tensor = None, + view: torch.Tensor = None, + perspective: torch.Tensor = None +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Project 3D points to 2D following the OpenGL convention (except for row major matrice) + + Args: + points (torch.Tensor): [..., N, 3 or 4] 3D points to project, if the last + dimension is 4, the points are assumed to be in homogeneous coordinates + model (torch.Tensor): [..., 4, 4] model matrix + view (torch.Tensor): [..., 4, 4] view matrix + perspective (torch.Tensor): [..., 4, 4] perspective matrix + + Returns: + scr_coord (torch.Tensor): [..., N, 3] screen space coordinates, value ranging in [0, 1]. + The origin (0., 0., 0.) is corresponding to the left & bottom & nearest + linear_depth (torch.Tensor): [..., N] linear depth + """ + assert perspective is not None, "perspective matrix is required" + + if points.shape[-1] == 3: + points = torch.cat([points, torch.ones_like(points[..., :1])], dim=-1) + mvp = perspective if perspective is not None else torch.eye(4).to(points) + if view is not None: + mvp = mvp @ view + if model is not None: + mvp = mvp @ model + clip_coord = points @ mvp.transpose(-1, -2) + ndc_coord = clip_coord[..., :3] / clip_coord[..., 3:] + scr_coord = ndc_coord * 0.5 + 0.5 + linear_depth = clip_coord[..., 3] + return scr_coord, linear_depth + + +@batched(2, 2, 2) +def project_cv( + points: torch.Tensor, + extrinsics: torch.Tensor = None, + intrinsics: torch.Tensor = None +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Project 3D points to 2D following the OpenCV convention + + Args: + points (torch.Tensor): [..., N, 3] or [..., N, 4] 3D points to project, if the last + dimension is 4, the points are assumed to be in homogeneous coordinates + extrinsics (torch.Tensor): [..., 4, 4] extrinsics matrix + intrinsics (torch.Tensor): [..., 3, 3] intrinsics matrix + + Returns: + uv_coord (torch.Tensor): [..., N, 2] uv coordinates, value ranging in [0, 1]. + The origin (0., 0.) is corresponding to the left & top + linear_depth (torch.Tensor): [..., N] linear depth + """ + assert intrinsics is not None, "intrinsics matrix is required" + if points.shape[-1] == 3: + points = torch.cat([points, torch.ones_like(points[..., :1])], dim=-1) + if extrinsics is not None: + points = points @ extrinsics.transpose(-1, -2) + points = points[..., :3] @ intrinsics.transpose(-2, -1) + uv_coord = points[..., :2] / points[..., 2:] + linear_depth = points[..., 2] + return uv_coord, linear_depth + + +@batched(2, 2, 2, 2) +def unproject_gl( + screen_coord: torch.Tensor, + model: torch.Tensor = None, + view: torch.Tensor = None, + perspective: torch.Tensor = None + ) -> torch.Tensor: + """ + Unproject screen space coordinates to 3D view space following the OpenGL convention (except for row major matrice) + + Args: + screen_coord (torch.Tensor): [... N, 3] screen space coordinates, value ranging in [0, 1]. + The origin (0., 0., 0.) is corresponding to the left & bottom & nearest + model (torch.Tensor): [..., 4, 4] model matrix + view (torch.Tensor): [..., 4, 4] view matrix + perspective (torch.Tensor): [..., 4, 4] perspective matrix + + Returns: + points (torch.Tensor): [..., N, 3] 3d points + """ + assert perspective is not None, "perspective matrix is required" + ndc_xy = screen_coord * 2 - 1 + clip_coord = torch.cat([ndc_xy, torch.ones_like(ndc_xy[..., :1])], dim=-1) + transform = perspective + if view is not None: + transform = transform @ view + if model is not None: + transform = transform @ model + transform = torch.inverse(transform) + points = clip_coord @ transform.transpose(-1, -2) + points = points[..., :3] / points[..., 3:] + return points + + +@batched(2, 1, 2, 2) +def unproject_cv( + uv_coord: torch.Tensor, + depth: torch.Tensor, + extrinsics: torch.Tensor = None, + intrinsics: torch.Tensor = None +) -> torch.Tensor: + """ + Unproject uv coordinates to 3D view space following the OpenCV convention + + Args: + uv_coord (torch.Tensor): [..., N, 2] uv coordinates, value ranging in [0, 1]. + The origin (0., 0.) is corresponding to the left & top + depth (torch.Tensor): [..., N] depth value + extrinsics (torch.Tensor): [..., 4, 4] extrinsics matrix + intrinsics (torch.Tensor): [..., 3, 3] intrinsics matrix + + Returns: + points (torch.Tensor): [..., N, 3] 3d points + """ + assert intrinsics is not None, "intrinsics matrix is required" + points = torch.cat([uv_coord, torch.ones_like(uv_coord[..., :1])], dim=-1) + points = points @ torch.inverse(intrinsics).transpose(-2, -1) + points = points * depth[..., None] + if extrinsics is not None: + points = torch.cat([points, torch.ones_like(points[..., :1])], dim=-1) + points = (points @ torch.inverse(extrinsics).transpose(-2, -1))[..., :3] + return points + + +def euler_axis_angle_rotation(axis: str, angle: torch.Tensor) -> torch.Tensor: + """ + Return the rotation matrices for one of the rotations about an axis + of which Euler angles describe, for each value of the angle given. + + Args: + axis: Axis label "X" or "Y or "Z". + angle: any shape tensor of Euler angles in radians + + Returns: + Rotation matrices as tensor of shape (..., 3, 3). + """ + + cos = torch.cos(angle) + sin = torch.sin(angle) + one = torch.ones_like(angle) + zero = torch.zeros_like(angle) + + if axis == "X": + R_flat = (one, zero, zero, zero, cos, -sin, zero, sin, cos) + elif axis == "Y": + R_flat = (cos, zero, sin, zero, one, zero, -sin, zero, cos) + elif axis == "Z": + R_flat = (cos, -sin, zero, sin, cos, zero, zero, zero, one) + else: + raise ValueError("letter must be either X, Y or Z.") + + return torch.stack(R_flat, -1).reshape(angle.shape + (3, 3)) + + +def euler_angles_to_matrix(euler_angles: torch.Tensor, convention: str = 'XYZ') -> torch.Tensor: + """ + Convert rotations given as Euler angles in radians to rotation matrices. + + Args: + euler_angles: Euler angles in radians as tensor of shape (..., 3), XYZ + convention: permutation of "X", "Y" or "Z", representing the order of Euler rotations to apply. + + Returns: + Rotation matrices as tensor of shape (..., 3, 3). + """ + if euler_angles.dim() == 0 or euler_angles.shape[-1] != 3: + raise ValueError("Invalid input euler angles.") + if len(convention) != 3: + raise ValueError("Convention must have 3 letters.") + if convention[1] in (convention[0], convention[2]): + raise ValueError(f"Invalid convention {convention}.") + for letter in convention: + if letter not in ("X", "Y", "Z"): + raise ValueError(f"Invalid letter {letter} in convention string.") + matrices = [ + euler_axis_angle_rotation(c, euler_angles[..., 'XYZ'.index(c)]) + for c in convention + ] + # return functools.reduce(torch.matmul, matrices) + return matrices[2] @ matrices[1] @ matrices[0] + + +def skew_symmetric(v: torch.Tensor): + "Skew symmetric matrix from a 3D vector" + assert v.shape[-1] == 3, "v must be 3D" + x, y, z = v.unbind(dim=-1) + zeros = torch.zeros_like(x) + return torch.stack([ + zeros, -z, y, + z, zeros, -x, + -y, x, zeros, + ], dim=-1).reshape(*v.shape[:-1], 3, 3) + + +def rotation_matrix_from_vectors(v1: torch.Tensor, v2: torch.Tensor): + "Rotation matrix that rotates v1 to v2" + I = torch.eye(3).to(v1) + v1 = F.normalize(v1, dim=-1) + v2 = F.normalize(v2, dim=-1) + v = torch.cross(v1, v2, dim=-1) + c = torch.sum(v1 * v2, dim=-1) + K = skew_symmetric(v) + R = I + K + (1 / (1 + c))[None, None] * (K @ K) + return R + + +def _angle_from_tan( + axis: str, other_axis: str, data, horizontal: bool, tait_bryan: bool +) -> torch.Tensor: + """ + Extract the first or third Euler angle from the two members of + the matrix which are positive constant times its sine and cosine. + + Args: + axis: Axis label "X" or "Y or "Z" for the angle we are finding. + other_axis: Axis label "X" or "Y or "Z" for the middle axis in the + convention. + data: Rotation matrices as tensor of shape (..., 3, 3). + horizontal: Whether we are looking for the angle for the third axis, + which means the relevant entries are in the same row of the + rotation matrix. If not, they are in the same column. + tait_bryan: Whether the first and third axes in the convention differ. + + Returns: + Euler Angles in radians for each matrix in data as a tensor + of shape (...). + """ + + i1, i2 = {"X": (2, 1), "Y": (0, 2), "Z": (1, 0)}[axis] + if horizontal: + i2, i1 = i1, i2 + even = (axis + other_axis) in ["XY", "YZ", "ZX"] + if horizontal == even: + return torch.atan2(data[..., i1], data[..., i2]) + if tait_bryan: + return torch.atan2(-data[..., i2], data[..., i1]) + return torch.atan2(data[..., i2], -data[..., i1]) + + +def matrix_to_euler_angles(matrix: torch.Tensor, convention: str) -> torch.Tensor: + """ + Convert rotations given as rotation matrices to Euler angles in radians. + NOTE: The composition order eg. `XYZ` means `Rz * Ry * Rx` (like blender), instead of `Rx * Ry * Rz` (like pytorch3d) + + Args: + matrix: Rotation matrices as tensor of shape (..., 3, 3). + convention: Convention string of three uppercase letters. + + Returns: + Euler angles in radians as tensor of shape (..., 3), in the order of XYZ (like blender), instead of convention (like pytorch3d) + """ + if not all(c in 'XYZ' for c in convention) or not all(c in convention for c in 'XYZ'): + raise ValueError(f"Invalid convention {convention}.") + if not matrix.shape[-2:] == (3, 3): + raise ValueError(f"Invalid rotation matrix shape {matrix.shape}.") + + i0 = 'XYZ'.index(convention[0]) + i2 = 'XYZ'.index(convention[2]) + tait_bryan = i0 != i2 + if tait_bryan: + central_angle = torch.asin(matrix[..., i2, i0] * (-1.0 if i2 - i0 in [-1, 2] else 1.0)) + else: + central_angle = torch.acos(matrix[..., i2, i2]) + + # Angles in composition order + o = [ + _angle_from_tan( + convention[0], convention[1], matrix[..., i2, :], True, tait_bryan + ), + central_angle, + _angle_from_tan( + convention[2], convention[1], matrix[..., i0], False, tait_bryan + ), + ] + return torch.stack([o[convention.index(c)] for c in 'XYZ'], -1) + + +def axis_angle_to_matrix(axis_angle: torch.Tensor, eps: float = 1e-12) -> torch.Tensor: + """Convert axis-angle representation (rotation vector) to rotation matrix, whose direction is the axis of rotation and length is the angle of rotation + + Args: + axis_angle (torch.Tensor): shape (..., 3), axis-angle vcetors + + Returns: + torch.Tensor: shape (..., 3, 3) The rotation matrices for the given axis-angle parameters + """ + batch_shape = axis_angle.shape[:-1] + device, dtype = axis_angle.device, axis_angle.dtype + + angle = torch.norm(axis_angle + eps, dim=-1, keepdim=True) + axis = axis_angle / angle + + cos = torch.cos(angle)[..., None, :] + sin = torch.sin(angle)[..., None, :] + + rx, ry, rz = torch.split(axis, 3, dim=-1) + zeros = torch.zeros((*batch_shape, 1), dtype=dtype, device=device) + K = torch.cat([zeros, -rz, ry, rz, zeros, -rx, -ry, rx, zeros], dim=-1).view((*batch_shape, 3, 3)) + + ident = torch.eye(3, dtype=dtype, device=device) + rot_mat = ident + sin * K + (1 - cos) * torch.matmul(K, K) + return rot_mat + + +def matrix_to_axis_angle(rot_mat: torch.Tensor, eps: float = 1e-12) -> torch.Tensor: + """Convert a batch of 3x3 rotation matrices to axis-angle representation (rotation vector) + + Args: + rot_mat (torch.Tensor): shape (..., 3, 3), the rotation matrices to convert + + Returns: + torch.Tensor: shape (..., 3), the axis-angle vectors corresponding to the given rotation matrices + """ + quat = matrix_to_quaternion(rot_mat) + axis_angle = quaternion_to_axis_angle(quat, eps=eps) + return axis_angle + + +def quaternion_to_axis_angle(quaternion: torch.Tensor, eps: float = 1e-12) -> torch.Tensor: + """Convert a batch of quaternions (w, x, y, z) to axis-angle representation (rotation vector) + + Args: + quaternion (torch.Tensor): shape (..., 4), the quaternions to convert + + Returns: + torch.Tensor: shape (..., 3), the axis-angle vectors corresponding to the given quaternions + """ + assert quaternion.shape[-1] == 4 + norm = torch.norm(quaternion[..., 1:], dim=-1, keepdim=True) + axis = quaternion[..., 1:] / norm.clamp(min=eps) + angle = 2 * torch.atan2(norm, quaternion[..., 0:1]) + return angle * axis + + +def axis_angle_to_quaternion(axis_angle: torch.Tensor, eps: float = 1e-12) -> torch.Tensor: + """Convert axis-angle representation (rotation vector) to quaternion (w, x, y, z) + + Args: + axis_angle (torch.Tensor): shape (..., 3), axis-angle vcetors + + Returns: + torch.Tensor: shape (..., 4) The quaternions for the given axis-angle parameters + """ + axis = F.normalize(axis_angle, dim=-1, eps=eps) + angle = torch.norm(axis_angle, dim=-1, keepdim=True) + quat = torch.cat([torch.cos(angle / 2), torch.sin(angle / 2) * axis], dim=-1) + return quat + + +def matrix_to_quaternion(rot_mat: torch.Tensor, eps: float = 1e-12) -> torch.Tensor: + """Convert 3x3 rotation matrix to quaternion (w, x, y, z) + + Args: + rot_mat (torch.Tensor): shape (..., 3, 3), the rotation matrices to convert + + Returns: + torch.Tensor: shape (..., 4), the quaternions corresponding to the given rotation matrices + """ + # Extract the diagonal and off-diagonal elements of the rotation matrix + m00, m01, m02, m10, m11, m12, m20, m21, m22 = rot_mat.flatten(-2).unbind(dim=-1) + + diag = torch.diagonal(rot_mat, dim1=-2, dim2=-1) + M = torch.tensor([ + [1, 1, 1], + [1, -1, -1], + [-1, 1, -1], + [-1, -1, 1] + ], dtype=rot_mat.dtype, device=rot_mat.device) + wxyz = (1 + diag @ M.transpose(-1, -2)).clamp_(0).sqrt().mul(0.5) + _, max_idx = wxyz.max(dim=-1) + xw = torch.sign(m21 - m12) + yw = torch.sign(m02 - m20) + zw = torch.sign(m10 - m01) + yz = torch.sign(m21 + m12) + xz = torch.sign(m02 + m20) + xy = torch.sign(m01 + m10) + ones = torch.ones_like(xw) + sign = torch.where( + max_idx[..., None] == 0, + torch.stack([ones, xw, yw, zw], dim=-1), + torch.where( + max_idx[..., None] == 1, + torch.stack([xw, ones, xy, xz], dim=-1), + torch.where( + max_idx[..., None] == 2, + torch.stack([yw, xy, ones, yz], dim=-1), + torch.stack([zw, xz, yz, ones], dim=-1) + ) + ) + ) + quat = sign * wxyz + quat = F.normalize(quat, dim=-1, eps=eps) + return quat + + +def quaternion_to_matrix(quaternion: torch.Tensor, eps: float = 1e-12) -> torch.Tensor: + """Converts a batch of quaternions (w, x, y, z) to rotation matrices + + Args: + quaternion (torch.Tensor): shape (..., 4), the quaternions to convert + + Returns: + torch.Tensor: shape (..., 3, 3), the rotation matrices corresponding to the given quaternions + """ + assert quaternion.shape[-1] == 4 + quaternion = F.normalize(quaternion, dim=-1, eps=eps) + w, x, y, z = quaternion.unbind(dim=-1) + zeros = torch.zeros_like(w) + I = torch.eye(3, dtype=quaternion.dtype, device=quaternion.device) + xyz = quaternion[..., 1:] + A = xyz[..., :, None] * xyz[..., None, :] - I * (xyz ** 2).sum(dim=-1)[..., None, None] + B = torch.stack([ + zeros, -z, y, + z, zeros, -x, + -y, x, zeros + ], dim=-1).unflatten(-1, (3, 3)) + rot_mat = I + 2 * (A + w[..., None, None] * B) + return rot_mat + + +def slerp(rot_mat_1: torch.Tensor, rot_mat_2: torch.Tensor, t: Union[Number, torch.Tensor]) -> torch.Tensor: + """Spherical linear interpolation between two rotation matrices + + Args: + rot_mat_1 (torch.Tensor): shape (..., 3, 3), the first rotation matrix + rot_mat_2 (torch.Tensor): shape (..., 3, 3), the second rotation matrix + t (torch.Tensor): scalar or shape (...,), the interpolation factor + + Returns: + torch.Tensor: shape (..., 3, 3), the interpolated rotation matrix + """ + assert rot_mat_1.shape[-2:] == (3, 3) + rot_vec_1 = matrix_to_axis_angle(rot_mat_1) + rot_vec_2 = matrix_to_axis_angle(rot_mat_2) + if isinstance(t, Number): + t = torch.tensor(t, dtype=rot_mat_1.dtype, device=rot_mat_1.device) + rot_vec = (1 - t[..., None]) * rot_vec_1 + t[..., None] * rot_vec_2 + rot_mat = axis_angle_to_matrix(rot_vec) + return rot_mat + + +def interpolate_extrinsics(ext1: torch.Tensor, ext2: torch.Tensor, t: Union[Number, torch.Tensor]) -> torch.Tensor: + """Interpolate extrinsics between two camera poses. Linear interpolation for translation, spherical linear interpolation for rotation. + + Args: + ext1 (torch.Tensor): shape (..., 4, 4), the first camera pose + ext2 (torch.Tensor): shape (..., 4, 4), the second camera pose + t (torch.Tensor): scalar or shape (...,), the interpolation factor + + Returns: + torch.Tensor: shape (..., 4, 4), the interpolated camera pose + """ + return torch.inverse(interpolate_transform(torch.inverse(ext1), torch.inverse(ext2), t)) + + +def interpolate_view(view1: torch.Tensor, view2: torch.Tensor, t: Union[Number, torch.Tensor]): + """Interpolate view matrices between two camera poses. Linear interpolation for translation, spherical linear interpolation for rotation. + + Args: + ext1 (torch.Tensor): shape (..., 4, 4), the first camera pose + ext2 (torch.Tensor): shape (..., 4, 4), the second camera pose + t (torch.Tensor): scalar or shape (...,), the interpolation factor + + Returns: + torch.Tensor: shape (..., 4, 4), the interpolated camera pose + """ + return interpolate_extrinsics(view1, view2, t) + + +def interpolate_transform(transform1: torch.Tensor, transform2: torch.Tensor, t: Union[Number, torch.Tensor]): + assert transform1.shape[-2:] == (4, 4) and transform2.shape[-2:] == (4, 4) + if isinstance(t, Number): + t = torch.tensor(t, dtype=transform1.dtype, device=transform1.device) + pos = (1 - t[..., None]) * transform1[..., :3, 3] + t[..., None] * transform2[..., :3, 3] + rot = slerp(transform1[..., :3, :3], transform2[..., :3, :3], t) + transform = torch.cat([rot, pos[..., None]], dim=-1) + transform = torch.cat([ext, torch.tensor([0, 0, 0, 1], dtype=transform.dtype, device=transform.device).expand_as(transform[..., :1, :])], dim=-2) + return transform + + +def extrinsics_to_essential(extrinsics: torch.Tensor): + """ + extrinsics matrix `[[R, t] [0, 0, 0, 1]]` such that `x' = R (x - t)` to essential matrix such that `x' E x = 0` + + Args: + extrinsics (torch.Tensor): [..., 4, 4] extrinsics matrix + + Returns: + (torch.Tensor): [..., 3, 3] essential matrix + """ + assert extrinsics.shape[-2:] == (4, 4) + R = extrinsics[..., :3, :3] + t = extrinsics[..., :3, 3] + zeros = torch.zeros_like(t) + t_x = torch.stack([ + zeros, -t[..., 2], t[..., 1], + t[..., 2], zeros, -t[..., 0], + -t[..., 1], t[..., 0], zeros + ]).reshape(*t.shape[:-1], 3, 3) + return R @ t_x + + +def to4x4(R: torch.Tensor, t: torch.Tensor): + """ + Compose rotation matrix and translation vector to 4x4 transformation matrix + + Args: + R (torch.Tensor): [..., 3, 3] rotation matrix + t (torch.Tensor): [..., 3] translation vector + + Returns: + (torch.Tensor): [..., 4, 4] transformation matrix + """ + assert R.shape[-2:] == (3, 3) + assert t.shape[-1] == 3 + assert R.shape[:-2] == t.shape[:-1] + return torch.cat([ + torch.cat([R, t[..., None]], dim=-1), + torch.tensor([0, 0, 0, 1], dtype=R.dtype, device=R.device).expand(*R.shape[:-2], 1, 4) + ], dim=-2) + + +def rotation_matrix_2d(theta: Union[float, torch.Tensor]): + """ + 2x2 matrix for 2D rotation + + Args: + theta (float | torch.Tensor): rotation angle in radians, arbitrary shape (...,) + + Returns: + (torch.Tensor): (..., 2, 2) rotation matrix + """ + if isinstance(theta, float): + theta = torch.tensor(theta) + return torch.stack([ + torch.cos(theta), -torch.sin(theta), + torch.sin(theta), torch.cos(theta), + ], dim=-1).unflatten(-1, (2, 2)) + + +def rotate_2d(theta: Union[float, torch.Tensor], center: torch.Tensor = None): + """ + 3x3 matrix for 2D rotation around a center + ``` + [[Rxx, Rxy, tx], + [Ryx, Ryy, ty], + [0, 0, 1]] + ``` + Args: + theta (float | torch.Tensor): rotation angle in radians, arbitrary shape (...,) + center (torch.Tensor): rotation center, arbitrary shape (..., 2). Default to (0, 0) + + Returns: + (torch.Tensor): (..., 3, 3) transformation matrix + """ + if isinstance(theta, float): + theta = torch.tensor(theta) + if center is not None: + theta = theta.to(center) + if center is None: + center = torch.zeros(2).to(theta).expand(*theta.shape, -1) + R = rotation_matrix_2d(theta) + return torch.cat([ + torch.cat([ + R, + center[..., :, None] - R @ center[..., :, None], + ], dim=-1), + torch.tensor([[0, 0, 1]], dtype=center.dtype, device=center.device).expand(*center.shape[:-1], -1, -1), + ], dim=-2) + + +def translate_2d(translation: torch.Tensor): + """ + Translation matrix for 2D translation + ``` + [[1, 0, tx], + [0, 1, ty], + [0, 0, 1]] + ``` + Args: + translation (torch.Tensor): translation vector, arbitrary shape (..., 2) + + Returns: + (torch.Tensor): (..., 3, 3) transformation matrix + """ + return torch.cat([ + torch.cat([ + torch.eye(2, dtype=translation.dtype, device=translation.device).expand(*translation.shape[:-1], -1, -1), + translation[..., None], + ], dim=-1), + torch.tensor([[0, 0, 1]], dtype=translation.dtype, device=translation.device).expand(*translation.shape[:-1], -1, -1), + ], dim=-2) + + +def scale_2d(scale: Union[float, torch.Tensor], center: torch.Tensor = None): + """ + Scale matrix for 2D scaling + ``` + [[s, 0, tx], + [0, s, ty], + [0, 0, 1]] + ``` + Args: + scale (float | torch.Tensor): scale factor, arbitrary shape (...,) + center (torch.Tensor): scale center, arbitrary shape (..., 2). Default to (0, 0) + + Returns: + (torch.Tensor): (..., 3, 3) transformation matrix + """ + if isinstance(scale, float): + scale = torch.tensor(scale) + if center is not None: + scale = scale.to(center) + if center is None: + center = torch.zeros(2, dtype=scale.dtype, device=scale.device).expand(*scale.shape, -1) + return torch.cat([ + torch.cat([ + scale * torch.eye(2, dtype=scale.dtype, device=scale.device).expand(*scale.shape[:-1], -1, -1), + center[..., :, None] - center[..., :, None] * scale[..., None, None], + ], dim=-1), + torch.tensor([[0, 0, 1]], dtype=scale.dtype, device=scale.device).expand(*center.shape[:-1], -1, -1), + ], dim=-2) + + +def apply_2d(transform: torch.Tensor, points: torch.Tensor): + """ + Apply (3x3 or 2x3) 2D affine transformation to points + ``` + p = R @ p + t + ``` + Args: + transform (torch.Tensor): (..., 2 or 3, 3) transformation matrix + points (torch.Tensor): (..., N, 2) points to transform + + Returns: + (torch.Tensor): (..., N, 2) transformed points + """ + assert transform.shape[-2:] == (3, 3) or transform.shape[-2:] == (2, 3), "transform must be 3x3 or 2x3" + assert points.shape[-1] == 2, "points must be 2D" + return points @ transform[..., :2, :2].mT + transform[..., :2, None, 2] \ No newline at end of file diff --git a/submodules/MoGe/utils3d/torch/utils.py b/submodules/MoGe/utils3d/torch/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..877ffb8a60a7f5206fbeb5a9e4a584758b875da4 --- /dev/null +++ b/submodules/MoGe/utils3d/torch/utils.py @@ -0,0 +1,351 @@ +from typing import * + +import torch +import torch.nn.functional as F + +from . import transforms +from . import mesh +from ._helpers import batched + + +__all__ = [ + 'sliding_window_1d', + 'sliding_window_2d', + 'sliding_window_nd', + 'image_uv', + 'image_pixel_center', + 'image_mesh', + 'chessboard', + 'depth_edge', + 'depth_aliasing', + 'image_mesh_from_depth', + 'point_to_normal', + 'depth_to_normal', + 'masked_min', + 'masked_max', + 'bounding_rect' +] + + +def sliding_window_1d(x: torch.Tensor, window_size: int, stride: int = 1, dim: int = -1) -> torch.Tensor: + """ + Sliding window view of the input tensor. The dimension of the sliding window is appended to the end of the input tensor's shape. + NOTE: Since Pytorch has `unfold` function, 1D sliding window view is just a wrapper of it. + """ + return x.unfold(dim, window_size, stride) + + +def sliding_window_nd(x: torch.Tensor, window_size: Tuple[int, ...], stride: Tuple[int, ...], dim: Tuple[int, ...]) -> torch.Tensor: + dim = [dim[i] % x.ndim for i in range(len(dim))] + assert len(window_size) == len(stride) == len(dim) + for i in range(len(window_size)): + x = sliding_window_1d(x, window_size[i], stride[i], dim[i]) + return x + + +def sliding_window_2d(x: torch.Tensor, window_size: Union[int, Tuple[int, int]], stride: Union[int, Tuple[int, int]], dim: Union[int, Tuple[int, int]] = (-2, -1)) -> torch.Tensor: + if isinstance(window_size, int): + window_size = (window_size, window_size) + if isinstance(stride, int): + stride = (stride, stride) + return sliding_window_nd(x, window_size, stride, dim) + + +def image_uv(height: int, width: int, left: int = None, top: int = None, right: int = None, bottom: int = None, device: torch.device = None, dtype: torch.dtype = None) -> torch.Tensor: + """ + Get image space UV grid, ranging in [0, 1]. + + >>> image_uv(10, 10): + [[[0.05, 0.05], [0.15, 0.05], ..., [0.95, 0.05]], + [[0.05, 0.15], [0.15, 0.15], ..., [0.95, 0.15]], + ... ... ... + [[0.05, 0.95], [0.15, 0.95], ..., [0.95, 0.95]]] + + Args: + width (int): image width + height (int): image height + + Returns: + np.ndarray: shape (height, width, 2) + """ + if left is None: left = 0 + if top is None: top = 0 + if right is None: right = width + if bottom is None: bottom = height + u = torch.linspace((left + 0.5) / width, (right - 0.5) / width, right - left, device=device, dtype=dtype) + v = torch.linspace((top + 0.5) / height, (bottom - 0.5) / height, bottom - top, device=device, dtype=dtype) + u, v = torch.meshgrid(u, v, indexing='xy') + uv = torch.stack([u, v], dim=-1) + return uv + + +def image_pixel_center( + height: int, + width: int, + left: int = None, + top: int = None, + right: int = None, + bottom: int = None, + dtype: torch.dtype = None, + device: torch.device = None +) -> torch.Tensor: + """ + Get image pixel center coordinates, ranging in [0, width] and [0, height]. + `image[i, j]` has pixel center coordinates `(j + 0.5, i + 0.5)`. + + >>> image_pixel_center(10, 10): + [[[0.5, 0.5], [1.5, 0.5], ..., [9.5, 0.5]], + [[0.5, 1.5], [1.5, 1.5], ..., [9.5, 1.5]], + ... ... ... + [[0.5, 9.5], [1.5, 9.5], ..., [9.5, 9.5]]] + + Args: + width (int): image width + height (int): image height + + Returns: + np.ndarray: shape (height, width, 2) + """ + if left is None: left = 0 + if top is None: top = 0 + if right is None: right = width + if bottom is None: bottom = height + u = torch.linspace(left + 0.5, right - 0.5, right - left, dtype=dtype, device=device) + v = torch.linspace(top + 0.5, bottom - 0.5, bottom - top, dtype=dtype, device=device) + u, v = torch.meshgrid(u, v, indexing='xy') + return torch.stack([u, v], dim=2) + + +def image_mesh(height: int, width: int, mask: torch.Tensor = None, device: torch.device = None, dtype: torch.dtype = None) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Get a quad mesh regarding image pixel uv coordinates as vertices and image grid as faces. + + Args: + width (int): image width + height (int): image height + mask (np.ndarray, optional): binary mask of shape (height, width), dtype=bool. Defaults to None. + + Returns: + uv (np.ndarray): uv corresponding to pixels as described in image_uv() + faces (np.ndarray): quad faces connecting neighboring pixels + indices (np.ndarray, optional): indices of vertices in the original mesh + """ + if device is None and mask is not None: + device = mask.device + if mask is not None: + assert mask.shape[0] == height and mask.shape[1] == width + assert mask.dtype == torch.bool + uv = image_uv(height, width, device=device, dtype=dtype).reshape((-1, 2)) + row_faces = torch.stack([ + torch.arange(0, width - 1, dtype=torch.int32, device=device), + torch.arange(width, 2 * width - 1, dtype=torch.int32, device=device), + torch.arange(1 + width, 2 * width, dtype=torch.int32, device=device), + torch.arange(1, width, dtype=torch.int32, device=device) + ], dim=1) + faces = (torch.arange(0, (height - 1) * width, width, device=device, dtype=torch.int32)[:, None, None] + row_faces[None, :, :]).reshape((-1, 4)) + if mask is not None: + quad_mask = (mask[:-1, :-1] & mask[1:, :-1] & mask[1:, 1:] & mask[:-1, 1:]).ravel() + faces = faces[quad_mask] + faces, uv, indices = mesh.remove_unreferenced_vertices(faces, uv, return_indices=True) + return uv, faces, indices + return uv, faces + + +def depth_edge(depth: torch.Tensor, atol: float = None, rtol: float = None, kernel_size: int = 3, mask: torch.Tensor = None) -> torch.BoolTensor: + """ + Compute the edge mask of a depth map. The edge is defined as the pixels whose neighbors have a large difference in depth. + + Args: + depth (torch.Tensor): shape (..., height, width), linear depth map + atol (float): absolute tolerance + rtol (float): relative tolerance + + Returns: + edge (torch.Tensor): shape (..., height, width) of dtype torch.bool + """ + shape = depth.shape + depth = depth.reshape(-1, 1, *shape[-2:]) + if mask is not None: + mask = mask.reshape(-1, 1, *shape[-2:]) + + if mask is None: + diff = (F.max_pool2d(depth, kernel_size, stride=1, padding=kernel_size // 2) + F.max_pool2d(-depth, kernel_size, stride=1, padding=kernel_size // 2)) + else: + diff = (F.max_pool2d(torch.where(mask, depth, -torch.inf), kernel_size, stride=1, padding=kernel_size // 2) + F.max_pool2d(torch.where(mask, -depth, -torch.inf), kernel_size, stride=1, padding=kernel_size // 2)) + + edge = torch.zeros_like(depth, dtype=torch.bool) + if atol is not None: + edge |= diff > atol + if rtol is not None: + edge |= (diff / depth).nan_to_num_() > rtol + edge = edge.reshape(*shape) + return edge + + +def depth_aliasing(depth: torch.Tensor, atol: float = None, rtol: float = None, kernel_size: int = 3, mask: torch.Tensor = None) -> torch.BoolTensor: + """ + Compute the map that indicates the aliasing of a depth map. The aliasing is defined as the pixels which neither close to the maximum nor the minimum of its neighbors. + Args: + depth (torch.Tensor): shape (..., height, width), linear depth map + atol (float): absolute tolerance + rtol (float): relative tolerance + + Returns: + edge (torch.Tensor): shape (..., height, width) of dtype torch.bool + """ + shape = depth.shape + depth = depth.reshape(-1, 1, *shape[-2:]) + if mask is not None: + mask = mask.reshape(-1, 1, *shape[-2:]) + + if mask is None: + diff_max = F.max_pool2d(depth, kernel_size, stride=1, padding=kernel_size // 2) - depth + diff_min = F.max_pool2d(-depth, kernel_size, stride=1, padding=kernel_size // 2) + depth + else: + diff_max = F.max_pool2d(torch.where(mask, depth, -torch.inf), kernel_size, stride=1, padding=kernel_size // 2) - depth + diff_min = F.max_pool2d(torch.where(mask, -depth, -torch.inf), kernel_size, stride=1, padding=kernel_size // 2) + depth + diff = torch.minimum(diff_max, diff_min) + + edge = torch.zeros_like(depth, dtype=torch.bool) + if atol is not None: + edge |= diff > atol + if rtol is not None: + edge |= (diff / depth).nan_to_num_() > rtol + edge = edge.reshape(*shape) + return edge + + +def image_mesh_from_depth( + depth: torch.Tensor, + extrinsics: torch.Tensor = None, + intrinsics: torch.Tensor = None +) -> Tuple[torch.Tensor, torch.Tensor]: + height, width = depth.shape + uv, faces = image_mesh(height, width) + faces = faces.reshape(-1, 4) + depth = depth.reshape(-1) + pts = transforms.unproject_cv(image_uv, depth, extrinsics, intrinsics) + faces = mesh.triangulate(faces, vertices=pts) + return pts, faces + + +@batched(3, 2, 2) +def point_to_normal(point: torch.Tensor, mask: torch.Tensor = None) -> torch.Tensor: + """ + Calculate normal map from point map. Value range is [-1, 1]. Normal direction in OpenGL identity camera's coordinate system. + + Args: + point (torch.Tensor): shape (..., height, width, 3), point map + Returns: + normal (torch.Tensor): shape (..., height, width, 3), normal map. + """ + has_mask = mask is not None + + if mask is None: + mask = torch.ones_like(point[..., 0], dtype=torch.bool) + mask = F.pad(mask, (1, 1, 1, 1), mode='constant', value=0) + + pts = F.pad(point.permute(0, 3, 1, 2), (1, 1, 1, 1), mode='constant', value=1).permute(0, 2, 3, 1) + up = pts[:, :-2, 1:-1, :] - pts[:, 1:-1, 1:-1, :] + left = pts[:, 1:-1, :-2, :] - pts[:, 1:-1, 1:-1, :] + down = pts[:, 2:, 1:-1, :] - pts[:, 1:-1, 1:-1, :] + right = pts[:, 1:-1, 2:, :] - pts[:, 1:-1, 1:-1, :] + normal = torch.stack([ + torch.cross(up, left, dim=-1), + torch.cross(left, down, dim=-1), + torch.cross(down, right, dim=-1), + torch.cross(right, up, dim=-1), + ]) + normal = F.normalize(normal, dim=-1) + valid = torch.stack([ + mask[:, :-2, 1:-1] & mask[:, 1:-1, :-2], + mask[:, 1:-1, :-2] & mask[:, 2:, 1:-1], + mask[:, 2:, 1:-1] & mask[:, 1:-1, 2:], + mask[:, 1:-1, 2:] & mask[:, :-2, 1:-1], + ]) & mask[None, :, 1:-1, 1:-1] + normal = (normal * valid[..., None]).sum(dim=0) + normal = F.normalize(normal, dim=-1) + + if has_mask: + return normal, valid.any(dim=0) + else: + return normal + + +@batched(2, 2, 2) +def depth_to_normal(depth: torch.Tensor, intrinsics: torch.Tensor, mask: torch.Tensor = None) -> torch.Tensor: + """ + Calculate normal map from depth map. Value range is [-1, 1]. Normal direction in OpenGL identity camera's coordinate system. + + Args: + depth (torch.Tensor): shape (..., height, width), linear depth map + intrinsics (torch.Tensor): shape (..., 3, 3), intrinsics matrix + Returns: + normal (torch.Tensor): shape (..., 3, height, width), normal map. + """ + has_mask = mask is not None + + height, width = depth.shape[-2:] + if mask is None: + mask = torch.ones_like(depth, dtype=torch.bool) + mask = F.pad(mask, (1, 1, 1, 1), mode='constant', value=0) + + uv = image_uv(*depth.shape[-2:]).unsqueeze(0).to(depth) + pts = transforms.unproject_cv(uv.reshape(-1, 2), depth.flatten(-2), intrinsics=intrinsics, extrinsics=None).unflatten(-2, (height, width)) + + return point_to_normal(pts, mask) + + +def masked_min(input: torch.Tensor, mask: torch.BoolTensor, dim: int = None, keepdim: bool = False) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + """Similar to torch.min, but with mask + """ + if dim is None: + return torch.where(mask, input, torch.tensor(torch.inf, dtype=input.dtype, device=input.device)).min() + else: + return torch.where(mask, input, torch.tensor(torch.inf, dtype=input.dtype, device=input.device)).min(dim=dim, keepdim=keepdim) + + +def masked_max(input: torch.Tensor, mask: torch.BoolTensor, dim: int = None, keepdim: bool = False) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + """Similar to torch.max, but with mask + """ + if dim is None: + return torch.where(mask, input, torch.tensor(-torch.inf, dtype=input.dtype, device=input.device)).max() + else: + return torch.where(mask, input, torch.tensor(-torch.inf, dtype=input.dtype, device=input.device)).max(dim=dim, keepdim=keepdim) + + +def bounding_rect(mask: torch.BoolTensor): + """get bounding rectangle of a mask + + Args: + mask (torch.Tensor): shape (..., height, width), mask + + Returns: + rect (torch.Tensor): shape (..., 4), bounding rectangle (left, top, right, bottom) + """ + height, width = mask.shape[-2:] + mask = mask.flatten(-2).unsqueeze(-1) + uv = image_uv(height, width).to(mask.device).reshape(-1, 2) + left_top = masked_min(uv, mask, dim=-2)[0] + right_bottom = masked_max(uv, mask, dim=-2)[0] + return torch.cat([left_top, right_bottom], dim=-1) + + +def chessboard(width: int, height: int, grid_size: int, color_a: torch.Tensor, color_b: torch.Tensor) -> torch.Tensor: + """get a chessboard image + + Args: + width (int): image width + height (int): image height + grid_size (int): size of chessboard grid + color_a (torch.Tensor): shape (chanenls,), color of the grid at the top-left corner + color_b (torch.Tensor): shape (chanenls,), color in complementary grids + + Returns: + image (torch.Tensor): shape (height, width, channels), chessboard image + """ + x = torch.div(torch.arange(width), grid_size, rounding_mode='floor') + y = torch.div(torch.arange(height), grid_size, rounding_mode='floor') + mask = ((x[None, :] + y[:, None]) % 2).to(color_a) + image = (1 - mask[..., None]) * color_a + mask[..., None] * color_b + return image \ No newline at end of file diff --git a/testing/evaluation.py b/testing/evaluation.py new file mode 100644 index 0000000000000000000000000000000000000000..b58fdd9a0f80f1139bcf1d93b792d2ede71b7f6d --- /dev/null +++ b/testing/evaluation.py @@ -0,0 +1,691 @@ +import argparse +from typing import Any, Dict, List, Literal, Tuple +import pandas as pd +import os +import sys + +import torch +from diffusers import ( + CogVideoXPipeline, + CogVideoXDDIMScheduler, + CogVideoXDPMScheduler, + CogVideoXImageToVideoPipeline, + CogVideoXVideoToVideoPipeline, +) + +from diffusers.utils import export_to_video, load_image, load_video + +import numpy as np +import random +import cv2 +from pathlib import Path +import decord +from torchvision import transforms +from torchvision.transforms.functional import resize + +import PIL.Image +from PIL import Image + +current_dir = os.path.dirname(os.path.abspath(__file__)) +sys.path.append(os.path.join(current_dir, '..')) +from models.cogvideox_tracking import CogVideoXImageToVideoPipelineTracking, CogVideoXPipelineTracking, CogVideoXVideoToVideoPipelineTracking +from training.dataset import VideoDataset, VideoDatasetWithResizingTracking + +class VideoDatasetWithResizingTrackingEval(VideoDataset): + def __init__(self, *args, **kwargs) -> None: + self.tracking_column = kwargs.pop("tracking_column", None) + self.image_paths = kwargs.pop("image_paths", None) + super().__init__(*args, **kwargs) + + def _preprocess_video(self, path: Path, tracking_path: Path, image_paths: Path = None) -> torch.Tensor: + if self.load_tensors: + return self._load_preprocessed_latents_and_embeds(path, tracking_path) + else: + video_reader = decord.VideoReader(uri=path.as_posix()) + video_num_frames = len(video_reader) + nearest_frame_bucket = min( + self.frame_buckets, key=lambda x: abs(x - min(video_num_frames, self.max_num_frames)) + ) + + frame_indices = list(range(0, video_num_frames, video_num_frames // nearest_frame_bucket)) + + frames = video_reader.get_batch(frame_indices) + frames = frames[:nearest_frame_bucket].float() + frames = frames.permute(0, 3, 1, 2).contiguous() + + nearest_res = self._find_nearest_resolution(frames.shape[2], frames.shape[3]) + frames_resized = torch.stack([resize(frame, nearest_res) for frame in frames], dim=0) + frames = torch.stack([self.video_transforms(frame) for frame in frames_resized], dim=0) + + image = Image.open(image_paths) + if image.mode != 'RGB': + image = image.convert('RGB') + + image = torch.from_numpy(np.array(image)).float() + image = image.permute(2, 0, 1).contiguous() + image = resize(image, nearest_res) + image = self.video_transforms(image) + + tracking_reader = decord.VideoReader(uri=tracking_path.as_posix()) + tracking_frames = tracking_reader.get_batch(frame_indices) + tracking_frames = tracking_frames[:nearest_frame_bucket].float() + tracking_frames = tracking_frames.permute(0, 3, 1, 2).contiguous() + tracking_frames_resized = torch.stack([resize(tracking_frame, nearest_res) for tracking_frame in tracking_frames], dim=0) + tracking_frames = torch.stack([self.video_transforms(tracking_frame) for tracking_frame in tracking_frames_resized], dim=0) + + return image, frames, tracking_frames, None + + def _find_nearest_resolution(self, height, width): + nearest_res = min(self.resolutions, key=lambda x: abs(x[1] - height) + abs(x[2] - width)) + return nearest_res[1], nearest_res[2] + + def _load_dataset_from_local_path(self) -> Tuple[List[str], List[str], List[str]]: + if not self.data_root.exists(): + raise ValueError("Root folder for videos does not exist") + + prompt_path = self.data_root.joinpath(self.caption_column) + video_path = self.data_root.joinpath(self.video_column) + tracking_path = self.data_root.joinpath(self.tracking_column) + image_paths = self.data_root.joinpath(self.image_paths) + + if not prompt_path.exists() or not prompt_path.is_file(): + raise ValueError( + "Expected `--caption_column` to be path to a file in `--data_root` containing line-separated text prompts." + ) + if not video_path.exists() or not video_path.is_file(): + raise ValueError( + "Expected `--video_column` to be path to a file in `--data_root` containing line-separated paths to video data in the same directory." + ) + if not tracking_path.exists() or not tracking_path.is_file(): + raise ValueError( + "Expected `--tracking_column` to be path to a file in `--data_root` containing line-separated tracking information." + ) + + with open(prompt_path, "r", encoding="utf-8") as file: + prompts = [line.strip() for line in file.readlines() if len(line.strip()) > 0] + with open(video_path, "r", encoding="utf-8") as file: + video_paths = [self.data_root.joinpath(line.strip()) for line in file.readlines() if len(line.strip()) > 0] + with open(tracking_path, "r", encoding="utf-8") as file: + tracking_paths = [self.data_root.joinpath(line.strip()) for line in file.readlines() if len(line.strip()) > 0] + + with open(image_paths, "r", encoding="utf-8") as file: + image_paths_list = [self.data_root.joinpath(line.strip()) for line in file.readlines() if len(line.strip()) > 0] + + if not self.load_tensors and any(not path.is_file() for path in video_paths): + raise ValueError( + f"Expected `{self.video_column=}` to be a path to a file in `{self.data_root=}` containing line-separated paths to video data but found atleast one path that is not a valid file." + ) + + self.tracking_paths = tracking_paths + self.image_paths = image_paths_list + return prompts, video_paths + + def _load_dataset_from_csv(self) -> Tuple[List[str], List[str], List[str]]: + df = pd.read_csv(self.dataset_file) + prompts = df[self.caption_column].tolist() + video_paths = df[self.video_column].tolist() + tracking_paths = df[self.tracking_column].tolist() + image_paths = df[self.image_paths].tolist() + video_paths = [self.data_root.joinpath(line.strip()) for line in video_paths] + tracking_paths = [self.data_root.joinpath(line.strip()) for line in tracking_paths] + image_paths = [self.data_root.joinpath(line.strip()) for line in image_paths] + + if any(not path.is_file() for path in video_paths): + raise ValueError( + f"Expected `{self.video_column=}` to be a path to a file in `{self.data_root=}` containing line-separated paths to video data but found at least one path that is not a valid file." + ) + + self.tracking_paths = tracking_paths + self.image_paths = image_paths + return prompts, video_paths + + def __getitem__(self, index: int) -> Dict[str, Any]: + if isinstance(index, list): + return index + + if self.load_tensors: + image_latents, video_latents, prompt_embeds = self._preprocess_video(self.video_paths[index], self.tracking_paths[index]) + + # The VAE's temporal compression ratio is 4. + # The VAE's spatial compression ratio is 8. + latent_num_frames = video_latents.size(1) + if latent_num_frames % 2 == 0: + num_frames = latent_num_frames * 4 + else: + num_frames = (latent_num_frames - 1) * 4 + 1 + + height = video_latents.size(2) * 8 + width = video_latents.size(3) * 8 + + return { + "prompt": prompt_embeds, + "image": image_latents, + "video": video_latents, + "tracking_map": tracking_map, + "video_metadata": { + "num_frames": num_frames, + "height": height, + "width": width, + }, + } + else: + image, video, tracking_map, _ = self._preprocess_video(self.video_paths[index], self.tracking_paths[index], self.image_paths[index]) + + return { + "prompt": self.id_token + self.prompts[index], + "image": image, + "video": video, + "tracking_map": tracking_map, + "video_metadata": { + "num_frames": video.shape[0], + "height": video.shape[2], + "width": video.shape[3], + }, + } + + def _load_preprocessed_latents_and_embeds(self, path: Path, tracking_path: Path) -> Tuple[torch.Tensor, torch.Tensor]: + filename_without_ext = path.name.split(".")[0] + pt_filename = f"{filename_without_ext}.pt" + + # The current path is something like: /a/b/c/d/videos/00001.mp4 + # We need to reach: /a/b/c/d/video_latents/00001.pt + image_latents_path = path.parent.parent.joinpath("image_latents") + video_latents_path = path.parent.parent.joinpath("video_latents") + tracking_map_path = path.parent.parent.joinpath("tracking_map") + embeds_path = path.parent.parent.joinpath("prompt_embeds") + + if ( + not video_latents_path.exists() + or not embeds_path.exists() + or not tracking_map_path.exists() + or (self.image_to_video and not image_latents_path.exists()) + ): + raise ValueError( + f"When setting the load_tensors parameter to `True`, it is expected that the `{self.data_root=}` contains folders named `video_latents`, `prompt_embeds`, and `tracking_map`. However, these folders were not found. Please make sure to have prepared your data correctly using `prepare_data.py`. Additionally, if you're training image-to-video, it is expected that an `image_latents` folder is also present." + ) + + if self.image_to_video: + image_latent_filepath = image_latents_path.joinpath(pt_filename) + video_latent_filepath = video_latents_path.joinpath(pt_filename) + tracking_map_filepath = tracking_map_path.joinpath(pt_filename) + embeds_filepath = embeds_path.joinpath(pt_filename) + + if not video_latent_filepath.is_file() or not embeds_filepath.is_file() or not tracking_map_filepath.is_file(): + if self.image_to_video: + image_latent_filepath = image_latent_filepath.as_posix() + video_latent_filepath = video_latent_filepath.as_posix() + tracking_map_filepath = tracking_map_filepath.as_posix() + embeds_filepath = embeds_filepath.as_posix() + raise ValueError( + f"The file {video_latent_filepath=} or {embeds_filepath=} or {tracking_map_filepath=} could not be found. Please ensure that you've correctly executed `prepare_dataset.py`." + ) + + images = ( + torch.load(image_latent_filepath, map_location="cpu", weights_only=True) if self.image_to_video else None + ) + latents = torch.load(video_latent_filepath, map_location="cpu", weights_only=True) + tracking_map = torch.load(tracking_map_filepath, map_location="cpu", weights_only=True) + embeds = torch.load(embeds_filepath, map_location="cpu", weights_only=True) + + return images, latents, tracking_map, embeds + +def sample_from_dataset( + data_root: str, + caption_column: str, + tracking_column: str, + image_paths: str, + video_column: str, + num_samples: int = -1, + random_seed: int = 42 +): + """Sample from dataset""" + if image_paths: + # If image_paths is provided, use VideoDatasetWithResizingTrackingEval + dataset = VideoDatasetWithResizingTrackingEval( + data_root=data_root, + caption_column=caption_column, + tracking_column=tracking_column, + image_paths=image_paths, + video_column=video_column, + max_num_frames=49, + load_tensors=False, + random_flip=None, + frame_buckets=[49], + image_to_video=True + ) + else: + # If image_paths is not provided, use VideoDatasetWithResizingTracking + dataset = VideoDatasetWithResizingTracking( + data_root=data_root, + caption_column=caption_column, + tracking_column=tracking_column, + video_column=video_column, + max_num_frames=49, + load_tensors=False, + random_flip=None, + frame_buckets=[49], + image_to_video=True + ) + + # Set random seed + random.seed(random_seed) + + # Randomly sample from dataset + total_samples = len(dataset) + if num_samples == -1: + # If num_samples is -1, process all samples + selected_indices = range(total_samples) + else: + selected_indices = random.sample(range(total_samples), min(num_samples, total_samples)) + + samples = [] + for idx in selected_indices: + sample = dataset[idx] + # Get data based on dataset.__getitem__ return value + image = sample["image"] # Already processed tensor + video = sample["video"] # Already processed tensor + tracking_map = sample["tracking_map"] # Already processed tensor + prompt = sample["prompt"] + + samples.append({ + "prompt": prompt, + "tracking_frame": tracking_map[0], # Get first frame + "video_frame": image, # Get first frame + "video": video, # Complete video + "tracking_maps": tracking_map, # Complete tracking maps + "height": sample["video_metadata"]["height"], + "width": sample["video_metadata"]["width"] + }) + + return samples + +def generate_video( + prompt: str, + model_path: str, + tracking_path: str = None, + output_path: str = "./output.mp4", + image_or_video_path: str = "", + num_inference_steps: int = 50, + guidance_scale: float = 6.0, + num_videos_per_prompt: int = 1, + dtype: torch.dtype = torch.bfloat16, + generate_type: str = Literal["i2v", "i2vo"], # i2v: image to video, i2vo: original CogVideoX-5b-I2V + seed: int = 42, + data_root: str = None, + caption_column: str = None, + tracking_column: str = None, + video_column: str = None, + image_paths: str = None, + num_samples: int = -1, + evaluation_dir: str = "evaluations", + fps: int = 8, +): + device = "cuda" if torch.cuda.is_available() else "cpu" + + # If dataset parameters are provided, sample from dataset + samples = None + if all([data_root, caption_column, tracking_column, video_column]): + samples = sample_from_dataset( + data_root=data_root, + caption_column=caption_column, + tracking_column=tracking_column, + image_paths=image_paths, + video_column=video_column, + num_samples=num_samples, + random_seed=seed + ) + + # Load model and data + if generate_type == "i2v": + pipe = CogVideoXImageToVideoPipelineTracking.from_pretrained(model_path, torch_dtype=dtype) + if not samples: + image = load_image(image=image_or_video_path) + height, width = image.height, image.width + else: + pipe = CogVideoXImageToVideoPipeline.from_pretrained("THUDM/CogVideoX-5b-I2V", torch_dtype=dtype) + if not samples: + image = load_image(image=image_or_video_path) + height, width = image.height, image.width + + # Set model parameters + pipe.to(device, dtype=dtype) + pipe.vae.enable_slicing() + pipe.vae.enable_tiling() + pipe.transformer.eval() + pipe.text_encoder.eval() + pipe.vae.eval() + pipe.transformer.gradient_checkpointing = False + pipe.scheduler = CogVideoXDPMScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing") + + # Generate video + if samples: + from tqdm import tqdm + for i, sample in tqdm(enumerate(samples), desc="Samples Num:"): + print(f"Prompt: {sample['prompt'][:30]}") + tracking_frame = sample["tracking_frame"].to(device=device, dtype=dtype) + video_frame = sample["video_frame"].to(device=device, dtype=dtype) + video = sample["video"].to(device=device, dtype=dtype) + tracking_maps = sample["tracking_maps"].to(device=device, dtype=dtype) + + # VAE + print("encoding tracking maps") + tracking_video = tracking_maps + tracking_maps = tracking_maps.unsqueeze(0) + tracking_maps = tracking_maps.permute(0, 2, 1, 3, 4) # [B, C, F, H, W] + with torch.no_grad(): + tracking_latent_dist = pipe.vae.encode(tracking_maps).latent_dist + tracking_maps = tracking_latent_dist.sample() * pipe.vae.config.scaling_factor + tracking_maps = tracking_maps.permute(0, 2, 1, 3, 4) # [B, F, C, H, W] + + + pipeline_args = { + "prompt": sample["prompt"], + "negative_prompt": "The video is not of a high quality, it has a low resolution. Watermark present in each frame. The background is solid. Strange body and strange trajectory. Distortion.", + "num_inference_steps": num_inference_steps, + "num_frames": 49, + "use_dynamic_cfg": True, + "guidance_scale": guidance_scale, + "generator": torch.Generator(device=device).manual_seed(seed), + "height": sample["height"], + "width": sample["width"] + } + + pipeline_args["image"] = (video_frame + 1.0) / 2.0 + + if tracking_column and generate_type == "i2v": + pipeline_args["tracking_maps"] = tracking_maps + pipeline_args["tracking_image"] = (tracking_frame.unsqueeze(0) + 1.0) / 2.0 + + with torch.no_grad(): + video_generate = pipe(**pipeline_args).frames[0] + + output_dir = os.path.join(data_root, evaluation_dir) + output_name = f"{i:04d}.mp4" + output_file = os.path.join(output_dir, output_name) + os.makedirs(output_dir, exist_ok=True) + export_concat_video(video_generate, video, tracking_video, output_file, fps=fps) + + else: + pipeline_args = { + "prompt": prompt, + "num_videos_per_prompt": num_videos_per_prompt, + "num_inference_steps": num_inference_steps, + "num_frames": 49, + "use_dynamic_cfg": True, + "guidance_scale": guidance_scale, + "generator": torch.Generator().manual_seed(seed), + } + + pipeline_args["video"] = video + pipeline_args["image"] = image + pipeline_args["height"] = height + pipeline_args["width"] = width + + if tracking_path and generate_type == "i2v": + tracking_maps = load_video(tracking_path) + tracking_maps = torch.stack([ + torch.from_numpy(np.array(frame)).permute(2, 0, 1).float() / 255.0 + for frame in tracking_maps + ]).to(device=device, dtype=dtype) + + tracking_video = tracking_maps + tracking_maps = tracking_maps.unsqueeze(0) + tracking_maps = tracking_maps.permute(0, 2, 1, 3, 4) + with torch.no_grad(): + tracking_latent_dist = pipe.vae.encode(tracking_maps).latent_dist + tracking_maps = tracking_latent_dist.sample() * pipe.vae.config.scaling_factor + tracking_maps = tracking_maps.permute(0, 2, 1, 3, 4) + + pipeline_args["tracking_maps"] = tracking_maps + pipeline_args["tracking_image"] = tracking_maps[:, :1] + + with torch.no_grad(): + video_generate = pipe(**pipeline_args).frames[0] + + output_dir = os.path.join(data_root, evaluation_dir) + output_name = f"{os.path.splitext(os.path.basename(image_or_video_path))[0]}.mp4" + output_file = os.path.join(output_dir, output_name) + os.makedirs(output_dir, exist_ok=True) + export_concat_video(video_generate, video, tracking_video, output_file, fps=fps) + +def create_frame_grid(frames: List[np.ndarray], interval: int = 9, max_cols: int = 7) -> np.ndarray: + """ + Arrange video frames into a grid image by sampling at intervals + + Args: + frames: List of video frames + interval: Sampling interval + max_cols: Maximum number of frames per row + + Returns: + Grid image array + """ + # Sample frames at intervals + sampled_frames = frames[::interval] + + # Calculate number of rows and columns + n_frames = len(sampled_frames) + n_cols = min(max_cols, n_frames) + n_rows = (n_frames + n_cols - 1) // n_cols + + # Get height and width of single frame + frame_height, frame_width = sampled_frames[0].shape[:2] + + # Create blank canvas + grid = np.zeros((frame_height * n_rows, frame_width * n_cols, 3), dtype=np.uint8) + + # Fill frames + for idx, frame in enumerate(sampled_frames): + i = idx // n_cols + j = idx % n_cols + grid[i*frame_height:(i+1)*frame_height, j*frame_width:(j+1)*frame_width] = frame + + return grid + +def export_concat_video( + generated_frames: List[PIL.Image.Image], + original_video: torch.Tensor, + tracking_maps: torch.Tensor = None, + output_video_path: str = None, + fps: int = 8 +) -> str: + """ + Export generated video frames, original video and tracking maps as video files, + and save sampled frames to different folders + """ + import imageio + import os + + if output_video_path is None: + output_video_path = tempfile.NamedTemporaryFile(suffix=".mp4").name + + # Create subdirectories + base_dir = os.path.dirname(output_video_path) + generated_dir = os.path.join(base_dir, "generated") # For storing generated videos + group_dir = os.path.join(base_dir, "group") # For storing concatenated videos + + # Get filename (without path) and create video-specific folder + filename = os.path.basename(output_video_path) + name_without_ext = os.path.splitext(filename)[0] + video_frames_dir = os.path.join(base_dir, "frames", name_without_ext) # frames/video_name/ + + # Create three subdirectories under video-specific folder + groundtruth_dir = os.path.join(video_frames_dir, "gt") + generated_frames_dir = os.path.join(video_frames_dir, "generated") + tracking_dir = os.path.join(video_frames_dir, "tracking") + + # Create all required directories + os.makedirs(generated_dir, exist_ok=True) + os.makedirs(group_dir, exist_ok=True) + os.makedirs(groundtruth_dir, exist_ok=True) + os.makedirs(generated_frames_dir, exist_ok=True) + os.makedirs(tracking_dir, exist_ok=True) + + # Convert original video tensor to numpy array and adjust format + original_frames = [] + for frame in original_video: + frame = frame.permute(1,2,0).to(dtype=torch.float32,device="cpu").numpy() + frame = ((frame + 1.0) * 127.5).astype(np.uint8) + original_frames.append(frame) + + tracking_frames = [] + if tracking_maps is not None: + for frame in tracking_maps: + frame = frame.permute(1,2,0).to(dtype=torch.float32,device="cpu").numpy() + frame = ((frame + 1.0) * 127.5).astype(np.uint8) + tracking_frames.append(frame) + + # Ensure all videos have same number of frames + num_frames = min(len(generated_frames), len(original_frames)) + if tracking_maps is not None: + num_frames = min(num_frames, len(tracking_frames)) + + generated_frames = generated_frames[:num_frames] + original_frames = original_frames[:num_frames] + if tracking_maps is not None: + tracking_frames = tracking_frames[:num_frames] + + # Convert generated PIL images to numpy arrays + generated_frames_np = [np.array(frame) for frame in generated_frames] + + # Save generated video separately to generated folder + gen_video_path = os.path.join(generated_dir, f"{name_without_ext}_generated.mp4") + with imageio.get_writer(gen_video_path, fps=fps) as writer: + for frame in generated_frames_np: + writer.append_data(frame) + + # Concatenate frames vertically and save sampled frames + concat_frames = [] + for i in range(num_frames): + gen_frame = generated_frames_np[i] + orig_frame = original_frames[i] + + width = min(gen_frame.shape[1], orig_frame.shape[1]) + height = orig_frame.shape[0] + + gen_frame = Image.fromarray(gen_frame).resize((width, height)) + gen_frame = np.array(gen_frame) + orig_frame = Image.fromarray(orig_frame).resize((width, height)) + orig_frame = np.array(orig_frame) + + if tracking_maps is not None: + track_frame = tracking_frames[i] + track_frame = Image.fromarray(track_frame).resize((width, height)) + track_frame = np.array(track_frame) + + right_concat = np.concatenate([orig_frame, track_frame], axis=0) + + right_concat_pil = Image.fromarray(right_concat) + new_height = right_concat.shape[0] // 2 + new_width = right_concat.shape[1] // 2 + right_concat_resized = right_concat_pil.resize((new_width, new_height)) + right_concat_resized = np.array(right_concat_resized) + + concat_frame = np.concatenate([gen_frame, right_concat_resized], axis=1) + else: + orig_frame_pil = Image.fromarray(orig_frame) + new_height = orig_frame.shape[0] // 2 + new_width = orig_frame.shape[1] // 2 + orig_frame_resized = orig_frame_pil.resize((new_width, new_height)) + orig_frame_resized = np.array(orig_frame_resized) + + concat_frame = np.concatenate([gen_frame, orig_frame_resized], axis=1) + + concat_frames.append(concat_frame) + + # Save every 9 frames of each type of frame + if i % 9 == 0: + # Save generated frame + gen_frame_path = os.path.join(generated_frames_dir, f"{i:04d}.png") + Image.fromarray(gen_frame).save(gen_frame_path) + + # Save original frame + gt_frame_path = os.path.join(groundtruth_dir, f"{i:04d}.png") + Image.fromarray(orig_frame).save(gt_frame_path) + + # If tracking maps, save tracking frame + if tracking_maps is not None: + track_frame_path = os.path.join(tracking_dir, f"{i:04d}.png") + Image.fromarray(track_frame).save(track_frame_path) + + # Export concatenated video to group folder + group_video_path = os.path.join(group_dir, filename) + with imageio.get_writer(group_video_path, fps=fps) as writer: + for frame in concat_frames: + writer.append_data(frame) + + return group_video_path + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Generate a video from a text prompt using CogVideoX") + parser.add_argument("--prompt", type=str, help="Optional: override the prompt from dataset") + parser.add_argument( + "--image_or_video_path", + type=str, + default=None, + help="The path of the image to be used as the background of the video", + ) + parser.add_argument( + "--model_path", type=str, default="THUDM/CogVideoX-5b", help="The path of the pre-trained model to be used" + ) + parser.add_argument( + "--output_path", type=str, default="./output.mp4", help="The path where the generated video will be saved" + ) + parser.add_argument("--guidance_scale", type=float, default=6.0, help="The scale for classifier-free guidance") + parser.add_argument( + "--num_inference_steps", type=int, default=50, help="Number of steps for the inference process" + ) + parser.add_argument("--num_videos_per_prompt", type=int, default=1, help="Number of videos to generate per prompt") + parser.add_argument( + "--generate_type", type=str, default="i2v", help="The type of video generation (e.g., 'i2v', 'i2vo')" + ) + parser.add_argument( + "--dtype", type=str, default="bfloat16", help="The data type for computation (e.g., 'float16' or 'bfloat16')" + ) + parser.add_argument("--seed", type=int, default=42, help="The seed for reproducibility") + parser.add_argument("--tracking_path", type=str, default=None, help="The path of the tracking maps to be used") + + # Dataset related parameters are required + parser.add_argument("--data_root", type=str, required=True, help="Root directory of the dataset") + parser.add_argument("--caption_column", type=str, required=True, help="Name of the caption column") + parser.add_argument("--tracking_column", type=str, required=True, help="Name of the tracking column") + parser.add_argument("--video_column", type=str, required=True, help="Name of the video column") + parser.add_argument("--image_paths", type=str, required=False, help="Name of the image column") + + # Add num_samples parameter + parser.add_argument("--num_samples", type=int, default=-1, + help="Number of samples to process. -1 means process all samples") + + # Add evaluation_dir parameter + parser.add_argument("--evaluation_dir", type=str, default="evaluations", + help="Name of the directory to store evaluation results") + + # Add fps parameter + parser.add_argument("--fps", type=int, default=8, + help="Frames per second for the output video") + + args = parser.parse_args() + dtype = torch.float16 if args.dtype == "float16" else torch.bfloat16 + + # If prompt is not provided, generate_video function will use prompts from dataset + generate_video( + prompt=args.prompt, # Can be None + model_path=args.model_path, + tracking_path=args.tracking_path, + image_paths=args.image_paths, + output_path=args.output_path, + image_or_video_path=args.image_or_video_path, + num_inference_steps=args.num_inference_steps, + guidance_scale=args.guidance_scale, + num_videos_per_prompt=args.num_videos_per_prompt, + dtype=dtype, + generate_type=args.generate_type, + seed=args.seed, + data_root=args.data_root, + caption_column=args.caption_column, + tracking_column=args.tracking_column, + video_column=args.video_column, + num_samples=args.num_samples, + evaluation_dir=args.evaluation_dir, + fps=args.fps, + ) \ No newline at end of file diff --git a/testing/inference.py b/testing/inference.py new file mode 100644 index 0000000000000000000000000000000000000000..65ee4211a2071b348411a02580567a71e1ac3769 --- /dev/null +++ b/testing/inference.py @@ -0,0 +1,210 @@ +import argparse +from typing import Literal +import os +import sys + +import torch +from diffusers import ( + CogVideoXDPMScheduler, + CogVideoXImageToVideoPipeline, +) + +from diffusers.utils import export_to_video, load_image, load_video + +import numpy as np + +current_dir = os.path.dirname(os.path.abspath(__file__)) +sys.path.append(os.path.join(current_dir, '..')) +from models.cogvideox_tracking import CogVideoXImageToVideoPipelineTracking, CogVideoXPipelineTracking, CogVideoXVideoToVideoPipelineTracking +from models.cogvideox_tracking import CogVideoXTransformer3DModelTracking + +def generate_video( + prompt: str, + model_path: str, + tracking_path: str = None, + tracking_video: torch.Tensor = None, + output_path: str = "./output.mp4", + image_or_video_path: str = "", + num_inference_steps: int = 50, + guidance_scale: float = 6.0, + num_videos_per_prompt: int = 1, + dtype: torch.dtype = torch.bfloat16, + generate_type: str = Literal["t2v", "i2v"], # i2v: image to video, i2vo: original CogVideoX-5b-I2V + fps: int = 24, + seed: int = 42, +): + """ + Generates a video based on the given prompt and saves it to the specified path. + + Parameters: + - prompt (str): The description of the video to be generated. + - model_path (str): The path of the pre-trained model to be used. + - tracking_path (str): The path of the tracking maps to be used. + - output_path (str): The path where the generated video will be saved. + - num_inference_steps (int): Number of steps for the inference process. More steps can result in better quality. + - guidance_scale (float): The scale for classifier-free guidance. Higher values can lead to better alignment with the prompt. + - num_videos_per_prompt (int): Number of videos to generate per prompt. + - dtype (torch.dtype): The data type for computation (default is torch.bfloat16). + - generate_type (str): The type of video generation (e.g., 't2v', 'i2v', 'v2v').· + - seed (int): The seed for reproducibility. + """ + + # 1. Load the pre-trained CogVideoX pipeline with the specified precision (bfloat16). + # add device_map="balanced" in the from_pretrained function and remove the enable_model_cpu_offload() + # function to use Multi GPUs. + + image = None + video = None + device = "cuda" if torch.cuda.is_available() else "cpu" + + # transformer = CogVideoXTransformer3DModelTracking.from_pretrained( + # model_path, + # subfolder="transformer", + # torch_dtype=dtype + # ) + + if generate_type == "i2v": + pipe = CogVideoXImageToVideoPipelineTracking.from_pretrained(model_path, torch_dtype=dtype) + image = load_image(image=image_or_video_path) + height, width = image.height, image.width + else: + pipe = CogVideoXImageToVideoPipeline.from_pretrained("THUDM/CogVideoX-5b-I2V", torch_dtype=dtype) + image = load_image(image=image_or_video_path) + height, width = image.height, image.width + + pipe.transformer.eval() + pipe.text_encoder.eval() + pipe.vae.eval() + + for param in pipe.transformer.parameters(): + param.requires_grad = False + + pipe.transformer.gradient_checkpointing = False + + # Convert tracking maps from list of PIL Images to tensor + if tracking_path is not None: + tracking_maps = load_video(tracking_path) + # Convert list of PIL Images to tensor [T, C, H, W] + tracking_maps = torch.stack([ + torch.from_numpy(np.array(frame)).permute(2, 0, 1).float() / 255.0 + for frame in tracking_maps + ]) + tracking_maps = tracking_maps.to(device=device, dtype=dtype) + tracking_first_frame = tracking_maps[0:1] # Get first frame as [1, C, H, W] + height, width = tracking_first_frame.shape[2], tracking_first_frame.shape[3] + elif tracking_video is not None: + tracking_maps = tracking_video.float() / 255.0 # [T, C, H, W] + tracking_maps = tracking_maps.to(device=device, dtype=dtype) + tracking_first_frame = tracking_maps[0:1] # Get first frame as [1, C, H, W] + height, width = tracking_first_frame.shape[2], tracking_first_frame.shape[3] + else: + tracking_maps = None + tracking_first_frame = None + + # 2. Set Scheduler. + pipe.scheduler = CogVideoXDPMScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing") + + pipe.to(device, dtype=dtype) + # pipe.enable_sequential_cpu_offload() + + pipe.vae.enable_slicing() + pipe.vae.enable_tiling() + pipe.transformer.eval() + pipe.text_encoder.eval() + pipe.vae.eval() + + pipe.transformer.gradient_checkpointing = False + + if tracking_maps is not None and generate_type == "i2v": + print("Encoding tracking maps") + tracking_maps = tracking_maps.unsqueeze(0) # [B, T, C, H, W] + tracking_maps = tracking_maps.permute(0, 2, 1, 3, 4) # [B, C, T, H, W] + with torch.no_grad(): + tracking_latent_dist = pipe.vae.encode(tracking_maps).latent_dist + tracking_maps = tracking_latent_dist.sample() * pipe.vae.config.scaling_factor + tracking_maps = tracking_maps.permute(0, 2, 1, 3, 4) # [B, F, C, H, W] + else: + tracking_maps = None + tracking_first_frame = None + + # 4. Generate the video frames based on the prompt. + if generate_type == "i2v": + with torch.no_grad(): + video_generate = pipe( + prompt=prompt, + negative_prompt="The video is not of a high quality, it has a low resolution. Watermark present in each frame. The background is solid. Strange body and strange trajectory. Distortion.", + image=image, + num_videos_per_prompt=num_videos_per_prompt, + num_inference_steps=num_inference_steps, + num_frames=49, + use_dynamic_cfg=True, + guidance_scale=guidance_scale, + generator=torch.Generator().manual_seed(seed), + tracking_maps=tracking_maps, + tracking_image=tracking_first_frame, + height=height, + width=width, + ).frames[0] + else: + with torch.no_grad(): + video_generate = pipe( + prompt=prompt, + negative_prompt="The video is not of a high quality, it has a low resolution. Watermark present in each frame. The background is solid. Strange body and strange trajectory. Distortion.", + image=image, + num_videos_per_prompt=num_videos_per_prompt, + num_inference_steps=num_inference_steps, + num_frames=49, + use_dynamic_cfg=True, + guidance_scale=guidance_scale, + generator=torch.Generator().manual_seed(seed), + ).frames[0] + # 5. Export the generated frames to a video file. fps must be 8 for original video. + output_path = output_path if output_path else f"{generate_type}_img[{os.path.splitext(os.path.basename(image_or_video_path))[0]}]_txt[{prompt}].mp4" + os.makedirs(os.path.dirname(output_path), exist_ok=True) + export_to_video(video_generate, output_path, fps=fps) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Generate a video from a text prompt using CogVideoX") + parser.add_argument("--prompt", type=str, required=True, help="The description of the video to be generated") + parser.add_argument( + "--image_or_video_path", + type=str, + default=None, + help="The path of the image to be used as the background of the video", + ) + parser.add_argument( + "--model_path", type=str, default="THUDM/CogVideoX-5b", help="The path of the pre-trained model to be used" + ) + parser.add_argument( + "--output_path", type=str, default="./output.mp4", help="The path where the generated video will be saved" + ) + parser.add_argument("--guidance_scale", type=float, default=6.0, help="The scale for classifier-free guidance") + parser.add_argument( + "--num_inference_steps", type=int, default=50, help="Number of steps for the inference process" + ) + parser.add_argument("--num_videos_per_prompt", type=int, default=1, help="Number of videos to generate per prompt") + parser.add_argument( + "--generate_type", type=str, default="t2v", help="The type of video generation (e.g., 't2v', 'i2v', 'v2v')" + ) + parser.add_argument( + "--dtype", type=str, default="bfloat16", help="The data type for computation (e.g., 'float16' or 'bfloat16')" + ) + parser.add_argument("--seed", type=int, default=42, help="The seed for reproducibility") + parser.add_argument("--tracking_path", type=str, default=None, help="The path of the tracking maps to be used") + + args = parser.parse_args() + dtype = torch.float16 if args.dtype == "float16" else torch.bfloat16 + generate_video( + prompt=args.prompt, + model_path=args.model_path, + tracking_path=args.tracking_path, + output_path=args.output_path, + image_or_video_path=args.image_or_video_path, + num_inference_steps=args.num_inference_steps, + guidance_scale=args.guidance_scale, + num_videos_per_prompt=args.num_videos_per_prompt, + dtype=dtype, + generate_type=args.generate_type, + seed=args.seed, + ) \ No newline at end of file