diff --git a/.gitattributes b/.gitattributes index a6344aac8c09253b3b630fb776ae94478aa0275b..2cec29e71fd5f2242a7ab04e9b24f5beab6ebd3d 100644 --- a/.gitattributes +++ b/.gitattributes @@ -33,3 +33,67 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text *.zip filter=lfs diff=lfs merge=lfs -text *.zst filter=lfs diff=lfs merge=lfs -text *tfevents* filter=lfs diff=lfs merge=lfs -text +TRELLIS/assets/example_image/T.png filter=lfs diff=lfs merge=lfs -text +TRELLIS/assets/example_image/typical_building_building.png filter=lfs diff=lfs merge=lfs -text +TRELLIS/assets/example_image/typical_building_castle.png filter=lfs diff=lfs merge=lfs -text +TRELLIS/assets/example_image/typical_building_colorful_cottage.png filter=lfs diff=lfs merge=lfs -text +TRELLIS/assets/example_image/typical_building_maya_pyramid.png filter=lfs diff=lfs merge=lfs -text +TRELLIS/assets/example_image/typical_building_mushroom.png filter=lfs diff=lfs merge=lfs -text +TRELLIS/assets/example_image/typical_building_space_station.png filter=lfs diff=lfs merge=lfs -text +TRELLIS/assets/example_image/typical_creature_dragon.png filter=lfs diff=lfs merge=lfs -text +TRELLIS/assets/example_image/typical_creature_elephant.png filter=lfs diff=lfs merge=lfs -text +TRELLIS/assets/example_image/typical_creature_furry.png filter=lfs diff=lfs merge=lfs -text +TRELLIS/assets/example_image/typical_creature_quadruped.png filter=lfs diff=lfs merge=lfs -text +TRELLIS/assets/example_image/typical_creature_robot_crab.png filter=lfs diff=lfs merge=lfs -text +TRELLIS/assets/example_image/typical_creature_robot_dinosour.png filter=lfs diff=lfs merge=lfs -text +TRELLIS/assets/example_image/typical_creature_rock_monster.png filter=lfs diff=lfs merge=lfs -text +TRELLIS/assets/example_image/typical_humanoid_block_robot.png filter=lfs diff=lfs merge=lfs -text +TRELLIS/assets/example_image/typical_humanoid_dragonborn.png filter=lfs diff=lfs merge=lfs -text +TRELLIS/assets/example_image/typical_humanoid_dwarf.png filter=lfs diff=lfs merge=lfs -text +TRELLIS/assets/example_image/typical_humanoid_goblin.png filter=lfs diff=lfs merge=lfs -text +TRELLIS/assets/example_image/typical_humanoid_mech.png filter=lfs diff=lfs merge=lfs -text +TRELLIS/assets/example_image/typical_misc_crate.png filter=lfs diff=lfs merge=lfs -text +TRELLIS/assets/example_image/typical_misc_fireplace.png filter=lfs diff=lfs merge=lfs -text +TRELLIS/assets/example_image/typical_misc_gate.png filter=lfs diff=lfs merge=lfs -text +TRELLIS/assets/example_image/typical_misc_lantern.png filter=lfs diff=lfs merge=lfs -text +TRELLIS/assets/example_image/typical_misc_magicbook.png filter=lfs diff=lfs merge=lfs -text +TRELLIS/assets/example_image/typical_misc_mailbox.png filter=lfs diff=lfs merge=lfs -text +TRELLIS/assets/example_image/typical_misc_monster_chest.png filter=lfs diff=lfs merge=lfs -text +TRELLIS/assets/example_image/typical_misc_paper_machine.png filter=lfs diff=lfs merge=lfs -text +TRELLIS/assets/example_image/typical_misc_phonograph.png filter=lfs diff=lfs merge=lfs -text +TRELLIS/assets/example_image/typical_misc_portal2.png filter=lfs diff=lfs merge=lfs -text +TRELLIS/assets/example_image/typical_misc_storage_chest.png filter=lfs diff=lfs merge=lfs -text +TRELLIS/assets/example_image/typical_misc_telephone.png filter=lfs diff=lfs merge=lfs -text +TRELLIS/assets/example_image/typical_misc_television.png filter=lfs diff=lfs merge=lfs -text +TRELLIS/assets/example_image/typical_misc_workbench.png filter=lfs diff=lfs merge=lfs -text +TRELLIS/assets/example_image/typical_vehicle_biplane.png filter=lfs diff=lfs merge=lfs -text +TRELLIS/assets/example_image/typical_vehicle_bulldozer.png filter=lfs diff=lfs merge=lfs -text +TRELLIS/assets/example_image/typical_vehicle_cart.png filter=lfs diff=lfs merge=lfs -text +TRELLIS/assets/example_image/typical_vehicle_excavator.png filter=lfs diff=lfs merge=lfs -text +TRELLIS/assets/example_image/typical_vehicle_helicopter.png filter=lfs diff=lfs merge=lfs -text +TRELLIS/assets/example_image/typical_vehicle_locomotive.png filter=lfs diff=lfs merge=lfs -text +TRELLIS/assets/example_image/typical_vehicle_pirate_ship.png filter=lfs diff=lfs merge=lfs -text +TRELLIS/assets/example_image/weatherworn_misc_paper_machine3.png filter=lfs diff=lfs merge=lfs -text +TRELLIS/assets/example_multi_image/character_1.png filter=lfs diff=lfs merge=lfs -text +TRELLIS/assets/example_multi_image/character_2.png filter=lfs diff=lfs merge=lfs -text +TRELLIS/assets/example_multi_image/character_3.png filter=lfs diff=lfs merge=lfs -text +TRELLIS/assets/example_multi_image/mushroom_1.png filter=lfs diff=lfs merge=lfs -text +TRELLIS/assets/example_multi_image/mushroom_2.png filter=lfs diff=lfs merge=lfs -text +TRELLIS/assets/example_multi_image/mushroom_3.png filter=lfs diff=lfs merge=lfs -text +TRELLIS/assets/example_multi_image/orangeguy_1.png filter=lfs diff=lfs merge=lfs -text +TRELLIS/assets/example_multi_image/orangeguy_2.png filter=lfs diff=lfs merge=lfs -text +TRELLIS/assets/example_multi_image/orangeguy_3.png filter=lfs diff=lfs merge=lfs -text +TRELLIS/assets/example_multi_image/popmart_1.png filter=lfs diff=lfs merge=lfs -text +TRELLIS/assets/example_multi_image/popmart_2.png filter=lfs diff=lfs merge=lfs -text +TRELLIS/assets/example_multi_image/popmart_3.png filter=lfs diff=lfs merge=lfs -text +TRELLIS/assets/example_multi_image/rabbit_1.png filter=lfs diff=lfs merge=lfs -text +TRELLIS/assets/example_multi_image/rabbit_2.png filter=lfs diff=lfs merge=lfs -text +TRELLIS/assets/example_multi_image/rabbit_3.png filter=lfs diff=lfs merge=lfs -text +TRELLIS/assets/example_multi_image/tiger_1.png filter=lfs diff=lfs merge=lfs -text +TRELLIS/assets/example_multi_image/tiger_2.png filter=lfs diff=lfs merge=lfs -text +TRELLIS/assets/example_multi_image/tiger_3.png filter=lfs diff=lfs merge=lfs -text +TRELLIS/assets/example_multi_image/yoimiya_1.png filter=lfs diff=lfs merge=lfs -text +TRELLIS/assets/example_multi_image/yoimiya_2.png filter=lfs diff=lfs merge=lfs -text +TRELLIS/assets/example_multi_image/yoimiya_3.png filter=lfs diff=lfs merge=lfs -text +TRELLIS/assets/logo.webp filter=lfs diff=lfs merge=lfs -text +TRELLIS/assets/teaser.png filter=lfs diff=lfs merge=lfs -text diff --git a/TRELLIS/.gitignore b/TRELLIS/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..f83ffb815447f086e49e3fe602ab2801b3096d3b --- /dev/null +++ b/TRELLIS/.gitignore @@ -0,0 +1,398 @@ +## Ignore Visual Studio temporary files, build results, and +## files generated by popular Visual Studio add-ons. +## +## Get latest from https://github.com/github/gitignore/blob/main/VisualStudio.gitignore + +# User-specific files +*.rsuser +*.suo +*.user +*.userosscache +*.sln.docstates + +# User-specific files (MonoDevelop/Xamarin Studio) +*.userprefs + +# Mono auto generated files +mono_crash.* + +# Build results +[Dd]ebug/ +[Dd]ebugPublic/ +[Rr]elease/ +[Rr]eleases/ +x64/ +x86/ +[Ww][Ii][Nn]32/ +[Aa][Rr][Mm]/ +[Aa][Rr][Mm]64/ +bld/ +[Bb]in/ +[Oo]bj/ +[Ll]og/ +[Ll]ogs/ + +# Visual Studio 2015/2017 cache/options directory +.vs/ +# Uncomment if you have tasks that create the project's static files in wwwroot +#wwwroot/ + +# Visual Studio 2017 auto generated files +Generated\ Files/ + +# MSTest test Results +[Tt]est[Rr]esult*/ +[Bb]uild[Ll]og.* + +# NUnit +*.VisualState.xml +TestResult.xml +nunit-*.xml + +# Build Results of an ATL Project +[Dd]ebugPS/ +[Rr]eleasePS/ +dlldata.c + +# Benchmark Results +BenchmarkDotNet.Artifacts/ + +# .NET Core +project.lock.json +project.fragment.lock.json +artifacts/ + +# ASP.NET Scaffolding +ScaffoldingReadMe.txt + +# StyleCop +StyleCopReport.xml + +# Files built by Visual Studio +*_i.c +*_p.c +*_h.h +*.ilk +*.meta +*.obj +*.iobj +*.pch +*.pdb +*.ipdb +*.pgc +*.pgd +*.rsp +*.sbr +*.tlb +*.tli +*.tlh +*.tmp +*.tmp_proj +*_wpftmp.csproj +*.log +*.tlog +*.vspscc +*.vssscc +.builds +*.pidb +*.svclog +*.scc + +# Chutzpah Test files +_Chutzpah* + +# Visual C++ cache files +ipch/ +*.aps +*.ncb +*.opendb +*.opensdf +*.sdf +*.cachefile +*.VC.db +*.VC.VC.opendb + +# Visual Studio profiler +*.psess +*.vsp +*.vspx +*.sap + +# Visual Studio Trace Files +*.e2e + +# TFS 2012 Local Workspace +$tf/ + +# Guidance Automation Toolkit +*.gpState + +# ReSharper is a .NET coding add-in +_ReSharper*/ +*.[Rr]e[Ss]harper +*.DotSettings.user + +# TeamCity is a build add-in +_TeamCity* + +# DotCover is a Code Coverage Tool +*.dotCover + +# AxoCover is a Code Coverage Tool +.axoCover/* +!.axoCover/settings.json + +# Coverlet is a free, cross platform Code Coverage Tool +coverage*.json +coverage*.xml +coverage*.info + +# Visual Studio code coverage results +*.coverage +*.coveragexml + +# NCrunch +_NCrunch_* +.*crunch*.local.xml +nCrunchTemp_* + +# MightyMoose +*.mm.* +AutoTest.Net/ + +# Web workbench (sass) +.sass-cache/ + +# Installshield output folder +[Ee]xpress/ + +# DocProject is a documentation generator add-in +DocProject/buildhelp/ +DocProject/Help/*.HxT +DocProject/Help/*.HxC +DocProject/Help/*.hhc +DocProject/Help/*.hhk +DocProject/Help/*.hhp +DocProject/Help/Html2 +DocProject/Help/html + +# Click-Once directory +publish/ + +# Publish Web Output +*.[Pp]ublish.xml +*.azurePubxml +# Note: Comment the next line if you want to checkin your web deploy settings, +# but database connection strings (with potential passwords) will be unencrypted +*.pubxml +*.publishproj + +# Microsoft Azure Web App publish settings. Comment the next line if you want to +# checkin your Azure Web App publish settings, but sensitive information contained +# in these scripts will be unencrypted +PublishScripts/ + +# NuGet Packages +*.nupkg +# NuGet Symbol Packages +*.snupkg +# The packages folder can be ignored because of Package Restore +**/[Pp]ackages/* +# except build/, which is used as an MSBuild target. +!**/[Pp]ackages/build/ +# Uncomment if necessary however generally it will be regenerated when needed +#!**/[Pp]ackages/repositories.config +# NuGet v3's project.json files produces more ignorable files +*.nuget.props +*.nuget.targets + +# Microsoft Azure Build Output +csx/ +*.build.csdef + +# Microsoft Azure Emulator +ecf/ +rcf/ + +# Windows Store app package directories and files +AppPackages/ +BundleArtifacts/ +Package.StoreAssociation.xml +_pkginfo.txt +*.appx +*.appxbundle +*.appxupload + +# Visual Studio cache files +# files ending in .cache can be ignored +*.[Cc]ache +# but keep track of directories ending in .cache +!?*.[Cc]ache/ + +# Others +ClientBin/ +~$* +*~ +*.dbmdl +*.dbproj.schemaview +*.jfm +*.pfx +*.publishsettings +orleans.codegen.cs + +# Including strong name files can present a security risk +# (https://github.com/github/gitignore/pull/2483#issue-259490424) +#*.snk + +# Since there are multiple workflows, uncomment next line to ignore bower_components +# (https://github.com/github/gitignore/pull/1529#issuecomment-104372622) +#bower_components/ + +# RIA/Silverlight projects +Generated_Code/ + +# Backup & report files from converting an old project file +# to a newer Visual Studio version. Backup files are not needed, +# because we have git ;-) +_UpgradeReport_Files/ +Backup*/ +UpgradeLog*.XML +UpgradeLog*.htm +ServiceFabricBackup/ +*.rptproj.bak + +# SQL Server files +*.mdf +*.ldf +*.ndf + +# Business Intelligence projects +*.rdl.data +*.bim.layout +*.bim_*.settings +*.rptproj.rsuser +*- [Bb]ackup.rdl +*- [Bb]ackup ([0-9]).rdl +*- [Bb]ackup ([0-9][0-9]).rdl + +# Microsoft Fakes +FakesAssemblies/ + +# GhostDoc plugin setting file +*.GhostDoc.xml + +# Node.js Tools for Visual Studio +.ntvs_analysis.dat +node_modules/ + +# Visual Studio 6 build log +*.plg + +# Visual Studio 6 workspace options file +*.opt + +# Visual Studio 6 auto-generated workspace file (contains which files were open etc.) +*.vbw + +# Visual Studio 6 auto-generated project file (contains which files were open etc.) +*.vbp + +# Visual Studio 6 workspace and project file (working project files containing files to include in project) +*.dsw +*.dsp + +# Visual Studio 6 technical files +*.ncb +*.aps + +# Visual Studio LightSwitch build output +**/*.HTMLClient/GeneratedArtifacts +**/*.DesktopClient/GeneratedArtifacts +**/*.DesktopClient/ModelManifest.xml +**/*.Server/GeneratedArtifacts +**/*.Server/ModelManifest.xml +_Pvt_Extensions + +# Paket dependency manager +.paket/paket.exe +paket-files/ + +# FAKE - F# Make +.fake/ + +# CodeRush personal settings +.cr/personal + +# Python Tools for Visual Studio (PTVS) +__pycache__/ +*.pyc + +# Cake - Uncomment if you are using it +# tools/** +# !tools/packages.config + +# Tabs Studio +*.tss + +# Telerik's JustMock configuration file +*.jmconfig + +# BizTalk build output +*.btp.cs +*.btm.cs +*.odx.cs +*.xsd.cs + +# OpenCover UI analysis results +OpenCover/ + +# Azure Stream Analytics local run output +ASALocalRun/ + +# MSBuild Binary and Structured Log +*.binlog + +# NVidia Nsight GPU debugger configuration file +*.nvuser + +# MFractors (Xamarin productivity tool) working folder +.mfractor/ + +# Local History for Visual Studio +.localhistory/ + +# Visual Studio History (VSHistory) files +.vshistory/ + +# BeatPulse healthcheck temp database +healthchecksdb + +# Backup folder for Package Reference Convert tool in Visual Studio 2017 +MigrationBackup/ + +# Ionide (cross platform F# VS Code tools) working folder +.ionide/ + +# Fody - auto-generated XML schema +FodyWeavers.xsd + +# VS Code files for those working on multiple tools +.vscode/* +!.vscode/settings.json +!.vscode/tasks.json +!.vscode/launch.json +!.vscode/extensions.json +*.code-workspace + +# Local History for Visual Studio Code +.history/ + +# Windows Installer files from build outputs +*.cab +*.msi +*.msix +*.msm +*.msp + +# JetBrains Rider +*.sln.iml diff --git a/TRELLIS/.gitmodules b/TRELLIS/.gitmodules new file mode 100644 index 0000000000000000000000000000000000000000..8647a964590832cb5ec99636d8eed7757dbd79c0 --- /dev/null +++ b/TRELLIS/.gitmodules @@ -0,0 +1,3 @@ +[submodule "trellis/representations/mesh/flexicubes"] + path = trellis/representations/mesh/flexicubes + url = https://github.com/MaxtirError/FlexiCubes.git diff --git a/TRELLIS/CODE_OF_CONDUCT.md b/TRELLIS/CODE_OF_CONDUCT.md new file mode 100644 index 0000000000000000000000000000000000000000..c72a5749c52ac97bca71c672ef5295d303d22b05 --- /dev/null +++ b/TRELLIS/CODE_OF_CONDUCT.md @@ -0,0 +1,9 @@ +# Microsoft Open Source Code of Conduct + +This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/). + +Resources: + +- [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/) +- [Microsoft Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/) +- Contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with questions or concerns diff --git a/TRELLIS/DATASET.md b/TRELLIS/DATASET.md new file mode 100644 index 0000000000000000000000000000000000000000..e140262b16b78f463a7e22eea65dc169721da2fb --- /dev/null +++ b/TRELLIS/DATASET.md @@ -0,0 +1,231 @@ +# TRELLIS-500K + +TRELLIS-500K is a dataset of 500K 3D assets curated from [Objaverse(XL)](https://objaverse.allenai.org/), [ABO](https://amazon-berkeley-objects.s3.amazonaws.com/index.html), [3D-FUTURE](https://tianchi.aliyun.com/specials/promotion/alibaba-3d-future), [HSSD](https://huggingface.co/datasets/hssd/hssd-models), and [Toys4k](https://github.com/rehg-lab/lowshot-shapebias/tree/main/toys4k), filtered based on aesthetic scores. +This dataset serves for 3D generation tasks. + +The dataset is provided as csv files containing the 3D assets' metadata. + +## Dataset Statistics + +The following table summarizes the dataset's filtering and composition: + +***NOTE: Some of the 3D assets lack text captions. Please filter out such assets if captions are required.*** +| Source | Aesthetic Score Threshold | Filtered Size | With Captions | +|:-:|:-:|:-:|:-:| +| ObjaverseXL (sketchfab) | 5.5 | 168307 | 167638 | +| ObjaverseXL (github) | 5.5 | 311843 | 306790 | +| ABO | 4.5 | 4485 | 4390 | +| 3D-FUTURE | 4.5 | 9472 | 9291 | +| HSSD | 4.5 | 6670 | 6661 | +| All (training set) | - | 500777 | 494770 | +| Toys4k (evaluation set) | 4.5 | 3229 | 3180 | + +## Dataset Location + +The dataset is hosted on Hugging Face Datasets. You can preview the dataset at + +[https://huggingface.co/datasets/JeffreyXiang/TRELLIS-500K](https://huggingface.co/datasets/JeffreyXiang/TRELLIS-500K) + +There is no need to download the csv files manually. We provide toolkits to load and prepare the dataset. + +## Dataset Toolkits + +We provide [toolkits](dataset_toolkits) for data preparation. + +### Step 1: Install Dependencies + +``` +. ./dataset_toolkits/setup.sh +``` + +### Step 2: Load Metadata + +First, we need to load the metadata of the dataset. + +``` +python dataset_toolkits/build_metadata.py --output_dir [--source ] +``` + +- `SUBSET`: The subset of the dataset to load. Options are `ObjaverseXL`, `ABO`, `3D-FUTURE`, `HSSD`, and `Toys4k`. +- `OUTPUT_DIR`: The directory to save the data. +- `SOURCE`: Required if `SUBSET` is `ObjaverseXL`. Options are `sketchfab` and `github`. + +For example, to load the metadata of the ObjaverseXL (sketchfab) subset and save it to `datasets/ObjaverseXL_sketchfab`, we can run: + +``` +python dataset_toolkits/build_metadata.py ObjaverseXL --source sketchfab --output_dir datasets/ObjaverseXL_sketchfab +``` + +### Step 3: Download Data + +Next, we need to download the 3D assets. + +``` +python dataset_toolkits/download.py --output_dir [--rank --world_size ] +``` + +- `SUBSET`: The subset of the dataset to download. Options are `ObjaverseXL`, `ABO`, `3D-FUTURE`, `HSSD`, and `Toys4k`. +- `OUTPUT_DIR`: The directory to save the data. + +You can also specify the `RANK` and `WORLD_SIZE` of the current process if you are using multiple nodes for data preparation. + +For example, to download the ObjaverseXL (sketchfab) subset and save it to `datasets/ObjaverseXL_sketchfab`, we can run: + +***NOTE: The example command below sets a large `WORLD_SIZE` for demonstration purposes. Only a small portion of the dataset will be downloaded.*** + +``` +python dataset_toolkits/download.py ObjaverseXL --output_dir datasets/ObjaverseXL_sketchfab --world_size 160000 +``` + +Some datasets may require interactive login to Hugging Face or manual downloading. Please follow the instructions given by the toolkits. + +After downloading, update the metadata file with: + +``` +python dataset_toolkits/build_metadata.py ObjaverseXL --output_dir datasets/ObjaverseXL_sketchfab +``` + +### Step 4: Render Multiview Images + +Multiview images can be rendered with: + +``` +python dataset_toolkits/render.py --output_dir [--num_views ] [--rank --world_size ] +``` + +- `SUBSET`: The subset of the dataset to render. Options are `ObjaverseXL`, `ABO`, `3D-FUTURE`, `HSSD`, and `Toys4k`. +- `OUTPUT_DIR`: The directory to save the data. +- `NUM_VIEWS`: The number of views to render. Default is 150. +- `RANK` and `WORLD_SIZE`: Multi-node configuration. + +For example, to render the ObjaverseXL (sketchfab) subset and save it to `datasets/ObjaverseXL_sketchfab`, we can run: + +``` +python dataset_toolkits/render.py ObjaverseXL --output_dir datasets/ObjaverseXL_sketchfab +``` + +Don't forget to update the metadata file with: + +``` +python dataset_toolkits/build_metadata.py ObjaverseXL --output_dir datasets/ObjaverseXL_sketchfab +``` + +### Step 5: Voxelize 3D Models + +We can voxelize the 3D models with: + +``` +python dataset_toolkits/voxelize.py --output_dir [--rank --world_size ] +``` + +- `SUBSET`: The subset of the dataset to voxelize. Options are `ObjaverseXL`, `ABO`, `3D-FUTURE`, `HSSD`, and `Toys4k`. +- `OUTPUT_DIR`: The directory to save the data. +- `RANK` and `WORLD_SIZE`: Multi-node configuration. + +For example, to voxelize the ObjaverseXL (sketchfab) subset and save it to `datasets/ObjaverseXL_sketchfab`, we can run: +``` +python dataset_toolkits/voxelize.py ObjaverseXL --output_dir datasets/ObjaverseXL_sketchfab +``` + +Then update the metadata file with: + +``` +python dataset_toolkits/build_metadata.py ObjaverseXL --output_dir datasets/ObjaverseXL_sketchfab +``` + +### Step 6: Extract DINO Features + +To prepare the training data for SLat VAE, we need to extract DINO features from multiview images and aggregate them into sparse voxel grids. + +``` +python dataset_toolkits/extract_features.py --output_dir [--rank --world_size ] +``` + +- `OUTPUT_DIR`: The directory to save the data. +- `RANK` and `WORLD_SIZE`: Multi-node configuration. + + +For example, to extract DINO features from the ObjaverseXL (sketchfab) subset and save it to `datasets/ObjaverseXL_sketchfab`, we can run: + +``` +python dataset_toolkits/extract_feature.py --output_dir datasets/ObjaverseXL_sketchfab +``` + +Then update the metadata file with: + +``` +python dataset_toolkits/build_metadata.py ObjaverseXL --output_dir datasets/ObjaverseXL_sketchfab +``` + +### Step 7: Encode Sparse Structures + +Encoding the sparse structures into latents to train the first stage generator: + +``` +python dataset_toolkits/encode_ss_latent.py --output_dir [--rank --world_size ] +``` + +- `OUTPUT_DIR`: The directory to save the data. +- `RANK` and `WORLD_SIZE`: Multi-node configuration. + +For example, to encode the sparse structures into latents for the ObjaverseXL (sketchfab) subset and save it to `datasets/ObjaverseXL_sketchfab`, we can run: + +``` +python dataset_toolkits/encode_ss_latent.py --output_dir datasets/ObjaverseXL_sketchfab +``` + +Then update the metadata file with: + +``` +python dataset_toolkits/build_metadata.py ObjaverseXL --output_dir datasets/ObjaverseXL_sketchfab +``` + +### Step 8: Encode SLat + +Encoding SLat for second stage generator training: + +``` +python dataset_toolkits/encode_latent.py --output_dir [--rank --world_size ] +``` + +- `OUTPUT_DIR`: The directory to save the data. +- `RANK` and `WORLD_SIZE`: Multi-node configuration. + +For example, to encode SLat for the ObjaverseXL (sketchfab) subset and save it to `datasets/ObjaverseXL_sketchfab`, we can run: + +``` +python dataset_toolkits/encode_latent.py --output_dir datasets/ObjaverseXL_sketchfab +``` + +Then update the metadata file with: + +``` +python dataset_toolkits/build_metadata.py ObjaverseXL --output_dir datasets/ObjaverseXL_sketchfab +``` + +### Step 9: Render Image Conditions + +To train the image conditioned generator, we need to render image conditions with augmented views. + +``` +python dataset_toolkits/render_cond.py --output_dir [--num_views ] [--rank --world_size ] +``` + +- `SUBSET`: The subset of the dataset to render. Options are `ObjaverseXL`, `ABO`, `3D-FUTURE`, `HSSD`, and `Toys4k`. +- `OUTPUT_DIR`: The directory to save the data. +- `NUM_VIEWS`: The number of views to render. Default is 24. +- `RANK` and `WORLD_SIZE`: Multi-node configuration. + +For example, to render image conditions for the ObjaverseXL (sketchfab) subset and save it to `datasets/ObjaverseXL_sketchfab`, we can run: + +``` +python dataset_toolkits/render_cond.py ObjaverseXL --output_dir datasets/ObjaverseXL_sketchfab +``` + +Then update the metadata file with: + +``` +python dataset_toolkits/build_metadata.py ObjaverseXL --output_dir datasets/ObjaverseXL_sketchfab +``` + + diff --git a/TRELLIS/DORA.png b/TRELLIS/DORA.png new file mode 100644 index 0000000000000000000000000000000000000000..c8eee9cae48df40018b319b5c71ede8ffbb2ffe7 Binary files /dev/null and b/TRELLIS/DORA.png differ diff --git a/TRELLIS/LICENSE b/TRELLIS/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..3d8b93bc7987d14c848448c089e2ae15311380d7 --- /dev/null +++ b/TRELLIS/LICENSE @@ -0,0 +1,21 @@ + MIT License + + Copyright (c) Microsoft Corporation. + + Permission is hereby granted, free of charge, to any person obtaining a copy + of this software and associated documentation files (the "Software"), to deal + in the Software without restriction, including without limitation the rights + to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + copies of the Software, and to permit persons to whom the Software is + furnished to do so, subject to the following conditions: + + The above copyright notice and this permission notice shall be included in all + copies or substantial portions of the Software. + + THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + SOFTWARE diff --git a/TRELLIS/README.md b/TRELLIS/README.md new file mode 100644 index 0000000000000000000000000000000000000000..11ed3474ab2453cfc55ebd70447fe808d511d899 --- /dev/null +++ b/TRELLIS/README.md @@ -0,0 +1,230 @@ + +

Structured 3D Latents
for Scalable and Versatile 3D Generation

+

arXiv +Project Page + +

+

+ +TRELLIS is a large 3D asset generation model. It takes in text or image prompts and generates high-quality 3D assets in various formats, such as Radiance Fields, 3D Gaussians, and meshes. The cornerstone of TRELLIS is a unified Structured LATent (SLAT) representation that allows decoding to different output formats and Rectified Flow Transformers tailored for SLAT as the powerful backbones. We provide large-scale pre-trained models with up to 2 billion parameters on a large 3D asset dataset of 500K diverse objects. TRELLIS significantly surpasses existing methods, including recent ones at similar scales, and showcases flexible output format selection and local 3D editing capabilities which were not offered by previous models. + +***Check out our [Project Page](https://trellis3d.github.io) for more videos and interactive demos!*** + + +## 🌟 Features +- **High Quality**: It produces diverse 3D assets at high quality with intricate shape and texture details. +- **Versatility**: It takes text or image prompts and can generate various final 3D representations including but not limited to *Radiance Fields*, *3D Gaussians*, and *meshes*, accommodating diverse downstream requirements. +- **Flexible Editing**: It allows for easy editings of generated 3D assets, such as generating variants of the same object or local editing of the 3D asset. + + +## ⏩ Updates + +**12/26/2024** +- Release [**TRELLIS-500K**](https://github.com/microsoft/TRELLIS#-dataset) dataset and toolkits for data preparation. + +**12/18/2024** +- Implementation of multi-image conditioning for TRELLIS-image model. ([#7](https://github.com/microsoft/TRELLIS/issues/7)). This is based on tuning-free algorithm without training a specialized model, so it may not give the best results for all input images. +- Add Gaussian export in `app.py` and `example.py`. ([#40](https://github.com/microsoft/TRELLIS/issues/40)) + + +## 🚧 TODO List +- [x] Release inference code and TRELLIS-image-large model +- [x] Release dataset and dataset toolkits +- [ ] Release TRELLIS-text model series +- [ ] Release training code + + +## 📦 Installation + +### Prerequisites +- **System**: The code is currently tested only on **Linux**. For windows setup, you may refer to [#3](https://github.com/microsoft/TRELLIS/issues/3) (not fully tested). +- **Hardware**: An NVIDIA GPU with at least 16GB of memory is necessary. The code has been verified on NVIDIA A100 and A6000 GPUs. +- **Software**: + - The [CUDA Toolkit](https://developer.nvidia.com/cuda-toolkit-archive) is needed to compile certain submodules. The code has been tested with CUDA versions 11.8 and 12.2. + - [Conda](https://docs.anaconda.com/miniconda/install/#quick-command-line-install) is recommended for managing dependencies. + - Python version 3.8 or higher is required. + +### Installation Steps +1. Clone the repo: + ```sh + git clone --recurse-submodules https://github.com/microsoft/TRELLIS.git + cd TRELLIS + ``` + +2. Install the dependencies: + + **Before running the following command there are somethings to note:** + - By adding `--new-env`, a new conda environment named `trellis` will be created. If you want to use an existing conda environment, please remove this flag. + - By default the `trellis` environment will use pytorch 2.4.0 with CUDA 11.8. If you want to use a different version of CUDA (e.g., if you have CUDA Toolkit 12.2 installed and do not want to install another 11.8 version for submodule compilation), you can remove the `--new-env` flag and manually install the required dependencies. Refer to [PyTorch](https://pytorch.org/get-started/previous-versions/) for the installation command. + - If you have multiple CUDA Toolkit versions installed, `PATH` should be set to the correct version before running the command. For example, if you have CUDA Toolkit 11.8 and 12.2 installed, you should run `export PATH=/usr/local/cuda-11.8/bin:$PATH` before running the command. + - By default, the code uses the `flash-attn` backend for attention. For GPUs do not support `flash-attn` (e.g., NVIDIA V100), you can remove the `--flash-attn` flag to install `xformers` only and set the `ATTN_BACKEND` environment variable to `xformers` before running the code. See the [Minimal Example](#minimal-example) for more details. + - The installation may take a while due to the large number of dependencies. Please be patient. If you encounter any issues, you can try to install the dependencies one by one, specifying one flag at a time. + - If you encounter any issues during the installation, feel free to open an issue or contact us. + + Create a new conda environment named `trellis` and install the dependencies: + ```sh + . ./setup.sh --new-env --basic --xformers --flash-attn --diffoctreerast --spconv --mipgaussian --kaolin --nvdiffrast + ``` + The detailed usage of `setup.sh` can be found by running `. ./setup.sh --help`. + ```sh + Usage: setup.sh [OPTIONS] + Options: + -h, --help Display this help message + --new-env Create a new conda environment + --basic Install basic dependencies + --xformers Install xformers + --flash-attn Install flash-attn + --diffoctreerast Install diffoctreerast + --vox2seq Install vox2seq + --spconv Install spconv + --mipgaussian Install mip-splatting + --kaolin Install kaolin + --nvdiffrast Install nvdiffrast + --demo Install all dependencies for demo + ``` + + +## 🤖 Pretrained Models + +We provide the following pretrained models: + +| Model | Description | #Params | Download | +| --- | --- | --- | --- | +| TRELLIS-image-large | Large image-to-3D model | 1.2B | [Download](https://huggingface.co/JeffreyXiang/TRELLIS-image-large) | +| TRELLIS-text-base | Base text-to-3D model | 342M | Coming Soon | +| TRELLIS-text-large | Large text-to-3D model | 1.1B | Coming Soon | +| TRELLIS-text-xlarge | Extra-large text-to-3D model | 2.0B | Coming Soon | + +The models are hosted on Hugging Face. You can directly load the models with their repository names in the code: +```python +TrellisImageTo3DPipeline.from_pretrained("JeffreyXiang/TRELLIS-image-large") +``` + +If you prefer loading the model from local, you can download the model files from the links above and load the model with the folder path (folder structure should be maintained): +```python +TrellisImageTo3DPipeline.from_pretrained("/path/to/TRELLIS-image-large") +``` + + +## 💡 Usage + +### Minimal Example + +Here is an [example](example.py) of how to use the pretrained models for 3D asset generation. + +```python +import os +# os.environ['ATTN_BACKEND'] = 'xformers' # Can be 'flash-attn' or 'xformers', default is 'flash-attn' +os.environ['SPCONV_ALGO'] = 'native' # Can be 'native' or 'auto', default is 'auto'. + # 'auto' is faster but will do benchmarking at the beginning. + # Recommended to set to 'native' if run only once. + +import imageio +from PIL import Image +from trellis.pipelines import TrellisImageTo3DPipeline +from trellis.utils import render_utils, postprocessing_utils + +# Load a pipeline from a model folder or a Hugging Face model hub. +pipeline = TrellisImageTo3DPipeline.from_pretrained("JeffreyXiang/TRELLIS-image-large") +pipeline.cuda() + +# Load an image +image = Image.open("assets/example_image/T.png") + +# Run the pipeline +outputs = pipeline.run( + image, + seed=1, + # Optional parameters + # sparse_structure_sampler_params={ + # "steps": 12, + # "cfg_strength": 7.5, + # }, + # slat_sampler_params={ + # "steps": 12, + # "cfg_strength": 3, + # }, +) +# outputs is a dictionary containing generated 3D assets in different formats: +# - outputs['gaussian']: a list of 3D Gaussians +# - outputs['radiance_field']: a list of radiance fields +# - outputs['mesh']: a list of meshes + +# Render the outputs +video = render_utils.render_video(outputs['gaussian'][0])['color'] +imageio.mimsave("sample_gs.mp4", video, fps=30) +video = render_utils.render_video(outputs['radiance_field'][0])['color'] +imageio.mimsave("sample_rf.mp4", video, fps=30) +video = render_utils.render_video(outputs['mesh'][0])['normal'] +imageio.mimsave("sample_mesh.mp4", video, fps=30) + +# GLB files can be extracted from the outputs +glb = postprocessing_utils.to_glb( + outputs['gaussian'][0], + outputs['mesh'][0], + # Optional parameters + simplify=0.95, # Ratio of triangles to remove in the simplification process + texture_size=1024, # Size of the texture used for the GLB +) +glb.export("sample.glb") + +# Save Gaussians as PLY files +outputs['gaussian'][0].save_ply("sample.ply") +``` + +After running the code, you will get the following files: +- `sample_gs.mp4`: a video showing the 3D Gaussian representation +- `sample_rf.mp4`: a video showing the Radiance Field representation +- `sample_mesh.mp4`: a video showing the mesh representation +- `sample.glb`: a GLB file containing the extracted textured mesh +- `sample.ply`: a PLY file containing the 3D Gaussian representation + + +### Web Demo + +[app.py](app.py) provides a simple web demo for 3D asset generation. Since this demo is based on [Gradio](https://gradio.app/), additional dependencies are required: +```sh +. ./setup.sh --demo +``` + +After installing the dependencies, you can run the demo with the following command: +```sh +python app.py +``` + +Then, you can access the demo at the address shown in the terminal. + +***The web demo is also available on [Hugging Face Spaces](https://huggingface.co/spaces/JeffreyXiang/TRELLIS)!*** + + + +## 📚 Dataset + +We provide **TRELLIS-500K**, a large-scale dataset containing 500K 3D assets curated from [Objaverse(XL)](https://objaverse.allenai.org/), [ABO](https://amazon-berkeley-objects.s3.amazonaws.com/index.html), [3D-FUTURE](https://tianchi.aliyun.com/specials/promotion/alibaba-3d-future), [HSSD](https://huggingface.co/datasets/hssd/hssd-models), and [Toys4k](https://github.com/rehg-lab/lowshot-shapebias/tree/main/toys4k), filtered based on aesthetic scores. Please refer to the [dataset README](DATASET.md) for more details. + + +## ⚖️ License + +TRELLIS models and the majority of the code are licensed under the [MIT License](LICENSE). The following submodules may have different licenses: +- [**diffoctreerast**](https://github.com/JeffreyXiang/diffoctreerast): We developed a CUDA-based real-time differentiable octree renderer for rendering radiance fields as part of this project. This renderer is derived from the [diff-gaussian-rasterization](https://github.com/graphdeco-inria/diff-gaussian-rasterization) project and is available under the [LICENSE](https://github.com/JeffreyXiang/diffoctreerast/blob/master/LICENSE). + + +- [**Modified Flexicubes**](https://github.com/MaxtirError/FlexiCubes): In this project, we used a modified version of [Flexicubes](https://github.com/nv-tlabs/FlexiCubes) to support vertex attributes. This modified version is licensed under the [LICENSE](https://github.com/nv-tlabs/FlexiCubes/blob/main/LICENSE.txt). + + + + + +## 📜 Citation + +If you find this work helpful, please consider citing our paper: + +```bibtex +@article{xiang2024structured, + title = {Structured 3D Latents for Scalable and Versatile 3D Generation}, + author = {Xiang, Jianfeng and Lv, Zelong and Xu, Sicheng and Deng, Yu and Wang, Ruicheng and Zhang, Bowen and Chen, Dong and Tong, Xin and Yang, Jiaolong}, + journal = {arXiv preprint arXiv:2412.01506}, + year = {2024} +} +``` + diff --git a/TRELLIS/SECURITY.md b/TRELLIS/SECURITY.md new file mode 100644 index 0000000000000000000000000000000000000000..6b906d43bc2057e6a832c3236897b2e514d6e1e7 --- /dev/null +++ b/TRELLIS/SECURITY.md @@ -0,0 +1,41 @@ + + +## Security + +Microsoft takes the security of our software products and services seriously, which includes all source code repositories managed through our GitHub organizations, which include [Microsoft](https://github.com/Microsoft), [Azure](https://github.com/Azure), [DotNet](https://github.com/dotnet), [AspNet](https://github.com/aspnet) and [Xamarin](https://github.com/xamarin). + +If you believe you have found a security vulnerability in any Microsoft-owned repository that meets [Microsoft's definition of a security vulnerability](https://aka.ms/security.md/definition), please report it to us as described below. + +## Reporting Security Issues + +**Please do not report security vulnerabilities through public GitHub issues.** + +Instead, please report them to the Microsoft Security Response Center (MSRC) at [https://msrc.microsoft.com/create-report](https://aka.ms/security.md/msrc/create-report). + +If you prefer to submit without logging in, send email to [secure@microsoft.com](mailto:secure@microsoft.com). If possible, encrypt your message with our PGP key; please download it from the [Microsoft Security Response Center PGP Key page](https://aka.ms/security.md/msrc/pgp). + +You should receive a response within 24 hours. If for some reason you do not, please follow up via email to ensure we received your original message. Additional information can be found at [microsoft.com/msrc](https://www.microsoft.com/msrc). + +Please include the requested information listed below (as much as you can provide) to help us better understand the nature and scope of the possible issue: + + * Type of issue (e.g. buffer overflow, SQL injection, cross-site scripting, etc.) + * Full paths of source file(s) related to the manifestation of the issue + * The location of the affected source code (tag/branch/commit or direct URL) + * Any special configuration required to reproduce the issue + * Step-by-step instructions to reproduce the issue + * Proof-of-concept or exploit code (if possible) + * Impact of the issue, including how an attacker might exploit the issue + +This information will help us triage your report more quickly. + +If you are reporting for a bug bounty, more complete reports can contribute to a higher bounty award. Please visit our [Microsoft Bug Bounty Program](https://aka.ms/security.md/msrc/bounty) page for more details about our active programs. + +## Preferred Languages + +We prefer all communications to be in English. + +## Policy + +Microsoft follows the principle of [Coordinated Vulnerability Disclosure](https://aka.ms/security.md/cvd). + + diff --git a/TRELLIS/SUPPORT.md b/TRELLIS/SUPPORT.md new file mode 100644 index 0000000000000000000000000000000000000000..291d4d43733f4c15a81ff598ec1c99fd6c18f64c --- /dev/null +++ b/TRELLIS/SUPPORT.md @@ -0,0 +1,25 @@ +# TODO: The maintainer of this repo has not yet edited this file + +**REPO OWNER**: Do you want Customer Service & Support (CSS) support for this product/project? + +- **No CSS support:** Fill out this template with information about how to file issues and get help. +- **Yes CSS support:** Fill out an intake form at [aka.ms/onboardsupport](https://aka.ms/onboardsupport). CSS will work with/help you to determine next steps. +- **Not sure?** Fill out an intake as though the answer were "Yes". CSS will help you decide. + +*Then remove this first heading from this SUPPORT.MD file before publishing your repo.* + +# Support + +## How to file issues and get help + +This project uses GitHub Issues to track bugs and feature requests. Please search the existing +issues before filing new issues to avoid duplicates. For new issues, file your bug or +feature request as a new Issue. + +For help and questions about using this project, please **REPO MAINTAINER: INSERT INSTRUCTIONS HERE +FOR HOW TO ENGAGE REPO OWNERS OR COMMUNITY FOR HELP. COULD BE A STACK OVERFLOW TAG OR OTHER +CHANNEL. WHERE WILL YOU HELP PEOPLE?**. + +## Microsoft Support Policy + +Support for this **PROJECT or PRODUCT** is limited to the resources listed above. diff --git a/TRELLIS/app.py b/TRELLIS/app.py new file mode 100644 index 0000000000000000000000000000000000000000..a79f28d36bc5d94d2c9e258b4ff2013ab7e1a9b6 --- /dev/null +++ b/TRELLIS/app.py @@ -0,0 +1,403 @@ +import gradio as gr +from gradio_litmodel3d import LitModel3D + +import os +import shutil +from typing import * +import torch +import numpy as np +import imageio +from easydict import EasyDict as edict +from PIL import Image +from trellis.pipelines import TrellisImageTo3DPipeline +from trellis.representations import Gaussian, MeshExtractResult +from trellis.utils import render_utils, postprocessing_utils + + +MAX_SEED = np.iinfo(np.int32).max +TMP_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'tmp') +os.makedirs(TMP_DIR, exist_ok=True) + + +def start_session(req: gr.Request): + user_dir = os.path.join(TMP_DIR, str(req.session_hash)) + os.makedirs(user_dir, exist_ok=True) + + +def end_session(req: gr.Request): + user_dir = os.path.join(TMP_DIR, str(req.session_hash)) + shutil.rmtree(user_dir) + + +def preprocess_image(image: Image.Image) -> Image.Image: + """ + Preprocess the input image. + + Args: + image (Image.Image): The input image. + + Returns: + Image.Image: The preprocessed image. + """ + processed_image = pipeline.preprocess_image(image) + return processed_image + + +def preprocess_images(images: List[Tuple[Image.Image, str]]) -> List[Image.Image]: + """ + Preprocess a list of input images. + + Args: + images (List[Tuple[Image.Image, str]]): The input images. + + Returns: + List[Image.Image]: The preprocessed images. + """ + images = [image[0] for image in images] + processed_images = [pipeline.preprocess_image(image) for image in images] + return processed_images + + +def pack_state(gs: Gaussian, mesh: MeshExtractResult) -> dict: + return { + 'gaussian': { + **gs.init_params, + '_xyz': gs._xyz.cpu().numpy(), + '_features_dc': gs._features_dc.cpu().numpy(), + '_scaling': gs._scaling.cpu().numpy(), + '_rotation': gs._rotation.cpu().numpy(), + '_opacity': gs._opacity.cpu().numpy(), + }, + 'mesh': { + 'vertices': mesh.vertices.cpu().numpy(), + 'faces': mesh.faces.cpu().numpy(), + }, + } + + +def unpack_state(state: dict) -> Tuple[Gaussian, edict, str]: + gs = Gaussian( + aabb=state['gaussian']['aabb'], + sh_degree=state['gaussian']['sh_degree'], + mininum_kernel_size=state['gaussian']['mininum_kernel_size'], + scaling_bias=state['gaussian']['scaling_bias'], + opacity_bias=state['gaussian']['opacity_bias'], + scaling_activation=state['gaussian']['scaling_activation'], + ) + gs._xyz = torch.tensor(state['gaussian']['_xyz'], device='cuda') + gs._features_dc = torch.tensor(state['gaussian']['_features_dc'], device='cuda') + gs._scaling = torch.tensor(state['gaussian']['_scaling'], device='cuda') + gs._rotation = torch.tensor(state['gaussian']['_rotation'], device='cuda') + gs._opacity = torch.tensor(state['gaussian']['_opacity'], device='cuda') + + mesh = edict( + vertices=torch.tensor(state['mesh']['vertices'], device='cuda'), + faces=torch.tensor(state['mesh']['faces'], device='cuda'), + ) + + return gs, mesh + + +def get_seed(randomize_seed: bool, seed: int) -> int: + """ + Get the random seed. + """ + return np.random.randint(0, MAX_SEED) if randomize_seed else seed + + +def image_to_3d( + image: Image.Image, + multiimages: List[Tuple[Image.Image, str]], + is_multiimage: bool, + seed: int, + ss_guidance_strength: float, + ss_sampling_steps: int, + slat_guidance_strength: float, + slat_sampling_steps: int, + multiimage_algo: Literal["multidiffusion", "stochastic"], + req: gr.Request, +) -> Tuple[dict, str]: + """ + Convert an image to a 3D model. + + Args: + image (Image.Image): The input image. + multiimages (List[Tuple[Image.Image, str]]): The input images in multi-image mode. + is_multiimage (bool): Whether is in multi-image mode. + seed (int): The random seed. + ss_guidance_strength (float): The guidance strength for sparse structure generation. + ss_sampling_steps (int): The number of sampling steps for sparse structure generation. + slat_guidance_strength (float): The guidance strength for structured latent generation. + slat_sampling_steps (int): The number of sampling steps for structured latent generation. + multiimage_algo (Literal["multidiffusion", "stochastic"]): The algorithm for multi-image generation. + + Returns: + dict: The information of the generated 3D model. + str: The path to the video of the 3D model. + """ + user_dir = os.path.join(TMP_DIR, str(req.session_hash)) + if not is_multiimage: + outputs = pipeline.run( + image, + seed=seed, + formats=["gaussian", "mesh"], + preprocess_image=False, + sparse_structure_sampler_params={ + "steps": ss_sampling_steps, + "cfg_strength": ss_guidance_strength, + }, + slat_sampler_params={ + "steps": slat_sampling_steps, + "cfg_strength": slat_guidance_strength, + }, + ) + else: + outputs = pipeline.run_multi_image( + [image[0] for image in multiimages], + seed=seed, + formats=["gaussian", "mesh"], + preprocess_image=False, + sparse_structure_sampler_params={ + "steps": ss_sampling_steps, + "cfg_strength": ss_guidance_strength, + }, + slat_sampler_params={ + "steps": slat_sampling_steps, + "cfg_strength": slat_guidance_strength, + }, + mode=multiimage_algo, + ) + video = render_utils.render_video(outputs['gaussian'][0], num_frames=120)['color'] + video_geo = render_utils.render_video(outputs['mesh'][0], num_frames=120)['normal'] + video = [np.concatenate([video[i], video_geo[i]], axis=1) for i in range(len(video))] + video_path = os.path.join(user_dir, 'sample.mp4') + imageio.mimsave(video_path, video, fps=15) + state = pack_state(outputs['gaussian'][0], outputs['mesh'][0]) + torch.cuda.empty_cache() + return state, video_path + + +def extract_glb( + state: dict, + mesh_simplify: float, + texture_size: int, + req: gr.Request, +) -> Tuple[str, str]: + """ + Extract a GLB file from the 3D model. + + Args: + state (dict): The state of the generated 3D model. + mesh_simplify (float): The mesh simplification factor. + texture_size (int): The texture resolution. + + Returns: + str: The path to the extracted GLB file. + """ + user_dir = os.path.join(TMP_DIR, str(req.session_hash)) + gs, mesh = unpack_state(state) + glb = postprocessing_utils.to_glb(gs, mesh, simplify=mesh_simplify, texture_size=texture_size, verbose=False) + glb_path = os.path.join(user_dir, 'sample.glb') + glb.export(glb_path) + torch.cuda.empty_cache() + return glb_path, glb_path + + +def extract_gaussian(state: dict, req: gr.Request) -> Tuple[str, str]: + """ + Extract a Gaussian file from the 3D model. + + Args: + state (dict): The state of the generated 3D model. + + Returns: + str: The path to the extracted Gaussian file. + """ + user_dir = os.path.join(TMP_DIR, str(req.session_hash)) + gs, _ = unpack_state(state) + gaussian_path = os.path.join(user_dir, 'sample.ply') + gs.save_ply(gaussian_path) + torch.cuda.empty_cache() + return gaussian_path, gaussian_path + + +def prepare_multi_example() -> List[Image.Image]: + multi_case = list(set([i.split('_')[0] for i in os.listdir("assets/example_multi_image")])) + images = [] + for case in multi_case: + _images = [] + for i in range(1, 4): + img = Image.open(f'assets/example_multi_image/{case}_{i}.png') + W, H = img.size + img = img.resize((int(W / H * 512), 512)) + _images.append(np.array(img)) + images.append(Image.fromarray(np.concatenate(_images, axis=1))) + return images + + +def split_image(image: Image.Image) -> List[Image.Image]: + """ + Split an image into multiple views. + """ + image = np.array(image) + alpha = image[..., 3] + alpha = np.any(alpha>0, axis=0) + start_pos = np.where(~alpha[:-1] & alpha[1:])[0].tolist() + end_pos = np.where(alpha[:-1] & ~alpha[1:])[0].tolist() + images = [] + for s, e in zip(start_pos, end_pos): + images.append(Image.fromarray(image[:, s:e+1])) + return [preprocess_image(image) for image in images] + + +with gr.Blocks(delete_cache=(600, 600)) as demo: + gr.Markdown(""" + ## Image to 3D Asset with [TRELLIS](https://trellis3d.github.io/) + * Upload an image and click "Generate" to create a 3D asset. If the image has alpha channel, it be used as the mask. Otherwise, we use `rembg` to remove the background. + * If you find the generated 3D asset satisfactory, click "Extract GLB" to extract the GLB file and download it. + """) + + with gr.Row(): + with gr.Column(): + with gr.Tabs() as input_tabs: + with gr.Tab(label="Single Image", id=0) as single_image_input_tab: + image_prompt = gr.Image(label="Image Prompt", format="png", image_mode="RGBA", type="pil", height=300) + with gr.Tab(label="Multiple Images", id=1) as multiimage_input_tab: + multiimage_prompt = gr.Gallery(label="Image Prompt", format="png", type="pil", height=300, columns=3) + gr.Markdown(""" + Input different views of the object in separate images. + + *NOTE: this is an experimental algorithm without training a specialized model. It may not produce the best results for all images, especially those having different poses or inconsistent details.* + """) + + with gr.Accordion(label="Generation Settings", open=False): + seed = gr.Slider(0, MAX_SEED, label="Seed", value=0, step=1) + randomize_seed = gr.Checkbox(label="Randomize Seed", value=True) + gr.Markdown("Stage 1: Sparse Structure Generation") + with gr.Row(): + ss_guidance_strength = gr.Slider(0.0, 10.0, label="Guidance Strength", value=7.5, step=0.1) + ss_sampling_steps = gr.Slider(1, 50, label="Sampling Steps", value=12, step=1) + gr.Markdown("Stage 2: Structured Latent Generation") + with gr.Row(): + slat_guidance_strength = gr.Slider(0.0, 10.0, label="Guidance Strength", value=3.0, step=0.1) + slat_sampling_steps = gr.Slider(1, 50, label="Sampling Steps", value=12, step=1) + multiimage_algo = gr.Radio(["stochastic", "multidiffusion"], label="Multi-image Algorithm", value="stochastic") + + generate_btn = gr.Button("Generate") + + with gr.Accordion(label="GLB Extraction Settings", open=False): + mesh_simplify = gr.Slider(0.9, 0.98, label="Simplify", value=0.95, step=0.01) + texture_size = gr.Slider(512, 2048, label="Texture Size", value=1024, step=512) + + with gr.Row(): + extract_glb_btn = gr.Button("Extract GLB", interactive=False) + extract_gs_btn = gr.Button("Extract Gaussian", interactive=False) + gr.Markdown(""" + *NOTE: Gaussian file can be very large (~50MB), it will take a while to display and download.* + """) + + with gr.Column(): + video_output = gr.Video(label="Generated 3D Asset", autoplay=True, loop=True, height=300) + model_output = LitModel3D(label="Extracted GLB/Gaussian", exposure=10.0, height=300) + + with gr.Row(): + download_glb = gr.DownloadButton(label="Download GLB", interactive=False) + download_gs = gr.DownloadButton(label="Download Gaussian", interactive=False) + + is_multiimage = gr.State(False) + output_buf = gr.State() + + # Example images at the bottom of the page + with gr.Row() as single_image_example: + examples = gr.Examples( + examples=[ + f'assets/example_image/{image}' + for image in os.listdir("assets/example_image") + ], + inputs=[image_prompt], + fn=preprocess_image, + outputs=[image_prompt], + run_on_click=True, + examples_per_page=64, + ) + with gr.Row(visible=False) as multiimage_example: + examples_multi = gr.Examples( + examples=prepare_multi_example(), + inputs=[image_prompt], + fn=split_image, + outputs=[multiimage_prompt], + run_on_click=True, + examples_per_page=8, + ) + + # Handlers + demo.load(start_session) + demo.unload(end_session) + + single_image_input_tab.select( + lambda: tuple([False, gr.Row.update(visible=True), gr.Row.update(visible=False)]), + outputs=[is_multiimage, single_image_example, multiimage_example] + ) + multiimage_input_tab.select( + lambda: tuple([True, gr.Row.update(visible=False), gr.Row.update(visible=True)]), + outputs=[is_multiimage, single_image_example, multiimage_example] + ) + + image_prompt.upload( + preprocess_image, + inputs=[image_prompt], + outputs=[image_prompt], + ) + multiimage_prompt.upload( + preprocess_images, + inputs=[multiimage_prompt], + outputs=[multiimage_prompt], + ) + + generate_btn.click( + get_seed, + inputs=[randomize_seed, seed], + outputs=[seed], + ).then( + image_to_3d, + inputs=[image_prompt, multiimage_prompt, is_multiimage, seed, ss_guidance_strength, ss_sampling_steps, slat_guidance_strength, slat_sampling_steps, multiimage_algo], + outputs=[output_buf, video_output], + ).then( + lambda: tuple([gr.Button(interactive=True), gr.Button(interactive=True)]), + outputs=[extract_glb_btn, extract_gs_btn], + ) + + video_output.clear( + lambda: tuple([gr.Button(interactive=False), gr.Button(interactive=False)]), + outputs=[extract_glb_btn, extract_gs_btn], + ) + + extract_glb_btn.click( + extract_glb, + inputs=[output_buf, mesh_simplify, texture_size], + outputs=[model_output, download_glb], + ).then( + lambda: gr.Button(interactive=True), + outputs=[download_glb], + ) + + extract_gs_btn.click( + extract_gaussian, + inputs=[output_buf], + outputs=[model_output, download_gs], + ).then( + lambda: gr.Button(interactive=True), + outputs=[download_gs], + ) + + model_output.clear( + lambda: gr.Button(interactive=False), + outputs=[download_glb], + ) + + +# Launch the Gradio app +if __name__ == "__main__": + pipeline = TrellisImageTo3DPipeline.from_pretrained("JeffreyXiang/TRELLIS-image-large") + pipeline.cuda() + demo.launch() diff --git a/TRELLIS/assets/example_image/T.png b/TRELLIS/assets/example_image/T.png new file mode 100644 index 0000000000000000000000000000000000000000..861ee434cb123a74d50b843631db47d0646675ed --- /dev/null +++ b/TRELLIS/assets/example_image/T.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:e29ddc83a5bd3a05fe9b34732169bc4ea7131f7c36527fdc5f626a90a73076d2 +size 954642 diff --git a/TRELLIS/assets/example_image/typical_building_building.png b/TRELLIS/assets/example_image/typical_building_building.png new file mode 100644 index 0000000000000000000000000000000000000000..515be4e5bb423b92933f1dc11438b570c5f4db95 --- /dev/null +++ b/TRELLIS/assets/example_image/typical_building_building.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:8faa11d557be95c000c475247e61a773d511114c7d1e517c04f8d3d88a6049ec +size 546948 diff --git a/TRELLIS/assets/example_image/typical_building_castle.png b/TRELLIS/assets/example_image/typical_building_castle.png new file mode 100644 index 0000000000000000000000000000000000000000..8b4f705e8f3e37eb96cf948eb822b193e47c3bf8 --- /dev/null +++ b/TRELLIS/assets/example_image/typical_building_castle.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:076f0554b087b921863643d2b1ab3e0572a13a347fd66bc29cd9d194034affae +size 426358 diff --git a/TRELLIS/assets/example_image/typical_building_colorful_cottage.png b/TRELLIS/assets/example_image/typical_building_colorful_cottage.png new file mode 100644 index 0000000000000000000000000000000000000000..d9f451150d723f39a402590c520702e4e7fd8e44 --- /dev/null +++ b/TRELLIS/assets/example_image/typical_building_colorful_cottage.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:687305b4e35da759692be0de614d728583a2a9cd2fd3a55593fa753e567d0d47 +size 609383 diff --git a/TRELLIS/assets/example_image/typical_building_maya_pyramid.png b/TRELLIS/assets/example_image/typical_building_maya_pyramid.png new file mode 100644 index 0000000000000000000000000000000000000000..d5db764b32e08a93ac098c06aac9329dba743ea9 --- /dev/null +++ b/TRELLIS/assets/example_image/typical_building_maya_pyramid.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:4d514f7f4db244ee184af4ddfbc5948d417b4e5bf1c6ee5f5a592679561690df +size 231664 diff --git a/TRELLIS/assets/example_image/typical_building_mushroom.png b/TRELLIS/assets/example_image/typical_building_mushroom.png new file mode 100644 index 0000000000000000000000000000000000000000..fcc169f1f0eef1725af0ba60b429a5d6c550dfa5 --- /dev/null +++ b/TRELLIS/assets/example_image/typical_building_mushroom.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:de9b72d3e13e967e70844ddc54643832a84a1b35ca043a11e7c774371d0ccdab +size 488032 diff --git a/TRELLIS/assets/example_image/typical_building_space_station.png b/TRELLIS/assets/example_image/typical_building_space_station.png new file mode 100644 index 0000000000000000000000000000000000000000..bffc6286658add1b747af63e29d42c2735ed37a3 --- /dev/null +++ b/TRELLIS/assets/example_image/typical_building_space_station.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:212c7b4c27ba1e01a7908dbc7f245e7115850eadbc9974aa726327cf35062846 +size 619626 diff --git a/TRELLIS/assets/example_image/typical_creature_dragon.png b/TRELLIS/assets/example_image/typical_creature_dragon.png new file mode 100644 index 0000000000000000000000000000000000000000..d62c7e18b52a7f1ac603c143abfc092b321e2734 --- /dev/null +++ b/TRELLIS/assets/example_image/typical_creature_dragon.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:0e8d6720dfa1e7b332b76e897e617b7f0863187f30879451b4724f482c84185a +size 563963 diff --git a/TRELLIS/assets/example_image/typical_creature_elephant.png b/TRELLIS/assets/example_image/typical_creature_elephant.png new file mode 100644 index 0000000000000000000000000000000000000000..d4f189675bc888238896eab93f4aefa6e2e32a9b --- /dev/null +++ b/TRELLIS/assets/example_image/typical_creature_elephant.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:86a171e37a3d781e7215977f565cd63e813341c1f89e2c586fa61937e4ed6916 +size 481919 diff --git a/TRELLIS/assets/example_image/typical_creature_furry.png b/TRELLIS/assets/example_image/typical_creature_furry.png new file mode 100644 index 0000000000000000000000000000000000000000..58033811f29f4757759eb66a351bbb82228c0f07 --- /dev/null +++ b/TRELLIS/assets/example_image/typical_creature_furry.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:5b5445b8f1996cf6d72497b2d7564c656f4048e6c1fa626fd7bb3ee582fee671 +size 647568 diff --git a/TRELLIS/assets/example_image/typical_creature_quadruped.png b/TRELLIS/assets/example_image/typical_creature_quadruped.png new file mode 100644 index 0000000000000000000000000000000000000000..503ff7d93aed42ef8b8ccaa741559007f4627e7e --- /dev/null +++ b/TRELLIS/assets/example_image/typical_creature_quadruped.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:7469f43f58389adec101e9685f60188bd4e7fbede77eef975102f6a8865bc786 +size 685321 diff --git a/TRELLIS/assets/example_image/typical_creature_robot_crab.png b/TRELLIS/assets/example_image/typical_creature_robot_crab.png new file mode 100644 index 0000000000000000000000000000000000000000..1546f322acc31caeb524a518979afbffe8de197f --- /dev/null +++ b/TRELLIS/assets/example_image/typical_creature_robot_crab.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:d7e716abe8f8895080f562d1dc26b14fa0e20a05aa5beb2770c6fb3b87b3476a +size 594232 diff --git a/TRELLIS/assets/example_image/typical_creature_robot_dinosour.png b/TRELLIS/assets/example_image/typical_creature_robot_dinosour.png new file mode 100644 index 0000000000000000000000000000000000000000..b8a802f1f64424db980dc09e35c7be98e526d9dd --- /dev/null +++ b/TRELLIS/assets/example_image/typical_creature_robot_dinosour.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:d0986f29557a6fddf9b52b5251a6b6103728c61e201b1cfad1e709b090b72f56 +size 632292 diff --git a/TRELLIS/assets/example_image/typical_creature_rock_monster.png b/TRELLIS/assets/example_image/typical_creature_rock_monster.png new file mode 100644 index 0000000000000000000000000000000000000000..8a6987064fddd7162ff55bd9a98c47c32a6b2397 --- /dev/null +++ b/TRELLIS/assets/example_image/typical_creature_rock_monster.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:e29458a6110bee8374c0d4d12471e7167a6c1c98c18f6e2d7ff4f5f0ca3fa01b +size 648344 diff --git a/TRELLIS/assets/example_image/typical_humanoid_block_robot.png b/TRELLIS/assets/example_image/typical_humanoid_block_robot.png new file mode 100644 index 0000000000000000000000000000000000000000..d509f6d257ff3851eae6f12820ec2cf605bec85f --- /dev/null +++ b/TRELLIS/assets/example_image/typical_humanoid_block_robot.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:3a0acbb532668e1bf35f3eef5bcbfdd094c22219ef2d837fa01ccf51cce75ca3 +size 441407 diff --git a/TRELLIS/assets/example_image/typical_humanoid_dragonborn.png b/TRELLIS/assets/example_image/typical_humanoid_dragonborn.png new file mode 100644 index 0000000000000000000000000000000000000000..88f2d9070bb76eb31cd2b9e3c9677f70998e5d3b --- /dev/null +++ b/TRELLIS/assets/example_image/typical_humanoid_dragonborn.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:5d7c547909a6c12da55dbab1c1c98181ff09e58c9ba943682ca105e71be9548e +size 481371 diff --git a/TRELLIS/assets/example_image/typical_humanoid_dwarf.png b/TRELLIS/assets/example_image/typical_humanoid_dwarf.png new file mode 100644 index 0000000000000000000000000000000000000000..6bcc3945be7038c12c61d72950bab8bcb8475f10 --- /dev/null +++ b/TRELLIS/assets/example_image/typical_humanoid_dwarf.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:a4a7c157d5d8071128c27594e45a7a03e5113b3333b7f1c5ff1379481e3e0264 +size 498425 diff --git a/TRELLIS/assets/example_image/typical_humanoid_goblin.png b/TRELLIS/assets/example_image/typical_humanoid_goblin.png new file mode 100644 index 0000000000000000000000000000000000000000..6f4a8142d9f452b233b6d5b0bd0d6dac5e41f94e --- /dev/null +++ b/TRELLIS/assets/example_image/typical_humanoid_goblin.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:2b0e9a04ae3e7bef44b7180a70306f95374b60727ffa0f6f01fd6c746595cd77 +size 496430 diff --git a/TRELLIS/assets/example_image/typical_humanoid_mech.png b/TRELLIS/assets/example_image/typical_humanoid_mech.png new file mode 100644 index 0000000000000000000000000000000000000000..e0e07443d760b405979253c85bd28747d686976c --- /dev/null +++ b/TRELLIS/assets/example_image/typical_humanoid_mech.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:a244ec54b7984e646e54d433de6897657081dd5b9cd5ccd3d865328d813beb49 +size 849947 diff --git a/TRELLIS/assets/example_image/typical_misc_crate.png b/TRELLIS/assets/example_image/typical_misc_crate.png new file mode 100644 index 0000000000000000000000000000000000000000..b66a57feffa7117f1f3c2d614d24001fb2611467 --- /dev/null +++ b/TRELLIS/assets/example_image/typical_misc_crate.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:59fd9884301faca93265166d90078e8c31e76c7f93524b1db31975df4b450748 +size 642105 diff --git a/TRELLIS/assets/example_image/typical_misc_fireplace.png b/TRELLIS/assets/example_image/typical_misc_fireplace.png new file mode 100644 index 0000000000000000000000000000000000000000..c8f352fbb6b256d56d43febb052b91943b6d4e3d --- /dev/null +++ b/TRELLIS/assets/example_image/typical_misc_fireplace.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:2288c034603e289192d63cbc73565107caefd99e81c4b7afa2983c8b13e34440 +size 558466 diff --git a/TRELLIS/assets/example_image/typical_misc_gate.png b/TRELLIS/assets/example_image/typical_misc_gate.png new file mode 100644 index 0000000000000000000000000000000000000000..58f96079232d47f7615a491b95d2fc7cd635ef1d --- /dev/null +++ b/TRELLIS/assets/example_image/typical_misc_gate.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ec8db5389b74fe56b826e3c6d860234541033387350e09268591c46d411cc8e9 +size 572414 diff --git a/TRELLIS/assets/example_image/typical_misc_lantern.png b/TRELLIS/assets/example_image/typical_misc_lantern.png new file mode 100644 index 0000000000000000000000000000000000000000..0917aa4be24bec12ba961475efd5bef1c6f8fd2a --- /dev/null +++ b/TRELLIS/assets/example_image/typical_misc_lantern.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:e17bd83adf433ebfca17abd220097b2b7f08affc649518bd7822e03797e83d41 +size 300270 diff --git a/TRELLIS/assets/example_image/typical_misc_magicbook.png b/TRELLIS/assets/example_image/typical_misc_magicbook.png new file mode 100644 index 0000000000000000000000000000000000000000..f268ca7e37851578d888cb26f8057846f9e12e78 --- /dev/null +++ b/TRELLIS/assets/example_image/typical_misc_magicbook.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:aff9c14589c340e31b61bf82e4506d77d72c511e741260fa1e600cefa4e103e6 +size 496050 diff --git a/TRELLIS/assets/example_image/typical_misc_mailbox.png b/TRELLIS/assets/example_image/typical_misc_mailbox.png new file mode 100644 index 0000000000000000000000000000000000000000..31a5b45556a0da3b7e36301230c7ffa1f15adad1 --- /dev/null +++ b/TRELLIS/assets/example_image/typical_misc_mailbox.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:01e86a5d68edafb7e11d7a86f7e8081f5ed1b02578198a3271554c5fb8fb9fcf +size 631146 diff --git a/TRELLIS/assets/example_image/typical_misc_monster_chest.png b/TRELLIS/assets/example_image/typical_misc_monster_chest.png new file mode 100644 index 0000000000000000000000000000000000000000..660a8a9bbec77a967b0263b98900411d69588d71 --- /dev/null +++ b/TRELLIS/assets/example_image/typical_misc_monster_chest.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:c57a598e842225a31b9770bf3bbb9ae86197ec57d0c2883caf8cb5eed4908fbc +size 690259 diff --git a/TRELLIS/assets/example_image/typical_misc_paper_machine.png b/TRELLIS/assets/example_image/typical_misc_paper_machine.png new file mode 100644 index 0000000000000000000000000000000000000000..db27c6c49446d85e081133a195be4aff90e66f33 --- /dev/null +++ b/TRELLIS/assets/example_image/typical_misc_paper_machine.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:2d55400ae5d4df2377258400d800ece75766d5274e80ce07c3b29a4d1fd1fa36 +size 613984 diff --git a/TRELLIS/assets/example_image/typical_misc_phonograph.png b/TRELLIS/assets/example_image/typical_misc_phonograph.png new file mode 100644 index 0000000000000000000000000000000000000000..d2f4c143fa51c16c4a3ab4a0a625319d95d4a8be --- /dev/null +++ b/TRELLIS/assets/example_image/typical_misc_phonograph.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:14fff9a27ea769d3ca711e9ff55ab3d9385486a5e8b99117f506df326a0a357e +size 517193 diff --git a/TRELLIS/assets/example_image/typical_misc_portal2.png b/TRELLIS/assets/example_image/typical_misc_portal2.png new file mode 100644 index 0000000000000000000000000000000000000000..6ab0c8b0784eaa370d566854974a7b4f1e0f4ff7 --- /dev/null +++ b/TRELLIS/assets/example_image/typical_misc_portal2.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:57aab2bba56bc946523a3fca77ca70651a4ad8c6fbf1b91a1a824418df48faae +size 386106 diff --git a/TRELLIS/assets/example_image/typical_misc_storage_chest.png b/TRELLIS/assets/example_image/typical_misc_storage_chest.png new file mode 100644 index 0000000000000000000000000000000000000000..2fcd82551d835903a914c8737b5650373776d363 --- /dev/null +++ b/TRELLIS/assets/example_image/typical_misc_storage_chest.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:0e4ac1c67fdda902ecb709447b8defd949c738954c844c1b8364b8e3f7d9e55a +size 632296 diff --git a/TRELLIS/assets/example_image/typical_misc_telephone.png b/TRELLIS/assets/example_image/typical_misc_telephone.png new file mode 100644 index 0000000000000000000000000000000000000000..58e7a3a434274b58ef0ad567a19ec46e16bd4ba4 --- /dev/null +++ b/TRELLIS/assets/example_image/typical_misc_telephone.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:00048be46234a2709c12614b04cbad61c6e3c7e63c2a4ef33d999185f5393e36 +size 647557 diff --git a/TRELLIS/assets/example_image/typical_misc_television.png b/TRELLIS/assets/example_image/typical_misc_television.png new file mode 100644 index 0000000000000000000000000000000000000000..79042fd48c953ccabc2f41333dfa12942ebbb5cd --- /dev/null +++ b/TRELLIS/assets/example_image/typical_misc_television.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:6a1947b737398bf535ec212668a4d78cd38fe84cf9da1ccd6c0c0d838337755e +size 627348 diff --git a/TRELLIS/assets/example_image/typical_misc_workbench.png b/TRELLIS/assets/example_image/typical_misc_workbench.png new file mode 100644 index 0000000000000000000000000000000000000000..c964138ac20d2378c9ba5ecd35eca70f1c728ed2 --- /dev/null +++ b/TRELLIS/assets/example_image/typical_misc_workbench.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:a6d9ed4d005a5253b8571fd976b0d102e293512d7b5a8ed5e3f7f17c5f4e19da +size 462693 diff --git a/TRELLIS/assets/example_image/typical_vehicle_biplane.png b/TRELLIS/assets/example_image/typical_vehicle_biplane.png new file mode 100644 index 0000000000000000000000000000000000000000..4bca4e16ec6453e008d21b5a82dd5667a6ef453a --- /dev/null +++ b/TRELLIS/assets/example_image/typical_vehicle_biplane.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:c73e98112eb603b4ba635b8965cad7807d0588f083811bc2faa0c7ab9668a65a +size 574448 diff --git a/TRELLIS/assets/example_image/typical_vehicle_bulldozer.png b/TRELLIS/assets/example_image/typical_vehicle_bulldozer.png new file mode 100644 index 0000000000000000000000000000000000000000..9fa27c21bb6beca459d19d2cbd536f0a41fdce0d --- /dev/null +++ b/TRELLIS/assets/example_image/typical_vehicle_bulldozer.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:23d821b4daea61cbea28cc6ddd3ae46712514dfcdff995c2664f5a70d21f4ef3 +size 692767 diff --git a/TRELLIS/assets/example_image/typical_vehicle_cart.png b/TRELLIS/assets/example_image/typical_vehicle_cart.png new file mode 100644 index 0000000000000000000000000000000000000000..d8848f3473355f2040c67a803847947e086dc969 --- /dev/null +++ b/TRELLIS/assets/example_image/typical_vehicle_cart.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:b72c04a2aa5cf57717c05151a2982d6dc31afde130d5e830adf37a84a70616cb +size 693145 diff --git a/TRELLIS/assets/example_image/typical_vehicle_excavator.png b/TRELLIS/assets/example_image/typical_vehicle_excavator.png new file mode 100644 index 0000000000000000000000000000000000000000..3dddee1020e052155d2e8f404d982f45786c5d07 --- /dev/null +++ b/TRELLIS/assets/example_image/typical_vehicle_excavator.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:27a418853eefa197f1e10ed944a7bb071413fd2bc1681804ee773a6ce3799c52 +size 712062 diff --git a/TRELLIS/assets/example_image/typical_vehicle_helicopter.png b/TRELLIS/assets/example_image/typical_vehicle_helicopter.png new file mode 100644 index 0000000000000000000000000000000000000000..499cc8e87a468582d3c64b96c30ddb11dbb04f1f --- /dev/null +++ b/TRELLIS/assets/example_image/typical_vehicle_helicopter.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:7f1a1b37bc52417c0e1048927a30bf3a52dde81345f90114040608186196ffe7 +size 352731 diff --git a/TRELLIS/assets/example_image/typical_vehicle_locomotive.png b/TRELLIS/assets/example_image/typical_vehicle_locomotive.png new file mode 100644 index 0000000000000000000000000000000000000000..658e79d82868ae6e86ac108914169deafe02a53b --- /dev/null +++ b/TRELLIS/assets/example_image/typical_vehicle_locomotive.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:67d5124e7069b133dc0aaa16047a52c6dc1d7c2a4e4510ffd3235fe95597fbef +size 806246 diff --git a/TRELLIS/assets/example_image/typical_vehicle_pirate_ship.png b/TRELLIS/assets/example_image/typical_vehicle_pirate_ship.png new file mode 100644 index 0000000000000000000000000000000000000000..a3baba4b951316dc7e0e940c22ba6f537bd5db14 --- /dev/null +++ b/TRELLIS/assets/example_image/typical_vehicle_pirate_ship.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:8926ec7c9f36a52e3bf1ca4e8cfc75d297da934fe7c0e8d7a73f0d35a5ef38ad +size 611289 diff --git a/TRELLIS/assets/example_image/weatherworn_misc_paper_machine3.png b/TRELLIS/assets/example_image/weatherworn_misc_paper_machine3.png new file mode 100644 index 0000000000000000000000000000000000000000..253a9c5dd331fc1a3d0d66cb5e9afb0b5b82ac46 --- /dev/null +++ b/TRELLIS/assets/example_image/weatherworn_misc_paper_machine3.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:3c6fbf47ed53ffad1a3027f72bf0806c238682c7bf7604b8770aef428906d33b +size 502071 diff --git a/TRELLIS/assets/example_multi_image/character_1.png b/TRELLIS/assets/example_multi_image/character_1.png new file mode 100644 index 0000000000000000000000000000000000000000..9f56066845cfd18504ace4924c9a94544d55280c --- /dev/null +++ b/TRELLIS/assets/example_multi_image/character_1.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:729e2e0214232e1dd45c9187e339f8a2a87c6e41257ef701e578a4f0a8be7ef1 +size 171816 diff --git a/TRELLIS/assets/example_multi_image/character_2.png b/TRELLIS/assets/example_multi_image/character_2.png new file mode 100644 index 0000000000000000000000000000000000000000..32de7843e9458be83c0adbd5ddbd526d3db59c94 --- /dev/null +++ b/TRELLIS/assets/example_multi_image/character_2.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:e8afc8af9960a5f2315d9d5b9815f29137ef9b63c4d512c451a8ba374003c3ac +size 198217 diff --git a/TRELLIS/assets/example_multi_image/character_3.png b/TRELLIS/assets/example_multi_image/character_3.png new file mode 100644 index 0000000000000000000000000000000000000000..036b8f12cc2e2744fa3ebf3fba41114b07b340be --- /dev/null +++ b/TRELLIS/assets/example_multi_image/character_3.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:3413a2b4f67105b947a42ebbe14d3f6ab9f68a99f2258a86fd50d94342b49bdd +size 145868 diff --git a/TRELLIS/assets/example_multi_image/mushroom_1.png b/TRELLIS/assets/example_multi_image/mushroom_1.png new file mode 100644 index 0000000000000000000000000000000000000000..ac2f8342feadfab20e91fbbc101a2c4db48a8556 --- /dev/null +++ b/TRELLIS/assets/example_multi_image/mushroom_1.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:3e5fd9ee75d39c827b0c5544392c255a89b4ca62bf3cf31f702d39b150bea00c +size 434321 diff --git a/TRELLIS/assets/example_multi_image/mushroom_2.png b/TRELLIS/assets/example_multi_image/mushroom_2.png new file mode 100644 index 0000000000000000000000000000000000000000..c81fb2dbf291c1898afbbfb61cce8e781ac7c42e --- /dev/null +++ b/TRELLIS/assets/example_multi_image/mushroom_2.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:6e5709b910341b12d149632b8442aebd218dd591b238ccb7e4e8b185860aae04 +size 461694 diff --git a/TRELLIS/assets/example_multi_image/mushroom_3.png b/TRELLIS/assets/example_multi_image/mushroom_3.png new file mode 100644 index 0000000000000000000000000000000000000000..67c68bcfccbf331fa132c3c41d8c40b8c0548a31 --- /dev/null +++ b/TRELLIS/assets/example_multi_image/mushroom_3.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:115c9c3a11d3d08de568680468ad42ac1322c21bdad46a43e149fc97cc687e48 +size 424840 diff --git a/TRELLIS/assets/example_multi_image/orangeguy_1.png b/TRELLIS/assets/example_multi_image/orangeguy_1.png new file mode 100644 index 0000000000000000000000000000000000000000..a89f44e74fd8ed4d7d140b5447315d60266081be --- /dev/null +++ b/TRELLIS/assets/example_multi_image/orangeguy_1.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ab30ee372fc365e5d100f2e06ea7cd17b3ea3f53b1a76ebe44e69a1cf834700e +size 632098 diff --git a/TRELLIS/assets/example_multi_image/orangeguy_2.png b/TRELLIS/assets/example_multi_image/orangeguy_2.png new file mode 100644 index 0000000000000000000000000000000000000000..b2554f8502d33981926afcfa8456dab03571ae63 --- /dev/null +++ b/TRELLIS/assets/example_multi_image/orangeguy_2.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:f7ceecb68666ce692bf0a292fe4e125e657ece2babb67a7df89ab1827ad18665 +size 447045 diff --git a/TRELLIS/assets/example_multi_image/orangeguy_3.png b/TRELLIS/assets/example_multi_image/orangeguy_3.png new file mode 100644 index 0000000000000000000000000000000000000000..207e70ebc17912569261d5796b942b1d64dd39d7 --- /dev/null +++ b/TRELLIS/assets/example_multi_image/orangeguy_3.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:fb89267bc91e0c0d1fdb148a5a5cd14e594f58e0a3da7b48f4e4cb61a3b7dcf1 +size 545243 diff --git a/TRELLIS/assets/example_multi_image/popmart_1.png b/TRELLIS/assets/example_multi_image/popmart_1.png new file mode 100644 index 0000000000000000000000000000000000000000..f4437187e4b20d0d35fcee7718bdea4aaebe5aa9 --- /dev/null +++ b/TRELLIS/assets/example_multi_image/popmart_1.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:90a1fff9944982d66d13abb78da87a35092c26b905efd5a58b3c077bb5fed2cd +size 311311 diff --git a/TRELLIS/assets/example_multi_image/popmart_2.png b/TRELLIS/assets/example_multi_image/popmart_2.png new file mode 100644 index 0000000000000000000000000000000000000000..3747f78e72c5d3199ab077511b1e1897356a7f85 --- /dev/null +++ b/TRELLIS/assets/example_multi_image/popmart_2.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:c74822f7fcf7b95d230e4578342272644a328ea175a90fa492f1b74ee438bdb4 +size 288176 diff --git a/TRELLIS/assets/example_multi_image/popmart_3.png b/TRELLIS/assets/example_multi_image/popmart_3.png new file mode 100644 index 0000000000000000000000000000000000000000..230775ef5b47d221e106aa2395cb70af93befa49 --- /dev/null +++ b/TRELLIS/assets/example_multi_image/popmart_3.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:239b41cf5f42ddcb1d3ba3f273e4ec66f2601b3303562ee97080d1904d9707b9 +size 310746 diff --git a/TRELLIS/assets/example_multi_image/rabbit_1.png b/TRELLIS/assets/example_multi_image/rabbit_1.png new file mode 100644 index 0000000000000000000000000000000000000000..f8389134e05238573f19962c2051a4a287d5fe81 --- /dev/null +++ b/TRELLIS/assets/example_multi_image/rabbit_1.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:2f64979aae844de9e04f08c3f2cb8b6c47e4681c015afc624375ed6ac6aff0cd +size 359949 diff --git a/TRELLIS/assets/example_multi_image/rabbit_2.png b/TRELLIS/assets/example_multi_image/rabbit_2.png new file mode 100644 index 0000000000000000000000000000000000000000..03f5d8d502fc6a10a1a915462272cc71075bc51e --- /dev/null +++ b/TRELLIS/assets/example_multi_image/rabbit_2.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:85bb8760b5fb4c2e5c36e169a3cd0f95ca3164bd4764b8a830d220cd852f60ed +size 492291 diff --git a/TRELLIS/assets/example_multi_image/rabbit_3.png b/TRELLIS/assets/example_multi_image/rabbit_3.png new file mode 100644 index 0000000000000000000000000000000000000000..4e0f07e8a188e456451da5961b866b74d98d104f --- /dev/null +++ b/TRELLIS/assets/example_multi_image/rabbit_3.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:f1f10a6c04e10ce27a20fd583ff443238bff2b6a49c485e5cf0ed6952a20c9c5 +size 339992 diff --git a/TRELLIS/assets/example_multi_image/tiger_1.png b/TRELLIS/assets/example_multi_image/tiger_1.png new file mode 100644 index 0000000000000000000000000000000000000000..bfd8a2cb78ac39ff4fbd46555e0af273d2319a27 --- /dev/null +++ b/TRELLIS/assets/example_multi_image/tiger_1.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:f2a61f5e48bd1a68d90c3afaf9ad8e5a7d955bc31ef7032b7fd876647922be1d +size 426296 diff --git a/TRELLIS/assets/example_multi_image/tiger_2.png b/TRELLIS/assets/example_multi_image/tiger_2.png new file mode 100644 index 0000000000000000000000000000000000000000..23ba80ad8bdd22f161e59b6d7b0a17e2f9082168 --- /dev/null +++ b/TRELLIS/assets/example_multi_image/tiger_2.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:e5cab1d833a2541ca18e56fada83f5fe6b0cfe84c2d7e72382e2c93fcabc33a0 +size 731448 diff --git a/TRELLIS/assets/example_multi_image/tiger_3.png b/TRELLIS/assets/example_multi_image/tiger_3.png new file mode 100644 index 0000000000000000000000000000000000000000..0e0831948ddbc2ecb108a141ad7345ea55e408b4 --- /dev/null +++ b/TRELLIS/assets/example_multi_image/tiger_3.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:5614fcb8e85dba6f30af6c992d1fe63e33b95ab70b4c870de4f0761fcecd5d9c +size 646392 diff --git a/TRELLIS/assets/example_multi_image/yoimiya_1.png b/TRELLIS/assets/example_multi_image/yoimiya_1.png new file mode 100644 index 0000000000000000000000000000000000000000..bb6519735180acb52c952cfa99be6f57d859941c --- /dev/null +++ b/TRELLIS/assets/example_multi_image/yoimiya_1.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:e6d9e7e0a3c75a736a25aeca6986afaa2adcd1ce96277484cbc9c41bf63d0230 +size 245840 diff --git a/TRELLIS/assets/example_multi_image/yoimiya_2.png b/TRELLIS/assets/example_multi_image/yoimiya_2.png new file mode 100644 index 0000000000000000000000000000000000000000..e2f03461b9f43ab58ad7810371a256744c215d70 --- /dev/null +++ b/TRELLIS/assets/example_multi_image/yoimiya_2.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:6e0703e3a557949a3176f2b02c8601ff883e22ce81962a75a21dbff9d5848ccf +size 182235 diff --git a/TRELLIS/assets/example_multi_image/yoimiya_3.png b/TRELLIS/assets/example_multi_image/yoimiya_3.png new file mode 100644 index 0000000000000000000000000000000000000000..54b268dfe229dcebc192e458b0f5c30ef34458d1 --- /dev/null +++ b/TRELLIS/assets/example_multi_image/yoimiya_3.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:56352f95536fcfe0996c33909e39e5beb423bd22b4bc48936938717073785305 +size 239134 diff --git a/TRELLIS/assets/logo.webp b/TRELLIS/assets/logo.webp new file mode 100644 index 0000000000000000000000000000000000000000..aaf832024623414cb6336b46d5b8b27b7b7b039a --- /dev/null +++ b/TRELLIS/assets/logo.webp @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:1548a7b7f6b0fb3c06091529bb5052f0ee9a119eb4e1a014325d6561e9b9f2d1 +size 1403066 diff --git a/TRELLIS/assets/teaser.png b/TRELLIS/assets/teaser.png new file mode 100644 index 0000000000000000000000000000000000000000..5dae838acc6418b48af635c9bceefaa84e5ea446 --- /dev/null +++ b/TRELLIS/assets/teaser.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:a83608f6f4ae71eb7b96965a16978b5f5fb0d594a34bf1066c105f93af71c4d5 +size 2098442 diff --git a/TRELLIS/dataset_toolkits/blender_script/io_scene_usdz.zip b/TRELLIS/dataset_toolkits/blender_script/io_scene_usdz.zip new file mode 100644 index 0000000000000000000000000000000000000000..153a9d97245cfbf98c2c53869987d82f1396922e --- /dev/null +++ b/TRELLIS/dataset_toolkits/blender_script/io_scene_usdz.zip @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ec07ab6125fe0a021ed08c64169eceda126330401aba3d494d5203d26ac4b093 +size 34685 diff --git a/TRELLIS/dataset_toolkits/blender_script/render.py b/TRELLIS/dataset_toolkits/blender_script/render.py new file mode 100644 index 0000000000000000000000000000000000000000..7be59d36c3973093cc16d8858591870c31a16e14 --- /dev/null +++ b/TRELLIS/dataset_toolkits/blender_script/render.py @@ -0,0 +1,528 @@ +import argparse, sys, os, math, re, glob +from typing import * +import bpy +from mathutils import Vector, Matrix +import numpy as np +import json +import glob + + +"""=============== BLENDER ===============""" + +IMPORT_FUNCTIONS: Dict[str, Callable] = { + "obj": bpy.ops.import_scene.obj, + "glb": bpy.ops.import_scene.gltf, + "gltf": bpy.ops.import_scene.gltf, + "usd": bpy.ops.import_scene.usd, + "fbx": bpy.ops.import_scene.fbx, + "stl": bpy.ops.import_mesh.stl, + "usda": bpy.ops.import_scene.usda, + "dae": bpy.ops.wm.collada_import, + "ply": bpy.ops.import_mesh.ply, + "abc": bpy.ops.wm.alembic_import, + "blend": bpy.ops.wm.append, +} + +EXT = { + 'PNG': 'png', + 'JPEG': 'jpg', + 'OPEN_EXR': 'exr', + 'TIFF': 'tiff', + 'BMP': 'bmp', + 'HDR': 'hdr', + 'TARGA': 'tga' +} + +def init_render(engine='CYCLES', resolution=512, geo_mode=False): + bpy.context.scene.render.engine = engine + bpy.context.scene.render.resolution_x = resolution + bpy.context.scene.render.resolution_y = resolution + bpy.context.scene.render.resolution_percentage = 100 + bpy.context.scene.render.image_settings.file_format = 'PNG' + bpy.context.scene.render.image_settings.color_mode = 'RGBA' + bpy.context.scene.render.film_transparent = True + + bpy.context.scene.cycles.device = 'GPU' + bpy.context.scene.cycles.samples = 128 if not geo_mode else 1 + bpy.context.scene.cycles.filter_type = 'BOX' + bpy.context.scene.cycles.filter_width = 1 + bpy.context.scene.cycles.diffuse_bounces = 1 + bpy.context.scene.cycles.glossy_bounces = 1 + bpy.context.scene.cycles.transparent_max_bounces = 3 if not geo_mode else 0 + bpy.context.scene.cycles.transmission_bounces = 3 if not geo_mode else 1 + bpy.context.scene.cycles.use_denoising = True + + bpy.context.preferences.addons['cycles'].preferences.get_devices() + bpy.context.preferences.addons['cycles'].preferences.compute_device_type = 'CUDA' + +def init_nodes(save_depth=False, save_normal=False, save_albedo=False, save_mist=False): + if not any([save_depth, save_normal, save_albedo, save_mist]): + return {}, {} + outputs = {} + spec_nodes = {} + + bpy.context.scene.use_nodes = True + bpy.context.scene.view_layers['View Layer'].use_pass_z = save_depth + bpy.context.scene.view_layers['View Layer'].use_pass_normal = save_normal + bpy.context.scene.view_layers['View Layer'].use_pass_diffuse_color = save_albedo + bpy.context.scene.view_layers['View Layer'].use_pass_mist = save_mist + + nodes = bpy.context.scene.node_tree.nodes + links = bpy.context.scene.node_tree.links + for n in nodes: + nodes.remove(n) + + render_layers = nodes.new('CompositorNodeRLayers') + + if save_depth: + depth_file_output = nodes.new('CompositorNodeOutputFile') + depth_file_output.base_path = '' + depth_file_output.file_slots[0].use_node_format = True + depth_file_output.format.file_format = 'PNG' + depth_file_output.format.color_depth = '16' + depth_file_output.format.color_mode = 'BW' + # Remap to 0-1 + map = nodes.new(type="CompositorNodeMapRange") + map.inputs[1].default_value = 0 # (min value you will be getting) + map.inputs[2].default_value = 10 # (max value you will be getting) + map.inputs[3].default_value = 0 # (min value you will map to) + map.inputs[4].default_value = 1 # (max value you will map to) + + links.new(render_layers.outputs['Depth'], map.inputs[0]) + links.new(map.outputs[0], depth_file_output.inputs[0]) + + outputs['depth'] = depth_file_output + spec_nodes['depth_map'] = map + + if save_normal: + normal_file_output = nodes.new('CompositorNodeOutputFile') + normal_file_output.base_path = '' + normal_file_output.file_slots[0].use_node_format = True + normal_file_output.format.file_format = 'OPEN_EXR' + normal_file_output.format.color_mode = 'RGB' + normal_file_output.format.color_depth = '16' + + links.new(render_layers.outputs['Normal'], normal_file_output.inputs[0]) + + outputs['normal'] = normal_file_output + + if save_albedo: + albedo_file_output = nodes.new('CompositorNodeOutputFile') + albedo_file_output.base_path = '' + albedo_file_output.file_slots[0].use_node_format = True + albedo_file_output.format.file_format = 'PNG' + albedo_file_output.format.color_mode = 'RGBA' + albedo_file_output.format.color_depth = '8' + + alpha_albedo = nodes.new('CompositorNodeSetAlpha') + + links.new(render_layers.outputs['DiffCol'], alpha_albedo.inputs['Image']) + links.new(render_layers.outputs['Alpha'], alpha_albedo.inputs['Alpha']) + links.new(alpha_albedo.outputs['Image'], albedo_file_output.inputs[0]) + + outputs['albedo'] = albedo_file_output + + if save_mist: + bpy.data.worlds['World'].mist_settings.start = 0 + bpy.data.worlds['World'].mist_settings.depth = 10 + + mist_file_output = nodes.new('CompositorNodeOutputFile') + mist_file_output.base_path = '' + mist_file_output.file_slots[0].use_node_format = True + mist_file_output.format.file_format = 'PNG' + mist_file_output.format.color_mode = 'BW' + mist_file_output.format.color_depth = '16' + + links.new(render_layers.outputs['Mist'], mist_file_output.inputs[0]) + + outputs['mist'] = mist_file_output + + return outputs, spec_nodes + +def init_scene() -> None: + """Resets the scene to a clean state. + + Returns: + None + """ + # delete everything + for obj in bpy.data.objects: + bpy.data.objects.remove(obj, do_unlink=True) + + # delete all the materials + for material in bpy.data.materials: + bpy.data.materials.remove(material, do_unlink=True) + + # delete all the textures + for texture in bpy.data.textures: + bpy.data.textures.remove(texture, do_unlink=True) + + # delete all the images + for image in bpy.data.images: + bpy.data.images.remove(image, do_unlink=True) + +def init_camera(): + cam = bpy.data.objects.new('Camera', bpy.data.cameras.new('Camera')) + bpy.context.collection.objects.link(cam) + bpy.context.scene.camera = cam + cam.data.sensor_height = cam.data.sensor_width = 32 + cam_constraint = cam.constraints.new(type='TRACK_TO') + cam_constraint.track_axis = 'TRACK_NEGATIVE_Z' + cam_constraint.up_axis = 'UP_Y' + cam_empty = bpy.data.objects.new("Empty", None) + cam_empty.location = (0, 0, 0) + bpy.context.scene.collection.objects.link(cam_empty) + cam_constraint.target = cam_empty + return cam + +def init_lighting(): + # Clear existing lights + bpy.ops.object.select_all(action="DESELECT") + bpy.ops.object.select_by_type(type="LIGHT") + bpy.ops.object.delete() + + # Create key light + default_light = bpy.data.objects.new("Default_Light", bpy.data.lights.new("Default_Light", type="POINT")) + bpy.context.collection.objects.link(default_light) + default_light.data.energy = 1000 + default_light.location = (4, 1, 6) + default_light.rotation_euler = (0, 0, 0) + + # create top light + top_light = bpy.data.objects.new("Top_Light", bpy.data.lights.new("Top_Light", type="AREA")) + bpy.context.collection.objects.link(top_light) + top_light.data.energy = 10000 + top_light.location = (0, 0, 10) + top_light.scale = (100, 100, 100) + + # create bottom light + bottom_light = bpy.data.objects.new("Bottom_Light", bpy.data.lights.new("Bottom_Light", type="AREA")) + bpy.context.collection.objects.link(bottom_light) + bottom_light.data.energy = 1000 + bottom_light.location = (0, 0, -10) + bottom_light.rotation_euler = (0, 0, 0) + + return { + "default_light": default_light, + "top_light": top_light, + "bottom_light": bottom_light + } + + +def load_object(object_path: str) -> None: + """Loads a model with a supported file extension into the scene. + + Args: + object_path (str): Path to the model file. + + Raises: + ValueError: If the file extension is not supported. + + Returns: + None + """ + file_extension = object_path.split(".")[-1].lower() + if file_extension is None: + raise ValueError(f"Unsupported file type: {object_path}") + + if file_extension == "usdz": + # install usdz io package + dirname = os.path.dirname(os.path.realpath(__file__)) + usdz_package = os.path.join(dirname, "io_scene_usdz.zip") + bpy.ops.preferences.addon_install(filepath=usdz_package) + # enable it + addon_name = "io_scene_usdz" + bpy.ops.preferences.addon_enable(module=addon_name) + # import the usdz + from io_scene_usdz.import_usdz import import_usdz + + import_usdz(context, filepath=object_path, materials=True, animations=True) + return None + + # load from existing import functions + import_function = IMPORT_FUNCTIONS[file_extension] + + print(f"Loading object from {object_path}") + if file_extension == "blend": + import_function(directory=object_path, link=False) + elif file_extension in {"glb", "gltf"}: + import_function(filepath=object_path, merge_vertices=True, import_shading='NORMALS') + else: + import_function(filepath=object_path) + +def delete_invisible_objects() -> None: + """Deletes all invisible objects in the scene. + + Returns: + None + """ + # bpy.ops.object.mode_set(mode="OBJECT") + bpy.ops.object.select_all(action="DESELECT") + for obj in bpy.context.scene.objects: + if obj.hide_viewport or obj.hide_render: + obj.hide_viewport = False + obj.hide_render = False + obj.hide_select = False + obj.select_set(True) + bpy.ops.object.delete() + + # Delete invisible collections + invisible_collections = [col for col in bpy.data.collections if col.hide_viewport] + for col in invisible_collections: + bpy.data.collections.remove(col) + +def split_mesh_normal(): + bpy.ops.object.select_all(action="DESELECT") + objs = [obj for obj in bpy.context.scene.objects if obj.type == "MESH"] + bpy.context.view_layer.objects.active = objs[0] + for obj in objs: + obj.select_set(True) + bpy.ops.object.mode_set(mode="EDIT") + bpy.ops.mesh.select_all(action='SELECT') + bpy.ops.mesh.split_normals() + bpy.ops.object.mode_set(mode='OBJECT') + bpy.ops.object.select_all(action="DESELECT") + +def delete_custom_normals(): + for this_obj in bpy.data.objects: + if this_obj.type == "MESH": + bpy.context.view_layer.objects.active = this_obj + bpy.ops.mesh.customdata_custom_splitnormals_clear() + +def override_material(): + new_mat = bpy.data.materials.new(name="Override0123456789") + new_mat.use_nodes = True + new_mat.node_tree.nodes.clear() + bsdf = new_mat.node_tree.nodes.new('ShaderNodeBsdfDiffuse') + bsdf.inputs[0].default_value = (0.5, 0.5, 0.5, 1) + bsdf.inputs[1].default_value = 1 + output = new_mat.node_tree.nodes.new('ShaderNodeOutputMaterial') + new_mat.node_tree.links.new(bsdf.outputs['BSDF'], output.inputs['Surface']) + bpy.context.scene.view_layers['View Layer'].material_override = new_mat + +def unhide_all_objects() -> None: + """Unhides all objects in the scene. + + Returns: + None + """ + for obj in bpy.context.scene.objects: + obj.hide_set(False) + +def convert_to_meshes() -> None: + """Converts all objects in the scene to meshes. + + Returns: + None + """ + bpy.ops.object.select_all(action="DESELECT") + bpy.context.view_layer.objects.active = [obj for obj in bpy.context.scene.objects if obj.type == "MESH"][0] + for obj in bpy.context.scene.objects: + obj.select_set(True) + bpy.ops.object.convert(target="MESH") + +def triangulate_meshes() -> None: + """Triangulates all meshes in the scene. + + Returns: + None + """ + bpy.ops.object.select_all(action="DESELECT") + objs = [obj for obj in bpy.context.scene.objects if obj.type == "MESH"] + bpy.context.view_layer.objects.active = objs[0] + for obj in objs: + obj.select_set(True) + bpy.ops.object.mode_set(mode="EDIT") + bpy.ops.mesh.reveal() + bpy.ops.mesh.select_all(action="SELECT") + bpy.ops.mesh.quads_convert_to_tris(quad_method="BEAUTY", ngon_method="BEAUTY") + bpy.ops.object.mode_set(mode="OBJECT") + bpy.ops.object.select_all(action="DESELECT") + +def scene_bbox() -> Tuple[Vector, Vector]: + """Returns the bounding box of the scene. + + Taken from Shap-E rendering script + (https://github.com/openai/shap-e/blob/main/shap_e/rendering/blender/blender_script.py#L68-L82) + + Returns: + Tuple[Vector, Vector]: The minimum and maximum coordinates of the bounding box. + """ + bbox_min = (math.inf,) * 3 + bbox_max = (-math.inf,) * 3 + found = False + scene_meshes = [obj for obj in bpy.context.scene.objects.values() if isinstance(obj.data, bpy.types.Mesh)] + for obj in scene_meshes: + found = True + for coord in obj.bound_box: + coord = Vector(coord) + coord = obj.matrix_world @ coord + bbox_min = tuple(min(x, y) for x, y in zip(bbox_min, coord)) + bbox_max = tuple(max(x, y) for x, y in zip(bbox_max, coord)) + if not found: + raise RuntimeError("no objects in scene to compute bounding box for") + return Vector(bbox_min), Vector(bbox_max) + +def normalize_scene() -> Tuple[float, Vector]: + """Normalizes the scene by scaling and translating it to fit in a unit cube centered + at the origin. + + Mostly taken from the Point-E / Shap-E rendering script + (https://github.com/openai/point-e/blob/main/point_e/evals/scripts/blender_script.py#L97-L112), + but fix for multiple root objects: (see bug report here: + https://github.com/openai/shap-e/pull/60). + + Returns: + Tuple[float, Vector]: The scale factor and the offset applied to the scene. + """ + scene_root_objects = [obj for obj in bpy.context.scene.objects.values() if not obj.parent] + if len(scene_root_objects) > 1: + # create an empty object to be used as a parent for all root objects + scene = bpy.data.objects.new("ParentEmpty", None) + bpy.context.scene.collection.objects.link(scene) + + # parent all root objects to the empty object + for obj in scene_root_objects: + obj.parent = scene + else: + scene = scene_root_objects[0] + + bbox_min, bbox_max = scene_bbox() + scale = 1 / max(bbox_max - bbox_min) + scene.scale = scene.scale * scale + + # Apply scale to matrix_world. + bpy.context.view_layer.update() + bbox_min, bbox_max = scene_bbox() + offset = -(bbox_min + bbox_max) / 2 + scene.matrix_world.translation += offset + bpy.ops.object.select_all(action="DESELECT") + + return scale, offset + +def get_transform_matrix(obj: bpy.types.Object) -> list: + pos, rt, _ = obj.matrix_world.decompose() + rt = rt.to_matrix() + matrix = [] + for ii in range(3): + a = [] + for jj in range(3): + a.append(rt[ii][jj]) + a.append(pos[ii]) + matrix.append(a) + matrix.append([0, 0, 0, 1]) + return matrix + +def main(arg): + os.makedirs(arg.output_folder, exist_ok=True) + + # Initialize context + init_render(engine=arg.engine, resolution=arg.resolution, geo_mode=arg.geo_mode) + outputs, spec_nodes = init_nodes( + save_depth=arg.save_depth, + save_normal=arg.save_normal, + save_albedo=arg.save_albedo, + save_mist=arg.save_mist + ) + if arg.object.endswith(".blend"): + delete_invisible_objects() + else: + init_scene() + load_object(arg.object) + if arg.split_normal: + split_mesh_normal() + # delete_custom_normals() + print('[INFO] Scene initialized.') + + # normalize scene + scale, offset = normalize_scene() + print('[INFO] Scene normalized.') + + # Initialize camera and lighting + cam = init_camera() + init_lighting() + print('[INFO] Camera and lighting initialized.') + + # Override material + if arg.geo_mode: + override_material() + + # Create a list of views + to_export = { + "aabb": [[-0.5, -0.5, -0.5], [0.5, 0.5, 0.5]], + "scale": scale, + "offset": [offset.x, offset.y, offset.z], + "frames": [] + } + views = json.loads(arg.views) + for i, view in enumerate(views): + cam.location = ( + view['radius'] * np.cos(view['yaw']) * np.cos(view['pitch']), + view['radius'] * np.sin(view['yaw']) * np.cos(view['pitch']), + view['radius'] * np.sin(view['pitch']) + ) + cam.data.lens = 16 / np.tan(view['fov'] / 2) + + if arg.save_depth: + spec_nodes['depth_map'].inputs[1].default_value = view['radius'] - 0.5 * np.sqrt(3) + spec_nodes['depth_map'].inputs[2].default_value = view['radius'] + 0.5 * np.sqrt(3) + + bpy.context.scene.render.filepath = os.path.join(arg.output_folder, f'{i:03d}.png') + for name, output in outputs.items(): + output.file_slots[0].path = os.path.join(arg.output_folder, f'{i:03d}_{name}') + + # Render the scene + bpy.ops.render.render(write_still=True) + bpy.context.view_layer.update() + for name, output in outputs.items(): + ext = EXT[output.format.file_format] + path = glob.glob(f'{output.file_slots[0].path}*.{ext}')[0] + os.rename(path, f'{output.file_slots[0].path}.{ext}') + + # Save camera parameters + metadata = { + "file_path": f'{i:03d}.png', + "camera_angle_x": view['fov'], + "transform_matrix": get_transform_matrix(cam) + } + if arg.save_depth: + metadata['depth'] = { + 'min': view['radius'] - 0.5 * np.sqrt(3), + 'max': view['radius'] + 0.5 * np.sqrt(3) + } + to_export["frames"].append(metadata) + + # Save the camera parameters + with open(os.path.join(arg.output_folder, 'transforms.json'), 'w') as f: + json.dump(to_export, f, indent=4) + + if arg.save_mesh: + # triangulate meshes + unhide_all_objects() + convert_to_meshes() + triangulate_meshes() + print('[INFO] Meshes triangulated.') + + # export ply mesh + bpy.ops.export_mesh.ply(filepath=os.path.join(arg.output_folder, 'mesh.ply')) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser(description='Renders given obj file by rotation a camera around it.') + parser.add_argument('--views', type=str, help='JSON string of views. Contains a list of {yaw, pitch, radius, fov} object.') + parser.add_argument('--object', type=str, help='Path to the 3D model file to be rendered.') + parser.add_argument('--output_folder', type=str, default='/tmp', help='The path the output will be dumped to.') + parser.add_argument('--resolution', type=int, default=512, help='Resolution of the images.') + parser.add_argument('--engine', type=str, default='CYCLES', help='Blender internal engine for rendering. E.g. CYCLES, BLENDER_EEVEE, ...') + parser.add_argument('--geo_mode', action='store_true', help='Geometry mode for rendering.') + parser.add_argument('--save_depth', action='store_true', help='Save the depth maps.') + parser.add_argument('--save_normal', action='store_true', help='Save the normal maps.') + parser.add_argument('--save_albedo', action='store_true', help='Save the albedo maps.') + parser.add_argument('--save_mist', action='store_true', help='Save the mist distance maps.') + parser.add_argument('--split_normal', action='store_true', help='Split the normals of the mesh.') + parser.add_argument('--save_mesh', action='store_true', help='Save the mesh as a .ply file.') + argv = sys.argv[sys.argv.index("--") + 1:] + args = parser.parse_args(argv) + + main(args) + \ No newline at end of file diff --git a/TRELLIS/dataset_toolkits/build_metadata.py b/TRELLIS/dataset_toolkits/build_metadata.py new file mode 100644 index 0000000000000000000000000000000000000000..c7c65d6be23ce3631d63be34e39ca1fa0d33840c --- /dev/null +++ b/TRELLIS/dataset_toolkits/build_metadata.py @@ -0,0 +1,270 @@ +import os +import shutil +import sys +import time +import importlib +import argparse +import numpy as np +import pandas as pd +from tqdm import tqdm +from easydict import EasyDict as edict +from concurrent.futures import ThreadPoolExecutor +import utils3d + +def get_first_directory(path): + with os.scandir(path) as it: + for entry in it: + if entry.is_dir(): + return entry.name + return None + +def need_process(key): + return key in opt.field or opt.field == ['all'] + +if __name__ == '__main__': + dataset_utils = importlib.import_module(f'datasets.{sys.argv[1]}') + + parser = argparse.ArgumentParser() + parser.add_argument('--output_dir', type=str, required=True, + help='Directory to save the metadata') + parser.add_argument('--field', type=str, default='all', + help='Fields to process, separated by commas') + parser.add_argument('--from_file', action='store_true', + help='Build metadata from file instead of from records of processings.' + + 'Useful when some processing fail to generate records but file already exists.') + dataset_utils.add_args(parser) + opt = parser.parse_args(sys.argv[2:]) + opt = edict(vars(opt)) + + os.makedirs(opt.output_dir, exist_ok=True) + os.makedirs(os.path.join(opt.output_dir, 'merged_records'), exist_ok=True) + + opt.field = opt.field.split(',') + + timestamp = str(int(time.time())) + + # get file list + if os.path.exists(os.path.join(opt.output_dir, 'metadata.csv')): + print('Loading previous metadata...') + metadata = pd.read_csv(os.path.join(opt.output_dir, 'metadata.csv')) + else: + metadata = dataset_utils.get_metadata(**opt) + metadata.set_index('sha256', inplace=True) + + # merge downloaded + df_files = [f for f in os.listdir(opt.output_dir) if f.startswith('downloaded_') and f.endswith('.csv')] + df_parts = [] + for f in df_files: + try: + df_parts.append(pd.read_csv(os.path.join(opt.output_dir, f))) + except: + pass + if len(df_parts) > 0: + df = pd.concat(df_parts) + df.set_index('sha256', inplace=True) + if 'local_path' in metadata.columns: + metadata.update(df, overwrite=True) + else: + metadata = metadata.join(df, on='sha256', how='left') + for f in df_files: + shutil.move(os.path.join(opt.output_dir, f), os.path.join(opt.output_dir, 'merged_records', f'{timestamp}_{f}')) + + # detect models + image_models = [] + if os.path.exists(os.path.join(opt.output_dir, 'features')): + image_models = os.listdir(os.path.join(opt.output_dir, 'features')) + latent_models = [] + if os.path.exists(os.path.join(opt.output_dir, 'latents')): + latent_models = os.listdir(os.path.join(opt.output_dir, 'latents')) + ss_latent_models = [] + if os.path.exists(os.path.join(opt.output_dir, 'ss_latents')): + ss_latent_models = os.listdir(os.path.join(opt.output_dir, 'ss_latents')) + print(f'Image models: {image_models}') + print(f'Latent models: {latent_models}') + print(f'Sparse Structure latent models: {ss_latent_models}') + + if 'rendered' not in metadata.columns: + metadata['rendered'] = [False] * len(metadata) + if 'voxelized' not in metadata.columns: + metadata['voxelized'] = [False] * len(metadata) + if 'num_voxels' not in metadata.columns: + metadata['num_voxels'] = [0] * len(metadata) + if 'cond_rendered' not in metadata.columns: + metadata['cond_rendered'] = [False] * len(metadata) + for model in image_models: + if f'feature_{model}' not in metadata.columns: + metadata[f'feature_{model}'] = [False] * len(metadata) + for model in latent_models: + if f'latent_{model}' not in metadata.columns: + metadata[f'latent_{model}'] = [False] * len(metadata) + for model in ss_latent_models: + if f'ss_latent_{model}' not in metadata.columns: + metadata[f'ss_latent_{model}'] = [False] * len(metadata) + + # merge rendered + df_files = [f for f in os.listdir(opt.output_dir) if f.startswith('rendered_') and f.endswith('.csv')] + df_parts = [] + for f in df_files: + try: + df_parts.append(pd.read_csv(os.path.join(opt.output_dir, f))) + except: + pass + if len(df_parts) > 0: + df = pd.concat(df_parts) + df.set_index('sha256', inplace=True) + metadata.update(df, overwrite=True) + for f in df_files: + shutil.move(os.path.join(opt.output_dir, f), os.path.join(opt.output_dir, 'merged_records', f'{timestamp}_{f}')) + + # merge voxelized + df_files = [f for f in os.listdir(opt.output_dir) if f.startswith('voxelized_') and f.endswith('.csv')] + df_parts = [] + for f in df_files: + try: + df_parts.append(pd.read_csv(os.path.join(opt.output_dir, f))) + except: + pass + if len(df_parts) > 0: + df = pd.concat(df_parts) + df.set_index('sha256', inplace=True) + metadata.update(df, overwrite=True) + for f in df_files: + shutil.move(os.path.join(opt.output_dir, f), os.path.join(opt.output_dir, 'merged_records', f'{timestamp}_{f}')) + + # merge cond_rendered + df_files = [f for f in os.listdir(opt.output_dir) if f.startswith('cond_rendered_') and f.endswith('.csv')] + df_parts = [] + for f in df_files: + try: + df_parts.append(pd.read_csv(os.path.join(opt.output_dir, f))) + except: + pass + if len(df_parts) > 0: + df = pd.concat(df_parts) + df.set_index('sha256', inplace=True) + metadata.update(df, overwrite=True) + for f in df_files: + shutil.move(os.path.join(opt.output_dir, f), os.path.join(opt.output_dir, 'merged_records', f'{timestamp}_{f}')) + + # merge features + for model in image_models: + df_files = [f for f in os.listdir(opt.output_dir) if f.startswith(f'feature_{model}_') and f.endswith('.csv')] + df_parts = [] + for f in df_files: + try: + df_parts.append(pd.read_csv(os.path.join(opt.output_dir, f))) + except: + pass + if len(df_parts) > 0: + df = pd.concat(df_parts) + df.set_index('sha256', inplace=True) + metadata.update(df, overwrite=True) + for f in df_files: + shutil.move(os.path.join(opt.output_dir, f), os.path.join(opt.output_dir, 'merged_records', f'{timestamp}_{f}')) + + # merge latents + for model in latent_models: + df_files = [f for f in os.listdir(opt.output_dir) if f.startswith(f'latent_{model}_') and f.endswith('.csv')] + df_parts = [] + for f in df_files: + try: + df_parts.append(pd.read_csv(os.path.join(opt.output_dir, f))) + except: + pass + if len(df_parts) > 0: + df = pd.concat(df_parts) + df.set_index('sha256', inplace=True) + metadata.update(df, overwrite=True) + for f in df_files: + shutil.move(os.path.join(opt.output_dir, f), os.path.join(opt.output_dir, 'merged_records', f'{timestamp}_{f}')) + + # merge sparse structure latents + for model in ss_latent_models: + df_files = [f for f in os.listdir(opt.output_dir) if f.startswith(f'ss_latent_{model}_') and f.endswith('.csv')] + df_parts = [] + for f in df_files: + try: + df_parts.append(pd.read_csv(os.path.join(opt.output_dir, f))) + except: + pass + if len(df_parts) > 0: + df = pd.concat(df_parts) + df.set_index('sha256', inplace=True) + metadata.update(df, overwrite=True) + for f in df_files: + shutil.move(os.path.join(opt.output_dir, f), os.path.join(opt.output_dir, 'merged_records', f'{timestamp}_{f}')) + + # build metadata from files + if opt.from_file: + with ThreadPoolExecutor(max_workers=os.cpu_count()) as executor, \ + tqdm(total=len(metadata), desc="Building metadata") as pbar: + def worker(sha256): + try: + if need_process('rendered') and metadata.loc[sha256, 'rendered'] == False and \ + os.path.exists(os.path.join(opt.output_dir, 'renders', sha256, 'transforms.json')): + metadata.loc[sha256, 'rendered'] = True + if need_process('voxelized') and metadata.loc[sha256, 'rendered'] == True and metadata.loc[sha256, 'voxelized'] == False and \ + os.path.exists(os.path.join(opt.output_dir, 'voxels', f'{sha256}.ply')): + try: + pts = utils3d.io.read_ply(os.path.join(opt.output_dir, 'voxels', f'{sha256}.ply'))[0] + metadata.loc[sha256, 'voxelized'] = True + metadata.loc[sha256, 'num_voxels'] = len(pts) + except Exception as e: + pass + if need_process('cond_rendered') and metadata.loc[sha256, 'cond_rendered'] == False and \ + os.path.exists(os.path.join(opt.output_dir, 'renders_cond', sha256, 'transforms.json')): + metadata.loc[sha256, 'cond_rendered'] = True + for model in image_models: + if need_process(f'feature_{model}') and \ + metadata.loc[sha256, f'feature_{model}'] == False and \ + metadata.loc[sha256, 'rendered'] == True and \ + metadata.loc[sha256, 'voxelized'] == True and \ + os.path.exists(os.path.join(opt.output_dir, 'features', model, f'{sha256}.npz')): + metadata.loc[sha256, f'feature_{model}'] = True + for model in latent_models: + if need_process(f'latent_{model}') and \ + metadata.loc[sha256, f'latent_{model}'] == False and \ + metadata.loc[sha256, 'rendered'] == True and \ + metadata.loc[sha256, 'voxelized'] == True and \ + os.path.exists(os.path.join(opt.output_dir, 'latents', model, f'{sha256}.npz')): + metadata.loc[sha256, f'latent_{model}'] = True + for model in ss_latent_models: + if need_process(f'ss_latent_{model}') and \ + metadata.loc[sha256, f'ss_latent_{model}'] == False and \ + metadata.loc[sha256, 'voxelized'] == True and \ + os.path.exists(os.path.join(opt.output_dir, 'ss_latents', model, f'{sha256}.npz')): + metadata.loc[sha256, f'ss_latent_{model}'] = True + pbar.update() + except Exception as e: + print(f'Error processing {sha256}: {e}') + pbar.update() + + executor.map(worker, metadata.index) + executor.shutdown(wait=True) + + # statistics + metadata.to_csv(os.path.join(opt.output_dir, 'metadata.csv')) + num_downloaded = metadata['local_path'].count() if 'local_path' in metadata.columns else 0 + with open(os.path.join(opt.output_dir, 'statistics.txt'), 'w') as f: + f.write('Statistics:\n') + f.write(f' - Number of assets: {len(metadata)}\n') + f.write(f' - Number of assets downloaded: {num_downloaded}\n') + f.write(f' - Number of assets rendered: {metadata["rendered"].sum()}\n') + f.write(f' - Number of assets voxelized: {metadata["voxelized"].sum()}\n') + if len(image_models) != 0: + f.write(f' - Number of assets with image features extracted:\n') + for model in image_models: + f.write(f' - {model}: {metadata[f"feature_{model}"].sum()}\n') + if len(latent_models) != 0: + f.write(f' - Number of assets with latents extracted:\n') + for model in latent_models: + f.write(f' - {model}: {metadata[f"latent_{model}"].sum()}\n') + if len(ss_latent_models) != 0: + f.write(f' - Number of assets with sparse structure latents extracted:\n') + for model in ss_latent_models: + f.write(f' - {model}: {metadata[f"ss_latent_{model}"].sum()}\n') + f.write(f' - Number of assets with captions: {metadata["captions"].count()}\n') + f.write(f' - Number of assets with image conditions: {metadata["cond_rendered"].sum()}\n') + + with open(os.path.join(opt.output_dir, 'statistics.txt'), 'r') as f: + print(f.read()) \ No newline at end of file diff --git a/TRELLIS/dataset_toolkits/datasets/3D-FUTURE.py b/TRELLIS/dataset_toolkits/datasets/3D-FUTURE.py new file mode 100644 index 0000000000000000000000000000000000000000..8977c92b01e181d75737d4e4ba7422d39a0a80b7 --- /dev/null +++ b/TRELLIS/dataset_toolkits/datasets/3D-FUTURE.py @@ -0,0 +1,97 @@ +import os +import re +import argparse +import zipfile +from concurrent.futures import ThreadPoolExecutor +from tqdm import tqdm +import pandas as pd +from utils import get_file_hash + + +def add_args(parser: argparse.ArgumentParser): + pass + + +def get_metadata(**kwargs): + metadata = pd.read_csv("hf://datasets/JeffreyXiang/TRELLIS-500K/3D-FUTURE.csv") + return metadata + + +def download(metadata, output_dir, **kwargs): + os.makedirs(output_dir, exist_ok=True) + + if not os.path.exists(os.path.join(output_dir, 'raw', '3D-FUTURE-model.zip')): + print("\033[93m") + print("3D-FUTURE have to be downloaded manually") + print(f"Please download the 3D-FUTURE-model.zip file and place it in the {output_dir}/raw directory") + print("Visit https://tianchi.aliyun.com/specials/promotion/alibaba-3d-future for more information") + print("\033[0m") + raise FileNotFoundError("3D-FUTURE-model.zip not found") + + downloaded = {} + metadata = metadata.set_index("file_identifier") + with zipfile.ZipFile(os.path.join(output_dir, 'raw', '3D-FUTURE-model.zip')) as zip_ref: + all_names = zip_ref.namelist() + instances = [instance[:-1] for instance in all_names if re.match(r"^3D-FUTURE-model/[^/]+/$", instance)] + instances = list(filter(lambda x: x in metadata.index, instances)) + + with ThreadPoolExecutor(max_workers=os.cpu_count()) as executor, \ + tqdm(total=len(instances), desc="Extracting") as pbar: + def worker(instance: str) -> str: + try: + instance_files = list(filter(lambda x: x.startswith(f"{instance}/") and not x.endswith("/"), all_names)) + zip_ref.extractall(os.path.join(output_dir, 'raw'), members=instance_files) + sha256 = get_file_hash(os.path.join(output_dir, 'raw', f"{instance}/image.jpg")) + pbar.update() + return sha256 + except Exception as e: + pbar.update() + print(f"Error extracting for {instance}: {e}") + return None + + sha256s = executor.map(worker, instances) + executor.shutdown(wait=True) + + for k, sha256 in zip(instances, sha256s): + if sha256 is not None: + if sha256 == metadata.loc[k, "sha256"]: + downloaded[sha256] = os.path.join("raw", f"{k}/raw_model.obj") + else: + print(f"Error downloading {k}: sha256s do not match") + + return pd.DataFrame(downloaded.items(), columns=['sha256', 'local_path']) + + +def foreach_instance(metadata, output_dir, func, max_workers=None, desc='Processing objects') -> pd.DataFrame: + import os + from concurrent.futures import ThreadPoolExecutor + from tqdm import tqdm + + # load metadata + metadata = metadata.to_dict('records') + + # processing objects + records = [] + max_workers = max_workers or os.cpu_count() + try: + with ThreadPoolExecutor(max_workers=max_workers) as executor, \ + tqdm(total=len(metadata), desc=desc) as pbar: + def worker(metadatum): + try: + local_path = metadatum['local_path'] + sha256 = metadatum['sha256'] + file = os.path.join(output_dir, local_path) + record = func(file, sha256) + if record is not None: + records.append(record) + pbar.update() + except Exception as e: + print(f"Error processing object {sha256}: {e}") + pbar.update() + + executor.map(worker, metadata) + executor.shutdown(wait=True) + except: + print("Error happened during processing.") + + return pd.DataFrame.from_records(records) diff --git a/TRELLIS/dataset_toolkits/datasets/ABO.py b/TRELLIS/dataset_toolkits/datasets/ABO.py new file mode 100644 index 0000000000000000000000000000000000000000..6d4c804aeb99bea9bcf44cb1976464f342cda62c --- /dev/null +++ b/TRELLIS/dataset_toolkits/datasets/ABO.py @@ -0,0 +1,96 @@ +import os +import re +import argparse +import tarfile +from concurrent.futures import ThreadPoolExecutor +from tqdm import tqdm +import pandas as pd +from utils import get_file_hash + + +def add_args(parser: argparse.ArgumentParser): + pass + + +def get_metadata(**kwargs): + metadata = pd.read_csv("hf://datasets/JeffreyXiang/TRELLIS-500K/ABO.csv") + return metadata + + +def download(metadata, output_dir, **kwargs): + os.makedirs(os.path.join(output_dir, 'raw'), exist_ok=True) + + if not os.path.exists(os.path.join(output_dir, 'raw', 'abo-3dmodels.tar')): + try: + os.makedirs(os.path.join(output_dir, 'raw'), exist_ok=True) + os.system(f"wget -O {output_dir}/raw/abo-3dmodels.tar https://amazon-berkeley-objects.s3.amazonaws.com/archives/abo-3dmodels.tar") + except: + print("\033[93m") + print("Error downloading ABO dataset. Please check your internet connection and try again.") + print("Or, you can manually download the abo-3dmodels.tar file and place it in the {output_dir}/raw directory") + print("Visit https://amazon-berkeley-objects.s3.amazonaws.com/index.html for more information") + print("\033[0m") + raise FileNotFoundError("Error downloading ABO dataset") + + downloaded = {} + metadata = metadata.set_index("file_identifier") + with tarfile.open(os.path.join(output_dir, 'raw', 'abo-3dmodels.tar')) as tar: + with ThreadPoolExecutor(max_workers=1) as executor, \ + tqdm(total=len(metadata), desc="Extracting") as pbar: + def worker(instance: str) -> str: + try: + tar.extract(f"3dmodels/original/{instance}", path=os.path.join(output_dir, 'raw')) + sha256 = get_file_hash(os.path.join(output_dir, 'raw/3dmodels/original', instance)) + pbar.update() + return sha256 + except Exception as e: + pbar.update() + print(f"Error extracting for {instance}: {e}") + return None + + sha256s = executor.map(worker, metadata.index) + executor.shutdown(wait=True) + + for k, sha256 in zip(metadata.index, sha256s): + if sha256 is not None: + if sha256 == metadata.loc[k, "sha256"]: + downloaded[sha256] = os.path.join('raw/3dmodels/original', k) + else: + print(f"Error downloading {k}: sha256s do not match") + + return pd.DataFrame(downloaded.items(), columns=['sha256', 'local_path']) + + +def foreach_instance(metadata, output_dir, func, max_workers=None, desc='Processing objects') -> pd.DataFrame: + import os + from concurrent.futures import ThreadPoolExecutor + from tqdm import tqdm + + # load metadata + metadata = metadata.to_dict('records') + + # processing objects + records = [] + max_workers = max_workers or os.cpu_count() + try: + with ThreadPoolExecutor(max_workers=max_workers) as executor, \ + tqdm(total=len(metadata), desc=desc) as pbar: + def worker(metadatum): + try: + local_path = metadatum['local_path'] + sha256 = metadatum['sha256'] + file = os.path.join(output_dir, local_path) + record = func(file, sha256) + if record is not None: + records.append(record) + pbar.update() + except Exception as e: + print(f"Error processing object {sha256}: {e}") + pbar.update() + + executor.map(worker, metadata) + executor.shutdown(wait=True) + except: + print("Error happened during processing.") + + return pd.DataFrame.from_records(records) diff --git a/TRELLIS/dataset_toolkits/datasets/HSSD.py b/TRELLIS/dataset_toolkits/datasets/HSSD.py new file mode 100644 index 0000000000000000000000000000000000000000..d8b7f822063ab43daf1fe1bc915cfd6528b73ca9 --- /dev/null +++ b/TRELLIS/dataset_toolkits/datasets/HSSD.py @@ -0,0 +1,103 @@ +import os +import re +import argparse +import tarfile +from concurrent.futures import ThreadPoolExecutor +from tqdm import tqdm +import pandas as pd +import huggingface_hub +from utils import get_file_hash + + +def add_args(parser: argparse.ArgumentParser): + pass + + +def get_metadata(**kwargs): + metadata = pd.read_csv("hf://datasets/JeffreyXiang/TRELLIS-500K/HSSD.csv") + return metadata + + +def download(metadata, output_dir, **kwargs): + os.makedirs(os.path.join(output_dir, 'raw'), exist_ok=True) + + # check login + try: + huggingface_hub.whoami() + except: + print("\033[93m") + print("Haven't logged in to the Hugging Face Hub.") + print("Visit https://huggingface.co/settings/tokens to get a token.") + print("\033[0m") + huggingface_hub.login() + + try: + huggingface_hub.hf_hub_download(repo_id="hssd/hssd-models", filename="README.md", repo_type="dataset") + except: + print("\033[93m") + print("Error downloading HSSD dataset.") + print("Check if you have access to the HSSD dataset.") + print("Visit https://huggingface.co/datasets/hssd/hssd-models for more information") + print("\033[0m") + + downloaded = {} + metadata = metadata.set_index("file_identifier") + with ThreadPoolExecutor(max_workers=os.cpu_count()) as executor, \ + tqdm(total=len(metadata), desc="Downloading") as pbar: + def worker(instance: str) -> str: + try: + huggingface_hub.hf_hub_download(repo_id="hssd/hssd-models", filename=instance, repo_type="dataset", local_dir=os.path.join(output_dir, 'raw')) + sha256 = get_file_hash(os.path.join(output_dir, 'raw', instance)) + pbar.update() + return sha256 + except Exception as e: + pbar.update() + print(f"Error extracting for {instance}: {e}") + return None + + sha256s = executor.map(worker, metadata.index) + executor.shutdown(wait=True) + + for k, sha256 in zip(metadata.index, sha256s): + if sha256 is not None: + if sha256 == metadata.loc[k, "sha256"]: + downloaded[sha256] = os.path.join('raw', k) + else: + print(f"Error downloading {k}: sha256s do not match") + + return pd.DataFrame(downloaded.items(), columns=['sha256', 'local_path']) + + +def foreach_instance(metadata, output_dir, func, max_workers=None, desc='Processing objects') -> pd.DataFrame: + import os + from concurrent.futures import ThreadPoolExecutor + from tqdm import tqdm + + # load metadata + metadata = metadata.to_dict('records') + + # processing objects + records = [] + max_workers = max_workers or os.cpu_count() + try: + with ThreadPoolExecutor(max_workers=max_workers) as executor, \ + tqdm(total=len(metadata), desc=desc) as pbar: + def worker(metadatum): + try: + local_path = metadatum['local_path'] + sha256 = metadatum['sha256'] + file = os.path.join(output_dir, local_path) + record = func(file, sha256) + if record is not None: + records.append(record) + pbar.update() + except Exception as e: + print(f"Error processing object {sha256}: {e}") + pbar.update() + + executor.map(worker, metadata) + executor.shutdown(wait=True) + except: + print("Error happened during processing.") + + return pd.DataFrame.from_records(records) diff --git a/TRELLIS/dataset_toolkits/datasets/ObjaverseXL.py b/TRELLIS/dataset_toolkits/datasets/ObjaverseXL.py new file mode 100644 index 0000000000000000000000000000000000000000..455ec87bb95f0256810f2ace6455e5f994f7ad2f --- /dev/null +++ b/TRELLIS/dataset_toolkits/datasets/ObjaverseXL.py @@ -0,0 +1,92 @@ +import os +import argparse +from concurrent.futures import ThreadPoolExecutor +from tqdm import tqdm +import pandas as pd +import objaverse.xl as oxl +from utils import get_file_hash + + +def add_args(parser: argparse.ArgumentParser): + parser.add_argument('--source', type=str, default='sketchfab', + help='Data source to download annotations from (github, sketchfab)') + + +def get_metadata(source, **kwargs): + if source == 'sketchfab': + metadata = pd.read_csv("hf://datasets/JeffreyXiang/TRELLIS-500K/ObjaverseXL_sketchfab.csv") + elif source == 'github': + metadata = pd.read_csv("hf://datasets/JeffreyXiang/TRELLIS-500K/ObjaverseXL_github.csv") + else: + raise ValueError(f"Invalid source: {source}") + return metadata + + +def download(metadata, output_dir, **kwargs): + os.makedirs(os.path.join(output_dir, 'raw'), exist_ok=True) + + # download annotations + annotations = oxl.get_annotations() + annotations = annotations[annotations['sha256'].isin(metadata['sha256'].values)] + + # download and render objects + file_paths = oxl.download_objects( + annotations, + download_dir=os.path.join(output_dir, "raw"), + save_repo_format="zip", + ) + + downloaded = {} + metadata = metadata.set_index("file_identifier") + for k, v in file_paths.items(): + sha256 = metadata.loc[k, "sha256"] + downloaded[sha256] = os.path.relpath(v, output_dir) + + return pd.DataFrame(downloaded.items(), columns=['sha256', 'local_path']) + + +def foreach_instance(metadata, output_dir, func, max_workers=None, desc='Processing objects') -> pd.DataFrame: + import os + from concurrent.futures import ThreadPoolExecutor + from tqdm import tqdm + import tempfile + import zipfile + + # load metadata + metadata = metadata.to_dict('records') + + # processing objects + records = [] + max_workers = max_workers or os.cpu_count() + try: + with ThreadPoolExecutor(max_workers=max_workers) as executor, \ + tqdm(total=len(metadata), desc=desc) as pbar: + def worker(metadatum): + try: + local_path = metadatum['local_path'] + sha256 = metadatum['sha256'] + if local_path.startswith('raw/github/repos/'): + path_parts = local_path.split('/') + file_name = os.path.join(*path_parts[5:]) + zip_file = os.path.join(output_dir, *path_parts[:5]) + with tempfile.TemporaryDirectory() as tmp_dir: + with zipfile.ZipFile(zip_file, 'r') as zip_ref: + zip_ref.extractall(tmp_dir) + file = os.path.join(tmp_dir, file_name) + record = func(file, sha256) + else: + file = os.path.join(output_dir, local_path) + record = func(file, sha256) + if record is not None: + records.append(record) + pbar.update() + except Exception as e: + print(f"Error processing object {sha256}: {e}") + pbar.update() + + executor.map(worker, metadata) + executor.shutdown(wait=True) + except: + print("Error happened during processing.") + + return pd.DataFrame.from_records(records) diff --git a/TRELLIS/dataset_toolkits/datasets/Toys4k.py b/TRELLIS/dataset_toolkits/datasets/Toys4k.py new file mode 100644 index 0000000000000000000000000000000000000000..17be4d28b8dcbf4a3ed115cad76b30db94690940 --- /dev/null +++ b/TRELLIS/dataset_toolkits/datasets/Toys4k.py @@ -0,0 +1,92 @@ +import os +import re +import argparse +import zipfile +from concurrent.futures import ThreadPoolExecutor +from tqdm import tqdm +import pandas as pd +from utils import get_file_hash + + +def add_args(parser: argparse.ArgumentParser): + pass + + +def get_metadata(**kwargs): + metadata = pd.read_csv("hf://datasets/JeffreyXiang/TRELLIS-500K/Toys4k.csv") + return metadata + + +def download(metadata, output_dir, **kwargs): + os.makedirs(output_dir, exist_ok=True) + + if not os.path.exists(os.path.join(output_dir, 'raw', 'toys4k_blend_files.zip')): + print("\033[93m") + print("Toys4k have to be downloaded manually") + print(f"Please download the toys4k_blend_files.zip file and place it in the {output_dir}/raw directory") + print("Visit https://github.com/rehg-lab/lowshot-shapebias/tree/main/toys4k for more information") + print("\033[0m") + raise FileNotFoundError("toys4k_blend_files.zip not found") + + downloaded = {} + metadata = metadata.set_index("file_identifier") + with zipfile.ZipFile(os.path.join(output_dir, 'raw', 'toys4k_blend_files.zip')) as zip_ref: + with ThreadPoolExecutor(max_workers=os.cpu_count()) as executor, \ + tqdm(total=len(metadata), desc="Extracting") as pbar: + def worker(instance: str) -> str: + try: + zip_ref.extract(os.path.join('toys4k_blend_files', instance), os.path.join(output_dir, 'raw')) + sha256 = get_file_hash(os.path.join(output_dir, 'raw/toys4k_blend_files', instance)) + pbar.update() + return sha256 + except Exception as e: + pbar.update() + print(f"Error extracting for {instance}: {e}") + return None + + sha256s = executor.map(worker, metadata.index) + executor.shutdown(wait=True) + + for k, sha256 in zip(metadata.index, sha256s): + if sha256 is not None: + if sha256 == metadata.loc[k, "sha256"]: + downloaded[sha256] = os.path.join("raw/toys4k_blend_files", k) + else: + print(f"Error downloading {k}: sha256s do not match") + + return pd.DataFrame(downloaded.items(), columns=['sha256', 'local_path']) + + +def foreach_instance(metadata, output_dir, func, max_workers=None, desc='Processing objects') -> pd.DataFrame: + import os + from concurrent.futures import ThreadPoolExecutor + from tqdm import tqdm + + # load metadata + metadata = metadata.to_dict('records') + + # processing objects + records = [] + max_workers = max_workers or os.cpu_count() + try: + with ThreadPoolExecutor(max_workers=max_workers) as executor, \ + tqdm(total=len(metadata), desc=desc) as pbar: + def worker(metadatum): + try: + local_path = metadatum['local_path'] + sha256 = metadatum['sha256'] + file = os.path.join(output_dir, local_path) + record = func(file, sha256) + if record is not None: + records.append(record) + pbar.update() + except Exception as e: + print(f"Error processing object {sha256}: {e}") + pbar.update() + + executor.map(worker, metadata) + executor.shutdown(wait=True) + except: + print("Error happened during processing.") + + return pd.DataFrame.from_records(records) diff --git a/TRELLIS/dataset_toolkits/download.py b/TRELLIS/dataset_toolkits/download.py new file mode 100644 index 0000000000000000000000000000000000000000..f069ad097ae6eedd970ffd763cbd6e48293124b0 --- /dev/null +++ b/TRELLIS/dataset_toolkits/download.py @@ -0,0 +1,52 @@ +import os +import copy +import sys +import importlib +import argparse +import pandas as pd +from easydict import EasyDict as edict + +if __name__ == '__main__': + dataset_utils = importlib.import_module(f'datasets.{sys.argv[1]}') + + parser = argparse.ArgumentParser() + parser.add_argument('--output_dir', type=str, required=True, + help='Directory to save the metadata') + parser.add_argument('--filter_low_aesthetic_score', type=float, default=None, + help='Filter objects with aesthetic score lower than this value') + parser.add_argument('--instances', type=str, default=None, + help='Instances to process') + dataset_utils.add_args(parser) + parser.add_argument('--rank', type=int, default=0) + parser.add_argument('--world_size', type=int, default=1) + opt = parser.parse_args(sys.argv[2:]) + opt = edict(vars(opt)) + + os.makedirs(opt.output_dir, exist_ok=True) + + # get file list + if not os.path.exists(os.path.join(opt.output_dir, 'metadata.csv')): + raise ValueError('metadata.csv not found') + metadata = pd.read_csv(os.path.join(opt.output_dir, 'metadata.csv')) + if opt.instances is None: + if opt.filter_low_aesthetic_score is not None: + metadata = metadata[metadata['aesthetic_score'] >= opt.filter_low_aesthetic_score] + if 'local_path' in metadata.columns: + metadata = metadata[metadata['local_path'].isna()] + else: + if os.path.exists(opt.instances): + with open(opt.instances, 'r') as f: + instances = f.read().splitlines() + else: + instances = opt.instances.split(',') + metadata = metadata[metadata['sha256'].isin(instances)] + + start = len(metadata) * opt.rank // opt.world_size + end = len(metadata) * (opt.rank + 1) // opt.world_size + metadata = metadata[start:end] + + print(f'Processing {len(metadata)} objects...') + + # process objects + downloaded = dataset_utils.download(metadata, **opt) + downloaded.to_csv(os.path.join(opt.output_dir, f'downloaded_{opt.rank}.csv'), index=False) diff --git a/TRELLIS/dataset_toolkits/encode_latent.py b/TRELLIS/dataset_toolkits/encode_latent.py new file mode 100644 index 0000000000000000000000000000000000000000..4868ad8ba0fd92602d9675042bf758a5c8a54728 --- /dev/null +++ b/TRELLIS/dataset_toolkits/encode_latent.py @@ -0,0 +1,127 @@ +import os +import sys +sys.path.append(os.path.join(os.path.dirname(__file__), '..')) +import copy +import json +import argparse +import torch +import numpy as np +import pandas as pd +from tqdm import tqdm +from easydict import EasyDict as edict +from concurrent.futures import ThreadPoolExecutor +from queue import Queue + +import trellis.models as models +import trellis.modules.sparse as sp + + +torch.set_grad_enabled(False) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('--output_dir', type=str, required=True, + help='Directory to save the metadata') + parser.add_argument('--filter_low_aesthetic_score', type=float, default=None, + help='Filter objects with aesthetic score lower than this value') + parser.add_argument('--feat_model', type=str, default='dinov2_vitl14_reg', + help='Feature model') + parser.add_argument('--enc_pretrained', type=str, default='JeffreyXiang/TRELLIS-image-large/ckpts/slat_enc_swin8_B_64l8_fp16', + help='Pretrained encoder model') + parser.add_argument('--model_root', type=str, default='results', + help='Root directory of models') + parser.add_argument('--enc_model', type=str, default=None, + help='Encoder model. if specified, use this model instead of pretrained model') + parser.add_argument('--ckpt', type=str, default=None, + help='Checkpoint to load') + parser.add_argument('--instances', type=str, default=None, + help='Instances to process') + parser.add_argument('--rank', type=int, default=0) + parser.add_argument('--world_size', type=int, default=1) + opt = parser.parse_args() + opt = edict(vars(opt)) + + if opt.enc_model is None: + latent_name = f'{opt.feat_model}_{opt.enc_pretrained.split("/")[-1]}' + encoder = models.from_pretrained(opt.enc_pretrained).eval().cuda() + else: + latent_name = f'{opt.feat_model}_{opt.enc_model}_{opt.ckpt}' + cfg = edict(json.load(open(os.path.join(opt.model_root, opt.enc_model, 'config.json'), 'r'))) + encoder = getattr(models, cfg.models.encoder.name)(**cfg.models.encoder.args).cuda() + ckpt_path = os.path.join(opt.model_root, opt.enc_model, 'ckpts', f'encoder_{opt.ckpt}.pt') + encoder.load_state_dict(torch.load(ckpt_path), strict=False) + encoder.eval() + print(f'Loaded model from {ckpt_path}') + + os.makedirs(os.path.join(opt.output_dir, 'latents', latent_name), exist_ok=True) + + # get file list + if os.path.exists(os.path.join(opt.output_dir, 'metadata.csv')): + metadata = pd.read_csv(os.path.join(opt.output_dir, 'metadata.csv')) + else: + raise ValueError('metadata.csv not found') + if opt.instances is not None: + with open(opt.instances, 'r') as f: + sha256s = [line.strip() for line in f] + metadata = metadata[metadata['sha256'].isin(sha256s)] + else: + if opt.filter_low_aesthetic_score is not None: + metadata = metadata[metadata['aesthetic_score'] >= opt.filter_low_aesthetic_score] + metadata = metadata[metadata[f'feature_{opt.feat_model}'] == True] + if f'latent_{latent_name}' in metadata.columns: + metadata = metadata[metadata[f'latent_{latent_name}'] == False] + + start = len(metadata) * opt.rank // opt.world_size + end = len(metadata) * (opt.rank + 1) // opt.world_size + metadata = metadata[start:end] + records = [] + + # filter out objects that are already processed + sha256s = list(metadata['sha256'].values) + for sha256 in copy.copy(sha256s): + if os.path.exists(os.path.join(opt.output_dir, 'latents', latent_name, f'{sha256}.npz')): + records.append({'sha256': sha256, f'latent_{latent_name}': True}) + sha256s.remove(sha256) + + # encode latents + load_queue = Queue(maxsize=4) + try: + with ThreadPoolExecutor(max_workers=32) as loader_executor, \ + ThreadPoolExecutor(max_workers=32) as saver_executor: + def loader(sha256): + try: + feats = np.load(os.path.join(opt.output_dir, 'features', opt.feat_model, f'{sha256}.npz')) + load_queue.put((sha256, feats)) + except Exception as e: + print(f"Error loading features for {sha256}: {e}") + loader_executor.map(loader, sha256s) + + def saver(sha256, pack): + save_path = os.path.join(opt.output_dir, 'latents', latent_name, f'{sha256}.npz') + np.savez_compressed(save_path, **pack) + records.append({'sha256': sha256, f'latent_{latent_name}': True}) + + for _ in tqdm(range(len(sha256s)), desc="Extracting latents"): + sha256, feats = load_queue.get() + feats = sp.SparseTensor( + feats = torch.from_numpy(feats['patchtokens']).float(), + coords = torch.cat([ + torch.zeros(feats['patchtokens'].shape[0], 1).int(), + torch.from_numpy(feats['indices']).int(), + ], dim=1), + ).cuda() + latent = encoder(feats, sample_posterior=False) + assert torch.isfinite(latent.feats).all(), "Non-finite latent" + pack = { + 'feats': latent.feats.cpu().numpy().astype(np.float32), + 'coords': latent.coords[:, 1:].cpu().numpy().astype(np.uint8), + } + saver_executor.submit(saver, sha256, pack) + + saver_executor.shutdown(wait=True) + except: + print("Error happened during processing.") + + records = pd.DataFrame.from_records(records) + records.to_csv(os.path.join(opt.output_dir, f'latent_{latent_name}_{opt.rank}.csv'), index=False) diff --git a/TRELLIS/dataset_toolkits/encode_ss_latent.py b/TRELLIS/dataset_toolkits/encode_ss_latent.py new file mode 100644 index 0000000000000000000000000000000000000000..2ef54da917a8cd226e8c845a9fc2cfed08ff12f3 --- /dev/null +++ b/TRELLIS/dataset_toolkits/encode_ss_latent.py @@ -0,0 +1,128 @@ +import os +import sys +sys.path.append(os.path.join(os.path.dirname(__file__), '..')) +import copy +import json +import argparse +import torch +import numpy as np +import pandas as pd +import utils3d +from tqdm import tqdm +from easydict import EasyDict as edict +from concurrent.futures import ThreadPoolExecutor +from queue import Queue + +import trellis.models as models + + +torch.set_grad_enabled(False) + + +def get_voxels(instance): + position = utils3d.io.read_ply(os.path.join(opt.output_dir, 'voxels', f'{instance}.ply'))[0] + coords = ((torch.tensor(position) + 0.5) * opt.resolution).int().contiguous() + ss = torch.zeros(1, opt.resolution, opt.resolution, opt.resolution, dtype=torch.long) + ss[:, coords[:, 0], coords[:, 1], coords[:, 2]] = 1 + return ss + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('--output_dir', type=str, required=True, + help='Directory to save the metadata') + parser.add_argument('--filter_low_aesthetic_score', type=float, default=None, + help='Filter objects with aesthetic score lower than this value') + parser.add_argument('--enc_pretrained', type=str, default='JeffreyXiang/TRELLIS-image-large/ckpts/ss_enc_conv3d_16l8_fp16', + help='Pretrained encoder model') + parser.add_argument('--model_root', type=str, default='results', + help='Root directory of models') + parser.add_argument('--enc_model', type=str, default=None, + help='Encoder model. if specified, use this model instead of pretrained model') + parser.add_argument('--ckpt', type=str, default=None, + help='Checkpoint to load') + parser.add_argument('--resolution', type=int, default=64, + help='Resolution') + parser.add_argument('--instances', type=str, default=None, + help='Instances to process') + parser.add_argument('--rank', type=int, default=0) + parser.add_argument('--world_size', type=int, default=1) + opt = parser.parse_args() + opt = edict(vars(opt)) + + if opt.enc_model is None: + latent_name = f'{opt.enc_pretrained.split("/")[-1]}' + encoder = models.from_pretrained(opt.enc_pretrained).eval().cuda() + else: + latent_name = f'{opt.enc_model}_{opt.ckpt}' + cfg = edict(json.load(open(os.path.join(opt.model_root, opt.enc_model, 'config.json'), 'r'))) + encoder = getattr(models, cfg.models.encoder.name)(**cfg.models.encoder.args).cuda() + ckpt_path = os.path.join(opt.model_root, opt.enc_model, 'ckpts', f'encoder_{opt.ckpt}.pt') + encoder.load_state_dict(torch.load(ckpt_path), strict=False) + encoder.eval() + print(f'Loaded model from {ckpt_path}') + + os.makedirs(os.path.join(opt.output_dir, 'ss_latents', latent_name), exist_ok=True) + + # get file list + if os.path.exists(os.path.join(opt.output_dir, 'metadata.csv')): + metadata = pd.read_csv(os.path.join(opt.output_dir, 'metadata.csv')) + else: + raise ValueError('metadata.csv not found') + if opt.instances is not None: + with open(opt.instances, 'r') as f: + instances = f.read().splitlines() + metadata = metadata[metadata['sha256'].isin(instances)] + else: + if opt.filter_low_aesthetic_score is not None: + metadata = metadata[metadata['aesthetic_score'] >= opt.filter_low_aesthetic_score] + metadata = metadata[metadata['voxelized'] == True] + if f'ss_latent_{latent_name}' in metadata.columns: + metadata = metadata[metadata[f'ss_latent_{latent_name}'] == False] + + start = len(metadata) * opt.rank // opt.world_size + end = len(metadata) * (opt.rank + 1) // opt.world_size + metadata = metadata[start:end] + records = [] + + # filter out objects that are already processed + sha256s = list(metadata['sha256'].values) + for sha256 in copy.copy(sha256s): + if os.path.exists(os.path.join(opt.output_dir, 'ss_latents', latent_name, f'{sha256}.npz')): + records.append({'sha256': sha256, f'ss_latent_{latent_name}': True}) + sha256s.remove(sha256) + + # encode latents + load_queue = Queue(maxsize=4) + try: + with ThreadPoolExecutor(max_workers=32) as loader_executor, \ + ThreadPoolExecutor(max_workers=32) as saver_executor: + def loader(sha256): + try: + ss = get_voxels(sha256)[None].float() + load_queue.put((sha256, ss)) + except Exception as e: + print(f"Error loading features for {sha256}: {e}") + loader_executor.map(loader, sha256s) + + def saver(sha256, pack): + save_path = os.path.join(opt.output_dir, 'ss_latents', latent_name, f'{sha256}.npz') + np.savez_compressed(save_path, **pack) + records.append({'sha256': sha256, f'ss_latent_{latent_name}': True}) + + for _ in tqdm(range(len(sha256s)), desc="Extracting latents"): + sha256, ss = load_queue.get() + ss = ss.cuda().float() + latent = encoder(ss, sample_posterior=False) + assert torch.isfinite(latent).all(), "Non-finite latent" + pack = { + 'mean': latent[0].cpu().numpy(), + } + saver_executor.submit(saver, sha256, pack) + + saver_executor.shutdown(wait=True) + except: + print("Error happened during processing.") + + records = pd.DataFrame.from_records(records) + records.to_csv(os.path.join(opt.output_dir, f'ss_latent_{latent_name}_{opt.rank}.csv'), index=False) diff --git a/TRELLIS/dataset_toolkits/extract_feature.py b/TRELLIS/dataset_toolkits/extract_feature.py new file mode 100644 index 0000000000000000000000000000000000000000..6cb27a068860527fbbb340c8ef073d8202165fdf --- /dev/null +++ b/TRELLIS/dataset_toolkits/extract_feature.py @@ -0,0 +1,179 @@ +import os +import copy +import sys +import json +import importlib +import argparse +import torch +import torch.nn.functional as F +import numpy as np +import pandas as pd +import utils3d +from tqdm import tqdm +from easydict import EasyDict as edict +from concurrent.futures import ThreadPoolExecutor +from queue import Queue +from torchvision import transforms +from PIL import Image + + +torch.set_grad_enabled(False) + + +def get_data(frames, sha256): + with ThreadPoolExecutor(max_workers=16) as executor: + def worker(view): + image_path = os.path.join(opt.output_dir, 'renders', sha256, view['file_path']) + try: + image = Image.open(image_path) + except: + print(f"Error loading image {image_path}") + return None + image = image.resize((518, 518), Image.Resampling.LANCZOS) + image = np.array(image).astype(np.float32) / 255 + image = image[:, :, :3] * image[:, :, 3:] + image = torch.from_numpy(image).permute(2, 0, 1).float() + + c2w = torch.tensor(view['transform_matrix']) + c2w[:3, 1:3] *= -1 + extrinsics = torch.inverse(c2w) + fov = view['camera_angle_x'] + intrinsics = utils3d.torch.intrinsics_from_fov_xy(torch.tensor(fov), torch.tensor(fov)) + + return { + 'image': image, + 'extrinsics': extrinsics, + 'intrinsics': intrinsics + } + + datas = executor.map(worker, frames) + for data in datas: + if data is not None: + yield data + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('--output_dir', type=str, required=True, + help='Directory to save the metadata') + parser.add_argument('--filter_low_aesthetic_score', type=float, default=None, + help='Filter objects with aesthetic score lower than this value') + parser.add_argument('--model', type=str, default='dinov2_vitl14_reg', + help='Feature extraction model') + parser.add_argument('--instances', type=str, default=None, + help='Instances to process') + parser.add_argument('--batch_size', type=int, default=16) + parser.add_argument('--rank', type=int, default=0) + parser.add_argument('--world_size', type=int, default=1) + opt = parser.parse_args() + opt = edict(vars(opt)) + + feature_name = opt.model + os.makedirs(os.path.join(opt.output_dir, 'features', feature_name), exist_ok=True) + + # load model + dinov2_model = torch.hub.load('facebookresearch/dinov2', opt.model) + dinov2_model.eval().cuda() + transform = transforms.Compose([ + transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), + ]) + n_patch = 518 // 14 + + # get file list + if os.path.exists(os.path.join(opt.output_dir, 'metadata.csv')): + metadata = pd.read_csv(os.path.join(opt.output_dir, 'metadata.csv')) + else: + raise ValueError('metadata.csv not found') + if opt.instances is not None: + with open(opt.instances, 'r') as f: + instances = f.read().splitlines() + metadata = metadata[metadata['sha256'].isin(instances)] + else: + if opt.filter_low_aesthetic_score is not None: + metadata = metadata[metadata['aesthetic_score'] >= opt.filter_low_aesthetic_score] + if f'feature_{feature_name}' in metadata.columns: + metadata = metadata[metadata[f'feature_{feature_name}'] == False] + metadata = metadata[metadata['voxelized'] == True] + metadata = metadata[metadata['rendered'] == True] + + start = len(metadata) * opt.rank // opt.world_size + end = len(metadata) * (opt.rank + 1) // opt.world_size + metadata = metadata[start:end] + records = [] + + # filter out objects that are already processed + sha256s = list(metadata['sha256'].values) + for sha256 in copy.copy(sha256s): + if os.path.exists(os.path.join(opt.output_dir, 'features', feature_name, f'{sha256}.npz')): + records.append({'sha256': sha256, f'feature_{feature_name}' : True}) + sha256s.remove(sha256) + + # extract features + load_queue = Queue(maxsize=4) + try: + with ThreadPoolExecutor(max_workers=8) as loader_executor, \ + ThreadPoolExecutor(max_workers=8) as saver_executor: + def loader(sha256): + try: + with open(os.path.join(opt.output_dir, 'renders', sha256, 'transforms.json'), 'r') as f: + metadata = json.load(f) + frames = metadata['frames'] + data = [] + for datum in get_data(frames, sha256): + datum['image'] = transform(datum['image']) + data.append(datum) + positions = utils3d.io.read_ply(os.path.join(opt.output_dir, 'voxels', f'{sha256}.ply'))[0] + load_queue.put((sha256, data, positions)) + except Exception as e: + print(f"Error loading data for {sha256}: {e}") + + loader_executor.map(loader, sha256s) + + def saver(sha256, pack, patchtokens, uv): + pack['patchtokens'] = F.grid_sample( + patchtokens, + uv.unsqueeze(1), + mode='bilinear', + align_corners=False, + ).squeeze(2).permute(0, 2, 1).cpu().numpy() + pack['patchtokens'] = np.mean(pack['patchtokens'], axis=0).astype(np.float16) + save_path = os.path.join(opt.output_dir, 'features', feature_name, f'{sha256}.npz') + np.savez_compressed(save_path, **pack) + records.append({'sha256': sha256, f'feature_{feature_name}' : True}) + + for _ in tqdm(range(len(sha256s)), desc="Extracting features"): + sha256, data, positions = load_queue.get() + positions = torch.from_numpy(positions).float().cuda() + indices = ((positions + 0.5) * 64).long() + assert torch.all(indices >= 0) and torch.all(indices < 64), "Some vertices are out of bounds" + n_views = len(data) + N = positions.shape[0] + pack = { + 'indices': indices.cpu().numpy().astype(np.uint8), + } + patchtokens_lst = [] + uv_lst = [] + for i in range(0, n_views, opt.batch_size): + batch_data = data[i:i+opt.batch_size] + bs = len(batch_data) + batch_images = torch.stack([d['image'] for d in batch_data]).cuda() + batch_extrinsics = torch.stack([d['extrinsics'] for d in batch_data]).cuda() + batch_intrinsics = torch.stack([d['intrinsics'] for d in batch_data]).cuda() + features = dinov2_model(batch_images, is_training=True) + uv = utils3d.torch.project_cv(positions, batch_extrinsics, batch_intrinsics)[0] * 2 - 1 + patchtokens = features['x_prenorm'][:, dinov2_model.num_register_tokens + 1:].permute(0, 2, 1).reshape(bs, 1024, n_patch, n_patch) + patchtokens_lst.append(patchtokens) + uv_lst.append(uv) + patchtokens = torch.cat(patchtokens_lst, dim=0) + uv = torch.cat(uv_lst, dim=0) + + # save features + saver_executor.submit(saver, sha256, pack, patchtokens, uv) + + saver_executor.shutdown(wait=True) + except: + print("Error happened during processing.") + + records = pd.DataFrame.from_records(records) + records.to_csv(os.path.join(opt.output_dir, f'feature_{feature_name}_{opt.rank}.csv'), index=False) + \ No newline at end of file diff --git a/TRELLIS/dataset_toolkits/render.py b/TRELLIS/dataset_toolkits/render.py new file mode 100644 index 0000000000000000000000000000000000000000..ec6a0b8488c83c5dbbcd0c07449a7c09d8357174 --- /dev/null +++ b/TRELLIS/dataset_toolkits/render.py @@ -0,0 +1,121 @@ +import os +import json +import copy +import sys +import importlib +import argparse +import pandas as pd +from easydict import EasyDict as edict +from functools import partial +from subprocess import DEVNULL, call +import numpy as np +from utils import sphere_hammersley_sequence + + +BLENDER_LINK = 'https://download.blender.org/release/Blender3.0/blender-3.0.1-linux-x64.tar.xz' +BLENDER_INSTALLATION_PATH = '/tmp' +BLENDER_PATH = f'{BLENDER_INSTALLATION_PATH}/blender-3.0.1-linux-x64/blender' + +def _install_blender(): + if not os.path.exists(BLENDER_PATH): + os.system('sudo apt-get update') + os.system('sudo apt-get install -y libxrender1 libxi6 libxkbcommon-x11-0 libsm6') + os.system(f'wget {BLENDER_LINK} -P {BLENDER_INSTALLATION_PATH}') + os.system(f'tar -xvf {BLENDER_INSTALLATION_PATH}/blender-3.0.1-linux-x64.tar.xz -C {BLENDER_INSTALLATION_PATH}') + + +def _render(file_path, sha256, output_dir, num_views): + output_folder = os.path.join(output_dir, 'renders', sha256) + + # Build camera {yaw, pitch, radius, fov} + yaws = [] + pitchs = [] + offset = (np.random.rand(), np.random.rand()) + for i in range(num_views): + y, p = sphere_hammersley_sequence(i, num_views, offset) + yaws.append(y) + pitchs.append(p) + radius = [2] * num_views + fov = [40 / 180 * np.pi] * num_views + views = [{'yaw': y, 'pitch': p, 'radius': r, 'fov': f} for y, p, r, f in zip(yaws, pitchs, radius, fov)] + + args = [ + BLENDER_PATH, '-b', '-P', os.path.join(os.path.dirname(__file__), 'blender_script', 'render.py'), + '--', + '--views', json.dumps(views), + '--object', os.path.expanduser(file_path), + '--resolution', '512', + '--output_folder', output_folder, + '--engine', 'CYCLES', + '--save_mesh', + ] + if file_path.endswith('.blend'): + args.insert(1, file_path) + + call(args, stdout=DEVNULL, stderr=DEVNULL) + + if os.path.exists(os.path.join(output_folder, 'transforms.json')): + return {'sha256': sha256, 'rendered': True} + + +if __name__ == '__main__': + dataset_utils = importlib.import_module(f'datasets.{sys.argv[1]}') + + parser = argparse.ArgumentParser() + parser.add_argument('--output_dir', type=str, required=True, + help='Directory to save the metadata') + parser.add_argument('--filter_low_aesthetic_score', type=float, default=None, + help='Filter objects with aesthetic score lower than this value') + parser.add_argument('--instances', type=str, default=None, + help='Instances to process') + parser.add_argument('--num_views', type=int, default=150, + help='Number of views to render') + dataset_utils.add_args(parser) + parser.add_argument('--rank', type=int, default=0) + parser.add_argument('--world_size', type=int, default=1) + parser.add_argument('--max_workers', type=int, default=8) + opt = parser.parse_args(sys.argv[2:]) + opt = edict(vars(opt)) + + os.makedirs(os.path.join(opt.output_dir, 'renders'), exist_ok=True) + + # install blender + print('Checking blender...', flush=True) + _install_blender() + + # get file list + if not os.path.exists(os.path.join(opt.output_dir, 'metadata.csv')): + raise ValueError('metadata.csv not found') + metadata = pd.read_csv(os.path.join(opt.output_dir, 'metadata.csv')) + if opt.instances is None: + metadata = metadata[metadata['local_path'].notna()] + if opt.filter_low_aesthetic_score is not None: + metadata = metadata[metadata['aesthetic_score'] >= opt.filter_low_aesthetic_score] + if 'rendered' in metadata.columns: + metadata = metadata[metadata['rendered'] == False] + else: + if os.path.exists(opt.instances): + with open(opt.instances, 'r') as f: + instances = f.read().splitlines() + else: + instances = opt.instances.split(',') + metadata = metadata[metadata['sha256'].isin(instances)] + + start = len(metadata) * opt.rank // opt.world_size + end = len(metadata) * (opt.rank + 1) // opt.world_size + metadata = metadata[start:end] + records = [] + + # filter out objects that are already processed + for sha256 in copy.copy(metadata['sha256'].values): + if os.path.exists(os.path.join(opt.output_dir, 'renders', sha256, 'transforms.json')): + records.append({'sha256': sha256, 'rendered': True}) + metadata = metadata[metadata['sha256'] != sha256] + + print(f'Processing {len(metadata)} objects...') + + # process objects + func = partial(_render, output_dir=opt.output_dir, num_views=opt.num_views) + rendered = dataset_utils.foreach_instance(metadata, opt.output_dir, func, max_workers=opt.max_workers, desc='Rendering objects') + rendered = pd.concat([rendered, pd.DataFrame.from_records(records)]) + rendered.to_csv(os.path.join(opt.output_dir, f'rendered_{opt.rank}.csv'), index=False) diff --git a/TRELLIS/dataset_toolkits/render_cond.py b/TRELLIS/dataset_toolkits/render_cond.py new file mode 100644 index 0000000000000000000000000000000000000000..ced971ab05495b410cae9c583c4729fc91302d7a --- /dev/null +++ b/TRELLIS/dataset_toolkits/render_cond.py @@ -0,0 +1,125 @@ +import os +import json +import copy +import sys +import importlib +import argparse +import pandas as pd +from easydict import EasyDict as edict +from functools import partial +from subprocess import DEVNULL, call +import numpy as np +from utils import sphere_hammersley_sequence + + +BLENDER_LINK = 'https://download.blender.org/release/Blender3.0/blender-3.0.1-linux-x64.tar.xz' +BLENDER_INSTALLATION_PATH = '/tmp' +BLENDER_PATH = f'{BLENDER_INSTALLATION_PATH}/blender-3.0.1-linux-x64/blender' + +def _install_blender(): + if not os.path.exists(BLENDER_PATH): + os.system('sudo apt-get update') + os.system('sudo apt-get install -y libxrender1 libxi6 libxkbcommon-x11-0 libsm6') + os.system(f'wget {BLENDER_LINK} -P {BLENDER_INSTALLATION_PATH}') + os.system(f'tar -xvf {BLENDER_INSTALLATION_PATH}/blender-3.0.1-linux-x64.tar.xz -C {BLENDER_INSTALLATION_PATH}') + + +def _render_cond(file_path, sha256, output_dir, num_views): + output_folder = os.path.join(output_dir, 'renders_cond', sha256) + + # Build camera {yaw, pitch, radius, fov} + yaws = [] + pitchs = [] + offset = (np.random.rand(), np.random.rand()) + for i in range(num_views): + y, p = sphere_hammersley_sequence(i, num_views, offset) + yaws.append(y) + pitchs.append(p) + fov_min, fov_max = 10, 70 + radius_min = np.sqrt(3) / 2 / np.sin(fov_max / 360 * np.pi) + radius_max = np.sqrt(3) / 2 / np.sin(fov_min / 360 * np.pi) + k_min = 1 / radius_max**2 + k_max = 1 / radius_min**2 + ks = np.random.uniform(k_min, k_max, (1000000,)) + radius = [1 / np.sqrt(k) for k in ks] + fov = [2 * np.arcsin(np.sqrt(3) / 2 / r) for r in radius] + views = [{'yaw': y, 'pitch': p, 'radius': r, 'fov': f} for y, p, r, f in zip(yaws, pitchs, radius, fov)] + + args = [ + BLENDER_PATH, '-b', '-P', os.path.join(os.path.dirname(__file__), 'blender_script', 'render.py'), + '--', + '--views', json.dumps(views), + '--object', os.path.expanduser(file_path), + '--output_folder', os.path.expanduser(output_folder), + '--resolution', '1024', + ] + if file_path.endswith('.blend'): + args.insert(1, file_path) + + call(args, stdout=DEVNULL) + + if os.path.exists(os.path.join(output_folder, 'transforms.json')): + return {'sha256': sha256, 'cond_rendered': True} + + +if __name__ == '__main__': + dataset_utils = importlib.import_module(f'datasets.{sys.argv[1]}') + + parser = argparse.ArgumentParser() + parser.add_argument('--output_dir', type=str, required=True, + help='Directory to save the metadata') + parser.add_argument('--filter_low_aesthetic_score', type=float, default=None, + help='Filter objects with aesthetic score lower than this value') + parser.add_argument('--instances', type=str, default=None, + help='Instances to process') + parser.add_argument('--num_views', type=int, default=24, + help='Number of views to render') + dataset_utils.add_args(parser) + parser.add_argument('--rank', type=int, default=0) + parser.add_argument('--world_size', type=int, default=1) + parser.add_argument('--max_workers', type=int, default=8) + opt = parser.parse_args(sys.argv[2:]) + opt = edict(vars(opt)) + + os.makedirs(os.path.join(opt.output_dir, 'renders_cond'), exist_ok=True) + + # install blender + print('Checking blender...', flush=True) + _install_blender() + + # get file list + if not os.path.exists(os.path.join(opt.output_dir, 'metadata.csv')): + raise ValueError('metadata.csv not found') + metadata = pd.read_csv(os.path.join(opt.output_dir, 'metadata.csv')) + if opt.instances is None: + metadata = metadata[metadata['local_path'].notna()] + if opt.filter_low_aesthetic_score is not None: + metadata = metadata[metadata['aesthetic_score'] >= opt.filter_low_aesthetic_score] + if 'cond_rendered' in metadata.columns: + metadata = metadata[metadata['cond_rendered'] == False] + else: + if os.path.exists(opt.instances): + with open(opt.instances, 'r') as f: + instances = f.read().splitlines() + else: + instances = opt.instances.split(',') + metadata = metadata[metadata['sha256'].isin(instances)] + + start = len(metadata) * opt.rank // opt.world_size + end = len(metadata) * (opt.rank + 1) // opt.world_size + metadata = metadata[start:end] + records = [] + + # filter out objects that are already processed + for sha256 in copy.copy(metadata['sha256'].values): + if os.path.exists(os.path.join(opt.output_dir, 'renders_cond', sha256, 'transforms.json')): + records.append({'sha256': sha256, 'cond_rendered': True}) + metadata = metadata[metadata['sha256'] != sha256] + + print(f'Processing {len(metadata)} objects...') + + # process objects + func = partial(_render_cond, output_dir=opt.output_dir, num_views=opt.num_views) + cond_rendered = dataset_utils.foreach_instance(metadata, opt.output_dir, func, max_workers=opt.max_workers, desc='Rendering objects') + cond_rendered = pd.concat([cond_rendered, pd.DataFrame.from_records(records)]) + cond_rendered.to_csv(os.path.join(opt.output_dir, f'cond_rendered_{opt.rank}.csv'), index=False) diff --git a/TRELLIS/dataset_toolkits/setup.sh b/TRELLIS/dataset_toolkits/setup.sh new file mode 100644 index 0000000000000000000000000000000000000000..ef009d0e20a198f758cb236e64839b04069684fb --- /dev/null +++ b/TRELLIS/dataset_toolkits/setup.sh @@ -0,0 +1 @@ +pip install pillow imageio imageio-ffmpeg tqdm easydict opencv-python-headless pandas open3d objaverse huggingface_hub diff --git a/TRELLIS/dataset_toolkits/stat_latent.py b/TRELLIS/dataset_toolkits/stat_latent.py new file mode 100644 index 0000000000000000000000000000000000000000..40290de42386f47b8e7da33f51fcf4b384e46a5e --- /dev/null +++ b/TRELLIS/dataset_toolkits/stat_latent.py @@ -0,0 +1,66 @@ +import os +import json +import argparse +import numpy as np +import pandas as pd +from tqdm import tqdm +from easydict import EasyDict as edict +from concurrent.futures import ThreadPoolExecutor + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('--output_dir', type=str, required=True, + help='Directory to save the metadata') + parser.add_argument('--filter_low_aesthetic_score', type=float, default=None, + help='Filter objects with aesthetic score lower than this value') + parser.add_argument('--model', type=str, default='dinov2_vitl14_reg_slat_enc_swin8_B_64l8_fp16', + help='Latent model to use') + parser.add_argument('--num_samples', type=int, default=50000, + help='Number of samples to use for calculating stats') + opt = parser.parse_args() + opt = edict(vars(opt)) + + # get file list + if os.path.exists(os.path.join(opt.output_dir, 'metadata.csv')): + metadata = pd.read_csv(os.path.join(opt.output_dir, 'metadata.csv')) + else: + raise ValueError('metadata.csv not found') + if opt.filter_low_aesthetic_score is not None: + metadata = metadata[metadata['aesthetic_score'] >= opt.filter_low_aesthetic_score] + metadata = metadata[metadata[f'latent_{opt.model}'] == True] + sha256s = metadata['sha256'].values + sha256s = np.random.choice(sha256s, min(opt.num_samples, len(sha256s)), replace=False) + + # stats + means = [] + mean2s = [] + with ThreadPoolExecutor(max_workers=16) as executor, \ + tqdm(total=len(sha256s), desc="Extracting features") as pbar: + def worker(sha256): + try: + feats = np.load(os.path.join(opt.output_dir, 'latents', opt.model, f'{sha256}.npz')) + feats = feats['feats'] + means.append(feats.mean(axis=0)) + mean2s.append((feats ** 2).mean(axis=0)) + pbar.update() + except Exception as e: + print(f"Error extracting features for {sha256}: {e}") + pbar.update() + + executor.map(worker, sha256s) + executor.shutdown(wait=True) + + mean = np.array(means).mean(axis=0) + mean2 = np.array(mean2s).mean(axis=0) + std = np.sqrt(mean2 - mean ** 2) + + print('mean:', mean) + print('std:', std) + + with open(os.path.join(opt.output_dir, 'latents', opt.model, 'stats.json'), 'w') as f: + json.dump({ + 'mean': mean.tolist(), + 'std': std.tolist(), + }, f, indent=4) + \ No newline at end of file diff --git a/TRELLIS/dataset_toolkits/utils.py b/TRELLIS/dataset_toolkits/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..8013c9a3854a1d696fd02d17bec87c5ff93563fc --- /dev/null +++ b/TRELLIS/dataset_toolkits/utils.py @@ -0,0 +1,43 @@ +from typing import * +import hashlib +import numpy as np + + +def get_file_hash(file: str) -> str: + sha256 = hashlib.sha256() + # Read the file from the path + with open(file, "rb") as f: + # Update the hash with the file content + for byte_block in iter(lambda: f.read(4096), b""): + sha256.update(byte_block) + return sha256.hexdigest() + +# ===============LOW DISCREPANCY SEQUENCES================ + +PRIMES = [2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37, 41, 43, 47, 53] + +def radical_inverse(base, n): + val = 0 + inv_base = 1.0 / base + inv_base_n = inv_base + while n > 0: + digit = n % base + val += digit * inv_base_n + n //= base + inv_base_n *= inv_base + return val + +def halton_sequence(dim, n): + return [radical_inverse(PRIMES[dim], n) for dim in range(dim)] + +def hammersley_sequence(dim, n, num_samples): + return [n / num_samples] + halton_sequence(dim - 1, n) + +def sphere_hammersley_sequence(n, num_samples, offset=(0, 0)): + u, v = hammersley_sequence(2, n, num_samples) + u += offset[0] / num_samples + v += offset[1] + u = 2 * u if u < 0.25 else 2 / 3 * u + 1 / 3 + theta = np.arccos(1 - 2 * u) - np.pi / 2 + phi = v * 2 * np.pi + return [phi, theta] diff --git a/TRELLIS/dataset_toolkits/voxelize.py b/TRELLIS/dataset_toolkits/voxelize.py new file mode 100644 index 0000000000000000000000000000000000000000..355d59aaae7a4f23002584e0d03329a09d94e8a1 --- /dev/null +++ b/TRELLIS/dataset_toolkits/voxelize.py @@ -0,0 +1,86 @@ +import os +import copy +import sys +import importlib +import argparse +import pandas as pd +from easydict import EasyDict as edict +from functools import partial +import numpy as np +import open3d as o3d +import utils3d + + +def _voxelize(file, sha256, output_dir): + mesh = o3d.io.read_triangle_mesh(os.path.join(output_dir, 'renders', sha256, 'mesh.ply')) + # clamp vertices to the range [-0.5, 0.5] + vertices = np.clip(np.asarray(mesh.vertices), -0.5 + 1e-6, 0.5 - 1e-6) + mesh.vertices = o3d.utility.Vector3dVector(vertices) + voxel_grid = o3d.geometry.VoxelGrid.create_from_triangle_mesh_within_bounds(mesh, voxel_size=1/64, min_bound=(-0.5, -0.5, -0.5), max_bound=(0.5, 0.5, 0.5)) + vertices = np.array([voxel.grid_index for voxel in voxel_grid.get_voxels()]) + assert np.all(vertices >= 0) and np.all(vertices < 64), "Some vertices are out of bounds" + vertices = (vertices + 0.5) / 64 - 0.5 + utils3d.io.write_ply(os.path.join(output_dir, 'voxels', f'{sha256}.ply'), vertices) + return {'sha256': sha256, 'voxelized': True, 'num_voxels': len(vertices)} + + +if __name__ == '__main__': + dataset_utils = importlib.import_module(f'datasets.{sys.argv[1]}') + + parser = argparse.ArgumentParser() + parser.add_argument('--output_dir', type=str, required=True, + help='Directory to save the metadata') + parser.add_argument('--filter_low_aesthetic_score', type=float, default=None, + help='Filter objects with aesthetic score lower than this value') + parser.add_argument('--instances', type=str, default=None, + help='Instances to process') + parser.add_argument('--num_views', type=int, default=150, + help='Number of views to render') + dataset_utils.add_args(parser) + parser.add_argument('--rank', type=int, default=0) + parser.add_argument('--world_size', type=int, default=1) + parser.add_argument('--max_workers', type=int, default=None) + opt = parser.parse_args(sys.argv[2:]) + opt = edict(vars(opt)) + + os.makedirs(os.path.join(opt.output_dir, 'voxels'), exist_ok=True) + + # get file list + if not os.path.exists(os.path.join(opt.output_dir, 'metadata.csv')): + raise ValueError('metadata.csv not found') + metadata = pd.read_csv(os.path.join(opt.output_dir, 'metadata.csv')) + if opt.instances is None: + if opt.filter_low_aesthetic_score is not None: + metadata = metadata[metadata['aesthetic_score'] >= opt.filter_low_aesthetic_score] + if 'rendered' not in metadata.columns: + raise ValueError('metadata.csv does not have "rendered" column, please run "build_metadata.py" first') + metadata = metadata[metadata['rendered'] == True] + if 'voxelized' in metadata.columns: + metadata = metadata[metadata['voxelized'] == False] + else: + if os.path.exists(opt.instances): + with open(opt.instances, 'r') as f: + instances = f.read().splitlines() + else: + instances = opt.instances.split(',') + metadata = metadata[metadata['sha256'].isin(instances)] + + start = len(metadata) * opt.rank // opt.world_size + end = len(metadata) * (opt.rank + 1) // opt.world_size + metadata = metadata[start:end] + records = [] + + # filter out objects that are already processed + for sha256 in copy.copy(metadata['sha256'].values): + if os.path.exists(os.path.join(opt.output_dir, 'voxels', f'{sha256}.ply')): + pts = utils3d.io.read_ply(os.path.join(opt.output_dir, 'voxels', f'{sha256}.ply'))[0] + records.append({'sha256': sha256, 'voxelized': True, 'num_voxels': len(pts)}) + metadata = metadata[metadata['sha256'] != sha256] + + print(f'Processing {len(metadata)} objects...') + + # process objects + func = partial(_voxelize, output_dir=opt.output_dir) + voxelized = dataset_utils.foreach_instance(metadata, opt.output_dir, func, max_workers=opt.max_workers, desc='Voxelizing') + voxelized = pd.concat([voxelized, pd.DataFrame.from_records(records)]) + voxelized.to_csv(os.path.join(opt.output_dir, f'voxelized_{opt.rank}.csv'), index=False) diff --git a/TRELLIS/example.py b/TRELLIS/example.py new file mode 100644 index 0000000000000000000000000000000000000000..591faf33a2eb4f1984aa9dc20756839f748e17d5 --- /dev/null +++ b/TRELLIS/example.py @@ -0,0 +1,57 @@ +import os +# os.environ['ATTN_BACKEND'] = 'xformers' # Can be 'flash-attn' or 'xformers', default is 'flash-attn' +os.environ['SPCONV_ALGO'] = 'native' # Can be 'native' or 'auto', default is 'auto'. + # 'auto' is faster but will do benchmarking at the beginning. + # Recommended to set to 'native' if run only once. + +import imageio +from PIL import Image +from trellis.pipelines import TrellisImageTo3DPipeline +from trellis.utils import render_utils, postprocessing_utils + +# Load a pipeline from a model folder or a Hugging Face model hub. +pipeline = TrellisImageTo3DPipeline.from_pretrained("JeffreyXiang/TRELLIS-image-large") +pipeline.cuda() + +# Load an image +image = Image.open("assets/example_image/T.png") + +# Run the pipeline +outputs = pipeline.run( + image, + seed=1, + # Optional parameters + # sparse_structure_sampler_params={ + # "steps": 12, + # "cfg_strength": 7.5, + # }, + # slat_sampler_params={ + # "steps": 12, + # "cfg_strength": 3, + # }, +) +# outputs is a dictionary containing generated 3D assets in different formats: +# - outputs['gaussian']: a list of 3D Gaussians +# - outputs['radiance_field']: a list of radiance fields +# - outputs['mesh']: a list of meshes + +# Render the outputs +video = render_utils.render_video(outputs['gaussian'][0])['color'] +imageio.mimsave("sample_gs.mp4", video, fps=30) +video = render_utils.render_video(outputs['radiance_field'][0])['color'] +imageio.mimsave("sample_rf.mp4", video, fps=30) +video = render_utils.render_video(outputs['mesh'][0])['normal'] +imageio.mimsave("sample_mesh.mp4", video, fps=30) + +# GLB files can be extracted from the outputs +glb = postprocessing_utils.to_glb( + outputs['gaussian'][0], + outputs['mesh'][0], + # Optional parameters + simplify=0.95, # Ratio of triangles to remove in the simplification process + texture_size=1024, # Size of the texture used for the GLB +) +glb.export("sample.glb") + +# Save Gaussians as PLY files +outputs['gaussian'][0].save_ply("sample.ply") diff --git a/TRELLIS/example_multi_image.py b/TRELLIS/example_multi_image.py new file mode 100644 index 0000000000000000000000000000000000000000..cd5c3ba7a0a9f01df2a1f6fe70df39d251064307 --- /dev/null +++ b/TRELLIS/example_multi_image.py @@ -0,0 +1,46 @@ +import os +# os.environ['ATTN_BACKEND'] = 'xformers' # Can be 'flash-attn' or 'xformers', default is 'flash-attn' +os.environ['SPCONV_ALGO'] = 'native' # Can be 'native' or 'auto', default is 'auto'. + # 'auto' is faster but will do benchmarking at the beginning. + # Recommended to set to 'native' if run only once. + +import numpy as np +import imageio +from PIL import Image +from trellis.pipelines import TrellisImageTo3DPipeline +from trellis.utils import render_utils + +# Load a pipeline from a model folder or a Hugging Face model hub. +pipeline = TrellisImageTo3DPipeline.from_pretrained("JeffreyXiang/TRELLIS-image-large") +pipeline.cuda() + +# Load an image +images = [ + Image.open("assets/example_multi_image/character_1.png"), + Image.open("assets/example_multi_image/character_2.png"), + Image.open("assets/example_multi_image/character_3.png"), +] + +# Run the pipeline +outputs = pipeline.run_multi_image( + images, + seed=1, + # Optional parameters + sparse_structure_sampler_params={ + "steps": 12, + "cfg_strength": 7.5, + }, + slat_sampler_params={ + "steps": 12, + "cfg_strength": 3, + }, +) +# outputs is a dictionary containing generated 3D assets in different formats: +# - outputs['gaussian']: a list of 3D Gaussians +# - outputs['radiance_field']: a list of radiance fields +# - outputs['mesh']: a list of meshes + +video_gs = render_utils.render_video(outputs['gaussian'][0])['color'] +video_mesh = render_utils.render_video(outputs['mesh'][0])['normal'] +video = [np.concatenate([frame_gs, frame_mesh], axis=1) for frame_gs, frame_mesh in zip(video_gs, video_mesh)] +imageio.mimsave("sample_multi.mp4", video, fps=30) diff --git a/TRELLIS/extensions/vox2seq/benchmark.py b/TRELLIS/extensions/vox2seq/benchmark.py new file mode 100644 index 0000000000000000000000000000000000000000..92eae879700c3517d66ba29cf687aadfb4925369 --- /dev/null +++ b/TRELLIS/extensions/vox2seq/benchmark.py @@ -0,0 +1,45 @@ +import time +import torch +import vox2seq + + +if __name__ == "__main__": + stats = { + 'z_order_cuda': [], + 'z_order_pytorch': [], + 'hilbert_cuda': [], + 'hilbert_pytorch': [], + } + RES = [16, 32, 64, 128, 256] + for res in RES: + coords = torch.meshgrid(torch.arange(res), torch.arange(res), torch.arange(res)) + coords = torch.stack(coords, dim=-1).reshape(-1, 3).int().cuda() + + start = time.time() + for _ in range(100): + code_z_cuda = vox2seq.encode(coords, mode='z_order').cuda() + torch.cuda.synchronize() + stats['z_order_cuda'].append((time.time() - start) / 100) + + start = time.time() + for _ in range(100): + code_z_pytorch = vox2seq.pytorch.encode(coords, mode='z_order').cuda() + torch.cuda.synchronize() + stats['z_order_pytorch'].append((time.time() - start) / 100) + + start = time.time() + for _ in range(100): + code_h_cuda = vox2seq.encode(coords, mode='hilbert').cuda() + torch.cuda.synchronize() + stats['hilbert_cuda'].append((time.time() - start) / 100) + + start = time.time() + for _ in range(100): + code_h_pytorch = vox2seq.pytorch.encode(coords, mode='hilbert').cuda() + torch.cuda.synchronize() + stats['hilbert_pytorch'].append((time.time() - start) / 100) + + print(f"{'Resolution':<12}{'Z-Order (CUDA)':<24}{'Z-Order (PyTorch)':<24}{'Hilbert (CUDA)':<24}{'Hilbert (PyTorch)':<24}") + for res, z_order_cuda, z_order_pytorch, hilbert_cuda, hilbert_pytorch in zip(RES, stats['z_order_cuda'], stats['z_order_pytorch'], stats['hilbert_cuda'], stats['hilbert_pytorch']): + print(f"{res:<12}{z_order_cuda:<24.6f}{z_order_pytorch:<24.6f}{hilbert_cuda:<24.6f}{hilbert_pytorch:<24.6f}") + diff --git a/TRELLIS/extensions/vox2seq/setup.py b/TRELLIS/extensions/vox2seq/setup.py new file mode 100644 index 0000000000000000000000000000000000000000..6184ca03ed2490879450c4a8c4a9383e230b54d4 --- /dev/null +++ b/TRELLIS/extensions/vox2seq/setup.py @@ -0,0 +1,34 @@ +# +# Copyright (C) 2023, Inria +# GRAPHDECO research group, https://team.inria.fr/graphdeco +# All rights reserved. +# +# This software is free for non-commercial, research and evaluation use +# under the terms of the LICENSE.md file. +# +# For inquiries contact george.drettakis@inria.fr +# + +from setuptools import setup +from torch.utils.cpp_extension import CUDAExtension, BuildExtension +import os +os.path.dirname(os.path.abspath(__file__)) + +setup( + name="vox2seq", + packages=['vox2seq', 'vox2seq.pytorch'], + ext_modules=[ + CUDAExtension( + name="vox2seq._C", + sources=[ + "src/api.cu", + "src/z_order.cu", + "src/hilbert.cu", + "src/ext.cpp", + ], + ) + ], + cmdclass={ + 'build_ext': BuildExtension + } +) diff --git a/TRELLIS/extensions/vox2seq/src/api.cu b/TRELLIS/extensions/vox2seq/src/api.cu new file mode 100644 index 0000000000000000000000000000000000000000..44635f94fc99966fb6fb4343eb5b8d1c843785b0 --- /dev/null +++ b/TRELLIS/extensions/vox2seq/src/api.cu @@ -0,0 +1,92 @@ +#include +#include "api.h" +#include "z_order.h" +#include "hilbert.h" + + +torch::Tensor +z_order_encode( + const torch::Tensor& x, + const torch::Tensor& y, + const torch::Tensor& z +) { + // Allocate output tensor + torch::Tensor codes = torch::empty_like(x); + + // Call CUDA kernel + z_order_encode_cuda<<<(x.size(0) + BLOCK_SIZE - 1) / BLOCK_SIZE, BLOCK_SIZE>>>( + x.size(0), + reinterpret_cast(x.contiguous().data_ptr()), + reinterpret_cast(y.contiguous().data_ptr()), + reinterpret_cast(z.contiguous().data_ptr()), + reinterpret_cast(codes.data_ptr()) + ); + + return codes; +} + + +std::tuple +z_order_decode( + const torch::Tensor& codes +) { + // Allocate output tensors + torch::Tensor x = torch::empty_like(codes); + torch::Tensor y = torch::empty_like(codes); + torch::Tensor z = torch::empty_like(codes); + + // Call CUDA kernel + z_order_decode_cuda<<<(codes.size(0) + BLOCK_SIZE - 1) / BLOCK_SIZE, BLOCK_SIZE>>>( + codes.size(0), + reinterpret_cast(codes.contiguous().data_ptr()), + reinterpret_cast(x.data_ptr()), + reinterpret_cast(y.data_ptr()), + reinterpret_cast(z.data_ptr()) + ); + + return std::make_tuple(x, y, z); +} + + +torch::Tensor +hilbert_encode( + const torch::Tensor& x, + const torch::Tensor& y, + const torch::Tensor& z +) { + // Allocate output tensor + torch::Tensor codes = torch::empty_like(x); + + // Call CUDA kernel + hilbert_encode_cuda<<<(x.size(0) + BLOCK_SIZE - 1) / BLOCK_SIZE, BLOCK_SIZE>>>( + x.size(0), + reinterpret_cast(x.contiguous().data_ptr()), + reinterpret_cast(y.contiguous().data_ptr()), + reinterpret_cast(z.contiguous().data_ptr()), + reinterpret_cast(codes.data_ptr()) + ); + + return codes; +} + + +std::tuple +hilbert_decode( + const torch::Tensor& codes +) { + // Allocate output tensors + torch::Tensor x = torch::empty_like(codes); + torch::Tensor y = torch::empty_like(codes); + torch::Tensor z = torch::empty_like(codes); + + // Call CUDA kernel + hilbert_decode_cuda<<<(codes.size(0) + BLOCK_SIZE - 1) / BLOCK_SIZE, BLOCK_SIZE>>>( + codes.size(0), + reinterpret_cast(codes.contiguous().data_ptr()), + reinterpret_cast(x.data_ptr()), + reinterpret_cast(y.data_ptr()), + reinterpret_cast(z.data_ptr()) + ); + + return std::make_tuple(x, y, z); +} diff --git a/TRELLIS/extensions/vox2seq/src/api.h b/TRELLIS/extensions/vox2seq/src/api.h new file mode 100644 index 0000000000000000000000000000000000000000..9ed6e33be15156ce4a425f10375af555530f5bbc --- /dev/null +++ b/TRELLIS/extensions/vox2seq/src/api.h @@ -0,0 +1,76 @@ +/* + * Serialize a voxel grid + * + * Copyright (C) 2024, Jianfeng XIANG + * All rights reserved. + * + * Licensed under The MIT License [see LICENSE for details] + * + * Written by Jianfeng XIANG + */ + +#pragma once +#include + + +#define BLOCK_SIZE 256 + + +/** + * Z-order encode 3D points + * + * @param x [N] tensor containing the x coordinates + * @param y [N] tensor containing the y coordinates + * @param z [N] tensor containing the z coordinates + * + * @return [N] tensor containing the z-order encoded values + */ +torch::Tensor +z_order_encode( + const torch::Tensor& x, + const torch::Tensor& y, + const torch::Tensor& z +); + + +/** + * Z-order decode 3D points + * + * @param codes [N] tensor containing the z-order encoded values + * + * @return 3 tensors [N] containing the x, y, z coordinates + */ +std::tuple +z_order_decode( + const torch::Tensor& codes +); + + +/** + * Hilbert encode 3D points + * + * @param x [N] tensor containing the x coordinates + * @param y [N] tensor containing the y coordinates + * @param z [N] tensor containing the z coordinates + * + * @return [N] tensor containing the Hilbert encoded values + */ +torch::Tensor +hilbert_encode( + const torch::Tensor& x, + const torch::Tensor& y, + const torch::Tensor& z +); + + +/** + * Hilbert decode 3D points + * + * @param codes [N] tensor containing the Hilbert encoded values + * + * @return 3 tensors [N] containing the x, y, z coordinates + */ +std::tuple +hilbert_decode( + const torch::Tensor& codes +); diff --git a/TRELLIS/extensions/vox2seq/src/ext.cpp b/TRELLIS/extensions/vox2seq/src/ext.cpp new file mode 100644 index 0000000000000000000000000000000000000000..15e12f57396504741613774cabab082b6c8d4f24 --- /dev/null +++ b/TRELLIS/extensions/vox2seq/src/ext.cpp @@ -0,0 +1,10 @@ +#include +#include "api.h" + + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("z_order_encode", &z_order_encode); + m.def("z_order_decode", &z_order_decode); + m.def("hilbert_encode", &hilbert_encode); + m.def("hilbert_decode", &hilbert_decode); +} \ No newline at end of file diff --git a/TRELLIS/extensions/vox2seq/src/hilbert.cu b/TRELLIS/extensions/vox2seq/src/hilbert.cu new file mode 100644 index 0000000000000000000000000000000000000000..706440cc30935b0b963f41b3b2c816497bb298df --- /dev/null +++ b/TRELLIS/extensions/vox2seq/src/hilbert.cu @@ -0,0 +1,133 @@ +#include +#include +#include + +#include +#include +namespace cg = cooperative_groups; + +#include "hilbert.h" + + +// Expands a 10-bit integer into 30 bits by inserting 2 zeros after each bit. +static __device__ uint32_t expandBits(uint32_t v) +{ + v = (v * 0x00010001u) & 0xFF0000FFu; + v = (v * 0x00000101u) & 0x0F00F00Fu; + v = (v * 0x00000011u) & 0xC30C30C3u; + v = (v * 0x00000005u) & 0x49249249u; + return v; +} + + +// Removes 2 zeros after each bit in a 30-bit integer. +static __device__ uint32_t extractBits(uint32_t v) +{ + v = v & 0x49249249; + v = (v ^ (v >> 2)) & 0x030C30C3u; + v = (v ^ (v >> 4)) & 0x0300F00Fu; + v = (v ^ (v >> 8)) & 0x030000FFu; + v = (v ^ (v >> 16)) & 0x000003FFu; + return v; +} + + +__global__ void hilbert_encode_cuda( + size_t N, + const uint32_t* x, + const uint32_t* y, + const uint32_t* z, + uint32_t* codes +) { + size_t thread_id = cg::this_grid().thread_rank(); + if (thread_id >= N) return; + + uint32_t point[3] = {x[thread_id], y[thread_id], z[thread_id]}; + + uint32_t m = 1 << 9, q, p, t; + + // Inverse undo excess work + q = m; + while (q > 1) { + p = q - 1; + for (int i = 0; i < 3; i++) { + if (point[i] & q) { + point[0] ^= p; // invert + } else { + t = (point[0] ^ point[i]) & p; + point[0] ^= t; + point[i] ^= t; + } + } + q >>= 1; + } + + // Gray encode + for (int i = 1; i < 3; i++) { + point[i] ^= point[i - 1]; + } + t = 0; + q = m; + while (q > 1) { + if (point[2] & q) { + t ^= q - 1; + } + q >>= 1; + } + for (int i = 0; i < 3; i++) { + point[i] ^= t; + } + + // Convert to 3D Hilbert code + uint32_t xx = expandBits(point[0]); + uint32_t yy = expandBits(point[1]); + uint32_t zz = expandBits(point[2]); + + codes[thread_id] = xx * 4 + yy * 2 + zz; +} + + +__global__ void hilbert_decode_cuda( + size_t N, + const uint32_t* codes, + uint32_t* x, + uint32_t* y, + uint32_t* z +) { + size_t thread_id = cg::this_grid().thread_rank(); + if (thread_id >= N) return; + + uint32_t point[3]; + point[0] = extractBits(codes[thread_id] >> 2); + point[1] = extractBits(codes[thread_id] >> 1); + point[2] = extractBits(codes[thread_id]); + + uint32_t m = 2 << 9, q, p, t; + + // Gray decode by H ^ (H/2) + t = point[2] >> 1; + for (int i = 2; i > 0; i--) { + point[i] ^= point[i - 1]; + } + point[0] ^= t; + + // Undo excess work + q = 2; + while (q != m) { + p = q - 1; + for (int i = 2; i >= 0; i--) { + if (point[i] & q) { + point[0] ^= p; + } else { + t = (point[0] ^ point[i]) & p; + point[0] ^= t; + point[i] ^= t; + } + } + q <<= 1; + } + + x[thread_id] = point[0]; + y[thread_id] = point[1]; + z[thread_id] = point[2]; +} diff --git a/TRELLIS/extensions/vox2seq/src/hilbert.h b/TRELLIS/extensions/vox2seq/src/hilbert.h new file mode 100644 index 0000000000000000000000000000000000000000..e8fbea16ca6df74b03fa0a544ca75819a1867fc4 --- /dev/null +++ b/TRELLIS/extensions/vox2seq/src/hilbert.h @@ -0,0 +1,35 @@ +#pragma once + +/** + * Hilbert encode 3D points + * + * @param x [N] tensor containing the x coordinates + * @param y [N] tensor containing the y coordinates + * @param z [N] tensor containing the z coordinates + * + * @return [N] tensor containing the z-order encoded values + */ +__global__ void hilbert_encode_cuda( + size_t N, + const uint32_t* x, + const uint32_t* y, + const uint32_t* z, + uint32_t* codes +); + + +/** + * Hilbert decode 3D points + * + * @param codes [N] tensor containing the z-order encoded values + * @param x [N] tensor containing the x coordinates + * @param y [N] tensor containing the y coordinates + * @param z [N] tensor containing the z coordinates + */ +__global__ void hilbert_decode_cuda( + size_t N, + const uint32_t* codes, + uint32_t* x, + uint32_t* y, + uint32_t* z +); diff --git a/TRELLIS/extensions/vox2seq/src/z_order.cu b/TRELLIS/extensions/vox2seq/src/z_order.cu new file mode 100644 index 0000000000000000000000000000000000000000..23bcf7a1fa8decc1199badb1d6d266f5b076e3f8 --- /dev/null +++ b/TRELLIS/extensions/vox2seq/src/z_order.cu @@ -0,0 +1,66 @@ +#include +#include +#include + +#include +#include +namespace cg = cooperative_groups; + +#include "z_order.h" + + +// Expands a 10-bit integer into 30 bits by inserting 2 zeros after each bit. +static __device__ uint32_t expandBits(uint32_t v) +{ + v = (v * 0x00010001u) & 0xFF0000FFu; + v = (v * 0x00000101u) & 0x0F00F00Fu; + v = (v * 0x00000011u) & 0xC30C30C3u; + v = (v * 0x00000005u) & 0x49249249u; + return v; +} + + +// Removes 2 zeros after each bit in a 30-bit integer. +static __device__ uint32_t extractBits(uint32_t v) +{ + v = v & 0x49249249; + v = (v ^ (v >> 2)) & 0x030C30C3u; + v = (v ^ (v >> 4)) & 0x0300F00Fu; + v = (v ^ (v >> 8)) & 0x030000FFu; + v = (v ^ (v >> 16)) & 0x000003FFu; + return v; +} + + +__global__ void z_order_encode_cuda( + size_t N, + const uint32_t* x, + const uint32_t* y, + const uint32_t* z, + uint32_t* codes +) { + size_t thread_id = cg::this_grid().thread_rank(); + if (thread_id >= N) return; + + uint32_t xx = expandBits(x[thread_id]); + uint32_t yy = expandBits(y[thread_id]); + uint32_t zz = expandBits(z[thread_id]); + + codes[thread_id] = xx * 4 + yy * 2 + zz; +} + + +__global__ void z_order_decode_cuda( + size_t N, + const uint32_t* codes, + uint32_t* x, + uint32_t* y, + uint32_t* z +) { + size_t thread_id = cg::this_grid().thread_rank(); + if (thread_id >= N) return; + + x[thread_id] = extractBits(codes[thread_id] >> 2); + y[thread_id] = extractBits(codes[thread_id] >> 1); + z[thread_id] = extractBits(codes[thread_id]); +} diff --git a/TRELLIS/extensions/vox2seq/src/z_order.h b/TRELLIS/extensions/vox2seq/src/z_order.h new file mode 100644 index 0000000000000000000000000000000000000000..ef9d5e0ff4c05bba2d9bd193c25c57e9dc9d69d9 --- /dev/null +++ b/TRELLIS/extensions/vox2seq/src/z_order.h @@ -0,0 +1,35 @@ +#pragma once + +/** + * Z-order encode 3D points + * + * @param x [N] tensor containing the x coordinates + * @param y [N] tensor containing the y coordinates + * @param z [N] tensor containing the z coordinates + * + * @return [N] tensor containing the z-order encoded values + */ +__global__ void z_order_encode_cuda( + size_t N, + const uint32_t* x, + const uint32_t* y, + const uint32_t* z, + uint32_t* codes +); + + +/** + * Z-order decode 3D points + * + * @param codes [N] tensor containing the z-order encoded values + * @param x [N] tensor containing the x coordinates + * @param y [N] tensor containing the y coordinates + * @param z [N] tensor containing the z coordinates + */ +__global__ void z_order_decode_cuda( + size_t N, + const uint32_t* codes, + uint32_t* x, + uint32_t* y, + uint32_t* z +); diff --git a/TRELLIS/extensions/vox2seq/test.py b/TRELLIS/extensions/vox2seq/test.py new file mode 100644 index 0000000000000000000000000000000000000000..bcc16831d9bc4ec0e34695a3299db5326112056f --- /dev/null +++ b/TRELLIS/extensions/vox2seq/test.py @@ -0,0 +1,25 @@ +import torch +import vox2seq + + +if __name__ == "__main__": + RES = 256 + coords = torch.meshgrid(torch.arange(RES), torch.arange(RES), torch.arange(RES)) + coords = torch.stack(coords, dim=-1).reshape(-1, 3).int().cuda() + code_z_cuda = vox2seq.encode(coords, mode='z_order') + code_z_pytorch = vox2seq.pytorch.encode(coords, mode='z_order') + code_h_cuda = vox2seq.encode(coords, mode='hilbert') + code_h_pytorch = vox2seq.pytorch.encode(coords, mode='hilbert') + assert torch.equal(code_z_cuda, code_z_pytorch) + assert torch.equal(code_h_cuda, code_h_pytorch) + + code = torch.arange(RES**3).int().cuda() + coords_z_cuda = vox2seq.decode(code, mode='z_order') + coords_z_pytorch = vox2seq.pytorch.decode(code, mode='z_order') + coords_h_cuda = vox2seq.decode(code, mode='hilbert') + coords_h_pytorch = vox2seq.pytorch.decode(code, mode='hilbert') + assert torch.equal(coords_z_cuda, coords_z_pytorch) + assert torch.equal(coords_h_cuda, coords_h_pytorch) + + print("All tests passed.") + diff --git a/TRELLIS/extensions/vox2seq/vox2seq/__init__.py b/TRELLIS/extensions/vox2seq/vox2seq/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..71c0a1c49c27d4384fded2d7a60df7258903052c --- /dev/null +++ b/TRELLIS/extensions/vox2seq/vox2seq/__init__.py @@ -0,0 +1,50 @@ + +from typing import * +import torch +from . import _C +from . import pytorch + + +@torch.no_grad() +def encode(coords: torch.Tensor, permute: List[int] = [0, 1, 2], mode: Literal['z_order', 'hilbert'] = 'z_order') -> torch.Tensor: + """ + Encodes 3D coordinates into a 30-bit code. + + Args: + coords: a tensor of shape [N, 3] containing the 3D coordinates. + permute: the permutation of the coordinates. + mode: the encoding mode to use. + """ + assert coords.shape[-1] == 3 and coords.ndim == 2, "Input coordinates must be of shape [N, 3]" + x = coords[:, permute[0]].int() + y = coords[:, permute[1]].int() + z = coords[:, permute[2]].int() + if mode == 'z_order': + return _C.z_order_encode(x, y, z) + elif mode == 'hilbert': + return _C.hilbert_encode(x, y, z) + else: + raise ValueError(f"Unknown encoding mode: {mode}") + + +@torch.no_grad() +def decode(code: torch.Tensor, permute: List[int] = [0, 1, 2], mode: Literal['z_order', 'hilbert'] = 'z_order') -> torch.Tensor: + """ + Decodes a 30-bit code into 3D coordinates. + + Args: + code: a tensor of shape [N] containing the 30-bit code. + permute: the permutation of the coordinates. + mode: the decoding mode to use. + """ + assert code.ndim == 1, "Input code must be of shape [N]" + if mode == 'z_order': + coords = _C.z_order_decode(code) + elif mode == 'hilbert': + coords = _C.hilbert_decode(code) + else: + raise ValueError(f"Unknown decoding mode: {mode}") + x = coords[permute.index(0)] + y = coords[permute.index(1)] + z = coords[permute.index(2)] + return torch.stack([x, y, z], dim=-1) diff --git a/TRELLIS/extensions/vox2seq/vox2seq/pytorch/__init__.py b/TRELLIS/extensions/vox2seq/vox2seq/pytorch/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..3d82da4729218e4ebee34a953ef312e745fc913a --- /dev/null +++ b/TRELLIS/extensions/vox2seq/vox2seq/pytorch/__init__.py @@ -0,0 +1,48 @@ +import torch +from typing import * + +from .default import ( + encode, + decode, + z_order_encode, + z_order_decode, + hilbert_encode, + hilbert_decode, +) + + +@torch.no_grad() +def encode(coords: torch.Tensor, permute: List[int] = [0, 1, 2], mode: Literal['z_order', 'hilbert'] = 'z_order') -> torch.Tensor: + """ + Encodes 3D coordinates into a 30-bit code. + + Args: + coords: a tensor of shape [N, 3] containing the 3D coordinates. + permute: the permutation of the coordinates. + mode: the encoding mode to use. + """ + if mode == 'z_order': + return z_order_encode(coords[:, permute], depth=10).int() + elif mode == 'hilbert': + return hilbert_encode(coords[:, permute], depth=10).int() + else: + raise ValueError(f"Unknown encoding mode: {mode}") + + +@torch.no_grad() +def decode(code: torch.Tensor, permute: List[int] = [0, 1, 2], mode: Literal['z_order', 'hilbert'] = 'z_order') -> torch.Tensor: + """ + Decodes a 30-bit code into 3D coordinates. + + Args: + code: a tensor of shape [N] containing the 30-bit code. + permute: the permutation of the coordinates. + mode: the decoding mode to use. + """ + if mode == 'z_order': + return z_order_decode(code, depth=10)[:, permute].float() + elif mode == 'hilbert': + return hilbert_decode(code, depth=10)[:, permute].float() + else: + raise ValueError(f"Unknown decoding mode: {mode}") + \ No newline at end of file diff --git a/TRELLIS/extensions/vox2seq/vox2seq/pytorch/default.py b/TRELLIS/extensions/vox2seq/vox2seq/pytorch/default.py new file mode 100644 index 0000000000000000000000000000000000000000..07dac7aac53345a6b9552be8fde71fa75bdcc421 --- /dev/null +++ b/TRELLIS/extensions/vox2seq/vox2seq/pytorch/default.py @@ -0,0 +1,59 @@ +import torch +from .z_order import xyz2key as z_order_encode_ +from .z_order import key2xyz as z_order_decode_ +from .hilbert import encode as hilbert_encode_ +from .hilbert import decode as hilbert_decode_ + + +@torch.inference_mode() +def encode(grid_coord, batch=None, depth=16, order="z"): + assert order in {"z", "z-trans", "hilbert", "hilbert-trans"} + if order == "z": + code = z_order_encode(grid_coord, depth=depth) + elif order == "z-trans": + code = z_order_encode(grid_coord[:, [1, 0, 2]], depth=depth) + elif order == "hilbert": + code = hilbert_encode(grid_coord, depth=depth) + elif order == "hilbert-trans": + code = hilbert_encode(grid_coord[:, [1, 0, 2]], depth=depth) + else: + raise NotImplementedError + if batch is not None: + batch = batch.long() + code = batch << depth * 3 | code + return code + + +@torch.inference_mode() +def decode(code, depth=16, order="z"): + assert order in {"z", "hilbert"} + batch = code >> depth * 3 + code = code & ((1 << depth * 3) - 1) + if order == "z": + grid_coord = z_order_decode(code, depth=depth) + elif order == "hilbert": + grid_coord = hilbert_decode(code, depth=depth) + else: + raise NotImplementedError + return grid_coord, batch + + +def z_order_encode(grid_coord: torch.Tensor, depth: int = 16): + x, y, z = grid_coord[:, 0].long(), grid_coord[:, 1].long(), grid_coord[:, 2].long() + # we block the support to batch, maintain batched code in Point class + code = z_order_encode_(x, y, z, b=None, depth=depth) + return code + + +def z_order_decode(code: torch.Tensor, depth): + x, y, z, _ = z_order_decode_(code, depth=depth) + grid_coord = torch.stack([x, y, z], dim=-1) # (N, 3) + return grid_coord + + +def hilbert_encode(grid_coord: torch.Tensor, depth: int = 16): + return hilbert_encode_(grid_coord, num_dims=3, num_bits=depth) + + +def hilbert_decode(code: torch.Tensor, depth: int = 16): + return hilbert_decode_(code, num_dims=3, num_bits=depth) \ No newline at end of file diff --git a/TRELLIS/extensions/vox2seq/vox2seq/pytorch/hilbert.py b/TRELLIS/extensions/vox2seq/vox2seq/pytorch/hilbert.py new file mode 100644 index 0000000000000000000000000000000000000000..8efe44720caaff4c1343b53d3dce9764ae8fd938 --- /dev/null +++ b/TRELLIS/extensions/vox2seq/vox2seq/pytorch/hilbert.py @@ -0,0 +1,303 @@ +""" +Hilbert Order +Modified from https://github.com/PrincetonLIPS/numpy-hilbert-curve + +Author: Xiaoyang Wu (xiaoyang.wu.cs@gmail.com), Kaixin Xu +Please cite our work if the code is helpful to you. +""" + +import torch + + +def right_shift(binary, k=1, axis=-1): + """Right shift an array of binary values. + + Parameters: + ----------- + binary: An ndarray of binary values. + + k: The number of bits to shift. Default 1. + + axis: The axis along which to shift. Default -1. + + Returns: + -------- + Returns an ndarray with zero prepended and the ends truncated, along + whatever axis was specified.""" + + # If we're shifting the whole thing, just return zeros. + if binary.shape[axis] <= k: + return torch.zeros_like(binary) + + # Determine the padding pattern. + # padding = [(0,0)] * len(binary.shape) + # padding[axis] = (k,0) + + # Determine the slicing pattern to eliminate just the last one. + slicing = [slice(None)] * len(binary.shape) + slicing[axis] = slice(None, -k) + shifted = torch.nn.functional.pad( + binary[tuple(slicing)], (k, 0), mode="constant", value=0 + ) + + return shifted + + +def binary2gray(binary, axis=-1): + """Convert an array of binary values into Gray codes. + + This uses the classic X ^ (X >> 1) trick to compute the Gray code. + + Parameters: + ----------- + binary: An ndarray of binary values. + + axis: The axis along which to compute the gray code. Default=-1. + + Returns: + -------- + Returns an ndarray of Gray codes. + """ + shifted = right_shift(binary, axis=axis) + + # Do the X ^ (X >> 1) trick. + gray = torch.logical_xor(binary, shifted) + + return gray + + +def gray2binary(gray, axis=-1): + """Convert an array of Gray codes back into binary values. + + Parameters: + ----------- + gray: An ndarray of gray codes. + + axis: The axis along which to perform Gray decoding. Default=-1. + + Returns: + -------- + Returns an ndarray of binary values. + """ + + # Loop the log2(bits) number of times necessary, with shift and xor. + shift = 2 ** (torch.Tensor([gray.shape[axis]]).log2().ceil().int() - 1) + while shift > 0: + gray = torch.logical_xor(gray, right_shift(gray, shift)) + shift = torch.div(shift, 2, rounding_mode="floor") + return gray + + +def encode(locs, num_dims, num_bits): + """Decode an array of locations in a hypercube into a Hilbert integer. + + This is a vectorized-ish version of the Hilbert curve implementation by John + Skilling as described in: + + Skilling, J. (2004, April). Programming the Hilbert curve. In AIP Conference + Proceedings (Vol. 707, No. 1, pp. 381-387). American Institute of Physics. + + Params: + ------- + locs - An ndarray of locations in a hypercube of num_dims dimensions, in + which each dimension runs from 0 to 2**num_bits-1. The shape can + be arbitrary, as long as the last dimension of the same has size + num_dims. + + num_dims - The dimensionality of the hypercube. Integer. + + num_bits - The number of bits for each dimension. Integer. + + Returns: + -------- + The output is an ndarray of uint64 integers with the same shape as the + input, excluding the last dimension, which needs to be num_dims. + """ + + # Keep around the original shape for later. + orig_shape = locs.shape + bitpack_mask = 1 << torch.arange(0, 8).to(locs.device) + bitpack_mask_rev = bitpack_mask.flip(-1) + + if orig_shape[-1] != num_dims: + raise ValueError( + """ + The shape of locs was surprising in that the last dimension was of size + %d, but num_dims=%d. These need to be equal. + """ + % (orig_shape[-1], num_dims) + ) + + if num_dims * num_bits > 63: + raise ValueError( + """ + num_dims=%d and num_bits=%d for %d bits total, which can't be encoded + into a int64. Are you sure you need that many points on your Hilbert + curve? + """ + % (num_dims, num_bits, num_dims * num_bits) + ) + + # Treat the location integers as 64-bit unsigned and then split them up into + # a sequence of uint8s. Preserve the association by dimension. + locs_uint8 = locs.long().view(torch.uint8).reshape((-1, num_dims, 8)).flip(-1) + + # Now turn these into bits and truncate to num_bits. + gray = ( + locs_uint8.unsqueeze(-1) + .bitwise_and(bitpack_mask_rev) + .ne(0) + .byte() + .flatten(-2, -1)[..., -num_bits:] + ) + + # Run the decoding process the other way. + # Iterate forwards through the bits. + for bit in range(0, num_bits): + # Iterate forwards through the dimensions. + for dim in range(0, num_dims): + # Identify which ones have this bit active. + mask = gray[:, dim, bit] + + # Where this bit is on, invert the 0 dimension for lower bits. + gray[:, 0, bit + 1 :] = torch.logical_xor( + gray[:, 0, bit + 1 :], mask[:, None] + ) + + # Where the bit is off, exchange the lower bits with the 0 dimension. + to_flip = torch.logical_and( + torch.logical_not(mask[:, None]).repeat(1, gray.shape[2] - bit - 1), + torch.logical_xor(gray[:, 0, bit + 1 :], gray[:, dim, bit + 1 :]), + ) + gray[:, dim, bit + 1 :] = torch.logical_xor( + gray[:, dim, bit + 1 :], to_flip + ) + gray[:, 0, bit + 1 :] = torch.logical_xor(gray[:, 0, bit + 1 :], to_flip) + + # Now flatten out. + gray = gray.swapaxes(1, 2).reshape((-1, num_bits * num_dims)) + + # Convert Gray back to binary. + hh_bin = gray2binary(gray) + + # Pad back out to 64 bits. + extra_dims = 64 - num_bits * num_dims + padded = torch.nn.functional.pad(hh_bin, (extra_dims, 0), "constant", 0) + + # Convert binary values into uint8s. + hh_uint8 = ( + (padded.flip(-1).reshape((-1, 8, 8)) * bitpack_mask) + .sum(2) + .squeeze() + .type(torch.uint8) + ) + + # Convert uint8s into uint64s. + hh_uint64 = hh_uint8.view(torch.int64).squeeze() + + return hh_uint64 + + +def decode(hilberts, num_dims, num_bits): + """Decode an array of Hilbert integers into locations in a hypercube. + + This is a vectorized-ish version of the Hilbert curve implementation by John + Skilling as described in: + + Skilling, J. (2004, April). Programming the Hilbert curve. In AIP Conference + Proceedings (Vol. 707, No. 1, pp. 381-387). American Institute of Physics. + + Params: + ------- + hilberts - An ndarray of Hilbert integers. Must be an integer dtype and + cannot have fewer bits than num_dims * num_bits. + + num_dims - The dimensionality of the hypercube. Integer. + + num_bits - The number of bits for each dimension. Integer. + + Returns: + -------- + The output is an ndarray of unsigned integers with the same shape as hilberts + but with an additional dimension of size num_dims. + """ + + if num_dims * num_bits > 64: + raise ValueError( + """ + num_dims=%d and num_bits=%d for %d bits total, which can't be encoded + into a uint64. Are you sure you need that many points on your Hilbert + curve? + """ + % (num_dims, num_bits) + ) + + # Handle the case where we got handed a naked integer. + hilberts = torch.atleast_1d(hilberts) + + # Keep around the shape for later. + orig_shape = hilberts.shape + bitpack_mask = 2 ** torch.arange(0, 8).to(hilberts.device) + bitpack_mask_rev = bitpack_mask.flip(-1) + + # Treat each of the hilberts as a s equence of eight uint8. + # This treats all of the inputs as uint64 and makes things uniform. + hh_uint8 = ( + hilberts.ravel().type(torch.int64).view(torch.uint8).reshape((-1, 8)).flip(-1) + ) + + # Turn these lists of uints into lists of bits and then truncate to the size + # we actually need for using Skilling's procedure. + hh_bits = ( + hh_uint8.unsqueeze(-1) + .bitwise_and(bitpack_mask_rev) + .ne(0) + .byte() + .flatten(-2, -1)[:, -num_dims * num_bits :] + ) + + # Take the sequence of bits and Gray-code it. + gray = binary2gray(hh_bits) + + # There has got to be a better way to do this. + # I could index them differently, but the eventual packbits likes it this way. + gray = gray.reshape((-1, num_bits, num_dims)).swapaxes(1, 2) + + # Iterate backwards through the bits. + for bit in range(num_bits - 1, -1, -1): + # Iterate backwards through the dimensions. + for dim in range(num_dims - 1, -1, -1): + # Identify which ones have this bit active. + mask = gray[:, dim, bit] + + # Where this bit is on, invert the 0 dimension for lower bits. + gray[:, 0, bit + 1 :] = torch.logical_xor( + gray[:, 0, bit + 1 :], mask[:, None] + ) + + # Where the bit is off, exchange the lower bits with the 0 dimension. + to_flip = torch.logical_and( + torch.logical_not(mask[:, None]), + torch.logical_xor(gray[:, 0, bit + 1 :], gray[:, dim, bit + 1 :]), + ) + gray[:, dim, bit + 1 :] = torch.logical_xor( + gray[:, dim, bit + 1 :], to_flip + ) + gray[:, 0, bit + 1 :] = torch.logical_xor(gray[:, 0, bit + 1 :], to_flip) + + # Pad back out to 64 bits. + extra_dims = 64 - num_bits + padded = torch.nn.functional.pad(gray, (extra_dims, 0), "constant", 0) + + # Now chop these up into blocks of 8. + locs_chopped = padded.flip(-1).reshape((-1, num_dims, 8, 8)) + + # Take those blocks and turn them unto uint8s. + # from IPython import embed; embed() + locs_uint8 = (locs_chopped * bitpack_mask).sum(3).squeeze().type(torch.uint8) + + # Finally, treat these as uint64s. + flat_locs = locs_uint8.view(torch.int64) + + # Return them in the expected shape. + return flat_locs.reshape((*orig_shape, num_dims)) \ No newline at end of file diff --git a/TRELLIS/extensions/vox2seq/vox2seq/pytorch/z_order.py b/TRELLIS/extensions/vox2seq/vox2seq/pytorch/z_order.py new file mode 100644 index 0000000000000000000000000000000000000000..e7b085b71678d6af153c7c30ad3d05b37f96a23d --- /dev/null +++ b/TRELLIS/extensions/vox2seq/vox2seq/pytorch/z_order.py @@ -0,0 +1,126 @@ +# -------------------------------------------------------- +# Octree-based Sparse Convolutional Neural Networks +# Copyright (c) 2022 Peng-Shuai Wang +# Licensed under The MIT License [see LICENSE for details] +# Written by Peng-Shuai Wang +# -------------------------------------------------------- + +import torch +from typing import Optional, Union + + +class KeyLUT: + def __init__(self): + r256 = torch.arange(256, dtype=torch.int64) + r512 = torch.arange(512, dtype=torch.int64) + zero = torch.zeros(256, dtype=torch.int64) + device = torch.device("cpu") + + self._encode = { + device: ( + self.xyz2key(r256, zero, zero, 8), + self.xyz2key(zero, r256, zero, 8), + self.xyz2key(zero, zero, r256, 8), + ) + } + self._decode = {device: self.key2xyz(r512, 9)} + + def encode_lut(self, device=torch.device("cpu")): + if device not in self._encode: + cpu = torch.device("cpu") + self._encode[device] = tuple(e.to(device) for e in self._encode[cpu]) + return self._encode[device] + + def decode_lut(self, device=torch.device("cpu")): + if device not in self._decode: + cpu = torch.device("cpu") + self._decode[device] = tuple(e.to(device) for e in self._decode[cpu]) + return self._decode[device] + + def xyz2key(self, x, y, z, depth): + key = torch.zeros_like(x) + for i in range(depth): + mask = 1 << i + key = ( + key + | ((x & mask) << (2 * i + 2)) + | ((y & mask) << (2 * i + 1)) + | ((z & mask) << (2 * i + 0)) + ) + return key + + def key2xyz(self, key, depth): + x = torch.zeros_like(key) + y = torch.zeros_like(key) + z = torch.zeros_like(key) + for i in range(depth): + x = x | ((key & (1 << (3 * i + 2))) >> (2 * i + 2)) + y = y | ((key & (1 << (3 * i + 1))) >> (2 * i + 1)) + z = z | ((key & (1 << (3 * i + 0))) >> (2 * i + 0)) + return x, y, z + + +_key_lut = KeyLUT() + + +def xyz2key( + x: torch.Tensor, + y: torch.Tensor, + z: torch.Tensor, + b: Optional[Union[torch.Tensor, int]] = None, + depth: int = 16, +): + r"""Encodes :attr:`x`, :attr:`y`, :attr:`z` coordinates to the shuffled keys + based on pre-computed look up tables. The speed of this function is much + faster than the method based on for-loop. + + Args: + x (torch.Tensor): The x coordinate. + y (torch.Tensor): The y coordinate. + z (torch.Tensor): The z coordinate. + b (torch.Tensor or int): The batch index of the coordinates, and should be + smaller than 32768. If :attr:`b` is :obj:`torch.Tensor`, the size of + :attr:`b` must be the same as :attr:`x`, :attr:`y`, and :attr:`z`. + depth (int): The depth of the shuffled key, and must be smaller than 17 (< 17). + """ + + EX, EY, EZ = _key_lut.encode_lut(x.device) + x, y, z = x.long(), y.long(), z.long() + + mask = 255 if depth > 8 else (1 << depth) - 1 + key = EX[x & mask] | EY[y & mask] | EZ[z & mask] + if depth > 8: + mask = (1 << (depth - 8)) - 1 + key16 = EX[(x >> 8) & mask] | EY[(y >> 8) & mask] | EZ[(z >> 8) & mask] + key = key16 << 24 | key + + if b is not None: + b = b.long() + key = b << 48 | key + + return key + + +def key2xyz(key: torch.Tensor, depth: int = 16): + r"""Decodes the shuffled key to :attr:`x`, :attr:`y`, :attr:`z` coordinates + and the batch index based on pre-computed look up tables. + + Args: + key (torch.Tensor): The shuffled key. + depth (int): The depth of the shuffled key, and must be smaller than 17 (< 17). + """ + + DX, DY, DZ = _key_lut.decode_lut(key.device) + x, y, z = torch.zeros_like(key), torch.zeros_like(key), torch.zeros_like(key) + + b = key >> 48 + key = key & ((1 << 48) - 1) + + n = (depth + 2) // 3 + for i in range(n): + k = key >> (i * 9) & 511 + x = x | (DX[k] << (i * 3)) + y = y | (DY[k] << (i * 3)) + z = z | (DZ[k] << (i * 3)) + + return x, y, z, b \ No newline at end of file diff --git a/TRELLIS/run.py b/TRELLIS/run.py new file mode 100644 index 0000000000000000000000000000000000000000..b77fc2b77cfb8d1bda5a732740608154ef8893fc --- /dev/null +++ b/TRELLIS/run.py @@ -0,0 +1,7 @@ +from trellis import ImageTo3D + +# Load the model +model = ImageTo3D() + +# Convert an image to 3D +model.convert("DORA.png", "output_3d.glb") diff --git a/TRELLIS/setup.sh b/TRELLIS/setup.sh new file mode 100644 index 0000000000000000000000000000000000000000..df9b9225277f2cba63a7ab021bcf0d75f3237502 --- /dev/null +++ b/TRELLIS/setup.sh @@ -0,0 +1,250 @@ +# Read Arguments +TEMP=`getopt -o h --long help,new-env,basic,xformers,flash-attn,diffoctreerast,vox2seq,spconv,mipgaussian,kaolin,nvdiffrast,demo -n 'setup.sh' -- "$@"` + +eval set -- "$TEMP" + +HELP=false +NEW_ENV=false +BASIC=false +XFORMERS=false +FLASHATTN=false +DIFFOCTREERAST=false +VOX2SEQ=false +LINEAR_ASSIGNMENT=false +SPCONV=false +ERROR=false +MIPGAUSSIAN=false +KAOLIN=false +NVDIFFRAST=false +DEMO=false + +if [ "$#" -eq 1 ] ; then + HELP=true +fi + +while true ; do + case "$1" in + -h|--help) HELP=true ; shift ;; + --new-env) NEW_ENV=true ; shift ;; + --basic) BASIC=true ; shift ;; + --xformers) XFORMERS=true ; shift ;; + --flash-attn) FLASHATTN=true ; shift ;; + --diffoctreerast) DIFFOCTREERAST=true ; shift ;; + --vox2seq) VOX2SEQ=true ; shift ;; + --spconv) SPCONV=true ; shift ;; + --mipgaussian) MIPGAUSSIAN=true ; shift ;; + --kaolin) KAOLIN=true ; shift ;; + --nvdiffrast) NVDIFFRAST=true ; shift ;; + --demo) DEMO=true ; shift ;; + --) shift ; break ;; + *) ERROR=true ; break ;; + esac +done + +if [ "$ERROR" = true ] ; then + echo "Error: Invalid argument" + HELP=true +fi + +if [ "$HELP" = true ] ; then + echo "Usage: setup.sh [OPTIONS]" + echo "Options:" + echo " -h, --help Display this help message" + echo " --new-env Create a new conda environment" + echo " --basic Install basic dependencies" + echo " --xformers Install xformers" + echo " --flash-attn Install flash-attn" + echo " --diffoctreerast Install diffoctreerast" + echo " --vox2seq Install vox2seq" + echo " --spconv Install spconv" + echo " --mipgaussian Install mip-splatting" + echo " --kaolin Install kaolin" + echo " --nvdiffrast Install nvdiffrast" + echo " --demo Install all dependencies for demo" + return +fi + +if [ "$NEW_ENV" = true ] ; then + conda create -n trellis python=3.10 + conda activate trellis + conda install pytorch==2.4.0 torchvision==0.19.0 pytorch-cuda=11.8 -c pytorch -c nvidia +fi + +# Get system information +WORKDIR=$(pwd) +PYTORCH_VERSION=$(python -c "import torch; print(torch.__version__)") +PLATFORM=$(python -c "import torch; print(('cuda' if torch.version.cuda else ('hip' if torch.version.hip else 'unknown')) if torch.cuda.is_available() else 'cpu')") +case $PLATFORM in + cuda) + CUDA_VERSION=$(python -c "import torch; print(torch.version.cuda)") + CUDA_MAJOR_VERSION=$(echo $CUDA_VERSION | cut -d'.' -f1) + CUDA_MINOR_VERSION=$(echo $CUDA_VERSION | cut -d'.' -f2) + echo "[SYSTEM] PyTorch Version: $PYTORCH_VERSION, CUDA Version: $CUDA_VERSION" + ;; + hip) + HIP_VERSION=$(python -c "import torch; print(torch.version.hip)") + HIP_MAJOR_VERSION=$(echo $HIP_VERSION | cut -d'.' -f1) + HIP_MINOR_VERSION=$(echo $HIP_VERSION | cut -d'.' -f2) + # Install pytorch 2.4.1 for hip + if [ "$PYTORCH_VERSION" != "2.4.1+rocm6.1" ] ; then + echo "[SYSTEM] Installing PyTorch 2.4.1 for HIP ($PYTORCH_VERSION -> 2.4.1+rocm6.1)" + pip install torch==2.4.1 torchvision==0.19.1 --index-url https://download.pytorch.org/whl/rocm6.1 --user + mkdir -p /tmp/extensions + sudo cp /opt/rocm/share/amd_smi /tmp/extensions/amd_smi -r + cd /tmp/extensions/amd_smi + sudo chmod -R 777 . + pip install . + cd $WORKDIR + PYTORCH_VERSION=$(python -c "import torch; print(torch.__version__)") + fi + echo "[SYSTEM] PyTorch Version: $PYTORCH_VERSION, HIP Version: $HIP_VERSION" + ;; + *) + ;; +esac + +if [ "$BASIC" = true ] ; then + pip install pillow imageio imageio-ffmpeg tqdm easydict opencv-python-headless scipy ninja rembg onnxruntime trimesh xatlas pyvista pymeshfix igraph transformers + pip install git+https://github.com/EasternJournalist/utils3d.git@9a4eb15e4021b67b12c460c7057d642626897ec8 +fi + +if [ "$XFORMERS" = true ] ; then + # install xformers + if [ "$PLATFORM" = "cuda" ] ; then + if [ "$CUDA_VERSION" = "11.8" ] ; then + case $PYTORCH_VERSION in + 2.0.1) pip install https://files.pythonhosted.org/packages/52/ca/82aeee5dcc24a3429ff5de65cc58ae9695f90f49fbba71755e7fab69a706/xformers-0.0.22-cp310-cp310-manylinux2014_x86_64.whl ;; + 2.1.0) pip install xformers==0.0.22.post7 --index-url https://download.pytorch.org/whl/cu118 ;; + 2.1.1) pip install xformers==0.0.23 --index-url https://download.pytorch.org/whl/cu118 ;; + 2.1.2) pip install xformers==0.0.23.post1 --index-url https://download.pytorch.org/whl/cu118 ;; + 2.2.0) pip install xformers==0.0.24 --index-url https://download.pytorch.org/whl/cu118 ;; + 2.2.1) pip install xformers==0.0.25 --index-url https://download.pytorch.org/whl/cu118 ;; + 2.2.2) pip install xformers==0.0.25.post1 --index-url https://download.pytorch.org/whl/cu118 ;; + 2.3.0) pip install xformers==0.0.26.post1 --index-url https://download.pytorch.org/whl/cu118 ;; + 2.4.0) pip install xformers==0.0.27.post2 --index-url https://download.pytorch.org/whl/cu118 ;; + 2.4.1) pip install xformers==0.0.28 --index-url https://download.pytorch.org/whl/cu118 ;; + 2.5.0) pip install xformers==0.0.28.post2 --index-url https://download.pytorch.org/whl/cu118 ;; + *) echo "[XFORMERS] Unsupported PyTorch & CUDA version: $PYTORCH_VERSION & $CUDA_VERSION" ;; + esac + elif [ "$CUDA_VERSION" = "12.1" ] ; then + case $PYTORCH_VERSION in + 2.1.0) pip install xformers==0.0.22.post7 --index-url https://download.pytorch.org/whl/cu121 ;; + 2.1.1) pip install xformers==0.0.23 --index-url https://download.pytorch.org/whl/cu121 ;; + 2.1.2) pip install xformers==0.0.23.post1 --index-url https://download.pytorch.org/whl/cu121 ;; + 2.2.0) pip install xformers==0.0.24 --index-url https://download.pytorch.org/whl/cu121 ;; + 2.2.1) pip install xformers==0.0.25 --index-url https://download.pytorch.org/whl/cu121 ;; + 2.2.2) pip install xformers==0.0.25.post1 --index-url https://download.pytorch.org/whl/cu121 ;; + 2.3.0) pip install xformers==0.0.26.post1 --index-url https://download.pytorch.org/whl/cu121 ;; + 2.4.0) pip install xformers==0.0.27.post2 --index-url https://download.pytorch.org/whl/cu121 ;; + 2.4.1) pip install xformers==0.0.28 --index-url https://download.pytorch.org/whl/cu121 ;; + 2.5.0) pip install xformers==0.0.28.post2 --index-url https://download.pytorch.org/whl/cu121 ;; + *) echo "[XFORMERS] Unsupported PyTorch & CUDA version: $PYTORCH_VERSION & $CUDA_VERSION" ;; + esac + elif [ "$CUDA_VERSION" = "12.4" ] ; then + case $PYTORCH_VERSION in + 2.5.0) pip install xformers==0.0.28.post2 --index-url https://download.pytorch.org/whl/cu124 ;; + *) echo "[XFORMERS] Unsupported PyTorch & CUDA version: $PYTORCH_VERSION & $CUDA_VERSION" ;; + esac + else + echo "[XFORMERS] Unsupported CUDA version: $CUDA_MAJOR_VERSION" + fi + elif [ "$PLATFORM" = "hip" ] ; then + case $PYTORCH_VERSION in + 2.4.1\+rocm6.1) pip install xformers==0.0.28 --index-url https://download.pytorch.org/whl/rocm6.1 ;; + *) echo "[XFORMERS] Unsupported PyTorch version: $PYTORCH_VERSION" ;; + esac + else + echo "[XFORMERS] Unsupported platform: $PLATFORM" + fi +fi + +if [ "$FLASHATTN" = true ] ; then + if [ "$PLATFORM" = "cuda" ] ; then + pip install flash-attn + elif [ "$PLATFORM" = "hip" ] ; then + echo "[FLASHATTN] Prebuilt binaries not found. Building from source..." + mkdir -p /tmp/extensions + git clone --recursive https://github.com/ROCm/flash-attention.git /tmp/extensions/flash-attention + cd /tmp/extensions/flash-attention + git checkout tags/v2.6.3-cktile + GPU_ARCHS=gfx942 python setup.py install #MI300 series + cd $WORKDIR + else + echo "[FLASHATTN] Unsupported platform: $PLATFORM" + fi +fi + +if [ "$KAOLIN" = true ] ; then + # install kaolin + if [ "$PLATFORM" = "cuda" ] ; then + case $PYTORCH_VERSION in + 2.0.1) pip install kaolin -f https://nvidia-kaolin.s3.us-east-2.amazonaws.com/torch-2.0.1_cu118.html;; + 2.1.0) pip install kaolin -f https://nvidia-kaolin.s3.us-east-2.amazonaws.com/torch-2.1.0_cu118.html;; + 2.1.1) pip install kaolin -f https://nvidia-kaolin.s3.us-east-2.amazonaws.com/torch-2.1.1_cu118.html;; + 2.2.0) pip install kaolin -f https://nvidia-kaolin.s3.us-east-2.amazonaws.com/torch-2.2.0_cu118.html;; + 2.2.1) pip install kaolin -f https://nvidia-kaolin.s3.us-east-2.amazonaws.com/torch-2.2.1_cu118.html;; + 2.2.2) pip install kaolin -f https://nvidia-kaolin.s3.us-east-2.amazonaws.com/torch-2.2.2_cu118.html;; + 2.4.0) pip install kaolin -f https://nvidia-kaolin.s3.us-east-2.amazonaws.com/torch-2.4.0_cu121.html;; + *) echo "[KAOLIN] Unsupported PyTorch version: $PYTORCH_VERSION" ;; + esac + else + echo "[KAOLIN] Unsupported platform: $PLATFORM" + fi +fi + +if [ "$NVDIFFRAST" = true ] ; then + if [ "$PLATFORM" = "cuda" ] ; then + mkdir -p /tmp/extensions + git clone https://github.com/NVlabs/nvdiffrast.git /tmp/extensions/nvdiffrast + pip install /tmp/extensions/nvdiffrast + else + echo "[NVDIFFRAST] Unsupported platform: $PLATFORM" + fi +fi + +if [ "$DIFFOCTREERAST" = true ] ; then + if [ "$PLATFORM" = "cuda" ] ; then + mkdir -p /tmp/extensions + git clone --recurse-submodules https://github.com/JeffreyXiang/diffoctreerast.git /tmp/extensions/diffoctreerast + pip install /tmp/extensions/diffoctreerast + else + echo "[DIFFOCTREERAST] Unsupported platform: $PLATFORM" + fi +fi + +if [ "$MIPGAUSSIAN" = true ] ; then + if [ "$PLATFORM" = "cuda" ] ; then + mkdir -p /tmp/extensions + git clone https://github.com/autonomousvision/mip-splatting.git /tmp/extensions/mip-splatting + pip install /tmp/extensions/mip-splatting/submodules/diff-gaussian-rasterization/ + else + echo "[MIPGAUSSIAN] Unsupported platform: $PLATFORM" + fi +fi + +if [ "$VOX2SEQ" = true ] ; then + if [ "$PLATFORM" = "cuda" ] ; then + mkdir -p /tmp/extensions + cp -r extensions/vox2seq /tmp/extensions/vox2seq + pip install /tmp/extensions/vox2seq + else + echo "[VOX2SEQ] Unsupported platform: $PLATFORM" + fi +fi + +if [ "$SPCONV" = true ] ; then + # install spconv + if [ "$PLATFORM" = "cuda" ] ; then + case $CUDA_MAJOR_VERSION in + 11) pip install spconv-cu118 ;; + 12) pip install spconv-cu120 ;; + *) echo "[SPCONV] Unsupported PyTorch CUDA version: $CUDA_MAJOR_VERSION" ;; + esac + else + echo "[SPCONV] Unsupported platform: $PLATFORM" + fi +fi + +if [ "$DEMO" = true ] ; then + pip install gradio==4.44.1 gradio_litmodel3d==0.0.1 +fi diff --git a/TRELLIS/trellis/__init__.py b/TRELLIS/trellis/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b02ac31563fec7c36fa4bd5d420ac4af2472bba8 --- /dev/null +++ b/TRELLIS/trellis/__init__.py @@ -0,0 +1,6 @@ +from . import models +from . import modules +from . import pipelines +from . import renderers +from . import representations +from . import utils diff --git a/TRELLIS/trellis/__pycache__/__init__.cpython-311.pyc b/TRELLIS/trellis/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..49bce200b8831e3bbdf260b78aa862dd35204ccb Binary files /dev/null and b/TRELLIS/trellis/__pycache__/__init__.cpython-311.pyc differ diff --git a/TRELLIS/trellis/models/__init__.py b/TRELLIS/trellis/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..00fd66a3272a48d43f9853682db48bcde2959d63 --- /dev/null +++ b/TRELLIS/trellis/models/__init__.py @@ -0,0 +1,70 @@ +import importlib + +__attributes = { + 'SparseStructureEncoder': 'sparse_structure_vae', + 'SparseStructureDecoder': 'sparse_structure_vae', + 'SparseStructureFlowModel': 'sparse_structure_flow', + 'SLatEncoder': 'structured_latent_vae', + 'SLatGaussianDecoder': 'structured_latent_vae', + 'SLatRadianceFieldDecoder': 'structured_latent_vae', + 'SLatMeshDecoder': 'structured_latent_vae', + 'SLatFlowModel': 'structured_latent_flow', +} + +__submodules = [] + +__all__ = list(__attributes.keys()) + __submodules + +def __getattr__(name): + if name not in globals(): + if name in __attributes: + module_name = __attributes[name] + module = importlib.import_module(f".{module_name}", __name__) + globals()[name] = getattr(module, name) + elif name in __submodules: + module = importlib.import_module(f".{name}", __name__) + globals()[name] = module + else: + raise AttributeError(f"module {__name__} has no attribute {name}") + return globals()[name] + + +def from_pretrained(path: str, **kwargs): + """ + Load a model from a pretrained checkpoint. + + Args: + path: The path to the checkpoint. Can be either local path or a Hugging Face model name. + NOTE: config file and model file should take the name f'{path}.json' and f'{path}.safetensors' respectively. + **kwargs: Additional arguments for the model constructor. + """ + import os + import json + from safetensors.torch import load_file + is_local = os.path.exists(f"{path}.json") and os.path.exists(f"{path}.safetensors") + + if is_local: + config_file = f"{path}.json" + model_file = f"{path}.safetensors" + else: + from huggingface_hub import hf_hub_download + path_parts = path.split('/') + repo_id = f'{path_parts[0]}/{path_parts[1]}' + model_name = '/'.join(path_parts[2:]) + config_file = hf_hub_download(repo_id, f"{model_name}.json") + model_file = hf_hub_download(repo_id, f"{model_name}.safetensors") + + with open(config_file, 'r') as f: + config = json.load(f) + model = __getattr__(config['name'])(**config['args'], **kwargs) + model.load_state_dict(load_file(model_file)) + + return model + + +# For Pylance +if __name__ == '__main__': + from .sparse_structure_vae import SparseStructureEncoder, SparseStructureDecoder + from .sparse_structure_flow import SparseStructureFlowModel + from .structured_latent_vae import SLatEncoder, SLatGaussianDecoder, SLatRadianceFieldDecoder, SLatMeshDecoder + from .structured_latent_flow import SLatFlowModel diff --git a/TRELLIS/trellis/models/__pycache__/__init__.cpython-311.pyc b/TRELLIS/trellis/models/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a30d50ed9d1a32c9356603e23dbc18089f9d27f9 Binary files /dev/null and b/TRELLIS/trellis/models/__pycache__/__init__.cpython-311.pyc differ diff --git a/TRELLIS/trellis/models/sparse_structure_flow.py b/TRELLIS/trellis/models/sparse_structure_flow.py new file mode 100644 index 0000000000000000000000000000000000000000..baa5dd9644e569b73717d7e7a9ebed55e9930459 --- /dev/null +++ b/TRELLIS/trellis/models/sparse_structure_flow.py @@ -0,0 +1,200 @@ +from typing import * +import torch +import torch.nn as nn +import torch.nn.functional as F +import numpy as np +from ..modules.utils import convert_module_to_f16, convert_module_to_f32 +from ..modules.transformer import AbsolutePositionEmbedder, ModulatedTransformerCrossBlock +from ..modules.spatial import patchify, unpatchify + + +class TimestepEmbedder(nn.Module): + """ + Embeds scalar timesteps into vector representations. + """ + def __init__(self, hidden_size, frequency_embedding_size=256): + super().__init__() + self.mlp = nn.Sequential( + nn.Linear(frequency_embedding_size, hidden_size, bias=True), + nn.SiLU(), + nn.Linear(hidden_size, hidden_size, bias=True), + ) + self.frequency_embedding_size = frequency_embedding_size + + @staticmethod + def timestep_embedding(t, dim, max_period=10000): + """ + Create sinusoidal timestep embeddings. + + Args: + t: a 1-D Tensor of N indices, one per batch element. + These may be fractional. + dim: the dimension of the output. + max_period: controls the minimum frequency of the embeddings. + + Returns: + an (N, D) Tensor of positional embeddings. + """ + # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py + half = dim // 2 + freqs = torch.exp( + -np.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half + ).to(device=t.device) + args = t[:, None].float() * freqs[None] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if dim % 2: + embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) + return embedding + + def forward(self, t): + t_freq = self.timestep_embedding(t, self.frequency_embedding_size) + t_emb = self.mlp(t_freq) + return t_emb + + +class SparseStructureFlowModel(nn.Module): + def __init__( + self, + resolution: int, + in_channels: int, + model_channels: int, + cond_channels: int, + out_channels: int, + num_blocks: int, + num_heads: Optional[int] = None, + num_head_channels: Optional[int] = 64, + mlp_ratio: float = 4, + patch_size: int = 2, + pe_mode: Literal["ape", "rope"] = "ape", + use_fp16: bool = False, + use_checkpoint: bool = False, + share_mod: bool = False, + qk_rms_norm: bool = False, + qk_rms_norm_cross: bool = False, + ): + super().__init__() + self.resolution = resolution + self.in_channels = in_channels + self.model_channels = model_channels + self.cond_channels = cond_channels + self.out_channels = out_channels + self.num_blocks = num_blocks + self.num_heads = num_heads or model_channels // num_head_channels + self.mlp_ratio = mlp_ratio + self.patch_size = patch_size + self.pe_mode = pe_mode + self.use_fp16 = use_fp16 + self.use_checkpoint = use_checkpoint + self.share_mod = share_mod + self.qk_rms_norm = qk_rms_norm + self.qk_rms_norm_cross = qk_rms_norm_cross + self.dtype = torch.float16 if use_fp16 else torch.float32 + + self.t_embedder = TimestepEmbedder(model_channels) + if share_mod: + self.adaLN_modulation = nn.Sequential( + nn.SiLU(), + nn.Linear(model_channels, 6 * model_channels, bias=True) + ) + + if pe_mode == "ape": + pos_embedder = AbsolutePositionEmbedder(model_channels, 3) + coords = torch.meshgrid(*[torch.arange(res, device=self.device) for res in [resolution // patch_size] * 3], indexing='ij') + coords = torch.stack(coords, dim=-1).reshape(-1, 3) + pos_emb = pos_embedder(coords) + self.register_buffer("pos_emb", pos_emb) + + self.input_layer = nn.Linear(in_channels * patch_size**3, model_channels) + + self.blocks = nn.ModuleList([ + ModulatedTransformerCrossBlock( + model_channels, + cond_channels, + num_heads=self.num_heads, + mlp_ratio=self.mlp_ratio, + attn_mode='full', + use_checkpoint=self.use_checkpoint, + use_rope=(pe_mode == "rope"), + share_mod=share_mod, + qk_rms_norm=self.qk_rms_norm, + qk_rms_norm_cross=self.qk_rms_norm_cross, + ) + for _ in range(num_blocks) + ]) + + self.out_layer = nn.Linear(model_channels, out_channels * patch_size**3) + + self.initialize_weights() + if use_fp16: + self.convert_to_fp16() + + @property + def device(self) -> torch.device: + """ + Return the device of the model. + """ + return next(self.parameters()).device + + def convert_to_fp16(self) -> None: + """ + Convert the torso of the model to float16. + """ + self.blocks.apply(convert_module_to_f16) + + def convert_to_fp32(self) -> None: + """ + Convert the torso of the model to float32. + """ + self.blocks.apply(convert_module_to_f32) + + def initialize_weights(self) -> None: + # Initialize transformer layers: + def _basic_init(module): + if isinstance(module, nn.Linear): + torch.nn.init.xavier_uniform_(module.weight) + if module.bias is not None: + nn.init.constant_(module.bias, 0) + self.apply(_basic_init) + + # Initialize timestep embedding MLP: + nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02) + nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02) + + # Zero-out adaLN modulation layers in DiT blocks: + if self.share_mod: + nn.init.constant_(self.adaLN_modulation[-1].weight, 0) + nn.init.constant_(self.adaLN_modulation[-1].bias, 0) + else: + for block in self.blocks: + nn.init.constant_(block.adaLN_modulation[-1].weight, 0) + nn.init.constant_(block.adaLN_modulation[-1].bias, 0) + + # Zero-out output layers: + nn.init.constant_(self.out_layer.weight, 0) + nn.init.constant_(self.out_layer.bias, 0) + + def forward(self, x: torch.Tensor, t: torch.Tensor, cond: torch.Tensor) -> torch.Tensor: + assert [*x.shape] == [x.shape[0], self.in_channels, *[self.resolution] * 3], \ + f"Input shape mismatch, got {x.shape}, expected {[x.shape[0], self.in_channels, *[self.resolution] * 3]}" + + h = patchify(x, self.patch_size) + h = h.view(*h.shape[:2], -1).permute(0, 2, 1).contiguous() + + h = self.input_layer(h) + h = h + self.pos_emb[None] + t_emb = self.t_embedder(t) + if self.share_mod: + t_emb = self.adaLN_modulation(t_emb) + t_emb = t_emb.type(self.dtype) + h = h.type(self.dtype) + cond = cond.type(self.dtype) + for block in self.blocks: + h = block(h, t_emb, cond) + h = h.type(x.dtype) + h = F.layer_norm(h, h.shape[-1:]) + h = self.out_layer(h) + + h = h.permute(0, 2, 1).view(h.shape[0], h.shape[2], *[self.resolution // self.patch_size] * 3) + h = unpatchify(h, self.patch_size).contiguous() + + return h diff --git a/TRELLIS/trellis/models/sparse_structure_vae.py b/TRELLIS/trellis/models/sparse_structure_vae.py new file mode 100644 index 0000000000000000000000000000000000000000..6ed49ae65b9cde2a45a59beb6868981a644b75d3 --- /dev/null +++ b/TRELLIS/trellis/models/sparse_structure_vae.py @@ -0,0 +1,306 @@ +from typing import * +import torch +import torch.nn as nn +import torch.nn.functional as F +from ..modules.norm import GroupNorm32, ChannelLayerNorm32 +from ..modules.spatial import pixel_shuffle_3d +from ..modules.utils import zero_module, convert_module_to_f16, convert_module_to_f32 + + +def norm_layer(norm_type: str, *args, **kwargs) -> nn.Module: + """ + Return a normalization layer. + """ + if norm_type == "group": + return GroupNorm32(32, *args, **kwargs) + elif norm_type == "layer": + return ChannelLayerNorm32(*args, **kwargs) + else: + raise ValueError(f"Invalid norm type {norm_type}") + + +class ResBlock3d(nn.Module): + def __init__( + self, + channels: int, + out_channels: Optional[int] = None, + norm_type: Literal["group", "layer"] = "layer", + ): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + + self.norm1 = norm_layer(norm_type, channels) + self.norm2 = norm_layer(norm_type, self.out_channels) + self.conv1 = nn.Conv3d(channels, self.out_channels, 3, padding=1) + self.conv2 = zero_module(nn.Conv3d(self.out_channels, self.out_channels, 3, padding=1)) + self.skip_connection = nn.Conv3d(channels, self.out_channels, 1) if channels != self.out_channels else nn.Identity() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + h = self.norm1(x) + h = F.silu(h) + h = self.conv1(h) + h = self.norm2(h) + h = F.silu(h) + h = self.conv2(h) + h = h + self.skip_connection(x) + return h + + +class DownsampleBlock3d(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + mode: Literal["conv", "avgpool"] = "conv", + ): + assert mode in ["conv", "avgpool"], f"Invalid mode {mode}" + + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + + if mode == "conv": + self.conv = nn.Conv3d(in_channels, out_channels, 2, stride=2) + elif mode == "avgpool": + assert in_channels == out_channels, "Pooling mode requires in_channels to be equal to out_channels" + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if hasattr(self, "conv"): + return self.conv(x) + else: + return F.avg_pool3d(x, 2) + + +class UpsampleBlock3d(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + mode: Literal["conv", "nearest"] = "conv", + ): + assert mode in ["conv", "nearest"], f"Invalid mode {mode}" + + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + + if mode == "conv": + self.conv = nn.Conv3d(in_channels, out_channels*8, 3, padding=1) + elif mode == "nearest": + assert in_channels == out_channels, "Nearest mode requires in_channels to be equal to out_channels" + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if hasattr(self, "conv"): + x = self.conv(x) + return pixel_shuffle_3d(x, 2) + else: + return F.interpolate(x, scale_factor=2, mode="nearest") + + +class SparseStructureEncoder(nn.Module): + """ + Encoder for Sparse Structure (\mathcal{E}_S in the paper Sec. 3.3). + + Args: + in_channels (int): Channels of the input. + latent_channels (int): Channels of the latent representation. + num_res_blocks (int): Number of residual blocks at each resolution. + channels (List[int]): Channels of the encoder blocks. + num_res_blocks_middle (int): Number of residual blocks in the middle. + norm_type (Literal["group", "layer"]): Type of normalization layer. + use_fp16 (bool): Whether to use FP16. + """ + def __init__( + self, + in_channels: int, + latent_channels: int, + num_res_blocks: int, + channels: List[int], + num_res_blocks_middle: int = 2, + norm_type: Literal["group", "layer"] = "layer", + use_fp16: bool = False, + ): + super().__init__() + self.in_channels = in_channels + self.latent_channels = latent_channels + self.num_res_blocks = num_res_blocks + self.channels = channels + self.num_res_blocks_middle = num_res_blocks_middle + self.norm_type = norm_type + self.use_fp16 = use_fp16 + self.dtype = torch.float16 if use_fp16 else torch.float32 + + self.input_layer = nn.Conv3d(in_channels, channels[0], 3, padding=1) + + self.blocks = nn.ModuleList([]) + for i, ch in enumerate(channels): + self.blocks.extend([ + ResBlock3d(ch, ch) + for _ in range(num_res_blocks) + ]) + if i < len(channels) - 1: + self.blocks.append( + DownsampleBlock3d(ch, channels[i+1]) + ) + + self.middle_block = nn.Sequential(*[ + ResBlock3d(channels[-1], channels[-1]) + for _ in range(num_res_blocks_middle) + ]) + + self.out_layer = nn.Sequential( + norm_layer(norm_type, channels[-1]), + nn.SiLU(), + nn.Conv3d(channels[-1], latent_channels*2, 3, padding=1) + ) + + if use_fp16: + self.convert_to_fp16() + + @property + def device(self) -> torch.device: + """ + Return the device of the model. + """ + return next(self.parameters()).device + + def convert_to_fp16(self) -> None: + """ + Convert the torso of the model to float16. + """ + self.use_fp16 = True + self.dtype = torch.float16 + self.blocks.apply(convert_module_to_f16) + self.middle_block.apply(convert_module_to_f16) + + def convert_to_fp32(self) -> None: + """ + Convert the torso of the model to float32. + """ + self.use_fp16 = False + self.dtype = torch.float32 + self.blocks.apply(convert_module_to_f32) + self.middle_block.apply(convert_module_to_f32) + + def forward(self, x: torch.Tensor, sample_posterior: bool = False, return_raw: bool = False) -> torch.Tensor: + h = self.input_layer(x) + h = h.type(self.dtype) + + for block in self.blocks: + h = block(h) + h = self.middle_block(h) + + h = h.type(x.dtype) + h = self.out_layer(h) + + mean, logvar = h.chunk(2, dim=1) + + if sample_posterior: + std = torch.exp(0.5 * logvar) + z = mean + std * torch.randn_like(std) + else: + z = mean + + if return_raw: + return z, mean, logvar + return z + + +class SparseStructureDecoder(nn.Module): + """ + Decoder for Sparse Structure (\mathcal{D}_S in the paper Sec. 3.3). + + Args: + out_channels (int): Channels of the output. + latent_channels (int): Channels of the latent representation. + num_res_blocks (int): Number of residual blocks at each resolution. + channels (List[int]): Channels of the decoder blocks. + num_res_blocks_middle (int): Number of residual blocks in the middle. + norm_type (Literal["group", "layer"]): Type of normalization layer. + use_fp16 (bool): Whether to use FP16. + """ + def __init__( + self, + out_channels: int, + latent_channels: int, + num_res_blocks: int, + channels: List[int], + num_res_blocks_middle: int = 2, + norm_type: Literal["group", "layer"] = "layer", + use_fp16: bool = False, + ): + super().__init__() + self.out_channels = out_channels + self.latent_channels = latent_channels + self.num_res_blocks = num_res_blocks + self.channels = channels + self.num_res_blocks_middle = num_res_blocks_middle + self.norm_type = norm_type + self.use_fp16 = use_fp16 + self.dtype = torch.float16 if use_fp16 else torch.float32 + + self.input_layer = nn.Conv3d(latent_channels, channels[0], 3, padding=1) + + self.middle_block = nn.Sequential(*[ + ResBlock3d(channels[0], channels[0]) + for _ in range(num_res_blocks_middle) + ]) + + self.blocks = nn.ModuleList([]) + for i, ch in enumerate(channels): + self.blocks.extend([ + ResBlock3d(ch, ch) + for _ in range(num_res_blocks) + ]) + if i < len(channels) - 1: + self.blocks.append( + UpsampleBlock3d(ch, channels[i+1]) + ) + + self.out_layer = nn.Sequential( + norm_layer(norm_type, channels[-1]), + nn.SiLU(), + nn.Conv3d(channels[-1], out_channels, 3, padding=1) + ) + + if use_fp16: + self.convert_to_fp16() + + @property + def device(self) -> torch.device: + """ + Return the device of the model. + """ + return next(self.parameters()).device + + def convert_to_fp16(self) -> None: + """ + Convert the torso of the model to float16. + """ + self.use_fp16 = True + self.dtype = torch.float16 + self.blocks.apply(convert_module_to_f16) + self.middle_block.apply(convert_module_to_f16) + + def convert_to_fp32(self) -> None: + """ + Convert the torso of the model to float32. + """ + self.use_fp16 = False + self.dtype = torch.float32 + self.blocks.apply(convert_module_to_f32) + self.middle_block.apply(convert_module_to_f32) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + h = self.input_layer(x) + + h = h.type(self.dtype) + + h = self.middle_block(h) + for block in self.blocks: + h = block(h) + + h = h.type(x.dtype) + h = self.out_layer(h) + return h diff --git a/TRELLIS/trellis/models/structured_latent_flow.py b/TRELLIS/trellis/models/structured_latent_flow.py new file mode 100644 index 0000000000000000000000000000000000000000..19c11597244ea53505d746b593d10cddad4bcb6f --- /dev/null +++ b/TRELLIS/trellis/models/structured_latent_flow.py @@ -0,0 +1,262 @@ +from typing import * +import torch +import torch.nn as nn +import torch.nn.functional as F +import numpy as np +from ..modules.utils import zero_module, convert_module_to_f16, convert_module_to_f32 +from ..modules.transformer import AbsolutePositionEmbedder +from ..modules.norm import LayerNorm32 +from ..modules import sparse as sp +from ..modules.sparse.transformer import ModulatedSparseTransformerCrossBlock +from .sparse_structure_flow import TimestepEmbedder + + +class SparseResBlock3d(nn.Module): + def __init__( + self, + channels: int, + emb_channels: int, + out_channels: Optional[int] = None, + downsample: bool = False, + upsample: bool = False, + ): + super().__init__() + self.channels = channels + self.emb_channels = emb_channels + self.out_channels = out_channels or channels + self.downsample = downsample + self.upsample = upsample + + assert not (downsample and upsample), "Cannot downsample and upsample at the same time" + + self.norm1 = LayerNorm32(channels, elementwise_affine=True, eps=1e-6) + self.norm2 = LayerNorm32(self.out_channels, elementwise_affine=False, eps=1e-6) + self.conv1 = sp.SparseConv3d(channels, self.out_channels, 3) + self.conv2 = zero_module(sp.SparseConv3d(self.out_channels, self.out_channels, 3)) + self.emb_layers = nn.Sequential( + nn.SiLU(), + nn.Linear(emb_channels, 2 * self.out_channels, bias=True), + ) + self.skip_connection = sp.SparseLinear(channels, self.out_channels) if channels != self.out_channels else nn.Identity() + self.updown = None + if self.downsample: + self.updown = sp.SparseDownsample(2) + elif self.upsample: + self.updown = sp.SparseUpsample(2) + + def _updown(self, x: sp.SparseTensor) -> sp.SparseTensor: + if self.updown is not None: + x = self.updown(x) + return x + + def forward(self, x: sp.SparseTensor, emb: torch.Tensor) -> sp.SparseTensor: + emb_out = self.emb_layers(emb).type(x.dtype) + scale, shift = torch.chunk(emb_out, 2, dim=1) + + x = self._updown(x) + h = x.replace(self.norm1(x.feats)) + h = h.replace(F.silu(h.feats)) + h = self.conv1(h) + h = h.replace(self.norm2(h.feats)) * (1 + scale) + shift + h = h.replace(F.silu(h.feats)) + h = self.conv2(h) + h = h + self.skip_connection(x) + + return h + + +class SLatFlowModel(nn.Module): + def __init__( + self, + resolution: int, + in_channels: int, + model_channels: int, + cond_channels: int, + out_channels: int, + num_blocks: int, + num_heads: Optional[int] = None, + num_head_channels: Optional[int] = 64, + mlp_ratio: float = 4, + patch_size: int = 2, + num_io_res_blocks: int = 2, + io_block_channels: List[int] = None, + pe_mode: Literal["ape", "rope"] = "ape", + use_fp16: bool = False, + use_checkpoint: bool = False, + use_skip_connection: bool = True, + share_mod: bool = False, + qk_rms_norm: bool = False, + qk_rms_norm_cross: bool = False, + ): + super().__init__() + self.resolution = resolution + self.in_channels = in_channels + self.model_channels = model_channels + self.cond_channels = cond_channels + self.out_channels = out_channels + self.num_blocks = num_blocks + self.num_heads = num_heads or model_channels // num_head_channels + self.mlp_ratio = mlp_ratio + self.patch_size = patch_size + self.num_io_res_blocks = num_io_res_blocks + self.io_block_channels = io_block_channels + self.pe_mode = pe_mode + self.use_fp16 = use_fp16 + self.use_checkpoint = use_checkpoint + self.use_skip_connection = use_skip_connection + self.share_mod = share_mod + self.qk_rms_norm = qk_rms_norm + self.qk_rms_norm_cross = qk_rms_norm_cross + self.dtype = torch.float16 if use_fp16 else torch.float32 + + assert int(np.log2(patch_size)) == np.log2(patch_size), "Patch size must be a power of 2" + assert np.log2(patch_size) == len(io_block_channels), "Number of IO ResBlocks must match the number of stages" + + self.t_embedder = TimestepEmbedder(model_channels) + if share_mod: + self.adaLN_modulation = nn.Sequential( + nn.SiLU(), + nn.Linear(model_channels, 6 * model_channels, bias=True) + ) + + if pe_mode == "ape": + self.pos_embedder = AbsolutePositionEmbedder(model_channels) + + self.input_layer = sp.SparseLinear(in_channels, io_block_channels[0]) + self.input_blocks = nn.ModuleList([]) + for chs, next_chs in zip(io_block_channels, io_block_channels[1:] + [model_channels]): + self.input_blocks.extend([ + SparseResBlock3d( + chs, + model_channels, + out_channels=chs, + ) + for _ in range(num_io_res_blocks-1) + ]) + self.input_blocks.append( + SparseResBlock3d( + chs, + model_channels, + out_channels=next_chs, + downsample=True, + ) + ) + + self.blocks = nn.ModuleList([ + ModulatedSparseTransformerCrossBlock( + model_channels, + cond_channels, + num_heads=self.num_heads, + mlp_ratio=self.mlp_ratio, + attn_mode='full', + use_checkpoint=self.use_checkpoint, + use_rope=(pe_mode == "rope"), + share_mod=self.share_mod, + qk_rms_norm=self.qk_rms_norm, + qk_rms_norm_cross=self.qk_rms_norm_cross, + ) + for _ in range(num_blocks) + ]) + + self.out_blocks = nn.ModuleList([]) + for chs, prev_chs in zip(reversed(io_block_channels), [model_channels] + list(reversed(io_block_channels[1:]))): + self.out_blocks.append( + SparseResBlock3d( + prev_chs * 2 if self.use_skip_connection else prev_chs, + model_channels, + out_channels=chs, + upsample=True, + ) + ) + self.out_blocks.extend([ + SparseResBlock3d( + chs * 2 if self.use_skip_connection else chs, + model_channels, + out_channels=chs, + ) + for _ in range(num_io_res_blocks-1) + ]) + self.out_layer = sp.SparseLinear(io_block_channels[0], out_channels) + + self.initialize_weights() + if use_fp16: + self.convert_to_fp16() + + @property + def device(self) -> torch.device: + """ + Return the device of the model. + """ + return next(self.parameters()).device + + def convert_to_fp16(self) -> None: + """ + Convert the torso of the model to float16. + """ + self.input_blocks.apply(convert_module_to_f16) + self.blocks.apply(convert_module_to_f16) + self.out_blocks.apply(convert_module_to_f16) + + def convert_to_fp32(self) -> None: + """ + Convert the torso of the model to float32. + """ + self.input_blocks.apply(convert_module_to_f32) + self.blocks.apply(convert_module_to_f32) + self.out_blocks.apply(convert_module_to_f32) + + def initialize_weights(self) -> None: + # Initialize transformer layers: + def _basic_init(module): + if isinstance(module, nn.Linear): + torch.nn.init.xavier_uniform_(module.weight) + if module.bias is not None: + nn.init.constant_(module.bias, 0) + self.apply(_basic_init) + + # Initialize timestep embedding MLP: + nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02) + nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02) + + # Zero-out adaLN modulation layers in DiT blocks: + if self.share_mod: + nn.init.constant_(self.adaLN_modulation[-1].weight, 0) + nn.init.constant_(self.adaLN_modulation[-1].bias, 0) + else: + for block in self.blocks: + nn.init.constant_(block.adaLN_modulation[-1].weight, 0) + nn.init.constant_(block.adaLN_modulation[-1].bias, 0) + + # Zero-out output layers: + nn.init.constant_(self.out_layer.weight, 0) + nn.init.constant_(self.out_layer.bias, 0) + + def forward(self, x: sp.SparseTensor, t: torch.Tensor, cond: torch.Tensor) -> sp.SparseTensor: + h = self.input_layer(x).type(self.dtype) + t_emb = self.t_embedder(t) + if self.share_mod: + t_emb = self.adaLN_modulation(t_emb) + t_emb = t_emb.type(self.dtype) + cond = cond.type(self.dtype) + + skips = [] + # pack with input blocks + for block in self.input_blocks: + h = block(h, t_emb) + skips.append(h.feats) + + if self.pe_mode == "ape": + h = h + self.pos_embedder(h.coords[:, 1:]).type(self.dtype) + for block in self.blocks: + h = block(h, t_emb, cond) + + # unpack with output blocks + for block, skip in zip(self.out_blocks, reversed(skips)): + if self.use_skip_connection: + h = block(h.replace(torch.cat([h.feats, skip], dim=1)), t_emb) + else: + h = block(h, t_emb) + + h = h.replace(F.layer_norm(h.feats, h.feats.shape[-1:])) + h = self.out_layer(h.type(x.dtype)) + return h diff --git a/TRELLIS/trellis/models/structured_latent_vae/__init__.py b/TRELLIS/trellis/models/structured_latent_vae/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..00cbf8826b328f9abe76cd641645f67128d7a04b --- /dev/null +++ b/TRELLIS/trellis/models/structured_latent_vae/__init__.py @@ -0,0 +1,4 @@ +from .encoder import SLatEncoder +from .decoder_gs import SLatGaussianDecoder +from .decoder_rf import SLatRadianceFieldDecoder +from .decoder_mesh import SLatMeshDecoder diff --git a/TRELLIS/trellis/models/structured_latent_vae/base.py b/TRELLIS/trellis/models/structured_latent_vae/base.py new file mode 100644 index 0000000000000000000000000000000000000000..7b86006fb35dee6f4f61a6f827d13787e0a287b2 --- /dev/null +++ b/TRELLIS/trellis/models/structured_latent_vae/base.py @@ -0,0 +1,117 @@ +from typing import * +import torch +import torch.nn as nn +from ...modules.utils import convert_module_to_f16, convert_module_to_f32 +from ...modules import sparse as sp +from ...modules.transformer import AbsolutePositionEmbedder +from ...modules.sparse.transformer import SparseTransformerBlock + + +def block_attn_config(self): + """ + Return the attention configuration of the model. + """ + for i in range(self.num_blocks): + if self.attn_mode == "shift_window": + yield "serialized", self.window_size, 0, (16 * (i % 2),) * 3, sp.SerializeMode.Z_ORDER + elif self.attn_mode == "shift_sequence": + yield "serialized", self.window_size, self.window_size // 2 * (i % 2), (0, 0, 0), sp.SerializeMode.Z_ORDER + elif self.attn_mode == "shift_order": + yield "serialized", self.window_size, 0, (0, 0, 0), sp.SerializeModes[i % 4] + elif self.attn_mode == "full": + yield "full", None, None, None, None + elif self.attn_mode == "swin": + yield "windowed", self.window_size, None, self.window_size // 2 * (i % 2), None + + +class SparseTransformerBase(nn.Module): + """ + Sparse Transformer without output layers. + Serve as the base class for encoder and decoder. + """ + def __init__( + self, + in_channels: int, + model_channels: int, + num_blocks: int, + num_heads: Optional[int] = None, + num_head_channels: Optional[int] = 64, + mlp_ratio: float = 4.0, + attn_mode: Literal["full", "shift_window", "shift_sequence", "shift_order", "swin"] = "full", + window_size: Optional[int] = None, + pe_mode: Literal["ape", "rope"] = "ape", + use_fp16: bool = False, + use_checkpoint: bool = False, + qk_rms_norm: bool = False, + ): + super().__init__() + self.in_channels = in_channels + self.model_channels = model_channels + self.num_blocks = num_blocks + self.window_size = window_size + self.num_heads = num_heads or model_channels // num_head_channels + self.mlp_ratio = mlp_ratio + self.attn_mode = attn_mode + self.pe_mode = pe_mode + self.use_fp16 = use_fp16 + self.use_checkpoint = use_checkpoint + self.qk_rms_norm = qk_rms_norm + self.dtype = torch.float16 if use_fp16 else torch.float32 + + if pe_mode == "ape": + self.pos_embedder = AbsolutePositionEmbedder(model_channels) + + self.input_layer = sp.SparseLinear(in_channels, model_channels) + self.blocks = nn.ModuleList([ + SparseTransformerBlock( + model_channels, + num_heads=self.num_heads, + mlp_ratio=self.mlp_ratio, + attn_mode=attn_mode, + window_size=window_size, + shift_sequence=shift_sequence, + shift_window=shift_window, + serialize_mode=serialize_mode, + use_checkpoint=self.use_checkpoint, + use_rope=(pe_mode == "rope"), + qk_rms_norm=self.qk_rms_norm, + ) + for attn_mode, window_size, shift_sequence, shift_window, serialize_mode in block_attn_config(self) + ]) + + @property + def device(self) -> torch.device: + """ + Return the device of the model. + """ + return next(self.parameters()).device + + def convert_to_fp16(self) -> None: + """ + Convert the torso of the model to float16. + """ + self.blocks.apply(convert_module_to_f16) + + def convert_to_fp32(self) -> None: + """ + Convert the torso of the model to float32. + """ + self.blocks.apply(convert_module_to_f32) + + def initialize_weights(self) -> None: + # Initialize transformer layers: + def _basic_init(module): + if isinstance(module, nn.Linear): + torch.nn.init.xavier_uniform_(module.weight) + if module.bias is not None: + nn.init.constant_(module.bias, 0) + self.apply(_basic_init) + + def forward(self, x: sp.SparseTensor) -> sp.SparseTensor: + h = self.input_layer(x) + if self.pe_mode == "ape": + h = h + self.pos_embedder(x.coords[:, 1:]) + h = h.type(self.dtype) + for block in self.blocks: + h = block(h) + return h diff --git a/TRELLIS/trellis/models/structured_latent_vae/decoder_gs.py b/TRELLIS/trellis/models/structured_latent_vae/decoder_gs.py new file mode 100644 index 0000000000000000000000000000000000000000..b6948173f57063ea1eff411f4840c8e1a711bd69 --- /dev/null +++ b/TRELLIS/trellis/models/structured_latent_vae/decoder_gs.py @@ -0,0 +1,122 @@ +from typing import * +import torch +import torch.nn as nn +import torch.nn.functional as F +from ...modules import sparse as sp +from ...utils.random_utils import hammersley_sequence +from .base import SparseTransformerBase +from ...representations import Gaussian + + +class SLatGaussianDecoder(SparseTransformerBase): + def __init__( + self, + resolution: int, + model_channels: int, + latent_channels: int, + num_blocks: int, + num_heads: Optional[int] = None, + num_head_channels: Optional[int] = 64, + mlp_ratio: float = 4, + attn_mode: Literal["full", "shift_window", "shift_sequence", "shift_order", "swin"] = "swin", + window_size: int = 8, + pe_mode: Literal["ape", "rope"] = "ape", + use_fp16: bool = False, + use_checkpoint: bool = False, + qk_rms_norm: bool = False, + representation_config: dict = None, + ): + super().__init__( + in_channels=latent_channels, + model_channels=model_channels, + num_blocks=num_blocks, + num_heads=num_heads, + num_head_channels=num_head_channels, + mlp_ratio=mlp_ratio, + attn_mode=attn_mode, + window_size=window_size, + pe_mode=pe_mode, + use_fp16=use_fp16, + use_checkpoint=use_checkpoint, + qk_rms_norm=qk_rms_norm, + ) + self.resolution = resolution + self.rep_config = representation_config + self._calc_layout() + self.out_layer = sp.SparseLinear(model_channels, self.out_channels) + self._build_perturbation() + + self.initialize_weights() + if use_fp16: + self.convert_to_fp16() + + def initialize_weights(self) -> None: + super().initialize_weights() + # Zero-out output layers: + nn.init.constant_(self.out_layer.weight, 0) + nn.init.constant_(self.out_layer.bias, 0) + + def _build_perturbation(self) -> None: + perturbation = [hammersley_sequence(3, i, self.rep_config['num_gaussians']) for i in range(self.rep_config['num_gaussians'])] + perturbation = torch.tensor(perturbation).float() * 2 - 1 + perturbation = perturbation / self.rep_config['voxel_size'] + perturbation = torch.atanh(perturbation).to(self.device) + self.register_buffer('offset_perturbation', perturbation) + + def _calc_layout(self) -> None: + self.layout = { + '_xyz' : {'shape': (self.rep_config['num_gaussians'], 3), 'size': self.rep_config['num_gaussians'] * 3}, + '_features_dc' : {'shape': (self.rep_config['num_gaussians'], 1, 3), 'size': self.rep_config['num_gaussians'] * 3}, + '_scaling' : {'shape': (self.rep_config['num_gaussians'], 3), 'size': self.rep_config['num_gaussians'] * 3}, + '_rotation' : {'shape': (self.rep_config['num_gaussians'], 4), 'size': self.rep_config['num_gaussians'] * 4}, + '_opacity' : {'shape': (self.rep_config['num_gaussians'], 1), 'size': self.rep_config['num_gaussians']}, + } + start = 0 + for k, v in self.layout.items(): + v['range'] = (start, start + v['size']) + start += v['size'] + self.out_channels = start + + def to_representation(self, x: sp.SparseTensor) -> List[Gaussian]: + """ + Convert a batch of network outputs to 3D representations. + + Args: + x: The [N x * x C] sparse tensor output by the network. + + Returns: + list of representations + """ + ret = [] + for i in range(x.shape[0]): + representation = Gaussian( + sh_degree=0, + aabb=[-0.5, -0.5, -0.5, 1.0, 1.0, 1.0], + mininum_kernel_size = self.rep_config['3d_filter_kernel_size'], + scaling_bias = self.rep_config['scaling_bias'], + opacity_bias = self.rep_config['opacity_bias'], + scaling_activation = self.rep_config['scaling_activation'] + ) + xyz = (x.coords[x.layout[i]][:, 1:].float() + 0.5) / self.resolution + for k, v in self.layout.items(): + if k == '_xyz': + offset = x.feats[x.layout[i]][:, v['range'][0]:v['range'][1]].reshape(-1, *v['shape']) + offset = offset * self.rep_config['lr'][k] + if self.rep_config['perturb_offset']: + offset = offset + self.offset_perturbation + offset = torch.tanh(offset) / self.resolution * 0.5 * self.rep_config['voxel_size'] + _xyz = xyz.unsqueeze(1) + offset + setattr(representation, k, _xyz.flatten(0, 1)) + else: + feats = x.feats[x.layout[i]][:, v['range'][0]:v['range'][1]].reshape(-1, *v['shape']).flatten(0, 1) + feats = feats * self.rep_config['lr'][k] + setattr(representation, k, feats) + ret.append(representation) + return ret + + def forward(self, x: sp.SparseTensor) -> List[Gaussian]: + h = super().forward(x) + h = h.type(x.dtype) + h = h.replace(F.layer_norm(h.feats, h.feats.shape[-1:])) + h = self.out_layer(h) + return self.to_representation(h) diff --git a/TRELLIS/trellis/models/structured_latent_vae/decoder_mesh.py b/TRELLIS/trellis/models/structured_latent_vae/decoder_mesh.py new file mode 100644 index 0000000000000000000000000000000000000000..06c0e7286af45eab1c9861b65174431fc2210bcc --- /dev/null +++ b/TRELLIS/trellis/models/structured_latent_vae/decoder_mesh.py @@ -0,0 +1,167 @@ +from typing import * +import torch +import torch.nn as nn +import torch.nn.functional as F +import numpy as np +from ...modules.utils import zero_module, convert_module_to_f16, convert_module_to_f32 +from ...modules import sparse as sp +from .base import SparseTransformerBase +from ...representations import MeshExtractResult +from ...representations.mesh import SparseFeatures2Mesh + + +class SparseSubdivideBlock3d(nn.Module): + """ + A 3D subdivide block that can subdivide the sparse tensor. + + Args: + channels: channels in the inputs and outputs. + out_channels: if specified, the number of output channels. + num_groups: the number of groups for the group norm. + """ + def __init__( + self, + channels: int, + resolution: int, + out_channels: Optional[int] = None, + num_groups: int = 32 + ): + super().__init__() + self.channels = channels + self.resolution = resolution + self.out_resolution = resolution * 2 + self.out_channels = out_channels or channels + + self.act_layers = nn.Sequential( + sp.SparseGroupNorm32(num_groups, channels), + sp.SparseSiLU() + ) + + self.sub = sp.SparseSubdivide() + + self.out_layers = nn.Sequential( + sp.SparseConv3d(channels, self.out_channels, 3, indice_key=f"res_{self.out_resolution}"), + sp.SparseGroupNorm32(num_groups, self.out_channels), + sp.SparseSiLU(), + zero_module(sp.SparseConv3d(self.out_channels, self.out_channels, 3, indice_key=f"res_{self.out_resolution}")), + ) + + if self.out_channels == channels: + self.skip_connection = nn.Identity() + else: + self.skip_connection = sp.SparseConv3d(channels, self.out_channels, 1, indice_key=f"res_{self.out_resolution}") + + def forward(self, x: sp.SparseTensor) -> sp.SparseTensor: + """ + Apply the block to a Tensor, conditioned on a timestep embedding. + + Args: + x: an [N x C x ...] Tensor of features. + Returns: + an [N x C x ...] Tensor of outputs. + """ + h = self.act_layers(x) + h = self.sub(h) + x = self.sub(x) + h = self.out_layers(h) + h = h + self.skip_connection(x) + return h + + +class SLatMeshDecoder(SparseTransformerBase): + def __init__( + self, + resolution: int, + model_channels: int, + latent_channels: int, + num_blocks: int, + num_heads: Optional[int] = None, + num_head_channels: Optional[int] = 64, + mlp_ratio: float = 4, + attn_mode: Literal["full", "shift_window", "shift_sequence", "shift_order", "swin"] = "swin", + window_size: int = 8, + pe_mode: Literal["ape", "rope"] = "ape", + use_fp16: bool = False, + use_checkpoint: bool = False, + qk_rms_norm: bool = False, + representation_config: dict = None, + ): + super().__init__( + in_channels=latent_channels, + model_channels=model_channels, + num_blocks=num_blocks, + num_heads=num_heads, + num_head_channels=num_head_channels, + mlp_ratio=mlp_ratio, + attn_mode=attn_mode, + window_size=window_size, + pe_mode=pe_mode, + use_fp16=use_fp16, + use_checkpoint=use_checkpoint, + qk_rms_norm=qk_rms_norm, + ) + self.resolution = resolution + self.rep_config = representation_config + self.mesh_extractor = SparseFeatures2Mesh(res=self.resolution*4, use_color=self.rep_config.get('use_color', False)) + self.out_channels = self.mesh_extractor.feats_channels + self.upsample = nn.ModuleList([ + SparseSubdivideBlock3d( + channels=model_channels, + resolution=resolution, + out_channels=model_channels // 4 + ), + SparseSubdivideBlock3d( + channels=model_channels // 4, + resolution=resolution * 2, + out_channels=model_channels // 8 + ) + ]) + self.out_layer = sp.SparseLinear(model_channels // 8, self.out_channels) + + self.initialize_weights() + if use_fp16: + self.convert_to_fp16() + + def initialize_weights(self) -> None: + super().initialize_weights() + # Zero-out output layers: + nn.init.constant_(self.out_layer.weight, 0) + nn.init.constant_(self.out_layer.bias, 0) + + def convert_to_fp16(self) -> None: + """ + Convert the torso of the model to float16. + """ + super().convert_to_fp16() + self.upsample.apply(convert_module_to_f16) + + def convert_to_fp32(self) -> None: + """ + Convert the torso of the model to float32. + """ + super().convert_to_fp32() + self.upsample.apply(convert_module_to_f32) + + def to_representation(self, x: sp.SparseTensor) -> List[MeshExtractResult]: + """ + Convert a batch of network outputs to 3D representations. + + Args: + x: The [N x * x C] sparse tensor output by the network. + + Returns: + list of representations + """ + ret = [] + for i in range(x.shape[0]): + mesh = self.mesh_extractor(x[i], training=self.training) + ret.append(mesh) + return ret + + def forward(self, x: sp.SparseTensor) -> List[MeshExtractResult]: + h = super().forward(x) + for block in self.upsample: + h = block(h) + h = h.type(x.dtype) + h = self.out_layer(h) + return self.to_representation(h) diff --git a/TRELLIS/trellis/models/structured_latent_vae/decoder_rf.py b/TRELLIS/trellis/models/structured_latent_vae/decoder_rf.py new file mode 100644 index 0000000000000000000000000000000000000000..4e916eebafd7e97fed82cadb567244719dbbcd83 --- /dev/null +++ b/TRELLIS/trellis/models/structured_latent_vae/decoder_rf.py @@ -0,0 +1,104 @@ +from typing import * +import torch +import torch.nn as nn +import torch.nn.functional as F +import numpy as np +from ...modules import sparse as sp +from .base import SparseTransformerBase +from ...representations import Strivec + + +class SLatRadianceFieldDecoder(SparseTransformerBase): + def __init__( + self, + resolution: int, + model_channels: int, + latent_channels: int, + num_blocks: int, + num_heads: Optional[int] = None, + num_head_channels: Optional[int] = 64, + mlp_ratio: float = 4, + attn_mode: Literal["full", "shift_window", "shift_sequence", "shift_order", "swin"] = "swin", + window_size: int = 8, + pe_mode: Literal["ape", "rope"] = "ape", + use_fp16: bool = False, + use_checkpoint: bool = False, + qk_rms_norm: bool = False, + representation_config: dict = None, + ): + super().__init__( + in_channels=latent_channels, + model_channels=model_channels, + num_blocks=num_blocks, + num_heads=num_heads, + num_head_channels=num_head_channels, + mlp_ratio=mlp_ratio, + attn_mode=attn_mode, + window_size=window_size, + pe_mode=pe_mode, + use_fp16=use_fp16, + use_checkpoint=use_checkpoint, + qk_rms_norm=qk_rms_norm, + ) + self.resolution = resolution + self.rep_config = representation_config + self._calc_layout() + self.out_layer = sp.SparseLinear(model_channels, self.out_channels) + + self.initialize_weights() + if use_fp16: + self.convert_to_fp16() + + def initialize_weights(self) -> None: + super().initialize_weights() + # Zero-out output layers: + nn.init.constant_(self.out_layer.weight, 0) + nn.init.constant_(self.out_layer.bias, 0) + + def _calc_layout(self) -> None: + self.layout = { + 'trivec': {'shape': (self.rep_config['rank'], 3, self.rep_config['dim']), 'size': self.rep_config['rank'] * 3 * self.rep_config['dim']}, + 'density': {'shape': (self.rep_config['rank'],), 'size': self.rep_config['rank']}, + 'features_dc': {'shape': (self.rep_config['rank'], 1, 3), 'size': self.rep_config['rank'] * 3}, + } + start = 0 + for k, v in self.layout.items(): + v['range'] = (start, start + v['size']) + start += v['size'] + self.out_channels = start + + def to_representation(self, x: sp.SparseTensor) -> List[Strivec]: + """ + Convert a batch of network outputs to 3D representations. + + Args: + x: The [N x * x C] sparse tensor output by the network. + + Returns: + list of representations + """ + ret = [] + for i in range(x.shape[0]): + representation = Strivec( + sh_degree=0, + resolution=self.resolution, + aabb=[-0.5, -0.5, -0.5, 1, 1, 1], + rank=self.rep_config['rank'], + dim=self.rep_config['dim'], + device='cuda', + ) + representation.density_shift = 0.0 + representation.position = (x.coords[x.layout[i]][:, 1:].float() + 0.5) / self.resolution + representation.depth = torch.full((representation.position.shape[0], 1), int(np.log2(self.resolution)), dtype=torch.uint8, device='cuda') + for k, v in self.layout.items(): + setattr(representation, k, x.feats[x.layout[i]][:, v['range'][0]:v['range'][1]].reshape(-1, *v['shape'])) + representation.trivec = representation.trivec + 1 + ret.append(representation) + return ret + + def forward(self, x: sp.SparseTensor) -> List[Strivec]: + h = super().forward(x) + h = h.type(x.dtype) + h = h.replace(F.layer_norm(h.feats, h.feats.shape[-1:])) + h = self.out_layer(h) + return self.to_representation(h) diff --git a/TRELLIS/trellis/models/structured_latent_vae/encoder.py b/TRELLIS/trellis/models/structured_latent_vae/encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..d3c04928bbd0a3fe88c05c687024a92daa0a1d6d --- /dev/null +++ b/TRELLIS/trellis/models/structured_latent_vae/encoder.py @@ -0,0 +1,72 @@ +from typing import * +import torch +import torch.nn as nn +import torch.nn.functional as F +from ...modules import sparse as sp +from .base import SparseTransformerBase + + +class SLatEncoder(SparseTransformerBase): + def __init__( + self, + resolution: int, + in_channels: int, + model_channels: int, + latent_channels: int, + num_blocks: int, + num_heads: Optional[int] = None, + num_head_channels: Optional[int] = 64, + mlp_ratio: float = 4, + attn_mode: Literal["full", "shift_window", "shift_sequence", "shift_order", "swin"] = "swin", + window_size: int = 8, + pe_mode: Literal["ape", "rope"] = "ape", + use_fp16: bool = False, + use_checkpoint: bool = False, + qk_rms_norm: bool = False, + ): + super().__init__( + in_channels=in_channels, + model_channels=model_channels, + num_blocks=num_blocks, + num_heads=num_heads, + num_head_channels=num_head_channels, + mlp_ratio=mlp_ratio, + attn_mode=attn_mode, + window_size=window_size, + pe_mode=pe_mode, + use_fp16=use_fp16, + use_checkpoint=use_checkpoint, + qk_rms_norm=qk_rms_norm, + ) + self.resolution = resolution + self.out_layer = sp.SparseLinear(model_channels, 2 * latent_channels) + + self.initialize_weights() + if use_fp16: + self.convert_to_fp16() + + def initialize_weights(self) -> None: + super().initialize_weights() + # Zero-out output layers: + nn.init.constant_(self.out_layer.weight, 0) + nn.init.constant_(self.out_layer.bias, 0) + + def forward(self, x: sp.SparseTensor, sample_posterior=True, return_raw=False): + h = super().forward(x) + h = h.type(x.dtype) + h = h.replace(F.layer_norm(h.feats, h.feats.shape[-1:])) + h = self.out_layer(h) + + # Sample from the posterior distribution + mean, logvar = h.feats.chunk(2, dim=-1) + if sample_posterior: + std = torch.exp(0.5 * logvar) + z = mean + std * torch.randn_like(std) + else: + z = mean + z = h.replace(z) + + if return_raw: + return z, mean, logvar + else: + return z diff --git a/TRELLIS/trellis/modules/attention/__init__.py b/TRELLIS/trellis/modules/attention/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..ffebf7dbf737606b32ef62c2e86f568189d322f0 --- /dev/null +++ b/TRELLIS/trellis/modules/attention/__init__.py @@ -0,0 +1,36 @@ +from typing import * + +BACKEND = 'flash_attn' +DEBUG = False + +def __from_env(): + import os + + global BACKEND + global DEBUG + + env_attn_backend = os.environ.get('ATTN_BACKEND') + env_sttn_debug = os.environ.get('ATTN_DEBUG') + + if env_attn_backend is not None and env_attn_backend in ['xformers', 'flash_attn', 'sdpa', 'naive']: + BACKEND = env_attn_backend + if env_sttn_debug is not None: + DEBUG = env_sttn_debug == '1' + + print(f"[ATTENTION] Using backend: {BACKEND}") + + +__from_env() + + +def set_backend(backend: Literal['xformers', 'flash_attn']): + global BACKEND + BACKEND = backend + +def set_debug(debug: bool): + global DEBUG + DEBUG = debug + + +from .full_attn import * +from .modules import * diff --git a/TRELLIS/trellis/modules/attention/full_attn.py b/TRELLIS/trellis/modules/attention/full_attn.py new file mode 100644 index 0000000000000000000000000000000000000000..68303dca94cefacb43865d3b737f2723dded20dd --- /dev/null +++ b/TRELLIS/trellis/modules/attention/full_attn.py @@ -0,0 +1,140 @@ +from typing import * +import torch +import math +from . import DEBUG, BACKEND + +if BACKEND == 'xformers': + import xformers.ops as xops +elif BACKEND == 'flash_attn': + import flash_attn +elif BACKEND == 'sdpa': + from torch.nn.functional import scaled_dot_product_attention as sdpa +elif BACKEND == 'naive': + pass +else: + raise ValueError(f"Unknown attention backend: {BACKEND}") + + +__all__ = [ + 'scaled_dot_product_attention', +] + + +def _naive_sdpa(q, k, v): + """ + Naive implementation of scaled dot product attention. + """ + q = q.permute(0, 2, 1, 3) # [N, H, L, C] + k = k.permute(0, 2, 1, 3) # [N, H, L, C] + v = v.permute(0, 2, 1, 3) # [N, H, L, C] + scale_factor = 1 / math.sqrt(q.size(-1)) + attn_weight = q @ k.transpose(-2, -1) * scale_factor + attn_weight = torch.softmax(attn_weight, dim=-1) + out = attn_weight @ v + out = out.permute(0, 2, 1, 3) # [N, L, H, C] + return out + + +@overload +def scaled_dot_product_attention(qkv: torch.Tensor) -> torch.Tensor: + """ + Apply scaled dot product attention. + + Args: + qkv (torch.Tensor): A [N, L, 3, H, C] tensor containing Qs, Ks, and Vs. + """ + ... + +@overload +def scaled_dot_product_attention(q: torch.Tensor, kv: torch.Tensor) -> torch.Tensor: + """ + Apply scaled dot product attention. + + Args: + q (torch.Tensor): A [N, L, H, C] tensor containing Qs. + kv (torch.Tensor): A [N, L, 2, H, C] tensor containing Ks and Vs. + """ + ... + +@overload +def scaled_dot_product_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> torch.Tensor: + """ + Apply scaled dot product attention. + + Args: + q (torch.Tensor): A [N, L, H, Ci] tensor containing Qs. + k (torch.Tensor): A [N, L, H, Ci] tensor containing Ks. + v (torch.Tensor): A [N, L, H, Co] tensor containing Vs. + + Note: + k and v are assumed to have the same coordinate map. + """ + ... + +def scaled_dot_product_attention(*args, **kwargs): + arg_names_dict = { + 1: ['qkv'], + 2: ['q', 'kv'], + 3: ['q', 'k', 'v'] + } + num_all_args = len(args) + len(kwargs) + assert num_all_args in arg_names_dict, f"Invalid number of arguments, got {num_all_args}, expected 1, 2, or 3" + for key in arg_names_dict[num_all_args][len(args):]: + assert key in kwargs, f"Missing argument {key}" + + if num_all_args == 1: + qkv = args[0] if len(args) > 0 else kwargs['qkv'] + assert len(qkv.shape) == 5 and qkv.shape[2] == 3, f"Invalid shape for qkv, got {qkv.shape}, expected [N, L, 3, H, C]" + device = qkv.device + + elif num_all_args == 2: + q = args[0] if len(args) > 0 else kwargs['q'] + kv = args[1] if len(args) > 1 else kwargs['kv'] + assert q.shape[0] == kv.shape[0], f"Batch size mismatch, got {q.shape[0]} and {kv.shape[0]}" + assert len(q.shape) == 4, f"Invalid shape for q, got {q.shape}, expected [N, L, H, C]" + assert len(kv.shape) == 5, f"Invalid shape for kv, got {kv.shape}, expected [N, L, 2, H, C]" + device = q.device + + elif num_all_args == 3: + q = args[0] if len(args) > 0 else kwargs['q'] + k = args[1] if len(args) > 1 else kwargs['k'] + v = args[2] if len(args) > 2 else kwargs['v'] + assert q.shape[0] == k.shape[0] == v.shape[0], f"Batch size mismatch, got {q.shape[0]}, {k.shape[0]}, and {v.shape[0]}" + assert len(q.shape) == 4, f"Invalid shape for q, got {q.shape}, expected [N, L, H, Ci]" + assert len(k.shape) == 4, f"Invalid shape for k, got {k.shape}, expected [N, L, H, Ci]" + assert len(v.shape) == 4, f"Invalid shape for v, got {v.shape}, expected [N, L, H, Co]" + device = q.device + + if BACKEND == 'xformers': + if num_all_args == 1: + q, k, v = qkv.unbind(dim=2) + elif num_all_args == 2: + k, v = kv.unbind(dim=2) + out = xops.memory_efficient_attention(q, k, v) + elif BACKEND == 'flash_attn': + if num_all_args == 1: + out = flash_attn.flash_attn_qkvpacked_func(qkv) + elif num_all_args == 2: + out = flash_attn.flash_attn_kvpacked_func(q, kv) + elif num_all_args == 3: + out = flash_attn.flash_attn_func(q, k, v) + elif BACKEND == 'sdpa': + if num_all_args == 1: + q, k, v = qkv.unbind(dim=2) + elif num_all_args == 2: + k, v = kv.unbind(dim=2) + q = q.permute(0, 2, 1, 3) # [N, H, L, C] + k = k.permute(0, 2, 1, 3) # [N, H, L, C] + v = v.permute(0, 2, 1, 3) # [N, H, L, C] + out = sdpa(q, k, v) # [N, H, L, C] + out = out.permute(0, 2, 1, 3) # [N, L, H, C] + elif BACKEND == 'naive': + if num_all_args == 1: + q, k, v = qkv.unbind(dim=2) + elif num_all_args == 2: + k, v = kv.unbind(dim=2) + out = _naive_sdpa(q, k, v) + else: + raise ValueError(f"Unknown attention module: {BACKEND}") + + return out diff --git a/TRELLIS/trellis/modules/attention/modules.py b/TRELLIS/trellis/modules/attention/modules.py new file mode 100644 index 0000000000000000000000000000000000000000..a82a9d27b6c02c7c88c724ddc0456fe61f6c0fb0 --- /dev/null +++ b/TRELLIS/trellis/modules/attention/modules.py @@ -0,0 +1,146 @@ +from typing import * +import torch +import torch.nn as nn +import torch.nn.functional as F +from .full_attn import scaled_dot_product_attention + + +class MultiHeadRMSNorm(nn.Module): + def __init__(self, dim: int, heads: int): + super().__init__() + self.scale = dim ** 0.5 + self.gamma = nn.Parameter(torch.ones(heads, dim)) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return (F.normalize(x.float(), dim = -1) * self.gamma * self.scale).to(x.dtype) + + +class RotaryPositionEmbedder(nn.Module): + def __init__(self, hidden_size: int, in_channels: int = 3): + super().__init__() + assert hidden_size % 2 == 0, "Hidden size must be divisible by 2" + self.hidden_size = hidden_size + self.in_channels = in_channels + self.freq_dim = hidden_size // in_channels // 2 + self.freqs = torch.arange(self.freq_dim, dtype=torch.float32) / self.freq_dim + self.freqs = 1.0 / (10000 ** self.freqs) + + def _get_phases(self, indices: torch.Tensor) -> torch.Tensor: + self.freqs = self.freqs.to(indices.device) + phases = torch.outer(indices, self.freqs) + phases = torch.polar(torch.ones_like(phases), phases) + return phases + + def _rotary_embedding(self, x: torch.Tensor, phases: torch.Tensor) -> torch.Tensor: + x_complex = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2)) + x_rotated = x_complex * phases + x_embed = torch.view_as_real(x_rotated).reshape(*x_rotated.shape[:-1], -1).to(x.dtype) + return x_embed + + def forward(self, q: torch.Tensor, k: torch.Tensor, indices: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Args: + q (sp.SparseTensor): [..., N, D] tensor of queries + k (sp.SparseTensor): [..., N, D] tensor of keys + indices (torch.Tensor): [..., N, C] tensor of spatial positions + """ + if indices is None: + indices = torch.arange(q.shape[-2], device=q.device) + if len(q.shape) > 2: + indices = indices.unsqueeze(0).expand(q.shape[:-2] + (-1,)) + + phases = self._get_phases(indices.reshape(-1)).reshape(*indices.shape[:-1], -1) + if phases.shape[1] < self.hidden_size // 2: + phases = torch.cat([phases, torch.polar( + torch.ones(*phases.shape[:-1], self.hidden_size // 2 - phases.shape[1], device=phases.device), + torch.zeros(*phases.shape[:-1], self.hidden_size // 2 - phases.shape[1], device=phases.device) + )], dim=-1) + q_embed = self._rotary_embedding(q, phases) + k_embed = self._rotary_embedding(k, phases) + return q_embed, k_embed + + +class MultiHeadAttention(nn.Module): + def __init__( + self, + channels: int, + num_heads: int, + ctx_channels: Optional[int]=None, + type: Literal["self", "cross"] = "self", + attn_mode: Literal["full", "windowed"] = "full", + window_size: Optional[int] = None, + shift_window: Optional[Tuple[int, int, int]] = None, + qkv_bias: bool = True, + use_rope: bool = False, + qk_rms_norm: bool = False, + ): + super().__init__() + assert channels % num_heads == 0 + assert type in ["self", "cross"], f"Invalid attention type: {type}" + assert attn_mode in ["full", "windowed"], f"Invalid attention mode: {attn_mode}" + assert type == "self" or attn_mode == "full", "Cross-attention only supports full attention" + + if attn_mode == "windowed": + raise NotImplementedError("Windowed attention is not yet implemented") + + self.channels = channels + self.head_dim = channels // num_heads + self.ctx_channels = ctx_channels if ctx_channels is not None else channels + self.num_heads = num_heads + self._type = type + self.attn_mode = attn_mode + self.window_size = window_size + self.shift_window = shift_window + self.use_rope = use_rope + self.qk_rms_norm = qk_rms_norm + + if self._type == "self": + self.to_qkv = nn.Linear(channels, channels * 3, bias=qkv_bias) + else: + self.to_q = nn.Linear(channels, channels, bias=qkv_bias) + self.to_kv = nn.Linear(self.ctx_channels, channels * 2, bias=qkv_bias) + + if self.qk_rms_norm: + self.q_rms_norm = MultiHeadRMSNorm(self.head_dim, num_heads) + self.k_rms_norm = MultiHeadRMSNorm(self.head_dim, num_heads) + + self.to_out = nn.Linear(channels, channels) + + if use_rope: + self.rope = RotaryPositionEmbedder(channels) + + def forward(self, x: torch.Tensor, context: Optional[torch.Tensor] = None, indices: Optional[torch.Tensor] = None) -> torch.Tensor: + B, L, C = x.shape + if self._type == "self": + qkv = self.to_qkv(x) + qkv = qkv.reshape(B, L, 3, self.num_heads, -1) + if self.use_rope: + q, k, v = qkv.unbind(dim=2) + q, k = self.rope(q, k, indices) + qkv = torch.stack([q, k, v], dim=2) + if self.attn_mode == "full": + if self.qk_rms_norm: + q, k, v = qkv.unbind(dim=2) + q = self.q_rms_norm(q) + k = self.k_rms_norm(k) + h = scaled_dot_product_attention(q, k, v) + else: + h = scaled_dot_product_attention(qkv) + elif self.attn_mode == "windowed": + raise NotImplementedError("Windowed attention is not yet implemented") + else: + Lkv = context.shape[1] + q = self.to_q(x) + kv = self.to_kv(context) + q = q.reshape(B, L, self.num_heads, -1) + kv = kv.reshape(B, Lkv, 2, self.num_heads, -1) + if self.qk_rms_norm: + q = self.q_rms_norm(q) + k, v = kv.unbind(dim=2) + k = self.k_rms_norm(k) + h = scaled_dot_product_attention(q, k, v) + else: + h = scaled_dot_product_attention(q, kv) + h = h.reshape(B, L, -1) + h = self.to_out(h) + return h diff --git a/TRELLIS/trellis/modules/norm.py b/TRELLIS/trellis/modules/norm.py new file mode 100644 index 0000000000000000000000000000000000000000..8f1750b88aae9e5ea901fc7af9f978a0db82de6d --- /dev/null +++ b/TRELLIS/trellis/modules/norm.py @@ -0,0 +1,25 @@ +import torch +import torch.nn as nn + + +class LayerNorm32(nn.LayerNorm): + def forward(self, x: torch.Tensor) -> torch.Tensor: + return super().forward(x.float()).type(x.dtype) + + +class GroupNorm32(nn.GroupNorm): + """ + A GroupNorm layer that converts to float32 before the forward pass. + """ + def forward(self, x: torch.Tensor) -> torch.Tensor: + return super().forward(x.float()).type(x.dtype) + + +class ChannelLayerNorm32(LayerNorm32): + def forward(self, x: torch.Tensor) -> torch.Tensor: + DIM = x.dim() + x = x.permute(0, *range(2, DIM), 1).contiguous() + x = super().forward(x) + x = x.permute(0, DIM-1, *range(1, DIM-1)).contiguous() + return x + \ No newline at end of file diff --git a/TRELLIS/trellis/modules/sparse/__init__.py b/TRELLIS/trellis/modules/sparse/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..df649cc4431a0f7f1a49ed15780f4217399adf66 --- /dev/null +++ b/TRELLIS/trellis/modules/sparse/__init__.py @@ -0,0 +1,102 @@ +from typing import * + +BACKEND = 'spconv' +DEBUG = False +ATTN = 'flash_attn' + +def __from_env(): + import os + + global BACKEND + global DEBUG + global ATTN + + env_sparse_backend = os.environ.get('SPARSE_BACKEND') + env_sparse_debug = os.environ.get('SPARSE_DEBUG') + env_sparse_attn = os.environ.get('SPARSE_ATTN_BACKEND') + if env_sparse_attn is None: + env_sparse_attn = os.environ.get('ATTN_BACKEND') + + if env_sparse_backend is not None and env_sparse_backend in ['spconv', 'torchsparse']: + BACKEND = env_sparse_backend + if env_sparse_debug is not None: + DEBUG = env_sparse_debug == '1' + if env_sparse_attn is not None and env_sparse_attn in ['xformers', 'flash_attn']: + ATTN = env_sparse_attn + + print(f"[SPARSE] Backend: {BACKEND}, Attention: {ATTN}") + + +__from_env() + + +def set_backend(backend: Literal['spconv', 'torchsparse']): + global BACKEND + BACKEND = backend + +def set_debug(debug: bool): + global DEBUG + DEBUG = debug + +def set_attn(attn: Literal['xformers', 'flash_attn']): + global ATTN + ATTN = attn + + +import importlib + +__attributes = { + 'SparseTensor': 'basic', + 'sparse_batch_broadcast': 'basic', + 'sparse_batch_op': 'basic', + 'sparse_cat': 'basic', + 'sparse_unbind': 'basic', + 'SparseGroupNorm': 'norm', + 'SparseLayerNorm': 'norm', + 'SparseGroupNorm32': 'norm', + 'SparseLayerNorm32': 'norm', + 'SparseReLU': 'nonlinearity', + 'SparseSiLU': 'nonlinearity', + 'SparseGELU': 'nonlinearity', + 'SparseActivation': 'nonlinearity', + 'SparseLinear': 'linear', + 'sparse_scaled_dot_product_attention': 'attention', + 'SerializeMode': 'attention', + 'sparse_serialized_scaled_dot_product_self_attention': 'attention', + 'sparse_windowed_scaled_dot_product_self_attention': 'attention', + 'SparseMultiHeadAttention': 'attention', + 'SparseConv3d': 'conv', + 'SparseInverseConv3d': 'conv', + 'SparseDownsample': 'spatial', + 'SparseUpsample': 'spatial', + 'SparseSubdivide' : 'spatial' +} + +__submodules = ['transformer'] + +__all__ = list(__attributes.keys()) + __submodules + +def __getattr__(name): + if name not in globals(): + if name in __attributes: + module_name = __attributes[name] + module = importlib.import_module(f".{module_name}", __name__) + globals()[name] = getattr(module, name) + elif name in __submodules: + module = importlib.import_module(f".{name}", __name__) + globals()[name] = module + else: + raise AttributeError(f"module {__name__} has no attribute {name}") + return globals()[name] + + +# For Pylance +if __name__ == '__main__': + from .basic import * + from .norm import * + from .nonlinearity import * + from .linear import * + from .attention import * + from .conv import * + from .spatial import * + import transformer diff --git a/TRELLIS/trellis/modules/sparse/attention/__init__.py b/TRELLIS/trellis/modules/sparse/attention/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..400de9a25960cf8ed32a3d7ec143af95b5f862bc --- /dev/null +++ b/TRELLIS/trellis/modules/sparse/attention/__init__.py @@ -0,0 +1,4 @@ +from .full_attn import * +from .serialized_attn import * +from .windowed_attn import * +from .modules import * diff --git a/TRELLIS/trellis/modules/sparse/attention/full_attn.py b/TRELLIS/trellis/modules/sparse/attention/full_attn.py new file mode 100644 index 0000000000000000000000000000000000000000..c724327678d38a5625699ace09c67107458b4d0a --- /dev/null +++ b/TRELLIS/trellis/modules/sparse/attention/full_attn.py @@ -0,0 +1,215 @@ +from typing import * +import torch +from .. import SparseTensor +from .. import DEBUG, ATTN + +if ATTN == 'xformers': + import xformers.ops as xops +elif ATTN == 'flash_attn': + import flash_attn +else: + raise ValueError(f"Unknown attention module: {ATTN}") + + +__all__ = [ + 'sparse_scaled_dot_product_attention', +] + + +@overload +def sparse_scaled_dot_product_attention(qkv: SparseTensor) -> SparseTensor: + """ + Apply scaled dot product attention to a sparse tensor. + + Args: + qkv (SparseTensor): A [N, *, 3, H, C] sparse tensor containing Qs, Ks, and Vs. + """ + ... + +@overload +def sparse_scaled_dot_product_attention(q: SparseTensor, kv: Union[SparseTensor, torch.Tensor]) -> SparseTensor: + """ + Apply scaled dot product attention to a sparse tensor. + + Args: + q (SparseTensor): A [N, *, H, C] sparse tensor containing Qs. + kv (SparseTensor or torch.Tensor): A [N, *, 2, H, C] sparse tensor or a [N, L, 2, H, C] dense tensor containing Ks and Vs. + """ + ... + +@overload +def sparse_scaled_dot_product_attention(q: torch.Tensor, kv: SparseTensor) -> torch.Tensor: + """ + Apply scaled dot product attention to a sparse tensor. + + Args: + q (SparseTensor): A [N, L, H, C] dense tensor containing Qs. + kv (SparseTensor or torch.Tensor): A [N, *, 2, H, C] sparse tensor containing Ks and Vs. + """ + ... + +@overload +def sparse_scaled_dot_product_attention(q: SparseTensor, k: SparseTensor, v: SparseTensor) -> SparseTensor: + """ + Apply scaled dot product attention to a sparse tensor. + + Args: + q (SparseTensor): A [N, *, H, Ci] sparse tensor containing Qs. + k (SparseTensor): A [N, *, H, Ci] sparse tensor containing Ks. + v (SparseTensor): A [N, *, H, Co] sparse tensor containing Vs. + + Note: + k and v are assumed to have the same coordinate map. + """ + ... + +@overload +def sparse_scaled_dot_product_attention(q: SparseTensor, k: torch.Tensor, v: torch.Tensor) -> SparseTensor: + """ + Apply scaled dot product attention to a sparse tensor. + + Args: + q (SparseTensor): A [N, *, H, Ci] sparse tensor containing Qs. + k (torch.Tensor): A [N, L, H, Ci] dense tensor containing Ks. + v (torch.Tensor): A [N, L, H, Co] dense tensor containing Vs. + """ + ... + +@overload +def sparse_scaled_dot_product_attention(q: torch.Tensor, k: SparseTensor, v: SparseTensor) -> torch.Tensor: + """ + Apply scaled dot product attention to a sparse tensor. + + Args: + q (torch.Tensor): A [N, L, H, Ci] dense tensor containing Qs. + k (SparseTensor): A [N, *, H, Ci] sparse tensor containing Ks. + v (SparseTensor): A [N, *, H, Co] sparse tensor containing Vs. + """ + ... + +def sparse_scaled_dot_product_attention(*args, **kwargs): + arg_names_dict = { + 1: ['qkv'], + 2: ['q', 'kv'], + 3: ['q', 'k', 'v'] + } + num_all_args = len(args) + len(kwargs) + assert num_all_args in arg_names_dict, f"Invalid number of arguments, got {num_all_args}, expected 1, 2, or 3" + for key in arg_names_dict[num_all_args][len(args):]: + assert key in kwargs, f"Missing argument {key}" + + if num_all_args == 1: + qkv = args[0] if len(args) > 0 else kwargs['qkv'] + assert isinstance(qkv, SparseTensor), f"qkv must be a SparseTensor, got {type(qkv)}" + assert len(qkv.shape) == 4 and qkv.shape[1] == 3, f"Invalid shape for qkv, got {qkv.shape}, expected [N, *, 3, H, C]" + device = qkv.device + + s = qkv + q_seqlen = [qkv.layout[i].stop - qkv.layout[i].start for i in range(qkv.shape[0])] + kv_seqlen = q_seqlen + qkv = qkv.feats # [T, 3, H, C] + + elif num_all_args == 2: + q = args[0] if len(args) > 0 else kwargs['q'] + kv = args[1] if len(args) > 1 else kwargs['kv'] + assert isinstance(q, SparseTensor) and isinstance(kv, (SparseTensor, torch.Tensor)) or \ + isinstance(q, torch.Tensor) and isinstance(kv, SparseTensor), \ + f"Invalid types, got {type(q)} and {type(kv)}" + assert q.shape[0] == kv.shape[0], f"Batch size mismatch, got {q.shape[0]} and {kv.shape[0]}" + device = q.device + + if isinstance(q, SparseTensor): + assert len(q.shape) == 3, f"Invalid shape for q, got {q.shape}, expected [N, *, H, C]" + s = q + q_seqlen = [q.layout[i].stop - q.layout[i].start for i in range(q.shape[0])] + q = q.feats # [T_Q, H, C] + else: + assert len(q.shape) == 4, f"Invalid shape for q, got {q.shape}, expected [N, L, H, C]" + s = None + N, L, H, C = q.shape + q_seqlen = [L] * N + q = q.reshape(N * L, H, C) # [T_Q, H, C] + + if isinstance(kv, SparseTensor): + assert len(kv.shape) == 4 and kv.shape[1] == 2, f"Invalid shape for kv, got {kv.shape}, expected [N, *, 2, H, C]" + kv_seqlen = [kv.layout[i].stop - kv.layout[i].start for i in range(kv.shape[0])] + kv = kv.feats # [T_KV, 2, H, C] + else: + assert len(kv.shape) == 5, f"Invalid shape for kv, got {kv.shape}, expected [N, L, 2, H, C]" + N, L, _, H, C = kv.shape + kv_seqlen = [L] * N + kv = kv.reshape(N * L, 2, H, C) # [T_KV, 2, H, C] + + elif num_all_args == 3: + q = args[0] if len(args) > 0 else kwargs['q'] + k = args[1] if len(args) > 1 else kwargs['k'] + v = args[2] if len(args) > 2 else kwargs['v'] + assert isinstance(q, SparseTensor) and isinstance(k, (SparseTensor, torch.Tensor)) and type(k) == type(v) or \ + isinstance(q, torch.Tensor) and isinstance(k, SparseTensor) and isinstance(v, SparseTensor), \ + f"Invalid types, got {type(q)}, {type(k)}, and {type(v)}" + assert q.shape[0] == k.shape[0] == v.shape[0], f"Batch size mismatch, got {q.shape[0]}, {k.shape[0]}, and {v.shape[0]}" + device = q.device + + if isinstance(q, SparseTensor): + assert len(q.shape) == 3, f"Invalid shape for q, got {q.shape}, expected [N, *, H, Ci]" + s = q + q_seqlen = [q.layout[i].stop - q.layout[i].start for i in range(q.shape[0])] + q = q.feats # [T_Q, H, Ci] + else: + assert len(q.shape) == 4, f"Invalid shape for q, got {q.shape}, expected [N, L, H, Ci]" + s = None + N, L, H, CI = q.shape + q_seqlen = [L] * N + q = q.reshape(N * L, H, CI) # [T_Q, H, Ci] + + if isinstance(k, SparseTensor): + assert len(k.shape) == 3, f"Invalid shape for k, got {k.shape}, expected [N, *, H, Ci]" + assert len(v.shape) == 3, f"Invalid shape for v, got {v.shape}, expected [N, *, H, Co]" + kv_seqlen = [k.layout[i].stop - k.layout[i].start for i in range(k.shape[0])] + k = k.feats # [T_KV, H, Ci] + v = v.feats # [T_KV, H, Co] + else: + assert len(k.shape) == 4, f"Invalid shape for k, got {k.shape}, expected [N, L, H, Ci]" + assert len(v.shape) == 4, f"Invalid shape for v, got {v.shape}, expected [N, L, H, Co]" + N, L, H, CI, CO = *k.shape, v.shape[-1] + kv_seqlen = [L] * N + k = k.reshape(N * L, H, CI) # [T_KV, H, Ci] + v = v.reshape(N * L, H, CO) # [T_KV, H, Co] + + if DEBUG: + if s is not None: + for i in range(s.shape[0]): + assert (s.coords[s.layout[i]] == i).all(), f"SparseScaledDotProductSelfAttention: batch index mismatch" + if num_all_args in [2, 3]: + assert q.shape[:2] == [1, sum(q_seqlen)], f"SparseScaledDotProductSelfAttention: q shape mismatch" + if num_all_args == 3: + assert k.shape[:2] == [1, sum(kv_seqlen)], f"SparseScaledDotProductSelfAttention: k shape mismatch" + assert v.shape[:2] == [1, sum(kv_seqlen)], f"SparseScaledDotProductSelfAttention: v shape mismatch" + + if ATTN == 'xformers': + if num_all_args == 1: + q, k, v = qkv.unbind(dim=1) + elif num_all_args == 2: + k, v = kv.unbind(dim=1) + q = q.unsqueeze(0) + k = k.unsqueeze(0) + v = v.unsqueeze(0) + mask = xops.fmha.BlockDiagonalMask.from_seqlens(q_seqlen, kv_seqlen) + out = xops.memory_efficient_attention(q, k, v, mask)[0] + elif ATTN == 'flash_attn': + cu_seqlens_q = torch.cat([torch.tensor([0]), torch.cumsum(torch.tensor(q_seqlen), dim=0)]).int().to(device) + if num_all_args in [2, 3]: + cu_seqlens_kv = torch.cat([torch.tensor([0]), torch.cumsum(torch.tensor(kv_seqlen), dim=0)]).int().to(device) + if num_all_args == 1: + out = flash_attn.flash_attn_varlen_qkvpacked_func(qkv, cu_seqlens_q, max(q_seqlen)) + elif num_all_args == 2: + out = flash_attn.flash_attn_varlen_kvpacked_func(q, kv, cu_seqlens_q, cu_seqlens_kv, max(q_seqlen), max(kv_seqlen)) + elif num_all_args == 3: + out = flash_attn.flash_attn_varlen_func(q, k, v, cu_seqlens_q, cu_seqlens_kv, max(q_seqlen), max(kv_seqlen)) + else: + raise ValueError(f"Unknown attention module: {ATTN}") + + if s is not None: + return s.replace(out) + else: + return out.reshape(N, L, H, -1) diff --git a/TRELLIS/trellis/modules/sparse/attention/modules.py b/TRELLIS/trellis/modules/sparse/attention/modules.py new file mode 100644 index 0000000000000000000000000000000000000000..d8fbb572786483a840dc325097eacc08b815a0a5 --- /dev/null +++ b/TRELLIS/trellis/modules/sparse/attention/modules.py @@ -0,0 +1,139 @@ +from typing import * +import torch +import torch.nn as nn +import torch.nn.functional as F +from .. import SparseTensor +from .full_attn import sparse_scaled_dot_product_attention +from .serialized_attn import SerializeMode, sparse_serialized_scaled_dot_product_self_attention +from .windowed_attn import sparse_windowed_scaled_dot_product_self_attention +from ...attention import RotaryPositionEmbedder + + +class SparseMultiHeadRMSNorm(nn.Module): + def __init__(self, dim: int, heads: int): + super().__init__() + self.scale = dim ** 0.5 + self.gamma = nn.Parameter(torch.ones(heads, dim)) + + def forward(self, x: Union[SparseTensor, torch.Tensor]) -> Union[SparseTensor, torch.Tensor]: + x_type = x.dtype + x = x.float() + if isinstance(x, SparseTensor): + x = x.replace(F.normalize(x.feats, dim=-1)) + else: + x = F.normalize(x, dim=-1) + return (x * self.gamma * self.scale).to(x_type) + + +class SparseMultiHeadAttention(nn.Module): + def __init__( + self, + channels: int, + num_heads: int, + ctx_channels: Optional[int] = None, + type: Literal["self", "cross"] = "self", + attn_mode: Literal["full", "serialized", "windowed"] = "full", + window_size: Optional[int] = None, + shift_sequence: Optional[int] = None, + shift_window: Optional[Tuple[int, int, int]] = None, + serialize_mode: Optional[SerializeMode] = None, + qkv_bias: bool = True, + use_rope: bool = False, + qk_rms_norm: bool = False, + ): + super().__init__() + assert channels % num_heads == 0 + assert type in ["self", "cross"], f"Invalid attention type: {type}" + assert attn_mode in ["full", "serialized", "windowed"], f"Invalid attention mode: {attn_mode}" + assert type == "self" or attn_mode == "full", "Cross-attention only supports full attention" + assert type == "self" or use_rope is False, "Rotary position embeddings only supported for self-attention" + self.channels = channels + self.ctx_channels = ctx_channels if ctx_channels is not None else channels + self.num_heads = num_heads + self._type = type + self.attn_mode = attn_mode + self.window_size = window_size + self.shift_sequence = shift_sequence + self.shift_window = shift_window + self.serialize_mode = serialize_mode + self.use_rope = use_rope + self.qk_rms_norm = qk_rms_norm + + if self._type == "self": + self.to_qkv = nn.Linear(channels, channels * 3, bias=qkv_bias) + else: + self.to_q = nn.Linear(channels, channels, bias=qkv_bias) + self.to_kv = nn.Linear(self.ctx_channels, channels * 2, bias=qkv_bias) + + if self.qk_rms_norm: + self.q_rms_norm = SparseMultiHeadRMSNorm(channels // num_heads, num_heads) + self.k_rms_norm = SparseMultiHeadRMSNorm(channels // num_heads, num_heads) + + self.to_out = nn.Linear(channels, channels) + + if use_rope: + self.rope = RotaryPositionEmbedder(channels) + + @staticmethod + def _linear(module: nn.Linear, x: Union[SparseTensor, torch.Tensor]) -> Union[SparseTensor, torch.Tensor]: + if isinstance(x, SparseTensor): + return x.replace(module(x.feats)) + else: + return module(x) + + @staticmethod + def _reshape_chs(x: Union[SparseTensor, torch.Tensor], shape: Tuple[int, ...]) -> Union[SparseTensor, torch.Tensor]: + if isinstance(x, SparseTensor): + return x.reshape(*shape) + else: + return x.reshape(*x.shape[:2], *shape) + + def _fused_pre(self, x: Union[SparseTensor, torch.Tensor], num_fused: int) -> Union[SparseTensor, torch.Tensor]: + if isinstance(x, SparseTensor): + x_feats = x.feats.unsqueeze(0) + else: + x_feats = x + x_feats = x_feats.reshape(*x_feats.shape[:2], num_fused, self.num_heads, -1) + return x.replace(x_feats.squeeze(0)) if isinstance(x, SparseTensor) else x_feats + + def _rope(self, qkv: SparseTensor) -> SparseTensor: + q, k, v = qkv.feats.unbind(dim=1) # [T, H, C] + q, k = self.rope(q, k, qkv.coords[:, 1:]) + qkv = qkv.replace(torch.stack([q, k, v], dim=1)) + return qkv + + def forward(self, x: Union[SparseTensor, torch.Tensor], context: Optional[Union[SparseTensor, torch.Tensor]] = None) -> Union[SparseTensor, torch.Tensor]: + if self._type == "self": + qkv = self._linear(self.to_qkv, x) + qkv = self._fused_pre(qkv, num_fused=3) + if self.use_rope: + qkv = self._rope(qkv) + if self.qk_rms_norm: + q, k, v = qkv.unbind(dim=1) + q = self.q_rms_norm(q) + k = self.k_rms_norm(k) + qkv = qkv.replace(torch.stack([q.feats, k.feats, v.feats], dim=1)) + if self.attn_mode == "full": + h = sparse_scaled_dot_product_attention(qkv) + elif self.attn_mode == "serialized": + h = sparse_serialized_scaled_dot_product_self_attention( + qkv, self.window_size, serialize_mode=self.serialize_mode, shift_sequence=self.shift_sequence, shift_window=self.shift_window + ) + elif self.attn_mode == "windowed": + h = sparse_windowed_scaled_dot_product_self_attention( + qkv, self.window_size, shift_window=self.shift_window + ) + else: + q = self._linear(self.to_q, x) + q = self._reshape_chs(q, (self.num_heads, -1)) + kv = self._linear(self.to_kv, context) + kv = self._fused_pre(kv, num_fused=2) + if self.qk_rms_norm: + q = self.q_rms_norm(q) + k, v = kv.unbind(dim=1) + k = self.k_rms_norm(k) + kv = kv.replace(torch.stack([k.feats, v.feats], dim=1)) + h = sparse_scaled_dot_product_attention(q, kv) + h = self._reshape_chs(h, (-1,)) + h = self._linear(self.to_out, h) + return h diff --git a/TRELLIS/trellis/modules/sparse/attention/serialized_attn.py b/TRELLIS/trellis/modules/sparse/attention/serialized_attn.py new file mode 100644 index 0000000000000000000000000000000000000000..b3da276c4b47db2e2816d61bfa66db413aa6b7aa --- /dev/null +++ b/TRELLIS/trellis/modules/sparse/attention/serialized_attn.py @@ -0,0 +1,193 @@ +from typing import * +from enum import Enum +import torch +import math +from .. import SparseTensor +from .. import DEBUG, ATTN + +if ATTN == 'xformers': + import xformers.ops as xops +elif ATTN == 'flash_attn': + import flash_attn +else: + raise ValueError(f"Unknown attention module: {ATTN}") + + +__all__ = [ + 'sparse_serialized_scaled_dot_product_self_attention', +] + + +class SerializeMode(Enum): + Z_ORDER = 0 + Z_ORDER_TRANSPOSED = 1 + HILBERT = 2 + HILBERT_TRANSPOSED = 3 + + +SerializeModes = [ + SerializeMode.Z_ORDER, + SerializeMode.Z_ORDER_TRANSPOSED, + SerializeMode.HILBERT, + SerializeMode.HILBERT_TRANSPOSED +] + + +def calc_serialization( + tensor: SparseTensor, + window_size: int, + serialize_mode: SerializeMode = SerializeMode.Z_ORDER, + shift_sequence: int = 0, + shift_window: Tuple[int, int, int] = (0, 0, 0) +) -> Tuple[torch.Tensor, torch.Tensor, List[int]]: + """ + Calculate serialization and partitioning for a set of coordinates. + + Args: + tensor (SparseTensor): The input tensor. + window_size (int): The window size to use. + serialize_mode (SerializeMode): The serialization mode to use. + shift_sequence (int): The shift of serialized sequence. + shift_window (Tuple[int, int, int]): The shift of serialized coordinates. + + Returns: + (torch.Tensor, torch.Tensor): Forwards and backwards indices. + """ + fwd_indices = [] + bwd_indices = [] + seq_lens = [] + seq_batch_indices = [] + offsets = [0] + + if 'vox2seq' not in globals(): + import vox2seq + + # Serialize the input + serialize_coords = tensor.coords[:, 1:].clone() + serialize_coords += torch.tensor(shift_window, dtype=torch.int32, device=tensor.device).reshape(1, 3) + if serialize_mode == SerializeMode.Z_ORDER: + code = vox2seq.encode(serialize_coords, mode='z_order', permute=[0, 1, 2]) + elif serialize_mode == SerializeMode.Z_ORDER_TRANSPOSED: + code = vox2seq.encode(serialize_coords, mode='z_order', permute=[1, 0, 2]) + elif serialize_mode == SerializeMode.HILBERT: + code = vox2seq.encode(serialize_coords, mode='hilbert', permute=[0, 1, 2]) + elif serialize_mode == SerializeMode.HILBERT_TRANSPOSED: + code = vox2seq.encode(serialize_coords, mode='hilbert', permute=[1, 0, 2]) + else: + raise ValueError(f"Unknown serialize mode: {serialize_mode}") + + for bi, s in enumerate(tensor.layout): + num_points = s.stop - s.start + num_windows = (num_points + window_size - 1) // window_size + valid_window_size = num_points / num_windows + to_ordered = torch.argsort(code[s.start:s.stop]) + if num_windows == 1: + fwd_indices.append(to_ordered) + bwd_indices.append(torch.zeros_like(to_ordered).scatter_(0, to_ordered, torch.arange(num_points, device=tensor.device))) + fwd_indices[-1] += s.start + bwd_indices[-1] += offsets[-1] + seq_lens.append(num_points) + seq_batch_indices.append(bi) + offsets.append(offsets[-1] + seq_lens[-1]) + else: + # Partition the input + offset = 0 + mids = [(i + 0.5) * valid_window_size + shift_sequence for i in range(num_windows)] + split = [math.floor(i * valid_window_size + shift_sequence) for i in range(num_windows + 1)] + bwd_index = torch.zeros((num_points,), dtype=torch.int64, device=tensor.device) + for i in range(num_windows): + mid = mids[i] + valid_start = split[i] + valid_end = split[i + 1] + padded_start = math.floor(mid - 0.5 * window_size) + padded_end = padded_start + window_size + fwd_indices.append(to_ordered[torch.arange(padded_start, padded_end, device=tensor.device) % num_points]) + offset += valid_start - padded_start + bwd_index.scatter_(0, fwd_indices[-1][valid_start-padded_start:valid_end-padded_start], torch.arange(offset, offset + valid_end - valid_start, device=tensor.device)) + offset += padded_end - valid_start + fwd_indices[-1] += s.start + seq_lens.extend([window_size] * num_windows) + seq_batch_indices.extend([bi] * num_windows) + bwd_indices.append(bwd_index + offsets[-1]) + offsets.append(offsets[-1] + num_windows * window_size) + + fwd_indices = torch.cat(fwd_indices) + bwd_indices = torch.cat(bwd_indices) + + return fwd_indices, bwd_indices, seq_lens, seq_batch_indices + + +def sparse_serialized_scaled_dot_product_self_attention( + qkv: SparseTensor, + window_size: int, + serialize_mode: SerializeMode = SerializeMode.Z_ORDER, + shift_sequence: int = 0, + shift_window: Tuple[int, int, int] = (0, 0, 0) +) -> SparseTensor: + """ + Apply serialized scaled dot product self attention to a sparse tensor. + + Args: + qkv (SparseTensor): [N, *, 3, H, C] sparse tensor containing Qs, Ks, and Vs. + window_size (int): The window size to use. + serialize_mode (SerializeMode): The serialization mode to use. + shift_sequence (int): The shift of serialized sequence. + shift_window (Tuple[int, int, int]): The shift of serialized coordinates. + shift (int): The shift to use. + """ + assert len(qkv.shape) == 4 and qkv.shape[1] == 3, f"Invalid shape for qkv, got {qkv.shape}, expected [N, *, 3, H, C]" + + serialization_spatial_cache_name = f'serialization_{serialize_mode}_{window_size}_{shift_sequence}_{shift_window}' + serialization_spatial_cache = qkv.get_spatial_cache(serialization_spatial_cache_name) + if serialization_spatial_cache is None: + fwd_indices, bwd_indices, seq_lens, seq_batch_indices = calc_serialization(qkv, window_size, serialize_mode, shift_sequence, shift_window) + qkv.register_spatial_cache(serialization_spatial_cache_name, (fwd_indices, bwd_indices, seq_lens, seq_batch_indices)) + else: + fwd_indices, bwd_indices, seq_lens, seq_batch_indices = serialization_spatial_cache + + M = fwd_indices.shape[0] + T = qkv.feats.shape[0] + H = qkv.feats.shape[2] + C = qkv.feats.shape[3] + + qkv_feats = qkv.feats[fwd_indices] # [M, 3, H, C] + + if DEBUG: + start = 0 + qkv_coords = qkv.coords[fwd_indices] + for i in range(len(seq_lens)): + assert (qkv_coords[start:start+seq_lens[i], 0] == seq_batch_indices[i]).all(), f"SparseWindowedScaledDotProductSelfAttention: batch index mismatch" + start += seq_lens[i] + + if all([seq_len == window_size for seq_len in seq_lens]): + B = len(seq_lens) + N = window_size + qkv_feats = qkv_feats.reshape(B, N, 3, H, C) + if ATTN == 'xformers': + q, k, v = qkv_feats.unbind(dim=2) # [B, N, H, C] + out = xops.memory_efficient_attention(q, k, v) # [B, N, H, C] + elif ATTN == 'flash_attn': + out = flash_attn.flash_attn_qkvpacked_func(qkv_feats) # [B, N, H, C] + else: + raise ValueError(f"Unknown attention module: {ATTN}") + out = out.reshape(B * N, H, C) # [M, H, C] + else: + if ATTN == 'xformers': + q, k, v = qkv_feats.unbind(dim=1) # [M, H, C] + q = q.unsqueeze(0) # [1, M, H, C] + k = k.unsqueeze(0) # [1, M, H, C] + v = v.unsqueeze(0) # [1, M, H, C] + mask = xops.fmha.BlockDiagonalMask.from_seqlens(seq_lens) + out = xops.memory_efficient_attention(q, k, v, mask)[0] # [M, H, C] + elif ATTN == 'flash_attn': + cu_seqlens = torch.cat([torch.tensor([0]), torch.cumsum(torch.tensor(seq_lens), dim=0)], dim=0) \ + .to(qkv.device).int() + out = flash_attn.flash_attn_varlen_qkvpacked_func(qkv_feats, cu_seqlens, max(seq_lens)) # [M, H, C] + + out = out[bwd_indices] # [T, H, C] + + if DEBUG: + qkv_coords = qkv_coords[bwd_indices] + assert torch.equal(qkv_coords, qkv.coords), "SparseWindowedScaledDotProductSelfAttention: coordinate mismatch" + + return qkv.replace(out) diff --git a/TRELLIS/trellis/modules/sparse/attention/windowed_attn.py b/TRELLIS/trellis/modules/sparse/attention/windowed_attn.py new file mode 100644 index 0000000000000000000000000000000000000000..11eebf851316d8b4c1f6ca39b881c422d8f2f088 --- /dev/null +++ b/TRELLIS/trellis/modules/sparse/attention/windowed_attn.py @@ -0,0 +1,135 @@ +from typing import * +import torch +import math +from .. import SparseTensor +from .. import DEBUG, ATTN + +if ATTN == 'xformers': + import xformers.ops as xops +elif ATTN == 'flash_attn': + import flash_attn +else: + raise ValueError(f"Unknown attention module: {ATTN}") + + +__all__ = [ + 'sparse_windowed_scaled_dot_product_self_attention', +] + + +def calc_window_partition( + tensor: SparseTensor, + window_size: Union[int, Tuple[int, ...]], + shift_window: Union[int, Tuple[int, ...]] = 0 +) -> Tuple[torch.Tensor, torch.Tensor, List[int], List[int]]: + """ + Calculate serialization and partitioning for a set of coordinates. + + Args: + tensor (SparseTensor): The input tensor. + window_size (int): The window size to use. + shift_window (Tuple[int, ...]): The shift of serialized coordinates. + + Returns: + (torch.Tensor): Forwards indices. + (torch.Tensor): Backwards indices. + (List[int]): Sequence lengths. + (List[int]): Sequence batch indices. + """ + DIM = tensor.coords.shape[1] - 1 + shift_window = (shift_window,) * DIM if isinstance(shift_window, int) else shift_window + window_size = (window_size,) * DIM if isinstance(window_size, int) else window_size + shifted_coords = tensor.coords.clone().detach() + shifted_coords[:, 1:] += torch.tensor(shift_window, device=tensor.device, dtype=torch.int32).unsqueeze(0) + + MAX_COORDS = shifted_coords[:, 1:].max(dim=0).values.tolist() + NUM_WINDOWS = [math.ceil((mc + 1) / ws) for mc, ws in zip(MAX_COORDS, window_size)] + OFFSET = torch.cumprod(torch.tensor([1] + NUM_WINDOWS[::-1]), dim=0).tolist()[::-1] + + shifted_coords[:, 1:] //= torch.tensor(window_size, device=tensor.device, dtype=torch.int32).unsqueeze(0) + shifted_indices = (shifted_coords * torch.tensor(OFFSET, device=tensor.device, dtype=torch.int32).unsqueeze(0)).sum(dim=1) + fwd_indices = torch.argsort(shifted_indices) + bwd_indices = torch.empty_like(fwd_indices) + bwd_indices[fwd_indices] = torch.arange(fwd_indices.shape[0], device=tensor.device) + seq_lens = torch.bincount(shifted_indices) + seq_batch_indices = torch.arange(seq_lens.shape[0], device=tensor.device, dtype=torch.int32) // OFFSET[0] + mask = seq_lens != 0 + seq_lens = seq_lens[mask].tolist() + seq_batch_indices = seq_batch_indices[mask].tolist() + + return fwd_indices, bwd_indices, seq_lens, seq_batch_indices + + +def sparse_windowed_scaled_dot_product_self_attention( + qkv: SparseTensor, + window_size: int, + shift_window: Tuple[int, int, int] = (0, 0, 0) +) -> SparseTensor: + """ + Apply windowed scaled dot product self attention to a sparse tensor. + + Args: + qkv (SparseTensor): [N, *, 3, H, C] sparse tensor containing Qs, Ks, and Vs. + window_size (int): The window size to use. + shift_window (Tuple[int, int, int]): The shift of serialized coordinates. + shift (int): The shift to use. + """ + assert len(qkv.shape) == 4 and qkv.shape[1] == 3, f"Invalid shape for qkv, got {qkv.shape}, expected [N, *, 3, H, C]" + + serialization_spatial_cache_name = f'window_partition_{window_size}_{shift_window}' + serialization_spatial_cache = qkv.get_spatial_cache(serialization_spatial_cache_name) + if serialization_spatial_cache is None: + fwd_indices, bwd_indices, seq_lens, seq_batch_indices = calc_window_partition(qkv, window_size, shift_window) + qkv.register_spatial_cache(serialization_spatial_cache_name, (fwd_indices, bwd_indices, seq_lens, seq_batch_indices)) + else: + fwd_indices, bwd_indices, seq_lens, seq_batch_indices = serialization_spatial_cache + + M = fwd_indices.shape[0] + T = qkv.feats.shape[0] + H = qkv.feats.shape[2] + C = qkv.feats.shape[3] + + qkv_feats = qkv.feats[fwd_indices] # [M, 3, H, C] + + if DEBUG: + start = 0 + qkv_coords = qkv.coords[fwd_indices] + for i in range(len(seq_lens)): + seq_coords = qkv_coords[start:start+seq_lens[i]] + assert (seq_coords[:, 0] == seq_batch_indices[i]).all(), f"SparseWindowedScaledDotProductSelfAttention: batch index mismatch" + assert (seq_coords[:, 1:].max(dim=0).values - seq_coords[:, 1:].min(dim=0).values < window_size).all(), \ + f"SparseWindowedScaledDotProductSelfAttention: window size exceeded" + start += seq_lens[i] + + if all([seq_len == window_size for seq_len in seq_lens]): + B = len(seq_lens) + N = window_size + qkv_feats = qkv_feats.reshape(B, N, 3, H, C) + if ATTN == 'xformers': + q, k, v = qkv_feats.unbind(dim=2) # [B, N, H, C] + out = xops.memory_efficient_attention(q, k, v) # [B, N, H, C] + elif ATTN == 'flash_attn': + out = flash_attn.flash_attn_qkvpacked_func(qkv_feats) # [B, N, H, C] + else: + raise ValueError(f"Unknown attention module: {ATTN}") + out = out.reshape(B * N, H, C) # [M, H, C] + else: + if ATTN == 'xformers': + q, k, v = qkv_feats.unbind(dim=1) # [M, H, C] + q = q.unsqueeze(0) # [1, M, H, C] + k = k.unsqueeze(0) # [1, M, H, C] + v = v.unsqueeze(0) # [1, M, H, C] + mask = xops.fmha.BlockDiagonalMask.from_seqlens(seq_lens) + out = xops.memory_efficient_attention(q, k, v, mask)[0] # [M, H, C] + elif ATTN == 'flash_attn': + cu_seqlens = torch.cat([torch.tensor([0]), torch.cumsum(torch.tensor(seq_lens), dim=0)], dim=0) \ + .to(qkv.device).int() + out = flash_attn.flash_attn_varlen_qkvpacked_func(qkv_feats, cu_seqlens, max(seq_lens)) # [M, H, C] + + out = out[bwd_indices] # [T, H, C] + + if DEBUG: + qkv_coords = qkv_coords[bwd_indices] + assert torch.equal(qkv_coords, qkv.coords), "SparseWindowedScaledDotProductSelfAttention: coordinate mismatch" + + return qkv.replace(out) diff --git a/TRELLIS/trellis/modules/sparse/basic.py b/TRELLIS/trellis/modules/sparse/basic.py new file mode 100644 index 0000000000000000000000000000000000000000..0fc685128f47906e5608d38d789366bb06837513 --- /dev/null +++ b/TRELLIS/trellis/modules/sparse/basic.py @@ -0,0 +1,459 @@ +from typing import * +import torch +import torch.nn as nn +from . import BACKEND, DEBUG +SparseTensorData = None # Lazy import + + +__all__ = [ + 'SparseTensor', + 'sparse_batch_broadcast', + 'sparse_batch_op', + 'sparse_cat', + 'sparse_unbind', +] + + +class SparseTensor: + """ + Sparse tensor with support for both torchsparse and spconv backends. + + Parameters: + - feats (torch.Tensor): Features of the sparse tensor. + - coords (torch.Tensor): Coordinates of the sparse tensor. + - shape (torch.Size): Shape of the sparse tensor. + - layout (List[slice]): Layout of the sparse tensor for each batch + - data (SparseTensorData): Sparse tensor data used for convolusion + + NOTE: + - Data corresponding to a same batch should be contiguous. + - Coords should be in [0, 1023] + """ + @overload + def __init__(self, feats: torch.Tensor, coords: torch.Tensor, shape: Optional[torch.Size] = None, layout: Optional[List[slice]] = None, **kwargs): ... + + @overload + def __init__(self, data, shape: Optional[torch.Size] = None, layout: Optional[List[slice]] = None, **kwargs): ... + + def __init__(self, *args, **kwargs): + # Lazy import of sparse tensor backend + global SparseTensorData + if SparseTensorData is None: + import importlib + if BACKEND == 'torchsparse': + SparseTensorData = importlib.import_module('torchsparse').SparseTensor + elif BACKEND == 'spconv': + SparseTensorData = importlib.import_module('spconv.pytorch').SparseConvTensor + + method_id = 0 + if len(args) != 0: + method_id = 0 if isinstance(args[0], torch.Tensor) else 1 + else: + method_id = 1 if 'data' in kwargs else 0 + + if method_id == 0: + feats, coords, shape, layout = args + (None,) * (4 - len(args)) + if 'feats' in kwargs: + feats = kwargs['feats'] + del kwargs['feats'] + if 'coords' in kwargs: + coords = kwargs['coords'] + del kwargs['coords'] + if 'shape' in kwargs: + shape = kwargs['shape'] + del kwargs['shape'] + if 'layout' in kwargs: + layout = kwargs['layout'] + del kwargs['layout'] + + if shape is None: + shape = self.__cal_shape(feats, coords) + if layout is None: + layout = self.__cal_layout(coords, shape[0]) + if BACKEND == 'torchsparse': + self.data = SparseTensorData(feats, coords, **kwargs) + elif BACKEND == 'spconv': + spatial_shape = list(coords.max(0)[0] + 1)[1:] + self.data = SparseTensorData(feats.reshape(feats.shape[0], -1), coords, spatial_shape, shape[0], **kwargs) + self.data._features = feats + elif method_id == 1: + data, shape, layout = args + (None,) * (3 - len(args)) + if 'data' in kwargs: + data = kwargs['data'] + del kwargs['data'] + if 'shape' in kwargs: + shape = kwargs['shape'] + del kwargs['shape'] + if 'layout' in kwargs: + layout = kwargs['layout'] + del kwargs['layout'] + + self.data = data + if shape is None: + shape = self.__cal_shape(self.feats, self.coords) + if layout is None: + layout = self.__cal_layout(self.coords, shape[0]) + + self._shape = shape + self._layout = layout + self._scale = kwargs.get('scale', (1, 1, 1)) + self._spatial_cache = kwargs.get('spatial_cache', {}) + + if DEBUG: + try: + assert self.feats.shape[0] == self.coords.shape[0], f"Invalid feats shape: {self.feats.shape}, coords shape: {self.coords.shape}" + assert self.shape == self.__cal_shape(self.feats, self.coords), f"Invalid shape: {self.shape}" + assert self.layout == self.__cal_layout(self.coords, self.shape[0]), f"Invalid layout: {self.layout}" + for i in range(self.shape[0]): + assert torch.all(self.coords[self.layout[i], 0] == i), f"The data of batch {i} is not contiguous" + except Exception as e: + print('Debugging information:') + print(f"- Shape: {self.shape}") + print(f"- Layout: {self.layout}") + print(f"- Scale: {self._scale}") + print(f"- Coords: {self.coords}") + raise e + + def __cal_shape(self, feats, coords): + shape = [] + shape.append(coords[:, 0].max().item() + 1) + shape.extend([*feats.shape[1:]]) + return torch.Size(shape) + + def __cal_layout(self, coords, batch_size): + seq_len = torch.bincount(coords[:, 0], minlength=batch_size) + offset = torch.cumsum(seq_len, dim=0) + layout = [slice((offset[i] - seq_len[i]).item(), offset[i].item()) for i in range(batch_size)] + return layout + + @property + def shape(self) -> torch.Size: + return self._shape + + def dim(self) -> int: + return len(self.shape) + + @property + def layout(self) -> List[slice]: + return self._layout + + @property + def feats(self) -> torch.Tensor: + if BACKEND == 'torchsparse': + return self.data.F + elif BACKEND == 'spconv': + return self.data.features + + @feats.setter + def feats(self, value: torch.Tensor): + if BACKEND == 'torchsparse': + self.data.F = value + elif BACKEND == 'spconv': + self.data.features = value + + @property + def coords(self) -> torch.Tensor: + if BACKEND == 'torchsparse': + return self.data.C + elif BACKEND == 'spconv': + return self.data.indices + + @coords.setter + def coords(self, value: torch.Tensor): + if BACKEND == 'torchsparse': + self.data.C = value + elif BACKEND == 'spconv': + self.data.indices = value + + @property + def dtype(self): + return self.feats.dtype + + @property + def device(self): + return self.feats.device + + @overload + def to(self, dtype: torch.dtype) -> 'SparseTensor': ... + + @overload + def to(self, device: Optional[Union[str, torch.device]] = None, dtype: Optional[torch.dtype] = None) -> 'SparseTensor': ... + + def to(self, *args, **kwargs) -> 'SparseTensor': + device = None + dtype = None + if len(args) == 2: + device, dtype = args + elif len(args) == 1: + if isinstance(args[0], torch.dtype): + dtype = args[0] + else: + device = args[0] + if 'dtype' in kwargs: + assert dtype is None, "to() received multiple values for argument 'dtype'" + dtype = kwargs['dtype'] + if 'device' in kwargs: + assert device is None, "to() received multiple values for argument 'device'" + device = kwargs['device'] + + new_feats = self.feats.to(device=device, dtype=dtype) + new_coords = self.coords.to(device=device) + return self.replace(new_feats, new_coords) + + def type(self, dtype): + new_feats = self.feats.type(dtype) + return self.replace(new_feats) + + def cpu(self) -> 'SparseTensor': + new_feats = self.feats.cpu() + new_coords = self.coords.cpu() + return self.replace(new_feats, new_coords) + + def cuda(self) -> 'SparseTensor': + new_feats = self.feats.cuda() + new_coords = self.coords.cuda() + return self.replace(new_feats, new_coords) + + def half(self) -> 'SparseTensor': + new_feats = self.feats.half() + return self.replace(new_feats) + + def float(self) -> 'SparseTensor': + new_feats = self.feats.float() + return self.replace(new_feats) + + def detach(self) -> 'SparseTensor': + new_coords = self.coords.detach() + new_feats = self.feats.detach() + return self.replace(new_feats, new_coords) + + def dense(self) -> torch.Tensor: + if BACKEND == 'torchsparse': + return self.data.dense() + elif BACKEND == 'spconv': + return self.data.dense() + + def reshape(self, *shape) -> 'SparseTensor': + new_feats = self.feats.reshape(self.feats.shape[0], *shape) + return self.replace(new_feats) + + def unbind(self, dim: int) -> List['SparseTensor']: + return sparse_unbind(self, dim) + + def replace(self, feats: torch.Tensor, coords: Optional[torch.Tensor] = None) -> 'SparseTensor': + new_shape = [self.shape[0]] + new_shape.extend(feats.shape[1:]) + if BACKEND == 'torchsparse': + new_data = SparseTensorData( + feats=feats, + coords=self.data.coords if coords is None else coords, + stride=self.data.stride, + spatial_range=self.data.spatial_range, + ) + new_data._caches = self.data._caches + elif BACKEND == 'spconv': + new_data = SparseTensorData( + self.data.features.reshape(self.data.features.shape[0], -1), + self.data.indices, + self.data.spatial_shape, + self.data.batch_size, + self.data.grid, + self.data.voxel_num, + self.data.indice_dict + ) + new_data._features = feats + new_data.benchmark = self.data.benchmark + new_data.benchmark_record = self.data.benchmark_record + new_data.thrust_allocator = self.data.thrust_allocator + new_data._timer = self.data._timer + new_data.force_algo = self.data.force_algo + new_data.int8_scale = self.data.int8_scale + if coords is not None: + new_data.indices = coords + new_tensor = SparseTensor(new_data, shape=torch.Size(new_shape), layout=self.layout, scale=self._scale, spatial_cache=self._spatial_cache) + return new_tensor + + @staticmethod + def full(aabb, dim, value, dtype=torch.float32, device=None) -> 'SparseTensor': + N, C = dim + x = torch.arange(aabb[0], aabb[3] + 1) + y = torch.arange(aabb[1], aabb[4] + 1) + z = torch.arange(aabb[2], aabb[5] + 1) + coords = torch.stack(torch.meshgrid(x, y, z, indexing='ij'), dim=-1).reshape(-1, 3) + coords = torch.cat([ + torch.arange(N).view(-1, 1).repeat(1, coords.shape[0]).view(-1, 1), + coords.repeat(N, 1), + ], dim=1).to(dtype=torch.int32, device=device) + feats = torch.full((coords.shape[0], C), value, dtype=dtype, device=device) + return SparseTensor(feats=feats, coords=coords) + + def __merge_sparse_cache(self, other: 'SparseTensor') -> dict: + new_cache = {} + for k in set(list(self._spatial_cache.keys()) + list(other._spatial_cache.keys())): + if k in self._spatial_cache: + new_cache[k] = self._spatial_cache[k] + if k in other._spatial_cache: + if k not in new_cache: + new_cache[k] = other._spatial_cache[k] + else: + new_cache[k].update(other._spatial_cache[k]) + return new_cache + + def __neg__(self) -> 'SparseTensor': + return self.replace(-self.feats) + + def __elemwise__(self, other: Union[torch.Tensor, 'SparseTensor'], op: callable) -> 'SparseTensor': + if isinstance(other, torch.Tensor): + try: + other = torch.broadcast_to(other, self.shape) + other = sparse_batch_broadcast(self, other) + except: + pass + if isinstance(other, SparseTensor): + other = other.feats + new_feats = op(self.feats, other) + new_tensor = self.replace(new_feats) + if isinstance(other, SparseTensor): + new_tensor._spatial_cache = self.__merge_sparse_cache(other) + return new_tensor + + def __add__(self, other: Union[torch.Tensor, 'SparseTensor', float]) -> 'SparseTensor': + return self.__elemwise__(other, torch.add) + + def __radd__(self, other: Union[torch.Tensor, 'SparseTensor', float]) -> 'SparseTensor': + return self.__elemwise__(other, torch.add) + + def __sub__(self, other: Union[torch.Tensor, 'SparseTensor', float]) -> 'SparseTensor': + return self.__elemwise__(other, torch.sub) + + def __rsub__(self, other: Union[torch.Tensor, 'SparseTensor', float]) -> 'SparseTensor': + return self.__elemwise__(other, lambda x, y: torch.sub(y, x)) + + def __mul__(self, other: Union[torch.Tensor, 'SparseTensor', float]) -> 'SparseTensor': + return self.__elemwise__(other, torch.mul) + + def __rmul__(self, other: Union[torch.Tensor, 'SparseTensor', float]) -> 'SparseTensor': + return self.__elemwise__(other, torch.mul) + + def __truediv__(self, other: Union[torch.Tensor, 'SparseTensor', float]) -> 'SparseTensor': + return self.__elemwise__(other, torch.div) + + def __rtruediv__(self, other: Union[torch.Tensor, 'SparseTensor', float]) -> 'SparseTensor': + return self.__elemwise__(other, lambda x, y: torch.div(y, x)) + + def __getitem__(self, idx): + if isinstance(idx, int): + idx = [idx] + elif isinstance(idx, slice): + idx = range(*idx.indices(self.shape[0])) + elif isinstance(idx, torch.Tensor): + if idx.dtype == torch.bool: + assert idx.shape == (self.shape[0],), f"Invalid index shape: {idx.shape}" + idx = idx.nonzero().squeeze(1) + elif idx.dtype in [torch.int32, torch.int64]: + assert len(idx.shape) == 1, f"Invalid index shape: {idx.shape}" + else: + raise ValueError(f"Unknown index type: {idx.dtype}") + else: + raise ValueError(f"Unknown index type: {type(idx)}") + + coords = [] + feats = [] + for new_idx, old_idx in enumerate(idx): + coords.append(self.coords[self.layout[old_idx]].clone()) + coords[-1][:, 0] = new_idx + feats.append(self.feats[self.layout[old_idx]]) + coords = torch.cat(coords, dim=0).contiguous() + feats = torch.cat(feats, dim=0).contiguous() + return SparseTensor(feats=feats, coords=coords) + + def register_spatial_cache(self, key, value) -> None: + """ + Register a spatial cache. + The spatial cache can be any thing you want to cache. + The registery and retrieval of the cache is based on current scale. + """ + scale_key = str(self._scale) + if scale_key not in self._spatial_cache: + self._spatial_cache[scale_key] = {} + self._spatial_cache[scale_key][key] = value + + def get_spatial_cache(self, key=None): + """ + Get a spatial cache. + """ + scale_key = str(self._scale) + cur_scale_cache = self._spatial_cache.get(scale_key, {}) + if key is None: + return cur_scale_cache + return cur_scale_cache.get(key, None) + + +def sparse_batch_broadcast(input: SparseTensor, other: torch.Tensor) -> torch.Tensor: + """ + Broadcast a 1D tensor to a sparse tensor along the batch dimension then perform an operation. + + Args: + input (torch.Tensor): 1D tensor to broadcast. + target (SparseTensor): Sparse tensor to broadcast to. + op (callable): Operation to perform after broadcasting. Defaults to torch.add. + """ + coords, feats = input.coords, input.feats + broadcasted = torch.zeros_like(feats) + for k in range(input.shape[0]): + broadcasted[input.layout[k]] = other[k] + return broadcasted + + +def sparse_batch_op(input: SparseTensor, other: torch.Tensor, op: callable = torch.add) -> SparseTensor: + """ + Broadcast a 1D tensor to a sparse tensor along the batch dimension then perform an operation. + + Args: + input (torch.Tensor): 1D tensor to broadcast. + target (SparseTensor): Sparse tensor to broadcast to. + op (callable): Operation to perform after broadcasting. Defaults to torch.add. + """ + return input.replace(op(input.feats, sparse_batch_broadcast(input, other))) + + +def sparse_cat(inputs: List[SparseTensor], dim: int = 0) -> SparseTensor: + """ + Concatenate a list of sparse tensors. + + Args: + inputs (List[SparseTensor]): List of sparse tensors to concatenate. + """ + if dim == 0: + start = 0 + coords = [] + for input in inputs: + coords.append(input.coords.clone()) + coords[-1][:, 0] += start + start += input.shape[0] + coords = torch.cat(coords, dim=0) + feats = torch.cat([input.feats for input in inputs], dim=0) + output = SparseTensor( + coords=coords, + feats=feats, + ) + else: + feats = torch.cat([input.feats for input in inputs], dim=dim) + output = inputs[0].replace(feats) + + return output + + +def sparse_unbind(input: SparseTensor, dim: int) -> List[SparseTensor]: + """ + Unbind a sparse tensor along a dimension. + + Args: + input (SparseTensor): Sparse tensor to unbind. + dim (int): Dimension to unbind. + """ + if dim == 0: + return [input[i] for i in range(input.shape[0])] + else: + feats = input.feats.unbind(dim) + return [input.replace(f) for f in feats] diff --git a/TRELLIS/trellis/modules/sparse/conv/__init__.py b/TRELLIS/trellis/modules/sparse/conv/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..3fdaddf7e07d42e296d056df28e56a544d2db5f2 --- /dev/null +++ b/TRELLIS/trellis/modules/sparse/conv/__init__.py @@ -0,0 +1,21 @@ +from .. import BACKEND + + +SPCONV_ALGO = 'auto' # 'auto', 'implicit_gemm', 'native' + +def __from_env(): + import os + + global SPCONV_ALGO + env_spconv_algo = os.environ.get('SPCONV_ALGO') + if env_spconv_algo is not None and env_spconv_algo in ['auto', 'implicit_gemm', 'native']: + SPCONV_ALGO = env_spconv_algo + print(f"[SPARSE][CONV] spconv algo: {SPCONV_ALGO}") + + +__from_env() + +if BACKEND == 'torchsparse': + from .conv_torchsparse import * +elif BACKEND == 'spconv': + from .conv_spconv import * diff --git a/TRELLIS/trellis/modules/sparse/conv/conv_spconv.py b/TRELLIS/trellis/modules/sparse/conv/conv_spconv.py new file mode 100644 index 0000000000000000000000000000000000000000..856405dea4b24e5800cc056106bb34bb40f6eef0 --- /dev/null +++ b/TRELLIS/trellis/modules/sparse/conv/conv_spconv.py @@ -0,0 +1,80 @@ +import torch +import torch.nn as nn +from .. import SparseTensor +from .. import DEBUG +from . import SPCONV_ALGO + +class SparseConv3d(nn.Module): + def __init__(self, in_channels, out_channels, kernel_size, stride=1, dilation=1, padding=None, bias=True, indice_key=None): + super(SparseConv3d, self).__init__() + if 'spconv' not in globals(): + import spconv.pytorch as spconv + algo = None + if SPCONV_ALGO == 'native': + algo = spconv.ConvAlgo.Native + elif SPCONV_ALGO == 'implicit_gemm': + algo = spconv.ConvAlgo.MaskImplicitGemm + if stride == 1 and (padding is None): + self.conv = spconv.SubMConv3d(in_channels, out_channels, kernel_size, dilation=dilation, bias=bias, indice_key=indice_key, algo=algo) + else: + self.conv = spconv.SparseConv3d(in_channels, out_channels, kernel_size, stride=stride, dilation=dilation, padding=padding, bias=bias, indice_key=indice_key, algo=algo) + self.stride = tuple(stride) if isinstance(stride, (list, tuple)) else (stride, stride, stride) + self.padding = padding + + def forward(self, x: SparseTensor) -> SparseTensor: + spatial_changed = any(s != 1 for s in self.stride) or (self.padding is not None) + new_data = self.conv(x.data) + new_shape = [x.shape[0], self.conv.out_channels] + new_layout = None if spatial_changed else x.layout + + if spatial_changed and (x.shape[0] != 1): + # spconv was non-1 stride will break the contiguous of the output tensor, sort by the coords + fwd = new_data.indices[:, 0].argsort() + bwd = torch.zeros_like(fwd).scatter_(0, fwd, torch.arange(fwd.shape[0], device=fwd.device)) + sorted_feats = new_data.features[fwd] + sorted_coords = new_data.indices[fwd] + unsorted_data = new_data + new_data = spconv.SparseConvTensor(sorted_feats, sorted_coords, unsorted_data.spatial_shape, unsorted_data.batch_size) # type: ignore + + out = SparseTensor( + new_data, shape=torch.Size(new_shape), layout=new_layout, + scale=tuple([s * stride for s, stride in zip(x._scale, self.stride)]), + spatial_cache=x._spatial_cache, + ) + + if spatial_changed and (x.shape[0] != 1): + out.register_spatial_cache(f'conv_{self.stride}_unsorted_data', unsorted_data) + out.register_spatial_cache(f'conv_{self.stride}_sort_bwd', bwd) + + return out + + +class SparseInverseConv3d(nn.Module): + def __init__(self, in_channels, out_channels, kernel_size, stride=1, dilation=1, bias=True, indice_key=None): + super(SparseInverseConv3d, self).__init__() + if 'spconv' not in globals(): + import spconv.pytorch as spconv + self.conv = spconv.SparseInverseConv3d(in_channels, out_channels, kernel_size, bias=bias, indice_key=indice_key) + self.stride = tuple(stride) if isinstance(stride, (list, tuple)) else (stride, stride, stride) + + def forward(self, x: SparseTensor) -> SparseTensor: + spatial_changed = any(s != 1 for s in self.stride) + if spatial_changed: + # recover the original spconv order + data = x.get_spatial_cache(f'conv_{self.stride}_unsorted_data') + bwd = x.get_spatial_cache(f'conv_{self.stride}_sort_bwd') + data = data.replace_feature(x.feats[bwd]) + if DEBUG: + assert torch.equal(data.indices, x.coords[bwd]), 'Recover the original order failed' + else: + data = x.data + + new_data = self.conv(data) + new_shape = [x.shape[0], self.conv.out_channels] + new_layout = None if spatial_changed else x.layout + out = SparseTensor( + new_data, shape=torch.Size(new_shape), layout=new_layout, + scale=tuple([s // stride for s, stride in zip(x._scale, self.stride)]), + spatial_cache=x._spatial_cache, + ) + return out diff --git a/TRELLIS/trellis/modules/sparse/conv/conv_torchsparse.py b/TRELLIS/trellis/modules/sparse/conv/conv_torchsparse.py new file mode 100644 index 0000000000000000000000000000000000000000..a10bd9105581a96117abcbb7349ea5975e4304ba --- /dev/null +++ b/TRELLIS/trellis/modules/sparse/conv/conv_torchsparse.py @@ -0,0 +1,38 @@ +import torch +import torch.nn as nn +from .. import SparseTensor + + +class SparseConv3d(nn.Module): + def __init__(self, in_channels, out_channels, kernel_size, stride=1, dilation=1, bias=True, indice_key=None): + super(SparseConv3d, self).__init__() + if 'torchsparse' not in globals(): + import torchsparse + self.conv = torchsparse.nn.Conv3d(in_channels, out_channels, kernel_size, stride, 0, dilation, bias) + + def forward(self, x: SparseTensor) -> SparseTensor: + out = self.conv(x.data) + new_shape = [x.shape[0], self.conv.out_channels] + out = SparseTensor(out, shape=torch.Size(new_shape), layout=x.layout if all(s == 1 for s in self.conv.stride) else None) + out._spatial_cache = x._spatial_cache + out._scale = tuple([s * stride for s, stride in zip(x._scale, self.conv.stride)]) + return out + + +class SparseInverseConv3d(nn.Module): + def __init__(self, in_channels, out_channels, kernel_size, stride=1, dilation=1, bias=True, indice_key=None): + super(SparseInverseConv3d, self).__init__() + if 'torchsparse' not in globals(): + import torchsparse + self.conv = torchsparse.nn.Conv3d(in_channels, out_channels, kernel_size, stride, 0, dilation, bias, transposed=True) + + def forward(self, x: SparseTensor) -> SparseTensor: + out = self.conv(x.data) + new_shape = [x.shape[0], self.conv.out_channels] + out = SparseTensor(out, shape=torch.Size(new_shape), layout=x.layout if all(s == 1 for s in self.conv.stride) else None) + out._spatial_cache = x._spatial_cache + out._scale = tuple([s // stride for s, stride in zip(x._scale, self.conv.stride)]) + return out + + + diff --git a/TRELLIS/trellis/modules/sparse/linear.py b/TRELLIS/trellis/modules/sparse/linear.py new file mode 100644 index 0000000000000000000000000000000000000000..0c25a32a4ab08665a0d3f6bc44e61ef1c1cb2861 --- /dev/null +++ b/TRELLIS/trellis/modules/sparse/linear.py @@ -0,0 +1,15 @@ +import torch +import torch.nn as nn +from . import SparseTensor + +__all__ = [ + 'SparseLinear' +] + + +class SparseLinear(nn.Linear): + def __init__(self, in_features, out_features, bias=True): + super(SparseLinear, self).__init__(in_features, out_features, bias) + + def forward(self, input: SparseTensor) -> SparseTensor: + return input.replace(super().forward(input.feats)) diff --git a/TRELLIS/trellis/modules/sparse/nonlinearity.py b/TRELLIS/trellis/modules/sparse/nonlinearity.py new file mode 100644 index 0000000000000000000000000000000000000000..db81ee886f5d047a6bc98bb8f4d4ce867d2a302d --- /dev/null +++ b/TRELLIS/trellis/modules/sparse/nonlinearity.py @@ -0,0 +1,35 @@ +import torch +import torch.nn as nn +from . import SparseTensor + +__all__ = [ + 'SparseReLU', + 'SparseSiLU', + 'SparseGELU', + 'SparseActivation' +] + + +class SparseReLU(nn.ReLU): + def forward(self, input: SparseTensor) -> SparseTensor: + return input.replace(super().forward(input.feats)) + + +class SparseSiLU(nn.SiLU): + def forward(self, input: SparseTensor) -> SparseTensor: + return input.replace(super().forward(input.feats)) + + +class SparseGELU(nn.GELU): + def forward(self, input: SparseTensor) -> SparseTensor: + return input.replace(super().forward(input.feats)) + + +class SparseActivation(nn.Module): + def __init__(self, activation: nn.Module): + super().__init__() + self.activation = activation + + def forward(self, input: SparseTensor) -> SparseTensor: + return input.replace(self.activation(input.feats)) + diff --git a/TRELLIS/trellis/modules/sparse/norm.py b/TRELLIS/trellis/modules/sparse/norm.py new file mode 100644 index 0000000000000000000000000000000000000000..7c132d8389277bc9a60937633b58256784597c4d --- /dev/null +++ b/TRELLIS/trellis/modules/sparse/norm.py @@ -0,0 +1,58 @@ +import torch +import torch.nn as nn +from . import SparseTensor +from . import DEBUG + +__all__ = [ + 'SparseGroupNorm', + 'SparseLayerNorm', + 'SparseGroupNorm32', + 'SparseLayerNorm32', +] + + +class SparseGroupNorm(nn.GroupNorm): + def __init__(self, num_groups, num_channels, eps=1e-5, affine=True): + super(SparseGroupNorm, self).__init__(num_groups, num_channels, eps, affine) + + def forward(self, input: SparseTensor) -> SparseTensor: + nfeats = torch.zeros_like(input.feats) + for k in range(input.shape[0]): + if DEBUG: + assert (input.coords[input.layout[k], 0] == k).all(), f"SparseGroupNorm: batch index mismatch" + bfeats = input.feats[input.layout[k]] + bfeats = bfeats.permute(1, 0).reshape(1, input.shape[1], -1) + bfeats = super().forward(bfeats) + bfeats = bfeats.reshape(input.shape[1], -1).permute(1, 0) + nfeats[input.layout[k]] = bfeats + return input.replace(nfeats) + + +class SparseLayerNorm(nn.LayerNorm): + def __init__(self, normalized_shape, eps=1e-5, elementwise_affine=True): + super(SparseLayerNorm, self).__init__(normalized_shape, eps, elementwise_affine) + + def forward(self, input: SparseTensor) -> SparseTensor: + nfeats = torch.zeros_like(input.feats) + for k in range(input.shape[0]): + bfeats = input.feats[input.layout[k]] + bfeats = bfeats.permute(1, 0).reshape(1, input.shape[1], -1) + bfeats = super().forward(bfeats) + bfeats = bfeats.reshape(input.shape[1], -1).permute(1, 0) + nfeats[input.layout[k]] = bfeats + return input.replace(nfeats) + + +class SparseGroupNorm32(SparseGroupNorm): + """ + A GroupNorm layer that converts to float32 before the forward pass. + """ + def forward(self, x: SparseTensor) -> SparseTensor: + return super().forward(x.float()).type(x.dtype) + +class SparseLayerNorm32(SparseLayerNorm): + """ + A LayerNorm layer that converts to float32 before the forward pass. + """ + def forward(self, x: SparseTensor) -> SparseTensor: + return super().forward(x.float()).type(x.dtype) diff --git a/TRELLIS/trellis/modules/sparse/spatial.py b/TRELLIS/trellis/modules/sparse/spatial.py new file mode 100644 index 0000000000000000000000000000000000000000..7a4713e82c42c6a1c7e72a00e3ab17e50f6f32a2 --- /dev/null +++ b/TRELLIS/trellis/modules/sparse/spatial.py @@ -0,0 +1,110 @@ +from typing import * +import torch +import torch.nn as nn +from . import SparseTensor + +__all__ = [ + 'SparseDownsample', + 'SparseUpsample', + 'SparseSubdivide' +] + + +class SparseDownsample(nn.Module): + """ + Downsample a sparse tensor by a factor of `factor`. + Implemented as average pooling. + """ + def __init__(self, factor: Union[int, Tuple[int, ...], List[int]]): + super(SparseDownsample, self).__init__() + self.factor = tuple(factor) if isinstance(factor, (list, tuple)) else factor + + def forward(self, input: SparseTensor) -> SparseTensor: + DIM = input.coords.shape[-1] - 1 + factor = self.factor if isinstance(self.factor, tuple) else (self.factor,) * DIM + assert DIM == len(factor), 'Input coordinates must have the same dimension as the downsample factor.' + + coord = list(input.coords.unbind(dim=-1)) + for i, f in enumerate(factor): + coord[i+1] = coord[i+1] // f + + MAX = [coord[i+1].max().item() + 1 for i in range(DIM)] + OFFSET = torch.cumprod(torch.tensor(MAX[::-1]), 0).tolist()[::-1] + [1] + code = sum([c * o for c, o in zip(coord, OFFSET)]) + code, idx = code.unique(return_inverse=True) + + new_feats = torch.scatter_reduce( + torch.zeros(code.shape[0], input.feats.shape[1], device=input.feats.device, dtype=input.feats.dtype), + dim=0, + index=idx.unsqueeze(1).expand(-1, input.feats.shape[1]), + src=input.feats, + reduce='mean' + ) + new_coords = torch.stack( + [code // OFFSET[0]] + + [(code // OFFSET[i+1]) % MAX[i] for i in range(DIM)], + dim=-1 + ) + out = SparseTensor(new_feats, new_coords, input.shape,) + out._scale = tuple([s // f for s, f in zip(input._scale, factor)]) + out._spatial_cache = input._spatial_cache + + out.register_spatial_cache(f'upsample_{factor}_coords', input.coords) + out.register_spatial_cache(f'upsample_{factor}_layout', input.layout) + out.register_spatial_cache(f'upsample_{factor}_idx', idx) + + return out + + +class SparseUpsample(nn.Module): + """ + Upsample a sparse tensor by a factor of `factor`. + Implemented as nearest neighbor interpolation. + """ + def __init__(self, factor: Union[int, Tuple[int, int, int], List[int]]): + super(SparseUpsample, self).__init__() + self.factor = tuple(factor) if isinstance(factor, (list, tuple)) else factor + + def forward(self, input: SparseTensor) -> SparseTensor: + DIM = input.coords.shape[-1] - 1 + factor = self.factor if isinstance(self.factor, tuple) else (self.factor,) * DIM + assert DIM == len(factor), 'Input coordinates must have the same dimension as the upsample factor.' + + new_coords = input.get_spatial_cache(f'upsample_{factor}_coords') + new_layout = input.get_spatial_cache(f'upsample_{factor}_layout') + idx = input.get_spatial_cache(f'upsample_{factor}_idx') + if any([x is None for x in [new_coords, new_layout, idx]]): + raise ValueError('Upsample cache not found. SparseUpsample must be paired with SparseDownsample.') + new_feats = input.feats[idx] + out = SparseTensor(new_feats, new_coords, input.shape, new_layout) + out._scale = tuple([s * f for s, f in zip(input._scale, factor)]) + out._spatial_cache = input._spatial_cache + return out + +class SparseSubdivide(nn.Module): + """ + Upsample a sparse tensor by a factor of `factor`. + Implemented as nearest neighbor interpolation. + """ + def __init__(self): + super(SparseSubdivide, self).__init__() + + def forward(self, input: SparseTensor) -> SparseTensor: + DIM = input.coords.shape[-1] - 1 + # upsample scale=2^DIM + n_cube = torch.ones([2] * DIM, device=input.device, dtype=torch.int) + n_coords = torch.nonzero(n_cube) + n_coords = torch.cat([torch.zeros_like(n_coords[:, :1]), n_coords], dim=-1) + factor = n_coords.shape[0] + assert factor == 2 ** DIM + # print(n_coords.shape) + new_coords = input.coords.clone() + new_coords[:, 1:] *= 2 + new_coords = new_coords.unsqueeze(1) + n_coords.unsqueeze(0).to(new_coords.dtype) + + new_feats = input.feats.unsqueeze(1).expand(input.feats.shape[0], factor, *input.feats.shape[1:]) + out = SparseTensor(new_feats.flatten(0, 1), new_coords.flatten(0, 1), input.shape) + out._scale = input._scale * 2 + out._spatial_cache = input._spatial_cache + return out + diff --git a/TRELLIS/trellis/modules/sparse/transformer/__init__.py b/TRELLIS/trellis/modules/sparse/transformer/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..67336cacd084ef5e779bf5a601d66720ea275fe6 --- /dev/null +++ b/TRELLIS/trellis/modules/sparse/transformer/__init__.py @@ -0,0 +1,2 @@ +from .blocks import * +from .modulated import * \ No newline at end of file diff --git a/TRELLIS/trellis/modules/sparse/transformer/blocks.py b/TRELLIS/trellis/modules/sparse/transformer/blocks.py new file mode 100644 index 0000000000000000000000000000000000000000..03b7283e3d2e6c8cdfd7fba82548a73ec7dd3130 --- /dev/null +++ b/TRELLIS/trellis/modules/sparse/transformer/blocks.py @@ -0,0 +1,151 @@ +from typing import * +import torch +import torch.nn as nn +from ..basic import SparseTensor +from ..linear import SparseLinear +from ..nonlinearity import SparseGELU +from ..attention import SparseMultiHeadAttention, SerializeMode +from ...norm import LayerNorm32 + + +class SparseFeedForwardNet(nn.Module): + def __init__(self, channels: int, mlp_ratio: float = 4.0): + super().__init__() + self.mlp = nn.Sequential( + SparseLinear(channels, int(channels * mlp_ratio)), + SparseGELU(approximate="tanh"), + SparseLinear(int(channels * mlp_ratio), channels), + ) + + def forward(self, x: SparseTensor) -> SparseTensor: + return self.mlp(x) + + +class SparseTransformerBlock(nn.Module): + """ + Sparse Transformer block (MSA + FFN). + """ + def __init__( + self, + channels: int, + num_heads: int, + mlp_ratio: float = 4.0, + attn_mode: Literal["full", "shift_window", "shift_sequence", "shift_order", "swin"] = "full", + window_size: Optional[int] = None, + shift_sequence: Optional[int] = None, + shift_window: Optional[Tuple[int, int, int]] = None, + serialize_mode: Optional[SerializeMode] = None, + use_checkpoint: bool = False, + use_rope: bool = False, + qk_rms_norm: bool = False, + qkv_bias: bool = True, + ln_affine: bool = False, + ): + super().__init__() + self.use_checkpoint = use_checkpoint + self.norm1 = LayerNorm32(channels, elementwise_affine=ln_affine, eps=1e-6) + self.norm2 = LayerNorm32(channels, elementwise_affine=ln_affine, eps=1e-6) + self.attn = SparseMultiHeadAttention( + channels, + num_heads=num_heads, + attn_mode=attn_mode, + window_size=window_size, + shift_sequence=shift_sequence, + shift_window=shift_window, + serialize_mode=serialize_mode, + qkv_bias=qkv_bias, + use_rope=use_rope, + qk_rms_norm=qk_rms_norm, + ) + self.mlp = SparseFeedForwardNet( + channels, + mlp_ratio=mlp_ratio, + ) + + def _forward(self, x: SparseTensor) -> SparseTensor: + h = x.replace(self.norm1(x.feats)) + h = self.attn(h) + x = x + h + h = x.replace(self.norm2(x.feats)) + h = self.mlp(h) + x = x + h + return x + + def forward(self, x: SparseTensor) -> SparseTensor: + if self.use_checkpoint: + return torch.utils.checkpoint.checkpoint(self._forward, x, use_reentrant=False) + else: + return self._forward(x) + + +class SparseTransformerCrossBlock(nn.Module): + """ + Sparse Transformer cross-attention block (MSA + MCA + FFN). + """ + def __init__( + self, + channels: int, + ctx_channels: int, + num_heads: int, + mlp_ratio: float = 4.0, + attn_mode: Literal["full", "shift_window", "shift_sequence", "shift_order", "swin"] = "full", + window_size: Optional[int] = None, + shift_sequence: Optional[int] = None, + shift_window: Optional[Tuple[int, int, int]] = None, + serialize_mode: Optional[SerializeMode] = None, + use_checkpoint: bool = False, + use_rope: bool = False, + qk_rms_norm: bool = False, + qk_rms_norm_cross: bool = False, + qkv_bias: bool = True, + ln_affine: bool = False, + ): + super().__init__() + self.use_checkpoint = use_checkpoint + self.norm1 = LayerNorm32(channels, elementwise_affine=ln_affine, eps=1e-6) + self.norm2 = LayerNorm32(channels, elementwise_affine=ln_affine, eps=1e-6) + self.norm3 = LayerNorm32(channels, elementwise_affine=ln_affine, eps=1e-6) + self.self_attn = SparseMultiHeadAttention( + channels, + num_heads=num_heads, + type="self", + attn_mode=attn_mode, + window_size=window_size, + shift_sequence=shift_sequence, + shift_window=shift_window, + serialize_mode=serialize_mode, + qkv_bias=qkv_bias, + use_rope=use_rope, + qk_rms_norm=qk_rms_norm, + ) + self.cross_attn = SparseMultiHeadAttention( + channels, + ctx_channels=ctx_channels, + num_heads=num_heads, + type="cross", + attn_mode="full", + qkv_bias=qkv_bias, + qk_rms_norm=qk_rms_norm_cross, + ) + self.mlp = SparseFeedForwardNet( + channels, + mlp_ratio=mlp_ratio, + ) + + def _forward(self, x: SparseTensor, mod: torch.Tensor, context: torch.Tensor): + h = x.replace(self.norm1(x.feats)) + h = self.self_attn(h) + x = x + h + h = x.replace(self.norm2(x.feats)) + h = self.cross_attn(h, context) + x = x + h + h = x.replace(self.norm3(x.feats)) + h = self.mlp(h) + x = x + h + return x + + def forward(self, x: SparseTensor, context: torch.Tensor): + if self.use_checkpoint: + return torch.utils.checkpoint.checkpoint(self._forward, x, context, use_reentrant=False) + else: + return self._forward(x, context) diff --git a/TRELLIS/trellis/modules/sparse/transformer/modulated.py b/TRELLIS/trellis/modules/sparse/transformer/modulated.py new file mode 100644 index 0000000000000000000000000000000000000000..6ec00bbb567dd432130490d799ac6cf480107593 --- /dev/null +++ b/TRELLIS/trellis/modules/sparse/transformer/modulated.py @@ -0,0 +1,166 @@ +from typing import * +import torch +import torch.nn as nn +from ..basic import SparseTensor +from ..attention import SparseMultiHeadAttention, SerializeMode +from ...norm import LayerNorm32 +from .blocks import SparseFeedForwardNet + + +class ModulatedSparseTransformerBlock(nn.Module): + """ + Sparse Transformer block (MSA + FFN) with adaptive layer norm conditioning. + """ + def __init__( + self, + channels: int, + num_heads: int, + mlp_ratio: float = 4.0, + attn_mode: Literal["full", "shift_window", "shift_sequence", "shift_order", "swin"] = "full", + window_size: Optional[int] = None, + shift_sequence: Optional[int] = None, + shift_window: Optional[Tuple[int, int, int]] = None, + serialize_mode: Optional[SerializeMode] = None, + use_checkpoint: bool = False, + use_rope: bool = False, + qk_rms_norm: bool = False, + qkv_bias: bool = True, + share_mod: bool = False, + ): + super().__init__() + self.use_checkpoint = use_checkpoint + self.share_mod = share_mod + self.norm1 = LayerNorm32(channels, elementwise_affine=False, eps=1e-6) + self.norm2 = LayerNorm32(channels, elementwise_affine=False, eps=1e-6) + self.attn = SparseMultiHeadAttention( + channels, + num_heads=num_heads, + attn_mode=attn_mode, + window_size=window_size, + shift_sequence=shift_sequence, + shift_window=shift_window, + serialize_mode=serialize_mode, + qkv_bias=qkv_bias, + use_rope=use_rope, + qk_rms_norm=qk_rms_norm, + ) + self.mlp = SparseFeedForwardNet( + channels, + mlp_ratio=mlp_ratio, + ) + if not share_mod: + self.adaLN_modulation = nn.Sequential( + nn.SiLU(), + nn.Linear(channels, 6 * channels, bias=True) + ) + + def _forward(self, x: SparseTensor, mod: torch.Tensor) -> SparseTensor: + if self.share_mod: + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = mod.chunk(6, dim=1) + else: + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(mod).chunk(6, dim=1) + h = x.replace(self.norm1(x.feats)) + h = h * (1 + scale_msa) + shift_msa + h = self.attn(h) + h = h * gate_msa + x = x + h + h = x.replace(self.norm2(x.feats)) + h = h * (1 + scale_mlp) + shift_mlp + h = self.mlp(h) + h = h * gate_mlp + x = x + h + return x + + def forward(self, x: SparseTensor, mod: torch.Tensor) -> SparseTensor: + if self.use_checkpoint: + return torch.utils.checkpoint.checkpoint(self._forward, x, mod, use_reentrant=False) + else: + return self._forward(x, mod) + + +class ModulatedSparseTransformerCrossBlock(nn.Module): + """ + Sparse Transformer cross-attention block (MSA + MCA + FFN) with adaptive layer norm conditioning. + """ + def __init__( + self, + channels: int, + ctx_channels: int, + num_heads: int, + mlp_ratio: float = 4.0, + attn_mode: Literal["full", "shift_window", "shift_sequence", "shift_order", "swin"] = "full", + window_size: Optional[int] = None, + shift_sequence: Optional[int] = None, + shift_window: Optional[Tuple[int, int, int]] = None, + serialize_mode: Optional[SerializeMode] = None, + use_checkpoint: bool = False, + use_rope: bool = False, + qk_rms_norm: bool = False, + qk_rms_norm_cross: bool = False, + qkv_bias: bool = True, + share_mod: bool = False, + + ): + super().__init__() + self.use_checkpoint = use_checkpoint + self.share_mod = share_mod + self.norm1 = LayerNorm32(channels, elementwise_affine=False, eps=1e-6) + self.norm2 = LayerNorm32(channels, elementwise_affine=True, eps=1e-6) + self.norm3 = LayerNorm32(channels, elementwise_affine=False, eps=1e-6) + self.self_attn = SparseMultiHeadAttention( + channels, + num_heads=num_heads, + type="self", + attn_mode=attn_mode, + window_size=window_size, + shift_sequence=shift_sequence, + shift_window=shift_window, + serialize_mode=serialize_mode, + qkv_bias=qkv_bias, + use_rope=use_rope, + qk_rms_norm=qk_rms_norm, + ) + self.cross_attn = SparseMultiHeadAttention( + channels, + ctx_channels=ctx_channels, + num_heads=num_heads, + type="cross", + attn_mode="full", + qkv_bias=qkv_bias, + qk_rms_norm=qk_rms_norm_cross, + ) + self.mlp = SparseFeedForwardNet( + channels, + mlp_ratio=mlp_ratio, + ) + if not share_mod: + self.adaLN_modulation = nn.Sequential( + nn.SiLU(), + nn.Linear(channels, 6 * channels, bias=True) + ) + + def _forward(self, x: SparseTensor, mod: torch.Tensor, context: torch.Tensor) -> SparseTensor: + if self.share_mod: + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = mod.chunk(6, dim=1) + else: + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(mod).chunk(6, dim=1) + h = x.replace(self.norm1(x.feats)) + h = h * (1 + scale_msa) + shift_msa + h = self.self_attn(h) + h = h * gate_msa + x = x + h + h = x.replace(self.norm2(x.feats)) + h = self.cross_attn(h, context) + x = x + h + h = x.replace(self.norm3(x.feats)) + h = h * (1 + scale_mlp) + shift_mlp + h = self.mlp(h) + h = h * gate_mlp + x = x + h + return x + + def forward(self, x: SparseTensor, mod: torch.Tensor, context: torch.Tensor) -> SparseTensor: + if self.use_checkpoint: + return torch.utils.checkpoint.checkpoint(self._forward, x, mod, context, use_reentrant=False) + else: + return self._forward(x, mod, context) diff --git a/TRELLIS/trellis/modules/spatial.py b/TRELLIS/trellis/modules/spatial.py new file mode 100644 index 0000000000000000000000000000000000000000..5e3b750c1da9462818ad5e25cc50e59a7d92f786 --- /dev/null +++ b/TRELLIS/trellis/modules/spatial.py @@ -0,0 +1,48 @@ +import torch + + +def pixel_shuffle_3d(x: torch.Tensor, scale_factor: int) -> torch.Tensor: + """ + 3D pixel shuffle. + """ + B, C, H, W, D = x.shape + C_ = C // scale_factor**3 + x = x.reshape(B, C_, scale_factor, scale_factor, scale_factor, H, W, D) + x = x.permute(0, 1, 5, 2, 6, 3, 7, 4) + x = x.reshape(B, C_, H*scale_factor, W*scale_factor, D*scale_factor) + return x + + +def patchify(x: torch.Tensor, patch_size: int): + """ + Patchify a tensor. + + Args: + x (torch.Tensor): (N, C, *spatial) tensor + patch_size (int): Patch size + """ + DIM = x.dim() - 2 + for d in range(2, DIM + 2): + assert x.shape[d] % patch_size == 0, f"Dimension {d} of input tensor must be divisible by patch size, got {x.shape[d]} and {patch_size}" + + x = x.reshape(*x.shape[:2], *sum([[x.shape[d] // patch_size, patch_size] for d in range(2, DIM + 2)], [])) + x = x.permute(0, 1, *([2 * i + 3 for i in range(DIM)] + [2 * i + 2 for i in range(DIM)])) + x = x.reshape(x.shape[0], x.shape[1] * (patch_size ** DIM), *(x.shape[-DIM:])) + return x + + +def unpatchify(x: torch.Tensor, patch_size: int): + """ + Unpatchify a tensor. + + Args: + x (torch.Tensor): (N, C, *spatial) tensor + patch_size (int): Patch size + """ + DIM = x.dim() - 2 + assert x.shape[1] % (patch_size ** DIM) == 0, f"Second dimension of input tensor must be divisible by patch size to unpatchify, got {x.shape[1]} and {patch_size ** DIM}" + + x = x.reshape(x.shape[0], x.shape[1] // (patch_size ** DIM), *([patch_size] * DIM), *(x.shape[-DIM:])) + x = x.permute(0, 1, *(sum([[2 + DIM + i, 2 + i] for i in range(DIM)], []))) + x = x.reshape(x.shape[0], x.shape[1], *[x.shape[2 + 2 * i] * patch_size for i in range(DIM)]) + return x diff --git a/TRELLIS/trellis/modules/transformer/__init__.py b/TRELLIS/trellis/modules/transformer/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..67336cacd084ef5e779bf5a601d66720ea275fe6 --- /dev/null +++ b/TRELLIS/trellis/modules/transformer/__init__.py @@ -0,0 +1,2 @@ +from .blocks import * +from .modulated import * \ No newline at end of file diff --git a/TRELLIS/trellis/modules/transformer/blocks.py b/TRELLIS/trellis/modules/transformer/blocks.py new file mode 100644 index 0000000000000000000000000000000000000000..ab65a2a20172b00e1f35e8e2db45701f393ff82d --- /dev/null +++ b/TRELLIS/trellis/modules/transformer/blocks.py @@ -0,0 +1,182 @@ +from typing import * +import torch +import torch.nn as nn +from ..attention import MultiHeadAttention +from ..norm import LayerNorm32 + + +class AbsolutePositionEmbedder(nn.Module): + """ + Embeds spatial positions into vector representations. + """ + def __init__(self, channels: int, in_channels: int = 3): + super().__init__() + self.channels = channels + self.in_channels = in_channels + self.freq_dim = channels // in_channels // 2 + self.freqs = torch.arange(self.freq_dim, dtype=torch.float32) / self.freq_dim + self.freqs = 1.0 / (10000 ** self.freqs) + + def _sin_cos_embedding(self, x: torch.Tensor) -> torch.Tensor: + """ + Create sinusoidal position embeddings. + + Args: + x: a 1-D Tensor of N indices + + Returns: + an (N, D) Tensor of positional embeddings. + """ + self.freqs = self.freqs.to(x.device) + out = torch.outer(x, self.freqs) + out = torch.cat([torch.sin(out), torch.cos(out)], dim=-1) + return out + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Args: + x (torch.Tensor): (N, D) tensor of spatial positions + """ + N, D = x.shape + assert D == self.in_channels, "Input dimension must match number of input channels" + embed = self._sin_cos_embedding(x.reshape(-1)) + embed = embed.reshape(N, -1) + if embed.shape[1] < self.channels: + embed = torch.cat([embed, torch.zeros(N, self.channels - embed.shape[1], device=embed.device)], dim=-1) + return embed + + +class FeedForwardNet(nn.Module): + def __init__(self, channels: int, mlp_ratio: float = 4.0): + super().__init__() + self.mlp = nn.Sequential( + nn.Linear(channels, int(channels * mlp_ratio)), + nn.GELU(approximate="tanh"), + nn.Linear(int(channels * mlp_ratio), channels), + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.mlp(x) + + +class TransformerBlock(nn.Module): + """ + Transformer block (MSA + FFN). + """ + def __init__( + self, + channels: int, + num_heads: int, + mlp_ratio: float = 4.0, + attn_mode: Literal["full", "windowed"] = "full", + window_size: Optional[int] = None, + shift_window: Optional[int] = None, + use_checkpoint: bool = False, + use_rope: bool = False, + qk_rms_norm: bool = False, + qkv_bias: bool = True, + ln_affine: bool = False, + ): + super().__init__() + self.use_checkpoint = use_checkpoint + self.norm1 = LayerNorm32(channels, elementwise_affine=ln_affine, eps=1e-6) + self.norm2 = LayerNorm32(channels, elementwise_affine=ln_affine, eps=1e-6) + self.attn = MultiHeadAttention( + channels, + num_heads=num_heads, + attn_mode=attn_mode, + window_size=window_size, + shift_window=shift_window, + qkv_bias=qkv_bias, + use_rope=use_rope, + qk_rms_norm=qk_rms_norm, + ) + self.mlp = FeedForwardNet( + channels, + mlp_ratio=mlp_ratio, + ) + + def _forward(self, x: torch.Tensor) -> torch.Tensor: + h = self.norm1(x) + h = self.attn(h) + x = x + h + h = self.norm2(x) + h = self.mlp(h) + x = x + h + return x + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if self.use_checkpoint: + return torch.utils.checkpoint.checkpoint(self._forward, x, use_reentrant=False) + else: + return self._forward(x) + + +class TransformerCrossBlock(nn.Module): + """ + Transformer cross-attention block (MSA + MCA + FFN). + """ + def __init__( + self, + channels: int, + ctx_channels: int, + num_heads: int, + mlp_ratio: float = 4.0, + attn_mode: Literal["full", "windowed"] = "full", + window_size: Optional[int] = None, + shift_window: Optional[Tuple[int, int, int]] = None, + use_checkpoint: bool = False, + use_rope: bool = False, + qk_rms_norm: bool = False, + qk_rms_norm_cross: bool = False, + qkv_bias: bool = True, + ln_affine: bool = False, + ): + super().__init__() + self.use_checkpoint = use_checkpoint + self.norm1 = LayerNorm32(channels, elementwise_affine=ln_affine, eps=1e-6) + self.norm2 = LayerNorm32(channels, elementwise_affine=ln_affine, eps=1e-6) + self.norm3 = LayerNorm32(channels, elementwise_affine=ln_affine, eps=1e-6) + self.self_attn = MultiHeadAttention( + channels, + num_heads=num_heads, + type="self", + attn_mode=attn_mode, + window_size=window_size, + shift_window=shift_window, + qkv_bias=qkv_bias, + use_rope=use_rope, + qk_rms_norm=qk_rms_norm, + ) + self.cross_attn = MultiHeadAttention( + channels, + ctx_channels=ctx_channels, + num_heads=num_heads, + type="cross", + attn_mode="full", + qkv_bias=qkv_bias, + qk_rms_norm=qk_rms_norm_cross, + ) + self.mlp = FeedForwardNet( + channels, + mlp_ratio=mlp_ratio, + ) + + def _forward(self, x: torch.Tensor, context: torch.Tensor): + h = self.norm1(x) + h = self.self_attn(h) + x = x + h + h = self.norm2(x) + h = self.cross_attn(h, context) + x = x + h + h = self.norm3(x) + h = self.mlp(h) + x = x + h + return x + + def forward(self, x: torch.Tensor, context: torch.Tensor): + if self.use_checkpoint: + return torch.utils.checkpoint.checkpoint(self._forward, x, context, use_reentrant=False) + else: + return self._forward(x, context) + \ No newline at end of file diff --git a/TRELLIS/trellis/modules/transformer/modulated.py b/TRELLIS/trellis/modules/transformer/modulated.py new file mode 100644 index 0000000000000000000000000000000000000000..a8b90190b6dabc04b38499a6033483334fcfa69a --- /dev/null +++ b/TRELLIS/trellis/modules/transformer/modulated.py @@ -0,0 +1,157 @@ +from typing import * +import torch +import torch.nn as nn +from ..attention import MultiHeadAttention +from ..norm import LayerNorm32 +from .blocks import FeedForwardNet + + +class ModulatedTransformerBlock(nn.Module): + """ + Transformer block (MSA + FFN) with adaptive layer norm conditioning. + """ + def __init__( + self, + channels: int, + num_heads: int, + mlp_ratio: float = 4.0, + attn_mode: Literal["full", "windowed"] = "full", + window_size: Optional[int] = None, + shift_window: Optional[Tuple[int, int, int]] = None, + use_checkpoint: bool = False, + use_rope: bool = False, + qk_rms_norm: bool = False, + qkv_bias: bool = True, + share_mod: bool = False, + ): + super().__init__() + self.use_checkpoint = use_checkpoint + self.share_mod = share_mod + self.norm1 = LayerNorm32(channels, elementwise_affine=False, eps=1e-6) + self.norm2 = LayerNorm32(channels, elementwise_affine=False, eps=1e-6) + self.attn = MultiHeadAttention( + channels, + num_heads=num_heads, + attn_mode=attn_mode, + window_size=window_size, + shift_window=shift_window, + qkv_bias=qkv_bias, + use_rope=use_rope, + qk_rms_norm=qk_rms_norm, + ) + self.mlp = FeedForwardNet( + channels, + mlp_ratio=mlp_ratio, + ) + if not share_mod: + self.adaLN_modulation = nn.Sequential( + nn.SiLU(), + nn.Linear(channels, 6 * channels, bias=True) + ) + + def _forward(self, x: torch.Tensor, mod: torch.Tensor) -> torch.Tensor: + if self.share_mod: + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = mod.chunk(6, dim=1) + else: + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(mod).chunk(6, dim=1) + h = self.norm1(x) + h = h * (1 + scale_msa.unsqueeze(1)) + shift_msa.unsqueeze(1) + h = self.attn(h) + h = h * gate_msa.unsqueeze(1) + x = x + h + h = self.norm2(x) + h = h * (1 + scale_mlp.unsqueeze(1)) + shift_mlp.unsqueeze(1) + h = self.mlp(h) + h = h * gate_mlp.unsqueeze(1) + x = x + h + return x + + def forward(self, x: torch.Tensor, mod: torch.Tensor) -> torch.Tensor: + if self.use_checkpoint: + return torch.utils.checkpoint.checkpoint(self._forward, x, mod, use_reentrant=False) + else: + return self._forward(x, mod) + + +class ModulatedTransformerCrossBlock(nn.Module): + """ + Transformer cross-attention block (MSA + MCA + FFN) with adaptive layer norm conditioning. + """ + def __init__( + self, + channels: int, + ctx_channels: int, + num_heads: int, + mlp_ratio: float = 4.0, + attn_mode: Literal["full", "windowed"] = "full", + window_size: Optional[int] = None, + shift_window: Optional[Tuple[int, int, int]] = None, + use_checkpoint: bool = False, + use_rope: bool = False, + qk_rms_norm: bool = False, + qk_rms_norm_cross: bool = False, + qkv_bias: bool = True, + share_mod: bool = False, + ): + super().__init__() + self.use_checkpoint = use_checkpoint + self.share_mod = share_mod + self.norm1 = LayerNorm32(channels, elementwise_affine=False, eps=1e-6) + self.norm2 = LayerNorm32(channels, elementwise_affine=True, eps=1e-6) + self.norm3 = LayerNorm32(channels, elementwise_affine=False, eps=1e-6) + self.self_attn = MultiHeadAttention( + channels, + num_heads=num_heads, + type="self", + attn_mode=attn_mode, + window_size=window_size, + shift_window=shift_window, + qkv_bias=qkv_bias, + use_rope=use_rope, + qk_rms_norm=qk_rms_norm, + ) + self.cross_attn = MultiHeadAttention( + channels, + ctx_channels=ctx_channels, + num_heads=num_heads, + type="cross", + attn_mode="full", + qkv_bias=qkv_bias, + qk_rms_norm=qk_rms_norm_cross, + ) + self.mlp = FeedForwardNet( + channels, + mlp_ratio=mlp_ratio, + ) + if not share_mod: + self.adaLN_modulation = nn.Sequential( + nn.SiLU(), + nn.Linear(channels, 6 * channels, bias=True) + ) + + def _forward(self, x: torch.Tensor, mod: torch.Tensor, context: torch.Tensor): + if self.share_mod: + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = mod.chunk(6, dim=1) + else: + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(mod).chunk(6, dim=1) + h = self.norm1(x) + h = h * (1 + scale_msa.unsqueeze(1)) + shift_msa.unsqueeze(1) + h = self.self_attn(h) + h = h * gate_msa.unsqueeze(1) + x = x + h + h = self.norm2(x) + h = self.cross_attn(h, context) + x = x + h + h = self.norm3(x) + h = h * (1 + scale_mlp.unsqueeze(1)) + shift_mlp.unsqueeze(1) + h = self.mlp(h) + h = h * gate_mlp.unsqueeze(1) + x = x + h + return x + + def forward(self, x: torch.Tensor, mod: torch.Tensor, context: torch.Tensor): + if self.use_checkpoint: + return torch.utils.checkpoint.checkpoint(self._forward, x, mod, context, use_reentrant=False) + else: + return self._forward(x, mod, context) + \ No newline at end of file diff --git a/TRELLIS/trellis/modules/utils.py b/TRELLIS/trellis/modules/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..215fe98059ab49eca42703a4f1f92d80c5343f6b --- /dev/null +++ b/TRELLIS/trellis/modules/utils.py @@ -0,0 +1,54 @@ +import torch.nn as nn +from ..modules import sparse as sp + +FP16_MODULES = ( + nn.Conv1d, + nn.Conv2d, + nn.Conv3d, + nn.ConvTranspose1d, + nn.ConvTranspose2d, + nn.ConvTranspose3d, + nn.Linear, + sp.SparseConv3d, + sp.SparseInverseConv3d, + sp.SparseLinear, +) + +def convert_module_to_f16(l): + """ + Convert primitive modules to float16. + """ + if isinstance(l, FP16_MODULES): + for p in l.parameters(): + p.data = p.data.half() + + +def convert_module_to_f32(l): + """ + Convert primitive modules to float32, undoing convert_module_to_f16(). + """ + if isinstance(l, FP16_MODULES): + for p in l.parameters(): + p.data = p.data.float() + + +def zero_module(module): + """ + Zero out the parameters of a module and return it. + """ + for p in module.parameters(): + p.detach().zero_() + return module + + +def scale_module(module, scale): + """ + Scale the parameters of a module and return it. + """ + for p in module.parameters(): + p.detach().mul_(scale) + return module + + +def modulate(x, shift, scale): + return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1) diff --git a/TRELLIS/trellis/pipelines/__init__.py b/TRELLIS/trellis/pipelines/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c01e15250c6b976fb61567e887c6f33bb849fe50 --- /dev/null +++ b/TRELLIS/trellis/pipelines/__init__.py @@ -0,0 +1,24 @@ +from . import samplers +from .trellis_image_to_3d import TrellisImageTo3DPipeline + + +def from_pretrained(path: str): + """ + Load a pipeline from a model folder or a Hugging Face model hub. + + Args: + path: The path to the model. Can be either local path or a Hugging Face model name. + """ + import os + import json + is_local = os.path.exists(f"{path}/pipeline.json") + + if is_local: + config_file = f"{path}/pipeline.json" + else: + from huggingface_hub import hf_hub_download + config_file = hf_hub_download(path, "pipeline.json") + + with open(config_file, 'r') as f: + config = json.load(f) + return globals()[config['name']].from_pretrained(path) diff --git a/TRELLIS/trellis/pipelines/__pycache__/__init__.cpython-311.pyc b/TRELLIS/trellis/pipelines/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1a9ffdf39d69c0839a17f47b55d81b883023697c Binary files /dev/null and b/TRELLIS/trellis/pipelines/__pycache__/__init__.cpython-311.pyc differ diff --git a/TRELLIS/trellis/pipelines/__pycache__/trellis_image_to_3d.cpython-311.pyc b/TRELLIS/trellis/pipelines/__pycache__/trellis_image_to_3d.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..673a0f974e67c706f5266c1e0d3d665a01d6cabd Binary files /dev/null and b/TRELLIS/trellis/pipelines/__pycache__/trellis_image_to_3d.cpython-311.pyc differ diff --git a/TRELLIS/trellis/pipelines/base.py b/TRELLIS/trellis/pipelines/base.py new file mode 100644 index 0000000000000000000000000000000000000000..041ddce12a483557a25355bd6f2ccfe5bbfb5a17 --- /dev/null +++ b/TRELLIS/trellis/pipelines/base.py @@ -0,0 +1,66 @@ +from typing import * +import torch +import torch.nn as nn +from .. import models + + +class Pipeline: + """ + A base class for pipelines. + """ + def __init__( + self, + models: dict[str, nn.Module] = None, + ): + if models is None: + return + self.models = models + for model in self.models.values(): + model.eval() + + @staticmethod + def from_pretrained(path: str) -> "Pipeline": + """ + Load a pretrained model. + """ + import os + import json + is_local = os.path.exists(f"{path}/pipeline.json") + + if is_local: + config_file = f"{path}/pipeline.json" + else: + from huggingface_hub import hf_hub_download + config_file = hf_hub_download(path, "pipeline.json") + + with open(config_file, 'r') as f: + args = json.load(f)['args'] + + _models = { + k: models.from_pretrained(f"{path}/{v}") + for k, v in args['models'].items() + } + + new_pipeline = Pipeline(_models) + new_pipeline._pretrained_args = args + return new_pipeline + + @property + def device(self) -> torch.device: + for model in self.models.values(): + if hasattr(model, 'device'): + return model.device + for model in self.models.values(): + if hasattr(model, 'parameters'): + return next(model.parameters()).device + raise RuntimeError("No device found.") + + def to(self, device: torch.device) -> None: + for model in self.models.values(): + model.to(device) + + def cuda(self) -> None: + self.to(torch.device("cuda")) + + def cpu(self) -> None: + self.to(torch.device("cpu")) diff --git a/TRELLIS/trellis/pipelines/samplers/__init__.py b/TRELLIS/trellis/pipelines/samplers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..4b1111715e7d39c8b5db1b70bbc1360d9bb6e0c6 --- /dev/null +++ b/TRELLIS/trellis/pipelines/samplers/__init__.py @@ -0,0 +1,2 @@ +from .base import Sampler +from .flow_euler import FlowEulerSampler, FlowEulerCfgSampler, FlowEulerGuidanceIntervalSampler \ No newline at end of file diff --git a/TRELLIS/trellis/pipelines/samplers/__pycache__/__init__.cpython-311.pyc b/TRELLIS/trellis/pipelines/samplers/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b500e95c277c6650268e1fb7fd8e56d7959d4daa Binary files /dev/null and b/TRELLIS/trellis/pipelines/samplers/__pycache__/__init__.cpython-311.pyc differ diff --git a/TRELLIS/trellis/pipelines/samplers/__pycache__/base.cpython-311.pyc b/TRELLIS/trellis/pipelines/samplers/__pycache__/base.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..581149868d4817a37f66d1c18cfd47caff42e7ef Binary files /dev/null and b/TRELLIS/trellis/pipelines/samplers/__pycache__/base.cpython-311.pyc differ diff --git a/TRELLIS/trellis/pipelines/samplers/__pycache__/classifier_free_guidance_mixin.cpython-311.pyc b/TRELLIS/trellis/pipelines/samplers/__pycache__/classifier_free_guidance_mixin.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b2930369d9e505c8ec7976cca0abb9e143c4d83d Binary files /dev/null and b/TRELLIS/trellis/pipelines/samplers/__pycache__/classifier_free_guidance_mixin.cpython-311.pyc differ diff --git a/TRELLIS/trellis/pipelines/samplers/__pycache__/flow_euler.cpython-311.pyc b/TRELLIS/trellis/pipelines/samplers/__pycache__/flow_euler.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7b7d96c8443b61ede173c787fdd9fbd690b22579 Binary files /dev/null and b/TRELLIS/trellis/pipelines/samplers/__pycache__/flow_euler.cpython-311.pyc differ diff --git a/TRELLIS/trellis/pipelines/samplers/__pycache__/guidance_interval_mixin.cpython-311.pyc b/TRELLIS/trellis/pipelines/samplers/__pycache__/guidance_interval_mixin.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bc8e137f0e7441434d0734b86ef8714c8b5025d8 Binary files /dev/null and b/TRELLIS/trellis/pipelines/samplers/__pycache__/guidance_interval_mixin.cpython-311.pyc differ diff --git a/TRELLIS/trellis/pipelines/samplers/base.py b/TRELLIS/trellis/pipelines/samplers/base.py new file mode 100644 index 0000000000000000000000000000000000000000..bb70700117317477e738845e566b9ea87a768d0a --- /dev/null +++ b/TRELLIS/trellis/pipelines/samplers/base.py @@ -0,0 +1,20 @@ +from typing import * +from abc import ABC, abstractmethod + + +class Sampler(ABC): + """ + A base class for samplers. + """ + + @abstractmethod + def sample( + self, + model, + **kwargs + ): + """ + Sample from a model. + """ + pass + \ No newline at end of file diff --git a/TRELLIS/trellis/pipelines/samplers/classifier_free_guidance_mixin.py b/TRELLIS/trellis/pipelines/samplers/classifier_free_guidance_mixin.py new file mode 100644 index 0000000000000000000000000000000000000000..076e1e3d9f6d1e7207c9659db530990894d614f9 --- /dev/null +++ b/TRELLIS/trellis/pipelines/samplers/classifier_free_guidance_mixin.py @@ -0,0 +1,12 @@ +from typing import * + + +class ClassifierFreeGuidanceSamplerMixin: + """ + A mixin class for samplers that apply classifier-free guidance. + """ + + def _inference_model(self, model, x_t, t, cond, neg_cond, cfg_strength, **kwargs): + pred = super()._inference_model(model, x_t, t, cond, **kwargs) + neg_pred = super()._inference_model(model, x_t, t, neg_cond, **kwargs) + return (1 + cfg_strength) * pred - cfg_strength * neg_pred diff --git a/TRELLIS/trellis/pipelines/samplers/flow_euler.py b/TRELLIS/trellis/pipelines/samplers/flow_euler.py new file mode 100644 index 0000000000000000000000000000000000000000..a84c9d472b0c74b807663083b8aea63ff1eb3c7e --- /dev/null +++ b/TRELLIS/trellis/pipelines/samplers/flow_euler.py @@ -0,0 +1,199 @@ +from typing import * +import torch +import numpy as np +from tqdm import tqdm +from easydict import EasyDict as edict +from .base import Sampler +from .classifier_free_guidance_mixin import ClassifierFreeGuidanceSamplerMixin +from .guidance_interval_mixin import GuidanceIntervalSamplerMixin + + +class FlowEulerSampler(Sampler): + """ + Generate samples from a flow-matching model using Euler sampling. + + Args: + sigma_min: The minimum scale of noise in flow. + """ + def __init__( + self, + sigma_min: float, + ): + self.sigma_min = sigma_min + + def _eps_to_xstart(self, x_t, t, eps): + assert x_t.shape == eps.shape + return (x_t - (self.sigma_min + (1 - self.sigma_min) * t) * eps) / (1 - t) + + def _xstart_to_eps(self, x_t, t, x_0): + assert x_t.shape == x_0.shape + return (x_t - (1 - t) * x_0) / (self.sigma_min + (1 - self.sigma_min) * t) + + def _v_to_xstart_eps(self, x_t, t, v): + assert x_t.shape == v.shape + eps = (1 - t) * v + x_t + x_0 = (1 - self.sigma_min) * x_t - (self.sigma_min + (1 - self.sigma_min) * t) * v + return x_0, eps + + def _inference_model(self, model, x_t, t, cond=None, **kwargs): + t = torch.tensor([1000 * t] * x_t.shape[0], device=x_t.device, dtype=torch.float32) + return model(x_t, t, cond, **kwargs) + + def _get_model_prediction(self, model, x_t, t, cond=None, **kwargs): + pred_v = self._inference_model(model, x_t, t, cond, **kwargs) + pred_x_0, pred_eps = self._v_to_xstart_eps(x_t=x_t, t=t, v=pred_v) + return pred_x_0, pred_eps, pred_v + + @torch.no_grad() + def sample_once( + self, + model, + x_t, + t: float, + t_prev: float, + cond: Optional[Any] = None, + **kwargs + ): + """ + Sample x_{t-1} from the model using Euler method. + + Args: + model: The model to sample from. + x_t: The [N x C x ...] tensor of noisy inputs at time t. + t: The current timestep. + t_prev: The previous timestep. + cond: conditional information. + **kwargs: Additional arguments for model inference. + + Returns: + a dict containing the following + - 'pred_x_prev': x_{t-1}. + - 'pred_x_0': a prediction of x_0. + """ + pred_x_0, pred_eps, pred_v = self._get_model_prediction(model, x_t, t, cond, **kwargs) + pred_x_prev = x_t - (t - t_prev) * pred_v + return edict({"pred_x_prev": pred_x_prev, "pred_x_0": pred_x_0}) + + @torch.no_grad() + def sample( + self, + model, + noise, + cond: Optional[Any] = None, + steps: int = 50, + rescale_t: float = 1.0, + verbose: bool = True, + **kwargs + ): + """ + Generate samples from the model using Euler method. + + Args: + model: The model to sample from. + noise: The initial noise tensor. + cond: conditional information. + steps: The number of steps to sample. + rescale_t: The rescale factor for t. + verbose: If True, show a progress bar. + **kwargs: Additional arguments for model_inference. + + Returns: + a dict containing the following + - 'samples': the model samples. + - 'pred_x_t': a list of prediction of x_t. + - 'pred_x_0': a list of prediction of x_0. + """ + sample = noise + t_seq = np.linspace(1, 0, steps + 1) + t_seq = rescale_t * t_seq / (1 + (rescale_t - 1) * t_seq) + t_pairs = list((t_seq[i], t_seq[i + 1]) for i in range(steps)) + ret = edict({"samples": None, "pred_x_t": [], "pred_x_0": []}) + for t, t_prev in tqdm(t_pairs, desc="Sampling", disable=not verbose): + out = self.sample_once(model, sample, t, t_prev, cond, **kwargs) + sample = out.pred_x_prev + ret.pred_x_t.append(out.pred_x_prev) + ret.pred_x_0.append(out.pred_x_0) + ret.samples = sample + return ret + + +class FlowEulerCfgSampler(ClassifierFreeGuidanceSamplerMixin, FlowEulerSampler): + """ + Generate samples from a flow-matching model using Euler sampling with classifier-free guidance. + """ + @torch.no_grad() + def sample( + self, + model, + noise, + cond, + neg_cond, + steps: int = 50, + rescale_t: float = 1.0, + cfg_strength: float = 3.0, + verbose: bool = True, + **kwargs + ): + """ + Generate samples from the model using Euler method. + + Args: + model: The model to sample from. + noise: The initial noise tensor. + cond: conditional information. + neg_cond: negative conditional information. + steps: The number of steps to sample. + rescale_t: The rescale factor for t. + cfg_strength: The strength of classifier-free guidance. + verbose: If True, show a progress bar. + **kwargs: Additional arguments for model_inference. + + Returns: + a dict containing the following + - 'samples': the model samples. + - 'pred_x_t': a list of prediction of x_t. + - 'pred_x_0': a list of prediction of x_0. + """ + return super().sample(model, noise, cond, steps, rescale_t, verbose, neg_cond=neg_cond, cfg_strength=cfg_strength, **kwargs) + + +class FlowEulerGuidanceIntervalSampler(GuidanceIntervalSamplerMixin, FlowEulerSampler): + """ + Generate samples from a flow-matching model using Euler sampling with classifier-free guidance and interval. + """ + @torch.no_grad() + def sample( + self, + model, + noise, + cond, + neg_cond, + steps: int = 50, + rescale_t: float = 1.0, + cfg_strength: float = 3.0, + cfg_interval: Tuple[float, float] = (0.0, 1.0), + verbose: bool = True, + **kwargs + ): + """ + Generate samples from the model using Euler method. + + Args: + model: The model to sample from. + noise: The initial noise tensor. + cond: conditional information. + neg_cond: negative conditional information. + steps: The number of steps to sample. + rescale_t: The rescale factor for t. + cfg_strength: The strength of classifier-free guidance. + cfg_interval: The interval for classifier-free guidance. + verbose: If True, show a progress bar. + **kwargs: Additional arguments for model_inference. + + Returns: + a dict containing the following + - 'samples': the model samples. + - 'pred_x_t': a list of prediction of x_t. + - 'pred_x_0': a list of prediction of x_0. + """ + return super().sample(model, noise, cond, steps, rescale_t, verbose, neg_cond=neg_cond, cfg_strength=cfg_strength, cfg_interval=cfg_interval, **kwargs) diff --git a/TRELLIS/trellis/pipelines/samplers/guidance_interval_mixin.py b/TRELLIS/trellis/pipelines/samplers/guidance_interval_mixin.py new file mode 100644 index 0000000000000000000000000000000000000000..10524ce011de121db28b17fcbc2589e60019042e --- /dev/null +++ b/TRELLIS/trellis/pipelines/samplers/guidance_interval_mixin.py @@ -0,0 +1,15 @@ +from typing import * + + +class GuidanceIntervalSamplerMixin: + """ + A mixin class for samplers that apply classifier-free guidance with interval. + """ + + def _inference_model(self, model, x_t, t, cond, neg_cond, cfg_strength, cfg_interval, **kwargs): + if cfg_interval[0] <= t <= cfg_interval[1]: + pred = super()._inference_model(model, x_t, t, cond, **kwargs) + neg_pred = super()._inference_model(model, x_t, t, neg_cond, **kwargs) + return (1 + cfg_strength) * pred - cfg_strength * neg_pred + else: + return super()._inference_model(model, x_t, t, cond, **kwargs) diff --git a/TRELLIS/trellis/pipelines/trellis_image_to_3d.py b/TRELLIS/trellis/pipelines/trellis_image_to_3d.py new file mode 100644 index 0000000000000000000000000000000000000000..faeb32297f62e82cdbed6692a2d2b30da1c9c56d --- /dev/null +++ b/TRELLIS/trellis/pipelines/trellis_image_to_3d.py @@ -0,0 +1,376 @@ +from typing import * +from contextlib import contextmanager +import torch +import torch.nn as nn +import torch.nn.functional as F +import numpy as np +from tqdm import tqdm +from easydict import EasyDict as edict +from torchvision import transforms +from PIL import Image +import rembg +from .base import Pipeline +from . import samplers +from ..modules import sparse as sp +from ..representations import Gaussian, Strivec, MeshExtractResult + + +class TrellisImageTo3DPipeline(Pipeline): + """ + Pipeline for inferring Trellis image-to-3D models. + + Args: + models (dict[str, nn.Module]): The models to use in the pipeline. + sparse_structure_sampler (samplers.Sampler): The sampler for the sparse structure. + slat_sampler (samplers.Sampler): The sampler for the structured latent. + slat_normalization (dict): The normalization parameters for the structured latent. + image_cond_model (str): The name of the image conditioning model. + """ + def __init__( + self, + models: dict[str, nn.Module] = None, + sparse_structure_sampler: samplers.Sampler = None, + slat_sampler: samplers.Sampler = None, + slat_normalization: dict = None, + image_cond_model: str = None, + ): + if models is None: + return + super().__init__(models) + self.sparse_structure_sampler = sparse_structure_sampler + self.slat_sampler = slat_sampler + self.sparse_structure_sampler_params = {} + self.slat_sampler_params = {} + self.slat_normalization = slat_normalization + self.rembg_session = None + self._init_image_cond_model(image_cond_model) + + @staticmethod + def from_pretrained(path: str) -> "TrellisImageTo3DPipeline": + """ + Load a pretrained model. + + Args: + path (str): The path to the model. Can be either local path or a Hugging Face repository. + """ + pipeline = super(TrellisImageTo3DPipeline, TrellisImageTo3DPipeline).from_pretrained(path) + new_pipeline = TrellisImageTo3DPipeline() + new_pipeline.__dict__ = pipeline.__dict__ + args = pipeline._pretrained_args + + new_pipeline.sparse_structure_sampler = getattr(samplers, args['sparse_structure_sampler']['name'])(**args['sparse_structure_sampler']['args']) + new_pipeline.sparse_structure_sampler_params = args['sparse_structure_sampler']['params'] + + new_pipeline.slat_sampler = getattr(samplers, args['slat_sampler']['name'])(**args['slat_sampler']['args']) + new_pipeline.slat_sampler_params = args['slat_sampler']['params'] + + new_pipeline.slat_normalization = args['slat_normalization'] + + new_pipeline._init_image_cond_model(args['image_cond_model']) + + return new_pipeline + + def _init_image_cond_model(self, name: str): + """ + Initialize the image conditioning model. + """ + dinov2_model = torch.hub.load('facebookresearch/dinov2', name, pretrained=True) + dinov2_model.eval() + self.models['image_cond_model'] = dinov2_model + transform = transforms.Compose([ + transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), + ]) + self.image_cond_model_transform = transform + + def preprocess_image(self, input: Image.Image) -> Image.Image: + """ + Preprocess the input image. + """ + # if has alpha channel, use it directly; otherwise, remove background + has_alpha = False + if input.mode == 'RGBA': + alpha = np.array(input)[:, :, 3] + if not np.all(alpha == 255): + has_alpha = True + if has_alpha: + output = input + else: + input = input.convert('RGB') + max_size = max(input.size) + scale = min(1, 1024 / max_size) + if scale < 1: + input = input.resize((int(input.width * scale), int(input.height * scale)), Image.Resampling.LANCZOS) + if getattr(self, 'rembg_session', None) is None: + self.rembg_session = rembg.new_session('u2net') + output = rembg.remove(input, session=self.rembg_session) + output_np = np.array(output) + alpha = output_np[:, :, 3] + bbox = np.argwhere(alpha > 0.8 * 255) + bbox = np.min(bbox[:, 1]), np.min(bbox[:, 0]), np.max(bbox[:, 1]), np.max(bbox[:, 0]) + center = (bbox[0] + bbox[2]) / 2, (bbox[1] + bbox[3]) / 2 + size = max(bbox[2] - bbox[0], bbox[3] - bbox[1]) + size = int(size * 1.2) + bbox = center[0] - size // 2, center[1] - size // 2, center[0] + size // 2, center[1] + size // 2 + output = output.crop(bbox) # type: ignore + output = output.resize((518, 518), Image.Resampling.LANCZOS) + output = np.array(output).astype(np.float32) / 255 + output = output[:, :, :3] * output[:, :, 3:4] + output = Image.fromarray((output * 255).astype(np.uint8)) + return output + + @torch.no_grad() + def encode_image(self, image: Union[torch.Tensor, list[Image.Image]]) -> torch.Tensor: + """ + Encode the image. + + Args: + image (Union[torch.Tensor, list[Image.Image]]): The image to encode + + Returns: + torch.Tensor: The encoded features. + """ + if isinstance(image, torch.Tensor): + assert image.ndim == 4, "Image tensor should be batched (B, C, H, W)" + elif isinstance(image, list): + assert all(isinstance(i, Image.Image) for i in image), "Image list should be list of PIL images" + image = [i.resize((518, 518), Image.LANCZOS) for i in image] + image = [np.array(i.convert('RGB')).astype(np.float32) / 255 for i in image] + image = [torch.from_numpy(i).permute(2, 0, 1).float() for i in image] + image = torch.stack(image).to(self.device) + else: + raise ValueError(f"Unsupported type of image: {type(image)}") + + image = self.image_cond_model_transform(image).to(self.device) + features = self.models['image_cond_model'](image, is_training=True)['x_prenorm'] + patchtokens = F.layer_norm(features, features.shape[-1:]) + return patchtokens + + def get_cond(self, image: Union[torch.Tensor, list[Image.Image]]) -> dict: + """ + Get the conditioning information for the model. + + Args: + image (Union[torch.Tensor, list[Image.Image]]): The image prompts. + + Returns: + dict: The conditioning information + """ + cond = self.encode_image(image) + neg_cond = torch.zeros_like(cond) + return { + 'cond': cond, + 'neg_cond': neg_cond, + } + + def sample_sparse_structure( + self, + cond: dict, + num_samples: int = 1, + sampler_params: dict = {}, + ) -> torch.Tensor: + """ + Sample sparse structures with the given conditioning. + + Args: + cond (dict): The conditioning information. + num_samples (int): The number of samples to generate. + sampler_params (dict): Additional parameters for the sampler. + """ + # Sample occupancy latent + flow_model = self.models['sparse_structure_flow_model'] + reso = flow_model.resolution + noise = torch.randn(num_samples, flow_model.in_channels, reso, reso, reso).to(self.device) + sampler_params = {**self.sparse_structure_sampler_params, **sampler_params} + z_s = self.sparse_structure_sampler.sample( + flow_model, + noise, + **cond, + **sampler_params, + verbose=True + ).samples + + # Decode occupancy latent + decoder = self.models['sparse_structure_decoder'] + coords = torch.argwhere(decoder(z_s)>0)[:, [0, 2, 3, 4]].int() + + return coords + + def decode_slat( + self, + slat: sp.SparseTensor, + formats: List[str] = ['mesh', 'gaussian', 'radiance_field'], + ) -> dict: + """ + Decode the structured latent. + + Args: + slat (sp.SparseTensor): The structured latent. + formats (List[str]): The formats to decode the structured latent to. + + Returns: + dict: The decoded structured latent. + """ + ret = {} + if 'mesh' in formats: + ret['mesh'] = self.models['slat_decoder_mesh'](slat) + if 'gaussian' in formats: + ret['gaussian'] = self.models['slat_decoder_gs'](slat) + if 'radiance_field' in formats: + ret['radiance_field'] = self.models['slat_decoder_rf'](slat) + return ret + + def sample_slat( + self, + cond: dict, + coords: torch.Tensor, + sampler_params: dict = {}, + ) -> sp.SparseTensor: + """ + Sample structured latent with the given conditioning. + + Args: + cond (dict): The conditioning information. + coords (torch.Tensor): The coordinates of the sparse structure. + sampler_params (dict): Additional parameters for the sampler. + """ + # Sample structured latent + flow_model = self.models['slat_flow_model'] + noise = sp.SparseTensor( + feats=torch.randn(coords.shape[0], flow_model.in_channels).to(self.device), + coords=coords, + ) + sampler_params = {**self.slat_sampler_params, **sampler_params} + slat = self.slat_sampler.sample( + flow_model, + noise, + **cond, + **sampler_params, + verbose=True + ).samples + + std = torch.tensor(self.slat_normalization['std'])[None].to(slat.device) + mean = torch.tensor(self.slat_normalization['mean'])[None].to(slat.device) + slat = slat * std + mean + + return slat + + @torch.no_grad() + def run( + self, + image: Image.Image, + num_samples: int = 1, + seed: int = 42, + sparse_structure_sampler_params: dict = {}, + slat_sampler_params: dict = {}, + formats: List[str] = ['mesh', 'gaussian', 'radiance_field'], + preprocess_image: bool = True, + ) -> dict: + """ + Run the pipeline. + + Args: + image (Image.Image): The image prompt. + num_samples (int): The number of samples to generate. + sparse_structure_sampler_params (dict): Additional parameters for the sparse structure sampler. + slat_sampler_params (dict): Additional parameters for the structured latent sampler. + preprocess_image (bool): Whether to preprocess the image. + """ + if preprocess_image: + image = self.preprocess_image(image) + cond = self.get_cond([image]) + torch.manual_seed(seed) + coords = self.sample_sparse_structure(cond, num_samples, sparse_structure_sampler_params) + slat = self.sample_slat(cond, coords, slat_sampler_params) + return self.decode_slat(slat, formats) + + @contextmanager + def inject_sampler_multi_image( + self, + sampler_name: str, + num_images: int, + num_steps: int, + mode: Literal['stochastic', 'multidiffusion'] = 'stochastic', + ): + """ + Inject a sampler with multiple images as condition. + + Args: + sampler_name (str): The name of the sampler to inject. + num_images (int): The number of images to condition on. + num_steps (int): The number of steps to run the sampler for. + """ + sampler = getattr(self, sampler_name) + setattr(sampler, f'_old_inference_model', sampler._inference_model) + + if mode == 'stochastic': + if num_images > num_steps: + print(f"\033[93mWarning: number of conditioning images is greater than number of steps for {sampler_name}. " + "This may lead to performance degradation.\033[0m") + + cond_indices = (np.arange(num_steps) % num_images).tolist() + def _new_inference_model(self, model, x_t, t, cond, **kwargs): + cond_idx = cond_indices.pop(0) + cond_i = cond[cond_idx:cond_idx+1] + return self._old_inference_model(model, x_t, t, cond=cond_i, **kwargs) + + elif mode =='multidiffusion': + from .samplers import FlowEulerSampler + def _new_inference_model(self, model, x_t, t, cond, neg_cond, cfg_strength, cfg_interval, **kwargs): + if cfg_interval[0] <= t <= cfg_interval[1]: + preds = [] + for i in range(len(cond)): + preds.append(FlowEulerSampler._inference_model(self, model, x_t, t, cond[i:i+1], **kwargs)) + pred = sum(preds) / len(preds) + neg_pred = FlowEulerSampler._inference_model(self, model, x_t, t, neg_cond, **kwargs) + return (1 + cfg_strength) * pred - cfg_strength * neg_pred + else: + preds = [] + for i in range(len(cond)): + preds.append(FlowEulerSampler._inference_model(self, model, x_t, t, cond[i:i+1], **kwargs)) + pred = sum(preds) / len(preds) + return pred + + else: + raise ValueError(f"Unsupported mode: {mode}") + + sampler._inference_model = _new_inference_model.__get__(sampler, type(sampler)) + + yield + + sampler._inference_model = sampler._old_inference_model + delattr(sampler, f'_old_inference_model') + + @torch.no_grad() + def run_multi_image( + self, + images: List[Image.Image], + num_samples: int = 1, + seed: int = 42, + sparse_structure_sampler_params: dict = {}, + slat_sampler_params: dict = {}, + formats: List[str] = ['mesh', 'gaussian', 'radiance_field'], + preprocess_image: bool = True, + mode: Literal['stochastic', 'multidiffusion'] = 'stochastic', + ) -> dict: + """ + Run the pipeline with multiple images as condition + + Args: + images (List[Image.Image]): The multi-view images of the assets + num_samples (int): The number of samples to generate. + sparse_structure_sampler_params (dict): Additional parameters for the sparse structure sampler. + slat_sampler_params (dict): Additional parameters for the structured latent sampler. + preprocess_image (bool): Whether to preprocess the image. + """ + if preprocess_image: + images = [self.preprocess_image(image) for image in images] + cond = self.get_cond(images) + cond['neg_cond'] = cond['neg_cond'][:1] + torch.manual_seed(seed) + ss_steps = {**self.sparse_structure_sampler_params, **sparse_structure_sampler_params}.get('steps') + with self.inject_sampler_multi_image('sparse_structure_sampler', len(images), ss_steps, mode=mode): + coords = self.sample_sparse_structure(cond, num_samples, sparse_structure_sampler_params) + slat_steps = {**self.slat_sampler_params, **slat_sampler_params}.get('steps') + with self.inject_sampler_multi_image('slat_sampler', len(images), slat_steps, mode=mode): + slat = self.sample_slat(cond, coords, slat_sampler_params) + return self.decode_slat(slat, formats) diff --git a/TRELLIS/trellis/renderers/__init__.py b/TRELLIS/trellis/renderers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..ec397d1e795baedf05d52ef49f5161885151cafc --- /dev/null +++ b/TRELLIS/trellis/renderers/__init__.py @@ -0,0 +1,31 @@ +import importlib + +__attributes = { + 'OctreeRenderer': 'octree_renderer', + 'GaussianRenderer': 'gaussian_render', + 'MeshRenderer': 'mesh_renderer', +} + +__submodules = [] + +__all__ = list(__attributes.keys()) + __submodules + +def __getattr__(name): + if name not in globals(): + if name in __attributes: + module_name = __attributes[name] + module = importlib.import_module(f".{module_name}", __name__) + globals()[name] = getattr(module, name) + elif name in __submodules: + module = importlib.import_module(f".{name}", __name__) + globals()[name] = module + else: + raise AttributeError(f"module {__name__} has no attribute {name}") + return globals()[name] + + +# For Pylance +if __name__ == '__main__': + from .octree_renderer import OctreeRenderer + from .gaussian_render import GaussianRenderer + from .mesh_renderer import MeshRenderer \ No newline at end of file diff --git a/TRELLIS/trellis/renderers/gaussian_render.py b/TRELLIS/trellis/renderers/gaussian_render.py new file mode 100644 index 0000000000000000000000000000000000000000..272cf07ceaf2cb14bc7b9b82772721d97fd954c8 --- /dev/null +++ b/TRELLIS/trellis/renderers/gaussian_render.py @@ -0,0 +1,231 @@ +# +# Copyright (C) 2023, Inria +# GRAPHDECO research group, https://team.inria.fr/graphdeco +# All rights reserved. +# +# This software is free for non-commercial, research and evaluation use +# under the terms of the LICENSE.md file. +# +# For inquiries contact george.drettakis@inria.fr +# + +import torch +import math +from easydict import EasyDict as edict +import numpy as np +from ..representations.gaussian import Gaussian +from .sh_utils import eval_sh +import torch.nn.functional as F +from easydict import EasyDict as edict + + +def intrinsics_to_projection( + intrinsics: torch.Tensor, + near: float, + far: float, + ) -> torch.Tensor: + """ + OpenCV intrinsics to OpenGL perspective matrix + + Args: + intrinsics (torch.Tensor): [3, 3] OpenCV intrinsics matrix + near (float): near plane to clip + far (float): far plane to clip + Returns: + (torch.Tensor): [4, 4] OpenGL perspective matrix + """ + fx, fy = intrinsics[0, 0], intrinsics[1, 1] + cx, cy = intrinsics[0, 2], intrinsics[1, 2] + ret = torch.zeros((4, 4), dtype=intrinsics.dtype, device=intrinsics.device) + ret[0, 0] = 2 * fx + ret[1, 1] = 2 * fy + ret[0, 2] = 2 * cx - 1 + ret[1, 2] = - 2 * cy + 1 + ret[2, 2] = far / (far - near) + ret[2, 3] = near * far / (near - far) + ret[3, 2] = 1. + return ret + + +def render(viewpoint_camera, pc : Gaussian, pipe, bg_color : torch.Tensor, scaling_modifier = 1.0, override_color = None): + """ + Render the scene. + + Background tensor (bg_color) must be on GPU! + """ + # lazy import + if 'GaussianRasterizer' not in globals(): + from diff_gaussian_rasterization import GaussianRasterizer, GaussianRasterizationSettings + + # Create zero tensor. We will use it to make pytorch return gradients of the 2D (screen-space) means + screenspace_points = torch.zeros_like(pc.get_xyz, dtype=pc.get_xyz.dtype, requires_grad=True, device="cuda") + 0 + try: + screenspace_points.retain_grad() + except: + pass + # Set up rasterization configuration + tanfovx = math.tan(viewpoint_camera.FoVx * 0.5) + tanfovy = math.tan(viewpoint_camera.FoVy * 0.5) + + kernel_size = pipe.kernel_size + subpixel_offset = torch.zeros((int(viewpoint_camera.image_height), int(viewpoint_camera.image_width), 2), dtype=torch.float32, device="cuda") + + raster_settings = GaussianRasterizationSettings( + image_height=int(viewpoint_camera.image_height), + image_width=int(viewpoint_camera.image_width), + tanfovx=tanfovx, + tanfovy=tanfovy, + kernel_size=kernel_size, + subpixel_offset=subpixel_offset, + bg=bg_color, + scale_modifier=scaling_modifier, + viewmatrix=viewpoint_camera.world_view_transform, + projmatrix=viewpoint_camera.full_proj_transform, + sh_degree=pc.active_sh_degree, + campos=viewpoint_camera.camera_center, + prefiltered=False, + debug=pipe.debug + ) + + rasterizer = GaussianRasterizer(raster_settings=raster_settings) + + means3D = pc.get_xyz + means2D = screenspace_points + opacity = pc.get_opacity + + # If precomputed 3d covariance is provided, use it. If not, then it will be computed from + # scaling / rotation by the rasterizer. + scales = None + rotations = None + cov3D_precomp = None + if pipe.compute_cov3D_python: + cov3D_precomp = pc.get_covariance(scaling_modifier) + else: + scales = pc.get_scaling + rotations = pc.get_rotation + + # If precomputed colors are provided, use them. Otherwise, if it is desired to precompute colors + # from SHs in Python, do it. If not, then SH -> RGB conversion will be done by rasterizer. + shs = None + colors_precomp = None + if override_color is None: + if pipe.convert_SHs_python: + shs_view = pc.get_features.transpose(1, 2).view(-1, 3, (pc.max_sh_degree+1)**2) + dir_pp = (pc.get_xyz - viewpoint_camera.camera_center.repeat(pc.get_features.shape[0], 1)) + dir_pp_normalized = dir_pp/dir_pp.norm(dim=1, keepdim=True) + sh2rgb = eval_sh(pc.active_sh_degree, shs_view, dir_pp_normalized) + colors_precomp = torch.clamp_min(sh2rgb + 0.5, 0.0) + else: + shs = pc.get_features + else: + colors_precomp = override_color + + # Rasterize visible Gaussians to image, obtain their radii (on screen). + rendered_image, radii = rasterizer( + means3D = means3D, + means2D = means2D, + shs = shs, + colors_precomp = colors_precomp, + opacities = opacity, + scales = scales, + rotations = rotations, + cov3D_precomp = cov3D_precomp + ) + + # Those Gaussians that were frustum culled or had a radius of 0 were not visible. + # They will be excluded from value updates used in the splitting criteria. + return edict({"render": rendered_image, + "viewspace_points": screenspace_points, + "visibility_filter" : radii > 0, + "radii": radii}) + + +class GaussianRenderer: + """ + Renderer for the Voxel representation. + + Args: + rendering_options (dict): Rendering options. + """ + + def __init__(self, rendering_options={}) -> None: + self.pipe = edict({ + "kernel_size": 0.1, + "convert_SHs_python": False, + "compute_cov3D_python": False, + "scale_modifier": 1.0, + "debug": False + }) + self.rendering_options = edict({ + "resolution": None, + "near": None, + "far": None, + "ssaa": 1, + "bg_color": 'random', + }) + self.rendering_options.update(rendering_options) + self.bg_color = None + + def render( + self, + gausssian: Gaussian, + extrinsics: torch.Tensor, + intrinsics: torch.Tensor, + colors_overwrite: torch.Tensor = None + ) -> edict: + """ + Render the gausssian. + + Args: + gaussian : gaussianmodule + extrinsics (torch.Tensor): (4, 4) camera extrinsics + intrinsics (torch.Tensor): (3, 3) camera intrinsics + colors_overwrite (torch.Tensor): (N, 3) override color + + Returns: + edict containing: + color (torch.Tensor): (3, H, W) rendered color image + """ + resolution = self.rendering_options["resolution"] + near = self.rendering_options["near"] + far = self.rendering_options["far"] + ssaa = self.rendering_options["ssaa"] + + if self.rendering_options["bg_color"] == 'random': + self.bg_color = torch.zeros(3, dtype=torch.float32, device="cuda") + if np.random.rand() < 0.5: + self.bg_color += 1 + else: + self.bg_color = torch.tensor(self.rendering_options["bg_color"], dtype=torch.float32, device="cuda") + + view = extrinsics + perspective = intrinsics_to_projection(intrinsics, near, far) + camera = torch.inverse(view)[:3, 3] + focalx = intrinsics[0, 0] + focaly = intrinsics[1, 1] + fovx = 2 * torch.atan(0.5 / focalx) + fovy = 2 * torch.atan(0.5 / focaly) + + camera_dict = edict({ + "image_height": resolution * ssaa, + "image_width": resolution * ssaa, + "FoVx": fovx, + "FoVy": fovy, + "znear": near, + "zfar": far, + "world_view_transform": view.T.contiguous(), + "projection_matrix": perspective.T.contiguous(), + "full_proj_transform": (perspective @ view).T.contiguous(), + "camera_center": camera + }) + + # Render + render_ret = render(camera_dict, gausssian, self.pipe, self.bg_color, override_color=colors_overwrite, scaling_modifier=self.pipe.scale_modifier) + + if ssaa > 1: + render_ret.render = F.interpolate(render_ret.render[None], size=(resolution, resolution), mode='bilinear', align_corners=False, antialias=True).squeeze() + + ret = edict({ + 'color': render_ret['render'] + }) + return ret diff --git a/TRELLIS/trellis/renderers/mesh_renderer.py b/TRELLIS/trellis/renderers/mesh_renderer.py new file mode 100644 index 0000000000000000000000000000000000000000..45dd8a884a5a441d1a7bcabe87235046960ae961 --- /dev/null +++ b/TRELLIS/trellis/renderers/mesh_renderer.py @@ -0,0 +1,133 @@ +import torch +import nvdiffrast.torch as dr +from easydict import EasyDict as edict +from ..representations.mesh import MeshExtractResult +import torch.nn.functional as F + + +def intrinsics_to_projection( + intrinsics: torch.Tensor, + near: float, + far: float, + ) -> torch.Tensor: + """ + OpenCV intrinsics to OpenGL perspective matrix + + Args: + intrinsics (torch.Tensor): [3, 3] OpenCV intrinsics matrix + near (float): near plane to clip + far (float): far plane to clip + Returns: + (torch.Tensor): [4, 4] OpenGL perspective matrix + """ + fx, fy = intrinsics[0, 0], intrinsics[1, 1] + cx, cy = intrinsics[0, 2], intrinsics[1, 2] + ret = torch.zeros((4, 4), dtype=intrinsics.dtype, device=intrinsics.device) + ret[0, 0] = 2 * fx + ret[1, 1] = 2 * fy + ret[0, 2] = 2 * cx - 1 + ret[1, 2] = - 2 * cy + 1 + ret[2, 2] = far / (far - near) + ret[2, 3] = near * far / (near - far) + ret[3, 2] = 1. + return ret + + +class MeshRenderer: + """ + Renderer for the Mesh representation. + + Args: + rendering_options (dict): Rendering options. + glctx (nvdiffrast.torch.RasterizeGLContext): RasterizeGLContext object for CUDA/OpenGL interop. + """ + def __init__(self, rendering_options={}, device='cuda'): + self.rendering_options = edict({ + "resolution": None, + "near": None, + "far": None, + "ssaa": 1 + }) + self.rendering_options.update(rendering_options) + self.glctx = dr.RasterizeCudaContext(device=device) + self.device=device + + def render( + self, + mesh : MeshExtractResult, + extrinsics: torch.Tensor, + intrinsics: torch.Tensor, + return_types = ["mask", "normal", "depth"] + ) -> edict: + """ + Render the mesh. + + Args: + mesh : meshmodel + extrinsics (torch.Tensor): (4, 4) camera extrinsics + intrinsics (torch.Tensor): (3, 3) camera intrinsics + return_types (list): list of return types, can be "mask", "depth", "normal_map", "normal", "color" + + Returns: + edict based on return_types containing: + color (torch.Tensor): [3, H, W] rendered color image + depth (torch.Tensor): [H, W] rendered depth image + normal (torch.Tensor): [3, H, W] rendered normal image + normal_map (torch.Tensor): [3, H, W] rendered normal map image + mask (torch.Tensor): [H, W] rendered mask image + """ + resolution = self.rendering_options["resolution"] + near = self.rendering_options["near"] + far = self.rendering_options["far"] + ssaa = self.rendering_options["ssaa"] + + if mesh.vertices.shape[0] == 0 or mesh.faces.shape[0] == 0: + default_img = torch.zeros((1, resolution, resolution, 3), dtype=torch.float32, device=self.device) + ret_dict = {k : default_img if k in ['normal', 'normal_map', 'color'] else default_img[..., :1] for k in return_types} + return ret_dict + + perspective = intrinsics_to_projection(intrinsics, near, far) + + RT = extrinsics.unsqueeze(0) + full_proj = (perspective @ extrinsics).unsqueeze(0) + + vertices = mesh.vertices.unsqueeze(0) + + vertices_homo = torch.cat([vertices, torch.ones_like(vertices[..., :1])], dim=-1) + vertices_camera = torch.bmm(vertices_homo, RT.transpose(-1, -2)) + vertices_clip = torch.bmm(vertices_homo, full_proj.transpose(-1, -2)) + faces_int = mesh.faces.int() + rast, _ = dr.rasterize( + self.glctx, vertices_clip, faces_int, (resolution * ssaa, resolution * ssaa)) + + out_dict = edict() + for type in return_types: + img = None + if type == "mask" : + img = dr.antialias((rast[..., -1:] > 0).float(), rast, vertices_clip, faces_int) + elif type == "depth": + img = dr.interpolate(vertices_camera[..., 2:3].contiguous(), rast, faces_int)[0] + img = dr.antialias(img, rast, vertices_clip, faces_int) + elif type == "normal" : + img = dr.interpolate( + mesh.face_normal.reshape(1, -1, 3), rast, + torch.arange(mesh.faces.shape[0] * 3, device=self.device, dtype=torch.int).reshape(-1, 3) + )[0] + img = dr.antialias(img, rast, vertices_clip, faces_int) + # normalize norm pictures + img = (img + 1) / 2 + elif type == "normal_map" : + img = dr.interpolate(mesh.vertex_attrs[:, 3:].contiguous(), rast, faces_int)[0] + img = dr.antialias(img, rast, vertices_clip, faces_int) + elif type == "color" : + img = dr.interpolate(mesh.vertex_attrs[:, :3].contiguous(), rast, faces_int)[0] + img = dr.antialias(img, rast, vertices_clip, faces_int) + + if ssaa > 1: + img = F.interpolate(img.permute(0, 3, 1, 2), (resolution, resolution), mode='bilinear', align_corners=False, antialias=True) + img = img.squeeze() + else: + img = img.permute(0, 3, 1, 2).squeeze() + out_dict[type] = img + + return out_dict diff --git a/TRELLIS/trellis/renderers/octree_renderer.py b/TRELLIS/trellis/renderers/octree_renderer.py new file mode 100644 index 0000000000000000000000000000000000000000..c72541888c60591109c8690bd269669faad667c0 --- /dev/null +++ b/TRELLIS/trellis/renderers/octree_renderer.py @@ -0,0 +1,300 @@ +import numpy as np +import torch +import torch.nn.functional as F +import math +import cv2 +from scipy.stats import qmc +from easydict import EasyDict as edict +from ..representations.octree import DfsOctree + + +def intrinsics_to_projection( + intrinsics: torch.Tensor, + near: float, + far: float, + ) -> torch.Tensor: + """ + OpenCV intrinsics to OpenGL perspective matrix + + Args: + intrinsics (torch.Tensor): [3, 3] OpenCV intrinsics matrix + near (float): near plane to clip + far (float): far plane to clip + Returns: + (torch.Tensor): [4, 4] OpenGL perspective matrix + """ + fx, fy = intrinsics[0, 0], intrinsics[1, 1] + cx, cy = intrinsics[0, 2], intrinsics[1, 2] + ret = torch.zeros((4, 4), dtype=intrinsics.dtype, device=intrinsics.device) + ret[0, 0] = 2 * fx + ret[1, 1] = 2 * fy + ret[0, 2] = 2 * cx - 1 + ret[1, 2] = - 2 * cy + 1 + ret[2, 2] = far / (far - near) + ret[2, 3] = near * far / (near - far) + ret[3, 2] = 1. + return ret + + +def render(viewpoint_camera, octree : DfsOctree, pipe, bg_color : torch.Tensor, scaling_modifier = 1.0, used_rank = None, colors_overwrite = None, aux=None, halton_sampler=None): + """ + Render the scene. + + Background tensor (bg_color) must be on GPU! + """ + # lazy import + if 'OctreeTrivecRasterizer' not in globals(): + from diffoctreerast import OctreeVoxelRasterizer, OctreeGaussianRasterizer, OctreeTrivecRasterizer, OctreeDecoupolyRasterizer + + # Set up rasterization configuration + tanfovx = math.tan(viewpoint_camera.FoVx * 0.5) + tanfovy = math.tan(viewpoint_camera.FoVy * 0.5) + + raster_settings = edict( + image_height=int(viewpoint_camera.image_height), + image_width=int(viewpoint_camera.image_width), + tanfovx=tanfovx, + tanfovy=tanfovy, + bg=bg_color, + scale_modifier=scaling_modifier, + viewmatrix=viewpoint_camera.world_view_transform, + projmatrix=viewpoint_camera.full_proj_transform, + sh_degree=octree.active_sh_degree, + campos=viewpoint_camera.camera_center, + with_distloss=pipe.with_distloss, + jitter=pipe.jitter, + debug=pipe.debug, + ) + + positions = octree.get_xyz + if octree.primitive == "voxel": + densities = octree.get_density + elif octree.primitive == "gaussian": + opacities = octree.get_opacity + elif octree.primitive == "trivec": + trivecs = octree.get_trivec + densities = octree.get_density + raster_settings.density_shift = octree.density_shift + elif octree.primitive == "decoupoly": + decoupolys_V, decoupolys_g = octree.get_decoupoly + densities = octree.get_density + raster_settings.density_shift = octree.density_shift + else: + raise ValueError(f"Unknown primitive {octree.primitive}") + depths = octree.get_depth + + # If precomputed colors are provided, use them. Otherwise, if it is desired to precompute colors + # from SHs in Python, do it. If not, then SH -> RGB conversion will be done by rasterizer. + colors_precomp = None + shs = octree.get_features + if octree.primitive in ["voxel", "gaussian"] and colors_overwrite is not None: + colors_precomp = colors_overwrite + shs = None + + ret = edict() + + if octree.primitive == "voxel": + renderer = OctreeVoxelRasterizer(raster_settings=raster_settings) + rgb, depth, alpha, distloss = renderer( + positions = positions, + densities = densities, + shs = shs, + colors_precomp = colors_precomp, + depths = depths, + aabb = octree.aabb, + aux = aux, + ) + ret['rgb'] = rgb + ret['depth'] = depth + ret['alpha'] = alpha + ret['distloss'] = distloss + elif octree.primitive == "gaussian": + renderer = OctreeGaussianRasterizer(raster_settings=raster_settings) + rgb, depth, alpha = renderer( + positions = positions, + opacities = opacities, + shs = shs, + colors_precomp = colors_precomp, + depths = depths, + aabb = octree.aabb, + aux = aux, + ) + ret['rgb'] = rgb + ret['depth'] = depth + ret['alpha'] = alpha + elif octree.primitive == "trivec": + raster_settings.used_rank = used_rank if used_rank is not None else trivecs.shape[1] + renderer = OctreeTrivecRasterizer(raster_settings=raster_settings) + rgb, depth, alpha, percent_depth = renderer( + positions = positions, + trivecs = trivecs, + densities = densities, + shs = shs, + colors_precomp = colors_precomp, + colors_overwrite = colors_overwrite, + depths = depths, + aabb = octree.aabb, + aux = aux, + halton_sampler = halton_sampler, + ) + ret['percent_depth'] = percent_depth + ret['rgb'] = rgb + ret['depth'] = depth + ret['alpha'] = alpha + elif octree.primitive == "decoupoly": + raster_settings.used_rank = used_rank if used_rank is not None else decoupolys_V.shape[1] + renderer = OctreeDecoupolyRasterizer(raster_settings=raster_settings) + rgb, depth, alpha = renderer( + positions = positions, + decoupolys_V = decoupolys_V, + decoupolys_g = decoupolys_g, + densities = densities, + shs = shs, + colors_precomp = colors_precomp, + depths = depths, + aabb = octree.aabb, + aux = aux, + ) + ret['rgb'] = rgb + ret['depth'] = depth + ret['alpha'] = alpha + + return ret + + +class OctreeRenderer: + """ + Renderer for the Voxel representation. + + Args: + rendering_options (dict): Rendering options. + """ + + def __init__(self, rendering_options={}) -> None: + try: + import diffoctreerast + except ImportError: + print("\033[93m[WARNING] diffoctreerast is not installed. The renderer will be disabled.\033[0m") + self.unsupported = True + else: + self.unsupported = False + + self.pipe = edict({ + "with_distloss": False, + "with_aux": False, + "scale_modifier": 1.0, + "used_rank": None, + "jitter": False, + "debug": False, + }) + self.rendering_options = edict({ + "resolution": None, + "near": None, + "far": None, + "ssaa": 1, + "bg_color": 'random', + }) + self.halton_sampler = qmc.Halton(2, scramble=False) + self.rendering_options.update(rendering_options) + self.bg_color = None + + def render( + self, + octree: DfsOctree, + extrinsics: torch.Tensor, + intrinsics: torch.Tensor, + colors_overwrite: torch.Tensor = None, + ) -> edict: + """ + Render the octree. + + Args: + octree (Octree): octree + extrinsics (torch.Tensor): (4, 4) camera extrinsics + intrinsics (torch.Tensor): (3, 3) camera intrinsics + colors_overwrite (torch.Tensor): (N, 3) override color + + Returns: + edict containing: + color (torch.Tensor): (3, H, W) rendered color + depth (torch.Tensor): (H, W) rendered depth + alpha (torch.Tensor): (H, W) rendered alpha + distloss (Optional[torch.Tensor]): (H, W) rendered distance loss + percent_depth (Optional[torch.Tensor]): (H, W) rendered percent depth + aux (Optional[edict]): auxiliary tensors + """ + resolution = self.rendering_options["resolution"] + near = self.rendering_options["near"] + far = self.rendering_options["far"] + ssaa = self.rendering_options["ssaa"] + + if self.unsupported: + image = np.zeros((512, 512, 3), dtype=np.uint8) + text_bbox = cv2.getTextSize("Unsupported", cv2.FONT_HERSHEY_SIMPLEX, 2, 3)[0] + origin = (512 - text_bbox[0]) // 2, (512 - text_bbox[1]) // 2 + image = cv2.putText(image, "Unsupported", origin, cv2.FONT_HERSHEY_SIMPLEX, 2, (255, 255, 255), 3, cv2.LINE_AA) + return { + 'color': torch.tensor(image, dtype=torch.float32).permute(2, 0, 1) / 255, + } + + if self.rendering_options["bg_color"] == 'random': + self.bg_color = torch.zeros(3, dtype=torch.float32, device="cuda") + if np.random.rand() < 0.5: + self.bg_color += 1 + else: + self.bg_color = torch.tensor(self.rendering_options["bg_color"], dtype=torch.float32, device="cuda") + + if self.pipe["with_aux"]: + aux = { + 'grad_color2': torch.zeros((octree.num_leaf_nodes, 3), dtype=torch.float32, requires_grad=True, device="cuda") + 0, + 'contributions': torch.zeros((octree.num_leaf_nodes, 1), dtype=torch.float32, requires_grad=True, device="cuda") + 0, + } + for k in aux.keys(): + aux[k].requires_grad_() + aux[k].retain_grad() + else: + aux = None + + view = extrinsics + perspective = intrinsics_to_projection(intrinsics, near, far) + camera = torch.inverse(view)[:3, 3] + focalx = intrinsics[0, 0] + focaly = intrinsics[1, 1] + fovx = 2 * torch.atan(0.5 / focalx) + fovy = 2 * torch.atan(0.5 / focaly) + + camera_dict = edict({ + "image_height": resolution * ssaa, + "image_width": resolution * ssaa, + "FoVx": fovx, + "FoVy": fovy, + "znear": near, + "zfar": far, + "world_view_transform": view.T.contiguous(), + "projection_matrix": perspective.T.contiguous(), + "full_proj_transform": (perspective @ view).T.contiguous(), + "camera_center": camera + }) + + # Render + render_ret = render(camera_dict, octree, self.pipe, self.bg_color, aux=aux, colors_overwrite=colors_overwrite, scaling_modifier=self.pipe.scale_modifier, used_rank=self.pipe.used_rank, halton_sampler=self.halton_sampler) + + if ssaa > 1: + render_ret.rgb = F.interpolate(render_ret.rgb[None], size=(resolution, resolution), mode='bilinear', align_corners=False, antialias=True).squeeze() + render_ret.depth = F.interpolate(render_ret.depth[None, None], size=(resolution, resolution), mode='bilinear', align_corners=False, antialias=True).squeeze() + render_ret.alpha = F.interpolate(render_ret.alpha[None, None], size=(resolution, resolution), mode='bilinear', align_corners=False, antialias=True).squeeze() + if hasattr(render_ret, 'percent_depth'): + render_ret.percent_depth = F.interpolate(render_ret.percent_depth[None, None], size=(resolution, resolution), mode='bilinear', align_corners=False, antialias=True).squeeze() + + ret = edict({ + 'color': render_ret.rgb, + 'depth': render_ret.depth, + 'alpha': render_ret.alpha, + }) + if self.pipe["with_distloss"] and 'distloss' in render_ret: + ret['distloss'] = render_ret.distloss + if self.pipe["with_aux"]: + ret['aux'] = aux + if hasattr(render_ret, 'percent_depth'): + ret['percent_depth'] = render_ret.percent_depth + return ret diff --git a/TRELLIS/trellis/renderers/sh_utils.py b/TRELLIS/trellis/renderers/sh_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..a54612b24cde4e1ab6da3fb37142cc5d0248ada8 --- /dev/null +++ b/TRELLIS/trellis/renderers/sh_utils.py @@ -0,0 +1,118 @@ +# Copyright 2021 The PlenOctree Authors. +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, +# this list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE +# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE +# POSSIBILITY OF SUCH DAMAGE. + +import torch + +C0 = 0.28209479177387814 +C1 = 0.4886025119029199 +C2 = [ + 1.0925484305920792, + -1.0925484305920792, + 0.31539156525252005, + -1.0925484305920792, + 0.5462742152960396 +] +C3 = [ + -0.5900435899266435, + 2.890611442640554, + -0.4570457994644658, + 0.3731763325901154, + -0.4570457994644658, + 1.445305721320277, + -0.5900435899266435 +] +C4 = [ + 2.5033429417967046, + -1.7701307697799304, + 0.9461746957575601, + -0.6690465435572892, + 0.10578554691520431, + -0.6690465435572892, + 0.47308734787878004, + -1.7701307697799304, + 0.6258357354491761, +] + + +def eval_sh(deg, sh, dirs): + """ + Evaluate spherical harmonics at unit directions + using hardcoded SH polynomials. + Works with torch/np/jnp. + ... Can be 0 or more batch dimensions. + Args: + deg: int SH deg. Currently, 0-3 supported + sh: jnp.ndarray SH coeffs [..., C, (deg + 1) ** 2] + dirs: jnp.ndarray unit directions [..., 3] + Returns: + [..., C] + """ + assert deg <= 4 and deg >= 0 + coeff = (deg + 1) ** 2 + assert sh.shape[-1] >= coeff + + result = C0 * sh[..., 0] + if deg > 0: + x, y, z = dirs[..., 0:1], dirs[..., 1:2], dirs[..., 2:3] + result = (result - + C1 * y * sh[..., 1] + + C1 * z * sh[..., 2] - + C1 * x * sh[..., 3]) + + if deg > 1: + xx, yy, zz = x * x, y * y, z * z + xy, yz, xz = x * y, y * z, x * z + result = (result + + C2[0] * xy * sh[..., 4] + + C2[1] * yz * sh[..., 5] + + C2[2] * (2.0 * zz - xx - yy) * sh[..., 6] + + C2[3] * xz * sh[..., 7] + + C2[4] * (xx - yy) * sh[..., 8]) + + if deg > 2: + result = (result + + C3[0] * y * (3 * xx - yy) * sh[..., 9] + + C3[1] * xy * z * sh[..., 10] + + C3[2] * y * (4 * zz - xx - yy)* sh[..., 11] + + C3[3] * z * (2 * zz - 3 * xx - 3 * yy) * sh[..., 12] + + C3[4] * x * (4 * zz - xx - yy) * sh[..., 13] + + C3[5] * z * (xx - yy) * sh[..., 14] + + C3[6] * x * (xx - 3 * yy) * sh[..., 15]) + + if deg > 3: + result = (result + C4[0] * xy * (xx - yy) * sh[..., 16] + + C4[1] * yz * (3 * xx - yy) * sh[..., 17] + + C4[2] * xy * (7 * zz - 1) * sh[..., 18] + + C4[3] * yz * (7 * zz - 3) * sh[..., 19] + + C4[4] * (zz * (35 * zz - 30) + 3) * sh[..., 20] + + C4[5] * xz * (7 * zz - 3) * sh[..., 21] + + C4[6] * (xx - yy) * (7 * zz - 1) * sh[..., 22] + + C4[7] * xz * (xx - 3 * yy) * sh[..., 23] + + C4[8] * (xx * (xx - 3 * yy) - yy * (3 * xx - yy)) * sh[..., 24]) + return result + +def RGB2SH(rgb): + return (rgb - 0.5) / C0 + +def SH2RGB(sh): + return sh * C0 + 0.5 \ No newline at end of file diff --git a/TRELLIS/trellis/representations/__init__.py b/TRELLIS/trellis/representations/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..26c62155fc77f668549092c55fca0c610aa02540 --- /dev/null +++ b/TRELLIS/trellis/representations/__init__.py @@ -0,0 +1,4 @@ +from .radiance_field import Strivec +from .octree import DfsOctree as Octree +from .gaussian import Gaussian +from .mesh import MeshExtractResult diff --git a/TRELLIS/trellis/representations/gaussian/__init__.py b/TRELLIS/trellis/representations/gaussian/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e3de6e180bd732836af876d748255595be2d4d74 --- /dev/null +++ b/TRELLIS/trellis/representations/gaussian/__init__.py @@ -0,0 +1 @@ +from .gaussian_model import Gaussian \ No newline at end of file diff --git a/TRELLIS/trellis/representations/gaussian/gaussian_model.py b/TRELLIS/trellis/representations/gaussian/gaussian_model.py new file mode 100644 index 0000000000000000000000000000000000000000..373411cb16aa2baf466eaa600fb4d4248bd550c0 --- /dev/null +++ b/TRELLIS/trellis/representations/gaussian/gaussian_model.py @@ -0,0 +1,209 @@ +import torch +import numpy as np +from plyfile import PlyData, PlyElement +from .general_utils import inverse_sigmoid, strip_symmetric, build_scaling_rotation +import utils3d + + +class Gaussian: + def __init__( + self, + aabb : list, + sh_degree : int = 0, + mininum_kernel_size : float = 0.0, + scaling_bias : float = 0.01, + opacity_bias : float = 0.1, + scaling_activation : str = "exp", + device='cuda' + ): + self.init_params = { + 'aabb': aabb, + 'sh_degree': sh_degree, + 'mininum_kernel_size': mininum_kernel_size, + 'scaling_bias': scaling_bias, + 'opacity_bias': opacity_bias, + 'scaling_activation': scaling_activation, + } + + self.sh_degree = sh_degree + self.active_sh_degree = sh_degree + self.mininum_kernel_size = mininum_kernel_size + self.scaling_bias = scaling_bias + self.opacity_bias = opacity_bias + self.scaling_activation_type = scaling_activation + self.device = device + self.aabb = torch.tensor(aabb, dtype=torch.float32, device=device) + self.setup_functions() + + self._xyz = None + self._features_dc = None + self._features_rest = None + self._scaling = None + self._rotation = None + self._opacity = None + + def setup_functions(self): + def build_covariance_from_scaling_rotation(scaling, scaling_modifier, rotation): + L = build_scaling_rotation(scaling_modifier * scaling, rotation) + actual_covariance = L @ L.transpose(1, 2) + symm = strip_symmetric(actual_covariance) + return symm + + if self.scaling_activation_type == "exp": + self.scaling_activation = torch.exp + self.inverse_scaling_activation = torch.log + elif self.scaling_activation_type == "softplus": + self.scaling_activation = torch.nn.functional.softplus + self.inverse_scaling_activation = lambda x: x + torch.log(-torch.expm1(-x)) + + self.covariance_activation = build_covariance_from_scaling_rotation + + self.opacity_activation = torch.sigmoid + self.inverse_opacity_activation = inverse_sigmoid + + self.rotation_activation = torch.nn.functional.normalize + + self.scale_bias = self.inverse_scaling_activation(torch.tensor(self.scaling_bias)).cuda() + self.rots_bias = torch.zeros((4)).cuda() + self.rots_bias[0] = 1 + self.opacity_bias = self.inverse_opacity_activation(torch.tensor(self.opacity_bias)).cuda() + + @property + def get_scaling(self): + scales = self.scaling_activation(self._scaling + self.scale_bias) + scales = torch.square(scales) + self.mininum_kernel_size ** 2 + scales = torch.sqrt(scales) + return scales + + @property + def get_rotation(self): + return self.rotation_activation(self._rotation + self.rots_bias[None, :]) + + @property + def get_xyz(self): + return self._xyz * self.aabb[None, 3:] + self.aabb[None, :3] + + @property + def get_features(self): + return torch.cat((self._features_dc, self._features_rest), dim=2) if self._features_rest is not None else self._features_dc + + @property + def get_opacity(self): + return self.opacity_activation(self._opacity + self.opacity_bias) + + def get_covariance(self, scaling_modifier = 1): + return self.covariance_activation(self.get_scaling, scaling_modifier, self._rotation + self.rots_bias[None, :]) + + def from_scaling(self, scales): + scales = torch.sqrt(torch.square(scales) - self.mininum_kernel_size ** 2) + self._scaling = self.inverse_scaling_activation(scales) - self.scale_bias + + def from_rotation(self, rots): + self._rotation = rots - self.rots_bias[None, :] + + def from_xyz(self, xyz): + self._xyz = (xyz - self.aabb[None, :3]) / self.aabb[None, 3:] + + def from_features(self, features): + self._features_dc = features + + def from_opacity(self, opacities): + self._opacity = self.inverse_opacity_activation(opacities) - self.opacity_bias + + def construct_list_of_attributes(self): + l = ['x', 'y', 'z', 'nx', 'ny', 'nz'] + # All channels except the 3 DC + for i in range(self._features_dc.shape[1]*self._features_dc.shape[2]): + l.append('f_dc_{}'.format(i)) + l.append('opacity') + for i in range(self._scaling.shape[1]): + l.append('scale_{}'.format(i)) + for i in range(self._rotation.shape[1]): + l.append('rot_{}'.format(i)) + return l + + def save_ply(self, path, transform=[[1, 0, 0], [0, 0, -1], [0, 1, 0]]): + xyz = self.get_xyz.detach().cpu().numpy() + normals = np.zeros_like(xyz) + f_dc = self._features_dc.detach().transpose(1, 2).flatten(start_dim=1).contiguous().cpu().numpy() + opacities = inverse_sigmoid(self.get_opacity).detach().cpu().numpy() + scale = torch.log(self.get_scaling).detach().cpu().numpy() + rotation = (self._rotation + self.rots_bias[None, :]).detach().cpu().numpy() + + if transform is not None: + transform = np.array(transform) + xyz = np.matmul(xyz, transform.T) + rotation = utils3d.numpy.quaternion_to_matrix(rotation) + rotation = np.matmul(transform, rotation) + rotation = utils3d.numpy.matrix_to_quaternion(rotation) + + dtype_full = [(attribute, 'f4') for attribute in self.construct_list_of_attributes()] + + elements = np.empty(xyz.shape[0], dtype=dtype_full) + attributes = np.concatenate((xyz, normals, f_dc, opacities, scale, rotation), axis=1) + elements[:] = list(map(tuple, attributes)) + el = PlyElement.describe(elements, 'vertex') + PlyData([el]).write(path) + + def load_ply(self, path, transform=[[1, 0, 0], [0, 0, -1], [0, 1, 0]]): + plydata = PlyData.read(path) + + xyz = np.stack((np.asarray(plydata.elements[0]["x"]), + np.asarray(plydata.elements[0]["y"]), + np.asarray(plydata.elements[0]["z"])), axis=1) + opacities = np.asarray(plydata.elements[0]["opacity"])[..., np.newaxis] + + features_dc = np.zeros((xyz.shape[0], 3, 1)) + features_dc[:, 0, 0] = np.asarray(plydata.elements[0]["f_dc_0"]) + features_dc[:, 1, 0] = np.asarray(plydata.elements[0]["f_dc_1"]) + features_dc[:, 2, 0] = np.asarray(plydata.elements[0]["f_dc_2"]) + + if self.sh_degree > 0: + extra_f_names = [p.name for p in plydata.elements[0].properties if p.name.startswith("f_rest_")] + extra_f_names = sorted(extra_f_names, key = lambda x: int(x.split('_')[-1])) + assert len(extra_f_names)==3*(self.sh_degree + 1) ** 2 - 3 + features_extra = np.zeros((xyz.shape[0], len(extra_f_names))) + for idx, attr_name in enumerate(extra_f_names): + features_extra[:, idx] = np.asarray(plydata.elements[0][attr_name]) + # Reshape (P,F*SH_coeffs) to (P, F, SH_coeffs except DC) + features_extra = features_extra.reshape((features_extra.shape[0], 3, (self.max_sh_degree + 1) ** 2 - 1)) + + scale_names = [p.name for p in plydata.elements[0].properties if p.name.startswith("scale_")] + scale_names = sorted(scale_names, key = lambda x: int(x.split('_')[-1])) + scales = np.zeros((xyz.shape[0], len(scale_names))) + for idx, attr_name in enumerate(scale_names): + scales[:, idx] = np.asarray(plydata.elements[0][attr_name]) + + rot_names = [p.name for p in plydata.elements[0].properties if p.name.startswith("rot")] + rot_names = sorted(rot_names, key = lambda x: int(x.split('_')[-1])) + rots = np.zeros((xyz.shape[0], len(rot_names))) + for idx, attr_name in enumerate(rot_names): + rots[:, idx] = np.asarray(plydata.elements[0][attr_name]) + + if transform is not None: + transform = np.array(transform) + xyz = np.matmul(xyz, transform) + rotation = utils3d.numpy.quaternion_to_matrix(rotation) + rotation = np.matmul(rotation, transform) + rotation = utils3d.numpy.matrix_to_quaternion(rotation) + + # convert to actual gaussian attributes + xyz = torch.tensor(xyz, dtype=torch.float, device=self.device) + features_dc = torch.tensor(features_dc, dtype=torch.float, device=self.device).transpose(1, 2).contiguous() + if self.sh_degree > 0: + features_extra = torch.tensor(features_extra, dtype=torch.float, device=self.device).transpose(1, 2).contiguous() + opacities = torch.sigmoid(torch.tensor(opacities, dtype=torch.float, device=self.device)) + scales = torch.exp(torch.tensor(scales, dtype=torch.float, device=self.device)) + rots = torch.tensor(rots, dtype=torch.float, device=self.device) + + # convert to _hidden attributes + self._xyz = (xyz - self.aabb[None, :3]) / self.aabb[None, 3:] + self._features_dc = features_dc + if self.sh_degree > 0: + self._features_rest = features_extra + else: + self._features_rest = None + self._opacity = self.inverse_opacity_activation(opacities) - self.opacity_bias + self._scaling = self.inverse_scaling_activation(torch.sqrt(torch.square(scales) - self.mininum_kernel_size ** 2)) - self.scale_bias + self._rotation = rots - self.rots_bias[None, :] + \ No newline at end of file diff --git a/TRELLIS/trellis/representations/gaussian/general_utils.py b/TRELLIS/trellis/representations/gaussian/general_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..ae982066ab2d04fac15e997df9dbd37620ad08ed --- /dev/null +++ b/TRELLIS/trellis/representations/gaussian/general_utils.py @@ -0,0 +1,133 @@ +# +# Copyright (C) 2023, Inria +# GRAPHDECO research group, https://team.inria.fr/graphdeco +# All rights reserved. +# +# This software is free for non-commercial, research and evaluation use +# under the terms of the LICENSE.md file. +# +# For inquiries contact george.drettakis@inria.fr +# + +import torch +import sys +from datetime import datetime +import numpy as np +import random + +def inverse_sigmoid(x): + return torch.log(x/(1-x)) + +def PILtoTorch(pil_image, resolution): + resized_image_PIL = pil_image.resize(resolution) + resized_image = torch.from_numpy(np.array(resized_image_PIL)) / 255.0 + if len(resized_image.shape) == 3: + return resized_image.permute(2, 0, 1) + else: + return resized_image.unsqueeze(dim=-1).permute(2, 0, 1) + +def get_expon_lr_func( + lr_init, lr_final, lr_delay_steps=0, lr_delay_mult=1.0, max_steps=1000000 +): + """ + Copied from Plenoxels + + Continuous learning rate decay function. Adapted from JaxNeRF + The returned rate is lr_init when step=0 and lr_final when step=max_steps, and + is log-linearly interpolated elsewhere (equivalent to exponential decay). + If lr_delay_steps>0 then the learning rate will be scaled by some smooth + function of lr_delay_mult, such that the initial learning rate is + lr_init*lr_delay_mult at the beginning of optimization but will be eased back + to the normal learning rate when steps>lr_delay_steps. + :param conf: config subtree 'lr' or similar + :param max_steps: int, the number of steps during optimization. + :return HoF which takes step as input + """ + + def helper(step): + if step < 0 or (lr_init == 0.0 and lr_final == 0.0): + # Disable this parameter + return 0.0 + if lr_delay_steps > 0: + # A kind of reverse cosine decay. + delay_rate = lr_delay_mult + (1 - lr_delay_mult) * np.sin( + 0.5 * np.pi * np.clip(step / lr_delay_steps, 0, 1) + ) + else: + delay_rate = 1.0 + t = np.clip(step / max_steps, 0, 1) + log_lerp = np.exp(np.log(lr_init) * (1 - t) + np.log(lr_final) * t) + return delay_rate * log_lerp + + return helper + +def strip_lowerdiag(L): + uncertainty = torch.zeros((L.shape[0], 6), dtype=torch.float, device="cuda") + + uncertainty[:, 0] = L[:, 0, 0] + uncertainty[:, 1] = L[:, 0, 1] + uncertainty[:, 2] = L[:, 0, 2] + uncertainty[:, 3] = L[:, 1, 1] + uncertainty[:, 4] = L[:, 1, 2] + uncertainty[:, 5] = L[:, 2, 2] + return uncertainty + +def strip_symmetric(sym): + return strip_lowerdiag(sym) + +def build_rotation(r): + norm = torch.sqrt(r[:,0]*r[:,0] + r[:,1]*r[:,1] + r[:,2]*r[:,2] + r[:,3]*r[:,3]) + + q = r / norm[:, None] + + R = torch.zeros((q.size(0), 3, 3), device='cuda') + + r = q[:, 0] + x = q[:, 1] + y = q[:, 2] + z = q[:, 3] + + R[:, 0, 0] = 1 - 2 * (y*y + z*z) + R[:, 0, 1] = 2 * (x*y - r*z) + R[:, 0, 2] = 2 * (x*z + r*y) + R[:, 1, 0] = 2 * (x*y + r*z) + R[:, 1, 1] = 1 - 2 * (x*x + z*z) + R[:, 1, 2] = 2 * (y*z - r*x) + R[:, 2, 0] = 2 * (x*z - r*y) + R[:, 2, 1] = 2 * (y*z + r*x) + R[:, 2, 2] = 1 - 2 * (x*x + y*y) + return R + +def build_scaling_rotation(s, r): + L = torch.zeros((s.shape[0], 3, 3), dtype=torch.float, device="cuda") + R = build_rotation(r) + + L[:,0,0] = s[:,0] + L[:,1,1] = s[:,1] + L[:,2,2] = s[:,2] + + L = R @ L + return L + +def safe_state(silent): + old_f = sys.stdout + class F: + def __init__(self, silent): + self.silent = silent + + def write(self, x): + if not self.silent: + if x.endswith("\n"): + old_f.write(x.replace("\n", " [{}]\n".format(str(datetime.now().strftime("%d/%m %H:%M:%S"))))) + else: + old_f.write(x) + + def flush(self): + old_f.flush() + + sys.stdout = F(silent) + + random.seed(0) + np.random.seed(0) + torch.manual_seed(0) + torch.cuda.set_device(torch.device("cuda:0")) diff --git a/TRELLIS/trellis/representations/mesh/__init__.py b/TRELLIS/trellis/representations/mesh/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..fffa2c12907ed38dce29455df884c474704bd663 --- /dev/null +++ b/TRELLIS/trellis/representations/mesh/__init__.py @@ -0,0 +1 @@ +from .cube2mesh import SparseFeatures2Mesh, MeshExtractResult diff --git a/TRELLIS/trellis/representations/mesh/cube2mesh.py b/TRELLIS/trellis/representations/mesh/cube2mesh.py new file mode 100644 index 0000000000000000000000000000000000000000..27193189c8044c2ff5ca3fdf198412db4061d18c --- /dev/null +++ b/TRELLIS/trellis/representations/mesh/cube2mesh.py @@ -0,0 +1,143 @@ +import torch +from ...modules.sparse import SparseTensor +from easydict import EasyDict as edict +from .utils_cube import * +from .flexicubes.flexicubes import FlexiCubes + + +class MeshExtractResult: + def __init__(self, + vertices, + faces, + vertex_attrs=None, + res=64 + ): + self.vertices = vertices + self.faces = faces.long() + self.vertex_attrs = vertex_attrs + self.face_normal = self.comput_face_normals(vertices, faces) + self.res = res + self.success = (vertices.shape[0] != 0 and faces.shape[0] != 0) + + # training only + self.tsdf_v = None + self.tsdf_s = None + self.reg_loss = None + + def comput_face_normals(self, verts, faces): + i0 = faces[..., 0].long() + i1 = faces[..., 1].long() + i2 = faces[..., 2].long() + + v0 = verts[i0, :] + v1 = verts[i1, :] + v2 = verts[i2, :] + face_normals = torch.cross(v1 - v0, v2 - v0, dim=-1) + face_normals = torch.nn.functional.normalize(face_normals, dim=1) + # print(face_normals.min(), face_normals.max(), face_normals.shape) + return face_normals[:, None, :].repeat(1, 3, 1) + + def comput_v_normals(self, verts, faces): + i0 = faces[..., 0].long() + i1 = faces[..., 1].long() + i2 = faces[..., 2].long() + + v0 = verts[i0, :] + v1 = verts[i1, :] + v2 = verts[i2, :] + face_normals = torch.cross(v1 - v0, v2 - v0, dim=-1) + v_normals = torch.zeros_like(verts) + v_normals.scatter_add_(0, i0[..., None].repeat(1, 3), face_normals) + v_normals.scatter_add_(0, i1[..., None].repeat(1, 3), face_normals) + v_normals.scatter_add_(0, i2[..., None].repeat(1, 3), face_normals) + + v_normals = torch.nn.functional.normalize(v_normals, dim=1) + return v_normals + + +class SparseFeatures2Mesh: + def __init__(self, device="cuda", res=64, use_color=True): + ''' + a model to generate a mesh from sparse features structures using flexicube + ''' + super().__init__() + self.device=device + self.res = res + self.mesh_extractor = FlexiCubes(device=device) + self.sdf_bias = -1.0 / res + verts, cube = construct_dense_grid(self.res, self.device) + self.reg_c = cube.to(self.device) + self.reg_v = verts.to(self.device) + self.use_color = use_color + self._calc_layout() + + def _calc_layout(self): + LAYOUTS = { + 'sdf': {'shape': (8, 1), 'size': 8}, + 'deform': {'shape': (8, 3), 'size': 8 * 3}, + 'weights': {'shape': (21,), 'size': 21} + } + if self.use_color: + ''' + 6 channel color including normal map + ''' + LAYOUTS['color'] = {'shape': (8, 6,), 'size': 8 * 6} + self.layouts = edict(LAYOUTS) + start = 0 + for k, v in self.layouts.items(): + v['range'] = (start, start + v['size']) + start += v['size'] + self.feats_channels = start + + def get_layout(self, feats : torch.Tensor, name : str): + if name not in self.layouts: + return None + return feats[:, self.layouts[name]['range'][0]:self.layouts[name]['range'][1]].reshape(-1, *self.layouts[name]['shape']) + + def __call__(self, cubefeats : SparseTensor, training=False): + """ + Generates a mesh based on the specified sparse voxel structures. + Args: + cube_attrs [Nx21] : Sparse Tensor attrs about cube weights + verts_attrs [Nx10] : [0:1] SDF [1:4] deform [4:7] color [7:10] normal + Returns: + return the success tag and ni you loss, + """ + # add sdf bias to verts_attrs + coords = cubefeats.coords[:, 1:] + feats = cubefeats.feats + + sdf, deform, color, weights = [self.get_layout(feats, name) for name in ['sdf', 'deform', 'color', 'weights']] + sdf += self.sdf_bias + v_attrs = [sdf, deform, color] if self.use_color else [sdf, deform] + v_pos, v_attrs, reg_loss = sparse_cube2verts(coords, torch.cat(v_attrs, dim=-1), training=training) + v_attrs_d = get_dense_attrs(v_pos, v_attrs, res=self.res+1, sdf_init=True) + weights_d = get_dense_attrs(coords, weights, res=self.res, sdf_init=False) + if self.use_color: + sdf_d, deform_d, colors_d = v_attrs_d[..., 0], v_attrs_d[..., 1:4], v_attrs_d[..., 4:] + else: + sdf_d, deform_d = v_attrs_d[..., 0], v_attrs_d[..., 1:4] + colors_d = None + + x_nx3 = get_defomed_verts(self.reg_v, deform_d, self.res) + + vertices, faces, L_dev, colors = self.mesh_extractor( + voxelgrid_vertices=x_nx3, + scalar_field=sdf_d, + cube_idx=self.reg_c, + resolution=self.res, + beta=weights_d[:, :12], + alpha=weights_d[:, 12:20], + gamma_f=weights_d[:, 20], + voxelgrid_colors=colors_d, + training=training) + + mesh = MeshExtractResult(vertices=vertices, faces=faces, vertex_attrs=colors, res=self.res) + if training: + if mesh.success: + reg_loss += L_dev.mean() * 0.5 + reg_loss += (weights[:,:20]).abs().mean() * 0.2 + mesh.reg_loss = reg_loss + mesh.tsdf_v = get_defomed_verts(v_pos, v_attrs[:, 1:4], self.res) + mesh.tsdf_s = v_attrs[:, 0] + return mesh diff --git a/TRELLIS/trellis/representations/mesh/utils_cube.py b/TRELLIS/trellis/representations/mesh/utils_cube.py new file mode 100644 index 0000000000000000000000000000000000000000..9befc1de4561d2d682873e0b752948275ddc9189 --- /dev/null +++ b/TRELLIS/trellis/representations/mesh/utils_cube.py @@ -0,0 +1,61 @@ +import torch +cube_corners = torch.tensor([[0, 0, 0], [1, 0, 0], [0, 1, 0], [1, 1, 0], [0, 0, 1], [ + 1, 0, 1], [0, 1, 1], [1, 1, 1]], dtype=torch.int) +cube_neighbor = torch.tensor([[1, 0, 0], [-1, 0, 0], [0, 1, 0], [0, -1, 0], [0, 0, 1], [0, 0, -1]]) +cube_edges = torch.tensor([0, 1, 1, 5, 4, 5, 0, 4, 2, 3, 3, 7, 6, 7, 2, 6, + 2, 0, 3, 1, 7, 5, 6, 4], dtype=torch.long, requires_grad=False) + +def construct_dense_grid(res, device='cuda'): + '''construct a dense grid based on resolution''' + res_v = res + 1 + vertsid = torch.arange(res_v ** 3, device=device) + coordsid = vertsid.reshape(res_v, res_v, res_v)[:res, :res, :res].flatten() + cube_corners_bias = (cube_corners[:, 0] * res_v + cube_corners[:, 1]) * res_v + cube_corners[:, 2] + cube_fx8 = (coordsid.unsqueeze(1) + cube_corners_bias.unsqueeze(0).to(device)) + verts = torch.stack([vertsid // (res_v ** 2), (vertsid // res_v) % res_v, vertsid % res_v], dim=1) + return verts, cube_fx8 + + +def construct_voxel_grid(coords): + verts = (cube_corners.unsqueeze(0).to(coords) + coords.unsqueeze(1)).reshape(-1, 3) + verts_unique, inverse_indices = torch.unique(verts, dim=0, return_inverse=True) + cubes = inverse_indices.reshape(-1, 8) + return verts_unique, cubes + + +def cubes_to_verts(num_verts, cubes, value, reduce='mean'): + """ + Args: + cubes [Vx8] verts index for each cube + value [Vx8xM] value to be scattered + Operation: + reduced[cubes[i][j]][k] += value[i][k] + """ + M = value.shape[2] # number of channels + reduced = torch.zeros(num_verts, M, device=cubes.device) + return torch.scatter_reduce(reduced, 0, + cubes.unsqueeze(-1).expand(-1, -1, M).flatten(0, 1), + value.flatten(0, 1), reduce=reduce, include_self=False) + +def sparse_cube2verts(coords, feats, training=True): + new_coords, cubes = construct_voxel_grid(coords) + new_feats = cubes_to_verts(new_coords.shape[0], cubes, feats) + if training: + con_loss = torch.mean((feats - new_feats[cubes]) ** 2) + else: + con_loss = 0.0 + return new_coords, new_feats, con_loss + + +def get_dense_attrs(coords : torch.Tensor, feats : torch.Tensor, res : int, sdf_init=True): + F = feats.shape[-1] + dense_attrs = torch.zeros([res] * 3 + [F], device=feats.device) + if sdf_init: + dense_attrs[..., 0] = 1 # initial outside sdf value + dense_attrs[coords[:, 0], coords[:, 1], coords[:, 2], :] = feats + return dense_attrs.reshape(-1, F) + + +def get_defomed_verts(v_pos : torch.Tensor, deform : torch.Tensor, res): + return v_pos / res - 0.5 + (1 - 1e-8) / (res * 2) * torch.tanh(deform) + \ No newline at end of file diff --git a/TRELLIS/trellis/representations/octree/__init__.py b/TRELLIS/trellis/representations/octree/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..f66a39a5a7498e2e99fe9d94d663796b3bc157b5 --- /dev/null +++ b/TRELLIS/trellis/representations/octree/__init__.py @@ -0,0 +1 @@ +from .octree_dfs import DfsOctree \ No newline at end of file diff --git a/TRELLIS/trellis/representations/octree/octree_dfs.py b/TRELLIS/trellis/representations/octree/octree_dfs.py new file mode 100644 index 0000000000000000000000000000000000000000..710f18b73c6a68acbc7bfb470efef7632fbbd6ed --- /dev/null +++ b/TRELLIS/trellis/representations/octree/octree_dfs.py @@ -0,0 +1,362 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + + +DEFAULT_TRIVEC_CONFIG = { + 'dim': 8, + 'rank': 8, +} + +DEFAULT_VOXEL_CONFIG = { + 'solid': False, +} + +DEFAULT_DECOPOLY_CONFIG = { + 'degree': 8, + 'rank': 16, +} + + +class DfsOctree: + """ + Sparse Voxel Octree (SVO) implementation for PyTorch. + Using Depth-First Search (DFS) order to store the octree. + DFS order suits rendering and ray tracing. + + The structure and data are separatedly stored. + Structure is stored as a continuous array, each element is a 3*32 bits descriptor. + |-----------------------------------------| + | 0:3 bits | 4:31 bits | + | leaf num | unused | + |-----------------------------------------| + | 0:31 bits | + | child ptr | + |-----------------------------------------| + | 0:31 bits | + | data ptr | + |-----------------------------------------| + Each element represents a non-leaf node in the octree. + The valid mask is used to indicate whether the children are valid. + The leaf mask is used to indicate whether the children are leaf nodes. + The child ptr is used to point to the first non-leaf child. Non-leaf children descriptors are stored continuously from the child ptr. + The data ptr is used to point to the data of leaf children. Leaf children data are stored continuously from the data ptr. + + There are also auxiliary arrays to store the additional structural information to facilitate parallel processing. + - Position: the position of the octree nodes. + - Depth: the depth of the octree nodes. + + Args: + depth (int): the depth of the octree. + """ + + def __init__( + self, + depth, + aabb=[0,0,0,1,1,1], + sh_degree=2, + primitive='voxel', + primitive_config={}, + device='cuda', + ): + self.max_depth = depth + self.aabb = torch.tensor(aabb, dtype=torch.float32, device=device) + self.device = device + self.sh_degree = sh_degree + self.active_sh_degree = sh_degree + self.primitive = primitive + self.primitive_config = primitive_config + + self.structure = torch.tensor([[8, 1, 0]], dtype=torch.int32, device=self.device) + self.position = torch.zeros((8, 3), dtype=torch.float32, device=self.device) + self.depth = torch.zeros((8, 1), dtype=torch.uint8, device=self.device) + self.position[:, 0] = torch.tensor([0.25, 0.75, 0.25, 0.75, 0.25, 0.75, 0.25, 0.75], device=self.device) + self.position[:, 1] = torch.tensor([0.25, 0.25, 0.75, 0.75, 0.25, 0.25, 0.75, 0.75], device=self.device) + self.position[:, 2] = torch.tensor([0.25, 0.25, 0.25, 0.25, 0.75, 0.75, 0.75, 0.75], device=self.device) + self.depth[:, 0] = 1 + + self.data = ['position', 'depth'] + self.param_names = [] + + if primitive == 'voxel': + self.features_dc = torch.zeros((8, 1, 3), dtype=torch.float32, device=self.device) + self.features_ac = torch.zeros((8, (sh_degree+1)**2-1, 3), dtype=torch.float32, device=self.device) + self.data += ['features_dc', 'features_ac'] + self.param_names += ['features_dc', 'features_ac'] + if not primitive_config.get('solid', False): + self.density = torch.zeros((8, 1), dtype=torch.float32, device=self.device) + self.data.append('density') + self.param_names.append('density') + elif primitive == 'gaussian': + self.features_dc = torch.zeros((8, 1, 3), dtype=torch.float32, device=self.device) + self.features_ac = torch.zeros((8, (sh_degree+1)**2-1, 3), dtype=torch.float32, device=self.device) + self.opacity = torch.zeros((8, 1), dtype=torch.float32, device=self.device) + self.data += ['features_dc', 'features_ac', 'opacity'] + self.param_names += ['features_dc', 'features_ac', 'opacity'] + elif primitive == 'trivec': + self.trivec = torch.zeros((8, primitive_config['rank'], 3, primitive_config['dim']), dtype=torch.float32, device=self.device) + self.density = torch.zeros((8, primitive_config['rank']), dtype=torch.float32, device=self.device) + self.features_dc = torch.zeros((8, primitive_config['rank'], 1, 3), dtype=torch.float32, device=self.device) + self.features_ac = torch.zeros((8, primitive_config['rank'], (sh_degree+1)**2-1, 3), dtype=torch.float32, device=self.device) + self.density_shift = 0 + self.data += ['trivec', 'density', 'features_dc', 'features_ac'] + self.param_names += ['trivec', 'density', 'features_dc', 'features_ac'] + elif primitive == 'decoupoly': + self.decoupoly_V = torch.zeros((8, primitive_config['rank'], 3), dtype=torch.float32, device=self.device) + self.decoupoly_g = torch.zeros((8, primitive_config['rank'], primitive_config['degree']), dtype=torch.float32, device=self.device) + self.density = torch.zeros((8, primitive_config['rank']), dtype=torch.float32, device=self.device) + self.features_dc = torch.zeros((8, primitive_config['rank'], 1, 3), dtype=torch.float32, device=self.device) + self.features_ac = torch.zeros((8, primitive_config['rank'], (sh_degree+1)**2-1, 3), dtype=torch.float32, device=self.device) + self.density_shift = 0 + self.data += ['decoupoly_V', 'decoupoly_g', 'density', 'features_dc', 'features_ac'] + self.param_names += ['decoupoly_V', 'decoupoly_g', 'density', 'features_dc', 'features_ac'] + + self.setup_functions() + + def setup_functions(self): + self.density_activation = (lambda x: torch.exp(x - 2)) if self.primitive != 'trivec' else (lambda x: x) + self.opacity_activation = lambda x: torch.sigmoid(x - 6) + self.inverse_opacity_activation = lambda x: torch.log(x / (1 - x)) + 6 + self.color_activation = lambda x: torch.sigmoid(x) + + @property + def num_non_leaf_nodes(self): + return self.structure.shape[0] + + @property + def num_leaf_nodes(self): + return self.depth.shape[0] + + @property + def cur_depth(self): + return self.depth.max().item() + + @property + def occupancy(self): + return self.num_leaf_nodes / 8 ** self.cur_depth + + @property + def get_xyz(self): + return self.position + + @property + def get_depth(self): + return self.depth + + @property + def get_density(self): + if self.primitive == 'voxel' and self.voxel_config['solid']: + return torch.full((self.position.shape[0], 1), 1000, dtype=torch.float32, device=self.device) + return self.density_activation(self.density) + + @property + def get_opacity(self): + return self.opacity_activation(self.density) + + @property + def get_trivec(self): + return self.trivec + + @property + def get_decoupoly(self): + return F.normalize(self.decoupoly_V, dim=-1), self.decoupoly_g + + @property + def get_color(self): + return self.color_activation(self.colors) + + @property + def get_features(self): + if self.sh_degree == 0: + return self.features_dc + return torch.cat([self.features_dc, self.features_ac], dim=-2) + + def state_dict(self): + ret = {'structure': self.structure, 'position': self.position, 'depth': self.depth, 'sh_degree': self.sh_degree, 'active_sh_degree': self.active_sh_degree, 'trivec_config': self.trivec_config, 'voxel_config': self.voxel_config, 'primitive': self.primitive} + if hasattr(self, 'density_shift'): + ret['density_shift'] = self.density_shift + for data in set(self.data + self.param_names): + if not isinstance(getattr(self, data), nn.Module): + ret[data] = getattr(self, data) + else: + ret[data] = getattr(self, data).state_dict() + return ret + + def load_state_dict(self, state_dict): + keys = list(set(self.data + self.param_names + list(state_dict.keys()) + ['structure', 'position', 'depth'])) + for key in keys: + if key not in state_dict: + print(f"Warning: key {key} not found in the state_dict.") + continue + try: + if not isinstance(getattr(self, key), nn.Module): + setattr(self, key, state_dict[key]) + else: + getattr(self, key).load_state_dict(state_dict[key]) + except Exception as e: + print(e) + raise ValueError(f"Error loading key {key}.") + + def gather_from_leaf_children(self, data): + """ + Gather the data from the leaf children. + + Args: + data (torch.Tensor): the data to gather. The first dimension should be the number of leaf nodes. + """ + leaf_cnt = self.structure[:, 0] + leaf_cnt_masks = [leaf_cnt == i for i in range(1, 9)] + ret = torch.zeros((self.num_non_leaf_nodes,), dtype=data.dtype, device=self.device) + for i in range(8): + if leaf_cnt_masks[i].sum() == 0: + continue + start = self.structure[leaf_cnt_masks[i], 2] + for j in range(i+1): + ret[leaf_cnt_masks[i]] += data[start + j] + return ret + + def gather_from_non_leaf_children(self, data): + """ + Gather the data from the non-leaf children. + + Args: + data (torch.Tensor): the data to gather. The first dimension should be the number of leaf nodes. + """ + non_leaf_cnt = 8 - self.structure[:, 0] + non_leaf_cnt_masks = [non_leaf_cnt == i for i in range(1, 9)] + ret = torch.zeros_like(data, device=self.device) + for i in range(8): + if non_leaf_cnt_masks[i].sum() == 0: + continue + start = self.structure[non_leaf_cnt_masks[i], 1] + for j in range(i+1): + ret[non_leaf_cnt_masks[i]] += data[start + j] + return ret + + def structure_control(self, mask): + """ + Control the structure of the octree. + + Args: + mask (torch.Tensor): the mask to control the structure. 1 for subdivide, -1 for merge, 0 for keep. + """ + # Dont subdivide when the depth is the maximum. + mask[self.depth.squeeze() == self.max_depth] = torch.clamp_max(mask[self.depth.squeeze() == self.max_depth], 0) + # Dont merge when the depth is the minimum. + mask[self.depth.squeeze() == 1] = torch.clamp_min(mask[self.depth.squeeze() == 1], 0) + + # Gather control mask + structre_ctrl = self.gather_from_leaf_children(mask) + structre_ctrl[structre_ctrl==-8] = -1 + + new_leaf_num = self.structure[:, 0].clone() + # Modify the leaf num. + structre_valid = structre_ctrl >= 0 + new_leaf_num[structre_valid] -= structre_ctrl[structre_valid] # Add the new nodes. + structre_delete = structre_ctrl < 0 + merged_nodes = self.gather_from_non_leaf_children(structre_delete.int()) + new_leaf_num += merged_nodes # Delete the merged nodes. + + # Update the structure array to allocate new nodes. + mem_offset = torch.zeros((self.num_non_leaf_nodes + 1,), dtype=torch.int32, device=self.device) + mem_offset.index_add_(0, self.structure[structre_valid, 1], structre_ctrl[structre_valid]) # Add the new nodes. + mem_offset[:-1] -= structre_delete.int() # Delete the merged nodes. + new_structre_idx = torch.arange(0, self.num_non_leaf_nodes + 1, dtype=torch.int32, device=self.device) + mem_offset.cumsum(0) + new_structure_length = new_structre_idx[-1].item() + new_structre_idx = new_structre_idx[:-1] + new_structure = torch.empty((new_structure_length, 3), dtype=torch.int32, device=self.device) + new_structure[new_structre_idx[structre_valid], 0] = new_leaf_num[structre_valid] + + # Initialize the new nodes. + new_node_mask = torch.ones((new_structure_length,), dtype=torch.bool, device=self.device) + new_node_mask[new_structre_idx[structre_valid]] = False + new_structure[new_node_mask, 0] = 8 # Initialize to all leaf nodes. + new_node_num = new_node_mask.sum().item() + + # Rebuild child ptr. + non_leaf_cnt = 8 - new_structure[:, 0] + new_child_ptr = torch.cat([torch.zeros((1,), dtype=torch.int32, device=self.device), non_leaf_cnt.cumsum(0)[:-1]]) + new_structure[:, 1] = new_child_ptr + 1 + + # Rebuild data ptr with old data. + leaf_cnt = torch.zeros((new_structure_length,), dtype=torch.int32, device=self.device) + leaf_cnt.index_add_(0, new_structre_idx, self.structure[:, 0]) + old_data_ptr = torch.cat([torch.zeros((1,), dtype=torch.int32, device=self.device), leaf_cnt.cumsum(0)[:-1]]) + + # Update the data array + subdivide_mask = mask == 1 + merge_mask = mask == -1 + data_valid = ~(subdivide_mask | merge_mask) + mem_offset = torch.zeros((self.num_leaf_nodes + 1,), dtype=torch.int32, device=self.device) + mem_offset.index_add_(0, old_data_ptr[new_node_mask], torch.full((new_node_num,), 8, dtype=torch.int32, device=self.device)) # Add data array for new nodes + mem_offset[:-1] -= subdivide_mask.int() # Delete data elements for subdivide nodes + mem_offset[:-1] -= merge_mask.int() # Delete data elements for merge nodes + mem_offset.index_add_(0, self.structure[structre_valid, 2], merged_nodes[structre_valid]) # Add data elements for merge nodes + new_data_idx = torch.arange(0, self.num_leaf_nodes + 1, dtype=torch.int32, device=self.device) + mem_offset.cumsum(0) + new_data_length = new_data_idx[-1].item() + new_data_idx = new_data_idx[:-1] + new_data = {data: torch.empty((new_data_length,) + getattr(self, data).shape[1:], dtype=getattr(self, data).dtype, device=self.device) for data in self.data} + for data in self.data: + new_data[data][new_data_idx[data_valid]] = getattr(self, data)[data_valid] + + # Rebuild data ptr + leaf_cnt = new_structure[:, 0] + new_data_ptr = torch.cat([torch.zeros((1,), dtype=torch.int32, device=self.device), leaf_cnt.cumsum(0)[:-1]]) + new_structure[:, 2] = new_data_ptr + + # Initialize the new data array + ## For subdivide nodes + if subdivide_mask.sum() > 0: + subdivide_data_ptr = new_structure[new_node_mask, 2] + for data in self.data: + for i in range(8): + if data == 'position': + offset = torch.tensor([i // 4, (i // 2) % 2, i % 2], dtype=torch.float32, device=self.device) - 0.5 + scale = 2 ** (-1.0 - self.depth[subdivide_mask]) + new_data['position'][subdivide_data_ptr + i] = self.position[subdivide_mask] + offset * scale + elif data == 'depth': + new_data['depth'][subdivide_data_ptr + i] = self.depth[subdivide_mask] + 1 + elif data == 'opacity': + new_data['opacity'][subdivide_data_ptr + i] = self.inverse_opacity_activation(torch.sqrt(self.opacity_activation(self.opacity[subdivide_mask]))) + elif data == 'trivec': + offset = torch.tensor([i // 4, (i // 2) % 2, i % 2], dtype=torch.float32, device=self.device) * 0.5 + coord = (torch.linspace(0, 0.5, self.trivec.shape[-1], dtype=torch.float32, device=self.device)[None] + offset[:, None]).reshape(1, 3, self.trivec.shape[-1], 1) + axis = torch.linspace(0, 1, 3, dtype=torch.float32, device=self.device).reshape(1, 3, 1, 1).repeat(1, 1, self.trivec.shape[-1], 1) + coord = torch.stack([coord, axis], dim=3).reshape(1, 3, self.trivec.shape[-1], 2).expand(self.trivec[subdivide_mask].shape[0], -1, -1, -1) * 2 - 1 + new_data['trivec'][subdivide_data_ptr + i] = F.grid_sample(self.trivec[subdivide_mask], coord, align_corners=True) + else: + new_data[data][subdivide_data_ptr + i] = getattr(self, data)[subdivide_mask] + ## For merge nodes + if merge_mask.sum() > 0: + merge_data_ptr = torch.empty((merged_nodes.sum().item(),), dtype=torch.int32, device=self.device) + merge_nodes_cumsum = torch.cat([torch.zeros((1,), dtype=torch.int32, device=self.device), merged_nodes.cumsum(0)[:-1]]) + for i in range(8): + merge_data_ptr[merge_nodes_cumsum[merged_nodes > i] + i] = new_structure[new_structre_idx[merged_nodes > i], 2] + i + old_merge_data_ptr = self.structure[structre_delete, 2] + for data in self.data: + if data == 'position': + scale = 2 ** (1.0 - self.depth[old_merge_data_ptr]) + new_data['position'][merge_data_ptr] = ((self.position[old_merge_data_ptr] + 0.5) / scale).floor() * scale + 0.5 * scale - 0.5 + elif data == 'depth': + new_data['depth'][merge_data_ptr] = self.depth[old_merge_data_ptr] - 1 + elif data == 'opacity': + new_data['opacity'][subdivide_data_ptr + i] = self.inverse_opacity_activation(self.opacity_activation(self.opacity[subdivide_mask])**2) + elif data == 'trivec': + new_data['trivec'][merge_data_ptr] = self.trivec[old_merge_data_ptr] + else: + new_data[data][merge_data_ptr] = getattr(self, data)[old_merge_data_ptr] + + # Update the structure and data array + self.structure = new_structure + for data in self.data: + setattr(self, data, new_data[data]) + + # Save data array control temp variables + self.data_rearrange_buffer = { + 'subdivide_mask': subdivide_mask, + 'merge_mask': merge_mask, + 'data_valid': data_valid, + 'new_data_idx': new_data_idx, + 'new_data_length': new_data_length, + 'new_data': new_data + } diff --git a/TRELLIS/trellis/representations/radiance_field/__init__.py b/TRELLIS/trellis/representations/radiance_field/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b72a1b7e76b509ee5a5e6979858eb17b4158a151 --- /dev/null +++ b/TRELLIS/trellis/representations/radiance_field/__init__.py @@ -0,0 +1 @@ +from .strivec import Strivec \ No newline at end of file diff --git a/TRELLIS/trellis/representations/radiance_field/strivec.py b/TRELLIS/trellis/representations/radiance_field/strivec.py new file mode 100644 index 0000000000000000000000000000000000000000..f2dc78cdd5a08e994daa57247a3f01f3d43986f9 --- /dev/null +++ b/TRELLIS/trellis/representations/radiance_field/strivec.py @@ -0,0 +1,28 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import numpy as np +from ..octree import DfsOctree as Octree + + +class Strivec(Octree): + def __init__( + self, + resolution: int, + aabb: list, + sh_degree: int = 0, + rank: int = 8, + dim: int = 8, + device: str = "cuda", + ): + assert np.log2(resolution) % 1 == 0, "Resolution must be a power of 2" + self.resolution = resolution + depth = int(np.round(np.log2(resolution))) + super().__init__( + depth=depth, + aabb=aabb, + sh_degree=sh_degree, + primitive="trivec", + primitive_config={"rank": rank, "dim": dim}, + device=device, + ) diff --git a/TRELLIS/trellis/utils/__init__.py b/TRELLIS/trellis/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/TRELLIS/trellis/utils/general_utils.py b/TRELLIS/trellis/utils/general_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..b91e6d075dbb2c02438b5f345ec3deb164fa7a8f --- /dev/null +++ b/TRELLIS/trellis/utils/general_utils.py @@ -0,0 +1,187 @@ +import numpy as np +import cv2 +import torch + + +# Dictionary utils +def _dict_merge(dicta, dictb, prefix=''): + """ + Merge two dictionaries. + """ + assert isinstance(dicta, dict), 'input must be a dictionary' + assert isinstance(dictb, dict), 'input must be a dictionary' + dict_ = {} + all_keys = set(dicta.keys()).union(set(dictb.keys())) + for key in all_keys: + if key in dicta.keys() and key in dictb.keys(): + if isinstance(dicta[key], dict) and isinstance(dictb[key], dict): + dict_[key] = _dict_merge(dicta[key], dictb[key], prefix=f'{prefix}.{key}') + else: + raise ValueError(f'Duplicate key {prefix}.{key} found in both dictionaries. Types: {type(dicta[key])}, {type(dictb[key])}') + elif key in dicta.keys(): + dict_[key] = dicta[key] + else: + dict_[key] = dictb[key] + return dict_ + + +def dict_merge(dicta, dictb): + """ + Merge two dictionaries. + """ + return _dict_merge(dicta, dictb, prefix='') + + +def dict_foreach(dic, func, special_func={}): + """ + Recursively apply a function to all non-dictionary leaf values in a dictionary. + """ + assert isinstance(dic, dict), 'input must be a dictionary' + for key in dic.keys(): + if isinstance(dic[key], dict): + dic[key] = dict_foreach(dic[key], func) + else: + if key in special_func.keys(): + dic[key] = special_func[key](dic[key]) + else: + dic[key] = func(dic[key]) + return dic + + +def dict_reduce(dicts, func, special_func={}): + """ + Reduce a list of dictionaries. Leaf values must be scalars. + """ + assert isinstance(dicts, list), 'input must be a list of dictionaries' + assert all([isinstance(d, dict) for d in dicts]), 'input must be a list of dictionaries' + assert len(dicts) > 0, 'input must be a non-empty list of dictionaries' + all_keys = set([key for dict_ in dicts for key in dict_.keys()]) + reduced_dict = {} + for key in all_keys: + vlist = [dict_[key] for dict_ in dicts if key in dict_.keys()] + if isinstance(vlist[0], dict): + reduced_dict[key] = dict_reduce(vlist, func, special_func) + else: + if key in special_func.keys(): + reduced_dict[key] = special_func[key](vlist) + else: + reduced_dict[key] = func(vlist) + return reduced_dict + + +def dict_any(dic, func): + """ + Recursively apply a function to all non-dictionary leaf values in a dictionary. + """ + assert isinstance(dic, dict), 'input must be a dictionary' + for key in dic.keys(): + if isinstance(dic[key], dict): + if dict_any(dic[key], func): + return True + else: + if func(dic[key]): + return True + return False + + +def dict_all(dic, func): + """ + Recursively apply a function to all non-dictionary leaf values in a dictionary. + """ + assert isinstance(dic, dict), 'input must be a dictionary' + for key in dic.keys(): + if isinstance(dic[key], dict): + if not dict_all(dic[key], func): + return False + else: + if not func(dic[key]): + return False + return True + + +def dict_flatten(dic, sep='.'): + """ + Flatten a nested dictionary into a dictionary with no nested dictionaries. + """ + assert isinstance(dic, dict), 'input must be a dictionary' + flat_dict = {} + for key in dic.keys(): + if isinstance(dic[key], dict): + sub_dict = dict_flatten(dic[key], sep=sep) + for sub_key in sub_dict.keys(): + flat_dict[str(key) + sep + str(sub_key)] = sub_dict[sub_key] + else: + flat_dict[key] = dic[key] + return flat_dict + + +def make_grid(images, nrow=None, ncol=None, aspect_ratio=None): + num_images = len(images) + if nrow is None and ncol is None: + if aspect_ratio is not None: + nrow = int(np.round(np.sqrt(num_images / aspect_ratio))) + else: + nrow = int(np.sqrt(num_images)) + ncol = (num_images + nrow - 1) // nrow + elif nrow is None and ncol is not None: + nrow = (num_images + ncol - 1) // ncol + elif nrow is not None and ncol is None: + ncol = (num_images + nrow - 1) // nrow + else: + assert nrow * ncol >= num_images, 'nrow * ncol must be greater than or equal to the number of images' + + grid = np.zeros((nrow * images[0].shape[0], ncol * images[0].shape[1], images[0].shape[2]), dtype=images[0].dtype) + for i, img in enumerate(images): + row = i // ncol + col = i % ncol + grid[row * img.shape[0]:(row + 1) * img.shape[0], col * img.shape[1]:(col + 1) * img.shape[1]] = img + return grid + + +def notes_on_image(img, notes=None): + img = np.pad(img, ((0, 32), (0, 0), (0, 0)), 'constant', constant_values=0) + img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR) + if notes is not None: + img = cv2.putText(img, notes, (0, img.shape[0] - 4), cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 255, 255), 1) + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) + return img + + +def save_image_with_notes(img, path, notes=None): + """ + Save an image with notes. + """ + if isinstance(img, torch.Tensor): + img = img.cpu().numpy().transpose(1, 2, 0) + if img.dtype == np.float32 or img.dtype == np.float64: + img = np.clip(img * 255, 0, 255).astype(np.uint8) + img = notes_on_image(img, notes) + cv2.imwrite(path, cv2.cvtColor(img, cv2.COLOR_RGB2BGR)) + + +# debug utils + +def atol(x, y): + """ + Absolute tolerance. + """ + return torch.abs(x - y) + + +def rtol(x, y): + """ + Relative tolerance. + """ + return torch.abs(x - y) / torch.clamp_min(torch.maximum(torch.abs(x), torch.abs(y)), 1e-12) + + +# print utils +def indent(s, n=4): + """ + Indent a string. + """ + lines = s.split('\n') + for i in range(1, len(lines)): + lines[i] = ' ' * n + lines[i] + return '\n'.join(lines) + diff --git a/TRELLIS/trellis/utils/postprocessing_utils.py b/TRELLIS/trellis/utils/postprocessing_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..e1c35fe2f7138a18db2767c37646ba77358eba1b --- /dev/null +++ b/TRELLIS/trellis/utils/postprocessing_utils.py @@ -0,0 +1,587 @@ +from typing import * +import numpy as np +import torch +import utils3d +import nvdiffrast.torch as dr +from tqdm import tqdm +import trimesh +import trimesh.visual +import xatlas +import pyvista as pv +from pymeshfix import _meshfix +import igraph +import cv2 +from PIL import Image +from .random_utils import sphere_hammersley_sequence +from .render_utils import render_multiview +from ..renderers import GaussianRenderer +from ..representations import Strivec, Gaussian, MeshExtractResult + + +@torch.no_grad() +def _fill_holes( + verts, + faces, + max_hole_size=0.04, + max_hole_nbe=32, + resolution=128, + num_views=500, + debug=False, + verbose=False +): + """ + Rasterize a mesh from multiple views and remove invisible faces. + Also includes postprocessing to: + 1. Remove connected components that are have low visibility. + 2. Mincut to remove faces at the inner side of the mesh connected to the outer side with a small hole. + + Args: + verts (torch.Tensor): Vertices of the mesh. Shape (V, 3). + faces (torch.Tensor): Faces of the mesh. Shape (F, 3). + max_hole_size (float): Maximum area of a hole to fill. + resolution (int): Resolution of the rasterization. + num_views (int): Number of views to rasterize the mesh. + verbose (bool): Whether to print progress. + """ + # Construct cameras + yaws = [] + pitchs = [] + for i in range(num_views): + y, p = sphere_hammersley_sequence(i, num_views) + yaws.append(y) + pitchs.append(p) + yaws = torch.tensor(yaws).cuda() + pitchs = torch.tensor(pitchs).cuda() + radius = 2.0 + fov = torch.deg2rad(torch.tensor(40)).cuda() + projection = utils3d.torch.perspective_from_fov_xy(fov, fov, 1, 3) + views = [] + for (yaw, pitch) in zip(yaws, pitchs): + orig = torch.tensor([ + torch.sin(yaw) * torch.cos(pitch), + torch.cos(yaw) * torch.cos(pitch), + torch.sin(pitch), + ]).cuda().float() * radius + view = utils3d.torch.view_look_at(orig, torch.tensor([0, 0, 0]).float().cuda(), torch.tensor([0, 0, 1]).float().cuda()) + views.append(view) + views = torch.stack(views, dim=0) + + # Rasterize + visblity = torch.zeros(faces.shape[0], dtype=torch.int32, device=verts.device) + rastctx = utils3d.torch.RastContext(backend='cuda') + for i in tqdm(range(views.shape[0]), total=views.shape[0], disable=not verbose, desc='Rasterizing'): + view = views[i] + buffers = utils3d.torch.rasterize_triangle_faces( + rastctx, verts[None], faces, resolution, resolution, view=view, projection=projection + ) + face_id = buffers['face_id'][0][buffers['mask'][0] > 0.95] - 1 + face_id = torch.unique(face_id).long() + visblity[face_id] += 1 + visblity = visblity.float() / num_views + + # Mincut + ## construct outer faces + edges, face2edge, edge_degrees = utils3d.torch.compute_edges(faces) + boundary_edge_indices = torch.nonzero(edge_degrees == 1).reshape(-1) + connected_components = utils3d.torch.compute_connected_components(faces, edges, face2edge) + outer_face_indices = torch.zeros(faces.shape[0], dtype=torch.bool, device=faces.device) + for i in range(len(connected_components)): + outer_face_indices[connected_components[i]] = visblity[connected_components[i]] > min(max(visblity[connected_components[i]].quantile(0.75).item(), 0.25), 0.5) + outer_face_indices = outer_face_indices.nonzero().reshape(-1) + + ## construct inner faces + inner_face_indices = torch.nonzero(visblity == 0).reshape(-1) + if verbose: + tqdm.write(f'Found {inner_face_indices.shape[0]} invisible faces') + if inner_face_indices.shape[0] == 0: + return verts, faces + + ## Construct dual graph (faces as nodes, edges as edges) + dual_edges, dual_edge2edge = utils3d.torch.compute_dual_graph(face2edge) + dual_edge2edge = edges[dual_edge2edge] + dual_edges_weights = torch.norm(verts[dual_edge2edge[:, 0]] - verts[dual_edge2edge[:, 1]], dim=1) + if verbose: + tqdm.write(f'Dual graph: {dual_edges.shape[0]} edges') + + ## solve mincut problem + ### construct main graph + g = igraph.Graph() + g.add_vertices(faces.shape[0]) + g.add_edges(dual_edges.cpu().numpy()) + g.es['weight'] = dual_edges_weights.cpu().numpy() + + ### source and target + g.add_vertex('s') + g.add_vertex('t') + + ### connect invisible faces to source + g.add_edges([(f, 's') for f in inner_face_indices], attributes={'weight': torch.ones(inner_face_indices.shape[0], dtype=torch.float32).cpu().numpy()}) + + ### connect outer faces to target + g.add_edges([(f, 't') for f in outer_face_indices], attributes={'weight': torch.ones(outer_face_indices.shape[0], dtype=torch.float32).cpu().numpy()}) + + ### solve mincut + cut = g.mincut('s', 't', (np.array(g.es['weight']) * 1000).tolist()) + remove_face_indices = torch.tensor([v for v in cut.partition[0] if v < faces.shape[0]], dtype=torch.long, device=faces.device) + if verbose: + tqdm.write(f'Mincut solved, start checking the cut') + + ### check if the cut is valid with each connected component + to_remove_cc = utils3d.torch.compute_connected_components(faces[remove_face_indices]) + if debug: + tqdm.write(f'Number of connected components of the cut: {len(to_remove_cc)}') + valid_remove_cc = [] + cutting_edges = [] + for cc in to_remove_cc: + #### check if the connected component has low visibility + visblity_median = visblity[remove_face_indices[cc]].median() + if debug: + tqdm.write(f'visblity_median: {visblity_median}') + if visblity_median > 0.25: + continue + + #### check if the cuting loop is small enough + cc_edge_indices, cc_edges_degree = torch.unique(face2edge[remove_face_indices[cc]], return_counts=True) + cc_boundary_edge_indices = cc_edge_indices[cc_edges_degree == 1] + cc_new_boundary_edge_indices = cc_boundary_edge_indices[~torch.isin(cc_boundary_edge_indices, boundary_edge_indices)] + if len(cc_new_boundary_edge_indices) > 0: + cc_new_boundary_edge_cc = utils3d.torch.compute_edge_connected_components(edges[cc_new_boundary_edge_indices]) + cc_new_boundary_edges_cc_center = [verts[edges[cc_new_boundary_edge_indices[edge_cc]]].mean(dim=1).mean(dim=0) for edge_cc in cc_new_boundary_edge_cc] + cc_new_boundary_edges_cc_area = [] + for i, edge_cc in enumerate(cc_new_boundary_edge_cc): + _e1 = verts[edges[cc_new_boundary_edge_indices[edge_cc]][:, 0]] - cc_new_boundary_edges_cc_center[i] + _e2 = verts[edges[cc_new_boundary_edge_indices[edge_cc]][:, 1]] - cc_new_boundary_edges_cc_center[i] + cc_new_boundary_edges_cc_area.append(torch.norm(torch.cross(_e1, _e2, dim=-1), dim=1).sum() * 0.5) + if debug: + cutting_edges.append(cc_new_boundary_edge_indices) + tqdm.write(f'Area of the cutting loop: {cc_new_boundary_edges_cc_area}') + if any([l > max_hole_size for l in cc_new_boundary_edges_cc_area]): + continue + + valid_remove_cc.append(cc) + + if debug: + face_v = verts[faces].mean(dim=1).cpu().numpy() + vis_dual_edges = dual_edges.cpu().numpy() + vis_colors = np.zeros((faces.shape[0], 3), dtype=np.uint8) + vis_colors[inner_face_indices.cpu().numpy()] = [0, 0, 255] + vis_colors[outer_face_indices.cpu().numpy()] = [0, 255, 0] + vis_colors[remove_face_indices.cpu().numpy()] = [255, 0, 255] + if len(valid_remove_cc) > 0: + vis_colors[remove_face_indices[torch.cat(valid_remove_cc)].cpu().numpy()] = [255, 0, 0] + utils3d.io.write_ply('dbg_dual.ply', face_v, edges=vis_dual_edges, vertex_colors=vis_colors) + + vis_verts = verts.cpu().numpy() + vis_edges = edges[torch.cat(cutting_edges)].cpu().numpy() + utils3d.io.write_ply('dbg_cut.ply', vis_verts, edges=vis_edges) + + + if len(valid_remove_cc) > 0: + remove_face_indices = remove_face_indices[torch.cat(valid_remove_cc)] + mask = torch.ones(faces.shape[0], dtype=torch.bool, device=faces.device) + mask[remove_face_indices] = 0 + faces = faces[mask] + faces, verts = utils3d.torch.remove_unreferenced_vertices(faces, verts) + if verbose: + tqdm.write(f'Removed {(~mask).sum()} faces by mincut') + else: + if verbose: + tqdm.write(f'Removed 0 faces by mincut') + + mesh = _meshfix.PyTMesh() + mesh.load_array(verts.cpu().numpy(), faces.cpu().numpy()) + mesh.fill_small_boundaries(nbe=max_hole_nbe, refine=True) + verts, faces = mesh.return_arrays() + verts, faces = torch.tensor(verts, device='cuda', dtype=torch.float32), torch.tensor(faces, device='cuda', dtype=torch.int32) + + return verts, faces + + +def postprocess_mesh( + vertices: np.array, + faces: np.array, + simplify: bool = True, + simplify_ratio: float = 0.9, + fill_holes: bool = True, + fill_holes_max_hole_size: float = 0.04, + fill_holes_max_hole_nbe: int = 32, + fill_holes_resolution: int = 1024, + fill_holes_num_views: int = 1000, + debug: bool = False, + verbose: bool = False, +): + """ + Postprocess a mesh by simplifying, removing invisible faces, and removing isolated pieces. + + Args: + vertices (np.array): Vertices of the mesh. Shape (V, 3). + faces (np.array): Faces of the mesh. Shape (F, 3). + simplify (bool): Whether to simplify the mesh, using quadric edge collapse. + simplify_ratio (float): Ratio of faces to keep after simplification. + fill_holes (bool): Whether to fill holes in the mesh. + fill_holes_max_hole_size (float): Maximum area of a hole to fill. + fill_holes_max_hole_nbe (int): Maximum number of boundary edges of a hole to fill. + fill_holes_resolution (int): Resolution of the rasterization. + fill_holes_num_views (int): Number of views to rasterize the mesh. + verbose (bool): Whether to print progress. + """ + + if verbose: + tqdm.write(f'Before postprocess: {vertices.shape[0]} vertices, {faces.shape[0]} faces') + + # Simplify + if simplify and simplify_ratio > 0: + mesh = pv.PolyData(vertices, np.concatenate([np.full((faces.shape[0], 1), 3), faces], axis=1)) + mesh = mesh.decimate(simplify_ratio, progress_bar=verbose) + vertices, faces = mesh.points, mesh.faces.reshape(-1, 4)[:, 1:] + if verbose: + tqdm.write(f'After decimate: {vertices.shape[0]} vertices, {faces.shape[0]} faces') + + # Remove invisible faces + if fill_holes: + vertices, faces = torch.tensor(vertices).cuda(), torch.tensor(faces.astype(np.int32)).cuda() + vertices, faces = _fill_holes( + vertices, faces, + max_hole_size=fill_holes_max_hole_size, + max_hole_nbe=fill_holes_max_hole_nbe, + resolution=fill_holes_resolution, + num_views=fill_holes_num_views, + debug=debug, + verbose=verbose, + ) + vertices, faces = vertices.cpu().numpy(), faces.cpu().numpy() + if verbose: + tqdm.write(f'After remove invisible faces: {vertices.shape[0]} vertices, {faces.shape[0]} faces') + + return vertices, faces + + +def parametrize_mesh(vertices: np.array, faces: np.array): + """ + Parametrize a mesh to a texture space, using xatlas. + + Args: + vertices (np.array): Vertices of the mesh. Shape (V, 3). + faces (np.array): Faces of the mesh. Shape (F, 3). + """ + + vmapping, indices, uvs = xatlas.parametrize(vertices, faces) + + vertices = vertices[vmapping] + faces = indices + + return vertices, faces, uvs + + +def bake_texture( + vertices: np.array, + faces: np.array, + uvs: np.array, + observations: List[np.array], + masks: List[np.array], + extrinsics: List[np.array], + intrinsics: List[np.array], + texture_size: int = 2048, + near: float = 0.1, + far: float = 10.0, + mode: Literal['fast', 'opt'] = 'opt', + lambda_tv: float = 1e-2, + verbose: bool = False, +): + """ + Bake texture to a mesh from multiple observations. + + Args: + vertices (np.array): Vertices of the mesh. Shape (V, 3). + faces (np.array): Faces of the mesh. Shape (F, 3). + uvs (np.array): UV coordinates of the mesh. Shape (V, 2). + observations (List[np.array]): List of observations. Each observation is a 2D image. Shape (H, W, 3). + masks (List[np.array]): List of masks. Each mask is a 2D image. Shape (H, W). + extrinsics (List[np.array]): List of extrinsics. Shape (4, 4). + intrinsics (List[np.array]): List of intrinsics. Shape (3, 3). + texture_size (int): Size of the texture. + near (float): Near plane of the camera. + far (float): Far plane of the camera. + mode (Literal['fast', 'opt']): Mode of texture baking. + lambda_tv (float): Weight of total variation loss in optimization. + verbose (bool): Whether to print progress. + """ + vertices = torch.tensor(vertices).cuda() + faces = torch.tensor(faces.astype(np.int32)).cuda() + uvs = torch.tensor(uvs).cuda() + observations = [torch.tensor(obs / 255.0).float().cuda() for obs in observations] + masks = [torch.tensor(m>0).bool().cuda() for m in masks] + views = [utils3d.torch.extrinsics_to_view(torch.tensor(extr).cuda()) for extr in extrinsics] + projections = [utils3d.torch.intrinsics_to_perspective(torch.tensor(intr).cuda(), near, far) for intr in intrinsics] + + if mode == 'fast': + texture = torch.zeros((texture_size * texture_size, 3), dtype=torch.float32).cuda() + texture_weights = torch.zeros((texture_size * texture_size), dtype=torch.float32).cuda() + rastctx = utils3d.torch.RastContext(backend='cuda') + for observation, view, projection in tqdm(zip(observations, views, projections), total=len(observations), disable=not verbose, desc='Texture baking (fast)'): + with torch.no_grad(): + rast = utils3d.torch.rasterize_triangle_faces( + rastctx, vertices[None], faces, observation.shape[1], observation.shape[0], uv=uvs[None], view=view, projection=projection + ) + uv_map = rast['uv'][0].detach().flip(0) + mask = rast['mask'][0].detach().bool() & masks[0] + + # nearest neighbor interpolation + uv_map = (uv_map * texture_size).floor().long() + obs = observation[mask] + uv_map = uv_map[mask] + idx = uv_map[:, 0] + (texture_size - uv_map[:, 1] - 1) * texture_size + texture = texture.scatter_add(0, idx.view(-1, 1).expand(-1, 3), obs) + texture_weights = texture_weights.scatter_add(0, idx, torch.ones((obs.shape[0]), dtype=torch.float32, device=texture.device)) + + mask = texture_weights > 0 + texture[mask] /= texture_weights[mask][:, None] + texture = np.clip(texture.reshape(texture_size, texture_size, 3).cpu().numpy() * 255, 0, 255).astype(np.uint8) + + # inpaint + mask = (texture_weights == 0).cpu().numpy().astype(np.uint8).reshape(texture_size, texture_size) + texture = cv2.inpaint(texture, mask, 3, cv2.INPAINT_TELEA) + + elif mode == 'opt': + rastctx = utils3d.torch.RastContext(backend='cuda') + observations = [observations.flip(0) for observations in observations] + masks = [m.flip(0) for m in masks] + _uv = [] + _uv_dr = [] + for observation, view, projection in tqdm(zip(observations, views, projections), total=len(views), disable=not verbose, desc='Texture baking (opt): UV'): + with torch.no_grad(): + rast = utils3d.torch.rasterize_triangle_faces( + rastctx, vertices[None], faces, observation.shape[1], observation.shape[0], uv=uvs[None], view=view, projection=projection + ) + _uv.append(rast['uv'].detach()) + _uv_dr.append(rast['uv_dr'].detach()) + + texture = torch.nn.Parameter(torch.zeros((1, texture_size, texture_size, 3), dtype=torch.float32).cuda()) + optimizer = torch.optim.Adam([texture], betas=(0.5, 0.9), lr=1e-2) + + def exp_anealing(optimizer, step, total_steps, start_lr, end_lr): + return start_lr * (end_lr / start_lr) ** (step / total_steps) + + def cosine_anealing(optimizer, step, total_steps, start_lr, end_lr): + return end_lr + 0.5 * (start_lr - end_lr) * (1 + np.cos(np.pi * step / total_steps)) + + def tv_loss(texture): + return torch.nn.functional.l1_loss(texture[:, :-1, :, :], texture[:, 1:, :, :]) + \ + torch.nn.functional.l1_loss(texture[:, :, :-1, :], texture[:, :, 1:, :]) + + total_steps = 2500 + with tqdm(total=total_steps, disable=not verbose, desc='Texture baking (opt): optimizing') as pbar: + for step in range(total_steps): + optimizer.zero_grad() + selected = np.random.randint(0, len(views)) + uv, uv_dr, observation, mask = _uv[selected], _uv_dr[selected], observations[selected], masks[selected] + render = dr.texture(texture, uv, uv_dr)[0] + loss = torch.nn.functional.l1_loss(render[mask], observation[mask]) + if lambda_tv > 0: + loss += lambda_tv * tv_loss(texture) + loss.backward() + optimizer.step() + # annealing + optimizer.param_groups[0]['lr'] = cosine_anealing(optimizer, step, total_steps, 1e-2, 1e-5) + pbar.set_postfix({'loss': loss.item()}) + pbar.update() + texture = np.clip(texture[0].flip(0).detach().cpu().numpy() * 255, 0, 255).astype(np.uint8) + mask = 1 - utils3d.torch.rasterize_triangle_faces( + rastctx, (uvs * 2 - 1)[None], faces, texture_size, texture_size + )['mask'][0].detach().cpu().numpy().astype(np.uint8) + texture = cv2.inpaint(texture, mask, 3, cv2.INPAINT_TELEA) + else: + raise ValueError(f'Unknown mode: {mode}') + + return texture + + +def to_glb( + app_rep: Union[Strivec, Gaussian], + mesh: MeshExtractResult, + simplify: float = 0.95, + fill_holes: bool = True, + fill_holes_max_size: float = 0.04, + texture_size: int = 1024, + debug: bool = False, + verbose: bool = True, +) -> trimesh.Trimesh: + """ + Convert a generated asset to a glb file. + + Args: + app_rep (Union[Strivec, Gaussian]): Appearance representation. + mesh (MeshExtractResult): Extracted mesh. + simplify (float): Ratio of faces to remove in simplification. + fill_holes (bool): Whether to fill holes in the mesh. + fill_holes_max_size (float): Maximum area of a hole to fill. + texture_size (int): Size of the texture. + debug (bool): Whether to print debug information. + verbose (bool): Whether to print progress. + """ + vertices = mesh.vertices.cpu().numpy() + faces = mesh.faces.cpu().numpy() + + # mesh postprocess + vertices, faces = postprocess_mesh( + vertices, faces, + simplify=simplify > 0, + simplify_ratio=simplify, + fill_holes=fill_holes, + fill_holes_max_hole_size=fill_holes_max_size, + fill_holes_max_hole_nbe=int(250 * np.sqrt(1-simplify)), + fill_holes_resolution=1024, + fill_holes_num_views=1000, + debug=debug, + verbose=verbose, + ) + + # parametrize mesh + vertices, faces, uvs = parametrize_mesh(vertices, faces) + + # bake texture + observations, extrinsics, intrinsics = render_multiview(app_rep, resolution=1024, nviews=100) + masks = [np.any(observation > 0, axis=-1) for observation in observations] + extrinsics = [extrinsics[i].cpu().numpy() for i in range(len(extrinsics))] + intrinsics = [intrinsics[i].cpu().numpy() for i in range(len(intrinsics))] + texture = bake_texture( + vertices, faces, uvs, + observations, masks, extrinsics, intrinsics, + texture_size=texture_size, mode='opt', + lambda_tv=0.01, + verbose=verbose + ) + texture = Image.fromarray(texture) + + # rotate mesh (from z-up to y-up) + vertices = vertices @ np.array([[1, 0, 0], [0, 0, -1], [0, 1, 0]]) + material = trimesh.visual.material.PBRMaterial( + roughnessFactor=1.0, + baseColorTexture=texture, + baseColorFactor=np.array([255, 255, 255, 255], dtype=np.uint8) + ) + mesh = trimesh.Trimesh(vertices, faces, visual=trimesh.visual.TextureVisuals(uv=uvs, material=material)) + return mesh + + +def simplify_gs( + gs: Gaussian, + simplify: float = 0.95, + verbose: bool = True, +): + """ + Simplify 3D Gaussians + NOTE: this function is not used in the current implementation for the unsatisfactory performance. + + Args: + gs (Gaussian): 3D Gaussian. + simplify (float): Ratio of Gaussians to remove in simplification. + """ + if simplify <= 0: + return gs + + # simplify + observations, extrinsics, intrinsics = render_multiview(gs, resolution=1024, nviews=100) + observations = [torch.tensor(obs / 255.0).float().cuda().permute(2, 0, 1) for obs in observations] + + # Following https://arxiv.org/pdf/2411.06019 + renderer = GaussianRenderer({ + "resolution": 1024, + "near": 0.8, + "far": 1.6, + "ssaa": 1, + "bg_color": (0,0,0), + }) + new_gs = Gaussian(**gs.init_params) + new_gs._features_dc = gs._features_dc.clone() + new_gs._features_rest = gs._features_rest.clone() if gs._features_rest is not None else None + new_gs._opacity = torch.nn.Parameter(gs._opacity.clone()) + new_gs._rotation = torch.nn.Parameter(gs._rotation.clone()) + new_gs._scaling = torch.nn.Parameter(gs._scaling.clone()) + new_gs._xyz = torch.nn.Parameter(gs._xyz.clone()) + + start_lr = [1e-4, 1e-3, 5e-3, 0.025] + end_lr = [1e-6, 1e-5, 5e-5, 0.00025] + optimizer = torch.optim.Adam([ + {"params": new_gs._xyz, "lr": start_lr[0]}, + {"params": new_gs._rotation, "lr": start_lr[1]}, + {"params": new_gs._scaling, "lr": start_lr[2]}, + {"params": new_gs._opacity, "lr": start_lr[3]}, + ], lr=start_lr[0]) + + def exp_anealing(optimizer, step, total_steps, start_lr, end_lr): + return start_lr * (end_lr / start_lr) ** (step / total_steps) + + def cosine_anealing(optimizer, step, total_steps, start_lr, end_lr): + return end_lr + 0.5 * (start_lr - end_lr) * (1 + np.cos(np.pi * step / total_steps)) + + _zeta = new_gs.get_opacity.clone().detach().squeeze() + _lambda = torch.zeros_like(_zeta) + _delta = 1e-7 + _interval = 10 + num_target = int((1 - simplify) * _zeta.shape[0]) + + with tqdm(total=2500, disable=not verbose, desc='Simplifying Gaussian') as pbar: + for i in range(2500): + # prune + if i % 100 == 0: + mask = new_gs.get_opacity.squeeze() > 0.05 + mask = torch.nonzero(mask).squeeze() + new_gs._xyz = torch.nn.Parameter(new_gs._xyz[mask]) + new_gs._rotation = torch.nn.Parameter(new_gs._rotation[mask]) + new_gs._scaling = torch.nn.Parameter(new_gs._scaling[mask]) + new_gs._opacity = torch.nn.Parameter(new_gs._opacity[mask]) + new_gs._features_dc = new_gs._features_dc[mask] + new_gs._features_rest = new_gs._features_rest[mask] if new_gs._features_rest is not None else None + _zeta = _zeta[mask] + _lambda = _lambda[mask] + # update optimizer state + for param_group, new_param in zip(optimizer.param_groups, [new_gs._xyz, new_gs._rotation, new_gs._scaling, new_gs._opacity]): + stored_state = optimizer.state[param_group['params'][0]] + if 'exp_avg' in stored_state: + stored_state['exp_avg'] = stored_state['exp_avg'][mask] + stored_state['exp_avg_sq'] = stored_state['exp_avg_sq'][mask] + del optimizer.state[param_group['params'][0]] + param_group['params'][0] = new_param + optimizer.state[param_group['params'][0]] = stored_state + + opacity = new_gs.get_opacity.squeeze() + + # sparisfy + if i % _interval == 0: + _zeta = _lambda + opacity.detach() + if opacity.shape[0] > num_target: + index = _zeta.topk(num_target)[1] + _m = torch.ones_like(_zeta, dtype=torch.bool) + _m[index] = 0 + _zeta[_m] = 0 + _lambda = _lambda + opacity.detach() - _zeta + + # sample a random view + view_idx = np.random.randint(len(observations)) + observation = observations[view_idx] + extrinsic = extrinsics[view_idx] + intrinsic = intrinsics[view_idx] + + color = renderer.render(new_gs, extrinsic, intrinsic)['color'] + rgb_loss = torch.nn.functional.l1_loss(color, observation) + loss = rgb_loss + \ + _delta * torch.sum(torch.pow(_lambda + opacity - _zeta, 2)) + + optimizer.zero_grad() + loss.backward() + optimizer.step() + + # update lr + for j in range(len(optimizer.param_groups)): + optimizer.param_groups[j]['lr'] = cosine_anealing(optimizer, i, 2500, start_lr[j], end_lr[j]) + + pbar.set_postfix({'loss': rgb_loss.item(), 'num': opacity.shape[0], 'lambda': _lambda.mean().item()}) + pbar.update() + + new_gs._xyz = new_gs._xyz.data + new_gs._rotation = new_gs._rotation.data + new_gs._scaling = new_gs._scaling.data + new_gs._opacity = new_gs._opacity.data + + return new_gs diff --git a/TRELLIS/trellis/utils/random_utils.py b/TRELLIS/trellis/utils/random_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..420c4146cbb4c2973f2cc2e69f376a3836e65eeb --- /dev/null +++ b/TRELLIS/trellis/utils/random_utils.py @@ -0,0 +1,30 @@ +import numpy as np + +PRIMES = [2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37, 41, 43, 47, 53] + +def radical_inverse(base, n): + val = 0 + inv_base = 1.0 / base + inv_base_n = inv_base + while n > 0: + digit = n % base + val += digit * inv_base_n + n //= base + inv_base_n *= inv_base + return val + +def halton_sequence(dim, n): + return [radical_inverse(PRIMES[dim], n) for dim in range(dim)] + +def hammersley_sequence(dim, n, num_samples): + return [n / num_samples] + halton_sequence(dim - 1, n) + +def sphere_hammersley_sequence(n, num_samples, offset=(0, 0), remap=False): + u, v = hammersley_sequence(2, n, num_samples) + u += offset[0] / num_samples + v += offset[1] + if remap: + u = 2 * u if u < 0.25 else 2 / 3 * u + 1 / 3 + theta = np.arccos(1 - 2 * u) - np.pi / 2 + phi = v * 2 * np.pi + return [phi, theta] \ No newline at end of file diff --git a/TRELLIS/trellis/utils/render_utils.py b/TRELLIS/trellis/utils/render_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..54a33ce79c6d3e1e358ab1650ea14cfe1d50ba91 --- /dev/null +++ b/TRELLIS/trellis/utils/render_utils.py @@ -0,0 +1,116 @@ +import torch +import numpy as np +from tqdm import tqdm +import utils3d +from PIL import Image + +from ..renderers import OctreeRenderer, GaussianRenderer, MeshRenderer +from ..representations import Octree, Gaussian, MeshExtractResult +from ..modules import sparse as sp +from .random_utils import sphere_hammersley_sequence + + +def yaw_pitch_r_fov_to_extrinsics_intrinsics(yaws, pitchs, rs, fovs): + is_list = isinstance(yaws, list) + if not is_list: + yaws = [yaws] + pitchs = [pitchs] + if not isinstance(rs, list): + rs = [rs] * len(yaws) + if not isinstance(fovs, list): + fovs = [fovs] * len(yaws) + extrinsics = [] + intrinsics = [] + for yaw, pitch, r, fov in zip(yaws, pitchs, rs, fovs): + fov = torch.deg2rad(torch.tensor(float(fov))).cuda() + yaw = torch.tensor(float(yaw)).cuda() + pitch = torch.tensor(float(pitch)).cuda() + orig = torch.tensor([ + torch.sin(yaw) * torch.cos(pitch), + torch.cos(yaw) * torch.cos(pitch), + torch.sin(pitch), + ]).cuda() * r + extr = utils3d.torch.extrinsics_look_at(orig, torch.tensor([0, 0, 0]).float().cuda(), torch.tensor([0, 0, 1]).float().cuda()) + intr = utils3d.torch.intrinsics_from_fov_xy(fov, fov) + extrinsics.append(extr) + intrinsics.append(intr) + if not is_list: + extrinsics = extrinsics[0] + intrinsics = intrinsics[0] + return extrinsics, intrinsics + + +def render_frames(sample, extrinsics, intrinsics, options={}, colors_overwrite=None, verbose=True, **kwargs): + if isinstance(sample, Octree): + renderer = OctreeRenderer() + renderer.rendering_options.resolution = options.get('resolution', 512) + renderer.rendering_options.near = options.get('near', 0.8) + renderer.rendering_options.far = options.get('far', 1.6) + renderer.rendering_options.bg_color = options.get('bg_color', (0, 0, 0)) + renderer.rendering_options.ssaa = options.get('ssaa', 4) + renderer.pipe.primitive = sample.primitive + elif isinstance(sample, Gaussian): + renderer = GaussianRenderer() + renderer.rendering_options.resolution = options.get('resolution', 512) + renderer.rendering_options.near = options.get('near', 0.8) + renderer.rendering_options.far = options.get('far', 1.6) + renderer.rendering_options.bg_color = options.get('bg_color', (0, 0, 0)) + renderer.rendering_options.ssaa = options.get('ssaa', 1) + renderer.pipe.kernel_size = kwargs.get('kernel_size', 0.1) + renderer.pipe.use_mip_gaussian = True + elif isinstance(sample, MeshExtractResult): + renderer = MeshRenderer() + renderer.rendering_options.resolution = options.get('resolution', 512) + renderer.rendering_options.near = options.get('near', 1) + renderer.rendering_options.far = options.get('far', 100) + renderer.rendering_options.ssaa = options.get('ssaa', 4) + else: + raise ValueError(f'Unsupported sample type: {type(sample)}') + + rets = {} + for j, (extr, intr) in tqdm(enumerate(zip(extrinsics, intrinsics)), desc='Rendering', disable=not verbose): + if not isinstance(sample, MeshExtractResult): + res = renderer.render(sample, extr, intr, colors_overwrite=colors_overwrite) + if 'color' not in rets: rets['color'] = [] + if 'depth' not in rets: rets['depth'] = [] + rets['color'].append(np.clip(res['color'].detach().cpu().numpy().transpose(1, 2, 0) * 255, 0, 255).astype(np.uint8)) + if 'percent_depth' in res: + rets['depth'].append(res['percent_depth'].detach().cpu().numpy()) + elif 'depth' in res: + rets['depth'].append(res['depth'].detach().cpu().numpy()) + else: + rets['depth'].append(None) + else: + res = renderer.render(sample, extr, intr) + if 'normal' not in rets: rets['normal'] = [] + rets['normal'].append(np.clip(res['normal'].detach().cpu().numpy().transpose(1, 2, 0) * 255, 0, 255).astype(np.uint8)) + return rets + + +def render_video(sample, resolution=512, bg_color=(0, 0, 0), num_frames=300, r=2, fov=40, **kwargs): + yaws = torch.linspace(0, 2 * 3.1415, num_frames) + pitch = 0.25 + 0.5 * torch.sin(torch.linspace(0, 2 * 3.1415, num_frames)) + yaws = yaws.tolist() + pitch = pitch.tolist() + extrinsics, intrinsics = yaw_pitch_r_fov_to_extrinsics_intrinsics(yaws, pitch, r, fov) + return render_frames(sample, extrinsics, intrinsics, {'resolution': resolution, 'bg_color': bg_color}, **kwargs) + + +def render_multiview(sample, resolution=512, nviews=30): + r = 2 + fov = 40 + cams = [sphere_hammersley_sequence(i, nviews) for i in range(nviews)] + yaws = [cam[0] for cam in cams] + pitchs = [cam[1] for cam in cams] + extrinsics, intrinsics = yaw_pitch_r_fov_to_extrinsics_intrinsics(yaws, pitchs, r, fov) + res = render_frames(sample, extrinsics, intrinsics, {'resolution': resolution, 'bg_color': (0, 0, 0)}) + return res['color'], extrinsics, intrinsics + + +def render_snapshot(samples, resolution=512, bg_color=(0, 0, 0), offset=(-16 / 180 * np.pi, 20 / 180 * np.pi), r=10, fov=8, **kwargs): + yaw = [0, np.pi/2, np.pi, 3*np.pi/2] + yaw_offset = offset[0] + yaw = [y + yaw_offset for y in yaw] + pitch = [offset[1] for _ in range(4)] + extrinsics, intrinsics = yaw_pitch_r_fov_to_extrinsics_intrinsics(yaw, pitch, r, fov) + return render_frames(samples, extrinsics, intrinsics, {'resolution': resolution, 'bg_color': bg_color}, **kwargs)