diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..837587841c89bc2865e84e36f2fe4382e752cc0d --- /dev/null +++ b/.gitignore @@ -0,0 +1,17 @@ +# Caches. +.DS_Store +__pycache__ +*.egg-info + +# VEnv. +.hsmr_env + +# Hydra dev. +outputs + +# Development tools. +.vscode +pyrightconfig.json + +# Logs. +gpu_monitor.log diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 0000000000000000000000000000000000000000..6a08aa69d6c15d28ffa17d6f46c9e997a6784971 --- /dev/null +++ b/Dockerfile @@ -0,0 +1,37 @@ +# read the doc: https://huggingface.co/docs/hub/spaces-sdks-docker +# you will also find guides on how best to write your Dockerfile + +FROM python:3.10 + +WORKDIR /code + +COPY ./requirements_part1.txt /code/requirements_part1.txt +COPY ./requirements_part2.txt /code/requirements_part2.txt + +RUN apt-get update -y && \ + apt-get upgrade -y && \ + apt-get install -y libglfw3-dev && \ + apt-get install -y libgles2-mesa-dev && \ + apt-get install -y aria2 && \ + pip install --no-cache-dir --upgrade -r /code/requirements_part1.txt && \ + pip install --no-cache-dir --upgrade -r /code/requirements_part2.txt + +# Set up a new user named "user" with user ID 1000 +RUN useradd -m -u 1000 user + +# Switch to the "user" user +USER user + + +# Set home to the user's home directory +ENV HOME=/home/user \ + PATH=/home/user/.local/bin:$PATH + +# Set the working directory to the user's home directory +WORKDIR $HOME/app + +# Copy the current directory contents into the container at $HOME/app setting the owner to the user +COPY --chown=user . $HOME/app + + +CMD ["bash", "tools/start.sh"] \ No newline at end of file diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..2d868c8b74175093a195fac31ede8a547a01b718 --- /dev/null +++ b/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2024 Yan XIA + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/README.md b/README.md index 5c14c7bbc6b7a4b6fb2b2643e82be65d74862b2d..1adfd10b13d9891ac6e20be5da4caae3a34f65ca 100644 --- a/README.md +++ b/README.md @@ -1,11 +1,10 @@ --- title: HSMR emoji: 💀 -colorFrom: purple -colorTo: indigo -sdk: gradio -sdk_version: 5.20.0 -app_file: tools/service.py +colorFrom: blue +colorTo: pink +sdk: docker +app_port: 7860 pinned: true --- diff --git a/configs/.gitignore b/configs/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/configs/README.md b/configs/README.md new file mode 100644 index 0000000000000000000000000000000000000000..fcf75522caf15419d5ce98e24fc6db855da4adde --- /dev/null +++ b/configs/README.md @@ -0,0 +1,31 @@ +# Instructions for the Configuration System + +The configuration system I used here is based on [Hydra](https://hydra.cc/). However, I made some small 'hack' to achieve some better features, such as `_hub_`. It might be a little bit hard to understand at first, but a comprehensive guidance is provided. Check `README.md` in each directory for more details. + +## Philosophy + +- Easy to modify and maintain. +- Help you to code with clear structure. +- Consistency. +- Easy to trace and identify specific item. + +## Some Ideas + +- Less ListConfig, or ListConfig only for real list data. + - Dumped list will be unfolded, each element occupies one line, and it's annoying when presenting. + - List things are not friendly to command line arguments supports. +- For defaults list, `_self_` must be explicitly specified. + - Items before `_self_` means 'based on those items'. + - Items after `_self_` means 'import those items'. + +### COMPOSITION OVER INHERITANCE + +Do not use "overrides" AS MUCH AS POSSIBLE, except when the changes are really tiny. Since it's hard to identify which term is actually used without running the code. Instead, the `default.yaml` serves like a template, you are supposed to copy it and modify it to create a new configuration. + +### REFERENCE OVER COPY + +If you want to use one things for many times (across each experiments or across each components in one experiment), you'd better use `${...}` to reference the Hydra object. So that you only need to modify one place while ensuring the consistency. Think about where to put the source of the object, `_hub_` is recommended but may not be the best choice every time. + +### PREPARE EVERYTHING YOU NEED LOCALLY + +Sometimes you might want to use some configurations outside the class configuration (I mean in the coding process). In that case, I recommend you to reference these things again in local configuration package. It might be too redundant, but it will bring cleaner code. \ No newline at end of file diff --git a/configs/_hub_/README.md b/configs/_hub_/README.md new file mode 100644 index 0000000000000000000000000000000000000000..db39e698657362975a38edb30a6932e2467c6f32 --- /dev/null +++ b/configs/_hub_/README.md @@ -0,0 +1,7 @@ +# _hub + +Configs here shouldn't be used as what Hydra calls a "config group". Configs here serve as a hub for other configs to reference. Most things here usually won't be used in a certain experiment. + +For example, the details of each datasets can be defined here, and than I only need to reference the `Human3.6M` dataset through `${_hub_.datasets.h36m}` in other configs. + +Each `.yaml` file here will be loaded separately in `base.yaml`. Check the `base.yaml` to understand how it works. \ No newline at end of file diff --git a/configs/_hub_/datasets.yaml b/configs/_hub_/datasets.yaml new file mode 100644 index 0000000000000000000000000000000000000000..134ed49ef786aa6fe76cd2f26a398d1e0448381f --- /dev/null +++ b/configs/_hub_/datasets.yaml @@ -0,0 +1,71 @@ +train: + hsmr: # Fitted according to the SMPL's vertices. + # Standard 4 datasets. + mpi_inf_3dhp: + name: 'HSMR-MPI-INF-3DHP-train-pruned' + urls: ${_pm_.inputs}/hsmr_training_data/mpi_inf_3dhp-tars/{000000..000012}.tar + epoch_size: 12_000 + h36m: + name: 'HSMR-H36M-train' + urls: ${_pm_.inputs}/hsmr_training_data/h36m-tars/{000000..000312}.tar + epoch_size: 314_000 + mpii: + name: 'HSMR-MPII-train' + urls: ${_pm_.inputs}/hsmr_training_data/mpii-tars/{000000..000009}.tar + epoch_size: 10_000 + coco14: + name: 'HSMR-COCO14-train' + urls: ${_pm_.inputs}/hsmr_training_data/coco14-tars/{000000..000017}.tar + epoch_size: 18_000 + # The rest full datasets for HMR2.0. + coco14vit: + name: 'HSMR-COCO14-vit-train' + urls: ${_pm_.inputs}/hsmr_training_data/coco14vit-tars/{000000..000044}.tar + epoch_size: 45_000 + aic: + name: 'HSMR-AIC-train' + urls: ${_pm_.inputs}/hsmr_training_data/aic-tars/{000000..000209}.tar + epoch_size: 210_000 + ava: + name: 'HSMR-AVA-train' + urls: ${_pm_.inputs}/hsmr_training_data/ava-tars/{000000..000184}.tar + epoch_size: 185_000 + insta: + name: 'HSMR-INSTA-train' + urls: ${_pm_.inputs}/hsmr_training_data/insta-tars/{000000..003657}.tar + epoch_size: 3_658_000 + +mocap: + bioamass_v1: + dataset_file: ${_pm_.inputs}/datasets/amass_skel/data_v1.npz + pve_threshold: 0.05 + cmu_mocap: + dataset_file: ${_pm_.inputs}/hmr2_training_data/cmu_mocap.npz + +eval: + h36m_val_p2: + dataset_file: ${_pm_.inputs}/hmr2_evaluation_data/h36m_val_p2.npz + img_root: ${_pm_.inputs}/datasets/h36m/images + kp_list: [25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 43] + use_hips: true + + 3dpw_test: + dataset_file: ${_pm_.inputs}/hmr2_evaluation_data/3dpw_test.npz + img_root: ${_pm_.inputs}/datasets/3dpw + kp_list: [25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 43] + use_hips: false + + posetrack_val: + dataset_file: ${_pm_.inputs}/hmr2_evaluation_data/posetrack_2018_val.npz + img_root: ${_pm_.inputs}/datasets/posetrack/posetrack2018/posetrack_data/ + kp_list: [0] # dummy + + lsp_extended: + dataset_file: ${_pm_.inputs}/hmr2_evaluation_data/hr-lspet_train.npz + img_root: ${_pm_.inputs}/datasets/hr-lspet + kp_list: [0] # dummy + + coco_val: + dataset_file: ${_pm_.inputs}/hmr2_evaluation_data/coco_val.npz + img_root: ${_pm_.inputs}/datasets/coco + kp_list: [0] # dummy diff --git a/configs/_hub_/models.yaml b/configs/_hub_/models.yaml new file mode 100644 index 0000000000000000000000000000000000000000..241ee7e4f3da2975367f1a9aa9696e2806b53472 --- /dev/null +++ b/configs/_hub_/models.yaml @@ -0,0 +1,57 @@ +body_models: + + skel_mix_hsmr: + _target_: lib.body_models.skel_wrapper.SKELWrapper + model_path: '${_pm_.inputs}/body_models/skel' + gender: male # Use male since we don't have neutral model. + joint_regressor_extra: '${_pm_.inputs}/body_models/SMPL_to_J19.pkl' + joint_regressor_custom: '${_pm_.inputs}/body_models/J_regressor_SKIL_mix_MALE.pkl' + + skel_hsmr: + _target_: lib.body_models.skel_wrapper.SKELWrapper + model_path: '${_pm_.inputs}/body_models/skel' + gender: male # Use male since we don't have neutral model. + joint_regressor_extra: '${_pm_.inputs}/body_models/SMPL_to_J19.pkl' + # joint_regressor_custom: '${_pm_.inputs}/body_models/J_regressor_SMPL_MALE.pkl' + + smpl_hsmr: + _target_: lib.body_models.smpl_wrapper.SMPLWrapper + model_path: '${_pm_.inputs}/body_models/smpl' + gender: male # align with skel_hsmr + num_body_joints: 23 + joint_regressor_extra: '${_pm_.inputs}/body_models/SMPL_to_J19.pkl' + + smpl_hsmr_neutral: + _target_: lib.body_models.smpl_wrapper.SMPLWrapper + model_path: '${_pm_.inputs}/body_models/smpl' + gender: neutral # align with skel_hsmr + num_body_joints: 23 + joint_regressor_extra: '${_pm_.inputs}/body_models/SMPL_to_J19.pkl' + +backbones: + + vit_b: + _target_: lib.modeling.networks.backbones.ViT + img_size: [256, 192] + patch_size: 16 + embed_dim: 768 + depth: 12 + num_heads: 12 + ratio: 1 + use_checkpoint: False + mlp_ratio: 4 + qkv_bias: True + drop_path_rate: 0.3 + + vit_h: + _target_: lib.modeling.networks.backbones.ViT + img_size: [256, 192] + patch_size: 16 + embed_dim: 1280 + depth: 32 + num_heads: 16 + ratio: 1 + use_checkpoint: False + mlp_ratio: 4 + qkv_bias: True + drop_path_rate: 0.55 \ No newline at end of file diff --git a/configs/base.yaml b/configs/base.yaml new file mode 100644 index 0000000000000000000000000000000000000000..151fbc5d81e94a75455db64b211e0acace2cfefc --- /dev/null +++ b/configs/base.yaml @@ -0,0 +1,33 @@ +defaults: + # Load each sub-hub. + - _hub_/datasets@_hub_.datasets + - _hub_/models@_hub_.models + + # Set up the template. + - _self_ + + # Register some defaults. + - exp: default + - policy: default + + +# exp_name: !!null # Name of experiment will determine the output folder. +exp_name: '${exp_topic}-${exp_tag}' +exp_topic: !!null # Theme of the experiment. +exp_tag: 'debug' # Tag of the experiment. + + +output_dir: ${_pm_.root}/data_outputs/exp/${exp_name} # Output directory for the experiment. +# output_dir: ${_pm_.root}/data_outputs/exp/${now:%Y-%m-%d}/${exp_name} # Output directory for the experiment. + +hydra: + run: + dir: ${output_dir} + + +# Information from `PathManager`, which will be automatically set by entrypoint wrapper. +# The related implementation is in `lib/utils/cfg_utils.py`:`entrypoint_with_args`. +_pm_: + root: !!null + inputs: !!null + outputs: !!null \ No newline at end of file diff --git a/configs/callback/ckpt/all-p1k.yaml b/configs/callback/ckpt/all-p1k.yaml new file mode 100644 index 0000000000000000000000000000000000000000..40d208d77f80c7d9f856271031f80d8e42bb1ef4 --- /dev/null +++ b/configs/callback/ckpt/all-p1k.yaml @@ -0,0 +1,5 @@ +_target_: pytorch_lightning.callbacks.ModelCheckpoint +dirpath: '${output_dir}/checkpoints' +save_last: True +every_n_train_steps: 1000 +save_top_k: -1 # save all ckpts \ No newline at end of file diff --git a/configs/callback/ckpt/top1-p1k.yaml b/configs/callback/ckpt/top1-p1k.yaml new file mode 100644 index 0000000000000000000000000000000000000000..4f7a912720e27db3128a827be624e7694dc5a4d7 --- /dev/null +++ b/configs/callback/ckpt/top1-p1k.yaml @@ -0,0 +1,5 @@ +_target_: pytorch_lightning.callbacks.ModelCheckpoint +dirpath: '${output_dir}/checkpoints' +save_last: True +every_n_train_steps: 1000 +save_top_k: 1 # save only the last checkpoint \ No newline at end of file diff --git a/configs/callback/skelify-spin/i10kb1.yaml b/configs/callback/skelify-spin/i10kb1.yaml new file mode 100644 index 0000000000000000000000000000000000000000..6570b331397608ef4af54eac66a86eb957b88062 --- /dev/null +++ b/configs/callback/skelify-spin/i10kb1.yaml @@ -0,0 +1,19 @@ +defaults: + - _self_ + # Import the SKELify-SPIN callback. + - /pipeline/skelify-refiner@skelify + +_target_: lib.modeling.callbacks.SKELifySPIN + +# This configuration are for frozen backbone exp with batch_size = 24000 + +cfg: + interval: 10 + batch_size: 24000 # when greater than `interval * dataloader's batch_size`, it's equivalent to that + max_batches_per_round: 1 # only the latest k * batch_size items are SPINed to save time + # better_pgt_fn: '${_pm_.inputs}/datasets/skel_training_data/spin/better_pseudo_gt.npz' + better_pgt_fn: '${output_dir}/better_pseudo_gt.npz' + skip_warm_up_steps: 10000 + # skip_warm_up_steps: 0 + update_better_pgt: True + valid_betas_threshold: 2 diff --git a/configs/callback/skelify-spin/i230kb1.yaml b/configs/callback/skelify-spin/i230kb1.yaml new file mode 100644 index 0000000000000000000000000000000000000000..97969670712fc1d1547e5719ff39b6b68f5993a4 --- /dev/null +++ b/configs/callback/skelify-spin/i230kb1.yaml @@ -0,0 +1,16 @@ +defaults: + - _self_ + # Import the SKELify-SPIN callback. + - /pipeline/skelify-refiner@skelify + +_target_: lib.modeling.callbacks.SKELifySPIN + +cfg: + interval: 230 + batch_size: 18000 # when greater than `interval * dataloader's batch_size`, it's equivalent to that + max_batches_per_round: 1 # only the latest k * batch_size items are SPINed to save time + # better_pgt_fn: '${_pm_.inputs}/datasets/skel_training_data/spin/better_pseudo_gt.npz' + better_pgt_fn: '${output_dir}/better_pseudo_gt.npz' + skip_warm_up_steps: 5000 + update_better_pgt: True + valid_betas_threshold: 2 diff --git a/configs/callback/skelify-spin/i80.yaml b/configs/callback/skelify-spin/i80.yaml new file mode 100644 index 0000000000000000000000000000000000000000..d1026e524f47b2a9be77fc22b9de2049c6f76983 --- /dev/null +++ b/configs/callback/skelify-spin/i80.yaml @@ -0,0 +1,15 @@ +defaults: + - _self_ + # Import the SKELify-SPIN callback. + - /pipeline/skelify-refiner@skelify + +_target_: lib.modeling.callbacks.SKELifySPIN + +cfg: + interval: 80 + batch_size: 24000 # when greater than `interval * dataloader's batch_size`, it's equivalent to that + # better_pgt_fn: '${_pm_.inputs}/datasets/skel_training_data/spin/better_pseudo_gt.npz' + better_pgt_fn: '${output_dir}/better_pseudo_gt.npz' + skip_warm_up_steps: 5000 + update_better_pgt: True + valid_betas_threshold: 2 diff --git a/configs/callback/skelify-spin/read_only.yaml b/configs/callback/skelify-spin/read_only.yaml new file mode 100644 index 0000000000000000000000000000000000000000..f21087355c220fc56d64e1608d3108659529e683 --- /dev/null +++ b/configs/callback/skelify-spin/read_only.yaml @@ -0,0 +1,18 @@ +defaults: + - _self_ + # Import the SKELify-SPIN callback. + - /pipeline/skelify-refiner@skelify + +_target_: lib.modeling.callbacks.SKELifySPIN + +# This configuration are for frozen backbone exp with batch_size = 24000 + +cfg: + interval: 0 + batch_size: 0 # when greater than `interval * dataloader's batch_size`, it's equivalent to that + max_batches_per_round: 0 # only the latest k * batch_size items are SPINed to save time + # better_pgt_fn: '${_pm_.inputs}/datasets/skel_training_data/spin/better_pseudo_gt.npz' + better_pgt_fn: '${output_dir}/better_pseudo_gt.npz' + skip_warm_up_steps: 0 + update_better_pgt: True + valid_betas_threshold: 2 diff --git a/configs/data/skel-hmr2_fashion.yaml b/configs/data/skel-hmr2_fashion.yaml new file mode 100644 index 0000000000000000000000000000000000000000..66868282ec2e1066bded8f821ccbe324fb63e8ce --- /dev/null +++ b/configs/data/skel-hmr2_fashion.yaml @@ -0,0 +1,71 @@ +_target_: lib.data.modules.hmr2_fashion.skel_wds.DataModule + +name: HMR2_fashion_WDS + +cfg: + + train: + shared_ds_opt: # TODO: modify this + SUPPRESS_KP_CONF_THRESH: 0.3 + FILTER_NUM_KP: 4 + FILTER_NUM_KP_THRESH: 0.0 + FILTER_REPROJ_THRESH: 31000 + SUPPRESS_BETAS_THRESH: 3.0 + SUPPRESS_BAD_POSES: False + POSES_BETAS_SIMULTANEOUS: True + FILTER_NO_POSES: False + BETAS_REG: True + + datasets: + - name: 'H36M' + item: ${_hub_.datasets.train.hsmr.h36m} + weight: 0.3 + - name: 'MPII' + item: ${_hub_.datasets.train.hsmr.mpii} + weight: 0.1 + - name: 'COCO14' + item: ${_hub_.datasets.train.hsmr.coco14} + weight: 0.4 + - name: 'MPI-INF-3DHP' + item: ${_hub_.datasets.train.hsmr.mpi_inf_3dhp} + weight: 0.2 + + dataloader: + drop_last: True + batch_size: 300 + num_workers: 6 + prefetch_factor: 2 + + eval: + datasets: + - name: 'LSP-EXTENDED' + item: ${_hub_.datasets.eval.lsp_extended} + - name: 'H36M-VAL-P2' + item: ${_hub_.datasets.eval.h36m_val_p2} + - name: '3DPW-TEST' + item: ${_hub_.datasets.eval.3dpw_test} + - name: 'POSETRACK-VAL' + item: ${_hub_.datasets.eval.posetrack_val} + - name: 'COCO-VAL' + item: ${_hub_.datasets.eval.coco_val} + + + dataloader: + shuffle: False + batch_size: 300 + num_workers: 6 + + policy: ${policy} + + # TODO: modify this + augm: + SCALE_FACTOR: 0.3 + ROT_FACTOR: 30 + TRANS_FACTOR: 0.02 + COLOR_SCALE: 0.2 + ROT_AUG_RATE: 0.6 + TRANS_AUG_RATE: 0.5 + DO_FLIP: True + FLIP_AUG_RATE: 0.5 + EXTREME_CROP_AUG_RATE: 0.10 + EXTREME_CROP_AUG_LEVEL: 1 \ No newline at end of file diff --git a/configs/data/skel-hsr_v1_4ds.yaml b/configs/data/skel-hsr_v1_4ds.yaml new file mode 100644 index 0000000000000000000000000000000000000000..b5e5c3d39f6dbeb73c644dbd44656c03d4e8c083 --- /dev/null +++ b/configs/data/skel-hsr_v1_4ds.yaml @@ -0,0 +1,84 @@ +_target_: lib.data.modules.hsmr_v1.data_module.DataModule + +name: SKEL_HSMR_V1 + +cfg: + + train: + cfg: + # Loader settings. + suppress_pgt_params_pve_max_thresh: 0.06 # Mark PVE-MAX larger than this value as invalid. + suppress_kp_conf_thresh: 0.3 # Mark key-point confidence smaller than this value as invalid. + suppress_betas_thresh: 3.0 # Mark betas having components larger than this value as invalid. + poses_betas_simultaneous: True # Sync poses and betas for the small person. + filter_insufficient_kp_cnt: 4 + suppress_insufficient_kp_thresh: 0.0 + filter_reproj_err_thresh: 31000 + regularize_invalid_betas: True + # Others. + image_augmentation: ${...image_augmentation} + policy: ${...policy} + + datasets: + - name: 'H36M' + item: ${_hub_.datasets.train.hsmr.h36m} + weight: 0.3 + - name: 'MPII' + item: ${_hub_.datasets.train.hsmr.mpii} + weight: 0.1 + - name: 'COCO14' + item: ${_hub_.datasets.train.hsmr.coco14} + weight: 0.4 + - name: 'MPI-INF-3DHP' + item: ${_hub_.datasets.train.hsmr.mpi_inf_3dhp} + weight: 0.2 + + dataloader: + drop_last: True + batch_size: 300 + num_workers: 6 + prefetch_factor: 2 + + # mocap: + # cfg: ${_hub_.datasets.mocap.bioamass_v1} + # dataloader: + # batch_size: 600 # num_train:2 * batch_size:300 (from HMR2.0's cfg) + # drop_last: True + # shuffle: True + # num_workers: 1 + + eval: + cfg: + image_augmentation: ${...image_augmentation} + policy: ${...policy} + + datasets: + - name: 'LSP-EXTENDED' + item: ${_hub_.datasets.eval.lsp_extended} + - name: 'H36M-VAL-P2' + item: ${_hub_.datasets.eval.h36m_val_p2} + - name: '3DPW-TEST' + item: ${_hub_.datasets.eval.3dpw_test} + - name: 'POSETRACK-VAL' + item: ${_hub_.datasets.eval.posetrack_val} + - name: 'COCO-VAL' + item: ${_hub_.datasets.eval.coco_val} + + dataloader: + shuffle: False + batch_size: 300 + num_workers: 6 + + # Augmentation settings. + image_augmentation: + trans_factor: 0.02 + bbox_scale_factor: 0.3 + rot_aug_rate: 0.6 + rot_factor: 30 + do_flip: True + flip_aug_rate: 0.5 + extreme_crop_aug_rate: 0.10 + half_color_scale: 0.2 + + # Others. + policy: ${policy} \ No newline at end of file diff --git a/configs/data/skel-hsr_v1_full.yaml b/configs/data/skel-hsr_v1_full.yaml new file mode 100644 index 0000000000000000000000000000000000000000..3377ebfdbcea6311efef32497bee64d90bcf59de --- /dev/null +++ b/configs/data/skel-hsr_v1_full.yaml @@ -0,0 +1,96 @@ +_target_: lib.data.modules.hsmr_v1.data_module.DataModule + +name: SKEL_HSMR_V1 + +cfg: + + train: + cfg: + # Loader settings. + suppress_pgt_params_pve_max_thresh: 0.06 # Mark PVE-MAX larger than this value as invalid. + suppress_kp_conf_thresh: 0.3 # Mark key-point confidence smaller than this value as invalid. + suppress_betas_thresh: 3.0 # Mark betas having components larger than this value as invalid. + poses_betas_simultaneous: True # Sync poses and betas for the small person. + filter_insufficient_kp_cnt: 4 + suppress_insufficient_kp_thresh: 0.0 + filter_reproj_err_thresh: 31000 + regularize_invalid_betas: True + # Others. + image_augmentation: ${...image_augmentation} + policy: ${...policy} + + datasets: + - name: 'H36M' + item: ${_hub_.datasets.train.hsmr.h36m} + weight: 0.1 + - name: 'MPII' + item: ${_hub_.datasets.train.hsmr.mpii} + weight: 0.1 + - name: 'COCO14' + item: ${_hub_.datasets.train.hsmr.coco14} + weight: 0.1 + - name: 'COCO14-ViTPose' + item: ${_hub_.datasets.train.hsmr.coco14vit} + weight: 0.1 + - name: 'MPI-INF-3DHP' + item: ${_hub_.datasets.train.hsmr.mpi_inf_3dhp} + weight: 0.02 + - name: 'AVA' + item: ${_hub_.datasets.train.hsmr.ava} + weight: 0.19 + - name: 'AIC' + item: ${_hub_.datasets.train.hsmr.aic} + weight: 0.19 + - name: 'INSTA' + item: ${_hub_.datasets.train.hsmr.insta} + weight: 0.2 + + dataloader: + drop_last: True + batch_size: 300 + num_workers: 6 + prefetch_factor: 2 + + mocap: + cfg: ${_hub_.datasets.mocap.bioamass_v1} + dataloader: + batch_size: 600 # num_train:2 * batch_size:300 (from HMR2.0's cfg) + drop_last: True + shuffle: True + num_workers: 1 + + eval: + cfg: + image_augmentation: ${...image_augmentation} + policy: ${...policy} + + datasets: + - name: 'LSP-EXTENDED' + item: ${_hub_.datasets.eval.lsp_extended} + - name: 'H36M-VAL-P2' + item: ${_hub_.datasets.eval.h36m_val_p2} + - name: '3DPW-TEST' + item: ${_hub_.datasets.eval.3dpw_test} + - name: 'POSETRACK-VAL' + item: ${_hub_.datasets.eval.posetrack_val} + - name: 'COCO-VAL' + item: ${_hub_.datasets.eval.coco_val} + + dataloader: + shuffle: False + batch_size: 300 + num_workers: 6 + + # Augmentation settings. + image_augmentation: + trans_factor: 0.02 + bbox_scale_factor: 0.3 + rot_aug_rate: 0.6 + rot_factor: 30 + do_flip: True + flip_aug_rate: 0.5 + extreme_crop_aug_rate: 0.10 + half_color_scale: 0.2 + + # Others. + policy: ${policy} \ No newline at end of file diff --git a/configs/exp/default.yaml b/configs/exp/default.yaml new file mode 100644 index 0000000000000000000000000000000000000000..ffb509d4aea8925b5ee865eefae04d8fca4d59fd --- /dev/null +++ b/configs/exp/default.yaml @@ -0,0 +1,16 @@ +# @package _global_ +defaults: + # Use absolute path, cause the root in this file is `configs/exp/` rather than `configs/`. + # - /data: ... + + # Configurations in this file should have the highest priority. + - _self_ + +# ======= Overwrite Section ======= +# (Do not use as much as possible!) + + + +# ====== Main Section ====== + +exp_topic: !!null # Name of experiment will determine the output folder. \ No newline at end of file diff --git a/configs/exp/hsr/train.yaml b/configs/exp/hsr/train.yaml new file mode 100644 index 0000000000000000000000000000000000000000..866c1e69c8e8f957c7a01e50ffc131671ad06b43 --- /dev/null +++ b/configs/exp/hsr/train.yaml @@ -0,0 +1,47 @@ +# @package _global_ +defaults: + - /pipeline: hsmr + - /data: skel-hsmr_v1_full + # Configurations in this file + - _self_ + # Import callbacks. + - /callback/ckpt/top1-p1k@callbacks.ckpt + # - /callback/skelify-spin/i230kb1@callbacks.SKELifySPIN + - /callback/skelify-spin/read_only@callbacks.SKELifySPIN + +# ======= Overwrite Section ======= +# (Do not use as much as possible!) + +pipeline: + cfg: + backbone: ${_hub_.models.backbones.vit_h} + backbone_ckpt: '${_pm_.inputs}/backbone/vitpose_backbone.pth' + freeze_backbone: False + +data: + cfg: + train: + dataloader: + batch_size: 78 + + +# ====== Main Section ====== + +exp_topic: 'HSMR-train-vit_h' + +enable_time_monitor: False + +seed: NULL +ckpt_path: NULL + +logger: + interval: 1000 + interval_skelify: 10 + samples_per_record: 5 + +task: 'fit' +pl_trainer: + devices: 8 + max_epochs: 100 + deterministic: false + precision: 16 \ No newline at end of file diff --git a/configs/exp/skelify-full.yaml b/configs/exp/skelify-full.yaml new file mode 100644 index 0000000000000000000000000000000000000000..8d269becd692494eb5c9a97be0f0b5c2495029a3 --- /dev/null +++ b/configs/exp/skelify-full.yaml @@ -0,0 +1,21 @@ +# @package _global_ +defaults: + - /pipeline: skelify-full + - /data: skel-hsmr_v1_full + # Configurations in this file should have the highest priority. + - _self_ + +# ======= Overwrite Section ======= +# (Do not use as much as possible!) + + + +# ====== Main Section ====== + +exp_topic: 'SKELify-Full' + +enable_time_monitor: True + +logger: + interval: 1 + samples_per_record: 8 \ No newline at end of file diff --git a/configs/exp/skelify-refiner.yaml b/configs/exp/skelify-refiner.yaml new file mode 100644 index 0000000000000000000000000000000000000000..b2f644e7f0d2a0fdb08931defb313c4774544765 --- /dev/null +++ b/configs/exp/skelify-refiner.yaml @@ -0,0 +1,21 @@ +# @package _global_ +defaults: + - /pipeline: skelify-refiner + - /data: skel-hsmr_v1_full + # Configurations in this file should have the highest priority. + - _self_ + +# ======= Overwrite Section ======= +# (Do not use as much as possible!) + + + +# ====== Main Section ====== + +exp_topic: 'SKELify-Refiner' + +enable_time_monitor: True + +logger: + interval: 2 + samples_per_record: 8 \ No newline at end of file diff --git a/configs/pipeline/hmr2.yaml b/configs/pipeline/hmr2.yaml new file mode 100644 index 0000000000000000000000000000000000000000..d32bbbc8a7bee6b0c2f113abd3df85cf13330f7e --- /dev/null +++ b/configs/pipeline/hmr2.yaml @@ -0,0 +1,43 @@ +_target_: lib.modeling.pipelines.HMR2Pipeline + +name: HMR2 + +cfg: + # Body Models + SMPL: ${_hub_.models.body_models.smpl_hsmr_neutral} + + # Backbone and its checkpoint. + backbone: ${_hub_.models.backbones.vit_h} + backbone_ckpt: ??? + + # Head to get the parameters. + head: + _target_: lib.modeling.networks.heads.SMPLTransformerDecoderHead + cfg: + transformer_decoder: + depth: 6 + heads: 8 + mlp_dim: 1024 + dim_head: 64 + dropout: 0.0 + emb_dropout: 0.0 + norm: 'layer' + context_dim: ${....backbone.embed_dim} + + optimizer: + _target_: torch.optim.AdamW + lr: 1e-5 + weight_decay: 1e-4 + + # This may be redesigned, e.g., we can add a loss object to maintain the calculation of the loss, or a callback. + loss_weights: + kp3d: 0.05 + kp2d: 0.01 + poses_orient: 0.002 + poses_body: 0.001 + betas: 0.0005 + adversarial: 0.0005 + # adversarial: 0.0 + + policy: ${policy} + logger: ${logger} \ No newline at end of file diff --git a/configs/pipeline/hsr.yaml b/configs/pipeline/hsr.yaml new file mode 100644 index 0000000000000000000000000000000000000000..a2d5c6183ecccf7ad0af94d922562c656266899b --- /dev/null +++ b/configs/pipeline/hsr.yaml @@ -0,0 +1,49 @@ +_target_: lib.modeling.pipelines.HSMRPipeline + +name: HSMR + +cfg: + pd_poses_repr: 'rotation_6d' # poses representation for prediction, choices: 'euler_angle' | 'rotation_6d' + sp_poses_repr: 'rotation_matrix' # poses representation for supervision, choices: 'euler_angle' | 'rotation_matrix' + + # Body Models + # SKEL: ${_hub_.models.body_models.skel_hsmr} + SKEL: ${_hub_.models.body_models.skel_mix_hsmr} + + # Backbone and its checkpoint. + backbone: ??? + backbone_ckpt: ??? + + # Head to get the parameters. + head: + _target_: lib.modeling.networks.heads.SKELTransformerDecoderHead + cfg: + pd_poses_repr: ${...pd_poses_repr} + transformer_decoder: + depth: 6 + heads: 8 + mlp_dim: 1024 + dim_head: 64 + dropout: 0.0 + emb_dropout: 0.0 + norm: 'layer' + context_dim: ${....backbone.embed_dim} + + optimizer: + _target_: torch.optim.AdamW + lr: 1e-5 + weight_decay: 1e-4 + + # This may be redesigned, e.g., we can add a loss object to maintain the calculation of the loss, or a callback. + loss_weights: + kp3d: 0.05 + kp2d: 0.01 + # prior: 0.0005 + prior: 0.0 + poses_orient: 0.002 + poses_body: 0.001 + betas: 0.0005 + # adversarial: 0.0005 + + policy: ${policy} + logger: ${logger} \ No newline at end of file diff --git a/configs/pipeline/skelify-full.yaml b/configs/pipeline/skelify-full.yaml new file mode 100644 index 0000000000000000000000000000000000000000..662f90d5bc35df34ccbfa63303269ce135236839 --- /dev/null +++ b/configs/pipeline/skelify-full.yaml @@ -0,0 +1,98 @@ +_target_: lib.modeling.optim.SKELify + +name: SKELify + +cfg: + skel_model: ${_hub_.models.body_models.skel_mix_hsmr} + + _f_normalize_kp2d: True + _f_normalize_kp2d_to_mean: False + _w_angle_prior_scale: 1.7 + + phases: + # ================================ + # ⛩️ Part 1: Camera initialization. + # -------------------------------- + STAGE-camera-init: + max_loop: 30 + params_keys: ['cam_t', 'poses_orient'] + parts: ['torso'] + optimizer: ${...optimizer} + losses: + f_normalize_kp2d: ${...._f_normalize_kp2d} + f_normalize_kp2d_to_mean: ${...._f_normalize_kp2d_to_mean} + w_depth: 100.0 + w_reprojection: 1.78 + # ================================ + + # ================================ + # ⛩️ Part 2: Overall optimization. + # -------------------------------- + STAGE-overall-1: + max_loop: 30 + params_keys: ['cam_t', 'poses_orient', 'poses_body', 'betas'] + parts: ['all'] + optimizer: ${...optimizer} + losses: + f_normalize_kp2d: ${...._f_normalize_kp2d} + f_normalize_kp2d_to_mean: ${...._f_normalize_kp2d_to_mean} + w_reprojection: 1.0 + w_shape_prior: 100.0 + w_angle_prior: 404.0 + w_angle_prior_scale: ${...._w_angle_prior_scale} # TODO: Finalize it. + # -------------------------------- + STAGE-overall-2: + max_loop: 30 + params_keys: ['cam_t', 'poses_orient', 'poses_body', 'betas'] + optimizer: ${...optimizer} + parts: ['all'] + losses: + f_normalize_kp2d: ${...._f_normalize_kp2d} + f_normalize_kp2d_to_mean: ${...._f_normalize_kp2d_to_mean} + w_reprojection: 1.0 + w_shape_prior: 50.0 + w_angle_prior: 404.0 + w_angle_prior_scale: ${...._w_angle_prior_scale} # TODO: Finalize it. + # -------------------------------- + STAGE-overall-3: + max_loop: 30 + params_keys: ['cam_t', 'poses_orient', 'poses_body', 'betas'] + parts: ['all'] + optimizer: ${...optimizer} + losses: + f_normalize_kp2d: ${...._f_normalize_kp2d} + f_normalize_kp2d_to_mean: ${...._f_normalize_kp2d_to_mean} + w_reprojection: 1.0 + w_shape_prior: 10.0 + w_angle_prior: 57.4 + w_angle_prior_scale: ${...._w_angle_prior_scale} # TODO: Finalize it. + # -------------------------------- + STAGE-overall-4: + max_loop: 30 + params_keys: ['cam_t', 'poses_orient', 'poses_body', 'betas'] + parts: ['all'] + optimizer: ${...optimizer} + losses: + f_normalize_kp2d: ${...._f_normalize_kp2d} + f_normalize_kp2d_to_mean: ${...._f_normalize_kp2d_to_mean} + w_reprojection: 1.0 + w_shape_prior: 5.0 + w_angle_prior: 4.78 + w_angle_prior_scale: ${...._w_angle_prior_scale} # TODO: Finalize it. + # ================================ + + optimizer: + _target_: torch.optim.LBFGS + lr: 1 + line_search_fn: 'strong_wolfe' + tolerance_grad: ${..early_quit_thresholds.abs} + tolerance_change: ${..early_quit_thresholds.rel} + + early_quit_thresholds: + abs: 1e-9 + rel: 1e-9 + + img_patch_size: ${policy.img_patch_size} + focal_length: ${policy.focal_length} + + logger: ${logger} \ No newline at end of file diff --git a/configs/pipeline/skelify-refiner.yaml b/configs/pipeline/skelify-refiner.yaml new file mode 100644 index 0000000000000000000000000000000000000000..4e2195579374246d7d4a3e9fc13cdfc4381b6301 --- /dev/null +++ b/configs/pipeline/skelify-refiner.yaml @@ -0,0 +1,37 @@ +_target_: lib.modeling.optim.SKELify + +name: SKELify-Refiner + +cfg: + skel_model: ${_hub_.models.body_models.skel_mix_hsmr} + + phases: + STAGE-refine: + max_loop: 10 + params_keys: ['cam_t', 'poses_orient', 'poses_body', 'betas'] + optimizer: ${...optimizer} + losses: + f_normalize_kp2d: True + f_normalize_kp2d_to_mean: False + w_reprojection: 1.0 + w_shape_prior: 5.0 + w_angle_prior: 4.78 + w_angle_prior_scale: 0.17 + parts: ['all'] + # ================================ + + optimizer: + _target_: torch.optim.LBFGS + lr: 1 + line_search_fn: 'strong_wolfe' + tolerance_grad: ${..early_quit_thresholds.abs} + tolerance_change: ${..early_quit_thresholds.rel} + + early_quit_thresholds: + abs: 1e-7 + rel: 1e-9 + + img_patch_size: ${policy.img_patch_size} + focal_length: ${policy.focal_length} + + logger: ${logger} \ No newline at end of file diff --git a/configs/policy/README.md b/configs/policy/README.md new file mode 100644 index 0000000000000000000000000000000000000000..7af8f5e62fa39f47dd1d90bd0d13e2b44a90f562 --- /dev/null +++ b/configs/policy/README.md @@ -0,0 +1,3 @@ +# policy + +Configs here define frequently used values that are shared across multiple sub-modules. They usually cannot belongs to any specific sub-module. \ No newline at end of file diff --git a/configs/policy/default.yaml b/configs/policy/default.yaml new file mode 100644 index 0000000000000000000000000000000000000000..d92856e58e8a9da9231f0313a836f2f2776fdec3 --- /dev/null +++ b/configs/policy/default.yaml @@ -0,0 +1,5 @@ +img_patch_size: 256 +focal_length: 5000 +img_mean: [0.485, 0.456, 0.406] +img_std: [0.229, 0.224, 0.225] +bbox_shape: null \ No newline at end of file diff --git a/data_inputs/.gitignore b/data_inputs/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..9964fcd73181748d80bc2783fdad0ed777484ada --- /dev/null +++ b/data_inputs/.gitignore @@ -0,0 +1,2 @@ +body_models +released_models \ No newline at end of file diff --git a/data_inputs/body_models.tar.gz b/data_inputs/body_models.tar.gz new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/data_inputs/description.md b/data_inputs/description.md new file mode 100644 index 0000000000000000000000000000000000000000..21c940ed0da5fc5d36743a44ca0cb168ad10a39e --- /dev/null +++ b/data_inputs/description.md @@ -0,0 +1,21 @@ +
+ HSMR +
+
+ Reconstructing Humans with a
Biomechanically Accurate Skeleton +
+
+Project Page +| +GitHub Repo +| +Paper +
+ +> Reconstructing Humans with a Biomechanically Accurate Skeleton
+> [Yan Xia](https://scholar.isshikih.top), +> [Xiaowei Zhou](https://xzhou.me), +> [Etienne Vouga](https://www.cs.utexas.edu/~evouga/), +> [Qixing Huang](https://www.cs.utexas.edu/~huangqx/), +> [Georgios Pavlakos](https://geopavlakos.github.io/)
+> *CVPR 2025* diff --git a/data_inputs/example_imgs/ballerina.png b/data_inputs/example_imgs/ballerina.png new file mode 100644 index 0000000000000000000000000000000000000000..c40509ecda6bd2e5280a2033a82d885276162dc1 Binary files /dev/null and b/data_inputs/example_imgs/ballerina.png differ diff --git a/data_inputs/example_imgs/exercise.jpeg b/data_inputs/example_imgs/exercise.jpeg new file mode 100644 index 0000000000000000000000000000000000000000..c2c1e6403eeb7228ca8d605dad28225f1a98d039 Binary files /dev/null and b/data_inputs/example_imgs/exercise.jpeg differ diff --git a/data_inputs/example_imgs/jump_high.jpg b/data_inputs/example_imgs/jump_high.jpg new file mode 100644 index 0000000000000000000000000000000000000000..fc75dcb8da75e19da1a7a3d794ae3cc957726ffb Binary files /dev/null and b/data_inputs/example_imgs/jump_high.jpg differ diff --git a/data_outputs/.gitignore b/data_outputs/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/lib/__init__.py b/lib/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..7152555990d5a8a9d489ea8dc9c69c528d034ad6 --- /dev/null +++ b/lib/__init__.py @@ -0,0 +1 @@ +from .version import __version__ \ No newline at end of file diff --git a/lib/body_models/abstract_skeletons.py b/lib/body_models/abstract_skeletons.py new file mode 100644 index 0000000000000000000000000000000000000000..3d3a3614b53aa7f0687843d9d2367e437fb18442 --- /dev/null +++ b/lib/body_models/abstract_skeletons.py @@ -0,0 +1,123 @@ +# "Skeletons" here means virtual "bones" between joints. It is used to draw the skeleton on the image. + +class Skeleton(): + bones = [] + bone_colors = [] + chains = [] + parent = [] + + +class Skeleton_SMPL24(Skeleton): + # The joints definition are copied from + # [ROMP](https://github.com/Arthur151/ROMP/blob/4eebd3647f57d291d26423e51f0d514ff7197cb3/romp/lib/constants.py#L58). + chains = [ + [0, 1, 4, 7, 10], # left leg + [0, 2, 5, 8, 11], # right leg + [0, 3, 6, 9, 12, 15], # spine & head + [12, 13, 16, 18, 20, 22], # left arm + [12, 14, 17, 19, 21, 23], # right arm + ] + bones = [ + [ 0, 1], [ 1, 4], [ 4, 7], [ 7, 10], # left leg + [ 0, 2], [ 2, 5], [ 5, 8], [ 8, 11], # right leg + [ 0, 3], [ 3, 6], [ 6, 9], [ 9, 12], [12, 15], # spine & head + [12, 13], [13, 16], [16, 18], [18, 20], [20, 22], # left arm + [12, 14], [14, 17], [17, 19], [19, 21], [21, 23], # right arm + ] + bone_colors = [ + [127, 0, 0], [148, 21, 21], [169, 41, 41], [191, 63, 63], # red + [ 0, 127, 0], [ 21, 148, 21], [ 41, 169, 41], [ 63, 191, 63], # green + [ 0, 0, 127], [ 15, 15, 143], [ 31, 31, 159], [ 47, 47, 175], [ 63, 63, 191], # blue + [ 0, 127, 127], [ 15, 143, 143], [ 31, 159, 159], [ 47, 175, 175], [ 63, 191, 191], # cyan + [127, 0, 127], [143, 15, 143], [159, 31, 159], [175, 47, 175], [191, 63, 191], # magenta + ] + parent = [-1, 0, 0, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 12, 12, 12, 13, 14, 16, 17, 18, 19, 20, 21] + + +class Skeleton_SMPL22(Skeleton): + chains = [ + [0, 1, 4, 7, 10], # left leg + [0, 2, 5, 8, 11], # right leg + [0, 3, 6, 9, 12, 15], # spine & head + [12, 13, 16, 18, 20], # left arm + [12, 14, 17, 19, 21], # right arm + ] + bones = [ + [ 0, 1], [ 1, 4], [ 4, 7], [ 7, 10], # left leg + [ 0, 2], [ 2, 5], [ 5, 8], [ 8, 11], # right leg + [ 0, 3], [ 3, 6], [ 6, 9], [ 9, 12], [12, 15], # spine & head + [12, 13], [13, 16], [16, 18], [18, 20], # left arm + [12, 14], [14, 17], [17, 19], [19, 21], # right arm + ] + bone_colors = [ + [127, 0, 0], [148, 21, 21], [169, 41, 41], [191, 63, 63], # red + [ 0, 127, 0], [ 21, 148, 21], [ 41, 169, 41], [ 63, 191, 63], # green + [ 0, 0, 127], [ 15, 15, 143], [ 31, 31, 159], [ 47, 47, 175], [ 63, 63, 191], # blue + [ 0, 127, 127], [ 15, 143, 143], [ 31, 159, 159], [ 47, 175, 175], # cyan + [127, 0, 127], [143, 15, 143], [159, 31, 159], [175, 47, 175], # magenta + ] + parent = [-1, 0, 0, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 12, 12, 12, 13, 14, 16, 17, 18, 19] + + +class Skeleton_SKEL24(Skeleton): + chains = [ + [ 0, 6, 7, 8, 9, 10], # left leg + [ 0, 1, 2, 3, 4, 5], # right leg + [ 0, 11, 12, 13], # spine & head + [12, 19, 20, 21, 22, 23], # left arm + [12, 14, 15, 16, 17, 18], # right arm + ] + bones = [ + [ 0, 6], [ 6, 7], [ 7, 8], [ 8, 9], [ 9, 10], # left leg + [ 0, 1], [ 1, 2], [ 2, 3], [ 3, 4], [ 4, 5], # right leg + [ 0, 11], [11, 12], [12, 13], # spine & head + [12, 19], [19, 20], [20, 21], [21, 22], [22, 23], # left arm + [12, 14], [14, 15], [15, 16], [16, 17], [17, 18], # right arm + ] + bone_colors = [ + [127, 0, 0], [148, 21, 21], [169, 41, 41], [191, 63, 63], [191, 63, 63], # red + [ 0, 127, 0], [ 21, 148, 21], [ 41, 169, 41], [ 63, 191, 63], [ 63, 191, 63], # green + [ 0, 0, 127], [ 31, 31, 159], [ 63, 63, 191], # blue + [ 0, 127, 127], [ 15, 143, 143], [ 31, 159, 159], [ 47, 175, 175], [ 63, 191, 191], # cyan + [127, 0, 127], [143, 15, 143], [159, 31, 159], [175, 47, 175], [191, 63, 191], # magenta + ] + parent = [-1, 0, 1, 2, 3, 4, 0, 6, 7, 8, 9, 0, 11, 12, 12, 19, 20, 21, 22, 12, 14, 15, 16, 17] + + +class Skeleton_OpenPose25(Skeleton): + ''' https://www.researchgate.net/figure/Twenty-five-keypoints-of-the-OpenPose-software-model_fig1_374116819 ''' + chain = [ + [ 8, 12, 13, 14, 19, 20], # left leg + [14, 21], # left heel + [ 8, 9, 10, 11, 22, 23], # right leg + [11, 24], # right heel + [ 8, 1, 0], # spine & head + [ 0, 16, 18], # left face + [ 0, 15, 17], # right face + [ 1, 5, 6, 7], # left arm + [ 1, 2, 3, 4], # right arm + ] + bones = [ + [ 8, 12], [12, 13], [13, 14], # left leg + [14, 19], [19, 20], [14, 21], # left foot + [ 8, 9], [ 9, 10], [10, 11], # right leg + [11, 22], [22, 23], [11, 24], # right foot + [ 8, 1], [ 1, 0], # spine & head + [ 0, 16], [16, 18], # left face + [ 0, 15], [15, 17], # right face + [ 1, 5], [ 5, 6], [ 6, 7], # left arm + [ 1, 2], [ 2, 3], [ 3, 4], # right arm + ] + bone_colors = [ + [ 95, 0, 255], [ 79, 0, 255], [ 83, 0, 255], # dark blue + [ 31, 0, 255], [ 15, 0, 255], [ 0, 0, 255], # dark blue + [127, 205, 255], [127, 205, 255], [ 95, 205, 255], # light blue + [ 63, 205, 255], [ 31, 205, 255], [ 0, 205, 255], # light blue + [255, 0, 0], [255, 0, 0], # red + [191, 63, 63], [191, 63, 191], # magenta + [255, 0, 127], [255, 0, 255], # purple + [127, 255, 0], [ 63, 255, 0], [ 0, 255, 0], # green + [255, 127, 0], [255, 191, 0], [255, 255, 0], # yellow + + ] + parent = [1, 8, 1, 2, 3, 1, 5, 6, -1, 8, 9, 10, 8, 12, 13, 0, 0, 15, 16, 14, 19, 14, 11, 22, 11] diff --git a/lib/body_models/common.py b/lib/body_models/common.py new file mode 100644 index 0000000000000000000000000000000000000000..e7b72a77400353c1c5210da4320023e79b06894e --- /dev/null +++ b/lib/body_models/common.py @@ -0,0 +1,69 @@ +from lib.kits.basic import * + +from smplx import SMPL + +from lib.platform import PM +from lib.body_models.skel_wrapper import SKELWrapper as SKEL +from lib.body_models.smpl_wrapper import SMPLWrapper + +def make_SMPL(gender='neutral', device='cuda:0'): + return SMPL( + gender = gender, + model_path = PM.inputs / 'body_models' / 'smpl', + ).to(device) + + +def make_SMPL_hmr2(gender='neutral', device='cuda:0'): + ''' SKEL doesn't have neutral model, so align with SKEL, using male. ''' + return SMPLWrapper( + gender = gender, + model_path = PM.inputs / 'body_models' / 'smpl', + num_body_joints = 23, + joint_regressor_extra = PM.inputs / 'body_models/SMPL_to_J19.pkl', + ).to(device) + + + +def make_SKEL(gender='male', device='cuda:0'): + ''' We don't have neutral model for SKEL, so use male for now. ''' + return make_SKEL_mix_joints(gender, device) + + +def make_SKEL_smpl_joints(gender='male', device='cuda:0'): + ''' We don't have neutral model for SKEL, so use male for now. ''' + return SKEL( + gender = gender, + model_path = PM.inputs / 'body_models' / 'skel', + joint_regressor_extra = PM.inputs / 'body_models' / 'SMPL_to_J19.pkl', + joint_regressor_custom = PM.inputs / 'body_models' / 'J_regressor_SMPL_MALE.pkl', + ).to(device) + + +def make_SKEL_skel_joints(gender='male', device='cuda:0'): + ''' We don't have neutral model for SKEL, so use male for now. ''' + return SKEL( + gender = gender, + model_path = PM.inputs / 'body_models' / 'skel', + joint_regressor_extra = PM.inputs / 'body_models' / 'SMPL_to_J19.pkl', + ).to(device) + + +def make_SKEL_mix_joints(gender='male', device='cuda:0'): + ''' We don't have neutral model for SKEL, so use male for now. ''' + return SKEL( + gender = gender, + model_path = PM.inputs / 'body_models' / 'skel', + joint_regressor_extra = PM.inputs / 'body_models' / 'SMPL_to_J19.pkl', + joint_regressor_custom = PM.inputs / 'body_models' / 'J_regressor_SKEL_mix_MALE.pkl', + ).to(device) + + +def make_SMPLX_moyo(v_template_path:Union[str, Path], batch_size:int=1, device='cuda:0'): + from lib.body_models.moyo_smplx_wrapper import MoYoSMPLX + + return MoYoSMPLX( + model_path = PM.inputs / 'body_models' / 'smplx', + v_template_path = v_template_path, + batch_size = batch_size, + device = device, + ) \ No newline at end of file diff --git a/lib/body_models/moyo_smplx_wrapper.py b/lib/body_models/moyo_smplx_wrapper.py new file mode 100644 index 0000000000000000000000000000000000000000..df391ce8da8838b60549d8990b583b37a8707824 --- /dev/null +++ b/lib/body_models/moyo_smplx_wrapper.py @@ -0,0 +1,77 @@ +from lib.kits.basic import * + +import smplx +from psbody.mesh import Mesh + +class MoYoSMPLX(smplx.SMPLX): + + def __init__( + self, + model_path : Union[str, Path], + v_template_path : Union[str, Path], + batch_size = 1, + n_betas = 10, + device = 'cpu' + ): + + if isinstance(v_template_path, Path): + v_template_path = str(v_template_path) + + # Load the `v_template`. + v_template_mesh = Mesh(filename=v_template_path) + v_template = to_tensor(v_template_mesh.v, device=device) + + self.n_betas = n_betas + + # Create the `body_model_params`. + body_model_params = { + 'model_path' : model_path, + 'gender' : 'neutral', + 'v_template' : v_template.float(), + 'batch_size' : batch_size, + 'create_global_orient' : True, + 'create_body_pose' : True, + 'create_betas' : True, + 'num_betas' : self.n_betas, # They actually don't use num_betas. + 'create_left_hand_pose' : True, + 'create_right_hand_pose' : True, + 'create_expression' : True, + 'create_jaw_pose' : True, + 'create_leye_pose' : True, + 'create_reye_pose' : True, + 'create_transl' : True, + 'use_pca' : False, + 'flat_hand_mean' : True, + 'dtype' : torch.float32, + } + + super().__init__(**body_model_params) + self = self.to(device) + + def forward(self, **kwargs): + ''' Only all parameters are passed, the batch_size will be flexible adjusted. ''' + assert 'global_orient' in kwargs, '`global_orient` is required for the forward pass.' + assert 'body_pose' in kwargs, '`body_pose` is required for the forward pass.' + B = kwargs['global_orient'].shape[0] + body_pose = kwargs['body_pose'] + + if 'left_hand_pose' not in kwargs: + kwargs['left_hand_pose'] = body_pose.new_zeros((B, 45)) + get_logger().warning('`left_hand_pose` is not provided, but it\'s expected, set to zeros.') + if 'right_hand_pose' not in kwargs: + kwargs['right_hand_pose'] = body_pose.new_zeros((B, 45)) + get_logger().warning('`left_hand_pose` is not provided, but it\'s expected, set to zeros.') + if 'transl' not in kwargs: + kwargs['transl'] = body_pose.new_zeros((B, 3)) + if 'betas' not in kwargs: + kwargs['betas'] = body_pose.new_zeros((B, self.n_betas)) + if 'expression' not in kwargs: + kwargs['expression'] = body_pose.new_zeros((B, 10)) + if 'jaw_pose' not in kwargs: + kwargs['jaw_pose'] = body_pose.new_zeros((B, 3)) + if 'leye_pose' not in kwargs: + kwargs['leye_pose'] = body_pose.new_zeros((B, 3)) + if 'reye_pose' not in kwargs: + kwargs['reye_pose'] = body_pose.new_zeros((B, 3)) + + return super().forward(**kwargs) \ No newline at end of file diff --git a/lib/body_models/skel/__init__.py b/lib/body_models/skel/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/lib/body_models/skel/alignment/align_config.py b/lib/body_models/skel/alignment/align_config.py new file mode 100644 index 0000000000000000000000000000000000000000..0a743fd556d952ba57ee3f73f6f047265ff6dc72 --- /dev/null +++ b/lib/body_models/skel/alignment/align_config.py @@ -0,0 +1,72 @@ +""" Optimization config file. In 'optim_steps' we define each optimization steps. +Each step inherit and overwrite the parameters of the previous step.""" + +config = { + 'keepalive_meshviewer': False, + 'optim_steps': + [ + { + 'description' : 'Adjust the root orientation and translation', + 'use_basic_loss': True, + 'lr': 1, + 'max_iter': 20, + 'num_steps': 10, + 'line_search_fn': 'strong_wolfe', #'strong_wolfe', + 'tolerance_change': 1e-7,# 1e-4, #0.01 + 'mode' : 'root_only', + + 'l_verts_loose': 300, + 'l_time_loss': 0,#5e2, + + 'l_joint': 0.0, + 'l_verts': 0, + 'l_scapula_loss': 0.0, + 'l_spine_loss': 0.0, + 'l_pose_loss': 0.0, + + + }, + # Adjust the upper limbs + { + 'description' : 'Adjust the upper limbs pose', + 'lr': 0.1, + 'max_iter': 20, + 'num_steps': 10, + 'tolerance_change': 1e-7, + 'mode' : 'fixed_upper_limbs', #'fixed_root', + + 'l_verts_loose': 600, + 'l_joint': 1e3, + 'l_time_loss': 0,# 5e2, + 'l_pose_loss': 1e-4, + }, + # Adjust the whole body + { + 'description' : 'Adjust the whole body pose with fixed root', + 'lr': 0.1, + 'max_iter': 20, + 'num_steps': 10, + 'tolerance_change': 1e-7, + 'mode' : 'fixed_root', #'fixed_root', + + 'l_verts_loose': 600, + 'l_joint': 1e3, + 'l_time_loss': 0, + 'l_pose_loss': 1e-4, + }, + # + { + 'description' : 'Free optimization', + 'lr': 0.1, + 'max_iter': 20, + 'num_steps': 10, + 'tolerance_change': 1e-7, + 'mode' : 'free', #'fixed_root', + + 'l_verts_loose': 600, + 'l_joint': 1e3, + 'l_time_loss':0, + 'l_pose_loss': 1e-4, + }, + ] +} \ No newline at end of file diff --git a/lib/body_models/skel/alignment/align_config_joint.py b/lib/body_models/skel/alignment/align_config_joint.py new file mode 100644 index 0000000000000000000000000000000000000000..99b0b55068e9c6f45fb403eed0b39f60c2e67c7b --- /dev/null +++ b/lib/body_models/skel/alignment/align_config_joint.py @@ -0,0 +1,72 @@ +""" Optimization config file. In 'optim_steps' we define each optimization steps. +Each step inherit and overwrite the parameters of the previous step.""" + +config = { + 'keepalive_meshviewer': False, + 'optim_steps': + [ + { + 'description' : 'Adjust the root orientation and translation', + 'use_basic_loss': True, + 'lr': 1, + 'max_iter': 20, + 'num_steps': 10, + 'line_search_fn': 'strong_wolfe', #'strong_wolfe', + 'tolerance_change': 1e-7,# 1e-4, #0.01 + 'mode' : 'root_only', + + 'l_verts_loose': 300, + 'l_time_loss': 0,#5e2, + + 'l_joint': 0.0, + 'l_verts': 0, + 'l_scapula_loss': 0.0, + 'l_spine_loss': 0.0, + 'l_pose_loss': 0.0, + + + }, + # Adjust the upper limbs + { + 'description' : 'Adjust the upper limbs pose', + 'lr': 0.1, + 'max_iter': 20, + 'num_steps': 10, + 'tolerance_change': 1e-7, + 'mode' : 'fixed_upper_limbs', #'fixed_root', + + 'l_verts_loose': 600, + 'l_joint': 2e3, + 'l_time_loss': 0,# 5e2, + 'l_pose_loss': 1e-3, + }, + # Adjust the whole body + { + 'description' : 'Adjust the whole body pose with fixed root', + 'lr': 0.1, + 'max_iter': 20, + 'num_steps': 10, + 'tolerance_change': 1e-7, + 'mode' : 'fixed_root', #'fixed_root', + + 'l_verts_loose': 600, + 'l_joint': 1e3, + 'l_time_loss': 0, + 'l_pose_loss': 1e-4, + }, + # + { + 'description' : 'Free optimization', + 'lr': 0.1, + 'max_iter': 20, + 'num_steps': 10, + 'tolerance_change': 1e-7, + 'mode' : 'free', #'fixed_root', + + 'l_verts_loose': 300, + 'l_joint': 2e3, + 'l_time_loss':0, + 'l_pose_loss': 1e-4, + }, + ] +} \ No newline at end of file diff --git a/lib/body_models/skel/alignment/aligner.py b/lib/body_models/skel/alignment/aligner.py new file mode 100644 index 0000000000000000000000000000000000000000..e3a953538dcdb2df6792a32015102d314f45baf8 --- /dev/null +++ b/lib/body_models/skel/alignment/aligner.py @@ -0,0 +1,469 @@ + +""" +Copyright©2023 Max-Planck-Gesellschaft zur Förderung +der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute +for Intelligent Systems. All rights reserved. + +Author: Soyong Shin, Marilyn Keller +See https://skel.is.tue.mpg.de/license.html for licensing and contact information. +""" + +import traceback +import math +import os +import pickle +import torch +import smplx +import omegaconf +import torch.nn.functional as F +from psbody.mesh import Mesh, MeshViewer, MeshViewers +from tqdm import trange +from pathlib import Path + +import lib.body_models.skel.config as cg +from lib.body_models.skel.skel_model import SKEL +from .losses import compute_anchor_pose, compute_anchor_trans, compute_pose_loss, compute_scapula_loss, compute_spine_loss, compute_time_loss, pretty_loss_print +from .utils import location_to_spheres, to_numpy, to_params, to_torch +from .align_config import config +from .align_config_joint import config as config_joint + +class SkelFitter(object): + + def __init__(self, gender, device, num_betas=10, export_meshes=False, joint_optim=False) -> None: + + self.smpl = smplx.create(cg.smpl_folder, model_type='smpl', gender=gender, num_betas=num_betas, batch_size=1, export_meshes=False).to(device) + self.skel = SKEL(gender).to(device) + self.gender = gender + self.device = device + self.num_betas = num_betas + # Instanciate masks used for the vertex to vertex fitting + fitting_mask_file = Path(__file__).parent / 'riggid_parts_mask.pkl' + fitting_indices = pickle.load(open(fitting_mask_file, 'rb')) + fitting_mask = torch.zeros(6890, dtype=torch.bool, device=self.device) + fitting_mask[fitting_indices] = 1 + self.fitting_mask = fitting_mask.reshape(1, -1, 1).to(self.device) # 1xVx1 to be applied to verts that are BxVx3 + + smpl_torso_joints = [0,3] + verts_mask = (self.smpl.lbs_weights[:,smpl_torso_joints]>0.5).sum(dim=-1)>0 + self.torso_verts_mask = verts_mask.unsqueeze(0).unsqueeze(-1) # Because verts are of shape BxVx3 + + self.export_meshes = export_meshes + + + # make the cfg being an object using omegaconf + if joint_optim: + self.cfg = omegaconf.OmegaConf.create(config_joint) + else: + self.cfg = omegaconf.OmegaConf.create(config) + + # Instanciate the mesh viewer to visualize the fitting + if('DISABLE_VIEWER' in os.environ): + self.mv = None + print("\n DISABLE_VIEWER flag is set, running in headless mode") + else: + self.mv = MeshViewers((1,2), keepalive=self.cfg.keepalive_meshviewer) + + + def run_fit(self, + trans_in, + betas_in, + poses_in, + batch_size=20, + skel_data_init=None, + force_recompute=False, + debug=False, + watch_frame=0, + freevert_mesh=None, + opt_sequence=False, + fix_poses=False, + variant_exp=''): + """Align SKEL to a SMPL sequence.""" + + self.nb_frames = poses_in.shape[0] + self.watch_frame = watch_frame + self.is_skel_data_init = skel_data_init is not None + self.force_recompute = force_recompute + + print('Fitting {} frames'.format(self.nb_frames)) + print('Watching frame: {}'.format(watch_frame)) + + # Initialize SKEL torch params + body_params = self._init_params(betas_in, poses_in, trans_in, skel_data_init, variant_exp) + + # We cut the whole sequence in batches for parallel optimization + if batch_size > self.nb_frames: + batch_size = self.nb_frames + print('Batch size is larger than the number of frames. Setting batch size to {}'.format(batch_size)) + + n_batch = math.ceil(self.nb_frames/batch_size) + pbar = trange(n_batch, desc='Running batch optimization') + + # Initialize the res dict to store the per frame result skel parameters + out_keys = ['poses', 'betas', 'trans'] + if self.export_meshes: + out_keys += ['skel_v', 'skin_v', 'smpl_v'] + res_dict = {key: [] for key in out_keys} + + res_dict['gender'] = self.gender + if self.export_meshes: + res_dict['skel_f'] = self.skel.skel_f.cpu().numpy().copy() + res_dict['skin_f'] = self.skel.skin_f.cpu().numpy().copy() + res_dict['smpl_f'] = self.smpl.faces + + # Iterate over the batches to fit the whole sequence + for i in pbar: + + if debug: + # Only run the first batch to test, ignore the rest + if i > 1: + continue + + # Get batch start and end indices + i_start = i * batch_size + i_end = min((i+1) * batch_size, self.nb_frames) + + # Fit the batch + betas, poses, trans, verts = self._fit_batch(body_params, i, i_start, i_end, enable_time=opt_sequence, fix_poses=fix_poses) + # if torch.isnan(betas).any() \ + # or torch.isnan(poses).any() \ + # or torch.isnan(trans).any(): + # print(f'Nan values detected.') + # raise ValueError('Nan values detected in the output.') + + # Store ethe results + res_dict['poses'].append(poses) + res_dict['betas'].append(betas) + res_dict['trans'].append(trans) + if self.export_meshes: + # Store the meshes vertices + skel_output = self.skel.forward(poses=poses, betas=betas, trans=trans, poses_type='skel', skelmesh=True) + res_dict['skel_v'].append(skel_output.skel_verts) + res_dict['skin_v'].append(skel_output.skin_verts) + res_dict['smpl_v'].append(verts) + + if opt_sequence: + # Initialize the next frames with current frame + body_params['poses_skel'][i_end:] = poses[-1:].detach() + body_params['trans_skel'][i_end:] = trans[-1].detach() + body_params['betas_skel'][i_end:] = betas[-1:].detach() + + # Concatenate the batches and convert to numpy + for key, val in res_dict.items(): + if isinstance(val, list): + res_dict[key] = torch.cat(val, dim=0).detach().cpu().numpy() + + return res_dict + + def _init_params(self, betas_smpl, poses_smpl, trans_smpl, skel_data_init=None, variant_exp=''): + """ Return initial SKEL parameters from SMPL data dictionary and an optional SKEL data dictionary.""" + + if skel_data_init is None or self.force_recompute: + + poses_skel = torch.zeros((self.nb_frames, self.skel.num_q_params), device=self.device) + if variant_exp == '' or variant_exp == '_official_old': + poses_skel[:, :3] = poses_smpl[:, :3] # Global orient are similar between SMPL and SKEL, so init with SMPL angles + elif variant_exp == '_official_fix': + # https://github.com/MarilynKeller/SKEL/commit/d1f6ff62235c142ba010158e00e21fd4fe25807f#diff-09188717a56a42e9589e9bd289f9ddb4fb53160e03c81a7ced70b3a84c1d9d0bR157 + pass + elif variant_exp == '_my_fix': + gt_orient_aa = poses_smpl[:, :3] + # IMPORTANT: The alignment comes from `exp/inspect_skel/archive/orientation.py`. + from lib.utils.geometry.rotation import axis_angle_to_matrix, matrix_to_euler_angles + gt_orient_mat = axis_angle_to_matrix(gt_orient_aa) + gt_orient_ea = matrix_to_euler_angles(gt_orient_mat, 'YXZ') + flip = torch.tensor([-1, 1, 1], device=self.device) + poses_skel[:, :3] = gt_orient_ea[:, [2, 1, 0]] * flip + else: + raise ValueError(f'Unknown variant_exp {variant_exp}') + + betas_skel = torch.zeros((self.nb_frames, 10), device=self.device) + betas_skel[:] = betas_smpl[..., :10] + + trans_skel = trans_smpl # Translation is similar between SMPL and SKEL, so init with SMPL translation + + else: + # Load from previous alignment + betas_skel = to_torch(skel_data_init['betas'], self.device) + poses_skel = to_torch(skel_data_init['poses'], self.device) + trans_skel = to_torch(skel_data_init['trans'], self.device) + + # Make a dictionary out of the necessary body parameters + body_params = { + 'betas_skel': betas_skel, + 'poses_skel': poses_skel, + 'trans_skel': trans_skel, + 'betas_smpl': betas_smpl, + 'poses_smpl': poses_smpl, + 'trans_smpl': trans_smpl + } + + return body_params + + + def _fit_batch(self, body_params, i, i_start, i_end, enable_time=False, fix_poses=False): + """ Create parameters for the batch and run the optimization.""" + + # Sample a batch ver + body_params = { key: val[i_start:i_end] for key, val in body_params.items()} + + # SMPL params + betas_smpl = body_params['betas_smpl'] + poses_smpl = body_params['poses_smpl'] + trans_smpl = body_params['trans_smpl'] + + # SKEL params + betas = to_params(body_params['betas_skel'], device=self.device) + poses = to_params(body_params['poses_skel'], device=self.device) + trans = to_params(body_params['trans_skel'], device=self.device) + + if 'verts' in body_params: + verts = body_params['verts'] + else: + # Run a SMPL forward pass to get the SMPL body vertices + smpl_output = self.smpl(betas=betas_smpl, body_pose=poses_smpl[:,3:], transl=trans_smpl, global_orient=poses_smpl[:,:3]) + verts = smpl_output.vertices + + # Optimize + config = self.cfg.optim_steps + current_cfg = config[0] + + # from lib.kits.debug import set_trace + # set_trace() + + try: + if fix_poses: + # for ci, cfg in enumerate(config[1:]): + for ci, cfg in enumerate([config[-1]]): # To debug, only run the last step + current_cfg.update(cfg) + print(f'Step {ci+1}: {current_cfg.description}') + self._optim([trans,betas], poses, betas, trans, verts, current_cfg, enable_time) + else: + if not enable_time or not self.is_skel_data_init: + # Optimize the global rotation and translation for the initial fitting + print(f'Step 0: {current_cfg.description}') + self._optim([trans,poses], poses, betas, trans, verts, current_cfg, enable_time) + + for ci, cfg in enumerate(config[1:]): + # for ci, cfg in enumerate([config[-1]]): # To debug, only run the last step + current_cfg.update(cfg) + print(f'Step {ci+1}: {current_cfg.description}') + self._optim([poses], poses, betas, trans, verts, current_cfg, enable_time) + + + # # Refine by optimizing the whole body + # cfg.update(self.cfg_optim[]) + # cfg.update({'mode' : 'free', 'tolerance_change': 0.0001, 'l_joint': 0.2e4}) + # self._optim([trans, poses], poses, betas, trans, verts, cfg) + except Exception as e: + print(e) + traceback.print_exc() + # from lib.kits.debug import set_trace + # set_trace() + + return betas, poses, trans, verts + + def _optim(self, + params, + poses, + betas, + trans, + verts, + cfg, + enable_time=False, + ): + + # regress anatomical joints from SMPL's vertices + anat_joints = torch.einsum('bik,ji->bjk', [verts, self.skel.J_regressor_osim]) + dJ=torch.zeros((poses.shape[0], 24, 3), device=betas.device) + + # Create the optimizer + optimizer = torch.optim.LBFGS(params, + lr=cfg.lr, + max_iter=cfg.max_iter, + line_search_fn=cfg.line_search_fn, + tolerance_change=cfg.tolerance_change) + + poses_init = poses.detach().clone() + trans_init = trans.detach().clone() + + def closure(): + optimizer.zero_grad() + + # fi = self.watch_frame #frame of the batch to display + # output = self.skel.forward(poses=poses[fi:fi+1], + # betas=betas[fi:fi+1], + # trans=trans[fi:fi+1], + # poses_type='skel', + # dJ=dJ[fi:fi+1], + # skelmesh=True) + + # self._fstep_plot(output, cfg, verts[fi:fi+1], anat_joints[fi:fi+1], ) + + loss_dict = self._fitting_loss(poses, + poses_init, + betas, + trans, + trans_init, + dJ, + anat_joints, + verts, + cfg, + enable_time) + + # print(pretty_loss_print(loss_dict)) + + loss = sum(loss_dict.values()) + loss.backward() + + return loss + + for step_i in range(cfg.num_steps): + loss = optimizer.step(closure).item() + + def _get_masks(self, cfg): + pose_mask = torch.ones((self.skel.num_q_params)).to(self.device).unsqueeze(0) + verts_mask = torch.ones_like(self.fitting_mask) + joint_mask = torch.ones((self.skel.num_joints, 3)).to(self.device).unsqueeze(0).bool() + + # Mask vertices + if cfg.mode=='root_only': + # Only optimize the global rotation of the body, i.e. the first 3 angles of the pose + pose_mask[:] = 0 # Only optimize for the global rotation + pose_mask[:,:3] = 1 + # Only fit the thorax vertices to recover the proper body orientation and translation + verts_mask = self.torso_verts_mask + + elif cfg.mode=='fixed_upper_limbs': + upper_limbs_joints = [0,1,2,3,6,9,12,15,17] + verts_mask = (self.smpl.lbs_weights[:,upper_limbs_joints]>0.5).sum(dim=-1)>0 + verts_mask = verts_mask.unsqueeze(0).unsqueeze(-1) + + joint_mask[:, [3,4,5,8,9,10,18,23], :] = 0 # Do not try to match the joints of the upper limbs + + pose_mask[:] = 1 + pose_mask[:,:3] = 0 # Block the global rotation + pose_mask[:,19] = 0 # block the lumbar twist + # pose_mask[:, 36:39] = 0 + # pose_mask[:, 43:46] = 0 + # pose_mask[:, 62:65] = 0 + # pose_mask[:, 62:65] = 0 + + elif cfg.mode=='fixed_root': + pose_mask[:] = 1 + pose_mask[:,:3] = 0 # Block the global rotation + # pose_mask[:,19] = 0 # block the lumbar twist + + # The orientation of the upper limbs is often wrong in SMPL so ignore these vertices for the finale step + upper_limbs_joints = [1,2,16,17] + verts_mask = (self.smpl.lbs_weights[:,upper_limbs_joints]>0.5).sum(dim=-1)>0 + verts_mask = torch.logical_not(verts_mask) + verts_mask = verts_mask.unsqueeze(0).unsqueeze(-1) + + elif cfg.mode=='free': + verts_mask = torch.ones_like(self.fitting_mask ) + + joint_mask[:]=0 + joint_mask[:, [19,14], :] = 1 # Only fir the scapula join to avoid collapsing shoulders + + else: + raise ValueError(f'Unknown mode {cfg.mode}') + + return pose_mask, verts_mask, joint_mask + + def _fitting_loss(self, + poses, + poses_init, + betas, + trans, + trans_init, + dJ, + anat_joints, + verts, + cfg, + enable_time=False): + + loss_dict = {} + + + pose_mask, verts_mask, joint_mask = self._get_masks(cfg) + poses = poses * pose_mask + poses_init * (1-pose_mask) + + # Mask joints to not optimize before computing the losses + + output = self.skel.forward(poses=poses, betas=betas, trans=trans, poses_type='skel', dJ=dJ, skelmesh=False) + + # Fit the SMPL vertices + # We know the skinning of the forearm and the neck are not perfect, + # so we create a mask of the SMPL vertices that are important to fit, like the hands and the head + loss_dict['verts_loss_loose'] = cfg.l_verts_loose * (verts_mask * (output.skin_verts - verts)**2).sum() / (((verts_mask).sum()*self.nb_frames)) + + # Fit the regressed joints, this avoids collapsing shoulders + # loss_dict['joint_loss'] = cfg.l_joint * F.mse_loss(output.joints, anat_joints) + loss_dict['joint_loss'] = cfg.l_joint * (joint_mask * (output.joints - anat_joints)**2).mean() + + # Time consistancy + if poses.shape[0] > 1 and enable_time: + # This avoids unstable hips orientationZ + loss_dict['time_loss'] = cfg.l_time_loss * F.mse_loss(poses[1:], poses[:-1]) + + loss_dict['pose_loss'] = cfg.l_pose_loss * compute_pose_loss(poses, poses_init) + + if cfg.use_basic_loss is False: + # These losses can be used to regularize the optimization but are not always necessary + loss_dict['anch_rot'] = cfg.l_anch_pose * compute_anchor_pose(poses, poses_init) + loss_dict['anch_trans'] = cfg.l_anch_trans * compute_anchor_trans(trans, trans_init) + + loss_dict['verts_loss'] = cfg.l_verts * (verts_mask * self.fitting_mask * (output.skin_verts - verts)**2).sum() / (self.fitting_mask*verts_mask).sum() + + # Regularize the pose + loss_dict['scapula_loss'] = cfg.l_scapula_loss * compute_scapula_loss(poses) + loss_dict['spine_loss'] = cfg.l_spine_loss * compute_spine_loss(poses) + + # Adjust the losses of all the pose regularizations sub losses with the pose_reg_factor value + for key in ['scapula_loss', 'spine_loss', 'pose_loss']: + loss_dict[key] = cfg.pose_reg_factor * loss_dict[key] + + return loss_dict + + def _fstep_plot(self, output, cfg, verts, anat_joints): + "Function to plot each step" + + if('DISABLE_VIEWER' in os.environ): + return + + pose_mask, verts_mask, joint_mask = self._get_masks(cfg) + + skin_err_value = ((output.skin_verts[0] - verts[0])**2).sum(dim=-1).sqrt() + skin_err_value = skin_err_value / 0.05 + skin_err_value = to_numpy(skin_err_value) + + skin_mesh = Mesh(v=to_numpy(output.skin_verts[0]), f=[], vc='white') + skel_mesh = Mesh(v=to_numpy(output.skel_verts[0]), f=self.skel.skel_f.cpu().numpy(), vc='white') + + # Display vertex distance on SMPL + smpl_verts = to_numpy(verts[0]) + smpl_mesh = Mesh(v=smpl_verts, f=self.smpl.faces) + smpl_mesh.set_vertex_colors_from_weights(skin_err_value, scale_to_range_1=False) + + smpl_mesh_masked = Mesh(v=smpl_verts[to_numpy(verts_mask[0,:,0])], f=[], vc='green') + smpl_mesh_pc = Mesh(v=smpl_verts, f=[], vc='green') + + skin_mesh_err = Mesh(v=to_numpy(output.skin_verts[0]), f=self.skel.skin_f.cpu().numpy(), vc='white') + skin_mesh_err.set_vertex_colors_from_weights(skin_err_value, scale_to_range_1=False) + # List the meshes to display + meshes_left = [skin_mesh_err, smpl_mesh_pc] + meshes_right = [smpl_mesh_masked, skin_mesh, skel_mesh] + + if cfg.l_joint > 0: + # Plot the joints + meshes_right += location_to_spheres(to_numpy(output.joints[joint_mask[:,:,0]]), color=(1,0,0), radius=0.02) + meshes_right += location_to_spheres(to_numpy(anat_joints[joint_mask[:,:,0]]), color=(0,1,0), radius=0.02) \ + + + self.mv[0][0].set_dynamic_meshes(meshes_left) + self.mv[0][1].set_dynamic_meshes(meshes_right) + + # print(poses[frame_to_watch, :3]) + # print(trans[frame_to_watch]) + # print(betas[frame_to_watch, :3]) + # mv.get_keypress() diff --git a/lib/body_models/skel/alignment/losses.py b/lib/body_models/skel/alignment/losses.py new file mode 100644 index 0000000000000000000000000000000000000000..51238fe784ff0c65ecd570cea76b6e11c338158c --- /dev/null +++ b/lib/body_models/skel/alignment/losses.py @@ -0,0 +1,47 @@ +import torch + +def compute_scapula_loss(poses): + + scapula_indices = [26, 27, 28, 36, 37, 38] + + scapula_poses = poses[:, scapula_indices] + scapula_loss = torch.linalg.norm(scapula_poses, ord=2) + return scapula_loss + +def compute_spine_loss(poses): + + spine_indices = range(17, 25) + + spine_poses = poses[:, spine_indices] + spine_loss = torch.linalg.norm(spine_poses, ord=2) + return spine_loss + +def compute_pose_loss(poses, pose_init): + + pose_loss = torch.linalg.norm(poses[:, 3:], ord=2) # The global rotation should not be constrained + return pose_loss + +def compute_anchor_pose(poses, pose_init): + + pose_loss = torch.nn.functional.mse_loss(poses[:, :3], pose_init[:, :3]) + return pose_loss + +def compute_anchor_trans(trans, trans_init): + + trans_loss = torch.nn.functional.mse_loss(trans, trans_init) + return trans_loss + +def compute_time_loss(poses): + + pose_delta = poses[1:] - poses[:-1] + time_loss = torch.linalg.norm(pose_delta, ord=2) + return time_loss + +def pretty_loss_print(loss_dict): + # Pretty print the loss on the form loss val | loss1 val1 | loss2 val2 + # Start with the total loss + loss = sum(loss_dict.values()) + pretty_loss = f'{loss:.4f}' + for key, val in loss_dict.items(): + pretty_loss += f' | {key} {val:.4f}' + return pretty_loss diff --git a/lib/body_models/skel/alignment/riggid_parts_mask.pkl b/lib/body_models/skel/alignment/riggid_parts_mask.pkl new file mode 100644 index 0000000000000000000000000000000000000000..53c9b8baf1992053946e27b33d6f677b8d52f68c Binary files /dev/null and b/lib/body_models/skel/alignment/riggid_parts_mask.pkl differ diff --git a/lib/body_models/skel/alignment/utils.py b/lib/body_models/skel/alignment/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..b1353100ad4f2bdf55956d5f1cbfd6ae17002406 --- /dev/null +++ b/lib/body_models/skel/alignment/utils.py @@ -0,0 +1,119 @@ + +import os +import pickle +import torch +import numpy as np +from psbody.mesh.sphere import Sphere + +# to_params = lambda x: torch.from_numpy(x).float().to(self.device).requires_grad_(True) +# to_torch = lambda x: torch.from_numpy(x).float().to(self.device) + +def to_params(x, device): + return x.to(device).requires_grad_(True) + +def to_torch(x, device): + return torch.from_numpy(x).float().to(device) + +def to_numpy(x): + return x.detach().cpu().numpy() + +def load_smpl_seq(smpl_seq_path, gender=None, straighten_hands=False): + + if not os.path.exists(smpl_seq_path): + raise Exception('Path does not exist: {}'.format(smpl_seq_path)) + + if smpl_seq_path.endswith('.pkl'): + data_dict = pickle.load(open(smpl_seq_path, 'rb')) + + elif smpl_seq_path.endswith('.npz'): + data_dict = np.load(smpl_seq_path, allow_pickle=True) + + if data_dict.files == ['pred_smpl_parms', 'verts', 'pred_cam_t']: + data_dict = data_dict['pred_smpl_parms'].item()# ['global_orient', 'body_pose', 'body_pose_axis_angle', 'global_orient_axis_angle', 'betas'] + else: + data_dict = {key: data_dict[key] for key in data_dict.keys()} # convert to python dict + else: + raise Exception('Unknown file format: {}. Supported formats are .pkl and .npz'.format(smpl_seq_path)) + + # Instanciate a dictionary with the keys expected by the fitter + data_fixed = {} + + # Get gender + if 'gender' not in data_dict: + assert gender is not None, f"The provided SMPL data dictionary does not contain gender, you need to pass it in command line" + data_fixed['gender'] = gender + elif not isinstance(data_dict['gender'], str): + # In some npz, the gender type happens to be: array('male', dtype=' None: + super().__init__() + pass + + def q_to_translation(self, q, **kwargs): + return torch.zeros(q.shape[0], 3).to(q.device) + + +class CustomJoint(OsimJoint): + + def __init__(self, axis, axis_flip) -> None: + super().__init__() + self.register_buffer('axis', torch.FloatTensor(axis)) + self.register_buffer('axis_flip', torch.FloatTensor(axis_flip)) + self.register_buffer('nb_dof', torch.tensor(len(axis))) + + def q_to_rot(self, q, **kwargs): + + ident = torch.eye(3, dtype=q.dtype).to(q.device) + + Rp = ident.unsqueeze(0).expand(q.shape[0],3,3) # torch.eye(q.shape[0], 3, 3) + for i in range(self.nb_dof): + axis = self.axis[i].to(q.device) + angle_axis = q[:, i:i+1] * self.axis_flip[i].to(q.device) * axis + Rp_i = axis_angle_to_matrix(angle_axis) + Rp = torch.matmul(Rp_i, Rp) + return Rp + + + +class CustomJoint1D(OsimJoint): + + def __init__(self, axis, axis_flip) -> None: + super().__init__() + self.axis = torch.FloatTensor(axis) + self.axis = self.axis / torch.linalg.norm(self.axis) + self.axis_flip = torch.FloatTensor(axis_flip) + self.nb_dof = 1 + + def q_to_rot(self, q, **kwargs): + axis = self.axis.to(q.device) + angle_axis = q[:, 0:1] * self.axis_flip.to(q.device) * axis + Rp_i = axis_angle_to_matrix(angle_axis) + return Rp_i + + +class WalkerKnee(OsimJoint): + + def __init__(self) -> None: + super().__init__() + self.register_buffer('nb_dof', torch.tensor(1)) + # self.nb_dof = 1 + + def q_to_rot(self, q, **kwargs): + # Todo : for now implement a basic knee + theta_i = torch.zeros(q.shape[0], 3).to(q.device) + theta_i[:, 2] = -q[:, 0] + Rp_i = axis_angle_to_matrix(theta_i) + return Rp_i + +class PinJoint(OsimJoint): + + def __init__(self, parent_frame_ori) -> None: + super().__init__() + self.register_buffer('parent_frame_ori', torch.FloatTensor(parent_frame_ori)) + self.register_buffer('nb_dof', torch.tensor(1)) + + + def q_to_rot(self, q, **kwargs): + + talus_orient_torch = self.parent_frame_ori.to(q.device) + Ra_i = euler_angles_to_matrix(talus_orient_torch, 'XYZ') + + z_axis = torch.FloatTensor([0,0,1]).to(q.device) + axis = torch.matmul(Ra_i, z_axis).to(q.device) + + axis_angle = q[:, 0:1] * axis + Rp_i = axis_angle_to_matrix(axis_angle) + + return Rp_i + + +class ConstantCurvatureJoint(CustomJoint): + + def __init__(self, **kwargs ) -> None: + super().__init__( **kwargs) + + + +class EllipsoidJoint(CustomJoint): + + def __init__(self, **kwargs) -> None: + super().__init__(**kwargs) diff --git a/lib/body_models/skel/skel_model.py b/lib/body_models/skel/skel_model.py new file mode 100644 index 0000000000000000000000000000000000000000..d6cd0695b201be2dd909279050e63d4fd2075c4f --- /dev/null +++ b/lib/body_models/skel/skel_model.py @@ -0,0 +1,675 @@ + +""" +Copyright©2024 Max-Planck-Gesellschaft zur Förderung +der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute +for Intelligent Systems. All rights reserved. + +Author: Marilyn Keller +See https://skel.is.tue.mpg.de/license.html for licensing and contact information. +""" + +import os +import torch.nn as nn +import torch +import numpy as np +import pickle as pkl +from typing import NewType, Optional + +from lib.body_models.skel.joints_def import curve_torch_3d, left_scapula, right_scapula +from lib.body_models.skel.osim_rot import ConstantCurvatureJoint, CustomJoint, EllipsoidJoint, PinJoint, WalkerKnee +from lib.body_models.skel.utils import build_homog_matrix, rotation_matrix_from_vectors, sparce_coo_matrix2tensor, with_zeros, matmul_chain +from dataclasses import dataclass, fields + +from lib.body_models.skel.kin_skel import scaling_keypoints, pose_param_names, smpl_joint_corresp +import lib.body_models.skel.config as cg + +Tensor = NewType('Tensor', torch.Tensor) + +@dataclass +class ModelOutput: + vertices: Optional[Tensor] = None + joints: Optional[Tensor] = None + full_pose: Optional[Tensor] = None + global_orient: Optional[Tensor] = None + transl: Optional[Tensor] = None + v_shaped: Optional[Tensor] = None + + def __getitem__(self, key): + return getattr(self, key) + + def get(self, key, default=None): + return getattr(self, key, default) + + def __iter__(self): + return self.keys() + + def keys(self): + keys = [t.name for t in fields(self)] + return iter(keys) + + def values(self): + values = [getattr(self, t.name) for t in fields(self)] + return iter(values) + + def items(self): + data = [(t.name, getattr(self, t.name)) for t in fields(self)] + return iter(data) + +@dataclass +class SKELOutput(ModelOutput): + betas: Optional[Tensor] = None + body_pose: Optional[Tensor] = None + skin_verts: Optional[Tensor] = None + skel_verts: Optional[Tensor] = None + joints: Optional[Tensor] = None + joints_ori: Optional[Tensor] = None + betas: Optional[Tensor] = None + poses: Optional[Tensor] = None + trans : Optional[Tensor] = None + pose_offsets : Optional[Tensor] = None + joints_tpose : Optional[Tensor] = None + v_skin_shaped : Optional[Tensor] = None + + +class SKEL(nn.Module): + + num_betas = 10 + + def __init__(self, gender, model_path=None, custom_joint_reg_path=None, **kwargs): + super(SKEL, self).__init__() + + if gender not in ['male', 'female']: + raise RuntimeError(f'Invalid Gender, got {gender}') + + self.gender = gender + + if model_path is None: + # skel_file = f"/Users/mkeller2/Data/skel_models_v1.0/skel_{gender}.pkl" + skel_file = os.path.join(cg.skel_folder, f"skel_{gender}.pkl") + else: + skel_file = os.path.join(model_path, f"skel_{gender}.pkl") + assert os.path.exists(skel_file), f"Skel model file {skel_file} does not exist" + + skel_data = pkl.load(open(skel_file, 'rb')) + + # Check that the version of the skel model is compatible with this loader + assert 'version' in skel_data, f"Expected version 1.1.1 of the SKEL picke. Please download the latest skel pkl versions from https://skel.is.tue.mpg.de/download.html" + version = skel_data['version'] + assert version == '1.1.1', f"Expected version 1.1.1, got {version}. Please download the latest skel pkl versions from https://skel.is.tue.mpg.de/download.html" + + self.num_betas = 10 + self.num_q_params = 46 + self.bone_names = skel_data['bone_names'] + self.num_joints = skel_data['J_regressor_osim'].shape[0] + self.num_joints_smpl = skel_data['J_regressor'].shape[0] + + self.joints_name = skel_data['joints_name'] + self.pose_params_name = skel_data['pose_params_name'] + + # register the template meshes + self.register_buffer('skin_template_v', torch.FloatTensor(skel_data['skin_template_v'])) + self.register_buffer('skin_f', torch.LongTensor(skel_data['skin_template_f'])) + + self.register_buffer('skel_template_v', torch.FloatTensor(skel_data['skel_template_v'])) + self.register_buffer('skel_f', torch.LongTensor(skel_data['skel_template_f'])) + + # Shape corrective blend shapes + self.register_buffer('shapedirs', torch.FloatTensor(np.array(skel_data['shapedirs'][:,:,:self.num_betas]))) + self.register_buffer('posedirs', torch.FloatTensor(np.array(skel_data['posedirs']))) + + # Model sparse joints regressor, regresses joints location from a mesh + self.register_buffer('J_regressor', sparce_coo_matrix2tensor(skel_data['J_regressor'])) + + # Regress the anatomical joint location with a regressor learned from BioAmass + if custom_joint_reg_path is not None: + J_regressor_skel = pkl.load(open(custom_joint_reg_path, 'rb')) + if 'scipy.sparse' in str(type(J_regressor_skel)): + J_regressor_skel = J_regressor_skel.todense() + self.register_buffer('J_regressor_osim', torch.FloatTensor(J_regressor_skel)) + print('WARNING: Using custom joint regressor') + else: + self.register_buffer('J_regressor_osim', sparce_coo_matrix2tensor(skel_data['J_regressor_osim'], make_dense=True)) + + self.register_buffer('per_joint_rot', torch.FloatTensor(skel_data['per_joint_rot'])) + + # Skin model skinning weights + self.register_buffer('skin_weights', sparce_coo_matrix2tensor(skel_data['skin_weights'])) + + # Skeleton model skinning weights + self.register_buffer('skel_weights', sparce_coo_matrix2tensor(skel_data['skel_weights'])) + self.register_buffer('skel_weights_rigid', sparce_coo_matrix2tensor(skel_data['skel_weights_rigid'])) + + # Kinematic tree of the model + self.register_buffer('kintree_table', torch.from_numpy(skel_data['osim_kintree_table'].astype(np.int64))) + self.register_buffer('parameter_mapping', torch.from_numpy(skel_data['parameter_mapping'].astype(np.int64))) + + # transformation from osim can pose to T pose + self.register_buffer('tpose_transfo', torch.FloatTensor(skel_data['tpose_transfo'])) + + # transformation from osim can pose to A pose + self.register_buffer('apose_transfo', torch.FloatTensor(skel_data['apose_transfo'])) + self.register_buffer('apose_rel_transfo', torch.FloatTensor(skel_data['apose_rel_transfo'])) + + # Indices of bones which orientation should not vary with beta in T pose: + joint_idx_fixed_beta = [0, 5, 10, 13, 18, 23] + self.register_buffer('joint_idx_fixed_beta', torch.IntTensor(joint_idx_fixed_beta)) + + id_to_col = {self.kintree_table[1, i].item(): i for i in range(self.kintree_table.shape[1])} + self.register_buffer('parent', torch.LongTensor( + [id_to_col[self.kintree_table[0, it].item()] for it in range(1, self.kintree_table.shape[1])])) + + + # child array + # TODO create this array in the SKEL creator + child_array = [] + Nj = self.num_joints + for i in range(0, Nj): + try: + j_array = torch.where(self.kintree_table[0] == i)[0] # candidate child lines + if len(j_array) == 0: + child_index = 0 + else: + + j = j_array[0] + if j>=len(self.kintree_table[1]): + child_index = 0 + else: + child_index = self.kintree_table[1,j].item() + child_array.append(child_index) + except: + import ipdb; ipdb.set_trace() + + # print(f"child_array: ") + # [print(i,child_array[i]) for i in range(0, Nj)] + self.register_buffer('child', torch.LongTensor(child_array)) + + # Instantiate joints + self.joints_dict = nn.ModuleList([ + CustomJoint(axis=[[0,0,1], [1,0,0], [0,1,0]], axis_flip=[1, 1, 1]), # 0 pelvis + CustomJoint(axis=[[0,0,1], [1,0,0], [0,1,0]], axis_flip=[1, 1, 1]), # 1 femur_r + WalkerKnee(), # 2 tibia_r + PinJoint(parent_frame_ori = [0.175895, -0.105208, 0.0186622]), # 3 talus_r Field taken from .osim Joint-> frames -> PhysicalOffsetFrame -> orientation + PinJoint(parent_frame_ori = [-1.76818999, 0.906223, 1.8196000]), # 4 calcn_r + PinJoint(parent_frame_ori = [-3.141589999, 0.6199010, 0]), # 5 toes_r + CustomJoint(axis=[[0,0,1], [1,0,0], [0,1,0]], axis_flip=[1, -1, -1]), # 6 femur_l + WalkerKnee(), # 7 tibia_l + PinJoint(parent_frame_ori = [0.175895, -0.105208, 0.0186622]), # 8 talus_l + PinJoint(parent_frame_ori = [1.768189999 ,-0.906223, 1.8196000]), # 9 calcn_l + PinJoint(parent_frame_ori = [-3.141589999, -0.6199010, 0]), # 10 toes_l + ConstantCurvatureJoint(axis=[[1,0,0], [0,0,1], [0,1,0]], axis_flip=[1, 1, 1]), # 11 lumbar + ConstantCurvatureJoint(axis=[[1,0,0], [0,0,1], [0,1,0]], axis_flip=[1, 1, 1]), # 12 thorax + ConstantCurvatureJoint(axis=[[1,0,0], [0,0,1], [0,1,0]], axis_flip=[1, 1, 1]), # 13 head + EllipsoidJoint(axis=[[0,1,0], [0,0,1], [1,0,0]], axis_flip=[1, -1, -1]), # 14 scapula_r + CustomJoint(axis=[[1,0,0], [0,1,0], [0,0,1]], axis_flip=[1, 1, 1]), # 15 humerus_r + CustomJoint(axis=[[0.0494, 0.0366, 0.99810825]], axis_flip=[[1]]), # 16 ulna_r + CustomJoint(axis=[[-0.01716099, 0.99266564, -0.11966796]], axis_flip=[[1]]), # 17 radius_r + CustomJoint(axis=[[1,0,0], [0,0,-1]], axis_flip=[1, 1]), # 18 hand_r + EllipsoidJoint(axis=[[0,1,0], [0,0,1], [1,0,0]], axis_flip=[1, 1, 1]), # 19 scapula_l + CustomJoint(axis=[[1,0,0], [0,1,0], [0,0,1]], axis_flip=[1, 1, 1]), # 20 humerus_l + CustomJoint(axis=[[-0.0494, -0.0366, 0.99810825]], axis_flip=[[1]]), # 21 ulna_l + CustomJoint(axis=[[0.01716099, -0.99266564, -0.11966796]], axis_flip=[[1]]), # 22 radius_l + CustomJoint(axis=[[-1,0,0], [0,0,-1]], axis_flip=[1, 1]), # 23 hand_l + ]) + + def pose_params_to_rot(self, osim_poses): + """ Transform the pose parameters to 3x3 rotation matrices + Each parameter is mapped to a joint as described in joint_dict. + The specific joint object is then used to compute the rotation matrix. + """ + + B = osim_poses.shape[0] + Nj = self.num_joints + + ident = torch.eye(3, dtype=osim_poses.dtype).to(osim_poses.device) + Rp = ident.unsqueeze(0).unsqueeze(0).repeat(B, Nj,1,1) + tp = torch.zeros(B, Nj, 3).to(osim_poses.device) + start_index = 0 + for i in range(0, Nj): + joint_object = self.joints_dict[i] + end_index = start_index + joint_object.nb_dof + Rp[:, i] = joint_object.q_to_rot(osim_poses[:, start_index:end_index]) + start_index = end_index + return Rp, tp + + + def params_name_to_index(self, param_name): + + assert param_name in pose_param_names + param_index = pose_param_names.index(param_name) + return param_index + + + def forward(self, poses, betas, trans, poses_type='skel', skelmesh=True, dJ=None, pose_dep_bs=True): + """ + params + poses : B x 46 tensor of pose parameters + betas : B x 10 tensor of shape parameters, same as SMPL + trans : B x 3 tensor of translation + poses_type : str, 'skel', should not be changed + skelmesh : bool, if True, returns the skeleton vertices. The skeleton mesh is heavy so to fit on GPU memory, set to False when not needed. + dJ : B x 24 x 3 tensor of the offset of the joints location from the anatomical regressor. If None, the offset is set to 0. + pose_dep_bs : bool, if True (default), applies the pose dependant blend shapes. If False, the pose dependant blend shapes are not applied. + + return SKELOutput class with the following fields: + betas : Bx10 tensor of shape parameters + poses : Bx46 tensor of pose parameters + skin_verts : Bx6890x3 tensor of skin vertices + skel_verts : tensor of skeleton vertices + joints : Bx24x3 tensor of joints location + joints_ori : Bx24x3x3 tensor of joints orientation + trans : Bx3 pose dependant blend shapes offsets + pose_offsets : Bx6080x3 pose dependant blend shapes offsets + joints_tpose : Bx24x3 3D joints location in T pose + + In this function we use the following conventions: + B : batch size + Ns : skin vertices + Nk : skeleton vertices + """ + + Ns = self.skin_template_v.shape[0] # nb skin vertices + Nk = self.skel_template_v.shape[0] # nb skeleton vertices + Nj = self.num_joints + B = poses.shape[0] + device = poses.device + + # Check the shapes of the inputs + assert len(betas.shape) == 2, f"Betas should be of shape (B, {self.num_betas}), but got {betas.shape}" + assert poses.shape[0] == betas.shape[0], f"Expected poses and betas to have the same batch size, but got {poses.shape[0]} and {betas.shape[0]}" + assert poses.shape[0] == trans.shape[0], f"Expected poses and betas to have the same batch size, but got {poses.shape[0]} and {trans.shape[0]}" + + if dJ is not None: + assert len(dJ.shape) == 3, f"Expected dJ to have shape (B, {Nj}, 3), but got {dJ.shape}" + assert dJ is None or dJ.shape[0] == B, f"Expected dJ to have the same batch size as poses, but got {dJ.shape[0]} and {poses.shape[0]}" + assert dJ.shape[1] == Nj, f"Expected dJ to have the same number of joints as the model, but got {dJ.shape[1]} and {Nj}" + + # Check the device of the inputs + assert betas.device == device, f"Betas should be on device {device}, but got {betas.device}" + assert trans.device == device, f"Trans should be on device {device}, but got {trans.device}" + + skin_v0 = self.skin_template_v[None, :] + skel_v0 = self.skel_template_v[None, :] + betas = betas[:, :, None] # TODO Name the expanded beta differently + + # TODO clean this part + assert poses_type in ['skel', 'bsm'], f"got {poses_type}" + if poses_type == 'bsm': + assert poses.shape[1] == self.num_q_params - 3, f'With poses_type bsm, expected parameters of shape (B, {self.num_q_params - 3}, got {poses.shape}' + poses_bsm = poses + poses_skel = torch.zeros(B, self.num_q_params) + poses_skel[:,:3] = poses_bsm[:, :3] + trans = poses_bsm[:, 3:6] # In BSM parametrization, the hips translation is given by params 3 to 5 + poses_skel[:, 3:] = poses_bsm + poses = poses_skel + + else: + assert poses.shape[1] == self.num_q_params, f'With poses_type skel, expected parameters of shape (B, {self.num_q_params}), got {poses.shape}' + pass + # Load poses as expected + # Distinction bsm skel. by default it will be bsm + + # ------- Shape ---------- + # Apply the beta offset to the template + shapedirs = self.shapedirs.view(-1, self.num_betas)[None, :].expand(B, -1, -1) # B x D*Ns x num_betas + v_shaped = skin_v0 + torch.matmul(shapedirs, betas).view(B, Ns, 3) + + # ------- Joints ---------- + # Regress the anatomical joint location + J = torch.einsum('bik,ji->bjk', [v_shaped, self.J_regressor_osim]) # BxJx3 # osim regressor + # J = self.apose_transfo[:, :3, -1].view(1, Nj, 3).expand(B, -1, -1) # Osim default pose joints location + + if dJ is not None: + J = J + dJ + J_tpose = J.clone() + + # Local translation + J_ = J.clone() # BxJx3 + J_[:, 1:, :] = J[:, 1:, :] - J[:, self.parent, :] + t = J_[:, :, :, None] # BxJx3x1 + + # ------- Bones transformation matrix---------- + + # Bone initial transform to go from unposed to SMPL T pose + Rk01 = self.compute_bone_orientation(J, J_) + + # BSM default pose rotations + Ra = self.apose_rel_transfo[:, :3, :3].view(1, Nj, 3,3).expand(B, Nj, 3, 3) + + # Local bone rotation given by the pose param + Rp, tp = self.pose_params_to_rot(poses) # BxNjx3x3 pose params to rotation + + R = matmul_chain([Rk01, Ra.transpose(2,3), Rp, Ra, Rk01.transpose(2,3)]) + + ###### Compute translation for non pure rotation joints + t_posed = t.clone() + + # Scapula + thorax_width = torch.norm(J[:, 19, :] - J[:, 14, :], dim=1) # Distance between the two scapula joints, size B + thorax_height = torch.norm(J[:, 12, :] - J[:, 11, :], dim=1) # Distance between the two scapula joints, size B + + angle_abduction = poses[:,26] + angle_elevation = poses[:,27] + angle_rot = poses[:,28] + angle_zero = torch.zeros_like(angle_abduction) + t_posed[:,14] = t_posed[:,14] + \ + (right_scapula(angle_abduction, angle_elevation, angle_rot, thorax_width, thorax_height).view(-1,3,1) + - right_scapula(angle_zero, angle_zero, angle_zero, thorax_width, thorax_height).view(-1,3,1)) + + + angle_abduction = poses[:,36] + angle_elevation = poses[:,37] + angle_rot = poses[:,38] + angle_zero = torch.zeros_like(angle_abduction) + t_posed[:,19] = t_posed[:,19] + \ + (left_scapula(angle_abduction, angle_elevation, angle_rot, thorax_width, thorax_height).view(-1,3,1) + - left_scapula(angle_zero, angle_zero, angle_zero, thorax_width, thorax_height).view(-1,3,1)) + + + # Knee_r + # TODO add the Walker knee offset + # bone_scale = self.compute_bone_scale(J_,J, skin_v0, v_shaped) + # f1 = poses[:, 2*3+2].clone() + # scale_femur = bone_scale[:, 2] + # factor = 0.076/0.080 * scale_femur # The template femur medial laterak spacing #66 + # f = -f1*180/torch.pi #knee_flexion + # varus = (0.12367*f)-0.0009*f**2 + # introt = 0.3781*f-0.001781*f**2 + # ydis = (-0.0683*f + # + 8.804e-4 * f**2 + # - 3.750e-06*f**3 + # )/1000*factor # up-down + # zdis = (-0.1283*f + # + 4.796e-4 * f**2)/1000*factor # + # import ipdb; ipdb.set_trace() + # poses[:, 9] = poses[:, 9] + varus + # t_posed[:,2] = t_posed[:,2] + torch.stack([torch.zeros_like(ydis), ydis, zdis], dim=1).view(-1,3,1) + # poses[:, 2*3+2]=0 + + # t_unposed = torch.zeros_like(t_posed) + # t_unposed[:,2] = torch.stack([torch.zeros_like(ydis), ydis, zdis], dim=1).view(-1,3,1) + + + # Spine + lumbar_bending = poses[:,17] + lumbar_extension = poses[:,18] + angle_zero = torch.zeros_like(lumbar_bending) + interp_t = torch.ones_like(lumbar_bending) + l = torch.abs(J[:, 11, 1] - J[:, 0, 1]) # Length of the spine section along y axis + t_posed[:,11] = t_posed[:,11] + \ + (curve_torch_3d(lumbar_bending, lumbar_extension, t=interp_t, l=l) + - curve_torch_3d(angle_zero, angle_zero, t=interp_t, l=l)) + + thorax_bending = poses[:,20] + thorax_extension = poses[:,21] + angle_zero = torch.zeros_like(thorax_bending) + interp_t = torch.ones_like(thorax_bending) + l = torch.abs(J[:, 12, 1] - J[:, 11, 1]) # Length of the spine section + + t_posed[:,12] = t_posed[:,12] + \ + (curve_torch_3d(thorax_bending, thorax_extension, t=interp_t, l=l) + - curve_torch_3d(angle_zero, angle_zero, t=interp_t, l=l)) + + head_bending = poses[:, 23] + head_extension = poses[:,24] + angle_zero = torch.zeros_like(head_bending) + interp_t = torch.ones_like(head_bending) + l = torch.abs(J[:, 13, 1] - J[:, 12, 1]) # Length of the spine section + t_posed[:,13] = t_posed[:,13] + \ + (curve_torch_3d(head_bending, head_extension, t=interp_t, l=l) + - curve_torch_3d(angle_zero, angle_zero, t=interp_t, l=l)) + + + # ------- Body surface transformation matrix---------- + + G_ = torch.cat([R, t_posed], dim=-1) # BxJx3x4 local transformation matrix + pad_row = torch.FloatTensor([0, 0, 0, 1]).to(device).view(1, 1, 1, 4).expand(B, Nj, -1, -1) # BxJx1x4 + G_ = torch.cat([G_, pad_row], dim=2) # BxJx4x4 padded to be 4x4 matrix an enable multiplication for the kinematic chain + + # Global transform + G = [G_[:, 0].clone()] + for i in range(1, Nj): + G.append(torch.matmul(G[self.parent[i - 1]], G_[:, i, :, :])) + G = torch.stack(G, dim=1) + + # ------- Pose dependant blend shapes ---------- + if pose_dep_bs is False: + v_shaped_pd = v_shaped + else: + # Note : Those should be retrained for SKEL as the SKEL joints location are different from SMPL. + # But the current version lets use get decent pose dependant deformations for the shoulders, belly and knies + ident = torch.eye(3, dtype=v_shaped.dtype, device=device) + + # We need the per SMPL joint bone transform to compute pose dependant blend shapes. + # Initialize each joint rotation with identity + Rsmpl = ident.unsqueeze(0).unsqueeze(0).expand(B, self.num_joints_smpl, -1, -1).clone() # BxNjx3x3 + + Rskin = G_[:, :, :3, :3] # BxNjx3x3 + Rsmpl[:, smpl_joint_corresp] = Rskin[:] # BxNjx3x3 pose params to rotation + pose_feature = Rsmpl[:, 1:].view(B, -1, 3, 3) - ident + pose_offsets = torch.matmul(pose_feature.view(B, -1), + self.posedirs.view(Ns*3, -1).T).view(B, -1, 3) + v_shaped_pd = v_shaped + pose_offsets + + ########################################################################################## + #Transform skin mesh + ############################################################################################ + + # Apply global transformation to the template mesh + rest = torch.cat([J, torch.zeros(B, Nj, 1).to(device)], dim=2).view(B, Nj, 4, 1) # BxJx4x1 + zeros = torch.zeros(B, Nj, 4, 3).to(device) # BxJx4x3 + rest = torch.cat([zeros, rest], dim=-1) # BxJx4x4 + rest = torch.matmul(G, rest) # This is a 4x4 transformation matrix that only contains translation to the rest pose joint location + Gskin = G - rest + + # Compute per vertex transformation matrix (after weighting) + T = torch.matmul(self.skin_weights, Gskin.permute(1, 0, 2, 3).contiguous().view(Nj, -1)).view(Ns, B, 4,4).transpose(0, 1) + rest_shape_h = torch.cat([v_shaped_pd, torch.ones_like(v_shaped_pd)[:, :, [0]]], dim=-1) + v_posed = torch.matmul(T, rest_shape_h[:, :, :, None])[:, :, :3, 0] + + # translation + v_trans = v_posed + trans[:,None,:] + + ########################################################################################## + #Transform joints + ############################################################################################ + + # import ipdb; ipdb.set_trace() + root_transform = with_zeros(torch.cat((R[:,0],J[:,0][:,:,None]),2)) + results = [root_transform] + for i in range(0, self.parent.shape[0]): + transform_i = with_zeros(torch.cat((R[:, i + 1], t_posed[:,i+1]), 2)) + curr_res = torch.matmul(results[self.parent[i]],transform_i) + results.append(curr_res) + results = torch.stack(results, dim=1) + posed_joints = results[:, :, :3, 3] + J_transformed = posed_joints + trans[:,None,:] + + + ########################################################################################## + # Transform skeleton + ############################################################################################ + + if skelmesh: + G_bones = None + # Shape the skeleton by scaling its bones + skel_rest_shape_h = torch.cat([skel_v0, torch.ones_like(skel_v0)[:, :, [0]]], dim=-1).expand(B, Nk, -1) # (1,Nk,3) + + # compute the bones scaling from the kinematic tree and skin mesh + #with torch.no_grad(): + # TODO: when dJ is optimized the shape of the mesh should be affected by the gradients + bone_scale = self.compute_bone_scale(J_, v_shaped, skin_v0) + # Apply bone meshes scaling: + skel_v_shaped = torch.cat([(torch.matmul(bone_scale[:,:,0], self.skel_weights_rigid.T) * skel_rest_shape_h[:, :, 0])[:, :, None], + (torch.matmul(bone_scale[:,:,1], self.skel_weights_rigid.T) * skel_rest_shape_h[:, :, 1])[:, :, None], + (torch.matmul(bone_scale[:,:,2], self.skel_weights_rigid.T) * skel_rest_shape_h[:, :, 2])[:, :, None], + (torch.ones(B, Nk, 1).to(device)) + ], dim=-1) + + # Align the bones with the proper axis + Gk01 = build_homog_matrix(Rk01, J.unsqueeze(-1)) # BxJx4x4 + T = torch.matmul(self.skel_weights_rigid, Gk01.permute(1, 0, 2, 3).contiguous().view(Nj, -1)).view(Nk, B, 4,4).transpose(0, 1) #[1, 48757, 3, 3] + skel_v_align = torch.matmul(T, skel_v_shaped[:, :, :, None])[:, :, :, 0] + + # This transfo will be applied with weights, effectively unposing the whole skeleton mesh in each joint frame. + # Then, per joint weighted transformation can then be applied + G_tpose_to_unposed = build_homog_matrix(torch.eye(3).view(1,1,3,3).expand(B, Nj, 3, 3).to(device), -J.unsqueeze(-1)) # BxJx4x4 + G_skel = torch.matmul(G, G_tpose_to_unposed) + G_bones = torch.matmul(G, Gk01) + + T = torch.matmul(self.skel_weights, G_skel.permute(1, 0, 2, 3).contiguous().view(Nj, -1)).view(Nk, B, 4,4).transpose(0, 1) + skel_v_posed = torch.matmul(T, skel_v_align[:, :, :, None])[:, :, :3, 0] + + skel_trans = skel_v_posed + trans[:,None,:] + + else: + skel_trans = skel_v0 + Gk01 = build_homog_matrix(Rk01, J.unsqueeze(-1)) # BxJx4x4 + G_bones = torch.matmul(G, Gk01) + + joints = J_transformed + skin_verts = v_trans + skel_verts = skel_trans + joints_ori = G_bones[:,:,:3,:3] + + if skin_verts.max() > 1e3: + import ipdb; ipdb.set_trace() + + output = SKELOutput(skin_verts=skin_verts, + skel_verts=skel_verts, + joints=joints, + joints_ori=joints_ori, + betas=betas, + poses=poses, + trans = trans, + pose_offsets = pose_offsets, + joints_tpose = J_tpose, + v_shaped = v_shaped,) + + return output + + + def compute_bone_scale(self, J_, v_shaped, skin_v0): + + # index [0, 1, 2, 3 4, 5, , ...] # todo add last one, figure out bone scale indices + # J_ bone vectors [j0, j1-j0, j2-j0, j3-j0, j4-j1, j5-j2, ...] + # norm(J) = length of the bone [j0, j1-j0, j2-j0, j3-j0, j4-j1, j5-j2, ...] + # self.joints_sockets [j0, j1-j0, j2-j0, j3-j0, j4-j1, j5-j2, ...] + # self.skel_weights [j0, j1, j2, j3, j4, j5, ...] + B = J_.shape[0] + Nj = J_.shape[1] + + bone_scale = torch.ones(B, Nj).to(J_.device) + + # BSM template joints location + osim_joints_r = self.apose_rel_transfo[:, :3, 3].view(1, Nj, 3).expand(B, Nj, 3).clone() + + length_bones_bsm = torch.norm(osim_joints_r, dim=-1).expand(B, -1) + length_bones_smpl = torch.norm(J_, dim=-1) # (B, Nj) + bone_scale_parent = length_bones_smpl / length_bones_bsm + + non_leaf_node = (self.child != 0) + bone_scale[:,non_leaf_node] = (bone_scale_parent[:,self.child])[:,non_leaf_node] + + # Ulna should have the same scale as radius + bone_scale[:,16] = bone_scale[:,17] + bone_scale[:,16] = bone_scale[:,17] + + bone_scale[:,21] = bone_scale[:,22] + bone_scale[:,21] = bone_scale[:,22] + + # Thorax + # Thorax scale is defined by the relative position of the thorax to its child joint, not parent joint as for other bones + bone_scale[:, 12] = bone_scale[:, 11] + + # Lumbars + # Lumbar scale is defined by the y relative position of the lumbar joint + length_bones_bsm = torch.abs(osim_joints_r[:,11, 1]) + length_bones_smpl = torch.abs(J_[:, 11, 1]) # (B, Nj) + bone_scale_lumbar = length_bones_smpl / length_bones_bsm + bone_scale[:, 11] = bone_scale_lumbar + + # Expand to 3 dimensions and adjest scaling to avoid skin-skeleton intersection and handle the scaling of leaf body parts (hands, feet) + bone_scale = bone_scale.reshape(B, Nj, 1).expand(B, Nj, 3).clone() + + for (ji, doi, dsi), (v1, v2) in scaling_keypoints.items(): + bone_scale[:, ji, doi] = ((v_shaped[:,v1] - v_shaped[:, v2])/ (skin_v0[:,v1] - skin_v0[:, v2]))[:,dsi] # Top over chin + #TODO: Add keypoints for feet scaling in scaling_keypoints + + # Adjust thorax front-back scaling + # TODO fix this part + v1 = 3027 #thorax back + v2 = 3495 #thorax front + + scale_thorax_up = ((v_shaped[:,v1] - v_shaped[:, v2])/ (skin_v0[:,v1] - skin_v0[:, v2]))[:,2] # good for large people + v2 = 3506 #sternum + scale_thorax_sternum = ((v_shaped[:,v1] - v_shaped[:, v2])/ (skin_v0[:,v1] - skin_v0[:, v2]))[:,2] # Good for skinny people + bone_scale[:, 12, 0] = torch.min(scale_thorax_up, scale_thorax_sternum) # Avoids super expanded ribcage for large people and sternum outside for skinny people + + #lumbars, adjust width to be same as thorax + bone_scale[:, 11, 0] = bone_scale[:, 12, 0] + + return bone_scale + + + + def compute_bone_orientation(self, J, J_): + """Compute each bone orientation in T pose """ + + # method = 'unposed' + # method = 'learned' + method = 'learn_adjust' + + B = J_.shape[0] + Nj = J_.shape[1] + + # Create an array of bone vectors the bone meshes should be aligned to. + bone_vect = torch.zeros_like(J_) # / torch.norm(J_, dim=-1)[:, :, None] # (B, Nj, 3) + bone_vect[:] = J_[:, self.child] # Most bones are aligned between their parent and child joint + bone_vect[:,16] = bone_vect[:,16]+bone_vect[:,17] # We want to align the ulna to the segment joint 16 to 18 + bone_vect[:,21] = bone_vect[:,21]+bone_vect[:,22] # Same other ulna + + # TODO Check indices here + # bone_vect[:,13] = bone_vect[:,12].clone() + bone_vect[:,12] = bone_vect.clone()[:,11].clone() # We want to align the thorax on the thorax-lumbar segment + # bone_vect[:,11] = bone_vect[:,0].clone() + + osim_vect = self.apose_rel_transfo[:, :3, 3].clone().view(1, Nj, 3).expand(B, Nj, 3).clone() + osim_vect[:] = osim_vect[:,self.child] + osim_vect[:,16] = osim_vect[:,16]+osim_vect[:,17] # We want to align the ulna to the segment joint 16 to 18 + osim_vect[:,21] = osim_vect[:,21]+osim_vect[:,22] # We want to align the ulna to the segment joint 16 to 18 + + # TODO: remove when this has been checked + # import matplotlib.pyplot as plt + # fig = plt.figure() + # ax = fig.add_subplot(111, projection='3d') + # ax.plot(osim_vect[:,0,0], osim_vect[:,0,1], osim_vect[:,0,2], color='r') + # plt.show() + + Gk = torch.eye(3, device=J_.device).repeat(B, Nj, 1, 1) + + if method == 'unposed': + return Gk + + elif method == 'learn_adjust': + Gk_learned = self.per_joint_rot.view(1, Nj, 3, 3).expand(B, -1, -1, -1) #load learned rotation + osim_vect_corr = torch.matmul(Gk_learned, osim_vect.unsqueeze(-1)).squeeze(-1) + + Gk[:,:] = rotation_matrix_from_vectors(osim_vect_corr, bone_vect) + # set nan to zero + # TODO: Check again why the following line was required + Gk[torch.isnan(Gk)] = 0 + # Gk[:,[18,23]] = Gk[:,[16,21]] # hand has same orientation as ulna + # Gk[:,[5,10]] = Gk[:,[4,9]] # toe has same orientation as calcaneus + # Gk[:,[0,11,12,13,14,19]] = torch.eye(3, device=J_.device).view(1,3,3).expand(B, 6, 3, 3) # pelvis, torso and shoulder blade orientation does not vary with beta, leave it + Gk[:, self.joint_idx_fixed_beta] = torch.eye(3, device=J_.device).view(1,3,3).expand(B, len(self.joint_idx_fixed_beta), 3, 3) # pelvis, torso and shoulder blade orientation should not vary with beta, leave it + + Gk = torch.matmul(Gk, Gk_learned) + + elif method == 'learned': + """ Apply learned transformation""" + Gk = self.per_joint_rot.view(1, Nj, 3, 3).expand(B, -1, -1, -1) + + else: + raise NotImplementedError + + return Gk diff --git a/lib/body_models/skel/utils.py b/lib/body_models/skel/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..4970002a191bff24a4e986ee589b2414b5e40bc5 --- /dev/null +++ b/lib/body_models/skel/utils.py @@ -0,0 +1,445 @@ +# -*- coding: utf-8 -*- +# +# Copyright (C) 2020 Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG), +# acting on behalf of its Max Planck Institute for Intelligent Systems and the +# Max Planck Institute for Biological Cybernetics. All rights reserved. +# +# Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is holder of all proprietary rights +# on this computer program. You can only use this computer program if you have closed a license agreement +# with MPG or you get the right to use the computer program from someone who is authorized to grant you that right. +# Any use of the computer program without a valid license is prohibited and liable to prosecution. +# Contact: ps-license@tuebingen.mpg.de +# +# +# If you use this code in a research publication please consider citing the following: +# +# STAR: Sparse Trained Articulated Human Body Regressor +# +# +# Code Developed by: +# Ahmed A. A. Osman, edited by Marilyn Keller + +import scipy +import torch +import numpy as np + +def build_homog_matrix(R, t=None): + """ Create a homogeneous matrix from rotation matrix and translation vector + @ R: rotation matrix of shape (B, Nj, 3, 3) + @ t: translation vector of shape (B, Nj, 3, 1) + returns: homogeneous matrix of shape (B, 4, 4) + By Marilyn Keller + """ + + if t is None: + B = R.shape[0] + Nj = R.shape[1] + t = torch.zeros(B, Nj, 3, 1).to(R.device) + + if R is None: + B = t.shape[0] + Nj = t.shape[1] + R = torch.eye(3).unsqueeze(0).unsqueeze(0).repeat(B, Nj, 1, 1).to(t.device) + + B = t.shape[0] + Nj = t.shape[1] + + # import ipdb; ipdb.set_trace() + assert R.shape == (B, Nj, 3, 3), f"R.shape: {R.shape}" + assert t.shape == (B, Nj, 3, 1), f"t.shape: {t.shape}" + + G = torch.cat([R, t], dim=-1) # BxJx3x4 local transformation matrix + pad_row = torch.FloatTensor([0, 0, 0, 1]).to(R.device).view(1, 1, 1, 4).expand(B, Nj, -1, -1) # BxJx1x4 + G = torch.cat([G, pad_row], dim=2) # BxJx4x4 padded to be 4x4 matrix an enable multiplication for the kinematic chain + + return G + + +def matmul_chain(rot_list): + R_tot = rot_list[-1] + for i in range(len(rot_list)-2,-1,-1): + R_tot = torch.matmul(rot_list[i], R_tot) + return R_tot + + +def rotation_6d_to_matrix(d6: torch.Tensor) -> torch.Tensor: + """ + Converts 6D rotation representation by Zhou et al. [1] to rotation matrix + using Gram--Schmidt orthogonalization per Section B of [1]. + Args: + d6: 6D rotation representation, of size (*, 6) + + Returns: + batch of rotation matrices of size (*, 3, 3) + + [1] Zhou, Y., Barnes, C., Lu, J., Yang, J., & Li, H. + On the Continuity of Rotation Representations in Neural Networks. + IEEE Conference on Computer Vision and Pattern Recognition, 2019. + Retrieved from http://arxiv.org/abs/1812.07035 + """ + import torch.nn.functional as F + a1, a2 = d6[..., :3], d6[..., 3:] + b1 = F.normalize(a1, dim=-1) + b2 = a2 - (b1 * a2).sum(-1, keepdim=True) * b1 + b2 = F.normalize(b2, dim=-1) + b3 = torch.cross(b1, b2, dim=-1) + return torch.stack((b1, b2, b3), dim=-2) + + +def rotation_matrix_from_vectors(vec1, vec2): + """ Find the rotation matrix that aligns vec1 to vec2 + :param vec1: A 3d "source" vector (B x Nj x 3) + :param vec2: A 3d "destination" vector (B x Nj x 3) + :return mat: A rotation matrix (B x Nj x 3 x 3) which when applied to vec1, aligns it with vec2. + """ + for v_id, v in enumerate([vec1, vec2]): + # vectors shape should be B x Nj x 3 + assert len(v.shape) == 3, f"Vectors {v_id} shape should be B x Nj x 3, got {v.shape}" + assert v.shape[-1] == 3, f"Vectors {v_id} shape should be B x Nj x 3, got {v.shape}" + + B = vec1.shape[0] + Nj = vec1.shape[1] + device = vec1.device + + a = vec1 / torch.linalg.norm(vec1, dim=-1, keepdim=True) + b = vec2 / torch.linalg.norm(vec2, dim=-1, keepdim=True) + v = torch.cross(a, b, dim=-1) + # Compute the dot product along the last dimension of a and b + c = torch.sum(a * b, dim=-1) + s = torch.linalg.norm(v, dim=-1) + torch.finfo(float).eps + v0 = torch.zeros_like(v[...,0], device=device).unsqueeze(-1) + kmat_l1 = torch.cat([v0, -v[...,2].unsqueeze(-1), v[...,1].unsqueeze(-1)], dim=-1) + kmat_l2 = torch.cat([v[...,2].unsqueeze(-1), v0, -v[...,0].unsqueeze(-1)], dim=-1) + kmat_l3 = torch.cat([-v[...,1].unsqueeze(-1), v[...,0].unsqueeze(-1), v0], dim=-1) + # Stack the matrix lines along a the -2 dimension + kmat = torch.cat([kmat_l1.unsqueeze(-2), kmat_l2.unsqueeze(-2), kmat_l3.unsqueeze(-2)], dim=-2) # B x Nj x 3 x 3 + # import ipdb; ipdb.set_trace() + rotation_matrix = torch.eye(3, device=device).view(1,1,3,3).expand(B, Nj, 3, 3) + kmat + torch.matmul(kmat, kmat) * ((1 - c) / (s ** 2)).view(B, Nj, 1, 1).expand(B, Nj, 3, 3) + return rotation_matrix + + +def quat_feat(theta): + ''' + Computes a normalized quaternion ([0,0,0,0] when the body is in rest pose) + given joint angles + :param theta: A tensor of joints axis angles, batch size x number of joints x 3 + :return: + ''' + l1norm = torch.norm(theta + 1e-8, p=2, dim=1) + angle = torch.unsqueeze(l1norm, -1) + normalized = torch.div(theta, angle) + angle = angle * 0.5 + v_cos = torch.cos(angle) + v_sin = torch.sin(angle) + quat = torch.cat([v_sin * normalized,v_cos-1], dim=1) + return quat + +def quat2mat(quat): + ''' + Converts a quaternion to a rotation matrix + :param quat: + :return: + ''' + norm_quat = quat + norm_quat = norm_quat / norm_quat.norm(p=2, dim=1, keepdim=True) + w, x, y, z = norm_quat[:, 0], norm_quat[:, 1], norm_quat[:, 2], norm_quat[:, 3] + B = quat.size(0) + w2, x2, y2, z2 = w.pow(2), x.pow(2), y.pow(2), z.pow(2) + wx, wy, wz = w * x, w * y, w * z + xy, xz, yz = x * y, x * z, y * z + rotMat = torch.stack([w2 + x2 - y2 - z2, 2 * xy - 2 * wz, 2 * wy + 2 * xz, + 2 * wz + 2 * xy, w2 - x2 + y2 - z2, 2 * yz - 2 * wx, + 2 * xz - 2 * wy, 2 * wx + 2 * yz, w2 - x2 - y2 + z2], dim=1).view(B, 3, 3) + return rotMat + + +def rodrigues(theta): + ''' + Computes the rodrigues representation given joint angles + + :param theta: batch_size x number of joints x 3 + :return: batch_size x number of joints x 3 x 4 + ''' + l1norm = torch.norm(theta + 1e-8, p = 2, dim = 1) + angle = torch.unsqueeze(l1norm, -1) + normalized = torch.div(theta, angle) + angle = angle * 0.5 + v_cos = torch.cos(angle) + v_sin = torch.sin(angle) + quat = torch.cat([v_cos, v_sin * normalized], dim = 1) + return quat2mat(quat) + + +def with_zeros(input): + ''' + Appends a row of [0,0,0,1] to a batch size x 3 x 4 Tensor + + :param input: A tensor of dimensions batch size x 3 x 4 + :return: A tensor batch size x 4 x 4 (appended with 0,0,0,1) + ''' + batch_size = input.shape[0] + row_append = torch.FloatTensor(([0.0, 0.0, 0.0, 1.0])).to(input.device) + row_append.requires_grad = False + padded_tensor = torch.cat([input, row_append.view(1, 1, 4).repeat(batch_size, 1, 1)], 1) + return padded_tensor + + +def with_zeros_44(input): + ''' + Appends a row of [0,0,0,1] to a batch size x 3 x 4 Tensor + + :param input: A tensor of dimensions batch size x 3 x 4 + :return: A tensor batch size x 4 x 4 (appended with 0,0,0,1) + ''' + import ipdb; ipdb.set_trace() + batch_size = input.shape[0] + col_append = torch.FloatTensor(([[[[0.0, 0.0, 0.0]]]])).to(input.device) + padded_tensor = torch.cat([input, col_append], dim=-1) + + row_append = torch.FloatTensor(([0.0, 0.0, 0.0, 1.0])).to(input.device) + row_append.requires_grad = False + padded_tensor = torch.cat([input, row_append.view(1, 1, 4).repeat(batch_size, 1, 1)], 1) + return padded_tensor + + +def vector_to_rot(): + + def rotation_matrix(A,B): + # Aligns vector A to vector B + + ax = A[0] + ay = A[1] + az = A[2] + + bx = B[0] + by = B[1] + bz = B[2] + + au = A/(torch.sqrt(ax*ax + ay*ay + az*az)) + bu = B/(torch.sqrt(bx*bx + by*by + bz*bz)) + + R=torch.tensor([[bu[0]*au[0], bu[0]*au[1], bu[0]*au[2]], [bu[1]*au[0], bu[1]*au[1], bu[1]*au[2]], [bu[2]*au[0], bu[2]*au[1], bu[2]*au[2]] ]) + + + return(R) + + +def axis_angle_to_matrix(axis_angle: torch.Tensor) -> torch.Tensor: + """ + Convert rotations given as axis/angle to rotation matrices. + + Args: + axis_angle: Rotations given as a vector in axis angle form, + as a tensor of shape (..., 3), where the magnitude is + the angle turned anticlockwise in radians around the + vector's direction. + + Returns: + Rotation matrices as tensor of shape (..., 3, 3). + """ + return quaternion_to_matrix(axis_angle_to_quaternion(axis_angle)) + + +def axis_angle_to_quaternion(axis_angle: torch.Tensor) -> torch.Tensor: + """ + Convert rotations given as axis/angle to quaternions. + + Args: + axis_angle: Rotations given as a vector in axis angle form, + as a tensor of shape (..., 3), where the magnitude is + the angle turned anticlockwise in radians around the + vector's direction. + + Returns: + quaternions with real part first, as tensor of shape (..., 4). + """ + angles = torch.norm(axis_angle, p=2, dim=-1, keepdim=True) + half_angles = angles * 0.5 + eps = 1e-6 + small_angles = angles.abs() < eps + sin_half_angles_over_angles = torch.empty_like(angles) + sin_half_angles_over_angles[~small_angles] = ( + torch.sin(half_angles[~small_angles]) / angles[~small_angles] + ) + # for x small, sin(x/2) is about x/2 - (x/2)^3/6 + # so sin(x/2)/x is about 1/2 - (x*x)/48 + sin_half_angles_over_angles[small_angles] = ( + 0.5 - (angles[small_angles] * angles[small_angles]) / 48 + ) + quaternions = torch.cat( + [torch.cos(half_angles), axis_angle * sin_half_angles_over_angles], dim=-1 + ) + return quaternions + + +def quaternion_to_matrix(quaternions: torch.Tensor) -> torch.Tensor: + """ + Convert rotations given as quaternions to rotation matrices. + + Args: + quaternions: quaternions with real part first, + as tensor of shape (..., 4). + + Returns: + Rotation matrices as tensor of shape (..., 3, 3). + """ + r, i, j, k = torch.unbind(quaternions, -1) + two_s = 2.0 / (quaternions * quaternions).sum(-1) + + o = torch.stack( + ( + 1 - two_s * (j * j + k * k), + two_s * (i * j - k * r), + two_s * (i * k + j * r), + two_s * (i * j + k * r), + 1 - two_s * (i * i + k * k), + two_s * (j * k - i * r), + two_s * (i * k - j * r), + two_s * (j * k + i * r), + 1 - two_s * (i * i + j * j), + ), + -1, + ) + return o.reshape(quaternions.shape[:-1] + (3, 3)) + + +def axis_angle_rotation(axis: str, angle: torch.Tensor) -> torch.Tensor: + """ + Return the rotation matrices for one of the rotations about an axis + of which Euler angles describe, for each value of the angle given. + + Args: + axis: Axis label "X" or "Y or "Z". + angle: any shape tensor of Euler angles in radians + + Returns: + Rotation matrices as tensor of shape (..., 3, 3). + """ + + cos = torch.cos(angle) + sin = torch.sin(angle) + one = torch.ones_like(angle) + zero = torch.zeros_like(angle) + + if axis == "X": + R_flat = (one, zero, zero, zero, cos, -sin, zero, sin, cos) + elif axis == "Y": + R_flat = (cos, zero, sin, zero, one, zero, -sin, zero, cos) + elif axis == "Z": + R_flat = (cos, -sin, zero, sin, cos, zero, zero, zero, one) + else: + raise ValueError("letter must be either X, Y or Z.") + + return torch.stack(R_flat, -1).reshape(angle.shape + (3, 3)) + + +def axis_angle_to_matrix(axis_angle: torch.Tensor) -> torch.Tensor: + """ + Convert rotations given as axis/angle to rotation matrices. + + Args: + axis_angle: Rotations given as a vector in axis angle form, + as a tensor of shape (..., 3), where the magnitude is + the angle turned anticlockwise in radians around the + vector's direction. + + Returns: + Rotation matrices as tensor of shape (..., 3, 3). + """ + return quaternion_to_matrix(axis_angle_to_quaternion(axis_angle)) + + +def euler_angles_to_matrix(euler_angles: torch.Tensor, convention: str) -> torch.Tensor: + """ + Convert rotations given as Euler angles in radians to rotation matrices. + + Args: + euler_angles: Euler angles in radians as tensor of shape (..., 3). + convention: Convention string of three uppercase letters from + {"X", "Y", and "Z"}. + + Returns: + Rotation matrices as tensor of shape (..., 3, 3). + """ + if euler_angles.dim() == 0 or euler_angles.shape[-1] != 3: + raise ValueError("Invalid input euler angles.") + if len(convention) != 3: + raise ValueError("Convention must have 3 letters.") + if convention[1] in (convention[0], convention[2]): + raise ValueError(f"Invalid convention {convention}.") + for letter in convention: + if letter not in ("X", "Y", "Z"): + raise ValueError(f"Invalid letter {letter} in convention string.") + matrices = [ + _axis_angle_rotation(c, e) + for c, e in zip(convention, torch.unbind(euler_angles, -1)) + ] + # return functools.reduce(torch.matmul, matrices) + return torch.matmul(torch.matmul(matrices[0], matrices[1]), matrices[2]) + + +def _axis_angle_rotation(axis: str, angle: torch.Tensor) -> torch.Tensor: + """ + Return the rotation matrices for one of the rotations about an axis + of which Euler angles describe, for each value of the angle given. + + Args: + axis: Axis label "X" or "Y or "Z". + angle: any shape tensor of Euler angles in radians + + Returns: + Rotation matrices as tensor of shape (..., 3, 3). + """ + + cos = torch.cos(angle) + sin = torch.sin(angle) + one = torch.ones_like(angle) + zero = torch.zeros_like(angle) + + if axis == "X": + R_flat = (one, zero, zero, zero, cos, -sin, zero, sin, cos) + elif axis == "Y": + R_flat = (cos, zero, sin, zero, one, zero, -sin, zero, cos) + elif axis == "Z": + R_flat = (cos, -sin, zero, sin, cos, zero, zero, zero, one) + else: + raise ValueError("letter must be either X, Y or Z.") + + return torch.stack(R_flat, -1).reshape(angle.shape + (3, 3)) + + +def location_to_spheres(loc, color=(1,0,0), radius=0.02): + """Given an array of 3D points, return a list of spheres located at those positions. + + Args: + loc (numpy.array): Nx3 array giving 3D positions + color (tuple, optional): One RGB float color vector to color the spheres. Defaults to (1,0,0). + radius (float, optional): Radius of the spheres in meters. Defaults to 0.02. + + Returns: + list: List of spheres Mesh + """ + from psbody.mesh.sphere import Sphere + import numpy as np + cL = [Sphere(np.asarray([loc[i, 0], loc[i, 1], loc[i, 2]]), radius).to_mesh() for i in range(loc.shape[0])] + for spL in cL: + spL.set_vertex_colors(np.array(color)) + return cL + +def sparce_coo_matrix2tensor(arr_coo, make_dense=True): + assert isinstance(arr_coo, scipy.sparse._coo.coo_matrix), f"arr_coo should be a coo_matrix, got {type(arr_coo)}. Please download the updated SKEL pkl files from https://skel.is.tue.mpg.de/." + + values = arr_coo.data + indices = np.vstack((arr_coo.row, arr_coo.col)) + + i = torch.LongTensor(indices) + v = torch.FloatTensor(values) + shape = arr_coo.shape + + tensor_arr = torch.sparse_coo_tensor(i, v, torch.Size(shape)) + + if make_dense: + tensor_arr = tensor_arr.to_dense() + + return tensor_arr + diff --git a/lib/body_models/skel_utils/augmentation.py b/lib/body_models/skel_utils/augmentation.py new file mode 100644 index 0000000000000000000000000000000000000000..2e5c1c3d855b3efab60ce82b949c7fad2ce649f2 --- /dev/null +++ b/lib/body_models/skel_utils/augmentation.py @@ -0,0 +1,47 @@ +import torch + +from typing import Optional + +from .transforms import real_orient_mat2q, real_orient_q2mat + + +def update_params_after_orient_rotation( + poses : torch.Tensor, # (B, 46) + rot_mat : torch.Tensor, # the rotation orientation matrix + root_offset : Optional[torch.Tensor] = None, # the offset from custom root to model root +): + ''' + + ### Args + - `poses`: torch.Tensor, shape = (B, 46) + - `rot_mat`: torch.Tensor, shape = (B, 3, 3) + - `root_offset`: torch.Tensor or None, shape = (B, 3) + - If None, the function won't update the translation. + - If not None, the function will calculate the root translation offset that make the model + rotate around the custom root instead of the model root. + + ### Returns + - If `root_offset` is None: + - `poses`: torch.Tensor, shape = (B, 46) + - If `root_offset` is not None: + - `poses`: torch.Tensor, shape = (B, 46) + - `trans_offset`: torch.Tensor, shape = (B, 3) + ''' + poses = poses.clone() + # 1. Transform the SKEL orientation to real matrix. + orient_q = poses[:, :3] # (B, 3) + orient_mat = real_orient_q2mat(orient_q) # (B, 3, 3) + orient_mat = torch.einsum('bij,bjk->bik', rot_mat, orient_mat) # (B, 3, 3) + orient_q = real_orient_mat2q(orient_mat) # (B, 3) + poses[:, :3] = orient_q + + # 2. Update the translation if needed. + if root_offset is not None: + root_before = root_offset.clone() # (B, 3) + root_after = torch.einsum('bij,bj->bi', rot_mat, root_before) # (B, 3) + root_offset = root_after - root_before # (B, 3) + ret = poses, root_offset + else: + ret = poses + + return ret \ No newline at end of file diff --git a/lib/body_models/skel_utils/definition.py b/lib/body_models/skel_utils/definition.py new file mode 100644 index 0000000000000000000000000000000000000000..8a8e84c78bb57e38e3c7874f954674033c928b88 --- /dev/null +++ b/lib/body_models/skel_utils/definition.py @@ -0,0 +1,100 @@ +from lib.body_models.skel.osim_rot import ConstantCurvatureJoint, CustomJoint, EllipsoidJoint, PinJoint, WalkerKnee + +Q_COMPONENTS = [ + {'qid': 0, 'name': 'pelvis', 'jid': 0}, + {'qid': 1, 'name': 'pelvis', 'jid': 0}, + {'qid': 2, 'name': 'pelvis', 'jid': 0}, + {'qid': 3, 'name': 'femur-r', 'jid': 1}, + {'qid': 4, 'name': 'femur-r', 'jid': 1}, + {'qid': 5, 'name': 'femur-r', 'jid': 1}, + {'qid': 6, 'name': 'tibia-r', 'jid': 2}, + {'qid': 7, 'name': 'talus-r', 'jid': 3}, + {'qid': 8, 'name': 'calcn-r', 'jid': 4}, + {'qid': 9, 'name': 'toes-r', 'jid': 5}, + {'qid': 10, 'name': 'femur-l', 'jid': 6}, + {'qid': 11, 'name': 'femur-l', 'jid': 6}, + {'qid': 12, 'name': 'femur-l', 'jid': 6}, + {'qid': 13, 'name': 'tibia-l', 'jid': 7}, + {'qid': 14, 'name': 'talus-l', 'jid': 8}, + {'qid': 15, 'name': 'calcn-l', 'jid': 9}, + {'qid': 16, 'name': 'toes-l', 'jid': 10}, + {'qid': 17, 'name': 'lumbar', 'jid': 11}, + {'qid': 18, 'name': 'lumbar', 'jid': 11}, + {'qid': 19, 'name': 'lumbar', 'jid': 11}, + {'qid': 20, 'name': 'thorax', 'jid': 12}, + {'qid': 21, 'name': 'thorax', 'jid': 12}, + {'qid': 22, 'name': 'thorax', 'jid': 12}, + {'qid': 23, 'name': 'head', 'jid': 13}, + {'qid': 24, 'name': 'head', 'jid': 13}, + {'qid': 25, 'name': 'head', 'jid': 13}, + {'qid': 26, 'name': 'scapula-r', 'jid': 14}, + {'qid': 27, 'name': 'scapula-r', 'jid': 14}, + {'qid': 28, 'name': 'scapula-r', 'jid': 14}, + {'qid': 29, 'name': 'humerus-r', 'jid': 15}, + {'qid': 30, 'name': 'humerus-r', 'jid': 15}, + {'qid': 31, 'name': 'humerus-r', 'jid': 15}, + {'qid': 32, 'name': 'ulna-r', 'jid': 16}, + {'qid': 33, 'name': 'radius-r', 'jid': 17}, + {'qid': 34, 'name': 'hand-r', 'jid': 18}, + {'qid': 35, 'name': 'hand-r', 'jid': 18}, + {'qid': 36, 'name': 'scapula-l', 'jid': 19}, + {'qid': 37, 'name': 'scapula-l', 'jid': 19}, + {'qid': 38, 'name': 'scapula-l', 'jid': 19}, + {'qid': 39, 'name': 'humerus-l', 'jid': 20}, + {'qid': 40, 'name': 'humerus-l', 'jid': 20}, + {'qid': 41, 'name': 'humerus-l', 'jid': 20}, + {'qid': 42, 'name': 'ulna-l', 'jid': 21}, + {'qid': 43, 'name': 'radius-l', 'jid': 22}, + {'qid': 44, 'name': 'hand-l', 'jid': 23}, + {'qid': 45, 'name': 'hand-l', 'jid': 23}, +] + + +QID2JID = {c['qid']: c['jid'] for c in Q_COMPONENTS} + +JID2QIDS = {} +for c in Q_COMPONENTS: + jid = c['jid'] + JID2QIDS[jid] = [] if jid not in JID2QIDS else JID2QIDS[jid] + JID2QIDS[jid].append(c['qid']) + +JID2DOF = {jid: len(qids) for jid, qids in JID2QIDS.items()} + +DoF1_JIDS = [2, 3, 4, 5, 7, 8, 9, 10, 16, 17, 21, 22] # (J1=12,) +DoF2_JIDS = [18, 23] # (J2=2,) +DoF3_JIDS = [0, 1, 6, 11, 12, 13, 14, 15, 19, 20] # (J3=10,) +DoF1_QIDS = [6, 7, 8, 9, 13, 14, 15, 16, 32, 33, 42, 43] # (Q1=12,) +DoF2_QIDS = [34, 35, 44, 45] # (Q2=4,) +DoF3_QIDS = [0, 1, 2, 3, 4, 5, 10, 11, 12, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 36, 37, 38, 39, 40, 41] # (Q3=30,) + + +# Copied from the `skel_model.py`. +# Change all axis (except those PinJoint) to positive and update the flip if needed. +JOINTS_DEF = [ + CustomJoint(axis=[[0,0,1], [1,0,0], [0,1,0]], axis_flip=[1, 1, 1]), # 0 pelvis + CustomJoint(axis=[[0,0,1], [1,0,0], [0,1,0]], axis_flip=[1, 1, 1]), # 1 femur_r + WalkerKnee(), # 2 tibia_r + PinJoint(parent_frame_ori = [0.175895, -0.105208, 0.0186622]), # 3 talus_r Field taken from .osim Joint-> frames -> PhysicalOffsetFrame -> orientation + PinJoint(parent_frame_ori = [-1.76818999, 0.906223, 1.8196000]), # 4 calcn_r + PinJoint(parent_frame_ori = [-3.141589999, 0.6199010, 0]), # 5 toes_r + CustomJoint(axis=[[0,0,1], [1,0,0], [0,1,0]], axis_flip=[1, -1, -1]), # 6 femur_l + WalkerKnee(), # 7 tibia_l + PinJoint(parent_frame_ori = [0.175895, -0.105208, 0.0186622]), # 8 talus_l + PinJoint(parent_frame_ori = [1.768189999 ,-0.906223, 1.8196000]), # 9 calcn_l + PinJoint(parent_frame_ori = [-3.141589999, -0.6199010, 0]), # 10 toes_l + ConstantCurvatureJoint(axis=[[1,0,0], [0,0,1], [0,1,0]], axis_flip=[1, 1, 1]), # 11 lumbar + ConstantCurvatureJoint(axis=[[1,0,0], [0,0,1], [0,1,0]], axis_flip=[1, 1, 1]), # 12 thorax + ConstantCurvatureJoint(axis=[[1,0,0], [0,0,1], [0,1,0]], axis_flip=[1, 1, 1]), # 13 head + EllipsoidJoint(axis=[[0,1,0], [0,0,1], [1,0,0]], axis_flip=[1, -1, -1]), # 14 scapula_r + CustomJoint(axis=[[1,0,0], [0,1,0], [0,0,1]], axis_flip=[1, 1, 1]), # 15 humerus_r + CustomJoint(axis=[[0.0494, 0.0366, 0.99810825]], axis_flip=[[1]]), # 16 ulna_r + CustomJoint(axis=[[-0.01716099, 0.99266564, -0.11966796]], axis_flip=[[1]]), # 17 radius_r + CustomJoint(axis=[[1,0,0], [0,0,1]], axis_flip=[1, -1]), # 18 hand_r + EllipsoidJoint(axis=[[0,1,0], [0,0,1], [1,0,0]], axis_flip=[1, 1, 1]), # 19 scapula_l + CustomJoint(axis=[[1,0,0], [0,1,0], [0,0,1]], axis_flip=[1, 1, 1]), # 20 humerus_l + CustomJoint(axis=[[-0.0494, -0.0366, 0.99810825]], axis_flip=[[1]]), # 21 ulna_l + CustomJoint(axis=[[0.01716099, -0.99266564, -0.11966796]], axis_flip=[[1]]), # 22 radius_l + CustomJoint(axis=[[1,0,0], [0,0,1]], axis_flip=[-1, -1]), # 23 hand_l +] + +N_JOINTS = len(JOINTS_DEF) # 24 diff --git a/lib/body_models/skel_utils/limits.py b/lib/body_models/skel_utils/limits.py new file mode 100644 index 0000000000000000000000000000000000000000..e61c146ad7465b4660e002b79e8ccd485d6e300a --- /dev/null +++ b/lib/body_models/skel_utils/limits.py @@ -0,0 +1,78 @@ +import math +import torch + +from lib.body_models.skel.kin_skel import pose_param_names + +# Different from the original one, has some modifications. +pose_limits = { + 'scapula_abduction_r' : [-0.628, 0.628], # 26 + 'scapula_elevation_r' : [-0.4, -0.1], # 27 + 'scapula_upward_rot_r' : [-0.190, 0.319], # 28 + + 'scapula_abduction_l' : [-0.628, 0.628], # 36 + 'scapula_elevation_l' : [-0.4, -0.1], # 37 + 'scapula_upward_rot_l' : [-0.210, 0.219], # 38 + + 'elbow_flexion_r' : [0, (3/4)*math.pi], # 32 + 'pro_sup_r' : [-3/4*math.pi/2, 3/4*math.pi/2], # 33 + 'wrist_flexion_r' : [-math.pi/2, math.pi/2], # 34 + 'wrist_deviation_r' :[-math.pi/4, math.pi/4], # 35 + + 'elbow_flexion_l' : [0, (3/4)*math.pi], # 42 + 'pro_sup_l' : [-math.pi/2, math.pi/2], # 43 + 'wrist_flexion_l' : [-math.pi/2, math.pi/2], # 44 + 'wrist_deviation_l' :[-math.pi/4, math.pi/4], # 45 + + 'lumbar_bending' : [-2/3*math.pi/4, 2/3*math.pi/4], # 17 + 'lumbar_extension' : [-math.pi/4, math.pi/4], # 18 + 'lumbar_twist' : [-math.pi/4, math.pi/4], # 19 + + 'thorax_bending' :[-math.pi/4, math.pi/4], # 20 + 'thorax_extension' :[-math.pi/4, math.pi/4], # 21 + 'thorax_twist' :[-math.pi/4, math.pi/4], # 22 + + 'head_bending' :[-math.pi/4, math.pi/4], # 23 + 'head_extension' :[-math.pi/4, math.pi/4], # 24 + 'head_twist' :[-math.pi/4, math.pi/4], # 25 + + 'ankle_angle_r' : [-math.pi/4, math.pi/4], # 7 + 'subtalar_angle_r' : [-math.pi/4, math.pi/4], # 8 + 'mtp_angle_r' : [-math.pi/4, math.pi/4], # 9 + + 'ankle_angle_l' : [-math.pi/4, math.pi/4], # 14 + 'subtalar_angle_l' : [-math.pi/4, math.pi/4], # 15 + 'mtp_angle_l' : [-math.pi/4, math.pi/4], # 16 + + 'knee_angle_r' : [0, 3/4*math.pi], # 6 + 'knee_angle_l' : [0, 3/4*math.pi], # 13 + + # Added by HSMR to make optimization more stable. + 'hip_flexion_r' : [-math.pi/4, 3/4*math.pi], # 3 + 'hip_adduction_r' : [-math.pi/4, 2/3*math.pi/4], # 4 + 'hip_rotation_r' : [-math.pi/4, math.pi/4], # 5 + 'hip_flexion_l' : [-math.pi/4, 3/4*math.pi], # 10 + 'hip_adduction_l' : [-math.pi/4, 2/3*math.pi/4], # 11 + 'hip_rotation_l' : [-math.pi/4, math.pi/4], # 12 + + 'shoulder_r_x' : [-math.pi/2, math.pi/2+1.5], # 29, from bsm.osim + 'shoulder_r_y' : [-math.pi/2, math.pi/2], # 30 + 'shoulder_r_z' : [-math.pi/2, math.pi/2], # 31, from bsm.osim + + 'shoulder_l_x' : [-math.pi/2-1.5, math.pi/2], # 39, from bsm.osim + 'shoulder_l_y' : [-math.pi/2, math.pi/2], # 40 + 'shoulder_l_z' : [-math.pi/2, math.pi/2], # 41, from bsm.osim + } + +pose_param_name2qid = {name: qid for qid, name in enumerate(pose_param_names)} +qid2pose_param_name = {qid: name for qid, name in enumerate(pose_param_names)} + +SKEL_LIM_QIDS = [] +SKEL_LIM_BOUNDS = [] +for name, (low, up) in pose_limits.items(): + if low > up: + low, up = up, low + SKEL_LIM_QIDS.append(pose_param_name2qid[name]) + SKEL_LIM_BOUNDS.append([low, up]) + +SKEL_LIM_BOUNDS = torch.Tensor(SKEL_LIM_BOUNDS).float() +SKEL_LIM_QID2IDX = {qid: i for i, qid in enumerate(SKEL_LIM_QIDS)} # inverse mapping \ No newline at end of file diff --git a/lib/body_models/skel_utils/reality.py b/lib/body_models/skel_utils/reality.py new file mode 100644 index 0000000000000000000000000000000000000000..33d81628fcbec1ac8043f8f7e626a1482b9852a2 --- /dev/null +++ b/lib/body_models/skel_utils/reality.py @@ -0,0 +1,36 @@ +from lib.kits.basic import * + +from lib.body_models.skel_utils.limits import SKEL_LIM_QID2IDX, SKEL_LIM_BOUNDS + + +qids_cfg = { + 'l_knee': [13], + 'r_knee': [6], + 'l_elbow': [42, 43], + 'r_elbow': [32, 33], +} + + +def eval_rot_delta(poses, tol_deg=5): + tol_rad = np.deg2rad(tol_deg) + + res = {} + for part in qids_cfg: + qids = qids_cfg[part] + violation_part = poses.new_zeros(poses.shape[0], len(qids)) + for i, qid in enumerate(qids): + idx = SKEL_LIM_QID2IDX[qid] + ea = poses[:, qid] + ea = (ea + np.pi) % (2 * np.pi) - np.pi # Normalize to (-pi, pi) + exceed_lb = torch.where( + ea < SKEL_LIM_BOUNDS[idx][0] - tol_rad, + ea - SKEL_LIM_BOUNDS[idx][0] + tol_rad, 0 + ) + exceed_ub = torch.where( + ea > SKEL_LIM_BOUNDS[idx][1] + tol_rad, + ea - SKEL_LIM_BOUNDS[idx][1] - tol_rad, 0 + ) + violation_part[:, i] = exceed_lb.abs() + exceed_ub.abs() + res[part] = violation_part + + return res \ No newline at end of file diff --git a/lib/body_models/skel_utils/transforms.py b/lib/body_models/skel_utils/transforms.py new file mode 100644 index 0000000000000000000000000000000000000000..300802c9b889ed2c710c3bfc4a0e61da213173ef --- /dev/null +++ b/lib/body_models/skel_utils/transforms.py @@ -0,0 +1,404 @@ +from lib.kits.basic import * + +import torch +import numpy as np +import torch.nn.functional as F + +from .definition import JOINTS_DEF, N_JOINTS, JID2DOF, JID2QIDS, DoF1_JIDS, DoF2_JIDS, DoF3_JIDS, DoF1_QIDS, DoF2_QIDS, DoF3_QIDS + +from lib.utils.data import to_tensor +from lib.utils.geometry.rotation import ( + matrix_to_euler_angles, + matrix_to_rotation_6d, + euler_angles_to_matrix, + rotation_6d_to_matrix, +) + + +# ====== Internal Utils ====== + + +def axis2convention(axis:List): + ''' [1,0,0] -> 'X', [0,1,0] -> 'Y', [0,0,1] -> 'Z' ''' + if axis == [1, 0, 0]: + return 'X' + elif axis == [0, 1, 0]: + return 'Y' + elif axis == [0, 0, 1]: + return 'Z' + else: + raise ValueError(f'Unsupported axis: {axis}.') + + +def rotation_2d_to_angle(r2d:torch.Tensor): + ''' + Extract single angle from a 2D rotation vector, which is the first column of a 2x2 rotation matrix. + + ### Args + - r2d: torch.Tensor + - shape = (...B, 2) + + ### Returns + - shape = (...B,) + ''' + cos, sin = r2d[..., [0]], -r2d[..., [1]] + return torch.atan2(sin, cos) + +# ====== Tools ====== + + +OS2S_FLIP = [-1, 1, 1] +OS2S_CONV = 'YZX' +def real_orient_mat2q(orient_mat:torch.Tensor) -> torch.Tensor: + ''' + The rotation matrix that SKEL uses is different from the SMPL's orientation matrix. + The rotation to representation functions below can not be used to transform the rotaiton matrix. + This function is used to convert the SMPL's orientation matrix to the SKEL's orientation q. + BUT, is that really important? Maybe we shouldn't align SMPL's orientation with SKEL's, can they be different? + + ### Args + - orient_mat: torch.Tensor, shape = (..., 3, 3) + + ### Returns + - orient_q: torch.Tensor, shape = (..., 3) + ''' + device = orient_mat.device + flip = to_tensor(OS2S_FLIP, device=device) # (3,) + orient_ea = matrix_to_euler_angles(orient_mat.clone(), OS2S_CONV) # (..., 3) + orient_ea = orient_ea[..., [2, 1, 0]] # Re-permuting the order. + orient_q = orient_ea * flip[None] + return orient_q + + +def real_orient_q2mat(orient_q:torch.Tensor) -> torch.Tensor: + ''' + The rotation matrix that SKEL uses is different from the SMPL's orientation matrix. + The rotation to representation functions below can not be used to transform the rotation matrix. + This function is used to convert the SKEL's orientation q to the SMPL's orientation matrix. + BUT, is that really important? Maybe we shouldn't align SMPL's orientation with SKEL's, can they be different? + + ### Args + - orient_q: torch.Tensor, shape = (..., 3) + + ### Returns + - orient_mat: torch.Tensor, shape = (..., 3, 3) + ''' + device = orient_q.device + flip = to_tensor(OS2S_FLIP, device=device) # (3,) + orient_ea = orient_q * flip[None] + orient_ea = orient_ea[..., [2, 1, 0]] # Re-permuting the order. + orient_mat = euler_angles_to_matrix(orient_ea, OS2S_CONV) + return orient_mat + + +def flip_params_lr(poses:torch.Tensor) -> torch.Tensor: + ''' + It flips the skel through exchanging the params of left part and right part of the body. It's useful for + data augmentation. Note that the 'left & right' defined when the body is facing z+ direction, this is + only important for the orientation. + + ### Args + - poses: torch.Tensor, shape = (B, L, 46) or (L, 46) + + ### Returns + - flipped_poses: torch.Tensor, shape = (B, L, 46) or (L, 46) + ''' + assert len(poses.shape) in [2, 3] and poses.shape[-1] == 46, f'Shape of poses should be (B, L, 46) or (L, 46) but get {poses.shape}.' + + # 1. Switch the value of each pair through re-permuting. + flipped_perm = [ + 0, 1, 2, # pelvis + 10, 11, 12, # femur-r -> femur-l + 13, # tibia-r -> tibia-l + 14, # talus-r -> talus-l + 15, # calcn-r -> calcn-l + 16, # toes-r -> toes-l + 3, 4, 5, # femur-l -> femur-r + 6, # tibia-l -> tibia-r + 7, # talus-l -> talus-r + 8, # calcn-l -> calcn-r + 9, # toes-l -> toes-r + 17, 18, 19, # lumbar + 20, 21, 22, # thorax + 23, 24, 25, # head + 36, 37, 38, # scapula-r -> scapula-l + 39, 40, 41, # humerus-r -> humerus-l + 42, # ulna-r -> ulna-l + 43, # radius-r -> radius-l + 44, 45, # hand-r -> hand-l + 26, 27, 28, # scapula-l -> scapula-r + 29, 30, 31, # humerus-l -> humerus-r + 32, # ulna-l -> ulna-r + 33, # radius-l -> radius-r + 34, 35 # hand-l -> hand-r + ] + + flipped_poses = poses[..., flipped_perm] + + # 2. Mirror the rotation direction through apply -1. + flipped_sign = [ + 1, -1, -1, # pelvis + 1, 1, 1, # femur-r' + 1, # tibia-r' + 1, # talus-r' + 1, # calcn-r' + 1, # toes-r' + 1, 1, 1, # femur-l' + 1, # tibia-l' + 1, # talus-l' + 1, # calcn-l' + 1, # toes-l' + -1, 1, -1, # lumbar + -1, 1, -1, # thorax + -1, 1, -1, # head + -1, -1, 1, # scapula-r' + -1, -1, 1, # humerus-r' + 1, # ulna-r' + 1, # radius-r' + 1, 1, # hand-r' + -1, -1, 1, # scapula-l' + -1, -1, 1, # humerus-l' + 1, # ulna-l' + 1, # radius-l' + 1, 1 # hand-l' + ] + flipped_sign = torch.tensor(flipped_sign, dtype=poses.dtype, device=poses.device) # (46,) + flipped_poses = flipped_sign * flipped_poses + + return flipped_poses + + + +# def rotate_orient_around_z(q, rot): +# """ +# Rotate SKEL orientation. +# Args: +# q (np.ndarray): SKEL style rotation representation (3,). +# rot (np.ndarray): Rotation angle in degrees. +# Returns: +# np.ndarray: Rotated axis-angle vector. +# """ +# import torch +# from lib.body_models.skel.osim_rot import CustomJoint +# # q to mat +# root = CustomJoint(axis=[[0,0,1], [1,0,0], [0,1,0]], axis_flip=[1, 1, 1]) # pelvis +# q = torch.from_numpy(q).unsqueeze(0) +# q = q[:, [2, 1, 0]] +# Rp = euler_angles_to_matrix(q, convention="YXZ") +# # rotate around z +# R = torch.Tensor([[np.deg2rad(-rot), 0, 0]]) +# R = axis_angle_to_matrix(R) +# R = torch.matmul(R, Rp) +# # mat to q +# q = matrix_to_euler_angles(R, convention="YXZ") +# q = q[:, [2, 1, 0]] +# q = q.numpy().squeeze() + +# return q.astype(np.float32) + + +def params_q2rot(params_q:Union[torch.Tensor, np.ndarray]): + ''' + Transform parts of the euler-like SKEL parameters representation all to rotation matrix. + + ### Args + - params_q: Union[torch.Tensor, np.ndarray], shape = (...B, 46) or (...B, 46) + + ### Returns + - shape = (...B, 24, 9) # 24 joints, each joint has a 3x3 matrix, but for some joints, the matrix is not all used. + ''' + # Check the type and unify to torch. + is_np = isinstance(params_q, np.ndarray) + if is_np: + params_q = torch.from_numpy(params_q) + + # Prepare for necessary variables. + Bs = params_q.shape[:-1] + ident = torch.eye(3, dtype=params_q.dtype).to(params_q.device) # (3, 3) + params_rot = ident.repeat(*Bs, N_JOINTS, 1, 1) # (...B, 24, 3, 3) + + # Deal with each joints separately. Modified from the `skel_model.py` but a static version. + sid = 0 + for jid in range(N_JOINTS): + joint_obj = JOINTS_DEF[jid] + eid = sid + joint_obj.nb_dof.item() + params_rot[..., jid, :, :] = joint_obj.q_to_rot(params_q[..., sid:eid]) + sid = eid + + if is_np: + params_rot = params_rot.detach().cpu().numpy() + return params_rot + + +def params_q2rep(params_q:Union[torch.Tensor, np.ndarray]): + ''' + Transform the euler-like SKEL parameters representation to the continuous representation. + This function is not supposed to be used in the training process, but only for debugging. + The function that matters actually is the inverse of this function. + + ### Args + - params_q: Union[torch.Tensor, np.ndarray], shape = (...B, 46) or (...B, 46) + + ### Returns + - shape = (...B, 24, 6) + - Among 24 joints, for 3 DoF ones, all 6 values are used to represent the rotation; + but for 1 DoF joints, only the first 2 are used. The rest will be represented as zeros. + ''' + # Check the type and unify to torch. + is_np = isinstance(params_q, np.ndarray) + if is_np: + params_q = torch.from_numpy(params_q) + + # Prepare for necessary variables. + Bs = params_q.shape[:-1] + params_rep = params_q.new_zeros(*Bs, N_JOINTS, 6) # (...B, 24, 6) + + # Deal with each joints separately. Modified from the `skel_model.py` but a static version. + sid = 0 + for jid in range(N_JOINTS): + joint_obj = JOINTS_DEF[jid] + dof = joint_obj.nb_dof.item() + eid = sid + dof + if dof == 3: + mat = joint_obj.q_to_rot(params_q[..., sid:eid]) # (...B, 3, 3) + params_rep[..., jid, :] = matrix_to_rotation_6d(mat) # (...B, 6) + elif dof == 2: + # mat = joint_obj.q_to_rot(params_q[..., sid:eid]) # (...B, 3, 3) + # params_rep[..., jid, :] = matrix_to_rotation_6d(mat) # (...B, 6) + params_rep[..., jid, :2] = params_q[..., sid:eid] + elif dof == 1: + cos = torch.cos(params_q[..., sid]) + sin = torch.sin(params_q[..., sid]) + params_rep[..., jid, 0] = cos + params_rep[..., jid, 1] = -sin + + sid = eid + + if is_np: + params_rep = params_rep.detach().cpu().numpy() + return params_rep + + +# Deprecated. +def dof3_to_q(rot, axises:List, flip:List): + ''' + Convert a rotation matrix to SKEL style rotation representation. + + ### Args + - rot: torch.Tensor, shape (...B, 3, 3) + - The rotation matrix. + - axises: list + - [[x0, y0, z0], [x1, y1, z1], [x2, y2, z2]] + - The axis defined in the SKEL's joint_dict. Only one of xi, yi, zi is 1, the others are 0. + - flip: list + - [f0, f1, f2] + - The flip value defined in the SKEL's joint_dict. fi is 1 or -1. + + ### Returns + - shape = (...B, 3) + ''' + convention = [axis2convention(axis) for axis in reversed(axises)] # SKEL use euler angle in reverse order + convention = ''.join(convention) + q = matrix_to_euler_angles(rot[..., :, :], convention=convention) # (...B, 3) + q = q[..., [2, 1, 0]] # SKEL use euler angle in reverse order + flip = rot.new_tensor(flip) # (3,) + q = flip * q + return q + + +### Slow version, deprecated. ### +# def params_rep2q(params_rot:Union[torch.Tensor, np.ndarray]): +# ''' +# Transform the continuous representation back to the SKEL style euler-like representation. +# +# ### Args +# - params_rot: Union[torch.Tensor, np.ndarray] +# - shape = (...B, 24, 6) +# +# ### Returns +# - shape = (...B, 46) +# ''' +# +# # Check the type and unify to torch. +# is_np = isinstance(params_rot, np.ndarray) +# if is_np: +# params_rot = torch.from_numpy(params_rot) +# +# # Prepare for necessary variables. +# Bs = params_rot.shape[:-2] +# params_q = params_rot.new_zeros((*Bs, 46)) # (...B, 46) +# +# for jid in range(N_JOINTS): +# joint_obj = JOINTS_DEF[jid] +# dof = joint_obj.nb_dof.item() +# sid, eid = JID2QIDS[jid][0], JID2QIDS[jid][-1] + 1 +# +# if dof == 3: +# mat = rotation_6d_to_matrix(params_rot[..., jid, :]) # (...B, 3, 3) +# params_q[..., sid:eid] = dof3_to_q( +# mat, +# joint_obj.axis.tolist(), +# joint_obj.axis_flip.detach().cpu().tolist(), +# ) +# elif dof == 2: +# params_q[..., sid:eid] = params_rot[..., jid, :2] +# else: +# params_q[..., sid:eid] = rotation_2d_to_angle(params_rot[..., jid, :2]) +# +# if is_np: +# params_q = params_q.detach().cpu().numpy() +# return params_q + +def orient_mat2q(orient_mat:torch.Tensor): + ''' This is a tool function for inspecting only. orient_mat ~ (...B, 3, 3)''' + poses_rep = orient_mat.new_zeros(orient_mat.shape[:-2] + (24, 6)) # (...B, 24, 6) + orient_rep = matrix_to_rotation_6d(orient_mat) # (...B, 6) + poses_rep[..., 0, :] = orient_rep + poses_q = params_rep2q(poses_rep) # (...B, 46) + return poses_q[..., :3] + + +# Pre-grouped the joints for different conventions +CON_GROUP2JIDS = {'YXZ': [0, 1, 6], 'YZX': [11, 12, 13], 'XZY': [14, 19], 'ZYX': [15, 20]} +CON_GROUP2FLIPS = {'YXZ': [[1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, -1.0, -1.0]], 'YZX': [[1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0]], 'XZY': [[1.0, -1.0, -1.0], [1.0, 1.0, 1.0]], 'ZYX': [[1.0, 1.0, 1.0], [1.0, 1.0, 1.0]]} +# Faster version. +def params_rep2q(params_rot:Union[torch.Tensor, np.ndarray]): + ''' + Transform the continuous representation back to the SKEL style euler-like representation. + + ### Args + - params_rot: Union[torch.Tensor, np.ndarray], shape = (...B, 24, 6) + + ### Returns + - shape = (...B, 46) + ''' + + with PM.time_monitor('params_rep2q'): + with PM.time_monitor('preprocess'): + params_rot, recover_type_back = to_tensor(params_rot, device=None, temporary=True) + + # Prepare for necessary variables. + Bs = params_rot.shape[:-2] + params_q = params_rot.new_zeros((*Bs, 46)) # (...B, 46) + + with PM.time_monitor(f'dof1&dof2'): + params_q[..., DoF1_QIDS] = rotation_2d_to_angle(params_rot[..., DoF1_JIDS, :2]).squeeze(-1) + params_q[..., DoF2_QIDS] = params_rot[..., DoF2_JIDS, :2].reshape(*Bs, -1) # (...B, J2=2 * 2) + + with PM.time_monitor(f'dof3'): + dof3_6ds = params_rot[..., DoF3_JIDS, :].reshape(*Bs, len(DoF3_JIDS), 6) # (...B, J3=10, 3, 6) + dof3_mats = rotation_6d_to_matrix(dof3_6ds) # (...B, J3=10, 3, 3) + + for convention, jids in CON_GROUP2JIDS.items(): + idxs = [DoF3_JIDS.index(jid) for jid in jids] + mats = dof3_mats[..., idxs, :, :] # (...B, J', 3, 3) + qs = matrix_to_euler_angles(mats, convention=convention) # (...B, J', 3) + qs = qs[..., [2, 1, 0]] # SKEL use euler angle in reverse order + flips = qs.new_tensor(CON_GROUP2FLIPS[convention]) # (J', 3) + qs = qs * flips # (...B, J', 3) + qids = [qid for jid in jids for qid in JID2QIDS[jid]] + params_q[..., qids] = qs.reshape(*Bs, -1) + + with PM.time_monitor('post_process'): + params_q = recover_type_back(params_q) + return params_q \ No newline at end of file diff --git a/lib/body_models/skel_wrapper.py b/lib/body_models/skel_wrapper.py new file mode 100644 index 0000000000000000000000000000000000000000..d64065ac2a18f055ee11b1de038ddfce077958fa --- /dev/null +++ b/lib/body_models/skel_wrapper.py @@ -0,0 +1,146 @@ +from lib.kits.basic import * + +import pickle +from smplx.vertex_joint_selector import VertexJointSelector +from smplx.vertex_ids import vertex_ids +from smplx.lbs import vertices2joints + +from lib.body_models.skel.skel_model import SKEL, SKELOutput + +class SKELWrapper(SKEL): + def __init__( + self, + *args, + joint_regressor_custom: Optional[str] = None, + joint_regressor_extra : Optional[str] = None, + update_root : bool = False, + **kwargs + ): + ''' This wrapper aims to extend the output joints of the SKEL model which fits SMPL's portal. ''' + + super(SKELWrapper, self).__init__(*args, **kwargs) + + # The final joints are combined from three parts: + # 1. The joints from the standard output. + # Map selected joints of interests from SKEL to SMPL. (Not all 24 joints will be used finally.) + # Notes: Only these SMPL joints will be used: [0, 1, 2, 4, 5, 7, 8, 12, 16, 17, 18, 19, 20, 21, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 45, 46, 47, 48, 49, 50, 51, 52, 53] + skel_to_smpl = [ + 0, + 6, + 1, + 11, # not aligned well; not used + 7, + 2, + 11, # not aligned well; not used + 8, # or 9 + 3, # or 4 + 12, # not aligned well; not used + 10, # not used + 5, # not used + 12, + 19, # not aligned well; not used + 14, # not aligned well; not used + 13, # not used + 20, # or 19 + 15, # or 14 + 21, # or 22 + 16, # or 17, + 23, + 18, + 23, # not aligned well; not used + 18, # not aligned well; not used + ] + + smpl_to_openpose = [24, 12, 17, 19, 21, 16, 18, 20, 0, 2, 5, 8, 1, 4, 7, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34] + + self.register_buffer('J_skel_to_smpl', torch.tensor(skel_to_smpl, dtype=torch.long)) + self.register_buffer('J_smpl_to_openpose', torch.tensor(smpl_to_openpose, dtype=torch.long)) + # (SKEL has the same topology as SMPL as well as SMPL-H, so perform the same operation for the other 2 parts.) + # 2. Joints selected from skin vertices. + self.vertex_joint_selector = VertexJointSelector(vertex_ids['smplh']) + # 3. Extra joints from the J_regressor_extra. + if joint_regressor_extra is not None: + self.register_buffer( + 'J_regressor_extra', + torch.tensor(pickle.load( + open(joint_regressor_extra, 'rb'), + encoding='latin1' + ), dtype=torch.float32) + ) + + self.custom_regress_joints = joint_regressor_custom is not None + if self.custom_regress_joints: + get_logger().info('Using customized joint regressor.') + with open(joint_regressor_custom, 'rb') as f: + J_regressor_custom = pickle.load(f, encoding='latin1') + if 'scipy.sparse' in str(type(J_regressor_custom)): + J_regressor_custom = J_regressor_custom.todense() # (24, 6890) + self.register_buffer( + 'J_regressor_custom', + torch.tensor( + J_regressor_custom, + dtype=torch.float32 + ) + ) + + self.update_root = update_root + + def forward(self, **kwargs) -> SKELOutput: # type: ignore + ''' Map the order of joints of SKEL to SMPL's. ''' + + if 'trans' not in kwargs.keys(): + kwargs['trans'] = kwargs['poses'].new_zeros((kwargs['poses'].shape[0], 3)) # (B, 3) + + skel_output = super(SKELWrapper, self).forward(**kwargs) + verts = skel_output.skin_verts # (B, 6890, 3) + joints = skel_output.joints.clone() # (B, 24, 3) + + # Update the root joint position (to avoid the root too forward). + if self.update_root: + # make root 0 to plane 11-1-6 + hips_middle = (joints[:, 1] + joints[:, 6]) / 2 # (B, 3) + lumbar2middle = (hips_middle - joints[:, 11]) # (B, 3) + lumbar2middle_unit = lumbar2middle / torch.norm(lumbar2middle, dim=1, keepdim=True) # (B, 3) + lumbar2root = joints[:, 0] - joints[:, 11] + lumbar2root_proj = \ + torch.einsum('bc,bc->b', lumbar2root, lumbar2middle_unit)[:, None] *\ + lumbar2middle_unit # (B, 3) + root2root_proj = lumbar2root_proj - lumbar2root # (B, 3) + joints[:, 0] += root2root_proj * 0.7 + + # Combine the joints from three parts: + if self.custom_regress_joints: + # 1.x. Regress joints from the skin vertices using SMPL's regressor. + joints = vertices2joints(self.J_regressor_custom, verts) # (B, 24, 3) + else: + # 1.y. Map selected joints of interests from SKEL to SMPL. + joints = joints[:, self.J_skel_to_smpl] # (B, 24, 3) + joints_custom = joints.clone() + # 2. Concat joints selected from skin vertices. + joints = self.vertex_joint_selector(verts, joints) # (B, 45, 3) + # 3. Map selected joints to OpenPose. + joints = joints[:, self.J_smpl_to_openpose] # (B, 25, 3) + # 4. Add extra joints from the J_regressor_extra. + joints_extra = vertices2joints(self.J_regressor_extra, verts) # (B, 19, 3) + joints = torch.cat([joints, joints_extra], dim=1) # (B, 44, 3) + + # Update the joints in the output. + skel_output.joints_backup = skel_output.joints + skel_output.joints_custom = joints_custom + skel_output.joints = joints + + return skel_output + + + @staticmethod + def get_static_root_offset(skel_output): + ''' + Background: + By default, the orientation rotation is always around the original skel_root. + In order to make the orientation rotation around the custom_root, we need to calculate the translation offset. + This function calculates the translation offset in static pose. (From custom_root to skel_root.) + ''' + custom_root = skel_output.joints_custom[:, 0] # (B, 3) + skel_root = skel_output.joints_backup[:, 0] # (B, 3) + offset = skel_root - custom_root # (B, 3) + return offset \ No newline at end of file diff --git a/lib/body_models/smpl_utils/reality.py b/lib/body_models/smpl_utils/reality.py new file mode 100644 index 0000000000000000000000000000000000000000..f195f19fb48f2ff8ef0f4468cb87e68c75a769d3 --- /dev/null +++ b/lib/body_models/smpl_utils/reality.py @@ -0,0 +1,125 @@ +from lib.kits.basic import * + +from lib.utils.geometry.rotation import axis_angle_to_matrix + + +def get_lim_cfg(tol_deg=5): + tol_limit = np.deg2rad(tol_deg) + lim_cfg = { + 'l_knee': { + 'jid': 4, + 'convention': 'XZY', + 'limitation': [ + [-tol_limit, 3/4*np.pi+tol_limit], + [-tol_limit, tol_limit], + [-tol_limit, tol_limit], + ] + }, + 'r_knee': { + 'jid': 5, + 'convention': 'XZY', + 'limitation': [ + [-tol_limit, 3/4*np.pi+tol_limit], + [-tol_limit, tol_limit], + [-tol_limit, tol_limit], + ] + }, + 'l_elbow': { + 'jid': 18, + 'convention': 'YZX', + 'limitation': [ + [-(3/4)*np.pi-tol_limit, tol_limit], + [-tol_limit, tol_limit], + [-3/4*np.pi/2-tol_limit, 3/4*np.pi/2+tol_limit], + ] + }, + 'r_elbow': { + 'jid': 19, + 'convention': 'YZX', + 'limitation': [ + [-tol_limit, (3/4)*np.pi+tol_limit], + [-tol_limit, tol_limit], + [-3/4*np.pi/2-tol_limit, 3/4*np.pi/2+tol_limit], + ] + }, + } + return lim_cfg + + +def matrix_to_possible_euler_angles(matrix: torch.Tensor, convention: str): + ''' + Convert rotations given as rotation matrices to Euler angles in radians. + + ### Args + matrix: Rotation matrices as tensor of shape (..., 3, 3). + convention: Convention string of three uppercase letters. + + ### Returns + List of possible euler angles in radians as tensor of shape (..., 3). + ''' + from lib.utils.geometry.rotation import _index_from_letter, _angle_from_tan + if len(convention) != 3: + raise ValueError("Convention must have 3 letters.") + if convention[1] in (convention[0], convention[2]): + raise ValueError(f"Invalid convention {convention}.") + for letter in convention: + if letter not in ("X", "Y", "Z"): + raise ValueError(f"Invalid letter {letter} in convention string.") + if matrix.size(-1) != 3 or matrix.size(-2) != 3: + raise ValueError(f"Invalid rotation matrix shape {matrix.shape}.") + i0 = _index_from_letter(convention[0]) + i2 = _index_from_letter(convention[2]) + tait_bryan = i0 != i2 + central_angle_possible = [] + if tait_bryan: + central_angle = torch.asin( + matrix[..., i0, i2] * (-1.0 if i0 - i2 in [-1, 2] else 1.0) + ) + central_angle_possible = [central_angle, np.pi - central_angle] + else: + central_angle = torch.acos(matrix[..., i0, i0]) + central_angle_possible = [central_angle, -central_angle] + + o_possible = [] + for central_angle in central_angle_possible: + o = ( + _angle_from_tan( + convention[0], convention[1], matrix[..., i2], False, tait_bryan + ), + central_angle, + _angle_from_tan( + convention[2], convention[1], matrix[..., i0, :], True, tait_bryan + ), + ) + o_possible.append(torch.stack(o, -1)) + return o_possible + + +def eval_rot_delta(body_pose, tol_deg=5): + lim_cfg = get_lim_cfg(tol_deg) + res ={} + for name, cfg in lim_cfg.items(): + jid = cfg['jid'] - 1 + cvt = cfg['convention'] + lim = cfg['limitation'] + aa = body_pose[:, jid, :] # (B, 3) + mt = axis_angle_to_matrix(aa) # (B, 3, 3) + ea_possible = matrix_to_possible_euler_angles(mt, cvt) # (B, 3) + violation_reasonable = None + for ea in ea_possible: + violation = ea.new_zeros(ea.shape) # (B, 3) + + for i in range(3): + ea_i = ea[:, i] + ea_i = (ea_i + np.pi) % (2 * np.pi) - np.pi # Normalize to (-pi, pi) + exceed_lb = torch.where(ea_i < lim[i][0], ea_i - lim[i][0], 0) + exceed_ub = torch.where(ea_i > lim[i][1], ea_i - lim[i][1], 0) + violation[:, i] = exceed_lb.abs() + exceed_ub.abs() # (B, 3) + if violation_reasonable is not None: # minimize the violation + upd_mask = violation.sum(-1) < violation_reasonable.sum(-1) + violation_reasonable[upd_mask] = violation[upd_mask] + else: + violation_reasonable = violation + + res[name] = violation_reasonable + return res \ No newline at end of file diff --git a/lib/body_models/smpl_utils/transforms.py b/lib/body_models/smpl_utils/transforms.py new file mode 100644 index 0000000000000000000000000000000000000000..e998be7c38a960ba4db3d31923006022f8baaced --- /dev/null +++ b/lib/body_models/smpl_utils/transforms.py @@ -0,0 +1,33 @@ +import torch +import numpy as np + +from typing import Dict + + +def fliplr_params(smpl_params: Dict): + global_orient = smpl_params['global_orient'].copy().reshape(-1, 3) + body_pose = smpl_params['body_pose'].copy().reshape(-1, 69) + betas = smpl_params['betas'].copy() + + body_pose_permutation = [6, 7, 8, 3, 4, 5, 9, 10, 11, 15, 16, 17, 12, 13, + 14 ,18, 19, 20, 24, 25, 26, 21, 22, 23, 27, 28, 29, 33, + 34, 35, 30, 31, 32, 36, 37, 38, 42, 43, 44, 39, 40, 41, + 45, 46, 47, 51, 52, 53, 48, 49, 50, 57, 58, 59, 54, 55, + 56, 63, 64, 65, 60, 61, 62, 69, 70, 71, 66, 67, 68] + body_pose_permutation = body_pose_permutation[:body_pose.shape[1]] + body_pose_permutation = [i-3 for i in body_pose_permutation] + + body_pose = body_pose[:, body_pose_permutation] + + global_orient[:, 1::3] *= -1 + global_orient[:, 2::3] *= -1 + body_pose[:, 1::3] *= -1 + body_pose[:, 2::3] *= -1 + + smpl_params = {'global_orient': global_orient.reshape(-1, 1, 3).astype(np.float32), + 'body_pose': body_pose.reshape(-1, 23, 3).astype(np.float32), + 'betas': betas.astype(np.float32) + } + + return smpl_params + diff --git a/lib/body_models/smpl_wrapper.py b/lib/body_models/smpl_wrapper.py new file mode 100644 index 0000000000000000000000000000000000000000..afdc0b0c6abc9021bc699aa953824da286330f58 --- /dev/null +++ b/lib/body_models/smpl_wrapper.py @@ -0,0 +1,43 @@ +import torch +import numpy as np +import pickle +from typing import Optional +import smplx +from smplx.lbs import vertices2joints +from smplx.utils import SMPLOutput + + +class SMPLWrapper(smplx.SMPLLayer): + def __init__(self, *args, joint_regressor_extra: Optional[str] = None, update_hips: bool = False, **kwargs): + """ + Extension of the official SMPL implementation to support more joints. + Args: + Same as SMPLLayer. + joint_regressor_extra (str): Path to extra joint regressor. + """ + super(SMPLWrapper, self).__init__(*args, **kwargs) + smpl_to_openpose = [24, 12, 17, 19, 21, 16, 18, 20, 0, 2, 5, 8, 1, 4, + 7, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34] + + if joint_regressor_extra is not None: + self.register_buffer('joint_regressor_extra', torch.tensor(pickle.load(open(joint_regressor_extra, 'rb'), encoding='latin1'), dtype=torch.float32)) + self.register_buffer('joint_map', torch.tensor(smpl_to_openpose, dtype=torch.long)) + self.update_hips = update_hips + + def forward(self, *args, **kwargs) -> SMPLOutput: + """ + Run forward pass. Same as SMPL and also append an extra set of joints if joint_regressor_extra is specified. + """ + smpl_output = super(SMPLWrapper, self).forward(*args, **kwargs) + joints_smpl = smpl_output.joints.clone() + joints = smpl_output.joints[:, self.joint_map, :] + if self.update_hips: + joints[:,[9,12]] = joints[:,[9,12]] + \ + 0.25*(joints[:,[9,12]]-joints[:,[12,9]]) + \ + 0.5*(joints[:,[8]] - 0.5*(joints[:,[9,12]] + joints[:,[12,9]])) + if hasattr(self, 'joint_regressor_extra'): + extra_joints = vertices2joints(self.joint_regressor_extra, smpl_output.vertices) + joints = torch.cat([joints, extra_joints], dim=1) + smpl_output.joints = joints + smpl_output.joints_smpl = joints_smpl + return smpl_output diff --git a/lib/data/augmentation/skel.py b/lib/data/augmentation/skel.py new file mode 100644 index 0000000000000000000000000000000000000000..d1e0567a6e76a07104b87718816994377b0c3a29 --- /dev/null +++ b/lib/data/augmentation/skel.py @@ -0,0 +1,90 @@ +from lib.kits.basic import * + +from lib.utils.data import * +from lib.utils.geometry.rotation import ( + euler_angles_to_matrix, + axis_angle_to_matrix, + matrix_to_euler_angles, +) +from lib.body_models.skel_utils.transforms import params_q2rot + + +def rot_q_orient( + q : Union[np.ndarray, torch.Tensor], + rot_rad : Union[np.ndarray, torch.Tensor], +): + ''' + ### Args + - q: np.ndarray or tensor, shape = (B, 3) + - SKEL style rotation representation. + - rot: np.ndarray or tensor, shape = (B,) + - Rotation angle in radian. + ### Returns + - np.ndarray: Rotated orientation SKEL q. + ''' + # Transform skel q to rot mat. + q, recover_type_back = to_tensor(q, device=None, temporary=True) # (B, 3) + q = q[:, [2, 1, 0]] + Rp = euler_angles_to_matrix(q, convention="YXZ") + # Rotate around z + rot = to_tensor(-rot_rad, device=q.device).float() # (B,) + padding_zeros = torch.zeros_like(rot) # (B,) + R = torch.stack([rot, padding_zeros, padding_zeros], dim=1) # (B, 3) + R = axis_angle_to_matrix(R) + R = torch.matmul(R, Rp) + # Transform rot mat to skel q. + q = matrix_to_euler_angles(R, convention="YXZ") + q = q[:, [2, 1, 0]] + q = recover_type_back(q) # (B, 3) + + return q + + +def rot_skel_on_plane( + params : Union[Dict, np.ndarray, torch.Tensor], + rot_deg : Union[np.ndarray, torch.Tensor, List[float]] +): + ''' + Rotate the skel parameters on the plane (around the z-axis), + in order to align the skel with the rotated image. To perform + this operation, we need to modify the orientation of the skel + parameters. + + ### Args + - params: Dict or (np.ndarray or torch.Tensor) + - If is dict, it should contain the following keys + - 'poses': np.ndarray or torch.Tensor (B, 72) + - ... + - If is np.ndarray or torch.Tensor, it should be the 'poses' part. + - rot_deg: np.ndarray, torch.Tensor or List[float] + - Rotation angle in degrees. + + ### Returns + - One of the following according to the input type: + - Dict: Modified skel parameters. + - np.ndarray or torch.Tensor: Modified skel poses parameters. + ''' + rot_deg = to_numpy(rot_deg) # (B,) + rot_rad = np.deg2rad(rot_deg) # (B,) + + if isinstance(params, Dict): + ret = {} + for k, v in params.items(): + if isinstance(v, np.ndarray): + ret[k] = v.copy() + elif isinstance(v, torch.Tensor): + ret[k] = v.clone() + else: + ret[k] = v + ret['poses'][:, :3] = rot_q_orient(ret['poses'][:, :3], rot_rad) + elif isinstance(params, (np.ndarray, torch.Tensor)): + if isinstance(params, np.ndarray): + ret = params.copy() + elif isinstance(params, torch.Tensor): + ret = params.clone() + else: + raise TypeError(f'Unsupported type: {type(params)}') + ret[:, :3] = rot_q_orient(ret[:, :3], rot_rad) + else: + raise TypeError(f'Unsupported type: {type(params)}') + return ret \ No newline at end of file diff --git a/lib/data/datasets/hsmr_eval_3d/eval3d_dataset.py b/lib/data/datasets/hsmr_eval_3d/eval3d_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..fac7c53eee1b6e98fdfb8b69c0bce6f8e67b4f30 --- /dev/null +++ b/lib/data/datasets/hsmr_eval_3d/eval3d_dataset.py @@ -0,0 +1,88 @@ +from lib.kits.basic import * + +from torch.utils import data +from lib.utils.media import load_img, flex_resize_img +from lib.utils.bbox import cwh_to_cs, cs_to_lurb, crop_with_lurb, fit_bbox_to_aspect_ratio +from lib.utils.data import to_numpy + +IMG_MEAN_255 = to_numpy([0.485, 0.456, 0.406]) * 255. +IMG_STD_255 = to_numpy([0.229, 0.224, 0.225]) * 255. + + +class Eval3DDataset(data.Dataset): + + def __init__(self, npz_fn:Union[str, Path], ignore_img=False): + super().__init__() + self.data = None + self._load_data(npz_fn) + self.ds_root = self._get_ds_root() + self.bbox_ratio = (192, 256) # the ViT Backbone's input size is w=192, h=256 + self.ignore_img = ignore_img # For speed up, if True, don't process images. + + def _load_data(self, npz_fn:Union[str, Path]): + supported_datasets = ['MoYo'] + raw_data = np.load(npz_fn, allow_pickle=True) + + # Load some meta data. + self.extra_info = raw_data['extra_info'].item() + self.ds_name = self.extra_info.pop('dataset_name') + # Load basic information. + self.seq_names = raw_data['names'] # (L,) + self.img_paths = raw_data['img_paths'] # (L,) + self.bbox_centers = raw_data['centers'].astype(np.float32) # (L, 2) + self.bbox_scales = raw_data['scales'].astype(np.float32) # (L, 2) + self.L = len(self.seq_names) + # Load the g.t. SMPL parameters. + self.genders = raw_data.get('genders', None) # (L, 2) or None + self.global_orient = raw_data['smpl'].item()['global_orient'].reshape(-1, 1 ,3).astype(np.float32) # (L, 1, 3) + self.body_pose = raw_data['smpl'].item()['body_pose'].reshape(-1, 23 ,3).astype(np.float32) # (L, 23, 3) + self.betas = raw_data['smpl'].item()['betas'].reshape(-1, 10).astype(np.float32) # (L, 10) + # Check validity. + assert self.ds_name in supported_datasets, f'Unsupported dataset: {self.ds_name}' + + + def __len__(self): + return self.L + + + def _process_img_patch(self, idx): + ''' Load and crop according to bbox. ''' + if self.ignore_img: + return np.zeros((1), dtype=np.float32) + + img, _ = load_img(self.ds_root / self.img_paths[idx]) # (H, W, RGB) + scale = self.bbox_scales[idx] # (2,) + center = self.bbox_centers[idx] # (2,) + bbox_cwh = np.concatenate([center, scale], axis=0) # (4,) lurb format + bbox_cwh = fit_bbox_to_aspect_ratio( + bbox = bbox_cwh, + tgt_ratio = self.bbox_ratio, + bbox_type = 'cwh' + ) + bbox_cs = cwh_to_cs(bbox_cwh, reduce='max') # (3,), make it to square + bbox_lurb = cs_to_lurb(bbox_cs) # (4,) + img_patch = crop_with_lurb(img, bbox_lurb) # (H', W', RGB) + img_patch = flex_resize_img(img_patch, tgt_wh=(256, 256)) + img_patch_normalized = (img_patch - IMG_MEAN_255) / IMG_STD_255 # (H', W', RGB) + img_patch_normalized = img_patch_normalized.transpose(2, 0, 1) # (RGB, H', W') + return img_patch_normalized.astype(np.float32) + + + def _get_ds_root(self): + return PM.inputs / 'datasets' / self.ds_name.lower() + + + def __getitem__(self, idx): + ret = {} + ret['seq_name'] = self.seq_names[idx] + ret['smpl'] = { + 'global_orient': self.global_orient[idx], + 'body_pose' : self.body_pose[idx], + 'betas' : self.betas[idx], + } + if self.genders is not None: + ret['gender'] = self.genders[idx] + ret['img_patch'] = self._process_img_patch(idx) + + return ret + diff --git a/lib/data/datasets/hsmr_v1/crop.py b/lib/data/datasets/hsmr_v1/crop.py new file mode 100644 index 0000000000000000000000000000000000000000..9c3b8493f0f2aa92f34fcc9e9a93265f73a60869 --- /dev/null +++ b/lib/data/datasets/hsmr_v1/crop.py @@ -0,0 +1,295 @@ +from lib.kits.basic import * + +# Copied from: https://github.com/shubham-goel/4D-Humans/blob/6ec79656a23c33237c724742ca2a0ec00b398b53/hmr2/datasets/utils.py#L663-L944 + +def crop_to_hips(center_x: float, center_y: float, width: float, height: float, keypoints_2d: np.ndarray) -> Tuple: + """ + Extreme cropping: Crop the box up to the hip locations. + Args: + center_x (float): x coordinate of the bounding box center. + center_y (float): y coordinate of the bounding box center. + width (float): Bounding box width. + height (float): Bounding box height. + keypoints_2d (np.ndarray): Array of shape (N, 3) containing 2D keypoint locations. + Returns: + center_x (float): x coordinate of the new bounding box center. + center_y (float): y coordinate of the new bounding box center. + width (float): New bounding box width. + height (float): New bounding box height. + """ + keypoints_2d = keypoints_2d.copy() + lower_body_keypoints = [10, 11, 13, 14, 19, 20, 21, 22, 23, 24, 25+0, 25+1, 25+4, 25+5] + keypoints_2d[lower_body_keypoints, :] = 0 + if keypoints_2d[:, -1].sum() > 1: + center, scale = get_bbox(keypoints_2d) + center_x = center[0] + center_y = center[1] + width = 1.1 * scale[0] + height = 1.1 * scale[1] + return center_x, center_y, width, height + + +def crop_to_shoulders(center_x: float, center_y: float, width: float, height: float, keypoints_2d: np.ndarray): + """ + Extreme cropping: Crop the box up to the shoulder locations. + Args: + center_x (float): x coordinate of the bounding box center. + center_y (float): y coordinate of the bounding box center. + width (float): Bounding box width. + height (float): Bounding box height. + keypoints_2d (np.ndarray): Array of shape (N, 3) containing 2D keypoint locations. + Returns: + center_x (float): x coordinate of the new bounding box center. + center_y (float): y coordinate of the new bounding box center. + width (float): New bounding box width. + height (float): New bounding box height. + """ + keypoints_2d = keypoints_2d.copy() + lower_body_keypoints = [3, 4, 6, 7, 8, 9, 10, 11, 12, 13, 14, 19, 20, 21, 22, 23, 24] + [25 + i for i in [0, 1, 2, 3, 4, 5, 6, 7, 10, 11, 14, 15, 16]] + keypoints_2d[lower_body_keypoints, :] = 0 + center, scale = get_bbox(keypoints_2d) + if keypoints_2d[:, -1].sum() > 1: + center, scale = get_bbox(keypoints_2d) + center_x = center[0] + center_y = center[1] + width = 1.2 * scale[0] + height = 1.2 * scale[1] + return center_x, center_y, width, height + + +def crop_to_head(center_x: float, center_y: float, width: float, height: float, keypoints_2d: np.ndarray): + """ + Extreme cropping: Crop the box and keep on only the head. + Args: + center_x (float): x coordinate of the bounding box center. + center_y (float): y coordinate of the bounding box center. + width (float): Bounding box width. + height (float): Bounding box height. + keypoints_2d (np.ndarray): Array of shape (N, 3) containing 2D keypoint locations. + Returns: + center_x (float): x coordinate of the new bounding box center. + center_y (float): y coordinate of the new bounding box center. + width (float): New bounding box width. + height (float): New bounding box height. + """ + keypoints_2d = keypoints_2d.copy() + lower_body_keypoints = [3, 4, 6, 7, 8, 9, 10, 11, 12, 13, 14, 19, 20, 21, 22, 23, 24] + [25 + i for i in [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 14, 15, 16]] + keypoints_2d[lower_body_keypoints, :] = 0 + if keypoints_2d[:, -1].sum() > 1: + center, scale = get_bbox(keypoints_2d) + center_x = center[0] + center_y = center[1] + width = 1.3 * scale[0] + height = 1.3 * scale[1] + return center_x, center_y, width, height + + +def crop_torso_only(center_x: float, center_y: float, width: float, height: float, keypoints_2d: np.ndarray): + """ + Extreme cropping: Crop the box and keep on only the torso. + Args: + center_x (float): x coordinate of the bounding box center. + center_y (float): y coordinate of the bounding box center. + width (float): Bounding box width. + height (float): Bounding box height. + keypoints_2d (np.ndarray): Array of shape (N, 3) containing 2D keypoint locations. + Returns: + center_x (float): x coordinate of the new bounding box center. + center_y (float): y coordinate of the new bounding box center. + width (float): New bounding box width. + height (float): New bounding box height. + """ + keypoints_2d = keypoints_2d.copy() + nontorso_body_keypoints = [0, 3, 4, 6, 7, 10, 11, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24] + [25 + i for i in [0, 1, 4, 5, 6, 7, 10, 11, 13, 17, 18]] + keypoints_2d[nontorso_body_keypoints, :] = 0 + if keypoints_2d[:, -1].sum() > 1: + center, scale = get_bbox(keypoints_2d) + center_x = center[0] + center_y = center[1] + width = 1.1 * scale[0] + height = 1.1 * scale[1] + return center_x, center_y, width, height + + +def crop_rightarm_only(center_x: float, center_y: float, width: float, height: float, keypoints_2d: np.ndarray): + """ + Extreme cropping: Crop the box and keep on only the right arm. + Args: + center_x (float): x coordinate of the bounding box center. + center_y (float): y coordinate of the bounding box center. + width (float): Bounding box width. + height (float): Bounding box height. + keypoints_2d (np.ndarray): Array of shape (N, 3) containing 2D keypoint locations. + Returns: + center_x (float): x coordinate of the new bounding box center. + center_y (float): y coordinate of the new bounding box center. + width (float): New bounding box width. + height (float): New bounding box height. + """ + keypoints_2d = keypoints_2d.copy() + nonrightarm_body_keypoints = [0, 1, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24] + [25 + i for i in [0, 1, 2, 3, 4, 5, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18]] + keypoints_2d[nonrightarm_body_keypoints, :] = 0 + if keypoints_2d[:, -1].sum() > 1: + center, scale = get_bbox(keypoints_2d) + center_x = center[0] + center_y = center[1] + width = 1.1 * scale[0] + height = 1.1 * scale[1] + return center_x, center_y, width, height + + +def crop_leftarm_only(center_x: float, center_y: float, width: float, height: float, keypoints_2d: np.ndarray): + """ + Extreme cropping: Crop the box and keep on only the left arm. + Args: + center_x (float): x coordinate of the bounding box center. + center_y (float): y coordinate of the bounding box center. + width (float): Bounding box width. + height (float): Bounding box height. + keypoints_2d (np.ndarray): Array of shape (N, 3) containing 2D keypoint locations. + Returns: + center_x (float): x coordinate of the new bounding box center. + center_y (float): y coordinate of the new bounding box center. + width (float): New bounding box width. + height (float): New bounding box height. + """ + keypoints_2d = keypoints_2d.copy() + nonleftarm_body_keypoints = [0, 1, 2, 3, 4, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24] + [25 + i for i in [0, 1, 2, 3, 4, 5, 6, 7, 8, 12, 13, 14, 15, 16, 17, 18]] + keypoints_2d[nonleftarm_body_keypoints, :] = 0 + if keypoints_2d[:, -1].sum() > 1: + center, scale = get_bbox(keypoints_2d) + center_x = center[0] + center_y = center[1] + width = 1.1 * scale[0] + height = 1.1 * scale[1] + return center_x, center_y, width, height + + +def crop_legs_only(center_x: float, center_y: float, width: float, height: float, keypoints_2d: np.ndarray): + """ + Extreme cropping: Crop the box and keep on only the legs. + Args: + center_x (float): x coordinate of the bounding box center. + center_y (float): y coordinate of the bounding box center. + width (float): Bounding box width. + height (float): Bounding box height. + keypoints_2d (np.ndarray): Array of shape (N, 3) containing 2D keypoint locations. + Returns: + center_x (float): x coordinate of the new bounding box center. + center_y (float): y coordinate of the new bounding box center. + width (float): New bounding box width. + height (float): New bounding box height. + """ + keypoints_2d = keypoints_2d.copy() + nonlegs_body_keypoints = [0, 1, 2, 3, 4, 5, 6, 7, 15, 16, 17, 18] + [25 + i for i in [6, 7, 8, 9, 10, 11, 12, 13, 15, 16, 17, 18]] + keypoints_2d[nonlegs_body_keypoints, :] = 0 + if keypoints_2d[:, -1].sum() > 1: + center, scale = get_bbox(keypoints_2d) + center_x = center[0] + center_y = center[1] + width = 1.1 * scale[0] + height = 1.1 * scale[1] + return center_x, center_y, width, height + + +def crop_rightleg_only(center_x: float, center_y: float, width: float, height: float, keypoints_2d: np.ndarray): + """ + Extreme cropping: Crop the box and keep on only the right leg. + Args: + center_x (float): x coordinate of the bounding box center. + center_y (float): y coordinate of the bounding box center. + width (float): Bounding box width. + height (float): Bounding box height. + keypoints_2d (np.ndarray): Array of shape (N, 3) containing 2D keypoint locations. + Returns: + center_x (float): x coordinate of the new bounding box center. + center_y (float): y coordinate of the new bounding box center. + width (float): New bounding box width. + height (float): New bounding box height. + """ + keypoints_2d = keypoints_2d.copy() + nonrightleg_body_keypoints = [0, 1, 2, 3, 4, 5, 6, 7, 8, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21] + [25 + i for i in [3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18]] + keypoints_2d[nonrightleg_body_keypoints, :] = 0 + if keypoints_2d[:, -1].sum() > 1: + center, scale = get_bbox(keypoints_2d) + center_x = center[0] + center_y = center[1] + width = 1.1 * scale[0] + height = 1.1 * scale[1] + return center_x, center_y, width, height + +def crop_leftleg_only(center_x: float, center_y: float, width: float, height: float, keypoints_2d: np.ndarray): + """ + Extreme cropping: Crop the box and keep on only the left leg. + Args: + center_x (float): x coordinate of the bounding box center. + center_y (float): y coordinate of the bounding box center. + width (float): Bounding box width. + height (float): Bounding box height. + keypoints_2d (np.ndarray): Array of shape (N, 3) containing 2D keypoint locations. + Returns: + center_x (float): x coordinate of the new bounding box center. + center_y (float): y coordinate of the new bounding box center. + width (float): New bounding box width. + height (float): New bounding box height. + """ + keypoints_2d = keypoints_2d.copy() + nonleftleg_body_keypoints = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 15, 16, 17, 18, 22, 23, 24] + [25 + i for i in [0, 1, 2, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18]] + keypoints_2d[nonleftleg_body_keypoints, :] = 0 + if keypoints_2d[:, -1].sum() > 1: + center, scale = get_bbox(keypoints_2d) + center_x = center[0] + center_y = center[1] + width = 1.1 * scale[0] + height = 1.1 * scale[1] + return center_x, center_y, width, height + + +def full_body(keypoints_2d: np.ndarray) -> bool: + """ + Check if all main body joints are visible. + Args: + keypoints_2d (np.ndarray): Array of shape (N, 3) containing 2D keypoint locations. + Returns: + bool: True if all main body joints are visible. + """ + + body_keypoints_openpose = [2, 3, 4, 5, 6, 7, 10, 11, 13, 14] + body_keypoints = [25 + i for i in [8, 7, 6, 9, 10, 11, 1, 0, 4, 5]] + return (np.maximum(keypoints_2d[body_keypoints, -1], keypoints_2d[body_keypoints_openpose, -1]) > 0).sum() == len(body_keypoints) + + +def upper_body(keypoints_2d: np.ndarray): + """ + Check if all upper body joints are visible. + Args: + keypoints_2d (np.ndarray): Array of shape (N, 3) containing 2D keypoint locations. + Returns: + bool: True if all main body joints are visible. + """ + lower_body_keypoints_openpose = [10, 11, 13, 14] + lower_body_keypoints = [25 + i for i in [1, 0, 4, 5]] + upper_body_keypoints_openpose = [0, 1, 15, 16, 17, 18] + upper_body_keypoints = [25+8, 25+9, 25+12, 25+13, 25+17, 25+18] + return ((keypoints_2d[lower_body_keypoints + lower_body_keypoints_openpose, -1] > 0).sum() == 0)\ + and ((keypoints_2d[upper_body_keypoints + upper_body_keypoints_openpose, -1] > 0).sum() >= 2) + + +def get_bbox(keypoints_2d: np.ndarray, rescale: float = 1.2) -> Tuple: + """ + Get center and scale for bounding box from openpose detections. + Args: + keypoints_2d (np.ndarray): Array of shape (N, 3) containing 2D keypoint locations. + rescale (float): Scale factor to rescale bounding boxes computed from the keypoints. + Returns: + center (np.ndarray): Array of shape (2,) containing the new bounding box center. + scale (float): New bounding box scale. + """ + valid = keypoints_2d[:,-1] > 0 + valid_keypoints = keypoints_2d[valid][:,:-1] + center = 0.5 * (valid_keypoints.max(axis=0) + valid_keypoints.min(axis=0)) + bbox_size = (valid_keypoints.max(axis=0) - valid_keypoints.min(axis=0)) + # adjust bounding box tightness + scale = bbox_size + scale *= rescale + return center, scale diff --git a/lib/data/datasets/hsmr_v1/mocap_dataset.py b/lib/data/datasets/hsmr_v1/mocap_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..ca37c47ac7918d943aa5840d93499d8ac6a23058 --- /dev/null +++ b/lib/data/datasets/hsmr_v1/mocap_dataset.py @@ -0,0 +1,32 @@ +from lib.kits.basic import * + + +class MoCapDataset: + def __init__(self, dataset_file:str, pve_threshold:Optional[float]=None): + ''' + Dataset class used for loading a dataset of unpaired SMPL parameter annotations + ### Args + - dataset_file: str + - Path to the dataset npz file. + - pve_threshold: float + - Threshold for PVE quality filtering. + ''' + data = np.load(dataset_file) + if pve_threshold is not None: + pve = data['pve_max'] + mask = pve < pve_threshold + else: + mask = np.ones(len(data['poses']), dtype=np.bool) + self.poses = data['poses'].astype(np.float32)[mask, 3:] + self.betas = data['betas'].astype(np.float32)[mask] + self.length = len(self.poses) + get_logger().info(f'Loaded {self.length} items among {len(pve)} samples filtered from {dataset_file} (using threshold = {pve_threshold})') + + def __getitem__(self, idx: int) -> Dict: + poses = self.poses[idx].copy() + betas = self.betas[idx].copy() + item = {'poses_body': poses, 'betas': betas} + return item + + def __len__(self) -> int: + return self.length diff --git a/lib/data/datasets/hsmr_v1/stream_pipelines.py b/lib/data/datasets/hsmr_v1/stream_pipelines.py new file mode 100644 index 0000000000000000000000000000000000000000..6816b23192463cb3bdecf51d69a367018c5acdc2 --- /dev/null +++ b/lib/data/datasets/hsmr_v1/stream_pipelines.py @@ -0,0 +1,342 @@ +from lib.kits.basic import * + +import math +import webdataset as wds + + +from .utils import ( + get_augm_args, + expand_to_aspect_ratio, + generate_image_patch_cv2, + flip_lr_keypoints, + extreme_cropping_aggressive, +) + + +def apply_corrupt_filter(dataset:wds.WebDataset): + AIC_TRAIN_CORRUPT_KEYS = { + '0a047f0124ae48f8eee15a9506ce1449ee1ba669', '1a703aa174450c02fbc9cfbf578a5435ef403689', + '0394e6dc4df78042929b891dbc24f0fd7ffb6b6d', '5c032b9626e410441544c7669123ecc4ae077058', + 'ca018a7b4c5f53494006ebeeff9b4c0917a55f07', '4a77adb695bef75a5d34c04d589baf646fe2ba35', + 'a0689017b1065c664daef4ae2d14ea03d543217e', '39596a45cbd21bed4a5f9c2342505532f8ec5cbb', + '3d33283b40610d87db660b62982f797d50a7366b', + } + + CORRUPT_KEYS = { + *{f'aic-train/{k}' for k in AIC_TRAIN_CORRUPT_KEYS}, + *{f'aic-train-vitpose/{k}' for k in AIC_TRAIN_CORRUPT_KEYS}, + } + + dataset = dataset.select(lambda sample: (sample['__key__'] not in CORRUPT_KEYS)) + return dataset + + +def apply_multi_ppl_splitter(dataset:wds.WebDataset): + ''' + Each item in the raw dataset contains multiple people, we need to split them into individual samples. + Meanwhile, we also need to note down the person id (pid) for each sample. + ''' + def multi_ppl_splitter(source): + for item in source: + data_multi_ppl = item['data.pyd'] # list of data for multiple people + for pid, data in enumerate(data_multi_ppl): + data['pid'] = pid + + if 'detection.npz' in item: + det_idx = data['extra_info']['detection_npz_idx'] + mask = item['detection.npz']['masks'][det_idx] + else: + mask = np.ones_like(item['jpg'][:, :, 0], dtype=bool) + yield { + '__key__' : item['__key__'] + f'_{pid}', + 'img_name' : item['__key__'], + 'img' : item['jpg'], + 'data' : data, + 'mask' : mask, + } + + return dataset.compose(multi_ppl_splitter) + + +def apply_keys_adapter(dataset:wds.WebDataset): + ''' Adapt the keys of the items, so we can adapt different version of dataset. ''' + def keys_adapter(item): + data = item['data'] + data['kp2d'] = data.pop('keypoints_2d') + data['kp3d'] = data.pop('keypoints_3d') + return item + return dataset.map(keys_adapter) + + +def apply_bad_pgt_params_nan_suppressor(dataset:wds.WebDataset): + ''' If the poses or betas contain NaN, we regard it as bad pseudo-GT and zero them out. ''' + def bad_pgt_params_suppressor(item): + for side in ['orig', 'flip']: + poses = item['data'][f'{side}_poses'] # (J, 3) + betas = item['data'][f'{side}_betas'] # (10,) + poses_has_nan = np.isnan(poses).any() + betas_has_nan = np.isnan(betas).any() + if poses_has_nan or betas_has_nan: + item['data'][f'{side}_has_poses'] = False + item['data'][f'{side}_has_betas'] = False + if poses_has_nan: + item['data'][f'{side}_poses'][:] = 0 + if betas_has_nan: + item['data'][f'{side}_betas'][:] = 0 + return item + dataset = dataset.map(bad_pgt_params_suppressor) + return dataset + + +def apply_bad_pgt_params_kp2d_err_suppressor(dataset:wds.WebDataset, thresh:float=0.1): + ''' If the 2D keypoints error of one single person is higher than the threshold, we regard it as bad pseudo-GT. ''' + if thresh > 0: + def bad_pgt_params_suppressor(item): + for side in ['orig', 'flip']: + if thresh > 0: + kp2d_err = item['data'][f'{side}_kp2d_err'] + is_valid_pgt = kp2d_err < thresh + item['data'][f'{side}_has_poses'] = is_valid_pgt + item['data'][f'{side}_has_betas'] = is_valid_pgt + return item + dataset = dataset.map(bad_pgt_params_suppressor) + return dataset + + +def apply_bad_pgt_params_pve_max_suppressor(dataset:wds.WebDataset, thresh:float=0.1): + ''' If the PVE-Max of one single person is higher than the threshold, we regard it as bad pseudo-GT. ''' + if thresh > 0: + def bad_pgt_params_suppressor(item): + for side in ['orig', 'flip']: + if thresh > 0: + pve_max = item['data'][f'{side}_pve_max'] + is_valid_pose = not math.isnan(pve_max) + is_valid_pgt = pve_max < thresh and is_valid_pose + item['data'][f'{side}_has_poses'] = is_valid_pgt + item['data'][f'{side}_has_betas'] = is_valid_pgt + return item + dataset = dataset.map(bad_pgt_params_suppressor) + return dataset + + +def apply_bad_kp_suppressor(dataset:wds.WebDataset, thresh:float=0.0): + ''' If the confidence of a keypoint is lower than the threshold, we reset it to 0. ''' + eps = 1e-6 + if thresh > eps: + def bad_kp_suppressor(item): + if thresh > 0: + kp2d = item['data']['kp2d'] + kp2d_conf = np.where(kp2d[:, 2] < thresh, 0.0, kp2d[:, 2]) # suppress bad keypoints + item['data']['kp2d'] = np.concatenate([kp2d[:, :2], kp2d_conf[:, None]], axis=1) + return item + dataset = dataset.map(bad_kp_suppressor) + return dataset + + +def apply_bad_betas_suppressor(dataset:wds.WebDataset, thresh:float=3): + ''' If the absolute value of betas is higher than the threshold, we regard it as bad betas. ''' + eps = 1e-6 + if thresh > eps: + def bad_betas_suppressor(item): + for side in ['orig', 'flip']: + has_betas = item['data'][f'{side}_has_betas'] # use this condition to save time + if thresh > 0 and has_betas: + betas_abs = np.abs(item['data'][f'{side}_betas']) + if (betas_abs > thresh).any(): + item['data'][f'{side}_has_betas'] = False + return item + dataset = dataset.map(bad_betas_suppressor) + return dataset + + +def apply_params_synchronizer(dataset:wds.WebDataset, poses_betas_simultaneous:bool=False): + ''' Only when both poses and betas are valid, we regard them as valid. ''' + if poses_betas_simultaneous: + def params_synchronizer(item): + for side in ['orig', 'flip']: + has_betas = item['data'][f'{side}_has_betas'] + has_poses = item['data'][f'{side}_has_poses'] + has_both = np.array(float((has_poses > 0) and (has_betas > 0))) + item['data'][f'{side}_has_betas'] = has_both + item['data'][f'{side}_has_poses'] = has_both + return item + dataset = dataset.map(params_synchronizer) + return dataset + + +def apply_insuff_kp_filter(dataset:wds.WebDataset, cnt_thresh:int=4, conf_thresh:float=0.0): + ''' + Counting the number of keypoints with confidence higher than the threshold. + If the number is less than the threshold, we regard it has insufficient valid 2D keypoints. + ''' + if cnt_thresh > 0: + def insuff_kp_filter(item): + kp_conf = item['data']['kp2d'][:, 2] + return (kp_conf > conf_thresh).sum() > cnt_thresh + dataset = dataset.select(insuff_kp_filter) + return dataset + + +def apply_bbox_size_filter(dataset:wds.WebDataset, bbox_size_thresh:Optional[float]=None): + if bbox_size_thresh: + def bbox_size_filter(item): + bbox_size = item['data']['scale'] * 200 + return bbox_size.min() > bbox_size_thresh # ensure the lower bound is large enough + dataset = dataset.select(bbox_size_filter) + return dataset + + +def apply_reproj_err_filter(dataset:wds.WebDataset, thresh:float=0.0): + ''' If the re-projection error is higher than the threshold, we regard it as bad sample. ''' + if thresh > 0: + def reproj_err_filter(item): + losses = item['data'].get('extra_info', {}).get('fitting_loss', np.array({})).item() + reproj_loss = losses.get('reprojection_loss', None) + return reproj_loss is None or reproj_loss < thresh + dataset = dataset.select(reproj_err_filter) + return dataset + + +def apply_invalid_betas_regularizer(dataset:wds.WebDataset, reg_betas:bool=False): + ''' For those items with invalid betas, set them to zero. ''' + if reg_betas: + def betas_regularizer(item): + # Always have betas set to zero, and all valid. + for side in ['orig', 'flip']: + has_betas = item['data'][f'{side}_has_betas'] + betas = item['data'][f'{side}_betas'] + + if not (has_betas > 0): + item['data'][f'{side}_has_betas'] = np.array(float((True))) + item['data'][f'{side}_betas'] = betas * 0 + return item + dataset = dataset.map(betas_regularizer) + return dataset + + +def apply_example_formatter(dataset:wds.WebDataset, cfg:DictConfig): + ''' Format the item to the wanted format. ''' + + def get_fmt_data(raw_item:Dict, augm_args:Dict, cfg:DictConfig): + ''' + On the one hand, we will perform the augmentation to the image, on the other hand, we need to + crop the image to the patch according to the bbox. Both processes would influence the position + of related keypoints. + After that, we need to align the 2D & 3D keypoints to the augmented image. + ''' + # 1. Prepare the raw data that will be used in the following steps. + img_rgb = raw_item['img'] # (H, W, 3) + img_a = raw_item['mask'].astype(np.uint8)[:, :, None] * 255 # (H, W, 1) mask is 0/1 valued + img_rgba = np.concatenate([img_rgb, img_a], axis=2) # (H, W, 4) + H, W, C = img_rgb.shape + cx, cy = raw_item['data']['center'] + # bbox_size = (raw_item['data']['scale'] * 200).max() + bbox_size = expand_to_aspect_ratio( + raw_item['data']['scale'] * 200, + target_aspect_ratio = cfg.policy.bbox_shape, + ).max() + + kp2d_with_conf = raw_item['data']['kp2d'].astype('float32') # (J, 3) + kp3d_with_conf = raw_item['data']['kp3d'].astype('float32') # (J, 4) + + # 2. [img][Augmentation] Extreme cropping according to the 2D keypoints. + if augm_args['do_extreme_crop']: + cx_, cy_, bbox_size_ = extreme_cropping_aggressive(cx, cy, bbox_size, bbox_size, kp2d_with_conf) + # Only apply the crop if the results is large enough. + THRESH = 4 + if bbox_size_ > THRESH: + cx, cy = cx_, cy_ + bbox_size = bbox_size_ + + # 3. [img][Augmentation] Shift the center of the image. + cx += augm_args['tx_ratio'] * bbox_size + cy += augm_args['ty_ratio'] * bbox_size + + # 4. [img][Format] Crop the image to the patch. + img_patch_cv2, transform_2d = generate_image_patch_cv2( + img = img_rgba, + c_x = cx, + c_y = cy, + bb_width = bbox_size, + bb_height = bbox_size, + patch_width = cfg.policy.img_patch_size, + patch_height = cfg.policy.img_patch_size, + do_flip = augm_args['do_flip'], + scale = augm_args['bbox_scale'], + rot = augm_args['rot_deg'], + ) # (H, W, 4), (2, 3) + + img_patch_hwc = img_patch_cv2.copy()[:, :, :3] # (H, W, C) + img_patch_chw = img_patch_hwc.transpose(2, 0, 1).astype(np.float32) + + # 5. [img][Augmentation] Scale the color + for cid in range(min(C, 3)): + img_patch_chw[cid] = np.clip( + a = img_patch_chw[cid] * augm_args['color_scale'][cid], + a_min = 0, + a_max = 255, + ) + + # 6. [img][Format] Normalize the color. + img_mean = [255. * x for x in cfg.policy.img_mean] + img_std = [255. * x for x in cfg.policy.img_std] + for cid in range(min(C, 3)): + img_patch_chw[cid] = (img_patch_chw[cid] - img_mean[cid]) / img_std[cid] + + # 7. [kp2d][Alignment] Align the 2D keypoints. + # 7.1. Flip. + if augm_args['do_flip']: + kp2d_with_conf = flip_lr_keypoints(kp2d_with_conf, W) + # 7.2. Others. Transform the 2D keypoints according to the same transformation of image. + J = len(kp2d_with_conf) + kp2d_homo = np.concatenate([kp2d_with_conf[:, :2], np.ones((J, 1))], axis=1) # (J, 3) + kp2d = np.einsum('ph, jh -> jp', transform_2d, kp2d_homo) # (J, 2) + kp2d_with_conf[:, :2] = kp2d # (J, 3) + + # 8. [kp2d][Format] Normalize the 2D keypoints position to [-0.5, 0.5]-visible space. + kp2d_with_conf[:, :2] = kp2d_with_conf[:, :2] / cfg.policy.img_patch_size - 0.5 + + # 9. [kp3d][Alignment] Align the 3D keypoints. + # 9.1. Flip. + if augm_args['do_flip']: + kp3d_with_conf = flip_lr_keypoints(kp3d_with_conf, W) + # 9.2. In-plane rotation. + rot_mat = np.eye(3) + # TODO: maybe this part can be packed to a single function. + if not augm_args['rot_deg'] == 0: + rot_rad = -augm_args['rot_deg'] * np.pi / 180 + sn, cs = np.sin(rot_rad), np.cos(rot_rad) + rot_mat[0, :2] = [cs, -sn] + rot_mat[1, :2] = [sn, cs] + kp3d_with_conf[:, :3] = np.einsum('ij, kj -> ki', rot_mat, kp3d_with_conf[:, :3]) + + return img_patch_chw, kp2d_with_conf, kp3d_with_conf + + def example_formatter(raw_item): + raw_data = raw_item['data'] + augm_args = get_augm_args(cfg.image_augmentation) + + params_side = 'flip' if augm_args['do_flip'] else 'orig' + img_patch_chw, kp2d, kp3d = get_fmt_data(raw_item, augm_args, cfg) + + fmt_item = {} + fmt_item['pid'] = raw_item['data']['pid'] + fmt_item['img_name'] = raw_item['img_name'] + fmt_item['img_patch'] = img_patch_chw + fmt_item['kp2d'] = kp2d + fmt_item['kp3d'] = kp3d + fmt_item['augm_args'] = augm_args + fmt_item['raw_skel_params'] = { + 'poses': raw_data[f'{params_side}_poses'], + 'betas': raw_data[f'{params_side}_betas'], + } + fmt_item['has_skel_params'] = { + 'poses': raw_data[f'{params_side}_has_poses'], + 'betas': raw_data[f'{params_side}_has_betas'], + } + fmt_item['updated_by_spin'] = False # Only data updated by spin process will be marked as True. + + return fmt_item + + dataset = dataset.map(example_formatter) + return dataset \ No newline at end of file diff --git a/lib/data/datasets/hsmr_v1/utils.py b/lib/data/datasets/hsmr_v1/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..107286b414e67825ffea5c834ef83e3c04b68ac4 --- /dev/null +++ b/lib/data/datasets/hsmr_v1/utils.py @@ -0,0 +1,327 @@ +from lib.kits.basic import * + +import os +import cv2 +import braceexpand +from typing import List, Union + +from .crop import * + + +def expand_urls(urls: Union[str, List[str]]): + + def expand_url(s): + return os.path.expanduser(os.path.expandvars(s)) + + if isinstance(urls, str): + urls = [urls] + urls = [u for url in urls for u in braceexpand.braceexpand(expand_url(url))] + return urls + + +def get_augm_args(img_augm_cfg:Optional[DictConfig]): + ''' + Perform some random augmentation to the image and patch it. Here we perform generate augmentation arguments + according to the configuration and random seed. + + Briefly speaking, things done here are: size scale, color scale, rotate, flip, extreme crop, translate. + ''' + sample_args = { + 'bbox_scale' : 1.0, + 'color_scale' : [1.0, 1.0, 1.0], + 'rot_deg' : 0.0, + 'do_flip' : False, + 'do_extreme_crop' : False, + 'tx_ratio' : 0.0, + 'ty_ratio' : 0.0, + } + + if img_augm_cfg is not None: + sample_args['tx_ratio'] += np.clip(np.random.randn(), -1.0, 1.0) * img_augm_cfg.trans_factor + sample_args['ty_ratio'] += np.clip(np.random.randn(), -1.0, 1.0) * img_augm_cfg.trans_factor + sample_args['bbox_scale'] += np.clip(np.random.randn(), -1.0, 1.0) * img_augm_cfg.bbox_scale_factor + + if np.random.random() <= img_augm_cfg.rot_aug_rate: + sample_args['rot_deg'] += np.clip(np.random.randn(), -2.0, 2.0) * img_augm_cfg.rot_factor + if np.random.random() <= img_augm_cfg.flip_aug_rate: + sample_args['do_flip'] = True + if np.random.random() <= img_augm_cfg.extreme_crop_aug_rate: + sample_args['do_extreme_crop'] = True + + c_up = 1.0 + img_augm_cfg.half_color_scale + c_low = 1.0 - img_augm_cfg.half_color_scale + sample_args['color_scale'] = [ + np.random.uniform(c_low, c_up), + np.random.uniform(c_low, c_up), + np.random.uniform(c_low, c_up), + ] + return sample_args + + +def rotate_2d(pt_2d: np.ndarray, rot_rad: float) -> np.ndarray: + ''' + Rotate a 2D point on the x-y plane. + Copied from: https://github.com/shubham-goel/4D-Humans/blob/6ec79656a23c33237c724742ca2a0ec00b398b53/hmr2/datasets/utils.py#L90-L104 + + ### Args + - pt_2d: np.ndarray + - Input 2D point with shape (2,). + - rot_rad: float + - Rotation angle. + + ### Returns + - np.ndarray + - Rotated 2D point. + ''' + x = pt_2d[0] + y = pt_2d[1] + sn, cs = np.sin(rot_rad), np.cos(rot_rad) + xx = x * cs - y * sn + yy = x * sn + y * cs + return np.array([xx, yy], dtype=np.float32) + + +def extreme_cropping_aggressive(center_x: float, center_y: float, width: float, height: float, keypoints_2d: np.ndarray) -> Tuple: + """ + Perform aggressive extreme cropping. + Copied from: https://github.com/shubham-goel/4D-Humans/blob/6ec79656a23c33237c724742ca2a0ec00b398b53/hmr2/datasets/utils.py#L978-L1025 + + ### Args + - center_x: float + - x coordinate of bounding box center. + - center_y: float + - y coordinate of bounding box center. + - width: float + - Bounding box width. + - height: float + - Bounding box height. + - keypoints_2d: np.ndarray + - Array of shape (N, 3) containing 2D keypoint locations. + - rescale: float + - Scale factor to rescale bounding boxes computed from the keypoints. + + ### Returns + - center_x: float + - x coordinate of bounding box center. + - center_y: float + - y coordinate of bounding box center. + - bbox_size: float + - Bounding box size. + """ + p = torch.rand(1).item() + if full_body(keypoints_2d): + if p < 0.2: + center_x, center_y, width, height = crop_to_hips(center_x, center_y, width, height, keypoints_2d) + elif p < 0.3: + center_x, center_y, width, height = crop_to_shoulders(center_x, center_y, width, height, keypoints_2d) + elif p < 0.4: + center_x, center_y, width, height = crop_to_head(center_x, center_y, width, height, keypoints_2d) + elif p < 0.5: + center_x, center_y, width, height = crop_torso_only(center_x, center_y, width, height, keypoints_2d) + elif p < 0.6: + center_x, center_y, width, height = crop_rightarm_only(center_x, center_y, width, height, keypoints_2d) + elif p < 0.7: + center_x, center_y, width, height = crop_leftarm_only(center_x, center_y, width, height, keypoints_2d) + elif p < 0.8: + center_x, center_y, width, height = crop_legs_only(center_x, center_y, width, height, keypoints_2d) + elif p < 0.9: + center_x, center_y, width, height = crop_rightleg_only(center_x, center_y, width, height, keypoints_2d) + else: + center_x, center_y, width, height = crop_leftleg_only(center_x, center_y, width, height, keypoints_2d) + elif upper_body(keypoints_2d): + if p < 0.2: + center_x, center_y, width, height = crop_to_shoulders(center_x, center_y, width, height, keypoints_2d) + elif p < 0.4: + center_x, center_y, width, height = crop_to_head(center_x, center_y, width, height, keypoints_2d) + elif p < 0.6: + center_x, center_y, width, height = crop_torso_only(center_x, center_y, width, height, keypoints_2d) + elif p < 0.8: + center_x, center_y, width, height = crop_rightarm_only(center_x, center_y, width, height, keypoints_2d) + else: + center_x, center_y, width, height = crop_leftarm_only(center_x, center_y, width, height, keypoints_2d) + return center_x, center_y, max(width, height) + + +def gen_trans_from_patch_cv( + c_x : float, + c_y : float, + src_width : float, + src_height : float, + dst_width : float, + dst_height : float, + scale : float, + rot : float +) -> np.ndarray: + ''' + Create transformation matrix for the bounding box crop. + Copied from: https://github.com/shubham-goel/4D-Humans/blob/6ec79656a23c33237c724742ca2a0ec00b398b53/hmr2/datasets/utils.py#L107-L154 + + ### Args + - c_x: float + - Bounding box center x coordinate in the original image. + - c_y: float + - Bounding box center y coordinate in the original image. + - src_width: float + - Bounding box width. + - src_height: float + - Bounding box height. + - dst_width: float + - Output box width. + - dst_height: float + - Output box height. + - scale: float + - Rescaling factor for the bounding box (augmentation). + - rot: float + - Random rotation applied to the box. + + ### Returns + - trans: np.ndarray + - Target geometric transformation. + ''' + # augment size with scale + src_w = src_width * scale + src_h = src_height * scale + src_center = np.zeros(2) + src_center[0] = c_x + src_center[1] = c_y + # augment rotation + rot_rad = np.pi * rot / 180 + src_downdir = rotate_2d(np.array([0, src_h * 0.5], dtype=np.float32), rot_rad) + src_rightdir = rotate_2d(np.array([src_w * 0.5, 0], dtype=np.float32), rot_rad) + + dst_w = dst_width + dst_h = dst_height + dst_center = np.array([dst_w * 0.5, dst_h * 0.5], dtype=np.float32) + dst_downdir = np.array([0, dst_h * 0.5], dtype=np.float32) + dst_rightdir = np.array([dst_w * 0.5, 0], dtype=np.float32) + + src = np.zeros((3, 2), dtype=np.float32) + src[0, :] = src_center + src[1, :] = src_center + src_downdir + src[2, :] = src_center + src_rightdir + + dst = np.zeros((3, 2), dtype=np.float32) + dst[0, :] = dst_center + dst[1, :] = dst_center + dst_downdir + dst[2, :] = dst_center + dst_rightdir + + trans = cv2.getAffineTransform(np.float32(src), np.float32(dst)) # (2, 3) # type: ignore + + return trans + + +def generate_image_patch_cv2( + img : np.ndarray, + c_x : float, + c_y : float, + bb_width : float, + bb_height : float, + patch_width : float, + patch_height : float, + do_flip : bool, + scale : float, + rot : float, + border_mode = cv2.BORDER_CONSTANT, + border_value = 0, +) -> Tuple[np.ndarray, np.ndarray]: + ''' + Crop the input image and return the crop and the corresponding transformation matrix. + Copied from: https://github.com/shubham-goel/4D-Humans/blob/6ec79656a23c33237c724742ca2a0ec00b398b53/hmr2/datasets/utils.py#L343-L386 + + ### Args + - img: np.ndarray, shape = (H, W, 3) + - c_x: float + - Bounding box center x coordinate in the original image. + - c_y: float + - Bounding box center y coordinate in the original image. + - bb_width: float + - Bounding box width. + - bb_height: float + - Bounding box height. + - patch_width: float + - Output box width. + - patch_height: float + - Output box height. + - do_flip: bool + - Whether to flip image or not. + - scale: float + - Rescaling factor for the bounding box (augmentation). + - rot: float + - Random rotation applied to the box. + ### Returns + - img_patch: np.ndarray + - Cropped image patch of shape (patch_height, patch_height, 3) + - trans: np.ndarray + - Transformation matrix. + ''' + img_height, img_width, img_channels = img.shape + if do_flip: + img = img[:, ::-1, :] + c_x = img_width - c_x - 1 + + trans = gen_trans_from_patch_cv(c_x, c_y, bb_width, bb_height, patch_width, patch_height, scale, rot) # (2, 3) + + img_patch = cv2.warpAffine(img, trans, (int(patch_width), int(patch_height)), + flags=cv2.INTER_LINEAR, + borderMode=border_mode, + borderValue=border_value, + ) # type: ignore + # Force borderValue=cv2.BORDER_CONSTANT for alpha channel + if (img.shape[2] == 4) and (border_mode != cv2.BORDER_CONSTANT): + img_patch[:,:,3] = cv2.warpAffine(img[:,:,3], trans, (int(patch_width), int(patch_height)), + flags=cv2.INTER_LINEAR, + borderMode=cv2.BORDER_CONSTANT, + ) + + return img_patch, trans + + +def expand_to_aspect_ratio(input_shape, target_aspect_ratio=None): + ''' + Increase the size of the bounding box to match the target shape. + Copied from https://github.com/shubham-goel/4D-Humans/blob/6ec79656a23c33237c724742ca2a0ec00b398b53/hmr2/datasets/utils.py#L14-L33 + ''' + if target_aspect_ratio is None: + return input_shape + + try: + w , h = input_shape + except (ValueError, TypeError): + return input_shape + + w_t, h_t = target_aspect_ratio + if h / w < h_t / w_t: + h_new = max(w * h_t / w_t, h) + w_new = w + else: + h_new = h + w_new = max(h * w_t / h_t, w) + if h_new < h or w_new < w: + breakpoint() + return np.array([w_new, h_new]) + + +body_permutation = [0, 1, 5, 6, 7, 2, 3, 4, 8, 12, 13, 14, 9, 10, 11, 16, 15, 18, 17, 22, 23, 24, 19, 20, 21] +extra_permutation = [5, 4, 3, 2, 1, 0, 11, 10, 9, 8, 7, 6, 12, 13, 14, 15, 16, 17, 18] +FLIP_KP_PERMUTATION = body_permutation + [25 + i for i in extra_permutation] + +def flip_lr_keypoints(joints: np.ndarray, width: float) -> np.ndarray: + """ + Flip 2D or 3D keypoints. + Modified from: https://github.com/shubham-goel/4D-Humans/blob/6ec79656a23c33237c724742ca2a0ec00b398b53/hmr2/datasets/utils.py#L448-L462 + + ### Args + - joints: np.ndarray + - Array of shape (N, 3) or (N, 4) containing 2D or 3D keypoint locations and confidence. + - flip_permutation: list + - Permutation to apply after flipping. + ### Returns + - np.ndarray + - Flipped 2D or 3D keypoints with shape (N, 3) or (N, 4) respectively. + """ + joints = joints.copy() + # Flip horizontal + joints[:, 0] = width - joints[:, 0] - 1 + joints = joints[FLIP_KP_PERMUTATION] + + return joints \ No newline at end of file diff --git a/lib/data/datasets/hsmr_v1/wds_loader.py b/lib/data/datasets/hsmr_v1/wds_loader.py new file mode 100644 index 0000000000000000000000000000000000000000..b318bb61326c244d7da20920c7162aaa57db6d17 --- /dev/null +++ b/lib/data/datasets/hsmr_v1/wds_loader.py @@ -0,0 +1,56 @@ +from lib.kits.basic import * + +import webdataset as wds + +from .utils import * +from .stream_pipelines import * + +# This line is to fix the problem of "OSError: image file is truncated" when loading images. +from PIL import ImageFile +ImageFile.LOAD_TRUNCATED_IMAGES = True + +def load_tars_as_wds( + cfg : DictConfig, + urls : Union[str, List[str]], + resampled : bool = False, + epoch_size : int = None, + cache_dir : str = None, + train : bool = True, +): + urls = expand_urls(urls) # to list of URL strings + + dataset : wds.WebDataset = wds.WebDataset( + urls, + nodesplitter = wds.split_by_node, + shardshuffle = True, + resampled = resampled, + cache_dir = cache_dir, + ) + if train: + dataset = dataset.shuffle(100) + + # A lot of processes to initialize the dataset. Check the pipeline generator function for more details. + # The order of the pipeline is important, since some of the process are dependent on some previous ones. + dataset = apply_corrupt_filter(dataset) + dataset = dataset.decode('rgb8').rename(jpg='jpg;jpeg;png') + dataset = apply_multi_ppl_splitter(dataset) + dataset = apply_keys_adapter(dataset) #* This adapter is only in HSMR's design, not in the baseline. + dataset = apply_bad_pgt_params_nan_suppressor(dataset) + dataset = apply_bad_pgt_params_kp2d_err_suppressor(dataset, cfg.get('suppress_pgt_params_kp2d_err_thresh', 0.0)) + dataset = apply_bad_pgt_params_pve_max_suppressor(dataset, cfg.get('suppress_pgt_params_pve_max_thresh', 0.0)) + dataset = apply_bad_kp_suppressor(dataset, cfg.get('suppress_kp_conf_thresh', 0.0)) + dataset = apply_bad_betas_suppressor(dataset, cfg.get('suppress_betas_thresh', 0.0)) + # dataset = apply_bad_pose_suppressor(dataset, cfg.get('suppress_pose_thresh', 0.0)) # Not used in baseline, so not implemented. + dataset = apply_params_synchronizer(dataset, cfg.get('poses_betas_simultaneous', False)) + # dataset = apply_no_pose_filter(dataset, cfg.get('no_pose_filter', False)) # Not used in baseline, so not implemented. + dataset = apply_insuff_kp_filter(dataset, cfg.get('filter_insufficient_kp_cnt', 4), cfg.get('suppress_insufficient_kp_thresh', 0.0)) + dataset = apply_bbox_size_filter(dataset, cfg.get('filter_bbox_size_thresh', None)) + dataset = apply_reproj_err_filter(dataset, cfg.get('filter_reproj_err_thresh', 0.0)) + dataset = apply_invalid_betas_regularizer(dataset, cfg.get('regularize_invalid_betas', False)) + + # Final preprocess / format of the data. (Consider to extract the augmentation process.) + dataset = apply_example_formatter(dataset, cfg) + + if epoch_size is not None: + dataset = dataset.with_epoch(epoch_size) + return dataset \ No newline at end of file diff --git a/lib/data/datasets/skel_hmr2_fashion/__init__.py b/lib/data/datasets/skel_hmr2_fashion/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..3eefc709bdd1c4d39d7acbc0f8e79739d7b6ba16 --- /dev/null +++ b/lib/data/datasets/skel_hmr2_fashion/__init__.py @@ -0,0 +1,98 @@ +from typing import Dict, Optional + +import torch +import numpy as np +import pytorch_lightning as pl +from yacs.config import CfgNode + +import webdataset as wds +from .dataset import Dataset +from .image_dataset import ImageDataset +from .mocap_dataset import MoCapDataset + +def to_lower(x: Dict) -> Dict: + """ + Convert all dictionary keys to lowercase + Args: + x (dict): Input dictionary + Returns: + dict: Output dictionary with all keys converted to lowercase + """ + return {k.lower(): v for k, v in x.items()} + +def create_dataset(cfg: CfgNode, dataset_cfg: CfgNode, train: bool = True, **kwargs) -> Dataset: + """ + Instantiate a dataset from a config file. + Args: + cfg (CfgNode): Model configuration file. + dataset_cfg (CfgNode): Dataset configuration info. + train (bool): Variable to select between train and val datasets. + """ + + dataset_type = Dataset.registry[dataset_cfg.TYPE] + return dataset_type(cfg, **to_lower(dataset_cfg), train=train, **kwargs) + +def create_webdataset(cfg: CfgNode, dataset_cfg: CfgNode, train: bool = True) -> Dataset: + """ + Like `create_dataset` but load data from tars. + """ + dataset_type = Dataset.registry[dataset_cfg.TYPE] + return dataset_type.load_tars_as_webdataset(cfg, **to_lower(dataset_cfg), train=train) + + +class MixedWebDataset(wds.WebDataset): + def __init__(self, cfg: CfgNode, dataset_cfg: CfgNode, train: bool = True) -> None: + super(wds.WebDataset, self).__init__() + dataset_list = cfg.DATASETS.TRAIN if train else cfg.DATASETS.VAL + datasets = [create_webdataset(cfg, dataset_cfg[dataset], train=train) for dataset, v in dataset_list.items()] + weights = np.array([v.WEIGHT for dataset, v in dataset_list.items()]) + weights = weights / weights.sum() # normalize + self.append(wds.RandomMix(datasets, weights)) + +class HMR2DataModule(pl.LightningDataModule): + + def __init__(self, cfg: CfgNode, dataset_cfg: CfgNode) -> None: + """ + Initialize LightningDataModule for HMR2 training + Args: + cfg (CfgNode): Config file as a yacs CfgNode containing necessary dataset info. + dataset_cfg (CfgNode): Dataset configuration file + """ + super().__init__() + self.cfg = cfg + self.dataset_cfg = dataset_cfg + self.train_dataset = None + self.val_dataset = None + self.test_dataset = None + self.mocap_dataset = None + + def setup(self, stage: Optional[str] = None) -> None: + """ + Load datasets necessary for training + Args: + cfg (CfgNode): Config file as a yacs CfgNode containing necessary dataset info. + """ + if self.train_dataset == None: + self.train_dataset = MixedWebDataset(self.cfg, self.dataset_cfg, train=True).with_epoch(100_000).shuffle(4000) + # self.val_dataset = MixedWebDataset(self.cfg, self.dataset_cfg, train=False).shuffle(4000) + self.mocap_dataset = MoCapDataset(**to_lower(self.dataset_cfg[self.cfg.DATASETS.MOCAP])) + + def train_dataloader(self) -> Dict: + """ + Setup training data loader. + Returns: + Dict: Dictionary containing image and mocap data dataloaders + """ + train_dataloader = torch.utils.data.DataLoader(self.train_dataset, self.cfg.TRAIN.BATCH_SIZE, drop_last=True, num_workers=self.cfg.GENERAL.NUM_WORKERS, prefetch_factor=self.cfg.GENERAL.PREFETCH_FACTOR) + mocap_dataloader = torch.utils.data.DataLoader(self.mocap_dataset, self.cfg.TRAIN.NUM_TRAIN_SAMPLES * self.cfg.TRAIN.BATCH_SIZE, shuffle=True, drop_last=True, num_workers=1) + return {'img': train_dataloader, 'mocap': mocap_dataloader} + # return {'img': train_dataloader} + + # def val_dataloader(self) -> torch.utils.data.DataLoader: + # """ + # Setup val data loader. + # Returns: + # torch.utils.data.DataLoader: Validation dataloader + # """ + # val_dataloader = torch.utils.data.DataLoader(self.val_dataset, self.cfg.TRAIN.BATCH_SIZE, drop_last=True, num_workers=self.cfg.GENERAL.NUM_WORKERS) + # return val_dataloader diff --git a/lib/data/datasets/skel_hmr2_fashion/dataset.py b/lib/data/datasets/skel_hmr2_fashion/dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..22fc5bc5f4a7b75da672bd89859da14823e71aff --- /dev/null +++ b/lib/data/datasets/skel_hmr2_fashion/dataset.py @@ -0,0 +1,27 @@ +""" +This file contains the defition of the base Dataset class. +""" + +class DatasetRegistration(type): + """ + Metaclass for registering different datasets + """ + def __init__(cls, name, bases, nmspc): + super().__init__(name, bases, nmspc) + if not hasattr(cls, 'registry'): + cls.registry = dict() + cls.registry[name] = cls + + # Metamethods, called on class objects: + def __iter__(cls): + return iter(cls.registry) + + def __str__(cls): + return str(cls.registry) + +class Dataset(metaclass=DatasetRegistration): + """ + Base Dataset class + """ + def __init__(self, *args, **kwargs): + pass \ No newline at end of file diff --git a/lib/data/datasets/skel_hmr2_fashion/image_dataset.py b/lib/data/datasets/skel_hmr2_fashion/image_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..d483b32b880d2db9a8b99f0c7b101f7c3d9b58dc --- /dev/null +++ b/lib/data/datasets/skel_hmr2_fashion/image_dataset.py @@ -0,0 +1,484 @@ +import copy +import os +import numpy as np +import torch +from typing import Any, Dict, List, Union +from yacs.config import CfgNode +import braceexpand +import cv2 +from ipdb import set_trace + +from .dataset import Dataset +from .utils import get_example, expand_to_aspect_ratio + +def expand(s): + return os.path.expanduser(os.path.expandvars(s)) + +def expand_urls(urls: Union[str, List[str]]): + if isinstance(urls, str): + urls = [urls] + urls = [u for url in urls for u in braceexpand.braceexpand(expand(url))] + return urls + +AIC_TRAIN_CORRUPT_KEYS = { + '0a047f0124ae48f8eee15a9506ce1449ee1ba669', + '1a703aa174450c02fbc9cfbf578a5435ef403689', + '0394e6dc4df78042929b891dbc24f0fd7ffb6b6d', + '5c032b9626e410441544c7669123ecc4ae077058', + 'ca018a7b4c5f53494006ebeeff9b4c0917a55f07', + '4a77adb695bef75a5d34c04d589baf646fe2ba35', + 'a0689017b1065c664daef4ae2d14ea03d543217e', + '39596a45cbd21bed4a5f9c2342505532f8ec5cbb', + '3d33283b40610d87db660b62982f797d50a7366b', +} +CORRUPT_KEYS = { + *{f'aic-train/{k}' for k in AIC_TRAIN_CORRUPT_KEYS}, + *{f'aic-train-vitpose/{k}' for k in AIC_TRAIN_CORRUPT_KEYS}, +} + +body_permutation = [0, 1, 5, 6, 7, 2, 3, 4, 8, 12, 13, 14, 9, 10, 11, 16, 15, 18, 17, 22, 23, 24, 19, 20, 21] +extra_permutation = [5, 4, 3, 2, 1, 0, 11, 10, 9, 8, 7, 6, 12, 13, 14, 15, 16, 17, 18] +FLIP_KEYPOINT_PERMUTATION = body_permutation + [25 + i for i in extra_permutation] + +DEFAULT_MEAN = 255. * np.array([0.485, 0.456, 0.406]) +DEFAULT_STD = 255. * np.array([0.229, 0.224, 0.225]) +DEFAULT_IMG_SIZE = 256 + +class ImageDataset(Dataset): + + def __init__(self, + cfg: CfgNode, + dataset_file: str, + img_dir: str, + train: bool = True, + prune: Dict[str, Any] = {}, + **kwargs): + """ + Dataset class used for loading images and corresponding annotations. + Args: + cfg (CfgNode): Model config file. + dataset_file (str): Path to npz file containing dataset info. + img_dir (str): Path to image folder. + train (bool): Whether it is for training or not (enables data augmentation). + """ + super(ImageDataset, self).__init__() + self.train = train + self.cfg = cfg + + self.img_size = cfg['IMAGE_SIZE'] + self.mean = 255. * np.array(self.cfg['IMAGE_MEAN']) + self.std = 255. * np.array(self.cfg['IMAGE_STD']) + + self.img_dir = img_dir + self.data = np.load(dataset_file, allow_pickle=True) + + self.imgname = self.data['imgname'] + self.personid = np.zeros(len(self.imgname), dtype=np.int32) + self.extra_info = self.data.get('extra_info', [{} for _ in range(len(self.imgname))]) + + self.flip_keypoint_permutation = copy.copy(FLIP_KEYPOINT_PERMUTATION) + + num_pose = 3 * 24 + + # Bounding boxes are assumed to be in the center and scale format + self.center = self.data['center'] + self.scale = self.data['scale'].reshape(len(self.center), -1) / 200.0 + if self.scale.shape[1] == 1: + self.scale = np.tile(self.scale, (1, 2)) + assert self.scale.shape == (len(self.center), 2) + + # Get gt SMPLX parameters, if available + try: + self.body_pose = self.data['body_pose'].astype(np.float32) + self.has_body_pose = self.data['has_body_pose'].astype(np.float32) + except KeyError: + self.body_pose = np.zeros((len(self.imgname), num_pose), dtype=np.float32) + self.has_body_pose = np.zeros(len(self.imgname), dtype=np.float32) + try: + self.betas = self.data['betas'].astype(np.float32) + self.has_betas = self.data['has_betas'].astype(np.float32) + except KeyError: + self.betas = np.zeros((len(self.imgname), 10), dtype=np.float32) + self.has_betas = np.zeros(len(self.imgname), dtype=np.float32) + + # try: + # self.trans = self.data['trans'].astype(np.float32) + # except KeyError: + # self.trans = np.zeros((len(self.imgname), 3), dtype=np.float32) + + # Try to get 2d keypoints, if available + try: + body_keypoints_2d = self.data['body_keypoints_2d'] + except KeyError: + body_keypoints_2d = np.zeros((len(self.center), 25, 3)) + # Try to get extra 2d keypoints, if available + try: + extra_keypoints_2d = self.data['extra_keypoints_2d'] + except KeyError: + extra_keypoints_2d = np.zeros((len(self.center), 19, 3)) + + self.keypoints_2d = np.concatenate((body_keypoints_2d, extra_keypoints_2d), axis=1).astype(np.float32) + + # Try to get 3d keypoints, if available + try: + body_keypoints_3d = self.data['body_keypoints_3d'].astype(np.float32) + except KeyError: + body_keypoints_3d = np.zeros((len(self.center), 25, 4), dtype=np.float32) + # Try to get extra 3d keypoints, if available + try: + extra_keypoints_3d = self.data['extra_keypoints_3d'].astype(np.float32) + except KeyError: + extra_keypoints_3d = np.zeros((len(self.center), 19, 4), dtype=np.float32) + + body_keypoints_3d[:, [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14], -1] = 0 + + self.keypoints_3d = np.concatenate((body_keypoints_3d, extra_keypoints_3d), axis=1).astype(np.float32) + + def __len__(self) -> int: + return len(self.scale) + + def __getitem__(self, idx: int) -> Dict: + """ + Returns an example from the dataset. + """ + try: + image_file_rel = self.imgname[idx].decode('utf-8') + except AttributeError: + image_file_rel = self.imgname[idx] + image_file = os.path.join(self.img_dir, image_file_rel) + keypoints_2d = self.keypoints_2d[idx].copy() + keypoints_3d = self.keypoints_3d[idx].copy() + + center = self.center[idx].copy() + center_x = center[0] + center_y = center[1] + scale = self.scale[idx] + BBOX_SHAPE = self.cfg['BBOX_SHAPE'] + bbox_size = expand_to_aspect_ratio(scale*200, target_aspect_ratio=BBOX_SHAPE).max() + bbox_expand_factor = bbox_size / ((scale*200).max()) + body_pose = self.body_pose[idx].copy().astype(np.float32) + betas = self.betas[idx].copy().astype(np.float32) + # trans = self.trans[idx].copy().astype(np.float32) + + has_body_pose = self.has_body_pose[idx].copy() + has_betas = self.has_betas[idx].copy() + + smpl_params = {'global_orient': body_pose[:3], + 'body_pose': body_pose[3:], + 'betas': betas, + # 'trans': trans, + } + + has_smpl_params = {'global_orient': has_body_pose, + 'body_pose': has_body_pose, + 'betas': has_betas + } + + smpl_params_is_axis_angle = {'global_orient': True, + 'body_pose': True, + 'betas': False + } + + augm_config = self.cfg['augm'] + # Crop image and (possibly) perform data augmentation + img_patch, keypoints_2d, keypoints_3d, smpl_params, has_smpl_params, img_size, augm_record = get_example(image_file, + center_x, center_y, + bbox_size, bbox_size, + keypoints_2d, keypoints_3d, + smpl_params, has_smpl_params, + self.flip_keypoint_permutation, + self.img_size, self.img_size, + self.mean, self.std, self.train, augm_config) + + item = {} + # These are the keypoints in the original image coordinates (before cropping) + orig_keypoints_2d = self.keypoints_2d[idx].copy() + + item['img_patch'] = img_patch + item['keypoints_2d'] = keypoints_2d.astype(np.float32) + item['keypoints_3d'] = keypoints_3d.astype(np.float32) + item['orig_keypoints_2d'] = orig_keypoints_2d + item['box_center'] = self.center[idx].copy() + item['box_size'] = bbox_size + item['bbox_expand_factor'] = bbox_expand_factor + item['img_size'] = 1.0 * img_size[::-1].copy() + item['smpl_params'] = smpl_params + item['has_smpl_params'] = has_smpl_params + item['smpl_params_is_axis_angle'] = smpl_params_is_axis_angle + item['imgname'] = image_file + item['imgname_rel'] = image_file_rel + item['personid'] = int(self.personid[idx]) + item['extra_info'] = copy.deepcopy(self.extra_info[idx]) + item['idx'] = idx + item['_scale'] = scale + item['augm_record'] = augm_record # Augmentation record for recovery in self-improvement process. + return item + + + @staticmethod + def load_tars_as_webdataset(cfg: CfgNode, urls: Union[str, List[str]], train: bool, + resampled=False, + epoch_size=None, + cache_dir=None, + **kwargs) -> Dataset: + """ + Loads the dataset from a webdataset tar file. + """ + from .smplh_prob_filter import poses_check_probable, load_amass_hist_smooth + + IMG_SIZE = cfg['IMAGE_SIZE'] + BBOX_SHAPE = cfg['BBOX_SHAPE'] + MEAN = 255. * np.array(cfg['IMAGE_MEAN']) + STD = 255. * np.array(cfg['IMAGE_STD']) + + def split_data(source): + for item in source: + datas = item['data.pyd'] + for pid, data in enumerate(datas): + data['pid'] = pid + data['orig_has_body_pose'] = data['orig_pve_max'] < 0.1 + data['orig_has_betas'] = data['orig_pve_max'] < 0.1 + if 'flip_pve_mean' in data: + data['flip_has_body_pose'] = data['flip_pve_mean'] < 0.05 # TODO: fix this problems + data['flip_has_betas'] = data['flip_has_body_pose'] + else: + data['flip_has_body_pose'] = data['flip_pve_max'] < 0.65 + data['flip_has_betas'] = data['flip_has_body_pose'] + # data['body_pose'] = data['orig_poses'] + # data['betas'] = data['orig_betas'] + if 'detection.npz' in item: + det_idx = data['extra_info']['detection_npz_idx'] + mask = item['detection.npz']['masks'][det_idx] + else: + mask = np.ones_like(item['jpg'][:,:,0], dtype=bool) + yield { + '__key__': item['__key__'], + 'jpg': item['jpg'], + 'data.pyd': data, + 'mask': mask, + } + + def suppress_bad_kps(item, thresh=0.0): + if thresh > 0: + kp2d = item['data.pyd']['keypoints_2d'] + kp2d_conf = np.where(kp2d[:, 2] < thresh, 0.0, kp2d[:, 2]) + item['data.pyd']['keypoints_2d'] = np.concatenate([kp2d[:,:2], kp2d_conf[:,None]], axis=1) + return item + + def filter_numkp(item, numkp=4, thresh=0.0): + kp_conf = item['data.pyd']['keypoints_2d'][:, 2] + return (kp_conf > thresh).sum() > numkp + + def filter_reproj_error(item, thresh=10**4.5): + losses = item['data.pyd'].get('extra_info', {}).get('fitting_loss', np.array({})).item() + reproj_loss = losses.get('reprojection_loss', None) + return reproj_loss is None or reproj_loss < thresh + + def filter_bbox_size(item, thresh=1): + bbox_size_min = item['data.pyd']['scale'].min().item() * 200. + return bbox_size_min > thresh + + def filter_no_poses(item): + return (item['data.pyd']['has_body_pose'] > 0) + + def supress_bad_betas(item, thresh=3): + for side in ['orig', 'flip']: + has_betas = item['data.pyd'][f'{side}_has_betas'] + if thresh > 0 and has_betas: + betas_abs = np.abs(item['data.pyd'][f'{side}_betas']) + if (betas_abs > thresh).any(): + item['data.pyd'][f'{side}_has_betas'] = False + + return item + + amass_poses_hist100_smooth = load_amass_hist_smooth() + def supress_bad_poses(item): + for side in ['orig', 'flip']: + has_body_pose = item['data.pyd'][f'{side}_has_body_pose'] + if has_body_pose: + body_pose = item['data.pyd'][f'{side}_body_pose'] + pose_is_probable = poses_check_probable(torch.from_numpy(body_pose)[None, 3:], amass_poses_hist100_smooth).item() + if not pose_is_probable: + item['data.pyd'][f'{side}_has_body_pose'] = False + return item + + def poses_betas_simultaneous(item): + # We either have both body_pose and betas, or neither + for side in ['orig', 'flip']: + has_betas = item['data.pyd'][f'{side}_has_betas'] + has_body_pose = item['data.pyd'][f'{side}_has_body_pose'] + item['data.pyd'][f'{side}_has_betas'] = item['data.pyd'][f'{side}_has_body_pose'] = np.array(float((has_body_pose>0) and (has_betas>0))) + return item + + def set_betas_for_reg(item): + for side in ['orig', 'flip']: + # Always have betas set to true + has_betas = item['data.pyd'][f'{side}_has_betas'] + betas = item['data.pyd'][f'{side}_betas'] + + if not (has_betas>0): + item['data.pyd'][f'{side}_has_betas'] = np.array(float((True))) + item['data.pyd'][f'{side}_betas'] = betas * 0 + return item + + # Load the dataset + if epoch_size is not None: + resampled = True + corrupt_filter = lambda sample: (sample['__key__'] not in CORRUPT_KEYS) + import webdataset as wds + dataset = wds.WebDataset(expand_urls(urls), + nodesplitter=wds.split_by_node, + shardshuffle=True, + resampled=resampled, + cache_dir=cache_dir, + ).select(corrupt_filter) + if train: + dataset = dataset.shuffle(100) + dataset = dataset.decode('rgb8').rename(jpg='jpg;jpeg;png') + + # Process the dataset + dataset = dataset.compose(split_data) + + # Filter/clean the dataset + SUPPRESS_KP_CONF_THRESH = cfg.get('SUPPRESS_KP_CONF_THRESH', 0.0) + SUPPRESS_BETAS_THRESH = cfg.get('SUPPRESS_BETAS_THRESH', 0.0) + SUPPRESS_BAD_POSES = cfg.get('SUPPRESS_BAD_POSES', False) + POSES_BETAS_SIMULTANEOUS = cfg.get('POSES_BETAS_SIMULTANEOUS', False) + BETAS_REG = cfg.get('BETAS_REG', False) + FILTER_NO_POSES = cfg.get('FILTER_NO_POSES', False) + FILTER_NUM_KP = cfg.get('FILTER_NUM_KP', 4) + FILTER_NUM_KP_THRESH = cfg.get('FILTER_NUM_KP_THRESH', 0.0) + FILTER_REPROJ_THRESH = cfg.get('FILTER_REPROJ_THRESH', 0.0) + FILTER_MIN_BBOX_SIZE = cfg.get('FILTER_MIN_BBOX_SIZE', 0.0) + if SUPPRESS_KP_CONF_THRESH > 0: + dataset = dataset.map(lambda x: suppress_bad_kps(x, thresh=SUPPRESS_KP_CONF_THRESH)) + if SUPPRESS_BETAS_THRESH > 0: + dataset = dataset.map(lambda x: supress_bad_betas(x, thresh=SUPPRESS_BETAS_THRESH)) + if SUPPRESS_BAD_POSES: + dataset = dataset.map(lambda x: supress_bad_poses(x)) + if POSES_BETAS_SIMULTANEOUS: + dataset = dataset.map(lambda x: poses_betas_simultaneous(x)) + if FILTER_NO_POSES: + dataset = dataset.select(lambda x: filter_no_poses(x)) + if FILTER_NUM_KP > 0: + dataset = dataset.select(lambda x: filter_numkp(x, numkp=FILTER_NUM_KP, thresh=FILTER_NUM_KP_THRESH)) + if FILTER_REPROJ_THRESH > 0: + dataset = dataset.select(lambda x: filter_reproj_error(x, thresh=FILTER_REPROJ_THRESH)) + if FILTER_MIN_BBOX_SIZE > 0: + dataset = dataset.select(lambda x: filter_bbox_size(x, thresh=FILTER_MIN_BBOX_SIZE)) + if BETAS_REG: + dataset = dataset.map(lambda x: set_betas_for_reg(x)) # NOTE: Must be at the end + + use_skimage_antialias = cfg.get('USE_SKIMAGE_ANTIALIAS', False) + border_mode = { + 'constant': cv2.BORDER_CONSTANT, + 'replicate': cv2.BORDER_REPLICATE, + }[cfg.get('BORDER_MODE', 'constant')] + + # Process the dataset further + dataset = dataset.map(lambda x: ImageDataset.process_webdataset_tar_item(x, train, + augm_config=cfg['augm'], + MEAN=MEAN, STD=STD, IMG_SIZE=IMG_SIZE, + BBOX_SHAPE=BBOX_SHAPE, + use_skimage_antialias=use_skimage_antialias, + border_mode=border_mode, + )) + if epoch_size is not None: + dataset = dataset.with_epoch(epoch_size) + + return dataset + + @staticmethod + def process_webdataset_tar_item(item, train, + augm_config=None, + MEAN=DEFAULT_MEAN, + STD=DEFAULT_STD, + IMG_SIZE=DEFAULT_IMG_SIZE, + BBOX_SHAPE=None, + use_skimage_antialias=False, + border_mode=cv2.BORDER_CONSTANT, + ): + # Read data from item + key = item['__key__'] + image = item['jpg'] + data = item['data.pyd'] + mask = item['mask'] + pid = data['pid'] + + keypoints_2d = data['keypoints_2d'] + keypoints_3d = data['keypoints_3d'] + center = data['center'] + scale = data['scale'] + body_pose = (data['orig_poses'], data['flip_poses']) + betas = (data['orig_betas'], data['flip_betas']) + # trans = data['trans'] + has_body_pose = (data['orig_has_body_pose'], data['flip_has_body_pose']) + has_betas = (data['orig_has_betas'], data['flip_has_betas']) + # image_file = data['image_file'] + + # Process data + orig_keypoints_2d = keypoints_2d.copy() + center_x = center[0] + center_y = center[1] + bbox_size = expand_to_aspect_ratio(scale*200, target_aspect_ratio=BBOX_SHAPE).max() + if bbox_size < 1: + breakpoint() + + + smpl_params = {'global_orient': (body_pose[0][:3], body_pose[1][:3]), + 'body_pose': (body_pose[0][3:], body_pose[1][3:]), + 'betas': betas, + # 'trans': trans, + } + + has_smpl_params = {'global_orient': has_body_pose, + 'body_pose': has_body_pose, + 'betas': has_betas + } + + smpl_params_is_axis_angle = {'global_orient': True, + 'body_pose': True, + 'betas': False + } + + augm_config = copy.deepcopy(augm_config) + # Crop image and (possibly) perform data augmentation + img_rgba = np.concatenate([image, mask.astype(np.uint8)[:,:,None]*255], axis=2) + img_patch_rgba, keypoints_2d, keypoints_3d, smpl_params, has_smpl_params, img_size, trans, augm_record = get_example(img_rgba, + center_x, center_y, + bbox_size, bbox_size, + keypoints_2d, keypoints_3d, + smpl_params, has_smpl_params, + FLIP_KEYPOINT_PERMUTATION, + IMG_SIZE, IMG_SIZE, + MEAN, STD, train, augm_config, + is_bgr=False, return_trans=True, + use_skimage_antialias=use_skimage_antialias, + border_mode=border_mode, + ) + img_patch = img_patch_rgba[:3,:,:] + mask_patch = (img_patch_rgba[3,:,:] / 255.0).clip(0,1) + if (mask_patch < 0.5).all(): + mask_patch = np.ones_like(mask_patch) + + item = {} + + item['img'] = img_patch + item['mask'] = mask_patch + # item['img_og'] = image + # item['mask_og'] = mask + item['keypoints_2d'] = keypoints_2d.astype(np.float32) + item['keypoints_3d'] = keypoints_3d.astype(np.float32) + item['orig_keypoints_2d'] = orig_keypoints_2d + item['box_center'] = center.copy() + item['box_size'] = bbox_size + item['img_size'] = 1.0 * img_size[::-1].copy() + item['smpl_params'] = smpl_params + item['has_smpl_params'] = has_smpl_params + item['smpl_params_is_axis_angle'] = smpl_params_is_axis_angle + item['_scale'] = scale + item['_trans'] = trans + item['imgname'] = key + item['pid'] = pid + item['augm_record'] = augm_record # Augmentation record for recovery in self-improvement process. + return item diff --git a/lib/data/datasets/skel_hmr2_fashion/mocap_dataset.py b/lib/data/datasets/skel_hmr2_fashion/mocap_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..2b1db9ed898ae4454d97ac0039ebd0a182f7ef5e --- /dev/null +++ b/lib/data/datasets/skel_hmr2_fashion/mocap_dataset.py @@ -0,0 +1,31 @@ +import numpy as np +from typing import Dict + +class MoCapDataset: + + def __init__(self, dataset_file: str, threshold: float = 0.01) -> None: + """ + Dataset class used for loading a dataset of unpaired SMPL parameter annotations + Args: + cfg (CfgNode): Model config file. + dataset_file (str): Path to npz file containing dataset info. + threshold (float): Threshold for PVE filtering. + """ + data = np.load(dataset_file) + # pve = data['pve'] + pve = data['pve_max'] + # pve = data['pve_mean'] + mask = pve < threshold + self.pose = data['poses'].astype(np.float32)[mask, 3:] + self.betas = data['betas'].astype(np.float32)[mask] + self.length = len(self.pose) + print(f'Loaded {self.length} among {len(pve)} samples from {dataset_file} (using threshold = {threshold})') + + def __getitem__(self, idx: int) -> Dict: + pose = self.pose[idx].copy() + betas = self.betas[idx].copy() + item = {'body_pose': pose, 'betas': betas} + return item + + def __len__(self) -> int: + return self.length diff --git a/lib/data/datasets/skel_hmr2_fashion/preprocess/lspet_to_npz.py b/lib/data/datasets/skel_hmr2_fashion/preprocess/lspet_to_npz.py new file mode 100644 index 0000000000000000000000000000000000000000..228c3df502e6bd9643085e80b09013c683fedf2e --- /dev/null +++ b/lib/data/datasets/skel_hmr2_fashion/preprocess/lspet_to_npz.py @@ -0,0 +1,74 @@ +# Adapted from https://raw.githubusercontent.com/nkolot/SPIN/master/datasets/preprocess/hr_lspet.py +import os +import glob +import numpy as np +import scipy.io as sio +# from .read_openpose import read_openpose + +def hr_lspet_extract(dataset_path, out_path): + + # training mode + png_path = os.path.join(dataset_path, '*.png') + imgs = glob.glob(png_path) + imgs.sort() + + # structs we use + imgnames_, scales_, centers_, parts_, openposes_= [], [], [], [], [] + + # scale factor + scaleFactor = 1.2 + + # annotation files + annot_file = os.path.join(dataset_path, 'joints.mat') + joints = sio.loadmat(annot_file)['joints'] + + # main loop + for i, imgname in enumerate(imgs): + # image name + imgname = imgname.split('/')[-1] + # read keypoints + part14 = joints[:,:2,i] + # scale and center + bbox = [min(part14[:,0]), min(part14[:,1]), + max(part14[:,0]), max(part14[:,1])] + center = [(bbox[2]+bbox[0])/2, (bbox[3]+bbox[1])/2] + # scale = scaleFactor*max(bbox[2]-bbox[0], bbox[3]-bbox[1]) # Don't /200 + scale = scaleFactor*np.array([bbox[2]-bbox[0], bbox[3]-bbox[1]]) # Don't /200 + # update keypoints + part = np.zeros([24,3]) + part[:14] = np.hstack([part14, np.ones([14,1])]) + + # # read openpose detections + # json_file = os.path.join(openpose_path, 'hrlspet', + # imgname.replace('.png', '_keypoints.json')) + # openpose = read_openpose(json_file, part, 'hrlspet') + + # store the data + imgnames_.append(imgname) + centers_.append(center) + scales_.append(scale) + parts_.append(part) + # openposes_.append(openpose) + + # Populate extra_keypoints_2d: N,19,3 + # extra_keypoints_2d[:14] = parts[:14] + extra_keypoints_2d = np.zeros((len(parts_), 19, 3)) + extra_keypoints_2d[:,:14,:] = np.stack(parts_)[:,:14,:3] + + print(f'{extra_keypoints_2d.shape=}') + + # store the data struct + if not os.path.isdir(out_path): + os.makedirs(out_path) + out_file = os.path.join(out_path, 'hr-lspet_train.npz') + np.savez(out_file, imgname=imgnames_, + center=centers_, + scale=scales_, + part=parts_, + extra_keypoints_2d=extra_keypoints_2d, + # openpose=openposes_ + ) + + +if __name__ == '__main__': + hr_lspet_extract('/shared/pavlakos/datasets/hr-lspet/', 'hmr2_evaluation_data/') diff --git a/lib/data/datasets/skel_hmr2_fashion/preprocess/posetrack_to_npz.py b/lib/data/datasets/skel_hmr2_fashion/preprocess/posetrack_to_npz.py new file mode 100644 index 0000000000000000000000000000000000000000..7c999b5c80c249ec768aabe4a271cbfc86877817 --- /dev/null +++ b/lib/data/datasets/skel_hmr2_fashion/preprocess/posetrack_to_npz.py @@ -0,0 +1,92 @@ +# Adapted from https://raw.githubusercontent.com/nkolot/SPIN/master/datasets/preprocess/coco.py +import os +from os.path import join +import sys +import json +import numpy as np +from pathlib import Path +# from .read_openpose import read_openpose + +def coco_extract(dataset_path, out_path): + + # # convert joints to global order + # joints_idx = [19, 20, 21, 22, 23, 9, 8, 10, 7, 11, 6, 3, 2, 4, 1, 5, 0] + + # bbox expansion factor + scaleFactor = 1.2 + + # structs we need + imgnames_, scales_, centers_, parts_, openposes_ = [], [], [], [], [] + + # json annotation file + SPLIT='val' + json_paths = (Path(dataset_path)/'posetrack_data/annotations'/SPLIT).glob('*.json') + for json_path in json_paths: + json_data = json.load(open(json_path, 'r')) + + imgs = {} + for img in json_data['images']: + imgs[img['id']] = img + + for annot in json_data['annotations']: + # keypoints processing + keypoints = annot['keypoints'] + keypoints = np.reshape(keypoints, (17,3)) + keypoints[keypoints[:,2]>0,2] = 1 + # check if all major body joints are annotated + if sum(keypoints[5:,2]>0) < 12: + continue + # image name + image_id = annot['image_id'] + img_name = str(imgs[image_id]['file_name']) + # img_name_full = f'images/{SPLIT}/{json_path.stem}/{img_name}' + img_name_full = img_name + + # keypoints + part = np.zeros([17,3]) + # part[joints_idx] = keypoints + part = keypoints + + # scale and center + bbox = annot['bbox'] + center = [bbox[0] + bbox[2]/2, bbox[1] + bbox[3]/2] + # scale = scaleFactor*max(bbox[2], bbox[3]) # Don't do /200 + scale = scaleFactor*np.array([bbox[2], bbox[3]]) # Don't /200 + # # read openpose detections + # json_file = os.path.join(openpose_path, 'coco', + # img_name.replace('.jpg', '_keypoints.json')) + # openpose = read_openpose(json_file, part, 'coco') + + # store data + imgnames_.append(img_name_full) + centers_.append(center) + scales_.append(scale) + parts_.append(part) + # openposes_.append(openpose) + + + # NOTE: Posetrack val doesn't annotate ears (17,18) + # But Posetrack does annotate head, neck so that wil have to live in extra_kps. + posetrack_to_op_extra = [0, 37, 38, 18, 17, 5, 2, 6, 3, 7, 4, 12, 9, 13, 10, 14, 11] # Will contain 15 keypoints. + all_keypoints_2d = np.zeros((len(parts_), 44, 3)) + all_keypoints_2d[:,posetrack_to_op_extra] = np.stack(parts_)[:,:len(posetrack_to_op_extra),:3] + body_keypoints_2d = all_keypoints_2d[:,:25,:] + extra_keypoints_2d = all_keypoints_2d[:,25:,:] + + print(f'{extra_keypoints_2d.shape=}') + + # store the data struct + if not os.path.isdir(out_path): + os.makedirs(out_path) + out_file = os.path.join(out_path, f'posetrack_2018_{SPLIT}.npz') + np.savez(out_file, imgname=imgnames_, + center=centers_, + scale=scales_, + part=parts_, + body_keypoints_2d=body_keypoints_2d, + extra_keypoints_2d=extra_keypoints_2d, + # openpose=openposes_ + ) + +if __name__ == '__main__': + coco_extract('/shared/pavlakos/datasets/posetrack/posetrack2018/', 'hmr2_evaluation_data/') diff --git a/lib/data/datasets/skel_hmr2_fashion/smplh_prob_filter.py b/lib/data/datasets/skel_hmr2_fashion/smplh_prob_filter.py new file mode 100644 index 0000000000000000000000000000000000000000..b544b54999aa726378b0f8b058b1a90b166d7e30 --- /dev/null +++ b/lib/data/datasets/skel_hmr2_fashion/smplh_prob_filter.py @@ -0,0 +1,152 @@ +import os +import numpy as np +import torch +import torch.nn.functional as F +from lib.platform import PM + +JOINT_NAMES = [ + 'left_hip', + 'right_hip', + 'spine1', + 'left_knee', + 'right_knee', + 'spine2', + 'left_ankle', + 'right_ankle', + 'spine3', + 'left_foot', + 'right_foot', + 'neck', + 'left_collar', + 'right_collar', + 'head', + 'left_shoulder', + 'right_shoulder', + 'left_elbow', + 'right_elbow', + 'left_wrist', + 'right_wrist' +] + +# Manually chosen probability density thresholds for each joint +# Probablities computed using SIGMA=2 gaussian blur on AMASS pose 3D histogram for range (-pi,pi) with 100x100x100 bins +JOINT_NAME_PROB_THRESHOLDS = { + 'left_hip': 5e-5, + 'right_hip': 5e-5, + 'spine1': 2e-3, + 'left_knee': 5e-6, + 'right_knee': 5e-6, + 'spine2': 0.01, + 'left_ankle': 5e-6, + 'right_ankle': 5e-6, + 'spine3': 0.025, + 'left_foot': 0, + 'right_foot': 0, + 'neck': 2e-4, + 'left_collar': 4.5e-4 , + 'right_collar': 4.5e-4, + 'head': 5e-4, + 'left_shoulder': 2e-4, + 'right_shoulder': 2e-4, + 'left_elbow': 4e-5, + 'right_elbow': 4e-5, + 'left_wrist': 1e-3, + 'right_wrist': 1e-3, +} + +JOINT_IDX_PROB_THRESHOLDS = torch.tensor([JOINT_NAME_PROB_THRESHOLDS[joint_name] for joint_name in JOINT_NAMES]) + +################################################################### +POSE_RANGE_MIN = -np.pi +POSE_RANGE_MAX = np.pi + +# Create 21x100x100x100 histogram of all 21 AMASS body pose joints using `create_pose_hist(amass_poses, nbins=100)` +AMASS_HIST100_PATH = PM.inputs / 'amass_poses_hist100_SMPL+H_G.npy' +if not os.path.exists(AMASS_HIST100_PATH): + assert False, f'AMASS_HIST100_PATH not found: {AMASS_HIST100_PATH}' + +def create_pose_hist(poses: np.ndarray, nbins: int = 100) -> np.ndarray: + N,K,C = poses.shape + assert C==3, poses.shape + poses_21x3 = normalize_axis_angle(torch.fromnumpy(poses).view(N*K,3)).numpy().reshape(N, K, 3) + assert (poses_21x3 > -np.pi).all() and (poses_21x3 < np.pi).all() + + Hs, Es = [], [] + for i in range(K): + H, edges = np.histogramdd(poses_21x3[:, i, :], bins=nbins, range=[(-np.pi, np.pi)]*3) + Hs.append(H) + Es.append(edges) + Hs = np.stack(Hs, axis=0) + return Hs + +def load_amass_hist_smooth(sigma=2) -> torch.Tensor: + amass_poses_hist100 = np.load(AMASS_HIST100_PATH) + amass_poses_hist100 = torch.from_numpy(amass_poses_hist100) + assert amass_poses_hist100.shape == (21,100,100,100) + + nbins = amass_poses_hist100.shape[1] + amass_poses_hist100 = amass_poses_hist100/amass_poses_hist100.sum() / (2*np.pi/nbins)**3 + + # Gaussian filter on amass_poses_hist100 + from scipy.ndimage import gaussian_filter + amass_poses_hist100_smooth = gaussian_filter(amass_poses_hist100.numpy(), sigma=sigma, mode='constant') + amass_poses_hist100_smooth = torch.from_numpy(amass_poses_hist100_smooth) + return amass_poses_hist100_smooth + +# Normalize axis angle representation s.t. angle is in [-pi, pi] +def normalize_axis_angle(poses: torch.Tensor) -> torch.Tensor: + # poses: N, 3 + # print(f'normalize_axis_angle ...') + assert poses.shape[1] == 3, poses.shape + angle = poses.norm(dim=1) + axis = F.normalize(poses, p=2, dim=1, eps=1e-8) + + angle_fixed = angle.clone() + axis_fixed = axis.clone() + + eps = 1e-6 + ii = 0 + while True: + # print(f'normalize_axis_angle iter {ii}') + ii += 1 + angle_too_big = (angle_fixed > np.pi + eps) + if not angle_too_big.any(): + break + + angle_fixed[angle_too_big] -= 2 * np.pi + angle_too_small = (angle_fixed < -eps) + axis_fixed[angle_too_small] *= -1 + angle_fixed[angle_too_small] *= -1 + + return axis_fixed * angle_fixed[:,None] + +def poses_to_joint_probs(poses: torch.Tensor, amass_poses_100_smooth: torch.Tensor) -> torch.Tensor: + # poses: Nx69 + # amass_poses_100_smooth: 21xBINSxBINSxBINS + # returns: poses_prob: Nx21 + N=poses.shape[0] + assert poses.shape == (N,69) + poses = poses[:,:63].reshape(N*21,3) + + nbins = amass_poses_100_smooth.shape[1] + assert amass_poses_100_smooth.shape == (21,nbins,nbins,nbins) + + poses_bin = (poses - POSE_RANGE_MIN) / (POSE_RANGE_MAX - POSE_RANGE_MIN) * (nbins - 1e-6) + poses_bin = poses_bin.long().clip(0, nbins-1) + joint_id = torch.arange(21, device=poses.device).view(1,21).expand(N,21).reshape(N*21) + poses_prob = amass_poses_100_smooth[joint_id, poses_bin[:,0], poses_bin[:,1], poses_bin[:,2]] + + poses_bad = ((poses < POSE_RANGE_MIN) | (poses >= POSE_RANGE_MAX)).any(dim=1) + poses_prob[poses_bad] = 0 + + return poses_prob.view(N,21) + +def poses_check_probable( + poses: torch.Tensor, + amass_poses_100_smooth: torch.Tensor, + prob_thresholds: torch.Tensor = JOINT_IDX_PROB_THRESHOLDS + ) -> torch.Tensor: + N,C=poses.shape + poses_norm = normalize_axis_angle(poses.reshape(N*(C//3),3)).reshape(N,C) + poses_prob = poses_to_joint_probs(poses_norm, amass_poses_100_smooth) + return (poses_prob > prob_thresholds).all(dim=1) diff --git a/lib/data/datasets/skel_hmr2_fashion/utils.py b/lib/data/datasets/skel_hmr2_fashion/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..b22225485fc2e89c5f19a711e4817720969656ea --- /dev/null +++ b/lib/data/datasets/skel_hmr2_fashion/utils.py @@ -0,0 +1,1069 @@ +""" +Parts of the code are taken or adapted from +https://github.com/mkocabas/EpipolarPose/blob/master/lib/utils/img_utils.py +""" +import torch +import numpy as np +# from skimage.transform import rotate, resize +# from skimage.filters import gaussian +import random +import cv2 +from typing import List, Dict, Tuple, Union +from yacs.config import CfgNode + +def expand_to_aspect_ratio(input_shape, target_aspect_ratio=None): + """Increase the size of the bounding box to match the target shape.""" + if target_aspect_ratio is None: + return input_shape + + try: + w , h = input_shape + except (ValueError, TypeError): + return input_shape + + w_t, h_t = target_aspect_ratio + if h / w < h_t / w_t: + h_new = max(w * h_t / w_t, h) + w_new = w + else: + h_new = h + w_new = max(h * w_t / h_t, w) + if h_new < h or w_new < w: + breakpoint() + return np.array([w_new, h_new]) + +def expand_bbox_to_aspect_ratio(bbox, target_aspect_ratio=None): + # bbox: np.ndarray: (N,4) detectron2 bbox format + # target_aspect_ratio: (width, height) + if target_aspect_ratio is None: + return bbox + + is_singleton = (bbox.ndim == 1) + if is_singleton: + bbox = bbox[None,:] + + if bbox.shape[0] > 0: + center = np.stack(((bbox[:,0] + bbox[:,2]) / 2, (bbox[:,1] + bbox[:,3]) / 2), axis=1) + scale_wh = np.stack((bbox[:,2] - bbox[:,0], bbox[:,3] - bbox[:,1]), axis=1) + scale_wh = np.stack([expand_to_aspect_ratio(wh, target_aspect_ratio) for wh in scale_wh], axis=0) + bbox = np.stack([ + center[:,0] - scale_wh[:,0] / 2, + center[:,1] - scale_wh[:,1] / 2, + center[:,0] + scale_wh[:,0] / 2, + center[:,1] + scale_wh[:,1] / 2, + ], axis=1) + + if is_singleton: + bbox = bbox[0,:] + + return bbox + +def do_augmentation(aug_config: CfgNode) -> Tuple: + """ + Compute random augmentation parameters. + Args: + aug_config (CfgNode): Config containing augmentation parameters. + Returns: + scale (float): Box rescaling factor. + rot (float): Random image rotation. + do_flip (bool): Whether to flip image or not. + do_extreme_crop (bool): Whether to apply extreme cropping (as proposed in EFT). + color_scale (List): Color rescaling factor + tx (float): Random translation along the x axis. + ty (float): Random translation along the y axis. + """ + + tx = np.clip(np.random.randn(), -1.0, 1.0) * aug_config.TRANS_FACTOR + ty = np.clip(np.random.randn(), -1.0, 1.0) * aug_config.TRANS_FACTOR + scale = np.clip(np.random.randn(), -1.0, 1.0) * aug_config.SCALE_FACTOR + 1.0 + rot = np.clip(np.random.randn(), -2.0, + 2.0) * aug_config.ROT_FACTOR if random.random() <= aug_config.ROT_AUG_RATE else 0 + do_flip = aug_config.DO_FLIP and random.random() <= aug_config.FLIP_AUG_RATE + do_extreme_crop = random.random() <= aug_config.EXTREME_CROP_AUG_RATE + extreme_crop_lvl = aug_config.get('EXTREME_CROP_AUG_LEVEL', 0) + # extreme_crop_lvl = 0 + c_up = 1.0 + aug_config.COLOR_SCALE + c_low = 1.0 - aug_config.COLOR_SCALE + color_scale = [random.uniform(c_low, c_up), random.uniform(c_low, c_up), random.uniform(c_low, c_up)] + return scale, rot, do_flip, do_extreme_crop, extreme_crop_lvl, color_scale, tx, ty + +def rotate_2d(pt_2d: np.ndarray, rot_rad: float) -> np.ndarray: + """ + Rotate a 2D point on the x-y plane. + Args: + pt_2d (np.ndarray): Input 2D point with shape (2,). + rot_rad (float): Rotation angle + Returns: + np.ndarray: Rotated 2D point. + """ + x = pt_2d[0] + y = pt_2d[1] + sn, cs = np.sin(rot_rad), np.cos(rot_rad) + xx = x * cs - y * sn + yy = x * sn + y * cs + return np.array([xx, yy], dtype=np.float32) + + +def gen_trans_from_patch_cv(c_x: float, c_y: float, + src_width: float, src_height: float, + dst_width: float, dst_height: float, + scale: float, rot: float) -> np.ndarray: + """ + Create transformation matrix for the bounding box crop. + Args: + c_x (float): Bounding box center x coordinate in the original image. + c_y (float): Bounding box center y coordinate in the original image. + src_width (float): Bounding box width. + src_height (float): Bounding box height. + dst_width (float): Output box width. + dst_height (float): Output box height. + scale (float): Rescaling factor for the bounding box (augmentation). + rot (float): Random rotation applied to the box. + Returns: + trans (np.ndarray): Target geometric transformation. + """ + # augment size with scale + src_w = src_width * scale + src_h = src_height * scale + src_center = np.zeros(2) + src_center[0] = c_x + src_center[1] = c_y + # augment rotation + rot_rad = np.pi * rot / 180 + src_downdir = rotate_2d(np.array([0, src_h * 0.5], dtype=np.float32), rot_rad) + src_rightdir = rotate_2d(np.array([src_w * 0.5, 0], dtype=np.float32), rot_rad) + + dst_w = dst_width + dst_h = dst_height + dst_center = np.array([dst_w * 0.5, dst_h * 0.5], dtype=np.float32) + dst_downdir = np.array([0, dst_h * 0.5], dtype=np.float32) + dst_rightdir = np.array([dst_w * 0.5, 0], dtype=np.float32) + + src = np.zeros((3, 2), dtype=np.float32) + src[0, :] = src_center + src[1, :] = src_center + src_downdir + src[2, :] = src_center + src_rightdir + + dst = np.zeros((3, 2), dtype=np.float32) + dst[0, :] = dst_center + dst[1, :] = dst_center + dst_downdir + dst[2, :] = dst_center + dst_rightdir + + trans = cv2.getAffineTransform(np.float32(src), np.float32(dst)) # type: ignore + + return trans + + +def trans_point2d(pt_2d: np.ndarray, trans: np.ndarray): + """ + Transform a 2D point using translation matrix trans. + Args: + pt_2d (np.ndarray): Input 2D point with shape (2,). + trans (np.ndarray): Transformation matrix. + Returns: + np.ndarray: Transformed 2D point. + """ + src_pt = np.array([pt_2d[0], pt_2d[1], 1.]).T + dst_pt = np.dot(trans, src_pt) + return dst_pt[0:2] + +def get_transform(center, scale, res, rot=0): + """Generate transformation matrix.""" + """Taken from PARE: https://github.com/mkocabas/PARE/blob/6e0caca86c6ab49ff80014b661350958e5b72fd8/pare/utils/image_utils.py""" + h = 200 * scale + t = np.zeros((3, 3)) + t[0, 0] = float(res[1]) / h + t[1, 1] = float(res[0]) / h + t[0, 2] = res[1] * (-float(center[0]) / h + .5) + t[1, 2] = res[0] * (-float(center[1]) / h + .5) + t[2, 2] = 1 + if not rot == 0: + rot = -rot # To match direction of rotation from cropping + rot_mat = np.zeros((3, 3)) + rot_rad = rot * np.pi / 180 + sn, cs = np.sin(rot_rad), np.cos(rot_rad) + rot_mat[0, :2] = [cs, -sn] + rot_mat[1, :2] = [sn, cs] + rot_mat[2, 2] = 1 + # Need to rotate around center + t_mat = np.eye(3) + t_mat[0, 2] = -res[1] / 2 + t_mat[1, 2] = -res[0] / 2 + t_inv = t_mat.copy() + t_inv[:2, 2] *= -1 + t = np.dot(t_inv, np.dot(rot_mat, np.dot(t_mat, t))) + return t + + +def transform(pt, center, scale, res, invert=0, rot=0, as_int=True): + """Transform pixel location to different reference.""" + """Taken from PARE: https://github.com/mkocabas/PARE/blob/6e0caca86c6ab49ff80014b661350958e5b72fd8/pare/utils/image_utils.py""" + t = get_transform(center, scale, res, rot=rot) + if invert: + t = np.linalg.inv(t) + new_pt = np.array([pt[0] - 1, pt[1] - 1, 1.]).T + new_pt = np.dot(t, new_pt) + if as_int: + new_pt = new_pt.astype(int) + return new_pt[:2] + 1 + +def crop_img(img, ul, br, border_mode=cv2.BORDER_CONSTANT, border_value=0): + c_x = (ul[0] + br[0])/2 + c_y = (ul[1] + br[1])/2 + bb_width = patch_width = br[0] - ul[0] + bb_height = patch_height = br[1] - ul[1] + trans = gen_trans_from_patch_cv(c_x, c_y, bb_width, bb_height, patch_width, patch_height, 1.0, 0) + img_patch = cv2.warpAffine(img, trans, (int(patch_width), int(patch_height)), + flags=cv2.INTER_LINEAR, + borderMode=border_mode, + borderValue=border_value + ) # type: ignore + + # Force borderValue=cv2.BORDER_CONSTANT for alpha channel + if (img.shape[2] == 4) and (border_mode != cv2.BORDER_CONSTANT): + img_patch[:,:,3] = cv2.warpAffine(img[:,:,3], trans, (int(patch_width), int(patch_height)), + flags=cv2.INTER_LINEAR, + borderMode=cv2.BORDER_CONSTANT, + ) + + return img_patch + +# def generate_image_patch_skimage(img: np.ndarray, c_x: float, c_y: float, +# bb_width: float, bb_height: float, +# patch_width: float, patch_height: float, +# do_flip: bool, scale: float, rot: float, +# border_mode=cv2.BORDER_CONSTANT, border_value=0) -> Tuple[np.ndarray, np.ndarray]: +# """ +# Crop image according to the supplied bounding box. +# Args: +# img (np.ndarray): Input image of shape (H, W, 3) +# c_x (float): Bounding box center x coordinate in the original image. +# c_y (float): Bounding box center y coordinate in the original image. +# bb_width (float): Bounding box width. +# bb_height (float): Bounding box height. +# patch_width (float): Output box width. +# patch_height (float): Output box height. +# do_flip (bool): Whether to flip image or not. +# scale (float): Rescaling factor for the bounding box (augmentation). +# rot (float): Random rotation applied to the box. +# Returns: +# img_patch (np.ndarray): Cropped image patch of shape (patch_height, patch_height, 3) +# trans (np.ndarray): Transformation matrix. +# """ + +# img_height, img_width, img_channels = img.shape +# if do_flip: +# img = img[:, ::-1, :] +# c_x = img_width - c_x - 1 + +# trans = gen_trans_from_patch_cv(c_x, c_y, bb_width, bb_height, patch_width, patch_height, scale, rot) + +# #img_patch = cv2.warpAffine(img, trans, (int(patch_width), int(patch_height)), flags=cv2.INTER_LINEAR) + +# # skimage +# center = np.zeros(2) +# center[0] = c_x +# center[1] = c_y +# res = np.zeros(2) +# res[0] = patch_width +# res[1] = patch_height +# # assumes bb_width = bb_height +# # assumes patch_width = patch_height +# assert bb_width == bb_height, f'{bb_width=} != {bb_height=}' +# assert patch_width == patch_height, f'{patch_width=} != {patch_height=}' +# scale1 = scale*bb_width/200. + +# # Upper left point +# ul = np.array(transform([1, 1], center, scale1, res, invert=1, as_int=False)) - 1 +# # Bottom right point +# br = np.array(transform([res[0] + 1, +# res[1] + 1], center, scale1, res, invert=1, as_int=False)) - 1 + +# # Padding so that when rotated proper amount of context is included +# try: +# pad = int(np.linalg.norm(br - ul) / 2 - float(br[1] - ul[1]) / 2) + 1 +# except: +# breakpoint() +# if not rot == 0: +# ul -= pad +# br += pad + + +# if False: +# # Old way of cropping image +# ul_int = ul.astype(int) +# br_int = br.astype(int) +# new_shape = [br_int[1] - ul_int[1], br_int[0] - ul_int[0]] +# if len(img.shape) > 2: +# new_shape += [img.shape[2]] +# new_img = np.zeros(new_shape) + +# # Range to fill new array +# new_x = max(0, -ul_int[0]), min(br_int[0], len(img[0])) - ul_int[0] +# new_y = max(0, -ul_int[1]), min(br_int[1], len(img)) - ul_int[1] +# # Range to sample from original image +# old_x = max(0, ul_int[0]), min(len(img[0]), br_int[0]) +# old_y = max(0, ul_int[1]), min(len(img), br_int[1]) +# new_img[new_y[0]:new_y[1], new_x[0]:new_x[1]] = img[old_y[0]:old_y[1], +# old_x[0]:old_x[1]] + +# # New way of cropping image +# new_img = crop_img(img, ul, br, border_mode=border_mode, border_value=border_value).astype(np.float32) + +# # print(f'{new_img.shape=}') +# # print(f'{new_img1.shape=}') +# # print(f'{np.allclose(new_img, new_img1)=}') +# # print(f'{img.dtype=}') + + +# if not rot == 0: +# # Remove padding + +# new_img = rotate(new_img, rot) # scipy.misc.imrotate(new_img, rot) +# new_img = new_img[pad:-pad, pad:-pad] + +# if new_img.shape[0] < 1 or new_img.shape[1] < 1: +# print(f'{img.shape=}') +# print(f'{new_img.shape=}') +# print(f'{ul=}') +# print(f'{br=}') +# print(f'{pad=}') +# print(f'{rot=}') + +# breakpoint() + +# # resize image +# new_img = resize(new_img, res) # scipy.misc.imresize(new_img, res) + +# new_img = np.clip(new_img, 0, 255).astype(np.uint8) + +# return new_img, trans + + +def generate_image_patch_cv2(img: np.ndarray, c_x: float, c_y: float, + bb_width: float, bb_height: float, + patch_width: float, patch_height: float, + do_flip: bool, scale: float, rot: float, + border_mode=cv2.BORDER_CONSTANT, border_value=0) -> Tuple[np.ndarray, np.ndarray]: + """ + Crop the input image and return the crop and the corresponding transformation matrix. + Args: + img (np.ndarray): Input image of shape (H, W, 3) + c_x (float): Bounding box center x coordinate in the original image. + c_y (float): Bounding box center y coordinate in the original image. + bb_width (float): Bounding box width. + bb_height (float): Bounding box height. + patch_width (float): Output box width. + patch_height (float): Output box height. + do_flip (bool): Whether to flip image or not. + scale (float): Rescaling factor for the bounding box (augmentation). + rot (float): Random rotation applied to the box. + Returns: + img_patch (np.ndarray): Cropped image patch of shape (patch_height, patch_height, 3) + trans (np.ndarray): Transformation matrix. + """ + + img_height, img_width, img_channels = img.shape + if do_flip: + img = img[:, ::-1, :] + c_x = img_width - c_x - 1 + + + trans = gen_trans_from_patch_cv(c_x, c_y, bb_width, bb_height, patch_width, patch_height, scale, rot) + + img_patch = cv2.warpAffine(img, trans, (int(patch_width), int(patch_height)), + flags=cv2.INTER_LINEAR, + borderMode=border_mode, + borderValue=border_value, + ) # type: ignore + # Force borderValue=cv2.BORDER_CONSTANT for alpha channel + if (img.shape[2] == 4) and (border_mode != cv2.BORDER_CONSTANT): + img_patch[:,:,3] = cv2.warpAffine(img[:,:,3], trans, (int(patch_width), int(patch_height)), + flags=cv2.INTER_LINEAR, + borderMode=cv2.BORDER_CONSTANT, + ) + + return img_patch, trans + + +def convert_cvimg_to_tensor(cvimg: np.ndarray): + """ + Convert image from HWC to CHW format. + Args: + cvimg (np.ndarray): Image of shape (H, W, 3) as loaded by OpenCV. + Returns: + np.ndarray: Output image of shape (3, H, W). + """ + # from h,w,c(OpenCV) to c,h,w + img = cvimg.copy() + img = np.transpose(img, (2, 0, 1)) + # from int to float + img = img.astype(np.float32) + return img + +def fliplr_params(smpl_params: Dict, has_smpl_params: Dict) -> Tuple[Dict, Dict]: + """ + Flip SMPL parameters when flipping the image. + Args: + smpl_params (Dict): SMPL parameter annotations. + has_smpl_params (Dict): Whether SMPL annotations are valid. + Returns: + Dict, Dict: Flipped SMPL parameters and valid flags. + """ + global_orient = smpl_params['global_orient'].copy() + body_pose = smpl_params['body_pose'].copy() + betas = smpl_params['betas'].copy() + has_global_orient = has_smpl_params['global_orient'].copy() + has_body_pose = has_smpl_params['body_pose'].copy() + has_betas = has_smpl_params['betas'].copy() + + body_pose_permutation = [6, 7, 8, 3, 4, 5, 9, 10, 11, 15, 16, 17, 12, 13, + 14 ,18, 19, 20, 24, 25, 26, 21, 22, 23, 27, 28, 29, 33, + 34, 35, 30, 31, 32, 36, 37, 38, 42, 43, 44, 39, 40, 41, + 45, 46, 47, 51, 52, 53, 48, 49, 50, 57, 58, 59, 54, 55, + 56, 63, 64, 65, 60, 61, 62, 69, 70, 71, 66, 67, 68] + body_pose_permutation = body_pose_permutation[:len(body_pose)] + body_pose_permutation = [i-3 for i in body_pose_permutation] + + body_pose = body_pose[body_pose_permutation] + + global_orient[1::3] *= -1 + global_orient[2::3] *= -1 + body_pose[1::3] *= -1 + body_pose[2::3] *= -1 + + smpl_params = {'global_orient': global_orient.astype(np.float32), + 'body_pose': body_pose.astype(np.float32), + 'betas': betas.astype(np.float32) + } + + has_smpl_params = {'global_orient': has_global_orient, + 'body_pose': has_body_pose, + 'betas': has_betas + } + + return smpl_params, has_smpl_params + + +def fliplr_keypoints(joints: np.ndarray, width: float, flip_permutation: List[int]) -> np.ndarray: + """ + Flip 2D or 3D keypoints. + Args: + joints (np.ndarray): Array of shape (N, 3) or (N, 4) containing 2D or 3D keypoint locations and confidence. + flip_permutation (List): Permutation to apply after flipping. + Returns: + np.ndarray: Flipped 2D or 3D keypoints with shape (N, 3) or (N, 4) respectively. + """ + joints = joints.copy() + # Flip horizontal + joints[:, 0] = width - joints[:, 0] - 1 + joints = joints[flip_permutation, :] + + return joints + +def keypoint_3d_processing(keypoints_3d: np.ndarray, flip_permutation: List[int], rot: float, do_flip: float) -> np.ndarray: + """ + Process 3D keypoints (rotation/flipping). + Args: + keypoints_3d (np.ndarray): Input array of shape (N, 4) containing the 3D keypoints and confidence. + flip_permutation (List): Permutation to apply after flipping. + rot (float): Random rotation applied to the keypoints. + do_flip (bool): Whether to flip keypoints or not. + Returns: + np.ndarray: Transformed 3D keypoints with shape (N, 4). + """ + if do_flip: + keypoints_3d = fliplr_keypoints(keypoints_3d, 1, flip_permutation) + # in-plane rotation + rot_mat = np.eye(3) + if not rot == 0: + rot_rad = -rot * np.pi / 180 + sn,cs = np.sin(rot_rad), np.cos(rot_rad) + rot_mat[0,:2] = [cs, -sn] + rot_mat[1,:2] = [sn, cs] + keypoints_3d[:, :-1] = np.einsum('ij,kj->ki', rot_mat, keypoints_3d[:, :-1]) + # flip the x coordinates + keypoints_3d = keypoints_3d.astype('float32') + return keypoints_3d + +def rot_q_orient(q: np.ndarray, rot: float) -> np.ndarray: + """ + Rotate SKEL orientation. + Args: + q (np.ndarray): SKEL style rotation representation (3,). + rot (np.ndarray): Rotation angle in degrees. + Returns: + np.ndarray: Rotated axis-angle vector. + """ + import torch + from lib.body_models.skel.osim_rot import CustomJoint + from lib.utils.geometry.rotation import ( + axis_angle_to_matrix, + matrix_to_euler_angles, + euler_angles_to_matrix, + ) + # q to mat + q = torch.from_numpy(q).unsqueeze(0) + q = q[:, [2, 1, 0]] + Rp = euler_angles_to_matrix(q, convention="YXZ") + # rotate around z + R = torch.Tensor([[np.deg2rad(-rot), 0, 0]]) + R = axis_angle_to_matrix(R) + R = torch.matmul(R, Rp) + # mat to q + q = matrix_to_euler_angles(R, convention="YXZ") + q = q[:, [2, 1, 0]] + q = q.numpy().squeeze() + + return q.astype(np.float32) + + +def rot_aa(aa: np.ndarray, rot: float) -> np.ndarray: + """ + Rotate axis angle parameters. + Args: + aa (np.ndarray): Axis-angle vector of shape (3,). + rot (np.ndarray): Rotation angle in degrees. + Returns: + np.ndarray: Rotated axis-angle vector. + """ + # pose parameters + R = np.array([[np.cos(np.deg2rad(-rot)), -np.sin(np.deg2rad(-rot)), 0], + [np.sin(np.deg2rad(-rot)), np.cos(np.deg2rad(-rot)), 0], + [0, 0, 1]]) + # find the rotation of the body in camera frame + per_rdg, _ = cv2.Rodrigues(aa) + # apply the global rotation to the global orientation + resrot, _ = cv2.Rodrigues(np.dot(R,per_rdg)) + aa = (resrot.T)[0] + return aa.astype(np.float32) + +def smpl_param_processing(smpl_params: Dict, has_smpl_params: Dict, rot: float, do_flip: bool, do_aug: bool) -> Tuple[Dict, Dict]: + """ + Apply random augmentations to the SMPL parameters. + Args: + smpl_params (Dict): SMPL parameter annotations. + has_smpl_params (Dict): Whether SMPL annotations are valid. + rot (float): Random rotation applied to the keypoints. + do_flip (bool): Whether to flip keypoints or not. + Returns: + Dict, Dict: Transformed SMPL parameters and valid flags. + """ + if do_aug: + if do_flip: + smpl_params, has_smpl_params = {k:v[1] for k, v in smpl_params.items()}, {k:v[1] for k, v in has_smpl_params.items()} + else: + smpl_params, has_smpl_params = {k:v[0] for k, v in smpl_params.items()}, {k:v[0] for k, v in has_smpl_params.items()} + if do_aug: + smpl_params['global_orient'] = rot_q_orient(smpl_params['global_orient'], rot) + else: + smpl_params['global_orient'] = rot_aa(smpl_params['global_orient'], rot) + return smpl_params, has_smpl_params + + + +def get_example(img_path: Union[str, np.ndarray], center_x: float, center_y: float, + width: float, height: float, + keypoints_2d: np.ndarray, keypoints_3d: np.ndarray, + smpl_params: Dict, has_smpl_params: Dict, + flip_kp_permutation: List[int], + patch_width: int, patch_height: int, + mean: np.ndarray, std: np.ndarray, + do_augment: bool, augm_config: CfgNode, + is_bgr: bool = True, + use_skimage_antialias: bool = False, + border_mode: int = cv2.BORDER_CONSTANT, + return_trans: bool = False) -> Tuple: + """ + Get an example from the dataset and (possibly) apply random augmentations. + Args: + img_path (str): Image filename + center_x (float): Bounding box center x coordinate in the original image. + center_y (float): Bounding box center y coordinate in the original image. + width (float): Bounding box width. + height (float): Bounding box height. + keypoints_2d (np.ndarray): Array with shape (N,3) containing the 2D keypoints in the original image coordinates. + keypoints_3d (np.ndarray): Array with shape (N,4) containing the 3D keypoints. + smpl_params (Dict): SMPL parameter annotations. + has_smpl_params (Dict): Whether SMPL annotations are valid. + flip_kp_permutation (List): Permutation to apply to the keypoints after flipping. + patch_width (float): Output box width. + patch_height (float): Output box height. + mean (np.ndarray): Array of shape (3,) containing the mean for normalizing the input image. + std (np.ndarray): Array of shape (3,) containing the std for normalizing the input image. + do_augment (bool): Whether to apply data augmentation or not. + aug_config (CfgNode): Config containing augmentation parameters. + Returns: + return img_patch, keypoints_2d, keypoints_3d, smpl_params, has_smpl_params, img_size + img_patch (np.ndarray): Cropped image patch of shape (3, patch_height, patch_height) + keypoints_2d (np.ndarray): Array with shape (N,3) containing the transformed 2D keypoints. + keypoints_3d (np.ndarray): Array with shape (N,4) containing the transformed 3D keypoints. + smpl_params (Dict): Transformed SMPL parameters. + has_smpl_params (Dict): Valid flag for transformed SMPL parameters. + img_size (np.ndarray): Image size of the original image. + """ + if isinstance(img_path, str): + # 1. load image + cvimg = cv2.imread(img_path, cv2.IMREAD_COLOR | cv2.IMREAD_IGNORE_ORIENTATION) + if not isinstance(cvimg, np.ndarray): + raise IOError("Fail to read %s" % img_path) + elif isinstance(img_path, np.ndarray): + cvimg = img_path + else: + raise TypeError('img_path must be either a string or a numpy array') + img_height, img_width, img_channels = cvimg.shape + + img_size = np.array([img_height, img_width]) + + # 2. get augmentation params + if do_augment: + scale, rot, do_flip, do_extreme_crop, extreme_crop_lvl, color_scale, tx, ty = do_augmentation(augm_config) + else: + scale, rot, do_flip, do_extreme_crop, extreme_crop_lvl, color_scale, tx, ty = 1.0, 0, False, False, 0, [1.0, 1.0, 1.0], 0., 0. + + if width < 1 or height < 1: + breakpoint() + + if do_extreme_crop: + if extreme_crop_lvl == 0: + center_x1, center_y1, width1, height1 = extreme_cropping(center_x, center_y, width, height, keypoints_2d) + elif extreme_crop_lvl == 1: + center_x1, center_y1, width1, height1 = extreme_cropping_aggressive(center_x, center_y, width, height, keypoints_2d) + + THRESH = 4 + if width1 < THRESH or height1 < THRESH: + # print(f'{do_extreme_crop=}') + # print(f'width: {width}, height: {height}') + # print(f'width1: {width1}, height1: {height1}') + # print(f'center_x: {center_x}, center_y: {center_y}') + # print(f'center_x1: {center_x1}, center_y1: {center_y1}') + # print(f'keypoints_2d: {keypoints_2d}') + # print(f'\n\n', flush=True) + # breakpoint() + pass + # print(f'skip ==> width1: {width1}, height1: {height1}, width: {width}, height: {height}') + else: + center_x, center_y, width, height = center_x1, center_y1, width1, height1 + + center_x += width * tx + center_y += height * ty + + # Process 3D keypoints + keypoints_3d = keypoint_3d_processing(keypoints_3d, flip_kp_permutation, rot, do_flip) + + # 3. generate image patch + # if use_skimage_antialias: + # # Blur image to avoid aliasing artifacts + # downsampling_factor = (patch_width / (width*scale)) + # if downsampling_factor > 1.1: + # cvimg = gaussian(cvimg, sigma=(downsampling_factor-1)/2, channel_axis=2, preserve_range=True, truncate=3.0) + + img_patch_cv, trans = generate_image_patch_cv2(cvimg, + center_x, center_y, + width, height, + patch_width, patch_height, + do_flip, scale, rot, + border_mode=border_mode) + # img_patch_cv, trans = generate_image_patch_skimage(cvimg, + # center_x, center_y, + # width, height, + # patch_width, patch_height, + # do_flip, scale, rot, + # border_mode=border_mode) + + image = img_patch_cv.copy() + if is_bgr: + image = image[:, :, ::-1] + img_patch_cv = image.copy() + img_patch = convert_cvimg_to_tensor(image) + + + smpl_params, has_smpl_params = smpl_param_processing(smpl_params, has_smpl_params, rot, do_flip, do_augment) + + # apply normalization + for n_c in range(min(img_channels, 3)): + img_patch[n_c, :, :] = np.clip(img_patch[n_c, :, :] * color_scale[n_c], 0, 255) + if mean is not None and std is not None: + img_patch[n_c, :, :] = (img_patch[n_c, :, :] - mean[n_c]) / std[n_c] + if do_flip: + keypoints_2d = fliplr_keypoints(keypoints_2d, img_width, flip_kp_permutation) + + + for n_jt in range(len(keypoints_2d)): + keypoints_2d[n_jt, 0:2] = trans_point2d(keypoints_2d[n_jt, 0:2], trans) + keypoints_2d[:, :-1] = keypoints_2d[:, :-1] / patch_width - 0.5 + + + augm_record = { + 'do_flip' : do_flip, + 'rot' : rot, + } + + if not return_trans: + return img_patch, keypoints_2d, keypoints_3d, smpl_params, has_smpl_params, img_size, augm_record + else: + return img_patch, keypoints_2d, keypoints_3d, smpl_params, has_smpl_params, img_size, trans, augm_record + +def crop_to_hips(center_x: float, center_y: float, width: float, height: float, keypoints_2d: np.ndarray) -> Tuple: + """ + Extreme cropping: Crop the box up to the hip locations. + Args: + center_x (float): x coordinate of the bounding box center. + center_y (float): y coordinate of the bounding box center. + width (float): Bounding box width. + height (float): Bounding box height. + keypoints_2d (np.ndarray): Array of shape (N, 3) containing 2D keypoint locations. + Returns: + center_x (float): x coordinate of the new bounding box center. + center_y (float): y coordinate of the new bounding box center. + width (float): New bounding box width. + height (float): New bounding box height. + """ + keypoints_2d = keypoints_2d.copy() + lower_body_keypoints = [10, 11, 13, 14, 19, 20, 21, 22, 23, 24, 25+0, 25+1, 25+4, 25+5] + keypoints_2d[lower_body_keypoints, :] = 0 + if keypoints_2d[:, -1].sum() > 1: + center, scale = get_bbox(keypoints_2d) + center_x = center[0] + center_y = center[1] + width = 1.1 * scale[0] + height = 1.1 * scale[1] + return center_x, center_y, width, height + + +def crop_to_shoulders(center_x: float, center_y: float, width: float, height: float, keypoints_2d: np.ndarray): + """ + Extreme cropping: Crop the box up to the shoulder locations. + Args: + center_x (float): x coordinate of the bounding box center. + center_y (float): y coordinate of the bounding box center. + width (float): Bounding box width. + height (float): Bounding box height. + keypoints_2d (np.ndarray): Array of shape (N, 3) containing 2D keypoint locations. + Returns: + center_x (float): x coordinate of the new bounding box center. + center_y (float): y coordinate of the new bounding box center. + width (float): New bounding box width. + height (float): New bounding box height. + """ + keypoints_2d = keypoints_2d.copy() + lower_body_keypoints = [3, 4, 6, 7, 8, 9, 10, 11, 12, 13, 14, 19, 20, 21, 22, 23, 24] + [25 + i for i in [0, 1, 2, 3, 4, 5, 6, 7, 10, 11, 14, 15, 16]] + keypoints_2d[lower_body_keypoints, :] = 0 + center, scale = get_bbox(keypoints_2d) + if keypoints_2d[:, -1].sum() > 1: + center, scale = get_bbox(keypoints_2d) + center_x = center[0] + center_y = center[1] + width = 1.2 * scale[0] + height = 1.2 * scale[1] + return center_x, center_y, width, height + +def crop_to_head(center_x: float, center_y: float, width: float, height: float, keypoints_2d: np.ndarray): + """ + Extreme cropping: Crop the box and keep on only the head. + Args: + center_x (float): x coordinate of the bounding box center. + center_y (float): y coordinate of the bounding box center. + width (float): Bounding box width. + height (float): Bounding box height. + keypoints_2d (np.ndarray): Array of shape (N, 3) containing 2D keypoint locations. + Returns: + center_x (float): x coordinate of the new bounding box center. + center_y (float): y coordinate of the new bounding box center. + width (float): New bounding box width. + height (float): New bounding box height. + """ + keypoints_2d = keypoints_2d.copy() + lower_body_keypoints = [3, 4, 6, 7, 8, 9, 10, 11, 12, 13, 14, 19, 20, 21, 22, 23, 24] + [25 + i for i in [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 14, 15, 16]] + keypoints_2d[lower_body_keypoints, :] = 0 + if keypoints_2d[:, -1].sum() > 1: + center, scale = get_bbox(keypoints_2d) + center_x = center[0] + center_y = center[1] + width = 1.3 * scale[0] + height = 1.3 * scale[1] + return center_x, center_y, width, height + +def crop_torso_only(center_x: float, center_y: float, width: float, height: float, keypoints_2d: np.ndarray): + """ + Extreme cropping: Crop the box and keep on only the torso. + Args: + center_x (float): x coordinate of the bounding box center. + center_y (float): y coordinate of the bounding box center. + width (float): Bounding box width. + height (float): Bounding box height. + keypoints_2d (np.ndarray): Array of shape (N, 3) containing 2D keypoint locations. + Returns: + center_x (float): x coordinate of the new bounding box center. + center_y (float): y coordinate of the new bounding box center. + width (float): New bounding box width. + height (float): New bounding box height. + """ + keypoints_2d = keypoints_2d.copy() + nontorso_body_keypoints = [0, 3, 4, 6, 7, 10, 11, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24] + [25 + i for i in [0, 1, 4, 5, 6, 7, 10, 11, 13, 17, 18]] + keypoints_2d[nontorso_body_keypoints, :] = 0 + if keypoints_2d[:, -1].sum() > 1: + center, scale = get_bbox(keypoints_2d) + center_x = center[0] + center_y = center[1] + width = 1.1 * scale[0] + height = 1.1 * scale[1] + return center_x, center_y, width, height + +def crop_rightarm_only(center_x: float, center_y: float, width: float, height: float, keypoints_2d: np.ndarray): + """ + Extreme cropping: Crop the box and keep on only the right arm. + Args: + center_x (float): x coordinate of the bounding box center. + center_y (float): y coordinate of the bounding box center. + width (float): Bounding box width. + height (float): Bounding box height. + keypoints_2d (np.ndarray): Array of shape (N, 3) containing 2D keypoint locations. + Returns: + center_x (float): x coordinate of the new bounding box center. + center_y (float): y coordinate of the new bounding box center. + width (float): New bounding box width. + height (float): New bounding box height. + """ + keypoints_2d = keypoints_2d.copy() + nonrightarm_body_keypoints = [0, 1, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24] + [25 + i for i in [0, 1, 2, 3, 4, 5, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18]] + keypoints_2d[nonrightarm_body_keypoints, :] = 0 + if keypoints_2d[:, -1].sum() > 1: + center, scale = get_bbox(keypoints_2d) + center_x = center[0] + center_y = center[1] + width = 1.1 * scale[0] + height = 1.1 * scale[1] + return center_x, center_y, width, height + +def crop_leftarm_only(center_x: float, center_y: float, width: float, height: float, keypoints_2d: np.ndarray): + """ + Extreme cropping: Crop the box and keep on only the left arm. + Args: + center_x (float): x coordinate of the bounding box center. + center_y (float): y coordinate of the bounding box center. + width (float): Bounding box width. + height (float): Bounding box height. + keypoints_2d (np.ndarray): Array of shape (N, 3) containing 2D keypoint locations. + Returns: + center_x (float): x coordinate of the new bounding box center. + center_y (float): y coordinate of the new bounding box center. + width (float): New bounding box width. + height (float): New bounding box height. + """ + keypoints_2d = keypoints_2d.copy() + nonleftarm_body_keypoints = [0, 1, 2, 3, 4, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24] + [25 + i for i in [0, 1, 2, 3, 4, 5, 6, 7, 8, 12, 13, 14, 15, 16, 17, 18]] + keypoints_2d[nonleftarm_body_keypoints, :] = 0 + if keypoints_2d[:, -1].sum() > 1: + center, scale = get_bbox(keypoints_2d) + center_x = center[0] + center_y = center[1] + width = 1.1 * scale[0] + height = 1.1 * scale[1] + return center_x, center_y, width, height + +def crop_legs_only(center_x: float, center_y: float, width: float, height: float, keypoints_2d: np.ndarray): + """ + Extreme cropping: Crop the box and keep on only the legs. + Args: + center_x (float): x coordinate of the bounding box center. + center_y (float): y coordinate of the bounding box center. + width (float): Bounding box width. + height (float): Bounding box height. + keypoints_2d (np.ndarray): Array of shape (N, 3) containing 2D keypoint locations. + Returns: + center_x (float): x coordinate of the new bounding box center. + center_y (float): y coordinate of the new bounding box center. + width (float): New bounding box width. + height (float): New bounding box height. + """ + keypoints_2d = keypoints_2d.copy() + nonlegs_body_keypoints = [0, 1, 2, 3, 4, 5, 6, 7, 15, 16, 17, 18] + [25 + i for i in [6, 7, 8, 9, 10, 11, 12, 13, 15, 16, 17, 18]] + keypoints_2d[nonlegs_body_keypoints, :] = 0 + if keypoints_2d[:, -1].sum() > 1: + center, scale = get_bbox(keypoints_2d) + center_x = center[0] + center_y = center[1] + width = 1.1 * scale[0] + height = 1.1 * scale[1] + return center_x, center_y, width, height + +def crop_rightleg_only(center_x: float, center_y: float, width: float, height: float, keypoints_2d: np.ndarray): + """ + Extreme cropping: Crop the box and keep on only the right leg. + Args: + center_x (float): x coordinate of the bounding box center. + center_y (float): y coordinate of the bounding box center. + width (float): Bounding box width. + height (float): Bounding box height. + keypoints_2d (np.ndarray): Array of shape (N, 3) containing 2D keypoint locations. + Returns: + center_x (float): x coordinate of the new bounding box center. + center_y (float): y coordinate of the new bounding box center. + width (float): New bounding box width. + height (float): New bounding box height. + """ + keypoints_2d = keypoints_2d.copy() + nonrightleg_body_keypoints = [0, 1, 2, 3, 4, 5, 6, 7, 8, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21] + [25 + i for i in [3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18]] + keypoints_2d[nonrightleg_body_keypoints, :] = 0 + if keypoints_2d[:, -1].sum() > 1: + center, scale = get_bbox(keypoints_2d) + center_x = center[0] + center_y = center[1] + width = 1.1 * scale[0] + height = 1.1 * scale[1] + return center_x, center_y, width, height + +def crop_leftleg_only(center_x: float, center_y: float, width: float, height: float, keypoints_2d: np.ndarray): + """ + Extreme cropping: Crop the box and keep on only the left leg. + Args: + center_x (float): x coordinate of the bounding box center. + center_y (float): y coordinate of the bounding box center. + width (float): Bounding box width. + height (float): Bounding box height. + keypoints_2d (np.ndarray): Array of shape (N, 3) containing 2D keypoint locations. + Returns: + center_x (float): x coordinate of the new bounding box center. + center_y (float): y coordinate of the new bounding box center. + width (float): New bounding box width. + height (float): New bounding box height. + """ + keypoints_2d = keypoints_2d.copy() + nonleftleg_body_keypoints = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 15, 16, 17, 18, 22, 23, 24] + [25 + i for i in [0, 1, 2, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18]] + keypoints_2d[nonleftleg_body_keypoints, :] = 0 + if keypoints_2d[:, -1].sum() > 1: + center, scale = get_bbox(keypoints_2d) + center_x = center[0] + center_y = center[1] + width = 1.1 * scale[0] + height = 1.1 * scale[1] + return center_x, center_y, width, height + +def full_body(keypoints_2d: np.ndarray) -> bool: + """ + Check if all main body joints are visible. + Args: + keypoints_2d (np.ndarray): Array of shape (N, 3) containing 2D keypoint locations. + Returns: + bool: True if all main body joints are visible. + """ + + body_keypoints_openpose = [2, 3, 4, 5, 6, 7, 10, 11, 13, 14] + body_keypoints = [25 + i for i in [8, 7, 6, 9, 10, 11, 1, 0, 4, 5]] + return (np.maximum(keypoints_2d[body_keypoints, -1], keypoints_2d[body_keypoints_openpose, -1]) > 0).sum() == len(body_keypoints) + +def upper_body(keypoints_2d: np.ndarray): + """ + Check if all upper body joints are visible. + Args: + keypoints_2d (np.ndarray): Array of shape (N, 3) containing 2D keypoint locations. + Returns: + bool: True if all main body joints are visible. + """ + lower_body_keypoints_openpose = [10, 11, 13, 14] + lower_body_keypoints = [25 + i for i in [1, 0, 4, 5]] + upper_body_keypoints_openpose = [0, 1, 15, 16, 17, 18] + upper_body_keypoints = [25+8, 25+9, 25+12, 25+13, 25+17, 25+18] + return ((keypoints_2d[lower_body_keypoints + lower_body_keypoints_openpose, -1] > 0).sum() == 0)\ + and ((keypoints_2d[upper_body_keypoints + upper_body_keypoints_openpose, -1] > 0).sum() >= 2) + +def get_bbox(keypoints_2d: np.ndarray, rescale: float = 1.2) -> Tuple: + """ + Get center and scale for bounding box from openpose detections. + Args: + keypoints_2d (np.ndarray): Array of shape (N, 3) containing 2D keypoint locations. + rescale (float): Scale factor to rescale bounding boxes computed from the keypoints. + Returns: + center (np.ndarray): Array of shape (2,) containing the new bounding box center. + scale (float): New bounding box scale. + """ + valid = keypoints_2d[:,-1] > 0 + valid_keypoints = keypoints_2d[valid][:,:-1] + center = 0.5 * (valid_keypoints.max(axis=0) + valid_keypoints.min(axis=0)) + bbox_size = (valid_keypoints.max(axis=0) - valid_keypoints.min(axis=0)) + # adjust bounding box tightness + scale = bbox_size + scale *= rescale + return center, scale + +def extreme_cropping(center_x: float, center_y: float, width: float, height: float, keypoints_2d: np.ndarray) -> Tuple: + """ + Perform extreme cropping + Args: + center_x (float): x coordinate of bounding box center. + center_y (float): y coordinate of bounding box center. + width (float): bounding box width. + height (float): bounding box height. + keypoints_2d (np.ndarray): Array of shape (N, 3) containing 2D keypoint locations. + rescale (float): Scale factor to rescale bounding boxes computed from the keypoints. + Returns: + center_x (float): x coordinate of bounding box center. + center_y (float): y coordinate of bounding box center. + width (float): bounding box width. + height (float): bounding box height. + """ + p = torch.rand(1).item() + if full_body(keypoints_2d): + if p < 0.7: + center_x, center_y, width, height = crop_to_hips(center_x, center_y, width, height, keypoints_2d) + elif p < 0.9: + center_x, center_y, width, height = crop_to_shoulders(center_x, center_y, width, height, keypoints_2d) + else: + center_x, center_y, width, height = crop_to_head(center_x, center_y, width, height, keypoints_2d) + elif upper_body(keypoints_2d): + if p < 0.9: + center_x, center_y, width, height = crop_to_shoulders(center_x, center_y, width, height, keypoints_2d) + else: + center_x, center_y, width, height = crop_to_head(center_x, center_y, width, height, keypoints_2d) + + return center_x, center_y, max(width, height), max(width, height) + +def extreme_cropping_aggressive(center_x: float, center_y: float, width: float, height: float, keypoints_2d: np.ndarray) -> Tuple: + """ + Perform aggressive extreme cropping + Args: + center_x (float): x coordinate of bounding box center. + center_y (float): y coordinate of bounding box center. + width (float): bounding box width. + height (float): bounding box height. + keypoints_2d (np.ndarray): Array of shape (N, 3) containing 2D keypoint locations. + rescale (float): Scale factor to rescale bounding boxes computed from the keypoints. + Returns: + center_x (float): x coordinate of bounding box center. + center_y (float): y coordinate of bounding box center. + width (float): bounding box width. + height (float): bounding box height. + """ + p = torch.rand(1).item() + if full_body(keypoints_2d): + if p < 0.2: + center_x, center_y, width, height = crop_to_hips(center_x, center_y, width, height, keypoints_2d) + elif p < 0.3: + center_x, center_y, width, height = crop_to_shoulders(center_x, center_y, width, height, keypoints_2d) + elif p < 0.4: + center_x, center_y, width, height = crop_to_head(center_x, center_y, width, height, keypoints_2d) + elif p < 0.5: + center_x, center_y, width, height = crop_torso_only(center_x, center_y, width, height, keypoints_2d) + elif p < 0.6: + center_x, center_y, width, height = crop_rightarm_only(center_x, center_y, width, height, keypoints_2d) + elif p < 0.7: + center_x, center_y, width, height = crop_leftarm_only(center_x, center_y, width, height, keypoints_2d) + elif p < 0.8: + center_x, center_y, width, height = crop_legs_only(center_x, center_y, width, height, keypoints_2d) + elif p < 0.9: + center_x, center_y, width, height = crop_rightleg_only(center_x, center_y, width, height, keypoints_2d) + else: + center_x, center_y, width, height = crop_leftleg_only(center_x, center_y, width, height, keypoints_2d) + elif upper_body(keypoints_2d): + if p < 0.2: + center_x, center_y, width, height = crop_to_shoulders(center_x, center_y, width, height, keypoints_2d) + elif p < 0.4: + center_x, center_y, width, height = crop_to_head(center_x, center_y, width, height, keypoints_2d) + elif p < 0.6: + center_x, center_y, width, height = crop_torso_only(center_x, center_y, width, height, keypoints_2d) + elif p < 0.8: + center_x, center_y, width, height = crop_rightarm_only(center_x, center_y, width, height, keypoints_2d) + else: + center_x, center_y, width, height = crop_leftarm_only(center_x, center_y, width, height, keypoints_2d) + return center_x, center_y, max(width, height), max(width, height) diff --git a/lib/data/datasets/skel_hmr2_fashion/vitdet_dataset.py b/lib/data/datasets/skel_hmr2_fashion/vitdet_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..8c9d9d4d01259884ae8438e859bb04a1d141a07d --- /dev/null +++ b/lib/data/datasets/skel_hmr2_fashion/vitdet_dataset.py @@ -0,0 +1,89 @@ +from typing import Dict + +import cv2 +import numpy as np +from skimage.filters import gaussian +from yacs.config import CfgNode +import torch + +from .utils import (convert_cvimg_to_tensor, + expand_to_aspect_ratio, + generate_image_patch_cv2) + +DEFAULT_MEAN = 255. * np.array([0.485, 0.456, 0.406]) +DEFAULT_STD = 255. * np.array([0.229, 0.224, 0.225]) + +class ViTDetDataset(torch.utils.data.Dataset): + + def __init__(self, + cfg: CfgNode, + img_cv2: np.ndarray, + boxes: np.ndarray, + train: bool = False, + **kwargs): + super().__init__() + self.cfg = cfg + self.img_cv2 = img_cv2 + # self.boxes = boxes + + assert train == False, "ViTDetDataset is only for inference" + self.train = train + self.img_size = cfg.MODEL.IMAGE_SIZE + self.mean = 255. * np.array(self.cfg.MODEL.IMAGE_MEAN) + self.std = 255. * np.array(self.cfg.MODEL.IMAGE_STD) + + # Preprocess annotations + boxes = boxes.astype(np.float32) + self.center = (boxes[:, 2:4] + boxes[:, 0:2]) / 2.0 + self.scale = (boxes[:, 2:4] - boxes[:, 0:2]) / 200.0 + self.personid = np.arange(len(boxes), dtype=np.int32) + + def __len__(self) -> int: + return len(self.personid) + + def __getitem__(self, idx: int) -> Dict[str, np.ndarray]: + + center = self.center[idx].copy() + center_x = center[0] + center_y = center[1] + + scale = self.scale[idx] + BBOX_SHAPE = self.cfg.MODEL.get('BBOX_SHAPE', None) + bbox_size = expand_to_aspect_ratio(scale*200, target_aspect_ratio=BBOX_SHAPE).max() + + patch_width = patch_height = self.img_size + + # 3. generate image patch + # if use_skimage_antialias: + cvimg = self.img_cv2.copy() + if True: + # Blur image to avoid aliasing artifacts + downsampling_factor = ((bbox_size*1.0) / patch_width) + print(f'{downsampling_factor=}') + downsampling_factor = downsampling_factor / 2.0 + if downsampling_factor > 1.1: + cvimg = gaussian(cvimg, sigma=(downsampling_factor-1)/2, channel_axis=2, preserve_range=True) + + + img_patch_cv, trans = generate_image_patch_cv2(cvimg, + center_x, center_y, + bbox_size, bbox_size, + patch_width, patch_height, + False, 1.0, 0, + border_mode=cv2.BORDER_CONSTANT) + img_patch_cv = img_patch_cv[:, :, ::-1] + img_patch = convert_cvimg_to_tensor(img_patch_cv) + + # apply normalization + for n_c in range(min(self.img_cv2.shape[2], 3)): + img_patch[n_c, :, :] = (img_patch[n_c, :, :] - self.mean[n_c]) / self.std[n_c] + + item = { + 'img_full': cvimg, + 'img': img_patch, + 'personid': int(self.personid[idx]), + } + item['box_center'] = self.center[idx].copy() + item['box_size'] = bbox_size + item['img_size'] = 1.0 * np.array([cvimg.shape[1], cvimg.shape[0]]) + return item diff --git a/lib/data/modules/hmr2_fashion/README.md b/lib/data/modules/hmr2_fashion/README.md new file mode 100644 index 0000000000000000000000000000000000000000..8c41d3b66efe7223d3f410db8e980e2033738ad7 --- /dev/null +++ b/lib/data/modules/hmr2_fashion/README.md @@ -0,0 +1,3 @@ +# HMR2 fashion dataset + +This folder contains customized implementation to adapt HMR2 style dataset. (e.g. the `ImageDataset` class). \ No newline at end of file diff --git a/lib/data/modules/hmr2_fashion/skel_wds.py b/lib/data/modules/hmr2_fashion/skel_wds.py new file mode 100644 index 0000000000000000000000000000000000000000..27155e2223323e7e8f877ae04865849e73726874 --- /dev/null +++ b/lib/data/modules/hmr2_fashion/skel_wds.py @@ -0,0 +1,106 @@ +from lib.kits.basic import * + +import webdataset as wds + + +from lib.data.datasets.skel_hmr2_fashion.image_dataset import ImageDataset + + +class MixedWebDataset(wds.WebDataset): + def __init__(self) -> None: + super(wds.WebDataset, self).__init__() + + +class DataModule(pl.LightningDataModule): + def __init__(self, name:str, cfg:DictConfig): + super().__init__() + self.name = name + self.cfg = cfg + self.cfg_eval = self.cfg.pop('eval', None) + self.cfg_train = self.cfg.pop('train', None) + + + def setup(self, stage=None): + if stage in ['test', None, '_debug_eval']: + self._setup_eval() + + if stage in ['fit', None, '_debug_train']: + self._setup_train() + + + def train_dataloader(self): + return torch.utils.data.DataLoader( + dataset = self.train_dataset, + **self.cfg_train.dataloader, + ) + + + def val_dataloader(self): + # Since we don't need validation here. + return self.test_dataloader() + + + def test_dataloader(self): + # return torch.utils.data.DataLoader( + # dataset = self.eval_datasets['LSP-EXTENDED'], # TODO: Support multiple datasets through ConcatDataset (but to figure out how to mix with weights) + # **self.cfg_eval.dataloader, + # ) + return torch.utils.data.DataLoader( + dataset = self.eval_datasets, # TODO: Support multiple datasets through ConcatDataset (but to figure out how to mix with weights) + **self.cfg_eval.dataloader, + ) + + + # ========== Internal Modules to Setup Datasets ========== + + def _setup_train(self): + hack_cfg = { + 'IMAGE_SIZE': self.cfg.policy.img_patch_size, + 'IMAGE_MEAN': self.cfg.policy.img_mean, + 'IMAGE_STD' : self.cfg.policy.img_std, + 'BBOX_SHAPE': None, + 'augm': self.cfg.augm, + } + + self.train_datasets = [] # [(dataset:Dataset, weight:float), ...] + datasets, weights = [], [] + opt = self.cfg_train.get('shared_ds_opt', {}) + for dataset_cfg in self.cfg_train.datasets: + cur_cfg = {**hack_cfg, **opt} + dataset = ImageDataset.load_tars_as_webdataset( + cfg = cur_cfg, + urls = dataset_cfg.item.urls, + train = True, + epoch_size = dataset_cfg.item.epoch_size, + ) + weights.append(dataset_cfg.weight) + datasets.append(dataset) + weights = to_numpy(weights) + weights = weights / weights.sum() + self.train_dataset = MixedWebDataset() + self.train_dataset.append(wds.RandomMix(datasets, weights, longest=False)) + self.train_dataset = self.train_dataset.with_epoch(100_000).shuffle(4000) + + + def _setup_eval(self): + hack_cfg = { + 'IMAGE_SIZE' : self.cfg.policy.img_patch_size, + 'IMAGE_MEAN' : self.cfg.policy.img_mean, + 'IMAGE_STD' : self.cfg.policy.img_std, + 'BBOX_SHAPE' : [192, 256], + 'augm' : self.cfg.augm, + } + + self.eval_datasets = {} + opt = self.cfg_train.get('shared_ds_opt', {}) + for dataset_cfg in self.cfg_eval.datasets: + cur_cfg = {**hack_cfg, **opt} + dataset = ImageDataset( + cfg = hack_cfg, + dataset_file = dataset_cfg.item.dataset_file, + img_dir = dataset_cfg.item.img_root, + train = False, + ) + dataset._kp_list_ = dataset_cfg.item.kp_list + self.eval_datasets[dataset_cfg.name] = dataset + diff --git a/lib/data/modules/hsmr_v1/data_module.py b/lib/data/modules/hsmr_v1/data_module.py new file mode 100644 index 0000000000000000000000000000000000000000..6616683879052d01c77dc3d9b00424553c7cc0f8 --- /dev/null +++ b/lib/data/modules/hsmr_v1/data_module.py @@ -0,0 +1,117 @@ +from lib.kits.basic import * + +from lib.data.datasets.hsmr_v1.mocap_dataset import MoCapDataset +from lib.data.datasets.hsmr_v1.wds_loader import load_tars_as_wds +import webdataset as wds + +class MixedWebDataset(wds.WebDataset): + def __init__(self) -> None: + super(wds.WebDataset, self).__init__() + + +class DataModule(pl.LightningDataModule): + def __init__(self, name, cfg): + super().__init__() + self.name = name + self.cfg = cfg + self.cfg_eval = self.cfg.pop('eval', None) + self.cfg_train = self.cfg.pop('train', None) + self.cfg_mocap = self.cfg.pop('mocap', None) + + + def setup(self, stage=None): + if stage in ['test', None, '_debug_eval'] and self.cfg_eval is not None: + # get_logger().info('Evaluation dataset will be enabled.') + self._setup_eval() + + if stage in ['fit', None, '_debug_train'] and self.cfg_train is not None: + # get_logger().info('Training dataset will be enabled.') + self._setup_train() + + if stage in ['fit', None, '_debug_mocap'] and self.cfg_mocap is not None: + # get_logger().info('Mocap dataset will be enabled.') + self._setup_mocap() + + + def train_dataloader(self): + img_dataset = torch.utils.data.DataLoader( + dataset = self.train_dataset, + **self.cfg_train.dataloader, + ) + ret = {'img_ds' : img_dataset } + + if self.cfg_mocap is not None: + mocap_dataset = torch.utils.data.DataLoader( + dataset = self.mocap_dataset, + **self.cfg_mocap.dataloader, + ) + ret['mocap_ds'] = mocap_dataset + + return ret + + + + # ========== Internal Modules to Setup Datasets ========== + + + def _setup_train(self): + names, datasets, weights = [], [], [] + ld_cfg = self.cfg_train.cfg # cfg for initializing wds loading pipeline + + for ds_cfg in self.cfg_train.datasets: + dataset = load_tars_as_wds( + ld_cfg, + ds_cfg.item.urls, + ds_cfg.item.epoch_size + ) + + names.append(ds_cfg.name) + datasets.append(dataset) + weights.append(ds_cfg.weight) + + # get_logger().info(f"Dataset '{ds_cfg.name}' loaded.") + + # Normalize the weights and mix the datasets. + weights = to_numpy(weights) + weights = weights / weights.sum() + self.train_datasets = datasets + self.train_dataset = MixedWebDataset() + self.train_dataset.append(wds.RandomMix(datasets, weights)) + self.train_dataset = self.train_dataset.with_epoch(50_000).shuffle(1000, initial=1000) + + + def _setup_mocap(self): + self.mocap_dataset = MoCapDataset(**self.cfg_mocap.cfg) + + + def _setup_eval(self, selected_ds_names:Optional[List[str]]=None): + from lib.data.datasets.skel_hmr2_fashion.image_dataset import ImageDataset + hack_cfg = { + 'IMAGE_SIZE' : 256, + 'IMAGE_MEAN' : [0.485, 0.456, 0.406], + 'IMAGE_STD' : [0.229, 0.224, 0.225], + 'BBOX_SHAPE' : [192, 256], + 'augm' : self.cfg.image_augmentation, + 'SUPPRESS_KP_CONF_THRESH' : 0.3, + 'FILTER_NUM_KP' : 4, + 'FILTER_NUM_KP_THRESH' : 0.0, + 'FILTER_REPROJ_THRESH' : 31000, + 'SUPPRESS_BETAS_THRESH' : 3.0, + 'SUPPRESS_BAD_POSES' : False, + 'POSES_BETAS_SIMULTANEOUS': True, + 'FILTER_NO_POSES' : False, + 'BETAS_REG' : True, + } + + self.eval_datasets = {} + for dataset_cfg in self.cfg_eval.datasets: + if selected_ds_names is not None and dataset_cfg.name not in selected_ds_names: + continue + dataset = ImageDataset( + cfg = hack_cfg, + dataset_file = dataset_cfg.item.dataset_file, + img_dir = dataset_cfg.item.img_root, + train = False, + ) + dataset._kp_list_ = dataset_cfg.item.kp_list + self.eval_datasets[dataset_cfg.name] = dataset diff --git a/lib/evaluation/__init__.py b/lib/evaluation/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..cb203b2c16637c3993b3e6cd2ad68164a2196370 --- /dev/null +++ b/lib/evaluation/__init__.py @@ -0,0 +1,93 @@ +from .metrics import * +from .evaluators import EvaluatorBase + +from lib.body_models.smpl_utils.reality import eval_rot_delta as smpl_eval_rot_delta + +class HSMR3DEvaluator(EvaluatorBase): + + def eval(self, **kwargs): + # Get the predictions and ground truths. + pd, gt = kwargs['pd'], kwargs['gt'] + v3d_pd, v3d_gt = pd['v3d_pose'], gt['v3d_pose'] + j3d_pd, j3d_gt = pd['j3d_pose'], gt['j3d_pose'] + + # Compute the metrics. + mpjpe = eval_MPxE(j3d_pd, j3d_gt) + pa_mpjpe = eval_PA_MPxE(j3d_pd, j3d_gt) + mpve = eval_MPxE(v3d_pd, v3d_gt) + pa_mpve = eval_PA_MPxE(v3d_pd, v3d_gt) + + # Append the results. + self.accumulator['MPJPE'].append(mpjpe.detach().cpu()) + self.accumulator['PA-MPJPE'].append(pa_mpjpe.detach().cpu()) + self.accumulator['MPVE'].append(mpve.detach().cpu()) + self.accumulator['PA-MPVE'].append(pa_mpve.detach().cpu()) + + + +class SMPLRealityEvaluator(EvaluatorBase): + + def eval(self, **kwargs): + # Get the predictions and ground truths. + pd = kwargs['pd'] + body_pose = pd['body_pose'] # (..., 23, 3) + + # Compute the metrics. + vio_d0 = smpl_eval_rot_delta(body_pose, tol_deg=0 ) # {k: (N, 3)} + vio_d5 = smpl_eval_rot_delta(body_pose, tol_deg=5 ) # {k: (N, 3)} + vio_d10 = smpl_eval_rot_delta(body_pose, tol_deg=10) # {k: (N, 3)} + vio_d20 = smpl_eval_rot_delta(body_pose, tol_deg=20) # {k: (N, 3)} + vio_d30 = smpl_eval_rot_delta(body_pose, tol_deg=30) # {k: (N, 3)} + + # Append the results. + parts = vio_d5.keys() + for part in parts: + self.accumulator[f'VD0_{part}' ].append(vio_d0[part].max(-1)[0].detach().cpu()) # (N,) + self.accumulator[f'VD5_{part}' ].append(vio_d5[part].max(-1)[0].detach().cpu()) # (N,) + self.accumulator[f'VD10_{part}'].append(vio_d10[part].max(-1)[0].detach().cpu()) # (N,) + self.accumulator[f'VD20_{part}'].append(vio_d20[part].max(-1)[0].detach().cpu()) # (N,) + self.accumulator[f'VD30_{part}'].append(vio_d30[part].max(-1)[0].detach().cpu()) # (N,) + + def get_results(self, chosen_metric=None): + ''' Get the current mean results. ''' + # Only chosen metrics will be compacted and returned. + compacted = self._compact_accumulator(chosen_metric) + ret = {} + for k, v in compacted.items(): + vio_max = v.max() + vio_mean = v.mean() + vio_median = v.median() + tot_cnt = len(v) + vio_cnt = (v > 0).float().sum() + vio_p = vio_cnt / tot_cnt + ret[f'{k}_max'] = vio_max.item() + ret[f'{k}_mean'] = vio_mean.item() + ret[f'{k}_median'] = vio_median.item() + ret[f'{k}_percentage'] = vio_p.item() + return ret + + +from lib.body_models.skel_utils.reality import eval_rot_delta as skel_eval_rot_delta + +class SKELRealityEvaluator(SMPLRealityEvaluator): + + def eval(self, **kwargs): + # Get the predictions and ground truths. + pd = kwargs['pd'] + poses = pd['poses'] # (..., 46) + + # Compute the metrics. + vio_d0 = skel_eval_rot_delta(poses, tol_deg=0 ) # {k: (N, 3)} + vio_d5 = skel_eval_rot_delta(poses, tol_deg=5 ) # {k: (N, 3)} + vio_d10 = skel_eval_rot_delta(poses, tol_deg=10) # {k: (N, 3)} + vio_d20 = skel_eval_rot_delta(poses, tol_deg=20) # {k: (N, 3)} + vio_d30 = skel_eval_rot_delta(poses, tol_deg=30) # {k: (N, 3)} + + # Append the results. + parts = vio_d5.keys() + for part in parts: + self.accumulator[f'VD0_{part}' ].append(vio_d0[part].max(-1)[0].detach().cpu()) # (N,) + self.accumulator[f'VD5_{part}' ].append(vio_d5[part].max(-1)[0].detach().cpu()) # (N,) + self.accumulator[f'VD10_{part}'].append(vio_d10[part].max(-1)[0].detach().cpu()) # (N,) + self.accumulator[f'VD20_{part}'].append(vio_d20[part].max(-1)[0].detach().cpu()) # (N,) + self.accumulator[f'VD30_{part}'].append(vio_d30[part].max(-1)[0].detach().cpu()) # (N,) diff --git a/lib/evaluation/evaluators.py b/lib/evaluation/evaluators.py new file mode 100644 index 0000000000000000000000000000000000000000..52477f39b26334931e1f3536785b6e3ee78e9c74 --- /dev/null +++ b/lib/evaluation/evaluators.py @@ -0,0 +1,34 @@ +import torch +from collections import defaultdict + +from .metrics import * + + +class EvaluatorBase(): + ''' To use this class, you should inherit it and implement the `eval` method. ''' + def __init__(self): + self.accumulator = defaultdict(list) + + def eval(self, **kwargs): + ''' Evaluate the metrics on the data. ''' + raise NotImplementedError + + def get_results(self, chosen_metric=None): + ''' Get the current mean results. ''' + # Only chosen metrics will be compacted and returned. + compacted = self._compact_accumulator(chosen_metric) + ret = {} + for k, v in compacted.items(): + ret[k] = v.mean(dim=0).item() + return ret + + def _compact_accumulator(self, chosen_metric=None): + ''' Compact the accumulator list and return the compacted results. ''' + ret = {} + for k, v in self.accumulator.items(): + # Only chosen metrics will be compacted. + if chosen_metric is None or k in chosen_metric: + ret[k] = torch.cat(v, dim=0) + self.accumulator[k] = [ret[k]] + return ret + diff --git a/lib/evaluation/hmr2_utils/__init__.py b/lib/evaluation/hmr2_utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..62a16b522f818b199b0146faff443752d9221093 --- /dev/null +++ b/lib/evaluation/hmr2_utils/__init__.py @@ -0,0 +1,314 @@ +""" +Code adapted from: https://github.com/akanazawa/hmr/blob/master/src/benchmark/eval_util.py +""" + +import torch +import numpy as np +from typing import Optional, Dict, List, Tuple + +def compute_similarity_transform(S1: torch.Tensor, S2: torch.Tensor) -> torch.Tensor: + """ + Computes a similarity transform (sR, t) in a batched way that takes + a set of 3D points S1 (B, N, 3) closest to a set of 3D points S2 (B, N, 3), + where R is a 3x3 rotation matrix, t 3x1 translation, s scale. + i.e. solves the orthogonal Procrutes problem. + Args: + S1 (torch.Tensor): First set of points of shape (B, N, 3). + S2 (torch.Tensor): Second set of points of shape (B, N, 3). + Returns: + (torch.Tensor): The first set of points after applying the similarity transformation. + """ + + batch_size = S1.shape[0] + S1 = S1.permute(0, 2, 1) + S2 = S2.permute(0, 2, 1) + # 1. Remove mean. + mu1 = S1.mean(dim=2, keepdim=True) + mu2 = S2.mean(dim=2, keepdim=True) + X1 = S1 - mu1 + X2 = S2 - mu2 + + # 2. Compute variance of X1 used for scale. + var1 = (X1**2).sum(dim=(1,2)) + + # 3. The outer product of X1 and X2. + K = torch.matmul(X1, X2.permute(0, 2, 1)) + + # 4. Solution that Maximizes trace(R'K) is R=U*V', where U, V are singular vectors of K. + U, s, V = torch.svd(K) + Vh = V.permute(0, 2, 1) + + # Construct Z that fixes the orientation of R to get det(R)=1. + Z = torch.eye(U.shape[1], device=U.device).unsqueeze(0).repeat(batch_size, 1, 1) + Z[:, -1, -1] *= torch.sign(torch.linalg.det(torch.matmul(U, Vh))) + + # Construct R. + R = torch.matmul(torch.matmul(V, Z), U.permute(0, 2, 1)) + + # 5. Recover scale. + trace = torch.matmul(R, K).diagonal(offset=0, dim1=-1, dim2=-2).sum(dim=-1) + scale = (trace / var1).unsqueeze(dim=-1).unsqueeze(dim=-1) + + # 6. Recover translation. + t = mu2 - scale*torch.matmul(R, mu1) + + # 7. Error: + S1_hat = scale*torch.matmul(R, S1) + t + + return S1_hat.permute(0, 2, 1) + +def reconstruction_error(S1, S2) -> np.array: + """ + Computes the mean Euclidean distance of 2 set of points S1, S2 after performing Procrustes alignment. + Args: + S1 (torch.Tensor): First set of points of shape (B, N, 3). + S2 (torch.Tensor): Second set of points of shape (B, N, 3). + Returns: + (np.array): Reconstruction error. + """ + S1_hat = compute_similarity_transform(S1, S2) + re = torch.sqrt( ((S1_hat - S2)** 2).sum(dim=-1)).mean(dim=-1) + return re + +def eval_pose(pred_joints, gt_joints) -> Tuple[np.array, np.array]: + """ + Compute joint errors in mm before and after Procrustes alignment. + Args: + pred_joints (torch.Tensor): Predicted 3D joints of shape (B, N, 3). + gt_joints (torch.Tensor): Ground truth 3D joints of shape (B, N, 3). + Returns: + Tuple[np.array, np.array]: Joint errors in mm before and after alignment. + """ + # Absolute error (MPJPE) + mpjpe = torch.sqrt(((pred_joints - gt_joints) ** 2).sum(dim=-1)).mean(dim=-1).cpu().numpy() + + # Reconstruction_error + r_error = reconstruction_error(pred_joints, gt_joints).cpu().numpy() + return 1000 * mpjpe, 1000 * r_error + +class Evaluator: + + def __init__(self, + dataset_length: int, + keypoint_list: List, + pelvis_ind: int, + metrics: List = ['mode_mpjpe', 'mode_re', 'min_mpjpe', 'min_re'], + pck_thresholds: Optional[List] = None): + """ + Class used for evaluating trained models on different 3D pose datasets. + Args: + dataset_length (int): Total dataset length. + keypoint_list [List]: List of keypoints used for evaluation. + pelvis_ind (int): Index of pelvis keypoint; used for aligning the predictions and ground truth. + metrics [List]: List of evaluation metrics to record. + """ + self.dataset_length = dataset_length + self.keypoint_list = keypoint_list + self.pelvis_ind = pelvis_ind + self.metrics = metrics + for metric in self.metrics: + setattr(self, metric, np.zeros((dataset_length,))) + self.counter = 0 + if pck_thresholds is None: + self.pck_evaluator = None + else: + self.pck_evaluator = EvaluatorPCK(pck_thresholds) + + def log(self): + """ + Print current evaluation metrics + """ + if self.counter == 0: + print('Evaluation has not started') + return + print(f'{self.counter} / {self.dataset_length} samples') + if self.pck_evaluator is not None: + self.pck_evaluator.log() + for metric in self.metrics: + if metric in ['mode_mpjpe', 'mode_re', 'min_mpjpe', 'min_re']: + unit = 'mm' + else: + unit = '' + print(f'{metric}: {getattr(self, metric)[:self.counter].mean()} {unit}') + print('***') + + def get_metrics_dict(self) -> Dict: + """ + Returns: + Dict: Dictionary of evaluation metrics. + """ + d1 = {metric: getattr(self, metric)[:self.counter].mean().item() for metric in self.metrics} + if self.pck_evaluator is not None: + d2 = self.pck_evaluator.get_metrics_dict() + d1.update(d2) + return d1 + + def __call__(self, output: Dict, batch: Dict, opt_output: Optional[Dict] = None): + """ + Evaluate current batch. + Args: + output (Dict): Regression output. + batch (Dict): Dictionary containing images and their corresponding annotations. + opt_output (Dict): Optimization output. + """ + if self.pck_evaluator is not None: + self.pck_evaluator(output, batch, opt_output) + + pred_keypoints_3d = output['pred_keypoints_3d'].detach() + pred_keypoints_3d = pred_keypoints_3d[:,None,:,:] + batch_size = pred_keypoints_3d.shape[0] + num_samples = pred_keypoints_3d.shape[1] + gt_keypoints_3d = batch['keypoints_3d'][:, :, :-1].unsqueeze(1).repeat(1, num_samples, 1, 1) + + # Align predictions and ground truth such that the pelvis location is at the origin + pred_keypoints_3d -= pred_keypoints_3d[:, :, [self.pelvis_ind]] + gt_keypoints_3d -= gt_keypoints_3d[:, :, [self.pelvis_ind]] + + # Compute joint errors + mpjpe, re = eval_pose(pred_keypoints_3d.reshape(batch_size * num_samples, -1, 3)[:, self.keypoint_list], gt_keypoints_3d.reshape(batch_size * num_samples, -1 ,3)[:, self.keypoint_list]) + mpjpe = mpjpe.reshape(batch_size, num_samples) + re = re.reshape(batch_size, num_samples) + + # Compute 2d keypoint errors + pred_keypoints_2d = output['pred_keypoints_2d'].detach() + pred_keypoints_2d = pred_keypoints_2d[:,None,:,:] + gt_keypoints_2d = batch['keypoints_2d'][:,None,:,:].repeat(1, num_samples, 1, 1) + conf = gt_keypoints_2d[:, :, :, -1].clone() + kp_err = torch.nn.functional.mse_loss( + pred_keypoints_2d, + gt_keypoints_2d[:, :, :, :-1], + reduction='none' + ).sum(dim=3) + kp_l2_loss = (conf * kp_err).mean(dim=2) + kp_l2_loss = kp_l2_loss.detach().cpu().numpy() + + # Compute joint errors after optimization, if available. + if opt_output is not None: + opt_keypoints_3d = opt_output['model_joints'] + opt_keypoints_3d -= opt_keypoints_3d[:, [self.pelvis_ind]] + opt_mpjpe, opt_re = eval_pose(opt_keypoints_3d[:, self.keypoint_list], gt_keypoints_3d[:, 0, self.keypoint_list]) + + # The 0-th sample always corresponds to the mode + if hasattr(self, 'mode_mpjpe'): + mode_mpjpe = mpjpe[:, 0] + self.mode_mpjpe[self.counter:self.counter+batch_size] = mode_mpjpe + if hasattr(self, 'mode_re'): + mode_re = re[:, 0] + self.mode_re[self.counter:self.counter+batch_size] = mode_re + if hasattr(self, 'mode_kpl2'): + mode_kpl2 = kp_l2_loss[:, 0] + self.mode_kpl2[self.counter:self.counter+batch_size] = mode_kpl2 + if hasattr(self, 'min_mpjpe'): + min_mpjpe = mpjpe.min(axis=-1) + self.min_mpjpe[self.counter:self.counter+batch_size] = min_mpjpe + if hasattr(self, 'min_re'): + min_re = re.min(axis=-1) + self.min_re[self.counter:self.counter+batch_size] = min_re + if hasattr(self, 'min_kpl2'): + min_kpl2 = kp_l2_loss.min(axis=-1) + self.min_kpl2[self.counter:self.counter+batch_size] = min_kpl2 + if hasattr(self, 'opt_mpjpe'): + self.opt_mpjpe[self.counter:self.counter+batch_size] = opt_mpjpe + if hasattr(self, 'opt_re'): + self.opt_re[self.counter:self.counter+batch_size] = opt_re + + self.counter += batch_size + + if hasattr(self, 'mode_mpjpe') and hasattr(self, 'mode_re'): + return { + 'mode_mpjpe': mode_mpjpe, + 'mode_re': mode_re, + } + if hasattr(self, 'mode_kpl2'): + return { + 'mode_kpl2': mode_kpl2, + } + else: + return {} + + +class EvaluatorPCK: + + def __init__(self, thresholds: List = [0.05, 0.1, 0.2, 0.3, 0.4, 0.5],): + """ + Class used for evaluating trained models on different 3D pose datasets. + Args: + thresholds [List]: List of PCK thresholds to evaluate. + metrics [List]: List of evaluation metrics to record. + """ + self.thresholds = thresholds + self.pred_kp_2d = [] + self.gt_kp_2d = [] + self.gt_conf_2d = [] + self.counter = 0 + + def log(self): + """ + Print current evaluation metrics + """ + if self.counter == 0: + print('Evaluation has not started') + return + print(f'{self.counter} samples') + metrics_dict = self.get_metrics_dict() + for metric in metrics_dict: + print(f'{metric}: {metrics_dict[metric]}') + print('***') + + def get_metrics_dict(self) -> Dict: + """ + Returns: + Dict: Dictionary of evaluation metrics. + """ + pcks = self.compute_pcks() + metrics = {} + for thr, (acc,avg_acc,cnt) in zip(self.thresholds, pcks): + metrics.update({f'kp{i}_pck_{thr}': float(a) for i, a in enumerate(acc) if a>=0}) + metrics.update({f'kpAvg_pck_{thr}': float(avg_acc)}) + return metrics + + def compute_pcks(self): + pred_kp_2d = np.concatenate(self.pred_kp_2d, axis=0) + gt_kp_2d = np.concatenate(self.gt_kp_2d, axis=0) + gt_conf_2d = np.concatenate(self.gt_conf_2d, axis=0) + assert pred_kp_2d.shape == gt_kp_2d.shape + assert pred_kp_2d[..., 0].shape == gt_conf_2d.shape + assert pred_kp_2d.shape[1] == 1 # num_samples + + from .pck_accuracy import keypoint_pck_accuracy + pcks = [ + keypoint_pck_accuracy( + pred_kp_2d[:, 0, :, :], + gt_kp_2d[:, 0, :, :], + gt_conf_2d[:, 0, :]>0.5, + thr=thr, + normalize = np.ones((len(pred_kp_2d),2)) # Already in [-0.5,0.5] range. No need to normalize + ) + for thr in self.thresholds + ] + return pcks + + def __call__(self, output: Dict, batch: Dict, opt_output: Optional[Dict] = None): + """ + Evaluate current batch. + Args: + output (Dict): Regression output. + batch (Dict): Dictionary containing images and their corresponding annotations. + opt_output (Dict): Optimization output. + """ + pred_keypoints_2d = output['pred_keypoints_2d'].detach() + num_samples = 1 + batch_size = pred_keypoints_2d.shape[0] + + pred_keypoints_2d = pred_keypoints_2d[:,None,:,:] + gt_keypoints_2d = batch['keypoints_2d'][:,None,:,:].repeat(1, num_samples, 1, 1) + + gt_bbox_expand_factor = (batch['box_size']/(batch['_scale']*200).max(dim=-1).values) + gt_bbox_expand_factor = gt_bbox_expand_factor[:,None,None,None].repeat(1, num_samples, 1, 1) + gt_bbox_expand_factor = gt_bbox_expand_factor.detach().cpu().numpy() + + self.pred_kp_2d.append(pred_keypoints_2d[:, :, :, :2].detach().cpu().numpy() * gt_bbox_expand_factor) + self.gt_conf_2d.append(gt_keypoints_2d[:, :, :, -1].detach().cpu().numpy()) + self.gt_kp_2d.append(gt_keypoints_2d[:, :, :, :2].detach().cpu().numpy() * gt_bbox_expand_factor) + + self.counter += batch_size diff --git a/lib/evaluation/hmr2_utils/pck_accuracy.py b/lib/evaluation/hmr2_utils/pck_accuracy.py new file mode 100644 index 0000000000000000000000000000000000000000..5fea6c3859a930c805e0f5357b0fb71adac6764c --- /dev/null +++ b/lib/evaluation/hmr2_utils/pck_accuracy.py @@ -0,0 +1,94 @@ +# Copyright (c) OpenMMLab. All rights reserved. + +import numpy as np + +def _calc_distances(preds, targets, mask, normalize): + """Calculate the normalized distances between preds and target. + + Note: + batch_size: N + num_keypoints: K + dimension of keypoints: D (normally, D=2 or D=3) + + Args: + preds (np.ndarray[N, K, D]): Predicted keypoint location. + targets (np.ndarray[N, K, D]): Groundtruth keypoint location. + mask (np.ndarray[N, K]): Visibility of the target. False for invisible + joints, and True for visible. Invisible joints will be ignored for + accuracy calculation. + normalize (np.ndarray[N, D]): Typical value is heatmap_size + + Returns: + np.ndarray[K, N]: The normalized distances. \ + If target keypoints are missing, the distance is -1. + """ + N, K, _ = preds.shape + # set mask=0 when normalize==0 + _mask = mask.copy() + _mask[np.where((normalize == 0).sum(1))[0], :] = False + distances = np.full((N, K), -1, dtype=np.float32) + # handle invalid values + normalize[np.where(normalize <= 0)] = 1e6 + distances[_mask] = np.linalg.norm( + ((preds - targets) / normalize[:, None, :])[_mask], axis=-1) + return distances.T + + +def _distance_acc(distances, thr=0.5): + """Return the percentage below the distance threshold, while ignoring + distances values with -1. + + Note: + batch_size: N + Args: + distances (np.ndarray[N, ]): The normalized distances. + thr (float): Threshold of the distances. + + Returns: + float: Percentage of distances below the threshold. \ + If all target keypoints are missing, return -1. + """ + distance_valid = distances != -1 + num_distance_valid = distance_valid.sum() + if num_distance_valid > 0: + return (distances[distance_valid] < thr).sum() / num_distance_valid + return -1 + + +def keypoint_pck_accuracy(pred, gt, mask, thr, normalize): + """Calculate the pose accuracy of PCK for each individual keypoint and the + averaged accuracy across all keypoints for coordinates. + + Note: + PCK metric measures accuracy of the localization of the body joints. + The distances between predicted positions and the ground-truth ones + are typically normalized by the bounding box size. + The threshold (thr) of the normalized distance is commonly set + as 0.05, 0.1 or 0.2 etc. + + - batch_size: N + - num_keypoints: K + + Args: + pred (np.ndarray[N, K, 2]): Predicted keypoint location. + gt (np.ndarray[N, K, 2]): Groundtruth keypoint location. + mask (np.ndarray[N, K]): Visibility of the target. False for invisible + joints, and True for visible. Invisible joints will be ignored for + accuracy calculation. + thr (float): Threshold of PCK calculation. + normalize (np.ndarray[N, 2]): Normalization factor for H&W. + + Returns: + tuple: A tuple containing keypoint accuracy. + + - acc (np.ndarray[K]): Accuracy of each keypoint. + - avg_acc (float): Averaged accuracy across all keypoints. + - cnt (int): Number of valid keypoints. + """ + distances = _calc_distances(pred, gt, mask, normalize) + + acc = np.array([_distance_acc(d, thr) for d in distances]) + valid_acc = acc[acc >= 0] + cnt = len(valid_acc) + avg_acc = valid_acc.mean() if cnt > 0 else 0 + return acc, avg_acc, cnt diff --git a/lib/evaluation/metrics/__init__.py b/lib/evaluation/metrics/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..5dbfa03320c4b86f2960a83ffb73917e4dd3d476 --- /dev/null +++ b/lib/evaluation/metrics/__init__.py @@ -0,0 +1,2 @@ +from .mpxe_like import * +from .utils import * \ No newline at end of file diff --git a/lib/evaluation/metrics/mpxe_like.py b/lib/evaluation/metrics/mpxe_like.py new file mode 100644 index 0000000000000000000000000000000000000000..6b656abd3b1834056eb91a241091375ce332b01a --- /dev/null +++ b/lib/evaluation/metrics/mpxe_like.py @@ -0,0 +1,147 @@ +from typing import Optional +import torch + +from .utils import * + + +''' +All MPxE-like metrics will be implements here. + +- Local Metrics: the inputs motion's translation should be removed (or may be automatically removed). + - MPxE: call `eval_MPxE()` + - PA-MPxE: cal `eval_PA_MPxE()` +- Global Metrics: the inputs motion's translation should be kept. + - G-MPxE: call `eval_MPxE()` + - W2-MPxE: call `eval_Wk_MPxE()`, and set k = 2 + - WA-MPxE: call `eval_WA_MPxE()` +''' + + +def eval_MPxE( + pred : torch.Tensor, + gt : torch.Tensor, + scale : float = m2mm, +): + ''' + Calculate the Mean Per Error. might be joints position (MPJPE), or vertices (MPVE). + + The results will be the sequence of MPxE of each multi-dim batch. + + ### Args + - `pred`: torch.Tensor + - shape = (...B, N, 3), where B is the multi-dim batch size, N is points count in one batch + - the predicted joints/vertices position data + - `gt`: torch.Tensor + - shape = (...B, N, 3), where B is the multi-dim batch size, N is points count in one batch + - the ground truth joints/vertices position data + - `scale`: float, default = `m2mm` + + ### Returns + - torch.Tensor + - shape = (...B) + - shape = () + ''' + # Calculate the MPxE. + ret = L2_error(pred, gt).mean(dim=-1) * scale # (...B,) + return ret + + +def eval_PA_MPxE( + pred : torch.Tensor, + gt : torch.Tensor, + scale : float = m2mm, +): + ''' + Calculate the Procrustes-Aligned Mean Per Error. might be joints position (PA-MPJPE), or + vertices (PA-MPVE). Targets will be Procrustes-aligned and then calculate the per frame MPxE. + + The results will be the sequence of MPxE of each batch. + + ### Args + - `pred`: torch.Tensor + - shape = (...B, N, 3), where B is the multi-dim batch size, N is points count in one batch + - the predicted joints/vertices position data + - `gt`: torch.Tensor + - shape = (...B, N, 3), where B is the multi-dim batch size, N is points count in one batch + - the ground truth joints/vertices position data + - `scale`: float, default = `m2mm` + + ### Returns + - torch.Tensor + - shape = (...B) + - shape = () + ''' + # Perform Procrustes alignment. + pred_aligned = similarity_align_to(pred, gt) # (...B, N, 3) + # Calculate the PA-MPxE + return eval_MPxE(pred_aligned, gt, scale) # (...B,) + + +def eval_Wk_MPxE( + pred : torch.Tensor, + gt : torch.Tensor, + scale : float = m2mm, + k_f : int = 2, +): + ''' + Calculate the first k frames aligned (World aligned) Mean Per Error. might be joints + position (PA-MPJPE), or vertices (PA-MPVE). Targets will be aligned using the first k frames + and then calculate the per frame MPxE. + + The results will be the sequence of MPxE of each batch. + + ### Args + - `pred`: torch.Tensor + - shape = (..., L, N, 3), where L is the length of the sequence, N is points count in one batch + - the predicted joints/vertices position data + - `gt`: torch.Tensor + - shape = (..., L, N, 3), where L is the length of the sequence, N is points count in one batch + - the ground truth joints/vertices position data + - `scale`: float, default = `m2mm` + - `k_f`: int, default = 2 + - the number of frames to use for alignment + + ### Returns + - torch.Tensor + - shape = (..., L) + - shape = () + ''' + L = max(pred.shape[-3], gt.shape[-3]) + assert L >= 2, f'Length of the sequence should be at least 2, but got {L}.' + # Perform first two alignment. + pred_aligned = first_k_frames_align_to(pred, gt, k_f) # (..., L, N, 3) + # Calculate the PA-MPxE + return eval_MPxE(pred_aligned, gt, scale) # (..., L) + + +def eval_WA_MPxE( + pred : torch.Tensor, + gt : torch.Tensor, + scale : float = m2mm, +): + ''' + Calculate the all frames aligned (World All aligned) Mean Per Error. might be joints + position (PA-MPJPE), or vertices (PA-MPVE). Targets will be aligned using the first k frames + and then calculate the per frame MPxE. + + The results will be the sequence of MPxE of each batch. + + ### Args + - `pred`: torch.Tensor + - shape = (..., L, N, 3), where L is the length of the sequence, N is points count in one batch + - the predicted joints/vertices position data + - `gt`: torch.Tensor + - shape = (..., L, N, 3), where L is the length of the sequence, N is points count in one batch + - the ground truth joints/vertices position data + - `scale`: float, default = `m2mm` + + ### Returns + - torch.Tensor + - shape = (..., L) + - shape = () + ''' + L_pred = pred.shape[-3] + L_gt = gt.shape[-3] + assert (L_pred == L_gt), f'Length of the sequence should be the same, but got {L_pred} and {L_gt}.' + # WA_MPxE is just Wk_MPxE when k = L. + return eval_Wk_MPxE(pred, gt, scale, L_gt) \ No newline at end of file diff --git a/lib/evaluation/metrics/utils.py b/lib/evaluation/metrics/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..787e5c70e924a486a089dacc520343ce54037d0c --- /dev/null +++ b/lib/evaluation/metrics/utils.py @@ -0,0 +1,195 @@ +import torch + + +m2mm = 1000.0 + + +def L2_error(x:torch.Tensor, y:torch.Tensor): + ''' + Calculate the L2 error across the last dim of the input tensors. + + ### Args + - `x`: torch.Tensor, shape (..., D) + - `y`: torch.Tensor, shape (..., D) + + ### Returns + - torch.Tensor, shape (...) + ''' + return (x - y).norm(dim=-1) + + +def similarity_align_to( + S1 : torch.Tensor, + S2 : torch.Tensor, +): + ''' + Computes a similarity transform (sR, t) that takes a set of 3D points S1 (3 x N) + closest to a set of 3D points S2, where R is an 3x3 rotation matrix, + t 3x1 translation, s scales. That is to solves the orthogonal Procrutes problem. + + The code was modified from [WHAM](https://github.com/yohanshin/WHAM/blob/d1ade93ae83a91855902fdb8246c129c4b3b8a40/lib/eval/eval_utils.py#L201-L252). + + ### Args + - `S1`: torch.Tensor, shape (...B, N, 3) + - `S2`: torch.Tensor, shape (...B, N, 3) + + ### Returns + - torch.Tensor, shape (...B, N, 3) + ''' + assert (S1.shape[-1] == 3 and S2.shape[-1] == 3), 'The last dimension of `S1` and `S2` must be 3.' + assert (S1.shape[:-2] == S2.shape[:-2]), 'The batch size of `S1` and `S2` must be the same.' + original_BN3 = S1.shape + N = original_BN3[-2] + S1 = S1.reshape(-1, N, 3) # (B', N, 3) <- (...B, N, 3) + S2 = S2.reshape(-1, N, 3) # (B', N, 3) <- (...B, N, 3) + B = S1.shape[0] + + S1 = S1.transpose(-1, -2) # (B', 3, N) <- (B', N, 3) + S2 = S2.transpose(-1, -2) # (B', 3, N) <- (B', N, 3) + _device = S2.device + S1 = S1.to(_device) + + # 1. Remove mean. + mu1 = S1.mean(axis=-1, keepdims=True) # (B', 3, 1) + mu2 = S2.mean(axis=-1, keepdims=True) # (B', 3, 1) + X1 = S1 - mu1 # (B', 3, N) + X2 = S2 - mu2 # (B', 3, N) + + # 2. Compute variance of X1 used for scales. + var1 = torch.einsum('...BDN->...B', X1**2) # (B',) + + # 3. The outer product of X1 and X2. + K = X1 @ X2.transpose(-1, -2) # (B', 3, 3) + + # 4. Solution that Maximizes trace(R'K) is R=U*V', where U, V are singular vectors of K. + U, s, V = torch.svd(K) # (B', 3, 3), (B', 3), (B', 3, 3) + + # Construct Z that fixes the orientation of R to get det(R)=1. + Z = torch.eye(3, device=_device)[None].repeat(B, 1, 1) # (B', 3, 3) + Z[:, -1, -1] *= (U @ V.transpose(-1, -2)).det().sign() + + # Construct R. + R = V @ (Z @ U.transpose(-1, -2)) # (B', 3, 3) + + # 5. Recover scales. + traces = [torch.trace(x)[None] for x in (R @ K)] + scales = torch.cat(traces) / var1 # (B',) + scales = scales[..., None, None] # (B', 1, 1) + + # 6. Recover translation. + t = mu2 - (scales * (R @ mu1)) # (B', 3, 1) + + # 7. Error: + S1_aligned = scales * (R @ S1) + t # (B', 3, N) + + S1_aligned = S1_aligned.transpose(-1, -2) # (B', N, 3) <- (B', 3, N) + S1_aligned = S1_aligned.reshape(original_BN3) # (...B, N, 3) + return S1_aligned # (...B, N, 3) + + +def align_pcl(Y: torch.Tensor, X: torch.Tensor, weight=None, fixed_scale=False): + ''' + Align similarity transform to align X with Y using umeyama method. X' = s * R * X + t is aligned with Y. + + The code was copied from [SLAHMR](https://github.com/vye16/slahmr/blob/58518fec991877bc4911e260776589185b828fe9/slahmr/geometry/pcl.py#L10-L60). + + ### Args + - `Y`: torch.Tensor, shape (*, N, 3) first trajectory + - `X`: torch.Tensor, shape (*, N, 3) second trajectory + - `weight`: torch.Tensor, shape (*, N, 1) optional weight of valid correspondences + - `fixed_scale`: bool, default = False + + ### Returns + - `s` (*, 1) + - `R` (*, 3, 3) + - `t` (*, 3) + ''' + *dims, N, _ = Y.shape + N = torch.ones(*dims, 1, 1) * N + + if weight is not None: + Y = Y * weight + X = X * weight + N = weight.sum(dim=-2, keepdim=True) # (*, 1, 1) + + # subtract mean + my = Y.sum(dim=-2) / N[..., 0] # (*, 3) + mx = X.sum(dim=-2) / N[..., 0] + y0 = Y - my[..., None, :] # (*, N, 3) + x0 = X - mx[..., None, :] + + if weight is not None: + y0 = y0 * weight + x0 = x0 * weight + + # correlation + C = torch.matmul(y0.transpose(-1, -2), x0) / N # (*, 3, 3) + U, D, Vh = torch.linalg.svd(C) # (*, 3, 3), (*, 3), (*, 3, 3) + + S = torch.eye(3).reshape(*(1,) * (len(dims)), 3, 3).repeat(*dims, 1, 1) + neg = torch.det(U) * torch.det(Vh.transpose(-1, -2)) < 0 + S[neg, 2, 2] = -1 + + R = torch.matmul(U, torch.matmul(S, Vh)) # (*, 3, 3) + + D = torch.diag_embed(D) # (*, 3, 3) + if fixed_scale: + s = torch.ones(*dims, 1, device=Y.device, dtype=torch.float32) + else: + var = torch.sum(torch.square(x0), dim=(-1, -2), keepdim=True) / N # (*, 1, 1) + s = ( + torch.diagonal(torch.matmul(D, S), dim1=-2, dim2=-1).sum( + dim=-1, keepdim=True + ) + / var[..., 0] + ) # (*, 1) + + t = my - s * torch.matmul(R, mx[..., None])[..., 0] # (*, 3) + + return s, R, t + + +def first_k_frames_align_to( + S1 : torch.Tensor, + S2 : torch.Tensor, + k_f : int, +): + ''' + Compute the transformation between the first trajectory segment of S1 and S2, and use + the transformation to align S1 to S2. + + The code was modified from [SLAHMR](https://github.com/vye16/slahmr/blob/58518fec991877bc4911e260776589185b828fe9/slahmr/eval/tools.py#L68-L81). + + ### Args + - `S1`: torch.Tensor, shape (..., L, N, 3) + - `S2`: torch.Tensor, shape (..., L, N, 3) + - `k_f`: int + - The number of frames to use for alignment. + + ### Returns + - `S1_aligned`: torch.Tensor, shape (..., L, N, 3) + - The aligned S1. + ''' + assert (len(S1.shape) >= 3 and len(S2.shape) >= 3), 'The input tensors must have at least 3 dimensions.' + original_shape = S1.shape # (..., L, N, 3) + L, N, _ = original_shape[-3:] + S1 = S1.reshape(-1, L, N, 3) # (B, L, N, 3) + S2 = S2.reshape(-1, L, N, 3) # (B, L, N, 3) + B = S1.shape[0] + + # 1. Prepare the clouds to be aligned. + S1_first = S1[:, :k_f, :, :].reshape(B, -1, 3) # (B, 1, k_f * N, 3) + S2_first = S2[:, :k_f, :, :].reshape(B, -1, 3) # (B, 1, k_f * N, 3) + + # 2. Get the transformation to perform the alignment. + s_first, R_first, t_first = align_pcl( + X = S1_first, + Y = S2_first, + ) # (B, 1), (B, 3, 3), (B, 3) + s_first = s_first.reshape(B, 1, 1, 1) # (B, 1, 1, 1) + t_first = t_first.reshape(B, 1, 1, 3) # (B, 1, 1, 3) + + # 3. Perform the alignment on the whole sequence. + S1_aligned = s_first * torch.einsum('Bij,BLNj->BLNi', R_first, S1) + t_first # (B, L, N, 3) + S1_aligned = S1_aligned.reshape(original_shape) # (..., L, N, 3) + return S1_aligned # (..., L, N, 3) \ No newline at end of file diff --git a/lib/info/log.py b/lib/info/log.py new file mode 100644 index 0000000000000000000000000000000000000000..74179d0f87d89cfe0831c622ff36698c17662968 --- /dev/null +++ b/lib/info/log.py @@ -0,0 +1,75 @@ +import inspect +import logging +import torch +from time import time +from colorlog import ColoredFormatter + + +def fold_path(fn:str): + ''' Fold a path like `from/to/file.py` to relative `f/t/file.py`. ''' + from lib.platform.proj_manager import ProjManager as PM + root_abs = str(PM.root.absolute()) + if fn.startswith(root_abs): + fn = fn[len(root_abs)+1:] + + return '/'.join([p[:1] for p in fn.split('/')[:-1]]) + '/' + fn.split('/')[-1] + + +def sync_time(): + torch.cuda.synchronize() + return time() + + +def get_logger(brief:bool=False, show_stack:bool=False): + # 1. Get the caller's file name and function name to identify the logging position. + caller_frame = inspect.currentframe().f_back + line_num = caller_frame.f_lineno + file_name = caller_frame.f_globals["__file__"] + file_name = fold_path(file_name) + func_name = caller_frame.f_code.co_name + frames_stack = inspect.stack() + + # 2. Add a trace method to the logger. + def trace_handler(self, message, *args, **kws): + if self.isEnabledFor(TRACE): + self._log(TRACE, message, args, **kws) + + TRACE = 15 # DEBUG is 10 and INFO is 20 + logging.addLevelName(TRACE, 'TRACE') + logging.Logger.trace = trace_handler + + # 3. Set up the logger. + logger = logging.getLogger() + logger.time = time + logger.sync_time = sync_time + + if logger.hasHandlers(): + logger.handlers.clear() + + ch = logging.StreamHandler() + + if brief: + prefix = f'[%(cyan)s%(asctime)s%(reset)s]' + else: + prefix = f'[%(cyan)s%(asctime)s%(reset)s @ %(cyan)s{func_name}%(reset)s @ %(cyan)s{file_name}%(reset)s:%(cyan)s{line_num}%(reset)s]' + if show_stack: + suffix = '\n STACK: ' + ' @ '.join([f'{fold_path(frame.filename)}:{frame.lineno}' for frame in frames_stack[1:]]) + else: + suffix = '' + formatstring = f'{prefix}[%(log_color)s%(levelname)s%(reset)s] %(message)s{suffix}' + datefmt = '%m/%d %H:%M:%S' + ch.setFormatter(ColoredFormatter(formatstring, datefmt=datefmt)) + logger.addHandler(ch) + + # Modify the logging level here. + logger.setLevel(TRACE) + ch.setLevel(TRACE) + + return logger + +if __name__ == '__main__': + get_logger().trace('Test TRACE') + get_logger().info('Test INFO') + get_logger().warning('Test WARN') + get_logger().error('Test ERROR') + get_logger().fatal('Test FATAL') \ No newline at end of file diff --git a/lib/info/look.py b/lib/info/look.py new file mode 100644 index 0000000000000000000000000000000000000000..37f124dcdb75245fabe4343e509001d3dd867d18 --- /dev/null +++ b/lib/info/look.py @@ -0,0 +1,108 @@ +# Provides methods to summarize the information of data, giving a brief overview in text. + +import torch +import numpy as np + +from typing import Optional + +from .log import get_logger + + +def look_tensor( + x : torch.Tensor, + prompt : Optional[str] = None, + silent : bool = False, +): + ''' + Summarize the information of a tensor, including its shape, value range (min, max, mean, std), and dtype. + Then return a string containing the information. + + ### Args + - x: torch.Tensor + - silent: bool, default `False` + - If not silent, the function will print the message itself. The information string will always be returned. + - prompt: Optional[str], default `None` + - If have prompt, it will be printed at the very beginning. + + ### Returns + - str + ''' + info_list = [] if prompt is None else [prompt] + # Convert to float to calculate the statistics. + x_num = x.float() + info_list.append(f'📐 [{x_num.min():06f} -> {x_num.max():06f}] ~ ({x_num.mean():06f}, {x_num.std():06f})') + info_list.append(f'📦 {tuple(x.shape)}') + info_list.append(f'🏷️ {x.dtype}') + info_list.append(f'🖥️ {x.device}') + # Generate the final information and print it if necessary. + ret = '\t'.join(info_list) + if not silent: + get_logger().info(ret) + return ret + + +def look_ndarray( + x : np.ndarray, + silent : bool = False, + prompt : Optional[str] = None, +): + ''' + Summarize the information of a numpy array, including its shape, value range (min, max, mean, std), and dtype. + Then return a string containing the information. + + ### Args + - x: np.ndarray + - silent: bool, default `False` + - If not silent, the function will print the message itself. The information string will always be returned. + - prompt: Optional[str], default `None` + - If have prompt, it will be printed at the very beginning. + + ### Returns + - str + ''' + info_list = [] if prompt is None else [prompt] + # Convert to float to calculate the statistics. + x_num = x.astype(np.float32) + info_list.append(f'📐 [ {x_num.min():06f} -> {x_num.max():06f} ] ~ ( {x_num.mean():06f}, {x_num.std():06f} )') + info_list.append(f'📦 {tuple(x.shape)}') + info_list.append(f'🏷️ {x.dtype}') + # Generate the final information and print it if necessary. + ret = '\t'.join(info_list) + if not silent: + get_logger().info(ret) + return ret + + +def look_dict( + d : dict, + silent : bool = False, +): + ''' + Summarize the information of a dictionary, including the keys and the information of the values. + Then return a string containing the information. + + ### Args + - d: dict + - silent: bool, default `False` + - If not silent, the function will print the message itself. The information string will always be returned. + + ### Returns + - str + ''' + info_list = ['{'] + + for k, v in d.items(): + if isinstance(v, torch.Tensor): + info_list.append(f'{k} : tensor: {look_tensor(v, silent=True)}') + elif isinstance(v, np.ndarray): + info_list.append(f'{k} : ndarray: {look_ndarray(v, silent=True)}') + elif isinstance(v, str): + info_list.append(f'{k} : {v[:32]}') + else: + info_list.append(f'{k} : {type(v)}') + + info_list.append('}') + ret = '\n'.join(info_list) + if not silent: + get_logger().info(ret) + return ret \ No newline at end of file diff --git a/lib/info/show.py b/lib/info/show.py new file mode 100644 index 0000000000000000000000000000000000000000..579bc9fa6e25a72b43b99f188c4f2ae91ed052e8 --- /dev/null +++ b/lib/info/show.py @@ -0,0 +1,93 @@ +# Provides methods to visualize the information of data, giving a brief overview in figure. + +import torch +import numpy as np +import matplotlib.pyplot as plt + +from typing import Optional, Union, List, Dict +from pathlib import Path + +from lib.utils.data import to_numpy + + +def show_distribution( + data : Dict, + fn : Union[str, Path], # File name of the saved figure. + bins : int = 100, # Number of bins in the histogram. + annotation : bool = False, + title : str = 'Data Distribution', + axis_names : List = ['Value', 'Frequency'], + bounds : Optional[List] = None, # Left and right bounds of the histogram. +): + ''' + Visualize the distribution of the data using histogram. + The data should be a dictionary with keys as the labels and values as the data. + ''' + labels = list(data.keys()) + data = np.stack([ to_numpy(x) for x in data.values() ], axis=0) + assert data.ndim == 2, f"Data dimension should be 2, but got {data.ndim}." + assert bounds is None or len(bounds) == 2, f"Bounds should be a list of length 2, but got {bounds}." + # Preparation. + N, K = data.shape + data = data.transpose(1, 0) # (K, N) + # Plot. + plt.hist(data, bins=bins, alpha=0.7, label=labels) + if annotation: + for i in range(K): + for j in range(N): + plt.text(data[i, j], 0, f'{data[i, j]:.2f}', va='bottom', fontsize=6) + plt.title(title) + plt.xlabel(axis_names[0]) + plt.ylabel(axis_names[1]) + plt.legend() + if bounds: + plt.xlim(bounds) + # Save. + plt.savefig(fn) + plt.close() + + + +def show_history( + data : Dict, + fn : Union[str, Path], # file name of the saved figure + annotation : bool = False, + title : str = 'Data History', + axis_names : List = ['Time', 'Value'], + ex_starts : Dict[str, int] = {}, # starting points of the history if not starting from 0 +): + ''' + Visualize the value of changing across time. + The history should be a dictionary with keys as the metric names and values as the metric values. + ''' + # Make sure the fn's parent exists. + if isinstance(fn, str): + fn = Path(fn) + fn.parent.mkdir(parents=True, exist_ok=True) + + # Preparation. + history_name = list(data.keys()) + history_data = [ to_numpy(x) for x in data.values() ] + N = len(history_name) + Ls = [len(x) for x in history_data] + Ss = [ + ex_starts[history_name[i]] + if (history_name[i] in ex_starts.keys()) else 0 + for i in range(N) + ] + + # Plot. + for i in range(N): + plt.plot(range(Ss[i], Ss[i]+Ls[i]), history_data[i], label=history_name[i]) + if annotation: + for i in range(N): + for j in range(Ls[i]): + plt.text(Ss[i]+j, history_data[i][j], f'{history_data[i][j]:.2f}', fontsize=6) + + plt.title(title) + plt.xlabel(axis_names[0]) + plt.ylabel(axis_names[1]) + plt.legend() + # Save. + plt.savefig(fn) + plt.close() \ No newline at end of file diff --git a/lib/kits/README.md b/lib/kits/README.md new file mode 100644 index 0000000000000000000000000000000000000000..8d1313f8d5fdba5e1e412258729567af5677f11f --- /dev/null +++ b/lib/kits/README.md @@ -0,0 +1,3 @@ +# `lib.kits.*` Using Instructions + +Submodules in this package are pre-prepared set of imports for common tasks. diff --git a/lib/kits/basic.py b/lib/kits/basic.py new file mode 100644 index 0000000000000000000000000000000000000000..1a3ad662a9640910b048d58edd2061082346a6ee --- /dev/null +++ b/lib/kits/basic.py @@ -0,0 +1,29 @@ +# Misc. +import os +import sys +from pathlib import Path +from typing import Union, Optional, List, Dict, Tuple, Any + + +# Machine Learning Related +import torch +import torch.nn as nn +import torch.nn.functional as F +import numpy as np +import pytorch_lightning as pl +from pytorch_lightning.utilities import rank_zero_only + +# Framework Supports +import hydra +from hydra.utils import instantiate +from omegaconf import DictConfig, OmegaConf + +# Framework Supports - Customized Part +from lib.utils.data import * +from lib.platform import PM +from lib.info.log import get_logger + +try: + import oven +except ImportError: + get_logger(brief=True).warning('ExpOven is not installed. Will not be able to use the oven related functions. Check https://github.com/IsshikiHugh/ExpOven for more information.') \ No newline at end of file diff --git a/lib/kits/debug.py b/lib/kits/debug.py new file mode 100644 index 0000000000000000000000000000000000000000..ad6d20bebb171bfef3cdd3f9ccbfbe0015dda342 --- /dev/null +++ b/lib/kits/debug.py @@ -0,0 +1,27 @@ +from ipdb import set_trace +import os +import inspect as frame_inspect +from tqdm import tqdm +from rich import inspect, pretty, print + +from lib.platform import PM +from lib.platform.monitor import GPUMonitor +from lib.info.log import get_logger + +def who_imported_me(): + # Get the current stack frames. + stack = frame_inspect.stack() + + # Traverse the stack to find the first external caller. + for frame_info in stack: + # Filter out the internal importlib calls and the current file. + if 'importlib' not in frame_info.filename and frame_info.filename != __file__: + return os.path.abspath(frame_info.filename) + + # If no external file is found, it might be running as the main script. + return None + +get_logger(brief=True).warning(f'DEBUG kits are imported at {who_imported_me()}, remember to remove them.') + +from lib.info.look import * +from lib.info.show import * diff --git a/lib/kits/gradio/__init__.py b/lib/kits/gradio/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..1a1135219bdd11a4041e63b5f3f63f63b9f37c0c --- /dev/null +++ b/lib/kits/gradio/__init__.py @@ -0,0 +1,2 @@ +from .backend import HSMRBackend +from .hsmr_service import HSMRService diff --git a/lib/kits/gradio/backend.py b/lib/kits/gradio/backend.py new file mode 100644 index 0000000000000000000000000000000000000000..3f5756d3bcab97cf4ce68258723e2014d6b834c3 --- /dev/null +++ b/lib/kits/gradio/backend.py @@ -0,0 +1,82 @@ +from lib.kits.hsmr_demo import * + +import gradio as gr + +from lib.modeling.pipelines import HSMRPipeline + +class HSMRBackend: + ''' + Backend class for maintaining HSMR model for inferencing. + Some gradio feature is included in this class. + ''' + def __init__(self, device:str='cpu') -> None: + self.max_img_w = 1920 + self.max_img_h = 1080 + self.device = device + self.pipeline = self._build_pipeline(self.device) + self.detector = build_detector( + batch_size = 1, + max_img_size = 512, + device = self.device, + ) + + + def _build_pipeline(self, device) -> HSMRPipeline: + return build_inference_pipeline( + model_root = DEFAULT_HSMR_ROOT, + device = device, + ) + + + def _load_limited_img(self, fn) -> List: + img, _ = load_img(fn) + if img.shape[0] > self.max_img_h: + img = flex_resize_img(img, (self.max_img_h, -1), kp_mod=4) + if img.shape[1] > self.max_img_w: + img = flex_resize_img(img, (-1, self.max_img_w), kp_mod=4) + return [img] + + + def __call__(self, input_path:Union[str, Path], args:Dict): + # 1. Initialization. + input_type = 'img' + if isinstance(input_path, str): input_path = Path(input_path) + outputs_root = input_path.parent / 'outputs' + outputs_root.mkdir(parents=True, exist_ok=True) + + # 2. Preprocess. + gr.Info(f'[1/3] Pre-processing...') + raw_imgs = self._load_limited_img(input_path) + detector_outputs = self.detector(raw_imgs) + patches, det_meta = imgs_det2patches(raw_imgs, *detector_outputs,args['max_instances']) # N * (256, 256, 3) + + # 3. Inference. + gr.Info(f'[2/3] HSMR inferencing...') + pd_params, pd_cam_t = [], [] + for bw in bsb(total=len(patches), batch_size=args['rec_bs'], enable_tqdm=True): + patches_i = np.concatenate(patches[bw.sid:bw.eid], axis=0) # (N, 256, 256, 3) + patches_normalized_i = (patches_i - IMG_MEAN_255) / IMG_STD_255 # (N, 256, 256, 3) + patches_normalized_i = patches_normalized_i.transpose(0, 3, 1, 2) # (N, 3, 256, 256) + with torch.no_grad(): + outputs = self.pipeline(patches_normalized_i) + pd_params.append({k: v.detach().cpu().clone() for k, v in outputs['pd_params'].items()}) + pd_cam_t.append(outputs['pd_cam_t'].detach().cpu().clone()) + + pd_params = assemble_dict(pd_params, expand_dim=False) # [{k:[x]}, {k:[y]}] -> {k:[x, y]} + pd_cam_t = torch.cat(pd_cam_t, dim=0) + + # 4. Render. + gr.Info(f'[3/3] Rendering results...') + m_skin, m_skel = prepare_mesh(self.pipeline, pd_params) + results = visualize_img_results(pd_cam_t, raw_imgs, det_meta, m_skin, m_skel) + + outputs = {} + + if input_type == 'img': + for k, v in results.items(): + img_path = str(outputs_root / f'{k}.jpg') + outputs[k] = img_path + save_img(v, img_path) + outputs[k] = img_path + + return outputs \ No newline at end of file diff --git a/lib/kits/gradio/hsmr_service.py b/lib/kits/gradio/hsmr_service.py new file mode 100644 index 0000000000000000000000000000000000000000..bf269384bbbd442a942e9cb3168a79fb57bd2aef --- /dev/null +++ b/lib/kits/gradio/hsmr_service.py @@ -0,0 +1,101 @@ +import gradio as gr + +from typing import List, Union +from pathlib import Path + +from lib.platform import PM +from lib.info.log import get_logger +from .backend import HSMRBackend + +class HSMRService: + # ===== Initialization Part ===== + + def __init__(self, backend:HSMRBackend) -> None: + self.example_imgs_root = PM.inputs / 'example_imgs' + self.description = self._load_description() + self.backend:HSMRBackend = backend + + def _load_description(self) -> str: + description_fn = PM.inputs / 'description.md' + with open(description_fn, 'r') as f: + description = f.read() + return description + + + # ===== Funcitonal Private Part ===== + + def _inference_img( + self, + raw_img_path:Union[str, Path], + max_instances:int = 5, + ) -> List: + get_logger(brief=True).info(f'Image Path: {raw_img_path}') + get_logger(brief=True).info(f'max_instances: {max_instances}') + if raw_img_path is None: + gr.Warning('No image uploaded yet. Please upload an image first.') + return '' + if isinstance(raw_img_path, str): + raw_img_path = Path(raw_img_path) + + args = { + 'max_instances': max_instances, + 'rec_bs': 1, + } + output_paths = self.backend(raw_img_path, args) + + # bbx_img_path = output_paths['bbx_img_path'] + # mesh_img_path = output_paths['mesh_img_path'] + # skel_img_path = output_paths['skel_img_path'] + blend_img_path = output_paths['front_blend'] + + return blend_img_path + + + # ===== Service Part ===== + + def serve(self) -> None: + ''' Build UI and set up the service. ''' + with gr.Blocks() as demo: + # 1a. Setup UI. + gr.Markdown(self.description) + + with gr.Tab(label='HSMR-IMG-CPU'): + gr.Markdown('> **Pure CPU** demo for recoverying human mesh and skeleton from a single image. Each inference may take **about 3 minutes**.') + with gr.Row(equal_height=False): + with gr.Column(): + input_image = gr.Image( + label = 'Input', + type = 'filepath', + ) + with gr.Row(equal_height=True): + run_button_image = gr.Button( + value = 'Inference', + variant = 'primary', + ) + + with gr.Column(): + output_blend = gr.Image( + label = 'Output', + type = 'filepath', + interactive = False, + ) + + # 1b. Add examples sections after setting I/O policy. + example_fns = sorted(self.example_imgs_root.glob('*')) + gr.Examples( + examples = example_fns, + fn = self._inference_img, + inputs = input_image, + outputs = output_blend, + ) + + # 2b. Continue binding I/O logic. + run_button_image.click( + fn = self._inference_img, + inputs = input_image, + outputs = output_blend, + ) + + + # 3. Launch the service. + demo.queue(max_size=20).launch(server_name='0.0.0.0', server_port=7860) \ No newline at end of file diff --git a/lib/kits/hsmr_demo.py b/lib/kits/hsmr_demo.py new file mode 100644 index 0000000000000000000000000000000000000000..9a3930db891b43c86e9ed390e3480df8c807ee88 --- /dev/null +++ b/lib/kits/hsmr_demo.py @@ -0,0 +1,217 @@ +from lib.kits.basic import * + +import argparse +from tqdm import tqdm + +from lib.version import DEFAULT_HSMR_ROOT +from lib.utils.vis.py_renderer import render_meshes_overlay_img, render_mesh_overlay_img +from lib.utils.bbox import crop_with_lurb, fit_bbox_to_aspect_ratio, lurb_to_cs, cs_to_lurb +from lib.utils.media import * +from lib.platform.monitor import TimeMonitor +from lib.platform.sliding_batches import bsb +from lib.modeling.pipelines.hsmr import build_inference_pipeline +from lib.modeling.pipelines.vitdet import build_detector + +IMG_MEAN_255 = np.array([0.485, 0.456, 0.406], dtype=np.float32) * 255. +IMG_STD_255 = np.array([0.229, 0.224, 0.225], dtype=np.float32) * 255. + + +# ================== Command Line Supports ================== + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument('-t', '--input_type', type=str, default='auto', help='Specify the input type. auto: file~video, folder~imgs', choices=['auto', 'video', 'imgs']) + parser.add_argument('-i', '--input_path', type=str, required=True, help='The input images root or video file path.') + parser.add_argument('-o', '--output_path', type=str, default=PM.outputs/'demos', help='The output root.') + parser.add_argument('-m', '--model_root', type=str, default=DEFAULT_HSMR_ROOT, help='The model root which contains `.hydra/config.yaml`.') + parser.add_argument('-d', '--device', type=str, default='cuda:0', help='The device.') + parser.add_argument('--det_bs', type=int, default=10, help='The max batch size for detector.') + parser.add_argument('--det_mis', type=int, default=512, help='The max image size for detector.') + parser.add_argument('--rec_bs', type=int, default=300, help='The batch size for recovery.') + parser.add_argument('--max_instances', type=int, default=5, help='Max instances activated in one image.') + parser.add_argument('--ignore_skel', action='store_true', help='Do not render skeleton to boost the rendering.') + args = parser.parse_args() + return args + + +# ================== Data Process Tools ================== + +def load_inputs(args, MAX_IMG_W=1920, MAX_IMG_H=1080): + # 1. Inference inputs type. + inputs_path = Path(args.input_path) + if args.input_type != 'auto': inputs_type = args.input_type + else: inputs_type = 'video' if Path(args.input_path).is_file() else 'imgs' + get_logger(brief=True).info(f'🚚 Loading inputs from: {inputs_path}, regarded as <{inputs_type}>.') + + # 2. Load inputs. + inputs_meta = {'type': inputs_type} + if inputs_type == 'video': + inputs_meta['seq_name'] = inputs_path.stem + frames, _ = load_video(inputs_path) + if frames.shape[1] > MAX_IMG_H: + frames = flex_resize_video(frames, (MAX_IMG_H, -1), kp_mod=4) + if frames.shape[2] > MAX_IMG_W: + frames = flex_resize_video(frames, (-1, MAX_IMG_W), kp_mod=4) + raw_imgs = [frame for frame in frames] + elif inputs_type == 'imgs': + img_fns = list(inputs_path.glob('*.*')) + img_fns = [fn for fn in img_fns if fn.suffix.lower() in ['.jpg', '.jpeg', '.png', '.webp']] + inputs_meta['seq_name'] = f'{inputs_path.stem}-img_cnt={len(img_fns)}' + raw_imgs = [] + for fn in img_fns: + img, _ = load_img(fn) + if img.shape[0] > MAX_IMG_H: + img = flex_resize_img(img, (MAX_IMG_H, -1), kp_mod=4) + if img.shape[1] > MAX_IMG_W: + img = flex_resize_img(img, (-1, MAX_IMG_W), kp_mod=4) + raw_imgs.append(img) + inputs_meta['img_fns'] = img_fns + else: + raise ValueError(f'Unsupported inputs type: {inputs_type}.') + get_logger(brief=True).info(f'📦 Totally {len(raw_imgs)} images are loaded.') + + return raw_imgs, inputs_meta + + +def imgs_det2patches(imgs, dets, downsample_ratios, max_instances_per_img): + ''' Given the raw images and the detection results, return the image patches of human instances. ''' + assert len(imgs) == len(dets), f'L_img = {len(imgs)}, L_det = {len(dets)}' + patches, n_patch_per_img, bbx_cs = [], [], [] + for i in tqdm(range(len(imgs))): + patches_i, bbx_cs_i = _img_det2patches(imgs[i], dets[i], downsample_ratios[i], max_instances_per_img) + n_patch_per_img.append(len(patches_i)) + if len(patches_i) > 0: + patches.append(patches_i.astype(np.float32)) + bbx_cs.append(bbx_cs_i) + else: + get_logger(brief=True).warning(f'No human detection results on image No.{i}.') + det_meta = { + 'n_patch_per_img' : n_patch_per_img, + 'bbx_cs' : bbx_cs, + } + return patches, det_meta + + +def _img_det2patches(imgs, det_instances, downsample_ratio:float, max_instances:int=5): + ''' + 1. Filter out the trusted human detections. + 2. Enlarge the bounding boxes to aspect ratio (ViT backbone only use 192*256 pixels, make sure these + pixels can capture main contents) and then to squares (to adapt the data module). + 3. Crop the image with the bounding boxes and resize them to 256x256. + 4. Normalize the cropped images. + ''' + if det_instances is None: # no human detected + return to_numpy([]), to_numpy([]) + CLASS_HUMAN_ID, DET_THRESHOLD_SCORE = 0, 0.5 + + # Filter out the trusted human detections. + is_human_mask = det_instances['pred_classes'] == CLASS_HUMAN_ID + reliable_mask = det_instances['scores'] > DET_THRESHOLD_SCORE + active_mask = is_human_mask & reliable_mask + + # Filter out the top-k human instances. + if active_mask.sum().item() > max_instances: + humans_scores = det_instances['scores'] * is_human_mask.float() + _, top_idx = humans_scores.topk(max_instances) + valid_mask = torch.zeros_like(active_mask).bool() + valid_mask[top_idx] = True + else: + valid_mask = active_mask + + # Process the bounding boxes and crop the images. + lurb_all = det_instances['pred_boxes'][valid_mask].numpy() / downsample_ratio # (N, 4) + lurb_all = [fit_bbox_to_aspect_ratio(bbox=lurb, tgt_ratio=(192, 256)) for lurb in lurb_all] # regularize the bbox size + cs_all = [lurb_to_cs(lurb) for lurb in lurb_all] + lurb_all = [cs_to_lurb(cs) for cs in cs_all] + cropped_imgs = [crop_with_lurb(imgs, lurb) for lurb in lurb_all] + patches = to_numpy([flex_resize_img(cropped_img, (256, 256)) for cropped_img in cropped_imgs]) # (N, 256, 256, 3) + return patches, cs_all + + +# ================== Secondary Outputs Tools ================== + +def prepare_mesh(pipeline, pd_params): + B = 720 # full SKEL inference is memory consuming + L = pd_params['poses'].shape[0] + n_rounds = (L + B - 1) // B + v_skin_all, v_skel_all = [], [] + for rid in range(n_rounds): + sid, eid = rid * B, min((rid + 1) * B, L) + smpl_outputs = pipeline.skel_model( + poses = pd_params['poses'][sid:eid].to(pipeline.device), + betas = pd_params['betas'][sid:eid].to(pipeline.device), + ) + v_skin = smpl_outputs.skin_verts.detach().cpu() # (B, Vi, 3) + v_skel = smpl_outputs.skel_verts.detach().cpu() # (B, Ve, 3) + v_skin_all.append(v_skin) + v_skel_all.append(v_skel) + v_skel_all = torch.cat(v_skel_all, dim=0) + v_skin_all = torch.cat(v_skin_all, dim=0) + m_skin = {'v': v_skin_all, 'f': pipeline.skel_model.skin_f} + m_skel = {'v': v_skel_all, 'f': pipeline.skel_model.skel_f} + return m_skin, m_skel + + +# ================== Visualization Tools ================== + + +def visualize_img_results(pd_cam_t, raw_imgs, det_meta, m_skin, m_skel): + ''' Render the results to the patches. ''' + bbx_cs, n_patches_per_img = det_meta['bbx_cs'], det_meta['n_patch_per_img'] + bbx_cs = np.concatenate(bbx_cs, axis=0) + + results = [] + pp = 0 # patch pointer + results = [] + for i in tqdm(range(len(raw_imgs)), desc='Rendering'): + raw_h, raw_w = raw_imgs[i].shape[:2] + raw_cx, raw_cy = raw_w/2, raw_h/2 + spp, epp = pp, pp + n_patches_per_img[i] + + # Rescale the camera translation. + raw_cam_t = pd_cam_t.clone().float() + bbx_s = to_tensor(bbx_cs[spp:epp, 2], device=raw_cam_t.device) + bbx_cx = to_tensor(bbx_cs[spp:epp, 0], device=raw_cam_t.device) + bbx_cy = to_tensor(bbx_cs[spp:epp, 1], device=raw_cam_t.device) + + raw_cam_t[spp:epp, 2] = pd_cam_t[spp:epp, 2] * 256 / bbx_s + raw_cam_t[spp:epp, 1] += (bbx_cy - raw_cy) / 5000 * raw_cam_t[spp:epp, 2] + raw_cam_t[spp:epp, 0] += (bbx_cx - raw_cx) / 5000 * raw_cam_t[spp:epp, 2] + raw_cam_t = raw_cam_t[spp:epp] + + # Render overlays on the full image. + full_img_bg = raw_imgs[i].copy() + render_results = {} + for view in ['front']: + full_img_skin = render_meshes_overlay_img( + faces_all = m_skin['f'], + verts_all = m_skin['v'][spp:epp].float(), + cam_t_all = raw_cam_t, + mesh_color = 'blue', + img = full_img_bg, + K4 = [5000, 5000, raw_cx, raw_cy], + view = view, + ) + + if m_skel is not None: + full_img_skel = render_meshes_overlay_img( + faces_all = m_skel['f'], + verts_all = m_skel['v'][spp:epp].float(), + cam_t_all = raw_cam_t, + mesh_color = 'human_yellow', + img = full_img_bg, + K4 = [5000, 5000, raw_cx, raw_cy], + view = view, + ) + + if m_skel is not None: + full_img_blend = cv2.addWeighted(full_img_skin, 0.7, full_img_skel, 0.3, 0) + render_results[f'{view}_blend'] = full_img_blend + if view == 'front': + render_results[f'{view}_skel'] = full_img_skel + else: + render_results[f'{view}_skin'] = full_img_skin + + pp = epp + + return render_results \ No newline at end of file diff --git a/lib/modeling/callbacks/__init__.py b/lib/modeling/callbacks/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..eb0b0f421b2ccc4cca7e257c07fd4f45290d8e27 --- /dev/null +++ b/lib/modeling/callbacks/__init__.py @@ -0,0 +1 @@ +from .skelify_spin import SKELifySPIN \ No newline at end of file diff --git a/lib/modeling/callbacks/skelify_spin.py b/lib/modeling/callbacks/skelify_spin.py new file mode 100644 index 0000000000000000000000000000000000000000..ce6198f4ce45210e307a5dfe23b995dad03cd237 --- /dev/null +++ b/lib/modeling/callbacks/skelify_spin.py @@ -0,0 +1,347 @@ +from lib.kits.basic import * + +from concurrent.futures import ThreadPoolExecutor +from lightning_fabric.utilities.rank_zero import _get_rank + +from lib.data.augmentation.skel import rot_skel_on_plane +from lib.utils.data import to_tensor, to_numpy +from lib.utils.camera import perspective_projection, estimate_camera_trans +from lib.utils.vis import Wis3D +from lib.modeling.optim.skelify.skelify import SKELify +from lib.body_models.common import make_SKEL + +DEBUG = False +DEBUG_ROUND = False + + +class SKELifySPIN(pl.Callback): + ''' + Call SKELify to optimize the prediction results. + Here we have several concepts of data: gt, opgt, pd, (res), bpgt. + + 1. `gt`: Static Ground Truth: they are loaded from static training datasets, + they might be real ground truth (like 2D keypoints), or pseudo ground truth (like SKEL + parameters). They will be gradually replaced by the better pseudo ground truth through + iterations in anticipation. + 2. `opgt`: Old Pseudo-Ground Truth: they are the better ground truth among static datasets + (those called gt), and the dynamic datasets (maintained in the callbacks), and will serve + as the labels for training the network. + 3. `pd`: Predicted Results: they are from the network outputs and will be optimized later. + After being optimized, they will be called as `res`(Results from optimization). + 4. `bpgt`: Better Pseudo Ground Truth: they are the optimized results stored in extra file + and in the memory. These data are the highest quality data picked among the static + ground truth, or cached better pseudo ground truth, or the predicted & optimized data. + ''' + + # TODO: Now I only consider to use kp2d to evaluate the performance. (Because not all data provide kp3d.) + # TODO: But we need to consider the kp3d in the future, which is if we have, than use it. + + def __init__( + self, + cfg : DictConfig, + skelify : DictConfig, + **kwargs, + ): + super().__init__() + self.interval = cfg.interval + self.B = cfg.batch_size + self.kb_pr = cfg.get('max_batches_per_round', None) # latest k batches per round are SPINed + self.better_pgt_fn = Path(cfg.better_pgt_fn) # load it before training + self.skip_warm_up_steps = cfg.skip_warm_up_steps + self.update_better_pgt = cfg.update_better_pgt + self.skelify_cfg = skelify + + # The threshold to determine if the result is valid. (In case some data + # don't have parameters at first but was updated to a bad parameters.) + self.valid_betas_threshold = cfg.valid_betas_threshold + + self._init_pd_dict() + + self.better_pgt = None + + + def on_train_batch_start(self, trainer, pl_module, raw_batch, batch_idx): + # Lazy initialization for better pgt. + if self.better_pgt is None: + self._init_better_pgt() + + # GPU_monitor.snapshot('GPU-Mem-Before-Train-Before-SPIN-Update') + device = pl_module.device + batch = raw_batch['img_ds'] + + if not self.update_better_pgt: + return + + # 1. Compose the data from batches. + seq_key_list = batch['__key__'] + batch_do_flip_list = batch['augm_args']['do_flip'] + sample_uid_list = [ + f'{seq_key}_flip' if do_flip else f'{seq_key}_orig' + for seq_key, do_flip in zip(seq_key_list, batch_do_flip_list) + ] + + # 2. Update the labels from better_pgt. + for i, sample_uid in enumerate(sample_uid_list): + if sample_uid in self.better_pgt['poses'].keys(): + batch['raw_skel_params']['poses'][i] = to_tensor(self.better_pgt['poses'][sample_uid], device=device) # (46,) + batch['raw_skel_params']['betas'][i] = to_tensor(self.better_pgt['betas'][sample_uid], device=device) # (10,) + batch['has_skel_params']['poses'][i] = self.better_pgt['has_poses'][sample_uid] # 0 or 1 + batch['has_skel_params']['betas'][i] = self.better_pgt['has_betas'][sample_uid] # 0 or 1 + batch['updated_by_spin'][i] = True # add information for inspection + # get_logger().trace(f'Update the pseudo-gt for {sample_uid}.') + + # GPU_monitor.snapshot('GPU-Mem-Before-Train-After-SPIN-Update') + + + def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx): + # GPU_monitor.snapshot('GPU-Mem-After-Train-Before-SPIN-Update') + + # Since the prediction from network might be far from the ground truth before well-trained, + # we can skip some steps to avoid meaningless optimization. + if trainer.global_step > self.skip_warm_up_steps or DEBUG_ROUND: + # Collect the prediction results. + self._save_pd(batch['img_ds'], outputs) + + if self.interval > 0 and trainer.global_step % self.interval == 0 or DEBUG_ROUND: + torch.cuda.empty_cache() + with PM.time_monitor('SPIN'): + self._spin(trainer.logger, pl_module.device) + torch.cuda.empty_cache() + + # GPU_monitor.snapshot('GPU-Mem-After-Train-After-SPIN-Update') + # GPU_monitor.report_latest(k=4) + + + def _init_pd_dict(self): + ''' Memory clean up for each SPIN. ''' + ''' Use numpy to store the value to save GPU memory. ''' + self.cache = { + # Things to identify one sample. + 'seq_key_list' : [], + # Things for comparison. + 'opgt_poses_list' : [], + 'opgt_betas_list' : [], + 'opgt_has_poses_list' : [], + 'opgt_has_betas_list' : [], + # Things for optimization and self-iteration. + 'gt_kp2d_list' : [], + 'pd_poses_list' : [], + 'pd_betas_list' : [], + 'pd_cam_t_list' : [], + 'do_flip_list' : [], + 'rot_deg_list' : [], + 'do_extreme_crop_list': [], # if the extreme crop is applied, we don't update the pseudo-gt + # Things for visualization. + 'img_patch': [], + # No gt_cam_t_list. + } + + + def _format_pd(self): + ''' Format the cache to numpy. ''' + if self.kb_pr is None: + last_k = len(self.cache['seq_key_list']) + else: + last_k = self.kb_pr * self.B # the latest k samples to be optimized. + self.cache['seq_key_list'] = to_numpy(self.cache['seq_key_list'])[-last_k:] + self.cache['opgt_poses_list'] = to_numpy(self.cache['opgt_poses_list'])[-last_k:] + self.cache['opgt_betas_list'] = to_numpy(self.cache['opgt_betas_list'])[-last_k:] + self.cache['opgt_has_poses_list'] = to_numpy(self.cache['opgt_has_poses_list'])[-last_k:] + self.cache['opgt_has_betas_list'] = to_numpy(self.cache['opgt_has_betas_list'])[-last_k:] + self.cache['gt_kp2d_list'] = to_numpy(self.cache['gt_kp2d_list'])[-last_k:] + self.cache['pd_poses_list'] = to_numpy(self.cache['pd_poses_list'])[-last_k:] + self.cache['pd_betas_list'] = to_numpy(self.cache['pd_betas_list'])[-last_k:] + self.cache['pd_cam_t_list'] = to_numpy(self.cache['pd_cam_t_list'])[-last_k:] + self.cache['do_flip_list'] = to_numpy(self.cache['do_flip_list'])[-last_k:] + self.cache['rot_deg_list'] = to_numpy(self.cache['rot_deg_list'])[-last_k:] + self.cache['do_extreme_crop_list'] = to_numpy(self.cache['do_extreme_crop_list'])[-last_k:] + + if DEBUG: + self.cache['img_patch'] = to_numpy(self.cache['img_patch'])[-last_k:] + + + def _save_pd(self, batch, outputs): + ''' Save all the prediction results and related labels from the outputs. ''' + B = len(batch['__key__']) + + self.cache['seq_key_list'].extend(batch['__key__']) # (NS,) + + self.cache['opgt_poses_list'].extend(to_numpy(batch['raw_skel_params']['poses'])) # (NS, 46) + self.cache['opgt_betas_list'].extend(to_numpy(batch['raw_skel_params']['betas'])) # (NS, 10) + self.cache['opgt_has_poses_list'].extend(to_numpy(batch['has_skel_params']['poses'])) # (NS,) 0 or 1 + self.cache['opgt_has_betas_list'].extend(to_numpy(batch['has_skel_params']['betas'])) # (NS,) 0 or 1 + self.cache['gt_kp2d_list'].extend(to_numpy(batch['kp2d'])) # (NS, 44, 3) + + self.cache['pd_poses_list'].extend(to_numpy(outputs['pd_params']['poses'])) + self.cache['pd_betas_list'].extend(to_numpy(outputs['pd_params']['betas'])) + self.cache['pd_cam_t_list'].extend(to_numpy(outputs['pd_cam_t'])) + self.cache['do_flip_list'].extend(to_numpy(batch['augm_args']['do_flip'])) + self.cache['rot_deg_list'].extend(to_numpy(batch['augm_args']['rot_deg'])) + self.cache['do_extreme_crop_list'].extend(to_numpy(batch['augm_args']['do_extreme_crop'])) + + if DEBUG: + img_patch = batch['img_patch'].clone().permute(0, 2, 3, 1) # (NS, 256, 256, 3) + mean = torch.tensor([0.485, 0.456, 0.406], device=img_patch.device).reshape(1, 1, 1, 3) + std = torch.tensor([0.229, 0.224, 0.225], device=img_patch.device).reshape(1, 1, 1, 3) + img_patch = 255 * (img_patch * std + mean) + self.cache['img_patch'].extend(to_numpy(img_patch).astype(np.uint8)) # (NS, 256, 256, 3) + + + def _init_better_pgt(self): + ''' DDP adaptable initialization. ''' + self.rank = _get_rank() + get_logger().info(f'Initializing better pgt cache @ rank {self.rank}') + + if self.rank is not None: + self.better_pgt_fn = Path(f'{self.better_pgt_fn}_r{self.rank}') + get_logger().info(f'Redirecting better pgt cache to {self.better_pgt_fn}') + + if self.better_pgt_fn.exists(): + better_pgt_z = np.load(self.better_pgt_fn, allow_pickle=True) + self.better_pgt = {k: better_pgt_z[k].item() for k in better_pgt_z.files} + else: + self.better_pgt = {'poses': {}, 'betas': {}, 'has_poses': {}, 'has_betas': {}} + + + def _spin(self, tb_logger, device): + skelify : SKELify = instantiate(self.skelify_cfg, tb_logger=tb_logger, device=device, _recursive_=False) + skel_model = skelify.skel_model + + self._format_pd() + + # 1. Make up the cache to run SKELify. + with PM.time_monitor('preparation'): + sample_uid_list = [ + f'{seq_key}_flip' if do_flip else f'{seq_key}_orig' + for seq_key, do_flip in zip(self.cache['seq_key_list'], self.cache['do_flip_list']) + ] + + all_gt_kp2d = self.cache['gt_kp2d_list'] # (NS, 44, 2) + all_init_poses = self.cache['pd_poses_list'] # (NS, 46) + all_init_betas = self.cache['pd_betas_list'] # (NS, 10) + all_init_cam_t = self.cache['pd_cam_t_list'] # (NS, 3) + all_do_extreme_crop = self.cache['do_extreme_crop_list'] # (NS,) + all_res_poses = [] + all_res_betas = [] + all_res_cam_t = [] + all_res_kp2d_err = [] # the evaluation of the keypoints 2D error + + # 2. Run SKELify optimization here to get better results. + with PM.time_monitor('SKELify') as tm: + get_logger().info(f'Start to run SKELify optimization. GPU-Mem: {torch.cuda.memory_allocated() / 1e9:.2f}G.') + n_samples = len(self.cache['seq_key_list']) + n_round = (n_samples - 1) // self.B + 1 + get_logger().info(f'Running SKELify optimization for {n_samples} samples in {n_round} rounds.') + for rid in range(n_round): + sid = rid * self.B + eid = min(sid + self.B, n_samples) + + gt_kp2d_with_conf = to_tensor(all_gt_kp2d[sid:eid], device=device) + init_poses = to_tensor(all_init_poses[sid:eid], device=device) + init_betas = to_tensor(all_init_betas[sid:eid], device=device) + init_cam_t = to_tensor(all_init_cam_t[sid:eid], device=device) + + # Run the SKELify optimization. + outputs = skelify( + gt_kp2d = gt_kp2d_with_conf, + init_poses = init_poses, + init_betas = init_betas, + init_cam_t = init_cam_t, + img_patch = self.cache['img_patch'][sid:eid] if DEBUG else None, + ) + + # Store the results. + all_res_poses.extend(to_numpy(outputs['poses'])) # (~NS, 46) + all_res_betas.extend(to_numpy(outputs['betas'])) # (~NS, 10) + all_res_cam_t.extend(to_numpy(outputs['cam_t'])) # (~NS, 3) + all_res_kp2d_err.extend(to_numpy(outputs['kp2d_err'])) # (~NS,) + + tm.tick(f'SKELify round {rid} finished.') + + # 3. Initialize the uninitialized better pseudo-gt with old ground truth. + with PM.time_monitor('init_bpgt'): + get_logger().info(f'Initializing bgbt. GPU-Mem: {torch.cuda.memory_allocated() / 1e9:.2f}G.') + for i in range(n_samples): + sample_uid = sample_uid_list[i] + if sample_uid not in self.better_pgt.keys(): + self.better_pgt['poses'][sample_uid] = self.cache['opgt_poses_list'][i] + self.better_pgt['betas'][sample_uid] = self.cache['opgt_betas_list'][i] + self.better_pgt['has_poses'][sample_uid] = self.cache['opgt_has_poses_list'][i] + self.better_pgt['has_betas'][sample_uid] = self.cache['opgt_has_betas_list'][i] + + # 4. Update the results. + with PM.time_monitor('upd_bpgt'): + upd_cnt = 0 # Count the number of updated samples. + get_logger().info(f'Update the results. GPU-Mem: {torch.cuda.memory_allocated() / 1e9:.2f}G.') + for rid in range(n_round): + torch.cuda.empty_cache() + sid = rid * self.B + eid = min(sid + self.B, n_samples) + + focal_length = np.ones(2) * 5000 / 256 # TODO: These data should be loaded from configuration files. + focal_length = focal_length.reshape(1, 2).repeat(eid - sid, 1) # (B, 2) + gt_kp2d_with_conf = to_tensor(all_gt_kp2d[sid:eid], device=device) # (B, 44, 3) + rot_deg = to_tensor(self.cache['rot_deg_list'][sid:eid], device=device) + + # 4.1. Prepare the better pseudo-gt and the results. + res_betas = to_tensor(all_res_betas[sid:eid], device=device) # (B, 10) + res_poses_after_augm = to_tensor(all_res_poses[sid:eid], device=device) # (B, 46) + res_poses_before_augm = rot_skel_on_plane(res_poses_after_augm, -rot_deg) # recover the augmentation rotation + res_kp2d_err = to_tensor(all_res_kp2d_err[sid:eid], device=device) # (B,) + cur_do_extreme_crop = all_do_extreme_crop[sid:eid] + + # 4.2. Evaluate the quality of the existing better pseudo-gt. + uids = sample_uid_list[sid:eid] # [sid ~ eid] -> sample_uids + bpgt_betas = to_tensor([self.better_pgt['betas'][uid] for uid in uids], device=device) + bpgt_poses_before_augm = to_tensor([self.better_pgt['poses'][uid] for uid in uids], device=device) + bpgt_poses_after_augm = rot_skel_on_plane(bpgt_poses_before_augm.clone(), rot_deg) # recover the augmentation rotation + + skel_outputs = skel_model(poses=bpgt_poses_after_augm, betas=bpgt_betas, skelmesh=False) + bpgt_kp3d = skel_outputs.joints.detach() # (B, 44, 3) + bpgt_est_cam_t = estimate_camera_trans( + S = bpgt_kp3d, + joints_2d = gt_kp2d_with_conf.clone(), + focal_length = 5000, + img_size = 256, + ) # estimate camera translation from inference 3D keypoints and GT 2D keypoints + bpgt_reproj_kp2d = perspective_projection( + points = to_tensor(bpgt_kp3d, device=device), + translation = to_tensor(bpgt_est_cam_t, device=device), + focal_length = to_tensor(focal_length, device=device), + ) + bpgt_kp2d_err = SKELify.eval_kp2d_err(gt_kp2d_with_conf, bpgt_reproj_kp2d) # (B, 44) + + valid_betas_mask = res_betas.abs().max(dim=-1)[0] < self.valid_betas_threshold # (B,) + better_mask = res_kp2d_err < bpgt_kp2d_err # (B,) + upd_mask = torch.logical_and(valid_betas_mask, better_mask) # (B,) + upd_ids = torch.arange(eid-sid, device=device)[upd_mask] # uids -> ids + + # Update one by one. + for upd_id in upd_ids: + # `uid` for dynamic dataset unique id, `id` for in-round batch data. + # Notes: id starts from zeros, it should be applied to [sid ~ eid] directly. + # Either `all_res_poses[upd_id]` or `res_poses[upd_id - sid]` is wrong. + if cur_do_extreme_crop[upd_id]: + # Skip the extreme crop data. + continue + sample_uid = uids[upd_id] + self.better_pgt['poses'][sample_uid] = to_numpy(res_poses_before_augm[upd_id]) + self.better_pgt['betas'][sample_uid] = to_numpy(res_betas[upd_id]) + self.better_pgt['has_poses'][sample_uid] = 1. # If updated, then must have. + self.better_pgt['has_betas'][sample_uid] = 1. # If updated, then must have. + upd_cnt += 1 + + get_logger().info(f'Update {upd_cnt} samples among all {n_samples} samples.') + + # 5. [Async] Save the results. + with PM.time_monitor('async_dumping'): + # TODO: Use lock and other techniques to achieve a better submission system. + # TODO: We need to design a better way to solve the synchronization problem. + if hasattr(self, 'dump_thread'): + self.dump_thread.result() # Wait for the previous dump to finish. + with ThreadPoolExecutor() as executor: + self.dump_thread = executor.submit(lambda: np.savez(self.better_pgt_fn, **self.better_pgt)) + + # 5. Clean up the memory. + del skelify, skel_model + self._init_pd_dict() \ No newline at end of file diff --git a/lib/modeling/losses/__init__.py b/lib/modeling/losses/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c791dce25e46daa4f728a28e76ac048cdf3f5945 --- /dev/null +++ b/lib/modeling/losses/__init__.py @@ -0,0 +1,2 @@ +from .prior import * +from .losses import * \ No newline at end of file diff --git a/lib/modeling/losses/kp.py b/lib/modeling/losses/kp.py new file mode 100644 index 0000000000000000000000000000000000000000..70a938f90c3d46d6a08ceec3c002f2f5ac66e70c --- /dev/null +++ b/lib/modeling/losses/kp.py @@ -0,0 +1,15 @@ +from lib.kits.basic import * + + +def compute_kp3d_loss(gt_kp3d, pd_kp3d, ref_jid=25+14): + conf = gt_kp3d[:, :, 3:].clone() # (B, 43, 1) + gt_kp3d_a = gt_kp3d[:, :, :3] - gt_kp3d[:, [ref_jid], :3] # aligned, (B, J=44, 3) + pd_kp3d_a = pd_kp3d[:, :, :3] - pd_kp3d[:, [ref_jid], :3] # aligned, (B, J=44, 3) + kp3d_loss = conf * F.l1_loss(pd_kp3d_a, gt_kp3d_a, reduction='none') # (B, J=44, 3) + return kp3d_loss.sum() # (,) + + +def compute_kp2d_loss(gt_kp2d, pd_kp2d): + conf = gt_kp2d[:, :, 2:].clone() # (B, 44, 1) + kp2d_loss = conf * F.l1_loss(pd_kp2d, gt_kp2d[:, :, :2], reduction='none') # (B, 44, 2) + return kp2d_loss.sum() # (,) \ No newline at end of file diff --git a/lib/modeling/losses/losses.py b/lib/modeling/losses/losses.py new file mode 100644 index 0000000000000000000000000000000000000000..87d2e30453da23e1c172c878444bc6415224e56e --- /dev/null +++ b/lib/modeling/losses/losses.py @@ -0,0 +1,97 @@ +import torch +import torch.nn as nn + +from smplx.body_models import SMPLOutput +from lib.body_models.skel.skel_model import SKELOutput + + +class Keypoint2DLoss(nn.Module): + + def __init__(self, loss_type: str = 'l1'): + """ + 2D keypoint loss module. + Args: + loss_type (str): Choose between l1 and l2 losses. + """ + super(Keypoint2DLoss, self).__init__() + if loss_type == 'l1': + self.loss_fn = nn.L1Loss(reduction='none') + elif loss_type == 'l2': + self.loss_fn = nn.MSELoss(reduction='none') + else: + raise NotImplementedError('Unsupported loss function') + + def forward(self, pred_keypoints_2d: torch.Tensor, gt_keypoints_2d: torch.Tensor) -> torch.Tensor: + """ + Compute 2D reprojection loss on the keypoints. + Args: + pred_keypoints_2d (torch.Tensor): Tensor of shape [B, S, N, 2] containing projected 2D keypoints (B: batch_size, S: num_samples, N: num_keypoints) + gt_keypoints_2d (torch.Tensor): Tensor of shape [B, S, N, 3] containing the ground truth 2D keypoints and confidence. + Returns: + torch.Tensor: 2D keypoint loss. + """ + conf = gt_keypoints_2d[:, :, -1].unsqueeze(-1).clone() + batch_size = conf.shape[0] + loss = (conf * self.loss_fn(pred_keypoints_2d, gt_keypoints_2d[:, :, :-1])).sum(dim=(1,2)) + return loss.sum() + + +class Keypoint3DLoss(nn.Module): + + def __init__(self, loss_type: str = 'l1'): + """ + 3D keypoint loss module. + Args: + loss_type (str): Choose between l1 and l2 losses. + """ + super(Keypoint3DLoss, self).__init__() + if loss_type == 'l1': + self.loss_fn = nn.L1Loss(reduction='none') + elif loss_type == 'l2': + self.loss_fn = nn.MSELoss(reduction='none') + else: + raise NotImplementedError('Unsupported loss function') + + def forward(self, pred_keypoints_3d: torch.Tensor, gt_keypoints_3d: torch.Tensor, pelvis_id: int = 39): + """ + Compute 3D keypoint loss. + Args: + pred_keypoints_3d (torch.Tensor): Tensor of shape [B, S, N, 3] containing the predicted 3D keypoints (B: batch_size, S: num_samples, N: num_keypoints) + gt_keypoints_3d (torch.Tensor): Tensor of shape [B, S, N, 4] containing the ground truth 3D keypoints and confidence. + Returns: + torch.Tensor: 3D keypoint loss. + """ + batch_size = pred_keypoints_3d.shape[0] + gt_keypoints_3d = gt_keypoints_3d.clone() + pred_keypoints_3d = pred_keypoints_3d - pred_keypoints_3d[:, pelvis_id, :].unsqueeze(dim=1) + gt_keypoints_3d[:, :, :-1] = gt_keypoints_3d[:, :, :-1] - gt_keypoints_3d[:, pelvis_id, :-1].unsqueeze(dim=1) + conf = gt_keypoints_3d[:, :, -1].unsqueeze(-1).clone() + gt_keypoints_3d = gt_keypoints_3d[:, :, :-1] + loss = (conf * self.loss_fn(pred_keypoints_3d, gt_keypoints_3d)).sum(dim=(1,2)) + return loss.sum() + + +class ParameterLoss(nn.Module): + + def __init__(self): + """ + SMPL parameter loss module. + """ + super(ParameterLoss, self).__init__() + self.loss_fn = nn.MSELoss(reduction='none') + + def forward(self, pred_param: torch.Tensor, gt_param: torch.Tensor, has_param: torch.Tensor): + """ + Compute SMPL parameter loss. + Args: + pred_param (torch.Tensor): Tensor of shape [B, S, ...] containing the predicted parameters (body pose / global orientation / betas) + gt_param (torch.Tensor): Tensor of shape [B, S, ...] containing the ground truth SMPL parameters. + Returns: + torch.Tensor: L2 parameter loss loss. + """ + batch_size = pred_param.shape[0] + num_dims = len(pred_param.shape) + mask_dimension = [batch_size] + [1] * (num_dims-1) + has_param = has_param.type(pred_param.type()).view(*mask_dimension) + loss_param = (has_param * self.loss_fn(pred_param, gt_param)) + return loss_param.sum() diff --git a/lib/modeling/losses/prior.py b/lib/modeling/losses/prior.py new file mode 100644 index 0000000000000000000000000000000000000000..2269c756369296282d72b143644f2e713567bfc3 --- /dev/null +++ b/lib/modeling/losses/prior.py @@ -0,0 +1,108 @@ +from lib.kits.basic import * + +# ============= # +# Utils # +# ============= # + + +def soft_bound_loss(x, low, up): + ''' + Softly penalize the violation of the lower and upper bounds. + PROBLEMS: for joints like legs, whose normal pose is near the boundary (standing person tend to have zero rotation but the limitation is zero-bounded, which encourage the leg to bend somehow). + + ### Args: + - x: torch.tensor + - shape = (B, Q), where Q is the number of components. + - low: torch.tensor, lower bound. + - shape = (Q,) + - Lower bound. + - up: (Q,) + - shape = (Q,) + - Upper bound. + + ### Returns: + - loss: torch.tensor + - shape = (B,) + ''' + B = len(x) + loss = torch.exp(low[None] - x).pow(2) + torch.exp(x - up[None]).pow(2) # (B, Q) + return loss # (B,) + + +def softer_bound_loss(x, low, up): + ''' + Softly penalize the violation of the lower and upper bounds. This loss won't penalize so hard when the + value exceed the bound by a small margin (half of up - low), but it's friendly to the case when the common + case is not centered at the middle of the bound. (E.g., the rotation of knee is more likely to be at zero + when some one is standing straight, but zero is the lower bound.) + + ### Args: + - x: torch.tensor, (B, Q), where Q is the number of components. + - low: torch.tensor, (Q,) + - Lower bound. + - up: torch.tensor, (Q,) + - Upper bound. + + ### Returns: + - loss: torch.tensor, (B,) + ''' + B = len(x) + expand = (up - low) / 2 + loss = torch.exp((low[None] - expand) - x).pow(2) + torch.exp(x - (up[None] + expand)).pow(2) # (B, Q) + return loss # (B,) + + +def softest_bound_loss(x, low, up): + ''' + Softly penalize the violation of the lower and upper bounds. This loss won't penalize so hard when the + value exceed the bound by a small margin (half of up - low), but it's friendly to the case when the common + case is not centered at the middle of the bound. (E.g., the rotation of knee is more likely to be at zero + when some one is standing straight, but zero is the lower bound.) + + ### Args: + - x: torch.tensor, (B, Q), where Q is the number of components. + - low: torch.tensor, (Q,) + - Lower bound. + - up: torch.tensor, (Q,) + - Upper bound. + + ### Returns: + - loss: torch.tensor, (B,) + ''' + B = len(x) + expand = (up - low) / 2 + lower_loss = torch.exp((low[None] - expand) - x).pow(2) - 1 # (B, Q) + upper_loss = torch.exp(x - (up[None] + expand)).pow(2) - 1 # (B, Q) + lower_loss = torch.where(lower_loss < 0, 0, lower_loss) + upper_loss = torch.where(upper_loss < 0, 0, upper_loss) + loss = lower_loss + upper_loss + return loss # (B,) + + +# ============= # +# Loss # +# ============= # + + +def compute_poses_angle_prior_loss(poses): + ''' + Some components have upper and lower bound, use exponential loss to apply soft limitation. + + ### Args + - poses: torch.tensor, (B, 46) + + ### Returns + - loss: torch.tensor, (,) + ''' + from lib.body_models.skel_utils.limits import SKEL_LIM_QIDS, SKEL_LIM_BOUNDS + + device = poses.device + # loss = softer_bound_loss( + # loss = softest_bound_loss( + loss = soft_bound_loss( + x = poses[:, SKEL_LIM_QIDS], + low = SKEL_LIM_BOUNDS[:, 0].to(device), + up = SKEL_LIM_BOUNDS[:, 1].to(device), + ) # (,) + + return loss diff --git a/lib/modeling/networks/backbones/README.md b/lib/modeling/networks/backbones/README.md new file mode 100644 index 0000000000000000000000000000000000000000..3f6b3321c7688f91ec95ac970af227207a94602d --- /dev/null +++ b/lib/modeling/networks/backbones/README.md @@ -0,0 +1,5 @@ +# ViT Backbone + +The implementation of ViT backbone comes from [ViTPose](https://github.com/ViTAE-Transformer/ViTPose/blob/d5216452796c90c6bc29f5c5ec0bdba94366768a/mmpose/models/backbones/vit.py#L103). + +Meanwhile, we used **the backbone part** of their released checkpoints. \ No newline at end of file diff --git a/lib/modeling/networks/backbones/__init__.py b/lib/modeling/networks/backbones/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..6e63a8408d15b5af56b2a02ce10a1f1185448160 --- /dev/null +++ b/lib/modeling/networks/backbones/__init__.py @@ -0,0 +1 @@ +from .vit import ViT \ No newline at end of file diff --git a/lib/modeling/networks/backbones/vit.py b/lib/modeling/networks/backbones/vit.py new file mode 100644 index 0000000000000000000000000000000000000000..85569f06a8fa10c6aff6b25356aa57a1283124cc --- /dev/null +++ b/lib/modeling/networks/backbones/vit.py @@ -0,0 +1,335 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import math + +import torch +from functools import partial +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.checkpoint as checkpoint + +from timm.models.layers import drop_path, to_2tuple, trunc_normal_ + + +def get_abs_pos(abs_pos, h, w, ori_h, ori_w, has_cls_token=True): + """ + Calculate absolute positional embeddings. If needed, resize embeddings and remove cls_token + dimension for the original embeddings. + Args: + abs_pos (Tensor): absolute positional embeddings with (1, num_position, C). + has_cls_token (bool): If true, has 1 embedding in abs_pos for cls token. + hw (Tuple): size of input image tokens. + + Returns: + Absolute positional embeddings after processing with shape (1, H, W, C) + """ + cls_token = None + B, L, C = abs_pos.shape + if has_cls_token: + cls_token = abs_pos[:, 0:1] + abs_pos = abs_pos[:, 1:] + + if ori_h != h or ori_w != w: + new_abs_pos = F.interpolate( + abs_pos.reshape(1, ori_h, ori_w, -1).permute(0, 3, 1, 2), + size=(h, w), + mode="bicubic", + align_corners=False, + ).permute(0, 2, 3, 1).reshape(B, -1, C) + + else: + new_abs_pos = abs_pos + + if cls_token is not None: + new_abs_pos = torch.cat([cls_token, new_abs_pos], dim=1) + return new_abs_pos + +class DropPath(nn.Module): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). + """ + def __init__(self, drop_prob=None): + super(DropPath, self).__init__() + self.drop_prob = drop_prob + + def forward(self, x): + return drop_path(x, self.drop_prob, self.training) + + def extra_repr(self): + return 'p={}'.format(self.drop_prob) + +class Mlp(nn.Module): + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.fc2(x) + x = self.drop(x) + return x + +class Attention(nn.Module): + def __init__( + self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., + proj_drop=0., attn_head_dim=None,): + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + self.dim = dim + + if attn_head_dim is not None: + head_dim = attn_head_dim + all_head_dim = head_dim * self.num_heads + + self.scale = qk_scale or head_dim ** -0.5 + + self.qkv = nn.Linear(dim, all_head_dim * 3, bias=qkv_bias) + + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(all_head_dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, x): + B, N, C = x.shape + qkv = self.qkv(x) + qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) + + q = q * self.scale + attn = (q @ k.transpose(-2, -1)) + + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B, N, -1) + x = self.proj(x) + x = self.proj_drop(x) + + return x + +class Block(nn.Module): + + def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, + drop=0., attn_drop=0., drop_path=0., act_layer=nn.GELU, + norm_layer=nn.LayerNorm, attn_head_dim=None + ): + super().__init__() + + self.norm1 = norm_layer(dim) + self.attn = Attention( + dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, + attn_drop=attn_drop, proj_drop=drop, attn_head_dim=attn_head_dim + ) + + # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + + def forward(self, x): + x = x + self.drop_path(self.attn(self.norm1(x))) + x = x + self.drop_path(self.mlp(self.norm2(x))) + return x + + +class PatchEmbed(nn.Module): + """ Image to Patch Embedding + """ + def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, ratio=1): + super().__init__() + img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) + num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0]) * (ratio ** 2) + self.patch_shape = (int(img_size[0] // patch_size[0] * ratio), int(img_size[1] // patch_size[1] * ratio)) + self.origin_patch_shape = (int(img_size[0] // patch_size[0]), int(img_size[1] // patch_size[1])) + self.img_size = img_size + self.patch_size = patch_size + self.num_patches = num_patches + + self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=(patch_size[0] // ratio), padding=4 + 2 * (ratio//2-1)) + + def forward(self, x, **kwargs): + B, C, H, W = x.shape + x = self.proj(x) + Hp, Wp = x.shape[2], x.shape[3] + + x = x.flatten(2).transpose(1, 2) + return x, (Hp, Wp) + + +class HybridEmbed(nn.Module): + """ CNN Feature Map Embedding + Extract feature map from CNN, flatten, project to embedding dim. + """ + def __init__(self, backbone, img_size=224, feature_size=None, in_chans=3, embed_dim=768): + super().__init__() + assert isinstance(backbone, nn.Module) + img_size = to_2tuple(img_size) + self.img_size = img_size + self.backbone = backbone + if feature_size is None: + with torch.no_grad(): + training = backbone.training + if training: + backbone.eval() + o = self.backbone(torch.zeros(1, in_chans, img_size[0], img_size[1]))[-1] + feature_size = o.shape[-2:] + feature_dim = o.shape[1] + backbone.train(training) + else: + feature_size = to_2tuple(feature_size) + feature_dim = self.backbone.feature_info.channels()[-1] + self.num_patches = feature_size[0] * feature_size[1] + self.proj = nn.Linear(feature_dim, embed_dim) + + def forward(self, x): + x = self.backbone(x)[-1] + x = x.flatten(2).transpose(1, 2) + x = self.proj(x) + return x + + +class ViT(nn.Module): + + def __init__(self, + img_size=224, patch_size=16, in_chans=3, num_classes=80, embed_dim=768, depth=12, + num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0., + drop_path_rate=0., hybrid_backbone=None, norm_layer=None, use_checkpoint=False, + frozen_stages=-1, ratio=1, last_norm=True, + patch_padding='pad', freeze_attn=False, freeze_ffn=False, + ): + # Protect mutable default arguments + super(ViT, self).__init__() + norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6) + self.num_classes = num_classes + self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models + self.frozen_stages = frozen_stages + self.use_checkpoint = use_checkpoint + self.patch_padding = patch_padding + self.freeze_attn = freeze_attn + self.freeze_ffn = freeze_ffn + self.depth = depth + + if hybrid_backbone is not None: + self.patch_embed = HybridEmbed( + hybrid_backbone, img_size=img_size, in_chans=in_chans, embed_dim=embed_dim) + else: + self.patch_embed = PatchEmbed( + img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, ratio=ratio) + num_patches = self.patch_embed.num_patches + + # since the pretraining model has class token + self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim)) + + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule + + self.blocks = nn.ModuleList([ + Block( + dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, + ) + for i in range(depth)]) + + self.last_norm = norm_layer(embed_dim) if last_norm else nn.Identity() + + if self.pos_embed is not None: + trunc_normal_(self.pos_embed, std=.02) + + self._freeze_stages() + + def _freeze_stages(self): + """Freeze parameters.""" + if self.frozen_stages >= 0: + self.patch_embed.eval() + for param in self.patch_embed.parameters(): + param.requires_grad = False + + for i in range(1, self.frozen_stages + 1): + m = self.blocks[i] + m.eval() + for param in m.parameters(): + param.requires_grad = False + + if self.freeze_attn: + for i in range(0, self.depth): + m = self.blocks[i] + m.attn.eval() + m.norm1.eval() + for param in m.attn.parameters(): + param.requires_grad = False + for param in m.norm1.parameters(): + param.requires_grad = False + + if self.freeze_ffn: + self.pos_embed.requires_grad = False + self.patch_embed.eval() + for param in self.patch_embed.parameters(): + param.requires_grad = False + for i in range(0, self.depth): + m = self.blocks[i] + m.mlp.eval() + m.norm2.eval() + for param in m.mlp.parameters(): + param.requires_grad = False + for param in m.norm2.parameters(): + param.requires_grad = False + + def init_weights(self): + """Initialize the weights in backbone. + Args: + pretrained (str, optional): Path to pre-trained weights. + Defaults to None. + """ + def _init_weights(m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + self.apply(_init_weights) + + def get_num_layers(self): + return len(self.blocks) + + @torch.jit.ignore + def no_weight_decay(self): + return {'pos_embed', 'cls_token'} + + def forward_features(self, x): + B, C, H, W = x.shape + x, (Hp, Wp) = self.patch_embed(x) + + if self.pos_embed is not None: + # fit for multiple GPU training + # since the first element for pos embed (sin-cos manner) is zero, it will cause no difference + x = x + self.pos_embed[:, 1:] + self.pos_embed[:, :1] + + for blk in self.blocks: + if self.use_checkpoint: + x = checkpoint.checkpoint(blk, x) + else: + x = blk(x) + + x = self.last_norm(x) + + xp = x.permute(0, 2, 1).reshape(B, -1, Hp, Wp).contiguous() + + return xp + + def forward(self, x): + x = self.forward_features(x) + return x + + def train(self, mode=True): + """Convert the model into training mode.""" + super().train(mode) + self._freeze_stages() diff --git a/lib/modeling/networks/discriminators/__init__.py b/lib/modeling/networks/discriminators/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..d11b6d255ce76003c6d0f18ac2753f976be35679 --- /dev/null +++ b/lib/modeling/networks/discriminators/__init__.py @@ -0,0 +1 @@ +from .hsmr_disc import * \ No newline at end of file diff --git a/lib/modeling/networks/discriminators/hsmr_disc.py b/lib/modeling/networks/discriminators/hsmr_disc.py new file mode 100644 index 0000000000000000000000000000000000000000..84b4649ae0052e4ed2e2057c6c971491f822ad6b --- /dev/null +++ b/lib/modeling/networks/discriminators/hsmr_disc.py @@ -0,0 +1,95 @@ +from lib.kits.basic import * + + + +class HSMRDiscriminator(nn.Module): + + def __init__(self): + ''' + Pose + Shape discriminator proposed in HMR + ''' + super(HSMRDiscriminator, self).__init__() + + self.n_poses = 23 + # poses_alone + self.D_conv1 = nn.Conv2d(9, 32, kernel_size=1) + nn.init.xavier_uniform_(self.D_conv1.weight) + nn.init.zeros_(self.D_conv1.bias) + self.relu = nn.ReLU(inplace=True) + self.D_conv2 = nn.Conv2d(32, 32, kernel_size=1) + nn.init.xavier_uniform_(self.D_conv2.weight) + nn.init.zeros_(self.D_conv2.bias) + pose_out = [] + for i in range(self.n_poses): + pose_out_temp = nn.Linear(32, 1) + nn.init.xavier_uniform_(pose_out_temp.weight) + nn.init.zeros_(pose_out_temp.bias) + pose_out.append(pose_out_temp) + self.pose_out = nn.ModuleList(pose_out) + + # betas + self.betas_fc1 = nn.Linear(10, 10) + nn.init.xavier_uniform_(self.betas_fc1.weight) + nn.init.zeros_(self.betas_fc1.bias) + self.betas_fc2 = nn.Linear(10, 5) + nn.init.xavier_uniform_(self.betas_fc2.weight) + nn.init.zeros_(self.betas_fc2.bias) + self.betas_out = nn.Linear(5, 1) + nn.init.xavier_uniform_(self.betas_out.weight) + nn.init.zeros_(self.betas_out.bias) + + # poses_joint + self.D_alljoints_fc1 = nn.Linear(32*self.n_poses, 1024) + nn.init.xavier_uniform_(self.D_alljoints_fc1.weight) + nn.init.zeros_(self.D_alljoints_fc1.bias) + self.D_alljoints_fc2 = nn.Linear(1024, 1024) + nn.init.xavier_uniform_(self.D_alljoints_fc2.weight) + nn.init.zeros_(self.D_alljoints_fc2.bias) + self.D_alljoints_out = nn.Linear(1024, 1) + nn.init.xavier_uniform_(self.D_alljoints_out.weight) + nn.init.zeros_(self.D_alljoints_out.bias) + + + def forward(self, poses_body: torch.Tensor, betas: torch.Tensor) -> torch.Tensor: + ''' + Forward pass of the discriminator. + ### Args + - poses: torch.Tensor, shape (B, 23, 9) + - Matrix representation of the SKEL poses excluding the global orientation. + - betas: torch.Tensor, shape (B, 10) + ### Returns + - torch.Tensor, shape (B, 25) + ''' + poses_body = poses_body.reshape(-1, self.n_poses, 1, 9) # (B, n_poses, 1, 9) + B = poses_body.shape[0] + poses_body = poses_body.permute(0, 3, 1, 2).contiguous() # (B, 9, n_poses, 1) + + # poses_alone + poses_body = self.D_conv1(poses_body) + poses_body = self.relu(poses_body) + poses_body = self.D_conv2(poses_body) + poses_body = self.relu(poses_body) + + poses_out = [] + for i in range(self.n_poses): + poses_out_i = self.pose_out[i](poses_body[:, :, i, 0]) + poses_out.append(poses_out_i) + poses_out = torch.cat(poses_out, dim=1) + + # betas + betas = self.betas_fc1(betas) + betas = self.relu(betas) + betas = self.betas_fc2(betas) + betas = self.relu(betas) + betas_out = self.betas_out(betas) + + # poses_joint + poses_body = poses_body.reshape(B, -1) + poses_all = self.D_alljoints_fc1(poses_body) + poses_all = self.relu(poses_all) + poses_all = self.D_alljoints_fc2(poses_all) + poses_all = self.relu(poses_all) + poses_all_out = self.D_alljoints_out(poses_all) + + disc_out = torch.cat((poses_out, betas_out, poses_all_out), dim=1) + return disc_out diff --git a/lib/modeling/networks/heads/SKEL_mean.npz b/lib/modeling/networks/heads/SKEL_mean.npz new file mode 100644 index 0000000000000000000000000000000000000000..fb888dc2112a35aa94b17f3edbff75651848866f Binary files /dev/null and b/lib/modeling/networks/heads/SKEL_mean.npz differ diff --git a/lib/modeling/networks/heads/__init__.py b/lib/modeling/networks/heads/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..5751ba0de3cee06070292cd00ff149dcb240bcf5 --- /dev/null +++ b/lib/modeling/networks/heads/__init__.py @@ -0,0 +1 @@ +from .skel_head import SKELTransformerDecoderHead \ No newline at end of file diff --git a/lib/modeling/networks/heads/skel_head.py b/lib/modeling/networks/heads/skel_head.py new file mode 100644 index 0000000000000000000000000000000000000000..c32c5cb76f1fbbaddf8b504f16a112f2ed33a9ce --- /dev/null +++ b/lib/modeling/networks/heads/skel_head.py @@ -0,0 +1,107 @@ +from lib.kits.basic import * + +import torch +import torch.nn as nn +import torch.nn.functional as F +import numpy as np +import einops + +from omegaconf import OmegaConf + +from lib.platform import PM +from lib.body_models.skel_utils.transforms import params_q2rep, params_rep2q + +from .utils.pose_transformer import TransformerDecoder + + +class SKELTransformerDecoderHead(nn.Module): + """ Cross-attention based SKEL Transformer decoder + """ + + def __init__(self, cfg): + super().__init__() + self.cfg = cfg + + if cfg.pd_poses_repr == 'rotation_6d': + n_poses = 24 * 6 + elif cfg.pd_poses_repr == 'euler_angle': + n_poses = 46 + else: + raise ValueError(f"Unknown pose representation: {cfg.pd_poses_repr}") + + n_betas = 10 + n_cam = 3 + self.input_is_mean_shape = False + + # Build transformer decoder. + transformer_args = { + 'num_tokens' : 1, + 'token_dim' : (n_poses + n_betas + n_cam) if self.input_is_mean_shape else 1, + 'dim' : 1024, + } + transformer_args.update(OmegaConf.to_container(cfg.transformer_decoder, resolve=True)) # type: ignore + self.transformer = TransformerDecoder(**transformer_args) + + # Build decoders for parameters. + dim = transformer_args['dim'] + self.poses_decoder = nn.Linear(dim, n_poses) + self.betas_decoder = nn.Linear(dim, n_betas) + self.cam_decoder = nn.Linear(dim, n_cam) + + # Load mean shape parameters as initial values. + skel_mean_path = Path(__file__).parent / 'SKEL_mean.npz' + skel_mean_params = np.load(skel_mean_path) + + init_poses = torch.from_numpy(skel_mean_params['poses'].astype(np.float32)).unsqueeze(0) # (1, 46) + if cfg.pd_poses_repr == 'rotation_6d': + init_poses = params_q2rep(init_poses).reshape(1, 24*6) # (1, 24*6) + init_betas = torch.from_numpy(skel_mean_params['betas'].astype(np.float32)).unsqueeze(0) + init_cam = torch.from_numpy(skel_mean_params['cam'].astype(np.float32)).unsqueeze(0) + + self.register_buffer('init_poses', init_poses) + self.register_buffer('init_betas', init_betas) + self.register_buffer('init_cam', init_cam) + + def forward(self, x, **kwargs): + B = x.shape[0] + # vit pretrained backbone is channel-first. Change to token-first + x = einops.rearrange(x, 'b c h w -> b (h w) c') + + # Initialize the parameters. + init_poses = self.init_poses.expand(B, -1) # (B, 46) + init_betas = self.init_betas.expand(B, -1) # (B, 10) + init_cam = self.init_cam.expand(B, -1) # (B, 3) + + # Input token to transformer is zero token. + with PM.time_monitor('init_token'): + if self.input_is_mean_shape: + token = torch.cat([init_poses, init_betas, init_cam], dim=1)[:, None, :] # (B, 1, C) + else: + token = x.new_zeros(B, 1, 1) + + # Pass through transformer. + with PM.time_monitor('transformer'): + token_out = self.transformer(token, context=x) + token_out = token_out.squeeze(1) # (B, C) + + # Parse the SKEL parameters out from token_out. + with PM.time_monitor('decode'): + pd_poses = self.poses_decoder(token_out) + init_poses + pd_betas = self.betas_decoder(token_out) + init_betas + pd_cam = self.cam_decoder(token_out) + init_cam + + with PM.time_monitor('rot_repr_transform'): + if self.cfg.pd_poses_repr == 'rotation_6d': + pd_poses = params_rep2q(pd_poses.reshape(-1, 24, 6)) # (B, 46) + elif self.cfg.pd_poses_repr == 'euler_angle': + pd_poses = pd_poses.reshape(-1, 46) # (B, 46) + else: + raise ValueError(f"Unknown pose representation: {self.cfg.pd_poses_repr}") + + pd_skel_params = { + 'poses' : pd_poses, + 'poses_orient' : pd_poses[:, :3], + 'poses_body' : pd_poses[:, 3:], + 'betas' : pd_betas + } + return pd_skel_params, pd_cam \ No newline at end of file diff --git a/lib/modeling/networks/heads/utils/__init__.py b/lib/modeling/networks/heads/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/lib/modeling/networks/heads/utils/geometry.py b/lib/modeling/networks/heads/utils/geometry.py new file mode 100644 index 0000000000000000000000000000000000000000..86648e9f22866a33b261880bb3953fddd93926a8 --- /dev/null +++ b/lib/modeling/networks/heads/utils/geometry.py @@ -0,0 +1,63 @@ +from typing import Optional +import torch +from torch.nn import functional as F + + +def aa_to_rotmat(theta: torch.Tensor): + """ + Convert axis-angle representation to rotation matrix. + Works by first converting it to a quaternion. + Args: + theta (torch.Tensor): Tensor of shape (B, 3) containing axis-angle representations. + Returns: + torch.Tensor: Corresponding rotation matrices with shape (B, 3, 3). + """ + norm = torch.norm(theta + 1e-8, p = 2, dim = 1) + angle = torch.unsqueeze(norm, -1) + normalized = torch.div(theta, angle) + angle = angle * 0.5 + v_cos = torch.cos(angle) + v_sin = torch.sin(angle) + quat = torch.cat([v_cos, v_sin * normalized], dim = 1) + return quat_to_rotmat(quat) + + +def quat_to_rotmat(quat: torch.Tensor) -> torch.Tensor: + """ + Convert quaternion representation to rotation matrix. + Args: + quat (torch.Tensor) of shape (B, 4); 4 <===> (w, x, y, z). + Returns: + torch.Tensor: Corresponding rotation matrices with shape (B, 3, 3). + """ + norm_quat = quat + norm_quat = norm_quat/norm_quat.norm(p=2, dim=1, keepdim=True) + w, x, y, z = norm_quat[:,0], norm_quat[:,1], norm_quat[:,2], norm_quat[:,3] + + B = quat.size(0) + + w2, x2, y2, z2 = w.pow(2), x.pow(2), y.pow(2), z.pow(2) + wx, wy, wz = w*x, w*y, w*z + xy, xz, yz = x*y, x*z, y*z + + rotMat = torch.stack([w2 + x2 - y2 - z2, 2*xy - 2*wz, 2*wy + 2*xz, + 2*wz + 2*xy, w2 - x2 + y2 - z2, 2*yz - 2*wx, + 2*xz - 2*wy, 2*wx + 2*yz, w2 - x2 - y2 + z2], dim=1).view(B, 3, 3) + return rotMat + +def rot6d_to_rotmat(x: torch.Tensor) -> torch.Tensor: + """ + Convert 6D rotation representation to 3x3 rotation matrix. + Based on Zhou et al., "On the Continuity of Rotation Representations in Neural Networks", CVPR 2019 + Args: + x (torch.Tensor): (B,6) Batch of 6-D rotation representations. + Returns: + torch.Tensor: Batch of corresponding rotation matrices with shape (B,3,3). + """ + x = x.reshape(-1,2,3).permute(0, 2, 1).contiguous() + a1 = x[:, :, 0] + a2 = x[:, :, 1] + b1 = F.normalize(a1) + b2 = F.normalize(a2 - torch.einsum('bi,bi->b', b1, a2).unsqueeze(-1) * b1) + b3 = torch.linalg.cross(b1, b2) + return torch.stack((b1, b2, b3), dim=-1) \ No newline at end of file diff --git a/lib/modeling/networks/heads/utils/pose_transformer.py b/lib/modeling/networks/heads/utils/pose_transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..ac04971407cb59637490cc4842f048b9bc4758be --- /dev/null +++ b/lib/modeling/networks/heads/utils/pose_transformer.py @@ -0,0 +1,358 @@ +from inspect import isfunction +from typing import Callable, Optional + +import torch +from einops import rearrange +from einops.layers.torch import Rearrange +from torch import nn + +from .t_cond_mlp import ( + AdaptiveLayerNorm1D, + FrequencyEmbedder, + normalization_layer, +) +# from .vit import Attention, FeedForward + + +def exists(val): + return val is not None + + +def default(val, d): + if exists(val): + return val + return d() if isfunction(d) else d + + +class PreNorm(nn.Module): + def __init__(self, dim: int, fn: Callable, norm: str = "layer", norm_cond_dim: int = -1): + super().__init__() + self.norm = normalization_layer(norm, dim, norm_cond_dim) + self.fn = fn + + def forward(self, x: torch.Tensor, *args, **kwargs): + if isinstance(self.norm, AdaptiveLayerNorm1D): + return self.fn(self.norm(x, *args), **kwargs) + else: + return self.fn(self.norm(x), **kwargs) + + +class FeedForward(nn.Module): + def __init__(self, dim, hidden_dim, dropout=0.0): + super().__init__() + self.net = nn.Sequential( + nn.Linear(dim, hidden_dim), + nn.GELU(), + nn.Dropout(dropout), + nn.Linear(hidden_dim, dim), + nn.Dropout(dropout), + ) + + def forward(self, x): + return self.net(x) + + +class Attention(nn.Module): + def __init__(self, dim, heads=8, dim_head=64, dropout=0.0): + super().__init__() + inner_dim = dim_head * heads + project_out = not (heads == 1 and dim_head == dim) + + self.heads = heads + self.scale = dim_head**-0.5 + + self.attend = nn.Softmax(dim=-1) + self.dropout = nn.Dropout(dropout) + + self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False) + + self.to_out = ( + nn.Sequential(nn.Linear(inner_dim, dim), nn.Dropout(dropout)) + if project_out + else nn.Identity() + ) + + def forward(self, x): + qkv = self.to_qkv(x).chunk(3, dim=-1) + q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=self.heads), qkv) + + dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale + + attn = self.attend(dots) + attn = self.dropout(attn) + + out = torch.matmul(attn, v) + out = rearrange(out, "b h n d -> b n (h d)") + return self.to_out(out) + + +class CrossAttention(nn.Module): + def __init__(self, dim, context_dim=None, heads=8, dim_head=64, dropout=0.0): + super().__init__() + inner_dim = dim_head * heads + project_out = not (heads == 1 and dim_head == dim) + + self.heads = heads + self.scale = dim_head**-0.5 + + self.attend = nn.Softmax(dim=-1) + self.dropout = nn.Dropout(dropout) + + context_dim = default(context_dim, dim) + self.to_kv = nn.Linear(context_dim, inner_dim * 2, bias=False) + self.to_q = nn.Linear(dim, inner_dim, bias=False) + + self.to_out = ( + nn.Sequential(nn.Linear(inner_dim, dim), nn.Dropout(dropout)) + if project_out + else nn.Identity() + ) + + def forward(self, x, context=None): + context = default(context, x) + k, v = self.to_kv(context).chunk(2, dim=-1) + q = self.to_q(x) + q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=self.heads), [q, k, v]) + + dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale + + attn = self.attend(dots) + attn = self.dropout(attn) + + out = torch.matmul(attn, v) + out = rearrange(out, "b h n d -> b n (h d)") + return self.to_out(out) + + +class Transformer(nn.Module): + def __init__( + self, + dim: int, + depth: int, + heads: int, + dim_head: int, + mlp_dim: int, + dropout: float = 0.0, + norm: str = "layer", + norm_cond_dim: int = -1, + ): + super().__init__() + self.layers = nn.ModuleList([]) + for _ in range(depth): + sa = Attention(dim, heads=heads, dim_head=dim_head, dropout=dropout) + ff = FeedForward(dim, mlp_dim, dropout=dropout) + self.layers.append( + nn.ModuleList( + [ + PreNorm(dim, sa, norm=norm, norm_cond_dim=norm_cond_dim), + PreNorm(dim, ff, norm=norm, norm_cond_dim=norm_cond_dim), + ] + ) + ) + + def forward(self, x: torch.Tensor, *args): + for attn, ff in self.layers: + x = attn(x, *args) + x + x = ff(x, *args) + x + return x + + +class TransformerCrossAttn(nn.Module): + def __init__( + self, + dim: int, + depth: int, + heads: int, + dim_head: int, + mlp_dim: int, + dropout: float = 0.0, + norm: str = "layer", + norm_cond_dim: int = -1, + context_dim: Optional[int] = None, + ): + super().__init__() + self.layers = nn.ModuleList([]) + for _ in range(depth): + sa = Attention(dim, heads=heads, dim_head=dim_head, dropout=dropout) + ca = CrossAttention( + dim, context_dim=context_dim, heads=heads, dim_head=dim_head, dropout=dropout + ) + ff = FeedForward(dim, mlp_dim, dropout=dropout) + self.layers.append( + nn.ModuleList( + [ + PreNorm(dim, sa, norm=norm, norm_cond_dim=norm_cond_dim), + PreNorm(dim, ca, norm=norm, norm_cond_dim=norm_cond_dim), + PreNorm(dim, ff, norm=norm, norm_cond_dim=norm_cond_dim), + ] + ) + ) + + def forward(self, x: torch.Tensor, *args, context=None, context_list=None): + if context_list is None: + context_list = [context] * len(self.layers) + if len(context_list) != len(self.layers): + raise ValueError(f"len(context_list) != len(self.layers) ({len(context_list)} != {len(self.layers)})") + + for i, (self_attn, cross_attn, ff) in enumerate(self.layers): + x = self_attn(x, *args) + x + x = cross_attn(x, *args, context=context_list[i]) + x + x = ff(x, *args) + x + return x + + +class DropTokenDropout(nn.Module): + def __init__(self, p: float = 0.1): + super().__init__() + if p < 0 or p > 1: + raise ValueError( + "dropout probability has to be between 0 and 1, " "but got {}".format(p) + ) + self.p = p + + def forward(self, x: torch.Tensor): + # x: (batch_size, seq_len, dim) + if self.training and self.p > 0: + zero_mask = torch.full_like(x[0, :, 0], self.p).bernoulli().bool() + # TODO: permutation idx for each batch using torch.argsort + if zero_mask.any(): + x = x[:, ~zero_mask, :] + return x + + +class ZeroTokenDropout(nn.Module): + def __init__(self, p: float = 0.1): + super().__init__() + if p < 0 or p > 1: + raise ValueError( + "dropout probability has to be between 0 and 1, " "but got {}".format(p) + ) + self.p = p + + def forward(self, x: torch.Tensor): + # x: (batch_size, seq_len, dim) + if self.training and self.p > 0: + zero_mask = torch.full_like(x[:, :, 0], self.p).bernoulli().bool() + # Zero-out the masked tokens + x[zero_mask, :] = 0 + return x + + +class TransformerEncoder(nn.Module): + def __init__( + self, + num_tokens: int, + token_dim: int, + dim: int, + depth: int, + heads: int, + mlp_dim: int, + dim_head: int = 64, + dropout: float = 0.0, + emb_dropout: float = 0.0, + emb_dropout_type: str = "drop", + emb_dropout_loc: str = "token", + norm: str = "layer", + norm_cond_dim: int = -1, + token_pe_numfreq: int = -1, + ): + super().__init__() + if token_pe_numfreq > 0: + token_dim_new = token_dim * (2 * token_pe_numfreq + 1) + self.to_token_embedding = nn.Sequential( + Rearrange("b n d -> (b n) d", n=num_tokens, d=token_dim), + FrequencyEmbedder(token_pe_numfreq, token_pe_numfreq - 1), + Rearrange("(b n) d -> b n d", n=num_tokens, d=token_dim_new), + nn.Linear(token_dim_new, dim), + ) + else: + self.to_token_embedding = nn.Linear(token_dim, dim) + self.pos_embedding = nn.Parameter(torch.randn(1, num_tokens, dim)) + if emb_dropout_type == "drop": + self.dropout = DropTokenDropout(emb_dropout) + elif emb_dropout_type == "zero": + self.dropout = ZeroTokenDropout(emb_dropout) + else: + raise ValueError(f"Unknown emb_dropout_type: {emb_dropout_type}") + self.emb_dropout_loc = emb_dropout_loc + + self.transformer = Transformer( + dim, depth, heads, dim_head, mlp_dim, dropout, norm=norm, norm_cond_dim=norm_cond_dim + ) + + def forward(self, inp: torch.Tensor, *args, **kwargs): + x = inp + + if self.emb_dropout_loc == "input": + x = self.dropout(x) + x = self.to_token_embedding(x) + + if self.emb_dropout_loc == "token": + x = self.dropout(x) + b, n, _ = x.shape + x += self.pos_embedding[:, :n] + + if self.emb_dropout_loc == "token_afterpos": + x = self.dropout(x) + x = self.transformer(x, *args) + return x + + +class TransformerDecoder(nn.Module): + def __init__( + self, + num_tokens: int, + token_dim: int, + dim: int, + depth: int, + heads: int, + mlp_dim: int, + dim_head: int = 64, + dropout: float = 0.0, + emb_dropout: float = 0.0, + emb_dropout_type: str = 'drop', + norm: str = "layer", + norm_cond_dim: int = -1, + context_dim: Optional[int] = None, + skip_token_embedding: bool = False, + ): + super().__init__() + if not skip_token_embedding: + self.to_token_embedding = nn.Linear(token_dim, dim) + else: + self.to_token_embedding = nn.Identity() + if token_dim != dim: + raise ValueError( + f"token_dim ({token_dim}) != dim ({dim}) when skip_token_embedding is True" + ) + + self.pos_embedding = nn.Parameter(torch.randn(1, num_tokens, dim)) + if emb_dropout_type == "drop": + self.dropout = DropTokenDropout(emb_dropout) + elif emb_dropout_type == "zero": + self.dropout = ZeroTokenDropout(emb_dropout) + elif emb_dropout_type == "normal": + self.dropout = nn.Dropout(emb_dropout) + + self.transformer = TransformerCrossAttn( + dim, + depth, + heads, + dim_head, + mlp_dim, + dropout, + norm=norm, + norm_cond_dim=norm_cond_dim, + context_dim=context_dim, + ) + + def forward(self, inp: torch.Tensor, *args, context=None, context_list=None): + x = self.to_token_embedding(inp) + b, n, _ = x.shape + + x = self.dropout(x) + x += self.pos_embedding[:, :n] + + x = self.transformer(x, *args, context=context, context_list=context_list) + return x + diff --git a/lib/modeling/networks/heads/utils/t_cond_mlp.py b/lib/modeling/networks/heads/utils/t_cond_mlp.py new file mode 100644 index 0000000000000000000000000000000000000000..44d5a09bf54f67712a69953039b7b5af41c3f029 --- /dev/null +++ b/lib/modeling/networks/heads/utils/t_cond_mlp.py @@ -0,0 +1,199 @@ +import copy +from typing import List, Optional + +import torch + + +class AdaptiveLayerNorm1D(torch.nn.Module): + def __init__(self, data_dim: int, norm_cond_dim: int): + super().__init__() + if data_dim <= 0: + raise ValueError(f"data_dim must be positive, but got {data_dim}") + if norm_cond_dim <= 0: + raise ValueError(f"norm_cond_dim must be positive, but got {norm_cond_dim}") + self.norm = torch.nn.LayerNorm( + data_dim + ) # TODO: Check if elementwise_affine=True is correct + self.linear = torch.nn.Linear(norm_cond_dim, 2 * data_dim) + torch.nn.init.zeros_(self.linear.weight) + torch.nn.init.zeros_(self.linear.bias) + + def forward(self, x: torch.Tensor, t: torch.Tensor) -> torch.Tensor: + # x: (batch, ..., data_dim) + # t: (batch, norm_cond_dim) + # return: (batch, data_dim) + x = self.norm(x) + alpha, beta = self.linear(t).chunk(2, dim=-1) + + # Add singleton dimensions to alpha and beta + if x.dim() > 2: + alpha = alpha.view(alpha.shape[0], *([1] * (x.dim() - 2)), alpha.shape[1]) + beta = beta.view(beta.shape[0], *([1] * (x.dim() - 2)), beta.shape[1]) + + return x * (1 + alpha) + beta + + +class SequentialCond(torch.nn.Sequential): + def forward(self, input, *args, **kwargs): + for module in self: + if isinstance(module, (AdaptiveLayerNorm1D, SequentialCond, ResidualMLPBlock)): + # print(f'Passing on args to {module}', [a.shape for a in args]) + input = module(input, *args, **kwargs) + else: + # print(f'Skipping passing args to {module}', [a.shape for a in args]) + input = module(input) + return input + + +def normalization_layer(norm: Optional[str], dim: int, norm_cond_dim: int = -1): + if norm == "batch": + return torch.nn.BatchNorm1d(dim) + elif norm == "layer": + return torch.nn.LayerNorm(dim) + elif norm == "ada": + assert norm_cond_dim > 0, f"norm_cond_dim must be positive, got {norm_cond_dim}" + return AdaptiveLayerNorm1D(dim, norm_cond_dim) + elif norm is None: + return torch.nn.Identity() + else: + raise ValueError(f"Unknown norm: {norm}") + + +def linear_norm_activ_dropout( + input_dim: int, + output_dim: int, + activation: torch.nn.Module = torch.nn.ReLU(), + bias: bool = True, + norm: Optional[str] = "layer", # Options: ada/batch/layer + dropout: float = 0.0, + norm_cond_dim: int = -1, +) -> SequentialCond: + layers = [] + layers.append(torch.nn.Linear(input_dim, output_dim, bias=bias)) + if norm is not None: + layers.append(normalization_layer(norm, output_dim, norm_cond_dim)) + layers.append(copy.deepcopy(activation)) + if dropout > 0.0: + layers.append(torch.nn.Dropout(dropout)) + return SequentialCond(*layers) + + +def create_simple_mlp( + input_dim: int, + hidden_dims: List[int], + output_dim: int, + activation: torch.nn.Module = torch.nn.ReLU(), + bias: bool = True, + norm: Optional[str] = "layer", # Options: ada/batch/layer + dropout: float = 0.0, + norm_cond_dim: int = -1, +) -> SequentialCond: + layers = [] + prev_dim = input_dim + for hidden_dim in hidden_dims: + layers.extend( + linear_norm_activ_dropout( + prev_dim, hidden_dim, activation, bias, norm, dropout, norm_cond_dim + ) + ) + prev_dim = hidden_dim + layers.append(torch.nn.Linear(prev_dim, output_dim, bias=bias)) + return SequentialCond(*layers) + + +class ResidualMLPBlock(torch.nn.Module): + def __init__( + self, + input_dim: int, + hidden_dim: int, + num_hidden_layers: int, + output_dim: int, + activation: torch.nn.Module = torch.nn.ReLU(), + bias: bool = True, + norm: Optional[str] = "layer", # Options: ada/batch/layer + dropout: float = 0.0, + norm_cond_dim: int = -1, + ): + super().__init__() + if not (input_dim == output_dim == hidden_dim): + raise NotImplementedError( + f"input_dim {input_dim} != output_dim {output_dim} is not implemented" + ) + + layers = [] + prev_dim = input_dim + for i in range(num_hidden_layers): + layers.append( + linear_norm_activ_dropout( + prev_dim, hidden_dim, activation, bias, norm, dropout, norm_cond_dim + ) + ) + prev_dim = hidden_dim + self.model = SequentialCond(*layers) + self.skip = torch.nn.Identity() + + def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor: + return x + self.model(x, *args, **kwargs) + + +class ResidualMLP(torch.nn.Module): + def __init__( + self, + input_dim: int, + hidden_dim: int, + num_hidden_layers: int, + output_dim: int, + activation: torch.nn.Module = torch.nn.ReLU(), + bias: bool = True, + norm: Optional[str] = "layer", # Options: ada/batch/layer + dropout: float = 0.0, + num_blocks: int = 1, + norm_cond_dim: int = -1, + ): + super().__init__() + self.input_dim = input_dim + self.model = SequentialCond( + linear_norm_activ_dropout( + input_dim, hidden_dim, activation, bias, norm, dropout, norm_cond_dim + ), + *[ + ResidualMLPBlock( + hidden_dim, + hidden_dim, + num_hidden_layers, + hidden_dim, + activation, + bias, + norm, + dropout, + norm_cond_dim, + ) + for _ in range(num_blocks) + ], + torch.nn.Linear(hidden_dim, output_dim, bias=bias), + ) + + def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor: + return self.model(x, *args, **kwargs) + + +class FrequencyEmbedder(torch.nn.Module): + def __init__(self, num_frequencies, max_freq_log2): + super().__init__() + frequencies = 2 ** torch.linspace(0, max_freq_log2, steps=num_frequencies) + self.register_buffer("frequencies", frequencies) + + def forward(self, x): + # x should be of size (N,) or (N, D) + N = x.size(0) + if x.dim() == 1: # (N,) + x = x.unsqueeze(1) # (N, D) where D=1 + x_unsqueezed = x.unsqueeze(-1) # (N, D, 1) + scaled = self.frequencies.view(1, 1, -1) * x_unsqueezed # (N, D, num_frequencies) + s = torch.sin(scaled) + c = torch.cos(scaled) + embedded = torch.cat([s, c, x_unsqueezed], dim=-1).view( + N, -1 + ) # (N, D * 2 * num_frequencies + D) + return embedded + diff --git a/lib/modeling/optim/__init__.py b/lib/modeling/optim/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c8b9f9d290513239bc3f7c764d460e38b5dccb8f --- /dev/null +++ b/lib/modeling/optim/__init__.py @@ -0,0 +1,2 @@ +from .skelify.skelify import SKELify +from .skelify_refiner import SKELifyRefiner \ No newline at end of file diff --git a/lib/modeling/optim/skelify/closure.py b/lib/modeling/optim/skelify/closure.py new file mode 100644 index 0000000000000000000000000000000000000000..52bc733b82bd32c97abcf4bc780dee551edccb23 --- /dev/null +++ b/lib/modeling/optim/skelify/closure.py @@ -0,0 +1,202 @@ +from lib.kits.basic import * + +from lib.utils.camera import perspective_projection +from lib.modeling.losses import compute_poses_angle_prior_loss + +from .utils import ( + gmof, + guess_cam_z, + estimate_kp2d_scale, + get_kp_active_jids, + get_params_active_q_masks, +) + +def build_closure( + self, + cfg, + optimizer, + inputs, + focal_length : float, + gt_kp2d, + log_data, +): + B = len(gt_kp2d) + + act_parts = instantiate(cfg.parts) + act_q_masks = None + if not (act_parts == 'all' or 'all' in act_parts): + act_q_masks = get_params_active_q_masks(act_parts) + + # Shortcuts for the inference of the skeleton model. + def inference_skel(inputs): + poses_active = torch.cat([inputs['poses_orient'], inputs['poses_body']], dim=-1) # (B, 46) + if act_q_masks is not None: + poses_hidden = poses_active.clone().detach() # (B, 46) + poses = poses_active * act_q_masks + poses_hidden * (1 - act_q_masks) # (B, 46) + else: + poses = poses_active + skel_params = { + 'poses' : poses, # (B, 46) + 'betas' : inputs['betas'], # (B, 10) + } + skel_output = self.skel_model(**skel_params, skelmesh=False) + return skel_params, skel_output + + # Estimate the camera depth as an initialization if depth loss is enabled. + gs_cam_z = None + if 'w_depth' in cfg.losses: + with torch.no_grad(): + _, skel_output = inference_skel(inputs) + gs_cam_z = guess_cam_z( + pd_kp3d = skel_output.joints, + gt_kp2d = gt_kp2d, + focal_length = focal_length, + ) + + # Prepare the focal length for the perspective projection. + focal_length_xy = np.ones((B, 2)) * focal_length # (B, 2) + + + def closure(): + optimizer.zero_grad() + + # 📦 Data preparation. + with PM.time_monitor('SKEL-forward'): + skel_params, skel_output = inference_skel(inputs) + + with PM.time_monitor('reproj'): + pd_kp2d = perspective_projection( + points = to_tensor(skel_output.joints, device=self.device), + translation = to_tensor(inputs['cam_t'], device=self.device), + focal_length = to_tensor(focal_length_xy, device=self.device), + ) + + with PM.time_monitor('compute_losses'): + loss, losses = compute_losses( + # Loss configuration. + loss_cfg = instantiate(cfg.losses), + parts = act_parts, + # Data inputs. + gt_kp2d = gt_kp2d, + pd_kp2d = pd_kp2d, + pd_params = skel_params, + pd_cam_z = inputs['cam_t'][:, 2], + gs_cam_z = gs_cam_z, + ) + + with PM.time_monitor('visualization'): + VISUALIZE = True + if VISUALIZE: + # For visualize the optimization process. + kp2d_err = torch.sum((pd_kp2d - gt_kp2d[..., :2]) ** 2, dim=-1) * gt_kp2d[..., 2] # (B, J) + kp2d_err = kp2d_err.sum(dim=-1) / (torch.sum(gt_kp2d[..., 2], dim=-1) + 1e-6) # (B,) + + # Store logging data. + if self.tb_logger is not None: + log_data.update({ + 'losses' : losses, + 'pd_kp2d' : pd_kp2d[:self.n_samples].detach().clone(), + 'pd_verts' : skel_output.skin_verts[:self.n_samples].detach().clone(), + 'cam_t' : inputs['cam_t'][:self.n_samples].detach().clone(), + 'optim_betas' : inputs['betas'][:self.n_samples].detach().clone(), + 'kp2d_err' : kp2d_err[:self.n_samples].detach().clone(), + }) + + with PM.time_monitor('backwards'): + loss.backward() + return loss.item() + + return closure + + +def compute_losses( + loss_cfg : Dict[str, Union[bool, float]], + parts : List[str], + gt_kp2d : torch.Tensor, + pd_kp2d : torch.Tensor, + pd_params : Dict[str, torch.Tensor], + pd_cam_z : torch.Tensor, + gs_cam_z : Optional[torch.Tensor] = None, +): + ''' + ### Args + - loss_cfg: Dict[str, Union[bool, float]] + - Special option flags (`f_xxx`) or loss weights (`w_xxx`). + - parts: List[str] + - The list of the active joint parts groups. + - Among {'all', 'torso', 'torso-lite', 'limbs', 'head', 'limbs_proximal', 'limbs_distal'}. + - gt_kp2d: torch.Tensor (B, 44, 3) + - The ground-truth 2D keypoints with confidence. + - pd_kp2d: torch.Tensor (B, 44, 2) + - The predicted 2D keypoints. + - pd_params: Dict[str, torch.Tensor] + - poses: torch.Tensor (B, 46) + - betas: torch.Tensor (B, 10) + - pd_cam_z: torch.Tensor (B,) + - The predicted camera depth translation. + - gs_cam_z: Optional[torch.Tensor] (B,) + - The guessed camera depth translation. + + ### Returns + - loss: torch.Tensor (,) + - The weighted loss value for optimization. + - losses: Dict[str, float] + - The dictionary of the loss values for logging. + ''' + + losses = {} + loss = torch.tensor(0.0, device=gt_kp2d.device) + kp2d_conf = gt_kp2d[:, :, 2] # (B, J) + gt_kp2d = gt_kp2d[:, :, :2] # (B, J, 2) + + # Special option flags. + f_normalize_kp2d = loss_cfg.get('f_normalize_kp2d', False) + + if f_normalize_kp2d: + scale2mean = loss_cfg.get('f_normalize_kp2d_to_mean', False) + scale2d = estimate_kp2d_scale(gt_kp2d) # (B,) + pd_kp2d = pd_kp2d / (scale2d[:, None, None] + 1e-6) # (B, J, 2) + gt_kp2d = gt_kp2d / (scale2d[:, None, None] + 1e-6) # (B, J, 2) + if scale2mean: + scale2d_mean = scale2d.mean() + pd_kp2d = pd_kp2d * scale2d_mean # (B, J, 2) + gt_kp2d = gt_kp2d * scale2d_mean # (B, J, 2) + + # Mask the keypoints. + act_jids = get_kp_active_jids(parts) + kp2d_conf = kp2d_conf[:, act_jids] # (B, J) + gt_kp2d = gt_kp2d[:, act_jids, :] # (B, J, 2) + pd_kp2d = pd_kp2d[:, act_jids, :] # (B, J, 2) + + # Calculate weighted losses. + w_depth = loss_cfg.get('w_depth', None) + w_reprojection = loss_cfg.get('w_reprojection', None) + w_shape_prior = loss_cfg.get('w_shape_prior', None) + w_angle_prior = loss_cfg.get('w_angle_prior', None) + + if w_depth: + assert gs_cam_z is not None, 'The guessed camera depth is required for the depth loss.' + depth_loss = (gs_cam_z - pd_cam_z).pow(2) # (B,) + loss += (w_depth ** 2) * depth_loss.mean() # (,) + losses['depth'] = (w_depth ** 2) * depth_loss.mean().item() # float + + if w_reprojection: + reproj_err_j = gmof(pd_kp2d - gt_kp2d).sum(dim=-1) # (B, J) + reproj_err_j = kp2d_conf.pow(2) * reproj_err_j # (B, J) + reproj_loss = reproj_err_j.sum(-1) # (B,) + loss += (w_reprojection ** 2) * reproj_loss.mean() # (,) + losses['reprojection'] = (w_reprojection ** 2) * reproj_loss.mean().item() # float + + if w_shape_prior: + shape_prior_loss = pd_params['betas'].pow(2).sum(dim=-1) # (B,) + loss += (w_shape_prior ** 2) * shape_prior_loss.mean() # (,) + losses['shape_prior'] = (w_shape_prior ** 2) * shape_prior_loss.mean().item() # float + + if w_angle_prior: + w_angle_prior *= loss_cfg.get('w_angle_prior_scale', 1.0) + angle_prior_loss = compute_poses_angle_prior_loss(pd_params['poses']) # (B,) + loss += (w_angle_prior ** 2) * angle_prior_loss.mean() # (,) + losses['angle_prior'] = (w_angle_prior ** 2) * angle_prior_loss.mean().item() # float + + losses['weighted_sum'] = loss.item() # float + return loss, losses \ No newline at end of file diff --git a/lib/modeling/optim/skelify/skelify.py b/lib/modeling/optim/skelify/skelify.py new file mode 100644 index 0000000000000000000000000000000000000000..499d9c5b34e68dc37f378ab96591737d08b05199 --- /dev/null +++ b/lib/modeling/optim/skelify/skelify.py @@ -0,0 +1,296 @@ +from lib.kits.basic import * + +import cv2 +import traceback +from tqdm import tqdm + +from lib.body_models.common import make_SKEL +from lib.body_models.abstract_skeletons import Skeleton_OpenPose25 +from lib.utils.vis import render_mesh_overlay_img +from lib.utils.data import to_tensor +from lib.utils.media import draw_kp2d_on_img, annotate_img, splice_img +from lib.utils.camera import perspective_projection + +from .utils import ( + compute_rel_change, + gmof, +) + +from .closure import build_closure + +class SKELify(): + + def __init__(self, cfg, tb_logger=None, device='cuda:0', name='SKELify'): + self.cfg = cfg + self.name = name + self.eq_thre = cfg.early_quit_thresholds + + self.tb_logger = tb_logger + + self.device = device + # self.skel_model = make_SKEL(device=device) + self.skel_model = instantiate(cfg.skel_model).to(device) + + # Shortcuts. + self.n_samples = cfg.logger.samples_per_record + + # Dirty implementation for visualization. + self.render_frames = [] + + + def __call__( + self, + gt_kp2d : Union[torch.Tensor, np.ndarray], + init_poses : Union[torch.Tensor, np.ndarray], + init_betas : Union[torch.Tensor, np.ndarray], + init_cam_t : Union[torch.Tensor, np.ndarray], + img_patch : Optional[np.ndarray] = None, + **kwargs + ): + ''' + Use optimization to fit the SKEL parameters to the 2D keypoints. + + ### Args: + - gt_kp2d: torch.Tensor or np.ndarray, (B, J, 3) + - The last three dim means [x, y, conf]. + - The 2D keypoints to fit, they are defined in [-0.5, 0.5], zero-centered space. + - init_poses: torch.Tensor or np.ndarray, (B, 46) + - init_betas: torch.Tensor or np.ndarray, (B, 10) + - init_cam_t: torch.Tensor or np.ndarray, (B, 3) + - img_patch: np.ndarray or None, (B, H, W, 3) + - The image patch for visualization. H, W are defined in normalized bounding box space. + - If it is None, the visualization will simply use a black background. + + ### Returns: + - dict, containing the optimized parameters. + - poses: torch.Tensor, (B, 46) + - betas: torch.Tensor, (B, 10) + - cam_t: torch.Tensor, (B, 3) + ''' + with PM.time_monitor('input preparation'): + gt_kp2d = to_tensor(gt_kp2d, device=self.device).detach().float().clone() # (B, J, 3) + init_poses = to_tensor(init_poses, device=self.device).detach().float().clone() # (B, 46) + init_betas = to_tensor(init_betas, device=self.device).detach().float().clone() # (B, 10) + init_cam_t = to_tensor(init_cam_t, device=self.device).detach().float().clone() # (B, 3) + inputs = { + 'poses_orient' : init_poses[:, :3], # (B, 3) + 'poses_body' : init_poses[:, 3:], # (B, 43) + 'betas' : init_betas, # (B, 10) + 'cam_t' : init_cam_t, # (B, 3) + } + focal_length = float(self.cfg.focal_length / self.cfg.img_patch_size) # float + + # ⛩️ Optimization phases, controlled by config file. + with PM.time_monitor('optim') as tm: + prev_steps = 0 # accumulate the steps are *supposed* to be done in the previous phases + n_phases = len(self.cfg.phases) + for phase_id, phase_name in enumerate(self.cfg.phases): + phase_cfg = self.cfg.phases[phase_name] + # 📦 Data preparation. + optim_params = [] + for k in inputs.keys(): + if k in phase_cfg.params_keys: + inputs[k].requires_grad = True + optim_params.append(inputs[k]) # (B, D) + else: + inputs[k].requires_grad = False + log_data = {} + tm.tick(f'Data preparation') + + # ⚙️ Optimization preparation. + optimizer = instantiate(phase_cfg.optimizer, optim_params, _recursive_=True) + closure = self._build_closure( + cfg=phase_cfg, optimizer=optimizer, # basic + inputs=inputs, focal_length=focal_length, gt_kp2d=gt_kp2d, # data reference + log_data=log_data, # monitoring + ) + tm.tick(f'Optimizer * closure prepared.') + + # 🚀 Optimization loop. + with tqdm(range(phase_cfg.max_loop)) as bar: + prev_loss = None + bar.set_description(f'[{phase_name}] Loss: ???') + for i in bar: + # 1. Main part of the optimization loop. + log_data.clear() + curr_loss = optimizer.step(closure) + + # 2. Log. + if self.tb_logger is not None: + log_data.update({ + 'img_patch' : img_patch[:self.n_samples] if img_patch is not None else None, + 'gt_kp2d' : gt_kp2d[:self.n_samples].detach().clone(), + }) + self._tb_log(prev_steps + i, phase_name, log_data) + + # 3. The end of one optimization loop. + bar.set_description(f'[{phase_id+1}/{n_phases}] @ {phase_name} - Loss: {curr_loss:.4f}') + if self._can_early_quit(optim_params, prev_loss, curr_loss): + break + prev_loss = curr_loss + + prev_steps += phase_cfg.max_loop + tm.tick(f'{phase_name} finished.') + + with PM.time_monitor('last infer'): + poses = torch.cat([inputs['poses_orient'], inputs['poses_body']], dim=-1).detach().clone() # (B, 46) + betas = inputs['betas'].detach().clone() # (B, 10) + cam_t = inputs['cam_t'].detach().clone() # (B, 3) + skel_outputs = self.skel_model(poses=poses, betas=betas, skelmesh=False) # (B, 44, 3) + optim_kp3d = skel_outputs.joints # (B, 44, 3) + # Evaluate the confidence of the results. + focal_length_xy = np.ones((len(poses), 2)) * focal_length # (B, 2) + optim_kp2d = perspective_projection( + points = optim_kp3d, + translation = cam_t, + focal_length = to_tensor(focal_length_xy, device=self.device), + ) + kp2d_err = SKELify.eval_kp2d_err(gt_kp2d, optim_kp2d) # (B,) + + # ⛩️ Prepare the output data. + outputs = { + 'poses' : poses, # (B, 46) + 'betas' : betas, # (B, 10) + 'cam_t' : cam_t, # (B, 3) + 'kp2d_err' : kp2d_err, # (B,) + } + return outputs + + + def _can_early_quit(self, opt_params, prev_loss, curr_loss): + ''' Judge whether to early quit the optimization process. If yes, return True, otherwise False.''' + if self.cfg.early_quit_thresholds is None: + # Never early quit. + return False + + # Relative change test. + if prev_loss is not None: + loss_rel_change = compute_rel_change(prev_loss, curr_loss) + if loss_rel_change < self.cfg.early_quit_thresholds.rel: + get_logger().info(f'Early quit due to relative change: {loss_rel_change} = rel({prev_loss}, {curr_loss})') + return True + + # Absolute change test. + if all([ + torch.abs(param.grad.max()).item() < self.cfg.early_quit_thresholds.abs + for param in opt_params if param.grad is not None + ]): + get_logger().info(f'Early quit due to absolute change.') + return True + + return False + + + def _build_closure(self, *args, **kwargs): + # Using this way to hide the very details and simplify the code. + return build_closure(self, *args, **kwargs) + + + @staticmethod + def eval_kp2d_err(gt_kp2d_with_conf:torch.Tensor, pd_kp2d:torch.Tensor): + ''' Evaluate the mean 2D keypoints L2 error. The formula is: ∑(gt - pd)^2 * conf / ∑conf. ''' + assert len(gt_kp2d_with_conf.shape) == len(gt_kp2d_with_conf.shape), f'gt_kp2d_with_conf.shape={gt_kp2d_with_conf.shape}, pd_kp2d.shape={pd_kp2d.shape} but they should both be ((B,) J, D).' + if len(gt_kp2d_with_conf.shape) == 2: + gt_kp2d_with_conf, pd_kp2d = gt_kp2d_with_conf[None], pd_kp2d[None] + assert len(gt_kp2d_with_conf.shape) == 3, f'gt_kp2d_with_conf.shape={gt_kp2d_with_conf.shape}, pd_kp2d.shape={pd_kp2d.shape} but they should both be ((B,) J, D).' + B, J, _ = gt_kp2d_with_conf.shape + assert gt_kp2d_with_conf.shape == (B, J, 3), f'gt_kp2d_with_conf.shape={gt_kp2d_with_conf.shape} but it should be ((B,) J, 3).' + assert pd_kp2d.shape == (B, J, 2), f'pd_kp2d.shape={pd_kp2d.shape} but it should be ((B,) J, 2).' + + conf = gt_kp2d_with_conf[..., 2] # (B, J) + gt_kp2d = gt_kp2d_with_conf[..., :2] # (B, J, 2) + kp2d_err = torch.sum((gt_kp2d - pd_kp2d) ** 2, dim=-1) * conf # (B, J) + kp2d_err = kp2d_err.sum(dim=-1) / (torch.sum(conf, dim=-1) + 1e-6) # (B,) + return kp2d_err + + + @rank_zero_only + def _tb_log(self, step_cnt:int, phase_name:str, log_data:Dict, *args, **kwargs): + ''' Write the logging information to the TensorBoard. ''' + if step_cnt != 0 and (step_cnt + 1) % self.cfg.logger.interval_skelify != 0: + return + + summary_writer = self.tb_logger.experiment + + # Save losses. + for loss_name, loss_val in log_data['losses'].items(): + summary_writer.add_scalar(f'skelify/{loss_name}', loss_val, step_cnt) + + # Visualization of the optimization process. TODO: Maybe we can make this more elegant. + if log_data['img_patch'] is None: + log_data['img_patch'] = [np.zeros((self.cfg.img_patch_size, self.cfg.img_patch_size, 3), dtype=np.uint8)] \ + * len(log_data['gt_kp2d']) + + if len(self.render_frames) < 1: + self.init_v = log_data['pd_verts'] + self.init_kp2d_err = log_data['kp2d_err'] + self.init_ct = log_data['cam_t'] + + # Overlay the skin mesh of the results on the original image. + try: + imgs_spliced = [] + for i, img_patch in enumerate(log_data['img_patch']): + kp2d_err = log_data['kp2d_err'][i].item() + + img_with_init = render_mesh_overlay_img( + faces = self.skel_model.skin_f, + verts = self.init_v[i], + K4 = [self.cfg.focal_length, self.cfg.focal_length, 128, 128], + img = img_patch, + Rt = [torch.eye(3), self.init_ct[i]], + mesh_color = 'pink', + ) + img_with_init = annotate_img(img_with_init, 'init') + img_with_init = annotate_img(img_with_init, f'Quality: {self.init_kp2d_err[i].item()*1000:.3f}/1e3', pos='tl') + + img_with_mesh = render_mesh_overlay_img( + faces = self.skel_model.skin_f, + verts = log_data['pd_verts'][i], + K4 = [self.cfg.focal_length, self.cfg.focal_length, 128, 128], + img = img_patch, + Rt = [torch.eye(3), log_data['cam_t'][i]], + mesh_color = 'pink', + ) + betas_max = log_data['optim_betas'][i].abs().max().item() + img_patch_raw = annotate_img(img_patch, 'raw') + + log_data['gt_kp2d'][i][..., :2] = (log_data['gt_kp2d'][i][..., :2] + 0.5) * self.cfg.img_patch_size + img_with_gt = annotate_img(img_patch, 'gt_kp2d') + img_with_gt = draw_kp2d_on_img( + img_with_gt, + log_data['gt_kp2d'][i], + Skeleton_OpenPose25.bones, + Skeleton_OpenPose25.bone_colors, + ) + + log_data['pd_kp2d'][i] = (log_data['pd_kp2d'][i] + 0.5) * self.cfg.img_patch_size + img_with_pd = cv2.addWeighted(img_with_mesh, 0.7, img_patch, 0.3, 0) + img_with_pd = draw_kp2d_on_img( + img_with_pd, + log_data['pd_kp2d'][i], + Skeleton_OpenPose25.bones, + Skeleton_OpenPose25.bone_colors, + ) + + img_with_pd = annotate_img(img_with_pd, 'pd') + img_with_pd = annotate_img(img_with_pd, f'Quality: {kp2d_err*1000:.3f}/1e3\nbetas_max: {betas_max:.3f}', pos='tl') + img_with_mesh = annotate_img(img_with_mesh, f'Quality: {kp2d_err*1000:.3f}/1e3\nbetas_max: {betas_max:.3f}', pos='tl') + img_with_mesh = annotate_img(img_with_mesh, 'pd_mesh') + + img_spliced = splice_img( + img_grids = [img_patch_raw, img_with_gt, img_with_pd, img_with_mesh, img_with_init], + # grid_ids = [[0, 1, 2, 3, 4]], + grid_ids = [[1, 2, 3, 4]], + ) + img_spliced = annotate_img(img_spliced, f'{phase_name}/{step_cnt}', pos='tl') + imgs_spliced.append(img_spliced) + + img_final = splice_img(imgs_spliced, grid_ids=[[i] for i in range(len(log_data['img_patch']))]) + + img_final = to_tensor(img_final, device=None).permute(2, 0, 1) # (3, H, W) + summary_writer.add_image('skelify/visualization', img_final, step_cnt) + + self.render_frames.append(img_final) + except Exception as e: + get_logger().error(f'Failed to visualize the optimization process: {e}') + traceback.print_exc() \ No newline at end of file diff --git a/lib/modeling/optim/skelify/utils.py b/lib/modeling/optim/skelify/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..902bc9223fc14547d991db2df737e5832d274db0 --- /dev/null +++ b/lib/modeling/optim/skelify/utils.py @@ -0,0 +1,169 @@ +from lib.kits.basic import * + +from lib.body_models.skel_utils.definition import JID2QIDS + + +def gmof(x, sigma=100): + ''' + Geman-McClure error function, to be used as a robust loss function. + ''' + x_squared = x ** 2 + sigma_squared = sigma ** 2 + return (sigma_squared * x_squared) / (sigma_squared + x_squared) + + +def compute_rel_change(prev_val: float, curr_val: float) -> float: + ''' + Compute the relative change between two values. + Copied from: + https://github.com/vchoutas/smplify-x + + ### Args: + - prev_val: float + - curr_val: float + + ### Returns: + - float + ''' + return np.abs(prev_val - curr_val) / max([np.abs(prev_val), np.abs(curr_val), 1]) + + +INVALID_JIDS = [37, 38] # These joints are not reliable. + +def get_kp_active_j_masks(parts:Union[str, List[str]], device='cuda'): + # Generate the masks performed on the keypoints to mask the loss. + act_jids = get_kp_active_jids(parts) + masks = torch.zeros(44).to(device) + masks[act_jids] = 1.0 + + return masks + + +def get_kp_active_jids(parts:Union[str, List[str]]): + if isinstance(parts, str): + if parts == 'all': + return get_kp_active_jids(['torso', 'limbs', 'head']) + elif parts == 'hips': + return [8, 9, 12, 27, 28, 39] + elif parts == 'torso-lite': + return [2, 5, 9, 12] + elif parts == 'torso': + return [1, 2, 5, 8, 9, 12, 27, 28, 33, 34, 37, 39, 40, 41] + elif parts == 'limbs': + return get_kp_active_jids(['limbs_proximal', 'limbs_distal']) + elif parts == 'head': + return [0, 15, 16, 17, 18, 38, 42, 43] + elif parts == 'limbs_proximal': + return [3, 6, 10, 13, 26, 29, 32, 35] + elif parts == 'limbs_distal': + return [4, 7, 11, 14, 19, 20, 21, 22, 23, 24, 25, 30, 31, 36] + else: + raise ValueError(f'Unsupported parts: {parts}') + else: + jids = [] + for part in parts: + jids.extend(get_kp_active_jids(part)) + jids = set(jids) - set(INVALID_JIDS) + return sorted(list(jids)) + + +def get_params_active_j_masks(parts:Union[str, List[str]], device='cuda'): + # Generate the masks performed on the keypoints to mask the loss. + act_jids = get_params_active_jids(parts) + masks = torch.zeros(24).to(device) + masks[act_jids] = 1.0 + + return masks + + +def get_params_active_jids(parts:Union[str, List[str]]): + if isinstance(parts, str): + if parts == 'all': + return get_params_active_jids(['torso', 'limbs', 'head']) + elif parts == 'torso-lite': + return get_params_active_jids('torso') + elif parts == 'hips': # Enable `hips` if `poses_orient` is enabled. + return [0] + elif parts == 'torso': + return [0, 11] + elif parts == 'limbs': + return get_params_active_jids(['limbs_proximal', 'limbs_distal']) + elif parts == 'head': + return [12, 13] + elif parts == 'limbs_proximal': + return [1, 6, 14, 15, 19, 20] + elif parts == 'limbs_distal': + return [2, 3, 4, 5, 7, 8, 9, 10, 16, 17, 18, 21, 22, 23] + else: + raise ValueError(f'Unsupported parts: {parts}') + else: + qids = [] + for part in parts: + qids.extend(get_params_active_jids(part)) + return sorted(list(set(qids))) + + +def get_params_active_q_masks(parts:Union[str, List[str]], device='cuda'): + # Generate the masks performed on the keypoints to mask the loss. + act_qids = get_params_active_qids(parts) + masks = torch.zeros(46).to(device) + masks[act_qids] = 1.0 + + return masks + + +def get_params_active_qids(parts:Union[str, List[str]]): + act_jids = get_params_active_jids(parts) + qids = [] + for act_jid in act_jids: + qids.extend(JID2QIDS[act_jid]) + return sorted(list(set(qids))) + + +def estimate_kp2d_scale( + kp2d : torch.Tensor, + edge_idxs : List[Tuple[int, int]] = [[5, 12], [2, 9]], # shoulders to hips +): + diff2d = [] + for edge in edge_idxs: + diff2d.append(kp2d[:, edge[0]] - kp2d[:, edge[1]]) # list of (B, 2) + scale2d = torch.stack(diff2d, dim=1).norm(dim=-1) # (B, E) + return scale2d.mean(dim=1) # (B,) + + +@torch.no_grad() +def guess_cam_z( + pd_kp3d : torch.Tensor, + gt_kp2d : torch.Tensor, + focal_length : float, + edge_idxs : List[Tuple[int, int]] = [[5, 12], [2, 9]], # shoulders to hips +): + ''' + Initializes the camera depth translation (i.e. z value) according to the ground truth 2D + keypoints and the predicted 3D keypoints. + Modified from: https://github.com/vchoutas/smplify-x/blob/68f8536707f43f4736cdd75a19b18ede886a4d53/smplifyx/fitting.py#L36-L110 + + ### Args + - pd_kp3d: torch.Tensor, (B, J, 3) + - gt_kp2d: torch.Tensor, (B, J, 2) + - Without confidence. + - focal_length: float + - edge_idxs: List[Tuple[int, int]], default=[[5, 12], [2, 9]], i.e. shoulders to hips + - Identify the edge to evaluate the scale of the entity. + ''' + diff3d, diff2d = [], [] + for edge in edge_idxs: + diff3d.append(pd_kp3d[:, edge[0]] - pd_kp3d[:, edge[1]]) # list of (B, 3) + diff2d.append(gt_kp2d[:, edge[0]] - gt_kp2d[:, edge[1]]) # list of (B, 2) + + diff3d = torch.stack(diff3d, dim=1) # (B, E, 3) + diff2d = torch.stack(diff2d, dim=1) # (B, E, 2) + + length3d = diff3d.norm(dim=-1) # (B, E) + length2d = diff2d.norm(dim=-1) # (B, E) + + height3d = length3d.mean(dim=1) # (B,) + height2d = length2d.mean(dim=1) # (B,) + + z_estim = focal_length * (height3d / height2d) # (B,) + return z_estim # (B,) \ No newline at end of file diff --git a/lib/modeling/optim/skelify_refiner.py b/lib/modeling/optim/skelify_refiner.py new file mode 100644 index 0000000000000000000000000000000000000000..6ab5eb7376797563dcdb35e91a95b103560977db --- /dev/null +++ b/lib/modeling/optim/skelify_refiner.py @@ -0,0 +1,425 @@ +from lib.kits.basic import * + +import traceback +from tqdm import tqdm + +from lib.body_models.common import make_SKEL +from lib.body_models.skel_wrapper import SKELWrapper, SKELOutput +from lib.body_models.abstract_skeletons import Skeleton_OpenPose25 +from lib.utils.data import to_tensor, to_list +from lib.utils.camera import perspective_projection +from lib.utils.media import draw_kp2d_on_img, annotate_img, splice_img +from lib.utils.vis import render_mesh_overlay_img + +from lib.modeling.losses import compute_poses_angle_prior_loss + +from .skelify.utils import get_kp_active_j_masks + + +def compute_rel_change(prev_val: float, curr_val: float) -> float: + ''' + Compute the relative change between two values. + Copied: from https://github.com/vchoutas/smplify-x + + ### Args: + - prev_val: float + - curr_val: float + + ### Returns: + - float + ''' + return np.abs(prev_val - curr_val) / max([np.abs(prev_val), np.abs(curr_val), 1]) + + +def gmof(x, sigma): + ''' + Geman-McClure error function, to be used as a robust loss function. + ''' + x_squared = x ** 2 + sigma_squared = sigma ** 2 + return (sigma_squared * x_squared) / (sigma_squared + x_squared) + + +class SKELifyRefiner(): + + def __init__(self, cfg, name='SKELify', tb_logger=None, device='cuda:0'): + self.cfg = cfg + self.name = name + self.eq_thre = cfg.early_quit_thresholds + + self.tb_logger = tb_logger + + self.device = device + self.skel_model = instantiate(cfg.skel_model).to(device) + + # Dirty implementation for visualization. + self.render_frames = [] + + + + def __call__( + self, + gt_kp2d : Union[torch.Tensor, np.ndarray], + init_poses : Union[torch.Tensor, np.ndarray], + init_betas : Union[torch.Tensor, np.ndarray], + init_cam_t : Union[torch.Tensor, np.ndarray], + img_patch : Optional[np.ndarray] = None, + **kwargs + ): + ''' + Use optimization to fit the SKEL parameters to the 2D keypoints. + + ### Args: + - gt_kp2d : torch.Tensor or np.ndarray, (B, J, 3) + - The last three dim means [x, y, conf]. + - The 2D keypoints to fit, they are defined in [-0.5, 0.5], zero-centered space. + - init_poses : torch.Tensor or np.ndarray, (B, 46) + - init_betas : torch.Tensor or np.ndarray, (B, 10) + - init_cam_t : torch.Tensor or np.ndarray, (B, 3) + - img_patch : np.ndarray or None, (B, H, W, 3) + - The image patch for visualization. H, W are defined in normalized bounding box space. + - If None, the visualization will simply use a black image. + + ### Returns: + - TODO: + ''' + # ⛩️ Prepare the input data. + gt_kp2d = to_tensor(gt_kp2d, device=self.device).detach().float().clone() # (B, J, 3) + init_poses = to_tensor(init_poses, device=self.device).detach().float().clone() # (B, 46) + init_betas = to_tensor(init_betas, device=self.device).detach().float().clone() # (B, 10) + init_cam_t = to_tensor(init_cam_t, device=self.device).detach().float().clone() # (B, 3) + inputs = { + 'poses_orient': init_poses[:, :3], # (B, 3) + 'poses_body' : init_poses[:, 3:], # (B, 43) + 'betas' : init_betas, # (B, 10) + 'cam_t' : init_cam_t, # (B, 3) + } + + focal_length = np.ones(2) * self.cfg.focal_length / self.cfg.img_patch_size + focal_length = focal_length.reshape(1, 2).repeat(inputs['cam_t'].shape[0], 1) + + # ⛩️ Optimization phases, controlled by config file. + prev_phase_steps = 0 # accumulate the steps are *supposed* to be done in the previous phases + for phase_id, phase_name in enumerate(self.cfg.phases): + phase_cfg = self.cfg.phases[phase_name] + # Preparation. + optim_params = [] + for k in inputs.keys(): + if k in phase_cfg.params_keys: + inputs[k].requires_grad = True + optim_params.append(inputs[k]) # (B, D) + else: + inputs[k].requires_grad = False + + optimizer = instantiate(phase_cfg.optimizer, optim_params, _recursive_=True) + + def closure(): + optimizer.zero_grad() + + # Data preparation. + cam_t = inputs['cam_t'] + skel_params = { + 'poses' : torch.cat([inputs['poses_orient'], inputs['poses_body']], dim=-1), # (B, 46) + 'betas' : inputs['betas'], # (B, 10) + 'skelmesh' : False, + } + + # Optimize steps. + skel_output = self.skel_model(**skel_params) + + pd_kp2d = perspective_projection( + points = to_tensor(skel_output.joints, device=self.device), + translation = to_tensor(cam_t, device=self.device), + focal_length = to_tensor(focal_length, device=self.device), + ) + + loss, losses = self._compute_losses( + act_losses = phase_cfg.losses, + act_parts = phase_cfg.get('parts', 'all'), + gt_kp2d = gt_kp2d, + pd_kp2d = pd_kp2d, + pd_params = skel_params, + **phase_cfg.get('weights', {}), + ) + + # For visualize the optimization process. + _conf = gt_kp2d[..., 2] # (B, J) + metric = torch.sum((pd_kp2d - gt_kp2d[..., :2]) ** 2, dim=-1) * _conf # (B, J) + metric = metric.sum(dim=-1) / (torch.sum(_conf, dim=-1) + 1e-6) # (B,) + + # Store logging data. + if self.tb_logger is not None: + log_data.update({ + 'losses' : losses, + 'pd_kp2d' : pd_kp2d[:self.cfg.logger.samples_per_record].detach().clone(), + 'pd_verts' : skel_output.skin_verts[:self.cfg.logger.samples_per_record].detach().clone(), + 'cam_t' : cam_t[:self.cfg.logger.samples_per_record].detach().clone(), + 'metric' : metric[:self.cfg.logger.samples_per_record].detach().clone(), + 'optim_betas' : inputs['betas'][:self.cfg.logger.samples_per_record].detach().clone(), + }) + + loss.backward() + return loss.item() + + # Optimization loop. + prev_loss = None + with tqdm(range(phase_cfg.max_loop)) as bar: + bar.set_description(f'[{phase_name}] Loss: ???') + for i in bar: + log_data = {} + curr_loss = optimizer.step(closure) + + # Logging. + if self.tb_logger is not None: + log_data.update({ + 'img_patch' : img_patch[:self.cfg.logger.samples_per_record] if img_patch is not None else None, + 'gt_kp2d' : gt_kp2d[:self.cfg.logger.samples_per_record].detach().clone(), + }) + self._tb_log(prev_phase_steps + i, log_data) + # self._tb_log_for_report(prev_phase_steps + i, log_data) + + bar.set_description(f'[{phase_name}] Loss: {curr_loss:.4f}') + if self._can_early_quit(optim_params, prev_loss, curr_loss): + break + + prev_loss = curr_loss + + prev_phase_steps += phase_cfg.max_loop + + # ⛩️ Prepare the output data. + outputs = { + 'poses': torch.cat([inputs['poses_orient'], inputs['poses_body']], dim=-1).detach().clone(), # (B, 46) + 'betas': inputs['betas'].detach().clone(), # (B, 10) + 'cam_t': inputs['cam_t'].detach().clone(), # (B, 3) + } + return outputs + + + def _compute_losses( + self, + act_losses : List[str], + act_parts : List[str], + gt_kp2d : torch.Tensor, + pd_kp2d : torch.Tensor, + pd_params : Dict, + robust_sigma : float = 100, + shape_prior_weight : float = 5, + angle_prior_weight : float = 15.2, + *args, **kwargs, + ): + ''' + Compute the weighted losses according to the config file. + Follow: https://github.com/nkolot/SPIN/blob/2476c436013055be5cb3905e4e4ecfa86966fac3/smplify/losses.py#L26-L58s + ''' + B = len(gt_kp2d) + act_j_masks = get_kp_active_j_masks(act_parts, device=gt_kp2d.device) # (44,) + + # Reproject the 3D keypoints to image and compare the L2 error with the g.t. 2D keypoints. + kp_conf = gt_kp2d[..., 2] # (B, J) + gt_kp2d = gt_kp2d[..., :2] # (B, J, 2) + reproj_err = gmof(pd_kp2d - gt_kp2d, robust_sigma) # (B, J, 2) + reproj_loss = ((kp_conf ** 2) * reproj_err.sum(dim=-1) * act_j_masks[None]).sum(-1) # (B,) + + # Regularize the shape parameters. + shape_prior_loss = (shape_prior_weight ** 2) * (pd_params['betas'] ** 2).sum(dim=-1) # (B,) + + # Use the SKEL angle prior knowledge (e.g., rotation limitation) to regularize the optimization process. + # TODO: Is that necessary? + angle_prior_loss = (angle_prior_weight ** 2) * compute_poses_angle_prior_loss(pd_params['poses']).mean() # (,) + + losses = { + 'reprojection' : reproj_loss.mean(), # (,) + 'shape_prior' : shape_prior_loss.mean(), # (,) + 'angle_prior' : angle_prior_loss, # (,) + } + loss = torch.tensor(0., device=gt_kp2d.device) + for k in act_losses: + loss += losses[k] + losses = {k: v.detach() for k, v in losses.items()} + losses['sum'] = loss.detach() # (,) + return loss, losses + + + def _can_early_quit(self, opt_params, prev_loss, curr_loss): + ''' Judge whether to early quit the optimization process. If yes, return True, otherwise False.''' + if self.cfg.early_quit_thresholds is None: + # Never early quit. + return False + + # Relative change test. + if prev_loss is not None: + loss_rel_change = compute_rel_change(prev_loss, curr_loss) + if loss_rel_change < self.cfg.early_quit_thresholds.rel: + get_logger().info(f'Early quit due to relative change: {loss_rel_change:.4f} = rel({prev_loss}, {curr_loss})') + return True + + # Absolute change test. + if all([ + torch.abs(param.grad.max()).item() < self.cfg.early_quit_thresholds.abs + for param in opt_params if param.grad is not None + ]): + get_logger().info(f'Early quit due to absolute change.') + return True + + return False + + + @rank_zero_only + def _tb_log(self, step_cnt:int, log_data:Dict, *args, **kwargs): + ''' Write the logging information to the TensorBoard. ''' + if step_cnt != 0 and (step_cnt + 1) % self.cfg.logger.interval != 0: + return + + summary_writer = self.tb_logger.experiment + + # Save losses. + for loss_name, loss_val in log_data['losses'].items(): + summary_writer.add_scalar(f'skelify/{loss_name}', loss_val.detach().item(), step_cnt) + + # Visualization of the optimization process. TODO: Maybe we can make this more elegant. + if log_data['img_patch'] is None: + log_data['img_patch'] = [np.zeros((self.cfg.img_patch_size, self.cfg.img_patch_size, 3), dtype=np.uint8)] \ + * len(log_data['gt_kp2d']) + + if len(self.render_frames) < 1: + self.init_v = log_data['pd_verts'] + self.init_metric = log_data['metric'] + self.init_ct = log_data['cam_t'] + + # Overlay the skin mesh of the results on the original image. + try: + imgs_spliced = [] + for i, img_patch in enumerate(log_data['img_patch']): + metric = log_data['metric'][i].item() + + img_with_init = render_mesh_overlay_img( + faces = self.skel_model.skin_f, + verts = self.init_v[i], + K4 = [self.cfg.focal_length, self.cfg.focal_length, 0, 0], + img = img_patch, + Rt = [torch.eye(3), self.init_ct[i]], + mesh_color = 'pink', + ) + img_with_init = annotate_img(img_with_init, 'init') + img_with_init = annotate_img(img_with_init, f'Quality: {self.init_metric[i].item()*1000:.3f}/1e3', pos='tl') + + img_with_mesh = render_mesh_overlay_img( + faces = self.skel_model.skin_f, + verts = log_data['pd_verts'][i], + K4 = [self.cfg.focal_length, self.cfg.focal_length, 0, 0], + img = img_patch, + Rt = [torch.eye(3), log_data['cam_t'][i]], + mesh_color = 'pink', + ) + img_with_mesh = annotate_img(img_with_mesh, 'pd_mesh') + betas_max = log_data['optim_betas'][i].abs().max().item() + img_with_mesh = annotate_img(img_with_mesh, f'Quality: {metric*1000:.3f}/1e3\nbetas_max: {betas_max:.3f}', pos='tl') + img_patch_raw = annotate_img(img_patch, 'raw') + + log_data['gt_kp2d'][i][..., :2] = (log_data['gt_kp2d'][i][..., :2] + 0.5) * self.cfg.img_patch_size + img_with_gt = annotate_img(img_patch, 'gt_kp2d') + img_with_gt = draw_kp2d_on_img( + img_with_gt, + log_data['gt_kp2d'][i], + Skeleton_OpenPose25.bones, + Skeleton_OpenPose25.bone_colors, + ) + + log_data['pd_kp2d'][i] = (log_data['pd_kp2d'][i] + 0.5) * self.cfg.img_patch_size + img_with_pd = annotate_img(img_patch, 'pd_kp2d') + img_with_pd = draw_kp2d_on_img( + img_with_pd, + log_data['pd_kp2d'][i], + Skeleton_OpenPose25.bones, + Skeleton_OpenPose25.bone_colors, + ) + + img_spliced = splice_img( + img_grids = [img_patch_raw, img_with_gt, img_with_pd, img_with_init, img_with_mesh], + # grid_ids = [[0, 1, 2, 3, 4]], + grid_ids = [[1, 2, 3, 4]], + ) + imgs_spliced.append(img_spliced) + + img_final = splice_img(imgs_spliced, grid_ids=[[i] for i in range(len(log_data['img_patch']))]) + + img_final = to_tensor(img_final, device=None).permute(2, 0, 1) # (3, H, W) + summary_writer.add_image('skelify/visualization', img_final, step_cnt) + + self.render_frames.append(img_final) + except Exception as e: + get_logger().error(f'Failed to visualize the optimization process: {e}') + # traceback.print_exc() + + + @rank_zero_only + def _tb_log_for_report(self, step_cnt:int, log_data:Dict, *args, **kwargs): + ''' Write the logging information to the TensorBoard. ''' + + get_logger().warning(f'This logging functions is just for presentation.') + + if len(self.render_frames) < 1: + self.init_v = log_data['pd_verts'] + self.init_ct = log_data['cam_t'] + + if step_cnt != 0 and (step_cnt + 1) % self.cfg.logger.interval != 0: + return + + summary_writer = self.tb_logger.experiment + + # Save losses. + for loss_name, loss_val in log_data['losses'].items(): + summary_writer.add_scalar(f'losses/{loss_name}', loss_val.detach().item(), step_cnt) + + # Visualization of the optimization process. TODO: Maybe we can make this more elegant. + if log_data['img_patch'] is None: + log_data['img_patch'] = [np.zeros((self.cfg.img_patch_size, self.cfg.img_patch_size, 3), dtype=np.uint8)] \ + * len(log_data['gt_kp2d']) + + # Overlay the skin mesh of the results on the original image. + try: + imgs_spliced = [] + for i, img_patch in enumerate(log_data['img_patch']): + img_with_init = render_mesh_overlay_img( + faces = self.skel_model.skin_f, + verts = self.init_v[i], + K4 = [self.cfg.focal_length, self.cfg.focal_length, 0, 0], + img = img_patch, + Rt = [torch.eye(3), self.init_ct[i]], + mesh_color = 'pink', + ) + img_with_init = annotate_img(img_with_init, 'init') + + img_with_mesh = render_mesh_overlay_img( + faces = self.skel_model.skin_f, + verts = log_data['pd_verts'][i], + K4 = [self.cfg.focal_length, self.cfg.focal_length, 0, 0], + img = img_patch, + Rt = [torch.eye(3), log_data['cam_t'][i]], + mesh_color = 'pink', + ) + img_with_mesh = annotate_img(img_with_mesh, 'pd_mesh') + + img_patch_raw = annotate_img(img_patch, 'raw') + + log_data['gt_kp2d'][i][..., :2] = (log_data['gt_kp2d'][i][..., :2] + 0.5) * self.cfg.img_patch_size + img_with_gt = annotate_img(img_patch, 'gt_kp2d') + img_with_gt = draw_kp2d_on_img( + img_with_gt, + log_data['gt_kp2d'][i], + Skeleton_OpenPose25.bones, + Skeleton_OpenPose25.bone_colors, + ) + + img_spliced = splice_img([img_patch_raw, img_with_gt, img_with_init, img_with_mesh], grid_ids=[[0, 1, 2, 3]]) + imgs_spliced.append(img_spliced) + + img_final = splice_img(imgs_spliced, grid_ids=[[i] for i in range(len(log_data['img_patch']))]) + + img_final = to_tensor(img_final, device=None).permute(2, 0, 1) + summary_writer.add_image('visualization', img_final, step_cnt) + + self.render_frames.append(img_final) + except Exception as e: + get_logger().error(f'Failed to visualize the optimization process: {e}') + traceback.print_exc() diff --git a/lib/modeling/pipelines/__init__.py b/lib/modeling/pipelines/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..84111117cb52fb9475d6d06685bdd2f74f808120 --- /dev/null +++ b/lib/modeling/pipelines/__init__.py @@ -0,0 +1 @@ +from .hsmr import HSMRPipeline \ No newline at end of file diff --git a/lib/modeling/pipelines/hsmr.py b/lib/modeling/pipelines/hsmr.py new file mode 100644 index 0000000000000000000000000000000000000000..be3c0ac4bc78a5a94aacf0300b49d108cba97aa4 --- /dev/null +++ b/lib/modeling/pipelines/hsmr.py @@ -0,0 +1,649 @@ +from lib.kits.basic import * + +import traceback + +from lib.utils.vis import Wis3D +from lib.utils.vis.py_renderer import render_mesh_overlay_img +from lib.utils.data import to_tensor +from lib.utils.media import draw_kp2d_on_img, annotate_img, splice_img +from lib.utils.camera import perspective_projection +from lib.body_models.abstract_skeletons import Skeleton_OpenPose25 +from lib.modeling.losses import * +from lib.modeling.networks.discriminators import HSMRDiscriminator +from lib.platform.config_utils import get_PM_info_dict + + +def build_inference_pipeline( + model_root: Union[Path, str], + ckpt_fn : Optional[Union[Path, str]] = None, + tuned_bcb : bool = True, + device : str = 'cpu', +): + # 1.1. Load the config file. + if isinstance(model_root, str): + model_root = Path(model_root) + cfg_path = model_root / '.hydra' / 'config.yaml' + cfg = OmegaConf.load(cfg_path) + # 1.2. Override PM info dict. + PM_overrides = get_PM_info_dict()._pm_ + cfg._pm_ = PM_overrides + get_logger(brief=True).info(f'Building inference pipeline of {cfg.exp_name}') + + # 2.1. Instantiate the pipeline. + init_bcb = not tuned_bcb + pipeline = instantiate(cfg.pipeline, init_backbone=init_bcb, _recursive_=False) + pipeline.set_data_adaption(data_module_name='IMG_PATCHES') + # 2.2. Load the checkpoint. + if ckpt_fn is None: + ckpt_fn = model_root / 'checkpoints' / 'last.ckpt' + pipeline.load_state_dict(torch.load(ckpt_fn, map_location=device)['state_dict']) + get_logger(brief=True).info(f'Load checkpoint from {ckpt_fn}.') + + pipeline.eval() + return pipeline.to(device) + + +class HSMRPipeline(pl.LightningModule): + + def __init__(self, cfg:DictConfig, name:str, init_backbone=True): + super(HSMRPipeline, self).__init__() + self.name = name + + self.skel_model = instantiate(cfg.SKEL) + self.backbone = instantiate(cfg.backbone) + self.head = instantiate(cfg.head) + self.cfg = cfg + + if init_backbone: + # For inference mode with tuned backbone checkpoints, we don't need to initialize the backbone here. + self._init_backbone() + + # Loss layers. + self.kp_3d_loss = Keypoint3DLoss(loss_type='l1') + self.kp_2d_loss = Keypoint2DLoss(loss_type='l1') + self.params_loss = ParameterLoss() + + # Discriminator. + self.enable_disc = self.cfg.loss_weights.get('adversarial', 0) > 0 + if self.enable_disc: + self.discriminator = HSMRDiscriminator() + get_logger().warning(f'Discriminator enabled, the global_steps will be doubled. Use the checkpoints carefully.') + else: + self.discriminator = None + self.cfg.loss_weights.pop('adversarial', None) # pop the adversarial term if not enabled + + # Manually control the optimization since we have an adversarial process. + self.automatic_optimization = False + self.set_data_adaption() + + # For visualization debug. + if False: + self.wis3d = Wis3D(seq_name=PM.cfg.exp_name) + else: + self.wis3d = None + + def set_data_adaption(self, data_module_name:Optional[str]=None): + if data_module_name is None: + # get_logger().warning('Data adapter schema is not defined. The input will be regarded as image patches.') + self.adapt_batch = self._adapt_img_inference + elif data_module_name == 'IMG_PATCHES': + self.adapt_batch = self._adapt_img_inference + elif data_module_name.startswith('SKEL_HSMR_V1'): + self.adapt_batch = self._adapt_hsmr_v1 + else: + raise ValueError(f'Unknown data module: {data_module_name}') + + def print_summary(self, max_depth=1): + from pytorch_lightning.utilities.model_summary.model_summary import ModelSummary + print(ModelSummary(self, max_depth=max_depth)) + + def configure_optimizers(self): + optimizers = [] + + params_main = filter(lambda p: p.requires_grad, self._params_main()) + optimizer_main = instantiate(self.cfg.optimizer, params=params_main) + optimizers.append(optimizer_main) + + if len(self._params_disc()) > 0: + params_disc = filter(lambda p: p.requires_grad, self._params_disc()) + optimizer_disc = instantiate(self.cfg.optimizer, params=params_disc) + optimizers.append(optimizer_disc) + + return optimizers + + def training_step(self, raw_batch, batch_idx): + with PM.time_monitor('training_step'): + return self._training_step(raw_batch, batch_idx) + + def _training_step(self, raw_batch, batch_idx): + # GPU_monitor = GPUMonitor() + # GPU_monitor.snapshot('HSMR training start') + + batch = self.adapt_batch(raw_batch['img_ds']) + # GPU_monitor.snapshot('HSMR adapt batch') + + # Get the optimizer. + optimizers = self.optimizers(use_pl_optimizer=True) + if isinstance(optimizers, List): + optimizer_main, optimizer_disc = optimizers + else: + optimizer_main = optimizers + # GPU_monitor.snapshot('HSMR get optimizer') + + # 1. Main parts forward pass. + with PM.time_monitor('forward_step'): + img_patch = to_tensor(batch['img_patch'], self.device) # (B, C, H, W) + B = len(img_patch) + outputs = self.forward_step(img_patch) # {...} + # GPU_monitor.snapshot('HSMR forward') + pd_skel_params = HSMRPipeline._adapt_skel_params(outputs['pd_params']) + # GPU_monitor.snapshot('HSMR adapt SKEL params') + + # 2. [Optional] Discriminator forward pass in main training step. + if self.enable_disc: + with PM.time_monitor('disc_forward'): + pd_poses_mat, _ = self.skel_model.pose_params_to_rot(pd_skel_params['poses']) # (B, J=24, 3, 3) + pd_poses_body_mat = pd_poses_mat[:, 1:, :, :] # (B, J=23, 3, 3) + pd_betas = pd_skel_params['betas'] # (B, 10) + disc_out = self.discriminator( + poses_body = pd_poses_body_mat, # (B, J=23, 3, 3) + betas = pd_betas, # (B, 10) + ) + else: + disc_out = None + + # 3. Prepare the secondary products + with PM.time_monitor('Secondary Products Preparation'): + # 3.1. Body model outputs. + with PM.time_monitor('SKEL Forward'): + skel_outputs = self.skel_model(**pd_skel_params, skelmesh=False) + pd_kp3d = skel_outputs.joints # (B, Q=44, 3) + pd_skin = skel_outputs.skin_verts # (B, V=6890, 3) + # 3.2. Reproject the 3D joints to 2D plain. + with PM.time_monitor('Reprojection'): + pd_kp2d = perspective_projection( + points = pd_kp3d, # (B, K=Q=44, 3) + translation = outputs['pd_cam_t'], # (B, 3) + focal_length = outputs['focal_length'] / self.cfg.policy.img_patch_size, # (B, 2) + ) # (B, 44, 2) + # 3.3. Extract G.T. from inputs. + gt_kp2d_with_conf = batch['kp2d'].clone() # (B, 44, 3) + gt_kp3d_with_conf = batch['kp3d'].clone() # (B, 44, 4) + # 3.4. Extract G.T. skin mesh only for visualization. + gt_skel_params = HSMRPipeline._adapt_skel_params(batch['gt_params']) # {poses, betas} + gt_skel_params = {k: v[:self.cfg.logger.samples_per_record] for k, v in gt_skel_params.items()} + skel_outputs = self.skel_model(**gt_skel_params, skelmesh=False) + gt_skin = skel_outputs.skin_verts # (B', V=6890, 3) + gt_valid_body = batch['has_gt_params']['poses_body'][:self.cfg.logger.samples_per_record] # {poses_orient, poses_body, betas} + gt_valid_orient = batch['has_gt_params']['poses_orient'][:self.cfg.logger.samples_per_record] # {poses_orient, poses_body, betas} + gt_valid_betas = batch['has_gt_params']['betas'][:self.cfg.logger.samples_per_record] # {poses_orient, poses_body, betas} + gt_valid = torch.logical_and(torch.logical_and(gt_valid_body, gt_valid_orient), gt_valid_betas) + # GPU_monitor.snapshot('HSMR secondary products') + + # 4. Compute losses. + with PM.time_monitor('Compute Loss'): + loss_main, losses_main = self._compute_losses_main( + self.cfg.loss_weights, + pd_kp3d, # (B, 44, 3) + gt_kp3d_with_conf, # (B, 44, 4) + pd_kp2d, # (B, 44, 2) + gt_kp2d_with_conf, # (B, 44, 3) + outputs['pd_params'], # {'poses_orient':..., 'poses_body':..., 'betas':...} + batch['gt_params'], # {'poses_orient':..., 'poses_body':..., 'betas':...} + batch['has_gt_params'], + disc_out, + ) + # GPU_monitor.snapshot('HSMR compute losses') + if torch.isnan(loss_main): + get_logger().error(f'NaN detected in loss computation. Losses: {losses}') + + # 5. Main parts backward pass. + with PM.time_monitor('Backward Step'): + optimizer_main.zero_grad() + self.manual_backward(loss_main) + optimizer_main.step() + # GPU_monitor.snapshot('HSMR backwards') + + # 6. [Optional] Discriminator training part. + if self.enable_disc: + with PM.time_monitor('Train Discriminator'): + losses_disc = self._train_discriminator( + mocap_batch = raw_batch['mocap_ds'], + pd_poses_body_mat = pd_poses_body_mat, + pd_betas = pd_betas, + optimizer = optimizer_disc, + ) + else: + losses_disc = {} + + # 7. Logging. + with PM.time_monitor('Tensorboard Logging'): + vis_data = { + 'img_patch' : to_numpy(img_patch[:self.cfg.logger.samples_per_record]).transpose((0, 2, 3, 1)).copy(), + 'pd_kp2d' : pd_kp2d[:self.cfg.logger.samples_per_record].clone(), + 'pd_kp3d' : pd_kp3d[:self.cfg.logger.samples_per_record].clone(), + 'gt_kp2d_with_conf' : gt_kp2d_with_conf[:self.cfg.logger.samples_per_record].clone(), + 'gt_kp3d_with_conf' : gt_kp3d_with_conf[:self.cfg.logger.samples_per_record].clone(), + 'pd_skin' : pd_skin[:self.cfg.logger.samples_per_record].clone(), + 'gt_skin' : gt_skin.clone(), + 'gt_skin_valid' : gt_valid, + 'cam_t' : outputs['pd_cam_t'][:self.cfg.logger.samples_per_record].clone(), + 'img_key' : batch['__key__'][:self.cfg.logger.samples_per_record], + } + self._tb_log(losses_main=losses_main, losses_disc=losses_disc, vis_data=vis_data) + # GPU_monitor.snapshot('HSMR logging') + self.log('_/loss_main', losses_main['weighted_sum'], on_step=True, on_epoch=True, prog_bar=True, logger=True, batch_size=B) + + # GPU_monitor.report_all() + return outputs + + def forward(self, batch): + ''' + ### Returns + - outputs: Dict + - pd_kp3d: torch.Tensor, shape (B, Q=44, 3) + - pd_kp2d: torch.Tensor, shape (B, Q=44, 2) + - pred_keypoints_2d: torch.Tensor, shape (B, Q=44, 2) + - pred_keypoints_3d: torch.Tensor, shape (B, Q=44, 3) + - pd_params: Dict + - poses: torch.Tensor, shape (B, 46) + - betas: torch.Tensor, shape (B, 10) + - pd_cam: torch.Tensor, shape (B, 3) + - pd_cam_t: torch.Tensor, shape (B, 3) + - focal_length: torch.Tensor, shape (B, 2) + ''' + batch = self.adapt_batch(batch) + + # 1. Main parts forward pass. + img_patch = to_tensor(batch['img_patch'], self.device) # (B, C, H, W) + outputs = self.forward_step(img_patch) # {...} + + # 2. Prepare the secondary products + # 2.1. Body model outputs. + pd_skel_params = HSMRPipeline._adapt_skel_params(outputs['pd_params']) + skel_outputs = self.skel_model(**pd_skel_params, skelmesh=False) + pd_kp3d = skel_outputs.joints # (B, Q=44, 3) + pd_skin_verts = skel_outputs.skin_verts.detach().cpu().clone() # (B, V=6890, 3) + # 2.2. Reproject the 3D joints to 2D plain. + pd_kp2d = perspective_projection( + points = to_tensor(pd_kp3d, device=self.device), # (B, K=Q=44, 3) + translation = to_tensor(outputs['pd_cam_t'], device=self.device), # (B, 3) + focal_length = to_tensor(outputs['focal_length'], device=self.device) / self.cfg.policy.img_patch_size, # (B, 2) + ) + + outputs['pd_kp3d'] = pd_kp3d + outputs['pd_kp2d'] = pd_kp2d + outputs['pred_keypoints_2d'] = pd_kp2d # adapt HMR2.0's script + outputs['pred_keypoints_3d'] = pd_kp3d # adapt HMR2.0's script + outputs['pd_params'] = pd_skel_params + outputs['pd_skin_verts'] = pd_skin_verts + + return outputs + + def forward_step(self, x:torch.Tensor): + ''' + Run an inference step on the model. + + ### Args + - x: torch.Tensor, shape (B, C, H, W) + - The input image patch. + + ### Returns + - outputs: Dict + - 'pd_cam': torch.Tensor, shape (B, 3) + - The predicted camera parameters. + - 'pd_params': Dict + - The predicted body model parameters. + - 'focal_length': float + ''' + # GPU_monitor = GPUMonitor() + B = len(x) + + # 1. Extract features from image. + # The input size is 256*256, but ViT needs 256*192. TODO: make this more elegant. + with PM.time_monitor('Backbone Forward'): + feats = self.backbone(x[:, :, :, 32:-32]) + # GPU_monitor.snapshot('HSMR forward backbone') + + + # 2. Run the head to predict the body model parameters. + with PM.time_monitor('Predict Head Forward'): + pd_params, pd_cam = self.head(feats) + # GPU_monitor.snapshot('HSMR forward head') + + # 3. Transform the camera parameters to camera translation. + focal_length = self.cfg.policy.focal_length * torch.ones(B, 2, device=self.device, dtype=pd_cam.dtype) # (B, 2) + pd_cam_t = torch.stack([ + pd_cam[:, 1], + pd_cam[:, 2], + 2 * focal_length[:, 0] / (self.cfg.policy.img_patch_size * pd_cam[:, 0] + 1e-9) + ], dim=-1) # (B, 3) + + # 4. Store the results. + outputs = { + 'pd_cam' : pd_cam, + 'pd_cam_t' : pd_cam_t, + 'pd_params' : pd_params, + # 'pd_params' : {k: v.clone() for k, v in pd_params.items()}, + 'focal_length' : focal_length, # (B, 2) + } + # GPU_monitor.report_all() + return outputs + + + # ========== Internal Functions ========== + + def _params_main(self): + return list(self.head.parameters()) + list(self.backbone.parameters()) + + def _params_disc(self): + if self.discriminator is None: + return [] + else: + return list(self.discriminator.parameters()) + + @staticmethod + def _adapt_skel_params(params:Dict): + ''' Change the parameters formed like [pose_orient, pose_body, betas, trans] to [poses, betas, trans]. ''' + adapted_params = {} + + if 'poses' in params.keys(): + adapted_params['poses'] = params['poses'] + elif 'poses_orient' in params.keys() and 'poses_body' in params.keys(): + poses_orient = params['poses_orient'] # (B, 3) + poses_body = params['poses_body'] # (B, 43) + adapted_params['poses'] = torch.cat([poses_orient, poses_body], dim=1) # (B, 46) + else: + raise ValueError(f'Cannot find the poses parameters among {list(params.keys())}.') + + if 'betas' in params.keys(): + adapted_params['betas'] = params['betas'] # (B, 10) + else: + raise ValueError(f'Cannot find the betas parameters among {list(params.keys())}.') + + return adapted_params + + def _init_backbone(self): + # 1. Loading the backbone weights. + get_logger().info(f'Loading backbone weights from {self.cfg.backbone_ckpt}') + state_dict = torch.load(self.cfg.backbone_ckpt, map_location='cpu')['state_dict'] + if 'backbone.cls_token' in state_dict.keys(): + state_dict = {k: v for k, v in state_dict.items() if 'backbone' in k and 'cls_token' not in k} + state_dict = {k.replace('backbone.', ''): v for k, v in state_dict.items()} + missing, unexpected = self.backbone.load_state_dict(state_dict) + if len(missing) > 0: + get_logger().warning(f'Missing keys in backbone: {missing}') + if len(unexpected) > 0: + get_logger().warning(f'Unexpected keys in backbone: {unexpected}') + + # 2. Freeze the backbone if needed. + if self.cfg.get('freeze_backbone', False): + self.backbone.eval() + self.backbone.requires_grad_(False) + + def _compute_losses_main( + self, + loss_weights : Dict, + pd_kp3d : torch.Tensor, + gt_kp3d : torch.Tensor, + pd_kp2d : torch.Tensor, + gt_kp2d : torch.Tensor, + pd_params : Dict, + gt_params : Dict, + has_params : Dict, + disc_out : Optional[torch.Tensor]=None, + *args, **kwargs, + ) -> Tuple[torch.Tensor, Dict]: + ''' Compute the weighted losses according to the config file. ''' + + # 1. Preparation. + with PM.time_monitor('Preparation'): + B = len(pd_kp3d) + gt_skel_params = HSMRPipeline._adapt_skel_params(gt_params) # {poses, betas} + pd_skel_params = HSMRPipeline._adapt_skel_params(pd_params) # {poses, betas} + + gt_betas = gt_skel_params['betas'].reshape(-1, 10) + pd_betas = pd_skel_params['betas'].reshape(-1, 10) + gt_poses = gt_skel_params['poses'].reshape(-1, 46) + pd_poses = pd_skel_params['poses'].reshape(-1, 46) + + # 2. Keypoints losses. + with PM.time_monitor('kp2d & kp3d Loss'): + kp2d_loss = self.kp_2d_loss(pd_kp2d, gt_kp2d) / B + kp3d_loss = self.kp_3d_loss(pd_kp3d, gt_kp3d) / B + + # 3. Prior losses. + with PM.time_monitor('Prior Loss'): + prior_loss = compute_poses_angle_prior_loss(pd_poses).mean() # (,) + + # 4. Parameters losses. + if self.cfg.sp_poses_repr == 'rotation_matrix': + with PM.time_monitor('q2mat'): + gt_poses_mat, _ = self.skel_model.pose_params_to_rot(gt_poses) # (B, J=24, 3, 3) + pd_poses_mat, _ = self.skel_model.pose_params_to_rot(pd_poses) # (B, J=24, 3, 3) + + gt_poses = gt_poses_mat.reshape(-1, 24*3*3) # (B, 24*3*3) + pd_poses = pd_poses_mat.reshape(-1, 24*3*3) # (B, 24*3*3) + + with PM.time_monitor('Parameters Loss'): + poses_orient_loss = self.params_loss(pd_poses[:, :9], gt_poses[:, :9], has_params['poses_orient']) / B + poses_body_loss = self.params_loss(pd_poses[:, 9:], gt_poses[:, 9:], has_params['poses_body']) / B + betas_loss = self.params_loss(pd_betas, gt_betas, has_params['betas']) / B + + # 5. Collect main losses. + with PM.time_monitor('Accumulate'): + losses = { + 'kp3d' : kp3d_loss, # (,) + 'kp2d' : kp2d_loss, # (,) + 'prior' : prior_loss, # (,) + 'poses_orient' : poses_orient_loss, # (,) + 'poses_body' : poses_body_loss, # (,) + 'betas' : betas_loss, # (,) + } + + # 6. Consider adversarial loss. + if disc_out is not None: + with PM.time_monitor('Adversarial Loss'): + adversarial_loss = ((disc_out - 1.0) ** 2).sum() / B # (,) + losses['adversarial'] = adversarial_loss + + with PM.time_monitor('Accumulate'): + loss = torch.tensor(0., device=self.device) + for k, v in losses.items(): + loss += v * loss_weights[k] + losses = {k: v.item() for k, v in losses.items()} + losses['weighted_sum'] = loss.item() + return loss, losses + + def _train_discriminator(self, mocap_batch, pd_poses_body_mat, pd_betas, optimizer): + ''' + Train the discriminator using the regressed body model parameters and the realistic MoCap data. + + ### Args + - mocap_batch: Dict + - 'poses_body': torch.Tensor, shape (B, 43) + - 'betas': torch.Tensor, shape (B, 10) + - pd_poses_body_mat: torch.Tensor, shape (B, J=23, 3, 3) + - pd_betas: torch.Tensor, shape (B, 10) + - optimizer: torch.optim.Optimizer + + ### Returns + - losses: Dict + - 'pd_disc': float + - 'mc_disc': float + ''' + pd_B = len(pd_poses_body_mat) + mc_B = len(mocap_batch['poses_body']) + get_logger().warning(f'pd_B: {pd_B} != mc_B: {mc_B}') + + # 1. Extract the realistic 3D MoCap label. + mc_poses_body = mocap_batch['poses_body'] # (B, 43) + padding_zeros = mc_poses_body.new_zeros(mc_B, 3) # (B, 3) + mc_poses = torch.cat([padding_zeros, mc_poses_body], dim=1) # (B, 46) + mc_poses_mat, _ = self.skel_model.pose_params_to_rot(mc_poses) # (B, J=24, 3, 3) + mc_poses_body_mat = mc_poses_mat[:, 1:, :, :] # (B, J=23, 3, 3) + mc_betas = mocap_batch['betas'] # (B, 10) + + # 2. Forward pass. + # Discriminator forward pass for the predicted data. + pd_disc_out = self.discriminator(pd_poses_body_mat.detach(), pd_betas.detach()) + pd_disc_loss = ((pd_disc_out - 0.0) ** 2).sum() / pd_B # (,) + # Discriminator forward pass for the realistic MoCap data. + mc_disc_out = self.discriminator(mc_poses_body_mat, mc_betas) + mc_disc_loss = ((mc_disc_out - 1.0) ** 2).sum() / pd_B # (,) TODO: This 'pd_B' is from HMR2, not sure if it's a bug. + + # 3. Backward pass. + disc_loss = self.cfg.loss_weights.adversarial * (pd_disc_loss + mc_disc_loss) + optimizer.zero_grad() + self.manual_backward(disc_loss) + optimizer.step() + + return { + 'pd_disc': pd_disc_loss.item(), + 'mc_disc': mc_disc_loss.item(), + } + + @rank_zero_only + def _tb_log(self, losses_main:Dict, losses_disc:Dict, vis_data:Dict, mode:str='train'): + ''' Write the logging information to the TensorBoard. ''' + if self.logger is None: + return + + if self.global_step != 1 and self.global_step % self.cfg.logger.interval != 0: + return + + # 1. Losses. + summary_writer = self.logger.experiment + for loss_name, loss_val in losses_main.items(): + summary_writer.add_scalar(f'{mode}/losses_main/{loss_name}', loss_val, self.global_step) + for loss_name, loss_val in losses_disc.items(): + summary_writer.add_scalar(f'{mode}/losses_disc/{loss_name}', loss_val, self.global_step) + + # 2. Visualization. + try: + pelvis_id = 39 + # 2.1. Visualize 3D information. + self.wis3d.add_motion_mesh( + verts = vis_data['pd_skin'] - vis_data['pd_kp3d'][:, pelvis_id:pelvis_id+1], # center the mesh + faces = self.skel_model.skin_f, + name = 'pd_skin', + ) + self.wis3d.add_motion_mesh( + verts = vis_data['gt_skin'] - vis_data['gt_kp3d_with_conf'][:, pelvis_id:pelvis_id+1, :3], # center the mesh + faces = self.skel_model.skin_f, + name = 'gt_skin', + ) + self.wis3d.add_motion_skel( + joints = vis_data['pd_kp3d'] - vis_data['pd_kp3d'][:, pelvis_id:pelvis_id+1], + bones = Skeleton_OpenPose25.bones, + colors = Skeleton_OpenPose25.bone_colors, + name = 'pd_kp3d', + ) + + aligned_gt_kp3d = vis_data['gt_kp3d_with_conf'] + aligned_gt_kp3d[..., :3] -= vis_data['gt_kp3d_with_conf'][:, pelvis_id:pelvis_id+1, :3] + self.wis3d.add_motion_skel( + joints = aligned_gt_kp3d, + bones = Skeleton_OpenPose25.bones, + colors = Skeleton_OpenPose25.bone_colors, + name = 'gt_kp3d', + ) + except Exception as e: + get_logger().error(f'Failed to visualize the current performance on wis3d: {e}') + + try: + # 2.2. Visualize 2D information. + if vis_data['img_patch'] is not None: + # Overlay the skin mesh of the results on the original image. + imgs_spliced = [] + for i, img_patch in enumerate(vis_data['img_patch']): + # TODO: make this more elegant. + img_mean = to_numpy(OmegaConf.to_container(self.cfg.policy.img_mean))[None, None] # (1, 1, 3) + img_std = to_numpy(OmegaConf.to_container(self.cfg.policy.img_std))[None, None] # (1, 1, 3) + img_patch = ((img_mean + img_patch * img_std) * 255).astype(np.uint8) + + img_patch_raw = annotate_img(img_patch, 'raw') + + img_with_mesh = render_mesh_overlay_img( + faces = self.skel_model.skin_f, + verts = vis_data['pd_skin'][i].float(), + K4 = [self.cfg.policy.focal_length, self.cfg.policy.focal_length, 128, 128], + img = img_patch, + Rt = [torch.eye(3).float(), vis_data['cam_t'][i].float()], + mesh_color = 'pink', + ) + img_with_mesh = annotate_img(img_with_mesh, 'pd_mesh') + + img_with_gt_mesh = render_mesh_overlay_img( + faces = self.skel_model.skin_f, + verts = vis_data['gt_skin'][i].float(), + K4 = [self.cfg.policy.focal_length, self.cfg.policy.focal_length, 128, 128], + img = img_patch, + Rt = [torch.eye(3).float(), vis_data['cam_t'][i].float()], + mesh_color = 'pink', + ) + valid = 'valid' if vis_data['gt_skin_valid'][i] else 'invalid' + img_with_gt_mesh = annotate_img(img_with_gt_mesh, f'gt_mesh_{valid}') + + img_with_gt = annotate_img(img_patch, 'gt_kp2d') + gt_kp2d_with_conf = vis_data['gt_kp2d_with_conf'][i] + gt_kp2d_with_conf[:, :2] = (gt_kp2d_with_conf[:, :2] + 0.5) * self.cfg.policy.img_patch_size + img_with_gt = draw_kp2d_on_img( + img_with_gt, + gt_kp2d_with_conf, + Skeleton_OpenPose25.bones, + Skeleton_OpenPose25.bone_colors, + ) + + img_with_pd = annotate_img(img_patch, 'pd_kp2d') + pd_kp2d_vis = vis_data['pd_kp2d'][i] + pd_kp2d_vis = (pd_kp2d_vis + 0.5) * self.cfg.policy.img_patch_size + img_with_pd = draw_kp2d_on_img( + img_with_pd, + (vis_data['pd_kp2d'][i] + 0.5) * self.cfg.policy.img_patch_size, + Skeleton_OpenPose25.bones, + Skeleton_OpenPose25.bone_colors, + ) + + img_spliced = splice_img([img_patch_raw, img_with_gt, img_with_pd, img_with_mesh, img_with_gt_mesh], grid_ids=[[0, 1, 2, 3, 4]]) + img_spliced = annotate_img(img_spliced, vis_data['img_key'][i], pos='tl') + imgs_spliced.append(img_spliced) + + try: + self.wis3d.set_scene_id(i) + self.wis3d.add_image( + image = img_spliced, + name = 'image', + ) + except Exception as e: + get_logger().error(f'Failed to visualize the current performance on wis3d: {e}') + + img_final = splice_img(imgs_spliced, grid_ids=[[i] for i in range(len(vis_data['img_patch']))]) + + img_final = to_tensor(img_final, device=None).permute(2, 0, 1) + summary_writer.add_image(f'{mode}/visualization', img_final, self.global_step) + + except Exception as e: + get_logger().error(f'Failed to visualize the current performance: {e}') + # traceback.print_exc() + + + def _adapt_hsmr_v1(self, batch): + from lib.data.augmentation.skel import rot_skel_on_plane + rot_deg = batch['augm_args']['rot_deg'] # (B,) + + skel_params = rot_skel_on_plane(batch['raw_skel_params'], rot_deg) + batch['gt_params'] = {} + batch['gt_params']['poses_orient'] = skel_params['poses'][:, :3] + batch['gt_params']['poses_body'] = skel_params['poses'][:, 3:] + batch['gt_params']['betas'] = skel_params['betas'] + + has_skel_params = batch['has_skel_params'] + batch['has_gt_params'] = {} + batch['has_gt_params']['poses_orient'] = has_skel_params['poses'] + batch['has_gt_params']['poses_body'] = has_skel_params['poses'] + batch['has_gt_params']['betas'] = has_skel_params['betas'] + return batch + + def _adapt_img_inference(self, img_patches): + return {'img_patch': img_patches} \ No newline at end of file diff --git a/lib/modeling/pipelines/vitdet/__init__.py b/lib/modeling/pipelines/vitdet/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..001fdc725e611574500b06c0c6024929db75968a --- /dev/null +++ b/lib/modeling/pipelines/vitdet/__init__.py @@ -0,0 +1,20 @@ +from lib.kits.basic import * + +from detectron2.config import LazyConfig +from .utils_detectron2 import DefaultPredictor_Lazy + + +def build_detector(batch_size, max_img_size, device): + cfg_path = Path(__file__).parent / 'cascade_mask_rcnn_vitdet_h_75ep.py' + detectron2_cfg = LazyConfig.load(str(cfg_path)) + detectron2_cfg.train.init_checkpoint = "https://dl.fbaipublicfiles.com/detectron2/ViTDet/COCO/cascade_mask_rcnn_vitdet_h/f328730692/model_final_f05665.pkl" + for i in range(3): + detectron2_cfg.model.roi_heads.box_predictors[i].test_score_thresh = 0.25 + + detector = DefaultPredictor_Lazy( + cfg = detectron2_cfg, + batch_size = batch_size, + max_img_size = max_img_size, + device = device, + ) + return detector \ No newline at end of file diff --git a/lib/modeling/pipelines/vitdet/cascade_mask_rcnn_vitdet_h_75ep.py b/lib/modeling/pipelines/vitdet/cascade_mask_rcnn_vitdet_h_75ep.py new file mode 100644 index 0000000000000000000000000000000000000000..0c6ae0eaf48c2c2d3b70529a0d2d915432e43db6 --- /dev/null +++ b/lib/modeling/pipelines/vitdet/cascade_mask_rcnn_vitdet_h_75ep.py @@ -0,0 +1,129 @@ +## coco_loader_lsj.py + +import detectron2.data.transforms as T +from detectron2 import model_zoo +from detectron2.config import LazyCall as L + +# Data using LSJ +image_size = 1024 +dataloader = model_zoo.get_config("common/data/coco.py").dataloader +dataloader.train.mapper.augmentations = [ + L(T.RandomFlip)(horizontal=True), # flip first + L(T.ResizeScale)( + min_scale=0.1, max_scale=2.0, target_height=image_size, target_width=image_size + ), + L(T.FixedSizeCrop)(crop_size=(image_size, image_size), pad=False), +] +dataloader.train.mapper.image_format = "RGB" +dataloader.train.total_batch_size = 64 +# recompute boxes due to cropping +dataloader.train.mapper.recompute_boxes = True + +dataloader.test.mapper.augmentations = [ + L(T.ResizeShortestEdge)(short_edge_length=image_size, max_size=image_size), +] + +from functools import partial +from fvcore.common.param_scheduler import MultiStepParamScheduler + +from detectron2 import model_zoo +from detectron2.config import LazyCall as L +from detectron2.solver import WarmupParamScheduler +from detectron2.modeling.backbone.vit import get_vit_lr_decay_rate + +# mask_rcnn_vitdet_b_100ep.py + +model = model_zoo.get_config("common/models/mask_rcnn_vitdet.py").model + +# Initialization and trainer settings +train = model_zoo.get_config("common/train.py").train +train.amp.enabled = True +train.ddp.fp16_compression = True +train.init_checkpoint = "detectron2://ImageNetPretrained/MAE/mae_pretrain_vit_base.pth" + + +# Schedule +# 100 ep = 184375 iters * 64 images/iter / 118000 images/ep +train.max_iter = 184375 + +lr_multiplier = L(WarmupParamScheduler)( + scheduler=L(MultiStepParamScheduler)( + values=[1.0, 0.1, 0.01], + milestones=[163889, 177546], + num_updates=train.max_iter, + ), + warmup_length=250 / train.max_iter, + warmup_factor=0.001, +) + +# Optimizer +optimizer = model_zoo.get_config("common/optim.py").AdamW +optimizer.params.lr_factor_func = partial(get_vit_lr_decay_rate, num_layers=12, lr_decay_rate=0.7) +optimizer.params.overrides = {"pos_embed": {"weight_decay": 0.0}} + +# cascade_mask_rcnn_vitdet_b_100ep.py + +from detectron2.config import LazyCall as L +from detectron2.layers import ShapeSpec +from detectron2.modeling.box_regression import Box2BoxTransform +from detectron2.modeling.matcher import Matcher +from detectron2.modeling.roi_heads import ( + FastRCNNOutputLayers, + FastRCNNConvFCHead, + CascadeROIHeads, +) + +# arguments that don't exist for Cascade R-CNN +[model.roi_heads.pop(k) for k in ["box_head", "box_predictor", "proposal_matcher"]] + +model.roi_heads.update( + _target_=CascadeROIHeads, + box_heads=[ + L(FastRCNNConvFCHead)( + input_shape=ShapeSpec(channels=256, height=7, width=7), + conv_dims=[256, 256, 256, 256], + fc_dims=[1024], + conv_norm="LN", + ) + for _ in range(3) + ], + box_predictors=[ + L(FastRCNNOutputLayers)( + input_shape=ShapeSpec(channels=1024), + test_score_thresh=0.05, + box2box_transform=L(Box2BoxTransform)(weights=(w1, w1, w2, w2)), + cls_agnostic_bbox_reg=True, + num_classes="${...num_classes}", + ) + for (w1, w2) in [(10, 5), (20, 10), (30, 15)] + ], + proposal_matchers=[ + L(Matcher)(thresholds=[th], labels=[0, 1], allow_low_quality_matches=False) + for th in [0.5, 0.6, 0.7] + ], +) + +# cascade_mask_rcnn_vitdet_h_75ep.py + +from functools import partial + +train.init_checkpoint = "detectron2://ImageNetPretrained/MAE/mae_pretrain_vit_huge_p14to16.pth" + +model.backbone.net.embed_dim = 1280 +model.backbone.net.depth = 32 +model.backbone.net.num_heads = 16 +model.backbone.net.drop_path_rate = 0.5 +# 7, 15, 23, 31 for global attention +model.backbone.net.window_block_indexes = ( + list(range(0, 7)) + list(range(8, 15)) + list(range(16, 23)) + list(range(24, 31)) +) + +optimizer.params.lr_factor_func = partial(get_vit_lr_decay_rate, lr_decay_rate=0.9, num_layers=32) +optimizer.params.overrides = {} +optimizer.params.weight_decay_norm = None + +train.max_iter = train.max_iter * 3 // 4 # 100ep -> 75ep +lr_multiplier.scheduler.milestones = [ + milestone * 3 // 4 for milestone in lr_multiplier.scheduler.milestones +] +lr_multiplier.scheduler.num_updates = train.max_iter diff --git a/lib/modeling/pipelines/vitdet/utils_detectron2.py b/lib/modeling/pipelines/vitdet/utils_detectron2.py new file mode 100644 index 0000000000000000000000000000000000000000..299331b8c0a1056e03d856ca32efe9b2ef9b1e3a --- /dev/null +++ b/lib/modeling/pipelines/vitdet/utils_detectron2.py @@ -0,0 +1,117 @@ +from lib.kits.basic import * + +from tqdm import tqdm + +from lib.utils.media import flex_resize_img + +import detectron2.data.transforms as T +from detectron2.checkpoint import DetectionCheckpointer +from detectron2.config import instantiate as instantiate_detectron2 +from detectron2.data import MetadataCatalog + + +class DefaultPredictor_Lazy: + ''' + Create a simple end-to-end predictor with the given config that runs on single device for a + several input images. + Compared to using the model directly, this class does the following additions: + + Modified from: https://github.com/shubham-goel/4D-Humans/blob/6ec79656a23c33237c724742ca2a0ec00b398b53/hmr2/utils/utils_detectron2.py#L9-L93 + + 1. Load checkpoint from the weights specified in config (cfg.MODEL.WEIGHTS). + 2. Always take BGR image as the input and apply format conversion internally. + 3. Apply resizing defined by the parameter `max_img_size`. + 4. Take input images and produce outputs, and filter out only the `instances` data. + 5. Use an auto-tuned batch size to process the images in a batch. + - Start with the given batch size, if failed, reduce the batch size by half. + - If the batch size is reduced to 1 and still failed, skip the image. + - The implementation is abstracted to `lib.platform.sliding_batches`. + ''' + + def __init__(self, cfg, batch_size=20, max_img_size=512, device='cuda:0'): + self.batch_size = batch_size + self.max_img_size = max_img_size + self.device = device + self.model = instantiate_detectron2(cfg.model) + + test_dataset = OmegaConf.select(cfg, 'dataloader.test.dataset.names', default=None) + if isinstance(test_dataset, (List, Tuple)): + test_dataset = test_dataset[0] + + checkpointer = DetectionCheckpointer(self.model) + checkpointer.load(OmegaConf.select(cfg, 'train.init_checkpoint', default='')) + + mapper = instantiate_detectron2(cfg.dataloader.test.mapper) + self.aug = mapper.augmentations + self.input_format = mapper.image_format + + self.model.eval().to(self.device) + if test_dataset: + self.metadata = MetadataCatalog.get(test_dataset) + + assert self.input_format in ['RGB'], f'Invalid input format: {self.input_format}' + # assert self.input_format in ['RGB', 'BGR'], f'Invalid input format: {self.input_format}' + + def __call__(self, imgs): + ''' + ### Args + - `imgs`: List[np.ndarray], a list of image of shape (Hi, Wi, RGB). + - Shapes of each image may be different. + + ### Returns + - `predictions`: dict, + - the output of the model for one image only. + - See :doc:`/tutorials/models` for details about the format. + ''' + with torch.no_grad(): + inputs = [] + downsample_ratios = [] + for img in imgs: + img_size = max(img.shape[:2]) + if img_size > self.max_img_size: # exceed the max size, make it smaller + downsample_ratio = self.max_img_size / img_size + img = flex_resize_img(img, ratio=downsample_ratio) + downsample_ratios.append(downsample_ratio) + else: + downsample_ratios.append(1.0) + h, w, _ = img.shape + img = self.aug(T.AugInput(img)).apply_image(img) + img = to_tensor(img.astype('float32').transpose(2, 0, 1), 'cpu') + inputs.append({'image': img, 'height': h, 'width': w}) + + preds = [] + N_imgs = len(inputs) + prog_bar = tqdm(total=N_imgs, desc='Batch Detection') + sid, last_fail_id = 0, 0 + cur_bs = self.batch_size + while sid < N_imgs: + eid = min(sid + cur_bs, N_imgs) + try: + preds_round = self.model(inputs[sid:eid]) + except Exception as e: + get_logger(brief=True).error(f'Image No.{sid}: {e}. Try to fix it.') + if cur_bs > 1: + cur_bs = (cur_bs - 1) // 2 + 1 # reduce the batch size by half + assert cur_bs > 0, 'Invalid batch size.' + get_logger(brief=True).info(f'Adjust the batch size to {cur_bs}.') + else: + get_logger(brief=True).error(f'Can\'t afford image No.{sid} even with batch_size=1, skip.') + preds.append(None) # placeholder for the failed image + sid += 1 + last_fail_id = sid + continue + # Save the results. + preds.extend([{ + 'pred_classes' : pred['instances'].pred_classes.cpu(), + 'scores' : pred['instances'].scores.cpu(), + 'pred_boxes' : pred['instances'].pred_boxes.tensor.cpu(), + } for pred in preds_round]) + + prog_bar.update(eid - sid) + sid = eid + # # Adjust the batch size. + # if last_fail_id < sid - cur_bs * 2: + # cur_bs = min(cur_bs * 2, self.batch_size) # gradually recover the batch size + prog_bar.close() + + return preds, downsample_ratios diff --git a/lib/platform/__init__.py b/lib/platform/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..0d1a44439fbe9a87308e732a94a334e069e842fc --- /dev/null +++ b/lib/platform/__init__.py @@ -0,0 +1,6 @@ +from .proj_manager import ProjManager as PM +from .config_utils import ( + print_cfg, + entrypoint, + entrypoint_with_args, +) \ No newline at end of file diff --git a/lib/platform/config_utils.py b/lib/platform/config_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..c327d42cd460bb36ca16a33aa0bc12e27de1dff9 --- /dev/null +++ b/lib/platform/config_utils.py @@ -0,0 +1,251 @@ +import sys +import json +import rich +import rich.text +import rich.tree +import rich.syntax +import hydra +from typing import List, Optional, Union, Any +from pathlib import Path +from omegaconf import OmegaConf, DictConfig, ListConfig +from pytorch_lightning.utilities import rank_zero_only + +from lib.info.log import get_logger + +from .proj_manager import ProjManager as PM + + +def get_PM_info_dict(): + ''' Get a OmegaConf object containing the information from the ProjManager. ''' + PM_info = OmegaConf.create({ + '_pm_': { + 'root' : str(PM.root), + 'inputs' : str(PM.inputs), + 'outputs': str(PM.outputs), + } + }) + return PM_info + + +def get_PM_info_list(): + ''' Get a list containing the information from the ProjManager. ''' + PM_info = [ + f'_pm_.root={str(PM.root)}', + f'_pm_.inputs={str(PM.inputs)}', + f'_pm_.outputs={str(PM.outputs)}', + ] + return PM_info + + +def entrypoint_with_args(*args, log_cfg=True, **kwargs): + ''' + This decorator extends the `hydra.main` decorator in these parts: + - Inject some runtime-known arguments, e.g., `proj_root`. + - Enable additional arguments that needn't to be specified in command line. + - Positional arguments are added to the command line arguments directly, so make sure they are valid. + - e.g., \'exp=<...>\', \'+extra=<...>\', etc. + - Key-specified arguments have the same effect as command line arguments {k}={v}. + - Check the validation of experiment name. + ''' + + overrides = get_PM_info_list() + + for arg in args: + overrides.append(arg) + + for k, v in kwargs.items(): + overrides.append(f'{k}={v}') + + overrides.extend(sys.argv[1:]) + + def entrypoint_wrapper(func): + # Import extra pre-specified arguments. + if len(overrides) > 0: + # The args from command line have higher priority, so put them in the back. + sys.argv = sys.argv[:1] + overrides + sys.argv[1:] + _log_exp_info(func.__name__, overrides) + + @hydra.main(version_base=None, config_path=str(PM.configs), config_name='base.yaml') + def entrypoint_preprocess(cfg:DictConfig): + # Resolve the references and make it editable. + cfg = unfold_cfg(cfg) + + # Print out the configuration files. + if log_cfg and cfg.get('show_cfg', True): + sum_keys = ['output_dir', 'pipeline.name', 'data.name', 'exp_name', 'exp_tag'] + print_cfg(cfg, sum_keys=sum_keys) + + # Check the validation of experiment name. + if cfg.get('exp_name') is None: + get_logger(brief=True).fatal(f'`exp_name` is not given! You may need to add `exp=` to the command line.') + raise ValueError('`exp_name` is not given!') + + # Bind config. + PM.init_with_cfg(cfg) + try: + with PM.time_monitor('exp', f'Main part of experiment `{cfg.exp_name}`.'): + # Enter the main function. + func(cfg) + except Exception as e: + raise e + finally: + PM.time_monitor.report(level='global') + + # TODO: Wrap a notifier here. + + return entrypoint_preprocess + + + return entrypoint_wrapper + + #! This implementation can't dump the config files in default ways. In order to keep c + # def entrypoint_wrapper(func): + # def entrypoint_preprocess(): + # # Initialize the configuration module. + # with hydra.initialize_config_dir(version_base=None, config_dir=str(PM.configs)): + # get_logger(brief=True).info(f'Exp entry `{func.__name__}` is called with overrides: {overrides}') + # cfg = hydra.compose(config_name='base', overrides=overrides) + + # cfg4dump_raw = cfg.copy() # store the folded raw configuration files + # # Resolve the references and make it editable. + # cfg = unfold_cfg(cfg) + + # # Print out the configuration files. + # if log_cfg: + # sum_keys = ['pipeline.name', 'data.name', 'exp_name'] + # print_cfg(cfg, sum_keys=sum_keys) + # # Check the validation of experiment name. + # if cfg.get('exp_name') is None: + # get_logger().fatal(f'`exp_name` is not given! You may need to add `exp=` to the command line.') + # raise ValueError('`exp_name` is not given!') + # # Enter the main function. + # func(cfg) + # return entrypoint_preprocess + # return entrypoint_wrapper + +def entrypoint(func): + ''' + This decorator extends the `hydra.main` decorator in these parts: + - Inject some runtime-known arguments, e.g., `proj_root`. + - Check the validation of experiment name. + ''' + return entrypoint_with_args()(func) + + +def unfold_cfg( + cfg : Union[DictConfig, Any], +): + ''' + Unfold the configuration files, i.e. from structured mode to container mode and recreate the + configuration files. It will resolve all the references and make the config editable. + + ### Args + - cfg: DictConfig or None + + ### Returns + - cfg: DictConfig or None + ''' + if cfg is None: + return None + + cfg_container = OmegaConf.to_container(cfg, resolve=True) + cfg = OmegaConf.create(cfg_container) + return cfg + + +def recursively_simplify_cfg( + node : DictConfig, + hide_misc : bool = True, +): + if isinstance(node, DictConfig): + for k in list(node.keys()): + # We delete some terms that are not commonly concerned. + if hide_misc: + if k in ['_hub_', 'hydra', 'job_logging']: + node.__delattr__(k) + continue + node[k] = recursively_simplify_cfg(node[k], hide_misc) + elif isinstance(node, ListConfig): + if len(node) > 0 and all([ + not isinstance(x, DictConfig) \ + and not isinstance(x, ListConfig) \ + for x in node + ]): + # We fold all lists of basic elements (int, float, ...) into a single line if possible. + folded_list_str = '*' + str(list(node)) + node = folded_list_str if len(folded_list_str) < 320 else node + else: + for i in range(len(node)): + node[i] = recursively_simplify_cfg(node[i], hide_misc) + return node + + +@rank_zero_only +def print_cfg( + cfg : Optional[DictConfig], + title : str ='cfg', + sum_keys: List[str] = [], + show_all: bool = False +): + ''' + Print configuration files using rich. + + ### Args + - cfg: DictConfig or None + - If None, print nothing. + - sum_keys: List[str], default [] + - If keys given in the list exist in the first level of the configuration files, + they will be printed in the summary part. + - show_all: bool, default False + - If False, hide terms starts with `_` in the configuration files's first level + and some hydra supporting configs. + ''' + + theme = 'coffee' + style = 'dim' + + tf_dict = { True: '◼', False: '◻' } + print_setting = f'<< {tf_dict[show_all]} SHOW_ALL >>' + tree = rich.tree.Tree(f'⌾ {title} - {print_setting}', style=style, guide_style=style) + + if cfg is None: + tree.add('None') + rich.print(tree) + return + + # Clone a new one to avoid changing the original configuration files. + cfg = cfg.copy() + cfg = unfold_cfg(cfg) + + if not show_all: + cfg = recursively_simplify_cfg(cfg) + + cfg_yaml = OmegaConf.to_yaml(cfg) + cfg_yaml = rich.syntax.Syntax(cfg_yaml, 'yaml', theme=theme, line_numbers=True) + tree.add(cfg_yaml) + + # Add a summary containing information only is commonly concerned. + if len(sum_keys) > 0: + concerned = {} + for k_str in sum_keys: + k_list = k_str.split('.') + tgt = cfg + for k in k_list: + if tgt is not None: + tgt = tgt.get(k) + if tgt is not None: + concerned[k_str] = tgt + else: + get_logger().warning(f'Key `{k_str}` is not found in the configuration files.') + + tree.add(rich.syntax.Syntax(OmegaConf.to_yaml(concerned), 'yaml', theme=theme)) + + rich.print(tree) + + +@rank_zero_only +def _log_exp_info( + func_name : str, + overrides : List[str], +): + get_logger(brief=True).info(f'Exp entry `{func_name}` is called with overrides: {overrides}') \ No newline at end of file diff --git a/lib/platform/monitor/__init__.py b/lib/platform/monitor/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..8165bfcd2187bee08b1a0ec904bccfd3b306322d --- /dev/null +++ b/lib/platform/monitor/__init__.py @@ -0,0 +1,2 @@ +from .time import * +from .gpu import * \ No newline at end of file diff --git a/lib/platform/monitor/gpu.py b/lib/platform/monitor/gpu.py new file mode 100644 index 0000000000000000000000000000000000000000..19e6aa1eb1d26890b3dae68bdbd8ba6fc90bb42a --- /dev/null +++ b/lib/platform/monitor/gpu.py @@ -0,0 +1,92 @@ +import time +import torch +import inspect + + +def fold_path(fn:str): + ''' Fold a path like `from/to/file.py` to relative `f/t/file.py`. ''' + return '/'.join([p[:1] for p in fn.split('/')[:-1]]) + '/' + fn.split('/')[-1] + + +def summary_frame_info(frame:inspect.FrameInfo): + ''' Convert a FrameInfo object to a summary string. ''' + return f'{frame.function} @ {fold_path(frame.filename)}:{frame.lineno}' + + +class GPUMonitor(): + ''' + This monitor is designed for GPU memory analysis. It records the peak memory usage in a period of time. + A snapshot will record the peak memory usage until the snapshot is taken. (After init / reset / previous snapshot.) + ''' + + def __init__(self): + self.reset() + self.clear() + self.log_fn = 'gpu_monitor.log' + + + def snapshot(self, desc:str='snapshot'): + timestamp = time.time() + caller_frame = inspect.stack()[1] + peak_MB = torch.cuda.max_memory_allocated() / 1024 / 1024 + free_mem, total_mem = torch.cuda.mem_get_info(0) + free_mem_MB, total_mem_MB = free_mem / 1024 / 1024, total_mem / 1024 / 1024 + + record = { + 'until' : time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(timestamp)), + 'until_raw' : timestamp, + 'position' : summary_frame_info(caller_frame), + 'peak' : peak_MB, + 'peak_msg' : f'{peak_MB:.2f} MB', + 'free' : free_mem_MB, + 'total' : total_mem_MB, + 'free_msg' : f'{free_mem_MB:.2f} MB', + 'total_msg' : f'{total_mem_MB:.2f} MB', + 'desc' : desc, + } + + self.max_peak = max(self.max_peak_MB, peak_MB) + + self.records.append(record) + self._update_log(record) + + self.reset() + return record + + + def report_latest(self, k:int=1): + import rich + caller_frame = inspect.stack()[1] + caller_info = summary_frame_info(caller_frame) + rich.print(f'{caller_info} -> latest {k} records:') + for rid, record in enumerate(self.records[-k:]): + msg = self._generate_log_msg(record) + rich.print(msg) + + + def report_all(self): + self.report_latest(len(self.records)) + + + def reset(self): + torch.cuda.reset_peak_memory_stats() + return + + + def clear(self): + self.records = [] + self.max_peak_MB = 0 + + def _generate_log_msg(self, record): + time = record['until'] + peak = record['peak'] + desc = record['desc'] + position = record['position'] + msg = f'[{time}] ⛰️ {peak:>8.2f} MB 📌 {desc} 🌐 {position}' + return msg + + + def _update_log(self, record): + msg = self._generate_log_msg(record) + with open(self.log_fn, 'a') as f: + f.write(msg + '\n') diff --git a/lib/platform/monitor/time.py b/lib/platform/monitor/time.py new file mode 100644 index 0000000000000000000000000000000000000000..7682fa689a5d53f59714ddbf612a170fced2c28f --- /dev/null +++ b/lib/platform/monitor/time.py @@ -0,0 +1,233 @@ +import time +import atexit +import inspect +import torch + +from typing import Optional, Union, List +from pathlib import Path +from concurrent.futures import ThreadPoolExecutor + + +def fold_path(fn:str): + ''' Fold a path like `from/to/file.py` to relative `f/t/file.py`. ''' + return '/'.join([p[:1] for p in fn.split('/')[:-1]]) + '/' + fn.split('/')[-1] + + +def summary_frame_info(frame:inspect.FrameInfo): + ''' Convert a FrameInfo object to a summary string. ''' + return f'{frame.function} @ {fold_path(frame.filename)}:{frame.lineno}' + + +class TimeMonitorDisabled: + def foo(self, *args, **kwargs): + return + + def __init__(self, log_folder:Optional[Union[str, Path]]=None, record_birth_block:bool=False): + self.tick = self.foo + self.report = self.foo + self.clear = self.foo + self.dump_statistics = self.foo + + def __call__(self, *args, **kwargs): + return self + def __enter__(self): + return self + def __exit__(self, exc_type, exc_val, exc_tb): + return + + +class TimeMonitor: + ''' + It is supposed to be used like this: + + time_monitor = TimeMonitor() + + with time_monitor('test_block', 'Block that does something.') as tm: + do_something() + + time_monitor.report() + ''' + + def __init__(self, log_folder:Optional[Union[str, Path]]=None, record_birth_block:bool=False): + if log_folder is not None: + self.log_folder = Path(log_folder) if isinstance(log_folder, str) else log_folder + self.log_folder.mkdir(parents=True, exist_ok=True) + log_fn = self.log_folder / 'readable.log' + self.log_fh = open(log_fn, 'w') # Log file handler. + self.log_fh.write('=== New Exp ===\n') + else: + self.log_folder = None + self.clear() + self.current_block_uid_stack : List = [] # Unique block id stack for recording. + self.current_block_aid_stack : List = [] # Block id stack for accumulated cost analysis. + + # Specially add a global start and end block. + self.record_birth_block = record_birth_block and log_folder is not None + if self.record_birth_block: + self.__call__('monitor_birth', 'Since the monitor is constructed.') + self.__enter__() + + # Register the exit hook to dump the data safely. + atexit.register(self._die_hook) + + + def __call__(self, block_name:str, block_desc:Optional[str]=None): + ''' Set up the name of the context for a block. ''' + # 1. Format the block name. + block_name = block_name.replace('/', '-').replace(' ', '-') + block_name_recursive = '/'.join([s.split('/')[-1] for s in self.current_block_aid_stack] + [block_name]) # Tree structure block name. + # 2. Get a unique name for the block record. + block_postfixed = 0 + while f'{block_name_recursive}_{block_postfixed}' in self.block_info: + block_postfixed += 1 + # 3. Get the caller frame information. + caller_frame = inspect.stack()[1] + block_position = summary_frame_info(caller_frame) + # 4. Initialize the block information. + self.current_block_uid_stack.append(f'{block_name_recursive}_{block_postfixed}') + self.current_block_aid_stack.append(block_name) + self.block_info[self.current_block_uid_stack[-1]] = { + 'records' : [], + 'position' : block_position, + 'desc' : block_desc, + } + + return self + + + def __enter__(self): + caller_frame = inspect.stack()[1] + record = self._tick_record(caller_frame, 'Start of the block.') + self.block_info[self.current_block_uid_stack[-1]]['records'].append(record) + return self + + + def __exit__(self, exc_type, exc_val, exc_tb): + caller_frame = inspect.stack()[1] + record = self._tick_record(caller_frame, 'End of the block.') + self.block_info[self.current_block_uid_stack[-1]]['records'].append(record) + + # Finish one block. + curr_block_uid = self.current_block_uid_stack.pop() + curr_block_aid = '/'.join(self.current_block_aid_stack) + self.current_block_aid_stack.pop() + + self.finished_blocks.append(curr_block_uid) + elapsed = self.block_info[curr_block_uid]['records'][-1]['timestamp'] \ + - self.block_info[curr_block_uid]['records'][0]['timestamp'] + self.block_cost[curr_block_aid] = self.block_cost.get(curr_block_aid, 0) + elapsed + + if hasattr(self, 'dump_thread'): + self.dump_thread.result() + with ThreadPoolExecutor() as executor: + self.dump_thread = executor.submit(self.dump_statistics) + + + + def tick(self, desc:str=''): + ''' + Record a intermediate timestamp. These records are only for in-block analysis, + and will be ignored when analyzing in global view. + ''' + caller_frame = inspect.stack()[1] + record = self._tick_record(caller_frame, desc) + self.block_info[self.current_block_uid_stack[-1]]['records'].append(record) + return + + + def report(self, level:Union[str, List[str]]='global'): + import rich + + caller_frame = inspect.stack()[1] + caller_info = summary_frame_info(caller_frame) + + if isinstance(level, str): + level = [level] + + for lv in level: # To make sure we can output in order. + if lv == 'block': + rich.print(f'[bold underline][EA-B][/bold underline] {caller_info} -> blocks level records:') + for block_name in self.finished_blocks: + msg = '\t' + self._generate_block_msg(block_name).replace('\n\t', '\n\t\t') + rich.print(msg) + elif lv == 'global': + rich.print(f'[bold underline][EA-G][/bold underline] {caller_info} -> global efficiency analysis:') + for block_name, cost in self.block_cost.items(): + rich.print(f'\t{block_name}: {cost:.2f} sec') + + + def clear(self): + self.finished_blocks = [] + self.block_info = {} + self.block_cost = {} + + + def dump_statistics(self): + ''' Dump the logging raw data for post analysis. ''' + if self.log_folder is None: + return + + dump_fn = self.log_folder / 'statistics.pkl' + with open(dump_fn, 'wb') as f: + import pickle + pickle.dump({ + 'finished_blocks' : self.finished_blocks, + 'block_info' : self.block_info, + 'block_cost' : self.block_cost, + 'curr_aid_stack' : self.current_block_aid_stack, # nonempty when when errors happen inside a block + }, f) + + + # TODO: Draw a graph to visualize the time consumption. + + + def _tick_record(self, caller_frame, desc:Optional[str]=''): + # 1. Generate the record. + torch.cuda.synchronize() + timestamp = time.time() + readable_time = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(timestamp)) + position = summary_frame_info(caller_frame) + record = { + 'time' : readable_time, + 'timestamp' : timestamp, + 'position' : position, + 'desc' : desc, + } + + # 2. Log the record. + if self.log_folder is not None: + block_uid = self.current_block_uid_stack[-1] + log_msg = f'[{readable_time}] 🗂️ {block_uid} 📌 {desc} 🌐 {position}' + self.log_fh.write(log_msg + '\n') + + return record + + + def _generate_block_msg(self, block_name): + block_info = self.block_info[block_name] + block_position = block_info['position'] + block_desc = block_info['desc'] + records = block_info['records'] + msg = f'🗂️ {block_name} 📌 {block_desc} 🌐 {block_position}' + for rid, record in enumerate(records): + readable_time = record['time'] + tick_desc = record['desc'] + tick_position = record['position'] + if rid > 0: + prev_record = records[rid-1] + tick_elapsed = record['timestamp'] - prev_record['timestamp'] + tick_elapsed = f'{tick_elapsed:.2f} s' + else: + tick_elapsed = 'N/A' + msg += f'\n\t[{readable_time}] ⏳ {tick_elapsed} 📌 {tick_desc} 🌐 {tick_position}' + return msg + + + def _die_hook(self): + if self.record_birth_block: + self.__exit__(None, None, None) + + self.dump_statistics() + + if self.log_folder is not None: + self.log_fh.close() \ No newline at end of file diff --git a/lib/platform/proj_manager.py b/lib/platform/proj_manager.py new file mode 100644 index 0000000000000000000000000000000000000000..8145b5008126cd0c9e95094e3ee62396db983ce9 --- /dev/null +++ b/lib/platform/proj_manager.py @@ -0,0 +1,24 @@ +from pathlib import Path +from .monitor import TimeMonitor, TimeMonitorDisabled + +class ProjManager(): + root = Path(__file__).parent.parent.parent # root / lib / utils / path_manager.py + assert (root.exists()), 'Can\'t find the path of project root.' + + configs = root / 'configs' # Generally, you are not supposed to access deep config through path. + inputs = root / 'data_inputs' + outputs = root / 'data_outputs' + assert (configs.exists()), 'Make sure you have a \'configs\' folder in the root directory.' + assert (inputs.exists()), 'Make sure you have a \'data_inputs\' folder in the root directory.' + assert (outputs.exists()), 'Make sure you have a \'data_outputs\' folder in the root directory.' + + # Default values. + cfg = None + time_monitor = TimeMonitorDisabled() + + @staticmethod + def init_with_cfg(cfg): + ProjManager.cfg = cfg + ProjManager.exp_outputs = Path(cfg.output_dir) + if cfg.get('enable_time_monitor', False): + ProjManager.time_monitor = TimeMonitor(ProjManager.exp_outputs, record_birth_block=False) \ No newline at end of file diff --git a/lib/platform/sliding_batches/__init__.py b/lib/platform/sliding_batches/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..9029de44049bfaec9c5048d43ef0f15903335e32 --- /dev/null +++ b/lib/platform/sliding_batches/__init__.py @@ -0,0 +1,42 @@ +from .basic import bsb +from .adaptable.v1 import asb + +# The batch manager is used to quickly initialize a batched task. +# The Usage of the batch manager is as follows: + + +def eg_bbm(): + ''' Basic version of the batch manager. ''' + task_len = 1e6 + task_things = [i for i in range(int(task_len))] + + for bw in bsb(total=task_len, batch_size=300, enable_tqdm=True): + sid = bw.sid + eid = bw.eid + round_things = task_things[sid:eid] + # Do something with `round_things`. + + +def eg_asb(): + ''' Basic version of the batch manager. ''' + task_len = 1024 + task_things = [i for i in range(int(task_len))] + + lb, ub = 1, 300 # lower & upper bound of batch size + for bw in asb(total=task_len, bs_scope=(lb, ub), enable_tqdm=True): + sid = bw.sid + eid = bw.eid + round_things = task_things[sid:eid] + # Do something with `round_things`. + + try: + # Do something with `round_things`. + pass + except Exception as e: + if not bw.shrink(): + # In this case, it means task_things[sid:sid+lb] is still too large to handle. + # So you need to do something to handle this situation. + pass + continue #! DO NOT FORGET CONTINUE + + # Do something with `round_things` if no exception is raised. diff --git a/lib/platform/sliding_batches/adaptable/v1.py b/lib/platform/sliding_batches/adaptable/v1.py new file mode 100644 index 0000000000000000000000000000000000000000..969a6ed1b1fdb4b2ae215984f996e6331bc3ff19 --- /dev/null +++ b/lib/platform/sliding_batches/adaptable/v1.py @@ -0,0 +1,69 @@ +import math +from tqdm import tqdm +from contextlib import contextmanager +from typing import Tuple, Union +from ..basic import BasicBatchWindow, bsb + +class AdaptableBatchWindow(BasicBatchWindow): + def __init__(self, sid, eid, min_B): + self.start_id = sid + self.end_id = eid + self.min_B = min_B + self.shrinking = False + + def shrink(self): + if self.size <= self.min_B: + return False + else: + self.shrinking = True + return True + + +class asb(bsb): + def __init__( + self, + total : int, + bs_scope : Union[Tuple[int, int], int], + enable_tqdm : bool = False, + ): + ''' Simple binary strategy. ''' + # Static hyperparameters. + self.total = int(total) + if isinstance(bs_scope, int): + self.min_B = 1 + self.max_B = bs_scope + else: + self.min_B, self.max_B = bs_scope # lower & upper bound of batch size + # Dynamic state. + self.B = self.max_B # current batch size + self.tqdm = tqdm(total=self.total) if enable_tqdm else None + self.cur_window = AdaptableBatchWindow(sid=-1, eid=0, min_B=self.min_B) # starting window + self.last_shrink_id = None + + def __next__(self): + if self.cur_window.shrinking: + sid = self.cur_window.sid + self.shrink_B(sid) + else: + sid = self.cur_window.eid + self.recover_B(sid) + + if sid >= self.total: + if self.tqdm: self.tqdm.close() + raise StopIteration + + eid = min(sid + self.B, self.total) + self.cur_window = AdaptableBatchWindow(sid, eid, min_B=self.min_B) + if self.tqdm: self.tqdm.update(eid - sid) + return self.cur_window + + def shrink_B(self, cur_id:int): + self.last_shrink_id = cur_id + self.cur_window.shrinking = False + self.B = max(math.ceil(self.B/2), self.min_B) + + def recover_B(self, cur_id:int): + if self.last_shrink_id and self.B < self.max_B: + newer_B = min(self.B * 2, self.max_B) + if self.last_shrink_id < cur_id - newer_B: + self.B = newer_B \ No newline at end of file diff --git a/lib/platform/sliding_batches/basic.py b/lib/platform/sliding_batches/basic.py new file mode 100644 index 0000000000000000000000000000000000000000..6b243720ea3d3bd2a4de9a0de4950ae34f577711 --- /dev/null +++ b/lib/platform/sliding_batches/basic.py @@ -0,0 +1,47 @@ +from tqdm import tqdm + +class BasicBatchWindow(): + def __init__(self, sid, eid): + self.start_id = sid + self.end_id = eid + + @property + def sid(self): + return self.start_id + + @property + def eid(self): + return self.end_id + + @property + def size(self): + return self.eid - self.sid + + +class bsb(): + def __init__( + self, + total : int, + batch_size : int, + enable_tqdm : bool = False, + ): + # Static hyperparameters. + self.total = int(total) + self.B = batch_size + # Dynamic state. + self.tqdm = tqdm(total=self.total) if enable_tqdm else None + self.cur_window = BasicBatchWindow(-1, 0) # starting window + + def __iter__(self): + return self + + def __next__(self): + if self.cur_window.eid >= self.total: + if self.tqdm: self.tqdm.close() + raise StopIteration + if self.tqdm: self.tqdm.update(self.cur_window.eid - self.cur_window.sid) + + sid = self.cur_window.eid + eid = min(sid + self.B, self.total) + self.cur_window = BasicBatchWindow(sid, eid) + return self.cur_window \ No newline at end of file diff --git a/lib/utils/bbox.py b/lib/utils/bbox.py new file mode 100644 index 0000000000000000000000000000000000000000..ca208d77240fcffda3ada04769d001ef0d297ef9 --- /dev/null +++ b/lib/utils/bbox.py @@ -0,0 +1,310 @@ +from lib.kits.basic import * + +from .data import to_tensor + + +def lurb_to_cwh( + lurb : Union[list, np.ndarray, torch.Tensor], +): + ''' + Convert the left-upper-right-bottom format to the center-width-height format. + + ### Args + - lurb: Union[list, np.ndarray, torch.Tensor], (..., 4) + - The left-upper-right-bottom format bounding box. + + ### Returns + - Union[list, np.ndarray, torch.Tensor], (..., 4) + - The center-width-height format bounding box. + ''' + lurb, recover_type_back = to_tensor(lurb, device=None, temporary=True) + assert lurb.shape[-1] == 4, f"Invalid shape: {lurb.shape}, should be (..., 4)" + + c = (lurb[..., :2] + lurb[..., 2:]) / 2 # (..., 2) + wh = lurb[..., 2:] - lurb[..., :2] # (..., 2) + + cwh = torch.cat([c, wh], dim=-1) # (..., 4) + return recover_type_back(cwh) + + +def cwh_to_lurb( + cwh : Union[list, np.ndarray, torch.Tensor], +): + ''' + Convert the center-width-height format to the left-upper-right-bottom format. + + ### Args + - cwh: Union[list, np.ndarray, torch.Tensor], (..., 4) + - The center-width-height format bounding box. + + ### Returns + - Union[list, np.ndarray, torch.Tensor], (..., 4) + - The left-upper-right-bottom format bounding box. + ''' + cwh, recover_type_back = to_tensor(cwh, device=None, temporary=True) + assert cwh.shape[-1] == 4, f"Invalid shape: {cwh.shape}, should be (..., 4)" + + l = cwh[..., :2] - cwh[..., 2:] / 2 # (..., 2) + r = cwh[..., :2] + cwh[..., 2:] / 2 # (..., 2) + + lurb = torch.cat([l, r], dim=-1) # (..., 4) + return recover_type_back(lurb) + + +def cwh_to_cs( + cwh : Union[list, np.ndarray, torch.Tensor], + reduce : Optional[str] = None, +): + ''' + Convert the center-width-height format to the center-scale format. + *Only works when width and height are the same.* + + ### Args + - cwh: Union[list, np.ndarray, torch.Tensor], (..., 4) + - The center-width-height format bounding box. + - reduce: Optional[str], default None, valid values: None, 'max' + - Determine how to reduce the width and height to a single scale. + + ### Returns + - Union[list, np.ndarray, torch.Tensor], (..., 3) + - The center-scale format bounding box. + ''' + cwh, recover_type_back = to_tensor(cwh, device=None, temporary=True) + assert cwh.shape[-1] == 4, f"Invalid shape: {cwh.shape}, should be (..., 4)" + + if reduce is None: + if (cwh[..., 2] != cwh[..., 3]).any(): + get_logger().warning(f"Width and height are supposed to be the same, but they're not. The larger one will be used.") + + c = cwh[..., :2] # (..., 2) + s = cwh[..., 2:].max(dim=-1)[0] # (...,) + + cs = torch.cat([c, s[..., None]], dim=-1) # (..., 3) + return recover_type_back(cs) + + +def cs_to_cwh( + cs : Union[list, np.ndarray, torch.Tensor], +): + ''' + Convert the center-scale format to the center-width-height format. + + ### Args + - cs: Union[list, np.ndarray, torch.Tensor], (..., 3) + - The center-scale format bounding box. + + ### Returns + - Union[list, np.ndarray, torch.Tensor], (..., 4) + - The center-width-height format bounding box. + ''' + cs, recover_type_back = to_tensor(cs, device=None, temporary=True) + assert cs.shape[-1] == 3, f"Invalid shape: {cs.shape}, should be (..., 3)" + + c = cs[..., :2] # (..., 2) + s = cs[..., 2] # (...,) + + cwh = torch.cat([c, s[..., None], s[..., None]], dim=-1) # (..., 4) + return recover_type_back(cwh) + + +def lurb_to_cs( + lurb : Union[list, np.ndarray, torch.Tensor], +): + ''' + Convert the left-upper-right-bottom format to the center-scale format. + *Only works when width and height are the same.* + + ### Args + - lurb: Union[list, np.ndarray, torch.Tensor], (..., 4) + - The left-upper-right-bottom format bounding box. + + ### Returns + - Union[list, np.ndarray, torch.Tensor], (..., 3) + - The center-scale format bounding box. + ''' + return cwh_to_cs(lurb_to_cwh(lurb), reduce='max') + + +def cs_to_lurb( + cs : Union[list, np.ndarray, torch.Tensor], +): + ''' + Convert the center-scale format to the left-upper-right-bottom format. + + ### Args + - cs: Union[list, np.ndarray, torch.Tensor], (..., 3) + - The center-scale format bounding box. + + ### Returns + - Union[list, np.ndarray, torch.Tensor], (..., 4) + - The left-upper-right-bottom format bounding box. + ''' + return cwh_to_lurb(cs_to_cwh(cs)) + + +def lurb_to_luwh( + lurb : Union[list, np.ndarray, torch.Tensor], +): + ''' + Convert the left-upper-right-bottom format to the left-upper-width-height format. + + ### Args + - lurb: Union[list, np.ndarray, torch.Tensor] + - The left-upper-right-bottom format bounding box. + + ### Returns + - Union[list, np.ndarray, torch.Tensor] + - The left-upper-width-height format bounding box. + ''' + lurb, recover_type_back = to_tensor(lurb, device=None, temporary=True) + assert lurb.shape[-1] == 4, f"Invalid shape: {lurb.shape}, should be (..., 4)" + + lu = lurb[..., :2] # (..., 2) + wh = lurb[..., 2:] - lurb[..., :2] # (..., 2) + + luwh = torch.cat([lu, wh], dim=-1) # (..., 4) + return recover_type_back(luwh) + + +def luwh_to_lurb( + luwh : Union[list, np.ndarray, torch.Tensor], +): + ''' + Convert the left-upper-width-height format to the left-upper-right-bottom format. + + ### Args + - luwh: Union[list, np.ndarray, torch.Tensor] + - The left-upper-width-height format bounding box. + + ### Returns + - Union[list, np.ndarray, torch.Tensor] + - The left-upper-right-bottom format bounding box. + ''' + luwh, recover_type_back = to_tensor(luwh, device=None, temporary=True) + assert luwh.shape[-1] == 4, f"Invalid shape: {luwh.shape}, should be (..., 4)" + + l = luwh[..., :2] # (..., 2) + r = luwh[..., :2] + luwh[..., 2:] # (..., 2) + + lurb = torch.cat([l, r], dim=-1) # (..., 4) + return recover_type_back(lurb) + + +def crop_with_lurb(data, lurb, padding=0): + """ + Crop the img-like data according to the lurb bounding box. + + ### Args + - data: Union[np.ndarray, torch.Tensor], shape (H, W, C) + - Data like image. + - lurb: Union[list, np.ndarray, torch.Tensor], shape (4,) + - Bounding box with [left, upper, right, bottom] coordinates. + - padding: int, default 0 + - Padding value for out-of-bound areas. + + ### Returns + - Union[np.ndarray, torch.Tensor], shape (H', W', C) + - Cropped image with padding if necessary. + """ + data, recover_type_back = to_tensor(data, device=None, temporary=True) + + # Ensure lurb is in numpy array format for indexing + lurb = np.array(lurb).astype(np.int64) + l_, u_, r_, b_ = lurb + + # Determine the shape of the data. + H_raw, W_raw, C_raw = data.size() + + # Compute the cropped patch size. + H_patch = b_ - u_ + W_patch = r_ - l_ + + # Create an output buffer of the crop size, initialized to padding + if isinstance(data, np.ndarray): + output = np.full((H_patch, W_patch, C_raw), padding, dtype=data.dtype) + else: + output = torch.full((H_patch, W_patch, C_raw), padding, dtype=data.dtype) + + # Calculate the valid region in the original data + valid_l_ = max(0, l_) + valid_u_ = max(0, u_) + valid_r_ = min(W_raw, r_) + valid_b_ = min(H_raw, b_) + + # Calculate the corresponding valid region in the output + target_l_ = valid_l_ - l_ + target_u_ = valid_u_ - u_ + target_r_ = target_l_ + (valid_r_ - valid_l_) + target_b_ = target_u_ + (valid_b_ - valid_u_) + + # Copy the valid region into the output buffer + output[target_u_:target_b_, target_l_:target_r_, :] = data[valid_u_:valid_b_, valid_l_:valid_r_, :] + + return recover_type_back(output) + + +def fit_bbox_to_aspect_ratio( + bbox : np.ndarray, + tgt_ratio : Optional[Tuple[int, int]] = None, + bbox_type : str = 'lurb' +): + ''' + Fit a random bounding box to a target aspect ratio through enlarging the bounding box with least change. + + ### Args + - bbox: np.ndarray, shape is determined by `bbox_type`, e.g. for 'lurb', shape is (4,) + - The bounding box to be modified. The format is determined by `bbox_type`. + - tgt_ratio: Optional[Tuple[int, int]], default None + - The target aspect ratio to be matched. + - bbox_type: str, default 'lurb', valid values: 'lurb', 'cwh'. + + ### Returns + - np.ndarray, shape is determined by `bbox_type`, e.g. for 'lurb', shape is (4,) + - The modified bounding box. + ''' + bbox = bbox.copy() + if bbox_type == 'lurb': + bbx_cwh = lurb_to_cwh(bbox) + bbx_wh = bbx_cwh[2:] + elif bbox_type == 'cwh': + bbx_wh = bbox[2:] + else: + raise ValueError(f"Unsupported bbox type: {bbox_type}") + + new_bbx_wh = expand_wh_to_aspect_ratio(bbx_wh, tgt_ratio) + + if bbox_type == 'lurb': + bbx_cwh[2:] = new_bbx_wh + new_bbox = cwh_to_lurb(bbx_cwh) + elif bbox_type == 'cwh': + new_bbox = np.concatenate([bbox[:2], new_bbx_wh]) + else: + raise ValueError(f"Unsupported bbox type: {bbox_type}") + + return new_bbox + + +def expand_wh_to_aspect_ratio(bbx_wh:np.ndarray, tgt_aspect_ratio:Optional[Tuple[int, int]]=None): + ''' + Increase the size of the bounding box to match the target shape. + Modified from https://github.com/shubham-goel/4D-Humans/blob/6ec79656a23c33237c724742ca2a0ec00b398b53/hmr2/datasets/utils.py#L14-L33 + ''' + if tgt_aspect_ratio is None: + return bbx_wh + + try: + bbx_w , bbx_h = bbx_wh + except (ValueError, TypeError): + get_logger().warning(f"Invalid bbox_wh content: {bbx_wh}") + return bbx_wh + + tgt_w, tgt_h = tgt_aspect_ratio + if bbx_h / bbx_w < tgt_h / tgt_w: + new_h = max(bbx_w * tgt_h / tgt_w, bbx_h) + new_w = bbx_w + else: + new_h = bbx_h + new_w = max(bbx_h * tgt_w / tgt_h, bbx_w) + assert new_h >= bbx_h and new_w >= bbx_w + + return to_numpy([new_w, new_h]) \ No newline at end of file diff --git a/lib/utils/camera.py b/lib/utils/camera.py new file mode 100644 index 0000000000000000000000000000000000000000..bea7824246dd7bc48eb26aafa53c573ad769409c --- /dev/null +++ b/lib/utils/camera.py @@ -0,0 +1,251 @@ +from lib.kits.basic import * + + +def T_to_Rt( + T : Union[torch.Tensor, np.ndarray], +): + ''' Get (..., 3, 3) rotation matrix and (..., 3) translation vector from (..., 4, 4) transformation matrix. ''' + if isinstance(T, np.ndarray): + T = torch.from_numpy(T).float() + assert T.shape[-2:] == (4, 4), f'T.shape[-2:] = {T.shape[-2:]}' + + R = T[..., :3, :3] + t = T[..., :3, 3] + + return R, t + + +def Rt_to_T( + R : Union[torch.Tensor, np.ndarray], + t : Union[torch.Tensor, np.ndarray], +): + ''' Get (..., 4, 4) transformation matrix from (..., 3, 3) rotation matrix and (..., 3) translation vector. ''' + if isinstance(R, np.ndarray): + R = torch.from_numpy(R).float() + if isinstance(t, np.ndarray): + t = torch.from_numpy(t).float() + assert R.shape[-2:] == (3, 3), f'R should be a (..., 3, 3) matrix, but R.shape = {R.shape}' + assert t.shape[-1] == 3, f't should be a (..., 3) vector, but t.shape = {t.shape}' + assert R.shape[:-2] == t.shape[:-1], f'R and t should have the same shape prefix but {R.shape[:-2]} != {t.shape[:-1]}' + + T = torch.eye(4, device=R.device, dtype=R.dtype).repeat(R.shape[:-2] + (1, 1)) # (..., 4, 4) + T[..., :3, :3] = R + T[..., :3, 3] = t + + return T + + +def apply_Ts_on_pts(Ts:torch.Tensor, pts:torch.Tensor): + ''' + Apply transformation matrix `T` on the points `pts`. + + ### Args + - Ts: torch.Tensor, (...B, 4, 4) + - pts: torch.Tensor, (...B, N, 3) + ''' + + assert len(pts.shape) >= 3 and pts.shape[-1] == 3, f'Shape of pts should be (...B, N, 3) but {pts.shape}' + assert Ts.shape[-2:] == (4, 4), f'Shape of Ts should be (..., 4, 4) but {Ts.shape}' + assert Ts.device == pts.device, f'Device of Ts and pts should be the same but {Ts.device} != {pts.device}' + + ret_pts = torch.einsum('...ij,...nj->...ni', Ts[..., :3, :3], pts) + Ts[..., None, :3, 3] + ret_pts = ret_pts.squeeze(0) # (B, N, 3) + + return ret_pts + + +def apply_T_on_pts(T:torch.Tensor, pts:torch.Tensor): + ''' + Apply transformation matrix `T` on the points `pts`. + + ### Args + - T: torch.Tensor, (4, 4) + - pts: torch.Tensor, (B, N, 3) or (N, 3) + ''' + unbatched = len(pts.shape) == 2 + if unbatched: + pts = pts[None] + ret = apply_Ts_on_pts(T[None], pts) + return ret.squeeze(0) if unbatched else ret + + +def apply_Ks_on_pts(Ks:torch.Tensor, pts:torch.Tensor): + ''' + Apply intrinsic camera matrix `K` on the points `pts`, i.e. project the 3D points to 2D. + + ### Args + - Ks: torch.Tensor, (...B, 3, 3) + - pts: torch.Tensor, (...B, N, 3) + ''' + + assert len(pts.shape) >= 3 and pts.shape[-1] == 3, f'Shape of pts should be (...B, N, 3) but {pts.shape}' + assert Ks.shape[-2:] == (3, 3), f'Shape of Ks should be (..., 3, 3) but {Ks.shape}' + assert Ks.device == pts.device, f'Device of Ks and pts should be the same but {Ks.device} != {pts.device}' + + pts_proj_homo = torch.einsum('...ij,...vj->...vi', Ks, pts) + pts_proj = pts_proj_homo[..., :2] / pts_proj_homo[..., 2:3] + return pts_proj + + +def apply_K_on_pts(K:torch.Tensor, pts:torch.Tensor): + ''' + Apply intrinsic camera matrix `K` on the points `pts`, i.e. project the 3D points to 2D. + + ### Args + - K: torch.Tensor, (3, 3) + - pts: torch.Tensor, (B, N, 3) or (N, 3) + ''' + unbatched = len(pts.shape) == 2 + if unbatched: + pts = pts[None] + ret = apply_Ks_on_pts(K[None], pts) + return ret.squeeze(0) if unbatched else ret + + +def perspective_projection( + points : torch.Tensor, + translation : torch.Tensor, + focal_length : torch.Tensor, + camera_center : Optional[torch.Tensor] = None, + rotation : Optional[torch.Tensor] = None, +) -> torch.Tensor: + ''' + Computes the perspective projection of a set of 3D points. + https://github.com/shubham-goel/4D-Humans/blob/6ec79656a23c33237c724742ca2a0ec00b398b53/hmr2/utils/geometry.py#L64-L102 + + ### Args + - points: torch.Tensor, (B, N, 3) + - The input 3D points. + - translation: torch.Tensor, (B, 3) + - The 3D camera translation. + - focal_length: torch.Tensor, (B, 2) + - The focal length in pixels. + - camera_center: torch.Tensor, (B, 2) + - The camera center in pixels. + - rotation: torch.Tensor, (B, 3, 3) + - The camera rotation. + + ### Returns + - torch.Tensor, (B, N, 2) + - The projection of the input points. + ''' + B = points.shape[0] + if rotation is None: + rotation = torch.eye(3, device=points.device, dtype=points.dtype).unsqueeze(0).expand(B, -1, -1) + if camera_center is None: + camera_center = torch.zeros(B, 2, device=points.device, dtype=points.dtype) + # Populate intrinsic camera matrix K. + K = torch.zeros([B, 3, 3], device=points.device, dtype=points.dtype) + K[:, 0, 0] = focal_length[:, 0] + K[:, 1, 1] = focal_length[:, 1] + K[:, 2, 2] = 1. + K[:, :-1, -1] = camera_center + + # Transform points + points = torch.einsum('bij, bkj -> bki', rotation, points) + points = points + translation.unsqueeze(1) + + # Apply perspective distortion + projected_points = points / points[:, :, -1].unsqueeze(-1) + + # Apply camera intrinsics + projected_points = torch.einsum('bij, bkj -> bki', K, projected_points) + + return projected_points[:, :, :-1] + + +def estimate_translation_np(S, joints_2d, joints_conf, focal_length=5000, img_size=224): + ''' + Find camera translation that brings 3D joints S closest to 2D the corresponding joints_2d. + Copied from: https://github.com/nkolot/SPIN/blob/2476c436013055be5cb3905e4e4ecfa86966fac3/utils/geometry.py#L94-L132 + + ### Args + - S: shape = (25, 3) + - 3D joint locations. + - joints: shape = (25, 3) + - 2D joint locations and confidence. + ### Returns + - shape = (3,) + - Camera translation vector. + ''' + + num_joints = S.shape[0] + # focal length + f = np.array([focal_length,focal_length]) + # optical center + center = np.array([img_size/2., img_size/2.]) + + # transformations + Z = np.reshape(np.tile(S[:,2],(2,1)).T,-1) + XY = np.reshape(S[:,0:2],-1) + O = np.tile(center,num_joints) + F = np.tile(f,num_joints) + weight2 = np.reshape(np.tile(np.sqrt(joints_conf),(2,1)).T,-1) + + # least squares + Q = np.array([F*np.tile(np.array([1,0]),num_joints), F*np.tile(np.array([0,1]),num_joints), O-np.reshape(joints_2d,-1)]).T + c = (np.reshape(joints_2d,-1)-O)*Z - F*XY + + # weighted least squares + W = np.diagflat(weight2) + Q = np.dot(W,Q) + c = np.dot(W,c) + + # square matrix + A = np.dot(Q.T,Q) + b = np.dot(Q.T,c) + + # solution + trans = np.linalg.solve(A, b) + + return trans + + +def estimate_camera_trans( + S : torch.Tensor, + joints_2d : torch.Tensor, + focal_length : float = 5000., + img_size : float = 224., + conf_thre : float = 4., +): + ''' + Find camera translation that brings 3D joints S closest to 2D the corresponding joints_2d. + Modified from: https://github.com/nkolot/SPIN/blob/2476c436013055be5cb3905e4e4ecfa86966fac3/utils/geometry.py#L135-L157 + + ### Args + - S: torch.Tensor, shape = (B, J, 3) + - 3D joint locations. + - joints: torch.Tensor, shape = (B, J, 3) + - Ground truth 2D joint locations and confidence. + - focal_length: float + - img_size: float + - conf_thre: float + - Confidence threshold to judge whether we use gt_kp2d or that from OpenPose. + + ### Returns + - torch.Tensor, shape = (B, 3) + - Camera translation vectors. + ''' + device = S.device + B = len(S) + + S = to_numpy(S) + joints_2d = to_numpy(joints_2d) + joints_conf = joints_2d[:, :, -1] # (B, J) + joints_2d = joints_2d[:, :, :-1] # (B, J, 2) + trans = np.zeros((S.shape[0], 3), dtype=np.float32) + # Find the translation for each example in the batch + for i in range(B): + conf_i = joints_conf[i] + # When the ground truth joints are not enough, use all the joints. + if conf_i[25:].sum() < conf_thre: + S_i = S[i] + joints_i = joints_2d[i] + else: + S_i = S[i, 25:] + conf_i = joints_conf[i, 25:] + joints_i = joints_2d[i, 25:] + + + trans[i] = estimate_translation_np(S_i, joints_i, conf_i, focal_length=focal_length, img_size=img_size) + return torch.from_numpy(trans).to(device) \ No newline at end of file diff --git a/lib/utils/ckpt.py b/lib/utils/ckpt.py new file mode 100644 index 0000000000000000000000000000000000000000..bcf341f7b30a0f5ee9af44ffbd7ca0f2377aad34 --- /dev/null +++ b/lib/utils/ckpt.py @@ -0,0 +1,75 @@ +from typing import List, Dict + +def replace_state_dict_name_prefix(state_dict:Dict[str, object], old_prefix:str, new_prefix:str): + ''' Replace the prefix of the keys in the state_dict. ''' + for old_name in list(state_dict.keys()): + if old_name.startswith(old_prefix): + new_name = new_prefix + old_name[len(old_prefix):] + state_dict[new_name] = state_dict.pop(old_name) + + return state_dict + + +def match_prefix_and_remove_state_dict(state_dict:Dict[str, object], prefix:str): + ''' Remove the keys in the state_dict that start with the prefix. ''' + for name in list(state_dict.keys()): + if name.startswith(prefix): + state_dict.pop(name) + return state_dict + + +class StateDictTree: + def __init__(self, keys:List[str]): + self.tree = {} + for key in keys: + parts = key.split('.') + self._recursively_add_leaf(self.tree, parts, key) + + + def rich_print(self, depth:int=-1): + from rich.tree import Tree + from rich import print + rich_tree = Tree('.') + self._recursively_build_rich_tree(rich_tree, self.tree, 0, depth) + print(rich_tree) + + def update_node_name(self, old_name:str, new_name:str): + ''' Input full node name and the whole node will be moved to the new name. ''' + old_parts = old_name.split('.') + # 1. Delete the old node. + try: + parent = None + node = self.tree + for part in old_parts: + parent = node + node = node[part] + parent.pop(old_parts[-1]) + except KeyError: + raise KeyError(f'Key {old_name} not found.') + # 2. Add the new node. + new_parts = new_name.split('.') + self._recursively_add_leaf(self.tree, new_parts, new_name) + + + def _recursively_add_leaf(self, node, parts, full_key): + cur_part, rest_parts = parts[0], parts[1:] + if len(rest_parts) == 0: + assert cur_part not in node, f'Key {full_key} already exists.' + node[cur_part] = full_key + else: + if cur_part not in node: + node[cur_part] = {} + self._recursively_add_leaf(node[cur_part], rest_parts, full_key) + + + def _recursively_build_rich_tree(self, rich_node, dict_node, depth, max_depth:int=-1): + if max_depth > 0 and depth >= max_depth: + rich_node.add(f'... {len(dict_node)} more') + return + + keys = sorted(dict_node.keys()) + for key in keys: + next_dict_node = dict_node[key] + next_rich_node = rich_node.add(key) + if isinstance(next_dict_node, Dict): + self._recursively_build_rich_tree(next_rich_node, next_dict_node, depth+1, max_depth) diff --git a/lib/utils/data/__init__.py b/lib/utils/data/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..8dc1ca57b54f8b17c26b58de681b5eefea345cc3 --- /dev/null +++ b/lib/utils/data/__init__.py @@ -0,0 +1,7 @@ +# Utils here are used to manipulate basic data structures, such as dict, list, tensor, etc. +# It has nothing to do with Machine Learning concepts (like dataset, dataloader, etc.). + +from .types import * +from .dict import * +from .mem import * +from .io import * \ No newline at end of file diff --git a/lib/utils/data/dict.py b/lib/utils/data/dict.py new file mode 100644 index 0000000000000000000000000000000000000000..95fc880ce9dc9464a26d2e7a67fe9a650b4d749b --- /dev/null +++ b/lib/utils/data/dict.py @@ -0,0 +1,79 @@ +import torch +import numpy as np +from typing import Dict, List + + +def disassemble_dict(d, keep_dim=False): + ''' + Unpack a dictionary into a list of dictionaries. The values should be in the same length. + If not keep dim: {k: [...] * N} -> [{k: [...]}] * N. + If keep dim: {k: [...] * N} -> [{k: [[...]]}] * N. + ''' + Ls = [len(v) for v in d.values()] + assert len(set(Ls)) == 1, 'The lengths of the values should be the same!' + + N = Ls[0] + if keep_dim: + return [{k: v[[i]] for k, v in d.items()} for i in range(N)] + else: + return [{k: v[i] for k, v in d.items()} for i in range(N)] + + +def assemble_dict(d, expand_dim=False, keys=None): + ''' + Pack a list of dictionaries into one dictionary. + If expand dim, perform stack, else, perform concat. + ''' + keys = list(d[0].keys()) if keys is None else keys + if isinstance(d[0][keys[0]], np.ndarray): + if expand_dim: + return {k: np.stack([v[k] for v in d], axis=0) for k in keys} + else: + return {k: np.concatenate([v[k] for v in d], axis=0) for k in keys} + elif isinstance(d[0][keys[0]], torch.Tensor): + if expand_dim: + return {k: torch.stack([v[k] for v in d], dim=0) for k in keys} + else: + return {k: torch.cat([v[k] for v in d], dim=0) for k in keys} + + +def filter_dict(d:Dict, keys:List, full:bool=False, strict:bool=False): + ''' + Use path-like syntax to filter the embedded dictionary. + The `'*'` string is regarded as a wildcard, and will return the matched keys. + For control flags: + - If `full`, return the full path, otherwise, only return the matched values. + - If `strict`, raise error if the key is not found, otherwise, simply ignore. + + Eg. + - `x = {'fruit': {'yellow': 'banana', 'red': 'apple'}, 'recycle': {'yellow': 'trash', 'blue': 'recyclable'}}` + - `filter_dict(x, ['*', 'yellow'])` -> `{'fruit': 'banana', 'recycle': 'trash'}` + - `filter_dict(x, ['*', 'yellow'], full=True)` -> `{'fruit': {'yellow': 'banana'}, 'recycle': {'yellow': 'trash'}}` + - `filter_dict(x, ['*', 'blue'])` -> `{'recycle': 'recyclable'}` + - `filter_dict(x, ['*', 'blue'], strict=True)` -> `KeyError: 'blue'` + ''' + + ret = {} + if keys: + cur_key, rest_keys = keys[0], keys[1:] + if cur_key == '*': + for match in d.keys(): + try: + res = filter_dict(d[match], rest_keys, full=full, strict=strict) + if res: + ret[match] = res + except Exception as e: + if strict: + raise e + else: + try: + res = filter_dict(d[cur_key], rest_keys, full=full, strict=strict) + if res: + ret = { cur_key : res } if full else res + except Exception as e: + if strict: + raise e + else: + ret = d + + return ret \ No newline at end of file diff --git a/lib/utils/data/io.py b/lib/utils/data/io.py new file mode 100644 index 0000000000000000000000000000000000000000..830f2148c538daaf4eeebdc0c75d97547799f1bb --- /dev/null +++ b/lib/utils/data/io.py @@ -0,0 +1,11 @@ + +import pickle +from pathlib import Path + + +def load_pickle(fn, mode='rb', encoding=None, pickle_encoding='ASCII'): + if isinstance(fn, Path): + fn = str(fn) + with open(fn, mode=mode, encoding=encoding) as f: + data = pickle.load(f, encoding=pickle_encoding) + return data \ No newline at end of file diff --git a/lib/utils/data/mem.py b/lib/utils/data/mem.py new file mode 100644 index 0000000000000000000000000000000000000000..a8d21854790d9f0d330733ca56473877207d1c6c --- /dev/null +++ b/lib/utils/data/mem.py @@ -0,0 +1,14 @@ +import torch +from typing import List, Dict, Tuple + + +def recursive_detach(x): + if isinstance(x, torch.Tensor): + return x.detach() + elif isinstance(x, Dict): + return {k: recursive_detach(v) for k, v in x.items()} + elif isinstance(x, (List, Tuple)): + return [recursive_detach(v) for v in x] + else: + return x + diff --git a/lib/utils/data/types.py b/lib/utils/data/types.py new file mode 100644 index 0000000000000000000000000000000000000000..1bf92fdcead22c14fc83b174c99eb2a7b83db9c9 --- /dev/null +++ b/lib/utils/data/types.py @@ -0,0 +1,82 @@ +import torch +import numpy as np + +from typing import Any, List +from omegaconf import ListConfig + +def to_numpy(x, temporary:bool=False) -> Any: + if isinstance(x, torch.Tensor): + if temporary: + recover_type_back = lambda x_: torch.from_numpy(x_).type_as(x).to(x.device) + return x.detach().cpu().numpy(), recover_type_back + else: + return x.detach().cpu().numpy() + if isinstance(x, np.ndarray): + if temporary: + recover_type_back = lambda x_: x_ + return x.copy(), recover_type_back + else: + return x + if isinstance(x, List): + if temporary: + recover_type_back = lambda x_: x_.tolist() + return np.array(x), recover_type_back + else: + return np.array(x) + raise ValueError(f"Unsupported type: {type(x)}") + +def to_tensor(x, device, temporary:bool=False) -> Any: + ''' + Simply unify the type transformation to torch.Tensor. + If device is None, don't change the device if device is not CPU. + ''' + if isinstance(x, torch.Tensor): + device = x.device if device is None else device + if temporary: + recover_type_back = lambda x_: x_.to(x.device) # recover the device + return x.to(device), recover_type_back + else: + return x.to(device) + + device = 'cpu' if device is None else device + if isinstance(x, np.ndarray): + if temporary: + recover_type_back = lambda x_: x_.detach().cpu().numpy() + return torch.from_numpy(x).to(device), recover_type_back + else: + return torch.from_numpy(x).to(device) + if isinstance(x, List): + if temporary: + recover_type_back = lambda x_: x_.tolist() + return torch.from_numpy(np.array(x)).to(device), recover_type_back + else: + return torch.from_numpy(np.array(x)).to(device) + raise ValueError(f"Unsupported type: {type(x)}") + + +def to_list(x, temporary:bool=False) -> Any: + if isinstance(x, List): + if temporary: + recover_type_back = lambda x_: x_ + return x.copy(), recover_type_back + else: + return x + if isinstance(x, torch.Tensor): + if temporary: + recover_type_back = lambda x_: torch.tensor(x_, device=x.device, dtype=x.dtype) + return x.detach().cpu().numpy().tolist(), recover_type_back + else: + return x.detach().cpu().numpy().tolist() + if isinstance(x, np.ndarray): + if temporary: + recover_type_back = lambda x_: np.array(x_) + return x.tolist(), recover_type_back + else: + return x.tolist() + if isinstance(x, ListConfig): + if temporary: + recover_type_back = lambda x_: ListConfig(x_) + return list(x), recover_type_back + else: + return list(x) + raise ValueError(f"Unsupported type: {type(x)}") \ No newline at end of file diff --git a/lib/utils/device.py b/lib/utils/device.py new file mode 100644 index 0000000000000000000000000000000000000000..9a0a22f8812dd5fcd4cdba7f79b31e03a8ec417f --- /dev/null +++ b/lib/utils/device.py @@ -0,0 +1,24 @@ +import torch +from typing import Any, Dict, List + + +def recursive_to(x: Any, target: torch.device): + ''' + Recursively transfer data to the target device. + Modified from: https://github.com/shubham-goel/4D-Humans/blob/6ec79656a23c33237c724742ca2a0ec00b398b53/hmr2/utils/__init__.py#L9-L25 + + ### Args + - x: Any + - target: torch.device + + ### Returns + - Data transferred to the target device. + ''' + if isinstance(x, Dict): + return {k: recursive_to(v, target) for k, v in x.items()} + elif isinstance(x, torch.Tensor): + return x.to(target) + elif isinstance(x, List): + return [recursive_to(i, target) for i in x] + else: + return x diff --git a/lib/utils/geometry/rotation.py b/lib/utils/geometry/rotation.py new file mode 100644 index 0000000000000000000000000000000000000000..de8b95831c6d1c8d556f400ffdb7001b35190245 --- /dev/null +++ b/lib/utils/geometry/rotation.py @@ -0,0 +1,517 @@ +# Edited by Yan Xia +# Modified from https://github.com/facebookresearch/pytorch3d/tree/main/pytorch3d/transforms +# +# # Copyright (c) Meta Platforms, Inc. and affiliates. +# # All rights reserved. +# # +# # This source code is licensed under the BSD-style license found in the +# # LICENSE file in the root directory of this source tree. + +import torch +import torch.nn.functional as F + + +''' +The transformation matrices returned from the functions in this file assume +the points on which the transformation will be applied are column vectors. +i.e. the R matrix is structured as + + R = [ + [Rxx, Rxy, Rxz], + [Ryx, Ryy, Ryz], + [Rzx, Rzy, Rzz], + ] # (3, 3) + +This matrix can be applied to column vectors by post multiplication +by the points e.g. + + points = [[0], [1], [2]] # (3 x 1) xyz coordinates of a point + transformed_points = R * points + +To apply the same matrix to points which are row vectors, the R matrix +can be transposed and pre multiplied by the points: + +e.g. + points = [[0, 1, 2]] # (1 x 3) xyz coordinates of a point + transformed_points = points * R.transpose(1, 0) +''' + + +def quaternion_to_matrix(quaternions: torch.Tensor) -> torch.Tensor: + ''' + Convert rotations given as quaternions to rotation matrices. + + Args: + quaternions: quaternions with real part first, + as tensor of shape (..., 4). + + Returns: + Rotation matrices as tensor of shape (..., 3, 3). + ''' + r, i, j, k = torch.unbind(quaternions, -1) + # pyre-fixme[58]: `/` is not supported for operand types `float` and `Tensor`. + two_s = 2.0 / (quaternions * quaternions).sum(-1) + + o = torch.stack( + ( + 1 - two_s * (j * j + k * k), + two_s * (i * j - k * r), + two_s * (i * k + j * r), + two_s * (i * j + k * r), + 1 - two_s * (i * i + k * k), + two_s * (j * k - i * r), + two_s * (i * k - j * r), + two_s * (j * k + i * r), + 1 - two_s * (i * i + j * j), + ), + -1, + ) + return o.reshape(quaternions.shape[:-1] + (3, 3)) + + +def _sqrt_positive_part(x: torch.Tensor) -> torch.Tensor: + ''' + Returns torch.sqrt(torch.max(0, x)) + but with a zero subgradient where x is 0. + ''' + ret = torch.zeros_like(x) + positive_mask = x > 0 + ret[positive_mask] = torch.sqrt(x[positive_mask]) + return ret + + +def matrix_to_quaternion(matrix: torch.Tensor) -> torch.Tensor: + ''' + Convert rotations given as rotation matrices to quaternions. + + Args: + matrix: Rotation matrices as tensor of shape (..., 3, 3). + + Returns: + quaternions with real part first, as tensor of shape (..., 4). + ''' + if matrix.size(-1) != 3 or matrix.size(-2) != 3: + raise ValueError(f"Invalid rotation matrix shape {matrix.shape}.") + + batch_dim = matrix.shape[:-2] + m00, m01, m02, m10, m11, m12, m20, m21, m22 = torch.unbind( + matrix.reshape(batch_dim + (9,)), dim=-1 + ) + + q_abs = _sqrt_positive_part( + torch.stack( + [ + 1.0 + m00 + m11 + m22, + 1.0 + m00 - m11 - m22, + 1.0 - m00 + m11 - m22, + 1.0 - m00 - m11 + m22, + ], + dim=-1, + ) + ) + + # we produce the desired quaternion multiplied by each of r, i, j, k + quat_by_rijk = torch.stack( + [ + # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and + # `int`. + torch.stack([q_abs[..., 0] ** 2, m21 - m12, m02 - m20, m10 - m01], dim=-1), + # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and + # `int`. + torch.stack([m21 - m12, q_abs[..., 1] ** 2, m10 + m01, m02 + m20], dim=-1), + # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and + # `int`. + torch.stack([m02 - m20, m10 + m01, q_abs[..., 2] ** 2, m12 + m21], dim=-1), + # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and + # `int`. + torch.stack([m10 - m01, m20 + m02, m21 + m12, q_abs[..., 3] ** 2], dim=-1), + ], + dim=-2, + ) + + # We floor here at 0.1 but the exact level is not important; if q_abs is small, + # the candidate won't be picked. + flr = torch.tensor(0.1).to(dtype=q_abs.dtype, device=q_abs.device) + quat_candidates = quat_by_rijk / (2.0 * q_abs[..., None].max(flr)) + + # if not for numerical problems, quat_candidates[i] should be same (up to a sign), + # forall i; we pick the best-conditioned one (with the largest denominator) + out = quat_candidates[ + F.one_hot(q_abs.argmax(dim=-1), num_classes=4) > 0.5, : + ].reshape(batch_dim + (4,)) + return standardize_quaternion(out) + + +def _axis_angle_rotation(axis: str, angle: torch.Tensor) -> torch.Tensor: + ''' + Return the rotation matrices for one of the rotations about an axis + of which Euler angles describe, for each value of the angle given. + + Args: + axis: Axis label "X" or "Y or "Z". + angle: any shape tensor of Euler angles in radians + + Returns: + Rotation matrices as tensor of shape (..., 3, 3). + ''' + + cos = torch.cos(angle) + sin = torch.sin(angle) + one = torch.ones_like(angle) + zero = torch.zeros_like(angle) + + if axis == "X": + R_flat = (one, zero, zero, zero, cos, -sin, zero, sin, cos) + elif axis == "Y": + R_flat = (cos, zero, sin, zero, one, zero, -sin, zero, cos) + elif axis == "Z": + R_flat = (cos, -sin, zero, sin, cos, zero, zero, zero, one) + else: + raise ValueError("letter must be either X, Y or Z.") + + return torch.stack(R_flat, -1).reshape(angle.shape + (3, 3)) + + +def euler_angles_to_matrix(euler_angles: torch.Tensor, convention: str) -> torch.Tensor: + ''' + Convert rotations given as Euler angles in radians to rotation matrices. + + Args: + euler_angles: Euler angles in radians as tensor of shape (..., 3). + convention: Convention string of three uppercase letters from + {"X", "Y", and "Z"}. + + Returns: + Rotation matrices as tensor of shape (..., 3, 3). + ''' + if euler_angles.dim() == 0 or euler_angles.shape[-1] != 3: + raise ValueError("Invalid input euler angles.") + if len(convention) != 3: + raise ValueError("Convention must have 3 letters.") + if convention[1] in (convention[0], convention[2]): + raise ValueError(f"Invalid convention {convention}.") + for letter in convention: + if letter not in ("X", "Y", "Z"): + raise ValueError(f"Invalid letter {letter} in convention string.") + matrices = [ + _axis_angle_rotation(c, e) + for c, e in zip(convention, torch.unbind(euler_angles, -1)) + ] + # return functools.reduce(torch.matmul, matrices) + return torch.matmul(torch.matmul(matrices[0], matrices[1]), matrices[2]) + + +def _angle_from_tan( + axis: str, other_axis: str, data, horizontal: bool, tait_bryan: bool +) -> torch.Tensor: + ''' + Extract the first or third Euler angle from the two members of + the matrix which are positive constant times its sine and cosine. + + Args: + axis: Axis label "X" or "Y or "Z" for the angle we are finding. + other_axis: Axis label "X" or "Y or "Z" for the middle axis in the + convention. + data: Rotation matrices as tensor of shape (..., 3, 3). + horizontal: Whether we are looking for the angle for the third axis, + which means the relevant entries are in the same row of the + rotation matrix. If not, they are in the same column. + tait_bryan: Whether the first and third axes in the convention differ. + + Returns: + Euler Angles in radians for each matrix in data as a tensor + of shape (...). + ''' + + i1, i2 = {"X": (2, 1), "Y": (0, 2), "Z": (1, 0)}[axis] + if horizontal: + i2, i1 = i1, i2 + even = (axis + other_axis) in ["XY", "YZ", "ZX"] + if horizontal == even: + return torch.atan2(data[..., i1], data[..., i2]) + if tait_bryan: + return torch.atan2(-data[..., i2], data[..., i1]) + return torch.atan2(data[..., i2], -data[..., i1]) + + +def _index_from_letter(letter: str) -> int: + if letter == "X": + return 0 + if letter == "Y": + return 1 + if letter == "Z": + return 2 + raise ValueError("letter must be either X, Y or Z.") + + +def matrix_to_euler_angles(matrix: torch.Tensor, convention: str) -> torch.Tensor: + ''' + Convert rotations given as rotation matrices to Euler angles in radians. + + Args: + matrix: Rotation matrices as tensor of shape (..., 3, 3). + convention: Convention string of three uppercase letters. + + Returns: + Euler angles in radians as tensor of shape (..., 3). + ''' + if len(convention) != 3: + raise ValueError("Convention must have 3 letters.") + if convention[1] in (convention[0], convention[2]): + raise ValueError(f"Invalid convention {convention}.") + for letter in convention: + if letter not in ("X", "Y", "Z"): + raise ValueError(f"Invalid letter {letter} in convention string.") + if matrix.size(-1) != 3 or matrix.size(-2) != 3: + raise ValueError(f"Invalid rotation matrix shape {matrix.shape}.") + i0 = _index_from_letter(convention[0]) + i2 = _index_from_letter(convention[2]) + tait_bryan = i0 != i2 + if tait_bryan: + central_angle = torch.asin( + matrix[..., i0, i2] * (-1.0 if i0 - i2 in [-1, 2] else 1.0) + ) + else: + central_angle = torch.acos(matrix[..., i0, i0]) + + o = ( + _angle_from_tan( + convention[0], convention[1], matrix[..., i2], False, tait_bryan + ), + central_angle, + _angle_from_tan( + convention[2], convention[1], matrix[..., i0, :], True, tait_bryan + ), + ) + return torch.stack(o, -1) + + +def standardize_quaternion(quaternions: torch.Tensor) -> torch.Tensor: + ''' + Convert a unit quaternion to a standard form: one in which the real + part is non negative. + + Args: + quaternions: Quaternions with real part first, + as tensor of shape (..., 4). + + Returns: + Standardized quaternions as tensor of shape (..., 4). + ''' + return torch.where(quaternions[..., 0:1] < 0, -quaternions, quaternions) + + +def quaternion_raw_multiply(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: + ''' + Multiply two quaternions. + Usual torch rules for broadcasting apply. + + Args: + a: Quaternions as tensor of shape (..., 4), real part first. + b: Quaternions as tensor of shape (..., 4), real part first. + + Returns: + The product of a and b, a tensor of quaternions shape (..., 4). + ''' + aw, ax, ay, az = torch.unbind(a, -1) + bw, bx, by, bz = torch.unbind(b, -1) + ow = aw * bw - ax * bx - ay * by - az * bz + ox = aw * bx + ax * bw + ay * bz - az * by + oy = aw * by - ax * bz + ay * bw + az * bx + oz = aw * bz + ax * by - ay * bx + az * bw + return torch.stack((ow, ox, oy, oz), -1) + + +def quaternion_multiply(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: + ''' + Multiply two quaternions representing rotations, returning the quaternion + representing their composition, i.e. the versor with nonnegative real part. + Usual torch rules for broadcasting apply. + + Args: + a: Quaternions as tensor of shape (..., 4), real part first. + b: Quaternions as tensor of shape (..., 4), real part first. + + Returns: + The product of a and b, a tensor of quaternions of shape (..., 4). + ''' + ab = quaternion_raw_multiply(a, b) + return standardize_quaternion(ab) + + +def quaternion_invert(quaternion: torch.Tensor) -> torch.Tensor: + ''' + Given a quaternion representing rotation, get the quaternion representing + its inverse. + + Args: + quaternion: Quaternions as tensor of shape (..., 4), with real part + first, which must be versors (unit quaternions). + + Returns: + The inverse, a tensor of quaternions of shape (..., 4). + ''' + + scaling = torch.tensor([1, -1, -1, -1], device=quaternion.device) + return quaternion * scaling + + +def quaternion_apply(quaternion: torch.Tensor, point: torch.Tensor) -> torch.Tensor: + ''' + Apply the rotation given by a quaternion to a 3D point. + Usual torch rules for broadcasting apply. + + Args: + quaternion: Tensor of quaternions, real part first, of shape (..., 4). + point: Tensor of 3D points of shape (..., 3). + + Returns: + Tensor of rotated points of shape (..., 3). + ''' + if point.size(-1) != 3: + raise ValueError(f"Points are not in 3D, {point.shape}.") + real_parts = point.new_zeros(point.shape[:-1] + (1,)) + point_as_quaternion = torch.cat((real_parts, point), -1) + out = quaternion_raw_multiply( + quaternion_raw_multiply(quaternion, point_as_quaternion), + quaternion_invert(quaternion), + ) + return out[..., 1:] + + +def axis_angle_to_matrix(axis_angle: torch.Tensor) -> torch.Tensor: + ''' + Convert rotations given as axis/angle to rotation matrices. + + Args: + axis_angle: Rotations given as a vector in axis angle form, + as a tensor of shape (..., 3), where the magnitude is + the angle turned anticlockwise in radians around the + vector's direction. + + Returns: + Rotation matrices as tensor of shape (..., 3, 3). + ''' + return quaternion_to_matrix(axis_angle_to_quaternion(axis_angle)) + + +def matrix_to_axis_angle(matrix: torch.Tensor) -> torch.Tensor: + ''' + Convert rotations given as rotation matrices to axis/angle. + + Args: + matrix: Rotation matrices as tensor of shape (..., 3, 3). + + Returns: + Rotations given as a vector in axis angle form, as a tensor + of shape (..., 3), where the magnitude is the angle + turned anticlockwise in radians around the vector's + direction. + ''' + return quaternion_to_axis_angle(matrix_to_quaternion(matrix)) + + +def axis_angle_to_quaternion(axis_angle: torch.Tensor) -> torch.Tensor: + ''' + Convert rotations given as axis/angle to quaternions. + + Args: + axis_angle: Rotations given as a vector in axis angle form, + as a tensor of shape (..., 3), where the magnitude is + the angle turned anticlockwise in radians around the + vector's direction. + + Returns: + quaternions with real part first, as tensor of shape (..., 4). + ''' + angles = torch.norm(axis_angle, p=2, dim=-1, keepdim=True) + half_angles = angles * 0.5 + eps = 1e-6 + small_angles = angles.abs() < eps + sin_half_angles_over_angles = torch.empty_like(angles) + sin_half_angles_over_angles[~small_angles] = ( + torch.sin(half_angles[~small_angles]) / angles[~small_angles] + ) + # for x small, sin(x/2) is about x/2 - (x/2)^3/6 + # so sin(x/2)/x is about 1/2 - (x*x)/48 + sin_half_angles_over_angles[small_angles] = ( + 0.5 - (angles[small_angles] * angles[small_angles]) / 48 + ) + quaternions = torch.cat( + [torch.cos(half_angles), axis_angle * sin_half_angles_over_angles], dim=-1 + ) + return quaternions + + +def quaternion_to_axis_angle(quaternions: torch.Tensor) -> torch.Tensor: + ''' + Convert rotations given as quaternions to axis/angle. + + Args: + quaternions: quaternions with real part first, + as tensor of shape (..., 4). + + Returns: + Rotations given as a vector in axis angle form, as a tensor + of shape (..., 3), where the magnitude is the angle + turned anticlockwise in radians around the vector's + direction. + ''' + norms = torch.norm(quaternions[..., 1:], p=2, dim=-1, keepdim=True) + half_angles = torch.atan2(norms, quaternions[..., :1]) + angles = 2 * half_angles + eps = 1e-6 + small_angles = angles.abs() < eps + sin_half_angles_over_angles = torch.empty_like(angles) + sin_half_angles_over_angles[~small_angles] = ( + torch.sin(half_angles[~small_angles]) / angles[~small_angles] + ) + # for x small, sin(x/2) is about x/2 - (x/2)^3/6 + # so sin(x/2)/x is about 1/2 - (x*x)/48 + sin_half_angles_over_angles[small_angles] = ( + 0.5 - (angles[small_angles] * angles[small_angles]) / 48 + ) + return quaternions[..., 1:] / sin_half_angles_over_angles + + +def rotation_6d_to_matrix(d6: torch.Tensor) -> torch.Tensor: + ''' + Converts 6D rotation representation by Zhou et al. [1] to rotation matrix + using Gram--Schmidt orthogonalization per Section B of [1]. + Args: + d6: 6D rotation representation, of size (*, 6) + + Returns: + batch of rotation matrices of size (*, 3, 3) + + [1] Zhou, Y., Barnes, C., Lu, J., Yang, J., & Li, H. + On the Continuity of Rotation Representations in Neural Networks. + IEEE Conference on Computer Vision and Pattern Recognition, 2019. + Retrieved from http://arxiv.org/abs/1812.07035 + ''' + + a1, a2 = d6[..., :3], d6[..., 3:] + b1 = F.normalize(a1, dim=-1) + b2 = a2 - (b1 * a2).sum(-1, keepdim=True) * b1 + b2 = F.normalize(b2, dim=-1) + b3 = torch.cross(b1, b2, dim=-1) + return torch.stack((b1, b2, b3), dim=-2) + + +def matrix_to_rotation_6d(matrix: torch.Tensor) -> torch.Tensor: + ''' + Converts rotation matrices to 6D rotation representation by Zhou et al. [1] + by dropping the last row. Note that 6D representation is not unique. + Args: + matrix: batch of rotation matrices of size (*, 3, 3) + + Returns: + 6D rotation representation, of size (*, 6) + + [1] Zhou, Y., Barnes, C., Lu, J., Yang, J., & Li, H. + On the Continuity of Rotation Representations in Neural Networks. + IEEE Conference on Computer Vision and Pattern Recognition, 2019. + Retrieved from http://arxiv.org/abs/1812.07035 + ''' + batch_dim = matrix.size()[:-2] + return matrix[..., :2, :].clone().reshape(batch_dim + (6,)) diff --git a/lib/utils/geometry/volume.py b/lib/utils/geometry/volume.py new file mode 100644 index 0000000000000000000000000000000000000000..dafbd03bbca636a8a5c9bdf34039060669fec088 --- /dev/null +++ b/lib/utils/geometry/volume.py @@ -0,0 +1,43 @@ +from lib.kits.basic import * + + +def compute_mesh_volume( + verts: Union[torch.Tensor, np.ndarray], + faces: Union[torch.Tensor, np.ndarray], +) -> torch.Tensor: + ''' + Computes the volume of a mesh object through triangles. + References: + 1. https://github.com/muelea/shapy/blob/a5daa70ce619cbd2a218cebbe63ae3a4c0b771fd/mesh-mesh-intersection/body_measurements/body_measurements.py#L201-L215 + 2. https://stackoverflow.com/questions/1406029/how-to-calculate-the-volume-of-a-3d-mesh-object-the-surface-of-which-is-made-up + + ### Args + - verts: `torch.Tensor` or `np.ndarray`, shape = ((...B,) V, C=3) + - faces: `torch.Tensor` or `np.ndarray`, shape = (T, K=3) where T = #triangles + + ### Returns + - volume: `torch.Tensor`, shape = (...B,) or (,), in m^3. + ''' + faces = to_numpy(faces) + + # Get triangles' xyz. + batch_shape = verts.shape[:-2] + V = verts.shape[-2] + verts = verts.reshape(-1, V, 3) # (B', V, C=3) + + tris = verts[:, faces] # (B', T, K=3, C=3) + tris = tris.reshape(*batch_shape, -1, 3, 3) # (..., T, K=3, C=3) + + x = tris[..., 0] # (..., T, K=3) + y = tris[..., 1] # (..., T, K=3) + z = tris[..., 2] # (..., T, K=3) + + volume = ( + -x[..., 2] * y[..., 1] * z[..., 0] + + x[..., 1] * y[..., 2] * z[..., 0] + + x[..., 2] * y[..., 0] * z[..., 1] - + x[..., 0] * y[..., 2] * z[..., 1] - + x[..., 1] * y[..., 0] * z[..., 2] + + x[..., 0] * y[..., 1] * z[..., 2] + ).sum(dim=-1).abs() / 6.0 + return volume \ No newline at end of file diff --git a/lib/utils/media/__init__.py b/lib/utils/media/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..40bafce9930369184b2ffe40cb551d9187de7bb7 --- /dev/null +++ b/lib/utils/media/__init__.py @@ -0,0 +1,3 @@ +from .io import * +from .edit import * +from .draw import * \ No newline at end of file diff --git a/lib/utils/media/draw.py b/lib/utils/media/draw.py new file mode 100644 index 0000000000000000000000000000000000000000..b8568609c648e9f62065bde86e341c629fc96ab7 --- /dev/null +++ b/lib/utils/media/draw.py @@ -0,0 +1,245 @@ +from lib.kits.basic import * + +import cv2 +import imageio + +from lib.utils.data import to_numpy +from lib.utils.vis import ColorPalette + + +def annotate_img( + img : np.ndarray, + text : str, + pos : Union[str, Tuple[int, int]] = 'bl', +): + ''' + Annotate the image with the given text. + + ### Args + - img: np.ndarray, (H, W, 3) + - text: str + - pos: str or tuple(int, int), default 'bl' + - If str, one of ['tl', 'bl']. + - If tuple, the position of the text. + + ### Returns + - np.ndarray, (H, W, 3) + - The annotated image. + ''' + assert len(img.shape) == 3, 'img must have 3 dimensions.' + return annotate_video(frames=img[None], text=text, pos=pos)[0] + + +def annotate_video( + frames : np.ndarray, + text : str, + pos : Union[str, Tuple[int, int]] = 'bl', + alpha : float = 0.75, +): + ''' + Annotate the video frames with the given text. + + ### Args + - frames: np.ndarray, (L, H, W, 3) + - text: str + - pos: str or tuple(int, int), default 'bl' + - If str, one of ['tl', 'bl']. + - If tuple, the position of the text. + - alpha: float, default 0.5 + - The transparency of the text. + + ### Returns + - np.ndarray, (L, H, W, 3) + - The annotated video. + ''' + assert len(frames.shape) == 4, 'frames must have 4 dimensions.' + frames = frames.copy() + L, H, W = frames.shape[:3] + + if isinstance(pos, str): + if pos == 'tl': + offset = (int(0.1 * W), int(0.1 * H)) + elif pos == 'bl': + offset = (int(0.1 * W), int(0.9 * H)) + else: + raise ValueError(f'Invalid position: {pos}') + else: + offset = pos + + for i, frame in enumerate(frames): + overlay = frame.copy() + _put_text(overlay, text, offset) + frames[i] = cv2.addWeighted(overlay, alpha, frame, 1 - alpha, 0) + + return frames + + +def draw_bbx_on_img( + img : np.ndarray, + lurb : np.ndarray, + color : str = 'red', +): + ''' + Draw the bounding box on the image. + + ### Args + - img: np.ndarray, (H, W, 3) + - lurb: np.ndarray, (4,) + - The bounding box in the format of left, up, right, bottom. + - color: str, default 'red' + + ### Returns + - np.ndarray, (H, W, 3) + - The image with the bounding box. + ''' + assert len(img.shape) == 3, 'img must have 3 dimensions.' + + img = img.copy() + l, u, r, b = lurb.astype(int) + color_rgb_int8 = ColorPalette.presets_int8[color] + cv2.rectangle(img, (l, u), (r, b), color_rgb_int8, 3) + + return img + + +def draw_bbx_on_video( + frames : np.ndarray, + lurbs : np.ndarray, + color : str = 'red', +): + ''' + Draw the bounding box on the video frames. + + ### Args + - frames: np.ndarray, (L, H, W, 3) + - lurbs: np.ndarray, (L, 4,) + - The bounding box in the format of left, up, right, bottom. + - color: str, default 'red' + + ### Returns + - np.ndarray, (L, H, W, 3) + - The video with the bounding box. + ''' + assert len(frames.shape) == 4, 'frames must have 4 dimensions.' + frames = frames.copy() + + for i, frame in enumerate(frames): + frames[i] = draw_bbx_on_img(frame, lurbs[i], color) + + return frames + + +def draw_kp2d_on_img( + img : np.ndarray, + kp2d : Union[np.ndarray, torch.Tensor], + links : list = [], + link_colors : list = [], + show_conf : bool = False, + show_idx : bool = False, +): + ''' + Draw the 2d keypoints (and connection lines if exists) on the image. + + ### Args + - img: np.ndarray, (H, W, 3) + - The image. + - kp2d: np.ndarray or torch.Tensor, (N, 2) or (N, 3) + - The 2d keypoints without/with confidence. + - links: list of [int, int] or (int, int), default [] + - The connections between keypoints. Each element is a tuple of two indices. + - If empty, only keypoints will be drawn. + - link_colors: list of [int, int, int] or (int, int, int), default [] + - The colors of the connections. + - If empty, the connections will be drawn in white. + - show_conf: bool, default False + - Whether to show the confidence of keypoints. + - show_idx: bool, default False + - Whether to show the index of keypoints. + + ### Returns + - img: np.ndarray, (H, W, 3) + - The image with skeleton. + ''' + img = img.copy() + kp2d = to_numpy(kp2d) # (N, 2) or (N, 3) + assert len(img.shape) == 3, f'`img`\'s shape should be (H, W, 3) but got {img.shape}' + assert len(kp2d.shape) == 2, f'`kp2d`\'s shape should be (N, 2) or (N, 3) but got {kp2d.shape}' + + if kp2d.shape[1] == 2: + kp2d = np.concatenate([kp2d, np.ones((kp2d.shape[0], 1))], axis=-1) # (N, 3) + + kp_has_drawn = [False] * kp2d.shape[0] + # Draw connections. + for lid, link in enumerate(links): + # Skip links related to impossible keypoints. + if kp2d[link[0], 2] < 0.5 or kp2d[link[1], 2] < 0.5: + continue + + pt1 = tuple(kp2d[link[0], :2].astype(int)) + pt2 = tuple(kp2d[link[1], :2].astype(int)) + color = (255, 255, 255) if len(link_colors) == 0 else tuple(link_colors[lid]) + cv2.line(img, pt1, pt2, color, 2) + if not kp_has_drawn[link[0]]: + cv2.circle(img, pt1, 3, color, -1) + if not kp_has_drawn[link[1]]: + cv2.circle(img, pt2, 3, color, -1) + kp_has_drawn[link[0]] = kp_has_drawn[link[1]] = True + + # Draw keypoints and annotate the confidence. + for i, kp in enumerate(kp2d): + conf = kp[2] + pos = tuple(kp[:2].astype(int)) + + if not kp_has_drawn[i]: + cv2.circle(img, pos, 4, (255, 255, 255), -1) + cv2.circle(img, pos, 2, ( 0, 255, 0), -1) + kp_has_drawn[i] = True + + if show_conf: + _put_text(img, f'{conf:.2f}', pos) + if show_idx: + if i >= 40: + continue + _put_text(img, f'{i}', pos, scale=0.03) + + return img + + +# ====== Internal Utils ====== + +def _put_text( + img : np.ndarray, + text : str, + pos : Tuple[int, int], + scale : float = 0.05, + color_inside : Tuple[int, int, int] = ColorPalette.presets_int8['black'], + color_stroke : Tuple[int, int, int] = ColorPalette.presets_int8['white'], + **kwargs +): + fontFace = cv2.FONT_HERSHEY_SIMPLEX + if 'fontFace' in kwargs: + fontFace = kwargs['fontFace'] + kwargs.pop('fontFace') + + H, W = img.shape[:2] + # https://stackoverflow.com/a/55772676/22331129 + font_scale = scale * min(H, W) / 25 * 1.5 + thickness_inside = max(int(font_scale), 1) + thickness_stroke = max(int(font_scale * 6), 6) + + # Deal with the multi-line text. + + ((fw, fh), baseline) = cv2.getTextSize( + text = text, + fontFace = fontFace, + fontScale = font_scale, + thickness = thickness_stroke, + ) # https://stackoverflow.com/questions/73664883/opencv-python-draw-text-with-fontsize-in-pixels + + lines = text.split('\n') + line_height = baseline + fh + + for i, line in enumerate(lines): + pos_ = (pos[0], pos[1] + line_height * i) + cv2.putText(img, line, pos_, fontFace, font_scale, color_stroke, thickness_stroke) + cv2.putText(img, line, pos_, fontFace, font_scale, color_inside, thickness_inside) \ No newline at end of file diff --git a/lib/utils/media/edit.py b/lib/utils/media/edit.py new file mode 100644 index 0000000000000000000000000000000000000000..b64f8ca39d04aa7740951813e9aa20fb8f5a4e3f --- /dev/null +++ b/lib/utils/media/edit.py @@ -0,0 +1,288 @@ +import cv2 +import imageio +import numpy as np + +from typing import Union, Tuple, List +from pathlib import Path + + +def flex_resize_img( + img : np.ndarray, + tgt_wh : Union[Tuple[int, int], None] = None, + ratio : Union[float, None] = None, + kp_mod : int = 1, +): + ''' + Resize the image to the target width and height. Set one of width and height to -1 to keep the aspect ratio. + Only one of `tgt_wh` and `ratio` can be set, if both are set, `tgt_wh` will be used. + + ### Args + - img: np.ndarray, (H, W, 3) + - tgt_wh: Tuple[int, int], default=None + - The target width and height, set one of them to -1 to keep the aspect ratio. + - ratio: float, default=None + - The ratio to resize the frames. It will be used if `tgt_wh` is not set. + - kp_mod: int, default 1 + - Keep the width and height as multiples of `kp_mod`. + - For example, if `kp_mod=16`, the width and height will be rounded to the nearest multiple of 16. + + ### Returns + - np.ndarray, (H', W', 3) + - The resized iamges. + ''' + assert len(img.shape) == 3, 'img must have 3 dimensions.' + return flex_resize_video(img[None], tgt_wh, ratio, kp_mod)[0] + + +def flex_resize_video( + frames : np.ndarray, + tgt_wh : Union[Tuple[int, int], None] = None, + ratio : Union[float, None] = None, + kp_mod : int = 1, +): + ''' + Resize the frames to the target width and height. Set one of width and height to -1 to keep the aspect ratio. + Only one of `tgt_wh` and `ratio` can be set, if both are set, `tgt_wh` will be used. + + ### Args + - frames: np.ndarray, (L, H, W, 3) + - tgt_wh: Tuple[int, int], default=None + - The target width and height, set one of them to -1 to keep the aspect ratio. + - ratio: float, default=None + - The ratio to resize the frames. It will be used if `tgt_wh` is not set. + - kp_mod: int, default 1 + - Keep the width and height as multiples of `kp_mod`. + - For example, if `kp_mod=16`, the width and height will be rounded to the nearest multiple of 16. + + ### Returns + - np.ndarray, (L, H', W', 3) + - The resized frames. + ''' + assert tgt_wh is not None or ratio is not None, 'At least one of tgt_wh and ratio must be set.' + if tgt_wh is not None: + assert len(tgt_wh) == 2, 'tgt_wh must be a tuple of 2 elements.' + assert tgt_wh[0] > 0 or tgt_wh[1] > 0, 'At least one of width and height must be positive.' + if ratio is not None: + assert ratio > 0, 'ratio must be positive.' + assert len(frames.shape) == 4, 'frames must have 3 or 4 dimensions.' + + def align_size(val:float): + ''' It will round the value to the nearest multiple of `kp_mod`. ''' + return int(round(val / kp_mod) * kp_mod) + + # Calculate the target width and height. + orig_h, orig_w = frames.shape[1], frames.shape[2] + tgt_wh = (int(orig_w * ratio), int(orig_h * ratio)) if tgt_wh is None else tgt_wh # Get wh from ratio if not given. # type: ignore + tgt_w, tgt_h = tgt_wh + tgt_w = align_size(orig_w * tgt_h / orig_h) if tgt_w == -1 else align_size(tgt_w) + tgt_h = align_size(orig_h * tgt_w / orig_w) if tgt_h == -1 else align_size(tgt_h) + # Resize the frames. + resized_frames = np.stack([cv2.resize(frame, (tgt_w, tgt_h)) for frame in frames]) + + return resized_frames + + +def splice_img( + img_grids : Union[List[np.ndarray], np.ndarray], + grid_ids : Union[List[int], np.ndarray], +): + ''' + Splice the images with the same size, according to the grid index. + For example, you have 3 images [i1, i2, i3], and a `grid_ids` matrix: + [[ 0, 1], |i1|i2| + [ 2, -1], , then the results will be |i3|ib| , where ib means a black place holder. + [-1, -1]] |ib|ib| + + ### Args + - img_grids: List[np.ndarray] or np.ndarray, (K, H, W, 3) + - The source images to splice. It indicates that all the images have the same size. + - grid_ids: List[int] or np.ndarray, (Y, X) + - The grid index of each image. It should be a 2D matrix with integers as the type of elements. + - The value in this matrix indexed the image in the `video_grids`, so it ranges from 0 to K-1. + - Specially, set the grid index to -1 to use a black place holder. + + ### Returns + - np.ndarray, (H*Y, W*X, 3) + - The spliced images. + ''' + if isinstance(img_grids, List): + img_grids = np.stack(img_grids) + if isinstance(grid_ids, List): + grid_ids = np.array(grid_ids) + + assert len(img_grids.shape) == 4, 'img_grids must be in shape (K, H, W, 3).' + return splice_video(img_grids[:, None], grid_ids)[0] + + +def splice_video( + video_grids : Union[List[np.ndarray], np.ndarray], + grid_ids : Union[List[int], np.ndarray], +): + ''' + Splice the videos with the same size, according to the grid index. + For example, you have 3 videos [v1, v2, v3], and a `grid_ids` matrix: + [[ 0, 1], |v1|v2| + [ 2, -1], , then the results will be |v3|vb| , wher vb means a black place holder. + [-1, -1]] |vb|vb| + + ### Args + - video_grids: List[np.ndarray] or np.ndarray, (K, L, H, W, C) + - The source videos to splice. It indicates that all the videos have the same size. + - grid_ids: List[int] or np.ndarray, (Y, X) + - The grid index of each video. It should be a 2D matrix with integers as the type of elements. + - The value in this matrix indexed the video in the `video_grids`, so it ranges from 0 to K-1. + - Specially, set the grid index to -1 to use a black place holder. + + ### Returns + - np.ndarray, (L, H*Y, W*X, C) + - The spliced video. + ''' + if isinstance(video_grids, List): + video_grids = np.stack(video_grids) + if isinstance(grid_ids, List): + grid_ids = np.array(grid_ids) + + assert len(video_grids.shape) == 5, 'video_grids must be in shape (K, L, H, W, 3).' + assert len(grid_ids.shape) == 2, 'grid_ids must be a 2D matrix.' + assert isinstance(grid_ids[0, 0].item(), int), f'grid_ids must be an integer matrix, but got {grid_ids.dtype}.' + + # Splice the videos. + K, L, H, W, C = video_grids.shape + Y, X = grid_ids.shape + + # Initialize the spliced video. + spliced_video = np.zeros((L, H*Y, W*X, C), dtype=np.uint8) + for x in range(X): + for y in range(Y): + grid_id = grid_ids[y, x] + if grid_id == -1: + continue + spliced_video[:, y*H:(y+1)*H, x*W:(x+1)*W, :] = video_grids[grid_id] + + return spliced_video + + +def crop_img( + img : np.ndarray, + lurb : Union[np.ndarray, List], +): + ''' + Crop the image with the given bounding box. + The data should be represented in uint8. + If the bounding box is out of the image, pad the image with zeros. + + ### Args + - img: np.ndarray, (H, W, C) + - lurb: np.ndarray or list, (4,) + - The bounding box in the format of left, up, right, bottom. + + ### Returns + - np.ndarray, (H', W', C) + - The cropped image. + ''' + + return crop_video(img[None], lurb)[0] + + +def crop_video( + frames : np.ndarray, + lurb : Union[np.ndarray, List], +): + ''' + Crop the video with the given bounding box. + The data should be represented in uint8. + If the bounding box is out of the video, pad the frames with zeros. + + ### Args + - frames: np.ndarray, (L, H, W, C) + - lurb: np.ndarray or list, (4,) + - The bounding box in the format of left, up, right, bottom. + + ### Returns + - np.ndarray, (L, H', W', C) + - The cropped video. + ''' + assert len(frames.shape) == 4, 'framess must have 4 dimensions.' + if isinstance(lurb, List): + lurb = np.array(lurb) + + l, u, r, b = lurb.astype(int) + L, H, W = frames.shape[:3] + l_, u_, r_, b_ = max(0, l), max(0, u), min(W, r), min(H, b) + cropped_frames = np.zeros((L, b-u, r-l, 3), dtype=np.uint8) + cropped_frames[:, u_-u:b_-u, l_-l:r_-l] = frames[:, u_:b, l_:r] + + return cropped_frames + +def pad_img( + img : np.ndarray, + tgt_wh : Tuple[int, int], + pad_val : int = 0, + align : str = 'c-c', +): + ''' + Pad the image to the target width and height. + + ### Args + - img: np.ndarray, (H, W, 3) + - tgt_wh: Tuple[int, int] + - The target width and height. Use -1 to indicate the original scale. + - pad_value: int, default 0 + - The value to pad the image. + - align: str, default 'c-c' + - The alignment of the image. It should be in the format of 'h-v', + where 'h' and 'v' can be 'l', 'c', 'r' and 't', 'c', 'b' respectively. + + ### Returns + - np.ndarray, (H', W', 3) + - The padded image. + ''' + assert len(img.shape) == 3, 'img must have 3 dimensions.' + return pad_video(img[None], tgt_wh, pad_val, align)[0] + +def pad_video( + frames : np.ndarray, + tgt_wh : Tuple[int, int], + pad_val : int = 0, + align : str = 'c-c', +): + ''' + Pad the video to the target width and height. + + ### Args + - frames: np.ndarray, (L, H, W, 3) + - tgt_wh: Tuple[int, int] + - The target width and height. Use -1 to indicate the original scale. + - pad_value: int, default 0 + - The value to pad the frames. + + ### Returns + - np.ndarray, (L, H', W', 3) + - The padded frames. + ''' + # Check data validity. + assert len(frames.shape) == 4, 'frames must have 4 dimensions.' + assert len(tgt_wh) == 2, 'tgt_wh must be a tuple of 2 elements.' + H, W = frames.shape[1], frames.shape[2] + if tgt_wh[0] == -1: tgt_wh = (W, tgt_wh[1]) + if tgt_wh[1] == -1: tgt_wh = (tgt_wh[0], H) + assert tgt_wh[0] >= frames.shape[2] and tgt_wh[1] >= frames.shape[1], 'The target size must be larger than the original size.' + assert pad_val >= 0 and pad_val <= 255, 'The pad value must be in the range of [0, 255].' + # Check align pattern. + align = align.split('-') + assert len(align) == 2, 'align must be in the format of "h-v".' + assert align[0] in ['l', 'c', 'r'] and align[1] in ['l', 'c', 'r'], 'align must be in ["l", "c", "r"].' + + tgt_w, tgt_h = tgt_wh + pad_pix = [tgt_w - W, tgt_h - H] # indicate how many pixels to be padded + pad_lu = [0, 0] # how many pixels to pad on the left and the up side + for direction in [0, 1]: + if align[direction] == 'c': + pad_lu[direction] = pad_pix[direction] // 2 + elif align[direction] == 'r': + pad_lu[direction] = pad_pix[direction] + pad_l, pad_r, pad_u, pad_b = pad_lu[0], pad_pix[0] - pad_lu[0], pad_lu[1], pad_pix[1] - pad_lu[1] + + padded_frames = np.pad(frames, ((0, 0), (pad_u, pad_b), (pad_l, pad_r), (0, 0)), 'constant', constant_values=pad_val) + + return padded_frames \ No newline at end of file diff --git a/lib/utils/media/io.py b/lib/utils/media/io.py new file mode 100644 index 0000000000000000000000000000000000000000..d2cbe9e21c33c96034c4acc37427360209f33e79 --- /dev/null +++ b/lib/utils/media/io.py @@ -0,0 +1,107 @@ +import imageio +import numpy as np + +from tqdm import tqdm +from typing import Union, List +from pathlib import Path +from glob import glob + +from .edit import flex_resize_img, flex_resize_video + + +def load_img_meta( + img_path : Union[str, Path], +): + ''' Read the image meta from the given path without opening image. ''' + assert Path(img_path).exists(), f'Image not found: {img_path}' + H, W = imageio.v3.improps(img_path).shape[:2] + meta = {'w': W, 'h': H} + return meta + + +def load_img( + img_path : Union[str, Path], + mode : str = 'RGB', +): + ''' Read the image from the given path. ''' + assert Path(img_path).exists(), f'Image not found: {img_path}' + + img = imageio.v3.imread(img_path, plugin='pillow', mode=mode) + + meta = { + 'w': img.shape[1], + 'h': img.shape[0], + } + return img, meta + + +def save_img( + img : np.ndarray, + output_path : Union[str, Path], + resize_ratio : Union[float, None] = None, + **kwargs, +): + ''' Save the image. ''' + assert img.ndim == 3, f'Invalid image shape: {img.shape}' + + if resize_ratio is not None: + img = flex_resize_img(img, ratio=resize_ratio) + + imageio.v3.imwrite(output_path, img, **kwargs) + + +def load_video( + video_path : Union[str, Path], +): + ''' Read the video from the given path. ''' + if isinstance(video_path, str): + video_path = Path(video_path) + + assert video_path.exists(), f'Video not found: {video_path}' + + if video_path.is_dir(): + print(f'Found {video_path} is a directory. It will be regarded as a image folder.') + imgs_path = sorted(glob(str(video_path / '*'))) + frames = [] + for img_path in tqdm(imgs_path): + frames.append(imageio.imread(img_path)) + fps = 30 # default fps + else: + print(f'Found {video_path} is a file. It will be regarded as a video file.') + reader = imageio.get_reader(video_path, format='FFMPEG') + frames = [] + for frame in tqdm(reader, total=reader.count_frames()): + frames.append(frame) + fps = reader.get_meta_data()['fps'] + frames = np.stack(frames, axis=0) # (L, H, W, 3) + meta = { + 'fps': fps, + 'w' : frames.shape[2], + 'h' : frames.shape[1], + 'L' : frames.shape[0], + } + + return frames, meta + + +def save_video( + frames : Union[np.ndarray, List[np.ndarray]], + output_path : Union[str, Path], + fps : float = 30, + resize_ratio : Union[float, None] = None, + quality : Union[int, None] = None, +): + ''' Save the frames as a video. ''' + if isinstance(frames, List): + frames = np.stack(frames, axis=0) + assert frames.ndim == 4, f'Invalid frames shape: {frames.shape}' + + if resize_ratio is not None: + frames = flex_resize_video(frames, ratio=resize_ratio) + Path(output_path).parent.mkdir(parents=True, exist_ok=True) + + writer = imageio.get_writer(output_path, fps=fps, quality=quality) + output_seq_name = str(output_path).split('/')[-1] + for frame in tqdm(frames, desc=f'Saving {output_seq_name}'): + writer.append_data(frame) + writer.close() \ No newline at end of file diff --git a/lib/utils/vis/__init__.py b/lib/utils/vis/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a59b8fbac08376a6807bbf250a26728662bbd747 --- /dev/null +++ b/lib/utils/vis/__init__.py @@ -0,0 +1,8 @@ +from .colors import ColorPalette + +from .wis3d import HWis3D as Wis3D + +from .py_renderer import ( + render_mesh_overlay_img, + render_mesh_overlay_video, +) \ No newline at end of file diff --git a/lib/utils/vis/colors.py b/lib/utils/vis/colors.py new file mode 100644 index 0000000000000000000000000000000000000000..26b8d622720005991faf8f614191400d057883f4 --- /dev/null +++ b/lib/utils/vis/colors.py @@ -0,0 +1,45 @@ +from lib.kits.basic import * + +def float_to_int8(color: List[float]) -> List[int]: + return [int(c * 255) for c in color] + +def int8_to_float(color: List[int]) -> List[float]: + return [c / 255 for c in color] + +def int8_to_hex(color: List[int]) -> str: + return '#%02x%02x%02x' % tuple(color) + +def float_to_hex(color: List[float]) -> str: + return int8_to_hex(float_to_int8(color)) + +def hex_to_int8(color: str) -> List[int]: + return [int(color[i+1:i+3], 16) for i in (0, 2, 4)] + +def hex_to_float(color: str) -> List[float]: + return int8_to_float(hex_to_int8(color)) + + +# TODO: incorporate https://github.com/vye16/slahmr/blob/main/slahmr/vis/colors.txt + +class ColorPalette: + + # Picked from: https://colorsite.librian.net/ + presets = { + 'black' : '#2b2b2b', + 'white' : '#eaedf7', + 'pink' : '#e6cde3', + 'light_pink' : '#fdeff2', + 'blue' : '#89c3eb', + 'purple' : '#a6a5c4', + 'light_purple' : '#bbc8e6', + 'red' : '#d3381c', + 'orange' : '#f9c89b', + 'light_orange' : '#fddea5', + 'brown' : '#b48a76', + 'human_yellow' : '#f1bf99', + 'green' : '#a8c97f', + } + + presets_int8 = {k: hex_to_int8(v) for k, v in presets.items()} + presets_float = {k: int8_to_float(v) for k, v in presets_int8.items()} + presets_hex = {k: v for k, v in presets.items()} \ No newline at end of file diff --git a/lib/utils/vis/p3d_renderer/README.md b/lib/utils/vis/p3d_renderer/README.md new file mode 100644 index 0000000000000000000000000000000000000000..9e692f02bd6a6cc6ff9a06811e4c85c48f65f9cd --- /dev/null +++ b/lib/utils/vis/p3d_renderer/README.md @@ -0,0 +1,29 @@ +## Pytorch3D Renderer + +> The code was modified from GVHMR: https://github.com/zju3dv/GVHMR/tree/main/hmr4d/utils/vis + +### Dependency + +```shell +pip install "git+https://github.com/facebookresearch/pytorch3d.git@v0.7.6" +``` + +### Example + +```python +from lib.utils.vis import Renderer +import imageio + +fps = 30 +focal_length = data["cam_int"][0][0, 0] +width, height = img_hw +faces = smplh[data["gender"]].bm.faces +renderer = Renderer(width, height, focal_length, "cuda", faces) +writer = imageio.get_writer("tmp_debug.mp4", fps=fps, mode="I", format="FFMPEG", macro_block_size=1) + +for i in tqdm(range(length)): + img = np.zeros((height, width, 3), dtype=np.uint8) + img = renderer.render_mesh(smplh_out.vertices[i].cuda(), img) + writer.append_data(img) +writer.close() +``` \ No newline at end of file diff --git a/lib/utils/vis/p3d_renderer/__init__.py b/lib/utils/vis/p3d_renderer/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..be01f37fc2629ed50bc654130037ad86fea57f62 --- /dev/null +++ b/lib/utils/vis/p3d_renderer/__init__.py @@ -0,0 +1,125 @@ +from lib.kits.basic import * + +import cv2 +import imageio + +from tqdm import tqdm + +from lib.utils.vis import ColorPalette +from lib.utils.media import save_img + +from .renderer import * + + +def render_mesh_overlay_img( + faces : Union[torch.Tensor, np.ndarray], + verts : torch.Tensor, + K4 : List, + img : np.ndarray, + output_fn : Optional[Union[str, Path]] = None, + device : str = 'cuda', + resize : float = 1.0, + Rt : Optional[Tuple[torch.Tensor]] = None, + mesh_color : Optional[Union[List[float], str]] = 'blue', +) -> Any: + ''' + Render the mesh overlay on the input video frames. + + ### Args + - faces: Union[torch.Tensor, np.ndarray], (V, 3) + - verts: torch.Tensor, (V, 3) + - K4: List + - [fx, fy, cx, cy], the components of intrinsic camera matrix. + - img: np.ndarray, (H, W, 3) + - output_fn: Union[str, Path] or None + - The output file path, if None, return the rendered img. + - fps: int, default 30 + - device: str, default 'cuda' + - resize: float, default 1.0 + - The resize factor of the output video. + - Rt: Tuple of Tensor, default None + - The extrinsic camera matrix, in the form of (R, t). + ''' + frame = render_mesh_overlay_video( + faces = faces, + verts = verts[None], + K4 = K4, + frames = img[None], + device = device, + resize = resize, + Rt = Rt, + mesh_color = mesh_color, + )[0] + + if output_fn is None: + return frame + else: + save_img(frame, output_fn) + + +def render_mesh_overlay_video( + faces : Union[torch.Tensor, np.ndarray], + verts : torch.Tensor, + K4 : List, + frames : np.ndarray, + output_fn : Optional[Union[str, Path]] = None, + fps : int = 30, + device : str = 'cuda', + resize : float = 1.0, + Rt : Tuple = None, + mesh_color : Optional[Union[List[float], str]] = 'blue', +) -> Any: + ''' + Render the mesh overlay on the input video frames. + + ### Args + - faces: Union[torch.Tensor, np.ndarray], (V, 3) + - verts: torch.Tensor, (L, V, 3) + - K4: List + - [fx, fy, cx, cy], the components of intrinsic camera matrix. + - frames: np.ndarray, (L, H, W, 3) + - output_fn: Union[str, Path] or None + - The output file path, if None, return the rendered frames. + - fps: int, default 30 + - device: str, default 'cuda' + - resize: float, default 1.0 + - The resize factor of the output video. + - Rt: Tuple, default None + - The extrinsic camera matrix, in the form of (R, t). + ''' + if isinstance(faces, torch.Tensor): + faces = faces.cpu().numpy() + assert len(K4) == 4, 'K4 must be a list of 4 elements.' + assert frames.shape[0] == verts.shape[0], 'The length of frames and verts must be the same.' + assert frames.shape[-1] == 3, 'The last dimension of frames must be 3.' + if isinstance(mesh_color, str): + mesh_color = ColorPalette.presets_float[mesh_color] + + # Prepare the data. + L = frames.shape[0] + focal_length = (K4[0] + K4[1]) / 2 # f = (fx + fy) / 2 + width, height = frames.shape[-2], frames.shape[-3] + cx2, cy2 = int(K4[2] * 2), int(K4[3] * 2) + # Prepare the renderer. + renderer = Renderer(cx2, cy2, focal_length, device, faces) + if Rt is not None: + Rt = (to_tensor(Rt[0], device), to_tensor(Rt[1], device)) + renderer.create_camera(*Rt) + + if output_fn is None: + result_frames = [] + for i in range(L): + img = renderer.render_mesh(verts[i].to(device), frames[i], mesh_color) + img = cv2.resize(img, (int(width * resize), int(height * resize))) + result_frames.append(img) + result_frames = np.stack(result_frames, axis=0) + return result_frames + else: + writer = imageio.get_writer(output_fn, fps=fps, mode='I', format='FFMPEG', macro_block_size=1) + # Render the video. + output_seq_name = str(output_fn).split('/')[-1] + for i in tqdm(range(L), desc=f'Rendering [{output_seq_name}]...'): + img = renderer.render_mesh(verts[i].to(device), frames[i]) + writer.append_data(img) + img = cv2.resize(img, (int(width * resize), int(height * resize))) + writer.close() \ No newline at end of file diff --git a/lib/utils/vis/p3d_renderer/renderer.py b/lib/utils/vis/p3d_renderer/renderer.py new file mode 100644 index 0000000000000000000000000000000000000000..9dc58781b17d7bccb11a325f938c01cb05cc885f --- /dev/null +++ b/lib/utils/vis/p3d_renderer/renderer.py @@ -0,0 +1,340 @@ +import cv2 +import torch +import numpy as np + +from pytorch3d.renderer import ( + PerspectiveCameras, + TexturesVertex, + PointLights, + Materials, + RasterizationSettings, + MeshRenderer, + MeshRasterizer, + SoftPhongShader, +) +from pytorch3d.structures import Meshes +from pytorch3d.structures.meshes import join_meshes_as_scene +from pytorch3d.renderer.cameras import look_at_rotation +from pytorch3d.transforms import axis_angle_to_matrix + +from .utils import get_colors, checkerboard_geometry + + +colors_str_map = { + "gray": [0.8, 0.8, 0.8], + "green": [39, 194, 128], +} + + +def overlay_image_onto_background(image, mask, bbox, background): + if isinstance(image, torch.Tensor): + image = image.detach().cpu().numpy() + if isinstance(mask, torch.Tensor): + mask = mask.detach().cpu().numpy() + + out_image = background.copy() + bbox = bbox[0].int().cpu().numpy().copy() + roi_image = out_image[bbox[1] : bbox[3], bbox[0] : bbox[2]] + + roi_image[mask] = image[mask] + out_image[bbox[1] : bbox[3], bbox[0] : bbox[2]] = roi_image + + return out_image + + +def update_intrinsics_from_bbox(K_org, bbox): + device, dtype = K_org.device, K_org.dtype + + K = torch.zeros((K_org.shape[0], 4, 4)).to(device=device, dtype=dtype) + K[:, :3, :3] = K_org.clone() + K[:, 2, 2] = 0 + K[:, 2, -1] = 1 + K[:, -1, 2] = 1 + + image_sizes = [] + for idx, bbox in enumerate(bbox): + left, upper, right, lower = bbox + cx, cy = K[idx, 0, 2], K[idx, 1, 2] + + new_cx = cx - left + new_cy = cy - upper + new_height = max(lower - upper, 1) + new_width = max(right - left, 1) + new_cx = new_width - new_cx + new_cy = new_height - new_cy + + K[idx, 0, 2] = new_cx + K[idx, 1, 2] = new_cy + image_sizes.append((int(new_height), int(new_width))) + + return K, image_sizes + + +def perspective_projection(x3d, K, R=None, T=None): + if R != None: + x3d = torch.matmul(R, x3d.transpose(1, 2)).transpose(1, 2) + if T != None: + x3d = x3d + T.transpose(1, 2) + + x2d = torch.div(x3d, x3d[..., 2:]) + x2d = torch.matmul(K, x2d.transpose(-1, -2)).transpose(-1, -2)[..., :2] + return x2d + + +def compute_bbox_from_points(X, img_w, img_h, scaleFactor=1.2): + left = torch.clamp(X.min(1)[0][:, 0], min=0, max=img_w) + right = torch.clamp(X.max(1)[0][:, 0], min=0, max=img_w) + top = torch.clamp(X.min(1)[0][:, 1], min=0, max=img_h) + bottom = torch.clamp(X.max(1)[0][:, 1], min=0, max=img_h) + + cx = (left + right) / 2 + cy = (top + bottom) / 2 + width = right - left + height = bottom - top + + new_left = torch.clamp(cx - width / 2 * scaleFactor, min=0, max=img_w - 1) + new_right = torch.clamp(cx + width / 2 * scaleFactor, min=1, max=img_w) + new_top = torch.clamp(cy - height / 2 * scaleFactor, min=0, max=img_h - 1) + new_bottom = torch.clamp(cy + height / 2 * scaleFactor, min=1, max=img_h) + + bbox = torch.stack((new_left.detach(), new_top.detach(), new_right.detach(), new_bottom.detach())).int().float().T + + return bbox + + +class Renderer: + def __init__(self, width, height, focal_length=None, device="cuda", faces=None, K=None): + self.width = width + self.height = height + assert (focal_length is not None) ^ (K is not None), "focal_length and K are mutually exclusive" + + self.device = device + if faces is not None: + if isinstance(faces, np.ndarray): + faces = torch.from_numpy((faces).astype("int")) + if len(faces.shape) == 2: + self.faces = faces.unsqueeze(0).to(self.device) + elif len(faces.shape) == 3: + self.faces = faces.to(self.device) + else: + raise ValueError("faces should have shape of (F, 3) or (N, F, 3)") + + self.initialize_camera_params(focal_length, K) + self.lights = PointLights(device=device, location=[[0.0, 0.0, -10.0]]) + self.create_renderer() + + def create_renderer(self): + self.renderer = MeshRenderer( + rasterizer = MeshRasterizer( + raster_settings = RasterizationSettings( + image_size = self.image_sizes[0], + blur_radius = 1e-5, + bin_size = 0, + ), + ), + shader = SoftPhongShader( + device=self.device, + lights=self.lights, + ), + ) + + def create_camera(self, R=None, T=None): + if R is not None: + self.R = R.clone().view(1, 3, 3).to(self.device) + if T is not None: + self.T = T.clone().view(1, 3).to(self.device) + + return PerspectiveCameras( + device=self.device, R=self.R.mT, T=self.T, K=self.K_full, image_size=self.image_sizes, in_ndc=False + ) + + def initialize_camera_params(self, focal_length, K): + # Extrinsics + self.R = torch.diag(torch.tensor([1, 1, 1])).float().to(self.device).unsqueeze(0) + + self.T = torch.tensor([0, 0, 0]).unsqueeze(0).float().to(self.device) + + # Intrinsics + if K is not None: + self.K = K.float().reshape(1, 3, 3).to(self.device) + else: + assert focal_length is not None, "focal_length or K should be provided" + self.K = ( + torch.tensor([[focal_length, 0, self.width / 2], [0, focal_length, self.height / 2], [0, 0, 1]]) + .float() + .reshape(1, 3, 3) + .to(self.device) + ) + self.bboxes = torch.tensor([[0, 0, self.width, self.height]]).float() + self.K_full, self.image_sizes = update_intrinsics_from_bbox(self.K, self.bboxes) + self.cameras = self.create_camera() + + def set_intrinsic(self, K): + self.K = K.reshape(1, 3, 3) + + def set_ground(self, length, center_x, center_z): + device = self.device + length, center_x, center_z = map(float, (length, center_x, center_z)) + v, f, vc, fc = map(torch.from_numpy, checkerboard_geometry(length=length * 2, c1=center_x, c2=center_z, up="y")) + v, f, vc = v.to(device), f.to(device), vc.to(device) + self.ground_geometry = [v, f, vc] + + def update_bbox(self, x3d, scale=2.0, mask=None): + """Update bbox of cameras from the given 3d points + + x3d: input 3D keypoints (or vertices), (num_frames, num_points, 3) + """ + + if x3d.size(-1) != 3: + x2d = x3d.unsqueeze(0) + else: + x2d = perspective_projection(x3d.unsqueeze(0), self.K, self.R, self.T.reshape(1, 3, 1)) + + if mask is not None: + x2d = x2d[:, ~mask] + bbox = compute_bbox_from_points(x2d, self.width, self.height, scale) + self.bboxes = bbox + + self.K_full, self.image_sizes = update_intrinsics_from_bbox(self.K, bbox) + self.cameras = self.create_camera() + self.create_renderer() + + def reset_bbox( + self, + ): + bbox = torch.zeros((1, 4)).float().to(self.device) + bbox[0, 2] = self.width + bbox[0, 3] = self.height + self.bboxes = bbox + + self.K_full, self.image_sizes = update_intrinsics_from_bbox(self.K, bbox) + self.cameras = self.create_camera() + self.create_renderer() + + def render_mesh(self, vertices, background=None, colors=[0.8, 0.8, 0.8], VI=50): + if vertices.dim() == 2: + vertices = vertices.unsqueeze(0) # (V, 3) -> (1, V, 3) + elif vertices.dim() != 3: + raise ValueError("vertices should have shape of ((Nm,) V, 3)") + self.update_bbox(vertices.view(-1, 3)[::VI], scale=1.2) + + if isinstance(colors, torch.Tensor): + # per-vertex color + verts_features = colors.to(device=vertices.device, dtype=vertices.dtype) + colors = [0.8, 0.8, 0.8] + else: + if colors[0] > 1: + colors = [c / 255.0 for c in colors] + verts_features = torch.tensor(colors).reshape(1, 1, 3).to(device=vertices.device, dtype=vertices.dtype) + verts_features = verts_features.repeat(vertices.shape[0], vertices.shape[1], 1) + textures = TexturesVertex(verts_features=verts_features) + + mesh = Meshes( + verts=vertices, + faces=self.faces, + textures=textures, + ) + + materials = Materials(device=self.device, specular_color=(colors,), shininess=0) + + results = torch.flip(self.renderer(mesh, materials=materials, cameras=self.cameras, lights=self.lights), [1, 2]) + image = results[0, ..., :3] * 255 + mask = results[0, ..., -1] > 1e-3 + + if background is None: + background = np.ones((self.height, self.width, 3)).astype(np.uint8) * 255 + + image = overlay_image_onto_background(image, mask, self.bboxes, background.copy()) + self.reset_bbox() + return image + + def render_with_ground(self, verts, colors, cameras, lights, faces=None): + """ + :param verts (N, V, 3), potential multiple people + :param colors (N, 3) or (N, V, 3) + :param faces (N, F, 3), optional, otherwise self.faces is used will be used + """ + # Sanity check of input verts, colors and faces: (B, V, 3), (B, F, 3), (B, V, 3) + N, V, _ = verts.shape + if faces is None: + faces = self.faces.clone().expand(N, -1, -1) + else: + assert len(faces.shape) == 3, "faces should have shape of (N, F, 3)" + + assert len(colors.shape) in [2, 3] + if len(colors.shape) == 2: + assert len(colors) == N, "colors of shape 2 should be (N, 3)" + colors = colors[:, None] + colors = colors.expand(N, V, -1)[..., :3] + + # (V, 3), (F, 3), (V, 3) + gv, gf, gc = self.ground_geometry + verts = list(torch.unbind(verts, dim=0)) + [gv] + faces = list(torch.unbind(faces, dim=0)) + [gf] + colors = list(torch.unbind(colors, dim=0)) + [gc[..., :3]] + mesh = create_meshes(verts, faces, colors) + + materials = Materials(device=self.device, shininess=0) + + results = self.renderer(mesh, cameras=cameras, lights=lights, materials=materials) + image = (results[0, ..., :3].cpu().numpy() * 255).astype(np.uint8) + + return image + + +def create_meshes(verts, faces, colors): + """ + :param verts (B, V, 3) + :param faces (B, F, 3) + :param colors (B, V, 3) + """ + textures = TexturesVertex(verts_features=colors) + meshes = Meshes(verts=verts, faces=faces, textures=textures) + return join_meshes_as_scene(meshes) + + +def get_global_cameras(verts, device="cuda", distance=5, position=(-5.0, 5.0, 0.0)): + """This always put object at the center of view""" + positions = torch.tensor([position]).repeat(len(verts), 1) + targets = verts.mean(1) + + directions = targets - positions + directions = directions / torch.norm(directions, dim=-1).unsqueeze(-1) * distance + positions = targets - directions + + rotation = look_at_rotation(positions, targets).mT + translation = -(rotation @ positions.unsqueeze(-1)).squeeze(-1) + + lights = PointLights(device=device, location=[position]) + return rotation, translation, lights + + +def get_global_cameras_static(verts, beta=4.0, cam_height_degree=30, target_center_height=0.75, device="cuda"): + L, V, _ = verts.shape + + # Compute target trajectory, denote as center + scale + targets = verts.mean(1) # (L, 3) + targets[:, 1] = 0 # project to xz-plane + target_center = targets.mean(0) # (3,) + target_scale, target_idx = torch.norm(targets - target_center, dim=-1).max(0) + + # a 45 degree vec from longest axis + long_vec = targets[target_idx] - target_center # (x, 0, z) + long_vec = long_vec / torch.norm(long_vec) + R = axis_angle_to_matrix(torch.tensor([0, np.pi / 4, 0])).to(long_vec) + vec = R @ long_vec + + # Compute camera position (center + scale * vec * beta) + y=4 + target_scale = max(target_scale, 1.0) * beta + position = target_center + vec * target_scale + position[1] = target_scale * np.tan(np.pi * cam_height_degree / 180) + target_center_height + + # Compute camera rotation and translation + positions = position.unsqueeze(0).repeat(L, 1) + target_centers = target_center.unsqueeze(0).repeat(L, 1) + target_centers[:, 1] = target_center_height + rotation = look_at_rotation(positions, target_centers).mT + translation = -(rotation @ positions.unsqueeze(-1)).squeeze(-1) + + lights = PointLights(device=device, location=[position.tolist()]) + return rotation, translation, lights diff --git a/lib/utils/vis/p3d_renderer/utils.py b/lib/utils/vis/p3d_renderer/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..646e1b21f4b0ad7924ddbcd822dc36808794e421 --- /dev/null +++ b/lib/utils/vis/p3d_renderer/utils.py @@ -0,0 +1,804 @@ +import os +import cv2 +import numpy as np +import torch +from PIL import Image + + +def read_image(path, scale=1): + im = Image.open(path) + if scale == 1: + return np.array(im) + W, H = im.size + w, h = int(scale * W), int(scale * H) + return np.array(im.resize((w, h), Image.ANTIALIAS)) + + +def transform_torch3d(T_c2w): + """ + :param T_c2w (*, 4, 4) + returns (*, 3, 3), (*, 3) + """ + R1 = torch.tensor( + [ + [-1.0, 0.0, 0.0], + [0.0, -1.0, 0.0], + [0.0, 0.0, 1.0], + ], + device=T_c2w.device, + ) + R2 = torch.tensor( + [ + [1.0, 0.0, 0.0], + [0.0, -1.0, 0.0], + [0.0, 0.0, -1.0], + ], + device=T_c2w.device, + ) + cam_R, cam_t = T_c2w[..., :3, :3], T_c2w[..., :3, 3] + cam_R = torch.einsum("...ij,jk->...ik", cam_R, R1) + cam_t = torch.einsum("ij,...j->...i", R2, cam_t) + return cam_R, cam_t + + +def transform_pyrender(T_c2w): + """ + :param T_c2w (*, 4, 4) + """ + T_vis = torch.tensor( + [ + [1.0, 0.0, 0.0, 0.0], + [0.0, -1.0, 0.0, 0.0], + [0.0, 0.0, -1.0, 0.0], + [0.0, 0.0, 0.0, 1.0], + ], + device=T_c2w.device, + ) + return torch.einsum("...ij,jk->...ik", torch.einsum("ij,...jk->...ik", T_vis, T_c2w), T_vis) + + +def smpl_to_geometry(verts, faces, vis_mask=None, track_ids=None): + """ + :param verts (B, T, V, 3) + :param faces (F, 3) + :param vis_mask (optional) (B, T) visibility of each person + :param track_ids (optional) (B,) + returns list of T verts (B, V, 3), faces (F, 3), colors (B, 3) + where B is different depending on the visibility of the people + """ + B, T = verts.shape[:2] + device = verts.device + + # (B, 3) + colors = track_to_colors(track_ids) if track_ids is not None else torch.ones(B, 3, device) * 0.5 + + # list T (B, V, 3), T (B, 3), T (F, 3) + return filter_visible_meshes(verts, colors, faces, vis_mask) + + +def filter_visible_meshes(verts, colors, faces, vis_mask=None, vis_opacity=False): + """ + :param verts (B, T, V, 3) + :param colors (B, 3) + :param faces (F, 3) + :param vis_mask (optional tensor, default None) (B, T) ternary mask + -1 if not in frame + 0 if temporarily occluded + 1 if visible + :param vis_opacity (optional bool, default False) + if True, make occluded people alpha=0.5, otherwise alpha=1 + returns a list of T lists verts (Bi, V, 3), colors (Bi, 4), faces (F, 3) + """ + # import ipdb; ipdb.set_trace() + B, T = verts.shape[:2] + faces = [faces for t in range(T)] + if vis_mask is None: + verts = [verts[:, t] for t in range(T)] + colors = [colors for t in range(T)] + return verts, colors, faces + + # render occluded and visible, but not removed + vis_mask = vis_mask >= 0 + if vis_opacity: + alpha = 0.5 * (vis_mask[..., None] + 1) + else: + alpha = (vis_mask[..., None] >= 0).float() + vert_list = [verts[vis_mask[:, t], t] for t in range(T)] + colors = [torch.cat([colors[vis_mask[:, t]], alpha[vis_mask[:, t], t]], dim=-1) for t in range(T)] + bounds = get_bboxes(verts, vis_mask) + return vert_list, colors, faces, bounds + + +def get_bboxes(verts, vis_mask): + """ + return bb_min, bb_max, and mean for each track (B, 3) over entire trajectory + :param verts (B, T, V, 3) + :param vis_mask (B, T) + """ + B, T, *_ = verts.shape + bb_min, bb_max, mean = [], [], [] + for b in range(B): + v = verts[b, vis_mask[b, :T]] # (Tb, V, 3) + bb_min.append(v.amin(dim=(0, 1))) + bb_max.append(v.amax(dim=(0, 1))) + mean.append(v.mean(dim=(0, 1))) + bb_min = torch.stack(bb_min, dim=0) + bb_max = torch.stack(bb_max, dim=0) + mean = torch.stack(mean, dim=0) + # point to a track that's long and close to the camera + zs = mean[:, 2] + counts = vis_mask[:, :T].sum(dim=-1) # (B,) + mask = counts < 0.8 * T + zs[mask] = torch.inf + sel = torch.argmin(zs) + return bb_min.amin(dim=0), bb_max.amax(dim=0), mean[sel] + + +def track_to_colors(track_ids): + """ + :param track_ids (B) + """ + color_map = torch.from_numpy(get_colors()).to(track_ids) + return color_map[track_ids] / 255 # (B, 3) + + +def get_colors(): + # color_file = os.path.abspath(os.path.join(__file__, "../colors_phalp.txt")) + color_file = os.path.abspath(os.path.join(__file__, "../colors.txt")) + RGB_tuples = np.vstack( + [ + np.loadtxt(color_file, skiprows=0), + # np.loadtxt(color_file, skiprows=1), + np.random.uniform(0, 255, size=(10000, 3)), + [[0, 0, 0]], + ] + ) + b = np.where(RGB_tuples == 0) + RGB_tuples[b] = 1 + return RGB_tuples.astype(np.float32) + + +def checkerboard_geometry( + length=12.0, + color0=[0.8, 0.9, 0.9], + color1=[0.6, 0.7, 0.7], + tile_width=0.5, + alpha=1.0, + up="y", + c1=0.0, + c2=0.0, +): + assert up == "y" or up == "z" + color0 = np.array(color0 + [alpha]) + color1 = np.array(color1 + [alpha]) + radius = length / 2.0 + num_rows = num_cols = max(2, int(length / tile_width)) + vertices = [] + vert_colors = [] + faces = [] + face_colors = [] + for i in range(num_rows): + for j in range(num_cols): + u0, v0 = j * tile_width - radius, i * tile_width - radius + us = np.array([u0, u0, u0 + tile_width, u0 + tile_width]) + vs = np.array([v0, v0 + tile_width, v0 + tile_width, v0]) + zs = np.zeros(4) + if up == "y": + cur_verts = np.stack([us, zs, vs], axis=-1) # (4, 3) + cur_verts[:, 0] += c1 + cur_verts[:, 2] += c2 + else: + cur_verts = np.stack([us, vs, zs], axis=-1) # (4, 3) + cur_verts[:, 0] += c1 + cur_verts[:, 1] += c2 + + cur_faces = np.array([[0, 1, 3], [1, 2, 3], [0, 3, 1], [1, 3, 2]], dtype=np.int64) + cur_faces += 4 * (i * num_cols + j) # the number of previously added verts + use_color0 = (i % 2 == 0 and j % 2 == 0) or (i % 2 == 1 and j % 2 == 1) + cur_color = color0 if use_color0 else color1 + cur_colors = np.array([cur_color, cur_color, cur_color, cur_color]) + + vertices.append(cur_verts) + faces.append(cur_faces) + vert_colors.append(cur_colors) + face_colors.append(cur_colors) + + vertices = np.concatenate(vertices, axis=0).astype(np.float32) + vert_colors = np.concatenate(vert_colors, axis=0).astype(np.float32) + faces = np.concatenate(faces, axis=0).astype(np.float32) + face_colors = np.concatenate(face_colors, axis=0).astype(np.float32) + + return vertices, faces, vert_colors, face_colors + + +def camera_marker_geometry(radius, height, up): + assert up == "y" or up == "z" + if up == "y": + vertices = np.array( + [ + [-radius, -radius, 0], + [radius, -radius, 0], + [radius, radius, 0], + [-radius, radius, 0], + [0, 0, height], + ] + ) + else: + vertices = np.array( + [ + [-radius, 0, -radius], + [radius, 0, -radius], + [radius, 0, radius], + [-radius, 0, radius], + [0, -height, 0], + ] + ) + + faces = np.array( + [ + [0, 3, 1], + [1, 3, 2], + [0, 1, 4], + [1, 2, 4], + [2, 3, 4], + [3, 0, 4], + ] + ) + + face_colors = np.array( + [ + [1.0, 1.0, 1.0, 1.0], + [1.0, 1.0, 1.0, 1.0], + [0.0, 1.0, 0.0, 1.0], + [1.0, 0.0, 0.0, 1.0], + [0.0, 1.0, 0.0, 1.0], + [1.0, 0.0, 0.0, 1.0], + ] + ) + return vertices, faces, face_colors + + +def vis_keypoints( + keypts_list, + img_size, + radius=6, + thickness=3, + kpt_score_thr=0.3, + dataset="TopDownCocoDataset", +): + """ + Visualize keypoints + From ViTPose/mmpose/apis/inference.py + """ + palette = np.array( + [ + [255, 128, 0], + [255, 153, 51], + [255, 178, 102], + [230, 230, 0], + [255, 153, 255], + [153, 204, 255], + [255, 102, 255], + [255, 51, 255], + [102, 178, 255], + [51, 153, 255], + [255, 153, 153], + [255, 102, 102], + [255, 51, 51], + [153, 255, 153], + [102, 255, 102], + [51, 255, 51], + [0, 255, 0], + [0, 0, 255], + [255, 0, 0], + [255, 255, 255], + ] + ) + + if dataset in ( + "TopDownCocoDataset", + "BottomUpCocoDataset", + "TopDownOCHumanDataset", + "AnimalMacaqueDataset", + ): + # show the results + skeleton = [ + [15, 13], + [13, 11], + [16, 14], + [14, 12], + [11, 12], + [5, 11], + [6, 12], + [5, 6], + [5, 7], + [6, 8], + [7, 9], + [8, 10], + [1, 2], + [0, 1], + [0, 2], + [1, 3], + [2, 4], + [3, 5], + [4, 6], + ] + + pose_link_color = palette[[0, 0, 0, 0, 7, 7, 7, 9, 9, 9, 9, 9, 16, 16, 16, 16, 16, 16, 16]] + pose_kpt_color = palette[[16, 16, 16, 16, 16, 9, 9, 9, 9, 9, 9, 0, 0, 0, 0, 0, 0]] + + elif dataset == "TopDownCocoWholeBodyDataset": + # show the results + skeleton = [ + [15, 13], + [13, 11], + [16, 14], + [14, 12], + [11, 12], + [5, 11], + [6, 12], + [5, 6], + [5, 7], + [6, 8], + [7, 9], + [8, 10], + [1, 2], + [0, 1], + [0, 2], + [1, 3], + [2, 4], + [3, 5], + [4, 6], + [15, 17], + [15, 18], + [15, 19], + [16, 20], + [16, 21], + [16, 22], + [91, 92], + [92, 93], + [93, 94], + [94, 95], + [91, 96], + [96, 97], + [97, 98], + [98, 99], + [91, 100], + [100, 101], + [101, 102], + [102, 103], + [91, 104], + [104, 105], + [105, 106], + [106, 107], + [91, 108], + [108, 109], + [109, 110], + [110, 111], + [112, 113], + [113, 114], + [114, 115], + [115, 116], + [112, 117], + [117, 118], + [118, 119], + [119, 120], + [112, 121], + [121, 122], + [122, 123], + [123, 124], + [112, 125], + [125, 126], + [126, 127], + [127, 128], + [112, 129], + [129, 130], + [130, 131], + [131, 132], + ] + + pose_link_color = palette[ + [0, 0, 0, 0, 7, 7, 7, 9, 9, 9, 9, 9, 16, 16, 16, 16, 16, 16, 16] + + [16, 16, 16, 16, 16, 16] + + [0, 0, 0, 0, 4, 4, 4, 4, 8, 8, 8, 8, 12, 12, 12, 12, 16, 16, 16, 16] + + [0, 0, 0, 0, 4, 4, 4, 4, 8, 8, 8, 8, 12, 12, 12, 12, 16, 16, 16, 16] + ] + pose_kpt_color = palette[ + [16, 16, 16, 16, 16, 9, 9, 9, 9, 9, 9, 0, 0, 0, 0, 0, 0] + [0, 0, 0, 0, 0, 0] + [19] * (68 + 42) + ] + + elif dataset == "TopDownAicDataset": + skeleton = [ + [2, 1], + [1, 0], + [0, 13], + [13, 3], + [3, 4], + [4, 5], + [8, 7], + [7, 6], + [6, 9], + [9, 10], + [10, 11], + [12, 13], + [0, 6], + [3, 9], + ] + + pose_link_color = palette[[9, 9, 9, 9, 9, 9, 16, 16, 16, 16, 16, 0, 7, 7]] + pose_kpt_color = palette[[9, 9, 9, 9, 9, 9, 16, 16, 16, 16, 16, 16, 0, 0]] + + elif dataset == "TopDownMpiiDataset": + skeleton = [ + [0, 1], + [1, 2], + [2, 6], + [6, 3], + [3, 4], + [4, 5], + [6, 7], + [7, 8], + [8, 9], + [8, 12], + [12, 11], + [11, 10], + [8, 13], + [13, 14], + [14, 15], + ] + + pose_link_color = palette[[16, 16, 16, 16, 16, 16, 7, 7, 0, 9, 9, 9, 9, 9, 9]] + pose_kpt_color = palette[[16, 16, 16, 16, 16, 16, 7, 7, 0, 0, 9, 9, 9, 9, 9, 9]] + + elif dataset == "TopDownMpiiTrbDataset": + skeleton = [ + [12, 13], + [13, 0], + [13, 1], + [0, 2], + [1, 3], + [2, 4], + [3, 5], + [0, 6], + [1, 7], + [6, 7], + [6, 8], + [7, 9], + [8, 10], + [9, 11], + [14, 15], + [16, 17], + [18, 19], + [20, 21], + [22, 23], + [24, 25], + [26, 27], + [28, 29], + [30, 31], + [32, 33], + [34, 35], + [36, 37], + [38, 39], + ] + + pose_link_color = palette[[16] * 14 + [19] * 13] + pose_kpt_color = palette[[16] * 14 + [0] * 26] + + elif dataset in ("OneHand10KDataset", "FreiHandDataset", "PanopticDataset"): + skeleton = [ + [0, 1], + [1, 2], + [2, 3], + [3, 4], + [0, 5], + [5, 6], + [6, 7], + [7, 8], + [0, 9], + [9, 10], + [10, 11], + [11, 12], + [0, 13], + [13, 14], + [14, 15], + [15, 16], + [0, 17], + [17, 18], + [18, 19], + [19, 20], + ] + + pose_link_color = palette[[0, 0, 0, 0, 4, 4, 4, 4, 8, 8, 8, 8, 12, 12, 12, 12, 16, 16, 16, 16]] + pose_kpt_color = palette[[0, 0, 0, 0, 0, 4, 4, 4, 4, 8, 8, 8, 8, 12, 12, 12, 12, 16, 16, 16, 16]] + + elif dataset == "InterHand2DDataset": + skeleton = [ + [0, 1], + [1, 2], + [2, 3], + [4, 5], + [5, 6], + [6, 7], + [8, 9], + [9, 10], + [10, 11], + [12, 13], + [13, 14], + [14, 15], + [16, 17], + [17, 18], + [18, 19], + [3, 20], + [7, 20], + [11, 20], + [15, 20], + [19, 20], + ] + + pose_link_color = palette[[0, 0, 0, 4, 4, 4, 8, 8, 8, 12, 12, 12, 16, 16, 16, 0, 4, 8, 12, 16]] + pose_kpt_color = palette[[0, 0, 0, 0, 4, 4, 4, 4, 8, 8, 8, 8, 12, 12, 12, 12, 16, 16, 16, 16, 0]] + + elif dataset == "Face300WDataset": + # show the results + skeleton = [] + + pose_link_color = palette[[]] + pose_kpt_color = palette[[19] * 68] + kpt_score_thr = 0 + + elif dataset == "FaceAFLWDataset": + # show the results + skeleton = [] + + pose_link_color = palette[[]] + pose_kpt_color = palette[[19] * 19] + kpt_score_thr = 0 + + elif dataset == "FaceCOFWDataset": + # show the results + skeleton = [] + + pose_link_color = palette[[]] + pose_kpt_color = palette[[19] * 29] + kpt_score_thr = 0 + + elif dataset == "FaceWFLWDataset": + # show the results + skeleton = [] + + pose_link_color = palette[[]] + pose_kpt_color = palette[[19] * 98] + kpt_score_thr = 0 + + elif dataset == "AnimalHorse10Dataset": + skeleton = [ + [0, 1], + [1, 12], + [12, 16], + [16, 21], + [21, 17], + [17, 11], + [11, 10], + [10, 8], + [8, 9], + [9, 12], + [2, 3], + [3, 4], + [5, 6], + [6, 7], + [13, 14], + [14, 15], + [18, 19], + [19, 20], + ] + + pose_link_color = palette[[4] * 10 + [6] * 2 + [6] * 2 + [7] * 2 + [7] * 2] + pose_kpt_color = palette[[4, 4, 6, 6, 6, 6, 6, 6, 4, 4, 4, 4, 4, 7, 7, 7, 4, 4, 7, 7, 7, 4]] + + elif dataset == "AnimalFlyDataset": + skeleton = [ + [1, 0], + [2, 0], + [3, 0], + [4, 3], + [5, 4], + [7, 6], + [8, 7], + [9, 8], + [11, 10], + [12, 11], + [13, 12], + [15, 14], + [16, 15], + [17, 16], + [19, 18], + [20, 19], + [21, 20], + [23, 22], + [24, 23], + [25, 24], + [27, 26], + [28, 27], + [29, 28], + [30, 3], + [31, 3], + ] + + pose_link_color = palette[[0] * 25] + pose_kpt_color = palette[[0] * 32] + + elif dataset == "AnimalLocustDataset": + skeleton = [ + [1, 0], + [2, 1], + [3, 2], + [4, 3], + [6, 5], + [7, 6], + [9, 8], + [10, 9], + [11, 10], + [13, 12], + [14, 13], + [15, 14], + [17, 16], + [18, 17], + [19, 18], + [21, 20], + [22, 21], + [24, 23], + [25, 24], + [26, 25], + [28, 27], + [29, 28], + [30, 29], + [32, 31], + [33, 32], + [34, 33], + ] + + pose_link_color = palette[[0] * 26] + pose_kpt_color = palette[[0] * 35] + + elif dataset == "AnimalZebraDataset": + skeleton = [[1, 0], [2, 1], [3, 2], [4, 2], [5, 7], [6, 7], [7, 2], [8, 7]] + + pose_link_color = palette[[0] * 8] + pose_kpt_color = palette[[0] * 9] + + elif dataset in "AnimalPoseDataset": + skeleton = [ + [0, 1], + [0, 2], + [1, 3], + [0, 4], + [1, 4], + [4, 5], + [5, 7], + [6, 7], + [5, 8], + [8, 12], + [12, 16], + [5, 9], + [9, 13], + [13, 17], + [6, 10], + [10, 14], + [14, 18], + [6, 11], + [11, 15], + [15, 19], + ] + + pose_link_color = palette[[0] * 20] + pose_kpt_color = palette[[0] * 20] + else: + NotImplementedError() + + img_w, img_h = img_size + img = 255 * np.ones((img_h, img_w, 3), dtype=np.uint8) + img = imshow_keypoints( + img, + keypts_list, + skeleton, + kpt_score_thr, + pose_kpt_color, + pose_link_color, + radius, + thickness, + ) + alpha = 255 * (img != 255).any(axis=-1, keepdims=True).astype(np.uint8) + return np.concatenate([img, alpha], axis=-1) + + +def imshow_keypoints( + img, + pose_result, + skeleton=None, + kpt_score_thr=0.3, + pose_kpt_color=None, + pose_link_color=None, + radius=4, + thickness=1, + show_keypoint_weight=False, +): + """Draw keypoints and links on an image. + From ViTPose/mmpose/core/visualization/image.py + + Args: + img (H, W, 3) array + pose_result (list[kpts]): The poses to draw. Each element kpts is + a set of K keypoints as an Kx3 numpy.ndarray, where each + keypoint is represented as x, y, score. + kpt_score_thr (float, optional): Minimum score of keypoints + to be shown. Default: 0.3. + pose_kpt_color (np.ndarray[Nx3]`): Color of N keypoints. If None, + the keypoint will not be drawn. + pose_link_color (np.ndarray[Mx3]): Color of M links. If None, the + links will not be drawn. + thickness (int): Thickness of lines. + show_keypoint_weight (bool): If True, opacity indicates keypoint score + """ + img_h, img_w, _ = img.shape + idcs = [0, 16, 15, 18, 17, 5, 2, 6, 3, 7, 4, 12, 9, 13, 10, 14, 11] + for kpts in pose_result: + kpts = np.array(kpts, copy=False)[idcs] + + # draw each point on image + if pose_kpt_color is not None: + assert len(pose_kpt_color) == len(kpts) + for kid, kpt in enumerate(kpts): + x_coord, y_coord, kpt_score = int(kpt[0]), int(kpt[1]), kpt[2] + if kpt_score > kpt_score_thr: + color = tuple(int(c) for c in pose_kpt_color[kid]) + if show_keypoint_weight: + img_copy = img.copy() + cv2.circle(img_copy, (int(x_coord), int(y_coord)), radius, color, -1) + transparency = max(0, min(1, kpt_score)) + cv2.addWeighted(img_copy, transparency, img, 1 - transparency, 0, dst=img) + else: + cv2.circle(img, (int(x_coord), int(y_coord)), radius, color, -1) + + # draw links + if skeleton is not None and pose_link_color is not None: + assert len(pose_link_color) == len(skeleton) + for sk_id, sk in enumerate(skeleton): + pos1 = (int(kpts[sk[0], 0]), int(kpts[sk[0], 1])) + pos2 = (int(kpts[sk[1], 0]), int(kpts[sk[1], 1])) + if ( + pos1[0] > 0 + and pos1[0] < img_w + and pos1[1] > 0 + and pos1[1] < img_h + and pos2[0] > 0 + and pos2[0] < img_w + and pos2[1] > 0 + and pos2[1] < img_h + and kpts[sk[0], 2] > kpt_score_thr + and kpts[sk[1], 2] > kpt_score_thr + ): + color = tuple(int(c) for c in pose_link_color[sk_id]) + if show_keypoint_weight: + img_copy = img.copy() + X = (pos1[0], pos2[0]) + Y = (pos1[1], pos2[1]) + mX = np.mean(X) + mY = np.mean(Y) + length = ((Y[0] - Y[1]) ** 2 + (X[0] - X[1]) ** 2) ** 0.5 + angle = math.degrees(math.atan2(Y[0] - Y[1], X[0] - X[1])) + stickwidth = 2 + polygon = cv2.ellipse2Poly( + (int(mX), int(mY)), + (int(length / 2), int(stickwidth)), + int(angle), + 0, + 360, + 1, + ) + cv2.fillConvexPoly(img_copy, polygon, color) + transparency = max(0, min(1, 0.5 * (kpts[sk[0], 2] + kpts[sk[1], 2]))) + cv2.addWeighted(img_copy, transparency, img, 1 - transparency, 0, dst=img) + else: + cv2.line(img, pos1, pos2, color, thickness=thickness) + + return img diff --git a/lib/utils/vis/py_renderer/README.md b/lib/utils/vis/py_renderer/README.md new file mode 100644 index 0000000000000000000000000000000000000000..a0cbedec89aa3a1aa9b4c7fadb9ffbf08e835498 --- /dev/null +++ b/lib/utils/vis/py_renderer/README.md @@ -0,0 +1,7 @@ +## [Pyrender Renderer](https://github.com/mmatl/pyrender) + +> The code was modified from HMR2.0: https://github.com/shubham-goel/4D-Humans/blob/main/hmr2/utils/renderer.py + +This render is used to solve the compatibility issue of [Pytorch3D Renderer](../p3d_renderer/README.md). So, the API is fully adapted, but not all functions are implemented (e.g., `output_fn` is useless but still listed in the API). + +Since this renderer is fully torch-independent, and I found the color of mesh is strange. So I tend to use Pytorch3D Renderer by default instead. \ No newline at end of file diff --git a/lib/utils/vis/py_renderer/__init__.py b/lib/utils/vis/py_renderer/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..3c022e170b5a3e2aca7c0ae5bfbe2b43c26eebb3 --- /dev/null +++ b/lib/utils/vis/py_renderer/__init__.py @@ -0,0 +1,367 @@ +import os +if 'PYOPENGL_PLATFORM' not in os.environ: + os.environ['PYOPENGL_PLATFORM'] = 'egl' +import torch +import numpy as np +import trimesh +import pyrender + +from typing import List, Optional, Union, Tuple +from pathlib import Path + +from lib.utils.vis import ColorPalette +from lib.utils.data import to_numpy +from lib.utils.media import save_img + +from .utils import * + + +def render_mesh_overlay_img( + faces : Union[torch.Tensor, np.ndarray], + verts : torch.Tensor, + K4 : List, + img : np.ndarray, + output_fn : Optional[Union[str, Path]] = None, + device : str = 'cuda', + resize : float = 1.0, + Rt : Optional[Tuple[torch.Tensor]] = None, + mesh_color : Optional[Union[List[float], str]] = 'green', +): + ''' + Render the mesh overlay on the input video frames. + + ### Args + - faces: Union[torch.Tensor, np.ndarray], (V, 3) + - verts: torch.Tensor, (V, 3) + - K4: List + - [fx, fy, cx, cy], the components of intrinsic camera matrix. + - img: np.ndarray, (H, W, 3) + - output_fn: Union[str, Path] or None + - The output file path, if None, return the rendered img. + - fps: int, default 30 + - device: str, default 'cuda' + - resize: float, default 1.0 + - The resize factor of the output video. + - Rt: Tuple of Tensor, default None + - The extrinsic camera matrix, in the form of (R, t). + ''' + frame = render_mesh_overlay_video( + faces = faces, + verts = verts[None], + K4 = K4, + frames = img[None], + device = device, + resize = resize, + Rt = Rt, + mesh_color = mesh_color, + )[0] + + if output_fn is None: + return frame + else: + save_img(frame, output_fn) + + +def render_mesh_overlay_video( + faces : Union[torch.Tensor, np.ndarray], + verts : Union[torch.Tensor, np.ndarray], + K4 : List, + frames : np.ndarray, + output_fn : Optional[Union[str, Path]] = None, + fps : int = 30, + device : str = 'cuda', + resize : float = 1.0, + Rt : Tuple = None, + mesh_color : Optional[Union[List[float], str]] = 'green', +): + ''' + Render the mesh overlay on the input video frames. + + ### Args + - faces: Union[torch.Tensor, np.ndarray], (V, 3) + - verts: Union[torch.Tensor, np.ndarray], (L, V, 3) + - K4: List + - [fx, fy, cx, cy], the components of intrinsic camera matrix. + - frames: np.ndarray, (L, H, W, 3) + - output_fn: useless, only for compatibility. + - fps: useless, only for compatibility. + - device: useless, only for compatibility. + - resize: useless, only for compatibility. + - Rt: Tuple, default None + - The extrinsic camera matrix, in the form of (R, t). + ''' + faces, verts = to_numpy(faces), to_numpy(verts) + assert len(K4) == 4, 'K4 must be a list of 4 elements.' + assert frames.shape[0] == verts.shape[0], 'The length of frames and verts must be the same.' + assert frames.shape[-1] == 3, 'The last dimension of frames must be 3.' + if isinstance(mesh_color, str): + mesh_color = ColorPalette.presets_float[mesh_color] + + # Prepare the data. + L = len(frames) + frame_w, frame_h = frames.shape[-2], frames.shape[-3] + + renderer = pyrender.OffscreenRenderer( + viewport_width = frame_w, + viewport_height = frame_h, + point_size = 1.0 + ) + + # Camera + camera, cam_pose = create_camera(K4, Rt) + + # Scene. + material = pyrender.MetallicRoughnessMaterial( + metallicFactor = 0.0, + alphaMode = 'OPAQUE', + baseColorFactor = (*mesh_color, 1.0) + ) + + # Light. + light_nodes = create_raymond_lights() + + results = [] + for i in range(L): + mesh = trimesh.Trimesh(verts[i].copy(), faces.copy()) + # if side_view: + # rot = trimesh.transformations.rotation_matrix( + # np.radians(rot_angle), [0, 1, 0]) + # mesh.apply_transform(rot) + # elif top_view: + # rot = trimesh.transformations.rotation_matrix( + # np.radians(rot_angle), [1, 0, 0]) + # mesh.apply_transform(rot) + rot = trimesh.transformations.rotation_matrix( + np.radians(180), [1, 0, 0]) + mesh.apply_transform(rot) + mesh = pyrender.Mesh.from_trimesh(mesh, material=material) + + scene = pyrender.Scene( + bg_color = [0.0, 0.0, 0.0, 0.0], + ambient_light = (0.3, 0.3, 0.3), + ) + + scene.add(mesh, 'mesh') + scene.add(camera, pose=cam_pose) + + # Light. + for node in light_nodes: + scene.add_node(node) + + # Render. + result_rgba, rend_depth = renderer.render(scene, flags=pyrender.RenderFlags.RGBA) + + valid_mask = result_rgba.astype(np.float32)[:, :, [-1]] / 255.0 # (H, W, 1) + bg = frames[i] # (H, W, 3) + final = result_rgba[:, :, :3] * valid_mask + bg * (1 - valid_mask) + final = final.astype(np.uint8) # (H, W, 3) + results.append(final) + results = np.stack(results, axis=0) # (L, H, W, 3) + + renderer.delete() + return results + + + +def render_meshes_overlay_img( + faces_all : Union[torch.Tensor, np.ndarray], + verts_all : Union[torch.Tensor, np.ndarray], + cam_t_all : Union[torch.Tensor, np.ndarray], + K4 : List, + img : np.ndarray, + output_fn : Optional[Union[str, Path]] = None, + device : str = 'cuda', + resize : float = 1.0, + Rt : Optional[Tuple[torch.Tensor]] = None, + mesh_color : Optional[Union[List[float], str]] = 'green', + view : str = 'front', + ret_rgba : bool = False, +): + ''' + Render the mesh overlay on the input video frames. + + ### Args + - faces_all: Union[torch.Tensor, np.ndarray], ((Nm,) V, 3) + - verts_all: Union[torch.Tensor, np.ndarray], ((Nm,) V, 3) + - cam_t_all: Union[torch.Tensor, np.ndarray], ((Nm,) 3) + - K4: List + - [fx, fy, cx, cy], the components of intrinsic camera matrix. + - img: np.ndarray, (H, W, 3) + - output_fn: Union[str, Path] or None + - The output file path, if None, return the rendered img. + - fps: int, default 30 + - device: str, default 'cuda' + - resize: float, default 1.0 + - The resize factor of the output video. + - Rt: Tuple of Tensor, default None + - The extrinsic camera matrix, in the form of (R, t). + - view: str, default 'front', {'front', 'side90d', 'side60d', 'top90d'} + - ret_rgba: bool, default False + - If True, return rgba images, otherwise return rgb images. + - For view is not 'front', the background will become transparent. + ''' + if len(verts_all.shape) == 2: + verts_all = verts_all[None] # (1, V, 3) + elif len(verts_all.shape) == 3: + verts_all = verts_all[:, None] # ((Nm,) 1, V, 3) + else: + raise ValueError('The shape of verts_all is not correct.') + if len(cam_t_all.shape) == 1: + cam_t_all = cam_t_all[None] # (1, 3) + elif len(cam_t_all.shape) == 2: + cam_t_all = cam_t_all[:, None] # ((Nm,) 1, 3) + else: + raise ValueError('The shape of verts_all is not correct.') + frame = render_meshes_overlay_video( + faces_all = faces_all, + verts_all = verts_all, + cam_t_all = cam_t_all, + K4 = K4, + frames = img[None], + device = device, + resize = resize, + Rt = Rt, + mesh_color = mesh_color, + view = view, + ret_rgba = ret_rgba, + )[0] + + if output_fn is None: + return frame + else: + save_img(frame, output_fn) + + +def render_meshes_overlay_video( + faces_all : Union[torch.Tensor, np.ndarray], + verts_all : Union[torch.Tensor, np.ndarray], + cam_t_all : Union[torch.Tensor, np.ndarray], + K4 : List, + frames : np.ndarray, + output_fn : Optional[Union[str, Path]] = None, + fps : int = 30, + device : str = 'cuda', + resize : float = 1.0, + Rt : Tuple = None, + mesh_color : Optional[Union[List[float], str]] = 'green', + view : str = 'front', + ret_rgba : bool = False, +): + ''' + Render the mesh overlay on the input video frames. + + ### Args + - faces_all: Union[torch.Tensor, np.ndarray], ((Nm,) V, 3) + - verts_all: Union[torch.Tensor, np.ndarray], ((Nm,) L, V, 3) + - cam_t_all: Union[torch.Tensor, np.ndarray], ((Nm,) L, 3) + - K4: List + - [fx, fy, cx, cy], the components of intrinsic camera matrix. + - frames: np.ndarray, (L, H, W, 3) + - output_fn: useless, only for compatibility. + - fps: useless, only for compatibility. + - device: useless, only for compatibility. + - resize: useless, only for compatibility. + - Rt: Tuple, default None + - The extrinsic camera matrix, in the form of (R, t). + - view: str, default 'front', {'front', 'side90d', 'side60d', 'top90d'} + - ret_rgba: bool, default False + - If True, return rgba images, otherwise return rgb images. + - For view is not 'front', the background will become transparent. + ''' + faces_all, verts_all = to_numpy(faces_all), to_numpy(verts_all) + if len(verts_all.shape) == 3: + verts_all = verts_all[None] # (1, L, V, 3) + if len(cam_t_all.shape) == 2: + cam_t_all = cam_t_all[None] # (1, L, 3) + Nm, L, _, _ = verts_all.shape + if len(faces_all.shape) == 2: + faces_all = faces_all[None].repeat(Nm, axis=0) # (Nm, V, 3) + + assert len(K4) == 4, 'K4 must be a list of 4 elements.' + assert frames.shape[0] == L, 'The length of frames and verts must be the same.' + assert frames.shape[-1] == 3, 'The last dimension of frames must be 3.' + assert len(verts_all.shape) == 4, 'The shape of verts_all must be (Nm, L, V, 3).' + assert len(faces_all.shape) == 3, 'The shape of faces_all must be (Nm, V, 3).' + if isinstance(mesh_color, str): + mesh_color = ColorPalette.presets_float[mesh_color] + + # Prepare the data. + frame_w, frame_h = frames.shape[-2], frames.shape[-3] + + renderer = pyrender.OffscreenRenderer( + viewport_width = frame_w, + viewport_height = frame_h, + point_size = 1.0 + ) + + # Camera + camera, cam_pose = create_camera(K4, Rt) + + # Scene. + material = pyrender.MetallicRoughnessMaterial( + metallicFactor = 0.0, + alphaMode = 'OPAQUE', + baseColorFactor = (*mesh_color, 1.0) + ) + + # Light. + light_nodes = create_raymond_lights() + + results = [] + for i in range(L): + scene = pyrender.Scene( + bg_color = [0.0, 0.0, 0.0, 0.0], + ambient_light = (0.3, 0.3, 0.3), + ) + + for mid in range(Nm): + mesh = trimesh.Trimesh(verts_all[mid][i].copy(), faces_all[mid].copy()) + if view == 'front': + pass + elif view == 'side90d': + rot = trimesh.transformations.rotation_matrix(np.radians(-90), [0, 1, 0]) + mesh.apply_transform(rot) + elif view == 'side60d': + rot = trimesh.transformations.rotation_matrix(np.radians(-60), [0, 1, 0]) + mesh.apply_transform(rot) + elif view == 'top90d': + rot = trimesh.transformations.rotation_matrix(np.radians(90), [1, 0, 0]) + mesh.apply_transform(rot) + else: + raise ValueError('The view is not supported.') + trans = trimesh.transformations.translation_matrix(to_numpy(cam_t_all[mid][i])) + mesh.apply_transform(trans) + rot = trimesh.transformations.rotation_matrix(np.radians(180), [1, 0, 0]) + mesh.apply_transform(rot) + mesh = pyrender.Mesh.from_trimesh(mesh, material=material) + + scene.add(mesh, f'mesh_{mid}') + scene.add(camera, pose=cam_pose) + + # Light. + for node in light_nodes: + scene.add_node(node) + + # Render. + result_rgba, rend_depth = renderer.render(scene, flags=pyrender.RenderFlags.RGBA) + + valid_mask = result_rgba.astype(np.float32)[:, :, [-1]] / 255.0 # (H, W, 1) + if view == 'front': + bg = frames[i] # (H, W, 3) + else: + bg = np.ones_like(frames[i]) * 255 # (H, W, 3) + if ret_rgba: + if view == 'front': + bg_alpha = np.ones_like(bg[..., [0]]) * 255 # (H, W, 1) + else: + bg_alpha = np.zeros_like(bg[..., [0]]) * 255 # (H, W, 1) + bg = np.concatenate([bg, bg_alpha], axis=-1) # (H, W, 4) + final = result_rgba * valid_mask + bg * (1 - valid_mask) # (H, W, 4) + else: + final = result_rgba[:, :, :3] * valid_mask + bg * (1 - valid_mask) + final = final.astype(np.uint8) # (H, W, 3) + results.append(final) + results = np.stack(results, axis=0) # (L, H, W, 3) + + renderer.delete() + return results diff --git a/lib/utils/vis/py_renderer/utils.py b/lib/utils/vis/py_renderer/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..dd977e3ce62922a392bac6a38b3ea5d90c1c7ec6 --- /dev/null +++ b/lib/utils/vis/py_renderer/utils.py @@ -0,0 +1,52 @@ +import numpy as np +import pyrender + +from typing import List, Optional, Union, Tuple +from lib.utils.data import to_numpy + + +def create_camera(K4:List, Rt:Optional[Tuple]): + if Rt is not None: + cam_R, cam_t = Rt + cam_R = to_numpy(cam_R).copy() + cam_t = to_numpy(cam_t).copy() + cam_t[0] *= -1 + else: + cam_R = np.eye(3) + cam_t = np.zeros(3) + fx, fy, cx, cy = K4 + cam_pose = np.eye(4) + cam_pose[:3, :3] = cam_R + cam_pose[:3, 3] = cam_t + camera = pyrender.IntrinsicsCamera( fx=fx, fy=fy, cx=cx, cy=cy, zfar=1e12 ) + return camera, cam_pose + + +def create_raymond_lights() -> List[pyrender.Node]: + ''' Return raymond light nodes for the scene. ''' + thetas = np.pi * np.array([1.0 / 6.0, 1.0 / 6.0, 1.0 / 6.0]) + phis = np.pi * np.array([0.0, 2.0 / 3.0, 4.0 / 3.0]) + + nodes = [] + + for phi, theta in zip(phis, thetas): + xp = np.sin(theta) * np.cos(phi) + yp = np.sin(theta) * np.sin(phi) + zp = np.cos(theta) + + z = np.array([xp, yp, zp]) + z = z / np.linalg.norm(z) + x = np.array([-z[1], z[0], 0.0]) + if np.linalg.norm(x) == 0: + x = np.array([1.0, 0.0, 0.0]) + x = x / np.linalg.norm(x) + y = np.cross(z, x) + + matrix = np.eye(4) + matrix[:3,:3] = np.c_[x,y,z] + nodes.append(pyrender.Node( + light = pyrender.DirectionalLight(color=np.ones(3), intensity=0.75), + matrix = matrix + )) + + return nodes \ No newline at end of file diff --git a/lib/utils/vis/wis3d.py b/lib/utils/vis/wis3d.py new file mode 100644 index 0000000000000000000000000000000000000000..e1c5aa5d4398d191091a2c47ebefb0ef2e659e36 --- /dev/null +++ b/lib/utils/vis/wis3d.py @@ -0,0 +1,493 @@ +import torch +import numpy as np + +from typing import Union, List, overload +from wis3d import Wis3D + +from lib.platform import PM +from lib.utils.geometry.rotation import axis_angle_to_matrix + + +class HWis3D(Wis3D): + ''' Abstraction of Wis3D for human motion. ''' + + def __init__( + self, + out_path : str = PM.outputs / 'wis3d', + seq_name : str = 'debug', + xyz_pattern : tuple = ('x', 'y', 'z'), + ): + seq_name = seq_name.replace('/', '-') + super().__init__(out_path, seq_name, xyz_pattern) + + + def add_text(self, text:str): + ''' + Add an item of vertices whose name is used to put the text message. *Dirty use!* + + ### Args + - text: str + ''' + fake_verts = np.array([[0, 0, 0]]) + self.add_point_cloud( + vertices = fake_verts, + colors = None, + name = text, + ) + + + def add_text_seq(self, texts:List[str], offset:int=0): + ''' + Add an item of vertices whose name is used to put the text message. *Dirty use!* + + ### Args + - texts: List[str] + - The list of text messages. + - offset: int, default = 0 + - The offset for the sequence index. + ''' + fake_verts = np.array([[0, 0, 0]]) + for i, text in enumerate(texts): + self.set_scene_id(i + offset) + self.add_point_cloud( + vertices = fake_verts, + colors = None, + name = text, + ) + + def add_image_seq(self, imgs:List[np.ndarray], name:str, offset:int=0): + ''' + Add an item of vertices whose name is used to put the image. *Dirty use!* + + ### Args + - imgs: List[np.ndarray] + - The list of images. + - offset: int, default = 0 + - The offset for the sequence index. + ''' + for i, img in enumerate(imgs): + self.set_scene_id(i + offset) + self.add_image( + image = img, + name = name, + ) + + def add_motion_mesh( + self, + verts : Union[torch.Tensor, np.ndarray], + faces : Union[torch.Tensor, np.ndarray], + name : str, + offset: int = 0, + ): + ''' + Add sequence of vertices and face(s) to the wis3d viewer. + + ### Args + - verts: torch.Tensor or np.ndarray, (L, V, 3), L ~ sequence length, V ~ number of vertices + - faces: torch.Tensor or np.ndarray, (F, 3) or (L, F, 3), F ~ number of faces, L ~ sequence length + - name: str + - The name of the point cloud. + - offset: int, default = 0 + - The offset for the sequence index. + ''' + assert (len(verts.shape) == 3), 'The input `verts` should have 3 dimensions: (L, V, 3).' + assert (verts.shape[-1] == 3), 'The last dimension of `verts` should be 3.' + if isinstance(verts, np.ndarray): + verts = torch.from_numpy(verts) + if isinstance(faces, torch.Tensor): + faces = faces.detach().cpu().numpy() + if len(faces.shape) == 2: + faces = faces[None].repeat(verts.shape[0], 0) + assert (len(faces.shape) == 3), 'The input `faces` should have 2 or 3 dimensions: (F, 3) or (L, F, 3).' + assert (faces.shape[-1] == 3), 'The last dimension of `faces` should be 3.' + assert (verts.shape[0] == faces.shape[0]), 'The first dimension of `verts` and `faces` should be the same.' + + L, _, _ = verts.shape + verts = verts.detach().cpu() + + # Add vertices frame by frame. + for i in range(L): + self.set_scene_id(i + offset) + self.add_mesh( + vertices = verts[i], + faces = faces[i], + name = name, + ) # type: ignore + + # Reset Wis3D scene id. + self.set_scene_id(0) + + + def add_motion_verts( + self, + verts : Union[torch.Tensor, np.ndarray], + name : str, + offset: int = 0, + ): + ''' + Add sequence of vertices to the wis3d viewer. + + ### Args + - verts: torch.Tensor or np.ndarray, (L, V, 3), L ~ sequence length, V ~ number of vertices + - name: str + - The name of the point cloud. + - offset: int, default = 0 + - The offset for the sequence index. + ''' + assert (len(verts.shape) == 3), 'The input `verts` should have 3 dimensions: (L, V, 3).' + assert (verts.shape[-1] == 3), 'The last dimension of `verts` should be 3.' + if isinstance(verts, np.ndarray): + verts = torch.from_numpy(verts) + + L, _, _ = verts.shape + verts = verts.detach().cpu() + + # Add vertices frame by frame. + for i in range(L): + self.set_scene_id(i + offset) + self.add_point_cloud( + vertices = verts[i], + colors = None, + name = name, + ) + + # Reset Wis3D scene id. + self.set_scene_id(0) + + + def add_motion_skel( + self, + joints : Union[torch.Tensor, np.ndarray], + bones : Union[list, torch.Tensor], + colors : Union[list, torch.Tensor], + name : str, + offset : int = 0, + threshold : float = 0.5, + ): + ''' + Add sequence of joints with specific skeleton to the wis3d viewer. + + ### Args + - joints: torch.Tensor or np.ndarray, shape = (L, J, 3) or (L, J, 4), L ~ sequence length, J ~ number of joints + - bones: list + - A list of bones of the skeleton, i.e. the edge in the kinematic trees. + - colors: list + - name: str + - The name of the point cloud. + - offset: int, default = 0 + - The offset for the sequence index. + - threshold: float, default = 0.5 + - Threshold to filter the confidence of the joints. It's useless when no confidence provided. + ''' + assert (len(joints.shape) == 3), 'The input `joints` should have 3 dimensions: (L, J, 3).' + assert (joints.shape[-1] == 3 or joints.shape[-1] == 4), 'The last dimension of `joints` should be 3 or 4.' + if isinstance(joints, np.ndarray): + joints = torch.from_numpy(joints) + if isinstance(bones, List): + bones = torch.tensor(bones) + if isinstance(colors, List): + colors = torch.tensor(colors) + + # Get the sequence length. + joints = joints.detach().cpu() # (L, J, 3) or (L, J, 4) + L, J, D = joints.shape + if D == 4: + conf = joints[:, :, 3] + joints = joints[:, :, :3] + else: + conf = None + + # Add vertices frame by frame. + for i in range(L): + self.set_scene_id(i + offset) + bones_s = joints[i][bones[:, 0]] + bones_e = joints[i][bones[:, 1]] + if conf is not None: + mask = torch.logical_and(conf[i][bones[:, 0]] > threshold, conf[i][bones[:, 1]] > threshold) + bones_s, bones_e = bones_s[mask], bones_e[mask] + if len(bones_s) > 0: + self.add_lines( + start_points = bones_s, + end_points = bones_e, + colors = colors, + name = name, + ) + + # Reset Wis3D scene id. + self.set_scene_id(0) + + + def add_vec_seq( + self, + vecs : torch.Tensor, + name : str, + offset : int = 0, + seg_num : int = 16, + ): + ''' + Add directional line sequence to the wis3d viewer. + + The line will be gradient colored, and the direction of the vector is visualized as from dark to light. + + ### Args + - vecs: torch.Tensor, (L, 2, 3) or (L, N, 2, 3), L ~ sequence length, N ~ vectors counts in one frame, + then give the start 3D point and end 3D point. + - name: str + - The name of the vector. + - offset: int, default = 0 + - The offset for the sequence index. + - seg_num: int, default = 16 + - The number of segments for gradient color, will just change the visualization effect. + ''' + if len(vecs.shape) == 3: + vecs = vecs[:, None, :, :] # (L, 2, 3) -> (L, 1, 2, 3) + assert (len(vecs.shape) == 4), 'The input `vecs` should have 3 or 4 dimensions: (L, 2, 3) or (L, N, 2, 3).' + assert (vecs.shape[-2:] == (2, 3)), f'The last two dimension of `vecs` should be (2, 3), but got vecs.shape = {vecs.shape}.' + + # Get the sequence length. + L, N, _, _ = vecs.shape + vecs = vecs.detach().cpu() + + # Cut the line into segments. + steps_delta = (vecs[:, :, [1]] - vecs[:, :, [0]]) / (seg_num + 1) # (L, N, 1, 3) + steps_cnt = torch.arange(seg_num + 1).reshape((1, 1, seg_num + 1, 1)) # (1, 1, seg_num+1, 1) + segs = steps_delta * steps_cnt + vecs[:, :, [0]] # (L, N, seg_num+1, 3) + start_pts = segs[:, :, :-1] # (L, N, seg_num, 3) + end_pts = segs[:, :, 1:] # (L, N, seg_num, 3) + + # Prepare the gradient colors. + grad_colors = torch.linspace(0, 255, seg_num).reshape((1, seg_num, 1)).repeat(N, 1, 3) # (N, seg_num, 3) + + # Add vertices frame by frame. + for i in range(L): + self.set_scene_id(i + offset) + self.add_lines( + start_points = start_pts[i].reshape(-1, 3), + end_points = end_pts[i].reshape(-1, 3), + colors = grad_colors.reshape(-1, 3), + name = name, + ) + + # Reset Wis3D scene id. + self.set_scene_id(0) + + + def add_traj( + self, + positions : torch.Tensor, + name : str, + offset : int = 0, + ): + ''' + Visualize the the positions change across the time as trajectory. The newer position will be brighter. + + ### Args + - positions: torch.Tensor, (L, 3), L ~ sequence length + - name: str + - The name of the trajectory. + - offset: int, default = 0 + - The offset for the sequence index. + ''' + assert (len(positions.shape) == 2), 'The input `positions` should have 2 dimensions: (L, 3).' + assert (positions.shape[-1] == 3), 'The last dimension of `positions` should be 3.' + + # Get the sequence length. + L, _ = positions.shape + positions = positions.detach().cpu() + traj = positions[[0]] # (1, 3) + + # Prepare the gradient colors. + grad_colors = torch.linspace(208, 48, L).reshape((L, 1)).repeat(1, 3) # (L, 3) + + for i in range(L): + traj = torch.cat((traj, positions[[i]]), dim=0) # (i+2, 3) + self.set_scene_id(i + offset) + self.add_lines( + start_points = traj[:-1], + end_points = traj[1:], + colors = grad_colors[-(i+1):], + name = name, + ) + + # Reset Wis3D scene id. + self.set_scene_id(0) + + + def add_sphere_sensors( + self, + positions : torch.Tensor, + radius : Union[torch.Tensor, float], + activities : torch.Tensor, + name : str, + ): + ''' + Draw the sphere sensors with different colors to represent the activities. The color is from white to red. + + ### Args + - positions: torch.Tensor, (N, 3), N ~ number of sensors + - radius: torch.Tensor or float, (N,), N ~ number of sensors + - activities: torch.Tensor, (N) + - The activities of the sensors, from 0 to 1. + - name: str + - The name of the spheres. + ''' + assert (len(positions.shape) == 2), 'The input `positions` should have 2 dimensions: (N, 3).' + assert (positions.shape[-1] == 3), 'The last dimension of `positions` should be 3.' + N, _ = positions.shape + if isinstance(radius, float): + radius = torch.Tensor(radius).reshape(1).repeat(N) # (N) + elif len(radius.shape) == 0: + radius = radius.reshape(1).repeat(N) + + colors = torch.ones(size=(N, 3)).float() + colors[:, 0] = 255 + colors[:, 1] = (1 - activities) ** 2 * 255 + colors[:, 2] = (1 - activities) ** 2 * 255 + self.add_spheres( + centers = positions, + radius = radius, + colors = colors, + name = name, + ) + + + def add_sphere_sensors_seq( + self, + positions : torch.Tensor, + radius : Union[torch.Tensor, float], + activities : torch.Tensor, + name : str, + offset : int = 0, + ): + ''' + Draw the sphere sensors with different colors to represent the activities. The color is from white to red. + + ### Args + - positions: torch.Tensor, (L, N, 3), N ~ number of sensors + - radius: torch.Tensor or float, (L, N,), N ~ number of sensors + - activities: torch.Tensor, (L, N) + - The activities of the sensors, from 0 to 1. + - name: str + - The name of the spheres. + - offset: int, default = 0 + - The offset for the sequence index. + ''' + assert (len(positions.shape) == 3), 'The input `positions` should have 3 dimensions: (L, N, 3).' + assert (positions.shape[-1] == 3), 'The last dimension of `positions` should be 3.' + L, N, _ = positions.shape + + for i in range(L): + self.set_scene_id(i + offset) + self.add_sphere_sensors( + positions = positions[i], + radius = radius, + activities = activities[i], + name = name, + ) + + + # ===== Overriding methods from original Wis3D. ===== + + + def add_lines( + self, + start_points: torch.Tensor, + end_points : torch.Tensor, + colors : Union[list, torch.Tensor] = None, + name : str = None, + thickness : float = 0.01, + resolution : int = 4, + ): + ''' + Add lines by points. Overriding the original `add_lines` method to use mesh to provide browser from crash. + + ### Args + - start_points: torch.Tensor, (N, 3), N ~ number of lines + - end_points: torch.Tensor, (N, 3), N ~ number of lines + - colors: list or torch.Tensor, (N, 3) + - The color of the lines, from 0 to 255. + - name: str + - The name of the vector. + - thickness: float, default = 0.01 + - The thickness of the lines. + - resolution: int, default = 3 + - The 'line' was actually a poly-cylinder, and the resolution how it looks like a cylinder. + ''' + if isinstance(colors, List): + colors = torch.tensor(colors) + + assert (len(start_points.shape) == 2), 'The input `start_points` should have 2 dimensions: (N, 3).' + assert (len(end_points.shape) == 2), 'The input `end_points` should have 2 dimensions: (N, 3).' + assert (start_points.shape == end_points.shape), 'The input `start_points` and `end_points` should have the same shape.' + + # ===== Prepare the data. ===== + N, _ = start_points.shape + device = start_points.device + dir = end_points - start_points # (N, 3) + dis = torch.norm(dir, dim=-1, keepdim=True) # (N, 1) + dir = dir / dis # (N, 3) + K = resolution + 1 # the first & the last point share the position + # Find out directions that are negative to the y-axis. + vec_y = torch.Tensor([[0, 1, 0]]).float().to(device) # (1, 3) + neg_mask = (dir @ vec_y.transpose(-1, -2) < 0).squeeze() # (N,) + + # ===== Get the ending surface vertices of the cylinder. ===== + # 1. Get the surface vertices template in x-z plain. + radius = torch.linspace(0, 2*torch.pi, K) # (K,) + v_ending_temp = \ + torch.stack( + [torch.cos(radius), torch.zeros_like(radius), torch.sin(radius)], + dim = -1 + ) # (K, 3) + v_ending_temp *= thickness # (K, 3) + v_ending_temp = v_ending_temp[None].repeat(N, 1, 1) # (N, K, 3) + + # 2. Rotate the template plane to the direction of the line. + rot_axis = torch.linalg.cross(vec_y, dir) # (N, 3) + rot_axis[neg_mask] *= -1 + rot_mat = axis_angle_to_matrix(rot_axis) # (N, 3, 3) + v_ending_temp = v_ending_temp @ rot_mat.transpose(-1, -2) + v_ending_temp = v_ending_temp.to(device) + + # 3. Move the template plane to the start and end points and get the cylinder vertices. + v_cylinder_start = v_ending_temp + start_points[:, None] # (N, K, 3) + v_cylinder_end = v_ending_temp + end_points[:, None] # (N, K, 3) + # Swap the start and end points for the negative direction to adjust the normal direction. + v_cylinder_start[neg_mask], v_cylinder_end[neg_mask] = v_cylinder_end[neg_mask], v_cylinder_start[neg_mask] + v_cylinder = torch.cat([v_cylinder_start, v_cylinder_end], dim=1) # (N, 2*K, 3) + + # ===== Calculate the face index. ===== + idx = torch.arange(0, 2*K, device=device).to(device) # (2*K,) + idx_s, idx_e = idx[:K], idx[K:] + f_cylinder = torch.cat([ + # Two ending surface. + torch.stack([idx_s[0].repeat(K-2), idx_s[1:-1], idx_s[2:]], dim=-1), + torch.stack([idx_e[0].repeat(K-2), idx_e[2:], idx_e[1:-1]], dim=-1), + # The side surface. + torch.stack([idx_e[:-1], idx_s[1:], idx_s[:-1]], dim=-1), + torch.stack([idx_e[:-1], idx_e[1:], idx_s[1:]], dim=-1), + ], dim=0) # (4*K-4, 3) + f_cylinder = f_cylinder[None].repeat(N, 1, 1) # (N, 4*K-4, 3) + + # ===== Calculate the face index. ===== + if colors is not None: + c_cylinder = colors / 255.0 # (N, 3) + c_cylinder = c_cylinder[:, None].repeat(1, 2*K, 1) # (N, 2*K, 3) + else: + c_cylinder = None + + N, V = v_cylinder.shape[:2] + v_cylinder = v_cylinder.reshape(-1, 3) # (N*(2*K), 3) + + # ===== Manually match the points index before flatten. ===== + f_cylinder = f_cylinder + torch.arange(0, N, device=device).unsqueeze(1).unsqueeze(1) * V + f_cylinder = f_cylinder.reshape(-1, 3) # (N*(4*K-4), 3) + if c_cylinder is not None: + c_cylinder = c_cylinder.reshape(-1, 3) # (N*(2*K), 3) + + self.add_mesh( + vertices = v_cylinder, + vertex_colors = c_cylinder, + faces = f_cylinder, + name = name, + ) \ No newline at end of file diff --git a/lib/version.py b/lib/version.py new file mode 100644 index 0000000000000000000000000000000000000000..cc37b911090d58edb9b4a5eba6557137a44832d4 --- /dev/null +++ b/lib/version.py @@ -0,0 +1,3 @@ +# coding=utf-8 +__version__ = 'v1.0' +DEFAULT_HSMR_ROOT = 'data_inputs/released_models/HSMR-ViTH-r1d0' \ No newline at end of file diff --git a/requirements_part1.txt b/requirements_part1.txt new file mode 100644 index 0000000000000000000000000000000000000000..162970ea39a1b5bea9737f6e4a4ebc0d8de87fe5 --- /dev/null +++ b/requirements_part1.txt @@ -0,0 +1,24 @@ +braceexpand +colorlog +einops +hydra-core +imageio[ffmpeg] +ipdb +matplotlib +numpy>=1.24.4 +omegaconf +opencv_python +pyrender +pytorch_lightning +Requests +rich +setuptools +smplx +timm +torch +tqdm +trimesh +webdataset +wis3d +yacs +gradio \ No newline at end of file diff --git a/requirements_part2.txt b/requirements_part2.txt new file mode 100644 index 0000000000000000000000000000000000000000..3134d06ff9136f12c4daf3e77de8131cbedc7db6 --- /dev/null +++ b/requirements_part2.txt @@ -0,0 +1,2 @@ +git+https://github.com/facebookresearch/detectron2.git +git+https://github.com/mattloper/chumpy.git \ No newline at end of file diff --git a/setup.py b/setup.py new file mode 100644 index 0000000000000000000000000000000000000000..287ccc9f960d2e646a920ce5e1e07ef0096a8eaa --- /dev/null +++ b/setup.py @@ -0,0 +1,15 @@ +from setuptools import setup, find_packages + + +# import `__version__` +with open('lib/version.py') as f: + exec(f.read()) + +setup( + name = 'lib', + version = __version__, # type: ignore + author = 'Yan Xia', + author_email = 'isshikihugh@gmail.com', + description = 'Official implementation of HSMR.', + packages = find_packages(), +) \ No newline at end of file diff --git a/tools/.gitignore b/tools/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..58f8d2d185d1d3ea32760f85cdd1d87aeaa602c1 --- /dev/null +++ b/tools/.gitignore @@ -0,0 +1 @@ +*.flag \ No newline at end of file diff --git a/tools/prepare_data.sh b/tools/prepare_data.sh new file mode 100644 index 0000000000000000000000000000000000000000..20a7c1d3864ea59afa03569d8180ad15ce746c0d --- /dev/null +++ b/tools/prepare_data.sh @@ -0,0 +1,10 @@ +cd data_inputs +echo "[SCRIPT-LOG] Downloading body models..." ; date +wget -q -c 'https://huggingface.co/IsshikiHugh/HSMR-data_inputs/resolve/main/body_models.tar.gz' -O body_models.tar.gz +tar -xzf body_models.tar.gz + +mkdir -p released_models +cd released_models +echo "[SCRIPT-LOG] Downloading checkpoints..." ; date +wget -q -c 'https://huggingface.co/IsshikiHugh/HSMR-data_inputs/resolve/main/released_models/HSMR-ViTH-r1d0.tar.gz' -O HSMR-ViTH-r1d0.tar.gz +tar -xzf HSMR-ViTH-r1d0.tar.gz \ No newline at end of file diff --git a/tools/service.py b/tools/service.py new file mode 100644 index 0000000000000000000000000000000000000000..e806bbf93f15d84c22cd2618edbcebce83255c92 --- /dev/null +++ b/tools/service.py @@ -0,0 +1,12 @@ +from lib.kits.gradio import * + +import os + +os.environ['PYOPENGL_PLATFORM'] = 'egl' + + +if __name__ == '__main__': + # Start serving. + backend = HSMRBackend(device='cpu') + hsmr_service = HSMRService(backend) + hsmr_service.serve() \ No newline at end of file diff --git a/tools/start.sh b/tools/start.sh new file mode 100644 index 0000000000000000000000000000000000000000..7de99574e0d84a456c2e0cbb7d54919a6efee446 --- /dev/null +++ b/tools/start.sh @@ -0,0 +1,7 @@ +bash tools/prepare_data.sh + +echo "[SCRIPT-LOG] Installing requirements..." ; date +pip install -e . + +echo "[SCRIPT-LOG] Starting service..." ; date +python tools/service.py \ No newline at end of file