diff --git a/.gitattributes b/.gitattributes
index a6344aac8c09253b3b630fb776ae94478aa0275b..32e0c12ec7fcceaf27417b9d03bc9381fdfc607a 100644
--- a/.gitattributes
+++ b/.gitattributes
@@ -33,3 +33,38 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
*.zip filter=lfs diff=lfs merge=lfs -text
*.zst filter=lfs diff=lfs merge=lfs -text
*tfevents* filter=lfs diff=lfs merge=lfs -text
+assets/demo.gif filter=lfs diff=lfs merge=lfs -text
+assets/driving_audio/1.wav filter=lfs diff=lfs merge=lfs -text
+assets/driving_audio/2.wav filter=lfs diff=lfs merge=lfs -text
+assets/driving_audio/3.wav filter=lfs diff=lfs merge=lfs -text
+assets/driving_audio/4.wav filter=lfs diff=lfs merge=lfs -text
+assets/driving_audio/5.wav filter=lfs diff=lfs merge=lfs -text
+assets/driving_audio/6.wav filter=lfs diff=lfs merge=lfs -text
+assets/driving_video/1.mp4 filter=lfs diff=lfs merge=lfs -text
+assets/driving_video/2.mp4 filter=lfs diff=lfs merge=lfs -text
+assets/driving_video/3.mp4 filter=lfs diff=lfs merge=lfs -text
+assets/driving_video/4.mp4 filter=lfs diff=lfs merge=lfs -text
+assets/driving_video/5.mp4 filter=lfs diff=lfs merge=lfs -text
+assets/driving_video/6.mp4 filter=lfs diff=lfs merge=lfs -text
+assets/driving_video/7.mp4 filter=lfs diff=lfs merge=lfs -text
+assets/driving_video/8.mp4 filter=lfs diff=lfs merge=lfs -text
+assets/ref_images/1.png filter=lfs diff=lfs merge=lfs -text
+assets/ref_images/10.png filter=lfs diff=lfs merge=lfs -text
+assets/ref_images/11.png filter=lfs diff=lfs merge=lfs -text
+assets/ref_images/12.png filter=lfs diff=lfs merge=lfs -text
+assets/ref_images/13.png filter=lfs diff=lfs merge=lfs -text
+assets/ref_images/14.png filter=lfs diff=lfs merge=lfs -text
+assets/ref_images/15.png filter=lfs diff=lfs merge=lfs -text
+assets/ref_images/16.png filter=lfs diff=lfs merge=lfs -text
+assets/ref_images/17.png filter=lfs diff=lfs merge=lfs -text
+assets/ref_images/18.png filter=lfs diff=lfs merge=lfs -text
+assets/ref_images/19.png filter=lfs diff=lfs merge=lfs -text
+assets/ref_images/2.png filter=lfs diff=lfs merge=lfs -text
+assets/ref_images/3.png filter=lfs diff=lfs merge=lfs -text
+assets/ref_images/4.png filter=lfs diff=lfs merge=lfs -text
+assets/ref_images/5.png filter=lfs diff=lfs merge=lfs -text
+assets/ref_images/6.png filter=lfs diff=lfs merge=lfs -text
+assets/ref_images/7.png filter=lfs diff=lfs merge=lfs -text
+assets/ref_images/8.png filter=lfs diff=lfs merge=lfs -text
+skyreels_a1/src/media_pipe/mp_models/face_landmarker_v2_with_blendshapes.task filter=lfs diff=lfs merge=lfs -text
+skyreels_a1/src/media_pipe/mp_models/pose_landmarker_heavy.task filter=lfs diff=lfs merge=lfs -text
diff --git a/LICENSE.txt b/LICENSE.txt
new file mode 100644
index 0000000000000000000000000000000000000000..7b84e00122f780a28245a696e8d92434bbdba387
--- /dev/null
+++ b/LICENSE.txt
@@ -0,0 +1,38 @@
+---
+language:
+ - en
+ - zh
+license: other
+tasks:
+ - text-generation
+
+---
+
+
+
+
+# 声明与协议/Terms and Conditions
+
+## 声明
+
+我们在此声明,不要利用Skywork模型进行任何危害国家社会安全或违法的活动。另外,我们也要求使用者不要将 Skywork 模型用于未经适当安全审查和备案的互联网服务。我们希望所有的使用者都能遵守这个原则,确保科技的发展能在规范和合法的环境下进行。
+
+我们已经尽我们所能,来确保模型训练过程中使用的数据的合规性。然而,尽管我们已经做出了巨大的努力,但由于模型和数据的复杂性,仍有可能存在一些无法预见的问题。因此,如果由于使用skywork开源模型而导致的任何问题,包括但不限于数据安全问题、公共舆论风险,或模型被误导、滥用、传播或不当利用所带来的任何风险和问题,我们将不承担任何责任。
+
+We hereby declare that the Skywork model should not be used for any activities that pose a threat to national or societal security or engage in unlawful actions. Additionally, we request users not to deploy the Skywork model for internet services without appropriate security reviews and records. We hope that all users will adhere to this principle to ensure that technological advancements occur in a regulated and lawful environment.
+
+We have done our utmost to ensure the compliance of the data used during the model's training process. However, despite our extensive efforts, due to the complexity of the model and data, there may still be unpredictable risks and issues. Therefore, if any problems arise as a result of using the Skywork open-source model, including but not limited to data security issues, public opinion risks, or any risks and problems arising from the model being misled, abused, disseminated, or improperly utilized, we will not assume any responsibility.
+
+## 协议
+
+社区使用Skywork模型需要遵循[《Skywork 模型社区许可协议》](https://github.com/SkyworkAI/Skywork/blob/main/Skywork%20模型社区许可协议.pdf)。Skywork模型支持商业用途,如果您计划将Skywork模型或其衍生品用于商业目的,无需再次申请, 但请您仔细阅读[《Skywork 模型社区许可协议》](https://github.com/SkyworkAI/Skywork/blob/main/Skywork%20模型社区许可协议.pdf)并严格遵守相关条款。
+
+
+The community usage of Skywork model requires [Skywork Community License](https://github.com/SkyworkAI/Skywork/blob/main/Skywork%20Community%20License.pdf). The Skywork model supports commercial use. If you plan to use the Skywork model or its derivatives for commercial purposes, you must abide by terms and conditions within [Skywork Community License](https://github.com/SkyworkAI/Skywork/blob/main/Skywork%20Community%20License.pdf).
+
+
+
+[《Skywork 模型社区许可协议》》]:https://github.com/SkyworkAI/Skywork/blob/main/Skywork%20模型社区许可协议.pdf
+
+
+[skywork-opensource@kunlun-inc.com]: mailto:skywork-opensource@kunlun-inc.com
diff --git a/README.md b/README.md
index a9726b5d3abe3d0e4b3020ec09bc567996b86fd3..b514f643d6739fa8be3473eb6c1e805f3353a193 100644
--- a/README.md
+++ b/README.md
@@ -1,14 +1,194 @@
----
-title: Skyreels Talking Head
-emoji: 😻
-colorFrom: yellow
-colorTo: green
-sdk: gradio
-sdk_version: 5.20.0
-app_file: app.py
-pinned: false
-license: mit
-short_description: audio to talking face
----
-
-Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
+
+
+
+
+
+SkyReels-A1: Expressive Portrait Animation in Video Diffusion Transformers
+
+
+
+
+
+ Skywork AI
+
+
+
+
+
+
+
+
+
+
+
+ 🔥 For more results, visit our homepage 🔥
+
+
+
+ 👋 Join our Discord
+
+
+
+This repo, named **SkyReels-A1**, contains the official PyTorch implementation of our paper [SkyReels-A1: Expressive Portrait Animation in Video Diffusion Transformers](https://arxiv.org).
+
+
+## 🔥🔥🔥 News!!
+* Mar 4, 2025: 🔥 We release audio-driven portrait image animation pipeline.
+* Feb 18, 2025: 👋 We release the inference code and model weights of SkyReels-A1. [Download](https://huggingface.co/Skywork/SkyReels-A1)
+* Feb 18, 2025: 🎉 We have made our technical report available as open source. [Read](https://skyworkai.github.io/skyreels-a1.github.io/report.pdf)
+* Feb 18, 2025: 🔥 Our online demo of LipSync is available on SkyReels now! Try out [LipSync](https://www.skyreels.ai/home/tools/lip-sync?refer=navbar).
+* Feb 18, 2025: 🔥 We have open-sourced I2V video generation model [SkyReels-V1](https://github.com/SkyworkAI/SkyReels-V1). This is the first and most advanced open-source human-centric video foundation model.
+
+## 📑 TODO List
+- [x] Checkpoints
+- [x] Inference Code
+- [x] Web Demo (Gradio)
+- [x] Audio-driven Portrait Image Animation Pipeline
+- [ ] Inference Code for Long Videos
+- [ ] User-Level GPU Inference on RTX4090
+- [ ] ComfyUI
+
+
+## Getting Started 🏁
+
+### 1. Clone the code and prepare the environment 🛠️
+First git clone the repository with code:
+```bash
+git clone https://github.com/SkyworkAI/SkyReels-A1.git
+cd SkyReels-A1
+
+# create env using conda
+conda create -n skyreels-a1 python=3.10
+conda activate skyreels-a1
+```
+Then, install the remaining dependencies:
+```bash
+pip install -r requirements.txt
+```
+
+
+### 2. Download pretrained weights 📥
+You can download the pretrained weights is from HuggingFace:
+```bash
+# !pip install -U "huggingface_hub[cli]"
+huggingface-cli download SkyReels-A1 --local-dir local_path --exclude "*.git*" "README.md" "docs"
+```
+
+The FLAME, mediapipe, and smirk models are located in the SkyReels-A1/extra_models folder.
+
+The directory structure of our SkyReels-A1 code is formulated as:
+```text
+pretrained_models
+├── FLAME
+├── SkyReels-A1-5B
+│ ├── pose_guider
+│ ├── scheduler
+│ ├── tokenizer
+│ ├── siglip-so400m-patch14-384
+│ ├── transformer
+│ ├── vae
+│ └── text_encoder
+├── mediapipe
+└── smirk
+
+```
+
+#### Download DiffposeTalk assets and pretrained weights (For Audio-driven)
+
+- We use [diffposetalk](https://github.com/DiffPoseTalk/DiffPoseTalk/tree/main) to generate flame coefficients from audio, thereby constructing motion signals.
+
+- Download the diffposetalk code and follow its README to download the weights and related data.
+
+- Then place them in the specified directory.
+
+```bash
+cp -r ${diffposetalk_root}/style pretrained_models/diffposetalk
+cp ${diffposetalk_root}/experiments/DPT/head-SA-hubert-WM/checkpoints/iter_0110000.pt pretrained_models/diffposetalk
+cp ${diffposetalk_root}/datasets/HDTF_TFHP/lmdb/stats_train.npz pretrained_models/diffposetalk
+```
+
+```text
+pretrained_models
+├── FLAME
+├── SkyReels-A1-5B
+├── mediapipe
+├── diffposetalk
+│ ├── style
+│ ├── iter_0110000.pt
+│ ├── states_train.npz
+└── smirk
+
+```
+
+
+### 3. Inference 🚀
+You can simply run the inference scripts as:
+```bash
+python inference.py
+
+# inference audio to video
+python inference_audio.py
+```
+
+If the script runs successfully, you will get an output mp4 file. This file includes the following results: driving video, input image or video, and generated result.
+
+
+## Gradio Interface 🤗
+
+We provide a [Gradio](https://huggingface.co/docs/hub/spaces-sdks-gradio) interface for a better experience, just run by:
+
+```bash
+python app.py
+```
+
+The graphical interactive interface is shown as below:
+
+
+
+
+## Metric Evaluation 👓
+
+We also provide all scripts for automatically calculating the metrics, including SimFace, FID, and L1 distance between expression and motion, reported in the paper.
+
+All codes can be found in the ```eval``` folder. After setting the video result path, run the following commands in sequence:
+
+```bash
+python arc_score.py
+python expression_score.py
+python pose_score.py
+```
+
+
+## Acknowledgements 💐
+We would like to thank the contributors of [CogvideoX](https://github.com/THUDM/CogVideo), [finetrainers](https://github.com/a-r-r-o-w/finetrainers) and [DiffPoseTalk](https://github.com/DiffPoseTalk/DiffPoseTalk)repositories, for their open research and contributions.
+
+## Citation 💖
+If you find SkyReels-A1 useful for your research, welcome to 🌟 this repo and cite our work using the following BibTeX:
+```bibtex
+@article{qiu2025skyreels,
+ title={SkyReels-A1: Expressive Portrait Animation in Video Diffusion Transformers},
+ author={Qiu, Di and Fei, Zhengcong and Wang, Rui and Bai, Jialin and Yu, Changqian and Fan, Mingyuan and Chen, Guibin and Wen, Xiang},
+ journal={arXiv preprint arXiv:2502.10841},
+ year={2025}
+}
+```
+
+
+
diff --git a/assets/.DS_Store b/assets/.DS_Store
new file mode 100644
index 0000000000000000000000000000000000000000..6d0c4178f4d34a45ea150033d1c45b6b4bf93e4d
Binary files /dev/null and b/assets/.DS_Store differ
diff --git a/assets/demo.gif b/assets/demo.gif
new file mode 100644
index 0000000000000000000000000000000000000000..5ef3a1ee9335d30d587dc0cc74e57985f5007212
--- /dev/null
+++ b/assets/demo.gif
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:1b8c13b7c718a9e2645dd4490dfe645a880781121b7d207ed43cc7cd3d0a35e4
+size 3084348
diff --git a/assets/driving_audio/1.wav b/assets/driving_audio/1.wav
new file mode 100644
index 0000000000000000000000000000000000000000..78cce7171b8296782dc92f762481e77c2aef28c8
--- /dev/null
+++ b/assets/driving_audio/1.wav
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:38dab65a002455f4c186d4b0bde848c964415441d9636d94a48d5b32f23b0f6f
+size 575850
diff --git a/assets/driving_audio/2.wav b/assets/driving_audio/2.wav
new file mode 100644
index 0000000000000000000000000000000000000000..6c522e7f3451694be7961443509111eb28b61870
--- /dev/null
+++ b/assets/driving_audio/2.wav
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:15b998fd3fbabf22e9fde210f93df4f7b1647e19fe81b2d33b2b74470fea32b5
+size 3891278
diff --git a/assets/driving_audio/3.wav b/assets/driving_audio/3.wav
new file mode 100644
index 0000000000000000000000000000000000000000..6bc47d0786e84bd1ed62e91fb6e2b0e4fcdda186
--- /dev/null
+++ b/assets/driving_audio/3.wav
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:448576f545e18c4cb33cb934a00c0bc3f331896eba4d1cb6a077c1b9382d0628
+size 910770
diff --git a/assets/driving_audio/4.wav b/assets/driving_audio/4.wav
new file mode 100644
index 0000000000000000000000000000000000000000..c675d5446621777c1495d89b8055b3a06972f938
--- /dev/null
+++ b/assets/driving_audio/4.wav
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:a3e0d37fc6e235b5a09eb4b7e3b0b5d5f566e2204c5c815c23ca2215dcbf9c93
+size 553038
diff --git a/assets/driving_audio/5.wav b/assets/driving_audio/5.wav
new file mode 100644
index 0000000000000000000000000000000000000000..62bb2e8ce2e997f30acac4b0b062d40b85a7c595
--- /dev/null
+++ b/assets/driving_audio/5.wav
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:95275c9299919e38e52789cb3af17ddc4691b7afea82f26c7edc640addce057d
+size 856142
diff --git a/assets/driving_audio/6.wav b/assets/driving_audio/6.wav
new file mode 100644
index 0000000000000000000000000000000000000000..b43f4a9b495396fec9af1de78c431b0591fccc53
--- /dev/null
+++ b/assets/driving_audio/6.wav
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:7558a976b8b64d214f33c65e503c77271d3be0cd116a00ddadcb2b2fc53a6396
+size 2641742
diff --git a/assets/driving_video/.DS_Store b/assets/driving_video/.DS_Store
new file mode 100644
index 0000000000000000000000000000000000000000..11a8fe3e8e406c47c95ddc456de2da406f26c42f
Binary files /dev/null and b/assets/driving_video/.DS_Store differ
diff --git a/assets/driving_video/1.mp4 b/assets/driving_video/1.mp4
new file mode 100644
index 0000000000000000000000000000000000000000..94cdbe9fc81151fe2d1cf2a44811f9eb8c9d53cc
--- /dev/null
+++ b/assets/driving_video/1.mp4
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:b7da4f10cf9e692ba8c75848bacceb3c4d30ee8d3b07719435560c44a8da6544
+size 306996
diff --git a/assets/driving_video/2.mp4 b/assets/driving_video/2.mp4
new file mode 100644
index 0000000000000000000000000000000000000000..e950b6baf054ee1babe8c14e558b621acb32464a
--- /dev/null
+++ b/assets/driving_video/2.mp4
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:7e795b7be655c4b5ae8cac0733a32e8d321ccebd13f2cac07cc15dfc8f61a547
+size 2875843
diff --git a/assets/driving_video/3.mp4 b/assets/driving_video/3.mp4
new file mode 100644
index 0000000000000000000000000000000000000000..3a6da66fbf20cecfbac6b444ca962db29d757af2
--- /dev/null
+++ b/assets/driving_video/3.mp4
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:02f5ee85c1028c9673c70682b533a4f22e203173eddd40de42bad0cb57f18abb
+size 1020948
diff --git a/assets/driving_video/4.mp4 b/assets/driving_video/4.mp4
new file mode 100644
index 0000000000000000000000000000000000000000..b214ba6f7d90f19676b8a11b51a1f6e82e30ab3b
--- /dev/null
+++ b/assets/driving_video/4.mp4
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:6f7ddbb17b198a580f658d57f4d83bee7489aa4d8a677f2c45b76b1ec01ae461
+size 215144
diff --git a/assets/driving_video/5.mp4 b/assets/driving_video/5.mp4
new file mode 100644
index 0000000000000000000000000000000000000000..62399c48a6291aa29d9cbdf847098e145c78a8e7
--- /dev/null
+++ b/assets/driving_video/5.mp4
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:9637fea5ef83b494a0aa8b7c526ae1efc6ec94d79dfa94381de8d6f38eec238e
+size 556047
diff --git a/assets/driving_video/6.mp4 b/assets/driving_video/6.mp4
new file mode 100644
index 0000000000000000000000000000000000000000..8008dacea2798956f250ff3ecc44e13f4c7a900a
--- /dev/null
+++ b/assets/driving_video/6.mp4
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:ac7ee3c2419046f11dc230b6db33c2391a98334eba2b1d773e7eb9627992622f
+size 1064930
diff --git a/assets/driving_video/7.mp4 b/assets/driving_video/7.mp4
new file mode 100644
index 0000000000000000000000000000000000000000..a91566ab36f2b2048413646a390a3c30d1973d13
--- /dev/null
+++ b/assets/driving_video/7.mp4
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:1dc94c1fec7ef7dc831c8a49f0e1788ae568812cb68e62f6875d9070f573d02a
+size 187263
diff --git a/assets/driving_video/8.mp4 b/assets/driving_video/8.mp4
new file mode 100644
index 0000000000000000000000000000000000000000..2c71946b067bfba717c670eb89514a16d7108e45
--- /dev/null
+++ b/assets/driving_video/8.mp4
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:3047ba66296d96b8a4584e412e61493d7bc0fa5149c77b130e7feea375e698bd
+size 232859
diff --git a/assets/logo.png b/assets/logo.png
new file mode 100644
index 0000000000000000000000000000000000000000..b2d39ed9924bce5f0c69910101ba6d8fcd8da579
Binary files /dev/null and b/assets/logo.png differ
diff --git a/assets/ref_images/1.png b/assets/ref_images/1.png
new file mode 100644
index 0000000000000000000000000000000000000000..9a5b8fd2cfd7d72c49a3d01a17d12f6ef85fbea5
--- /dev/null
+++ b/assets/ref_images/1.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:93429c6e7408723b04f3681cc06ac98072f8ce4fd69476ee612466a335ca152c
+size 1066069
diff --git a/assets/ref_images/10.png b/assets/ref_images/10.png
new file mode 100644
index 0000000000000000000000000000000000000000..ee87e3fb10bf346c6aad9f030e1a7a67e29fa368
--- /dev/null
+++ b/assets/ref_images/10.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:ef7456fd5eb3b31584f0933d1b71c25f92a8a9cb428466c1c4daf4eede2db9d3
+size 507817
diff --git a/assets/ref_images/11.png b/assets/ref_images/11.png
new file mode 100644
index 0000000000000000000000000000000000000000..f0407d5b0fef1de726289c24355ff0abf0adf3ab
--- /dev/null
+++ b/assets/ref_images/11.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:99bfeeecbefa2bf408d0f15688ae89fa4c71f881d78baced1591bef128367efc
+size 633895
diff --git a/assets/ref_images/12.png b/assets/ref_images/12.png
new file mode 100644
index 0000000000000000000000000000000000000000..4674e0538b1aeae0a72374b628d2e286501441e1
--- /dev/null
+++ b/assets/ref_images/12.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:c258cba0979585f3fac4d63d9ca0fc3e51604afde62a272f069146ae43d1a996
+size 793368
diff --git a/assets/ref_images/13.png b/assets/ref_images/13.png
new file mode 100644
index 0000000000000000000000000000000000000000..0717d4461d27bdb5c32b82c6ee1b3d8a6ad70ad1
--- /dev/null
+++ b/assets/ref_images/13.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:c31191bc70144def9c0de388483d0a9257b0e4eb72128474232bbaa234f5a0a5
+size 632959
diff --git a/assets/ref_images/14.png b/assets/ref_images/14.png
new file mode 100644
index 0000000000000000000000000000000000000000..29bec6a4925481b5fb39f0c8eaef3f0573ca9ca7
--- /dev/null
+++ b/assets/ref_images/14.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:8058fc784284c59f1954269638f1ad937ac35cf58563b935736d3f34e6355045
+size 517336
diff --git a/assets/ref_images/15.png b/assets/ref_images/15.png
new file mode 100644
index 0000000000000000000000000000000000000000..37cb45c7de1a6e4e33750c6c4ac4be1ee9862a97
--- /dev/null
+++ b/assets/ref_images/15.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:4c3e49512a2253b2a7291ad6b1636521e66b10050dba37a0b9d47c9a5666fb61
+size 641391
diff --git a/assets/ref_images/16.png b/assets/ref_images/16.png
new file mode 100644
index 0000000000000000000000000000000000000000..bc4eb676f61c02aa9f40cee5ed1b4bfc352056c3
--- /dev/null
+++ b/assets/ref_images/16.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:5e65a2f40f5f971b0e91023e774ce8aff56a1da723c1f8ffdfc5ec616690cde2
+size 391537
diff --git a/assets/ref_images/17.png b/assets/ref_images/17.png
new file mode 100644
index 0000000000000000000000000000000000000000..4a7bd72553bb60dce275188021e3ac5787a8c555
--- /dev/null
+++ b/assets/ref_images/17.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:202b2a66e87de425c55e223554942da71a4b0a27757bc2f90ec4c8d51133934b
+size 749924
diff --git a/assets/ref_images/18.png b/assets/ref_images/18.png
new file mode 100644
index 0000000000000000000000000000000000000000..240a83a00b04198c32225ba3bc727e5af04d757e
--- /dev/null
+++ b/assets/ref_images/18.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:06a756c3e0a0b5d786428b0968126281c292e1df2c286cb683bac059821c0122
+size 182530
diff --git a/assets/ref_images/19.png b/assets/ref_images/19.png
new file mode 100644
index 0000000000000000000000000000000000000000..2b7613f3c9eed3de6ad293f2020adff8acf8bae5
--- /dev/null
+++ b/assets/ref_images/19.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:277fc3ecf3c0299f87cb59d056c5d484feb2fa7897c9d0f80ee0854eba2c3487
+size 283261
diff --git a/assets/ref_images/2.png b/assets/ref_images/2.png
new file mode 100644
index 0000000000000000000000000000000000000000..4f705a503935b5eaa4772775349abd2895e7192b
--- /dev/null
+++ b/assets/ref_images/2.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:5c972790d52fc6adf7e5bcb4611720570e260f56c52f063acfea5e4d2f52c07f
+size 761861
diff --git a/assets/ref_images/20.png b/assets/ref_images/20.png
new file mode 100644
index 0000000000000000000000000000000000000000..854cd330826e1ff085264090f9f46f4bd1ceb223
Binary files /dev/null and b/assets/ref_images/20.png differ
diff --git a/assets/ref_images/3.png b/assets/ref_images/3.png
new file mode 100644
index 0000000000000000000000000000000000000000..5f1d527021d64d647bdbf8542cd0fbd4865f2bd3
--- /dev/null
+++ b/assets/ref_images/3.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:bce73675d41349d0792e9903d08ad12280d0e1b3af21e686720a7dac5dcaa649
+size 737200
diff --git a/assets/ref_images/4.png b/assets/ref_images/4.png
new file mode 100644
index 0000000000000000000000000000000000000000..4b7a24775f1b93963cadc719a855ff07d1f595ba
--- /dev/null
+++ b/assets/ref_images/4.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:03ff23c5be3ff225969ddd97a26971bab40af4cc6012f0f859971a12cd8e9003
+size 347775
diff --git a/assets/ref_images/5.png b/assets/ref_images/5.png
new file mode 100644
index 0000000000000000000000000000000000000000..2028428e54a9634558a261503797441b8db965c5
--- /dev/null
+++ b/assets/ref_images/5.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:6b9c2279c99ef4f354fa9e2ea8f1751e8f35ed2ed937e5a2b0b3c918fb49f947
+size 375475
diff --git a/assets/ref_images/6.png b/assets/ref_images/6.png
new file mode 100644
index 0000000000000000000000000000000000000000..55daeb4a5028c35376d73bd1f8f34c07cd07109c
--- /dev/null
+++ b/assets/ref_images/6.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:d127961dece864d4000351c1c14a71d3c1bc54c51c2cce6d9dd1c74bdea0ec4c
+size 370310
diff --git a/assets/ref_images/7.png b/assets/ref_images/7.png
new file mode 100644
index 0000000000000000000000000000000000000000..488fbfd17d57156905f31c5f19006faf32096f74
--- /dev/null
+++ b/assets/ref_images/7.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:c1e2c11b7f9832b2acbf454065b2beebf95f6817f623ee1fe56ff2fafc0caf1d
+size 542232
diff --git a/assets/ref_images/8.png b/assets/ref_images/8.png
new file mode 100644
index 0000000000000000000000000000000000000000..7b4567bfae185e6c250983db8b5d4a3e1538c952
--- /dev/null
+++ b/assets/ref_images/8.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:8c8aa92c1bea3f5f0b1b3859b35ed801fc4022f064b3ebba09e621157a2ac4c6
+size 357691
diff --git a/diffposetalk/common.py b/diffposetalk/common.py
new file mode 100644
index 0000000000000000000000000000000000000000..59ebbb510ec54d2a35790e0993156001f7e015a1
--- /dev/null
+++ b/diffposetalk/common.py
@@ -0,0 +1,46 @@
+import math
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+
+class PositionalEncoding(nn.Module):
+ def __init__(self, d_model, dropout=0.1, max_len=600):
+ super().__init__()
+ self.dropout = nn.Dropout(p=dropout)
+ # vanilla sinusoidal encoding
+ pe = torch.zeros(max_len, d_model)
+ position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
+ div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
+ pe[:, 0::2] = torch.sin(position * div_term)
+ pe[:, 1::2] = torch.cos(position * div_term)
+ pe = pe.unsqueeze(0)
+ self.register_buffer('pe', pe)
+
+ def forward(self, x):
+ x = x + self.pe[:, x.shape[1], :]
+ return self.dropout(x)
+
+
+def enc_dec_mask(T, S, frame_width=2, expansion=0, device='cuda'):
+ mask = torch.ones(T, S)
+ for i in range(T):
+ mask[i, max(0, (i - expansion) * frame_width):(i + expansion + 1) * frame_width] = 0
+ return (mask == 1).to(device=device)
+
+
+def pad_audio(audio, audio_unit=320, pad_threshold=80):
+ batch_size, audio_len = audio.shape
+ n_units = audio_len // audio_unit
+ side_len = math.ceil((audio_unit * n_units + pad_threshold - audio_len) / 2)
+ if side_len >= 0:
+ reflect_len = side_len // 2
+ replicate_len = side_len % 2
+ if reflect_len > 0:
+ audio = F.pad(audio, (reflect_len, reflect_len), mode='reflect')
+ audio = F.pad(audio, (reflect_len, reflect_len), mode='reflect')
+ if replicate_len > 0:
+ audio = F.pad(audio, (1, 1), mode='replicate')
+
+ return audio
diff --git a/diffposetalk/diff_talking_head.py b/diffposetalk/diff_talking_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..9867bf79e0974c97f1e7e0468429ef1c2b8088d7
--- /dev/null
+++ b/diffposetalk/diff_talking_head.py
@@ -0,0 +1,536 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from .common import PositionalEncoding, enc_dec_mask, pad_audio
+
+
+class DiffusionSchedule(nn.Module):
+ def __init__(self, num_steps, mode='linear', beta_1=1e-4, beta_T=0.02, s=0.008):
+ super().__init__()
+
+ if mode == 'linear':
+ betas = torch.linspace(beta_1, beta_T, num_steps)
+ elif mode == 'quadratic':
+ betas = torch.linspace(beta_1 ** 0.5, beta_T ** 0.5, num_steps) ** 2
+ elif mode == 'sigmoid':
+ betas = torch.sigmoid(torch.linspace(-5, 5, num_steps)) * (beta_T - beta_1) + beta_1
+ elif mode == 'cosine':
+ steps = num_steps + 1
+ x = torch.linspace(0, num_steps, steps)
+ alpha_bars = torch.cos(((x / num_steps) + s) / (1 + s) * torch.pi * 0.5) ** 2
+ alpha_bars = alpha_bars / alpha_bars[0]
+ betas = 1 - (alpha_bars[1:] / alpha_bars[:-1])
+ betas = torch.clip(betas, 0.0001, 0.999)
+ else:
+ raise ValueError(f'Unknown diffusion schedule {mode}!')
+ betas = torch.cat([torch.zeros(1), betas], dim=0) # Padding beta_0 = 0
+
+ alphas = 1 - betas
+ log_alphas = torch.log(alphas)
+ for i in range(1, log_alphas.shape[0]): # 1 to T
+ log_alphas[i] += log_alphas[i - 1]
+ alpha_bars = log_alphas.exp()
+
+ sigmas_flex = torch.sqrt(betas)
+ sigmas_inflex = torch.zeros_like(sigmas_flex)
+ for i in range(1, sigmas_flex.shape[0]):
+ sigmas_inflex[i] = ((1 - alpha_bars[i - 1]) / (1 - alpha_bars[i])) * betas[i]
+ sigmas_inflex = torch.sqrt(sigmas_inflex)
+
+ self.num_steps = num_steps
+ self.register_buffer('betas', betas)
+ self.register_buffer('alphas', alphas)
+ self.register_buffer('alpha_bars', alpha_bars)
+ self.register_buffer('sigmas_flex', sigmas_flex)
+ self.register_buffer('sigmas_inflex', sigmas_inflex)
+
+ def uniform_sample_t(self, batch_size):
+ ts = torch.randint(1, self.num_steps + 1, (batch_size,))
+ return ts.tolist()
+
+ def get_sigmas(self, t, flexibility=0):
+ assert 0 <= flexibility <= 1
+ sigmas = self.sigmas_flex[t] * flexibility + self.sigmas_inflex[t] * (1 - flexibility)
+ return sigmas
+
+
+class DiffTalkingHead(nn.Module):
+ def __init__(self, args, device='cuda'):
+ super().__init__()
+
+ # Model parameters
+ self.target = args.target
+ self.architecture = args.architecture
+ self.use_style = args.style_enc_ckpt is not None
+
+ self.motion_feat_dim = 50
+ if args.rot_repr == 'aa':
+ self.motion_feat_dim += 1 if args.no_head_pose else 4
+ else:
+ raise ValueError(f'Unknown rotation representation {args.rot_repr}!')
+
+ self.fps = args.fps
+ self.n_motions = args.n_motions
+ self.n_prev_motions = args.n_prev_motions
+ if self.use_style:
+ self.style_feat_dim = args.d_style
+
+ # Audio encoder
+ self.audio_model = args.audio_model
+ if self.audio_model == 'wav2vec2':
+ from .wav2vec2 import Wav2Vec2Model
+ self.audio_encoder = Wav2Vec2Model.from_pretrained('facebook/wav2vec2-base-960h')
+ # wav2vec 2.0 weights initialization
+ self.audio_encoder.feature_extractor._freeze_parameters()
+ elif self.audio_model == 'hubert':
+ from .hubert import HubertModel
+ self.audio_encoder = HubertModel.from_pretrained('facebook/hubert-base-ls960')
+ self.audio_encoder.feature_extractor._freeze_parameters()
+
+ frozen_layers = [0, 1]
+ for name, param in self.audio_encoder.named_parameters():
+ if name.startswith("feature_projection"):
+ param.requires_grad = False
+ if name.startswith("encoder.layers"):
+ layer = int(name.split(".")[2])
+ if layer in frozen_layers:
+ param.requires_grad = False
+ else:
+ raise ValueError(f'Unknown audio model {self.audio_model}!')
+
+ if args.architecture == 'decoder':
+ self.audio_feature_map = nn.Linear(768, args.feature_dim)
+ self.start_audio_feat = nn.Parameter(torch.randn(1, self.n_prev_motions, args.feature_dim))
+ else:
+ raise ValueError(f'Unknown architecture {args.architecture}!')
+
+ self.start_motion_feat = nn.Parameter(torch.randn(1, self.n_prev_motions, self.motion_feat_dim))
+
+ # Diffusion model
+ self.denoising_net = DenoisingNetwork(args, device)
+ # diffusion schedule
+ self.diffusion_sched = DiffusionSchedule(args.n_diff_steps, args.diff_schedule)
+
+ # Classifier-free settings
+ self.cfg_mode = args.cfg_mode
+ guiding_conditions = args.guiding_conditions.split(',') if args.guiding_conditions else []
+ self.guiding_conditions = [cond for cond in guiding_conditions if cond in ['style', 'audio']]
+ if 'style' in self.guiding_conditions:
+ if not self.use_style:
+ raise ValueError('Cannot use style guiding without enabling it!')
+ self.null_style_feat = nn.Parameter(torch.randn(1, 1, self.style_feat_dim))
+ if 'audio' in self.guiding_conditions:
+ audio_feat_dim = args.feature_dim
+ self.null_audio_feat = nn.Parameter(torch.randn(1, 1, audio_feat_dim))
+
+ self.to(device)
+
+ @property
+ def device(self):
+ return next(self.parameters()).device
+
+ def forward(self, motion_feat, audio_or_feat, shape_feat, style_feat=None,
+ prev_motion_feat=None, prev_audio_feat=None, time_step=None, indicator=None):
+ """
+ Args:
+ motion_feat: (N, L, d_coef) motion coefficients or features
+ audio_or_feat: (N, L_audio) raw audio or audio feature
+ shape_feat: (N, d_shape) or (N, 1, d_shape)
+ style_feat: (N, d_style)
+ prev_motion_feat: (N, n_prev_motions, d_motion) previous motion coefficients or feature
+ prev_audio_feat: (N, n_prev_motions, d_audio) previous audio features
+ time_step: (N,)
+ indicator: (N, L) 0/1 indicator of real (unpadded) motion coefficients
+
+ Returns:
+ motion_feat_noise: (N, L, d_motion)
+ """
+ if self.use_style:
+ assert style_feat is not None, 'Missing style features!'
+
+ batch_size = motion_feat.shape[0]
+
+ if audio_or_feat.ndim == 2:
+ # Extract audio features
+ assert audio_or_feat.shape[1] == 16000 * self.n_motions / self.fps, \
+ f'Incorrect audio length {audio_or_feat.shape[1]}'
+ audio_feat_saved = self.extract_audio_feature(audio_or_feat) # (N, L, feature_dim)
+ elif audio_or_feat.ndim == 3:
+ assert audio_or_feat.shape[1] == self.n_motions, f'Incorrect audio feature length {audio_or_feat.shape[1]}'
+ audio_feat_saved = audio_or_feat
+ else:
+ raise ValueError(f'Incorrect audio input shape {audio_or_feat.shape}')
+ audio_feat = audio_feat_saved.clone()
+
+ if shape_feat.ndim == 2:
+ shape_feat = shape_feat.unsqueeze(1) # (N, 1, d_shape)
+ if style_feat is not None and style_feat.ndim == 2:
+ style_feat = style_feat.unsqueeze(1) # (N, 1, d_style)
+
+ if prev_motion_feat is None:
+ prev_motion_feat = self.start_motion_feat.expand(batch_size, -1, -1) # (N, n_prev_motions, d_motion)
+ if prev_audio_feat is None:
+ # (N, n_prev_motions, feature_dim)
+ prev_audio_feat = self.start_audio_feat.expand(batch_size, -1, -1)
+
+ # Classifier-free guidance
+ if len(self.guiding_conditions) > 0:
+ assert len(self.guiding_conditions) <= 2, 'Only support 1 or 2 CFG conditions!'
+ if len(self.guiding_conditions) == 1 or self.cfg_mode == 'independent':
+ null_cond_prob = 0.5 if len(self.guiding_conditions) >= 2 else 0.1
+ if 'style' in self.guiding_conditions:
+ mask_style = torch.rand(batch_size, device=self.device) < null_cond_prob
+ style_feat = torch.where(mask_style.view(-1, 1, 1),
+ self.null_style_feat.expand(batch_size, -1, -1),
+ style_feat)
+ if 'audio' in self.guiding_conditions:
+ mask_audio = torch.rand(batch_size, device=self.device) < null_cond_prob
+ audio_feat = torch.where(mask_audio.view(-1, 1, 1),
+ self.null_audio_feat.expand(batch_size, self.n_motions, -1),
+ audio_feat)
+ else:
+ # len(self.guiding_conditions) > 1 and self.cfg_mode == 'incremental'
+ # full (0.45), w/o style (0.45), w/o style or audio (0.1)
+ mask_flag = torch.rand(batch_size, device=self.device)
+ if 'style' in self.guiding_conditions:
+ mask_style = mask_flag > 0.55
+ style_feat = torch.where(mask_style.view(-1, 1, 1),
+ self.null_style_feat.expand(batch_size, -1, -1),
+ style_feat)
+ if 'audio' in self.guiding_conditions:
+ mask_audio = mask_flag > 0.9
+ audio_feat = torch.where(mask_audio.view(-1, 1, 1),
+ self.null_audio_feat.expand(batch_size, self.n_motions, -1),
+ audio_feat)
+
+ if style_feat is None:
+ # The model only accepts audio and shape features, i.e., self.use_style = False
+ person_feat = shape_feat
+ else:
+ person_feat = torch.cat([shape_feat, style_feat], dim=-1)
+
+ if time_step is None:
+ # Sample time step
+ time_step = self.diffusion_sched.uniform_sample_t(batch_size) # (N,)
+
+ # The forward diffusion process
+ alpha_bar = self.diffusion_sched.alpha_bars[time_step] # (N,)
+ c0 = torch.sqrt(alpha_bar).view(-1, 1, 1) # (N, 1, 1)
+ c1 = torch.sqrt(1 - alpha_bar).view(-1, 1, 1) # (N, 1, 1)
+
+ eps = torch.randn_like(motion_feat) # (N, L, d_motion)
+ motion_feat_noisy = c0 * motion_feat + c1 * eps
+
+ # The reverse diffusion process
+ motion_feat_target = self.denoising_net(motion_feat_noisy, audio_feat, person_feat,
+ prev_motion_feat, prev_audio_feat, time_step, indicator)
+
+ return eps, motion_feat_target, motion_feat.detach(), audio_feat_saved.detach()
+
+ def extract_audio_feature(self, audio, frame_num=None):
+ frame_num = frame_num or self.n_motions
+
+ # # Strategy 1: resample during audio feature extraction
+ # hidden_states = self.audio_encoder(pad_audio(audio), self.fps, frame_num=frame_num).last_hidden_state # (N, L, 768)
+
+ # Strategy 2: resample after audio feature extraction (BackResample)
+ hidden_states = self.audio_encoder(pad_audio(audio), self.fps,
+ frame_num=frame_num * 2).last_hidden_state # (N, 2L, 768)
+ hidden_states = hidden_states.transpose(1, 2) # (N, 768, 2L)
+ hidden_states = F.interpolate(hidden_states, size=frame_num, align_corners=False, mode='linear') # (N, 768, L)
+ hidden_states = hidden_states.transpose(1, 2) # (N, L, 768)
+
+ audio_feat = self.audio_feature_map(hidden_states) # (N, L, feature_dim)
+ return audio_feat
+
+ @torch.no_grad()
+ def sample(self, audio_or_feat, shape_feat, style_feat=None, prev_motion_feat=None, prev_audio_feat=None,
+ motion_at_T=None, indicator=None, cfg_mode=None, cfg_cond=None, cfg_scale=1.15, flexibility=0,
+ dynamic_threshold=None, ret_traj=False):
+ # Check and convert inputs
+ batch_size = audio_or_feat.shape[0]
+
+ # Check CFG conditions
+ if cfg_mode is None: # Use default CFG mode
+ cfg_mode = self.cfg_mode
+ if cfg_cond is None: # Use default CFG conditions
+ cfg_cond = self.guiding_conditions
+ cfg_cond = [c for c in cfg_cond if c in ['audio', 'style']]
+
+ if not isinstance(cfg_scale, list):
+ cfg_scale = [cfg_scale] * len(cfg_cond)
+
+ # sort cfg_cond and cfg_scale
+ if len(cfg_cond) > 0:
+ cfg_cond, cfg_scale = zip(*sorted(zip(cfg_cond, cfg_scale), key=lambda x: ['audio', 'style'].index(x[0])))
+ else:
+ cfg_cond, cfg_scale = [], []
+
+ if 'style' in cfg_cond:
+ assert self.use_style and style_feat is not None
+
+ if self.use_style:
+ if style_feat is None: # use null style feature
+ style_feat = self.null_style_feat.expand(batch_size, -1, -1)
+ else:
+ assert style_feat is None, 'This model does not support style feature input!'
+
+ if audio_or_feat.ndim == 2:
+ # Extract audio features
+ assert audio_or_feat.shape[1] == 16000 * self.n_motions / self.fps, \
+ f'Incorrect audio length {audio_or_feat.shape[1]}'
+ audio_feat = self.extract_audio_feature(audio_or_feat) # (N, L, feature_dim)
+ elif audio_or_feat.ndim == 3:
+ assert audio_or_feat.shape[1] == self.n_motions, f'Incorrect audio feature length {audio_or_feat.shape[1]}'
+ audio_feat = audio_or_feat
+ else:
+ raise ValueError(f'Incorrect audio input shape {audio_or_feat.shape}')
+
+ if shape_feat.ndim == 2:
+ shape_feat = shape_feat.unsqueeze(1) # (N, 1, d_shape)
+ if style_feat is not None and style_feat.ndim == 2:
+ style_feat = style_feat.unsqueeze(1) # (N, 1, d_style)
+
+ if prev_motion_feat is None:
+ prev_motion_feat = self.start_motion_feat.expand(batch_size, -1, -1) # (N, n_prev_motions, d_motion)
+ if prev_audio_feat is None:
+ # (N, n_prev_motions, feature_dim)
+ prev_audio_feat = self.start_audio_feat.expand(batch_size, -1, -1)
+
+ if motion_at_T is None:
+ motion_at_T = torch.randn((batch_size, self.n_motions, self.motion_feat_dim)).to(self.device)
+
+ # Prepare input for the reverse diffusion process (including optional classifier-free guidance)
+ if 'audio' in cfg_cond:
+ audio_feat_null = self.null_audio_feat.expand(batch_size, self.n_motions, -1)
+ else:
+ audio_feat_null = audio_feat
+
+ if 'style' in cfg_cond:
+ person_feat_null = torch.cat([shape_feat, self.null_style_feat.expand(batch_size, -1, -1)], dim=-1)
+ else:
+ if self.use_style:
+ person_feat_null = torch.cat([shape_feat, style_feat], dim=-1)
+ else:
+ person_feat_null = shape_feat
+
+ audio_feat_in = [audio_feat_null]
+ person_feat_in = [person_feat_null]
+ for cond in cfg_cond:
+ if cond == 'audio':
+ audio_feat_in.append(audio_feat)
+ person_feat_in.append(person_feat_null)
+ elif cond == 'style':
+ if cfg_mode == 'independent':
+ audio_feat_in.append(audio_feat_null)
+ elif cfg_mode == 'incremental':
+ audio_feat_in.append(audio_feat)
+ else:
+ raise NotImplementedError(f'Unknown cfg_mode {cfg_mode}')
+ person_feat_in.append(torch.cat([shape_feat, style_feat], dim=-1))
+
+ n_entries = len(audio_feat_in)
+ audio_feat_in = torch.cat(audio_feat_in, dim=0)
+ person_feat_in = torch.cat(person_feat_in, dim=0)
+ prev_motion_feat_in = torch.cat([prev_motion_feat] * n_entries, dim=0)
+ prev_audio_feat_in = torch.cat([prev_audio_feat] * n_entries, dim=0)
+ indicator_in = torch.cat([indicator] * n_entries, dim=0) if indicator is not None else None
+
+ traj = {self.diffusion_sched.num_steps: motion_at_T}
+ for t in range(self.diffusion_sched.num_steps, 0, -1):
+ if t > 1:
+ z = torch.randn_like(motion_at_T)
+ else:
+ z = torch.zeros_like(motion_at_T)
+
+ alpha = self.diffusion_sched.alphas[t]
+ alpha_bar = self.diffusion_sched.alpha_bars[t]
+ alpha_bar_prev = self.diffusion_sched.alpha_bars[t - 1]
+ sigma = self.diffusion_sched.get_sigmas(t, flexibility)
+
+ motion_at_t = traj[t]
+ motion_in = torch.cat([motion_at_t] * n_entries, dim=0)
+ step_in = torch.tensor([t] * batch_size, device=self.device)
+ step_in = torch.cat([step_in] * n_entries, dim=0)
+
+ results = self.denoising_net(motion_in, audio_feat_in, person_feat_in, prev_motion_feat_in,
+ prev_audio_feat_in, step_in, indicator_in)
+
+ # Apply thresholding if specified
+ if dynamic_threshold:
+ dt_ratio, dt_min, dt_max = dynamic_threshold
+ abs_results = results[:, -self.n_motions:].reshape(batch_size * n_entries, -1).abs()
+ s = torch.quantile(abs_results, dt_ratio, dim=1)
+ s = torch.clamp(s, min=dt_min, max=dt_max)
+ s = s[..., None, None]
+ results = torch.clamp(results, min=-s, max=s)
+
+ results = results.chunk(n_entries)
+
+ # Unconditional target (CFG) or the conditional target (non-CFG)
+ target_theta = results[0][:, -self.n_motions:]
+ # Classifier-free Guidance (optional)
+ for i in range(0, n_entries - 1):
+ if cfg_mode == 'independent':
+ target_theta += cfg_scale[i] * (
+ results[i + 1][:, -self.n_motions:] - results[0][:, -self.n_motions:])
+ elif cfg_mode == 'incremental':
+ target_theta += cfg_scale[i] * (
+ results[i + 1][:, -self.n_motions:] - results[i][:, -self.n_motions:])
+ else:
+ raise NotImplementedError(f'Unknown cfg_mode {cfg_mode}')
+
+ if self.target == 'noise':
+ c0 = 1 / torch.sqrt(alpha)
+ c1 = (1 - alpha) / torch.sqrt(1 - alpha_bar)
+ motion_next = c0 * (motion_at_t - c1 * target_theta) + sigma * z
+ elif self.target == 'sample':
+ c0 = (1 - alpha_bar_prev) * torch.sqrt(alpha) / (1 - alpha_bar)
+ c1 = (1 - alpha) * torch.sqrt(alpha_bar_prev) / (1 - alpha_bar)
+ motion_next = c0 * motion_at_t + c1 * target_theta + sigma * z
+ else:
+ raise ValueError('Unknown target type: {}'.format(self.target))
+
+ traj[t - 1] = motion_next.detach() # Stop gradient and save trajectory.
+ traj[t] = traj[t].cpu() # Move previous output to CPU memory.
+ if not ret_traj:
+ del traj[t]
+
+ if ret_traj:
+ return traj, motion_at_T, audio_feat
+ else:
+ return traj[0], motion_at_T, audio_feat
+
+
+class DenoisingNetwork(nn.Module):
+ def __init__(self, args, device='cuda'):
+ super().__init__()
+
+ # Model parameters
+ self.use_style = args.style_enc_ckpt is not None
+ self.motion_feat_dim = 50
+ if args.rot_repr == 'aa':
+ self.motion_feat_dim += 1 if args.no_head_pose else 4
+ else:
+ raise ValueError(f'Unknown rotation representation {args.rot_repr}!')
+ self.shape_feat_dim = 100
+ if self.use_style:
+ self.style_feat_dim = args.d_style
+ self.person_feat_dim = self.shape_feat_dim + self.style_feat_dim
+ else:
+ self.person_feat_dim = self.shape_feat_dim
+ self.use_indicator = args.use_indicator
+
+ # Transformer
+ self.architecture = args.architecture
+ self.feature_dim = args.feature_dim
+ self.n_heads = args.n_heads
+ self.n_layers = args.n_layers
+ self.mlp_ratio = args.mlp_ratio
+ self.align_mask_width = args.align_mask_width
+ self.use_learnable_pe = not args.no_use_learnable_pe
+ # sequence length
+ self.n_prev_motions = args.n_prev_motions
+ self.n_motions = args.n_motions
+
+ # Temporal embedding for the diffusion time step
+ self.TE = PositionalEncoding(self.feature_dim, max_len=args.n_diff_steps + 1)
+ self.diff_step_map = nn.Sequential(
+ nn.Linear(self.feature_dim, self.feature_dim),
+ nn.GELU(),
+ nn.Linear(self.feature_dim, self.feature_dim)
+ )
+
+ if self.use_learnable_pe:
+ # Learnable positional encoding
+ self.PE = nn.Parameter(torch.randn(1, 1 + self.n_prev_motions + self.n_motions, self.feature_dim))
+ else:
+ self.PE = PositionalEncoding(self.feature_dim)
+
+ self.person_proj = nn.Linear(self.person_feat_dim, self.feature_dim)
+
+ # Transformer decoder
+ if self.architecture == 'decoder':
+ self.feature_proj = nn.Linear(self.motion_feat_dim + (1 if self.use_indicator else 0),
+ self.feature_dim)
+ decoder_layer = nn.TransformerDecoderLayer(
+ d_model=self.feature_dim, nhead=self.n_heads, dim_feedforward=self.mlp_ratio * self.feature_dim,
+ activation='gelu', batch_first=True
+ )
+ self.transformer = nn.TransformerDecoder(decoder_layer, num_layers=self.n_layers)
+ if self.align_mask_width > 0:
+ motion_len = self.n_prev_motions + self.n_motions
+ alignment_mask = enc_dec_mask(motion_len, motion_len, 1, self.align_mask_width - 1)
+ alignment_mask = F.pad(alignment_mask, (0, 0, 1, 0), value=False)
+ self.register_buffer('alignment_mask', alignment_mask)
+ else:
+ self.alignment_mask = None
+ else:
+ raise ValueError(f'Unknown architecture: {self.architecture}')
+
+ # Motion decoder
+ self.motion_dec = nn.Sequential(
+ nn.Linear(self.feature_dim, self.feature_dim // 2),
+ nn.GELU(),
+ nn.Linear(self.feature_dim // 2, self.motion_feat_dim)
+ )
+
+ self.to(device)
+
+ @property
+ def device(self):
+ return next(self.parameters()).device
+
+ def forward(self, motion_feat, audio_feat, person_feat, prev_motion_feat, prev_audio_feat, step, indicator=None):
+ """
+ Args:
+ motion_feat: (N, L, d_motion). Noisy motion feature
+ audio_feat: (N, L, feature_dim)
+ person_feat: (N, 1, d_person)
+ prev_motion_feat: (N, L_p, d_motion). Padded previous motion coefficients or feature
+ prev_audio_feat: (N, L_p, d_audio). Padded previous motion coefficients or feature
+ step: (N,)
+ indicator: (N, L). 0/1 indicator for the real (unpadded) motion feature
+
+ Returns:
+ motion_feat_target: (N, L_p + L, d_motion)
+ """
+ # Diffusion time step embedding
+ diff_step_embedding = self.diff_step_map(self.TE.pe[0, step]).unsqueeze(1) # (N, 1, diff_step_dim)
+
+ person_feat = self.person_proj(person_feat) # (N, 1, feature_dim)
+ person_feat = person_feat + diff_step_embedding
+
+ if indicator is not None:
+ indicator = torch.cat([torch.zeros((indicator.shape[0], self.n_prev_motions), device=indicator.device),
+ indicator], dim=1) # (N, L_p + L)
+ indicator = indicator.unsqueeze(-1) # (N, L_p + L, 1)
+
+ # Concat features and embeddings
+ if self.architecture == 'decoder':
+ feats_in = torch.cat([prev_motion_feat, motion_feat], dim=1) # (N, L_p + L, d_motion)
+ else:
+ raise ValueError(f'Unknown architecture: {self.architecture}')
+ if self.use_indicator:
+ feats_in = torch.cat([feats_in, indicator], dim=-1) # (N, L_p + L, d_motion + d_audio + 1)
+
+ feats_in = self.feature_proj(feats_in) # (N, L_p + L, feature_dim)
+ feats_in = torch.cat([person_feat, feats_in], dim=1) # (N, 1 + L_p + L, feature_dim)
+
+ if self.use_learnable_pe:
+ feats_in = feats_in + self.PE
+ else:
+ feats_in = self.PE(feats_in)
+
+ # Transformer
+ if self.architecture == 'decoder':
+ audio_feat_in = torch.cat([prev_audio_feat, audio_feat], dim=1) # (N, L_p + L, d_audio)
+ feat_out = self.transformer(feats_in, audio_feat_in, memory_mask=self.alignment_mask)
+ else:
+ raise ValueError(f'Unknown architecture: {self.architecture}')
+
+ # Decode predicted motion feature noise / sample
+ motion_feat_target = self.motion_dec(feat_out[:, 1:]) # (N, L_p + L, d_motion)
+
+ return motion_feat_target
diff --git a/diffposetalk/diffposetalk.py b/diffposetalk/diffposetalk.py
new file mode 100644
index 0000000000000000000000000000000000000000..b35f78edc484fcd088b1afbb2c6d00254977c45a
--- /dev/null
+++ b/diffposetalk/diffposetalk.py
@@ -0,0 +1,228 @@
+import math
+import tempfile
+import warnings
+from pathlib import Path
+
+import cv2
+import librosa
+import numpy as np
+import torch
+import torch.nn.functional as F
+from tqdm import tqdm
+from pydantic import BaseModel
+
+from .diff_talking_head import DiffTalkingHead
+from .utils import NullableArgs, coef_dict_to_vertices, get_coef_dict
+from .utils.media import combine_video_and_audio, convert_video, reencode_audio
+
+warnings.filterwarnings('ignore', message='PySoundFile failed. Trying audioread instead.')
+
+class DiffPoseTalkConfig(BaseModel):
+ no_context_audio_feat: bool = False
+ model_path: str = "pretrained_models/diffposetalk/iter_0110000.pt" # DPT/head-SA-hubert-WM
+ coef_stats: str = "pretrained_models/diffposetalk/stats_train.npz"
+ style_path: str = "pretrained_models/diffposetalk/style/L4H4-T0.1-BS32/iter_0034000/normal.npy"
+ dynamic_threshold_ratio: float = 0.99
+ dynamic_threshold_min: float = 1.0
+ dynamic_threshold_max: float = 4.0
+ scale_audio: float = 1.15
+ scale_style: float = 3.0
+
+class DiffPoseTalk:
+ def __init__(self, config: DiffPoseTalkConfig = DiffPoseTalkConfig(), device="cuda"):
+ self.cfg = config
+ self.device = device
+
+ self.no_context_audio_feat = self.cfg.no_context_audio_feat
+ model_data = torch.load(self.cfg.model_path, map_location=self.device)
+
+ self.model_args = NullableArgs(model_data['args'])
+ self.model = DiffTalkingHead(self.model_args, self.device)
+ model_data['model'].pop('denoising_net.TE.pe')
+ self.model.load_state_dict(model_data['model'], strict=False)
+ self.model.to(self.device)
+ self.model.eval()
+
+ self.use_indicator = self.model_args.use_indicator
+ self.rot_repr = self.model_args.rot_repr
+ self.predict_head_pose = not self.model_args.no_head_pose
+ if self.model.use_style:
+ style_dir = Path(self.model_args.style_enc_ckpt)
+ style_dir = Path(*style_dir.with_suffix('').parts[-3::2])
+ self.style_dir = style_dir
+
+ # sequence
+ self.n_motions = self.model_args.n_motions
+ self.n_prev_motions = self.model_args.n_prev_motions
+ self.fps = self.model_args.fps
+ self.audio_unit = 16000. / self.fps # num of samples per frame
+ self.n_audio_samples = round(self.audio_unit * self.n_motions)
+ self.pad_mode = self.model_args.pad_mode
+
+ self.coef_stats = dict(np.load(self.cfg.coef_stats))
+ self.coef_stats = {k: torch.from_numpy(v).to(self.device) for k, v in self.coef_stats.items()}
+
+ if self.cfg.dynamic_threshold_ratio > 0:
+ self.dynamic_threshold = (self.cfg.dynamic_threshold_ratio, self.cfg.dynamic_threshold_min,
+ self.cfg.dynamic_threshold_max)
+ else:
+ self.dynamic_threshold = None
+
+
+ def infer_from_file(self, audio_path, shape_coef):
+ n_repetitions = 1
+ cfg_mode = None
+ cfg_cond = self.model.guiding_conditions
+ cfg_scale = []
+ for cond in cfg_cond:
+ if cond == 'audio':
+ cfg_scale.append(self.cfg.scale_audio)
+ elif cond == 'style':
+ cfg_scale.append(self.cfg.scale_style)
+
+ coef_dict = self.infer_coeffs(audio_path, shape_coef, self.cfg.style_path, n_repetitions,
+ cfg_mode, cfg_cond, cfg_scale, include_shape=True)
+ return coef_dict
+
+ @torch.no_grad()
+ def infer_coeffs(self, audio, shape_coef, style_feat=None, n_repetitions=1,
+ cfg_mode=None, cfg_cond=None, cfg_scale=1.15, include_shape=False):
+ # Returns dict[str, (n_repetitions, L, *)]
+ # Step 1: Preprocessing
+ # Preprocess audio
+ if isinstance(audio, (str, Path)):
+ audio, _ = librosa.load(audio, sr=16000, mono=True)
+ if isinstance(audio, np.ndarray):
+ audio = torch.from_numpy(audio).to(self.device)
+ assert audio.ndim == 1, 'Audio must be 1D tensor.'
+ audio_mean, audio_std = torch.mean(audio), torch.std(audio)
+ audio = (audio - audio_mean) / (audio_std + 1e-5)
+
+ # Preprocess shape coefficient
+ if isinstance(shape_coef, (str, Path)):
+ shape_coef = np.load(shape_coef)
+ if not isinstance(shape_coef, np.ndarray):
+ shape_coef = shape_coef['shape']
+ if isinstance(shape_coef, np.ndarray):
+ shape_coef = torch.from_numpy(shape_coef).float().to(self.device)
+ assert shape_coef.ndim <= 2, 'Shape coefficient must be 1D or 2D tensor.'
+ if shape_coef.ndim > 1:
+ # use the first frame as the shape coefficient
+ shape_coef = shape_coef[0]
+ original_shape_coef = shape_coef.clone()
+ if self.coef_stats is not None:
+ shape_coef = (shape_coef - self.coef_stats['shape_mean']) / self.coef_stats['shape_std']
+ shape_coef = shape_coef.unsqueeze(0).expand(n_repetitions, -1)
+
+ # Preprocess style feature if given
+ if style_feat is not None:
+ assert self.model.use_style
+ if isinstance(style_feat, (str, Path)):
+ style_feat = Path(style_feat)
+ if not style_feat.exists() and not style_feat.is_absolute():
+ style_feat = style_feat.parent / self.style_dir / style_feat.name
+ style_feat = np.load(style_feat)
+ if not isinstance(style_feat, np.ndarray):
+ style_feat = style_feat['style']
+ if isinstance(style_feat, np.ndarray):
+ style_feat = torch.from_numpy(style_feat).float().to(self.device)
+ assert style_feat.ndim == 1, 'Style feature must be 1D tensor.'
+ style_feat = style_feat.unsqueeze(0).expand(n_repetitions, -1)
+
+ # Step 2: Predict motion coef
+ # divide into synthesize units and do synthesize
+ clip_len = int(len(audio) / 16000 * self.fps)
+ stride = self.n_motions
+ if clip_len <= self.n_motions:
+ n_subdivision = 1
+ else:
+ n_subdivision = math.ceil(clip_len / stride)
+
+ # Prepare audio input
+ n_padding_audio_samples = self.n_audio_samples * n_subdivision - len(audio)
+ n_padding_frames = math.ceil(n_padding_audio_samples / self.audio_unit)
+ if n_padding_audio_samples > 0:
+ if self.pad_mode == 'zero':
+ padding_value = 0
+ elif self.pad_mode == 'replicate':
+ padding_value = audio[-1]
+ else:
+ raise ValueError(f'Unknown pad mode: {self.pad_mode}')
+ audio = F.pad(audio, (0, n_padding_audio_samples), value=padding_value)
+
+ if not self.no_context_audio_feat:
+ audio_feat = self.model.extract_audio_feature(audio.unsqueeze(0), self.n_motions * n_subdivision)
+
+ # Generate `self.n_motions` new frames at one time, and use the last `self.n_prev_motions` frames
+ # from the previous generation as the initial motion condition
+ coef_list = []
+ for i in range(0, n_subdivision):
+ start_idx = i * stride
+ end_idx = start_idx + self.n_motions
+ indicator = torch.ones((n_repetitions, self.n_motions)).to(self.device) if self.use_indicator else None
+ if indicator is not None and i == n_subdivision - 1 and n_padding_frames > 0:
+ indicator[:, -n_padding_frames:] = 0
+ if not self.no_context_audio_feat:
+ audio_in = audio_feat[:, start_idx:end_idx].expand(n_repetitions, -1, -1)
+ else:
+ audio_in = audio[round(start_idx * self.audio_unit):round(end_idx * self.audio_unit)].unsqueeze(0)
+
+ # generate motion coefficients
+ if i == 0:
+ # -> (N, L, d_motion=n_code_per_frame * code_dim)
+ motion_feat, noise, prev_audio_feat = self.model.sample(audio_in, shape_coef, style_feat,
+ indicator=indicator, cfg_mode=cfg_mode,
+ cfg_cond=cfg_cond, cfg_scale=cfg_scale,
+ dynamic_threshold=self.dynamic_threshold)
+ else:
+ motion_feat, noise, prev_audio_feat = self.model.sample(audio_in, shape_coef, style_feat,
+ prev_motion_feat, prev_audio_feat, noise,
+ indicator=indicator, cfg_mode=cfg_mode,
+ cfg_cond=cfg_cond, cfg_scale=cfg_scale,
+ dynamic_threshold=self.dynamic_threshold)
+ prev_motion_feat = motion_feat[:, -self.n_prev_motions:].clone()
+ prev_audio_feat = prev_audio_feat[:, -self.n_prev_motions:]
+
+ motion_coef = motion_feat
+ if i == n_subdivision - 1 and n_padding_frames > 0:
+ motion_coef = motion_coef[:, :-n_padding_frames] # delete padded frames
+ coef_list.append(motion_coef)
+
+ motion_coef = torch.cat(coef_list, dim=1)
+
+ # Step 3: restore to coef dict
+ coef_dict = get_coef_dict(motion_coef, None, self.coef_stats, self.predict_head_pose, self.rot_repr)
+ if include_shape:
+ coef_dict['shape'] = original_shape_coef[None, None].expand(n_repetitions, motion_coef.shape[1], -1)
+ return self.coef_to_a1_format(coef_dict)
+
+ def coef_to_a1_format(self, coef_dict):
+ n_frames = coef_dict['exp'].shape[1]
+ new_coef_dict = []
+ for i in range(n_frames):
+
+ new_coef_dict.append({
+ "expression_params": coef_dict["exp"][0, i:i+1],
+ "jaw_params": coef_dict["pose"][0, i:i+1, 3:],
+ "eye_pose_params": torch.zeros(1, 6).type_as(coef_dict["pose"]),
+ "pose_params": coef_dict["pose"][0, i:i+1, :3],
+ "eyelid_params": None
+ })
+ return new_coef_dict
+
+
+
+
+
+ @staticmethod
+ def _pad_coef(coef, n_frames, elem_ndim=1):
+ if coef.ndim == elem_ndim:
+ coef = coef[None]
+ elem_shape = coef.shape[1:]
+ if coef.shape[0] >= n_frames:
+ new_coef = coef[:n_frames]
+ else:
+ # repeat the last coef frame
+ new_coef = torch.cat([coef, coef[[-1]].expand(n_frames - coef.shape[0], *elem_shape)], dim=0)
+ return new_coef # (n_frames, *elem_shape)
+
diff --git a/diffposetalk/hubert.py b/diffposetalk/hubert.py
new file mode 100644
index 0000000000000000000000000000000000000000..c98c8f040ae9905f8646c612bc63b5968f3737e5
--- /dev/null
+++ b/diffposetalk/hubert.py
@@ -0,0 +1,51 @@
+from transformers import HubertModel
+from transformers.modeling_outputs import BaseModelOutput
+
+from .wav2vec2 import linear_interpolation
+
+_CONFIG_FOR_DOC = 'HubertConfig'
+
+
+class HubertModel(HubertModel):
+ def __init__(self, config):
+ super().__init__(config)
+
+ def forward(self, input_values, output_fps=25, attention_mask=None, output_attentions=None,
+ output_hidden_states=None, return_dict=None, frame_num=None):
+ self.config.output_attentions = True
+
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states)
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ extract_features = self.feature_extractor(input_values) # (N, C, L)
+ # Resample the audio feature @ 50 fps to `output_fps`.
+ if frame_num is not None:
+ extract_features_len = round(frame_num * 50 / output_fps)
+ extract_features = extract_features[:, :, :extract_features_len]
+ extract_features = linear_interpolation(extract_features, 50, output_fps, output_len=frame_num)
+ extract_features = extract_features.transpose(1, 2) # (N, L, C)
+
+ if attention_mask is not None:
+ # compute reduced attention_mask corresponding to feature vectors
+ attention_mask = self._get_feature_vector_attention_mask(extract_features.shape[1], attention_mask)
+
+ hidden_states = self.feature_projection(extract_features)
+ hidden_states = self._mask_hidden_states(hidden_states)
+
+ encoder_outputs = self.encoder(
+ hidden_states,
+ attention_mask=attention_mask,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ hidden_states = encoder_outputs[0]
+
+ if not return_dict:
+ return (hidden_states,) + encoder_outputs[1:]
+
+ return BaseModelOutput(last_hidden_state=hidden_states, hidden_states=encoder_outputs.hidden_states,
+ attentions=encoder_outputs.attentions, )
diff --git a/diffposetalk/utils/__init__.py b/diffposetalk/utils/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..55e5f844b4a2b2e06ed806c0747cbda88b29d084
--- /dev/null
+++ b/diffposetalk/utils/__init__.py
@@ -0,0 +1 @@
+from .common import *
diff --git a/diffposetalk/utils/common.py b/diffposetalk/utils/common.py
new file mode 100644
index 0000000000000000000000000000000000000000..90bbd46b3b66c61a64ec2b6373cac34c6b3eb63a
--- /dev/null
+++ b/diffposetalk/utils/common.py
@@ -0,0 +1,378 @@
+from functools import reduce
+from pathlib import Path
+
+import torch
+import torch.nn.functional as F
+
+
+class NullableArgs:
+ def __init__(self, namespace):
+ for key, value in namespace.__dict__.items():
+ setattr(self, key, value)
+
+ def __getattr__(self, key):
+ # when an attribute lookup has not found the attribute
+ if key == 'align_mask_width':
+ if 'use_alignment_mask' in self.__dict__:
+ return 1 if self.use_alignment_mask else 0
+ else:
+ return 0
+ if key == 'no_head_pose':
+ return not self.predict_head_pose
+ if key == 'no_use_learnable_pe':
+ return not self.use_learnable_pe
+
+ return None
+
+
+def count_parameters(model):
+ return sum(p.numel() for p in model.parameters() if p.requires_grad)
+
+
+def get_option_text(args, parser):
+ message = ''
+ for k, v in sorted(vars(args).items()):
+ comment = ''
+ default = parser.get_default(k)
+ if v != default:
+ comment = f'\t[default: {str(default)}]'
+ message += f'{str(k):>30}: {str(v):<30}{comment}\n'
+ return message
+
+
+def get_model_path(exp_name, iteration, model_type='DPT'):
+ exp_root_dir = Path(__file__).parent.parent / 'experiments' / model_type
+ exp_dir = exp_root_dir / exp_name
+ if not exp_dir.exists():
+ exp_dir = next(exp_root_dir.glob(f'{exp_name}*'))
+ model_path = exp_dir / f'checkpoints/iter_{iteration:07}.pt'
+ return model_path, exp_dir.relative_to(exp_root_dir)
+
+
+def get_pose_input(coef_dict, rot_repr, with_global_pose):
+ if rot_repr == 'aa':
+ pose_input = coef_dict['pose'] if with_global_pose else coef_dict['pose'][..., -3:]
+ # Remove mouth rotation round y, z axis
+ pose_input = pose_input[..., :-2]
+ else:
+ raise ValueError(f'Unknown rotation representation: {rot_repr}')
+ return pose_input
+
+
+def get_motion_coef(coef_dict, rot_repr, with_global_pose=False, norm_stats=None):
+ if norm_stats is not None:
+ if rot_repr == 'aa':
+ keys = ['exp', 'pose']
+ else:
+ raise ValueError(f'Unknown rotation representation {rot_repr}!')
+
+ coef_dict = {k: (coef_dict[k] - norm_stats[f'{k}_mean']) / norm_stats[f'{k}_std'] for k in keys}
+ pose_coef = get_pose_input(coef_dict, rot_repr, with_global_pose)
+ return torch.cat([coef_dict['exp'], pose_coef], dim=-1)
+
+
+def get_coef_dict(motion_coef, shape_coef=None, denorm_stats=None, with_global_pose=False, rot_repr='aa'):
+ coef_dict = {
+ 'exp': motion_coef[..., :50]
+ }
+ if rot_repr == 'aa':
+ if with_global_pose:
+ coef_dict['pose'] = motion_coef[..., 50:]
+ else:
+ placeholder = torch.zeros_like(motion_coef[..., :3])
+ coef_dict['pose'] = torch.cat([placeholder, motion_coef[..., -1:]], dim=-1)
+ # Add back rotation around y, z axis
+ coef_dict['pose'] = torch.cat([coef_dict['pose'], torch.zeros_like(motion_coef[..., :2])], dim=-1)
+ else:
+ raise ValueError(f'Unknown rotation representation {rot_repr}!')
+
+ if shape_coef is not None:
+ if motion_coef.ndim == 3:
+ if shape_coef.ndim == 2:
+ shape_coef = shape_coef.unsqueeze(1)
+ if shape_coef.shape[1] == 1:
+ shape_coef = shape_coef.expand(-1, motion_coef.shape[1], -1)
+
+ coef_dict['shape'] = shape_coef
+
+ if denorm_stats is not None:
+ coef_dict = {k: coef_dict[k] * denorm_stats[f'{k}_std'] + denorm_stats[f'{k}_mean'] for k in coef_dict}
+
+ if not with_global_pose:
+ if rot_repr == 'aa':
+ coef_dict['pose'][..., :3] = 0
+ else:
+ raise ValueError(f'Unknown rotation representation {rot_repr}!')
+
+ return coef_dict
+
+
+def coef_dict_to_vertices(coef_dict, flame, rot_repr='aa', ignore_global_rot=False, flame_batch_size=512):
+ shape = coef_dict['exp'].shape[:-1]
+ coef_dict = {k: v.view(-1, v.shape[-1]) for k, v in coef_dict.items()}
+ n_samples = reduce(lambda x, y: x * y, shape, 1)
+
+ # Convert to vertices
+ vert_list = []
+ for i in range(0, n_samples, flame_batch_size):
+ batch_coef_dict = {k: v[i:i + flame_batch_size] for k, v in coef_dict.items()}
+ if rot_repr == 'aa':
+ vert, _, _ = flame(
+ batch_coef_dict['shape'], batch_coef_dict['exp'], batch_coef_dict['pose'],
+ pose2rot=True, ignore_global_rot=ignore_global_rot, return_lm2d=False, return_lm3d=False)
+ else:
+ raise ValueError(f'Unknown rot_repr: {rot_repr}')
+ vert_list.append(vert)
+
+ vert_list = torch.cat(vert_list, dim=0) # (n_samples, 5023, 3)
+ vert_list = vert_list.view(*shape, -1, 3) # (..., 5023, 3)
+
+ return vert_list
+
+
+def compute_loss(args, is_starting_sample, shape_coef, motion_coef_gt, noise, target, prev_motion_coef, coef_stats,
+ flame, end_idx=None):
+ if args.criterion.lower() == 'l2':
+ criterion_func = F.mse_loss
+ elif args.criterion.lower() == 'l1':
+ criterion_func = F.l1_loss
+ else:
+ raise NotImplementedError(f'Criterion {args.criterion} not implemented.')
+
+ loss_vert = None
+ loss_vel = None
+ loss_smooth = None
+ loss_head_angle = None
+ loss_head_vel = None
+ loss_head_smooth = None
+ loss_head_trans_vel = None
+ loss_head_trans_accel = None
+ loss_head_trans = None
+ if args.target == 'noise':
+ loss_noise = criterion_func(noise, target[:, args.n_prev_motions:], reduction='none')
+ elif args.target == 'sample':
+ if is_starting_sample:
+ target = target[:, args.n_prev_motions:]
+ else:
+ motion_coef_gt = torch.cat([prev_motion_coef, motion_coef_gt], dim=1)
+ if args.no_constrain_prev:
+ target = torch.cat([prev_motion_coef, target[:, args.n_prev_motions:]], dim=1)
+
+ loss_noise = criterion_func(motion_coef_gt, target, reduction='none')
+
+ if args.l_vert > 0 or args.l_vel > 0:
+ coef_gt = get_coef_dict(motion_coef_gt, shape_coef, coef_stats, with_global_pose=False,
+ rot_repr=args.rot_repr)
+ coef_pred = get_coef_dict(target, shape_coef, coef_stats, with_global_pose=False,
+ rot_repr=args.rot_repr)
+ seq_len = target.shape[1]
+
+ if args.rot_repr == 'aa':
+ verts_gt, _, _ = flame(coef_gt['shape'].view(-1, 100), coef_gt['exp'].view(-1, 50),
+ coef_gt['pose'].view(-1, 6), return_lm2d=False, return_lm3d=False)
+ verts_pred, _, _ = flame(coef_pred['shape'].view(-1, 100), coef_pred['exp'].view(-1, 50),
+ coef_pred['pose'].view(-1, 6), return_lm2d=False, return_lm3d=False)
+ else:
+ raise ValueError(f'Unknown rotation representation {args.rot_repr}!')
+ verts_gt = verts_gt.view(-1, seq_len, 5023, 3)
+ verts_pred = verts_pred.view(-1, seq_len, 5023, 3)
+
+ if args.l_vert > 0:
+ loss_vert = criterion_func(verts_gt, verts_pred, reduction='none')
+
+ if args.l_vel > 0:
+ vel_gt = verts_gt[:, 1:] - verts_gt[:, :-1]
+ vel_pred = verts_pred[:, 1:] - verts_pred[:, :-1]
+ loss_vel = criterion_func(vel_gt, vel_pred, reduction='none')
+
+ if args.l_smooth > 0:
+ vel_pred = verts_pred[:, 1:] - verts_pred[:, :-1]
+ loss_smooth = criterion_func(vel_pred[:, 1:], vel_pred[:, :-1], reduction='none')
+
+ # head pose
+ if not args.no_head_pose:
+ if args.rot_repr == 'aa':
+ head_pose_gt = motion_coef_gt[:, :, 50:53]
+ head_pose_pred = target[:, :, 50:53]
+ else:
+ raise ValueError(f'Unknown rotation representation {args.rot_repr}!')
+
+ if args.l_head_angle > 0:
+ loss_head_angle = criterion_func(head_pose_gt, head_pose_pred, reduction='none')
+
+ if args.l_head_vel > 0:
+ head_vel_gt = head_pose_gt[:, 1:] - head_pose_gt[:, :-1]
+ head_vel_pred = head_pose_pred[:, 1:] - head_pose_pred[:, :-1]
+ loss_head_vel = criterion_func(head_vel_gt, head_vel_pred, reduction='none')
+
+ if args.l_head_smooth > 0:
+ head_vel_pred = head_pose_pred[:, 1:] - head_pose_pred[:, :-1]
+ loss_head_smooth = criterion_func(head_vel_pred[:, 1:], head_vel_pred[:, :-1], reduction='none')
+
+ if not is_starting_sample and args.l_head_trans > 0:
+ # # version 1: constrain both the predicted previous and current motions (x_{-3} ~ x_{2})
+ # head_pose_trans = head_pose_pred[:, args.n_prev_motions - 3:args.n_prev_motions + 3]
+ # head_vel_pred = head_pose_trans[:, 1:] - head_pose_trans[:, :-1]
+ # head_accel_pred = head_vel_pred[:, 1:] - head_vel_pred[:, :-1]
+
+ # version 2: constrain only the predicted current motions (x_{0} ~ x_{2})
+ head_pose_trans = torch.cat([head_pose_gt[:, args.n_prev_motions - 3:args.n_prev_motions],
+ head_pose_pred[:, args.n_prev_motions:args.n_prev_motions + 3]], dim=1)
+ head_vel_pred = head_pose_trans[:, 1:] - head_pose_trans[:, :-1]
+ head_accel_pred = head_vel_pred[:, 1:] - head_vel_pred[:, :-1]
+
+ # will constrain x_{-2|0} ~ x_{1}
+ loss_head_trans_vel = criterion_func(head_vel_pred[:, 2:4], head_vel_pred[:, 1:3], reduction='none')
+ # will constrain x_{-3|0} ~ x_{2}
+ loss_head_trans_accel = criterion_func(head_accel_pred[:, 1:], head_accel_pred[:, :-1],
+ reduction='none')
+ else:
+ raise ValueError(f'Unknown diffusion target: {args.target}')
+
+ if end_idx is None:
+ mask = torch.ones((target.shape[0], args.n_motions), dtype=torch.bool, device=target.device)
+ else:
+ mask = torch.arange(args.n_motions, device=target.device).expand(target.shape[0], -1) < end_idx.unsqueeze(1)
+
+ if args.target == 'sample' and not is_starting_sample:
+ if args.no_constrain_prev:
+ # Warning: this option will be deprecated in the future
+ mask = torch.cat([torch.zeros_like(mask[:, :args.n_prev_motions]), mask], dim=1)
+ else:
+ mask = torch.cat([torch.ones_like(mask[:, :args.n_prev_motions]), mask], dim=1)
+
+ loss_noise = loss_noise[mask].mean()
+ if loss_vert is not None:
+ loss_vert = loss_vert[mask].mean()
+ if loss_vel is not None:
+ loss_vel = loss_vel[mask[:, 1:]]
+ loss_vel = loss_vel.mean() if torch.numel(loss_vel) > 0 else None
+ if loss_smooth is not None:
+ loss_smooth = loss_smooth[mask[:, 2:]]
+ loss_smooth = loss_smooth.mean() if torch.numel(loss_smooth) > 0 else None
+ if loss_head_angle is not None:
+ loss_head_angle = loss_head_angle[mask].mean()
+ if loss_head_vel is not None:
+ loss_head_vel = loss_head_vel[mask[:, 1:]]
+ loss_head_vel = loss_head_vel.mean() if torch.numel(loss_head_vel) > 0 else None
+ if loss_head_smooth is not None:
+ loss_head_smooth = loss_head_smooth[mask[:, 2:]]
+ loss_head_smooth = loss_head_smooth.mean() if torch.numel(loss_head_smooth) > 0 else None
+ if loss_head_trans_vel is not None:
+ vel_mask = mask[:, args.n_prev_motions:args.n_prev_motions + 2]
+ accel_mask = mask[:, args.n_prev_motions:args.n_prev_motions + 3]
+ loss_head_trans_vel = loss_head_trans_vel[vel_mask].mean()
+ loss_head_trans_accel = loss_head_trans_accel[accel_mask].mean()
+ loss_head_trans = loss_head_trans_vel + loss_head_trans_accel
+
+ return loss_noise, loss_vert, loss_vel, loss_smooth, loss_head_angle, loss_head_vel, loss_head_smooth, \
+ loss_head_trans
+
+
+def _truncate_audio(audio, end_idx, pad_mode='zero'):
+ batch_size = audio.shape[0]
+ audio_trunc = audio.clone()
+ if pad_mode == 'replicate':
+ for i in range(batch_size):
+ audio_trunc[i, end_idx[i]:] = audio_trunc[i, end_idx[i] - 1]
+ elif pad_mode == 'zero':
+ for i in range(batch_size):
+ audio_trunc[i, end_idx[i]:] = 0
+ else:
+ raise ValueError(f'Unknown pad mode {pad_mode}!')
+
+ return audio_trunc
+
+
+def _truncate_coef_dict(coef_dict, end_idx, pad_mode='zero'):
+ batch_size = coef_dict['exp'].shape[0]
+ coef_dict_trunc = {k: v.clone() for k, v in coef_dict.items()}
+ if pad_mode == 'replicate':
+ for i in range(batch_size):
+ for k in coef_dict_trunc:
+ coef_dict_trunc[k][i, end_idx[i]:] = coef_dict_trunc[k][i, end_idx[i] - 1]
+ elif pad_mode == 'zero':
+ for i in range(batch_size):
+ for k in coef_dict:
+ coef_dict_trunc[k][i, end_idx[i]:] = 0
+ else:
+ raise ValueError(f'Unknown pad mode: {pad_mode}!')
+
+ return coef_dict_trunc
+
+
+def truncate_coef_dict_and_audio(audio, coef_dict, n_motions, audio_unit=640, pad_mode='zero'):
+ batch_size = audio.shape[0]
+ end_idx = torch.randint(1, n_motions, (batch_size,), device=audio.device)
+ audio_end_idx = (end_idx * audio_unit).long()
+ # mask = torch.arange(n_motions, device=audio.device).expand(batch_size, -1) < end_idx.unsqueeze(1)
+
+ # truncate audio
+ audio_trunc = _truncate_audio(audio, audio_end_idx, pad_mode=pad_mode)
+
+ # truncate coef dict
+ coef_dict_trunc = _truncate_coef_dict(coef_dict, end_idx, pad_mode=pad_mode)
+
+ return audio_trunc, coef_dict_trunc, end_idx
+
+
+def truncate_motion_coef_and_audio(audio, motion_coef, n_motions, audio_unit=640, pad_mode='zero'):
+ batch_size = audio.shape[0]
+ end_idx = torch.randint(1, n_motions, (batch_size,), device=audio.device)
+ audio_end_idx = (end_idx * audio_unit).long()
+ # mask = torch.arange(n_motions, device=audio.device).expand(batch_size, -1) < end_idx.unsqueeze(1)
+
+ # truncate audio
+ audio_trunc = _truncate_audio(audio, audio_end_idx, pad_mode=pad_mode)
+
+ # prepare coef dict and stats
+ coef_dict = {'exp': motion_coef[..., :50], 'pose_any': motion_coef[..., 50:]}
+
+ # truncate coef dict
+ coef_dict_trunc = _truncate_coef_dict(coef_dict, end_idx, pad_mode=pad_mode)
+ motion_coef_trunc = torch.cat([coef_dict_trunc['exp'], coef_dict_trunc['pose_any']], dim=-1)
+
+ return audio_trunc, motion_coef_trunc, end_idx
+
+
+def nt_xent_loss(feature_a, feature_b, temperature):
+ """
+ Normalized temperature-scaled cross entropy loss.
+
+ (Adapted from https://github.com/sthalles/SimCLR/blob/master/simclr.py)
+
+ Args:
+ feature_a (torch.Tensor): shape (batch_size, feature_dim)
+ feature_b (torch.Tensor): shape (batch_size, feature_dim)
+ temperature (float): temperature scaling factor
+
+ Returns:
+ torch.Tensor: scalar
+ """
+ batch_size = feature_a.shape[0]
+ device = feature_a.device
+
+ features = torch.cat([feature_a, feature_b], dim=0)
+
+ labels = torch.cat([torch.arange(batch_size), torch.arange(batch_size)], dim=0)
+ labels = (labels.unsqueeze(0) == labels.unsqueeze(1))
+ labels = labels.to(device)
+
+ features = F.normalize(features, dim=1)
+ similarity_matrix = torch.matmul(features, features.T)
+
+ # discard the main diagonal from both: labels and similarities matrix
+ mask = torch.eye(labels.shape[0], dtype=torch.bool).to(device)
+ labels = labels[~mask].view(labels.shape[0], -1)
+ similarity_matrix = similarity_matrix[~mask].view(labels.shape[0], -1)
+
+ # select the positives and negatives
+ positives = similarity_matrix[labels].view(labels.shape[0], -1)
+ negatives = similarity_matrix[~labels].view(labels.shape[0], -1)
+
+ logits = torch.cat([positives, negatives], dim=1)
+ logits = logits / temperature
+ labels = torch.zeros(labels.shape[0], dtype=torch.long).to(device)
+
+ loss = F.cross_entropy(logits, labels)
+ return loss
diff --git a/diffposetalk/utils/media.py b/diffposetalk/utils/media.py
new file mode 100644
index 0000000000000000000000000000000000000000..6e181671832137ae4d14d994d7c16763e25d1b21
--- /dev/null
+++ b/diffposetalk/utils/media.py
@@ -0,0 +1,35 @@
+import shlex
+import subprocess
+from pathlib import Path
+
+
+def combine_video_and_audio(video_file, audio_file, output, quality=17, copy_audio=True):
+ audio_codec = '-c:a copy' if copy_audio else ''
+ cmd = f'ffmpeg -i {video_file} -i {audio_file} -c:v libx264 -crf {quality} -pix_fmt yuv420p ' \
+ f'{audio_codec} -fflags +shortest -y -hide_banner -loglevel error {output}'
+ assert subprocess.run(shlex.split(cmd)).returncode == 0
+
+
+def combine_frames_and_audio(frame_files, audio_file, fps, output, quality=17):
+ cmd = f'ffmpeg -framerate {fps} -i {frame_files} -i {audio_file} -c:v libx264 -crf {quality} -pix_fmt yuv420p ' \
+ f'-c:a copy -fflags +shortest -y -hide_banner -loglevel error {output}'
+ assert subprocess.run(shlex.split(cmd)).returncode == 0
+
+
+def convert_video(video_file, output, quality=17):
+ cmd = f'ffmpeg -i {video_file} -c:v libx264 -crf {quality} -pix_fmt yuv420p ' \
+ f'-fflags +shortest -y -hide_banner -loglevel error {output}'
+ assert subprocess.run(shlex.split(cmd)).returncode == 0
+
+
+def reencode_audio(audio_file, output):
+ cmd = f'ffmpeg -i {audio_file} -y -hide_banner -loglevel error {output}'
+ assert subprocess.run(shlex.split(cmd)).returncode == 0
+
+
+def extract_frames(filename, output_dir, quality=1):
+ output_dir = Path(output_dir)
+ output_dir.mkdir(parents=True, exist_ok=True)
+ cmd = f'ffmpeg -i {filename} -qmin 1 -qscale:v {quality} -y -start_number 0 -hide_banner -loglevel error ' \
+ f'{output_dir / "%06d.jpg"}'
+ assert subprocess.run(shlex.split(cmd)).returncode == 0
diff --git a/diffposetalk/utils/renderer.py b/diffposetalk/utils/renderer.py
new file mode 100644
index 0000000000000000000000000000000000000000..f2f202d94aa282ff6e0fd4c6d40e2e1e6a487f49
--- /dev/null
+++ b/diffposetalk/utils/renderer.py
@@ -0,0 +1,147 @@
+import os
+import tempfile
+
+import cv2
+import kiui.mesh
+import numpy as np
+
+# os.environ['PYOPENGL_PLATFORM'] = 'osmesa' # osmesa or egl
+os.environ['PYOPENGL_PLATFORM'] = 'egl'
+import pyrender
+import trimesh
+# from psbody.mesh import Mesh
+
+
+class MeshRenderer:
+ def __init__(self, size, fov=16 / 180 * np.pi, camera_pose=None, light_pose=None, black_bg=False):
+ # Camera
+ self.frustum = {'near': 0.01, 'far': 3.0}
+ self.camera = pyrender.PerspectiveCamera(yfov=fov, znear=self.frustum['near'],
+ zfar=self.frustum['far'], aspectRatio=1.0)
+
+ # Material
+ self.primitive_material = pyrender.material.MetallicRoughnessMaterial(
+ alphaMode='BLEND',
+ baseColorFactor=[0.3, 0.3, 0.3, 1.0],
+ metallicFactor=0.8,
+ roughnessFactor=0.8
+ )
+
+ # Lighting
+ light_color = np.array([1., 1., 1.])
+ self.light = pyrender.DirectionalLight(color=light_color, intensity=2)
+ self.light_angle = np.pi / 6.0
+
+ # Scene
+ self.scene = None
+ self._init_scene(black_bg)
+
+ # add camera and lighting
+ self._init_camera(camera_pose)
+ self._init_lighting(light_pose)
+
+ # Renderer
+ self.renderer = pyrender.OffscreenRenderer(*size, point_size=1.0)
+
+ def _init_scene(self, black_bg=False):
+ if black_bg:
+ self.scene = pyrender.Scene(ambient_light=[.2, .2, .2], bg_color=[0, 0, 0])
+ else:
+ self.scene = pyrender.Scene(ambient_light=[.2, .2, .2], bg_color=[255, 255, 255])
+
+ def _init_camera(self, camera_pose=None):
+ if camera_pose is None:
+ camera_pose = np.eye(4)
+ camera_pose[:3, 3] = np.array([0, 0, 1])
+ self.camera_pose = camera_pose.copy()
+ self.camera_node = self.scene.add(self.camera, pose=camera_pose)
+
+ def _init_lighting(self, light_pose=None):
+ if light_pose is None:
+ light_pose = np.eye(4)
+ light_pose[:3, 3] = np.array([0, 0, 1])
+ self.light_pose = light_pose.copy()
+
+ light_poses = self._get_light_poses(self.light_angle, light_pose)
+ self.light_nodes = [self.scene.add(self.light, pose=light_pose) for light_pose in light_poses]
+
+ def set_camera_pose(self, camera_pose):
+ self.camera_pose = camera_pose.copy()
+ self.scene.set_pose(self.camera_node, pose=camera_pose)
+
+ def set_lighting_pose(self, light_pose):
+ self.light_pose = light_pose.copy()
+
+ light_poses = self._get_light_poses(self.light_angle, light_pose)
+ for light_node, light_pose in zip(self.light_nodes, light_poses):
+ self.scene.set_pose(light_node, pose=light_pose)
+
+ def render_mesh(self, v, f, t_center, rot=np.zeros(3), tex_img=None, tex_uv=None,
+ camera_pose=None, light_pose=None):
+ # Prepare mesh
+ v[:] = cv2.Rodrigues(rot)[0].dot((v - t_center).T).T + t_center
+ if tex_img is not None:
+ tex = pyrender.Texture(source=tex_img, source_channels='RGB')
+ tex_material = pyrender.material.MetallicRoughnessMaterial(baseColorTexture=tex)
+ from kiui.mesh import Mesh
+ import torch
+ mesh = Mesh(
+ v=torch.from_numpy(v),
+ f=torch.from_numpy(f),
+ vt=tex_uv['vt'],
+ ft=tex_uv['ft']
+ )
+ with tempfile.NamedTemporaryFile(suffix='.obj') as f:
+ mesh.write_obj(f.name)
+ tri_mesh = trimesh.load(f.name, process=False)
+ return tri_mesh
+ # tri_mesh = self._pyrender_mesh_workaround(mesh)
+ render_mesh = pyrender.Mesh.from_trimesh(tri_mesh, material=tex_material)
+ else:
+ tri_mesh = trimesh.Trimesh(vertices=v, faces=f)
+ render_mesh = pyrender.Mesh.from_trimesh(tri_mesh, material=self.primitive_material, smooth=True)
+ mesh_node = self.scene.add(render_mesh, pose=np.eye(4))
+
+ # Change camera and lighting pose if necessary
+ if camera_pose is not None:
+ self.set_camera_pose(camera_pose)
+ if light_pose is not None:
+ self.set_lighting_pose(light_pose)
+
+ # Render
+ flags = pyrender.RenderFlags.SKIP_CULL_FACES
+ color, depth = self.renderer.render(self.scene, flags=flags)
+
+ # Remove mesh
+ self.scene.remove_node(mesh_node)
+
+ return color, depth
+
+ @staticmethod
+ def _get_light_poses(light_angle, light_pose):
+ light_poses = []
+ init_pos = light_pose[:3, 3].copy()
+
+ light_poses.append(light_pose.copy())
+
+ light_pose[:3, 3] = cv2.Rodrigues(np.array([light_angle, 0, 0]))[0].dot(init_pos)
+ light_poses.append(light_pose.copy())
+
+ light_pose[:3, 3] = cv2.Rodrigues(np.array([-light_angle, 0, 0]))[0].dot(init_pos)
+ light_poses.append(light_pose.copy())
+
+ light_pose[:3, 3] = cv2.Rodrigues(np.array([0, -light_angle, 0]))[0].dot(init_pos)
+ light_poses.append(light_pose.copy())
+
+ light_pose[:3, 3] = cv2.Rodrigues(np.array([0, light_angle, 0]))[0].dot(init_pos)
+ light_poses.append(light_pose.copy())
+
+ return light_poses
+
+ @staticmethod
+ def _pyrender_mesh_workaround(mesh):
+ # Workaround as pyrender requires number of vertices and uv coordinates to be the same
+ with tempfile.NamedTemporaryFile(suffix='.obj') as f:
+ mesh.write_obj(f.name)
+ tri_mesh = trimesh.load(f.name, process=False)
+ return tri_mesh
diff --git a/diffposetalk/utils/rotation_conversions.py b/diffposetalk/utils/rotation_conversions.py
new file mode 100644
index 0000000000000000000000000000000000000000..41846705df7ec744308f02916d8e24ab1e8213cd
--- /dev/null
+++ b/diffposetalk/utils/rotation_conversions.py
@@ -0,0 +1,569 @@
+# This code is based on https://github.com/Mathux/ACTOR.git
+# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
+# Check PYTORCH3D_LICENCE before use
+
+import functools
+from typing import Optional
+
+import torch
+import torch.nn.functional as F
+
+
+"""
+The transformation matrices returned from the functions in this file assume
+the points on which the transformation will be applied are column vectors.
+i.e. the R matrix is structured as
+
+ R = [
+ [Rxx, Rxy, Rxz],
+ [Ryx, Ryy, Ryz],
+ [Rzx, Rzy, Rzz],
+ ] # (3, 3)
+
+This matrix can be applied to column vectors by post multiplication
+by the points e.g.
+
+ points = [[0], [1], [2]] # (3 x 1) xyz coordinates of a point
+ transformed_points = R * points
+
+To apply the same matrix to points which are row vectors, the R matrix
+can be transposed and pre multiplied by the points:
+
+e.g.
+ points = [[0, 1, 2]] # (1 x 3) xyz coordinates of a point
+ transformed_points = points * R.transpose(1, 0)
+"""
+
+
+def quaternion_to_matrix(quaternions):
+ """
+ Convert rotations given as quaternions to rotation matrices.
+
+ Args:
+ quaternions: quaternions with real part first,
+ as tensor of shape (..., 4).
+
+ Returns:
+ Rotation matrices as tensor of shape (..., 3, 3).
+ """
+ r, i, j, k = torch.unbind(quaternions, -1)
+ two_s = 2.0 / (quaternions * quaternions).sum(-1)
+
+ o = torch.stack(
+ (
+ 1 - two_s * (j * j + k * k),
+ two_s * (i * j - k * r),
+ two_s * (i * k + j * r),
+ two_s * (i * j + k * r),
+ 1 - two_s * (i * i + k * k),
+ two_s * (j * k - i * r),
+ two_s * (i * k - j * r),
+ two_s * (j * k + i * r),
+ 1 - two_s * (i * i + j * j),
+ ),
+ -1,
+ )
+ return o.reshape(quaternions.shape[:-1] + (3, 3))
+
+
+def _copysign(a, b):
+ """
+ Return a tensor where each element has the absolute value taken from the,
+ corresponding element of a, with sign taken from the corresponding
+ element of b. This is like the standard copysign floating-point operation,
+ but is not careful about negative 0 and NaN.
+
+ Args:
+ a: source tensor.
+ b: tensor whose signs will be used, of the same shape as a.
+
+ Returns:
+ Tensor of the same shape as a with the signs of b.
+ """
+ signs_differ = (a < 0) != (b < 0)
+ return torch.where(signs_differ, -a, a)
+
+
+def _sqrt_positive_part(x):
+ """
+ Returns torch.sqrt(torch.max(0, x))
+ but with a zero subgradient where x is 0.
+ """
+ ret = torch.zeros_like(x)
+ positive_mask = x > 0
+ ret[positive_mask] = torch.sqrt(x[positive_mask])
+ return ret
+
+
+def matrix_to_quaternion(matrix):
+ """
+ Convert rotations given as rotation matrices to quaternions.
+
+ Args:
+ matrix: Rotation matrices as tensor of shape (..., 3, 3).
+
+ Returns:
+ quaternions with real part first, as tensor of shape (..., 4).
+ """
+ if matrix.size(-1) != 3 or matrix.size(-2) != 3:
+ raise ValueError(f"Invalid rotation matrix shape f{matrix.shape}.")
+ m00 = matrix[..., 0, 0]
+ m11 = matrix[..., 1, 1]
+ m22 = matrix[..., 2, 2]
+ o0 = 0.5 * _sqrt_positive_part(1 + m00 + m11 + m22)
+ x = 0.5 * _sqrt_positive_part(1 + m00 - m11 - m22)
+ y = 0.5 * _sqrt_positive_part(1 - m00 + m11 - m22)
+ z = 0.5 * _sqrt_positive_part(1 - m00 - m11 + m22)
+ o1 = _copysign(x, matrix[..., 2, 1] - matrix[..., 1, 2])
+ o2 = _copysign(y, matrix[..., 0, 2] - matrix[..., 2, 0])
+ o3 = _copysign(z, matrix[..., 1, 0] - matrix[..., 0, 1])
+ return torch.stack((o0, o1, o2, o3), -1)
+
+
+def _axis_angle_rotation(axis: str, angle):
+ """
+ Return the rotation matrices for one of the rotations about an axis
+ of which Euler angles describe, for each value of the angle given.
+
+ Args:
+ axis: Axis label "X" or "Y or "Z".
+ angle: any shape tensor of Euler angles in radians
+
+ Returns:
+ Rotation matrices as tensor of shape (..., 3, 3).
+ """
+
+ cos = torch.cos(angle)
+ sin = torch.sin(angle)
+ one = torch.ones_like(angle)
+ zero = torch.zeros_like(angle)
+
+ if axis == "X":
+ R_flat = (one, zero, zero, zero, cos, -sin, zero, sin, cos)
+ if axis == "Y":
+ R_flat = (cos, zero, sin, zero, one, zero, -sin, zero, cos)
+ if axis == "Z":
+ R_flat = (cos, -sin, zero, sin, cos, zero, zero, zero, one)
+
+ return torch.stack(R_flat, -1).reshape(angle.shape + (3, 3))
+
+
+def euler_angles_to_matrix(euler_angles, convention: str):
+ """
+ Convert rotations given as Euler angles in radians to rotation matrices.
+
+ Args:
+ euler_angles: Euler angles in radians as tensor of shape (..., 3).
+ convention: Convention string of three uppercase letters from
+ {"X", "Y", and "Z"}.
+
+ Returns:
+ Rotation matrices as tensor of shape (..., 3, 3).
+ """
+ if euler_angles.dim() == 0 or euler_angles.shape[-1] != 3:
+ raise ValueError("Invalid input euler angles.")
+ if len(convention) != 3:
+ raise ValueError("Convention must have 3 letters.")
+ if convention[1] in (convention[0], convention[2]):
+ raise ValueError(f"Invalid convention {convention}.")
+ for letter in convention:
+ if letter not in ("X", "Y", "Z"):
+ raise ValueError(f"Invalid letter {letter} in convention string.")
+ matrices = map(_axis_angle_rotation, convention, torch.unbind(euler_angles, -1))
+ return functools.reduce(torch.matmul, matrices)
+
+
+def _angle_from_tan(
+ axis: str, other_axis: str, data, horizontal: bool, tait_bryan: bool
+):
+ """
+ Extract the first or third Euler angle from the two members of
+ the matrix which are positive constant times its sine and cosine.
+
+ Args:
+ axis: Axis label "X" or "Y or "Z" for the angle we are finding.
+ other_axis: Axis label "X" or "Y or "Z" for the middle axis in the
+ convention.
+ data: Rotation matrices as tensor of shape (..., 3, 3).
+ horizontal: Whether we are looking for the angle for the third axis,
+ which means the relevant entries are in the same row of the
+ rotation matrix. If not, they are in the same column.
+ tait_bryan: Whether the first and third axes in the convention differ.
+
+ Returns:
+ Euler Angles in radians for each matrix in dataset as a tensor
+ of shape (...).
+ """
+
+ i1, i2 = {"X": (2, 1), "Y": (0, 2), "Z": (1, 0)}[axis]
+ if horizontal:
+ i2, i1 = i1, i2
+ even = (axis + other_axis) in ["XY", "YZ", "ZX"]
+ if horizontal == even:
+ return torch.atan2(data[..., i1], data[..., i2])
+ if tait_bryan:
+ return torch.atan2(-data[..., i2], data[..., i1])
+ return torch.atan2(data[..., i2], -data[..., i1])
+
+
+def _index_from_letter(letter: str):
+ if letter == "X":
+ return 0
+ if letter == "Y":
+ return 1
+ if letter == "Z":
+ return 2
+
+
+def matrix_to_euler_angles(matrix, convention: str):
+ """
+ Convert rotations given as rotation matrices to Euler angles in radians.
+
+ Args:
+ matrix: Rotation matrices as tensor of shape (..., 3, 3).
+ convention: Convention string of three uppercase letters.
+
+ Returns:
+ Euler angles in radians as tensor of shape (..., 3).
+ """
+ if len(convention) != 3:
+ raise ValueError("Convention must have 3 letters.")
+ if convention[1] in (convention[0], convention[2]):
+ raise ValueError(f"Invalid convention {convention}.")
+ for letter in convention:
+ if letter not in ("X", "Y", "Z"):
+ raise ValueError(f"Invalid letter {letter} in convention string.")
+ if matrix.size(-1) != 3 or matrix.size(-2) != 3:
+ raise ValueError(f"Invalid rotation matrix shape f{matrix.shape}.")
+ i0 = _index_from_letter(convention[0])
+ i2 = _index_from_letter(convention[2])
+ tait_bryan = i0 != i2
+ if tait_bryan:
+ central_angle = torch.asin(
+ matrix[..., i0, i2] * (-1.0 if i0 - i2 in [-1, 2] else 1.0)
+ )
+ else:
+ central_angle = torch.acos(matrix[..., i0, i0])
+
+ o = (
+ _angle_from_tan(
+ convention[0], convention[1], matrix[..., i2], False, tait_bryan
+ ),
+ central_angle,
+ _angle_from_tan(
+ convention[2], convention[1], matrix[..., i0, :], True, tait_bryan
+ ),
+ )
+ return torch.stack(o, -1)
+
+
+def random_quaternions(
+ n: int, dtype: Optional[torch.dtype] = None, device=None, requires_grad=False
+):
+ """
+ Generate random quaternions representing rotations,
+ i.e. versors with nonnegative real part.
+
+ Args:
+ n: Number of quaternions in a batch to return.
+ dtype: Type to return.
+ device: Desired device of returned tensor. Default:
+ uses the current device for the default tensor type.
+ requires_grad: Whether the resulting tensor should have the gradient
+ flag set.
+
+ Returns:
+ Quaternions as tensor of shape (N, 4).
+ """
+ o = torch.randn((n, 4), dtype=dtype, device=device, requires_grad=requires_grad)
+ s = (o * o).sum(1)
+ o = o / _copysign(torch.sqrt(s), o[:, 0])[:, None]
+ return o
+
+
+def random_rotations(
+ n: int, dtype: Optional[torch.dtype] = None, device=None, requires_grad=False
+):
+ """
+ Generate random rotations as 3x3 rotation matrices.
+
+ Args:
+ n: Number of rotation matrices in a batch to return.
+ dtype: Type to return.
+ device: Device of returned tensor. Default: if None,
+ uses the current device for the default tensor type.
+ requires_grad: Whether the resulting tensor should have the gradient
+ flag set.
+
+ Returns:
+ Rotation matrices as tensor of shape (n, 3, 3).
+ """
+ quaternions = random_quaternions(
+ n, dtype=dtype, device=device, requires_grad=requires_grad
+ )
+ return quaternion_to_matrix(quaternions)
+
+
+def random_rotation(
+ dtype: Optional[torch.dtype] = None, device=None, requires_grad=False
+):
+ """
+ Generate a single random 3x3 rotation matrix.
+
+ Args:
+ dtype: Type to return
+ device: Device of returned tensor. Default: if None,
+ uses the current device for the default tensor type
+ requires_grad: Whether the resulting tensor should have the gradient
+ flag set
+
+ Returns:
+ Rotation matrix as tensor of shape (3, 3).
+ """
+ return random_rotations(1, dtype, device, requires_grad)[0]
+
+
+def standardize_quaternion(quaternions):
+ """
+ Convert a unit quaternion to a standard form: one in which the real
+ part is non negative.
+
+ Args:
+ quaternions: Quaternions with real part first,
+ as tensor of shape (..., 4).
+
+ Returns:
+ Standardized quaternions as tensor of shape (..., 4).
+ """
+ return torch.where(quaternions[..., 0:1] < 0, -quaternions, quaternions)
+
+
+def quaternion_raw_multiply(a, b):
+ """
+ Multiply two quaternions.
+ Usual torch rules for broadcasting apply.
+
+ Args:
+ a: Quaternions as tensor of shape (..., 4), real part first.
+ b: Quaternions as tensor of shape (..., 4), real part first.
+
+ Returns:
+ The product of a and b, a tensor of quaternions shape (..., 4).
+ """
+ aw, ax, ay, az = torch.unbind(a, -1)
+ bw, bx, by, bz = torch.unbind(b, -1)
+ ow = aw * bw - ax * bx - ay * by - az * bz
+ ox = aw * bx + ax * bw + ay * bz - az * by
+ oy = aw * by - ax * bz + ay * bw + az * bx
+ oz = aw * bz + ax * by - ay * bx + az * bw
+ return torch.stack((ow, ox, oy, oz), -1)
+
+
+def quaternion_multiply(a, b):
+ """
+ Multiply two quaternions representing rotations, returning the quaternion
+ representing their composition, i.e. the versor with nonnegative real part.
+ Usual torch rules for broadcasting apply.
+
+ Args:
+ a: Quaternions as tensor of shape (..., 4), real part first.
+ b: Quaternions as tensor of shape (..., 4), real part first.
+
+ Returns:
+ The product of a and b, a tensor of quaternions of shape (..., 4).
+ """
+ ab = quaternion_raw_multiply(a, b)
+ return standardize_quaternion(ab)
+
+
+def quaternion_invert(quaternion):
+ """
+ Given a quaternion representing rotation, get the quaternion representing
+ its inverse.
+
+ Args:
+ quaternion: Quaternions as tensor of shape (..., 4), with real part
+ first, which must be versors (unit quaternions).
+
+ Returns:
+ The inverse, a tensor of quaternions of shape (..., 4).
+ """
+
+ return quaternion * quaternion.new_tensor([1, -1, -1, -1])
+
+
+def quaternion_apply(quaternion, point):
+ """
+ Apply the rotation given by a quaternion to a 3D point.
+ Usual torch rules for broadcasting apply.
+
+ Args:
+ quaternion: Tensor of quaternions, real part first, of shape (..., 4).
+ point: Tensor of 3D points of shape (..., 3).
+
+ Returns:
+ Tensor of rotated points of shape (..., 3).
+ """
+ if point.size(-1) != 3:
+ raise ValueError(f"Points are not in 3D, f{point.shape}.")
+ real_parts = point.new_zeros(point.shape[:-1] + (1,))
+ point_as_quaternion = torch.cat((real_parts, point), -1)
+ out = quaternion_raw_multiply(
+ quaternion_raw_multiply(quaternion, point_as_quaternion),
+ quaternion_invert(quaternion),
+ )
+ return out[..., 1:]
+
+
+def axis_angle_to_matrix(axis_angle):
+ """
+ Convert rotations given as axis/angle to rotation matrices.
+
+ Args:
+ axis_angle: Rotations given as a vector in axis angle form,
+ as a tensor of shape (..., 3), where the magnitude is
+ the angle turned anticlockwise in radians around the
+ vector's direction.
+
+ Returns:
+ Rotation matrices as tensor of shape (..., 3, 3).
+ """
+ return quaternion_to_matrix(axis_angle_to_quaternion(axis_angle))
+
+
+def matrix_to_axis_angle(matrix):
+ """
+ Convert rotations given as rotation matrices to axis/angle.
+
+ Args:
+ matrix: Rotation matrices as tensor of shape (..., 3, 3).
+
+ Returns:
+ Rotations given as a vector in axis angle form, as a tensor
+ of shape (..., 3), where the magnitude is the angle
+ turned anticlockwise in radians around the vector's
+ direction.
+ """
+ return quaternion_to_axis_angle(matrix_to_quaternion(matrix))
+
+
+def axis_angle_to_quaternion(axis_angle):
+ """
+ Convert rotations given as axis/angle to quaternions.
+
+ Args:
+ axis_angle: Rotations given as a vector in axis angle form,
+ as a tensor of shape (..., 3), where the magnitude is
+ the angle turned anticlockwise in radians around the
+ vector's direction.
+
+ Returns:
+ quaternions with real part first, as tensor of shape (..., 4).
+ """
+ angles = torch.norm(axis_angle, p=2, dim=-1, keepdim=True)
+ half_angles = 0.5 * angles
+ eps = 1e-6
+ small_angles = angles.abs() < eps
+ sin_half_angles_over_angles = torch.empty_like(angles)
+ sin_half_angles_over_angles[~small_angles] = (
+ torch.sin(half_angles[~small_angles]) / angles[~small_angles]
+ )
+ # for x small, sin(x/2) is about x/2 - (x/2)^3/6
+ # so sin(x/2)/x is about 1/2 - (x*x)/48
+ sin_half_angles_over_angles[small_angles] = (
+ 0.5 - (angles[small_angles] * angles[small_angles]) / 48
+ )
+ quaternions = torch.cat(
+ [torch.cos(half_angles), axis_angle * sin_half_angles_over_angles], dim=-1
+ )
+ return quaternions
+
+
+def quaternion_to_axis_angle(quaternions):
+ """
+ Convert rotations given as quaternions to axis/angle.
+
+ Args:
+ quaternions: quaternions with real part first,
+ as tensor of shape (..., 4).
+
+ Returns:
+ Rotations given as a vector in axis angle form, as a tensor
+ of shape (..., 3), where the magnitude is the angle
+ turned anticlockwise in radians around the vector's
+ direction.
+ """
+ norms = torch.norm(quaternions[..., 1:], p=2, dim=-1, keepdim=True)
+ half_angles = torch.atan2(norms, quaternions[..., :1])
+ angles = 2 * half_angles
+ eps = 1e-6
+ small_angles = angles.abs() < eps
+ sin_half_angles_over_angles = torch.empty_like(angles)
+ sin_half_angles_over_angles[~small_angles] = (
+ torch.sin(half_angles[~small_angles]) / angles[~small_angles]
+ )
+ # for x small, sin(x/2) is about x/2 - (x/2)^3/6
+ # so sin(x/2)/x is about 1/2 - (x*x)/48
+ sin_half_angles_over_angles[small_angles] = (
+ 0.5 - (angles[small_angles] * angles[small_angles]) / 48
+ )
+ return quaternions[..., 1:] / sin_half_angles_over_angles
+
+
+def rotation_6d_to_matrix(d6: torch.Tensor) -> torch.Tensor:
+ """
+ Converts 6D rotation representation by Zhou et al. [1] to rotation matrix
+ using Gram--Schmidt orthogonalisation per Section B of [1].
+ Args:
+ d6: 6D rotation representation, of size (*, 6)
+
+ Returns:
+ batch of rotation matrices of size (*, 3, 3)
+
+ [1] Zhou, Y., Barnes, C., Lu, J., Yang, J., & Li, H.
+ On the Continuity of Rotation Representations in Neural Networks.
+ IEEE Conference on Computer Vision and Pattern Recognition, 2019.
+ Retrieved from http://arxiv.org/abs/1812.07035
+ """
+
+ a1, a2 = d6[..., :3], d6[..., 3:]
+ b1 = F.normalize(a1, dim=-1)
+ b2 = a2 - (b1 * a2).sum(-1, keepdim=True) * b1
+ b2 = F.normalize(b2, dim=-1)
+ b3 = torch.cross(b1, b2, dim=-1)
+ return torch.stack((b1, b2, b3), dim=-2)
+
+
+def matrix_to_rotation_6d(matrix: torch.Tensor) -> torch.Tensor:
+ """
+ Converts rotation matrices to 6D rotation representation by Zhou et al. [1]
+ by dropping the last row. Note that 6D representation is not unique.
+ Args:
+ matrix: batch of rotation matrices of size (*, 3, 3)
+
+ Returns:
+ 6D rotation representation, of size (*, 6)
+
+ [1] Zhou, Y., Barnes, C., Lu, J., Yang, J., & Li, H.
+ On the Continuity of Rotation Representations in Neural Networks.
+ IEEE Conference on Computer Vision and Pattern Recognition, 2019.
+ Retrieved from http://arxiv.org/abs/1812.07035
+ """
+ return matrix[..., :2, :].clone().reshape(*matrix.size()[:-2], 6)
+
+
+def axis_angle_to_rotation_6d(axis_angle):
+ """
+ Convert rotations given as axis/angle to 6D rotation representation by Zhou
+ et al. [1].
+
+ Args:
+ axis_angle: Rotations given as a vector in axis angle form,
+ as a tensor of shape (..., 3), where the magnitude is
+ the angle turned anticlockwise in radians around the
+ vector's direction.
+
+ Returns:
+ 6D rotation representation, of size (*, 6)
+ """
+ return matrix_to_rotation_6d(axis_angle_to_matrix(axis_angle))
diff --git a/diffposetalk/wav2vec2.py b/diffposetalk/wav2vec2.py
new file mode 100644
index 0000000000000000000000000000000000000000..499140bbe90d147d07ba180b261ec8ea6f752df2
--- /dev/null
+++ b/diffposetalk/wav2vec2.py
@@ -0,0 +1,119 @@
+from packaging import version
+from typing import Optional, Tuple
+
+import numpy as np
+import torch
+import torch.nn.functional as F
+import transformers
+from transformers import Wav2Vec2Model
+from transformers.modeling_outputs import BaseModelOutput
+
+_CONFIG_FOR_DOC = 'Wav2Vec2Config'
+
+
+# the implementation of Wav2Vec2Model is borrowed from
+# https://huggingface.co/transformers/_modules/transformers/models/wav2vec2/modeling_wav2vec2.html#Wav2Vec2Model
+# initialize our encoder with the pre-trained wav2vec 2.0 weights.
+def _compute_mask_indices(shape: Tuple[int, int], mask_prob: float, mask_length: int,
+ attention_mask: Optional[torch.Tensor] = None, min_masks: int = 0, ) -> np.ndarray:
+ bsz, all_sz = shape
+ mask = np.full((bsz, all_sz), False)
+
+ all_num_mask = int(mask_prob * all_sz / float(mask_length) + np.random.rand())
+ all_num_mask = max(min_masks, all_num_mask)
+ mask_idcs = []
+ padding_mask = attention_mask.ne(1) if attention_mask is not None else None
+ for i in range(bsz):
+ if padding_mask is not None:
+ sz = all_sz - padding_mask[i].long().sum().item()
+ num_mask = int(mask_prob * sz / float(mask_length) + np.random.rand())
+ num_mask = max(min_masks, num_mask)
+ else:
+ sz = all_sz
+ num_mask = all_num_mask
+
+ lengths = np.full(num_mask, mask_length)
+
+ if sum(lengths) == 0:
+ lengths[0] = min(mask_length, sz - 1)
+
+ min_len = min(lengths)
+ if sz - min_len <= num_mask:
+ min_len = sz - num_mask - 1
+
+ mask_idc = np.random.choice(sz - min_len, num_mask, replace=False)
+ mask_idc = np.asarray([mask_idc[j] + offset for j in range(len(mask_idc)) for offset in range(lengths[j])])
+ mask_idcs.append(np.unique(mask_idc[mask_idc < sz]))
+
+ min_len = min([len(m) for m in mask_idcs])
+ for i, mask_idc in enumerate(mask_idcs):
+ if len(mask_idc) > min_len:
+ mask_idc = np.random.choice(mask_idc, min_len, replace=False)
+ mask[i, mask_idc] = True
+ return mask
+
+
+# linear interpolation layer
+def linear_interpolation(features, input_fps, output_fps, output_len=None):
+ # features: (N, C, L)
+ seq_len = features.shape[2] / float(input_fps)
+ if output_len is None:
+ output_len = int(seq_len * output_fps)
+ output_features = F.interpolate(features, size=output_len, align_corners=False, mode='linear')
+ return output_features
+
+
+class Wav2Vec2Model(Wav2Vec2Model):
+ def __init__(self, config):
+ super().__init__(config)
+ self.is_old_version = version.parse(transformers.__version__) < version.parse('4.7.0')
+
+ def forward(self, input_values, output_fps=25, attention_mask=None, output_attentions=None,
+ output_hidden_states=None, return_dict=None, frame_num=None):
+ self.config.output_attentions = True
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states)
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ hidden_states = self.feature_extractor(input_values) # (N, C, L)
+ # Resample the audio feature @ 50 fps to `output_fps`.
+ if frame_num is not None:
+ hidden_states_len = round(frame_num * 50 / output_fps)
+ hidden_states = hidden_states[:, :, :hidden_states_len]
+ hidden_states = linear_interpolation(hidden_states, 50, output_fps, output_len=frame_num)
+ hidden_states = hidden_states.transpose(1, 2) # (N, L, C)
+
+ if attention_mask is not None:
+ output_lengths = self._get_feat_extract_output_lengths(attention_mask.sum(-1))
+ attention_mask = torch.zeros(hidden_states.shape[:2], dtype=hidden_states.dtype,
+ device=hidden_states.device)
+ attention_mask[(torch.arange(attention_mask.shape[0], device=hidden_states.device), output_lengths - 1)] = 1
+ attention_mask = attention_mask.flip([-1]).cumsum(-1).flip([-1]).bool()
+
+ if self.is_old_version:
+ hidden_states = self.feature_projection(hidden_states)
+ else:
+ hidden_states = self.feature_projection(hidden_states)[0]
+
+ if self.config.apply_spec_augment and self.training:
+ batch_size, sequence_length, hidden_size = hidden_states.size()
+ if self.config.mask_time_prob > 0:
+ mask_time_indices = _compute_mask_indices((batch_size, sequence_length), self.config.mask_time_prob,
+ self.config.mask_time_length, attention_mask=attention_mask,
+ min_masks=2, )
+ hidden_states[torch.from_numpy(mask_time_indices)] = self.masked_spec_embed.to(hidden_states.dtype)
+ if self.config.mask_feature_prob > 0:
+ mask_feature_indices = _compute_mask_indices((batch_size, hidden_size), self.config.mask_feature_prob,
+ self.config.mask_feature_length, )
+ mask_feature_indices = torch.from_numpy(mask_feature_indices).to(hidden_states.device)
+ hidden_states[mask_feature_indices[:, None].expand(-1, sequence_length, -1)] = 0
+ encoder_outputs = self.encoder(hidden_states, attention_mask=attention_mask,
+ output_attentions=output_attentions, output_hidden_states=output_hidden_states,
+ return_dict=return_dict, )
+ hidden_states = encoder_outputs[0]
+ if not return_dict:
+ return (hidden_states,) + encoder_outputs[1:]
+
+ return BaseModelOutput(last_hidden_state=hidden_states, hidden_states=encoder_outputs.hidden_states,
+ attentions=encoder_outputs.attentions, )
diff --git a/eval/arc_score.py b/eval/arc_score.py
new file mode 100644
index 0000000000000000000000000000000000000000..98890088735ebf59bf08fd15d8c4e6a39a64e0cb
--- /dev/null
+++ b/eval/arc_score.py
@@ -0,0 +1,253 @@
+import os
+import torch
+from insightface.app import FaceAnalysis
+from insightface.utils import face_align
+from PIL import Image
+from torchvision import models, transforms
+from curricularface import get_model
+import cv2
+import numpy as np
+import numpy
+
+
+def matrix_sqrt(matrix):
+ eigenvalues, eigenvectors = torch.linalg.eigh(matrix)
+ sqrt_eigenvalues = torch.sqrt(torch.clamp(eigenvalues, min=0))
+ sqrt_matrix = (eigenvectors * sqrt_eigenvalues).mm(eigenvectors.T)
+ return sqrt_matrix
+
+def sample_video_frames(video_path, num_frames=16):
+ cap = cv2.VideoCapture(video_path)
+ total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
+ frame_indices = np.linspace(0, total_frames - 1, num_frames, dtype=int)
+
+ frames = []
+ for idx in frame_indices:
+ cap.set(cv2.CAP_PROP_POS_FRAMES, idx)
+ ret, frame = cap.read()
+ if ret:
+ # print(frame.shape)
+ #if frame.shape[1] > 1024:
+ # frame = frame[:, 1440:, :]
+ # print(frame.shape)
+ frames.append(frame)
+ cap.release()
+ return frames
+
+
+def get_face_keypoints(face_model, image_bgr):
+ face_info = face_model.get(image_bgr)
+ if len(face_info) > 0:
+ return sorted(face_info, key=lambda x: (x['bbox'][2] - x['bbox'][0]) * (x['bbox'][3] - x['bbox'][1]))[-1]
+ return None
+
+def load_image(image):
+ img = image.convert('RGB')
+ img = transforms.Resize((299, 299))(img) # Resize to Inception input size
+ img = transforms.ToTensor()(img)
+ return img.unsqueeze(0) # Add batch dimension
+
+def calculate_fid(real_activations, fake_activations, device="cuda"):
+ real_activations_tensor = torch.tensor(real_activations).to(device)
+ fake_activations_tensor = torch.tensor(fake_activations).to(device)
+
+ mu1 = real_activations_tensor.mean(dim=0)
+ sigma1 = torch.cov(real_activations_tensor.T)
+ mu2 = fake_activations_tensor.mean(dim=0)
+ sigma2 = torch.cov(fake_activations_tensor.T)
+
+ ssdiff = torch.sum((mu1 - mu2) ** 2)
+ covmean = matrix_sqrt(sigma1.mm(sigma2))
+ if torch.is_complex(covmean):
+ covmean = covmean.real
+ fid = ssdiff + torch.trace(sigma1 + sigma2 - 2 * covmean)
+ return fid.item()
+
+def batch_cosine_similarity(embedding_image, embedding_frames, device="cuda"):
+ embedding_image = torch.tensor(embedding_image).to(device)
+ embedding_frames = torch.tensor(embedding_frames).to(device)
+ return torch.nn.functional.cosine_similarity(embedding_image, embedding_frames, dim=-1).cpu().numpy()
+
+
+def get_activations(images, model, batch_size=16):
+ model.eval()
+ activations = []
+ with torch.no_grad():
+ for i in range(0, len(images), batch_size):
+ batch = images[i:i + batch_size]
+ pred = model(batch)
+ activations.append(pred)
+ activations = torch.cat(activations, dim=0).cpu().numpy()
+ if activations.shape[0] == 1:
+ activations = np.repeat(activations, 2, axis=0)
+ return activations
+
+def pad_np_bgr_image(np_image, scale=1.25):
+ assert scale >= 1.0, "scale should be >= 1.0"
+ pad_scale = scale - 1.0
+ h, w = np_image.shape[:2]
+ top = bottom = int(h * pad_scale)
+ left = right = int(w * pad_scale)
+ return cv2.copyMakeBorder(np_image, top, bottom, left, right, cv2.BORDER_CONSTANT, value=(128, 128, 128)), (left, top)
+
+
+def process_image(face_model, image_path):
+ if isinstance(image_path, str):
+ np_faceid_image = np.array(Image.open(image_path).convert("RGB"))
+ elif isinstance(image_path, numpy.ndarray):
+ np_faceid_image = image_path
+ else:
+ raise TypeError("image_path should be a string or PIL.Image.Image object")
+
+ image_bgr = cv2.cvtColor(np_faceid_image, cv2.COLOR_RGB2BGR)
+
+ face_info = get_face_keypoints(face_model, image_bgr)
+ if face_info is None:
+ padded_image, sub_coord = pad_np_bgr_image(image_bgr)
+ face_info = get_face_keypoints(face_model, padded_image)
+ if face_info is None:
+ print("Warning: No face detected in the image. Continuing processing...")
+ return None, None
+ face_kps = face_info['kps']
+ face_kps -= np.array(sub_coord)
+ else:
+ face_kps = face_info['kps']
+ arcface_embedding = face_info['embedding']
+ # print(face_kps)
+ norm_face = face_align.norm_crop(image_bgr, landmark=face_kps, image_size=224)
+ align_face = cv2.cvtColor(norm_face, cv2.COLOR_BGR2RGB)
+
+ return align_face, arcface_embedding
+
+@torch.no_grad()
+def inference(face_model, img, device):
+ img = cv2.resize(img, (112, 112))
+ img = np.transpose(img, (2, 0, 1))
+ img = torch.from_numpy(img).unsqueeze(0).float().to(device)
+ img.div_(255).sub_(0.5).div_(0.5)
+ embedding = face_model(img).detach().cpu().numpy()[0]
+ return embedding / np.linalg.norm(embedding)
+
+
+def process_video(video_path, face_arc_model, face_cur_model, fid_model, arcface_image_embedding, cur_image_embedding, real_activations, device):
+ video_frames = sample_video_frames(video_path, num_frames=16)
+ #print(video_frames)
+ # Initialize lists to store the scores
+ cur_scores = []
+ arc_scores = []
+ fid_face = []
+
+ for frame in video_frames:
+ # Convert to RGB once at the beginning
+ frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
+
+ # Process the frame for ArcFace embeddings
+ align_face_frame, arcface_frame_embedding = process_image(face_arc_model, frame_rgb)
+
+ # Skip if alignment fails
+ if align_face_frame is None:
+ continue
+
+ # Perform inference for current face model
+ cur_embedding_frame = inference(face_cur_model, align_face_frame, device)
+
+ # Compute cosine similarity for cur_score and arc_score in a compact manner
+ cur_score = max(0.0, batch_cosine_similarity(cur_image_embedding, cur_embedding_frame, device=device).item())
+ arc_score = max(0.0, batch_cosine_similarity(arcface_image_embedding, arcface_frame_embedding, device=device).item())
+
+ # Process FID score
+ align_face_frame_pil = Image.fromarray(align_face_frame)
+ fake_image = load_image(align_face_frame_pil).to(device)
+ fake_activations = get_activations(fake_image, fid_model)
+ fid_score = calculate_fid(real_activations, fake_activations, device)
+
+ # Collect scores
+ fid_face.append(fid_score)
+ cur_scores.append(cur_score)
+ arc_scores.append(arc_score)
+
+ # Aggregate results with default values for empty lists
+ avg_cur_score = np.mean(cur_scores) if cur_scores else 0.0
+ avg_arc_score = np.mean(arc_scores) if arc_scores else 0.0
+ avg_fid_score = np.mean(fid_face) if fid_face else 0.0
+
+ return avg_cur_score, avg_arc_score, avg_fid_score
+
+
+
+def main():
+ device = "cuda"
+ # data_path = "data/SkyActor"
+ # data_path = "data/LivePotraits"
+ # data_path = "data/Actor-One"
+ data_path = "data/FollowYourEmoji"
+ img_path = "/maindata/data/shared/public/rui.wang/act_review/ref_images"
+ pre_tag = False
+ mp4_list = os.listdir(data_path)
+ print(mp4_list)
+
+ img_list = []
+ video_list = []
+ for mp4 in mp4_list:
+ if "mp4" not in mp4:
+ continue
+ if pre_tag:
+ png_path = mp4.split('.')[0].split('-')[0] + ".png"
+ else:
+ if "-" in mp4:
+ png_path = mp4.split('.')[0].split('-')[1] + ".png"
+ else:
+ png_path = mp4.split('.')[0].split('_')[1] + ".png"
+ img_list.append(os.path.join(img_path, png_path))
+ video_list.append(os.path.join(data_path, mp4))
+ print(img_list)
+ print(video_list[0])
+
+ model_path = "eval"
+ face_arc_path = os.path.join(model_path, "face_encoder")
+ face_cur_path = os.path.join(face_arc_path, "glint360k_curricular_face_r101_backbone.bin")
+
+ # Initialize FaceEncoder model for face detection and embedding extraction
+ face_arc_model = FaceAnalysis(root=face_arc_path, providers=['CUDAExecutionProvider'])
+ face_arc_model.prepare(ctx_id=0, det_size=(320, 320))
+
+ # Load face recognition model
+ face_cur_model = get_model('IR_101')([112, 112])
+ face_cur_model.load_state_dict(torch.load(face_cur_path, map_location="cpu"))
+ face_cur_model = face_cur_model.to(device)
+ face_cur_model.eval()
+
+ # Load InceptionV3 model for FID calculation
+ fid_model = models.inception_v3(weights=models.Inception_V3_Weights.DEFAULT)
+ fid_model.fc = torch.nn.Identity() # Remove final classification layer
+ fid_model.eval()
+ fid_model = fid_model.to(device)
+
+ # Process the single video and image pair
+ # Extract embeddings and features from the image
+ cur_list, arc_list, fid_list = [], [], []
+ for i in range(len(img_list)):
+ align_face_image, arcface_image_embedding = process_image(face_arc_model, img_list[i])
+
+ cur_image_embedding = inference(face_cur_model, align_face_image, device)
+ align_face_image_pil = Image.fromarray(align_face_image)
+ real_image = load_image(align_face_image_pil).to(device)
+ real_activations = get_activations(real_image, fid_model)
+
+ # Process the video and calculate scores
+ cur_score, arc_score, fid_score = process_video(
+ video_list[i], face_arc_model, face_cur_model, fid_model,
+ arcface_image_embedding, cur_image_embedding, real_activations, device
+ )
+ print(cur_score, arc_score, fid_score)
+ cur_list.append(cur_score)
+ arc_list.append(arc_score)
+ fid_list.append(fid_score)
+ # break
+ print("cur", sum(cur_list)/ len(cur_list))
+ print("arc", sum(arc_list)/ len(arc_list))
+ print("fid", sum(fid_list)/ len(fid_list))
+
+
+
+main()
diff --git a/eval/curricularface/__init__.py b/eval/curricularface/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..7f2f059c959e744d49bde7a3c6b15b72cd5e7ba9
--- /dev/null
+++ b/eval/curricularface/__init__.py
@@ -0,0 +1,33 @@
+# The implementation is adopted from TFace,made pubicly available under the Apache-2.0 license at
+# https://github.com/Tencent/TFace/blob/master/recognition/torchkit/backbone
+from .model_irse import IR_18, IR_34, IR_50, IR_101, IR_152, IR_200, IR_SE_50, IR_SE_101, IR_SE_152, IR_SE_200
+from .model_resnet import ResNet_50, ResNet_101, ResNet_152
+
+
+_model_dict = {
+ 'ResNet_50': ResNet_50,
+ 'ResNet_101': ResNet_101,
+ 'ResNet_152': ResNet_152,
+ 'IR_18': IR_18,
+ 'IR_34': IR_34,
+ 'IR_50': IR_50,
+ 'IR_101': IR_101,
+ 'IR_152': IR_152,
+ 'IR_200': IR_200,
+ 'IR_SE_50': IR_SE_50,
+ 'IR_SE_101': IR_SE_101,
+ 'IR_SE_152': IR_SE_152,
+ 'IR_SE_200': IR_SE_200
+}
+
+
+def get_model(key):
+ """ Get different backbone network by key,
+ support ResNet50, ResNet_101, ResNet_152
+ IR_18, IR_34, IR_50, IR_101, IR_152, IR_200,
+ IR_SE_50, IR_SE_101, IR_SE_152, IR_SE_200.
+ """
+ if key in _model_dict.keys():
+ return _model_dict[key]
+ else:
+ raise KeyError('not support model {}'.format(key))
diff --git a/eval/curricularface/common.py b/eval/curricularface/common.py
new file mode 100644
index 0000000000000000000000000000000000000000..d4a30335486e9bd502f31a0df17758a9e47d3350
--- /dev/null
+++ b/eval/curricularface/common.py
@@ -0,0 +1,68 @@
+# The implementation is adopted from TFace,made pubicly available under the Apache-2.0 license at
+# https://github.com/Tencent/TFace/blob/master/recognition/torchkit/backbone/common.py
+import torch.nn as nn
+from torch.nn import Conv2d, Module, ReLU, Sigmoid
+
+
+def initialize_weights(modules):
+ """ Weight initilize, conv2d and linear is initialized with kaiming_normal
+ """
+ for m in modules:
+ if isinstance(m, nn.Conv2d):
+ nn.init.kaiming_normal_(
+ m.weight, mode='fan_out', nonlinearity='relu')
+ if m.bias is not None:
+ m.bias.data.zero_()
+ elif isinstance(m, nn.BatchNorm2d):
+ m.weight.data.fill_(1)
+ m.bias.data.zero_()
+ elif isinstance(m, nn.Linear):
+ nn.init.kaiming_normal_(
+ m.weight, mode='fan_out', nonlinearity='relu')
+ if m.bias is not None:
+ m.bias.data.zero_()
+
+
+class Flatten(Module):
+ """ Flat tensor
+ """
+
+ def forward(self, input):
+ return input.view(input.size(0), -1)
+
+
+class SEModule(Module):
+ """ SE block
+ """
+
+ def __init__(self, channels, reduction):
+ super(SEModule, self).__init__()
+ self.avg_pool = nn.AdaptiveAvgPool2d(1)
+ self.fc1 = Conv2d(
+ channels,
+ channels // reduction,
+ kernel_size=1,
+ padding=0,
+ bias=False)
+
+ nn.init.xavier_uniform_(self.fc1.weight.data)
+
+ self.relu = ReLU(inplace=True)
+ self.fc2 = Conv2d(
+ channels // reduction,
+ channels,
+ kernel_size=1,
+ padding=0,
+ bias=False)
+
+ self.sigmoid = Sigmoid()
+
+ def forward(self, x):
+ module_input = x
+ x = self.avg_pool(x)
+ x = self.fc1(x)
+ x = self.relu(x)
+ x = self.fc2(x)
+ x = self.sigmoid(x)
+
+ return module_input * x
diff --git a/eval/curricularface/model_irse.py b/eval/curricularface/model_irse.py
new file mode 100644
index 0000000000000000000000000000000000000000..843bb9da33eb7ab777ab7ca7746e19e4138ed33c
--- /dev/null
+++ b/eval/curricularface/model_irse.py
@@ -0,0 +1,299 @@
+# The implementation is adopted from TFace,made pubicly available under the Apache-2.0 license at
+# https://github.com/Tencent/TFace/blob/master/recognition/torchkit/backbone/model_irse.py
+from collections import namedtuple
+
+from torch.nn import BatchNorm1d, BatchNorm2d, Conv2d, Dropout, Linear, MaxPool2d, Module, PReLU, Sequential
+
+from .common import Flatten, SEModule, initialize_weights
+
+
+class BasicBlockIR(Module):
+ """ BasicBlock for IRNet
+ """
+
+ def __init__(self, in_channel, depth, stride):
+ super(BasicBlockIR, self).__init__()
+ if in_channel == depth:
+ self.shortcut_layer = MaxPool2d(1, stride)
+ else:
+ self.shortcut_layer = Sequential(
+ Conv2d(in_channel, depth, (1, 1), stride, bias=False),
+ BatchNorm2d(depth))
+ self.res_layer = Sequential(
+ BatchNorm2d(in_channel),
+ Conv2d(in_channel, depth, (3, 3), (1, 1), 1, bias=False),
+ BatchNorm2d(depth), PReLU(depth),
+ Conv2d(depth, depth, (3, 3), stride, 1, bias=False),
+ BatchNorm2d(depth))
+
+ def forward(self, x):
+ shortcut = self.shortcut_layer(x)
+ res = self.res_layer(x)
+
+ return res + shortcut
+
+
+class BottleneckIR(Module):
+ """ BasicBlock with bottleneck for IRNet
+ """
+
+ def __init__(self, in_channel, depth, stride):
+ super(BottleneckIR, self).__init__()
+ reduction_channel = depth // 4
+ if in_channel == depth:
+ self.shortcut_layer = MaxPool2d(1, stride)
+ else:
+ self.shortcut_layer = Sequential(
+ Conv2d(in_channel, depth, (1, 1), stride, bias=False),
+ BatchNorm2d(depth))
+ self.res_layer = Sequential(
+ BatchNorm2d(in_channel),
+ Conv2d(
+ in_channel, reduction_channel, (1, 1), (1, 1), 0, bias=False),
+ BatchNorm2d(reduction_channel), PReLU(reduction_channel),
+ Conv2d(
+ reduction_channel,
+ reduction_channel, (3, 3), (1, 1),
+ 1,
+ bias=False), BatchNorm2d(reduction_channel),
+ PReLU(reduction_channel),
+ Conv2d(reduction_channel, depth, (1, 1), stride, 0, bias=False),
+ BatchNorm2d(depth))
+
+ def forward(self, x):
+ shortcut = self.shortcut_layer(x)
+ res = self.res_layer(x)
+
+ return res + shortcut
+
+
+class BasicBlockIRSE(BasicBlockIR):
+
+ def __init__(self, in_channel, depth, stride):
+ super(BasicBlockIRSE, self).__init__(in_channel, depth, stride)
+ self.res_layer.add_module('se_block', SEModule(depth, 16))
+
+
+class BottleneckIRSE(BottleneckIR):
+
+ def __init__(self, in_channel, depth, stride):
+ super(BottleneckIRSE, self).__init__(in_channel, depth, stride)
+ self.res_layer.add_module('se_block', SEModule(depth, 16))
+
+
+class Bottleneck(namedtuple('Block', ['in_channel', 'depth', 'stride'])):
+ '''A named tuple describing a ResNet block.'''
+
+
+def get_block(in_channel, depth, num_units, stride=2):
+ return [Bottleneck(in_channel, depth, stride)] + \
+ [Bottleneck(depth, depth, 1) for i in range(num_units - 1)]
+
+
+def get_blocks(num_layers):
+ if num_layers == 18:
+ blocks = [
+ get_block(in_channel=64, depth=64, num_units=2),
+ get_block(in_channel=64, depth=128, num_units=2),
+ get_block(in_channel=128, depth=256, num_units=2),
+ get_block(in_channel=256, depth=512, num_units=2)
+ ]
+ elif num_layers == 34:
+ blocks = [
+ get_block(in_channel=64, depth=64, num_units=3),
+ get_block(in_channel=64, depth=128, num_units=4),
+ get_block(in_channel=128, depth=256, num_units=6),
+ get_block(in_channel=256, depth=512, num_units=3)
+ ]
+ elif num_layers == 50:
+ blocks = [
+ get_block(in_channel=64, depth=64, num_units=3),
+ get_block(in_channel=64, depth=128, num_units=4),
+ get_block(in_channel=128, depth=256, num_units=14),
+ get_block(in_channel=256, depth=512, num_units=3)
+ ]
+ elif num_layers == 100:
+ blocks = [
+ get_block(in_channel=64, depth=64, num_units=3),
+ get_block(in_channel=64, depth=128, num_units=13),
+ get_block(in_channel=128, depth=256, num_units=30),
+ get_block(in_channel=256, depth=512, num_units=3)
+ ]
+ elif num_layers == 152:
+ blocks = [
+ get_block(in_channel=64, depth=256, num_units=3),
+ get_block(in_channel=256, depth=512, num_units=8),
+ get_block(in_channel=512, depth=1024, num_units=36),
+ get_block(in_channel=1024, depth=2048, num_units=3)
+ ]
+ elif num_layers == 200:
+ blocks = [
+ get_block(in_channel=64, depth=256, num_units=3),
+ get_block(in_channel=256, depth=512, num_units=24),
+ get_block(in_channel=512, depth=1024, num_units=36),
+ get_block(in_channel=1024, depth=2048, num_units=3)
+ ]
+
+ return blocks
+
+
+class Backbone(Module):
+
+ def __init__(self, input_size, num_layers, mode='ir'):
+ """ Args:
+ input_size: input_size of backbone
+ num_layers: num_layers of backbone
+ mode: support ir or irse
+ """
+ super(Backbone, self).__init__()
+ assert input_size[0] in [112, 224], \
+ 'input_size should be [112, 112] or [224, 224]'
+ assert num_layers in [18, 34, 50, 100, 152, 200], \
+ 'num_layers should be 18, 34, 50, 100 or 152'
+ assert mode in ['ir', 'ir_se'], \
+ 'mode should be ir or ir_se'
+ self.input_layer = Sequential(
+ Conv2d(3, 64, (3, 3), 1, 1, bias=False), BatchNorm2d(64),
+ PReLU(64))
+ blocks = get_blocks(num_layers)
+ if num_layers <= 100:
+ if mode == 'ir':
+ unit_module = BasicBlockIR
+ elif mode == 'ir_se':
+ unit_module = BasicBlockIRSE
+ output_channel = 512
+ else:
+ if mode == 'ir':
+ unit_module = BottleneckIR
+ elif mode == 'ir_se':
+ unit_module = BottleneckIRSE
+ output_channel = 2048
+
+ if input_size[0] == 112:
+ self.output_layer = Sequential(
+ BatchNorm2d(output_channel), Dropout(0.4), Flatten(),
+ Linear(output_channel * 7 * 7, 512),
+ BatchNorm1d(512, affine=False))
+ else:
+ self.output_layer = Sequential(
+ BatchNorm2d(output_channel), Dropout(0.4), Flatten(),
+ Linear(output_channel * 14 * 14, 512),
+ BatchNorm1d(512, affine=False))
+
+ modules = []
+ mid_layer_indices = [] # [2, 15, 45, 48], total 49 layers for IR101
+ for block in blocks:
+ if len(mid_layer_indices) == 0:
+ mid_layer_indices.append(len(block) - 1)
+ else:
+ mid_layer_indices.append(len(block) + mid_layer_indices[-1])
+ for bottleneck in block:
+ modules.append(
+ unit_module(bottleneck.in_channel, bottleneck.depth,
+ bottleneck.stride))
+ self.body = Sequential(*modules)
+ self.mid_layer_indices = mid_layer_indices[-4:]
+
+ # self.dtype = next(self.parameters()).dtype
+ initialize_weights(self.modules())
+
+ def device(self):
+ return next(self.parameters()).device
+
+ def dtype(self):
+ return next(self.parameters()).dtype
+
+ def forward(self, x, return_mid_feats=False):
+ x = self.input_layer(x)
+ if not return_mid_feats:
+ x = self.body(x)
+ x = self.output_layer(x)
+ return x
+ else:
+ out_feats = []
+ for idx, module in enumerate(self.body):
+ x = module(x)
+ if idx in self.mid_layer_indices:
+ out_feats.append(x)
+ x = self.output_layer(x)
+ return x, out_feats
+
+
+def IR_18(input_size):
+ """ Constructs a ir-18 model.
+ """
+ model = Backbone(input_size, 18, 'ir')
+
+ return model
+
+
+def IR_34(input_size):
+ """ Constructs a ir-34 model.
+ """
+ model = Backbone(input_size, 34, 'ir')
+
+ return model
+
+
+def IR_50(input_size):
+ """ Constructs a ir-50 model.
+ """
+ model = Backbone(input_size, 50, 'ir')
+
+ return model
+
+
+def IR_101(input_size):
+ """ Constructs a ir-101 model.
+ """
+ model = Backbone(input_size, 100, 'ir')
+
+ return model
+
+
+def IR_152(input_size):
+ """ Constructs a ir-152 model.
+ """
+ model = Backbone(input_size, 152, 'ir')
+
+ return model
+
+
+def IR_200(input_size):
+ """ Constructs a ir-200 model.
+ """
+ model = Backbone(input_size, 200, 'ir')
+
+ return model
+
+
+def IR_SE_50(input_size):
+ """ Constructs a ir_se-50 model.
+ """
+ model = Backbone(input_size, 50, 'ir_se')
+
+ return model
+
+
+def IR_SE_101(input_size):
+ """ Constructs a ir_se-101 model.
+ """
+ model = Backbone(input_size, 100, 'ir_se')
+
+ return model
+
+
+def IR_SE_152(input_size):
+ """ Constructs a ir_se-152 model.
+ """
+ model = Backbone(input_size, 152, 'ir_se')
+
+ return model
+
+
+def IR_SE_200(input_size):
+ """ Constructs a ir_se-200 model.
+ """
+ model = Backbone(input_size, 200, 'ir_se')
+
+ return model
diff --git a/eval/curricularface/model_resnet.py b/eval/curricularface/model_resnet.py
new file mode 100644
index 0000000000000000000000000000000000000000..33f41a2dffd0ab27c3b17dec81b5b753bb76701c
--- /dev/null
+++ b/eval/curricularface/model_resnet.py
@@ -0,0 +1,161 @@
+# The implementation is adopted from TFace,made pubicly available under the Apache-2.0 license at
+# https://github.com/Tencent/TFace/blob/master/recognition/torchkit/backbone/model_resnet.py
+import torch.nn as nn
+from torch.nn import BatchNorm1d, BatchNorm2d, Conv2d, Dropout, Linear, MaxPool2d, Module, ReLU, Sequential
+
+from .common import initialize_weights
+
+
+def conv3x3(in_planes, out_planes, stride=1):
+ """ 3x3 convolution with padding
+ """
+ return Conv2d(
+ in_planes,
+ out_planes,
+ kernel_size=3,
+ stride=stride,
+ padding=1,
+ bias=False)
+
+
+def conv1x1(in_planes, out_planes, stride=1):
+ """ 1x1 convolution
+ """
+ return Conv2d(
+ in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
+
+
+class Bottleneck(Module):
+ expansion = 4
+
+ def __init__(self, inplanes, planes, stride=1, downsample=None):
+ super(Bottleneck, self).__init__()
+ self.conv1 = conv1x1(inplanes, planes)
+ self.bn1 = BatchNorm2d(planes)
+ self.conv2 = conv3x3(planes, planes, stride)
+ self.bn2 = BatchNorm2d(planes)
+ self.conv3 = conv1x1(planes, planes * self.expansion)
+ self.bn3 = BatchNorm2d(planes * self.expansion)
+ self.relu = ReLU(inplace=True)
+ self.downsample = downsample
+ self.stride = stride
+
+ def forward(self, x):
+ identity = x
+
+ out = self.conv1(x)
+ out = self.bn1(out)
+ out = self.relu(out)
+
+ out = self.conv2(out)
+ out = self.bn2(out)
+ out = self.relu(out)
+
+ out = self.conv3(out)
+ out = self.bn3(out)
+
+ if self.downsample is not None:
+ identity = self.downsample(x)
+
+ out += identity
+ out = self.relu(out)
+
+ return out
+
+
+class ResNet(Module):
+ """ ResNet backbone
+ """
+
+ def __init__(self, input_size, block, layers, zero_init_residual=True):
+ """ Args:
+ input_size: input_size of backbone
+ block: block function
+ layers: layers in each block
+ """
+ super(ResNet, self).__init__()
+ assert input_size[0] in [112, 224], \
+ 'input_size should be [112, 112] or [224, 224]'
+ self.inplanes = 64
+ self.conv1 = Conv2d(
+ 3, 64, kernel_size=7, stride=2, padding=3, bias=False)
+ self.bn1 = BatchNorm2d(64)
+ self.relu = ReLU(inplace=True)
+ self.maxpool = MaxPool2d(kernel_size=3, stride=2, padding=1)
+ self.layer1 = self._make_layer(block, 64, layers[0])
+ self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
+ self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
+ self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
+
+ self.bn_o1 = BatchNorm2d(2048)
+ self.dropout = Dropout()
+ if input_size[0] == 112:
+ self.fc = Linear(2048 * 4 * 4, 512)
+ else:
+ self.fc = Linear(2048 * 7 * 7, 512)
+ self.bn_o2 = BatchNorm1d(512)
+
+ initialize_weights(self.modules)
+ if zero_init_residual:
+ for m in self.modules():
+ if isinstance(m, Bottleneck):
+ nn.init.constant_(m.bn3.weight, 0)
+
+ def _make_layer(self, block, planes, blocks, stride=1):
+ downsample = None
+ if stride != 1 or self.inplanes != planes * block.expansion:
+ downsample = Sequential(
+ conv1x1(self.inplanes, planes * block.expansion, stride),
+ BatchNorm2d(planes * block.expansion),
+ )
+
+ layers = []
+ layers.append(block(self.inplanes, planes, stride, downsample))
+ self.inplanes = planes * block.expansion
+ for _ in range(1, blocks):
+ layers.append(block(self.inplanes, planes))
+
+ return Sequential(*layers)
+
+ def forward(self, x):
+ x = self.conv1(x)
+ x = self.bn1(x)
+ x = self.relu(x)
+ x = self.maxpool(x)
+
+ x = self.layer1(x)
+ x = self.layer2(x)
+ x = self.layer3(x)
+ x = self.layer4(x)
+
+ x = self.bn_o1(x)
+ x = self.dropout(x)
+ x = x.view(x.size(0), -1)
+ x = self.fc(x)
+ x = self.bn_o2(x)
+
+ return x
+
+
+def ResNet_50(input_size, **kwargs):
+ """ Constructs a ResNet-50 model.
+ """
+ model = ResNet(input_size, Bottleneck, [3, 4, 6, 3], **kwargs)
+
+ return model
+
+
+def ResNet_101(input_size, **kwargs):
+ """ Constructs a ResNet-101 model.
+ """
+ model = ResNet(input_size, Bottleneck, [3, 4, 23, 3], **kwargs)
+
+ return model
+
+
+def ResNet_152(input_size, **kwargs):
+ """ Constructs a ResNet-152 model.
+ """
+ model = ResNet(input_size, Bottleneck, [3, 8, 36, 3], **kwargs)
+
+ return model
diff --git a/eval/expression_score.py b/eval/expression_score.py
new file mode 100644
index 0000000000000000000000000000000000000000..7f2ffe7f3121c7e2d54db47b572d21a90f10c120
--- /dev/null
+++ b/eval/expression_score.py
@@ -0,0 +1,167 @@
+import os
+import os
+import torch
+from insightface.app import FaceAnalysis
+from insightface.utils import face_align
+from PIL import Image
+from torchvision import models, transforms
+from curricularface import get_model
+import cv2
+import numpy as np
+import numpy
+
+def pad_np_bgr_image(np_image, scale=1.25):
+ assert scale >= 1.0, "scale should be >= 1.0"
+ pad_scale = scale - 1.0
+ h, w = np_image.shape[:2]
+ top = bottom = int(h * pad_scale)
+ left = right = int(w * pad_scale)
+ return cv2.copyMakeBorder(np_image, top, bottom, left, right, cv2.BORDER_CONSTANT, value=(128, 128, 128)), (left, top)
+
+
+def sample_video_frames(video_path,):
+ cap = cv2.VideoCapture(video_path)
+ total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
+ frame_indices = np.linspace(0, total_frames - 1, total_frames, dtype=int)
+
+ frames = []
+ for idx in frame_indices:
+ cap.set(cv2.CAP_PROP_POS_FRAMES, idx)
+ ret, frame = cap.read()
+ if ret:
+ # if frame.shape[1] > 1024:
+ # frame = frame[:, 1440:, :]
+ # print(frame.shape)
+ frame = cv2.resize(frame, (720, 480))
+ # print(frame.shape)
+ frames.append(frame)
+ cap.release()
+ return frames
+
+
+def get_face_keypoints(face_model, image_bgr):
+ face_info = face_model.get(image_bgr)
+ if len(face_info) > 0:
+ return sorted(face_info, key=lambda x: (x['bbox'][2] - x['bbox'][0]) * (x['bbox'][3] - x['bbox'][1]))[-1]
+ return None
+
+def process_image(face_model, image_path):
+ if isinstance(image_path, str):
+ np_faceid_image = np.array(Image.open(image_path).convert("RGB"))
+ elif isinstance(image_path, numpy.ndarray):
+ np_faceid_image = image_path
+ else:
+ raise TypeError("image_path should be a string or PIL.Image.Image object")
+
+ image_bgr = cv2.cvtColor(np_faceid_image, cv2.COLOR_RGB2BGR)
+
+ face_info = get_face_keypoints(face_model, image_bgr)
+ if face_info is None:
+ padded_image, sub_coord = pad_np_bgr_image(image_bgr)
+ face_info = get_face_keypoints(face_model, padded_image)
+ if face_info is None:
+ print("Warning: No face detected in the image. Continuing processing...")
+ return None
+ face_kps = face_info['kps']
+ face_kps -= np.array(sub_coord)
+ else:
+ face_kps = face_info['kps']
+ return face_kps
+
+def process_video(video_path, face_arc_model):
+ video_frames = sample_video_frames(video_path,)
+ print(len(video_frames))
+ kps_list = []
+ for frame in video_frames:
+ # Convert to RGB once at the beginning
+ frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
+ kps = process_image(face_arc_model, frame_rgb)
+ if kps is None:
+ return None
+ # print(kps)
+ kps_list.append(kps)
+ return kps_list
+
+
+def calculate_l1_distance(list1, list2):
+ """
+ 计算两个列表的 L1 距离
+ :param list1: 第一个列表,形状为 (5, 2)
+ :param list2: 第二个列表,形状为 (5, 2)
+ :return: L1 距离
+ """
+ # 将列表转换为 NumPy 数组
+ list1 = np.array(list1)
+ list2 = np.array(list2)
+
+ # 计算每对点的 L1 距离
+ l1_distances = np.abs(list1 - list2).sum(axis=1)
+
+ # 返回所有点的 L1 距离之和
+ return l1_distances.sum()
+
+
+def calculate_kps(list1, list2):
+ distance_list = []
+ for kps1 in list1:
+ min_dis = (480 + 720) * 5 + 1
+ for kps2 in list2:
+ min_dis = min(min_dis, calculate_l1_distance(kps1, kps2))
+ distance_list.append(min_dis/(480+720)/10)
+ return sum(distance_list)/len(distance_list)
+
+
+def main():
+ device = "cuda"
+ # data_path = "data/SkyActor"
+ # data_path = "data/LivePotraits"
+ # data_path = "data/Actor-One"
+ data_path = "data/FollowYourEmoji"
+ img_path = "/maindata/data/shared/public/rui.wang/act_review/driving_video"
+ pre_tag = False
+ mp4_list = os.listdir(data_path)
+ print(mp4_list)
+
+ img_list = []
+ video_list = []
+ for mp4 in mp4_list:
+ if "mp4" not in mp4:
+ continue
+ if pre_tag:
+ png_path = mp4.split('.')[0].split('--')[1] + ".mp4"
+ else:
+ if "-" in mp4:
+ png_path = mp4.split('.')[0].split('-')[0] + ".mp4"
+ else:
+ png_path = mp4.split('.')[0].split('_')[0] + ".mp4"
+ img_list.append(os.path.join(img_path, png_path))
+ video_list.append(os.path.join(data_path, mp4))
+ print(img_list)
+ print(video_list[0])
+
+ model_path = "eval"
+ face_arc_path = os.path.join(model_path, "face_encoder")
+ face_cur_path = os.path.join(face_arc_path, "glint360k_curricular_face_r101_backbone.bin")
+
+ # Initialize FaceEncoder model for face detection and embedding extraction
+ face_arc_model = FaceAnalysis(root=face_arc_path, providers=['CUDAExecutionProvider'])
+ face_arc_model.prepare(ctx_id=0, det_size=(320, 320))
+
+ expression_list = []
+ for i in range(len(img_list)):
+ print("number: ", str(i), " total: ", len(img_list), data_path)
+ kps_1 = process_video(video_list[i], face_arc_model)
+ kps_2 = process_video(img_list[i], face_arc_model)
+ if kps_1 is None or kps_2 is None:
+ continue
+
+ dis = calculate_kps(kps_1, kps_2)
+ print(dis)
+ expression_list.append(dis)
+ # break
+
+ print("kps", sum(expression_list)/ len(expression_list))
+
+
+
+main()
diff --git a/eval/pose_score.py b/eval/pose_score.py
new file mode 100644
index 0000000000000000000000000000000000000000..f437e2e61e8749f4e71ee5eb47d242274c2c4275
--- /dev/null
+++ b/eval/pose_score.py
@@ -0,0 +1,588 @@
+import torch
+from collections import OrderedDict
+import os
+import torch
+import torch.nn as nn
+import cv2
+import numpy
+import numpy as np
+import math
+import time
+from scipy.ndimage.filters import gaussian_filter
+import matplotlib.pyplot as plt
+import matplotlib
+import torch
+from torchvision import transforms
+
+
+def transfer(model, model_weights):
+ transfered_model_weights = {}
+ for weights_name in model.state_dict().keys():
+ transfered_model_weights[weights_name] = model_weights['.'.join(weights_name.split('.')[1:])]
+ return transfered_model_weights
+
+def padRightDownCorner(img, stride, padValue):
+ h = img.shape[0]
+ w = img.shape[1]
+
+ pad = 4 * [None]
+ pad[0] = 0 # up
+ pad[1] = 0 # left
+ pad[2] = 0 if (h % stride == 0) else stride - (h % stride) # down
+ pad[3] = 0 if (w % stride == 0) else stride - (w % stride) # right
+
+ img_padded = img
+ pad_up = np.tile(img_padded[0:1, :, :]*0 + padValue, (pad[0], 1, 1))
+ img_padded = np.concatenate((pad_up, img_padded), axis=0)
+ pad_left = np.tile(img_padded[:, 0:1, :]*0 + padValue, (1, pad[1], 1))
+ img_padded = np.concatenate((pad_left, img_padded), axis=1)
+ pad_down = np.tile(img_padded[-2:-1, :, :]*0 + padValue, (pad[2], 1, 1))
+ img_padded = np.concatenate((img_padded, pad_down), axis=0)
+ pad_right = np.tile(img_padded[:, -2:-1, :]*0 + padValue, (1, pad[3], 1))
+ img_padded = np.concatenate((img_padded, pad_right), axis=1)
+
+ return img_padded, pad
+
+def make_layers(block, no_relu_layers):
+ layers = []
+ for layer_name, v in block.items():
+ if 'pool' in layer_name:
+ layer = nn.MaxPool2d(kernel_size=v[0], stride=v[1],
+ padding=v[2])
+ layers.append((layer_name, layer))
+ else:
+ conv2d = nn.Conv2d(in_channels=v[0], out_channels=v[1],
+ kernel_size=v[2], stride=v[3],
+ padding=v[4])
+ layers.append((layer_name, conv2d))
+ if layer_name not in no_relu_layers:
+ layers.append(('relu_'+layer_name, nn.ReLU(inplace=True)))
+
+ return nn.Sequential(OrderedDict(layers))
+
+class bodypose_model(nn.Module):
+ def __init__(self):
+ super(bodypose_model, self).__init__()
+
+ # these layers have no relu layer
+ no_relu_layers = ['conv5_5_CPM_L1', 'conv5_5_CPM_L2', 'Mconv7_stage2_L1',\
+ 'Mconv7_stage2_L2', 'Mconv7_stage3_L1', 'Mconv7_stage3_L2',\
+ 'Mconv7_stage4_L1', 'Mconv7_stage4_L2', 'Mconv7_stage5_L1',\
+ 'Mconv7_stage5_L2', 'Mconv7_stage6_L1', 'Mconv7_stage6_L1']
+ blocks = {}
+ block0 = OrderedDict([
+ ('conv1_1', [3, 64, 3, 1, 1]),
+ ('conv1_2', [64, 64, 3, 1, 1]),
+ ('pool1_stage1', [2, 2, 0]),
+ ('conv2_1', [64, 128, 3, 1, 1]),
+ ('conv2_2', [128, 128, 3, 1, 1]),
+ ('pool2_stage1', [2, 2, 0]),
+ ('conv3_1', [128, 256, 3, 1, 1]),
+ ('conv3_2', [256, 256, 3, 1, 1]),
+ ('conv3_3', [256, 256, 3, 1, 1]),
+ ('conv3_4', [256, 256, 3, 1, 1]),
+ ('pool3_stage1', [2, 2, 0]),
+ ('conv4_1', [256, 512, 3, 1, 1]),
+ ('conv4_2', [512, 512, 3, 1, 1]),
+ ('conv4_3_CPM', [512, 256, 3, 1, 1]),
+ ('conv4_4_CPM', [256, 128, 3, 1, 1])
+ ])
+
+
+ # Stage 1
+ block1_1 = OrderedDict([
+ ('conv5_1_CPM_L1', [128, 128, 3, 1, 1]),
+ ('conv5_2_CPM_L1', [128, 128, 3, 1, 1]),
+ ('conv5_3_CPM_L1', [128, 128, 3, 1, 1]),
+ ('conv5_4_CPM_L1', [128, 512, 1, 1, 0]),
+ ('conv5_5_CPM_L1', [512, 38, 1, 1, 0])
+ ])
+
+ block1_2 = OrderedDict([
+ ('conv5_1_CPM_L2', [128, 128, 3, 1, 1]),
+ ('conv5_2_CPM_L2', [128, 128, 3, 1, 1]),
+ ('conv5_3_CPM_L2', [128, 128, 3, 1, 1]),
+ ('conv5_4_CPM_L2', [128, 512, 1, 1, 0]),
+ ('conv5_5_CPM_L2', [512, 19, 1, 1, 0])
+ ])
+ blocks['block1_1'] = block1_1
+ blocks['block1_2'] = block1_2
+
+ self.model0 = make_layers(block0, no_relu_layers)
+
+ # Stages 2 - 6
+ for i in range(2, 7):
+ blocks['block%d_1' % i] = OrderedDict([
+ ('Mconv1_stage%d_L1' % i, [185, 128, 7, 1, 3]),
+ ('Mconv2_stage%d_L1' % i, [128, 128, 7, 1, 3]),
+ ('Mconv3_stage%d_L1' % i, [128, 128, 7, 1, 3]),
+ ('Mconv4_stage%d_L1' % i, [128, 128, 7, 1, 3]),
+ ('Mconv5_stage%d_L1' % i, [128, 128, 7, 1, 3]),
+ ('Mconv6_stage%d_L1' % i, [128, 128, 1, 1, 0]),
+ ('Mconv7_stage%d_L1' % i, [128, 38, 1, 1, 0])
+ ])
+
+ blocks['block%d_2' % i] = OrderedDict([
+ ('Mconv1_stage%d_L2' % i, [185, 128, 7, 1, 3]),
+ ('Mconv2_stage%d_L2' % i, [128, 128, 7, 1, 3]),
+ ('Mconv3_stage%d_L2' % i, [128, 128, 7, 1, 3]),
+ ('Mconv4_stage%d_L2' % i, [128, 128, 7, 1, 3]),
+ ('Mconv5_stage%d_L2' % i, [128, 128, 7, 1, 3]),
+ ('Mconv6_stage%d_L2' % i, [128, 128, 1, 1, 0]),
+ ('Mconv7_stage%d_L2' % i, [128, 19, 1, 1, 0])
+ ])
+
+ for k in blocks.keys():
+ blocks[k] = make_layers(blocks[k], no_relu_layers)
+
+ self.model1_1 = blocks['block1_1']
+ self.model2_1 = blocks['block2_1']
+ self.model3_1 = blocks['block3_1']
+ self.model4_1 = blocks['block4_1']
+ self.model5_1 = blocks['block5_1']
+ self.model6_1 = blocks['block6_1']
+
+ self.model1_2 = blocks['block1_2']
+ self.model2_2 = blocks['block2_2']
+ self.model3_2 = blocks['block3_2']
+ self.model4_2 = blocks['block4_2']
+ self.model5_2 = blocks['block5_2']
+ self.model6_2 = blocks['block6_2']
+
+
+ def forward(self, x):
+
+ out1 = self.model0(x)
+
+ out1_1 = self.model1_1(out1)
+ out1_2 = self.model1_2(out1)
+ out2 = torch.cat([out1_1, out1_2, out1], 1)
+
+ out2_1 = self.model2_1(out2)
+ out2_2 = self.model2_2(out2)
+ out3 = torch.cat([out2_1, out2_2, out1], 1)
+
+ out3_1 = self.model3_1(out3)
+ out3_2 = self.model3_2(out3)
+ out4 = torch.cat([out3_1, out3_2, out1], 1)
+
+ out4_1 = self.model4_1(out4)
+ out4_2 = self.model4_2(out4)
+ out5 = torch.cat([out4_1, out4_2, out1], 1)
+
+ out5_1 = self.model5_1(out5)
+ out5_2 = self.model5_2(out5)
+ out6 = torch.cat([out5_1, out5_2, out1], 1)
+
+ out6_1 = self.model6_1(out6)
+ out6_2 = self.model6_2(out6)
+
+ return out6_1, out6_2
+
+class handpose_model(nn.Module):
+ def __init__(self):
+ super(handpose_model, self).__init__()
+
+ # these layers have no relu layer
+ no_relu_layers = ['conv6_2_CPM', 'Mconv7_stage2', 'Mconv7_stage3',\
+ 'Mconv7_stage4', 'Mconv7_stage5', 'Mconv7_stage6']
+ # stage 1
+ block1_0 = OrderedDict([
+ ('conv1_1', [3, 64, 3, 1, 1]),
+ ('conv1_2', [64, 64, 3, 1, 1]),
+ ('pool1_stage1', [2, 2, 0]),
+ ('conv2_1', [64, 128, 3, 1, 1]),
+ ('conv2_2', [128, 128, 3, 1, 1]),
+ ('pool2_stage1', [2, 2, 0]),
+ ('conv3_1', [128, 256, 3, 1, 1]),
+ ('conv3_2', [256, 256, 3, 1, 1]),
+ ('conv3_3', [256, 256, 3, 1, 1]),
+ ('conv3_4', [256, 256, 3, 1, 1]),
+ ('pool3_stage1', [2, 2, 0]),
+ ('conv4_1', [256, 512, 3, 1, 1]),
+ ('conv4_2', [512, 512, 3, 1, 1]),
+ ('conv4_3', [512, 512, 3, 1, 1]),
+ ('conv4_4', [512, 512, 3, 1, 1]),
+ ('conv5_1', [512, 512, 3, 1, 1]),
+ ('conv5_2', [512, 512, 3, 1, 1]),
+ ('conv5_3_CPM', [512, 128, 3, 1, 1])
+ ])
+
+ block1_1 = OrderedDict([
+ ('conv6_1_CPM', [128, 512, 1, 1, 0]),
+ ('conv6_2_CPM', [512, 22, 1, 1, 0])
+ ])
+
+ blocks = {}
+ blocks['block1_0'] = block1_0
+ blocks['block1_1'] = block1_1
+
+ # stage 2-6
+ for i in range(2, 7):
+ blocks['block%d' % i] = OrderedDict([
+ ('Mconv1_stage%d' % i, [150, 128, 7, 1, 3]),
+ ('Mconv2_stage%d' % i, [128, 128, 7, 1, 3]),
+ ('Mconv3_stage%d' % i, [128, 128, 7, 1, 3]),
+ ('Mconv4_stage%d' % i, [128, 128, 7, 1, 3]),
+ ('Mconv5_stage%d' % i, [128, 128, 7, 1, 3]),
+ ('Mconv6_stage%d' % i, [128, 128, 1, 1, 0]),
+ ('Mconv7_stage%d' % i, [128, 22, 1, 1, 0])
+ ])
+
+ for k in blocks.keys():
+ blocks[k] = make_layers(blocks[k], no_relu_layers)
+
+ self.model1_0 = blocks['block1_0']
+ self.model1_1 = blocks['block1_1']
+ self.model2 = blocks['block2']
+ self.model3 = blocks['block3']
+ self.model4 = blocks['block4']
+ self.model5 = blocks['block5']
+ self.model6 = blocks['block6']
+
+ def forward(self, x):
+ out1_0 = self.model1_0(x)
+ out1_1 = self.model1_1(out1_0)
+ concat_stage2 = torch.cat([out1_1, out1_0], 1)
+ out_stage2 = self.model2(concat_stage2)
+ concat_stage3 = torch.cat([out_stage2, out1_0], 1)
+ out_stage3 = self.model3(concat_stage3)
+ concat_stage4 = torch.cat([out_stage3, out1_0], 1)
+ out_stage4 = self.model4(concat_stage4)
+ concat_stage5 = torch.cat([out_stage4, out1_0], 1)
+ out_stage5 = self.model5(concat_stage5)
+ concat_stage6 = torch.cat([out_stage5, out1_0], 1)
+ out_stage6 = self.model6(concat_stage6)
+ return out_stage6
+
+class Body(object):
+ def __init__(self, model_path):
+ self.model = bodypose_model()
+ if torch.cuda.is_available():
+ self.model = self.model.cuda()
+ print('cuda')
+ model_dict = transfer(self.model, torch.load(model_path))
+ self.model.load_state_dict(model_dict)
+ self.model.eval()
+
+ def __call__(self, oriImg):
+ # scale_search = [0.5, 1.0, 1.5, 2.0]
+ scale_search = [0.5]
+ boxsize = 368
+ stride = 8
+ padValue = 128
+ thre1 = 0.1
+ thre2 = 0.05
+ multiplier = [x * boxsize / oriImg.shape[0] for x in scale_search]
+ heatmap_avg = np.zeros((oriImg.shape[0], oriImg.shape[1], 19))
+ paf_avg = np.zeros((oriImg.shape[0], oriImg.shape[1], 38))
+
+ for m in range(len(multiplier)):
+ scale = multiplier[m]
+ imageToTest = cv2.resize(oriImg, (0, 0), fx=scale, fy=scale, interpolation=cv2.INTER_CUBIC)
+ imageToTest_padded, pad = padRightDownCorner(imageToTest, stride, padValue)
+ im = np.transpose(np.float32(imageToTest_padded[:, :, :, np.newaxis]), (3, 2, 0, 1)) / 256 - 0.5
+ im = np.ascontiguousarray(im)
+
+ data = torch.from_numpy(im).float()
+ if torch.cuda.is_available():
+ data = data.cuda()
+ # data = data.permute([2, 0, 1]).unsqueeze(0).float()
+ with torch.no_grad():
+ Mconv7_stage6_L1, Mconv7_stage6_L2 = self.model(data)
+ Mconv7_stage6_L1 = Mconv7_stage6_L1.cpu().numpy()
+ Mconv7_stage6_L2 = Mconv7_stage6_L2.cpu().numpy()
+
+ # extract outputs, resize, and remove padding
+ # heatmap = np.transpose(np.squeeze(net.blobs[output_blobs.keys()[1]].data), (1, 2, 0)) # output 1 is heatmaps
+ heatmap = np.transpose(np.squeeze(Mconv7_stage6_L2), (1, 2, 0)) # output 1 is heatmaps
+ heatmap = cv2.resize(heatmap, (0, 0), fx=stride, fy=stride, interpolation=cv2.INTER_CUBIC)
+ heatmap = heatmap[:imageToTest_padded.shape[0] - pad[2], :imageToTest_padded.shape[1] - pad[3], :]
+ heatmap = cv2.resize(heatmap, (oriImg.shape[1], oriImg.shape[0]), interpolation=cv2.INTER_CUBIC)
+
+ # paf = np.transpose(np.squeeze(net.blobs[output_blobs.keys()[0]].data), (1, 2, 0)) # output 0 is PAFs
+ paf = np.transpose(np.squeeze(Mconv7_stage6_L1), (1, 2, 0)) # output 0 is PAFs
+ paf = cv2.resize(paf, (0, 0), fx=stride, fy=stride, interpolation=cv2.INTER_CUBIC)
+ paf = paf[:imageToTest_padded.shape[0] - pad[2], :imageToTest_padded.shape[1] - pad[3], :]
+ paf = cv2.resize(paf, (oriImg.shape[1], oriImg.shape[0]), interpolation=cv2.INTER_CUBIC)
+
+ heatmap_avg += heatmap_avg + heatmap / len(multiplier)
+ paf_avg += + paf / len(multiplier)
+
+ all_peaks = []
+ peak_counter = 0
+
+ for part in range(18):
+ map_ori = heatmap_avg[:, :, part]
+ one_heatmap = gaussian_filter(map_ori, sigma=3)
+
+ map_left = np.zeros(one_heatmap.shape)
+ map_left[1:, :] = one_heatmap[:-1, :]
+ map_right = np.zeros(one_heatmap.shape)
+ map_right[:-1, :] = one_heatmap[1:, :]
+ map_up = np.zeros(one_heatmap.shape)
+ map_up[:, 1:] = one_heatmap[:, :-1]
+ map_down = np.zeros(one_heatmap.shape)
+ map_down[:, :-1] = one_heatmap[:, 1:]
+
+ peaks_binary = np.logical_and.reduce(
+ (one_heatmap >= map_left, one_heatmap >= map_right, one_heatmap >= map_up, one_heatmap >= map_down, one_heatmap > thre1))
+ peaks = list(zip(np.nonzero(peaks_binary)[1], np.nonzero(peaks_binary)[0])) # note reverse
+ peaks_with_score = [x + (map_ori[x[1], x[0]],) for x in peaks]
+ peak_id = range(peak_counter, peak_counter + len(peaks))
+ peaks_with_score_and_id = [peaks_with_score[i] + (peak_id[i],) for i in range(len(peak_id))]
+
+ all_peaks.append(peaks_with_score_and_id)
+ peak_counter += len(peaks)
+
+ # find connection in the specified sequence, center 29 is in the position 15
+ limbSeq = [[2, 3], [2, 6], [3, 4], [4, 5], [6, 7], [7, 8], [2, 9], [9, 10], \
+ [10, 11], [2, 12], [12, 13], [13, 14], [2, 1], [1, 15], [15, 17], \
+ [1, 16], [16, 18], [3, 17], [6, 18]]
+ # the middle joints heatmap correpondence
+ mapIdx = [[31, 32], [39, 40], [33, 34], [35, 36], [41, 42], [43, 44], [19, 20], [21, 22], \
+ [23, 24], [25, 26], [27, 28], [29, 30], [47, 48], [49, 50], [53, 54], [51, 52], \
+ [55, 56], [37, 38], [45, 46]]
+
+ connection_all = []
+ special_k = []
+ mid_num = 10
+
+ for k in range(len(mapIdx)):
+ score_mid = paf_avg[:, :, [x - 19 for x in mapIdx[k]]]
+ candA = all_peaks[limbSeq[k][0] - 1]
+ candB = all_peaks[limbSeq[k][1] - 1]
+ nA = len(candA)
+ nB = len(candB)
+ indexA, indexB = limbSeq[k]
+ if (nA != 0 and nB != 0):
+ connection_candidate = []
+ for i in range(nA):
+ for j in range(nB):
+ vec = np.subtract(candB[j][:2], candA[i][:2])
+ norm = math.sqrt(vec[0] * vec[0] + vec[1] * vec[1])
+ norm = max(0.001, norm)
+ vec = np.divide(vec, norm)
+
+ startend = list(zip(np.linspace(candA[i][0], candB[j][0], num=mid_num), \
+ np.linspace(candA[i][1], candB[j][1], num=mid_num)))
+
+ vec_x = np.array([score_mid[int(round(startend[I][1])), int(round(startend[I][0])), 0] \
+ for I in range(len(startend))])
+ vec_y = np.array([score_mid[int(round(startend[I][1])), int(round(startend[I][0])), 1] \
+ for I in range(len(startend))])
+
+ score_midpts = np.multiply(vec_x, vec[0]) + np.multiply(vec_y, vec[1])
+ score_with_dist_prior = sum(score_midpts) / len(score_midpts) + min(
+ 0.5 * oriImg.shape[0] / norm - 1, 0)
+ criterion1 = len(np.nonzero(score_midpts > thre2)[0]) > 0.8 * len(score_midpts)
+ criterion2 = score_with_dist_prior > 0
+ if criterion1 and criterion2:
+ connection_candidate.append(
+ [i, j, score_with_dist_prior, score_with_dist_prior + candA[i][2] + candB[j][2]])
+
+ connection_candidate = sorted(connection_candidate, key=lambda x: x[2], reverse=True)
+ connection = np.zeros((0, 5))
+ for c in range(len(connection_candidate)):
+ i, j, s = connection_candidate[c][0:3]
+ if (i not in connection[:, 3] and j not in connection[:, 4]):
+ connection = np.vstack([connection, [candA[i][3], candB[j][3], s, i, j]])
+ if (len(connection) >= min(nA, nB)):
+ break
+
+ connection_all.append(connection)
+ else:
+ special_k.append(k)
+ connection_all.append([])
+
+ # last number in each row is the total parts number of that person
+ # the second last number in each row is the score of the overall configuration
+ subset = -1 * np.ones((0, 20))
+ candidate = np.array([item for sublist in all_peaks for item in sublist])
+
+ for k in range(len(mapIdx)):
+ if k not in special_k:
+ partAs = connection_all[k][:, 0]
+ partBs = connection_all[k][:, 1]
+ indexA, indexB = np.array(limbSeq[k]) - 1
+
+ for i in range(len(connection_all[k])): # = 1:size(temp,1)
+ found = 0
+ subset_idx = [-1, -1]
+ for j in range(len(subset)): # 1:size(subset,1):
+ if subset[j][indexA] == partAs[i] or subset[j][indexB] == partBs[i]:
+ subset_idx[found] = j
+ found += 1
+
+ if found == 1:
+ j = subset_idx[0]
+ if subset[j][indexB] != partBs[i]:
+ subset[j][indexB] = partBs[i]
+ subset[j][-1] += 1
+ subset[j][-2] += candidate[partBs[i].astype(int), 2] + connection_all[k][i][2]
+ elif found == 2: # if found 2 and disjoint, merge them
+ j1, j2 = subset_idx
+ membership = ((subset[j1] >= 0).astype(int) + (subset[j2] >= 0).astype(int))[:-2]
+ if len(np.nonzero(membership == 2)[0]) == 0: # merge
+ subset[j1][:-2] += (subset[j2][:-2] + 1)
+ subset[j1][-2:] += subset[j2][-2:]
+ subset[j1][-2] += connection_all[k][i][2]
+ subset = np.delete(subset, j2, 0)
+ else: # as like found == 1
+ subset[j1][indexB] = partBs[i]
+ subset[j1][-1] += 1
+ subset[j1][-2] += candidate[partBs[i].astype(int), 2] + connection_all[k][i][2]
+
+ # if find no partA in the subset, create a new subset
+ elif not found and k < 17:
+ row = -1 * np.ones(20)
+ row[indexA] = partAs[i]
+ row[indexB] = partBs[i]
+ row[-1] = 2
+ row[-2] = sum(candidate[connection_all[k][i, :2].astype(int), 2]) + connection_all[k][i][2]
+ subset = np.vstack([subset, row])
+ # delete some rows of subset which has few parts occur
+ deleteIdx = []
+ for i in range(len(subset)):
+ if subset[i][-1] < 4 or subset[i][-2] / subset[i][-1] < 0.4:
+ deleteIdx.append(i)
+ subset = np.delete(subset, deleteIdx, axis=0)
+
+ # subset: n*20 array, 0-17 is the index in candidate, 18 is the total score, 19 is the total parts
+ # candidate: x, y, score, id
+ return candidate, subset
+
+
+
+def sample_video_frames(video_path,):
+ cap = cv2.VideoCapture(video_path)
+ total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
+ frame_indices = np.linspace(0, total_frames - 1, total_frames, dtype=int)
+
+ frames = []
+ for idx in frame_indices:
+ cap.set(cv2.CAP_PROP_POS_FRAMES, idx)
+ ret, frame = cap.read()
+ if ret:
+ if frame.shape[1] > 1024:
+ frame = frame[:, 1440:, :]
+ frame = cv2.resize(frame, (720, 480))
+ frames.append(frame)
+ cap.release()
+ return frames
+
+
+def process_image(pose_model, image_path):
+ if isinstance(image_path, str):
+ np_faceid_image = np.array(Image.open(image_path).convert("RGB"))
+ elif isinstance(image_path, numpy.ndarray):
+ np_faceid_image = image_path
+ else:
+ raise TypeError("image_path should be a string or PIL.Image.Image object")
+
+ image_bgr = cv2.cvtColor(np_faceid_image, cv2.COLOR_RGB2BGR)
+ candidate, subset = pose_model(image_bgr)
+
+ pose_list = []
+ for c in candidate:
+ pose_list.append([c[0], c[1]])
+ return pose_list
+
+
+def process_video(video_path, pose_model):
+ video_frames = sample_video_frames(video_path,)
+ print(len(video_frames))
+ pose_list = []
+ for frame in video_frames:
+ # Convert to RGB once at the beginning
+ frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
+ pose = process_image(pose_model, frame_rgb)
+ pose_list.append(pose)
+ # break
+ return pose_list
+
+
+def calculate_l1_distance(list1, list2):
+ """
+ 计算两个列表的 L1 距离
+ :return: L1 距离
+ """
+ # 将列表转换为 NumPy 数组
+ list1 = np.array(list1)
+ list2 = np.array(list2)
+
+ min_d = min(list1.shape[0], list2.shape[0])
+ list1 = list1[:min_d, :]
+ list2 = list2[:min_d, :]
+ # 计算每对点的 L1 距离
+ l1_distances = np.abs(list1 - list2).sum(axis=1)
+
+ # 返回所有点的 L1 距离之和
+ return l1_distances.sum()
+
+
+def calculate_pose(list1, list2):
+ distance_list = []
+ for kps1 in list1:
+ min_dis = (480 + 720) * 17 + 1
+ for kps2 in list2:
+ try:
+ min_dis = min(min_dis, calculate_l1_distance(kps1, kps2))
+ except:
+ continue
+ min_dis = min_dis/(480+720)/16
+ if min_dis > 1:
+ continue
+ distance_list.append(min_dis)
+
+ if len(distance_list) > 0:
+ return sum(distance_list)/len(distance_list)
+ else:
+ return 0.
+
+def main():
+ body_estimation = Body('eval/pose/body_pose_model.pth')
+
+ device = "cuda"
+ data_path = "data/SkyActor"
+ # data_path = "data/LivePotraits"
+ # data_path = "data/Actor-One"
+ # data_path = "data/FollowYourEmoji"
+ img_path = "/maindata/data/shared/public/rui.wang/act_review/driving_video"
+ pre_tag = True
+ mp4_list = os.listdir(data_path)
+ print(mp4_list)
+
+ img_list = []
+ video_list = []
+ for mp4 in mp4_list:
+ if "mp4" not in mp4:
+ continue
+ if pre_tag:
+ png_path = mp4.split('.')[0].split('-')[1] + ".mp4"
+ else:
+ if "-" in mp4:
+ png_path = mp4.split('.')[0].split('-')[0] + ".mp4"
+ else:
+ png_path = mp4.split('.')[0].split('_')[0] + ".mp4"
+ img_list.append(os.path.join(img_path, png_path))
+ video_list.append(os.path.join(data_path, mp4))
+ print(img_list)
+ print(video_list[0])
+
+ pd_list = []
+ for i in range(len(img_list)):
+ print("number: ", str(i), " total: ", len(img_list), data_path)
+
+ pose_1 = process_video(video_list[i], body_estimation)
+ pose_2 = process_video(img_list[i], body_estimation)
+
+ dis = calculate_pose(pose_1, pose_2)
+ print(dis)
+ if dis > 0.0001:
+ pd_list.append(dis)
+
+ print("pose", sum(pd_list)/ len(pd_list))
+
+
+main()
diff --git a/inference.py b/inference.py
new file mode 100644
index 0000000000000000000000000000000000000000..e3f77803ec9d0cad256a2a56495e1720badcc35a
--- /dev/null
+++ b/inference.py
@@ -0,0 +1,241 @@
+import torch
+import os
+import numpy as np
+from PIL import Image
+import glob
+import insightface
+import cv2
+import subprocess
+import argparse
+from decord import VideoReader
+from moviepy.editor import ImageSequenceClip, AudioFileClip, VideoFileClip
+from facexlib.parsing import init_parsing_model
+from facexlib.utils.face_restoration_helper import FaceRestoreHelper
+from insightface.app import FaceAnalysis
+
+from diffusers.models import AutoencoderKLCogVideoX
+from diffusers.utils import export_to_video, load_image
+from transformers import AutoModelForDepthEstimation, AutoProcessor, SiglipImageProcessor, SiglipVisionModel
+from transformers import CLIPVisionModelWithProjection, CLIPImageProcessor
+
+from skyreels_a1.models.transformer3d import CogVideoXTransformer3DModel
+from skyreels_a1.skyreels_a1_i2v_pipeline import SkyReelsA1ImagePoseToVideoPipeline
+from skyreels_a1.pre_process_lmk3d import FaceAnimationProcessor
+from skyreels_a1.src.media_pipe.mp_utils import LMKExtractor
+from skyreels_a1.src.media_pipe.draw_util_2d import FaceMeshVisualizer2d
+
+def crop_and_resize(image, height, width):
+ image = np.array(image)
+ image_height, image_width, _ = image.shape
+ if image_height / image_width < height / width:
+ croped_width = int(image_height / height * width)
+ left = (image_width - croped_width) // 2
+ image = image[:, left: left+croped_width]
+ image = Image.fromarray(image).resize((width, height))
+ else:
+ pad = int((((width / height) * image_height) - image_width) / 2.)
+ padded_image = np.zeros((image_height, image_width + pad * 2, 3), dtype=np.uint8)
+ padded_image[:, pad:pad+image_width] = image
+ image = Image.fromarray(padded_image).resize((width, height))
+ return image
+
+def write_mp4(video_path, samples, fps=12, audio_bitrate="192k"):
+ clip = ImageSequenceClip(samples, fps=fps)
+ clip.write_videofile(video_path, audio_codec="aac", audio_bitrate=audio_bitrate,
+ ffmpeg_params=["-crf", "18", "-preset", "slow"])
+
+def parse_video(driving_video_path, max_frame_num):
+ vr = VideoReader(driving_video_path)
+ fps = vr.get_avg_fps()
+ video_length = len(vr)
+
+ duration = video_length / fps
+ target_times = np.arange(0, duration, 1/12)
+ frame_indices = (target_times * fps).astype(np.int32)
+
+ frame_indices = frame_indices[frame_indices < video_length]
+ control_frames = vr.get_batch(frame_indices).asnumpy()[:(max_frame_num-1)]
+
+ out_frames = len(control_frames) - 1
+ if len(control_frames) < max_frame_num - 1:
+ video_lenght_add = max_frame_num - len(control_frames) - 1
+ control_frames = np.concatenate(([control_frames[0]]*2, control_frames[1:len(control_frames)-1], [control_frames[-1]] * video_lenght_add), axis=0)
+ else:
+ control_frames = np.concatenate(([control_frames[0]]*2, control_frames[1:len(control_frames)-1]), axis=0)
+
+ return control_frames
+
+def exec_cmd(cmd):
+ return subprocess.run(cmd, shell=True, check=True, stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
+
+def add_audio_to_video(silent_video_path: str, audio_video_path: str, output_video_path: str):
+ cmd = [
+ 'ffmpeg',
+ '-y',
+ '-i', f'"{silent_video_path}"',
+ '-i', f'"{audio_video_path}"',
+ '-map', '0:v',
+ '-map', '1:a',
+ '-c:v', 'copy',
+ '-shortest',
+ f'"{output_video_path}"'
+ ]
+
+ try:
+ exec_cmd(' '.join(cmd))
+ print(f"Video with audio generated successfully: {output_video_path}")
+ except subprocess.CalledProcessError as e:
+ print(f"Error occurred: {e}")
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser(description="Process video and image for face animation.")
+ parser.add_argument('--image_path', type=str, default="assets/ref_images/1.png", help='Path to the source image.')
+ parser.add_argument('--driving_video_path', type=str, default="assets/driving_video/1.mp4", help='Path to the driving video.')
+ parser.add_argument('--output_path', type=str, default="outputs", help='Path to save the output video.')
+ args = parser.parse_args()
+
+ guidance_scale = 3.0
+ seed = 43
+ num_inference_steps = 10
+ sample_size = [480, 720]
+ max_frame_num = 49
+ weight_dtype = torch.bfloat16
+ save_path = args.output_path
+ generator = torch.Generator(device="cuda").manual_seed(seed)
+ model_name = "pretrained_models/SkyReels-A1-5B/"
+ siglip_name = "pretrained_models/SkyReels-A1-5B/siglip-so400m-patch14-384"
+
+ lmk_extractor = LMKExtractor()
+ processor = FaceAnimationProcessor(checkpoint='pretrained_models/smirk/SMIRK_em1.pt')
+ vis = FaceMeshVisualizer2d(forehead_edge=False, draw_head=False, draw_iris=False,)
+ face_helper = FaceRestoreHelper(upscale_factor=1, face_size=512, crop_ratio=(1, 1), det_model='retinaface_resnet50', save_ext='png', device="cuda",)
+
+ # siglip visual encoder
+ siglip = SiglipVisionModel.from_pretrained(siglip_name)
+ siglip_normalize = SiglipImageProcessor.from_pretrained(siglip_name)
+
+ # skyreels a1 model
+ transformer = CogVideoXTransformer3DModel.from_pretrained(
+ model_name,
+ subfolder="transformer"
+ ).to(weight_dtype)
+
+ vae = AutoencoderKLCogVideoX.from_pretrained(
+ model_name,
+ subfolder="vae"
+ ).to(weight_dtype)
+
+ lmk_encoder = AutoencoderKLCogVideoX.from_pretrained(
+ model_name,
+ subfolder="pose_guider",
+ ).to(weight_dtype)
+
+ pipe = SkyReelsA1ImagePoseToVideoPipeline.from_pretrained(
+ model_name,
+ transformer = transformer,
+ vae = vae,
+ lmk_encoder = lmk_encoder,
+ image_encoder = siglip,
+ feature_extractor = siglip_normalize,
+ torch_dtype=torch.bfloat16
+ )
+
+ pipe.to("cuda")
+ pipe.enable_model_cpu_offload()
+ pipe.vae.enable_tiling()
+
+ control_frames = parse_video(args.driving_video_path, max_frame_num)
+
+ # driving video crop face
+ driving_video_crop = []
+ for control_frame in control_frames:
+ frame, _, _ = processor.face_crop(control_frame)
+ driving_video_crop.append(frame)
+
+ image = load_image(image=args.image_path)
+ image = processor.crop_and_resize(image, sample_size[0], sample_size[1])
+
+ # ref image crop face
+ ref_image, x1, y1 = processor.face_crop(np.array(image))
+ face_h, face_w, _, = ref_image.shape
+ source_image = ref_image
+ driving_video = driving_video_crop
+ out_frames = processor.preprocess_lmk3d(source_image, driving_video)
+
+ rescale_motions = np.zeros_like(image)[np.newaxis, :].repeat(48, axis=0)
+ for ii in range(rescale_motions.shape[0]):
+ rescale_motions[ii][y1:y1+face_h, x1:x1+face_w] = out_frames[ii]
+ ref_image = cv2.resize(ref_image, (512, 512))
+ ref_lmk = lmk_extractor(ref_image[:, :, ::-1])
+
+ ref_img = vis.draw_landmarks_v3((512, 512), (face_w, face_h), ref_lmk['lmks'].astype(np.float32), normed=True)
+
+ first_motion = np.zeros_like(np.array(image))
+ first_motion[y1:y1+face_h, x1:x1+face_w] = ref_img
+ first_motion = first_motion[np.newaxis, :]
+
+ motions = np.concatenate([first_motion, rescale_motions])
+ input_video = motions[:max_frame_num]
+
+ face_helper.clean_all()
+ face_helper.read_image(np.array(image)[:, :, ::-1])
+ face_helper.get_face_landmarks_5(only_center_face=True)
+ face_helper.align_warp_face()
+ align_face = face_helper.cropped_faces[0]
+ image_face = align_face[:, :, ::-1]
+
+ input_video = input_video[:max_frame_num]
+ motions = np.array(input_video)
+
+ # [F, H, W, C]
+ input_video = torch.from_numpy(np.array(input_video)).permute([3, 0, 1, 2]).unsqueeze(0)
+ input_video = input_video / 255
+
+ out_samples = []
+
+ with torch.no_grad():
+ sample = pipe(
+ image=image,
+ image_face=image_face,
+ control_video = input_video,
+ prompt = "",
+ negative_prompt = "",
+ height = sample_size[0],
+ width = sample_size[1],
+ num_frames = 49,
+ generator = generator,
+ guidance_scale = guidance_scale,
+ num_inference_steps = num_inference_steps,
+ )
+ out_samples.extend(sample.frames[0])
+ out_samples = out_samples[2:]
+
+ save_path_name = os.path.basename(args.image_path).split(".")[0] + "-" + os.path.basename(args.driving_video_path).split(".")[0]+ ".mp4"
+
+ if not os.path.exists(save_path):
+ os.makedirs(save_path, exist_ok=True)
+ video_path = os.path.join(save_path, save_path_name + "-output.mp4")
+ export_to_video(out_samples, video_path, fps=12)
+
+ target_h, target_w = sample_size[0], sample_size[1]
+ final_images = []
+ final_images2 =[]
+ rescale_motions = rescale_motions[1:]
+ control_frames = control_frames[1:]
+ for q in range(len(out_samples)):
+ frame1 = image
+ frame2 = crop_and_resize(Image.fromarray(np.array(control_frames[q])).convert("RGB"), target_h, target_w)
+ frame3 = Image.fromarray(np.array(out_samples[q])).convert("RGB")
+
+ result = Image.new('RGB', (target_w * 3, target_h))
+ result.paste(frame1, (0, 0))
+ result.paste(frame2, (target_w, 0))
+ result.paste(frame3, (target_w * 2, 0))
+ final_images.append(np.array(result))
+
+ video_out_path = os.path.join(save_path, save_path_name)
+ write_mp4(video_out_path, final_images, fps=12)
+
+ add_audio_to_video(video_out_path, args.driving_video_path, video_out_path + ".audio.mp4")
+ add_audio_to_video(video_path, args.driving_video_path, video_path + ".audio.mp4")
diff --git a/inference_audio.py b/inference_audio.py
new file mode 100644
index 0000000000000000000000000000000000000000..68046dd9239a04dd2a066887a284b36251b77b5a
--- /dev/null
+++ b/inference_audio.py
@@ -0,0 +1,234 @@
+import torch
+import os
+import numpy as np
+from PIL import Image
+import glob
+import insightface
+import cv2
+import subprocess
+import argparse
+from decord import VideoReader
+from moviepy.editor import ImageSequenceClip, AudioFileClip, VideoFileClip
+from facexlib.parsing import init_parsing_model
+from facexlib.utils.face_restoration_helper import FaceRestoreHelper
+from insightface.app import FaceAnalysis
+
+from diffusers.models import AutoencoderKLCogVideoX
+from diffusers.utils import export_to_video, load_image
+from transformers import AutoModelForDepthEstimation, AutoProcessor, SiglipImageProcessor, SiglipVisionModel
+from transformers import CLIPVisionModelWithProjection, CLIPImageProcessor
+
+from skyreels_a1.models.transformer3d import CogVideoXTransformer3DModel
+from skyreels_a1.skyreels_a1_i2v_pipeline import SkyReelsA1ImagePoseToVideoPipeline
+from skyreels_a1.pre_process_lmk3d import FaceAnimationProcessor
+from skyreels_a1.src.media_pipe.mp_utils import LMKExtractor
+from skyreels_a1.src.media_pipe.draw_util_2d import FaceMeshVisualizer2d
+
+import moviepy.editor as mp
+from diffposetalk.diffposetalk import DiffPoseTalk
+
+
+def crop_and_resize(image, height, width):
+ image = np.array(image)
+ image_height, image_width, _ = image.shape
+ if image_height / image_width < height / width:
+ croped_width = int(image_height / height * width)
+ left = (image_width - croped_width) // 2
+ image = image[:, left: left+croped_width]
+ image = Image.fromarray(image).resize((width, height))
+ else:
+ pad = int((((width / height) * image_height) - image_width) / 2.)
+ padded_image = np.zeros((image_height, image_width + pad * 2, 3), dtype=np.uint8)
+ padded_image[:, pad:pad+image_width] = image
+ image = Image.fromarray(padded_image).resize((width, height))
+ return image
+
+def write_mp4(video_path, samples, fps=12, audio_bitrate="192k"):
+ clip = ImageSequenceClip(samples, fps=fps)
+ clip.write_videofile(video_path, audio_codec="aac", audio_bitrate=audio_bitrate,
+ ffmpeg_params=["-crf", "18", "-preset", "slow"])
+
+
+def parse_video(driving_frames, max_frame_num, fps=25):
+
+ video_length = len(driving_frames)
+
+ duration = video_length / fps
+ target_times = np.arange(0, duration, 1/12)
+ frame_indices = (target_times * fps).astype(np.int32)
+
+ frame_indices = frame_indices[frame_indices < video_length]
+ new_driving_frames = []
+ for idx in frame_indices:
+ new_driving_frames.append(driving_frames[idx])
+ if len(new_driving_frames) >= max_frame_num - 1:
+ break
+
+ video_lenght_add = max_frame_num - len(new_driving_frames) - 1
+ new_driving_frames = [new_driving_frames[0]]*2 + new_driving_frames[1:len(new_driving_frames)-1] + [new_driving_frames[-1]] * video_lenght_add
+ return new_driving_frames
+
+
+def save_video_with_audio(video_path:str, audio_path: str, save_path: str):
+ video_clip = mp.VideoFileClip(video_path)
+ audio_clip = mp.AudioFileClip(audio_path)
+
+ if audio_clip.duration > video_clip.duration:
+ audio_clip = audio_clip.subclip(0, video_clip.duration)
+
+ video_with_audio = video_clip.set_audio(audio_clip)
+
+ video_with_audio.write_videofile(save_path, fps=12)
+
+ os.remove(video_path)
+
+ video_clip.close()
+ audio_clip.close()
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser(description="Process video and image for face animation.")
+ parser.add_argument('--image_path', type=str, default="assets/ref_images/1.png", help='Path to the source image.')
+ parser.add_argument('--driving_audio_path', type=str, default="assets/driving_audio/1.wav", help='Path to the driving video.')
+ parser.add_argument('--output_path', type=str, default="outputs_audio", help='Path to save the output video.')
+ args = parser.parse_args()
+
+ guidance_scale = 3.0
+ seed = 43
+ num_inference_steps = 10
+ sample_size = [480, 720]
+ max_frame_num = 49
+ weight_dtype = torch.bfloat16
+ save_path = args.output_path
+ generator = torch.Generator(device="cuda").manual_seed(seed)
+ model_name = "pretrained_models/SkyReels-A1-5B/"
+ siglip_name = "pretrained_models/SkyReels-A1-5B/siglip-so400m-patch14-384"
+
+ lmk_extractor = LMKExtractor()
+ processor = FaceAnimationProcessor(checkpoint='pretrained_models/smirk/SMIRK_em1.pt')
+ vis = FaceMeshVisualizer2d(forehead_edge=False, draw_head=False, draw_iris=False,)
+ face_helper = FaceRestoreHelper(upscale_factor=1, face_size=512, crop_ratio=(1, 1), det_model='retinaface_resnet50', save_ext='png', device="cuda",)
+
+ # siglip visual encoder
+ siglip = SiglipVisionModel.from_pretrained(siglip_name)
+ siglip_normalize = SiglipImageProcessor.from_pretrained(siglip_name)
+
+ # diffposetalk
+ diffposetalk = DiffPoseTalk()
+
+ # skyreels a1 model
+ transformer = CogVideoXTransformer3DModel.from_pretrained(
+ model_name,
+ subfolder="transformer"
+ ).to(weight_dtype)
+
+ vae = AutoencoderKLCogVideoX.from_pretrained(
+ model_name,
+ subfolder="vae"
+ ).to(weight_dtype)
+
+ lmk_encoder = AutoencoderKLCogVideoX.from_pretrained(
+ model_name,
+ subfolder="pose_guider",
+ ).to(weight_dtype)
+
+ pipe = SkyReelsA1ImagePoseToVideoPipeline.from_pretrained(
+ model_name,
+ transformer = transformer,
+ vae = vae,
+ lmk_encoder = lmk_encoder,
+ image_encoder = siglip,
+ feature_extractor = siglip_normalize,
+ torch_dtype=torch.bfloat16
+ )
+
+ pipe.to("cuda")
+ pipe.enable_model_cpu_offload()
+ pipe.vae.enable_tiling()
+
+ image = load_image(image=args.image_path)
+ image = processor.crop_and_resize(image, sample_size[0], sample_size[1])
+
+ # ref image crop face
+ ref_image, x1, y1 = processor.face_crop(np.array(image))
+ face_h, face_w, _, = ref_image.shape
+ source_image = ref_image
+
+ source_outputs, source_tform, image_original = processor.process_source_image(source_image)
+ driving_outputs = diffposetalk.infer_from_file(args.driving_audio_path, source_outputs["shape_params"].view(-1)[:100].detach().cpu().numpy())
+ out_frames = processor.preprocess_lmk3d_from_coef(source_outputs, source_tform, image_original.shape, driving_outputs)
+ out_frames = parse_video(out_frames, max_frame_num)
+
+ rescale_motions = np.zeros_like(image)[np.newaxis, :].repeat(48, axis=0)
+ for ii in range(rescale_motions.shape[0]):
+ rescale_motions[ii][y1:y1+face_h, x1:x1+face_w] = out_frames[ii]
+ ref_image = cv2.resize(ref_image, (512, 512))
+ ref_lmk = lmk_extractor(ref_image[:, :, ::-1])
+
+ ref_img = vis.draw_landmarks_v3((512, 512), (face_w, face_h), ref_lmk['lmks'].astype(np.float32), normed=True)
+
+ first_motion = np.zeros_like(np.array(image))
+ first_motion[y1:y1+face_h, x1:x1+face_w] = ref_img
+ first_motion = first_motion[np.newaxis, :]
+
+ motions = np.concatenate([first_motion, rescale_motions])
+ input_video = motions[:max_frame_num]
+
+ face_helper.clean_all()
+ face_helper.read_image(np.array(image)[:, :, ::-1])
+ face_helper.get_face_landmarks_5(only_center_face=True)
+ face_helper.align_warp_face()
+ align_face = face_helper.cropped_faces[0]
+ image_face = align_face[:, :, ::-1]
+
+ input_video = input_video[:max_frame_num]
+ motions = np.array(input_video)
+
+ # [F, H, W, C]
+ input_video = torch.from_numpy(np.array(input_video)).permute([3, 0, 1, 2]).unsqueeze(0)
+ input_video = input_video / 255
+
+ out_samples = []
+
+ with torch.no_grad():
+ sample = pipe(
+ image=image,
+ image_face=image_face,
+ control_video = input_video,
+ prompt = "",
+ negative_prompt = "",
+ height = sample_size[0],
+ width = sample_size[1],
+ num_frames = 49,
+ generator = generator,
+ guidance_scale = guidance_scale,
+ num_inference_steps = num_inference_steps,
+ )
+ out_samples.extend(sample.frames[0])
+ out_samples = out_samples[2:]
+
+ save_path_name = os.path.basename(args.image_path).split(".")[0] + "-" + os.path.basename(args.driving_audio_path).split(".")[0]+ ".mp4"
+
+ if not os.path.exists(save_path):
+ os.makedirs(save_path, exist_ok=True)
+ video_path = os.path.join(save_path, save_path_name + ".output.mp4")
+ export_to_video(out_samples, video_path, fps=12)
+ target_h, target_w = sample_size[0], sample_size[1]
+ final_images = []
+ final_images2 =[]
+ rescale_motions = rescale_motions[1:]
+ control_frames = out_frames[1:]
+ for q in range(len(out_samples)):
+ frame1 = image
+ frame2 = Image.fromarray(np.array(out_samples[q])).convert("RGB")
+
+ result = Image.new('RGB', (target_w * 2, target_h))
+ result.paste(frame1, (0, 0))
+ result.paste(frame2, (target_w, 0))
+ final_images.append(np.array(result))
+
+ video_out_path = os.path.join(save_path, save_path_name)
+ write_mp4(video_out_path, final_images, fps=12)
+
+ save_video_with_audio(video_out_path, args.driving_audio_path, video_out_path + ".audio.mp4")
+ save_video_with_audio(video_path, args.driving_audio_path, video_path + ".audio.mp4")
+
diff --git a/requirements.txt b/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..227ddbbfabcec175b3a9fb3f54f0eeabd5a330b5
--- /dev/null
+++ b/requirements.txt
@@ -0,0 +1,22 @@
+chumpy==0.70
+decord==0.6.0
+diffusers==0.32.2
+einops==0.8.1
+facexlib==0.3.0
+gradio==5.16.0
+insightface==0.7.3
+moviepy==1.0.3
+numpy==2.2.2
+opencv_contrib_python==4.10.0.84
+opencv_python==4.10.0.84
+opencv_python_headless==4.10.0.84
+Pillow==11.1.0
+pytorch3d==0.7.8
+safetensors==0.5.2
+scikit-image==0.24.0
+timm==0.6.13
+torch==2.2.2+cu118
+tqdm==4.66.2
+transformers==4.37.2
+mediapipe==0.10.21
+librosa==0.10.2.post1
diff --git a/scripts/__init__.py b/scripts/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/scripts/demo.py b/scripts/demo.py
new file mode 100644
index 0000000000000000000000000000000000000000..4743636be5526a52c68f7b9499b6f69451ef1830
--- /dev/null
+++ b/scripts/demo.py
@@ -0,0 +1,251 @@
+import torch
+import os
+import numpy as np
+from PIL import Image
+import glob
+import insightface
+import cv2
+import subprocess
+import argparse
+from decord import VideoReader
+from moviepy.editor import ImageSequenceClip, AudioFileClip, VideoFileClip
+from facexlib.parsing import init_parsing_model
+from facexlib.utils.face_restoration_helper import FaceRestoreHelper
+from insightface.app import FaceAnalysis
+
+from diffusers.models import AutoencoderKLCogVideoX
+from diffusers.utils import export_to_video, load_image
+from transformers import AutoModelForDepthEstimation, AutoProcessor, SiglipImageProcessor, SiglipVisionModel
+from transformers import CLIPVisionModelWithProjection, CLIPImageProcessor
+
+from skyreels_a1.models.transformer3d import CogVideoXTransformer3DModel
+from skyreels_a1.skyreels_a1_i2v_pipeline import SkyReelsA1ImagePoseToVideoPipeline
+from skyreels_a1.pre_process_lmk3d import FaceAnimationProcessor
+from skyreels_a1.src.media_pipe.mp_utils import LMKExtractor
+from skyreels_a1.src.media_pipe.draw_util_2d import FaceMeshVisualizer2d
+
+
+def crop_and_resize(image, height, width):
+ image = np.array(image)
+ image_height, image_width, _ = image.shape
+ if image_height / image_width < height / width:
+ croped_width = int(image_height / height * width)
+ left = (image_width - croped_width) // 2
+ image = image[:, left: left+croped_width]
+ image = Image.fromarray(image).resize((width, height))
+ else:
+ pad = int((((width / height) * image_height) - image_width) / 2.)
+ padded_image = np.zeros((image_height, image_width + pad * 2, 3), dtype=np.uint8)
+ padded_image[:, pad:pad+image_width] = image
+ image = Image.fromarray(padded_image).resize((width, height))
+ return image
+
+def write_mp4(video_path, samples, fps=14, audio_bitrate="192k"):
+ clip = ImageSequenceClip(samples, fps=fps)
+ clip.write_videofile(video_path, audio_codec="aac", audio_bitrate=audio_bitrate,
+ ffmpeg_params=["-crf", "18", "-preset", "slow"])
+
+def init_model(
+ model_name: str = "pretrained_models/SkyReels-A1-5B/",
+ subfolder: str = "outputs/",
+ siglip_path: str = "pretrained_models/siglip-so400m-patch14-384",
+ weight_dtype=torch.bfloat16,
+ ):
+
+ lmk_extractor = LMKExtractor()
+ vis = FaceMeshVisualizer2d(forehead_edge=False, draw_head=False, draw_iris=False,)
+ processor = FaceAnimationProcessor(checkpoint='pretrained_models/smirk/SMIRK_em1.pt')
+
+ face_helper = FaceRestoreHelper(
+ upscale_factor=1,
+ face_size=512,
+ crop_ratio=(1, 1),
+ det_model='retinaface_resnet50',
+ save_ext='png',
+ device="cuda",
+ )
+
+ siglip = SiglipVisionModel.from_pretrained(siglip_path)
+ siglip_normalize = SiglipImageProcessor.from_pretrained(siglip_path)
+
+ transformer = CogVideoXTransformer3DModel.from_pretrained(
+ model_name,
+ subfolder="transformer",
+ ).to(weight_dtype)
+
+ vae = AutoencoderKLCogVideoX.from_pretrained(
+ model_name,
+ subfolder="vae"
+ ).to(weight_dtype)
+
+ lmk_encoder = AutoencoderKLCogVideoX.from_pretrained(
+ model_name,
+ subfolder="pose_guider"
+ ).to(weight_dtype)
+
+ pipe = SkyReelsA1ImagePoseToVideoPipeline.from_pretrained(
+ model_name,
+ transformer = transformer,
+ vae = vae,
+ lmk_encoder = lmk_encoder,
+ image_encoder = siglip,
+ feature_extractor = siglip_normalize,
+ torch_dtype=weight_dtype)
+ pipe.to("cuda")
+ pipe.enable_model_cpu_offload()
+ pipe.vae.enable_tiling()
+
+ return pipe, face_helper, processor, lmk_extractor, vis
+
+
+
+def generate_video(
+ pipe,
+ face_helper,
+ processor,
+ lmk_extractor,
+ vis,
+ control_video_path: str = None,
+ image_path: str = None,
+ save_path: str = None,
+ guidance_scale=3.0,
+ seed=43,
+ num_inference_steps=10,
+ sample_size=[480, 720],
+ max_frame_num=49,
+ weight_dtype=torch.bfloat16,
+ ):
+
+ vr = VideoReader(control_video_path)
+ fps = vr.get_avg_fps()
+ video_length = len(vr)
+
+ duration = video_length / fps
+ target_times = np.arange(0, duration, 1/12)
+ frame_indices = (target_times * fps).astype(np.int32)
+
+ frame_indices = frame_indices[frame_indices < video_length]
+ control_frames = vr.get_batch(frame_indices).asnumpy()[:(max_frame_num-1)]
+
+ out_frames = len(control_frames) - 1
+ if len(control_frames) < max_frame_num:
+ video_lenght_add = max_frame_num - len(control_frames)
+ control_frames = np.concatenate(([control_frames[0]]*2, control_frames[1:len(control_frames)-2], [control_frames[-1]] * video_lenght_add), axis=0)
+
+ # driving video crop face
+ driving_video_crop = []
+ for control_frame in control_frames:
+ frame, _, _ = processor.face_crop(control_frame)
+ driving_video_crop.append(frame)
+
+ image = load_image(image=image_path)
+ image = crop_and_resize(image, sample_size[0], sample_size[1])
+
+ with torch.no_grad():
+ face_helper.clean_all()
+ face_helper.read_image(np.array(image)[:, :, ::-1])
+ face_helper.get_face_landmarks_5(only_center_face=True)
+ face_helper.align_warp_face()
+ if len(face_helper.cropped_faces) == 0:
+ return
+ align_face = face_helper.cropped_faces[0]
+ image_face = align_face[:, :, ::-1]
+
+ # ref image crop face
+ ref_image, x1, y1 = processor.face_crop(np.array(image))
+ face_h, face_w, _, = ref_image.shape
+ source_image = ref_image
+ driving_video = driving_video_crop
+ out_frames = processor.preprocess_lmk3d(source_image, driving_video)
+
+ rescale_motions = np.zeros_like(image)[np.newaxis, :].repeat(48, axis=0)
+ for ii in range(rescale_motions.shape[0]):
+ rescale_motions[ii][y1:y1+face_h, x1:x1+face_w] = out_frames[ii]
+ ref_image = cv2.resize(ref_image, (512, 512))
+ ref_lmk = lmk_extractor(ref_image[:, :, ::-1])
+
+ ref_img = vis.draw_landmarks_v3((512, 512), (face_w, face_h), ref_lmk['lmks'].astype(np.float32), normed=True)
+
+ first_motion = np.zeros_like(np.array(image))
+ first_motion[y1:y1+face_h, x1:x1+face_w] = ref_img
+ first_motion = first_motion[np.newaxis, :]
+
+ motions = np.concatenate([first_motion, rescale_motions])
+ input_video = motions[:max_frame_num]
+
+ input_video = input_video[:max_frame_num]
+ motions = np.array(input_video)
+
+ # [F, H, W, C]
+ input_video = torch.from_numpy(np.array(input_video)).permute([3, 0, 1, 2]).unsqueeze(0)
+ input_video = input_video / 255
+
+ out_samples = []
+
+ generator = torch.Generator(device="cuda").manual_seed(seed)
+ with torch.no_grad():
+ sample = pipe(
+ image=image,
+ image_face=image_face,
+ control_video = input_video,
+ height = sample_size[0],
+ width = sample_size[1],
+ num_frames = 49,
+ generator = generator,
+ guidance_scale = guidance_scale,
+ num_inference_steps = num_inference_steps,
+ )
+ out_samples.extend(sample.frames[0][2:])
+
+ # export_to_video(out_samples, save_path, fps=12)
+ control_frames = control_frames[1:]
+ target_h, target_w = sample_size
+ final_images = []
+ for i in range(len(out_samples)):
+ frame1 = image
+ frame2 = crop_and_resize(Image.fromarray(np.array(control_frames[i])).convert("RGB"), target_h, target_w)
+ frame3 = Image.fromarray(np.array(out_samples[i])).convert("RGB")
+ result = Image.new('RGB', (target_w * 3, target_h))
+ result.paste(frame1, (0, 0))
+ result.paste(frame2, (target_w, 0))
+ result.paste(frame3, (target_w * 2, 0))
+ final_images.append(np.array(result))
+
+ write_mp4(save_path, final_images, fps=12)
+
+
+
+if __name__ == "__main__":
+ control_video_zip = glob.glob("assets/driving_video/*.mp4")
+ image_path_zip = glob.glob("assets/ref_images/*.png")
+
+ guidance_scale = 3.0
+ seed = 43
+ num_inference_steps = 10
+ sample_size = [480, 720]
+ max_frame_num = 49
+ weight_dtype = torch.bfloat16
+
+ save_path = "outputs"
+
+ # init model
+ pipe, face_helper, processor, lmk_extractor, vis = init_model()
+
+ for i in range(len(control_video_zip)):
+ for j in range(len(image_path_zip)):
+ generate_video(
+ pipe,
+ face_helper,
+ processor,
+ lmk_extractor,
+ vis,
+ control_video_path=control_video_zip[i],
+ image_path=image_path_zip[j],
+ save_path=save_path,
+ guidance_scale=guidance_scale,
+ seed=seed,
+ num_inference_steps=num_inference_steps,
+ sample_size=sample_size,
+ max_frame_num=max_frame_num,
+ weight_dtype=weight_dtype,
+ )
diff --git a/skyreels_a1/__init__.py b/skyreels_a1/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/skyreels_a1/ddim_solver.py b/skyreels_a1/ddim_solver.py
new file mode 100644
index 0000000000000000000000000000000000000000..68ef825d452993ef0ad9519d7b14b6dee095fbd5
--- /dev/null
+++ b/skyreels_a1/ddim_solver.py
@@ -0,0 +1,52 @@
+import numpy as np
+import torch
+
+def append_dims(x, target_dims):
+ """Appends dimensions to the end of a tensor until it has target_dims dimensions."""
+ dims_to_append = target_dims - x.ndim
+ if dims_to_append < 0:
+ raise ValueError(f"input has {x.ndim} dims but target_dims is {target_dims}, which is less")
+ return x[(...,) + (None,) * dims_to_append]
+
+
+# From LCMScheduler.get_scalings_for_boundary_condition_discrete
+def scalings_for_boundary_conditions(timestep, sigma_data=0.5, timestep_scaling=10.0):
+ scaled_timestep = timestep_scaling * timestep
+ c_skip = sigma_data**2 / (scaled_timestep**2 + sigma_data**2)
+ c_out = scaled_timestep / (scaled_timestep**2 + sigma_data**2) ** 0.5
+ return c_skip, c_out
+
+
+def extract_into_tensor(a, t, x_shape):
+ b, *_ = t.shape
+ out = a.gather(-1, t)
+ return out.reshape(b, *((1,) * (len(x_shape) - 1)))
+
+
+
+class DDIMSolver:
+ def __init__(self, alpha_cumprods, timesteps=1000, ddim_timesteps=50):
+ # DDIM sampling parameters
+ step_ratio = timesteps // ddim_timesteps
+
+ self.ddim_timesteps = (np.arange(1, ddim_timesteps + 1) * step_ratio).round().astype(np.int64) - 1
+ self.ddim_alpha_cumprods = alpha_cumprods[self.ddim_timesteps]
+ self.ddim_alpha_cumprods_prev = np.asarray(
+ [alpha_cumprods[0]] + alpha_cumprods[self.ddim_timesteps[:-1]].tolist()
+ )
+ # convert to torch tensors
+ self.ddim_timesteps = torch.from_numpy(self.ddim_timesteps).long()
+ self.ddim_alpha_cumprods = torch.from_numpy(self.ddim_alpha_cumprods)
+ self.ddim_alpha_cumprods_prev = torch.from_numpy(self.ddim_alpha_cumprods_prev)
+
+ def to(self, device):
+ self.ddim_timesteps = self.ddim_timesteps.to(device)
+ self.ddim_alpha_cumprods = self.ddim_alpha_cumprods.to(device)
+ self.ddim_alpha_cumprods_prev = self.ddim_alpha_cumprods_prev.to(device)
+ return self
+
+ def ddim_step(self, pred_x0, pred_noise, timestep_index):
+ alpha_cumprod_prev = extract_into_tensor(self.ddim_alpha_cumprods_prev, timestep_index, pred_x0.shape)
+ dir_xt = (1.0 - alpha_cumprod_prev).sqrt() * pred_noise
+ x_prev = alpha_cumprod_prev.sqrt() * pred_x0 + dir_xt
+ return x_prev
\ No newline at end of file
diff --git a/skyreels_a1/models/__init__.py b/skyreels_a1/models/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/skyreels_a1/models/transformer3d.py b/skyreels_a1/models/transformer3d.py
new file mode 100644
index 0000000000000000000000000000000000000000..1eb16efc8744d7a7e1f1d27183edd89d0f092d48
--- /dev/null
+++ b/skyreels_a1/models/transformer3d.py
@@ -0,0 +1,783 @@
+# Copyright 2024 The CogVideoX team, Tsinghua University & ZhipuAI and The HuggingFace Team.
+# All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Any, Dict, Optional, Tuple, Union
+
+import torch
+from torch import nn
+import os
+import torch.nn.functional as F
+import glob
+import json
+
+
+from diffusers.configuration_utils import ConfigMixin, register_to_config
+from diffusers.loaders import PeftAdapterMixin
+from diffusers.utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers
+from diffusers.utils.torch_utils import maybe_allow_in_graph
+from diffusers.models.attention import Attention, FeedForward
+from diffusers.models.attention_processor import AttentionProcessor, CogVideoXAttnProcessor2_0, FusedCogVideoXAttnProcessor2_0
+# from diffusers.models.embeddings import CogVideoXPatchEmbed, TimestepEmbedding, Timesteps
+from diffusers.models.modeling_utils import ModelMixin
+from diffusers.models.normalization import AdaLayerNorm, CogVideoXLayerNormZero
+from diffusers.models.embeddings import TimestepEmbedding, Timesteps, get_3d_sincos_pos_embed
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+class CogVideoXPatchEmbed(nn.Module):
+ def __init__(
+ self,
+ patch_size: int = 2,
+ in_channels: int = 16,
+ embed_dim: int = 1920,
+ text_embed_dim: int = 4096,
+ bias: bool = True,
+ sample_width: int = 90,
+ sample_height: int = 60,
+ sample_frames: int = 49,
+ temporal_compression_ratio: int = 4,
+ # max_text_seq_length: int = 226,
+ max_text_seq_length: int = 729,
+ spatial_interpolation_scale: float = 1.875,
+ temporal_interpolation_scale: float = 1.0,
+ use_positional_embeddings: bool = True,
+ use_learned_positional_embeddings: bool = True,
+ ) -> None:
+ super().__init__()
+
+ self.patch_size = patch_size
+ self.embed_dim = embed_dim
+ self.sample_height = sample_height
+ self.sample_width = sample_width
+ self.sample_frames = sample_frames
+ self.temporal_compression_ratio = temporal_compression_ratio
+ # self.max_text_seq_length = max_text_seq_length
+ self.max_text_seq_length = 729
+ self.spatial_interpolation_scale = spatial_interpolation_scale
+ self.temporal_interpolation_scale = temporal_interpolation_scale
+ self.use_positional_embeddings = use_positional_embeddings
+ self.use_learned_positional_embeddings = use_learned_positional_embeddings
+
+ self.proj = nn.Conv2d(
+ in_channels, embed_dim, kernel_size=(patch_size, patch_size), stride=patch_size, bias=bias
+ )
+
+ self.text_proj = nn.Linear(text_embed_dim, embed_dim)
+
+ self.pose_proj = torch.nn.Sequential(
+ torch.nn.Linear(1152, 1152*3),
+ torch.nn.SiLU(),
+ torch.nn.Linear(1152*3, text_embed_dim),
+ # torch.nn.LayerNorm(text_embed_dim)
+ )
+ self.pose_proj = zero_module(self.pose_proj)
+
+ if use_positional_embeddings or use_learned_positional_embeddings:
+ persistent = use_learned_positional_embeddings
+ pos_embedding = self._get_positional_embeddings(sample_height, sample_width, sample_frames)
+ self.register_buffer("pos_embedding", pos_embedding, persistent=persistent)
+
+ def _get_positional_embeddings(self, sample_height: int, sample_width: int, sample_frames: int) -> torch.Tensor:
+ post_patch_height = sample_height // self.patch_size
+ post_patch_width = sample_width // self.patch_size
+ post_time_compression_frames = (sample_frames - 1) // self.temporal_compression_ratio + 1
+ num_patches = post_patch_height * post_patch_width * post_time_compression_frames
+
+ pos_embedding = get_3d_sincos_pos_embed(
+ self.embed_dim,
+ (post_patch_width, post_patch_height),
+ post_time_compression_frames,
+ self.spatial_interpolation_scale,
+ self.temporal_interpolation_scale,
+ )
+ pos_embedding = torch.from_numpy(pos_embedding).flatten(0, 1)
+ joint_pos_embedding = torch.zeros(
+ 1, self.max_text_seq_length + num_patches, self.embed_dim, requires_grad=False
+ )
+ joint_pos_embedding.data[:, self.max_text_seq_length :].copy_(pos_embedding)
+
+ return joint_pos_embedding
+
+ def forward(self, pose_embeds: torch.Tensor, image_embeds: torch.Tensor):
+ r"""
+ Args:
+ text_embeds (`torch.Tensor`):
+ Input text embeddings. Expected shape: (batch_size, seq_length, embedding_dim).
+ image_embeds (`torch.Tensor`):
+ Input image embeddings. Expected shape: (batch_size, num_frames, channels, height, width).
+ """
+ pose_embeds = self.pose_proj(pose_embeds)
+ pose_embeds = self.text_proj(pose_embeds)
+
+ batch, num_frames, channels, height, width = image_embeds.shape
+ image_embeds = image_embeds.reshape(-1, channels, height, width)
+
+ image_embeds = self.proj(image_embeds)
+ image_embeds = image_embeds.view(batch, num_frames, *image_embeds.shape[1:])
+ image_embeds = image_embeds.flatten(3).transpose(2, 3) # [batch, num_frames, height x width, channels]
+ image_embeds = image_embeds.flatten(1, 2) # [batch, num_frames x height x width, channels]
+
+ embeds = torch.cat(
+ [pose_embeds, image_embeds], dim=1
+ ).contiguous() # [batch, seq_length + num_frames x height x width, channels]
+
+ if self.use_positional_embeddings or self.use_learned_positional_embeddings:
+ if self.use_learned_positional_embeddings and (self.sample_width != width or self.sample_height != height):
+ raise ValueError(
+ "It is currently not possible to generate videos at a different resolution that the defaults. This should only be the case with 'THUDM/CogVideoX-5b-I2V'."
+ "If you think this is incorrect, please open an issue at https://github.com/huggingface/diffusers/issues."
+ )
+
+ pre_time_compression_frames = (num_frames - 1) * self.temporal_compression_ratio + 1
+
+ if (
+ self.sample_height != height
+ or self.sample_width != width
+ or self.sample_frames != pre_time_compression_frames
+ ):
+ pos_embedding = self._get_positional_embeddings(height, width, pre_time_compression_frames)
+ pos_embedding = pos_embedding.to(embeds.device, dtype=embeds.dtype)
+ else:
+ pos_embedding = self.pos_embedding
+
+ embeds = embeds + pos_embedding
+
+ return embeds
+
+
+
+@maybe_allow_in_graph
+class CogVideoXBlock(nn.Module):
+ r"""
+ Transformer block used in [CogVideoX](https://github.com/THUDM/CogVideo) model.
+
+ Parameters:
+ dim (`int`):
+ The number of channels in the input and output.
+ num_attention_heads (`int`):
+ The number of heads to use for multi-head attention.
+ attention_head_dim (`int`):
+ The number of channels in each head.
+ time_embed_dim (`int`):
+ The number of channels in timestep embedding.
+ dropout (`float`, defaults to `0.0`):
+ The dropout probability to use.
+ activation_fn (`str`, defaults to `"gelu-approximate"`):
+ Activation function to be used in feed-forward.
+ attention_bias (`bool`, defaults to `False`):
+ Whether or not to use bias in attention projection layers.
+ qk_norm (`bool`, defaults to `True`):
+ Whether or not to use normalization after query and key projections in Attention.
+ norm_elementwise_affine (`bool`, defaults to `True`):
+ Whether to use learnable elementwise affine parameters for normalization.
+ norm_eps (`float`, defaults to `1e-5`):
+ Epsilon value for normalization layers.
+ final_dropout (`bool` defaults to `False`):
+ Whether to apply a final dropout after the last feed-forward layer.
+ ff_inner_dim (`int`, *optional*, defaults to `None`):
+ Custom hidden dimension of Feed-forward layer. If not provided, `4 * dim` is used.
+ ff_bias (`bool`, defaults to `True`):
+ Whether or not to use bias in Feed-forward layer.
+ attention_out_bias (`bool`, defaults to `True`):
+ Whether or not to use bias in Attention output projection layer.
+ """
+
+ def __init__(
+ self,
+ dim: int,
+ num_attention_heads: int,
+ attention_head_dim: int,
+ time_embed_dim: int,
+ dropout: float = 0.0,
+ activation_fn: str = "gelu-approximate",
+ attention_bias: bool = False,
+ qk_norm: bool = True,
+ norm_elementwise_affine: bool = True,
+ norm_eps: float = 1e-5,
+ final_dropout: bool = True,
+ ff_inner_dim: Optional[int] = None,
+ ff_bias: bool = True,
+ attention_out_bias: bool = True,
+ ):
+ super().__init__()
+
+ # 1. Self Attention
+ self.norm1 = CogVideoXLayerNormZero(time_embed_dim, dim, norm_elementwise_affine, norm_eps, bias=True)
+
+ self.attn1 = Attention(
+ query_dim=dim,
+ dim_head=attention_head_dim,
+ heads=num_attention_heads,
+ qk_norm="layer_norm" if qk_norm else None,
+ eps=1e-6,
+ bias=attention_bias,
+ out_bias=attention_out_bias,
+ processor=CogVideoXAttnProcessor2_0(),
+ )
+
+ # 2. Feed Forward
+ self.norm2 = CogVideoXLayerNormZero(time_embed_dim, dim, norm_elementwise_affine, norm_eps, bias=True)
+
+ self.ff = FeedForward(
+ dim,
+ dropout=dropout,
+ activation_fn=activation_fn,
+ final_dropout=final_dropout,
+ inner_dim=ff_inner_dim,
+ bias=ff_bias,
+ )
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: torch.Tensor,
+ temb: torch.Tensor,
+ image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
+ ) -> torch.Tensor:
+ text_seq_length = encoder_hidden_states.size(1)
+
+ # norm & modulate
+ norm_hidden_states, norm_encoder_hidden_states, gate_msa, enc_gate_msa = self.norm1(
+ hidden_states, encoder_hidden_states, temb
+ )
+
+ # attention
+ attn_hidden_states, attn_encoder_hidden_states = self.attn1(
+ hidden_states=norm_hidden_states,
+ encoder_hidden_states=norm_encoder_hidden_states,
+ image_rotary_emb=image_rotary_emb,
+ )
+
+ hidden_states = hidden_states + gate_msa * attn_hidden_states
+ encoder_hidden_states = encoder_hidden_states + enc_gate_msa * attn_encoder_hidden_states
+
+ # norm & modulate
+ norm_hidden_states, norm_encoder_hidden_states, gate_ff, enc_gate_ff = self.norm2(
+ hidden_states, encoder_hidden_states, temb
+ )
+
+ # feed-forward
+ norm_hidden_states = torch.cat([norm_encoder_hidden_states, norm_hidden_states], dim=1)
+ ff_output = self.ff(norm_hidden_states)
+
+ hidden_states = hidden_states + gate_ff * ff_output[:, text_seq_length:]
+ encoder_hidden_states = encoder_hidden_states + enc_gate_ff * ff_output[:, :text_seq_length]
+
+ return hidden_states, encoder_hidden_states
+
+
+class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
+ """
+ A Transformer model for video-like data in [CogVideoX](https://github.com/THUDM/CogVideo).
+
+ Parameters:
+ num_attention_heads (`int`, defaults to `30`):
+ The number of heads to use for multi-head attention.
+ attention_head_dim (`int`, defaults to `64`):
+ The number of channels in each head.
+ in_channels (`int`, defaults to `16`):
+ The number of channels in the input.
+ out_channels (`int`, *optional*, defaults to `16`):
+ The number of channels in the output.
+ flip_sin_to_cos (`bool`, defaults to `True`):
+ Whether to flip the sin to cos in the time embedding.
+ time_embed_dim (`int`, defaults to `512`):
+ Output dimension of timestep embeddings.
+ text_embed_dim (`int`, defaults to `4096`):
+ Input dimension of text embeddings from the text encoder.
+ num_layers (`int`, defaults to `30`):
+ The number of layers of Transformer blocks to use.
+ dropout (`float`, defaults to `0.0`):
+ The dropout probability to use.
+ attention_bias (`bool`, defaults to `True`):
+ Whether or not to use bias in the attention projection layers.
+ sample_width (`int`, defaults to `90`):
+ The width of the input latents.
+ sample_height (`int`, defaults to `60`):
+ The height of the input latents.
+ sample_frames (`int`, defaults to `49`):
+ The number of frames in the input latents. Note that this parameter was incorrectly initialized to 49
+ instead of 13 because CogVideoX processed 13 latent frames at once in its default and recommended settings,
+ but cannot be changed to the correct value to ensure backwards compatibility. To create a transformer with
+ K latent frames, the correct value to pass here would be: ((K - 1) * temporal_compression_ratio + 1).
+ patch_size (`int`, defaults to `2`):
+ The size of the patches to use in the patch embedding layer.
+ temporal_compression_ratio (`int`, defaults to `4`):
+ The compression ratio across the temporal dimension. See documentation for `sample_frames`.
+ max_text_seq_length (`int`, defaults to `226`):
+ The maximum sequence length of the input text embeddings.
+ activation_fn (`str`, defaults to `"gelu-approximate"`):
+ Activation function to use in feed-forward.
+ timestep_activation_fn (`str`, defaults to `"silu"`):
+ Activation function to use when generating the timestep embeddings.
+ norm_elementwise_affine (`bool`, defaults to `True`):
+ Whether or not to use elementwise affine in normalization layers.
+ norm_eps (`float`, defaults to `1e-5`):
+ The epsilon value to use in normalization layers.
+ spatial_interpolation_scale (`float`, defaults to `1.875`):
+ Scaling factor to apply in 3D positional embeddings across spatial dimensions.
+ temporal_interpolation_scale (`float`, defaults to `1.0`):
+ Scaling factor to apply in 3D positional embeddings across temporal dimensions.
+ """
+
+ _supports_gradient_checkpointing = True
+
+ @register_to_config
+ def __init__(
+ self,
+ num_attention_heads: int = 30,
+ attention_head_dim: int = 64,
+ in_channels: int = 16,
+ out_channels: Optional[int] = 16,
+ flip_sin_to_cos: bool = True,
+ freq_shift: int = 0,
+ time_embed_dim: int = 512,
+ text_embed_dim: int = 4096,
+ num_layers: int = 30,
+ dropout: float = 0.0,
+ attention_bias: bool = True,
+ sample_width: int = 90,
+ sample_height: int = 60,
+ sample_frames: int = 49,
+ patch_size: int = 2,
+ temporal_compression_ratio: int = 4,
+ # max_text_seq_length: int = 226,
+ max_text_seq_length: int = 729,
+ activation_fn: str = "gelu-approximate",
+ timestep_activation_fn: str = "silu",
+ norm_elementwise_affine: bool = True,
+ norm_eps: float = 1e-5,
+ spatial_interpolation_scale: float = 1.875,
+ temporal_interpolation_scale: float = 1.0,
+ use_rotary_positional_embeddings: bool = False,
+ use_learned_positional_embeddings: bool = False,
+ ):
+ super().__init__()
+ inner_dim = num_attention_heads * attention_head_dim
+
+ if not use_rotary_positional_embeddings and use_learned_positional_embeddings:
+ raise ValueError(
+ "There are no CogVideoX checkpoints available with disable rotary embeddings and learned positional "
+ "embeddings. If you're using a custom model and/or believe this should be supported, please open an "
+ "issue at https://github.com/huggingface/diffusers/issues."
+ )
+
+ # 1. Patch embedding
+ self.patch_embed = CogVideoXPatchEmbed(
+ patch_size=patch_size,
+ in_channels=in_channels,
+ embed_dim=inner_dim,
+ text_embed_dim=text_embed_dim,
+ bias=True,
+ sample_width=sample_width,
+ sample_height=sample_height,
+ sample_frames=sample_frames,
+ temporal_compression_ratio=temporal_compression_ratio,
+ # max_text_seq_length=max_text_seq_length,
+ max_text_seq_length=729,
+ spatial_interpolation_scale=spatial_interpolation_scale,
+ temporal_interpolation_scale=temporal_interpolation_scale,
+ use_positional_embeddings=not use_rotary_positional_embeddings,
+ use_learned_positional_embeddings=use_learned_positional_embeddings,
+ )
+ self.embedding_dropout = nn.Dropout(dropout)
+
+ # 2. Time embeddings
+ self.time_proj = Timesteps(inner_dim, flip_sin_to_cos, freq_shift)
+ self.time_embedding = TimestepEmbedding(inner_dim, time_embed_dim, timestep_activation_fn)
+
+ # 3. Define spatio-temporal transformers blocks
+ self.transformer_blocks = nn.ModuleList(
+ [
+ CogVideoXBlock(
+ dim=inner_dim,
+ num_attention_heads=num_attention_heads,
+ attention_head_dim=attention_head_dim,
+ time_embed_dim=time_embed_dim,
+ dropout=dropout,
+ activation_fn=activation_fn,
+ attention_bias=attention_bias,
+ norm_elementwise_affine=norm_elementwise_affine,
+ norm_eps=norm_eps,
+ )
+ for _ in range(num_layers)
+ ]
+ )
+ self.norm_final = nn.LayerNorm(inner_dim, norm_eps, norm_elementwise_affine)
+
+ # 4. Output blocks
+ self.norm_out = AdaLayerNorm(
+ embedding_dim=time_embed_dim,
+ output_dim=2 * inner_dim,
+ norm_elementwise_affine=norm_elementwise_affine,
+ norm_eps=norm_eps,
+ chunk_dim=1,
+ )
+ self.proj_out = nn.Linear(inner_dim, patch_size * patch_size * out_channels)
+
+ self.gradient_checkpointing = False
+
+ def _set_gradient_checkpointing(self, module, value=False):
+ self.gradient_checkpointing = value
+
+ @property
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
+ def attn_processors(self) -> Dict[str, AttentionProcessor]:
+ r"""
+ Returns:
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
+ indexed by its weight name.
+ """
+ # set recursively
+ processors = {}
+
+ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
+ if hasattr(module, "get_processor"):
+ processors[f"{name}.processor"] = module.get_processor()
+
+ for sub_name, child in module.named_children():
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
+
+ return processors
+
+ for name, module in self.named_children():
+ fn_recursive_add_processors(name, module, processors)
+
+ return processors
+
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
+ def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
+ r"""
+ Sets the attention processor to use to compute attention.
+
+ Parameters:
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
+ for **all** `Attention` layers.
+
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
+ processor. This is strongly recommended when setting trainable attention processors.
+
+ """
+ count = len(self.attn_processors.keys())
+
+ if isinstance(processor, dict) and len(processor) != count:
+ raise ValueError(
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
+ )
+
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
+ if hasattr(module, "set_processor"):
+ if not isinstance(processor, dict):
+ module.set_processor(processor)
+ else:
+ module.set_processor(processor.pop(f"{name}.processor"))
+
+ for sub_name, child in module.named_children():
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
+
+ for name, module in self.named_children():
+ fn_recursive_attn_processor(name, module, processor)
+
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections with FusedAttnProcessor2_0->FusedCogVideoXAttnProcessor2_0
+ def fuse_qkv_projections(self):
+ """
+ Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
+ are fused. For cross-attention modules, key and value projection matrices are fused.
+
+
+
+ This API is 🧪 experimental.
+
+
+ """
+ self.original_attn_processors = None
+
+ for _, attn_processor in self.attn_processors.items():
+ if "Added" in str(attn_processor.__class__.__name__):
+ raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
+
+ self.original_attn_processors = self.attn_processors
+
+ for module in self.modules():
+ if isinstance(module, Attention):
+ module.fuse_projections(fuse=True)
+
+ self.set_attn_processor(FusedCogVideoXAttnProcessor2_0())
+
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
+ def unfuse_qkv_projections(self):
+ """Disables the fused QKV projection if enabled.
+
+
+
+ This API is 🧪 experimental.
+
+
+
+ """
+ if self.original_attn_processors is not None:
+ self.set_attn_processor(self.original_attn_processors)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: torch.Tensor,
+ timestep: Union[int, float, torch.LongTensor],
+ timestep_cond: Optional[torch.Tensor] = None,
+ image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
+ attention_kwargs: Optional[Dict[str, Any]] = None,
+ return_dict: bool = True,
+ ):
+ if attention_kwargs is not None:
+ attention_kwargs = attention_kwargs.copy()
+ lora_scale = attention_kwargs.pop("scale", 1.0)
+ else:
+ lora_scale = 1.0
+
+ if USE_PEFT_BACKEND:
+ # weight the lora layers by setting `lora_scale` for each PEFT layer
+ scale_lora_layers(self, lora_scale)
+ else:
+ if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
+ logger.warning(
+ "Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
+ )
+
+ batch_size, num_frames, channels, height, width = hidden_states.shape
+
+ # 1. Time embedding
+ timesteps = timestep
+ t_emb = self.time_proj(timesteps)
+
+ # timesteps does not contain any weights and will always return f32 tensors
+ # but time_embedding might actually be running in fp16. so we need to cast here.
+ # there might be better ways to encapsulate this.
+ t_emb = t_emb.to(dtype=hidden_states.dtype)
+ emb = self.time_embedding(t_emb, timestep_cond)
+
+ # 2. Patch embedding
+ hidden_states = self.patch_embed(encoder_hidden_states, hidden_states)
+ hidden_states = self.embedding_dropout(hidden_states)
+
+ text_seq_length = encoder_hidden_states.shape[1]
+ # text_seq_length = 226
+ encoder_hidden_states = hidden_states[:, :text_seq_length]
+ hidden_states = hidden_states[:, text_seq_length:]
+
+ # 3. Transformer blocks
+ for i, block in enumerate(self.transformer_blocks):
+ if self.training and self.gradient_checkpointing:
+
+ def create_custom_forward(module):
+ def custom_forward(*inputs):
+ return module(*inputs)
+
+ return custom_forward
+
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
+ hidden_states, encoder_hidden_states = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(block),
+ hidden_states,
+ encoder_hidden_states,
+ emb,
+ image_rotary_emb,
+ **ckpt_kwargs,
+ )
+ else:
+ hidden_states, encoder_hidden_states = block(
+ hidden_states=hidden_states,
+ encoder_hidden_states=encoder_hidden_states,
+ temb=emb,
+ image_rotary_emb=image_rotary_emb,
+ )
+
+ if not self.config.use_rotary_positional_embeddings:
+ # CogVideoX-2B
+ hidden_states = self.norm_final(hidden_states)
+ else:
+ # CogVideoX-5B
+ hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
+ hidden_states = self.norm_final(hidden_states)
+ hidden_states = hidden_states[:, text_seq_length:]
+
+ # 4. Final block
+ hidden_states = self.norm_out(hidden_states, temb=emb)
+ hidden_states = self.proj_out(hidden_states)
+
+ # 5. Unpatchify
+ # Note: we use `-1` instead of `channels`:
+ # - It is okay to `channels` use for CogVideoX-2b and CogVideoX-5b (number of input channels is equal to output channels)
+ # - However, for CogVideoX-5b-I2V also takes concatenated input image latents (number of input channels is twice the output channels)
+ p = self.config.patch_size
+ output = hidden_states.reshape(batch_size, num_frames, height // p, width // p, -1, p, p)
+ output = output.permute(0, 1, 4, 2, 5, 3, 6).flatten(5, 6).flatten(3, 4)
+
+ if USE_PEFT_BACKEND:
+ # remove `lora_scale` from each PEFT layer
+ unscale_lora_layers(self, lora_scale)
+
+ if not return_dict:
+ return (output,)
+ return Transformer2DModelOutput(sample=output)
+
+ @classmethod
+ def from_pretrained_2d(cls, pretrained_model_path, subfolder=None, transformer_additional_kwargs={}):
+ if subfolder is not None:
+ pretrained_model_path = os.path.join(pretrained_model_path, subfolder)
+ print(f"loaded 3D transformer's pretrained weights from {pretrained_model_path} ...")
+
+ config_file = os.path.join(pretrained_model_path, 'config.json')
+ if not os.path.isfile(config_file):
+ raise RuntimeError(f"{config_file} does not exist")
+ with open(config_file, "r") as f:
+ config = json.load(f)
+
+ from diffusers.utils import WEIGHTS_NAME
+ model = cls.from_config(config, **transformer_additional_kwargs)
+ model_file = os.path.join(pretrained_model_path, WEIGHTS_NAME)
+ model_file_safetensors = model_file.replace(".bin", ".safetensors")
+ if os.path.exists(model_file):
+ state_dict = torch.load(model_file, map_location="cpu")
+ elif os.path.exists(model_file_safetensors):
+ from safetensors.torch import load_file, safe_open
+ state_dict = load_file(model_file_safetensors)
+ else:
+ from safetensors.torch import load_file, safe_open
+ model_files_safetensors = glob.glob(os.path.join(pretrained_model_path, "*.safetensors"))
+ state_dict = {}
+ for model_file_safetensors in model_files_safetensors:
+ _state_dict = load_file(model_file_safetensors)
+ for key in _state_dict:
+ state_dict[key] = _state_dict[key]
+
+ if model.state_dict()['patch_embed.proj.weight'].size() != state_dict['patch_embed.proj.weight'].size():
+ new_shape = model.state_dict()['patch_embed.proj.weight'].size()
+ if len(new_shape) == 5:
+ state_dict['patch_embed.proj.weight'] = state_dict['patch_embed.proj.weight'].unsqueeze(2).expand(new_shape).clone()
+ state_dict['patch_embed.proj.weight'][:, :, :-1] = 0
+ else:
+ if model.state_dict()['patch_embed.proj.weight'].size()[1] > state_dict['patch_embed.proj.weight'].size()[1]:
+ model.state_dict()['patch_embed.proj.weight'][:, :state_dict['patch_embed.proj.weight'].size()[1], :, :] = state_dict['patch_embed.proj.weight']
+ model.state_dict()['patch_embed.proj.weight'][:, state_dict['patch_embed.proj.weight'].size()[1]:, :, :] = 0
+ state_dict['patch_embed.proj.weight'] = model.state_dict()['patch_embed.proj.weight']
+ else:
+ model.state_dict()['patch_embed.proj.weight'][:, :, :, :] = state_dict['patch_embed.proj.weight'][:, :model.state_dict()['patch_embed.proj.weight'].size()[1], :, :]
+ state_dict['patch_embed.proj.weight'] = model.state_dict()['patch_embed.proj.weight']
+
+ tmp_state_dict = {}
+ for key in state_dict:
+ if key in model.state_dict().keys() and model.state_dict()[key].size() == state_dict[key].size():
+ tmp_state_dict[key] = state_dict[key]
+ else:
+ print(key, "Size don't match, skip")
+ state_dict = tmp_state_dict
+
+ m, u = model.load_state_dict(state_dict, strict=False)
+ print(f"### missing keys: {len(m)}; \n### unexpected keys: {len(u)};")
+
+ params = [p.numel() if "mamba" in n else 0 for n, p in model.named_parameters()]
+ print(f"### Mamba Parameters: {sum(params) / 1e6} M")
+
+ params = [p.numel() if "attn1." in n else 0 for n, p in model.named_parameters()]
+ print(f"### attn1 Parameters: {sum(params) / 1e6} M")
+
+ return model
+
+
+ @classmethod
+ def from_pretrained_2model(cls, pretrained_model_path, config_path, subfolder=None, state1_path=None, state2_path=None, transformer_additional_kwargs={}):
+ config_file = config_path
+ print("config_file: ", config_file)
+ if subfolder is not None:
+ pretrained_model_path = os.path.join(pretrained_model_path, subfolder)
+ if not os.path.isfile(config_file):
+ raise RuntimeError(f"{config_file} does not exist")
+ with open(config_file, "r") as f:
+ config = json.load(f)
+
+ from diffusers.utils import WEIGHTS_NAME
+ model = cls.from_config(config, **transformer_additional_kwargs)
+
+ from safetensors.torch import load_file, safe_open
+ file_paths = [
+ "diffusion_pytorch_model-00001-of-00003.safetensors",
+ "diffusion_pytorch_model-00002-of-00003.safetensors",
+ "diffusion_pytorch_model-00003-of-00003.safetensors"
+ ]
+ state1_dict = {}
+ for file_path in file_paths:
+ model_path = os.path.join(pretrained_model_path, state1_path, file_path)
+ state1_dict.update(load_file(model_path))
+
+ model2_file = os.path.join(pretrained_model_path, state2_path, WEIGHTS_NAME)
+ model2_file_safetensors = model2_file.replace(".bin", ".safetensors")
+ state2_dict = load_file(model2_file_safetensors)
+
+ NUM_MODELS = 2
+ tmp_state_dict = {}
+
+ if model.state_dict()['patch_embed.proj.weight'].size() != state1_dict['patch_embed.proj.weight'].size():
+ model.state_dict()['patch_embed.proj.weight'][:, :32, :, :] = state1_dict['patch_embed.proj.weight'][:, :, :, :]
+ model.state_dict()['patch_embed.proj.weight'][:, 32:, :, :] = state2_dict['patch_embed.proj.weight'][:, 16:, :, :]
+ tmp_state_dict['patch_embed.proj.weight'] = model.state_dict()['patch_embed.proj.weight']
+
+ #patch_embed.pos_embedding
+
+ for key in state1_dict:
+ if key in model.state_dict().keys() and model.state_dict()[key].size() == state1_dict[key].size():
+ # UniformSoup
+ tmp_state_dict[key] = state1_dict[key] * (1. / NUM_MODELS)
+ if key in state2_dict:
+ tmp_state_dict[key] += state2_dict[key] * (1. / NUM_MODELS)
+ else:
+ tmp_state_dict[key] = state1_dict[key]
+
+ # LayerSoup
+ # if "transformer_blocks.1" <= key <= "transformer_blocks.12":
+ # tmp_state_dict[key] = state2_dict[key]
+ # # elif "transformer_blocks.30" <= key <= "transformer_blocks.36":
+ # # tmp_state_dict[key] = state1_dict[key] * (WEIGHT_STATE1_30_36 / NUM_MODELS)
+ # # tmp_state_dict[key] += state2_dict[key] * (WEIGHT_STATE2_30_36 / NUM_MODELS)
+ # else:
+ # tmp_state_dict[key] = state1_dict[key] * (1. / NUM_MODELS)
+ # tmp_state_dict[key] += state2_dict[key] * (1. / NUM_MODELS)
+
+ # print(f"the key is {key} and the difference between state1 and state2 is {(state2_dict[key] - state1_dict[key]).mean()}")
+ else:
+ print(key, "Size don't match, skip")
+
+ state_dict = tmp_state_dict
+
+ m, u = model.load_state_dict(state_dict, strict=False)
+ print(f"### missing keys: {len(m)}; \n### unexpected keys: {len(u)};")
+ print(m)
+
+ params = [p.numel() if "mamba" in n else 0 for n, p in model.named_parameters()]
+ print(f"### Mamba Parameters: {sum(params) / 1e6} M")
+
+ params = [p.numel() if "attn1." in n else 0 for n, p in model.named_parameters()]
+ print(f"### attn1 Parameters: {sum(params) / 1e6} M")
+
+ return model
+
+
+def zero_module(module):
+ for p in module.parameters():
+ nn.init.zeros_(p)
+ return module
\ No newline at end of file
diff --git a/skyreels_a1/pipeline_output.py b/skyreels_a1/pipeline_output.py
new file mode 100644
index 0000000000000000000000000000000000000000..2432171b4b24ad449639ab0a4af1553739d4c72f
--- /dev/null
+++ b/skyreels_a1/pipeline_output.py
@@ -0,0 +1,20 @@
+from dataclasses import dataclass
+
+import torch
+
+from diffusers.utils import BaseOutput
+
+
+@dataclass
+class CogVideoXPipelineOutput(BaseOutput):
+ r"""
+ Output class for CogVideo pipelines.
+
+ Args:
+ frames (`torch.Tensor`, `np.ndarray`, or List[List[PIL.Image.Image]]):
+ List of video outputs - It can be a nested list of length `batch_size,` with each sub-list containing
+ denoised PIL image sequences of length `num_frames.` It can also be a NumPy array or Torch tensor of shape
+ `(batch_size, num_frames, channels, height, width)`.
+ """
+
+ frames: torch.Tensor
\ No newline at end of file
diff --git a/skyreels_a1/pre_process_lmk3d.py b/skyreels_a1/pre_process_lmk3d.py
new file mode 100644
index 0000000000000000000000000000000000000000..b3c940984fd0b71f84bc8bd0297043a22f6c5eae
--- /dev/null
+++ b/skyreels_a1/pre_process_lmk3d.py
@@ -0,0 +1,351 @@
+import torch
+import cv2
+import os
+import sys
+import numpy as np
+import argparse
+import math
+from PIL import Image
+from decord import VideoReader
+from skimage.transform import estimate_transform, warp
+from insightface.app import FaceAnalysis
+from diffusers.utils import load_image
+
+sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
+from skyreels_a1.src.utils.mediapipe_utils import MediaPipeUtils
+from skyreels_a1.src.smirk_encoder import SmirkEncoder
+from skyreels_a1.src.FLAME.FLAME import FLAME
+from skyreels_a1.src.renderer import Renderer
+from moviepy.editor import ImageSequenceClip
+
+class FaceAnimationProcessor:
+ def __init__(self, device='cuda', checkpoint="pretrained_models/smirk/smirk_encoder.pt"):
+ self.device = device
+ self.app = FaceAnalysis(allowed_modules=['detection'])
+ self.app.prepare(ctx_id=0, det_size=(640, 640))
+ self.smirk_encoder = SmirkEncoder().to(device)
+ self.flame = FLAME(n_shape=300, n_exp=50).to(device)
+ self.renderer = Renderer().to(device)
+ self.load_checkpoint(checkpoint)
+
+ def load_checkpoint(self, checkpoint):
+ checkpoint_data = torch.load(checkpoint)
+ checkpoint_encoder = {k.replace('smirk_encoder.', ''): v for k, v in checkpoint_data.items() if 'smirk_encoder' in k}
+ self.smirk_encoder.load_state_dict(checkpoint_encoder)
+ self.smirk_encoder.eval()
+
+ def face_crop(self, image):
+ height, width, _ = image.shape
+ faces = self.app.get(image)
+ bbox = faces[0]['bbox']
+ w = bbox[2] - bbox[0]
+ h = bbox[3] - bbox[1]
+ x1 = max(0, int(bbox[0] - w/2))
+ x2 = min(width - 1, int(bbox[2] + w/2))
+ w_new = x2 - x1
+ y_offset = (w_new - h) / 2.
+ y1 = max(0, int(bbox[1] - y_offset))
+ y2 = min(height, int(bbox[3] + y_offset))
+ x_comp = int(((x2 - x1) - (y2 - y1)) / 2) if (x2 - x1) > (y2 - y1) else 0
+ x1 += x_comp
+ x2 -= x_comp
+ image_crop = image[y1:y2, x1:x2]
+ return image_crop, x1, y1
+
+ def crop_and_resize(self, image, height, width):
+ image = np.array(image)
+ image_height, image_width, _ = image.shape
+ if image_height / image_width < height / width:
+ croped_width = int(image_height / height * width)
+ left = (image_width - croped_width) // 2
+ image = image[:, left: left+croped_width]
+ else:
+ pad = int((((width / height) * image_height) - image_width) / 2.)
+ padded_image = np.zeros((image_height, image_width + pad * 2, 3), dtype=np.uint8)
+ padded_image[:, pad:pad+image_width] = image
+ image = padded_image
+ return Image.fromarray(image).resize((width, height))
+
+ def rodrigues_to_matrix(self, pose_params):
+ theta = torch.norm(pose_params, dim=-1, keepdim=True)
+ r = pose_params / (theta + 1e-8)
+ cos_theta = torch.cos(theta)
+ sin_theta = torch.sin(theta)
+ r_x = torch.zeros((pose_params.shape[0], 3, 3), device=pose_params.device)
+ r_x[:, 0, 1] = -r[:, 2]
+ r_x[:, 0, 2] = r[:, 1]
+ r_x[:, 1, 0] = r[:, 2]
+ r_x[:, 1, 2] = -r[:, 0]
+ r_x[:, 2, 0] = -r[:, 1]
+ r_x[:, 2, 1] = r[:, 0]
+ R = cos_theta * torch.eye(3, device=pose_params.device).unsqueeze(0) + \
+ sin_theta * r_x + \
+ (1 - cos_theta) * r.unsqueeze(-1) @ r.unsqueeze(-2)
+ return R
+
+ def matrix_to_rodrigues(self, R):
+ cos_theta = (torch.trace(R[0]) - 1) / 2
+ cos_theta = torch.clamp(cos_theta, -1, 1)
+ theta = torch.acos(cos_theta)
+ if abs(theta) < 1e-4:
+ return torch.zeros(1, 3, device=R.device)
+ elif abs(theta - math.pi) < 1e-4:
+ R_plus_I = R[0] + torch.eye(3, device=R.device)
+ col_norms = torch.norm(R_plus_I, dim=0)
+ max_col_idx = torch.argmax(col_norms)
+ v = R_plus_I[:, max_col_idx]
+ v = v / torch.norm(v)
+ return (v * math.pi).unsqueeze(0)
+ sin_theta = torch.sin(theta)
+ r = torch.zeros(1, 3, device=R.device)
+ r[0, 0] = R[0, 2, 1] - R[0, 1, 2]
+ r[0, 1] = R[0, 0, 2] - R[0, 2, 0]
+ r[0, 2] = R[0, 1, 0] - R[0, 0, 1]
+ r = r / (2 * sin_theta)
+ return r * theta
+
+ def crop_face(self, frame, landmarks, scale=1.0, image_size=224):
+ left = np.min(landmarks[:, 0])
+ right = np.max(landmarks[:, 0])
+ top = np.min(landmarks[:, 1])
+ bottom = np.max(landmarks[:, 1])
+ h, w, _ = frame.shape
+ old_size = (right - left + bottom - top) / 2
+ center = np.array([right - (right - left) / 2.0, bottom - (bottom - top) / 2.0])
+ size = int(old_size * scale)
+ src_pts = np.array([[center[0] - size / 2, center[1] - size / 2], [center[0] - size / 2, center[1] + size / 2],
+ [center[0] + size / 2, center[1] - size / 2]])
+ DST_PTS = np.array([[0, 0], [0, image_size - 1], [image_size - 1, 0]])
+ tform = estimate_transform('similarity', src_pts, DST_PTS)
+ tform_original = estimate_transform('similarity', src_pts, src_pts)
+ return tform, tform_original
+
+ def compute_landmark_relation(self, kpt_mediapipe, target_idx=473, ref_indices=[362, 382, 381, 380, 374, 373, 390, 249, 263, 466, 388, 387, 386, 385, 384, 398]):
+ target_point = kpt_mediapipe[target_idx]
+ ref_points = kpt_mediapipe[ref_indices]
+ left_corner = ref_points[0]
+ right_corner = ref_points[8]
+ eye_center = (left_corner + right_corner) / 2
+ eye_width_vector = right_corner - left_corner
+ eye_width = np.linalg.norm(eye_width_vector)
+ eye_direction = eye_width_vector / eye_width
+ eye_vertical = np.array([-eye_direction[1], eye_direction[0]])
+ target_vector = target_point - eye_center
+ x_relative = np.dot(target_vector, eye_direction) / (eye_width/2)
+ y_relative = np.dot(target_vector, eye_vertical) / (eye_width/2)
+ return [np.array([x_relative, y_relative]),target_point,ref_points,ref_indices]
+
+ def process_source_image(self, image_rgb, input_size=224):
+ image_bgr = cv2.cvtColor(image_rgb, cv2.COLOR_RGB2BGR)
+ mediapipe_utils = MediaPipeUtils()
+ kpt_mediapipe, _, _, mediapipe_eye_pose = mediapipe_utils.run_mediapipe(image_bgr)
+ if kpt_mediapipe is None:
+ raise ValueError('Cannot find facial landmarks in the source image')
+ kpt_mediapipe = kpt_mediapipe[..., :2]
+ tform, _ = self.crop_face(image_rgb, kpt_mediapipe, scale=1.4, image_size=input_size)
+ cropped_image = warp(image_rgb, tform.inverse, output_shape=(input_size, input_size), preserve_range=True).astype(np.uint8)
+ cropped_image = torch.tensor(cropped_image).permute(2, 0, 1).unsqueeze(0).float() / 255.0
+ cropped_image = cropped_image.to(self.device)
+ with torch.no_grad():
+ source_outputs = self.smirk_encoder(cropped_image)
+ source_outputs['eye_pose_params'] = torch.tensor(mediapipe_eye_pose).to(self.device)
+ return source_outputs, tform, image_rgb
+
+ def smooth_params(self, data, alpha=0.7):
+ smoothed_data = [data[0]]
+ for i in range(1, len(data)):
+ smoothed_value = alpha * data[i] + (1 - alpha) * smoothed_data[i-1]
+ smoothed_data.append(smoothed_value)
+ return smoothed_data
+
+ def process_driving_img_list(self, img_list, input_size=224):
+ driving_frames = []
+ driving_outputs = []
+ driving_tforms = []
+ weights_473 = []
+ weights_468 = []
+ mediapipe_utils = MediaPipeUtils()
+ for i, frame in enumerate(img_list):
+ frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
+ try:
+ kpt_mediapipe, mediapipe_exp, mediapipe_pose, mediapipe_eye_pose = mediapipe_utils.run_mediapipe(frame)
+ except:
+ print('Warning: No face detected in a frame, skipping this frame')
+ continue
+ if kpt_mediapipe is None:
+ print('Warning: No face detected in a frame, skipping this frame')
+ continue
+ kpt_mediapipe = kpt_mediapipe[..., :2]
+ weights_473.append(self.compute_landmark_relation(kpt_mediapipe))
+ weights_468.append(self.compute_landmark_relation(kpt_mediapipe, target_idx=468, ref_indices=[33, 7, 163, 144, 145, 153, 154, 155, 133, 173, 157, 158, 159, 160, 161, 246]))
+
+
+ tform, _ = self.crop_face(frame, kpt_mediapipe, scale=1.4, image_size=input_size)
+ cropped_frame = warp(frame, tform.inverse, output_shape=(input_size, input_size), preserve_range=True).astype(np.uint8)
+ cropped_frame = cv2.cvtColor(cropped_frame, cv2.COLOR_BGR2RGB)
+ cropped_frame = torch.tensor(cropped_frame).permute(2, 0, 1).unsqueeze(0).float() / 255.0
+ cropped_frame = cropped_frame.to(self.device)
+ with torch.no_grad():
+ outputs = self.smirk_encoder(cropped_frame)
+ outputs['eye_pose_params'] = torch.tensor(mediapipe_eye_pose).to(self.device)
+ outputs['mediapipe_exp'] = torch.tensor(mediapipe_exp).to(self.device)
+ outputs['mediapipe_pose'] = torch.tensor(mediapipe_pose).to(self.device)
+ driving_frames.append(frame)
+ driving_outputs.append(outputs)
+ driving_tforms.append(tform)
+ return driving_frames, driving_outputs, driving_tforms, weights_473, weights_468
+
+ def preprocess_lmk3d(self, source_image=None, driving_image_list=None):
+ source_outputs, source_tform, image_original = self.process_source_image(source_image)
+ _, driving_outputs, driving_video_tform, weights_473, weights_468 = self.process_driving_img_list(driving_image_list)
+ driving_outputs_list = []
+ source_pose_init = source_outputs['pose_params'].clone()
+ driving_outputs_pose = [outputs['pose_params'] for outputs in driving_outputs]
+ driving_outputs_pose = self.smooth_params(driving_outputs_pose)
+ for i, outputs in enumerate(driving_outputs):
+ outputs['pose_params'] = driving_outputs_pose[i]
+ source_outputs['expression_params'] = outputs['expression_params']
+ source_outputs['jaw_params'] = outputs['jaw_params']
+ source_outputs['eye_pose_params'] = outputs['eye_pose_params']
+ source_matrix = self.rodrigues_to_matrix(source_pose_init)
+ driving_matrix_0 = self.rodrigues_to_matrix(driving_outputs[0]['pose_params'])
+ driving_matrix_i = self.rodrigues_to_matrix(driving_outputs[i]['pose_params'])
+ relative_rotation = torch.inverse(driving_matrix_0) @ driving_matrix_i
+ new_rotation = source_matrix @ relative_rotation
+ source_outputs['pose_params'] = self.matrix_to_rodrigues(new_rotation)
+ source_outputs['eyelid_params'] = outputs['eyelid_params']
+ flame_output = self.flame.forward(source_outputs)
+ renderer_output = self.renderer.forward(
+ flame_output['vertices'],
+ source_outputs['cam'],
+ landmarks_fan=flame_output['landmarks_fan'], source_tform=source_tform,
+ tform_512=None, weights_468=weights_468[i], weights_473=weights_473[i],
+ landmarks_mp=flame_output['landmarks_mp'], shape=image_original.shape)
+ rendered_img = renderer_output['rendered_img']
+ driving_outputs_list.extend(np.copy(rendered_img)[np.newaxis, :])
+ return driving_outputs_list
+
+ def preprocess_lmk3d_from_coef(self, source_outputs, source_tform, render_shape, driving_outputs):
+ driving_outputs_list = []
+ source_pose_init = source_outputs['pose_params'].clone()
+ driving_outputs_pose = [outputs['pose_params'] for outputs in driving_outputs]
+ driving_outputs_pose = self.smooth_params(driving_outputs_pose)
+ for i, outputs in enumerate(driving_outputs):
+ outputs['pose_params'] = driving_outputs_pose[i]
+ source_outputs['expression_params'] = outputs['expression_params']
+ source_outputs['jaw_params'] = outputs['jaw_params']
+ source_outputs['eye_pose_params'] = outputs['eye_pose_params']
+ source_matrix = self.rodrigues_to_matrix(source_pose_init)
+ driving_matrix_0 = self.rodrigues_to_matrix(driving_outputs[0]['pose_params'])
+ driving_matrix_i = self.rodrigues_to_matrix(driving_outputs[i]['pose_params'])
+ relative_rotation = torch.inverse(driving_matrix_0) @ driving_matrix_i
+ new_rotation = source_matrix @ relative_rotation
+ source_outputs['pose_params'] = self.matrix_to_rodrigues(new_rotation)
+ source_outputs['eyelid_params'] = outputs['eyelid_params']
+ flame_output = self.flame.forward(source_outputs)
+ renderer_output = self.renderer.forward(
+ flame_output['vertices'],
+ source_outputs['cam'],
+ landmarks_fan=flame_output['landmarks_fan'], source_tform=source_tform,
+ tform_512=None, weights_468=None, weights_473=None,
+ landmarks_mp=flame_output['landmarks_mp'], shape=render_shape)
+ rendered_img = renderer_output['rendered_img']
+ driving_outputs_list.extend(np.copy(rendered_img)[np.newaxis, :])
+ return driving_outputs_list
+
+
+ def ensure_even_dimensions(self, frame):
+ height, width = frame.shape[:2]
+ new_width = width - (width % 2)
+ new_height = height - (height % 2)
+ if new_width != width or new_height != height:
+ frame = cv2.resize(frame, (new_width, new_height))
+ return frame
+
+ def get_global_bbox(self, frames):
+ max_x1, max_y1, max_x2, max_y2 = float('inf'), float('inf'), 0, 0
+ for frame in frames:
+ faces = self.app.get(frame)
+ if not faces:
+ continue
+ bbox = faces[0]['bbox']
+ x1, y1, x2, y2 = bbox[0], bbox[1], bbox[2], bbox[3]
+ max_x1 = min(max_x1, x1)
+ max_y1 = min(max_y1, y1)
+ max_x2 = max(max_x2, x2)
+ max_y2 = max(max_y2, y2)
+ w = max_x2 - max_x1
+ h = max_y2 - max_y1
+ x1 = int(max_x1 - w / 2)
+ x2 = int(max_x2 + w / 2)
+ y_offset = (x2 - x1 - h) / 2
+ y1 = int(max_y1 - y_offset)
+ y2 = int(max_y2 + y_offset)
+ if (x2 - x1) > (y2 - y1):
+ x_comp = int(((x2 - x1) - (y2 - y1)) / 2)
+ x1 += x_comp
+ x2 -= x_comp
+ else:
+ y_comp = int(((y2 - y1) - (x2 - x1)) / 2)
+ y1 += y_comp
+ y2 -= y_comp
+ x1 = max(0, x1)
+ y1 = max(0, y1)
+ x2 = min(frames[0].shape[1], x2)
+ y2 = min(frames[0].shape[0], y2)
+ return int(x1), int(y1), int(x2), int(y2)
+
+ def face_crop_with_global_box(self, image, global_box):
+ x1, y1, x2, y2 = global_box
+ return image[y1:y2, x1:x2]
+
+ def process_video(self, source_image_path, driving_video_path, output_path, sample_size=[480, 720]):
+ image = load_image(source_image_path)
+ image = self.crop_and_resize(image, sample_size[0], sample_size[1])
+ ref_image = np.array(image)
+ ref_image, x1, y1 = self.face_crop(ref_image)
+ face_h, face_w, _ = ref_image.shape
+
+ vr = VideoReader(driving_video_path)
+ fps = vr.get_avg_fps()
+ video_length = len(vr)
+ duration = video_length / fps
+ target_times = np.arange(0, duration, 1/12)
+ frame_indices = (target_times * fps).astype(np.int32)
+ frame_indices = frame_indices[frame_indices < video_length]
+ control_frames = vr.get_batch(frame_indices).asnumpy()[:48]
+ if len(control_frames) < 49:
+ video_lenght_add = 49 - len(control_frames)
+ control_frames = np.concatenate(([control_frames[0]]*2, control_frames[1:len(control_frames)-2], [control_frames[-1]] * video_lenght_add), axis=0)
+
+ control_frames_crop = []
+ global_box = self.get_global_bbox(control_frames)
+ for control_frame in control_frames:
+ frame = self.face_crop_with_global_box(control_frame, global_box)
+ control_frames_crop.append(frame)
+
+ out_frames = self.preprocess_lmk3d(source_image=ref_image, driving_image_list=control_frames_crop)
+
+ def write_mp4(video_path, samples, fps=14):
+ clip = ImageSequenceClip(samples, fps=fps)
+ clip.write_videofile(video_path, codec='libx264', ffmpeg_params=["-pix_fmt", "yuv420p", "-crf", "23", "-preset", "medium"])
+
+ concat_frames = []
+ for i in range(len(out_frames)):
+ ref_image_concat = ref_image.copy()
+ driving_frame = cv2.resize(control_frames_crop[i], (face_w, face_h))
+ out_frame = cv2.resize(out_frames[i], (face_w, face_h))
+ concat_frame = np.concatenate([ref_image_concat, driving_frame, out_frame], axis=1)
+ concat_frame = self.ensure_even_dimensions(concat_frame)
+ concat_frames.append(concat_frame)
+ write_mp4(output_path, concat_frames, fps=12)
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser(description="Process video and image for face animation.")
+ parser.add_argument('--source_image', type=str, default="assets/ref_images/1.png", help='Path to the source image.')
+ parser.add_argument('--driving_video', type=str, default="assets/driving_video/1.mp4", help='Path to the driving video.')
+ parser.add_argument('--output_path', type=str, default="./output.mp4", help='Path to save the output video.')
+ args = parser.parse_args()
+
+ processor = FaceAnimationProcessor(checkpoint='pretrained_models/smirk/SMIRK_em1.pt')
+ processor.process_video(args.source_image, args.driving_video, args.output_path)
\ No newline at end of file
diff --git a/skyreels_a1/skyreels_a1_i2v_pipeline.py b/skyreels_a1/skyreels_a1_i2v_pipeline.py
new file mode 100644
index 0000000000000000000000000000000000000000..bbf13c30a03f8a00316a3e61ae7c33f8621350c0
--- /dev/null
+++ b/skyreels_a1/skyreels_a1_i2v_pipeline.py
@@ -0,0 +1,941 @@
+# Copyright 2024 The CogVideoX team, Tsinghua University & ZhipuAI and The HuggingFace Team.
+# All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import inspect
+import math
+from typing import Any, Callable, Dict, List, Optional, Tuple, Union
+
+import PIL
+import torch
+from transformers import T5EncoderModel, T5Tokenizer
+
+from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback
+from diffusers.image_processor import PipelineImageInput
+from diffusers.image_processor import VaeImageProcessor
+
+# from diffusers.loaders import CogVideoXLoraLoaderMixin
+# from diffusers.models import AutoencoderKLCogVideoX, CogVideoXTransformer3DModel
+from diffusers.models import AutoencoderKLCogVideoX
+from diffusers.models.embeddings import get_3d_rotary_pos_embed
+from diffusers.pipelines.pipeline_utils import DiffusionPipeline
+from diffusers.schedulers import CogVideoXDDIMScheduler, CogVideoXDPMScheduler
+from diffusers.utils import (
+ logging,
+ replace_example_docstring,
+)
+from diffusers.utils.torch_utils import randn_tensor
+from diffusers.video_processor import VideoProcessor
+from .pipeline_output import CogVideoXPipelineOutput
+from einops import rearrange
+from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
+from diffusers.pipelines.stable_video_diffusion.pipeline_stable_video_diffusion import _resize_with_antialiasing
+from skyreels_a1.models.transformer3d import CogVideoXTransformer3DModel
+from transformers import AutoModelForDepthEstimation, AutoProcessor, SiglipImageProcessor, SiglipVisionModel
+
+from facexlib.parsing import init_parsing_model
+from facexlib.utils.face_restoration_helper import FaceRestoreHelper
+from insightface.app import FaceAnalysis
+import insightface
+import cv2
+import numpy as np
+from PIL import Image
+
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+EXAMPLE_DOC_STRING = """
+ Examples:
+ ```py
+ >>> import torch
+ >>> from diffusers import CogVideoXImageToVideoPipeline
+ >>> from diffusers.utils import export_to_video, load_image
+
+ >>> pipe = CogVideoXImageToVideoPipeline.from_pretrained("THUDM/CogVideoX-5b-I2V", torch_dtype=torch.bfloat16)
+ >>> pipe.to("cuda")
+
+ >>> prompt = "An astronaut hatching from an egg, on the surface of the moon, the darkness and depth of space realised in the background. High quality, ultrarealistic detail and breath-taking movie-like camera shot."
+ >>> image = load_image(
+ ... "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/astronaut.jpg"
+ ... )
+ >>> video = pipe(image, prompt, use_dynamic_cfg=True)
+ >>> export_to_video(video.frames[0], "output.mp4", fps=8)
+ ```
+"""
+
+
+# Similar to diffusers.pipelines.hunyuandit.pipeline_hunyuandit.get_resize_crop_region_for_grid
+def get_resize_crop_region_for_grid(src, tgt_width, tgt_height):
+ tw = tgt_width
+ th = tgt_height
+ h, w = src
+ r = h / w
+ if r > (th / tw):
+ resize_height = th
+ resize_width = int(round(th / h * w))
+ else:
+ resize_width = tw
+ resize_height = int(round(tw / w * h))
+
+ crop_top = int(round((th - resize_height) / 2.0))
+ crop_left = int(round((tw - resize_width) / 2.0))
+
+ return (crop_top, crop_left), (crop_top + resize_height, crop_left + resize_width)
+
+
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
+def retrieve_timesteps(
+ scheduler,
+ num_inference_steps: Optional[int] = None,
+ device: Optional[Union[str, torch.device]] = None,
+ timesteps: Optional[List[int]] = None,
+ sigmas: Optional[List[float]] = None,
+ **kwargs,
+):
+ """
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
+
+ Args:
+ scheduler (`SchedulerMixin`):
+ The scheduler to get timesteps from.
+ num_inference_steps (`int`):
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
+ must be `None`.
+ device (`str` or `torch.device`, *optional*):
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
+ timesteps (`List[int]`, *optional*):
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
+ `num_inference_steps` and `sigmas` must be `None`.
+ sigmas (`List[float]`, *optional*):
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
+ `num_inference_steps` and `timesteps` must be `None`.
+
+ Returns:
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
+ second element is the number of inference steps.
+ """
+ if timesteps is not None and sigmas is not None:
+ raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
+ if timesteps is not None:
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
+ if not accepts_timesteps:
+ raise ValueError(
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
+ f" timestep schedules. Please check whether you are using the correct scheduler."
+ )
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ num_inference_steps = len(timesteps)
+ elif sigmas is not None:
+ accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
+ if not accept_sigmas:
+ raise ValueError(
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
+ )
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ num_inference_steps = len(timesteps)
+ else:
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ return timesteps, num_inference_steps
+
+
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
+def retrieve_latents(
+ encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
+):
+ if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
+ return encoder_output.latent_dist.sample(generator)
+ elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
+ return encoder_output.latent_dist.mode()
+ elif hasattr(encoder_output, "latents"):
+ return encoder_output.latents
+ else:
+ raise AttributeError("Could not access latents of provided encoder_output")
+
+
+class SkyReelsA1ImagePoseToVideoPipeline(DiffusionPipeline):
+ r"""
+ Pipeline for image-to-video generation using CogVideoX.
+
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
+
+ Args:
+ vae ([`AutoencoderKL`]):
+ Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations.
+ text_encoder ([`T5EncoderModel`]):
+ Frozen text-encoder. CogVideoX uses
+ [T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5EncoderModel); specifically the
+ [t5-v1_1-xxl](https://huggingface.co/PixArt-alpha/PixArt-alpha/tree/main/t5-v1_1-xxl) variant.
+ tokenizer (`T5Tokenizer`):
+ Tokenizer of class
+ [T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer).
+ transformer ([`CogVideoXTransformer3DModel`]):
+ A text conditioned `CogVideoXTransformer3DModel` to denoise the encoded video latents.
+ scheduler ([`SchedulerMixin`]):
+ A scheduler to be used in combination with `transformer` to denoise the encoded video latents.
+ """
+
+ _optional_components = []
+ model_cpu_offload_seq = "text_encoder->transformer->vae"
+
+ _callback_tensor_inputs = [
+ "latents",
+ "prompt_embeds",
+ "negative_prompt_embeds",
+ ]
+
+ def __init__(
+ self,
+ tokenizer: T5Tokenizer,
+ text_encoder: T5EncoderModel,
+ vae: AutoencoderKLCogVideoX,
+ lmk_encoder: AutoencoderKLCogVideoX,
+ transformer: CogVideoXTransformer3DModel,
+ scheduler: Union[CogVideoXDDIMScheduler, CogVideoXDPMScheduler],
+ image_encoder: SiglipVisionModel,
+ feature_extractor: SiglipImageProcessor,
+ ):
+ super().__init__()
+
+ self.register_modules(
+ tokenizer=tokenizer,
+ text_encoder=text_encoder,
+ vae=vae,
+ lmk_encoder=lmk_encoder,
+ transformer=transformer,
+ scheduler=scheduler,
+ image_encoder=image_encoder,
+ feature_extractor=feature_extractor,
+ )
+ self.vae_scale_factor_spatial = (
+ 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8
+ )
+ self.vae_scale_factor_temporal = (
+ self.vae.config.temporal_compression_ratio if hasattr(self, "vae") and self.vae is not None else 4
+ )
+ self.vae_scaling_factor_image = (
+ self.vae.config.scaling_factor if hasattr(self, "vae") and self.vae is not None else 0.7
+ )
+
+ self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial)
+
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
+
+ self.face_helper = FaceRestoreHelper(
+ upscale_factor=1,
+ face_size=512,
+ crop_ratio=(1, 1),
+ det_model='retinaface_resnet50',
+ save_ext='png',
+ )
+
+ # def _encode_image(self, image, device, num_videos_per_prompt, do_classifier_free_guidance):
+ # dtype = next(self.image_encoder.parameters()).dtype
+
+ # if not isinstance(image, torch.Tensor):
+ # image = self.image_processor.pil_to_numpy(image)
+ # image = self.image_processor.numpy_to_pt(image)
+
+ # # We normalize the image before resizing to match with the original implementation.
+ # # Then we unnormalize it after resizing.
+ # image = image * 2.0 - 1.0
+ # image = _resize_with_antialiasing(image, (224, 224))
+ # image = (image + 1.0) / 2.0
+
+ # # Normalize the image with for CLIP input
+ # image = self.feature_extractor(
+ # images=image,
+ # do_normalize=True,
+ # do_center_crop=False,
+ # do_resize=False,
+ # do_rescale=False,
+ # return_tensors="pt",
+ # ).pixel_values
+
+ # image = image.to(device=device, dtype=dtype)
+ # image_embeddings = self.image_encoder(image).image_embeds
+ # image_embeddings = image_embeddings.unsqueeze(1).repeat(1, 226, 1)
+ # # image_embeddings = image_embeddings.unsqueeze(1)
+
+ # # duplicate image embeddings for each generation per prompt, using mps friendly method
+ # # bs_embed, seq_len, _ = image_embeddings.shape
+ # # image_embeddings = image_embeddings.repeat(1, num_videos_per_prompt, 1)
+ # # image_embeddings = image_embeddings.view(bs_embed * num_videos_per_prompt, seq_len, -1)
+
+ # if do_classifier_free_guidance:
+ # negative_image_embeddings = torch.zeros_like(image_embeddings)
+
+ # # For classifier free guidance, we need to do two forward passes.
+ # # Here we concatenate the unconditional and text embeddings into a single batch
+ # # to avoid doing two forward passes
+ # image_embeddings = torch.cat([negative_image_embeddings, image_embeddings])
+
+ # return image_embeddings
+
+ def _encode_image(self, image, device, num_videos_per_prompt, do_classifier_free_guidance):
+ dtype = next(self.image_encoder.parameters()).dtype
+
+ imgs = self.feature_extractor.preprocess(images=[image], do_resize=True, return_tensors="pt", do_convert_rgb=True)
+ image_embeddings = self.image_encoder(**imgs.to(device=device, dtype=dtype)).last_hidden_state # torch.Size([2, 729, 1152])
+
+
+ bs_embed, seq_len, _ = image_embeddings.shape
+ image_embeddings = image_embeddings.repeat(1, num_videos_per_prompt, 1)
+ image_embeddings = image_embeddings.view(bs_embed * num_videos_per_prompt, seq_len, -1)
+
+ if do_classifier_free_guidance:
+ negative_image_embeddings = torch.zeros_like(image_embeddings)
+ image_embeddings = torch.cat([negative_image_embeddings, image_embeddings], dim=0)
+
+ return image_embeddings
+
+
+
+ # Copied from diffusers.pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipeline._get_t5_prompt_embeds
+ def _get_t5_prompt_embeds(
+ self,
+ prompt: Union[str, List[str]] = None,
+ num_videos_per_prompt: int = 1,
+ max_sequence_length: int = 226,
+ device: Optional[torch.device] = None,
+ dtype: Optional[torch.dtype] = None,
+ ):
+ device = device or self._execution_device
+ dtype = dtype or self.text_encoder.dtype
+
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+ batch_size = len(prompt)
+
+ text_inputs = self.tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=max_sequence_length,
+ truncation=True,
+ add_special_tokens=True,
+ return_tensors="pt",
+ )
+ text_input_ids = text_inputs.input_ids
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
+
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
+ removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1])
+ logger.warning(
+ "The following part of your input was truncated because `max_sequence_length` is set to "
+ f" {max_sequence_length} tokens: {removed_text}"
+ )
+
+ prompt_embeds = self.text_encoder(text_input_ids.to(device))[0]
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
+
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
+ _, seq_len, _ = prompt_embeds.shape
+ prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
+ prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
+
+ return prompt_embeds
+
+ # Copied from diffusers.pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipeline.encode_prompt
+ def encode_prompt(
+ self,
+ prompt: Union[str, List[str]],
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ do_classifier_free_guidance: bool = True,
+ num_videos_per_prompt: int = 1,
+ prompt_embeds: Optional[torch.Tensor] = None,
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
+ max_sequence_length: int = 226,
+ device: Optional[torch.device] = None,
+ dtype: Optional[torch.dtype] = None,
+ ):
+ r"""
+ Encodes the prompt into text encoder hidden states.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ prompt to be encoded
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+ less than `1`).
+ do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
+ Whether to use classifier free guidance or not.
+ num_videos_per_prompt (`int`, *optional*, defaults to 1):
+ Number of videos that should be generated per prompt. torch device to place the resulting embeddings on
+ prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ device: (`torch.device`, *optional*):
+ torch device
+ dtype: (`torch.dtype`, *optional*):
+ torch dtype
+ """
+ device = device or self._execution_device
+
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+ if prompt is not None:
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ if prompt_embeds is None:
+ prompt_embeds = self._get_t5_prompt_embeds(
+ prompt=prompt,
+ num_videos_per_prompt=num_videos_per_prompt,
+ max_sequence_length=max_sequence_length,
+ device=device,
+ dtype=dtype,
+ )
+
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
+ negative_prompt = negative_prompt or ""
+ negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
+
+ if prompt is not None and type(prompt) is not type(negative_prompt):
+ raise TypeError(
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
+ f" {type(prompt)}."
+ )
+ elif batch_size != len(negative_prompt):
+ raise ValueError(
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
+ " the batch size of `prompt`."
+ )
+
+ negative_prompt_embeds = self._get_t5_prompt_embeds(
+ prompt=negative_prompt,
+ num_videos_per_prompt=num_videos_per_prompt,
+ max_sequence_length=max_sequence_length,
+ device=device,
+ dtype=dtype,
+ )
+
+ return prompt_embeds, negative_prompt_embeds
+
+ def prepare_latents(
+ self,
+ image: torch.Tensor,
+ batch_size: int = 1,
+ num_channels_latents: int = 16,
+ num_frames: int = 13,
+ height: int = 60,
+ width: int = 90,
+ dtype: Optional[torch.dtype] = None,
+ device: Optional[torch.device] = None,
+ generator: Optional[torch.Generator] = None,
+ latents: Optional[torch.Tensor] = None,
+ ):
+ if isinstance(generator, list) and len(generator) != batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+ )
+
+ num_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1
+ shape = (
+ batch_size,
+ num_frames,
+ num_channels_latents,
+ height // self.vae_scale_factor_spatial,
+ width // self.vae_scale_factor_spatial,
+ )
+
+ image = image.unsqueeze(2) # [B, C, F, H, W]
+
+ if isinstance(generator, list):
+ image_latents = [
+ retrieve_latents(self.vae.encode(image[i].unsqueeze(0)), generator[i]) for i in range(batch_size)
+ ]
+ else:
+ image_latents = [retrieve_latents(self.vae.encode(img.unsqueeze(0)), generator) for img in image]
+
+ image_latents = torch.cat(image_latents, dim=0).to(dtype).permute(0, 2, 1, 3, 4) # [B, F, C, H, W]
+ image_latents = self.vae_scaling_factor_image * image_latents
+
+ padding_shape = (
+ batch_size,
+ num_frames - 1,
+ num_channels_latents,
+ height // self.vae_scale_factor_spatial,
+ width // self.vae_scale_factor_spatial,
+ )
+ latent_padding = torch.zeros(padding_shape, device=device, dtype=dtype)
+
+ image_latents = torch.cat([image_latents, latent_padding], dim=1)
+
+
+ if latents is None:
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+ else:
+ latents = latents.to(device)
+
+ # scale the initial noise by the standard deviation required by the scheduler
+ latents = latents * self.scheduler.init_noise_sigma
+ return latents, image_latents
+
+ def prepare_control_latents(
+ self, mask, masked_image, batch_size, height, width, dtype, device, generator, do_classifier_free_guidance
+ ):
+ # resize the mask to latents shape as we concatenate the mask to the latents
+ # we do that before converting to dtype to avoid breaking in case we're using cpu_offload
+ # and half precision
+
+ if mask is not None:
+ mask = mask.to(device=device, dtype=self.lmk_encoder.dtype)
+ bs = 1
+ new_mask = []
+ for i in range(0, mask.shape[0], bs):
+ mask_bs = mask[i : i + bs]
+ mask_bs = self.lmk_encoder.encode(mask_bs)[0]
+ mask_bs = mask_bs.mode()
+ new_mask.append(mask_bs)
+ mask = torch.cat(new_mask, dim = 0)
+ mask = mask * self.lmk_encoder.config.scaling_factor
+
+ if masked_image is not None:
+ masked_image = masked_image.to(device=device, dtype=self.lmk_encoder.dtype)
+ bs = 1
+ new_mask_pixel_values = []
+ for i in range(0, masked_image.shape[0], bs):
+ mask_pixel_values_bs = masked_image[i : i + bs]
+ mask_pixel_values_bs = self.lmk_encoder.encode(mask_pixel_values_bs)[0]
+ mask_pixel_values_bs = mask_pixel_values_bs.mode()
+ new_mask_pixel_values.append(mask_pixel_values_bs)
+ masked_image_latents = torch.cat(new_mask_pixel_values, dim = 0)
+ masked_image_latents = masked_image_latents * self.lmk_encoder.config.scaling_factor
+ else:
+ masked_image_latents = None
+
+ return mask, masked_image_latents
+
+ # Copied from diffusers.pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipeline.decode_latents
+ def decode_latents(self, latents: torch.Tensor) -> torch.Tensor:
+ latents = latents.permute(0, 2, 1, 3, 4) # [batch_size, num_channels, num_frames, height, width]
+ latents = 1 / self.vae_scaling_factor_image * latents
+
+ frames = self.vae.decode(latents).sample
+ return frames
+
+ # Copied from diffusers.pipelines.animatediff.pipeline_animatediff_video2video.AnimateDiffVideoToVideoPipeline.get_timesteps
+ def get_timesteps(self, num_inference_steps, timesteps, strength, device):
+ # get the original timestep using init_timestep
+ init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
+
+ t_start = max(num_inference_steps - init_timestep, 0)
+ timesteps = timesteps[t_start * self.scheduler.order :]
+
+ return timesteps, num_inference_steps - t_start
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
+ def prepare_extra_step_kwargs(self, generator, eta):
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+ # and should be between [0, 1]
+
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ extra_step_kwargs = {}
+ if accepts_eta:
+ extra_step_kwargs["eta"] = eta
+
+ # check if the scheduler accepts generator
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ if accepts_generator:
+ extra_step_kwargs["generator"] = generator
+ return extra_step_kwargs
+
+ def check_inputs(
+ self,
+ image,
+ height,
+ width,
+ callback_on_step_end_tensor_inputs,
+ latents=None,
+ ):
+ if (
+ not isinstance(image, torch.Tensor)
+ and not isinstance(image, PIL.Image.Image)
+ and not isinstance(image, list)
+ ):
+ raise ValueError(
+ "`image` has to be of type `torch.Tensor` or `PIL.Image.Image` or `List[PIL.Image.Image]` but is"
+ f" {type(image)}"
+ )
+
+ if height % 8 != 0 or width % 8 != 0:
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
+
+ if callback_on_step_end_tensor_inputs is not None and not all(
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
+ ):
+ raise ValueError(
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
+ )
+
+ # Copied from diffusers.pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipeline.fuse_qkv_projections
+ def fuse_qkv_projections(self) -> None:
+ r"""Enables fused QKV projections."""
+ self.fusing_transformer = True
+ self.transformer.fuse_qkv_projections()
+
+ # Copied from diffusers.pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipeline.unfuse_qkv_projections
+ def unfuse_qkv_projections(self) -> None:
+ r"""Disable QKV projection fusion if enabled."""
+ if not self.fusing_transformer:
+ logger.warning("The Transformer was not initially fused for QKV projections. Doing nothing.")
+ else:
+ self.transformer.unfuse_qkv_projections()
+ self.fusing_transformer = False
+
+ # Copied from diffusers.pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipeline._prepare_rotary_positional_embeddings
+ def _prepare_rotary_positional_embeddings(
+ self,
+ height: int,
+ width: int,
+ num_frames: int,
+ device: torch.device,
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ grid_height = height // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
+ grid_width = width // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
+ base_size_width = 720 // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
+ base_size_height = 480 // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
+
+ grid_crops_coords = get_resize_crop_region_for_grid(
+ (grid_height, grid_width), base_size_width, base_size_height
+ )
+ freqs_cos, freqs_sin = get_3d_rotary_pos_embed(
+ embed_dim=self.transformer.config.attention_head_dim,
+ crops_coords=grid_crops_coords,
+ grid_size=(grid_height, grid_width),
+ temporal_size=num_frames,
+ )
+
+ freqs_cos = freqs_cos.to(device=device)
+ freqs_sin = freqs_sin.to(device=device)
+ return freqs_cos, freqs_sin
+
+ @property
+ def guidance_scale(self):
+ return self._guidance_scale
+
+ @property
+ def num_timesteps(self):
+ return self._num_timesteps
+
+ @property
+ def attention_kwargs(self):
+ return self._attention_kwargs
+
+ @property
+ def interrupt(self):
+ return self._interrupt
+
+ @torch.no_grad()
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
+ def __call__(
+ self,
+ image: PipelineImageInput,
+ image_face: PipelineImageInput,
+ video: Union[torch.FloatTensor] = None,
+ control_video: Union[torch.FloatTensor] = None,
+ prompt: Optional[Union[str, List[str]]] = None,
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ height: int = 480,
+ width: int = 720,
+ num_frames: int = 49,
+ num_inference_steps: int = 50,
+ timesteps: Optional[List[int]] = None,
+ guidance_scale: float = 6,
+ use_dynamic_cfg: bool = False,
+ num_videos_per_prompt: int = 1,
+ eta: float = 0.0,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.FloatTensor] = None,
+ prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
+ output_type: str = "pil",
+ return_dict: bool = True,
+ attention_kwargs: Optional[Dict[str, Any]] = None,
+ callback_on_step_end: Optional[
+ Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
+ ] = None,
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
+ max_sequence_length: int = 226,
+ ) -> Union[CogVideoXPipelineOutput, Tuple]:
+ """
+ Function invoked when calling the pipeline for generation.
+
+ Args:
+ image (`PipelineImageInput`):
+ The input image to condition the generation on. Must be an image, a list of images or a `torch.Tensor`.
+ prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
+ instead.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+ less than `1`).
+ height (`int`, *optional*, defaults to self.transformer.config.sample_height * self.vae_scale_factor_spatial):
+ The height in pixels of the generated image. This is set to 480 by default for the best results.
+ width (`int`, *optional*, defaults to self.transformer.config.sample_height * self.vae_scale_factor_spatial):
+ The width in pixels of the generated image. This is set to 720 by default for the best results.
+ num_frames (`int`, defaults to `48`):
+ Number of frames to generate. Must be divisible by self.vae_scale_factor_temporal. Generated video will
+ contain 1 extra frame because CogVideoX is conditioned with (num_seconds * fps + 1) frames where
+ num_seconds is 6 and fps is 4. However, since videos can be saved at any fps, the only condition that
+ needs to be satisfied is that of divisibility mentioned above.
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ timesteps (`List[int]`, *optional*):
+ Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
+ in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
+ passed will be used. Must be in descending order.
+ guidance_scale (`float`, *optional*, defaults to 7.0):
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
+ usually at the expense of lower image quality.
+ num_videos_per_prompt (`int`, *optional*, defaults to 1):
+ The number of videos to generate per prompt.
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
+ to make generation deterministic.
+ latents (`torch.FloatTensor`, *optional*):
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor will ge generated by sampling using the supplied random `generator`.
+ prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generate image. Choose between
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead
+ of a plain tuple.
+ attention_kwargs (`dict`, *optional*):
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
+ `self.processor` in
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
+ callback_on_step_end (`Callable`, *optional*):
+ A function that calls at the end of each denoising steps during the inference. The function is called
+ with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
+ callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
+ `callback_on_step_end_tensor_inputs`.
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
+ `._callback_tensor_inputs` attribute of your pipeline class.
+ max_sequence_length (`int`, defaults to `226`):
+ Maximum sequence length in encoded prompt. Must be consistent with
+ `self.transformer.config.max_text_seq_length` otherwise may lead to poor results.
+
+ Examples:
+
+ Returns:
+ [`~pipelines.cogvideo.pipeline_output.CogVideoXPipelineOutput`] or `tuple`:
+ [`~pipelines.cogvideo.pipeline_output.CogVideoXPipelineOutput`] if `return_dict` is True, otherwise a
+ `tuple`. When returning a tuple, the first element is a list with the generated images.
+ """
+
+ if num_frames > 49:
+ raise ValueError(
+ "The number of frames must be less than 49 for now due to static positional embeddings. This will be updated in the future to remove this limitation."
+ )
+
+ if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
+ callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
+
+ num_videos_per_prompt = 1
+
+ # 1. Check inputs. Raise error if not correct
+ self.check_inputs(
+ image=image,
+ height=height,
+ width=width,
+ callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
+ latents=latents,
+ )
+ self._guidance_scale = guidance_scale
+ self._attention_kwargs = attention_kwargs
+ self._interrupt = False
+
+ # 2. Default call parameters
+ batch_size = 1
+
+ device = self._execution_device
+
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # corresponds to doing no classifier free guidance.
+ do_classifier_free_guidance = guidance_scale > 1.0
+
+ # 3. Encode image prompt
+ image_embeddings = self._encode_image(Image.fromarray(np.array(image_face)), device, num_videos_per_prompt, do_classifier_free_guidance)
+ image_embeddings = image_embeddings.type(torch.bfloat16)
+
+ # 4. Prepare timesteps
+ timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
+ self._num_timesteps = len(timesteps)
+
+ # 5. Prepare latents
+ image = self.video_processor.preprocess(image, height=height, width=width).to(
+ device, dtype=image_embeddings.dtype
+ )
+
+ latent_channels = self.transformer.config.in_channels // 3
+ latents, image_latents = self.prepare_latents(
+ image,
+ batch_size * num_videos_per_prompt,
+ latent_channels,
+ num_frames,
+ height,
+ width,
+ image_embeddings.dtype,
+ device,
+ generator,
+ latents,
+ )
+
+ if control_video is not None:
+ video_length = control_video.shape[2]
+ control_video = self.video_processor.preprocess(rearrange(control_video, "b c f h w -> (b f) c h w"), height=height, width=width)
+ control_video = control_video.to(dtype=torch.float32)
+ control_video = rearrange(control_video, "(b f) c h w -> b c f h w", f=video_length)
+ else:
+ control_video = None
+
+ control_video_latents = self.prepare_control_latents(
+ None,
+ control_video,
+ batch_size,
+ height,
+ width,
+ image_embeddings.dtype,
+ device,
+ generator,
+ do_classifier_free_guidance
+ )[1]
+
+ control_video_latents_input = (
+ torch.cat([control_video_latents] * 2) if do_classifier_free_guidance else control_video_latents
+ )
+ control_latents = rearrange(control_video_latents_input, "b c f h w -> b f c h w")
+
+ # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
+
+ # 7. Create rotary embeds if required
+ image_rotary_emb = (
+ self._prepare_rotary_positional_embeddings(height, width, latents.size(1), device)
+ if self.transformer.config.use_rotary_positional_embeddings
+ else None
+ )
+
+ # 8. Denoising loop
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
+
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ # for DPM-solver++
+ old_pred_original_sample = None
+ for i, t in enumerate(timesteps):
+ if self.interrupt:
+ continue
+
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
+
+ latent_image_input = torch.cat([image_latents] * 2) if do_classifier_free_guidance else image_latents
+ latent_model_input = torch.cat([latent_model_input, control_latents, latent_image_input], dim=2)
+
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
+ timestep = t.expand(latent_model_input.shape[0])
+ # predict noise model_output
+ noise_pred = self.transformer(
+ hidden_states=latent_model_input,
+ encoder_hidden_states=image_embeddings,
+ timestep=timestep,
+ image_rotary_emb=image_rotary_emb,
+ # attention_kwargs=attention_kwargs,
+ return_dict=False,
+ )[0]
+ noise_pred = noise_pred.float()
+
+ # perform guidance
+ if use_dynamic_cfg:
+ self._guidance_scale = 1 + guidance_scale * (
+ (1 - math.cos(math.pi * ((num_inference_steps - t.item()) / num_inference_steps) ** 5.0)) / 2
+ )
+ if do_classifier_free_guidance:
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+ noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
+
+ # compute the previous noisy sample x_t -> x_t-1
+ if not isinstance(self.scheduler, CogVideoXDPMScheduler):
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
+ else:
+ latents, old_pred_original_sample = self.scheduler.step(
+ noise_pred,
+ old_pred_original_sample,
+ t,
+ timesteps[i - 1] if i > 0 else None,
+ latents,
+ **extra_step_kwargs,
+ return_dict=False,
+ )
+ latents = latents.to(image_embeddings.dtype)
+
+ # call the callback, if provided
+ if callback_on_step_end is not None:
+ callback_kwargs = {}
+ for k in callback_on_step_end_tensor_inputs:
+ callback_kwargs[k] = locals()[k]
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
+
+ latents = callback_outputs.pop("latents", latents)
+
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ progress_bar.update()
+
+ if not output_type == "latent":
+ video = self.decode_latents(latents)
+ video = self.video_processor.postprocess_video(video=video, output_type=output_type)
+ else:
+ video = latents
+
+ # Offload all models
+ self.maybe_free_model_hooks()
+
+ if not return_dict:
+ return (video,)
+
+ return CogVideoXPipelineOutput(frames=video)
diff --git a/skyreels_a1/src/FLAME/FLAME.py b/skyreels_a1/src/FLAME/FLAME.py
new file mode 100644
index 0000000000000000000000000000000000000000..bac2915b172dcc3cecdf88ca4ea4c7dfb9b1da9e
--- /dev/null
+++ b/skyreels_a1/src/FLAME/FLAME.py
@@ -0,0 +1,314 @@
+# -*- coding: utf-8 -*-
+#
+# Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is
+# holder of all proprietary rights on this computer program.
+# Using this computer program means that you agree to the terms
+# in the LICENSE file included with this software distribution.
+# Any use not explicitly granted by the LICENSE is prohibited.
+#
+# Copyright©2019 Max-Planck-Gesellschaft zur Förderung
+# der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute
+# for Intelligent Systems. All rights reserved.
+#
+# For comments or questions, please email us at deca@tue.mpg.de
+# For commercial licensing contact, please contact ps-license@tuebingen.mpg.de
+
+import torch
+import torch.nn as nn
+import numpy as np
+np.bool = np.bool_
+np.int = np.int_
+np.float = np.float_
+np.complex = np.complex_
+np.object = np.object_
+np.unicode = np.unicode_
+np.str = np.str_
+import pickle
+import torch.nn.functional as F
+
+from .lbs import lbs, batch_rodrigues, vertices2landmarks, rot_mat_to_euler
+
+def to_tensor(array, dtype=torch.float32):
+ if 'torch.tensor' not in str(type(array)):
+ return torch.tensor(array, dtype=dtype)
+def to_np(array, dtype=np.float32):
+ if 'scipy.sparse' in str(type(array)):
+ array = array.todense()
+ return np.array(array, dtype=dtype)
+
+class Struct(object):
+ def __init__(self, **kwargs):
+ for key, val in kwargs.items():
+ setattr(self, key, val)
+
+class FLAME(nn.Module):
+ """
+ borrowed from https://github.com/soubhiksanyal/FLAME_PyTorch/blob/master/FLAME.py
+ Given flame parameters this class generates a differentiable FLAME function
+ which outputs the a mesh and 2D/3D facial landmarks
+ """
+ def __init__(self, flame_model_path='pretrained_models/FLAME/generic_model.pkl',
+ flame_lmk_embedding_path='pretrained_models/FLAME/landmark_embedding.npy', n_shape=300, n_exp=50):
+ super(FLAME, self).__init__()
+
+ with open(flame_model_path, 'rb') as f:
+ ss = pickle.load(f, encoding='latin1')
+ flame_model = Struct(**ss)
+
+ self.n_shape = n_shape
+ self.n_exp = n_exp
+ self.dtype = torch.float32
+ self.register_buffer('faces_tensor', to_tensor(to_np(flame_model.f, dtype=np.int64), dtype=torch.long))
+ # The vertices of the template model
+ print('Using generic FLAME model')
+ self.register_buffer('v_template', to_tensor(to_np(flame_model.v_template), dtype=self.dtype))
+
+ # The shape components and expression
+ shapedirs = to_tensor(to_np(flame_model.shapedirs), dtype=self.dtype)
+ shapedirs = torch.cat([shapedirs[:,:,:n_shape], shapedirs[:,:,300:300+n_exp]], 2)
+ self.register_buffer('shapedirs', shapedirs)
+ # The pose components
+ num_pose_basis = flame_model.posedirs.shape[-1]
+ posedirs = np.reshape(flame_model.posedirs, [-1, num_pose_basis]).T
+ self.register_buffer('posedirs', to_tensor(to_np(posedirs), dtype=self.dtype))
+ #
+ self.register_buffer('J_regressor', to_tensor(to_np(flame_model.J_regressor), dtype=self.dtype))
+ parents = to_tensor(to_np(flame_model.kintree_table[0])).long(); parents[0] = -1
+ self.register_buffer('parents', parents)
+ self.register_buffer('lbs_weights', to_tensor(to_np(flame_model.weights), dtype=self.dtype))
+
+
+ self.register_buffer('l_eyelid', torch.from_numpy(np.load(f'pretrained_models/smirk/l_eyelid.npy')).to(self.dtype)[None])
+ self.register_buffer('r_eyelid', torch.from_numpy(np.load(f'pretrained_models/smirk/r_eyelid.npy')).to(self.dtype)[None])
+ # import pdb;pdb.set_trace()
+
+ # Fixing Eyeball and neck rotation
+ default_eyball_pose = torch.zeros([1, 6], dtype=self.dtype, requires_grad=False)
+ self.register_parameter('eye_pose', nn.Parameter(default_eyball_pose,
+ requires_grad=False))
+ default_neck_pose = torch.zeros([1, 3], dtype=self.dtype, requires_grad=False)
+ self.register_parameter('neck_pose', nn.Parameter(default_neck_pose,
+ requires_grad=False))
+
+ # Static and Dynamic Landmark embeddings for FLAME
+ lmk_embeddings = np.load(flame_lmk_embedding_path, allow_pickle=True, encoding='latin1')
+ lmk_embeddings = lmk_embeddings[()]
+ self.register_buffer('lmk_faces_idx', torch.from_numpy(lmk_embeddings['static_lmk_faces_idx']).long())
+ self.register_buffer('lmk_bary_coords', torch.from_numpy(lmk_embeddings['static_lmk_bary_coords']).to(self.dtype))
+ self.register_buffer('dynamic_lmk_faces_idx', lmk_embeddings['dynamic_lmk_faces_idx'].long())
+ self.register_buffer('dynamic_lmk_bary_coords', lmk_embeddings['dynamic_lmk_bary_coords'].to(self.dtype))
+ self.register_buffer('full_lmk_faces_idx', torch.from_numpy(lmk_embeddings['full_lmk_faces_idx']).long())
+ self.register_buffer('full_lmk_bary_coords', torch.from_numpy(lmk_embeddings['full_lmk_bary_coords']).to(self.dtype))
+
+ neck_kin_chain = []; NECK_IDX=1
+ curr_idx = torch.tensor(NECK_IDX, dtype=torch.long)
+ while curr_idx != -1:
+ neck_kin_chain.append(curr_idx)
+ curr_idx = self.parents[curr_idx]
+ self.register_buffer('neck_kin_chain', torch.stack(neck_kin_chain))
+
+ lmk_embeddings_mp = np.load("pretrained_models/smirk/mediapipe_landmark_embedding.npz")
+ self.register_buffer('mp_lmk_faces_idx', torch.from_numpy(lmk_embeddings_mp['lmk_face_idx'].astype('int32')).long())
+ self.register_buffer('mp_lmk_bary_coords', torch.from_numpy(lmk_embeddings_mp['lmk_b_coords']).to(self.dtype))
+
+ def _find_dynamic_lmk_idx_and_bcoords(self, pose, dynamic_lmk_faces_idx,
+ dynamic_lmk_b_coords,
+ neck_kin_chain, dtype=torch.float32):
+ """
+ Selects the face contour depending on the reletive position of the head
+ Input:
+ vertices: N X num_of_vertices X 3
+ pose: N X full pose
+ dynamic_lmk_faces_idx: The list of contour face indexes
+ dynamic_lmk_b_coords: The list of contour barycentric weights
+ neck_kin_chain: The tree to consider for the relative rotation
+ dtype: Data type
+ return:
+ The contour face indexes and the corresponding barycentric weights
+ """
+
+ batch_size = pose.shape[0]
+
+ aa_pose = torch.index_select(pose.view(batch_size, -1, 3), 1,
+ neck_kin_chain)
+ rot_mats = batch_rodrigues(
+ aa_pose.view(-1, 3), dtype=dtype).view(batch_size, -1, 3, 3)
+
+ rel_rot_mat = torch.eye(3, device=pose.device,
+ dtype=dtype).unsqueeze_(dim=0).expand(batch_size, -1, -1)
+ for idx in range(len(neck_kin_chain)):
+ rel_rot_mat = torch.bmm(rot_mats[:, idx], rel_rot_mat)
+
+ y_rot_angle = torch.round(
+ torch.clamp(rot_mat_to_euler(rel_rot_mat) * 180.0 / np.pi,
+ max=39)).to(dtype=torch.long)
+
+ neg_mask = y_rot_angle.lt(0).to(dtype=torch.long)
+ mask = y_rot_angle.lt(-39).to(dtype=torch.long)
+ neg_vals = mask * 78 + (1 - mask) * (39 - y_rot_angle)
+ y_rot_angle = (neg_mask * neg_vals +
+ (1 - neg_mask) * y_rot_angle)
+
+ dyn_lmk_faces_idx = torch.index_select(dynamic_lmk_faces_idx,
+ 0, y_rot_angle)
+ dyn_lmk_b_coords = torch.index_select(dynamic_lmk_b_coords,
+ 0, y_rot_angle)
+ return dyn_lmk_faces_idx, dyn_lmk_b_coords
+
+ def _vertices2landmarks(self, vertices, faces, lmk_faces_idx, lmk_bary_coords):
+ """
+ Calculates landmarks by barycentric interpolation
+ Input:
+ vertices: torch.tensor NxVx3, dtype = torch.float32
+ The tensor of input vertices
+ faces: torch.tensor (N*F)x3, dtype = torch.long
+ The faces of the mesh
+ lmk_faces_idx: torch.tensor N X L, dtype = torch.long
+ The tensor with the indices of the faces used to calculate the
+ landmarks.
+ lmk_bary_coords: torch.tensor N X L X 3, dtype = torch.float32
+ The tensor of barycentric coordinates that are used to interpolate
+ the landmarks
+
+ Returns:
+ landmarks: torch.tensor NxLx3, dtype = torch.float32
+ The coordinates of the landmarks for each mesh in the batch
+ """
+ # Extract the indices of the vertices for each face
+ # NxLx3
+ batch_size, num_verts = vertices.shape[:dd2]
+ lmk_faces = torch.index_select(faces, 0, lmk_faces_idx.view(-1)).view(
+ 1, -1, 3).view(batch_size, lmk_faces_idx.shape[1], -1)
+
+ lmk_faces += torch.arange(batch_size, dtype=torch.long).view(-1, 1, 1).to(
+ device=vertices.device) * num_verts
+
+ lmk_vertices = vertices.view(-1, 3)[lmk_faces]
+ landmarks = torch.einsum('blfi,blf->bli', [lmk_vertices, lmk_bary_coords])
+ return landmarks
+
+ def seletec_3d68(self, vertices):
+ landmarks3d = vertices2landmarks(vertices, self.faces_tensor,
+ self.full_lmk_faces_idx.repeat(vertices.shape[0], 1),
+ self.full_lmk_bary_coords.repeat(vertices.shape[0], 1, 1))
+ return landmarks3d
+
+ def get_landmarks(self, vertices):
+ """
+ Input:
+ shape_params: N X number of shape parameters
+ expression_params: N X number of expression parameters
+ pose_params: N X number of pose parameters (6)
+ return:d
+ vertices: N X V X 3
+ landmarks: N X number of landmarks X 3
+ """
+ batch_size = vertices.shape[0]
+ template_vertices = self.v_template.unsqueeze(0).expand(batch_size, -1, -1)
+
+ lmk_faces_idx = self.lmk_faces_idx.unsqueeze(dim=0).expand(batch_size, -1)
+ lmk_bary_coords = self.lmk_bary_coords.unsqueeze(dim=0).expand(batch_size, -1, -1)
+
+ dyn_lmk_faces_idx, dyn_lmk_bary_coords = self._find_dynamic_lmk_idx_and_bcoords(
+ full_pose, self.dynamic_lmk_faces_idx,
+ self.dynamic_lmk_bary_coords,
+ self.neck_kin_chain, dtype=self.dtype)
+ lmk_faces_idx = torch.cat([dyn_lmk_faces_idx, lmk_faces_idx], 1)
+ lmk_bary_coords = torch.cat([dyn_lmk_bary_coords, lmk_bary_coords], 1)
+
+ landmarks2d = vertices2landmarks(vertices, self.faces_tensor,
+ lmk_faces_idx,
+ lmk_bary_coords)
+ bz = vertices.shape[0]
+ landmarks3d = vertices2landmarks(vertices, self.faces_tensor,
+ self.full_lmk_faces_idx.repeat(bz, 1),
+ self.full_lmk_bary_coords.repeat(bz, 1, 1))
+ return vertices, landmarks2d, landmarks3d
+
+
+ def forward(self, param_dictionary, zero_expression=False, zero_shape=False, zero_pose=False):
+ shape_params = param_dictionary['shape_params']
+ expression_params = param_dictionary['expression_params']
+ pose_params = param_dictionary.get('pose_params', None)
+ jaw_params = param_dictionary.get('jaw_params', None)
+ eye_pose_params = param_dictionary.get('eye_pose_params', None)
+ neck_pose_params = param_dictionary.get('neck_pose_params', None)
+ eyelid_params = param_dictionary.get('eyelid_params', None)
+
+ batch_size = shape_params.shape[0]
+
+ # Adjust expression params size if needed
+ if expression_params.shape[1] < self.n_exp:
+ expression_params = torch.cat([expression_params, torch.zeros(expression_params.shape[0], self.n_exp - expression_params.shape[1]).to(shape_params.device)], dim=1)
+
+ if shape_params.shape[1] < self.n_shape:
+ shape_params = torch.cat([shape_params, torch.zeros(shape_params.shape[0], self.n_shape - shape_params.shape[1]).to(shape_params.device)], dim=1)
+
+ # Zero out the expression and pose parameters if needed
+ if zero_expression:
+ expression_params = torch.zeros_like(expression_params).to(shape_params.device)
+ jaw_params = torch.zeros_like(jaw_params).to(shape_params.device)
+
+ if zero_shape:
+ shape_params = torch.zeros_like(shape_params).to(shape_params.device)
+
+
+ if zero_pose:
+ pose_params = torch.zeros_like(pose_params).to(shape_params.device)
+ pose_params[...,0] = 0.2
+ pose_params[...,1] = -0.7
+
+ if pose_params is None:
+ pose_params = self.pose_params.expand(batch_size, -1)
+
+ if eye_pose_params is None:
+ eye_pose_params = self.eye_pose.expand(batch_size, -1)
+
+ if neck_pose_params is None:
+ neck_pose_params = self.neck_pose.expand(batch_size, -1)
+
+
+ betas = torch.cat([shape_params, expression_params], dim=1)
+ full_pose = torch.cat([pose_params, neck_pose_params, jaw_params, eye_pose_params], dim=1)
+ # import pdb;pdb.set_trace()
+ template_vertices = self.v_template.unsqueeze(0).expand(batch_size, -1, -1)
+
+ vertices, _ = lbs(betas, full_pose, template_vertices,
+ self.shapedirs, self.posedirs,
+ self.J_regressor, self.parents,
+ self.lbs_weights, dtype=self.dtype)
+ # import pdb;pdb.set_trace()
+ if eyelid_params is not None:
+ vertices = vertices + self.r_eyelid.expand(batch_size, -1, -1) * eyelid_params[:, 1:2, None]
+ vertices = vertices + self.l_eyelid.expand(batch_size, -1, -1) * eyelid_params[:, 0:1, None]
+
+ lmk_faces_idx = self.lmk_faces_idx.unsqueeze(dim=0).expand(batch_size, -1)
+ lmk_bary_coords = self.lmk_bary_coords.unsqueeze(dim=0).expand(batch_size, -1, -1)
+
+ dyn_lmk_faces_idx, dyn_lmk_bary_coords = self._find_dynamic_lmk_idx_and_bcoords(
+ full_pose, self.dynamic_lmk_faces_idx,
+ self.dynamic_lmk_bary_coords,
+ self.neck_kin_chain, dtype=self.dtype)
+ lmk_faces_idx = torch.cat([dyn_lmk_faces_idx, lmk_faces_idx], 1)
+ lmk_bary_coords = torch.cat([dyn_lmk_bary_coords, lmk_bary_coords], 1)
+
+ landmarks2d = vertices2landmarks(vertices, self.faces_tensor,
+ lmk_faces_idx,
+ lmk_bary_coords)
+ bz = vertices.shape[0]
+ landmarks3d = vertices2landmarks(vertices, self.faces_tensor,
+ self.full_lmk_faces_idx.repeat(bz, 1),
+ self.full_lmk_bary_coords.repeat(bz, 1, 1))
+
+ landmarksmp = vertices2landmarks(vertices, self.faces_tensor,
+ self.mp_lmk_faces_idx.repeat(vertices.shape[0], 1),
+ self.mp_lmk_bary_coords.repeat(vertices.shape[0], 1, 1))
+
+ return {
+ 'vertices': vertices,
+ 'landmarks_fan': landmarks2d,
+ 'landmarks_fan_3d': landmarks3d,
+ 'landmarks_mp': landmarksmp
+ }
+
+
diff --git a/skyreels_a1/src/FLAME/lbs.py b/skyreels_a1/src/FLAME/lbs.py
new file mode 100644
index 0000000000000000000000000000000000000000..df55ab37bf3a644ddd6fb7e5a2d7f48c3c793c31
--- /dev/null
+++ b/skyreels_a1/src/FLAME/lbs.py
@@ -0,0 +1,378 @@
+# -*- coding: utf-8 -*-
+
+# Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is
+# holder of all proprietary rights on this computer program.
+# You can only use this computer program if you have closed
+# a license agreement with MPG or you get the right to use the computer
+# program from someone who is authorized to grant you that right.
+# Any use of the computer program without a valid license is prohibited and
+# liable to prosecution.
+#
+# Copyright©2019 Max-Planck-Gesellschaft zur Förderung
+# der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute
+# for Intelligent Systems. All rights reserved.
+#
+# Contact: ps-license@tuebingen.mpg.de
+
+from __future__ import absolute_import
+from __future__ import print_function
+from __future__ import division
+
+import numpy as np
+
+import torch
+import torch.nn.functional as F
+
+def rot_mat_to_euler(rot_mats):
+ # Calculates rotation matrix to euler angles
+ # Careful for extreme cases of eular angles like [0.0, pi, 0.0]
+
+ sy = torch.sqrt(rot_mats[:, 0, 0] * rot_mats[:, 0, 0] +
+ rot_mats[:, 1, 0] * rot_mats[:, 1, 0])
+ return torch.atan2(-rot_mats[:, 2, 0], sy)
+
+def find_dynamic_lmk_idx_and_bcoords(vertices, pose, dynamic_lmk_faces_idx,
+ dynamic_lmk_b_coords,
+ neck_kin_chain, dtype=torch.float32):
+ ''' Compute the faces, barycentric coordinates for the dynamic landmarks
+
+
+ To do so, we first compute the rotation of the neck around the y-axis
+ and then use a pre-computed look-up table to find the faces and the
+ barycentric coordinates that will be used.
+
+ Special thanks to Soubhik Sanyal (soubhik.sanyal@tuebingen.mpg.de)
+ for providing the original TensorFlow implementation and for the LUT.
+
+ Parameters
+ ----------
+ vertices: torch.tensor BxVx3, dtype = torch.float32
+ The tensor of input vertices
+ pose: torch.tensor Bx(Jx3), dtype = torch.float32
+ The current pose of the body model
+ dynamic_lmk_faces_idx: torch.tensor L, dtype = torch.long
+ The look-up table from neck rotation to faces
+ dynamic_lmk_b_coords: torch.tensor Lx3, dtype = torch.float32
+ The look-up table from neck rotation to barycentric coordinates
+ neck_kin_chain: list
+ A python list that contains the indices of the joints that form the
+ kinematic chain of the neck.
+ dtype: torch.dtype, optional
+
+ Returns
+ -------
+ dyn_lmk_faces_idx: torch.tensor, dtype = torch.long
+ A tensor of size BxL that contains the indices of the faces that
+ will be used to compute the current dynamic landmarks.
+ dyn_lmk_b_coords: torch.tensor, dtype = torch.float32
+ A tensor of size BxL that contains the indices of the faces that
+ will be used to compute the current dynamic landmarks.
+ '''
+
+ batch_size = vertices.shape[0]
+
+ aa_pose = torch.index_select(pose.view(batch_size, -1, 3), 1,
+ neck_kin_chain)
+ rot_mats = batch_rodrigues(
+ aa_pose.view(-1, 3), dtype=dtype).view(batch_size, -1, 3, 3)
+
+ rel_rot_mat = torch.eye(3, device=vertices.device,
+ dtype=dtype).unsqueeze_(dim=0)
+ for idx in range(len(neck_kin_chain)):
+ rel_rot_mat = torch.bmm(rot_mats[:, idx], rel_rot_mat)
+
+ y_rot_angle = torch.round(
+ torch.clamp(-rot_mat_to_euler(rel_rot_mat) * 180.0 / np.pi,
+ max=39)).to(dtype=torch.long)
+ neg_mask = y_rot_angle.lt(0).to(dtype=torch.long)
+ mask = y_rot_angle.lt(-39).to(dtype=torch.long)
+ neg_vals = mask * 78 + (1 - mask) * (39 - y_rot_angle)
+ y_rot_angle = (neg_mask * neg_vals +
+ (1 - neg_mask) * y_rot_angle)
+
+ dyn_lmk_faces_idx = torch.index_select(dynamic_lmk_faces_idx,
+ 0, y_rot_angle)
+ dyn_lmk_b_coords = torch.index_select(dynamic_lmk_b_coords,
+ 0, y_rot_angle)
+
+ return dyn_lmk_faces_idx, dyn_lmk_b_coords
+
+
+def vertices2landmarks(vertices, faces, lmk_faces_idx, lmk_bary_coords):
+ ''' Calculates landmarks by barycentric interpolation
+
+ Parameters
+ ----------
+ vertices: torch.tensor BxVx3, dtype = torch.float32
+ The tensor of input vertices
+ faces: torch.tensor Fx3, dtype = torch.long
+ The faces of the mesh
+ lmk_faces_idx: torch.tensor L, dtype = torch.long
+ The tensor with the indices of the faces used to calculate the
+ landmarks.
+ lmk_bary_coords: torch.tensor Lx3, dtype = torch.float32
+ The tensor of barycentric coordinates that are used to interpolate
+ the landmarks
+
+ Returns
+ -------
+ landmarks: torch.tensor BxLx3, dtype = torch.float32
+ The coordinates of the landmarks for each mesh in the batch
+ '''
+ # Extract the indices of the vertices for each face
+ # BxLx3
+ batch_size, num_verts = vertices.shape[:2]
+ device = vertices.device
+
+ lmk_faces = torch.index_select(faces, 0, lmk_faces_idx.view(-1)).view(
+ batch_size, -1, 3)
+
+ lmk_faces += torch.arange(
+ batch_size, dtype=torch.long, device=device).view(-1, 1, 1) * num_verts
+
+ lmk_vertices = vertices.view(-1, 3)[lmk_faces].view(
+ batch_size, -1, 3, 3)
+
+ landmarks = torch.einsum('blfi,blf->bli', [lmk_vertices, lmk_bary_coords])
+ return landmarks
+
+
+def lbs(betas, pose, v_template, shapedirs, posedirs, J_regressor, parents,
+ lbs_weights, pose2rot=True, dtype=torch.float32):
+ ''' Performs Linear Blend Skinning with the given shape and pose parameters
+
+ Parameters
+ ----------
+ betas : torch.tensor BxNB
+ The tensor of shape parameters
+ pose : torch.tensor Bx(J + 1) * 3
+ The pose parameters in axis-angle format
+ v_template torch.tensor BxVx3
+ The template mesh that will be deformed
+ shapedirs : torch.tensor 1xNB
+ The tensor of PCA shape displacements
+ posedirs : torch.tensor Px(V * 3)
+ The pose PCA coefficients
+ J_regressor : torch.tensor JxV
+ The regressor array that is used to calculate the joints from
+ the position of the vertices
+ parents: torch.tensor J
+ The array that describes the kinematic tree for the model
+ lbs_weights: torch.tensor N x V x (J + 1)
+ The linear blend skinning weights that represent how much the
+ rotation matrix of each part affects each vertex
+ pose2rot: bool, optional
+ Flag on whether to convert the input pose tensor to rotation
+ matrices. The default value is True. If False, then the pose tensor
+ should already contain rotation matrices and have a size of
+ Bx(J + 1)x9
+ dtype: torch.dtype, optional
+
+ Returns
+ -------
+ verts: torch.tensor BxVx3
+ The vertices of the mesh after applying the shape and pose
+ displacements.
+ joints: torch.tensor BxJx3
+ The joints of the model
+ '''
+
+ batch_size = max(betas.shape[0], pose.shape[0])
+ device = betas.device
+
+ # Add shape contribution
+ v_shaped = v_template + blend_shapes(betas, shapedirs)
+
+ # Get the joints
+ # NxJx3 array
+ J = vertices2joints(J_regressor, v_shaped)
+
+ # 3. Add pose blend shapes
+ # N x J x 3 x 3
+ ident = torch.eye(3, dtype=dtype, device=device)
+ if pose2rot:
+ rot_mats = batch_rodrigues(
+ pose.view(-1, 3), dtype=dtype).view([batch_size, -1, 3, 3])
+
+ pose_feature = (rot_mats[:, 1:, :, :] - ident).view([batch_size, -1])
+ # (N x P) x (P, V * 3) -> N x V x 3
+ pose_offsets = torch.matmul(pose_feature, posedirs) \
+ .view(batch_size, -1, 3)
+ else:
+ pose_feature = pose[:, 1:].view(batch_size, -1, 3, 3) - ident
+ rot_mats = pose.view(batch_size, -1, 3, 3)
+
+ pose_offsets = torch.matmul(pose_feature.view(batch_size, -1),
+ posedirs).view(batch_size, -1, 3)
+
+ v_posed = pose_offsets + v_shaped
+ # 4. Get the global joint location
+ J_transformed, A = batch_rigid_transform(rot_mats, J, parents, dtype=dtype)
+
+ # 5. Do skinning:
+ # W is N x V x (J + 1)
+ W = lbs_weights.unsqueeze(dim=0).expand([batch_size, -1, -1])
+ # (N x V x (J + 1)) x (N x (J + 1) x 16)
+ num_joints = J_regressor.shape[0]
+ T = torch.matmul(W, A.view(batch_size, num_joints, 16)) \
+ .view(batch_size, -1, 4, 4)
+
+ homogen_coord = torch.ones([batch_size, v_posed.shape[1], 1],
+ dtype=dtype, device=device)
+ v_posed_homo = torch.cat([v_posed, homogen_coord], dim=2)
+ v_homo = torch.matmul(T, torch.unsqueeze(v_posed_homo, dim=-1))
+
+ verts = v_homo[:, :, :3, 0]
+
+ return verts, J_transformed
+
+
+def vertices2joints(J_regressor, vertices):
+ ''' Calculates the 3D joint locations from the vertices
+
+ Parameters
+ ----------
+ J_regressor : torch.tensor JxV
+ The regressor array that is used to calculate the joints from the
+ position of the vertices
+ vertices : torch.tensor BxVx3
+ The tensor of mesh vertices
+
+ Returns
+ -------
+ torch.tensor BxJx3
+ The location of the joints
+ '''
+
+ return torch.einsum('bik,ji->bjk', [vertices, J_regressor])
+
+
+def blend_shapes(betas, shape_disps):
+ ''' Calculates the per vertex displacement due to the blend shapes
+
+
+ Parameters
+ ----------
+ betas : torch.tensor Bx(num_betas)
+ Blend shape coefficients
+ shape_disps: torch.tensor Vx3x(num_betas)
+ Blend shapes
+
+ Returns
+ -------
+ torch.tensor BxVx3
+ The per-vertex displacement due to shape deformation
+ '''
+
+ # Displacement[b, m, k] = sum_{l} betas[b, l] * shape_disps[m, k, l]
+ # i.e. Multiply each shape displacement by its corresponding beta and
+ # then sum them.
+ blend_shape = torch.einsum('bl,mkl->bmk', [betas, shape_disps])
+ return blend_shape
+
+
+def batch_rodrigues(rot_vecs, epsilon=1e-8, dtype=torch.float32):
+ ''' Calculates the rotation matrices for a batch of rotation vectors
+ Parameters
+ ----------
+ rot_vecs: torch.tensor Nx3
+ array of N axis-angle vectors
+ Returns
+ -------
+ R: torch.tensor Nx3x3
+ The rotation matrices for the given axis-angle parameters
+ '''
+
+ batch_size = rot_vecs.shape[0]
+ device = rot_vecs.device
+
+ angle = torch.norm(rot_vecs + 1e-8, dim=1, keepdim=True)
+ rot_dir = rot_vecs / angle
+
+ cos = torch.unsqueeze(torch.cos(angle), dim=1)
+ sin = torch.unsqueeze(torch.sin(angle), dim=1)
+
+ # Bx1 arrays
+ rx, ry, rz = torch.split(rot_dir, 1, dim=1)
+ K = torch.zeros((batch_size, 3, 3), dtype=dtype, device=device)
+
+ zeros = torch.zeros((batch_size, 1), dtype=dtype, device=device)
+ K = torch.cat([zeros, -rz, ry, rz, zeros, -rx, -ry, rx, zeros], dim=1) \
+ .view((batch_size, 3, 3))
+
+ ident = torch.eye(3, dtype=dtype, device=device).unsqueeze(dim=0)
+ rot_mat = ident + sin * K + (1 - cos) * torch.bmm(K, K)
+ return rot_mat
+
+
+def transform_mat(R, t):
+ ''' Creates a batch of transformation matrices
+ Args:
+ - R: Bx3x3 array of a batch of rotation matrices
+ - t: Bx3x1 array of a batch of translation vectors
+ Returns:
+ - T: Bx4x4 Transformation matrix
+ '''
+ # No padding left or right, only add an extra row
+ return torch.cat([F.pad(R, [0, 0, 0, 1]),
+ F.pad(t, [0, 0, 0, 1], value=1)], dim=2)
+
+
+def batch_rigid_transform(rot_mats, joints, parents, dtype=torch.float32):
+ """
+ Applies a batch of rigid transformations to the joints
+
+ Parameters
+ ----------
+ rot_mats : torch.tensor BxNx3x3
+ Tensor of rotation matrices
+ joints : torch.tensor BxNx3
+ Locations of joints
+ parents : torch.tensor BxN
+ The kinematic tree of each object
+ dtype : torch.dtype, optional:
+ The data type of the created tensors, the default is torch.float32
+
+ Returns
+ -------
+ posed_joints : torch.tensor BxNx3
+ The locations of the joints after applying the pose rotations
+ rel_transforms : torch.tensor BxNx4x4
+ The relative (with respect to the root joint) rigid transformations
+ for all the joints
+ """
+
+ joints = torch.unsqueeze(joints, dim=-1)
+
+ rel_joints = joints.clone()
+ rel_joints[:, 1:] -= joints[:, parents[1:]]
+
+ # transforms_mat = transform_mat(
+ # rot_mats.view(-1, 3, 3),
+ # rel_joints.view(-1, 3, 1)).view(-1, joints.shape[1], 4, 4)
+ transforms_mat = transform_mat(
+ rot_mats.view(-1, 3, 3),
+ rel_joints.reshape(-1, 3, 1)).reshape(-1, joints.shape[1], 4, 4)
+
+ transform_chain = [transforms_mat[:, 0]]
+ for i in range(1, parents.shape[0]):
+ # Subtract the joint location at the rest pose
+ # No need for rotation, since it's identity when at rest
+ curr_res = torch.matmul(transform_chain[parents[i]],
+ transforms_mat[:, i])
+ transform_chain.append(curr_res)
+
+ transforms = torch.stack(transform_chain, dim=1)
+
+ # The last column of the transformations contains the posed joints
+ posed_joints = transforms[:, :, :3, 3]
+
+ # The last column of the transformations contains the posed joints
+ posed_joints = transforms[:, :, :3, 3]
+
+ joints_homogen = F.pad(joints, [0, 0, 0, 1])
+
+ rel_transforms = transforms - F.pad(
+ torch.matmul(transforms, joints_homogen), [3, 0, 0, 0, 0, 0, 0, 0])
+
+ return posed_joints, rel_transforms
\ No newline at end of file
diff --git a/skyreels_a1/src/__init__.py b/skyreels_a1/src/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/skyreels_a1/src/lmk3d_test.py b/skyreels_a1/src/lmk3d_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..f8cdf0915f390792b9fd91d1d1ff1e457a069c80
--- /dev/null
+++ b/skyreels_a1/src/lmk3d_test.py
@@ -0,0 +1,323 @@
+import torch
+import cv2
+import os
+import sys
+import numpy as np
+import argparse
+import math
+from PIL import Image
+from decord import VideoReader
+from skimage.transform import estimate_transform, warp
+from insightface.app import FaceAnalysis
+from diffusers.utils import load_image
+from utils.mediapipe_utils import MediaPipeUtils
+
+sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
+from smirk_encoder import SmirkEncoder
+from FLAME.FLAME import FLAME
+from renderer import Renderer
+from moviepy.editor import ImageSequenceClip
+
+class FaceAnimationProcessor:
+ def __init__(self, device='cuda', checkpoint="pretrained_models/smirk/smirk_encoder.pt"):
+ self.device = device
+ self.app = FaceAnalysis(allowed_modules=['detection'])
+ self.app.prepare(ctx_id=0, det_size=(640, 640))
+ self.smirk_encoder = SmirkEncoder().to(device)
+ self.flame = FLAME(n_shape=300, n_exp=50).to(device)
+ self.renderer = Renderer().to(device)
+ self.load_checkpoint(checkpoint)
+
+ def load_checkpoint(self, checkpoint):
+ checkpoint_data = torch.load(checkpoint)
+ checkpoint_encoder = {k.replace('smirk_encoder.', ''): v for k, v in checkpoint_data.items() if 'smirk_encoder' in k}
+ self.smirk_encoder.load_state_dict(checkpoint_encoder)
+ self.smirk_encoder.eval()
+
+ def face_crop(self, image):
+ height, width, _ = image.shape
+ faces = self.app.get(image)
+ bbox = faces[0]['bbox']
+ w = bbox[2] - bbox[0]
+ h = bbox[3] - bbox[1]
+ x1 = max(0, int(bbox[0] - w/2))
+ x2 = min(width - 1, int(bbox[2] + w/2))
+ w_new = x2 - x1
+ y_offset = (w_new - h) / 2.
+ y1 = max(0, int(bbox[1] - y_offset))
+ y2 = min(height, int(bbox[3] + y_offset))
+ x_comp = int(((x2 - x1) - (y2 - y1)) / 2) if (x2 - x1) > (y2 - y1) else 0
+ x1 += x_comp
+ x2 -= x_comp
+ image_crop = image[y1:y2, x1:x2]
+ return image_crop, x1, y1
+
+ def crop_and_resize(self, image, height, width):
+ image = np.array(image)
+ image_height, image_width, _ = image.shape
+ if image_height / image_width < height / width:
+ croped_width = int(image_height / height * width)
+ left = (image_width - croped_width) // 2
+ image = image[:, left: left+croped_width]
+ else:
+ pad = int((((width / height) * image_height) - image_width) / 2.)
+ padded_image = np.zeros((image_height, image_width + pad * 2, 3), dtype=np.uint8)
+ padded_image[:, pad:pad+image_width] = image
+ image = padded_image
+ return Image.fromarray(image).resize((width, height))
+
+ def rodrigues_to_matrix(self, pose_params):
+ theta = torch.norm(pose_params, dim=-1, keepdim=True)
+ r = pose_params / (theta + 1e-8)
+ cos_theta = torch.cos(theta)
+ sin_theta = torch.sin(theta)
+ r_x = torch.zeros((pose_params.shape[0], 3, 3), device=pose_params.device)
+ r_x[:, 0, 1] = -r[:, 2]
+ r_x[:, 0, 2] = r[:, 1]
+ r_x[:, 1, 0] = r[:, 2]
+ r_x[:, 1, 2] = -r[:, 0]
+ r_x[:, 2, 0] = -r[:, 1]
+ r_x[:, 2, 1] = r[:, 0]
+ R = cos_theta * torch.eye(3, device=pose_params.device).unsqueeze(0) + \
+ sin_theta * r_x + \
+ (1 - cos_theta) * r.unsqueeze(-1) @ r.unsqueeze(-2)
+ return R
+
+ def matrix_to_rodrigues(self, R):
+ cos_theta = (torch.trace(R[0]) - 1) / 2
+ cos_theta = torch.clamp(cos_theta, -1, 1)
+ theta = torch.acos(cos_theta)
+ if abs(theta) < 1e-4:
+ return torch.zeros(1, 3, device=R.device)
+ elif abs(theta - math.pi) < 1e-4:
+ R_plus_I = R[0] + torch.eye(3, device=R.device)
+ col_norms = torch.norm(R_plus_I, dim=0)
+ max_col_idx = torch.argmax(col_norms)
+ v = R_plus_I[:, max_col_idx]
+ v = v / torch.norm(v)
+ return (v * math.pi).unsqueeze(0)
+ sin_theta = torch.sin(theta)
+ r = torch.zeros(1, 3, device=R.device)
+ r[0, 0] = R[0, 2, 1] - R[0, 1, 2]
+ r[0, 1] = R[0, 0, 2] - R[0, 2, 0]
+ r[0, 2] = R[0, 1, 0] - R[0, 0, 1]
+ r = r / (2 * sin_theta)
+ return r * theta
+
+ def crop_face(self, frame, landmarks, scale=1.0, image_size=224):
+ left = np.min(landmarks[:, 0])
+ right = np.max(landmarks[:, 0])
+ top = np.min(landmarks[:, 1])
+ bottom = np.max(landmarks[:, 1])
+ h, w, _ = frame.shape
+ old_size = (right - left + bottom - top) / 2
+ center = np.array([right - (right - left) / 2.0, bottom - (bottom - top) / 2.0])
+ size = int(old_size * scale)
+ src_pts = np.array([[center[0] - size / 2, center[1] - size / 2], [center[0] - size / 2, center[1] + size / 2],
+ [center[0] + size / 2, center[1] - size / 2]])
+ DST_PTS = np.array([[0, 0], [0, image_size - 1], [image_size - 1, 0]])
+ tform = estimate_transform('similarity', src_pts, DST_PTS)
+ tform_original = estimate_transform('similarity', src_pts, src_pts)
+ return tform, tform_original
+
+ def compute_landmark_relation(self, kpt_mediapipe, target_idx=473, ref_indices=[362, 382, 381, 380, 374, 373, 390, 249, 263, 466, 388, 387, 386, 385, 384, 398]):
+ target_point = kpt_mediapipe[target_idx]
+ ref_points = kpt_mediapipe[ref_indices]
+ left_corner = ref_points[0]
+ right_corner = ref_points[8]
+ eye_center = (left_corner + right_corner) / 2
+ eye_width_vector = right_corner - left_corner
+ eye_width = np.linalg.norm(eye_width_vector)
+ eye_direction = eye_width_vector / eye_width
+ eye_vertical = np.array([-eye_direction[1], eye_direction[0]])
+ target_vector = target_point - eye_center
+ x_relative = np.dot(target_vector, eye_direction) / (eye_width/2)
+ y_relative = np.dot(target_vector, eye_vertical) / (eye_width/2)
+ return [np.array([x_relative, y_relative]),target_point,ref_points,ref_indices]
+
+ def process_source_image(self, image_rgb, input_size=224):
+ image_bgr = cv2.cvtColor(image_rgb, cv2.COLOR_RGB2BGR)
+ mediapipe_utils = MediaPipeUtils()
+ kpt_mediapipe, _, _, mediapipe_eye_pose = mediapipe_utils.run_mediapipe(image_bgr)
+ if kpt_mediapipe is None:
+ raise ValueError('Cannot find facial landmarks in the source image')
+ kpt_mediapipe = kpt_mediapipe[..., :2]
+ tform, _ = self.crop_face(image_rgb, kpt_mediapipe, scale=1.4, image_size=input_size)
+ cropped_image = warp(image_rgb, tform.inverse, output_shape=(input_size, input_size), preserve_range=True).astype(np.uint8)
+ cropped_image = torch.tensor(cropped_image).permute(2, 0, 1).unsqueeze(0).float() / 255.0
+ cropped_image = cropped_image.to(self.device)
+ with torch.no_grad():
+ source_outputs = self.smirk_encoder(cropped_image)
+ source_outputs['eye_pose_params'] = torch.tensor(mediapipe_eye_pose).to(self.device)
+ return source_outputs, tform, image_rgb
+
+ def smooth_params(self, data, alpha=0.7):
+ smoothed_data = [data[0]]
+ for i in range(1, len(data)):
+ smoothed_value = alpha * data[i] + (1 - alpha) * smoothed_data[i-1]
+ smoothed_data.append(smoothed_value)
+ return smoothed_data
+
+ def process_driving_img_list(self, img_list, input_size=224):
+ driving_frames = []
+ driving_outputs = []
+ driving_tforms = []
+ weights_473 = []
+ weights_468 = []
+ mediapipe_utils = MediaPipeUtils()
+ for i, frame in enumerate(img_list):
+ frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
+ try:
+ kpt_mediapipe, mediapipe_exp, mediapipe_pose, mediapipe_eye_pose = mediapipe_utils.run_mediapipe(frame)
+ except:
+ print('Warning: No face detected in a frame, skipping this frame')
+ continue
+ if kpt_mediapipe is None:
+ print('Warning: No face detected in a frame, skipping this frame')
+ continue
+ kpt_mediapipe = kpt_mediapipe[..., :2]
+ weights_473.append(self.compute_landmark_relation(kpt_mediapipe))
+ weights_468.append(self.compute_landmark_relation(kpt_mediapipe, target_idx=468, ref_indices=[33, 7, 163, 144, 145, 153, 154, 155, 133, 173, 157, 158, 159, 160, 161, 246]))
+
+
+ tform, _ = self.crop_face(frame, kpt_mediapipe, scale=1.4, image_size=input_size)
+ cropped_frame = warp(frame, tform.inverse, output_shape=(input_size, input_size), preserve_range=True).astype(np.uint8)
+ cropped_frame = cv2.cvtColor(cropped_frame, cv2.COLOR_BGR2RGB)
+ cropped_frame = torch.tensor(cropped_frame).permute(2, 0, 1).unsqueeze(0).float() / 255.0
+ cropped_frame = cropped_frame.to(self.device)
+ with torch.no_grad():
+ outputs = self.smirk_encoder(cropped_frame)
+ outputs['eye_pose_params'] = torch.tensor(mediapipe_eye_pose).to(self.device)
+ outputs['mediapipe_exp'] = torch.tensor(mediapipe_exp).to(self.device)
+ outputs['mediapipe_pose'] = torch.tensor(mediapipe_pose).to(self.device)
+ driving_frames.append(frame)
+ driving_outputs.append(outputs)
+ driving_tforms.append(tform)
+ return driving_frames, driving_outputs, driving_tforms, weights_473, weights_468
+
+ def preprocess_lmk3d(self, source_image=None, driving_image_list=None):
+ source_outputs, source_tform, image_original = self.process_source_image(source_image)
+ _, driving_outputs, driving_video_tform, weights_473, weights_468 = self.process_driving_img_list(driving_image_list)
+ driving_outputs_list = []
+ source_pose_init = source_outputs['pose_params'].clone()
+ driving_outputs_pose = [outputs['pose_params'] for outputs in driving_outputs]
+ driving_outputs_pose = self.smooth_params(driving_outputs_pose)
+ for i, outputs in enumerate(driving_outputs):
+ outputs['pose_params'] = driving_outputs_pose[i]
+ source_outputs['expression_params'] = outputs['expression_params']
+ source_outputs['jaw_params'] = outputs['jaw_params']
+ source_outputs['eye_pose_params'] = outputs['eye_pose_params']
+ source_matrix = self.rodrigues_to_matrix(source_pose_init)
+ driving_matrix_0 = self.rodrigues_to_matrix(driving_outputs[0]['pose_params'])
+ driving_matrix_i = self.rodrigues_to_matrix(driving_outputs[i]['pose_params'])
+ relative_rotation = torch.inverse(driving_matrix_0) @ driving_matrix_i
+ new_rotation = source_matrix @ relative_rotation
+ source_outputs['pose_params'] = self.matrix_to_rodrigues(new_rotation)
+ source_outputs['eyelid_params'] = outputs['eyelid_params']
+ flame_output = self.flame.forward(source_outputs)
+ renderer_output = self.renderer.forward(
+ flame_output['vertices'],
+ source_outputs['cam'],
+ landmarks_fan=flame_output['landmarks_fan'], source_tform=source_tform,
+ tform_512=None, weights_468=weights_468[i], weights_473=weights_473[i],
+ landmarks_mp=flame_output['landmarks_mp'], shape=image_original.shape)
+ rendered_img = renderer_output['rendered_img']
+ driving_outputs_list.extend(np.copy(rendered_img)[np.newaxis, :])
+ return driving_outputs_list
+
+ def ensure_even_dimensions(self, frame):
+ height, width = frame.shape[:2]
+ new_width = width - (width % 2)
+ new_height = height - (height % 2)
+ if new_width != width or new_height != height:
+ frame = cv2.resize(frame, (new_width, new_height))
+ return frame
+
+ def get_global_bbox(self, frames):
+ max_x1, max_y1, max_x2, max_y2 = float('inf'), float('inf'), 0, 0
+ for frame in frames:
+ faces = self.app.get(frame)
+ if not faces:
+ continue
+ bbox = faces[0]['bbox']
+ x1, y1, x2, y2 = bbox[0], bbox[1], bbox[2], bbox[3]
+ max_x1 = min(max_x1, x1)
+ max_y1 = min(max_y1, y1)
+ max_x2 = max(max_x2, x2)
+ max_y2 = max(max_y2, y2)
+ w = max_x2 - max_x1
+ h = max_y2 - max_y1
+ x1 = int(max_x1 - w / 2)
+ x2 = int(max_x2 + w / 2)
+ y_offset = (x2 - x1 - h) / 2
+ y1 = int(max_y1 - y_offset)
+ y2 = int(max_y2 + y_offset)
+ if (x2 - x1) > (y2 - y1):
+ x_comp = int(((x2 - x1) - (y2 - y1)) / 2)
+ x1 += x_comp
+ x2 -= x_comp
+ else:
+ y_comp = int(((y2 - y1) - (x2 - x1)) / 2)
+ y1 += y_comp
+ y2 -= y_comp
+ x1 = max(0, x1)
+ y1 = max(0, y1)
+ x2 = min(frames[0].shape[1], x2)
+ y2 = min(frames[0].shape[0], y2)
+ return int(x1), int(y1), int(x2), int(y2)
+
+ def face_crop_with_global_box(self, image, global_box):
+ x1, y1, x2, y2 = global_box
+ return image[y1:y2, x1:x2]
+
+ def process_video(self, source_image_path, driving_video_path, output_path, sample_size=[480, 720]):
+ image = load_image(source_image_path)
+ image = self.crop_and_resize(image, sample_size[0], sample_size[1])
+ ref_image = np.array(image)
+ ref_image, x1, y1 = self.face_crop(ref_image)
+ face_h, face_w, _ = ref_image.shape
+
+ vr = VideoReader(driving_video_path)
+ fps = vr.get_avg_fps()
+ video_length = len(vr)
+ duration = video_length / fps
+ target_times = np.arange(0, duration, 1/12)
+ frame_indices = (target_times * fps).astype(np.int32)
+ frame_indices = frame_indices[frame_indices < video_length]
+ control_frames = vr.get_batch(frame_indices).asnumpy()[:48]
+ if len(control_frames) < 49:
+ video_lenght_add = 49 - len(control_frames)
+ control_frames = np.concatenate(([control_frames[0]]*2, control_frames[1:len(control_frames)-2], [control_frames[-1]] * video_lenght_add), axis=0)
+
+ control_frames_crop = []
+ global_box = self.get_global_bbox(control_frames)
+ for control_frame in control_frames:
+ frame = self.face_crop_with_global_box(control_frame, global_box)
+ control_frames_crop.append(frame)
+
+ out_frames = self.preprocess_lmk3d(source_image=ref_image, driving_image_list=control_frames_crop)
+
+
+ def write_mp4(video_path, samples, fps=14):
+ clip = ImageSequenceClip(samples, fps=fps)
+ clip.write_videofile(video_path, codec='libx264', ffmpeg_params=["-pix_fmt", "yuv420p", "-crf", "23", "-preset", "medium"])
+
+ concat_frames = []
+ for i in range(len(out_frames)):
+ ref_image_concat = ref_image.copy()
+ driving_frame = cv2.resize(control_frames_crop[i], (face_w, face_h))
+ out_frame = cv2.resize(out_frames[i], (face_w, face_h))
+ concat_frame = np.concatenate([ref_image_concat, driving_frame, out_frame], axis=1)
+ concat_frame = self.ensure_even_dimensions(concat_frame)
+ concat_frames.append(concat_frame)
+ write_mp4(output_path, concat_frames, fps=12)
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser(description="Process video and image for face animation.")
+ parser.add_argument('--source_image', type=str, default="assets/ref_images/1.png", help='Path to the source image.')
+ parser.add_argument('--driving_video', type=str, default="assets/driving_video/1.mp4", help='Path to the driving video.')
+ parser.add_argument('--output_path', type=str, default="./output.mp4", help='Path to save the output video.')
+ args = parser.parse_args()
+
+ processor = FaceAnimationProcessor(checkpoint='./pretrained_models/smirk/SMIRK_em1.pt')
+ processor.process_video(args.source_image, args.driving_video, args.output_path)
\ No newline at end of file
diff --git a/skyreels_a1/src/media_pipe/draw_util.py b/skyreels_a1/src/media_pipe/draw_util.py
new file mode 100644
index 0000000000000000000000000000000000000000..0c0d53efe776c2d1aadb2427892c9dcd6ac2ee9e
--- /dev/null
+++ b/skyreels_a1/src/media_pipe/draw_util.py
@@ -0,0 +1,480 @@
+import cv2
+import mediapipe as mp
+import numpy as np
+from mediapipe.framework.formats import landmark_pb2
+
+class FaceMeshVisualizer:
+ def __init__(self, forehead_edge=False, iris_edge=False, iris_point=False):
+ self.mp_drawing = mp.solutions.drawing_utils
+ mp_face_mesh = mp.solutions.face_mesh
+ self.mp_face_mesh = mp_face_mesh
+ self.forehead_edge = forehead_edge
+
+ DrawingSpec = mp.solutions.drawing_styles.DrawingSpec
+ f_thick = 1
+ f_rad = 1
+ right_iris_draw = DrawingSpec(color=(10, 200, 250), thickness=f_thick, circle_radius=f_rad)
+ right_eye_draw = DrawingSpec(color=(10, 200, 180), thickness=f_thick, circle_radius=f_rad)
+ right_eyebrow_draw = DrawingSpec(color=(10, 220, 180), thickness=f_thick, circle_radius=f_rad)
+ left_iris_draw = DrawingSpec(color=(250, 200, 10), thickness=f_thick, circle_radius=f_rad)
+ left_eye_draw = DrawingSpec(color=(180, 200, 10), thickness=f_thick, circle_radius=f_rad)
+ left_eyebrow_draw = DrawingSpec(color=(180, 220, 10), thickness=f_thick, circle_radius=f_rad)
+ # head_draw = DrawingSpec(color=(10, 200, 10), thickness=f_thick, circle_radius=f_rad)
+ head_draw = DrawingSpec(color=(0, 0, 0), thickness=f_thick, circle_radius=f_rad)
+
+ mouth_draw_obl = DrawingSpec(color=(10, 180, 20), thickness=f_thick, circle_radius=f_rad)
+ mouth_draw_obr = DrawingSpec(color=(20, 10, 180), thickness=f_thick, circle_radius=f_rad)
+
+ mouth_draw_ibl = DrawingSpec(color=(100, 100, 30), thickness=f_thick, circle_radius=f_rad)
+ mouth_draw_ibr = DrawingSpec(color=(100, 150, 50), thickness=f_thick, circle_radius=f_rad)
+
+ mouth_draw_otl = DrawingSpec(color=(20, 80, 100), thickness=f_thick, circle_radius=f_rad)
+ mouth_draw_otr = DrawingSpec(color=(80, 100, 20), thickness=f_thick, circle_radius=f_rad)
+
+ mouth_draw_itl = DrawingSpec(color=(120, 100, 200), thickness=f_thick, circle_radius=f_rad)
+ mouth_draw_itr = DrawingSpec(color=(150 ,120, 100), thickness=f_thick, circle_radius=f_rad)
+
+ self.pupil_landmark_spec = {468: right_iris_draw, 473: left_iris_draw}
+
+ # FACEMESH_LIPS_OUTER_BOTTOM_LEFT = [(61,146),(146,91),(91,181),(181,84),(84,17)]
+ # FACEMESH_LIPS_OUTER_BOTTOM_RIGHT = [(17,314),(314,405),(405,321),(321,375),(375,291)]
+
+ # FACEMESH_LIPS_INNER_BOTTOM_LEFT = [(78,95),(95,88),(88,178),(178,87),(87,14)]
+ # FACEMESH_LIPS_INNER_BOTTOM_RIGHT = [(14,317),(317,402),(402,318),(318,324),(324,308)]
+
+ # FACEMESH_LIPS_OUTER_TOP_LEFT = [(61,185),(185,40),(40,39),(39,37),(37,0)]
+ # FACEMESH_LIPS_OUTER_TOP_RIGHT = [(0,267),(267,269),(269,270),(270,409),(409,291)]
+
+ # FACEMESH_LIPS_INNER_TOP_LEFT = [(78,191),(191,80),(80,81),(81,82),(82,13)]
+ # FACEMESH_LIPS_INNER_TOP_RIGHT = [(13,312),(312,311),(311,310),(310,415),(415,308)]
+
+ # FACEMESH_CUSTOM_FACE_OVAL = [(176, 149), (150, 136), (356, 454), (58, 132), (152, 148), (361, 288), (251, 389), (132, 93), (389, 356), (400, 377), (136, 172), (377, 152), (323, 361), (172, 58), (454, 323), (365, 379), (379, 378), (148, 176), (93, 234), (397, 365), (149, 150), (288, 397), (234, 127), (378, 400), (127, 162), (162, 21)]
+
+ index_mapping = [276, 282, 283, 285, 293, 295, 296, 300, 334, 336, 46, 52, 53,
+ 55, 63, 65, 66, 70, 105, 107, 249, 263, 362, 373, 374, 380,
+ 381, 382, 384, 385, 386, 387, 388, 390, 398, 466, 7, 33, 133,
+ 144, 145, 153, 154, 155, 157, 158, 159, 160, 161, 163, 173, 246,
+ 168, 6, 197, 195, 5, 4, 129, 98, 97, 2, 326, 327, 358,
+ 0, 13, 14, 17, 37, 39, 40, 61, 78, 80, 81, 82, 84,
+ 87, 88, 91, 95, 146, 178, 181, 185, 191, 267, 269, 270, 291,
+ 308, 310, 311, 312, 314, 317, 318, 321, 324, 375, 402, 405, 409,
+ 415]#, 469, 470, 471, 472, 474, 475, 476, 477]
+
+ self.index_mapping = index_mapping
+
+ def safe_index(mapping, value):
+ try:
+ return mapping.index(value)
+ except ValueError:
+ return None
+
+ # 使用新的landmark索引映射
+ FACEMESH_LIPS_OUTER_BOTTOM_LEFT = [
+ (a, b) for a, b in [
+ (safe_index(index_mapping, 61), safe_index(index_mapping, 146)),
+ (safe_index(index_mapping, 146), safe_index(index_mapping, 91)),
+ (safe_index(index_mapping, 91), safe_index(index_mapping, 181)),
+ (safe_index(index_mapping, 181), safe_index(index_mapping, 84)),
+ (safe_index(index_mapping, 84), safe_index(index_mapping, 17))
+ ] if a is not None and b is not None
+ ]
+
+ FACEMESH_LIPS_OUTER_BOTTOM_RIGHT = [
+ (a, b) for a, b in [
+ (safe_index(index_mapping, 17), safe_index(index_mapping, 314)),
+ (safe_index(index_mapping, 314), safe_index(index_mapping, 405)),
+ (safe_index(index_mapping, 405), safe_index(index_mapping, 321)),
+ (safe_index(index_mapping, 321), safe_index(index_mapping, 375)),
+ (safe_index(index_mapping, 375), safe_index(index_mapping, 291))
+ ] if a is not None and b is not None
+ ]
+
+ FACEMESH_LIPS_INNER_BOTTOM_LEFT = [
+ (a, b) for a, b in [
+ (safe_index(index_mapping, 78), safe_index(index_mapping, 95)),
+ (safe_index(index_mapping, 95), safe_index(index_mapping, 88)),
+ (safe_index(index_mapping, 88), safe_index(index_mapping, 178)),
+ (safe_index(index_mapping, 178), safe_index(index_mapping, 87)),
+ (safe_index(index_mapping, 87), safe_index(index_mapping, 14))
+ ] if a is not None and b is not None
+ ]
+
+ FACEMESH_LIPS_INNER_BOTTOM_RIGHT = [
+ (a, b) for a, b in [
+ (safe_index(index_mapping, 14), safe_index(index_mapping, 317)),
+ (safe_index(index_mapping, 317), safe_index(index_mapping, 402)),
+ (safe_index(index_mapping, 402), safe_index(index_mapping, 318)),
+ (safe_index(index_mapping, 318), safe_index(index_mapping, 324)),
+ (safe_index(index_mapping, 324), safe_index(index_mapping, 308))
+ ] if a is not None and b is not None
+ ]
+
+ FACEMESH_LIPS_OUTER_TOP_LEFT = [
+ (a, b) for a, b in [
+ (safe_index(index_mapping, 61), safe_index(index_mapping, 185)),
+ (safe_index(index_mapping, 185), safe_index(index_mapping, 40)),
+ (safe_index(index_mapping, 40), safe_index(index_mapping, 39)),
+ (safe_index(index_mapping, 39), safe_index(index_mapping, 37)),
+ (safe_index(index_mapping, 37), safe_index(index_mapping, 0))
+ ] if a is not None and b is not None
+ ]
+
+ FACEMESH_LIPS_OUTER_TOP_RIGHT = [
+ (a, b) for a, b in [
+ (safe_index(index_mapping, 0), safe_index(index_mapping, 267)),
+ (safe_index(index_mapping, 267), safe_index(index_mapping, 269)),
+ (safe_index(index_mapping, 269), safe_index(index_mapping, 270)),
+ (safe_index(index_mapping, 270), safe_index(index_mapping, 409)),
+ (safe_index(index_mapping, 409), safe_index(index_mapping, 291))
+ ] if a is not None and b is not None
+ ]
+
+ FACEMESH_LIPS_INNER_TOP_LEFT = [
+ (a, b) for a, b in [
+ (safe_index(index_mapping, 78), safe_index(index_mapping, 191)),
+ (safe_index(index_mapping, 191), safe_index(index_mapping, 80)),
+ (safe_index(index_mapping, 80), safe_index(index_mapping, 81)),
+ (safe_index(index_mapping, 81), safe_index(index_mapping, 82)),
+ (safe_index(index_mapping, 82), safe_index(index_mapping, 13))
+ ] if a is not None and b is not None
+ ]
+
+ FACEMESH_LIPS_INNER_TOP_RIGHT = [
+ (a, b) for a, b in [
+ (safe_index(index_mapping, 13), safe_index(index_mapping, 312)),
+ (safe_index(index_mapping, 312), safe_index(index_mapping, 311)),
+ (safe_index(index_mapping, 311), safe_index(index_mapping, 310)),
+ (safe_index(index_mapping, 310), safe_index(index_mapping, 415)),
+ (safe_index(index_mapping, 415), safe_index(index_mapping, 308))
+ ] if a is not None and b is not None
+ ]
+
+
+ FACEMESH_EYE_RIGHT = [
+ (a, b) for a, b in [
+ (safe_index(index_mapping, 144), safe_index(index_mapping, 145)),
+ (safe_index(index_mapping, 145), safe_index(index_mapping, 153)),
+ (safe_index(index_mapping, 153), safe_index(index_mapping, 154)),
+ (safe_index(index_mapping, 154), safe_index(index_mapping, 155)),
+ (safe_index(index_mapping, 155), safe_index(index_mapping, 157)),
+ (safe_index(index_mapping, 157), safe_index(index_mapping, 158)),
+ (safe_index(index_mapping, 158), safe_index(index_mapping, 159)),
+ (safe_index(index_mapping, 159), safe_index(index_mapping, 160)),
+ (safe_index(index_mapping, 160), safe_index(index_mapping, 161)),
+ (safe_index(index_mapping, 161), safe_index(index_mapping, 33)),
+ (safe_index(index_mapping, 33), safe_index(index_mapping, 7)),
+ (safe_index(index_mapping, 7), safe_index(index_mapping, 163)),
+
+ (safe_index(index_mapping, 46), safe_index(index_mapping, 53)),
+ (safe_index(index_mapping, 53), safe_index(index_mapping, 52)),
+ (safe_index(index_mapping, 52), safe_index(index_mapping, 65)),
+ (safe_index(index_mapping, 65), safe_index(index_mapping, 55)),
+
+ (safe_index(index_mapping, 107), safe_index(index_mapping, 66)),
+ (safe_index(index_mapping, 66), safe_index(index_mapping, 105)),
+ (safe_index(index_mapping, 105), safe_index(index_mapping, 63)),
+ (safe_index(index_mapping, 63), safe_index(index_mapping, 70)),
+
+ ] if a is not None and b is not None
+ ]
+ # index_mapping = [276, 282, 283, 285, 293, 295, 296, 300, 334, 336, 46, 52, 53,
+ # 55, 63, 65, 66, 70, 105, 107, 249, 263, 362, 373, 374, 380,
+ # 381, 382, 384, 385, 386, 387, 388, 390, 398, 466, 7, 33, 133,
+ # 144, 145, 153, 154, 155, 157, 158, 159, 160, 161, 163, 173, 246,
+ # 168, 6, 197, 195, 5, 4, 129, 98, 97, 2, 326, 327, 358,
+ # 0, 13, 14, 17, 37, 39, 40, 61, 78, 80, 81, 82, 84,
+ # 87, 88, 91, 95, 146, 178, 181, 185, 191, 267, 269, 270, 291,
+ # 308, 310, 311, 312, 314, 317, 318, 321, 324, 375, 402, 405, 409,
+ # 415]
+
+ FACEMESH_EYE_LEFT = [
+ (a, b) for a, b in [
+ (safe_index(index_mapping, 362), safe_index(index_mapping, 382)),
+ (safe_index(index_mapping, 382), safe_index(index_mapping, 381)),
+ (safe_index(index_mapping, 381), safe_index(index_mapping, 380)),
+ (safe_index(index_mapping, 380), safe_index(index_mapping, 374)),
+ (safe_index(index_mapping, 374), safe_index(index_mapping, 373)),
+ (safe_index(index_mapping, 373), safe_index(index_mapping, 390)),
+ (safe_index(index_mapping, 390), safe_index(index_mapping, 249)),
+ (safe_index(index_mapping, 249), safe_index(index_mapping, 263)),
+ (safe_index(index_mapping, 263), safe_index(index_mapping, 466)),
+ (safe_index(index_mapping, 466), safe_index(index_mapping, 388)),
+ (safe_index(index_mapping, 388), safe_index(index_mapping, 387)),
+ (safe_index(index_mapping, 387), safe_index(index_mapping, 386)),
+ (safe_index(index_mapping, 386), safe_index(index_mapping, 385)),
+ (safe_index(index_mapping, 385), safe_index(index_mapping, 384)),
+ (safe_index(index_mapping, 384), safe_index(index_mapping, 398)),
+ (safe_index(index_mapping, 398), safe_index(index_mapping, 362)),
+
+ # (safe_index(index_mapping, 285), safe_index(index_mapping, 295)),
+ # (safe_index(index_mapping, 295), safe_index(index_mapping, 282)),
+ # (safe_index(index_mapping, 282), safe_index(index_mapping, 283)),
+ # (safe_index(index_mapping, 283), safe_index(index_mapping, 276)),
+
+ # (safe_index(index_mapping, 336), safe_index(index_mapping, 296)),
+ # (safe_index(index_mapping, 296), safe_index(index_mapping, 334)),
+ # (safe_index(index_mapping, 334), safe_index(index_mapping, 293)),
+ # (safe_index(index_mapping, 293), safe_index(index_mapping, 300)),
+
+ ] if a is not None and b is not None
+ ]
+
+
+ FACEMESH_EYE_LEFT_new = [(0,267),(267,269),(269,270),(270,409),(409,291)]
+
+
+
+ FACEMESH_EYEBROW_LEFT = [
+ (a, b) for a, b in [
+ (safe_index(index_mapping, 285), safe_index(index_mapping, 295)),
+ (safe_index(index_mapping, 295), safe_index(index_mapping, 282)),
+ (safe_index(index_mapping, 282), safe_index(index_mapping, 283)),
+ (safe_index(index_mapping, 283), safe_index(index_mapping, 276)),
+
+ (safe_index(index_mapping, 336), safe_index(index_mapping, 296)),
+ (safe_index(index_mapping, 296), safe_index(index_mapping, 334)),
+ (safe_index(index_mapping, 334), safe_index(index_mapping, 293)),
+ (safe_index(index_mapping, 293), safe_index(index_mapping, 300)),
+
+ ] if a is not None and b is not None
+ ]
+
+
+ FACEMESH_EYE_RIGHT = [
+ (a, b) for a, b in [
+ (safe_index(index_mapping, 144), safe_index(index_mapping, 145)),
+ (safe_index(index_mapping, 145), safe_index(index_mapping, 153)),
+ (safe_index(index_mapping, 153), safe_index(index_mapping, 154)),
+ (safe_index(index_mapping, 154), safe_index(index_mapping, 155)),
+ (safe_index(index_mapping, 155), safe_index(index_mapping, 133)),
+ (safe_index(index_mapping, 133), safe_index(index_mapping, 173)),
+ (safe_index(index_mapping, 173), safe_index(index_mapping, 157)),
+ (safe_index(index_mapping, 157), safe_index(index_mapping, 158)),
+ (safe_index(index_mapping, 158), safe_index(index_mapping, 159)),
+ (safe_index(index_mapping, 159), safe_index(index_mapping, 160)),
+ (safe_index(index_mapping, 160), safe_index(index_mapping, 161)),
+ (safe_index(index_mapping, 161), safe_index(index_mapping, 246)),
+ (safe_index(index_mapping, 246), safe_index(index_mapping, 33)),
+ (safe_index(index_mapping, 33), safe_index(index_mapping, 7)),
+ (safe_index(index_mapping, 7), safe_index(index_mapping, 163)),
+ (safe_index(index_mapping, 163), safe_index(index_mapping, 144)),
+
+
+
+ ] if a is not None and b is not None
+ ]
+
+ FACEMESH_EYEBROW_RIGHT = [
+ (a, b) for a, b in [
+ (safe_index(index_mapping, 46), safe_index(index_mapping, 53)),
+ (safe_index(index_mapping, 53), safe_index(index_mapping, 52)),
+ (safe_index(index_mapping, 52), safe_index(index_mapping, 65)),
+ (safe_index(index_mapping, 65), safe_index(index_mapping, 55)),
+
+ (safe_index(index_mapping, 70), safe_index(index_mapping, 63)),
+ (safe_index(index_mapping, 63), safe_index(index_mapping, 105)),
+ (safe_index(index_mapping, 105), safe_index(index_mapping, 66)),
+ (safe_index(index_mapping, 66), safe_index(index_mapping, 107)),
+
+ ] if a is not None and b is not None
+ ]
+
+
+ FACE_LANDMARKS_LEFT_IRIS = [
+ (a, b) for a, b in [
+ (safe_index(index_mapping, 469), safe_index(index_mapping, 470)),
+ (safe_index(index_mapping, 470), safe_index(index_mapping, 471)),
+ (safe_index(index_mapping, 471), safe_index(index_mapping, 472)),
+ (safe_index(index_mapping, 472), safe_index(index_mapping, 469)),
+ ] if a is not None and b is not None
+ ]
+
+ FACE_LANDMARKS_RIGHT_IRIS = [
+ (a, b) for a, b in [
+ (safe_index(index_mapping, 474), safe_index(index_mapping, 475)),
+ (safe_index(index_mapping, 475), safe_index(index_mapping, 476)),
+ (safe_index(index_mapping, 476), safe_index(index_mapping, 477)),
+ (safe_index(index_mapping, 477), safe_index(index_mapping, 474)),
+
+ ] if a is not None and b is not None
+ ]
+
+
+ FACEMESH_CUSTOM_FACE_OVAL = [
+ (a, b) for a, b in [
+ (safe_index(index_mapping, 144), safe_index(index_mapping, 145)),
+ (safe_index(index_mapping, 145), safe_index(index_mapping, 153)),
+ (safe_index(index_mapping, 153), safe_index(index_mapping, 154)),
+ (safe_index(index_mapping, 154), safe_index(index_mapping, 155)),
+ (safe_index(index_mapping, 155), safe_index(index_mapping, 157)),
+ (safe_index(index_mapping, 157), safe_index(index_mapping, 158)),
+ (safe_index(index_mapping, 158), safe_index(index_mapping, 159)),
+ (safe_index(index_mapping, 159), safe_index(index_mapping, 160)),
+ (safe_index(index_mapping, 160), safe_index(index_mapping, 161)),
+ (safe_index(index_mapping, 161), safe_index(index_mapping, 33)),
+ (safe_index(index_mapping, 33), safe_index(index_mapping, 7)),
+ (safe_index(index_mapping, 7), safe_index(index_mapping, 163)),
+
+ (safe_index(index_mapping, 163), safe_index(index_mapping, 144)),
+ (safe_index(index_mapping, 172), safe_index(index_mapping, 58)),
+ (safe_index(index_mapping, 454), safe_index(index_mapping, 323)),
+ (safe_index(index_mapping, 365), safe_index(index_mapping, 379)),
+ (safe_index(index_mapping, 379), safe_index(index_mapping, 378)),
+ (safe_index(index_mapping, 148), safe_index(index_mapping, 176)),
+ (safe_index(index_mapping, 93), safe_index(index_mapping, 234)),
+ (safe_index(index_mapping, 397), safe_index(index_mapping, 365)),
+ (safe_index(index_mapping, 149), safe_index(index_mapping, 150)),
+ (safe_index(index_mapping, 288), safe_index(index_mapping, 397)),
+ (safe_index(index_mapping, 234), safe_index(index_mapping, 127)),
+ (safe_index(index_mapping, 378), safe_index(index_mapping, 400)),
+ (safe_index(index_mapping, 127), safe_index(index_mapping, 162)),
+ (safe_index(index_mapping, 162), safe_index(index_mapping, 21))
+ ] if a is not None and b is not None
+ ]
+
+ # import pdb;pdb.set_trace()
+
+
+
+ # mp_face_mesh.FACEMESH_CONTOURS has all the items we care about.
+ face_connection_spec = {}
+ # if self.forehead_edge:
+ # for edge in mp_face_mesh.FACEMESH_FACE_OVAL:
+ # face_connection_spec[edge] = head_draw
+ # else:
+ for edge in FACEMESH_CUSTOM_FACE_OVAL:
+ face_connection_spec[edge] = head_draw
+ # for edge in mp_face_mesh.FACEMESH_LEFT_EYE:
+ # face_connection_spec[edge] = left_eye_draw
+ # for edge in mp_face_mesh.FACEMESH_LEFT_EYEBROW:
+ # face_connection_spec[edge] = left_eyebrow_draw
+ # for edge in mp_face_mesh.FACEMESH_RIGHT_EYE:
+ # face_connection_spec[edge] = right_eye_draw
+ # for edge in mp_face_mesh.FACEMESH_RIGHT_EYEBROW:
+ # face_connection_spec[edge] = right_eyebrow_draw
+ # if iris_edge:
+ # for edge in mp_face_mesh.FACEMESH_LEFT_IRIS:
+ # face_connection_spec[edge] = left_iris_draw
+ # for edge in mp_face_mesh.FACEMESH_RIGHT_IRIS:
+ # face_connection_spec[edge] = right_iris_draw
+ # for edge in mp_face_mesh.FACEMESH_LIPS:
+ # face_connection_spec[edge] = mouth_draw
+
+
+ for edge in FACEMESH_EYE_LEFT:
+ face_connection_spec[edge] = left_eye_draw
+
+ for edge in FACEMESH_EYEBROW_LEFT:
+ face_connection_spec[edge] = left_eyebrow_draw
+
+ for edge in FACEMESH_EYE_RIGHT:
+ face_connection_spec[edge] = right_eye_draw
+
+ for edge in FACEMESH_EYEBROW_RIGHT:
+ face_connection_spec[edge] = right_eyebrow_draw
+
+ for edge in FACE_LANDMARKS_LEFT_IRIS:
+ face_connection_spec[edge] = left_iris_draw
+
+ for edge in FACE_LANDMARKS_RIGHT_IRIS:
+ face_connection_spec[edge] = right_iris_draw
+
+ for edge in FACEMESH_LIPS_OUTER_BOTTOM_LEFT:
+ face_connection_spec[edge] = mouth_draw_obl
+ for edge in FACEMESH_LIPS_OUTER_BOTTOM_RIGHT:
+ face_connection_spec[edge] = mouth_draw_obr
+ for edge in FACEMESH_LIPS_INNER_BOTTOM_LEFT:
+ face_connection_spec[edge] = mouth_draw_ibl
+ for edge in FACEMESH_LIPS_INNER_BOTTOM_RIGHT:
+ face_connection_spec[edge] = mouth_draw_ibr
+ for edge in FACEMESH_LIPS_OUTER_TOP_LEFT:
+ face_connection_spec[edge] = mouth_draw_otl
+ for edge in FACEMESH_LIPS_OUTER_TOP_RIGHT:
+ face_connection_spec[edge] = mouth_draw_otr
+ for edge in FACEMESH_LIPS_INNER_TOP_LEFT:
+ face_connection_spec[edge] = mouth_draw_itl
+ for edge in FACEMESH_LIPS_INNER_TOP_RIGHT:
+ face_connection_spec[edge] = mouth_draw_itr
+
+ self.iris_point = iris_point
+
+ self.face_connection_spec = face_connection_spec
+
+ def draw_pupils(self, image, landmark_list, drawing_spec, halfwidth: int = 2):
+ """We have a custom function to draw the pupils because the mp.draw_landmarks method requires a parameter for all
+ landmarks. Until our PR is merged into mediapipe, we need this separate method."""
+ if len(image.shape) != 3:
+ raise ValueError("Input image must be H,W,C.")
+ image_rows, image_cols, image_channels = image.shape
+ if image_channels != 3: # BGR channels
+ raise ValueError('Input image must contain three channel bgr data.')
+ for idx, landmark in enumerate(landmark_list.landmark):
+ if (
+ (landmark.HasField('visibility') and landmark.visibility < 0.9) or
+ (landmark.HasField('presence') and landmark.presence < 0.5)
+ ):
+ continue
+ if landmark.x >= 1.0 or landmark.x < 0 or landmark.y >= 1.0 or landmark.y < 0:
+ continue
+ image_x = int(image_cols*landmark.x)
+ image_y = int(image_rows*landmark.y)
+ draw_color = None
+ if isinstance(drawing_spec, Mapping):
+ if drawing_spec.get(idx) is None:
+ continue
+ else:
+ draw_color = drawing_spec[idx].color
+ elif isinstance(drawing_spec, DrawingSpec):
+ draw_color = drawing_spec.color
+ image[image_y-halfwidth:image_y+halfwidth, image_x-halfwidth:image_x+halfwidth, :] = draw_color
+
+ def draw_iris_points(self, image, point_list, halfwidth=2, normed=False):
+ color = (255, 0, 0)
+ for idx, point in enumerate(point_list):
+ if normed:
+ x, y = int(point[0] * image.shape[1]), int(point[1] * image.shape[0])
+ else:
+ x, y = int(point[0]), int(point[1])
+ image[y-halfwidth:y+halfwidth, x-halfwidth:x+halfwidth, :] = color
+ return image
+
+ def draw_landmarks(self, image_size, keypoints, normed=False):
+ ini_size = image_size #[512, 512]
+ image = np.zeros([ini_size[1], ini_size[0], 3], dtype=np.uint8)
+ new_landmarks = landmark_pb2.NormalizedLandmarkList()
+ for i in range(keypoints.shape[0]):
+ landmark = new_landmarks.landmark.add()
+ if normed:
+ landmark.x = keypoints[i, 0]
+ landmark.y = keypoints[i, 1]
+ else:
+ landmark.x = keypoints[i, 0] / image_size[0]
+ landmark.y = keypoints[i, 1] / image_size[1]
+ landmark.z = 1.0
+
+ self.mp_drawing.draw_landmarks(
+ image=image,
+ landmark_list=new_landmarks,
+ connections=self.face_connection_spec.keys(),
+ landmark_drawing_spec=None,
+ connection_drawing_spec=self.face_connection_spec
+ )
+
+ if self.iris_point:
+ image = self.draw_iris_points(image, [keypoints[468], keypoints[473]], halfwidth=3, normed=normed)
+
+ return image
+
+ def draw_mask(self, image_size, keypoints, normed=False):
+ mask = np.zeros([image_size[1], image_size[0], 3], dtype=np.uint8)
+ if normed:
+ keypoints[:, 0] *= image_size[0]
+ keypoints[:, 1] *= image_size[1]
+
+ head_idxs = [21, 162, 127, 234, 93, 132, 58, 172, 136, 150, 149, 176, 148, 152, 377, 400, 378, 379, 365, 397, 288, 361, 323, 454, 356, 389]
+ head_points = np.array(keypoints[head_idxs, :2], np.int32)
+
+ mask = cv2.fillPoly(mask, [head_points], (255, 255, 255))
+ mask = np.array(mask) / 255.0
+
+ return mask
\ No newline at end of file
diff --git a/skyreels_a1/src/media_pipe/draw_util_2d.py b/skyreels_a1/src/media_pipe/draw_util_2d.py
new file mode 100644
index 0000000000000000000000000000000000000000..32fe9f05f6da012c2474007f1ac8cf418a27e988
--- /dev/null
+++ b/skyreels_a1/src/media_pipe/draw_util_2d.py
@@ -0,0 +1,289 @@
+import cv2
+import mediapipe as mp
+import numpy as np
+from mediapipe.framework.formats import landmark_pb2
+
+class FaceMeshVisualizer2d:
+ # def __init__(self,
+ # forehead_edge=False,
+ # upface_only=False,
+ # draw_eye=True,
+ # draw_head=False,
+ # draw_iris=True,
+ # draw_eyebrow=True,
+ # draw_mouse=True,
+ # draw_nose=True,
+ # draw_pupil=True
+ # ):
+ def __init__(self,
+ forehead_edge=True,
+ upface_only=True,
+ draw_eye=True,
+ draw_head=True,
+ draw_iris=True,
+ draw_eyebrow=True,
+ draw_mouse=True,
+ draw_nose=True,
+ draw_pupil=True
+ ):
+ self.mp_drawing = mp.solutions.drawing_utils
+ mp_face_mesh = mp.solutions.face_mesh
+ self.mp_face_mesh = mp_face_mesh
+ self.forehead_edge = forehead_edge
+
+ DrawingSpec = mp.solutions.drawing_styles.DrawingSpec
+ f_thick = 1
+ f_rad = 1
+ right_iris_draw = DrawingSpec(color=(10, 200, 250), thickness=f_thick, circle_radius=f_rad)
+ right_eye_draw = DrawingSpec(color=(10, 200, 180), thickness=f_thick, circle_radius=f_rad)
+ right_eyebrow_draw = DrawingSpec(color=(10, 220, 180), thickness=f_thick, circle_radius=f_rad)
+ left_iris_draw = DrawingSpec(color=(250, 200, 10), thickness=f_thick, circle_radius=f_rad)
+ left_eye_draw = DrawingSpec(color=(180, 200, 10), thickness=f_thick, circle_radius=f_rad)
+ left_eyebrow_draw = DrawingSpec(color=(180, 220, 10), thickness=f_thick, circle_radius=f_rad)
+ head_draw = DrawingSpec(color=(10, 200, 10), thickness=f_thick, circle_radius=f_rad)
+ nose_draw = DrawingSpec(color=(200, 200, 200), thickness=f_thick, circle_radius=f_rad)
+
+ mouth_draw_obl = DrawingSpec(color=(10, 180, 20), thickness=f_thick, circle_radius=f_rad)
+ mouth_draw_obr = DrawingSpec(color=(20, 10, 180), thickness=f_thick, circle_radius=f_rad)
+
+ mouth_draw_ibl = DrawingSpec(color=(100, 100, 30), thickness=f_thick, circle_radius=f_rad)
+ mouth_draw_ibr = DrawingSpec(color=(100, 150, 50), thickness=f_thick, circle_radius=f_rad)
+
+ mouth_draw_otl = DrawingSpec(color=(20, 80, 100), thickness=f_thick, circle_radius=f_rad)
+ mouth_draw_otr = DrawingSpec(color=(80, 100, 20), thickness=f_thick, circle_radius=f_rad)
+
+ mouth_draw_itl = DrawingSpec(color=(120, 100, 200), thickness=f_thick, circle_radius=f_rad)
+ mouth_draw_itr = DrawingSpec(color=(150 ,120, 100), thickness=f_thick, circle_radius=f_rad)
+
+ FACEMESH_LIPS_OUTER_BOTTOM_LEFT = [(61,146),(146,91),(91,181),(181,84),(84,17)]
+ FACEMESH_LIPS_OUTER_BOTTOM_RIGHT = [(17,314),(314,405),(405,321),(321,375),(375,291)]
+
+ FACEMESH_LIPS_INNER_BOTTOM_LEFT = [(78,95),(95,88),(88,178),(178,87),(87,14)]
+ FACEMESH_LIPS_INNER_BOTTOM_RIGHT = [(14,317),(317,402),(402,318),(318,324),(324,308)]
+
+ FACEMESH_LIPS_OUTER_TOP_LEFT = [(61,185),(185,40),(40,39),(39,37),(37,0)]
+ FACEMESH_LIPS_OUTER_TOP_RIGHT = [(0,267),(267,269),(269,270),(270,409),(409,291)]
+
+ FACEMESH_LIPS_INNER_TOP_LEFT = [(78,191),(191,80),(80,81),(81,82),(82,13)]
+ FACEMESH_LIPS_INNER_TOP_RIGHT = [(13,312),(312,311),(311,310),(310,415),(415,308)]
+
+ FACEMESH_CUSTOM_FACE_OVAL = [(176, 149), (150, 136), (356, 454), (58, 132), (152, 148), (361, 288), (251, 389), (132, 93), (389, 356), (400, 377), (136, 172), (377, 152), (323, 361), (172, 58), (454, 323), (365, 379), (379, 378), (148, 176), (93, 234), (397, 365), (149, 150), (288, 397), (234, 127), (378, 400), (127, 162), (162, 21)]
+
+ # mp_face_mesh.FACEMESH_CONTOURS has all the items we care about.
+ face_connection_spec = {}
+
+ #from IPython import embed
+ #embed()
+ if self.forehead_edge:
+ for edge in mp_face_mesh.FACEMESH_FACE_OVAL:
+ face_connection_spec[edge] = head_draw
+ else:
+ if draw_head:
+ FACEMESH_CUSTOM_FACE_OVAL_sorted = sorted(FACEMESH_CUSTOM_FACE_OVAL)
+ if upface_only:
+ for edge in [FACEMESH_CUSTOM_FACE_OVAL_sorted[edge_idx] for edge_idx in [1,2,9,12,13,16,22,25]]:
+ face_connection_spec[edge] = head_draw
+ else:
+ for edge in FACEMESH_CUSTOM_FACE_OVAL_sorted:
+ face_connection_spec[edge] = head_draw
+
+ if draw_eye:
+ for edge in mp_face_mesh.FACEMESH_LEFT_EYE:
+ face_connection_spec[edge] = left_eye_draw
+ for edge in mp_face_mesh.FACEMESH_RIGHT_EYE:
+ face_connection_spec[edge] = right_eye_draw
+
+ if draw_eyebrow:
+ for edge in mp_face_mesh.FACEMESH_LEFT_EYEBROW:
+ face_connection_spec[edge] = left_eyebrow_draw
+
+ for edge in mp_face_mesh.FACEMESH_RIGHT_EYEBROW:
+ face_connection_spec[edge] = right_eyebrow_draw
+
+ if draw_iris:
+ for edge in mp_face_mesh.FACEMESH_LEFT_IRIS:
+ face_connection_spec[edge] = left_iris_draw
+ for edge in mp_face_mesh.FACEMESH_RIGHT_IRIS:
+ face_connection_spec[edge] = right_iris_draw
+
+ #for edge in mp_face_mesh.FACEMESH_RIGHT_EYEBROW:
+ # face_connection_spec[edge] = right_eyebrow_draw
+ # for edge in mp_face_mesh.FACEMESH_RIGHT_IRIS:
+ # face_connection_spec[edge] = right_iris_draw
+
+ # for edge in mp_face_mesh.FACEMESH_LIPS:
+ # face_connection_spec[edge] = mouth_draw
+
+ if draw_mouse:
+ for edge in FACEMESH_LIPS_OUTER_BOTTOM_LEFT:
+ face_connection_spec[edge] = mouth_draw_obl
+ for edge in FACEMESH_LIPS_OUTER_BOTTOM_RIGHT:
+ face_connection_spec[edge] = mouth_draw_obr
+ for edge in FACEMESH_LIPS_INNER_BOTTOM_LEFT:
+ face_connection_spec[edge] = mouth_draw_ibl
+ for edge in FACEMESH_LIPS_INNER_BOTTOM_RIGHT:
+ face_connection_spec[edge] = mouth_draw_ibr
+ for edge in FACEMESH_LIPS_OUTER_TOP_LEFT:
+ face_connection_spec[edge] = mouth_draw_otl
+ for edge in FACEMESH_LIPS_OUTER_TOP_RIGHT:
+ face_connection_spec[edge] = mouth_draw_otr
+ for edge in FACEMESH_LIPS_INNER_TOP_LEFT:
+ face_connection_spec[edge] = mouth_draw_itl
+ for edge in FACEMESH_LIPS_INNER_TOP_RIGHT:
+ face_connection_spec[edge] = mouth_draw_itr
+
+ self.face_connection_spec = face_connection_spec
+
+ self.pupil_landmark_spec = {468: right_iris_draw, 473: left_iris_draw}
+ self.nose_landmark_spec = {4: nose_draw}
+
+ self.draw_pupil = draw_pupil
+ self.draw_nose = draw_nose
+
+ def draw_points(self, image, landmark_list, drawing_spec, halfwidth: int = 2):
+ """We have a custom function to draw the pupils because the mp.draw_landmarks method requires a parameter for all
+ landmarks. Until our PR is merged into mediapipe, we need this separate method."""
+ if len(image.shape) != 3:
+ raise ValueError("Input image must be H,W,C.")
+ image_rows, image_cols, image_channels = image.shape
+ if image_channels != 3: # BGR channels
+ raise ValueError('Input image must contain three channel bgr data.')
+ for idx, landmark in enumerate(landmark_list.landmark):
+ if idx not in drawing_spec:
+ continue
+
+ if (
+ (landmark.HasField('visibility') and landmark.visibility < 0.9) or
+ (landmark.HasField('presence') and landmark.presence < 0.5)
+ ):
+ continue
+ if landmark.x >= 1.0 or landmark.x < 0 or landmark.y >= 1.0 or landmark.y < 0:
+ continue
+
+ image_x = int(image_cols * landmark.x)
+ image_y = int(image_rows * landmark.y)
+
+ draw_color = drawing_spec[idx].color
+ image[image_y - halfwidth : image_y + halfwidth, image_x - halfwidth : image_x + halfwidth, :] = draw_color
+
+
+ def draw_landmarks(self, image_size, keypoints, normed=False):
+ ini_size = [512, 512]
+ image = np.zeros([ini_size[1], ini_size[0], 3], dtype=np.uint8)
+ if keypoints is not None:
+ new_landmarks = landmark_pb2.NormalizedLandmarkList()
+ for i in range(keypoints.shape[0]):
+ landmark = new_landmarks.landmark.add()
+ if normed:
+ landmark.x = keypoints[i, 0]
+ landmark.y = keypoints[i, 1]
+ else:
+ landmark.x = keypoints[i, 0] / image_size[0]
+ landmark.y = keypoints[i, 1] / image_size[1]
+ landmark.z = 1.0
+
+ self.mp_drawing.draw_landmarks(
+ image=image,
+ landmark_list=new_landmarks,
+ connections=self.face_connection_spec.keys(),
+ landmark_drawing_spec=None,
+ connection_drawing_spec=self.face_connection_spec
+ )
+
+ # if self.draw_pupil:
+ # self.draw_points(image, new_landmarks, self.pupil_landmark_spec, 2)
+
+ # if self.draw_nose:
+ # self.draw_points(image, new_landmarks, self.nose_landmark_spec, 2)
+
+ if self.draw_pupil:
+ self.draw_points(image, new_landmarks, self.pupil_landmark_spec, 1)
+
+ if self.draw_nose:
+ self.draw_points(image, new_landmarks, self.nose_landmark_spec, 2)
+
+ image = cv2.resize(image, (image_size[0], image_size[1]))
+
+ return image
+
+ def draw_landmarks_v2(self, image_size, keypoints, normed=False):
+ # ini_size = [512, 512]
+ ini_size = [image_size[0], image_size[1]]
+ image = np.zeros([ini_size[1], ini_size[0], 3], dtype=np.uint8)
+ if keypoints is not None:
+ new_landmarks = landmark_pb2.NormalizedLandmarkList()
+ for i in range(keypoints.shape[0]):
+ landmark = new_landmarks.landmark.add()
+ if normed:
+ landmark.x = keypoints[i, 0]
+ landmark.y = keypoints[i, 1]
+ else:
+ landmark.x = keypoints[i, 0] / image_size[0]
+ landmark.y = keypoints[i, 1] / image_size[1]
+ landmark.z = 1.0
+
+ self.mp_drawing.draw_landmarks(
+ image=image,
+ landmark_list=new_landmarks,
+ connections=self.face_connection_spec.keys(),
+ landmark_drawing_spec=None,
+ connection_drawing_spec=self.face_connection_spec
+ )
+
+ # if self.draw_pupil:
+ # self.draw_points(image, new_landmarks, self.pupil_landmark_spec, 2)
+
+ # if self.draw_nose:
+ # self.draw_points(image, new_landmarks, self.nose_landmark_spec, 2)
+
+ if self.draw_pupil:
+ self.draw_points(image, new_landmarks, self.pupil_landmark_spec, 1)
+
+ if self.draw_nose:
+ self.draw_points(image, new_landmarks, self.nose_landmark_spec, 2)
+
+ image = cv2.resize(image, (image_size[0], image_size[1]))
+
+ return image
+
+ def draw_landmarks_v3(self, image_size, resize_size, keypoints, normed=False):
+ # ini_size = [512, 512]
+ ini_size = [resize_size[0], resize_size[1]]
+ image = np.zeros([ini_size[1], ini_size[0], 3], dtype=np.uint8)
+ if keypoints is not None:
+ new_landmarks = landmark_pb2.NormalizedLandmarkList()
+ for i in range(keypoints.shape[0]):
+ landmark = new_landmarks.landmark.add()
+ if normed:
+ # landmark.x = keypoints[i, 0] * resize_size[0] / image_size[0]
+ # landmark.y = keypoints[i, 1] * resize_size[1] / image_size[1]
+ landmark.x = keypoints[i, 0]
+ landmark.y = keypoints[i, 1]
+ else:
+ landmark.x = keypoints[i, 0] / image_size[0]
+ landmark.y = keypoints[i, 1] / image_size[1]
+ landmark.z = 1.0
+
+ self.mp_drawing.draw_landmarks(
+ image=image,
+ landmark_list=new_landmarks,
+ connections=self.face_connection_spec.keys(),
+ landmark_drawing_spec=None,
+ connection_drawing_spec=self.face_connection_spec
+ )
+
+ # if self.draw_pupil:
+ # self.draw_points(image, new_landmarks, self.pupil_landmark_spec, 2)
+
+ # if self.draw_nose:
+ # self.draw_points(image, new_landmarks, self.nose_landmark_spec, 2)
+
+ if self.draw_pupil:
+ self.draw_points(image, new_landmarks, self.pupil_landmark_spec, 1)
+
+ if self.draw_nose:
+ self.draw_points(image, new_landmarks, self.nose_landmark_spec, 2)
+
+ image = cv2.resize(image, (resize_size[0], resize_size[1]))
+
+ return image
diff --git a/skyreels_a1/src/media_pipe/face_landmark.py b/skyreels_a1/src/media_pipe/face_landmark.py
new file mode 100644
index 0000000000000000000000000000000000000000..b6580cb2cded9dcfeab46b0d50c8931ed6256669
--- /dev/null
+++ b/skyreels_a1/src/media_pipe/face_landmark.py
@@ -0,0 +1,3305 @@
+# Copyright 2023 The MediaPipe Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""MediaPipe face landmarker task."""
+
+import dataclasses
+import enum
+from typing import Callable, Mapping, Optional, List
+
+import numpy as np
+
+from mediapipe.framework.formats import classification_pb2
+from mediapipe.framework.formats import landmark_pb2
+from mediapipe.framework.formats import matrix_data_pb2
+from mediapipe.python import packet_creator
+from mediapipe.python import packet_getter
+from mediapipe.python._framework_bindings import image as image_module
+from mediapipe.python._framework_bindings import packet as packet_module
+# pylint: disable=unused-import
+from mediapipe.tasks.cc.vision.face_geometry.proto import face_geometry_pb2
+# pylint: enable=unused-import
+from mediapipe.tasks.cc.vision.face_landmarker.proto import face_landmarker_graph_options_pb2
+from mediapipe.tasks.python.components.containers import category as category_module
+from mediapipe.tasks.python.components.containers import landmark as landmark_module
+from mediapipe.tasks.python.core import base_options as base_options_module
+from mediapipe.tasks.python.core import task_info as task_info_module
+from mediapipe.tasks.python.core.optional_dependencies import doc_controls
+from mediapipe.tasks.python.vision.core import base_vision_task_api
+from mediapipe.tasks.python.vision.core import image_processing_options as image_processing_options_module
+from mediapipe.tasks.python.vision.core import vision_task_running_mode as running_mode_module
+
+_BaseOptions = base_options_module.BaseOptions
+_FaceLandmarkerGraphOptionsProto = (
+ face_landmarker_graph_options_pb2.FaceLandmarkerGraphOptions
+)
+_LayoutEnum = matrix_data_pb2.MatrixData.Layout
+_RunningMode = running_mode_module.VisionTaskRunningMode
+_ImageProcessingOptions = image_processing_options_module.ImageProcessingOptions
+_TaskInfo = task_info_module.TaskInfo
+
+_IMAGE_IN_STREAM_NAME = 'image_in'
+_IMAGE_OUT_STREAM_NAME = 'image_out'
+_IMAGE_TAG = 'IMAGE'
+_NORM_RECT_STREAM_NAME = 'norm_rect_in'
+_NORM_RECT_TAG = 'NORM_RECT'
+_NORM_LANDMARKS_STREAM_NAME = 'norm_landmarks'
+_NORM_LANDMARKS_TAG = 'NORM_LANDMARKS'
+_BLENDSHAPES_STREAM_NAME = 'blendshapes'
+_BLENDSHAPES_TAG = 'BLENDSHAPES'
+_FACE_GEOMETRY_STREAM_NAME = 'face_geometry'
+_FACE_GEOMETRY_TAG = 'FACE_GEOMETRY'
+_TASK_GRAPH_NAME = 'mediapipe.tasks.vision.face_landmarker.FaceLandmarkerGraph'
+_MICRO_SECONDS_PER_MILLISECOND = 1000
+
+
+class Blendshapes(enum.IntEnum):
+ """The 52 blendshape coefficients."""
+
+ NEUTRAL = 0
+ BROW_DOWN_LEFT = 1
+ BROW_DOWN_RIGHT = 2
+ BROW_INNER_UP = 3
+ BROW_OUTER_UP_LEFT = 4
+ BROW_OUTER_UP_RIGHT = 5
+ CHEEK_PUFF = 6
+ CHEEK_SQUINT_LEFT = 7
+ CHEEK_SQUINT_RIGHT = 8
+ EYE_BLINK_LEFT = 9
+ EYE_BLINK_RIGHT = 10
+ EYE_LOOK_DOWN_LEFT = 11
+ EYE_LOOK_DOWN_RIGHT = 12
+ EYE_LOOK_IN_LEFT = 13
+ EYE_LOOK_IN_RIGHT = 14
+ EYE_LOOK_OUT_LEFT = 15
+ EYE_LOOK_OUT_RIGHT = 16
+ EYE_LOOK_UP_LEFT = 17
+ EYE_LOOK_UP_RIGHT = 18
+ EYE_SQUINT_LEFT = 19
+ EYE_SQUINT_RIGHT = 20
+ EYE_WIDE_LEFT = 21
+ EYE_WIDE_RIGHT = 22
+ JAW_FORWARD = 23
+ JAW_LEFT = 24
+ JAW_OPEN = 25
+ JAW_RIGHT = 26
+ MOUTH_CLOSE = 27
+ MOUTH_DIMPLE_LEFT = 28
+ MOUTH_DIMPLE_RIGHT = 29
+ MOUTH_FROWN_LEFT = 30
+ MOUTH_FROWN_RIGHT = 31
+ MOUTH_FUNNEL = 32
+ MOUTH_LEFT = 33
+ MOUTH_LOWER_DOWN_LEFT = 34
+ MOUTH_LOWER_DOWN_RIGHT = 35
+ MOUTH_PRESS_LEFT = 36
+ MOUTH_PRESS_RIGHT = 37
+ MOUTH_PUCKER = 38
+ MOUTH_RIGHT = 39
+ MOUTH_ROLL_LOWER = 40
+ MOUTH_ROLL_UPPER = 41
+ MOUTH_SHRUG_LOWER = 42
+ MOUTH_SHRUG_UPPER = 43
+ MOUTH_SMILE_LEFT = 44
+ MOUTH_SMILE_RIGHT = 45
+ MOUTH_STRETCH_LEFT = 46
+ MOUTH_STRETCH_RIGHT = 47
+ MOUTH_UPPER_UP_LEFT = 48
+ MOUTH_UPPER_UP_RIGHT = 49
+ NOSE_SNEER_LEFT = 50
+ NOSE_SNEER_RIGHT = 51
+
+
+class FaceLandmarksConnections:
+ """The connections between face landmarks."""
+
+ @dataclasses.dataclass
+ class Connection:
+ """The connection class for face landmarks."""
+
+ start: int
+ end: int
+
+ FACE_LANDMARKS_LIPS: List[Connection] = [
+ Connection(61, 146),
+ Connection(146, 91),
+ Connection(91, 181),
+ Connection(181, 84),
+ Connection(84, 17),
+ Connection(17, 314),
+ Connection(314, 405),
+ Connection(405, 321),
+ Connection(321, 375),
+ Connection(375, 291),
+ Connection(61, 185),
+ Connection(185, 40),
+ Connection(40, 39),
+ Connection(39, 37),
+ Connection(37, 0),
+ Connection(0, 267),
+ Connection(267, 269),
+ Connection(269, 270),
+ Connection(270, 409),
+ Connection(409, 291),
+ Connection(78, 95),
+ Connection(95, 88),
+ Connection(88, 178),
+ Connection(178, 87),
+ Connection(87, 14),
+ Connection(14, 317),
+ Connection(317, 402),
+ Connection(402, 318),
+ Connection(318, 324),
+ Connection(324, 308),
+ Connection(78, 191),
+ Connection(191, 80),
+ Connection(80, 81),
+ Connection(81, 82),
+ Connection(82, 13),
+ Connection(13, 312),
+ Connection(312, 311),
+ Connection(311, 310),
+ Connection(310, 415),
+ Connection(415, 308),
+ ]
+
+ FACE_LANDMARKS_LEFT_EYE: List[Connection] = [
+ Connection(263, 249),
+ Connection(249, 390),
+ Connection(390, 373),
+ Connection(373, 374),
+ Connection(374, 380),
+ Connection(380, 381),
+ Connection(381, 382),
+ Connection(382, 362),
+ Connection(263, 466),
+ Connection(466, 388),
+ Connection(388, 387),
+ Connection(387, 386),
+ Connection(386, 385),
+ Connection(385, 384),
+ Connection(384, 398),
+ Connection(398, 362),
+ ]
+
+ FACE_LANDMARKS_LEFT_EYEBROW: List[Connection] = [
+ Connection(276, 283),
+ Connection(283, 282),
+ Connection(282, 295),
+ Connection(295, 285),
+ Connection(300, 293),
+ Connection(293, 334),
+ Connection(334, 296),
+ Connection(296, 336),
+ ]
+
+ FACE_LANDMARKS_LEFT_IRIS: List[Connection] = [
+ Connection(474, 475),
+ Connection(475, 476),
+ Connection(476, 477),
+ Connection(477, 474),
+ ]
+
+ FACE_LANDMARKS_RIGHT_EYE: List[Connection] = [
+ Connection(33, 7),
+ Connection(7, 163),
+ Connection(163, 144),
+ Connection(144, 145),
+ Connection(145, 153),
+ Connection(153, 154),
+ Connection(154, 155),
+ Connection(155, 133),
+ Connection(33, 246),
+ Connection(246, 161),
+ Connection(161, 160),
+ Connection(160, 159),
+ Connection(159, 158),
+ Connection(158, 157),
+ Connection(157, 173),
+ Connection(173, 133),
+ ]
+
+ FACE_LANDMARKS_RIGHT_EYEBROW: List[Connection] = [
+ Connection(46, 53),
+ Connection(53, 52),
+ Connection(52, 65),
+ Connection(65, 55),
+ Connection(70, 63),
+ Connection(63, 105),
+ Connection(105, 66),
+ Connection(66, 107),
+ ]
+
+ FACE_LANDMARKS_RIGHT_IRIS: List[Connection] = [
+ Connection(469, 470),
+ Connection(470, 471),
+ Connection(471, 472),
+ Connection(472, 469),
+ ]
+
+ FACE_LANDMARKS_FACE_OVAL: List[Connection] = [
+ Connection(10, 338),
+ Connection(338, 297),
+ Connection(297, 332),
+ Connection(332, 284),
+ Connection(284, 251),
+ Connection(251, 389),
+ Connection(389, 356),
+ Connection(356, 454),
+ Connection(454, 323),
+ Connection(323, 361),
+ Connection(361, 288),
+ Connection(288, 397),
+ Connection(397, 365),
+ Connection(365, 379),
+ Connection(379, 378),
+ Connection(378, 400),
+ Connection(400, 377),
+ Connection(377, 152),
+ Connection(152, 148),
+ Connection(148, 176),
+ Connection(176, 149),
+ Connection(149, 150),
+ Connection(150, 136),
+ Connection(136, 172),
+ Connection(172, 58),
+ Connection(58, 132),
+ Connection(132, 93),
+ Connection(93, 234),
+ Connection(234, 127),
+ Connection(127, 162),
+ Connection(162, 21),
+ Connection(21, 54),
+ Connection(54, 103),
+ Connection(103, 67),
+ Connection(67, 109),
+ Connection(109, 10),
+ ]
+
+ FACE_LANDMARKS_CONTOURS: List[Connection] = (
+ FACE_LANDMARKS_LIPS
+ + FACE_LANDMARKS_LEFT_EYE
+ + FACE_LANDMARKS_LEFT_EYEBROW
+ + FACE_LANDMARKS_RIGHT_EYE
+ + FACE_LANDMARKS_RIGHT_EYEBROW
+ + FACE_LANDMARKS_FACE_OVAL
+ )
+
+ FACE_LANDMARKS_TESSELATION: List[Connection] = [
+ Connection(127, 34),
+ Connection(34, 139),
+ Connection(139, 127),
+ Connection(11, 0),
+ Connection(0, 37),
+ Connection(37, 11),
+ Connection(232, 231),
+ Connection(231, 120),
+ Connection(120, 232),
+ Connection(72, 37),
+ Connection(37, 39),
+ Connection(39, 72),
+ Connection(128, 121),
+ Connection(121, 47),
+ Connection(47, 128),
+ Connection(232, 121),
+ Connection(121, 128),
+ Connection(128, 232),
+ Connection(104, 69),
+ Connection(69, 67),
+ Connection(67, 104),
+ Connection(175, 171),
+ Connection(171, 148),
+ Connection(148, 175),
+ Connection(118, 50),
+ Connection(50, 101),
+ Connection(101, 118),
+ Connection(73, 39),
+ Connection(39, 40),
+ Connection(40, 73),
+ Connection(9, 151),
+ Connection(151, 108),
+ Connection(108, 9),
+ Connection(48, 115),
+ Connection(115, 131),
+ Connection(131, 48),
+ Connection(194, 204),
+ Connection(204, 211),
+ Connection(211, 194),
+ Connection(74, 40),
+ Connection(40, 185),
+ Connection(185, 74),
+ Connection(80, 42),
+ Connection(42, 183),
+ Connection(183, 80),
+ Connection(40, 92),
+ Connection(92, 186),
+ Connection(186, 40),
+ Connection(230, 229),
+ Connection(229, 118),
+ Connection(118, 230),
+ Connection(202, 212),
+ Connection(212, 214),
+ Connection(214, 202),
+ Connection(83, 18),
+ Connection(18, 17),
+ Connection(17, 83),
+ Connection(76, 61),
+ Connection(61, 146),
+ Connection(146, 76),
+ Connection(160, 29),
+ Connection(29, 30),
+ Connection(30, 160),
+ Connection(56, 157),
+ Connection(157, 173),
+ Connection(173, 56),
+ Connection(106, 204),
+ Connection(204, 194),
+ Connection(194, 106),
+ Connection(135, 214),
+ Connection(214, 192),
+ Connection(192, 135),
+ Connection(203, 165),
+ Connection(165, 98),
+ Connection(98, 203),
+ Connection(21, 71),
+ Connection(71, 68),
+ Connection(68, 21),
+ Connection(51, 45),
+ Connection(45, 4),
+ Connection(4, 51),
+ Connection(144, 24),
+ Connection(24, 23),
+ Connection(23, 144),
+ Connection(77, 146),
+ Connection(146, 91),
+ Connection(91, 77),
+ Connection(205, 50),
+ Connection(50, 187),
+ Connection(187, 205),
+ Connection(201, 200),
+ Connection(200, 18),
+ Connection(18, 201),
+ Connection(91, 106),
+ Connection(106, 182),
+ Connection(182, 91),
+ Connection(90, 91),
+ Connection(91, 181),
+ Connection(181, 90),
+ Connection(85, 84),
+ Connection(84, 17),
+ Connection(17, 85),
+ Connection(206, 203),
+ Connection(203, 36),
+ Connection(36, 206),
+ Connection(148, 171),
+ Connection(171, 140),
+ Connection(140, 148),
+ Connection(92, 40),
+ Connection(40, 39),
+ Connection(39, 92),
+ Connection(193, 189),
+ Connection(189, 244),
+ Connection(244, 193),
+ Connection(159, 158),
+ Connection(158, 28),
+ Connection(28, 159),
+ Connection(247, 246),
+ Connection(246, 161),
+ Connection(161, 247),
+ Connection(236, 3),
+ Connection(3, 196),
+ Connection(196, 236),
+ Connection(54, 68),
+ Connection(68, 104),
+ Connection(104, 54),
+ Connection(193, 168),
+ Connection(168, 8),
+ Connection(8, 193),
+ Connection(117, 228),
+ Connection(228, 31),
+ Connection(31, 117),
+ Connection(189, 193),
+ Connection(193, 55),
+ Connection(55, 189),
+ Connection(98, 97),
+ Connection(97, 99),
+ Connection(99, 98),
+ Connection(126, 47),
+ Connection(47, 100),
+ Connection(100, 126),
+ Connection(166, 79),
+ Connection(79, 218),
+ Connection(218, 166),
+ Connection(155, 154),
+ Connection(154, 26),
+ Connection(26, 155),
+ Connection(209, 49),
+ Connection(49, 131),
+ Connection(131, 209),
+ Connection(135, 136),
+ Connection(136, 150),
+ Connection(150, 135),
+ Connection(47, 126),
+ Connection(126, 217),
+ Connection(217, 47),
+ Connection(223, 52),
+ Connection(52, 53),
+ Connection(53, 223),
+ Connection(45, 51),
+ Connection(51, 134),
+ Connection(134, 45),
+ Connection(211, 170),
+ Connection(170, 140),
+ Connection(140, 211),
+ Connection(67, 69),
+ Connection(69, 108),
+ Connection(108, 67),
+ Connection(43, 106),
+ Connection(106, 91),
+ Connection(91, 43),
+ Connection(230, 119),
+ Connection(119, 120),
+ Connection(120, 230),
+ Connection(226, 130),
+ Connection(130, 247),
+ Connection(247, 226),
+ Connection(63, 53),
+ Connection(53, 52),
+ Connection(52, 63),
+ Connection(238, 20),
+ Connection(20, 242),
+ Connection(242, 238),
+ Connection(46, 70),
+ Connection(70, 156),
+ Connection(156, 46),
+ Connection(78, 62),
+ Connection(62, 96),
+ Connection(96, 78),
+ Connection(46, 53),
+ Connection(53, 63),
+ Connection(63, 46),
+ Connection(143, 34),
+ Connection(34, 227),
+ Connection(227, 143),
+ Connection(123, 117),
+ Connection(117, 111),
+ Connection(111, 123),
+ Connection(44, 125),
+ Connection(125, 19),
+ Connection(19, 44),
+ Connection(236, 134),
+ Connection(134, 51),
+ Connection(51, 236),
+ Connection(216, 206),
+ Connection(206, 205),
+ Connection(205, 216),
+ Connection(154, 153),
+ Connection(153, 22),
+ Connection(22, 154),
+ Connection(39, 37),
+ Connection(37, 167),
+ Connection(167, 39),
+ Connection(200, 201),
+ Connection(201, 208),
+ Connection(208, 200),
+ Connection(36, 142),
+ Connection(142, 100),
+ Connection(100, 36),
+ Connection(57, 212),
+ Connection(212, 202),
+ Connection(202, 57),
+ Connection(20, 60),
+ Connection(60, 99),
+ Connection(99, 20),
+ Connection(28, 158),
+ Connection(158, 157),
+ Connection(157, 28),
+ Connection(35, 226),
+ Connection(226, 113),
+ Connection(113, 35),
+ Connection(160, 159),
+ Connection(159, 27),
+ Connection(27, 160),
+ Connection(204, 202),
+ Connection(202, 210),
+ Connection(210, 204),
+ Connection(113, 225),
+ Connection(225, 46),
+ Connection(46, 113),
+ Connection(43, 202),
+ Connection(202, 204),
+ Connection(204, 43),
+ Connection(62, 76),
+ Connection(76, 77),
+ Connection(77, 62),
+ Connection(137, 123),
+ Connection(123, 116),
+ Connection(116, 137),
+ Connection(41, 38),
+ Connection(38, 72),
+ Connection(72, 41),
+ Connection(203, 129),
+ Connection(129, 142),
+ Connection(142, 203),
+ Connection(64, 98),
+ Connection(98, 240),
+ Connection(240, 64),
+ Connection(49, 102),
+ Connection(102, 64),
+ Connection(64, 49),
+ Connection(41, 73),
+ Connection(73, 74),
+ Connection(74, 41),
+ Connection(212, 216),
+ Connection(216, 207),
+ Connection(207, 212),
+ Connection(42, 74),
+ Connection(74, 184),
+ Connection(184, 42),
+ Connection(169, 170),
+ Connection(170, 211),
+ Connection(211, 169),
+ Connection(170, 149),
+ Connection(149, 176),
+ Connection(176, 170),
+ Connection(105, 66),
+ Connection(66, 69),
+ Connection(69, 105),
+ Connection(122, 6),
+ Connection(6, 168),
+ Connection(168, 122),
+ Connection(123, 147),
+ Connection(147, 187),
+ Connection(187, 123),
+ Connection(96, 77),
+ Connection(77, 90),
+ Connection(90, 96),
+ Connection(65, 55),
+ Connection(55, 107),
+ Connection(107, 65),
+ Connection(89, 90),
+ Connection(90, 180),
+ Connection(180, 89),
+ Connection(101, 100),
+ Connection(100, 120),
+ Connection(120, 101),
+ Connection(63, 105),
+ Connection(105, 104),
+ Connection(104, 63),
+ Connection(93, 137),
+ Connection(137, 227),
+ Connection(227, 93),
+ Connection(15, 86),
+ Connection(86, 85),
+ Connection(85, 15),
+ Connection(129, 102),
+ Connection(102, 49),
+ Connection(49, 129),
+ Connection(14, 87),
+ Connection(87, 86),
+ Connection(86, 14),
+ Connection(55, 8),
+ Connection(8, 9),
+ Connection(9, 55),
+ Connection(100, 47),
+ Connection(47, 121),
+ Connection(121, 100),
+ Connection(145, 23),
+ Connection(23, 22),
+ Connection(22, 145),
+ Connection(88, 89),
+ Connection(89, 179),
+ Connection(179, 88),
+ Connection(6, 122),
+ Connection(122, 196),
+ Connection(196, 6),
+ Connection(88, 95),
+ Connection(95, 96),
+ Connection(96, 88),
+ Connection(138, 172),
+ Connection(172, 136),
+ Connection(136, 138),
+ Connection(215, 58),
+ Connection(58, 172),
+ Connection(172, 215),
+ Connection(115, 48),
+ Connection(48, 219),
+ Connection(219, 115),
+ Connection(42, 80),
+ Connection(80, 81),
+ Connection(81, 42),
+ Connection(195, 3),
+ Connection(3, 51),
+ Connection(51, 195),
+ Connection(43, 146),
+ Connection(146, 61),
+ Connection(61, 43),
+ Connection(171, 175),
+ Connection(175, 199),
+ Connection(199, 171),
+ Connection(81, 82),
+ Connection(82, 38),
+ Connection(38, 81),
+ Connection(53, 46),
+ Connection(46, 225),
+ Connection(225, 53),
+ Connection(144, 163),
+ Connection(163, 110),
+ Connection(110, 144),
+ Connection(52, 65),
+ Connection(65, 66),
+ Connection(66, 52),
+ Connection(229, 228),
+ Connection(228, 117),
+ Connection(117, 229),
+ Connection(34, 127),
+ Connection(127, 234),
+ Connection(234, 34),
+ Connection(107, 108),
+ Connection(108, 69),
+ Connection(69, 107),
+ Connection(109, 108),
+ Connection(108, 151),
+ Connection(151, 109),
+ Connection(48, 64),
+ Connection(64, 235),
+ Connection(235, 48),
+ Connection(62, 78),
+ Connection(78, 191),
+ Connection(191, 62),
+ Connection(129, 209),
+ Connection(209, 126),
+ Connection(126, 129),
+ Connection(111, 35),
+ Connection(35, 143),
+ Connection(143, 111),
+ Connection(117, 123),
+ Connection(123, 50),
+ Connection(50, 117),
+ Connection(222, 65),
+ Connection(65, 52),
+ Connection(52, 222),
+ Connection(19, 125),
+ Connection(125, 141),
+ Connection(141, 19),
+ Connection(221, 55),
+ Connection(55, 65),
+ Connection(65, 221),
+ Connection(3, 195),
+ Connection(195, 197),
+ Connection(197, 3),
+ Connection(25, 7),
+ Connection(7, 33),
+ Connection(33, 25),
+ Connection(220, 237),
+ Connection(237, 44),
+ Connection(44, 220),
+ Connection(70, 71),
+ Connection(71, 139),
+ Connection(139, 70),
+ Connection(122, 193),
+ Connection(193, 245),
+ Connection(245, 122),
+ Connection(247, 130),
+ Connection(130, 33),
+ Connection(33, 247),
+ Connection(71, 21),
+ Connection(21, 162),
+ Connection(162, 71),
+ Connection(170, 169),
+ Connection(169, 150),
+ Connection(150, 170),
+ Connection(188, 174),
+ Connection(174, 196),
+ Connection(196, 188),
+ Connection(216, 186),
+ Connection(186, 92),
+ Connection(92, 216),
+ Connection(2, 97),
+ Connection(97, 167),
+ Connection(167, 2),
+ Connection(141, 125),
+ Connection(125, 241),
+ Connection(241, 141),
+ Connection(164, 167),
+ Connection(167, 37),
+ Connection(37, 164),
+ Connection(72, 38),
+ Connection(38, 12),
+ Connection(12, 72),
+ Connection(38, 82),
+ Connection(82, 13),
+ Connection(13, 38),
+ Connection(63, 68),
+ Connection(68, 71),
+ Connection(71, 63),
+ Connection(226, 35),
+ Connection(35, 111),
+ Connection(111, 226),
+ Connection(101, 50),
+ Connection(50, 205),
+ Connection(205, 101),
+ Connection(206, 92),
+ Connection(92, 165),
+ Connection(165, 206),
+ Connection(209, 198),
+ Connection(198, 217),
+ Connection(217, 209),
+ Connection(165, 167),
+ Connection(167, 97),
+ Connection(97, 165),
+ Connection(220, 115),
+ Connection(115, 218),
+ Connection(218, 220),
+ Connection(133, 112),
+ Connection(112, 243),
+ Connection(243, 133),
+ Connection(239, 238),
+ Connection(238, 241),
+ Connection(241, 239),
+ Connection(214, 135),
+ Connection(135, 169),
+ Connection(169, 214),
+ Connection(190, 173),
+ Connection(173, 133),
+ Connection(133, 190),
+ Connection(171, 208),
+ Connection(208, 32),
+ Connection(32, 171),
+ Connection(125, 44),
+ Connection(44, 237),
+ Connection(237, 125),
+ Connection(86, 87),
+ Connection(87, 178),
+ Connection(178, 86),
+ Connection(85, 86),
+ Connection(86, 179),
+ Connection(179, 85),
+ Connection(84, 85),
+ Connection(85, 180),
+ Connection(180, 84),
+ Connection(83, 84),
+ Connection(84, 181),
+ Connection(181, 83),
+ Connection(201, 83),
+ Connection(83, 182),
+ Connection(182, 201),
+ Connection(137, 93),
+ Connection(93, 132),
+ Connection(132, 137),
+ Connection(76, 62),
+ Connection(62, 183),
+ Connection(183, 76),
+ Connection(61, 76),
+ Connection(76, 184),
+ Connection(184, 61),
+ Connection(57, 61),
+ Connection(61, 185),
+ Connection(185, 57),
+ Connection(212, 57),
+ Connection(57, 186),
+ Connection(186, 212),
+ Connection(214, 207),
+ Connection(207, 187),
+ Connection(187, 214),
+ Connection(34, 143),
+ Connection(143, 156),
+ Connection(156, 34),
+ Connection(79, 239),
+ Connection(239, 237),
+ Connection(237, 79),
+ Connection(123, 137),
+ Connection(137, 177),
+ Connection(177, 123),
+ Connection(44, 1),
+ Connection(1, 4),
+ Connection(4, 44),
+ Connection(201, 194),
+ Connection(194, 32),
+ Connection(32, 201),
+ Connection(64, 102),
+ Connection(102, 129),
+ Connection(129, 64),
+ Connection(213, 215),
+ Connection(215, 138),
+ Connection(138, 213),
+ Connection(59, 166),
+ Connection(166, 219),
+ Connection(219, 59),
+ Connection(242, 99),
+ Connection(99, 97),
+ Connection(97, 242),
+ Connection(2, 94),
+ Connection(94, 141),
+ Connection(141, 2),
+ Connection(75, 59),
+ Connection(59, 235),
+ Connection(235, 75),
+ Connection(24, 110),
+ Connection(110, 228),
+ Connection(228, 24),
+ Connection(25, 130),
+ Connection(130, 226),
+ Connection(226, 25),
+ Connection(23, 24),
+ Connection(24, 229),
+ Connection(229, 23),
+ Connection(22, 23),
+ Connection(23, 230),
+ Connection(230, 22),
+ Connection(26, 22),
+ Connection(22, 231),
+ Connection(231, 26),
+ Connection(112, 26),
+ Connection(26, 232),
+ Connection(232, 112),
+ Connection(189, 190),
+ Connection(190, 243),
+ Connection(243, 189),
+ Connection(221, 56),
+ Connection(56, 190),
+ Connection(190, 221),
+ Connection(28, 56),
+ Connection(56, 221),
+ Connection(221, 28),
+ Connection(27, 28),
+ Connection(28, 222),
+ Connection(222, 27),
+ Connection(29, 27),
+ Connection(27, 223),
+ Connection(223, 29),
+ Connection(30, 29),
+ Connection(29, 224),
+ Connection(224, 30),
+ Connection(247, 30),
+ Connection(30, 225),
+ Connection(225, 247),
+ Connection(238, 79),
+ Connection(79, 20),
+ Connection(20, 238),
+ Connection(166, 59),
+ Connection(59, 75),
+ Connection(75, 166),
+ Connection(60, 75),
+ Connection(75, 240),
+ Connection(240, 60),
+ Connection(147, 177),
+ Connection(177, 215),
+ Connection(215, 147),
+ Connection(20, 79),
+ Connection(79, 166),
+ Connection(166, 20),
+ Connection(187, 147),
+ Connection(147, 213),
+ Connection(213, 187),
+ Connection(112, 233),
+ Connection(233, 244),
+ Connection(244, 112),
+ Connection(233, 128),
+ Connection(128, 245),
+ Connection(245, 233),
+ Connection(128, 114),
+ Connection(114, 188),
+ Connection(188, 128),
+ Connection(114, 217),
+ Connection(217, 174),
+ Connection(174, 114),
+ Connection(131, 115),
+ Connection(115, 220),
+ Connection(220, 131),
+ Connection(217, 198),
+ Connection(198, 236),
+ Connection(236, 217),
+ Connection(198, 131),
+ Connection(131, 134),
+ Connection(134, 198),
+ Connection(177, 132),
+ Connection(132, 58),
+ Connection(58, 177),
+ Connection(143, 35),
+ Connection(35, 124),
+ Connection(124, 143),
+ Connection(110, 163),
+ Connection(163, 7),
+ Connection(7, 110),
+ Connection(228, 110),
+ Connection(110, 25),
+ Connection(25, 228),
+ Connection(356, 389),
+ Connection(389, 368),
+ Connection(368, 356),
+ Connection(11, 302),
+ Connection(302, 267),
+ Connection(267, 11),
+ Connection(452, 350),
+ Connection(350, 349),
+ Connection(349, 452),
+ Connection(302, 303),
+ Connection(303, 269),
+ Connection(269, 302),
+ Connection(357, 343),
+ Connection(343, 277),
+ Connection(277, 357),
+ Connection(452, 453),
+ Connection(453, 357),
+ Connection(357, 452),
+ Connection(333, 332),
+ Connection(332, 297),
+ Connection(297, 333),
+ Connection(175, 152),
+ Connection(152, 377),
+ Connection(377, 175),
+ Connection(347, 348),
+ Connection(348, 330),
+ Connection(330, 347),
+ Connection(303, 304),
+ Connection(304, 270),
+ Connection(270, 303),
+ Connection(9, 336),
+ Connection(336, 337),
+ Connection(337, 9),
+ Connection(278, 279),
+ Connection(279, 360),
+ Connection(360, 278),
+ Connection(418, 262),
+ Connection(262, 431),
+ Connection(431, 418),
+ Connection(304, 408),
+ Connection(408, 409),
+ Connection(409, 304),
+ Connection(310, 415),
+ Connection(415, 407),
+ Connection(407, 310),
+ Connection(270, 409),
+ Connection(409, 410),
+ Connection(410, 270),
+ Connection(450, 348),
+ Connection(348, 347),
+ Connection(347, 450),
+ Connection(422, 430),
+ Connection(430, 434),
+ Connection(434, 422),
+ Connection(313, 314),
+ Connection(314, 17),
+ Connection(17, 313),
+ Connection(306, 307),
+ Connection(307, 375),
+ Connection(375, 306),
+ Connection(387, 388),
+ Connection(388, 260),
+ Connection(260, 387),
+ Connection(286, 414),
+ Connection(414, 398),
+ Connection(398, 286),
+ Connection(335, 406),
+ Connection(406, 418),
+ Connection(418, 335),
+ Connection(364, 367),
+ Connection(367, 416),
+ Connection(416, 364),
+ Connection(423, 358),
+ Connection(358, 327),
+ Connection(327, 423),
+ Connection(251, 284),
+ Connection(284, 298),
+ Connection(298, 251),
+ Connection(281, 5),
+ Connection(5, 4),
+ Connection(4, 281),
+ Connection(373, 374),
+ Connection(374, 253),
+ Connection(253, 373),
+ Connection(307, 320),
+ Connection(320, 321),
+ Connection(321, 307),
+ Connection(425, 427),
+ Connection(427, 411),
+ Connection(411, 425),
+ Connection(421, 313),
+ Connection(313, 18),
+ Connection(18, 421),
+ Connection(321, 405),
+ Connection(405, 406),
+ Connection(406, 321),
+ Connection(320, 404),
+ Connection(404, 405),
+ Connection(405, 320),
+ Connection(315, 16),
+ Connection(16, 17),
+ Connection(17, 315),
+ Connection(426, 425),
+ Connection(425, 266),
+ Connection(266, 426),
+ Connection(377, 400),
+ Connection(400, 369),
+ Connection(369, 377),
+ Connection(322, 391),
+ Connection(391, 269),
+ Connection(269, 322),
+ Connection(417, 465),
+ Connection(465, 464),
+ Connection(464, 417),
+ Connection(386, 257),
+ Connection(257, 258),
+ Connection(258, 386),
+ Connection(466, 260),
+ Connection(260, 388),
+ Connection(388, 466),
+ Connection(456, 399),
+ Connection(399, 419),
+ Connection(419, 456),
+ Connection(284, 332),
+ Connection(332, 333),
+ Connection(333, 284),
+ Connection(417, 285),
+ Connection(285, 8),
+ Connection(8, 417),
+ Connection(346, 340),
+ Connection(340, 261),
+ Connection(261, 346),
+ Connection(413, 441),
+ Connection(441, 285),
+ Connection(285, 413),
+ Connection(327, 460),
+ Connection(460, 328),
+ Connection(328, 327),
+ Connection(355, 371),
+ Connection(371, 329),
+ Connection(329, 355),
+ Connection(392, 439),
+ Connection(439, 438),
+ Connection(438, 392),
+ Connection(382, 341),
+ Connection(341, 256),
+ Connection(256, 382),
+ Connection(429, 420),
+ Connection(420, 360),
+ Connection(360, 429),
+ Connection(364, 394),
+ Connection(394, 379),
+ Connection(379, 364),
+ Connection(277, 343),
+ Connection(343, 437),
+ Connection(437, 277),
+ Connection(443, 444),
+ Connection(444, 283),
+ Connection(283, 443),
+ Connection(275, 440),
+ Connection(440, 363),
+ Connection(363, 275),
+ Connection(431, 262),
+ Connection(262, 369),
+ Connection(369, 431),
+ Connection(297, 338),
+ Connection(338, 337),
+ Connection(337, 297),
+ Connection(273, 375),
+ Connection(375, 321),
+ Connection(321, 273),
+ Connection(450, 451),
+ Connection(451, 349),
+ Connection(349, 450),
+ Connection(446, 342),
+ Connection(342, 467),
+ Connection(467, 446),
+ Connection(293, 334),
+ Connection(334, 282),
+ Connection(282, 293),
+ Connection(458, 461),
+ Connection(461, 462),
+ Connection(462, 458),
+ Connection(276, 353),
+ Connection(353, 383),
+ Connection(383, 276),
+ Connection(308, 324),
+ Connection(324, 325),
+ Connection(325, 308),
+ Connection(276, 300),
+ Connection(300, 293),
+ Connection(293, 276),
+ Connection(372, 345),
+ Connection(345, 447),
+ Connection(447, 372),
+ Connection(352, 345),
+ Connection(345, 340),
+ Connection(340, 352),
+ Connection(274, 1),
+ Connection(1, 19),
+ Connection(19, 274),
+ Connection(456, 248),
+ Connection(248, 281),
+ Connection(281, 456),
+ Connection(436, 427),
+ Connection(427, 425),
+ Connection(425, 436),
+ Connection(381, 256),
+ Connection(256, 252),
+ Connection(252, 381),
+ Connection(269, 391),
+ Connection(391, 393),
+ Connection(393, 269),
+ Connection(200, 199),
+ Connection(199, 428),
+ Connection(428, 200),
+ Connection(266, 330),
+ Connection(330, 329),
+ Connection(329, 266),
+ Connection(287, 273),
+ Connection(273, 422),
+ Connection(422, 287),
+ Connection(250, 462),
+ Connection(462, 328),
+ Connection(328, 250),
+ Connection(258, 286),
+ Connection(286, 384),
+ Connection(384, 258),
+ Connection(265, 353),
+ Connection(353, 342),
+ Connection(342, 265),
+ Connection(387, 259),
+ Connection(259, 257),
+ Connection(257, 387),
+ Connection(424, 431),
+ Connection(431, 430),
+ Connection(430, 424),
+ Connection(342, 353),
+ Connection(353, 276),
+ Connection(276, 342),
+ Connection(273, 335),
+ Connection(335, 424),
+ Connection(424, 273),
+ Connection(292, 325),
+ Connection(325, 307),
+ Connection(307, 292),
+ Connection(366, 447),
+ Connection(447, 345),
+ Connection(345, 366),
+ Connection(271, 303),
+ Connection(303, 302),
+ Connection(302, 271),
+ Connection(423, 266),
+ Connection(266, 371),
+ Connection(371, 423),
+ Connection(294, 455),
+ Connection(455, 460),
+ Connection(460, 294),
+ Connection(279, 278),
+ Connection(278, 294),
+ Connection(294, 279),
+ Connection(271, 272),
+ Connection(272, 304),
+ Connection(304, 271),
+ Connection(432, 434),
+ Connection(434, 427),
+ Connection(427, 432),
+ Connection(272, 407),
+ Connection(407, 408),
+ Connection(408, 272),
+ Connection(394, 430),
+ Connection(430, 431),
+ Connection(431, 394),
+ Connection(395, 369),
+ Connection(369, 400),
+ Connection(400, 395),
+ Connection(334, 333),
+ Connection(333, 299),
+ Connection(299, 334),
+ Connection(351, 417),
+ Connection(417, 168),
+ Connection(168, 351),
+ Connection(352, 280),
+ Connection(280, 411),
+ Connection(411, 352),
+ Connection(325, 319),
+ Connection(319, 320),
+ Connection(320, 325),
+ Connection(295, 296),
+ Connection(296, 336),
+ Connection(336, 295),
+ Connection(319, 403),
+ Connection(403, 404),
+ Connection(404, 319),
+ Connection(330, 348),
+ Connection(348, 349),
+ Connection(349, 330),
+ Connection(293, 298),
+ Connection(298, 333),
+ Connection(333, 293),
+ Connection(323, 454),
+ Connection(454, 447),
+ Connection(447, 323),
+ Connection(15, 16),
+ Connection(16, 315),
+ Connection(315, 15),
+ Connection(358, 429),
+ Connection(429, 279),
+ Connection(279, 358),
+ Connection(14, 15),
+ Connection(15, 316),
+ Connection(316, 14),
+ Connection(285, 336),
+ Connection(336, 9),
+ Connection(9, 285),
+ Connection(329, 349),
+ Connection(349, 350),
+ Connection(350, 329),
+ Connection(374, 380),
+ Connection(380, 252),
+ Connection(252, 374),
+ Connection(318, 402),
+ Connection(402, 403),
+ Connection(403, 318),
+ Connection(6, 197),
+ Connection(197, 419),
+ Connection(419, 6),
+ Connection(318, 319),
+ Connection(319, 325),
+ Connection(325, 318),
+ Connection(367, 364),
+ Connection(364, 365),
+ Connection(365, 367),
+ Connection(435, 367),
+ Connection(367, 397),
+ Connection(397, 435),
+ Connection(344, 438),
+ Connection(438, 439),
+ Connection(439, 344),
+ Connection(272, 271),
+ Connection(271, 311),
+ Connection(311, 272),
+ Connection(195, 5),
+ Connection(5, 281),
+ Connection(281, 195),
+ Connection(273, 287),
+ Connection(287, 291),
+ Connection(291, 273),
+ Connection(396, 428),
+ Connection(428, 199),
+ Connection(199, 396),
+ Connection(311, 271),
+ Connection(271, 268),
+ Connection(268, 311),
+ Connection(283, 444),
+ Connection(444, 445),
+ Connection(445, 283),
+ Connection(373, 254),
+ Connection(254, 339),
+ Connection(339, 373),
+ Connection(282, 334),
+ Connection(334, 296),
+ Connection(296, 282),
+ Connection(449, 347),
+ Connection(347, 346),
+ Connection(346, 449),
+ Connection(264, 447),
+ Connection(447, 454),
+ Connection(454, 264),
+ Connection(336, 296),
+ Connection(296, 299),
+ Connection(299, 336),
+ Connection(338, 10),
+ Connection(10, 151),
+ Connection(151, 338),
+ Connection(278, 439),
+ Connection(439, 455),
+ Connection(455, 278),
+ Connection(292, 407),
+ Connection(407, 415),
+ Connection(415, 292),
+ Connection(358, 371),
+ Connection(371, 355),
+ Connection(355, 358),
+ Connection(340, 345),
+ Connection(345, 372),
+ Connection(372, 340),
+ Connection(346, 347),
+ Connection(347, 280),
+ Connection(280, 346),
+ Connection(442, 443),
+ Connection(443, 282),
+ Connection(282, 442),
+ Connection(19, 94),
+ Connection(94, 370),
+ Connection(370, 19),
+ Connection(441, 442),
+ Connection(442, 295),
+ Connection(295, 441),
+ Connection(248, 419),
+ Connection(419, 197),
+ Connection(197, 248),
+ Connection(263, 255),
+ Connection(255, 359),
+ Connection(359, 263),
+ Connection(440, 275),
+ Connection(275, 274),
+ Connection(274, 440),
+ Connection(300, 383),
+ Connection(383, 368),
+ Connection(368, 300),
+ Connection(351, 412),
+ Connection(412, 465),
+ Connection(465, 351),
+ Connection(263, 467),
+ Connection(467, 466),
+ Connection(466, 263),
+ Connection(301, 368),
+ Connection(368, 389),
+ Connection(389, 301),
+ Connection(395, 378),
+ Connection(378, 379),
+ Connection(379, 395),
+ Connection(412, 351),
+ Connection(351, 419),
+ Connection(419, 412),
+ Connection(436, 426),
+ Connection(426, 322),
+ Connection(322, 436),
+ Connection(2, 164),
+ Connection(164, 393),
+ Connection(393, 2),
+ Connection(370, 462),
+ Connection(462, 461),
+ Connection(461, 370),
+ Connection(164, 0),
+ Connection(0, 267),
+ Connection(267, 164),
+ Connection(302, 11),
+ Connection(11, 12),
+ Connection(12, 302),
+ Connection(268, 12),
+ Connection(12, 13),
+ Connection(13, 268),
+ Connection(293, 300),
+ Connection(300, 301),
+ Connection(301, 293),
+ Connection(446, 261),
+ Connection(261, 340),
+ Connection(340, 446),
+ Connection(330, 266),
+ Connection(266, 425),
+ Connection(425, 330),
+ Connection(426, 423),
+ Connection(423, 391),
+ Connection(391, 426),
+ Connection(429, 355),
+ Connection(355, 437),
+ Connection(437, 429),
+ Connection(391, 327),
+ Connection(327, 326),
+ Connection(326, 391),
+ Connection(440, 457),
+ Connection(457, 438),
+ Connection(438, 440),
+ Connection(341, 382),
+ Connection(382, 362),
+ Connection(362, 341),
+ Connection(459, 457),
+ Connection(457, 461),
+ Connection(461, 459),
+ Connection(434, 430),
+ Connection(430, 394),
+ Connection(394, 434),
+ Connection(414, 463),
+ Connection(463, 362),
+ Connection(362, 414),
+ Connection(396, 369),
+ Connection(369, 262),
+ Connection(262, 396),
+ Connection(354, 461),
+ Connection(461, 457),
+ Connection(457, 354),
+ Connection(316, 403),
+ Connection(403, 402),
+ Connection(402, 316),
+ Connection(315, 404),
+ Connection(404, 403),
+ Connection(403, 315),
+ Connection(314, 405),
+ Connection(405, 404),
+ Connection(404, 314),
+ Connection(313, 406),
+ Connection(406, 405),
+ Connection(405, 313),
+ Connection(421, 418),
+ Connection(418, 406),
+ Connection(406, 421),
+ Connection(366, 401),
+ Connection(401, 361),
+ Connection(361, 366),
+ Connection(306, 408),
+ Connection(408, 407),
+ Connection(407, 306),
+ Connection(291, 409),
+ Connection(409, 408),
+ Connection(408, 291),
+ Connection(287, 410),
+ Connection(410, 409),
+ Connection(409, 287),
+ Connection(432, 436),
+ Connection(436, 410),
+ Connection(410, 432),
+ Connection(434, 416),
+ Connection(416, 411),
+ Connection(411, 434),
+ Connection(264, 368),
+ Connection(368, 383),
+ Connection(383, 264),
+ Connection(309, 438),
+ Connection(438, 457),
+ Connection(457, 309),
+ Connection(352, 376),
+ Connection(376, 401),
+ Connection(401, 352),
+ Connection(274, 275),
+ Connection(275, 4),
+ Connection(4, 274),
+ Connection(421, 428),
+ Connection(428, 262),
+ Connection(262, 421),
+ Connection(294, 327),
+ Connection(327, 358),
+ Connection(358, 294),
+ Connection(433, 416),
+ Connection(416, 367),
+ Connection(367, 433),
+ Connection(289, 455),
+ Connection(455, 439),
+ Connection(439, 289),
+ Connection(462, 370),
+ Connection(370, 326),
+ Connection(326, 462),
+ Connection(2, 326),
+ Connection(326, 370),
+ Connection(370, 2),
+ Connection(305, 460),
+ Connection(460, 455),
+ Connection(455, 305),
+ Connection(254, 449),
+ Connection(449, 448),
+ Connection(448, 254),
+ Connection(255, 261),
+ Connection(261, 446),
+ Connection(446, 255),
+ Connection(253, 450),
+ Connection(450, 449),
+ Connection(449, 253),
+ Connection(252, 451),
+ Connection(451, 450),
+ Connection(450, 252),
+ Connection(256, 452),
+ Connection(452, 451),
+ Connection(451, 256),
+ Connection(341, 453),
+ Connection(453, 452),
+ Connection(452, 341),
+ Connection(413, 464),
+ Connection(464, 463),
+ Connection(463, 413),
+ Connection(441, 413),
+ Connection(413, 414),
+ Connection(414, 441),
+ Connection(258, 442),
+ Connection(442, 441),
+ Connection(441, 258),
+ Connection(257, 443),
+ Connection(443, 442),
+ Connection(442, 257),
+ Connection(259, 444),
+ Connection(444, 443),
+ Connection(443, 259),
+ Connection(260, 445),
+ Connection(445, 444),
+ Connection(444, 260),
+ Connection(467, 342),
+ Connection(342, 445),
+ Connection(445, 467),
+ Connection(459, 458),
+ Connection(458, 250),
+ Connection(250, 459),
+ Connection(289, 392),
+ Connection(392, 290),
+ Connection(290, 289),
+ Connection(290, 328),
+ Connection(328, 460),
+ Connection(460, 290),
+ Connection(376, 433),
+ Connection(433, 435),
+ Connection(435, 376),
+ Connection(250, 290),
+ Connection(290, 392),
+ Connection(392, 250),
+ Connection(411, 416),
+ Connection(416, 433),
+ Connection(433, 411),
+ Connection(341, 463),
+ Connection(463, 464),
+ Connection(464, 341),
+ Connection(453, 464),
+ Connection(464, 465),
+ Connection(465, 453),
+ Connection(357, 465),
+ Connection(465, 412),
+ Connection(412, 357),
+ Connection(343, 412),
+ Connection(412, 399),
+ Connection(399, 343),
+ Connection(360, 363),
+ Connection(363, 440),
+ Connection(440, 360),
+ Connection(437, 399),
+ Connection(399, 456),
+ Connection(456, 437),
+ Connection(420, 456),
+ Connection(456, 363),
+ Connection(363, 420),
+ Connection(401, 435),
+ Connection(435, 288),
+ Connection(288, 401),
+ Connection(372, 383),
+ Connection(383, 353),
+ Connection(353, 372),
+ Connection(339, 255),
+ Connection(255, 249),
+ Connection(249, 339),
+ Connection(448, 261),
+ Connection(261, 255),
+ Connection(255, 448),
+ Connection(133, 243),
+ Connection(243, 190),
+ Connection(190, 133),
+ Connection(133, 155),
+ Connection(155, 112),
+ Connection(112, 133),
+ Connection(33, 246),
+ Connection(246, 247),
+ Connection(247, 33),
+ Connection(33, 130),
+ Connection(130, 25),
+ Connection(25, 33),
+ Connection(398, 384),
+ Connection(384, 286),
+ Connection(286, 398),
+ Connection(362, 398),
+ Connection(398, 414),
+ Connection(414, 362),
+ Connection(362, 463),
+ Connection(463, 341),
+ Connection(341, 362),
+ Connection(263, 359),
+ Connection(359, 467),
+ Connection(467, 263),
+ Connection(263, 249),
+ Connection(249, 255),
+ Connection(255, 263),
+ Connection(466, 467),
+ Connection(467, 260),
+ Connection(260, 466),
+ Connection(75, 60),
+ Connection(60, 166),
+ Connection(166, 75),
+ Connection(238, 239),
+ Connection(239, 79),
+ Connection(79, 238),
+ Connection(162, 127),
+ Connection(127, 139),
+ Connection(139, 162),
+ Connection(72, 11),
+ Connection(11, 37),
+ Connection(37, 72),
+ Connection(121, 232),
+ Connection(232, 120),
+ Connection(120, 121),
+ Connection(73, 72),
+ Connection(72, 39),
+ Connection(39, 73),
+ Connection(114, 128),
+ Connection(128, 47),
+ Connection(47, 114),
+ Connection(233, 232),
+ Connection(232, 128),
+ Connection(128, 233),
+ Connection(103, 104),
+ Connection(104, 67),
+ Connection(67, 103),
+ Connection(152, 175),
+ Connection(175, 148),
+ Connection(148, 152),
+ Connection(119, 118),
+ Connection(118, 101),
+ Connection(101, 119),
+ Connection(74, 73),
+ Connection(73, 40),
+ Connection(40, 74),
+ Connection(107, 9),
+ Connection(9, 108),
+ Connection(108, 107),
+ Connection(49, 48),
+ Connection(48, 131),
+ Connection(131, 49),
+ Connection(32, 194),
+ Connection(194, 211),
+ Connection(211, 32),
+ Connection(184, 74),
+ Connection(74, 185),
+ Connection(185, 184),
+ Connection(191, 80),
+ Connection(80, 183),
+ Connection(183, 191),
+ Connection(185, 40),
+ Connection(40, 186),
+ Connection(186, 185),
+ Connection(119, 230),
+ Connection(230, 118),
+ Connection(118, 119),
+ Connection(210, 202),
+ Connection(202, 214),
+ Connection(214, 210),
+ Connection(84, 83),
+ Connection(83, 17),
+ Connection(17, 84),
+ Connection(77, 76),
+ Connection(76, 146),
+ Connection(146, 77),
+ Connection(161, 160),
+ Connection(160, 30),
+ Connection(30, 161),
+ Connection(190, 56),
+ Connection(56, 173),
+ Connection(173, 190),
+ Connection(182, 106),
+ Connection(106, 194),
+ Connection(194, 182),
+ Connection(138, 135),
+ Connection(135, 192),
+ Connection(192, 138),
+ Connection(129, 203),
+ Connection(203, 98),
+ Connection(98, 129),
+ Connection(54, 21),
+ Connection(21, 68),
+ Connection(68, 54),
+ Connection(5, 51),
+ Connection(51, 4),
+ Connection(4, 5),
+ Connection(145, 144),
+ Connection(144, 23),
+ Connection(23, 145),
+ Connection(90, 77),
+ Connection(77, 91),
+ Connection(91, 90),
+ Connection(207, 205),
+ Connection(205, 187),
+ Connection(187, 207),
+ Connection(83, 201),
+ Connection(201, 18),
+ Connection(18, 83),
+ Connection(181, 91),
+ Connection(91, 182),
+ Connection(182, 181),
+ Connection(180, 90),
+ Connection(90, 181),
+ Connection(181, 180),
+ Connection(16, 85),
+ Connection(85, 17),
+ Connection(17, 16),
+ Connection(205, 206),
+ Connection(206, 36),
+ Connection(36, 205),
+ Connection(176, 148),
+ Connection(148, 140),
+ Connection(140, 176),
+ Connection(165, 92),
+ Connection(92, 39),
+ Connection(39, 165),
+ Connection(245, 193),
+ Connection(193, 244),
+ Connection(244, 245),
+ Connection(27, 159),
+ Connection(159, 28),
+ Connection(28, 27),
+ Connection(30, 247),
+ Connection(247, 161),
+ Connection(161, 30),
+ Connection(174, 236),
+ Connection(236, 196),
+ Connection(196, 174),
+ Connection(103, 54),
+ Connection(54, 104),
+ Connection(104, 103),
+ Connection(55, 193),
+ Connection(193, 8),
+ Connection(8, 55),
+ Connection(111, 117),
+ Connection(117, 31),
+ Connection(31, 111),
+ Connection(221, 189),
+ Connection(189, 55),
+ Connection(55, 221),
+ Connection(240, 98),
+ Connection(98, 99),
+ Connection(99, 240),
+ Connection(142, 126),
+ Connection(126, 100),
+ Connection(100, 142),
+ Connection(219, 166),
+ Connection(166, 218),
+ Connection(218, 219),
+ Connection(112, 155),
+ Connection(155, 26),
+ Connection(26, 112),
+ Connection(198, 209),
+ Connection(209, 131),
+ Connection(131, 198),
+ Connection(169, 135),
+ Connection(135, 150),
+ Connection(150, 169),
+ Connection(114, 47),
+ Connection(47, 217),
+ Connection(217, 114),
+ Connection(224, 223),
+ Connection(223, 53),
+ Connection(53, 224),
+ Connection(220, 45),
+ Connection(45, 134),
+ Connection(134, 220),
+ Connection(32, 211),
+ Connection(211, 140),
+ Connection(140, 32),
+ Connection(109, 67),
+ Connection(67, 108),
+ Connection(108, 109),
+ Connection(146, 43),
+ Connection(43, 91),
+ Connection(91, 146),
+ Connection(231, 230),
+ Connection(230, 120),
+ Connection(120, 231),
+ Connection(113, 226),
+ Connection(226, 247),
+ Connection(247, 113),
+ Connection(105, 63),
+ Connection(63, 52),
+ Connection(52, 105),
+ Connection(241, 238),
+ Connection(238, 242),
+ Connection(242, 241),
+ Connection(124, 46),
+ Connection(46, 156),
+ Connection(156, 124),
+ Connection(95, 78),
+ Connection(78, 96),
+ Connection(96, 95),
+ Connection(70, 46),
+ Connection(46, 63),
+ Connection(63, 70),
+ Connection(116, 143),
+ Connection(143, 227),
+ Connection(227, 116),
+ Connection(116, 123),
+ Connection(123, 111),
+ Connection(111, 116),
+ Connection(1, 44),
+ Connection(44, 19),
+ Connection(19, 1),
+ Connection(3, 236),
+ Connection(236, 51),
+ Connection(51, 3),
+ Connection(207, 216),
+ Connection(216, 205),
+ Connection(205, 207),
+ Connection(26, 154),
+ Connection(154, 22),
+ Connection(22, 26),
+ Connection(165, 39),
+ Connection(39, 167),
+ Connection(167, 165),
+ Connection(199, 200),
+ Connection(200, 208),
+ Connection(208, 199),
+ Connection(101, 36),
+ Connection(36, 100),
+ Connection(100, 101),
+ Connection(43, 57),
+ Connection(57, 202),
+ Connection(202, 43),
+ Connection(242, 20),
+ Connection(20, 99),
+ Connection(99, 242),
+ Connection(56, 28),
+ Connection(28, 157),
+ Connection(157, 56),
+ Connection(124, 35),
+ Connection(35, 113),
+ Connection(113, 124),
+ Connection(29, 160),
+ Connection(160, 27),
+ Connection(27, 29),
+ Connection(211, 204),
+ Connection(204, 210),
+ Connection(210, 211),
+ Connection(124, 113),
+ Connection(113, 46),
+ Connection(46, 124),
+ Connection(106, 43),
+ Connection(43, 204),
+ Connection(204, 106),
+ Connection(96, 62),
+ Connection(62, 77),
+ Connection(77, 96),
+ Connection(227, 137),
+ Connection(137, 116),
+ Connection(116, 227),
+ Connection(73, 41),
+ Connection(41, 72),
+ Connection(72, 73),
+ Connection(36, 203),
+ Connection(203, 142),
+ Connection(142, 36),
+ Connection(235, 64),
+ Connection(64, 240),
+ Connection(240, 235),
+ Connection(48, 49),
+ Connection(49, 64),
+ Connection(64, 48),
+ Connection(42, 41),
+ Connection(41, 74),
+ Connection(74, 42),
+ Connection(214, 212),
+ Connection(212, 207),
+ Connection(207, 214),
+ Connection(183, 42),
+ Connection(42, 184),
+ Connection(184, 183),
+ Connection(210, 169),
+ Connection(169, 211),
+ Connection(211, 210),
+ Connection(140, 170),
+ Connection(170, 176),
+ Connection(176, 140),
+ Connection(104, 105),
+ Connection(105, 69),
+ Connection(69, 104),
+ Connection(193, 122),
+ Connection(122, 168),
+ Connection(168, 193),
+ Connection(50, 123),
+ Connection(123, 187),
+ Connection(187, 50),
+ Connection(89, 96),
+ Connection(96, 90),
+ Connection(90, 89),
+ Connection(66, 65),
+ Connection(65, 107),
+ Connection(107, 66),
+ Connection(179, 89),
+ Connection(89, 180),
+ Connection(180, 179),
+ Connection(119, 101),
+ Connection(101, 120),
+ Connection(120, 119),
+ Connection(68, 63),
+ Connection(63, 104),
+ Connection(104, 68),
+ Connection(234, 93),
+ Connection(93, 227),
+ Connection(227, 234),
+ Connection(16, 15),
+ Connection(15, 85),
+ Connection(85, 16),
+ Connection(209, 129),
+ Connection(129, 49),
+ Connection(49, 209),
+ Connection(15, 14),
+ Connection(14, 86),
+ Connection(86, 15),
+ Connection(107, 55),
+ Connection(55, 9),
+ Connection(9, 107),
+ Connection(120, 100),
+ Connection(100, 121),
+ Connection(121, 120),
+ Connection(153, 145),
+ Connection(145, 22),
+ Connection(22, 153),
+ Connection(178, 88),
+ Connection(88, 179),
+ Connection(179, 178),
+ Connection(197, 6),
+ Connection(6, 196),
+ Connection(196, 197),
+ Connection(89, 88),
+ Connection(88, 96),
+ Connection(96, 89),
+ Connection(135, 138),
+ Connection(138, 136),
+ Connection(136, 135),
+ Connection(138, 215),
+ Connection(215, 172),
+ Connection(172, 138),
+ Connection(218, 115),
+ Connection(115, 219),
+ Connection(219, 218),
+ Connection(41, 42),
+ Connection(42, 81),
+ Connection(81, 41),
+ Connection(5, 195),
+ Connection(195, 51),
+ Connection(51, 5),
+ Connection(57, 43),
+ Connection(43, 61),
+ Connection(61, 57),
+ Connection(208, 171),
+ Connection(171, 199),
+ Connection(199, 208),
+ Connection(41, 81),
+ Connection(81, 38),
+ Connection(38, 41),
+ Connection(224, 53),
+ Connection(53, 225),
+ Connection(225, 224),
+ Connection(24, 144),
+ Connection(144, 110),
+ Connection(110, 24),
+ Connection(105, 52),
+ Connection(52, 66),
+ Connection(66, 105),
+ Connection(118, 229),
+ Connection(229, 117),
+ Connection(117, 118),
+ Connection(227, 34),
+ Connection(34, 234),
+ Connection(234, 227),
+ Connection(66, 107),
+ Connection(107, 69),
+ Connection(69, 66),
+ Connection(10, 109),
+ Connection(109, 151),
+ Connection(151, 10),
+ Connection(219, 48),
+ Connection(48, 235),
+ Connection(235, 219),
+ Connection(183, 62),
+ Connection(62, 191),
+ Connection(191, 183),
+ Connection(142, 129),
+ Connection(129, 126),
+ Connection(126, 142),
+ Connection(116, 111),
+ Connection(111, 143),
+ Connection(143, 116),
+ Connection(118, 117),
+ Connection(117, 50),
+ Connection(50, 118),
+ Connection(223, 222),
+ Connection(222, 52),
+ Connection(52, 223),
+ Connection(94, 19),
+ Connection(19, 141),
+ Connection(141, 94),
+ Connection(222, 221),
+ Connection(221, 65),
+ Connection(65, 222),
+ Connection(196, 3),
+ Connection(3, 197),
+ Connection(197, 196),
+ Connection(45, 220),
+ Connection(220, 44),
+ Connection(44, 45),
+ Connection(156, 70),
+ Connection(70, 139),
+ Connection(139, 156),
+ Connection(188, 122),
+ Connection(122, 245),
+ Connection(245, 188),
+ Connection(139, 71),
+ Connection(71, 162),
+ Connection(162, 139),
+ Connection(149, 170),
+ Connection(170, 150),
+ Connection(150, 149),
+ Connection(122, 188),
+ Connection(188, 196),
+ Connection(196, 122),
+ Connection(206, 216),
+ Connection(216, 92),
+ Connection(92, 206),
+ Connection(164, 2),
+ Connection(2, 167),
+ Connection(167, 164),
+ Connection(242, 141),
+ Connection(141, 241),
+ Connection(241, 242),
+ Connection(0, 164),
+ Connection(164, 37),
+ Connection(37, 0),
+ Connection(11, 72),
+ Connection(72, 12),
+ Connection(12, 11),
+ Connection(12, 38),
+ Connection(38, 13),
+ Connection(13, 12),
+ Connection(70, 63),
+ Connection(63, 71),
+ Connection(71, 70),
+ Connection(31, 226),
+ Connection(226, 111),
+ Connection(111, 31),
+ Connection(36, 101),
+ Connection(101, 205),
+ Connection(205, 36),
+ Connection(203, 206),
+ Connection(206, 165),
+ Connection(165, 203),
+ Connection(126, 209),
+ Connection(209, 217),
+ Connection(217, 126),
+ Connection(98, 165),
+ Connection(165, 97),
+ Connection(97, 98),
+ Connection(237, 220),
+ Connection(220, 218),
+ Connection(218, 237),
+ Connection(237, 239),
+ Connection(239, 241),
+ Connection(241, 237),
+ Connection(210, 214),
+ Connection(214, 169),
+ Connection(169, 210),
+ Connection(140, 171),
+ Connection(171, 32),
+ Connection(32, 140),
+ Connection(241, 125),
+ Connection(125, 237),
+ Connection(237, 241),
+ Connection(179, 86),
+ Connection(86, 178),
+ Connection(178, 179),
+ Connection(180, 85),
+ Connection(85, 179),
+ Connection(179, 180),
+ Connection(181, 84),
+ Connection(84, 180),
+ Connection(180, 181),
+ Connection(182, 83),
+ Connection(83, 181),
+ Connection(181, 182),
+ Connection(194, 201),
+ Connection(201, 182),
+ Connection(182, 194),
+ Connection(177, 137),
+ Connection(137, 132),
+ Connection(132, 177),
+ Connection(184, 76),
+ Connection(76, 183),
+ Connection(183, 184),
+ Connection(185, 61),
+ Connection(61, 184),
+ Connection(184, 185),
+ Connection(186, 57),
+ Connection(57, 185),
+ Connection(185, 186),
+ Connection(216, 212),
+ Connection(212, 186),
+ Connection(186, 216),
+ Connection(192, 214),
+ Connection(214, 187),
+ Connection(187, 192),
+ Connection(139, 34),
+ Connection(34, 156),
+ Connection(156, 139),
+ Connection(218, 79),
+ Connection(79, 237),
+ Connection(237, 218),
+ Connection(147, 123),
+ Connection(123, 177),
+ Connection(177, 147),
+ Connection(45, 44),
+ Connection(44, 4),
+ Connection(4, 45),
+ Connection(208, 201),
+ Connection(201, 32),
+ Connection(32, 208),
+ Connection(98, 64),
+ Connection(64, 129),
+ Connection(129, 98),
+ Connection(192, 213),
+ Connection(213, 138),
+ Connection(138, 192),
+ Connection(235, 59),
+ Connection(59, 219),
+ Connection(219, 235),
+ Connection(141, 242),
+ Connection(242, 97),
+ Connection(97, 141),
+ Connection(97, 2),
+ Connection(2, 141),
+ Connection(141, 97),
+ Connection(240, 75),
+ Connection(75, 235),
+ Connection(235, 240),
+ Connection(229, 24),
+ Connection(24, 228),
+ Connection(228, 229),
+ Connection(31, 25),
+ Connection(25, 226),
+ Connection(226, 31),
+ Connection(230, 23),
+ Connection(23, 229),
+ Connection(229, 230),
+ Connection(231, 22),
+ Connection(22, 230),
+ Connection(230, 231),
+ Connection(232, 26),
+ Connection(26, 231),
+ Connection(231, 232),
+ Connection(233, 112),
+ Connection(112, 232),
+ Connection(232, 233),
+ Connection(244, 189),
+ Connection(189, 243),
+ Connection(243, 244),
+ Connection(189, 221),
+ Connection(221, 190),
+ Connection(190, 189),
+ Connection(222, 28),
+ Connection(28, 221),
+ Connection(221, 222),
+ Connection(223, 27),
+ Connection(27, 222),
+ Connection(222, 223),
+ Connection(224, 29),
+ Connection(29, 223),
+ Connection(223, 224),
+ Connection(225, 30),
+ Connection(30, 224),
+ Connection(224, 225),
+ Connection(113, 247),
+ Connection(247, 225),
+ Connection(225, 113),
+ Connection(99, 60),
+ Connection(60, 240),
+ Connection(240, 99),
+ Connection(213, 147),
+ Connection(147, 215),
+ Connection(215, 213),
+ Connection(60, 20),
+ Connection(20, 166),
+ Connection(166, 60),
+ Connection(192, 187),
+ Connection(187, 213),
+ Connection(213, 192),
+ Connection(243, 112),
+ Connection(112, 244),
+ Connection(244, 243),
+ Connection(244, 233),
+ Connection(233, 245),
+ Connection(245, 244),
+ Connection(245, 128),
+ Connection(128, 188),
+ Connection(188, 245),
+ Connection(188, 114),
+ Connection(114, 174),
+ Connection(174, 188),
+ Connection(134, 131),
+ Connection(131, 220),
+ Connection(220, 134),
+ Connection(174, 217),
+ Connection(217, 236),
+ Connection(236, 174),
+ Connection(236, 198),
+ Connection(198, 134),
+ Connection(134, 236),
+ Connection(215, 177),
+ Connection(177, 58),
+ Connection(58, 215),
+ Connection(156, 143),
+ Connection(143, 124),
+ Connection(124, 156),
+ Connection(25, 110),
+ Connection(110, 7),
+ Connection(7, 25),
+ Connection(31, 228),
+ Connection(228, 25),
+ Connection(25, 31),
+ Connection(264, 356),
+ Connection(356, 368),
+ Connection(368, 264),
+ Connection(0, 11),
+ Connection(11, 267),
+ Connection(267, 0),
+ Connection(451, 452),
+ Connection(452, 349),
+ Connection(349, 451),
+ Connection(267, 302),
+ Connection(302, 269),
+ Connection(269, 267),
+ Connection(350, 357),
+ Connection(357, 277),
+ Connection(277, 350),
+ Connection(350, 452),
+ Connection(452, 357),
+ Connection(357, 350),
+ Connection(299, 333),
+ Connection(333, 297),
+ Connection(297, 299),
+ Connection(396, 175),
+ Connection(175, 377),
+ Connection(377, 396),
+ Connection(280, 347),
+ Connection(347, 330),
+ Connection(330, 280),
+ Connection(269, 303),
+ Connection(303, 270),
+ Connection(270, 269),
+ Connection(151, 9),
+ Connection(9, 337),
+ Connection(337, 151),
+ Connection(344, 278),
+ Connection(278, 360),
+ Connection(360, 344),
+ Connection(424, 418),
+ Connection(418, 431),
+ Connection(431, 424),
+ Connection(270, 304),
+ Connection(304, 409),
+ Connection(409, 270),
+ Connection(272, 310),
+ Connection(310, 407),
+ Connection(407, 272),
+ Connection(322, 270),
+ Connection(270, 410),
+ Connection(410, 322),
+ Connection(449, 450),
+ Connection(450, 347),
+ Connection(347, 449),
+ Connection(432, 422),
+ Connection(422, 434),
+ Connection(434, 432),
+ Connection(18, 313),
+ Connection(313, 17),
+ Connection(17, 18),
+ Connection(291, 306),
+ Connection(306, 375),
+ Connection(375, 291),
+ Connection(259, 387),
+ Connection(387, 260),
+ Connection(260, 259),
+ Connection(424, 335),
+ Connection(335, 418),
+ Connection(418, 424),
+ Connection(434, 364),
+ Connection(364, 416),
+ Connection(416, 434),
+ Connection(391, 423),
+ Connection(423, 327),
+ Connection(327, 391),
+ Connection(301, 251),
+ Connection(251, 298),
+ Connection(298, 301),
+ Connection(275, 281),
+ Connection(281, 4),
+ Connection(4, 275),
+ Connection(254, 373),
+ Connection(373, 253),
+ Connection(253, 254),
+ Connection(375, 307),
+ Connection(307, 321),
+ Connection(321, 375),
+ Connection(280, 425),
+ Connection(425, 411),
+ Connection(411, 280),
+ Connection(200, 421),
+ Connection(421, 18),
+ Connection(18, 200),
+ Connection(335, 321),
+ Connection(321, 406),
+ Connection(406, 335),
+ Connection(321, 320),
+ Connection(320, 405),
+ Connection(405, 321),
+ Connection(314, 315),
+ Connection(315, 17),
+ Connection(17, 314),
+ Connection(423, 426),
+ Connection(426, 266),
+ Connection(266, 423),
+ Connection(396, 377),
+ Connection(377, 369),
+ Connection(369, 396),
+ Connection(270, 322),
+ Connection(322, 269),
+ Connection(269, 270),
+ Connection(413, 417),
+ Connection(417, 464),
+ Connection(464, 413),
+ Connection(385, 386),
+ Connection(386, 258),
+ Connection(258, 385),
+ Connection(248, 456),
+ Connection(456, 419),
+ Connection(419, 248),
+ Connection(298, 284),
+ Connection(284, 333),
+ Connection(333, 298),
+ Connection(168, 417),
+ Connection(417, 8),
+ Connection(8, 168),
+ Connection(448, 346),
+ Connection(346, 261),
+ Connection(261, 448),
+ Connection(417, 413),
+ Connection(413, 285),
+ Connection(285, 417),
+ Connection(326, 327),
+ Connection(327, 328),
+ Connection(328, 326),
+ Connection(277, 355),
+ Connection(355, 329),
+ Connection(329, 277),
+ Connection(309, 392),
+ Connection(392, 438),
+ Connection(438, 309),
+ Connection(381, 382),
+ Connection(382, 256),
+ Connection(256, 381),
+ Connection(279, 429),
+ Connection(429, 360),
+ Connection(360, 279),
+ Connection(365, 364),
+ Connection(364, 379),
+ Connection(379, 365),
+ Connection(355, 277),
+ Connection(277, 437),
+ Connection(437, 355),
+ Connection(282, 443),
+ Connection(443, 283),
+ Connection(283, 282),
+ Connection(281, 275),
+ Connection(275, 363),
+ Connection(363, 281),
+ Connection(395, 431),
+ Connection(431, 369),
+ Connection(369, 395),
+ Connection(299, 297),
+ Connection(297, 337),
+ Connection(337, 299),
+ Connection(335, 273),
+ Connection(273, 321),
+ Connection(321, 335),
+ Connection(348, 450),
+ Connection(450, 349),
+ Connection(349, 348),
+ Connection(359, 446),
+ Connection(446, 467),
+ Connection(467, 359),
+ Connection(283, 293),
+ Connection(293, 282),
+ Connection(282, 283),
+ Connection(250, 458),
+ Connection(458, 462),
+ Connection(462, 250),
+ Connection(300, 276),
+ Connection(276, 383),
+ Connection(383, 300),
+ Connection(292, 308),
+ Connection(308, 325),
+ Connection(325, 292),
+ Connection(283, 276),
+ Connection(276, 293),
+ Connection(293, 283),
+ Connection(264, 372),
+ Connection(372, 447),
+ Connection(447, 264),
+ Connection(346, 352),
+ Connection(352, 340),
+ Connection(340, 346),
+ Connection(354, 274),
+ Connection(274, 19),
+ Connection(19, 354),
+ Connection(363, 456),
+ Connection(456, 281),
+ Connection(281, 363),
+ Connection(426, 436),
+ Connection(436, 425),
+ Connection(425, 426),
+ Connection(380, 381),
+ Connection(381, 252),
+ Connection(252, 380),
+ Connection(267, 269),
+ Connection(269, 393),
+ Connection(393, 267),
+ Connection(421, 200),
+ Connection(200, 428),
+ Connection(428, 421),
+ Connection(371, 266),
+ Connection(266, 329),
+ Connection(329, 371),
+ Connection(432, 287),
+ Connection(287, 422),
+ Connection(422, 432),
+ Connection(290, 250),
+ Connection(250, 328),
+ Connection(328, 290),
+ Connection(385, 258),
+ Connection(258, 384),
+ Connection(384, 385),
+ Connection(446, 265),
+ Connection(265, 342),
+ Connection(342, 446),
+ Connection(386, 387),
+ Connection(387, 257),
+ Connection(257, 386),
+ Connection(422, 424),
+ Connection(424, 430),
+ Connection(430, 422),
+ Connection(445, 342),
+ Connection(342, 276),
+ Connection(276, 445),
+ Connection(422, 273),
+ Connection(273, 424),
+ Connection(424, 422),
+ Connection(306, 292),
+ Connection(292, 307),
+ Connection(307, 306),
+ Connection(352, 366),
+ Connection(366, 345),
+ Connection(345, 352),
+ Connection(268, 271),
+ Connection(271, 302),
+ Connection(302, 268),
+ Connection(358, 423),
+ Connection(423, 371),
+ Connection(371, 358),
+ Connection(327, 294),
+ Connection(294, 460),
+ Connection(460, 327),
+ Connection(331, 279),
+ Connection(279, 294),
+ Connection(294, 331),
+ Connection(303, 271),
+ Connection(271, 304),
+ Connection(304, 303),
+ Connection(436, 432),
+ Connection(432, 427),
+ Connection(427, 436),
+ Connection(304, 272),
+ Connection(272, 408),
+ Connection(408, 304),
+ Connection(395, 394),
+ Connection(394, 431),
+ Connection(431, 395),
+ Connection(378, 395),
+ Connection(395, 400),
+ Connection(400, 378),
+ Connection(296, 334),
+ Connection(334, 299),
+ Connection(299, 296),
+ Connection(6, 351),
+ Connection(351, 168),
+ Connection(168, 6),
+ Connection(376, 352),
+ Connection(352, 411),
+ Connection(411, 376),
+ Connection(307, 325),
+ Connection(325, 320),
+ Connection(320, 307),
+ Connection(285, 295),
+ Connection(295, 336),
+ Connection(336, 285),
+ Connection(320, 319),
+ Connection(319, 404),
+ Connection(404, 320),
+ Connection(329, 330),
+ Connection(330, 349),
+ Connection(349, 329),
+ Connection(334, 293),
+ Connection(293, 333),
+ Connection(333, 334),
+ Connection(366, 323),
+ Connection(323, 447),
+ Connection(447, 366),
+ Connection(316, 15),
+ Connection(15, 315),
+ Connection(315, 316),
+ Connection(331, 358),
+ Connection(358, 279),
+ Connection(279, 331),
+ Connection(317, 14),
+ Connection(14, 316),
+ Connection(316, 317),
+ Connection(8, 285),
+ Connection(285, 9),
+ Connection(9, 8),
+ Connection(277, 329),
+ Connection(329, 350),
+ Connection(350, 277),
+ Connection(253, 374),
+ Connection(374, 252),
+ Connection(252, 253),
+ Connection(319, 318),
+ Connection(318, 403),
+ Connection(403, 319),
+ Connection(351, 6),
+ Connection(6, 419),
+ Connection(419, 351),
+ Connection(324, 318),
+ Connection(318, 325),
+ Connection(325, 324),
+ Connection(397, 367),
+ Connection(367, 365),
+ Connection(365, 397),
+ Connection(288, 435),
+ Connection(435, 397),
+ Connection(397, 288),
+ Connection(278, 344),
+ Connection(344, 439),
+ Connection(439, 278),
+ Connection(310, 272),
+ Connection(272, 311),
+ Connection(311, 310),
+ Connection(248, 195),
+ Connection(195, 281),
+ Connection(281, 248),
+ Connection(375, 273),
+ Connection(273, 291),
+ Connection(291, 375),
+ Connection(175, 396),
+ Connection(396, 199),
+ Connection(199, 175),
+ Connection(312, 311),
+ Connection(311, 268),
+ Connection(268, 312),
+ Connection(276, 283),
+ Connection(283, 445),
+ Connection(445, 276),
+ Connection(390, 373),
+ Connection(373, 339),
+ Connection(339, 390),
+ Connection(295, 282),
+ Connection(282, 296),
+ Connection(296, 295),
+ Connection(448, 449),
+ Connection(449, 346),
+ Connection(346, 448),
+ Connection(356, 264),
+ Connection(264, 454),
+ Connection(454, 356),
+ Connection(337, 336),
+ Connection(336, 299),
+ Connection(299, 337),
+ Connection(337, 338),
+ Connection(338, 151),
+ Connection(151, 337),
+ Connection(294, 278),
+ Connection(278, 455),
+ Connection(455, 294),
+ Connection(308, 292),
+ Connection(292, 415),
+ Connection(415, 308),
+ Connection(429, 358),
+ Connection(358, 355),
+ Connection(355, 429),
+ Connection(265, 340),
+ Connection(340, 372),
+ Connection(372, 265),
+ Connection(352, 346),
+ Connection(346, 280),
+ Connection(280, 352),
+ Connection(295, 442),
+ Connection(442, 282),
+ Connection(282, 295),
+ Connection(354, 19),
+ Connection(19, 370),
+ Connection(370, 354),
+ Connection(285, 441),
+ Connection(441, 295),
+ Connection(295, 285),
+ Connection(195, 248),
+ Connection(248, 197),
+ Connection(197, 195),
+ Connection(457, 440),
+ Connection(440, 274),
+ Connection(274, 457),
+ Connection(301, 300),
+ Connection(300, 368),
+ Connection(368, 301),
+ Connection(417, 351),
+ Connection(351, 465),
+ Connection(465, 417),
+ Connection(251, 301),
+ Connection(301, 389),
+ Connection(389, 251),
+ Connection(394, 395),
+ Connection(395, 379),
+ Connection(379, 394),
+ Connection(399, 412),
+ Connection(412, 419),
+ Connection(419, 399),
+ Connection(410, 436),
+ Connection(436, 322),
+ Connection(322, 410),
+ Connection(326, 2),
+ Connection(2, 393),
+ Connection(393, 326),
+ Connection(354, 370),
+ Connection(370, 461),
+ Connection(461, 354),
+ Connection(393, 164),
+ Connection(164, 267),
+ Connection(267, 393),
+ Connection(268, 302),
+ Connection(302, 12),
+ Connection(12, 268),
+ Connection(312, 268),
+ Connection(268, 13),
+ Connection(13, 312),
+ Connection(298, 293),
+ Connection(293, 301),
+ Connection(301, 298),
+ Connection(265, 446),
+ Connection(446, 340),
+ Connection(340, 265),
+ Connection(280, 330),
+ Connection(330, 425),
+ Connection(425, 280),
+ Connection(322, 426),
+ Connection(426, 391),
+ Connection(391, 322),
+ Connection(420, 429),
+ Connection(429, 437),
+ Connection(437, 420),
+ Connection(393, 391),
+ Connection(391, 326),
+ Connection(326, 393),
+ Connection(344, 440),
+ Connection(440, 438),
+ Connection(438, 344),
+ Connection(458, 459),
+ Connection(459, 461),
+ Connection(461, 458),
+ Connection(364, 434),
+ Connection(434, 394),
+ Connection(394, 364),
+ Connection(428, 396),
+ Connection(396, 262),
+ Connection(262, 428),
+ Connection(274, 354),
+ Connection(354, 457),
+ Connection(457, 274),
+ Connection(317, 316),
+ Connection(316, 402),
+ Connection(402, 317),
+ Connection(316, 315),
+ Connection(315, 403),
+ Connection(403, 316),
+ Connection(315, 314),
+ Connection(314, 404),
+ Connection(404, 315),
+ Connection(314, 313),
+ Connection(313, 405),
+ Connection(405, 314),
+ Connection(313, 421),
+ Connection(421, 406),
+ Connection(406, 313),
+ Connection(323, 366),
+ Connection(366, 361),
+ Connection(361, 323),
+ Connection(292, 306),
+ Connection(306, 407),
+ Connection(407, 292),
+ Connection(306, 291),
+ Connection(291, 408),
+ Connection(408, 306),
+ Connection(291, 287),
+ Connection(287, 409),
+ Connection(409, 291),
+ Connection(287, 432),
+ Connection(432, 410),
+ Connection(410, 287),
+ Connection(427, 434),
+ Connection(434, 411),
+ Connection(411, 427),
+ Connection(372, 264),
+ Connection(264, 383),
+ Connection(383, 372),
+ Connection(459, 309),
+ Connection(309, 457),
+ Connection(457, 459),
+ Connection(366, 352),
+ Connection(352, 401),
+ Connection(401, 366),
+ Connection(1, 274),
+ Connection(274, 4),
+ Connection(4, 1),
+ Connection(418, 421),
+ Connection(421, 262),
+ Connection(262, 418),
+ Connection(331, 294),
+ Connection(294, 358),
+ Connection(358, 331),
+ Connection(435, 433),
+ Connection(433, 367),
+ Connection(367, 435),
+ Connection(392, 289),
+ Connection(289, 439),
+ Connection(439, 392),
+ Connection(328, 462),
+ Connection(462, 326),
+ Connection(326, 328),
+ Connection(94, 2),
+ Connection(2, 370),
+ Connection(370, 94),
+ Connection(289, 305),
+ Connection(305, 455),
+ Connection(455, 289),
+ Connection(339, 254),
+ Connection(254, 448),
+ Connection(448, 339),
+ Connection(359, 255),
+ Connection(255, 446),
+ Connection(446, 359),
+ Connection(254, 253),
+ Connection(253, 449),
+ Connection(449, 254),
+ Connection(253, 252),
+ Connection(252, 450),
+ Connection(450, 253),
+ Connection(252, 256),
+ Connection(256, 451),
+ Connection(451, 252),
+ Connection(256, 341),
+ Connection(341, 452),
+ Connection(452, 256),
+ Connection(414, 413),
+ Connection(413, 463),
+ Connection(463, 414),
+ Connection(286, 441),
+ Connection(441, 414),
+ Connection(414, 286),
+ Connection(286, 258),
+ Connection(258, 441),
+ Connection(441, 286),
+ Connection(258, 257),
+ Connection(257, 442),
+ Connection(442, 258),
+ Connection(257, 259),
+ Connection(259, 443),
+ Connection(443, 257),
+ Connection(259, 260),
+ Connection(260, 444),
+ Connection(444, 259),
+ Connection(260, 467),
+ Connection(467, 445),
+ Connection(445, 260),
+ Connection(309, 459),
+ Connection(459, 250),
+ Connection(250, 309),
+ Connection(305, 289),
+ Connection(289, 290),
+ Connection(290, 305),
+ Connection(305, 290),
+ Connection(290, 460),
+ Connection(460, 305),
+ Connection(401, 376),
+ Connection(376, 435),
+ Connection(435, 401),
+ Connection(309, 250),
+ Connection(250, 392),
+ Connection(392, 309),
+ Connection(376, 411),
+ Connection(411, 433),
+ Connection(433, 376),
+ Connection(453, 341),
+ Connection(341, 464),
+ Connection(464, 453),
+ Connection(357, 453),
+ Connection(453, 465),
+ Connection(465, 357),
+ Connection(343, 357),
+ Connection(357, 412),
+ Connection(412, 343),
+ Connection(437, 343),
+ Connection(343, 399),
+ Connection(399, 437),
+ Connection(344, 360),
+ Connection(360, 440),
+ Connection(440, 344),
+ Connection(420, 437),
+ Connection(437, 456),
+ Connection(456, 420),
+ Connection(360, 420),
+ Connection(420, 363),
+ Connection(363, 360),
+ Connection(361, 401),
+ Connection(401, 288),
+ Connection(288, 361),
+ Connection(265, 372),
+ Connection(372, 353),
+ Connection(353, 265),
+ Connection(390, 339),
+ Connection(339, 249),
+ Connection(249, 390),
+ Connection(339, 448),
+ Connection(448, 255),
+ Connection(255, 339),
+ ]
+
+
+@dataclasses.dataclass
+class FaceLandmarkerResult:
+ """The face landmarks detection result from FaceLandmarker, where each vector element represents a single face detected in the image.
+
+ Attributes:
+ face_landmarks: Detected face landmarks in normalized image coordinates.
+ face_blendshapes: Optional face blendshapes results.
+ facial_transformation_matrixes: Optional facial transformation matrix.
+ """
+
+ face_landmarks: List[List[landmark_module.NormalizedLandmark]]
+ face_blendshapes: List[List[category_module.Category]]
+ facial_transformation_matrixes: List[np.ndarray]
+
+
+def _build_landmarker_result(
+ output_packets: Mapping[str, packet_module.Packet]
+) -> FaceLandmarkerResult:
+ """Constructs a `FaceLandmarkerResult` from output packets."""
+ face_landmarks_proto_list = packet_getter.get_proto_list(
+ output_packets[_NORM_LANDMARKS_STREAM_NAME]
+ )
+
+ face_landmarks_results = []
+ for proto in face_landmarks_proto_list:
+ face_landmarks = landmark_pb2.NormalizedLandmarkList()
+ face_landmarks.MergeFrom(proto)
+ face_landmarks_list = []
+ for face_landmark in face_landmarks.landmark:
+ face_landmarks_list.append(
+ landmark_module.NormalizedLandmark.create_from_pb2(face_landmark)
+ )
+ face_landmarks_results.append(face_landmarks_list)
+
+ face_blendshapes_results = []
+ if _BLENDSHAPES_STREAM_NAME in output_packets:
+ face_blendshapes_proto_list = packet_getter.get_proto_list(
+ output_packets[_BLENDSHAPES_STREAM_NAME]
+ )
+ for proto in face_blendshapes_proto_list:
+ face_blendshapes_categories = []
+ face_blendshapes_classifications = classification_pb2.ClassificationList()
+ face_blendshapes_classifications.MergeFrom(proto)
+ for face_blendshapes in face_blendshapes_classifications.classification:
+ face_blendshapes_categories.append(
+ category_module.Category(
+ index=face_blendshapes.index,
+ score=face_blendshapes.score,
+ display_name=face_blendshapes.display_name,
+ category_name=face_blendshapes.label,
+ )
+ )
+ face_blendshapes_results.append(face_blendshapes_categories)
+
+ facial_transformation_matrixes_results = []
+ if _FACE_GEOMETRY_STREAM_NAME in output_packets:
+ facial_transformation_matrixes_proto_list = packet_getter.get_proto_list(
+ output_packets[_FACE_GEOMETRY_STREAM_NAME]
+ )
+ for proto in facial_transformation_matrixes_proto_list:
+ if hasattr(proto, 'pose_transform_matrix'):
+ matrix_data = matrix_data_pb2.MatrixData()
+ matrix_data.MergeFrom(proto.pose_transform_matrix)
+ matrix = np.array(matrix_data.packed_data)
+ matrix = matrix.reshape((matrix_data.rows, matrix_data.cols))
+ matrix = (
+ matrix if matrix_data.layout == _LayoutEnum.ROW_MAJOR else matrix.T
+ )
+ facial_transformation_matrixes_results.append(matrix)
+
+ return FaceLandmarkerResult(
+ face_landmarks_results,
+ face_blendshapes_results,
+ facial_transformation_matrixes_results,
+ )
+
+def _build_landmarker_result2(
+ output_packets: Mapping[str, packet_module.Packet]
+) -> FaceLandmarkerResult:
+ """Constructs a `FaceLandmarkerResult` from output packets."""
+ face_landmarks_proto_list = packet_getter.get_proto_list(
+ output_packets[_NORM_LANDMARKS_STREAM_NAME]
+ )
+
+ face_landmarks_results = []
+ for proto in face_landmarks_proto_list:
+ face_landmarks = landmark_pb2.NormalizedLandmarkList()
+ face_landmarks.MergeFrom(proto)
+ face_landmarks_list = []
+ for face_landmark in face_landmarks.landmark:
+ face_landmarks_list.append(
+ landmark_module.NormalizedLandmark.create_from_pb2(face_landmark)
+ )
+ face_landmarks_results.append(face_landmarks_list)
+
+ face_blendshapes_results = []
+ if _BLENDSHAPES_STREAM_NAME in output_packets:
+ face_blendshapes_proto_list = packet_getter.get_proto_list(
+ output_packets[_BLENDSHAPES_STREAM_NAME]
+ )
+ for proto in face_blendshapes_proto_list:
+ face_blendshapes_categories = []
+ face_blendshapes_classifications = classification_pb2.ClassificationList()
+ face_blendshapes_classifications.MergeFrom(proto)
+ for face_blendshapes in face_blendshapes_classifications.classification:
+ face_blendshapes_categories.append(
+ category_module.Category(
+ index=face_blendshapes.index,
+ score=face_blendshapes.score,
+ display_name=face_blendshapes.display_name,
+ category_name=face_blendshapes.label,
+ )
+ )
+ face_blendshapes_results.append(face_blendshapes_categories)
+
+ facial_transformation_matrixes_results = []
+ if _FACE_GEOMETRY_STREAM_NAME in output_packets:
+ facial_transformation_matrixes_proto_list = packet_getter.get_proto_list(
+ output_packets[_FACE_GEOMETRY_STREAM_NAME]
+ )
+ for proto in facial_transformation_matrixes_proto_list:
+ if hasattr(proto, 'pose_transform_matrix'):
+ matrix_data = matrix_data_pb2.MatrixData()
+ matrix_data.MergeFrom(proto.pose_transform_matrix)
+ matrix = np.array(matrix_data.packed_data)
+ matrix = matrix.reshape((matrix_data.rows, matrix_data.cols))
+ matrix = (
+ matrix if matrix_data.layout == _LayoutEnum.ROW_MAJOR else matrix.T
+ )
+ facial_transformation_matrixes_results.append(matrix)
+
+ return FaceLandmarkerResult(
+ face_landmarks_results,
+ face_blendshapes_results,
+ facial_transformation_matrixes_results,
+ ), facial_transformation_matrixes_proto_list[0].mesh
+
+@dataclasses.dataclass
+class FaceLandmarkerOptions:
+ """Options for the face landmarker task.
+
+ Attributes:
+ base_options: Base options for the face landmarker task.
+ running_mode: The running mode of the task. Default to the image mode.
+ FaceLandmarker has three running modes: 1) The image mode for detecting
+ face landmarks on single image inputs. 2) The video mode for detecting
+ face landmarks on the decoded frames of a video. 3) The live stream mode
+ for detecting face landmarks on the live stream of input data, such as
+ from camera. In this mode, the "result_callback" below must be specified
+ to receive the detection results asynchronously.
+ num_faces: The maximum number of faces that can be detected by the
+ FaceLandmarker.
+ min_face_detection_confidence: The minimum confidence score for the face
+ detection to be considered successful.
+ min_face_presence_confidence: The minimum confidence score of face presence
+ score in the face landmark detection.
+ min_tracking_confidence: The minimum confidence score for the face tracking
+ to be considered successful.
+ output_face_blendshapes: Whether FaceLandmarker outputs face blendshapes
+ classification. Face blendshapes are used for rendering the 3D face model.
+ output_facial_transformation_matrixes: Whether FaceLandmarker outputs facial
+ transformation_matrix. Facial transformation matrix is used to transform
+ the face landmarks in canonical face to the detected face, so that users
+ can apply face effects on the detected landmarks.
+ result_callback: The user-defined result callback for processing live stream
+ data. The result callback should only be specified when the running mode
+ is set to the live stream mode.
+ """
+
+ base_options: _BaseOptions
+ running_mode: _RunningMode = _RunningMode.IMAGE
+ num_faces: int = 1
+ min_face_detection_confidence: float = 0.5
+ min_face_presence_confidence: float = 0.5
+ min_tracking_confidence: float = 0.5
+ output_face_blendshapes: bool = False
+ output_facial_transformation_matrixes: bool = False
+ result_callback: Optional[
+ Callable[[FaceLandmarkerResult, image_module.Image, int], None]
+ ] = None
+
+ @doc_controls.do_not_generate_docs
+ def to_pb2(self) -> _FaceLandmarkerGraphOptionsProto:
+ """Generates an FaceLandmarkerGraphOptions protobuf object."""
+ base_options_proto = self.base_options.to_pb2()
+ base_options_proto.use_stream_mode = (
+ False if self.running_mode == _RunningMode.IMAGE else True
+ )
+
+ # Initialize the face landmarker options from base options.
+ face_landmarker_options_proto = _FaceLandmarkerGraphOptionsProto(
+ base_options=base_options_proto
+ )
+
+ # Configure face detector options.
+ face_landmarker_options_proto.face_detector_graph_options.num_faces = (
+ self.num_faces
+ )
+ face_landmarker_options_proto.face_detector_graph_options.min_detection_confidence = (
+ self.min_face_detection_confidence
+ )
+
+ # Configure face landmark detector options.
+ face_landmarker_options_proto.min_tracking_confidence = (
+ self.min_tracking_confidence
+ )
+ face_landmarker_options_proto.face_landmarks_detector_graph_options.min_detection_confidence = (
+ self.min_face_detection_confidence
+ )
+ return face_landmarker_options_proto
+
+
+class FaceLandmarker(base_vision_task_api.BaseVisionTaskApi):
+ """Class that performs face landmarks detection on images."""
+
+ @classmethod
+ def create_from_model_path(cls, model_path: str) -> 'FaceLandmarker':
+ """Creates an `FaceLandmarker` object from a TensorFlow Lite model and the default `FaceLandmarkerOptions`.
+
+ Note that the created `FaceLandmarker` instance is in image mode, for
+ detecting face landmarks on single image inputs.
+
+ Args:
+ model_path: Path to the model.
+
+ Returns:
+ `FaceLandmarker` object that's created from the model file and the
+ default `FaceLandmarkerOptions`.
+
+ Raises:
+ ValueError: If failed to create `FaceLandmarker` object from the
+ provided file such as invalid file path.
+ RuntimeError: If other types of error occurred.
+ """
+ base_options = _BaseOptions(model_asset_path=model_path)
+ options = FaceLandmarkerOptions(
+ base_options=base_options, running_mode=_RunningMode.IMAGE
+ )
+ return cls.create_from_options(options)
+
+ @classmethod
+ def create_from_options(
+ cls, options: FaceLandmarkerOptions
+ ) -> 'FaceLandmarker':
+ """Creates the `FaceLandmarker` object from face landmarker options.
+
+ Args:
+ options: Options for the face landmarker task.
+
+ Returns:
+ `FaceLandmarker` object that's created from `options`.
+
+ Raises:
+ ValueError: If failed to create `FaceLandmarker` object from
+ `FaceLandmarkerOptions` such as missing the model.
+ RuntimeError: If other types of error occurred.
+ """
+
+ def packets_callback(output_packets: Mapping[str, packet_module.Packet]):
+ if output_packets[_IMAGE_OUT_STREAM_NAME].is_empty():
+ return
+
+ image = packet_getter.get_image(output_packets[_IMAGE_OUT_STREAM_NAME])
+ if output_packets[_IMAGE_OUT_STREAM_NAME].is_empty():
+ return
+
+ if output_packets[_NORM_LANDMARKS_STREAM_NAME].is_empty():
+ empty_packet = output_packets[_NORM_LANDMARKS_STREAM_NAME]
+ options.result_callback(
+ FaceLandmarkerResult([], [], []),
+ image,
+ empty_packet.timestamp.value // _MICRO_SECONDS_PER_MILLISECOND,
+ )
+ return
+
+ face_landmarks_result = _build_landmarker_result(output_packets)
+ timestamp = output_packets[_NORM_LANDMARKS_STREAM_NAME].timestamp
+ options.result_callback(
+ face_landmarks_result,
+ image,
+ timestamp.value // _MICRO_SECONDS_PER_MILLISECOND,
+ )
+
+ output_streams = [
+ ':'.join([_NORM_LANDMARKS_TAG, _NORM_LANDMARKS_STREAM_NAME]),
+ ':'.join([_IMAGE_TAG, _IMAGE_OUT_STREAM_NAME]),
+ ]
+
+ if options.output_face_blendshapes:
+ output_streams.append(
+ ':'.join([_BLENDSHAPES_TAG, _BLENDSHAPES_STREAM_NAME])
+ )
+ if options.output_facial_transformation_matrixes:
+ output_streams.append(
+ ':'.join([_FACE_GEOMETRY_TAG, _FACE_GEOMETRY_STREAM_NAME])
+ )
+
+ task_info = _TaskInfo(
+ task_graph=_TASK_GRAPH_NAME,
+ input_streams=[
+ ':'.join([_IMAGE_TAG, _IMAGE_IN_STREAM_NAME]),
+ ':'.join([_NORM_RECT_TAG, _NORM_RECT_STREAM_NAME]),
+ ],
+ output_streams=output_streams,
+ task_options=options,
+ )
+ return cls(
+ task_info.generate_graph_config(
+ enable_flow_limiting=options.running_mode
+ == _RunningMode.LIVE_STREAM
+ ),
+ options.running_mode,
+ packets_callback if options.result_callback else None,
+ )
+
+ def detect(
+ self,
+ image: image_module.Image,
+ image_processing_options: Optional[_ImageProcessingOptions] = None,
+ ) -> FaceLandmarkerResult:
+ """Performs face landmarks detection on the given image.
+
+ Only use this method when the FaceLandmarker is created with the image
+ running mode.
+
+ The image can be of any size with format RGB or RGBA.
+ TODO: Describes how the input image will be preprocessed after the yuv
+ support is implemented.
+
+ Args:
+ image: MediaPipe Image.
+ image_processing_options: Options for image processing.
+
+ Returns:
+ The face landmarks detection results.
+
+ Raises:
+ ValueError: If any of the input arguments is invalid.
+ RuntimeError: If face landmarker detection failed to run.
+ """
+
+ normalized_rect = self.convert_to_normalized_rect(
+ image_processing_options, image, roi_allowed=False
+ )
+ output_packets = self._process_image_data({
+ _IMAGE_IN_STREAM_NAME: packet_creator.create_image(image),
+ _NORM_RECT_STREAM_NAME: packet_creator.create_proto(
+ normalized_rect.to_pb2()
+ ),
+ })
+
+ if output_packets[_NORM_LANDMARKS_STREAM_NAME].is_empty():
+ return FaceLandmarkerResult([], [], [])
+
+ return _build_landmarker_result2(output_packets)
+
+ def detect_for_video(
+ self,
+ image: image_module.Image,
+ timestamp_ms: int,
+ image_processing_options: Optional[_ImageProcessingOptions] = None,
+ ):
+ """Performs face landmarks detection on the provided video frame.
+
+ Only use this method when the FaceLandmarker is created with the video
+ running mode.
+
+ Only use this method when the FaceLandmarker is created with the video
+ running mode. It's required to provide the video frame's timestamp (in
+ milliseconds) along with the video frame. The input timestamps should be
+ monotonically increasing for adjacent calls of this method.
+
+ Args:
+ image: MediaPipe Image.
+ timestamp_ms: The timestamp of the input video frame in milliseconds.
+ image_processing_options: Options for image processing.
+
+ Returns:
+ The face landmarks detection results.
+
+ Raises:
+ ValueError: If any of the input arguments is invalid.
+ RuntimeError: If face landmarker detection failed to run.
+ """
+ normalized_rect = self.convert_to_normalized_rect(
+ image_processing_options, image, roi_allowed=False
+ )
+ output_packets = self._process_video_data({
+ _IMAGE_IN_STREAM_NAME: packet_creator.create_image(image).at(
+ timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND
+ ),
+ _NORM_RECT_STREAM_NAME: packet_creator.create_proto(
+ normalized_rect.to_pb2()
+ ).at(timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND),
+ })
+
+ if output_packets[_NORM_LANDMARKS_STREAM_NAME].is_empty():
+ return FaceLandmarkerResult([], [], [])
+
+ return _build_landmarker_result2(output_packets)
+
+ def detect_async(
+ self,
+ image: image_module.Image,
+ timestamp_ms: int,
+ image_processing_options: Optional[_ImageProcessingOptions] = None,
+ ) -> None:
+ """Sends live image data to perform face landmarks detection.
+
+ The results will be available via the "result_callback" provided in the
+ FaceLandmarkerOptions. Only use this method when the FaceLandmarker is
+ created with the live stream running mode.
+
+ Only use this method when the FaceLandmarker is created with the live
+ stream running mode. The input timestamps should be monotonically increasing
+ for adjacent calls of this method. This method will return immediately after
+ the input image is accepted. The results will be available via the
+ `result_callback` provided in the `FaceLandmarkerOptions`. The
+ `detect_async` method is designed to process live stream data such as
+ camera input. To lower the overall latency, face landmarker may drop the
+ input images if needed. In other words, it's not guaranteed to have output
+ per input image.
+
+ The `result_callback` provides:
+ - The face landmarks detection results.
+ - The input image that the face landmarker runs on.
+ - The input timestamp in milliseconds.
+
+ Args:
+ image: MediaPipe Image.
+ timestamp_ms: The timestamp of the input image in milliseconds.
+ image_processing_options: Options for image processing.
+
+ Raises:
+ ValueError: If the current input timestamp is smaller than what the
+ face landmarker has already processed.
+ """
+ normalized_rect = self.convert_to_normalized_rect(
+ image_processing_options, image, roi_allowed=False
+ )
+ self._send_live_stream_data({
+ _IMAGE_IN_STREAM_NAME: packet_creator.create_image(image).at(
+ timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND
+ ),
+ _NORM_RECT_STREAM_NAME: packet_creator.create_proto(
+ normalized_rect.to_pb2()
+ ).at(timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND),
+ })
\ No newline at end of file
diff --git a/skyreels_a1/src/media_pipe/mp_models/blaze_face_short_range.tflite b/skyreels_a1/src/media_pipe/mp_models/blaze_face_short_range.tflite
new file mode 100644
index 0000000000000000000000000000000000000000..2645898ee18d8bf53746df830303779c9deabc7d
--- /dev/null
+++ b/skyreels_a1/src/media_pipe/mp_models/blaze_face_short_range.tflite
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:b4578f35940bf5a1a655214a1cce5cab13eba73c1297cd78e1a04c2380b0152f
+size 229746
diff --git a/skyreels_a1/src/media_pipe/mp_models/face_landmarker_v2_with_blendshapes.task b/skyreels_a1/src/media_pipe/mp_models/face_landmarker_v2_with_blendshapes.task
new file mode 100644
index 0000000000000000000000000000000000000000..fedb14de6d2b6708a56c04ae259783e23404c1aa
--- /dev/null
+++ b/skyreels_a1/src/media_pipe/mp_models/face_landmarker_v2_with_blendshapes.task
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:64184e229b263107bc2b804c6625db1341ff2bb731874b0bcc2fe6544e0bc9ff
+size 3758596
diff --git a/skyreels_a1/src/media_pipe/mp_models/pose_landmarker_heavy.task b/skyreels_a1/src/media_pipe/mp_models/pose_landmarker_heavy.task
new file mode 100644
index 0000000000000000000000000000000000000000..5f2c1e254fe2d104606a9031b20b266863d014a6
--- /dev/null
+++ b/skyreels_a1/src/media_pipe/mp_models/pose_landmarker_heavy.task
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:64437af838a65d18e5ba7a0d39b465540069bc8aae8308de3e318aad31fcbc7b
+size 30664242
diff --git a/skyreels_a1/src/media_pipe/mp_utils.py b/skyreels_a1/src/media_pipe/mp_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..23677d2cf3f8bfe8c5663d2e0b6d5d57e9350635
--- /dev/null
+++ b/skyreels_a1/src/media_pipe/mp_utils.py
@@ -0,0 +1,98 @@
+import os
+import numpy as np
+import cv2
+import time
+from tqdm import tqdm
+import multiprocessing
+import glob
+
+import mediapipe as mp
+from mediapipe import solutions
+from mediapipe.framework.formats import landmark_pb2
+from mediapipe.tasks import python
+from mediapipe.tasks.python import vision
+from . import face_landmark
+
+CUR_DIR = os.path.dirname(__file__)
+
+
+class LMKExtractor():
+ def __init__(self, FPS=25):
+ # Create an FaceLandmarker object.
+ self.mode = mp.tasks.vision.FaceDetectorOptions.running_mode.IMAGE
+ base_options = python.BaseOptions(model_asset_path=os.path.join(CUR_DIR, 'mp_models/face_landmarker_v2_with_blendshapes.task'))
+ base_options.delegate = mp.tasks.BaseOptions.Delegate.CPU
+ options = vision.FaceLandmarkerOptions(base_options=base_options,
+ running_mode=self.mode,
+ # min_face_detection_confidence=0.3,
+ # min_face_presence_confidence=0.3,
+ # min_tracking_confidence=0.3,
+ output_face_blendshapes=True,
+ output_facial_transformation_matrixes=True,
+ num_faces=1)
+ self.detector = face_landmark.FaceLandmarker.create_from_options(options)
+ self.last_ts = 0
+ self.frame_ms = int(1000 / FPS)
+
+ det_base_options = python.BaseOptions(model_asset_path=os.path.join(CUR_DIR, 'mp_models/blaze_face_short_range.tflite'))
+ det_options = vision.FaceDetectorOptions(base_options=det_base_options)
+ self.det_detector = vision.FaceDetector.create_from_options(det_options)
+
+
+ def __call__(self, img):
+ frame = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
+ image = mp.Image(image_format=mp.ImageFormat.SRGB, data=frame)
+ t0 = time.time()
+ if self.mode == mp.tasks.vision.FaceDetectorOptions.running_mode.VIDEO:
+ det_result = self.det_detector.detect(image)
+ if len(det_result.detections) != 1:
+ return None
+ self.last_ts += self.frame_ms
+ try:
+ detection_result, mesh3d = self.detector.detect_for_video(image, timestamp_ms=self.last_ts)
+ except:
+ return None
+ elif self.mode == mp.tasks.vision.FaceDetectorOptions.running_mode.IMAGE:
+ # det_result = self.det_detector.detect(image)
+
+ # if len(det_result.detections) != 1:
+ # return None
+ try:
+ detection_result, mesh3d = self.detector.detect(image)
+ except:
+ return None
+
+
+ bs_list = detection_result.face_blendshapes
+ if len(bs_list) == 1:
+ bs = bs_list[0]
+ bs_values = []
+ for index in range(len(bs)):
+ bs_values.append(bs[index].score)
+ bs_values = bs_values[1:] # remove neutral
+ trans_mat = detection_result.facial_transformation_matrixes[0]
+ face_landmarks_list = detection_result.face_landmarks
+ face_landmarks = face_landmarks_list[0]
+ lmks = []
+ for index in range(len(face_landmarks)):
+ x = face_landmarks[index].x
+ y = face_landmarks[index].y
+ z = face_landmarks[index].z
+ lmks.append([x, y, z])
+ lmks = np.array(lmks)
+
+ lmks3d = np.array(mesh3d.vertex_buffer)
+ lmks3d = lmks3d.reshape(-1, 5)[:, :3]
+ mp_tris = np.array(mesh3d.index_buffer).reshape(-1, 3) + 1
+
+ return {
+ "lmks": lmks,
+ 'lmks3d': lmks3d,
+ "trans_mat": trans_mat,
+ 'faces': mp_tris,
+ "bs": bs_values
+ }
+ else:
+ # print('multiple faces in the image: {}'.format(img_path))
+ return None
+
\ No newline at end of file
diff --git a/skyreels_a1/src/media_pipe/readme b/skyreels_a1/src/media_pipe/readme
new file mode 100644
index 0000000000000000000000000000000000000000..569741c6412bb015fac63125ac31c110bb20ddd3
--- /dev/null
+++ b/skyreels_a1/src/media_pipe/readme
@@ -0,0 +1,5 @@
+The landmark file defines the barycentric embedding of 105 points of the Mediapipe mesh in the surface of FLAME.
+In consists of three arrays: lmk_face_idx, lmk_b_coords, and landmark_indices.
+- lmk_face_idx contains for every landmark the index of the FLAME triangle which each landmark is embedded into
+- lmk_b_coords are the barycentric weights for each vertex of the triangles
+- landmark_indices are the indices of the vertices of the Mediapipe mesh
diff --git a/skyreels_a1/src/renderer.py b/skyreels_a1/src/renderer.py
new file mode 100644
index 0000000000000000000000000000000000000000..dd32360bd3a0e31cb4f4b2cfb7711342783fce85
--- /dev/null
+++ b/skyreels_a1/src/renderer.py
@@ -0,0 +1,436 @@
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from pytorch3d.structures import Meshes
+from pytorch3d.io import load_obj
+from pytorch3d.renderer.mesh import rasterize_meshes
+import pickle
+import chumpy as ch
+import cv2
+import sys, os
+sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
+from skyreels_a1.src.utils.mediapipe_utils import face_vertices, vertex_normals, batch_orth_proj
+from skyreels_a1.src.media_pipe.draw_util import FaceMeshVisualizer
+from mediapipe.framework.formats import landmark_pb2
+
+def keep_vertices_and_update_faces(faces, vertices_to_keep):
+ """
+ Keep specified vertices in the mesh and update the faces.
+ """
+ if isinstance(vertices_to_keep, list) or isinstance(vertices_to_keep, np.ndarray):
+ vertices_to_keep = torch.tensor(vertices_to_keep, dtype=torch.long)
+
+ vertices_to_keep = torch.unique(vertices_to_keep)
+ max_vertex_index = faces.max().long().item() + 1
+
+ mask = torch.zeros(max_vertex_index, dtype=torch.bool)
+ mask[vertices_to_keep] = True
+
+ new_vertex_indices = torch.full((max_vertex_index,), -1, dtype=torch.long)
+ new_vertex_indices[mask] = torch.arange(len(vertices_to_keep))
+
+ valid_faces_mask = (new_vertex_indices[faces] != -1).all(dim=1)
+ filtered_faces = faces[valid_faces_mask]
+ updated_faces = new_vertex_indices[filtered_faces]
+
+ return updated_faces
+
+def predict_landmark_position(ref_points, relative_coords):
+ """
+ Predict the new position of the eyeball based on reference points and relative coordinates.
+ """
+ left_corner = ref_points[0]
+ right_corner = ref_points[8]
+
+ eye_center = (left_corner + right_corner) / 2
+ eye_width_vector = right_corner - left_corner
+ eye_width = np.linalg.norm(eye_width_vector)
+ eye_direction = eye_width_vector / eye_width
+ eye_vertical = np.array([-eye_direction[1], eye_direction[0]])
+
+ predicted_pos = eye_center + \
+ (eye_width/2) * relative_coords[0] * eye_direction + \
+ (eye_width/2) * relative_coords[1] * eye_vertical
+
+ return predicted_pos
+
+def mesh_points_by_barycentric_coordinates(mesh_verts, mesh_faces, lmk_face_idx, lmk_b_coords):
+ """
+ Evaluation 3d points given mesh and landmark embedding
+ """
+ dif1 = ch.vstack([
+ (mesh_verts[mesh_faces[lmk_face_idx], 0] * lmk_b_coords).sum(axis=1),
+ (mesh_verts[mesh_faces[lmk_face_idx], 1] * lmk_b_coords).sum(axis=1),
+ (mesh_verts[mesh_faces[lmk_face_idx], 2] * lmk_b_coords).sum(axis=1)
+ ]).T
+ return dif1
+
+class Renderer(nn.Module):
+ def __init__(self, render_full_head=False, obj_filename='pretrained_models/FLAME/head_template.obj'):
+ super(Renderer, self).__init__()
+ self.image_size = 224
+ self.mediapipe_landmark_embedding = np.load("pretrained_models/smirk/mediapipe_landmark_embedding.npz")
+ self.vis = FaceMeshVisualizer(forehead_edge=False)
+
+ verts, faces, aux = load_obj(obj_filename)
+ uvcoords = aux.verts_uvs[None, ...] # (N, V, 2)
+ uvfaces = faces.textures_idx[None, ...] # (N, F, 3)
+ faces = faces.verts_idx[None,...]
+
+ self.render_full_head = render_full_head
+
+ red_color = torch.tensor([255, 0, 0])[None, None, :].float() / 255.
+ transparent_color = torch.tensor([0, 0, 0])[None, None, :].float()
+ colors = transparent_color.repeat(1, 5023, 1)
+
+ flame_masks = pickle.load(
+ open('pretrained_models/FLAME/FLAME_masks.pkl', 'rb'),
+ encoding='latin1')
+ self.flame_masks = flame_masks
+
+ self.register_buffer('faces', faces)
+
+ face_colors = face_vertices(colors, faces)
+ self.register_buffer('face_colors', face_colors)
+
+ self.register_buffer('raw_uvcoords', uvcoords)
+
+ uvcoords = torch.cat([uvcoords, uvcoords[:,:,0:1]*0.+1.], -1) #[bz, ntv, 3]
+ uvcoords = uvcoords*2 - 1; uvcoords[...,1] = -uvcoords[...,1]
+ face_uvcoords = face_vertices(uvcoords, uvfaces)
+ self.register_buffer('uvcoords', uvcoords)
+ self.register_buffer('uvfaces', uvfaces)
+ self.register_buffer('face_uvcoords', face_uvcoords)
+
+ pi = np.pi
+ constant_factor = torch.tensor([1/np.sqrt(4*pi), ((2*pi)/3)*(np.sqrt(3/(4*pi))), ((2*pi)/3)*(np.sqrt(3/(4*pi))),\
+ ((2*pi)/3)*(np.sqrt(3/(4*pi))), (pi/4)*(3)*(np.sqrt(5/(12*pi))), (pi/4)*(3)*(np.sqrt(5/(12*pi))),\
+ (pi/4)*(3)*(np.sqrt(5/(12*pi))), (pi/4)*(3/2)*(np.sqrt(5/(12*pi))), (pi/4)*(1/2)*(np.sqrt(5/(4*pi)))]).float()
+ self.register_buffer('constant_factor', constant_factor)
+
+ def forward(self, vertices, cam_params, source_tform=None, tform_512=None, weights_468=None, weights_473=None,shape = None, **landmarks):
+ transformed_vertices = batch_orth_proj(vertices, cam_params)
+ transformed_vertices[:, :, 1:] = -transformed_vertices[:, :, 1:]
+
+ transformed_landmarks = {}
+ for key in landmarks.keys():
+ transformed_landmarks[key] = batch_orth_proj(landmarks[key], cam_params)
+ transformed_landmarks[key][:, :, 1:] = - transformed_landmarks[key][:, :, 1:]
+ transformed_landmarks[key] = transformed_landmarks[key][...,:2]
+
+ # rendered_img = self.render(vertices, transformed_vertices, source_tform, tform_512, weights_468, weights_473,shape)
+ if weights_468 is None:
+ rendered_img = self.render_with_pulid_in_vertices(vertices, transformed_vertices, source_tform, tform_512, shape)
+ else:
+ rendered_img = self.render(vertices, transformed_vertices, source_tform, tform_512, weights_468, weights_473,shape)
+
+ outputs = {
+ 'rendered_img': rendered_img,
+ 'transformed_vertices': transformed_vertices
+ }
+ outputs.update(transformed_landmarks)
+
+ return outputs
+
+ def _calculate_eye_landmarks(self, landmark_list_pixlevel, weights_468, weights_473, source_tform):
+ # [np.array([x_relative, y_relative]),target_point,ref_points] 根据当前的new_landmarks,根据target_point, 利用映射变换,计算出眼部landmarks
+
+ import pdb; pdb.set_trace()
+ pass
+
+ def render_with_pulid_in_vertices(self, vertices, transformed_vertices, source_tform, tform_512, shape):
+ batch_size = vertices.shape[0]
+ ## rasterizer near 0 far 100. move mesh so minz larger than 0
+ transformed_vertices[:,:,2] = transformed_vertices[:,:,2] + 10
+ # import pdb;pdb.set_trace()
+ # 只使用颜色作为attributes
+ colors = self.face_colors.expand(batch_size, -1, -1, -1)
+
+ # # 加载 mediapipe_landmark_embedding 数据
+ lmk_b_coords = self.mediapipe_landmark_embedding['lmk_b_coords']
+ lmk_face_idx = self.mediapipe_landmark_embedding['lmk_face_idx']
+ # import pdb;pdb.set_trace()
+ # 计算 v_selected
+ v_selected = mesh_points_by_barycentric_coordinates(transformed_vertices.detach().cpu().numpy()[0], self.faces.detach().cpu().numpy()[0], lmk_face_idx, lmk_b_coords)
+ # v_selected 增加对应左眼和右眼的8个位置,序号分别是:[4051, 3997, 3965, 3933, 4020],[4597, 4543, 4511, 4479, 4575],得根据transformed_vertices.detach().cpu().numpy()[0]来获取
+ v_selected = np.concatenate([v_selected, transformed_vertices.detach().cpu().numpy()[0][[4543, 4511, 4479, 4575]], transformed_vertices.detach().cpu().numpy()[0][[3997, 3965, 3933, 4020]]], axis=0)
+
+ v_selected_tensor = torch.tensor( np.array(v_selected), dtype=torch.float32).to(transformed_vertices.device)
+ new_landmarks = landmark_pb2.NormalizedLandmarkList()
+ for v in v_selected_tensor:
+ # 将 v 映射到图像坐标
+ img_x = (v[0] + 1) * 0.5 * self.image_size
+ img_y = ((v[1] + 1) * 0.5) * self.image_size
+ # import pdb;pdb.set_trace()
+ point = np.array([img_x.cpu().numpy(), img_y.cpu().numpy(), 1.0])
+ croped_point = np.dot(source_tform.inverse.params, point)
+
+ # original_point = np.dot(tform_512.inverse.params, point)
+ landmark = new_landmarks.landmark.add()
+ landmark.x = croped_point[0]/shape[1]
+ landmark.y = croped_point[1]/shape[0]
+ landmark.z = 1.0
+ # 将 v 映射到图像坐标
+ right_eye_x = (transformed_vertices[0,4597,0] + 1) * 0.5 * self.image_size
+ right_eye_y = (transformed_vertices[0,4597,1] + 1) * 0.5 * self.image_size
+ right_eye_point = np.array([right_eye_x.cpu().numpy(), right_eye_y.cpu().numpy(), 1.0])
+ right_eye_original = np.dot(source_tform.inverse.params, right_eye_point)
+ right_eye_landmarks = right_eye_original[:2]
+
+ left_eye_x = (transformed_vertices[0,4051,0] + 1) * 0.5 * self.image_size
+ left_eye_y = (transformed_vertices[0,4051,1] + 1) * 0.5 * self.image_size
+ left_eye_point = np.array([left_eye_x.cpu().numpy(), left_eye_y.cpu().numpy(), 1.0])
+ left_eye_original = np.dot(source_tform.inverse.params, left_eye_point)
+ left_eye_landmarks = left_eye_original[:2]
+
+ image_new = np.zeros([shape[0],shape[1],3], dtype=np.uint8)
+ self.vis.mp_drawing.draw_landmarks(image=image_new,landmark_list=new_landmarks,connections=self.vis.face_connection_spec.keys(),landmark_drawing_spec=None,connection_drawing_spec=self.vis.face_connection_spec)
+
+ # 直接设置单个像素点的颜色
+ left_point = (int(left_eye_landmarks[0]), int(left_eye_landmarks[1]))
+ right_point = (int(right_eye_landmarks[0]), int(right_eye_landmarks[1]))
+ # import pdb;pdb.set_trace()
+ # 左眼点 - 3x3 区域
+ image_new[left_point[1]-1:left_point[1]+2, left_point[0]-1:left_point[0]+2] = [180, 200, 10] # RGB格式
+ # 右眼点 - 3x3 区域
+ image_new[right_point[1]-1:right_point[1]+2, right_point[0]-1:right_point[0]+2] = [10, 200, 180]
+
+ landmark_58 = new_landmarks.landmark[57] # 因为索引从0开始,所以57表示第58个点
+ x = int(landmark_58.x * shape[1])
+ y = int(landmark_58.y * shape[0])
+ image_new[y-2:y+3, x-2:x+3] = [255, 255, 255] # 设置3x3的白色区域
+
+ return np.copy(image_new)
+
+ def render(self, vertices, transformed_vertices, source_tform, tform_512, weights_468, weights_473, shape):
+ # batch_size = vertices.shape[0]
+ transformed_vertices[:,:,2] += 10 # Z-axis offset
+
+ # colors = self.face_colors.expand(batch_size, -1, -1, -1)
+ # rendering = self.rasterize(transformed_vertices, self.faces.expand(batch_size, -1, -1), colors)
+
+ v_selected = self._calculate_landmark_points(transformed_vertices)
+ v_selected_tensor = torch.tensor(v_selected, dtype=torch.float32, device=transformed_vertices.device) #torch.Size([113, 3])
+ # import pdb; pdb.set_trace()
+ new_landmarks, landmark_list_pixlevel = self._create_landmark_list(v_selected_tensor, source_tform, shape)
+ # 基于weights_468和weights_473,计算眼部landmarks
+ left_eye_point_indices = weights_468[3]
+ right_eye_point_indices = weights_473[3]
+ # 遍历每个索引以找到其在 index_mapping 中的位置
+ left_eye_point_indices = [self.vis.index_mapping.index(idx) for idx in left_eye_point_indices]
+ right_eye_point_indices = [self.vis.index_mapping.index(idx) for idx in right_eye_point_indices]
+
+ left_eye_point = [landmark_list_pixlevel[idx] for idx in left_eye_point_indices]
+ right_eye_point = [landmark_list_pixlevel[idx] for idx in right_eye_point_indices]
+ # import pdb; pdb.set_trace()
+ # weights_468[2].shape = (16, 2)
+ M_affine_left, _ = cv2.estimateAffine2D(np.array(weights_468[2], dtype=np.float32), np.array(left_eye_point, dtype=np.float32))
+ M_affine_right, _ = cv2.estimateAffine2D(np.array(weights_473[2], dtype=np.float32), np.array(right_eye_point, dtype=np.float32))
+
+ # 计算瞳孔点
+ pupil_left_eye = cv2.transform(weights_468[1].reshape(1, 1, 2), M_affine_left).reshape(-1)
+ pupil_right_eye = cv2.transform(weights_473[1].reshape(1, 1, 2), M_affine_right).reshape(-1)
+
+ # left_eye_point, right_eye_point = self._calculate_eye_landmarks(landmark_list_pixlevel, weights_468, weights_473, source_tform)
+ # left_eye_point, right_eye_point = self._process_eye_landmarks(transformed_vertices, source_tform)
+ # import pdb; pdb.set_trace()
+ return self._generate_final_image(new_landmarks, pupil_left_eye, pupil_right_eye, shape)
+ # return self._generate_final_image(new_landmarks, left_eye_point, right_eye_point, shape)
+
+ def _calculate_landmark_points(self, transformed_vertices):
+ lmk_b_coords = self.mediapipe_landmark_embedding['lmk_b_coords']
+ lmk_face_idx = self.mediapipe_landmark_embedding['lmk_face_idx']
+
+ base_points = mesh_points_by_barycentric_coordinates(
+ transformed_vertices.detach().cpu().numpy()[0],
+ self.faces.detach().cpu().numpy()[0],
+ lmk_face_idx, lmk_b_coords
+ )
+
+ RIGHT_EYE_INDICES = [4543, 4511, 4479, 4575]
+ LEFT_EYE_INDICES = [3997, 3965, 3933, 4020]
+ return np.concatenate([
+ base_points,
+ transformed_vertices.detach().cpu().numpy()[0][RIGHT_EYE_INDICES],
+ transformed_vertices.detach().cpu().numpy()[0][LEFT_EYE_INDICES]
+ ], axis=0)
+
+ def _create_landmark_list(self, vertices, transform, shape):
+ landmark_list = landmark_pb2.NormalizedLandmarkList()
+ landmark_list_pixlevel = []
+ for v in vertices:
+ img_x = (v[0] + 1) * 0.5 * self.image_size
+ img_y = (v[1] + 1) * 0.5 * self.image_size
+ projected = np.dot(transform.inverse.params, [img_x.cpu().numpy(), img_y.cpu().numpy(), 1.0])
+ landmark_list_pixlevel.append((projected[0], projected[1]))
+ landmark = landmark_list.landmark.add()
+ landmark.x = projected[0] / shape[1]
+ landmark.y = projected[1] / shape[0]
+ landmark.z = 1.0
+ return landmark_list, landmark_list_pixlevel
+
+ def _process_eye_landmarks(self, vertices, transform):
+ def project_eye_point(vertex_idx):
+ x = (vertices[0, vertex_idx, 0] + 1) * 0.5 * self.image_size
+ y = (vertices[0, vertex_idx, 1] + 1) * 0.5 * self.image_size
+ # import pdb; pdb.set_trace()
+ projected = np.dot(transform.inverse.params, [x.cpu().numpy(), y.cpu().numpy(), 1.0])
+ return (int(projected[0]), int(projected[1]))
+
+ return (
+ project_eye_point(4051), # Left eye index
+ project_eye_point(4597) # Right eye index
+ )
+
+ def _generate_final_image(self, landmarks, left_eye, right_eye, shape):
+ image = np.zeros([shape[0], shape[1], 3], dtype=np.uint8)
+
+ self.vis.mp_drawing.draw_landmarks(
+ image=image,
+ landmark_list=landmarks,
+ connections=self.vis.face_connection_spec.keys(),
+ landmark_drawing_spec=None,
+ connection_drawing_spec=self.vis.face_connection_spec
+ )
+
+ self._draw_eye_markers(image, np.array(left_eye, dtype=np.int32), np.array(right_eye, dtype=np.int32))
+ self._draw_landmark_58(image, landmarks, shape)
+ return np.copy(image)
+
+ def _draw_eye_markers(self, image, left_eye, right_eye):
+ y, x = left_eye[1]-1, left_eye[0]-1
+ image[y:y+3, x:x+3] = [10, 200, 250]
+
+ y, x = right_eye[1]-1, right_eye[0]-1
+ image[y:y+3, x:x+3] = [250, 200, 10]
+
+ def _draw_landmark_58(self, image, landmarks, shape):
+ if len(landmarks.landmark) > 57:
+ point = landmarks.landmark[57]
+ x = int(point.x * shape[1])
+ y = int(point.y * shape[0])
+ image[y-2:y+3, x-2:x+3] = [255, 255, 255]
+
+ def rasterize(self, vertices, faces, attributes=None, h=None, w=None):
+ fixed_vertices = vertices.clone()
+ fixed_vertices[...,:2] = -fixed_vertices[...,:2]
+
+ if h is None and w is None:
+ image_size = self.image_size
+ else:
+ image_size = [h, w]
+ if h>w:
+ fixed_vertices[..., 1] = fixed_vertices[..., 1]*h/w
+ else:
+ fixed_vertices[..., 0] = fixed_vertices[..., 0]*w/h
+ meshes_screen = Meshes(verts=fixed_vertices.float(), faces=faces.long())
+ pix_to_face, zbuf, bary_coords, dists = rasterize_meshes(
+ meshes_screen,
+ image_size=image_size,
+ blur_radius=0.0,
+ faces_per_pixel=1,
+ bin_size=None,
+ max_faces_per_bin=None,
+ perspective_correct=False,
+ )
+ vismask = (pix_to_face > -1).float()
+ D = attributes.shape[-1]
+ attributes = attributes.clone(); attributes = attributes.view(attributes.shape[0]*attributes.shape[1], 3, attributes.shape[-1])
+ N, H, W, K, _ = bary_coords.shape
+ mask = pix_to_face == -1
+ pix_to_face = pix_to_face.clone()
+ pix_to_face[mask] = 0
+ idx = pix_to_face.view(N * H * W * K, 1, 1).expand(N * H * W * K, 3, D)
+ pixel_face_vals = attributes.gather(0, idx).view(N, H, W, K, 3, D)
+ pixel_vals = (bary_coords[..., None] * pixel_face_vals).sum(dim=-2)
+ pixel_vals[mask] = 0 # Replace masked values in output.
+ pixel_vals = pixel_vals[:,:,:,0].permute(0,3,1,2)
+ pixel_vals = torch.cat([pixel_vals, vismask[:,:,:,0][:,None,:,:]], dim=1)
+ return pixel_vals
+
+ def add_SHlight(self, normal_images, sh_coeff):
+ '''
+ sh_coeff: [bz, 9, 3]
+ '''
+ N = normal_images
+ sh = torch.stack([
+ N[:,0]*0.+1., N[:,0], N[:,1], \
+ N[:,2], N[:,0]*N[:,1], N[:,0]*N[:,2],
+ N[:,1]*N[:,2], N[:,0]**2 - N[:,1]**2, 3*(N[:,2]**2) - 1
+ ],
+ 1) # [bz, 9, h, w]
+ sh = sh*self.constant_factor[None,:,None,None]
+ shading = torch.sum(sh_coeff[:,:,:,None,None]*sh[:,:,None,:,:], 1) # [bz, 9, 3, h, w]
+ return shading
+
+ def add_pointlight(self, vertices, normals, lights):
+ '''
+ vertices: [bz, nv, 3]
+ lights: [bz, nlight, 6]
+ returns:
+ shading: [bz, nv, 3]
+ '''
+ light_positions = lights[:,:,:3]; light_intensities = lights[:,:,3:]
+ directions_to_lights = F.normalize(light_positions[:,:,None,:] - vertices[:,None,:,:], dim=3)
+ normals_dot_lights = (normals[:,None,:,:]*directions_to_lights).sum(dim=3)
+ shading = normals_dot_lights[:,:,:,None]*light_intensities[:,:,None,:]
+ return shading.mean(1)
+
+
+ def add_directionlight(self, normals, lights):
+ '''
+ normals: [bz, nv, 3]
+ lights: [bz, nlight, 6]
+ returns:
+ shading: [bz, nv, 3]
+ '''
+ light_direction = lights[:,:,:3]; light_intensities = lights[:,:,3:]
+ directions_to_lights = F.normalize(light_direction[:,:,None,:].expand(-1,-1,normals.shape[1],-1), dim=3)
+ normals_dot_lights = torch.clamp((normals[:,None,:,:]*directions_to_lights).sum(dim=3), 0., 1.)
+ shading = normals_dot_lights[:,:,:,None]*light_intensities[:,:,None,:]
+ return shading.mean(1)
+
+
+
+ def render_multiface(self, vertices, transformed_vertices, faces):
+
+ batch_size = vertices.shape[0]
+
+ light_positions = torch.tensor(
+ [
+ [-1,-1,-1],
+ [1,-1,-1],
+ [-1,+1,-1],
+ [1,+1,-1],
+ [0,0,-1]
+ ]
+ )[None,:,:].expand(batch_size, -1, -1).float()
+
+ light_intensities = torch.ones_like(light_positions).float()*1.7
+ lights = torch.cat((light_positions, light_intensities), 2).to(vertices.device)
+
+ transformed_vertices[:,:,2] = transformed_vertices[:,:,2] + 10
+ normals = vertex_normals(vertices, faces)
+ face_normals = face_vertices(normals, faces)
+
+ colors = torch.tensor([180, 180, 180])[None, None, :].repeat(1, transformed_vertices.shape[1]+1, 1).float()/255.
+ colors = colors.cuda()
+ face_colors = face_vertices(colors, faces[0].unsqueeze(0))
+
+ colors = face_colors.expand(batch_size, -1, -1, -1)
+
+ attributes = torch.cat([colors,
+ face_normals],
+ -1)
+ rendering = self.rasterize(transformed_vertices, faces, attributes)
+
+ albedo_images = rendering[:, :3, :, :]
+
+ normal_images = rendering[:, 3:6, :, :]
+
+ shading = self.add_directionlight(normal_images.permute(0,2,3,1).reshape([batch_size, -1, 3]), lights)
+ shading_images = shading.reshape([batch_size, albedo_images.shape[2], albedo_images.shape[3], 3]).permute(0,3,1,2).contiguous()
+ shaded_images = albedo_images*shading_images
+
+ return shaded_images
diff --git a/skyreels_a1/src/smirk_encoder.py b/skyreels_a1/src/smirk_encoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..97a8355e37cf5ee7e29e3950b57ec5faf442c450
--- /dev/null
+++ b/skyreels_a1/src/smirk_encoder.py
@@ -0,0 +1,134 @@
+import torch
+import torch.nn.functional as F
+from torch import nn
+import timm
+
+
+def create_backbone(backbone_name, pretrained=True):
+ backbone = timm.create_model(backbone_name,
+ pretrained=pretrained,
+ features_only=True)
+ feature_dim = backbone.feature_info[-1]['num_chs']
+ return backbone, feature_dim
+
+class PoseEncoder(nn.Module):
+ def __init__(self) -> None:
+ super().__init__()
+
+ self.encoder, feature_dim = create_backbone('tf_mobilenetv3_small_minimal_100')
+
+ self.pose_cam_layers = nn.Sequential(
+ nn.Linear(feature_dim, 6)
+ )
+
+ self.init_weights()
+
+ def init_weights(self):
+ self.pose_cam_layers[-1].weight.data *= 0.001
+ self.pose_cam_layers[-1].bias.data *= 0.001
+
+ self.pose_cam_layers[-1].weight.data[3] = 0
+ self.pose_cam_layers[-1].bias.data[3] = 7
+
+
+ def forward(self, img):
+ features = self.encoder(img)[-1]
+
+ features = F.adaptive_avg_pool2d(features, (1, 1)).squeeze(-1).squeeze(-1)
+
+ outputs = {}
+
+ pose_cam = self.pose_cam_layers(features).reshape(img.size(0), -1)
+ outputs['pose_params'] = pose_cam[...,:3]
+ # import pdb;pdb.set_trace()
+ outputs['cam'] = pose_cam[...,3:]
+
+ return outputs
+
+
+class ShapeEncoder(nn.Module):
+ def __init__(self, n_shape=300) -> None:
+ super().__init__()
+
+ self.encoder, feature_dim = create_backbone('tf_mobilenetv3_large_minimal_100')
+
+ self.shape_layers = nn.Sequential(
+ nn.Linear(feature_dim, n_shape)
+ )
+
+ self.init_weights()
+
+
+ def init_weights(self):
+ self.shape_layers[-1].weight.data *= 0
+ self.shape_layers[-1].bias.data *= 0
+
+
+ def forward(self, img):
+ features = self.encoder(img)[-1]
+
+ features = F.adaptive_avg_pool2d(features, (1, 1)).squeeze(-1).squeeze(-1)
+
+ parameters = self.shape_layers(features).reshape(img.size(0), -1)
+
+ return {'shape_params': parameters}
+
+
+class ExpressionEncoder(nn.Module):
+ def __init__(self, n_exp=50) -> None:
+ super().__init__()
+
+ self.encoder, feature_dim = create_backbone('tf_mobilenetv3_large_minimal_100')
+
+ self.expression_layers = nn.Sequential(
+ nn.Linear(feature_dim, n_exp+2+3) # num expressions + jaw + eyelid
+ )
+
+ self.n_exp = n_exp
+ self.init_weights()
+
+
+ def init_weights(self):
+ self.expression_layers[-1].weight.data *= 0.1
+ self.expression_layers[-1].bias.data *= 0.1
+
+
+ def forward(self, img):
+ features = self.encoder(img)[-1]
+
+ features = F.adaptive_avg_pool2d(features, (1, 1)).squeeze(-1).squeeze(-1)
+
+
+ parameters = self.expression_layers(features).reshape(img.size(0), -1)
+
+ outputs = {}
+
+ outputs['expression_params'] = parameters[...,:self.n_exp]
+ outputs['eyelid_params'] = torch.clamp(parameters[...,self.n_exp:self.n_exp+2], 0, 1)
+ outputs['jaw_params'] = torch.cat([F.relu(parameters[...,self.n_exp+2].unsqueeze(-1)),
+ torch.clamp(parameters[...,self.n_exp+3:self.n_exp+5], -.2, .2)], dim=-1)
+
+ return outputs
+
+
+class SmirkEncoder(nn.Module):
+ def __init__(self, n_exp=50, n_shape=300) -> None:
+ super().__init__()
+
+ self.pose_encoder = PoseEncoder()
+
+ self.shape_encoder = ShapeEncoder(n_shape=n_shape)
+
+ self.expression_encoder = ExpressionEncoder(n_exp=n_exp)
+
+ def forward(self, img):
+ pose_outputs = self.pose_encoder(img)
+ shape_outputs = self.shape_encoder(img)
+ expression_outputs = self.expression_encoder(img)
+
+ outputs = {}
+ outputs.update(pose_outputs)
+ outputs.update(shape_outputs)
+ outputs.update(expression_outputs)
+
+ return outputs
diff --git a/skyreels_a1/src/utils/mediapipe_utils.py b/skyreels_a1/src/utils/mediapipe_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..7b79aebfebf8bcc9b1f7522e4ed5eaacf3be7057
--- /dev/null
+++ b/skyreels_a1/src/utils/mediapipe_utils.py
@@ -0,0 +1,137 @@
+import mediapipe as mp
+from mediapipe.tasks import python
+from mediapipe.tasks.python import vision
+import cv2
+import numpy as np
+import os
+
+
+import numpy as np
+import torch
+import torch.nn.functional as F
+import os
+import cv2
+
+# borrowed from https://github.com/daniilidis-group/neural_renderer/blob/master/neural_renderer/vertices_to_faces.py
+def face_vertices(vertices, faces):
+ """
+ :param vertices: [batch size, number of vertices, 3]
+ :param faces: [batch size, number of faces, 3]
+ :return: [batch size, number of faces, 3, 3]
+ """
+ assert (vertices.ndimension() == 3)
+ assert (faces.ndimension() == 3)
+ assert (vertices.shape[0] == faces.shape[0])
+ assert (vertices.shape[2] == 3)
+ assert (faces.shape[2] == 3)
+
+ bs, nv = vertices.shape[:2]
+ bs, nf = faces.shape[:2]
+ device = vertices.device
+ faces = faces + (torch.arange(bs, dtype=torch.int32).to(device) * nv)[:, None, None]
+ vertices = vertices.reshape((bs * nv, 3))
+ # pytorch only supports long and byte tensors for indexing
+ return vertices[faces.long()]
+
+def vertex_normals(vertices, faces):
+ """
+ :param vertices: [batch size, number of vertices, 3]
+ :param faces: [batch size, number of faces, 3]
+ :return: [batch size, number of vertices, 3]
+ """
+ assert (vertices.ndimension() == 3)
+ assert (faces.ndimension() == 3)
+ assert (vertices.shape[0] == faces.shape[0])
+ assert (vertices.shape[2] == 3)
+ assert (faces.shape[2] == 3)
+ bs, nv = vertices.shape[:2]
+ bs, nf = faces.shape[:2]
+ device = vertices.device
+ normals = torch.zeros(bs * nv, 3).to(device)
+
+ faces = faces + (torch.arange(bs, dtype=torch.int32).to(device) * nv)[:, None, None] # expanded faces
+ vertices_faces = vertices.reshape((bs * nv, 3))[faces.long()]
+
+ faces = faces.reshape(-1, 3)
+ vertices_faces = vertices_faces.reshape(-1, 3, 3)
+
+ normals.index_add_(0, faces[:, 1].long(),
+ torch.cross(vertices_faces[:, 2] - vertices_faces[:, 1], vertices_faces[:, 0] - vertices_faces[:, 1]))
+ normals.index_add_(0, faces[:, 2].long(),
+ torch.cross(vertices_faces[:, 0] - vertices_faces[:, 2], vertices_faces[:, 1] - vertices_faces[:, 2]))
+ normals.index_add_(0, faces[:, 0].long(),
+ torch.cross(vertices_faces[:, 1] - vertices_faces[:, 0], vertices_faces[:, 2] - vertices_faces[:, 0]))
+
+ normals = F.normalize(normals, eps=1e-6, dim=1)
+ normals = normals.reshape((bs, nv, 3))
+ # pytorch only supports long and byte tensors for indexing
+ return normals
+
+def batch_orth_proj(X, camera):
+ ''' orthgraphic projection
+ X: 3d vertices, [bz, n_point, 3]
+ camera: scale and translation, [bz, 3], [scale, tx, ty]
+ '''
+ #print('--------')
+ #print(camera[0, 1:].abs())
+ #print(X[0].abs().mean(0))
+
+ camera = camera.clone().view(-1, 1, 3)
+ X_trans = X[:, :, :2] + camera[:, :, 1:]
+ #print(X_trans[0].abs().mean(0))
+ X_trans = torch.cat([X_trans, X[:,:,2:]], 2)
+ Xn = (camera[:, :, 0:1] * X_trans)
+ return Xn
+
+class MP_2_FLAME():
+ """
+ Convert Mediapipe 52 blendshape scores to FLAME's coefficients
+ """
+ def __init__(self, mappings_path):
+ self.bs2exp = np.load(os.path.join(mappings_path, 'bs2exp.npy'))
+ self.bs2pose = np.load(os.path.join(mappings_path, 'bs2pose.npy'))
+ self.bs2eye = np.load(os.path.join(mappings_path, 'bs2eye.npy'))
+
+ def convert(self, blendshape_scores : np.array):
+ # blendshape_scores: [N, 52]
+
+ # Calculate expression, pose, and eye_pose using the mappings
+ exp = blendshape_scores @ self.bs2exp
+ pose = blendshape_scores @ self.bs2pose
+ pose[0, :3] = 0 # we do not support head rotation yet
+ eye_pose = blendshape_scores @ self.bs2eye
+
+ return exp, pose, eye_pose
+
+class MediaPipeUtils:
+ def __init__(self, model_asset_path='pretrained_models/mediapipe/face_landmarker.task', mappings_path='pretrained_models/mediapipe/'):
+ base_options = python.BaseOptions(model_asset_path=model_asset_path)
+ options = vision.FaceLandmarkerOptions(base_options=base_options,
+ output_face_blendshapes=True,
+ output_facial_transformation_matrixes=True,
+ num_faces=1,
+ min_face_detection_confidence=0.1,
+ min_face_presence_confidence=0.1)
+ self.detector = vision.FaceLandmarker.create_from_options(options)
+ self.mp2flame = MP_2_FLAME(mappings_path=mappings_path)
+
+ def run_mediapipe(self, image):
+ image_numpy = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
+ image = mp.Image(image_format=mp.ImageFormat.SRGB, data=image_numpy)
+ detection_result = self.detector.detect(image)
+
+ if len(detection_result.face_landmarks) == 0:
+ print('No face detected')
+ return None
+
+ blend_scores = detection_result.face_blendshapes[0]
+ blend_scores = np.array(list(map(lambda l: l.score, blend_scores)), dtype=np.float32).reshape(1, 52)
+ exp, pose, eye_pose = self.mp2flame.convert(blendshape_scores=blend_scores)
+
+ face_landmarks = detection_result.face_landmarks[0]
+ face_landmarks_numpy = np.zeros((478, 3))
+
+ for i, landmark in enumerate(face_landmarks):
+ face_landmarks_numpy[i] = [landmark.x * image.width, landmark.y * image.height, landmark.z]
+
+ return face_landmarks_numpy, exp, pose, eye_pose