diff --git a/.gitignore b/.gitignore
new file mode 100644
index 0000000000000000000000000000000000000000..837587841c89bc2865e84e36f2fe4382e752cc0d
--- /dev/null
+++ b/.gitignore
@@ -0,0 +1,17 @@
+# Caches.
+.DS_Store
+__pycache__
+*.egg-info
+
+# VEnv.
+.hsmr_env
+
+# Hydra dev.
+outputs
+
+# Development tools.
+.vscode
+pyrightconfig.json
+
+# Logs.
+gpu_monitor.log
diff --git a/Dockerfile b/Dockerfile
new file mode 100644
index 0000000000000000000000000000000000000000..6a08aa69d6c15d28ffa17d6f46c9e997a6784971
--- /dev/null
+++ b/Dockerfile
@@ -0,0 +1,37 @@
+# read the doc: https://huggingface.co/docs/hub/spaces-sdks-docker
+# you will also find guides on how best to write your Dockerfile
+
+FROM python:3.10
+
+WORKDIR /code
+
+COPY ./requirements_part1.txt /code/requirements_part1.txt
+COPY ./requirements_part2.txt /code/requirements_part2.txt
+
+RUN apt-get update -y && \
+ apt-get upgrade -y && \
+ apt-get install -y libglfw3-dev && \
+ apt-get install -y libgles2-mesa-dev && \
+ apt-get install -y aria2 && \
+ pip install --no-cache-dir --upgrade -r /code/requirements_part1.txt && \
+ pip install --no-cache-dir --upgrade -r /code/requirements_part2.txt
+
+# Set up a new user named "user" with user ID 1000
+RUN useradd -m -u 1000 user
+
+# Switch to the "user" user
+USER user
+
+
+# Set home to the user's home directory
+ENV HOME=/home/user \
+ PATH=/home/user/.local/bin:$PATH
+
+# Set the working directory to the user's home directory
+WORKDIR $HOME/app
+
+# Copy the current directory contents into the container at $HOME/app setting the owner to the user
+COPY --chown=user . $HOME/app
+
+
+CMD ["bash", "tools/start.sh"]
\ No newline at end of file
diff --git a/LICENSE b/LICENSE
new file mode 100644
index 0000000000000000000000000000000000000000..2d868c8b74175093a195fac31ede8a547a01b718
--- /dev/null
+++ b/LICENSE
@@ -0,0 +1,21 @@
+MIT License
+
+Copyright (c) 2024 Yan XIA
+
+Permission is hereby granted, free of charge, to any person obtaining a copy
+of this software and associated documentation files (the "Software"), to deal
+in the Software without restriction, including without limitation the rights
+to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+copies of the Software, and to permit persons to whom the Software is
+furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in all
+copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+SOFTWARE.
diff --git a/README.md b/README.md
index 5c14c7bbc6b7a4b6fb2b2643e82be65d74862b2d..1adfd10b13d9891ac6e20be5da4caae3a34f65ca 100644
--- a/README.md
+++ b/README.md
@@ -1,11 +1,10 @@
---
title: HSMR
emoji: 💀
-colorFrom: purple
-colorTo: indigo
-sdk: gradio
-sdk_version: 5.20.0
-app_file: tools/service.py
+colorFrom: blue
+colorTo: pink
+sdk: docker
+app_port: 7860
pinned: true
---
diff --git a/configs/.gitignore b/configs/.gitignore
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/configs/README.md b/configs/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..fcf75522caf15419d5ce98e24fc6db855da4adde
--- /dev/null
+++ b/configs/README.md
@@ -0,0 +1,31 @@
+# Instructions for the Configuration System
+
+The configuration system I used here is based on [Hydra](https://hydra.cc/). However, I made some small 'hack' to achieve some better features, such as `_hub_`. It might be a little bit hard to understand at first, but a comprehensive guidance is provided. Check `README.md` in each directory for more details.
+
+## Philosophy
+
+- Easy to modify and maintain.
+- Help you to code with clear structure.
+- Consistency.
+- Easy to trace and identify specific item.
+
+## Some Ideas
+
+- Less ListConfig, or ListConfig only for real list data.
+ - Dumped list will be unfolded, each element occupies one line, and it's annoying when presenting.
+ - List things are not friendly to command line arguments supports.
+- For defaults list, `_self_` must be explicitly specified.
+ - Items before `_self_` means 'based on those items'.
+ - Items after `_self_` means 'import those items'.
+
+### COMPOSITION OVER INHERITANCE
+
+Do not use "overrides" AS MUCH AS POSSIBLE, except when the changes are really tiny. Since it's hard to identify which term is actually used without running the code. Instead, the `default.yaml` serves like a template, you are supposed to copy it and modify it to create a new configuration.
+
+### REFERENCE OVER COPY
+
+If you want to use one things for many times (across each experiments or across each components in one experiment), you'd better use `${...}` to reference the Hydra object. So that you only need to modify one place while ensuring the consistency. Think about where to put the source of the object, `_hub_` is recommended but may not be the best choice every time.
+
+### PREPARE EVERYTHING YOU NEED LOCALLY
+
+Sometimes you might want to use some configurations outside the class configuration (I mean in the coding process). In that case, I recommend you to reference these things again in local configuration package. It might be too redundant, but it will bring cleaner code.
\ No newline at end of file
diff --git a/configs/_hub_/README.md b/configs/_hub_/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..db39e698657362975a38edb30a6932e2467c6f32
--- /dev/null
+++ b/configs/_hub_/README.md
@@ -0,0 +1,7 @@
+# _hub
+
+Configs here shouldn't be used as what Hydra calls a "config group". Configs here serve as a hub for other configs to reference. Most things here usually won't be used in a certain experiment.
+
+For example, the details of each datasets can be defined here, and than I only need to reference the `Human3.6M` dataset through `${_hub_.datasets.h36m}` in other configs.
+
+Each `.yaml` file here will be loaded separately in `base.yaml`. Check the `base.yaml` to understand how it works.
\ No newline at end of file
diff --git a/configs/_hub_/datasets.yaml b/configs/_hub_/datasets.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..134ed49ef786aa6fe76cd2f26a398d1e0448381f
--- /dev/null
+++ b/configs/_hub_/datasets.yaml
@@ -0,0 +1,71 @@
+train:
+ hsmr: # Fitted according to the SMPL's vertices.
+ # Standard 4 datasets.
+ mpi_inf_3dhp:
+ name: 'HSMR-MPI-INF-3DHP-train-pruned'
+ urls: ${_pm_.inputs}/hsmr_training_data/mpi_inf_3dhp-tars/{000000..000012}.tar
+ epoch_size: 12_000
+ h36m:
+ name: 'HSMR-H36M-train'
+ urls: ${_pm_.inputs}/hsmr_training_data/h36m-tars/{000000..000312}.tar
+ epoch_size: 314_000
+ mpii:
+ name: 'HSMR-MPII-train'
+ urls: ${_pm_.inputs}/hsmr_training_data/mpii-tars/{000000..000009}.tar
+ epoch_size: 10_000
+ coco14:
+ name: 'HSMR-COCO14-train'
+ urls: ${_pm_.inputs}/hsmr_training_data/coco14-tars/{000000..000017}.tar
+ epoch_size: 18_000
+ # The rest full datasets for HMR2.0.
+ coco14vit:
+ name: 'HSMR-COCO14-vit-train'
+ urls: ${_pm_.inputs}/hsmr_training_data/coco14vit-tars/{000000..000044}.tar
+ epoch_size: 45_000
+ aic:
+ name: 'HSMR-AIC-train'
+ urls: ${_pm_.inputs}/hsmr_training_data/aic-tars/{000000..000209}.tar
+ epoch_size: 210_000
+ ava:
+ name: 'HSMR-AVA-train'
+ urls: ${_pm_.inputs}/hsmr_training_data/ava-tars/{000000..000184}.tar
+ epoch_size: 185_000
+ insta:
+ name: 'HSMR-INSTA-train'
+ urls: ${_pm_.inputs}/hsmr_training_data/insta-tars/{000000..003657}.tar
+ epoch_size: 3_658_000
+
+mocap:
+ bioamass_v1:
+ dataset_file: ${_pm_.inputs}/datasets/amass_skel/data_v1.npz
+ pve_threshold: 0.05
+ cmu_mocap:
+ dataset_file: ${_pm_.inputs}/hmr2_training_data/cmu_mocap.npz
+
+eval:
+ h36m_val_p2:
+ dataset_file: ${_pm_.inputs}/hmr2_evaluation_data/h36m_val_p2.npz
+ img_root: ${_pm_.inputs}/datasets/h36m/images
+ kp_list: [25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 43]
+ use_hips: true
+
+ 3dpw_test:
+ dataset_file: ${_pm_.inputs}/hmr2_evaluation_data/3dpw_test.npz
+ img_root: ${_pm_.inputs}/datasets/3dpw
+ kp_list: [25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 43]
+ use_hips: false
+
+ posetrack_val:
+ dataset_file: ${_pm_.inputs}/hmr2_evaluation_data/posetrack_2018_val.npz
+ img_root: ${_pm_.inputs}/datasets/posetrack/posetrack2018/posetrack_data/
+ kp_list: [0] # dummy
+
+ lsp_extended:
+ dataset_file: ${_pm_.inputs}/hmr2_evaluation_data/hr-lspet_train.npz
+ img_root: ${_pm_.inputs}/datasets/hr-lspet
+ kp_list: [0] # dummy
+
+ coco_val:
+ dataset_file: ${_pm_.inputs}/hmr2_evaluation_data/coco_val.npz
+ img_root: ${_pm_.inputs}/datasets/coco
+ kp_list: [0] # dummy
diff --git a/configs/_hub_/models.yaml b/configs/_hub_/models.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..241ee7e4f3da2975367f1a9aa9696e2806b53472
--- /dev/null
+++ b/configs/_hub_/models.yaml
@@ -0,0 +1,57 @@
+body_models:
+
+ skel_mix_hsmr:
+ _target_: lib.body_models.skel_wrapper.SKELWrapper
+ model_path: '${_pm_.inputs}/body_models/skel'
+ gender: male # Use male since we don't have neutral model.
+ joint_regressor_extra: '${_pm_.inputs}/body_models/SMPL_to_J19.pkl'
+ joint_regressor_custom: '${_pm_.inputs}/body_models/J_regressor_SKIL_mix_MALE.pkl'
+
+ skel_hsmr:
+ _target_: lib.body_models.skel_wrapper.SKELWrapper
+ model_path: '${_pm_.inputs}/body_models/skel'
+ gender: male # Use male since we don't have neutral model.
+ joint_regressor_extra: '${_pm_.inputs}/body_models/SMPL_to_J19.pkl'
+ # joint_regressor_custom: '${_pm_.inputs}/body_models/J_regressor_SMPL_MALE.pkl'
+
+ smpl_hsmr:
+ _target_: lib.body_models.smpl_wrapper.SMPLWrapper
+ model_path: '${_pm_.inputs}/body_models/smpl'
+ gender: male # align with skel_hsmr
+ num_body_joints: 23
+ joint_regressor_extra: '${_pm_.inputs}/body_models/SMPL_to_J19.pkl'
+
+ smpl_hsmr_neutral:
+ _target_: lib.body_models.smpl_wrapper.SMPLWrapper
+ model_path: '${_pm_.inputs}/body_models/smpl'
+ gender: neutral # align with skel_hsmr
+ num_body_joints: 23
+ joint_regressor_extra: '${_pm_.inputs}/body_models/SMPL_to_J19.pkl'
+
+backbones:
+
+ vit_b:
+ _target_: lib.modeling.networks.backbones.ViT
+ img_size: [256, 192]
+ patch_size: 16
+ embed_dim: 768
+ depth: 12
+ num_heads: 12
+ ratio: 1
+ use_checkpoint: False
+ mlp_ratio: 4
+ qkv_bias: True
+ drop_path_rate: 0.3
+
+ vit_h:
+ _target_: lib.modeling.networks.backbones.ViT
+ img_size: [256, 192]
+ patch_size: 16
+ embed_dim: 1280
+ depth: 32
+ num_heads: 16
+ ratio: 1
+ use_checkpoint: False
+ mlp_ratio: 4
+ qkv_bias: True
+ drop_path_rate: 0.55
\ No newline at end of file
diff --git a/configs/base.yaml b/configs/base.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..151fbc5d81e94a75455db64b211e0acace2cfefc
--- /dev/null
+++ b/configs/base.yaml
@@ -0,0 +1,33 @@
+defaults:
+ # Load each sub-hub.
+ - _hub_/datasets@_hub_.datasets
+ - _hub_/models@_hub_.models
+
+ # Set up the template.
+ - _self_
+
+ # Register some defaults.
+ - exp: default
+ - policy: default
+
+
+# exp_name: !!null # Name of experiment will determine the output folder.
+exp_name: '${exp_topic}-${exp_tag}'
+exp_topic: !!null # Theme of the experiment.
+exp_tag: 'debug' # Tag of the experiment.
+
+
+output_dir: ${_pm_.root}/data_outputs/exp/${exp_name} # Output directory for the experiment.
+# output_dir: ${_pm_.root}/data_outputs/exp/${now:%Y-%m-%d}/${exp_name} # Output directory for the experiment.
+
+hydra:
+ run:
+ dir: ${output_dir}
+
+
+# Information from `PathManager`, which will be automatically set by entrypoint wrapper.
+# The related implementation is in `lib/utils/cfg_utils.py`:`entrypoint_with_args`.
+_pm_:
+ root: !!null
+ inputs: !!null
+ outputs: !!null
\ No newline at end of file
diff --git a/configs/callback/ckpt/all-p1k.yaml b/configs/callback/ckpt/all-p1k.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..40d208d77f80c7d9f856271031f80d8e42bb1ef4
--- /dev/null
+++ b/configs/callback/ckpt/all-p1k.yaml
@@ -0,0 +1,5 @@
+_target_: pytorch_lightning.callbacks.ModelCheckpoint
+dirpath: '${output_dir}/checkpoints'
+save_last: True
+every_n_train_steps: 1000
+save_top_k: -1 # save all ckpts
\ No newline at end of file
diff --git a/configs/callback/ckpt/top1-p1k.yaml b/configs/callback/ckpt/top1-p1k.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..4f7a912720e27db3128a827be624e7694dc5a4d7
--- /dev/null
+++ b/configs/callback/ckpt/top1-p1k.yaml
@@ -0,0 +1,5 @@
+_target_: pytorch_lightning.callbacks.ModelCheckpoint
+dirpath: '${output_dir}/checkpoints'
+save_last: True
+every_n_train_steps: 1000
+save_top_k: 1 # save only the last checkpoint
\ No newline at end of file
diff --git a/configs/callback/skelify-spin/i10kb1.yaml b/configs/callback/skelify-spin/i10kb1.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..6570b331397608ef4af54eac66a86eb957b88062
--- /dev/null
+++ b/configs/callback/skelify-spin/i10kb1.yaml
@@ -0,0 +1,19 @@
+defaults:
+ - _self_
+ # Import the SKELify-SPIN callback.
+ - /pipeline/skelify-refiner@skelify
+
+_target_: lib.modeling.callbacks.SKELifySPIN
+
+# This configuration are for frozen backbone exp with batch_size = 24000
+
+cfg:
+ interval: 10
+ batch_size: 24000 # when greater than `interval * dataloader's batch_size`, it's equivalent to that
+ max_batches_per_round: 1 # only the latest k * batch_size items are SPINed to save time
+ # better_pgt_fn: '${_pm_.inputs}/datasets/skel_training_data/spin/better_pseudo_gt.npz'
+ better_pgt_fn: '${output_dir}/better_pseudo_gt.npz'
+ skip_warm_up_steps: 10000
+ # skip_warm_up_steps: 0
+ update_better_pgt: True
+ valid_betas_threshold: 2
diff --git a/configs/callback/skelify-spin/i230kb1.yaml b/configs/callback/skelify-spin/i230kb1.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..97969670712fc1d1547e5719ff39b6b68f5993a4
--- /dev/null
+++ b/configs/callback/skelify-spin/i230kb1.yaml
@@ -0,0 +1,16 @@
+defaults:
+ - _self_
+ # Import the SKELify-SPIN callback.
+ - /pipeline/skelify-refiner@skelify
+
+_target_: lib.modeling.callbacks.SKELifySPIN
+
+cfg:
+ interval: 230
+ batch_size: 18000 # when greater than `interval * dataloader's batch_size`, it's equivalent to that
+ max_batches_per_round: 1 # only the latest k * batch_size items are SPINed to save time
+ # better_pgt_fn: '${_pm_.inputs}/datasets/skel_training_data/spin/better_pseudo_gt.npz'
+ better_pgt_fn: '${output_dir}/better_pseudo_gt.npz'
+ skip_warm_up_steps: 5000
+ update_better_pgt: True
+ valid_betas_threshold: 2
diff --git a/configs/callback/skelify-spin/i80.yaml b/configs/callback/skelify-spin/i80.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..d1026e524f47b2a9be77fc22b9de2049c6f76983
--- /dev/null
+++ b/configs/callback/skelify-spin/i80.yaml
@@ -0,0 +1,15 @@
+defaults:
+ - _self_
+ # Import the SKELify-SPIN callback.
+ - /pipeline/skelify-refiner@skelify
+
+_target_: lib.modeling.callbacks.SKELifySPIN
+
+cfg:
+ interval: 80
+ batch_size: 24000 # when greater than `interval * dataloader's batch_size`, it's equivalent to that
+ # better_pgt_fn: '${_pm_.inputs}/datasets/skel_training_data/spin/better_pseudo_gt.npz'
+ better_pgt_fn: '${output_dir}/better_pseudo_gt.npz'
+ skip_warm_up_steps: 5000
+ update_better_pgt: True
+ valid_betas_threshold: 2
diff --git a/configs/callback/skelify-spin/read_only.yaml b/configs/callback/skelify-spin/read_only.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..f21087355c220fc56d64e1608d3108659529e683
--- /dev/null
+++ b/configs/callback/skelify-spin/read_only.yaml
@@ -0,0 +1,18 @@
+defaults:
+ - _self_
+ # Import the SKELify-SPIN callback.
+ - /pipeline/skelify-refiner@skelify
+
+_target_: lib.modeling.callbacks.SKELifySPIN
+
+# This configuration are for frozen backbone exp with batch_size = 24000
+
+cfg:
+ interval: 0
+ batch_size: 0 # when greater than `interval * dataloader's batch_size`, it's equivalent to that
+ max_batches_per_round: 0 # only the latest k * batch_size items are SPINed to save time
+ # better_pgt_fn: '${_pm_.inputs}/datasets/skel_training_data/spin/better_pseudo_gt.npz'
+ better_pgt_fn: '${output_dir}/better_pseudo_gt.npz'
+ skip_warm_up_steps: 0
+ update_better_pgt: True
+ valid_betas_threshold: 2
diff --git a/configs/data/skel-hmr2_fashion.yaml b/configs/data/skel-hmr2_fashion.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..66868282ec2e1066bded8f821ccbe324fb63e8ce
--- /dev/null
+++ b/configs/data/skel-hmr2_fashion.yaml
@@ -0,0 +1,71 @@
+_target_: lib.data.modules.hmr2_fashion.skel_wds.DataModule
+
+name: HMR2_fashion_WDS
+
+cfg:
+
+ train:
+ shared_ds_opt: # TODO: modify this
+ SUPPRESS_KP_CONF_THRESH: 0.3
+ FILTER_NUM_KP: 4
+ FILTER_NUM_KP_THRESH: 0.0
+ FILTER_REPROJ_THRESH: 31000
+ SUPPRESS_BETAS_THRESH: 3.0
+ SUPPRESS_BAD_POSES: False
+ POSES_BETAS_SIMULTANEOUS: True
+ FILTER_NO_POSES: False
+ BETAS_REG: True
+
+ datasets:
+ - name: 'H36M'
+ item: ${_hub_.datasets.train.hsmr.h36m}
+ weight: 0.3
+ - name: 'MPII'
+ item: ${_hub_.datasets.train.hsmr.mpii}
+ weight: 0.1
+ - name: 'COCO14'
+ item: ${_hub_.datasets.train.hsmr.coco14}
+ weight: 0.4
+ - name: 'MPI-INF-3DHP'
+ item: ${_hub_.datasets.train.hsmr.mpi_inf_3dhp}
+ weight: 0.2
+
+ dataloader:
+ drop_last: True
+ batch_size: 300
+ num_workers: 6
+ prefetch_factor: 2
+
+ eval:
+ datasets:
+ - name: 'LSP-EXTENDED'
+ item: ${_hub_.datasets.eval.lsp_extended}
+ - name: 'H36M-VAL-P2'
+ item: ${_hub_.datasets.eval.h36m_val_p2}
+ - name: '3DPW-TEST'
+ item: ${_hub_.datasets.eval.3dpw_test}
+ - name: 'POSETRACK-VAL'
+ item: ${_hub_.datasets.eval.posetrack_val}
+ - name: 'COCO-VAL'
+ item: ${_hub_.datasets.eval.coco_val}
+
+
+ dataloader:
+ shuffle: False
+ batch_size: 300
+ num_workers: 6
+
+ policy: ${policy}
+
+ # TODO: modify this
+ augm:
+ SCALE_FACTOR: 0.3
+ ROT_FACTOR: 30
+ TRANS_FACTOR: 0.02
+ COLOR_SCALE: 0.2
+ ROT_AUG_RATE: 0.6
+ TRANS_AUG_RATE: 0.5
+ DO_FLIP: True
+ FLIP_AUG_RATE: 0.5
+ EXTREME_CROP_AUG_RATE: 0.10
+ EXTREME_CROP_AUG_LEVEL: 1
\ No newline at end of file
diff --git a/configs/data/skel-hsr_v1_4ds.yaml b/configs/data/skel-hsr_v1_4ds.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..b5e5c3d39f6dbeb73c644dbd44656c03d4e8c083
--- /dev/null
+++ b/configs/data/skel-hsr_v1_4ds.yaml
@@ -0,0 +1,84 @@
+_target_: lib.data.modules.hsmr_v1.data_module.DataModule
+
+name: SKEL_HSMR_V1
+
+cfg:
+
+ train:
+ cfg:
+ # Loader settings.
+ suppress_pgt_params_pve_max_thresh: 0.06 # Mark PVE-MAX larger than this value as invalid.
+ suppress_kp_conf_thresh: 0.3 # Mark key-point confidence smaller than this value as invalid.
+ suppress_betas_thresh: 3.0 # Mark betas having components larger than this value as invalid.
+ poses_betas_simultaneous: True # Sync poses and betas for the small person.
+ filter_insufficient_kp_cnt: 4
+ suppress_insufficient_kp_thresh: 0.0
+ filter_reproj_err_thresh: 31000
+ regularize_invalid_betas: True
+ # Others.
+ image_augmentation: ${...image_augmentation}
+ policy: ${...policy}
+
+ datasets:
+ - name: 'H36M'
+ item: ${_hub_.datasets.train.hsmr.h36m}
+ weight: 0.3
+ - name: 'MPII'
+ item: ${_hub_.datasets.train.hsmr.mpii}
+ weight: 0.1
+ - name: 'COCO14'
+ item: ${_hub_.datasets.train.hsmr.coco14}
+ weight: 0.4
+ - name: 'MPI-INF-3DHP'
+ item: ${_hub_.datasets.train.hsmr.mpi_inf_3dhp}
+ weight: 0.2
+
+ dataloader:
+ drop_last: True
+ batch_size: 300
+ num_workers: 6
+ prefetch_factor: 2
+
+ # mocap:
+ # cfg: ${_hub_.datasets.mocap.bioamass_v1}
+ # dataloader:
+ # batch_size: 600 # num_train:2 * batch_size:300 (from HMR2.0's cfg)
+ # drop_last: True
+ # shuffle: True
+ # num_workers: 1
+
+ eval:
+ cfg:
+ image_augmentation: ${...image_augmentation}
+ policy: ${...policy}
+
+ datasets:
+ - name: 'LSP-EXTENDED'
+ item: ${_hub_.datasets.eval.lsp_extended}
+ - name: 'H36M-VAL-P2'
+ item: ${_hub_.datasets.eval.h36m_val_p2}
+ - name: '3DPW-TEST'
+ item: ${_hub_.datasets.eval.3dpw_test}
+ - name: 'POSETRACK-VAL'
+ item: ${_hub_.datasets.eval.posetrack_val}
+ - name: 'COCO-VAL'
+ item: ${_hub_.datasets.eval.coco_val}
+
+ dataloader:
+ shuffle: False
+ batch_size: 300
+ num_workers: 6
+
+ # Augmentation settings.
+ image_augmentation:
+ trans_factor: 0.02
+ bbox_scale_factor: 0.3
+ rot_aug_rate: 0.6
+ rot_factor: 30
+ do_flip: True
+ flip_aug_rate: 0.5
+ extreme_crop_aug_rate: 0.10
+ half_color_scale: 0.2
+
+ # Others.
+ policy: ${policy}
\ No newline at end of file
diff --git a/configs/data/skel-hsr_v1_full.yaml b/configs/data/skel-hsr_v1_full.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..3377ebfdbcea6311efef32497bee64d90bcf59de
--- /dev/null
+++ b/configs/data/skel-hsr_v1_full.yaml
@@ -0,0 +1,96 @@
+_target_: lib.data.modules.hsmr_v1.data_module.DataModule
+
+name: SKEL_HSMR_V1
+
+cfg:
+
+ train:
+ cfg:
+ # Loader settings.
+ suppress_pgt_params_pve_max_thresh: 0.06 # Mark PVE-MAX larger than this value as invalid.
+ suppress_kp_conf_thresh: 0.3 # Mark key-point confidence smaller than this value as invalid.
+ suppress_betas_thresh: 3.0 # Mark betas having components larger than this value as invalid.
+ poses_betas_simultaneous: True # Sync poses and betas for the small person.
+ filter_insufficient_kp_cnt: 4
+ suppress_insufficient_kp_thresh: 0.0
+ filter_reproj_err_thresh: 31000
+ regularize_invalid_betas: True
+ # Others.
+ image_augmentation: ${...image_augmentation}
+ policy: ${...policy}
+
+ datasets:
+ - name: 'H36M'
+ item: ${_hub_.datasets.train.hsmr.h36m}
+ weight: 0.1
+ - name: 'MPII'
+ item: ${_hub_.datasets.train.hsmr.mpii}
+ weight: 0.1
+ - name: 'COCO14'
+ item: ${_hub_.datasets.train.hsmr.coco14}
+ weight: 0.1
+ - name: 'COCO14-ViTPose'
+ item: ${_hub_.datasets.train.hsmr.coco14vit}
+ weight: 0.1
+ - name: 'MPI-INF-3DHP'
+ item: ${_hub_.datasets.train.hsmr.mpi_inf_3dhp}
+ weight: 0.02
+ - name: 'AVA'
+ item: ${_hub_.datasets.train.hsmr.ava}
+ weight: 0.19
+ - name: 'AIC'
+ item: ${_hub_.datasets.train.hsmr.aic}
+ weight: 0.19
+ - name: 'INSTA'
+ item: ${_hub_.datasets.train.hsmr.insta}
+ weight: 0.2
+
+ dataloader:
+ drop_last: True
+ batch_size: 300
+ num_workers: 6
+ prefetch_factor: 2
+
+ mocap:
+ cfg: ${_hub_.datasets.mocap.bioamass_v1}
+ dataloader:
+ batch_size: 600 # num_train:2 * batch_size:300 (from HMR2.0's cfg)
+ drop_last: True
+ shuffle: True
+ num_workers: 1
+
+ eval:
+ cfg:
+ image_augmentation: ${...image_augmentation}
+ policy: ${...policy}
+
+ datasets:
+ - name: 'LSP-EXTENDED'
+ item: ${_hub_.datasets.eval.lsp_extended}
+ - name: 'H36M-VAL-P2'
+ item: ${_hub_.datasets.eval.h36m_val_p2}
+ - name: '3DPW-TEST'
+ item: ${_hub_.datasets.eval.3dpw_test}
+ - name: 'POSETRACK-VAL'
+ item: ${_hub_.datasets.eval.posetrack_val}
+ - name: 'COCO-VAL'
+ item: ${_hub_.datasets.eval.coco_val}
+
+ dataloader:
+ shuffle: False
+ batch_size: 300
+ num_workers: 6
+
+ # Augmentation settings.
+ image_augmentation:
+ trans_factor: 0.02
+ bbox_scale_factor: 0.3
+ rot_aug_rate: 0.6
+ rot_factor: 30
+ do_flip: True
+ flip_aug_rate: 0.5
+ extreme_crop_aug_rate: 0.10
+ half_color_scale: 0.2
+
+ # Others.
+ policy: ${policy}
\ No newline at end of file
diff --git a/configs/exp/default.yaml b/configs/exp/default.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..ffb509d4aea8925b5ee865eefae04d8fca4d59fd
--- /dev/null
+++ b/configs/exp/default.yaml
@@ -0,0 +1,16 @@
+# @package _global_
+defaults:
+ # Use absolute path, cause the root in this file is `configs/exp/` rather than `configs/`.
+ # - /data: ...
+
+ # Configurations in this file should have the highest priority.
+ - _self_
+
+# ======= Overwrite Section =======
+# (Do not use as much as possible!)
+
+
+
+# ====== Main Section ======
+
+exp_topic: !!null # Name of experiment will determine the output folder.
\ No newline at end of file
diff --git a/configs/exp/hsr/train.yaml b/configs/exp/hsr/train.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..866c1e69c8e8f957c7a01e50ffc131671ad06b43
--- /dev/null
+++ b/configs/exp/hsr/train.yaml
@@ -0,0 +1,47 @@
+# @package _global_
+defaults:
+ - /pipeline: hsmr
+ - /data: skel-hsmr_v1_full
+ # Configurations in this file
+ - _self_
+ # Import callbacks.
+ - /callback/ckpt/top1-p1k@callbacks.ckpt
+ # - /callback/skelify-spin/i230kb1@callbacks.SKELifySPIN
+ - /callback/skelify-spin/read_only@callbacks.SKELifySPIN
+
+# ======= Overwrite Section =======
+# (Do not use as much as possible!)
+
+pipeline:
+ cfg:
+ backbone: ${_hub_.models.backbones.vit_h}
+ backbone_ckpt: '${_pm_.inputs}/backbone/vitpose_backbone.pth'
+ freeze_backbone: False
+
+data:
+ cfg:
+ train:
+ dataloader:
+ batch_size: 78
+
+
+# ====== Main Section ======
+
+exp_topic: 'HSMR-train-vit_h'
+
+enable_time_monitor: False
+
+seed: NULL
+ckpt_path: NULL
+
+logger:
+ interval: 1000
+ interval_skelify: 10
+ samples_per_record: 5
+
+task: 'fit'
+pl_trainer:
+ devices: 8
+ max_epochs: 100
+ deterministic: false
+ precision: 16
\ No newline at end of file
diff --git a/configs/exp/skelify-full.yaml b/configs/exp/skelify-full.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..8d269becd692494eb5c9a97be0f0b5c2495029a3
--- /dev/null
+++ b/configs/exp/skelify-full.yaml
@@ -0,0 +1,21 @@
+# @package _global_
+defaults:
+ - /pipeline: skelify-full
+ - /data: skel-hsmr_v1_full
+ # Configurations in this file should have the highest priority.
+ - _self_
+
+# ======= Overwrite Section =======
+# (Do not use as much as possible!)
+
+
+
+# ====== Main Section ======
+
+exp_topic: 'SKELify-Full'
+
+enable_time_monitor: True
+
+logger:
+ interval: 1
+ samples_per_record: 8
\ No newline at end of file
diff --git a/configs/exp/skelify-refiner.yaml b/configs/exp/skelify-refiner.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..b2f644e7f0d2a0fdb08931defb313c4774544765
--- /dev/null
+++ b/configs/exp/skelify-refiner.yaml
@@ -0,0 +1,21 @@
+# @package _global_
+defaults:
+ - /pipeline: skelify-refiner
+ - /data: skel-hsmr_v1_full
+ # Configurations in this file should have the highest priority.
+ - _self_
+
+# ======= Overwrite Section =======
+# (Do not use as much as possible!)
+
+
+
+# ====== Main Section ======
+
+exp_topic: 'SKELify-Refiner'
+
+enable_time_monitor: True
+
+logger:
+ interval: 2
+ samples_per_record: 8
\ No newline at end of file
diff --git a/configs/pipeline/hmr2.yaml b/configs/pipeline/hmr2.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..d32bbbc8a7bee6b0c2f113abd3df85cf13330f7e
--- /dev/null
+++ b/configs/pipeline/hmr2.yaml
@@ -0,0 +1,43 @@
+_target_: lib.modeling.pipelines.HMR2Pipeline
+
+name: HMR2
+
+cfg:
+ # Body Models
+ SMPL: ${_hub_.models.body_models.smpl_hsmr_neutral}
+
+ # Backbone and its checkpoint.
+ backbone: ${_hub_.models.backbones.vit_h}
+ backbone_ckpt: ???
+
+ # Head to get the parameters.
+ head:
+ _target_: lib.modeling.networks.heads.SMPLTransformerDecoderHead
+ cfg:
+ transformer_decoder:
+ depth: 6
+ heads: 8
+ mlp_dim: 1024
+ dim_head: 64
+ dropout: 0.0
+ emb_dropout: 0.0
+ norm: 'layer'
+ context_dim: ${....backbone.embed_dim}
+
+ optimizer:
+ _target_: torch.optim.AdamW
+ lr: 1e-5
+ weight_decay: 1e-4
+
+ # This may be redesigned, e.g., we can add a loss object to maintain the calculation of the loss, or a callback.
+ loss_weights:
+ kp3d: 0.05
+ kp2d: 0.01
+ poses_orient: 0.002
+ poses_body: 0.001
+ betas: 0.0005
+ adversarial: 0.0005
+ # adversarial: 0.0
+
+ policy: ${policy}
+ logger: ${logger}
\ No newline at end of file
diff --git a/configs/pipeline/hsr.yaml b/configs/pipeline/hsr.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..a2d5c6183ecccf7ad0af94d922562c656266899b
--- /dev/null
+++ b/configs/pipeline/hsr.yaml
@@ -0,0 +1,49 @@
+_target_: lib.modeling.pipelines.HSMRPipeline
+
+name: HSMR
+
+cfg:
+ pd_poses_repr: 'rotation_6d' # poses representation for prediction, choices: 'euler_angle' | 'rotation_6d'
+ sp_poses_repr: 'rotation_matrix' # poses representation for supervision, choices: 'euler_angle' | 'rotation_matrix'
+
+ # Body Models
+ # SKEL: ${_hub_.models.body_models.skel_hsmr}
+ SKEL: ${_hub_.models.body_models.skel_mix_hsmr}
+
+ # Backbone and its checkpoint.
+ backbone: ???
+ backbone_ckpt: ???
+
+ # Head to get the parameters.
+ head:
+ _target_: lib.modeling.networks.heads.SKELTransformerDecoderHead
+ cfg:
+ pd_poses_repr: ${...pd_poses_repr}
+ transformer_decoder:
+ depth: 6
+ heads: 8
+ mlp_dim: 1024
+ dim_head: 64
+ dropout: 0.0
+ emb_dropout: 0.0
+ norm: 'layer'
+ context_dim: ${....backbone.embed_dim}
+
+ optimizer:
+ _target_: torch.optim.AdamW
+ lr: 1e-5
+ weight_decay: 1e-4
+
+ # This may be redesigned, e.g., we can add a loss object to maintain the calculation of the loss, or a callback.
+ loss_weights:
+ kp3d: 0.05
+ kp2d: 0.01
+ # prior: 0.0005
+ prior: 0.0
+ poses_orient: 0.002
+ poses_body: 0.001
+ betas: 0.0005
+ # adversarial: 0.0005
+
+ policy: ${policy}
+ logger: ${logger}
\ No newline at end of file
diff --git a/configs/pipeline/skelify-full.yaml b/configs/pipeline/skelify-full.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..662f90d5bc35df34ccbfa63303269ce135236839
--- /dev/null
+++ b/configs/pipeline/skelify-full.yaml
@@ -0,0 +1,98 @@
+_target_: lib.modeling.optim.SKELify
+
+name: SKELify
+
+cfg:
+ skel_model: ${_hub_.models.body_models.skel_mix_hsmr}
+
+ _f_normalize_kp2d: True
+ _f_normalize_kp2d_to_mean: False
+ _w_angle_prior_scale: 1.7
+
+ phases:
+ # ================================
+ # ⛩️ Part 1: Camera initialization.
+ # --------------------------------
+ STAGE-camera-init:
+ max_loop: 30
+ params_keys: ['cam_t', 'poses_orient']
+ parts: ['torso']
+ optimizer: ${...optimizer}
+ losses:
+ f_normalize_kp2d: ${...._f_normalize_kp2d}
+ f_normalize_kp2d_to_mean: ${...._f_normalize_kp2d_to_mean}
+ w_depth: 100.0
+ w_reprojection: 1.78
+ # ================================
+
+ # ================================
+ # ⛩️ Part 2: Overall optimization.
+ # --------------------------------
+ STAGE-overall-1:
+ max_loop: 30
+ params_keys: ['cam_t', 'poses_orient', 'poses_body', 'betas']
+ parts: ['all']
+ optimizer: ${...optimizer}
+ losses:
+ f_normalize_kp2d: ${...._f_normalize_kp2d}
+ f_normalize_kp2d_to_mean: ${...._f_normalize_kp2d_to_mean}
+ w_reprojection: 1.0
+ w_shape_prior: 100.0
+ w_angle_prior: 404.0
+ w_angle_prior_scale: ${...._w_angle_prior_scale} # TODO: Finalize it.
+ # --------------------------------
+ STAGE-overall-2:
+ max_loop: 30
+ params_keys: ['cam_t', 'poses_orient', 'poses_body', 'betas']
+ optimizer: ${...optimizer}
+ parts: ['all']
+ losses:
+ f_normalize_kp2d: ${...._f_normalize_kp2d}
+ f_normalize_kp2d_to_mean: ${...._f_normalize_kp2d_to_mean}
+ w_reprojection: 1.0
+ w_shape_prior: 50.0
+ w_angle_prior: 404.0
+ w_angle_prior_scale: ${...._w_angle_prior_scale} # TODO: Finalize it.
+ # --------------------------------
+ STAGE-overall-3:
+ max_loop: 30
+ params_keys: ['cam_t', 'poses_orient', 'poses_body', 'betas']
+ parts: ['all']
+ optimizer: ${...optimizer}
+ losses:
+ f_normalize_kp2d: ${...._f_normalize_kp2d}
+ f_normalize_kp2d_to_mean: ${...._f_normalize_kp2d_to_mean}
+ w_reprojection: 1.0
+ w_shape_prior: 10.0
+ w_angle_prior: 57.4
+ w_angle_prior_scale: ${...._w_angle_prior_scale} # TODO: Finalize it.
+ # --------------------------------
+ STAGE-overall-4:
+ max_loop: 30
+ params_keys: ['cam_t', 'poses_orient', 'poses_body', 'betas']
+ parts: ['all']
+ optimizer: ${...optimizer}
+ losses:
+ f_normalize_kp2d: ${...._f_normalize_kp2d}
+ f_normalize_kp2d_to_mean: ${...._f_normalize_kp2d_to_mean}
+ w_reprojection: 1.0
+ w_shape_prior: 5.0
+ w_angle_prior: 4.78
+ w_angle_prior_scale: ${...._w_angle_prior_scale} # TODO: Finalize it.
+ # ================================
+
+ optimizer:
+ _target_: torch.optim.LBFGS
+ lr: 1
+ line_search_fn: 'strong_wolfe'
+ tolerance_grad: ${..early_quit_thresholds.abs}
+ tolerance_change: ${..early_quit_thresholds.rel}
+
+ early_quit_thresholds:
+ abs: 1e-9
+ rel: 1e-9
+
+ img_patch_size: ${policy.img_patch_size}
+ focal_length: ${policy.focal_length}
+
+ logger: ${logger}
\ No newline at end of file
diff --git a/configs/pipeline/skelify-refiner.yaml b/configs/pipeline/skelify-refiner.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..4e2195579374246d7d4a3e9fc13cdfc4381b6301
--- /dev/null
+++ b/configs/pipeline/skelify-refiner.yaml
@@ -0,0 +1,37 @@
+_target_: lib.modeling.optim.SKELify
+
+name: SKELify-Refiner
+
+cfg:
+ skel_model: ${_hub_.models.body_models.skel_mix_hsmr}
+
+ phases:
+ STAGE-refine:
+ max_loop: 10
+ params_keys: ['cam_t', 'poses_orient', 'poses_body', 'betas']
+ optimizer: ${...optimizer}
+ losses:
+ f_normalize_kp2d: True
+ f_normalize_kp2d_to_mean: False
+ w_reprojection: 1.0
+ w_shape_prior: 5.0
+ w_angle_prior: 4.78
+ w_angle_prior_scale: 0.17
+ parts: ['all']
+ # ================================
+
+ optimizer:
+ _target_: torch.optim.LBFGS
+ lr: 1
+ line_search_fn: 'strong_wolfe'
+ tolerance_grad: ${..early_quit_thresholds.abs}
+ tolerance_change: ${..early_quit_thresholds.rel}
+
+ early_quit_thresholds:
+ abs: 1e-7
+ rel: 1e-9
+
+ img_patch_size: ${policy.img_patch_size}
+ focal_length: ${policy.focal_length}
+
+ logger: ${logger}
\ No newline at end of file
diff --git a/configs/policy/README.md b/configs/policy/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..7af8f5e62fa39f47dd1d90bd0d13e2b44a90f562
--- /dev/null
+++ b/configs/policy/README.md
@@ -0,0 +1,3 @@
+# policy
+
+Configs here define frequently used values that are shared across multiple sub-modules. They usually cannot belongs to any specific sub-module.
\ No newline at end of file
diff --git a/configs/policy/default.yaml b/configs/policy/default.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..d92856e58e8a9da9231f0313a836f2f2776fdec3
--- /dev/null
+++ b/configs/policy/default.yaml
@@ -0,0 +1,5 @@
+img_patch_size: 256
+focal_length: 5000
+img_mean: [0.485, 0.456, 0.406]
+img_std: [0.229, 0.224, 0.225]
+bbox_shape: null
\ No newline at end of file
diff --git a/data_inputs/.gitignore b/data_inputs/.gitignore
new file mode 100644
index 0000000000000000000000000000000000000000..9964fcd73181748d80bc2783fdad0ed777484ada
--- /dev/null
+++ b/data_inputs/.gitignore
@@ -0,0 +1,2 @@
+body_models
+released_models
\ No newline at end of file
diff --git a/data_inputs/body_models.tar.gz b/data_inputs/body_models.tar.gz
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/data_inputs/description.md b/data_inputs/description.md
new file mode 100644
index 0000000000000000000000000000000000000000..21c940ed0da5fc5d36743a44ca0cb168ad10a39e
--- /dev/null
+++ b/data_inputs/description.md
@@ -0,0 +1,21 @@
+
+ HSMR
+
+
+ Reconstructing Humans with a
Biomechanically Accurate Skeleton
+
+
+Project Page
+|
+GitHub Repo
+|
+Paper
+
+
+> Reconstructing Humans with a Biomechanically Accurate Skeleton
+> [Yan Xia](https://scholar.isshikih.top),
+> [Xiaowei Zhou](https://xzhou.me),
+> [Etienne Vouga](https://www.cs.utexas.edu/~evouga/),
+> [Qixing Huang](https://www.cs.utexas.edu/~huangqx/),
+> [Georgios Pavlakos](https://geopavlakos.github.io/)
+> *CVPR 2025*
diff --git a/data_inputs/example_imgs/ballerina.png b/data_inputs/example_imgs/ballerina.png
new file mode 100644
index 0000000000000000000000000000000000000000..c40509ecda6bd2e5280a2033a82d885276162dc1
Binary files /dev/null and b/data_inputs/example_imgs/ballerina.png differ
diff --git a/data_inputs/example_imgs/exercise.jpeg b/data_inputs/example_imgs/exercise.jpeg
new file mode 100644
index 0000000000000000000000000000000000000000..c2c1e6403eeb7228ca8d605dad28225f1a98d039
Binary files /dev/null and b/data_inputs/example_imgs/exercise.jpeg differ
diff --git a/data_inputs/example_imgs/jump_high.jpg b/data_inputs/example_imgs/jump_high.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..fc75dcb8da75e19da1a7a3d794ae3cc957726ffb
Binary files /dev/null and b/data_inputs/example_imgs/jump_high.jpg differ
diff --git a/data_outputs/.gitignore b/data_outputs/.gitignore
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/lib/__init__.py b/lib/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..7152555990d5a8a9d489ea8dc9c69c528d034ad6
--- /dev/null
+++ b/lib/__init__.py
@@ -0,0 +1 @@
+from .version import __version__
\ No newline at end of file
diff --git a/lib/body_models/abstract_skeletons.py b/lib/body_models/abstract_skeletons.py
new file mode 100644
index 0000000000000000000000000000000000000000..3d3a3614b53aa7f0687843d9d2367e437fb18442
--- /dev/null
+++ b/lib/body_models/abstract_skeletons.py
@@ -0,0 +1,123 @@
+# "Skeletons" here means virtual "bones" between joints. It is used to draw the skeleton on the image.
+
+class Skeleton():
+ bones = []
+ bone_colors = []
+ chains = []
+ parent = []
+
+
+class Skeleton_SMPL24(Skeleton):
+ # The joints definition are copied from
+ # [ROMP](https://github.com/Arthur151/ROMP/blob/4eebd3647f57d291d26423e51f0d514ff7197cb3/romp/lib/constants.py#L58).
+ chains = [
+ [0, 1, 4, 7, 10], # left leg
+ [0, 2, 5, 8, 11], # right leg
+ [0, 3, 6, 9, 12, 15], # spine & head
+ [12, 13, 16, 18, 20, 22], # left arm
+ [12, 14, 17, 19, 21, 23], # right arm
+ ]
+ bones = [
+ [ 0, 1], [ 1, 4], [ 4, 7], [ 7, 10], # left leg
+ [ 0, 2], [ 2, 5], [ 5, 8], [ 8, 11], # right leg
+ [ 0, 3], [ 3, 6], [ 6, 9], [ 9, 12], [12, 15], # spine & head
+ [12, 13], [13, 16], [16, 18], [18, 20], [20, 22], # left arm
+ [12, 14], [14, 17], [17, 19], [19, 21], [21, 23], # right arm
+ ]
+ bone_colors = [
+ [127, 0, 0], [148, 21, 21], [169, 41, 41], [191, 63, 63], # red
+ [ 0, 127, 0], [ 21, 148, 21], [ 41, 169, 41], [ 63, 191, 63], # green
+ [ 0, 0, 127], [ 15, 15, 143], [ 31, 31, 159], [ 47, 47, 175], [ 63, 63, 191], # blue
+ [ 0, 127, 127], [ 15, 143, 143], [ 31, 159, 159], [ 47, 175, 175], [ 63, 191, 191], # cyan
+ [127, 0, 127], [143, 15, 143], [159, 31, 159], [175, 47, 175], [191, 63, 191], # magenta
+ ]
+ parent = [-1, 0, 0, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 12, 12, 12, 13, 14, 16, 17, 18, 19, 20, 21]
+
+
+class Skeleton_SMPL22(Skeleton):
+ chains = [
+ [0, 1, 4, 7, 10], # left leg
+ [0, 2, 5, 8, 11], # right leg
+ [0, 3, 6, 9, 12, 15], # spine & head
+ [12, 13, 16, 18, 20], # left arm
+ [12, 14, 17, 19, 21], # right arm
+ ]
+ bones = [
+ [ 0, 1], [ 1, 4], [ 4, 7], [ 7, 10], # left leg
+ [ 0, 2], [ 2, 5], [ 5, 8], [ 8, 11], # right leg
+ [ 0, 3], [ 3, 6], [ 6, 9], [ 9, 12], [12, 15], # spine & head
+ [12, 13], [13, 16], [16, 18], [18, 20], # left arm
+ [12, 14], [14, 17], [17, 19], [19, 21], # right arm
+ ]
+ bone_colors = [
+ [127, 0, 0], [148, 21, 21], [169, 41, 41], [191, 63, 63], # red
+ [ 0, 127, 0], [ 21, 148, 21], [ 41, 169, 41], [ 63, 191, 63], # green
+ [ 0, 0, 127], [ 15, 15, 143], [ 31, 31, 159], [ 47, 47, 175], [ 63, 63, 191], # blue
+ [ 0, 127, 127], [ 15, 143, 143], [ 31, 159, 159], [ 47, 175, 175], # cyan
+ [127, 0, 127], [143, 15, 143], [159, 31, 159], [175, 47, 175], # magenta
+ ]
+ parent = [-1, 0, 0, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 12, 12, 12, 13, 14, 16, 17, 18, 19]
+
+
+class Skeleton_SKEL24(Skeleton):
+ chains = [
+ [ 0, 6, 7, 8, 9, 10], # left leg
+ [ 0, 1, 2, 3, 4, 5], # right leg
+ [ 0, 11, 12, 13], # spine & head
+ [12, 19, 20, 21, 22, 23], # left arm
+ [12, 14, 15, 16, 17, 18], # right arm
+ ]
+ bones = [
+ [ 0, 6], [ 6, 7], [ 7, 8], [ 8, 9], [ 9, 10], # left leg
+ [ 0, 1], [ 1, 2], [ 2, 3], [ 3, 4], [ 4, 5], # right leg
+ [ 0, 11], [11, 12], [12, 13], # spine & head
+ [12, 19], [19, 20], [20, 21], [21, 22], [22, 23], # left arm
+ [12, 14], [14, 15], [15, 16], [16, 17], [17, 18], # right arm
+ ]
+ bone_colors = [
+ [127, 0, 0], [148, 21, 21], [169, 41, 41], [191, 63, 63], [191, 63, 63], # red
+ [ 0, 127, 0], [ 21, 148, 21], [ 41, 169, 41], [ 63, 191, 63], [ 63, 191, 63], # green
+ [ 0, 0, 127], [ 31, 31, 159], [ 63, 63, 191], # blue
+ [ 0, 127, 127], [ 15, 143, 143], [ 31, 159, 159], [ 47, 175, 175], [ 63, 191, 191], # cyan
+ [127, 0, 127], [143, 15, 143], [159, 31, 159], [175, 47, 175], [191, 63, 191], # magenta
+ ]
+ parent = [-1, 0, 1, 2, 3, 4, 0, 6, 7, 8, 9, 0, 11, 12, 12, 19, 20, 21, 22, 12, 14, 15, 16, 17]
+
+
+class Skeleton_OpenPose25(Skeleton):
+ ''' https://www.researchgate.net/figure/Twenty-five-keypoints-of-the-OpenPose-software-model_fig1_374116819 '''
+ chain = [
+ [ 8, 12, 13, 14, 19, 20], # left leg
+ [14, 21], # left heel
+ [ 8, 9, 10, 11, 22, 23], # right leg
+ [11, 24], # right heel
+ [ 8, 1, 0], # spine & head
+ [ 0, 16, 18], # left face
+ [ 0, 15, 17], # right face
+ [ 1, 5, 6, 7], # left arm
+ [ 1, 2, 3, 4], # right arm
+ ]
+ bones = [
+ [ 8, 12], [12, 13], [13, 14], # left leg
+ [14, 19], [19, 20], [14, 21], # left foot
+ [ 8, 9], [ 9, 10], [10, 11], # right leg
+ [11, 22], [22, 23], [11, 24], # right foot
+ [ 8, 1], [ 1, 0], # spine & head
+ [ 0, 16], [16, 18], # left face
+ [ 0, 15], [15, 17], # right face
+ [ 1, 5], [ 5, 6], [ 6, 7], # left arm
+ [ 1, 2], [ 2, 3], [ 3, 4], # right arm
+ ]
+ bone_colors = [
+ [ 95, 0, 255], [ 79, 0, 255], [ 83, 0, 255], # dark blue
+ [ 31, 0, 255], [ 15, 0, 255], [ 0, 0, 255], # dark blue
+ [127, 205, 255], [127, 205, 255], [ 95, 205, 255], # light blue
+ [ 63, 205, 255], [ 31, 205, 255], [ 0, 205, 255], # light blue
+ [255, 0, 0], [255, 0, 0], # red
+ [191, 63, 63], [191, 63, 191], # magenta
+ [255, 0, 127], [255, 0, 255], # purple
+ [127, 255, 0], [ 63, 255, 0], [ 0, 255, 0], # green
+ [255, 127, 0], [255, 191, 0], [255, 255, 0], # yellow
+
+ ]
+ parent = [1, 8, 1, 2, 3, 1, 5, 6, -1, 8, 9, 10, 8, 12, 13, 0, 0, 15, 16, 14, 19, 14, 11, 22, 11]
diff --git a/lib/body_models/common.py b/lib/body_models/common.py
new file mode 100644
index 0000000000000000000000000000000000000000..e7b72a77400353c1c5210da4320023e79b06894e
--- /dev/null
+++ b/lib/body_models/common.py
@@ -0,0 +1,69 @@
+from lib.kits.basic import *
+
+from smplx import SMPL
+
+from lib.platform import PM
+from lib.body_models.skel_wrapper import SKELWrapper as SKEL
+from lib.body_models.smpl_wrapper import SMPLWrapper
+
+def make_SMPL(gender='neutral', device='cuda:0'):
+ return SMPL(
+ gender = gender,
+ model_path = PM.inputs / 'body_models' / 'smpl',
+ ).to(device)
+
+
+def make_SMPL_hmr2(gender='neutral', device='cuda:0'):
+ ''' SKEL doesn't have neutral model, so align with SKEL, using male. '''
+ return SMPLWrapper(
+ gender = gender,
+ model_path = PM.inputs / 'body_models' / 'smpl',
+ num_body_joints = 23,
+ joint_regressor_extra = PM.inputs / 'body_models/SMPL_to_J19.pkl',
+ ).to(device)
+
+
+
+def make_SKEL(gender='male', device='cuda:0'):
+ ''' We don't have neutral model for SKEL, so use male for now. '''
+ return make_SKEL_mix_joints(gender, device)
+
+
+def make_SKEL_smpl_joints(gender='male', device='cuda:0'):
+ ''' We don't have neutral model for SKEL, so use male for now. '''
+ return SKEL(
+ gender = gender,
+ model_path = PM.inputs / 'body_models' / 'skel',
+ joint_regressor_extra = PM.inputs / 'body_models' / 'SMPL_to_J19.pkl',
+ joint_regressor_custom = PM.inputs / 'body_models' / 'J_regressor_SMPL_MALE.pkl',
+ ).to(device)
+
+
+def make_SKEL_skel_joints(gender='male', device='cuda:0'):
+ ''' We don't have neutral model for SKEL, so use male for now. '''
+ return SKEL(
+ gender = gender,
+ model_path = PM.inputs / 'body_models' / 'skel',
+ joint_regressor_extra = PM.inputs / 'body_models' / 'SMPL_to_J19.pkl',
+ ).to(device)
+
+
+def make_SKEL_mix_joints(gender='male', device='cuda:0'):
+ ''' We don't have neutral model for SKEL, so use male for now. '''
+ return SKEL(
+ gender = gender,
+ model_path = PM.inputs / 'body_models' / 'skel',
+ joint_regressor_extra = PM.inputs / 'body_models' / 'SMPL_to_J19.pkl',
+ joint_regressor_custom = PM.inputs / 'body_models' / 'J_regressor_SKEL_mix_MALE.pkl',
+ ).to(device)
+
+
+def make_SMPLX_moyo(v_template_path:Union[str, Path], batch_size:int=1, device='cuda:0'):
+ from lib.body_models.moyo_smplx_wrapper import MoYoSMPLX
+
+ return MoYoSMPLX(
+ model_path = PM.inputs / 'body_models' / 'smplx',
+ v_template_path = v_template_path,
+ batch_size = batch_size,
+ device = device,
+ )
\ No newline at end of file
diff --git a/lib/body_models/moyo_smplx_wrapper.py b/lib/body_models/moyo_smplx_wrapper.py
new file mode 100644
index 0000000000000000000000000000000000000000..df391ce8da8838b60549d8990b583b37a8707824
--- /dev/null
+++ b/lib/body_models/moyo_smplx_wrapper.py
@@ -0,0 +1,77 @@
+from lib.kits.basic import *
+
+import smplx
+from psbody.mesh import Mesh
+
+class MoYoSMPLX(smplx.SMPLX):
+
+ def __init__(
+ self,
+ model_path : Union[str, Path],
+ v_template_path : Union[str, Path],
+ batch_size = 1,
+ n_betas = 10,
+ device = 'cpu'
+ ):
+
+ if isinstance(v_template_path, Path):
+ v_template_path = str(v_template_path)
+
+ # Load the `v_template`.
+ v_template_mesh = Mesh(filename=v_template_path)
+ v_template = to_tensor(v_template_mesh.v, device=device)
+
+ self.n_betas = n_betas
+
+ # Create the `body_model_params`.
+ body_model_params = {
+ 'model_path' : model_path,
+ 'gender' : 'neutral',
+ 'v_template' : v_template.float(),
+ 'batch_size' : batch_size,
+ 'create_global_orient' : True,
+ 'create_body_pose' : True,
+ 'create_betas' : True,
+ 'num_betas' : self.n_betas, # They actually don't use num_betas.
+ 'create_left_hand_pose' : True,
+ 'create_right_hand_pose' : True,
+ 'create_expression' : True,
+ 'create_jaw_pose' : True,
+ 'create_leye_pose' : True,
+ 'create_reye_pose' : True,
+ 'create_transl' : True,
+ 'use_pca' : False,
+ 'flat_hand_mean' : True,
+ 'dtype' : torch.float32,
+ }
+
+ super().__init__(**body_model_params)
+ self = self.to(device)
+
+ def forward(self, **kwargs):
+ ''' Only all parameters are passed, the batch_size will be flexible adjusted. '''
+ assert 'global_orient' in kwargs, '`global_orient` is required for the forward pass.'
+ assert 'body_pose' in kwargs, '`body_pose` is required for the forward pass.'
+ B = kwargs['global_orient'].shape[0]
+ body_pose = kwargs['body_pose']
+
+ if 'left_hand_pose' not in kwargs:
+ kwargs['left_hand_pose'] = body_pose.new_zeros((B, 45))
+ get_logger().warning('`left_hand_pose` is not provided, but it\'s expected, set to zeros.')
+ if 'right_hand_pose' not in kwargs:
+ kwargs['right_hand_pose'] = body_pose.new_zeros((B, 45))
+ get_logger().warning('`left_hand_pose` is not provided, but it\'s expected, set to zeros.')
+ if 'transl' not in kwargs:
+ kwargs['transl'] = body_pose.new_zeros((B, 3))
+ if 'betas' not in kwargs:
+ kwargs['betas'] = body_pose.new_zeros((B, self.n_betas))
+ if 'expression' not in kwargs:
+ kwargs['expression'] = body_pose.new_zeros((B, 10))
+ if 'jaw_pose' not in kwargs:
+ kwargs['jaw_pose'] = body_pose.new_zeros((B, 3))
+ if 'leye_pose' not in kwargs:
+ kwargs['leye_pose'] = body_pose.new_zeros((B, 3))
+ if 'reye_pose' not in kwargs:
+ kwargs['reye_pose'] = body_pose.new_zeros((B, 3))
+
+ return super().forward(**kwargs)
\ No newline at end of file
diff --git a/lib/body_models/skel/__init__.py b/lib/body_models/skel/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/lib/body_models/skel/alignment/align_config.py b/lib/body_models/skel/alignment/align_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..0a743fd556d952ba57ee3f73f6f047265ff6dc72
--- /dev/null
+++ b/lib/body_models/skel/alignment/align_config.py
@@ -0,0 +1,72 @@
+""" Optimization config file. In 'optim_steps' we define each optimization steps.
+Each step inherit and overwrite the parameters of the previous step."""
+
+config = {
+ 'keepalive_meshviewer': False,
+ 'optim_steps':
+ [
+ {
+ 'description' : 'Adjust the root orientation and translation',
+ 'use_basic_loss': True,
+ 'lr': 1,
+ 'max_iter': 20,
+ 'num_steps': 10,
+ 'line_search_fn': 'strong_wolfe', #'strong_wolfe',
+ 'tolerance_change': 1e-7,# 1e-4, #0.01
+ 'mode' : 'root_only',
+
+ 'l_verts_loose': 300,
+ 'l_time_loss': 0,#5e2,
+
+ 'l_joint': 0.0,
+ 'l_verts': 0,
+ 'l_scapula_loss': 0.0,
+ 'l_spine_loss': 0.0,
+ 'l_pose_loss': 0.0,
+
+
+ },
+ # Adjust the upper limbs
+ {
+ 'description' : 'Adjust the upper limbs pose',
+ 'lr': 0.1,
+ 'max_iter': 20,
+ 'num_steps': 10,
+ 'tolerance_change': 1e-7,
+ 'mode' : 'fixed_upper_limbs', #'fixed_root',
+
+ 'l_verts_loose': 600,
+ 'l_joint': 1e3,
+ 'l_time_loss': 0,# 5e2,
+ 'l_pose_loss': 1e-4,
+ },
+ # Adjust the whole body
+ {
+ 'description' : 'Adjust the whole body pose with fixed root',
+ 'lr': 0.1,
+ 'max_iter': 20,
+ 'num_steps': 10,
+ 'tolerance_change': 1e-7,
+ 'mode' : 'fixed_root', #'fixed_root',
+
+ 'l_verts_loose': 600,
+ 'l_joint': 1e3,
+ 'l_time_loss': 0,
+ 'l_pose_loss': 1e-4,
+ },
+ #
+ {
+ 'description' : 'Free optimization',
+ 'lr': 0.1,
+ 'max_iter': 20,
+ 'num_steps': 10,
+ 'tolerance_change': 1e-7,
+ 'mode' : 'free', #'fixed_root',
+
+ 'l_verts_loose': 600,
+ 'l_joint': 1e3,
+ 'l_time_loss':0,
+ 'l_pose_loss': 1e-4,
+ },
+ ]
+}
\ No newline at end of file
diff --git a/lib/body_models/skel/alignment/align_config_joint.py b/lib/body_models/skel/alignment/align_config_joint.py
new file mode 100644
index 0000000000000000000000000000000000000000..99b0b55068e9c6f45fb403eed0b39f60c2e67c7b
--- /dev/null
+++ b/lib/body_models/skel/alignment/align_config_joint.py
@@ -0,0 +1,72 @@
+""" Optimization config file. In 'optim_steps' we define each optimization steps.
+Each step inherit and overwrite the parameters of the previous step."""
+
+config = {
+ 'keepalive_meshviewer': False,
+ 'optim_steps':
+ [
+ {
+ 'description' : 'Adjust the root orientation and translation',
+ 'use_basic_loss': True,
+ 'lr': 1,
+ 'max_iter': 20,
+ 'num_steps': 10,
+ 'line_search_fn': 'strong_wolfe', #'strong_wolfe',
+ 'tolerance_change': 1e-7,# 1e-4, #0.01
+ 'mode' : 'root_only',
+
+ 'l_verts_loose': 300,
+ 'l_time_loss': 0,#5e2,
+
+ 'l_joint': 0.0,
+ 'l_verts': 0,
+ 'l_scapula_loss': 0.0,
+ 'l_spine_loss': 0.0,
+ 'l_pose_loss': 0.0,
+
+
+ },
+ # Adjust the upper limbs
+ {
+ 'description' : 'Adjust the upper limbs pose',
+ 'lr': 0.1,
+ 'max_iter': 20,
+ 'num_steps': 10,
+ 'tolerance_change': 1e-7,
+ 'mode' : 'fixed_upper_limbs', #'fixed_root',
+
+ 'l_verts_loose': 600,
+ 'l_joint': 2e3,
+ 'l_time_loss': 0,# 5e2,
+ 'l_pose_loss': 1e-3,
+ },
+ # Adjust the whole body
+ {
+ 'description' : 'Adjust the whole body pose with fixed root',
+ 'lr': 0.1,
+ 'max_iter': 20,
+ 'num_steps': 10,
+ 'tolerance_change': 1e-7,
+ 'mode' : 'fixed_root', #'fixed_root',
+
+ 'l_verts_loose': 600,
+ 'l_joint': 1e3,
+ 'l_time_loss': 0,
+ 'l_pose_loss': 1e-4,
+ },
+ #
+ {
+ 'description' : 'Free optimization',
+ 'lr': 0.1,
+ 'max_iter': 20,
+ 'num_steps': 10,
+ 'tolerance_change': 1e-7,
+ 'mode' : 'free', #'fixed_root',
+
+ 'l_verts_loose': 300,
+ 'l_joint': 2e3,
+ 'l_time_loss':0,
+ 'l_pose_loss': 1e-4,
+ },
+ ]
+}
\ No newline at end of file
diff --git a/lib/body_models/skel/alignment/aligner.py b/lib/body_models/skel/alignment/aligner.py
new file mode 100644
index 0000000000000000000000000000000000000000..e3a953538dcdb2df6792a32015102d314f45baf8
--- /dev/null
+++ b/lib/body_models/skel/alignment/aligner.py
@@ -0,0 +1,469 @@
+
+"""
+Copyright©2023 Max-Planck-Gesellschaft zur Förderung
+der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute
+for Intelligent Systems. All rights reserved.
+
+Author: Soyong Shin, Marilyn Keller
+See https://skel.is.tue.mpg.de/license.html for licensing and contact information.
+"""
+
+import traceback
+import math
+import os
+import pickle
+import torch
+import smplx
+import omegaconf
+import torch.nn.functional as F
+from psbody.mesh import Mesh, MeshViewer, MeshViewers
+from tqdm import trange
+from pathlib import Path
+
+import lib.body_models.skel.config as cg
+from lib.body_models.skel.skel_model import SKEL
+from .losses import compute_anchor_pose, compute_anchor_trans, compute_pose_loss, compute_scapula_loss, compute_spine_loss, compute_time_loss, pretty_loss_print
+from .utils import location_to_spheres, to_numpy, to_params, to_torch
+from .align_config import config
+from .align_config_joint import config as config_joint
+
+class SkelFitter(object):
+
+ def __init__(self, gender, device, num_betas=10, export_meshes=False, joint_optim=False) -> None:
+
+ self.smpl = smplx.create(cg.smpl_folder, model_type='smpl', gender=gender, num_betas=num_betas, batch_size=1, export_meshes=False).to(device)
+ self.skel = SKEL(gender).to(device)
+ self.gender = gender
+ self.device = device
+ self.num_betas = num_betas
+ # Instanciate masks used for the vertex to vertex fitting
+ fitting_mask_file = Path(__file__).parent / 'riggid_parts_mask.pkl'
+ fitting_indices = pickle.load(open(fitting_mask_file, 'rb'))
+ fitting_mask = torch.zeros(6890, dtype=torch.bool, device=self.device)
+ fitting_mask[fitting_indices] = 1
+ self.fitting_mask = fitting_mask.reshape(1, -1, 1).to(self.device) # 1xVx1 to be applied to verts that are BxVx3
+
+ smpl_torso_joints = [0,3]
+ verts_mask = (self.smpl.lbs_weights[:,smpl_torso_joints]>0.5).sum(dim=-1)>0
+ self.torso_verts_mask = verts_mask.unsqueeze(0).unsqueeze(-1) # Because verts are of shape BxVx3
+
+ self.export_meshes = export_meshes
+
+
+ # make the cfg being an object using omegaconf
+ if joint_optim:
+ self.cfg = omegaconf.OmegaConf.create(config_joint)
+ else:
+ self.cfg = omegaconf.OmegaConf.create(config)
+
+ # Instanciate the mesh viewer to visualize the fitting
+ if('DISABLE_VIEWER' in os.environ):
+ self.mv = None
+ print("\n DISABLE_VIEWER flag is set, running in headless mode")
+ else:
+ self.mv = MeshViewers((1,2), keepalive=self.cfg.keepalive_meshviewer)
+
+
+ def run_fit(self,
+ trans_in,
+ betas_in,
+ poses_in,
+ batch_size=20,
+ skel_data_init=None,
+ force_recompute=False,
+ debug=False,
+ watch_frame=0,
+ freevert_mesh=None,
+ opt_sequence=False,
+ fix_poses=False,
+ variant_exp=''):
+ """Align SKEL to a SMPL sequence."""
+
+ self.nb_frames = poses_in.shape[0]
+ self.watch_frame = watch_frame
+ self.is_skel_data_init = skel_data_init is not None
+ self.force_recompute = force_recompute
+
+ print('Fitting {} frames'.format(self.nb_frames))
+ print('Watching frame: {}'.format(watch_frame))
+
+ # Initialize SKEL torch params
+ body_params = self._init_params(betas_in, poses_in, trans_in, skel_data_init, variant_exp)
+
+ # We cut the whole sequence in batches for parallel optimization
+ if batch_size > self.nb_frames:
+ batch_size = self.nb_frames
+ print('Batch size is larger than the number of frames. Setting batch size to {}'.format(batch_size))
+
+ n_batch = math.ceil(self.nb_frames/batch_size)
+ pbar = trange(n_batch, desc='Running batch optimization')
+
+ # Initialize the res dict to store the per frame result skel parameters
+ out_keys = ['poses', 'betas', 'trans']
+ if self.export_meshes:
+ out_keys += ['skel_v', 'skin_v', 'smpl_v']
+ res_dict = {key: [] for key in out_keys}
+
+ res_dict['gender'] = self.gender
+ if self.export_meshes:
+ res_dict['skel_f'] = self.skel.skel_f.cpu().numpy().copy()
+ res_dict['skin_f'] = self.skel.skin_f.cpu().numpy().copy()
+ res_dict['smpl_f'] = self.smpl.faces
+
+ # Iterate over the batches to fit the whole sequence
+ for i in pbar:
+
+ if debug:
+ # Only run the first batch to test, ignore the rest
+ if i > 1:
+ continue
+
+ # Get batch start and end indices
+ i_start = i * batch_size
+ i_end = min((i+1) * batch_size, self.nb_frames)
+
+ # Fit the batch
+ betas, poses, trans, verts = self._fit_batch(body_params, i, i_start, i_end, enable_time=opt_sequence, fix_poses=fix_poses)
+ # if torch.isnan(betas).any() \
+ # or torch.isnan(poses).any() \
+ # or torch.isnan(trans).any():
+ # print(f'Nan values detected.')
+ # raise ValueError('Nan values detected in the output.')
+
+ # Store ethe results
+ res_dict['poses'].append(poses)
+ res_dict['betas'].append(betas)
+ res_dict['trans'].append(trans)
+ if self.export_meshes:
+ # Store the meshes vertices
+ skel_output = self.skel.forward(poses=poses, betas=betas, trans=trans, poses_type='skel', skelmesh=True)
+ res_dict['skel_v'].append(skel_output.skel_verts)
+ res_dict['skin_v'].append(skel_output.skin_verts)
+ res_dict['smpl_v'].append(verts)
+
+ if opt_sequence:
+ # Initialize the next frames with current frame
+ body_params['poses_skel'][i_end:] = poses[-1:].detach()
+ body_params['trans_skel'][i_end:] = trans[-1].detach()
+ body_params['betas_skel'][i_end:] = betas[-1:].detach()
+
+ # Concatenate the batches and convert to numpy
+ for key, val in res_dict.items():
+ if isinstance(val, list):
+ res_dict[key] = torch.cat(val, dim=0).detach().cpu().numpy()
+
+ return res_dict
+
+ def _init_params(self, betas_smpl, poses_smpl, trans_smpl, skel_data_init=None, variant_exp=''):
+ """ Return initial SKEL parameters from SMPL data dictionary and an optional SKEL data dictionary."""
+
+ if skel_data_init is None or self.force_recompute:
+
+ poses_skel = torch.zeros((self.nb_frames, self.skel.num_q_params), device=self.device)
+ if variant_exp == '' or variant_exp == '_official_old':
+ poses_skel[:, :3] = poses_smpl[:, :3] # Global orient are similar between SMPL and SKEL, so init with SMPL angles
+ elif variant_exp == '_official_fix':
+ # https://github.com/MarilynKeller/SKEL/commit/d1f6ff62235c142ba010158e00e21fd4fe25807f#diff-09188717a56a42e9589e9bd289f9ddb4fb53160e03c81a7ced70b3a84c1d9d0bR157
+ pass
+ elif variant_exp == '_my_fix':
+ gt_orient_aa = poses_smpl[:, :3]
+ # IMPORTANT: The alignment comes from `exp/inspect_skel/archive/orientation.py`.
+ from lib.utils.geometry.rotation import axis_angle_to_matrix, matrix_to_euler_angles
+ gt_orient_mat = axis_angle_to_matrix(gt_orient_aa)
+ gt_orient_ea = matrix_to_euler_angles(gt_orient_mat, 'YXZ')
+ flip = torch.tensor([-1, 1, 1], device=self.device)
+ poses_skel[:, :3] = gt_orient_ea[:, [2, 1, 0]] * flip
+ else:
+ raise ValueError(f'Unknown variant_exp {variant_exp}')
+
+ betas_skel = torch.zeros((self.nb_frames, 10), device=self.device)
+ betas_skel[:] = betas_smpl[..., :10]
+
+ trans_skel = trans_smpl # Translation is similar between SMPL and SKEL, so init with SMPL translation
+
+ else:
+ # Load from previous alignment
+ betas_skel = to_torch(skel_data_init['betas'], self.device)
+ poses_skel = to_torch(skel_data_init['poses'], self.device)
+ trans_skel = to_torch(skel_data_init['trans'], self.device)
+
+ # Make a dictionary out of the necessary body parameters
+ body_params = {
+ 'betas_skel': betas_skel,
+ 'poses_skel': poses_skel,
+ 'trans_skel': trans_skel,
+ 'betas_smpl': betas_smpl,
+ 'poses_smpl': poses_smpl,
+ 'trans_smpl': trans_smpl
+ }
+
+ return body_params
+
+
+ def _fit_batch(self, body_params, i, i_start, i_end, enable_time=False, fix_poses=False):
+ """ Create parameters for the batch and run the optimization."""
+
+ # Sample a batch ver
+ body_params = { key: val[i_start:i_end] for key, val in body_params.items()}
+
+ # SMPL params
+ betas_smpl = body_params['betas_smpl']
+ poses_smpl = body_params['poses_smpl']
+ trans_smpl = body_params['trans_smpl']
+
+ # SKEL params
+ betas = to_params(body_params['betas_skel'], device=self.device)
+ poses = to_params(body_params['poses_skel'], device=self.device)
+ trans = to_params(body_params['trans_skel'], device=self.device)
+
+ if 'verts' in body_params:
+ verts = body_params['verts']
+ else:
+ # Run a SMPL forward pass to get the SMPL body vertices
+ smpl_output = self.smpl(betas=betas_smpl, body_pose=poses_smpl[:,3:], transl=trans_smpl, global_orient=poses_smpl[:,:3])
+ verts = smpl_output.vertices
+
+ # Optimize
+ config = self.cfg.optim_steps
+ current_cfg = config[0]
+
+ # from lib.kits.debug import set_trace
+ # set_trace()
+
+ try:
+ if fix_poses:
+ # for ci, cfg in enumerate(config[1:]):
+ for ci, cfg in enumerate([config[-1]]): # To debug, only run the last step
+ current_cfg.update(cfg)
+ print(f'Step {ci+1}: {current_cfg.description}')
+ self._optim([trans,betas], poses, betas, trans, verts, current_cfg, enable_time)
+ else:
+ if not enable_time or not self.is_skel_data_init:
+ # Optimize the global rotation and translation for the initial fitting
+ print(f'Step 0: {current_cfg.description}')
+ self._optim([trans,poses], poses, betas, trans, verts, current_cfg, enable_time)
+
+ for ci, cfg in enumerate(config[1:]):
+ # for ci, cfg in enumerate([config[-1]]): # To debug, only run the last step
+ current_cfg.update(cfg)
+ print(f'Step {ci+1}: {current_cfg.description}')
+ self._optim([poses], poses, betas, trans, verts, current_cfg, enable_time)
+
+
+ # # Refine by optimizing the whole body
+ # cfg.update(self.cfg_optim[])
+ # cfg.update({'mode' : 'free', 'tolerance_change': 0.0001, 'l_joint': 0.2e4})
+ # self._optim([trans, poses], poses, betas, trans, verts, cfg)
+ except Exception as e:
+ print(e)
+ traceback.print_exc()
+ # from lib.kits.debug import set_trace
+ # set_trace()
+
+ return betas, poses, trans, verts
+
+ def _optim(self,
+ params,
+ poses,
+ betas,
+ trans,
+ verts,
+ cfg,
+ enable_time=False,
+ ):
+
+ # regress anatomical joints from SMPL's vertices
+ anat_joints = torch.einsum('bik,ji->bjk', [verts, self.skel.J_regressor_osim])
+ dJ=torch.zeros((poses.shape[0], 24, 3), device=betas.device)
+
+ # Create the optimizer
+ optimizer = torch.optim.LBFGS(params,
+ lr=cfg.lr,
+ max_iter=cfg.max_iter,
+ line_search_fn=cfg.line_search_fn,
+ tolerance_change=cfg.tolerance_change)
+
+ poses_init = poses.detach().clone()
+ trans_init = trans.detach().clone()
+
+ def closure():
+ optimizer.zero_grad()
+
+ # fi = self.watch_frame #frame of the batch to display
+ # output = self.skel.forward(poses=poses[fi:fi+1],
+ # betas=betas[fi:fi+1],
+ # trans=trans[fi:fi+1],
+ # poses_type='skel',
+ # dJ=dJ[fi:fi+1],
+ # skelmesh=True)
+
+ # self._fstep_plot(output, cfg, verts[fi:fi+1], anat_joints[fi:fi+1], )
+
+ loss_dict = self._fitting_loss(poses,
+ poses_init,
+ betas,
+ trans,
+ trans_init,
+ dJ,
+ anat_joints,
+ verts,
+ cfg,
+ enable_time)
+
+ # print(pretty_loss_print(loss_dict))
+
+ loss = sum(loss_dict.values())
+ loss.backward()
+
+ return loss
+
+ for step_i in range(cfg.num_steps):
+ loss = optimizer.step(closure).item()
+
+ def _get_masks(self, cfg):
+ pose_mask = torch.ones((self.skel.num_q_params)).to(self.device).unsqueeze(0)
+ verts_mask = torch.ones_like(self.fitting_mask)
+ joint_mask = torch.ones((self.skel.num_joints, 3)).to(self.device).unsqueeze(0).bool()
+
+ # Mask vertices
+ if cfg.mode=='root_only':
+ # Only optimize the global rotation of the body, i.e. the first 3 angles of the pose
+ pose_mask[:] = 0 # Only optimize for the global rotation
+ pose_mask[:,:3] = 1
+ # Only fit the thorax vertices to recover the proper body orientation and translation
+ verts_mask = self.torso_verts_mask
+
+ elif cfg.mode=='fixed_upper_limbs':
+ upper_limbs_joints = [0,1,2,3,6,9,12,15,17]
+ verts_mask = (self.smpl.lbs_weights[:,upper_limbs_joints]>0.5).sum(dim=-1)>0
+ verts_mask = verts_mask.unsqueeze(0).unsqueeze(-1)
+
+ joint_mask[:, [3,4,5,8,9,10,18,23], :] = 0 # Do not try to match the joints of the upper limbs
+
+ pose_mask[:] = 1
+ pose_mask[:,:3] = 0 # Block the global rotation
+ pose_mask[:,19] = 0 # block the lumbar twist
+ # pose_mask[:, 36:39] = 0
+ # pose_mask[:, 43:46] = 0
+ # pose_mask[:, 62:65] = 0
+ # pose_mask[:, 62:65] = 0
+
+ elif cfg.mode=='fixed_root':
+ pose_mask[:] = 1
+ pose_mask[:,:3] = 0 # Block the global rotation
+ # pose_mask[:,19] = 0 # block the lumbar twist
+
+ # The orientation of the upper limbs is often wrong in SMPL so ignore these vertices for the finale step
+ upper_limbs_joints = [1,2,16,17]
+ verts_mask = (self.smpl.lbs_weights[:,upper_limbs_joints]>0.5).sum(dim=-1)>0
+ verts_mask = torch.logical_not(verts_mask)
+ verts_mask = verts_mask.unsqueeze(0).unsqueeze(-1)
+
+ elif cfg.mode=='free':
+ verts_mask = torch.ones_like(self.fitting_mask )
+
+ joint_mask[:]=0
+ joint_mask[:, [19,14], :] = 1 # Only fir the scapula join to avoid collapsing shoulders
+
+ else:
+ raise ValueError(f'Unknown mode {cfg.mode}')
+
+ return pose_mask, verts_mask, joint_mask
+
+ def _fitting_loss(self,
+ poses,
+ poses_init,
+ betas,
+ trans,
+ trans_init,
+ dJ,
+ anat_joints,
+ verts,
+ cfg,
+ enable_time=False):
+
+ loss_dict = {}
+
+
+ pose_mask, verts_mask, joint_mask = self._get_masks(cfg)
+ poses = poses * pose_mask + poses_init * (1-pose_mask)
+
+ # Mask joints to not optimize before computing the losses
+
+ output = self.skel.forward(poses=poses, betas=betas, trans=trans, poses_type='skel', dJ=dJ, skelmesh=False)
+
+ # Fit the SMPL vertices
+ # We know the skinning of the forearm and the neck are not perfect,
+ # so we create a mask of the SMPL vertices that are important to fit, like the hands and the head
+ loss_dict['verts_loss_loose'] = cfg.l_verts_loose * (verts_mask * (output.skin_verts - verts)**2).sum() / (((verts_mask).sum()*self.nb_frames))
+
+ # Fit the regressed joints, this avoids collapsing shoulders
+ # loss_dict['joint_loss'] = cfg.l_joint * F.mse_loss(output.joints, anat_joints)
+ loss_dict['joint_loss'] = cfg.l_joint * (joint_mask * (output.joints - anat_joints)**2).mean()
+
+ # Time consistancy
+ if poses.shape[0] > 1 and enable_time:
+ # This avoids unstable hips orientationZ
+ loss_dict['time_loss'] = cfg.l_time_loss * F.mse_loss(poses[1:], poses[:-1])
+
+ loss_dict['pose_loss'] = cfg.l_pose_loss * compute_pose_loss(poses, poses_init)
+
+ if cfg.use_basic_loss is False:
+ # These losses can be used to regularize the optimization but are not always necessary
+ loss_dict['anch_rot'] = cfg.l_anch_pose * compute_anchor_pose(poses, poses_init)
+ loss_dict['anch_trans'] = cfg.l_anch_trans * compute_anchor_trans(trans, trans_init)
+
+ loss_dict['verts_loss'] = cfg.l_verts * (verts_mask * self.fitting_mask * (output.skin_verts - verts)**2).sum() / (self.fitting_mask*verts_mask).sum()
+
+ # Regularize the pose
+ loss_dict['scapula_loss'] = cfg.l_scapula_loss * compute_scapula_loss(poses)
+ loss_dict['spine_loss'] = cfg.l_spine_loss * compute_spine_loss(poses)
+
+ # Adjust the losses of all the pose regularizations sub losses with the pose_reg_factor value
+ for key in ['scapula_loss', 'spine_loss', 'pose_loss']:
+ loss_dict[key] = cfg.pose_reg_factor * loss_dict[key]
+
+ return loss_dict
+
+ def _fstep_plot(self, output, cfg, verts, anat_joints):
+ "Function to plot each step"
+
+ if('DISABLE_VIEWER' in os.environ):
+ return
+
+ pose_mask, verts_mask, joint_mask = self._get_masks(cfg)
+
+ skin_err_value = ((output.skin_verts[0] - verts[0])**2).sum(dim=-1).sqrt()
+ skin_err_value = skin_err_value / 0.05
+ skin_err_value = to_numpy(skin_err_value)
+
+ skin_mesh = Mesh(v=to_numpy(output.skin_verts[0]), f=[], vc='white')
+ skel_mesh = Mesh(v=to_numpy(output.skel_verts[0]), f=self.skel.skel_f.cpu().numpy(), vc='white')
+
+ # Display vertex distance on SMPL
+ smpl_verts = to_numpy(verts[0])
+ smpl_mesh = Mesh(v=smpl_verts, f=self.smpl.faces)
+ smpl_mesh.set_vertex_colors_from_weights(skin_err_value, scale_to_range_1=False)
+
+ smpl_mesh_masked = Mesh(v=smpl_verts[to_numpy(verts_mask[0,:,0])], f=[], vc='green')
+ smpl_mesh_pc = Mesh(v=smpl_verts, f=[], vc='green')
+
+ skin_mesh_err = Mesh(v=to_numpy(output.skin_verts[0]), f=self.skel.skin_f.cpu().numpy(), vc='white')
+ skin_mesh_err.set_vertex_colors_from_weights(skin_err_value, scale_to_range_1=False)
+ # List the meshes to display
+ meshes_left = [skin_mesh_err, smpl_mesh_pc]
+ meshes_right = [smpl_mesh_masked, skin_mesh, skel_mesh]
+
+ if cfg.l_joint > 0:
+ # Plot the joints
+ meshes_right += location_to_spheres(to_numpy(output.joints[joint_mask[:,:,0]]), color=(1,0,0), radius=0.02)
+ meshes_right += location_to_spheres(to_numpy(anat_joints[joint_mask[:,:,0]]), color=(0,1,0), radius=0.02) \
+
+
+ self.mv[0][0].set_dynamic_meshes(meshes_left)
+ self.mv[0][1].set_dynamic_meshes(meshes_right)
+
+ # print(poses[frame_to_watch, :3])
+ # print(trans[frame_to_watch])
+ # print(betas[frame_to_watch, :3])
+ # mv.get_keypress()
diff --git a/lib/body_models/skel/alignment/losses.py b/lib/body_models/skel/alignment/losses.py
new file mode 100644
index 0000000000000000000000000000000000000000..51238fe784ff0c65ecd570cea76b6e11c338158c
--- /dev/null
+++ b/lib/body_models/skel/alignment/losses.py
@@ -0,0 +1,47 @@
+import torch
+
+def compute_scapula_loss(poses):
+
+ scapula_indices = [26, 27, 28, 36, 37, 38]
+
+ scapula_poses = poses[:, scapula_indices]
+ scapula_loss = torch.linalg.norm(scapula_poses, ord=2)
+ return scapula_loss
+
+def compute_spine_loss(poses):
+
+ spine_indices = range(17, 25)
+
+ spine_poses = poses[:, spine_indices]
+ spine_loss = torch.linalg.norm(spine_poses, ord=2)
+ return spine_loss
+
+def compute_pose_loss(poses, pose_init):
+
+ pose_loss = torch.linalg.norm(poses[:, 3:], ord=2) # The global rotation should not be constrained
+ return pose_loss
+
+def compute_anchor_pose(poses, pose_init):
+
+ pose_loss = torch.nn.functional.mse_loss(poses[:, :3], pose_init[:, :3])
+ return pose_loss
+
+def compute_anchor_trans(trans, trans_init):
+
+ trans_loss = torch.nn.functional.mse_loss(trans, trans_init)
+ return trans_loss
+
+def compute_time_loss(poses):
+
+ pose_delta = poses[1:] - poses[:-1]
+ time_loss = torch.linalg.norm(pose_delta, ord=2)
+ return time_loss
+
+def pretty_loss_print(loss_dict):
+ # Pretty print the loss on the form loss val | loss1 val1 | loss2 val2
+ # Start with the total loss
+ loss = sum(loss_dict.values())
+ pretty_loss = f'{loss:.4f}'
+ for key, val in loss_dict.items():
+ pretty_loss += f' | {key} {val:.4f}'
+ return pretty_loss
diff --git a/lib/body_models/skel/alignment/riggid_parts_mask.pkl b/lib/body_models/skel/alignment/riggid_parts_mask.pkl
new file mode 100644
index 0000000000000000000000000000000000000000..53c9b8baf1992053946e27b33d6f677b8d52f68c
Binary files /dev/null and b/lib/body_models/skel/alignment/riggid_parts_mask.pkl differ
diff --git a/lib/body_models/skel/alignment/utils.py b/lib/body_models/skel/alignment/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..b1353100ad4f2bdf55956d5f1cbfd6ae17002406
--- /dev/null
+++ b/lib/body_models/skel/alignment/utils.py
@@ -0,0 +1,119 @@
+
+import os
+import pickle
+import torch
+import numpy as np
+from psbody.mesh.sphere import Sphere
+
+# to_params = lambda x: torch.from_numpy(x).float().to(self.device).requires_grad_(True)
+# to_torch = lambda x: torch.from_numpy(x).float().to(self.device)
+
+def to_params(x, device):
+ return x.to(device).requires_grad_(True)
+
+def to_torch(x, device):
+ return torch.from_numpy(x).float().to(device)
+
+def to_numpy(x):
+ return x.detach().cpu().numpy()
+
+def load_smpl_seq(smpl_seq_path, gender=None, straighten_hands=False):
+
+ if not os.path.exists(smpl_seq_path):
+ raise Exception('Path does not exist: {}'.format(smpl_seq_path))
+
+ if smpl_seq_path.endswith('.pkl'):
+ data_dict = pickle.load(open(smpl_seq_path, 'rb'))
+
+ elif smpl_seq_path.endswith('.npz'):
+ data_dict = np.load(smpl_seq_path, allow_pickle=True)
+
+ if data_dict.files == ['pred_smpl_parms', 'verts', 'pred_cam_t']:
+ data_dict = data_dict['pred_smpl_parms'].item()# ['global_orient', 'body_pose', 'body_pose_axis_angle', 'global_orient_axis_angle', 'betas']
+ else:
+ data_dict = {key: data_dict[key] for key in data_dict.keys()} # convert to python dict
+ else:
+ raise Exception('Unknown file format: {}. Supported formats are .pkl and .npz'.format(smpl_seq_path))
+
+ # Instanciate a dictionary with the keys expected by the fitter
+ data_fixed = {}
+
+ # Get gender
+ if 'gender' not in data_dict:
+ assert gender is not None, f"The provided SMPL data dictionary does not contain gender, you need to pass it in command line"
+ data_fixed['gender'] = gender
+ elif not isinstance(data_dict['gender'], str):
+ # In some npz, the gender type happens to be: array('male', dtype=' None:
+ super().__init__()
+ pass
+
+ def q_to_translation(self, q, **kwargs):
+ return torch.zeros(q.shape[0], 3).to(q.device)
+
+
+class CustomJoint(OsimJoint):
+
+ def __init__(self, axis, axis_flip) -> None:
+ super().__init__()
+ self.register_buffer('axis', torch.FloatTensor(axis))
+ self.register_buffer('axis_flip', torch.FloatTensor(axis_flip))
+ self.register_buffer('nb_dof', torch.tensor(len(axis)))
+
+ def q_to_rot(self, q, **kwargs):
+
+ ident = torch.eye(3, dtype=q.dtype).to(q.device)
+
+ Rp = ident.unsqueeze(0).expand(q.shape[0],3,3) # torch.eye(q.shape[0], 3, 3)
+ for i in range(self.nb_dof):
+ axis = self.axis[i].to(q.device)
+ angle_axis = q[:, i:i+1] * self.axis_flip[i].to(q.device) * axis
+ Rp_i = axis_angle_to_matrix(angle_axis)
+ Rp = torch.matmul(Rp_i, Rp)
+ return Rp
+
+
+
+class CustomJoint1D(OsimJoint):
+
+ def __init__(self, axis, axis_flip) -> None:
+ super().__init__()
+ self.axis = torch.FloatTensor(axis)
+ self.axis = self.axis / torch.linalg.norm(self.axis)
+ self.axis_flip = torch.FloatTensor(axis_flip)
+ self.nb_dof = 1
+
+ def q_to_rot(self, q, **kwargs):
+ axis = self.axis.to(q.device)
+ angle_axis = q[:, 0:1] * self.axis_flip.to(q.device) * axis
+ Rp_i = axis_angle_to_matrix(angle_axis)
+ return Rp_i
+
+
+class WalkerKnee(OsimJoint):
+
+ def __init__(self) -> None:
+ super().__init__()
+ self.register_buffer('nb_dof', torch.tensor(1))
+ # self.nb_dof = 1
+
+ def q_to_rot(self, q, **kwargs):
+ # Todo : for now implement a basic knee
+ theta_i = torch.zeros(q.shape[0], 3).to(q.device)
+ theta_i[:, 2] = -q[:, 0]
+ Rp_i = axis_angle_to_matrix(theta_i)
+ return Rp_i
+
+class PinJoint(OsimJoint):
+
+ def __init__(self, parent_frame_ori) -> None:
+ super().__init__()
+ self.register_buffer('parent_frame_ori', torch.FloatTensor(parent_frame_ori))
+ self.register_buffer('nb_dof', torch.tensor(1))
+
+
+ def q_to_rot(self, q, **kwargs):
+
+ talus_orient_torch = self.parent_frame_ori.to(q.device)
+ Ra_i = euler_angles_to_matrix(talus_orient_torch, 'XYZ')
+
+ z_axis = torch.FloatTensor([0,0,1]).to(q.device)
+ axis = torch.matmul(Ra_i, z_axis).to(q.device)
+
+ axis_angle = q[:, 0:1] * axis
+ Rp_i = axis_angle_to_matrix(axis_angle)
+
+ return Rp_i
+
+
+class ConstantCurvatureJoint(CustomJoint):
+
+ def __init__(self, **kwargs ) -> None:
+ super().__init__( **kwargs)
+
+
+
+class EllipsoidJoint(CustomJoint):
+
+ def __init__(self, **kwargs) -> None:
+ super().__init__(**kwargs)
diff --git a/lib/body_models/skel/skel_model.py b/lib/body_models/skel/skel_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..d6cd0695b201be2dd909279050e63d4fd2075c4f
--- /dev/null
+++ b/lib/body_models/skel/skel_model.py
@@ -0,0 +1,675 @@
+
+"""
+Copyright©2024 Max-Planck-Gesellschaft zur Förderung
+der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute
+for Intelligent Systems. All rights reserved.
+
+Author: Marilyn Keller
+See https://skel.is.tue.mpg.de/license.html for licensing and contact information.
+"""
+
+import os
+import torch.nn as nn
+import torch
+import numpy as np
+import pickle as pkl
+from typing import NewType, Optional
+
+from lib.body_models.skel.joints_def import curve_torch_3d, left_scapula, right_scapula
+from lib.body_models.skel.osim_rot import ConstantCurvatureJoint, CustomJoint, EllipsoidJoint, PinJoint, WalkerKnee
+from lib.body_models.skel.utils import build_homog_matrix, rotation_matrix_from_vectors, sparce_coo_matrix2tensor, with_zeros, matmul_chain
+from dataclasses import dataclass, fields
+
+from lib.body_models.skel.kin_skel import scaling_keypoints, pose_param_names, smpl_joint_corresp
+import lib.body_models.skel.config as cg
+
+Tensor = NewType('Tensor', torch.Tensor)
+
+@dataclass
+class ModelOutput:
+ vertices: Optional[Tensor] = None
+ joints: Optional[Tensor] = None
+ full_pose: Optional[Tensor] = None
+ global_orient: Optional[Tensor] = None
+ transl: Optional[Tensor] = None
+ v_shaped: Optional[Tensor] = None
+
+ def __getitem__(self, key):
+ return getattr(self, key)
+
+ def get(self, key, default=None):
+ return getattr(self, key, default)
+
+ def __iter__(self):
+ return self.keys()
+
+ def keys(self):
+ keys = [t.name for t in fields(self)]
+ return iter(keys)
+
+ def values(self):
+ values = [getattr(self, t.name) for t in fields(self)]
+ return iter(values)
+
+ def items(self):
+ data = [(t.name, getattr(self, t.name)) for t in fields(self)]
+ return iter(data)
+
+@dataclass
+class SKELOutput(ModelOutput):
+ betas: Optional[Tensor] = None
+ body_pose: Optional[Tensor] = None
+ skin_verts: Optional[Tensor] = None
+ skel_verts: Optional[Tensor] = None
+ joints: Optional[Tensor] = None
+ joints_ori: Optional[Tensor] = None
+ betas: Optional[Tensor] = None
+ poses: Optional[Tensor] = None
+ trans : Optional[Tensor] = None
+ pose_offsets : Optional[Tensor] = None
+ joints_tpose : Optional[Tensor] = None
+ v_skin_shaped : Optional[Tensor] = None
+
+
+class SKEL(nn.Module):
+
+ num_betas = 10
+
+ def __init__(self, gender, model_path=None, custom_joint_reg_path=None, **kwargs):
+ super(SKEL, self).__init__()
+
+ if gender not in ['male', 'female']:
+ raise RuntimeError(f'Invalid Gender, got {gender}')
+
+ self.gender = gender
+
+ if model_path is None:
+ # skel_file = f"/Users/mkeller2/Data/skel_models_v1.0/skel_{gender}.pkl"
+ skel_file = os.path.join(cg.skel_folder, f"skel_{gender}.pkl")
+ else:
+ skel_file = os.path.join(model_path, f"skel_{gender}.pkl")
+ assert os.path.exists(skel_file), f"Skel model file {skel_file} does not exist"
+
+ skel_data = pkl.load(open(skel_file, 'rb'))
+
+ # Check that the version of the skel model is compatible with this loader
+ assert 'version' in skel_data, f"Expected version 1.1.1 of the SKEL picke. Please download the latest skel pkl versions from https://skel.is.tue.mpg.de/download.html"
+ version = skel_data['version']
+ assert version == '1.1.1', f"Expected version 1.1.1, got {version}. Please download the latest skel pkl versions from https://skel.is.tue.mpg.de/download.html"
+
+ self.num_betas = 10
+ self.num_q_params = 46
+ self.bone_names = skel_data['bone_names']
+ self.num_joints = skel_data['J_regressor_osim'].shape[0]
+ self.num_joints_smpl = skel_data['J_regressor'].shape[0]
+
+ self.joints_name = skel_data['joints_name']
+ self.pose_params_name = skel_data['pose_params_name']
+
+ # register the template meshes
+ self.register_buffer('skin_template_v', torch.FloatTensor(skel_data['skin_template_v']))
+ self.register_buffer('skin_f', torch.LongTensor(skel_data['skin_template_f']))
+
+ self.register_buffer('skel_template_v', torch.FloatTensor(skel_data['skel_template_v']))
+ self.register_buffer('skel_f', torch.LongTensor(skel_data['skel_template_f']))
+
+ # Shape corrective blend shapes
+ self.register_buffer('shapedirs', torch.FloatTensor(np.array(skel_data['shapedirs'][:,:,:self.num_betas])))
+ self.register_buffer('posedirs', torch.FloatTensor(np.array(skel_data['posedirs'])))
+
+ # Model sparse joints regressor, regresses joints location from a mesh
+ self.register_buffer('J_regressor', sparce_coo_matrix2tensor(skel_data['J_regressor']))
+
+ # Regress the anatomical joint location with a regressor learned from BioAmass
+ if custom_joint_reg_path is not None:
+ J_regressor_skel = pkl.load(open(custom_joint_reg_path, 'rb'))
+ if 'scipy.sparse' in str(type(J_regressor_skel)):
+ J_regressor_skel = J_regressor_skel.todense()
+ self.register_buffer('J_regressor_osim', torch.FloatTensor(J_regressor_skel))
+ print('WARNING: Using custom joint regressor')
+ else:
+ self.register_buffer('J_regressor_osim', sparce_coo_matrix2tensor(skel_data['J_regressor_osim'], make_dense=True))
+
+ self.register_buffer('per_joint_rot', torch.FloatTensor(skel_data['per_joint_rot']))
+
+ # Skin model skinning weights
+ self.register_buffer('skin_weights', sparce_coo_matrix2tensor(skel_data['skin_weights']))
+
+ # Skeleton model skinning weights
+ self.register_buffer('skel_weights', sparce_coo_matrix2tensor(skel_data['skel_weights']))
+ self.register_buffer('skel_weights_rigid', sparce_coo_matrix2tensor(skel_data['skel_weights_rigid']))
+
+ # Kinematic tree of the model
+ self.register_buffer('kintree_table', torch.from_numpy(skel_data['osim_kintree_table'].astype(np.int64)))
+ self.register_buffer('parameter_mapping', torch.from_numpy(skel_data['parameter_mapping'].astype(np.int64)))
+
+ # transformation from osim can pose to T pose
+ self.register_buffer('tpose_transfo', torch.FloatTensor(skel_data['tpose_transfo']))
+
+ # transformation from osim can pose to A pose
+ self.register_buffer('apose_transfo', torch.FloatTensor(skel_data['apose_transfo']))
+ self.register_buffer('apose_rel_transfo', torch.FloatTensor(skel_data['apose_rel_transfo']))
+
+ # Indices of bones which orientation should not vary with beta in T pose:
+ joint_idx_fixed_beta = [0, 5, 10, 13, 18, 23]
+ self.register_buffer('joint_idx_fixed_beta', torch.IntTensor(joint_idx_fixed_beta))
+
+ id_to_col = {self.kintree_table[1, i].item(): i for i in range(self.kintree_table.shape[1])}
+ self.register_buffer('parent', torch.LongTensor(
+ [id_to_col[self.kintree_table[0, it].item()] for it in range(1, self.kintree_table.shape[1])]))
+
+
+ # child array
+ # TODO create this array in the SKEL creator
+ child_array = []
+ Nj = self.num_joints
+ for i in range(0, Nj):
+ try:
+ j_array = torch.where(self.kintree_table[0] == i)[0] # candidate child lines
+ if len(j_array) == 0:
+ child_index = 0
+ else:
+
+ j = j_array[0]
+ if j>=len(self.kintree_table[1]):
+ child_index = 0
+ else:
+ child_index = self.kintree_table[1,j].item()
+ child_array.append(child_index)
+ except:
+ import ipdb; ipdb.set_trace()
+
+ # print(f"child_array: ")
+ # [print(i,child_array[i]) for i in range(0, Nj)]
+ self.register_buffer('child', torch.LongTensor(child_array))
+
+ # Instantiate joints
+ self.joints_dict = nn.ModuleList([
+ CustomJoint(axis=[[0,0,1], [1,0,0], [0,1,0]], axis_flip=[1, 1, 1]), # 0 pelvis
+ CustomJoint(axis=[[0,0,1], [1,0,0], [0,1,0]], axis_flip=[1, 1, 1]), # 1 femur_r
+ WalkerKnee(), # 2 tibia_r
+ PinJoint(parent_frame_ori = [0.175895, -0.105208, 0.0186622]), # 3 talus_r Field taken from .osim Joint-> frames -> PhysicalOffsetFrame -> orientation
+ PinJoint(parent_frame_ori = [-1.76818999, 0.906223, 1.8196000]), # 4 calcn_r
+ PinJoint(parent_frame_ori = [-3.141589999, 0.6199010, 0]), # 5 toes_r
+ CustomJoint(axis=[[0,0,1], [1,0,0], [0,1,0]], axis_flip=[1, -1, -1]), # 6 femur_l
+ WalkerKnee(), # 7 tibia_l
+ PinJoint(parent_frame_ori = [0.175895, -0.105208, 0.0186622]), # 8 talus_l
+ PinJoint(parent_frame_ori = [1.768189999 ,-0.906223, 1.8196000]), # 9 calcn_l
+ PinJoint(parent_frame_ori = [-3.141589999, -0.6199010, 0]), # 10 toes_l
+ ConstantCurvatureJoint(axis=[[1,0,0], [0,0,1], [0,1,0]], axis_flip=[1, 1, 1]), # 11 lumbar
+ ConstantCurvatureJoint(axis=[[1,0,0], [0,0,1], [0,1,0]], axis_flip=[1, 1, 1]), # 12 thorax
+ ConstantCurvatureJoint(axis=[[1,0,0], [0,0,1], [0,1,0]], axis_flip=[1, 1, 1]), # 13 head
+ EllipsoidJoint(axis=[[0,1,0], [0,0,1], [1,0,0]], axis_flip=[1, -1, -1]), # 14 scapula_r
+ CustomJoint(axis=[[1,0,0], [0,1,0], [0,0,1]], axis_flip=[1, 1, 1]), # 15 humerus_r
+ CustomJoint(axis=[[0.0494, 0.0366, 0.99810825]], axis_flip=[[1]]), # 16 ulna_r
+ CustomJoint(axis=[[-0.01716099, 0.99266564, -0.11966796]], axis_flip=[[1]]), # 17 radius_r
+ CustomJoint(axis=[[1,0,0], [0,0,-1]], axis_flip=[1, 1]), # 18 hand_r
+ EllipsoidJoint(axis=[[0,1,0], [0,0,1], [1,0,0]], axis_flip=[1, 1, 1]), # 19 scapula_l
+ CustomJoint(axis=[[1,0,0], [0,1,0], [0,0,1]], axis_flip=[1, 1, 1]), # 20 humerus_l
+ CustomJoint(axis=[[-0.0494, -0.0366, 0.99810825]], axis_flip=[[1]]), # 21 ulna_l
+ CustomJoint(axis=[[0.01716099, -0.99266564, -0.11966796]], axis_flip=[[1]]), # 22 radius_l
+ CustomJoint(axis=[[-1,0,0], [0,0,-1]], axis_flip=[1, 1]), # 23 hand_l
+ ])
+
+ def pose_params_to_rot(self, osim_poses):
+ """ Transform the pose parameters to 3x3 rotation matrices
+ Each parameter is mapped to a joint as described in joint_dict.
+ The specific joint object is then used to compute the rotation matrix.
+ """
+
+ B = osim_poses.shape[0]
+ Nj = self.num_joints
+
+ ident = torch.eye(3, dtype=osim_poses.dtype).to(osim_poses.device)
+ Rp = ident.unsqueeze(0).unsqueeze(0).repeat(B, Nj,1,1)
+ tp = torch.zeros(B, Nj, 3).to(osim_poses.device)
+ start_index = 0
+ for i in range(0, Nj):
+ joint_object = self.joints_dict[i]
+ end_index = start_index + joint_object.nb_dof
+ Rp[:, i] = joint_object.q_to_rot(osim_poses[:, start_index:end_index])
+ start_index = end_index
+ return Rp, tp
+
+
+ def params_name_to_index(self, param_name):
+
+ assert param_name in pose_param_names
+ param_index = pose_param_names.index(param_name)
+ return param_index
+
+
+ def forward(self, poses, betas, trans, poses_type='skel', skelmesh=True, dJ=None, pose_dep_bs=True):
+ """
+ params
+ poses : B x 46 tensor of pose parameters
+ betas : B x 10 tensor of shape parameters, same as SMPL
+ trans : B x 3 tensor of translation
+ poses_type : str, 'skel', should not be changed
+ skelmesh : bool, if True, returns the skeleton vertices. The skeleton mesh is heavy so to fit on GPU memory, set to False when not needed.
+ dJ : B x 24 x 3 tensor of the offset of the joints location from the anatomical regressor. If None, the offset is set to 0.
+ pose_dep_bs : bool, if True (default), applies the pose dependant blend shapes. If False, the pose dependant blend shapes are not applied.
+
+ return SKELOutput class with the following fields:
+ betas : Bx10 tensor of shape parameters
+ poses : Bx46 tensor of pose parameters
+ skin_verts : Bx6890x3 tensor of skin vertices
+ skel_verts : tensor of skeleton vertices
+ joints : Bx24x3 tensor of joints location
+ joints_ori : Bx24x3x3 tensor of joints orientation
+ trans : Bx3 pose dependant blend shapes offsets
+ pose_offsets : Bx6080x3 pose dependant blend shapes offsets
+ joints_tpose : Bx24x3 3D joints location in T pose
+
+ In this function we use the following conventions:
+ B : batch size
+ Ns : skin vertices
+ Nk : skeleton vertices
+ """
+
+ Ns = self.skin_template_v.shape[0] # nb skin vertices
+ Nk = self.skel_template_v.shape[0] # nb skeleton vertices
+ Nj = self.num_joints
+ B = poses.shape[0]
+ device = poses.device
+
+ # Check the shapes of the inputs
+ assert len(betas.shape) == 2, f"Betas should be of shape (B, {self.num_betas}), but got {betas.shape}"
+ assert poses.shape[0] == betas.shape[0], f"Expected poses and betas to have the same batch size, but got {poses.shape[0]} and {betas.shape[0]}"
+ assert poses.shape[0] == trans.shape[0], f"Expected poses and betas to have the same batch size, but got {poses.shape[0]} and {trans.shape[0]}"
+
+ if dJ is not None:
+ assert len(dJ.shape) == 3, f"Expected dJ to have shape (B, {Nj}, 3), but got {dJ.shape}"
+ assert dJ is None or dJ.shape[0] == B, f"Expected dJ to have the same batch size as poses, but got {dJ.shape[0]} and {poses.shape[0]}"
+ assert dJ.shape[1] == Nj, f"Expected dJ to have the same number of joints as the model, but got {dJ.shape[1]} and {Nj}"
+
+ # Check the device of the inputs
+ assert betas.device == device, f"Betas should be on device {device}, but got {betas.device}"
+ assert trans.device == device, f"Trans should be on device {device}, but got {trans.device}"
+
+ skin_v0 = self.skin_template_v[None, :]
+ skel_v0 = self.skel_template_v[None, :]
+ betas = betas[:, :, None] # TODO Name the expanded beta differently
+
+ # TODO clean this part
+ assert poses_type in ['skel', 'bsm'], f"got {poses_type}"
+ if poses_type == 'bsm':
+ assert poses.shape[1] == self.num_q_params - 3, f'With poses_type bsm, expected parameters of shape (B, {self.num_q_params - 3}, got {poses.shape}'
+ poses_bsm = poses
+ poses_skel = torch.zeros(B, self.num_q_params)
+ poses_skel[:,:3] = poses_bsm[:, :3]
+ trans = poses_bsm[:, 3:6] # In BSM parametrization, the hips translation is given by params 3 to 5
+ poses_skel[:, 3:] = poses_bsm
+ poses = poses_skel
+
+ else:
+ assert poses.shape[1] == self.num_q_params, f'With poses_type skel, expected parameters of shape (B, {self.num_q_params}), got {poses.shape}'
+ pass
+ # Load poses as expected
+ # Distinction bsm skel. by default it will be bsm
+
+ # ------- Shape ----------
+ # Apply the beta offset to the template
+ shapedirs = self.shapedirs.view(-1, self.num_betas)[None, :].expand(B, -1, -1) # B x D*Ns x num_betas
+ v_shaped = skin_v0 + torch.matmul(shapedirs, betas).view(B, Ns, 3)
+
+ # ------- Joints ----------
+ # Regress the anatomical joint location
+ J = torch.einsum('bik,ji->bjk', [v_shaped, self.J_regressor_osim]) # BxJx3 # osim regressor
+ # J = self.apose_transfo[:, :3, -1].view(1, Nj, 3).expand(B, -1, -1) # Osim default pose joints location
+
+ if dJ is not None:
+ J = J + dJ
+ J_tpose = J.clone()
+
+ # Local translation
+ J_ = J.clone() # BxJx3
+ J_[:, 1:, :] = J[:, 1:, :] - J[:, self.parent, :]
+ t = J_[:, :, :, None] # BxJx3x1
+
+ # ------- Bones transformation matrix----------
+
+ # Bone initial transform to go from unposed to SMPL T pose
+ Rk01 = self.compute_bone_orientation(J, J_)
+
+ # BSM default pose rotations
+ Ra = self.apose_rel_transfo[:, :3, :3].view(1, Nj, 3,3).expand(B, Nj, 3, 3)
+
+ # Local bone rotation given by the pose param
+ Rp, tp = self.pose_params_to_rot(poses) # BxNjx3x3 pose params to rotation
+
+ R = matmul_chain([Rk01, Ra.transpose(2,3), Rp, Ra, Rk01.transpose(2,3)])
+
+ ###### Compute translation for non pure rotation joints
+ t_posed = t.clone()
+
+ # Scapula
+ thorax_width = torch.norm(J[:, 19, :] - J[:, 14, :], dim=1) # Distance between the two scapula joints, size B
+ thorax_height = torch.norm(J[:, 12, :] - J[:, 11, :], dim=1) # Distance between the two scapula joints, size B
+
+ angle_abduction = poses[:,26]
+ angle_elevation = poses[:,27]
+ angle_rot = poses[:,28]
+ angle_zero = torch.zeros_like(angle_abduction)
+ t_posed[:,14] = t_posed[:,14] + \
+ (right_scapula(angle_abduction, angle_elevation, angle_rot, thorax_width, thorax_height).view(-1,3,1)
+ - right_scapula(angle_zero, angle_zero, angle_zero, thorax_width, thorax_height).view(-1,3,1))
+
+
+ angle_abduction = poses[:,36]
+ angle_elevation = poses[:,37]
+ angle_rot = poses[:,38]
+ angle_zero = torch.zeros_like(angle_abduction)
+ t_posed[:,19] = t_posed[:,19] + \
+ (left_scapula(angle_abduction, angle_elevation, angle_rot, thorax_width, thorax_height).view(-1,3,1)
+ - left_scapula(angle_zero, angle_zero, angle_zero, thorax_width, thorax_height).view(-1,3,1))
+
+
+ # Knee_r
+ # TODO add the Walker knee offset
+ # bone_scale = self.compute_bone_scale(J_,J, skin_v0, v_shaped)
+ # f1 = poses[:, 2*3+2].clone()
+ # scale_femur = bone_scale[:, 2]
+ # factor = 0.076/0.080 * scale_femur # The template femur medial laterak spacing #66
+ # f = -f1*180/torch.pi #knee_flexion
+ # varus = (0.12367*f)-0.0009*f**2
+ # introt = 0.3781*f-0.001781*f**2
+ # ydis = (-0.0683*f
+ # + 8.804e-4 * f**2
+ # - 3.750e-06*f**3
+ # )/1000*factor # up-down
+ # zdis = (-0.1283*f
+ # + 4.796e-4 * f**2)/1000*factor #
+ # import ipdb; ipdb.set_trace()
+ # poses[:, 9] = poses[:, 9] + varus
+ # t_posed[:,2] = t_posed[:,2] + torch.stack([torch.zeros_like(ydis), ydis, zdis], dim=1).view(-1,3,1)
+ # poses[:, 2*3+2]=0
+
+ # t_unposed = torch.zeros_like(t_posed)
+ # t_unposed[:,2] = torch.stack([torch.zeros_like(ydis), ydis, zdis], dim=1).view(-1,3,1)
+
+
+ # Spine
+ lumbar_bending = poses[:,17]
+ lumbar_extension = poses[:,18]
+ angle_zero = torch.zeros_like(lumbar_bending)
+ interp_t = torch.ones_like(lumbar_bending)
+ l = torch.abs(J[:, 11, 1] - J[:, 0, 1]) # Length of the spine section along y axis
+ t_posed[:,11] = t_posed[:,11] + \
+ (curve_torch_3d(lumbar_bending, lumbar_extension, t=interp_t, l=l)
+ - curve_torch_3d(angle_zero, angle_zero, t=interp_t, l=l))
+
+ thorax_bending = poses[:,20]
+ thorax_extension = poses[:,21]
+ angle_zero = torch.zeros_like(thorax_bending)
+ interp_t = torch.ones_like(thorax_bending)
+ l = torch.abs(J[:, 12, 1] - J[:, 11, 1]) # Length of the spine section
+
+ t_posed[:,12] = t_posed[:,12] + \
+ (curve_torch_3d(thorax_bending, thorax_extension, t=interp_t, l=l)
+ - curve_torch_3d(angle_zero, angle_zero, t=interp_t, l=l))
+
+ head_bending = poses[:, 23]
+ head_extension = poses[:,24]
+ angle_zero = torch.zeros_like(head_bending)
+ interp_t = torch.ones_like(head_bending)
+ l = torch.abs(J[:, 13, 1] - J[:, 12, 1]) # Length of the spine section
+ t_posed[:,13] = t_posed[:,13] + \
+ (curve_torch_3d(head_bending, head_extension, t=interp_t, l=l)
+ - curve_torch_3d(angle_zero, angle_zero, t=interp_t, l=l))
+
+
+ # ------- Body surface transformation matrix----------
+
+ G_ = torch.cat([R, t_posed], dim=-1) # BxJx3x4 local transformation matrix
+ pad_row = torch.FloatTensor([0, 0, 0, 1]).to(device).view(1, 1, 1, 4).expand(B, Nj, -1, -1) # BxJx1x4
+ G_ = torch.cat([G_, pad_row], dim=2) # BxJx4x4 padded to be 4x4 matrix an enable multiplication for the kinematic chain
+
+ # Global transform
+ G = [G_[:, 0].clone()]
+ for i in range(1, Nj):
+ G.append(torch.matmul(G[self.parent[i - 1]], G_[:, i, :, :]))
+ G = torch.stack(G, dim=1)
+
+ # ------- Pose dependant blend shapes ----------
+ if pose_dep_bs is False:
+ v_shaped_pd = v_shaped
+ else:
+ # Note : Those should be retrained for SKEL as the SKEL joints location are different from SMPL.
+ # But the current version lets use get decent pose dependant deformations for the shoulders, belly and knies
+ ident = torch.eye(3, dtype=v_shaped.dtype, device=device)
+
+ # We need the per SMPL joint bone transform to compute pose dependant blend shapes.
+ # Initialize each joint rotation with identity
+ Rsmpl = ident.unsqueeze(0).unsqueeze(0).expand(B, self.num_joints_smpl, -1, -1).clone() # BxNjx3x3
+
+ Rskin = G_[:, :, :3, :3] # BxNjx3x3
+ Rsmpl[:, smpl_joint_corresp] = Rskin[:] # BxNjx3x3 pose params to rotation
+ pose_feature = Rsmpl[:, 1:].view(B, -1, 3, 3) - ident
+ pose_offsets = torch.matmul(pose_feature.view(B, -1),
+ self.posedirs.view(Ns*3, -1).T).view(B, -1, 3)
+ v_shaped_pd = v_shaped + pose_offsets
+
+ ##########################################################################################
+ #Transform skin mesh
+ ############################################################################################
+
+ # Apply global transformation to the template mesh
+ rest = torch.cat([J, torch.zeros(B, Nj, 1).to(device)], dim=2).view(B, Nj, 4, 1) # BxJx4x1
+ zeros = torch.zeros(B, Nj, 4, 3).to(device) # BxJx4x3
+ rest = torch.cat([zeros, rest], dim=-1) # BxJx4x4
+ rest = torch.matmul(G, rest) # This is a 4x4 transformation matrix that only contains translation to the rest pose joint location
+ Gskin = G - rest
+
+ # Compute per vertex transformation matrix (after weighting)
+ T = torch.matmul(self.skin_weights, Gskin.permute(1, 0, 2, 3).contiguous().view(Nj, -1)).view(Ns, B, 4,4).transpose(0, 1)
+ rest_shape_h = torch.cat([v_shaped_pd, torch.ones_like(v_shaped_pd)[:, :, [0]]], dim=-1)
+ v_posed = torch.matmul(T, rest_shape_h[:, :, :, None])[:, :, :3, 0]
+
+ # translation
+ v_trans = v_posed + trans[:,None,:]
+
+ ##########################################################################################
+ #Transform joints
+ ############################################################################################
+
+ # import ipdb; ipdb.set_trace()
+ root_transform = with_zeros(torch.cat((R[:,0],J[:,0][:,:,None]),2))
+ results = [root_transform]
+ for i in range(0, self.parent.shape[0]):
+ transform_i = with_zeros(torch.cat((R[:, i + 1], t_posed[:,i+1]), 2))
+ curr_res = torch.matmul(results[self.parent[i]],transform_i)
+ results.append(curr_res)
+ results = torch.stack(results, dim=1)
+ posed_joints = results[:, :, :3, 3]
+ J_transformed = posed_joints + trans[:,None,:]
+
+
+ ##########################################################################################
+ # Transform skeleton
+ ############################################################################################
+
+ if skelmesh:
+ G_bones = None
+ # Shape the skeleton by scaling its bones
+ skel_rest_shape_h = torch.cat([skel_v0, torch.ones_like(skel_v0)[:, :, [0]]], dim=-1).expand(B, Nk, -1) # (1,Nk,3)
+
+ # compute the bones scaling from the kinematic tree and skin mesh
+ #with torch.no_grad():
+ # TODO: when dJ is optimized the shape of the mesh should be affected by the gradients
+ bone_scale = self.compute_bone_scale(J_, v_shaped, skin_v0)
+ # Apply bone meshes scaling:
+ skel_v_shaped = torch.cat([(torch.matmul(bone_scale[:,:,0], self.skel_weights_rigid.T) * skel_rest_shape_h[:, :, 0])[:, :, None],
+ (torch.matmul(bone_scale[:,:,1], self.skel_weights_rigid.T) * skel_rest_shape_h[:, :, 1])[:, :, None],
+ (torch.matmul(bone_scale[:,:,2], self.skel_weights_rigid.T) * skel_rest_shape_h[:, :, 2])[:, :, None],
+ (torch.ones(B, Nk, 1).to(device))
+ ], dim=-1)
+
+ # Align the bones with the proper axis
+ Gk01 = build_homog_matrix(Rk01, J.unsqueeze(-1)) # BxJx4x4
+ T = torch.matmul(self.skel_weights_rigid, Gk01.permute(1, 0, 2, 3).contiguous().view(Nj, -1)).view(Nk, B, 4,4).transpose(0, 1) #[1, 48757, 3, 3]
+ skel_v_align = torch.matmul(T, skel_v_shaped[:, :, :, None])[:, :, :, 0]
+
+ # This transfo will be applied with weights, effectively unposing the whole skeleton mesh in each joint frame.
+ # Then, per joint weighted transformation can then be applied
+ G_tpose_to_unposed = build_homog_matrix(torch.eye(3).view(1,1,3,3).expand(B, Nj, 3, 3).to(device), -J.unsqueeze(-1)) # BxJx4x4
+ G_skel = torch.matmul(G, G_tpose_to_unposed)
+ G_bones = torch.matmul(G, Gk01)
+
+ T = torch.matmul(self.skel_weights, G_skel.permute(1, 0, 2, 3).contiguous().view(Nj, -1)).view(Nk, B, 4,4).transpose(0, 1)
+ skel_v_posed = torch.matmul(T, skel_v_align[:, :, :, None])[:, :, :3, 0]
+
+ skel_trans = skel_v_posed + trans[:,None,:]
+
+ else:
+ skel_trans = skel_v0
+ Gk01 = build_homog_matrix(Rk01, J.unsqueeze(-1)) # BxJx4x4
+ G_bones = torch.matmul(G, Gk01)
+
+ joints = J_transformed
+ skin_verts = v_trans
+ skel_verts = skel_trans
+ joints_ori = G_bones[:,:,:3,:3]
+
+ if skin_verts.max() > 1e3:
+ import ipdb; ipdb.set_trace()
+
+ output = SKELOutput(skin_verts=skin_verts,
+ skel_verts=skel_verts,
+ joints=joints,
+ joints_ori=joints_ori,
+ betas=betas,
+ poses=poses,
+ trans = trans,
+ pose_offsets = pose_offsets,
+ joints_tpose = J_tpose,
+ v_shaped = v_shaped,)
+
+ return output
+
+
+ def compute_bone_scale(self, J_, v_shaped, skin_v0):
+
+ # index [0, 1, 2, 3 4, 5, , ...] # todo add last one, figure out bone scale indices
+ # J_ bone vectors [j0, j1-j0, j2-j0, j3-j0, j4-j1, j5-j2, ...]
+ # norm(J) = length of the bone [j0, j1-j0, j2-j0, j3-j0, j4-j1, j5-j2, ...]
+ # self.joints_sockets [j0, j1-j0, j2-j0, j3-j0, j4-j1, j5-j2, ...]
+ # self.skel_weights [j0, j1, j2, j3, j4, j5, ...]
+ B = J_.shape[0]
+ Nj = J_.shape[1]
+
+ bone_scale = torch.ones(B, Nj).to(J_.device)
+
+ # BSM template joints location
+ osim_joints_r = self.apose_rel_transfo[:, :3, 3].view(1, Nj, 3).expand(B, Nj, 3).clone()
+
+ length_bones_bsm = torch.norm(osim_joints_r, dim=-1).expand(B, -1)
+ length_bones_smpl = torch.norm(J_, dim=-1) # (B, Nj)
+ bone_scale_parent = length_bones_smpl / length_bones_bsm
+
+ non_leaf_node = (self.child != 0)
+ bone_scale[:,non_leaf_node] = (bone_scale_parent[:,self.child])[:,non_leaf_node]
+
+ # Ulna should have the same scale as radius
+ bone_scale[:,16] = bone_scale[:,17]
+ bone_scale[:,16] = bone_scale[:,17]
+
+ bone_scale[:,21] = bone_scale[:,22]
+ bone_scale[:,21] = bone_scale[:,22]
+
+ # Thorax
+ # Thorax scale is defined by the relative position of the thorax to its child joint, not parent joint as for other bones
+ bone_scale[:, 12] = bone_scale[:, 11]
+
+ # Lumbars
+ # Lumbar scale is defined by the y relative position of the lumbar joint
+ length_bones_bsm = torch.abs(osim_joints_r[:,11, 1])
+ length_bones_smpl = torch.abs(J_[:, 11, 1]) # (B, Nj)
+ bone_scale_lumbar = length_bones_smpl / length_bones_bsm
+ bone_scale[:, 11] = bone_scale_lumbar
+
+ # Expand to 3 dimensions and adjest scaling to avoid skin-skeleton intersection and handle the scaling of leaf body parts (hands, feet)
+ bone_scale = bone_scale.reshape(B, Nj, 1).expand(B, Nj, 3).clone()
+
+ for (ji, doi, dsi), (v1, v2) in scaling_keypoints.items():
+ bone_scale[:, ji, doi] = ((v_shaped[:,v1] - v_shaped[:, v2])/ (skin_v0[:,v1] - skin_v0[:, v2]))[:,dsi] # Top over chin
+ #TODO: Add keypoints for feet scaling in scaling_keypoints
+
+ # Adjust thorax front-back scaling
+ # TODO fix this part
+ v1 = 3027 #thorax back
+ v2 = 3495 #thorax front
+
+ scale_thorax_up = ((v_shaped[:,v1] - v_shaped[:, v2])/ (skin_v0[:,v1] - skin_v0[:, v2]))[:,2] # good for large people
+ v2 = 3506 #sternum
+ scale_thorax_sternum = ((v_shaped[:,v1] - v_shaped[:, v2])/ (skin_v0[:,v1] - skin_v0[:, v2]))[:,2] # Good for skinny people
+ bone_scale[:, 12, 0] = torch.min(scale_thorax_up, scale_thorax_sternum) # Avoids super expanded ribcage for large people and sternum outside for skinny people
+
+ #lumbars, adjust width to be same as thorax
+ bone_scale[:, 11, 0] = bone_scale[:, 12, 0]
+
+ return bone_scale
+
+
+
+ def compute_bone_orientation(self, J, J_):
+ """Compute each bone orientation in T pose """
+
+ # method = 'unposed'
+ # method = 'learned'
+ method = 'learn_adjust'
+
+ B = J_.shape[0]
+ Nj = J_.shape[1]
+
+ # Create an array of bone vectors the bone meshes should be aligned to.
+ bone_vect = torch.zeros_like(J_) # / torch.norm(J_, dim=-1)[:, :, None] # (B, Nj, 3)
+ bone_vect[:] = J_[:, self.child] # Most bones are aligned between their parent and child joint
+ bone_vect[:,16] = bone_vect[:,16]+bone_vect[:,17] # We want to align the ulna to the segment joint 16 to 18
+ bone_vect[:,21] = bone_vect[:,21]+bone_vect[:,22] # Same other ulna
+
+ # TODO Check indices here
+ # bone_vect[:,13] = bone_vect[:,12].clone()
+ bone_vect[:,12] = bone_vect.clone()[:,11].clone() # We want to align the thorax on the thorax-lumbar segment
+ # bone_vect[:,11] = bone_vect[:,0].clone()
+
+ osim_vect = self.apose_rel_transfo[:, :3, 3].clone().view(1, Nj, 3).expand(B, Nj, 3).clone()
+ osim_vect[:] = osim_vect[:,self.child]
+ osim_vect[:,16] = osim_vect[:,16]+osim_vect[:,17] # We want to align the ulna to the segment joint 16 to 18
+ osim_vect[:,21] = osim_vect[:,21]+osim_vect[:,22] # We want to align the ulna to the segment joint 16 to 18
+
+ # TODO: remove when this has been checked
+ # import matplotlib.pyplot as plt
+ # fig = plt.figure()
+ # ax = fig.add_subplot(111, projection='3d')
+ # ax.plot(osim_vect[:,0,0], osim_vect[:,0,1], osim_vect[:,0,2], color='r')
+ # plt.show()
+
+ Gk = torch.eye(3, device=J_.device).repeat(B, Nj, 1, 1)
+
+ if method == 'unposed':
+ return Gk
+
+ elif method == 'learn_adjust':
+ Gk_learned = self.per_joint_rot.view(1, Nj, 3, 3).expand(B, -1, -1, -1) #load learned rotation
+ osim_vect_corr = torch.matmul(Gk_learned, osim_vect.unsqueeze(-1)).squeeze(-1)
+
+ Gk[:,:] = rotation_matrix_from_vectors(osim_vect_corr, bone_vect)
+ # set nan to zero
+ # TODO: Check again why the following line was required
+ Gk[torch.isnan(Gk)] = 0
+ # Gk[:,[18,23]] = Gk[:,[16,21]] # hand has same orientation as ulna
+ # Gk[:,[5,10]] = Gk[:,[4,9]] # toe has same orientation as calcaneus
+ # Gk[:,[0,11,12,13,14,19]] = torch.eye(3, device=J_.device).view(1,3,3).expand(B, 6, 3, 3) # pelvis, torso and shoulder blade orientation does not vary with beta, leave it
+ Gk[:, self.joint_idx_fixed_beta] = torch.eye(3, device=J_.device).view(1,3,3).expand(B, len(self.joint_idx_fixed_beta), 3, 3) # pelvis, torso and shoulder blade orientation should not vary with beta, leave it
+
+ Gk = torch.matmul(Gk, Gk_learned)
+
+ elif method == 'learned':
+ """ Apply learned transformation"""
+ Gk = self.per_joint_rot.view(1, Nj, 3, 3).expand(B, -1, -1, -1)
+
+ else:
+ raise NotImplementedError
+
+ return Gk
diff --git a/lib/body_models/skel/utils.py b/lib/body_models/skel/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..4970002a191bff24a4e986ee589b2414b5e40bc5
--- /dev/null
+++ b/lib/body_models/skel/utils.py
@@ -0,0 +1,445 @@
+# -*- coding: utf-8 -*-
+#
+# Copyright (C) 2020 Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG),
+# acting on behalf of its Max Planck Institute for Intelligent Systems and the
+# Max Planck Institute for Biological Cybernetics. All rights reserved.
+#
+# Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is holder of all proprietary rights
+# on this computer program. You can only use this computer program if you have closed a license agreement
+# with MPG or you get the right to use the computer program from someone who is authorized to grant you that right.
+# Any use of the computer program without a valid license is prohibited and liable to prosecution.
+# Contact: ps-license@tuebingen.mpg.de
+#
+#
+# If you use this code in a research publication please consider citing the following:
+#
+# STAR: Sparse Trained Articulated Human Body Regressor
+#
+#
+# Code Developed by:
+# Ahmed A. A. Osman, edited by Marilyn Keller
+
+import scipy
+import torch
+import numpy as np
+
+def build_homog_matrix(R, t=None):
+ """ Create a homogeneous matrix from rotation matrix and translation vector
+ @ R: rotation matrix of shape (B, Nj, 3, 3)
+ @ t: translation vector of shape (B, Nj, 3, 1)
+ returns: homogeneous matrix of shape (B, 4, 4)
+ By Marilyn Keller
+ """
+
+ if t is None:
+ B = R.shape[0]
+ Nj = R.shape[1]
+ t = torch.zeros(B, Nj, 3, 1).to(R.device)
+
+ if R is None:
+ B = t.shape[0]
+ Nj = t.shape[1]
+ R = torch.eye(3).unsqueeze(0).unsqueeze(0).repeat(B, Nj, 1, 1).to(t.device)
+
+ B = t.shape[0]
+ Nj = t.shape[1]
+
+ # import ipdb; ipdb.set_trace()
+ assert R.shape == (B, Nj, 3, 3), f"R.shape: {R.shape}"
+ assert t.shape == (B, Nj, 3, 1), f"t.shape: {t.shape}"
+
+ G = torch.cat([R, t], dim=-1) # BxJx3x4 local transformation matrix
+ pad_row = torch.FloatTensor([0, 0, 0, 1]).to(R.device).view(1, 1, 1, 4).expand(B, Nj, -1, -1) # BxJx1x4
+ G = torch.cat([G, pad_row], dim=2) # BxJx4x4 padded to be 4x4 matrix an enable multiplication for the kinematic chain
+
+ return G
+
+
+def matmul_chain(rot_list):
+ R_tot = rot_list[-1]
+ for i in range(len(rot_list)-2,-1,-1):
+ R_tot = torch.matmul(rot_list[i], R_tot)
+ return R_tot
+
+
+def rotation_6d_to_matrix(d6: torch.Tensor) -> torch.Tensor:
+ """
+ Converts 6D rotation representation by Zhou et al. [1] to rotation matrix
+ using Gram--Schmidt orthogonalization per Section B of [1].
+ Args:
+ d6: 6D rotation representation, of size (*, 6)
+
+ Returns:
+ batch of rotation matrices of size (*, 3, 3)
+
+ [1] Zhou, Y., Barnes, C., Lu, J., Yang, J., & Li, H.
+ On the Continuity of Rotation Representations in Neural Networks.
+ IEEE Conference on Computer Vision and Pattern Recognition, 2019.
+ Retrieved from http://arxiv.org/abs/1812.07035
+ """
+ import torch.nn.functional as F
+ a1, a2 = d6[..., :3], d6[..., 3:]
+ b1 = F.normalize(a1, dim=-1)
+ b2 = a2 - (b1 * a2).sum(-1, keepdim=True) * b1
+ b2 = F.normalize(b2, dim=-1)
+ b3 = torch.cross(b1, b2, dim=-1)
+ return torch.stack((b1, b2, b3), dim=-2)
+
+
+def rotation_matrix_from_vectors(vec1, vec2):
+ """ Find the rotation matrix that aligns vec1 to vec2
+ :param vec1: A 3d "source" vector (B x Nj x 3)
+ :param vec2: A 3d "destination" vector (B x Nj x 3)
+ :return mat: A rotation matrix (B x Nj x 3 x 3) which when applied to vec1, aligns it with vec2.
+ """
+ for v_id, v in enumerate([vec1, vec2]):
+ # vectors shape should be B x Nj x 3
+ assert len(v.shape) == 3, f"Vectors {v_id} shape should be B x Nj x 3, got {v.shape}"
+ assert v.shape[-1] == 3, f"Vectors {v_id} shape should be B x Nj x 3, got {v.shape}"
+
+ B = vec1.shape[0]
+ Nj = vec1.shape[1]
+ device = vec1.device
+
+ a = vec1 / torch.linalg.norm(vec1, dim=-1, keepdim=True)
+ b = vec2 / torch.linalg.norm(vec2, dim=-1, keepdim=True)
+ v = torch.cross(a, b, dim=-1)
+ # Compute the dot product along the last dimension of a and b
+ c = torch.sum(a * b, dim=-1)
+ s = torch.linalg.norm(v, dim=-1) + torch.finfo(float).eps
+ v0 = torch.zeros_like(v[...,0], device=device).unsqueeze(-1)
+ kmat_l1 = torch.cat([v0, -v[...,2].unsqueeze(-1), v[...,1].unsqueeze(-1)], dim=-1)
+ kmat_l2 = torch.cat([v[...,2].unsqueeze(-1), v0, -v[...,0].unsqueeze(-1)], dim=-1)
+ kmat_l3 = torch.cat([-v[...,1].unsqueeze(-1), v[...,0].unsqueeze(-1), v0], dim=-1)
+ # Stack the matrix lines along a the -2 dimension
+ kmat = torch.cat([kmat_l1.unsqueeze(-2), kmat_l2.unsqueeze(-2), kmat_l3.unsqueeze(-2)], dim=-2) # B x Nj x 3 x 3
+ # import ipdb; ipdb.set_trace()
+ rotation_matrix = torch.eye(3, device=device).view(1,1,3,3).expand(B, Nj, 3, 3) + kmat + torch.matmul(kmat, kmat) * ((1 - c) / (s ** 2)).view(B, Nj, 1, 1).expand(B, Nj, 3, 3)
+ return rotation_matrix
+
+
+def quat_feat(theta):
+ '''
+ Computes a normalized quaternion ([0,0,0,0] when the body is in rest pose)
+ given joint angles
+ :param theta: A tensor of joints axis angles, batch size x number of joints x 3
+ :return:
+ '''
+ l1norm = torch.norm(theta + 1e-8, p=2, dim=1)
+ angle = torch.unsqueeze(l1norm, -1)
+ normalized = torch.div(theta, angle)
+ angle = angle * 0.5
+ v_cos = torch.cos(angle)
+ v_sin = torch.sin(angle)
+ quat = torch.cat([v_sin * normalized,v_cos-1], dim=1)
+ return quat
+
+def quat2mat(quat):
+ '''
+ Converts a quaternion to a rotation matrix
+ :param quat:
+ :return:
+ '''
+ norm_quat = quat
+ norm_quat = norm_quat / norm_quat.norm(p=2, dim=1, keepdim=True)
+ w, x, y, z = norm_quat[:, 0], norm_quat[:, 1], norm_quat[:, 2], norm_quat[:, 3]
+ B = quat.size(0)
+ w2, x2, y2, z2 = w.pow(2), x.pow(2), y.pow(2), z.pow(2)
+ wx, wy, wz = w * x, w * y, w * z
+ xy, xz, yz = x * y, x * z, y * z
+ rotMat = torch.stack([w2 + x2 - y2 - z2, 2 * xy - 2 * wz, 2 * wy + 2 * xz,
+ 2 * wz + 2 * xy, w2 - x2 + y2 - z2, 2 * yz - 2 * wx,
+ 2 * xz - 2 * wy, 2 * wx + 2 * yz, w2 - x2 - y2 + z2], dim=1).view(B, 3, 3)
+ return rotMat
+
+
+def rodrigues(theta):
+ '''
+ Computes the rodrigues representation given joint angles
+
+ :param theta: batch_size x number of joints x 3
+ :return: batch_size x number of joints x 3 x 4
+ '''
+ l1norm = torch.norm(theta + 1e-8, p = 2, dim = 1)
+ angle = torch.unsqueeze(l1norm, -1)
+ normalized = torch.div(theta, angle)
+ angle = angle * 0.5
+ v_cos = torch.cos(angle)
+ v_sin = torch.sin(angle)
+ quat = torch.cat([v_cos, v_sin * normalized], dim = 1)
+ return quat2mat(quat)
+
+
+def with_zeros(input):
+ '''
+ Appends a row of [0,0,0,1] to a batch size x 3 x 4 Tensor
+
+ :param input: A tensor of dimensions batch size x 3 x 4
+ :return: A tensor batch size x 4 x 4 (appended with 0,0,0,1)
+ '''
+ batch_size = input.shape[0]
+ row_append = torch.FloatTensor(([0.0, 0.0, 0.0, 1.0])).to(input.device)
+ row_append.requires_grad = False
+ padded_tensor = torch.cat([input, row_append.view(1, 1, 4).repeat(batch_size, 1, 1)], 1)
+ return padded_tensor
+
+
+def with_zeros_44(input):
+ '''
+ Appends a row of [0,0,0,1] to a batch size x 3 x 4 Tensor
+
+ :param input: A tensor of dimensions batch size x 3 x 4
+ :return: A tensor batch size x 4 x 4 (appended with 0,0,0,1)
+ '''
+ import ipdb; ipdb.set_trace()
+ batch_size = input.shape[0]
+ col_append = torch.FloatTensor(([[[[0.0, 0.0, 0.0]]]])).to(input.device)
+ padded_tensor = torch.cat([input, col_append], dim=-1)
+
+ row_append = torch.FloatTensor(([0.0, 0.0, 0.0, 1.0])).to(input.device)
+ row_append.requires_grad = False
+ padded_tensor = torch.cat([input, row_append.view(1, 1, 4).repeat(batch_size, 1, 1)], 1)
+ return padded_tensor
+
+
+def vector_to_rot():
+
+ def rotation_matrix(A,B):
+ # Aligns vector A to vector B
+
+ ax = A[0]
+ ay = A[1]
+ az = A[2]
+
+ bx = B[0]
+ by = B[1]
+ bz = B[2]
+
+ au = A/(torch.sqrt(ax*ax + ay*ay + az*az))
+ bu = B/(torch.sqrt(bx*bx + by*by + bz*bz))
+
+ R=torch.tensor([[bu[0]*au[0], bu[0]*au[1], bu[0]*au[2]], [bu[1]*au[0], bu[1]*au[1], bu[1]*au[2]], [bu[2]*au[0], bu[2]*au[1], bu[2]*au[2]] ])
+
+
+ return(R)
+
+
+def axis_angle_to_matrix(axis_angle: torch.Tensor) -> torch.Tensor:
+ """
+ Convert rotations given as axis/angle to rotation matrices.
+
+ Args:
+ axis_angle: Rotations given as a vector in axis angle form,
+ as a tensor of shape (..., 3), where the magnitude is
+ the angle turned anticlockwise in radians around the
+ vector's direction.
+
+ Returns:
+ Rotation matrices as tensor of shape (..., 3, 3).
+ """
+ return quaternion_to_matrix(axis_angle_to_quaternion(axis_angle))
+
+
+def axis_angle_to_quaternion(axis_angle: torch.Tensor) -> torch.Tensor:
+ """
+ Convert rotations given as axis/angle to quaternions.
+
+ Args:
+ axis_angle: Rotations given as a vector in axis angle form,
+ as a tensor of shape (..., 3), where the magnitude is
+ the angle turned anticlockwise in radians around the
+ vector's direction.
+
+ Returns:
+ quaternions with real part first, as tensor of shape (..., 4).
+ """
+ angles = torch.norm(axis_angle, p=2, dim=-1, keepdim=True)
+ half_angles = angles * 0.5
+ eps = 1e-6
+ small_angles = angles.abs() < eps
+ sin_half_angles_over_angles = torch.empty_like(angles)
+ sin_half_angles_over_angles[~small_angles] = (
+ torch.sin(half_angles[~small_angles]) / angles[~small_angles]
+ )
+ # for x small, sin(x/2) is about x/2 - (x/2)^3/6
+ # so sin(x/2)/x is about 1/2 - (x*x)/48
+ sin_half_angles_over_angles[small_angles] = (
+ 0.5 - (angles[small_angles] * angles[small_angles]) / 48
+ )
+ quaternions = torch.cat(
+ [torch.cos(half_angles), axis_angle * sin_half_angles_over_angles], dim=-1
+ )
+ return quaternions
+
+
+def quaternion_to_matrix(quaternions: torch.Tensor) -> torch.Tensor:
+ """
+ Convert rotations given as quaternions to rotation matrices.
+
+ Args:
+ quaternions: quaternions with real part first,
+ as tensor of shape (..., 4).
+
+ Returns:
+ Rotation matrices as tensor of shape (..., 3, 3).
+ """
+ r, i, j, k = torch.unbind(quaternions, -1)
+ two_s = 2.0 / (quaternions * quaternions).sum(-1)
+
+ o = torch.stack(
+ (
+ 1 - two_s * (j * j + k * k),
+ two_s * (i * j - k * r),
+ two_s * (i * k + j * r),
+ two_s * (i * j + k * r),
+ 1 - two_s * (i * i + k * k),
+ two_s * (j * k - i * r),
+ two_s * (i * k - j * r),
+ two_s * (j * k + i * r),
+ 1 - two_s * (i * i + j * j),
+ ),
+ -1,
+ )
+ return o.reshape(quaternions.shape[:-1] + (3, 3))
+
+
+def axis_angle_rotation(axis: str, angle: torch.Tensor) -> torch.Tensor:
+ """
+ Return the rotation matrices for one of the rotations about an axis
+ of which Euler angles describe, for each value of the angle given.
+
+ Args:
+ axis: Axis label "X" or "Y or "Z".
+ angle: any shape tensor of Euler angles in radians
+
+ Returns:
+ Rotation matrices as tensor of shape (..., 3, 3).
+ """
+
+ cos = torch.cos(angle)
+ sin = torch.sin(angle)
+ one = torch.ones_like(angle)
+ zero = torch.zeros_like(angle)
+
+ if axis == "X":
+ R_flat = (one, zero, zero, zero, cos, -sin, zero, sin, cos)
+ elif axis == "Y":
+ R_flat = (cos, zero, sin, zero, one, zero, -sin, zero, cos)
+ elif axis == "Z":
+ R_flat = (cos, -sin, zero, sin, cos, zero, zero, zero, one)
+ else:
+ raise ValueError("letter must be either X, Y or Z.")
+
+ return torch.stack(R_flat, -1).reshape(angle.shape + (3, 3))
+
+
+def axis_angle_to_matrix(axis_angle: torch.Tensor) -> torch.Tensor:
+ """
+ Convert rotations given as axis/angle to rotation matrices.
+
+ Args:
+ axis_angle: Rotations given as a vector in axis angle form,
+ as a tensor of shape (..., 3), where the magnitude is
+ the angle turned anticlockwise in radians around the
+ vector's direction.
+
+ Returns:
+ Rotation matrices as tensor of shape (..., 3, 3).
+ """
+ return quaternion_to_matrix(axis_angle_to_quaternion(axis_angle))
+
+
+def euler_angles_to_matrix(euler_angles: torch.Tensor, convention: str) -> torch.Tensor:
+ """
+ Convert rotations given as Euler angles in radians to rotation matrices.
+
+ Args:
+ euler_angles: Euler angles in radians as tensor of shape (..., 3).
+ convention: Convention string of three uppercase letters from
+ {"X", "Y", and "Z"}.
+
+ Returns:
+ Rotation matrices as tensor of shape (..., 3, 3).
+ """
+ if euler_angles.dim() == 0 or euler_angles.shape[-1] != 3:
+ raise ValueError("Invalid input euler angles.")
+ if len(convention) != 3:
+ raise ValueError("Convention must have 3 letters.")
+ if convention[1] in (convention[0], convention[2]):
+ raise ValueError(f"Invalid convention {convention}.")
+ for letter in convention:
+ if letter not in ("X", "Y", "Z"):
+ raise ValueError(f"Invalid letter {letter} in convention string.")
+ matrices = [
+ _axis_angle_rotation(c, e)
+ for c, e in zip(convention, torch.unbind(euler_angles, -1))
+ ]
+ # return functools.reduce(torch.matmul, matrices)
+ return torch.matmul(torch.matmul(matrices[0], matrices[1]), matrices[2])
+
+
+def _axis_angle_rotation(axis: str, angle: torch.Tensor) -> torch.Tensor:
+ """
+ Return the rotation matrices for one of the rotations about an axis
+ of which Euler angles describe, for each value of the angle given.
+
+ Args:
+ axis: Axis label "X" or "Y or "Z".
+ angle: any shape tensor of Euler angles in radians
+
+ Returns:
+ Rotation matrices as tensor of shape (..., 3, 3).
+ """
+
+ cos = torch.cos(angle)
+ sin = torch.sin(angle)
+ one = torch.ones_like(angle)
+ zero = torch.zeros_like(angle)
+
+ if axis == "X":
+ R_flat = (one, zero, zero, zero, cos, -sin, zero, sin, cos)
+ elif axis == "Y":
+ R_flat = (cos, zero, sin, zero, one, zero, -sin, zero, cos)
+ elif axis == "Z":
+ R_flat = (cos, -sin, zero, sin, cos, zero, zero, zero, one)
+ else:
+ raise ValueError("letter must be either X, Y or Z.")
+
+ return torch.stack(R_flat, -1).reshape(angle.shape + (3, 3))
+
+
+def location_to_spheres(loc, color=(1,0,0), radius=0.02):
+ """Given an array of 3D points, return a list of spheres located at those positions.
+
+ Args:
+ loc (numpy.array): Nx3 array giving 3D positions
+ color (tuple, optional): One RGB float color vector to color the spheres. Defaults to (1,0,0).
+ radius (float, optional): Radius of the spheres in meters. Defaults to 0.02.
+
+ Returns:
+ list: List of spheres Mesh
+ """
+ from psbody.mesh.sphere import Sphere
+ import numpy as np
+ cL = [Sphere(np.asarray([loc[i, 0], loc[i, 1], loc[i, 2]]), radius).to_mesh() for i in range(loc.shape[0])]
+ for spL in cL:
+ spL.set_vertex_colors(np.array(color))
+ return cL
+
+def sparce_coo_matrix2tensor(arr_coo, make_dense=True):
+ assert isinstance(arr_coo, scipy.sparse._coo.coo_matrix), f"arr_coo should be a coo_matrix, got {type(arr_coo)}. Please download the updated SKEL pkl files from https://skel.is.tue.mpg.de/."
+
+ values = arr_coo.data
+ indices = np.vstack((arr_coo.row, arr_coo.col))
+
+ i = torch.LongTensor(indices)
+ v = torch.FloatTensor(values)
+ shape = arr_coo.shape
+
+ tensor_arr = torch.sparse_coo_tensor(i, v, torch.Size(shape))
+
+ if make_dense:
+ tensor_arr = tensor_arr.to_dense()
+
+ return tensor_arr
+
diff --git a/lib/body_models/skel_utils/augmentation.py b/lib/body_models/skel_utils/augmentation.py
new file mode 100644
index 0000000000000000000000000000000000000000..2e5c1c3d855b3efab60ce82b949c7fad2ce649f2
--- /dev/null
+++ b/lib/body_models/skel_utils/augmentation.py
@@ -0,0 +1,47 @@
+import torch
+
+from typing import Optional
+
+from .transforms import real_orient_mat2q, real_orient_q2mat
+
+
+def update_params_after_orient_rotation(
+ poses : torch.Tensor, # (B, 46)
+ rot_mat : torch.Tensor, # the rotation orientation matrix
+ root_offset : Optional[torch.Tensor] = None, # the offset from custom root to model root
+):
+ '''
+
+ ### Args
+ - `poses`: torch.Tensor, shape = (B, 46)
+ - `rot_mat`: torch.Tensor, shape = (B, 3, 3)
+ - `root_offset`: torch.Tensor or None, shape = (B, 3)
+ - If None, the function won't update the translation.
+ - If not None, the function will calculate the root translation offset that make the model
+ rotate around the custom root instead of the model root.
+
+ ### Returns
+ - If `root_offset` is None:
+ - `poses`: torch.Tensor, shape = (B, 46)
+ - If `root_offset` is not None:
+ - `poses`: torch.Tensor, shape = (B, 46)
+ - `trans_offset`: torch.Tensor, shape = (B, 3)
+ '''
+ poses = poses.clone()
+ # 1. Transform the SKEL orientation to real matrix.
+ orient_q = poses[:, :3] # (B, 3)
+ orient_mat = real_orient_q2mat(orient_q) # (B, 3, 3)
+ orient_mat = torch.einsum('bij,bjk->bik', rot_mat, orient_mat) # (B, 3, 3)
+ orient_q = real_orient_mat2q(orient_mat) # (B, 3)
+ poses[:, :3] = orient_q
+
+ # 2. Update the translation if needed.
+ if root_offset is not None:
+ root_before = root_offset.clone() # (B, 3)
+ root_after = torch.einsum('bij,bj->bi', rot_mat, root_before) # (B, 3)
+ root_offset = root_after - root_before # (B, 3)
+ ret = poses, root_offset
+ else:
+ ret = poses
+
+ return ret
\ No newline at end of file
diff --git a/lib/body_models/skel_utils/definition.py b/lib/body_models/skel_utils/definition.py
new file mode 100644
index 0000000000000000000000000000000000000000..8a8e84c78bb57e38e3c7874f954674033c928b88
--- /dev/null
+++ b/lib/body_models/skel_utils/definition.py
@@ -0,0 +1,100 @@
+from lib.body_models.skel.osim_rot import ConstantCurvatureJoint, CustomJoint, EllipsoidJoint, PinJoint, WalkerKnee
+
+Q_COMPONENTS = [
+ {'qid': 0, 'name': 'pelvis', 'jid': 0},
+ {'qid': 1, 'name': 'pelvis', 'jid': 0},
+ {'qid': 2, 'name': 'pelvis', 'jid': 0},
+ {'qid': 3, 'name': 'femur-r', 'jid': 1},
+ {'qid': 4, 'name': 'femur-r', 'jid': 1},
+ {'qid': 5, 'name': 'femur-r', 'jid': 1},
+ {'qid': 6, 'name': 'tibia-r', 'jid': 2},
+ {'qid': 7, 'name': 'talus-r', 'jid': 3},
+ {'qid': 8, 'name': 'calcn-r', 'jid': 4},
+ {'qid': 9, 'name': 'toes-r', 'jid': 5},
+ {'qid': 10, 'name': 'femur-l', 'jid': 6},
+ {'qid': 11, 'name': 'femur-l', 'jid': 6},
+ {'qid': 12, 'name': 'femur-l', 'jid': 6},
+ {'qid': 13, 'name': 'tibia-l', 'jid': 7},
+ {'qid': 14, 'name': 'talus-l', 'jid': 8},
+ {'qid': 15, 'name': 'calcn-l', 'jid': 9},
+ {'qid': 16, 'name': 'toes-l', 'jid': 10},
+ {'qid': 17, 'name': 'lumbar', 'jid': 11},
+ {'qid': 18, 'name': 'lumbar', 'jid': 11},
+ {'qid': 19, 'name': 'lumbar', 'jid': 11},
+ {'qid': 20, 'name': 'thorax', 'jid': 12},
+ {'qid': 21, 'name': 'thorax', 'jid': 12},
+ {'qid': 22, 'name': 'thorax', 'jid': 12},
+ {'qid': 23, 'name': 'head', 'jid': 13},
+ {'qid': 24, 'name': 'head', 'jid': 13},
+ {'qid': 25, 'name': 'head', 'jid': 13},
+ {'qid': 26, 'name': 'scapula-r', 'jid': 14},
+ {'qid': 27, 'name': 'scapula-r', 'jid': 14},
+ {'qid': 28, 'name': 'scapula-r', 'jid': 14},
+ {'qid': 29, 'name': 'humerus-r', 'jid': 15},
+ {'qid': 30, 'name': 'humerus-r', 'jid': 15},
+ {'qid': 31, 'name': 'humerus-r', 'jid': 15},
+ {'qid': 32, 'name': 'ulna-r', 'jid': 16},
+ {'qid': 33, 'name': 'radius-r', 'jid': 17},
+ {'qid': 34, 'name': 'hand-r', 'jid': 18},
+ {'qid': 35, 'name': 'hand-r', 'jid': 18},
+ {'qid': 36, 'name': 'scapula-l', 'jid': 19},
+ {'qid': 37, 'name': 'scapula-l', 'jid': 19},
+ {'qid': 38, 'name': 'scapula-l', 'jid': 19},
+ {'qid': 39, 'name': 'humerus-l', 'jid': 20},
+ {'qid': 40, 'name': 'humerus-l', 'jid': 20},
+ {'qid': 41, 'name': 'humerus-l', 'jid': 20},
+ {'qid': 42, 'name': 'ulna-l', 'jid': 21},
+ {'qid': 43, 'name': 'radius-l', 'jid': 22},
+ {'qid': 44, 'name': 'hand-l', 'jid': 23},
+ {'qid': 45, 'name': 'hand-l', 'jid': 23},
+]
+
+
+QID2JID = {c['qid']: c['jid'] for c in Q_COMPONENTS}
+
+JID2QIDS = {}
+for c in Q_COMPONENTS:
+ jid = c['jid']
+ JID2QIDS[jid] = [] if jid not in JID2QIDS else JID2QIDS[jid]
+ JID2QIDS[jid].append(c['qid'])
+
+JID2DOF = {jid: len(qids) for jid, qids in JID2QIDS.items()}
+
+DoF1_JIDS = [2, 3, 4, 5, 7, 8, 9, 10, 16, 17, 21, 22] # (J1=12,)
+DoF2_JIDS = [18, 23] # (J2=2,)
+DoF3_JIDS = [0, 1, 6, 11, 12, 13, 14, 15, 19, 20] # (J3=10,)
+DoF1_QIDS = [6, 7, 8, 9, 13, 14, 15, 16, 32, 33, 42, 43] # (Q1=12,)
+DoF2_QIDS = [34, 35, 44, 45] # (Q2=4,)
+DoF3_QIDS = [0, 1, 2, 3, 4, 5, 10, 11, 12, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 36, 37, 38, 39, 40, 41] # (Q3=30,)
+
+
+# Copied from the `skel_model.py`.
+# Change all axis (except those PinJoint) to positive and update the flip if needed.
+JOINTS_DEF = [
+ CustomJoint(axis=[[0,0,1], [1,0,0], [0,1,0]], axis_flip=[1, 1, 1]), # 0 pelvis
+ CustomJoint(axis=[[0,0,1], [1,0,0], [0,1,0]], axis_flip=[1, 1, 1]), # 1 femur_r
+ WalkerKnee(), # 2 tibia_r
+ PinJoint(parent_frame_ori = [0.175895, -0.105208, 0.0186622]), # 3 talus_r Field taken from .osim Joint-> frames -> PhysicalOffsetFrame -> orientation
+ PinJoint(parent_frame_ori = [-1.76818999, 0.906223, 1.8196000]), # 4 calcn_r
+ PinJoint(parent_frame_ori = [-3.141589999, 0.6199010, 0]), # 5 toes_r
+ CustomJoint(axis=[[0,0,1], [1,0,0], [0,1,0]], axis_flip=[1, -1, -1]), # 6 femur_l
+ WalkerKnee(), # 7 tibia_l
+ PinJoint(parent_frame_ori = [0.175895, -0.105208, 0.0186622]), # 8 talus_l
+ PinJoint(parent_frame_ori = [1.768189999 ,-0.906223, 1.8196000]), # 9 calcn_l
+ PinJoint(parent_frame_ori = [-3.141589999, -0.6199010, 0]), # 10 toes_l
+ ConstantCurvatureJoint(axis=[[1,0,0], [0,0,1], [0,1,0]], axis_flip=[1, 1, 1]), # 11 lumbar
+ ConstantCurvatureJoint(axis=[[1,0,0], [0,0,1], [0,1,0]], axis_flip=[1, 1, 1]), # 12 thorax
+ ConstantCurvatureJoint(axis=[[1,0,0], [0,0,1], [0,1,0]], axis_flip=[1, 1, 1]), # 13 head
+ EllipsoidJoint(axis=[[0,1,0], [0,0,1], [1,0,0]], axis_flip=[1, -1, -1]), # 14 scapula_r
+ CustomJoint(axis=[[1,0,0], [0,1,0], [0,0,1]], axis_flip=[1, 1, 1]), # 15 humerus_r
+ CustomJoint(axis=[[0.0494, 0.0366, 0.99810825]], axis_flip=[[1]]), # 16 ulna_r
+ CustomJoint(axis=[[-0.01716099, 0.99266564, -0.11966796]], axis_flip=[[1]]), # 17 radius_r
+ CustomJoint(axis=[[1,0,0], [0,0,1]], axis_flip=[1, -1]), # 18 hand_r
+ EllipsoidJoint(axis=[[0,1,0], [0,0,1], [1,0,0]], axis_flip=[1, 1, 1]), # 19 scapula_l
+ CustomJoint(axis=[[1,0,0], [0,1,0], [0,0,1]], axis_flip=[1, 1, 1]), # 20 humerus_l
+ CustomJoint(axis=[[-0.0494, -0.0366, 0.99810825]], axis_flip=[[1]]), # 21 ulna_l
+ CustomJoint(axis=[[0.01716099, -0.99266564, -0.11966796]], axis_flip=[[1]]), # 22 radius_l
+ CustomJoint(axis=[[1,0,0], [0,0,1]], axis_flip=[-1, -1]), # 23 hand_l
+]
+
+N_JOINTS = len(JOINTS_DEF) # 24
diff --git a/lib/body_models/skel_utils/limits.py b/lib/body_models/skel_utils/limits.py
new file mode 100644
index 0000000000000000000000000000000000000000..e61c146ad7465b4660e002b79e8ccd485d6e300a
--- /dev/null
+++ b/lib/body_models/skel_utils/limits.py
@@ -0,0 +1,78 @@
+import math
+import torch
+
+from lib.body_models.skel.kin_skel import pose_param_names
+
+# Different from the original one, has some modifications.
+pose_limits = {
+ 'scapula_abduction_r' : [-0.628, 0.628], # 26
+ 'scapula_elevation_r' : [-0.4, -0.1], # 27
+ 'scapula_upward_rot_r' : [-0.190, 0.319], # 28
+
+ 'scapula_abduction_l' : [-0.628, 0.628], # 36
+ 'scapula_elevation_l' : [-0.4, -0.1], # 37
+ 'scapula_upward_rot_l' : [-0.210, 0.219], # 38
+
+ 'elbow_flexion_r' : [0, (3/4)*math.pi], # 32
+ 'pro_sup_r' : [-3/4*math.pi/2, 3/4*math.pi/2], # 33
+ 'wrist_flexion_r' : [-math.pi/2, math.pi/2], # 34
+ 'wrist_deviation_r' :[-math.pi/4, math.pi/4], # 35
+
+ 'elbow_flexion_l' : [0, (3/4)*math.pi], # 42
+ 'pro_sup_l' : [-math.pi/2, math.pi/2], # 43
+ 'wrist_flexion_l' : [-math.pi/2, math.pi/2], # 44
+ 'wrist_deviation_l' :[-math.pi/4, math.pi/4], # 45
+
+ 'lumbar_bending' : [-2/3*math.pi/4, 2/3*math.pi/4], # 17
+ 'lumbar_extension' : [-math.pi/4, math.pi/4], # 18
+ 'lumbar_twist' : [-math.pi/4, math.pi/4], # 19
+
+ 'thorax_bending' :[-math.pi/4, math.pi/4], # 20
+ 'thorax_extension' :[-math.pi/4, math.pi/4], # 21
+ 'thorax_twist' :[-math.pi/4, math.pi/4], # 22
+
+ 'head_bending' :[-math.pi/4, math.pi/4], # 23
+ 'head_extension' :[-math.pi/4, math.pi/4], # 24
+ 'head_twist' :[-math.pi/4, math.pi/4], # 25
+
+ 'ankle_angle_r' : [-math.pi/4, math.pi/4], # 7
+ 'subtalar_angle_r' : [-math.pi/4, math.pi/4], # 8
+ 'mtp_angle_r' : [-math.pi/4, math.pi/4], # 9
+
+ 'ankle_angle_l' : [-math.pi/4, math.pi/4], # 14
+ 'subtalar_angle_l' : [-math.pi/4, math.pi/4], # 15
+ 'mtp_angle_l' : [-math.pi/4, math.pi/4], # 16
+
+ 'knee_angle_r' : [0, 3/4*math.pi], # 6
+ 'knee_angle_l' : [0, 3/4*math.pi], # 13
+
+ # Added by HSMR to make optimization more stable.
+ 'hip_flexion_r' : [-math.pi/4, 3/4*math.pi], # 3
+ 'hip_adduction_r' : [-math.pi/4, 2/3*math.pi/4], # 4
+ 'hip_rotation_r' : [-math.pi/4, math.pi/4], # 5
+ 'hip_flexion_l' : [-math.pi/4, 3/4*math.pi], # 10
+ 'hip_adduction_l' : [-math.pi/4, 2/3*math.pi/4], # 11
+ 'hip_rotation_l' : [-math.pi/4, math.pi/4], # 12
+
+ 'shoulder_r_x' : [-math.pi/2, math.pi/2+1.5], # 29, from bsm.osim
+ 'shoulder_r_y' : [-math.pi/2, math.pi/2], # 30
+ 'shoulder_r_z' : [-math.pi/2, math.pi/2], # 31, from bsm.osim
+
+ 'shoulder_l_x' : [-math.pi/2-1.5, math.pi/2], # 39, from bsm.osim
+ 'shoulder_l_y' : [-math.pi/2, math.pi/2], # 40
+ 'shoulder_l_z' : [-math.pi/2, math.pi/2], # 41, from bsm.osim
+ }
+
+pose_param_name2qid = {name: qid for qid, name in enumerate(pose_param_names)}
+qid2pose_param_name = {qid: name for qid, name in enumerate(pose_param_names)}
+
+SKEL_LIM_QIDS = []
+SKEL_LIM_BOUNDS = []
+for name, (low, up) in pose_limits.items():
+ if low > up:
+ low, up = up, low
+ SKEL_LIM_QIDS.append(pose_param_name2qid[name])
+ SKEL_LIM_BOUNDS.append([low, up])
+
+SKEL_LIM_BOUNDS = torch.Tensor(SKEL_LIM_BOUNDS).float()
+SKEL_LIM_QID2IDX = {qid: i for i, qid in enumerate(SKEL_LIM_QIDS)} # inverse mapping
\ No newline at end of file
diff --git a/lib/body_models/skel_utils/reality.py b/lib/body_models/skel_utils/reality.py
new file mode 100644
index 0000000000000000000000000000000000000000..33d81628fcbec1ac8043f8f7e626a1482b9852a2
--- /dev/null
+++ b/lib/body_models/skel_utils/reality.py
@@ -0,0 +1,36 @@
+from lib.kits.basic import *
+
+from lib.body_models.skel_utils.limits import SKEL_LIM_QID2IDX, SKEL_LIM_BOUNDS
+
+
+qids_cfg = {
+ 'l_knee': [13],
+ 'r_knee': [6],
+ 'l_elbow': [42, 43],
+ 'r_elbow': [32, 33],
+}
+
+
+def eval_rot_delta(poses, tol_deg=5):
+ tol_rad = np.deg2rad(tol_deg)
+
+ res = {}
+ for part in qids_cfg:
+ qids = qids_cfg[part]
+ violation_part = poses.new_zeros(poses.shape[0], len(qids))
+ for i, qid in enumerate(qids):
+ idx = SKEL_LIM_QID2IDX[qid]
+ ea = poses[:, qid]
+ ea = (ea + np.pi) % (2 * np.pi) - np.pi # Normalize to (-pi, pi)
+ exceed_lb = torch.where(
+ ea < SKEL_LIM_BOUNDS[idx][0] - tol_rad,
+ ea - SKEL_LIM_BOUNDS[idx][0] + tol_rad, 0
+ )
+ exceed_ub = torch.where(
+ ea > SKEL_LIM_BOUNDS[idx][1] + tol_rad,
+ ea - SKEL_LIM_BOUNDS[idx][1] - tol_rad, 0
+ )
+ violation_part[:, i] = exceed_lb.abs() + exceed_ub.abs()
+ res[part] = violation_part
+
+ return res
\ No newline at end of file
diff --git a/lib/body_models/skel_utils/transforms.py b/lib/body_models/skel_utils/transforms.py
new file mode 100644
index 0000000000000000000000000000000000000000..300802c9b889ed2c710c3bfc4a0e61da213173ef
--- /dev/null
+++ b/lib/body_models/skel_utils/transforms.py
@@ -0,0 +1,404 @@
+from lib.kits.basic import *
+
+import torch
+import numpy as np
+import torch.nn.functional as F
+
+from .definition import JOINTS_DEF, N_JOINTS, JID2DOF, JID2QIDS, DoF1_JIDS, DoF2_JIDS, DoF3_JIDS, DoF1_QIDS, DoF2_QIDS, DoF3_QIDS
+
+from lib.utils.data import to_tensor
+from lib.utils.geometry.rotation import (
+ matrix_to_euler_angles,
+ matrix_to_rotation_6d,
+ euler_angles_to_matrix,
+ rotation_6d_to_matrix,
+)
+
+
+# ====== Internal Utils ======
+
+
+def axis2convention(axis:List):
+ ''' [1,0,0] -> 'X', [0,1,0] -> 'Y', [0,0,1] -> 'Z' '''
+ if axis == [1, 0, 0]:
+ return 'X'
+ elif axis == [0, 1, 0]:
+ return 'Y'
+ elif axis == [0, 0, 1]:
+ return 'Z'
+ else:
+ raise ValueError(f'Unsupported axis: {axis}.')
+
+
+def rotation_2d_to_angle(r2d:torch.Tensor):
+ '''
+ Extract single angle from a 2D rotation vector, which is the first column of a 2x2 rotation matrix.
+
+ ### Args
+ - r2d: torch.Tensor
+ - shape = (...B, 2)
+
+ ### Returns
+ - shape = (...B,)
+ '''
+ cos, sin = r2d[..., [0]], -r2d[..., [1]]
+ return torch.atan2(sin, cos)
+
+# ====== Tools ======
+
+
+OS2S_FLIP = [-1, 1, 1]
+OS2S_CONV = 'YZX'
+def real_orient_mat2q(orient_mat:torch.Tensor) -> torch.Tensor:
+ '''
+ The rotation matrix that SKEL uses is different from the SMPL's orientation matrix.
+ The rotation to representation functions below can not be used to transform the rotaiton matrix.
+ This function is used to convert the SMPL's orientation matrix to the SKEL's orientation q.
+ BUT, is that really important? Maybe we shouldn't align SMPL's orientation with SKEL's, can they be different?
+
+ ### Args
+ - orient_mat: torch.Tensor, shape = (..., 3, 3)
+
+ ### Returns
+ - orient_q: torch.Tensor, shape = (..., 3)
+ '''
+ device = orient_mat.device
+ flip = to_tensor(OS2S_FLIP, device=device) # (3,)
+ orient_ea = matrix_to_euler_angles(orient_mat.clone(), OS2S_CONV) # (..., 3)
+ orient_ea = orient_ea[..., [2, 1, 0]] # Re-permuting the order.
+ orient_q = orient_ea * flip[None]
+ return orient_q
+
+
+def real_orient_q2mat(orient_q:torch.Tensor) -> torch.Tensor:
+ '''
+ The rotation matrix that SKEL uses is different from the SMPL's orientation matrix.
+ The rotation to representation functions below can not be used to transform the rotation matrix.
+ This function is used to convert the SKEL's orientation q to the SMPL's orientation matrix.
+ BUT, is that really important? Maybe we shouldn't align SMPL's orientation with SKEL's, can they be different?
+
+ ### Args
+ - orient_q: torch.Tensor, shape = (..., 3)
+
+ ### Returns
+ - orient_mat: torch.Tensor, shape = (..., 3, 3)
+ '''
+ device = orient_q.device
+ flip = to_tensor(OS2S_FLIP, device=device) # (3,)
+ orient_ea = orient_q * flip[None]
+ orient_ea = orient_ea[..., [2, 1, 0]] # Re-permuting the order.
+ orient_mat = euler_angles_to_matrix(orient_ea, OS2S_CONV)
+ return orient_mat
+
+
+def flip_params_lr(poses:torch.Tensor) -> torch.Tensor:
+ '''
+ It flips the skel through exchanging the params of left part and right part of the body. It's useful for
+ data augmentation. Note that the 'left & right' defined when the body is facing z+ direction, this is
+ only important for the orientation.
+
+ ### Args
+ - poses: torch.Tensor, shape = (B, L, 46) or (L, 46)
+
+ ### Returns
+ - flipped_poses: torch.Tensor, shape = (B, L, 46) or (L, 46)
+ '''
+ assert len(poses.shape) in [2, 3] and poses.shape[-1] == 46, f'Shape of poses should be (B, L, 46) or (L, 46) but get {poses.shape}.'
+
+ # 1. Switch the value of each pair through re-permuting.
+ flipped_perm = [
+ 0, 1, 2, # pelvis
+ 10, 11, 12, # femur-r -> femur-l
+ 13, # tibia-r -> tibia-l
+ 14, # talus-r -> talus-l
+ 15, # calcn-r -> calcn-l
+ 16, # toes-r -> toes-l
+ 3, 4, 5, # femur-l -> femur-r
+ 6, # tibia-l -> tibia-r
+ 7, # talus-l -> talus-r
+ 8, # calcn-l -> calcn-r
+ 9, # toes-l -> toes-r
+ 17, 18, 19, # lumbar
+ 20, 21, 22, # thorax
+ 23, 24, 25, # head
+ 36, 37, 38, # scapula-r -> scapula-l
+ 39, 40, 41, # humerus-r -> humerus-l
+ 42, # ulna-r -> ulna-l
+ 43, # radius-r -> radius-l
+ 44, 45, # hand-r -> hand-l
+ 26, 27, 28, # scapula-l -> scapula-r
+ 29, 30, 31, # humerus-l -> humerus-r
+ 32, # ulna-l -> ulna-r
+ 33, # radius-l -> radius-r
+ 34, 35 # hand-l -> hand-r
+ ]
+
+ flipped_poses = poses[..., flipped_perm]
+
+ # 2. Mirror the rotation direction through apply -1.
+ flipped_sign = [
+ 1, -1, -1, # pelvis
+ 1, 1, 1, # femur-r'
+ 1, # tibia-r'
+ 1, # talus-r'
+ 1, # calcn-r'
+ 1, # toes-r'
+ 1, 1, 1, # femur-l'
+ 1, # tibia-l'
+ 1, # talus-l'
+ 1, # calcn-l'
+ 1, # toes-l'
+ -1, 1, -1, # lumbar
+ -1, 1, -1, # thorax
+ -1, 1, -1, # head
+ -1, -1, 1, # scapula-r'
+ -1, -1, 1, # humerus-r'
+ 1, # ulna-r'
+ 1, # radius-r'
+ 1, 1, # hand-r'
+ -1, -1, 1, # scapula-l'
+ -1, -1, 1, # humerus-l'
+ 1, # ulna-l'
+ 1, # radius-l'
+ 1, 1 # hand-l'
+ ]
+ flipped_sign = torch.tensor(flipped_sign, dtype=poses.dtype, device=poses.device) # (46,)
+ flipped_poses = flipped_sign * flipped_poses
+
+ return flipped_poses
+
+
+
+# def rotate_orient_around_z(q, rot):
+# """
+# Rotate SKEL orientation.
+# Args:
+# q (np.ndarray): SKEL style rotation representation (3,).
+# rot (np.ndarray): Rotation angle in degrees.
+# Returns:
+# np.ndarray: Rotated axis-angle vector.
+# """
+# import torch
+# from lib.body_models.skel.osim_rot import CustomJoint
+# # q to mat
+# root = CustomJoint(axis=[[0,0,1], [1,0,0], [0,1,0]], axis_flip=[1, 1, 1]) # pelvis
+# q = torch.from_numpy(q).unsqueeze(0)
+# q = q[:, [2, 1, 0]]
+# Rp = euler_angles_to_matrix(q, convention="YXZ")
+# # rotate around z
+# R = torch.Tensor([[np.deg2rad(-rot), 0, 0]])
+# R = axis_angle_to_matrix(R)
+# R = torch.matmul(R, Rp)
+# # mat to q
+# q = matrix_to_euler_angles(R, convention="YXZ")
+# q = q[:, [2, 1, 0]]
+# q = q.numpy().squeeze()
+
+# return q.astype(np.float32)
+
+
+def params_q2rot(params_q:Union[torch.Tensor, np.ndarray]):
+ '''
+ Transform parts of the euler-like SKEL parameters representation all to rotation matrix.
+
+ ### Args
+ - params_q: Union[torch.Tensor, np.ndarray], shape = (...B, 46) or (...B, 46)
+
+ ### Returns
+ - shape = (...B, 24, 9) # 24 joints, each joint has a 3x3 matrix, but for some joints, the matrix is not all used.
+ '''
+ # Check the type and unify to torch.
+ is_np = isinstance(params_q, np.ndarray)
+ if is_np:
+ params_q = torch.from_numpy(params_q)
+
+ # Prepare for necessary variables.
+ Bs = params_q.shape[:-1]
+ ident = torch.eye(3, dtype=params_q.dtype).to(params_q.device) # (3, 3)
+ params_rot = ident.repeat(*Bs, N_JOINTS, 1, 1) # (...B, 24, 3, 3)
+
+ # Deal with each joints separately. Modified from the `skel_model.py` but a static version.
+ sid = 0
+ for jid in range(N_JOINTS):
+ joint_obj = JOINTS_DEF[jid]
+ eid = sid + joint_obj.nb_dof.item()
+ params_rot[..., jid, :, :] = joint_obj.q_to_rot(params_q[..., sid:eid])
+ sid = eid
+
+ if is_np:
+ params_rot = params_rot.detach().cpu().numpy()
+ return params_rot
+
+
+def params_q2rep(params_q:Union[torch.Tensor, np.ndarray]):
+ '''
+ Transform the euler-like SKEL parameters representation to the continuous representation.
+ This function is not supposed to be used in the training process, but only for debugging.
+ The function that matters actually is the inverse of this function.
+
+ ### Args
+ - params_q: Union[torch.Tensor, np.ndarray], shape = (...B, 46) or (...B, 46)
+
+ ### Returns
+ - shape = (...B, 24, 6)
+ - Among 24 joints, for 3 DoF ones, all 6 values are used to represent the rotation;
+ but for 1 DoF joints, only the first 2 are used. The rest will be represented as zeros.
+ '''
+ # Check the type and unify to torch.
+ is_np = isinstance(params_q, np.ndarray)
+ if is_np:
+ params_q = torch.from_numpy(params_q)
+
+ # Prepare for necessary variables.
+ Bs = params_q.shape[:-1]
+ params_rep = params_q.new_zeros(*Bs, N_JOINTS, 6) # (...B, 24, 6)
+
+ # Deal with each joints separately. Modified from the `skel_model.py` but a static version.
+ sid = 0
+ for jid in range(N_JOINTS):
+ joint_obj = JOINTS_DEF[jid]
+ dof = joint_obj.nb_dof.item()
+ eid = sid + dof
+ if dof == 3:
+ mat = joint_obj.q_to_rot(params_q[..., sid:eid]) # (...B, 3, 3)
+ params_rep[..., jid, :] = matrix_to_rotation_6d(mat) # (...B, 6)
+ elif dof == 2:
+ # mat = joint_obj.q_to_rot(params_q[..., sid:eid]) # (...B, 3, 3)
+ # params_rep[..., jid, :] = matrix_to_rotation_6d(mat) # (...B, 6)
+ params_rep[..., jid, :2] = params_q[..., sid:eid]
+ elif dof == 1:
+ cos = torch.cos(params_q[..., sid])
+ sin = torch.sin(params_q[..., sid])
+ params_rep[..., jid, 0] = cos
+ params_rep[..., jid, 1] = -sin
+
+ sid = eid
+
+ if is_np:
+ params_rep = params_rep.detach().cpu().numpy()
+ return params_rep
+
+
+# Deprecated.
+def dof3_to_q(rot, axises:List, flip:List):
+ '''
+ Convert a rotation matrix to SKEL style rotation representation.
+
+ ### Args
+ - rot: torch.Tensor, shape (...B, 3, 3)
+ - The rotation matrix.
+ - axises: list
+ - [[x0, y0, z0], [x1, y1, z1], [x2, y2, z2]]
+ - The axis defined in the SKEL's joint_dict. Only one of xi, yi, zi is 1, the others are 0.
+ - flip: list
+ - [f0, f1, f2]
+ - The flip value defined in the SKEL's joint_dict. fi is 1 or -1.
+
+ ### Returns
+ - shape = (...B, 3)
+ '''
+ convention = [axis2convention(axis) for axis in reversed(axises)] # SKEL use euler angle in reverse order
+ convention = ''.join(convention)
+ q = matrix_to_euler_angles(rot[..., :, :], convention=convention) # (...B, 3)
+ q = q[..., [2, 1, 0]] # SKEL use euler angle in reverse order
+ flip = rot.new_tensor(flip) # (3,)
+ q = flip * q
+ return q
+
+
+### Slow version, deprecated. ###
+# def params_rep2q(params_rot:Union[torch.Tensor, np.ndarray]):
+# '''
+# Transform the continuous representation back to the SKEL style euler-like representation.
+#
+# ### Args
+# - params_rot: Union[torch.Tensor, np.ndarray]
+# - shape = (...B, 24, 6)
+#
+# ### Returns
+# - shape = (...B, 46)
+# '''
+#
+# # Check the type and unify to torch.
+# is_np = isinstance(params_rot, np.ndarray)
+# if is_np:
+# params_rot = torch.from_numpy(params_rot)
+#
+# # Prepare for necessary variables.
+# Bs = params_rot.shape[:-2]
+# params_q = params_rot.new_zeros((*Bs, 46)) # (...B, 46)
+#
+# for jid in range(N_JOINTS):
+# joint_obj = JOINTS_DEF[jid]
+# dof = joint_obj.nb_dof.item()
+# sid, eid = JID2QIDS[jid][0], JID2QIDS[jid][-1] + 1
+#
+# if dof == 3:
+# mat = rotation_6d_to_matrix(params_rot[..., jid, :]) # (...B, 3, 3)
+# params_q[..., sid:eid] = dof3_to_q(
+# mat,
+# joint_obj.axis.tolist(),
+# joint_obj.axis_flip.detach().cpu().tolist(),
+# )
+# elif dof == 2:
+# params_q[..., sid:eid] = params_rot[..., jid, :2]
+# else:
+# params_q[..., sid:eid] = rotation_2d_to_angle(params_rot[..., jid, :2])
+#
+# if is_np:
+# params_q = params_q.detach().cpu().numpy()
+# return params_q
+
+def orient_mat2q(orient_mat:torch.Tensor):
+ ''' This is a tool function for inspecting only. orient_mat ~ (...B, 3, 3)'''
+ poses_rep = orient_mat.new_zeros(orient_mat.shape[:-2] + (24, 6)) # (...B, 24, 6)
+ orient_rep = matrix_to_rotation_6d(orient_mat) # (...B, 6)
+ poses_rep[..., 0, :] = orient_rep
+ poses_q = params_rep2q(poses_rep) # (...B, 46)
+ return poses_q[..., :3]
+
+
+# Pre-grouped the joints for different conventions
+CON_GROUP2JIDS = {'YXZ': [0, 1, 6], 'YZX': [11, 12, 13], 'XZY': [14, 19], 'ZYX': [15, 20]}
+CON_GROUP2FLIPS = {'YXZ': [[1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, -1.0, -1.0]], 'YZX': [[1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0]], 'XZY': [[1.0, -1.0, -1.0], [1.0, 1.0, 1.0]], 'ZYX': [[1.0, 1.0, 1.0], [1.0, 1.0, 1.0]]}
+# Faster version.
+def params_rep2q(params_rot:Union[torch.Tensor, np.ndarray]):
+ '''
+ Transform the continuous representation back to the SKEL style euler-like representation.
+
+ ### Args
+ - params_rot: Union[torch.Tensor, np.ndarray], shape = (...B, 24, 6)
+
+ ### Returns
+ - shape = (...B, 46)
+ '''
+
+ with PM.time_monitor('params_rep2q'):
+ with PM.time_monitor('preprocess'):
+ params_rot, recover_type_back = to_tensor(params_rot, device=None, temporary=True)
+
+ # Prepare for necessary variables.
+ Bs = params_rot.shape[:-2]
+ params_q = params_rot.new_zeros((*Bs, 46)) # (...B, 46)
+
+ with PM.time_monitor(f'dof1&dof2'):
+ params_q[..., DoF1_QIDS] = rotation_2d_to_angle(params_rot[..., DoF1_JIDS, :2]).squeeze(-1)
+ params_q[..., DoF2_QIDS] = params_rot[..., DoF2_JIDS, :2].reshape(*Bs, -1) # (...B, J2=2 * 2)
+
+ with PM.time_monitor(f'dof3'):
+ dof3_6ds = params_rot[..., DoF3_JIDS, :].reshape(*Bs, len(DoF3_JIDS), 6) # (...B, J3=10, 3, 6)
+ dof3_mats = rotation_6d_to_matrix(dof3_6ds) # (...B, J3=10, 3, 3)
+
+ for convention, jids in CON_GROUP2JIDS.items():
+ idxs = [DoF3_JIDS.index(jid) for jid in jids]
+ mats = dof3_mats[..., idxs, :, :] # (...B, J', 3, 3)
+ qs = matrix_to_euler_angles(mats, convention=convention) # (...B, J', 3)
+ qs = qs[..., [2, 1, 0]] # SKEL use euler angle in reverse order
+ flips = qs.new_tensor(CON_GROUP2FLIPS[convention]) # (J', 3)
+ qs = qs * flips # (...B, J', 3)
+ qids = [qid for jid in jids for qid in JID2QIDS[jid]]
+ params_q[..., qids] = qs.reshape(*Bs, -1)
+
+ with PM.time_monitor('post_process'):
+ params_q = recover_type_back(params_q)
+ return params_q
\ No newline at end of file
diff --git a/lib/body_models/skel_wrapper.py b/lib/body_models/skel_wrapper.py
new file mode 100644
index 0000000000000000000000000000000000000000..d64065ac2a18f055ee11b1de038ddfce077958fa
--- /dev/null
+++ b/lib/body_models/skel_wrapper.py
@@ -0,0 +1,146 @@
+from lib.kits.basic import *
+
+import pickle
+from smplx.vertex_joint_selector import VertexJointSelector
+from smplx.vertex_ids import vertex_ids
+from smplx.lbs import vertices2joints
+
+from lib.body_models.skel.skel_model import SKEL, SKELOutput
+
+class SKELWrapper(SKEL):
+ def __init__(
+ self,
+ *args,
+ joint_regressor_custom: Optional[str] = None,
+ joint_regressor_extra : Optional[str] = None,
+ update_root : bool = False,
+ **kwargs
+ ):
+ ''' This wrapper aims to extend the output joints of the SKEL model which fits SMPL's portal. '''
+
+ super(SKELWrapper, self).__init__(*args, **kwargs)
+
+ # The final joints are combined from three parts:
+ # 1. The joints from the standard output.
+ # Map selected joints of interests from SKEL to SMPL. (Not all 24 joints will be used finally.)
+ # Notes: Only these SMPL joints will be used: [0, 1, 2, 4, 5, 7, 8, 12, 16, 17, 18, 19, 20, 21, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 45, 46, 47, 48, 49, 50, 51, 52, 53]
+ skel_to_smpl = [
+ 0,
+ 6,
+ 1,
+ 11, # not aligned well; not used
+ 7,
+ 2,
+ 11, # not aligned well; not used
+ 8, # or 9
+ 3, # or 4
+ 12, # not aligned well; not used
+ 10, # not used
+ 5, # not used
+ 12,
+ 19, # not aligned well; not used
+ 14, # not aligned well; not used
+ 13, # not used
+ 20, # or 19
+ 15, # or 14
+ 21, # or 22
+ 16, # or 17,
+ 23,
+ 18,
+ 23, # not aligned well; not used
+ 18, # not aligned well; not used
+ ]
+
+ smpl_to_openpose = [24, 12, 17, 19, 21, 16, 18, 20, 0, 2, 5, 8, 1, 4, 7, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34]
+
+ self.register_buffer('J_skel_to_smpl', torch.tensor(skel_to_smpl, dtype=torch.long))
+ self.register_buffer('J_smpl_to_openpose', torch.tensor(smpl_to_openpose, dtype=torch.long))
+ # (SKEL has the same topology as SMPL as well as SMPL-H, so perform the same operation for the other 2 parts.)
+ # 2. Joints selected from skin vertices.
+ self.vertex_joint_selector = VertexJointSelector(vertex_ids['smplh'])
+ # 3. Extra joints from the J_regressor_extra.
+ if joint_regressor_extra is not None:
+ self.register_buffer(
+ 'J_regressor_extra',
+ torch.tensor(pickle.load(
+ open(joint_regressor_extra, 'rb'),
+ encoding='latin1'
+ ), dtype=torch.float32)
+ )
+
+ self.custom_regress_joints = joint_regressor_custom is not None
+ if self.custom_regress_joints:
+ get_logger().info('Using customized joint regressor.')
+ with open(joint_regressor_custom, 'rb') as f:
+ J_regressor_custom = pickle.load(f, encoding='latin1')
+ if 'scipy.sparse' in str(type(J_regressor_custom)):
+ J_regressor_custom = J_regressor_custom.todense() # (24, 6890)
+ self.register_buffer(
+ 'J_regressor_custom',
+ torch.tensor(
+ J_regressor_custom,
+ dtype=torch.float32
+ )
+ )
+
+ self.update_root = update_root
+
+ def forward(self, **kwargs) -> SKELOutput: # type: ignore
+ ''' Map the order of joints of SKEL to SMPL's. '''
+
+ if 'trans' not in kwargs.keys():
+ kwargs['trans'] = kwargs['poses'].new_zeros((kwargs['poses'].shape[0], 3)) # (B, 3)
+
+ skel_output = super(SKELWrapper, self).forward(**kwargs)
+ verts = skel_output.skin_verts # (B, 6890, 3)
+ joints = skel_output.joints.clone() # (B, 24, 3)
+
+ # Update the root joint position (to avoid the root too forward).
+ if self.update_root:
+ # make root 0 to plane 11-1-6
+ hips_middle = (joints[:, 1] + joints[:, 6]) / 2 # (B, 3)
+ lumbar2middle = (hips_middle - joints[:, 11]) # (B, 3)
+ lumbar2middle_unit = lumbar2middle / torch.norm(lumbar2middle, dim=1, keepdim=True) # (B, 3)
+ lumbar2root = joints[:, 0] - joints[:, 11]
+ lumbar2root_proj = \
+ torch.einsum('bc,bc->b', lumbar2root, lumbar2middle_unit)[:, None] *\
+ lumbar2middle_unit # (B, 3)
+ root2root_proj = lumbar2root_proj - lumbar2root # (B, 3)
+ joints[:, 0] += root2root_proj * 0.7
+
+ # Combine the joints from three parts:
+ if self.custom_regress_joints:
+ # 1.x. Regress joints from the skin vertices using SMPL's regressor.
+ joints = vertices2joints(self.J_regressor_custom, verts) # (B, 24, 3)
+ else:
+ # 1.y. Map selected joints of interests from SKEL to SMPL.
+ joints = joints[:, self.J_skel_to_smpl] # (B, 24, 3)
+ joints_custom = joints.clone()
+ # 2. Concat joints selected from skin vertices.
+ joints = self.vertex_joint_selector(verts, joints) # (B, 45, 3)
+ # 3. Map selected joints to OpenPose.
+ joints = joints[:, self.J_smpl_to_openpose] # (B, 25, 3)
+ # 4. Add extra joints from the J_regressor_extra.
+ joints_extra = vertices2joints(self.J_regressor_extra, verts) # (B, 19, 3)
+ joints = torch.cat([joints, joints_extra], dim=1) # (B, 44, 3)
+
+ # Update the joints in the output.
+ skel_output.joints_backup = skel_output.joints
+ skel_output.joints_custom = joints_custom
+ skel_output.joints = joints
+
+ return skel_output
+
+
+ @staticmethod
+ def get_static_root_offset(skel_output):
+ '''
+ Background:
+ By default, the orientation rotation is always around the original skel_root.
+ In order to make the orientation rotation around the custom_root, we need to calculate the translation offset.
+ This function calculates the translation offset in static pose. (From custom_root to skel_root.)
+ '''
+ custom_root = skel_output.joints_custom[:, 0] # (B, 3)
+ skel_root = skel_output.joints_backup[:, 0] # (B, 3)
+ offset = skel_root - custom_root # (B, 3)
+ return offset
\ No newline at end of file
diff --git a/lib/body_models/smpl_utils/reality.py b/lib/body_models/smpl_utils/reality.py
new file mode 100644
index 0000000000000000000000000000000000000000..f195f19fb48f2ff8ef0f4468cb87e68c75a769d3
--- /dev/null
+++ b/lib/body_models/smpl_utils/reality.py
@@ -0,0 +1,125 @@
+from lib.kits.basic import *
+
+from lib.utils.geometry.rotation import axis_angle_to_matrix
+
+
+def get_lim_cfg(tol_deg=5):
+ tol_limit = np.deg2rad(tol_deg)
+ lim_cfg = {
+ 'l_knee': {
+ 'jid': 4,
+ 'convention': 'XZY',
+ 'limitation': [
+ [-tol_limit, 3/4*np.pi+tol_limit],
+ [-tol_limit, tol_limit],
+ [-tol_limit, tol_limit],
+ ]
+ },
+ 'r_knee': {
+ 'jid': 5,
+ 'convention': 'XZY',
+ 'limitation': [
+ [-tol_limit, 3/4*np.pi+tol_limit],
+ [-tol_limit, tol_limit],
+ [-tol_limit, tol_limit],
+ ]
+ },
+ 'l_elbow': {
+ 'jid': 18,
+ 'convention': 'YZX',
+ 'limitation': [
+ [-(3/4)*np.pi-tol_limit, tol_limit],
+ [-tol_limit, tol_limit],
+ [-3/4*np.pi/2-tol_limit, 3/4*np.pi/2+tol_limit],
+ ]
+ },
+ 'r_elbow': {
+ 'jid': 19,
+ 'convention': 'YZX',
+ 'limitation': [
+ [-tol_limit, (3/4)*np.pi+tol_limit],
+ [-tol_limit, tol_limit],
+ [-3/4*np.pi/2-tol_limit, 3/4*np.pi/2+tol_limit],
+ ]
+ },
+ }
+ return lim_cfg
+
+
+def matrix_to_possible_euler_angles(matrix: torch.Tensor, convention: str):
+ '''
+ Convert rotations given as rotation matrices to Euler angles in radians.
+
+ ### Args
+ matrix: Rotation matrices as tensor of shape (..., 3, 3).
+ convention: Convention string of three uppercase letters.
+
+ ### Returns
+ List of possible euler angles in radians as tensor of shape (..., 3).
+ '''
+ from lib.utils.geometry.rotation import _index_from_letter, _angle_from_tan
+ if len(convention) != 3:
+ raise ValueError("Convention must have 3 letters.")
+ if convention[1] in (convention[0], convention[2]):
+ raise ValueError(f"Invalid convention {convention}.")
+ for letter in convention:
+ if letter not in ("X", "Y", "Z"):
+ raise ValueError(f"Invalid letter {letter} in convention string.")
+ if matrix.size(-1) != 3 or matrix.size(-2) != 3:
+ raise ValueError(f"Invalid rotation matrix shape {matrix.shape}.")
+ i0 = _index_from_letter(convention[0])
+ i2 = _index_from_letter(convention[2])
+ tait_bryan = i0 != i2
+ central_angle_possible = []
+ if tait_bryan:
+ central_angle = torch.asin(
+ matrix[..., i0, i2] * (-1.0 if i0 - i2 in [-1, 2] else 1.0)
+ )
+ central_angle_possible = [central_angle, np.pi - central_angle]
+ else:
+ central_angle = torch.acos(matrix[..., i0, i0])
+ central_angle_possible = [central_angle, -central_angle]
+
+ o_possible = []
+ for central_angle in central_angle_possible:
+ o = (
+ _angle_from_tan(
+ convention[0], convention[1], matrix[..., i2], False, tait_bryan
+ ),
+ central_angle,
+ _angle_from_tan(
+ convention[2], convention[1], matrix[..., i0, :], True, tait_bryan
+ ),
+ )
+ o_possible.append(torch.stack(o, -1))
+ return o_possible
+
+
+def eval_rot_delta(body_pose, tol_deg=5):
+ lim_cfg = get_lim_cfg(tol_deg)
+ res ={}
+ for name, cfg in lim_cfg.items():
+ jid = cfg['jid'] - 1
+ cvt = cfg['convention']
+ lim = cfg['limitation']
+ aa = body_pose[:, jid, :] # (B, 3)
+ mt = axis_angle_to_matrix(aa) # (B, 3, 3)
+ ea_possible = matrix_to_possible_euler_angles(mt, cvt) # (B, 3)
+ violation_reasonable = None
+ for ea in ea_possible:
+ violation = ea.new_zeros(ea.shape) # (B, 3)
+
+ for i in range(3):
+ ea_i = ea[:, i]
+ ea_i = (ea_i + np.pi) % (2 * np.pi) - np.pi # Normalize to (-pi, pi)
+ exceed_lb = torch.where(ea_i < lim[i][0], ea_i - lim[i][0], 0)
+ exceed_ub = torch.where(ea_i > lim[i][1], ea_i - lim[i][1], 0)
+ violation[:, i] = exceed_lb.abs() + exceed_ub.abs() # (B, 3)
+ if violation_reasonable is not None: # minimize the violation
+ upd_mask = violation.sum(-1) < violation_reasonable.sum(-1)
+ violation_reasonable[upd_mask] = violation[upd_mask]
+ else:
+ violation_reasonable = violation
+
+ res[name] = violation_reasonable
+ return res
\ No newline at end of file
diff --git a/lib/body_models/smpl_utils/transforms.py b/lib/body_models/smpl_utils/transforms.py
new file mode 100644
index 0000000000000000000000000000000000000000..e998be7c38a960ba4db3d31923006022f8baaced
--- /dev/null
+++ b/lib/body_models/smpl_utils/transforms.py
@@ -0,0 +1,33 @@
+import torch
+import numpy as np
+
+from typing import Dict
+
+
+def fliplr_params(smpl_params: Dict):
+ global_orient = smpl_params['global_orient'].copy().reshape(-1, 3)
+ body_pose = smpl_params['body_pose'].copy().reshape(-1, 69)
+ betas = smpl_params['betas'].copy()
+
+ body_pose_permutation = [6, 7, 8, 3, 4, 5, 9, 10, 11, 15, 16, 17, 12, 13,
+ 14 ,18, 19, 20, 24, 25, 26, 21, 22, 23, 27, 28, 29, 33,
+ 34, 35, 30, 31, 32, 36, 37, 38, 42, 43, 44, 39, 40, 41,
+ 45, 46, 47, 51, 52, 53, 48, 49, 50, 57, 58, 59, 54, 55,
+ 56, 63, 64, 65, 60, 61, 62, 69, 70, 71, 66, 67, 68]
+ body_pose_permutation = body_pose_permutation[:body_pose.shape[1]]
+ body_pose_permutation = [i-3 for i in body_pose_permutation]
+
+ body_pose = body_pose[:, body_pose_permutation]
+
+ global_orient[:, 1::3] *= -1
+ global_orient[:, 2::3] *= -1
+ body_pose[:, 1::3] *= -1
+ body_pose[:, 2::3] *= -1
+
+ smpl_params = {'global_orient': global_orient.reshape(-1, 1, 3).astype(np.float32),
+ 'body_pose': body_pose.reshape(-1, 23, 3).astype(np.float32),
+ 'betas': betas.astype(np.float32)
+ }
+
+ return smpl_params
+
diff --git a/lib/body_models/smpl_wrapper.py b/lib/body_models/smpl_wrapper.py
new file mode 100644
index 0000000000000000000000000000000000000000..afdc0b0c6abc9021bc699aa953824da286330f58
--- /dev/null
+++ b/lib/body_models/smpl_wrapper.py
@@ -0,0 +1,43 @@
+import torch
+import numpy as np
+import pickle
+from typing import Optional
+import smplx
+from smplx.lbs import vertices2joints
+from smplx.utils import SMPLOutput
+
+
+class SMPLWrapper(smplx.SMPLLayer):
+ def __init__(self, *args, joint_regressor_extra: Optional[str] = None, update_hips: bool = False, **kwargs):
+ """
+ Extension of the official SMPL implementation to support more joints.
+ Args:
+ Same as SMPLLayer.
+ joint_regressor_extra (str): Path to extra joint regressor.
+ """
+ super(SMPLWrapper, self).__init__(*args, **kwargs)
+ smpl_to_openpose = [24, 12, 17, 19, 21, 16, 18, 20, 0, 2, 5, 8, 1, 4,
+ 7, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34]
+
+ if joint_regressor_extra is not None:
+ self.register_buffer('joint_regressor_extra', torch.tensor(pickle.load(open(joint_regressor_extra, 'rb'), encoding='latin1'), dtype=torch.float32))
+ self.register_buffer('joint_map', torch.tensor(smpl_to_openpose, dtype=torch.long))
+ self.update_hips = update_hips
+
+ def forward(self, *args, **kwargs) -> SMPLOutput:
+ """
+ Run forward pass. Same as SMPL and also append an extra set of joints if joint_regressor_extra is specified.
+ """
+ smpl_output = super(SMPLWrapper, self).forward(*args, **kwargs)
+ joints_smpl = smpl_output.joints.clone()
+ joints = smpl_output.joints[:, self.joint_map, :]
+ if self.update_hips:
+ joints[:,[9,12]] = joints[:,[9,12]] + \
+ 0.25*(joints[:,[9,12]]-joints[:,[12,9]]) + \
+ 0.5*(joints[:,[8]] - 0.5*(joints[:,[9,12]] + joints[:,[12,9]]))
+ if hasattr(self, 'joint_regressor_extra'):
+ extra_joints = vertices2joints(self.joint_regressor_extra, smpl_output.vertices)
+ joints = torch.cat([joints, extra_joints], dim=1)
+ smpl_output.joints = joints
+ smpl_output.joints_smpl = joints_smpl
+ return smpl_output
diff --git a/lib/data/augmentation/skel.py b/lib/data/augmentation/skel.py
new file mode 100644
index 0000000000000000000000000000000000000000..d1e0567a6e76a07104b87718816994377b0c3a29
--- /dev/null
+++ b/lib/data/augmentation/skel.py
@@ -0,0 +1,90 @@
+from lib.kits.basic import *
+
+from lib.utils.data import *
+from lib.utils.geometry.rotation import (
+ euler_angles_to_matrix,
+ axis_angle_to_matrix,
+ matrix_to_euler_angles,
+)
+from lib.body_models.skel_utils.transforms import params_q2rot
+
+
+def rot_q_orient(
+ q : Union[np.ndarray, torch.Tensor],
+ rot_rad : Union[np.ndarray, torch.Tensor],
+):
+ '''
+ ### Args
+ - q: np.ndarray or tensor, shape = (B, 3)
+ - SKEL style rotation representation.
+ - rot: np.ndarray or tensor, shape = (B,)
+ - Rotation angle in radian.
+ ### Returns
+ - np.ndarray: Rotated orientation SKEL q.
+ '''
+ # Transform skel q to rot mat.
+ q, recover_type_back = to_tensor(q, device=None, temporary=True) # (B, 3)
+ q = q[:, [2, 1, 0]]
+ Rp = euler_angles_to_matrix(q, convention="YXZ")
+ # Rotate around z
+ rot = to_tensor(-rot_rad, device=q.device).float() # (B,)
+ padding_zeros = torch.zeros_like(rot) # (B,)
+ R = torch.stack([rot, padding_zeros, padding_zeros], dim=1) # (B, 3)
+ R = axis_angle_to_matrix(R)
+ R = torch.matmul(R, Rp)
+ # Transform rot mat to skel q.
+ q = matrix_to_euler_angles(R, convention="YXZ")
+ q = q[:, [2, 1, 0]]
+ q = recover_type_back(q) # (B, 3)
+
+ return q
+
+
+def rot_skel_on_plane(
+ params : Union[Dict, np.ndarray, torch.Tensor],
+ rot_deg : Union[np.ndarray, torch.Tensor, List[float]]
+):
+ '''
+ Rotate the skel parameters on the plane (around the z-axis),
+ in order to align the skel with the rotated image. To perform
+ this operation, we need to modify the orientation of the skel
+ parameters.
+
+ ### Args
+ - params: Dict or (np.ndarray or torch.Tensor)
+ - If is dict, it should contain the following keys
+ - 'poses': np.ndarray or torch.Tensor (B, 72)
+ - ...
+ - If is np.ndarray or torch.Tensor, it should be the 'poses' part.
+ - rot_deg: np.ndarray, torch.Tensor or List[float]
+ - Rotation angle in degrees.
+
+ ### Returns
+ - One of the following according to the input type:
+ - Dict: Modified skel parameters.
+ - np.ndarray or torch.Tensor: Modified skel poses parameters.
+ '''
+ rot_deg = to_numpy(rot_deg) # (B,)
+ rot_rad = np.deg2rad(rot_deg) # (B,)
+
+ if isinstance(params, Dict):
+ ret = {}
+ for k, v in params.items():
+ if isinstance(v, np.ndarray):
+ ret[k] = v.copy()
+ elif isinstance(v, torch.Tensor):
+ ret[k] = v.clone()
+ else:
+ ret[k] = v
+ ret['poses'][:, :3] = rot_q_orient(ret['poses'][:, :3], rot_rad)
+ elif isinstance(params, (np.ndarray, torch.Tensor)):
+ if isinstance(params, np.ndarray):
+ ret = params.copy()
+ elif isinstance(params, torch.Tensor):
+ ret = params.clone()
+ else:
+ raise TypeError(f'Unsupported type: {type(params)}')
+ ret[:, :3] = rot_q_orient(ret[:, :3], rot_rad)
+ else:
+ raise TypeError(f'Unsupported type: {type(params)}')
+ return ret
\ No newline at end of file
diff --git a/lib/data/datasets/hsmr_eval_3d/eval3d_dataset.py b/lib/data/datasets/hsmr_eval_3d/eval3d_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..fac7c53eee1b6e98fdfb8b69c0bce6f8e67b4f30
--- /dev/null
+++ b/lib/data/datasets/hsmr_eval_3d/eval3d_dataset.py
@@ -0,0 +1,88 @@
+from lib.kits.basic import *
+
+from torch.utils import data
+from lib.utils.media import load_img, flex_resize_img
+from lib.utils.bbox import cwh_to_cs, cs_to_lurb, crop_with_lurb, fit_bbox_to_aspect_ratio
+from lib.utils.data import to_numpy
+
+IMG_MEAN_255 = to_numpy([0.485, 0.456, 0.406]) * 255.
+IMG_STD_255 = to_numpy([0.229, 0.224, 0.225]) * 255.
+
+
+class Eval3DDataset(data.Dataset):
+
+ def __init__(self, npz_fn:Union[str, Path], ignore_img=False):
+ super().__init__()
+ self.data = None
+ self._load_data(npz_fn)
+ self.ds_root = self._get_ds_root()
+ self.bbox_ratio = (192, 256) # the ViT Backbone's input size is w=192, h=256
+ self.ignore_img = ignore_img # For speed up, if True, don't process images.
+
+ def _load_data(self, npz_fn:Union[str, Path]):
+ supported_datasets = ['MoYo']
+ raw_data = np.load(npz_fn, allow_pickle=True)
+
+ # Load some meta data.
+ self.extra_info = raw_data['extra_info'].item()
+ self.ds_name = self.extra_info.pop('dataset_name')
+ # Load basic information.
+ self.seq_names = raw_data['names'] # (L,)
+ self.img_paths = raw_data['img_paths'] # (L,)
+ self.bbox_centers = raw_data['centers'].astype(np.float32) # (L, 2)
+ self.bbox_scales = raw_data['scales'].astype(np.float32) # (L, 2)
+ self.L = len(self.seq_names)
+ # Load the g.t. SMPL parameters.
+ self.genders = raw_data.get('genders', None) # (L, 2) or None
+ self.global_orient = raw_data['smpl'].item()['global_orient'].reshape(-1, 1 ,3).astype(np.float32) # (L, 1, 3)
+ self.body_pose = raw_data['smpl'].item()['body_pose'].reshape(-1, 23 ,3).astype(np.float32) # (L, 23, 3)
+ self.betas = raw_data['smpl'].item()['betas'].reshape(-1, 10).astype(np.float32) # (L, 10)
+ # Check validity.
+ assert self.ds_name in supported_datasets, f'Unsupported dataset: {self.ds_name}'
+
+
+ def __len__(self):
+ return self.L
+
+
+ def _process_img_patch(self, idx):
+ ''' Load and crop according to bbox. '''
+ if self.ignore_img:
+ return np.zeros((1), dtype=np.float32)
+
+ img, _ = load_img(self.ds_root / self.img_paths[idx]) # (H, W, RGB)
+ scale = self.bbox_scales[idx] # (2,)
+ center = self.bbox_centers[idx] # (2,)
+ bbox_cwh = np.concatenate([center, scale], axis=0) # (4,) lurb format
+ bbox_cwh = fit_bbox_to_aspect_ratio(
+ bbox = bbox_cwh,
+ tgt_ratio = self.bbox_ratio,
+ bbox_type = 'cwh'
+ )
+ bbox_cs = cwh_to_cs(bbox_cwh, reduce='max') # (3,), make it to square
+ bbox_lurb = cs_to_lurb(bbox_cs) # (4,)
+ img_patch = crop_with_lurb(img, bbox_lurb) # (H', W', RGB)
+ img_patch = flex_resize_img(img_patch, tgt_wh=(256, 256))
+ img_patch_normalized = (img_patch - IMG_MEAN_255) / IMG_STD_255 # (H', W', RGB)
+ img_patch_normalized = img_patch_normalized.transpose(2, 0, 1) # (RGB, H', W')
+ return img_patch_normalized.astype(np.float32)
+
+
+ def _get_ds_root(self):
+ return PM.inputs / 'datasets' / self.ds_name.lower()
+
+
+ def __getitem__(self, idx):
+ ret = {}
+ ret['seq_name'] = self.seq_names[idx]
+ ret['smpl'] = {
+ 'global_orient': self.global_orient[idx],
+ 'body_pose' : self.body_pose[idx],
+ 'betas' : self.betas[idx],
+ }
+ if self.genders is not None:
+ ret['gender'] = self.genders[idx]
+ ret['img_patch'] = self._process_img_patch(idx)
+
+ return ret
+
diff --git a/lib/data/datasets/hsmr_v1/crop.py b/lib/data/datasets/hsmr_v1/crop.py
new file mode 100644
index 0000000000000000000000000000000000000000..9c3b8493f0f2aa92f34fcc9e9a93265f73a60869
--- /dev/null
+++ b/lib/data/datasets/hsmr_v1/crop.py
@@ -0,0 +1,295 @@
+from lib.kits.basic import *
+
+# Copied from: https://github.com/shubham-goel/4D-Humans/blob/6ec79656a23c33237c724742ca2a0ec00b398b53/hmr2/datasets/utils.py#L663-L944
+
+def crop_to_hips(center_x: float, center_y: float, width: float, height: float, keypoints_2d: np.ndarray) -> Tuple:
+ """
+ Extreme cropping: Crop the box up to the hip locations.
+ Args:
+ center_x (float): x coordinate of the bounding box center.
+ center_y (float): y coordinate of the bounding box center.
+ width (float): Bounding box width.
+ height (float): Bounding box height.
+ keypoints_2d (np.ndarray): Array of shape (N, 3) containing 2D keypoint locations.
+ Returns:
+ center_x (float): x coordinate of the new bounding box center.
+ center_y (float): y coordinate of the new bounding box center.
+ width (float): New bounding box width.
+ height (float): New bounding box height.
+ """
+ keypoints_2d = keypoints_2d.copy()
+ lower_body_keypoints = [10, 11, 13, 14, 19, 20, 21, 22, 23, 24, 25+0, 25+1, 25+4, 25+5]
+ keypoints_2d[lower_body_keypoints, :] = 0
+ if keypoints_2d[:, -1].sum() > 1:
+ center, scale = get_bbox(keypoints_2d)
+ center_x = center[0]
+ center_y = center[1]
+ width = 1.1 * scale[0]
+ height = 1.1 * scale[1]
+ return center_x, center_y, width, height
+
+
+def crop_to_shoulders(center_x: float, center_y: float, width: float, height: float, keypoints_2d: np.ndarray):
+ """
+ Extreme cropping: Crop the box up to the shoulder locations.
+ Args:
+ center_x (float): x coordinate of the bounding box center.
+ center_y (float): y coordinate of the bounding box center.
+ width (float): Bounding box width.
+ height (float): Bounding box height.
+ keypoints_2d (np.ndarray): Array of shape (N, 3) containing 2D keypoint locations.
+ Returns:
+ center_x (float): x coordinate of the new bounding box center.
+ center_y (float): y coordinate of the new bounding box center.
+ width (float): New bounding box width.
+ height (float): New bounding box height.
+ """
+ keypoints_2d = keypoints_2d.copy()
+ lower_body_keypoints = [3, 4, 6, 7, 8, 9, 10, 11, 12, 13, 14, 19, 20, 21, 22, 23, 24] + [25 + i for i in [0, 1, 2, 3, 4, 5, 6, 7, 10, 11, 14, 15, 16]]
+ keypoints_2d[lower_body_keypoints, :] = 0
+ center, scale = get_bbox(keypoints_2d)
+ if keypoints_2d[:, -1].sum() > 1:
+ center, scale = get_bbox(keypoints_2d)
+ center_x = center[0]
+ center_y = center[1]
+ width = 1.2 * scale[0]
+ height = 1.2 * scale[1]
+ return center_x, center_y, width, height
+
+
+def crop_to_head(center_x: float, center_y: float, width: float, height: float, keypoints_2d: np.ndarray):
+ """
+ Extreme cropping: Crop the box and keep on only the head.
+ Args:
+ center_x (float): x coordinate of the bounding box center.
+ center_y (float): y coordinate of the bounding box center.
+ width (float): Bounding box width.
+ height (float): Bounding box height.
+ keypoints_2d (np.ndarray): Array of shape (N, 3) containing 2D keypoint locations.
+ Returns:
+ center_x (float): x coordinate of the new bounding box center.
+ center_y (float): y coordinate of the new bounding box center.
+ width (float): New bounding box width.
+ height (float): New bounding box height.
+ """
+ keypoints_2d = keypoints_2d.copy()
+ lower_body_keypoints = [3, 4, 6, 7, 8, 9, 10, 11, 12, 13, 14, 19, 20, 21, 22, 23, 24] + [25 + i for i in [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 14, 15, 16]]
+ keypoints_2d[lower_body_keypoints, :] = 0
+ if keypoints_2d[:, -1].sum() > 1:
+ center, scale = get_bbox(keypoints_2d)
+ center_x = center[0]
+ center_y = center[1]
+ width = 1.3 * scale[0]
+ height = 1.3 * scale[1]
+ return center_x, center_y, width, height
+
+
+def crop_torso_only(center_x: float, center_y: float, width: float, height: float, keypoints_2d: np.ndarray):
+ """
+ Extreme cropping: Crop the box and keep on only the torso.
+ Args:
+ center_x (float): x coordinate of the bounding box center.
+ center_y (float): y coordinate of the bounding box center.
+ width (float): Bounding box width.
+ height (float): Bounding box height.
+ keypoints_2d (np.ndarray): Array of shape (N, 3) containing 2D keypoint locations.
+ Returns:
+ center_x (float): x coordinate of the new bounding box center.
+ center_y (float): y coordinate of the new bounding box center.
+ width (float): New bounding box width.
+ height (float): New bounding box height.
+ """
+ keypoints_2d = keypoints_2d.copy()
+ nontorso_body_keypoints = [0, 3, 4, 6, 7, 10, 11, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24] + [25 + i for i in [0, 1, 4, 5, 6, 7, 10, 11, 13, 17, 18]]
+ keypoints_2d[nontorso_body_keypoints, :] = 0
+ if keypoints_2d[:, -1].sum() > 1:
+ center, scale = get_bbox(keypoints_2d)
+ center_x = center[0]
+ center_y = center[1]
+ width = 1.1 * scale[0]
+ height = 1.1 * scale[1]
+ return center_x, center_y, width, height
+
+
+def crop_rightarm_only(center_x: float, center_y: float, width: float, height: float, keypoints_2d: np.ndarray):
+ """
+ Extreme cropping: Crop the box and keep on only the right arm.
+ Args:
+ center_x (float): x coordinate of the bounding box center.
+ center_y (float): y coordinate of the bounding box center.
+ width (float): Bounding box width.
+ height (float): Bounding box height.
+ keypoints_2d (np.ndarray): Array of shape (N, 3) containing 2D keypoint locations.
+ Returns:
+ center_x (float): x coordinate of the new bounding box center.
+ center_y (float): y coordinate of the new bounding box center.
+ width (float): New bounding box width.
+ height (float): New bounding box height.
+ """
+ keypoints_2d = keypoints_2d.copy()
+ nonrightarm_body_keypoints = [0, 1, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24] + [25 + i for i in [0, 1, 2, 3, 4, 5, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18]]
+ keypoints_2d[nonrightarm_body_keypoints, :] = 0
+ if keypoints_2d[:, -1].sum() > 1:
+ center, scale = get_bbox(keypoints_2d)
+ center_x = center[0]
+ center_y = center[1]
+ width = 1.1 * scale[0]
+ height = 1.1 * scale[1]
+ return center_x, center_y, width, height
+
+
+def crop_leftarm_only(center_x: float, center_y: float, width: float, height: float, keypoints_2d: np.ndarray):
+ """
+ Extreme cropping: Crop the box and keep on only the left arm.
+ Args:
+ center_x (float): x coordinate of the bounding box center.
+ center_y (float): y coordinate of the bounding box center.
+ width (float): Bounding box width.
+ height (float): Bounding box height.
+ keypoints_2d (np.ndarray): Array of shape (N, 3) containing 2D keypoint locations.
+ Returns:
+ center_x (float): x coordinate of the new bounding box center.
+ center_y (float): y coordinate of the new bounding box center.
+ width (float): New bounding box width.
+ height (float): New bounding box height.
+ """
+ keypoints_2d = keypoints_2d.copy()
+ nonleftarm_body_keypoints = [0, 1, 2, 3, 4, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24] + [25 + i for i in [0, 1, 2, 3, 4, 5, 6, 7, 8, 12, 13, 14, 15, 16, 17, 18]]
+ keypoints_2d[nonleftarm_body_keypoints, :] = 0
+ if keypoints_2d[:, -1].sum() > 1:
+ center, scale = get_bbox(keypoints_2d)
+ center_x = center[0]
+ center_y = center[1]
+ width = 1.1 * scale[0]
+ height = 1.1 * scale[1]
+ return center_x, center_y, width, height
+
+
+def crop_legs_only(center_x: float, center_y: float, width: float, height: float, keypoints_2d: np.ndarray):
+ """
+ Extreme cropping: Crop the box and keep on only the legs.
+ Args:
+ center_x (float): x coordinate of the bounding box center.
+ center_y (float): y coordinate of the bounding box center.
+ width (float): Bounding box width.
+ height (float): Bounding box height.
+ keypoints_2d (np.ndarray): Array of shape (N, 3) containing 2D keypoint locations.
+ Returns:
+ center_x (float): x coordinate of the new bounding box center.
+ center_y (float): y coordinate of the new bounding box center.
+ width (float): New bounding box width.
+ height (float): New bounding box height.
+ """
+ keypoints_2d = keypoints_2d.copy()
+ nonlegs_body_keypoints = [0, 1, 2, 3, 4, 5, 6, 7, 15, 16, 17, 18] + [25 + i for i in [6, 7, 8, 9, 10, 11, 12, 13, 15, 16, 17, 18]]
+ keypoints_2d[nonlegs_body_keypoints, :] = 0
+ if keypoints_2d[:, -1].sum() > 1:
+ center, scale = get_bbox(keypoints_2d)
+ center_x = center[0]
+ center_y = center[1]
+ width = 1.1 * scale[0]
+ height = 1.1 * scale[1]
+ return center_x, center_y, width, height
+
+
+def crop_rightleg_only(center_x: float, center_y: float, width: float, height: float, keypoints_2d: np.ndarray):
+ """
+ Extreme cropping: Crop the box and keep on only the right leg.
+ Args:
+ center_x (float): x coordinate of the bounding box center.
+ center_y (float): y coordinate of the bounding box center.
+ width (float): Bounding box width.
+ height (float): Bounding box height.
+ keypoints_2d (np.ndarray): Array of shape (N, 3) containing 2D keypoint locations.
+ Returns:
+ center_x (float): x coordinate of the new bounding box center.
+ center_y (float): y coordinate of the new bounding box center.
+ width (float): New bounding box width.
+ height (float): New bounding box height.
+ """
+ keypoints_2d = keypoints_2d.copy()
+ nonrightleg_body_keypoints = [0, 1, 2, 3, 4, 5, 6, 7, 8, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21] + [25 + i for i in [3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18]]
+ keypoints_2d[nonrightleg_body_keypoints, :] = 0
+ if keypoints_2d[:, -1].sum() > 1:
+ center, scale = get_bbox(keypoints_2d)
+ center_x = center[0]
+ center_y = center[1]
+ width = 1.1 * scale[0]
+ height = 1.1 * scale[1]
+ return center_x, center_y, width, height
+
+def crop_leftleg_only(center_x: float, center_y: float, width: float, height: float, keypoints_2d: np.ndarray):
+ """
+ Extreme cropping: Crop the box and keep on only the left leg.
+ Args:
+ center_x (float): x coordinate of the bounding box center.
+ center_y (float): y coordinate of the bounding box center.
+ width (float): Bounding box width.
+ height (float): Bounding box height.
+ keypoints_2d (np.ndarray): Array of shape (N, 3) containing 2D keypoint locations.
+ Returns:
+ center_x (float): x coordinate of the new bounding box center.
+ center_y (float): y coordinate of the new bounding box center.
+ width (float): New bounding box width.
+ height (float): New bounding box height.
+ """
+ keypoints_2d = keypoints_2d.copy()
+ nonleftleg_body_keypoints = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 15, 16, 17, 18, 22, 23, 24] + [25 + i for i in [0, 1, 2, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18]]
+ keypoints_2d[nonleftleg_body_keypoints, :] = 0
+ if keypoints_2d[:, -1].sum() > 1:
+ center, scale = get_bbox(keypoints_2d)
+ center_x = center[0]
+ center_y = center[1]
+ width = 1.1 * scale[0]
+ height = 1.1 * scale[1]
+ return center_x, center_y, width, height
+
+
+def full_body(keypoints_2d: np.ndarray) -> bool:
+ """
+ Check if all main body joints are visible.
+ Args:
+ keypoints_2d (np.ndarray): Array of shape (N, 3) containing 2D keypoint locations.
+ Returns:
+ bool: True if all main body joints are visible.
+ """
+
+ body_keypoints_openpose = [2, 3, 4, 5, 6, 7, 10, 11, 13, 14]
+ body_keypoints = [25 + i for i in [8, 7, 6, 9, 10, 11, 1, 0, 4, 5]]
+ return (np.maximum(keypoints_2d[body_keypoints, -1], keypoints_2d[body_keypoints_openpose, -1]) > 0).sum() == len(body_keypoints)
+
+
+def upper_body(keypoints_2d: np.ndarray):
+ """
+ Check if all upper body joints are visible.
+ Args:
+ keypoints_2d (np.ndarray): Array of shape (N, 3) containing 2D keypoint locations.
+ Returns:
+ bool: True if all main body joints are visible.
+ """
+ lower_body_keypoints_openpose = [10, 11, 13, 14]
+ lower_body_keypoints = [25 + i for i in [1, 0, 4, 5]]
+ upper_body_keypoints_openpose = [0, 1, 15, 16, 17, 18]
+ upper_body_keypoints = [25+8, 25+9, 25+12, 25+13, 25+17, 25+18]
+ return ((keypoints_2d[lower_body_keypoints + lower_body_keypoints_openpose, -1] > 0).sum() == 0)\
+ and ((keypoints_2d[upper_body_keypoints + upper_body_keypoints_openpose, -1] > 0).sum() >= 2)
+
+
+def get_bbox(keypoints_2d: np.ndarray, rescale: float = 1.2) -> Tuple:
+ """
+ Get center and scale for bounding box from openpose detections.
+ Args:
+ keypoints_2d (np.ndarray): Array of shape (N, 3) containing 2D keypoint locations.
+ rescale (float): Scale factor to rescale bounding boxes computed from the keypoints.
+ Returns:
+ center (np.ndarray): Array of shape (2,) containing the new bounding box center.
+ scale (float): New bounding box scale.
+ """
+ valid = keypoints_2d[:,-1] > 0
+ valid_keypoints = keypoints_2d[valid][:,:-1]
+ center = 0.5 * (valid_keypoints.max(axis=0) + valid_keypoints.min(axis=0))
+ bbox_size = (valid_keypoints.max(axis=0) - valid_keypoints.min(axis=0))
+ # adjust bounding box tightness
+ scale = bbox_size
+ scale *= rescale
+ return center, scale
diff --git a/lib/data/datasets/hsmr_v1/mocap_dataset.py b/lib/data/datasets/hsmr_v1/mocap_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..ca37c47ac7918d943aa5840d93499d8ac6a23058
--- /dev/null
+++ b/lib/data/datasets/hsmr_v1/mocap_dataset.py
@@ -0,0 +1,32 @@
+from lib.kits.basic import *
+
+
+class MoCapDataset:
+ def __init__(self, dataset_file:str, pve_threshold:Optional[float]=None):
+ '''
+ Dataset class used for loading a dataset of unpaired SMPL parameter annotations
+ ### Args
+ - dataset_file: str
+ - Path to the dataset npz file.
+ - pve_threshold: float
+ - Threshold for PVE quality filtering.
+ '''
+ data = np.load(dataset_file)
+ if pve_threshold is not None:
+ pve = data['pve_max']
+ mask = pve < pve_threshold
+ else:
+ mask = np.ones(len(data['poses']), dtype=np.bool)
+ self.poses = data['poses'].astype(np.float32)[mask, 3:]
+ self.betas = data['betas'].astype(np.float32)[mask]
+ self.length = len(self.poses)
+ get_logger().info(f'Loaded {self.length} items among {len(pve)} samples filtered from {dataset_file} (using threshold = {pve_threshold})')
+
+ def __getitem__(self, idx: int) -> Dict:
+ poses = self.poses[idx].copy()
+ betas = self.betas[idx].copy()
+ item = {'poses_body': poses, 'betas': betas}
+ return item
+
+ def __len__(self) -> int:
+ return self.length
diff --git a/lib/data/datasets/hsmr_v1/stream_pipelines.py b/lib/data/datasets/hsmr_v1/stream_pipelines.py
new file mode 100644
index 0000000000000000000000000000000000000000..6816b23192463cb3bdecf51d69a367018c5acdc2
--- /dev/null
+++ b/lib/data/datasets/hsmr_v1/stream_pipelines.py
@@ -0,0 +1,342 @@
+from lib.kits.basic import *
+
+import math
+import webdataset as wds
+
+
+from .utils import (
+ get_augm_args,
+ expand_to_aspect_ratio,
+ generate_image_patch_cv2,
+ flip_lr_keypoints,
+ extreme_cropping_aggressive,
+)
+
+
+def apply_corrupt_filter(dataset:wds.WebDataset):
+ AIC_TRAIN_CORRUPT_KEYS = {
+ '0a047f0124ae48f8eee15a9506ce1449ee1ba669', '1a703aa174450c02fbc9cfbf578a5435ef403689',
+ '0394e6dc4df78042929b891dbc24f0fd7ffb6b6d', '5c032b9626e410441544c7669123ecc4ae077058',
+ 'ca018a7b4c5f53494006ebeeff9b4c0917a55f07', '4a77adb695bef75a5d34c04d589baf646fe2ba35',
+ 'a0689017b1065c664daef4ae2d14ea03d543217e', '39596a45cbd21bed4a5f9c2342505532f8ec5cbb',
+ '3d33283b40610d87db660b62982f797d50a7366b',
+ }
+
+ CORRUPT_KEYS = {
+ *{f'aic-train/{k}' for k in AIC_TRAIN_CORRUPT_KEYS},
+ *{f'aic-train-vitpose/{k}' for k in AIC_TRAIN_CORRUPT_KEYS},
+ }
+
+ dataset = dataset.select(lambda sample: (sample['__key__'] not in CORRUPT_KEYS))
+ return dataset
+
+
+def apply_multi_ppl_splitter(dataset:wds.WebDataset):
+ '''
+ Each item in the raw dataset contains multiple people, we need to split them into individual samples.
+ Meanwhile, we also need to note down the person id (pid) for each sample.
+ '''
+ def multi_ppl_splitter(source):
+ for item in source:
+ data_multi_ppl = item['data.pyd'] # list of data for multiple people
+ for pid, data in enumerate(data_multi_ppl):
+ data['pid'] = pid
+
+ if 'detection.npz' in item:
+ det_idx = data['extra_info']['detection_npz_idx']
+ mask = item['detection.npz']['masks'][det_idx]
+ else:
+ mask = np.ones_like(item['jpg'][:, :, 0], dtype=bool)
+ yield {
+ '__key__' : item['__key__'] + f'_{pid}',
+ 'img_name' : item['__key__'],
+ 'img' : item['jpg'],
+ 'data' : data,
+ 'mask' : mask,
+ }
+
+ return dataset.compose(multi_ppl_splitter)
+
+
+def apply_keys_adapter(dataset:wds.WebDataset):
+ ''' Adapt the keys of the items, so we can adapt different version of dataset. '''
+ def keys_adapter(item):
+ data = item['data']
+ data['kp2d'] = data.pop('keypoints_2d')
+ data['kp3d'] = data.pop('keypoints_3d')
+ return item
+ return dataset.map(keys_adapter)
+
+
+def apply_bad_pgt_params_nan_suppressor(dataset:wds.WebDataset):
+ ''' If the poses or betas contain NaN, we regard it as bad pseudo-GT and zero them out. '''
+ def bad_pgt_params_suppressor(item):
+ for side in ['orig', 'flip']:
+ poses = item['data'][f'{side}_poses'] # (J, 3)
+ betas = item['data'][f'{side}_betas'] # (10,)
+ poses_has_nan = np.isnan(poses).any()
+ betas_has_nan = np.isnan(betas).any()
+ if poses_has_nan or betas_has_nan:
+ item['data'][f'{side}_has_poses'] = False
+ item['data'][f'{side}_has_betas'] = False
+ if poses_has_nan:
+ item['data'][f'{side}_poses'][:] = 0
+ if betas_has_nan:
+ item['data'][f'{side}_betas'][:] = 0
+ return item
+ dataset = dataset.map(bad_pgt_params_suppressor)
+ return dataset
+
+
+def apply_bad_pgt_params_kp2d_err_suppressor(dataset:wds.WebDataset, thresh:float=0.1):
+ ''' If the 2D keypoints error of one single person is higher than the threshold, we regard it as bad pseudo-GT. '''
+ if thresh > 0:
+ def bad_pgt_params_suppressor(item):
+ for side in ['orig', 'flip']:
+ if thresh > 0:
+ kp2d_err = item['data'][f'{side}_kp2d_err']
+ is_valid_pgt = kp2d_err < thresh
+ item['data'][f'{side}_has_poses'] = is_valid_pgt
+ item['data'][f'{side}_has_betas'] = is_valid_pgt
+ return item
+ dataset = dataset.map(bad_pgt_params_suppressor)
+ return dataset
+
+
+def apply_bad_pgt_params_pve_max_suppressor(dataset:wds.WebDataset, thresh:float=0.1):
+ ''' If the PVE-Max of one single person is higher than the threshold, we regard it as bad pseudo-GT. '''
+ if thresh > 0:
+ def bad_pgt_params_suppressor(item):
+ for side in ['orig', 'flip']:
+ if thresh > 0:
+ pve_max = item['data'][f'{side}_pve_max']
+ is_valid_pose = not math.isnan(pve_max)
+ is_valid_pgt = pve_max < thresh and is_valid_pose
+ item['data'][f'{side}_has_poses'] = is_valid_pgt
+ item['data'][f'{side}_has_betas'] = is_valid_pgt
+ return item
+ dataset = dataset.map(bad_pgt_params_suppressor)
+ return dataset
+
+
+def apply_bad_kp_suppressor(dataset:wds.WebDataset, thresh:float=0.0):
+ ''' If the confidence of a keypoint is lower than the threshold, we reset it to 0. '''
+ eps = 1e-6
+ if thresh > eps:
+ def bad_kp_suppressor(item):
+ if thresh > 0:
+ kp2d = item['data']['kp2d']
+ kp2d_conf = np.where(kp2d[:, 2] < thresh, 0.0, kp2d[:, 2]) # suppress bad keypoints
+ item['data']['kp2d'] = np.concatenate([kp2d[:, :2], kp2d_conf[:, None]], axis=1)
+ return item
+ dataset = dataset.map(bad_kp_suppressor)
+ return dataset
+
+
+def apply_bad_betas_suppressor(dataset:wds.WebDataset, thresh:float=3):
+ ''' If the absolute value of betas is higher than the threshold, we regard it as bad betas. '''
+ eps = 1e-6
+ if thresh > eps:
+ def bad_betas_suppressor(item):
+ for side in ['orig', 'flip']:
+ has_betas = item['data'][f'{side}_has_betas'] # use this condition to save time
+ if thresh > 0 and has_betas:
+ betas_abs = np.abs(item['data'][f'{side}_betas'])
+ if (betas_abs > thresh).any():
+ item['data'][f'{side}_has_betas'] = False
+ return item
+ dataset = dataset.map(bad_betas_suppressor)
+ return dataset
+
+
+def apply_params_synchronizer(dataset:wds.WebDataset, poses_betas_simultaneous:bool=False):
+ ''' Only when both poses and betas are valid, we regard them as valid. '''
+ if poses_betas_simultaneous:
+ def params_synchronizer(item):
+ for side in ['orig', 'flip']:
+ has_betas = item['data'][f'{side}_has_betas']
+ has_poses = item['data'][f'{side}_has_poses']
+ has_both = np.array(float((has_poses > 0) and (has_betas > 0)))
+ item['data'][f'{side}_has_betas'] = has_both
+ item['data'][f'{side}_has_poses'] = has_both
+ return item
+ dataset = dataset.map(params_synchronizer)
+ return dataset
+
+
+def apply_insuff_kp_filter(dataset:wds.WebDataset, cnt_thresh:int=4, conf_thresh:float=0.0):
+ '''
+ Counting the number of keypoints with confidence higher than the threshold.
+ If the number is less than the threshold, we regard it has insufficient valid 2D keypoints.
+ '''
+ if cnt_thresh > 0:
+ def insuff_kp_filter(item):
+ kp_conf = item['data']['kp2d'][:, 2]
+ return (kp_conf > conf_thresh).sum() > cnt_thresh
+ dataset = dataset.select(insuff_kp_filter)
+ return dataset
+
+
+def apply_bbox_size_filter(dataset:wds.WebDataset, bbox_size_thresh:Optional[float]=None):
+ if bbox_size_thresh:
+ def bbox_size_filter(item):
+ bbox_size = item['data']['scale'] * 200
+ return bbox_size.min() > bbox_size_thresh # ensure the lower bound is large enough
+ dataset = dataset.select(bbox_size_filter)
+ return dataset
+
+
+def apply_reproj_err_filter(dataset:wds.WebDataset, thresh:float=0.0):
+ ''' If the re-projection error is higher than the threshold, we regard it as bad sample. '''
+ if thresh > 0:
+ def reproj_err_filter(item):
+ losses = item['data'].get('extra_info', {}).get('fitting_loss', np.array({})).item()
+ reproj_loss = losses.get('reprojection_loss', None)
+ return reproj_loss is None or reproj_loss < thresh
+ dataset = dataset.select(reproj_err_filter)
+ return dataset
+
+
+def apply_invalid_betas_regularizer(dataset:wds.WebDataset, reg_betas:bool=False):
+ ''' For those items with invalid betas, set them to zero. '''
+ if reg_betas:
+ def betas_regularizer(item):
+ # Always have betas set to zero, and all valid.
+ for side in ['orig', 'flip']:
+ has_betas = item['data'][f'{side}_has_betas']
+ betas = item['data'][f'{side}_betas']
+
+ if not (has_betas > 0):
+ item['data'][f'{side}_has_betas'] = np.array(float((True)))
+ item['data'][f'{side}_betas'] = betas * 0
+ return item
+ dataset = dataset.map(betas_regularizer)
+ return dataset
+
+
+def apply_example_formatter(dataset:wds.WebDataset, cfg:DictConfig):
+ ''' Format the item to the wanted format. '''
+
+ def get_fmt_data(raw_item:Dict, augm_args:Dict, cfg:DictConfig):
+ '''
+ On the one hand, we will perform the augmentation to the image, on the other hand, we need to
+ crop the image to the patch according to the bbox. Both processes would influence the position
+ of related keypoints.
+ After that, we need to align the 2D & 3D keypoints to the augmented image.
+ '''
+ # 1. Prepare the raw data that will be used in the following steps.
+ img_rgb = raw_item['img'] # (H, W, 3)
+ img_a = raw_item['mask'].astype(np.uint8)[:, :, None] * 255 # (H, W, 1) mask is 0/1 valued
+ img_rgba = np.concatenate([img_rgb, img_a], axis=2) # (H, W, 4)
+ H, W, C = img_rgb.shape
+ cx, cy = raw_item['data']['center']
+ # bbox_size = (raw_item['data']['scale'] * 200).max()
+ bbox_size = expand_to_aspect_ratio(
+ raw_item['data']['scale'] * 200,
+ target_aspect_ratio = cfg.policy.bbox_shape,
+ ).max()
+
+ kp2d_with_conf = raw_item['data']['kp2d'].astype('float32') # (J, 3)
+ kp3d_with_conf = raw_item['data']['kp3d'].astype('float32') # (J, 4)
+
+ # 2. [img][Augmentation] Extreme cropping according to the 2D keypoints.
+ if augm_args['do_extreme_crop']:
+ cx_, cy_, bbox_size_ = extreme_cropping_aggressive(cx, cy, bbox_size, bbox_size, kp2d_with_conf)
+ # Only apply the crop if the results is large enough.
+ THRESH = 4
+ if bbox_size_ > THRESH:
+ cx, cy = cx_, cy_
+ bbox_size = bbox_size_
+
+ # 3. [img][Augmentation] Shift the center of the image.
+ cx += augm_args['tx_ratio'] * bbox_size
+ cy += augm_args['ty_ratio'] * bbox_size
+
+ # 4. [img][Format] Crop the image to the patch.
+ img_patch_cv2, transform_2d = generate_image_patch_cv2(
+ img = img_rgba,
+ c_x = cx,
+ c_y = cy,
+ bb_width = bbox_size,
+ bb_height = bbox_size,
+ patch_width = cfg.policy.img_patch_size,
+ patch_height = cfg.policy.img_patch_size,
+ do_flip = augm_args['do_flip'],
+ scale = augm_args['bbox_scale'],
+ rot = augm_args['rot_deg'],
+ ) # (H, W, 4), (2, 3)
+
+ img_patch_hwc = img_patch_cv2.copy()[:, :, :3] # (H, W, C)
+ img_patch_chw = img_patch_hwc.transpose(2, 0, 1).astype(np.float32)
+
+ # 5. [img][Augmentation] Scale the color
+ for cid in range(min(C, 3)):
+ img_patch_chw[cid] = np.clip(
+ a = img_patch_chw[cid] * augm_args['color_scale'][cid],
+ a_min = 0,
+ a_max = 255,
+ )
+
+ # 6. [img][Format] Normalize the color.
+ img_mean = [255. * x for x in cfg.policy.img_mean]
+ img_std = [255. * x for x in cfg.policy.img_std]
+ for cid in range(min(C, 3)):
+ img_patch_chw[cid] = (img_patch_chw[cid] - img_mean[cid]) / img_std[cid]
+
+ # 7. [kp2d][Alignment] Align the 2D keypoints.
+ # 7.1. Flip.
+ if augm_args['do_flip']:
+ kp2d_with_conf = flip_lr_keypoints(kp2d_with_conf, W)
+ # 7.2. Others. Transform the 2D keypoints according to the same transformation of image.
+ J = len(kp2d_with_conf)
+ kp2d_homo = np.concatenate([kp2d_with_conf[:, :2], np.ones((J, 1))], axis=1) # (J, 3)
+ kp2d = np.einsum('ph, jh -> jp', transform_2d, kp2d_homo) # (J, 2)
+ kp2d_with_conf[:, :2] = kp2d # (J, 3)
+
+ # 8. [kp2d][Format] Normalize the 2D keypoints position to [-0.5, 0.5]-visible space.
+ kp2d_with_conf[:, :2] = kp2d_with_conf[:, :2] / cfg.policy.img_patch_size - 0.5
+
+ # 9. [kp3d][Alignment] Align the 3D keypoints.
+ # 9.1. Flip.
+ if augm_args['do_flip']:
+ kp3d_with_conf = flip_lr_keypoints(kp3d_with_conf, W)
+ # 9.2. In-plane rotation.
+ rot_mat = np.eye(3)
+ # TODO: maybe this part can be packed to a single function.
+ if not augm_args['rot_deg'] == 0:
+ rot_rad = -augm_args['rot_deg'] * np.pi / 180
+ sn, cs = np.sin(rot_rad), np.cos(rot_rad)
+ rot_mat[0, :2] = [cs, -sn]
+ rot_mat[1, :2] = [sn, cs]
+ kp3d_with_conf[:, :3] = np.einsum('ij, kj -> ki', rot_mat, kp3d_with_conf[:, :3])
+
+ return img_patch_chw, kp2d_with_conf, kp3d_with_conf
+
+ def example_formatter(raw_item):
+ raw_data = raw_item['data']
+ augm_args = get_augm_args(cfg.image_augmentation)
+
+ params_side = 'flip' if augm_args['do_flip'] else 'orig'
+ img_patch_chw, kp2d, kp3d = get_fmt_data(raw_item, augm_args, cfg)
+
+ fmt_item = {}
+ fmt_item['pid'] = raw_item['data']['pid']
+ fmt_item['img_name'] = raw_item['img_name']
+ fmt_item['img_patch'] = img_patch_chw
+ fmt_item['kp2d'] = kp2d
+ fmt_item['kp3d'] = kp3d
+ fmt_item['augm_args'] = augm_args
+ fmt_item['raw_skel_params'] = {
+ 'poses': raw_data[f'{params_side}_poses'],
+ 'betas': raw_data[f'{params_side}_betas'],
+ }
+ fmt_item['has_skel_params'] = {
+ 'poses': raw_data[f'{params_side}_has_poses'],
+ 'betas': raw_data[f'{params_side}_has_betas'],
+ }
+ fmt_item['updated_by_spin'] = False # Only data updated by spin process will be marked as True.
+
+ return fmt_item
+
+ dataset = dataset.map(example_formatter)
+ return dataset
\ No newline at end of file
diff --git a/lib/data/datasets/hsmr_v1/utils.py b/lib/data/datasets/hsmr_v1/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..107286b414e67825ffea5c834ef83e3c04b68ac4
--- /dev/null
+++ b/lib/data/datasets/hsmr_v1/utils.py
@@ -0,0 +1,327 @@
+from lib.kits.basic import *
+
+import os
+import cv2
+import braceexpand
+from typing import List, Union
+
+from .crop import *
+
+
+def expand_urls(urls: Union[str, List[str]]):
+
+ def expand_url(s):
+ return os.path.expanduser(os.path.expandvars(s))
+
+ if isinstance(urls, str):
+ urls = [urls]
+ urls = [u for url in urls for u in braceexpand.braceexpand(expand_url(url))]
+ return urls
+
+
+def get_augm_args(img_augm_cfg:Optional[DictConfig]):
+ '''
+ Perform some random augmentation to the image and patch it. Here we perform generate augmentation arguments
+ according to the configuration and random seed.
+
+ Briefly speaking, things done here are: size scale, color scale, rotate, flip, extreme crop, translate.
+ '''
+ sample_args = {
+ 'bbox_scale' : 1.0,
+ 'color_scale' : [1.0, 1.0, 1.0],
+ 'rot_deg' : 0.0,
+ 'do_flip' : False,
+ 'do_extreme_crop' : False,
+ 'tx_ratio' : 0.0,
+ 'ty_ratio' : 0.0,
+ }
+
+ if img_augm_cfg is not None:
+ sample_args['tx_ratio'] += np.clip(np.random.randn(), -1.0, 1.0) * img_augm_cfg.trans_factor
+ sample_args['ty_ratio'] += np.clip(np.random.randn(), -1.0, 1.0) * img_augm_cfg.trans_factor
+ sample_args['bbox_scale'] += np.clip(np.random.randn(), -1.0, 1.0) * img_augm_cfg.bbox_scale_factor
+
+ if np.random.random() <= img_augm_cfg.rot_aug_rate:
+ sample_args['rot_deg'] += np.clip(np.random.randn(), -2.0, 2.0) * img_augm_cfg.rot_factor
+ if np.random.random() <= img_augm_cfg.flip_aug_rate:
+ sample_args['do_flip'] = True
+ if np.random.random() <= img_augm_cfg.extreme_crop_aug_rate:
+ sample_args['do_extreme_crop'] = True
+
+ c_up = 1.0 + img_augm_cfg.half_color_scale
+ c_low = 1.0 - img_augm_cfg.half_color_scale
+ sample_args['color_scale'] = [
+ np.random.uniform(c_low, c_up),
+ np.random.uniform(c_low, c_up),
+ np.random.uniform(c_low, c_up),
+ ]
+ return sample_args
+
+
+def rotate_2d(pt_2d: np.ndarray, rot_rad: float) -> np.ndarray:
+ '''
+ Rotate a 2D point on the x-y plane.
+ Copied from: https://github.com/shubham-goel/4D-Humans/blob/6ec79656a23c33237c724742ca2a0ec00b398b53/hmr2/datasets/utils.py#L90-L104
+
+ ### Args
+ - pt_2d: np.ndarray
+ - Input 2D point with shape (2,).
+ - rot_rad: float
+ - Rotation angle.
+
+ ### Returns
+ - np.ndarray
+ - Rotated 2D point.
+ '''
+ x = pt_2d[0]
+ y = pt_2d[1]
+ sn, cs = np.sin(rot_rad), np.cos(rot_rad)
+ xx = x * cs - y * sn
+ yy = x * sn + y * cs
+ return np.array([xx, yy], dtype=np.float32)
+
+
+def extreme_cropping_aggressive(center_x: float, center_y: float, width: float, height: float, keypoints_2d: np.ndarray) -> Tuple:
+ """
+ Perform aggressive extreme cropping.
+ Copied from: https://github.com/shubham-goel/4D-Humans/blob/6ec79656a23c33237c724742ca2a0ec00b398b53/hmr2/datasets/utils.py#L978-L1025
+
+ ### Args
+ - center_x: float
+ - x coordinate of bounding box center.
+ - center_y: float
+ - y coordinate of bounding box center.
+ - width: float
+ - Bounding box width.
+ - height: float
+ - Bounding box height.
+ - keypoints_2d: np.ndarray
+ - Array of shape (N, 3) containing 2D keypoint locations.
+ - rescale: float
+ - Scale factor to rescale bounding boxes computed from the keypoints.
+
+ ### Returns
+ - center_x: float
+ - x coordinate of bounding box center.
+ - center_y: float
+ - y coordinate of bounding box center.
+ - bbox_size: float
+ - Bounding box size.
+ """
+ p = torch.rand(1).item()
+ if full_body(keypoints_2d):
+ if p < 0.2:
+ center_x, center_y, width, height = crop_to_hips(center_x, center_y, width, height, keypoints_2d)
+ elif p < 0.3:
+ center_x, center_y, width, height = crop_to_shoulders(center_x, center_y, width, height, keypoints_2d)
+ elif p < 0.4:
+ center_x, center_y, width, height = crop_to_head(center_x, center_y, width, height, keypoints_2d)
+ elif p < 0.5:
+ center_x, center_y, width, height = crop_torso_only(center_x, center_y, width, height, keypoints_2d)
+ elif p < 0.6:
+ center_x, center_y, width, height = crop_rightarm_only(center_x, center_y, width, height, keypoints_2d)
+ elif p < 0.7:
+ center_x, center_y, width, height = crop_leftarm_only(center_x, center_y, width, height, keypoints_2d)
+ elif p < 0.8:
+ center_x, center_y, width, height = crop_legs_only(center_x, center_y, width, height, keypoints_2d)
+ elif p < 0.9:
+ center_x, center_y, width, height = crop_rightleg_only(center_x, center_y, width, height, keypoints_2d)
+ else:
+ center_x, center_y, width, height = crop_leftleg_only(center_x, center_y, width, height, keypoints_2d)
+ elif upper_body(keypoints_2d):
+ if p < 0.2:
+ center_x, center_y, width, height = crop_to_shoulders(center_x, center_y, width, height, keypoints_2d)
+ elif p < 0.4:
+ center_x, center_y, width, height = crop_to_head(center_x, center_y, width, height, keypoints_2d)
+ elif p < 0.6:
+ center_x, center_y, width, height = crop_torso_only(center_x, center_y, width, height, keypoints_2d)
+ elif p < 0.8:
+ center_x, center_y, width, height = crop_rightarm_only(center_x, center_y, width, height, keypoints_2d)
+ else:
+ center_x, center_y, width, height = crop_leftarm_only(center_x, center_y, width, height, keypoints_2d)
+ return center_x, center_y, max(width, height)
+
+
+def gen_trans_from_patch_cv(
+ c_x : float,
+ c_y : float,
+ src_width : float,
+ src_height : float,
+ dst_width : float,
+ dst_height : float,
+ scale : float,
+ rot : float
+) -> np.ndarray:
+ '''
+ Create transformation matrix for the bounding box crop.
+ Copied from: https://github.com/shubham-goel/4D-Humans/blob/6ec79656a23c33237c724742ca2a0ec00b398b53/hmr2/datasets/utils.py#L107-L154
+
+ ### Args
+ - c_x: float
+ - Bounding box center x coordinate in the original image.
+ - c_y: float
+ - Bounding box center y coordinate in the original image.
+ - src_width: float
+ - Bounding box width.
+ - src_height: float
+ - Bounding box height.
+ - dst_width: float
+ - Output box width.
+ - dst_height: float
+ - Output box height.
+ - scale: float
+ - Rescaling factor for the bounding box (augmentation).
+ - rot: float
+ - Random rotation applied to the box.
+
+ ### Returns
+ - trans: np.ndarray
+ - Target geometric transformation.
+ '''
+ # augment size with scale
+ src_w = src_width * scale
+ src_h = src_height * scale
+ src_center = np.zeros(2)
+ src_center[0] = c_x
+ src_center[1] = c_y
+ # augment rotation
+ rot_rad = np.pi * rot / 180
+ src_downdir = rotate_2d(np.array([0, src_h * 0.5], dtype=np.float32), rot_rad)
+ src_rightdir = rotate_2d(np.array([src_w * 0.5, 0], dtype=np.float32), rot_rad)
+
+ dst_w = dst_width
+ dst_h = dst_height
+ dst_center = np.array([dst_w * 0.5, dst_h * 0.5], dtype=np.float32)
+ dst_downdir = np.array([0, dst_h * 0.5], dtype=np.float32)
+ dst_rightdir = np.array([dst_w * 0.5, 0], dtype=np.float32)
+
+ src = np.zeros((3, 2), dtype=np.float32)
+ src[0, :] = src_center
+ src[1, :] = src_center + src_downdir
+ src[2, :] = src_center + src_rightdir
+
+ dst = np.zeros((3, 2), dtype=np.float32)
+ dst[0, :] = dst_center
+ dst[1, :] = dst_center + dst_downdir
+ dst[2, :] = dst_center + dst_rightdir
+
+ trans = cv2.getAffineTransform(np.float32(src), np.float32(dst)) # (2, 3) # type: ignore
+
+ return trans
+
+
+def generate_image_patch_cv2(
+ img : np.ndarray,
+ c_x : float,
+ c_y : float,
+ bb_width : float,
+ bb_height : float,
+ patch_width : float,
+ patch_height : float,
+ do_flip : bool,
+ scale : float,
+ rot : float,
+ border_mode = cv2.BORDER_CONSTANT,
+ border_value = 0,
+) -> Tuple[np.ndarray, np.ndarray]:
+ '''
+ Crop the input image and return the crop and the corresponding transformation matrix.
+ Copied from: https://github.com/shubham-goel/4D-Humans/blob/6ec79656a23c33237c724742ca2a0ec00b398b53/hmr2/datasets/utils.py#L343-L386
+
+ ### Args
+ - img: np.ndarray, shape = (H, W, 3)
+ - c_x: float
+ - Bounding box center x coordinate in the original image.
+ - c_y: float
+ - Bounding box center y coordinate in the original image.
+ - bb_width: float
+ - Bounding box width.
+ - bb_height: float
+ - Bounding box height.
+ - patch_width: float
+ - Output box width.
+ - patch_height: float
+ - Output box height.
+ - do_flip: bool
+ - Whether to flip image or not.
+ - scale: float
+ - Rescaling factor for the bounding box (augmentation).
+ - rot: float
+ - Random rotation applied to the box.
+ ### Returns
+ - img_patch: np.ndarray
+ - Cropped image patch of shape (patch_height, patch_height, 3)
+ - trans: np.ndarray
+ - Transformation matrix.
+ '''
+ img_height, img_width, img_channels = img.shape
+ if do_flip:
+ img = img[:, ::-1, :]
+ c_x = img_width - c_x - 1
+
+ trans = gen_trans_from_patch_cv(c_x, c_y, bb_width, bb_height, patch_width, patch_height, scale, rot) # (2, 3)
+
+ img_patch = cv2.warpAffine(img, trans, (int(patch_width), int(patch_height)),
+ flags=cv2.INTER_LINEAR,
+ borderMode=border_mode,
+ borderValue=border_value,
+ ) # type: ignore
+ # Force borderValue=cv2.BORDER_CONSTANT for alpha channel
+ if (img.shape[2] == 4) and (border_mode != cv2.BORDER_CONSTANT):
+ img_patch[:,:,3] = cv2.warpAffine(img[:,:,3], trans, (int(patch_width), int(patch_height)),
+ flags=cv2.INTER_LINEAR,
+ borderMode=cv2.BORDER_CONSTANT,
+ )
+
+ return img_patch, trans
+
+
+def expand_to_aspect_ratio(input_shape, target_aspect_ratio=None):
+ '''
+ Increase the size of the bounding box to match the target shape.
+ Copied from https://github.com/shubham-goel/4D-Humans/blob/6ec79656a23c33237c724742ca2a0ec00b398b53/hmr2/datasets/utils.py#L14-L33
+ '''
+ if target_aspect_ratio is None:
+ return input_shape
+
+ try:
+ w , h = input_shape
+ except (ValueError, TypeError):
+ return input_shape
+
+ w_t, h_t = target_aspect_ratio
+ if h / w < h_t / w_t:
+ h_new = max(w * h_t / w_t, h)
+ w_new = w
+ else:
+ h_new = h
+ w_new = max(h * w_t / h_t, w)
+ if h_new < h or w_new < w:
+ breakpoint()
+ return np.array([w_new, h_new])
+
+
+body_permutation = [0, 1, 5, 6, 7, 2, 3, 4, 8, 12, 13, 14, 9, 10, 11, 16, 15, 18, 17, 22, 23, 24, 19, 20, 21]
+extra_permutation = [5, 4, 3, 2, 1, 0, 11, 10, 9, 8, 7, 6, 12, 13, 14, 15, 16, 17, 18]
+FLIP_KP_PERMUTATION = body_permutation + [25 + i for i in extra_permutation]
+
+def flip_lr_keypoints(joints: np.ndarray, width: float) -> np.ndarray:
+ """
+ Flip 2D or 3D keypoints.
+ Modified from: https://github.com/shubham-goel/4D-Humans/blob/6ec79656a23c33237c724742ca2a0ec00b398b53/hmr2/datasets/utils.py#L448-L462
+
+ ### Args
+ - joints: np.ndarray
+ - Array of shape (N, 3) or (N, 4) containing 2D or 3D keypoint locations and confidence.
+ - flip_permutation: list
+ - Permutation to apply after flipping.
+ ### Returns
+ - np.ndarray
+ - Flipped 2D or 3D keypoints with shape (N, 3) or (N, 4) respectively.
+ """
+ joints = joints.copy()
+ # Flip horizontal
+ joints[:, 0] = width - joints[:, 0] - 1
+ joints = joints[FLIP_KP_PERMUTATION]
+
+ return joints
\ No newline at end of file
diff --git a/lib/data/datasets/hsmr_v1/wds_loader.py b/lib/data/datasets/hsmr_v1/wds_loader.py
new file mode 100644
index 0000000000000000000000000000000000000000..b318bb61326c244d7da20920c7162aaa57db6d17
--- /dev/null
+++ b/lib/data/datasets/hsmr_v1/wds_loader.py
@@ -0,0 +1,56 @@
+from lib.kits.basic import *
+
+import webdataset as wds
+
+from .utils import *
+from .stream_pipelines import *
+
+# This line is to fix the problem of "OSError: image file is truncated" when loading images.
+from PIL import ImageFile
+ImageFile.LOAD_TRUNCATED_IMAGES = True
+
+def load_tars_as_wds(
+ cfg : DictConfig,
+ urls : Union[str, List[str]],
+ resampled : bool = False,
+ epoch_size : int = None,
+ cache_dir : str = None,
+ train : bool = True,
+):
+ urls = expand_urls(urls) # to list of URL strings
+
+ dataset : wds.WebDataset = wds.WebDataset(
+ urls,
+ nodesplitter = wds.split_by_node,
+ shardshuffle = True,
+ resampled = resampled,
+ cache_dir = cache_dir,
+ )
+ if train:
+ dataset = dataset.shuffle(100)
+
+ # A lot of processes to initialize the dataset. Check the pipeline generator function for more details.
+ # The order of the pipeline is important, since some of the process are dependent on some previous ones.
+ dataset = apply_corrupt_filter(dataset)
+ dataset = dataset.decode('rgb8').rename(jpg='jpg;jpeg;png')
+ dataset = apply_multi_ppl_splitter(dataset)
+ dataset = apply_keys_adapter(dataset) #* This adapter is only in HSMR's design, not in the baseline.
+ dataset = apply_bad_pgt_params_nan_suppressor(dataset)
+ dataset = apply_bad_pgt_params_kp2d_err_suppressor(dataset, cfg.get('suppress_pgt_params_kp2d_err_thresh', 0.0))
+ dataset = apply_bad_pgt_params_pve_max_suppressor(dataset, cfg.get('suppress_pgt_params_pve_max_thresh', 0.0))
+ dataset = apply_bad_kp_suppressor(dataset, cfg.get('suppress_kp_conf_thresh', 0.0))
+ dataset = apply_bad_betas_suppressor(dataset, cfg.get('suppress_betas_thresh', 0.0))
+ # dataset = apply_bad_pose_suppressor(dataset, cfg.get('suppress_pose_thresh', 0.0)) # Not used in baseline, so not implemented.
+ dataset = apply_params_synchronizer(dataset, cfg.get('poses_betas_simultaneous', False))
+ # dataset = apply_no_pose_filter(dataset, cfg.get('no_pose_filter', False)) # Not used in baseline, so not implemented.
+ dataset = apply_insuff_kp_filter(dataset, cfg.get('filter_insufficient_kp_cnt', 4), cfg.get('suppress_insufficient_kp_thresh', 0.0))
+ dataset = apply_bbox_size_filter(dataset, cfg.get('filter_bbox_size_thresh', None))
+ dataset = apply_reproj_err_filter(dataset, cfg.get('filter_reproj_err_thresh', 0.0))
+ dataset = apply_invalid_betas_regularizer(dataset, cfg.get('regularize_invalid_betas', False))
+
+ # Final preprocess / format of the data. (Consider to extract the augmentation process.)
+ dataset = apply_example_formatter(dataset, cfg)
+
+ if epoch_size is not None:
+ dataset = dataset.with_epoch(epoch_size)
+ return dataset
\ No newline at end of file
diff --git a/lib/data/datasets/skel_hmr2_fashion/__init__.py b/lib/data/datasets/skel_hmr2_fashion/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..3eefc709bdd1c4d39d7acbc0f8e79739d7b6ba16
--- /dev/null
+++ b/lib/data/datasets/skel_hmr2_fashion/__init__.py
@@ -0,0 +1,98 @@
+from typing import Dict, Optional
+
+import torch
+import numpy as np
+import pytorch_lightning as pl
+from yacs.config import CfgNode
+
+import webdataset as wds
+from .dataset import Dataset
+from .image_dataset import ImageDataset
+from .mocap_dataset import MoCapDataset
+
+def to_lower(x: Dict) -> Dict:
+ """
+ Convert all dictionary keys to lowercase
+ Args:
+ x (dict): Input dictionary
+ Returns:
+ dict: Output dictionary with all keys converted to lowercase
+ """
+ return {k.lower(): v for k, v in x.items()}
+
+def create_dataset(cfg: CfgNode, dataset_cfg: CfgNode, train: bool = True, **kwargs) -> Dataset:
+ """
+ Instantiate a dataset from a config file.
+ Args:
+ cfg (CfgNode): Model configuration file.
+ dataset_cfg (CfgNode): Dataset configuration info.
+ train (bool): Variable to select between train and val datasets.
+ """
+
+ dataset_type = Dataset.registry[dataset_cfg.TYPE]
+ return dataset_type(cfg, **to_lower(dataset_cfg), train=train, **kwargs)
+
+def create_webdataset(cfg: CfgNode, dataset_cfg: CfgNode, train: bool = True) -> Dataset:
+ """
+ Like `create_dataset` but load data from tars.
+ """
+ dataset_type = Dataset.registry[dataset_cfg.TYPE]
+ return dataset_type.load_tars_as_webdataset(cfg, **to_lower(dataset_cfg), train=train)
+
+
+class MixedWebDataset(wds.WebDataset):
+ def __init__(self, cfg: CfgNode, dataset_cfg: CfgNode, train: bool = True) -> None:
+ super(wds.WebDataset, self).__init__()
+ dataset_list = cfg.DATASETS.TRAIN if train else cfg.DATASETS.VAL
+ datasets = [create_webdataset(cfg, dataset_cfg[dataset], train=train) for dataset, v in dataset_list.items()]
+ weights = np.array([v.WEIGHT for dataset, v in dataset_list.items()])
+ weights = weights / weights.sum() # normalize
+ self.append(wds.RandomMix(datasets, weights))
+
+class HMR2DataModule(pl.LightningDataModule):
+
+ def __init__(self, cfg: CfgNode, dataset_cfg: CfgNode) -> None:
+ """
+ Initialize LightningDataModule for HMR2 training
+ Args:
+ cfg (CfgNode): Config file as a yacs CfgNode containing necessary dataset info.
+ dataset_cfg (CfgNode): Dataset configuration file
+ """
+ super().__init__()
+ self.cfg = cfg
+ self.dataset_cfg = dataset_cfg
+ self.train_dataset = None
+ self.val_dataset = None
+ self.test_dataset = None
+ self.mocap_dataset = None
+
+ def setup(self, stage: Optional[str] = None) -> None:
+ """
+ Load datasets necessary for training
+ Args:
+ cfg (CfgNode): Config file as a yacs CfgNode containing necessary dataset info.
+ """
+ if self.train_dataset == None:
+ self.train_dataset = MixedWebDataset(self.cfg, self.dataset_cfg, train=True).with_epoch(100_000).shuffle(4000)
+ # self.val_dataset = MixedWebDataset(self.cfg, self.dataset_cfg, train=False).shuffle(4000)
+ self.mocap_dataset = MoCapDataset(**to_lower(self.dataset_cfg[self.cfg.DATASETS.MOCAP]))
+
+ def train_dataloader(self) -> Dict:
+ """
+ Setup training data loader.
+ Returns:
+ Dict: Dictionary containing image and mocap data dataloaders
+ """
+ train_dataloader = torch.utils.data.DataLoader(self.train_dataset, self.cfg.TRAIN.BATCH_SIZE, drop_last=True, num_workers=self.cfg.GENERAL.NUM_WORKERS, prefetch_factor=self.cfg.GENERAL.PREFETCH_FACTOR)
+ mocap_dataloader = torch.utils.data.DataLoader(self.mocap_dataset, self.cfg.TRAIN.NUM_TRAIN_SAMPLES * self.cfg.TRAIN.BATCH_SIZE, shuffle=True, drop_last=True, num_workers=1)
+ return {'img': train_dataloader, 'mocap': mocap_dataloader}
+ # return {'img': train_dataloader}
+
+ # def val_dataloader(self) -> torch.utils.data.DataLoader:
+ # """
+ # Setup val data loader.
+ # Returns:
+ # torch.utils.data.DataLoader: Validation dataloader
+ # """
+ # val_dataloader = torch.utils.data.DataLoader(self.val_dataset, self.cfg.TRAIN.BATCH_SIZE, drop_last=True, num_workers=self.cfg.GENERAL.NUM_WORKERS)
+ # return val_dataloader
diff --git a/lib/data/datasets/skel_hmr2_fashion/dataset.py b/lib/data/datasets/skel_hmr2_fashion/dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..22fc5bc5f4a7b75da672bd89859da14823e71aff
--- /dev/null
+++ b/lib/data/datasets/skel_hmr2_fashion/dataset.py
@@ -0,0 +1,27 @@
+"""
+This file contains the defition of the base Dataset class.
+"""
+
+class DatasetRegistration(type):
+ """
+ Metaclass for registering different datasets
+ """
+ def __init__(cls, name, bases, nmspc):
+ super().__init__(name, bases, nmspc)
+ if not hasattr(cls, 'registry'):
+ cls.registry = dict()
+ cls.registry[name] = cls
+
+ # Metamethods, called on class objects:
+ def __iter__(cls):
+ return iter(cls.registry)
+
+ def __str__(cls):
+ return str(cls.registry)
+
+class Dataset(metaclass=DatasetRegistration):
+ """
+ Base Dataset class
+ """
+ def __init__(self, *args, **kwargs):
+ pass
\ No newline at end of file
diff --git a/lib/data/datasets/skel_hmr2_fashion/image_dataset.py b/lib/data/datasets/skel_hmr2_fashion/image_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..d483b32b880d2db9a8b99f0c7b101f7c3d9b58dc
--- /dev/null
+++ b/lib/data/datasets/skel_hmr2_fashion/image_dataset.py
@@ -0,0 +1,484 @@
+import copy
+import os
+import numpy as np
+import torch
+from typing import Any, Dict, List, Union
+from yacs.config import CfgNode
+import braceexpand
+import cv2
+from ipdb import set_trace
+
+from .dataset import Dataset
+from .utils import get_example, expand_to_aspect_ratio
+
+def expand(s):
+ return os.path.expanduser(os.path.expandvars(s))
+
+def expand_urls(urls: Union[str, List[str]]):
+ if isinstance(urls, str):
+ urls = [urls]
+ urls = [u for url in urls for u in braceexpand.braceexpand(expand(url))]
+ return urls
+
+AIC_TRAIN_CORRUPT_KEYS = {
+ '0a047f0124ae48f8eee15a9506ce1449ee1ba669',
+ '1a703aa174450c02fbc9cfbf578a5435ef403689',
+ '0394e6dc4df78042929b891dbc24f0fd7ffb6b6d',
+ '5c032b9626e410441544c7669123ecc4ae077058',
+ 'ca018a7b4c5f53494006ebeeff9b4c0917a55f07',
+ '4a77adb695bef75a5d34c04d589baf646fe2ba35',
+ 'a0689017b1065c664daef4ae2d14ea03d543217e',
+ '39596a45cbd21bed4a5f9c2342505532f8ec5cbb',
+ '3d33283b40610d87db660b62982f797d50a7366b',
+}
+CORRUPT_KEYS = {
+ *{f'aic-train/{k}' for k in AIC_TRAIN_CORRUPT_KEYS},
+ *{f'aic-train-vitpose/{k}' for k in AIC_TRAIN_CORRUPT_KEYS},
+}
+
+body_permutation = [0, 1, 5, 6, 7, 2, 3, 4, 8, 12, 13, 14, 9, 10, 11, 16, 15, 18, 17, 22, 23, 24, 19, 20, 21]
+extra_permutation = [5, 4, 3, 2, 1, 0, 11, 10, 9, 8, 7, 6, 12, 13, 14, 15, 16, 17, 18]
+FLIP_KEYPOINT_PERMUTATION = body_permutation + [25 + i for i in extra_permutation]
+
+DEFAULT_MEAN = 255. * np.array([0.485, 0.456, 0.406])
+DEFAULT_STD = 255. * np.array([0.229, 0.224, 0.225])
+DEFAULT_IMG_SIZE = 256
+
+class ImageDataset(Dataset):
+
+ def __init__(self,
+ cfg: CfgNode,
+ dataset_file: str,
+ img_dir: str,
+ train: bool = True,
+ prune: Dict[str, Any] = {},
+ **kwargs):
+ """
+ Dataset class used for loading images and corresponding annotations.
+ Args:
+ cfg (CfgNode): Model config file.
+ dataset_file (str): Path to npz file containing dataset info.
+ img_dir (str): Path to image folder.
+ train (bool): Whether it is for training or not (enables data augmentation).
+ """
+ super(ImageDataset, self).__init__()
+ self.train = train
+ self.cfg = cfg
+
+ self.img_size = cfg['IMAGE_SIZE']
+ self.mean = 255. * np.array(self.cfg['IMAGE_MEAN'])
+ self.std = 255. * np.array(self.cfg['IMAGE_STD'])
+
+ self.img_dir = img_dir
+ self.data = np.load(dataset_file, allow_pickle=True)
+
+ self.imgname = self.data['imgname']
+ self.personid = np.zeros(len(self.imgname), dtype=np.int32)
+ self.extra_info = self.data.get('extra_info', [{} for _ in range(len(self.imgname))])
+
+ self.flip_keypoint_permutation = copy.copy(FLIP_KEYPOINT_PERMUTATION)
+
+ num_pose = 3 * 24
+
+ # Bounding boxes are assumed to be in the center and scale format
+ self.center = self.data['center']
+ self.scale = self.data['scale'].reshape(len(self.center), -1) / 200.0
+ if self.scale.shape[1] == 1:
+ self.scale = np.tile(self.scale, (1, 2))
+ assert self.scale.shape == (len(self.center), 2)
+
+ # Get gt SMPLX parameters, if available
+ try:
+ self.body_pose = self.data['body_pose'].astype(np.float32)
+ self.has_body_pose = self.data['has_body_pose'].astype(np.float32)
+ except KeyError:
+ self.body_pose = np.zeros((len(self.imgname), num_pose), dtype=np.float32)
+ self.has_body_pose = np.zeros(len(self.imgname), dtype=np.float32)
+ try:
+ self.betas = self.data['betas'].astype(np.float32)
+ self.has_betas = self.data['has_betas'].astype(np.float32)
+ except KeyError:
+ self.betas = np.zeros((len(self.imgname), 10), dtype=np.float32)
+ self.has_betas = np.zeros(len(self.imgname), dtype=np.float32)
+
+ # try:
+ # self.trans = self.data['trans'].astype(np.float32)
+ # except KeyError:
+ # self.trans = np.zeros((len(self.imgname), 3), dtype=np.float32)
+
+ # Try to get 2d keypoints, if available
+ try:
+ body_keypoints_2d = self.data['body_keypoints_2d']
+ except KeyError:
+ body_keypoints_2d = np.zeros((len(self.center), 25, 3))
+ # Try to get extra 2d keypoints, if available
+ try:
+ extra_keypoints_2d = self.data['extra_keypoints_2d']
+ except KeyError:
+ extra_keypoints_2d = np.zeros((len(self.center), 19, 3))
+
+ self.keypoints_2d = np.concatenate((body_keypoints_2d, extra_keypoints_2d), axis=1).astype(np.float32)
+
+ # Try to get 3d keypoints, if available
+ try:
+ body_keypoints_3d = self.data['body_keypoints_3d'].astype(np.float32)
+ except KeyError:
+ body_keypoints_3d = np.zeros((len(self.center), 25, 4), dtype=np.float32)
+ # Try to get extra 3d keypoints, if available
+ try:
+ extra_keypoints_3d = self.data['extra_keypoints_3d'].astype(np.float32)
+ except KeyError:
+ extra_keypoints_3d = np.zeros((len(self.center), 19, 4), dtype=np.float32)
+
+ body_keypoints_3d[:, [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14], -1] = 0
+
+ self.keypoints_3d = np.concatenate((body_keypoints_3d, extra_keypoints_3d), axis=1).astype(np.float32)
+
+ def __len__(self) -> int:
+ return len(self.scale)
+
+ def __getitem__(self, idx: int) -> Dict:
+ """
+ Returns an example from the dataset.
+ """
+ try:
+ image_file_rel = self.imgname[idx].decode('utf-8')
+ except AttributeError:
+ image_file_rel = self.imgname[idx]
+ image_file = os.path.join(self.img_dir, image_file_rel)
+ keypoints_2d = self.keypoints_2d[idx].copy()
+ keypoints_3d = self.keypoints_3d[idx].copy()
+
+ center = self.center[idx].copy()
+ center_x = center[0]
+ center_y = center[1]
+ scale = self.scale[idx]
+ BBOX_SHAPE = self.cfg['BBOX_SHAPE']
+ bbox_size = expand_to_aspect_ratio(scale*200, target_aspect_ratio=BBOX_SHAPE).max()
+ bbox_expand_factor = bbox_size / ((scale*200).max())
+ body_pose = self.body_pose[idx].copy().astype(np.float32)
+ betas = self.betas[idx].copy().astype(np.float32)
+ # trans = self.trans[idx].copy().astype(np.float32)
+
+ has_body_pose = self.has_body_pose[idx].copy()
+ has_betas = self.has_betas[idx].copy()
+
+ smpl_params = {'global_orient': body_pose[:3],
+ 'body_pose': body_pose[3:],
+ 'betas': betas,
+ # 'trans': trans,
+ }
+
+ has_smpl_params = {'global_orient': has_body_pose,
+ 'body_pose': has_body_pose,
+ 'betas': has_betas
+ }
+
+ smpl_params_is_axis_angle = {'global_orient': True,
+ 'body_pose': True,
+ 'betas': False
+ }
+
+ augm_config = self.cfg['augm']
+ # Crop image and (possibly) perform data augmentation
+ img_patch, keypoints_2d, keypoints_3d, smpl_params, has_smpl_params, img_size, augm_record = get_example(image_file,
+ center_x, center_y,
+ bbox_size, bbox_size,
+ keypoints_2d, keypoints_3d,
+ smpl_params, has_smpl_params,
+ self.flip_keypoint_permutation,
+ self.img_size, self.img_size,
+ self.mean, self.std, self.train, augm_config)
+
+ item = {}
+ # These are the keypoints in the original image coordinates (before cropping)
+ orig_keypoints_2d = self.keypoints_2d[idx].copy()
+
+ item['img_patch'] = img_patch
+ item['keypoints_2d'] = keypoints_2d.astype(np.float32)
+ item['keypoints_3d'] = keypoints_3d.astype(np.float32)
+ item['orig_keypoints_2d'] = orig_keypoints_2d
+ item['box_center'] = self.center[idx].copy()
+ item['box_size'] = bbox_size
+ item['bbox_expand_factor'] = bbox_expand_factor
+ item['img_size'] = 1.0 * img_size[::-1].copy()
+ item['smpl_params'] = smpl_params
+ item['has_smpl_params'] = has_smpl_params
+ item['smpl_params_is_axis_angle'] = smpl_params_is_axis_angle
+ item['imgname'] = image_file
+ item['imgname_rel'] = image_file_rel
+ item['personid'] = int(self.personid[idx])
+ item['extra_info'] = copy.deepcopy(self.extra_info[idx])
+ item['idx'] = idx
+ item['_scale'] = scale
+ item['augm_record'] = augm_record # Augmentation record for recovery in self-improvement process.
+ return item
+
+
+ @staticmethod
+ def load_tars_as_webdataset(cfg: CfgNode, urls: Union[str, List[str]], train: bool,
+ resampled=False,
+ epoch_size=None,
+ cache_dir=None,
+ **kwargs) -> Dataset:
+ """
+ Loads the dataset from a webdataset tar file.
+ """
+ from .smplh_prob_filter import poses_check_probable, load_amass_hist_smooth
+
+ IMG_SIZE = cfg['IMAGE_SIZE']
+ BBOX_SHAPE = cfg['BBOX_SHAPE']
+ MEAN = 255. * np.array(cfg['IMAGE_MEAN'])
+ STD = 255. * np.array(cfg['IMAGE_STD'])
+
+ def split_data(source):
+ for item in source:
+ datas = item['data.pyd']
+ for pid, data in enumerate(datas):
+ data['pid'] = pid
+ data['orig_has_body_pose'] = data['orig_pve_max'] < 0.1
+ data['orig_has_betas'] = data['orig_pve_max'] < 0.1
+ if 'flip_pve_mean' in data:
+ data['flip_has_body_pose'] = data['flip_pve_mean'] < 0.05 # TODO: fix this problems
+ data['flip_has_betas'] = data['flip_has_body_pose']
+ else:
+ data['flip_has_body_pose'] = data['flip_pve_max'] < 0.65
+ data['flip_has_betas'] = data['flip_has_body_pose']
+ # data['body_pose'] = data['orig_poses']
+ # data['betas'] = data['orig_betas']
+ if 'detection.npz' in item:
+ det_idx = data['extra_info']['detection_npz_idx']
+ mask = item['detection.npz']['masks'][det_idx]
+ else:
+ mask = np.ones_like(item['jpg'][:,:,0], dtype=bool)
+ yield {
+ '__key__': item['__key__'],
+ 'jpg': item['jpg'],
+ 'data.pyd': data,
+ 'mask': mask,
+ }
+
+ def suppress_bad_kps(item, thresh=0.0):
+ if thresh > 0:
+ kp2d = item['data.pyd']['keypoints_2d']
+ kp2d_conf = np.where(kp2d[:, 2] < thresh, 0.0, kp2d[:, 2])
+ item['data.pyd']['keypoints_2d'] = np.concatenate([kp2d[:,:2], kp2d_conf[:,None]], axis=1)
+ return item
+
+ def filter_numkp(item, numkp=4, thresh=0.0):
+ kp_conf = item['data.pyd']['keypoints_2d'][:, 2]
+ return (kp_conf > thresh).sum() > numkp
+
+ def filter_reproj_error(item, thresh=10**4.5):
+ losses = item['data.pyd'].get('extra_info', {}).get('fitting_loss', np.array({})).item()
+ reproj_loss = losses.get('reprojection_loss', None)
+ return reproj_loss is None or reproj_loss < thresh
+
+ def filter_bbox_size(item, thresh=1):
+ bbox_size_min = item['data.pyd']['scale'].min().item() * 200.
+ return bbox_size_min > thresh
+
+ def filter_no_poses(item):
+ return (item['data.pyd']['has_body_pose'] > 0)
+
+ def supress_bad_betas(item, thresh=3):
+ for side in ['orig', 'flip']:
+ has_betas = item['data.pyd'][f'{side}_has_betas']
+ if thresh > 0 and has_betas:
+ betas_abs = np.abs(item['data.pyd'][f'{side}_betas'])
+ if (betas_abs > thresh).any():
+ item['data.pyd'][f'{side}_has_betas'] = False
+
+ return item
+
+ amass_poses_hist100_smooth = load_amass_hist_smooth()
+ def supress_bad_poses(item):
+ for side in ['orig', 'flip']:
+ has_body_pose = item['data.pyd'][f'{side}_has_body_pose']
+ if has_body_pose:
+ body_pose = item['data.pyd'][f'{side}_body_pose']
+ pose_is_probable = poses_check_probable(torch.from_numpy(body_pose)[None, 3:], amass_poses_hist100_smooth).item()
+ if not pose_is_probable:
+ item['data.pyd'][f'{side}_has_body_pose'] = False
+ return item
+
+ def poses_betas_simultaneous(item):
+ # We either have both body_pose and betas, or neither
+ for side in ['orig', 'flip']:
+ has_betas = item['data.pyd'][f'{side}_has_betas']
+ has_body_pose = item['data.pyd'][f'{side}_has_body_pose']
+ item['data.pyd'][f'{side}_has_betas'] = item['data.pyd'][f'{side}_has_body_pose'] = np.array(float((has_body_pose>0) and (has_betas>0)))
+ return item
+
+ def set_betas_for_reg(item):
+ for side in ['orig', 'flip']:
+ # Always have betas set to true
+ has_betas = item['data.pyd'][f'{side}_has_betas']
+ betas = item['data.pyd'][f'{side}_betas']
+
+ if not (has_betas>0):
+ item['data.pyd'][f'{side}_has_betas'] = np.array(float((True)))
+ item['data.pyd'][f'{side}_betas'] = betas * 0
+ return item
+
+ # Load the dataset
+ if epoch_size is not None:
+ resampled = True
+ corrupt_filter = lambda sample: (sample['__key__'] not in CORRUPT_KEYS)
+ import webdataset as wds
+ dataset = wds.WebDataset(expand_urls(urls),
+ nodesplitter=wds.split_by_node,
+ shardshuffle=True,
+ resampled=resampled,
+ cache_dir=cache_dir,
+ ).select(corrupt_filter)
+ if train:
+ dataset = dataset.shuffle(100)
+ dataset = dataset.decode('rgb8').rename(jpg='jpg;jpeg;png')
+
+ # Process the dataset
+ dataset = dataset.compose(split_data)
+
+ # Filter/clean the dataset
+ SUPPRESS_KP_CONF_THRESH = cfg.get('SUPPRESS_KP_CONF_THRESH', 0.0)
+ SUPPRESS_BETAS_THRESH = cfg.get('SUPPRESS_BETAS_THRESH', 0.0)
+ SUPPRESS_BAD_POSES = cfg.get('SUPPRESS_BAD_POSES', False)
+ POSES_BETAS_SIMULTANEOUS = cfg.get('POSES_BETAS_SIMULTANEOUS', False)
+ BETAS_REG = cfg.get('BETAS_REG', False)
+ FILTER_NO_POSES = cfg.get('FILTER_NO_POSES', False)
+ FILTER_NUM_KP = cfg.get('FILTER_NUM_KP', 4)
+ FILTER_NUM_KP_THRESH = cfg.get('FILTER_NUM_KP_THRESH', 0.0)
+ FILTER_REPROJ_THRESH = cfg.get('FILTER_REPROJ_THRESH', 0.0)
+ FILTER_MIN_BBOX_SIZE = cfg.get('FILTER_MIN_BBOX_SIZE', 0.0)
+ if SUPPRESS_KP_CONF_THRESH > 0:
+ dataset = dataset.map(lambda x: suppress_bad_kps(x, thresh=SUPPRESS_KP_CONF_THRESH))
+ if SUPPRESS_BETAS_THRESH > 0:
+ dataset = dataset.map(lambda x: supress_bad_betas(x, thresh=SUPPRESS_BETAS_THRESH))
+ if SUPPRESS_BAD_POSES:
+ dataset = dataset.map(lambda x: supress_bad_poses(x))
+ if POSES_BETAS_SIMULTANEOUS:
+ dataset = dataset.map(lambda x: poses_betas_simultaneous(x))
+ if FILTER_NO_POSES:
+ dataset = dataset.select(lambda x: filter_no_poses(x))
+ if FILTER_NUM_KP > 0:
+ dataset = dataset.select(lambda x: filter_numkp(x, numkp=FILTER_NUM_KP, thresh=FILTER_NUM_KP_THRESH))
+ if FILTER_REPROJ_THRESH > 0:
+ dataset = dataset.select(lambda x: filter_reproj_error(x, thresh=FILTER_REPROJ_THRESH))
+ if FILTER_MIN_BBOX_SIZE > 0:
+ dataset = dataset.select(lambda x: filter_bbox_size(x, thresh=FILTER_MIN_BBOX_SIZE))
+ if BETAS_REG:
+ dataset = dataset.map(lambda x: set_betas_for_reg(x)) # NOTE: Must be at the end
+
+ use_skimage_antialias = cfg.get('USE_SKIMAGE_ANTIALIAS', False)
+ border_mode = {
+ 'constant': cv2.BORDER_CONSTANT,
+ 'replicate': cv2.BORDER_REPLICATE,
+ }[cfg.get('BORDER_MODE', 'constant')]
+
+ # Process the dataset further
+ dataset = dataset.map(lambda x: ImageDataset.process_webdataset_tar_item(x, train,
+ augm_config=cfg['augm'],
+ MEAN=MEAN, STD=STD, IMG_SIZE=IMG_SIZE,
+ BBOX_SHAPE=BBOX_SHAPE,
+ use_skimage_antialias=use_skimage_antialias,
+ border_mode=border_mode,
+ ))
+ if epoch_size is not None:
+ dataset = dataset.with_epoch(epoch_size)
+
+ return dataset
+
+ @staticmethod
+ def process_webdataset_tar_item(item, train,
+ augm_config=None,
+ MEAN=DEFAULT_MEAN,
+ STD=DEFAULT_STD,
+ IMG_SIZE=DEFAULT_IMG_SIZE,
+ BBOX_SHAPE=None,
+ use_skimage_antialias=False,
+ border_mode=cv2.BORDER_CONSTANT,
+ ):
+ # Read data from item
+ key = item['__key__']
+ image = item['jpg']
+ data = item['data.pyd']
+ mask = item['mask']
+ pid = data['pid']
+
+ keypoints_2d = data['keypoints_2d']
+ keypoints_3d = data['keypoints_3d']
+ center = data['center']
+ scale = data['scale']
+ body_pose = (data['orig_poses'], data['flip_poses'])
+ betas = (data['orig_betas'], data['flip_betas'])
+ # trans = data['trans']
+ has_body_pose = (data['orig_has_body_pose'], data['flip_has_body_pose'])
+ has_betas = (data['orig_has_betas'], data['flip_has_betas'])
+ # image_file = data['image_file']
+
+ # Process data
+ orig_keypoints_2d = keypoints_2d.copy()
+ center_x = center[0]
+ center_y = center[1]
+ bbox_size = expand_to_aspect_ratio(scale*200, target_aspect_ratio=BBOX_SHAPE).max()
+ if bbox_size < 1:
+ breakpoint()
+
+
+ smpl_params = {'global_orient': (body_pose[0][:3], body_pose[1][:3]),
+ 'body_pose': (body_pose[0][3:], body_pose[1][3:]),
+ 'betas': betas,
+ # 'trans': trans,
+ }
+
+ has_smpl_params = {'global_orient': has_body_pose,
+ 'body_pose': has_body_pose,
+ 'betas': has_betas
+ }
+
+ smpl_params_is_axis_angle = {'global_orient': True,
+ 'body_pose': True,
+ 'betas': False
+ }
+
+ augm_config = copy.deepcopy(augm_config)
+ # Crop image and (possibly) perform data augmentation
+ img_rgba = np.concatenate([image, mask.astype(np.uint8)[:,:,None]*255], axis=2)
+ img_patch_rgba, keypoints_2d, keypoints_3d, smpl_params, has_smpl_params, img_size, trans, augm_record = get_example(img_rgba,
+ center_x, center_y,
+ bbox_size, bbox_size,
+ keypoints_2d, keypoints_3d,
+ smpl_params, has_smpl_params,
+ FLIP_KEYPOINT_PERMUTATION,
+ IMG_SIZE, IMG_SIZE,
+ MEAN, STD, train, augm_config,
+ is_bgr=False, return_trans=True,
+ use_skimage_antialias=use_skimage_antialias,
+ border_mode=border_mode,
+ )
+ img_patch = img_patch_rgba[:3,:,:]
+ mask_patch = (img_patch_rgba[3,:,:] / 255.0).clip(0,1)
+ if (mask_patch < 0.5).all():
+ mask_patch = np.ones_like(mask_patch)
+
+ item = {}
+
+ item['img'] = img_patch
+ item['mask'] = mask_patch
+ # item['img_og'] = image
+ # item['mask_og'] = mask
+ item['keypoints_2d'] = keypoints_2d.astype(np.float32)
+ item['keypoints_3d'] = keypoints_3d.astype(np.float32)
+ item['orig_keypoints_2d'] = orig_keypoints_2d
+ item['box_center'] = center.copy()
+ item['box_size'] = bbox_size
+ item['img_size'] = 1.0 * img_size[::-1].copy()
+ item['smpl_params'] = smpl_params
+ item['has_smpl_params'] = has_smpl_params
+ item['smpl_params_is_axis_angle'] = smpl_params_is_axis_angle
+ item['_scale'] = scale
+ item['_trans'] = trans
+ item['imgname'] = key
+ item['pid'] = pid
+ item['augm_record'] = augm_record # Augmentation record for recovery in self-improvement process.
+ return item
diff --git a/lib/data/datasets/skel_hmr2_fashion/mocap_dataset.py b/lib/data/datasets/skel_hmr2_fashion/mocap_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..2b1db9ed898ae4454d97ac0039ebd0a182f7ef5e
--- /dev/null
+++ b/lib/data/datasets/skel_hmr2_fashion/mocap_dataset.py
@@ -0,0 +1,31 @@
+import numpy as np
+from typing import Dict
+
+class MoCapDataset:
+
+ def __init__(self, dataset_file: str, threshold: float = 0.01) -> None:
+ """
+ Dataset class used for loading a dataset of unpaired SMPL parameter annotations
+ Args:
+ cfg (CfgNode): Model config file.
+ dataset_file (str): Path to npz file containing dataset info.
+ threshold (float): Threshold for PVE filtering.
+ """
+ data = np.load(dataset_file)
+ # pve = data['pve']
+ pve = data['pve_max']
+ # pve = data['pve_mean']
+ mask = pve < threshold
+ self.pose = data['poses'].astype(np.float32)[mask, 3:]
+ self.betas = data['betas'].astype(np.float32)[mask]
+ self.length = len(self.pose)
+ print(f'Loaded {self.length} among {len(pve)} samples from {dataset_file} (using threshold = {threshold})')
+
+ def __getitem__(self, idx: int) -> Dict:
+ pose = self.pose[idx].copy()
+ betas = self.betas[idx].copy()
+ item = {'body_pose': pose, 'betas': betas}
+ return item
+
+ def __len__(self) -> int:
+ return self.length
diff --git a/lib/data/datasets/skel_hmr2_fashion/preprocess/lspet_to_npz.py b/lib/data/datasets/skel_hmr2_fashion/preprocess/lspet_to_npz.py
new file mode 100644
index 0000000000000000000000000000000000000000..228c3df502e6bd9643085e80b09013c683fedf2e
--- /dev/null
+++ b/lib/data/datasets/skel_hmr2_fashion/preprocess/lspet_to_npz.py
@@ -0,0 +1,74 @@
+# Adapted from https://raw.githubusercontent.com/nkolot/SPIN/master/datasets/preprocess/hr_lspet.py
+import os
+import glob
+import numpy as np
+import scipy.io as sio
+# from .read_openpose import read_openpose
+
+def hr_lspet_extract(dataset_path, out_path):
+
+ # training mode
+ png_path = os.path.join(dataset_path, '*.png')
+ imgs = glob.glob(png_path)
+ imgs.sort()
+
+ # structs we use
+ imgnames_, scales_, centers_, parts_, openposes_= [], [], [], [], []
+
+ # scale factor
+ scaleFactor = 1.2
+
+ # annotation files
+ annot_file = os.path.join(dataset_path, 'joints.mat')
+ joints = sio.loadmat(annot_file)['joints']
+
+ # main loop
+ for i, imgname in enumerate(imgs):
+ # image name
+ imgname = imgname.split('/')[-1]
+ # read keypoints
+ part14 = joints[:,:2,i]
+ # scale and center
+ bbox = [min(part14[:,0]), min(part14[:,1]),
+ max(part14[:,0]), max(part14[:,1])]
+ center = [(bbox[2]+bbox[0])/2, (bbox[3]+bbox[1])/2]
+ # scale = scaleFactor*max(bbox[2]-bbox[0], bbox[3]-bbox[1]) # Don't /200
+ scale = scaleFactor*np.array([bbox[2]-bbox[0], bbox[3]-bbox[1]]) # Don't /200
+ # update keypoints
+ part = np.zeros([24,3])
+ part[:14] = np.hstack([part14, np.ones([14,1])])
+
+ # # read openpose detections
+ # json_file = os.path.join(openpose_path, 'hrlspet',
+ # imgname.replace('.png', '_keypoints.json'))
+ # openpose = read_openpose(json_file, part, 'hrlspet')
+
+ # store the data
+ imgnames_.append(imgname)
+ centers_.append(center)
+ scales_.append(scale)
+ parts_.append(part)
+ # openposes_.append(openpose)
+
+ # Populate extra_keypoints_2d: N,19,3
+ # extra_keypoints_2d[:14] = parts[:14]
+ extra_keypoints_2d = np.zeros((len(parts_), 19, 3))
+ extra_keypoints_2d[:,:14,:] = np.stack(parts_)[:,:14,:3]
+
+ print(f'{extra_keypoints_2d.shape=}')
+
+ # store the data struct
+ if not os.path.isdir(out_path):
+ os.makedirs(out_path)
+ out_file = os.path.join(out_path, 'hr-lspet_train.npz')
+ np.savez(out_file, imgname=imgnames_,
+ center=centers_,
+ scale=scales_,
+ part=parts_,
+ extra_keypoints_2d=extra_keypoints_2d,
+ # openpose=openposes_
+ )
+
+
+if __name__ == '__main__':
+ hr_lspet_extract('/shared/pavlakos/datasets/hr-lspet/', 'hmr2_evaluation_data/')
diff --git a/lib/data/datasets/skel_hmr2_fashion/preprocess/posetrack_to_npz.py b/lib/data/datasets/skel_hmr2_fashion/preprocess/posetrack_to_npz.py
new file mode 100644
index 0000000000000000000000000000000000000000..7c999b5c80c249ec768aabe4a271cbfc86877817
--- /dev/null
+++ b/lib/data/datasets/skel_hmr2_fashion/preprocess/posetrack_to_npz.py
@@ -0,0 +1,92 @@
+# Adapted from https://raw.githubusercontent.com/nkolot/SPIN/master/datasets/preprocess/coco.py
+import os
+from os.path import join
+import sys
+import json
+import numpy as np
+from pathlib import Path
+# from .read_openpose import read_openpose
+
+def coco_extract(dataset_path, out_path):
+
+ # # convert joints to global order
+ # joints_idx = [19, 20, 21, 22, 23, 9, 8, 10, 7, 11, 6, 3, 2, 4, 1, 5, 0]
+
+ # bbox expansion factor
+ scaleFactor = 1.2
+
+ # structs we need
+ imgnames_, scales_, centers_, parts_, openposes_ = [], [], [], [], []
+
+ # json annotation file
+ SPLIT='val'
+ json_paths = (Path(dataset_path)/'posetrack_data/annotations'/SPLIT).glob('*.json')
+ for json_path in json_paths:
+ json_data = json.load(open(json_path, 'r'))
+
+ imgs = {}
+ for img in json_data['images']:
+ imgs[img['id']] = img
+
+ for annot in json_data['annotations']:
+ # keypoints processing
+ keypoints = annot['keypoints']
+ keypoints = np.reshape(keypoints, (17,3))
+ keypoints[keypoints[:,2]>0,2] = 1
+ # check if all major body joints are annotated
+ if sum(keypoints[5:,2]>0) < 12:
+ continue
+ # image name
+ image_id = annot['image_id']
+ img_name = str(imgs[image_id]['file_name'])
+ # img_name_full = f'images/{SPLIT}/{json_path.stem}/{img_name}'
+ img_name_full = img_name
+
+ # keypoints
+ part = np.zeros([17,3])
+ # part[joints_idx] = keypoints
+ part = keypoints
+
+ # scale and center
+ bbox = annot['bbox']
+ center = [bbox[0] + bbox[2]/2, bbox[1] + bbox[3]/2]
+ # scale = scaleFactor*max(bbox[2], bbox[3]) # Don't do /200
+ scale = scaleFactor*np.array([bbox[2], bbox[3]]) # Don't /200
+ # # read openpose detections
+ # json_file = os.path.join(openpose_path, 'coco',
+ # img_name.replace('.jpg', '_keypoints.json'))
+ # openpose = read_openpose(json_file, part, 'coco')
+
+ # store data
+ imgnames_.append(img_name_full)
+ centers_.append(center)
+ scales_.append(scale)
+ parts_.append(part)
+ # openposes_.append(openpose)
+
+
+ # NOTE: Posetrack val doesn't annotate ears (17,18)
+ # But Posetrack does annotate head, neck so that wil have to live in extra_kps.
+ posetrack_to_op_extra = [0, 37, 38, 18, 17, 5, 2, 6, 3, 7, 4, 12, 9, 13, 10, 14, 11] # Will contain 15 keypoints.
+ all_keypoints_2d = np.zeros((len(parts_), 44, 3))
+ all_keypoints_2d[:,posetrack_to_op_extra] = np.stack(parts_)[:,:len(posetrack_to_op_extra),:3]
+ body_keypoints_2d = all_keypoints_2d[:,:25,:]
+ extra_keypoints_2d = all_keypoints_2d[:,25:,:]
+
+ print(f'{extra_keypoints_2d.shape=}')
+
+ # store the data struct
+ if not os.path.isdir(out_path):
+ os.makedirs(out_path)
+ out_file = os.path.join(out_path, f'posetrack_2018_{SPLIT}.npz')
+ np.savez(out_file, imgname=imgnames_,
+ center=centers_,
+ scale=scales_,
+ part=parts_,
+ body_keypoints_2d=body_keypoints_2d,
+ extra_keypoints_2d=extra_keypoints_2d,
+ # openpose=openposes_
+ )
+
+if __name__ == '__main__':
+ coco_extract('/shared/pavlakos/datasets/posetrack/posetrack2018/', 'hmr2_evaluation_data/')
diff --git a/lib/data/datasets/skel_hmr2_fashion/smplh_prob_filter.py b/lib/data/datasets/skel_hmr2_fashion/smplh_prob_filter.py
new file mode 100644
index 0000000000000000000000000000000000000000..b544b54999aa726378b0f8b058b1a90b166d7e30
--- /dev/null
+++ b/lib/data/datasets/skel_hmr2_fashion/smplh_prob_filter.py
@@ -0,0 +1,152 @@
+import os
+import numpy as np
+import torch
+import torch.nn.functional as F
+from lib.platform import PM
+
+JOINT_NAMES = [
+ 'left_hip',
+ 'right_hip',
+ 'spine1',
+ 'left_knee',
+ 'right_knee',
+ 'spine2',
+ 'left_ankle',
+ 'right_ankle',
+ 'spine3',
+ 'left_foot',
+ 'right_foot',
+ 'neck',
+ 'left_collar',
+ 'right_collar',
+ 'head',
+ 'left_shoulder',
+ 'right_shoulder',
+ 'left_elbow',
+ 'right_elbow',
+ 'left_wrist',
+ 'right_wrist'
+]
+
+# Manually chosen probability density thresholds for each joint
+# Probablities computed using SIGMA=2 gaussian blur on AMASS pose 3D histogram for range (-pi,pi) with 100x100x100 bins
+JOINT_NAME_PROB_THRESHOLDS = {
+ 'left_hip': 5e-5,
+ 'right_hip': 5e-5,
+ 'spine1': 2e-3,
+ 'left_knee': 5e-6,
+ 'right_knee': 5e-6,
+ 'spine2': 0.01,
+ 'left_ankle': 5e-6,
+ 'right_ankle': 5e-6,
+ 'spine3': 0.025,
+ 'left_foot': 0,
+ 'right_foot': 0,
+ 'neck': 2e-4,
+ 'left_collar': 4.5e-4 ,
+ 'right_collar': 4.5e-4,
+ 'head': 5e-4,
+ 'left_shoulder': 2e-4,
+ 'right_shoulder': 2e-4,
+ 'left_elbow': 4e-5,
+ 'right_elbow': 4e-5,
+ 'left_wrist': 1e-3,
+ 'right_wrist': 1e-3,
+}
+
+JOINT_IDX_PROB_THRESHOLDS = torch.tensor([JOINT_NAME_PROB_THRESHOLDS[joint_name] for joint_name in JOINT_NAMES])
+
+###################################################################
+POSE_RANGE_MIN = -np.pi
+POSE_RANGE_MAX = np.pi
+
+# Create 21x100x100x100 histogram of all 21 AMASS body pose joints using `create_pose_hist(amass_poses, nbins=100)`
+AMASS_HIST100_PATH = PM.inputs / 'amass_poses_hist100_SMPL+H_G.npy'
+if not os.path.exists(AMASS_HIST100_PATH):
+ assert False, f'AMASS_HIST100_PATH not found: {AMASS_HIST100_PATH}'
+
+def create_pose_hist(poses: np.ndarray, nbins: int = 100) -> np.ndarray:
+ N,K,C = poses.shape
+ assert C==3, poses.shape
+ poses_21x3 = normalize_axis_angle(torch.fromnumpy(poses).view(N*K,3)).numpy().reshape(N, K, 3)
+ assert (poses_21x3 > -np.pi).all() and (poses_21x3 < np.pi).all()
+
+ Hs, Es = [], []
+ for i in range(K):
+ H, edges = np.histogramdd(poses_21x3[:, i, :], bins=nbins, range=[(-np.pi, np.pi)]*3)
+ Hs.append(H)
+ Es.append(edges)
+ Hs = np.stack(Hs, axis=0)
+ return Hs
+
+def load_amass_hist_smooth(sigma=2) -> torch.Tensor:
+ amass_poses_hist100 = np.load(AMASS_HIST100_PATH)
+ amass_poses_hist100 = torch.from_numpy(amass_poses_hist100)
+ assert amass_poses_hist100.shape == (21,100,100,100)
+
+ nbins = amass_poses_hist100.shape[1]
+ amass_poses_hist100 = amass_poses_hist100/amass_poses_hist100.sum() / (2*np.pi/nbins)**3
+
+ # Gaussian filter on amass_poses_hist100
+ from scipy.ndimage import gaussian_filter
+ amass_poses_hist100_smooth = gaussian_filter(amass_poses_hist100.numpy(), sigma=sigma, mode='constant')
+ amass_poses_hist100_smooth = torch.from_numpy(amass_poses_hist100_smooth)
+ return amass_poses_hist100_smooth
+
+# Normalize axis angle representation s.t. angle is in [-pi, pi]
+def normalize_axis_angle(poses: torch.Tensor) -> torch.Tensor:
+ # poses: N, 3
+ # print(f'normalize_axis_angle ...')
+ assert poses.shape[1] == 3, poses.shape
+ angle = poses.norm(dim=1)
+ axis = F.normalize(poses, p=2, dim=1, eps=1e-8)
+
+ angle_fixed = angle.clone()
+ axis_fixed = axis.clone()
+
+ eps = 1e-6
+ ii = 0
+ while True:
+ # print(f'normalize_axis_angle iter {ii}')
+ ii += 1
+ angle_too_big = (angle_fixed > np.pi + eps)
+ if not angle_too_big.any():
+ break
+
+ angle_fixed[angle_too_big] -= 2 * np.pi
+ angle_too_small = (angle_fixed < -eps)
+ axis_fixed[angle_too_small] *= -1
+ angle_fixed[angle_too_small] *= -1
+
+ return axis_fixed * angle_fixed[:,None]
+
+def poses_to_joint_probs(poses: torch.Tensor, amass_poses_100_smooth: torch.Tensor) -> torch.Tensor:
+ # poses: Nx69
+ # amass_poses_100_smooth: 21xBINSxBINSxBINS
+ # returns: poses_prob: Nx21
+ N=poses.shape[0]
+ assert poses.shape == (N,69)
+ poses = poses[:,:63].reshape(N*21,3)
+
+ nbins = amass_poses_100_smooth.shape[1]
+ assert amass_poses_100_smooth.shape == (21,nbins,nbins,nbins)
+
+ poses_bin = (poses - POSE_RANGE_MIN) / (POSE_RANGE_MAX - POSE_RANGE_MIN) * (nbins - 1e-6)
+ poses_bin = poses_bin.long().clip(0, nbins-1)
+ joint_id = torch.arange(21, device=poses.device).view(1,21).expand(N,21).reshape(N*21)
+ poses_prob = amass_poses_100_smooth[joint_id, poses_bin[:,0], poses_bin[:,1], poses_bin[:,2]]
+
+ poses_bad = ((poses < POSE_RANGE_MIN) | (poses >= POSE_RANGE_MAX)).any(dim=1)
+ poses_prob[poses_bad] = 0
+
+ return poses_prob.view(N,21)
+
+def poses_check_probable(
+ poses: torch.Tensor,
+ amass_poses_100_smooth: torch.Tensor,
+ prob_thresholds: torch.Tensor = JOINT_IDX_PROB_THRESHOLDS
+ ) -> torch.Tensor:
+ N,C=poses.shape
+ poses_norm = normalize_axis_angle(poses.reshape(N*(C//3),3)).reshape(N,C)
+ poses_prob = poses_to_joint_probs(poses_norm, amass_poses_100_smooth)
+ return (poses_prob > prob_thresholds).all(dim=1)
diff --git a/lib/data/datasets/skel_hmr2_fashion/utils.py b/lib/data/datasets/skel_hmr2_fashion/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..b22225485fc2e89c5f19a711e4817720969656ea
--- /dev/null
+++ b/lib/data/datasets/skel_hmr2_fashion/utils.py
@@ -0,0 +1,1069 @@
+"""
+Parts of the code are taken or adapted from
+https://github.com/mkocabas/EpipolarPose/blob/master/lib/utils/img_utils.py
+"""
+import torch
+import numpy as np
+# from skimage.transform import rotate, resize
+# from skimage.filters import gaussian
+import random
+import cv2
+from typing import List, Dict, Tuple, Union
+from yacs.config import CfgNode
+
+def expand_to_aspect_ratio(input_shape, target_aspect_ratio=None):
+ """Increase the size of the bounding box to match the target shape."""
+ if target_aspect_ratio is None:
+ return input_shape
+
+ try:
+ w , h = input_shape
+ except (ValueError, TypeError):
+ return input_shape
+
+ w_t, h_t = target_aspect_ratio
+ if h / w < h_t / w_t:
+ h_new = max(w * h_t / w_t, h)
+ w_new = w
+ else:
+ h_new = h
+ w_new = max(h * w_t / h_t, w)
+ if h_new < h or w_new < w:
+ breakpoint()
+ return np.array([w_new, h_new])
+
+def expand_bbox_to_aspect_ratio(bbox, target_aspect_ratio=None):
+ # bbox: np.ndarray: (N,4) detectron2 bbox format
+ # target_aspect_ratio: (width, height)
+ if target_aspect_ratio is None:
+ return bbox
+
+ is_singleton = (bbox.ndim == 1)
+ if is_singleton:
+ bbox = bbox[None,:]
+
+ if bbox.shape[0] > 0:
+ center = np.stack(((bbox[:,0] + bbox[:,2]) / 2, (bbox[:,1] + bbox[:,3]) / 2), axis=1)
+ scale_wh = np.stack((bbox[:,2] - bbox[:,0], bbox[:,3] - bbox[:,1]), axis=1)
+ scale_wh = np.stack([expand_to_aspect_ratio(wh, target_aspect_ratio) for wh in scale_wh], axis=0)
+ bbox = np.stack([
+ center[:,0] - scale_wh[:,0] / 2,
+ center[:,1] - scale_wh[:,1] / 2,
+ center[:,0] + scale_wh[:,0] / 2,
+ center[:,1] + scale_wh[:,1] / 2,
+ ], axis=1)
+
+ if is_singleton:
+ bbox = bbox[0,:]
+
+ return bbox
+
+def do_augmentation(aug_config: CfgNode) -> Tuple:
+ """
+ Compute random augmentation parameters.
+ Args:
+ aug_config (CfgNode): Config containing augmentation parameters.
+ Returns:
+ scale (float): Box rescaling factor.
+ rot (float): Random image rotation.
+ do_flip (bool): Whether to flip image or not.
+ do_extreme_crop (bool): Whether to apply extreme cropping (as proposed in EFT).
+ color_scale (List): Color rescaling factor
+ tx (float): Random translation along the x axis.
+ ty (float): Random translation along the y axis.
+ """
+
+ tx = np.clip(np.random.randn(), -1.0, 1.0) * aug_config.TRANS_FACTOR
+ ty = np.clip(np.random.randn(), -1.0, 1.0) * aug_config.TRANS_FACTOR
+ scale = np.clip(np.random.randn(), -1.0, 1.0) * aug_config.SCALE_FACTOR + 1.0
+ rot = np.clip(np.random.randn(), -2.0,
+ 2.0) * aug_config.ROT_FACTOR if random.random() <= aug_config.ROT_AUG_RATE else 0
+ do_flip = aug_config.DO_FLIP and random.random() <= aug_config.FLIP_AUG_RATE
+ do_extreme_crop = random.random() <= aug_config.EXTREME_CROP_AUG_RATE
+ extreme_crop_lvl = aug_config.get('EXTREME_CROP_AUG_LEVEL', 0)
+ # extreme_crop_lvl = 0
+ c_up = 1.0 + aug_config.COLOR_SCALE
+ c_low = 1.0 - aug_config.COLOR_SCALE
+ color_scale = [random.uniform(c_low, c_up), random.uniform(c_low, c_up), random.uniform(c_low, c_up)]
+ return scale, rot, do_flip, do_extreme_crop, extreme_crop_lvl, color_scale, tx, ty
+
+def rotate_2d(pt_2d: np.ndarray, rot_rad: float) -> np.ndarray:
+ """
+ Rotate a 2D point on the x-y plane.
+ Args:
+ pt_2d (np.ndarray): Input 2D point with shape (2,).
+ rot_rad (float): Rotation angle
+ Returns:
+ np.ndarray: Rotated 2D point.
+ """
+ x = pt_2d[0]
+ y = pt_2d[1]
+ sn, cs = np.sin(rot_rad), np.cos(rot_rad)
+ xx = x * cs - y * sn
+ yy = x * sn + y * cs
+ return np.array([xx, yy], dtype=np.float32)
+
+
+def gen_trans_from_patch_cv(c_x: float, c_y: float,
+ src_width: float, src_height: float,
+ dst_width: float, dst_height: float,
+ scale: float, rot: float) -> np.ndarray:
+ """
+ Create transformation matrix for the bounding box crop.
+ Args:
+ c_x (float): Bounding box center x coordinate in the original image.
+ c_y (float): Bounding box center y coordinate in the original image.
+ src_width (float): Bounding box width.
+ src_height (float): Bounding box height.
+ dst_width (float): Output box width.
+ dst_height (float): Output box height.
+ scale (float): Rescaling factor for the bounding box (augmentation).
+ rot (float): Random rotation applied to the box.
+ Returns:
+ trans (np.ndarray): Target geometric transformation.
+ """
+ # augment size with scale
+ src_w = src_width * scale
+ src_h = src_height * scale
+ src_center = np.zeros(2)
+ src_center[0] = c_x
+ src_center[1] = c_y
+ # augment rotation
+ rot_rad = np.pi * rot / 180
+ src_downdir = rotate_2d(np.array([0, src_h * 0.5], dtype=np.float32), rot_rad)
+ src_rightdir = rotate_2d(np.array([src_w * 0.5, 0], dtype=np.float32), rot_rad)
+
+ dst_w = dst_width
+ dst_h = dst_height
+ dst_center = np.array([dst_w * 0.5, dst_h * 0.5], dtype=np.float32)
+ dst_downdir = np.array([0, dst_h * 0.5], dtype=np.float32)
+ dst_rightdir = np.array([dst_w * 0.5, 0], dtype=np.float32)
+
+ src = np.zeros((3, 2), dtype=np.float32)
+ src[0, :] = src_center
+ src[1, :] = src_center + src_downdir
+ src[2, :] = src_center + src_rightdir
+
+ dst = np.zeros((3, 2), dtype=np.float32)
+ dst[0, :] = dst_center
+ dst[1, :] = dst_center + dst_downdir
+ dst[2, :] = dst_center + dst_rightdir
+
+ trans = cv2.getAffineTransform(np.float32(src), np.float32(dst)) # type: ignore
+
+ return trans
+
+
+def trans_point2d(pt_2d: np.ndarray, trans: np.ndarray):
+ """
+ Transform a 2D point using translation matrix trans.
+ Args:
+ pt_2d (np.ndarray): Input 2D point with shape (2,).
+ trans (np.ndarray): Transformation matrix.
+ Returns:
+ np.ndarray: Transformed 2D point.
+ """
+ src_pt = np.array([pt_2d[0], pt_2d[1], 1.]).T
+ dst_pt = np.dot(trans, src_pt)
+ return dst_pt[0:2]
+
+def get_transform(center, scale, res, rot=0):
+ """Generate transformation matrix."""
+ """Taken from PARE: https://github.com/mkocabas/PARE/blob/6e0caca86c6ab49ff80014b661350958e5b72fd8/pare/utils/image_utils.py"""
+ h = 200 * scale
+ t = np.zeros((3, 3))
+ t[0, 0] = float(res[1]) / h
+ t[1, 1] = float(res[0]) / h
+ t[0, 2] = res[1] * (-float(center[0]) / h + .5)
+ t[1, 2] = res[0] * (-float(center[1]) / h + .5)
+ t[2, 2] = 1
+ if not rot == 0:
+ rot = -rot # To match direction of rotation from cropping
+ rot_mat = np.zeros((3, 3))
+ rot_rad = rot * np.pi / 180
+ sn, cs = np.sin(rot_rad), np.cos(rot_rad)
+ rot_mat[0, :2] = [cs, -sn]
+ rot_mat[1, :2] = [sn, cs]
+ rot_mat[2, 2] = 1
+ # Need to rotate around center
+ t_mat = np.eye(3)
+ t_mat[0, 2] = -res[1] / 2
+ t_mat[1, 2] = -res[0] / 2
+ t_inv = t_mat.copy()
+ t_inv[:2, 2] *= -1
+ t = np.dot(t_inv, np.dot(rot_mat, np.dot(t_mat, t)))
+ return t
+
+
+def transform(pt, center, scale, res, invert=0, rot=0, as_int=True):
+ """Transform pixel location to different reference."""
+ """Taken from PARE: https://github.com/mkocabas/PARE/blob/6e0caca86c6ab49ff80014b661350958e5b72fd8/pare/utils/image_utils.py"""
+ t = get_transform(center, scale, res, rot=rot)
+ if invert:
+ t = np.linalg.inv(t)
+ new_pt = np.array([pt[0] - 1, pt[1] - 1, 1.]).T
+ new_pt = np.dot(t, new_pt)
+ if as_int:
+ new_pt = new_pt.astype(int)
+ return new_pt[:2] + 1
+
+def crop_img(img, ul, br, border_mode=cv2.BORDER_CONSTANT, border_value=0):
+ c_x = (ul[0] + br[0])/2
+ c_y = (ul[1] + br[1])/2
+ bb_width = patch_width = br[0] - ul[0]
+ bb_height = patch_height = br[1] - ul[1]
+ trans = gen_trans_from_patch_cv(c_x, c_y, bb_width, bb_height, patch_width, patch_height, 1.0, 0)
+ img_patch = cv2.warpAffine(img, trans, (int(patch_width), int(patch_height)),
+ flags=cv2.INTER_LINEAR,
+ borderMode=border_mode,
+ borderValue=border_value
+ ) # type: ignore
+
+ # Force borderValue=cv2.BORDER_CONSTANT for alpha channel
+ if (img.shape[2] == 4) and (border_mode != cv2.BORDER_CONSTANT):
+ img_patch[:,:,3] = cv2.warpAffine(img[:,:,3], trans, (int(patch_width), int(patch_height)),
+ flags=cv2.INTER_LINEAR,
+ borderMode=cv2.BORDER_CONSTANT,
+ )
+
+ return img_patch
+
+# def generate_image_patch_skimage(img: np.ndarray, c_x: float, c_y: float,
+# bb_width: float, bb_height: float,
+# patch_width: float, patch_height: float,
+# do_flip: bool, scale: float, rot: float,
+# border_mode=cv2.BORDER_CONSTANT, border_value=0) -> Tuple[np.ndarray, np.ndarray]:
+# """
+# Crop image according to the supplied bounding box.
+# Args:
+# img (np.ndarray): Input image of shape (H, W, 3)
+# c_x (float): Bounding box center x coordinate in the original image.
+# c_y (float): Bounding box center y coordinate in the original image.
+# bb_width (float): Bounding box width.
+# bb_height (float): Bounding box height.
+# patch_width (float): Output box width.
+# patch_height (float): Output box height.
+# do_flip (bool): Whether to flip image or not.
+# scale (float): Rescaling factor for the bounding box (augmentation).
+# rot (float): Random rotation applied to the box.
+# Returns:
+# img_patch (np.ndarray): Cropped image patch of shape (patch_height, patch_height, 3)
+# trans (np.ndarray): Transformation matrix.
+# """
+
+# img_height, img_width, img_channels = img.shape
+# if do_flip:
+# img = img[:, ::-1, :]
+# c_x = img_width - c_x - 1
+
+# trans = gen_trans_from_patch_cv(c_x, c_y, bb_width, bb_height, patch_width, patch_height, scale, rot)
+
+# #img_patch = cv2.warpAffine(img, trans, (int(patch_width), int(patch_height)), flags=cv2.INTER_LINEAR)
+
+# # skimage
+# center = np.zeros(2)
+# center[0] = c_x
+# center[1] = c_y
+# res = np.zeros(2)
+# res[0] = patch_width
+# res[1] = patch_height
+# # assumes bb_width = bb_height
+# # assumes patch_width = patch_height
+# assert bb_width == bb_height, f'{bb_width=} != {bb_height=}'
+# assert patch_width == patch_height, f'{patch_width=} != {patch_height=}'
+# scale1 = scale*bb_width/200.
+
+# # Upper left point
+# ul = np.array(transform([1, 1], center, scale1, res, invert=1, as_int=False)) - 1
+# # Bottom right point
+# br = np.array(transform([res[0] + 1,
+# res[1] + 1], center, scale1, res, invert=1, as_int=False)) - 1
+
+# # Padding so that when rotated proper amount of context is included
+# try:
+# pad = int(np.linalg.norm(br - ul) / 2 - float(br[1] - ul[1]) / 2) + 1
+# except:
+# breakpoint()
+# if not rot == 0:
+# ul -= pad
+# br += pad
+
+
+# if False:
+# # Old way of cropping image
+# ul_int = ul.astype(int)
+# br_int = br.astype(int)
+# new_shape = [br_int[1] - ul_int[1], br_int[0] - ul_int[0]]
+# if len(img.shape) > 2:
+# new_shape += [img.shape[2]]
+# new_img = np.zeros(new_shape)
+
+# # Range to fill new array
+# new_x = max(0, -ul_int[0]), min(br_int[0], len(img[0])) - ul_int[0]
+# new_y = max(0, -ul_int[1]), min(br_int[1], len(img)) - ul_int[1]
+# # Range to sample from original image
+# old_x = max(0, ul_int[0]), min(len(img[0]), br_int[0])
+# old_y = max(0, ul_int[1]), min(len(img), br_int[1])
+# new_img[new_y[0]:new_y[1], new_x[0]:new_x[1]] = img[old_y[0]:old_y[1],
+# old_x[0]:old_x[1]]
+
+# # New way of cropping image
+# new_img = crop_img(img, ul, br, border_mode=border_mode, border_value=border_value).astype(np.float32)
+
+# # print(f'{new_img.shape=}')
+# # print(f'{new_img1.shape=}')
+# # print(f'{np.allclose(new_img, new_img1)=}')
+# # print(f'{img.dtype=}')
+
+
+# if not rot == 0:
+# # Remove padding
+
+# new_img = rotate(new_img, rot) # scipy.misc.imrotate(new_img, rot)
+# new_img = new_img[pad:-pad, pad:-pad]
+
+# if new_img.shape[0] < 1 or new_img.shape[1] < 1:
+# print(f'{img.shape=}')
+# print(f'{new_img.shape=}')
+# print(f'{ul=}')
+# print(f'{br=}')
+# print(f'{pad=}')
+# print(f'{rot=}')
+
+# breakpoint()
+
+# # resize image
+# new_img = resize(new_img, res) # scipy.misc.imresize(new_img, res)
+
+# new_img = np.clip(new_img, 0, 255).astype(np.uint8)
+
+# return new_img, trans
+
+
+def generate_image_patch_cv2(img: np.ndarray, c_x: float, c_y: float,
+ bb_width: float, bb_height: float,
+ patch_width: float, patch_height: float,
+ do_flip: bool, scale: float, rot: float,
+ border_mode=cv2.BORDER_CONSTANT, border_value=0) -> Tuple[np.ndarray, np.ndarray]:
+ """
+ Crop the input image and return the crop and the corresponding transformation matrix.
+ Args:
+ img (np.ndarray): Input image of shape (H, W, 3)
+ c_x (float): Bounding box center x coordinate in the original image.
+ c_y (float): Bounding box center y coordinate in the original image.
+ bb_width (float): Bounding box width.
+ bb_height (float): Bounding box height.
+ patch_width (float): Output box width.
+ patch_height (float): Output box height.
+ do_flip (bool): Whether to flip image or not.
+ scale (float): Rescaling factor for the bounding box (augmentation).
+ rot (float): Random rotation applied to the box.
+ Returns:
+ img_patch (np.ndarray): Cropped image patch of shape (patch_height, patch_height, 3)
+ trans (np.ndarray): Transformation matrix.
+ """
+
+ img_height, img_width, img_channels = img.shape
+ if do_flip:
+ img = img[:, ::-1, :]
+ c_x = img_width - c_x - 1
+
+
+ trans = gen_trans_from_patch_cv(c_x, c_y, bb_width, bb_height, patch_width, patch_height, scale, rot)
+
+ img_patch = cv2.warpAffine(img, trans, (int(patch_width), int(patch_height)),
+ flags=cv2.INTER_LINEAR,
+ borderMode=border_mode,
+ borderValue=border_value,
+ ) # type: ignore
+ # Force borderValue=cv2.BORDER_CONSTANT for alpha channel
+ if (img.shape[2] == 4) and (border_mode != cv2.BORDER_CONSTANT):
+ img_patch[:,:,3] = cv2.warpAffine(img[:,:,3], trans, (int(patch_width), int(patch_height)),
+ flags=cv2.INTER_LINEAR,
+ borderMode=cv2.BORDER_CONSTANT,
+ )
+
+ return img_patch, trans
+
+
+def convert_cvimg_to_tensor(cvimg: np.ndarray):
+ """
+ Convert image from HWC to CHW format.
+ Args:
+ cvimg (np.ndarray): Image of shape (H, W, 3) as loaded by OpenCV.
+ Returns:
+ np.ndarray: Output image of shape (3, H, W).
+ """
+ # from h,w,c(OpenCV) to c,h,w
+ img = cvimg.copy()
+ img = np.transpose(img, (2, 0, 1))
+ # from int to float
+ img = img.astype(np.float32)
+ return img
+
+def fliplr_params(smpl_params: Dict, has_smpl_params: Dict) -> Tuple[Dict, Dict]:
+ """
+ Flip SMPL parameters when flipping the image.
+ Args:
+ smpl_params (Dict): SMPL parameter annotations.
+ has_smpl_params (Dict): Whether SMPL annotations are valid.
+ Returns:
+ Dict, Dict: Flipped SMPL parameters and valid flags.
+ """
+ global_orient = smpl_params['global_orient'].copy()
+ body_pose = smpl_params['body_pose'].copy()
+ betas = smpl_params['betas'].copy()
+ has_global_orient = has_smpl_params['global_orient'].copy()
+ has_body_pose = has_smpl_params['body_pose'].copy()
+ has_betas = has_smpl_params['betas'].copy()
+
+ body_pose_permutation = [6, 7, 8, 3, 4, 5, 9, 10, 11, 15, 16, 17, 12, 13,
+ 14 ,18, 19, 20, 24, 25, 26, 21, 22, 23, 27, 28, 29, 33,
+ 34, 35, 30, 31, 32, 36, 37, 38, 42, 43, 44, 39, 40, 41,
+ 45, 46, 47, 51, 52, 53, 48, 49, 50, 57, 58, 59, 54, 55,
+ 56, 63, 64, 65, 60, 61, 62, 69, 70, 71, 66, 67, 68]
+ body_pose_permutation = body_pose_permutation[:len(body_pose)]
+ body_pose_permutation = [i-3 for i in body_pose_permutation]
+
+ body_pose = body_pose[body_pose_permutation]
+
+ global_orient[1::3] *= -1
+ global_orient[2::3] *= -1
+ body_pose[1::3] *= -1
+ body_pose[2::3] *= -1
+
+ smpl_params = {'global_orient': global_orient.astype(np.float32),
+ 'body_pose': body_pose.astype(np.float32),
+ 'betas': betas.astype(np.float32)
+ }
+
+ has_smpl_params = {'global_orient': has_global_orient,
+ 'body_pose': has_body_pose,
+ 'betas': has_betas
+ }
+
+ return smpl_params, has_smpl_params
+
+
+def fliplr_keypoints(joints: np.ndarray, width: float, flip_permutation: List[int]) -> np.ndarray:
+ """
+ Flip 2D or 3D keypoints.
+ Args:
+ joints (np.ndarray): Array of shape (N, 3) or (N, 4) containing 2D or 3D keypoint locations and confidence.
+ flip_permutation (List): Permutation to apply after flipping.
+ Returns:
+ np.ndarray: Flipped 2D or 3D keypoints with shape (N, 3) or (N, 4) respectively.
+ """
+ joints = joints.copy()
+ # Flip horizontal
+ joints[:, 0] = width - joints[:, 0] - 1
+ joints = joints[flip_permutation, :]
+
+ return joints
+
+def keypoint_3d_processing(keypoints_3d: np.ndarray, flip_permutation: List[int], rot: float, do_flip: float) -> np.ndarray:
+ """
+ Process 3D keypoints (rotation/flipping).
+ Args:
+ keypoints_3d (np.ndarray): Input array of shape (N, 4) containing the 3D keypoints and confidence.
+ flip_permutation (List): Permutation to apply after flipping.
+ rot (float): Random rotation applied to the keypoints.
+ do_flip (bool): Whether to flip keypoints or not.
+ Returns:
+ np.ndarray: Transformed 3D keypoints with shape (N, 4).
+ """
+ if do_flip:
+ keypoints_3d = fliplr_keypoints(keypoints_3d, 1, flip_permutation)
+ # in-plane rotation
+ rot_mat = np.eye(3)
+ if not rot == 0:
+ rot_rad = -rot * np.pi / 180
+ sn,cs = np.sin(rot_rad), np.cos(rot_rad)
+ rot_mat[0,:2] = [cs, -sn]
+ rot_mat[1,:2] = [sn, cs]
+ keypoints_3d[:, :-1] = np.einsum('ij,kj->ki', rot_mat, keypoints_3d[:, :-1])
+ # flip the x coordinates
+ keypoints_3d = keypoints_3d.astype('float32')
+ return keypoints_3d
+
+def rot_q_orient(q: np.ndarray, rot: float) -> np.ndarray:
+ """
+ Rotate SKEL orientation.
+ Args:
+ q (np.ndarray): SKEL style rotation representation (3,).
+ rot (np.ndarray): Rotation angle in degrees.
+ Returns:
+ np.ndarray: Rotated axis-angle vector.
+ """
+ import torch
+ from lib.body_models.skel.osim_rot import CustomJoint
+ from lib.utils.geometry.rotation import (
+ axis_angle_to_matrix,
+ matrix_to_euler_angles,
+ euler_angles_to_matrix,
+ )
+ # q to mat
+ q = torch.from_numpy(q).unsqueeze(0)
+ q = q[:, [2, 1, 0]]
+ Rp = euler_angles_to_matrix(q, convention="YXZ")
+ # rotate around z
+ R = torch.Tensor([[np.deg2rad(-rot), 0, 0]])
+ R = axis_angle_to_matrix(R)
+ R = torch.matmul(R, Rp)
+ # mat to q
+ q = matrix_to_euler_angles(R, convention="YXZ")
+ q = q[:, [2, 1, 0]]
+ q = q.numpy().squeeze()
+
+ return q.astype(np.float32)
+
+
+def rot_aa(aa: np.ndarray, rot: float) -> np.ndarray:
+ """
+ Rotate axis angle parameters.
+ Args:
+ aa (np.ndarray): Axis-angle vector of shape (3,).
+ rot (np.ndarray): Rotation angle in degrees.
+ Returns:
+ np.ndarray: Rotated axis-angle vector.
+ """
+ # pose parameters
+ R = np.array([[np.cos(np.deg2rad(-rot)), -np.sin(np.deg2rad(-rot)), 0],
+ [np.sin(np.deg2rad(-rot)), np.cos(np.deg2rad(-rot)), 0],
+ [0, 0, 1]])
+ # find the rotation of the body in camera frame
+ per_rdg, _ = cv2.Rodrigues(aa)
+ # apply the global rotation to the global orientation
+ resrot, _ = cv2.Rodrigues(np.dot(R,per_rdg))
+ aa = (resrot.T)[0]
+ return aa.astype(np.float32)
+
+def smpl_param_processing(smpl_params: Dict, has_smpl_params: Dict, rot: float, do_flip: bool, do_aug: bool) -> Tuple[Dict, Dict]:
+ """
+ Apply random augmentations to the SMPL parameters.
+ Args:
+ smpl_params (Dict): SMPL parameter annotations.
+ has_smpl_params (Dict): Whether SMPL annotations are valid.
+ rot (float): Random rotation applied to the keypoints.
+ do_flip (bool): Whether to flip keypoints or not.
+ Returns:
+ Dict, Dict: Transformed SMPL parameters and valid flags.
+ """
+ if do_aug:
+ if do_flip:
+ smpl_params, has_smpl_params = {k:v[1] for k, v in smpl_params.items()}, {k:v[1] for k, v in has_smpl_params.items()}
+ else:
+ smpl_params, has_smpl_params = {k:v[0] for k, v in smpl_params.items()}, {k:v[0] for k, v in has_smpl_params.items()}
+ if do_aug:
+ smpl_params['global_orient'] = rot_q_orient(smpl_params['global_orient'], rot)
+ else:
+ smpl_params['global_orient'] = rot_aa(smpl_params['global_orient'], rot)
+ return smpl_params, has_smpl_params
+
+
+
+def get_example(img_path: Union[str, np.ndarray], center_x: float, center_y: float,
+ width: float, height: float,
+ keypoints_2d: np.ndarray, keypoints_3d: np.ndarray,
+ smpl_params: Dict, has_smpl_params: Dict,
+ flip_kp_permutation: List[int],
+ patch_width: int, patch_height: int,
+ mean: np.ndarray, std: np.ndarray,
+ do_augment: bool, augm_config: CfgNode,
+ is_bgr: bool = True,
+ use_skimage_antialias: bool = False,
+ border_mode: int = cv2.BORDER_CONSTANT,
+ return_trans: bool = False) -> Tuple:
+ """
+ Get an example from the dataset and (possibly) apply random augmentations.
+ Args:
+ img_path (str): Image filename
+ center_x (float): Bounding box center x coordinate in the original image.
+ center_y (float): Bounding box center y coordinate in the original image.
+ width (float): Bounding box width.
+ height (float): Bounding box height.
+ keypoints_2d (np.ndarray): Array with shape (N,3) containing the 2D keypoints in the original image coordinates.
+ keypoints_3d (np.ndarray): Array with shape (N,4) containing the 3D keypoints.
+ smpl_params (Dict): SMPL parameter annotations.
+ has_smpl_params (Dict): Whether SMPL annotations are valid.
+ flip_kp_permutation (List): Permutation to apply to the keypoints after flipping.
+ patch_width (float): Output box width.
+ patch_height (float): Output box height.
+ mean (np.ndarray): Array of shape (3,) containing the mean for normalizing the input image.
+ std (np.ndarray): Array of shape (3,) containing the std for normalizing the input image.
+ do_augment (bool): Whether to apply data augmentation or not.
+ aug_config (CfgNode): Config containing augmentation parameters.
+ Returns:
+ return img_patch, keypoints_2d, keypoints_3d, smpl_params, has_smpl_params, img_size
+ img_patch (np.ndarray): Cropped image patch of shape (3, patch_height, patch_height)
+ keypoints_2d (np.ndarray): Array with shape (N,3) containing the transformed 2D keypoints.
+ keypoints_3d (np.ndarray): Array with shape (N,4) containing the transformed 3D keypoints.
+ smpl_params (Dict): Transformed SMPL parameters.
+ has_smpl_params (Dict): Valid flag for transformed SMPL parameters.
+ img_size (np.ndarray): Image size of the original image.
+ """
+ if isinstance(img_path, str):
+ # 1. load image
+ cvimg = cv2.imread(img_path, cv2.IMREAD_COLOR | cv2.IMREAD_IGNORE_ORIENTATION)
+ if not isinstance(cvimg, np.ndarray):
+ raise IOError("Fail to read %s" % img_path)
+ elif isinstance(img_path, np.ndarray):
+ cvimg = img_path
+ else:
+ raise TypeError('img_path must be either a string or a numpy array')
+ img_height, img_width, img_channels = cvimg.shape
+
+ img_size = np.array([img_height, img_width])
+
+ # 2. get augmentation params
+ if do_augment:
+ scale, rot, do_flip, do_extreme_crop, extreme_crop_lvl, color_scale, tx, ty = do_augmentation(augm_config)
+ else:
+ scale, rot, do_flip, do_extreme_crop, extreme_crop_lvl, color_scale, tx, ty = 1.0, 0, False, False, 0, [1.0, 1.0, 1.0], 0., 0.
+
+ if width < 1 or height < 1:
+ breakpoint()
+
+ if do_extreme_crop:
+ if extreme_crop_lvl == 0:
+ center_x1, center_y1, width1, height1 = extreme_cropping(center_x, center_y, width, height, keypoints_2d)
+ elif extreme_crop_lvl == 1:
+ center_x1, center_y1, width1, height1 = extreme_cropping_aggressive(center_x, center_y, width, height, keypoints_2d)
+
+ THRESH = 4
+ if width1 < THRESH or height1 < THRESH:
+ # print(f'{do_extreme_crop=}')
+ # print(f'width: {width}, height: {height}')
+ # print(f'width1: {width1}, height1: {height1}')
+ # print(f'center_x: {center_x}, center_y: {center_y}')
+ # print(f'center_x1: {center_x1}, center_y1: {center_y1}')
+ # print(f'keypoints_2d: {keypoints_2d}')
+ # print(f'\n\n', flush=True)
+ # breakpoint()
+ pass
+ # print(f'skip ==> width1: {width1}, height1: {height1}, width: {width}, height: {height}')
+ else:
+ center_x, center_y, width, height = center_x1, center_y1, width1, height1
+
+ center_x += width * tx
+ center_y += height * ty
+
+ # Process 3D keypoints
+ keypoints_3d = keypoint_3d_processing(keypoints_3d, flip_kp_permutation, rot, do_flip)
+
+ # 3. generate image patch
+ # if use_skimage_antialias:
+ # # Blur image to avoid aliasing artifacts
+ # downsampling_factor = (patch_width / (width*scale))
+ # if downsampling_factor > 1.1:
+ # cvimg = gaussian(cvimg, sigma=(downsampling_factor-1)/2, channel_axis=2, preserve_range=True, truncate=3.0)
+
+ img_patch_cv, trans = generate_image_patch_cv2(cvimg,
+ center_x, center_y,
+ width, height,
+ patch_width, patch_height,
+ do_flip, scale, rot,
+ border_mode=border_mode)
+ # img_patch_cv, trans = generate_image_patch_skimage(cvimg,
+ # center_x, center_y,
+ # width, height,
+ # patch_width, patch_height,
+ # do_flip, scale, rot,
+ # border_mode=border_mode)
+
+ image = img_patch_cv.copy()
+ if is_bgr:
+ image = image[:, :, ::-1]
+ img_patch_cv = image.copy()
+ img_patch = convert_cvimg_to_tensor(image)
+
+
+ smpl_params, has_smpl_params = smpl_param_processing(smpl_params, has_smpl_params, rot, do_flip, do_augment)
+
+ # apply normalization
+ for n_c in range(min(img_channels, 3)):
+ img_patch[n_c, :, :] = np.clip(img_patch[n_c, :, :] * color_scale[n_c], 0, 255)
+ if mean is not None and std is not None:
+ img_patch[n_c, :, :] = (img_patch[n_c, :, :] - mean[n_c]) / std[n_c]
+ if do_flip:
+ keypoints_2d = fliplr_keypoints(keypoints_2d, img_width, flip_kp_permutation)
+
+
+ for n_jt in range(len(keypoints_2d)):
+ keypoints_2d[n_jt, 0:2] = trans_point2d(keypoints_2d[n_jt, 0:2], trans)
+ keypoints_2d[:, :-1] = keypoints_2d[:, :-1] / patch_width - 0.5
+
+
+ augm_record = {
+ 'do_flip' : do_flip,
+ 'rot' : rot,
+ }
+
+ if not return_trans:
+ return img_patch, keypoints_2d, keypoints_3d, smpl_params, has_smpl_params, img_size, augm_record
+ else:
+ return img_patch, keypoints_2d, keypoints_3d, smpl_params, has_smpl_params, img_size, trans, augm_record
+
+def crop_to_hips(center_x: float, center_y: float, width: float, height: float, keypoints_2d: np.ndarray) -> Tuple:
+ """
+ Extreme cropping: Crop the box up to the hip locations.
+ Args:
+ center_x (float): x coordinate of the bounding box center.
+ center_y (float): y coordinate of the bounding box center.
+ width (float): Bounding box width.
+ height (float): Bounding box height.
+ keypoints_2d (np.ndarray): Array of shape (N, 3) containing 2D keypoint locations.
+ Returns:
+ center_x (float): x coordinate of the new bounding box center.
+ center_y (float): y coordinate of the new bounding box center.
+ width (float): New bounding box width.
+ height (float): New bounding box height.
+ """
+ keypoints_2d = keypoints_2d.copy()
+ lower_body_keypoints = [10, 11, 13, 14, 19, 20, 21, 22, 23, 24, 25+0, 25+1, 25+4, 25+5]
+ keypoints_2d[lower_body_keypoints, :] = 0
+ if keypoints_2d[:, -1].sum() > 1:
+ center, scale = get_bbox(keypoints_2d)
+ center_x = center[0]
+ center_y = center[1]
+ width = 1.1 * scale[0]
+ height = 1.1 * scale[1]
+ return center_x, center_y, width, height
+
+
+def crop_to_shoulders(center_x: float, center_y: float, width: float, height: float, keypoints_2d: np.ndarray):
+ """
+ Extreme cropping: Crop the box up to the shoulder locations.
+ Args:
+ center_x (float): x coordinate of the bounding box center.
+ center_y (float): y coordinate of the bounding box center.
+ width (float): Bounding box width.
+ height (float): Bounding box height.
+ keypoints_2d (np.ndarray): Array of shape (N, 3) containing 2D keypoint locations.
+ Returns:
+ center_x (float): x coordinate of the new bounding box center.
+ center_y (float): y coordinate of the new bounding box center.
+ width (float): New bounding box width.
+ height (float): New bounding box height.
+ """
+ keypoints_2d = keypoints_2d.copy()
+ lower_body_keypoints = [3, 4, 6, 7, 8, 9, 10, 11, 12, 13, 14, 19, 20, 21, 22, 23, 24] + [25 + i for i in [0, 1, 2, 3, 4, 5, 6, 7, 10, 11, 14, 15, 16]]
+ keypoints_2d[lower_body_keypoints, :] = 0
+ center, scale = get_bbox(keypoints_2d)
+ if keypoints_2d[:, -1].sum() > 1:
+ center, scale = get_bbox(keypoints_2d)
+ center_x = center[0]
+ center_y = center[1]
+ width = 1.2 * scale[0]
+ height = 1.2 * scale[1]
+ return center_x, center_y, width, height
+
+def crop_to_head(center_x: float, center_y: float, width: float, height: float, keypoints_2d: np.ndarray):
+ """
+ Extreme cropping: Crop the box and keep on only the head.
+ Args:
+ center_x (float): x coordinate of the bounding box center.
+ center_y (float): y coordinate of the bounding box center.
+ width (float): Bounding box width.
+ height (float): Bounding box height.
+ keypoints_2d (np.ndarray): Array of shape (N, 3) containing 2D keypoint locations.
+ Returns:
+ center_x (float): x coordinate of the new bounding box center.
+ center_y (float): y coordinate of the new bounding box center.
+ width (float): New bounding box width.
+ height (float): New bounding box height.
+ """
+ keypoints_2d = keypoints_2d.copy()
+ lower_body_keypoints = [3, 4, 6, 7, 8, 9, 10, 11, 12, 13, 14, 19, 20, 21, 22, 23, 24] + [25 + i for i in [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 14, 15, 16]]
+ keypoints_2d[lower_body_keypoints, :] = 0
+ if keypoints_2d[:, -1].sum() > 1:
+ center, scale = get_bbox(keypoints_2d)
+ center_x = center[0]
+ center_y = center[1]
+ width = 1.3 * scale[0]
+ height = 1.3 * scale[1]
+ return center_x, center_y, width, height
+
+def crop_torso_only(center_x: float, center_y: float, width: float, height: float, keypoints_2d: np.ndarray):
+ """
+ Extreme cropping: Crop the box and keep on only the torso.
+ Args:
+ center_x (float): x coordinate of the bounding box center.
+ center_y (float): y coordinate of the bounding box center.
+ width (float): Bounding box width.
+ height (float): Bounding box height.
+ keypoints_2d (np.ndarray): Array of shape (N, 3) containing 2D keypoint locations.
+ Returns:
+ center_x (float): x coordinate of the new bounding box center.
+ center_y (float): y coordinate of the new bounding box center.
+ width (float): New bounding box width.
+ height (float): New bounding box height.
+ """
+ keypoints_2d = keypoints_2d.copy()
+ nontorso_body_keypoints = [0, 3, 4, 6, 7, 10, 11, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24] + [25 + i for i in [0, 1, 4, 5, 6, 7, 10, 11, 13, 17, 18]]
+ keypoints_2d[nontorso_body_keypoints, :] = 0
+ if keypoints_2d[:, -1].sum() > 1:
+ center, scale = get_bbox(keypoints_2d)
+ center_x = center[0]
+ center_y = center[1]
+ width = 1.1 * scale[0]
+ height = 1.1 * scale[1]
+ return center_x, center_y, width, height
+
+def crop_rightarm_only(center_x: float, center_y: float, width: float, height: float, keypoints_2d: np.ndarray):
+ """
+ Extreme cropping: Crop the box and keep on only the right arm.
+ Args:
+ center_x (float): x coordinate of the bounding box center.
+ center_y (float): y coordinate of the bounding box center.
+ width (float): Bounding box width.
+ height (float): Bounding box height.
+ keypoints_2d (np.ndarray): Array of shape (N, 3) containing 2D keypoint locations.
+ Returns:
+ center_x (float): x coordinate of the new bounding box center.
+ center_y (float): y coordinate of the new bounding box center.
+ width (float): New bounding box width.
+ height (float): New bounding box height.
+ """
+ keypoints_2d = keypoints_2d.copy()
+ nonrightarm_body_keypoints = [0, 1, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24] + [25 + i for i in [0, 1, 2, 3, 4, 5, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18]]
+ keypoints_2d[nonrightarm_body_keypoints, :] = 0
+ if keypoints_2d[:, -1].sum() > 1:
+ center, scale = get_bbox(keypoints_2d)
+ center_x = center[0]
+ center_y = center[1]
+ width = 1.1 * scale[0]
+ height = 1.1 * scale[1]
+ return center_x, center_y, width, height
+
+def crop_leftarm_only(center_x: float, center_y: float, width: float, height: float, keypoints_2d: np.ndarray):
+ """
+ Extreme cropping: Crop the box and keep on only the left arm.
+ Args:
+ center_x (float): x coordinate of the bounding box center.
+ center_y (float): y coordinate of the bounding box center.
+ width (float): Bounding box width.
+ height (float): Bounding box height.
+ keypoints_2d (np.ndarray): Array of shape (N, 3) containing 2D keypoint locations.
+ Returns:
+ center_x (float): x coordinate of the new bounding box center.
+ center_y (float): y coordinate of the new bounding box center.
+ width (float): New bounding box width.
+ height (float): New bounding box height.
+ """
+ keypoints_2d = keypoints_2d.copy()
+ nonleftarm_body_keypoints = [0, 1, 2, 3, 4, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24] + [25 + i for i in [0, 1, 2, 3, 4, 5, 6, 7, 8, 12, 13, 14, 15, 16, 17, 18]]
+ keypoints_2d[nonleftarm_body_keypoints, :] = 0
+ if keypoints_2d[:, -1].sum() > 1:
+ center, scale = get_bbox(keypoints_2d)
+ center_x = center[0]
+ center_y = center[1]
+ width = 1.1 * scale[0]
+ height = 1.1 * scale[1]
+ return center_x, center_y, width, height
+
+def crop_legs_only(center_x: float, center_y: float, width: float, height: float, keypoints_2d: np.ndarray):
+ """
+ Extreme cropping: Crop the box and keep on only the legs.
+ Args:
+ center_x (float): x coordinate of the bounding box center.
+ center_y (float): y coordinate of the bounding box center.
+ width (float): Bounding box width.
+ height (float): Bounding box height.
+ keypoints_2d (np.ndarray): Array of shape (N, 3) containing 2D keypoint locations.
+ Returns:
+ center_x (float): x coordinate of the new bounding box center.
+ center_y (float): y coordinate of the new bounding box center.
+ width (float): New bounding box width.
+ height (float): New bounding box height.
+ """
+ keypoints_2d = keypoints_2d.copy()
+ nonlegs_body_keypoints = [0, 1, 2, 3, 4, 5, 6, 7, 15, 16, 17, 18] + [25 + i for i in [6, 7, 8, 9, 10, 11, 12, 13, 15, 16, 17, 18]]
+ keypoints_2d[nonlegs_body_keypoints, :] = 0
+ if keypoints_2d[:, -1].sum() > 1:
+ center, scale = get_bbox(keypoints_2d)
+ center_x = center[0]
+ center_y = center[1]
+ width = 1.1 * scale[0]
+ height = 1.1 * scale[1]
+ return center_x, center_y, width, height
+
+def crop_rightleg_only(center_x: float, center_y: float, width: float, height: float, keypoints_2d: np.ndarray):
+ """
+ Extreme cropping: Crop the box and keep on only the right leg.
+ Args:
+ center_x (float): x coordinate of the bounding box center.
+ center_y (float): y coordinate of the bounding box center.
+ width (float): Bounding box width.
+ height (float): Bounding box height.
+ keypoints_2d (np.ndarray): Array of shape (N, 3) containing 2D keypoint locations.
+ Returns:
+ center_x (float): x coordinate of the new bounding box center.
+ center_y (float): y coordinate of the new bounding box center.
+ width (float): New bounding box width.
+ height (float): New bounding box height.
+ """
+ keypoints_2d = keypoints_2d.copy()
+ nonrightleg_body_keypoints = [0, 1, 2, 3, 4, 5, 6, 7, 8, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21] + [25 + i for i in [3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18]]
+ keypoints_2d[nonrightleg_body_keypoints, :] = 0
+ if keypoints_2d[:, -1].sum() > 1:
+ center, scale = get_bbox(keypoints_2d)
+ center_x = center[0]
+ center_y = center[1]
+ width = 1.1 * scale[0]
+ height = 1.1 * scale[1]
+ return center_x, center_y, width, height
+
+def crop_leftleg_only(center_x: float, center_y: float, width: float, height: float, keypoints_2d: np.ndarray):
+ """
+ Extreme cropping: Crop the box and keep on only the left leg.
+ Args:
+ center_x (float): x coordinate of the bounding box center.
+ center_y (float): y coordinate of the bounding box center.
+ width (float): Bounding box width.
+ height (float): Bounding box height.
+ keypoints_2d (np.ndarray): Array of shape (N, 3) containing 2D keypoint locations.
+ Returns:
+ center_x (float): x coordinate of the new bounding box center.
+ center_y (float): y coordinate of the new bounding box center.
+ width (float): New bounding box width.
+ height (float): New bounding box height.
+ """
+ keypoints_2d = keypoints_2d.copy()
+ nonleftleg_body_keypoints = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 15, 16, 17, 18, 22, 23, 24] + [25 + i for i in [0, 1, 2, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18]]
+ keypoints_2d[nonleftleg_body_keypoints, :] = 0
+ if keypoints_2d[:, -1].sum() > 1:
+ center, scale = get_bbox(keypoints_2d)
+ center_x = center[0]
+ center_y = center[1]
+ width = 1.1 * scale[0]
+ height = 1.1 * scale[1]
+ return center_x, center_y, width, height
+
+def full_body(keypoints_2d: np.ndarray) -> bool:
+ """
+ Check if all main body joints are visible.
+ Args:
+ keypoints_2d (np.ndarray): Array of shape (N, 3) containing 2D keypoint locations.
+ Returns:
+ bool: True if all main body joints are visible.
+ """
+
+ body_keypoints_openpose = [2, 3, 4, 5, 6, 7, 10, 11, 13, 14]
+ body_keypoints = [25 + i for i in [8, 7, 6, 9, 10, 11, 1, 0, 4, 5]]
+ return (np.maximum(keypoints_2d[body_keypoints, -1], keypoints_2d[body_keypoints_openpose, -1]) > 0).sum() == len(body_keypoints)
+
+def upper_body(keypoints_2d: np.ndarray):
+ """
+ Check if all upper body joints are visible.
+ Args:
+ keypoints_2d (np.ndarray): Array of shape (N, 3) containing 2D keypoint locations.
+ Returns:
+ bool: True if all main body joints are visible.
+ """
+ lower_body_keypoints_openpose = [10, 11, 13, 14]
+ lower_body_keypoints = [25 + i for i in [1, 0, 4, 5]]
+ upper_body_keypoints_openpose = [0, 1, 15, 16, 17, 18]
+ upper_body_keypoints = [25+8, 25+9, 25+12, 25+13, 25+17, 25+18]
+ return ((keypoints_2d[lower_body_keypoints + lower_body_keypoints_openpose, -1] > 0).sum() == 0)\
+ and ((keypoints_2d[upper_body_keypoints + upper_body_keypoints_openpose, -1] > 0).sum() >= 2)
+
+def get_bbox(keypoints_2d: np.ndarray, rescale: float = 1.2) -> Tuple:
+ """
+ Get center and scale for bounding box from openpose detections.
+ Args:
+ keypoints_2d (np.ndarray): Array of shape (N, 3) containing 2D keypoint locations.
+ rescale (float): Scale factor to rescale bounding boxes computed from the keypoints.
+ Returns:
+ center (np.ndarray): Array of shape (2,) containing the new bounding box center.
+ scale (float): New bounding box scale.
+ """
+ valid = keypoints_2d[:,-1] > 0
+ valid_keypoints = keypoints_2d[valid][:,:-1]
+ center = 0.5 * (valid_keypoints.max(axis=0) + valid_keypoints.min(axis=0))
+ bbox_size = (valid_keypoints.max(axis=0) - valid_keypoints.min(axis=0))
+ # adjust bounding box tightness
+ scale = bbox_size
+ scale *= rescale
+ return center, scale
+
+def extreme_cropping(center_x: float, center_y: float, width: float, height: float, keypoints_2d: np.ndarray) -> Tuple:
+ """
+ Perform extreme cropping
+ Args:
+ center_x (float): x coordinate of bounding box center.
+ center_y (float): y coordinate of bounding box center.
+ width (float): bounding box width.
+ height (float): bounding box height.
+ keypoints_2d (np.ndarray): Array of shape (N, 3) containing 2D keypoint locations.
+ rescale (float): Scale factor to rescale bounding boxes computed from the keypoints.
+ Returns:
+ center_x (float): x coordinate of bounding box center.
+ center_y (float): y coordinate of bounding box center.
+ width (float): bounding box width.
+ height (float): bounding box height.
+ """
+ p = torch.rand(1).item()
+ if full_body(keypoints_2d):
+ if p < 0.7:
+ center_x, center_y, width, height = crop_to_hips(center_x, center_y, width, height, keypoints_2d)
+ elif p < 0.9:
+ center_x, center_y, width, height = crop_to_shoulders(center_x, center_y, width, height, keypoints_2d)
+ else:
+ center_x, center_y, width, height = crop_to_head(center_x, center_y, width, height, keypoints_2d)
+ elif upper_body(keypoints_2d):
+ if p < 0.9:
+ center_x, center_y, width, height = crop_to_shoulders(center_x, center_y, width, height, keypoints_2d)
+ else:
+ center_x, center_y, width, height = crop_to_head(center_x, center_y, width, height, keypoints_2d)
+
+ return center_x, center_y, max(width, height), max(width, height)
+
+def extreme_cropping_aggressive(center_x: float, center_y: float, width: float, height: float, keypoints_2d: np.ndarray) -> Tuple:
+ """
+ Perform aggressive extreme cropping
+ Args:
+ center_x (float): x coordinate of bounding box center.
+ center_y (float): y coordinate of bounding box center.
+ width (float): bounding box width.
+ height (float): bounding box height.
+ keypoints_2d (np.ndarray): Array of shape (N, 3) containing 2D keypoint locations.
+ rescale (float): Scale factor to rescale bounding boxes computed from the keypoints.
+ Returns:
+ center_x (float): x coordinate of bounding box center.
+ center_y (float): y coordinate of bounding box center.
+ width (float): bounding box width.
+ height (float): bounding box height.
+ """
+ p = torch.rand(1).item()
+ if full_body(keypoints_2d):
+ if p < 0.2:
+ center_x, center_y, width, height = crop_to_hips(center_x, center_y, width, height, keypoints_2d)
+ elif p < 0.3:
+ center_x, center_y, width, height = crop_to_shoulders(center_x, center_y, width, height, keypoints_2d)
+ elif p < 0.4:
+ center_x, center_y, width, height = crop_to_head(center_x, center_y, width, height, keypoints_2d)
+ elif p < 0.5:
+ center_x, center_y, width, height = crop_torso_only(center_x, center_y, width, height, keypoints_2d)
+ elif p < 0.6:
+ center_x, center_y, width, height = crop_rightarm_only(center_x, center_y, width, height, keypoints_2d)
+ elif p < 0.7:
+ center_x, center_y, width, height = crop_leftarm_only(center_x, center_y, width, height, keypoints_2d)
+ elif p < 0.8:
+ center_x, center_y, width, height = crop_legs_only(center_x, center_y, width, height, keypoints_2d)
+ elif p < 0.9:
+ center_x, center_y, width, height = crop_rightleg_only(center_x, center_y, width, height, keypoints_2d)
+ else:
+ center_x, center_y, width, height = crop_leftleg_only(center_x, center_y, width, height, keypoints_2d)
+ elif upper_body(keypoints_2d):
+ if p < 0.2:
+ center_x, center_y, width, height = crop_to_shoulders(center_x, center_y, width, height, keypoints_2d)
+ elif p < 0.4:
+ center_x, center_y, width, height = crop_to_head(center_x, center_y, width, height, keypoints_2d)
+ elif p < 0.6:
+ center_x, center_y, width, height = crop_torso_only(center_x, center_y, width, height, keypoints_2d)
+ elif p < 0.8:
+ center_x, center_y, width, height = crop_rightarm_only(center_x, center_y, width, height, keypoints_2d)
+ else:
+ center_x, center_y, width, height = crop_leftarm_only(center_x, center_y, width, height, keypoints_2d)
+ return center_x, center_y, max(width, height), max(width, height)
diff --git a/lib/data/datasets/skel_hmr2_fashion/vitdet_dataset.py b/lib/data/datasets/skel_hmr2_fashion/vitdet_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..8c9d9d4d01259884ae8438e859bb04a1d141a07d
--- /dev/null
+++ b/lib/data/datasets/skel_hmr2_fashion/vitdet_dataset.py
@@ -0,0 +1,89 @@
+from typing import Dict
+
+import cv2
+import numpy as np
+from skimage.filters import gaussian
+from yacs.config import CfgNode
+import torch
+
+from .utils import (convert_cvimg_to_tensor,
+ expand_to_aspect_ratio,
+ generate_image_patch_cv2)
+
+DEFAULT_MEAN = 255. * np.array([0.485, 0.456, 0.406])
+DEFAULT_STD = 255. * np.array([0.229, 0.224, 0.225])
+
+class ViTDetDataset(torch.utils.data.Dataset):
+
+ def __init__(self,
+ cfg: CfgNode,
+ img_cv2: np.ndarray,
+ boxes: np.ndarray,
+ train: bool = False,
+ **kwargs):
+ super().__init__()
+ self.cfg = cfg
+ self.img_cv2 = img_cv2
+ # self.boxes = boxes
+
+ assert train == False, "ViTDetDataset is only for inference"
+ self.train = train
+ self.img_size = cfg.MODEL.IMAGE_SIZE
+ self.mean = 255. * np.array(self.cfg.MODEL.IMAGE_MEAN)
+ self.std = 255. * np.array(self.cfg.MODEL.IMAGE_STD)
+
+ # Preprocess annotations
+ boxes = boxes.astype(np.float32)
+ self.center = (boxes[:, 2:4] + boxes[:, 0:2]) / 2.0
+ self.scale = (boxes[:, 2:4] - boxes[:, 0:2]) / 200.0
+ self.personid = np.arange(len(boxes), dtype=np.int32)
+
+ def __len__(self) -> int:
+ return len(self.personid)
+
+ def __getitem__(self, idx: int) -> Dict[str, np.ndarray]:
+
+ center = self.center[idx].copy()
+ center_x = center[0]
+ center_y = center[1]
+
+ scale = self.scale[idx]
+ BBOX_SHAPE = self.cfg.MODEL.get('BBOX_SHAPE', None)
+ bbox_size = expand_to_aspect_ratio(scale*200, target_aspect_ratio=BBOX_SHAPE).max()
+
+ patch_width = patch_height = self.img_size
+
+ # 3. generate image patch
+ # if use_skimage_antialias:
+ cvimg = self.img_cv2.copy()
+ if True:
+ # Blur image to avoid aliasing artifacts
+ downsampling_factor = ((bbox_size*1.0) / patch_width)
+ print(f'{downsampling_factor=}')
+ downsampling_factor = downsampling_factor / 2.0
+ if downsampling_factor > 1.1:
+ cvimg = gaussian(cvimg, sigma=(downsampling_factor-1)/2, channel_axis=2, preserve_range=True)
+
+
+ img_patch_cv, trans = generate_image_patch_cv2(cvimg,
+ center_x, center_y,
+ bbox_size, bbox_size,
+ patch_width, patch_height,
+ False, 1.0, 0,
+ border_mode=cv2.BORDER_CONSTANT)
+ img_patch_cv = img_patch_cv[:, :, ::-1]
+ img_patch = convert_cvimg_to_tensor(img_patch_cv)
+
+ # apply normalization
+ for n_c in range(min(self.img_cv2.shape[2], 3)):
+ img_patch[n_c, :, :] = (img_patch[n_c, :, :] - self.mean[n_c]) / self.std[n_c]
+
+ item = {
+ 'img_full': cvimg,
+ 'img': img_patch,
+ 'personid': int(self.personid[idx]),
+ }
+ item['box_center'] = self.center[idx].copy()
+ item['box_size'] = bbox_size
+ item['img_size'] = 1.0 * np.array([cvimg.shape[1], cvimg.shape[0]])
+ return item
diff --git a/lib/data/modules/hmr2_fashion/README.md b/lib/data/modules/hmr2_fashion/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..8c41d3b66efe7223d3f410db8e980e2033738ad7
--- /dev/null
+++ b/lib/data/modules/hmr2_fashion/README.md
@@ -0,0 +1,3 @@
+# HMR2 fashion dataset
+
+This folder contains customized implementation to adapt HMR2 style dataset. (e.g. the `ImageDataset` class).
\ No newline at end of file
diff --git a/lib/data/modules/hmr2_fashion/skel_wds.py b/lib/data/modules/hmr2_fashion/skel_wds.py
new file mode 100644
index 0000000000000000000000000000000000000000..27155e2223323e7e8f877ae04865849e73726874
--- /dev/null
+++ b/lib/data/modules/hmr2_fashion/skel_wds.py
@@ -0,0 +1,106 @@
+from lib.kits.basic import *
+
+import webdataset as wds
+
+
+from lib.data.datasets.skel_hmr2_fashion.image_dataset import ImageDataset
+
+
+class MixedWebDataset(wds.WebDataset):
+ def __init__(self) -> None:
+ super(wds.WebDataset, self).__init__()
+
+
+class DataModule(pl.LightningDataModule):
+ def __init__(self, name:str, cfg:DictConfig):
+ super().__init__()
+ self.name = name
+ self.cfg = cfg
+ self.cfg_eval = self.cfg.pop('eval', None)
+ self.cfg_train = self.cfg.pop('train', None)
+
+
+ def setup(self, stage=None):
+ if stage in ['test', None, '_debug_eval']:
+ self._setup_eval()
+
+ if stage in ['fit', None, '_debug_train']:
+ self._setup_train()
+
+
+ def train_dataloader(self):
+ return torch.utils.data.DataLoader(
+ dataset = self.train_dataset,
+ **self.cfg_train.dataloader,
+ )
+
+
+ def val_dataloader(self):
+ # Since we don't need validation here.
+ return self.test_dataloader()
+
+
+ def test_dataloader(self):
+ # return torch.utils.data.DataLoader(
+ # dataset = self.eval_datasets['LSP-EXTENDED'], # TODO: Support multiple datasets through ConcatDataset (but to figure out how to mix with weights)
+ # **self.cfg_eval.dataloader,
+ # )
+ return torch.utils.data.DataLoader(
+ dataset = self.eval_datasets, # TODO: Support multiple datasets through ConcatDataset (but to figure out how to mix with weights)
+ **self.cfg_eval.dataloader,
+ )
+
+
+ # ========== Internal Modules to Setup Datasets ==========
+
+ def _setup_train(self):
+ hack_cfg = {
+ 'IMAGE_SIZE': self.cfg.policy.img_patch_size,
+ 'IMAGE_MEAN': self.cfg.policy.img_mean,
+ 'IMAGE_STD' : self.cfg.policy.img_std,
+ 'BBOX_SHAPE': None,
+ 'augm': self.cfg.augm,
+ }
+
+ self.train_datasets = [] # [(dataset:Dataset, weight:float), ...]
+ datasets, weights = [], []
+ opt = self.cfg_train.get('shared_ds_opt', {})
+ for dataset_cfg in self.cfg_train.datasets:
+ cur_cfg = {**hack_cfg, **opt}
+ dataset = ImageDataset.load_tars_as_webdataset(
+ cfg = cur_cfg,
+ urls = dataset_cfg.item.urls,
+ train = True,
+ epoch_size = dataset_cfg.item.epoch_size,
+ )
+ weights.append(dataset_cfg.weight)
+ datasets.append(dataset)
+ weights = to_numpy(weights)
+ weights = weights / weights.sum()
+ self.train_dataset = MixedWebDataset()
+ self.train_dataset.append(wds.RandomMix(datasets, weights, longest=False))
+ self.train_dataset = self.train_dataset.with_epoch(100_000).shuffle(4000)
+
+
+ def _setup_eval(self):
+ hack_cfg = {
+ 'IMAGE_SIZE' : self.cfg.policy.img_patch_size,
+ 'IMAGE_MEAN' : self.cfg.policy.img_mean,
+ 'IMAGE_STD' : self.cfg.policy.img_std,
+ 'BBOX_SHAPE' : [192, 256],
+ 'augm' : self.cfg.augm,
+ }
+
+ self.eval_datasets = {}
+ opt = self.cfg_train.get('shared_ds_opt', {})
+ for dataset_cfg in self.cfg_eval.datasets:
+ cur_cfg = {**hack_cfg, **opt}
+ dataset = ImageDataset(
+ cfg = hack_cfg,
+ dataset_file = dataset_cfg.item.dataset_file,
+ img_dir = dataset_cfg.item.img_root,
+ train = False,
+ )
+ dataset._kp_list_ = dataset_cfg.item.kp_list
+ self.eval_datasets[dataset_cfg.name] = dataset
+
diff --git a/lib/data/modules/hsmr_v1/data_module.py b/lib/data/modules/hsmr_v1/data_module.py
new file mode 100644
index 0000000000000000000000000000000000000000..6616683879052d01c77dc3d9b00424553c7cc0f8
--- /dev/null
+++ b/lib/data/modules/hsmr_v1/data_module.py
@@ -0,0 +1,117 @@
+from lib.kits.basic import *
+
+from lib.data.datasets.hsmr_v1.mocap_dataset import MoCapDataset
+from lib.data.datasets.hsmr_v1.wds_loader import load_tars_as_wds
+import webdataset as wds
+
+class MixedWebDataset(wds.WebDataset):
+ def __init__(self) -> None:
+ super(wds.WebDataset, self).__init__()
+
+
+class DataModule(pl.LightningDataModule):
+ def __init__(self, name, cfg):
+ super().__init__()
+ self.name = name
+ self.cfg = cfg
+ self.cfg_eval = self.cfg.pop('eval', None)
+ self.cfg_train = self.cfg.pop('train', None)
+ self.cfg_mocap = self.cfg.pop('mocap', None)
+
+
+ def setup(self, stage=None):
+ if stage in ['test', None, '_debug_eval'] and self.cfg_eval is not None:
+ # get_logger().info('Evaluation dataset will be enabled.')
+ self._setup_eval()
+
+ if stage in ['fit', None, '_debug_train'] and self.cfg_train is not None:
+ # get_logger().info('Training dataset will be enabled.')
+ self._setup_train()
+
+ if stage in ['fit', None, '_debug_mocap'] and self.cfg_mocap is not None:
+ # get_logger().info('Mocap dataset will be enabled.')
+ self._setup_mocap()
+
+
+ def train_dataloader(self):
+ img_dataset = torch.utils.data.DataLoader(
+ dataset = self.train_dataset,
+ **self.cfg_train.dataloader,
+ )
+ ret = {'img_ds' : img_dataset }
+
+ if self.cfg_mocap is not None:
+ mocap_dataset = torch.utils.data.DataLoader(
+ dataset = self.mocap_dataset,
+ **self.cfg_mocap.dataloader,
+ )
+ ret['mocap_ds'] = mocap_dataset
+
+ return ret
+
+
+
+ # ========== Internal Modules to Setup Datasets ==========
+
+
+ def _setup_train(self):
+ names, datasets, weights = [], [], []
+ ld_cfg = self.cfg_train.cfg # cfg for initializing wds loading pipeline
+
+ for ds_cfg in self.cfg_train.datasets:
+ dataset = load_tars_as_wds(
+ ld_cfg,
+ ds_cfg.item.urls,
+ ds_cfg.item.epoch_size
+ )
+
+ names.append(ds_cfg.name)
+ datasets.append(dataset)
+ weights.append(ds_cfg.weight)
+
+ # get_logger().info(f"Dataset '{ds_cfg.name}' loaded.")
+
+ # Normalize the weights and mix the datasets.
+ weights = to_numpy(weights)
+ weights = weights / weights.sum()
+ self.train_datasets = datasets
+ self.train_dataset = MixedWebDataset()
+ self.train_dataset.append(wds.RandomMix(datasets, weights))
+ self.train_dataset = self.train_dataset.with_epoch(50_000).shuffle(1000, initial=1000)
+
+
+ def _setup_mocap(self):
+ self.mocap_dataset = MoCapDataset(**self.cfg_mocap.cfg)
+
+
+ def _setup_eval(self, selected_ds_names:Optional[List[str]]=None):
+ from lib.data.datasets.skel_hmr2_fashion.image_dataset import ImageDataset
+ hack_cfg = {
+ 'IMAGE_SIZE' : 256,
+ 'IMAGE_MEAN' : [0.485, 0.456, 0.406],
+ 'IMAGE_STD' : [0.229, 0.224, 0.225],
+ 'BBOX_SHAPE' : [192, 256],
+ 'augm' : self.cfg.image_augmentation,
+ 'SUPPRESS_KP_CONF_THRESH' : 0.3,
+ 'FILTER_NUM_KP' : 4,
+ 'FILTER_NUM_KP_THRESH' : 0.0,
+ 'FILTER_REPROJ_THRESH' : 31000,
+ 'SUPPRESS_BETAS_THRESH' : 3.0,
+ 'SUPPRESS_BAD_POSES' : False,
+ 'POSES_BETAS_SIMULTANEOUS': True,
+ 'FILTER_NO_POSES' : False,
+ 'BETAS_REG' : True,
+ }
+
+ self.eval_datasets = {}
+ for dataset_cfg in self.cfg_eval.datasets:
+ if selected_ds_names is not None and dataset_cfg.name not in selected_ds_names:
+ continue
+ dataset = ImageDataset(
+ cfg = hack_cfg,
+ dataset_file = dataset_cfg.item.dataset_file,
+ img_dir = dataset_cfg.item.img_root,
+ train = False,
+ )
+ dataset._kp_list_ = dataset_cfg.item.kp_list
+ self.eval_datasets[dataset_cfg.name] = dataset
diff --git a/lib/evaluation/__init__.py b/lib/evaluation/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..cb203b2c16637c3993b3e6cd2ad68164a2196370
--- /dev/null
+++ b/lib/evaluation/__init__.py
@@ -0,0 +1,93 @@
+from .metrics import *
+from .evaluators import EvaluatorBase
+
+from lib.body_models.smpl_utils.reality import eval_rot_delta as smpl_eval_rot_delta
+
+class HSMR3DEvaluator(EvaluatorBase):
+
+ def eval(self, **kwargs):
+ # Get the predictions and ground truths.
+ pd, gt = kwargs['pd'], kwargs['gt']
+ v3d_pd, v3d_gt = pd['v3d_pose'], gt['v3d_pose']
+ j3d_pd, j3d_gt = pd['j3d_pose'], gt['j3d_pose']
+
+ # Compute the metrics.
+ mpjpe = eval_MPxE(j3d_pd, j3d_gt)
+ pa_mpjpe = eval_PA_MPxE(j3d_pd, j3d_gt)
+ mpve = eval_MPxE(v3d_pd, v3d_gt)
+ pa_mpve = eval_PA_MPxE(v3d_pd, v3d_gt)
+
+ # Append the results.
+ self.accumulator['MPJPE'].append(mpjpe.detach().cpu())
+ self.accumulator['PA-MPJPE'].append(pa_mpjpe.detach().cpu())
+ self.accumulator['MPVE'].append(mpve.detach().cpu())
+ self.accumulator['PA-MPVE'].append(pa_mpve.detach().cpu())
+
+
+
+class SMPLRealityEvaluator(EvaluatorBase):
+
+ def eval(self, **kwargs):
+ # Get the predictions and ground truths.
+ pd = kwargs['pd']
+ body_pose = pd['body_pose'] # (..., 23, 3)
+
+ # Compute the metrics.
+ vio_d0 = smpl_eval_rot_delta(body_pose, tol_deg=0 ) # {k: (N, 3)}
+ vio_d5 = smpl_eval_rot_delta(body_pose, tol_deg=5 ) # {k: (N, 3)}
+ vio_d10 = smpl_eval_rot_delta(body_pose, tol_deg=10) # {k: (N, 3)}
+ vio_d20 = smpl_eval_rot_delta(body_pose, tol_deg=20) # {k: (N, 3)}
+ vio_d30 = smpl_eval_rot_delta(body_pose, tol_deg=30) # {k: (N, 3)}
+
+ # Append the results.
+ parts = vio_d5.keys()
+ for part in parts:
+ self.accumulator[f'VD0_{part}' ].append(vio_d0[part].max(-1)[0].detach().cpu()) # (N,)
+ self.accumulator[f'VD5_{part}' ].append(vio_d5[part].max(-1)[0].detach().cpu()) # (N,)
+ self.accumulator[f'VD10_{part}'].append(vio_d10[part].max(-1)[0].detach().cpu()) # (N,)
+ self.accumulator[f'VD20_{part}'].append(vio_d20[part].max(-1)[0].detach().cpu()) # (N,)
+ self.accumulator[f'VD30_{part}'].append(vio_d30[part].max(-1)[0].detach().cpu()) # (N,)
+
+ def get_results(self, chosen_metric=None):
+ ''' Get the current mean results. '''
+ # Only chosen metrics will be compacted and returned.
+ compacted = self._compact_accumulator(chosen_metric)
+ ret = {}
+ for k, v in compacted.items():
+ vio_max = v.max()
+ vio_mean = v.mean()
+ vio_median = v.median()
+ tot_cnt = len(v)
+ vio_cnt = (v > 0).float().sum()
+ vio_p = vio_cnt / tot_cnt
+ ret[f'{k}_max'] = vio_max.item()
+ ret[f'{k}_mean'] = vio_mean.item()
+ ret[f'{k}_median'] = vio_median.item()
+ ret[f'{k}_percentage'] = vio_p.item()
+ return ret
+
+
+from lib.body_models.skel_utils.reality import eval_rot_delta as skel_eval_rot_delta
+
+class SKELRealityEvaluator(SMPLRealityEvaluator):
+
+ def eval(self, **kwargs):
+ # Get the predictions and ground truths.
+ pd = kwargs['pd']
+ poses = pd['poses'] # (..., 46)
+
+ # Compute the metrics.
+ vio_d0 = skel_eval_rot_delta(poses, tol_deg=0 ) # {k: (N, 3)}
+ vio_d5 = skel_eval_rot_delta(poses, tol_deg=5 ) # {k: (N, 3)}
+ vio_d10 = skel_eval_rot_delta(poses, tol_deg=10) # {k: (N, 3)}
+ vio_d20 = skel_eval_rot_delta(poses, tol_deg=20) # {k: (N, 3)}
+ vio_d30 = skel_eval_rot_delta(poses, tol_deg=30) # {k: (N, 3)}
+
+ # Append the results.
+ parts = vio_d5.keys()
+ for part in parts:
+ self.accumulator[f'VD0_{part}' ].append(vio_d0[part].max(-1)[0].detach().cpu()) # (N,)
+ self.accumulator[f'VD5_{part}' ].append(vio_d5[part].max(-1)[0].detach().cpu()) # (N,)
+ self.accumulator[f'VD10_{part}'].append(vio_d10[part].max(-1)[0].detach().cpu()) # (N,)
+ self.accumulator[f'VD20_{part}'].append(vio_d20[part].max(-1)[0].detach().cpu()) # (N,)
+ self.accumulator[f'VD30_{part}'].append(vio_d30[part].max(-1)[0].detach().cpu()) # (N,)
diff --git a/lib/evaluation/evaluators.py b/lib/evaluation/evaluators.py
new file mode 100644
index 0000000000000000000000000000000000000000..52477f39b26334931e1f3536785b6e3ee78e9c74
--- /dev/null
+++ b/lib/evaluation/evaluators.py
@@ -0,0 +1,34 @@
+import torch
+from collections import defaultdict
+
+from .metrics import *
+
+
+class EvaluatorBase():
+ ''' To use this class, you should inherit it and implement the `eval` method. '''
+ def __init__(self):
+ self.accumulator = defaultdict(list)
+
+ def eval(self, **kwargs):
+ ''' Evaluate the metrics on the data. '''
+ raise NotImplementedError
+
+ def get_results(self, chosen_metric=None):
+ ''' Get the current mean results. '''
+ # Only chosen metrics will be compacted and returned.
+ compacted = self._compact_accumulator(chosen_metric)
+ ret = {}
+ for k, v in compacted.items():
+ ret[k] = v.mean(dim=0).item()
+ return ret
+
+ def _compact_accumulator(self, chosen_metric=None):
+ ''' Compact the accumulator list and return the compacted results. '''
+ ret = {}
+ for k, v in self.accumulator.items():
+ # Only chosen metrics will be compacted.
+ if chosen_metric is None or k in chosen_metric:
+ ret[k] = torch.cat(v, dim=0)
+ self.accumulator[k] = [ret[k]]
+ return ret
+
diff --git a/lib/evaluation/hmr2_utils/__init__.py b/lib/evaluation/hmr2_utils/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..62a16b522f818b199b0146faff443752d9221093
--- /dev/null
+++ b/lib/evaluation/hmr2_utils/__init__.py
@@ -0,0 +1,314 @@
+"""
+Code adapted from: https://github.com/akanazawa/hmr/blob/master/src/benchmark/eval_util.py
+"""
+
+import torch
+import numpy as np
+from typing import Optional, Dict, List, Tuple
+
+def compute_similarity_transform(S1: torch.Tensor, S2: torch.Tensor) -> torch.Tensor:
+ """
+ Computes a similarity transform (sR, t) in a batched way that takes
+ a set of 3D points S1 (B, N, 3) closest to a set of 3D points S2 (B, N, 3),
+ where R is a 3x3 rotation matrix, t 3x1 translation, s scale.
+ i.e. solves the orthogonal Procrutes problem.
+ Args:
+ S1 (torch.Tensor): First set of points of shape (B, N, 3).
+ S2 (torch.Tensor): Second set of points of shape (B, N, 3).
+ Returns:
+ (torch.Tensor): The first set of points after applying the similarity transformation.
+ """
+
+ batch_size = S1.shape[0]
+ S1 = S1.permute(0, 2, 1)
+ S2 = S2.permute(0, 2, 1)
+ # 1. Remove mean.
+ mu1 = S1.mean(dim=2, keepdim=True)
+ mu2 = S2.mean(dim=2, keepdim=True)
+ X1 = S1 - mu1
+ X2 = S2 - mu2
+
+ # 2. Compute variance of X1 used for scale.
+ var1 = (X1**2).sum(dim=(1,2))
+
+ # 3. The outer product of X1 and X2.
+ K = torch.matmul(X1, X2.permute(0, 2, 1))
+
+ # 4. Solution that Maximizes trace(R'K) is R=U*V', where U, V are singular vectors of K.
+ U, s, V = torch.svd(K)
+ Vh = V.permute(0, 2, 1)
+
+ # Construct Z that fixes the orientation of R to get det(R)=1.
+ Z = torch.eye(U.shape[1], device=U.device).unsqueeze(0).repeat(batch_size, 1, 1)
+ Z[:, -1, -1] *= torch.sign(torch.linalg.det(torch.matmul(U, Vh)))
+
+ # Construct R.
+ R = torch.matmul(torch.matmul(V, Z), U.permute(0, 2, 1))
+
+ # 5. Recover scale.
+ trace = torch.matmul(R, K).diagonal(offset=0, dim1=-1, dim2=-2).sum(dim=-1)
+ scale = (trace / var1).unsqueeze(dim=-1).unsqueeze(dim=-1)
+
+ # 6. Recover translation.
+ t = mu2 - scale*torch.matmul(R, mu1)
+
+ # 7. Error:
+ S1_hat = scale*torch.matmul(R, S1) + t
+
+ return S1_hat.permute(0, 2, 1)
+
+def reconstruction_error(S1, S2) -> np.array:
+ """
+ Computes the mean Euclidean distance of 2 set of points S1, S2 after performing Procrustes alignment.
+ Args:
+ S1 (torch.Tensor): First set of points of shape (B, N, 3).
+ S2 (torch.Tensor): Second set of points of shape (B, N, 3).
+ Returns:
+ (np.array): Reconstruction error.
+ """
+ S1_hat = compute_similarity_transform(S1, S2)
+ re = torch.sqrt( ((S1_hat - S2)** 2).sum(dim=-1)).mean(dim=-1)
+ return re
+
+def eval_pose(pred_joints, gt_joints) -> Tuple[np.array, np.array]:
+ """
+ Compute joint errors in mm before and after Procrustes alignment.
+ Args:
+ pred_joints (torch.Tensor): Predicted 3D joints of shape (B, N, 3).
+ gt_joints (torch.Tensor): Ground truth 3D joints of shape (B, N, 3).
+ Returns:
+ Tuple[np.array, np.array]: Joint errors in mm before and after alignment.
+ """
+ # Absolute error (MPJPE)
+ mpjpe = torch.sqrt(((pred_joints - gt_joints) ** 2).sum(dim=-1)).mean(dim=-1).cpu().numpy()
+
+ # Reconstruction_error
+ r_error = reconstruction_error(pred_joints, gt_joints).cpu().numpy()
+ return 1000 * mpjpe, 1000 * r_error
+
+class Evaluator:
+
+ def __init__(self,
+ dataset_length: int,
+ keypoint_list: List,
+ pelvis_ind: int,
+ metrics: List = ['mode_mpjpe', 'mode_re', 'min_mpjpe', 'min_re'],
+ pck_thresholds: Optional[List] = None):
+ """
+ Class used for evaluating trained models on different 3D pose datasets.
+ Args:
+ dataset_length (int): Total dataset length.
+ keypoint_list [List]: List of keypoints used for evaluation.
+ pelvis_ind (int): Index of pelvis keypoint; used for aligning the predictions and ground truth.
+ metrics [List]: List of evaluation metrics to record.
+ """
+ self.dataset_length = dataset_length
+ self.keypoint_list = keypoint_list
+ self.pelvis_ind = pelvis_ind
+ self.metrics = metrics
+ for metric in self.metrics:
+ setattr(self, metric, np.zeros((dataset_length,)))
+ self.counter = 0
+ if pck_thresholds is None:
+ self.pck_evaluator = None
+ else:
+ self.pck_evaluator = EvaluatorPCK(pck_thresholds)
+
+ def log(self):
+ """
+ Print current evaluation metrics
+ """
+ if self.counter == 0:
+ print('Evaluation has not started')
+ return
+ print(f'{self.counter} / {self.dataset_length} samples')
+ if self.pck_evaluator is not None:
+ self.pck_evaluator.log()
+ for metric in self.metrics:
+ if metric in ['mode_mpjpe', 'mode_re', 'min_mpjpe', 'min_re']:
+ unit = 'mm'
+ else:
+ unit = ''
+ print(f'{metric}: {getattr(self, metric)[:self.counter].mean()} {unit}')
+ print('***')
+
+ def get_metrics_dict(self) -> Dict:
+ """
+ Returns:
+ Dict: Dictionary of evaluation metrics.
+ """
+ d1 = {metric: getattr(self, metric)[:self.counter].mean().item() for metric in self.metrics}
+ if self.pck_evaluator is not None:
+ d2 = self.pck_evaluator.get_metrics_dict()
+ d1.update(d2)
+ return d1
+
+ def __call__(self, output: Dict, batch: Dict, opt_output: Optional[Dict] = None):
+ """
+ Evaluate current batch.
+ Args:
+ output (Dict): Regression output.
+ batch (Dict): Dictionary containing images and their corresponding annotations.
+ opt_output (Dict): Optimization output.
+ """
+ if self.pck_evaluator is not None:
+ self.pck_evaluator(output, batch, opt_output)
+
+ pred_keypoints_3d = output['pred_keypoints_3d'].detach()
+ pred_keypoints_3d = pred_keypoints_3d[:,None,:,:]
+ batch_size = pred_keypoints_3d.shape[0]
+ num_samples = pred_keypoints_3d.shape[1]
+ gt_keypoints_3d = batch['keypoints_3d'][:, :, :-1].unsqueeze(1).repeat(1, num_samples, 1, 1)
+
+ # Align predictions and ground truth such that the pelvis location is at the origin
+ pred_keypoints_3d -= pred_keypoints_3d[:, :, [self.pelvis_ind]]
+ gt_keypoints_3d -= gt_keypoints_3d[:, :, [self.pelvis_ind]]
+
+ # Compute joint errors
+ mpjpe, re = eval_pose(pred_keypoints_3d.reshape(batch_size * num_samples, -1, 3)[:, self.keypoint_list], gt_keypoints_3d.reshape(batch_size * num_samples, -1 ,3)[:, self.keypoint_list])
+ mpjpe = mpjpe.reshape(batch_size, num_samples)
+ re = re.reshape(batch_size, num_samples)
+
+ # Compute 2d keypoint errors
+ pred_keypoints_2d = output['pred_keypoints_2d'].detach()
+ pred_keypoints_2d = pred_keypoints_2d[:,None,:,:]
+ gt_keypoints_2d = batch['keypoints_2d'][:,None,:,:].repeat(1, num_samples, 1, 1)
+ conf = gt_keypoints_2d[:, :, :, -1].clone()
+ kp_err = torch.nn.functional.mse_loss(
+ pred_keypoints_2d,
+ gt_keypoints_2d[:, :, :, :-1],
+ reduction='none'
+ ).sum(dim=3)
+ kp_l2_loss = (conf * kp_err).mean(dim=2)
+ kp_l2_loss = kp_l2_loss.detach().cpu().numpy()
+
+ # Compute joint errors after optimization, if available.
+ if opt_output is not None:
+ opt_keypoints_3d = opt_output['model_joints']
+ opt_keypoints_3d -= opt_keypoints_3d[:, [self.pelvis_ind]]
+ opt_mpjpe, opt_re = eval_pose(opt_keypoints_3d[:, self.keypoint_list], gt_keypoints_3d[:, 0, self.keypoint_list])
+
+ # The 0-th sample always corresponds to the mode
+ if hasattr(self, 'mode_mpjpe'):
+ mode_mpjpe = mpjpe[:, 0]
+ self.mode_mpjpe[self.counter:self.counter+batch_size] = mode_mpjpe
+ if hasattr(self, 'mode_re'):
+ mode_re = re[:, 0]
+ self.mode_re[self.counter:self.counter+batch_size] = mode_re
+ if hasattr(self, 'mode_kpl2'):
+ mode_kpl2 = kp_l2_loss[:, 0]
+ self.mode_kpl2[self.counter:self.counter+batch_size] = mode_kpl2
+ if hasattr(self, 'min_mpjpe'):
+ min_mpjpe = mpjpe.min(axis=-1)
+ self.min_mpjpe[self.counter:self.counter+batch_size] = min_mpjpe
+ if hasattr(self, 'min_re'):
+ min_re = re.min(axis=-1)
+ self.min_re[self.counter:self.counter+batch_size] = min_re
+ if hasattr(self, 'min_kpl2'):
+ min_kpl2 = kp_l2_loss.min(axis=-1)
+ self.min_kpl2[self.counter:self.counter+batch_size] = min_kpl2
+ if hasattr(self, 'opt_mpjpe'):
+ self.opt_mpjpe[self.counter:self.counter+batch_size] = opt_mpjpe
+ if hasattr(self, 'opt_re'):
+ self.opt_re[self.counter:self.counter+batch_size] = opt_re
+
+ self.counter += batch_size
+
+ if hasattr(self, 'mode_mpjpe') and hasattr(self, 'mode_re'):
+ return {
+ 'mode_mpjpe': mode_mpjpe,
+ 'mode_re': mode_re,
+ }
+ if hasattr(self, 'mode_kpl2'):
+ return {
+ 'mode_kpl2': mode_kpl2,
+ }
+ else:
+ return {}
+
+
+class EvaluatorPCK:
+
+ def __init__(self, thresholds: List = [0.05, 0.1, 0.2, 0.3, 0.4, 0.5],):
+ """
+ Class used for evaluating trained models on different 3D pose datasets.
+ Args:
+ thresholds [List]: List of PCK thresholds to evaluate.
+ metrics [List]: List of evaluation metrics to record.
+ """
+ self.thresholds = thresholds
+ self.pred_kp_2d = []
+ self.gt_kp_2d = []
+ self.gt_conf_2d = []
+ self.counter = 0
+
+ def log(self):
+ """
+ Print current evaluation metrics
+ """
+ if self.counter == 0:
+ print('Evaluation has not started')
+ return
+ print(f'{self.counter} samples')
+ metrics_dict = self.get_metrics_dict()
+ for metric in metrics_dict:
+ print(f'{metric}: {metrics_dict[metric]}')
+ print('***')
+
+ def get_metrics_dict(self) -> Dict:
+ """
+ Returns:
+ Dict: Dictionary of evaluation metrics.
+ """
+ pcks = self.compute_pcks()
+ metrics = {}
+ for thr, (acc,avg_acc,cnt) in zip(self.thresholds, pcks):
+ metrics.update({f'kp{i}_pck_{thr}': float(a) for i, a in enumerate(acc) if a>=0})
+ metrics.update({f'kpAvg_pck_{thr}': float(avg_acc)})
+ return metrics
+
+ def compute_pcks(self):
+ pred_kp_2d = np.concatenate(self.pred_kp_2d, axis=0)
+ gt_kp_2d = np.concatenate(self.gt_kp_2d, axis=0)
+ gt_conf_2d = np.concatenate(self.gt_conf_2d, axis=0)
+ assert pred_kp_2d.shape == gt_kp_2d.shape
+ assert pred_kp_2d[..., 0].shape == gt_conf_2d.shape
+ assert pred_kp_2d.shape[1] == 1 # num_samples
+
+ from .pck_accuracy import keypoint_pck_accuracy
+ pcks = [
+ keypoint_pck_accuracy(
+ pred_kp_2d[:, 0, :, :],
+ gt_kp_2d[:, 0, :, :],
+ gt_conf_2d[:, 0, :]>0.5,
+ thr=thr,
+ normalize = np.ones((len(pred_kp_2d),2)) # Already in [-0.5,0.5] range. No need to normalize
+ )
+ for thr in self.thresholds
+ ]
+ return pcks
+
+ def __call__(self, output: Dict, batch: Dict, opt_output: Optional[Dict] = None):
+ """
+ Evaluate current batch.
+ Args:
+ output (Dict): Regression output.
+ batch (Dict): Dictionary containing images and their corresponding annotations.
+ opt_output (Dict): Optimization output.
+ """
+ pred_keypoints_2d = output['pred_keypoints_2d'].detach()
+ num_samples = 1
+ batch_size = pred_keypoints_2d.shape[0]
+
+ pred_keypoints_2d = pred_keypoints_2d[:,None,:,:]
+ gt_keypoints_2d = batch['keypoints_2d'][:,None,:,:].repeat(1, num_samples, 1, 1)
+
+ gt_bbox_expand_factor = (batch['box_size']/(batch['_scale']*200).max(dim=-1).values)
+ gt_bbox_expand_factor = gt_bbox_expand_factor[:,None,None,None].repeat(1, num_samples, 1, 1)
+ gt_bbox_expand_factor = gt_bbox_expand_factor.detach().cpu().numpy()
+
+ self.pred_kp_2d.append(pred_keypoints_2d[:, :, :, :2].detach().cpu().numpy() * gt_bbox_expand_factor)
+ self.gt_conf_2d.append(gt_keypoints_2d[:, :, :, -1].detach().cpu().numpy())
+ self.gt_kp_2d.append(gt_keypoints_2d[:, :, :, :2].detach().cpu().numpy() * gt_bbox_expand_factor)
+
+ self.counter += batch_size
diff --git a/lib/evaluation/hmr2_utils/pck_accuracy.py b/lib/evaluation/hmr2_utils/pck_accuracy.py
new file mode 100644
index 0000000000000000000000000000000000000000..5fea6c3859a930c805e0f5357b0fb71adac6764c
--- /dev/null
+++ b/lib/evaluation/hmr2_utils/pck_accuracy.py
@@ -0,0 +1,94 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+
+import numpy as np
+
+def _calc_distances(preds, targets, mask, normalize):
+ """Calculate the normalized distances between preds and target.
+
+ Note:
+ batch_size: N
+ num_keypoints: K
+ dimension of keypoints: D (normally, D=2 or D=3)
+
+ Args:
+ preds (np.ndarray[N, K, D]): Predicted keypoint location.
+ targets (np.ndarray[N, K, D]): Groundtruth keypoint location.
+ mask (np.ndarray[N, K]): Visibility of the target. False for invisible
+ joints, and True for visible. Invisible joints will be ignored for
+ accuracy calculation.
+ normalize (np.ndarray[N, D]): Typical value is heatmap_size
+
+ Returns:
+ np.ndarray[K, N]: The normalized distances. \
+ If target keypoints are missing, the distance is -1.
+ """
+ N, K, _ = preds.shape
+ # set mask=0 when normalize==0
+ _mask = mask.copy()
+ _mask[np.where((normalize == 0).sum(1))[0], :] = False
+ distances = np.full((N, K), -1, dtype=np.float32)
+ # handle invalid values
+ normalize[np.where(normalize <= 0)] = 1e6
+ distances[_mask] = np.linalg.norm(
+ ((preds - targets) / normalize[:, None, :])[_mask], axis=-1)
+ return distances.T
+
+
+def _distance_acc(distances, thr=0.5):
+ """Return the percentage below the distance threshold, while ignoring
+ distances values with -1.
+
+ Note:
+ batch_size: N
+ Args:
+ distances (np.ndarray[N, ]): The normalized distances.
+ thr (float): Threshold of the distances.
+
+ Returns:
+ float: Percentage of distances below the threshold. \
+ If all target keypoints are missing, return -1.
+ """
+ distance_valid = distances != -1
+ num_distance_valid = distance_valid.sum()
+ if num_distance_valid > 0:
+ return (distances[distance_valid] < thr).sum() / num_distance_valid
+ return -1
+
+
+def keypoint_pck_accuracy(pred, gt, mask, thr, normalize):
+ """Calculate the pose accuracy of PCK for each individual keypoint and the
+ averaged accuracy across all keypoints for coordinates.
+
+ Note:
+ PCK metric measures accuracy of the localization of the body joints.
+ The distances between predicted positions and the ground-truth ones
+ are typically normalized by the bounding box size.
+ The threshold (thr) of the normalized distance is commonly set
+ as 0.05, 0.1 or 0.2 etc.
+
+ - batch_size: N
+ - num_keypoints: K
+
+ Args:
+ pred (np.ndarray[N, K, 2]): Predicted keypoint location.
+ gt (np.ndarray[N, K, 2]): Groundtruth keypoint location.
+ mask (np.ndarray[N, K]): Visibility of the target. False for invisible
+ joints, and True for visible. Invisible joints will be ignored for
+ accuracy calculation.
+ thr (float): Threshold of PCK calculation.
+ normalize (np.ndarray[N, 2]): Normalization factor for H&W.
+
+ Returns:
+ tuple: A tuple containing keypoint accuracy.
+
+ - acc (np.ndarray[K]): Accuracy of each keypoint.
+ - avg_acc (float): Averaged accuracy across all keypoints.
+ - cnt (int): Number of valid keypoints.
+ """
+ distances = _calc_distances(pred, gt, mask, normalize)
+
+ acc = np.array([_distance_acc(d, thr) for d in distances])
+ valid_acc = acc[acc >= 0]
+ cnt = len(valid_acc)
+ avg_acc = valid_acc.mean() if cnt > 0 else 0
+ return acc, avg_acc, cnt
diff --git a/lib/evaluation/metrics/__init__.py b/lib/evaluation/metrics/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..5dbfa03320c4b86f2960a83ffb73917e4dd3d476
--- /dev/null
+++ b/lib/evaluation/metrics/__init__.py
@@ -0,0 +1,2 @@
+from .mpxe_like import *
+from .utils import *
\ No newline at end of file
diff --git a/lib/evaluation/metrics/mpxe_like.py b/lib/evaluation/metrics/mpxe_like.py
new file mode 100644
index 0000000000000000000000000000000000000000..6b656abd3b1834056eb91a241091375ce332b01a
--- /dev/null
+++ b/lib/evaluation/metrics/mpxe_like.py
@@ -0,0 +1,147 @@
+from typing import Optional
+import torch
+
+from .utils import *
+
+
+'''
+All MPxE-like metrics will be implements here.
+
+- Local Metrics: the inputs motion's translation should be removed (or may be automatically removed).
+ - MPxE: call `eval_MPxE()`
+ - PA-MPxE: cal `eval_PA_MPxE()`
+- Global Metrics: the inputs motion's translation should be kept.
+ - G-MPxE: call `eval_MPxE()`
+ - W2-MPxE: call `eval_Wk_MPxE()`, and set k = 2
+ - WA-MPxE: call `eval_WA_MPxE()`
+'''
+
+
+def eval_MPxE(
+ pred : torch.Tensor,
+ gt : torch.Tensor,
+ scale : float = m2mm,
+):
+ '''
+ Calculate the Mean Per Error. might be joints position (MPJPE), or vertices (MPVE).
+
+ The results will be the sequence of MPxE of each multi-dim batch.
+
+ ### Args
+ - `pred`: torch.Tensor
+ - shape = (...B, N, 3), where B is the multi-dim batch size, N is points count in one batch
+ - the predicted joints/vertices position data
+ - `gt`: torch.Tensor
+ - shape = (...B, N, 3), where B is the multi-dim batch size, N is points count in one batch
+ - the ground truth joints/vertices position data
+ - `scale`: float, default = `m2mm`
+
+ ### Returns
+ - torch.Tensor
+ - shape = (...B)
+ - shape = ()
+ '''
+ # Calculate the MPxE.
+ ret = L2_error(pred, gt).mean(dim=-1) * scale # (...B,)
+ return ret
+
+
+def eval_PA_MPxE(
+ pred : torch.Tensor,
+ gt : torch.Tensor,
+ scale : float = m2mm,
+):
+ '''
+ Calculate the Procrustes-Aligned Mean Per Error. might be joints position (PA-MPJPE), or
+ vertices (PA-MPVE). Targets will be Procrustes-aligned and then calculate the per frame MPxE.
+
+ The results will be the sequence of MPxE of each batch.
+
+ ### Args
+ - `pred`: torch.Tensor
+ - shape = (...B, N, 3), where B is the multi-dim batch size, N is points count in one batch
+ - the predicted joints/vertices position data
+ - `gt`: torch.Tensor
+ - shape = (...B, N, 3), where B is the multi-dim batch size, N is points count in one batch
+ - the ground truth joints/vertices position data
+ - `scale`: float, default = `m2mm`
+
+ ### Returns
+ - torch.Tensor
+ - shape = (...B)
+ - shape = ()
+ '''
+ # Perform Procrustes alignment.
+ pred_aligned = similarity_align_to(pred, gt) # (...B, N, 3)
+ # Calculate the PA-MPxE
+ return eval_MPxE(pred_aligned, gt, scale) # (...B,)
+
+
+def eval_Wk_MPxE(
+ pred : torch.Tensor,
+ gt : torch.Tensor,
+ scale : float = m2mm,
+ k_f : int = 2,
+):
+ '''
+ Calculate the first k frames aligned (World aligned) Mean Per Error. might be joints
+ position (PA-MPJPE), or vertices (PA-MPVE). Targets will be aligned using the first k frames
+ and then calculate the per frame MPxE.
+
+ The results will be the sequence of MPxE of each batch.
+
+ ### Args
+ - `pred`: torch.Tensor
+ - shape = (..., L, N, 3), where L is the length of the sequence, N is points count in one batch
+ - the predicted joints/vertices position data
+ - `gt`: torch.Tensor
+ - shape = (..., L, N, 3), where L is the length of the sequence, N is points count in one batch
+ - the ground truth joints/vertices position data
+ - `scale`: float, default = `m2mm`
+ - `k_f`: int, default = 2
+ - the number of frames to use for alignment
+
+ ### Returns
+ - torch.Tensor
+ - shape = (..., L)
+ - shape = ()
+ '''
+ L = max(pred.shape[-3], gt.shape[-3])
+ assert L >= 2, f'Length of the sequence should be at least 2, but got {L}.'
+ # Perform first two alignment.
+ pred_aligned = first_k_frames_align_to(pred, gt, k_f) # (..., L, N, 3)
+ # Calculate the PA-MPxE
+ return eval_MPxE(pred_aligned, gt, scale) # (..., L)
+
+
+def eval_WA_MPxE(
+ pred : torch.Tensor,
+ gt : torch.Tensor,
+ scale : float = m2mm,
+):
+ '''
+ Calculate the all frames aligned (World All aligned) Mean Per Error. might be joints
+ position (PA-MPJPE), or vertices (PA-MPVE). Targets will be aligned using the first k frames
+ and then calculate the per frame MPxE.
+
+ The results will be the sequence of MPxE of each batch.
+
+ ### Args
+ - `pred`: torch.Tensor
+ - shape = (..., L, N, 3), where L is the length of the sequence, N is points count in one batch
+ - the predicted joints/vertices position data
+ - `gt`: torch.Tensor
+ - shape = (..., L, N, 3), where L is the length of the sequence, N is points count in one batch
+ - the ground truth joints/vertices position data
+ - `scale`: float, default = `m2mm`
+
+ ### Returns
+ - torch.Tensor
+ - shape = (..., L)
+ - shape = ()
+ '''
+ L_pred = pred.shape[-3]
+ L_gt = gt.shape[-3]
+ assert (L_pred == L_gt), f'Length of the sequence should be the same, but got {L_pred} and {L_gt}.'
+ # WA_MPxE is just Wk_MPxE when k = L.
+ return eval_Wk_MPxE(pred, gt, scale, L_gt)
\ No newline at end of file
diff --git a/lib/evaluation/metrics/utils.py b/lib/evaluation/metrics/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..787e5c70e924a486a089dacc520343ce54037d0c
--- /dev/null
+++ b/lib/evaluation/metrics/utils.py
@@ -0,0 +1,195 @@
+import torch
+
+
+m2mm = 1000.0
+
+
+def L2_error(x:torch.Tensor, y:torch.Tensor):
+ '''
+ Calculate the L2 error across the last dim of the input tensors.
+
+ ### Args
+ - `x`: torch.Tensor, shape (..., D)
+ - `y`: torch.Tensor, shape (..., D)
+
+ ### Returns
+ - torch.Tensor, shape (...)
+ '''
+ return (x - y).norm(dim=-1)
+
+
+def similarity_align_to(
+ S1 : torch.Tensor,
+ S2 : torch.Tensor,
+):
+ '''
+ Computes a similarity transform (sR, t) that takes a set of 3D points S1 (3 x N)
+ closest to a set of 3D points S2, where R is an 3x3 rotation matrix,
+ t 3x1 translation, s scales. That is to solves the orthogonal Procrutes problem.
+
+ The code was modified from [WHAM](https://github.com/yohanshin/WHAM/blob/d1ade93ae83a91855902fdb8246c129c4b3b8a40/lib/eval/eval_utils.py#L201-L252).
+
+ ### Args
+ - `S1`: torch.Tensor, shape (...B, N, 3)
+ - `S2`: torch.Tensor, shape (...B, N, 3)
+
+ ### Returns
+ - torch.Tensor, shape (...B, N, 3)
+ '''
+ assert (S1.shape[-1] == 3 and S2.shape[-1] == 3), 'The last dimension of `S1` and `S2` must be 3.'
+ assert (S1.shape[:-2] == S2.shape[:-2]), 'The batch size of `S1` and `S2` must be the same.'
+ original_BN3 = S1.shape
+ N = original_BN3[-2]
+ S1 = S1.reshape(-1, N, 3) # (B', N, 3) <- (...B, N, 3)
+ S2 = S2.reshape(-1, N, 3) # (B', N, 3) <- (...B, N, 3)
+ B = S1.shape[0]
+
+ S1 = S1.transpose(-1, -2) # (B', 3, N) <- (B', N, 3)
+ S2 = S2.transpose(-1, -2) # (B', 3, N) <- (B', N, 3)
+ _device = S2.device
+ S1 = S1.to(_device)
+
+ # 1. Remove mean.
+ mu1 = S1.mean(axis=-1, keepdims=True) # (B', 3, 1)
+ mu2 = S2.mean(axis=-1, keepdims=True) # (B', 3, 1)
+ X1 = S1 - mu1 # (B', 3, N)
+ X2 = S2 - mu2 # (B', 3, N)
+
+ # 2. Compute variance of X1 used for scales.
+ var1 = torch.einsum('...BDN->...B', X1**2) # (B',)
+
+ # 3. The outer product of X1 and X2.
+ K = X1 @ X2.transpose(-1, -2) # (B', 3, 3)
+
+ # 4. Solution that Maximizes trace(R'K) is R=U*V', where U, V are singular vectors of K.
+ U, s, V = torch.svd(K) # (B', 3, 3), (B', 3), (B', 3, 3)
+
+ # Construct Z that fixes the orientation of R to get det(R)=1.
+ Z = torch.eye(3, device=_device)[None].repeat(B, 1, 1) # (B', 3, 3)
+ Z[:, -1, -1] *= (U @ V.transpose(-1, -2)).det().sign()
+
+ # Construct R.
+ R = V @ (Z @ U.transpose(-1, -2)) # (B', 3, 3)
+
+ # 5. Recover scales.
+ traces = [torch.trace(x)[None] for x in (R @ K)]
+ scales = torch.cat(traces) / var1 # (B',)
+ scales = scales[..., None, None] # (B', 1, 1)
+
+ # 6. Recover translation.
+ t = mu2 - (scales * (R @ mu1)) # (B', 3, 1)
+
+ # 7. Error:
+ S1_aligned = scales * (R @ S1) + t # (B', 3, N)
+
+ S1_aligned = S1_aligned.transpose(-1, -2) # (B', N, 3) <- (B', 3, N)
+ S1_aligned = S1_aligned.reshape(original_BN3) # (...B, N, 3)
+ return S1_aligned # (...B, N, 3)
+
+
+def align_pcl(Y: torch.Tensor, X: torch.Tensor, weight=None, fixed_scale=False):
+ '''
+ Align similarity transform to align X with Y using umeyama method. X' = s * R * X + t is aligned with Y.
+
+ The code was copied from [SLAHMR](https://github.com/vye16/slahmr/blob/58518fec991877bc4911e260776589185b828fe9/slahmr/geometry/pcl.py#L10-L60).
+
+ ### Args
+ - `Y`: torch.Tensor, shape (*, N, 3) first trajectory
+ - `X`: torch.Tensor, shape (*, N, 3) second trajectory
+ - `weight`: torch.Tensor, shape (*, N, 1) optional weight of valid correspondences
+ - `fixed_scale`: bool, default = False
+
+ ### Returns
+ - `s` (*, 1)
+ - `R` (*, 3, 3)
+ - `t` (*, 3)
+ '''
+ *dims, N, _ = Y.shape
+ N = torch.ones(*dims, 1, 1) * N
+
+ if weight is not None:
+ Y = Y * weight
+ X = X * weight
+ N = weight.sum(dim=-2, keepdim=True) # (*, 1, 1)
+
+ # subtract mean
+ my = Y.sum(dim=-2) / N[..., 0] # (*, 3)
+ mx = X.sum(dim=-2) / N[..., 0]
+ y0 = Y - my[..., None, :] # (*, N, 3)
+ x0 = X - mx[..., None, :]
+
+ if weight is not None:
+ y0 = y0 * weight
+ x0 = x0 * weight
+
+ # correlation
+ C = torch.matmul(y0.transpose(-1, -2), x0) / N # (*, 3, 3)
+ U, D, Vh = torch.linalg.svd(C) # (*, 3, 3), (*, 3), (*, 3, 3)
+
+ S = torch.eye(3).reshape(*(1,) * (len(dims)), 3, 3).repeat(*dims, 1, 1)
+ neg = torch.det(U) * torch.det(Vh.transpose(-1, -2)) < 0
+ S[neg, 2, 2] = -1
+
+ R = torch.matmul(U, torch.matmul(S, Vh)) # (*, 3, 3)
+
+ D = torch.diag_embed(D) # (*, 3, 3)
+ if fixed_scale:
+ s = torch.ones(*dims, 1, device=Y.device, dtype=torch.float32)
+ else:
+ var = torch.sum(torch.square(x0), dim=(-1, -2), keepdim=True) / N # (*, 1, 1)
+ s = (
+ torch.diagonal(torch.matmul(D, S), dim1=-2, dim2=-1).sum(
+ dim=-1, keepdim=True
+ )
+ / var[..., 0]
+ ) # (*, 1)
+
+ t = my - s * torch.matmul(R, mx[..., None])[..., 0] # (*, 3)
+
+ return s, R, t
+
+
+def first_k_frames_align_to(
+ S1 : torch.Tensor,
+ S2 : torch.Tensor,
+ k_f : int,
+):
+ '''
+ Compute the transformation between the first trajectory segment of S1 and S2, and use
+ the transformation to align S1 to S2.
+
+ The code was modified from [SLAHMR](https://github.com/vye16/slahmr/blob/58518fec991877bc4911e260776589185b828fe9/slahmr/eval/tools.py#L68-L81).
+
+ ### Args
+ - `S1`: torch.Tensor, shape (..., L, N, 3)
+ - `S2`: torch.Tensor, shape (..., L, N, 3)
+ - `k_f`: int
+ - The number of frames to use for alignment.
+
+ ### Returns
+ - `S1_aligned`: torch.Tensor, shape (..., L, N, 3)
+ - The aligned S1.
+ '''
+ assert (len(S1.shape) >= 3 and len(S2.shape) >= 3), 'The input tensors must have at least 3 dimensions.'
+ original_shape = S1.shape # (..., L, N, 3)
+ L, N, _ = original_shape[-3:]
+ S1 = S1.reshape(-1, L, N, 3) # (B, L, N, 3)
+ S2 = S2.reshape(-1, L, N, 3) # (B, L, N, 3)
+ B = S1.shape[0]
+
+ # 1. Prepare the clouds to be aligned.
+ S1_first = S1[:, :k_f, :, :].reshape(B, -1, 3) # (B, 1, k_f * N, 3)
+ S2_first = S2[:, :k_f, :, :].reshape(B, -1, 3) # (B, 1, k_f * N, 3)
+
+ # 2. Get the transformation to perform the alignment.
+ s_first, R_first, t_first = align_pcl(
+ X = S1_first,
+ Y = S2_first,
+ ) # (B, 1), (B, 3, 3), (B, 3)
+ s_first = s_first.reshape(B, 1, 1, 1) # (B, 1, 1, 1)
+ t_first = t_first.reshape(B, 1, 1, 3) # (B, 1, 1, 3)
+
+ # 3. Perform the alignment on the whole sequence.
+ S1_aligned = s_first * torch.einsum('Bij,BLNj->BLNi', R_first, S1) + t_first # (B, L, N, 3)
+ S1_aligned = S1_aligned.reshape(original_shape) # (..., L, N, 3)
+ return S1_aligned # (..., L, N, 3)
\ No newline at end of file
diff --git a/lib/info/log.py b/lib/info/log.py
new file mode 100644
index 0000000000000000000000000000000000000000..74179d0f87d89cfe0831c622ff36698c17662968
--- /dev/null
+++ b/lib/info/log.py
@@ -0,0 +1,75 @@
+import inspect
+import logging
+import torch
+from time import time
+from colorlog import ColoredFormatter
+
+
+def fold_path(fn:str):
+ ''' Fold a path like `from/to/file.py` to relative `f/t/file.py`. '''
+ from lib.platform.proj_manager import ProjManager as PM
+ root_abs = str(PM.root.absolute())
+ if fn.startswith(root_abs):
+ fn = fn[len(root_abs)+1:]
+
+ return '/'.join([p[:1] for p in fn.split('/')[:-1]]) + '/' + fn.split('/')[-1]
+
+
+def sync_time():
+ torch.cuda.synchronize()
+ return time()
+
+
+def get_logger(brief:bool=False, show_stack:bool=False):
+ # 1. Get the caller's file name and function name to identify the logging position.
+ caller_frame = inspect.currentframe().f_back
+ line_num = caller_frame.f_lineno
+ file_name = caller_frame.f_globals["__file__"]
+ file_name = fold_path(file_name)
+ func_name = caller_frame.f_code.co_name
+ frames_stack = inspect.stack()
+
+ # 2. Add a trace method to the logger.
+ def trace_handler(self, message, *args, **kws):
+ if self.isEnabledFor(TRACE):
+ self._log(TRACE, message, args, **kws)
+
+ TRACE = 15 # DEBUG is 10 and INFO is 20
+ logging.addLevelName(TRACE, 'TRACE')
+ logging.Logger.trace = trace_handler
+
+ # 3. Set up the logger.
+ logger = logging.getLogger()
+ logger.time = time
+ logger.sync_time = sync_time
+
+ if logger.hasHandlers():
+ logger.handlers.clear()
+
+ ch = logging.StreamHandler()
+
+ if brief:
+ prefix = f'[%(cyan)s%(asctime)s%(reset)s]'
+ else:
+ prefix = f'[%(cyan)s%(asctime)s%(reset)s @ %(cyan)s{func_name}%(reset)s @ %(cyan)s{file_name}%(reset)s:%(cyan)s{line_num}%(reset)s]'
+ if show_stack:
+ suffix = '\n STACK: ' + ' @ '.join([f'{fold_path(frame.filename)}:{frame.lineno}' for frame in frames_stack[1:]])
+ else:
+ suffix = ''
+ formatstring = f'{prefix}[%(log_color)s%(levelname)s%(reset)s] %(message)s{suffix}'
+ datefmt = '%m/%d %H:%M:%S'
+ ch.setFormatter(ColoredFormatter(formatstring, datefmt=datefmt))
+ logger.addHandler(ch)
+
+ # Modify the logging level here.
+ logger.setLevel(TRACE)
+ ch.setLevel(TRACE)
+
+ return logger
+
+if __name__ == '__main__':
+ get_logger().trace('Test TRACE')
+ get_logger().info('Test INFO')
+ get_logger().warning('Test WARN')
+ get_logger().error('Test ERROR')
+ get_logger().fatal('Test FATAL')
\ No newline at end of file
diff --git a/lib/info/look.py b/lib/info/look.py
new file mode 100644
index 0000000000000000000000000000000000000000..37f124dcdb75245fabe4343e509001d3dd867d18
--- /dev/null
+++ b/lib/info/look.py
@@ -0,0 +1,108 @@
+# Provides methods to summarize the information of data, giving a brief overview in text.
+
+import torch
+import numpy as np
+
+from typing import Optional
+
+from .log import get_logger
+
+
+def look_tensor(
+ x : torch.Tensor,
+ prompt : Optional[str] = None,
+ silent : bool = False,
+):
+ '''
+ Summarize the information of a tensor, including its shape, value range (min, max, mean, std), and dtype.
+ Then return a string containing the information.
+
+ ### Args
+ - x: torch.Tensor
+ - silent: bool, default `False`
+ - If not silent, the function will print the message itself. The information string will always be returned.
+ - prompt: Optional[str], default `None`
+ - If have prompt, it will be printed at the very beginning.
+
+ ### Returns
+ - str
+ '''
+ info_list = [] if prompt is None else [prompt]
+ # Convert to float to calculate the statistics.
+ x_num = x.float()
+ info_list.append(f'📐 [{x_num.min():06f} -> {x_num.max():06f}] ~ ({x_num.mean():06f}, {x_num.std():06f})')
+ info_list.append(f'📦 {tuple(x.shape)}')
+ info_list.append(f'🏷️ {x.dtype}')
+ info_list.append(f'🖥️ {x.device}')
+ # Generate the final information and print it if necessary.
+ ret = '\t'.join(info_list)
+ if not silent:
+ get_logger().info(ret)
+ return ret
+
+
+def look_ndarray(
+ x : np.ndarray,
+ silent : bool = False,
+ prompt : Optional[str] = None,
+):
+ '''
+ Summarize the information of a numpy array, including its shape, value range (min, max, mean, std), and dtype.
+ Then return a string containing the information.
+
+ ### Args
+ - x: np.ndarray
+ - silent: bool, default `False`
+ - If not silent, the function will print the message itself. The information string will always be returned.
+ - prompt: Optional[str], default `None`
+ - If have prompt, it will be printed at the very beginning.
+
+ ### Returns
+ - str
+ '''
+ info_list = [] if prompt is None else [prompt]
+ # Convert to float to calculate the statistics.
+ x_num = x.astype(np.float32)
+ info_list.append(f'📐 [ {x_num.min():06f} -> {x_num.max():06f} ] ~ ( {x_num.mean():06f}, {x_num.std():06f} )')
+ info_list.append(f'📦 {tuple(x.shape)}')
+ info_list.append(f'🏷️ {x.dtype}')
+ # Generate the final information and print it if necessary.
+ ret = '\t'.join(info_list)
+ if not silent:
+ get_logger().info(ret)
+ return ret
+
+
+def look_dict(
+ d : dict,
+ silent : bool = False,
+):
+ '''
+ Summarize the information of a dictionary, including the keys and the information of the values.
+ Then return a string containing the information.
+
+ ### Args
+ - d: dict
+ - silent: bool, default `False`
+ - If not silent, the function will print the message itself. The information string will always be returned.
+
+ ### Returns
+ - str
+ '''
+ info_list = ['{']
+
+ for k, v in d.items():
+ if isinstance(v, torch.Tensor):
+ info_list.append(f'{k} : tensor: {look_tensor(v, silent=True)}')
+ elif isinstance(v, np.ndarray):
+ info_list.append(f'{k} : ndarray: {look_ndarray(v, silent=True)}')
+ elif isinstance(v, str):
+ info_list.append(f'{k} : {v[:32]}')
+ else:
+ info_list.append(f'{k} : {type(v)}')
+
+ info_list.append('}')
+ ret = '\n'.join(info_list)
+ if not silent:
+ get_logger().info(ret)
+ return ret
\ No newline at end of file
diff --git a/lib/info/show.py b/lib/info/show.py
new file mode 100644
index 0000000000000000000000000000000000000000..579bc9fa6e25a72b43b99f188c4f2ae91ed052e8
--- /dev/null
+++ b/lib/info/show.py
@@ -0,0 +1,93 @@
+# Provides methods to visualize the information of data, giving a brief overview in figure.
+
+import torch
+import numpy as np
+import matplotlib.pyplot as plt
+
+from typing import Optional, Union, List, Dict
+from pathlib import Path
+
+from lib.utils.data import to_numpy
+
+
+def show_distribution(
+ data : Dict,
+ fn : Union[str, Path], # File name of the saved figure.
+ bins : int = 100, # Number of bins in the histogram.
+ annotation : bool = False,
+ title : str = 'Data Distribution',
+ axis_names : List = ['Value', 'Frequency'],
+ bounds : Optional[List] = None, # Left and right bounds of the histogram.
+):
+ '''
+ Visualize the distribution of the data using histogram.
+ The data should be a dictionary with keys as the labels and values as the data.
+ '''
+ labels = list(data.keys())
+ data = np.stack([ to_numpy(x) for x in data.values() ], axis=0)
+ assert data.ndim == 2, f"Data dimension should be 2, but got {data.ndim}."
+ assert bounds is None or len(bounds) == 2, f"Bounds should be a list of length 2, but got {bounds}."
+ # Preparation.
+ N, K = data.shape
+ data = data.transpose(1, 0) # (K, N)
+ # Plot.
+ plt.hist(data, bins=bins, alpha=0.7, label=labels)
+ if annotation:
+ for i in range(K):
+ for j in range(N):
+ plt.text(data[i, j], 0, f'{data[i, j]:.2f}', va='bottom', fontsize=6)
+ plt.title(title)
+ plt.xlabel(axis_names[0])
+ plt.ylabel(axis_names[1])
+ plt.legend()
+ if bounds:
+ plt.xlim(bounds)
+ # Save.
+ plt.savefig(fn)
+ plt.close()
+
+
+
+def show_history(
+ data : Dict,
+ fn : Union[str, Path], # file name of the saved figure
+ annotation : bool = False,
+ title : str = 'Data History',
+ axis_names : List = ['Time', 'Value'],
+ ex_starts : Dict[str, int] = {}, # starting points of the history if not starting from 0
+):
+ '''
+ Visualize the value of changing across time.
+ The history should be a dictionary with keys as the metric names and values as the metric values.
+ '''
+ # Make sure the fn's parent exists.
+ if isinstance(fn, str):
+ fn = Path(fn)
+ fn.parent.mkdir(parents=True, exist_ok=True)
+
+ # Preparation.
+ history_name = list(data.keys())
+ history_data = [ to_numpy(x) for x in data.values() ]
+ N = len(history_name)
+ Ls = [len(x) for x in history_data]
+ Ss = [
+ ex_starts[history_name[i]]
+ if (history_name[i] in ex_starts.keys()) else 0
+ for i in range(N)
+ ]
+
+ # Plot.
+ for i in range(N):
+ plt.plot(range(Ss[i], Ss[i]+Ls[i]), history_data[i], label=history_name[i])
+ if annotation:
+ for i in range(N):
+ for j in range(Ls[i]):
+ plt.text(Ss[i]+j, history_data[i][j], f'{history_data[i][j]:.2f}', fontsize=6)
+
+ plt.title(title)
+ plt.xlabel(axis_names[0])
+ plt.ylabel(axis_names[1])
+ plt.legend()
+ # Save.
+ plt.savefig(fn)
+ plt.close()
\ No newline at end of file
diff --git a/lib/kits/README.md b/lib/kits/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..8d1313f8d5fdba5e1e412258729567af5677f11f
--- /dev/null
+++ b/lib/kits/README.md
@@ -0,0 +1,3 @@
+# `lib.kits.*` Using Instructions
+
+Submodules in this package are pre-prepared set of imports for common tasks.
diff --git a/lib/kits/basic.py b/lib/kits/basic.py
new file mode 100644
index 0000000000000000000000000000000000000000..1a3ad662a9640910b048d58edd2061082346a6ee
--- /dev/null
+++ b/lib/kits/basic.py
@@ -0,0 +1,29 @@
+# Misc.
+import os
+import sys
+from pathlib import Path
+from typing import Union, Optional, List, Dict, Tuple, Any
+
+
+# Machine Learning Related
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import numpy as np
+import pytorch_lightning as pl
+from pytorch_lightning.utilities import rank_zero_only
+
+# Framework Supports
+import hydra
+from hydra.utils import instantiate
+from omegaconf import DictConfig, OmegaConf
+
+# Framework Supports - Customized Part
+from lib.utils.data import *
+from lib.platform import PM
+from lib.info.log import get_logger
+
+try:
+ import oven
+except ImportError:
+ get_logger(brief=True).warning('ExpOven is not installed. Will not be able to use the oven related functions. Check https://github.com/IsshikiHugh/ExpOven for more information.')
\ No newline at end of file
diff --git a/lib/kits/debug.py b/lib/kits/debug.py
new file mode 100644
index 0000000000000000000000000000000000000000..ad6d20bebb171bfef3cdd3f9ccbfbe0015dda342
--- /dev/null
+++ b/lib/kits/debug.py
@@ -0,0 +1,27 @@
+from ipdb import set_trace
+import os
+import inspect as frame_inspect
+from tqdm import tqdm
+from rich import inspect, pretty, print
+
+from lib.platform import PM
+from lib.platform.monitor import GPUMonitor
+from lib.info.log import get_logger
+
+def who_imported_me():
+ # Get the current stack frames.
+ stack = frame_inspect.stack()
+
+ # Traverse the stack to find the first external caller.
+ for frame_info in stack:
+ # Filter out the internal importlib calls and the current file.
+ if 'importlib' not in frame_info.filename and frame_info.filename != __file__:
+ return os.path.abspath(frame_info.filename)
+
+ # If no external file is found, it might be running as the main script.
+ return None
+
+get_logger(brief=True).warning(f'DEBUG kits are imported at {who_imported_me()}, remember to remove them.')
+
+from lib.info.look import *
+from lib.info.show import *
diff --git a/lib/kits/gradio/__init__.py b/lib/kits/gradio/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..1a1135219bdd11a4041e63b5f3f63f63b9f37c0c
--- /dev/null
+++ b/lib/kits/gradio/__init__.py
@@ -0,0 +1,2 @@
+from .backend import HSMRBackend
+from .hsmr_service import HSMRService
diff --git a/lib/kits/gradio/backend.py b/lib/kits/gradio/backend.py
new file mode 100644
index 0000000000000000000000000000000000000000..3f5756d3bcab97cf4ce68258723e2014d6b834c3
--- /dev/null
+++ b/lib/kits/gradio/backend.py
@@ -0,0 +1,82 @@
+from lib.kits.hsmr_demo import *
+
+import gradio as gr
+
+from lib.modeling.pipelines import HSMRPipeline
+
+class HSMRBackend:
+ '''
+ Backend class for maintaining HSMR model for inferencing.
+ Some gradio feature is included in this class.
+ '''
+ def __init__(self, device:str='cpu') -> None:
+ self.max_img_w = 1920
+ self.max_img_h = 1080
+ self.device = device
+ self.pipeline = self._build_pipeline(self.device)
+ self.detector = build_detector(
+ batch_size = 1,
+ max_img_size = 512,
+ device = self.device,
+ )
+
+
+ def _build_pipeline(self, device) -> HSMRPipeline:
+ return build_inference_pipeline(
+ model_root = DEFAULT_HSMR_ROOT,
+ device = device,
+ )
+
+
+ def _load_limited_img(self, fn) -> List:
+ img, _ = load_img(fn)
+ if img.shape[0] > self.max_img_h:
+ img = flex_resize_img(img, (self.max_img_h, -1), kp_mod=4)
+ if img.shape[1] > self.max_img_w:
+ img = flex_resize_img(img, (-1, self.max_img_w), kp_mod=4)
+ return [img]
+
+
+ def __call__(self, input_path:Union[str, Path], args:Dict):
+ # 1. Initialization.
+ input_type = 'img'
+ if isinstance(input_path, str): input_path = Path(input_path)
+ outputs_root = input_path.parent / 'outputs'
+ outputs_root.mkdir(parents=True, exist_ok=True)
+
+ # 2. Preprocess.
+ gr.Info(f'[1/3] Pre-processing...')
+ raw_imgs = self._load_limited_img(input_path)
+ detector_outputs = self.detector(raw_imgs)
+ patches, det_meta = imgs_det2patches(raw_imgs, *detector_outputs,args['max_instances']) # N * (256, 256, 3)
+
+ # 3. Inference.
+ gr.Info(f'[2/3] HSMR inferencing...')
+ pd_params, pd_cam_t = [], []
+ for bw in bsb(total=len(patches), batch_size=args['rec_bs'], enable_tqdm=True):
+ patches_i = np.concatenate(patches[bw.sid:bw.eid], axis=0) # (N, 256, 256, 3)
+ patches_normalized_i = (patches_i - IMG_MEAN_255) / IMG_STD_255 # (N, 256, 256, 3)
+ patches_normalized_i = patches_normalized_i.transpose(0, 3, 1, 2) # (N, 3, 256, 256)
+ with torch.no_grad():
+ outputs = self.pipeline(patches_normalized_i)
+ pd_params.append({k: v.detach().cpu().clone() for k, v in outputs['pd_params'].items()})
+ pd_cam_t.append(outputs['pd_cam_t'].detach().cpu().clone())
+
+ pd_params = assemble_dict(pd_params, expand_dim=False) # [{k:[x]}, {k:[y]}] -> {k:[x, y]}
+ pd_cam_t = torch.cat(pd_cam_t, dim=0)
+
+ # 4. Render.
+ gr.Info(f'[3/3] Rendering results...')
+ m_skin, m_skel = prepare_mesh(self.pipeline, pd_params)
+ results = visualize_img_results(pd_cam_t, raw_imgs, det_meta, m_skin, m_skel)
+
+ outputs = {}
+
+ if input_type == 'img':
+ for k, v in results.items():
+ img_path = str(outputs_root / f'{k}.jpg')
+ outputs[k] = img_path
+ save_img(v, img_path)
+ outputs[k] = img_path
+
+ return outputs
\ No newline at end of file
diff --git a/lib/kits/gradio/hsmr_service.py b/lib/kits/gradio/hsmr_service.py
new file mode 100644
index 0000000000000000000000000000000000000000..bf269384bbbd442a942e9cb3168a79fb57bd2aef
--- /dev/null
+++ b/lib/kits/gradio/hsmr_service.py
@@ -0,0 +1,101 @@
+import gradio as gr
+
+from typing import List, Union
+from pathlib import Path
+
+from lib.platform import PM
+from lib.info.log import get_logger
+from .backend import HSMRBackend
+
+class HSMRService:
+ # ===== Initialization Part =====
+
+ def __init__(self, backend:HSMRBackend) -> None:
+ self.example_imgs_root = PM.inputs / 'example_imgs'
+ self.description = self._load_description()
+ self.backend:HSMRBackend = backend
+
+ def _load_description(self) -> str:
+ description_fn = PM.inputs / 'description.md'
+ with open(description_fn, 'r') as f:
+ description = f.read()
+ return description
+
+
+ # ===== Funcitonal Private Part =====
+
+ def _inference_img(
+ self,
+ raw_img_path:Union[str, Path],
+ max_instances:int = 5,
+ ) -> List:
+ get_logger(brief=True).info(f'Image Path: {raw_img_path}')
+ get_logger(brief=True).info(f'max_instances: {max_instances}')
+ if raw_img_path is None:
+ gr.Warning('No image uploaded yet. Please upload an image first.')
+ return ''
+ if isinstance(raw_img_path, str):
+ raw_img_path = Path(raw_img_path)
+
+ args = {
+ 'max_instances': max_instances,
+ 'rec_bs': 1,
+ }
+ output_paths = self.backend(raw_img_path, args)
+
+ # bbx_img_path = output_paths['bbx_img_path']
+ # mesh_img_path = output_paths['mesh_img_path']
+ # skel_img_path = output_paths['skel_img_path']
+ blend_img_path = output_paths['front_blend']
+
+ return blend_img_path
+
+
+ # ===== Service Part =====
+
+ def serve(self) -> None:
+ ''' Build UI and set up the service. '''
+ with gr.Blocks() as demo:
+ # 1a. Setup UI.
+ gr.Markdown(self.description)
+
+ with gr.Tab(label='HSMR-IMG-CPU'):
+ gr.Markdown('> **Pure CPU** demo for recoverying human mesh and skeleton from a single image. Each inference may take **about 3 minutes**.')
+ with gr.Row(equal_height=False):
+ with gr.Column():
+ input_image = gr.Image(
+ label = 'Input',
+ type = 'filepath',
+ )
+ with gr.Row(equal_height=True):
+ run_button_image = gr.Button(
+ value = 'Inference',
+ variant = 'primary',
+ )
+
+ with gr.Column():
+ output_blend = gr.Image(
+ label = 'Output',
+ type = 'filepath',
+ interactive = False,
+ )
+
+ # 1b. Add examples sections after setting I/O policy.
+ example_fns = sorted(self.example_imgs_root.glob('*'))
+ gr.Examples(
+ examples = example_fns,
+ fn = self._inference_img,
+ inputs = input_image,
+ outputs = output_blend,
+ )
+
+ # 2b. Continue binding I/O logic.
+ run_button_image.click(
+ fn = self._inference_img,
+ inputs = input_image,
+ outputs = output_blend,
+ )
+
+
+ # 3. Launch the service.
+ demo.queue(max_size=20).launch(server_name='0.0.0.0', server_port=7860)
\ No newline at end of file
diff --git a/lib/kits/hsmr_demo.py b/lib/kits/hsmr_demo.py
new file mode 100644
index 0000000000000000000000000000000000000000..9a3930db891b43c86e9ed390e3480df8c807ee88
--- /dev/null
+++ b/lib/kits/hsmr_demo.py
@@ -0,0 +1,217 @@
+from lib.kits.basic import *
+
+import argparse
+from tqdm import tqdm
+
+from lib.version import DEFAULT_HSMR_ROOT
+from lib.utils.vis.py_renderer import render_meshes_overlay_img, render_mesh_overlay_img
+from lib.utils.bbox import crop_with_lurb, fit_bbox_to_aspect_ratio, lurb_to_cs, cs_to_lurb
+from lib.utils.media import *
+from lib.platform.monitor import TimeMonitor
+from lib.platform.sliding_batches import bsb
+from lib.modeling.pipelines.hsmr import build_inference_pipeline
+from lib.modeling.pipelines.vitdet import build_detector
+
+IMG_MEAN_255 = np.array([0.485, 0.456, 0.406], dtype=np.float32) * 255.
+IMG_STD_255 = np.array([0.229, 0.224, 0.225], dtype=np.float32) * 255.
+
+
+# ================== Command Line Supports ==================
+
+def parse_args():
+ parser = argparse.ArgumentParser()
+ parser.add_argument('-t', '--input_type', type=str, default='auto', help='Specify the input type. auto: file~video, folder~imgs', choices=['auto', 'video', 'imgs'])
+ parser.add_argument('-i', '--input_path', type=str, required=True, help='The input images root or video file path.')
+ parser.add_argument('-o', '--output_path', type=str, default=PM.outputs/'demos', help='The output root.')
+ parser.add_argument('-m', '--model_root', type=str, default=DEFAULT_HSMR_ROOT, help='The model root which contains `.hydra/config.yaml`.')
+ parser.add_argument('-d', '--device', type=str, default='cuda:0', help='The device.')
+ parser.add_argument('--det_bs', type=int, default=10, help='The max batch size for detector.')
+ parser.add_argument('--det_mis', type=int, default=512, help='The max image size for detector.')
+ parser.add_argument('--rec_bs', type=int, default=300, help='The batch size for recovery.')
+ parser.add_argument('--max_instances', type=int, default=5, help='Max instances activated in one image.')
+ parser.add_argument('--ignore_skel', action='store_true', help='Do not render skeleton to boost the rendering.')
+ args = parser.parse_args()
+ return args
+
+
+# ================== Data Process Tools ==================
+
+def load_inputs(args, MAX_IMG_W=1920, MAX_IMG_H=1080):
+ # 1. Inference inputs type.
+ inputs_path = Path(args.input_path)
+ if args.input_type != 'auto': inputs_type = args.input_type
+ else: inputs_type = 'video' if Path(args.input_path).is_file() else 'imgs'
+ get_logger(brief=True).info(f'🚚 Loading inputs from: {inputs_path}, regarded as <{inputs_type}>.')
+
+ # 2. Load inputs.
+ inputs_meta = {'type': inputs_type}
+ if inputs_type == 'video':
+ inputs_meta['seq_name'] = inputs_path.stem
+ frames, _ = load_video(inputs_path)
+ if frames.shape[1] > MAX_IMG_H:
+ frames = flex_resize_video(frames, (MAX_IMG_H, -1), kp_mod=4)
+ if frames.shape[2] > MAX_IMG_W:
+ frames = flex_resize_video(frames, (-1, MAX_IMG_W), kp_mod=4)
+ raw_imgs = [frame for frame in frames]
+ elif inputs_type == 'imgs':
+ img_fns = list(inputs_path.glob('*.*'))
+ img_fns = [fn for fn in img_fns if fn.suffix.lower() in ['.jpg', '.jpeg', '.png', '.webp']]
+ inputs_meta['seq_name'] = f'{inputs_path.stem}-img_cnt={len(img_fns)}'
+ raw_imgs = []
+ for fn in img_fns:
+ img, _ = load_img(fn)
+ if img.shape[0] > MAX_IMG_H:
+ img = flex_resize_img(img, (MAX_IMG_H, -1), kp_mod=4)
+ if img.shape[1] > MAX_IMG_W:
+ img = flex_resize_img(img, (-1, MAX_IMG_W), kp_mod=4)
+ raw_imgs.append(img)
+ inputs_meta['img_fns'] = img_fns
+ else:
+ raise ValueError(f'Unsupported inputs type: {inputs_type}.')
+ get_logger(brief=True).info(f'📦 Totally {len(raw_imgs)} images are loaded.')
+
+ return raw_imgs, inputs_meta
+
+
+def imgs_det2patches(imgs, dets, downsample_ratios, max_instances_per_img):
+ ''' Given the raw images and the detection results, return the image patches of human instances. '''
+ assert len(imgs) == len(dets), f'L_img = {len(imgs)}, L_det = {len(dets)}'
+ patches, n_patch_per_img, bbx_cs = [], [], []
+ for i in tqdm(range(len(imgs))):
+ patches_i, bbx_cs_i = _img_det2patches(imgs[i], dets[i], downsample_ratios[i], max_instances_per_img)
+ n_patch_per_img.append(len(patches_i))
+ if len(patches_i) > 0:
+ patches.append(patches_i.astype(np.float32))
+ bbx_cs.append(bbx_cs_i)
+ else:
+ get_logger(brief=True).warning(f'No human detection results on image No.{i}.')
+ det_meta = {
+ 'n_patch_per_img' : n_patch_per_img,
+ 'bbx_cs' : bbx_cs,
+ }
+ return patches, det_meta
+
+
+def _img_det2patches(imgs, det_instances, downsample_ratio:float, max_instances:int=5):
+ '''
+ 1. Filter out the trusted human detections.
+ 2. Enlarge the bounding boxes to aspect ratio (ViT backbone only use 192*256 pixels, make sure these
+ pixels can capture main contents) and then to squares (to adapt the data module).
+ 3. Crop the image with the bounding boxes and resize them to 256x256.
+ 4. Normalize the cropped images.
+ '''
+ if det_instances is None: # no human detected
+ return to_numpy([]), to_numpy([])
+ CLASS_HUMAN_ID, DET_THRESHOLD_SCORE = 0, 0.5
+
+ # Filter out the trusted human detections.
+ is_human_mask = det_instances['pred_classes'] == CLASS_HUMAN_ID
+ reliable_mask = det_instances['scores'] > DET_THRESHOLD_SCORE
+ active_mask = is_human_mask & reliable_mask
+
+ # Filter out the top-k human instances.
+ if active_mask.sum().item() > max_instances:
+ humans_scores = det_instances['scores'] * is_human_mask.float()
+ _, top_idx = humans_scores.topk(max_instances)
+ valid_mask = torch.zeros_like(active_mask).bool()
+ valid_mask[top_idx] = True
+ else:
+ valid_mask = active_mask
+
+ # Process the bounding boxes and crop the images.
+ lurb_all = det_instances['pred_boxes'][valid_mask].numpy() / downsample_ratio # (N, 4)
+ lurb_all = [fit_bbox_to_aspect_ratio(bbox=lurb, tgt_ratio=(192, 256)) for lurb in lurb_all] # regularize the bbox size
+ cs_all = [lurb_to_cs(lurb) for lurb in lurb_all]
+ lurb_all = [cs_to_lurb(cs) for cs in cs_all]
+ cropped_imgs = [crop_with_lurb(imgs, lurb) for lurb in lurb_all]
+ patches = to_numpy([flex_resize_img(cropped_img, (256, 256)) for cropped_img in cropped_imgs]) # (N, 256, 256, 3)
+ return patches, cs_all
+
+
+# ================== Secondary Outputs Tools ==================
+
+def prepare_mesh(pipeline, pd_params):
+ B = 720 # full SKEL inference is memory consuming
+ L = pd_params['poses'].shape[0]
+ n_rounds = (L + B - 1) // B
+ v_skin_all, v_skel_all = [], []
+ for rid in range(n_rounds):
+ sid, eid = rid * B, min((rid + 1) * B, L)
+ smpl_outputs = pipeline.skel_model(
+ poses = pd_params['poses'][sid:eid].to(pipeline.device),
+ betas = pd_params['betas'][sid:eid].to(pipeline.device),
+ )
+ v_skin = smpl_outputs.skin_verts.detach().cpu() # (B, Vi, 3)
+ v_skel = smpl_outputs.skel_verts.detach().cpu() # (B, Ve, 3)
+ v_skin_all.append(v_skin)
+ v_skel_all.append(v_skel)
+ v_skel_all = torch.cat(v_skel_all, dim=0)
+ v_skin_all = torch.cat(v_skin_all, dim=0)
+ m_skin = {'v': v_skin_all, 'f': pipeline.skel_model.skin_f}
+ m_skel = {'v': v_skel_all, 'f': pipeline.skel_model.skel_f}
+ return m_skin, m_skel
+
+
+# ================== Visualization Tools ==================
+
+
+def visualize_img_results(pd_cam_t, raw_imgs, det_meta, m_skin, m_skel):
+ ''' Render the results to the patches. '''
+ bbx_cs, n_patches_per_img = det_meta['bbx_cs'], det_meta['n_patch_per_img']
+ bbx_cs = np.concatenate(bbx_cs, axis=0)
+
+ results = []
+ pp = 0 # patch pointer
+ results = []
+ for i in tqdm(range(len(raw_imgs)), desc='Rendering'):
+ raw_h, raw_w = raw_imgs[i].shape[:2]
+ raw_cx, raw_cy = raw_w/2, raw_h/2
+ spp, epp = pp, pp + n_patches_per_img[i]
+
+ # Rescale the camera translation.
+ raw_cam_t = pd_cam_t.clone().float()
+ bbx_s = to_tensor(bbx_cs[spp:epp, 2], device=raw_cam_t.device)
+ bbx_cx = to_tensor(bbx_cs[spp:epp, 0], device=raw_cam_t.device)
+ bbx_cy = to_tensor(bbx_cs[spp:epp, 1], device=raw_cam_t.device)
+
+ raw_cam_t[spp:epp, 2] = pd_cam_t[spp:epp, 2] * 256 / bbx_s
+ raw_cam_t[spp:epp, 1] += (bbx_cy - raw_cy) / 5000 * raw_cam_t[spp:epp, 2]
+ raw_cam_t[spp:epp, 0] += (bbx_cx - raw_cx) / 5000 * raw_cam_t[spp:epp, 2]
+ raw_cam_t = raw_cam_t[spp:epp]
+
+ # Render overlays on the full image.
+ full_img_bg = raw_imgs[i].copy()
+ render_results = {}
+ for view in ['front']:
+ full_img_skin = render_meshes_overlay_img(
+ faces_all = m_skin['f'],
+ verts_all = m_skin['v'][spp:epp].float(),
+ cam_t_all = raw_cam_t,
+ mesh_color = 'blue',
+ img = full_img_bg,
+ K4 = [5000, 5000, raw_cx, raw_cy],
+ view = view,
+ )
+
+ if m_skel is not None:
+ full_img_skel = render_meshes_overlay_img(
+ faces_all = m_skel['f'],
+ verts_all = m_skel['v'][spp:epp].float(),
+ cam_t_all = raw_cam_t,
+ mesh_color = 'human_yellow',
+ img = full_img_bg,
+ K4 = [5000, 5000, raw_cx, raw_cy],
+ view = view,
+ )
+
+ if m_skel is not None:
+ full_img_blend = cv2.addWeighted(full_img_skin, 0.7, full_img_skel, 0.3, 0)
+ render_results[f'{view}_blend'] = full_img_blend
+ if view == 'front':
+ render_results[f'{view}_skel'] = full_img_skel
+ else:
+ render_results[f'{view}_skin'] = full_img_skin
+
+ pp = epp
+
+ return render_results
\ No newline at end of file
diff --git a/lib/modeling/callbacks/__init__.py b/lib/modeling/callbacks/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..eb0b0f421b2ccc4cca7e257c07fd4f45290d8e27
--- /dev/null
+++ b/lib/modeling/callbacks/__init__.py
@@ -0,0 +1 @@
+from .skelify_spin import SKELifySPIN
\ No newline at end of file
diff --git a/lib/modeling/callbacks/skelify_spin.py b/lib/modeling/callbacks/skelify_spin.py
new file mode 100644
index 0000000000000000000000000000000000000000..ce6198f4ce45210e307a5dfe23b995dad03cd237
--- /dev/null
+++ b/lib/modeling/callbacks/skelify_spin.py
@@ -0,0 +1,347 @@
+from lib.kits.basic import *
+
+from concurrent.futures import ThreadPoolExecutor
+from lightning_fabric.utilities.rank_zero import _get_rank
+
+from lib.data.augmentation.skel import rot_skel_on_plane
+from lib.utils.data import to_tensor, to_numpy
+from lib.utils.camera import perspective_projection, estimate_camera_trans
+from lib.utils.vis import Wis3D
+from lib.modeling.optim.skelify.skelify import SKELify
+from lib.body_models.common import make_SKEL
+
+DEBUG = False
+DEBUG_ROUND = False
+
+
+class SKELifySPIN(pl.Callback):
+ '''
+ Call SKELify to optimize the prediction results.
+ Here we have several concepts of data: gt, opgt, pd, (res), bpgt.
+
+ 1. `gt`: Static Ground Truth: they are loaded from static training datasets,
+ they might be real ground truth (like 2D keypoints), or pseudo ground truth (like SKEL
+ parameters). They will be gradually replaced by the better pseudo ground truth through
+ iterations in anticipation.
+ 2. `opgt`: Old Pseudo-Ground Truth: they are the better ground truth among static datasets
+ (those called gt), and the dynamic datasets (maintained in the callbacks), and will serve
+ as the labels for training the network.
+ 3. `pd`: Predicted Results: they are from the network outputs and will be optimized later.
+ After being optimized, they will be called as `res`(Results from optimization).
+ 4. `bpgt`: Better Pseudo Ground Truth: they are the optimized results stored in extra file
+ and in the memory. These data are the highest quality data picked among the static
+ ground truth, or cached better pseudo ground truth, or the predicted & optimized data.
+ '''
+
+ # TODO: Now I only consider to use kp2d to evaluate the performance. (Because not all data provide kp3d.)
+ # TODO: But we need to consider the kp3d in the future, which is if we have, than use it.
+
+ def __init__(
+ self,
+ cfg : DictConfig,
+ skelify : DictConfig,
+ **kwargs,
+ ):
+ super().__init__()
+ self.interval = cfg.interval
+ self.B = cfg.batch_size
+ self.kb_pr = cfg.get('max_batches_per_round', None) # latest k batches per round are SPINed
+ self.better_pgt_fn = Path(cfg.better_pgt_fn) # load it before training
+ self.skip_warm_up_steps = cfg.skip_warm_up_steps
+ self.update_better_pgt = cfg.update_better_pgt
+ self.skelify_cfg = skelify
+
+ # The threshold to determine if the result is valid. (In case some data
+ # don't have parameters at first but was updated to a bad parameters.)
+ self.valid_betas_threshold = cfg.valid_betas_threshold
+
+ self._init_pd_dict()
+
+ self.better_pgt = None
+
+
+ def on_train_batch_start(self, trainer, pl_module, raw_batch, batch_idx):
+ # Lazy initialization for better pgt.
+ if self.better_pgt is None:
+ self._init_better_pgt()
+
+ # GPU_monitor.snapshot('GPU-Mem-Before-Train-Before-SPIN-Update')
+ device = pl_module.device
+ batch = raw_batch['img_ds']
+
+ if not self.update_better_pgt:
+ return
+
+ # 1. Compose the data from batches.
+ seq_key_list = batch['__key__']
+ batch_do_flip_list = batch['augm_args']['do_flip']
+ sample_uid_list = [
+ f'{seq_key}_flip' if do_flip else f'{seq_key}_orig'
+ for seq_key, do_flip in zip(seq_key_list, batch_do_flip_list)
+ ]
+
+ # 2. Update the labels from better_pgt.
+ for i, sample_uid in enumerate(sample_uid_list):
+ if sample_uid in self.better_pgt['poses'].keys():
+ batch['raw_skel_params']['poses'][i] = to_tensor(self.better_pgt['poses'][sample_uid], device=device) # (46,)
+ batch['raw_skel_params']['betas'][i] = to_tensor(self.better_pgt['betas'][sample_uid], device=device) # (10,)
+ batch['has_skel_params']['poses'][i] = self.better_pgt['has_poses'][sample_uid] # 0 or 1
+ batch['has_skel_params']['betas'][i] = self.better_pgt['has_betas'][sample_uid] # 0 or 1
+ batch['updated_by_spin'][i] = True # add information for inspection
+ # get_logger().trace(f'Update the pseudo-gt for {sample_uid}.')
+
+ # GPU_monitor.snapshot('GPU-Mem-Before-Train-After-SPIN-Update')
+
+
+ def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
+ # GPU_monitor.snapshot('GPU-Mem-After-Train-Before-SPIN-Update')
+
+ # Since the prediction from network might be far from the ground truth before well-trained,
+ # we can skip some steps to avoid meaningless optimization.
+ if trainer.global_step > self.skip_warm_up_steps or DEBUG_ROUND:
+ # Collect the prediction results.
+ self._save_pd(batch['img_ds'], outputs)
+
+ if self.interval > 0 and trainer.global_step % self.interval == 0 or DEBUG_ROUND:
+ torch.cuda.empty_cache()
+ with PM.time_monitor('SPIN'):
+ self._spin(trainer.logger, pl_module.device)
+ torch.cuda.empty_cache()
+
+ # GPU_monitor.snapshot('GPU-Mem-After-Train-After-SPIN-Update')
+ # GPU_monitor.report_latest(k=4)
+
+
+ def _init_pd_dict(self):
+ ''' Memory clean up for each SPIN. '''
+ ''' Use numpy to store the value to save GPU memory. '''
+ self.cache = {
+ # Things to identify one sample.
+ 'seq_key_list' : [],
+ # Things for comparison.
+ 'opgt_poses_list' : [],
+ 'opgt_betas_list' : [],
+ 'opgt_has_poses_list' : [],
+ 'opgt_has_betas_list' : [],
+ # Things for optimization and self-iteration.
+ 'gt_kp2d_list' : [],
+ 'pd_poses_list' : [],
+ 'pd_betas_list' : [],
+ 'pd_cam_t_list' : [],
+ 'do_flip_list' : [],
+ 'rot_deg_list' : [],
+ 'do_extreme_crop_list': [], # if the extreme crop is applied, we don't update the pseudo-gt
+ # Things for visualization.
+ 'img_patch': [],
+ # No gt_cam_t_list.
+ }
+
+
+ def _format_pd(self):
+ ''' Format the cache to numpy. '''
+ if self.kb_pr is None:
+ last_k = len(self.cache['seq_key_list'])
+ else:
+ last_k = self.kb_pr * self.B # the latest k samples to be optimized.
+ self.cache['seq_key_list'] = to_numpy(self.cache['seq_key_list'])[-last_k:]
+ self.cache['opgt_poses_list'] = to_numpy(self.cache['opgt_poses_list'])[-last_k:]
+ self.cache['opgt_betas_list'] = to_numpy(self.cache['opgt_betas_list'])[-last_k:]
+ self.cache['opgt_has_poses_list'] = to_numpy(self.cache['opgt_has_poses_list'])[-last_k:]
+ self.cache['opgt_has_betas_list'] = to_numpy(self.cache['opgt_has_betas_list'])[-last_k:]
+ self.cache['gt_kp2d_list'] = to_numpy(self.cache['gt_kp2d_list'])[-last_k:]
+ self.cache['pd_poses_list'] = to_numpy(self.cache['pd_poses_list'])[-last_k:]
+ self.cache['pd_betas_list'] = to_numpy(self.cache['pd_betas_list'])[-last_k:]
+ self.cache['pd_cam_t_list'] = to_numpy(self.cache['pd_cam_t_list'])[-last_k:]
+ self.cache['do_flip_list'] = to_numpy(self.cache['do_flip_list'])[-last_k:]
+ self.cache['rot_deg_list'] = to_numpy(self.cache['rot_deg_list'])[-last_k:]
+ self.cache['do_extreme_crop_list'] = to_numpy(self.cache['do_extreme_crop_list'])[-last_k:]
+
+ if DEBUG:
+ self.cache['img_patch'] = to_numpy(self.cache['img_patch'])[-last_k:]
+
+
+ def _save_pd(self, batch, outputs):
+ ''' Save all the prediction results and related labels from the outputs. '''
+ B = len(batch['__key__'])
+
+ self.cache['seq_key_list'].extend(batch['__key__']) # (NS,)
+
+ self.cache['opgt_poses_list'].extend(to_numpy(batch['raw_skel_params']['poses'])) # (NS, 46)
+ self.cache['opgt_betas_list'].extend(to_numpy(batch['raw_skel_params']['betas'])) # (NS, 10)
+ self.cache['opgt_has_poses_list'].extend(to_numpy(batch['has_skel_params']['poses'])) # (NS,) 0 or 1
+ self.cache['opgt_has_betas_list'].extend(to_numpy(batch['has_skel_params']['betas'])) # (NS,) 0 or 1
+ self.cache['gt_kp2d_list'].extend(to_numpy(batch['kp2d'])) # (NS, 44, 3)
+
+ self.cache['pd_poses_list'].extend(to_numpy(outputs['pd_params']['poses']))
+ self.cache['pd_betas_list'].extend(to_numpy(outputs['pd_params']['betas']))
+ self.cache['pd_cam_t_list'].extend(to_numpy(outputs['pd_cam_t']))
+ self.cache['do_flip_list'].extend(to_numpy(batch['augm_args']['do_flip']))
+ self.cache['rot_deg_list'].extend(to_numpy(batch['augm_args']['rot_deg']))
+ self.cache['do_extreme_crop_list'].extend(to_numpy(batch['augm_args']['do_extreme_crop']))
+
+ if DEBUG:
+ img_patch = batch['img_patch'].clone().permute(0, 2, 3, 1) # (NS, 256, 256, 3)
+ mean = torch.tensor([0.485, 0.456, 0.406], device=img_patch.device).reshape(1, 1, 1, 3)
+ std = torch.tensor([0.229, 0.224, 0.225], device=img_patch.device).reshape(1, 1, 1, 3)
+ img_patch = 255 * (img_patch * std + mean)
+ self.cache['img_patch'].extend(to_numpy(img_patch).astype(np.uint8)) # (NS, 256, 256, 3)
+
+
+ def _init_better_pgt(self):
+ ''' DDP adaptable initialization. '''
+ self.rank = _get_rank()
+ get_logger().info(f'Initializing better pgt cache @ rank {self.rank}')
+
+ if self.rank is not None:
+ self.better_pgt_fn = Path(f'{self.better_pgt_fn}_r{self.rank}')
+ get_logger().info(f'Redirecting better pgt cache to {self.better_pgt_fn}')
+
+ if self.better_pgt_fn.exists():
+ better_pgt_z = np.load(self.better_pgt_fn, allow_pickle=True)
+ self.better_pgt = {k: better_pgt_z[k].item() for k in better_pgt_z.files}
+ else:
+ self.better_pgt = {'poses': {}, 'betas': {}, 'has_poses': {}, 'has_betas': {}}
+
+
+ def _spin(self, tb_logger, device):
+ skelify : SKELify = instantiate(self.skelify_cfg, tb_logger=tb_logger, device=device, _recursive_=False)
+ skel_model = skelify.skel_model
+
+ self._format_pd()
+
+ # 1. Make up the cache to run SKELify.
+ with PM.time_monitor('preparation'):
+ sample_uid_list = [
+ f'{seq_key}_flip' if do_flip else f'{seq_key}_orig'
+ for seq_key, do_flip in zip(self.cache['seq_key_list'], self.cache['do_flip_list'])
+ ]
+
+ all_gt_kp2d = self.cache['gt_kp2d_list'] # (NS, 44, 2)
+ all_init_poses = self.cache['pd_poses_list'] # (NS, 46)
+ all_init_betas = self.cache['pd_betas_list'] # (NS, 10)
+ all_init_cam_t = self.cache['pd_cam_t_list'] # (NS, 3)
+ all_do_extreme_crop = self.cache['do_extreme_crop_list'] # (NS,)
+ all_res_poses = []
+ all_res_betas = []
+ all_res_cam_t = []
+ all_res_kp2d_err = [] # the evaluation of the keypoints 2D error
+
+ # 2. Run SKELify optimization here to get better results.
+ with PM.time_monitor('SKELify') as tm:
+ get_logger().info(f'Start to run SKELify optimization. GPU-Mem: {torch.cuda.memory_allocated() / 1e9:.2f}G.')
+ n_samples = len(self.cache['seq_key_list'])
+ n_round = (n_samples - 1) // self.B + 1
+ get_logger().info(f'Running SKELify optimization for {n_samples} samples in {n_round} rounds.')
+ for rid in range(n_round):
+ sid = rid * self.B
+ eid = min(sid + self.B, n_samples)
+
+ gt_kp2d_with_conf = to_tensor(all_gt_kp2d[sid:eid], device=device)
+ init_poses = to_tensor(all_init_poses[sid:eid], device=device)
+ init_betas = to_tensor(all_init_betas[sid:eid], device=device)
+ init_cam_t = to_tensor(all_init_cam_t[sid:eid], device=device)
+
+ # Run the SKELify optimization.
+ outputs = skelify(
+ gt_kp2d = gt_kp2d_with_conf,
+ init_poses = init_poses,
+ init_betas = init_betas,
+ init_cam_t = init_cam_t,
+ img_patch = self.cache['img_patch'][sid:eid] if DEBUG else None,
+ )
+
+ # Store the results.
+ all_res_poses.extend(to_numpy(outputs['poses'])) # (~NS, 46)
+ all_res_betas.extend(to_numpy(outputs['betas'])) # (~NS, 10)
+ all_res_cam_t.extend(to_numpy(outputs['cam_t'])) # (~NS, 3)
+ all_res_kp2d_err.extend(to_numpy(outputs['kp2d_err'])) # (~NS,)
+
+ tm.tick(f'SKELify round {rid} finished.')
+
+ # 3. Initialize the uninitialized better pseudo-gt with old ground truth.
+ with PM.time_monitor('init_bpgt'):
+ get_logger().info(f'Initializing bgbt. GPU-Mem: {torch.cuda.memory_allocated() / 1e9:.2f}G.')
+ for i in range(n_samples):
+ sample_uid = sample_uid_list[i]
+ if sample_uid not in self.better_pgt.keys():
+ self.better_pgt['poses'][sample_uid] = self.cache['opgt_poses_list'][i]
+ self.better_pgt['betas'][sample_uid] = self.cache['opgt_betas_list'][i]
+ self.better_pgt['has_poses'][sample_uid] = self.cache['opgt_has_poses_list'][i]
+ self.better_pgt['has_betas'][sample_uid] = self.cache['opgt_has_betas_list'][i]
+
+ # 4. Update the results.
+ with PM.time_monitor('upd_bpgt'):
+ upd_cnt = 0 # Count the number of updated samples.
+ get_logger().info(f'Update the results. GPU-Mem: {torch.cuda.memory_allocated() / 1e9:.2f}G.')
+ for rid in range(n_round):
+ torch.cuda.empty_cache()
+ sid = rid * self.B
+ eid = min(sid + self.B, n_samples)
+
+ focal_length = np.ones(2) * 5000 / 256 # TODO: These data should be loaded from configuration files.
+ focal_length = focal_length.reshape(1, 2).repeat(eid - sid, 1) # (B, 2)
+ gt_kp2d_with_conf = to_tensor(all_gt_kp2d[sid:eid], device=device) # (B, 44, 3)
+ rot_deg = to_tensor(self.cache['rot_deg_list'][sid:eid], device=device)
+
+ # 4.1. Prepare the better pseudo-gt and the results.
+ res_betas = to_tensor(all_res_betas[sid:eid], device=device) # (B, 10)
+ res_poses_after_augm = to_tensor(all_res_poses[sid:eid], device=device) # (B, 46)
+ res_poses_before_augm = rot_skel_on_plane(res_poses_after_augm, -rot_deg) # recover the augmentation rotation
+ res_kp2d_err = to_tensor(all_res_kp2d_err[sid:eid], device=device) # (B,)
+ cur_do_extreme_crop = all_do_extreme_crop[sid:eid]
+
+ # 4.2. Evaluate the quality of the existing better pseudo-gt.
+ uids = sample_uid_list[sid:eid] # [sid ~ eid] -> sample_uids
+ bpgt_betas = to_tensor([self.better_pgt['betas'][uid] for uid in uids], device=device)
+ bpgt_poses_before_augm = to_tensor([self.better_pgt['poses'][uid] for uid in uids], device=device)
+ bpgt_poses_after_augm = rot_skel_on_plane(bpgt_poses_before_augm.clone(), rot_deg) # recover the augmentation rotation
+
+ skel_outputs = skel_model(poses=bpgt_poses_after_augm, betas=bpgt_betas, skelmesh=False)
+ bpgt_kp3d = skel_outputs.joints.detach() # (B, 44, 3)
+ bpgt_est_cam_t = estimate_camera_trans(
+ S = bpgt_kp3d,
+ joints_2d = gt_kp2d_with_conf.clone(),
+ focal_length = 5000,
+ img_size = 256,
+ ) # estimate camera translation from inference 3D keypoints and GT 2D keypoints
+ bpgt_reproj_kp2d = perspective_projection(
+ points = to_tensor(bpgt_kp3d, device=device),
+ translation = to_tensor(bpgt_est_cam_t, device=device),
+ focal_length = to_tensor(focal_length, device=device),
+ )
+ bpgt_kp2d_err = SKELify.eval_kp2d_err(gt_kp2d_with_conf, bpgt_reproj_kp2d) # (B, 44)
+
+ valid_betas_mask = res_betas.abs().max(dim=-1)[0] < self.valid_betas_threshold # (B,)
+ better_mask = res_kp2d_err < bpgt_kp2d_err # (B,)
+ upd_mask = torch.logical_and(valid_betas_mask, better_mask) # (B,)
+ upd_ids = torch.arange(eid-sid, device=device)[upd_mask] # uids -> ids
+
+ # Update one by one.
+ for upd_id in upd_ids:
+ # `uid` for dynamic dataset unique id, `id` for in-round batch data.
+ # Notes: id starts from zeros, it should be applied to [sid ~ eid] directly.
+ # Either `all_res_poses[upd_id]` or `res_poses[upd_id - sid]` is wrong.
+ if cur_do_extreme_crop[upd_id]:
+ # Skip the extreme crop data.
+ continue
+ sample_uid = uids[upd_id]
+ self.better_pgt['poses'][sample_uid] = to_numpy(res_poses_before_augm[upd_id])
+ self.better_pgt['betas'][sample_uid] = to_numpy(res_betas[upd_id])
+ self.better_pgt['has_poses'][sample_uid] = 1. # If updated, then must have.
+ self.better_pgt['has_betas'][sample_uid] = 1. # If updated, then must have.
+ upd_cnt += 1
+
+ get_logger().info(f'Update {upd_cnt} samples among all {n_samples} samples.')
+
+ # 5. [Async] Save the results.
+ with PM.time_monitor('async_dumping'):
+ # TODO: Use lock and other techniques to achieve a better submission system.
+ # TODO: We need to design a better way to solve the synchronization problem.
+ if hasattr(self, 'dump_thread'):
+ self.dump_thread.result() # Wait for the previous dump to finish.
+ with ThreadPoolExecutor() as executor:
+ self.dump_thread = executor.submit(lambda: np.savez(self.better_pgt_fn, **self.better_pgt))
+
+ # 5. Clean up the memory.
+ del skelify, skel_model
+ self._init_pd_dict()
\ No newline at end of file
diff --git a/lib/modeling/losses/__init__.py b/lib/modeling/losses/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..c791dce25e46daa4f728a28e76ac048cdf3f5945
--- /dev/null
+++ b/lib/modeling/losses/__init__.py
@@ -0,0 +1,2 @@
+from .prior import *
+from .losses import *
\ No newline at end of file
diff --git a/lib/modeling/losses/kp.py b/lib/modeling/losses/kp.py
new file mode 100644
index 0000000000000000000000000000000000000000..70a938f90c3d46d6a08ceec3c002f2f5ac66e70c
--- /dev/null
+++ b/lib/modeling/losses/kp.py
@@ -0,0 +1,15 @@
+from lib.kits.basic import *
+
+
+def compute_kp3d_loss(gt_kp3d, pd_kp3d, ref_jid=25+14):
+ conf = gt_kp3d[:, :, 3:].clone() # (B, 43, 1)
+ gt_kp3d_a = gt_kp3d[:, :, :3] - gt_kp3d[:, [ref_jid], :3] # aligned, (B, J=44, 3)
+ pd_kp3d_a = pd_kp3d[:, :, :3] - pd_kp3d[:, [ref_jid], :3] # aligned, (B, J=44, 3)
+ kp3d_loss = conf * F.l1_loss(pd_kp3d_a, gt_kp3d_a, reduction='none') # (B, J=44, 3)
+ return kp3d_loss.sum() # (,)
+
+
+def compute_kp2d_loss(gt_kp2d, pd_kp2d):
+ conf = gt_kp2d[:, :, 2:].clone() # (B, 44, 1)
+ kp2d_loss = conf * F.l1_loss(pd_kp2d, gt_kp2d[:, :, :2], reduction='none') # (B, 44, 2)
+ return kp2d_loss.sum() # (,)
\ No newline at end of file
diff --git a/lib/modeling/losses/losses.py b/lib/modeling/losses/losses.py
new file mode 100644
index 0000000000000000000000000000000000000000..87d2e30453da23e1c172c878444bc6415224e56e
--- /dev/null
+++ b/lib/modeling/losses/losses.py
@@ -0,0 +1,97 @@
+import torch
+import torch.nn as nn
+
+from smplx.body_models import SMPLOutput
+from lib.body_models.skel.skel_model import SKELOutput
+
+
+class Keypoint2DLoss(nn.Module):
+
+ def __init__(self, loss_type: str = 'l1'):
+ """
+ 2D keypoint loss module.
+ Args:
+ loss_type (str): Choose between l1 and l2 losses.
+ """
+ super(Keypoint2DLoss, self).__init__()
+ if loss_type == 'l1':
+ self.loss_fn = nn.L1Loss(reduction='none')
+ elif loss_type == 'l2':
+ self.loss_fn = nn.MSELoss(reduction='none')
+ else:
+ raise NotImplementedError('Unsupported loss function')
+
+ def forward(self, pred_keypoints_2d: torch.Tensor, gt_keypoints_2d: torch.Tensor) -> torch.Tensor:
+ """
+ Compute 2D reprojection loss on the keypoints.
+ Args:
+ pred_keypoints_2d (torch.Tensor): Tensor of shape [B, S, N, 2] containing projected 2D keypoints (B: batch_size, S: num_samples, N: num_keypoints)
+ gt_keypoints_2d (torch.Tensor): Tensor of shape [B, S, N, 3] containing the ground truth 2D keypoints and confidence.
+ Returns:
+ torch.Tensor: 2D keypoint loss.
+ """
+ conf = gt_keypoints_2d[:, :, -1].unsqueeze(-1).clone()
+ batch_size = conf.shape[0]
+ loss = (conf * self.loss_fn(pred_keypoints_2d, gt_keypoints_2d[:, :, :-1])).sum(dim=(1,2))
+ return loss.sum()
+
+
+class Keypoint3DLoss(nn.Module):
+
+ def __init__(self, loss_type: str = 'l1'):
+ """
+ 3D keypoint loss module.
+ Args:
+ loss_type (str): Choose between l1 and l2 losses.
+ """
+ super(Keypoint3DLoss, self).__init__()
+ if loss_type == 'l1':
+ self.loss_fn = nn.L1Loss(reduction='none')
+ elif loss_type == 'l2':
+ self.loss_fn = nn.MSELoss(reduction='none')
+ else:
+ raise NotImplementedError('Unsupported loss function')
+
+ def forward(self, pred_keypoints_3d: torch.Tensor, gt_keypoints_3d: torch.Tensor, pelvis_id: int = 39):
+ """
+ Compute 3D keypoint loss.
+ Args:
+ pred_keypoints_3d (torch.Tensor): Tensor of shape [B, S, N, 3] containing the predicted 3D keypoints (B: batch_size, S: num_samples, N: num_keypoints)
+ gt_keypoints_3d (torch.Tensor): Tensor of shape [B, S, N, 4] containing the ground truth 3D keypoints and confidence.
+ Returns:
+ torch.Tensor: 3D keypoint loss.
+ """
+ batch_size = pred_keypoints_3d.shape[0]
+ gt_keypoints_3d = gt_keypoints_3d.clone()
+ pred_keypoints_3d = pred_keypoints_3d - pred_keypoints_3d[:, pelvis_id, :].unsqueeze(dim=1)
+ gt_keypoints_3d[:, :, :-1] = gt_keypoints_3d[:, :, :-1] - gt_keypoints_3d[:, pelvis_id, :-1].unsqueeze(dim=1)
+ conf = gt_keypoints_3d[:, :, -1].unsqueeze(-1).clone()
+ gt_keypoints_3d = gt_keypoints_3d[:, :, :-1]
+ loss = (conf * self.loss_fn(pred_keypoints_3d, gt_keypoints_3d)).sum(dim=(1,2))
+ return loss.sum()
+
+
+class ParameterLoss(nn.Module):
+
+ def __init__(self):
+ """
+ SMPL parameter loss module.
+ """
+ super(ParameterLoss, self).__init__()
+ self.loss_fn = nn.MSELoss(reduction='none')
+
+ def forward(self, pred_param: torch.Tensor, gt_param: torch.Tensor, has_param: torch.Tensor):
+ """
+ Compute SMPL parameter loss.
+ Args:
+ pred_param (torch.Tensor): Tensor of shape [B, S, ...] containing the predicted parameters (body pose / global orientation / betas)
+ gt_param (torch.Tensor): Tensor of shape [B, S, ...] containing the ground truth SMPL parameters.
+ Returns:
+ torch.Tensor: L2 parameter loss loss.
+ """
+ batch_size = pred_param.shape[0]
+ num_dims = len(pred_param.shape)
+ mask_dimension = [batch_size] + [1] * (num_dims-1)
+ has_param = has_param.type(pred_param.type()).view(*mask_dimension)
+ loss_param = (has_param * self.loss_fn(pred_param, gt_param))
+ return loss_param.sum()
diff --git a/lib/modeling/losses/prior.py b/lib/modeling/losses/prior.py
new file mode 100644
index 0000000000000000000000000000000000000000..2269c756369296282d72b143644f2e713567bfc3
--- /dev/null
+++ b/lib/modeling/losses/prior.py
@@ -0,0 +1,108 @@
+from lib.kits.basic import *
+
+# ============= #
+# Utils #
+# ============= #
+
+
+def soft_bound_loss(x, low, up):
+ '''
+ Softly penalize the violation of the lower and upper bounds.
+ PROBLEMS: for joints like legs, whose normal pose is near the boundary (standing person tend to have zero rotation but the limitation is zero-bounded, which encourage the leg to bend somehow).
+
+ ### Args:
+ - x: torch.tensor
+ - shape = (B, Q), where Q is the number of components.
+ - low: torch.tensor, lower bound.
+ - shape = (Q,)
+ - Lower bound.
+ - up: (Q,)
+ - shape = (Q,)
+ - Upper bound.
+
+ ### Returns:
+ - loss: torch.tensor
+ - shape = (B,)
+ '''
+ B = len(x)
+ loss = torch.exp(low[None] - x).pow(2) + torch.exp(x - up[None]).pow(2) # (B, Q)
+ return loss # (B,)
+
+
+def softer_bound_loss(x, low, up):
+ '''
+ Softly penalize the violation of the lower and upper bounds. This loss won't penalize so hard when the
+ value exceed the bound by a small margin (half of up - low), but it's friendly to the case when the common
+ case is not centered at the middle of the bound. (E.g., the rotation of knee is more likely to be at zero
+ when some one is standing straight, but zero is the lower bound.)
+
+ ### Args:
+ - x: torch.tensor, (B, Q), where Q is the number of components.
+ - low: torch.tensor, (Q,)
+ - Lower bound.
+ - up: torch.tensor, (Q,)
+ - Upper bound.
+
+ ### Returns:
+ - loss: torch.tensor, (B,)
+ '''
+ B = len(x)
+ expand = (up - low) / 2
+ loss = torch.exp((low[None] - expand) - x).pow(2) + torch.exp(x - (up[None] + expand)).pow(2) # (B, Q)
+ return loss # (B,)
+
+
+def softest_bound_loss(x, low, up):
+ '''
+ Softly penalize the violation of the lower and upper bounds. This loss won't penalize so hard when the
+ value exceed the bound by a small margin (half of up - low), but it's friendly to the case when the common
+ case is not centered at the middle of the bound. (E.g., the rotation of knee is more likely to be at zero
+ when some one is standing straight, but zero is the lower bound.)
+
+ ### Args:
+ - x: torch.tensor, (B, Q), where Q is the number of components.
+ - low: torch.tensor, (Q,)
+ - Lower bound.
+ - up: torch.tensor, (Q,)
+ - Upper bound.
+
+ ### Returns:
+ - loss: torch.tensor, (B,)
+ '''
+ B = len(x)
+ expand = (up - low) / 2
+ lower_loss = torch.exp((low[None] - expand) - x).pow(2) - 1 # (B, Q)
+ upper_loss = torch.exp(x - (up[None] + expand)).pow(2) - 1 # (B, Q)
+ lower_loss = torch.where(lower_loss < 0, 0, lower_loss)
+ upper_loss = torch.where(upper_loss < 0, 0, upper_loss)
+ loss = lower_loss + upper_loss
+ return loss # (B,)
+
+
+# ============= #
+# Loss #
+# ============= #
+
+
+def compute_poses_angle_prior_loss(poses):
+ '''
+ Some components have upper and lower bound, use exponential loss to apply soft limitation.
+
+ ### Args
+ - poses: torch.tensor, (B, 46)
+
+ ### Returns
+ - loss: torch.tensor, (,)
+ '''
+ from lib.body_models.skel_utils.limits import SKEL_LIM_QIDS, SKEL_LIM_BOUNDS
+
+ device = poses.device
+ # loss = softer_bound_loss(
+ # loss = softest_bound_loss(
+ loss = soft_bound_loss(
+ x = poses[:, SKEL_LIM_QIDS],
+ low = SKEL_LIM_BOUNDS[:, 0].to(device),
+ up = SKEL_LIM_BOUNDS[:, 1].to(device),
+ ) # (,)
+
+ return loss
diff --git a/lib/modeling/networks/backbones/README.md b/lib/modeling/networks/backbones/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..3f6b3321c7688f91ec95ac970af227207a94602d
--- /dev/null
+++ b/lib/modeling/networks/backbones/README.md
@@ -0,0 +1,5 @@
+# ViT Backbone
+
+The implementation of ViT backbone comes from [ViTPose](https://github.com/ViTAE-Transformer/ViTPose/blob/d5216452796c90c6bc29f5c5ec0bdba94366768a/mmpose/models/backbones/vit.py#L103).
+
+Meanwhile, we used **the backbone part** of their released checkpoints.
\ No newline at end of file
diff --git a/lib/modeling/networks/backbones/__init__.py b/lib/modeling/networks/backbones/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..6e63a8408d15b5af56b2a02ce10a1f1185448160
--- /dev/null
+++ b/lib/modeling/networks/backbones/__init__.py
@@ -0,0 +1 @@
+from .vit import ViT
\ No newline at end of file
diff --git a/lib/modeling/networks/backbones/vit.py b/lib/modeling/networks/backbones/vit.py
new file mode 100644
index 0000000000000000000000000000000000000000..85569f06a8fa10c6aff6b25356aa57a1283124cc
--- /dev/null
+++ b/lib/modeling/networks/backbones/vit.py
@@ -0,0 +1,335 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import math
+
+import torch
+from functools import partial
+import torch.nn as nn
+import torch.nn.functional as F
+import torch.utils.checkpoint as checkpoint
+
+from timm.models.layers import drop_path, to_2tuple, trunc_normal_
+
+
+def get_abs_pos(abs_pos, h, w, ori_h, ori_w, has_cls_token=True):
+ """
+ Calculate absolute positional embeddings. If needed, resize embeddings and remove cls_token
+ dimension for the original embeddings.
+ Args:
+ abs_pos (Tensor): absolute positional embeddings with (1, num_position, C).
+ has_cls_token (bool): If true, has 1 embedding in abs_pos for cls token.
+ hw (Tuple): size of input image tokens.
+
+ Returns:
+ Absolute positional embeddings after processing with shape (1, H, W, C)
+ """
+ cls_token = None
+ B, L, C = abs_pos.shape
+ if has_cls_token:
+ cls_token = abs_pos[:, 0:1]
+ abs_pos = abs_pos[:, 1:]
+
+ if ori_h != h or ori_w != w:
+ new_abs_pos = F.interpolate(
+ abs_pos.reshape(1, ori_h, ori_w, -1).permute(0, 3, 1, 2),
+ size=(h, w),
+ mode="bicubic",
+ align_corners=False,
+ ).permute(0, 2, 3, 1).reshape(B, -1, C)
+
+ else:
+ new_abs_pos = abs_pos
+
+ if cls_token is not None:
+ new_abs_pos = torch.cat([cls_token, new_abs_pos], dim=1)
+ return new_abs_pos
+
+class DropPath(nn.Module):
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
+ """
+ def __init__(self, drop_prob=None):
+ super(DropPath, self).__init__()
+ self.drop_prob = drop_prob
+
+ def forward(self, x):
+ return drop_path(x, self.drop_prob, self.training)
+
+ def extra_repr(self):
+ return 'p={}'.format(self.drop_prob)
+
+class Mlp(nn.Module):
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
+ super().__init__()
+ out_features = out_features or in_features
+ hidden_features = hidden_features or in_features
+ self.fc1 = nn.Linear(in_features, hidden_features)
+ self.act = act_layer()
+ self.fc2 = nn.Linear(hidden_features, out_features)
+ self.drop = nn.Dropout(drop)
+
+ def forward(self, x):
+ x = self.fc1(x)
+ x = self.act(x)
+ x = self.fc2(x)
+ x = self.drop(x)
+ return x
+
+class Attention(nn.Module):
+ def __init__(
+ self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0.,
+ proj_drop=0., attn_head_dim=None,):
+ super().__init__()
+ self.num_heads = num_heads
+ head_dim = dim // num_heads
+ self.dim = dim
+
+ if attn_head_dim is not None:
+ head_dim = attn_head_dim
+ all_head_dim = head_dim * self.num_heads
+
+ self.scale = qk_scale or head_dim ** -0.5
+
+ self.qkv = nn.Linear(dim, all_head_dim * 3, bias=qkv_bias)
+
+ self.attn_drop = nn.Dropout(attn_drop)
+ self.proj = nn.Linear(all_head_dim, dim)
+ self.proj_drop = nn.Dropout(proj_drop)
+
+ def forward(self, x):
+ B, N, C = x.shape
+ qkv = self.qkv(x)
+ qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
+ q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
+
+ q = q * self.scale
+ attn = (q @ k.transpose(-2, -1))
+
+ attn = attn.softmax(dim=-1)
+ attn = self.attn_drop(attn)
+
+ x = (attn @ v).transpose(1, 2).reshape(B, N, -1)
+ x = self.proj(x)
+ x = self.proj_drop(x)
+
+ return x
+
+class Block(nn.Module):
+
+ def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None,
+ drop=0., attn_drop=0., drop_path=0., act_layer=nn.GELU,
+ norm_layer=nn.LayerNorm, attn_head_dim=None
+ ):
+ super().__init__()
+
+ self.norm1 = norm_layer(dim)
+ self.attn = Attention(
+ dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale,
+ attn_drop=attn_drop, proj_drop=drop, attn_head_dim=attn_head_dim
+ )
+
+ # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
+ self.norm2 = norm_layer(dim)
+ mlp_hidden_dim = int(dim * mlp_ratio)
+ self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
+
+ def forward(self, x):
+ x = x + self.drop_path(self.attn(self.norm1(x)))
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
+ return x
+
+
+class PatchEmbed(nn.Module):
+ """ Image to Patch Embedding
+ """
+ def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, ratio=1):
+ super().__init__()
+ img_size = to_2tuple(img_size)
+ patch_size = to_2tuple(patch_size)
+ num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0]) * (ratio ** 2)
+ self.patch_shape = (int(img_size[0] // patch_size[0] * ratio), int(img_size[1] // patch_size[1] * ratio))
+ self.origin_patch_shape = (int(img_size[0] // patch_size[0]), int(img_size[1] // patch_size[1]))
+ self.img_size = img_size
+ self.patch_size = patch_size
+ self.num_patches = num_patches
+
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=(patch_size[0] // ratio), padding=4 + 2 * (ratio//2-1))
+
+ def forward(self, x, **kwargs):
+ B, C, H, W = x.shape
+ x = self.proj(x)
+ Hp, Wp = x.shape[2], x.shape[3]
+
+ x = x.flatten(2).transpose(1, 2)
+ return x, (Hp, Wp)
+
+
+class HybridEmbed(nn.Module):
+ """ CNN Feature Map Embedding
+ Extract feature map from CNN, flatten, project to embedding dim.
+ """
+ def __init__(self, backbone, img_size=224, feature_size=None, in_chans=3, embed_dim=768):
+ super().__init__()
+ assert isinstance(backbone, nn.Module)
+ img_size = to_2tuple(img_size)
+ self.img_size = img_size
+ self.backbone = backbone
+ if feature_size is None:
+ with torch.no_grad():
+ training = backbone.training
+ if training:
+ backbone.eval()
+ o = self.backbone(torch.zeros(1, in_chans, img_size[0], img_size[1]))[-1]
+ feature_size = o.shape[-2:]
+ feature_dim = o.shape[1]
+ backbone.train(training)
+ else:
+ feature_size = to_2tuple(feature_size)
+ feature_dim = self.backbone.feature_info.channels()[-1]
+ self.num_patches = feature_size[0] * feature_size[1]
+ self.proj = nn.Linear(feature_dim, embed_dim)
+
+ def forward(self, x):
+ x = self.backbone(x)[-1]
+ x = x.flatten(2).transpose(1, 2)
+ x = self.proj(x)
+ return x
+
+
+class ViT(nn.Module):
+
+ def __init__(self,
+ img_size=224, patch_size=16, in_chans=3, num_classes=80, embed_dim=768, depth=12,
+ num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0.,
+ drop_path_rate=0., hybrid_backbone=None, norm_layer=None, use_checkpoint=False,
+ frozen_stages=-1, ratio=1, last_norm=True,
+ patch_padding='pad', freeze_attn=False, freeze_ffn=False,
+ ):
+ # Protect mutable default arguments
+ super(ViT, self).__init__()
+ norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
+ self.num_classes = num_classes
+ self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
+ self.frozen_stages = frozen_stages
+ self.use_checkpoint = use_checkpoint
+ self.patch_padding = patch_padding
+ self.freeze_attn = freeze_attn
+ self.freeze_ffn = freeze_ffn
+ self.depth = depth
+
+ if hybrid_backbone is not None:
+ self.patch_embed = HybridEmbed(
+ hybrid_backbone, img_size=img_size, in_chans=in_chans, embed_dim=embed_dim)
+ else:
+ self.patch_embed = PatchEmbed(
+ img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, ratio=ratio)
+ num_patches = self.patch_embed.num_patches
+
+ # since the pretraining model has class token
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
+
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
+
+ self.blocks = nn.ModuleList([
+ Block(
+ dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
+ drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer,
+ )
+ for i in range(depth)])
+
+ self.last_norm = norm_layer(embed_dim) if last_norm else nn.Identity()
+
+ if self.pos_embed is not None:
+ trunc_normal_(self.pos_embed, std=.02)
+
+ self._freeze_stages()
+
+ def _freeze_stages(self):
+ """Freeze parameters."""
+ if self.frozen_stages >= 0:
+ self.patch_embed.eval()
+ for param in self.patch_embed.parameters():
+ param.requires_grad = False
+
+ for i in range(1, self.frozen_stages + 1):
+ m = self.blocks[i]
+ m.eval()
+ for param in m.parameters():
+ param.requires_grad = False
+
+ if self.freeze_attn:
+ for i in range(0, self.depth):
+ m = self.blocks[i]
+ m.attn.eval()
+ m.norm1.eval()
+ for param in m.attn.parameters():
+ param.requires_grad = False
+ for param in m.norm1.parameters():
+ param.requires_grad = False
+
+ if self.freeze_ffn:
+ self.pos_embed.requires_grad = False
+ self.patch_embed.eval()
+ for param in self.patch_embed.parameters():
+ param.requires_grad = False
+ for i in range(0, self.depth):
+ m = self.blocks[i]
+ m.mlp.eval()
+ m.norm2.eval()
+ for param in m.mlp.parameters():
+ param.requires_grad = False
+ for param in m.norm2.parameters():
+ param.requires_grad = False
+
+ def init_weights(self):
+ """Initialize the weights in backbone.
+ Args:
+ pretrained (str, optional): Path to pre-trained weights.
+ Defaults to None.
+ """
+ def _init_weights(m):
+ if isinstance(m, nn.Linear):
+ trunc_normal_(m.weight, std=.02)
+ if isinstance(m, nn.Linear) and m.bias is not None:
+ nn.init.constant_(m.bias, 0)
+ elif isinstance(m, nn.LayerNorm):
+ nn.init.constant_(m.bias, 0)
+ nn.init.constant_(m.weight, 1.0)
+
+ self.apply(_init_weights)
+
+ def get_num_layers(self):
+ return len(self.blocks)
+
+ @torch.jit.ignore
+ def no_weight_decay(self):
+ return {'pos_embed', 'cls_token'}
+
+ def forward_features(self, x):
+ B, C, H, W = x.shape
+ x, (Hp, Wp) = self.patch_embed(x)
+
+ if self.pos_embed is not None:
+ # fit for multiple GPU training
+ # since the first element for pos embed (sin-cos manner) is zero, it will cause no difference
+ x = x + self.pos_embed[:, 1:] + self.pos_embed[:, :1]
+
+ for blk in self.blocks:
+ if self.use_checkpoint:
+ x = checkpoint.checkpoint(blk, x)
+ else:
+ x = blk(x)
+
+ x = self.last_norm(x)
+
+ xp = x.permute(0, 2, 1).reshape(B, -1, Hp, Wp).contiguous()
+
+ return xp
+
+ def forward(self, x):
+ x = self.forward_features(x)
+ return x
+
+ def train(self, mode=True):
+ """Convert the model into training mode."""
+ super().train(mode)
+ self._freeze_stages()
diff --git a/lib/modeling/networks/discriminators/__init__.py b/lib/modeling/networks/discriminators/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..d11b6d255ce76003c6d0f18ac2753f976be35679
--- /dev/null
+++ b/lib/modeling/networks/discriminators/__init__.py
@@ -0,0 +1 @@
+from .hsmr_disc import *
\ No newline at end of file
diff --git a/lib/modeling/networks/discriminators/hsmr_disc.py b/lib/modeling/networks/discriminators/hsmr_disc.py
new file mode 100644
index 0000000000000000000000000000000000000000..84b4649ae0052e4ed2e2057c6c971491f822ad6b
--- /dev/null
+++ b/lib/modeling/networks/discriminators/hsmr_disc.py
@@ -0,0 +1,95 @@
+from lib.kits.basic import *
+
+
+
+class HSMRDiscriminator(nn.Module):
+
+ def __init__(self):
+ '''
+ Pose + Shape discriminator proposed in HMR
+ '''
+ super(HSMRDiscriminator, self).__init__()
+
+ self.n_poses = 23
+ # poses_alone
+ self.D_conv1 = nn.Conv2d(9, 32, kernel_size=1)
+ nn.init.xavier_uniform_(self.D_conv1.weight)
+ nn.init.zeros_(self.D_conv1.bias)
+ self.relu = nn.ReLU(inplace=True)
+ self.D_conv2 = nn.Conv2d(32, 32, kernel_size=1)
+ nn.init.xavier_uniform_(self.D_conv2.weight)
+ nn.init.zeros_(self.D_conv2.bias)
+ pose_out = []
+ for i in range(self.n_poses):
+ pose_out_temp = nn.Linear(32, 1)
+ nn.init.xavier_uniform_(pose_out_temp.weight)
+ nn.init.zeros_(pose_out_temp.bias)
+ pose_out.append(pose_out_temp)
+ self.pose_out = nn.ModuleList(pose_out)
+
+ # betas
+ self.betas_fc1 = nn.Linear(10, 10)
+ nn.init.xavier_uniform_(self.betas_fc1.weight)
+ nn.init.zeros_(self.betas_fc1.bias)
+ self.betas_fc2 = nn.Linear(10, 5)
+ nn.init.xavier_uniform_(self.betas_fc2.weight)
+ nn.init.zeros_(self.betas_fc2.bias)
+ self.betas_out = nn.Linear(5, 1)
+ nn.init.xavier_uniform_(self.betas_out.weight)
+ nn.init.zeros_(self.betas_out.bias)
+
+ # poses_joint
+ self.D_alljoints_fc1 = nn.Linear(32*self.n_poses, 1024)
+ nn.init.xavier_uniform_(self.D_alljoints_fc1.weight)
+ nn.init.zeros_(self.D_alljoints_fc1.bias)
+ self.D_alljoints_fc2 = nn.Linear(1024, 1024)
+ nn.init.xavier_uniform_(self.D_alljoints_fc2.weight)
+ nn.init.zeros_(self.D_alljoints_fc2.bias)
+ self.D_alljoints_out = nn.Linear(1024, 1)
+ nn.init.xavier_uniform_(self.D_alljoints_out.weight)
+ nn.init.zeros_(self.D_alljoints_out.bias)
+
+
+ def forward(self, poses_body: torch.Tensor, betas: torch.Tensor) -> torch.Tensor:
+ '''
+ Forward pass of the discriminator.
+ ### Args
+ - poses: torch.Tensor, shape (B, 23, 9)
+ - Matrix representation of the SKEL poses excluding the global orientation.
+ - betas: torch.Tensor, shape (B, 10)
+ ### Returns
+ - torch.Tensor, shape (B, 25)
+ '''
+ poses_body = poses_body.reshape(-1, self.n_poses, 1, 9) # (B, n_poses, 1, 9)
+ B = poses_body.shape[0]
+ poses_body = poses_body.permute(0, 3, 1, 2).contiguous() # (B, 9, n_poses, 1)
+
+ # poses_alone
+ poses_body = self.D_conv1(poses_body)
+ poses_body = self.relu(poses_body)
+ poses_body = self.D_conv2(poses_body)
+ poses_body = self.relu(poses_body)
+
+ poses_out = []
+ for i in range(self.n_poses):
+ poses_out_i = self.pose_out[i](poses_body[:, :, i, 0])
+ poses_out.append(poses_out_i)
+ poses_out = torch.cat(poses_out, dim=1)
+
+ # betas
+ betas = self.betas_fc1(betas)
+ betas = self.relu(betas)
+ betas = self.betas_fc2(betas)
+ betas = self.relu(betas)
+ betas_out = self.betas_out(betas)
+
+ # poses_joint
+ poses_body = poses_body.reshape(B, -1)
+ poses_all = self.D_alljoints_fc1(poses_body)
+ poses_all = self.relu(poses_all)
+ poses_all = self.D_alljoints_fc2(poses_all)
+ poses_all = self.relu(poses_all)
+ poses_all_out = self.D_alljoints_out(poses_all)
+
+ disc_out = torch.cat((poses_out, betas_out, poses_all_out), dim=1)
+ return disc_out
diff --git a/lib/modeling/networks/heads/SKEL_mean.npz b/lib/modeling/networks/heads/SKEL_mean.npz
new file mode 100644
index 0000000000000000000000000000000000000000..fb888dc2112a35aa94b17f3edbff75651848866f
Binary files /dev/null and b/lib/modeling/networks/heads/SKEL_mean.npz differ
diff --git a/lib/modeling/networks/heads/__init__.py b/lib/modeling/networks/heads/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..5751ba0de3cee06070292cd00ff149dcb240bcf5
--- /dev/null
+++ b/lib/modeling/networks/heads/__init__.py
@@ -0,0 +1 @@
+from .skel_head import SKELTransformerDecoderHead
\ No newline at end of file
diff --git a/lib/modeling/networks/heads/skel_head.py b/lib/modeling/networks/heads/skel_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..c32c5cb76f1fbbaddf8b504f16a112f2ed33a9ce
--- /dev/null
+++ b/lib/modeling/networks/heads/skel_head.py
@@ -0,0 +1,107 @@
+from lib.kits.basic import *
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import numpy as np
+import einops
+
+from omegaconf import OmegaConf
+
+from lib.platform import PM
+from lib.body_models.skel_utils.transforms import params_q2rep, params_rep2q
+
+from .utils.pose_transformer import TransformerDecoder
+
+
+class SKELTransformerDecoderHead(nn.Module):
+ """ Cross-attention based SKEL Transformer decoder
+ """
+
+ def __init__(self, cfg):
+ super().__init__()
+ self.cfg = cfg
+
+ if cfg.pd_poses_repr == 'rotation_6d':
+ n_poses = 24 * 6
+ elif cfg.pd_poses_repr == 'euler_angle':
+ n_poses = 46
+ else:
+ raise ValueError(f"Unknown pose representation: {cfg.pd_poses_repr}")
+
+ n_betas = 10
+ n_cam = 3
+ self.input_is_mean_shape = False
+
+ # Build transformer decoder.
+ transformer_args = {
+ 'num_tokens' : 1,
+ 'token_dim' : (n_poses + n_betas + n_cam) if self.input_is_mean_shape else 1,
+ 'dim' : 1024,
+ }
+ transformer_args.update(OmegaConf.to_container(cfg.transformer_decoder, resolve=True)) # type: ignore
+ self.transformer = TransformerDecoder(**transformer_args)
+
+ # Build decoders for parameters.
+ dim = transformer_args['dim']
+ self.poses_decoder = nn.Linear(dim, n_poses)
+ self.betas_decoder = nn.Linear(dim, n_betas)
+ self.cam_decoder = nn.Linear(dim, n_cam)
+
+ # Load mean shape parameters as initial values.
+ skel_mean_path = Path(__file__).parent / 'SKEL_mean.npz'
+ skel_mean_params = np.load(skel_mean_path)
+
+ init_poses = torch.from_numpy(skel_mean_params['poses'].astype(np.float32)).unsqueeze(0) # (1, 46)
+ if cfg.pd_poses_repr == 'rotation_6d':
+ init_poses = params_q2rep(init_poses).reshape(1, 24*6) # (1, 24*6)
+ init_betas = torch.from_numpy(skel_mean_params['betas'].astype(np.float32)).unsqueeze(0)
+ init_cam = torch.from_numpy(skel_mean_params['cam'].astype(np.float32)).unsqueeze(0)
+
+ self.register_buffer('init_poses', init_poses)
+ self.register_buffer('init_betas', init_betas)
+ self.register_buffer('init_cam', init_cam)
+
+ def forward(self, x, **kwargs):
+ B = x.shape[0]
+ # vit pretrained backbone is channel-first. Change to token-first
+ x = einops.rearrange(x, 'b c h w -> b (h w) c')
+
+ # Initialize the parameters.
+ init_poses = self.init_poses.expand(B, -1) # (B, 46)
+ init_betas = self.init_betas.expand(B, -1) # (B, 10)
+ init_cam = self.init_cam.expand(B, -1) # (B, 3)
+
+ # Input token to transformer is zero token.
+ with PM.time_monitor('init_token'):
+ if self.input_is_mean_shape:
+ token = torch.cat([init_poses, init_betas, init_cam], dim=1)[:, None, :] # (B, 1, C)
+ else:
+ token = x.new_zeros(B, 1, 1)
+
+ # Pass through transformer.
+ with PM.time_monitor('transformer'):
+ token_out = self.transformer(token, context=x)
+ token_out = token_out.squeeze(1) # (B, C)
+
+ # Parse the SKEL parameters out from token_out.
+ with PM.time_monitor('decode'):
+ pd_poses = self.poses_decoder(token_out) + init_poses
+ pd_betas = self.betas_decoder(token_out) + init_betas
+ pd_cam = self.cam_decoder(token_out) + init_cam
+
+ with PM.time_monitor('rot_repr_transform'):
+ if self.cfg.pd_poses_repr == 'rotation_6d':
+ pd_poses = params_rep2q(pd_poses.reshape(-1, 24, 6)) # (B, 46)
+ elif self.cfg.pd_poses_repr == 'euler_angle':
+ pd_poses = pd_poses.reshape(-1, 46) # (B, 46)
+ else:
+ raise ValueError(f"Unknown pose representation: {self.cfg.pd_poses_repr}")
+
+ pd_skel_params = {
+ 'poses' : pd_poses,
+ 'poses_orient' : pd_poses[:, :3],
+ 'poses_body' : pd_poses[:, 3:],
+ 'betas' : pd_betas
+ }
+ return pd_skel_params, pd_cam
\ No newline at end of file
diff --git a/lib/modeling/networks/heads/utils/__init__.py b/lib/modeling/networks/heads/utils/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/lib/modeling/networks/heads/utils/geometry.py b/lib/modeling/networks/heads/utils/geometry.py
new file mode 100644
index 0000000000000000000000000000000000000000..86648e9f22866a33b261880bb3953fddd93926a8
--- /dev/null
+++ b/lib/modeling/networks/heads/utils/geometry.py
@@ -0,0 +1,63 @@
+from typing import Optional
+import torch
+from torch.nn import functional as F
+
+
+def aa_to_rotmat(theta: torch.Tensor):
+ """
+ Convert axis-angle representation to rotation matrix.
+ Works by first converting it to a quaternion.
+ Args:
+ theta (torch.Tensor): Tensor of shape (B, 3) containing axis-angle representations.
+ Returns:
+ torch.Tensor: Corresponding rotation matrices with shape (B, 3, 3).
+ """
+ norm = torch.norm(theta + 1e-8, p = 2, dim = 1)
+ angle = torch.unsqueeze(norm, -1)
+ normalized = torch.div(theta, angle)
+ angle = angle * 0.5
+ v_cos = torch.cos(angle)
+ v_sin = torch.sin(angle)
+ quat = torch.cat([v_cos, v_sin * normalized], dim = 1)
+ return quat_to_rotmat(quat)
+
+
+def quat_to_rotmat(quat: torch.Tensor) -> torch.Tensor:
+ """
+ Convert quaternion representation to rotation matrix.
+ Args:
+ quat (torch.Tensor) of shape (B, 4); 4 <===> (w, x, y, z).
+ Returns:
+ torch.Tensor: Corresponding rotation matrices with shape (B, 3, 3).
+ """
+ norm_quat = quat
+ norm_quat = norm_quat/norm_quat.norm(p=2, dim=1, keepdim=True)
+ w, x, y, z = norm_quat[:,0], norm_quat[:,1], norm_quat[:,2], norm_quat[:,3]
+
+ B = quat.size(0)
+
+ w2, x2, y2, z2 = w.pow(2), x.pow(2), y.pow(2), z.pow(2)
+ wx, wy, wz = w*x, w*y, w*z
+ xy, xz, yz = x*y, x*z, y*z
+
+ rotMat = torch.stack([w2 + x2 - y2 - z2, 2*xy - 2*wz, 2*wy + 2*xz,
+ 2*wz + 2*xy, w2 - x2 + y2 - z2, 2*yz - 2*wx,
+ 2*xz - 2*wy, 2*wx + 2*yz, w2 - x2 - y2 + z2], dim=1).view(B, 3, 3)
+ return rotMat
+
+def rot6d_to_rotmat(x: torch.Tensor) -> torch.Tensor:
+ """
+ Convert 6D rotation representation to 3x3 rotation matrix.
+ Based on Zhou et al., "On the Continuity of Rotation Representations in Neural Networks", CVPR 2019
+ Args:
+ x (torch.Tensor): (B,6) Batch of 6-D rotation representations.
+ Returns:
+ torch.Tensor: Batch of corresponding rotation matrices with shape (B,3,3).
+ """
+ x = x.reshape(-1,2,3).permute(0, 2, 1).contiguous()
+ a1 = x[:, :, 0]
+ a2 = x[:, :, 1]
+ b1 = F.normalize(a1)
+ b2 = F.normalize(a2 - torch.einsum('bi,bi->b', b1, a2).unsqueeze(-1) * b1)
+ b3 = torch.linalg.cross(b1, b2)
+ return torch.stack((b1, b2, b3), dim=-1)
\ No newline at end of file
diff --git a/lib/modeling/networks/heads/utils/pose_transformer.py b/lib/modeling/networks/heads/utils/pose_transformer.py
new file mode 100644
index 0000000000000000000000000000000000000000..ac04971407cb59637490cc4842f048b9bc4758be
--- /dev/null
+++ b/lib/modeling/networks/heads/utils/pose_transformer.py
@@ -0,0 +1,358 @@
+from inspect import isfunction
+from typing import Callable, Optional
+
+import torch
+from einops import rearrange
+from einops.layers.torch import Rearrange
+from torch import nn
+
+from .t_cond_mlp import (
+ AdaptiveLayerNorm1D,
+ FrequencyEmbedder,
+ normalization_layer,
+)
+# from .vit import Attention, FeedForward
+
+
+def exists(val):
+ return val is not None
+
+
+def default(val, d):
+ if exists(val):
+ return val
+ return d() if isfunction(d) else d
+
+
+class PreNorm(nn.Module):
+ def __init__(self, dim: int, fn: Callable, norm: str = "layer", norm_cond_dim: int = -1):
+ super().__init__()
+ self.norm = normalization_layer(norm, dim, norm_cond_dim)
+ self.fn = fn
+
+ def forward(self, x: torch.Tensor, *args, **kwargs):
+ if isinstance(self.norm, AdaptiveLayerNorm1D):
+ return self.fn(self.norm(x, *args), **kwargs)
+ else:
+ return self.fn(self.norm(x), **kwargs)
+
+
+class FeedForward(nn.Module):
+ def __init__(self, dim, hidden_dim, dropout=0.0):
+ super().__init__()
+ self.net = nn.Sequential(
+ nn.Linear(dim, hidden_dim),
+ nn.GELU(),
+ nn.Dropout(dropout),
+ nn.Linear(hidden_dim, dim),
+ nn.Dropout(dropout),
+ )
+
+ def forward(self, x):
+ return self.net(x)
+
+
+class Attention(nn.Module):
+ def __init__(self, dim, heads=8, dim_head=64, dropout=0.0):
+ super().__init__()
+ inner_dim = dim_head * heads
+ project_out = not (heads == 1 and dim_head == dim)
+
+ self.heads = heads
+ self.scale = dim_head**-0.5
+
+ self.attend = nn.Softmax(dim=-1)
+ self.dropout = nn.Dropout(dropout)
+
+ self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False)
+
+ self.to_out = (
+ nn.Sequential(nn.Linear(inner_dim, dim), nn.Dropout(dropout))
+ if project_out
+ else nn.Identity()
+ )
+
+ def forward(self, x):
+ qkv = self.to_qkv(x).chunk(3, dim=-1)
+ q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=self.heads), qkv)
+
+ dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
+
+ attn = self.attend(dots)
+ attn = self.dropout(attn)
+
+ out = torch.matmul(attn, v)
+ out = rearrange(out, "b h n d -> b n (h d)")
+ return self.to_out(out)
+
+
+class CrossAttention(nn.Module):
+ def __init__(self, dim, context_dim=None, heads=8, dim_head=64, dropout=0.0):
+ super().__init__()
+ inner_dim = dim_head * heads
+ project_out = not (heads == 1 and dim_head == dim)
+
+ self.heads = heads
+ self.scale = dim_head**-0.5
+
+ self.attend = nn.Softmax(dim=-1)
+ self.dropout = nn.Dropout(dropout)
+
+ context_dim = default(context_dim, dim)
+ self.to_kv = nn.Linear(context_dim, inner_dim * 2, bias=False)
+ self.to_q = nn.Linear(dim, inner_dim, bias=False)
+
+ self.to_out = (
+ nn.Sequential(nn.Linear(inner_dim, dim), nn.Dropout(dropout))
+ if project_out
+ else nn.Identity()
+ )
+
+ def forward(self, x, context=None):
+ context = default(context, x)
+ k, v = self.to_kv(context).chunk(2, dim=-1)
+ q = self.to_q(x)
+ q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=self.heads), [q, k, v])
+
+ dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
+
+ attn = self.attend(dots)
+ attn = self.dropout(attn)
+
+ out = torch.matmul(attn, v)
+ out = rearrange(out, "b h n d -> b n (h d)")
+ return self.to_out(out)
+
+
+class Transformer(nn.Module):
+ def __init__(
+ self,
+ dim: int,
+ depth: int,
+ heads: int,
+ dim_head: int,
+ mlp_dim: int,
+ dropout: float = 0.0,
+ norm: str = "layer",
+ norm_cond_dim: int = -1,
+ ):
+ super().__init__()
+ self.layers = nn.ModuleList([])
+ for _ in range(depth):
+ sa = Attention(dim, heads=heads, dim_head=dim_head, dropout=dropout)
+ ff = FeedForward(dim, mlp_dim, dropout=dropout)
+ self.layers.append(
+ nn.ModuleList(
+ [
+ PreNorm(dim, sa, norm=norm, norm_cond_dim=norm_cond_dim),
+ PreNorm(dim, ff, norm=norm, norm_cond_dim=norm_cond_dim),
+ ]
+ )
+ )
+
+ def forward(self, x: torch.Tensor, *args):
+ for attn, ff in self.layers:
+ x = attn(x, *args) + x
+ x = ff(x, *args) + x
+ return x
+
+
+class TransformerCrossAttn(nn.Module):
+ def __init__(
+ self,
+ dim: int,
+ depth: int,
+ heads: int,
+ dim_head: int,
+ mlp_dim: int,
+ dropout: float = 0.0,
+ norm: str = "layer",
+ norm_cond_dim: int = -1,
+ context_dim: Optional[int] = None,
+ ):
+ super().__init__()
+ self.layers = nn.ModuleList([])
+ for _ in range(depth):
+ sa = Attention(dim, heads=heads, dim_head=dim_head, dropout=dropout)
+ ca = CrossAttention(
+ dim, context_dim=context_dim, heads=heads, dim_head=dim_head, dropout=dropout
+ )
+ ff = FeedForward(dim, mlp_dim, dropout=dropout)
+ self.layers.append(
+ nn.ModuleList(
+ [
+ PreNorm(dim, sa, norm=norm, norm_cond_dim=norm_cond_dim),
+ PreNorm(dim, ca, norm=norm, norm_cond_dim=norm_cond_dim),
+ PreNorm(dim, ff, norm=norm, norm_cond_dim=norm_cond_dim),
+ ]
+ )
+ )
+
+ def forward(self, x: torch.Tensor, *args, context=None, context_list=None):
+ if context_list is None:
+ context_list = [context] * len(self.layers)
+ if len(context_list) != len(self.layers):
+ raise ValueError(f"len(context_list) != len(self.layers) ({len(context_list)} != {len(self.layers)})")
+
+ for i, (self_attn, cross_attn, ff) in enumerate(self.layers):
+ x = self_attn(x, *args) + x
+ x = cross_attn(x, *args, context=context_list[i]) + x
+ x = ff(x, *args) + x
+ return x
+
+
+class DropTokenDropout(nn.Module):
+ def __init__(self, p: float = 0.1):
+ super().__init__()
+ if p < 0 or p > 1:
+ raise ValueError(
+ "dropout probability has to be between 0 and 1, " "but got {}".format(p)
+ )
+ self.p = p
+
+ def forward(self, x: torch.Tensor):
+ # x: (batch_size, seq_len, dim)
+ if self.training and self.p > 0:
+ zero_mask = torch.full_like(x[0, :, 0], self.p).bernoulli().bool()
+ # TODO: permutation idx for each batch using torch.argsort
+ if zero_mask.any():
+ x = x[:, ~zero_mask, :]
+ return x
+
+
+class ZeroTokenDropout(nn.Module):
+ def __init__(self, p: float = 0.1):
+ super().__init__()
+ if p < 0 or p > 1:
+ raise ValueError(
+ "dropout probability has to be between 0 and 1, " "but got {}".format(p)
+ )
+ self.p = p
+
+ def forward(self, x: torch.Tensor):
+ # x: (batch_size, seq_len, dim)
+ if self.training and self.p > 0:
+ zero_mask = torch.full_like(x[:, :, 0], self.p).bernoulli().bool()
+ # Zero-out the masked tokens
+ x[zero_mask, :] = 0
+ return x
+
+
+class TransformerEncoder(nn.Module):
+ def __init__(
+ self,
+ num_tokens: int,
+ token_dim: int,
+ dim: int,
+ depth: int,
+ heads: int,
+ mlp_dim: int,
+ dim_head: int = 64,
+ dropout: float = 0.0,
+ emb_dropout: float = 0.0,
+ emb_dropout_type: str = "drop",
+ emb_dropout_loc: str = "token",
+ norm: str = "layer",
+ norm_cond_dim: int = -1,
+ token_pe_numfreq: int = -1,
+ ):
+ super().__init__()
+ if token_pe_numfreq > 0:
+ token_dim_new = token_dim * (2 * token_pe_numfreq + 1)
+ self.to_token_embedding = nn.Sequential(
+ Rearrange("b n d -> (b n) d", n=num_tokens, d=token_dim),
+ FrequencyEmbedder(token_pe_numfreq, token_pe_numfreq - 1),
+ Rearrange("(b n) d -> b n d", n=num_tokens, d=token_dim_new),
+ nn.Linear(token_dim_new, dim),
+ )
+ else:
+ self.to_token_embedding = nn.Linear(token_dim, dim)
+ self.pos_embedding = nn.Parameter(torch.randn(1, num_tokens, dim))
+ if emb_dropout_type == "drop":
+ self.dropout = DropTokenDropout(emb_dropout)
+ elif emb_dropout_type == "zero":
+ self.dropout = ZeroTokenDropout(emb_dropout)
+ else:
+ raise ValueError(f"Unknown emb_dropout_type: {emb_dropout_type}")
+ self.emb_dropout_loc = emb_dropout_loc
+
+ self.transformer = Transformer(
+ dim, depth, heads, dim_head, mlp_dim, dropout, norm=norm, norm_cond_dim=norm_cond_dim
+ )
+
+ def forward(self, inp: torch.Tensor, *args, **kwargs):
+ x = inp
+
+ if self.emb_dropout_loc == "input":
+ x = self.dropout(x)
+ x = self.to_token_embedding(x)
+
+ if self.emb_dropout_loc == "token":
+ x = self.dropout(x)
+ b, n, _ = x.shape
+ x += self.pos_embedding[:, :n]
+
+ if self.emb_dropout_loc == "token_afterpos":
+ x = self.dropout(x)
+ x = self.transformer(x, *args)
+ return x
+
+
+class TransformerDecoder(nn.Module):
+ def __init__(
+ self,
+ num_tokens: int,
+ token_dim: int,
+ dim: int,
+ depth: int,
+ heads: int,
+ mlp_dim: int,
+ dim_head: int = 64,
+ dropout: float = 0.0,
+ emb_dropout: float = 0.0,
+ emb_dropout_type: str = 'drop',
+ norm: str = "layer",
+ norm_cond_dim: int = -1,
+ context_dim: Optional[int] = None,
+ skip_token_embedding: bool = False,
+ ):
+ super().__init__()
+ if not skip_token_embedding:
+ self.to_token_embedding = nn.Linear(token_dim, dim)
+ else:
+ self.to_token_embedding = nn.Identity()
+ if token_dim != dim:
+ raise ValueError(
+ f"token_dim ({token_dim}) != dim ({dim}) when skip_token_embedding is True"
+ )
+
+ self.pos_embedding = nn.Parameter(torch.randn(1, num_tokens, dim))
+ if emb_dropout_type == "drop":
+ self.dropout = DropTokenDropout(emb_dropout)
+ elif emb_dropout_type == "zero":
+ self.dropout = ZeroTokenDropout(emb_dropout)
+ elif emb_dropout_type == "normal":
+ self.dropout = nn.Dropout(emb_dropout)
+
+ self.transformer = TransformerCrossAttn(
+ dim,
+ depth,
+ heads,
+ dim_head,
+ mlp_dim,
+ dropout,
+ norm=norm,
+ norm_cond_dim=norm_cond_dim,
+ context_dim=context_dim,
+ )
+
+ def forward(self, inp: torch.Tensor, *args, context=None, context_list=None):
+ x = self.to_token_embedding(inp)
+ b, n, _ = x.shape
+
+ x = self.dropout(x)
+ x += self.pos_embedding[:, :n]
+
+ x = self.transformer(x, *args, context=context, context_list=context_list)
+ return x
+
diff --git a/lib/modeling/networks/heads/utils/t_cond_mlp.py b/lib/modeling/networks/heads/utils/t_cond_mlp.py
new file mode 100644
index 0000000000000000000000000000000000000000..44d5a09bf54f67712a69953039b7b5af41c3f029
--- /dev/null
+++ b/lib/modeling/networks/heads/utils/t_cond_mlp.py
@@ -0,0 +1,199 @@
+import copy
+from typing import List, Optional
+
+import torch
+
+
+class AdaptiveLayerNorm1D(torch.nn.Module):
+ def __init__(self, data_dim: int, norm_cond_dim: int):
+ super().__init__()
+ if data_dim <= 0:
+ raise ValueError(f"data_dim must be positive, but got {data_dim}")
+ if norm_cond_dim <= 0:
+ raise ValueError(f"norm_cond_dim must be positive, but got {norm_cond_dim}")
+ self.norm = torch.nn.LayerNorm(
+ data_dim
+ ) # TODO: Check if elementwise_affine=True is correct
+ self.linear = torch.nn.Linear(norm_cond_dim, 2 * data_dim)
+ torch.nn.init.zeros_(self.linear.weight)
+ torch.nn.init.zeros_(self.linear.bias)
+
+ def forward(self, x: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
+ # x: (batch, ..., data_dim)
+ # t: (batch, norm_cond_dim)
+ # return: (batch, data_dim)
+ x = self.norm(x)
+ alpha, beta = self.linear(t).chunk(2, dim=-1)
+
+ # Add singleton dimensions to alpha and beta
+ if x.dim() > 2:
+ alpha = alpha.view(alpha.shape[0], *([1] * (x.dim() - 2)), alpha.shape[1])
+ beta = beta.view(beta.shape[0], *([1] * (x.dim() - 2)), beta.shape[1])
+
+ return x * (1 + alpha) + beta
+
+
+class SequentialCond(torch.nn.Sequential):
+ def forward(self, input, *args, **kwargs):
+ for module in self:
+ if isinstance(module, (AdaptiveLayerNorm1D, SequentialCond, ResidualMLPBlock)):
+ # print(f'Passing on args to {module}', [a.shape for a in args])
+ input = module(input, *args, **kwargs)
+ else:
+ # print(f'Skipping passing args to {module}', [a.shape for a in args])
+ input = module(input)
+ return input
+
+
+def normalization_layer(norm: Optional[str], dim: int, norm_cond_dim: int = -1):
+ if norm == "batch":
+ return torch.nn.BatchNorm1d(dim)
+ elif norm == "layer":
+ return torch.nn.LayerNorm(dim)
+ elif norm == "ada":
+ assert norm_cond_dim > 0, f"norm_cond_dim must be positive, got {norm_cond_dim}"
+ return AdaptiveLayerNorm1D(dim, norm_cond_dim)
+ elif norm is None:
+ return torch.nn.Identity()
+ else:
+ raise ValueError(f"Unknown norm: {norm}")
+
+
+def linear_norm_activ_dropout(
+ input_dim: int,
+ output_dim: int,
+ activation: torch.nn.Module = torch.nn.ReLU(),
+ bias: bool = True,
+ norm: Optional[str] = "layer", # Options: ada/batch/layer
+ dropout: float = 0.0,
+ norm_cond_dim: int = -1,
+) -> SequentialCond:
+ layers = []
+ layers.append(torch.nn.Linear(input_dim, output_dim, bias=bias))
+ if norm is not None:
+ layers.append(normalization_layer(norm, output_dim, norm_cond_dim))
+ layers.append(copy.deepcopy(activation))
+ if dropout > 0.0:
+ layers.append(torch.nn.Dropout(dropout))
+ return SequentialCond(*layers)
+
+
+def create_simple_mlp(
+ input_dim: int,
+ hidden_dims: List[int],
+ output_dim: int,
+ activation: torch.nn.Module = torch.nn.ReLU(),
+ bias: bool = True,
+ norm: Optional[str] = "layer", # Options: ada/batch/layer
+ dropout: float = 0.0,
+ norm_cond_dim: int = -1,
+) -> SequentialCond:
+ layers = []
+ prev_dim = input_dim
+ for hidden_dim in hidden_dims:
+ layers.extend(
+ linear_norm_activ_dropout(
+ prev_dim, hidden_dim, activation, bias, norm, dropout, norm_cond_dim
+ )
+ )
+ prev_dim = hidden_dim
+ layers.append(torch.nn.Linear(prev_dim, output_dim, bias=bias))
+ return SequentialCond(*layers)
+
+
+class ResidualMLPBlock(torch.nn.Module):
+ def __init__(
+ self,
+ input_dim: int,
+ hidden_dim: int,
+ num_hidden_layers: int,
+ output_dim: int,
+ activation: torch.nn.Module = torch.nn.ReLU(),
+ bias: bool = True,
+ norm: Optional[str] = "layer", # Options: ada/batch/layer
+ dropout: float = 0.0,
+ norm_cond_dim: int = -1,
+ ):
+ super().__init__()
+ if not (input_dim == output_dim == hidden_dim):
+ raise NotImplementedError(
+ f"input_dim {input_dim} != output_dim {output_dim} is not implemented"
+ )
+
+ layers = []
+ prev_dim = input_dim
+ for i in range(num_hidden_layers):
+ layers.append(
+ linear_norm_activ_dropout(
+ prev_dim, hidden_dim, activation, bias, norm, dropout, norm_cond_dim
+ )
+ )
+ prev_dim = hidden_dim
+ self.model = SequentialCond(*layers)
+ self.skip = torch.nn.Identity()
+
+ def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor:
+ return x + self.model(x, *args, **kwargs)
+
+
+class ResidualMLP(torch.nn.Module):
+ def __init__(
+ self,
+ input_dim: int,
+ hidden_dim: int,
+ num_hidden_layers: int,
+ output_dim: int,
+ activation: torch.nn.Module = torch.nn.ReLU(),
+ bias: bool = True,
+ norm: Optional[str] = "layer", # Options: ada/batch/layer
+ dropout: float = 0.0,
+ num_blocks: int = 1,
+ norm_cond_dim: int = -1,
+ ):
+ super().__init__()
+ self.input_dim = input_dim
+ self.model = SequentialCond(
+ linear_norm_activ_dropout(
+ input_dim, hidden_dim, activation, bias, norm, dropout, norm_cond_dim
+ ),
+ *[
+ ResidualMLPBlock(
+ hidden_dim,
+ hidden_dim,
+ num_hidden_layers,
+ hidden_dim,
+ activation,
+ bias,
+ norm,
+ dropout,
+ norm_cond_dim,
+ )
+ for _ in range(num_blocks)
+ ],
+ torch.nn.Linear(hidden_dim, output_dim, bias=bias),
+ )
+
+ def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor:
+ return self.model(x, *args, **kwargs)
+
+
+class FrequencyEmbedder(torch.nn.Module):
+ def __init__(self, num_frequencies, max_freq_log2):
+ super().__init__()
+ frequencies = 2 ** torch.linspace(0, max_freq_log2, steps=num_frequencies)
+ self.register_buffer("frequencies", frequencies)
+
+ def forward(self, x):
+ # x should be of size (N,) or (N, D)
+ N = x.size(0)
+ if x.dim() == 1: # (N,)
+ x = x.unsqueeze(1) # (N, D) where D=1
+ x_unsqueezed = x.unsqueeze(-1) # (N, D, 1)
+ scaled = self.frequencies.view(1, 1, -1) * x_unsqueezed # (N, D, num_frequencies)
+ s = torch.sin(scaled)
+ c = torch.cos(scaled)
+ embedded = torch.cat([s, c, x_unsqueezed], dim=-1).view(
+ N, -1
+ ) # (N, D * 2 * num_frequencies + D)
+ return embedded
+
diff --git a/lib/modeling/optim/__init__.py b/lib/modeling/optim/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..c8b9f9d290513239bc3f7c764d460e38b5dccb8f
--- /dev/null
+++ b/lib/modeling/optim/__init__.py
@@ -0,0 +1,2 @@
+from .skelify.skelify import SKELify
+from .skelify_refiner import SKELifyRefiner
\ No newline at end of file
diff --git a/lib/modeling/optim/skelify/closure.py b/lib/modeling/optim/skelify/closure.py
new file mode 100644
index 0000000000000000000000000000000000000000..52bc733b82bd32c97abcf4bc780dee551edccb23
--- /dev/null
+++ b/lib/modeling/optim/skelify/closure.py
@@ -0,0 +1,202 @@
+from lib.kits.basic import *
+
+from lib.utils.camera import perspective_projection
+from lib.modeling.losses import compute_poses_angle_prior_loss
+
+from .utils import (
+ gmof,
+ guess_cam_z,
+ estimate_kp2d_scale,
+ get_kp_active_jids,
+ get_params_active_q_masks,
+)
+
+def build_closure(
+ self,
+ cfg,
+ optimizer,
+ inputs,
+ focal_length : float,
+ gt_kp2d,
+ log_data,
+):
+ B = len(gt_kp2d)
+
+ act_parts = instantiate(cfg.parts)
+ act_q_masks = None
+ if not (act_parts == 'all' or 'all' in act_parts):
+ act_q_masks = get_params_active_q_masks(act_parts)
+
+ # Shortcuts for the inference of the skeleton model.
+ def inference_skel(inputs):
+ poses_active = torch.cat([inputs['poses_orient'], inputs['poses_body']], dim=-1) # (B, 46)
+ if act_q_masks is not None:
+ poses_hidden = poses_active.clone().detach() # (B, 46)
+ poses = poses_active * act_q_masks + poses_hidden * (1 - act_q_masks) # (B, 46)
+ else:
+ poses = poses_active
+ skel_params = {
+ 'poses' : poses, # (B, 46)
+ 'betas' : inputs['betas'], # (B, 10)
+ }
+ skel_output = self.skel_model(**skel_params, skelmesh=False)
+ return skel_params, skel_output
+
+ # Estimate the camera depth as an initialization if depth loss is enabled.
+ gs_cam_z = None
+ if 'w_depth' in cfg.losses:
+ with torch.no_grad():
+ _, skel_output = inference_skel(inputs)
+ gs_cam_z = guess_cam_z(
+ pd_kp3d = skel_output.joints,
+ gt_kp2d = gt_kp2d,
+ focal_length = focal_length,
+ )
+
+ # Prepare the focal length for the perspective projection.
+ focal_length_xy = np.ones((B, 2)) * focal_length # (B, 2)
+
+
+ def closure():
+ optimizer.zero_grad()
+
+ # 📦 Data preparation.
+ with PM.time_monitor('SKEL-forward'):
+ skel_params, skel_output = inference_skel(inputs)
+
+ with PM.time_monitor('reproj'):
+ pd_kp2d = perspective_projection(
+ points = to_tensor(skel_output.joints, device=self.device),
+ translation = to_tensor(inputs['cam_t'], device=self.device),
+ focal_length = to_tensor(focal_length_xy, device=self.device),
+ )
+
+ with PM.time_monitor('compute_losses'):
+ loss, losses = compute_losses(
+ # Loss configuration.
+ loss_cfg = instantiate(cfg.losses),
+ parts = act_parts,
+ # Data inputs.
+ gt_kp2d = gt_kp2d,
+ pd_kp2d = pd_kp2d,
+ pd_params = skel_params,
+ pd_cam_z = inputs['cam_t'][:, 2],
+ gs_cam_z = gs_cam_z,
+ )
+
+ with PM.time_monitor('visualization'):
+ VISUALIZE = True
+ if VISUALIZE:
+ # For visualize the optimization process.
+ kp2d_err = torch.sum((pd_kp2d - gt_kp2d[..., :2]) ** 2, dim=-1) * gt_kp2d[..., 2] # (B, J)
+ kp2d_err = kp2d_err.sum(dim=-1) / (torch.sum(gt_kp2d[..., 2], dim=-1) + 1e-6) # (B,)
+
+ # Store logging data.
+ if self.tb_logger is not None:
+ log_data.update({
+ 'losses' : losses,
+ 'pd_kp2d' : pd_kp2d[:self.n_samples].detach().clone(),
+ 'pd_verts' : skel_output.skin_verts[:self.n_samples].detach().clone(),
+ 'cam_t' : inputs['cam_t'][:self.n_samples].detach().clone(),
+ 'optim_betas' : inputs['betas'][:self.n_samples].detach().clone(),
+ 'kp2d_err' : kp2d_err[:self.n_samples].detach().clone(),
+ })
+
+ with PM.time_monitor('backwards'):
+ loss.backward()
+ return loss.item()
+
+ return closure
+
+
+def compute_losses(
+ loss_cfg : Dict[str, Union[bool, float]],
+ parts : List[str],
+ gt_kp2d : torch.Tensor,
+ pd_kp2d : torch.Tensor,
+ pd_params : Dict[str, torch.Tensor],
+ pd_cam_z : torch.Tensor,
+ gs_cam_z : Optional[torch.Tensor] = None,
+):
+ '''
+ ### Args
+ - loss_cfg: Dict[str, Union[bool, float]]
+ - Special option flags (`f_xxx`) or loss weights (`w_xxx`).
+ - parts: List[str]
+ - The list of the active joint parts groups.
+ - Among {'all', 'torso', 'torso-lite', 'limbs', 'head', 'limbs_proximal', 'limbs_distal'}.
+ - gt_kp2d: torch.Tensor (B, 44, 3)
+ - The ground-truth 2D keypoints with confidence.
+ - pd_kp2d: torch.Tensor (B, 44, 2)
+ - The predicted 2D keypoints.
+ - pd_params: Dict[str, torch.Tensor]
+ - poses: torch.Tensor (B, 46)
+ - betas: torch.Tensor (B, 10)
+ - pd_cam_z: torch.Tensor (B,)
+ - The predicted camera depth translation.
+ - gs_cam_z: Optional[torch.Tensor] (B,)
+ - The guessed camera depth translation.
+
+ ### Returns
+ - loss: torch.Tensor (,)
+ - The weighted loss value for optimization.
+ - losses: Dict[str, float]
+ - The dictionary of the loss values for logging.
+ '''
+
+ losses = {}
+ loss = torch.tensor(0.0, device=gt_kp2d.device)
+ kp2d_conf = gt_kp2d[:, :, 2] # (B, J)
+ gt_kp2d = gt_kp2d[:, :, :2] # (B, J, 2)
+
+ # Special option flags.
+ f_normalize_kp2d = loss_cfg.get('f_normalize_kp2d', False)
+
+ if f_normalize_kp2d:
+ scale2mean = loss_cfg.get('f_normalize_kp2d_to_mean', False)
+ scale2d = estimate_kp2d_scale(gt_kp2d) # (B,)
+ pd_kp2d = pd_kp2d / (scale2d[:, None, None] + 1e-6) # (B, J, 2)
+ gt_kp2d = gt_kp2d / (scale2d[:, None, None] + 1e-6) # (B, J, 2)
+ if scale2mean:
+ scale2d_mean = scale2d.mean()
+ pd_kp2d = pd_kp2d * scale2d_mean # (B, J, 2)
+ gt_kp2d = gt_kp2d * scale2d_mean # (B, J, 2)
+
+ # Mask the keypoints.
+ act_jids = get_kp_active_jids(parts)
+ kp2d_conf = kp2d_conf[:, act_jids] # (B, J)
+ gt_kp2d = gt_kp2d[:, act_jids, :] # (B, J, 2)
+ pd_kp2d = pd_kp2d[:, act_jids, :] # (B, J, 2)
+
+ # Calculate weighted losses.
+ w_depth = loss_cfg.get('w_depth', None)
+ w_reprojection = loss_cfg.get('w_reprojection', None)
+ w_shape_prior = loss_cfg.get('w_shape_prior', None)
+ w_angle_prior = loss_cfg.get('w_angle_prior', None)
+
+ if w_depth:
+ assert gs_cam_z is not None, 'The guessed camera depth is required for the depth loss.'
+ depth_loss = (gs_cam_z - pd_cam_z).pow(2) # (B,)
+ loss += (w_depth ** 2) * depth_loss.mean() # (,)
+ losses['depth'] = (w_depth ** 2) * depth_loss.mean().item() # float
+
+ if w_reprojection:
+ reproj_err_j = gmof(pd_kp2d - gt_kp2d).sum(dim=-1) # (B, J)
+ reproj_err_j = kp2d_conf.pow(2) * reproj_err_j # (B, J)
+ reproj_loss = reproj_err_j.sum(-1) # (B,)
+ loss += (w_reprojection ** 2) * reproj_loss.mean() # (,)
+ losses['reprojection'] = (w_reprojection ** 2) * reproj_loss.mean().item() # float
+
+ if w_shape_prior:
+ shape_prior_loss = pd_params['betas'].pow(2).sum(dim=-1) # (B,)
+ loss += (w_shape_prior ** 2) * shape_prior_loss.mean() # (,)
+ losses['shape_prior'] = (w_shape_prior ** 2) * shape_prior_loss.mean().item() # float
+
+ if w_angle_prior:
+ w_angle_prior *= loss_cfg.get('w_angle_prior_scale', 1.0)
+ angle_prior_loss = compute_poses_angle_prior_loss(pd_params['poses']) # (B,)
+ loss += (w_angle_prior ** 2) * angle_prior_loss.mean() # (,)
+ losses['angle_prior'] = (w_angle_prior ** 2) * angle_prior_loss.mean().item() # float
+
+ losses['weighted_sum'] = loss.item() # float
+ return loss, losses
\ No newline at end of file
diff --git a/lib/modeling/optim/skelify/skelify.py b/lib/modeling/optim/skelify/skelify.py
new file mode 100644
index 0000000000000000000000000000000000000000..499d9c5b34e68dc37f378ab96591737d08b05199
--- /dev/null
+++ b/lib/modeling/optim/skelify/skelify.py
@@ -0,0 +1,296 @@
+from lib.kits.basic import *
+
+import cv2
+import traceback
+from tqdm import tqdm
+
+from lib.body_models.common import make_SKEL
+from lib.body_models.abstract_skeletons import Skeleton_OpenPose25
+from lib.utils.vis import render_mesh_overlay_img
+from lib.utils.data import to_tensor
+from lib.utils.media import draw_kp2d_on_img, annotate_img, splice_img
+from lib.utils.camera import perspective_projection
+
+from .utils import (
+ compute_rel_change,
+ gmof,
+)
+
+from .closure import build_closure
+
+class SKELify():
+
+ def __init__(self, cfg, tb_logger=None, device='cuda:0', name='SKELify'):
+ self.cfg = cfg
+ self.name = name
+ self.eq_thre = cfg.early_quit_thresholds
+
+ self.tb_logger = tb_logger
+
+ self.device = device
+ # self.skel_model = make_SKEL(device=device)
+ self.skel_model = instantiate(cfg.skel_model).to(device)
+
+ # Shortcuts.
+ self.n_samples = cfg.logger.samples_per_record
+
+ # Dirty implementation for visualization.
+ self.render_frames = []
+
+
+ def __call__(
+ self,
+ gt_kp2d : Union[torch.Tensor, np.ndarray],
+ init_poses : Union[torch.Tensor, np.ndarray],
+ init_betas : Union[torch.Tensor, np.ndarray],
+ init_cam_t : Union[torch.Tensor, np.ndarray],
+ img_patch : Optional[np.ndarray] = None,
+ **kwargs
+ ):
+ '''
+ Use optimization to fit the SKEL parameters to the 2D keypoints.
+
+ ### Args:
+ - gt_kp2d: torch.Tensor or np.ndarray, (B, J, 3)
+ - The last three dim means [x, y, conf].
+ - The 2D keypoints to fit, they are defined in [-0.5, 0.5], zero-centered space.
+ - init_poses: torch.Tensor or np.ndarray, (B, 46)
+ - init_betas: torch.Tensor or np.ndarray, (B, 10)
+ - init_cam_t: torch.Tensor or np.ndarray, (B, 3)
+ - img_patch: np.ndarray or None, (B, H, W, 3)
+ - The image patch for visualization. H, W are defined in normalized bounding box space.
+ - If it is None, the visualization will simply use a black background.
+
+ ### Returns:
+ - dict, containing the optimized parameters.
+ - poses: torch.Tensor, (B, 46)
+ - betas: torch.Tensor, (B, 10)
+ - cam_t: torch.Tensor, (B, 3)
+ '''
+ with PM.time_monitor('input preparation'):
+ gt_kp2d = to_tensor(gt_kp2d, device=self.device).detach().float().clone() # (B, J, 3)
+ init_poses = to_tensor(init_poses, device=self.device).detach().float().clone() # (B, 46)
+ init_betas = to_tensor(init_betas, device=self.device).detach().float().clone() # (B, 10)
+ init_cam_t = to_tensor(init_cam_t, device=self.device).detach().float().clone() # (B, 3)
+ inputs = {
+ 'poses_orient' : init_poses[:, :3], # (B, 3)
+ 'poses_body' : init_poses[:, 3:], # (B, 43)
+ 'betas' : init_betas, # (B, 10)
+ 'cam_t' : init_cam_t, # (B, 3)
+ }
+ focal_length = float(self.cfg.focal_length / self.cfg.img_patch_size) # float
+
+ # ⛩️ Optimization phases, controlled by config file.
+ with PM.time_monitor('optim') as tm:
+ prev_steps = 0 # accumulate the steps are *supposed* to be done in the previous phases
+ n_phases = len(self.cfg.phases)
+ for phase_id, phase_name in enumerate(self.cfg.phases):
+ phase_cfg = self.cfg.phases[phase_name]
+ # 📦 Data preparation.
+ optim_params = []
+ for k in inputs.keys():
+ if k in phase_cfg.params_keys:
+ inputs[k].requires_grad = True
+ optim_params.append(inputs[k]) # (B, D)
+ else:
+ inputs[k].requires_grad = False
+ log_data = {}
+ tm.tick(f'Data preparation')
+
+ # ⚙️ Optimization preparation.
+ optimizer = instantiate(phase_cfg.optimizer, optim_params, _recursive_=True)
+ closure = self._build_closure(
+ cfg=phase_cfg, optimizer=optimizer, # basic
+ inputs=inputs, focal_length=focal_length, gt_kp2d=gt_kp2d, # data reference
+ log_data=log_data, # monitoring
+ )
+ tm.tick(f'Optimizer * closure prepared.')
+
+ # 🚀 Optimization loop.
+ with tqdm(range(phase_cfg.max_loop)) as bar:
+ prev_loss = None
+ bar.set_description(f'[{phase_name}] Loss: ???')
+ for i in bar:
+ # 1. Main part of the optimization loop.
+ log_data.clear()
+ curr_loss = optimizer.step(closure)
+
+ # 2. Log.
+ if self.tb_logger is not None:
+ log_data.update({
+ 'img_patch' : img_patch[:self.n_samples] if img_patch is not None else None,
+ 'gt_kp2d' : gt_kp2d[:self.n_samples].detach().clone(),
+ })
+ self._tb_log(prev_steps + i, phase_name, log_data)
+
+ # 3. The end of one optimization loop.
+ bar.set_description(f'[{phase_id+1}/{n_phases}] @ {phase_name} - Loss: {curr_loss:.4f}')
+ if self._can_early_quit(optim_params, prev_loss, curr_loss):
+ break
+ prev_loss = curr_loss
+
+ prev_steps += phase_cfg.max_loop
+ tm.tick(f'{phase_name} finished.')
+
+ with PM.time_monitor('last infer'):
+ poses = torch.cat([inputs['poses_orient'], inputs['poses_body']], dim=-1).detach().clone() # (B, 46)
+ betas = inputs['betas'].detach().clone() # (B, 10)
+ cam_t = inputs['cam_t'].detach().clone() # (B, 3)
+ skel_outputs = self.skel_model(poses=poses, betas=betas, skelmesh=False) # (B, 44, 3)
+ optim_kp3d = skel_outputs.joints # (B, 44, 3)
+ # Evaluate the confidence of the results.
+ focal_length_xy = np.ones((len(poses), 2)) * focal_length # (B, 2)
+ optim_kp2d = perspective_projection(
+ points = optim_kp3d,
+ translation = cam_t,
+ focal_length = to_tensor(focal_length_xy, device=self.device),
+ )
+ kp2d_err = SKELify.eval_kp2d_err(gt_kp2d, optim_kp2d) # (B,)
+
+ # ⛩️ Prepare the output data.
+ outputs = {
+ 'poses' : poses, # (B, 46)
+ 'betas' : betas, # (B, 10)
+ 'cam_t' : cam_t, # (B, 3)
+ 'kp2d_err' : kp2d_err, # (B,)
+ }
+ return outputs
+
+
+ def _can_early_quit(self, opt_params, prev_loss, curr_loss):
+ ''' Judge whether to early quit the optimization process. If yes, return True, otherwise False.'''
+ if self.cfg.early_quit_thresholds is None:
+ # Never early quit.
+ return False
+
+ # Relative change test.
+ if prev_loss is not None:
+ loss_rel_change = compute_rel_change(prev_loss, curr_loss)
+ if loss_rel_change < self.cfg.early_quit_thresholds.rel:
+ get_logger().info(f'Early quit due to relative change: {loss_rel_change} = rel({prev_loss}, {curr_loss})')
+ return True
+
+ # Absolute change test.
+ if all([
+ torch.abs(param.grad.max()).item() < self.cfg.early_quit_thresholds.abs
+ for param in opt_params if param.grad is not None
+ ]):
+ get_logger().info(f'Early quit due to absolute change.')
+ return True
+
+ return False
+
+
+ def _build_closure(self, *args, **kwargs):
+ # Using this way to hide the very details and simplify the code.
+ return build_closure(self, *args, **kwargs)
+
+
+ @staticmethod
+ def eval_kp2d_err(gt_kp2d_with_conf:torch.Tensor, pd_kp2d:torch.Tensor):
+ ''' Evaluate the mean 2D keypoints L2 error. The formula is: ∑(gt - pd)^2 * conf / ∑conf. '''
+ assert len(gt_kp2d_with_conf.shape) == len(gt_kp2d_with_conf.shape), f'gt_kp2d_with_conf.shape={gt_kp2d_with_conf.shape}, pd_kp2d.shape={pd_kp2d.shape} but they should both be ((B,) J, D).'
+ if len(gt_kp2d_with_conf.shape) == 2:
+ gt_kp2d_with_conf, pd_kp2d = gt_kp2d_with_conf[None], pd_kp2d[None]
+ assert len(gt_kp2d_with_conf.shape) == 3, f'gt_kp2d_with_conf.shape={gt_kp2d_with_conf.shape}, pd_kp2d.shape={pd_kp2d.shape} but they should both be ((B,) J, D).'
+ B, J, _ = gt_kp2d_with_conf.shape
+ assert gt_kp2d_with_conf.shape == (B, J, 3), f'gt_kp2d_with_conf.shape={gt_kp2d_with_conf.shape} but it should be ((B,) J, 3).'
+ assert pd_kp2d.shape == (B, J, 2), f'pd_kp2d.shape={pd_kp2d.shape} but it should be ((B,) J, 2).'
+
+ conf = gt_kp2d_with_conf[..., 2] # (B, J)
+ gt_kp2d = gt_kp2d_with_conf[..., :2] # (B, J, 2)
+ kp2d_err = torch.sum((gt_kp2d - pd_kp2d) ** 2, dim=-1) * conf # (B, J)
+ kp2d_err = kp2d_err.sum(dim=-1) / (torch.sum(conf, dim=-1) + 1e-6) # (B,)
+ return kp2d_err
+
+
+ @rank_zero_only
+ def _tb_log(self, step_cnt:int, phase_name:str, log_data:Dict, *args, **kwargs):
+ ''' Write the logging information to the TensorBoard. '''
+ if step_cnt != 0 and (step_cnt + 1) % self.cfg.logger.interval_skelify != 0:
+ return
+
+ summary_writer = self.tb_logger.experiment
+
+ # Save losses.
+ for loss_name, loss_val in log_data['losses'].items():
+ summary_writer.add_scalar(f'skelify/{loss_name}', loss_val, step_cnt)
+
+ # Visualization of the optimization process. TODO: Maybe we can make this more elegant.
+ if log_data['img_patch'] is None:
+ log_data['img_patch'] = [np.zeros((self.cfg.img_patch_size, self.cfg.img_patch_size, 3), dtype=np.uint8)] \
+ * len(log_data['gt_kp2d'])
+
+ if len(self.render_frames) < 1:
+ self.init_v = log_data['pd_verts']
+ self.init_kp2d_err = log_data['kp2d_err']
+ self.init_ct = log_data['cam_t']
+
+ # Overlay the skin mesh of the results on the original image.
+ try:
+ imgs_spliced = []
+ for i, img_patch in enumerate(log_data['img_patch']):
+ kp2d_err = log_data['kp2d_err'][i].item()
+
+ img_with_init = render_mesh_overlay_img(
+ faces = self.skel_model.skin_f,
+ verts = self.init_v[i],
+ K4 = [self.cfg.focal_length, self.cfg.focal_length, 128, 128],
+ img = img_patch,
+ Rt = [torch.eye(3), self.init_ct[i]],
+ mesh_color = 'pink',
+ )
+ img_with_init = annotate_img(img_with_init, 'init')
+ img_with_init = annotate_img(img_with_init, f'Quality: {self.init_kp2d_err[i].item()*1000:.3f}/1e3', pos='tl')
+
+ img_with_mesh = render_mesh_overlay_img(
+ faces = self.skel_model.skin_f,
+ verts = log_data['pd_verts'][i],
+ K4 = [self.cfg.focal_length, self.cfg.focal_length, 128, 128],
+ img = img_patch,
+ Rt = [torch.eye(3), log_data['cam_t'][i]],
+ mesh_color = 'pink',
+ )
+ betas_max = log_data['optim_betas'][i].abs().max().item()
+ img_patch_raw = annotate_img(img_patch, 'raw')
+
+ log_data['gt_kp2d'][i][..., :2] = (log_data['gt_kp2d'][i][..., :2] + 0.5) * self.cfg.img_patch_size
+ img_with_gt = annotate_img(img_patch, 'gt_kp2d')
+ img_with_gt = draw_kp2d_on_img(
+ img_with_gt,
+ log_data['gt_kp2d'][i],
+ Skeleton_OpenPose25.bones,
+ Skeleton_OpenPose25.bone_colors,
+ )
+
+ log_data['pd_kp2d'][i] = (log_data['pd_kp2d'][i] + 0.5) * self.cfg.img_patch_size
+ img_with_pd = cv2.addWeighted(img_with_mesh, 0.7, img_patch, 0.3, 0)
+ img_with_pd = draw_kp2d_on_img(
+ img_with_pd,
+ log_data['pd_kp2d'][i],
+ Skeleton_OpenPose25.bones,
+ Skeleton_OpenPose25.bone_colors,
+ )
+
+ img_with_pd = annotate_img(img_with_pd, 'pd')
+ img_with_pd = annotate_img(img_with_pd, f'Quality: {kp2d_err*1000:.3f}/1e3\nbetas_max: {betas_max:.3f}', pos='tl')
+ img_with_mesh = annotate_img(img_with_mesh, f'Quality: {kp2d_err*1000:.3f}/1e3\nbetas_max: {betas_max:.3f}', pos='tl')
+ img_with_mesh = annotate_img(img_with_mesh, 'pd_mesh')
+
+ img_spliced = splice_img(
+ img_grids = [img_patch_raw, img_with_gt, img_with_pd, img_with_mesh, img_with_init],
+ # grid_ids = [[0, 1, 2, 3, 4]],
+ grid_ids = [[1, 2, 3, 4]],
+ )
+ img_spliced = annotate_img(img_spliced, f'{phase_name}/{step_cnt}', pos='tl')
+ imgs_spliced.append(img_spliced)
+
+ img_final = splice_img(imgs_spliced, grid_ids=[[i] for i in range(len(log_data['img_patch']))])
+
+ img_final = to_tensor(img_final, device=None).permute(2, 0, 1) # (3, H, W)
+ summary_writer.add_image('skelify/visualization', img_final, step_cnt)
+
+ self.render_frames.append(img_final)
+ except Exception as e:
+ get_logger().error(f'Failed to visualize the optimization process: {e}')
+ traceback.print_exc()
\ No newline at end of file
diff --git a/lib/modeling/optim/skelify/utils.py b/lib/modeling/optim/skelify/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..902bc9223fc14547d991db2df737e5832d274db0
--- /dev/null
+++ b/lib/modeling/optim/skelify/utils.py
@@ -0,0 +1,169 @@
+from lib.kits.basic import *
+
+from lib.body_models.skel_utils.definition import JID2QIDS
+
+
+def gmof(x, sigma=100):
+ '''
+ Geman-McClure error function, to be used as a robust loss function.
+ '''
+ x_squared = x ** 2
+ sigma_squared = sigma ** 2
+ return (sigma_squared * x_squared) / (sigma_squared + x_squared)
+
+
+def compute_rel_change(prev_val: float, curr_val: float) -> float:
+ '''
+ Compute the relative change between two values.
+ Copied from:
+ https://github.com/vchoutas/smplify-x
+
+ ### Args:
+ - prev_val: float
+ - curr_val: float
+
+ ### Returns:
+ - float
+ '''
+ return np.abs(prev_val - curr_val) / max([np.abs(prev_val), np.abs(curr_val), 1])
+
+
+INVALID_JIDS = [37, 38] # These joints are not reliable.
+
+def get_kp_active_j_masks(parts:Union[str, List[str]], device='cuda'):
+ # Generate the masks performed on the keypoints to mask the loss.
+ act_jids = get_kp_active_jids(parts)
+ masks = torch.zeros(44).to(device)
+ masks[act_jids] = 1.0
+
+ return masks
+
+
+def get_kp_active_jids(parts:Union[str, List[str]]):
+ if isinstance(parts, str):
+ if parts == 'all':
+ return get_kp_active_jids(['torso', 'limbs', 'head'])
+ elif parts == 'hips':
+ return [8, 9, 12, 27, 28, 39]
+ elif parts == 'torso-lite':
+ return [2, 5, 9, 12]
+ elif parts == 'torso':
+ return [1, 2, 5, 8, 9, 12, 27, 28, 33, 34, 37, 39, 40, 41]
+ elif parts == 'limbs':
+ return get_kp_active_jids(['limbs_proximal', 'limbs_distal'])
+ elif parts == 'head':
+ return [0, 15, 16, 17, 18, 38, 42, 43]
+ elif parts == 'limbs_proximal':
+ return [3, 6, 10, 13, 26, 29, 32, 35]
+ elif parts == 'limbs_distal':
+ return [4, 7, 11, 14, 19, 20, 21, 22, 23, 24, 25, 30, 31, 36]
+ else:
+ raise ValueError(f'Unsupported parts: {parts}')
+ else:
+ jids = []
+ for part in parts:
+ jids.extend(get_kp_active_jids(part))
+ jids = set(jids) - set(INVALID_JIDS)
+ return sorted(list(jids))
+
+
+def get_params_active_j_masks(parts:Union[str, List[str]], device='cuda'):
+ # Generate the masks performed on the keypoints to mask the loss.
+ act_jids = get_params_active_jids(parts)
+ masks = torch.zeros(24).to(device)
+ masks[act_jids] = 1.0
+
+ return masks
+
+
+def get_params_active_jids(parts:Union[str, List[str]]):
+ if isinstance(parts, str):
+ if parts == 'all':
+ return get_params_active_jids(['torso', 'limbs', 'head'])
+ elif parts == 'torso-lite':
+ return get_params_active_jids('torso')
+ elif parts == 'hips': # Enable `hips` if `poses_orient` is enabled.
+ return [0]
+ elif parts == 'torso':
+ return [0, 11]
+ elif parts == 'limbs':
+ return get_params_active_jids(['limbs_proximal', 'limbs_distal'])
+ elif parts == 'head':
+ return [12, 13]
+ elif parts == 'limbs_proximal':
+ return [1, 6, 14, 15, 19, 20]
+ elif parts == 'limbs_distal':
+ return [2, 3, 4, 5, 7, 8, 9, 10, 16, 17, 18, 21, 22, 23]
+ else:
+ raise ValueError(f'Unsupported parts: {parts}')
+ else:
+ qids = []
+ for part in parts:
+ qids.extend(get_params_active_jids(part))
+ return sorted(list(set(qids)))
+
+
+def get_params_active_q_masks(parts:Union[str, List[str]], device='cuda'):
+ # Generate the masks performed on the keypoints to mask the loss.
+ act_qids = get_params_active_qids(parts)
+ masks = torch.zeros(46).to(device)
+ masks[act_qids] = 1.0
+
+ return masks
+
+
+def get_params_active_qids(parts:Union[str, List[str]]):
+ act_jids = get_params_active_jids(parts)
+ qids = []
+ for act_jid in act_jids:
+ qids.extend(JID2QIDS[act_jid])
+ return sorted(list(set(qids)))
+
+
+def estimate_kp2d_scale(
+ kp2d : torch.Tensor,
+ edge_idxs : List[Tuple[int, int]] = [[5, 12], [2, 9]], # shoulders to hips
+):
+ diff2d = []
+ for edge in edge_idxs:
+ diff2d.append(kp2d[:, edge[0]] - kp2d[:, edge[1]]) # list of (B, 2)
+ scale2d = torch.stack(diff2d, dim=1).norm(dim=-1) # (B, E)
+ return scale2d.mean(dim=1) # (B,)
+
+
+@torch.no_grad()
+def guess_cam_z(
+ pd_kp3d : torch.Tensor,
+ gt_kp2d : torch.Tensor,
+ focal_length : float,
+ edge_idxs : List[Tuple[int, int]] = [[5, 12], [2, 9]], # shoulders to hips
+):
+ '''
+ Initializes the camera depth translation (i.e. z value) according to the ground truth 2D
+ keypoints and the predicted 3D keypoints.
+ Modified from: https://github.com/vchoutas/smplify-x/blob/68f8536707f43f4736cdd75a19b18ede886a4d53/smplifyx/fitting.py#L36-L110
+
+ ### Args
+ - pd_kp3d: torch.Tensor, (B, J, 3)
+ - gt_kp2d: torch.Tensor, (B, J, 2)
+ - Without confidence.
+ - focal_length: float
+ - edge_idxs: List[Tuple[int, int]], default=[[5, 12], [2, 9]], i.e. shoulders to hips
+ - Identify the edge to evaluate the scale of the entity.
+ '''
+ diff3d, diff2d = [], []
+ for edge in edge_idxs:
+ diff3d.append(pd_kp3d[:, edge[0]] - pd_kp3d[:, edge[1]]) # list of (B, 3)
+ diff2d.append(gt_kp2d[:, edge[0]] - gt_kp2d[:, edge[1]]) # list of (B, 2)
+
+ diff3d = torch.stack(diff3d, dim=1) # (B, E, 3)
+ diff2d = torch.stack(diff2d, dim=1) # (B, E, 2)
+
+ length3d = diff3d.norm(dim=-1) # (B, E)
+ length2d = diff2d.norm(dim=-1) # (B, E)
+
+ height3d = length3d.mean(dim=1) # (B,)
+ height2d = length2d.mean(dim=1) # (B,)
+
+ z_estim = focal_length * (height3d / height2d) # (B,)
+ return z_estim # (B,)
\ No newline at end of file
diff --git a/lib/modeling/optim/skelify_refiner.py b/lib/modeling/optim/skelify_refiner.py
new file mode 100644
index 0000000000000000000000000000000000000000..6ab5eb7376797563dcdb35e91a95b103560977db
--- /dev/null
+++ b/lib/modeling/optim/skelify_refiner.py
@@ -0,0 +1,425 @@
+from lib.kits.basic import *
+
+import traceback
+from tqdm import tqdm
+
+from lib.body_models.common import make_SKEL
+from lib.body_models.skel_wrapper import SKELWrapper, SKELOutput
+from lib.body_models.abstract_skeletons import Skeleton_OpenPose25
+from lib.utils.data import to_tensor, to_list
+from lib.utils.camera import perspective_projection
+from lib.utils.media import draw_kp2d_on_img, annotate_img, splice_img
+from lib.utils.vis import render_mesh_overlay_img
+
+from lib.modeling.losses import compute_poses_angle_prior_loss
+
+from .skelify.utils import get_kp_active_j_masks
+
+
+def compute_rel_change(prev_val: float, curr_val: float) -> float:
+ '''
+ Compute the relative change between two values.
+ Copied: from https://github.com/vchoutas/smplify-x
+
+ ### Args:
+ - prev_val: float
+ - curr_val: float
+
+ ### Returns:
+ - float
+ '''
+ return np.abs(prev_val - curr_val) / max([np.abs(prev_val), np.abs(curr_val), 1])
+
+
+def gmof(x, sigma):
+ '''
+ Geman-McClure error function, to be used as a robust loss function.
+ '''
+ x_squared = x ** 2
+ sigma_squared = sigma ** 2
+ return (sigma_squared * x_squared) / (sigma_squared + x_squared)
+
+
+class SKELifyRefiner():
+
+ def __init__(self, cfg, name='SKELify', tb_logger=None, device='cuda:0'):
+ self.cfg = cfg
+ self.name = name
+ self.eq_thre = cfg.early_quit_thresholds
+
+ self.tb_logger = tb_logger
+
+ self.device = device
+ self.skel_model = instantiate(cfg.skel_model).to(device)
+
+ # Dirty implementation for visualization.
+ self.render_frames = []
+
+
+
+ def __call__(
+ self,
+ gt_kp2d : Union[torch.Tensor, np.ndarray],
+ init_poses : Union[torch.Tensor, np.ndarray],
+ init_betas : Union[torch.Tensor, np.ndarray],
+ init_cam_t : Union[torch.Tensor, np.ndarray],
+ img_patch : Optional[np.ndarray] = None,
+ **kwargs
+ ):
+ '''
+ Use optimization to fit the SKEL parameters to the 2D keypoints.
+
+ ### Args:
+ - gt_kp2d : torch.Tensor or np.ndarray, (B, J, 3)
+ - The last three dim means [x, y, conf].
+ - The 2D keypoints to fit, they are defined in [-0.5, 0.5], zero-centered space.
+ - init_poses : torch.Tensor or np.ndarray, (B, 46)
+ - init_betas : torch.Tensor or np.ndarray, (B, 10)
+ - init_cam_t : torch.Tensor or np.ndarray, (B, 3)
+ - img_patch : np.ndarray or None, (B, H, W, 3)
+ - The image patch for visualization. H, W are defined in normalized bounding box space.
+ - If None, the visualization will simply use a black image.
+
+ ### Returns:
+ - TODO:
+ '''
+ # ⛩️ Prepare the input data.
+ gt_kp2d = to_tensor(gt_kp2d, device=self.device).detach().float().clone() # (B, J, 3)
+ init_poses = to_tensor(init_poses, device=self.device).detach().float().clone() # (B, 46)
+ init_betas = to_tensor(init_betas, device=self.device).detach().float().clone() # (B, 10)
+ init_cam_t = to_tensor(init_cam_t, device=self.device).detach().float().clone() # (B, 3)
+ inputs = {
+ 'poses_orient': init_poses[:, :3], # (B, 3)
+ 'poses_body' : init_poses[:, 3:], # (B, 43)
+ 'betas' : init_betas, # (B, 10)
+ 'cam_t' : init_cam_t, # (B, 3)
+ }
+
+ focal_length = np.ones(2) * self.cfg.focal_length / self.cfg.img_patch_size
+ focal_length = focal_length.reshape(1, 2).repeat(inputs['cam_t'].shape[0], 1)
+
+ # ⛩️ Optimization phases, controlled by config file.
+ prev_phase_steps = 0 # accumulate the steps are *supposed* to be done in the previous phases
+ for phase_id, phase_name in enumerate(self.cfg.phases):
+ phase_cfg = self.cfg.phases[phase_name]
+ # Preparation.
+ optim_params = []
+ for k in inputs.keys():
+ if k in phase_cfg.params_keys:
+ inputs[k].requires_grad = True
+ optim_params.append(inputs[k]) # (B, D)
+ else:
+ inputs[k].requires_grad = False
+
+ optimizer = instantiate(phase_cfg.optimizer, optim_params, _recursive_=True)
+
+ def closure():
+ optimizer.zero_grad()
+
+ # Data preparation.
+ cam_t = inputs['cam_t']
+ skel_params = {
+ 'poses' : torch.cat([inputs['poses_orient'], inputs['poses_body']], dim=-1), # (B, 46)
+ 'betas' : inputs['betas'], # (B, 10)
+ 'skelmesh' : False,
+ }
+
+ # Optimize steps.
+ skel_output = self.skel_model(**skel_params)
+
+ pd_kp2d = perspective_projection(
+ points = to_tensor(skel_output.joints, device=self.device),
+ translation = to_tensor(cam_t, device=self.device),
+ focal_length = to_tensor(focal_length, device=self.device),
+ )
+
+ loss, losses = self._compute_losses(
+ act_losses = phase_cfg.losses,
+ act_parts = phase_cfg.get('parts', 'all'),
+ gt_kp2d = gt_kp2d,
+ pd_kp2d = pd_kp2d,
+ pd_params = skel_params,
+ **phase_cfg.get('weights', {}),
+ )
+
+ # For visualize the optimization process.
+ _conf = gt_kp2d[..., 2] # (B, J)
+ metric = torch.sum((pd_kp2d - gt_kp2d[..., :2]) ** 2, dim=-1) * _conf # (B, J)
+ metric = metric.sum(dim=-1) / (torch.sum(_conf, dim=-1) + 1e-6) # (B,)
+
+ # Store logging data.
+ if self.tb_logger is not None:
+ log_data.update({
+ 'losses' : losses,
+ 'pd_kp2d' : pd_kp2d[:self.cfg.logger.samples_per_record].detach().clone(),
+ 'pd_verts' : skel_output.skin_verts[:self.cfg.logger.samples_per_record].detach().clone(),
+ 'cam_t' : cam_t[:self.cfg.logger.samples_per_record].detach().clone(),
+ 'metric' : metric[:self.cfg.logger.samples_per_record].detach().clone(),
+ 'optim_betas' : inputs['betas'][:self.cfg.logger.samples_per_record].detach().clone(),
+ })
+
+ loss.backward()
+ return loss.item()
+
+ # Optimization loop.
+ prev_loss = None
+ with tqdm(range(phase_cfg.max_loop)) as bar:
+ bar.set_description(f'[{phase_name}] Loss: ???')
+ for i in bar:
+ log_data = {}
+ curr_loss = optimizer.step(closure)
+
+ # Logging.
+ if self.tb_logger is not None:
+ log_data.update({
+ 'img_patch' : img_patch[:self.cfg.logger.samples_per_record] if img_patch is not None else None,
+ 'gt_kp2d' : gt_kp2d[:self.cfg.logger.samples_per_record].detach().clone(),
+ })
+ self._tb_log(prev_phase_steps + i, log_data)
+ # self._tb_log_for_report(prev_phase_steps + i, log_data)
+
+ bar.set_description(f'[{phase_name}] Loss: {curr_loss:.4f}')
+ if self._can_early_quit(optim_params, prev_loss, curr_loss):
+ break
+
+ prev_loss = curr_loss
+
+ prev_phase_steps += phase_cfg.max_loop
+
+ # ⛩️ Prepare the output data.
+ outputs = {
+ 'poses': torch.cat([inputs['poses_orient'], inputs['poses_body']], dim=-1).detach().clone(), # (B, 46)
+ 'betas': inputs['betas'].detach().clone(), # (B, 10)
+ 'cam_t': inputs['cam_t'].detach().clone(), # (B, 3)
+ }
+ return outputs
+
+
+ def _compute_losses(
+ self,
+ act_losses : List[str],
+ act_parts : List[str],
+ gt_kp2d : torch.Tensor,
+ pd_kp2d : torch.Tensor,
+ pd_params : Dict,
+ robust_sigma : float = 100,
+ shape_prior_weight : float = 5,
+ angle_prior_weight : float = 15.2,
+ *args, **kwargs,
+ ):
+ '''
+ Compute the weighted losses according to the config file.
+ Follow: https://github.com/nkolot/SPIN/blob/2476c436013055be5cb3905e4e4ecfa86966fac3/smplify/losses.py#L26-L58s
+ '''
+ B = len(gt_kp2d)
+ act_j_masks = get_kp_active_j_masks(act_parts, device=gt_kp2d.device) # (44,)
+
+ # Reproject the 3D keypoints to image and compare the L2 error with the g.t. 2D keypoints.
+ kp_conf = gt_kp2d[..., 2] # (B, J)
+ gt_kp2d = gt_kp2d[..., :2] # (B, J, 2)
+ reproj_err = gmof(pd_kp2d - gt_kp2d, robust_sigma) # (B, J, 2)
+ reproj_loss = ((kp_conf ** 2) * reproj_err.sum(dim=-1) * act_j_masks[None]).sum(-1) # (B,)
+
+ # Regularize the shape parameters.
+ shape_prior_loss = (shape_prior_weight ** 2) * (pd_params['betas'] ** 2).sum(dim=-1) # (B,)
+
+ # Use the SKEL angle prior knowledge (e.g., rotation limitation) to regularize the optimization process.
+ # TODO: Is that necessary?
+ angle_prior_loss = (angle_prior_weight ** 2) * compute_poses_angle_prior_loss(pd_params['poses']).mean() # (,)
+
+ losses = {
+ 'reprojection' : reproj_loss.mean(), # (,)
+ 'shape_prior' : shape_prior_loss.mean(), # (,)
+ 'angle_prior' : angle_prior_loss, # (,)
+ }
+ loss = torch.tensor(0., device=gt_kp2d.device)
+ for k in act_losses:
+ loss += losses[k]
+ losses = {k: v.detach() for k, v in losses.items()}
+ losses['sum'] = loss.detach() # (,)
+ return loss, losses
+
+
+ def _can_early_quit(self, opt_params, prev_loss, curr_loss):
+ ''' Judge whether to early quit the optimization process. If yes, return True, otherwise False.'''
+ if self.cfg.early_quit_thresholds is None:
+ # Never early quit.
+ return False
+
+ # Relative change test.
+ if prev_loss is not None:
+ loss_rel_change = compute_rel_change(prev_loss, curr_loss)
+ if loss_rel_change < self.cfg.early_quit_thresholds.rel:
+ get_logger().info(f'Early quit due to relative change: {loss_rel_change:.4f} = rel({prev_loss}, {curr_loss})')
+ return True
+
+ # Absolute change test.
+ if all([
+ torch.abs(param.grad.max()).item() < self.cfg.early_quit_thresholds.abs
+ for param in opt_params if param.grad is not None
+ ]):
+ get_logger().info(f'Early quit due to absolute change.')
+ return True
+
+ return False
+
+
+ @rank_zero_only
+ def _tb_log(self, step_cnt:int, log_data:Dict, *args, **kwargs):
+ ''' Write the logging information to the TensorBoard. '''
+ if step_cnt != 0 and (step_cnt + 1) % self.cfg.logger.interval != 0:
+ return
+
+ summary_writer = self.tb_logger.experiment
+
+ # Save losses.
+ for loss_name, loss_val in log_data['losses'].items():
+ summary_writer.add_scalar(f'skelify/{loss_name}', loss_val.detach().item(), step_cnt)
+
+ # Visualization of the optimization process. TODO: Maybe we can make this more elegant.
+ if log_data['img_patch'] is None:
+ log_data['img_patch'] = [np.zeros((self.cfg.img_patch_size, self.cfg.img_patch_size, 3), dtype=np.uint8)] \
+ * len(log_data['gt_kp2d'])
+
+ if len(self.render_frames) < 1:
+ self.init_v = log_data['pd_verts']
+ self.init_metric = log_data['metric']
+ self.init_ct = log_data['cam_t']
+
+ # Overlay the skin mesh of the results on the original image.
+ try:
+ imgs_spliced = []
+ for i, img_patch in enumerate(log_data['img_patch']):
+ metric = log_data['metric'][i].item()
+
+ img_with_init = render_mesh_overlay_img(
+ faces = self.skel_model.skin_f,
+ verts = self.init_v[i],
+ K4 = [self.cfg.focal_length, self.cfg.focal_length, 0, 0],
+ img = img_patch,
+ Rt = [torch.eye(3), self.init_ct[i]],
+ mesh_color = 'pink',
+ )
+ img_with_init = annotate_img(img_with_init, 'init')
+ img_with_init = annotate_img(img_with_init, f'Quality: {self.init_metric[i].item()*1000:.3f}/1e3', pos='tl')
+
+ img_with_mesh = render_mesh_overlay_img(
+ faces = self.skel_model.skin_f,
+ verts = log_data['pd_verts'][i],
+ K4 = [self.cfg.focal_length, self.cfg.focal_length, 0, 0],
+ img = img_patch,
+ Rt = [torch.eye(3), log_data['cam_t'][i]],
+ mesh_color = 'pink',
+ )
+ img_with_mesh = annotate_img(img_with_mesh, 'pd_mesh')
+ betas_max = log_data['optim_betas'][i].abs().max().item()
+ img_with_mesh = annotate_img(img_with_mesh, f'Quality: {metric*1000:.3f}/1e3\nbetas_max: {betas_max:.3f}', pos='tl')
+ img_patch_raw = annotate_img(img_patch, 'raw')
+
+ log_data['gt_kp2d'][i][..., :2] = (log_data['gt_kp2d'][i][..., :2] + 0.5) * self.cfg.img_patch_size
+ img_with_gt = annotate_img(img_patch, 'gt_kp2d')
+ img_with_gt = draw_kp2d_on_img(
+ img_with_gt,
+ log_data['gt_kp2d'][i],
+ Skeleton_OpenPose25.bones,
+ Skeleton_OpenPose25.bone_colors,
+ )
+
+ log_data['pd_kp2d'][i] = (log_data['pd_kp2d'][i] + 0.5) * self.cfg.img_patch_size
+ img_with_pd = annotate_img(img_patch, 'pd_kp2d')
+ img_with_pd = draw_kp2d_on_img(
+ img_with_pd,
+ log_data['pd_kp2d'][i],
+ Skeleton_OpenPose25.bones,
+ Skeleton_OpenPose25.bone_colors,
+ )
+
+ img_spliced = splice_img(
+ img_grids = [img_patch_raw, img_with_gt, img_with_pd, img_with_init, img_with_mesh],
+ # grid_ids = [[0, 1, 2, 3, 4]],
+ grid_ids = [[1, 2, 3, 4]],
+ )
+ imgs_spliced.append(img_spliced)
+
+ img_final = splice_img(imgs_spliced, grid_ids=[[i] for i in range(len(log_data['img_patch']))])
+
+ img_final = to_tensor(img_final, device=None).permute(2, 0, 1) # (3, H, W)
+ summary_writer.add_image('skelify/visualization', img_final, step_cnt)
+
+ self.render_frames.append(img_final)
+ except Exception as e:
+ get_logger().error(f'Failed to visualize the optimization process: {e}')
+ # traceback.print_exc()
+
+
+ @rank_zero_only
+ def _tb_log_for_report(self, step_cnt:int, log_data:Dict, *args, **kwargs):
+ ''' Write the logging information to the TensorBoard. '''
+
+ get_logger().warning(f'This logging functions is just for presentation.')
+
+ if len(self.render_frames) < 1:
+ self.init_v = log_data['pd_verts']
+ self.init_ct = log_data['cam_t']
+
+ if step_cnt != 0 and (step_cnt + 1) % self.cfg.logger.interval != 0:
+ return
+
+ summary_writer = self.tb_logger.experiment
+
+ # Save losses.
+ for loss_name, loss_val in log_data['losses'].items():
+ summary_writer.add_scalar(f'losses/{loss_name}', loss_val.detach().item(), step_cnt)
+
+ # Visualization of the optimization process. TODO: Maybe we can make this more elegant.
+ if log_data['img_patch'] is None:
+ log_data['img_patch'] = [np.zeros((self.cfg.img_patch_size, self.cfg.img_patch_size, 3), dtype=np.uint8)] \
+ * len(log_data['gt_kp2d'])
+
+ # Overlay the skin mesh of the results on the original image.
+ try:
+ imgs_spliced = []
+ for i, img_patch in enumerate(log_data['img_patch']):
+ img_with_init = render_mesh_overlay_img(
+ faces = self.skel_model.skin_f,
+ verts = self.init_v[i],
+ K4 = [self.cfg.focal_length, self.cfg.focal_length, 0, 0],
+ img = img_patch,
+ Rt = [torch.eye(3), self.init_ct[i]],
+ mesh_color = 'pink',
+ )
+ img_with_init = annotate_img(img_with_init, 'init')
+
+ img_with_mesh = render_mesh_overlay_img(
+ faces = self.skel_model.skin_f,
+ verts = log_data['pd_verts'][i],
+ K4 = [self.cfg.focal_length, self.cfg.focal_length, 0, 0],
+ img = img_patch,
+ Rt = [torch.eye(3), log_data['cam_t'][i]],
+ mesh_color = 'pink',
+ )
+ img_with_mesh = annotate_img(img_with_mesh, 'pd_mesh')
+
+ img_patch_raw = annotate_img(img_patch, 'raw')
+
+ log_data['gt_kp2d'][i][..., :2] = (log_data['gt_kp2d'][i][..., :2] + 0.5) * self.cfg.img_patch_size
+ img_with_gt = annotate_img(img_patch, 'gt_kp2d')
+ img_with_gt = draw_kp2d_on_img(
+ img_with_gt,
+ log_data['gt_kp2d'][i],
+ Skeleton_OpenPose25.bones,
+ Skeleton_OpenPose25.bone_colors,
+ )
+
+ img_spliced = splice_img([img_patch_raw, img_with_gt, img_with_init, img_with_mesh], grid_ids=[[0, 1, 2, 3]])
+ imgs_spliced.append(img_spliced)
+
+ img_final = splice_img(imgs_spliced, grid_ids=[[i] for i in range(len(log_data['img_patch']))])
+
+ img_final = to_tensor(img_final, device=None).permute(2, 0, 1)
+ summary_writer.add_image('visualization', img_final, step_cnt)
+
+ self.render_frames.append(img_final)
+ except Exception as e:
+ get_logger().error(f'Failed to visualize the optimization process: {e}')
+ traceback.print_exc()
diff --git a/lib/modeling/pipelines/__init__.py b/lib/modeling/pipelines/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..84111117cb52fb9475d6d06685bdd2f74f808120
--- /dev/null
+++ b/lib/modeling/pipelines/__init__.py
@@ -0,0 +1 @@
+from .hsmr import HSMRPipeline
\ No newline at end of file
diff --git a/lib/modeling/pipelines/hsmr.py b/lib/modeling/pipelines/hsmr.py
new file mode 100644
index 0000000000000000000000000000000000000000..be3c0ac4bc78a5a94aacf0300b49d108cba97aa4
--- /dev/null
+++ b/lib/modeling/pipelines/hsmr.py
@@ -0,0 +1,649 @@
+from lib.kits.basic import *
+
+import traceback
+
+from lib.utils.vis import Wis3D
+from lib.utils.vis.py_renderer import render_mesh_overlay_img
+from lib.utils.data import to_tensor
+from lib.utils.media import draw_kp2d_on_img, annotate_img, splice_img
+from lib.utils.camera import perspective_projection
+from lib.body_models.abstract_skeletons import Skeleton_OpenPose25
+from lib.modeling.losses import *
+from lib.modeling.networks.discriminators import HSMRDiscriminator
+from lib.platform.config_utils import get_PM_info_dict
+
+
+def build_inference_pipeline(
+ model_root: Union[Path, str],
+ ckpt_fn : Optional[Union[Path, str]] = None,
+ tuned_bcb : bool = True,
+ device : str = 'cpu',
+):
+ # 1.1. Load the config file.
+ if isinstance(model_root, str):
+ model_root = Path(model_root)
+ cfg_path = model_root / '.hydra' / 'config.yaml'
+ cfg = OmegaConf.load(cfg_path)
+ # 1.2. Override PM info dict.
+ PM_overrides = get_PM_info_dict()._pm_
+ cfg._pm_ = PM_overrides
+ get_logger(brief=True).info(f'Building inference pipeline of {cfg.exp_name}')
+
+ # 2.1. Instantiate the pipeline.
+ init_bcb = not tuned_bcb
+ pipeline = instantiate(cfg.pipeline, init_backbone=init_bcb, _recursive_=False)
+ pipeline.set_data_adaption(data_module_name='IMG_PATCHES')
+ # 2.2. Load the checkpoint.
+ if ckpt_fn is None:
+ ckpt_fn = model_root / 'checkpoints' / 'last.ckpt'
+ pipeline.load_state_dict(torch.load(ckpt_fn, map_location=device)['state_dict'])
+ get_logger(brief=True).info(f'Load checkpoint from {ckpt_fn}.')
+
+ pipeline.eval()
+ return pipeline.to(device)
+
+
+class HSMRPipeline(pl.LightningModule):
+
+ def __init__(self, cfg:DictConfig, name:str, init_backbone=True):
+ super(HSMRPipeline, self).__init__()
+ self.name = name
+
+ self.skel_model = instantiate(cfg.SKEL)
+ self.backbone = instantiate(cfg.backbone)
+ self.head = instantiate(cfg.head)
+ self.cfg = cfg
+
+ if init_backbone:
+ # For inference mode with tuned backbone checkpoints, we don't need to initialize the backbone here.
+ self._init_backbone()
+
+ # Loss layers.
+ self.kp_3d_loss = Keypoint3DLoss(loss_type='l1')
+ self.kp_2d_loss = Keypoint2DLoss(loss_type='l1')
+ self.params_loss = ParameterLoss()
+
+ # Discriminator.
+ self.enable_disc = self.cfg.loss_weights.get('adversarial', 0) > 0
+ if self.enable_disc:
+ self.discriminator = HSMRDiscriminator()
+ get_logger().warning(f'Discriminator enabled, the global_steps will be doubled. Use the checkpoints carefully.')
+ else:
+ self.discriminator = None
+ self.cfg.loss_weights.pop('adversarial', None) # pop the adversarial term if not enabled
+
+ # Manually control the optimization since we have an adversarial process.
+ self.automatic_optimization = False
+ self.set_data_adaption()
+
+ # For visualization debug.
+ if False:
+ self.wis3d = Wis3D(seq_name=PM.cfg.exp_name)
+ else:
+ self.wis3d = None
+
+ def set_data_adaption(self, data_module_name:Optional[str]=None):
+ if data_module_name is None:
+ # get_logger().warning('Data adapter schema is not defined. The input will be regarded as image patches.')
+ self.adapt_batch = self._adapt_img_inference
+ elif data_module_name == 'IMG_PATCHES':
+ self.adapt_batch = self._adapt_img_inference
+ elif data_module_name.startswith('SKEL_HSMR_V1'):
+ self.adapt_batch = self._adapt_hsmr_v1
+ else:
+ raise ValueError(f'Unknown data module: {data_module_name}')
+
+ def print_summary(self, max_depth=1):
+ from pytorch_lightning.utilities.model_summary.model_summary import ModelSummary
+ print(ModelSummary(self, max_depth=max_depth))
+
+ def configure_optimizers(self):
+ optimizers = []
+
+ params_main = filter(lambda p: p.requires_grad, self._params_main())
+ optimizer_main = instantiate(self.cfg.optimizer, params=params_main)
+ optimizers.append(optimizer_main)
+
+ if len(self._params_disc()) > 0:
+ params_disc = filter(lambda p: p.requires_grad, self._params_disc())
+ optimizer_disc = instantiate(self.cfg.optimizer, params=params_disc)
+ optimizers.append(optimizer_disc)
+
+ return optimizers
+
+ def training_step(self, raw_batch, batch_idx):
+ with PM.time_monitor('training_step'):
+ return self._training_step(raw_batch, batch_idx)
+
+ def _training_step(self, raw_batch, batch_idx):
+ # GPU_monitor = GPUMonitor()
+ # GPU_monitor.snapshot('HSMR training start')
+
+ batch = self.adapt_batch(raw_batch['img_ds'])
+ # GPU_monitor.snapshot('HSMR adapt batch')
+
+ # Get the optimizer.
+ optimizers = self.optimizers(use_pl_optimizer=True)
+ if isinstance(optimizers, List):
+ optimizer_main, optimizer_disc = optimizers
+ else:
+ optimizer_main = optimizers
+ # GPU_monitor.snapshot('HSMR get optimizer')
+
+ # 1. Main parts forward pass.
+ with PM.time_monitor('forward_step'):
+ img_patch = to_tensor(batch['img_patch'], self.device) # (B, C, H, W)
+ B = len(img_patch)
+ outputs = self.forward_step(img_patch) # {...}
+ # GPU_monitor.snapshot('HSMR forward')
+ pd_skel_params = HSMRPipeline._adapt_skel_params(outputs['pd_params'])
+ # GPU_monitor.snapshot('HSMR adapt SKEL params')
+
+ # 2. [Optional] Discriminator forward pass in main training step.
+ if self.enable_disc:
+ with PM.time_monitor('disc_forward'):
+ pd_poses_mat, _ = self.skel_model.pose_params_to_rot(pd_skel_params['poses']) # (B, J=24, 3, 3)
+ pd_poses_body_mat = pd_poses_mat[:, 1:, :, :] # (B, J=23, 3, 3)
+ pd_betas = pd_skel_params['betas'] # (B, 10)
+ disc_out = self.discriminator(
+ poses_body = pd_poses_body_mat, # (B, J=23, 3, 3)
+ betas = pd_betas, # (B, 10)
+ )
+ else:
+ disc_out = None
+
+ # 3. Prepare the secondary products
+ with PM.time_monitor('Secondary Products Preparation'):
+ # 3.1. Body model outputs.
+ with PM.time_monitor('SKEL Forward'):
+ skel_outputs = self.skel_model(**pd_skel_params, skelmesh=False)
+ pd_kp3d = skel_outputs.joints # (B, Q=44, 3)
+ pd_skin = skel_outputs.skin_verts # (B, V=6890, 3)
+ # 3.2. Reproject the 3D joints to 2D plain.
+ with PM.time_monitor('Reprojection'):
+ pd_kp2d = perspective_projection(
+ points = pd_kp3d, # (B, K=Q=44, 3)
+ translation = outputs['pd_cam_t'], # (B, 3)
+ focal_length = outputs['focal_length'] / self.cfg.policy.img_patch_size, # (B, 2)
+ ) # (B, 44, 2)
+ # 3.3. Extract G.T. from inputs.
+ gt_kp2d_with_conf = batch['kp2d'].clone() # (B, 44, 3)
+ gt_kp3d_with_conf = batch['kp3d'].clone() # (B, 44, 4)
+ # 3.4. Extract G.T. skin mesh only for visualization.
+ gt_skel_params = HSMRPipeline._adapt_skel_params(batch['gt_params']) # {poses, betas}
+ gt_skel_params = {k: v[:self.cfg.logger.samples_per_record] for k, v in gt_skel_params.items()}
+ skel_outputs = self.skel_model(**gt_skel_params, skelmesh=False)
+ gt_skin = skel_outputs.skin_verts # (B', V=6890, 3)
+ gt_valid_body = batch['has_gt_params']['poses_body'][:self.cfg.logger.samples_per_record] # {poses_orient, poses_body, betas}
+ gt_valid_orient = batch['has_gt_params']['poses_orient'][:self.cfg.logger.samples_per_record] # {poses_orient, poses_body, betas}
+ gt_valid_betas = batch['has_gt_params']['betas'][:self.cfg.logger.samples_per_record] # {poses_orient, poses_body, betas}
+ gt_valid = torch.logical_and(torch.logical_and(gt_valid_body, gt_valid_orient), gt_valid_betas)
+ # GPU_monitor.snapshot('HSMR secondary products')
+
+ # 4. Compute losses.
+ with PM.time_monitor('Compute Loss'):
+ loss_main, losses_main = self._compute_losses_main(
+ self.cfg.loss_weights,
+ pd_kp3d, # (B, 44, 3)
+ gt_kp3d_with_conf, # (B, 44, 4)
+ pd_kp2d, # (B, 44, 2)
+ gt_kp2d_with_conf, # (B, 44, 3)
+ outputs['pd_params'], # {'poses_orient':..., 'poses_body':..., 'betas':...}
+ batch['gt_params'], # {'poses_orient':..., 'poses_body':..., 'betas':...}
+ batch['has_gt_params'],
+ disc_out,
+ )
+ # GPU_monitor.snapshot('HSMR compute losses')
+ if torch.isnan(loss_main):
+ get_logger().error(f'NaN detected in loss computation. Losses: {losses}')
+
+ # 5. Main parts backward pass.
+ with PM.time_monitor('Backward Step'):
+ optimizer_main.zero_grad()
+ self.manual_backward(loss_main)
+ optimizer_main.step()
+ # GPU_monitor.snapshot('HSMR backwards')
+
+ # 6. [Optional] Discriminator training part.
+ if self.enable_disc:
+ with PM.time_monitor('Train Discriminator'):
+ losses_disc = self._train_discriminator(
+ mocap_batch = raw_batch['mocap_ds'],
+ pd_poses_body_mat = pd_poses_body_mat,
+ pd_betas = pd_betas,
+ optimizer = optimizer_disc,
+ )
+ else:
+ losses_disc = {}
+
+ # 7. Logging.
+ with PM.time_monitor('Tensorboard Logging'):
+ vis_data = {
+ 'img_patch' : to_numpy(img_patch[:self.cfg.logger.samples_per_record]).transpose((0, 2, 3, 1)).copy(),
+ 'pd_kp2d' : pd_kp2d[:self.cfg.logger.samples_per_record].clone(),
+ 'pd_kp3d' : pd_kp3d[:self.cfg.logger.samples_per_record].clone(),
+ 'gt_kp2d_with_conf' : gt_kp2d_with_conf[:self.cfg.logger.samples_per_record].clone(),
+ 'gt_kp3d_with_conf' : gt_kp3d_with_conf[:self.cfg.logger.samples_per_record].clone(),
+ 'pd_skin' : pd_skin[:self.cfg.logger.samples_per_record].clone(),
+ 'gt_skin' : gt_skin.clone(),
+ 'gt_skin_valid' : gt_valid,
+ 'cam_t' : outputs['pd_cam_t'][:self.cfg.logger.samples_per_record].clone(),
+ 'img_key' : batch['__key__'][:self.cfg.logger.samples_per_record],
+ }
+ self._tb_log(losses_main=losses_main, losses_disc=losses_disc, vis_data=vis_data)
+ # GPU_monitor.snapshot('HSMR logging')
+ self.log('_/loss_main', losses_main['weighted_sum'], on_step=True, on_epoch=True, prog_bar=True, logger=True, batch_size=B)
+
+ # GPU_monitor.report_all()
+ return outputs
+
+ def forward(self, batch):
+ '''
+ ### Returns
+ - outputs: Dict
+ - pd_kp3d: torch.Tensor, shape (B, Q=44, 3)
+ - pd_kp2d: torch.Tensor, shape (B, Q=44, 2)
+ - pred_keypoints_2d: torch.Tensor, shape (B, Q=44, 2)
+ - pred_keypoints_3d: torch.Tensor, shape (B, Q=44, 3)
+ - pd_params: Dict
+ - poses: torch.Tensor, shape (B, 46)
+ - betas: torch.Tensor, shape (B, 10)
+ - pd_cam: torch.Tensor, shape (B, 3)
+ - pd_cam_t: torch.Tensor, shape (B, 3)
+ - focal_length: torch.Tensor, shape (B, 2)
+ '''
+ batch = self.adapt_batch(batch)
+
+ # 1. Main parts forward pass.
+ img_patch = to_tensor(batch['img_patch'], self.device) # (B, C, H, W)
+ outputs = self.forward_step(img_patch) # {...}
+
+ # 2. Prepare the secondary products
+ # 2.1. Body model outputs.
+ pd_skel_params = HSMRPipeline._adapt_skel_params(outputs['pd_params'])
+ skel_outputs = self.skel_model(**pd_skel_params, skelmesh=False)
+ pd_kp3d = skel_outputs.joints # (B, Q=44, 3)
+ pd_skin_verts = skel_outputs.skin_verts.detach().cpu().clone() # (B, V=6890, 3)
+ # 2.2. Reproject the 3D joints to 2D plain.
+ pd_kp2d = perspective_projection(
+ points = to_tensor(pd_kp3d, device=self.device), # (B, K=Q=44, 3)
+ translation = to_tensor(outputs['pd_cam_t'], device=self.device), # (B, 3)
+ focal_length = to_tensor(outputs['focal_length'], device=self.device) / self.cfg.policy.img_patch_size, # (B, 2)
+ )
+
+ outputs['pd_kp3d'] = pd_kp3d
+ outputs['pd_kp2d'] = pd_kp2d
+ outputs['pred_keypoints_2d'] = pd_kp2d # adapt HMR2.0's script
+ outputs['pred_keypoints_3d'] = pd_kp3d # adapt HMR2.0's script
+ outputs['pd_params'] = pd_skel_params
+ outputs['pd_skin_verts'] = pd_skin_verts
+
+ return outputs
+
+ def forward_step(self, x:torch.Tensor):
+ '''
+ Run an inference step on the model.
+
+ ### Args
+ - x: torch.Tensor, shape (B, C, H, W)
+ - The input image patch.
+
+ ### Returns
+ - outputs: Dict
+ - 'pd_cam': torch.Tensor, shape (B, 3)
+ - The predicted camera parameters.
+ - 'pd_params': Dict
+ - The predicted body model parameters.
+ - 'focal_length': float
+ '''
+ # GPU_monitor = GPUMonitor()
+ B = len(x)
+
+ # 1. Extract features from image.
+ # The input size is 256*256, but ViT needs 256*192. TODO: make this more elegant.
+ with PM.time_monitor('Backbone Forward'):
+ feats = self.backbone(x[:, :, :, 32:-32])
+ # GPU_monitor.snapshot('HSMR forward backbone')
+
+
+ # 2. Run the head to predict the body model parameters.
+ with PM.time_monitor('Predict Head Forward'):
+ pd_params, pd_cam = self.head(feats)
+ # GPU_monitor.snapshot('HSMR forward head')
+
+ # 3. Transform the camera parameters to camera translation.
+ focal_length = self.cfg.policy.focal_length * torch.ones(B, 2, device=self.device, dtype=pd_cam.dtype) # (B, 2)
+ pd_cam_t = torch.stack([
+ pd_cam[:, 1],
+ pd_cam[:, 2],
+ 2 * focal_length[:, 0] / (self.cfg.policy.img_patch_size * pd_cam[:, 0] + 1e-9)
+ ], dim=-1) # (B, 3)
+
+ # 4. Store the results.
+ outputs = {
+ 'pd_cam' : pd_cam,
+ 'pd_cam_t' : pd_cam_t,
+ 'pd_params' : pd_params,
+ # 'pd_params' : {k: v.clone() for k, v in pd_params.items()},
+ 'focal_length' : focal_length, # (B, 2)
+ }
+ # GPU_monitor.report_all()
+ return outputs
+
+
+ # ========== Internal Functions ==========
+
+ def _params_main(self):
+ return list(self.head.parameters()) + list(self.backbone.parameters())
+
+ def _params_disc(self):
+ if self.discriminator is None:
+ return []
+ else:
+ return list(self.discriminator.parameters())
+
+ @staticmethod
+ def _adapt_skel_params(params:Dict):
+ ''' Change the parameters formed like [pose_orient, pose_body, betas, trans] to [poses, betas, trans]. '''
+ adapted_params = {}
+
+ if 'poses' in params.keys():
+ adapted_params['poses'] = params['poses']
+ elif 'poses_orient' in params.keys() and 'poses_body' in params.keys():
+ poses_orient = params['poses_orient'] # (B, 3)
+ poses_body = params['poses_body'] # (B, 43)
+ adapted_params['poses'] = torch.cat([poses_orient, poses_body], dim=1) # (B, 46)
+ else:
+ raise ValueError(f'Cannot find the poses parameters among {list(params.keys())}.')
+
+ if 'betas' in params.keys():
+ adapted_params['betas'] = params['betas'] # (B, 10)
+ else:
+ raise ValueError(f'Cannot find the betas parameters among {list(params.keys())}.')
+
+ return adapted_params
+
+ def _init_backbone(self):
+ # 1. Loading the backbone weights.
+ get_logger().info(f'Loading backbone weights from {self.cfg.backbone_ckpt}')
+ state_dict = torch.load(self.cfg.backbone_ckpt, map_location='cpu')['state_dict']
+ if 'backbone.cls_token' in state_dict.keys():
+ state_dict = {k: v for k, v in state_dict.items() if 'backbone' in k and 'cls_token' not in k}
+ state_dict = {k.replace('backbone.', ''): v for k, v in state_dict.items()}
+ missing, unexpected = self.backbone.load_state_dict(state_dict)
+ if len(missing) > 0:
+ get_logger().warning(f'Missing keys in backbone: {missing}')
+ if len(unexpected) > 0:
+ get_logger().warning(f'Unexpected keys in backbone: {unexpected}')
+
+ # 2. Freeze the backbone if needed.
+ if self.cfg.get('freeze_backbone', False):
+ self.backbone.eval()
+ self.backbone.requires_grad_(False)
+
+ def _compute_losses_main(
+ self,
+ loss_weights : Dict,
+ pd_kp3d : torch.Tensor,
+ gt_kp3d : torch.Tensor,
+ pd_kp2d : torch.Tensor,
+ gt_kp2d : torch.Tensor,
+ pd_params : Dict,
+ gt_params : Dict,
+ has_params : Dict,
+ disc_out : Optional[torch.Tensor]=None,
+ *args, **kwargs,
+ ) -> Tuple[torch.Tensor, Dict]:
+ ''' Compute the weighted losses according to the config file. '''
+
+ # 1. Preparation.
+ with PM.time_monitor('Preparation'):
+ B = len(pd_kp3d)
+ gt_skel_params = HSMRPipeline._adapt_skel_params(gt_params) # {poses, betas}
+ pd_skel_params = HSMRPipeline._adapt_skel_params(pd_params) # {poses, betas}
+
+ gt_betas = gt_skel_params['betas'].reshape(-1, 10)
+ pd_betas = pd_skel_params['betas'].reshape(-1, 10)
+ gt_poses = gt_skel_params['poses'].reshape(-1, 46)
+ pd_poses = pd_skel_params['poses'].reshape(-1, 46)
+
+ # 2. Keypoints losses.
+ with PM.time_monitor('kp2d & kp3d Loss'):
+ kp2d_loss = self.kp_2d_loss(pd_kp2d, gt_kp2d) / B
+ kp3d_loss = self.kp_3d_loss(pd_kp3d, gt_kp3d) / B
+
+ # 3. Prior losses.
+ with PM.time_monitor('Prior Loss'):
+ prior_loss = compute_poses_angle_prior_loss(pd_poses).mean() # (,)
+
+ # 4. Parameters losses.
+ if self.cfg.sp_poses_repr == 'rotation_matrix':
+ with PM.time_monitor('q2mat'):
+ gt_poses_mat, _ = self.skel_model.pose_params_to_rot(gt_poses) # (B, J=24, 3, 3)
+ pd_poses_mat, _ = self.skel_model.pose_params_to_rot(pd_poses) # (B, J=24, 3, 3)
+
+ gt_poses = gt_poses_mat.reshape(-1, 24*3*3) # (B, 24*3*3)
+ pd_poses = pd_poses_mat.reshape(-1, 24*3*3) # (B, 24*3*3)
+
+ with PM.time_monitor('Parameters Loss'):
+ poses_orient_loss = self.params_loss(pd_poses[:, :9], gt_poses[:, :9], has_params['poses_orient']) / B
+ poses_body_loss = self.params_loss(pd_poses[:, 9:], gt_poses[:, 9:], has_params['poses_body']) / B
+ betas_loss = self.params_loss(pd_betas, gt_betas, has_params['betas']) / B
+
+ # 5. Collect main losses.
+ with PM.time_monitor('Accumulate'):
+ losses = {
+ 'kp3d' : kp3d_loss, # (,)
+ 'kp2d' : kp2d_loss, # (,)
+ 'prior' : prior_loss, # (,)
+ 'poses_orient' : poses_orient_loss, # (,)
+ 'poses_body' : poses_body_loss, # (,)
+ 'betas' : betas_loss, # (,)
+ }
+
+ # 6. Consider adversarial loss.
+ if disc_out is not None:
+ with PM.time_monitor('Adversarial Loss'):
+ adversarial_loss = ((disc_out - 1.0) ** 2).sum() / B # (,)
+ losses['adversarial'] = adversarial_loss
+
+ with PM.time_monitor('Accumulate'):
+ loss = torch.tensor(0., device=self.device)
+ for k, v in losses.items():
+ loss += v * loss_weights[k]
+ losses = {k: v.item() for k, v in losses.items()}
+ losses['weighted_sum'] = loss.item()
+ return loss, losses
+
+ def _train_discriminator(self, mocap_batch, pd_poses_body_mat, pd_betas, optimizer):
+ '''
+ Train the discriminator using the regressed body model parameters and the realistic MoCap data.
+
+ ### Args
+ - mocap_batch: Dict
+ - 'poses_body': torch.Tensor, shape (B, 43)
+ - 'betas': torch.Tensor, shape (B, 10)
+ - pd_poses_body_mat: torch.Tensor, shape (B, J=23, 3, 3)
+ - pd_betas: torch.Tensor, shape (B, 10)
+ - optimizer: torch.optim.Optimizer
+
+ ### Returns
+ - losses: Dict
+ - 'pd_disc': float
+ - 'mc_disc': float
+ '''
+ pd_B = len(pd_poses_body_mat)
+ mc_B = len(mocap_batch['poses_body'])
+ get_logger().warning(f'pd_B: {pd_B} != mc_B: {mc_B}')
+
+ # 1. Extract the realistic 3D MoCap label.
+ mc_poses_body = mocap_batch['poses_body'] # (B, 43)
+ padding_zeros = mc_poses_body.new_zeros(mc_B, 3) # (B, 3)
+ mc_poses = torch.cat([padding_zeros, mc_poses_body], dim=1) # (B, 46)
+ mc_poses_mat, _ = self.skel_model.pose_params_to_rot(mc_poses) # (B, J=24, 3, 3)
+ mc_poses_body_mat = mc_poses_mat[:, 1:, :, :] # (B, J=23, 3, 3)
+ mc_betas = mocap_batch['betas'] # (B, 10)
+
+ # 2. Forward pass.
+ # Discriminator forward pass for the predicted data.
+ pd_disc_out = self.discriminator(pd_poses_body_mat.detach(), pd_betas.detach())
+ pd_disc_loss = ((pd_disc_out - 0.0) ** 2).sum() / pd_B # (,)
+ # Discriminator forward pass for the realistic MoCap data.
+ mc_disc_out = self.discriminator(mc_poses_body_mat, mc_betas)
+ mc_disc_loss = ((mc_disc_out - 1.0) ** 2).sum() / pd_B # (,) TODO: This 'pd_B' is from HMR2, not sure if it's a bug.
+
+ # 3. Backward pass.
+ disc_loss = self.cfg.loss_weights.adversarial * (pd_disc_loss + mc_disc_loss)
+ optimizer.zero_grad()
+ self.manual_backward(disc_loss)
+ optimizer.step()
+
+ return {
+ 'pd_disc': pd_disc_loss.item(),
+ 'mc_disc': mc_disc_loss.item(),
+ }
+
+ @rank_zero_only
+ def _tb_log(self, losses_main:Dict, losses_disc:Dict, vis_data:Dict, mode:str='train'):
+ ''' Write the logging information to the TensorBoard. '''
+ if self.logger is None:
+ return
+
+ if self.global_step != 1 and self.global_step % self.cfg.logger.interval != 0:
+ return
+
+ # 1. Losses.
+ summary_writer = self.logger.experiment
+ for loss_name, loss_val in losses_main.items():
+ summary_writer.add_scalar(f'{mode}/losses_main/{loss_name}', loss_val, self.global_step)
+ for loss_name, loss_val in losses_disc.items():
+ summary_writer.add_scalar(f'{mode}/losses_disc/{loss_name}', loss_val, self.global_step)
+
+ # 2. Visualization.
+ try:
+ pelvis_id = 39
+ # 2.1. Visualize 3D information.
+ self.wis3d.add_motion_mesh(
+ verts = vis_data['pd_skin'] - vis_data['pd_kp3d'][:, pelvis_id:pelvis_id+1], # center the mesh
+ faces = self.skel_model.skin_f,
+ name = 'pd_skin',
+ )
+ self.wis3d.add_motion_mesh(
+ verts = vis_data['gt_skin'] - vis_data['gt_kp3d_with_conf'][:, pelvis_id:pelvis_id+1, :3], # center the mesh
+ faces = self.skel_model.skin_f,
+ name = 'gt_skin',
+ )
+ self.wis3d.add_motion_skel(
+ joints = vis_data['pd_kp3d'] - vis_data['pd_kp3d'][:, pelvis_id:pelvis_id+1],
+ bones = Skeleton_OpenPose25.bones,
+ colors = Skeleton_OpenPose25.bone_colors,
+ name = 'pd_kp3d',
+ )
+
+ aligned_gt_kp3d = vis_data['gt_kp3d_with_conf']
+ aligned_gt_kp3d[..., :3] -= vis_data['gt_kp3d_with_conf'][:, pelvis_id:pelvis_id+1, :3]
+ self.wis3d.add_motion_skel(
+ joints = aligned_gt_kp3d,
+ bones = Skeleton_OpenPose25.bones,
+ colors = Skeleton_OpenPose25.bone_colors,
+ name = 'gt_kp3d',
+ )
+ except Exception as e:
+ get_logger().error(f'Failed to visualize the current performance on wis3d: {e}')
+
+ try:
+ # 2.2. Visualize 2D information.
+ if vis_data['img_patch'] is not None:
+ # Overlay the skin mesh of the results on the original image.
+ imgs_spliced = []
+ for i, img_patch in enumerate(vis_data['img_patch']):
+ # TODO: make this more elegant.
+ img_mean = to_numpy(OmegaConf.to_container(self.cfg.policy.img_mean))[None, None] # (1, 1, 3)
+ img_std = to_numpy(OmegaConf.to_container(self.cfg.policy.img_std))[None, None] # (1, 1, 3)
+ img_patch = ((img_mean + img_patch * img_std) * 255).astype(np.uint8)
+
+ img_patch_raw = annotate_img(img_patch, 'raw')
+
+ img_with_mesh = render_mesh_overlay_img(
+ faces = self.skel_model.skin_f,
+ verts = vis_data['pd_skin'][i].float(),
+ K4 = [self.cfg.policy.focal_length, self.cfg.policy.focal_length, 128, 128],
+ img = img_patch,
+ Rt = [torch.eye(3).float(), vis_data['cam_t'][i].float()],
+ mesh_color = 'pink',
+ )
+ img_with_mesh = annotate_img(img_with_mesh, 'pd_mesh')
+
+ img_with_gt_mesh = render_mesh_overlay_img(
+ faces = self.skel_model.skin_f,
+ verts = vis_data['gt_skin'][i].float(),
+ K4 = [self.cfg.policy.focal_length, self.cfg.policy.focal_length, 128, 128],
+ img = img_patch,
+ Rt = [torch.eye(3).float(), vis_data['cam_t'][i].float()],
+ mesh_color = 'pink',
+ )
+ valid = 'valid' if vis_data['gt_skin_valid'][i] else 'invalid'
+ img_with_gt_mesh = annotate_img(img_with_gt_mesh, f'gt_mesh_{valid}')
+
+ img_with_gt = annotate_img(img_patch, 'gt_kp2d')
+ gt_kp2d_with_conf = vis_data['gt_kp2d_with_conf'][i]
+ gt_kp2d_with_conf[:, :2] = (gt_kp2d_with_conf[:, :2] + 0.5) * self.cfg.policy.img_patch_size
+ img_with_gt = draw_kp2d_on_img(
+ img_with_gt,
+ gt_kp2d_with_conf,
+ Skeleton_OpenPose25.bones,
+ Skeleton_OpenPose25.bone_colors,
+ )
+
+ img_with_pd = annotate_img(img_patch, 'pd_kp2d')
+ pd_kp2d_vis = vis_data['pd_kp2d'][i]
+ pd_kp2d_vis = (pd_kp2d_vis + 0.5) * self.cfg.policy.img_patch_size
+ img_with_pd = draw_kp2d_on_img(
+ img_with_pd,
+ (vis_data['pd_kp2d'][i] + 0.5) * self.cfg.policy.img_patch_size,
+ Skeleton_OpenPose25.bones,
+ Skeleton_OpenPose25.bone_colors,
+ )
+
+ img_spliced = splice_img([img_patch_raw, img_with_gt, img_with_pd, img_with_mesh, img_with_gt_mesh], grid_ids=[[0, 1, 2, 3, 4]])
+ img_spliced = annotate_img(img_spliced, vis_data['img_key'][i], pos='tl')
+ imgs_spliced.append(img_spliced)
+
+ try:
+ self.wis3d.set_scene_id(i)
+ self.wis3d.add_image(
+ image = img_spliced,
+ name = 'image',
+ )
+ except Exception as e:
+ get_logger().error(f'Failed to visualize the current performance on wis3d: {e}')
+
+ img_final = splice_img(imgs_spliced, grid_ids=[[i] for i in range(len(vis_data['img_patch']))])
+
+ img_final = to_tensor(img_final, device=None).permute(2, 0, 1)
+ summary_writer.add_image(f'{mode}/visualization', img_final, self.global_step)
+
+ except Exception as e:
+ get_logger().error(f'Failed to visualize the current performance: {e}')
+ # traceback.print_exc()
+
+
+ def _adapt_hsmr_v1(self, batch):
+ from lib.data.augmentation.skel import rot_skel_on_plane
+ rot_deg = batch['augm_args']['rot_deg'] # (B,)
+
+ skel_params = rot_skel_on_plane(batch['raw_skel_params'], rot_deg)
+ batch['gt_params'] = {}
+ batch['gt_params']['poses_orient'] = skel_params['poses'][:, :3]
+ batch['gt_params']['poses_body'] = skel_params['poses'][:, 3:]
+ batch['gt_params']['betas'] = skel_params['betas']
+
+ has_skel_params = batch['has_skel_params']
+ batch['has_gt_params'] = {}
+ batch['has_gt_params']['poses_orient'] = has_skel_params['poses']
+ batch['has_gt_params']['poses_body'] = has_skel_params['poses']
+ batch['has_gt_params']['betas'] = has_skel_params['betas']
+ return batch
+
+ def _adapt_img_inference(self, img_patches):
+ return {'img_patch': img_patches}
\ No newline at end of file
diff --git a/lib/modeling/pipelines/vitdet/__init__.py b/lib/modeling/pipelines/vitdet/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..001fdc725e611574500b06c0c6024929db75968a
--- /dev/null
+++ b/lib/modeling/pipelines/vitdet/__init__.py
@@ -0,0 +1,20 @@
+from lib.kits.basic import *
+
+from detectron2.config import LazyConfig
+from .utils_detectron2 import DefaultPredictor_Lazy
+
+
+def build_detector(batch_size, max_img_size, device):
+ cfg_path = Path(__file__).parent / 'cascade_mask_rcnn_vitdet_h_75ep.py'
+ detectron2_cfg = LazyConfig.load(str(cfg_path))
+ detectron2_cfg.train.init_checkpoint = "https://dl.fbaipublicfiles.com/detectron2/ViTDet/COCO/cascade_mask_rcnn_vitdet_h/f328730692/model_final_f05665.pkl"
+ for i in range(3):
+ detectron2_cfg.model.roi_heads.box_predictors[i].test_score_thresh = 0.25
+
+ detector = DefaultPredictor_Lazy(
+ cfg = detectron2_cfg,
+ batch_size = batch_size,
+ max_img_size = max_img_size,
+ device = device,
+ )
+ return detector
\ No newline at end of file
diff --git a/lib/modeling/pipelines/vitdet/cascade_mask_rcnn_vitdet_h_75ep.py b/lib/modeling/pipelines/vitdet/cascade_mask_rcnn_vitdet_h_75ep.py
new file mode 100644
index 0000000000000000000000000000000000000000..0c6ae0eaf48c2c2d3b70529a0d2d915432e43db6
--- /dev/null
+++ b/lib/modeling/pipelines/vitdet/cascade_mask_rcnn_vitdet_h_75ep.py
@@ -0,0 +1,129 @@
+## coco_loader_lsj.py
+
+import detectron2.data.transforms as T
+from detectron2 import model_zoo
+from detectron2.config import LazyCall as L
+
+# Data using LSJ
+image_size = 1024
+dataloader = model_zoo.get_config("common/data/coco.py").dataloader
+dataloader.train.mapper.augmentations = [
+ L(T.RandomFlip)(horizontal=True), # flip first
+ L(T.ResizeScale)(
+ min_scale=0.1, max_scale=2.0, target_height=image_size, target_width=image_size
+ ),
+ L(T.FixedSizeCrop)(crop_size=(image_size, image_size), pad=False),
+]
+dataloader.train.mapper.image_format = "RGB"
+dataloader.train.total_batch_size = 64
+# recompute boxes due to cropping
+dataloader.train.mapper.recompute_boxes = True
+
+dataloader.test.mapper.augmentations = [
+ L(T.ResizeShortestEdge)(short_edge_length=image_size, max_size=image_size),
+]
+
+from functools import partial
+from fvcore.common.param_scheduler import MultiStepParamScheduler
+
+from detectron2 import model_zoo
+from detectron2.config import LazyCall as L
+from detectron2.solver import WarmupParamScheduler
+from detectron2.modeling.backbone.vit import get_vit_lr_decay_rate
+
+# mask_rcnn_vitdet_b_100ep.py
+
+model = model_zoo.get_config("common/models/mask_rcnn_vitdet.py").model
+
+# Initialization and trainer settings
+train = model_zoo.get_config("common/train.py").train
+train.amp.enabled = True
+train.ddp.fp16_compression = True
+train.init_checkpoint = "detectron2://ImageNetPretrained/MAE/mae_pretrain_vit_base.pth"
+
+
+# Schedule
+# 100 ep = 184375 iters * 64 images/iter / 118000 images/ep
+train.max_iter = 184375
+
+lr_multiplier = L(WarmupParamScheduler)(
+ scheduler=L(MultiStepParamScheduler)(
+ values=[1.0, 0.1, 0.01],
+ milestones=[163889, 177546],
+ num_updates=train.max_iter,
+ ),
+ warmup_length=250 / train.max_iter,
+ warmup_factor=0.001,
+)
+
+# Optimizer
+optimizer = model_zoo.get_config("common/optim.py").AdamW
+optimizer.params.lr_factor_func = partial(get_vit_lr_decay_rate, num_layers=12, lr_decay_rate=0.7)
+optimizer.params.overrides = {"pos_embed": {"weight_decay": 0.0}}
+
+# cascade_mask_rcnn_vitdet_b_100ep.py
+
+from detectron2.config import LazyCall as L
+from detectron2.layers import ShapeSpec
+from detectron2.modeling.box_regression import Box2BoxTransform
+from detectron2.modeling.matcher import Matcher
+from detectron2.modeling.roi_heads import (
+ FastRCNNOutputLayers,
+ FastRCNNConvFCHead,
+ CascadeROIHeads,
+)
+
+# arguments that don't exist for Cascade R-CNN
+[model.roi_heads.pop(k) for k in ["box_head", "box_predictor", "proposal_matcher"]]
+
+model.roi_heads.update(
+ _target_=CascadeROIHeads,
+ box_heads=[
+ L(FastRCNNConvFCHead)(
+ input_shape=ShapeSpec(channels=256, height=7, width=7),
+ conv_dims=[256, 256, 256, 256],
+ fc_dims=[1024],
+ conv_norm="LN",
+ )
+ for _ in range(3)
+ ],
+ box_predictors=[
+ L(FastRCNNOutputLayers)(
+ input_shape=ShapeSpec(channels=1024),
+ test_score_thresh=0.05,
+ box2box_transform=L(Box2BoxTransform)(weights=(w1, w1, w2, w2)),
+ cls_agnostic_bbox_reg=True,
+ num_classes="${...num_classes}",
+ )
+ for (w1, w2) in [(10, 5), (20, 10), (30, 15)]
+ ],
+ proposal_matchers=[
+ L(Matcher)(thresholds=[th], labels=[0, 1], allow_low_quality_matches=False)
+ for th in [0.5, 0.6, 0.7]
+ ],
+)
+
+# cascade_mask_rcnn_vitdet_h_75ep.py
+
+from functools import partial
+
+train.init_checkpoint = "detectron2://ImageNetPretrained/MAE/mae_pretrain_vit_huge_p14to16.pth"
+
+model.backbone.net.embed_dim = 1280
+model.backbone.net.depth = 32
+model.backbone.net.num_heads = 16
+model.backbone.net.drop_path_rate = 0.5
+# 7, 15, 23, 31 for global attention
+model.backbone.net.window_block_indexes = (
+ list(range(0, 7)) + list(range(8, 15)) + list(range(16, 23)) + list(range(24, 31))
+)
+
+optimizer.params.lr_factor_func = partial(get_vit_lr_decay_rate, lr_decay_rate=0.9, num_layers=32)
+optimizer.params.overrides = {}
+optimizer.params.weight_decay_norm = None
+
+train.max_iter = train.max_iter * 3 // 4 # 100ep -> 75ep
+lr_multiplier.scheduler.milestones = [
+ milestone * 3 // 4 for milestone in lr_multiplier.scheduler.milestones
+]
+lr_multiplier.scheduler.num_updates = train.max_iter
diff --git a/lib/modeling/pipelines/vitdet/utils_detectron2.py b/lib/modeling/pipelines/vitdet/utils_detectron2.py
new file mode 100644
index 0000000000000000000000000000000000000000..299331b8c0a1056e03d856ca32efe9b2ef9b1e3a
--- /dev/null
+++ b/lib/modeling/pipelines/vitdet/utils_detectron2.py
@@ -0,0 +1,117 @@
+from lib.kits.basic import *
+
+from tqdm import tqdm
+
+from lib.utils.media import flex_resize_img
+
+import detectron2.data.transforms as T
+from detectron2.checkpoint import DetectionCheckpointer
+from detectron2.config import instantiate as instantiate_detectron2
+from detectron2.data import MetadataCatalog
+
+
+class DefaultPredictor_Lazy:
+ '''
+ Create a simple end-to-end predictor with the given config that runs on single device for a
+ several input images.
+ Compared to using the model directly, this class does the following additions:
+
+ Modified from: https://github.com/shubham-goel/4D-Humans/blob/6ec79656a23c33237c724742ca2a0ec00b398b53/hmr2/utils/utils_detectron2.py#L9-L93
+
+ 1. Load checkpoint from the weights specified in config (cfg.MODEL.WEIGHTS).
+ 2. Always take BGR image as the input and apply format conversion internally.
+ 3. Apply resizing defined by the parameter `max_img_size`.
+ 4. Take input images and produce outputs, and filter out only the `instances` data.
+ 5. Use an auto-tuned batch size to process the images in a batch.
+ - Start with the given batch size, if failed, reduce the batch size by half.
+ - If the batch size is reduced to 1 and still failed, skip the image.
+ - The implementation is abstracted to `lib.platform.sliding_batches`.
+ '''
+
+ def __init__(self, cfg, batch_size=20, max_img_size=512, device='cuda:0'):
+ self.batch_size = batch_size
+ self.max_img_size = max_img_size
+ self.device = device
+ self.model = instantiate_detectron2(cfg.model)
+
+ test_dataset = OmegaConf.select(cfg, 'dataloader.test.dataset.names', default=None)
+ if isinstance(test_dataset, (List, Tuple)):
+ test_dataset = test_dataset[0]
+
+ checkpointer = DetectionCheckpointer(self.model)
+ checkpointer.load(OmegaConf.select(cfg, 'train.init_checkpoint', default=''))
+
+ mapper = instantiate_detectron2(cfg.dataloader.test.mapper)
+ self.aug = mapper.augmentations
+ self.input_format = mapper.image_format
+
+ self.model.eval().to(self.device)
+ if test_dataset:
+ self.metadata = MetadataCatalog.get(test_dataset)
+
+ assert self.input_format in ['RGB'], f'Invalid input format: {self.input_format}'
+ # assert self.input_format in ['RGB', 'BGR'], f'Invalid input format: {self.input_format}'
+
+ def __call__(self, imgs):
+ '''
+ ### Args
+ - `imgs`: List[np.ndarray], a list of image of shape (Hi, Wi, RGB).
+ - Shapes of each image may be different.
+
+ ### Returns
+ - `predictions`: dict,
+ - the output of the model for one image only.
+ - See :doc:`/tutorials/models` for details about the format.
+ '''
+ with torch.no_grad():
+ inputs = []
+ downsample_ratios = []
+ for img in imgs:
+ img_size = max(img.shape[:2])
+ if img_size > self.max_img_size: # exceed the max size, make it smaller
+ downsample_ratio = self.max_img_size / img_size
+ img = flex_resize_img(img, ratio=downsample_ratio)
+ downsample_ratios.append(downsample_ratio)
+ else:
+ downsample_ratios.append(1.0)
+ h, w, _ = img.shape
+ img = self.aug(T.AugInput(img)).apply_image(img)
+ img = to_tensor(img.astype('float32').transpose(2, 0, 1), 'cpu')
+ inputs.append({'image': img, 'height': h, 'width': w})
+
+ preds = []
+ N_imgs = len(inputs)
+ prog_bar = tqdm(total=N_imgs, desc='Batch Detection')
+ sid, last_fail_id = 0, 0
+ cur_bs = self.batch_size
+ while sid < N_imgs:
+ eid = min(sid + cur_bs, N_imgs)
+ try:
+ preds_round = self.model(inputs[sid:eid])
+ except Exception as e:
+ get_logger(brief=True).error(f'Image No.{sid}: {e}. Try to fix it.')
+ if cur_bs > 1:
+ cur_bs = (cur_bs - 1) // 2 + 1 # reduce the batch size by half
+ assert cur_bs > 0, 'Invalid batch size.'
+ get_logger(brief=True).info(f'Adjust the batch size to {cur_bs}.')
+ else:
+ get_logger(brief=True).error(f'Can\'t afford image No.{sid} even with batch_size=1, skip.')
+ preds.append(None) # placeholder for the failed image
+ sid += 1
+ last_fail_id = sid
+ continue
+ # Save the results.
+ preds.extend([{
+ 'pred_classes' : pred['instances'].pred_classes.cpu(),
+ 'scores' : pred['instances'].scores.cpu(),
+ 'pred_boxes' : pred['instances'].pred_boxes.tensor.cpu(),
+ } for pred in preds_round])
+
+ prog_bar.update(eid - sid)
+ sid = eid
+ # # Adjust the batch size.
+ # if last_fail_id < sid - cur_bs * 2:
+ # cur_bs = min(cur_bs * 2, self.batch_size) # gradually recover the batch size
+ prog_bar.close()
+
+ return preds, downsample_ratios
diff --git a/lib/platform/__init__.py b/lib/platform/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..0d1a44439fbe9a87308e732a94a334e069e842fc
--- /dev/null
+++ b/lib/platform/__init__.py
@@ -0,0 +1,6 @@
+from .proj_manager import ProjManager as PM
+from .config_utils import (
+ print_cfg,
+ entrypoint,
+ entrypoint_with_args,
+)
\ No newline at end of file
diff --git a/lib/platform/config_utils.py b/lib/platform/config_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..c327d42cd460bb36ca16a33aa0bc12e27de1dff9
--- /dev/null
+++ b/lib/platform/config_utils.py
@@ -0,0 +1,251 @@
+import sys
+import json
+import rich
+import rich.text
+import rich.tree
+import rich.syntax
+import hydra
+from typing import List, Optional, Union, Any
+from pathlib import Path
+from omegaconf import OmegaConf, DictConfig, ListConfig
+from pytorch_lightning.utilities import rank_zero_only
+
+from lib.info.log import get_logger
+
+from .proj_manager import ProjManager as PM
+
+
+def get_PM_info_dict():
+ ''' Get a OmegaConf object containing the information from the ProjManager. '''
+ PM_info = OmegaConf.create({
+ '_pm_': {
+ 'root' : str(PM.root),
+ 'inputs' : str(PM.inputs),
+ 'outputs': str(PM.outputs),
+ }
+ })
+ return PM_info
+
+
+def get_PM_info_list():
+ ''' Get a list containing the information from the ProjManager. '''
+ PM_info = [
+ f'_pm_.root={str(PM.root)}',
+ f'_pm_.inputs={str(PM.inputs)}',
+ f'_pm_.outputs={str(PM.outputs)}',
+ ]
+ return PM_info
+
+
+def entrypoint_with_args(*args, log_cfg=True, **kwargs):
+ '''
+ This decorator extends the `hydra.main` decorator in these parts:
+ - Inject some runtime-known arguments, e.g., `proj_root`.
+ - Enable additional arguments that needn't to be specified in command line.
+ - Positional arguments are added to the command line arguments directly, so make sure they are valid.
+ - e.g., \'exp=<...>\', \'+extra=<...>\', etc.
+ - Key-specified arguments have the same effect as command line arguments {k}={v}.
+ - Check the validation of experiment name.
+ '''
+
+ overrides = get_PM_info_list()
+
+ for arg in args:
+ overrides.append(arg)
+
+ for k, v in kwargs.items():
+ overrides.append(f'{k}={v}')
+
+ overrides.extend(sys.argv[1:])
+
+ def entrypoint_wrapper(func):
+ # Import extra pre-specified arguments.
+ if len(overrides) > 0:
+ # The args from command line have higher priority, so put them in the back.
+ sys.argv = sys.argv[:1] + overrides + sys.argv[1:]
+ _log_exp_info(func.__name__, overrides)
+
+ @hydra.main(version_base=None, config_path=str(PM.configs), config_name='base.yaml')
+ def entrypoint_preprocess(cfg:DictConfig):
+ # Resolve the references and make it editable.
+ cfg = unfold_cfg(cfg)
+
+ # Print out the configuration files.
+ if log_cfg and cfg.get('show_cfg', True):
+ sum_keys = ['output_dir', 'pipeline.name', 'data.name', 'exp_name', 'exp_tag']
+ print_cfg(cfg, sum_keys=sum_keys)
+
+ # Check the validation of experiment name.
+ if cfg.get('exp_name') is None:
+ get_logger(brief=True).fatal(f'`exp_name` is not given! You may need to add `exp=` to the command line.')
+ raise ValueError('`exp_name` is not given!')
+
+ # Bind config.
+ PM.init_with_cfg(cfg)
+ try:
+ with PM.time_monitor('exp', f'Main part of experiment `{cfg.exp_name}`.'):
+ # Enter the main function.
+ func(cfg)
+ except Exception as e:
+ raise e
+ finally:
+ PM.time_monitor.report(level='global')
+
+ # TODO: Wrap a notifier here.
+
+ return entrypoint_preprocess
+
+
+ return entrypoint_wrapper
+
+ #! This implementation can't dump the config files in default ways. In order to keep c
+ # def entrypoint_wrapper(func):
+ # def entrypoint_preprocess():
+ # # Initialize the configuration module.
+ # with hydra.initialize_config_dir(version_base=None, config_dir=str(PM.configs)):
+ # get_logger(brief=True).info(f'Exp entry `{func.__name__}` is called with overrides: {overrides}')
+ # cfg = hydra.compose(config_name='base', overrides=overrides)
+
+ # cfg4dump_raw = cfg.copy() # store the folded raw configuration files
+ # # Resolve the references and make it editable.
+ # cfg = unfold_cfg(cfg)
+
+ # # Print out the configuration files.
+ # if log_cfg:
+ # sum_keys = ['pipeline.name', 'data.name', 'exp_name']
+ # print_cfg(cfg, sum_keys=sum_keys)
+ # # Check the validation of experiment name.
+ # if cfg.get('exp_name') is None:
+ # get_logger().fatal(f'`exp_name` is not given! You may need to add `exp=` to the command line.')
+ # raise ValueError('`exp_name` is not given!')
+ # # Enter the main function.
+ # func(cfg)
+ # return entrypoint_preprocess
+ # return entrypoint_wrapper
+
+def entrypoint(func):
+ '''
+ This decorator extends the `hydra.main` decorator in these parts:
+ - Inject some runtime-known arguments, e.g., `proj_root`.
+ - Check the validation of experiment name.
+ '''
+ return entrypoint_with_args()(func)
+
+
+def unfold_cfg(
+ cfg : Union[DictConfig, Any],
+):
+ '''
+ Unfold the configuration files, i.e. from structured mode to container mode and recreate the
+ configuration files. It will resolve all the references and make the config editable.
+
+ ### Args
+ - cfg: DictConfig or None
+
+ ### Returns
+ - cfg: DictConfig or None
+ '''
+ if cfg is None:
+ return None
+
+ cfg_container = OmegaConf.to_container(cfg, resolve=True)
+ cfg = OmegaConf.create(cfg_container)
+ return cfg
+
+
+def recursively_simplify_cfg(
+ node : DictConfig,
+ hide_misc : bool = True,
+):
+ if isinstance(node, DictConfig):
+ for k in list(node.keys()):
+ # We delete some terms that are not commonly concerned.
+ if hide_misc:
+ if k in ['_hub_', 'hydra', 'job_logging']:
+ node.__delattr__(k)
+ continue
+ node[k] = recursively_simplify_cfg(node[k], hide_misc)
+ elif isinstance(node, ListConfig):
+ if len(node) > 0 and all([
+ not isinstance(x, DictConfig) \
+ and not isinstance(x, ListConfig) \
+ for x in node
+ ]):
+ # We fold all lists of basic elements (int, float, ...) into a single line if possible.
+ folded_list_str = '*' + str(list(node))
+ node = folded_list_str if len(folded_list_str) < 320 else node
+ else:
+ for i in range(len(node)):
+ node[i] = recursively_simplify_cfg(node[i], hide_misc)
+ return node
+
+
+@rank_zero_only
+def print_cfg(
+ cfg : Optional[DictConfig],
+ title : str ='cfg',
+ sum_keys: List[str] = [],
+ show_all: bool = False
+):
+ '''
+ Print configuration files using rich.
+
+ ### Args
+ - cfg: DictConfig or None
+ - If None, print nothing.
+ - sum_keys: List[str], default []
+ - If keys given in the list exist in the first level of the configuration files,
+ they will be printed in the summary part.
+ - show_all: bool, default False
+ - If False, hide terms starts with `_` in the configuration files's first level
+ and some hydra supporting configs.
+ '''
+
+ theme = 'coffee'
+ style = 'dim'
+
+ tf_dict = { True: '◼', False: '◻' }
+ print_setting = f'<< {tf_dict[show_all]} SHOW_ALL >>'
+ tree = rich.tree.Tree(f'⌾ {title} - {print_setting}', style=style, guide_style=style)
+
+ if cfg is None:
+ tree.add('None')
+ rich.print(tree)
+ return
+
+ # Clone a new one to avoid changing the original configuration files.
+ cfg = cfg.copy()
+ cfg = unfold_cfg(cfg)
+
+ if not show_all:
+ cfg = recursively_simplify_cfg(cfg)
+
+ cfg_yaml = OmegaConf.to_yaml(cfg)
+ cfg_yaml = rich.syntax.Syntax(cfg_yaml, 'yaml', theme=theme, line_numbers=True)
+ tree.add(cfg_yaml)
+
+ # Add a summary containing information only is commonly concerned.
+ if len(sum_keys) > 0:
+ concerned = {}
+ for k_str in sum_keys:
+ k_list = k_str.split('.')
+ tgt = cfg
+ for k in k_list:
+ if tgt is not None:
+ tgt = tgt.get(k)
+ if tgt is not None:
+ concerned[k_str] = tgt
+ else:
+ get_logger().warning(f'Key `{k_str}` is not found in the configuration files.')
+
+ tree.add(rich.syntax.Syntax(OmegaConf.to_yaml(concerned), 'yaml', theme=theme))
+
+ rich.print(tree)
+
+
+@rank_zero_only
+def _log_exp_info(
+ func_name : str,
+ overrides : List[str],
+):
+ get_logger(brief=True).info(f'Exp entry `{func_name}` is called with overrides: {overrides}')
\ No newline at end of file
diff --git a/lib/platform/monitor/__init__.py b/lib/platform/monitor/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..8165bfcd2187bee08b1a0ec904bccfd3b306322d
--- /dev/null
+++ b/lib/platform/monitor/__init__.py
@@ -0,0 +1,2 @@
+from .time import *
+from .gpu import *
\ No newline at end of file
diff --git a/lib/platform/monitor/gpu.py b/lib/platform/monitor/gpu.py
new file mode 100644
index 0000000000000000000000000000000000000000..19e6aa1eb1d26890b3dae68bdbd8ba6fc90bb42a
--- /dev/null
+++ b/lib/platform/monitor/gpu.py
@@ -0,0 +1,92 @@
+import time
+import torch
+import inspect
+
+
+def fold_path(fn:str):
+ ''' Fold a path like `from/to/file.py` to relative `f/t/file.py`. '''
+ return '/'.join([p[:1] for p in fn.split('/')[:-1]]) + '/' + fn.split('/')[-1]
+
+
+def summary_frame_info(frame:inspect.FrameInfo):
+ ''' Convert a FrameInfo object to a summary string. '''
+ return f'{frame.function} @ {fold_path(frame.filename)}:{frame.lineno}'
+
+
+class GPUMonitor():
+ '''
+ This monitor is designed for GPU memory analysis. It records the peak memory usage in a period of time.
+ A snapshot will record the peak memory usage until the snapshot is taken. (After init / reset / previous snapshot.)
+ '''
+
+ def __init__(self):
+ self.reset()
+ self.clear()
+ self.log_fn = 'gpu_monitor.log'
+
+
+ def snapshot(self, desc:str='snapshot'):
+ timestamp = time.time()
+ caller_frame = inspect.stack()[1]
+ peak_MB = torch.cuda.max_memory_allocated() / 1024 / 1024
+ free_mem, total_mem = torch.cuda.mem_get_info(0)
+ free_mem_MB, total_mem_MB = free_mem / 1024 / 1024, total_mem / 1024 / 1024
+
+ record = {
+ 'until' : time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(timestamp)),
+ 'until_raw' : timestamp,
+ 'position' : summary_frame_info(caller_frame),
+ 'peak' : peak_MB,
+ 'peak_msg' : f'{peak_MB:.2f} MB',
+ 'free' : free_mem_MB,
+ 'total' : total_mem_MB,
+ 'free_msg' : f'{free_mem_MB:.2f} MB',
+ 'total_msg' : f'{total_mem_MB:.2f} MB',
+ 'desc' : desc,
+ }
+
+ self.max_peak = max(self.max_peak_MB, peak_MB)
+
+ self.records.append(record)
+ self._update_log(record)
+
+ self.reset()
+ return record
+
+
+ def report_latest(self, k:int=1):
+ import rich
+ caller_frame = inspect.stack()[1]
+ caller_info = summary_frame_info(caller_frame)
+ rich.print(f'{caller_info} -> latest {k} records:')
+ for rid, record in enumerate(self.records[-k:]):
+ msg = self._generate_log_msg(record)
+ rich.print(msg)
+
+
+ def report_all(self):
+ self.report_latest(len(self.records))
+
+
+ def reset(self):
+ torch.cuda.reset_peak_memory_stats()
+ return
+
+
+ def clear(self):
+ self.records = []
+ self.max_peak_MB = 0
+
+ def _generate_log_msg(self, record):
+ time = record['until']
+ peak = record['peak']
+ desc = record['desc']
+ position = record['position']
+ msg = f'[{time}] ⛰️ {peak:>8.2f} MB 📌 {desc} 🌐 {position}'
+ return msg
+
+
+ def _update_log(self, record):
+ msg = self._generate_log_msg(record)
+ with open(self.log_fn, 'a') as f:
+ f.write(msg + '\n')
diff --git a/lib/platform/monitor/time.py b/lib/platform/monitor/time.py
new file mode 100644
index 0000000000000000000000000000000000000000..7682fa689a5d53f59714ddbf612a170fced2c28f
--- /dev/null
+++ b/lib/platform/monitor/time.py
@@ -0,0 +1,233 @@
+import time
+import atexit
+import inspect
+import torch
+
+from typing import Optional, Union, List
+from pathlib import Path
+from concurrent.futures import ThreadPoolExecutor
+
+
+def fold_path(fn:str):
+ ''' Fold a path like `from/to/file.py` to relative `f/t/file.py`. '''
+ return '/'.join([p[:1] for p in fn.split('/')[:-1]]) + '/' + fn.split('/')[-1]
+
+
+def summary_frame_info(frame:inspect.FrameInfo):
+ ''' Convert a FrameInfo object to a summary string. '''
+ return f'{frame.function} @ {fold_path(frame.filename)}:{frame.lineno}'
+
+
+class TimeMonitorDisabled:
+ def foo(self, *args, **kwargs):
+ return
+
+ def __init__(self, log_folder:Optional[Union[str, Path]]=None, record_birth_block:bool=False):
+ self.tick = self.foo
+ self.report = self.foo
+ self.clear = self.foo
+ self.dump_statistics = self.foo
+
+ def __call__(self, *args, **kwargs):
+ return self
+ def __enter__(self):
+ return self
+ def __exit__(self, exc_type, exc_val, exc_tb):
+ return
+
+
+class TimeMonitor:
+ '''
+ It is supposed to be used like this:
+
+ time_monitor = TimeMonitor()
+
+ with time_monitor('test_block', 'Block that does something.') as tm:
+ do_something()
+
+ time_monitor.report()
+ '''
+
+ def __init__(self, log_folder:Optional[Union[str, Path]]=None, record_birth_block:bool=False):
+ if log_folder is not None:
+ self.log_folder = Path(log_folder) if isinstance(log_folder, str) else log_folder
+ self.log_folder.mkdir(parents=True, exist_ok=True)
+ log_fn = self.log_folder / 'readable.log'
+ self.log_fh = open(log_fn, 'w') # Log file handler.
+ self.log_fh.write('=== New Exp ===\n')
+ else:
+ self.log_folder = None
+ self.clear()
+ self.current_block_uid_stack : List = [] # Unique block id stack for recording.
+ self.current_block_aid_stack : List = [] # Block id stack for accumulated cost analysis.
+
+ # Specially add a global start and end block.
+ self.record_birth_block = record_birth_block and log_folder is not None
+ if self.record_birth_block:
+ self.__call__('monitor_birth', 'Since the monitor is constructed.')
+ self.__enter__()
+
+ # Register the exit hook to dump the data safely.
+ atexit.register(self._die_hook)
+
+
+ def __call__(self, block_name:str, block_desc:Optional[str]=None):
+ ''' Set up the name of the context for a block. '''
+ # 1. Format the block name.
+ block_name = block_name.replace('/', '-').replace(' ', '-')
+ block_name_recursive = '/'.join([s.split('/')[-1] for s in self.current_block_aid_stack] + [block_name]) # Tree structure block name.
+ # 2. Get a unique name for the block record.
+ block_postfixed = 0
+ while f'{block_name_recursive}_{block_postfixed}' in self.block_info:
+ block_postfixed += 1
+ # 3. Get the caller frame information.
+ caller_frame = inspect.stack()[1]
+ block_position = summary_frame_info(caller_frame)
+ # 4. Initialize the block information.
+ self.current_block_uid_stack.append(f'{block_name_recursive}_{block_postfixed}')
+ self.current_block_aid_stack.append(block_name)
+ self.block_info[self.current_block_uid_stack[-1]] = {
+ 'records' : [],
+ 'position' : block_position,
+ 'desc' : block_desc,
+ }
+
+ return self
+
+
+ def __enter__(self):
+ caller_frame = inspect.stack()[1]
+ record = self._tick_record(caller_frame, 'Start of the block.')
+ self.block_info[self.current_block_uid_stack[-1]]['records'].append(record)
+ return self
+
+
+ def __exit__(self, exc_type, exc_val, exc_tb):
+ caller_frame = inspect.stack()[1]
+ record = self._tick_record(caller_frame, 'End of the block.')
+ self.block_info[self.current_block_uid_stack[-1]]['records'].append(record)
+
+ # Finish one block.
+ curr_block_uid = self.current_block_uid_stack.pop()
+ curr_block_aid = '/'.join(self.current_block_aid_stack)
+ self.current_block_aid_stack.pop()
+
+ self.finished_blocks.append(curr_block_uid)
+ elapsed = self.block_info[curr_block_uid]['records'][-1]['timestamp'] \
+ - self.block_info[curr_block_uid]['records'][0]['timestamp']
+ self.block_cost[curr_block_aid] = self.block_cost.get(curr_block_aid, 0) + elapsed
+
+ if hasattr(self, 'dump_thread'):
+ self.dump_thread.result()
+ with ThreadPoolExecutor() as executor:
+ self.dump_thread = executor.submit(self.dump_statistics)
+
+
+
+ def tick(self, desc:str=''):
+ '''
+ Record a intermediate timestamp. These records are only for in-block analysis,
+ and will be ignored when analyzing in global view.
+ '''
+ caller_frame = inspect.stack()[1]
+ record = self._tick_record(caller_frame, desc)
+ self.block_info[self.current_block_uid_stack[-1]]['records'].append(record)
+ return
+
+
+ def report(self, level:Union[str, List[str]]='global'):
+ import rich
+
+ caller_frame = inspect.stack()[1]
+ caller_info = summary_frame_info(caller_frame)
+
+ if isinstance(level, str):
+ level = [level]
+
+ for lv in level: # To make sure we can output in order.
+ if lv == 'block':
+ rich.print(f'[bold underline][EA-B][/bold underline] {caller_info} -> blocks level records:')
+ for block_name in self.finished_blocks:
+ msg = '\t' + self._generate_block_msg(block_name).replace('\n\t', '\n\t\t')
+ rich.print(msg)
+ elif lv == 'global':
+ rich.print(f'[bold underline][EA-G][/bold underline] {caller_info} -> global efficiency analysis:')
+ for block_name, cost in self.block_cost.items():
+ rich.print(f'\t{block_name}: {cost:.2f} sec')
+
+
+ def clear(self):
+ self.finished_blocks = []
+ self.block_info = {}
+ self.block_cost = {}
+
+
+ def dump_statistics(self):
+ ''' Dump the logging raw data for post analysis. '''
+ if self.log_folder is None:
+ return
+
+ dump_fn = self.log_folder / 'statistics.pkl'
+ with open(dump_fn, 'wb') as f:
+ import pickle
+ pickle.dump({
+ 'finished_blocks' : self.finished_blocks,
+ 'block_info' : self.block_info,
+ 'block_cost' : self.block_cost,
+ 'curr_aid_stack' : self.current_block_aid_stack, # nonempty when when errors happen inside a block
+ }, f)
+
+
+ # TODO: Draw a graph to visualize the time consumption.
+
+
+ def _tick_record(self, caller_frame, desc:Optional[str]=''):
+ # 1. Generate the record.
+ torch.cuda.synchronize()
+ timestamp = time.time()
+ readable_time = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(timestamp))
+ position = summary_frame_info(caller_frame)
+ record = {
+ 'time' : readable_time,
+ 'timestamp' : timestamp,
+ 'position' : position,
+ 'desc' : desc,
+ }
+
+ # 2. Log the record.
+ if self.log_folder is not None:
+ block_uid = self.current_block_uid_stack[-1]
+ log_msg = f'[{readable_time}] 🗂️ {block_uid} 📌 {desc} 🌐 {position}'
+ self.log_fh.write(log_msg + '\n')
+
+ return record
+
+
+ def _generate_block_msg(self, block_name):
+ block_info = self.block_info[block_name]
+ block_position = block_info['position']
+ block_desc = block_info['desc']
+ records = block_info['records']
+ msg = f'🗂️ {block_name} 📌 {block_desc} 🌐 {block_position}'
+ for rid, record in enumerate(records):
+ readable_time = record['time']
+ tick_desc = record['desc']
+ tick_position = record['position']
+ if rid > 0:
+ prev_record = records[rid-1]
+ tick_elapsed = record['timestamp'] - prev_record['timestamp']
+ tick_elapsed = f'{tick_elapsed:.2f} s'
+ else:
+ tick_elapsed = 'N/A'
+ msg += f'\n\t[{readable_time}] ⏳ {tick_elapsed} 📌 {tick_desc} 🌐 {tick_position}'
+ return msg
+
+
+ def _die_hook(self):
+ if self.record_birth_block:
+ self.__exit__(None, None, None)
+
+ self.dump_statistics()
+
+ if self.log_folder is not None:
+ self.log_fh.close()
\ No newline at end of file
diff --git a/lib/platform/proj_manager.py b/lib/platform/proj_manager.py
new file mode 100644
index 0000000000000000000000000000000000000000..8145b5008126cd0c9e95094e3ee62396db983ce9
--- /dev/null
+++ b/lib/platform/proj_manager.py
@@ -0,0 +1,24 @@
+from pathlib import Path
+from .monitor import TimeMonitor, TimeMonitorDisabled
+
+class ProjManager():
+ root = Path(__file__).parent.parent.parent # root / lib / utils / path_manager.py
+ assert (root.exists()), 'Can\'t find the path of project root.'
+
+ configs = root / 'configs' # Generally, you are not supposed to access deep config through path.
+ inputs = root / 'data_inputs'
+ outputs = root / 'data_outputs'
+ assert (configs.exists()), 'Make sure you have a \'configs\' folder in the root directory.'
+ assert (inputs.exists()), 'Make sure you have a \'data_inputs\' folder in the root directory.'
+ assert (outputs.exists()), 'Make sure you have a \'data_outputs\' folder in the root directory.'
+
+ # Default values.
+ cfg = None
+ time_monitor = TimeMonitorDisabled()
+
+ @staticmethod
+ def init_with_cfg(cfg):
+ ProjManager.cfg = cfg
+ ProjManager.exp_outputs = Path(cfg.output_dir)
+ if cfg.get('enable_time_monitor', False):
+ ProjManager.time_monitor = TimeMonitor(ProjManager.exp_outputs, record_birth_block=False)
\ No newline at end of file
diff --git a/lib/platform/sliding_batches/__init__.py b/lib/platform/sliding_batches/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..9029de44049bfaec9c5048d43ef0f15903335e32
--- /dev/null
+++ b/lib/platform/sliding_batches/__init__.py
@@ -0,0 +1,42 @@
+from .basic import bsb
+from .adaptable.v1 import asb
+
+# The batch manager is used to quickly initialize a batched task.
+# The Usage of the batch manager is as follows:
+
+
+def eg_bbm():
+ ''' Basic version of the batch manager. '''
+ task_len = 1e6
+ task_things = [i for i in range(int(task_len))]
+
+ for bw in bsb(total=task_len, batch_size=300, enable_tqdm=True):
+ sid = bw.sid
+ eid = bw.eid
+ round_things = task_things[sid:eid]
+ # Do something with `round_things`.
+
+
+def eg_asb():
+ ''' Basic version of the batch manager. '''
+ task_len = 1024
+ task_things = [i for i in range(int(task_len))]
+
+ lb, ub = 1, 300 # lower & upper bound of batch size
+ for bw in asb(total=task_len, bs_scope=(lb, ub), enable_tqdm=True):
+ sid = bw.sid
+ eid = bw.eid
+ round_things = task_things[sid:eid]
+ # Do something with `round_things`.
+
+ try:
+ # Do something with `round_things`.
+ pass
+ except Exception as e:
+ if not bw.shrink():
+ # In this case, it means task_things[sid:sid+lb] is still too large to handle.
+ # So you need to do something to handle this situation.
+ pass
+ continue #! DO NOT FORGET CONTINUE
+
+ # Do something with `round_things` if no exception is raised.
diff --git a/lib/platform/sliding_batches/adaptable/v1.py b/lib/platform/sliding_batches/adaptable/v1.py
new file mode 100644
index 0000000000000000000000000000000000000000..969a6ed1b1fdb4b2ae215984f996e6331bc3ff19
--- /dev/null
+++ b/lib/platform/sliding_batches/adaptable/v1.py
@@ -0,0 +1,69 @@
+import math
+from tqdm import tqdm
+from contextlib import contextmanager
+from typing import Tuple, Union
+from ..basic import BasicBatchWindow, bsb
+
+class AdaptableBatchWindow(BasicBatchWindow):
+ def __init__(self, sid, eid, min_B):
+ self.start_id = sid
+ self.end_id = eid
+ self.min_B = min_B
+ self.shrinking = False
+
+ def shrink(self):
+ if self.size <= self.min_B:
+ return False
+ else:
+ self.shrinking = True
+ return True
+
+
+class asb(bsb):
+ def __init__(
+ self,
+ total : int,
+ bs_scope : Union[Tuple[int, int], int],
+ enable_tqdm : bool = False,
+ ):
+ ''' Simple binary strategy. '''
+ # Static hyperparameters.
+ self.total = int(total)
+ if isinstance(bs_scope, int):
+ self.min_B = 1
+ self.max_B = bs_scope
+ else:
+ self.min_B, self.max_B = bs_scope # lower & upper bound of batch size
+ # Dynamic state.
+ self.B = self.max_B # current batch size
+ self.tqdm = tqdm(total=self.total) if enable_tqdm else None
+ self.cur_window = AdaptableBatchWindow(sid=-1, eid=0, min_B=self.min_B) # starting window
+ self.last_shrink_id = None
+
+ def __next__(self):
+ if self.cur_window.shrinking:
+ sid = self.cur_window.sid
+ self.shrink_B(sid)
+ else:
+ sid = self.cur_window.eid
+ self.recover_B(sid)
+
+ if sid >= self.total:
+ if self.tqdm: self.tqdm.close()
+ raise StopIteration
+
+ eid = min(sid + self.B, self.total)
+ self.cur_window = AdaptableBatchWindow(sid, eid, min_B=self.min_B)
+ if self.tqdm: self.tqdm.update(eid - sid)
+ return self.cur_window
+
+ def shrink_B(self, cur_id:int):
+ self.last_shrink_id = cur_id
+ self.cur_window.shrinking = False
+ self.B = max(math.ceil(self.B/2), self.min_B)
+
+ def recover_B(self, cur_id:int):
+ if self.last_shrink_id and self.B < self.max_B:
+ newer_B = min(self.B * 2, self.max_B)
+ if self.last_shrink_id < cur_id - newer_B:
+ self.B = newer_B
\ No newline at end of file
diff --git a/lib/platform/sliding_batches/basic.py b/lib/platform/sliding_batches/basic.py
new file mode 100644
index 0000000000000000000000000000000000000000..6b243720ea3d3bd2a4de9a0de4950ae34f577711
--- /dev/null
+++ b/lib/platform/sliding_batches/basic.py
@@ -0,0 +1,47 @@
+from tqdm import tqdm
+
+class BasicBatchWindow():
+ def __init__(self, sid, eid):
+ self.start_id = sid
+ self.end_id = eid
+
+ @property
+ def sid(self):
+ return self.start_id
+
+ @property
+ def eid(self):
+ return self.end_id
+
+ @property
+ def size(self):
+ return self.eid - self.sid
+
+
+class bsb():
+ def __init__(
+ self,
+ total : int,
+ batch_size : int,
+ enable_tqdm : bool = False,
+ ):
+ # Static hyperparameters.
+ self.total = int(total)
+ self.B = batch_size
+ # Dynamic state.
+ self.tqdm = tqdm(total=self.total) if enable_tqdm else None
+ self.cur_window = BasicBatchWindow(-1, 0) # starting window
+
+ def __iter__(self):
+ return self
+
+ def __next__(self):
+ if self.cur_window.eid >= self.total:
+ if self.tqdm: self.tqdm.close()
+ raise StopIteration
+ if self.tqdm: self.tqdm.update(self.cur_window.eid - self.cur_window.sid)
+
+ sid = self.cur_window.eid
+ eid = min(sid + self.B, self.total)
+ self.cur_window = BasicBatchWindow(sid, eid)
+ return self.cur_window
\ No newline at end of file
diff --git a/lib/utils/bbox.py b/lib/utils/bbox.py
new file mode 100644
index 0000000000000000000000000000000000000000..ca208d77240fcffda3ada04769d001ef0d297ef9
--- /dev/null
+++ b/lib/utils/bbox.py
@@ -0,0 +1,310 @@
+from lib.kits.basic import *
+
+from .data import to_tensor
+
+
+def lurb_to_cwh(
+ lurb : Union[list, np.ndarray, torch.Tensor],
+):
+ '''
+ Convert the left-upper-right-bottom format to the center-width-height format.
+
+ ### Args
+ - lurb: Union[list, np.ndarray, torch.Tensor], (..., 4)
+ - The left-upper-right-bottom format bounding box.
+
+ ### Returns
+ - Union[list, np.ndarray, torch.Tensor], (..., 4)
+ - The center-width-height format bounding box.
+ '''
+ lurb, recover_type_back = to_tensor(lurb, device=None, temporary=True)
+ assert lurb.shape[-1] == 4, f"Invalid shape: {lurb.shape}, should be (..., 4)"
+
+ c = (lurb[..., :2] + lurb[..., 2:]) / 2 # (..., 2)
+ wh = lurb[..., 2:] - lurb[..., :2] # (..., 2)
+
+ cwh = torch.cat([c, wh], dim=-1) # (..., 4)
+ return recover_type_back(cwh)
+
+
+def cwh_to_lurb(
+ cwh : Union[list, np.ndarray, torch.Tensor],
+):
+ '''
+ Convert the center-width-height format to the left-upper-right-bottom format.
+
+ ### Args
+ - cwh: Union[list, np.ndarray, torch.Tensor], (..., 4)
+ - The center-width-height format bounding box.
+
+ ### Returns
+ - Union[list, np.ndarray, torch.Tensor], (..., 4)
+ - The left-upper-right-bottom format bounding box.
+ '''
+ cwh, recover_type_back = to_tensor(cwh, device=None, temporary=True)
+ assert cwh.shape[-1] == 4, f"Invalid shape: {cwh.shape}, should be (..., 4)"
+
+ l = cwh[..., :2] - cwh[..., 2:] / 2 # (..., 2)
+ r = cwh[..., :2] + cwh[..., 2:] / 2 # (..., 2)
+
+ lurb = torch.cat([l, r], dim=-1) # (..., 4)
+ return recover_type_back(lurb)
+
+
+def cwh_to_cs(
+ cwh : Union[list, np.ndarray, torch.Tensor],
+ reduce : Optional[str] = None,
+):
+ '''
+ Convert the center-width-height format to the center-scale format.
+ *Only works when width and height are the same.*
+
+ ### Args
+ - cwh: Union[list, np.ndarray, torch.Tensor], (..., 4)
+ - The center-width-height format bounding box.
+ - reduce: Optional[str], default None, valid values: None, 'max'
+ - Determine how to reduce the width and height to a single scale.
+
+ ### Returns
+ - Union[list, np.ndarray, torch.Tensor], (..., 3)
+ - The center-scale format bounding box.
+ '''
+ cwh, recover_type_back = to_tensor(cwh, device=None, temporary=True)
+ assert cwh.shape[-1] == 4, f"Invalid shape: {cwh.shape}, should be (..., 4)"
+
+ if reduce is None:
+ if (cwh[..., 2] != cwh[..., 3]).any():
+ get_logger().warning(f"Width and height are supposed to be the same, but they're not. The larger one will be used.")
+
+ c = cwh[..., :2] # (..., 2)
+ s = cwh[..., 2:].max(dim=-1)[0] # (...,)
+
+ cs = torch.cat([c, s[..., None]], dim=-1) # (..., 3)
+ return recover_type_back(cs)
+
+
+def cs_to_cwh(
+ cs : Union[list, np.ndarray, torch.Tensor],
+):
+ '''
+ Convert the center-scale format to the center-width-height format.
+
+ ### Args
+ - cs: Union[list, np.ndarray, torch.Tensor], (..., 3)
+ - The center-scale format bounding box.
+
+ ### Returns
+ - Union[list, np.ndarray, torch.Tensor], (..., 4)
+ - The center-width-height format bounding box.
+ '''
+ cs, recover_type_back = to_tensor(cs, device=None, temporary=True)
+ assert cs.shape[-1] == 3, f"Invalid shape: {cs.shape}, should be (..., 3)"
+
+ c = cs[..., :2] # (..., 2)
+ s = cs[..., 2] # (...,)
+
+ cwh = torch.cat([c, s[..., None], s[..., None]], dim=-1) # (..., 4)
+ return recover_type_back(cwh)
+
+
+def lurb_to_cs(
+ lurb : Union[list, np.ndarray, torch.Tensor],
+):
+ '''
+ Convert the left-upper-right-bottom format to the center-scale format.
+ *Only works when width and height are the same.*
+
+ ### Args
+ - lurb: Union[list, np.ndarray, torch.Tensor], (..., 4)
+ - The left-upper-right-bottom format bounding box.
+
+ ### Returns
+ - Union[list, np.ndarray, torch.Tensor], (..., 3)
+ - The center-scale format bounding box.
+ '''
+ return cwh_to_cs(lurb_to_cwh(lurb), reduce='max')
+
+
+def cs_to_lurb(
+ cs : Union[list, np.ndarray, torch.Tensor],
+):
+ '''
+ Convert the center-scale format to the left-upper-right-bottom format.
+
+ ### Args
+ - cs: Union[list, np.ndarray, torch.Tensor], (..., 3)
+ - The center-scale format bounding box.
+
+ ### Returns
+ - Union[list, np.ndarray, torch.Tensor], (..., 4)
+ - The left-upper-right-bottom format bounding box.
+ '''
+ return cwh_to_lurb(cs_to_cwh(cs))
+
+
+def lurb_to_luwh(
+ lurb : Union[list, np.ndarray, torch.Tensor],
+):
+ '''
+ Convert the left-upper-right-bottom format to the left-upper-width-height format.
+
+ ### Args
+ - lurb: Union[list, np.ndarray, torch.Tensor]
+ - The left-upper-right-bottom format bounding box.
+
+ ### Returns
+ - Union[list, np.ndarray, torch.Tensor]
+ - The left-upper-width-height format bounding box.
+ '''
+ lurb, recover_type_back = to_tensor(lurb, device=None, temporary=True)
+ assert lurb.shape[-1] == 4, f"Invalid shape: {lurb.shape}, should be (..., 4)"
+
+ lu = lurb[..., :2] # (..., 2)
+ wh = lurb[..., 2:] - lurb[..., :2] # (..., 2)
+
+ luwh = torch.cat([lu, wh], dim=-1) # (..., 4)
+ return recover_type_back(luwh)
+
+
+def luwh_to_lurb(
+ luwh : Union[list, np.ndarray, torch.Tensor],
+):
+ '''
+ Convert the left-upper-width-height format to the left-upper-right-bottom format.
+
+ ### Args
+ - luwh: Union[list, np.ndarray, torch.Tensor]
+ - The left-upper-width-height format bounding box.
+
+ ### Returns
+ - Union[list, np.ndarray, torch.Tensor]
+ - The left-upper-right-bottom format bounding box.
+ '''
+ luwh, recover_type_back = to_tensor(luwh, device=None, temporary=True)
+ assert luwh.shape[-1] == 4, f"Invalid shape: {luwh.shape}, should be (..., 4)"
+
+ l = luwh[..., :2] # (..., 2)
+ r = luwh[..., :2] + luwh[..., 2:] # (..., 2)
+
+ lurb = torch.cat([l, r], dim=-1) # (..., 4)
+ return recover_type_back(lurb)
+
+
+def crop_with_lurb(data, lurb, padding=0):
+ """
+ Crop the img-like data according to the lurb bounding box.
+
+ ### Args
+ - data: Union[np.ndarray, torch.Tensor], shape (H, W, C)
+ - Data like image.
+ - lurb: Union[list, np.ndarray, torch.Tensor], shape (4,)
+ - Bounding box with [left, upper, right, bottom] coordinates.
+ - padding: int, default 0
+ - Padding value for out-of-bound areas.
+
+ ### Returns
+ - Union[np.ndarray, torch.Tensor], shape (H', W', C)
+ - Cropped image with padding if necessary.
+ """
+ data, recover_type_back = to_tensor(data, device=None, temporary=True)
+
+ # Ensure lurb is in numpy array format for indexing
+ lurb = np.array(lurb).astype(np.int64)
+ l_, u_, r_, b_ = lurb
+
+ # Determine the shape of the data.
+ H_raw, W_raw, C_raw = data.size()
+
+ # Compute the cropped patch size.
+ H_patch = b_ - u_
+ W_patch = r_ - l_
+
+ # Create an output buffer of the crop size, initialized to padding
+ if isinstance(data, np.ndarray):
+ output = np.full((H_patch, W_patch, C_raw), padding, dtype=data.dtype)
+ else:
+ output = torch.full((H_patch, W_patch, C_raw), padding, dtype=data.dtype)
+
+ # Calculate the valid region in the original data
+ valid_l_ = max(0, l_)
+ valid_u_ = max(0, u_)
+ valid_r_ = min(W_raw, r_)
+ valid_b_ = min(H_raw, b_)
+
+ # Calculate the corresponding valid region in the output
+ target_l_ = valid_l_ - l_
+ target_u_ = valid_u_ - u_
+ target_r_ = target_l_ + (valid_r_ - valid_l_)
+ target_b_ = target_u_ + (valid_b_ - valid_u_)
+
+ # Copy the valid region into the output buffer
+ output[target_u_:target_b_, target_l_:target_r_, :] = data[valid_u_:valid_b_, valid_l_:valid_r_, :]
+
+ return recover_type_back(output)
+
+
+def fit_bbox_to_aspect_ratio(
+ bbox : np.ndarray,
+ tgt_ratio : Optional[Tuple[int, int]] = None,
+ bbox_type : str = 'lurb'
+):
+ '''
+ Fit a random bounding box to a target aspect ratio through enlarging the bounding box with least change.
+
+ ### Args
+ - bbox: np.ndarray, shape is determined by `bbox_type`, e.g. for 'lurb', shape is (4,)
+ - The bounding box to be modified. The format is determined by `bbox_type`.
+ - tgt_ratio: Optional[Tuple[int, int]], default None
+ - The target aspect ratio to be matched.
+ - bbox_type: str, default 'lurb', valid values: 'lurb', 'cwh'.
+
+ ### Returns
+ - np.ndarray, shape is determined by `bbox_type`, e.g. for 'lurb', shape is (4,)
+ - The modified bounding box.
+ '''
+ bbox = bbox.copy()
+ if bbox_type == 'lurb':
+ bbx_cwh = lurb_to_cwh(bbox)
+ bbx_wh = bbx_cwh[2:]
+ elif bbox_type == 'cwh':
+ bbx_wh = bbox[2:]
+ else:
+ raise ValueError(f"Unsupported bbox type: {bbox_type}")
+
+ new_bbx_wh = expand_wh_to_aspect_ratio(bbx_wh, tgt_ratio)
+
+ if bbox_type == 'lurb':
+ bbx_cwh[2:] = new_bbx_wh
+ new_bbox = cwh_to_lurb(bbx_cwh)
+ elif bbox_type == 'cwh':
+ new_bbox = np.concatenate([bbox[:2], new_bbx_wh])
+ else:
+ raise ValueError(f"Unsupported bbox type: {bbox_type}")
+
+ return new_bbox
+
+
+def expand_wh_to_aspect_ratio(bbx_wh:np.ndarray, tgt_aspect_ratio:Optional[Tuple[int, int]]=None):
+ '''
+ Increase the size of the bounding box to match the target shape.
+ Modified from https://github.com/shubham-goel/4D-Humans/blob/6ec79656a23c33237c724742ca2a0ec00b398b53/hmr2/datasets/utils.py#L14-L33
+ '''
+ if tgt_aspect_ratio is None:
+ return bbx_wh
+
+ try:
+ bbx_w , bbx_h = bbx_wh
+ except (ValueError, TypeError):
+ get_logger().warning(f"Invalid bbox_wh content: {bbx_wh}")
+ return bbx_wh
+
+ tgt_w, tgt_h = tgt_aspect_ratio
+ if bbx_h / bbx_w < tgt_h / tgt_w:
+ new_h = max(bbx_w * tgt_h / tgt_w, bbx_h)
+ new_w = bbx_w
+ else:
+ new_h = bbx_h
+ new_w = max(bbx_h * tgt_w / tgt_h, bbx_w)
+ assert new_h >= bbx_h and new_w >= bbx_w
+
+ return to_numpy([new_w, new_h])
\ No newline at end of file
diff --git a/lib/utils/camera.py b/lib/utils/camera.py
new file mode 100644
index 0000000000000000000000000000000000000000..bea7824246dd7bc48eb26aafa53c573ad769409c
--- /dev/null
+++ b/lib/utils/camera.py
@@ -0,0 +1,251 @@
+from lib.kits.basic import *
+
+
+def T_to_Rt(
+ T : Union[torch.Tensor, np.ndarray],
+):
+ ''' Get (..., 3, 3) rotation matrix and (..., 3) translation vector from (..., 4, 4) transformation matrix. '''
+ if isinstance(T, np.ndarray):
+ T = torch.from_numpy(T).float()
+ assert T.shape[-2:] == (4, 4), f'T.shape[-2:] = {T.shape[-2:]}'
+
+ R = T[..., :3, :3]
+ t = T[..., :3, 3]
+
+ return R, t
+
+
+def Rt_to_T(
+ R : Union[torch.Tensor, np.ndarray],
+ t : Union[torch.Tensor, np.ndarray],
+):
+ ''' Get (..., 4, 4) transformation matrix from (..., 3, 3) rotation matrix and (..., 3) translation vector. '''
+ if isinstance(R, np.ndarray):
+ R = torch.from_numpy(R).float()
+ if isinstance(t, np.ndarray):
+ t = torch.from_numpy(t).float()
+ assert R.shape[-2:] == (3, 3), f'R should be a (..., 3, 3) matrix, but R.shape = {R.shape}'
+ assert t.shape[-1] == 3, f't should be a (..., 3) vector, but t.shape = {t.shape}'
+ assert R.shape[:-2] == t.shape[:-1], f'R and t should have the same shape prefix but {R.shape[:-2]} != {t.shape[:-1]}'
+
+ T = torch.eye(4, device=R.device, dtype=R.dtype).repeat(R.shape[:-2] + (1, 1)) # (..., 4, 4)
+ T[..., :3, :3] = R
+ T[..., :3, 3] = t
+
+ return T
+
+
+def apply_Ts_on_pts(Ts:torch.Tensor, pts:torch.Tensor):
+ '''
+ Apply transformation matrix `T` on the points `pts`.
+
+ ### Args
+ - Ts: torch.Tensor, (...B, 4, 4)
+ - pts: torch.Tensor, (...B, N, 3)
+ '''
+
+ assert len(pts.shape) >= 3 and pts.shape[-1] == 3, f'Shape of pts should be (...B, N, 3) but {pts.shape}'
+ assert Ts.shape[-2:] == (4, 4), f'Shape of Ts should be (..., 4, 4) but {Ts.shape}'
+ assert Ts.device == pts.device, f'Device of Ts and pts should be the same but {Ts.device} != {pts.device}'
+
+ ret_pts = torch.einsum('...ij,...nj->...ni', Ts[..., :3, :3], pts) + Ts[..., None, :3, 3]
+ ret_pts = ret_pts.squeeze(0) # (B, N, 3)
+
+ return ret_pts
+
+
+def apply_T_on_pts(T:torch.Tensor, pts:torch.Tensor):
+ '''
+ Apply transformation matrix `T` on the points `pts`.
+
+ ### Args
+ - T: torch.Tensor, (4, 4)
+ - pts: torch.Tensor, (B, N, 3) or (N, 3)
+ '''
+ unbatched = len(pts.shape) == 2
+ if unbatched:
+ pts = pts[None]
+ ret = apply_Ts_on_pts(T[None], pts)
+ return ret.squeeze(0) if unbatched else ret
+
+
+def apply_Ks_on_pts(Ks:torch.Tensor, pts:torch.Tensor):
+ '''
+ Apply intrinsic camera matrix `K` on the points `pts`, i.e. project the 3D points to 2D.
+
+ ### Args
+ - Ks: torch.Tensor, (...B, 3, 3)
+ - pts: torch.Tensor, (...B, N, 3)
+ '''
+
+ assert len(pts.shape) >= 3 and pts.shape[-1] == 3, f'Shape of pts should be (...B, N, 3) but {pts.shape}'
+ assert Ks.shape[-2:] == (3, 3), f'Shape of Ks should be (..., 3, 3) but {Ks.shape}'
+ assert Ks.device == pts.device, f'Device of Ks and pts should be the same but {Ks.device} != {pts.device}'
+
+ pts_proj_homo = torch.einsum('...ij,...vj->...vi', Ks, pts)
+ pts_proj = pts_proj_homo[..., :2] / pts_proj_homo[..., 2:3]
+ return pts_proj
+
+
+def apply_K_on_pts(K:torch.Tensor, pts:torch.Tensor):
+ '''
+ Apply intrinsic camera matrix `K` on the points `pts`, i.e. project the 3D points to 2D.
+
+ ### Args
+ - K: torch.Tensor, (3, 3)
+ - pts: torch.Tensor, (B, N, 3) or (N, 3)
+ '''
+ unbatched = len(pts.shape) == 2
+ if unbatched:
+ pts = pts[None]
+ ret = apply_Ks_on_pts(K[None], pts)
+ return ret.squeeze(0) if unbatched else ret
+
+
+def perspective_projection(
+ points : torch.Tensor,
+ translation : torch.Tensor,
+ focal_length : torch.Tensor,
+ camera_center : Optional[torch.Tensor] = None,
+ rotation : Optional[torch.Tensor] = None,
+) -> torch.Tensor:
+ '''
+ Computes the perspective projection of a set of 3D points.
+ https://github.com/shubham-goel/4D-Humans/blob/6ec79656a23c33237c724742ca2a0ec00b398b53/hmr2/utils/geometry.py#L64-L102
+
+ ### Args
+ - points: torch.Tensor, (B, N, 3)
+ - The input 3D points.
+ - translation: torch.Tensor, (B, 3)
+ - The 3D camera translation.
+ - focal_length: torch.Tensor, (B, 2)
+ - The focal length in pixels.
+ - camera_center: torch.Tensor, (B, 2)
+ - The camera center in pixels.
+ - rotation: torch.Tensor, (B, 3, 3)
+ - The camera rotation.
+
+ ### Returns
+ - torch.Tensor, (B, N, 2)
+ - The projection of the input points.
+ '''
+ B = points.shape[0]
+ if rotation is None:
+ rotation = torch.eye(3, device=points.device, dtype=points.dtype).unsqueeze(0).expand(B, -1, -1)
+ if camera_center is None:
+ camera_center = torch.zeros(B, 2, device=points.device, dtype=points.dtype)
+ # Populate intrinsic camera matrix K.
+ K = torch.zeros([B, 3, 3], device=points.device, dtype=points.dtype)
+ K[:, 0, 0] = focal_length[:, 0]
+ K[:, 1, 1] = focal_length[:, 1]
+ K[:, 2, 2] = 1.
+ K[:, :-1, -1] = camera_center
+
+ # Transform points
+ points = torch.einsum('bij, bkj -> bki', rotation, points)
+ points = points + translation.unsqueeze(1)
+
+ # Apply perspective distortion
+ projected_points = points / points[:, :, -1].unsqueeze(-1)
+
+ # Apply camera intrinsics
+ projected_points = torch.einsum('bij, bkj -> bki', K, projected_points)
+
+ return projected_points[:, :, :-1]
+
+
+def estimate_translation_np(S, joints_2d, joints_conf, focal_length=5000, img_size=224):
+ '''
+ Find camera translation that brings 3D joints S closest to 2D the corresponding joints_2d.
+ Copied from: https://github.com/nkolot/SPIN/blob/2476c436013055be5cb3905e4e4ecfa86966fac3/utils/geometry.py#L94-L132
+
+ ### Args
+ - S: shape = (25, 3)
+ - 3D joint locations.
+ - joints: shape = (25, 3)
+ - 2D joint locations and confidence.
+ ### Returns
+ - shape = (3,)
+ - Camera translation vector.
+ '''
+
+ num_joints = S.shape[0]
+ # focal length
+ f = np.array([focal_length,focal_length])
+ # optical center
+ center = np.array([img_size/2., img_size/2.])
+
+ # transformations
+ Z = np.reshape(np.tile(S[:,2],(2,1)).T,-1)
+ XY = np.reshape(S[:,0:2],-1)
+ O = np.tile(center,num_joints)
+ F = np.tile(f,num_joints)
+ weight2 = np.reshape(np.tile(np.sqrt(joints_conf),(2,1)).T,-1)
+
+ # least squares
+ Q = np.array([F*np.tile(np.array([1,0]),num_joints), F*np.tile(np.array([0,1]),num_joints), O-np.reshape(joints_2d,-1)]).T
+ c = (np.reshape(joints_2d,-1)-O)*Z - F*XY
+
+ # weighted least squares
+ W = np.diagflat(weight2)
+ Q = np.dot(W,Q)
+ c = np.dot(W,c)
+
+ # square matrix
+ A = np.dot(Q.T,Q)
+ b = np.dot(Q.T,c)
+
+ # solution
+ trans = np.linalg.solve(A, b)
+
+ return trans
+
+
+def estimate_camera_trans(
+ S : torch.Tensor,
+ joints_2d : torch.Tensor,
+ focal_length : float = 5000.,
+ img_size : float = 224.,
+ conf_thre : float = 4.,
+):
+ '''
+ Find camera translation that brings 3D joints S closest to 2D the corresponding joints_2d.
+ Modified from: https://github.com/nkolot/SPIN/blob/2476c436013055be5cb3905e4e4ecfa86966fac3/utils/geometry.py#L135-L157
+
+ ### Args
+ - S: torch.Tensor, shape = (B, J, 3)
+ - 3D joint locations.
+ - joints: torch.Tensor, shape = (B, J, 3)
+ - Ground truth 2D joint locations and confidence.
+ - focal_length: float
+ - img_size: float
+ - conf_thre: float
+ - Confidence threshold to judge whether we use gt_kp2d or that from OpenPose.
+
+ ### Returns
+ - torch.Tensor, shape = (B, 3)
+ - Camera translation vectors.
+ '''
+ device = S.device
+ B = len(S)
+
+ S = to_numpy(S)
+ joints_2d = to_numpy(joints_2d)
+ joints_conf = joints_2d[:, :, -1] # (B, J)
+ joints_2d = joints_2d[:, :, :-1] # (B, J, 2)
+ trans = np.zeros((S.shape[0], 3), dtype=np.float32)
+ # Find the translation for each example in the batch
+ for i in range(B):
+ conf_i = joints_conf[i]
+ # When the ground truth joints are not enough, use all the joints.
+ if conf_i[25:].sum() < conf_thre:
+ S_i = S[i]
+ joints_i = joints_2d[i]
+ else:
+ S_i = S[i, 25:]
+ conf_i = joints_conf[i, 25:]
+ joints_i = joints_2d[i, 25:]
+
+
+ trans[i] = estimate_translation_np(S_i, joints_i, conf_i, focal_length=focal_length, img_size=img_size)
+ return torch.from_numpy(trans).to(device)
\ No newline at end of file
diff --git a/lib/utils/ckpt.py b/lib/utils/ckpt.py
new file mode 100644
index 0000000000000000000000000000000000000000..bcf341f7b30a0f5ee9af44ffbd7ca0f2377aad34
--- /dev/null
+++ b/lib/utils/ckpt.py
@@ -0,0 +1,75 @@
+from typing import List, Dict
+
+def replace_state_dict_name_prefix(state_dict:Dict[str, object], old_prefix:str, new_prefix:str):
+ ''' Replace the prefix of the keys in the state_dict. '''
+ for old_name in list(state_dict.keys()):
+ if old_name.startswith(old_prefix):
+ new_name = new_prefix + old_name[len(old_prefix):]
+ state_dict[new_name] = state_dict.pop(old_name)
+
+ return state_dict
+
+
+def match_prefix_and_remove_state_dict(state_dict:Dict[str, object], prefix:str):
+ ''' Remove the keys in the state_dict that start with the prefix. '''
+ for name in list(state_dict.keys()):
+ if name.startswith(prefix):
+ state_dict.pop(name)
+ return state_dict
+
+
+class StateDictTree:
+ def __init__(self, keys:List[str]):
+ self.tree = {}
+ for key in keys:
+ parts = key.split('.')
+ self._recursively_add_leaf(self.tree, parts, key)
+
+
+ def rich_print(self, depth:int=-1):
+ from rich.tree import Tree
+ from rich import print
+ rich_tree = Tree('.')
+ self._recursively_build_rich_tree(rich_tree, self.tree, 0, depth)
+ print(rich_tree)
+
+ def update_node_name(self, old_name:str, new_name:str):
+ ''' Input full node name and the whole node will be moved to the new name. '''
+ old_parts = old_name.split('.')
+ # 1. Delete the old node.
+ try:
+ parent = None
+ node = self.tree
+ for part in old_parts:
+ parent = node
+ node = node[part]
+ parent.pop(old_parts[-1])
+ except KeyError:
+ raise KeyError(f'Key {old_name} not found.')
+ # 2. Add the new node.
+ new_parts = new_name.split('.')
+ self._recursively_add_leaf(self.tree, new_parts, new_name)
+
+
+ def _recursively_add_leaf(self, node, parts, full_key):
+ cur_part, rest_parts = parts[0], parts[1:]
+ if len(rest_parts) == 0:
+ assert cur_part not in node, f'Key {full_key} already exists.'
+ node[cur_part] = full_key
+ else:
+ if cur_part not in node:
+ node[cur_part] = {}
+ self._recursively_add_leaf(node[cur_part], rest_parts, full_key)
+
+
+ def _recursively_build_rich_tree(self, rich_node, dict_node, depth, max_depth:int=-1):
+ if max_depth > 0 and depth >= max_depth:
+ rich_node.add(f'... {len(dict_node)} more')
+ return
+
+ keys = sorted(dict_node.keys())
+ for key in keys:
+ next_dict_node = dict_node[key]
+ next_rich_node = rich_node.add(key)
+ if isinstance(next_dict_node, Dict):
+ self._recursively_build_rich_tree(next_rich_node, next_dict_node, depth+1, max_depth)
diff --git a/lib/utils/data/__init__.py b/lib/utils/data/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..8dc1ca57b54f8b17c26b58de681b5eefea345cc3
--- /dev/null
+++ b/lib/utils/data/__init__.py
@@ -0,0 +1,7 @@
+# Utils here are used to manipulate basic data structures, such as dict, list, tensor, etc.
+# It has nothing to do with Machine Learning concepts (like dataset, dataloader, etc.).
+
+from .types import *
+from .dict import *
+from .mem import *
+from .io import *
\ No newline at end of file
diff --git a/lib/utils/data/dict.py b/lib/utils/data/dict.py
new file mode 100644
index 0000000000000000000000000000000000000000..95fc880ce9dc9464a26d2e7a67fe9a650b4d749b
--- /dev/null
+++ b/lib/utils/data/dict.py
@@ -0,0 +1,79 @@
+import torch
+import numpy as np
+from typing import Dict, List
+
+
+def disassemble_dict(d, keep_dim=False):
+ '''
+ Unpack a dictionary into a list of dictionaries. The values should be in the same length.
+ If not keep dim: {k: [...] * N} -> [{k: [...]}] * N.
+ If keep dim: {k: [...] * N} -> [{k: [[...]]}] * N.
+ '''
+ Ls = [len(v) for v in d.values()]
+ assert len(set(Ls)) == 1, 'The lengths of the values should be the same!'
+
+ N = Ls[0]
+ if keep_dim:
+ return [{k: v[[i]] for k, v in d.items()} for i in range(N)]
+ else:
+ return [{k: v[i] for k, v in d.items()} for i in range(N)]
+
+
+def assemble_dict(d, expand_dim=False, keys=None):
+ '''
+ Pack a list of dictionaries into one dictionary.
+ If expand dim, perform stack, else, perform concat.
+ '''
+ keys = list(d[0].keys()) if keys is None else keys
+ if isinstance(d[0][keys[0]], np.ndarray):
+ if expand_dim:
+ return {k: np.stack([v[k] for v in d], axis=0) for k in keys}
+ else:
+ return {k: np.concatenate([v[k] for v in d], axis=0) for k in keys}
+ elif isinstance(d[0][keys[0]], torch.Tensor):
+ if expand_dim:
+ return {k: torch.stack([v[k] for v in d], dim=0) for k in keys}
+ else:
+ return {k: torch.cat([v[k] for v in d], dim=0) for k in keys}
+
+
+def filter_dict(d:Dict, keys:List, full:bool=False, strict:bool=False):
+ '''
+ Use path-like syntax to filter the embedded dictionary.
+ The `'*'` string is regarded as a wildcard, and will return the matched keys.
+ For control flags:
+ - If `full`, return the full path, otherwise, only return the matched values.
+ - If `strict`, raise error if the key is not found, otherwise, simply ignore.
+
+ Eg.
+ - `x = {'fruit': {'yellow': 'banana', 'red': 'apple'}, 'recycle': {'yellow': 'trash', 'blue': 'recyclable'}}`
+ - `filter_dict(x, ['*', 'yellow'])` -> `{'fruit': 'banana', 'recycle': 'trash'}`
+ - `filter_dict(x, ['*', 'yellow'], full=True)` -> `{'fruit': {'yellow': 'banana'}, 'recycle': {'yellow': 'trash'}}`
+ - `filter_dict(x, ['*', 'blue'])` -> `{'recycle': 'recyclable'}`
+ - `filter_dict(x, ['*', 'blue'], strict=True)` -> `KeyError: 'blue'`
+ '''
+
+ ret = {}
+ if keys:
+ cur_key, rest_keys = keys[0], keys[1:]
+ if cur_key == '*':
+ for match in d.keys():
+ try:
+ res = filter_dict(d[match], rest_keys, full=full, strict=strict)
+ if res:
+ ret[match] = res
+ except Exception as e:
+ if strict:
+ raise e
+ else:
+ try:
+ res = filter_dict(d[cur_key], rest_keys, full=full, strict=strict)
+ if res:
+ ret = { cur_key : res } if full else res
+ except Exception as e:
+ if strict:
+ raise e
+ else:
+ ret = d
+
+ return ret
\ No newline at end of file
diff --git a/lib/utils/data/io.py b/lib/utils/data/io.py
new file mode 100644
index 0000000000000000000000000000000000000000..830f2148c538daaf4eeebdc0c75d97547799f1bb
--- /dev/null
+++ b/lib/utils/data/io.py
@@ -0,0 +1,11 @@
+
+import pickle
+from pathlib import Path
+
+
+def load_pickle(fn, mode='rb', encoding=None, pickle_encoding='ASCII'):
+ if isinstance(fn, Path):
+ fn = str(fn)
+ with open(fn, mode=mode, encoding=encoding) as f:
+ data = pickle.load(f, encoding=pickle_encoding)
+ return data
\ No newline at end of file
diff --git a/lib/utils/data/mem.py b/lib/utils/data/mem.py
new file mode 100644
index 0000000000000000000000000000000000000000..a8d21854790d9f0d330733ca56473877207d1c6c
--- /dev/null
+++ b/lib/utils/data/mem.py
@@ -0,0 +1,14 @@
+import torch
+from typing import List, Dict, Tuple
+
+
+def recursive_detach(x):
+ if isinstance(x, torch.Tensor):
+ return x.detach()
+ elif isinstance(x, Dict):
+ return {k: recursive_detach(v) for k, v in x.items()}
+ elif isinstance(x, (List, Tuple)):
+ return [recursive_detach(v) for v in x]
+ else:
+ return x
+
diff --git a/lib/utils/data/types.py b/lib/utils/data/types.py
new file mode 100644
index 0000000000000000000000000000000000000000..1bf92fdcead22c14fc83b174c99eb2a7b83db9c9
--- /dev/null
+++ b/lib/utils/data/types.py
@@ -0,0 +1,82 @@
+import torch
+import numpy as np
+
+from typing import Any, List
+from omegaconf import ListConfig
+
+def to_numpy(x, temporary:bool=False) -> Any:
+ if isinstance(x, torch.Tensor):
+ if temporary:
+ recover_type_back = lambda x_: torch.from_numpy(x_).type_as(x).to(x.device)
+ return x.detach().cpu().numpy(), recover_type_back
+ else:
+ return x.detach().cpu().numpy()
+ if isinstance(x, np.ndarray):
+ if temporary:
+ recover_type_back = lambda x_: x_
+ return x.copy(), recover_type_back
+ else:
+ return x
+ if isinstance(x, List):
+ if temporary:
+ recover_type_back = lambda x_: x_.tolist()
+ return np.array(x), recover_type_back
+ else:
+ return np.array(x)
+ raise ValueError(f"Unsupported type: {type(x)}")
+
+def to_tensor(x, device, temporary:bool=False) -> Any:
+ '''
+ Simply unify the type transformation to torch.Tensor.
+ If device is None, don't change the device if device is not CPU.
+ '''
+ if isinstance(x, torch.Tensor):
+ device = x.device if device is None else device
+ if temporary:
+ recover_type_back = lambda x_: x_.to(x.device) # recover the device
+ return x.to(device), recover_type_back
+ else:
+ return x.to(device)
+
+ device = 'cpu' if device is None else device
+ if isinstance(x, np.ndarray):
+ if temporary:
+ recover_type_back = lambda x_: x_.detach().cpu().numpy()
+ return torch.from_numpy(x).to(device), recover_type_back
+ else:
+ return torch.from_numpy(x).to(device)
+ if isinstance(x, List):
+ if temporary:
+ recover_type_back = lambda x_: x_.tolist()
+ return torch.from_numpy(np.array(x)).to(device), recover_type_back
+ else:
+ return torch.from_numpy(np.array(x)).to(device)
+ raise ValueError(f"Unsupported type: {type(x)}")
+
+
+def to_list(x, temporary:bool=False) -> Any:
+ if isinstance(x, List):
+ if temporary:
+ recover_type_back = lambda x_: x_
+ return x.copy(), recover_type_back
+ else:
+ return x
+ if isinstance(x, torch.Tensor):
+ if temporary:
+ recover_type_back = lambda x_: torch.tensor(x_, device=x.device, dtype=x.dtype)
+ return x.detach().cpu().numpy().tolist(), recover_type_back
+ else:
+ return x.detach().cpu().numpy().tolist()
+ if isinstance(x, np.ndarray):
+ if temporary:
+ recover_type_back = lambda x_: np.array(x_)
+ return x.tolist(), recover_type_back
+ else:
+ return x.tolist()
+ if isinstance(x, ListConfig):
+ if temporary:
+ recover_type_back = lambda x_: ListConfig(x_)
+ return list(x), recover_type_back
+ else:
+ return list(x)
+ raise ValueError(f"Unsupported type: {type(x)}")
\ No newline at end of file
diff --git a/lib/utils/device.py b/lib/utils/device.py
new file mode 100644
index 0000000000000000000000000000000000000000..9a0a22f8812dd5fcd4cdba7f79b31e03a8ec417f
--- /dev/null
+++ b/lib/utils/device.py
@@ -0,0 +1,24 @@
+import torch
+from typing import Any, Dict, List
+
+
+def recursive_to(x: Any, target: torch.device):
+ '''
+ Recursively transfer data to the target device.
+ Modified from: https://github.com/shubham-goel/4D-Humans/blob/6ec79656a23c33237c724742ca2a0ec00b398b53/hmr2/utils/__init__.py#L9-L25
+
+ ### Args
+ - x: Any
+ - target: torch.device
+
+ ### Returns
+ - Data transferred to the target device.
+ '''
+ if isinstance(x, Dict):
+ return {k: recursive_to(v, target) for k, v in x.items()}
+ elif isinstance(x, torch.Tensor):
+ return x.to(target)
+ elif isinstance(x, List):
+ return [recursive_to(i, target) for i in x]
+ else:
+ return x
diff --git a/lib/utils/geometry/rotation.py b/lib/utils/geometry/rotation.py
new file mode 100644
index 0000000000000000000000000000000000000000..de8b95831c6d1c8d556f400ffdb7001b35190245
--- /dev/null
+++ b/lib/utils/geometry/rotation.py
@@ -0,0 +1,517 @@
+# Edited by Yan Xia
+# Modified from https://github.com/facebookresearch/pytorch3d/tree/main/pytorch3d/transforms
+#
+# # Copyright (c) Meta Platforms, Inc. and affiliates.
+# # All rights reserved.
+# #
+# # This source code is licensed under the BSD-style license found in the
+# # LICENSE file in the root directory of this source tree.
+
+import torch
+import torch.nn.functional as F
+
+
+'''
+The transformation matrices returned from the functions in this file assume
+the points on which the transformation will be applied are column vectors.
+i.e. the R matrix is structured as
+
+ R = [
+ [Rxx, Rxy, Rxz],
+ [Ryx, Ryy, Ryz],
+ [Rzx, Rzy, Rzz],
+ ] # (3, 3)
+
+This matrix can be applied to column vectors by post multiplication
+by the points e.g.
+
+ points = [[0], [1], [2]] # (3 x 1) xyz coordinates of a point
+ transformed_points = R * points
+
+To apply the same matrix to points which are row vectors, the R matrix
+can be transposed and pre multiplied by the points:
+
+e.g.
+ points = [[0, 1, 2]] # (1 x 3) xyz coordinates of a point
+ transformed_points = points * R.transpose(1, 0)
+'''
+
+
+def quaternion_to_matrix(quaternions: torch.Tensor) -> torch.Tensor:
+ '''
+ Convert rotations given as quaternions to rotation matrices.
+
+ Args:
+ quaternions: quaternions with real part first,
+ as tensor of shape (..., 4).
+
+ Returns:
+ Rotation matrices as tensor of shape (..., 3, 3).
+ '''
+ r, i, j, k = torch.unbind(quaternions, -1)
+ # pyre-fixme[58]: `/` is not supported for operand types `float` and `Tensor`.
+ two_s = 2.0 / (quaternions * quaternions).sum(-1)
+
+ o = torch.stack(
+ (
+ 1 - two_s * (j * j + k * k),
+ two_s * (i * j - k * r),
+ two_s * (i * k + j * r),
+ two_s * (i * j + k * r),
+ 1 - two_s * (i * i + k * k),
+ two_s * (j * k - i * r),
+ two_s * (i * k - j * r),
+ two_s * (j * k + i * r),
+ 1 - two_s * (i * i + j * j),
+ ),
+ -1,
+ )
+ return o.reshape(quaternions.shape[:-1] + (3, 3))
+
+
+def _sqrt_positive_part(x: torch.Tensor) -> torch.Tensor:
+ '''
+ Returns torch.sqrt(torch.max(0, x))
+ but with a zero subgradient where x is 0.
+ '''
+ ret = torch.zeros_like(x)
+ positive_mask = x > 0
+ ret[positive_mask] = torch.sqrt(x[positive_mask])
+ return ret
+
+
+def matrix_to_quaternion(matrix: torch.Tensor) -> torch.Tensor:
+ '''
+ Convert rotations given as rotation matrices to quaternions.
+
+ Args:
+ matrix: Rotation matrices as tensor of shape (..., 3, 3).
+
+ Returns:
+ quaternions with real part first, as tensor of shape (..., 4).
+ '''
+ if matrix.size(-1) != 3 or matrix.size(-2) != 3:
+ raise ValueError(f"Invalid rotation matrix shape {matrix.shape}.")
+
+ batch_dim = matrix.shape[:-2]
+ m00, m01, m02, m10, m11, m12, m20, m21, m22 = torch.unbind(
+ matrix.reshape(batch_dim + (9,)), dim=-1
+ )
+
+ q_abs = _sqrt_positive_part(
+ torch.stack(
+ [
+ 1.0 + m00 + m11 + m22,
+ 1.0 + m00 - m11 - m22,
+ 1.0 - m00 + m11 - m22,
+ 1.0 - m00 - m11 + m22,
+ ],
+ dim=-1,
+ )
+ )
+
+ # we produce the desired quaternion multiplied by each of r, i, j, k
+ quat_by_rijk = torch.stack(
+ [
+ # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and
+ # `int`.
+ torch.stack([q_abs[..., 0] ** 2, m21 - m12, m02 - m20, m10 - m01], dim=-1),
+ # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and
+ # `int`.
+ torch.stack([m21 - m12, q_abs[..., 1] ** 2, m10 + m01, m02 + m20], dim=-1),
+ # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and
+ # `int`.
+ torch.stack([m02 - m20, m10 + m01, q_abs[..., 2] ** 2, m12 + m21], dim=-1),
+ # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and
+ # `int`.
+ torch.stack([m10 - m01, m20 + m02, m21 + m12, q_abs[..., 3] ** 2], dim=-1),
+ ],
+ dim=-2,
+ )
+
+ # We floor here at 0.1 but the exact level is not important; if q_abs is small,
+ # the candidate won't be picked.
+ flr = torch.tensor(0.1).to(dtype=q_abs.dtype, device=q_abs.device)
+ quat_candidates = quat_by_rijk / (2.0 * q_abs[..., None].max(flr))
+
+ # if not for numerical problems, quat_candidates[i] should be same (up to a sign),
+ # forall i; we pick the best-conditioned one (with the largest denominator)
+ out = quat_candidates[
+ F.one_hot(q_abs.argmax(dim=-1), num_classes=4) > 0.5, :
+ ].reshape(batch_dim + (4,))
+ return standardize_quaternion(out)
+
+
+def _axis_angle_rotation(axis: str, angle: torch.Tensor) -> torch.Tensor:
+ '''
+ Return the rotation matrices for one of the rotations about an axis
+ of which Euler angles describe, for each value of the angle given.
+
+ Args:
+ axis: Axis label "X" or "Y or "Z".
+ angle: any shape tensor of Euler angles in radians
+
+ Returns:
+ Rotation matrices as tensor of shape (..., 3, 3).
+ '''
+
+ cos = torch.cos(angle)
+ sin = torch.sin(angle)
+ one = torch.ones_like(angle)
+ zero = torch.zeros_like(angle)
+
+ if axis == "X":
+ R_flat = (one, zero, zero, zero, cos, -sin, zero, sin, cos)
+ elif axis == "Y":
+ R_flat = (cos, zero, sin, zero, one, zero, -sin, zero, cos)
+ elif axis == "Z":
+ R_flat = (cos, -sin, zero, sin, cos, zero, zero, zero, one)
+ else:
+ raise ValueError("letter must be either X, Y or Z.")
+
+ return torch.stack(R_flat, -1).reshape(angle.shape + (3, 3))
+
+
+def euler_angles_to_matrix(euler_angles: torch.Tensor, convention: str) -> torch.Tensor:
+ '''
+ Convert rotations given as Euler angles in radians to rotation matrices.
+
+ Args:
+ euler_angles: Euler angles in radians as tensor of shape (..., 3).
+ convention: Convention string of three uppercase letters from
+ {"X", "Y", and "Z"}.
+
+ Returns:
+ Rotation matrices as tensor of shape (..., 3, 3).
+ '''
+ if euler_angles.dim() == 0 or euler_angles.shape[-1] != 3:
+ raise ValueError("Invalid input euler angles.")
+ if len(convention) != 3:
+ raise ValueError("Convention must have 3 letters.")
+ if convention[1] in (convention[0], convention[2]):
+ raise ValueError(f"Invalid convention {convention}.")
+ for letter in convention:
+ if letter not in ("X", "Y", "Z"):
+ raise ValueError(f"Invalid letter {letter} in convention string.")
+ matrices = [
+ _axis_angle_rotation(c, e)
+ for c, e in zip(convention, torch.unbind(euler_angles, -1))
+ ]
+ # return functools.reduce(torch.matmul, matrices)
+ return torch.matmul(torch.matmul(matrices[0], matrices[1]), matrices[2])
+
+
+def _angle_from_tan(
+ axis: str, other_axis: str, data, horizontal: bool, tait_bryan: bool
+) -> torch.Tensor:
+ '''
+ Extract the first or third Euler angle from the two members of
+ the matrix which are positive constant times its sine and cosine.
+
+ Args:
+ axis: Axis label "X" or "Y or "Z" for the angle we are finding.
+ other_axis: Axis label "X" or "Y or "Z" for the middle axis in the
+ convention.
+ data: Rotation matrices as tensor of shape (..., 3, 3).
+ horizontal: Whether we are looking for the angle for the third axis,
+ which means the relevant entries are in the same row of the
+ rotation matrix. If not, they are in the same column.
+ tait_bryan: Whether the first and third axes in the convention differ.
+
+ Returns:
+ Euler Angles in radians for each matrix in data as a tensor
+ of shape (...).
+ '''
+
+ i1, i2 = {"X": (2, 1), "Y": (0, 2), "Z": (1, 0)}[axis]
+ if horizontal:
+ i2, i1 = i1, i2
+ even = (axis + other_axis) in ["XY", "YZ", "ZX"]
+ if horizontal == even:
+ return torch.atan2(data[..., i1], data[..., i2])
+ if tait_bryan:
+ return torch.atan2(-data[..., i2], data[..., i1])
+ return torch.atan2(data[..., i2], -data[..., i1])
+
+
+def _index_from_letter(letter: str) -> int:
+ if letter == "X":
+ return 0
+ if letter == "Y":
+ return 1
+ if letter == "Z":
+ return 2
+ raise ValueError("letter must be either X, Y or Z.")
+
+
+def matrix_to_euler_angles(matrix: torch.Tensor, convention: str) -> torch.Tensor:
+ '''
+ Convert rotations given as rotation matrices to Euler angles in radians.
+
+ Args:
+ matrix: Rotation matrices as tensor of shape (..., 3, 3).
+ convention: Convention string of three uppercase letters.
+
+ Returns:
+ Euler angles in radians as tensor of shape (..., 3).
+ '''
+ if len(convention) != 3:
+ raise ValueError("Convention must have 3 letters.")
+ if convention[1] in (convention[0], convention[2]):
+ raise ValueError(f"Invalid convention {convention}.")
+ for letter in convention:
+ if letter not in ("X", "Y", "Z"):
+ raise ValueError(f"Invalid letter {letter} in convention string.")
+ if matrix.size(-1) != 3 or matrix.size(-2) != 3:
+ raise ValueError(f"Invalid rotation matrix shape {matrix.shape}.")
+ i0 = _index_from_letter(convention[0])
+ i2 = _index_from_letter(convention[2])
+ tait_bryan = i0 != i2
+ if tait_bryan:
+ central_angle = torch.asin(
+ matrix[..., i0, i2] * (-1.0 if i0 - i2 in [-1, 2] else 1.0)
+ )
+ else:
+ central_angle = torch.acos(matrix[..., i0, i0])
+
+ o = (
+ _angle_from_tan(
+ convention[0], convention[1], matrix[..., i2], False, tait_bryan
+ ),
+ central_angle,
+ _angle_from_tan(
+ convention[2], convention[1], matrix[..., i0, :], True, tait_bryan
+ ),
+ )
+ return torch.stack(o, -1)
+
+
+def standardize_quaternion(quaternions: torch.Tensor) -> torch.Tensor:
+ '''
+ Convert a unit quaternion to a standard form: one in which the real
+ part is non negative.
+
+ Args:
+ quaternions: Quaternions with real part first,
+ as tensor of shape (..., 4).
+
+ Returns:
+ Standardized quaternions as tensor of shape (..., 4).
+ '''
+ return torch.where(quaternions[..., 0:1] < 0, -quaternions, quaternions)
+
+
+def quaternion_raw_multiply(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
+ '''
+ Multiply two quaternions.
+ Usual torch rules for broadcasting apply.
+
+ Args:
+ a: Quaternions as tensor of shape (..., 4), real part first.
+ b: Quaternions as tensor of shape (..., 4), real part first.
+
+ Returns:
+ The product of a and b, a tensor of quaternions shape (..., 4).
+ '''
+ aw, ax, ay, az = torch.unbind(a, -1)
+ bw, bx, by, bz = torch.unbind(b, -1)
+ ow = aw * bw - ax * bx - ay * by - az * bz
+ ox = aw * bx + ax * bw + ay * bz - az * by
+ oy = aw * by - ax * bz + ay * bw + az * bx
+ oz = aw * bz + ax * by - ay * bx + az * bw
+ return torch.stack((ow, ox, oy, oz), -1)
+
+
+def quaternion_multiply(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
+ '''
+ Multiply two quaternions representing rotations, returning the quaternion
+ representing their composition, i.e. the versor with nonnegative real part.
+ Usual torch rules for broadcasting apply.
+
+ Args:
+ a: Quaternions as tensor of shape (..., 4), real part first.
+ b: Quaternions as tensor of shape (..., 4), real part first.
+
+ Returns:
+ The product of a and b, a tensor of quaternions of shape (..., 4).
+ '''
+ ab = quaternion_raw_multiply(a, b)
+ return standardize_quaternion(ab)
+
+
+def quaternion_invert(quaternion: torch.Tensor) -> torch.Tensor:
+ '''
+ Given a quaternion representing rotation, get the quaternion representing
+ its inverse.
+
+ Args:
+ quaternion: Quaternions as tensor of shape (..., 4), with real part
+ first, which must be versors (unit quaternions).
+
+ Returns:
+ The inverse, a tensor of quaternions of shape (..., 4).
+ '''
+
+ scaling = torch.tensor([1, -1, -1, -1], device=quaternion.device)
+ return quaternion * scaling
+
+
+def quaternion_apply(quaternion: torch.Tensor, point: torch.Tensor) -> torch.Tensor:
+ '''
+ Apply the rotation given by a quaternion to a 3D point.
+ Usual torch rules for broadcasting apply.
+
+ Args:
+ quaternion: Tensor of quaternions, real part first, of shape (..., 4).
+ point: Tensor of 3D points of shape (..., 3).
+
+ Returns:
+ Tensor of rotated points of shape (..., 3).
+ '''
+ if point.size(-1) != 3:
+ raise ValueError(f"Points are not in 3D, {point.shape}.")
+ real_parts = point.new_zeros(point.shape[:-1] + (1,))
+ point_as_quaternion = torch.cat((real_parts, point), -1)
+ out = quaternion_raw_multiply(
+ quaternion_raw_multiply(quaternion, point_as_quaternion),
+ quaternion_invert(quaternion),
+ )
+ return out[..., 1:]
+
+
+def axis_angle_to_matrix(axis_angle: torch.Tensor) -> torch.Tensor:
+ '''
+ Convert rotations given as axis/angle to rotation matrices.
+
+ Args:
+ axis_angle: Rotations given as a vector in axis angle form,
+ as a tensor of shape (..., 3), where the magnitude is
+ the angle turned anticlockwise in radians around the
+ vector's direction.
+
+ Returns:
+ Rotation matrices as tensor of shape (..., 3, 3).
+ '''
+ return quaternion_to_matrix(axis_angle_to_quaternion(axis_angle))
+
+
+def matrix_to_axis_angle(matrix: torch.Tensor) -> torch.Tensor:
+ '''
+ Convert rotations given as rotation matrices to axis/angle.
+
+ Args:
+ matrix: Rotation matrices as tensor of shape (..., 3, 3).
+
+ Returns:
+ Rotations given as a vector in axis angle form, as a tensor
+ of shape (..., 3), where the magnitude is the angle
+ turned anticlockwise in radians around the vector's
+ direction.
+ '''
+ return quaternion_to_axis_angle(matrix_to_quaternion(matrix))
+
+
+def axis_angle_to_quaternion(axis_angle: torch.Tensor) -> torch.Tensor:
+ '''
+ Convert rotations given as axis/angle to quaternions.
+
+ Args:
+ axis_angle: Rotations given as a vector in axis angle form,
+ as a tensor of shape (..., 3), where the magnitude is
+ the angle turned anticlockwise in radians around the
+ vector's direction.
+
+ Returns:
+ quaternions with real part first, as tensor of shape (..., 4).
+ '''
+ angles = torch.norm(axis_angle, p=2, dim=-1, keepdim=True)
+ half_angles = angles * 0.5
+ eps = 1e-6
+ small_angles = angles.abs() < eps
+ sin_half_angles_over_angles = torch.empty_like(angles)
+ sin_half_angles_over_angles[~small_angles] = (
+ torch.sin(half_angles[~small_angles]) / angles[~small_angles]
+ )
+ # for x small, sin(x/2) is about x/2 - (x/2)^3/6
+ # so sin(x/2)/x is about 1/2 - (x*x)/48
+ sin_half_angles_over_angles[small_angles] = (
+ 0.5 - (angles[small_angles] * angles[small_angles]) / 48
+ )
+ quaternions = torch.cat(
+ [torch.cos(half_angles), axis_angle * sin_half_angles_over_angles], dim=-1
+ )
+ return quaternions
+
+
+def quaternion_to_axis_angle(quaternions: torch.Tensor) -> torch.Tensor:
+ '''
+ Convert rotations given as quaternions to axis/angle.
+
+ Args:
+ quaternions: quaternions with real part first,
+ as tensor of shape (..., 4).
+
+ Returns:
+ Rotations given as a vector in axis angle form, as a tensor
+ of shape (..., 3), where the magnitude is the angle
+ turned anticlockwise in radians around the vector's
+ direction.
+ '''
+ norms = torch.norm(quaternions[..., 1:], p=2, dim=-1, keepdim=True)
+ half_angles = torch.atan2(norms, quaternions[..., :1])
+ angles = 2 * half_angles
+ eps = 1e-6
+ small_angles = angles.abs() < eps
+ sin_half_angles_over_angles = torch.empty_like(angles)
+ sin_half_angles_over_angles[~small_angles] = (
+ torch.sin(half_angles[~small_angles]) / angles[~small_angles]
+ )
+ # for x small, sin(x/2) is about x/2 - (x/2)^3/6
+ # so sin(x/2)/x is about 1/2 - (x*x)/48
+ sin_half_angles_over_angles[small_angles] = (
+ 0.5 - (angles[small_angles] * angles[small_angles]) / 48
+ )
+ return quaternions[..., 1:] / sin_half_angles_over_angles
+
+
+def rotation_6d_to_matrix(d6: torch.Tensor) -> torch.Tensor:
+ '''
+ Converts 6D rotation representation by Zhou et al. [1] to rotation matrix
+ using Gram--Schmidt orthogonalization per Section B of [1].
+ Args:
+ d6: 6D rotation representation, of size (*, 6)
+
+ Returns:
+ batch of rotation matrices of size (*, 3, 3)
+
+ [1] Zhou, Y., Barnes, C., Lu, J., Yang, J., & Li, H.
+ On the Continuity of Rotation Representations in Neural Networks.
+ IEEE Conference on Computer Vision and Pattern Recognition, 2019.
+ Retrieved from http://arxiv.org/abs/1812.07035
+ '''
+
+ a1, a2 = d6[..., :3], d6[..., 3:]
+ b1 = F.normalize(a1, dim=-1)
+ b2 = a2 - (b1 * a2).sum(-1, keepdim=True) * b1
+ b2 = F.normalize(b2, dim=-1)
+ b3 = torch.cross(b1, b2, dim=-1)
+ return torch.stack((b1, b2, b3), dim=-2)
+
+
+def matrix_to_rotation_6d(matrix: torch.Tensor) -> torch.Tensor:
+ '''
+ Converts rotation matrices to 6D rotation representation by Zhou et al. [1]
+ by dropping the last row. Note that 6D representation is not unique.
+ Args:
+ matrix: batch of rotation matrices of size (*, 3, 3)
+
+ Returns:
+ 6D rotation representation, of size (*, 6)
+
+ [1] Zhou, Y., Barnes, C., Lu, J., Yang, J., & Li, H.
+ On the Continuity of Rotation Representations in Neural Networks.
+ IEEE Conference on Computer Vision and Pattern Recognition, 2019.
+ Retrieved from http://arxiv.org/abs/1812.07035
+ '''
+ batch_dim = matrix.size()[:-2]
+ return matrix[..., :2, :].clone().reshape(batch_dim + (6,))
diff --git a/lib/utils/geometry/volume.py b/lib/utils/geometry/volume.py
new file mode 100644
index 0000000000000000000000000000000000000000..dafbd03bbca636a8a5c9bdf34039060669fec088
--- /dev/null
+++ b/lib/utils/geometry/volume.py
@@ -0,0 +1,43 @@
+from lib.kits.basic import *
+
+
+def compute_mesh_volume(
+ verts: Union[torch.Tensor, np.ndarray],
+ faces: Union[torch.Tensor, np.ndarray],
+) -> torch.Tensor:
+ '''
+ Computes the volume of a mesh object through triangles.
+ References:
+ 1. https://github.com/muelea/shapy/blob/a5daa70ce619cbd2a218cebbe63ae3a4c0b771fd/mesh-mesh-intersection/body_measurements/body_measurements.py#L201-L215
+ 2. https://stackoverflow.com/questions/1406029/how-to-calculate-the-volume-of-a-3d-mesh-object-the-surface-of-which-is-made-up
+
+ ### Args
+ - verts: `torch.Tensor` or `np.ndarray`, shape = ((...B,) V, C=3)
+ - faces: `torch.Tensor` or `np.ndarray`, shape = (T, K=3) where T = #triangles
+
+ ### Returns
+ - volume: `torch.Tensor`, shape = (...B,) or (,), in m^3.
+ '''
+ faces = to_numpy(faces)
+
+ # Get triangles' xyz.
+ batch_shape = verts.shape[:-2]
+ V = verts.shape[-2]
+ verts = verts.reshape(-1, V, 3) # (B', V, C=3)
+
+ tris = verts[:, faces] # (B', T, K=3, C=3)
+ tris = tris.reshape(*batch_shape, -1, 3, 3) # (..., T, K=3, C=3)
+
+ x = tris[..., 0] # (..., T, K=3)
+ y = tris[..., 1] # (..., T, K=3)
+ z = tris[..., 2] # (..., T, K=3)
+
+ volume = (
+ -x[..., 2] * y[..., 1] * z[..., 0] +
+ x[..., 1] * y[..., 2] * z[..., 0] +
+ x[..., 2] * y[..., 0] * z[..., 1] -
+ x[..., 0] * y[..., 2] * z[..., 1] -
+ x[..., 1] * y[..., 0] * z[..., 2] +
+ x[..., 0] * y[..., 1] * z[..., 2]
+ ).sum(dim=-1).abs() / 6.0
+ return volume
\ No newline at end of file
diff --git a/lib/utils/media/__init__.py b/lib/utils/media/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..40bafce9930369184b2ffe40cb551d9187de7bb7
--- /dev/null
+++ b/lib/utils/media/__init__.py
@@ -0,0 +1,3 @@
+from .io import *
+from .edit import *
+from .draw import *
\ No newline at end of file
diff --git a/lib/utils/media/draw.py b/lib/utils/media/draw.py
new file mode 100644
index 0000000000000000000000000000000000000000..b8568609c648e9f62065bde86e341c629fc96ab7
--- /dev/null
+++ b/lib/utils/media/draw.py
@@ -0,0 +1,245 @@
+from lib.kits.basic import *
+
+import cv2
+import imageio
+
+from lib.utils.data import to_numpy
+from lib.utils.vis import ColorPalette
+
+
+def annotate_img(
+ img : np.ndarray,
+ text : str,
+ pos : Union[str, Tuple[int, int]] = 'bl',
+):
+ '''
+ Annotate the image with the given text.
+
+ ### Args
+ - img: np.ndarray, (H, W, 3)
+ - text: str
+ - pos: str or tuple(int, int), default 'bl'
+ - If str, one of ['tl', 'bl'].
+ - If tuple, the position of the text.
+
+ ### Returns
+ - np.ndarray, (H, W, 3)
+ - The annotated image.
+ '''
+ assert len(img.shape) == 3, 'img must have 3 dimensions.'
+ return annotate_video(frames=img[None], text=text, pos=pos)[0]
+
+
+def annotate_video(
+ frames : np.ndarray,
+ text : str,
+ pos : Union[str, Tuple[int, int]] = 'bl',
+ alpha : float = 0.75,
+):
+ '''
+ Annotate the video frames with the given text.
+
+ ### Args
+ - frames: np.ndarray, (L, H, W, 3)
+ - text: str
+ - pos: str or tuple(int, int), default 'bl'
+ - If str, one of ['tl', 'bl'].
+ - If tuple, the position of the text.
+ - alpha: float, default 0.5
+ - The transparency of the text.
+
+ ### Returns
+ - np.ndarray, (L, H, W, 3)
+ - The annotated video.
+ '''
+ assert len(frames.shape) == 4, 'frames must have 4 dimensions.'
+ frames = frames.copy()
+ L, H, W = frames.shape[:3]
+
+ if isinstance(pos, str):
+ if pos == 'tl':
+ offset = (int(0.1 * W), int(0.1 * H))
+ elif pos == 'bl':
+ offset = (int(0.1 * W), int(0.9 * H))
+ else:
+ raise ValueError(f'Invalid position: {pos}')
+ else:
+ offset = pos
+
+ for i, frame in enumerate(frames):
+ overlay = frame.copy()
+ _put_text(overlay, text, offset)
+ frames[i] = cv2.addWeighted(overlay, alpha, frame, 1 - alpha, 0)
+
+ return frames
+
+
+def draw_bbx_on_img(
+ img : np.ndarray,
+ lurb : np.ndarray,
+ color : str = 'red',
+):
+ '''
+ Draw the bounding box on the image.
+
+ ### Args
+ - img: np.ndarray, (H, W, 3)
+ - lurb: np.ndarray, (4,)
+ - The bounding box in the format of left, up, right, bottom.
+ - color: str, default 'red'
+
+ ### Returns
+ - np.ndarray, (H, W, 3)
+ - The image with the bounding box.
+ '''
+ assert len(img.shape) == 3, 'img must have 3 dimensions.'
+
+ img = img.copy()
+ l, u, r, b = lurb.astype(int)
+ color_rgb_int8 = ColorPalette.presets_int8[color]
+ cv2.rectangle(img, (l, u), (r, b), color_rgb_int8, 3)
+
+ return img
+
+
+def draw_bbx_on_video(
+ frames : np.ndarray,
+ lurbs : np.ndarray,
+ color : str = 'red',
+):
+ '''
+ Draw the bounding box on the video frames.
+
+ ### Args
+ - frames: np.ndarray, (L, H, W, 3)
+ - lurbs: np.ndarray, (L, 4,)
+ - The bounding box in the format of left, up, right, bottom.
+ - color: str, default 'red'
+
+ ### Returns
+ - np.ndarray, (L, H, W, 3)
+ - The video with the bounding box.
+ '''
+ assert len(frames.shape) == 4, 'frames must have 4 dimensions.'
+ frames = frames.copy()
+
+ for i, frame in enumerate(frames):
+ frames[i] = draw_bbx_on_img(frame, lurbs[i], color)
+
+ return frames
+
+
+def draw_kp2d_on_img(
+ img : np.ndarray,
+ kp2d : Union[np.ndarray, torch.Tensor],
+ links : list = [],
+ link_colors : list = [],
+ show_conf : bool = False,
+ show_idx : bool = False,
+):
+ '''
+ Draw the 2d keypoints (and connection lines if exists) on the image.
+
+ ### Args
+ - img: np.ndarray, (H, W, 3)
+ - The image.
+ - kp2d: np.ndarray or torch.Tensor, (N, 2) or (N, 3)
+ - The 2d keypoints without/with confidence.
+ - links: list of [int, int] or (int, int), default []
+ - The connections between keypoints. Each element is a tuple of two indices.
+ - If empty, only keypoints will be drawn.
+ - link_colors: list of [int, int, int] or (int, int, int), default []
+ - The colors of the connections.
+ - If empty, the connections will be drawn in white.
+ - show_conf: bool, default False
+ - Whether to show the confidence of keypoints.
+ - show_idx: bool, default False
+ - Whether to show the index of keypoints.
+
+ ### Returns
+ - img: np.ndarray, (H, W, 3)
+ - The image with skeleton.
+ '''
+ img = img.copy()
+ kp2d = to_numpy(kp2d) # (N, 2) or (N, 3)
+ assert len(img.shape) == 3, f'`img`\'s shape should be (H, W, 3) but got {img.shape}'
+ assert len(kp2d.shape) == 2, f'`kp2d`\'s shape should be (N, 2) or (N, 3) but got {kp2d.shape}'
+
+ if kp2d.shape[1] == 2:
+ kp2d = np.concatenate([kp2d, np.ones((kp2d.shape[0], 1))], axis=-1) # (N, 3)
+
+ kp_has_drawn = [False] * kp2d.shape[0]
+ # Draw connections.
+ for lid, link in enumerate(links):
+ # Skip links related to impossible keypoints.
+ if kp2d[link[0], 2] < 0.5 or kp2d[link[1], 2] < 0.5:
+ continue
+
+ pt1 = tuple(kp2d[link[0], :2].astype(int))
+ pt2 = tuple(kp2d[link[1], :2].astype(int))
+ color = (255, 255, 255) if len(link_colors) == 0 else tuple(link_colors[lid])
+ cv2.line(img, pt1, pt2, color, 2)
+ if not kp_has_drawn[link[0]]:
+ cv2.circle(img, pt1, 3, color, -1)
+ if not kp_has_drawn[link[1]]:
+ cv2.circle(img, pt2, 3, color, -1)
+ kp_has_drawn[link[0]] = kp_has_drawn[link[1]] = True
+
+ # Draw keypoints and annotate the confidence.
+ for i, kp in enumerate(kp2d):
+ conf = kp[2]
+ pos = tuple(kp[:2].astype(int))
+
+ if not kp_has_drawn[i]:
+ cv2.circle(img, pos, 4, (255, 255, 255), -1)
+ cv2.circle(img, pos, 2, ( 0, 255, 0), -1)
+ kp_has_drawn[i] = True
+
+ if show_conf:
+ _put_text(img, f'{conf:.2f}', pos)
+ if show_idx:
+ if i >= 40:
+ continue
+ _put_text(img, f'{i}', pos, scale=0.03)
+
+ return img
+
+
+# ====== Internal Utils ======
+
+def _put_text(
+ img : np.ndarray,
+ text : str,
+ pos : Tuple[int, int],
+ scale : float = 0.05,
+ color_inside : Tuple[int, int, int] = ColorPalette.presets_int8['black'],
+ color_stroke : Tuple[int, int, int] = ColorPalette.presets_int8['white'],
+ **kwargs
+):
+ fontFace = cv2.FONT_HERSHEY_SIMPLEX
+ if 'fontFace' in kwargs:
+ fontFace = kwargs['fontFace']
+ kwargs.pop('fontFace')
+
+ H, W = img.shape[:2]
+ # https://stackoverflow.com/a/55772676/22331129
+ font_scale = scale * min(H, W) / 25 * 1.5
+ thickness_inside = max(int(font_scale), 1)
+ thickness_stroke = max(int(font_scale * 6), 6)
+
+ # Deal with the multi-line text.
+
+ ((fw, fh), baseline) = cv2.getTextSize(
+ text = text,
+ fontFace = fontFace,
+ fontScale = font_scale,
+ thickness = thickness_stroke,
+ ) # https://stackoverflow.com/questions/73664883/opencv-python-draw-text-with-fontsize-in-pixels
+
+ lines = text.split('\n')
+ line_height = baseline + fh
+
+ for i, line in enumerate(lines):
+ pos_ = (pos[0], pos[1] + line_height * i)
+ cv2.putText(img, line, pos_, fontFace, font_scale, color_stroke, thickness_stroke)
+ cv2.putText(img, line, pos_, fontFace, font_scale, color_inside, thickness_inside)
\ No newline at end of file
diff --git a/lib/utils/media/edit.py b/lib/utils/media/edit.py
new file mode 100644
index 0000000000000000000000000000000000000000..b64f8ca39d04aa7740951813e9aa20fb8f5a4e3f
--- /dev/null
+++ b/lib/utils/media/edit.py
@@ -0,0 +1,288 @@
+import cv2
+import imageio
+import numpy as np
+
+from typing import Union, Tuple, List
+from pathlib import Path
+
+
+def flex_resize_img(
+ img : np.ndarray,
+ tgt_wh : Union[Tuple[int, int], None] = None,
+ ratio : Union[float, None] = None,
+ kp_mod : int = 1,
+):
+ '''
+ Resize the image to the target width and height. Set one of width and height to -1 to keep the aspect ratio.
+ Only one of `tgt_wh` and `ratio` can be set, if both are set, `tgt_wh` will be used.
+
+ ### Args
+ - img: np.ndarray, (H, W, 3)
+ - tgt_wh: Tuple[int, int], default=None
+ - The target width and height, set one of them to -1 to keep the aspect ratio.
+ - ratio: float, default=None
+ - The ratio to resize the frames. It will be used if `tgt_wh` is not set.
+ - kp_mod: int, default 1
+ - Keep the width and height as multiples of `kp_mod`.
+ - For example, if `kp_mod=16`, the width and height will be rounded to the nearest multiple of 16.
+
+ ### Returns
+ - np.ndarray, (H', W', 3)
+ - The resized iamges.
+ '''
+ assert len(img.shape) == 3, 'img must have 3 dimensions.'
+ return flex_resize_video(img[None], tgt_wh, ratio, kp_mod)[0]
+
+
+def flex_resize_video(
+ frames : np.ndarray,
+ tgt_wh : Union[Tuple[int, int], None] = None,
+ ratio : Union[float, None] = None,
+ kp_mod : int = 1,
+):
+ '''
+ Resize the frames to the target width and height. Set one of width and height to -1 to keep the aspect ratio.
+ Only one of `tgt_wh` and `ratio` can be set, if both are set, `tgt_wh` will be used.
+
+ ### Args
+ - frames: np.ndarray, (L, H, W, 3)
+ - tgt_wh: Tuple[int, int], default=None
+ - The target width and height, set one of them to -1 to keep the aspect ratio.
+ - ratio: float, default=None
+ - The ratio to resize the frames. It will be used if `tgt_wh` is not set.
+ - kp_mod: int, default 1
+ - Keep the width and height as multiples of `kp_mod`.
+ - For example, if `kp_mod=16`, the width and height will be rounded to the nearest multiple of 16.
+
+ ### Returns
+ - np.ndarray, (L, H', W', 3)
+ - The resized frames.
+ '''
+ assert tgt_wh is not None or ratio is not None, 'At least one of tgt_wh and ratio must be set.'
+ if tgt_wh is not None:
+ assert len(tgt_wh) == 2, 'tgt_wh must be a tuple of 2 elements.'
+ assert tgt_wh[0] > 0 or tgt_wh[1] > 0, 'At least one of width and height must be positive.'
+ if ratio is not None:
+ assert ratio > 0, 'ratio must be positive.'
+ assert len(frames.shape) == 4, 'frames must have 3 or 4 dimensions.'
+
+ def align_size(val:float):
+ ''' It will round the value to the nearest multiple of `kp_mod`. '''
+ return int(round(val / kp_mod) * kp_mod)
+
+ # Calculate the target width and height.
+ orig_h, orig_w = frames.shape[1], frames.shape[2]
+ tgt_wh = (int(orig_w * ratio), int(orig_h * ratio)) if tgt_wh is None else tgt_wh # Get wh from ratio if not given. # type: ignore
+ tgt_w, tgt_h = tgt_wh
+ tgt_w = align_size(orig_w * tgt_h / orig_h) if tgt_w == -1 else align_size(tgt_w)
+ tgt_h = align_size(orig_h * tgt_w / orig_w) if tgt_h == -1 else align_size(tgt_h)
+ # Resize the frames.
+ resized_frames = np.stack([cv2.resize(frame, (tgt_w, tgt_h)) for frame in frames])
+
+ return resized_frames
+
+
+def splice_img(
+ img_grids : Union[List[np.ndarray], np.ndarray],
+ grid_ids : Union[List[int], np.ndarray],
+):
+ '''
+ Splice the images with the same size, according to the grid index.
+ For example, you have 3 images [i1, i2, i3], and a `grid_ids` matrix:
+ [[ 0, 1], |i1|i2|
+ [ 2, -1], , then the results will be |i3|ib| , where ib means a black place holder.
+ [-1, -1]] |ib|ib|
+
+ ### Args
+ - img_grids: List[np.ndarray] or np.ndarray, (K, H, W, 3)
+ - The source images to splice. It indicates that all the images have the same size.
+ - grid_ids: List[int] or np.ndarray, (Y, X)
+ - The grid index of each image. It should be a 2D matrix with integers as the type of elements.
+ - The value in this matrix indexed the image in the `video_grids`, so it ranges from 0 to K-1.
+ - Specially, set the grid index to -1 to use a black place holder.
+
+ ### Returns
+ - np.ndarray, (H*Y, W*X, 3)
+ - The spliced images.
+ '''
+ if isinstance(img_grids, List):
+ img_grids = np.stack(img_grids)
+ if isinstance(grid_ids, List):
+ grid_ids = np.array(grid_ids)
+
+ assert len(img_grids.shape) == 4, 'img_grids must be in shape (K, H, W, 3).'
+ return splice_video(img_grids[:, None], grid_ids)[0]
+
+
+def splice_video(
+ video_grids : Union[List[np.ndarray], np.ndarray],
+ grid_ids : Union[List[int], np.ndarray],
+):
+ '''
+ Splice the videos with the same size, according to the grid index.
+ For example, you have 3 videos [v1, v2, v3], and a `grid_ids` matrix:
+ [[ 0, 1], |v1|v2|
+ [ 2, -1], , then the results will be |v3|vb| , wher vb means a black place holder.
+ [-1, -1]] |vb|vb|
+
+ ### Args
+ - video_grids: List[np.ndarray] or np.ndarray, (K, L, H, W, C)
+ - The source videos to splice. It indicates that all the videos have the same size.
+ - grid_ids: List[int] or np.ndarray, (Y, X)
+ - The grid index of each video. It should be a 2D matrix with integers as the type of elements.
+ - The value in this matrix indexed the video in the `video_grids`, so it ranges from 0 to K-1.
+ - Specially, set the grid index to -1 to use a black place holder.
+
+ ### Returns
+ - np.ndarray, (L, H*Y, W*X, C)
+ - The spliced video.
+ '''
+ if isinstance(video_grids, List):
+ video_grids = np.stack(video_grids)
+ if isinstance(grid_ids, List):
+ grid_ids = np.array(grid_ids)
+
+ assert len(video_grids.shape) == 5, 'video_grids must be in shape (K, L, H, W, 3).'
+ assert len(grid_ids.shape) == 2, 'grid_ids must be a 2D matrix.'
+ assert isinstance(grid_ids[0, 0].item(), int), f'grid_ids must be an integer matrix, but got {grid_ids.dtype}.'
+
+ # Splice the videos.
+ K, L, H, W, C = video_grids.shape
+ Y, X = grid_ids.shape
+
+ # Initialize the spliced video.
+ spliced_video = np.zeros((L, H*Y, W*X, C), dtype=np.uint8)
+ for x in range(X):
+ for y in range(Y):
+ grid_id = grid_ids[y, x]
+ if grid_id == -1:
+ continue
+ spliced_video[:, y*H:(y+1)*H, x*W:(x+1)*W, :] = video_grids[grid_id]
+
+ return spliced_video
+
+
+def crop_img(
+ img : np.ndarray,
+ lurb : Union[np.ndarray, List],
+):
+ '''
+ Crop the image with the given bounding box.
+ The data should be represented in uint8.
+ If the bounding box is out of the image, pad the image with zeros.
+
+ ### Args
+ - img: np.ndarray, (H, W, C)
+ - lurb: np.ndarray or list, (4,)
+ - The bounding box in the format of left, up, right, bottom.
+
+ ### Returns
+ - np.ndarray, (H', W', C)
+ - The cropped image.
+ '''
+
+ return crop_video(img[None], lurb)[0]
+
+
+def crop_video(
+ frames : np.ndarray,
+ lurb : Union[np.ndarray, List],
+):
+ '''
+ Crop the video with the given bounding box.
+ The data should be represented in uint8.
+ If the bounding box is out of the video, pad the frames with zeros.
+
+ ### Args
+ - frames: np.ndarray, (L, H, W, C)
+ - lurb: np.ndarray or list, (4,)
+ - The bounding box in the format of left, up, right, bottom.
+
+ ### Returns
+ - np.ndarray, (L, H', W', C)
+ - The cropped video.
+ '''
+ assert len(frames.shape) == 4, 'framess must have 4 dimensions.'
+ if isinstance(lurb, List):
+ lurb = np.array(lurb)
+
+ l, u, r, b = lurb.astype(int)
+ L, H, W = frames.shape[:3]
+ l_, u_, r_, b_ = max(0, l), max(0, u), min(W, r), min(H, b)
+ cropped_frames = np.zeros((L, b-u, r-l, 3), dtype=np.uint8)
+ cropped_frames[:, u_-u:b_-u, l_-l:r_-l] = frames[:, u_:b, l_:r]
+
+ return cropped_frames
+
+def pad_img(
+ img : np.ndarray,
+ tgt_wh : Tuple[int, int],
+ pad_val : int = 0,
+ align : str = 'c-c',
+):
+ '''
+ Pad the image to the target width and height.
+
+ ### Args
+ - img: np.ndarray, (H, W, 3)
+ - tgt_wh: Tuple[int, int]
+ - The target width and height. Use -1 to indicate the original scale.
+ - pad_value: int, default 0
+ - The value to pad the image.
+ - align: str, default 'c-c'
+ - The alignment of the image. It should be in the format of 'h-v',
+ where 'h' and 'v' can be 'l', 'c', 'r' and 't', 'c', 'b' respectively.
+
+ ### Returns
+ - np.ndarray, (H', W', 3)
+ - The padded image.
+ '''
+ assert len(img.shape) == 3, 'img must have 3 dimensions.'
+ return pad_video(img[None], tgt_wh, pad_val, align)[0]
+
+def pad_video(
+ frames : np.ndarray,
+ tgt_wh : Tuple[int, int],
+ pad_val : int = 0,
+ align : str = 'c-c',
+):
+ '''
+ Pad the video to the target width and height.
+
+ ### Args
+ - frames: np.ndarray, (L, H, W, 3)
+ - tgt_wh: Tuple[int, int]
+ - The target width and height. Use -1 to indicate the original scale.
+ - pad_value: int, default 0
+ - The value to pad the frames.
+
+ ### Returns
+ - np.ndarray, (L, H', W', 3)
+ - The padded frames.
+ '''
+ # Check data validity.
+ assert len(frames.shape) == 4, 'frames must have 4 dimensions.'
+ assert len(tgt_wh) == 2, 'tgt_wh must be a tuple of 2 elements.'
+ H, W = frames.shape[1], frames.shape[2]
+ if tgt_wh[0] == -1: tgt_wh = (W, tgt_wh[1])
+ if tgt_wh[1] == -1: tgt_wh = (tgt_wh[0], H)
+ assert tgt_wh[0] >= frames.shape[2] and tgt_wh[1] >= frames.shape[1], 'The target size must be larger than the original size.'
+ assert pad_val >= 0 and pad_val <= 255, 'The pad value must be in the range of [0, 255].'
+ # Check align pattern.
+ align = align.split('-')
+ assert len(align) == 2, 'align must be in the format of "h-v".'
+ assert align[0] in ['l', 'c', 'r'] and align[1] in ['l', 'c', 'r'], 'align must be in ["l", "c", "r"].'
+
+ tgt_w, tgt_h = tgt_wh
+ pad_pix = [tgt_w - W, tgt_h - H] # indicate how many pixels to be padded
+ pad_lu = [0, 0] # how many pixels to pad on the left and the up side
+ for direction in [0, 1]:
+ if align[direction] == 'c':
+ pad_lu[direction] = pad_pix[direction] // 2
+ elif align[direction] == 'r':
+ pad_lu[direction] = pad_pix[direction]
+ pad_l, pad_r, pad_u, pad_b = pad_lu[0], pad_pix[0] - pad_lu[0], pad_lu[1], pad_pix[1] - pad_lu[1]
+
+ padded_frames = np.pad(frames, ((0, 0), (pad_u, pad_b), (pad_l, pad_r), (0, 0)), 'constant', constant_values=pad_val)
+
+ return padded_frames
\ No newline at end of file
diff --git a/lib/utils/media/io.py b/lib/utils/media/io.py
new file mode 100644
index 0000000000000000000000000000000000000000..d2cbe9e21c33c96034c4acc37427360209f33e79
--- /dev/null
+++ b/lib/utils/media/io.py
@@ -0,0 +1,107 @@
+import imageio
+import numpy as np
+
+from tqdm import tqdm
+from typing import Union, List
+from pathlib import Path
+from glob import glob
+
+from .edit import flex_resize_img, flex_resize_video
+
+
+def load_img_meta(
+ img_path : Union[str, Path],
+):
+ ''' Read the image meta from the given path without opening image. '''
+ assert Path(img_path).exists(), f'Image not found: {img_path}'
+ H, W = imageio.v3.improps(img_path).shape[:2]
+ meta = {'w': W, 'h': H}
+ return meta
+
+
+def load_img(
+ img_path : Union[str, Path],
+ mode : str = 'RGB',
+):
+ ''' Read the image from the given path. '''
+ assert Path(img_path).exists(), f'Image not found: {img_path}'
+
+ img = imageio.v3.imread(img_path, plugin='pillow', mode=mode)
+
+ meta = {
+ 'w': img.shape[1],
+ 'h': img.shape[0],
+ }
+ return img, meta
+
+
+def save_img(
+ img : np.ndarray,
+ output_path : Union[str, Path],
+ resize_ratio : Union[float, None] = None,
+ **kwargs,
+):
+ ''' Save the image. '''
+ assert img.ndim == 3, f'Invalid image shape: {img.shape}'
+
+ if resize_ratio is not None:
+ img = flex_resize_img(img, ratio=resize_ratio)
+
+ imageio.v3.imwrite(output_path, img, **kwargs)
+
+
+def load_video(
+ video_path : Union[str, Path],
+):
+ ''' Read the video from the given path. '''
+ if isinstance(video_path, str):
+ video_path = Path(video_path)
+
+ assert video_path.exists(), f'Video not found: {video_path}'
+
+ if video_path.is_dir():
+ print(f'Found {video_path} is a directory. It will be regarded as a image folder.')
+ imgs_path = sorted(glob(str(video_path / '*')))
+ frames = []
+ for img_path in tqdm(imgs_path):
+ frames.append(imageio.imread(img_path))
+ fps = 30 # default fps
+ else:
+ print(f'Found {video_path} is a file. It will be regarded as a video file.')
+ reader = imageio.get_reader(video_path, format='FFMPEG')
+ frames = []
+ for frame in tqdm(reader, total=reader.count_frames()):
+ frames.append(frame)
+ fps = reader.get_meta_data()['fps']
+ frames = np.stack(frames, axis=0) # (L, H, W, 3)
+ meta = {
+ 'fps': fps,
+ 'w' : frames.shape[2],
+ 'h' : frames.shape[1],
+ 'L' : frames.shape[0],
+ }
+
+ return frames, meta
+
+
+def save_video(
+ frames : Union[np.ndarray, List[np.ndarray]],
+ output_path : Union[str, Path],
+ fps : float = 30,
+ resize_ratio : Union[float, None] = None,
+ quality : Union[int, None] = None,
+):
+ ''' Save the frames as a video. '''
+ if isinstance(frames, List):
+ frames = np.stack(frames, axis=0)
+ assert frames.ndim == 4, f'Invalid frames shape: {frames.shape}'
+
+ if resize_ratio is not None:
+ frames = flex_resize_video(frames, ratio=resize_ratio)
+ Path(output_path).parent.mkdir(parents=True, exist_ok=True)
+
+ writer = imageio.get_writer(output_path, fps=fps, quality=quality)
+ output_seq_name = str(output_path).split('/')[-1]
+ for frame in tqdm(frames, desc=f'Saving {output_seq_name}'):
+ writer.append_data(frame)
+ writer.close()
\ No newline at end of file
diff --git a/lib/utils/vis/__init__.py b/lib/utils/vis/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..a59b8fbac08376a6807bbf250a26728662bbd747
--- /dev/null
+++ b/lib/utils/vis/__init__.py
@@ -0,0 +1,8 @@
+from .colors import ColorPalette
+
+from .wis3d import HWis3D as Wis3D
+
+from .py_renderer import (
+ render_mesh_overlay_img,
+ render_mesh_overlay_video,
+)
\ No newline at end of file
diff --git a/lib/utils/vis/colors.py b/lib/utils/vis/colors.py
new file mode 100644
index 0000000000000000000000000000000000000000..26b8d622720005991faf8f614191400d057883f4
--- /dev/null
+++ b/lib/utils/vis/colors.py
@@ -0,0 +1,45 @@
+from lib.kits.basic import *
+
+def float_to_int8(color: List[float]) -> List[int]:
+ return [int(c * 255) for c in color]
+
+def int8_to_float(color: List[int]) -> List[float]:
+ return [c / 255 for c in color]
+
+def int8_to_hex(color: List[int]) -> str:
+ return '#%02x%02x%02x' % tuple(color)
+
+def float_to_hex(color: List[float]) -> str:
+ return int8_to_hex(float_to_int8(color))
+
+def hex_to_int8(color: str) -> List[int]:
+ return [int(color[i+1:i+3], 16) for i in (0, 2, 4)]
+
+def hex_to_float(color: str) -> List[float]:
+ return int8_to_float(hex_to_int8(color))
+
+
+# TODO: incorporate https://github.com/vye16/slahmr/blob/main/slahmr/vis/colors.txt
+
+class ColorPalette:
+
+ # Picked from: https://colorsite.librian.net/
+ presets = {
+ 'black' : '#2b2b2b',
+ 'white' : '#eaedf7',
+ 'pink' : '#e6cde3',
+ 'light_pink' : '#fdeff2',
+ 'blue' : '#89c3eb',
+ 'purple' : '#a6a5c4',
+ 'light_purple' : '#bbc8e6',
+ 'red' : '#d3381c',
+ 'orange' : '#f9c89b',
+ 'light_orange' : '#fddea5',
+ 'brown' : '#b48a76',
+ 'human_yellow' : '#f1bf99',
+ 'green' : '#a8c97f',
+ }
+
+ presets_int8 = {k: hex_to_int8(v) for k, v in presets.items()}
+ presets_float = {k: int8_to_float(v) for k, v in presets_int8.items()}
+ presets_hex = {k: v for k, v in presets.items()}
\ No newline at end of file
diff --git a/lib/utils/vis/p3d_renderer/README.md b/lib/utils/vis/p3d_renderer/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..9e692f02bd6a6cc6ff9a06811e4c85c48f65f9cd
--- /dev/null
+++ b/lib/utils/vis/p3d_renderer/README.md
@@ -0,0 +1,29 @@
+## Pytorch3D Renderer
+
+> The code was modified from GVHMR: https://github.com/zju3dv/GVHMR/tree/main/hmr4d/utils/vis
+
+### Dependency
+
+```shell
+pip install "git+https://github.com/facebookresearch/pytorch3d.git@v0.7.6"
+```
+
+### Example
+
+```python
+from lib.utils.vis import Renderer
+import imageio
+
+fps = 30
+focal_length = data["cam_int"][0][0, 0]
+width, height = img_hw
+faces = smplh[data["gender"]].bm.faces
+renderer = Renderer(width, height, focal_length, "cuda", faces)
+writer = imageio.get_writer("tmp_debug.mp4", fps=fps, mode="I", format="FFMPEG", macro_block_size=1)
+
+for i in tqdm(range(length)):
+ img = np.zeros((height, width, 3), dtype=np.uint8)
+ img = renderer.render_mesh(smplh_out.vertices[i].cuda(), img)
+ writer.append_data(img)
+writer.close()
+```
\ No newline at end of file
diff --git a/lib/utils/vis/p3d_renderer/__init__.py b/lib/utils/vis/p3d_renderer/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..be01f37fc2629ed50bc654130037ad86fea57f62
--- /dev/null
+++ b/lib/utils/vis/p3d_renderer/__init__.py
@@ -0,0 +1,125 @@
+from lib.kits.basic import *
+
+import cv2
+import imageio
+
+from tqdm import tqdm
+
+from lib.utils.vis import ColorPalette
+from lib.utils.media import save_img
+
+from .renderer import *
+
+
+def render_mesh_overlay_img(
+ faces : Union[torch.Tensor, np.ndarray],
+ verts : torch.Tensor,
+ K4 : List,
+ img : np.ndarray,
+ output_fn : Optional[Union[str, Path]] = None,
+ device : str = 'cuda',
+ resize : float = 1.0,
+ Rt : Optional[Tuple[torch.Tensor]] = None,
+ mesh_color : Optional[Union[List[float], str]] = 'blue',
+) -> Any:
+ '''
+ Render the mesh overlay on the input video frames.
+
+ ### Args
+ - faces: Union[torch.Tensor, np.ndarray], (V, 3)
+ - verts: torch.Tensor, (V, 3)
+ - K4: List
+ - [fx, fy, cx, cy], the components of intrinsic camera matrix.
+ - img: np.ndarray, (H, W, 3)
+ - output_fn: Union[str, Path] or None
+ - The output file path, if None, return the rendered img.
+ - fps: int, default 30
+ - device: str, default 'cuda'
+ - resize: float, default 1.0
+ - The resize factor of the output video.
+ - Rt: Tuple of Tensor, default None
+ - The extrinsic camera matrix, in the form of (R, t).
+ '''
+ frame = render_mesh_overlay_video(
+ faces = faces,
+ verts = verts[None],
+ K4 = K4,
+ frames = img[None],
+ device = device,
+ resize = resize,
+ Rt = Rt,
+ mesh_color = mesh_color,
+ )[0]
+
+ if output_fn is None:
+ return frame
+ else:
+ save_img(frame, output_fn)
+
+
+def render_mesh_overlay_video(
+ faces : Union[torch.Tensor, np.ndarray],
+ verts : torch.Tensor,
+ K4 : List,
+ frames : np.ndarray,
+ output_fn : Optional[Union[str, Path]] = None,
+ fps : int = 30,
+ device : str = 'cuda',
+ resize : float = 1.0,
+ Rt : Tuple = None,
+ mesh_color : Optional[Union[List[float], str]] = 'blue',
+) -> Any:
+ '''
+ Render the mesh overlay on the input video frames.
+
+ ### Args
+ - faces: Union[torch.Tensor, np.ndarray], (V, 3)
+ - verts: torch.Tensor, (L, V, 3)
+ - K4: List
+ - [fx, fy, cx, cy], the components of intrinsic camera matrix.
+ - frames: np.ndarray, (L, H, W, 3)
+ - output_fn: Union[str, Path] or None
+ - The output file path, if None, return the rendered frames.
+ - fps: int, default 30
+ - device: str, default 'cuda'
+ - resize: float, default 1.0
+ - The resize factor of the output video.
+ - Rt: Tuple, default None
+ - The extrinsic camera matrix, in the form of (R, t).
+ '''
+ if isinstance(faces, torch.Tensor):
+ faces = faces.cpu().numpy()
+ assert len(K4) == 4, 'K4 must be a list of 4 elements.'
+ assert frames.shape[0] == verts.shape[0], 'The length of frames and verts must be the same.'
+ assert frames.shape[-1] == 3, 'The last dimension of frames must be 3.'
+ if isinstance(mesh_color, str):
+ mesh_color = ColorPalette.presets_float[mesh_color]
+
+ # Prepare the data.
+ L = frames.shape[0]
+ focal_length = (K4[0] + K4[1]) / 2 # f = (fx + fy) / 2
+ width, height = frames.shape[-2], frames.shape[-3]
+ cx2, cy2 = int(K4[2] * 2), int(K4[3] * 2)
+ # Prepare the renderer.
+ renderer = Renderer(cx2, cy2, focal_length, device, faces)
+ if Rt is not None:
+ Rt = (to_tensor(Rt[0], device), to_tensor(Rt[1], device))
+ renderer.create_camera(*Rt)
+
+ if output_fn is None:
+ result_frames = []
+ for i in range(L):
+ img = renderer.render_mesh(verts[i].to(device), frames[i], mesh_color)
+ img = cv2.resize(img, (int(width * resize), int(height * resize)))
+ result_frames.append(img)
+ result_frames = np.stack(result_frames, axis=0)
+ return result_frames
+ else:
+ writer = imageio.get_writer(output_fn, fps=fps, mode='I', format='FFMPEG', macro_block_size=1)
+ # Render the video.
+ output_seq_name = str(output_fn).split('/')[-1]
+ for i in tqdm(range(L), desc=f'Rendering [{output_seq_name}]...'):
+ img = renderer.render_mesh(verts[i].to(device), frames[i])
+ writer.append_data(img)
+ img = cv2.resize(img, (int(width * resize), int(height * resize)))
+ writer.close()
\ No newline at end of file
diff --git a/lib/utils/vis/p3d_renderer/renderer.py b/lib/utils/vis/p3d_renderer/renderer.py
new file mode 100644
index 0000000000000000000000000000000000000000..9dc58781b17d7bccb11a325f938c01cb05cc885f
--- /dev/null
+++ b/lib/utils/vis/p3d_renderer/renderer.py
@@ -0,0 +1,340 @@
+import cv2
+import torch
+import numpy as np
+
+from pytorch3d.renderer import (
+ PerspectiveCameras,
+ TexturesVertex,
+ PointLights,
+ Materials,
+ RasterizationSettings,
+ MeshRenderer,
+ MeshRasterizer,
+ SoftPhongShader,
+)
+from pytorch3d.structures import Meshes
+from pytorch3d.structures.meshes import join_meshes_as_scene
+from pytorch3d.renderer.cameras import look_at_rotation
+from pytorch3d.transforms import axis_angle_to_matrix
+
+from .utils import get_colors, checkerboard_geometry
+
+
+colors_str_map = {
+ "gray": [0.8, 0.8, 0.8],
+ "green": [39, 194, 128],
+}
+
+
+def overlay_image_onto_background(image, mask, bbox, background):
+ if isinstance(image, torch.Tensor):
+ image = image.detach().cpu().numpy()
+ if isinstance(mask, torch.Tensor):
+ mask = mask.detach().cpu().numpy()
+
+ out_image = background.copy()
+ bbox = bbox[0].int().cpu().numpy().copy()
+ roi_image = out_image[bbox[1] : bbox[3], bbox[0] : bbox[2]]
+
+ roi_image[mask] = image[mask]
+ out_image[bbox[1] : bbox[3], bbox[0] : bbox[2]] = roi_image
+
+ return out_image
+
+
+def update_intrinsics_from_bbox(K_org, bbox):
+ device, dtype = K_org.device, K_org.dtype
+
+ K = torch.zeros((K_org.shape[0], 4, 4)).to(device=device, dtype=dtype)
+ K[:, :3, :3] = K_org.clone()
+ K[:, 2, 2] = 0
+ K[:, 2, -1] = 1
+ K[:, -1, 2] = 1
+
+ image_sizes = []
+ for idx, bbox in enumerate(bbox):
+ left, upper, right, lower = bbox
+ cx, cy = K[idx, 0, 2], K[idx, 1, 2]
+
+ new_cx = cx - left
+ new_cy = cy - upper
+ new_height = max(lower - upper, 1)
+ new_width = max(right - left, 1)
+ new_cx = new_width - new_cx
+ new_cy = new_height - new_cy
+
+ K[idx, 0, 2] = new_cx
+ K[idx, 1, 2] = new_cy
+ image_sizes.append((int(new_height), int(new_width)))
+
+ return K, image_sizes
+
+
+def perspective_projection(x3d, K, R=None, T=None):
+ if R != None:
+ x3d = torch.matmul(R, x3d.transpose(1, 2)).transpose(1, 2)
+ if T != None:
+ x3d = x3d + T.transpose(1, 2)
+
+ x2d = torch.div(x3d, x3d[..., 2:])
+ x2d = torch.matmul(K, x2d.transpose(-1, -2)).transpose(-1, -2)[..., :2]
+ return x2d
+
+
+def compute_bbox_from_points(X, img_w, img_h, scaleFactor=1.2):
+ left = torch.clamp(X.min(1)[0][:, 0], min=0, max=img_w)
+ right = torch.clamp(X.max(1)[0][:, 0], min=0, max=img_w)
+ top = torch.clamp(X.min(1)[0][:, 1], min=0, max=img_h)
+ bottom = torch.clamp(X.max(1)[0][:, 1], min=0, max=img_h)
+
+ cx = (left + right) / 2
+ cy = (top + bottom) / 2
+ width = right - left
+ height = bottom - top
+
+ new_left = torch.clamp(cx - width / 2 * scaleFactor, min=0, max=img_w - 1)
+ new_right = torch.clamp(cx + width / 2 * scaleFactor, min=1, max=img_w)
+ new_top = torch.clamp(cy - height / 2 * scaleFactor, min=0, max=img_h - 1)
+ new_bottom = torch.clamp(cy + height / 2 * scaleFactor, min=1, max=img_h)
+
+ bbox = torch.stack((new_left.detach(), new_top.detach(), new_right.detach(), new_bottom.detach())).int().float().T
+
+ return bbox
+
+
+class Renderer:
+ def __init__(self, width, height, focal_length=None, device="cuda", faces=None, K=None):
+ self.width = width
+ self.height = height
+ assert (focal_length is not None) ^ (K is not None), "focal_length and K are mutually exclusive"
+
+ self.device = device
+ if faces is not None:
+ if isinstance(faces, np.ndarray):
+ faces = torch.from_numpy((faces).astype("int"))
+ if len(faces.shape) == 2:
+ self.faces = faces.unsqueeze(0).to(self.device)
+ elif len(faces.shape) == 3:
+ self.faces = faces.to(self.device)
+ else:
+ raise ValueError("faces should have shape of (F, 3) or (N, F, 3)")
+
+ self.initialize_camera_params(focal_length, K)
+ self.lights = PointLights(device=device, location=[[0.0, 0.0, -10.0]])
+ self.create_renderer()
+
+ def create_renderer(self):
+ self.renderer = MeshRenderer(
+ rasterizer = MeshRasterizer(
+ raster_settings = RasterizationSettings(
+ image_size = self.image_sizes[0],
+ blur_radius = 1e-5,
+ bin_size = 0,
+ ),
+ ),
+ shader = SoftPhongShader(
+ device=self.device,
+ lights=self.lights,
+ ),
+ )
+
+ def create_camera(self, R=None, T=None):
+ if R is not None:
+ self.R = R.clone().view(1, 3, 3).to(self.device)
+ if T is not None:
+ self.T = T.clone().view(1, 3).to(self.device)
+
+ return PerspectiveCameras(
+ device=self.device, R=self.R.mT, T=self.T, K=self.K_full, image_size=self.image_sizes, in_ndc=False
+ )
+
+ def initialize_camera_params(self, focal_length, K):
+ # Extrinsics
+ self.R = torch.diag(torch.tensor([1, 1, 1])).float().to(self.device).unsqueeze(0)
+
+ self.T = torch.tensor([0, 0, 0]).unsqueeze(0).float().to(self.device)
+
+ # Intrinsics
+ if K is not None:
+ self.K = K.float().reshape(1, 3, 3).to(self.device)
+ else:
+ assert focal_length is not None, "focal_length or K should be provided"
+ self.K = (
+ torch.tensor([[focal_length, 0, self.width / 2], [0, focal_length, self.height / 2], [0, 0, 1]])
+ .float()
+ .reshape(1, 3, 3)
+ .to(self.device)
+ )
+ self.bboxes = torch.tensor([[0, 0, self.width, self.height]]).float()
+ self.K_full, self.image_sizes = update_intrinsics_from_bbox(self.K, self.bboxes)
+ self.cameras = self.create_camera()
+
+ def set_intrinsic(self, K):
+ self.K = K.reshape(1, 3, 3)
+
+ def set_ground(self, length, center_x, center_z):
+ device = self.device
+ length, center_x, center_z = map(float, (length, center_x, center_z))
+ v, f, vc, fc = map(torch.from_numpy, checkerboard_geometry(length=length * 2, c1=center_x, c2=center_z, up="y"))
+ v, f, vc = v.to(device), f.to(device), vc.to(device)
+ self.ground_geometry = [v, f, vc]
+
+ def update_bbox(self, x3d, scale=2.0, mask=None):
+ """Update bbox of cameras from the given 3d points
+
+ x3d: input 3D keypoints (or vertices), (num_frames, num_points, 3)
+ """
+
+ if x3d.size(-1) != 3:
+ x2d = x3d.unsqueeze(0)
+ else:
+ x2d = perspective_projection(x3d.unsqueeze(0), self.K, self.R, self.T.reshape(1, 3, 1))
+
+ if mask is not None:
+ x2d = x2d[:, ~mask]
+ bbox = compute_bbox_from_points(x2d, self.width, self.height, scale)
+ self.bboxes = bbox
+
+ self.K_full, self.image_sizes = update_intrinsics_from_bbox(self.K, bbox)
+ self.cameras = self.create_camera()
+ self.create_renderer()
+
+ def reset_bbox(
+ self,
+ ):
+ bbox = torch.zeros((1, 4)).float().to(self.device)
+ bbox[0, 2] = self.width
+ bbox[0, 3] = self.height
+ self.bboxes = bbox
+
+ self.K_full, self.image_sizes = update_intrinsics_from_bbox(self.K, bbox)
+ self.cameras = self.create_camera()
+ self.create_renderer()
+
+ def render_mesh(self, vertices, background=None, colors=[0.8, 0.8, 0.8], VI=50):
+ if vertices.dim() == 2:
+ vertices = vertices.unsqueeze(0) # (V, 3) -> (1, V, 3)
+ elif vertices.dim() != 3:
+ raise ValueError("vertices should have shape of ((Nm,) V, 3)")
+ self.update_bbox(vertices.view(-1, 3)[::VI], scale=1.2)
+
+ if isinstance(colors, torch.Tensor):
+ # per-vertex color
+ verts_features = colors.to(device=vertices.device, dtype=vertices.dtype)
+ colors = [0.8, 0.8, 0.8]
+ else:
+ if colors[0] > 1:
+ colors = [c / 255.0 for c in colors]
+ verts_features = torch.tensor(colors).reshape(1, 1, 3).to(device=vertices.device, dtype=vertices.dtype)
+ verts_features = verts_features.repeat(vertices.shape[0], vertices.shape[1], 1)
+ textures = TexturesVertex(verts_features=verts_features)
+
+ mesh = Meshes(
+ verts=vertices,
+ faces=self.faces,
+ textures=textures,
+ )
+
+ materials = Materials(device=self.device, specular_color=(colors,), shininess=0)
+
+ results = torch.flip(self.renderer(mesh, materials=materials, cameras=self.cameras, lights=self.lights), [1, 2])
+ image = results[0, ..., :3] * 255
+ mask = results[0, ..., -1] > 1e-3
+
+ if background is None:
+ background = np.ones((self.height, self.width, 3)).astype(np.uint8) * 255
+
+ image = overlay_image_onto_background(image, mask, self.bboxes, background.copy())
+ self.reset_bbox()
+ return image
+
+ def render_with_ground(self, verts, colors, cameras, lights, faces=None):
+ """
+ :param verts (N, V, 3), potential multiple people
+ :param colors (N, 3) or (N, V, 3)
+ :param faces (N, F, 3), optional, otherwise self.faces is used will be used
+ """
+ # Sanity check of input verts, colors and faces: (B, V, 3), (B, F, 3), (B, V, 3)
+ N, V, _ = verts.shape
+ if faces is None:
+ faces = self.faces.clone().expand(N, -1, -1)
+ else:
+ assert len(faces.shape) == 3, "faces should have shape of (N, F, 3)"
+
+ assert len(colors.shape) in [2, 3]
+ if len(colors.shape) == 2:
+ assert len(colors) == N, "colors of shape 2 should be (N, 3)"
+ colors = colors[:, None]
+ colors = colors.expand(N, V, -1)[..., :3]
+
+ # (V, 3), (F, 3), (V, 3)
+ gv, gf, gc = self.ground_geometry
+ verts = list(torch.unbind(verts, dim=0)) + [gv]
+ faces = list(torch.unbind(faces, dim=0)) + [gf]
+ colors = list(torch.unbind(colors, dim=0)) + [gc[..., :3]]
+ mesh = create_meshes(verts, faces, colors)
+
+ materials = Materials(device=self.device, shininess=0)
+
+ results = self.renderer(mesh, cameras=cameras, lights=lights, materials=materials)
+ image = (results[0, ..., :3].cpu().numpy() * 255).astype(np.uint8)
+
+ return image
+
+
+def create_meshes(verts, faces, colors):
+ """
+ :param verts (B, V, 3)
+ :param faces (B, F, 3)
+ :param colors (B, V, 3)
+ """
+ textures = TexturesVertex(verts_features=colors)
+ meshes = Meshes(verts=verts, faces=faces, textures=textures)
+ return join_meshes_as_scene(meshes)
+
+
+def get_global_cameras(verts, device="cuda", distance=5, position=(-5.0, 5.0, 0.0)):
+ """This always put object at the center of view"""
+ positions = torch.tensor([position]).repeat(len(verts), 1)
+ targets = verts.mean(1)
+
+ directions = targets - positions
+ directions = directions / torch.norm(directions, dim=-1).unsqueeze(-1) * distance
+ positions = targets - directions
+
+ rotation = look_at_rotation(positions, targets).mT
+ translation = -(rotation @ positions.unsqueeze(-1)).squeeze(-1)
+
+ lights = PointLights(device=device, location=[position])
+ return rotation, translation, lights
+
+
+def get_global_cameras_static(verts, beta=4.0, cam_height_degree=30, target_center_height=0.75, device="cuda"):
+ L, V, _ = verts.shape
+
+ # Compute target trajectory, denote as center + scale
+ targets = verts.mean(1) # (L, 3)
+ targets[:, 1] = 0 # project to xz-plane
+ target_center = targets.mean(0) # (3,)
+ target_scale, target_idx = torch.norm(targets - target_center, dim=-1).max(0)
+
+ # a 45 degree vec from longest axis
+ long_vec = targets[target_idx] - target_center # (x, 0, z)
+ long_vec = long_vec / torch.norm(long_vec)
+ R = axis_angle_to_matrix(torch.tensor([0, np.pi / 4, 0])).to(long_vec)
+ vec = R @ long_vec
+
+ # Compute camera position (center + scale * vec * beta) + y=4
+ target_scale = max(target_scale, 1.0) * beta
+ position = target_center + vec * target_scale
+ position[1] = target_scale * np.tan(np.pi * cam_height_degree / 180) + target_center_height
+
+ # Compute camera rotation and translation
+ positions = position.unsqueeze(0).repeat(L, 1)
+ target_centers = target_center.unsqueeze(0).repeat(L, 1)
+ target_centers[:, 1] = target_center_height
+ rotation = look_at_rotation(positions, target_centers).mT
+ translation = -(rotation @ positions.unsqueeze(-1)).squeeze(-1)
+
+ lights = PointLights(device=device, location=[position.tolist()])
+ return rotation, translation, lights
diff --git a/lib/utils/vis/p3d_renderer/utils.py b/lib/utils/vis/p3d_renderer/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..646e1b21f4b0ad7924ddbcd822dc36808794e421
--- /dev/null
+++ b/lib/utils/vis/p3d_renderer/utils.py
@@ -0,0 +1,804 @@
+import os
+import cv2
+import numpy as np
+import torch
+from PIL import Image
+
+
+def read_image(path, scale=1):
+ im = Image.open(path)
+ if scale == 1:
+ return np.array(im)
+ W, H = im.size
+ w, h = int(scale * W), int(scale * H)
+ return np.array(im.resize((w, h), Image.ANTIALIAS))
+
+
+def transform_torch3d(T_c2w):
+ """
+ :param T_c2w (*, 4, 4)
+ returns (*, 3, 3), (*, 3)
+ """
+ R1 = torch.tensor(
+ [
+ [-1.0, 0.0, 0.0],
+ [0.0, -1.0, 0.0],
+ [0.0, 0.0, 1.0],
+ ],
+ device=T_c2w.device,
+ )
+ R2 = torch.tensor(
+ [
+ [1.0, 0.0, 0.0],
+ [0.0, -1.0, 0.0],
+ [0.0, 0.0, -1.0],
+ ],
+ device=T_c2w.device,
+ )
+ cam_R, cam_t = T_c2w[..., :3, :3], T_c2w[..., :3, 3]
+ cam_R = torch.einsum("...ij,jk->...ik", cam_R, R1)
+ cam_t = torch.einsum("ij,...j->...i", R2, cam_t)
+ return cam_R, cam_t
+
+
+def transform_pyrender(T_c2w):
+ """
+ :param T_c2w (*, 4, 4)
+ """
+ T_vis = torch.tensor(
+ [
+ [1.0, 0.0, 0.0, 0.0],
+ [0.0, -1.0, 0.0, 0.0],
+ [0.0, 0.0, -1.0, 0.0],
+ [0.0, 0.0, 0.0, 1.0],
+ ],
+ device=T_c2w.device,
+ )
+ return torch.einsum("...ij,jk->...ik", torch.einsum("ij,...jk->...ik", T_vis, T_c2w), T_vis)
+
+
+def smpl_to_geometry(verts, faces, vis_mask=None, track_ids=None):
+ """
+ :param verts (B, T, V, 3)
+ :param faces (F, 3)
+ :param vis_mask (optional) (B, T) visibility of each person
+ :param track_ids (optional) (B,)
+ returns list of T verts (B, V, 3), faces (F, 3), colors (B, 3)
+ where B is different depending on the visibility of the people
+ """
+ B, T = verts.shape[:2]
+ device = verts.device
+
+ # (B, 3)
+ colors = track_to_colors(track_ids) if track_ids is not None else torch.ones(B, 3, device) * 0.5
+
+ # list T (B, V, 3), T (B, 3), T (F, 3)
+ return filter_visible_meshes(verts, colors, faces, vis_mask)
+
+
+def filter_visible_meshes(verts, colors, faces, vis_mask=None, vis_opacity=False):
+ """
+ :param verts (B, T, V, 3)
+ :param colors (B, 3)
+ :param faces (F, 3)
+ :param vis_mask (optional tensor, default None) (B, T) ternary mask
+ -1 if not in frame
+ 0 if temporarily occluded
+ 1 if visible
+ :param vis_opacity (optional bool, default False)
+ if True, make occluded people alpha=0.5, otherwise alpha=1
+ returns a list of T lists verts (Bi, V, 3), colors (Bi, 4), faces (F, 3)
+ """
+ # import ipdb; ipdb.set_trace()
+ B, T = verts.shape[:2]
+ faces = [faces for t in range(T)]
+ if vis_mask is None:
+ verts = [verts[:, t] for t in range(T)]
+ colors = [colors for t in range(T)]
+ return verts, colors, faces
+
+ # render occluded and visible, but not removed
+ vis_mask = vis_mask >= 0
+ if vis_opacity:
+ alpha = 0.5 * (vis_mask[..., None] + 1)
+ else:
+ alpha = (vis_mask[..., None] >= 0).float()
+ vert_list = [verts[vis_mask[:, t], t] for t in range(T)]
+ colors = [torch.cat([colors[vis_mask[:, t]], alpha[vis_mask[:, t], t]], dim=-1) for t in range(T)]
+ bounds = get_bboxes(verts, vis_mask)
+ return vert_list, colors, faces, bounds
+
+
+def get_bboxes(verts, vis_mask):
+ """
+ return bb_min, bb_max, and mean for each track (B, 3) over entire trajectory
+ :param verts (B, T, V, 3)
+ :param vis_mask (B, T)
+ """
+ B, T, *_ = verts.shape
+ bb_min, bb_max, mean = [], [], []
+ for b in range(B):
+ v = verts[b, vis_mask[b, :T]] # (Tb, V, 3)
+ bb_min.append(v.amin(dim=(0, 1)))
+ bb_max.append(v.amax(dim=(0, 1)))
+ mean.append(v.mean(dim=(0, 1)))
+ bb_min = torch.stack(bb_min, dim=0)
+ bb_max = torch.stack(bb_max, dim=0)
+ mean = torch.stack(mean, dim=0)
+ # point to a track that's long and close to the camera
+ zs = mean[:, 2]
+ counts = vis_mask[:, :T].sum(dim=-1) # (B,)
+ mask = counts < 0.8 * T
+ zs[mask] = torch.inf
+ sel = torch.argmin(zs)
+ return bb_min.amin(dim=0), bb_max.amax(dim=0), mean[sel]
+
+
+def track_to_colors(track_ids):
+ """
+ :param track_ids (B)
+ """
+ color_map = torch.from_numpy(get_colors()).to(track_ids)
+ return color_map[track_ids] / 255 # (B, 3)
+
+
+def get_colors():
+ # color_file = os.path.abspath(os.path.join(__file__, "../colors_phalp.txt"))
+ color_file = os.path.abspath(os.path.join(__file__, "../colors.txt"))
+ RGB_tuples = np.vstack(
+ [
+ np.loadtxt(color_file, skiprows=0),
+ # np.loadtxt(color_file, skiprows=1),
+ np.random.uniform(0, 255, size=(10000, 3)),
+ [[0, 0, 0]],
+ ]
+ )
+ b = np.where(RGB_tuples == 0)
+ RGB_tuples[b] = 1
+ return RGB_tuples.astype(np.float32)
+
+
+def checkerboard_geometry(
+ length=12.0,
+ color0=[0.8, 0.9, 0.9],
+ color1=[0.6, 0.7, 0.7],
+ tile_width=0.5,
+ alpha=1.0,
+ up="y",
+ c1=0.0,
+ c2=0.0,
+):
+ assert up == "y" or up == "z"
+ color0 = np.array(color0 + [alpha])
+ color1 = np.array(color1 + [alpha])
+ radius = length / 2.0
+ num_rows = num_cols = max(2, int(length / tile_width))
+ vertices = []
+ vert_colors = []
+ faces = []
+ face_colors = []
+ for i in range(num_rows):
+ for j in range(num_cols):
+ u0, v0 = j * tile_width - radius, i * tile_width - radius
+ us = np.array([u0, u0, u0 + tile_width, u0 + tile_width])
+ vs = np.array([v0, v0 + tile_width, v0 + tile_width, v0])
+ zs = np.zeros(4)
+ if up == "y":
+ cur_verts = np.stack([us, zs, vs], axis=-1) # (4, 3)
+ cur_verts[:, 0] += c1
+ cur_verts[:, 2] += c2
+ else:
+ cur_verts = np.stack([us, vs, zs], axis=-1) # (4, 3)
+ cur_verts[:, 0] += c1
+ cur_verts[:, 1] += c2
+
+ cur_faces = np.array([[0, 1, 3], [1, 2, 3], [0, 3, 1], [1, 3, 2]], dtype=np.int64)
+ cur_faces += 4 * (i * num_cols + j) # the number of previously added verts
+ use_color0 = (i % 2 == 0 and j % 2 == 0) or (i % 2 == 1 and j % 2 == 1)
+ cur_color = color0 if use_color0 else color1
+ cur_colors = np.array([cur_color, cur_color, cur_color, cur_color])
+
+ vertices.append(cur_verts)
+ faces.append(cur_faces)
+ vert_colors.append(cur_colors)
+ face_colors.append(cur_colors)
+
+ vertices = np.concatenate(vertices, axis=0).astype(np.float32)
+ vert_colors = np.concatenate(vert_colors, axis=0).astype(np.float32)
+ faces = np.concatenate(faces, axis=0).astype(np.float32)
+ face_colors = np.concatenate(face_colors, axis=0).astype(np.float32)
+
+ return vertices, faces, vert_colors, face_colors
+
+
+def camera_marker_geometry(radius, height, up):
+ assert up == "y" or up == "z"
+ if up == "y":
+ vertices = np.array(
+ [
+ [-radius, -radius, 0],
+ [radius, -radius, 0],
+ [radius, radius, 0],
+ [-radius, radius, 0],
+ [0, 0, height],
+ ]
+ )
+ else:
+ vertices = np.array(
+ [
+ [-radius, 0, -radius],
+ [radius, 0, -radius],
+ [radius, 0, radius],
+ [-radius, 0, radius],
+ [0, -height, 0],
+ ]
+ )
+
+ faces = np.array(
+ [
+ [0, 3, 1],
+ [1, 3, 2],
+ [0, 1, 4],
+ [1, 2, 4],
+ [2, 3, 4],
+ [3, 0, 4],
+ ]
+ )
+
+ face_colors = np.array(
+ [
+ [1.0, 1.0, 1.0, 1.0],
+ [1.0, 1.0, 1.0, 1.0],
+ [0.0, 1.0, 0.0, 1.0],
+ [1.0, 0.0, 0.0, 1.0],
+ [0.0, 1.0, 0.0, 1.0],
+ [1.0, 0.0, 0.0, 1.0],
+ ]
+ )
+ return vertices, faces, face_colors
+
+
+def vis_keypoints(
+ keypts_list,
+ img_size,
+ radius=6,
+ thickness=3,
+ kpt_score_thr=0.3,
+ dataset="TopDownCocoDataset",
+):
+ """
+ Visualize keypoints
+ From ViTPose/mmpose/apis/inference.py
+ """
+ palette = np.array(
+ [
+ [255, 128, 0],
+ [255, 153, 51],
+ [255, 178, 102],
+ [230, 230, 0],
+ [255, 153, 255],
+ [153, 204, 255],
+ [255, 102, 255],
+ [255, 51, 255],
+ [102, 178, 255],
+ [51, 153, 255],
+ [255, 153, 153],
+ [255, 102, 102],
+ [255, 51, 51],
+ [153, 255, 153],
+ [102, 255, 102],
+ [51, 255, 51],
+ [0, 255, 0],
+ [0, 0, 255],
+ [255, 0, 0],
+ [255, 255, 255],
+ ]
+ )
+
+ if dataset in (
+ "TopDownCocoDataset",
+ "BottomUpCocoDataset",
+ "TopDownOCHumanDataset",
+ "AnimalMacaqueDataset",
+ ):
+ # show the results
+ skeleton = [
+ [15, 13],
+ [13, 11],
+ [16, 14],
+ [14, 12],
+ [11, 12],
+ [5, 11],
+ [6, 12],
+ [5, 6],
+ [5, 7],
+ [6, 8],
+ [7, 9],
+ [8, 10],
+ [1, 2],
+ [0, 1],
+ [0, 2],
+ [1, 3],
+ [2, 4],
+ [3, 5],
+ [4, 6],
+ ]
+
+ pose_link_color = palette[[0, 0, 0, 0, 7, 7, 7, 9, 9, 9, 9, 9, 16, 16, 16, 16, 16, 16, 16]]
+ pose_kpt_color = palette[[16, 16, 16, 16, 16, 9, 9, 9, 9, 9, 9, 0, 0, 0, 0, 0, 0]]
+
+ elif dataset == "TopDownCocoWholeBodyDataset":
+ # show the results
+ skeleton = [
+ [15, 13],
+ [13, 11],
+ [16, 14],
+ [14, 12],
+ [11, 12],
+ [5, 11],
+ [6, 12],
+ [5, 6],
+ [5, 7],
+ [6, 8],
+ [7, 9],
+ [8, 10],
+ [1, 2],
+ [0, 1],
+ [0, 2],
+ [1, 3],
+ [2, 4],
+ [3, 5],
+ [4, 6],
+ [15, 17],
+ [15, 18],
+ [15, 19],
+ [16, 20],
+ [16, 21],
+ [16, 22],
+ [91, 92],
+ [92, 93],
+ [93, 94],
+ [94, 95],
+ [91, 96],
+ [96, 97],
+ [97, 98],
+ [98, 99],
+ [91, 100],
+ [100, 101],
+ [101, 102],
+ [102, 103],
+ [91, 104],
+ [104, 105],
+ [105, 106],
+ [106, 107],
+ [91, 108],
+ [108, 109],
+ [109, 110],
+ [110, 111],
+ [112, 113],
+ [113, 114],
+ [114, 115],
+ [115, 116],
+ [112, 117],
+ [117, 118],
+ [118, 119],
+ [119, 120],
+ [112, 121],
+ [121, 122],
+ [122, 123],
+ [123, 124],
+ [112, 125],
+ [125, 126],
+ [126, 127],
+ [127, 128],
+ [112, 129],
+ [129, 130],
+ [130, 131],
+ [131, 132],
+ ]
+
+ pose_link_color = palette[
+ [0, 0, 0, 0, 7, 7, 7, 9, 9, 9, 9, 9, 16, 16, 16, 16, 16, 16, 16]
+ + [16, 16, 16, 16, 16, 16]
+ + [0, 0, 0, 0, 4, 4, 4, 4, 8, 8, 8, 8, 12, 12, 12, 12, 16, 16, 16, 16]
+ + [0, 0, 0, 0, 4, 4, 4, 4, 8, 8, 8, 8, 12, 12, 12, 12, 16, 16, 16, 16]
+ ]
+ pose_kpt_color = palette[
+ [16, 16, 16, 16, 16, 9, 9, 9, 9, 9, 9, 0, 0, 0, 0, 0, 0] + [0, 0, 0, 0, 0, 0] + [19] * (68 + 42)
+ ]
+
+ elif dataset == "TopDownAicDataset":
+ skeleton = [
+ [2, 1],
+ [1, 0],
+ [0, 13],
+ [13, 3],
+ [3, 4],
+ [4, 5],
+ [8, 7],
+ [7, 6],
+ [6, 9],
+ [9, 10],
+ [10, 11],
+ [12, 13],
+ [0, 6],
+ [3, 9],
+ ]
+
+ pose_link_color = palette[[9, 9, 9, 9, 9, 9, 16, 16, 16, 16, 16, 0, 7, 7]]
+ pose_kpt_color = palette[[9, 9, 9, 9, 9, 9, 16, 16, 16, 16, 16, 16, 0, 0]]
+
+ elif dataset == "TopDownMpiiDataset":
+ skeleton = [
+ [0, 1],
+ [1, 2],
+ [2, 6],
+ [6, 3],
+ [3, 4],
+ [4, 5],
+ [6, 7],
+ [7, 8],
+ [8, 9],
+ [8, 12],
+ [12, 11],
+ [11, 10],
+ [8, 13],
+ [13, 14],
+ [14, 15],
+ ]
+
+ pose_link_color = palette[[16, 16, 16, 16, 16, 16, 7, 7, 0, 9, 9, 9, 9, 9, 9]]
+ pose_kpt_color = palette[[16, 16, 16, 16, 16, 16, 7, 7, 0, 0, 9, 9, 9, 9, 9, 9]]
+
+ elif dataset == "TopDownMpiiTrbDataset":
+ skeleton = [
+ [12, 13],
+ [13, 0],
+ [13, 1],
+ [0, 2],
+ [1, 3],
+ [2, 4],
+ [3, 5],
+ [0, 6],
+ [1, 7],
+ [6, 7],
+ [6, 8],
+ [7, 9],
+ [8, 10],
+ [9, 11],
+ [14, 15],
+ [16, 17],
+ [18, 19],
+ [20, 21],
+ [22, 23],
+ [24, 25],
+ [26, 27],
+ [28, 29],
+ [30, 31],
+ [32, 33],
+ [34, 35],
+ [36, 37],
+ [38, 39],
+ ]
+
+ pose_link_color = palette[[16] * 14 + [19] * 13]
+ pose_kpt_color = palette[[16] * 14 + [0] * 26]
+
+ elif dataset in ("OneHand10KDataset", "FreiHandDataset", "PanopticDataset"):
+ skeleton = [
+ [0, 1],
+ [1, 2],
+ [2, 3],
+ [3, 4],
+ [0, 5],
+ [5, 6],
+ [6, 7],
+ [7, 8],
+ [0, 9],
+ [9, 10],
+ [10, 11],
+ [11, 12],
+ [0, 13],
+ [13, 14],
+ [14, 15],
+ [15, 16],
+ [0, 17],
+ [17, 18],
+ [18, 19],
+ [19, 20],
+ ]
+
+ pose_link_color = palette[[0, 0, 0, 0, 4, 4, 4, 4, 8, 8, 8, 8, 12, 12, 12, 12, 16, 16, 16, 16]]
+ pose_kpt_color = palette[[0, 0, 0, 0, 0, 4, 4, 4, 4, 8, 8, 8, 8, 12, 12, 12, 12, 16, 16, 16, 16]]
+
+ elif dataset == "InterHand2DDataset":
+ skeleton = [
+ [0, 1],
+ [1, 2],
+ [2, 3],
+ [4, 5],
+ [5, 6],
+ [6, 7],
+ [8, 9],
+ [9, 10],
+ [10, 11],
+ [12, 13],
+ [13, 14],
+ [14, 15],
+ [16, 17],
+ [17, 18],
+ [18, 19],
+ [3, 20],
+ [7, 20],
+ [11, 20],
+ [15, 20],
+ [19, 20],
+ ]
+
+ pose_link_color = palette[[0, 0, 0, 4, 4, 4, 8, 8, 8, 12, 12, 12, 16, 16, 16, 0, 4, 8, 12, 16]]
+ pose_kpt_color = palette[[0, 0, 0, 0, 4, 4, 4, 4, 8, 8, 8, 8, 12, 12, 12, 12, 16, 16, 16, 16, 0]]
+
+ elif dataset == "Face300WDataset":
+ # show the results
+ skeleton = []
+
+ pose_link_color = palette[[]]
+ pose_kpt_color = palette[[19] * 68]
+ kpt_score_thr = 0
+
+ elif dataset == "FaceAFLWDataset":
+ # show the results
+ skeleton = []
+
+ pose_link_color = palette[[]]
+ pose_kpt_color = palette[[19] * 19]
+ kpt_score_thr = 0
+
+ elif dataset == "FaceCOFWDataset":
+ # show the results
+ skeleton = []
+
+ pose_link_color = palette[[]]
+ pose_kpt_color = palette[[19] * 29]
+ kpt_score_thr = 0
+
+ elif dataset == "FaceWFLWDataset":
+ # show the results
+ skeleton = []
+
+ pose_link_color = palette[[]]
+ pose_kpt_color = palette[[19] * 98]
+ kpt_score_thr = 0
+
+ elif dataset == "AnimalHorse10Dataset":
+ skeleton = [
+ [0, 1],
+ [1, 12],
+ [12, 16],
+ [16, 21],
+ [21, 17],
+ [17, 11],
+ [11, 10],
+ [10, 8],
+ [8, 9],
+ [9, 12],
+ [2, 3],
+ [3, 4],
+ [5, 6],
+ [6, 7],
+ [13, 14],
+ [14, 15],
+ [18, 19],
+ [19, 20],
+ ]
+
+ pose_link_color = palette[[4] * 10 + [6] * 2 + [6] * 2 + [7] * 2 + [7] * 2]
+ pose_kpt_color = palette[[4, 4, 6, 6, 6, 6, 6, 6, 4, 4, 4, 4, 4, 7, 7, 7, 4, 4, 7, 7, 7, 4]]
+
+ elif dataset == "AnimalFlyDataset":
+ skeleton = [
+ [1, 0],
+ [2, 0],
+ [3, 0],
+ [4, 3],
+ [5, 4],
+ [7, 6],
+ [8, 7],
+ [9, 8],
+ [11, 10],
+ [12, 11],
+ [13, 12],
+ [15, 14],
+ [16, 15],
+ [17, 16],
+ [19, 18],
+ [20, 19],
+ [21, 20],
+ [23, 22],
+ [24, 23],
+ [25, 24],
+ [27, 26],
+ [28, 27],
+ [29, 28],
+ [30, 3],
+ [31, 3],
+ ]
+
+ pose_link_color = palette[[0] * 25]
+ pose_kpt_color = palette[[0] * 32]
+
+ elif dataset == "AnimalLocustDataset":
+ skeleton = [
+ [1, 0],
+ [2, 1],
+ [3, 2],
+ [4, 3],
+ [6, 5],
+ [7, 6],
+ [9, 8],
+ [10, 9],
+ [11, 10],
+ [13, 12],
+ [14, 13],
+ [15, 14],
+ [17, 16],
+ [18, 17],
+ [19, 18],
+ [21, 20],
+ [22, 21],
+ [24, 23],
+ [25, 24],
+ [26, 25],
+ [28, 27],
+ [29, 28],
+ [30, 29],
+ [32, 31],
+ [33, 32],
+ [34, 33],
+ ]
+
+ pose_link_color = palette[[0] * 26]
+ pose_kpt_color = palette[[0] * 35]
+
+ elif dataset == "AnimalZebraDataset":
+ skeleton = [[1, 0], [2, 1], [3, 2], [4, 2], [5, 7], [6, 7], [7, 2], [8, 7]]
+
+ pose_link_color = palette[[0] * 8]
+ pose_kpt_color = palette[[0] * 9]
+
+ elif dataset in "AnimalPoseDataset":
+ skeleton = [
+ [0, 1],
+ [0, 2],
+ [1, 3],
+ [0, 4],
+ [1, 4],
+ [4, 5],
+ [5, 7],
+ [6, 7],
+ [5, 8],
+ [8, 12],
+ [12, 16],
+ [5, 9],
+ [9, 13],
+ [13, 17],
+ [6, 10],
+ [10, 14],
+ [14, 18],
+ [6, 11],
+ [11, 15],
+ [15, 19],
+ ]
+
+ pose_link_color = palette[[0] * 20]
+ pose_kpt_color = palette[[0] * 20]
+ else:
+ NotImplementedError()
+
+ img_w, img_h = img_size
+ img = 255 * np.ones((img_h, img_w, 3), dtype=np.uint8)
+ img = imshow_keypoints(
+ img,
+ keypts_list,
+ skeleton,
+ kpt_score_thr,
+ pose_kpt_color,
+ pose_link_color,
+ radius,
+ thickness,
+ )
+ alpha = 255 * (img != 255).any(axis=-1, keepdims=True).astype(np.uint8)
+ return np.concatenate([img, alpha], axis=-1)
+
+
+def imshow_keypoints(
+ img,
+ pose_result,
+ skeleton=None,
+ kpt_score_thr=0.3,
+ pose_kpt_color=None,
+ pose_link_color=None,
+ radius=4,
+ thickness=1,
+ show_keypoint_weight=False,
+):
+ """Draw keypoints and links on an image.
+ From ViTPose/mmpose/core/visualization/image.py
+
+ Args:
+ img (H, W, 3) array
+ pose_result (list[kpts]): The poses to draw. Each element kpts is
+ a set of K keypoints as an Kx3 numpy.ndarray, where each
+ keypoint is represented as x, y, score.
+ kpt_score_thr (float, optional): Minimum score of keypoints
+ to be shown. Default: 0.3.
+ pose_kpt_color (np.ndarray[Nx3]`): Color of N keypoints. If None,
+ the keypoint will not be drawn.
+ pose_link_color (np.ndarray[Mx3]): Color of M links. If None, the
+ links will not be drawn.
+ thickness (int): Thickness of lines.
+ show_keypoint_weight (bool): If True, opacity indicates keypoint score
+ """
+ img_h, img_w, _ = img.shape
+ idcs = [0, 16, 15, 18, 17, 5, 2, 6, 3, 7, 4, 12, 9, 13, 10, 14, 11]
+ for kpts in pose_result:
+ kpts = np.array(kpts, copy=False)[idcs]
+
+ # draw each point on image
+ if pose_kpt_color is not None:
+ assert len(pose_kpt_color) == len(kpts)
+ for kid, kpt in enumerate(kpts):
+ x_coord, y_coord, kpt_score = int(kpt[0]), int(kpt[1]), kpt[2]
+ if kpt_score > kpt_score_thr:
+ color = tuple(int(c) for c in pose_kpt_color[kid])
+ if show_keypoint_weight:
+ img_copy = img.copy()
+ cv2.circle(img_copy, (int(x_coord), int(y_coord)), radius, color, -1)
+ transparency = max(0, min(1, kpt_score))
+ cv2.addWeighted(img_copy, transparency, img, 1 - transparency, 0, dst=img)
+ else:
+ cv2.circle(img, (int(x_coord), int(y_coord)), radius, color, -1)
+
+ # draw links
+ if skeleton is not None and pose_link_color is not None:
+ assert len(pose_link_color) == len(skeleton)
+ for sk_id, sk in enumerate(skeleton):
+ pos1 = (int(kpts[sk[0], 0]), int(kpts[sk[0], 1]))
+ pos2 = (int(kpts[sk[1], 0]), int(kpts[sk[1], 1]))
+ if (
+ pos1[0] > 0
+ and pos1[0] < img_w
+ and pos1[1] > 0
+ and pos1[1] < img_h
+ and pos2[0] > 0
+ and pos2[0] < img_w
+ and pos2[1] > 0
+ and pos2[1] < img_h
+ and kpts[sk[0], 2] > kpt_score_thr
+ and kpts[sk[1], 2] > kpt_score_thr
+ ):
+ color = tuple(int(c) for c in pose_link_color[sk_id])
+ if show_keypoint_weight:
+ img_copy = img.copy()
+ X = (pos1[0], pos2[0])
+ Y = (pos1[1], pos2[1])
+ mX = np.mean(X)
+ mY = np.mean(Y)
+ length = ((Y[0] - Y[1]) ** 2 + (X[0] - X[1]) ** 2) ** 0.5
+ angle = math.degrees(math.atan2(Y[0] - Y[1], X[0] - X[1]))
+ stickwidth = 2
+ polygon = cv2.ellipse2Poly(
+ (int(mX), int(mY)),
+ (int(length / 2), int(stickwidth)),
+ int(angle),
+ 0,
+ 360,
+ 1,
+ )
+ cv2.fillConvexPoly(img_copy, polygon, color)
+ transparency = max(0, min(1, 0.5 * (kpts[sk[0], 2] + kpts[sk[1], 2])))
+ cv2.addWeighted(img_copy, transparency, img, 1 - transparency, 0, dst=img)
+ else:
+ cv2.line(img, pos1, pos2, color, thickness=thickness)
+
+ return img
diff --git a/lib/utils/vis/py_renderer/README.md b/lib/utils/vis/py_renderer/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..a0cbedec89aa3a1aa9b4c7fadb9ffbf08e835498
--- /dev/null
+++ b/lib/utils/vis/py_renderer/README.md
@@ -0,0 +1,7 @@
+## [Pyrender Renderer](https://github.com/mmatl/pyrender)
+
+> The code was modified from HMR2.0: https://github.com/shubham-goel/4D-Humans/blob/main/hmr2/utils/renderer.py
+
+This render is used to solve the compatibility issue of [Pytorch3D Renderer](../p3d_renderer/README.md). So, the API is fully adapted, but not all functions are implemented (e.g., `output_fn` is useless but still listed in the API).
+
+Since this renderer is fully torch-independent, and I found the color of mesh is strange. So I tend to use Pytorch3D Renderer by default instead.
\ No newline at end of file
diff --git a/lib/utils/vis/py_renderer/__init__.py b/lib/utils/vis/py_renderer/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..3c022e170b5a3e2aca7c0ae5bfbe2b43c26eebb3
--- /dev/null
+++ b/lib/utils/vis/py_renderer/__init__.py
@@ -0,0 +1,367 @@
+import os
+if 'PYOPENGL_PLATFORM' not in os.environ:
+ os.environ['PYOPENGL_PLATFORM'] = 'egl'
+import torch
+import numpy as np
+import trimesh
+import pyrender
+
+from typing import List, Optional, Union, Tuple
+from pathlib import Path
+
+from lib.utils.vis import ColorPalette
+from lib.utils.data import to_numpy
+from lib.utils.media import save_img
+
+from .utils import *
+
+
+def render_mesh_overlay_img(
+ faces : Union[torch.Tensor, np.ndarray],
+ verts : torch.Tensor,
+ K4 : List,
+ img : np.ndarray,
+ output_fn : Optional[Union[str, Path]] = None,
+ device : str = 'cuda',
+ resize : float = 1.0,
+ Rt : Optional[Tuple[torch.Tensor]] = None,
+ mesh_color : Optional[Union[List[float], str]] = 'green',
+):
+ '''
+ Render the mesh overlay on the input video frames.
+
+ ### Args
+ - faces: Union[torch.Tensor, np.ndarray], (V, 3)
+ - verts: torch.Tensor, (V, 3)
+ - K4: List
+ - [fx, fy, cx, cy], the components of intrinsic camera matrix.
+ - img: np.ndarray, (H, W, 3)
+ - output_fn: Union[str, Path] or None
+ - The output file path, if None, return the rendered img.
+ - fps: int, default 30
+ - device: str, default 'cuda'
+ - resize: float, default 1.0
+ - The resize factor of the output video.
+ - Rt: Tuple of Tensor, default None
+ - The extrinsic camera matrix, in the form of (R, t).
+ '''
+ frame = render_mesh_overlay_video(
+ faces = faces,
+ verts = verts[None],
+ K4 = K4,
+ frames = img[None],
+ device = device,
+ resize = resize,
+ Rt = Rt,
+ mesh_color = mesh_color,
+ )[0]
+
+ if output_fn is None:
+ return frame
+ else:
+ save_img(frame, output_fn)
+
+
+def render_mesh_overlay_video(
+ faces : Union[torch.Tensor, np.ndarray],
+ verts : Union[torch.Tensor, np.ndarray],
+ K4 : List,
+ frames : np.ndarray,
+ output_fn : Optional[Union[str, Path]] = None,
+ fps : int = 30,
+ device : str = 'cuda',
+ resize : float = 1.0,
+ Rt : Tuple = None,
+ mesh_color : Optional[Union[List[float], str]] = 'green',
+):
+ '''
+ Render the mesh overlay on the input video frames.
+
+ ### Args
+ - faces: Union[torch.Tensor, np.ndarray], (V, 3)
+ - verts: Union[torch.Tensor, np.ndarray], (L, V, 3)
+ - K4: List
+ - [fx, fy, cx, cy], the components of intrinsic camera matrix.
+ - frames: np.ndarray, (L, H, W, 3)
+ - output_fn: useless, only for compatibility.
+ - fps: useless, only for compatibility.
+ - device: useless, only for compatibility.
+ - resize: useless, only for compatibility.
+ - Rt: Tuple, default None
+ - The extrinsic camera matrix, in the form of (R, t).
+ '''
+ faces, verts = to_numpy(faces), to_numpy(verts)
+ assert len(K4) == 4, 'K4 must be a list of 4 elements.'
+ assert frames.shape[0] == verts.shape[0], 'The length of frames and verts must be the same.'
+ assert frames.shape[-1] == 3, 'The last dimension of frames must be 3.'
+ if isinstance(mesh_color, str):
+ mesh_color = ColorPalette.presets_float[mesh_color]
+
+ # Prepare the data.
+ L = len(frames)
+ frame_w, frame_h = frames.shape[-2], frames.shape[-3]
+
+ renderer = pyrender.OffscreenRenderer(
+ viewport_width = frame_w,
+ viewport_height = frame_h,
+ point_size = 1.0
+ )
+
+ # Camera
+ camera, cam_pose = create_camera(K4, Rt)
+
+ # Scene.
+ material = pyrender.MetallicRoughnessMaterial(
+ metallicFactor = 0.0,
+ alphaMode = 'OPAQUE',
+ baseColorFactor = (*mesh_color, 1.0)
+ )
+
+ # Light.
+ light_nodes = create_raymond_lights()
+
+ results = []
+ for i in range(L):
+ mesh = trimesh.Trimesh(verts[i].copy(), faces.copy())
+ # if side_view:
+ # rot = trimesh.transformations.rotation_matrix(
+ # np.radians(rot_angle), [0, 1, 0])
+ # mesh.apply_transform(rot)
+ # elif top_view:
+ # rot = trimesh.transformations.rotation_matrix(
+ # np.radians(rot_angle), [1, 0, 0])
+ # mesh.apply_transform(rot)
+ rot = trimesh.transformations.rotation_matrix(
+ np.radians(180), [1, 0, 0])
+ mesh.apply_transform(rot)
+ mesh = pyrender.Mesh.from_trimesh(mesh, material=material)
+
+ scene = pyrender.Scene(
+ bg_color = [0.0, 0.0, 0.0, 0.0],
+ ambient_light = (0.3, 0.3, 0.3),
+ )
+
+ scene.add(mesh, 'mesh')
+ scene.add(camera, pose=cam_pose)
+
+ # Light.
+ for node in light_nodes:
+ scene.add_node(node)
+
+ # Render.
+ result_rgba, rend_depth = renderer.render(scene, flags=pyrender.RenderFlags.RGBA)
+
+ valid_mask = result_rgba.astype(np.float32)[:, :, [-1]] / 255.0 # (H, W, 1)
+ bg = frames[i] # (H, W, 3)
+ final = result_rgba[:, :, :3] * valid_mask + bg * (1 - valid_mask)
+ final = final.astype(np.uint8) # (H, W, 3)
+ results.append(final)
+ results = np.stack(results, axis=0) # (L, H, W, 3)
+
+ renderer.delete()
+ return results
+
+
+
+def render_meshes_overlay_img(
+ faces_all : Union[torch.Tensor, np.ndarray],
+ verts_all : Union[torch.Tensor, np.ndarray],
+ cam_t_all : Union[torch.Tensor, np.ndarray],
+ K4 : List,
+ img : np.ndarray,
+ output_fn : Optional[Union[str, Path]] = None,
+ device : str = 'cuda',
+ resize : float = 1.0,
+ Rt : Optional[Tuple[torch.Tensor]] = None,
+ mesh_color : Optional[Union[List[float], str]] = 'green',
+ view : str = 'front',
+ ret_rgba : bool = False,
+):
+ '''
+ Render the mesh overlay on the input video frames.
+
+ ### Args
+ - faces_all: Union[torch.Tensor, np.ndarray], ((Nm,) V, 3)
+ - verts_all: Union[torch.Tensor, np.ndarray], ((Nm,) V, 3)
+ - cam_t_all: Union[torch.Tensor, np.ndarray], ((Nm,) 3)
+ - K4: List
+ - [fx, fy, cx, cy], the components of intrinsic camera matrix.
+ - img: np.ndarray, (H, W, 3)
+ - output_fn: Union[str, Path] or None
+ - The output file path, if None, return the rendered img.
+ - fps: int, default 30
+ - device: str, default 'cuda'
+ - resize: float, default 1.0
+ - The resize factor of the output video.
+ - Rt: Tuple of Tensor, default None
+ - The extrinsic camera matrix, in the form of (R, t).
+ - view: str, default 'front', {'front', 'side90d', 'side60d', 'top90d'}
+ - ret_rgba: bool, default False
+ - If True, return rgba images, otherwise return rgb images.
+ - For view is not 'front', the background will become transparent.
+ '''
+ if len(verts_all.shape) == 2:
+ verts_all = verts_all[None] # (1, V, 3)
+ elif len(verts_all.shape) == 3:
+ verts_all = verts_all[:, None] # ((Nm,) 1, V, 3)
+ else:
+ raise ValueError('The shape of verts_all is not correct.')
+ if len(cam_t_all.shape) == 1:
+ cam_t_all = cam_t_all[None] # (1, 3)
+ elif len(cam_t_all.shape) == 2:
+ cam_t_all = cam_t_all[:, None] # ((Nm,) 1, 3)
+ else:
+ raise ValueError('The shape of verts_all is not correct.')
+ frame = render_meshes_overlay_video(
+ faces_all = faces_all,
+ verts_all = verts_all,
+ cam_t_all = cam_t_all,
+ K4 = K4,
+ frames = img[None],
+ device = device,
+ resize = resize,
+ Rt = Rt,
+ mesh_color = mesh_color,
+ view = view,
+ ret_rgba = ret_rgba,
+ )[0]
+
+ if output_fn is None:
+ return frame
+ else:
+ save_img(frame, output_fn)
+
+
+def render_meshes_overlay_video(
+ faces_all : Union[torch.Tensor, np.ndarray],
+ verts_all : Union[torch.Tensor, np.ndarray],
+ cam_t_all : Union[torch.Tensor, np.ndarray],
+ K4 : List,
+ frames : np.ndarray,
+ output_fn : Optional[Union[str, Path]] = None,
+ fps : int = 30,
+ device : str = 'cuda',
+ resize : float = 1.0,
+ Rt : Tuple = None,
+ mesh_color : Optional[Union[List[float], str]] = 'green',
+ view : str = 'front',
+ ret_rgba : bool = False,
+):
+ '''
+ Render the mesh overlay on the input video frames.
+
+ ### Args
+ - faces_all: Union[torch.Tensor, np.ndarray], ((Nm,) V, 3)
+ - verts_all: Union[torch.Tensor, np.ndarray], ((Nm,) L, V, 3)
+ - cam_t_all: Union[torch.Tensor, np.ndarray], ((Nm,) L, 3)
+ - K4: List
+ - [fx, fy, cx, cy], the components of intrinsic camera matrix.
+ - frames: np.ndarray, (L, H, W, 3)
+ - output_fn: useless, only for compatibility.
+ - fps: useless, only for compatibility.
+ - device: useless, only for compatibility.
+ - resize: useless, only for compatibility.
+ - Rt: Tuple, default None
+ - The extrinsic camera matrix, in the form of (R, t).
+ - view: str, default 'front', {'front', 'side90d', 'side60d', 'top90d'}
+ - ret_rgba: bool, default False
+ - If True, return rgba images, otherwise return rgb images.
+ - For view is not 'front', the background will become transparent.
+ '''
+ faces_all, verts_all = to_numpy(faces_all), to_numpy(verts_all)
+ if len(verts_all.shape) == 3:
+ verts_all = verts_all[None] # (1, L, V, 3)
+ if len(cam_t_all.shape) == 2:
+ cam_t_all = cam_t_all[None] # (1, L, 3)
+ Nm, L, _, _ = verts_all.shape
+ if len(faces_all.shape) == 2:
+ faces_all = faces_all[None].repeat(Nm, axis=0) # (Nm, V, 3)
+
+ assert len(K4) == 4, 'K4 must be a list of 4 elements.'
+ assert frames.shape[0] == L, 'The length of frames and verts must be the same.'
+ assert frames.shape[-1] == 3, 'The last dimension of frames must be 3.'
+ assert len(verts_all.shape) == 4, 'The shape of verts_all must be (Nm, L, V, 3).'
+ assert len(faces_all.shape) == 3, 'The shape of faces_all must be (Nm, V, 3).'
+ if isinstance(mesh_color, str):
+ mesh_color = ColorPalette.presets_float[mesh_color]
+
+ # Prepare the data.
+ frame_w, frame_h = frames.shape[-2], frames.shape[-3]
+
+ renderer = pyrender.OffscreenRenderer(
+ viewport_width = frame_w,
+ viewport_height = frame_h,
+ point_size = 1.0
+ )
+
+ # Camera
+ camera, cam_pose = create_camera(K4, Rt)
+
+ # Scene.
+ material = pyrender.MetallicRoughnessMaterial(
+ metallicFactor = 0.0,
+ alphaMode = 'OPAQUE',
+ baseColorFactor = (*mesh_color, 1.0)
+ )
+
+ # Light.
+ light_nodes = create_raymond_lights()
+
+ results = []
+ for i in range(L):
+ scene = pyrender.Scene(
+ bg_color = [0.0, 0.0, 0.0, 0.0],
+ ambient_light = (0.3, 0.3, 0.3),
+ )
+
+ for mid in range(Nm):
+ mesh = trimesh.Trimesh(verts_all[mid][i].copy(), faces_all[mid].copy())
+ if view == 'front':
+ pass
+ elif view == 'side90d':
+ rot = trimesh.transformations.rotation_matrix(np.radians(-90), [0, 1, 0])
+ mesh.apply_transform(rot)
+ elif view == 'side60d':
+ rot = trimesh.transformations.rotation_matrix(np.radians(-60), [0, 1, 0])
+ mesh.apply_transform(rot)
+ elif view == 'top90d':
+ rot = trimesh.transformations.rotation_matrix(np.radians(90), [1, 0, 0])
+ mesh.apply_transform(rot)
+ else:
+ raise ValueError('The view is not supported.')
+ trans = trimesh.transformations.translation_matrix(to_numpy(cam_t_all[mid][i]))
+ mesh.apply_transform(trans)
+ rot = trimesh.transformations.rotation_matrix(np.radians(180), [1, 0, 0])
+ mesh.apply_transform(rot)
+ mesh = pyrender.Mesh.from_trimesh(mesh, material=material)
+
+ scene.add(mesh, f'mesh_{mid}')
+ scene.add(camera, pose=cam_pose)
+
+ # Light.
+ for node in light_nodes:
+ scene.add_node(node)
+
+ # Render.
+ result_rgba, rend_depth = renderer.render(scene, flags=pyrender.RenderFlags.RGBA)
+
+ valid_mask = result_rgba.astype(np.float32)[:, :, [-1]] / 255.0 # (H, W, 1)
+ if view == 'front':
+ bg = frames[i] # (H, W, 3)
+ else:
+ bg = np.ones_like(frames[i]) * 255 # (H, W, 3)
+ if ret_rgba:
+ if view == 'front':
+ bg_alpha = np.ones_like(bg[..., [0]]) * 255 # (H, W, 1)
+ else:
+ bg_alpha = np.zeros_like(bg[..., [0]]) * 255 # (H, W, 1)
+ bg = np.concatenate([bg, bg_alpha], axis=-1) # (H, W, 4)
+ final = result_rgba * valid_mask + bg * (1 - valid_mask) # (H, W, 4)
+ else:
+ final = result_rgba[:, :, :3] * valid_mask + bg * (1 - valid_mask)
+ final = final.astype(np.uint8) # (H, W, 3)
+ results.append(final)
+ results = np.stack(results, axis=0) # (L, H, W, 3)
+
+ renderer.delete()
+ return results
diff --git a/lib/utils/vis/py_renderer/utils.py b/lib/utils/vis/py_renderer/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..dd977e3ce62922a392bac6a38b3ea5d90c1c7ec6
--- /dev/null
+++ b/lib/utils/vis/py_renderer/utils.py
@@ -0,0 +1,52 @@
+import numpy as np
+import pyrender
+
+from typing import List, Optional, Union, Tuple
+from lib.utils.data import to_numpy
+
+
+def create_camera(K4:List, Rt:Optional[Tuple]):
+ if Rt is not None:
+ cam_R, cam_t = Rt
+ cam_R = to_numpy(cam_R).copy()
+ cam_t = to_numpy(cam_t).copy()
+ cam_t[0] *= -1
+ else:
+ cam_R = np.eye(3)
+ cam_t = np.zeros(3)
+ fx, fy, cx, cy = K4
+ cam_pose = np.eye(4)
+ cam_pose[:3, :3] = cam_R
+ cam_pose[:3, 3] = cam_t
+ camera = pyrender.IntrinsicsCamera( fx=fx, fy=fy, cx=cx, cy=cy, zfar=1e12 )
+ return camera, cam_pose
+
+
+def create_raymond_lights() -> List[pyrender.Node]:
+ ''' Return raymond light nodes for the scene. '''
+ thetas = np.pi * np.array([1.0 / 6.0, 1.0 / 6.0, 1.0 / 6.0])
+ phis = np.pi * np.array([0.0, 2.0 / 3.0, 4.0 / 3.0])
+
+ nodes = []
+
+ for phi, theta in zip(phis, thetas):
+ xp = np.sin(theta) * np.cos(phi)
+ yp = np.sin(theta) * np.sin(phi)
+ zp = np.cos(theta)
+
+ z = np.array([xp, yp, zp])
+ z = z / np.linalg.norm(z)
+ x = np.array([-z[1], z[0], 0.0])
+ if np.linalg.norm(x) == 0:
+ x = np.array([1.0, 0.0, 0.0])
+ x = x / np.linalg.norm(x)
+ y = np.cross(z, x)
+
+ matrix = np.eye(4)
+ matrix[:3,:3] = np.c_[x,y,z]
+ nodes.append(pyrender.Node(
+ light = pyrender.DirectionalLight(color=np.ones(3), intensity=0.75),
+ matrix = matrix
+ ))
+
+ return nodes
\ No newline at end of file
diff --git a/lib/utils/vis/wis3d.py b/lib/utils/vis/wis3d.py
new file mode 100644
index 0000000000000000000000000000000000000000..e1c5aa5d4398d191091a2c47ebefb0ef2e659e36
--- /dev/null
+++ b/lib/utils/vis/wis3d.py
@@ -0,0 +1,493 @@
+import torch
+import numpy as np
+
+from typing import Union, List, overload
+from wis3d import Wis3D
+
+from lib.platform import PM
+from lib.utils.geometry.rotation import axis_angle_to_matrix
+
+
+class HWis3D(Wis3D):
+ ''' Abstraction of Wis3D for human motion. '''
+
+ def __init__(
+ self,
+ out_path : str = PM.outputs / 'wis3d',
+ seq_name : str = 'debug',
+ xyz_pattern : tuple = ('x', 'y', 'z'),
+ ):
+ seq_name = seq_name.replace('/', '-')
+ super().__init__(out_path, seq_name, xyz_pattern)
+
+
+ def add_text(self, text:str):
+ '''
+ Add an item of vertices whose name is used to put the text message. *Dirty use!*
+
+ ### Args
+ - text: str
+ '''
+ fake_verts = np.array([[0, 0, 0]])
+ self.add_point_cloud(
+ vertices = fake_verts,
+ colors = None,
+ name = text,
+ )
+
+
+ def add_text_seq(self, texts:List[str], offset:int=0):
+ '''
+ Add an item of vertices whose name is used to put the text message. *Dirty use!*
+
+ ### Args
+ - texts: List[str]
+ - The list of text messages.
+ - offset: int, default = 0
+ - The offset for the sequence index.
+ '''
+ fake_verts = np.array([[0, 0, 0]])
+ for i, text in enumerate(texts):
+ self.set_scene_id(i + offset)
+ self.add_point_cloud(
+ vertices = fake_verts,
+ colors = None,
+ name = text,
+ )
+
+ def add_image_seq(self, imgs:List[np.ndarray], name:str, offset:int=0):
+ '''
+ Add an item of vertices whose name is used to put the image. *Dirty use!*
+
+ ### Args
+ - imgs: List[np.ndarray]
+ - The list of images.
+ - offset: int, default = 0
+ - The offset for the sequence index.
+ '''
+ for i, img in enumerate(imgs):
+ self.set_scene_id(i + offset)
+ self.add_image(
+ image = img,
+ name = name,
+ )
+
+ def add_motion_mesh(
+ self,
+ verts : Union[torch.Tensor, np.ndarray],
+ faces : Union[torch.Tensor, np.ndarray],
+ name : str,
+ offset: int = 0,
+ ):
+ '''
+ Add sequence of vertices and face(s) to the wis3d viewer.
+
+ ### Args
+ - verts: torch.Tensor or np.ndarray, (L, V, 3), L ~ sequence length, V ~ number of vertices
+ - faces: torch.Tensor or np.ndarray, (F, 3) or (L, F, 3), F ~ number of faces, L ~ sequence length
+ - name: str
+ - The name of the point cloud.
+ - offset: int, default = 0
+ - The offset for the sequence index.
+ '''
+ assert (len(verts.shape) == 3), 'The input `verts` should have 3 dimensions: (L, V, 3).'
+ assert (verts.shape[-1] == 3), 'The last dimension of `verts` should be 3.'
+ if isinstance(verts, np.ndarray):
+ verts = torch.from_numpy(verts)
+ if isinstance(faces, torch.Tensor):
+ faces = faces.detach().cpu().numpy()
+ if len(faces.shape) == 2:
+ faces = faces[None].repeat(verts.shape[0], 0)
+ assert (len(faces.shape) == 3), 'The input `faces` should have 2 or 3 dimensions: (F, 3) or (L, F, 3).'
+ assert (faces.shape[-1] == 3), 'The last dimension of `faces` should be 3.'
+ assert (verts.shape[0] == faces.shape[0]), 'The first dimension of `verts` and `faces` should be the same.'
+
+ L, _, _ = verts.shape
+ verts = verts.detach().cpu()
+
+ # Add vertices frame by frame.
+ for i in range(L):
+ self.set_scene_id(i + offset)
+ self.add_mesh(
+ vertices = verts[i],
+ faces = faces[i],
+ name = name,
+ ) # type: ignore
+
+ # Reset Wis3D scene id.
+ self.set_scene_id(0)
+
+
+ def add_motion_verts(
+ self,
+ verts : Union[torch.Tensor, np.ndarray],
+ name : str,
+ offset: int = 0,
+ ):
+ '''
+ Add sequence of vertices to the wis3d viewer.
+
+ ### Args
+ - verts: torch.Tensor or np.ndarray, (L, V, 3), L ~ sequence length, V ~ number of vertices
+ - name: str
+ - The name of the point cloud.
+ - offset: int, default = 0
+ - The offset for the sequence index.
+ '''
+ assert (len(verts.shape) == 3), 'The input `verts` should have 3 dimensions: (L, V, 3).'
+ assert (verts.shape[-1] == 3), 'The last dimension of `verts` should be 3.'
+ if isinstance(verts, np.ndarray):
+ verts = torch.from_numpy(verts)
+
+ L, _, _ = verts.shape
+ verts = verts.detach().cpu()
+
+ # Add vertices frame by frame.
+ for i in range(L):
+ self.set_scene_id(i + offset)
+ self.add_point_cloud(
+ vertices = verts[i],
+ colors = None,
+ name = name,
+ )
+
+ # Reset Wis3D scene id.
+ self.set_scene_id(0)
+
+
+ def add_motion_skel(
+ self,
+ joints : Union[torch.Tensor, np.ndarray],
+ bones : Union[list, torch.Tensor],
+ colors : Union[list, torch.Tensor],
+ name : str,
+ offset : int = 0,
+ threshold : float = 0.5,
+ ):
+ '''
+ Add sequence of joints with specific skeleton to the wis3d viewer.
+
+ ### Args
+ - joints: torch.Tensor or np.ndarray, shape = (L, J, 3) or (L, J, 4), L ~ sequence length, J ~ number of joints
+ - bones: list
+ - A list of bones of the skeleton, i.e. the edge in the kinematic trees.
+ - colors: list
+ - name: str
+ - The name of the point cloud.
+ - offset: int, default = 0
+ - The offset for the sequence index.
+ - threshold: float, default = 0.5
+ - Threshold to filter the confidence of the joints. It's useless when no confidence provided.
+ '''
+ assert (len(joints.shape) == 3), 'The input `joints` should have 3 dimensions: (L, J, 3).'
+ assert (joints.shape[-1] == 3 or joints.shape[-1] == 4), 'The last dimension of `joints` should be 3 or 4.'
+ if isinstance(joints, np.ndarray):
+ joints = torch.from_numpy(joints)
+ if isinstance(bones, List):
+ bones = torch.tensor(bones)
+ if isinstance(colors, List):
+ colors = torch.tensor(colors)
+
+ # Get the sequence length.
+ joints = joints.detach().cpu() # (L, J, 3) or (L, J, 4)
+ L, J, D = joints.shape
+ if D == 4:
+ conf = joints[:, :, 3]
+ joints = joints[:, :, :3]
+ else:
+ conf = None
+
+ # Add vertices frame by frame.
+ for i in range(L):
+ self.set_scene_id(i + offset)
+ bones_s = joints[i][bones[:, 0]]
+ bones_e = joints[i][bones[:, 1]]
+ if conf is not None:
+ mask = torch.logical_and(conf[i][bones[:, 0]] > threshold, conf[i][bones[:, 1]] > threshold)
+ bones_s, bones_e = bones_s[mask], bones_e[mask]
+ if len(bones_s) > 0:
+ self.add_lines(
+ start_points = bones_s,
+ end_points = bones_e,
+ colors = colors,
+ name = name,
+ )
+
+ # Reset Wis3D scene id.
+ self.set_scene_id(0)
+
+
+ def add_vec_seq(
+ self,
+ vecs : torch.Tensor,
+ name : str,
+ offset : int = 0,
+ seg_num : int = 16,
+ ):
+ '''
+ Add directional line sequence to the wis3d viewer.
+
+ The line will be gradient colored, and the direction of the vector is visualized as from dark to light.
+
+ ### Args
+ - vecs: torch.Tensor, (L, 2, 3) or (L, N, 2, 3), L ~ sequence length, N ~ vectors counts in one frame,
+ then give the start 3D point and end 3D point.
+ - name: str
+ - The name of the vector.
+ - offset: int, default = 0
+ - The offset for the sequence index.
+ - seg_num: int, default = 16
+ - The number of segments for gradient color, will just change the visualization effect.
+ '''
+ if len(vecs.shape) == 3:
+ vecs = vecs[:, None, :, :] # (L, 2, 3) -> (L, 1, 2, 3)
+ assert (len(vecs.shape) == 4), 'The input `vecs` should have 3 or 4 dimensions: (L, 2, 3) or (L, N, 2, 3).'
+ assert (vecs.shape[-2:] == (2, 3)), f'The last two dimension of `vecs` should be (2, 3), but got vecs.shape = {vecs.shape}.'
+
+ # Get the sequence length.
+ L, N, _, _ = vecs.shape
+ vecs = vecs.detach().cpu()
+
+ # Cut the line into segments.
+ steps_delta = (vecs[:, :, [1]] - vecs[:, :, [0]]) / (seg_num + 1) # (L, N, 1, 3)
+ steps_cnt = torch.arange(seg_num + 1).reshape((1, 1, seg_num + 1, 1)) # (1, 1, seg_num+1, 1)
+ segs = steps_delta * steps_cnt + vecs[:, :, [0]] # (L, N, seg_num+1, 3)
+ start_pts = segs[:, :, :-1] # (L, N, seg_num, 3)
+ end_pts = segs[:, :, 1:] # (L, N, seg_num, 3)
+
+ # Prepare the gradient colors.
+ grad_colors = torch.linspace(0, 255, seg_num).reshape((1, seg_num, 1)).repeat(N, 1, 3) # (N, seg_num, 3)
+
+ # Add vertices frame by frame.
+ for i in range(L):
+ self.set_scene_id(i + offset)
+ self.add_lines(
+ start_points = start_pts[i].reshape(-1, 3),
+ end_points = end_pts[i].reshape(-1, 3),
+ colors = grad_colors.reshape(-1, 3),
+ name = name,
+ )
+
+ # Reset Wis3D scene id.
+ self.set_scene_id(0)
+
+
+ def add_traj(
+ self,
+ positions : torch.Tensor,
+ name : str,
+ offset : int = 0,
+ ):
+ '''
+ Visualize the the positions change across the time as trajectory. The newer position will be brighter.
+
+ ### Args
+ - positions: torch.Tensor, (L, 3), L ~ sequence length
+ - name: str
+ - The name of the trajectory.
+ - offset: int, default = 0
+ - The offset for the sequence index.
+ '''
+ assert (len(positions.shape) == 2), 'The input `positions` should have 2 dimensions: (L, 3).'
+ assert (positions.shape[-1] == 3), 'The last dimension of `positions` should be 3.'
+
+ # Get the sequence length.
+ L, _ = positions.shape
+ positions = positions.detach().cpu()
+ traj = positions[[0]] # (1, 3)
+
+ # Prepare the gradient colors.
+ grad_colors = torch.linspace(208, 48, L).reshape((L, 1)).repeat(1, 3) # (L, 3)
+
+ for i in range(L):
+ traj = torch.cat((traj, positions[[i]]), dim=0) # (i+2, 3)
+ self.set_scene_id(i + offset)
+ self.add_lines(
+ start_points = traj[:-1],
+ end_points = traj[1:],
+ colors = grad_colors[-(i+1):],
+ name = name,
+ )
+
+ # Reset Wis3D scene id.
+ self.set_scene_id(0)
+
+
+ def add_sphere_sensors(
+ self,
+ positions : torch.Tensor,
+ radius : Union[torch.Tensor, float],
+ activities : torch.Tensor,
+ name : str,
+ ):
+ '''
+ Draw the sphere sensors with different colors to represent the activities. The color is from white to red.
+
+ ### Args
+ - positions: torch.Tensor, (N, 3), N ~ number of sensors
+ - radius: torch.Tensor or float, (N,), N ~ number of sensors
+ - activities: torch.Tensor, (N)
+ - The activities of the sensors, from 0 to 1.
+ - name: str
+ - The name of the spheres.
+ '''
+ assert (len(positions.shape) == 2), 'The input `positions` should have 2 dimensions: (N, 3).'
+ assert (positions.shape[-1] == 3), 'The last dimension of `positions` should be 3.'
+ N, _ = positions.shape
+ if isinstance(radius, float):
+ radius = torch.Tensor(radius).reshape(1).repeat(N) # (N)
+ elif len(radius.shape) == 0:
+ radius = radius.reshape(1).repeat(N)
+
+ colors = torch.ones(size=(N, 3)).float()
+ colors[:, 0] = 255
+ colors[:, 1] = (1 - activities) ** 2 * 255
+ colors[:, 2] = (1 - activities) ** 2 * 255
+ self.add_spheres(
+ centers = positions,
+ radius = radius,
+ colors = colors,
+ name = name,
+ )
+
+
+ def add_sphere_sensors_seq(
+ self,
+ positions : torch.Tensor,
+ radius : Union[torch.Tensor, float],
+ activities : torch.Tensor,
+ name : str,
+ offset : int = 0,
+ ):
+ '''
+ Draw the sphere sensors with different colors to represent the activities. The color is from white to red.
+
+ ### Args
+ - positions: torch.Tensor, (L, N, 3), N ~ number of sensors
+ - radius: torch.Tensor or float, (L, N,), N ~ number of sensors
+ - activities: torch.Tensor, (L, N)
+ - The activities of the sensors, from 0 to 1.
+ - name: str
+ - The name of the spheres.
+ - offset: int, default = 0
+ - The offset for the sequence index.
+ '''
+ assert (len(positions.shape) == 3), 'The input `positions` should have 3 dimensions: (L, N, 3).'
+ assert (positions.shape[-1] == 3), 'The last dimension of `positions` should be 3.'
+ L, N, _ = positions.shape
+
+ for i in range(L):
+ self.set_scene_id(i + offset)
+ self.add_sphere_sensors(
+ positions = positions[i],
+ radius = radius,
+ activities = activities[i],
+ name = name,
+ )
+
+
+ # ===== Overriding methods from original Wis3D. =====
+
+
+ def add_lines(
+ self,
+ start_points: torch.Tensor,
+ end_points : torch.Tensor,
+ colors : Union[list, torch.Tensor] = None,
+ name : str = None,
+ thickness : float = 0.01,
+ resolution : int = 4,
+ ):
+ '''
+ Add lines by points. Overriding the original `add_lines` method to use mesh to provide browser from crash.
+
+ ### Args
+ - start_points: torch.Tensor, (N, 3), N ~ number of lines
+ - end_points: torch.Tensor, (N, 3), N ~ number of lines
+ - colors: list or torch.Tensor, (N, 3)
+ - The color of the lines, from 0 to 255.
+ - name: str
+ - The name of the vector.
+ - thickness: float, default = 0.01
+ - The thickness of the lines.
+ - resolution: int, default = 3
+ - The 'line' was actually a poly-cylinder, and the resolution how it looks like a cylinder.
+ '''
+ if isinstance(colors, List):
+ colors = torch.tensor(colors)
+
+ assert (len(start_points.shape) == 2), 'The input `start_points` should have 2 dimensions: (N, 3).'
+ assert (len(end_points.shape) == 2), 'The input `end_points` should have 2 dimensions: (N, 3).'
+ assert (start_points.shape == end_points.shape), 'The input `start_points` and `end_points` should have the same shape.'
+
+ # ===== Prepare the data. =====
+ N, _ = start_points.shape
+ device = start_points.device
+ dir = end_points - start_points # (N, 3)
+ dis = torch.norm(dir, dim=-1, keepdim=True) # (N, 1)
+ dir = dir / dis # (N, 3)
+ K = resolution + 1 # the first & the last point share the position
+ # Find out directions that are negative to the y-axis.
+ vec_y = torch.Tensor([[0, 1, 0]]).float().to(device) # (1, 3)
+ neg_mask = (dir @ vec_y.transpose(-1, -2) < 0).squeeze() # (N,)
+
+ # ===== Get the ending surface vertices of the cylinder. =====
+ # 1. Get the surface vertices template in x-z plain.
+ radius = torch.linspace(0, 2*torch.pi, K) # (K,)
+ v_ending_temp = \
+ torch.stack(
+ [torch.cos(radius), torch.zeros_like(radius), torch.sin(radius)],
+ dim = -1
+ ) # (K, 3)
+ v_ending_temp *= thickness # (K, 3)
+ v_ending_temp = v_ending_temp[None].repeat(N, 1, 1) # (N, K, 3)
+
+ # 2. Rotate the template plane to the direction of the line.
+ rot_axis = torch.linalg.cross(vec_y, dir) # (N, 3)
+ rot_axis[neg_mask] *= -1
+ rot_mat = axis_angle_to_matrix(rot_axis) # (N, 3, 3)
+ v_ending_temp = v_ending_temp @ rot_mat.transpose(-1, -2)
+ v_ending_temp = v_ending_temp.to(device)
+
+ # 3. Move the template plane to the start and end points and get the cylinder vertices.
+ v_cylinder_start = v_ending_temp + start_points[:, None] # (N, K, 3)
+ v_cylinder_end = v_ending_temp + end_points[:, None] # (N, K, 3)
+ # Swap the start and end points for the negative direction to adjust the normal direction.
+ v_cylinder_start[neg_mask], v_cylinder_end[neg_mask] = v_cylinder_end[neg_mask], v_cylinder_start[neg_mask]
+ v_cylinder = torch.cat([v_cylinder_start, v_cylinder_end], dim=1) # (N, 2*K, 3)
+
+ # ===== Calculate the face index. =====
+ idx = torch.arange(0, 2*K, device=device).to(device) # (2*K,)
+ idx_s, idx_e = idx[:K], idx[K:]
+ f_cylinder = torch.cat([
+ # Two ending surface.
+ torch.stack([idx_s[0].repeat(K-2), idx_s[1:-1], idx_s[2:]], dim=-1),
+ torch.stack([idx_e[0].repeat(K-2), idx_e[2:], idx_e[1:-1]], dim=-1),
+ # The side surface.
+ torch.stack([idx_e[:-1], idx_s[1:], idx_s[:-1]], dim=-1),
+ torch.stack([idx_e[:-1], idx_e[1:], idx_s[1:]], dim=-1),
+ ], dim=0) # (4*K-4, 3)
+ f_cylinder = f_cylinder[None].repeat(N, 1, 1) # (N, 4*K-4, 3)
+
+ # ===== Calculate the face index. =====
+ if colors is not None:
+ c_cylinder = colors / 255.0 # (N, 3)
+ c_cylinder = c_cylinder[:, None].repeat(1, 2*K, 1) # (N, 2*K, 3)
+ else:
+ c_cylinder = None
+
+ N, V = v_cylinder.shape[:2]
+ v_cylinder = v_cylinder.reshape(-1, 3) # (N*(2*K), 3)
+
+ # ===== Manually match the points index before flatten. =====
+ f_cylinder = f_cylinder + torch.arange(0, N, device=device).unsqueeze(1).unsqueeze(1) * V
+ f_cylinder = f_cylinder.reshape(-1, 3) # (N*(4*K-4), 3)
+ if c_cylinder is not None:
+ c_cylinder = c_cylinder.reshape(-1, 3) # (N*(2*K), 3)
+
+ self.add_mesh(
+ vertices = v_cylinder,
+ vertex_colors = c_cylinder,
+ faces = f_cylinder,
+ name = name,
+ )
\ No newline at end of file
diff --git a/lib/version.py b/lib/version.py
new file mode 100644
index 0000000000000000000000000000000000000000..cc37b911090d58edb9b4a5eba6557137a44832d4
--- /dev/null
+++ b/lib/version.py
@@ -0,0 +1,3 @@
+# coding=utf-8
+__version__ = 'v1.0'
+DEFAULT_HSMR_ROOT = 'data_inputs/released_models/HSMR-ViTH-r1d0'
\ No newline at end of file
diff --git a/requirements_part1.txt b/requirements_part1.txt
new file mode 100644
index 0000000000000000000000000000000000000000..162970ea39a1b5bea9737f6e4a4ebc0d8de87fe5
--- /dev/null
+++ b/requirements_part1.txt
@@ -0,0 +1,24 @@
+braceexpand
+colorlog
+einops
+hydra-core
+imageio[ffmpeg]
+ipdb
+matplotlib
+numpy>=1.24.4
+omegaconf
+opencv_python
+pyrender
+pytorch_lightning
+Requests
+rich
+setuptools
+smplx
+timm
+torch
+tqdm
+trimesh
+webdataset
+wis3d
+yacs
+gradio
\ No newline at end of file
diff --git a/requirements_part2.txt b/requirements_part2.txt
new file mode 100644
index 0000000000000000000000000000000000000000..3134d06ff9136f12c4daf3e77de8131cbedc7db6
--- /dev/null
+++ b/requirements_part2.txt
@@ -0,0 +1,2 @@
+git+https://github.com/facebookresearch/detectron2.git
+git+https://github.com/mattloper/chumpy.git
\ No newline at end of file
diff --git a/setup.py b/setup.py
new file mode 100644
index 0000000000000000000000000000000000000000..287ccc9f960d2e646a920ce5e1e07ef0096a8eaa
--- /dev/null
+++ b/setup.py
@@ -0,0 +1,15 @@
+from setuptools import setup, find_packages
+
+
+# import `__version__`
+with open('lib/version.py') as f:
+ exec(f.read())
+
+setup(
+ name = 'lib',
+ version = __version__, # type: ignore
+ author = 'Yan Xia',
+ author_email = 'isshikihugh@gmail.com',
+ description = 'Official implementation of HSMR.',
+ packages = find_packages(),
+)
\ No newline at end of file
diff --git a/tools/.gitignore b/tools/.gitignore
new file mode 100644
index 0000000000000000000000000000000000000000..58f8d2d185d1d3ea32760f85cdd1d87aeaa602c1
--- /dev/null
+++ b/tools/.gitignore
@@ -0,0 +1 @@
+*.flag
\ No newline at end of file
diff --git a/tools/prepare_data.sh b/tools/prepare_data.sh
new file mode 100644
index 0000000000000000000000000000000000000000..20a7c1d3864ea59afa03569d8180ad15ce746c0d
--- /dev/null
+++ b/tools/prepare_data.sh
@@ -0,0 +1,10 @@
+cd data_inputs
+echo "[SCRIPT-LOG] Downloading body models..." ; date
+wget -q -c 'https://huggingface.co/IsshikiHugh/HSMR-data_inputs/resolve/main/body_models.tar.gz' -O body_models.tar.gz
+tar -xzf body_models.tar.gz
+
+mkdir -p released_models
+cd released_models
+echo "[SCRIPT-LOG] Downloading checkpoints..." ; date
+wget -q -c 'https://huggingface.co/IsshikiHugh/HSMR-data_inputs/resolve/main/released_models/HSMR-ViTH-r1d0.tar.gz' -O HSMR-ViTH-r1d0.tar.gz
+tar -xzf HSMR-ViTH-r1d0.tar.gz
\ No newline at end of file
diff --git a/tools/service.py b/tools/service.py
new file mode 100644
index 0000000000000000000000000000000000000000..e806bbf93f15d84c22cd2618edbcebce83255c92
--- /dev/null
+++ b/tools/service.py
@@ -0,0 +1,12 @@
+from lib.kits.gradio import *
+
+import os
+
+os.environ['PYOPENGL_PLATFORM'] = 'egl'
+
+
+if __name__ == '__main__':
+ # Start serving.
+ backend = HSMRBackend(device='cpu')
+ hsmr_service = HSMRService(backend)
+ hsmr_service.serve()
\ No newline at end of file
diff --git a/tools/start.sh b/tools/start.sh
new file mode 100644
index 0000000000000000000000000000000000000000..7de99574e0d84a456c2e0cbb7d54919a6efee446
--- /dev/null
+++ b/tools/start.sh
@@ -0,0 +1,7 @@
+bash tools/prepare_data.sh
+
+echo "[SCRIPT-LOG] Installing requirements..." ; date
+pip install -e .
+
+echo "[SCRIPT-LOG] Starting service..." ; date
+python tools/service.py
\ No newline at end of file