diff --git a/.gitattributes b/.gitattributes index 005b600aa0ca51527676a414cc4da5fc025203d4..9e3a9029db520d01eb517befc51214c7a374298e 100644 --- a/.gitattributes +++ b/.gitattributes @@ -34,3 +34,6 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text *.zst filter=lfs diff=lfs merge=lfs -text *tfevents* filter=lfs diff=lfs merge=lfs -text wheels/nvdiffrast-0.3.3-cp310-cp310-linux_x86_64.whl filter=lfs diff=lfs merge=lfs -text +*.ply filter=lfs diff=lfs merge=lfs -text +*.webp filter=lfs diff=lfs merge=lfs -text +*.png filter=lfs diff=lfs merge=lfs -text diff --git a/assets/T.ply b/assets/T.ply new file mode 100644 index 0000000000000000000000000000000000000000..56c12d1c09d79d5557eb8421b442ec39820c35c7 --- /dev/null +++ b/assets/T.ply @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:163e3efe355f4c7fe36eb3b55563d1897ac1384c5ab2eb1acfc68700de2dc31b +size 2089367 diff --git a/assets/example_image/T.png b/assets/example_image/T.png index 79af51bc9d711951fbc63be16b7d07c84294355b..861ee434cb123a74d50b843631db47d0646675ed 100644 Binary files a/assets/example_image/T.png and b/assets/example_image/T.png differ diff --git a/assets/example_image/typical_building_building.png b/assets/example_image/typical_building_building.png index 4f9adcf79d4297bb9b906608c23c311f9b8f23d2..515be4e5bb423b92933f1dc11438b570c5f4db95 100644 Binary files a/assets/example_image/typical_building_building.png and b/assets/example_image/typical_building_building.png differ diff --git a/assets/example_image/typical_building_castle.png b/assets/example_image/typical_building_castle.png index 5f5f50733f3b8ed168026340ed679df357ccb9ec..8b4f705e8f3e37eb96cf948eb822b193e47c3bf8 100644 Binary files a/assets/example_image/typical_building_castle.png and b/assets/example_image/typical_building_castle.png differ diff --git a/assets/example_image/typical_building_colorful_cottage.png b/assets/example_image/typical_building_colorful_cottage.png index 94616b19be6a896413c287b3168b3b7886d64a56..d9f451150d723f39a402590c520702e4e7fd8e44 100644 Binary files a/assets/example_image/typical_building_colorful_cottage.png and b/assets/example_image/typical_building_colorful_cottage.png differ diff --git a/assets/example_image/typical_building_maya_pyramid.png b/assets/example_image/typical_building_maya_pyramid.png index 1d87f3c3980f80eee878b4f8ab69e32279a5ea50..d5db764b32e08a93ac098c06aac9329dba743ea9 100644 Binary files a/assets/example_image/typical_building_maya_pyramid.png and b/assets/example_image/typical_building_maya_pyramid.png differ diff --git a/assets/example_image/typical_building_mushroom.png b/assets/example_image/typical_building_mushroom.png index c4db49d4284e6fec83b2ea548e7e85d489711b84..fcc169f1f0eef1725af0ba60b429a5d6c550dfa5 100644 Binary files a/assets/example_image/typical_building_mushroom.png and b/assets/example_image/typical_building_mushroom.png differ diff --git a/assets/example_image/typical_building_space_station.png b/assets/example_image/typical_building_space_station.png index e37a5806d403dccc098d82492d687d71afa36850..bffc6286658add1b747af63e29d42c2735ed37a3 100644 Binary files a/assets/example_image/typical_building_space_station.png and b/assets/example_image/typical_building_space_station.png differ diff --git a/assets/example_image/typical_creature_dragon.png b/assets/example_image/typical_creature_dragon.png index c3fb92ff0400451c69f44fb75156f50db93da6ec..d62c7e18b52a7f1ac603c143abfc092b321e2734 100644 Binary files a/assets/example_image/typical_creature_dragon.png and b/assets/example_image/typical_creature_dragon.png differ diff --git a/assets/example_image/typical_creature_elephant.png b/assets/example_image/typical_creature_elephant.png index 6fc3cf1776c66b91e739cad8e839b52074a57e4c..d4f189675bc888238896eab93f4aefa6e2e32a9b 100644 Binary files a/assets/example_image/typical_creature_elephant.png and b/assets/example_image/typical_creature_elephant.png differ diff --git a/assets/example_image/typical_creature_furry.png b/assets/example_image/typical_creature_furry.png index eb4e8d6c6cac1e03a206429eaf7de261c6f14072..58033811f29f4757759eb66a351bbb82228c0f07 100644 Binary files a/assets/example_image/typical_creature_furry.png and b/assets/example_image/typical_creature_furry.png differ diff --git a/assets/example_image/typical_creature_quadruped.png b/assets/example_image/typical_creature_quadruped.png index b246e08e05702051fb22cced1366ab765cd6fbb0..503ff7d93aed42ef8b8ccaa741559007f4627e7e 100644 Binary files a/assets/example_image/typical_creature_quadruped.png and b/assets/example_image/typical_creature_quadruped.png differ diff --git a/assets/example_image/typical_creature_robot_crab.png b/assets/example_image/typical_creature_robot_crab.png index 8b4e10b353e0e9b60634ea272ff8fd9135fdd640..1546f322acc31caeb524a518979afbffe8de197f 100644 Binary files a/assets/example_image/typical_creature_robot_crab.png and b/assets/example_image/typical_creature_robot_crab.png differ diff --git a/assets/example_image/typical_creature_robot_dinosour.png b/assets/example_image/typical_creature_robot_dinosour.png index 7f8f51728fe1fecb0532673756b1601ef46edc2c..b8a802f1f64424db980dc09e35c7be98e526d9dd 100644 Binary files a/assets/example_image/typical_creature_robot_dinosour.png and b/assets/example_image/typical_creature_robot_dinosour.png differ diff --git a/assets/example_image/typical_creature_rock_monster.png b/assets/example_image/typical_creature_rock_monster.png index 29dc243b197d9b3ee4df9355a5f08752ef0b9b9e..8a6987064fddd7162ff55bd9a98c47c32a6b2397 100644 Binary files a/assets/example_image/typical_creature_rock_monster.png and b/assets/example_image/typical_creature_rock_monster.png differ diff --git a/assets/example_image/typical_humanoid_block_robot.png b/assets/example_image/typical_humanoid_block_robot.png index 195212e38e6a8e331b02c2d58728ba41dba429a1..d509f6d257ff3851eae6f12820ec2cf605bec85f 100644 Binary files a/assets/example_image/typical_humanoid_block_robot.png and b/assets/example_image/typical_humanoid_block_robot.png differ diff --git a/assets/example_image/typical_humanoid_dragonborn.png b/assets/example_image/typical_humanoid_dragonborn.png index 61ca2d9e69634c12ee9ae6f7e77f84839df83fdb..88f2d9070bb76eb31cd2b9e3c9677f70998e5d3b 100644 Binary files a/assets/example_image/typical_humanoid_dragonborn.png and b/assets/example_image/typical_humanoid_dragonborn.png differ diff --git a/assets/example_image/typical_humanoid_dwarf.png b/assets/example_image/typical_humanoid_dwarf.png index 16de1631fff3cc42a3a5d6a8b0f638da75ad7b2f..6bcc3945be7038c12c61d72950bab8bcb8475f10 100644 Binary files a/assets/example_image/typical_humanoid_dwarf.png and b/assets/example_image/typical_humanoid_dwarf.png differ diff --git a/assets/example_image/typical_humanoid_goblin.png b/assets/example_image/typical_humanoid_goblin.png index 4e4fe04517801d5722817e8dfaed2af83b31d67e..6f4a8142d9f452b233b6d5b0bd0d6dac5e41f94e 100644 Binary files a/assets/example_image/typical_humanoid_goblin.png and b/assets/example_image/typical_humanoid_goblin.png differ diff --git a/assets/example_image/typical_humanoid_mech.png b/assets/example_image/typical_humanoid_mech.png index f0fbbdf6cda5636f517b6e2fa3f20e15e56e3777..e0e07443d760b405979253c85bd28747d686976c 100644 Binary files a/assets/example_image/typical_humanoid_mech.png and b/assets/example_image/typical_humanoid_mech.png differ diff --git a/assets/example_image/typical_misc_crate.png b/assets/example_image/typical_misc_crate.png index c3086f885bf9fc27c398b5bacfb04a65bd7dfbd9..b66a57feffa7117f1f3c2d614d24001fb2611467 100644 Binary files a/assets/example_image/typical_misc_crate.png and b/assets/example_image/typical_misc_crate.png differ diff --git a/assets/example_image/typical_misc_fireplace.png b/assets/example_image/typical_misc_fireplace.png index 82d79bc10346604a8b8b9cc8e2c317e8dc6d8c47..c8f352fbb6b256d56d43febb052b91943b6d4e3d 100644 Binary files a/assets/example_image/typical_misc_fireplace.png and b/assets/example_image/typical_misc_fireplace.png differ diff --git a/assets/example_image/typical_misc_gate.png b/assets/example_image/typical_misc_gate.png index fa77919f9d9faabc26b9287b35c4dd3b4006163e..58f96079232d47f7615a491b95d2fc7cd635ef1d 100644 Binary files a/assets/example_image/typical_misc_gate.png and b/assets/example_image/typical_misc_gate.png differ diff --git a/assets/example_image/typical_misc_lantern.png b/assets/example_image/typical_misc_lantern.png index 4c93f5dea2638a5a169dd1557a36f6d34b57144d..0917aa4be24bec12ba961475efd5bef1c6f8fd2a 100644 Binary files a/assets/example_image/typical_misc_lantern.png and b/assets/example_image/typical_misc_lantern.png differ diff --git a/assets/example_image/typical_misc_magicbook.png b/assets/example_image/typical_misc_magicbook.png index 7dc521a10fda176694c30170811809050f478a66..f268ca7e37851578d888cb26f8057846f9e12e78 100644 Binary files a/assets/example_image/typical_misc_magicbook.png and b/assets/example_image/typical_misc_magicbook.png differ diff --git a/assets/example_image/typical_misc_mailbox.png b/assets/example_image/typical_misc_mailbox.png index b6e8bc50cd270bb7462eee2af7a6d5649ef54cf2..31a5b45556a0da3b7e36301230c7ffa1f15adad1 100644 Binary files a/assets/example_image/typical_misc_mailbox.png and b/assets/example_image/typical_misc_mailbox.png differ diff --git a/assets/example_image/typical_misc_monster_chest.png b/assets/example_image/typical_misc_monster_chest.png index 6d544370fa306138e7dbab3e548d6e05b8ef2317..660a8a9bbec77a967b0263b98900411d69588d71 100644 Binary files a/assets/example_image/typical_misc_monster_chest.png and b/assets/example_image/typical_misc_monster_chest.png differ diff --git a/assets/example_image/typical_misc_paper_machine.png b/assets/example_image/typical_misc_paper_machine.png index a630074dbfe32c53f52f2f27e5b6b3eff8469a9e..db27c6c49446d85e081133a195be4aff90e66f33 100644 Binary files a/assets/example_image/typical_misc_paper_machine.png and b/assets/example_image/typical_misc_paper_machine.png differ diff --git a/assets/example_image/typical_misc_phonograph.png b/assets/example_image/typical_misc_phonograph.png index 668662d741344ac16427259fc966186ef8ca97a9..d2f4c143fa51c16c4a3ab4a0a625319d95d4a8be 100644 Binary files a/assets/example_image/typical_misc_phonograph.png and b/assets/example_image/typical_misc_phonograph.png differ diff --git a/assets/example_image/typical_misc_portal2.png b/assets/example_image/typical_misc_portal2.png index 666daa75fbaf7df55585f7143906d158175be6be..6ab0c8b0784eaa370d566854974a7b4f1e0f4ff7 100644 Binary files a/assets/example_image/typical_misc_portal2.png and b/assets/example_image/typical_misc_portal2.png differ diff --git a/assets/example_image/typical_misc_storage_chest.png b/assets/example_image/typical_misc_storage_chest.png index 38f4bd31f8eb62badcc5e1a51d4612e528b4069e..2fcd82551d835903a914c8737b5650373776d363 100644 Binary files a/assets/example_image/typical_misc_storage_chest.png and b/assets/example_image/typical_misc_storage_chest.png differ diff --git a/assets/example_image/typical_misc_telephone.png b/assets/example_image/typical_misc_telephone.png index a0a7d65a300d9f1adc55b2fc36951731f5abb355..58e7a3a434274b58ef0ad567a19ec46e16bd4ba4 100644 Binary files a/assets/example_image/typical_misc_telephone.png and b/assets/example_image/typical_misc_telephone.png differ diff --git a/assets/example_image/typical_misc_television.png b/assets/example_image/typical_misc_television.png index 1d6b5882b42ce532f6a60080ad55bda7053530c0..79042fd48c953ccabc2f41333dfa12942ebbb5cd 100644 Binary files a/assets/example_image/typical_misc_television.png and b/assets/example_image/typical_misc_television.png differ diff --git a/assets/example_image/typical_misc_workbench.png b/assets/example_image/typical_misc_workbench.png index 88024f960ff56aa619b0c496f85de390076bbf5a..c964138ac20d2378c9ba5ecd35eca70f1c728ed2 100644 Binary files a/assets/example_image/typical_misc_workbench.png and b/assets/example_image/typical_misc_workbench.png differ diff --git a/assets/example_image/typical_vehicle_biplane.png b/assets/example_image/typical_vehicle_biplane.png index 7427cad3270d8ed33dad05c7a2ae1b0092b4beb2..4bca4e16ec6453e008d21b5a82dd5667a6ef453a 100644 Binary files a/assets/example_image/typical_vehicle_biplane.png and b/assets/example_image/typical_vehicle_biplane.png differ diff --git a/assets/example_image/typical_vehicle_bulldozer.png b/assets/example_image/typical_vehicle_bulldozer.png index 17ffe389498d9561ef92766de654ef17b5755f60..9fa27c21bb6beca459d19d2cbd536f0a41fdce0d 100644 Binary files a/assets/example_image/typical_vehicle_bulldozer.png and b/assets/example_image/typical_vehicle_bulldozer.png differ diff --git a/assets/example_image/typical_vehicle_cart.png b/assets/example_image/typical_vehicle_cart.png index 137bb4887f3879691ffff21227951790eb1840b4..d8848f3473355f2040c67a803847947e086dc969 100644 Binary files a/assets/example_image/typical_vehicle_cart.png and b/assets/example_image/typical_vehicle_cart.png differ diff --git a/assets/example_image/typical_vehicle_excavator.png b/assets/example_image/typical_vehicle_excavator.png index c434e8b0ab142ecc35caf91d42df5f4541825b8f..3dddee1020e052155d2e8f404d982f45786c5d07 100644 Binary files a/assets/example_image/typical_vehicle_excavator.png and b/assets/example_image/typical_vehicle_excavator.png differ diff --git a/assets/example_image/typical_vehicle_helicopter.png b/assets/example_image/typical_vehicle_helicopter.png index 39c2497d22ea519cddf576c6504954c338f943e4..499cc8e87a468582d3c64b96c30ddb11dbb04f1f 100644 Binary files a/assets/example_image/typical_vehicle_helicopter.png and b/assets/example_image/typical_vehicle_helicopter.png differ diff --git a/assets/example_image/typical_vehicle_locomotive.png b/assets/example_image/typical_vehicle_locomotive.png index dac6a2a2de9e8830bac53d3893aa1d3741916b1a..658e79d82868ae6e86ac108914169deafe02a53b 100644 Binary files a/assets/example_image/typical_vehicle_locomotive.png and b/assets/example_image/typical_vehicle_locomotive.png differ diff --git a/assets/example_image/typical_vehicle_pirate_ship.png b/assets/example_image/typical_vehicle_pirate_ship.png index 9eed1529f309c64fc6237caba97631cc1f2bab53..a3baba4b951316dc7e0e940c22ba6f537bd5db14 100644 Binary files a/assets/example_image/typical_vehicle_pirate_ship.png and b/assets/example_image/typical_vehicle_pirate_ship.png differ diff --git a/assets/example_image/weatherworn_misc_paper_machine3.png b/assets/example_image/weatherworn_misc_paper_machine3.png index 46e8a9dc123aaf71a41e994a8e50eabb4e53e721..253a9c5dd331fc1a3d0d66cb5e9afb0b5b82ac46 100644 Binary files a/assets/example_image/weatherworn_misc_paper_machine3.png and b/assets/example_image/weatherworn_misc_paper_machine3.png differ diff --git a/assets/example_multi_image/character_1.png b/assets/example_multi_image/character_1.png index 743117c4458af4a9db717c7c7c6b05b6b08037dc..9f56066845cfd18504ace4924c9a94544d55280c 100644 Binary files a/assets/example_multi_image/character_1.png and b/assets/example_multi_image/character_1.png differ diff --git a/assets/example_multi_image/character_2.png b/assets/example_multi_image/character_2.png index 5fc37f61179b2dc4293a40bec2959ac82fd7503c..32de7843e9458be83c0adbd5ddbd526d3db59c94 100644 Binary files a/assets/example_multi_image/character_2.png and b/assets/example_multi_image/character_2.png differ diff --git a/assets/example_multi_image/character_3.png b/assets/example_multi_image/character_3.png index c6e8cb9fb2ab3e86e749405f96ee026273e1e99b..036b8f12cc2e2744fa3ebf3fba41114b07b340be 100644 Binary files a/assets/example_multi_image/character_3.png and b/assets/example_multi_image/character_3.png differ diff --git a/assets/example_multi_image/mushroom_1.png b/assets/example_multi_image/mushroom_1.png index 982645f790b3c374ccab05f33371b2979b5a4031..ac2f8342feadfab20e91fbbc101a2c4db48a8556 100644 Binary files a/assets/example_multi_image/mushroom_1.png and b/assets/example_multi_image/mushroom_1.png differ diff --git a/assets/example_multi_image/mushroom_2.png b/assets/example_multi_image/mushroom_2.png index 48e961e28e6996bce3259b114da8c6aced35e777..c81fb2dbf291c1898afbbfb61cce8e781ac7c42e 100644 Binary files a/assets/example_multi_image/mushroom_2.png and b/assets/example_multi_image/mushroom_2.png differ diff --git a/assets/example_multi_image/mushroom_3.png b/assets/example_multi_image/mushroom_3.png index 16f2022458343da9a084013d8dba92faad0c2103..67c68bcfccbf331fa132c3c41d8c40b8c0548a31 100644 Binary files a/assets/example_multi_image/mushroom_3.png and b/assets/example_multi_image/mushroom_3.png differ diff --git a/assets/example_multi_image/orangeguy_1.png b/assets/example_multi_image/orangeguy_1.png index 5221021f6bde7d3fcb25fe0002ca3de528e4443d..a89f44e74fd8ed4d7d140b5447315d60266081be 100644 Binary files a/assets/example_multi_image/orangeguy_1.png and b/assets/example_multi_image/orangeguy_1.png differ diff --git a/assets/example_multi_image/orangeguy_2.png b/assets/example_multi_image/orangeguy_2.png index 156aea1f404d7748b73694c34037b1b13d87db66..b2554f8502d33981926afcfa8456dab03571ae63 100644 Binary files a/assets/example_multi_image/orangeguy_2.png and b/assets/example_multi_image/orangeguy_2.png differ diff --git a/assets/example_multi_image/orangeguy_3.png b/assets/example_multi_image/orangeguy_3.png index 0c15598b97acb46301876c4cdde34aacc8580a68..207e70ebc17912569261d5796b942b1d64dd39d7 100644 Binary files a/assets/example_multi_image/orangeguy_3.png and b/assets/example_multi_image/orangeguy_3.png differ diff --git a/assets/example_multi_image/popmart_1.png b/assets/example_multi_image/popmart_1.png index 81f1838f3a3698441d3d33bf04f0fce5dcdf05f3..f4437187e4b20d0d35fcee7718bdea4aaebe5aa9 100644 Binary files a/assets/example_multi_image/popmart_1.png and b/assets/example_multi_image/popmart_1.png differ diff --git a/assets/example_multi_image/popmart_2.png b/assets/example_multi_image/popmart_2.png index ac6fdf3c53aa95dd0763a4e865ba7583768f2ac8..3747f78e72c5d3199ab077511b1e1897356a7f85 100644 Binary files a/assets/example_multi_image/popmart_2.png and b/assets/example_multi_image/popmart_2.png differ diff --git a/assets/example_multi_image/popmart_3.png b/assets/example_multi_image/popmart_3.png index c83ea960e3aa151151260427d10fc5671619cbee..230775ef5b47d221e106aa2395cb70af93befa49 100644 Binary files a/assets/example_multi_image/popmart_3.png and b/assets/example_multi_image/popmart_3.png differ diff --git a/assets/example_multi_image/rabbit_1.png b/assets/example_multi_image/rabbit_1.png index 0cd5708a752cb3951d5edb41165d23f1246955e1..f8389134e05238573f19962c2051a4a287d5fe81 100644 Binary files a/assets/example_multi_image/rabbit_1.png and b/assets/example_multi_image/rabbit_1.png differ diff --git a/assets/example_multi_image/rabbit_2.png b/assets/example_multi_image/rabbit_2.png index 95492498199904f299a88ba06a80fc0742f874f9..03f5d8d502fc6a10a1a915462272cc71075bc51e 100644 Binary files a/assets/example_multi_image/rabbit_2.png and b/assets/example_multi_image/rabbit_2.png differ diff --git a/assets/example_multi_image/rabbit_3.png b/assets/example_multi_image/rabbit_3.png index a83285e29702f9680cd5530fa2ab9526e7812352..4e0f07e8a188e456451da5961b866b74d98d104f 100644 Binary files a/assets/example_multi_image/rabbit_3.png and b/assets/example_multi_image/rabbit_3.png differ diff --git a/assets/example_multi_image/tiger_1.png b/assets/example_multi_image/tiger_1.png index c4f87f93b63873a81c1e2bda18937a165d49a773..bfd8a2cb78ac39ff4fbd46555e0af273d2319a27 100644 Binary files a/assets/example_multi_image/tiger_1.png and b/assets/example_multi_image/tiger_1.png differ diff --git a/assets/example_multi_image/tiger_2.png b/assets/example_multi_image/tiger_2.png index 8fb9818ab2c6920be720811c04babb4024372c40..23ba80ad8bdd22f161e59b6d7b0a17e2f9082168 100644 Binary files a/assets/example_multi_image/tiger_2.png and b/assets/example_multi_image/tiger_2.png differ diff --git a/assets/example_multi_image/tiger_3.png b/assets/example_multi_image/tiger_3.png index 53689b9b2e3deeeb968628f6fcb636cf6d1223a4..0e0831948ddbc2ecb108a141ad7345ea55e408b4 100644 Binary files a/assets/example_multi_image/tiger_3.png and b/assets/example_multi_image/tiger_3.png differ diff --git a/assets/example_multi_image/yoimiya_1.png b/assets/example_multi_image/yoimiya_1.png index da323f970a288542665e316a8447b1cccf54998d..bb6519735180acb52c952cfa99be6f57d859941c 100644 Binary files a/assets/example_multi_image/yoimiya_1.png and b/assets/example_multi_image/yoimiya_1.png differ diff --git a/assets/example_multi_image/yoimiya_2.png b/assets/example_multi_image/yoimiya_2.png index d38d854fc264025034ded35f3363e95cf509a0b4..e2f03461b9f43ab58ad7810371a256744c215d70 100644 Binary files a/assets/example_multi_image/yoimiya_2.png and b/assets/example_multi_image/yoimiya_2.png differ diff --git a/assets/example_multi_image/yoimiya_3.png b/assets/example_multi_image/yoimiya_3.png index f2c8a7ca4085badda0fede3e4ca92dac1565ed24..54b268dfe229dcebc192e458b0f5c30ef34458d1 100644 Binary files a/assets/example_multi_image/yoimiya_3.png and b/assets/example_multi_image/yoimiya_3.png differ diff --git a/assets/logo.webp b/assets/logo.webp new file mode 100644 index 0000000000000000000000000000000000000000..aaf832024623414cb6336b46d5b8b27b7b7b039a --- /dev/null +++ b/assets/logo.webp @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:1548a7b7f6b0fb3c06091529bb5052f0ee9a119eb4e1a014325d6561e9b9f2d1 +size 1403066 diff --git a/assets/teaser.png b/assets/teaser.png new file mode 100644 index 0000000000000000000000000000000000000000..5dae838acc6418b48af635c9bceefaa84e5ea446 --- /dev/null +++ b/assets/teaser.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:a83608f6f4ae71eb7b96965a16978b5f5fb0d594a34bf1066c105f93af71c4d5 +size 2098442 diff --git a/configs/generation/slat_flow_img_dit_L_64l8p2_fp16.json b/configs/generation/slat_flow_img_dit_L_64l8p2_fp16.json new file mode 100644 index 0000000000000000000000000000000000000000..9af591f33cbed5d6f44bd99ea174da7f97f4c741 --- /dev/null +++ b/configs/generation/slat_flow_img_dit_L_64l8p2_fp16.json @@ -0,0 +1,102 @@ +{ + "models": { + "denoiser": { + "name": "ElasticSLatFlowModel", + "args": { + "resolution": 64, + "in_channels": 8, + "out_channels": 8, + "model_channels": 1024, + "cond_channels": 1024, + "num_blocks": 24, + "num_heads": 16, + "mlp_ratio": 4, + "patch_size": 2, + "num_io_res_blocks": 2, + "io_block_channels": [128], + "pe_mode": "ape", + "qk_rms_norm": true, + "use_fp16": true + } + } + }, + "dataset": { + "name": "ImageConditionedSLat", + "args": { + "latent_model": "dinov2_vitl14_reg_slat_enc_swin8_B_64l8_fp16", + "min_aesthetic_score": 4.5, + "max_num_voxels": 32768, + "image_size": 518, + "normalization": { + "mean": [ + -2.1687545776367188, + -0.004347046371549368, + -0.13352349400520325, + -0.08418072760105133, + -0.5271206498146057, + 0.7238689064979553, + -1.1414450407028198, + 1.2039363384246826 + ], + "std": [ + 2.377650737762451, + 2.386378288269043, + 2.124418020248413, + 2.1748552322387695, + 2.663944721221924, + 2.371192216873169, + 2.6217446327209473, + 2.684523105621338 + ] + }, + "pretrained_slat_dec": "JeffreyXiang/TRELLIS-image-large/ckpts/slat_dec_gs_swin8_B_64l8gs32_fp16" + } + }, + "trainer": { + "name": "ImageConditionedSparseFlowMatchingCFGTrainer", + "args": { + "max_steps": 1000000, + "batch_size_per_gpu": 8, + "batch_split": 4, + "optimizer": { + "name": "AdamW", + "args": { + "lr": 0.0001, + "weight_decay": 0.0 + } + }, + "ema_rate": [ + 0.9999 + ], + "fp16_mode": "inflat_all", + "fp16_scale_growth": 0.001, + "elastic": { + "name": "LinearMemoryController", + "args": { + "target_ratio": 0.75, + "max_mem_ratio_start": 0.5 + } + }, + "grad_clip": { + "name": "AdaptiveGradClipper", + "args": { + "max_norm": 1.0, + "clip_percentile": 95 + } + }, + "i_log": 500, + "i_sample": 10000, + "i_save": 10000, + "p_uncond": 0.1, + "t_schedule": { + "name": "logitNormal", + "args": { + "mean": 1.0, + "std": 1.0 + } + }, + "sigma_min": 1e-5, + "image_cond_model": "dinov2_vitl14_reg" + } + } +} \ No newline at end of file diff --git a/configs/generation/slat_flow_txt_dit_B_64l8p2_fp16.json b/configs/generation/slat_flow_txt_dit_B_64l8p2_fp16.json new file mode 100644 index 0000000000000000000000000000000000000000..5f780999aa8e880cbd79d5e1c6eeeca39de1de55 --- /dev/null +++ b/configs/generation/slat_flow_txt_dit_B_64l8p2_fp16.json @@ -0,0 +1,101 @@ +{ + "models": { + "denoiser": { + "name": "ElasticSLatFlowModel", + "args": { + "resolution": 64, + "in_channels": 8, + "out_channels": 8, + "model_channels": 768, + "cond_channels": 768, + "num_blocks": 12, + "num_heads": 12, + "mlp_ratio": 4, + "patch_size": 2, + "num_io_res_blocks": 2, + "io_block_channels": [128], + "pe_mode": "ape", + "qk_rms_norm": true, + "use_fp16": true + } + } + }, + "dataset": { + "name": "TextConditionedSLat", + "args": { + "latent_model": "dinov2_vitl14_reg_slat_enc_swin8_B_64l8_fp16", + "min_aesthetic_score": 4.5, + "max_num_voxels": 32768, + "normalization": { + "mean": [ + -2.1687545776367188, + -0.004347046371549368, + -0.13352349400520325, + -0.08418072760105133, + -0.5271206498146057, + 0.7238689064979553, + -1.1414450407028198, + 1.2039363384246826 + ], + "std": [ + 2.377650737762451, + 2.386378288269043, + 2.124418020248413, + 2.1748552322387695, + 2.663944721221924, + 2.371192216873169, + 2.6217446327209473, + 2.684523105621338 + ] + }, + "pretrained_slat_dec": "JeffreyXiang/TRELLIS-image-large/ckpts/slat_dec_gs_swin8_B_64l8gs32_fp16" + } + }, + "trainer": { + "name": "TextConditionedSparseFlowMatchingCFGTrainer", + "args": { + "max_steps": 1000000, + "batch_size_per_gpu": 16, + "batch_split": 4, + "optimizer": { + "name": "AdamW", + "args": { + "lr": 0.0001, + "weight_decay": 0.0 + } + }, + "ema_rate": [ + 0.9999 + ], + "fp16_mode": "inflat_all", + "fp16_scale_growth": 0.001, + "elastic": { + "name": "LinearMemoryController", + "args": { + "target_ratio": 0.75, + "max_mem_ratio_start": 0.5 + } + }, + "grad_clip": { + "name": "AdaptiveGradClipper", + "args": { + "max_norm": 1.0, + "clip_percentile": 95 + } + }, + "i_log": 500, + "i_sample": 10000, + "i_save": 10000, + "p_uncond": 0.1, + "t_schedule": { + "name": "logitNormal", + "args": { + "mean": 1.0, + "std": 1.0 + } + }, + "sigma_min": 1e-5, + "text_cond_model": "openai/clip-vit-large-patch14" + } + } +} \ No newline at end of file diff --git a/configs/generation/slat_flow_txt_dit_L_64l8p2_fp16.json b/configs/generation/slat_flow_txt_dit_L_64l8p2_fp16.json new file mode 100644 index 0000000000000000000000000000000000000000..7e970a41146f40e7dfe75bb2b522104671d93486 --- /dev/null +++ b/configs/generation/slat_flow_txt_dit_L_64l8p2_fp16.json @@ -0,0 +1,101 @@ +{ + "models": { + "denoiser": { + "name": "ElasticSLatFlowModel", + "args": { + "resolution": 64, + "in_channels": 8, + "out_channels": 8, + "model_channels": 1024, + "cond_channels": 768, + "num_blocks": 24, + "num_heads": 16, + "mlp_ratio": 4, + "patch_size": 2, + "num_io_res_blocks": 2, + "io_block_channels": [128], + "pe_mode": "ape", + "qk_rms_norm": true, + "use_fp16": true + } + } + }, + "dataset": { + "name": "TextConditionedSLat", + "args": { + "latent_model": "dinov2_vitl14_reg_slat_enc_swin8_B_64l8_fp16", + "min_aesthetic_score": 4.5, + "max_num_voxels": 32768, + "normalization": { + "mean": [ + -2.1687545776367188, + -0.004347046371549368, + -0.13352349400520325, + -0.08418072760105133, + -0.5271206498146057, + 0.7238689064979553, + -1.1414450407028198, + 1.2039363384246826 + ], + "std": [ + 2.377650737762451, + 2.386378288269043, + 2.124418020248413, + 2.1748552322387695, + 2.663944721221924, + 2.371192216873169, + 2.6217446327209473, + 2.684523105621338 + ] + }, + "pretrained_slat_dec": "JeffreyXiang/TRELLIS-image-large/ckpts/slat_dec_gs_swin8_B_64l8gs32_fp16" + } + }, + "trainer": { + "name": "TextConditionedSparseFlowMatchingCFGTrainer", + "args": { + "max_steps": 1000000, + "batch_size_per_gpu": 8, + "batch_split": 4, + "optimizer": { + "name": "AdamW", + "args": { + "lr": 0.0001, + "weight_decay": 0.0 + } + }, + "ema_rate": [ + 0.9999 + ], + "fp16_mode": "inflat_all", + "fp16_scale_growth": 0.001, + "elastic": { + "name": "LinearMemoryController", + "args": { + "target_ratio": 0.75, + "max_mem_ratio_start": 0.5 + } + }, + "grad_clip": { + "name": "AdaptiveGradClipper", + "args": { + "max_norm": 1.0, + "clip_percentile": 95 + } + }, + "i_log": 500, + "i_sample": 10000, + "i_save": 10000, + "p_uncond": 0.1, + "t_schedule": { + "name": "logitNormal", + "args": { + "mean": 1.0, + "std": 1.0 + } + }, + "sigma_min": 1e-5, + "text_cond_model": "openai/clip-vit-large-patch14" + } + } +} \ No newline at end of file diff --git a/configs/generation/slat_flow_txt_dit_XL_64l8p2_fp16.json b/configs/generation/slat_flow_txt_dit_XL_64l8p2_fp16.json new file mode 100644 index 0000000000000000000000000000000000000000..2aabe40000fcf4c38d2f1942282356a0fd07fbef --- /dev/null +++ b/configs/generation/slat_flow_txt_dit_XL_64l8p2_fp16.json @@ -0,0 +1,101 @@ +{ + "models": { + "denoiser": { + "name": "ElasticSLatFlowModel", + "args": { + "resolution": 64, + "in_channels": 8, + "out_channels": 8, + "model_channels": 1280, + "cond_channels": 768, + "num_blocks": 28, + "num_heads": 16, + "mlp_ratio": 4, + "patch_size": 2, + "num_io_res_blocks": 3, + "io_block_channels": [256], + "pe_mode": "ape", + "qk_rms_norm": true, + "use_fp16": true + } + } + }, + "dataset": { + "name": "TextConditionedSLat", + "args": { + "latent_model": "dinov2_vitl14_reg_slat_enc_swin8_B_64l8_fp16", + "min_aesthetic_score": 4.5, + "max_num_voxels": 32768, + "normalization": { + "mean": [ + -2.1687545776367188, + -0.004347046371549368, + -0.13352349400520325, + -0.08418072760105133, + -0.5271206498146057, + 0.7238689064979553, + -1.1414450407028198, + 1.2039363384246826 + ], + "std": [ + 2.377650737762451, + 2.386378288269043, + 2.124418020248413, + 2.1748552322387695, + 2.663944721221924, + 2.371192216873169, + 2.6217446327209473, + 2.684523105621338 + ] + }, + "pretrained_slat_dec": "JeffreyXiang/TRELLIS-image-large/ckpts/slat_dec_gs_swin8_B_64l8gs32_fp16" + } + }, + "trainer": { + "name": "TextConditionedSparseFlowMatchingCFGTrainer", + "args": { + "max_steps": 1000000, + "batch_size_per_gpu": 4, + "batch_split": 4, + "optimizer": { + "name": "AdamW", + "args": { + "lr": 0.0001, + "weight_decay": 0.0 + } + }, + "ema_rate": [ + 0.9999 + ], + "fp16_mode": "inflat_all", + "fp16_scale_growth": 0.001, + "elastic": { + "name": "LinearMemoryController", + "args": { + "target_ratio": 0.75, + "max_mem_ratio_start": 0.5 + } + }, + "grad_clip": { + "name": "AdaptiveGradClipper", + "args": { + "max_norm": 1.0, + "clip_percentile": 95 + } + }, + "i_log": 500, + "i_sample": 10000, + "i_save": 10000, + "p_uncond": 0.1, + "t_schedule": { + "name": "logitNormal", + "args": { + "mean": 1.0, + "std": 1.0 + } + }, + "sigma_min": 1e-5, + "text_cond_model": "openai/clip-vit-large-patch14" + } + } +} \ No newline at end of file diff --git a/configs/generation/ss_flow_img_dit_L_16l8_fp16.json b/configs/generation/ss_flow_img_dit_L_16l8_fp16.json new file mode 100644 index 0000000000000000000000000000000000000000..8fcda2d87895260fbb26f1a1f073a9e9cfdb5ab6 --- /dev/null +++ b/configs/generation/ss_flow_img_dit_L_16l8_fp16.json @@ -0,0 +1,70 @@ +{ + "models": { + "denoiser": { + "name": "SparseStructureFlowModel", + "args": { + "resolution": 16, + "in_channels": 8, + "out_channels": 8, + "model_channels": 1024, + "cond_channels": 1024, + "num_blocks": 24, + "num_heads": 16, + "mlp_ratio": 4, + "patch_size": 1, + "pe_mode": "ape", + "qk_rms_norm": true, + "use_fp16": true + } + } + }, + "dataset": { + "name": "ImageConditionedSparseStructureLatent", + "args": { + "latent_model": "ss_enc_conv3d_16l8_fp16", + "min_aesthetic_score": 4.5, + "image_size": 518, + "pretrained_ss_dec": "JeffreyXiang/TRELLIS-image-large/ckpts/ss_dec_conv3d_16l8_fp16" + } + }, + "trainer": { + "name": "ImageConditionedFlowMatchingCFGTrainer", + "args": { + "max_steps": 1000000, + "batch_size_per_gpu": 8, + "batch_split": 1, + "optimizer": { + "name": "AdamW", + "args": { + "lr": 0.0001, + "weight_decay": 0.0 + } + }, + "ema_rate": [ + 0.9999 + ], + "fp16_mode": "inflat_all", + "fp16_scale_growth": 0.001, + "grad_clip": { + "name": "AdaptiveGradClipper", + "args": { + "max_norm": 1.0, + "clip_percentile": 95 + } + }, + "i_log": 500, + "i_sample": 10000, + "i_save": 10000, + "p_uncond": 0.1, + "t_schedule": { + "name": "logitNormal", + "args": { + "mean": 1.0, + "std": 1.0 + } + }, + "sigma_min": 1e-5, + "image_cond_model": "dinov2_vitl14_reg" + } + } +} \ No newline at end of file diff --git a/configs/generation/ss_flow_txt_dit_B_16l8_fp16.json b/configs/generation/ss_flow_txt_dit_B_16l8_fp16.json new file mode 100644 index 0000000000000000000000000000000000000000..be8040e8301de2fb1442be4cd3f6381f579d3871 --- /dev/null +++ b/configs/generation/ss_flow_txt_dit_B_16l8_fp16.json @@ -0,0 +1,69 @@ +{ + "models": { + "denoiser": { + "name": "SparseStructureFlowModel", + "args": { + "resolution": 16, + "in_channels": 8, + "out_channels": 8, + "model_channels": 768, + "cond_channels": 768, + "num_blocks": 12, + "num_heads": 12, + "mlp_ratio": 4, + "patch_size": 1, + "pe_mode": "ape", + "qk_rms_norm": true, + "use_fp16": true + } + } + }, + "dataset": { + "name": "TextConditionedSparseStructureLatent", + "args": { + "latent_model": "ss_enc_conv3d_16l8_fp16", + "min_aesthetic_score": 4.5, + "pretrained_ss_dec": "JeffreyXiang/TRELLIS-image-large/ckpts/ss_dec_conv3d_16l8_fp16" + } + }, + "trainer": { + "name": "TextConditionedFlowMatchingCFGTrainer", + "args": { + "max_steps": 1000000, + "batch_size_per_gpu": 16, + "batch_split": 1, + "optimizer": { + "name": "AdamW", + "args": { + "lr": 0.0001, + "weight_decay": 0.0 + } + }, + "ema_rate": [ + 0.9999 + ], + "fp16_mode": "inflat_all", + "fp16_scale_growth": 0.001, + "grad_clip": { + "name": "AdaptiveGradClipper", + "args": { + "max_norm": 1.0, + "clip_percentile": 95 + } + }, + "i_log": 500, + "i_sample": 10000, + "i_save": 10000, + "p_uncond": 0.1, + "t_schedule": { + "name": "logitNormal", + "args": { + "mean": 1.0, + "std": 1.0 + } + }, + "sigma_min": 1e-5, + "text_cond_model": "openai/clip-vit-large-patch14" + } + } +} \ No newline at end of file diff --git a/configs/generation/ss_flow_txt_dit_L_16l8_fp16.json b/configs/generation/ss_flow_txt_dit_L_16l8_fp16.json new file mode 100644 index 0000000000000000000000000000000000000000..1f11a3803f1503966b07b2056852a13ebc79eaff --- /dev/null +++ b/configs/generation/ss_flow_txt_dit_L_16l8_fp16.json @@ -0,0 +1,69 @@ +{ + "models": { + "denoiser": { + "name": "SparseStructureFlowModel", + "args": { + "resolution": 16, + "in_channels": 8, + "out_channels": 8, + "model_channels": 1024, + "cond_channels": 768, + "num_blocks": 24, + "num_heads": 16, + "mlp_ratio": 4, + "patch_size": 1, + "pe_mode": "ape", + "qk_rms_norm": true, + "use_fp16": true + } + } + }, + "dataset": { + "name": "TextConditionedSparseStructureLatent", + "args": { + "latent_model": "ss_enc_conv3d_16l8_fp16", + "min_aesthetic_score": 4.5, + "pretrained_ss_dec": "JeffreyXiang/TRELLIS-image-large/ckpts/ss_dec_conv3d_16l8_fp16" + } + }, + "trainer": { + "name": "TextConditionedFlowMatchingCFGTrainer", + "args": { + "max_steps": 1000000, + "batch_size_per_gpu": 8, + "batch_split": 1, + "optimizer": { + "name": "AdamW", + "args": { + "lr": 0.0001, + "weight_decay": 0.0 + } + }, + "ema_rate": [ + 0.9999 + ], + "fp16_mode": "inflat_all", + "fp16_scale_growth": 0.001, + "grad_clip": { + "name": "AdaptiveGradClipper", + "args": { + "max_norm": 1.0, + "clip_percentile": 95 + } + }, + "i_log": 500, + "i_sample": 10000, + "i_save": 10000, + "p_uncond": 0.1, + "t_schedule": { + "name": "logitNormal", + "args": { + "mean": 1.0, + "std": 1.0 + } + }, + "sigma_min": 1e-5, + "text_cond_model": "openai/clip-vit-large-patch14" + } + } +} \ No newline at end of file diff --git a/configs/generation/ss_flow_txt_dit_XL_16l8_fp16.json b/configs/generation/ss_flow_txt_dit_XL_16l8_fp16.json new file mode 100644 index 0000000000000000000000000000000000000000..0d4bebea08c3fcd2bae7a733bbfaaecc992fc3de --- /dev/null +++ b/configs/generation/ss_flow_txt_dit_XL_16l8_fp16.json @@ -0,0 +1,70 @@ +{ + "models": { + "denoiser": { + "name": "SparseStructureFlowModel", + "args": { + "resolution": 16, + "in_channels": 8, + "out_channels": 8, + "model_channels": 1280, + "cond_channels": 768, + "num_blocks": 28, + "num_heads": 16, + "mlp_ratio": 4, + "patch_size": 1, + "pe_mode": "ape", + "qk_rms_norm": true, + "qk_rms_norm_cross": true, + "use_fp16": true + } + } + }, + "dataset": { + "name": "TextConditionedSparseStructureLatent", + "args": { + "latent_model": "ss_enc_conv3d_16l8_fp16", + "min_aesthetic_score": 4.5, + "pretrained_ss_dec": "JeffreyXiang/TRELLIS-image-large/ckpts/ss_dec_conv3d_16l8_fp16" + } + }, + "trainer": { + "name": "TextConditionedFlowMatchingCFGTrainer", + "args": { + "max_steps": 1000000, + "batch_size_per_gpu": 4, + "batch_split": 1, + "optimizer": { + "name": "AdamW", + "args": { + "lr": 0.0001, + "weight_decay": 0.0 + } + }, + "ema_rate": [ + 0.9999 + ], + "fp16_mode": "inflat_all", + "fp16_scale_growth": 0.001, + "grad_clip": { + "name": "AdaptiveGradClipper", + "args": { + "max_norm": 1.0, + "clip_percentile": 95 + } + }, + "i_log": 500, + "i_sample": 10000, + "i_save": 10000, + "p_uncond": 0.1, + "t_schedule": { + "name": "logitNormal", + "args": { + "mean": 1.0, + "std": 1.0 + } + }, + "sigma_min": 1e-5, + "text_cond_model": "openai/clip-vit-large-patch14" + } + } +} \ No newline at end of file diff --git a/configs/vae/slat_vae_dec_mesh_swin8_B_64l8_fp16.json b/configs/vae/slat_vae_dec_mesh_swin8_B_64l8_fp16.json new file mode 100644 index 0000000000000000000000000000000000000000..c86b42cbe94e6acda72c0436d57dd50189ef927d --- /dev/null +++ b/configs/vae/slat_vae_dec_mesh_swin8_B_64l8_fp16.json @@ -0,0 +1,73 @@ +{ + "models": { + "decoder": { + "name": "ElasticSLatMeshDecoder", + "args": { + "resolution": 64, + "model_channels": 768, + "latent_channels": 8, + "num_blocks": 12, + "num_heads": 12, + "mlp_ratio": 4, + "attn_mode": "swin", + "window_size": 8, + "use_fp16": true, + "representation_config": { + "use_color": true + } + } + } + }, + "dataset": { + "name": "Slat2RenderGeo", + "args": { + "image_size": 512, + "latent_model": "dinov2_vitl14_reg_slat_enc_swin8_B_64l8_fp16", + "min_aesthetic_score": 4.5, + "max_num_voxels": 32768 + } + }, + "trainer": { + "name": "SLatVaeMeshDecoderTrainer", + "args": { + "max_steps": 1000000, + "batch_size_per_gpu": 4, + "batch_split": 4, + "optimizer": { + "name": "AdamW", + "args": { + "lr": 1e-4, + "weight_decay": 0.0 + } + }, + "ema_rate": [ + 0.9999 + ], + "fp16_mode": "inflat_all", + "fp16_scale_growth": 0.001, + "elastic": { + "name": "LinearMemoryController", + "args": { + "target_ratio": 0.75, + "max_mem_ratio_start": 0.5 + } + }, + "grad_clip": { + "name": "AdaptiveGradClipper", + "args": { + "max_norm": 1.0, + "clip_percentile": 95 + } + }, + "i_log": 500, + "i_sample": 10000, + "i_save": 10000, + "lambda_ssim": 0.2, + "lambda_lpips": 0.2, + "lambda_tsdf": 0.01, + "lambda_depth": 10.0, + "lambda_color": 0.1, + "depth_loss_type": "smooth_l1" + } + } +} \ No newline at end of file diff --git a/configs/vae/slat_vae_dec_rf_swin8_B_64l8_fp16.json b/configs/vae/slat_vae_dec_rf_swin8_B_64l8_fp16.json new file mode 100644 index 0000000000000000000000000000000000000000..5a7aed900f86e946ef671abe8ecde047827b4f6c --- /dev/null +++ b/configs/vae/slat_vae_dec_rf_swin8_B_64l8_fp16.json @@ -0,0 +1,71 @@ +{ + "models": { + "decoder": { + "name": "ElasticSLatRadianceFieldDecoder", + "args": { + "resolution": 64, + "model_channels": 768, + "latent_channels": 8, + "num_blocks": 12, + "num_heads": 12, + "mlp_ratio": 4, + "attn_mode": "swin", + "window_size": 8, + "use_fp16": true, + "representation_config": { + "rank": 16, + "dim": 8 + } + } + } + }, + "dataset": { + "name": "SLat2Render", + "args": { + "image_size": 512, + "latent_model": "dinov2_vitl14_reg_slat_enc_swin8_B_64l8_fp16", + "min_aesthetic_score": 4.5, + "max_num_voxels": 32768 + } + }, + "trainer": { + "name": "SLatVaeRadianceFieldDecoderTrainer", + "args": { + "max_steps": 1000000, + "batch_size_per_gpu": 4, + "batch_split": 2, + "optimizer": { + "name": "AdamW", + "args": { + "lr": 1e-4, + "weight_decay": 0.0 + } + }, + "ema_rate": [ + 0.9999 + ], + "fp16_mode": "inflat_all", + "fp16_scale_growth": 0.001, + "elastic": { + "name": "LinearMemoryController", + "args": { + "target_ratio": 0.75, + "max_mem_ratio_start": 0.5 + } + }, + "grad_clip": { + "name": "AdaptiveGradClipper", + "args": { + "max_norm": 1.0, + "clip_percentile": 95 + } + }, + "i_log": 500, + "i_sample": 10000, + "i_save": 10000, + "loss_type": "l1", + "lambda_ssim": 0.2, + "lambda_lpips": 0.2 + } + } +} \ No newline at end of file diff --git a/configs/vae/slat_vae_enc_dec_gs_swin8_B_64l8_fp16.json b/configs/vae/slat_vae_enc_dec_gs_swin8_B_64l8_fp16.json new file mode 100644 index 0000000000000000000000000000000000000000..d012b50bc0243acf22655ad1b4c79e99754f9ac3 --- /dev/null +++ b/configs/vae/slat_vae_enc_dec_gs_swin8_B_64l8_fp16.json @@ -0,0 +1,105 @@ +{ + "models": { + "encoder": { + "name": "ElasticSLatEncoder", + "args": { + "resolution": 64, + "in_channels": 1024, + "model_channels": 768, + "latent_channels": 8, + "num_blocks": 12, + "num_heads": 12, + "mlp_ratio": 4, + "attn_mode": "swin", + "window_size": 8, + "use_fp16": true + } + }, + "decoder": { + "name": "ElasticSLatGaussianDecoder", + "args": { + "resolution": 64, + "model_channels": 768, + "latent_channels": 8, + "num_blocks": 12, + "num_heads": 12, + "mlp_ratio": 4, + "attn_mode": "swin", + "window_size": 8, + "use_fp16": true, + "representation_config": { + "lr": { + "_xyz": 1.0, + "_features_dc": 1.0, + "_opacity": 1.0, + "_scaling": 1.0, + "_rotation": 0.1 + }, + "perturb_offset": true, + "voxel_size": 1.5, + "num_gaussians": 32, + "2d_filter_kernel_size": 0.1, + "3d_filter_kernel_size": 9e-4, + "scaling_bias": 4e-3, + "opacity_bias": 0.1, + "scaling_activation": "softplus" + } + } + } + }, + "dataset": { + "name": "SparseFeat2Render", + "args": { + "image_size": 512, + "model": "dinov2_vitl14_reg", + "resolution": 64, + "min_aesthetic_score": 4.5, + "max_num_voxels": 32768 + } + }, + "trainer": { + "name": "SLatVaeGaussianTrainer", + "args": { + "max_steps": 1000000, + "batch_size_per_gpu": 4, + "batch_split": 2, + "optimizer": { + "name": "AdamW", + "args": { + "lr": 1e-4, + "weight_decay": 0.0 + } + }, + "ema_rate": [ + 0.9999 + ], + "fp16_mode": "inflat_all", + "fp16_scale_growth": 0.001, + "elastic": { + "name": "LinearMemoryController", + "args": { + "target_ratio": 0.75, + "max_mem_ratio_start": 0.5 + } + }, + "grad_clip": { + "name": "AdaptiveGradClipper", + "args": { + "max_norm": 1.0, + "clip_percentile": 95 + } + }, + "i_log": 500, + "i_sample": 10000, + "i_save": 10000, + "loss_type": "l1", + "lambda_ssim": 0.2, + "lambda_lpips": 0.2, + "lambda_kl": 1e-06, + "regularizations": { + "lambda_vol": 10000.0, + "lambda_opacity": 0.001 + } + } + } +} \ No newline at end of file diff --git a/configs/vae/ss_vae_conv3d_16l8_fp16.json b/configs/vae/ss_vae_conv3d_16l8_fp16.json new file mode 100644 index 0000000000000000000000000000000000000000..3847cb8cc87000c31f5f29855c958cd92ee4eac8 --- /dev/null +++ b/configs/vae/ss_vae_conv3d_16l8_fp16.json @@ -0,0 +1,65 @@ +{ + "models": { + "encoder": { + "name": "SparseStructureEncoder", + "args": { + "in_channels": 1, + "latent_channels": 8, + "num_res_blocks": 2, + "num_res_blocks_middle": 2, + "channels": [32, 128, 512], + "use_fp16": true + } + }, + "decoder": { + "name": "SparseStructureDecoder", + "args": { + "out_channels": 1, + "latent_channels": 8, + "num_res_blocks": 2, + "num_res_blocks_middle": 2, + "channels": [512, 128, 32], + "use_fp16": true + } + } + }, + "dataset": { + "name": "SparseStructure", + "args": { + "resolution": 64, + "min_aesthetic_score": 4.5 + } + }, + "trainer": { + "name": "SparseStructureVaeTrainer", + "args": { + "max_steps": 1000000, + "batch_size_per_gpu": 4, + "batch_split": 1, + "optimizer": { + "name": "AdamW", + "args": { + "lr": 1e-4, + "weight_decay": 0.0 + } + }, + "ema_rate": [ + 0.9999 + ], + "fp16_mode": "inflat_all", + "fp16_scale_growth": 0.001, + "grad_clip": { + "name": "AdaptiveGradClipper", + "args": { + "max_norm": 1.0, + "clip_percentile": 95 + } + }, + "i_log": 500, + "i_sample": 10000, + "i_save": 10000, + "loss_type": "dice", + "lambda_kl": 0.001 + } + } +} \ No newline at end of file diff --git a/dataset_toolkits/blender_script/io_scene_usdz.zip b/dataset_toolkits/blender_script/io_scene_usdz.zip new file mode 100644 index 0000000000000000000000000000000000000000..153a9d97245cfbf98c2c53869987d82f1396922e --- /dev/null +++ b/dataset_toolkits/blender_script/io_scene_usdz.zip @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ec07ab6125fe0a021ed08c64169eceda126330401aba3d494d5203d26ac4b093 +size 34685 diff --git a/dataset_toolkits/blender_script/render.py b/dataset_toolkits/blender_script/render.py new file mode 100644 index 0000000000000000000000000000000000000000..1fbd5867186d6ebd86cff6e1f98f1a2e1d375e3c --- /dev/null +++ b/dataset_toolkits/blender_script/render.py @@ -0,0 +1,528 @@ +import argparse, sys, os, math, re, glob +from typing import * +import bpy +from mathutils import Vector, Matrix +import numpy as np +import json +import glob + + +"""=============== BLENDER ===============""" + +IMPORT_FUNCTIONS: Dict[str, Callable] = { + "obj": bpy.ops.import_scene.obj, + "glb": bpy.ops.import_scene.gltf, + "gltf": bpy.ops.import_scene.gltf, + "usd": bpy.ops.import_scene.usd, + "fbx": bpy.ops.import_scene.fbx, + "stl": bpy.ops.import_mesh.stl, + "usda": bpy.ops.import_scene.usda, + "dae": bpy.ops.wm.collada_import, + "ply": bpy.ops.import_mesh.ply, + "abc": bpy.ops.wm.alembic_import, + "blend": bpy.ops.wm.append, +} + +EXT = { + 'PNG': 'png', + 'JPEG': 'jpg', + 'OPEN_EXR': 'exr', + 'TIFF': 'tiff', + 'BMP': 'bmp', + 'HDR': 'hdr', + 'TARGA': 'tga' +} + +def init_render(engine='CYCLES', resolution=512, geo_mode=False): + bpy.context.scene.render.engine = engine + bpy.context.scene.render.resolution_x = resolution + bpy.context.scene.render.resolution_y = resolution + bpy.context.scene.render.resolution_percentage = 100 + bpy.context.scene.render.image_settings.file_format = 'PNG' + bpy.context.scene.render.image_settings.color_mode = 'RGBA' + bpy.context.scene.render.film_transparent = True + + bpy.context.scene.cycles.device = 'GPU' + bpy.context.scene.cycles.samples = 128 if not geo_mode else 1 + bpy.context.scene.cycles.filter_type = 'BOX' + bpy.context.scene.cycles.filter_width = 1 + bpy.context.scene.cycles.diffuse_bounces = 1 + bpy.context.scene.cycles.glossy_bounces = 1 + bpy.context.scene.cycles.transparent_max_bounces = 3 if not geo_mode else 0 + bpy.context.scene.cycles.transmission_bounces = 3 if not geo_mode else 1 + bpy.context.scene.cycles.use_denoising = True + + bpy.context.preferences.addons['cycles'].preferences.get_devices() + bpy.context.preferences.addons['cycles'].preferences.compute_device_type = 'CUDA' + +def init_nodes(save_depth=False, save_normal=False, save_albedo=False, save_mist=False): + if not any([save_depth, save_normal, save_albedo, save_mist]): + return {}, {} + outputs = {} + spec_nodes = {} + + bpy.context.scene.use_nodes = True + bpy.context.scene.view_layers['View Layer'].use_pass_z = save_depth + bpy.context.scene.view_layers['View Layer'].use_pass_normal = save_normal + bpy.context.scene.view_layers['View Layer'].use_pass_diffuse_color = save_albedo + bpy.context.scene.view_layers['View Layer'].use_pass_mist = save_mist + + nodes = bpy.context.scene.node_tree.nodes + links = bpy.context.scene.node_tree.links + for n in nodes: + nodes.remove(n) + + render_layers = nodes.new('CompositorNodeRLayers') + + if save_depth: + depth_file_output = nodes.new('CompositorNodeOutputFile') + depth_file_output.base_path = '' + depth_file_output.file_slots[0].use_node_format = True + depth_file_output.format.file_format = 'PNG' + depth_file_output.format.color_depth = '16' + depth_file_output.format.color_mode = 'BW' + # Remap to 0-1 + map = nodes.new(type="CompositorNodeMapRange") + map.inputs[1].default_value = 0 # (min value you will be getting) + map.inputs[2].default_value = 10 # (max value you will be getting) + map.inputs[3].default_value = 0 # (min value you will map to) + map.inputs[4].default_value = 1 # (max value you will map to) + + links.new(render_layers.outputs['Depth'], map.inputs[0]) + links.new(map.outputs[0], depth_file_output.inputs[0]) + + outputs['depth'] = depth_file_output + spec_nodes['depth_map'] = map + + if save_normal: + normal_file_output = nodes.new('CompositorNodeOutputFile') + normal_file_output.base_path = '' + normal_file_output.file_slots[0].use_node_format = True + normal_file_output.format.file_format = 'OPEN_EXR' + normal_file_output.format.color_mode = 'RGB' + normal_file_output.format.color_depth = '16' + + links.new(render_layers.outputs['Normal'], normal_file_output.inputs[0]) + + outputs['normal'] = normal_file_output + + if save_albedo: + albedo_file_output = nodes.new('CompositorNodeOutputFile') + albedo_file_output.base_path = '' + albedo_file_output.file_slots[0].use_node_format = True + albedo_file_output.format.file_format = 'PNG' + albedo_file_output.format.color_mode = 'RGBA' + albedo_file_output.format.color_depth = '8' + + alpha_albedo = nodes.new('CompositorNodeSetAlpha') + + links.new(render_layers.outputs['DiffCol'], alpha_albedo.inputs['Image']) + links.new(render_layers.outputs['Alpha'], alpha_albedo.inputs['Alpha']) + links.new(alpha_albedo.outputs['Image'], albedo_file_output.inputs[0]) + + outputs['albedo'] = albedo_file_output + + if save_mist: + bpy.data.worlds['World'].mist_settings.start = 0 + bpy.data.worlds['World'].mist_settings.depth = 10 + + mist_file_output = nodes.new('CompositorNodeOutputFile') + mist_file_output.base_path = '' + mist_file_output.file_slots[0].use_node_format = True + mist_file_output.format.file_format = 'PNG' + mist_file_output.format.color_mode = 'BW' + mist_file_output.format.color_depth = '16' + + links.new(render_layers.outputs['Mist'], mist_file_output.inputs[0]) + + outputs['mist'] = mist_file_output + + return outputs, spec_nodes + +def init_scene() -> None: + """Resets the scene to a clean state. + + Returns: + None + """ + # delete everything + for obj in bpy.data.objects: + bpy.data.objects.remove(obj, do_unlink=True) + + # delete all the materials + for material in bpy.data.materials: + bpy.data.materials.remove(material, do_unlink=True) + + # delete all the textures + for texture in bpy.data.textures: + bpy.data.textures.remove(texture, do_unlink=True) + + # delete all the images + for image in bpy.data.images: + bpy.data.images.remove(image, do_unlink=True) + +def init_camera(): + cam = bpy.data.objects.new('Camera', bpy.data.cameras.new('Camera')) + bpy.context.collection.objects.link(cam) + bpy.context.scene.camera = cam + cam.data.sensor_height = cam.data.sensor_width = 32 + cam_constraint = cam.constraints.new(type='TRACK_TO') + cam_constraint.track_axis = 'TRACK_NEGATIVE_Z' + cam_constraint.up_axis = 'UP_Y' + cam_empty = bpy.data.objects.new("Empty", None) + cam_empty.location = (0, 0, 0) + bpy.context.scene.collection.objects.link(cam_empty) + cam_constraint.target = cam_empty + return cam + +def init_lighting(): + # Clear existing lights + bpy.ops.object.select_all(action="DESELECT") + bpy.ops.object.select_by_type(type="LIGHT") + bpy.ops.object.delete() + + # Create key light + default_light = bpy.data.objects.new("Default_Light", bpy.data.lights.new("Default_Light", type="POINT")) + bpy.context.collection.objects.link(default_light) + default_light.data.energy = 1000 + default_light.location = (4, 1, 6) + default_light.rotation_euler = (0, 0, 0) + + # create top light + top_light = bpy.data.objects.new("Top_Light", bpy.data.lights.new("Top_Light", type="AREA")) + bpy.context.collection.objects.link(top_light) + top_light.data.energy = 10000 + top_light.location = (0, 0, 10) + top_light.scale = (100, 100, 100) + + # create bottom light + bottom_light = bpy.data.objects.new("Bottom_Light", bpy.data.lights.new("Bottom_Light", type="AREA")) + bpy.context.collection.objects.link(bottom_light) + bottom_light.data.energy = 1000 + bottom_light.location = (0, 0, -10) + bottom_light.rotation_euler = (0, 0, 0) + + return { + "default_light": default_light, + "top_light": top_light, + "bottom_light": bottom_light + } + + +def load_object(object_path: str) -> None: + """Loads a model with a supported file extension into the scene. + + Args: + object_path (str): Path to the model file. + + Raises: + ValueError: If the file extension is not supported. + + Returns: + None + """ + file_extension = object_path.split(".")[-1].lower() + if file_extension is None: + raise ValueError(f"Unsupported file type: {object_path}") + + if file_extension == "usdz": + # install usdz io package + dirname = os.path.dirname(os.path.realpath(__file__)) + usdz_package = os.path.join(dirname, "io_scene_usdz.zip") + bpy.ops.preferences.addon_install(filepath=usdz_package) + # enable it + addon_name = "io_scene_usdz" + bpy.ops.preferences.addon_enable(module=addon_name) + # import the usdz + from io_scene_usdz.import_usdz import import_usdz + + import_usdz(context, filepath=object_path, materials=True, animations=True) + return None + + # load from existing import functions + import_function = IMPORT_FUNCTIONS[file_extension] + + print(f"Loading object from {object_path}") + if file_extension == "blend": + import_function(directory=object_path, link=False) + elif file_extension in {"glb", "gltf"}: + import_function(filepath=object_path, merge_vertices=True, import_shading='NORMALS') + else: + import_function(filepath=object_path) + +def delete_invisible_objects() -> None: + """Deletes all invisible objects in the scene. + + Returns: + None + """ + # bpy.ops.object.mode_set(mode="OBJECT") + bpy.ops.object.select_all(action="DESELECT") + for obj in bpy.context.scene.objects: + if obj.hide_viewport or obj.hide_render: + obj.hide_viewport = False + obj.hide_render = False + obj.hide_select = False + obj.select_set(True) + bpy.ops.object.delete() + + # Delete invisible collections + invisible_collections = [col for col in bpy.data.collections if col.hide_viewport] + for col in invisible_collections: + bpy.data.collections.remove(col) + +def split_mesh_normal(): + bpy.ops.object.select_all(action="DESELECT") + objs = [obj for obj in bpy.context.scene.objects if obj.type == "MESH"] + bpy.context.view_layer.objects.active = objs[0] + for obj in objs: + obj.select_set(True) + bpy.ops.object.mode_set(mode="EDIT") + bpy.ops.mesh.select_all(action='SELECT') + bpy.ops.mesh.split_normals() + bpy.ops.object.mode_set(mode='OBJECT') + bpy.ops.object.select_all(action="DESELECT") + +def delete_custom_normals(): + for this_obj in bpy.data.objects: + if this_obj.type == "MESH": + bpy.context.view_layer.objects.active = this_obj + bpy.ops.mesh.customdata_custom_splitnormals_clear() + +def override_material(): + new_mat = bpy.data.materials.new(name="Override0123456789") + new_mat.use_nodes = True + new_mat.node_tree.nodes.clear() + bsdf = new_mat.node_tree.nodes.new('ShaderNodeBsdfDiffuse') + bsdf.inputs[0].default_value = (0.5, 0.5, 0.5, 1) + bsdf.inputs[1].default_value = 1 + output = new_mat.node_tree.nodes.new('ShaderNodeOutputMaterial') + new_mat.node_tree.links.new(bsdf.outputs['BSDF'], output.inputs['Surface']) + bpy.context.scene.view_layers['View Layer'].material_override = new_mat + +def unhide_all_objects() -> None: + """Unhides all objects in the scene. + + Returns: + None + """ + for obj in bpy.context.scene.objects: + obj.hide_set(False) + +def convert_to_meshes() -> None: + """Converts all objects in the scene to meshes. + + Returns: + None + """ + bpy.ops.object.select_all(action="DESELECT") + bpy.context.view_layer.objects.active = [obj for obj in bpy.context.scene.objects if obj.type == "MESH"][0] + for obj in bpy.context.scene.objects: + obj.select_set(True) + bpy.ops.object.convert(target="MESH") + +def triangulate_meshes() -> None: + """Triangulates all meshes in the scene. + + Returns: + None + """ + bpy.ops.object.select_all(action="DESELECT") + objs = [obj for obj in bpy.context.scene.objects if obj.type == "MESH"] + bpy.context.view_layer.objects.active = objs[0] + for obj in objs: + obj.select_set(True) + bpy.ops.object.mode_set(mode="EDIT") + bpy.ops.mesh.reveal() + bpy.ops.mesh.select_all(action="SELECT") + bpy.ops.mesh.quads_convert_to_tris(quad_method="BEAUTY", ngon_method="BEAUTY") + bpy.ops.object.mode_set(mode="OBJECT") + bpy.ops.object.select_all(action="DESELECT") + +def scene_bbox() -> Tuple[Vector, Vector]: + """Returns the bounding box of the scene. + + Taken from Shap-E rendering script + (https://github.com/openai/shap-e/blob/main/shap_e/rendering/blender/blender_script.py#L68-L82) + + Returns: + Tuple[Vector, Vector]: The minimum and maximum coordinates of the bounding box. + """ + bbox_min = (math.inf,) * 3 + bbox_max = (-math.inf,) * 3 + found = False + scene_meshes = [obj for obj in bpy.context.scene.objects.values() if isinstance(obj.data, bpy.types.Mesh)] + for obj in scene_meshes: + found = True + for coord in obj.bound_box: + coord = Vector(coord) + coord = obj.matrix_world @ coord + bbox_min = tuple(min(x, y) for x, y in zip(bbox_min, coord)) + bbox_max = tuple(max(x, y) for x, y in zip(bbox_max, coord)) + if not found: + raise RuntimeError("no objects in scene to compute bounding box for") + return Vector(bbox_min), Vector(bbox_max) + +def normalize_scene() -> Tuple[float, Vector]: + """Normalizes the scene by scaling and translating it to fit in a unit cube centered + at the origin. + + Mostly taken from the Point-E / Shap-E rendering script + (https://github.com/openai/point-e/blob/main/point_e/evals/scripts/blender_script.py#L97-L112), + but fix for multiple root objects: (see bug report here: + https://github.com/openai/shap-e/pull/60). + + Returns: + Tuple[float, Vector]: The scale factor and the offset applied to the scene. + """ + scene_root_objects = [obj for obj in bpy.context.scene.objects.values() if not obj.parent] + if len(scene_root_objects) > 1: + # create an empty object to be used as a parent for all root objects + scene = bpy.data.objects.new("ParentEmpty", None) + bpy.context.scene.collection.objects.link(scene) + + # parent all root objects to the empty object + for obj in scene_root_objects: + obj.parent = scene + else: + scene = scene_root_objects[0] + + bbox_min, bbox_max = scene_bbox() + scale = 1 / max(bbox_max - bbox_min) + scene.scale = scene.scale * scale + + # Apply scale to matrix_world. + bpy.context.view_layer.update() + bbox_min, bbox_max = scene_bbox() + offset = -(bbox_min + bbox_max) / 2 + scene.matrix_world.translation += offset + bpy.ops.object.select_all(action="DESELECT") + + return scale, offset + +def get_transform_matrix(obj: bpy.types.Object) -> list: + pos, rt, _ = obj.matrix_world.decompose() + rt = rt.to_matrix() + matrix = [] + for ii in range(3): + a = [] + for jj in range(3): + a.append(rt[ii][jj]) + a.append(pos[ii]) + matrix.append(a) + matrix.append([0, 0, 0, 1]) + return matrix + +def main(arg): + os.makedirs(arg.output_folder, exist_ok=True) + + # Initialize context + init_render(engine=arg.engine, resolution=arg.resolution, geo_mode=arg.geo_mode) + outputs, spec_nodes = init_nodes( + save_depth=arg.save_depth, + save_normal=arg.save_normal, + save_albedo=arg.save_albedo, + save_mist=arg.save_mist + ) + if arg.object.endswith(".blend"): + delete_invisible_objects() + else: + init_scene() + load_object(arg.object) + if arg.split_normal: + split_mesh_normal() + # delete_custom_normals() + print('[INFO] Scene initialized.') + + # normalize scene + scale, offset = normalize_scene() + print('[INFO] Scene normalized.') + + # Initialize camera and lighting + cam = init_camera() + init_lighting() + print('[INFO] Camera and lighting initialized.') + + # Override material + if arg.geo_mode: + override_material() + + # Create a list of views + to_export = { + "aabb": [[-0.5, -0.5, -0.5], [0.5, 0.5, 0.5]], + "scale": scale, + "offset": [offset.x, offset.y, offset.z], + "frames": [] + } + views = json.loads(arg.views) + for i, view in enumerate(views): + cam.location = ( + view['radius'] * np.cos(view['yaw']) * np.cos(view['pitch']), + view['radius'] * np.sin(view['yaw']) * np.cos(view['pitch']), + view['radius'] * np.sin(view['pitch']) + ) + cam.data.lens = 16 / np.tan(view['fov'] / 2) + + if arg.save_depth: + spec_nodes['depth_map'].inputs[1].default_value = view['radius'] - 0.5 * np.sqrt(3) + spec_nodes['depth_map'].inputs[2].default_value = view['radius'] + 0.5 * np.sqrt(3) + + bpy.context.scene.render.filepath = os.path.join(arg.output_folder, f'{i:03d}.png') + for name, output in outputs.items(): + output.file_slots[0].path = os.path.join(arg.output_folder, f'{i:03d}_{name}') + + # Render the scene + bpy.ops.render.render(write_still=True) + bpy.context.view_layer.update() + for name, output in outputs.items(): + ext = EXT[output.format.file_format] + path = glob.glob(f'{output.file_slots[0].path}*.{ext}')[0] + os.rename(path, f'{output.file_slots[0].path}.{ext}') + + # Save camera parameters + metadata = { + "file_path": f'{i:03d}.png', + "camera_angle_x": view['fov'], + "transform_matrix": get_transform_matrix(cam) + } + if arg.save_depth: + metadata['depth'] = { + 'min': view['radius'] - 0.5 * np.sqrt(3), + 'max': view['radius'] + 0.5 * np.sqrt(3) + } + to_export["frames"].append(metadata) + + # Save the camera parameters + with open(os.path.join(arg.output_folder, 'transforms.json'), 'w') as f: + json.dump(to_export, f, indent=4) + + if arg.save_mesh: + # triangulate meshes + unhide_all_objects() + convert_to_meshes() + triangulate_meshes() + print('[INFO] Meshes triangulated.') + + # export ply mesh + bpy.ops.export_mesh.ply(filepath=os.path.join(arg.output_folder, 'mesh.ply')) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser(description='Renders given obj file by rotation a camera around it.') + parser.add_argument('--views', type=str, help='JSON string of views. Contains a list of {yaw, pitch, radius, fov} object.') + parser.add_argument('--object', type=str, help='Path to the 3D model file to be rendered.') + parser.add_argument('--output_folder', type=str, default='/tmp', help='The path the output will be dumped to.') + parser.add_argument('--resolution', type=int, default=512, help='Resolution of the images.') + parser.add_argument('--engine', type=str, default='CYCLES', help='Blender internal engine for rendering. E.g. CYCLES, BLENDER_EEVEE, ...') + parser.add_argument('--geo_mode', action='store_true', help='Geometry mode for rendering.') + parser.add_argument('--save_depth', action='store_true', help='Save the depth maps.') + parser.add_argument('--save_normal', action='store_true', help='Save the normal maps.') + parser.add_argument('--save_albedo', action='store_true', help='Save the albedo maps.') + parser.add_argument('--save_mist', action='store_true', help='Save the mist distance maps.') + parser.add_argument('--split_normal', action='store_true', help='Split the normals of the mesh.') + parser.add_argument('--save_mesh', action='store_true', help='Save the mesh as a .ply file.') + argv = sys.argv[sys.argv.index("--") + 1:] + args = parser.parse_args(argv) + + main(args) + \ No newline at end of file diff --git a/dataset_toolkits/build_metadata.py b/dataset_toolkits/build_metadata.py new file mode 100644 index 0000000000000000000000000000000000000000..0c1f6396a9d11ca99be8eb8ffae22e64472f4ec8 --- /dev/null +++ b/dataset_toolkits/build_metadata.py @@ -0,0 +1,270 @@ +import os +import shutil +import sys +import time +import importlib +import argparse +import numpy as np +import pandas as pd +from tqdm import tqdm +from easydict import EasyDict as edict +from concurrent.futures import ThreadPoolExecutor +import utils3d + +def get_first_directory(path): + with os.scandir(path) as it: + for entry in it: + if entry.is_dir(): + return entry.name + return None + +def need_process(key): + return key in opt.field or opt.field == ['all'] + +if __name__ == '__main__': + dataset_utils = importlib.import_module(f'datasets.{sys.argv[1]}') + + parser = argparse.ArgumentParser() + parser.add_argument('--output_dir', type=str, required=True, + help='Directory to save the metadata') + parser.add_argument('--field', type=str, default='all', + help='Fields to process, separated by commas') + parser.add_argument('--from_file', action='store_true', + help='Build metadata from file instead of from records of processings.' + + 'Useful when some processing fail to generate records but file already exists.') + dataset_utils.add_args(parser) + opt = parser.parse_args(sys.argv[2:]) + opt = edict(vars(opt)) + + os.makedirs(opt.output_dir, exist_ok=True) + os.makedirs(os.path.join(opt.output_dir, 'merged_records'), exist_ok=True) + + opt.field = opt.field.split(',') + + timestamp = str(int(time.time())) + + # get file list + if os.path.exists(os.path.join(opt.output_dir, 'metadata.csv')): + print('Loading previous metadata...') + metadata = pd.read_csv(os.path.join(opt.output_dir, 'metadata.csv')) + else: + metadata = dataset_utils.get_metadata(**opt) + metadata.set_index('sha256', inplace=True) + + # merge downloaded + df_files = [f for f in os.listdir(opt.output_dir) if f.startswith('downloaded_') and f.endswith('.csv')] + df_parts = [] + for f in df_files: + try: + df_parts.append(pd.read_csv(os.path.join(opt.output_dir, f))) + except: + pass + if len(df_parts) > 0: + df = pd.concat(df_parts) + df.set_index('sha256', inplace=True) + if 'local_path' in metadata.columns: + metadata.update(df, overwrite=True) + else: + metadata = metadata.join(df, on='sha256', how='left') + for f in df_files: + shutil.move(os.path.join(opt.output_dir, f), os.path.join(opt.output_dir, 'merged_records', f'{timestamp}_{f}')) + + # detect models + image_models = [] + if os.path.exists(os.path.join(opt.output_dir, 'features')): + image_models = os.listdir(os.path.join(opt.output_dir, 'features')) + latent_models = [] + if os.path.exists(os.path.join(opt.output_dir, 'latents')): + latent_models = os.listdir(os.path.join(opt.output_dir, 'latents')) + ss_latent_models = [] + if os.path.exists(os.path.join(opt.output_dir, 'ss_latents')): + ss_latent_models = os.listdir(os.path.join(opt.output_dir, 'ss_latents')) + print(f'Image models: {image_models}') + print(f'Latent models: {latent_models}') + print(f'Sparse Structure latent models: {ss_latent_models}') + + if 'rendered' not in metadata.columns: + metadata['rendered'] = [False] * len(metadata) + if 'voxelized' not in metadata.columns: + metadata['voxelized'] = [False] * len(metadata) + if 'num_voxels' not in metadata.columns: + metadata['num_voxels'] = [0] * len(metadata) + if 'cond_rendered' not in metadata.columns: + metadata['cond_rendered'] = [False] * len(metadata) + for model in image_models: + if f'feature_{model}' not in metadata.columns: + metadata[f'feature_{model}'] = [False] * len(metadata) + for model in latent_models: + if f'latent_{model}' not in metadata.columns: + metadata[f'latent_{model}'] = [False] * len(metadata) + for model in ss_latent_models: + if f'ss_latent_{model}' not in metadata.columns: + metadata[f'ss_latent_{model}'] = [False] * len(metadata) + + # merge rendered + df_files = [f for f in os.listdir(opt.output_dir) if f.startswith('rendered_') and f.endswith('.csv')] + df_parts = [] + for f in df_files: + try: + df_parts.append(pd.read_csv(os.path.join(opt.output_dir, f))) + except: + pass + if len(df_parts) > 0: + df = pd.concat(df_parts) + df.set_index('sha256', inplace=True) + metadata.update(df, overwrite=True) + for f in df_files: + shutil.move(os.path.join(opt.output_dir, f), os.path.join(opt.output_dir, 'merged_records', f'{timestamp}_{f}')) + + # merge voxelized + df_files = [f for f in os.listdir(opt.output_dir) if f.startswith('voxelized_') and f.endswith('.csv')] + df_parts = [] + for f in df_files: + try: + df_parts.append(pd.read_csv(os.path.join(opt.output_dir, f))) + except: + pass + if len(df_parts) > 0: + df = pd.concat(df_parts) + df.set_index('sha256', inplace=True) + metadata.update(df, overwrite=True) + for f in df_files: + shutil.move(os.path.join(opt.output_dir, f), os.path.join(opt.output_dir, 'merged_records', f'{timestamp}_{f}')) + + # merge cond_rendered + df_files = [f for f in os.listdir(opt.output_dir) if f.startswith('cond_rendered_') and f.endswith('.csv')] + df_parts = [] + for f in df_files: + try: + df_parts.append(pd.read_csv(os.path.join(opt.output_dir, f))) + except: + pass + if len(df_parts) > 0: + df = pd.concat(df_parts) + df.set_index('sha256', inplace=True) + metadata.update(df, overwrite=True) + for f in df_files: + shutil.move(os.path.join(opt.output_dir, f), os.path.join(opt.output_dir, 'merged_records', f'{timestamp}_{f}')) + + # merge features + for model in image_models: + df_files = [f for f in os.listdir(opt.output_dir) if f.startswith(f'feature_{model}_') and f.endswith('.csv')] + df_parts = [] + for f in df_files: + try: + df_parts.append(pd.read_csv(os.path.join(opt.output_dir, f))) + except: + pass + if len(df_parts) > 0: + df = pd.concat(df_parts) + df.set_index('sha256', inplace=True) + metadata.update(df, overwrite=True) + for f in df_files: + shutil.move(os.path.join(opt.output_dir, f), os.path.join(opt.output_dir, 'merged_records', f'{timestamp}_{f}')) + + # merge latents + for model in latent_models: + df_files = [f for f in os.listdir(opt.output_dir) if f.startswith(f'latent_{model}_') and f.endswith('.csv')] + df_parts = [] + for f in df_files: + try: + df_parts.append(pd.read_csv(os.path.join(opt.output_dir, f))) + except: + pass + if len(df_parts) > 0: + df = pd.concat(df_parts) + df.set_index('sha256', inplace=True) + metadata.update(df, overwrite=True) + for f in df_files: + shutil.move(os.path.join(opt.output_dir, f), os.path.join(opt.output_dir, 'merged_records', f'{timestamp}_{f}')) + + # merge sparse structure latents + for model in ss_latent_models: + df_files = [f for f in os.listdir(opt.output_dir) if f.startswith(f'ss_latent_{model}_') and f.endswith('.csv')] + df_parts = [] + for f in df_files: + try: + df_parts.append(pd.read_csv(os.path.join(opt.output_dir, f))) + except: + pass + if len(df_parts) > 0: + df = pd.concat(df_parts) + df.set_index('sha256', inplace=True) + metadata.update(df, overwrite=True) + for f in df_files: + shutil.move(os.path.join(opt.output_dir, f), os.path.join(opt.output_dir, 'merged_records', f'{timestamp}_{f}')) + + # build metadata from files + if opt.from_file: + with ThreadPoolExecutor(max_workers=os.cpu_count()) as executor, \ + tqdm(total=len(metadata), desc="Building metadata") as pbar: + def worker(sha256): + try: + if need_process('rendered') and metadata.loc[sha256, 'rendered'] == False and \ + os.path.exists(os.path.join(opt.output_dir, 'renders', sha256, 'transforms.json')): + metadata.loc[sha256, 'rendered'] = True + if need_process('voxelized') and metadata.loc[sha256, 'rendered'] == True and metadata.loc[sha256, 'voxelized'] == False and \ + os.path.exists(os.path.join(opt.output_dir, 'voxels', f'{sha256}.ply')): + try: + pts = utils3d.io.read_ply(os.path.join(opt.output_dir, 'voxels', f'{sha256}.ply'))[0] + metadata.loc[sha256, 'voxelized'] = True + metadata.loc[sha256, 'num_voxels'] = len(pts) + except Exception as e: + pass + if need_process('cond_rendered') and metadata.loc[sha256, 'cond_rendered'] == False and \ + os.path.exists(os.path.join(opt.output_dir, 'renders_cond', sha256, 'transforms.json')): + metadata.loc[sha256, 'cond_rendered'] = True + for model in image_models: + if need_process(f'feature_{model}') and \ + metadata.loc[sha256, f'feature_{model}'] == False and \ + metadata.loc[sha256, 'rendered'] == True and \ + metadata.loc[sha256, 'voxelized'] == True and \ + os.path.exists(os.path.join(opt.output_dir, 'features', model, f'{sha256}.npz')): + metadata.loc[sha256, f'feature_{model}'] = True + for model in latent_models: + if need_process(f'latent_{model}') and \ + metadata.loc[sha256, f'latent_{model}'] == False and \ + metadata.loc[sha256, 'rendered'] == True and \ + metadata.loc[sha256, 'voxelized'] == True and \ + os.path.exists(os.path.join(opt.output_dir, 'latents', model, f'{sha256}.npz')): + metadata.loc[sha256, f'latent_{model}'] = True + for model in ss_latent_models: + if need_process(f'ss_latent_{model}') and \ + metadata.loc[sha256, f'ss_latent_{model}'] == False and \ + metadata.loc[sha256, 'voxelized'] == True and \ + os.path.exists(os.path.join(opt.output_dir, 'ss_latents', model, f'{sha256}.npz')): + metadata.loc[sha256, f'ss_latent_{model}'] = True + pbar.update() + except Exception as e: + print(f'Error processing {sha256}: {e}') + pbar.update() + + executor.map(worker, metadata.index) + executor.shutdown(wait=True) + + # statistics + metadata.to_csv(os.path.join(opt.output_dir, 'metadata.csv')) + num_downloaded = metadata['local_path'].count() if 'local_path' in metadata.columns else 0 + with open(os.path.join(opt.output_dir, 'statistics.txt'), 'w') as f: + f.write('Statistics:\n') + f.write(f' - Number of assets: {len(metadata)}\n') + f.write(f' - Number of assets downloaded: {num_downloaded}\n') + f.write(f' - Number of assets rendered: {metadata["rendered"].sum()}\n') + f.write(f' - Number of assets voxelized: {metadata["voxelized"].sum()}\n') + if len(image_models) != 0: + f.write(f' - Number of assets with image features extracted:\n') + for model in image_models: + f.write(f' - {model}: {metadata[f"feature_{model}"].sum()}\n') + if len(latent_models) != 0: + f.write(f' - Number of assets with latents extracted:\n') + for model in latent_models: + f.write(f' - {model}: {metadata[f"latent_{model}"].sum()}\n') + if len(ss_latent_models) != 0: + f.write(f' - Number of assets with sparse structure latents extracted:\n') + for model in ss_latent_models: + f.write(f' - {model}: {metadata[f"ss_latent_{model}"].sum()}\n') + f.write(f' - Number of assets with captions: {metadata["captions"].count()}\n') + f.write(f' - Number of assets with image conditions: {metadata["cond_rendered"].sum()}\n') + + with open(os.path.join(opt.output_dir, 'statistics.txt'), 'r') as f: + print(f.read()) \ No newline at end of file diff --git a/dataset_toolkits/datasets/3D-FUTURE.py b/dataset_toolkits/datasets/3D-FUTURE.py new file mode 100644 index 0000000000000000000000000000000000000000..a5ccc632de8e5a1d4f4a3b2b3024b8a3361e12c4 --- /dev/null +++ b/dataset_toolkits/datasets/3D-FUTURE.py @@ -0,0 +1,97 @@ +import os +import re +import argparse +import zipfile +from concurrent.futures import ThreadPoolExecutor +from tqdm import tqdm +import pandas as pd +from utils import get_file_hash + + +def add_args(parser: argparse.ArgumentParser): + pass + + +def get_metadata(**kwargs): + metadata = pd.read_csv("hf://datasets/JeffreyXiang/TRELLIS-500K/3D-FUTURE.csv") + return metadata + + +def download(metadata, output_dir, **kwargs): + os.makedirs(output_dir, exist_ok=True) + + if not os.path.exists(os.path.join(output_dir, 'raw', '3D-FUTURE-model.zip')): + print("\033[93m") + print("3D-FUTURE have to be downloaded manually") + print(f"Please download the 3D-FUTURE-model.zip file and place it in the {output_dir}/raw directory") + print("Visit https://tianchi.aliyun.com/specials/promotion/alibaba-3d-future for more information") + print("\033[0m") + raise FileNotFoundError("3D-FUTURE-model.zip not found") + + downloaded = {} + metadata = metadata.set_index("file_identifier") + with zipfile.ZipFile(os.path.join(output_dir, 'raw', '3D-FUTURE-model.zip')) as zip_ref: + all_names = zip_ref.namelist() + instances = [instance[:-1] for instance in all_names if re.match(r"^3D-FUTURE-model/[^/]+/$", instance)] + instances = list(filter(lambda x: x in metadata.index, instances)) + + with ThreadPoolExecutor(max_workers=os.cpu_count()) as executor, \ + tqdm(total=len(instances), desc="Extracting") as pbar: + def worker(instance: str) -> str: + try: + instance_files = list(filter(lambda x: x.startswith(f"{instance}/") and not x.endswith("/"), all_names)) + zip_ref.extractall(os.path.join(output_dir, 'raw'), members=instance_files) + sha256 = get_file_hash(os.path.join(output_dir, 'raw', f"{instance}/image.jpg")) + pbar.update() + return sha256 + except Exception as e: + pbar.update() + print(f"Error extracting for {instance}: {e}") + return None + + sha256s = executor.map(worker, instances) + executor.shutdown(wait=True) + + for k, sha256 in zip(instances, sha256s): + if sha256 is not None: + if sha256 == metadata.loc[k, "sha256"]: + downloaded[sha256] = os.path.join("raw", f"{k}/raw_model.obj") + else: + print(f"Error downloading {k}: sha256s do not match") + + return pd.DataFrame(downloaded.items(), columns=['sha256', 'local_path']) + + +def foreach_instance(metadata, output_dir, func, max_workers=None, desc='Processing objects') -> pd.DataFrame: + import os + from concurrent.futures import ThreadPoolExecutor + from tqdm import tqdm + + # load metadata + metadata = metadata.to_dict('records') + + # processing objects + records = [] + max_workers = max_workers or os.cpu_count() + try: + with ThreadPoolExecutor(max_workers=max_workers) as executor, \ + tqdm(total=len(metadata), desc=desc) as pbar: + def worker(metadatum): + try: + local_path = metadatum['local_path'] + sha256 = metadatum['sha256'] + file = os.path.join(output_dir, local_path) + record = func(file, sha256) + if record is not None: + records.append(record) + pbar.update() + except Exception as e: + print(f"Error processing object {sha256}: {e}") + pbar.update() + + executor.map(worker, metadata) + executor.shutdown(wait=True) + except: + print("Error happened during processing.") + + return pd.DataFrame.from_records(records) diff --git a/dataset_toolkits/datasets/ABO.py b/dataset_toolkits/datasets/ABO.py new file mode 100644 index 0000000000000000000000000000000000000000..b0aba22c2c4980b024091058a458811ab93805dd --- /dev/null +++ b/dataset_toolkits/datasets/ABO.py @@ -0,0 +1,96 @@ +import os +import re +import argparse +import tarfile +from concurrent.futures import ThreadPoolExecutor +from tqdm import tqdm +import pandas as pd +from utils import get_file_hash + + +def add_args(parser: argparse.ArgumentParser): + pass + + +def get_metadata(**kwargs): + metadata = pd.read_csv("hf://datasets/JeffreyXiang/TRELLIS-500K/ABO.csv") + return metadata + + +def download(metadata, output_dir, **kwargs): + os.makedirs(os.path.join(output_dir, 'raw'), exist_ok=True) + + if not os.path.exists(os.path.join(output_dir, 'raw', 'abo-3dmodels.tar')): + try: + os.makedirs(os.path.join(output_dir, 'raw'), exist_ok=True) + os.system(f"wget -O {output_dir}/raw/abo-3dmodels.tar https://amazon-berkeley-objects.s3.amazonaws.com/archives/abo-3dmodels.tar") + except: + print("\033[93m") + print("Error downloading ABO dataset. Please check your internet connection and try again.") + print("Or, you can manually download the abo-3dmodels.tar file and place it in the {output_dir}/raw directory") + print("Visit https://amazon-berkeley-objects.s3.amazonaws.com/index.html for more information") + print("\033[0m") + raise FileNotFoundError("Error downloading ABO dataset") + + downloaded = {} + metadata = metadata.set_index("file_identifier") + with tarfile.open(os.path.join(output_dir, 'raw', 'abo-3dmodels.tar')) as tar: + with ThreadPoolExecutor(max_workers=1) as executor, \ + tqdm(total=len(metadata), desc="Extracting") as pbar: + def worker(instance: str) -> str: + try: + tar.extract(f"3dmodels/original/{instance}", path=os.path.join(output_dir, 'raw')) + sha256 = get_file_hash(os.path.join(output_dir, 'raw/3dmodels/original', instance)) + pbar.update() + return sha256 + except Exception as e: + pbar.update() + print(f"Error extracting for {instance}: {e}") + return None + + sha256s = executor.map(worker, metadata.index) + executor.shutdown(wait=True) + + for k, sha256 in zip(metadata.index, sha256s): + if sha256 is not None: + if sha256 == metadata.loc[k, "sha256"]: + downloaded[sha256] = os.path.join('raw/3dmodels/original', k) + else: + print(f"Error downloading {k}: sha256s do not match") + + return pd.DataFrame(downloaded.items(), columns=['sha256', 'local_path']) + + +def foreach_instance(metadata, output_dir, func, max_workers=None, desc='Processing objects') -> pd.DataFrame: + import os + from concurrent.futures import ThreadPoolExecutor + from tqdm import tqdm + + # load metadata + metadata = metadata.to_dict('records') + + # processing objects + records = [] + max_workers = max_workers or os.cpu_count() + try: + with ThreadPoolExecutor(max_workers=max_workers) as executor, \ + tqdm(total=len(metadata), desc=desc) as pbar: + def worker(metadatum): + try: + local_path = metadatum['local_path'] + sha256 = metadatum['sha256'] + file = os.path.join(output_dir, local_path) + record = func(file, sha256) + if record is not None: + records.append(record) + pbar.update() + except Exception as e: + print(f"Error processing object {sha256}: {e}") + pbar.update() + + executor.map(worker, metadata) + executor.shutdown(wait=True) + except: + print("Error happened during processing.") + + return pd.DataFrame.from_records(records) diff --git a/dataset_toolkits/datasets/HSSD.py b/dataset_toolkits/datasets/HSSD.py new file mode 100644 index 0000000000000000000000000000000000000000..465e6a140010d0b33ba6435e8129825599dda5db --- /dev/null +++ b/dataset_toolkits/datasets/HSSD.py @@ -0,0 +1,103 @@ +import os +import re +import argparse +import tarfile +from concurrent.futures import ThreadPoolExecutor +from tqdm import tqdm +import pandas as pd +import huggingface_hub +from utils import get_file_hash + + +def add_args(parser: argparse.ArgumentParser): + pass + + +def get_metadata(**kwargs): + metadata = pd.read_csv("hf://datasets/JeffreyXiang/TRELLIS-500K/HSSD.csv") + return metadata + + +def download(metadata, output_dir, **kwargs): + os.makedirs(os.path.join(output_dir, 'raw'), exist_ok=True) + + # check login + try: + huggingface_hub.whoami() + except: + print("\033[93m") + print("Haven't logged in to the Hugging Face Hub.") + print("Visit https://huggingface.co/settings/tokens to get a token.") + print("\033[0m") + huggingface_hub.login() + + try: + huggingface_hub.hf_hub_download(repo_id="hssd/hssd-models", filename="README.md", repo_type="dataset") + except: + print("\033[93m") + print("Error downloading HSSD dataset.") + print("Check if you have access to the HSSD dataset.") + print("Visit https://huggingface.co/datasets/hssd/hssd-models for more information") + print("\033[0m") + + downloaded = {} + metadata = metadata.set_index("file_identifier") + with ThreadPoolExecutor(max_workers=os.cpu_count()) as executor, \ + tqdm(total=len(metadata), desc="Downloading") as pbar: + def worker(instance: str) -> str: + try: + huggingface_hub.hf_hub_download(repo_id="hssd/hssd-models", filename=instance, repo_type="dataset", local_dir=os.path.join(output_dir, 'raw')) + sha256 = get_file_hash(os.path.join(output_dir, 'raw', instance)) + pbar.update() + return sha256 + except Exception as e: + pbar.update() + print(f"Error extracting for {instance}: {e}") + return None + + sha256s = executor.map(worker, metadata.index) + executor.shutdown(wait=True) + + for k, sha256 in zip(metadata.index, sha256s): + if sha256 is not None: + if sha256 == metadata.loc[k, "sha256"]: + downloaded[sha256] = os.path.join('raw', k) + else: + print(f"Error downloading {k}: sha256s do not match") + + return pd.DataFrame(downloaded.items(), columns=['sha256', 'local_path']) + + +def foreach_instance(metadata, output_dir, func, max_workers=None, desc='Processing objects') -> pd.DataFrame: + import os + from concurrent.futures import ThreadPoolExecutor + from tqdm import tqdm + + # load metadata + metadata = metadata.to_dict('records') + + # processing objects + records = [] + max_workers = max_workers or os.cpu_count() + try: + with ThreadPoolExecutor(max_workers=max_workers) as executor, \ + tqdm(total=len(metadata), desc=desc) as pbar: + def worker(metadatum): + try: + local_path = metadatum['local_path'] + sha256 = metadatum['sha256'] + file = os.path.join(output_dir, local_path) + record = func(file, sha256) + if record is not None: + records.append(record) + pbar.update() + except Exception as e: + print(f"Error processing object {sha256}: {e}") + pbar.update() + + executor.map(worker, metadata) + executor.shutdown(wait=True) + except: + print("Error happened during processing.") + + return pd.DataFrame.from_records(records) diff --git a/dataset_toolkits/datasets/ObjaverseXL.py b/dataset_toolkits/datasets/ObjaverseXL.py new file mode 100644 index 0000000000000000000000000000000000000000..b2f5c76c07701b198a20af32fc87c81614ae7f79 --- /dev/null +++ b/dataset_toolkits/datasets/ObjaverseXL.py @@ -0,0 +1,92 @@ +import os +import argparse +from concurrent.futures import ThreadPoolExecutor +from tqdm import tqdm +import pandas as pd +import objaverse.xl as oxl +from utils import get_file_hash + + +def add_args(parser: argparse.ArgumentParser): + parser.add_argument('--source', type=str, default='sketchfab', + help='Data source to download annotations from (github, sketchfab)') + + +def get_metadata(source, **kwargs): + if source == 'sketchfab': + metadata = pd.read_csv("hf://datasets/JeffreyXiang/TRELLIS-500K/ObjaverseXL_sketchfab.csv") + elif source == 'github': + metadata = pd.read_csv("hf://datasets/JeffreyXiang/TRELLIS-500K/ObjaverseXL_github.csv") + else: + raise ValueError(f"Invalid source: {source}") + return metadata + + +def download(metadata, output_dir, **kwargs): + os.makedirs(os.path.join(output_dir, 'raw'), exist_ok=True) + + # download annotations + annotations = oxl.get_annotations() + annotations = annotations[annotations['sha256'].isin(metadata['sha256'].values)] + + # download and render objects + file_paths = oxl.download_objects( + annotations, + download_dir=os.path.join(output_dir, "raw"), + save_repo_format="zip", + ) + + downloaded = {} + metadata = metadata.set_index("file_identifier") + for k, v in file_paths.items(): + sha256 = metadata.loc[k, "sha256"] + downloaded[sha256] = os.path.relpath(v, output_dir) + + return pd.DataFrame(downloaded.items(), columns=['sha256', 'local_path']) + + +def foreach_instance(metadata, output_dir, func, max_workers=None, desc='Processing objects') -> pd.DataFrame: + import os + from concurrent.futures import ThreadPoolExecutor + from tqdm import tqdm + import tempfile + import zipfile + + # load metadata + metadata = metadata.to_dict('records') + + # processing objects + records = [] + max_workers = max_workers or os.cpu_count() + try: + with ThreadPoolExecutor(max_workers=max_workers) as executor, \ + tqdm(total=len(metadata), desc=desc) as pbar: + def worker(metadatum): + try: + local_path = metadatum['local_path'] + sha256 = metadatum['sha256'] + if local_path.startswith('raw/github/repos/'): + path_parts = local_path.split('/') + file_name = os.path.join(*path_parts[5:]) + zip_file = os.path.join(output_dir, *path_parts[:5]) + with tempfile.TemporaryDirectory() as tmp_dir: + with zipfile.ZipFile(zip_file, 'r') as zip_ref: + zip_ref.extractall(tmp_dir) + file = os.path.join(tmp_dir, file_name) + record = func(file, sha256) + else: + file = os.path.join(output_dir, local_path) + record = func(file, sha256) + if record is not None: + records.append(record) + pbar.update() + except Exception as e: + print(f"Error processing object {sha256}: {e}") + pbar.update() + + executor.map(worker, metadata) + executor.shutdown(wait=True) + except: + print("Error happened during processing.") + + return pd.DataFrame.from_records(records) diff --git a/dataset_toolkits/datasets/Toys4k.py b/dataset_toolkits/datasets/Toys4k.py new file mode 100644 index 0000000000000000000000000000000000000000..378afdaa87b1b2a5962a33e7e55efa9540f28436 --- /dev/null +++ b/dataset_toolkits/datasets/Toys4k.py @@ -0,0 +1,92 @@ +import os +import re +import argparse +import zipfile +from concurrent.futures import ThreadPoolExecutor +from tqdm import tqdm +import pandas as pd +from utils import get_file_hash + + +def add_args(parser: argparse.ArgumentParser): + pass + + +def get_metadata(**kwargs): + metadata = pd.read_csv("hf://datasets/JeffreyXiang/TRELLIS-500K/Toys4k.csv") + return metadata + + +def download(metadata, output_dir, **kwargs): + os.makedirs(output_dir, exist_ok=True) + + if not os.path.exists(os.path.join(output_dir, 'raw', 'toys4k_blend_files.zip')): + print("\033[93m") + print("Toys4k have to be downloaded manually") + print(f"Please download the toys4k_blend_files.zip file and place it in the {output_dir}/raw directory") + print("Visit https://github.com/rehg-lab/lowshot-shapebias/tree/main/toys4k for more information") + print("\033[0m") + raise FileNotFoundError("toys4k_blend_files.zip not found") + + downloaded = {} + metadata = metadata.set_index("file_identifier") + with zipfile.ZipFile(os.path.join(output_dir, 'raw', 'toys4k_blend_files.zip')) as zip_ref: + with ThreadPoolExecutor(max_workers=os.cpu_count()) as executor, \ + tqdm(total=len(metadata), desc="Extracting") as pbar: + def worker(instance: str) -> str: + try: + zip_ref.extract(os.path.join('toys4k_blend_files', instance), os.path.join(output_dir, 'raw')) + sha256 = get_file_hash(os.path.join(output_dir, 'raw/toys4k_blend_files', instance)) + pbar.update() + return sha256 + except Exception as e: + pbar.update() + print(f"Error extracting for {instance}: {e}") + return None + + sha256s = executor.map(worker, metadata.index) + executor.shutdown(wait=True) + + for k, sha256 in zip(metadata.index, sha256s): + if sha256 is not None: + if sha256 == metadata.loc[k, "sha256"]: + downloaded[sha256] = os.path.join("raw/toys4k_blend_files", k) + else: + print(f"Error downloading {k}: sha256s do not match") + + return pd.DataFrame(downloaded.items(), columns=['sha256', 'local_path']) + + +def foreach_instance(metadata, output_dir, func, max_workers=None, desc='Processing objects') -> pd.DataFrame: + import os + from concurrent.futures import ThreadPoolExecutor + from tqdm import tqdm + + # load metadata + metadata = metadata.to_dict('records') + + # processing objects + records = [] + max_workers = max_workers or os.cpu_count() + try: + with ThreadPoolExecutor(max_workers=max_workers) as executor, \ + tqdm(total=len(metadata), desc=desc) as pbar: + def worker(metadatum): + try: + local_path = metadatum['local_path'] + sha256 = metadatum['sha256'] + file = os.path.join(output_dir, local_path) + record = func(file, sha256) + if record is not None: + records.append(record) + pbar.update() + except Exception as e: + print(f"Error processing object {sha256}: {e}") + pbar.update() + + executor.map(worker, metadata) + executor.shutdown(wait=True) + except: + print("Error happened during processing.") + + return pd.DataFrame.from_records(records) diff --git a/dataset_toolkits/download.py b/dataset_toolkits/download.py new file mode 100644 index 0000000000000000000000000000000000000000..36e684ff5e61105d8c69c101291dc8fa4415af3d --- /dev/null +++ b/dataset_toolkits/download.py @@ -0,0 +1,52 @@ +import os +import copy +import sys +import importlib +import argparse +import pandas as pd +from easydict import EasyDict as edict + +if __name__ == '__main__': + dataset_utils = importlib.import_module(f'datasets.{sys.argv[1]}') + + parser = argparse.ArgumentParser() + parser.add_argument('--output_dir', type=str, required=True, + help='Directory to save the metadata') + parser.add_argument('--filter_low_aesthetic_score', type=float, default=None, + help='Filter objects with aesthetic score lower than this value') + parser.add_argument('--instances', type=str, default=None, + help='Instances to process') + dataset_utils.add_args(parser) + parser.add_argument('--rank', type=int, default=0) + parser.add_argument('--world_size', type=int, default=1) + opt = parser.parse_args(sys.argv[2:]) + opt = edict(vars(opt)) + + os.makedirs(opt.output_dir, exist_ok=True) + + # get file list + if not os.path.exists(os.path.join(opt.output_dir, 'metadata.csv')): + raise ValueError('metadata.csv not found') + metadata = pd.read_csv(os.path.join(opt.output_dir, 'metadata.csv')) + if opt.instances is None: + if opt.filter_low_aesthetic_score is not None: + metadata = metadata[metadata['aesthetic_score'] >= opt.filter_low_aesthetic_score] + if 'local_path' in metadata.columns: + metadata = metadata[metadata['local_path'].isna()] + else: + if os.path.exists(opt.instances): + with open(opt.instances, 'r') as f: + instances = f.read().splitlines() + else: + instances = opt.instances.split(',') + metadata = metadata[metadata['sha256'].isin(instances)] + + start = len(metadata) * opt.rank // opt.world_size + end = len(metadata) * (opt.rank + 1) // opt.world_size + metadata = metadata[start:end] + + print(f'Processing {len(metadata)} objects...') + + # process objects + downloaded = dataset_utils.download(metadata, **opt) + downloaded.to_csv(os.path.join(opt.output_dir, f'downloaded_{opt.rank}.csv'), index=False) diff --git a/dataset_toolkits/encode_latent.py b/dataset_toolkits/encode_latent.py new file mode 100644 index 0000000000000000000000000000000000000000..d8c4770909d56ca3820f7aacd09bcf9b8988b9ac --- /dev/null +++ b/dataset_toolkits/encode_latent.py @@ -0,0 +1,127 @@ +import os +import sys +sys.path.append(os.path.join(os.path.dirname(__file__), '..')) +import copy +import json +import argparse +import torch +import numpy as np +import pandas as pd +from tqdm import tqdm +from easydict import EasyDict as edict +from concurrent.futures import ThreadPoolExecutor +from queue import Queue + +import trellis.models as models +import trellis.modules.sparse as sp + + +torch.set_grad_enabled(False) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('--output_dir', type=str, required=True, + help='Directory to save the metadata') + parser.add_argument('--filter_low_aesthetic_score', type=float, default=None, + help='Filter objects with aesthetic score lower than this value') + parser.add_argument('--feat_model', type=str, default='dinov2_vitl14_reg', + help='Feature model') + parser.add_argument('--enc_pretrained', type=str, default='JeffreyXiang/TRELLIS-image-large/ckpts/slat_enc_swin8_B_64l8_fp16', + help='Pretrained encoder model') + parser.add_argument('--model_root', type=str, default='results', + help='Root directory of models') + parser.add_argument('--enc_model', type=str, default=None, + help='Encoder model. if specified, use this model instead of pretrained model') + parser.add_argument('--ckpt', type=str, default=None, + help='Checkpoint to load') + parser.add_argument('--instances', type=str, default=None, + help='Instances to process') + parser.add_argument('--rank', type=int, default=0) + parser.add_argument('--world_size', type=int, default=1) + opt = parser.parse_args() + opt = edict(vars(opt)) + + if opt.enc_model is None: + latent_name = f'{opt.feat_model}_{opt.enc_pretrained.split("/")[-1]}' + encoder = models.from_pretrained(opt.enc_pretrained).eval().cuda() + else: + latent_name = f'{opt.feat_model}_{opt.enc_model}_{opt.ckpt}' + cfg = edict(json.load(open(os.path.join(opt.model_root, opt.enc_model, 'config.json'), 'r'))) + encoder = getattr(models, cfg.models.encoder.name)(**cfg.models.encoder.args).cuda() + ckpt_path = os.path.join(opt.model_root, opt.enc_model, 'ckpts', f'encoder_{opt.ckpt}.pt') + encoder.load_state_dict(torch.load(ckpt_path), strict=False) + encoder.eval() + print(f'Loaded model from {ckpt_path}') + + os.makedirs(os.path.join(opt.output_dir, 'latents', latent_name), exist_ok=True) + + # get file list + if os.path.exists(os.path.join(opt.output_dir, 'metadata.csv')): + metadata = pd.read_csv(os.path.join(opt.output_dir, 'metadata.csv')) + else: + raise ValueError('metadata.csv not found') + if opt.instances is not None: + with open(opt.instances, 'r') as f: + sha256s = [line.strip() for line in f] + metadata = metadata[metadata['sha256'].isin(sha256s)] + else: + if opt.filter_low_aesthetic_score is not None: + metadata = metadata[metadata['aesthetic_score'] >= opt.filter_low_aesthetic_score] + metadata = metadata[metadata[f'feature_{opt.feat_model}'] == True] + if f'latent_{latent_name}' in metadata.columns: + metadata = metadata[metadata[f'latent_{latent_name}'] == False] + + start = len(metadata) * opt.rank // opt.world_size + end = len(metadata) * (opt.rank + 1) // opt.world_size + metadata = metadata[start:end] + records = [] + + # filter out objects that are already processed + sha256s = list(metadata['sha256'].values) + for sha256 in copy.copy(sha256s): + if os.path.exists(os.path.join(opt.output_dir, 'latents', latent_name, f'{sha256}.npz')): + records.append({'sha256': sha256, f'latent_{latent_name}': True}) + sha256s.remove(sha256) + + # encode latents + load_queue = Queue(maxsize=4) + try: + with ThreadPoolExecutor(max_workers=32) as loader_executor, \ + ThreadPoolExecutor(max_workers=32) as saver_executor: + def loader(sha256): + try: + feats = np.load(os.path.join(opt.output_dir, 'features', opt.feat_model, f'{sha256}.npz')) + load_queue.put((sha256, feats)) + except Exception as e: + print(f"Error loading features for {sha256}: {e}") + loader_executor.map(loader, sha256s) + + def saver(sha256, pack): + save_path = os.path.join(opt.output_dir, 'latents', latent_name, f'{sha256}.npz') + np.savez_compressed(save_path, **pack) + records.append({'sha256': sha256, f'latent_{latent_name}': True}) + + for _ in tqdm(range(len(sha256s)), desc="Extracting latents"): + sha256, feats = load_queue.get() + feats = sp.SparseTensor( + feats = torch.from_numpy(feats['patchtokens']).float(), + coords = torch.cat([ + torch.zeros(feats['patchtokens'].shape[0], 1).int(), + torch.from_numpy(feats['indices']).int(), + ], dim=1), + ).cuda() + latent = encoder(feats, sample_posterior=False) + assert torch.isfinite(latent.feats).all(), "Non-finite latent" + pack = { + 'feats': latent.feats.cpu().numpy().astype(np.float32), + 'coords': latent.coords[:, 1:].cpu().numpy().astype(np.uint8), + } + saver_executor.submit(saver, sha256, pack) + + saver_executor.shutdown(wait=True) + except: + print("Error happened during processing.") + + records = pd.DataFrame.from_records(records) + records.to_csv(os.path.join(opt.output_dir, f'latent_{latent_name}_{opt.rank}.csv'), index=False) diff --git a/dataset_toolkits/encode_ss_latent.py b/dataset_toolkits/encode_ss_latent.py new file mode 100644 index 0000000000000000000000000000000000000000..c5af5df048524218308dbb4e3a29580fc3b41c20 --- /dev/null +++ b/dataset_toolkits/encode_ss_latent.py @@ -0,0 +1,128 @@ +import os +import sys +sys.path.append(os.path.join(os.path.dirname(__file__), '..')) +import copy +import json +import argparse +import torch +import numpy as np +import pandas as pd +import utils3d +from tqdm import tqdm +from easydict import EasyDict as edict +from concurrent.futures import ThreadPoolExecutor +from queue import Queue + +import trellis.models as models + + +torch.set_grad_enabled(False) + + +def get_voxels(instance): + position = utils3d.io.read_ply(os.path.join(opt.output_dir, 'voxels', f'{instance}.ply'))[0] + coords = ((torch.tensor(position) + 0.5) * opt.resolution).int().contiguous() + ss = torch.zeros(1, opt.resolution, opt.resolution, opt.resolution, dtype=torch.long) + ss[:, coords[:, 0], coords[:, 1], coords[:, 2]] = 1 + return ss + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('--output_dir', type=str, required=True, + help='Directory to save the metadata') + parser.add_argument('--filter_low_aesthetic_score', type=float, default=None, + help='Filter objects with aesthetic score lower than this value') + parser.add_argument('--enc_pretrained', type=str, default='JeffreyXiang/TRELLIS-image-large/ckpts/ss_enc_conv3d_16l8_fp16', + help='Pretrained encoder model') + parser.add_argument('--model_root', type=str, default='results', + help='Root directory of models') + parser.add_argument('--enc_model', type=str, default=None, + help='Encoder model. if specified, use this model instead of pretrained model') + parser.add_argument('--ckpt', type=str, default=None, + help='Checkpoint to load') + parser.add_argument('--resolution', type=int, default=64, + help='Resolution') + parser.add_argument('--instances', type=str, default=None, + help='Instances to process') + parser.add_argument('--rank', type=int, default=0) + parser.add_argument('--world_size', type=int, default=1) + opt = parser.parse_args() + opt = edict(vars(opt)) + + if opt.enc_model is None: + latent_name = f'{opt.enc_pretrained.split("/")[-1]}' + encoder = models.from_pretrained(opt.enc_pretrained).eval().cuda() + else: + latent_name = f'{opt.enc_model}_{opt.ckpt}' + cfg = edict(json.load(open(os.path.join(opt.model_root, opt.enc_model, 'config.json'), 'r'))) + encoder = getattr(models, cfg.models.encoder.name)(**cfg.models.encoder.args).cuda() + ckpt_path = os.path.join(opt.model_root, opt.enc_model, 'ckpts', f'encoder_{opt.ckpt}.pt') + encoder.load_state_dict(torch.load(ckpt_path), strict=False) + encoder.eval() + print(f'Loaded model from {ckpt_path}') + + os.makedirs(os.path.join(opt.output_dir, 'ss_latents', latent_name), exist_ok=True) + + # get file list + if os.path.exists(os.path.join(opt.output_dir, 'metadata.csv')): + metadata = pd.read_csv(os.path.join(opt.output_dir, 'metadata.csv')) + else: + raise ValueError('metadata.csv not found') + if opt.instances is not None: + with open(opt.instances, 'r') as f: + instances = f.read().splitlines() + metadata = metadata[metadata['sha256'].isin(instances)] + else: + if opt.filter_low_aesthetic_score is not None: + metadata = metadata[metadata['aesthetic_score'] >= opt.filter_low_aesthetic_score] + metadata = metadata[metadata['voxelized'] == True] + if f'ss_latent_{latent_name}' in metadata.columns: + metadata = metadata[metadata[f'ss_latent_{latent_name}'] == False] + + start = len(metadata) * opt.rank // opt.world_size + end = len(metadata) * (opt.rank + 1) // opt.world_size + metadata = metadata[start:end] + records = [] + + # filter out objects that are already processed + sha256s = list(metadata['sha256'].values) + for sha256 in copy.copy(sha256s): + if os.path.exists(os.path.join(opt.output_dir, 'ss_latents', latent_name, f'{sha256}.npz')): + records.append({'sha256': sha256, f'ss_latent_{latent_name}': True}) + sha256s.remove(sha256) + + # encode latents + load_queue = Queue(maxsize=4) + try: + with ThreadPoolExecutor(max_workers=32) as loader_executor, \ + ThreadPoolExecutor(max_workers=32) as saver_executor: + def loader(sha256): + try: + ss = get_voxels(sha256)[None].float() + load_queue.put((sha256, ss)) + except Exception as e: + print(f"Error loading features for {sha256}: {e}") + loader_executor.map(loader, sha256s) + + def saver(sha256, pack): + save_path = os.path.join(opt.output_dir, 'ss_latents', latent_name, f'{sha256}.npz') + np.savez_compressed(save_path, **pack) + records.append({'sha256': sha256, f'ss_latent_{latent_name}': True}) + + for _ in tqdm(range(len(sha256s)), desc="Extracting latents"): + sha256, ss = load_queue.get() + ss = ss.cuda().float() + latent = encoder(ss, sample_posterior=False) + assert torch.isfinite(latent).all(), "Non-finite latent" + pack = { + 'mean': latent[0].cpu().numpy(), + } + saver_executor.submit(saver, sha256, pack) + + saver_executor.shutdown(wait=True) + except: + print("Error happened during processing.") + + records = pd.DataFrame.from_records(records) + records.to_csv(os.path.join(opt.output_dir, f'ss_latent_{latent_name}_{opt.rank}.csv'), index=False) diff --git a/dataset_toolkits/extract_feature.py b/dataset_toolkits/extract_feature.py new file mode 100644 index 0000000000000000000000000000000000000000..1ded2fe03b3e4355da00af429c9439a68069308a --- /dev/null +++ b/dataset_toolkits/extract_feature.py @@ -0,0 +1,179 @@ +import os +import copy +import sys +import json +import importlib +import argparse +import torch +import torch.nn.functional as F +import numpy as np +import pandas as pd +import utils3d +from tqdm import tqdm +from easydict import EasyDict as edict +from concurrent.futures import ThreadPoolExecutor +from queue import Queue +from torchvision import transforms +from PIL import Image + + +torch.set_grad_enabled(False) + + +def get_data(frames, sha256): + with ThreadPoolExecutor(max_workers=16) as executor: + def worker(view): + image_path = os.path.join(opt.output_dir, 'renders', sha256, view['file_path']) + try: + image = Image.open(image_path) + except: + print(f"Error loading image {image_path}") + return None + image = image.resize((518, 518), Image.Resampling.LANCZOS) + image = np.array(image).astype(np.float32) / 255 + image = image[:, :, :3] * image[:, :, 3:] + image = torch.from_numpy(image).permute(2, 0, 1).float() + + c2w = torch.tensor(view['transform_matrix']) + c2w[:3, 1:3] *= -1 + extrinsics = torch.inverse(c2w) + fov = view['camera_angle_x'] + intrinsics = utils3d.torch.intrinsics_from_fov_xy(torch.tensor(fov), torch.tensor(fov)) + + return { + 'image': image, + 'extrinsics': extrinsics, + 'intrinsics': intrinsics + } + + datas = executor.map(worker, frames) + for data in datas: + if data is not None: + yield data + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('--output_dir', type=str, required=True, + help='Directory to save the metadata') + parser.add_argument('--filter_low_aesthetic_score', type=float, default=None, + help='Filter objects with aesthetic score lower than this value') + parser.add_argument('--model', type=str, default='dinov2_vitl14_reg', + help='Feature extraction model') + parser.add_argument('--instances', type=str, default=None, + help='Instances to process') + parser.add_argument('--batch_size', type=int, default=16) + parser.add_argument('--rank', type=int, default=0) + parser.add_argument('--world_size', type=int, default=1) + opt = parser.parse_args() + opt = edict(vars(opt)) + + feature_name = opt.model + os.makedirs(os.path.join(opt.output_dir, 'features', feature_name), exist_ok=True) + + # load model + dinov2_model = torch.hub.load('facebookresearch/dinov2', opt.model) + dinov2_model.eval().cuda() + transform = transforms.Compose([ + transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), + ]) + n_patch = 518 // 14 + + # get file list + if os.path.exists(os.path.join(opt.output_dir, 'metadata.csv')): + metadata = pd.read_csv(os.path.join(opt.output_dir, 'metadata.csv')) + else: + raise ValueError('metadata.csv not found') + if opt.instances is not None: + with open(opt.instances, 'r') as f: + instances = f.read().splitlines() + metadata = metadata[metadata['sha256'].isin(instances)] + else: + if opt.filter_low_aesthetic_score is not None: + metadata = metadata[metadata['aesthetic_score'] >= opt.filter_low_aesthetic_score] + if f'feature_{feature_name}' in metadata.columns: + metadata = metadata[metadata[f'feature_{feature_name}'] == False] + metadata = metadata[metadata['voxelized'] == True] + metadata = metadata[metadata['rendered'] == True] + + start = len(metadata) * opt.rank // opt.world_size + end = len(metadata) * (opt.rank + 1) // opt.world_size + metadata = metadata[start:end] + records = [] + + # filter out objects that are already processed + sha256s = list(metadata['sha256'].values) + for sha256 in copy.copy(sha256s): + if os.path.exists(os.path.join(opt.output_dir, 'features', feature_name, f'{sha256}.npz')): + records.append({'sha256': sha256, f'feature_{feature_name}' : True}) + sha256s.remove(sha256) + + # extract features + load_queue = Queue(maxsize=4) + try: + with ThreadPoolExecutor(max_workers=8) as loader_executor, \ + ThreadPoolExecutor(max_workers=8) as saver_executor: + def loader(sha256): + try: + with open(os.path.join(opt.output_dir, 'renders', sha256, 'transforms.json'), 'r') as f: + metadata = json.load(f) + frames = metadata['frames'] + data = [] + for datum in get_data(frames, sha256): + datum['image'] = transform(datum['image']) + data.append(datum) + positions = utils3d.io.read_ply(os.path.join(opt.output_dir, 'voxels', f'{sha256}.ply'))[0] + load_queue.put((sha256, data, positions)) + except Exception as e: + print(f"Error loading data for {sha256}: {e}") + + loader_executor.map(loader, sha256s) + + def saver(sha256, pack, patchtokens, uv): + pack['patchtokens'] = F.grid_sample( + patchtokens, + uv.unsqueeze(1), + mode='bilinear', + align_corners=False, + ).squeeze(2).permute(0, 2, 1).cpu().numpy() + pack['patchtokens'] = np.mean(pack['patchtokens'], axis=0).astype(np.float16) + save_path = os.path.join(opt.output_dir, 'features', feature_name, f'{sha256}.npz') + np.savez_compressed(save_path, **pack) + records.append({'sha256': sha256, f'feature_{feature_name}' : True}) + + for _ in tqdm(range(len(sha256s)), desc="Extracting features"): + sha256, data, positions = load_queue.get() + positions = torch.from_numpy(positions).float().cuda() + indices = ((positions + 0.5) * 64).long() + assert torch.all(indices >= 0) and torch.all(indices < 64), "Some vertices are out of bounds" + n_views = len(data) + N = positions.shape[0] + pack = { + 'indices': indices.cpu().numpy().astype(np.uint8), + } + patchtokens_lst = [] + uv_lst = [] + for i in range(0, n_views, opt.batch_size): + batch_data = data[i:i+opt.batch_size] + bs = len(batch_data) + batch_images = torch.stack([d['image'] for d in batch_data]).cuda() + batch_extrinsics = torch.stack([d['extrinsics'] for d in batch_data]).cuda() + batch_intrinsics = torch.stack([d['intrinsics'] for d in batch_data]).cuda() + features = dinov2_model(batch_images, is_training=True) + uv = utils3d.torch.project_cv(positions, batch_extrinsics, batch_intrinsics)[0] * 2 - 1 + patchtokens = features['x_prenorm'][:, dinov2_model.num_register_tokens + 1:].permute(0, 2, 1).reshape(bs, 1024, n_patch, n_patch) + patchtokens_lst.append(patchtokens) + uv_lst.append(uv) + patchtokens = torch.cat(patchtokens_lst, dim=0) + uv = torch.cat(uv_lst, dim=0) + + # save features + saver_executor.submit(saver, sha256, pack, patchtokens, uv) + + saver_executor.shutdown(wait=True) + except: + print("Error happened during processing.") + + records = pd.DataFrame.from_records(records) + records.to_csv(os.path.join(opt.output_dir, f'feature_{feature_name}_{opt.rank}.csv'), index=False) + \ No newline at end of file diff --git a/dataset_toolkits/render.py b/dataset_toolkits/render.py new file mode 100644 index 0000000000000000000000000000000000000000..636f3b308fa9a33e1304b6b64370221167885118 --- /dev/null +++ b/dataset_toolkits/render.py @@ -0,0 +1,121 @@ +import os +import json +import copy +import sys +import importlib +import argparse +import pandas as pd +from easydict import EasyDict as edict +from functools import partial +from subprocess import DEVNULL, call +import numpy as np +from utils import sphere_hammersley_sequence + + +BLENDER_LINK = 'https://download.blender.org/release/Blender3.0/blender-3.0.1-linux-x64.tar.xz' +BLENDER_INSTALLATION_PATH = '/tmp' +BLENDER_PATH = f'{BLENDER_INSTALLATION_PATH}/blender-3.0.1-linux-x64/blender' + +def _install_blender(): + if not os.path.exists(BLENDER_PATH): + os.system('sudo apt-get update') + os.system('sudo apt-get install -y libxrender1 libxi6 libxkbcommon-x11-0 libsm6') + os.system(f'wget {BLENDER_LINK} -P {BLENDER_INSTALLATION_PATH}') + os.system(f'tar -xvf {BLENDER_INSTALLATION_PATH}/blender-3.0.1-linux-x64.tar.xz -C {BLENDER_INSTALLATION_PATH}') + + +def _render(file_path, sha256, output_dir, num_views): + output_folder = os.path.join(output_dir, 'renders', sha256) + + # Build camera {yaw, pitch, radius, fov} + yaws = [] + pitchs = [] + offset = (np.random.rand(), np.random.rand()) + for i in range(num_views): + y, p = sphere_hammersley_sequence(i, num_views, offset) + yaws.append(y) + pitchs.append(p) + radius = [2] * num_views + fov = [40 / 180 * np.pi] * num_views + views = [{'yaw': y, 'pitch': p, 'radius': r, 'fov': f} for y, p, r, f in zip(yaws, pitchs, radius, fov)] + + args = [ + BLENDER_PATH, '-b', '-P', os.path.join(os.path.dirname(__file__), 'blender_script', 'render.py'), + '--', + '--views', json.dumps(views), + '--object', os.path.expanduser(file_path), + '--resolution', '512', + '--output_folder', output_folder, + '--engine', 'CYCLES', + '--save_mesh', + ] + if file_path.endswith('.blend'): + args.insert(1, file_path) + + call(args, stdout=DEVNULL, stderr=DEVNULL) + + if os.path.exists(os.path.join(output_folder, 'transforms.json')): + return {'sha256': sha256, 'rendered': True} + + +if __name__ == '__main__': + dataset_utils = importlib.import_module(f'datasets.{sys.argv[1]}') + + parser = argparse.ArgumentParser() + parser.add_argument('--output_dir', type=str, required=True, + help='Directory to save the metadata') + parser.add_argument('--filter_low_aesthetic_score', type=float, default=None, + help='Filter objects with aesthetic score lower than this value') + parser.add_argument('--instances', type=str, default=None, + help='Instances to process') + parser.add_argument('--num_views', type=int, default=150, + help='Number of views to render') + dataset_utils.add_args(parser) + parser.add_argument('--rank', type=int, default=0) + parser.add_argument('--world_size', type=int, default=1) + parser.add_argument('--max_workers', type=int, default=8) + opt = parser.parse_args(sys.argv[2:]) + opt = edict(vars(opt)) + + os.makedirs(os.path.join(opt.output_dir, 'renders'), exist_ok=True) + + # install blender + print('Checking blender...', flush=True) + _install_blender() + + # get file list + if not os.path.exists(os.path.join(opt.output_dir, 'metadata.csv')): + raise ValueError('metadata.csv not found') + metadata = pd.read_csv(os.path.join(opt.output_dir, 'metadata.csv')) + if opt.instances is None: + metadata = metadata[metadata['local_path'].notna()] + if opt.filter_low_aesthetic_score is not None: + metadata = metadata[metadata['aesthetic_score'] >= opt.filter_low_aesthetic_score] + if 'rendered' in metadata.columns: + metadata = metadata[metadata['rendered'] == False] + else: + if os.path.exists(opt.instances): + with open(opt.instances, 'r') as f: + instances = f.read().splitlines() + else: + instances = opt.instances.split(',') + metadata = metadata[metadata['sha256'].isin(instances)] + + start = len(metadata) * opt.rank // opt.world_size + end = len(metadata) * (opt.rank + 1) // opt.world_size + metadata = metadata[start:end] + records = [] + + # filter out objects that are already processed + for sha256 in copy.copy(metadata['sha256'].values): + if os.path.exists(os.path.join(opt.output_dir, 'renders', sha256, 'transforms.json')): + records.append({'sha256': sha256, 'rendered': True}) + metadata = metadata[metadata['sha256'] != sha256] + + print(f'Processing {len(metadata)} objects...') + + # process objects + func = partial(_render, output_dir=opt.output_dir, num_views=opt.num_views) + rendered = dataset_utils.foreach_instance(metadata, opt.output_dir, func, max_workers=opt.max_workers, desc='Rendering objects') + rendered = pd.concat([rendered, pd.DataFrame.from_records(records)]) + rendered.to_csv(os.path.join(opt.output_dir, f'rendered_{opt.rank}.csv'), index=False) diff --git a/dataset_toolkits/render_cond.py b/dataset_toolkits/render_cond.py new file mode 100644 index 0000000000000000000000000000000000000000..b2a40e6d9974a9e7256571d9ed636f4a5758b5c7 --- /dev/null +++ b/dataset_toolkits/render_cond.py @@ -0,0 +1,125 @@ +import os +import json +import copy +import sys +import importlib +import argparse +import pandas as pd +from easydict import EasyDict as edict +from functools import partial +from subprocess import DEVNULL, call +import numpy as np +from utils import sphere_hammersley_sequence + + +BLENDER_LINK = 'https://download.blender.org/release/Blender3.0/blender-3.0.1-linux-x64.tar.xz' +BLENDER_INSTALLATION_PATH = '/tmp' +BLENDER_PATH = f'{BLENDER_INSTALLATION_PATH}/blender-3.0.1-linux-x64/blender' + +def _install_blender(): + if not os.path.exists(BLENDER_PATH): + os.system('sudo apt-get update') + os.system('sudo apt-get install -y libxrender1 libxi6 libxkbcommon-x11-0 libsm6') + os.system(f'wget {BLENDER_LINK} -P {BLENDER_INSTALLATION_PATH}') + os.system(f'tar -xvf {BLENDER_INSTALLATION_PATH}/blender-3.0.1-linux-x64.tar.xz -C {BLENDER_INSTALLATION_PATH}') + + +def _render_cond(file_path, sha256, output_dir, num_views): + output_folder = os.path.join(output_dir, 'renders_cond', sha256) + + # Build camera {yaw, pitch, radius, fov} + yaws = [] + pitchs = [] + offset = (np.random.rand(), np.random.rand()) + for i in range(num_views): + y, p = sphere_hammersley_sequence(i, num_views, offset) + yaws.append(y) + pitchs.append(p) + fov_min, fov_max = 10, 70 + radius_min = np.sqrt(3) / 2 / np.sin(fov_max / 360 * np.pi) + radius_max = np.sqrt(3) / 2 / np.sin(fov_min / 360 * np.pi) + k_min = 1 / radius_max**2 + k_max = 1 / radius_min**2 + ks = np.random.uniform(k_min, k_max, (1000000,)) + radius = [1 / np.sqrt(k) for k in ks] + fov = [2 * np.arcsin(np.sqrt(3) / 2 / r) for r in radius] + views = [{'yaw': y, 'pitch': p, 'radius': r, 'fov': f} for y, p, r, f in zip(yaws, pitchs, radius, fov)] + + args = [ + BLENDER_PATH, '-b', '-P', os.path.join(os.path.dirname(__file__), 'blender_script', 'render.py'), + '--', + '--views', json.dumps(views), + '--object', os.path.expanduser(file_path), + '--output_folder', os.path.expanduser(output_folder), + '--resolution', '1024', + ] + if file_path.endswith('.blend'): + args.insert(1, file_path) + + call(args, stdout=DEVNULL) + + if os.path.exists(os.path.join(output_folder, 'transforms.json')): + return {'sha256': sha256, 'cond_rendered': True} + + +if __name__ == '__main__': + dataset_utils = importlib.import_module(f'datasets.{sys.argv[1]}') + + parser = argparse.ArgumentParser() + parser.add_argument('--output_dir', type=str, required=True, + help='Directory to save the metadata') + parser.add_argument('--filter_low_aesthetic_score', type=float, default=None, + help='Filter objects with aesthetic score lower than this value') + parser.add_argument('--instances', type=str, default=None, + help='Instances to process') + parser.add_argument('--num_views', type=int, default=24, + help='Number of views to render') + dataset_utils.add_args(parser) + parser.add_argument('--rank', type=int, default=0) + parser.add_argument('--world_size', type=int, default=1) + parser.add_argument('--max_workers', type=int, default=8) + opt = parser.parse_args(sys.argv[2:]) + opt = edict(vars(opt)) + + os.makedirs(os.path.join(opt.output_dir, 'renders_cond'), exist_ok=True) + + # install blender + print('Checking blender...', flush=True) + _install_blender() + + # get file list + if not os.path.exists(os.path.join(opt.output_dir, 'metadata.csv')): + raise ValueError('metadata.csv not found') + metadata = pd.read_csv(os.path.join(opt.output_dir, 'metadata.csv')) + if opt.instances is None: + metadata = metadata[metadata['local_path'].notna()] + if opt.filter_low_aesthetic_score is not None: + metadata = metadata[metadata['aesthetic_score'] >= opt.filter_low_aesthetic_score] + if 'cond_rendered' in metadata.columns: + metadata = metadata[metadata['cond_rendered'] == False] + else: + if os.path.exists(opt.instances): + with open(opt.instances, 'r') as f: + instances = f.read().splitlines() + else: + instances = opt.instances.split(',') + metadata = metadata[metadata['sha256'].isin(instances)] + + start = len(metadata) * opt.rank // opt.world_size + end = len(metadata) * (opt.rank + 1) // opt.world_size + metadata = metadata[start:end] + records = [] + + # filter out objects that are already processed + for sha256 in copy.copy(metadata['sha256'].values): + if os.path.exists(os.path.join(opt.output_dir, 'renders_cond', sha256, 'transforms.json')): + records.append({'sha256': sha256, 'cond_rendered': True}) + metadata = metadata[metadata['sha256'] != sha256] + + print(f'Processing {len(metadata)} objects...') + + # process objects + func = partial(_render_cond, output_dir=opt.output_dir, num_views=opt.num_views) + cond_rendered = dataset_utils.foreach_instance(metadata, opt.output_dir, func, max_workers=opt.max_workers, desc='Rendering objects') + cond_rendered = pd.concat([cond_rendered, pd.DataFrame.from_records(records)]) + cond_rendered.to_csv(os.path.join(opt.output_dir, f'cond_rendered_{opt.rank}.csv'), index=False) diff --git a/dataset_toolkits/setup.sh b/dataset_toolkits/setup.sh new file mode 100644 index 0000000000000000000000000000000000000000..e387ea4fe1f7d6b68ee7f5ccf07043f80bd624f4 --- /dev/null +++ b/dataset_toolkits/setup.sh @@ -0,0 +1 @@ +pip install pillow imageio imageio-ffmpeg tqdm easydict opencv-python-headless pandas open3d objaverse huggingface_hub diff --git a/dataset_toolkits/stat_latent.py b/dataset_toolkits/stat_latent.py new file mode 100644 index 0000000000000000000000000000000000000000..7f27a062bb8be898ef6c705a3f3d9488f94299fe --- /dev/null +++ b/dataset_toolkits/stat_latent.py @@ -0,0 +1,66 @@ +import os +import json +import argparse +import numpy as np +import pandas as pd +from tqdm import tqdm +from easydict import EasyDict as edict +from concurrent.futures import ThreadPoolExecutor + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('--output_dir', type=str, required=True, + help='Directory to save the metadata') + parser.add_argument('--filter_low_aesthetic_score', type=float, default=None, + help='Filter objects with aesthetic score lower than this value') + parser.add_argument('--model', type=str, default='dinov2_vitl14_reg_slat_enc_swin8_B_64l8_fp16', + help='Latent model to use') + parser.add_argument('--num_samples', type=int, default=50000, + help='Number of samples to use for calculating stats') + opt = parser.parse_args() + opt = edict(vars(opt)) + + # get file list + if os.path.exists(os.path.join(opt.output_dir, 'metadata.csv')): + metadata = pd.read_csv(os.path.join(opt.output_dir, 'metadata.csv')) + else: + raise ValueError('metadata.csv not found') + if opt.filter_low_aesthetic_score is not None: + metadata = metadata[metadata['aesthetic_score'] >= opt.filter_low_aesthetic_score] + metadata = metadata[metadata[f'latent_{opt.model}'] == True] + sha256s = metadata['sha256'].values + sha256s = np.random.choice(sha256s, min(opt.num_samples, len(sha256s)), replace=False) + + # stats + means = [] + mean2s = [] + with ThreadPoolExecutor(max_workers=16) as executor, \ + tqdm(total=len(sha256s), desc="Extracting features") as pbar: + def worker(sha256): + try: + feats = np.load(os.path.join(opt.output_dir, 'latents', opt.model, f'{sha256}.npz')) + feats = feats['feats'] + means.append(feats.mean(axis=0)) + mean2s.append((feats ** 2).mean(axis=0)) + pbar.update() + except Exception as e: + print(f"Error extracting features for {sha256}: {e}") + pbar.update() + + executor.map(worker, sha256s) + executor.shutdown(wait=True) + + mean = np.array(means).mean(axis=0) + mean2 = np.array(mean2s).mean(axis=0) + std = np.sqrt(mean2 - mean ** 2) + + print('mean:', mean) + print('std:', std) + + with open(os.path.join(opt.output_dir, 'latents', opt.model, 'stats.json'), 'w') as f: + json.dump({ + 'mean': mean.tolist(), + 'std': std.tolist(), + }, f, indent=4) + \ No newline at end of file diff --git a/dataset_toolkits/utils.py b/dataset_toolkits/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..69830845d40697bd2c4e7c68e64e54b4d4091a81 --- /dev/null +++ b/dataset_toolkits/utils.py @@ -0,0 +1,43 @@ +from typing import * +import hashlib +import numpy as np + + +def get_file_hash(file: str) -> str: + sha256 = hashlib.sha256() + # Read the file from the path + with open(file, "rb") as f: + # Update the hash with the file content + for byte_block in iter(lambda: f.read(4096), b""): + sha256.update(byte_block) + return sha256.hexdigest() + +# ===============LOW DISCREPANCY SEQUENCES================ + +PRIMES = [2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37, 41, 43, 47, 53] + +def radical_inverse(base, n): + val = 0 + inv_base = 1.0 / base + inv_base_n = inv_base + while n > 0: + digit = n % base + val += digit * inv_base_n + n //= base + inv_base_n *= inv_base + return val + +def halton_sequence(dim, n): + return [radical_inverse(PRIMES[dim], n) for dim in range(dim)] + +def hammersley_sequence(dim, n, num_samples): + return [n / num_samples] + halton_sequence(dim - 1, n) + +def sphere_hammersley_sequence(n, num_samples, offset=(0, 0)): + u, v = hammersley_sequence(2, n, num_samples) + u += offset[0] / num_samples + v += offset[1] + u = 2 * u if u < 0.25 else 2 / 3 * u + 1 / 3 + theta = np.arccos(1 - 2 * u) - np.pi / 2 + phi = v * 2 * np.pi + return [phi, theta] diff --git a/dataset_toolkits/voxelize.py b/dataset_toolkits/voxelize.py new file mode 100644 index 0000000000000000000000000000000000000000..390575ab9e26e73b467151dadc8252eae3c96424 --- /dev/null +++ b/dataset_toolkits/voxelize.py @@ -0,0 +1,86 @@ +import os +import copy +import sys +import importlib +import argparse +import pandas as pd +from easydict import EasyDict as edict +from functools import partial +import numpy as np +import open3d as o3d +import utils3d + + +def _voxelize(file, sha256, output_dir): + mesh = o3d.io.read_triangle_mesh(os.path.join(output_dir, 'renders', sha256, 'mesh.ply')) + # clamp vertices to the range [-0.5, 0.5] + vertices = np.clip(np.asarray(mesh.vertices), -0.5 + 1e-6, 0.5 - 1e-6) + mesh.vertices = o3d.utility.Vector3dVector(vertices) + voxel_grid = o3d.geometry.VoxelGrid.create_from_triangle_mesh_within_bounds(mesh, voxel_size=1/64, min_bound=(-0.5, -0.5, -0.5), max_bound=(0.5, 0.5, 0.5)) + vertices = np.array([voxel.grid_index for voxel in voxel_grid.get_voxels()]) + assert np.all(vertices >= 0) and np.all(vertices < 64), "Some vertices are out of bounds" + vertices = (vertices + 0.5) / 64 - 0.5 + utils3d.io.write_ply(os.path.join(output_dir, 'voxels', f'{sha256}.ply'), vertices) + return {'sha256': sha256, 'voxelized': True, 'num_voxels': len(vertices)} + + +if __name__ == '__main__': + dataset_utils = importlib.import_module(f'datasets.{sys.argv[1]}') + + parser = argparse.ArgumentParser() + parser.add_argument('--output_dir', type=str, required=True, + help='Directory to save the metadata') + parser.add_argument('--filter_low_aesthetic_score', type=float, default=None, + help='Filter objects with aesthetic score lower than this value') + parser.add_argument('--instances', type=str, default=None, + help='Instances to process') + parser.add_argument('--num_views', type=int, default=150, + help='Number of views to render') + dataset_utils.add_args(parser) + parser.add_argument('--rank', type=int, default=0) + parser.add_argument('--world_size', type=int, default=1) + parser.add_argument('--max_workers', type=int, default=None) + opt = parser.parse_args(sys.argv[2:]) + opt = edict(vars(opt)) + + os.makedirs(os.path.join(opt.output_dir, 'voxels'), exist_ok=True) + + # get file list + if not os.path.exists(os.path.join(opt.output_dir, 'metadata.csv')): + raise ValueError('metadata.csv not found') + metadata = pd.read_csv(os.path.join(opt.output_dir, 'metadata.csv')) + if opt.instances is None: + if opt.filter_low_aesthetic_score is not None: + metadata = metadata[metadata['aesthetic_score'] >= opt.filter_low_aesthetic_score] + if 'rendered' not in metadata.columns: + raise ValueError('metadata.csv does not have "rendered" column, please run "build_metadata.py" first') + metadata = metadata[metadata['rendered'] == True] + if 'voxelized' in metadata.columns: + metadata = metadata[metadata['voxelized'] == False] + else: + if os.path.exists(opt.instances): + with open(opt.instances, 'r') as f: + instances = f.read().splitlines() + else: + instances = opt.instances.split(',') + metadata = metadata[metadata['sha256'].isin(instances)] + + start = len(metadata) * opt.rank // opt.world_size + end = len(metadata) * (opt.rank + 1) // opt.world_size + metadata = metadata[start:end] + records = [] + + # filter out objects that are already processed + for sha256 in copy.copy(metadata['sha256'].values): + if os.path.exists(os.path.join(opt.output_dir, 'voxels', f'{sha256}.ply')): + pts = utils3d.io.read_ply(os.path.join(opt.output_dir, 'voxels', f'{sha256}.ply'))[0] + records.append({'sha256': sha256, 'voxelized': True, 'num_voxels': len(pts)}) + metadata = metadata[metadata['sha256'] != sha256] + + print(f'Processing {len(metadata)} objects...') + + # process objects + func = partial(_voxelize, output_dir=opt.output_dir) + voxelized = dataset_utils.foreach_instance(metadata, opt.output_dir, func, max_workers=opt.max_workers, desc='Voxelizing') + voxelized = pd.concat([voxelized, pd.DataFrame.from_records(records)]) + voxelized.to_csv(os.path.join(opt.output_dir, f'voxelized_{opt.rank}.csv'), index=False) diff --git a/extensions/vox2seq/benchmark.py b/extensions/vox2seq/benchmark.py new file mode 100644 index 0000000000000000000000000000000000000000..30351e0251cfed4db82d286af1b53654f8cdce8b --- /dev/null +++ b/extensions/vox2seq/benchmark.py @@ -0,0 +1,45 @@ +import time +import torch +import vox2seq + + +if __name__ == "__main__": + stats = { + 'z_order_cuda': [], + 'z_order_pytorch': [], + 'hilbert_cuda': [], + 'hilbert_pytorch': [], + } + RES = [16, 32, 64, 128, 256] + for res in RES: + coords = torch.meshgrid(torch.arange(res), torch.arange(res), torch.arange(res)) + coords = torch.stack(coords, dim=-1).reshape(-1, 3).int().cuda() + + start = time.time() + for _ in range(100): + code_z_cuda = vox2seq.encode(coords, mode='z_order').cuda() + torch.cuda.synchronize() + stats['z_order_cuda'].append((time.time() - start) / 100) + + start = time.time() + for _ in range(100): + code_z_pytorch = vox2seq.pytorch.encode(coords, mode='z_order').cuda() + torch.cuda.synchronize() + stats['z_order_pytorch'].append((time.time() - start) / 100) + + start = time.time() + for _ in range(100): + code_h_cuda = vox2seq.encode(coords, mode='hilbert').cuda() + torch.cuda.synchronize() + stats['hilbert_cuda'].append((time.time() - start) / 100) + + start = time.time() + for _ in range(100): + code_h_pytorch = vox2seq.pytorch.encode(coords, mode='hilbert').cuda() + torch.cuda.synchronize() + stats['hilbert_pytorch'].append((time.time() - start) / 100) + + print(f"{'Resolution':<12}{'Z-Order (CUDA)':<24}{'Z-Order (PyTorch)':<24}{'Hilbert (CUDA)':<24}{'Hilbert (PyTorch)':<24}") + for res, z_order_cuda, z_order_pytorch, hilbert_cuda, hilbert_pytorch in zip(RES, stats['z_order_cuda'], stats['z_order_pytorch'], stats['hilbert_cuda'], stats['hilbert_pytorch']): + print(f"{res:<12}{z_order_cuda:<24.6f}{z_order_pytorch:<24.6f}{hilbert_cuda:<24.6f}{hilbert_pytorch:<24.6f}") + diff --git a/extensions/vox2seq/setup.py b/extensions/vox2seq/setup.py new file mode 100644 index 0000000000000000000000000000000000000000..500d97b7c2c69e2ced48a9286b1b030976f2fb47 --- /dev/null +++ b/extensions/vox2seq/setup.py @@ -0,0 +1,34 @@ +# +# Copyright (C) 2023, Inria +# GRAPHDECO research group, https://team.inria.fr/graphdeco +# All rights reserved. +# +# This software is free for non-commercial, research and evaluation use +# under the terms of the LICENSE.md file. +# +# For inquiries contact george.drettakis@inria.fr +# + +from setuptools import setup +from torch.utils.cpp_extension import CUDAExtension, BuildExtension +import os +os.path.dirname(os.path.abspath(__file__)) + +setup( + name="vox2seq", + packages=['vox2seq', 'vox2seq.pytorch'], + ext_modules=[ + CUDAExtension( + name="vox2seq._C", + sources=[ + "src/api.cu", + "src/z_order.cu", + "src/hilbert.cu", + "src/ext.cpp", + ], + ) + ], + cmdclass={ + 'build_ext': BuildExtension + } +) diff --git a/extensions/vox2seq/src/api.cu b/extensions/vox2seq/src/api.cu new file mode 100644 index 0000000000000000000000000000000000000000..072e930f90278f2f407b45750220d7d98c37b91e --- /dev/null +++ b/extensions/vox2seq/src/api.cu @@ -0,0 +1,92 @@ +#include +#include "api.h" +#include "z_order.h" +#include "hilbert.h" + + +torch::Tensor +z_order_encode( + const torch::Tensor& x, + const torch::Tensor& y, + const torch::Tensor& z +) { + // Allocate output tensor + torch::Tensor codes = torch::empty_like(x); + + // Call CUDA kernel + z_order_encode_cuda<<<(x.size(0) + BLOCK_SIZE - 1) / BLOCK_SIZE, BLOCK_SIZE>>>( + x.size(0), + reinterpret_cast(x.contiguous().data_ptr()), + reinterpret_cast(y.contiguous().data_ptr()), + reinterpret_cast(z.contiguous().data_ptr()), + reinterpret_cast(codes.data_ptr()) + ); + + return codes; +} + + +std::tuple +z_order_decode( + const torch::Tensor& codes +) { + // Allocate output tensors + torch::Tensor x = torch::empty_like(codes); + torch::Tensor y = torch::empty_like(codes); + torch::Tensor z = torch::empty_like(codes); + + // Call CUDA kernel + z_order_decode_cuda<<<(codes.size(0) + BLOCK_SIZE - 1) / BLOCK_SIZE, BLOCK_SIZE>>>( + codes.size(0), + reinterpret_cast(codes.contiguous().data_ptr()), + reinterpret_cast(x.data_ptr()), + reinterpret_cast(y.data_ptr()), + reinterpret_cast(z.data_ptr()) + ); + + return std::make_tuple(x, y, z); +} + + +torch::Tensor +hilbert_encode( + const torch::Tensor& x, + const torch::Tensor& y, + const torch::Tensor& z +) { + // Allocate output tensor + torch::Tensor codes = torch::empty_like(x); + + // Call CUDA kernel + hilbert_encode_cuda<<<(x.size(0) + BLOCK_SIZE - 1) / BLOCK_SIZE, BLOCK_SIZE>>>( + x.size(0), + reinterpret_cast(x.contiguous().data_ptr()), + reinterpret_cast(y.contiguous().data_ptr()), + reinterpret_cast(z.contiguous().data_ptr()), + reinterpret_cast(codes.data_ptr()) + ); + + return codes; +} + + +std::tuple +hilbert_decode( + const torch::Tensor& codes +) { + // Allocate output tensors + torch::Tensor x = torch::empty_like(codes); + torch::Tensor y = torch::empty_like(codes); + torch::Tensor z = torch::empty_like(codes); + + // Call CUDA kernel + hilbert_decode_cuda<<<(codes.size(0) + BLOCK_SIZE - 1) / BLOCK_SIZE, BLOCK_SIZE>>>( + codes.size(0), + reinterpret_cast(codes.contiguous().data_ptr()), + reinterpret_cast(x.data_ptr()), + reinterpret_cast(y.data_ptr()), + reinterpret_cast(z.data_ptr()) + ); + + return std::make_tuple(x, y, z); +} diff --git a/extensions/vox2seq/src/api.h b/extensions/vox2seq/src/api.h new file mode 100644 index 0000000000000000000000000000000000000000..26a68348d56585d0e9e1dfb4900a0d23587df9a6 --- /dev/null +++ b/extensions/vox2seq/src/api.h @@ -0,0 +1,76 @@ +/* + * Serialize a voxel grid + * + * Copyright (C) 2024, Jianfeng XIANG + * All rights reserved. + * + * Licensed under The MIT License [see LICENSE for details] + * + * Written by Jianfeng XIANG + */ + +#pragma once +#include + + +#define BLOCK_SIZE 256 + + +/** + * Z-order encode 3D points + * + * @param x [N] tensor containing the x coordinates + * @param y [N] tensor containing the y coordinates + * @param z [N] tensor containing the z coordinates + * + * @return [N] tensor containing the z-order encoded values + */ +torch::Tensor +z_order_encode( + const torch::Tensor& x, + const torch::Tensor& y, + const torch::Tensor& z +); + + +/** + * Z-order decode 3D points + * + * @param codes [N] tensor containing the z-order encoded values + * + * @return 3 tensors [N] containing the x, y, z coordinates + */ +std::tuple +z_order_decode( + const torch::Tensor& codes +); + + +/** + * Hilbert encode 3D points + * + * @param x [N] tensor containing the x coordinates + * @param y [N] tensor containing the y coordinates + * @param z [N] tensor containing the z coordinates + * + * @return [N] tensor containing the Hilbert encoded values + */ +torch::Tensor +hilbert_encode( + const torch::Tensor& x, + const torch::Tensor& y, + const torch::Tensor& z +); + + +/** + * Hilbert decode 3D points + * + * @param codes [N] tensor containing the Hilbert encoded values + * + * @return 3 tensors [N] containing the x, y, z coordinates + */ +std::tuple +hilbert_decode( + const torch::Tensor& codes +); diff --git a/extensions/vox2seq/src/ext.cpp b/extensions/vox2seq/src/ext.cpp new file mode 100644 index 0000000000000000000000000000000000000000..72e76d3b361eb8f355760f067f71005d4e37902c --- /dev/null +++ b/extensions/vox2seq/src/ext.cpp @@ -0,0 +1,10 @@ +#include +#include "api.h" + + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("z_order_encode", &z_order_encode); + m.def("z_order_decode", &z_order_decode); + m.def("hilbert_encode", &hilbert_encode); + m.def("hilbert_decode", &hilbert_decode); +} \ No newline at end of file diff --git a/extensions/vox2seq/src/hilbert.cu b/extensions/vox2seq/src/hilbert.cu new file mode 100644 index 0000000000000000000000000000000000000000..b3c5bb19474a528cf6f3102e728fa7550588ca61 --- /dev/null +++ b/extensions/vox2seq/src/hilbert.cu @@ -0,0 +1,133 @@ +#include +#include +#include + +#include +#include +namespace cg = cooperative_groups; + +#include "hilbert.h" + + +// Expands a 10-bit integer into 30 bits by inserting 2 zeros after each bit. +static __device__ uint32_t expandBits(uint32_t v) +{ + v = (v * 0x00010001u) & 0xFF0000FFu; + v = (v * 0x00000101u) & 0x0F00F00Fu; + v = (v * 0x00000011u) & 0xC30C30C3u; + v = (v * 0x00000005u) & 0x49249249u; + return v; +} + + +// Removes 2 zeros after each bit in a 30-bit integer. +static __device__ uint32_t extractBits(uint32_t v) +{ + v = v & 0x49249249; + v = (v ^ (v >> 2)) & 0x030C30C3u; + v = (v ^ (v >> 4)) & 0x0300F00Fu; + v = (v ^ (v >> 8)) & 0x030000FFu; + v = (v ^ (v >> 16)) & 0x000003FFu; + return v; +} + + +__global__ void hilbert_encode_cuda( + size_t N, + const uint32_t* x, + const uint32_t* y, + const uint32_t* z, + uint32_t* codes +) { + size_t thread_id = cg::this_grid().thread_rank(); + if (thread_id >= N) return; + + uint32_t point[3] = {x[thread_id], y[thread_id], z[thread_id]}; + + uint32_t m = 1 << 9, q, p, t; + + // Inverse undo excess work + q = m; + while (q > 1) { + p = q - 1; + for (int i = 0; i < 3; i++) { + if (point[i] & q) { + point[0] ^= p; // invert + } else { + t = (point[0] ^ point[i]) & p; + point[0] ^= t; + point[i] ^= t; + } + } + q >>= 1; + } + + // Gray encode + for (int i = 1; i < 3; i++) { + point[i] ^= point[i - 1]; + } + t = 0; + q = m; + while (q > 1) { + if (point[2] & q) { + t ^= q - 1; + } + q >>= 1; + } + for (int i = 0; i < 3; i++) { + point[i] ^= t; + } + + // Convert to 3D Hilbert code + uint32_t xx = expandBits(point[0]); + uint32_t yy = expandBits(point[1]); + uint32_t zz = expandBits(point[2]); + + codes[thread_id] = xx * 4 + yy * 2 + zz; +} + + +__global__ void hilbert_decode_cuda( + size_t N, + const uint32_t* codes, + uint32_t* x, + uint32_t* y, + uint32_t* z +) { + size_t thread_id = cg::this_grid().thread_rank(); + if (thread_id >= N) return; + + uint32_t point[3]; + point[0] = extractBits(codes[thread_id] >> 2); + point[1] = extractBits(codes[thread_id] >> 1); + point[2] = extractBits(codes[thread_id]); + + uint32_t m = 2 << 9, q, p, t; + + // Gray decode by H ^ (H/2) + t = point[2] >> 1; + for (int i = 2; i > 0; i--) { + point[i] ^= point[i - 1]; + } + point[0] ^= t; + + // Undo excess work + q = 2; + while (q != m) { + p = q - 1; + for (int i = 2; i >= 0; i--) { + if (point[i] & q) { + point[0] ^= p; + } else { + t = (point[0] ^ point[i]) & p; + point[0] ^= t; + point[i] ^= t; + } + } + q <<= 1; + } + + x[thread_id] = point[0]; + y[thread_id] = point[1]; + z[thread_id] = point[2]; +} diff --git a/extensions/vox2seq/src/hilbert.h b/extensions/vox2seq/src/hilbert.h new file mode 100644 index 0000000000000000000000000000000000000000..4896bf6006f43e5e527d8bde691ce7a54b38c4d7 --- /dev/null +++ b/extensions/vox2seq/src/hilbert.h @@ -0,0 +1,35 @@ +#pragma once + +/** + * Hilbert encode 3D points + * + * @param x [N] tensor containing the x coordinates + * @param y [N] tensor containing the y coordinates + * @param z [N] tensor containing the z coordinates + * + * @return [N] tensor containing the z-order encoded values + */ +__global__ void hilbert_encode_cuda( + size_t N, + const uint32_t* x, + const uint32_t* y, + const uint32_t* z, + uint32_t* codes +); + + +/** + * Hilbert decode 3D points + * + * @param codes [N] tensor containing the z-order encoded values + * @param x [N] tensor containing the x coordinates + * @param y [N] tensor containing the y coordinates + * @param z [N] tensor containing the z coordinates + */ +__global__ void hilbert_decode_cuda( + size_t N, + const uint32_t* codes, + uint32_t* x, + uint32_t* y, + uint32_t* z +); diff --git a/extensions/vox2seq/src/z_order.cu b/extensions/vox2seq/src/z_order.cu new file mode 100644 index 0000000000000000000000000000000000000000..ba6f5a91e55588d1ca4bea7cb45e6936330a694b --- /dev/null +++ b/extensions/vox2seq/src/z_order.cu @@ -0,0 +1,66 @@ +#include +#include +#include + +#include +#include +namespace cg = cooperative_groups; + +#include "z_order.h" + + +// Expands a 10-bit integer into 30 bits by inserting 2 zeros after each bit. +static __device__ uint32_t expandBits(uint32_t v) +{ + v = (v * 0x00010001u) & 0xFF0000FFu; + v = (v * 0x00000101u) & 0x0F00F00Fu; + v = (v * 0x00000011u) & 0xC30C30C3u; + v = (v * 0x00000005u) & 0x49249249u; + return v; +} + + +// Removes 2 zeros after each bit in a 30-bit integer. +static __device__ uint32_t extractBits(uint32_t v) +{ + v = v & 0x49249249; + v = (v ^ (v >> 2)) & 0x030C30C3u; + v = (v ^ (v >> 4)) & 0x0300F00Fu; + v = (v ^ (v >> 8)) & 0x030000FFu; + v = (v ^ (v >> 16)) & 0x000003FFu; + return v; +} + + +__global__ void z_order_encode_cuda( + size_t N, + const uint32_t* x, + const uint32_t* y, + const uint32_t* z, + uint32_t* codes +) { + size_t thread_id = cg::this_grid().thread_rank(); + if (thread_id >= N) return; + + uint32_t xx = expandBits(x[thread_id]); + uint32_t yy = expandBits(y[thread_id]); + uint32_t zz = expandBits(z[thread_id]); + + codes[thread_id] = xx * 4 + yy * 2 + zz; +} + + +__global__ void z_order_decode_cuda( + size_t N, + const uint32_t* codes, + uint32_t* x, + uint32_t* y, + uint32_t* z +) { + size_t thread_id = cg::this_grid().thread_rank(); + if (thread_id >= N) return; + + x[thread_id] = extractBits(codes[thread_id] >> 2); + y[thread_id] = extractBits(codes[thread_id] >> 1); + z[thread_id] = extractBits(codes[thread_id]); +} diff --git a/extensions/vox2seq/src/z_order.h b/extensions/vox2seq/src/z_order.h new file mode 100644 index 0000000000000000000000000000000000000000..a4aa857d064e375c8f2eb023abd9ac4af5a2d8f5 --- /dev/null +++ b/extensions/vox2seq/src/z_order.h @@ -0,0 +1,35 @@ +#pragma once + +/** + * Z-order encode 3D points + * + * @param x [N] tensor containing the x coordinates + * @param y [N] tensor containing the y coordinates + * @param z [N] tensor containing the z coordinates + * + * @return [N] tensor containing the z-order encoded values + */ +__global__ void z_order_encode_cuda( + size_t N, + const uint32_t* x, + const uint32_t* y, + const uint32_t* z, + uint32_t* codes +); + + +/** + * Z-order decode 3D points + * + * @param codes [N] tensor containing the z-order encoded values + * @param x [N] tensor containing the x coordinates + * @param y [N] tensor containing the y coordinates + * @param z [N] tensor containing the z coordinates + */ +__global__ void z_order_decode_cuda( + size_t N, + const uint32_t* codes, + uint32_t* x, + uint32_t* y, + uint32_t* z +); diff --git a/extensions/vox2seq/test.py b/extensions/vox2seq/test.py new file mode 100644 index 0000000000000000000000000000000000000000..60f4fc0ae1aa0ff55744455e66b0854706908506 --- /dev/null +++ b/extensions/vox2seq/test.py @@ -0,0 +1,25 @@ +import torch +import vox2seq + + +if __name__ == "__main__": + RES = 256 + coords = torch.meshgrid(torch.arange(RES), torch.arange(RES), torch.arange(RES)) + coords = torch.stack(coords, dim=-1).reshape(-1, 3).int().cuda() + code_z_cuda = vox2seq.encode(coords, mode='z_order') + code_z_pytorch = vox2seq.pytorch.encode(coords, mode='z_order') + code_h_cuda = vox2seq.encode(coords, mode='hilbert') + code_h_pytorch = vox2seq.pytorch.encode(coords, mode='hilbert') + assert torch.equal(code_z_cuda, code_z_pytorch) + assert torch.equal(code_h_cuda, code_h_pytorch) + + code = torch.arange(RES**3).int().cuda() + coords_z_cuda = vox2seq.decode(code, mode='z_order') + coords_z_pytorch = vox2seq.pytorch.decode(code, mode='z_order') + coords_h_cuda = vox2seq.decode(code, mode='hilbert') + coords_h_pytorch = vox2seq.pytorch.decode(code, mode='hilbert') + assert torch.equal(coords_z_cuda, coords_z_pytorch) + assert torch.equal(coords_h_cuda, coords_h_pytorch) + + print("All tests passed.") + diff --git a/extensions/vox2seq/vox2seq/__init__.py b/extensions/vox2seq/vox2seq/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..ba13bb5f46b2ba90c7882f23b7d97c3b6fe960ac --- /dev/null +++ b/extensions/vox2seq/vox2seq/__init__.py @@ -0,0 +1,50 @@ + +from typing import * +import torch +from . import _C +from . import pytorch + + +@torch.no_grad() +def encode(coords: torch.Tensor, permute: List[int] = [0, 1, 2], mode: Literal['z_order', 'hilbert'] = 'z_order') -> torch.Tensor: + """ + Encodes 3D coordinates into a 30-bit code. + + Args: + coords: a tensor of shape [N, 3] containing the 3D coordinates. + permute: the permutation of the coordinates. + mode: the encoding mode to use. + """ + assert coords.shape[-1] == 3 and coords.ndim == 2, "Input coordinates must be of shape [N, 3]" + x = coords[:, permute[0]].int() + y = coords[:, permute[1]].int() + z = coords[:, permute[2]].int() + if mode == 'z_order': + return _C.z_order_encode(x, y, z) + elif mode == 'hilbert': + return _C.hilbert_encode(x, y, z) + else: + raise ValueError(f"Unknown encoding mode: {mode}") + + +@torch.no_grad() +def decode(code: torch.Tensor, permute: List[int] = [0, 1, 2], mode: Literal['z_order', 'hilbert'] = 'z_order') -> torch.Tensor: + """ + Decodes a 30-bit code into 3D coordinates. + + Args: + code: a tensor of shape [N] containing the 30-bit code. + permute: the permutation of the coordinates. + mode: the decoding mode to use. + """ + assert code.ndim == 1, "Input code must be of shape [N]" + if mode == 'z_order': + coords = _C.z_order_decode(code) + elif mode == 'hilbert': + coords = _C.hilbert_decode(code) + else: + raise ValueError(f"Unknown decoding mode: {mode}") + x = coords[permute.index(0)] + y = coords[permute.index(1)] + z = coords[permute.index(2)] + return torch.stack([x, y, z], dim=-1) diff --git a/extensions/vox2seq/vox2seq/pytorch/__init__.py b/extensions/vox2seq/vox2seq/pytorch/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..25c74c42feb802b24eee9a1bc8744040468c927d --- /dev/null +++ b/extensions/vox2seq/vox2seq/pytorch/__init__.py @@ -0,0 +1,48 @@ +import torch +from typing import * + +from .default import ( + encode, + decode, + z_order_encode, + z_order_decode, + hilbert_encode, + hilbert_decode, +) + + +@torch.no_grad() +def encode(coords: torch.Tensor, permute: List[int] = [0, 1, 2], mode: Literal['z_order', 'hilbert'] = 'z_order') -> torch.Tensor: + """ + Encodes 3D coordinates into a 30-bit code. + + Args: + coords: a tensor of shape [N, 3] containing the 3D coordinates. + permute: the permutation of the coordinates. + mode: the encoding mode to use. + """ + if mode == 'z_order': + return z_order_encode(coords[:, permute], depth=10).int() + elif mode == 'hilbert': + return hilbert_encode(coords[:, permute], depth=10).int() + else: + raise ValueError(f"Unknown encoding mode: {mode}") + + +@torch.no_grad() +def decode(code: torch.Tensor, permute: List[int] = [0, 1, 2], mode: Literal['z_order', 'hilbert'] = 'z_order') -> torch.Tensor: + """ + Decodes a 30-bit code into 3D coordinates. + + Args: + code: a tensor of shape [N] containing the 30-bit code. + permute: the permutation of the coordinates. + mode: the decoding mode to use. + """ + if mode == 'z_order': + return z_order_decode(code, depth=10)[:, permute].float() + elif mode == 'hilbert': + return hilbert_decode(code, depth=10)[:, permute].float() + else: + raise ValueError(f"Unknown decoding mode: {mode}") + \ No newline at end of file diff --git a/extensions/vox2seq/vox2seq/pytorch/default.py b/extensions/vox2seq/vox2seq/pytorch/default.py new file mode 100644 index 0000000000000000000000000000000000000000..906f9bfbe80fcf71977ca774b6491ff63a1ee43b --- /dev/null +++ b/extensions/vox2seq/vox2seq/pytorch/default.py @@ -0,0 +1,59 @@ +import torch +from .z_order import xyz2key as z_order_encode_ +from .z_order import key2xyz as z_order_decode_ +from .hilbert import encode as hilbert_encode_ +from .hilbert import decode as hilbert_decode_ + + +@torch.inference_mode() +def encode(grid_coord, batch=None, depth=16, order="z"): + assert order in {"z", "z-trans", "hilbert", "hilbert-trans"} + if order == "z": + code = z_order_encode(grid_coord, depth=depth) + elif order == "z-trans": + code = z_order_encode(grid_coord[:, [1, 0, 2]], depth=depth) + elif order == "hilbert": + code = hilbert_encode(grid_coord, depth=depth) + elif order == "hilbert-trans": + code = hilbert_encode(grid_coord[:, [1, 0, 2]], depth=depth) + else: + raise NotImplementedError + if batch is not None: + batch = batch.long() + code = batch << depth * 3 | code + return code + + +@torch.inference_mode() +def decode(code, depth=16, order="z"): + assert order in {"z", "hilbert"} + batch = code >> depth * 3 + code = code & ((1 << depth * 3) - 1) + if order == "z": + grid_coord = z_order_decode(code, depth=depth) + elif order == "hilbert": + grid_coord = hilbert_decode(code, depth=depth) + else: + raise NotImplementedError + return grid_coord, batch + + +def z_order_encode(grid_coord: torch.Tensor, depth: int = 16): + x, y, z = grid_coord[:, 0].long(), grid_coord[:, 1].long(), grid_coord[:, 2].long() + # we block the support to batch, maintain batched code in Point class + code = z_order_encode_(x, y, z, b=None, depth=depth) + return code + + +def z_order_decode(code: torch.Tensor, depth): + x, y, z, _ = z_order_decode_(code, depth=depth) + grid_coord = torch.stack([x, y, z], dim=-1) # (N, 3) + return grid_coord + + +def hilbert_encode(grid_coord: torch.Tensor, depth: int = 16): + return hilbert_encode_(grid_coord, num_dims=3, num_bits=depth) + + +def hilbert_decode(code: torch.Tensor, depth: int = 16): + return hilbert_decode_(code, num_dims=3, num_bits=depth) \ No newline at end of file diff --git a/extensions/vox2seq/vox2seq/pytorch/hilbert.py b/extensions/vox2seq/vox2seq/pytorch/hilbert.py new file mode 100644 index 0000000000000000000000000000000000000000..c3fb6565ff855c50553d6215eb74407f88b43a01 --- /dev/null +++ b/extensions/vox2seq/vox2seq/pytorch/hilbert.py @@ -0,0 +1,303 @@ +""" +Hilbert Order +Modified from https://github.com/PrincetonLIPS/numpy-hilbert-curve + +Author: Xiaoyang Wu (xiaoyang.wu.cs@gmail.com), Kaixin Xu +Please cite our work if the code is helpful to you. +""" + +import torch + + +def right_shift(binary, k=1, axis=-1): + """Right shift an array of binary values. + + Parameters: + ----------- + binary: An ndarray of binary values. + + k: The number of bits to shift. Default 1. + + axis: The axis along which to shift. Default -1. + + Returns: + -------- + Returns an ndarray with zero prepended and the ends truncated, along + whatever axis was specified.""" + + # If we're shifting the whole thing, just return zeros. + if binary.shape[axis] <= k: + return torch.zeros_like(binary) + + # Determine the padding pattern. + # padding = [(0,0)] * len(binary.shape) + # padding[axis] = (k,0) + + # Determine the slicing pattern to eliminate just the last one. + slicing = [slice(None)] * len(binary.shape) + slicing[axis] = slice(None, -k) + shifted = torch.nn.functional.pad( + binary[tuple(slicing)], (k, 0), mode="constant", value=0 + ) + + return shifted + + +def binary2gray(binary, axis=-1): + """Convert an array of binary values into Gray codes. + + This uses the classic X ^ (X >> 1) trick to compute the Gray code. + + Parameters: + ----------- + binary: An ndarray of binary values. + + axis: The axis along which to compute the gray code. Default=-1. + + Returns: + -------- + Returns an ndarray of Gray codes. + """ + shifted = right_shift(binary, axis=axis) + + # Do the X ^ (X >> 1) trick. + gray = torch.logical_xor(binary, shifted) + + return gray + + +def gray2binary(gray, axis=-1): + """Convert an array of Gray codes back into binary values. + + Parameters: + ----------- + gray: An ndarray of gray codes. + + axis: The axis along which to perform Gray decoding. Default=-1. + + Returns: + -------- + Returns an ndarray of binary values. + """ + + # Loop the log2(bits) number of times necessary, with shift and xor. + shift = 2 ** (torch.Tensor([gray.shape[axis]]).log2().ceil().int() - 1) + while shift > 0: + gray = torch.logical_xor(gray, right_shift(gray, shift)) + shift = torch.div(shift, 2, rounding_mode="floor") + return gray + + +def encode(locs, num_dims, num_bits): + """Decode an array of locations in a hypercube into a Hilbert integer. + + This is a vectorized-ish version of the Hilbert curve implementation by John + Skilling as described in: + + Skilling, J. (2004, April). Programming the Hilbert curve. In AIP Conference + Proceedings (Vol. 707, No. 1, pp. 381-387). American Institute of Physics. + + Params: + ------- + locs - An ndarray of locations in a hypercube of num_dims dimensions, in + which each dimension runs from 0 to 2**num_bits-1. The shape can + be arbitrary, as long as the last dimension of the same has size + num_dims. + + num_dims - The dimensionality of the hypercube. Integer. + + num_bits - The number of bits for each dimension. Integer. + + Returns: + -------- + The output is an ndarray of uint64 integers with the same shape as the + input, excluding the last dimension, which needs to be num_dims. + """ + + # Keep around the original shape for later. + orig_shape = locs.shape + bitpack_mask = 1 << torch.arange(0, 8).to(locs.device) + bitpack_mask_rev = bitpack_mask.flip(-1) + + if orig_shape[-1] != num_dims: + raise ValueError( + """ + The shape of locs was surprising in that the last dimension was of size + %d, but num_dims=%d. These need to be equal. + """ + % (orig_shape[-1], num_dims) + ) + + if num_dims * num_bits > 63: + raise ValueError( + """ + num_dims=%d and num_bits=%d for %d bits total, which can't be encoded + into a int64. Are you sure you need that many points on your Hilbert + curve? + """ + % (num_dims, num_bits, num_dims * num_bits) + ) + + # Treat the location integers as 64-bit unsigned and then split them up into + # a sequence of uint8s. Preserve the association by dimension. + locs_uint8 = locs.long().view(torch.uint8).reshape((-1, num_dims, 8)).flip(-1) + + # Now turn these into bits and truncate to num_bits. + gray = ( + locs_uint8.unsqueeze(-1) + .bitwise_and(bitpack_mask_rev) + .ne(0) + .byte() + .flatten(-2, -1)[..., -num_bits:] + ) + + # Run the decoding process the other way. + # Iterate forwards through the bits. + for bit in range(0, num_bits): + # Iterate forwards through the dimensions. + for dim in range(0, num_dims): + # Identify which ones have this bit active. + mask = gray[:, dim, bit] + + # Where this bit is on, invert the 0 dimension for lower bits. + gray[:, 0, bit + 1 :] = torch.logical_xor( + gray[:, 0, bit + 1 :], mask[:, None] + ) + + # Where the bit is off, exchange the lower bits with the 0 dimension. + to_flip = torch.logical_and( + torch.logical_not(mask[:, None]).repeat(1, gray.shape[2] - bit - 1), + torch.logical_xor(gray[:, 0, bit + 1 :], gray[:, dim, bit + 1 :]), + ) + gray[:, dim, bit + 1 :] = torch.logical_xor( + gray[:, dim, bit + 1 :], to_flip + ) + gray[:, 0, bit + 1 :] = torch.logical_xor(gray[:, 0, bit + 1 :], to_flip) + + # Now flatten out. + gray = gray.swapaxes(1, 2).reshape((-1, num_bits * num_dims)) + + # Convert Gray back to binary. + hh_bin = gray2binary(gray) + + # Pad back out to 64 bits. + extra_dims = 64 - num_bits * num_dims + padded = torch.nn.functional.pad(hh_bin, (extra_dims, 0), "constant", 0) + + # Convert binary values into uint8s. + hh_uint8 = ( + (padded.flip(-1).reshape((-1, 8, 8)) * bitpack_mask) + .sum(2) + .squeeze() + .type(torch.uint8) + ) + + # Convert uint8s into uint64s. + hh_uint64 = hh_uint8.view(torch.int64).squeeze() + + return hh_uint64 + + +def decode(hilberts, num_dims, num_bits): + """Decode an array of Hilbert integers into locations in a hypercube. + + This is a vectorized-ish version of the Hilbert curve implementation by John + Skilling as described in: + + Skilling, J. (2004, April). Programming the Hilbert curve. In AIP Conference + Proceedings (Vol. 707, No. 1, pp. 381-387). American Institute of Physics. + + Params: + ------- + hilberts - An ndarray of Hilbert integers. Must be an integer dtype and + cannot have fewer bits than num_dims * num_bits. + + num_dims - The dimensionality of the hypercube. Integer. + + num_bits - The number of bits for each dimension. Integer. + + Returns: + -------- + The output is an ndarray of unsigned integers with the same shape as hilberts + but with an additional dimension of size num_dims. + """ + + if num_dims * num_bits > 64: + raise ValueError( + """ + num_dims=%d and num_bits=%d for %d bits total, which can't be encoded + into a uint64. Are you sure you need that many points on your Hilbert + curve? + """ + % (num_dims, num_bits) + ) + + # Handle the case where we got handed a naked integer. + hilberts = torch.atleast_1d(hilberts) + + # Keep around the shape for later. + orig_shape = hilberts.shape + bitpack_mask = 2 ** torch.arange(0, 8).to(hilberts.device) + bitpack_mask_rev = bitpack_mask.flip(-1) + + # Treat each of the hilberts as a s equence of eight uint8. + # This treats all of the inputs as uint64 and makes things uniform. + hh_uint8 = ( + hilberts.ravel().type(torch.int64).view(torch.uint8).reshape((-1, 8)).flip(-1) + ) + + # Turn these lists of uints into lists of bits and then truncate to the size + # we actually need for using Skilling's procedure. + hh_bits = ( + hh_uint8.unsqueeze(-1) + .bitwise_and(bitpack_mask_rev) + .ne(0) + .byte() + .flatten(-2, -1)[:, -num_dims * num_bits :] + ) + + # Take the sequence of bits and Gray-code it. + gray = binary2gray(hh_bits) + + # There has got to be a better way to do this. + # I could index them differently, but the eventual packbits likes it this way. + gray = gray.reshape((-1, num_bits, num_dims)).swapaxes(1, 2) + + # Iterate backwards through the bits. + for bit in range(num_bits - 1, -1, -1): + # Iterate backwards through the dimensions. + for dim in range(num_dims - 1, -1, -1): + # Identify which ones have this bit active. + mask = gray[:, dim, bit] + + # Where this bit is on, invert the 0 dimension for lower bits. + gray[:, 0, bit + 1 :] = torch.logical_xor( + gray[:, 0, bit + 1 :], mask[:, None] + ) + + # Where the bit is off, exchange the lower bits with the 0 dimension. + to_flip = torch.logical_and( + torch.logical_not(mask[:, None]), + torch.logical_xor(gray[:, 0, bit + 1 :], gray[:, dim, bit + 1 :]), + ) + gray[:, dim, bit + 1 :] = torch.logical_xor( + gray[:, dim, bit + 1 :], to_flip + ) + gray[:, 0, bit + 1 :] = torch.logical_xor(gray[:, 0, bit + 1 :], to_flip) + + # Pad back out to 64 bits. + extra_dims = 64 - num_bits + padded = torch.nn.functional.pad(gray, (extra_dims, 0), "constant", 0) + + # Now chop these up into blocks of 8. + locs_chopped = padded.flip(-1).reshape((-1, num_dims, 8, 8)) + + # Take those blocks and turn them unto uint8s. + # from IPython import embed; embed() + locs_uint8 = (locs_chopped * bitpack_mask).sum(3).squeeze().type(torch.uint8) + + # Finally, treat these as uint64s. + flat_locs = locs_uint8.view(torch.int64) + + # Return them in the expected shape. + return flat_locs.reshape((*orig_shape, num_dims)) \ No newline at end of file diff --git a/extensions/vox2seq/vox2seq/pytorch/z_order.py b/extensions/vox2seq/vox2seq/pytorch/z_order.py new file mode 100644 index 0000000000000000000000000000000000000000..b33963de44ddfcbecdf2e403ea70166f2d0e4eb0 --- /dev/null +++ b/extensions/vox2seq/vox2seq/pytorch/z_order.py @@ -0,0 +1,126 @@ +# -------------------------------------------------------- +# Octree-based Sparse Convolutional Neural Networks +# Copyright (c) 2022 Peng-Shuai Wang +# Licensed under The MIT License [see LICENSE for details] +# Written by Peng-Shuai Wang +# -------------------------------------------------------- + +import torch +from typing import Optional, Union + + +class KeyLUT: + def __init__(self): + r256 = torch.arange(256, dtype=torch.int64) + r512 = torch.arange(512, dtype=torch.int64) + zero = torch.zeros(256, dtype=torch.int64) + device = torch.device("cpu") + + self._encode = { + device: ( + self.xyz2key(r256, zero, zero, 8), + self.xyz2key(zero, r256, zero, 8), + self.xyz2key(zero, zero, r256, 8), + ) + } + self._decode = {device: self.key2xyz(r512, 9)} + + def encode_lut(self, device=torch.device("cpu")): + if device not in self._encode: + cpu = torch.device("cpu") + self._encode[device] = tuple(e.to(device) for e in self._encode[cpu]) + return self._encode[device] + + def decode_lut(self, device=torch.device("cpu")): + if device not in self._decode: + cpu = torch.device("cpu") + self._decode[device] = tuple(e.to(device) for e in self._decode[cpu]) + return self._decode[device] + + def xyz2key(self, x, y, z, depth): + key = torch.zeros_like(x) + for i in range(depth): + mask = 1 << i + key = ( + key + | ((x & mask) << (2 * i + 2)) + | ((y & mask) << (2 * i + 1)) + | ((z & mask) << (2 * i + 0)) + ) + return key + + def key2xyz(self, key, depth): + x = torch.zeros_like(key) + y = torch.zeros_like(key) + z = torch.zeros_like(key) + for i in range(depth): + x = x | ((key & (1 << (3 * i + 2))) >> (2 * i + 2)) + y = y | ((key & (1 << (3 * i + 1))) >> (2 * i + 1)) + z = z | ((key & (1 << (3 * i + 0))) >> (2 * i + 0)) + return x, y, z + + +_key_lut = KeyLUT() + + +def xyz2key( + x: torch.Tensor, + y: torch.Tensor, + z: torch.Tensor, + b: Optional[Union[torch.Tensor, int]] = None, + depth: int = 16, +): + r"""Encodes :attr:`x`, :attr:`y`, :attr:`z` coordinates to the shuffled keys + based on pre-computed look up tables. The speed of this function is much + faster than the method based on for-loop. + + Args: + x (torch.Tensor): The x coordinate. + y (torch.Tensor): The y coordinate. + z (torch.Tensor): The z coordinate. + b (torch.Tensor or int): The batch index of the coordinates, and should be + smaller than 32768. If :attr:`b` is :obj:`torch.Tensor`, the size of + :attr:`b` must be the same as :attr:`x`, :attr:`y`, and :attr:`z`. + depth (int): The depth of the shuffled key, and must be smaller than 17 (< 17). + """ + + EX, EY, EZ = _key_lut.encode_lut(x.device) + x, y, z = x.long(), y.long(), z.long() + + mask = 255 if depth > 8 else (1 << depth) - 1 + key = EX[x & mask] | EY[y & mask] | EZ[z & mask] + if depth > 8: + mask = (1 << (depth - 8)) - 1 + key16 = EX[(x >> 8) & mask] | EY[(y >> 8) & mask] | EZ[(z >> 8) & mask] + key = key16 << 24 | key + + if b is not None: + b = b.long() + key = b << 48 | key + + return key + + +def key2xyz(key: torch.Tensor, depth: int = 16): + r"""Decodes the shuffled key to :attr:`x`, :attr:`y`, :attr:`z` coordinates + and the batch index based on pre-computed look up tables. + + Args: + key (torch.Tensor): The shuffled key. + depth (int): The depth of the shuffled key, and must be smaller than 17 (< 17). + """ + + DX, DY, DZ = _key_lut.decode_lut(key.device) + x, y, z = torch.zeros_like(key), torch.zeros_like(key), torch.zeros_like(key) + + b = key >> 48 + key = key & ((1 << 48) - 1) + + n = (depth + 2) // 3 + for i in range(n): + k = key >> (i * 9) & 511 + x = x | (DX[k] << (i * 3)) + y = y | (DY[k] << (i * 3)) + z = z | (DZ[k] << (i * 3)) + + return x, y, z, b \ No newline at end of file