diff --git a/trellis/__init__.py b/trellis/__init__.py deleted file mode 100644 index b02ac31563fec7c36fa4bd5d420ac4af2472bba8..0000000000000000000000000000000000000000 --- a/trellis/__init__.py +++ /dev/null @@ -1,6 +0,0 @@ -from . import models -from . import modules -from . import pipelines -from . import renderers -from . import representations -from . import utils diff --git a/trellis/models/__init__.py b/trellis/models/__init__.py deleted file mode 100644 index 00fd66a3272a48d43f9853682db48bcde2959d63..0000000000000000000000000000000000000000 --- a/trellis/models/__init__.py +++ /dev/null @@ -1,70 +0,0 @@ -import importlib - -__attributes = { - 'SparseStructureEncoder': 'sparse_structure_vae', - 'SparseStructureDecoder': 'sparse_structure_vae', - 'SparseStructureFlowModel': 'sparse_structure_flow', - 'SLatEncoder': 'structured_latent_vae', - 'SLatGaussianDecoder': 'structured_latent_vae', - 'SLatRadianceFieldDecoder': 'structured_latent_vae', - 'SLatMeshDecoder': 'structured_latent_vae', - 'SLatFlowModel': 'structured_latent_flow', -} - -__submodules = [] - -__all__ = list(__attributes.keys()) + __submodules - -def __getattr__(name): - if name not in globals(): - if name in __attributes: - module_name = __attributes[name] - module = importlib.import_module(f".{module_name}", __name__) - globals()[name] = getattr(module, name) - elif name in __submodules: - module = importlib.import_module(f".{name}", __name__) - globals()[name] = module - else: - raise AttributeError(f"module {__name__} has no attribute {name}") - return globals()[name] - - -def from_pretrained(path: str, **kwargs): - """ - Load a model from a pretrained checkpoint. - - Args: - path: The path to the checkpoint. Can be either local path or a Hugging Face model name. - NOTE: config file and model file should take the name f'{path}.json' and f'{path}.safetensors' respectively. - **kwargs: Additional arguments for the model constructor. - """ - import os - import json - from safetensors.torch import load_file - is_local = os.path.exists(f"{path}.json") and os.path.exists(f"{path}.safetensors") - - if is_local: - config_file = f"{path}.json" - model_file = f"{path}.safetensors" - else: - from huggingface_hub import hf_hub_download - path_parts = path.split('/') - repo_id = f'{path_parts[0]}/{path_parts[1]}' - model_name = '/'.join(path_parts[2:]) - config_file = hf_hub_download(repo_id, f"{model_name}.json") - model_file = hf_hub_download(repo_id, f"{model_name}.safetensors") - - with open(config_file, 'r') as f: - config = json.load(f) - model = __getattr__(config['name'])(**config['args'], **kwargs) - model.load_state_dict(load_file(model_file)) - - return model - - -# For Pylance -if __name__ == '__main__': - from .sparse_structure_vae import SparseStructureEncoder, SparseStructureDecoder - from .sparse_structure_flow import SparseStructureFlowModel - from .structured_latent_vae import SLatEncoder, SLatGaussianDecoder, SLatRadianceFieldDecoder, SLatMeshDecoder - from .structured_latent_flow import SLatFlowModel diff --git a/trellis/models/sparse_structure_flow.py b/trellis/models/sparse_structure_flow.py deleted file mode 100644 index baa5dd9644e569b73717d7e7a9ebed55e9930459..0000000000000000000000000000000000000000 --- a/trellis/models/sparse_structure_flow.py +++ /dev/null @@ -1,200 +0,0 @@ -from typing import * -import torch -import torch.nn as nn -import torch.nn.functional as F -import numpy as np -from ..modules.utils import convert_module_to_f16, convert_module_to_f32 -from ..modules.transformer import AbsolutePositionEmbedder, ModulatedTransformerCrossBlock -from ..modules.spatial import patchify, unpatchify - - -class TimestepEmbedder(nn.Module): - """ - Embeds scalar timesteps into vector representations. - """ - def __init__(self, hidden_size, frequency_embedding_size=256): - super().__init__() - self.mlp = nn.Sequential( - nn.Linear(frequency_embedding_size, hidden_size, bias=True), - nn.SiLU(), - nn.Linear(hidden_size, hidden_size, bias=True), - ) - self.frequency_embedding_size = frequency_embedding_size - - @staticmethod - def timestep_embedding(t, dim, max_period=10000): - """ - Create sinusoidal timestep embeddings. - - Args: - t: a 1-D Tensor of N indices, one per batch element. - These may be fractional. - dim: the dimension of the output. - max_period: controls the minimum frequency of the embeddings. - - Returns: - an (N, D) Tensor of positional embeddings. - """ - # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py - half = dim // 2 - freqs = torch.exp( - -np.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half - ).to(device=t.device) - args = t[:, None].float() * freqs[None] - embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) - if dim % 2: - embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) - return embedding - - def forward(self, t): - t_freq = self.timestep_embedding(t, self.frequency_embedding_size) - t_emb = self.mlp(t_freq) - return t_emb - - -class SparseStructureFlowModel(nn.Module): - def __init__( - self, - resolution: int, - in_channels: int, - model_channels: int, - cond_channels: int, - out_channels: int, - num_blocks: int, - num_heads: Optional[int] = None, - num_head_channels: Optional[int] = 64, - mlp_ratio: float = 4, - patch_size: int = 2, - pe_mode: Literal["ape", "rope"] = "ape", - use_fp16: bool = False, - use_checkpoint: bool = False, - share_mod: bool = False, - qk_rms_norm: bool = False, - qk_rms_norm_cross: bool = False, - ): - super().__init__() - self.resolution = resolution - self.in_channels = in_channels - self.model_channels = model_channels - self.cond_channels = cond_channels - self.out_channels = out_channels - self.num_blocks = num_blocks - self.num_heads = num_heads or model_channels // num_head_channels - self.mlp_ratio = mlp_ratio - self.patch_size = patch_size - self.pe_mode = pe_mode - self.use_fp16 = use_fp16 - self.use_checkpoint = use_checkpoint - self.share_mod = share_mod - self.qk_rms_norm = qk_rms_norm - self.qk_rms_norm_cross = qk_rms_norm_cross - self.dtype = torch.float16 if use_fp16 else torch.float32 - - self.t_embedder = TimestepEmbedder(model_channels) - if share_mod: - self.adaLN_modulation = nn.Sequential( - nn.SiLU(), - nn.Linear(model_channels, 6 * model_channels, bias=True) - ) - - if pe_mode == "ape": - pos_embedder = AbsolutePositionEmbedder(model_channels, 3) - coords = torch.meshgrid(*[torch.arange(res, device=self.device) for res in [resolution // patch_size] * 3], indexing='ij') - coords = torch.stack(coords, dim=-1).reshape(-1, 3) - pos_emb = pos_embedder(coords) - self.register_buffer("pos_emb", pos_emb) - - self.input_layer = nn.Linear(in_channels * patch_size**3, model_channels) - - self.blocks = nn.ModuleList([ - ModulatedTransformerCrossBlock( - model_channels, - cond_channels, - num_heads=self.num_heads, - mlp_ratio=self.mlp_ratio, - attn_mode='full', - use_checkpoint=self.use_checkpoint, - use_rope=(pe_mode == "rope"), - share_mod=share_mod, - qk_rms_norm=self.qk_rms_norm, - qk_rms_norm_cross=self.qk_rms_norm_cross, - ) - for _ in range(num_blocks) - ]) - - self.out_layer = nn.Linear(model_channels, out_channels * patch_size**3) - - self.initialize_weights() - if use_fp16: - self.convert_to_fp16() - - @property - def device(self) -> torch.device: - """ - Return the device of the model. - """ - return next(self.parameters()).device - - def convert_to_fp16(self) -> None: - """ - Convert the torso of the model to float16. - """ - self.blocks.apply(convert_module_to_f16) - - def convert_to_fp32(self) -> None: - """ - Convert the torso of the model to float32. - """ - self.blocks.apply(convert_module_to_f32) - - def initialize_weights(self) -> None: - # Initialize transformer layers: - def _basic_init(module): - if isinstance(module, nn.Linear): - torch.nn.init.xavier_uniform_(module.weight) - if module.bias is not None: - nn.init.constant_(module.bias, 0) - self.apply(_basic_init) - - # Initialize timestep embedding MLP: - nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02) - nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02) - - # Zero-out adaLN modulation layers in DiT blocks: - if self.share_mod: - nn.init.constant_(self.adaLN_modulation[-1].weight, 0) - nn.init.constant_(self.adaLN_modulation[-1].bias, 0) - else: - for block in self.blocks: - nn.init.constant_(block.adaLN_modulation[-1].weight, 0) - nn.init.constant_(block.adaLN_modulation[-1].bias, 0) - - # Zero-out output layers: - nn.init.constant_(self.out_layer.weight, 0) - nn.init.constant_(self.out_layer.bias, 0) - - def forward(self, x: torch.Tensor, t: torch.Tensor, cond: torch.Tensor) -> torch.Tensor: - assert [*x.shape] == [x.shape[0], self.in_channels, *[self.resolution] * 3], \ - f"Input shape mismatch, got {x.shape}, expected {[x.shape[0], self.in_channels, *[self.resolution] * 3]}" - - h = patchify(x, self.patch_size) - h = h.view(*h.shape[:2], -1).permute(0, 2, 1).contiguous() - - h = self.input_layer(h) - h = h + self.pos_emb[None] - t_emb = self.t_embedder(t) - if self.share_mod: - t_emb = self.adaLN_modulation(t_emb) - t_emb = t_emb.type(self.dtype) - h = h.type(self.dtype) - cond = cond.type(self.dtype) - for block in self.blocks: - h = block(h, t_emb, cond) - h = h.type(x.dtype) - h = F.layer_norm(h, h.shape[-1:]) - h = self.out_layer(h) - - h = h.permute(0, 2, 1).view(h.shape[0], h.shape[2], *[self.resolution // self.patch_size] * 3) - h = unpatchify(h, self.patch_size).contiguous() - - return h diff --git a/trellis/models/sparse_structure_vae.py b/trellis/models/sparse_structure_vae.py deleted file mode 100644 index 6ed49ae65b9cde2a45a59beb6868981a644b75d3..0000000000000000000000000000000000000000 --- a/trellis/models/sparse_structure_vae.py +++ /dev/null @@ -1,306 +0,0 @@ -from typing import * -import torch -import torch.nn as nn -import torch.nn.functional as F -from ..modules.norm import GroupNorm32, ChannelLayerNorm32 -from ..modules.spatial import pixel_shuffle_3d -from ..modules.utils import zero_module, convert_module_to_f16, convert_module_to_f32 - - -def norm_layer(norm_type: str, *args, **kwargs) -> nn.Module: - """ - Return a normalization layer. - """ - if norm_type == "group": - return GroupNorm32(32, *args, **kwargs) - elif norm_type == "layer": - return ChannelLayerNorm32(*args, **kwargs) - else: - raise ValueError(f"Invalid norm type {norm_type}") - - -class ResBlock3d(nn.Module): - def __init__( - self, - channels: int, - out_channels: Optional[int] = None, - norm_type: Literal["group", "layer"] = "layer", - ): - super().__init__() - self.channels = channels - self.out_channels = out_channels or channels - - self.norm1 = norm_layer(norm_type, channels) - self.norm2 = norm_layer(norm_type, self.out_channels) - self.conv1 = nn.Conv3d(channels, self.out_channels, 3, padding=1) - self.conv2 = zero_module(nn.Conv3d(self.out_channels, self.out_channels, 3, padding=1)) - self.skip_connection = nn.Conv3d(channels, self.out_channels, 1) if channels != self.out_channels else nn.Identity() - - def forward(self, x: torch.Tensor) -> torch.Tensor: - h = self.norm1(x) - h = F.silu(h) - h = self.conv1(h) - h = self.norm2(h) - h = F.silu(h) - h = self.conv2(h) - h = h + self.skip_connection(x) - return h - - -class DownsampleBlock3d(nn.Module): - def __init__( - self, - in_channels: int, - out_channels: int, - mode: Literal["conv", "avgpool"] = "conv", - ): - assert mode in ["conv", "avgpool"], f"Invalid mode {mode}" - - super().__init__() - self.in_channels = in_channels - self.out_channels = out_channels - - if mode == "conv": - self.conv = nn.Conv3d(in_channels, out_channels, 2, stride=2) - elif mode == "avgpool": - assert in_channels == out_channels, "Pooling mode requires in_channels to be equal to out_channels" - - def forward(self, x: torch.Tensor) -> torch.Tensor: - if hasattr(self, "conv"): - return self.conv(x) - else: - return F.avg_pool3d(x, 2) - - -class UpsampleBlock3d(nn.Module): - def __init__( - self, - in_channels: int, - out_channels: int, - mode: Literal["conv", "nearest"] = "conv", - ): - assert mode in ["conv", "nearest"], f"Invalid mode {mode}" - - super().__init__() - self.in_channels = in_channels - self.out_channels = out_channels - - if mode == "conv": - self.conv = nn.Conv3d(in_channels, out_channels*8, 3, padding=1) - elif mode == "nearest": - assert in_channels == out_channels, "Nearest mode requires in_channels to be equal to out_channels" - - def forward(self, x: torch.Tensor) -> torch.Tensor: - if hasattr(self, "conv"): - x = self.conv(x) - return pixel_shuffle_3d(x, 2) - else: - return F.interpolate(x, scale_factor=2, mode="nearest") - - -class SparseStructureEncoder(nn.Module): - """ - Encoder for Sparse Structure (\mathcal{E}_S in the paper Sec. 3.3). - - Args: - in_channels (int): Channels of the input. - latent_channels (int): Channels of the latent representation. - num_res_blocks (int): Number of residual blocks at each resolution. - channels (List[int]): Channels of the encoder blocks. - num_res_blocks_middle (int): Number of residual blocks in the middle. - norm_type (Literal["group", "layer"]): Type of normalization layer. - use_fp16 (bool): Whether to use FP16. - """ - def __init__( - self, - in_channels: int, - latent_channels: int, - num_res_blocks: int, - channels: List[int], - num_res_blocks_middle: int = 2, - norm_type: Literal["group", "layer"] = "layer", - use_fp16: bool = False, - ): - super().__init__() - self.in_channels = in_channels - self.latent_channels = latent_channels - self.num_res_blocks = num_res_blocks - self.channels = channels - self.num_res_blocks_middle = num_res_blocks_middle - self.norm_type = norm_type - self.use_fp16 = use_fp16 - self.dtype = torch.float16 if use_fp16 else torch.float32 - - self.input_layer = nn.Conv3d(in_channels, channels[0], 3, padding=1) - - self.blocks = nn.ModuleList([]) - for i, ch in enumerate(channels): - self.blocks.extend([ - ResBlock3d(ch, ch) - for _ in range(num_res_blocks) - ]) - if i < len(channels) - 1: - self.blocks.append( - DownsampleBlock3d(ch, channels[i+1]) - ) - - self.middle_block = nn.Sequential(*[ - ResBlock3d(channels[-1], channels[-1]) - for _ in range(num_res_blocks_middle) - ]) - - self.out_layer = nn.Sequential( - norm_layer(norm_type, channels[-1]), - nn.SiLU(), - nn.Conv3d(channels[-1], latent_channels*2, 3, padding=1) - ) - - if use_fp16: - self.convert_to_fp16() - - @property - def device(self) -> torch.device: - """ - Return the device of the model. - """ - return next(self.parameters()).device - - def convert_to_fp16(self) -> None: - """ - Convert the torso of the model to float16. - """ - self.use_fp16 = True - self.dtype = torch.float16 - self.blocks.apply(convert_module_to_f16) - self.middle_block.apply(convert_module_to_f16) - - def convert_to_fp32(self) -> None: - """ - Convert the torso of the model to float32. - """ - self.use_fp16 = False - self.dtype = torch.float32 - self.blocks.apply(convert_module_to_f32) - self.middle_block.apply(convert_module_to_f32) - - def forward(self, x: torch.Tensor, sample_posterior: bool = False, return_raw: bool = False) -> torch.Tensor: - h = self.input_layer(x) - h = h.type(self.dtype) - - for block in self.blocks: - h = block(h) - h = self.middle_block(h) - - h = h.type(x.dtype) - h = self.out_layer(h) - - mean, logvar = h.chunk(2, dim=1) - - if sample_posterior: - std = torch.exp(0.5 * logvar) - z = mean + std * torch.randn_like(std) - else: - z = mean - - if return_raw: - return z, mean, logvar - return z - - -class SparseStructureDecoder(nn.Module): - """ - Decoder for Sparse Structure (\mathcal{D}_S in the paper Sec. 3.3). - - Args: - out_channels (int): Channels of the output. - latent_channels (int): Channels of the latent representation. - num_res_blocks (int): Number of residual blocks at each resolution. - channels (List[int]): Channels of the decoder blocks. - num_res_blocks_middle (int): Number of residual blocks in the middle. - norm_type (Literal["group", "layer"]): Type of normalization layer. - use_fp16 (bool): Whether to use FP16. - """ - def __init__( - self, - out_channels: int, - latent_channels: int, - num_res_blocks: int, - channels: List[int], - num_res_blocks_middle: int = 2, - norm_type: Literal["group", "layer"] = "layer", - use_fp16: bool = False, - ): - super().__init__() - self.out_channels = out_channels - self.latent_channels = latent_channels - self.num_res_blocks = num_res_blocks - self.channels = channels - self.num_res_blocks_middle = num_res_blocks_middle - self.norm_type = norm_type - self.use_fp16 = use_fp16 - self.dtype = torch.float16 if use_fp16 else torch.float32 - - self.input_layer = nn.Conv3d(latent_channels, channels[0], 3, padding=1) - - self.middle_block = nn.Sequential(*[ - ResBlock3d(channels[0], channels[0]) - for _ in range(num_res_blocks_middle) - ]) - - self.blocks = nn.ModuleList([]) - for i, ch in enumerate(channels): - self.blocks.extend([ - ResBlock3d(ch, ch) - for _ in range(num_res_blocks) - ]) - if i < len(channels) - 1: - self.blocks.append( - UpsampleBlock3d(ch, channels[i+1]) - ) - - self.out_layer = nn.Sequential( - norm_layer(norm_type, channels[-1]), - nn.SiLU(), - nn.Conv3d(channels[-1], out_channels, 3, padding=1) - ) - - if use_fp16: - self.convert_to_fp16() - - @property - def device(self) -> torch.device: - """ - Return the device of the model. - """ - return next(self.parameters()).device - - def convert_to_fp16(self) -> None: - """ - Convert the torso of the model to float16. - """ - self.use_fp16 = True - self.dtype = torch.float16 - self.blocks.apply(convert_module_to_f16) - self.middle_block.apply(convert_module_to_f16) - - def convert_to_fp32(self) -> None: - """ - Convert the torso of the model to float32. - """ - self.use_fp16 = False - self.dtype = torch.float32 - self.blocks.apply(convert_module_to_f32) - self.middle_block.apply(convert_module_to_f32) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - h = self.input_layer(x) - - h = h.type(self.dtype) - - h = self.middle_block(h) - for block in self.blocks: - h = block(h) - - h = h.type(x.dtype) - h = self.out_layer(h) - return h diff --git a/trellis/models/structured_latent_flow.py b/trellis/models/structured_latent_flow.py deleted file mode 100644 index 19c11597244ea53505d746b593d10cddad4bcb6f..0000000000000000000000000000000000000000 --- a/trellis/models/structured_latent_flow.py +++ /dev/null @@ -1,262 +0,0 @@ -from typing import * -import torch -import torch.nn as nn -import torch.nn.functional as F -import numpy as np -from ..modules.utils import zero_module, convert_module_to_f16, convert_module_to_f32 -from ..modules.transformer import AbsolutePositionEmbedder -from ..modules.norm import LayerNorm32 -from ..modules import sparse as sp -from ..modules.sparse.transformer import ModulatedSparseTransformerCrossBlock -from .sparse_structure_flow import TimestepEmbedder - - -class SparseResBlock3d(nn.Module): - def __init__( - self, - channels: int, - emb_channels: int, - out_channels: Optional[int] = None, - downsample: bool = False, - upsample: bool = False, - ): - super().__init__() - self.channels = channels - self.emb_channels = emb_channels - self.out_channels = out_channels or channels - self.downsample = downsample - self.upsample = upsample - - assert not (downsample and upsample), "Cannot downsample and upsample at the same time" - - self.norm1 = LayerNorm32(channels, elementwise_affine=True, eps=1e-6) - self.norm2 = LayerNorm32(self.out_channels, elementwise_affine=False, eps=1e-6) - self.conv1 = sp.SparseConv3d(channels, self.out_channels, 3) - self.conv2 = zero_module(sp.SparseConv3d(self.out_channels, self.out_channels, 3)) - self.emb_layers = nn.Sequential( - nn.SiLU(), - nn.Linear(emb_channels, 2 * self.out_channels, bias=True), - ) - self.skip_connection = sp.SparseLinear(channels, self.out_channels) if channels != self.out_channels else nn.Identity() - self.updown = None - if self.downsample: - self.updown = sp.SparseDownsample(2) - elif self.upsample: - self.updown = sp.SparseUpsample(2) - - def _updown(self, x: sp.SparseTensor) -> sp.SparseTensor: - if self.updown is not None: - x = self.updown(x) - return x - - def forward(self, x: sp.SparseTensor, emb: torch.Tensor) -> sp.SparseTensor: - emb_out = self.emb_layers(emb).type(x.dtype) - scale, shift = torch.chunk(emb_out, 2, dim=1) - - x = self._updown(x) - h = x.replace(self.norm1(x.feats)) - h = h.replace(F.silu(h.feats)) - h = self.conv1(h) - h = h.replace(self.norm2(h.feats)) * (1 + scale) + shift - h = h.replace(F.silu(h.feats)) - h = self.conv2(h) - h = h + self.skip_connection(x) - - return h - - -class SLatFlowModel(nn.Module): - def __init__( - self, - resolution: int, - in_channels: int, - model_channels: int, - cond_channels: int, - out_channels: int, - num_blocks: int, - num_heads: Optional[int] = None, - num_head_channels: Optional[int] = 64, - mlp_ratio: float = 4, - patch_size: int = 2, - num_io_res_blocks: int = 2, - io_block_channels: List[int] = None, - pe_mode: Literal["ape", "rope"] = "ape", - use_fp16: bool = False, - use_checkpoint: bool = False, - use_skip_connection: bool = True, - share_mod: bool = False, - qk_rms_norm: bool = False, - qk_rms_norm_cross: bool = False, - ): - super().__init__() - self.resolution = resolution - self.in_channels = in_channels - self.model_channels = model_channels - self.cond_channels = cond_channels - self.out_channels = out_channels - self.num_blocks = num_blocks - self.num_heads = num_heads or model_channels // num_head_channels - self.mlp_ratio = mlp_ratio - self.patch_size = patch_size - self.num_io_res_blocks = num_io_res_blocks - self.io_block_channels = io_block_channels - self.pe_mode = pe_mode - self.use_fp16 = use_fp16 - self.use_checkpoint = use_checkpoint - self.use_skip_connection = use_skip_connection - self.share_mod = share_mod - self.qk_rms_norm = qk_rms_norm - self.qk_rms_norm_cross = qk_rms_norm_cross - self.dtype = torch.float16 if use_fp16 else torch.float32 - - assert int(np.log2(patch_size)) == np.log2(patch_size), "Patch size must be a power of 2" - assert np.log2(patch_size) == len(io_block_channels), "Number of IO ResBlocks must match the number of stages" - - self.t_embedder = TimestepEmbedder(model_channels) - if share_mod: - self.adaLN_modulation = nn.Sequential( - nn.SiLU(), - nn.Linear(model_channels, 6 * model_channels, bias=True) - ) - - if pe_mode == "ape": - self.pos_embedder = AbsolutePositionEmbedder(model_channels) - - self.input_layer = sp.SparseLinear(in_channels, io_block_channels[0]) - self.input_blocks = nn.ModuleList([]) - for chs, next_chs in zip(io_block_channels, io_block_channels[1:] + [model_channels]): - self.input_blocks.extend([ - SparseResBlock3d( - chs, - model_channels, - out_channels=chs, - ) - for _ in range(num_io_res_blocks-1) - ]) - self.input_blocks.append( - SparseResBlock3d( - chs, - model_channels, - out_channels=next_chs, - downsample=True, - ) - ) - - self.blocks = nn.ModuleList([ - ModulatedSparseTransformerCrossBlock( - model_channels, - cond_channels, - num_heads=self.num_heads, - mlp_ratio=self.mlp_ratio, - attn_mode='full', - use_checkpoint=self.use_checkpoint, - use_rope=(pe_mode == "rope"), - share_mod=self.share_mod, - qk_rms_norm=self.qk_rms_norm, - qk_rms_norm_cross=self.qk_rms_norm_cross, - ) - for _ in range(num_blocks) - ]) - - self.out_blocks = nn.ModuleList([]) - for chs, prev_chs in zip(reversed(io_block_channels), [model_channels] + list(reversed(io_block_channels[1:]))): - self.out_blocks.append( - SparseResBlock3d( - prev_chs * 2 if self.use_skip_connection else prev_chs, - model_channels, - out_channels=chs, - upsample=True, - ) - ) - self.out_blocks.extend([ - SparseResBlock3d( - chs * 2 if self.use_skip_connection else chs, - model_channels, - out_channels=chs, - ) - for _ in range(num_io_res_blocks-1) - ]) - self.out_layer = sp.SparseLinear(io_block_channels[0], out_channels) - - self.initialize_weights() - if use_fp16: - self.convert_to_fp16() - - @property - def device(self) -> torch.device: - """ - Return the device of the model. - """ - return next(self.parameters()).device - - def convert_to_fp16(self) -> None: - """ - Convert the torso of the model to float16. - """ - self.input_blocks.apply(convert_module_to_f16) - self.blocks.apply(convert_module_to_f16) - self.out_blocks.apply(convert_module_to_f16) - - def convert_to_fp32(self) -> None: - """ - Convert the torso of the model to float32. - """ - self.input_blocks.apply(convert_module_to_f32) - self.blocks.apply(convert_module_to_f32) - self.out_blocks.apply(convert_module_to_f32) - - def initialize_weights(self) -> None: - # Initialize transformer layers: - def _basic_init(module): - if isinstance(module, nn.Linear): - torch.nn.init.xavier_uniform_(module.weight) - if module.bias is not None: - nn.init.constant_(module.bias, 0) - self.apply(_basic_init) - - # Initialize timestep embedding MLP: - nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02) - nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02) - - # Zero-out adaLN modulation layers in DiT blocks: - if self.share_mod: - nn.init.constant_(self.adaLN_modulation[-1].weight, 0) - nn.init.constant_(self.adaLN_modulation[-1].bias, 0) - else: - for block in self.blocks: - nn.init.constant_(block.adaLN_modulation[-1].weight, 0) - nn.init.constant_(block.adaLN_modulation[-1].bias, 0) - - # Zero-out output layers: - nn.init.constant_(self.out_layer.weight, 0) - nn.init.constant_(self.out_layer.bias, 0) - - def forward(self, x: sp.SparseTensor, t: torch.Tensor, cond: torch.Tensor) -> sp.SparseTensor: - h = self.input_layer(x).type(self.dtype) - t_emb = self.t_embedder(t) - if self.share_mod: - t_emb = self.adaLN_modulation(t_emb) - t_emb = t_emb.type(self.dtype) - cond = cond.type(self.dtype) - - skips = [] - # pack with input blocks - for block in self.input_blocks: - h = block(h, t_emb) - skips.append(h.feats) - - if self.pe_mode == "ape": - h = h + self.pos_embedder(h.coords[:, 1:]).type(self.dtype) - for block in self.blocks: - h = block(h, t_emb, cond) - - # unpack with output blocks - for block, skip in zip(self.out_blocks, reversed(skips)): - if self.use_skip_connection: - h = block(h.replace(torch.cat([h.feats, skip], dim=1)), t_emb) - else: - h = block(h, t_emb) - - h = h.replace(F.layer_norm(h.feats, h.feats.shape[-1:])) - h = self.out_layer(h.type(x.dtype)) - return h diff --git a/trellis/models/structured_latent_vae/__init__.py b/trellis/models/structured_latent_vae/__init__.py deleted file mode 100644 index 00cbf8826b328f9abe76cd641645f67128d7a04b..0000000000000000000000000000000000000000 --- a/trellis/models/structured_latent_vae/__init__.py +++ /dev/null @@ -1,4 +0,0 @@ -from .encoder import SLatEncoder -from .decoder_gs import SLatGaussianDecoder -from .decoder_rf import SLatRadianceFieldDecoder -from .decoder_mesh import SLatMeshDecoder diff --git a/trellis/models/structured_latent_vae/base.py b/trellis/models/structured_latent_vae/base.py deleted file mode 100644 index 7b86006fb35dee6f4f61a6f827d13787e0a287b2..0000000000000000000000000000000000000000 --- a/trellis/models/structured_latent_vae/base.py +++ /dev/null @@ -1,117 +0,0 @@ -from typing import * -import torch -import torch.nn as nn -from ...modules.utils import convert_module_to_f16, convert_module_to_f32 -from ...modules import sparse as sp -from ...modules.transformer import AbsolutePositionEmbedder -from ...modules.sparse.transformer import SparseTransformerBlock - - -def block_attn_config(self): - """ - Return the attention configuration of the model. - """ - for i in range(self.num_blocks): - if self.attn_mode == "shift_window": - yield "serialized", self.window_size, 0, (16 * (i % 2),) * 3, sp.SerializeMode.Z_ORDER - elif self.attn_mode == "shift_sequence": - yield "serialized", self.window_size, self.window_size // 2 * (i % 2), (0, 0, 0), sp.SerializeMode.Z_ORDER - elif self.attn_mode == "shift_order": - yield "serialized", self.window_size, 0, (0, 0, 0), sp.SerializeModes[i % 4] - elif self.attn_mode == "full": - yield "full", None, None, None, None - elif self.attn_mode == "swin": - yield "windowed", self.window_size, None, self.window_size // 2 * (i % 2), None - - -class SparseTransformerBase(nn.Module): - """ - Sparse Transformer without output layers. - Serve as the base class for encoder and decoder. - """ - def __init__( - self, - in_channels: int, - model_channels: int, - num_blocks: int, - num_heads: Optional[int] = None, - num_head_channels: Optional[int] = 64, - mlp_ratio: float = 4.0, - attn_mode: Literal["full", "shift_window", "shift_sequence", "shift_order", "swin"] = "full", - window_size: Optional[int] = None, - pe_mode: Literal["ape", "rope"] = "ape", - use_fp16: bool = False, - use_checkpoint: bool = False, - qk_rms_norm: bool = False, - ): - super().__init__() - self.in_channels = in_channels - self.model_channels = model_channels - self.num_blocks = num_blocks - self.window_size = window_size - self.num_heads = num_heads or model_channels // num_head_channels - self.mlp_ratio = mlp_ratio - self.attn_mode = attn_mode - self.pe_mode = pe_mode - self.use_fp16 = use_fp16 - self.use_checkpoint = use_checkpoint - self.qk_rms_norm = qk_rms_norm - self.dtype = torch.float16 if use_fp16 else torch.float32 - - if pe_mode == "ape": - self.pos_embedder = AbsolutePositionEmbedder(model_channels) - - self.input_layer = sp.SparseLinear(in_channels, model_channels) - self.blocks = nn.ModuleList([ - SparseTransformerBlock( - model_channels, - num_heads=self.num_heads, - mlp_ratio=self.mlp_ratio, - attn_mode=attn_mode, - window_size=window_size, - shift_sequence=shift_sequence, - shift_window=shift_window, - serialize_mode=serialize_mode, - use_checkpoint=self.use_checkpoint, - use_rope=(pe_mode == "rope"), - qk_rms_norm=self.qk_rms_norm, - ) - for attn_mode, window_size, shift_sequence, shift_window, serialize_mode in block_attn_config(self) - ]) - - @property - def device(self) -> torch.device: - """ - Return the device of the model. - """ - return next(self.parameters()).device - - def convert_to_fp16(self) -> None: - """ - Convert the torso of the model to float16. - """ - self.blocks.apply(convert_module_to_f16) - - def convert_to_fp32(self) -> None: - """ - Convert the torso of the model to float32. - """ - self.blocks.apply(convert_module_to_f32) - - def initialize_weights(self) -> None: - # Initialize transformer layers: - def _basic_init(module): - if isinstance(module, nn.Linear): - torch.nn.init.xavier_uniform_(module.weight) - if module.bias is not None: - nn.init.constant_(module.bias, 0) - self.apply(_basic_init) - - def forward(self, x: sp.SparseTensor) -> sp.SparseTensor: - h = self.input_layer(x) - if self.pe_mode == "ape": - h = h + self.pos_embedder(x.coords[:, 1:]) - h = h.type(self.dtype) - for block in self.blocks: - h = block(h) - return h diff --git a/trellis/models/structured_latent_vae/decoder_gs.py b/trellis/models/structured_latent_vae/decoder_gs.py deleted file mode 100644 index b6948173f57063ea1eff411f4840c8e1a711bd69..0000000000000000000000000000000000000000 --- a/trellis/models/structured_latent_vae/decoder_gs.py +++ /dev/null @@ -1,122 +0,0 @@ -from typing import * -import torch -import torch.nn as nn -import torch.nn.functional as F -from ...modules import sparse as sp -from ...utils.random_utils import hammersley_sequence -from .base import SparseTransformerBase -from ...representations import Gaussian - - -class SLatGaussianDecoder(SparseTransformerBase): - def __init__( - self, - resolution: int, - model_channels: int, - latent_channels: int, - num_blocks: int, - num_heads: Optional[int] = None, - num_head_channels: Optional[int] = 64, - mlp_ratio: float = 4, - attn_mode: Literal["full", "shift_window", "shift_sequence", "shift_order", "swin"] = "swin", - window_size: int = 8, - pe_mode: Literal["ape", "rope"] = "ape", - use_fp16: bool = False, - use_checkpoint: bool = False, - qk_rms_norm: bool = False, - representation_config: dict = None, - ): - super().__init__( - in_channels=latent_channels, - model_channels=model_channels, - num_blocks=num_blocks, - num_heads=num_heads, - num_head_channels=num_head_channels, - mlp_ratio=mlp_ratio, - attn_mode=attn_mode, - window_size=window_size, - pe_mode=pe_mode, - use_fp16=use_fp16, - use_checkpoint=use_checkpoint, - qk_rms_norm=qk_rms_norm, - ) - self.resolution = resolution - self.rep_config = representation_config - self._calc_layout() - self.out_layer = sp.SparseLinear(model_channels, self.out_channels) - self._build_perturbation() - - self.initialize_weights() - if use_fp16: - self.convert_to_fp16() - - def initialize_weights(self) -> None: - super().initialize_weights() - # Zero-out output layers: - nn.init.constant_(self.out_layer.weight, 0) - nn.init.constant_(self.out_layer.bias, 0) - - def _build_perturbation(self) -> None: - perturbation = [hammersley_sequence(3, i, self.rep_config['num_gaussians']) for i in range(self.rep_config['num_gaussians'])] - perturbation = torch.tensor(perturbation).float() * 2 - 1 - perturbation = perturbation / self.rep_config['voxel_size'] - perturbation = torch.atanh(perturbation).to(self.device) - self.register_buffer('offset_perturbation', perturbation) - - def _calc_layout(self) -> None: - self.layout = { - '_xyz' : {'shape': (self.rep_config['num_gaussians'], 3), 'size': self.rep_config['num_gaussians'] * 3}, - '_features_dc' : {'shape': (self.rep_config['num_gaussians'], 1, 3), 'size': self.rep_config['num_gaussians'] * 3}, - '_scaling' : {'shape': (self.rep_config['num_gaussians'], 3), 'size': self.rep_config['num_gaussians'] * 3}, - '_rotation' : {'shape': (self.rep_config['num_gaussians'], 4), 'size': self.rep_config['num_gaussians'] * 4}, - '_opacity' : {'shape': (self.rep_config['num_gaussians'], 1), 'size': self.rep_config['num_gaussians']}, - } - start = 0 - for k, v in self.layout.items(): - v['range'] = (start, start + v['size']) - start += v['size'] - self.out_channels = start - - def to_representation(self, x: sp.SparseTensor) -> List[Gaussian]: - """ - Convert a batch of network outputs to 3D representations. - - Args: - x: The [N x * x C] sparse tensor output by the network. - - Returns: - list of representations - """ - ret = [] - for i in range(x.shape[0]): - representation = Gaussian( - sh_degree=0, - aabb=[-0.5, -0.5, -0.5, 1.0, 1.0, 1.0], - mininum_kernel_size = self.rep_config['3d_filter_kernel_size'], - scaling_bias = self.rep_config['scaling_bias'], - opacity_bias = self.rep_config['opacity_bias'], - scaling_activation = self.rep_config['scaling_activation'] - ) - xyz = (x.coords[x.layout[i]][:, 1:].float() + 0.5) / self.resolution - for k, v in self.layout.items(): - if k == '_xyz': - offset = x.feats[x.layout[i]][:, v['range'][0]:v['range'][1]].reshape(-1, *v['shape']) - offset = offset * self.rep_config['lr'][k] - if self.rep_config['perturb_offset']: - offset = offset + self.offset_perturbation - offset = torch.tanh(offset) / self.resolution * 0.5 * self.rep_config['voxel_size'] - _xyz = xyz.unsqueeze(1) + offset - setattr(representation, k, _xyz.flatten(0, 1)) - else: - feats = x.feats[x.layout[i]][:, v['range'][0]:v['range'][1]].reshape(-1, *v['shape']).flatten(0, 1) - feats = feats * self.rep_config['lr'][k] - setattr(representation, k, feats) - ret.append(representation) - return ret - - def forward(self, x: sp.SparseTensor) -> List[Gaussian]: - h = super().forward(x) - h = h.type(x.dtype) - h = h.replace(F.layer_norm(h.feats, h.feats.shape[-1:])) - h = self.out_layer(h) - return self.to_representation(h) diff --git a/trellis/models/structured_latent_vae/decoder_mesh.py b/trellis/models/structured_latent_vae/decoder_mesh.py deleted file mode 100644 index 06c0e7286af45eab1c9861b65174431fc2210bcc..0000000000000000000000000000000000000000 --- a/trellis/models/structured_latent_vae/decoder_mesh.py +++ /dev/null @@ -1,167 +0,0 @@ -from typing import * -import torch -import torch.nn as nn -import torch.nn.functional as F -import numpy as np -from ...modules.utils import zero_module, convert_module_to_f16, convert_module_to_f32 -from ...modules import sparse as sp -from .base import SparseTransformerBase -from ...representations import MeshExtractResult -from ...representations.mesh import SparseFeatures2Mesh - - -class SparseSubdivideBlock3d(nn.Module): - """ - A 3D subdivide block that can subdivide the sparse tensor. - - Args: - channels: channels in the inputs and outputs. - out_channels: if specified, the number of output channels. - num_groups: the number of groups for the group norm. - """ - def __init__( - self, - channels: int, - resolution: int, - out_channels: Optional[int] = None, - num_groups: int = 32 - ): - super().__init__() - self.channels = channels - self.resolution = resolution - self.out_resolution = resolution * 2 - self.out_channels = out_channels or channels - - self.act_layers = nn.Sequential( - sp.SparseGroupNorm32(num_groups, channels), - sp.SparseSiLU() - ) - - self.sub = sp.SparseSubdivide() - - self.out_layers = nn.Sequential( - sp.SparseConv3d(channels, self.out_channels, 3, indice_key=f"res_{self.out_resolution}"), - sp.SparseGroupNorm32(num_groups, self.out_channels), - sp.SparseSiLU(), - zero_module(sp.SparseConv3d(self.out_channels, self.out_channels, 3, indice_key=f"res_{self.out_resolution}")), - ) - - if self.out_channels == channels: - self.skip_connection = nn.Identity() - else: - self.skip_connection = sp.SparseConv3d(channels, self.out_channels, 1, indice_key=f"res_{self.out_resolution}") - - def forward(self, x: sp.SparseTensor) -> sp.SparseTensor: - """ - Apply the block to a Tensor, conditioned on a timestep embedding. - - Args: - x: an [N x C x ...] Tensor of features. - Returns: - an [N x C x ...] Tensor of outputs. - """ - h = self.act_layers(x) - h = self.sub(h) - x = self.sub(x) - h = self.out_layers(h) - h = h + self.skip_connection(x) - return h - - -class SLatMeshDecoder(SparseTransformerBase): - def __init__( - self, - resolution: int, - model_channels: int, - latent_channels: int, - num_blocks: int, - num_heads: Optional[int] = None, - num_head_channels: Optional[int] = 64, - mlp_ratio: float = 4, - attn_mode: Literal["full", "shift_window", "shift_sequence", "shift_order", "swin"] = "swin", - window_size: int = 8, - pe_mode: Literal["ape", "rope"] = "ape", - use_fp16: bool = False, - use_checkpoint: bool = False, - qk_rms_norm: bool = False, - representation_config: dict = None, - ): - super().__init__( - in_channels=latent_channels, - model_channels=model_channels, - num_blocks=num_blocks, - num_heads=num_heads, - num_head_channels=num_head_channels, - mlp_ratio=mlp_ratio, - attn_mode=attn_mode, - window_size=window_size, - pe_mode=pe_mode, - use_fp16=use_fp16, - use_checkpoint=use_checkpoint, - qk_rms_norm=qk_rms_norm, - ) - self.resolution = resolution - self.rep_config = representation_config - self.mesh_extractor = SparseFeatures2Mesh(res=self.resolution*4, use_color=self.rep_config.get('use_color', False)) - self.out_channels = self.mesh_extractor.feats_channels - self.upsample = nn.ModuleList([ - SparseSubdivideBlock3d( - channels=model_channels, - resolution=resolution, - out_channels=model_channels // 4 - ), - SparseSubdivideBlock3d( - channels=model_channels // 4, - resolution=resolution * 2, - out_channels=model_channels // 8 - ) - ]) - self.out_layer = sp.SparseLinear(model_channels // 8, self.out_channels) - - self.initialize_weights() - if use_fp16: - self.convert_to_fp16() - - def initialize_weights(self) -> None: - super().initialize_weights() - # Zero-out output layers: - nn.init.constant_(self.out_layer.weight, 0) - nn.init.constant_(self.out_layer.bias, 0) - - def convert_to_fp16(self) -> None: - """ - Convert the torso of the model to float16. - """ - super().convert_to_fp16() - self.upsample.apply(convert_module_to_f16) - - def convert_to_fp32(self) -> None: - """ - Convert the torso of the model to float32. - """ - super().convert_to_fp32() - self.upsample.apply(convert_module_to_f32) - - def to_representation(self, x: sp.SparseTensor) -> List[MeshExtractResult]: - """ - Convert a batch of network outputs to 3D representations. - - Args: - x: The [N x * x C] sparse tensor output by the network. - - Returns: - list of representations - """ - ret = [] - for i in range(x.shape[0]): - mesh = self.mesh_extractor(x[i], training=self.training) - ret.append(mesh) - return ret - - def forward(self, x: sp.SparseTensor) -> List[MeshExtractResult]: - h = super().forward(x) - for block in self.upsample: - h = block(h) - h = h.type(x.dtype) - h = self.out_layer(h) - return self.to_representation(h) diff --git a/trellis/models/structured_latent_vae/decoder_rf.py b/trellis/models/structured_latent_vae/decoder_rf.py deleted file mode 100644 index 4e916eebafd7e97fed82cadb567244719dbbcd83..0000000000000000000000000000000000000000 --- a/trellis/models/structured_latent_vae/decoder_rf.py +++ /dev/null @@ -1,104 +0,0 @@ -from typing import * -import torch -import torch.nn as nn -import torch.nn.functional as F -import numpy as np -from ...modules import sparse as sp -from .base import SparseTransformerBase -from ...representations import Strivec - - -class SLatRadianceFieldDecoder(SparseTransformerBase): - def __init__( - self, - resolution: int, - model_channels: int, - latent_channels: int, - num_blocks: int, - num_heads: Optional[int] = None, - num_head_channels: Optional[int] = 64, - mlp_ratio: float = 4, - attn_mode: Literal["full", "shift_window", "shift_sequence", "shift_order", "swin"] = "swin", - window_size: int = 8, - pe_mode: Literal["ape", "rope"] = "ape", - use_fp16: bool = False, - use_checkpoint: bool = False, - qk_rms_norm: bool = False, - representation_config: dict = None, - ): - super().__init__( - in_channels=latent_channels, - model_channels=model_channels, - num_blocks=num_blocks, - num_heads=num_heads, - num_head_channels=num_head_channels, - mlp_ratio=mlp_ratio, - attn_mode=attn_mode, - window_size=window_size, - pe_mode=pe_mode, - use_fp16=use_fp16, - use_checkpoint=use_checkpoint, - qk_rms_norm=qk_rms_norm, - ) - self.resolution = resolution - self.rep_config = representation_config - self._calc_layout() - self.out_layer = sp.SparseLinear(model_channels, self.out_channels) - - self.initialize_weights() - if use_fp16: - self.convert_to_fp16() - - def initialize_weights(self) -> None: - super().initialize_weights() - # Zero-out output layers: - nn.init.constant_(self.out_layer.weight, 0) - nn.init.constant_(self.out_layer.bias, 0) - - def _calc_layout(self) -> None: - self.layout = { - 'trivec': {'shape': (self.rep_config['rank'], 3, self.rep_config['dim']), 'size': self.rep_config['rank'] * 3 * self.rep_config['dim']}, - 'density': {'shape': (self.rep_config['rank'],), 'size': self.rep_config['rank']}, - 'features_dc': {'shape': (self.rep_config['rank'], 1, 3), 'size': self.rep_config['rank'] * 3}, - } - start = 0 - for k, v in self.layout.items(): - v['range'] = (start, start + v['size']) - start += v['size'] - self.out_channels = start - - def to_representation(self, x: sp.SparseTensor) -> List[Strivec]: - """ - Convert a batch of network outputs to 3D representations. - - Args: - x: The [N x * x C] sparse tensor output by the network. - - Returns: - list of representations - """ - ret = [] - for i in range(x.shape[0]): - representation = Strivec( - sh_degree=0, - resolution=self.resolution, - aabb=[-0.5, -0.5, -0.5, 1, 1, 1], - rank=self.rep_config['rank'], - dim=self.rep_config['dim'], - device='cuda', - ) - representation.density_shift = 0.0 - representation.position = (x.coords[x.layout[i]][:, 1:].float() + 0.5) / self.resolution - representation.depth = torch.full((representation.position.shape[0], 1), int(np.log2(self.resolution)), dtype=torch.uint8, device='cuda') - for k, v in self.layout.items(): - setattr(representation, k, x.feats[x.layout[i]][:, v['range'][0]:v['range'][1]].reshape(-1, *v['shape'])) - representation.trivec = representation.trivec + 1 - ret.append(representation) - return ret - - def forward(self, x: sp.SparseTensor) -> List[Strivec]: - h = super().forward(x) - h = h.type(x.dtype) - h = h.replace(F.layer_norm(h.feats, h.feats.shape[-1:])) - h = self.out_layer(h) - return self.to_representation(h) diff --git a/trellis/models/structured_latent_vae/encoder.py b/trellis/models/structured_latent_vae/encoder.py deleted file mode 100644 index d3c04928bbd0a3fe88c05c687024a92daa0a1d6d..0000000000000000000000000000000000000000 --- a/trellis/models/structured_latent_vae/encoder.py +++ /dev/null @@ -1,72 +0,0 @@ -from typing import * -import torch -import torch.nn as nn -import torch.nn.functional as F -from ...modules import sparse as sp -from .base import SparseTransformerBase - - -class SLatEncoder(SparseTransformerBase): - def __init__( - self, - resolution: int, - in_channels: int, - model_channels: int, - latent_channels: int, - num_blocks: int, - num_heads: Optional[int] = None, - num_head_channels: Optional[int] = 64, - mlp_ratio: float = 4, - attn_mode: Literal["full", "shift_window", "shift_sequence", "shift_order", "swin"] = "swin", - window_size: int = 8, - pe_mode: Literal["ape", "rope"] = "ape", - use_fp16: bool = False, - use_checkpoint: bool = False, - qk_rms_norm: bool = False, - ): - super().__init__( - in_channels=in_channels, - model_channels=model_channels, - num_blocks=num_blocks, - num_heads=num_heads, - num_head_channels=num_head_channels, - mlp_ratio=mlp_ratio, - attn_mode=attn_mode, - window_size=window_size, - pe_mode=pe_mode, - use_fp16=use_fp16, - use_checkpoint=use_checkpoint, - qk_rms_norm=qk_rms_norm, - ) - self.resolution = resolution - self.out_layer = sp.SparseLinear(model_channels, 2 * latent_channels) - - self.initialize_weights() - if use_fp16: - self.convert_to_fp16() - - def initialize_weights(self) -> None: - super().initialize_weights() - # Zero-out output layers: - nn.init.constant_(self.out_layer.weight, 0) - nn.init.constant_(self.out_layer.bias, 0) - - def forward(self, x: sp.SparseTensor, sample_posterior=True, return_raw=False): - h = super().forward(x) - h = h.type(x.dtype) - h = h.replace(F.layer_norm(h.feats, h.feats.shape[-1:])) - h = self.out_layer(h) - - # Sample from the posterior distribution - mean, logvar = h.feats.chunk(2, dim=-1) - if sample_posterior: - std = torch.exp(0.5 * logvar) - z = mean + std * torch.randn_like(std) - else: - z = mean - z = h.replace(z) - - if return_raw: - return z, mean, logvar - else: - return z diff --git a/trellis/modules/attention/__init__.py b/trellis/modules/attention/__init__.py deleted file mode 100644 index ffebf7dbf737606b32ef62c2e86f568189d322f0..0000000000000000000000000000000000000000 --- a/trellis/modules/attention/__init__.py +++ /dev/null @@ -1,36 +0,0 @@ -from typing import * - -BACKEND = 'flash_attn' -DEBUG = False - -def __from_env(): - import os - - global BACKEND - global DEBUG - - env_attn_backend = os.environ.get('ATTN_BACKEND') - env_sttn_debug = os.environ.get('ATTN_DEBUG') - - if env_attn_backend is not None and env_attn_backend in ['xformers', 'flash_attn', 'sdpa', 'naive']: - BACKEND = env_attn_backend - if env_sttn_debug is not None: - DEBUG = env_sttn_debug == '1' - - print(f"[ATTENTION] Using backend: {BACKEND}") - - -__from_env() - - -def set_backend(backend: Literal['xformers', 'flash_attn']): - global BACKEND - BACKEND = backend - -def set_debug(debug: bool): - global DEBUG - DEBUG = debug - - -from .full_attn import * -from .modules import * diff --git a/trellis/modules/attention/full_attn.py b/trellis/modules/attention/full_attn.py deleted file mode 100644 index 68303dca94cefacb43865d3b737f2723dded20dd..0000000000000000000000000000000000000000 --- a/trellis/modules/attention/full_attn.py +++ /dev/null @@ -1,140 +0,0 @@ -from typing import * -import torch -import math -from . import DEBUG, BACKEND - -if BACKEND == 'xformers': - import xformers.ops as xops -elif BACKEND == 'flash_attn': - import flash_attn -elif BACKEND == 'sdpa': - from torch.nn.functional import scaled_dot_product_attention as sdpa -elif BACKEND == 'naive': - pass -else: - raise ValueError(f"Unknown attention backend: {BACKEND}") - - -__all__ = [ - 'scaled_dot_product_attention', -] - - -def _naive_sdpa(q, k, v): - """ - Naive implementation of scaled dot product attention. - """ - q = q.permute(0, 2, 1, 3) # [N, H, L, C] - k = k.permute(0, 2, 1, 3) # [N, H, L, C] - v = v.permute(0, 2, 1, 3) # [N, H, L, C] - scale_factor = 1 / math.sqrt(q.size(-1)) - attn_weight = q @ k.transpose(-2, -1) * scale_factor - attn_weight = torch.softmax(attn_weight, dim=-1) - out = attn_weight @ v - out = out.permute(0, 2, 1, 3) # [N, L, H, C] - return out - - -@overload -def scaled_dot_product_attention(qkv: torch.Tensor) -> torch.Tensor: - """ - Apply scaled dot product attention. - - Args: - qkv (torch.Tensor): A [N, L, 3, H, C] tensor containing Qs, Ks, and Vs. - """ - ... - -@overload -def scaled_dot_product_attention(q: torch.Tensor, kv: torch.Tensor) -> torch.Tensor: - """ - Apply scaled dot product attention. - - Args: - q (torch.Tensor): A [N, L, H, C] tensor containing Qs. - kv (torch.Tensor): A [N, L, 2, H, C] tensor containing Ks and Vs. - """ - ... - -@overload -def scaled_dot_product_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> torch.Tensor: - """ - Apply scaled dot product attention. - - Args: - q (torch.Tensor): A [N, L, H, Ci] tensor containing Qs. - k (torch.Tensor): A [N, L, H, Ci] tensor containing Ks. - v (torch.Tensor): A [N, L, H, Co] tensor containing Vs. - - Note: - k and v are assumed to have the same coordinate map. - """ - ... - -def scaled_dot_product_attention(*args, **kwargs): - arg_names_dict = { - 1: ['qkv'], - 2: ['q', 'kv'], - 3: ['q', 'k', 'v'] - } - num_all_args = len(args) + len(kwargs) - assert num_all_args in arg_names_dict, f"Invalid number of arguments, got {num_all_args}, expected 1, 2, or 3" - for key in arg_names_dict[num_all_args][len(args):]: - assert key in kwargs, f"Missing argument {key}" - - if num_all_args == 1: - qkv = args[0] if len(args) > 0 else kwargs['qkv'] - assert len(qkv.shape) == 5 and qkv.shape[2] == 3, f"Invalid shape for qkv, got {qkv.shape}, expected [N, L, 3, H, C]" - device = qkv.device - - elif num_all_args == 2: - q = args[0] if len(args) > 0 else kwargs['q'] - kv = args[1] if len(args) > 1 else kwargs['kv'] - assert q.shape[0] == kv.shape[0], f"Batch size mismatch, got {q.shape[0]} and {kv.shape[0]}" - assert len(q.shape) == 4, f"Invalid shape for q, got {q.shape}, expected [N, L, H, C]" - assert len(kv.shape) == 5, f"Invalid shape for kv, got {kv.shape}, expected [N, L, 2, H, C]" - device = q.device - - elif num_all_args == 3: - q = args[0] if len(args) > 0 else kwargs['q'] - k = args[1] if len(args) > 1 else kwargs['k'] - v = args[2] if len(args) > 2 else kwargs['v'] - assert q.shape[0] == k.shape[0] == v.shape[0], f"Batch size mismatch, got {q.shape[0]}, {k.shape[0]}, and {v.shape[0]}" - assert len(q.shape) == 4, f"Invalid shape for q, got {q.shape}, expected [N, L, H, Ci]" - assert len(k.shape) == 4, f"Invalid shape for k, got {k.shape}, expected [N, L, H, Ci]" - assert len(v.shape) == 4, f"Invalid shape for v, got {v.shape}, expected [N, L, H, Co]" - device = q.device - - if BACKEND == 'xformers': - if num_all_args == 1: - q, k, v = qkv.unbind(dim=2) - elif num_all_args == 2: - k, v = kv.unbind(dim=2) - out = xops.memory_efficient_attention(q, k, v) - elif BACKEND == 'flash_attn': - if num_all_args == 1: - out = flash_attn.flash_attn_qkvpacked_func(qkv) - elif num_all_args == 2: - out = flash_attn.flash_attn_kvpacked_func(q, kv) - elif num_all_args == 3: - out = flash_attn.flash_attn_func(q, k, v) - elif BACKEND == 'sdpa': - if num_all_args == 1: - q, k, v = qkv.unbind(dim=2) - elif num_all_args == 2: - k, v = kv.unbind(dim=2) - q = q.permute(0, 2, 1, 3) # [N, H, L, C] - k = k.permute(0, 2, 1, 3) # [N, H, L, C] - v = v.permute(0, 2, 1, 3) # [N, H, L, C] - out = sdpa(q, k, v) # [N, H, L, C] - out = out.permute(0, 2, 1, 3) # [N, L, H, C] - elif BACKEND == 'naive': - if num_all_args == 1: - q, k, v = qkv.unbind(dim=2) - elif num_all_args == 2: - k, v = kv.unbind(dim=2) - out = _naive_sdpa(q, k, v) - else: - raise ValueError(f"Unknown attention module: {BACKEND}") - - return out diff --git a/trellis/modules/attention/modules.py b/trellis/modules/attention/modules.py deleted file mode 100644 index a82a9d27b6c02c7c88c724ddc0456fe61f6c0fb0..0000000000000000000000000000000000000000 --- a/trellis/modules/attention/modules.py +++ /dev/null @@ -1,146 +0,0 @@ -from typing import * -import torch -import torch.nn as nn -import torch.nn.functional as F -from .full_attn import scaled_dot_product_attention - - -class MultiHeadRMSNorm(nn.Module): - def __init__(self, dim: int, heads: int): - super().__init__() - self.scale = dim ** 0.5 - self.gamma = nn.Parameter(torch.ones(heads, dim)) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - return (F.normalize(x.float(), dim = -1) * self.gamma * self.scale).to(x.dtype) - - -class RotaryPositionEmbedder(nn.Module): - def __init__(self, hidden_size: int, in_channels: int = 3): - super().__init__() - assert hidden_size % 2 == 0, "Hidden size must be divisible by 2" - self.hidden_size = hidden_size - self.in_channels = in_channels - self.freq_dim = hidden_size // in_channels // 2 - self.freqs = torch.arange(self.freq_dim, dtype=torch.float32) / self.freq_dim - self.freqs = 1.0 / (10000 ** self.freqs) - - def _get_phases(self, indices: torch.Tensor) -> torch.Tensor: - self.freqs = self.freqs.to(indices.device) - phases = torch.outer(indices, self.freqs) - phases = torch.polar(torch.ones_like(phases), phases) - return phases - - def _rotary_embedding(self, x: torch.Tensor, phases: torch.Tensor) -> torch.Tensor: - x_complex = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2)) - x_rotated = x_complex * phases - x_embed = torch.view_as_real(x_rotated).reshape(*x_rotated.shape[:-1], -1).to(x.dtype) - return x_embed - - def forward(self, q: torch.Tensor, k: torch.Tensor, indices: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, torch.Tensor]: - """ - Args: - q (sp.SparseTensor): [..., N, D] tensor of queries - k (sp.SparseTensor): [..., N, D] tensor of keys - indices (torch.Tensor): [..., N, C] tensor of spatial positions - """ - if indices is None: - indices = torch.arange(q.shape[-2], device=q.device) - if len(q.shape) > 2: - indices = indices.unsqueeze(0).expand(q.shape[:-2] + (-1,)) - - phases = self._get_phases(indices.reshape(-1)).reshape(*indices.shape[:-1], -1) - if phases.shape[1] < self.hidden_size // 2: - phases = torch.cat([phases, torch.polar( - torch.ones(*phases.shape[:-1], self.hidden_size // 2 - phases.shape[1], device=phases.device), - torch.zeros(*phases.shape[:-1], self.hidden_size // 2 - phases.shape[1], device=phases.device) - )], dim=-1) - q_embed = self._rotary_embedding(q, phases) - k_embed = self._rotary_embedding(k, phases) - return q_embed, k_embed - - -class MultiHeadAttention(nn.Module): - def __init__( - self, - channels: int, - num_heads: int, - ctx_channels: Optional[int]=None, - type: Literal["self", "cross"] = "self", - attn_mode: Literal["full", "windowed"] = "full", - window_size: Optional[int] = None, - shift_window: Optional[Tuple[int, int, int]] = None, - qkv_bias: bool = True, - use_rope: bool = False, - qk_rms_norm: bool = False, - ): - super().__init__() - assert channels % num_heads == 0 - assert type in ["self", "cross"], f"Invalid attention type: {type}" - assert attn_mode in ["full", "windowed"], f"Invalid attention mode: {attn_mode}" - assert type == "self" or attn_mode == "full", "Cross-attention only supports full attention" - - if attn_mode == "windowed": - raise NotImplementedError("Windowed attention is not yet implemented") - - self.channels = channels - self.head_dim = channels // num_heads - self.ctx_channels = ctx_channels if ctx_channels is not None else channels - self.num_heads = num_heads - self._type = type - self.attn_mode = attn_mode - self.window_size = window_size - self.shift_window = shift_window - self.use_rope = use_rope - self.qk_rms_norm = qk_rms_norm - - if self._type == "self": - self.to_qkv = nn.Linear(channels, channels * 3, bias=qkv_bias) - else: - self.to_q = nn.Linear(channels, channels, bias=qkv_bias) - self.to_kv = nn.Linear(self.ctx_channels, channels * 2, bias=qkv_bias) - - if self.qk_rms_norm: - self.q_rms_norm = MultiHeadRMSNorm(self.head_dim, num_heads) - self.k_rms_norm = MultiHeadRMSNorm(self.head_dim, num_heads) - - self.to_out = nn.Linear(channels, channels) - - if use_rope: - self.rope = RotaryPositionEmbedder(channels) - - def forward(self, x: torch.Tensor, context: Optional[torch.Tensor] = None, indices: Optional[torch.Tensor] = None) -> torch.Tensor: - B, L, C = x.shape - if self._type == "self": - qkv = self.to_qkv(x) - qkv = qkv.reshape(B, L, 3, self.num_heads, -1) - if self.use_rope: - q, k, v = qkv.unbind(dim=2) - q, k = self.rope(q, k, indices) - qkv = torch.stack([q, k, v], dim=2) - if self.attn_mode == "full": - if self.qk_rms_norm: - q, k, v = qkv.unbind(dim=2) - q = self.q_rms_norm(q) - k = self.k_rms_norm(k) - h = scaled_dot_product_attention(q, k, v) - else: - h = scaled_dot_product_attention(qkv) - elif self.attn_mode == "windowed": - raise NotImplementedError("Windowed attention is not yet implemented") - else: - Lkv = context.shape[1] - q = self.to_q(x) - kv = self.to_kv(context) - q = q.reshape(B, L, self.num_heads, -1) - kv = kv.reshape(B, Lkv, 2, self.num_heads, -1) - if self.qk_rms_norm: - q = self.q_rms_norm(q) - k, v = kv.unbind(dim=2) - k = self.k_rms_norm(k) - h = scaled_dot_product_attention(q, k, v) - else: - h = scaled_dot_product_attention(q, kv) - h = h.reshape(B, L, -1) - h = self.to_out(h) - return h diff --git a/trellis/modules/norm.py b/trellis/modules/norm.py deleted file mode 100644 index 8f1750b88aae9e5ea901fc7af9f978a0db82de6d..0000000000000000000000000000000000000000 --- a/trellis/modules/norm.py +++ /dev/null @@ -1,25 +0,0 @@ -import torch -import torch.nn as nn - - -class LayerNorm32(nn.LayerNorm): - def forward(self, x: torch.Tensor) -> torch.Tensor: - return super().forward(x.float()).type(x.dtype) - - -class GroupNorm32(nn.GroupNorm): - """ - A GroupNorm layer that converts to float32 before the forward pass. - """ - def forward(self, x: torch.Tensor) -> torch.Tensor: - return super().forward(x.float()).type(x.dtype) - - -class ChannelLayerNorm32(LayerNorm32): - def forward(self, x: torch.Tensor) -> torch.Tensor: - DIM = x.dim() - x = x.permute(0, *range(2, DIM), 1).contiguous() - x = super().forward(x) - x = x.permute(0, DIM-1, *range(1, DIM-1)).contiguous() - return x - \ No newline at end of file diff --git a/trellis/modules/sparse/__init__.py b/trellis/modules/sparse/__init__.py deleted file mode 100644 index df649cc4431a0f7f1a49ed15780f4217399adf66..0000000000000000000000000000000000000000 --- a/trellis/modules/sparse/__init__.py +++ /dev/null @@ -1,102 +0,0 @@ -from typing import * - -BACKEND = 'spconv' -DEBUG = False -ATTN = 'flash_attn' - -def __from_env(): - import os - - global BACKEND - global DEBUG - global ATTN - - env_sparse_backend = os.environ.get('SPARSE_BACKEND') - env_sparse_debug = os.environ.get('SPARSE_DEBUG') - env_sparse_attn = os.environ.get('SPARSE_ATTN_BACKEND') - if env_sparse_attn is None: - env_sparse_attn = os.environ.get('ATTN_BACKEND') - - if env_sparse_backend is not None and env_sparse_backend in ['spconv', 'torchsparse']: - BACKEND = env_sparse_backend - if env_sparse_debug is not None: - DEBUG = env_sparse_debug == '1' - if env_sparse_attn is not None and env_sparse_attn in ['xformers', 'flash_attn']: - ATTN = env_sparse_attn - - print(f"[SPARSE] Backend: {BACKEND}, Attention: {ATTN}") - - -__from_env() - - -def set_backend(backend: Literal['spconv', 'torchsparse']): - global BACKEND - BACKEND = backend - -def set_debug(debug: bool): - global DEBUG - DEBUG = debug - -def set_attn(attn: Literal['xformers', 'flash_attn']): - global ATTN - ATTN = attn - - -import importlib - -__attributes = { - 'SparseTensor': 'basic', - 'sparse_batch_broadcast': 'basic', - 'sparse_batch_op': 'basic', - 'sparse_cat': 'basic', - 'sparse_unbind': 'basic', - 'SparseGroupNorm': 'norm', - 'SparseLayerNorm': 'norm', - 'SparseGroupNorm32': 'norm', - 'SparseLayerNorm32': 'norm', - 'SparseReLU': 'nonlinearity', - 'SparseSiLU': 'nonlinearity', - 'SparseGELU': 'nonlinearity', - 'SparseActivation': 'nonlinearity', - 'SparseLinear': 'linear', - 'sparse_scaled_dot_product_attention': 'attention', - 'SerializeMode': 'attention', - 'sparse_serialized_scaled_dot_product_self_attention': 'attention', - 'sparse_windowed_scaled_dot_product_self_attention': 'attention', - 'SparseMultiHeadAttention': 'attention', - 'SparseConv3d': 'conv', - 'SparseInverseConv3d': 'conv', - 'SparseDownsample': 'spatial', - 'SparseUpsample': 'spatial', - 'SparseSubdivide' : 'spatial' -} - -__submodules = ['transformer'] - -__all__ = list(__attributes.keys()) + __submodules - -def __getattr__(name): - if name not in globals(): - if name in __attributes: - module_name = __attributes[name] - module = importlib.import_module(f".{module_name}", __name__) - globals()[name] = getattr(module, name) - elif name in __submodules: - module = importlib.import_module(f".{name}", __name__) - globals()[name] = module - else: - raise AttributeError(f"module {__name__} has no attribute {name}") - return globals()[name] - - -# For Pylance -if __name__ == '__main__': - from .basic import * - from .norm import * - from .nonlinearity import * - from .linear import * - from .attention import * - from .conv import * - from .spatial import * - import transformer diff --git a/trellis/modules/sparse/attention/__init__.py b/trellis/modules/sparse/attention/__init__.py deleted file mode 100644 index 400de9a25960cf8ed32a3d7ec143af95b5f862bc..0000000000000000000000000000000000000000 --- a/trellis/modules/sparse/attention/__init__.py +++ /dev/null @@ -1,4 +0,0 @@ -from .full_attn import * -from .serialized_attn import * -from .windowed_attn import * -from .modules import * diff --git a/trellis/modules/sparse/attention/full_attn.py b/trellis/modules/sparse/attention/full_attn.py deleted file mode 100644 index c724327678d38a5625699ace09c67107458b4d0a..0000000000000000000000000000000000000000 --- a/trellis/modules/sparse/attention/full_attn.py +++ /dev/null @@ -1,215 +0,0 @@ -from typing import * -import torch -from .. import SparseTensor -from .. import DEBUG, ATTN - -if ATTN == 'xformers': - import xformers.ops as xops -elif ATTN == 'flash_attn': - import flash_attn -else: - raise ValueError(f"Unknown attention module: {ATTN}") - - -__all__ = [ - 'sparse_scaled_dot_product_attention', -] - - -@overload -def sparse_scaled_dot_product_attention(qkv: SparseTensor) -> SparseTensor: - """ - Apply scaled dot product attention to a sparse tensor. - - Args: - qkv (SparseTensor): A [N, *, 3, H, C] sparse tensor containing Qs, Ks, and Vs. - """ - ... - -@overload -def sparse_scaled_dot_product_attention(q: SparseTensor, kv: Union[SparseTensor, torch.Tensor]) -> SparseTensor: - """ - Apply scaled dot product attention to a sparse tensor. - - Args: - q (SparseTensor): A [N, *, H, C] sparse tensor containing Qs. - kv (SparseTensor or torch.Tensor): A [N, *, 2, H, C] sparse tensor or a [N, L, 2, H, C] dense tensor containing Ks and Vs. - """ - ... - -@overload -def sparse_scaled_dot_product_attention(q: torch.Tensor, kv: SparseTensor) -> torch.Tensor: - """ - Apply scaled dot product attention to a sparse tensor. - - Args: - q (SparseTensor): A [N, L, H, C] dense tensor containing Qs. - kv (SparseTensor or torch.Tensor): A [N, *, 2, H, C] sparse tensor containing Ks and Vs. - """ - ... - -@overload -def sparse_scaled_dot_product_attention(q: SparseTensor, k: SparseTensor, v: SparseTensor) -> SparseTensor: - """ - Apply scaled dot product attention to a sparse tensor. - - Args: - q (SparseTensor): A [N, *, H, Ci] sparse tensor containing Qs. - k (SparseTensor): A [N, *, H, Ci] sparse tensor containing Ks. - v (SparseTensor): A [N, *, H, Co] sparse tensor containing Vs. - - Note: - k and v are assumed to have the same coordinate map. - """ - ... - -@overload -def sparse_scaled_dot_product_attention(q: SparseTensor, k: torch.Tensor, v: torch.Tensor) -> SparseTensor: - """ - Apply scaled dot product attention to a sparse tensor. - - Args: - q (SparseTensor): A [N, *, H, Ci] sparse tensor containing Qs. - k (torch.Tensor): A [N, L, H, Ci] dense tensor containing Ks. - v (torch.Tensor): A [N, L, H, Co] dense tensor containing Vs. - """ - ... - -@overload -def sparse_scaled_dot_product_attention(q: torch.Tensor, k: SparseTensor, v: SparseTensor) -> torch.Tensor: - """ - Apply scaled dot product attention to a sparse tensor. - - Args: - q (torch.Tensor): A [N, L, H, Ci] dense tensor containing Qs. - k (SparseTensor): A [N, *, H, Ci] sparse tensor containing Ks. - v (SparseTensor): A [N, *, H, Co] sparse tensor containing Vs. - """ - ... - -def sparse_scaled_dot_product_attention(*args, **kwargs): - arg_names_dict = { - 1: ['qkv'], - 2: ['q', 'kv'], - 3: ['q', 'k', 'v'] - } - num_all_args = len(args) + len(kwargs) - assert num_all_args in arg_names_dict, f"Invalid number of arguments, got {num_all_args}, expected 1, 2, or 3" - for key in arg_names_dict[num_all_args][len(args):]: - assert key in kwargs, f"Missing argument {key}" - - if num_all_args == 1: - qkv = args[0] if len(args) > 0 else kwargs['qkv'] - assert isinstance(qkv, SparseTensor), f"qkv must be a SparseTensor, got {type(qkv)}" - assert len(qkv.shape) == 4 and qkv.shape[1] == 3, f"Invalid shape for qkv, got {qkv.shape}, expected [N, *, 3, H, C]" - device = qkv.device - - s = qkv - q_seqlen = [qkv.layout[i].stop - qkv.layout[i].start for i in range(qkv.shape[0])] - kv_seqlen = q_seqlen - qkv = qkv.feats # [T, 3, H, C] - - elif num_all_args == 2: - q = args[0] if len(args) > 0 else kwargs['q'] - kv = args[1] if len(args) > 1 else kwargs['kv'] - assert isinstance(q, SparseTensor) and isinstance(kv, (SparseTensor, torch.Tensor)) or \ - isinstance(q, torch.Tensor) and isinstance(kv, SparseTensor), \ - f"Invalid types, got {type(q)} and {type(kv)}" - assert q.shape[0] == kv.shape[0], f"Batch size mismatch, got {q.shape[0]} and {kv.shape[0]}" - device = q.device - - if isinstance(q, SparseTensor): - assert len(q.shape) == 3, f"Invalid shape for q, got {q.shape}, expected [N, *, H, C]" - s = q - q_seqlen = [q.layout[i].stop - q.layout[i].start for i in range(q.shape[0])] - q = q.feats # [T_Q, H, C] - else: - assert len(q.shape) == 4, f"Invalid shape for q, got {q.shape}, expected [N, L, H, C]" - s = None - N, L, H, C = q.shape - q_seqlen = [L] * N - q = q.reshape(N * L, H, C) # [T_Q, H, C] - - if isinstance(kv, SparseTensor): - assert len(kv.shape) == 4 and kv.shape[1] == 2, f"Invalid shape for kv, got {kv.shape}, expected [N, *, 2, H, C]" - kv_seqlen = [kv.layout[i].stop - kv.layout[i].start for i in range(kv.shape[0])] - kv = kv.feats # [T_KV, 2, H, C] - else: - assert len(kv.shape) == 5, f"Invalid shape for kv, got {kv.shape}, expected [N, L, 2, H, C]" - N, L, _, H, C = kv.shape - kv_seqlen = [L] * N - kv = kv.reshape(N * L, 2, H, C) # [T_KV, 2, H, C] - - elif num_all_args == 3: - q = args[0] if len(args) > 0 else kwargs['q'] - k = args[1] if len(args) > 1 else kwargs['k'] - v = args[2] if len(args) > 2 else kwargs['v'] - assert isinstance(q, SparseTensor) and isinstance(k, (SparseTensor, torch.Tensor)) and type(k) == type(v) or \ - isinstance(q, torch.Tensor) and isinstance(k, SparseTensor) and isinstance(v, SparseTensor), \ - f"Invalid types, got {type(q)}, {type(k)}, and {type(v)}" - assert q.shape[0] == k.shape[0] == v.shape[0], f"Batch size mismatch, got {q.shape[0]}, {k.shape[0]}, and {v.shape[0]}" - device = q.device - - if isinstance(q, SparseTensor): - assert len(q.shape) == 3, f"Invalid shape for q, got {q.shape}, expected [N, *, H, Ci]" - s = q - q_seqlen = [q.layout[i].stop - q.layout[i].start for i in range(q.shape[0])] - q = q.feats # [T_Q, H, Ci] - else: - assert len(q.shape) == 4, f"Invalid shape for q, got {q.shape}, expected [N, L, H, Ci]" - s = None - N, L, H, CI = q.shape - q_seqlen = [L] * N - q = q.reshape(N * L, H, CI) # [T_Q, H, Ci] - - if isinstance(k, SparseTensor): - assert len(k.shape) == 3, f"Invalid shape for k, got {k.shape}, expected [N, *, H, Ci]" - assert len(v.shape) == 3, f"Invalid shape for v, got {v.shape}, expected [N, *, H, Co]" - kv_seqlen = [k.layout[i].stop - k.layout[i].start for i in range(k.shape[0])] - k = k.feats # [T_KV, H, Ci] - v = v.feats # [T_KV, H, Co] - else: - assert len(k.shape) == 4, f"Invalid shape for k, got {k.shape}, expected [N, L, H, Ci]" - assert len(v.shape) == 4, f"Invalid shape for v, got {v.shape}, expected [N, L, H, Co]" - N, L, H, CI, CO = *k.shape, v.shape[-1] - kv_seqlen = [L] * N - k = k.reshape(N * L, H, CI) # [T_KV, H, Ci] - v = v.reshape(N * L, H, CO) # [T_KV, H, Co] - - if DEBUG: - if s is not None: - for i in range(s.shape[0]): - assert (s.coords[s.layout[i]] == i).all(), f"SparseScaledDotProductSelfAttention: batch index mismatch" - if num_all_args in [2, 3]: - assert q.shape[:2] == [1, sum(q_seqlen)], f"SparseScaledDotProductSelfAttention: q shape mismatch" - if num_all_args == 3: - assert k.shape[:2] == [1, sum(kv_seqlen)], f"SparseScaledDotProductSelfAttention: k shape mismatch" - assert v.shape[:2] == [1, sum(kv_seqlen)], f"SparseScaledDotProductSelfAttention: v shape mismatch" - - if ATTN == 'xformers': - if num_all_args == 1: - q, k, v = qkv.unbind(dim=1) - elif num_all_args == 2: - k, v = kv.unbind(dim=1) - q = q.unsqueeze(0) - k = k.unsqueeze(0) - v = v.unsqueeze(0) - mask = xops.fmha.BlockDiagonalMask.from_seqlens(q_seqlen, kv_seqlen) - out = xops.memory_efficient_attention(q, k, v, mask)[0] - elif ATTN == 'flash_attn': - cu_seqlens_q = torch.cat([torch.tensor([0]), torch.cumsum(torch.tensor(q_seqlen), dim=0)]).int().to(device) - if num_all_args in [2, 3]: - cu_seqlens_kv = torch.cat([torch.tensor([0]), torch.cumsum(torch.tensor(kv_seqlen), dim=0)]).int().to(device) - if num_all_args == 1: - out = flash_attn.flash_attn_varlen_qkvpacked_func(qkv, cu_seqlens_q, max(q_seqlen)) - elif num_all_args == 2: - out = flash_attn.flash_attn_varlen_kvpacked_func(q, kv, cu_seqlens_q, cu_seqlens_kv, max(q_seqlen), max(kv_seqlen)) - elif num_all_args == 3: - out = flash_attn.flash_attn_varlen_func(q, k, v, cu_seqlens_q, cu_seqlens_kv, max(q_seqlen), max(kv_seqlen)) - else: - raise ValueError(f"Unknown attention module: {ATTN}") - - if s is not None: - return s.replace(out) - else: - return out.reshape(N, L, H, -1) diff --git a/trellis/modules/sparse/attention/modules.py b/trellis/modules/sparse/attention/modules.py deleted file mode 100644 index d8fbb572786483a840dc325097eacc08b815a0a5..0000000000000000000000000000000000000000 --- a/trellis/modules/sparse/attention/modules.py +++ /dev/null @@ -1,139 +0,0 @@ -from typing import * -import torch -import torch.nn as nn -import torch.nn.functional as F -from .. import SparseTensor -from .full_attn import sparse_scaled_dot_product_attention -from .serialized_attn import SerializeMode, sparse_serialized_scaled_dot_product_self_attention -from .windowed_attn import sparse_windowed_scaled_dot_product_self_attention -from ...attention import RotaryPositionEmbedder - - -class SparseMultiHeadRMSNorm(nn.Module): - def __init__(self, dim: int, heads: int): - super().__init__() - self.scale = dim ** 0.5 - self.gamma = nn.Parameter(torch.ones(heads, dim)) - - def forward(self, x: Union[SparseTensor, torch.Tensor]) -> Union[SparseTensor, torch.Tensor]: - x_type = x.dtype - x = x.float() - if isinstance(x, SparseTensor): - x = x.replace(F.normalize(x.feats, dim=-1)) - else: - x = F.normalize(x, dim=-1) - return (x * self.gamma * self.scale).to(x_type) - - -class SparseMultiHeadAttention(nn.Module): - def __init__( - self, - channels: int, - num_heads: int, - ctx_channels: Optional[int] = None, - type: Literal["self", "cross"] = "self", - attn_mode: Literal["full", "serialized", "windowed"] = "full", - window_size: Optional[int] = None, - shift_sequence: Optional[int] = None, - shift_window: Optional[Tuple[int, int, int]] = None, - serialize_mode: Optional[SerializeMode] = None, - qkv_bias: bool = True, - use_rope: bool = False, - qk_rms_norm: bool = False, - ): - super().__init__() - assert channels % num_heads == 0 - assert type in ["self", "cross"], f"Invalid attention type: {type}" - assert attn_mode in ["full", "serialized", "windowed"], f"Invalid attention mode: {attn_mode}" - assert type == "self" or attn_mode == "full", "Cross-attention only supports full attention" - assert type == "self" or use_rope is False, "Rotary position embeddings only supported for self-attention" - self.channels = channels - self.ctx_channels = ctx_channels if ctx_channels is not None else channels - self.num_heads = num_heads - self._type = type - self.attn_mode = attn_mode - self.window_size = window_size - self.shift_sequence = shift_sequence - self.shift_window = shift_window - self.serialize_mode = serialize_mode - self.use_rope = use_rope - self.qk_rms_norm = qk_rms_norm - - if self._type == "self": - self.to_qkv = nn.Linear(channels, channels * 3, bias=qkv_bias) - else: - self.to_q = nn.Linear(channels, channels, bias=qkv_bias) - self.to_kv = nn.Linear(self.ctx_channels, channels * 2, bias=qkv_bias) - - if self.qk_rms_norm: - self.q_rms_norm = SparseMultiHeadRMSNorm(channels // num_heads, num_heads) - self.k_rms_norm = SparseMultiHeadRMSNorm(channels // num_heads, num_heads) - - self.to_out = nn.Linear(channels, channels) - - if use_rope: - self.rope = RotaryPositionEmbedder(channels) - - @staticmethod - def _linear(module: nn.Linear, x: Union[SparseTensor, torch.Tensor]) -> Union[SparseTensor, torch.Tensor]: - if isinstance(x, SparseTensor): - return x.replace(module(x.feats)) - else: - return module(x) - - @staticmethod - def _reshape_chs(x: Union[SparseTensor, torch.Tensor], shape: Tuple[int, ...]) -> Union[SparseTensor, torch.Tensor]: - if isinstance(x, SparseTensor): - return x.reshape(*shape) - else: - return x.reshape(*x.shape[:2], *shape) - - def _fused_pre(self, x: Union[SparseTensor, torch.Tensor], num_fused: int) -> Union[SparseTensor, torch.Tensor]: - if isinstance(x, SparseTensor): - x_feats = x.feats.unsqueeze(0) - else: - x_feats = x - x_feats = x_feats.reshape(*x_feats.shape[:2], num_fused, self.num_heads, -1) - return x.replace(x_feats.squeeze(0)) if isinstance(x, SparseTensor) else x_feats - - def _rope(self, qkv: SparseTensor) -> SparseTensor: - q, k, v = qkv.feats.unbind(dim=1) # [T, H, C] - q, k = self.rope(q, k, qkv.coords[:, 1:]) - qkv = qkv.replace(torch.stack([q, k, v], dim=1)) - return qkv - - def forward(self, x: Union[SparseTensor, torch.Tensor], context: Optional[Union[SparseTensor, torch.Tensor]] = None) -> Union[SparseTensor, torch.Tensor]: - if self._type == "self": - qkv = self._linear(self.to_qkv, x) - qkv = self._fused_pre(qkv, num_fused=3) - if self.use_rope: - qkv = self._rope(qkv) - if self.qk_rms_norm: - q, k, v = qkv.unbind(dim=1) - q = self.q_rms_norm(q) - k = self.k_rms_norm(k) - qkv = qkv.replace(torch.stack([q.feats, k.feats, v.feats], dim=1)) - if self.attn_mode == "full": - h = sparse_scaled_dot_product_attention(qkv) - elif self.attn_mode == "serialized": - h = sparse_serialized_scaled_dot_product_self_attention( - qkv, self.window_size, serialize_mode=self.serialize_mode, shift_sequence=self.shift_sequence, shift_window=self.shift_window - ) - elif self.attn_mode == "windowed": - h = sparse_windowed_scaled_dot_product_self_attention( - qkv, self.window_size, shift_window=self.shift_window - ) - else: - q = self._linear(self.to_q, x) - q = self._reshape_chs(q, (self.num_heads, -1)) - kv = self._linear(self.to_kv, context) - kv = self._fused_pre(kv, num_fused=2) - if self.qk_rms_norm: - q = self.q_rms_norm(q) - k, v = kv.unbind(dim=1) - k = self.k_rms_norm(k) - kv = kv.replace(torch.stack([k.feats, v.feats], dim=1)) - h = sparse_scaled_dot_product_attention(q, kv) - h = self._reshape_chs(h, (-1,)) - h = self._linear(self.to_out, h) - return h diff --git a/trellis/modules/sparse/attention/serialized_attn.py b/trellis/modules/sparse/attention/serialized_attn.py deleted file mode 100644 index b3da276c4b47db2e2816d61bfa66db413aa6b7aa..0000000000000000000000000000000000000000 --- a/trellis/modules/sparse/attention/serialized_attn.py +++ /dev/null @@ -1,193 +0,0 @@ -from typing import * -from enum import Enum -import torch -import math -from .. import SparseTensor -from .. import DEBUG, ATTN - -if ATTN == 'xformers': - import xformers.ops as xops -elif ATTN == 'flash_attn': - import flash_attn -else: - raise ValueError(f"Unknown attention module: {ATTN}") - - -__all__ = [ - 'sparse_serialized_scaled_dot_product_self_attention', -] - - -class SerializeMode(Enum): - Z_ORDER = 0 - Z_ORDER_TRANSPOSED = 1 - HILBERT = 2 - HILBERT_TRANSPOSED = 3 - - -SerializeModes = [ - SerializeMode.Z_ORDER, - SerializeMode.Z_ORDER_TRANSPOSED, - SerializeMode.HILBERT, - SerializeMode.HILBERT_TRANSPOSED -] - - -def calc_serialization( - tensor: SparseTensor, - window_size: int, - serialize_mode: SerializeMode = SerializeMode.Z_ORDER, - shift_sequence: int = 0, - shift_window: Tuple[int, int, int] = (0, 0, 0) -) -> Tuple[torch.Tensor, torch.Tensor, List[int]]: - """ - Calculate serialization and partitioning for a set of coordinates. - - Args: - tensor (SparseTensor): The input tensor. - window_size (int): The window size to use. - serialize_mode (SerializeMode): The serialization mode to use. - shift_sequence (int): The shift of serialized sequence. - shift_window (Tuple[int, int, int]): The shift of serialized coordinates. - - Returns: - (torch.Tensor, torch.Tensor): Forwards and backwards indices. - """ - fwd_indices = [] - bwd_indices = [] - seq_lens = [] - seq_batch_indices = [] - offsets = [0] - - if 'vox2seq' not in globals(): - import vox2seq - - # Serialize the input - serialize_coords = tensor.coords[:, 1:].clone() - serialize_coords += torch.tensor(shift_window, dtype=torch.int32, device=tensor.device).reshape(1, 3) - if serialize_mode == SerializeMode.Z_ORDER: - code = vox2seq.encode(serialize_coords, mode='z_order', permute=[0, 1, 2]) - elif serialize_mode == SerializeMode.Z_ORDER_TRANSPOSED: - code = vox2seq.encode(serialize_coords, mode='z_order', permute=[1, 0, 2]) - elif serialize_mode == SerializeMode.HILBERT: - code = vox2seq.encode(serialize_coords, mode='hilbert', permute=[0, 1, 2]) - elif serialize_mode == SerializeMode.HILBERT_TRANSPOSED: - code = vox2seq.encode(serialize_coords, mode='hilbert', permute=[1, 0, 2]) - else: - raise ValueError(f"Unknown serialize mode: {serialize_mode}") - - for bi, s in enumerate(tensor.layout): - num_points = s.stop - s.start - num_windows = (num_points + window_size - 1) // window_size - valid_window_size = num_points / num_windows - to_ordered = torch.argsort(code[s.start:s.stop]) - if num_windows == 1: - fwd_indices.append(to_ordered) - bwd_indices.append(torch.zeros_like(to_ordered).scatter_(0, to_ordered, torch.arange(num_points, device=tensor.device))) - fwd_indices[-1] += s.start - bwd_indices[-1] += offsets[-1] - seq_lens.append(num_points) - seq_batch_indices.append(bi) - offsets.append(offsets[-1] + seq_lens[-1]) - else: - # Partition the input - offset = 0 - mids = [(i + 0.5) * valid_window_size + shift_sequence for i in range(num_windows)] - split = [math.floor(i * valid_window_size + shift_sequence) for i in range(num_windows + 1)] - bwd_index = torch.zeros((num_points,), dtype=torch.int64, device=tensor.device) - for i in range(num_windows): - mid = mids[i] - valid_start = split[i] - valid_end = split[i + 1] - padded_start = math.floor(mid - 0.5 * window_size) - padded_end = padded_start + window_size - fwd_indices.append(to_ordered[torch.arange(padded_start, padded_end, device=tensor.device) % num_points]) - offset += valid_start - padded_start - bwd_index.scatter_(0, fwd_indices[-1][valid_start-padded_start:valid_end-padded_start], torch.arange(offset, offset + valid_end - valid_start, device=tensor.device)) - offset += padded_end - valid_start - fwd_indices[-1] += s.start - seq_lens.extend([window_size] * num_windows) - seq_batch_indices.extend([bi] * num_windows) - bwd_indices.append(bwd_index + offsets[-1]) - offsets.append(offsets[-1] + num_windows * window_size) - - fwd_indices = torch.cat(fwd_indices) - bwd_indices = torch.cat(bwd_indices) - - return fwd_indices, bwd_indices, seq_lens, seq_batch_indices - - -def sparse_serialized_scaled_dot_product_self_attention( - qkv: SparseTensor, - window_size: int, - serialize_mode: SerializeMode = SerializeMode.Z_ORDER, - shift_sequence: int = 0, - shift_window: Tuple[int, int, int] = (0, 0, 0) -) -> SparseTensor: - """ - Apply serialized scaled dot product self attention to a sparse tensor. - - Args: - qkv (SparseTensor): [N, *, 3, H, C] sparse tensor containing Qs, Ks, and Vs. - window_size (int): The window size to use. - serialize_mode (SerializeMode): The serialization mode to use. - shift_sequence (int): The shift of serialized sequence. - shift_window (Tuple[int, int, int]): The shift of serialized coordinates. - shift (int): The shift to use. - """ - assert len(qkv.shape) == 4 and qkv.shape[1] == 3, f"Invalid shape for qkv, got {qkv.shape}, expected [N, *, 3, H, C]" - - serialization_spatial_cache_name = f'serialization_{serialize_mode}_{window_size}_{shift_sequence}_{shift_window}' - serialization_spatial_cache = qkv.get_spatial_cache(serialization_spatial_cache_name) - if serialization_spatial_cache is None: - fwd_indices, bwd_indices, seq_lens, seq_batch_indices = calc_serialization(qkv, window_size, serialize_mode, shift_sequence, shift_window) - qkv.register_spatial_cache(serialization_spatial_cache_name, (fwd_indices, bwd_indices, seq_lens, seq_batch_indices)) - else: - fwd_indices, bwd_indices, seq_lens, seq_batch_indices = serialization_spatial_cache - - M = fwd_indices.shape[0] - T = qkv.feats.shape[0] - H = qkv.feats.shape[2] - C = qkv.feats.shape[3] - - qkv_feats = qkv.feats[fwd_indices] # [M, 3, H, C] - - if DEBUG: - start = 0 - qkv_coords = qkv.coords[fwd_indices] - for i in range(len(seq_lens)): - assert (qkv_coords[start:start+seq_lens[i], 0] == seq_batch_indices[i]).all(), f"SparseWindowedScaledDotProductSelfAttention: batch index mismatch" - start += seq_lens[i] - - if all([seq_len == window_size for seq_len in seq_lens]): - B = len(seq_lens) - N = window_size - qkv_feats = qkv_feats.reshape(B, N, 3, H, C) - if ATTN == 'xformers': - q, k, v = qkv_feats.unbind(dim=2) # [B, N, H, C] - out = xops.memory_efficient_attention(q, k, v) # [B, N, H, C] - elif ATTN == 'flash_attn': - out = flash_attn.flash_attn_qkvpacked_func(qkv_feats) # [B, N, H, C] - else: - raise ValueError(f"Unknown attention module: {ATTN}") - out = out.reshape(B * N, H, C) # [M, H, C] - else: - if ATTN == 'xformers': - q, k, v = qkv_feats.unbind(dim=1) # [M, H, C] - q = q.unsqueeze(0) # [1, M, H, C] - k = k.unsqueeze(0) # [1, M, H, C] - v = v.unsqueeze(0) # [1, M, H, C] - mask = xops.fmha.BlockDiagonalMask.from_seqlens(seq_lens) - out = xops.memory_efficient_attention(q, k, v, mask)[0] # [M, H, C] - elif ATTN == 'flash_attn': - cu_seqlens = torch.cat([torch.tensor([0]), torch.cumsum(torch.tensor(seq_lens), dim=0)], dim=0) \ - .to(qkv.device).int() - out = flash_attn.flash_attn_varlen_qkvpacked_func(qkv_feats, cu_seqlens, max(seq_lens)) # [M, H, C] - - out = out[bwd_indices] # [T, H, C] - - if DEBUG: - qkv_coords = qkv_coords[bwd_indices] - assert torch.equal(qkv_coords, qkv.coords), "SparseWindowedScaledDotProductSelfAttention: coordinate mismatch" - - return qkv.replace(out) diff --git a/trellis/modules/sparse/attention/windowed_attn.py b/trellis/modules/sparse/attention/windowed_attn.py deleted file mode 100644 index 11eebf851316d8b4c1f6ca39b881c422d8f2f088..0000000000000000000000000000000000000000 --- a/trellis/modules/sparse/attention/windowed_attn.py +++ /dev/null @@ -1,135 +0,0 @@ -from typing import * -import torch -import math -from .. import SparseTensor -from .. import DEBUG, ATTN - -if ATTN == 'xformers': - import xformers.ops as xops -elif ATTN == 'flash_attn': - import flash_attn -else: - raise ValueError(f"Unknown attention module: {ATTN}") - - -__all__ = [ - 'sparse_windowed_scaled_dot_product_self_attention', -] - - -def calc_window_partition( - tensor: SparseTensor, - window_size: Union[int, Tuple[int, ...]], - shift_window: Union[int, Tuple[int, ...]] = 0 -) -> Tuple[torch.Tensor, torch.Tensor, List[int], List[int]]: - """ - Calculate serialization and partitioning for a set of coordinates. - - Args: - tensor (SparseTensor): The input tensor. - window_size (int): The window size to use. - shift_window (Tuple[int, ...]): The shift of serialized coordinates. - - Returns: - (torch.Tensor): Forwards indices. - (torch.Tensor): Backwards indices. - (List[int]): Sequence lengths. - (List[int]): Sequence batch indices. - """ - DIM = tensor.coords.shape[1] - 1 - shift_window = (shift_window,) * DIM if isinstance(shift_window, int) else shift_window - window_size = (window_size,) * DIM if isinstance(window_size, int) else window_size - shifted_coords = tensor.coords.clone().detach() - shifted_coords[:, 1:] += torch.tensor(shift_window, device=tensor.device, dtype=torch.int32).unsqueeze(0) - - MAX_COORDS = shifted_coords[:, 1:].max(dim=0).values.tolist() - NUM_WINDOWS = [math.ceil((mc + 1) / ws) for mc, ws in zip(MAX_COORDS, window_size)] - OFFSET = torch.cumprod(torch.tensor([1] + NUM_WINDOWS[::-1]), dim=0).tolist()[::-1] - - shifted_coords[:, 1:] //= torch.tensor(window_size, device=tensor.device, dtype=torch.int32).unsqueeze(0) - shifted_indices = (shifted_coords * torch.tensor(OFFSET, device=tensor.device, dtype=torch.int32).unsqueeze(0)).sum(dim=1) - fwd_indices = torch.argsort(shifted_indices) - bwd_indices = torch.empty_like(fwd_indices) - bwd_indices[fwd_indices] = torch.arange(fwd_indices.shape[0], device=tensor.device) - seq_lens = torch.bincount(shifted_indices) - seq_batch_indices = torch.arange(seq_lens.shape[0], device=tensor.device, dtype=torch.int32) // OFFSET[0] - mask = seq_lens != 0 - seq_lens = seq_lens[mask].tolist() - seq_batch_indices = seq_batch_indices[mask].tolist() - - return fwd_indices, bwd_indices, seq_lens, seq_batch_indices - - -def sparse_windowed_scaled_dot_product_self_attention( - qkv: SparseTensor, - window_size: int, - shift_window: Tuple[int, int, int] = (0, 0, 0) -) -> SparseTensor: - """ - Apply windowed scaled dot product self attention to a sparse tensor. - - Args: - qkv (SparseTensor): [N, *, 3, H, C] sparse tensor containing Qs, Ks, and Vs. - window_size (int): The window size to use. - shift_window (Tuple[int, int, int]): The shift of serialized coordinates. - shift (int): The shift to use. - """ - assert len(qkv.shape) == 4 and qkv.shape[1] == 3, f"Invalid shape for qkv, got {qkv.shape}, expected [N, *, 3, H, C]" - - serialization_spatial_cache_name = f'window_partition_{window_size}_{shift_window}' - serialization_spatial_cache = qkv.get_spatial_cache(serialization_spatial_cache_name) - if serialization_spatial_cache is None: - fwd_indices, bwd_indices, seq_lens, seq_batch_indices = calc_window_partition(qkv, window_size, shift_window) - qkv.register_spatial_cache(serialization_spatial_cache_name, (fwd_indices, bwd_indices, seq_lens, seq_batch_indices)) - else: - fwd_indices, bwd_indices, seq_lens, seq_batch_indices = serialization_spatial_cache - - M = fwd_indices.shape[0] - T = qkv.feats.shape[0] - H = qkv.feats.shape[2] - C = qkv.feats.shape[3] - - qkv_feats = qkv.feats[fwd_indices] # [M, 3, H, C] - - if DEBUG: - start = 0 - qkv_coords = qkv.coords[fwd_indices] - for i in range(len(seq_lens)): - seq_coords = qkv_coords[start:start+seq_lens[i]] - assert (seq_coords[:, 0] == seq_batch_indices[i]).all(), f"SparseWindowedScaledDotProductSelfAttention: batch index mismatch" - assert (seq_coords[:, 1:].max(dim=0).values - seq_coords[:, 1:].min(dim=0).values < window_size).all(), \ - f"SparseWindowedScaledDotProductSelfAttention: window size exceeded" - start += seq_lens[i] - - if all([seq_len == window_size for seq_len in seq_lens]): - B = len(seq_lens) - N = window_size - qkv_feats = qkv_feats.reshape(B, N, 3, H, C) - if ATTN == 'xformers': - q, k, v = qkv_feats.unbind(dim=2) # [B, N, H, C] - out = xops.memory_efficient_attention(q, k, v) # [B, N, H, C] - elif ATTN == 'flash_attn': - out = flash_attn.flash_attn_qkvpacked_func(qkv_feats) # [B, N, H, C] - else: - raise ValueError(f"Unknown attention module: {ATTN}") - out = out.reshape(B * N, H, C) # [M, H, C] - else: - if ATTN == 'xformers': - q, k, v = qkv_feats.unbind(dim=1) # [M, H, C] - q = q.unsqueeze(0) # [1, M, H, C] - k = k.unsqueeze(0) # [1, M, H, C] - v = v.unsqueeze(0) # [1, M, H, C] - mask = xops.fmha.BlockDiagonalMask.from_seqlens(seq_lens) - out = xops.memory_efficient_attention(q, k, v, mask)[0] # [M, H, C] - elif ATTN == 'flash_attn': - cu_seqlens = torch.cat([torch.tensor([0]), torch.cumsum(torch.tensor(seq_lens), dim=0)], dim=0) \ - .to(qkv.device).int() - out = flash_attn.flash_attn_varlen_qkvpacked_func(qkv_feats, cu_seqlens, max(seq_lens)) # [M, H, C] - - out = out[bwd_indices] # [T, H, C] - - if DEBUG: - qkv_coords = qkv_coords[bwd_indices] - assert torch.equal(qkv_coords, qkv.coords), "SparseWindowedScaledDotProductSelfAttention: coordinate mismatch" - - return qkv.replace(out) diff --git a/trellis/modules/sparse/basic.py b/trellis/modules/sparse/basic.py deleted file mode 100644 index 0fc685128f47906e5608d38d789366bb06837513..0000000000000000000000000000000000000000 --- a/trellis/modules/sparse/basic.py +++ /dev/null @@ -1,459 +0,0 @@ -from typing import * -import torch -import torch.nn as nn -from . import BACKEND, DEBUG -SparseTensorData = None # Lazy import - - -__all__ = [ - 'SparseTensor', - 'sparse_batch_broadcast', - 'sparse_batch_op', - 'sparse_cat', - 'sparse_unbind', -] - - -class SparseTensor: - """ - Sparse tensor with support for both torchsparse and spconv backends. - - Parameters: - - feats (torch.Tensor): Features of the sparse tensor. - - coords (torch.Tensor): Coordinates of the sparse tensor. - - shape (torch.Size): Shape of the sparse tensor. - - layout (List[slice]): Layout of the sparse tensor for each batch - - data (SparseTensorData): Sparse tensor data used for convolusion - - NOTE: - - Data corresponding to a same batch should be contiguous. - - Coords should be in [0, 1023] - """ - @overload - def __init__(self, feats: torch.Tensor, coords: torch.Tensor, shape: Optional[torch.Size] = None, layout: Optional[List[slice]] = None, **kwargs): ... - - @overload - def __init__(self, data, shape: Optional[torch.Size] = None, layout: Optional[List[slice]] = None, **kwargs): ... - - def __init__(self, *args, **kwargs): - # Lazy import of sparse tensor backend - global SparseTensorData - if SparseTensorData is None: - import importlib - if BACKEND == 'torchsparse': - SparseTensorData = importlib.import_module('torchsparse').SparseTensor - elif BACKEND == 'spconv': - SparseTensorData = importlib.import_module('spconv.pytorch').SparseConvTensor - - method_id = 0 - if len(args) != 0: - method_id = 0 if isinstance(args[0], torch.Tensor) else 1 - else: - method_id = 1 if 'data' in kwargs else 0 - - if method_id == 0: - feats, coords, shape, layout = args + (None,) * (4 - len(args)) - if 'feats' in kwargs: - feats = kwargs['feats'] - del kwargs['feats'] - if 'coords' in kwargs: - coords = kwargs['coords'] - del kwargs['coords'] - if 'shape' in kwargs: - shape = kwargs['shape'] - del kwargs['shape'] - if 'layout' in kwargs: - layout = kwargs['layout'] - del kwargs['layout'] - - if shape is None: - shape = self.__cal_shape(feats, coords) - if layout is None: - layout = self.__cal_layout(coords, shape[0]) - if BACKEND == 'torchsparse': - self.data = SparseTensorData(feats, coords, **kwargs) - elif BACKEND == 'spconv': - spatial_shape = list(coords.max(0)[0] + 1)[1:] - self.data = SparseTensorData(feats.reshape(feats.shape[0], -1), coords, spatial_shape, shape[0], **kwargs) - self.data._features = feats - elif method_id == 1: - data, shape, layout = args + (None,) * (3 - len(args)) - if 'data' in kwargs: - data = kwargs['data'] - del kwargs['data'] - if 'shape' in kwargs: - shape = kwargs['shape'] - del kwargs['shape'] - if 'layout' in kwargs: - layout = kwargs['layout'] - del kwargs['layout'] - - self.data = data - if shape is None: - shape = self.__cal_shape(self.feats, self.coords) - if layout is None: - layout = self.__cal_layout(self.coords, shape[0]) - - self._shape = shape - self._layout = layout - self._scale = kwargs.get('scale', (1, 1, 1)) - self._spatial_cache = kwargs.get('spatial_cache', {}) - - if DEBUG: - try: - assert self.feats.shape[0] == self.coords.shape[0], f"Invalid feats shape: {self.feats.shape}, coords shape: {self.coords.shape}" - assert self.shape == self.__cal_shape(self.feats, self.coords), f"Invalid shape: {self.shape}" - assert self.layout == self.__cal_layout(self.coords, self.shape[0]), f"Invalid layout: {self.layout}" - for i in range(self.shape[0]): - assert torch.all(self.coords[self.layout[i], 0] == i), f"The data of batch {i} is not contiguous" - except Exception as e: - print('Debugging information:') - print(f"- Shape: {self.shape}") - print(f"- Layout: {self.layout}") - print(f"- Scale: {self._scale}") - print(f"- Coords: {self.coords}") - raise e - - def __cal_shape(self, feats, coords): - shape = [] - shape.append(coords[:, 0].max().item() + 1) - shape.extend([*feats.shape[1:]]) - return torch.Size(shape) - - def __cal_layout(self, coords, batch_size): - seq_len = torch.bincount(coords[:, 0], minlength=batch_size) - offset = torch.cumsum(seq_len, dim=0) - layout = [slice((offset[i] - seq_len[i]).item(), offset[i].item()) for i in range(batch_size)] - return layout - - @property - def shape(self) -> torch.Size: - return self._shape - - def dim(self) -> int: - return len(self.shape) - - @property - def layout(self) -> List[slice]: - return self._layout - - @property - def feats(self) -> torch.Tensor: - if BACKEND == 'torchsparse': - return self.data.F - elif BACKEND == 'spconv': - return self.data.features - - @feats.setter - def feats(self, value: torch.Tensor): - if BACKEND == 'torchsparse': - self.data.F = value - elif BACKEND == 'spconv': - self.data.features = value - - @property - def coords(self) -> torch.Tensor: - if BACKEND == 'torchsparse': - return self.data.C - elif BACKEND == 'spconv': - return self.data.indices - - @coords.setter - def coords(self, value: torch.Tensor): - if BACKEND == 'torchsparse': - self.data.C = value - elif BACKEND == 'spconv': - self.data.indices = value - - @property - def dtype(self): - return self.feats.dtype - - @property - def device(self): - return self.feats.device - - @overload - def to(self, dtype: torch.dtype) -> 'SparseTensor': ... - - @overload - def to(self, device: Optional[Union[str, torch.device]] = None, dtype: Optional[torch.dtype] = None) -> 'SparseTensor': ... - - def to(self, *args, **kwargs) -> 'SparseTensor': - device = None - dtype = None - if len(args) == 2: - device, dtype = args - elif len(args) == 1: - if isinstance(args[0], torch.dtype): - dtype = args[0] - else: - device = args[0] - if 'dtype' in kwargs: - assert dtype is None, "to() received multiple values for argument 'dtype'" - dtype = kwargs['dtype'] - if 'device' in kwargs: - assert device is None, "to() received multiple values for argument 'device'" - device = kwargs['device'] - - new_feats = self.feats.to(device=device, dtype=dtype) - new_coords = self.coords.to(device=device) - return self.replace(new_feats, new_coords) - - def type(self, dtype): - new_feats = self.feats.type(dtype) - return self.replace(new_feats) - - def cpu(self) -> 'SparseTensor': - new_feats = self.feats.cpu() - new_coords = self.coords.cpu() - return self.replace(new_feats, new_coords) - - def cuda(self) -> 'SparseTensor': - new_feats = self.feats.cuda() - new_coords = self.coords.cuda() - return self.replace(new_feats, new_coords) - - def half(self) -> 'SparseTensor': - new_feats = self.feats.half() - return self.replace(new_feats) - - def float(self) -> 'SparseTensor': - new_feats = self.feats.float() - return self.replace(new_feats) - - def detach(self) -> 'SparseTensor': - new_coords = self.coords.detach() - new_feats = self.feats.detach() - return self.replace(new_feats, new_coords) - - def dense(self) -> torch.Tensor: - if BACKEND == 'torchsparse': - return self.data.dense() - elif BACKEND == 'spconv': - return self.data.dense() - - def reshape(self, *shape) -> 'SparseTensor': - new_feats = self.feats.reshape(self.feats.shape[0], *shape) - return self.replace(new_feats) - - def unbind(self, dim: int) -> List['SparseTensor']: - return sparse_unbind(self, dim) - - def replace(self, feats: torch.Tensor, coords: Optional[torch.Tensor] = None) -> 'SparseTensor': - new_shape = [self.shape[0]] - new_shape.extend(feats.shape[1:]) - if BACKEND == 'torchsparse': - new_data = SparseTensorData( - feats=feats, - coords=self.data.coords if coords is None else coords, - stride=self.data.stride, - spatial_range=self.data.spatial_range, - ) - new_data._caches = self.data._caches - elif BACKEND == 'spconv': - new_data = SparseTensorData( - self.data.features.reshape(self.data.features.shape[0], -1), - self.data.indices, - self.data.spatial_shape, - self.data.batch_size, - self.data.grid, - self.data.voxel_num, - self.data.indice_dict - ) - new_data._features = feats - new_data.benchmark = self.data.benchmark - new_data.benchmark_record = self.data.benchmark_record - new_data.thrust_allocator = self.data.thrust_allocator - new_data._timer = self.data._timer - new_data.force_algo = self.data.force_algo - new_data.int8_scale = self.data.int8_scale - if coords is not None: - new_data.indices = coords - new_tensor = SparseTensor(new_data, shape=torch.Size(new_shape), layout=self.layout, scale=self._scale, spatial_cache=self._spatial_cache) - return new_tensor - - @staticmethod - def full(aabb, dim, value, dtype=torch.float32, device=None) -> 'SparseTensor': - N, C = dim - x = torch.arange(aabb[0], aabb[3] + 1) - y = torch.arange(aabb[1], aabb[4] + 1) - z = torch.arange(aabb[2], aabb[5] + 1) - coords = torch.stack(torch.meshgrid(x, y, z, indexing='ij'), dim=-1).reshape(-1, 3) - coords = torch.cat([ - torch.arange(N).view(-1, 1).repeat(1, coords.shape[0]).view(-1, 1), - coords.repeat(N, 1), - ], dim=1).to(dtype=torch.int32, device=device) - feats = torch.full((coords.shape[0], C), value, dtype=dtype, device=device) - return SparseTensor(feats=feats, coords=coords) - - def __merge_sparse_cache(self, other: 'SparseTensor') -> dict: - new_cache = {} - for k in set(list(self._spatial_cache.keys()) + list(other._spatial_cache.keys())): - if k in self._spatial_cache: - new_cache[k] = self._spatial_cache[k] - if k in other._spatial_cache: - if k not in new_cache: - new_cache[k] = other._spatial_cache[k] - else: - new_cache[k].update(other._spatial_cache[k]) - return new_cache - - def __neg__(self) -> 'SparseTensor': - return self.replace(-self.feats) - - def __elemwise__(self, other: Union[torch.Tensor, 'SparseTensor'], op: callable) -> 'SparseTensor': - if isinstance(other, torch.Tensor): - try: - other = torch.broadcast_to(other, self.shape) - other = sparse_batch_broadcast(self, other) - except: - pass - if isinstance(other, SparseTensor): - other = other.feats - new_feats = op(self.feats, other) - new_tensor = self.replace(new_feats) - if isinstance(other, SparseTensor): - new_tensor._spatial_cache = self.__merge_sparse_cache(other) - return new_tensor - - def __add__(self, other: Union[torch.Tensor, 'SparseTensor', float]) -> 'SparseTensor': - return self.__elemwise__(other, torch.add) - - def __radd__(self, other: Union[torch.Tensor, 'SparseTensor', float]) -> 'SparseTensor': - return self.__elemwise__(other, torch.add) - - def __sub__(self, other: Union[torch.Tensor, 'SparseTensor', float]) -> 'SparseTensor': - return self.__elemwise__(other, torch.sub) - - def __rsub__(self, other: Union[torch.Tensor, 'SparseTensor', float]) -> 'SparseTensor': - return self.__elemwise__(other, lambda x, y: torch.sub(y, x)) - - def __mul__(self, other: Union[torch.Tensor, 'SparseTensor', float]) -> 'SparseTensor': - return self.__elemwise__(other, torch.mul) - - def __rmul__(self, other: Union[torch.Tensor, 'SparseTensor', float]) -> 'SparseTensor': - return self.__elemwise__(other, torch.mul) - - def __truediv__(self, other: Union[torch.Tensor, 'SparseTensor', float]) -> 'SparseTensor': - return self.__elemwise__(other, torch.div) - - def __rtruediv__(self, other: Union[torch.Tensor, 'SparseTensor', float]) -> 'SparseTensor': - return self.__elemwise__(other, lambda x, y: torch.div(y, x)) - - def __getitem__(self, idx): - if isinstance(idx, int): - idx = [idx] - elif isinstance(idx, slice): - idx = range(*idx.indices(self.shape[0])) - elif isinstance(idx, torch.Tensor): - if idx.dtype == torch.bool: - assert idx.shape == (self.shape[0],), f"Invalid index shape: {idx.shape}" - idx = idx.nonzero().squeeze(1) - elif idx.dtype in [torch.int32, torch.int64]: - assert len(idx.shape) == 1, f"Invalid index shape: {idx.shape}" - else: - raise ValueError(f"Unknown index type: {idx.dtype}") - else: - raise ValueError(f"Unknown index type: {type(idx)}") - - coords = [] - feats = [] - for new_idx, old_idx in enumerate(idx): - coords.append(self.coords[self.layout[old_idx]].clone()) - coords[-1][:, 0] = new_idx - feats.append(self.feats[self.layout[old_idx]]) - coords = torch.cat(coords, dim=0).contiguous() - feats = torch.cat(feats, dim=0).contiguous() - return SparseTensor(feats=feats, coords=coords) - - def register_spatial_cache(self, key, value) -> None: - """ - Register a spatial cache. - The spatial cache can be any thing you want to cache. - The registery and retrieval of the cache is based on current scale. - """ - scale_key = str(self._scale) - if scale_key not in self._spatial_cache: - self._spatial_cache[scale_key] = {} - self._spatial_cache[scale_key][key] = value - - def get_spatial_cache(self, key=None): - """ - Get a spatial cache. - """ - scale_key = str(self._scale) - cur_scale_cache = self._spatial_cache.get(scale_key, {}) - if key is None: - return cur_scale_cache - return cur_scale_cache.get(key, None) - - -def sparse_batch_broadcast(input: SparseTensor, other: torch.Tensor) -> torch.Tensor: - """ - Broadcast a 1D tensor to a sparse tensor along the batch dimension then perform an operation. - - Args: - input (torch.Tensor): 1D tensor to broadcast. - target (SparseTensor): Sparse tensor to broadcast to. - op (callable): Operation to perform after broadcasting. Defaults to torch.add. - """ - coords, feats = input.coords, input.feats - broadcasted = torch.zeros_like(feats) - for k in range(input.shape[0]): - broadcasted[input.layout[k]] = other[k] - return broadcasted - - -def sparse_batch_op(input: SparseTensor, other: torch.Tensor, op: callable = torch.add) -> SparseTensor: - """ - Broadcast a 1D tensor to a sparse tensor along the batch dimension then perform an operation. - - Args: - input (torch.Tensor): 1D tensor to broadcast. - target (SparseTensor): Sparse tensor to broadcast to. - op (callable): Operation to perform after broadcasting. Defaults to torch.add. - """ - return input.replace(op(input.feats, sparse_batch_broadcast(input, other))) - - -def sparse_cat(inputs: List[SparseTensor], dim: int = 0) -> SparseTensor: - """ - Concatenate a list of sparse tensors. - - Args: - inputs (List[SparseTensor]): List of sparse tensors to concatenate. - """ - if dim == 0: - start = 0 - coords = [] - for input in inputs: - coords.append(input.coords.clone()) - coords[-1][:, 0] += start - start += input.shape[0] - coords = torch.cat(coords, dim=0) - feats = torch.cat([input.feats for input in inputs], dim=0) - output = SparseTensor( - coords=coords, - feats=feats, - ) - else: - feats = torch.cat([input.feats for input in inputs], dim=dim) - output = inputs[0].replace(feats) - - return output - - -def sparse_unbind(input: SparseTensor, dim: int) -> List[SparseTensor]: - """ - Unbind a sparse tensor along a dimension. - - Args: - input (SparseTensor): Sparse tensor to unbind. - dim (int): Dimension to unbind. - """ - if dim == 0: - return [input[i] for i in range(input.shape[0])] - else: - feats = input.feats.unbind(dim) - return [input.replace(f) for f in feats] diff --git a/trellis/modules/sparse/conv/__init__.py b/trellis/modules/sparse/conv/__init__.py deleted file mode 100644 index 3fdaddf7e07d42e296d056df28e56a544d2db5f2..0000000000000000000000000000000000000000 --- a/trellis/modules/sparse/conv/__init__.py +++ /dev/null @@ -1,21 +0,0 @@ -from .. import BACKEND - - -SPCONV_ALGO = 'auto' # 'auto', 'implicit_gemm', 'native' - -def __from_env(): - import os - - global SPCONV_ALGO - env_spconv_algo = os.environ.get('SPCONV_ALGO') - if env_spconv_algo is not None and env_spconv_algo in ['auto', 'implicit_gemm', 'native']: - SPCONV_ALGO = env_spconv_algo - print(f"[SPARSE][CONV] spconv algo: {SPCONV_ALGO}") - - -__from_env() - -if BACKEND == 'torchsparse': - from .conv_torchsparse import * -elif BACKEND == 'spconv': - from .conv_spconv import * diff --git a/trellis/modules/sparse/conv/conv_spconv.py b/trellis/modules/sparse/conv/conv_spconv.py deleted file mode 100644 index 856405dea4b24e5800cc056106bb34bb40f6eef0..0000000000000000000000000000000000000000 --- a/trellis/modules/sparse/conv/conv_spconv.py +++ /dev/null @@ -1,80 +0,0 @@ -import torch -import torch.nn as nn -from .. import SparseTensor -from .. import DEBUG -from . import SPCONV_ALGO - -class SparseConv3d(nn.Module): - def __init__(self, in_channels, out_channels, kernel_size, stride=1, dilation=1, padding=None, bias=True, indice_key=None): - super(SparseConv3d, self).__init__() - if 'spconv' not in globals(): - import spconv.pytorch as spconv - algo = None - if SPCONV_ALGO == 'native': - algo = spconv.ConvAlgo.Native - elif SPCONV_ALGO == 'implicit_gemm': - algo = spconv.ConvAlgo.MaskImplicitGemm - if stride == 1 and (padding is None): - self.conv = spconv.SubMConv3d(in_channels, out_channels, kernel_size, dilation=dilation, bias=bias, indice_key=indice_key, algo=algo) - else: - self.conv = spconv.SparseConv3d(in_channels, out_channels, kernel_size, stride=stride, dilation=dilation, padding=padding, bias=bias, indice_key=indice_key, algo=algo) - self.stride = tuple(stride) if isinstance(stride, (list, tuple)) else (stride, stride, stride) - self.padding = padding - - def forward(self, x: SparseTensor) -> SparseTensor: - spatial_changed = any(s != 1 for s in self.stride) or (self.padding is not None) - new_data = self.conv(x.data) - new_shape = [x.shape[0], self.conv.out_channels] - new_layout = None if spatial_changed else x.layout - - if spatial_changed and (x.shape[0] != 1): - # spconv was non-1 stride will break the contiguous of the output tensor, sort by the coords - fwd = new_data.indices[:, 0].argsort() - bwd = torch.zeros_like(fwd).scatter_(0, fwd, torch.arange(fwd.shape[0], device=fwd.device)) - sorted_feats = new_data.features[fwd] - sorted_coords = new_data.indices[fwd] - unsorted_data = new_data - new_data = spconv.SparseConvTensor(sorted_feats, sorted_coords, unsorted_data.spatial_shape, unsorted_data.batch_size) # type: ignore - - out = SparseTensor( - new_data, shape=torch.Size(new_shape), layout=new_layout, - scale=tuple([s * stride for s, stride in zip(x._scale, self.stride)]), - spatial_cache=x._spatial_cache, - ) - - if spatial_changed and (x.shape[0] != 1): - out.register_spatial_cache(f'conv_{self.stride}_unsorted_data', unsorted_data) - out.register_spatial_cache(f'conv_{self.stride}_sort_bwd', bwd) - - return out - - -class SparseInverseConv3d(nn.Module): - def __init__(self, in_channels, out_channels, kernel_size, stride=1, dilation=1, bias=True, indice_key=None): - super(SparseInverseConv3d, self).__init__() - if 'spconv' not in globals(): - import spconv.pytorch as spconv - self.conv = spconv.SparseInverseConv3d(in_channels, out_channels, kernel_size, bias=bias, indice_key=indice_key) - self.stride = tuple(stride) if isinstance(stride, (list, tuple)) else (stride, stride, stride) - - def forward(self, x: SparseTensor) -> SparseTensor: - spatial_changed = any(s != 1 for s in self.stride) - if spatial_changed: - # recover the original spconv order - data = x.get_spatial_cache(f'conv_{self.stride}_unsorted_data') - bwd = x.get_spatial_cache(f'conv_{self.stride}_sort_bwd') - data = data.replace_feature(x.feats[bwd]) - if DEBUG: - assert torch.equal(data.indices, x.coords[bwd]), 'Recover the original order failed' - else: - data = x.data - - new_data = self.conv(data) - new_shape = [x.shape[0], self.conv.out_channels] - new_layout = None if spatial_changed else x.layout - out = SparseTensor( - new_data, shape=torch.Size(new_shape), layout=new_layout, - scale=tuple([s // stride for s, stride in zip(x._scale, self.stride)]), - spatial_cache=x._spatial_cache, - ) - return out diff --git a/trellis/modules/sparse/conv/conv_torchsparse.py b/trellis/modules/sparse/conv/conv_torchsparse.py deleted file mode 100644 index a10bd9105581a96117abcbb7349ea5975e4304ba..0000000000000000000000000000000000000000 --- a/trellis/modules/sparse/conv/conv_torchsparse.py +++ /dev/null @@ -1,38 +0,0 @@ -import torch -import torch.nn as nn -from .. import SparseTensor - - -class SparseConv3d(nn.Module): - def __init__(self, in_channels, out_channels, kernel_size, stride=1, dilation=1, bias=True, indice_key=None): - super(SparseConv3d, self).__init__() - if 'torchsparse' not in globals(): - import torchsparse - self.conv = torchsparse.nn.Conv3d(in_channels, out_channels, kernel_size, stride, 0, dilation, bias) - - def forward(self, x: SparseTensor) -> SparseTensor: - out = self.conv(x.data) - new_shape = [x.shape[0], self.conv.out_channels] - out = SparseTensor(out, shape=torch.Size(new_shape), layout=x.layout if all(s == 1 for s in self.conv.stride) else None) - out._spatial_cache = x._spatial_cache - out._scale = tuple([s * stride for s, stride in zip(x._scale, self.conv.stride)]) - return out - - -class SparseInverseConv3d(nn.Module): - def __init__(self, in_channels, out_channels, kernel_size, stride=1, dilation=1, bias=True, indice_key=None): - super(SparseInverseConv3d, self).__init__() - if 'torchsparse' not in globals(): - import torchsparse - self.conv = torchsparse.nn.Conv3d(in_channels, out_channels, kernel_size, stride, 0, dilation, bias, transposed=True) - - def forward(self, x: SparseTensor) -> SparseTensor: - out = self.conv(x.data) - new_shape = [x.shape[0], self.conv.out_channels] - out = SparseTensor(out, shape=torch.Size(new_shape), layout=x.layout if all(s == 1 for s in self.conv.stride) else None) - out._spatial_cache = x._spatial_cache - out._scale = tuple([s // stride for s, stride in zip(x._scale, self.conv.stride)]) - return out - - - diff --git a/trellis/modules/sparse/linear.py b/trellis/modules/sparse/linear.py deleted file mode 100644 index 0c25a32a4ab08665a0d3f6bc44e61ef1c1cb2861..0000000000000000000000000000000000000000 --- a/trellis/modules/sparse/linear.py +++ /dev/null @@ -1,15 +0,0 @@ -import torch -import torch.nn as nn -from . import SparseTensor - -__all__ = [ - 'SparseLinear' -] - - -class SparseLinear(nn.Linear): - def __init__(self, in_features, out_features, bias=True): - super(SparseLinear, self).__init__(in_features, out_features, bias) - - def forward(self, input: SparseTensor) -> SparseTensor: - return input.replace(super().forward(input.feats)) diff --git a/trellis/modules/sparse/nonlinearity.py b/trellis/modules/sparse/nonlinearity.py deleted file mode 100644 index db81ee886f5d047a6bc98bb8f4d4ce867d2a302d..0000000000000000000000000000000000000000 --- a/trellis/modules/sparse/nonlinearity.py +++ /dev/null @@ -1,35 +0,0 @@ -import torch -import torch.nn as nn -from . import SparseTensor - -__all__ = [ - 'SparseReLU', - 'SparseSiLU', - 'SparseGELU', - 'SparseActivation' -] - - -class SparseReLU(nn.ReLU): - def forward(self, input: SparseTensor) -> SparseTensor: - return input.replace(super().forward(input.feats)) - - -class SparseSiLU(nn.SiLU): - def forward(self, input: SparseTensor) -> SparseTensor: - return input.replace(super().forward(input.feats)) - - -class SparseGELU(nn.GELU): - def forward(self, input: SparseTensor) -> SparseTensor: - return input.replace(super().forward(input.feats)) - - -class SparseActivation(nn.Module): - def __init__(self, activation: nn.Module): - super().__init__() - self.activation = activation - - def forward(self, input: SparseTensor) -> SparseTensor: - return input.replace(self.activation(input.feats)) - diff --git a/trellis/modules/sparse/norm.py b/trellis/modules/sparse/norm.py deleted file mode 100644 index 7c132d8389277bc9a60937633b58256784597c4d..0000000000000000000000000000000000000000 --- a/trellis/modules/sparse/norm.py +++ /dev/null @@ -1,58 +0,0 @@ -import torch -import torch.nn as nn -from . import SparseTensor -from . import DEBUG - -__all__ = [ - 'SparseGroupNorm', - 'SparseLayerNorm', - 'SparseGroupNorm32', - 'SparseLayerNorm32', -] - - -class SparseGroupNorm(nn.GroupNorm): - def __init__(self, num_groups, num_channels, eps=1e-5, affine=True): - super(SparseGroupNorm, self).__init__(num_groups, num_channels, eps, affine) - - def forward(self, input: SparseTensor) -> SparseTensor: - nfeats = torch.zeros_like(input.feats) - for k in range(input.shape[0]): - if DEBUG: - assert (input.coords[input.layout[k], 0] == k).all(), f"SparseGroupNorm: batch index mismatch" - bfeats = input.feats[input.layout[k]] - bfeats = bfeats.permute(1, 0).reshape(1, input.shape[1], -1) - bfeats = super().forward(bfeats) - bfeats = bfeats.reshape(input.shape[1], -1).permute(1, 0) - nfeats[input.layout[k]] = bfeats - return input.replace(nfeats) - - -class SparseLayerNorm(nn.LayerNorm): - def __init__(self, normalized_shape, eps=1e-5, elementwise_affine=True): - super(SparseLayerNorm, self).__init__(normalized_shape, eps, elementwise_affine) - - def forward(self, input: SparseTensor) -> SparseTensor: - nfeats = torch.zeros_like(input.feats) - for k in range(input.shape[0]): - bfeats = input.feats[input.layout[k]] - bfeats = bfeats.permute(1, 0).reshape(1, input.shape[1], -1) - bfeats = super().forward(bfeats) - bfeats = bfeats.reshape(input.shape[1], -1).permute(1, 0) - nfeats[input.layout[k]] = bfeats - return input.replace(nfeats) - - -class SparseGroupNorm32(SparseGroupNorm): - """ - A GroupNorm layer that converts to float32 before the forward pass. - """ - def forward(self, x: SparseTensor) -> SparseTensor: - return super().forward(x.float()).type(x.dtype) - -class SparseLayerNorm32(SparseLayerNorm): - """ - A LayerNorm layer that converts to float32 before the forward pass. - """ - def forward(self, x: SparseTensor) -> SparseTensor: - return super().forward(x.float()).type(x.dtype) diff --git a/trellis/modules/sparse/spatial.py b/trellis/modules/sparse/spatial.py deleted file mode 100644 index 7a4713e82c42c6a1c7e72a00e3ab17e50f6f32a2..0000000000000000000000000000000000000000 --- a/trellis/modules/sparse/spatial.py +++ /dev/null @@ -1,110 +0,0 @@ -from typing import * -import torch -import torch.nn as nn -from . import SparseTensor - -__all__ = [ - 'SparseDownsample', - 'SparseUpsample', - 'SparseSubdivide' -] - - -class SparseDownsample(nn.Module): - """ - Downsample a sparse tensor by a factor of `factor`. - Implemented as average pooling. - """ - def __init__(self, factor: Union[int, Tuple[int, ...], List[int]]): - super(SparseDownsample, self).__init__() - self.factor = tuple(factor) if isinstance(factor, (list, tuple)) else factor - - def forward(self, input: SparseTensor) -> SparseTensor: - DIM = input.coords.shape[-1] - 1 - factor = self.factor if isinstance(self.factor, tuple) else (self.factor,) * DIM - assert DIM == len(factor), 'Input coordinates must have the same dimension as the downsample factor.' - - coord = list(input.coords.unbind(dim=-1)) - for i, f in enumerate(factor): - coord[i+1] = coord[i+1] // f - - MAX = [coord[i+1].max().item() + 1 for i in range(DIM)] - OFFSET = torch.cumprod(torch.tensor(MAX[::-1]), 0).tolist()[::-1] + [1] - code = sum([c * o for c, o in zip(coord, OFFSET)]) - code, idx = code.unique(return_inverse=True) - - new_feats = torch.scatter_reduce( - torch.zeros(code.shape[0], input.feats.shape[1], device=input.feats.device, dtype=input.feats.dtype), - dim=0, - index=idx.unsqueeze(1).expand(-1, input.feats.shape[1]), - src=input.feats, - reduce='mean' - ) - new_coords = torch.stack( - [code // OFFSET[0]] + - [(code // OFFSET[i+1]) % MAX[i] for i in range(DIM)], - dim=-1 - ) - out = SparseTensor(new_feats, new_coords, input.shape,) - out._scale = tuple([s // f for s, f in zip(input._scale, factor)]) - out._spatial_cache = input._spatial_cache - - out.register_spatial_cache(f'upsample_{factor}_coords', input.coords) - out.register_spatial_cache(f'upsample_{factor}_layout', input.layout) - out.register_spatial_cache(f'upsample_{factor}_idx', idx) - - return out - - -class SparseUpsample(nn.Module): - """ - Upsample a sparse tensor by a factor of `factor`. - Implemented as nearest neighbor interpolation. - """ - def __init__(self, factor: Union[int, Tuple[int, int, int], List[int]]): - super(SparseUpsample, self).__init__() - self.factor = tuple(factor) if isinstance(factor, (list, tuple)) else factor - - def forward(self, input: SparseTensor) -> SparseTensor: - DIM = input.coords.shape[-1] - 1 - factor = self.factor if isinstance(self.factor, tuple) else (self.factor,) * DIM - assert DIM == len(factor), 'Input coordinates must have the same dimension as the upsample factor.' - - new_coords = input.get_spatial_cache(f'upsample_{factor}_coords') - new_layout = input.get_spatial_cache(f'upsample_{factor}_layout') - idx = input.get_spatial_cache(f'upsample_{factor}_idx') - if any([x is None for x in [new_coords, new_layout, idx]]): - raise ValueError('Upsample cache not found. SparseUpsample must be paired with SparseDownsample.') - new_feats = input.feats[idx] - out = SparseTensor(new_feats, new_coords, input.shape, new_layout) - out._scale = tuple([s * f for s, f in zip(input._scale, factor)]) - out._spatial_cache = input._spatial_cache - return out - -class SparseSubdivide(nn.Module): - """ - Upsample a sparse tensor by a factor of `factor`. - Implemented as nearest neighbor interpolation. - """ - def __init__(self): - super(SparseSubdivide, self).__init__() - - def forward(self, input: SparseTensor) -> SparseTensor: - DIM = input.coords.shape[-1] - 1 - # upsample scale=2^DIM - n_cube = torch.ones([2] * DIM, device=input.device, dtype=torch.int) - n_coords = torch.nonzero(n_cube) - n_coords = torch.cat([torch.zeros_like(n_coords[:, :1]), n_coords], dim=-1) - factor = n_coords.shape[0] - assert factor == 2 ** DIM - # print(n_coords.shape) - new_coords = input.coords.clone() - new_coords[:, 1:] *= 2 - new_coords = new_coords.unsqueeze(1) + n_coords.unsqueeze(0).to(new_coords.dtype) - - new_feats = input.feats.unsqueeze(1).expand(input.feats.shape[0], factor, *input.feats.shape[1:]) - out = SparseTensor(new_feats.flatten(0, 1), new_coords.flatten(0, 1), input.shape) - out._scale = input._scale * 2 - out._spatial_cache = input._spatial_cache - return out - diff --git a/trellis/modules/sparse/transformer/__init__.py b/trellis/modules/sparse/transformer/__init__.py deleted file mode 100644 index 67336cacd084ef5e779bf5a601d66720ea275fe6..0000000000000000000000000000000000000000 --- a/trellis/modules/sparse/transformer/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -from .blocks import * -from .modulated import * \ No newline at end of file diff --git a/trellis/modules/sparse/transformer/blocks.py b/trellis/modules/sparse/transformer/blocks.py deleted file mode 100644 index 03b7283e3d2e6c8cdfd7fba82548a73ec7dd3130..0000000000000000000000000000000000000000 --- a/trellis/modules/sparse/transformer/blocks.py +++ /dev/null @@ -1,151 +0,0 @@ -from typing import * -import torch -import torch.nn as nn -from ..basic import SparseTensor -from ..linear import SparseLinear -from ..nonlinearity import SparseGELU -from ..attention import SparseMultiHeadAttention, SerializeMode -from ...norm import LayerNorm32 - - -class SparseFeedForwardNet(nn.Module): - def __init__(self, channels: int, mlp_ratio: float = 4.0): - super().__init__() - self.mlp = nn.Sequential( - SparseLinear(channels, int(channels * mlp_ratio)), - SparseGELU(approximate="tanh"), - SparseLinear(int(channels * mlp_ratio), channels), - ) - - def forward(self, x: SparseTensor) -> SparseTensor: - return self.mlp(x) - - -class SparseTransformerBlock(nn.Module): - """ - Sparse Transformer block (MSA + FFN). - """ - def __init__( - self, - channels: int, - num_heads: int, - mlp_ratio: float = 4.0, - attn_mode: Literal["full", "shift_window", "shift_sequence", "shift_order", "swin"] = "full", - window_size: Optional[int] = None, - shift_sequence: Optional[int] = None, - shift_window: Optional[Tuple[int, int, int]] = None, - serialize_mode: Optional[SerializeMode] = None, - use_checkpoint: bool = False, - use_rope: bool = False, - qk_rms_norm: bool = False, - qkv_bias: bool = True, - ln_affine: bool = False, - ): - super().__init__() - self.use_checkpoint = use_checkpoint - self.norm1 = LayerNorm32(channels, elementwise_affine=ln_affine, eps=1e-6) - self.norm2 = LayerNorm32(channels, elementwise_affine=ln_affine, eps=1e-6) - self.attn = SparseMultiHeadAttention( - channels, - num_heads=num_heads, - attn_mode=attn_mode, - window_size=window_size, - shift_sequence=shift_sequence, - shift_window=shift_window, - serialize_mode=serialize_mode, - qkv_bias=qkv_bias, - use_rope=use_rope, - qk_rms_norm=qk_rms_norm, - ) - self.mlp = SparseFeedForwardNet( - channels, - mlp_ratio=mlp_ratio, - ) - - def _forward(self, x: SparseTensor) -> SparseTensor: - h = x.replace(self.norm1(x.feats)) - h = self.attn(h) - x = x + h - h = x.replace(self.norm2(x.feats)) - h = self.mlp(h) - x = x + h - return x - - def forward(self, x: SparseTensor) -> SparseTensor: - if self.use_checkpoint: - return torch.utils.checkpoint.checkpoint(self._forward, x, use_reentrant=False) - else: - return self._forward(x) - - -class SparseTransformerCrossBlock(nn.Module): - """ - Sparse Transformer cross-attention block (MSA + MCA + FFN). - """ - def __init__( - self, - channels: int, - ctx_channels: int, - num_heads: int, - mlp_ratio: float = 4.0, - attn_mode: Literal["full", "shift_window", "shift_sequence", "shift_order", "swin"] = "full", - window_size: Optional[int] = None, - shift_sequence: Optional[int] = None, - shift_window: Optional[Tuple[int, int, int]] = None, - serialize_mode: Optional[SerializeMode] = None, - use_checkpoint: bool = False, - use_rope: bool = False, - qk_rms_norm: bool = False, - qk_rms_norm_cross: bool = False, - qkv_bias: bool = True, - ln_affine: bool = False, - ): - super().__init__() - self.use_checkpoint = use_checkpoint - self.norm1 = LayerNorm32(channels, elementwise_affine=ln_affine, eps=1e-6) - self.norm2 = LayerNorm32(channels, elementwise_affine=ln_affine, eps=1e-6) - self.norm3 = LayerNorm32(channels, elementwise_affine=ln_affine, eps=1e-6) - self.self_attn = SparseMultiHeadAttention( - channels, - num_heads=num_heads, - type="self", - attn_mode=attn_mode, - window_size=window_size, - shift_sequence=shift_sequence, - shift_window=shift_window, - serialize_mode=serialize_mode, - qkv_bias=qkv_bias, - use_rope=use_rope, - qk_rms_norm=qk_rms_norm, - ) - self.cross_attn = SparseMultiHeadAttention( - channels, - ctx_channels=ctx_channels, - num_heads=num_heads, - type="cross", - attn_mode="full", - qkv_bias=qkv_bias, - qk_rms_norm=qk_rms_norm_cross, - ) - self.mlp = SparseFeedForwardNet( - channels, - mlp_ratio=mlp_ratio, - ) - - def _forward(self, x: SparseTensor, mod: torch.Tensor, context: torch.Tensor): - h = x.replace(self.norm1(x.feats)) - h = self.self_attn(h) - x = x + h - h = x.replace(self.norm2(x.feats)) - h = self.cross_attn(h, context) - x = x + h - h = x.replace(self.norm3(x.feats)) - h = self.mlp(h) - x = x + h - return x - - def forward(self, x: SparseTensor, context: torch.Tensor): - if self.use_checkpoint: - return torch.utils.checkpoint.checkpoint(self._forward, x, context, use_reentrant=False) - else: - return self._forward(x, context) diff --git a/trellis/modules/sparse/transformer/modulated.py b/trellis/modules/sparse/transformer/modulated.py deleted file mode 100644 index 6ec00bbb567dd432130490d799ac6cf480107593..0000000000000000000000000000000000000000 --- a/trellis/modules/sparse/transformer/modulated.py +++ /dev/null @@ -1,166 +0,0 @@ -from typing import * -import torch -import torch.nn as nn -from ..basic import SparseTensor -from ..attention import SparseMultiHeadAttention, SerializeMode -from ...norm import LayerNorm32 -from .blocks import SparseFeedForwardNet - - -class ModulatedSparseTransformerBlock(nn.Module): - """ - Sparse Transformer block (MSA + FFN) with adaptive layer norm conditioning. - """ - def __init__( - self, - channels: int, - num_heads: int, - mlp_ratio: float = 4.0, - attn_mode: Literal["full", "shift_window", "shift_sequence", "shift_order", "swin"] = "full", - window_size: Optional[int] = None, - shift_sequence: Optional[int] = None, - shift_window: Optional[Tuple[int, int, int]] = None, - serialize_mode: Optional[SerializeMode] = None, - use_checkpoint: bool = False, - use_rope: bool = False, - qk_rms_norm: bool = False, - qkv_bias: bool = True, - share_mod: bool = False, - ): - super().__init__() - self.use_checkpoint = use_checkpoint - self.share_mod = share_mod - self.norm1 = LayerNorm32(channels, elementwise_affine=False, eps=1e-6) - self.norm2 = LayerNorm32(channels, elementwise_affine=False, eps=1e-6) - self.attn = SparseMultiHeadAttention( - channels, - num_heads=num_heads, - attn_mode=attn_mode, - window_size=window_size, - shift_sequence=shift_sequence, - shift_window=shift_window, - serialize_mode=serialize_mode, - qkv_bias=qkv_bias, - use_rope=use_rope, - qk_rms_norm=qk_rms_norm, - ) - self.mlp = SparseFeedForwardNet( - channels, - mlp_ratio=mlp_ratio, - ) - if not share_mod: - self.adaLN_modulation = nn.Sequential( - nn.SiLU(), - nn.Linear(channels, 6 * channels, bias=True) - ) - - def _forward(self, x: SparseTensor, mod: torch.Tensor) -> SparseTensor: - if self.share_mod: - shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = mod.chunk(6, dim=1) - else: - shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(mod).chunk(6, dim=1) - h = x.replace(self.norm1(x.feats)) - h = h * (1 + scale_msa) + shift_msa - h = self.attn(h) - h = h * gate_msa - x = x + h - h = x.replace(self.norm2(x.feats)) - h = h * (1 + scale_mlp) + shift_mlp - h = self.mlp(h) - h = h * gate_mlp - x = x + h - return x - - def forward(self, x: SparseTensor, mod: torch.Tensor) -> SparseTensor: - if self.use_checkpoint: - return torch.utils.checkpoint.checkpoint(self._forward, x, mod, use_reentrant=False) - else: - return self._forward(x, mod) - - -class ModulatedSparseTransformerCrossBlock(nn.Module): - """ - Sparse Transformer cross-attention block (MSA + MCA + FFN) with adaptive layer norm conditioning. - """ - def __init__( - self, - channels: int, - ctx_channels: int, - num_heads: int, - mlp_ratio: float = 4.0, - attn_mode: Literal["full", "shift_window", "shift_sequence", "shift_order", "swin"] = "full", - window_size: Optional[int] = None, - shift_sequence: Optional[int] = None, - shift_window: Optional[Tuple[int, int, int]] = None, - serialize_mode: Optional[SerializeMode] = None, - use_checkpoint: bool = False, - use_rope: bool = False, - qk_rms_norm: bool = False, - qk_rms_norm_cross: bool = False, - qkv_bias: bool = True, - share_mod: bool = False, - - ): - super().__init__() - self.use_checkpoint = use_checkpoint - self.share_mod = share_mod - self.norm1 = LayerNorm32(channels, elementwise_affine=False, eps=1e-6) - self.norm2 = LayerNorm32(channels, elementwise_affine=True, eps=1e-6) - self.norm3 = LayerNorm32(channels, elementwise_affine=False, eps=1e-6) - self.self_attn = SparseMultiHeadAttention( - channels, - num_heads=num_heads, - type="self", - attn_mode=attn_mode, - window_size=window_size, - shift_sequence=shift_sequence, - shift_window=shift_window, - serialize_mode=serialize_mode, - qkv_bias=qkv_bias, - use_rope=use_rope, - qk_rms_norm=qk_rms_norm, - ) - self.cross_attn = SparseMultiHeadAttention( - channels, - ctx_channels=ctx_channels, - num_heads=num_heads, - type="cross", - attn_mode="full", - qkv_bias=qkv_bias, - qk_rms_norm=qk_rms_norm_cross, - ) - self.mlp = SparseFeedForwardNet( - channels, - mlp_ratio=mlp_ratio, - ) - if not share_mod: - self.adaLN_modulation = nn.Sequential( - nn.SiLU(), - nn.Linear(channels, 6 * channels, bias=True) - ) - - def _forward(self, x: SparseTensor, mod: torch.Tensor, context: torch.Tensor) -> SparseTensor: - if self.share_mod: - shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = mod.chunk(6, dim=1) - else: - shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(mod).chunk(6, dim=1) - h = x.replace(self.norm1(x.feats)) - h = h * (1 + scale_msa) + shift_msa - h = self.self_attn(h) - h = h * gate_msa - x = x + h - h = x.replace(self.norm2(x.feats)) - h = self.cross_attn(h, context) - x = x + h - h = x.replace(self.norm3(x.feats)) - h = h * (1 + scale_mlp) + shift_mlp - h = self.mlp(h) - h = h * gate_mlp - x = x + h - return x - - def forward(self, x: SparseTensor, mod: torch.Tensor, context: torch.Tensor) -> SparseTensor: - if self.use_checkpoint: - return torch.utils.checkpoint.checkpoint(self._forward, x, mod, context, use_reentrant=False) - else: - return self._forward(x, mod, context) diff --git a/trellis/modules/spatial.py b/trellis/modules/spatial.py deleted file mode 100644 index 5e3b750c1da9462818ad5e25cc50e59a7d92f786..0000000000000000000000000000000000000000 --- a/trellis/modules/spatial.py +++ /dev/null @@ -1,48 +0,0 @@ -import torch - - -def pixel_shuffle_3d(x: torch.Tensor, scale_factor: int) -> torch.Tensor: - """ - 3D pixel shuffle. - """ - B, C, H, W, D = x.shape - C_ = C // scale_factor**3 - x = x.reshape(B, C_, scale_factor, scale_factor, scale_factor, H, W, D) - x = x.permute(0, 1, 5, 2, 6, 3, 7, 4) - x = x.reshape(B, C_, H*scale_factor, W*scale_factor, D*scale_factor) - return x - - -def patchify(x: torch.Tensor, patch_size: int): - """ - Patchify a tensor. - - Args: - x (torch.Tensor): (N, C, *spatial) tensor - patch_size (int): Patch size - """ - DIM = x.dim() - 2 - for d in range(2, DIM + 2): - assert x.shape[d] % patch_size == 0, f"Dimension {d} of input tensor must be divisible by patch size, got {x.shape[d]} and {patch_size}" - - x = x.reshape(*x.shape[:2], *sum([[x.shape[d] // patch_size, patch_size] for d in range(2, DIM + 2)], [])) - x = x.permute(0, 1, *([2 * i + 3 for i in range(DIM)] + [2 * i + 2 for i in range(DIM)])) - x = x.reshape(x.shape[0], x.shape[1] * (patch_size ** DIM), *(x.shape[-DIM:])) - return x - - -def unpatchify(x: torch.Tensor, patch_size: int): - """ - Unpatchify a tensor. - - Args: - x (torch.Tensor): (N, C, *spatial) tensor - patch_size (int): Patch size - """ - DIM = x.dim() - 2 - assert x.shape[1] % (patch_size ** DIM) == 0, f"Second dimension of input tensor must be divisible by patch size to unpatchify, got {x.shape[1]} and {patch_size ** DIM}" - - x = x.reshape(x.shape[0], x.shape[1] // (patch_size ** DIM), *([patch_size] * DIM), *(x.shape[-DIM:])) - x = x.permute(0, 1, *(sum([[2 + DIM + i, 2 + i] for i in range(DIM)], []))) - x = x.reshape(x.shape[0], x.shape[1], *[x.shape[2 + 2 * i] * patch_size for i in range(DIM)]) - return x diff --git a/trellis/modules/transformer/__init__.py b/trellis/modules/transformer/__init__.py deleted file mode 100644 index 67336cacd084ef5e779bf5a601d66720ea275fe6..0000000000000000000000000000000000000000 --- a/trellis/modules/transformer/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -from .blocks import * -from .modulated import * \ No newline at end of file diff --git a/trellis/modules/transformer/blocks.py b/trellis/modules/transformer/blocks.py deleted file mode 100644 index ab65a2a20172b00e1f35e8e2db45701f393ff82d..0000000000000000000000000000000000000000 --- a/trellis/modules/transformer/blocks.py +++ /dev/null @@ -1,182 +0,0 @@ -from typing import * -import torch -import torch.nn as nn -from ..attention import MultiHeadAttention -from ..norm import LayerNorm32 - - -class AbsolutePositionEmbedder(nn.Module): - """ - Embeds spatial positions into vector representations. - """ - def __init__(self, channels: int, in_channels: int = 3): - super().__init__() - self.channels = channels - self.in_channels = in_channels - self.freq_dim = channels // in_channels // 2 - self.freqs = torch.arange(self.freq_dim, dtype=torch.float32) / self.freq_dim - self.freqs = 1.0 / (10000 ** self.freqs) - - def _sin_cos_embedding(self, x: torch.Tensor) -> torch.Tensor: - """ - Create sinusoidal position embeddings. - - Args: - x: a 1-D Tensor of N indices - - Returns: - an (N, D) Tensor of positional embeddings. - """ - self.freqs = self.freqs.to(x.device) - out = torch.outer(x, self.freqs) - out = torch.cat([torch.sin(out), torch.cos(out)], dim=-1) - return out - - def forward(self, x: torch.Tensor) -> torch.Tensor: - """ - Args: - x (torch.Tensor): (N, D) tensor of spatial positions - """ - N, D = x.shape - assert D == self.in_channels, "Input dimension must match number of input channels" - embed = self._sin_cos_embedding(x.reshape(-1)) - embed = embed.reshape(N, -1) - if embed.shape[1] < self.channels: - embed = torch.cat([embed, torch.zeros(N, self.channels - embed.shape[1], device=embed.device)], dim=-1) - return embed - - -class FeedForwardNet(nn.Module): - def __init__(self, channels: int, mlp_ratio: float = 4.0): - super().__init__() - self.mlp = nn.Sequential( - nn.Linear(channels, int(channels * mlp_ratio)), - nn.GELU(approximate="tanh"), - nn.Linear(int(channels * mlp_ratio), channels), - ) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - return self.mlp(x) - - -class TransformerBlock(nn.Module): - """ - Transformer block (MSA + FFN). - """ - def __init__( - self, - channels: int, - num_heads: int, - mlp_ratio: float = 4.0, - attn_mode: Literal["full", "windowed"] = "full", - window_size: Optional[int] = None, - shift_window: Optional[int] = None, - use_checkpoint: bool = False, - use_rope: bool = False, - qk_rms_norm: bool = False, - qkv_bias: bool = True, - ln_affine: bool = False, - ): - super().__init__() - self.use_checkpoint = use_checkpoint - self.norm1 = LayerNorm32(channels, elementwise_affine=ln_affine, eps=1e-6) - self.norm2 = LayerNorm32(channels, elementwise_affine=ln_affine, eps=1e-6) - self.attn = MultiHeadAttention( - channels, - num_heads=num_heads, - attn_mode=attn_mode, - window_size=window_size, - shift_window=shift_window, - qkv_bias=qkv_bias, - use_rope=use_rope, - qk_rms_norm=qk_rms_norm, - ) - self.mlp = FeedForwardNet( - channels, - mlp_ratio=mlp_ratio, - ) - - def _forward(self, x: torch.Tensor) -> torch.Tensor: - h = self.norm1(x) - h = self.attn(h) - x = x + h - h = self.norm2(x) - h = self.mlp(h) - x = x + h - return x - - def forward(self, x: torch.Tensor) -> torch.Tensor: - if self.use_checkpoint: - return torch.utils.checkpoint.checkpoint(self._forward, x, use_reentrant=False) - else: - return self._forward(x) - - -class TransformerCrossBlock(nn.Module): - """ - Transformer cross-attention block (MSA + MCA + FFN). - """ - def __init__( - self, - channels: int, - ctx_channels: int, - num_heads: int, - mlp_ratio: float = 4.0, - attn_mode: Literal["full", "windowed"] = "full", - window_size: Optional[int] = None, - shift_window: Optional[Tuple[int, int, int]] = None, - use_checkpoint: bool = False, - use_rope: bool = False, - qk_rms_norm: bool = False, - qk_rms_norm_cross: bool = False, - qkv_bias: bool = True, - ln_affine: bool = False, - ): - super().__init__() - self.use_checkpoint = use_checkpoint - self.norm1 = LayerNorm32(channels, elementwise_affine=ln_affine, eps=1e-6) - self.norm2 = LayerNorm32(channels, elementwise_affine=ln_affine, eps=1e-6) - self.norm3 = LayerNorm32(channels, elementwise_affine=ln_affine, eps=1e-6) - self.self_attn = MultiHeadAttention( - channels, - num_heads=num_heads, - type="self", - attn_mode=attn_mode, - window_size=window_size, - shift_window=shift_window, - qkv_bias=qkv_bias, - use_rope=use_rope, - qk_rms_norm=qk_rms_norm, - ) - self.cross_attn = MultiHeadAttention( - channels, - ctx_channels=ctx_channels, - num_heads=num_heads, - type="cross", - attn_mode="full", - qkv_bias=qkv_bias, - qk_rms_norm=qk_rms_norm_cross, - ) - self.mlp = FeedForwardNet( - channels, - mlp_ratio=mlp_ratio, - ) - - def _forward(self, x: torch.Tensor, context: torch.Tensor): - h = self.norm1(x) - h = self.self_attn(h) - x = x + h - h = self.norm2(x) - h = self.cross_attn(h, context) - x = x + h - h = self.norm3(x) - h = self.mlp(h) - x = x + h - return x - - def forward(self, x: torch.Tensor, context: torch.Tensor): - if self.use_checkpoint: - return torch.utils.checkpoint.checkpoint(self._forward, x, context, use_reentrant=False) - else: - return self._forward(x, context) - \ No newline at end of file diff --git a/trellis/modules/transformer/modulated.py b/trellis/modules/transformer/modulated.py deleted file mode 100644 index a8b90190b6dabc04b38499a6033483334fcfa69a..0000000000000000000000000000000000000000 --- a/trellis/modules/transformer/modulated.py +++ /dev/null @@ -1,157 +0,0 @@ -from typing import * -import torch -import torch.nn as nn -from ..attention import MultiHeadAttention -from ..norm import LayerNorm32 -from .blocks import FeedForwardNet - - -class ModulatedTransformerBlock(nn.Module): - """ - Transformer block (MSA + FFN) with adaptive layer norm conditioning. - """ - def __init__( - self, - channels: int, - num_heads: int, - mlp_ratio: float = 4.0, - attn_mode: Literal["full", "windowed"] = "full", - window_size: Optional[int] = None, - shift_window: Optional[Tuple[int, int, int]] = None, - use_checkpoint: bool = False, - use_rope: bool = False, - qk_rms_norm: bool = False, - qkv_bias: bool = True, - share_mod: bool = False, - ): - super().__init__() - self.use_checkpoint = use_checkpoint - self.share_mod = share_mod - self.norm1 = LayerNorm32(channels, elementwise_affine=False, eps=1e-6) - self.norm2 = LayerNorm32(channels, elementwise_affine=False, eps=1e-6) - self.attn = MultiHeadAttention( - channels, - num_heads=num_heads, - attn_mode=attn_mode, - window_size=window_size, - shift_window=shift_window, - qkv_bias=qkv_bias, - use_rope=use_rope, - qk_rms_norm=qk_rms_norm, - ) - self.mlp = FeedForwardNet( - channels, - mlp_ratio=mlp_ratio, - ) - if not share_mod: - self.adaLN_modulation = nn.Sequential( - nn.SiLU(), - nn.Linear(channels, 6 * channels, bias=True) - ) - - def _forward(self, x: torch.Tensor, mod: torch.Tensor) -> torch.Tensor: - if self.share_mod: - shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = mod.chunk(6, dim=1) - else: - shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(mod).chunk(6, dim=1) - h = self.norm1(x) - h = h * (1 + scale_msa.unsqueeze(1)) + shift_msa.unsqueeze(1) - h = self.attn(h) - h = h * gate_msa.unsqueeze(1) - x = x + h - h = self.norm2(x) - h = h * (1 + scale_mlp.unsqueeze(1)) + shift_mlp.unsqueeze(1) - h = self.mlp(h) - h = h * gate_mlp.unsqueeze(1) - x = x + h - return x - - def forward(self, x: torch.Tensor, mod: torch.Tensor) -> torch.Tensor: - if self.use_checkpoint: - return torch.utils.checkpoint.checkpoint(self._forward, x, mod, use_reentrant=False) - else: - return self._forward(x, mod) - - -class ModulatedTransformerCrossBlock(nn.Module): - """ - Transformer cross-attention block (MSA + MCA + FFN) with adaptive layer norm conditioning. - """ - def __init__( - self, - channels: int, - ctx_channels: int, - num_heads: int, - mlp_ratio: float = 4.0, - attn_mode: Literal["full", "windowed"] = "full", - window_size: Optional[int] = None, - shift_window: Optional[Tuple[int, int, int]] = None, - use_checkpoint: bool = False, - use_rope: bool = False, - qk_rms_norm: bool = False, - qk_rms_norm_cross: bool = False, - qkv_bias: bool = True, - share_mod: bool = False, - ): - super().__init__() - self.use_checkpoint = use_checkpoint - self.share_mod = share_mod - self.norm1 = LayerNorm32(channels, elementwise_affine=False, eps=1e-6) - self.norm2 = LayerNorm32(channels, elementwise_affine=True, eps=1e-6) - self.norm3 = LayerNorm32(channels, elementwise_affine=False, eps=1e-6) - self.self_attn = MultiHeadAttention( - channels, - num_heads=num_heads, - type="self", - attn_mode=attn_mode, - window_size=window_size, - shift_window=shift_window, - qkv_bias=qkv_bias, - use_rope=use_rope, - qk_rms_norm=qk_rms_norm, - ) - self.cross_attn = MultiHeadAttention( - channels, - ctx_channels=ctx_channels, - num_heads=num_heads, - type="cross", - attn_mode="full", - qkv_bias=qkv_bias, - qk_rms_norm=qk_rms_norm_cross, - ) - self.mlp = FeedForwardNet( - channels, - mlp_ratio=mlp_ratio, - ) - if not share_mod: - self.adaLN_modulation = nn.Sequential( - nn.SiLU(), - nn.Linear(channels, 6 * channels, bias=True) - ) - - def _forward(self, x: torch.Tensor, mod: torch.Tensor, context: torch.Tensor): - if self.share_mod: - shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = mod.chunk(6, dim=1) - else: - shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(mod).chunk(6, dim=1) - h = self.norm1(x) - h = h * (1 + scale_msa.unsqueeze(1)) + shift_msa.unsqueeze(1) - h = self.self_attn(h) - h = h * gate_msa.unsqueeze(1) - x = x + h - h = self.norm2(x) - h = self.cross_attn(h, context) - x = x + h - h = self.norm3(x) - h = h * (1 + scale_mlp.unsqueeze(1)) + shift_mlp.unsqueeze(1) - h = self.mlp(h) - h = h * gate_mlp.unsqueeze(1) - x = x + h - return x - - def forward(self, x: torch.Tensor, mod: torch.Tensor, context: torch.Tensor): - if self.use_checkpoint: - return torch.utils.checkpoint.checkpoint(self._forward, x, mod, context, use_reentrant=False) - else: - return self._forward(x, mod, context) - \ No newline at end of file diff --git a/trellis/modules/utils.py b/trellis/modules/utils.py deleted file mode 100644 index 215fe98059ab49eca42703a4f1f92d80c5343f6b..0000000000000000000000000000000000000000 --- a/trellis/modules/utils.py +++ /dev/null @@ -1,54 +0,0 @@ -import torch.nn as nn -from ..modules import sparse as sp - -FP16_MODULES = ( - nn.Conv1d, - nn.Conv2d, - nn.Conv3d, - nn.ConvTranspose1d, - nn.ConvTranspose2d, - nn.ConvTranspose3d, - nn.Linear, - sp.SparseConv3d, - sp.SparseInverseConv3d, - sp.SparseLinear, -) - -def convert_module_to_f16(l): - """ - Convert primitive modules to float16. - """ - if isinstance(l, FP16_MODULES): - for p in l.parameters(): - p.data = p.data.half() - - -def convert_module_to_f32(l): - """ - Convert primitive modules to float32, undoing convert_module_to_f16(). - """ - if isinstance(l, FP16_MODULES): - for p in l.parameters(): - p.data = p.data.float() - - -def zero_module(module): - """ - Zero out the parameters of a module and return it. - """ - for p in module.parameters(): - p.detach().zero_() - return module - - -def scale_module(module, scale): - """ - Scale the parameters of a module and return it. - """ - for p in module.parameters(): - p.detach().mul_(scale) - return module - - -def modulate(x, shift, scale): - return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1) diff --git a/trellis/pipelines/__init__.py b/trellis/pipelines/__init__.py deleted file mode 100644 index c01e15250c6b976fb61567e887c6f33bb849fe50..0000000000000000000000000000000000000000 --- a/trellis/pipelines/__init__.py +++ /dev/null @@ -1,24 +0,0 @@ -from . import samplers -from .trellis_image_to_3d import TrellisImageTo3DPipeline - - -def from_pretrained(path: str): - """ - Load a pipeline from a model folder or a Hugging Face model hub. - - Args: - path: The path to the model. Can be either local path or a Hugging Face model name. - """ - import os - import json - is_local = os.path.exists(f"{path}/pipeline.json") - - if is_local: - config_file = f"{path}/pipeline.json" - else: - from huggingface_hub import hf_hub_download - config_file = hf_hub_download(path, "pipeline.json") - - with open(config_file, 'r') as f: - config = json.load(f) - return globals()[config['name']].from_pretrained(path) diff --git a/trellis/pipelines/base.py b/trellis/pipelines/base.py deleted file mode 100644 index 041ddce12a483557a25355bd6f2ccfe5bbfb5a17..0000000000000000000000000000000000000000 --- a/trellis/pipelines/base.py +++ /dev/null @@ -1,66 +0,0 @@ -from typing import * -import torch -import torch.nn as nn -from .. import models - - -class Pipeline: - """ - A base class for pipelines. - """ - def __init__( - self, - models: dict[str, nn.Module] = None, - ): - if models is None: - return - self.models = models - for model in self.models.values(): - model.eval() - - @staticmethod - def from_pretrained(path: str) -> "Pipeline": - """ - Load a pretrained model. - """ - import os - import json - is_local = os.path.exists(f"{path}/pipeline.json") - - if is_local: - config_file = f"{path}/pipeline.json" - else: - from huggingface_hub import hf_hub_download - config_file = hf_hub_download(path, "pipeline.json") - - with open(config_file, 'r') as f: - args = json.load(f)['args'] - - _models = { - k: models.from_pretrained(f"{path}/{v}") - for k, v in args['models'].items() - } - - new_pipeline = Pipeline(_models) - new_pipeline._pretrained_args = args - return new_pipeline - - @property - def device(self) -> torch.device: - for model in self.models.values(): - if hasattr(model, 'device'): - return model.device - for model in self.models.values(): - if hasattr(model, 'parameters'): - return next(model.parameters()).device - raise RuntimeError("No device found.") - - def to(self, device: torch.device) -> None: - for model in self.models.values(): - model.to(device) - - def cuda(self) -> None: - self.to(torch.device("cuda")) - - def cpu(self) -> None: - self.to(torch.device("cpu")) diff --git a/trellis/pipelines/samplers/__init__.py b/trellis/pipelines/samplers/__init__.py deleted file mode 100644 index 4b1111715e7d39c8b5db1b70bbc1360d9bb6e0c6..0000000000000000000000000000000000000000 --- a/trellis/pipelines/samplers/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -from .base import Sampler -from .flow_euler import FlowEulerSampler, FlowEulerCfgSampler, FlowEulerGuidanceIntervalSampler \ No newline at end of file diff --git a/trellis/pipelines/samplers/base.py b/trellis/pipelines/samplers/base.py deleted file mode 100644 index bb70700117317477e738845e566b9ea87a768d0a..0000000000000000000000000000000000000000 --- a/trellis/pipelines/samplers/base.py +++ /dev/null @@ -1,20 +0,0 @@ -from typing import * -from abc import ABC, abstractmethod - - -class Sampler(ABC): - """ - A base class for samplers. - """ - - @abstractmethod - def sample( - self, - model, - **kwargs - ): - """ - Sample from a model. - """ - pass - \ No newline at end of file diff --git a/trellis/pipelines/samplers/classifier_free_guidance_mixin.py b/trellis/pipelines/samplers/classifier_free_guidance_mixin.py deleted file mode 100644 index 076e1e3d9f6d1e7207c9659db530990894d614f9..0000000000000000000000000000000000000000 --- a/trellis/pipelines/samplers/classifier_free_guidance_mixin.py +++ /dev/null @@ -1,12 +0,0 @@ -from typing import * - - -class ClassifierFreeGuidanceSamplerMixin: - """ - A mixin class for samplers that apply classifier-free guidance. - """ - - def _inference_model(self, model, x_t, t, cond, neg_cond, cfg_strength, **kwargs): - pred = super()._inference_model(model, x_t, t, cond, **kwargs) - neg_pred = super()._inference_model(model, x_t, t, neg_cond, **kwargs) - return (1 + cfg_strength) * pred - cfg_strength * neg_pred diff --git a/trellis/pipelines/samplers/flow_euler.py b/trellis/pipelines/samplers/flow_euler.py deleted file mode 100644 index a84c9d472b0c74b807663083b8aea63ff1eb3c7e..0000000000000000000000000000000000000000 --- a/trellis/pipelines/samplers/flow_euler.py +++ /dev/null @@ -1,199 +0,0 @@ -from typing import * -import torch -import numpy as np -from tqdm import tqdm -from easydict import EasyDict as edict -from .base import Sampler -from .classifier_free_guidance_mixin import ClassifierFreeGuidanceSamplerMixin -from .guidance_interval_mixin import GuidanceIntervalSamplerMixin - - -class FlowEulerSampler(Sampler): - """ - Generate samples from a flow-matching model using Euler sampling. - - Args: - sigma_min: The minimum scale of noise in flow. - """ - def __init__( - self, - sigma_min: float, - ): - self.sigma_min = sigma_min - - def _eps_to_xstart(self, x_t, t, eps): - assert x_t.shape == eps.shape - return (x_t - (self.sigma_min + (1 - self.sigma_min) * t) * eps) / (1 - t) - - def _xstart_to_eps(self, x_t, t, x_0): - assert x_t.shape == x_0.shape - return (x_t - (1 - t) * x_0) / (self.sigma_min + (1 - self.sigma_min) * t) - - def _v_to_xstart_eps(self, x_t, t, v): - assert x_t.shape == v.shape - eps = (1 - t) * v + x_t - x_0 = (1 - self.sigma_min) * x_t - (self.sigma_min + (1 - self.sigma_min) * t) * v - return x_0, eps - - def _inference_model(self, model, x_t, t, cond=None, **kwargs): - t = torch.tensor([1000 * t] * x_t.shape[0], device=x_t.device, dtype=torch.float32) - return model(x_t, t, cond, **kwargs) - - def _get_model_prediction(self, model, x_t, t, cond=None, **kwargs): - pred_v = self._inference_model(model, x_t, t, cond, **kwargs) - pred_x_0, pred_eps = self._v_to_xstart_eps(x_t=x_t, t=t, v=pred_v) - return pred_x_0, pred_eps, pred_v - - @torch.no_grad() - def sample_once( - self, - model, - x_t, - t: float, - t_prev: float, - cond: Optional[Any] = None, - **kwargs - ): - """ - Sample x_{t-1} from the model using Euler method. - - Args: - model: The model to sample from. - x_t: The [N x C x ...] tensor of noisy inputs at time t. - t: The current timestep. - t_prev: The previous timestep. - cond: conditional information. - **kwargs: Additional arguments for model inference. - - Returns: - a dict containing the following - - 'pred_x_prev': x_{t-1}. - - 'pred_x_0': a prediction of x_0. - """ - pred_x_0, pred_eps, pred_v = self._get_model_prediction(model, x_t, t, cond, **kwargs) - pred_x_prev = x_t - (t - t_prev) * pred_v - return edict({"pred_x_prev": pred_x_prev, "pred_x_0": pred_x_0}) - - @torch.no_grad() - def sample( - self, - model, - noise, - cond: Optional[Any] = None, - steps: int = 50, - rescale_t: float = 1.0, - verbose: bool = True, - **kwargs - ): - """ - Generate samples from the model using Euler method. - - Args: - model: The model to sample from. - noise: The initial noise tensor. - cond: conditional information. - steps: The number of steps to sample. - rescale_t: The rescale factor for t. - verbose: If True, show a progress bar. - **kwargs: Additional arguments for model_inference. - - Returns: - a dict containing the following - - 'samples': the model samples. - - 'pred_x_t': a list of prediction of x_t. - - 'pred_x_0': a list of prediction of x_0. - """ - sample = noise - t_seq = np.linspace(1, 0, steps + 1) - t_seq = rescale_t * t_seq / (1 + (rescale_t - 1) * t_seq) - t_pairs = list((t_seq[i], t_seq[i + 1]) for i in range(steps)) - ret = edict({"samples": None, "pred_x_t": [], "pred_x_0": []}) - for t, t_prev in tqdm(t_pairs, desc="Sampling", disable=not verbose): - out = self.sample_once(model, sample, t, t_prev, cond, **kwargs) - sample = out.pred_x_prev - ret.pred_x_t.append(out.pred_x_prev) - ret.pred_x_0.append(out.pred_x_0) - ret.samples = sample - return ret - - -class FlowEulerCfgSampler(ClassifierFreeGuidanceSamplerMixin, FlowEulerSampler): - """ - Generate samples from a flow-matching model using Euler sampling with classifier-free guidance. - """ - @torch.no_grad() - def sample( - self, - model, - noise, - cond, - neg_cond, - steps: int = 50, - rescale_t: float = 1.0, - cfg_strength: float = 3.0, - verbose: bool = True, - **kwargs - ): - """ - Generate samples from the model using Euler method. - - Args: - model: The model to sample from. - noise: The initial noise tensor. - cond: conditional information. - neg_cond: negative conditional information. - steps: The number of steps to sample. - rescale_t: The rescale factor for t. - cfg_strength: The strength of classifier-free guidance. - verbose: If True, show a progress bar. - **kwargs: Additional arguments for model_inference. - - Returns: - a dict containing the following - - 'samples': the model samples. - - 'pred_x_t': a list of prediction of x_t. - - 'pred_x_0': a list of prediction of x_0. - """ - return super().sample(model, noise, cond, steps, rescale_t, verbose, neg_cond=neg_cond, cfg_strength=cfg_strength, **kwargs) - - -class FlowEulerGuidanceIntervalSampler(GuidanceIntervalSamplerMixin, FlowEulerSampler): - """ - Generate samples from a flow-matching model using Euler sampling with classifier-free guidance and interval. - """ - @torch.no_grad() - def sample( - self, - model, - noise, - cond, - neg_cond, - steps: int = 50, - rescale_t: float = 1.0, - cfg_strength: float = 3.0, - cfg_interval: Tuple[float, float] = (0.0, 1.0), - verbose: bool = True, - **kwargs - ): - """ - Generate samples from the model using Euler method. - - Args: - model: The model to sample from. - noise: The initial noise tensor. - cond: conditional information. - neg_cond: negative conditional information. - steps: The number of steps to sample. - rescale_t: The rescale factor for t. - cfg_strength: The strength of classifier-free guidance. - cfg_interval: The interval for classifier-free guidance. - verbose: If True, show a progress bar. - **kwargs: Additional arguments for model_inference. - - Returns: - a dict containing the following - - 'samples': the model samples. - - 'pred_x_t': a list of prediction of x_t. - - 'pred_x_0': a list of prediction of x_0. - """ - return super().sample(model, noise, cond, steps, rescale_t, verbose, neg_cond=neg_cond, cfg_strength=cfg_strength, cfg_interval=cfg_interval, **kwargs) diff --git a/trellis/pipelines/samplers/guidance_interval_mixin.py b/trellis/pipelines/samplers/guidance_interval_mixin.py deleted file mode 100644 index 10524ce011de121db28b17fcbc2589e60019042e..0000000000000000000000000000000000000000 --- a/trellis/pipelines/samplers/guidance_interval_mixin.py +++ /dev/null @@ -1,15 +0,0 @@ -from typing import * - - -class GuidanceIntervalSamplerMixin: - """ - A mixin class for samplers that apply classifier-free guidance with interval. - """ - - def _inference_model(self, model, x_t, t, cond, neg_cond, cfg_strength, cfg_interval, **kwargs): - if cfg_interval[0] <= t <= cfg_interval[1]: - pred = super()._inference_model(model, x_t, t, cond, **kwargs) - neg_pred = super()._inference_model(model, x_t, t, neg_cond, **kwargs) - return (1 + cfg_strength) * pred - cfg_strength * neg_pred - else: - return super()._inference_model(model, x_t, t, cond, **kwargs) diff --git a/trellis/pipelines/trellis_image_to_3d.py b/trellis/pipelines/trellis_image_to_3d.py deleted file mode 100644 index faeb32297f62e82cdbed6692a2d2b30da1c9c56d..0000000000000000000000000000000000000000 --- a/trellis/pipelines/trellis_image_to_3d.py +++ /dev/null @@ -1,376 +0,0 @@ -from typing import * -from contextlib import contextmanager -import torch -import torch.nn as nn -import torch.nn.functional as F -import numpy as np -from tqdm import tqdm -from easydict import EasyDict as edict -from torchvision import transforms -from PIL import Image -import rembg -from .base import Pipeline -from . import samplers -from ..modules import sparse as sp -from ..representations import Gaussian, Strivec, MeshExtractResult - - -class TrellisImageTo3DPipeline(Pipeline): - """ - Pipeline for inferring Trellis image-to-3D models. - - Args: - models (dict[str, nn.Module]): The models to use in the pipeline. - sparse_structure_sampler (samplers.Sampler): The sampler for the sparse structure. - slat_sampler (samplers.Sampler): The sampler for the structured latent. - slat_normalization (dict): The normalization parameters for the structured latent. - image_cond_model (str): The name of the image conditioning model. - """ - def __init__( - self, - models: dict[str, nn.Module] = None, - sparse_structure_sampler: samplers.Sampler = None, - slat_sampler: samplers.Sampler = None, - slat_normalization: dict = None, - image_cond_model: str = None, - ): - if models is None: - return - super().__init__(models) - self.sparse_structure_sampler = sparse_structure_sampler - self.slat_sampler = slat_sampler - self.sparse_structure_sampler_params = {} - self.slat_sampler_params = {} - self.slat_normalization = slat_normalization - self.rembg_session = None - self._init_image_cond_model(image_cond_model) - - @staticmethod - def from_pretrained(path: str) -> "TrellisImageTo3DPipeline": - """ - Load a pretrained model. - - Args: - path (str): The path to the model. Can be either local path or a Hugging Face repository. - """ - pipeline = super(TrellisImageTo3DPipeline, TrellisImageTo3DPipeline).from_pretrained(path) - new_pipeline = TrellisImageTo3DPipeline() - new_pipeline.__dict__ = pipeline.__dict__ - args = pipeline._pretrained_args - - new_pipeline.sparse_structure_sampler = getattr(samplers, args['sparse_structure_sampler']['name'])(**args['sparse_structure_sampler']['args']) - new_pipeline.sparse_structure_sampler_params = args['sparse_structure_sampler']['params'] - - new_pipeline.slat_sampler = getattr(samplers, args['slat_sampler']['name'])(**args['slat_sampler']['args']) - new_pipeline.slat_sampler_params = args['slat_sampler']['params'] - - new_pipeline.slat_normalization = args['slat_normalization'] - - new_pipeline._init_image_cond_model(args['image_cond_model']) - - return new_pipeline - - def _init_image_cond_model(self, name: str): - """ - Initialize the image conditioning model. - """ - dinov2_model = torch.hub.load('facebookresearch/dinov2', name, pretrained=True) - dinov2_model.eval() - self.models['image_cond_model'] = dinov2_model - transform = transforms.Compose([ - transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), - ]) - self.image_cond_model_transform = transform - - def preprocess_image(self, input: Image.Image) -> Image.Image: - """ - Preprocess the input image. - """ - # if has alpha channel, use it directly; otherwise, remove background - has_alpha = False - if input.mode == 'RGBA': - alpha = np.array(input)[:, :, 3] - if not np.all(alpha == 255): - has_alpha = True - if has_alpha: - output = input - else: - input = input.convert('RGB') - max_size = max(input.size) - scale = min(1, 1024 / max_size) - if scale < 1: - input = input.resize((int(input.width * scale), int(input.height * scale)), Image.Resampling.LANCZOS) - if getattr(self, 'rembg_session', None) is None: - self.rembg_session = rembg.new_session('u2net') - output = rembg.remove(input, session=self.rembg_session) - output_np = np.array(output) - alpha = output_np[:, :, 3] - bbox = np.argwhere(alpha > 0.8 * 255) - bbox = np.min(bbox[:, 1]), np.min(bbox[:, 0]), np.max(bbox[:, 1]), np.max(bbox[:, 0]) - center = (bbox[0] + bbox[2]) / 2, (bbox[1] + bbox[3]) / 2 - size = max(bbox[2] - bbox[0], bbox[3] - bbox[1]) - size = int(size * 1.2) - bbox = center[0] - size // 2, center[1] - size // 2, center[0] + size // 2, center[1] + size // 2 - output = output.crop(bbox) # type: ignore - output = output.resize((518, 518), Image.Resampling.LANCZOS) - output = np.array(output).astype(np.float32) / 255 - output = output[:, :, :3] * output[:, :, 3:4] - output = Image.fromarray((output * 255).astype(np.uint8)) - return output - - @torch.no_grad() - def encode_image(self, image: Union[torch.Tensor, list[Image.Image]]) -> torch.Tensor: - """ - Encode the image. - - Args: - image (Union[torch.Tensor, list[Image.Image]]): The image to encode - - Returns: - torch.Tensor: The encoded features. - """ - if isinstance(image, torch.Tensor): - assert image.ndim == 4, "Image tensor should be batched (B, C, H, W)" - elif isinstance(image, list): - assert all(isinstance(i, Image.Image) for i in image), "Image list should be list of PIL images" - image = [i.resize((518, 518), Image.LANCZOS) for i in image] - image = [np.array(i.convert('RGB')).astype(np.float32) / 255 for i in image] - image = [torch.from_numpy(i).permute(2, 0, 1).float() for i in image] - image = torch.stack(image).to(self.device) - else: - raise ValueError(f"Unsupported type of image: {type(image)}") - - image = self.image_cond_model_transform(image).to(self.device) - features = self.models['image_cond_model'](image, is_training=True)['x_prenorm'] - patchtokens = F.layer_norm(features, features.shape[-1:]) - return patchtokens - - def get_cond(self, image: Union[torch.Tensor, list[Image.Image]]) -> dict: - """ - Get the conditioning information for the model. - - Args: - image (Union[torch.Tensor, list[Image.Image]]): The image prompts. - - Returns: - dict: The conditioning information - """ - cond = self.encode_image(image) - neg_cond = torch.zeros_like(cond) - return { - 'cond': cond, - 'neg_cond': neg_cond, - } - - def sample_sparse_structure( - self, - cond: dict, - num_samples: int = 1, - sampler_params: dict = {}, - ) -> torch.Tensor: - """ - Sample sparse structures with the given conditioning. - - Args: - cond (dict): The conditioning information. - num_samples (int): The number of samples to generate. - sampler_params (dict): Additional parameters for the sampler. - """ - # Sample occupancy latent - flow_model = self.models['sparse_structure_flow_model'] - reso = flow_model.resolution - noise = torch.randn(num_samples, flow_model.in_channels, reso, reso, reso).to(self.device) - sampler_params = {**self.sparse_structure_sampler_params, **sampler_params} - z_s = self.sparse_structure_sampler.sample( - flow_model, - noise, - **cond, - **sampler_params, - verbose=True - ).samples - - # Decode occupancy latent - decoder = self.models['sparse_structure_decoder'] - coords = torch.argwhere(decoder(z_s)>0)[:, [0, 2, 3, 4]].int() - - return coords - - def decode_slat( - self, - slat: sp.SparseTensor, - formats: List[str] = ['mesh', 'gaussian', 'radiance_field'], - ) -> dict: - """ - Decode the structured latent. - - Args: - slat (sp.SparseTensor): The structured latent. - formats (List[str]): The formats to decode the structured latent to. - - Returns: - dict: The decoded structured latent. - """ - ret = {} - if 'mesh' in formats: - ret['mesh'] = self.models['slat_decoder_mesh'](slat) - if 'gaussian' in formats: - ret['gaussian'] = self.models['slat_decoder_gs'](slat) - if 'radiance_field' in formats: - ret['radiance_field'] = self.models['slat_decoder_rf'](slat) - return ret - - def sample_slat( - self, - cond: dict, - coords: torch.Tensor, - sampler_params: dict = {}, - ) -> sp.SparseTensor: - """ - Sample structured latent with the given conditioning. - - Args: - cond (dict): The conditioning information. - coords (torch.Tensor): The coordinates of the sparse structure. - sampler_params (dict): Additional parameters for the sampler. - """ - # Sample structured latent - flow_model = self.models['slat_flow_model'] - noise = sp.SparseTensor( - feats=torch.randn(coords.shape[0], flow_model.in_channels).to(self.device), - coords=coords, - ) - sampler_params = {**self.slat_sampler_params, **sampler_params} - slat = self.slat_sampler.sample( - flow_model, - noise, - **cond, - **sampler_params, - verbose=True - ).samples - - std = torch.tensor(self.slat_normalization['std'])[None].to(slat.device) - mean = torch.tensor(self.slat_normalization['mean'])[None].to(slat.device) - slat = slat * std + mean - - return slat - - @torch.no_grad() - def run( - self, - image: Image.Image, - num_samples: int = 1, - seed: int = 42, - sparse_structure_sampler_params: dict = {}, - slat_sampler_params: dict = {}, - formats: List[str] = ['mesh', 'gaussian', 'radiance_field'], - preprocess_image: bool = True, - ) -> dict: - """ - Run the pipeline. - - Args: - image (Image.Image): The image prompt. - num_samples (int): The number of samples to generate. - sparse_structure_sampler_params (dict): Additional parameters for the sparse structure sampler. - slat_sampler_params (dict): Additional parameters for the structured latent sampler. - preprocess_image (bool): Whether to preprocess the image. - """ - if preprocess_image: - image = self.preprocess_image(image) - cond = self.get_cond([image]) - torch.manual_seed(seed) - coords = self.sample_sparse_structure(cond, num_samples, sparse_structure_sampler_params) - slat = self.sample_slat(cond, coords, slat_sampler_params) - return self.decode_slat(slat, formats) - - @contextmanager - def inject_sampler_multi_image( - self, - sampler_name: str, - num_images: int, - num_steps: int, - mode: Literal['stochastic', 'multidiffusion'] = 'stochastic', - ): - """ - Inject a sampler with multiple images as condition. - - Args: - sampler_name (str): The name of the sampler to inject. - num_images (int): The number of images to condition on. - num_steps (int): The number of steps to run the sampler for. - """ - sampler = getattr(self, sampler_name) - setattr(sampler, f'_old_inference_model', sampler._inference_model) - - if mode == 'stochastic': - if num_images > num_steps: - print(f"\033[93mWarning: number of conditioning images is greater than number of steps for {sampler_name}. " - "This may lead to performance degradation.\033[0m") - - cond_indices = (np.arange(num_steps) % num_images).tolist() - def _new_inference_model(self, model, x_t, t, cond, **kwargs): - cond_idx = cond_indices.pop(0) - cond_i = cond[cond_idx:cond_idx+1] - return self._old_inference_model(model, x_t, t, cond=cond_i, **kwargs) - - elif mode =='multidiffusion': - from .samplers import FlowEulerSampler - def _new_inference_model(self, model, x_t, t, cond, neg_cond, cfg_strength, cfg_interval, **kwargs): - if cfg_interval[0] <= t <= cfg_interval[1]: - preds = [] - for i in range(len(cond)): - preds.append(FlowEulerSampler._inference_model(self, model, x_t, t, cond[i:i+1], **kwargs)) - pred = sum(preds) / len(preds) - neg_pred = FlowEulerSampler._inference_model(self, model, x_t, t, neg_cond, **kwargs) - return (1 + cfg_strength) * pred - cfg_strength * neg_pred - else: - preds = [] - for i in range(len(cond)): - preds.append(FlowEulerSampler._inference_model(self, model, x_t, t, cond[i:i+1], **kwargs)) - pred = sum(preds) / len(preds) - return pred - - else: - raise ValueError(f"Unsupported mode: {mode}") - - sampler._inference_model = _new_inference_model.__get__(sampler, type(sampler)) - - yield - - sampler._inference_model = sampler._old_inference_model - delattr(sampler, f'_old_inference_model') - - @torch.no_grad() - def run_multi_image( - self, - images: List[Image.Image], - num_samples: int = 1, - seed: int = 42, - sparse_structure_sampler_params: dict = {}, - slat_sampler_params: dict = {}, - formats: List[str] = ['mesh', 'gaussian', 'radiance_field'], - preprocess_image: bool = True, - mode: Literal['stochastic', 'multidiffusion'] = 'stochastic', - ) -> dict: - """ - Run the pipeline with multiple images as condition - - Args: - images (List[Image.Image]): The multi-view images of the assets - num_samples (int): The number of samples to generate. - sparse_structure_sampler_params (dict): Additional parameters for the sparse structure sampler. - slat_sampler_params (dict): Additional parameters for the structured latent sampler. - preprocess_image (bool): Whether to preprocess the image. - """ - if preprocess_image: - images = [self.preprocess_image(image) for image in images] - cond = self.get_cond(images) - cond['neg_cond'] = cond['neg_cond'][:1] - torch.manual_seed(seed) - ss_steps = {**self.sparse_structure_sampler_params, **sparse_structure_sampler_params}.get('steps') - with self.inject_sampler_multi_image('sparse_structure_sampler', len(images), ss_steps, mode=mode): - coords = self.sample_sparse_structure(cond, num_samples, sparse_structure_sampler_params) - slat_steps = {**self.slat_sampler_params, **slat_sampler_params}.get('steps') - with self.inject_sampler_multi_image('slat_sampler', len(images), slat_steps, mode=mode): - slat = self.sample_slat(cond, coords, slat_sampler_params) - return self.decode_slat(slat, formats) diff --git a/trellis/renderers/__init__.py b/trellis/renderers/__init__.py deleted file mode 100644 index ec397d1e795baedf05d52ef49f5161885151cafc..0000000000000000000000000000000000000000 --- a/trellis/renderers/__init__.py +++ /dev/null @@ -1,31 +0,0 @@ -import importlib - -__attributes = { - 'OctreeRenderer': 'octree_renderer', - 'GaussianRenderer': 'gaussian_render', - 'MeshRenderer': 'mesh_renderer', -} - -__submodules = [] - -__all__ = list(__attributes.keys()) + __submodules - -def __getattr__(name): - if name not in globals(): - if name in __attributes: - module_name = __attributes[name] - module = importlib.import_module(f".{module_name}", __name__) - globals()[name] = getattr(module, name) - elif name in __submodules: - module = importlib.import_module(f".{name}", __name__) - globals()[name] = module - else: - raise AttributeError(f"module {__name__} has no attribute {name}") - return globals()[name] - - -# For Pylance -if __name__ == '__main__': - from .octree_renderer import OctreeRenderer - from .gaussian_render import GaussianRenderer - from .mesh_renderer import MeshRenderer \ No newline at end of file diff --git a/trellis/renderers/gaussian_render.py b/trellis/renderers/gaussian_render.py deleted file mode 100644 index 272cf07ceaf2cb14bc7b9b82772721d97fd954c8..0000000000000000000000000000000000000000 --- a/trellis/renderers/gaussian_render.py +++ /dev/null @@ -1,231 +0,0 @@ -# -# Copyright (C) 2023, Inria -# GRAPHDECO research group, https://team.inria.fr/graphdeco -# All rights reserved. -# -# This software is free for non-commercial, research and evaluation use -# under the terms of the LICENSE.md file. -# -# For inquiries contact george.drettakis@inria.fr -# - -import torch -import math -from easydict import EasyDict as edict -import numpy as np -from ..representations.gaussian import Gaussian -from .sh_utils import eval_sh -import torch.nn.functional as F -from easydict import EasyDict as edict - - -def intrinsics_to_projection( - intrinsics: torch.Tensor, - near: float, - far: float, - ) -> torch.Tensor: - """ - OpenCV intrinsics to OpenGL perspective matrix - - Args: - intrinsics (torch.Tensor): [3, 3] OpenCV intrinsics matrix - near (float): near plane to clip - far (float): far plane to clip - Returns: - (torch.Tensor): [4, 4] OpenGL perspective matrix - """ - fx, fy = intrinsics[0, 0], intrinsics[1, 1] - cx, cy = intrinsics[0, 2], intrinsics[1, 2] - ret = torch.zeros((4, 4), dtype=intrinsics.dtype, device=intrinsics.device) - ret[0, 0] = 2 * fx - ret[1, 1] = 2 * fy - ret[0, 2] = 2 * cx - 1 - ret[1, 2] = - 2 * cy + 1 - ret[2, 2] = far / (far - near) - ret[2, 3] = near * far / (near - far) - ret[3, 2] = 1. - return ret - - -def render(viewpoint_camera, pc : Gaussian, pipe, bg_color : torch.Tensor, scaling_modifier = 1.0, override_color = None): - """ - Render the scene. - - Background tensor (bg_color) must be on GPU! - """ - # lazy import - if 'GaussianRasterizer' not in globals(): - from diff_gaussian_rasterization import GaussianRasterizer, GaussianRasterizationSettings - - # Create zero tensor. We will use it to make pytorch return gradients of the 2D (screen-space) means - screenspace_points = torch.zeros_like(pc.get_xyz, dtype=pc.get_xyz.dtype, requires_grad=True, device="cuda") + 0 - try: - screenspace_points.retain_grad() - except: - pass - # Set up rasterization configuration - tanfovx = math.tan(viewpoint_camera.FoVx * 0.5) - tanfovy = math.tan(viewpoint_camera.FoVy * 0.5) - - kernel_size = pipe.kernel_size - subpixel_offset = torch.zeros((int(viewpoint_camera.image_height), int(viewpoint_camera.image_width), 2), dtype=torch.float32, device="cuda") - - raster_settings = GaussianRasterizationSettings( - image_height=int(viewpoint_camera.image_height), - image_width=int(viewpoint_camera.image_width), - tanfovx=tanfovx, - tanfovy=tanfovy, - kernel_size=kernel_size, - subpixel_offset=subpixel_offset, - bg=bg_color, - scale_modifier=scaling_modifier, - viewmatrix=viewpoint_camera.world_view_transform, - projmatrix=viewpoint_camera.full_proj_transform, - sh_degree=pc.active_sh_degree, - campos=viewpoint_camera.camera_center, - prefiltered=False, - debug=pipe.debug - ) - - rasterizer = GaussianRasterizer(raster_settings=raster_settings) - - means3D = pc.get_xyz - means2D = screenspace_points - opacity = pc.get_opacity - - # If precomputed 3d covariance is provided, use it. If not, then it will be computed from - # scaling / rotation by the rasterizer. - scales = None - rotations = None - cov3D_precomp = None - if pipe.compute_cov3D_python: - cov3D_precomp = pc.get_covariance(scaling_modifier) - else: - scales = pc.get_scaling - rotations = pc.get_rotation - - # If precomputed colors are provided, use them. Otherwise, if it is desired to precompute colors - # from SHs in Python, do it. If not, then SH -> RGB conversion will be done by rasterizer. - shs = None - colors_precomp = None - if override_color is None: - if pipe.convert_SHs_python: - shs_view = pc.get_features.transpose(1, 2).view(-1, 3, (pc.max_sh_degree+1)**2) - dir_pp = (pc.get_xyz - viewpoint_camera.camera_center.repeat(pc.get_features.shape[0], 1)) - dir_pp_normalized = dir_pp/dir_pp.norm(dim=1, keepdim=True) - sh2rgb = eval_sh(pc.active_sh_degree, shs_view, dir_pp_normalized) - colors_precomp = torch.clamp_min(sh2rgb + 0.5, 0.0) - else: - shs = pc.get_features - else: - colors_precomp = override_color - - # Rasterize visible Gaussians to image, obtain their radii (on screen). - rendered_image, radii = rasterizer( - means3D = means3D, - means2D = means2D, - shs = shs, - colors_precomp = colors_precomp, - opacities = opacity, - scales = scales, - rotations = rotations, - cov3D_precomp = cov3D_precomp - ) - - # Those Gaussians that were frustum culled or had a radius of 0 were not visible. - # They will be excluded from value updates used in the splitting criteria. - return edict({"render": rendered_image, - "viewspace_points": screenspace_points, - "visibility_filter" : radii > 0, - "radii": radii}) - - -class GaussianRenderer: - """ - Renderer for the Voxel representation. - - Args: - rendering_options (dict): Rendering options. - """ - - def __init__(self, rendering_options={}) -> None: - self.pipe = edict({ - "kernel_size": 0.1, - "convert_SHs_python": False, - "compute_cov3D_python": False, - "scale_modifier": 1.0, - "debug": False - }) - self.rendering_options = edict({ - "resolution": None, - "near": None, - "far": None, - "ssaa": 1, - "bg_color": 'random', - }) - self.rendering_options.update(rendering_options) - self.bg_color = None - - def render( - self, - gausssian: Gaussian, - extrinsics: torch.Tensor, - intrinsics: torch.Tensor, - colors_overwrite: torch.Tensor = None - ) -> edict: - """ - Render the gausssian. - - Args: - gaussian : gaussianmodule - extrinsics (torch.Tensor): (4, 4) camera extrinsics - intrinsics (torch.Tensor): (3, 3) camera intrinsics - colors_overwrite (torch.Tensor): (N, 3) override color - - Returns: - edict containing: - color (torch.Tensor): (3, H, W) rendered color image - """ - resolution = self.rendering_options["resolution"] - near = self.rendering_options["near"] - far = self.rendering_options["far"] - ssaa = self.rendering_options["ssaa"] - - if self.rendering_options["bg_color"] == 'random': - self.bg_color = torch.zeros(3, dtype=torch.float32, device="cuda") - if np.random.rand() < 0.5: - self.bg_color += 1 - else: - self.bg_color = torch.tensor(self.rendering_options["bg_color"], dtype=torch.float32, device="cuda") - - view = extrinsics - perspective = intrinsics_to_projection(intrinsics, near, far) - camera = torch.inverse(view)[:3, 3] - focalx = intrinsics[0, 0] - focaly = intrinsics[1, 1] - fovx = 2 * torch.atan(0.5 / focalx) - fovy = 2 * torch.atan(0.5 / focaly) - - camera_dict = edict({ - "image_height": resolution * ssaa, - "image_width": resolution * ssaa, - "FoVx": fovx, - "FoVy": fovy, - "znear": near, - "zfar": far, - "world_view_transform": view.T.contiguous(), - "projection_matrix": perspective.T.contiguous(), - "full_proj_transform": (perspective @ view).T.contiguous(), - "camera_center": camera - }) - - # Render - render_ret = render(camera_dict, gausssian, self.pipe, self.bg_color, override_color=colors_overwrite, scaling_modifier=self.pipe.scale_modifier) - - if ssaa > 1: - render_ret.render = F.interpolate(render_ret.render[None], size=(resolution, resolution), mode='bilinear', align_corners=False, antialias=True).squeeze() - - ret = edict({ - 'color': render_ret['render'] - }) - return ret diff --git a/trellis/renderers/mesh_renderer.py b/trellis/renderers/mesh_renderer.py deleted file mode 100644 index b0726def9a515834db145be3a24817fa97b09071..0000000000000000000000000000000000000000 --- a/trellis/renderers/mesh_renderer.py +++ /dev/null @@ -1,140 +0,0 @@ -# Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# -# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property -# and proprietary rights in and to this software, related documentation -# and any modifications thereto. Any use, reproduction, disclosure or -# distribution of this software and related documentation without an express -# license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited. -import torch -import nvdiffrast.torch as dr -from easydict import EasyDict as edict -from ..representations.mesh import MeshExtractResult -import torch.nn.functional as F - - -def intrinsics_to_projection( - intrinsics: torch.Tensor, - near: float, - far: float, - ) -> torch.Tensor: - """ - OpenCV intrinsics to OpenGL perspective matrix - - Args: - intrinsics (torch.Tensor): [3, 3] OpenCV intrinsics matrix - near (float): near plane to clip - far (float): far plane to clip - Returns: - (torch.Tensor): [4, 4] OpenGL perspective matrix - """ - fx, fy = intrinsics[0, 0], intrinsics[1, 1] - cx, cy = intrinsics[0, 2], intrinsics[1, 2] - ret = torch.zeros((4, 4), dtype=intrinsics.dtype, device=intrinsics.device) - ret[0, 0] = 2 * fx - ret[1, 1] = 2 * fy - ret[0, 2] = 2 * cx - 1 - ret[1, 2] = - 2 * cy + 1 - ret[2, 2] = far / (far - near) - ret[2, 3] = near * far / (near - far) - ret[3, 2] = 1. - return ret - - -class MeshRenderer: - """ - Renderer for the Mesh representation. - - Args: - rendering_options (dict): Rendering options. - glctx (nvdiffrast.torch.RasterizeGLContext): RasterizeGLContext object for CUDA/OpenGL interop. - """ - def __init__(self, rendering_options={}, device='cuda'): - self.rendering_options = edict({ - "resolution": None, - "near": None, - "far": None, - "ssaa": 1 - }) - self.rendering_options.update(rendering_options) - self.glctx = dr.RasterizeCudaContext(device=device) - self.device=device - - def render( - self, - mesh : MeshExtractResult, - extrinsics: torch.Tensor, - intrinsics: torch.Tensor, - return_types = ["mask", "normal", "depth"] - ) -> edict: - """ - Render the mesh. - - Args: - mesh : meshmodel - extrinsics (torch.Tensor): (4, 4) camera extrinsics - intrinsics (torch.Tensor): (3, 3) camera intrinsics - return_types (list): list of return types, can be "mask", "depth", "normal_map", "normal", "color" - - Returns: - edict based on return_types containing: - color (torch.Tensor): [3, H, W] rendered color image - depth (torch.Tensor): [H, W] rendered depth image - normal (torch.Tensor): [3, H, W] rendered normal image - normal_map (torch.Tensor): [3, H, W] rendered normal map image - mask (torch.Tensor): [H, W] rendered mask image - """ - resolution = self.rendering_options["resolution"] - near = self.rendering_options["near"] - far = self.rendering_options["far"] - ssaa = self.rendering_options["ssaa"] - - if mesh.vertices.shape[0] == 0 or mesh.faces.shape[0] == 0: - default_img = torch.zeros((1, resolution, resolution, 3), dtype=torch.float32, device=self.device) - ret_dict = {k : default_img if k in ['normal', 'normal_map', 'color'] else default_img[..., :1] for k in return_types} - return ret_dict - - perspective = intrinsics_to_projection(intrinsics, near, far) - - RT = extrinsics.unsqueeze(0) - full_proj = (perspective @ extrinsics).unsqueeze(0) - - vertices = mesh.vertices.unsqueeze(0) - - vertices_homo = torch.cat([vertices, torch.ones_like(vertices[..., :1])], dim=-1) - vertices_camera = torch.bmm(vertices_homo, RT.transpose(-1, -2)) - vertices_clip = torch.bmm(vertices_homo, full_proj.transpose(-1, -2)) - faces_int = mesh.faces.int() - rast, _ = dr.rasterize( - self.glctx, vertices_clip, faces_int, (resolution * ssaa, resolution * ssaa)) - - out_dict = edict() - for type in return_types: - img = None - if type == "mask" : - img = dr.antialias((rast[..., -1:] > 0).float(), rast, vertices_clip, faces_int) - elif type == "depth": - img = dr.interpolate(vertices_camera[..., 2:3].contiguous(), rast, faces_int)[0] - img = dr.antialias(img, rast, vertices_clip, faces_int) - elif type == "normal" : - img = dr.interpolate( - mesh.face_normal.reshape(1, -1, 3), rast, - torch.arange(mesh.faces.shape[0] * 3, device=self.device, dtype=torch.int).reshape(-1, 3) - )[0] - img = dr.antialias(img, rast, vertices_clip, faces_int) - # normalize norm pictures - img = (img + 1) / 2 - elif type == "normal_map" : - img = dr.interpolate(mesh.vertex_attrs[:, 3:].contiguous(), rast, faces_int)[0] - img = dr.antialias(img, rast, vertices_clip, faces_int) - elif type == "color" : - img = dr.interpolate(mesh.vertex_attrs[:, :3].contiguous(), rast, faces_int)[0] - img = dr.antialias(img, rast, vertices_clip, faces_int) - - if ssaa > 1: - img = F.interpolate(img.permute(0, 3, 1, 2), (resolution, resolution), mode='bilinear', align_corners=False, antialias=True) - img = img.squeeze() - else: - img = img.permute(0, 3, 1, 2).squeeze() - out_dict[type] = img - - return out_dict diff --git a/trellis/renderers/octree_renderer.py b/trellis/renderers/octree_renderer.py deleted file mode 100644 index c72541888c60591109c8690bd269669faad667c0..0000000000000000000000000000000000000000 --- a/trellis/renderers/octree_renderer.py +++ /dev/null @@ -1,300 +0,0 @@ -import numpy as np -import torch -import torch.nn.functional as F -import math -import cv2 -from scipy.stats import qmc -from easydict import EasyDict as edict -from ..representations.octree import DfsOctree - - -def intrinsics_to_projection( - intrinsics: torch.Tensor, - near: float, - far: float, - ) -> torch.Tensor: - """ - OpenCV intrinsics to OpenGL perspective matrix - - Args: - intrinsics (torch.Tensor): [3, 3] OpenCV intrinsics matrix - near (float): near plane to clip - far (float): far plane to clip - Returns: - (torch.Tensor): [4, 4] OpenGL perspective matrix - """ - fx, fy = intrinsics[0, 0], intrinsics[1, 1] - cx, cy = intrinsics[0, 2], intrinsics[1, 2] - ret = torch.zeros((4, 4), dtype=intrinsics.dtype, device=intrinsics.device) - ret[0, 0] = 2 * fx - ret[1, 1] = 2 * fy - ret[0, 2] = 2 * cx - 1 - ret[1, 2] = - 2 * cy + 1 - ret[2, 2] = far / (far - near) - ret[2, 3] = near * far / (near - far) - ret[3, 2] = 1. - return ret - - -def render(viewpoint_camera, octree : DfsOctree, pipe, bg_color : torch.Tensor, scaling_modifier = 1.0, used_rank = None, colors_overwrite = None, aux=None, halton_sampler=None): - """ - Render the scene. - - Background tensor (bg_color) must be on GPU! - """ - # lazy import - if 'OctreeTrivecRasterizer' not in globals(): - from diffoctreerast import OctreeVoxelRasterizer, OctreeGaussianRasterizer, OctreeTrivecRasterizer, OctreeDecoupolyRasterizer - - # Set up rasterization configuration - tanfovx = math.tan(viewpoint_camera.FoVx * 0.5) - tanfovy = math.tan(viewpoint_camera.FoVy * 0.5) - - raster_settings = edict( - image_height=int(viewpoint_camera.image_height), - image_width=int(viewpoint_camera.image_width), - tanfovx=tanfovx, - tanfovy=tanfovy, - bg=bg_color, - scale_modifier=scaling_modifier, - viewmatrix=viewpoint_camera.world_view_transform, - projmatrix=viewpoint_camera.full_proj_transform, - sh_degree=octree.active_sh_degree, - campos=viewpoint_camera.camera_center, - with_distloss=pipe.with_distloss, - jitter=pipe.jitter, - debug=pipe.debug, - ) - - positions = octree.get_xyz - if octree.primitive == "voxel": - densities = octree.get_density - elif octree.primitive == "gaussian": - opacities = octree.get_opacity - elif octree.primitive == "trivec": - trivecs = octree.get_trivec - densities = octree.get_density - raster_settings.density_shift = octree.density_shift - elif octree.primitive == "decoupoly": - decoupolys_V, decoupolys_g = octree.get_decoupoly - densities = octree.get_density - raster_settings.density_shift = octree.density_shift - else: - raise ValueError(f"Unknown primitive {octree.primitive}") - depths = octree.get_depth - - # If precomputed colors are provided, use them. Otherwise, if it is desired to precompute colors - # from SHs in Python, do it. If not, then SH -> RGB conversion will be done by rasterizer. - colors_precomp = None - shs = octree.get_features - if octree.primitive in ["voxel", "gaussian"] and colors_overwrite is not None: - colors_precomp = colors_overwrite - shs = None - - ret = edict() - - if octree.primitive == "voxel": - renderer = OctreeVoxelRasterizer(raster_settings=raster_settings) - rgb, depth, alpha, distloss = renderer( - positions = positions, - densities = densities, - shs = shs, - colors_precomp = colors_precomp, - depths = depths, - aabb = octree.aabb, - aux = aux, - ) - ret['rgb'] = rgb - ret['depth'] = depth - ret['alpha'] = alpha - ret['distloss'] = distloss - elif octree.primitive == "gaussian": - renderer = OctreeGaussianRasterizer(raster_settings=raster_settings) - rgb, depth, alpha = renderer( - positions = positions, - opacities = opacities, - shs = shs, - colors_precomp = colors_precomp, - depths = depths, - aabb = octree.aabb, - aux = aux, - ) - ret['rgb'] = rgb - ret['depth'] = depth - ret['alpha'] = alpha - elif octree.primitive == "trivec": - raster_settings.used_rank = used_rank if used_rank is not None else trivecs.shape[1] - renderer = OctreeTrivecRasterizer(raster_settings=raster_settings) - rgb, depth, alpha, percent_depth = renderer( - positions = positions, - trivecs = trivecs, - densities = densities, - shs = shs, - colors_precomp = colors_precomp, - colors_overwrite = colors_overwrite, - depths = depths, - aabb = octree.aabb, - aux = aux, - halton_sampler = halton_sampler, - ) - ret['percent_depth'] = percent_depth - ret['rgb'] = rgb - ret['depth'] = depth - ret['alpha'] = alpha - elif octree.primitive == "decoupoly": - raster_settings.used_rank = used_rank if used_rank is not None else decoupolys_V.shape[1] - renderer = OctreeDecoupolyRasterizer(raster_settings=raster_settings) - rgb, depth, alpha = renderer( - positions = positions, - decoupolys_V = decoupolys_V, - decoupolys_g = decoupolys_g, - densities = densities, - shs = shs, - colors_precomp = colors_precomp, - depths = depths, - aabb = octree.aabb, - aux = aux, - ) - ret['rgb'] = rgb - ret['depth'] = depth - ret['alpha'] = alpha - - return ret - - -class OctreeRenderer: - """ - Renderer for the Voxel representation. - - Args: - rendering_options (dict): Rendering options. - """ - - def __init__(self, rendering_options={}) -> None: - try: - import diffoctreerast - except ImportError: - print("\033[93m[WARNING] diffoctreerast is not installed. The renderer will be disabled.\033[0m") - self.unsupported = True - else: - self.unsupported = False - - self.pipe = edict({ - "with_distloss": False, - "with_aux": False, - "scale_modifier": 1.0, - "used_rank": None, - "jitter": False, - "debug": False, - }) - self.rendering_options = edict({ - "resolution": None, - "near": None, - "far": None, - "ssaa": 1, - "bg_color": 'random', - }) - self.halton_sampler = qmc.Halton(2, scramble=False) - self.rendering_options.update(rendering_options) - self.bg_color = None - - def render( - self, - octree: DfsOctree, - extrinsics: torch.Tensor, - intrinsics: torch.Tensor, - colors_overwrite: torch.Tensor = None, - ) -> edict: - """ - Render the octree. - - Args: - octree (Octree): octree - extrinsics (torch.Tensor): (4, 4) camera extrinsics - intrinsics (torch.Tensor): (3, 3) camera intrinsics - colors_overwrite (torch.Tensor): (N, 3) override color - - Returns: - edict containing: - color (torch.Tensor): (3, H, W) rendered color - depth (torch.Tensor): (H, W) rendered depth - alpha (torch.Tensor): (H, W) rendered alpha - distloss (Optional[torch.Tensor]): (H, W) rendered distance loss - percent_depth (Optional[torch.Tensor]): (H, W) rendered percent depth - aux (Optional[edict]): auxiliary tensors - """ - resolution = self.rendering_options["resolution"] - near = self.rendering_options["near"] - far = self.rendering_options["far"] - ssaa = self.rendering_options["ssaa"] - - if self.unsupported: - image = np.zeros((512, 512, 3), dtype=np.uint8) - text_bbox = cv2.getTextSize("Unsupported", cv2.FONT_HERSHEY_SIMPLEX, 2, 3)[0] - origin = (512 - text_bbox[0]) // 2, (512 - text_bbox[1]) // 2 - image = cv2.putText(image, "Unsupported", origin, cv2.FONT_HERSHEY_SIMPLEX, 2, (255, 255, 255), 3, cv2.LINE_AA) - return { - 'color': torch.tensor(image, dtype=torch.float32).permute(2, 0, 1) / 255, - } - - if self.rendering_options["bg_color"] == 'random': - self.bg_color = torch.zeros(3, dtype=torch.float32, device="cuda") - if np.random.rand() < 0.5: - self.bg_color += 1 - else: - self.bg_color = torch.tensor(self.rendering_options["bg_color"], dtype=torch.float32, device="cuda") - - if self.pipe["with_aux"]: - aux = { - 'grad_color2': torch.zeros((octree.num_leaf_nodes, 3), dtype=torch.float32, requires_grad=True, device="cuda") + 0, - 'contributions': torch.zeros((octree.num_leaf_nodes, 1), dtype=torch.float32, requires_grad=True, device="cuda") + 0, - } - for k in aux.keys(): - aux[k].requires_grad_() - aux[k].retain_grad() - else: - aux = None - - view = extrinsics - perspective = intrinsics_to_projection(intrinsics, near, far) - camera = torch.inverse(view)[:3, 3] - focalx = intrinsics[0, 0] - focaly = intrinsics[1, 1] - fovx = 2 * torch.atan(0.5 / focalx) - fovy = 2 * torch.atan(0.5 / focaly) - - camera_dict = edict({ - "image_height": resolution * ssaa, - "image_width": resolution * ssaa, - "FoVx": fovx, - "FoVy": fovy, - "znear": near, - "zfar": far, - "world_view_transform": view.T.contiguous(), - "projection_matrix": perspective.T.contiguous(), - "full_proj_transform": (perspective @ view).T.contiguous(), - "camera_center": camera - }) - - # Render - render_ret = render(camera_dict, octree, self.pipe, self.bg_color, aux=aux, colors_overwrite=colors_overwrite, scaling_modifier=self.pipe.scale_modifier, used_rank=self.pipe.used_rank, halton_sampler=self.halton_sampler) - - if ssaa > 1: - render_ret.rgb = F.interpolate(render_ret.rgb[None], size=(resolution, resolution), mode='bilinear', align_corners=False, antialias=True).squeeze() - render_ret.depth = F.interpolate(render_ret.depth[None, None], size=(resolution, resolution), mode='bilinear', align_corners=False, antialias=True).squeeze() - render_ret.alpha = F.interpolate(render_ret.alpha[None, None], size=(resolution, resolution), mode='bilinear', align_corners=False, antialias=True).squeeze() - if hasattr(render_ret, 'percent_depth'): - render_ret.percent_depth = F.interpolate(render_ret.percent_depth[None, None], size=(resolution, resolution), mode='bilinear', align_corners=False, antialias=True).squeeze() - - ret = edict({ - 'color': render_ret.rgb, - 'depth': render_ret.depth, - 'alpha': render_ret.alpha, - }) - if self.pipe["with_distloss"] and 'distloss' in render_ret: - ret['distloss'] = render_ret.distloss - if self.pipe["with_aux"]: - ret['aux'] = aux - if hasattr(render_ret, 'percent_depth'): - ret['percent_depth'] = render_ret.percent_depth - return ret diff --git a/trellis/renderers/sh_utils.py b/trellis/renderers/sh_utils.py deleted file mode 100644 index a54612b24cde4e1ab6da3fb37142cc5d0248ada8..0000000000000000000000000000000000000000 --- a/trellis/renderers/sh_utils.py +++ /dev/null @@ -1,118 +0,0 @@ -# Copyright 2021 The PlenOctree Authors. -# Redistribution and use in source and binary forms, with or without -# modification, are permitted provided that the following conditions are met: -# -# 1. Redistributions of source code must retain the above copyright notice, -# this list of conditions and the following disclaimer. -# -# 2. Redistributions in binary form must reproduce the above copyright notice, -# this list of conditions and the following disclaimer in the documentation -# and/or other materials provided with the distribution. -# -# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE -# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE -# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR -# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF -# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS -# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN -# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) -# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE -# POSSIBILITY OF SUCH DAMAGE. - -import torch - -C0 = 0.28209479177387814 -C1 = 0.4886025119029199 -C2 = [ - 1.0925484305920792, - -1.0925484305920792, - 0.31539156525252005, - -1.0925484305920792, - 0.5462742152960396 -] -C3 = [ - -0.5900435899266435, - 2.890611442640554, - -0.4570457994644658, - 0.3731763325901154, - -0.4570457994644658, - 1.445305721320277, - -0.5900435899266435 -] -C4 = [ - 2.5033429417967046, - -1.7701307697799304, - 0.9461746957575601, - -0.6690465435572892, - 0.10578554691520431, - -0.6690465435572892, - 0.47308734787878004, - -1.7701307697799304, - 0.6258357354491761, -] - - -def eval_sh(deg, sh, dirs): - """ - Evaluate spherical harmonics at unit directions - using hardcoded SH polynomials. - Works with torch/np/jnp. - ... Can be 0 or more batch dimensions. - Args: - deg: int SH deg. Currently, 0-3 supported - sh: jnp.ndarray SH coeffs [..., C, (deg + 1) ** 2] - dirs: jnp.ndarray unit directions [..., 3] - Returns: - [..., C] - """ - assert deg <= 4 and deg >= 0 - coeff = (deg + 1) ** 2 - assert sh.shape[-1] >= coeff - - result = C0 * sh[..., 0] - if deg > 0: - x, y, z = dirs[..., 0:1], dirs[..., 1:2], dirs[..., 2:3] - result = (result - - C1 * y * sh[..., 1] + - C1 * z * sh[..., 2] - - C1 * x * sh[..., 3]) - - if deg > 1: - xx, yy, zz = x * x, y * y, z * z - xy, yz, xz = x * y, y * z, x * z - result = (result + - C2[0] * xy * sh[..., 4] + - C2[1] * yz * sh[..., 5] + - C2[2] * (2.0 * zz - xx - yy) * sh[..., 6] + - C2[3] * xz * sh[..., 7] + - C2[4] * (xx - yy) * sh[..., 8]) - - if deg > 2: - result = (result + - C3[0] * y * (3 * xx - yy) * sh[..., 9] + - C3[1] * xy * z * sh[..., 10] + - C3[2] * y * (4 * zz - xx - yy)* sh[..., 11] + - C3[3] * z * (2 * zz - 3 * xx - 3 * yy) * sh[..., 12] + - C3[4] * x * (4 * zz - xx - yy) * sh[..., 13] + - C3[5] * z * (xx - yy) * sh[..., 14] + - C3[6] * x * (xx - 3 * yy) * sh[..., 15]) - - if deg > 3: - result = (result + C4[0] * xy * (xx - yy) * sh[..., 16] + - C4[1] * yz * (3 * xx - yy) * sh[..., 17] + - C4[2] * xy * (7 * zz - 1) * sh[..., 18] + - C4[3] * yz * (7 * zz - 3) * sh[..., 19] + - C4[4] * (zz * (35 * zz - 30) + 3) * sh[..., 20] + - C4[5] * xz * (7 * zz - 3) * sh[..., 21] + - C4[6] * (xx - yy) * (7 * zz - 1) * sh[..., 22] + - C4[7] * xz * (xx - 3 * yy) * sh[..., 23] + - C4[8] * (xx * (xx - 3 * yy) - yy * (3 * xx - yy)) * sh[..., 24]) - return result - -def RGB2SH(rgb): - return (rgb - 0.5) / C0 - -def SH2RGB(sh): - return sh * C0 + 0.5 \ No newline at end of file diff --git a/trellis/representations/__init__.py b/trellis/representations/__init__.py deleted file mode 100644 index 26c62155fc77f668549092c55fca0c610aa02540..0000000000000000000000000000000000000000 --- a/trellis/representations/__init__.py +++ /dev/null @@ -1,4 +0,0 @@ -from .radiance_field import Strivec -from .octree import DfsOctree as Octree -from .gaussian import Gaussian -from .mesh import MeshExtractResult diff --git a/trellis/representations/gaussian/__init__.py b/trellis/representations/gaussian/__init__.py deleted file mode 100644 index e3de6e180bd732836af876d748255595be2d4d74..0000000000000000000000000000000000000000 --- a/trellis/representations/gaussian/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from .gaussian_model import Gaussian \ No newline at end of file diff --git a/trellis/representations/gaussian/gaussian_model.py b/trellis/representations/gaussian/gaussian_model.py deleted file mode 100644 index 373411cb16aa2baf466eaa600fb4d4248bd550c0..0000000000000000000000000000000000000000 --- a/trellis/representations/gaussian/gaussian_model.py +++ /dev/null @@ -1,209 +0,0 @@ -import torch -import numpy as np -from plyfile import PlyData, PlyElement -from .general_utils import inverse_sigmoid, strip_symmetric, build_scaling_rotation -import utils3d - - -class Gaussian: - def __init__( - self, - aabb : list, - sh_degree : int = 0, - mininum_kernel_size : float = 0.0, - scaling_bias : float = 0.01, - opacity_bias : float = 0.1, - scaling_activation : str = "exp", - device='cuda' - ): - self.init_params = { - 'aabb': aabb, - 'sh_degree': sh_degree, - 'mininum_kernel_size': mininum_kernel_size, - 'scaling_bias': scaling_bias, - 'opacity_bias': opacity_bias, - 'scaling_activation': scaling_activation, - } - - self.sh_degree = sh_degree - self.active_sh_degree = sh_degree - self.mininum_kernel_size = mininum_kernel_size - self.scaling_bias = scaling_bias - self.opacity_bias = opacity_bias - self.scaling_activation_type = scaling_activation - self.device = device - self.aabb = torch.tensor(aabb, dtype=torch.float32, device=device) - self.setup_functions() - - self._xyz = None - self._features_dc = None - self._features_rest = None - self._scaling = None - self._rotation = None - self._opacity = None - - def setup_functions(self): - def build_covariance_from_scaling_rotation(scaling, scaling_modifier, rotation): - L = build_scaling_rotation(scaling_modifier * scaling, rotation) - actual_covariance = L @ L.transpose(1, 2) - symm = strip_symmetric(actual_covariance) - return symm - - if self.scaling_activation_type == "exp": - self.scaling_activation = torch.exp - self.inverse_scaling_activation = torch.log - elif self.scaling_activation_type == "softplus": - self.scaling_activation = torch.nn.functional.softplus - self.inverse_scaling_activation = lambda x: x + torch.log(-torch.expm1(-x)) - - self.covariance_activation = build_covariance_from_scaling_rotation - - self.opacity_activation = torch.sigmoid - self.inverse_opacity_activation = inverse_sigmoid - - self.rotation_activation = torch.nn.functional.normalize - - self.scale_bias = self.inverse_scaling_activation(torch.tensor(self.scaling_bias)).cuda() - self.rots_bias = torch.zeros((4)).cuda() - self.rots_bias[0] = 1 - self.opacity_bias = self.inverse_opacity_activation(torch.tensor(self.opacity_bias)).cuda() - - @property - def get_scaling(self): - scales = self.scaling_activation(self._scaling + self.scale_bias) - scales = torch.square(scales) + self.mininum_kernel_size ** 2 - scales = torch.sqrt(scales) - return scales - - @property - def get_rotation(self): - return self.rotation_activation(self._rotation + self.rots_bias[None, :]) - - @property - def get_xyz(self): - return self._xyz * self.aabb[None, 3:] + self.aabb[None, :3] - - @property - def get_features(self): - return torch.cat((self._features_dc, self._features_rest), dim=2) if self._features_rest is not None else self._features_dc - - @property - def get_opacity(self): - return self.opacity_activation(self._opacity + self.opacity_bias) - - def get_covariance(self, scaling_modifier = 1): - return self.covariance_activation(self.get_scaling, scaling_modifier, self._rotation + self.rots_bias[None, :]) - - def from_scaling(self, scales): - scales = torch.sqrt(torch.square(scales) - self.mininum_kernel_size ** 2) - self._scaling = self.inverse_scaling_activation(scales) - self.scale_bias - - def from_rotation(self, rots): - self._rotation = rots - self.rots_bias[None, :] - - def from_xyz(self, xyz): - self._xyz = (xyz - self.aabb[None, :3]) / self.aabb[None, 3:] - - def from_features(self, features): - self._features_dc = features - - def from_opacity(self, opacities): - self._opacity = self.inverse_opacity_activation(opacities) - self.opacity_bias - - def construct_list_of_attributes(self): - l = ['x', 'y', 'z', 'nx', 'ny', 'nz'] - # All channels except the 3 DC - for i in range(self._features_dc.shape[1]*self._features_dc.shape[2]): - l.append('f_dc_{}'.format(i)) - l.append('opacity') - for i in range(self._scaling.shape[1]): - l.append('scale_{}'.format(i)) - for i in range(self._rotation.shape[1]): - l.append('rot_{}'.format(i)) - return l - - def save_ply(self, path, transform=[[1, 0, 0], [0, 0, -1], [0, 1, 0]]): - xyz = self.get_xyz.detach().cpu().numpy() - normals = np.zeros_like(xyz) - f_dc = self._features_dc.detach().transpose(1, 2).flatten(start_dim=1).contiguous().cpu().numpy() - opacities = inverse_sigmoid(self.get_opacity).detach().cpu().numpy() - scale = torch.log(self.get_scaling).detach().cpu().numpy() - rotation = (self._rotation + self.rots_bias[None, :]).detach().cpu().numpy() - - if transform is not None: - transform = np.array(transform) - xyz = np.matmul(xyz, transform.T) - rotation = utils3d.numpy.quaternion_to_matrix(rotation) - rotation = np.matmul(transform, rotation) - rotation = utils3d.numpy.matrix_to_quaternion(rotation) - - dtype_full = [(attribute, 'f4') for attribute in self.construct_list_of_attributes()] - - elements = np.empty(xyz.shape[0], dtype=dtype_full) - attributes = np.concatenate((xyz, normals, f_dc, opacities, scale, rotation), axis=1) - elements[:] = list(map(tuple, attributes)) - el = PlyElement.describe(elements, 'vertex') - PlyData([el]).write(path) - - def load_ply(self, path, transform=[[1, 0, 0], [0, 0, -1], [0, 1, 0]]): - plydata = PlyData.read(path) - - xyz = np.stack((np.asarray(plydata.elements[0]["x"]), - np.asarray(plydata.elements[0]["y"]), - np.asarray(plydata.elements[0]["z"])), axis=1) - opacities = np.asarray(plydata.elements[0]["opacity"])[..., np.newaxis] - - features_dc = np.zeros((xyz.shape[0], 3, 1)) - features_dc[:, 0, 0] = np.asarray(plydata.elements[0]["f_dc_0"]) - features_dc[:, 1, 0] = np.asarray(plydata.elements[0]["f_dc_1"]) - features_dc[:, 2, 0] = np.asarray(plydata.elements[0]["f_dc_2"]) - - if self.sh_degree > 0: - extra_f_names = [p.name for p in plydata.elements[0].properties if p.name.startswith("f_rest_")] - extra_f_names = sorted(extra_f_names, key = lambda x: int(x.split('_')[-1])) - assert len(extra_f_names)==3*(self.sh_degree + 1) ** 2 - 3 - features_extra = np.zeros((xyz.shape[0], len(extra_f_names))) - for idx, attr_name in enumerate(extra_f_names): - features_extra[:, idx] = np.asarray(plydata.elements[0][attr_name]) - # Reshape (P,F*SH_coeffs) to (P, F, SH_coeffs except DC) - features_extra = features_extra.reshape((features_extra.shape[0], 3, (self.max_sh_degree + 1) ** 2 - 1)) - - scale_names = [p.name for p in plydata.elements[0].properties if p.name.startswith("scale_")] - scale_names = sorted(scale_names, key = lambda x: int(x.split('_')[-1])) - scales = np.zeros((xyz.shape[0], len(scale_names))) - for idx, attr_name in enumerate(scale_names): - scales[:, idx] = np.asarray(plydata.elements[0][attr_name]) - - rot_names = [p.name for p in plydata.elements[0].properties if p.name.startswith("rot")] - rot_names = sorted(rot_names, key = lambda x: int(x.split('_')[-1])) - rots = np.zeros((xyz.shape[0], len(rot_names))) - for idx, attr_name in enumerate(rot_names): - rots[:, idx] = np.asarray(plydata.elements[0][attr_name]) - - if transform is not None: - transform = np.array(transform) - xyz = np.matmul(xyz, transform) - rotation = utils3d.numpy.quaternion_to_matrix(rotation) - rotation = np.matmul(rotation, transform) - rotation = utils3d.numpy.matrix_to_quaternion(rotation) - - # convert to actual gaussian attributes - xyz = torch.tensor(xyz, dtype=torch.float, device=self.device) - features_dc = torch.tensor(features_dc, dtype=torch.float, device=self.device).transpose(1, 2).contiguous() - if self.sh_degree > 0: - features_extra = torch.tensor(features_extra, dtype=torch.float, device=self.device).transpose(1, 2).contiguous() - opacities = torch.sigmoid(torch.tensor(opacities, dtype=torch.float, device=self.device)) - scales = torch.exp(torch.tensor(scales, dtype=torch.float, device=self.device)) - rots = torch.tensor(rots, dtype=torch.float, device=self.device) - - # convert to _hidden attributes - self._xyz = (xyz - self.aabb[None, :3]) / self.aabb[None, 3:] - self._features_dc = features_dc - if self.sh_degree > 0: - self._features_rest = features_extra - else: - self._features_rest = None - self._opacity = self.inverse_opacity_activation(opacities) - self.opacity_bias - self._scaling = self.inverse_scaling_activation(torch.sqrt(torch.square(scales) - self.mininum_kernel_size ** 2)) - self.scale_bias - self._rotation = rots - self.rots_bias[None, :] - \ No newline at end of file diff --git a/trellis/representations/gaussian/general_utils.py b/trellis/representations/gaussian/general_utils.py deleted file mode 100644 index ae982066ab2d04fac15e997df9dbd37620ad08ed..0000000000000000000000000000000000000000 --- a/trellis/representations/gaussian/general_utils.py +++ /dev/null @@ -1,133 +0,0 @@ -# -# Copyright (C) 2023, Inria -# GRAPHDECO research group, https://team.inria.fr/graphdeco -# All rights reserved. -# -# This software is free for non-commercial, research and evaluation use -# under the terms of the LICENSE.md file. -# -# For inquiries contact george.drettakis@inria.fr -# - -import torch -import sys -from datetime import datetime -import numpy as np -import random - -def inverse_sigmoid(x): - return torch.log(x/(1-x)) - -def PILtoTorch(pil_image, resolution): - resized_image_PIL = pil_image.resize(resolution) - resized_image = torch.from_numpy(np.array(resized_image_PIL)) / 255.0 - if len(resized_image.shape) == 3: - return resized_image.permute(2, 0, 1) - else: - return resized_image.unsqueeze(dim=-1).permute(2, 0, 1) - -def get_expon_lr_func( - lr_init, lr_final, lr_delay_steps=0, lr_delay_mult=1.0, max_steps=1000000 -): - """ - Copied from Plenoxels - - Continuous learning rate decay function. Adapted from JaxNeRF - The returned rate is lr_init when step=0 and lr_final when step=max_steps, and - is log-linearly interpolated elsewhere (equivalent to exponential decay). - If lr_delay_steps>0 then the learning rate will be scaled by some smooth - function of lr_delay_mult, such that the initial learning rate is - lr_init*lr_delay_mult at the beginning of optimization but will be eased back - to the normal learning rate when steps>lr_delay_steps. - :param conf: config subtree 'lr' or similar - :param max_steps: int, the number of steps during optimization. - :return HoF which takes step as input - """ - - def helper(step): - if step < 0 or (lr_init == 0.0 and lr_final == 0.0): - # Disable this parameter - return 0.0 - if lr_delay_steps > 0: - # A kind of reverse cosine decay. - delay_rate = lr_delay_mult + (1 - lr_delay_mult) * np.sin( - 0.5 * np.pi * np.clip(step / lr_delay_steps, 0, 1) - ) - else: - delay_rate = 1.0 - t = np.clip(step / max_steps, 0, 1) - log_lerp = np.exp(np.log(lr_init) * (1 - t) + np.log(lr_final) * t) - return delay_rate * log_lerp - - return helper - -def strip_lowerdiag(L): - uncertainty = torch.zeros((L.shape[0], 6), dtype=torch.float, device="cuda") - - uncertainty[:, 0] = L[:, 0, 0] - uncertainty[:, 1] = L[:, 0, 1] - uncertainty[:, 2] = L[:, 0, 2] - uncertainty[:, 3] = L[:, 1, 1] - uncertainty[:, 4] = L[:, 1, 2] - uncertainty[:, 5] = L[:, 2, 2] - return uncertainty - -def strip_symmetric(sym): - return strip_lowerdiag(sym) - -def build_rotation(r): - norm = torch.sqrt(r[:,0]*r[:,0] + r[:,1]*r[:,1] + r[:,2]*r[:,2] + r[:,3]*r[:,3]) - - q = r / norm[:, None] - - R = torch.zeros((q.size(0), 3, 3), device='cuda') - - r = q[:, 0] - x = q[:, 1] - y = q[:, 2] - z = q[:, 3] - - R[:, 0, 0] = 1 - 2 * (y*y + z*z) - R[:, 0, 1] = 2 * (x*y - r*z) - R[:, 0, 2] = 2 * (x*z + r*y) - R[:, 1, 0] = 2 * (x*y + r*z) - R[:, 1, 1] = 1 - 2 * (x*x + z*z) - R[:, 1, 2] = 2 * (y*z - r*x) - R[:, 2, 0] = 2 * (x*z - r*y) - R[:, 2, 1] = 2 * (y*z + r*x) - R[:, 2, 2] = 1 - 2 * (x*x + y*y) - return R - -def build_scaling_rotation(s, r): - L = torch.zeros((s.shape[0], 3, 3), dtype=torch.float, device="cuda") - R = build_rotation(r) - - L[:,0,0] = s[:,0] - L[:,1,1] = s[:,1] - L[:,2,2] = s[:,2] - - L = R @ L - return L - -def safe_state(silent): - old_f = sys.stdout - class F: - def __init__(self, silent): - self.silent = silent - - def write(self, x): - if not self.silent: - if x.endswith("\n"): - old_f.write(x.replace("\n", " [{}]\n".format(str(datetime.now().strftime("%d/%m %H:%M:%S"))))) - else: - old_f.write(x) - - def flush(self): - old_f.flush() - - sys.stdout = F(silent) - - random.seed(0) - np.random.seed(0) - torch.manual_seed(0) - torch.cuda.set_device(torch.device("cuda:0")) diff --git a/trellis/representations/mesh/__init__.py b/trellis/representations/mesh/__init__.py deleted file mode 100644 index fffa2c12907ed38dce29455df884c474704bd663..0000000000000000000000000000000000000000 --- a/trellis/representations/mesh/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from .cube2mesh import SparseFeatures2Mesh, MeshExtractResult diff --git a/trellis/representations/mesh/cube2mesh.py b/trellis/representations/mesh/cube2mesh.py deleted file mode 100644 index 8a6d3b57c49b7140198b5d438a3e618e0ba23d64..0000000000000000000000000000000000000000 --- a/trellis/representations/mesh/cube2mesh.py +++ /dev/null @@ -1,150 +0,0 @@ -# Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# -# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property -# and proprietary rights in and to this software, related documentation -# and any modifications thereto. Any use, reproduction, disclosure or -# distribution of this software and related documentation without an express -# license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited. -import torch -from ...modules.sparse import SparseTensor -from easydict import EasyDict as edict -from .utils_cube import * -from .flexicube import FlexiCubes - - -class MeshExtractResult: - def __init__(self, - vertices, - faces, - vertex_attrs=None, - res=64 - ): - self.vertices = vertices - self.faces = faces.long() - self.vertex_attrs = vertex_attrs - self.face_normal = self.comput_face_normals(vertices, faces) - self.res = res - self.success = (vertices.shape[0] != 0 and faces.shape[0] != 0) - - # training only - self.tsdf_v = None - self.tsdf_s = None - self.reg_loss = None - - def comput_face_normals(self, verts, faces): - i0 = faces[..., 0].long() - i1 = faces[..., 1].long() - i2 = faces[..., 2].long() - - v0 = verts[i0, :] - v1 = verts[i1, :] - v2 = verts[i2, :] - face_normals = torch.cross(v1 - v0, v2 - v0, dim=-1) - face_normals = torch.nn.functional.normalize(face_normals, dim=1) - # print(face_normals.min(), face_normals.max(), face_normals.shape) - return face_normals[:, None, :].repeat(1, 3, 1) - - def comput_v_normals(self, verts, faces): - i0 = faces[..., 0].long() - i1 = faces[..., 1].long() - i2 = faces[..., 2].long() - - v0 = verts[i0, :] - v1 = verts[i1, :] - v2 = verts[i2, :] - face_normals = torch.cross(v1 - v0, v2 - v0, dim=-1) - v_normals = torch.zeros_like(verts) - v_normals.scatter_add_(0, i0[..., None].repeat(1, 3), face_normals) - v_normals.scatter_add_(0, i1[..., None].repeat(1, 3), face_normals) - v_normals.scatter_add_(0, i2[..., None].repeat(1, 3), face_normals) - - v_normals = torch.nn.functional.normalize(v_normals, dim=1) - return v_normals - - -class SparseFeatures2Mesh: - def __init__(self, device="cuda", res=64, use_color=True): - ''' - a model to generate a mesh from sparse features structures using flexicube - ''' - super().__init__() - self.device=device - self.res = res - self.mesh_extractor = FlexiCubes(device=device) - self.sdf_bias = -1.0 / res - verts, cube = construct_dense_grid(self.res, self.device) - self.reg_c = cube.to(self.device) - self.reg_v = verts.to(self.device) - self.use_color = use_color - self._calc_layout() - - def _calc_layout(self): - LAYOUTS = { - 'sdf': {'shape': (8, 1), 'size': 8}, - 'deform': {'shape': (8, 3), 'size': 8 * 3}, - 'weights': {'shape': (21,), 'size': 21} - } - if self.use_color: - ''' - 6 channel color including normal map - ''' - LAYOUTS['color'] = {'shape': (8, 6,), 'size': 8 * 6} - self.layouts = edict(LAYOUTS) - start = 0 - for k, v in self.layouts.items(): - v['range'] = (start, start + v['size']) - start += v['size'] - self.feats_channels = start - - def get_layout(self, feats : torch.Tensor, name : str): - if name not in self.layouts: - return None - return feats[:, self.layouts[name]['range'][0]:self.layouts[name]['range'][1]].reshape(-1, *self.layouts[name]['shape']) - - def __call__(self, cubefeats : SparseTensor, training=False): - """ - Generates a mesh based on the specified sparse voxel structures. - Args: - cube_attrs [Nx21] : Sparse Tensor attrs about cube weights - verts_attrs [Nx10] : [0:1] SDF [1:4] deform [4:7] color [7:10] normal - Returns: - return the success tag and ni you loss, - """ - # add sdf bias to verts_attrs - coords = cubefeats.coords[:, 1:] - feats = cubefeats.feats - - sdf, deform, color, weights = [self.get_layout(feats, name) for name in ['sdf', 'deform', 'color', 'weights']] - sdf += self.sdf_bias - v_attrs = [sdf, deform, color] if self.use_color else [sdf, deform] - v_pos, v_attrs, reg_loss = sparse_cube2verts(coords, torch.cat(v_attrs, dim=-1), training=training) - v_attrs_d = get_dense_attrs(v_pos, v_attrs, res=self.res+1, sdf_init=True) - weights_d = get_dense_attrs(coords, weights, res=self.res, sdf_init=False) - if self.use_color: - sdf_d, deform_d, colors_d = v_attrs_d[..., 0], v_attrs_d[..., 1:4], v_attrs_d[..., 4:] - else: - sdf_d, deform_d = v_attrs_d[..., 0], v_attrs_d[..., 1:4] - colors_d = None - - x_nx3 = get_defomed_verts(self.reg_v, deform_d, self.res) - - vertices, faces, L_dev, colors = self.mesh_extractor( - voxelgrid_vertices=x_nx3, - scalar_field=sdf_d, - cube_idx=self.reg_c, - resolution=self.res, - beta=weights_d[:, :12], - alpha=weights_d[:, 12:20], - gamma_f=weights_d[:, 20], - voxelgrid_colors=colors_d, - training=training) - - mesh = MeshExtractResult(vertices=vertices, faces=faces, vertex_attrs=colors, res=self.res) - if training: - if mesh.success: - reg_loss += L_dev.mean() * 0.5 - reg_loss += (weights[:,:20]).abs().mean() * 0.2 - mesh.reg_loss = reg_loss - mesh.tsdf_v = get_defomed_verts(v_pos, v_attrs[:, 1:4], self.res) - mesh.tsdf_s = v_attrs[:, 0] - return mesh diff --git a/trellis/representations/mesh/flexicube.py b/trellis/representations/mesh/flexicube.py deleted file mode 100644 index 0071b1ba000a826e870e43971a68a4963b23a6c3..0000000000000000000000000000000000000000 --- a/trellis/representations/mesh/flexicube.py +++ /dev/null @@ -1,359 +0,0 @@ -# Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# -# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property -# and proprietary rights in and to this software, related documentation -# and any modifications thereto. Any use, reproduction, disclosure or -# distribution of this software and related documentation without an express -# license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited. - -import torch -from .tables import * - -__all__ = [ - 'FlexiCubes' -] - - -class FlexiCubes: - def __init__(self, device="cuda"): - - self.device = device - self.dmc_table = torch.tensor(dmc_table, dtype=torch.long, device=device, requires_grad=False) - self.num_vd_table = torch.tensor(num_vd_table, - dtype=torch.long, device=device, requires_grad=False) - self.check_table = torch.tensor( - check_table, - dtype=torch.long, device=device, requires_grad=False) - - self.tet_table = torch.tensor(tet_table, dtype=torch.long, device=device, requires_grad=False) - self.quad_split_1 = torch.tensor([0, 1, 2, 0, 2, 3], dtype=torch.long, device=device, requires_grad=False) - self.quad_split_2 = torch.tensor([0, 1, 3, 3, 1, 2], dtype=torch.long, device=device, requires_grad=False) - self.quad_split_train = torch.tensor( - [0, 1, 1, 2, 2, 3, 3, 0], dtype=torch.long, device=device, requires_grad=False) - - self.cube_corners = torch.tensor([[0, 0, 0], [1, 0, 0], [0, 1, 0], [1, 1, 0], [0, 0, 1], [ - 1, 0, 1], [0, 1, 1], [1, 1, 1]], dtype=torch.float, device=device) - self.cube_corners_idx = torch.pow(2, torch.arange(8, requires_grad=False)) - self.cube_edges = torch.tensor([0, 1, 1, 5, 4, 5, 0, 4, 2, 3, 3, 7, 6, 7, 2, 6, - 2, 0, 3, 1, 7, 5, 6, 4], dtype=torch.long, device=device, requires_grad=False) - - self.edge_dir_table = torch.tensor([0, 2, 0, 2, 0, 2, 0, 2, 1, 1, 1, 1], - dtype=torch.long, device=device) - self.dir_faces_table = torch.tensor([ - [[5, 4], [3, 2], [4, 5], [2, 3]], - [[5, 4], [1, 0], [4, 5], [0, 1]], - [[3, 2], [1, 0], [2, 3], [0, 1]] - ], dtype=torch.long, device=device) - self.adj_pairs = torch.tensor([0, 1, 1, 3, 3, 2, 2, 0], dtype=torch.long, device=device) - - def __call__(self, voxelgrid_vertices, scalar_field, cube_idx, resolution, qef_reg_scale=1e-3, - weight_scale=0.99, beta=None, alpha=None, gamma_f=None, voxelgrid_colors=None, training=False): - surf_cubes, occ_fx8 = self._identify_surf_cubes(scalar_field, cube_idx) - if surf_cubes.sum() == 0: - return ( - torch.zeros((0, 3), device=self.device), - torch.zeros((0, 3), dtype=torch.long, device=self.device), - torch.zeros((0), device=self.device), - torch.zeros((0, voxelgrid_colors.shape[-1]), device=self.device) if voxelgrid_colors is not None else None - ) - beta, alpha, gamma_f = self._normalize_weights( - beta, alpha, gamma_f, surf_cubes, weight_scale) - - if voxelgrid_colors is not None: - voxelgrid_colors = torch.sigmoid(voxelgrid_colors) - - case_ids = self._get_case_id(occ_fx8, surf_cubes, resolution) - - surf_edges, idx_map, edge_counts, surf_edges_mask = self._identify_surf_edges( - scalar_field, cube_idx, surf_cubes - ) - - vd, L_dev, vd_gamma, vd_idx_map, vd_color = self._compute_vd( - voxelgrid_vertices, cube_idx[surf_cubes], surf_edges, scalar_field, - case_ids, beta, alpha, gamma_f, idx_map, qef_reg_scale, voxelgrid_colors) - vertices, faces, s_edges, edge_indices, vertices_color = self._triangulate( - scalar_field, surf_edges, vd, vd_gamma, edge_counts, idx_map, - vd_idx_map, surf_edges_mask, training, vd_color) - return vertices, faces, L_dev, vertices_color - - def _compute_reg_loss(self, vd, ue, edge_group_to_vd, vd_num_edges): - """ - Regularizer L_dev as in Equation 8 - """ - dist = torch.norm(ue - torch.index_select(input=vd, index=edge_group_to_vd, dim=0), dim=-1) - mean_l2 = torch.zeros_like(vd[:, 0]) - mean_l2 = (mean_l2).index_add_(0, edge_group_to_vd, dist) / vd_num_edges.squeeze(1).float() - mad = (dist - torch.index_select(input=mean_l2, index=edge_group_to_vd, dim=0)).abs() - return mad - - def _normalize_weights(self, beta, alpha, gamma_f, surf_cubes, weight_scale): - """ - Normalizes the given weights to be non-negative. If input weights are None, it creates and returns a set of weights of ones. - """ - n_cubes = surf_cubes.shape[0] - - if beta is not None: - beta = (torch.tanh(beta) * weight_scale + 1) - else: - beta = torch.ones((n_cubes, 12), dtype=torch.float, device=self.device) - - if alpha is not None: - alpha = (torch.tanh(alpha) * weight_scale + 1) - else: - alpha = torch.ones((n_cubes, 8), dtype=torch.float, device=self.device) - - if gamma_f is not None: - gamma_f = torch.sigmoid(gamma_f) * weight_scale + (1 - weight_scale) / 2 - else: - gamma_f = torch.ones((n_cubes), dtype=torch.float, device=self.device) - - return beta[surf_cubes], alpha[surf_cubes], gamma_f[surf_cubes] - - @torch.no_grad() - def _get_case_id(self, occ_fx8, surf_cubes, res): - """ - Obtains the ID of topology cases based on cell corner occupancy. This function resolves the - ambiguity in the Dual Marching Cubes (DMC) configurations as described in Section 1.3 of the - supplementary material. It should be noted that this function assumes a regular grid. - """ - case_ids = (occ_fx8[surf_cubes] * self.cube_corners_idx.to(self.device).unsqueeze(0)).sum(-1) - - problem_config = self.check_table.to(self.device)[case_ids] - to_check = problem_config[..., 0] == 1 - problem_config = problem_config[to_check] - if not isinstance(res, (list, tuple)): - res = [res, res, res] - - # The 'problematic_configs' only contain configurations for surface cubes. Next, we construct a 3D array, - # 'problem_config_full', to store configurations for all cubes (with default config for non-surface cubes). - # This allows efficient checking on adjacent cubes. - problem_config_full = torch.zeros(list(res) + [5], device=self.device, dtype=torch.long) - vol_idx = torch.nonzero(problem_config_full[..., 0] == 0) # N, 3 - vol_idx_problem = vol_idx[surf_cubes][to_check] - problem_config_full[vol_idx_problem[..., 0], vol_idx_problem[..., 1], vol_idx_problem[..., 2]] = problem_config - vol_idx_problem_adj = vol_idx_problem + problem_config[..., 1:4] - - within_range = ( - vol_idx_problem_adj[..., 0] >= 0) & ( - vol_idx_problem_adj[..., 0] < res[0]) & ( - vol_idx_problem_adj[..., 1] >= 0) & ( - vol_idx_problem_adj[..., 1] < res[1]) & ( - vol_idx_problem_adj[..., 2] >= 0) & ( - vol_idx_problem_adj[..., 2] < res[2]) - - vol_idx_problem = vol_idx_problem[within_range] - vol_idx_problem_adj = vol_idx_problem_adj[within_range] - problem_config = problem_config[within_range] - problem_config_adj = problem_config_full[vol_idx_problem_adj[..., 0], - vol_idx_problem_adj[..., 1], vol_idx_problem_adj[..., 2]] - # If two cubes with cases C16 and C19 share an ambiguous face, both cases are inverted. - to_invert = (problem_config_adj[..., 0] == 1) - idx = torch.arange(case_ids.shape[0], device=self.device)[to_check][within_range][to_invert] - case_ids.index_put_((idx,), problem_config[to_invert][..., -1]) - return case_ids - - @torch.no_grad() - def _identify_surf_edges(self, scalar_field, cube_idx, surf_cubes): - """ - Identifies grid edges that intersect with the underlying surface by checking for opposite signs. As each edge - can be shared by multiple cubes, this function also assigns a unique index to each surface-intersecting edge - and marks the cube edges with this index. - """ - occ_n = scalar_field < 0 - all_edges = cube_idx[surf_cubes][:, self.cube_edges].reshape(-1, 2) - unique_edges, _idx_map, counts = torch.unique(all_edges, dim=0, return_inverse=True, return_counts=True) - - unique_edges = unique_edges.long() - mask_edges = occ_n[unique_edges.reshape(-1)].reshape(-1, 2).sum(-1) == 1 - - surf_edges_mask = mask_edges[_idx_map] - counts = counts[_idx_map] - - mapping = torch.ones((unique_edges.shape[0]), dtype=torch.long, device=cube_idx.device) * -1 - mapping[mask_edges] = torch.arange(mask_edges.sum(), device=cube_idx.device) - # Shaped as [number of cubes x 12 edges per cube]. This is later used to map a cube edge to the unique index - # for a surface-intersecting edge. Non-surface-intersecting edges are marked with -1. - idx_map = mapping[_idx_map] - surf_edges = unique_edges[mask_edges] - return surf_edges, idx_map, counts, surf_edges_mask - - @torch.no_grad() - def _identify_surf_cubes(self, scalar_field, cube_idx): - """ - Identifies grid cubes that intersect with the underlying surface by checking if the signs at - all corners are not identical. - """ - occ_n = scalar_field < 0 - occ_fx8 = occ_n[cube_idx.reshape(-1)].reshape(-1, 8) - _occ_sum = torch.sum(occ_fx8, -1) - surf_cubes = (_occ_sum > 0) & (_occ_sum < 8) - return surf_cubes, occ_fx8 - - def _linear_interp(self, edges_weight, edges_x): - """ - Computes the location of zero-crossings on 'edges_x' using linear interpolation with 'edges_weight'. - """ - edge_dim = edges_weight.dim() - 2 - assert edges_weight.shape[edge_dim] == 2 - edges_weight = torch.cat([torch.index_select(input=edges_weight, index=torch.tensor(1, device=self.device), dim=edge_dim), - - torch.index_select(input=edges_weight, index=torch.tensor(0, device=self.device), dim=edge_dim)] - , edge_dim) - denominator = edges_weight.sum(edge_dim) - ue = (edges_x * edges_weight).sum(edge_dim) / denominator - return ue - - def _solve_vd_QEF(self, p_bxnx3, norm_bxnx3, c_bx3, qef_reg_scale): - p_bxnx3 = p_bxnx3.reshape(-1, 7, 3) - norm_bxnx3 = norm_bxnx3.reshape(-1, 7, 3) - c_bx3 = c_bx3.reshape(-1, 3) - A = norm_bxnx3 - B = ((p_bxnx3) * norm_bxnx3).sum(-1, keepdims=True) - - A_reg = (torch.eye(3, device=p_bxnx3.device) * qef_reg_scale).unsqueeze(0).repeat(p_bxnx3.shape[0], 1, 1) - B_reg = (qef_reg_scale * c_bx3).unsqueeze(-1) - A = torch.cat([A, A_reg], 1) - B = torch.cat([B, B_reg], 1) - dual_verts = torch.linalg.lstsq(A, B).solution.squeeze(-1) - return dual_verts - - def _compute_vd(self, voxelgrid_vertices, surf_cubes_fx8, surf_edges, scalar_field, - case_ids, beta, alpha, gamma_f, idx_map, qef_reg_scale, voxelgrid_colors): - """ - Computes the location of dual vertices as described in Section 4.2 - """ - alpha_nx12x2 = torch.index_select(input=alpha, index=self.cube_edges, dim=1).reshape(-1, 12, 2) - surf_edges_x = torch.index_select(input=voxelgrid_vertices, index=surf_edges.reshape(-1), dim=0).reshape(-1, 2, 3) - surf_edges_s = torch.index_select(input=scalar_field, index=surf_edges.reshape(-1), dim=0).reshape(-1, 2, 1) - zero_crossing = self._linear_interp(surf_edges_s, surf_edges_x) - - if voxelgrid_colors is not None: - C = voxelgrid_colors.shape[-1] - surf_edges_c = torch.index_select(input=voxelgrid_colors, index=surf_edges.reshape(-1), dim=0).reshape(-1, 2, C) - - idx_map = idx_map.reshape(-1, 12) - num_vd = torch.index_select(input=self.num_vd_table, index=case_ids, dim=0) - edge_group, edge_group_to_vd, edge_group_to_cube, vd_num_edges, vd_gamma = [], [], [], [], [] - - # if color is not None: - # vd_color = [] - - total_num_vd = 0 - vd_idx_map = torch.zeros((case_ids.shape[0], 12), dtype=torch.long, device=self.device, requires_grad=False) - - for num in torch.unique(num_vd): - cur_cubes = (num_vd == num) # consider cubes with the same numbers of vd emitted (for batching) - curr_num_vd = cur_cubes.sum() * num - curr_edge_group = self.dmc_table[case_ids[cur_cubes], :num].reshape(-1, num * 7) - curr_edge_group_to_vd = torch.arange( - curr_num_vd, device=self.device).unsqueeze(-1).repeat(1, 7) + total_num_vd - total_num_vd += curr_num_vd - curr_edge_group_to_cube = torch.arange(idx_map.shape[0], device=self.device)[ - cur_cubes].unsqueeze(-1).repeat(1, num * 7).reshape_as(curr_edge_group) - - curr_mask = (curr_edge_group != -1) - edge_group.append(torch.masked_select(curr_edge_group, curr_mask)) - edge_group_to_vd.append(torch.masked_select(curr_edge_group_to_vd.reshape_as(curr_edge_group), curr_mask)) - edge_group_to_cube.append(torch.masked_select(curr_edge_group_to_cube, curr_mask)) - vd_num_edges.append(curr_mask.reshape(-1, 7).sum(-1, keepdims=True)) - vd_gamma.append(torch.masked_select(gamma_f, cur_cubes).unsqueeze(-1).repeat(1, num).reshape(-1)) - # if color is not None: - # vd_color.append(color[cur_cubes].unsqueeze(1).repeat(1, num, 1).reshape(-1, 3)) - - edge_group = torch.cat(edge_group) - edge_group_to_vd = torch.cat(edge_group_to_vd) - edge_group_to_cube = torch.cat(edge_group_to_cube) - vd_num_edges = torch.cat(vd_num_edges) - vd_gamma = torch.cat(vd_gamma) - # if color is not None: - # vd_color = torch.cat(vd_color) - # else: - # vd_color = None - - vd = torch.zeros((total_num_vd, 3), device=self.device) - beta_sum = torch.zeros((total_num_vd, 1), device=self.device) - - idx_group = torch.gather(input=idx_map.reshape(-1), dim=0, index=edge_group_to_cube * 12 + edge_group) - - x_group = torch.index_select(input=surf_edges_x, index=idx_group.reshape(-1), dim=0).reshape(-1, 2, 3) - s_group = torch.index_select(input=surf_edges_s, index=idx_group.reshape(-1), dim=0).reshape(-1, 2, 1) - - - zero_crossing_group = torch.index_select( - input=zero_crossing, index=idx_group.reshape(-1), dim=0).reshape(-1, 3) - - alpha_group = torch.index_select(input=alpha_nx12x2.reshape(-1, 2), dim=0, - index=edge_group_to_cube * 12 + edge_group).reshape(-1, 2, 1) - ue_group = self._linear_interp(s_group * alpha_group, x_group) - - beta_group = torch.gather(input=beta.reshape(-1), dim=0, - index=edge_group_to_cube * 12 + edge_group).reshape(-1, 1) - beta_sum = beta_sum.index_add_(0, index=edge_group_to_vd, source=beta_group) - vd = vd.index_add_(0, index=edge_group_to_vd, source=ue_group * beta_group) / beta_sum - - ''' - interpolate colors use the same method as dual vertices - ''' - if voxelgrid_colors is not None: - vd_color = torch.zeros((total_num_vd, C), device=self.device) - c_group = torch.index_select(input=surf_edges_c, index=idx_group.reshape(-1), dim=0).reshape(-1, 2, C) - uc_group = self._linear_interp(s_group * alpha_group, c_group) - vd_color = vd_color.index_add_(0, index=edge_group_to_vd, source=uc_group * beta_group) / beta_sum - else: - vd_color = None - - L_dev = self._compute_reg_loss(vd, zero_crossing_group, edge_group_to_vd, vd_num_edges) - - v_idx = torch.arange(vd.shape[0], device=self.device) # + total_num_vd - - vd_idx_map = (vd_idx_map.reshape(-1)).scatter(dim=0, index=edge_group_to_cube * - 12 + edge_group, src=v_idx[edge_group_to_vd]) - - return vd, L_dev, vd_gamma, vd_idx_map, vd_color - - def _triangulate(self, scalar_field, surf_edges, vd, vd_gamma, edge_counts, idx_map, vd_idx_map, surf_edges_mask, training, vd_color): - """ - Connects four neighboring dual vertices to form a quadrilateral. The quadrilaterals are then split into - triangles based on the gamma parameter, as described in Section 4.3. - """ - with torch.no_grad(): - group_mask = (edge_counts == 4) & surf_edges_mask # surface edges shared by 4 cubes. - group = idx_map.reshape(-1)[group_mask] - vd_idx = vd_idx_map[group_mask] - edge_indices, indices = torch.sort(group, stable=True) - quad_vd_idx = vd_idx[indices].reshape(-1, 4) - - # Ensure all face directions point towards the positive SDF to maintain consistent winding. - s_edges = scalar_field[surf_edges[edge_indices.reshape(-1, 4)[:, 0]].reshape(-1)].reshape(-1, 2) - flip_mask = s_edges[:, 0] > 0 - quad_vd_idx = torch.cat((quad_vd_idx[flip_mask][:, [0, 1, 3, 2]], - quad_vd_idx[~flip_mask][:, [2, 3, 1, 0]])) - - quad_gamma = torch.index_select(input=vd_gamma, index=quad_vd_idx.reshape(-1), dim=0).reshape(-1, 4) - gamma_02 = quad_gamma[:, 0] * quad_gamma[:, 2] - gamma_13 = quad_gamma[:, 1] * quad_gamma[:, 3] - if not training: - mask = (gamma_02 > gamma_13) - faces = torch.zeros((quad_gamma.shape[0], 6), dtype=torch.long, device=quad_vd_idx.device) - faces[mask] = quad_vd_idx[mask][:, self.quad_split_1] - faces[~mask] = quad_vd_idx[~mask][:, self.quad_split_2] - faces = faces.reshape(-1, 3) - else: - vd_quad = torch.index_select(input=vd, index=quad_vd_idx.reshape(-1), dim=0).reshape(-1, 4, 3) - vd_02 = (vd_quad[:, 0] + vd_quad[:, 2]) / 2 - vd_13 = (vd_quad[:, 1] + vd_quad[:, 3]) / 2 - weight_sum = (gamma_02 + gamma_13) + 1e-8 - vd_center = (vd_02 * gamma_02.unsqueeze(-1) + vd_13 * gamma_13.unsqueeze(-1)) / weight_sum.unsqueeze(-1) - - if vd_color is not None: - color_quad = torch.index_select(input=vd_color, index=quad_vd_idx.reshape(-1), dim=0).reshape(-1, 4, vd_color.shape[-1]) - color_02 = (color_quad[:, 0] + color_quad[:, 2]) / 2 - color_13 = (color_quad[:, 1] + color_quad[:, 3]) / 2 - color_center = (color_02 * gamma_02.unsqueeze(-1) + color_13 * gamma_13.unsqueeze(-1)) / weight_sum.unsqueeze(-1) - vd_color = torch.cat([vd_color, color_center]) - - - vd_center_idx = torch.arange(vd_center.shape[0], device=self.device) + vd.shape[0] - vd = torch.cat([vd, vd_center]) - faces = quad_vd_idx[:, self.quad_split_train].reshape(-1, 4, 2) - faces = torch.cat([faces, vd_center_idx.reshape(-1, 1, 1).repeat(1, 4, 1)], -1).reshape(-1, 3) - return vd, faces, s_edges, edge_indices, vd_color diff --git a/trellis/representations/mesh/tables.py b/trellis/representations/mesh/tables.py deleted file mode 100644 index 203d6b1c40410e57d56f5d6f86973d22eea61057..0000000000000000000000000000000000000000 --- a/trellis/representations/mesh/tables.py +++ /dev/null @@ -1,791 +0,0 @@ -# Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# -# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property -# and proprietary rights in and to this software, related documentation -# and any modifications thereto. Any use, reproduction, disclosure or -# distribution of this software and related documentation without an express -# license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited. -dmc_table = [ -[[-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 3, 8, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 1, 9, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[1, 3, 8, 9, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[4, 7, 8, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 3, 4, 7, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 1, 9, -1, -1, -1, -1], [4, 7, 8, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[1, 3, 4, 7, 9, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[4, 5, 9, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 3, 8, -1, -1, -1, -1], [4, 5, 9, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 1, 4, 5, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[1, 3, 4, 5, 8, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[5, 7, 8, 9, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 3, 5, 7, 9, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 1, 5, 7, 8, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[1, 3, 5, 7, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 2, 8, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 1, 9, -1, -1, -1, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[1, 2, 8, 9, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[4, 7, 8, -1, -1, -1, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 2, 4, 7, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 1, 9, -1, -1, -1, -1], [4, 7, 8, -1, -1, -1, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[1, 2, 4, 7, 9, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[4, 5, 9, -1, -1, -1, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 2, 8, 11, -1, -1, -1], [4, 5, 9, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 1, 4, 5, -1, -1, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[1, 2, 4, 5, 8, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[5, 7, 8, 9, -1, -1, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 2, 5, 7, 9, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 1, 5, 7, 8, -1, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[1, 2, 5, 7, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 3, 8, -1, -1, -1, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 2, 9, 10, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[2, 3, 8, 9, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[4, 7, 8, -1, -1, -1, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 3, 4, 7, -1, -1, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 2, 9, 10, -1, -1, -1], [4, 7, 8, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[2, 3, 4, 7, 9, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[4, 5, 9, -1, -1, -1, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 3, 8, -1, -1, -1, -1], [4, 5, 9, -1, -1, -1, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 2, 4, 5, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[2, 3, 4, 5, 8, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[5, 7, 8, 9, -1, -1, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 3, 5, 7, 9, -1, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 2, 5, 7, 8, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[2, 3, 5, 7, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[1, 3, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 1, 8, 10, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 3, 9, 10, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[8, 9, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[4, 7, 8, -1, -1, -1, -1], [1, 3, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 1, 4, 7, 10, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 3, 9, 10, 11, -1, -1], [4, 7, 8, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[4, 7, 9, 10, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[4, 5, 9, -1, -1, -1, -1], [1, 3, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 1, 8, 10, 11, -1, -1], [4, 5, 9, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 3, 4, 5, 10, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[4, 5, 8, 10, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[5, 7, 8, 9, -1, -1, -1], [1, 3, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 1, 5, 7, 9, 10, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 3, 5, 7, 8, 10, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[5, 7, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 3, 8, -1, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 1, 9, -1, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[1, 3, 8, 9, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[4, 6, 8, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 3, 4, 6, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 1, 9, -1, -1, -1, -1], [4, 6, 8, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[1, 3, 4, 6, 9, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[4, 5, 9, -1, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 3, 8, -1, -1, -1, -1], [4, 5, 9, -1, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 1, 4, 5, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[1, 3, 4, 5, 8, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[5, 6, 8, 9, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 3, 5, 6, 9, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 1, 5, 6, 8, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[1, 3, 5, 6, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[2, 3, 6, 7, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 2, 6, 7, 8, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 1, 9, -1, -1, -1, -1], [2, 3, 6, 7, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[1, 2, 6, 7, 8, 9, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[2, 3, 4, 6, 8, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 2, 4, 6, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 1, 9, -1, -1, -1, -1], [2, 3, 4, 6, 8, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[1, 2, 4, 6, 9, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[4, 5, 9, -1, -1, -1, -1], [2, 3, 6, 7, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 2, 6, 7, 8, -1, -1], [4, 5, 9, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 1, 4, 5, -1, -1, -1], [2, 3, 6, 7, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[1, 2, 4, 5, 6, 7, 8], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[2, 3, 5, 6, 8, 9, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 2, 5, 6, 9, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 1, 2, 3, 5, 6, 8], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[1, 2, 5, 6, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[1, 2, 10, -1, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 3, 8, -1, -1, -1, -1], [1, 2, 10, -1, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 2, 9, 10, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[2, 3, 8, 9, 10, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[4, 6, 8, 11, -1, -1, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 3, 4, 6, 11, -1, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 2, 9, 10, -1, -1, -1], [4, 6, 8, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[2, 3, 4, 6, 9, 10, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[4, 5, 9, -1, -1, -1, -1], [1, 2, 10, -1, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 3, 8, -1, -1, -1, -1], [4, 5, 9, -1, -1, -1, -1], [1, 2, 10, -1, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1]], -[[0, 2, 4, 5, 10, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[2, 3, 4, 5, 8, 10, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[5, 6, 8, 9, 11, -1, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 3, 5, 6, 9, 11, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 2, 5, 6, 8, 10, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[2, 3, 5, 6, 10, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[1, 3, 6, 7, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 1, 6, 7, 8, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 3, 6, 7, 9, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[6, 7, 8, 9, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[1, 3, 4, 6, 8, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 1, 4, 6, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 3, 4, 6, 8, 9, 10], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[4, 6, 9, 10, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[4, 5, 9, -1, -1, -1, -1], [1, 3, 6, 7, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 1, 6, 7, 8, 10, -1], [4, 5, 9, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 3, 4, 5, 6, 7, 10], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[4, 5, 6, 7, 8, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[1, 3, 5, 6, 8, 9, 10], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 1, 5, 6, 9, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 3, 8, -1, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 3, 8, -1, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 1, 9, -1, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[1, 3, 8, 9, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[4, 7, 8, -1, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 3, 4, 7, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 1, 9, -1, -1, -1, -1], [4, 7, 8, -1, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[1, 3, 4, 7, 9, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[4, 6, 9, 10, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 3, 8, -1, -1, -1, -1], [4, 6, 9, 10, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 1, 4, 6, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[1, 3, 4, 6, 8, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[6, 7, 8, 9, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 3, 6, 7, 9, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 1, 6, 7, 8, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[1, 3, 6, 7, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[2, 3, 11, -1, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 2, 8, 11, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 1, 9, -1, -1, -1, -1], [2, 3, 11, -1, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[1, 2, 8, 9, 11, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[4, 7, 8, -1, -1, -1, -1], [2, 3, 11, -1, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 2, 4, 7, 11, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 1, 9, -1, -1, -1, -1], [4, 7, 8, -1, -1, -1, -1], [2, 3, 11, -1, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1]], -[[1, 2, 4, 7, 9, 11, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[4, 6, 9, 10, -1, -1, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 2, 8, 11, -1, -1, -1], [4, 6, 9, 10, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 1, 4, 6, 10, -1, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[1, 2, 4, 6, 8, 10, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[6, 7, 8, 9, 10, -1, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 2, 6, 7, 9, 10, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 1, 6, 7, 8, 10, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[1, 2, 6, 7, 10, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[1, 2, 5, 6, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 3, 8, -1, -1, -1, -1], [1, 2, 5, 6, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 2, 5, 6, 9, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[2, 3, 5, 6, 8, 9, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[4, 7, 8, -1, -1, -1, -1], [1, 2, 5, 6, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 3, 4, 7, -1, -1, -1], [1, 2, 5, 6, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 2, 5, 6, 9, -1, -1], [4, 7, 8, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[2, 3, 4, 5, 6, 7, 9], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[1, 2, 4, 6, 9, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 3, 8, -1, -1, -1, -1], [1, 2, 4, 6, 9, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 2, 4, 6, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[2, 3, 4, 6, 8, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[1, 2, 6, 7, 8, 9, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 1, 2, 3, 6, 7, 9], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 2, 6, 7, 8, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[2, 3, 6, 7, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[1, 3, 5, 6, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 1, 5, 6, 8, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 3, 5, 6, 9, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[5, 6, 8, 9, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[4, 7, 8, -1, -1, -1, -1], [1, 3, 5, 6, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 1, 4, 5, 6, 7, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 3, 5, 6, 9, 11, -1], [4, 7, 8, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[4, 5, 6, 7, 9, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[1, 3, 4, 6, 9, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 1, 4, 6, 8, 9, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 3, 4, 6, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[4, 6, 8, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[1, 3, 6, 7, 8, 9, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 1, 9, -1, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 3, 6, 7, 8, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[5, 7, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 3, 8, -1, -1, -1, -1], [5, 7, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 1, 9, -1, -1, -1, -1], [5, 7, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[1, 3, 8, 9, -1, -1, -1], [5, 7, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[4, 5, 8, 10, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 3, 4, 5, 10, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 1, 9, -1, -1, -1, -1], [4, 5, 8, 10, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[1, 3, 4, 5, 9, 10, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[4, 7, 9, 10, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 3, 8, -1, -1, -1, -1], [4, 7, 9, 10, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 1, 4, 7, 10, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[1, 3, 4, 7, 8, 10, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[8, 9, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 3, 9, 10, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 1, 8, 10, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[1, 3, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[2, 3, 5, 7, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 2, 5, 7, 8, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 1, 9, -1, -1, -1, -1], [2, 3, 5, 7, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[1, 2, 5, 7, 8, 9, 10], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[2, 3, 4, 5, 8, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 2, 4, 5, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 1, 9, -1, -1, -1, -1], [2, 3, 4, 5, 8, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[1, 2, 4, 5, 9, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[2, 3, 4, 7, 9, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 2, 4, 7, 8, 9, 10], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 1, 2, 3, 4, 7, 10], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[4, 7, 8, -1, -1, -1, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[2, 3, 8, 9, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 2, 9, 10, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 1, 2, 3, 8, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[1, 2, 5, 7, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 3, 8, -1, -1, -1, -1], [1, 2, 5, 7, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 2, 5, 7, 9, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[2, 3, 5, 7, 8, 9, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[1, 2, 4, 5, 8, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 1, 2, 3, 4, 5, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 2, 4, 5, 8, 9, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[4, 5, 9, -1, -1, -1, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[1, 2, 4, 7, 9, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 3, 8, -1, -1, -1, -1], [1, 2, 4, 7, 9, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 2, 4, 7, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[2, 3, 4, 7, 8, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[1, 2, 8, 9, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 1, 2, 3, 9, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 2, 8, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[1, 3, 5, 7, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 1, 5, 7, 8, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 3, 5, 7, 9, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[5, 7, 8, 9, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[1, 3, 4, 5, 8, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 1, 4, 5, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 3, 4, 5, 8, 9, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[4, 5, 9, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[1, 3, 4, 7, 9, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 1, 4, 7, 8, 9, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 3, 4, 7, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[4, 7, 8, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[1, 3, 8, 9, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 1, 9, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 3, 8, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]] -] -num_vd_table = [0, 1, 1, 1, 1, 1, 2, 1, 1, 2, 1, 1, 1, 1, 1, 1, 1, 1, 2, 1, 2, 1, 3, 1, 2, 2, -2, 1, 2, 1, 2, 1, 1, 2, 1, 1, 2, 2, 2, 1, 2, 3, 1, 1, 2, 2, 1, 1, 1, 1, 1, 1, 2, -1, 2, 1, 2, 2, 1, 1, 2, 1, 1, 1, 1, 2, 2, 2, 1, 1, 2, 1, 2, 3, 2, 2, 1, 1, 1, 1, -1, 1, 2, 1, 1, 1, 2, 1, 2, 2, 2, 1, 1, 1, 1, 1, 2, 3, 2, 2, 2, 2, 2, 1, 3, 4, 2, -2, 2, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 1, 1, 1, 1, 2, 1, 1, 2, 2, 2, 2, 2, -3, 2, 1, 2, 1, 1, 1, 1, 1, 1, 2, 2, 3, 2, 3, 2, 4, 2, 2, 2, 2, 1, 2, 1, 2, 1, 1, -2, 1, 1, 2, 2, 2, 1, 1, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 1, 2, 1, 1, 1, 1, 1, -1, 2, 1, 1, 1, 2, 2, 2, 1, 1, 2, 1, 1, 2, 1, 1, 1, 1, 1, 1, 1, 1, 2, 1, 1, 1, 2, -1, 1, 1, 1, 2, 1, 1, 1, 1, 1, 2, 1, 1, 1, 1, 1, 2, 1, 2, 1, 1, 1, 1, 1, 1, 1, 1, -1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0] -check_table = [ -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[1, 1, 0, 0, 194], -[1, -1, 0, 0, 193], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[1, 0, 1, 0, 164], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[1, 0, -1, 0, 161], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[1, 0, 0, 1, 152], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[1, 0, 0, 1, 145], -[1, 0, 0, 1, 144], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[1, 0, 0, -1, 137], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[1, 0, 1, 0, 133], -[1, 0, 1, 0, 132], -[1, 1, 0, 0, 131], -[1, 1, 0, 0, 130], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[1, 0, 0, 1, 100], -[0, 0, 0, 0, 0], -[1, 0, 0, 1, 98], -[0, 0, 0, 0, 0], -[1, 0, 0, 1, 96], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[1, 0, 1, 0, 88], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[1, 0, -1, 0, 82], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[1, 0, 1, 0, 74], -[0, 0, 0, 0, 0], -[1, 0, 1, 0, 72], -[0, 0, 0, 0, 0], -[1, 0, 0, -1, 70], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[1, -1, 0, 0, 67], -[0, 0, 0, 0, 0], -[1, -1, 0, 0, 65], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[1, 1, 0, 0, 56], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[1, -1, 0, 0, 52], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[1, 1, 0, 0, 44], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[1, 1, 0, 0, 40], -[0, 0, 0, 0, 0], -[1, 0, 0, -1, 38], -[1, 0, -1, 0, 37], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[1, 0, -1, 0, 33], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[1, -1, 0, 0, 28], -[0, 0, 0, 0, 0], -[1, 0, -1, 0, 26], -[1, 0, 0, -1, 25], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[1, -1, 0, 0, 20], -[0, 0, 0, 0, 0], -[1, 0, -1, 0, 18], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[1, 0, 0, -1, 9], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[1, 0, 0, -1, 6], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0] -] -tet_table = [ -[-1, -1, -1, -1, -1, -1], -[0, 0, 0, 0, 0, 0], -[0, 0, 0, 0, 0, 0], -[1, 1, 1, 1, 1, 1], -[4, 4, 4, 4, 4, 4], -[0, 0, 0, 0, 0, 0], -[4, 0, 0, 4, 4, -1], -[1, 1, 1, 1, 1, 1], -[4, 4, 4, 4, 4, 4], -[0, 4, 0, 4, 4, -1], -[0, 0, 0, 0, 0, 0], -[1, 1, 1, 1, 1, 1], -[5, 5, 5, 5, 5, 5], -[0, 0, 0, 0, 0, 0], -[0, 0, 0, 0, 0, 0], -[1, 1, 1, 1, 1, 1], -[2, 2, 2, 2, 2, 2], -[0, 0, 0, 0, 0, 0], -[2, 0, 2, -1, 0, 2], -[1, 1, 1, 1, 1, 1], -[2, -1, 2, 4, 4, 2], -[0, 0, 0, 0, 0, 0], -[2, 0, 2, 4, 4, 2], -[1, 1, 1, 1, 1, 1], -[2, 4, 2, 4, 4, 2], -[0, 4, 0, 4, 4, 0], -[2, 0, 2, 0, 0, 2], -[1, 1, 1, 1, 1, 1], -[2, 5, 2, 5, 5, 2], -[0, 0, 0, 0, 0, 0], -[2, 0, 2, 0, 0, 2], -[1, 1, 1, 1, 1, 1], -[1, 1, 1, 1, 1, 1], -[0, 1, 1, -1, 0, 1], -[0, 0, 0, 0, 0, 0], -[2, 2, 2, 2, 2, 2], -[4, 1, 1, 4, 4, 1], -[0, 1, 1, 0, 0, 1], -[4, 0, 0, 4, 4, 0], -[2, 2, 2, 2, 2, 2], -[-1, 1, 1, 4, 4, 1], -[0, 1, 1, 4, 4, 1], -[0, 0, 0, 0, 0, 0], -[2, 2, 2, 2, 2, 2], -[5, 1, 1, 5, 5, 1], -[0, 1, 1, 0, 0, 1], -[0, 0, 0, 0, 0, 0], -[2, 2, 2, 2, 2, 2], -[1, 1, 1, 1, 1, 1], -[0, 0, 0, 0, 0, 0], -[0, 0, 0, 0, 0, 0], -[8, 8, 8, 8, 8, 8], -[1, 1, 1, 4, 4, 1], -[0, 0, 0, 0, 0, 0], -[4, 0, 0, 4, 4, 0], -[4, 4, 4, 4, 4, 4], -[1, 1, 1, 4, 4, 1], -[0, 4, 0, 4, 4, 0], -[0, 0, 0, 0, 0, 0], -[4, 4, 4, 4, 4, 4], -[1, 1, 1, 5, 5, 1], -[0, 0, 0, 0, 0, 0], -[0, 0, 0, 0, 0, 0], -[5, 5, 5, 5, 5, 5], -[6, 6, 6, 6, 6, 6], -[6, -1, 0, 6, 0, 6], -[6, 0, 0, 6, 0, 6], -[6, 1, 1, 6, 1, 6], -[4, 4, 4, 4, 4, 4], -[0, 0, 0, 0, 0, 0], -[4, 0, 0, 4, 4, 4], -[1, 1, 1, 1, 1, 1], -[6, 4, -1, 6, 4, 6], -[6, 4, 0, 6, 4, 6], -[6, 0, 0, 6, 0, 6], -[6, 1, 1, 6, 1, 6], -[5, 5, 5, 5, 5, 5], -[0, 0, 0, 0, 0, 0], -[0, 0, 0, 0, 0, 0], -[1, 1, 1, 1, 1, 1], -[2, 2, 2, 2, 2, 2], -[0, 0, 0, 0, 0, 0], -[2, 0, 2, 2, 0, 2], -[1, 1, 1, 1, 1, 1], -[2, 2, 2, 2, 2, 2], -[0, 0, 0, 0, 0, 0], -[2, 0, 2, 2, 2, 2], -[1, 1, 1, 1, 1, 1], -[2, 4, 2, 2, 4, 2], -[0, 4, 0, 4, 4, 0], -[2, 0, 2, 2, 0, 2], -[1, 1, 1, 1, 1, 1], -[2, 2, 2, 2, 2, 2], -[0, 0, 0, 0, 0, 0], -[0, 0, 0, 0, 0, 0], -[1, 1, 1, 1, 1, 1], -[6, 1, 1, 6, -1, 6], -[6, 1, 1, 6, 0, 6], -[6, 0, 0, 6, 0, 6], -[6, 2, 2, 6, 2, 6], -[4, 1, 1, 4, 4, 1], -[0, 1, 1, 0, 0, 1], -[4, 0, 0, 4, 4, 4], -[2, 2, 2, 2, 2, 2], -[6, 1, 1, 6, 4, 6], -[6, 1, 1, 6, 4, 6], -[6, 0, 0, 6, 0, 6], -[6, 2, 2, 6, 2, 6], -[5, 1, 1, 5, 5, 1], -[0, 1, 1, 0, 0, 1], -[0, 0, 0, 0, 0, 0], -[2, 2, 2, 2, 2, 2], -[1, 1, 1, 1, 1, 1], -[0, 0, 0, 0, 0, 0], -[0, 0, 0, 0, 0, 0], -[6, 6, 6, 6, 6, 6], -[1, 1, 1, 1, 1, 1], -[0, 0, 0, 0, 0, 0], -[0, 0, 0, 0, 0, 0], -[4, 4, 4, 4, 4, 4], -[1, 1, 1, 1, 4, 1], -[0, 4, 0, 4, 4, 0], -[0, 0, 0, 0, 0, 0], -[4, 4, 4, 4, 4, 4], -[1, 1, 1, 1, 1, 1], -[0, 0, 0, 0, 0, 0], -[0, 5, 0, 5, 0, 5], -[5, 5, 5, 5, 5, 5], -[5, 5, 5, 5, 5, 5], -[0, 5, 0, 5, 0, 5], -[-1, 5, 0, 5, 0, 5], -[1, 5, 1, 5, 1, 5], -[4, 5, -1, 5, 4, 5], -[0, 5, 0, 5, 0, 5], -[4, 5, 0, 5, 4, 5], -[1, 5, 1, 5, 1, 5], -[4, 4, 4, 4, 4, 4], -[0, 4, 0, 4, 4, 4], -[0, 0, 0, 0, 0, 0], -[1, 1, 1, 1, 1, 1], -[6, 6, 6, 6, 6, 6], -[0, 0, 0, 0, 0, 0], -[0, 0, 0, 0, 0, 0], -[1, 1, 1, 1, 1, 1], -[2, 5, 2, 5, -1, 5], -[0, 5, 0, 5, 0, 5], -[2, 5, 2, 5, 0, 5], -[1, 5, 1, 5, 1, 5], -[2, 5, 2, 5, 4, 5], -[0, 5, 0, 5, 0, 5], -[2, 5, 2, 5, 4, 5], -[1, 5, 1, 5, 1, 5], -[2, 4, 2, 4, 4, 2], -[0, 4, 0, 4, 4, 4], -[2, 0, 2, 0, 0, 2], -[1, 1, 1, 1, 1, 1], -[2, 6, 2, 6, 6, 2], -[0, 0, 0, 0, 0, 0], -[2, 0, 2, 0, 0, 2], -[1, 1, 1, 1, 1, 1], -[1, 1, 1, 1, 1, 1], -[0, 1, 1, 1, 0, 1], -[0, 0, 0, 0, 0, 0], -[2, 2, 2, 2, 2, 2], -[4, 1, 1, 1, 4, 1], -[0, 1, 1, 1, 0, 1], -[4, 0, 0, 4, 4, 0], -[2, 2, 2, 2, 2, 2], -[1, 1, 1, 1, 1, 1], -[0, 1, 1, 1, 1, 1], -[0, 0, 0, 0, 0, 0], -[2, 2, 2, 2, 2, 2], -[1, 1, 1, 1, 1, 1], -[0, 0, 0, 0, 0, 0], -[0, 0, 0, 0, 0, 0], -[2, 2, 2, 2, 2, 2], -[1, 1, 1, 1, 1, 1], -[0, 0, 0, 0, 0, 0], -[0, 0, 0, 0, 0, 0], -[5, 5, 5, 5, 5, 5], -[1, 1, 1, 1, 4, 1], -[0, 0, 0, 0, 0, 0], -[4, 0, 0, 4, 4, 0], -[4, 4, 4, 4, 4, 4], -[1, 1, 1, 1, 1, 1], -[0, 0, 0, 0, 0, 0], -[0, 0, 0, 0, 0, 0], -[4, 4, 4, 4, 4, 4], -[1, 1, 1, 1, 1, 1], -[6, 0, 0, 6, 0, 6], -[0, 0, 0, 0, 0, 0], -[6, 6, 6, 6, 6, 6], -[5, 5, 5, 5, 5, 5], -[5, 5, 0, 5, 0, 5], -[5, 5, 0, 5, 0, 5], -[5, 5, 1, 5, 1, 5], -[4, 4, 4, 4, 4, 4], -[0, 0, 0, 0, 0, 0], -[4, 4, 0, 4, 4, 4], -[1, 1, 1, 1, 1, 1], -[4, 4, 4, 4, 4, 4], -[4, 4, 0, 4, 4, 4], -[0, 0, 0, 0, 0, 0], -[1, 1, 1, 1, 1, 1], -[8, 8, 8, 8, 8, 8], -[0, 0, 0, 0, 0, 0], -[0, 0, 0, 0, 0, 0], -[1, 1, 1, 1, 1, 1], -[2, 2, 2, 2, 2, 2], -[0, 0, 0, 0, 0, 0], -[2, 2, 2, 2, 0, 2], -[1, 1, 1, 1, 1, 1], -[2, 2, 2, 2, 2, 2], -[0, 0, 0, 0, 0, 0], -[2, 2, 2, 2, 2, 2], -[1, 1, 1, 1, 1, 1], -[2, 2, 2, 2, 2, 2], -[0, 0, 0, 0, 0, 0], -[0, 0, 0, 0, 0, 0], -[4, 1, 1, 4, 4, 1], -[2, 2, 2, 2, 2, 2], -[0, 0, 0, 0, 0, 0], -[0, 0, 0, 0, 0, 0], -[1, 1, 1, 1, 1, 1], -[1, 1, 1, 1, 1, 1], -[1, 1, 1, 1, 0, 1], -[0, 0, 0, 0, 0, 0], -[2, 2, 2, 2, 2, 2], -[1, 1, 1, 1, 1, 1], -[0, 0, 0, 0, 0, 0], -[0, 0, 0, 0, 0, 0], -[2, 4, 2, 4, 4, 2], -[1, 1, 1, 1, 1, 1], -[1, 1, 1, 1, 1, 1], -[0, 0, 0, 0, 0, 0], -[2, 2, 2, 2, 2, 2], -[1, 1, 1, 1, 1, 1], -[0, 0, 0, 0, 0, 0], -[0, 0, 0, 0, 0, 0], -[2, 2, 2, 2, 2, 2], -[1, 1, 1, 1, 1, 1], -[0, 0, 0, 0, 0, 0], -[0, 0, 0, 0, 0, 0], -[5, 5, 5, 5, 5, 5], -[1, 1, 1, 1, 1, 1], -[0, 0, 0, 0, 0, 0], -[0, 0, 0, 0, 0, 0], -[4, 4, 4, 4, 4, 4], -[1, 1, 1, 1, 1, 1], -[0, 0, 0, 0, 0, 0], -[0, 0, 0, 0, 0, 0], -[4, 4, 4, 4, 4, 4], -[1, 1, 1, 1, 1, 1], -[0, 0, 0, 0, 0, 0], -[0, 0, 0, 0, 0, 0], -[12, 12, 12, 12, 12, 12] -] \ No newline at end of file diff --git a/trellis/representations/mesh/utils_cube.py b/trellis/representations/mesh/utils_cube.py deleted file mode 100644 index 9befc1de4561d2d682873e0b752948275ddc9189..0000000000000000000000000000000000000000 --- a/trellis/representations/mesh/utils_cube.py +++ /dev/null @@ -1,61 +0,0 @@ -import torch -cube_corners = torch.tensor([[0, 0, 0], [1, 0, 0], [0, 1, 0], [1, 1, 0], [0, 0, 1], [ - 1, 0, 1], [0, 1, 1], [1, 1, 1]], dtype=torch.int) -cube_neighbor = torch.tensor([[1, 0, 0], [-1, 0, 0], [0, 1, 0], [0, -1, 0], [0, 0, 1], [0, 0, -1]]) -cube_edges = torch.tensor([0, 1, 1, 5, 4, 5, 0, 4, 2, 3, 3, 7, 6, 7, 2, 6, - 2, 0, 3, 1, 7, 5, 6, 4], dtype=torch.long, requires_grad=False) - -def construct_dense_grid(res, device='cuda'): - '''construct a dense grid based on resolution''' - res_v = res + 1 - vertsid = torch.arange(res_v ** 3, device=device) - coordsid = vertsid.reshape(res_v, res_v, res_v)[:res, :res, :res].flatten() - cube_corners_bias = (cube_corners[:, 0] * res_v + cube_corners[:, 1]) * res_v + cube_corners[:, 2] - cube_fx8 = (coordsid.unsqueeze(1) + cube_corners_bias.unsqueeze(0).to(device)) - verts = torch.stack([vertsid // (res_v ** 2), (vertsid // res_v) % res_v, vertsid % res_v], dim=1) - return verts, cube_fx8 - - -def construct_voxel_grid(coords): - verts = (cube_corners.unsqueeze(0).to(coords) + coords.unsqueeze(1)).reshape(-1, 3) - verts_unique, inverse_indices = torch.unique(verts, dim=0, return_inverse=True) - cubes = inverse_indices.reshape(-1, 8) - return verts_unique, cubes - - -def cubes_to_verts(num_verts, cubes, value, reduce='mean'): - """ - Args: - cubes [Vx8] verts index for each cube - value [Vx8xM] value to be scattered - Operation: - reduced[cubes[i][j]][k] += value[i][k] - """ - M = value.shape[2] # number of channels - reduced = torch.zeros(num_verts, M, device=cubes.device) - return torch.scatter_reduce(reduced, 0, - cubes.unsqueeze(-1).expand(-1, -1, M).flatten(0, 1), - value.flatten(0, 1), reduce=reduce, include_self=False) - -def sparse_cube2verts(coords, feats, training=True): - new_coords, cubes = construct_voxel_grid(coords) - new_feats = cubes_to_verts(new_coords.shape[0], cubes, feats) - if training: - con_loss = torch.mean((feats - new_feats[cubes]) ** 2) - else: - con_loss = 0.0 - return new_coords, new_feats, con_loss - - -def get_dense_attrs(coords : torch.Tensor, feats : torch.Tensor, res : int, sdf_init=True): - F = feats.shape[-1] - dense_attrs = torch.zeros([res] * 3 + [F], device=feats.device) - if sdf_init: - dense_attrs[..., 0] = 1 # initial outside sdf value - dense_attrs[coords[:, 0], coords[:, 1], coords[:, 2], :] = feats - return dense_attrs.reshape(-1, F) - - -def get_defomed_verts(v_pos : torch.Tensor, deform : torch.Tensor, res): - return v_pos / res - 0.5 + (1 - 1e-8) / (res * 2) * torch.tanh(deform) - \ No newline at end of file diff --git a/trellis/representations/octree/__init__.py b/trellis/representations/octree/__init__.py deleted file mode 100644 index f66a39a5a7498e2e99fe9d94d663796b3bc157b5..0000000000000000000000000000000000000000 --- a/trellis/representations/octree/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from .octree_dfs import DfsOctree \ No newline at end of file diff --git a/trellis/representations/octree/octree_dfs.py b/trellis/representations/octree/octree_dfs.py deleted file mode 100644 index 710f18b73c6a68acbc7bfb470efef7632fbbd6ed..0000000000000000000000000000000000000000 --- a/trellis/representations/octree/octree_dfs.py +++ /dev/null @@ -1,362 +0,0 @@ -import torch -import torch.nn as nn -import torch.nn.functional as F - - -DEFAULT_TRIVEC_CONFIG = { - 'dim': 8, - 'rank': 8, -} - -DEFAULT_VOXEL_CONFIG = { - 'solid': False, -} - -DEFAULT_DECOPOLY_CONFIG = { - 'degree': 8, - 'rank': 16, -} - - -class DfsOctree: - """ - Sparse Voxel Octree (SVO) implementation for PyTorch. - Using Depth-First Search (DFS) order to store the octree. - DFS order suits rendering and ray tracing. - - The structure and data are separatedly stored. - Structure is stored as a continuous array, each element is a 3*32 bits descriptor. - |-----------------------------------------| - | 0:3 bits | 4:31 bits | - | leaf num | unused | - |-----------------------------------------| - | 0:31 bits | - | child ptr | - |-----------------------------------------| - | 0:31 bits | - | data ptr | - |-----------------------------------------| - Each element represents a non-leaf node in the octree. - The valid mask is used to indicate whether the children are valid. - The leaf mask is used to indicate whether the children are leaf nodes. - The child ptr is used to point to the first non-leaf child. Non-leaf children descriptors are stored continuously from the child ptr. - The data ptr is used to point to the data of leaf children. Leaf children data are stored continuously from the data ptr. - - There are also auxiliary arrays to store the additional structural information to facilitate parallel processing. - - Position: the position of the octree nodes. - - Depth: the depth of the octree nodes. - - Args: - depth (int): the depth of the octree. - """ - - def __init__( - self, - depth, - aabb=[0,0,0,1,1,1], - sh_degree=2, - primitive='voxel', - primitive_config={}, - device='cuda', - ): - self.max_depth = depth - self.aabb = torch.tensor(aabb, dtype=torch.float32, device=device) - self.device = device - self.sh_degree = sh_degree - self.active_sh_degree = sh_degree - self.primitive = primitive - self.primitive_config = primitive_config - - self.structure = torch.tensor([[8, 1, 0]], dtype=torch.int32, device=self.device) - self.position = torch.zeros((8, 3), dtype=torch.float32, device=self.device) - self.depth = torch.zeros((8, 1), dtype=torch.uint8, device=self.device) - self.position[:, 0] = torch.tensor([0.25, 0.75, 0.25, 0.75, 0.25, 0.75, 0.25, 0.75], device=self.device) - self.position[:, 1] = torch.tensor([0.25, 0.25, 0.75, 0.75, 0.25, 0.25, 0.75, 0.75], device=self.device) - self.position[:, 2] = torch.tensor([0.25, 0.25, 0.25, 0.25, 0.75, 0.75, 0.75, 0.75], device=self.device) - self.depth[:, 0] = 1 - - self.data = ['position', 'depth'] - self.param_names = [] - - if primitive == 'voxel': - self.features_dc = torch.zeros((8, 1, 3), dtype=torch.float32, device=self.device) - self.features_ac = torch.zeros((8, (sh_degree+1)**2-1, 3), dtype=torch.float32, device=self.device) - self.data += ['features_dc', 'features_ac'] - self.param_names += ['features_dc', 'features_ac'] - if not primitive_config.get('solid', False): - self.density = torch.zeros((8, 1), dtype=torch.float32, device=self.device) - self.data.append('density') - self.param_names.append('density') - elif primitive == 'gaussian': - self.features_dc = torch.zeros((8, 1, 3), dtype=torch.float32, device=self.device) - self.features_ac = torch.zeros((8, (sh_degree+1)**2-1, 3), dtype=torch.float32, device=self.device) - self.opacity = torch.zeros((8, 1), dtype=torch.float32, device=self.device) - self.data += ['features_dc', 'features_ac', 'opacity'] - self.param_names += ['features_dc', 'features_ac', 'opacity'] - elif primitive == 'trivec': - self.trivec = torch.zeros((8, primitive_config['rank'], 3, primitive_config['dim']), dtype=torch.float32, device=self.device) - self.density = torch.zeros((8, primitive_config['rank']), dtype=torch.float32, device=self.device) - self.features_dc = torch.zeros((8, primitive_config['rank'], 1, 3), dtype=torch.float32, device=self.device) - self.features_ac = torch.zeros((8, primitive_config['rank'], (sh_degree+1)**2-1, 3), dtype=torch.float32, device=self.device) - self.density_shift = 0 - self.data += ['trivec', 'density', 'features_dc', 'features_ac'] - self.param_names += ['trivec', 'density', 'features_dc', 'features_ac'] - elif primitive == 'decoupoly': - self.decoupoly_V = torch.zeros((8, primitive_config['rank'], 3), dtype=torch.float32, device=self.device) - self.decoupoly_g = torch.zeros((8, primitive_config['rank'], primitive_config['degree']), dtype=torch.float32, device=self.device) - self.density = torch.zeros((8, primitive_config['rank']), dtype=torch.float32, device=self.device) - self.features_dc = torch.zeros((8, primitive_config['rank'], 1, 3), dtype=torch.float32, device=self.device) - self.features_ac = torch.zeros((8, primitive_config['rank'], (sh_degree+1)**2-1, 3), dtype=torch.float32, device=self.device) - self.density_shift = 0 - self.data += ['decoupoly_V', 'decoupoly_g', 'density', 'features_dc', 'features_ac'] - self.param_names += ['decoupoly_V', 'decoupoly_g', 'density', 'features_dc', 'features_ac'] - - self.setup_functions() - - def setup_functions(self): - self.density_activation = (lambda x: torch.exp(x - 2)) if self.primitive != 'trivec' else (lambda x: x) - self.opacity_activation = lambda x: torch.sigmoid(x - 6) - self.inverse_opacity_activation = lambda x: torch.log(x / (1 - x)) + 6 - self.color_activation = lambda x: torch.sigmoid(x) - - @property - def num_non_leaf_nodes(self): - return self.structure.shape[0] - - @property - def num_leaf_nodes(self): - return self.depth.shape[0] - - @property - def cur_depth(self): - return self.depth.max().item() - - @property - def occupancy(self): - return self.num_leaf_nodes / 8 ** self.cur_depth - - @property - def get_xyz(self): - return self.position - - @property - def get_depth(self): - return self.depth - - @property - def get_density(self): - if self.primitive == 'voxel' and self.voxel_config['solid']: - return torch.full((self.position.shape[0], 1), 1000, dtype=torch.float32, device=self.device) - return self.density_activation(self.density) - - @property - def get_opacity(self): - return self.opacity_activation(self.density) - - @property - def get_trivec(self): - return self.trivec - - @property - def get_decoupoly(self): - return F.normalize(self.decoupoly_V, dim=-1), self.decoupoly_g - - @property - def get_color(self): - return self.color_activation(self.colors) - - @property - def get_features(self): - if self.sh_degree == 0: - return self.features_dc - return torch.cat([self.features_dc, self.features_ac], dim=-2) - - def state_dict(self): - ret = {'structure': self.structure, 'position': self.position, 'depth': self.depth, 'sh_degree': self.sh_degree, 'active_sh_degree': self.active_sh_degree, 'trivec_config': self.trivec_config, 'voxel_config': self.voxel_config, 'primitive': self.primitive} - if hasattr(self, 'density_shift'): - ret['density_shift'] = self.density_shift - for data in set(self.data + self.param_names): - if not isinstance(getattr(self, data), nn.Module): - ret[data] = getattr(self, data) - else: - ret[data] = getattr(self, data).state_dict() - return ret - - def load_state_dict(self, state_dict): - keys = list(set(self.data + self.param_names + list(state_dict.keys()) + ['structure', 'position', 'depth'])) - for key in keys: - if key not in state_dict: - print(f"Warning: key {key} not found in the state_dict.") - continue - try: - if not isinstance(getattr(self, key), nn.Module): - setattr(self, key, state_dict[key]) - else: - getattr(self, key).load_state_dict(state_dict[key]) - except Exception as e: - print(e) - raise ValueError(f"Error loading key {key}.") - - def gather_from_leaf_children(self, data): - """ - Gather the data from the leaf children. - - Args: - data (torch.Tensor): the data to gather. The first dimension should be the number of leaf nodes. - """ - leaf_cnt = self.structure[:, 0] - leaf_cnt_masks = [leaf_cnt == i for i in range(1, 9)] - ret = torch.zeros((self.num_non_leaf_nodes,), dtype=data.dtype, device=self.device) - for i in range(8): - if leaf_cnt_masks[i].sum() == 0: - continue - start = self.structure[leaf_cnt_masks[i], 2] - for j in range(i+1): - ret[leaf_cnt_masks[i]] += data[start + j] - return ret - - def gather_from_non_leaf_children(self, data): - """ - Gather the data from the non-leaf children. - - Args: - data (torch.Tensor): the data to gather. The first dimension should be the number of leaf nodes. - """ - non_leaf_cnt = 8 - self.structure[:, 0] - non_leaf_cnt_masks = [non_leaf_cnt == i for i in range(1, 9)] - ret = torch.zeros_like(data, device=self.device) - for i in range(8): - if non_leaf_cnt_masks[i].sum() == 0: - continue - start = self.structure[non_leaf_cnt_masks[i], 1] - for j in range(i+1): - ret[non_leaf_cnt_masks[i]] += data[start + j] - return ret - - def structure_control(self, mask): - """ - Control the structure of the octree. - - Args: - mask (torch.Tensor): the mask to control the structure. 1 for subdivide, -1 for merge, 0 for keep. - """ - # Dont subdivide when the depth is the maximum. - mask[self.depth.squeeze() == self.max_depth] = torch.clamp_max(mask[self.depth.squeeze() == self.max_depth], 0) - # Dont merge when the depth is the minimum. - mask[self.depth.squeeze() == 1] = torch.clamp_min(mask[self.depth.squeeze() == 1], 0) - - # Gather control mask - structre_ctrl = self.gather_from_leaf_children(mask) - structre_ctrl[structre_ctrl==-8] = -1 - - new_leaf_num = self.structure[:, 0].clone() - # Modify the leaf num. - structre_valid = structre_ctrl >= 0 - new_leaf_num[structre_valid] -= structre_ctrl[structre_valid] # Add the new nodes. - structre_delete = structre_ctrl < 0 - merged_nodes = self.gather_from_non_leaf_children(structre_delete.int()) - new_leaf_num += merged_nodes # Delete the merged nodes. - - # Update the structure array to allocate new nodes. - mem_offset = torch.zeros((self.num_non_leaf_nodes + 1,), dtype=torch.int32, device=self.device) - mem_offset.index_add_(0, self.structure[structre_valid, 1], structre_ctrl[structre_valid]) # Add the new nodes. - mem_offset[:-1] -= structre_delete.int() # Delete the merged nodes. - new_structre_idx = torch.arange(0, self.num_non_leaf_nodes + 1, dtype=torch.int32, device=self.device) + mem_offset.cumsum(0) - new_structure_length = new_structre_idx[-1].item() - new_structre_idx = new_structre_idx[:-1] - new_structure = torch.empty((new_structure_length, 3), dtype=torch.int32, device=self.device) - new_structure[new_structre_idx[structre_valid], 0] = new_leaf_num[structre_valid] - - # Initialize the new nodes. - new_node_mask = torch.ones((new_structure_length,), dtype=torch.bool, device=self.device) - new_node_mask[new_structre_idx[structre_valid]] = False - new_structure[new_node_mask, 0] = 8 # Initialize to all leaf nodes. - new_node_num = new_node_mask.sum().item() - - # Rebuild child ptr. - non_leaf_cnt = 8 - new_structure[:, 0] - new_child_ptr = torch.cat([torch.zeros((1,), dtype=torch.int32, device=self.device), non_leaf_cnt.cumsum(0)[:-1]]) - new_structure[:, 1] = new_child_ptr + 1 - - # Rebuild data ptr with old data. - leaf_cnt = torch.zeros((new_structure_length,), dtype=torch.int32, device=self.device) - leaf_cnt.index_add_(0, new_structre_idx, self.structure[:, 0]) - old_data_ptr = torch.cat([torch.zeros((1,), dtype=torch.int32, device=self.device), leaf_cnt.cumsum(0)[:-1]]) - - # Update the data array - subdivide_mask = mask == 1 - merge_mask = mask == -1 - data_valid = ~(subdivide_mask | merge_mask) - mem_offset = torch.zeros((self.num_leaf_nodes + 1,), dtype=torch.int32, device=self.device) - mem_offset.index_add_(0, old_data_ptr[new_node_mask], torch.full((new_node_num,), 8, dtype=torch.int32, device=self.device)) # Add data array for new nodes - mem_offset[:-1] -= subdivide_mask.int() # Delete data elements for subdivide nodes - mem_offset[:-1] -= merge_mask.int() # Delete data elements for merge nodes - mem_offset.index_add_(0, self.structure[structre_valid, 2], merged_nodes[structre_valid]) # Add data elements for merge nodes - new_data_idx = torch.arange(0, self.num_leaf_nodes + 1, dtype=torch.int32, device=self.device) + mem_offset.cumsum(0) - new_data_length = new_data_idx[-1].item() - new_data_idx = new_data_idx[:-1] - new_data = {data: torch.empty((new_data_length,) + getattr(self, data).shape[1:], dtype=getattr(self, data).dtype, device=self.device) for data in self.data} - for data in self.data: - new_data[data][new_data_idx[data_valid]] = getattr(self, data)[data_valid] - - # Rebuild data ptr - leaf_cnt = new_structure[:, 0] - new_data_ptr = torch.cat([torch.zeros((1,), dtype=torch.int32, device=self.device), leaf_cnt.cumsum(0)[:-1]]) - new_structure[:, 2] = new_data_ptr - - # Initialize the new data array - ## For subdivide nodes - if subdivide_mask.sum() > 0: - subdivide_data_ptr = new_structure[new_node_mask, 2] - for data in self.data: - for i in range(8): - if data == 'position': - offset = torch.tensor([i // 4, (i // 2) % 2, i % 2], dtype=torch.float32, device=self.device) - 0.5 - scale = 2 ** (-1.0 - self.depth[subdivide_mask]) - new_data['position'][subdivide_data_ptr + i] = self.position[subdivide_mask] + offset * scale - elif data == 'depth': - new_data['depth'][subdivide_data_ptr + i] = self.depth[subdivide_mask] + 1 - elif data == 'opacity': - new_data['opacity'][subdivide_data_ptr + i] = self.inverse_opacity_activation(torch.sqrt(self.opacity_activation(self.opacity[subdivide_mask]))) - elif data == 'trivec': - offset = torch.tensor([i // 4, (i // 2) % 2, i % 2], dtype=torch.float32, device=self.device) * 0.5 - coord = (torch.linspace(0, 0.5, self.trivec.shape[-1], dtype=torch.float32, device=self.device)[None] + offset[:, None]).reshape(1, 3, self.trivec.shape[-1], 1) - axis = torch.linspace(0, 1, 3, dtype=torch.float32, device=self.device).reshape(1, 3, 1, 1).repeat(1, 1, self.trivec.shape[-1], 1) - coord = torch.stack([coord, axis], dim=3).reshape(1, 3, self.trivec.shape[-1], 2).expand(self.trivec[subdivide_mask].shape[0], -1, -1, -1) * 2 - 1 - new_data['trivec'][subdivide_data_ptr + i] = F.grid_sample(self.trivec[subdivide_mask], coord, align_corners=True) - else: - new_data[data][subdivide_data_ptr + i] = getattr(self, data)[subdivide_mask] - ## For merge nodes - if merge_mask.sum() > 0: - merge_data_ptr = torch.empty((merged_nodes.sum().item(),), dtype=torch.int32, device=self.device) - merge_nodes_cumsum = torch.cat([torch.zeros((1,), dtype=torch.int32, device=self.device), merged_nodes.cumsum(0)[:-1]]) - for i in range(8): - merge_data_ptr[merge_nodes_cumsum[merged_nodes > i] + i] = new_structure[new_structre_idx[merged_nodes > i], 2] + i - old_merge_data_ptr = self.structure[structre_delete, 2] - for data in self.data: - if data == 'position': - scale = 2 ** (1.0 - self.depth[old_merge_data_ptr]) - new_data['position'][merge_data_ptr] = ((self.position[old_merge_data_ptr] + 0.5) / scale).floor() * scale + 0.5 * scale - 0.5 - elif data == 'depth': - new_data['depth'][merge_data_ptr] = self.depth[old_merge_data_ptr] - 1 - elif data == 'opacity': - new_data['opacity'][subdivide_data_ptr + i] = self.inverse_opacity_activation(self.opacity_activation(self.opacity[subdivide_mask])**2) - elif data == 'trivec': - new_data['trivec'][merge_data_ptr] = self.trivec[old_merge_data_ptr] - else: - new_data[data][merge_data_ptr] = getattr(self, data)[old_merge_data_ptr] - - # Update the structure and data array - self.structure = new_structure - for data in self.data: - setattr(self, data, new_data[data]) - - # Save data array control temp variables - self.data_rearrange_buffer = { - 'subdivide_mask': subdivide_mask, - 'merge_mask': merge_mask, - 'data_valid': data_valid, - 'new_data_idx': new_data_idx, - 'new_data_length': new_data_length, - 'new_data': new_data - } diff --git a/trellis/representations/radiance_field/__init__.py b/trellis/representations/radiance_field/__init__.py deleted file mode 100644 index b72a1b7e76b509ee5a5e6979858eb17b4158a151..0000000000000000000000000000000000000000 --- a/trellis/representations/radiance_field/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from .strivec import Strivec \ No newline at end of file diff --git a/trellis/representations/radiance_field/strivec.py b/trellis/representations/radiance_field/strivec.py deleted file mode 100644 index f2dc78cdd5a08e994daa57247a3f01f3d43986f9..0000000000000000000000000000000000000000 --- a/trellis/representations/radiance_field/strivec.py +++ /dev/null @@ -1,28 +0,0 @@ -import torch -import torch.nn as nn -import torch.nn.functional as F -import numpy as np -from ..octree import DfsOctree as Octree - - -class Strivec(Octree): - def __init__( - self, - resolution: int, - aabb: list, - sh_degree: int = 0, - rank: int = 8, - dim: int = 8, - device: str = "cuda", - ): - assert np.log2(resolution) % 1 == 0, "Resolution must be a power of 2" - self.resolution = resolution - depth = int(np.round(np.log2(resolution))) - super().__init__( - depth=depth, - aabb=aabb, - sh_degree=sh_degree, - primitive="trivec", - primitive_config={"rank": rank, "dim": dim}, - device=device, - ) diff --git a/trellis/utils/__init__.py b/trellis/utils/__init__.py deleted file mode 100644 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/trellis/utils/general_utils.py b/trellis/utils/general_utils.py deleted file mode 100644 index b91e6d075dbb2c02438b5f345ec3deb164fa7a8f..0000000000000000000000000000000000000000 --- a/trellis/utils/general_utils.py +++ /dev/null @@ -1,187 +0,0 @@ -import numpy as np -import cv2 -import torch - - -# Dictionary utils -def _dict_merge(dicta, dictb, prefix=''): - """ - Merge two dictionaries. - """ - assert isinstance(dicta, dict), 'input must be a dictionary' - assert isinstance(dictb, dict), 'input must be a dictionary' - dict_ = {} - all_keys = set(dicta.keys()).union(set(dictb.keys())) - for key in all_keys: - if key in dicta.keys() and key in dictb.keys(): - if isinstance(dicta[key], dict) and isinstance(dictb[key], dict): - dict_[key] = _dict_merge(dicta[key], dictb[key], prefix=f'{prefix}.{key}') - else: - raise ValueError(f'Duplicate key {prefix}.{key} found in both dictionaries. Types: {type(dicta[key])}, {type(dictb[key])}') - elif key in dicta.keys(): - dict_[key] = dicta[key] - else: - dict_[key] = dictb[key] - return dict_ - - -def dict_merge(dicta, dictb): - """ - Merge two dictionaries. - """ - return _dict_merge(dicta, dictb, prefix='') - - -def dict_foreach(dic, func, special_func={}): - """ - Recursively apply a function to all non-dictionary leaf values in a dictionary. - """ - assert isinstance(dic, dict), 'input must be a dictionary' - for key in dic.keys(): - if isinstance(dic[key], dict): - dic[key] = dict_foreach(dic[key], func) - else: - if key in special_func.keys(): - dic[key] = special_func[key](dic[key]) - else: - dic[key] = func(dic[key]) - return dic - - -def dict_reduce(dicts, func, special_func={}): - """ - Reduce a list of dictionaries. Leaf values must be scalars. - """ - assert isinstance(dicts, list), 'input must be a list of dictionaries' - assert all([isinstance(d, dict) for d in dicts]), 'input must be a list of dictionaries' - assert len(dicts) > 0, 'input must be a non-empty list of dictionaries' - all_keys = set([key for dict_ in dicts for key in dict_.keys()]) - reduced_dict = {} - for key in all_keys: - vlist = [dict_[key] for dict_ in dicts if key in dict_.keys()] - if isinstance(vlist[0], dict): - reduced_dict[key] = dict_reduce(vlist, func, special_func) - else: - if key in special_func.keys(): - reduced_dict[key] = special_func[key](vlist) - else: - reduced_dict[key] = func(vlist) - return reduced_dict - - -def dict_any(dic, func): - """ - Recursively apply a function to all non-dictionary leaf values in a dictionary. - """ - assert isinstance(dic, dict), 'input must be a dictionary' - for key in dic.keys(): - if isinstance(dic[key], dict): - if dict_any(dic[key], func): - return True - else: - if func(dic[key]): - return True - return False - - -def dict_all(dic, func): - """ - Recursively apply a function to all non-dictionary leaf values in a dictionary. - """ - assert isinstance(dic, dict), 'input must be a dictionary' - for key in dic.keys(): - if isinstance(dic[key], dict): - if not dict_all(dic[key], func): - return False - else: - if not func(dic[key]): - return False - return True - - -def dict_flatten(dic, sep='.'): - """ - Flatten a nested dictionary into a dictionary with no nested dictionaries. - """ - assert isinstance(dic, dict), 'input must be a dictionary' - flat_dict = {} - for key in dic.keys(): - if isinstance(dic[key], dict): - sub_dict = dict_flatten(dic[key], sep=sep) - for sub_key in sub_dict.keys(): - flat_dict[str(key) + sep + str(sub_key)] = sub_dict[sub_key] - else: - flat_dict[key] = dic[key] - return flat_dict - - -def make_grid(images, nrow=None, ncol=None, aspect_ratio=None): - num_images = len(images) - if nrow is None and ncol is None: - if aspect_ratio is not None: - nrow = int(np.round(np.sqrt(num_images / aspect_ratio))) - else: - nrow = int(np.sqrt(num_images)) - ncol = (num_images + nrow - 1) // nrow - elif nrow is None and ncol is not None: - nrow = (num_images + ncol - 1) // ncol - elif nrow is not None and ncol is None: - ncol = (num_images + nrow - 1) // nrow - else: - assert nrow * ncol >= num_images, 'nrow * ncol must be greater than or equal to the number of images' - - grid = np.zeros((nrow * images[0].shape[0], ncol * images[0].shape[1], images[0].shape[2]), dtype=images[0].dtype) - for i, img in enumerate(images): - row = i // ncol - col = i % ncol - grid[row * img.shape[0]:(row + 1) * img.shape[0], col * img.shape[1]:(col + 1) * img.shape[1]] = img - return grid - - -def notes_on_image(img, notes=None): - img = np.pad(img, ((0, 32), (0, 0), (0, 0)), 'constant', constant_values=0) - img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR) - if notes is not None: - img = cv2.putText(img, notes, (0, img.shape[0] - 4), cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 255, 255), 1) - img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) - return img - - -def save_image_with_notes(img, path, notes=None): - """ - Save an image with notes. - """ - if isinstance(img, torch.Tensor): - img = img.cpu().numpy().transpose(1, 2, 0) - if img.dtype == np.float32 or img.dtype == np.float64: - img = np.clip(img * 255, 0, 255).astype(np.uint8) - img = notes_on_image(img, notes) - cv2.imwrite(path, cv2.cvtColor(img, cv2.COLOR_RGB2BGR)) - - -# debug utils - -def atol(x, y): - """ - Absolute tolerance. - """ - return torch.abs(x - y) - - -def rtol(x, y): - """ - Relative tolerance. - """ - return torch.abs(x - y) / torch.clamp_min(torch.maximum(torch.abs(x), torch.abs(y)), 1e-12) - - -# print utils -def indent(s, n=4): - """ - Indent a string. - """ - lines = s.split('\n') - for i in range(1, len(lines)): - lines[i] = ' ' * n + lines[i] - return '\n'.join(lines) - diff --git a/trellis/utils/postprocessing_utils.py b/trellis/utils/postprocessing_utils.py deleted file mode 100644 index e1c35fe2f7138a18db2767c37646ba77358eba1b..0000000000000000000000000000000000000000 --- a/trellis/utils/postprocessing_utils.py +++ /dev/null @@ -1,587 +0,0 @@ -from typing import * -import numpy as np -import torch -import utils3d -import nvdiffrast.torch as dr -from tqdm import tqdm -import trimesh -import trimesh.visual -import xatlas -import pyvista as pv -from pymeshfix import _meshfix -import igraph -import cv2 -from PIL import Image -from .random_utils import sphere_hammersley_sequence -from .render_utils import render_multiview -from ..renderers import GaussianRenderer -from ..representations import Strivec, Gaussian, MeshExtractResult - - -@torch.no_grad() -def _fill_holes( - verts, - faces, - max_hole_size=0.04, - max_hole_nbe=32, - resolution=128, - num_views=500, - debug=False, - verbose=False -): - """ - Rasterize a mesh from multiple views and remove invisible faces. - Also includes postprocessing to: - 1. Remove connected components that are have low visibility. - 2. Mincut to remove faces at the inner side of the mesh connected to the outer side with a small hole. - - Args: - verts (torch.Tensor): Vertices of the mesh. Shape (V, 3). - faces (torch.Tensor): Faces of the mesh. Shape (F, 3). - max_hole_size (float): Maximum area of a hole to fill. - resolution (int): Resolution of the rasterization. - num_views (int): Number of views to rasterize the mesh. - verbose (bool): Whether to print progress. - """ - # Construct cameras - yaws = [] - pitchs = [] - for i in range(num_views): - y, p = sphere_hammersley_sequence(i, num_views) - yaws.append(y) - pitchs.append(p) - yaws = torch.tensor(yaws).cuda() - pitchs = torch.tensor(pitchs).cuda() - radius = 2.0 - fov = torch.deg2rad(torch.tensor(40)).cuda() - projection = utils3d.torch.perspective_from_fov_xy(fov, fov, 1, 3) - views = [] - for (yaw, pitch) in zip(yaws, pitchs): - orig = torch.tensor([ - torch.sin(yaw) * torch.cos(pitch), - torch.cos(yaw) * torch.cos(pitch), - torch.sin(pitch), - ]).cuda().float() * radius - view = utils3d.torch.view_look_at(orig, torch.tensor([0, 0, 0]).float().cuda(), torch.tensor([0, 0, 1]).float().cuda()) - views.append(view) - views = torch.stack(views, dim=0) - - # Rasterize - visblity = torch.zeros(faces.shape[0], dtype=torch.int32, device=verts.device) - rastctx = utils3d.torch.RastContext(backend='cuda') - for i in tqdm(range(views.shape[0]), total=views.shape[0], disable=not verbose, desc='Rasterizing'): - view = views[i] - buffers = utils3d.torch.rasterize_triangle_faces( - rastctx, verts[None], faces, resolution, resolution, view=view, projection=projection - ) - face_id = buffers['face_id'][0][buffers['mask'][0] > 0.95] - 1 - face_id = torch.unique(face_id).long() - visblity[face_id] += 1 - visblity = visblity.float() / num_views - - # Mincut - ## construct outer faces - edges, face2edge, edge_degrees = utils3d.torch.compute_edges(faces) - boundary_edge_indices = torch.nonzero(edge_degrees == 1).reshape(-1) - connected_components = utils3d.torch.compute_connected_components(faces, edges, face2edge) - outer_face_indices = torch.zeros(faces.shape[0], dtype=torch.bool, device=faces.device) - for i in range(len(connected_components)): - outer_face_indices[connected_components[i]] = visblity[connected_components[i]] > min(max(visblity[connected_components[i]].quantile(0.75).item(), 0.25), 0.5) - outer_face_indices = outer_face_indices.nonzero().reshape(-1) - - ## construct inner faces - inner_face_indices = torch.nonzero(visblity == 0).reshape(-1) - if verbose: - tqdm.write(f'Found {inner_face_indices.shape[0]} invisible faces') - if inner_face_indices.shape[0] == 0: - return verts, faces - - ## Construct dual graph (faces as nodes, edges as edges) - dual_edges, dual_edge2edge = utils3d.torch.compute_dual_graph(face2edge) - dual_edge2edge = edges[dual_edge2edge] - dual_edges_weights = torch.norm(verts[dual_edge2edge[:, 0]] - verts[dual_edge2edge[:, 1]], dim=1) - if verbose: - tqdm.write(f'Dual graph: {dual_edges.shape[0]} edges') - - ## solve mincut problem - ### construct main graph - g = igraph.Graph() - g.add_vertices(faces.shape[0]) - g.add_edges(dual_edges.cpu().numpy()) - g.es['weight'] = dual_edges_weights.cpu().numpy() - - ### source and target - g.add_vertex('s') - g.add_vertex('t') - - ### connect invisible faces to source - g.add_edges([(f, 's') for f in inner_face_indices], attributes={'weight': torch.ones(inner_face_indices.shape[0], dtype=torch.float32).cpu().numpy()}) - - ### connect outer faces to target - g.add_edges([(f, 't') for f in outer_face_indices], attributes={'weight': torch.ones(outer_face_indices.shape[0], dtype=torch.float32).cpu().numpy()}) - - ### solve mincut - cut = g.mincut('s', 't', (np.array(g.es['weight']) * 1000).tolist()) - remove_face_indices = torch.tensor([v for v in cut.partition[0] if v < faces.shape[0]], dtype=torch.long, device=faces.device) - if verbose: - tqdm.write(f'Mincut solved, start checking the cut') - - ### check if the cut is valid with each connected component - to_remove_cc = utils3d.torch.compute_connected_components(faces[remove_face_indices]) - if debug: - tqdm.write(f'Number of connected components of the cut: {len(to_remove_cc)}') - valid_remove_cc = [] - cutting_edges = [] - for cc in to_remove_cc: - #### check if the connected component has low visibility - visblity_median = visblity[remove_face_indices[cc]].median() - if debug: - tqdm.write(f'visblity_median: {visblity_median}') - if visblity_median > 0.25: - continue - - #### check if the cuting loop is small enough - cc_edge_indices, cc_edges_degree = torch.unique(face2edge[remove_face_indices[cc]], return_counts=True) - cc_boundary_edge_indices = cc_edge_indices[cc_edges_degree == 1] - cc_new_boundary_edge_indices = cc_boundary_edge_indices[~torch.isin(cc_boundary_edge_indices, boundary_edge_indices)] - if len(cc_new_boundary_edge_indices) > 0: - cc_new_boundary_edge_cc = utils3d.torch.compute_edge_connected_components(edges[cc_new_boundary_edge_indices]) - cc_new_boundary_edges_cc_center = [verts[edges[cc_new_boundary_edge_indices[edge_cc]]].mean(dim=1).mean(dim=0) for edge_cc in cc_new_boundary_edge_cc] - cc_new_boundary_edges_cc_area = [] - for i, edge_cc in enumerate(cc_new_boundary_edge_cc): - _e1 = verts[edges[cc_new_boundary_edge_indices[edge_cc]][:, 0]] - cc_new_boundary_edges_cc_center[i] - _e2 = verts[edges[cc_new_boundary_edge_indices[edge_cc]][:, 1]] - cc_new_boundary_edges_cc_center[i] - cc_new_boundary_edges_cc_area.append(torch.norm(torch.cross(_e1, _e2, dim=-1), dim=1).sum() * 0.5) - if debug: - cutting_edges.append(cc_new_boundary_edge_indices) - tqdm.write(f'Area of the cutting loop: {cc_new_boundary_edges_cc_area}') - if any([l > max_hole_size for l in cc_new_boundary_edges_cc_area]): - continue - - valid_remove_cc.append(cc) - - if debug: - face_v = verts[faces].mean(dim=1).cpu().numpy() - vis_dual_edges = dual_edges.cpu().numpy() - vis_colors = np.zeros((faces.shape[0], 3), dtype=np.uint8) - vis_colors[inner_face_indices.cpu().numpy()] = [0, 0, 255] - vis_colors[outer_face_indices.cpu().numpy()] = [0, 255, 0] - vis_colors[remove_face_indices.cpu().numpy()] = [255, 0, 255] - if len(valid_remove_cc) > 0: - vis_colors[remove_face_indices[torch.cat(valid_remove_cc)].cpu().numpy()] = [255, 0, 0] - utils3d.io.write_ply('dbg_dual.ply', face_v, edges=vis_dual_edges, vertex_colors=vis_colors) - - vis_verts = verts.cpu().numpy() - vis_edges = edges[torch.cat(cutting_edges)].cpu().numpy() - utils3d.io.write_ply('dbg_cut.ply', vis_verts, edges=vis_edges) - - - if len(valid_remove_cc) > 0: - remove_face_indices = remove_face_indices[torch.cat(valid_remove_cc)] - mask = torch.ones(faces.shape[0], dtype=torch.bool, device=faces.device) - mask[remove_face_indices] = 0 - faces = faces[mask] - faces, verts = utils3d.torch.remove_unreferenced_vertices(faces, verts) - if verbose: - tqdm.write(f'Removed {(~mask).sum()} faces by mincut') - else: - if verbose: - tqdm.write(f'Removed 0 faces by mincut') - - mesh = _meshfix.PyTMesh() - mesh.load_array(verts.cpu().numpy(), faces.cpu().numpy()) - mesh.fill_small_boundaries(nbe=max_hole_nbe, refine=True) - verts, faces = mesh.return_arrays() - verts, faces = torch.tensor(verts, device='cuda', dtype=torch.float32), torch.tensor(faces, device='cuda', dtype=torch.int32) - - return verts, faces - - -def postprocess_mesh( - vertices: np.array, - faces: np.array, - simplify: bool = True, - simplify_ratio: float = 0.9, - fill_holes: bool = True, - fill_holes_max_hole_size: float = 0.04, - fill_holes_max_hole_nbe: int = 32, - fill_holes_resolution: int = 1024, - fill_holes_num_views: int = 1000, - debug: bool = False, - verbose: bool = False, -): - """ - Postprocess a mesh by simplifying, removing invisible faces, and removing isolated pieces. - - Args: - vertices (np.array): Vertices of the mesh. Shape (V, 3). - faces (np.array): Faces of the mesh. Shape (F, 3). - simplify (bool): Whether to simplify the mesh, using quadric edge collapse. - simplify_ratio (float): Ratio of faces to keep after simplification. - fill_holes (bool): Whether to fill holes in the mesh. - fill_holes_max_hole_size (float): Maximum area of a hole to fill. - fill_holes_max_hole_nbe (int): Maximum number of boundary edges of a hole to fill. - fill_holes_resolution (int): Resolution of the rasterization. - fill_holes_num_views (int): Number of views to rasterize the mesh. - verbose (bool): Whether to print progress. - """ - - if verbose: - tqdm.write(f'Before postprocess: {vertices.shape[0]} vertices, {faces.shape[0]} faces') - - # Simplify - if simplify and simplify_ratio > 0: - mesh = pv.PolyData(vertices, np.concatenate([np.full((faces.shape[0], 1), 3), faces], axis=1)) - mesh = mesh.decimate(simplify_ratio, progress_bar=verbose) - vertices, faces = mesh.points, mesh.faces.reshape(-1, 4)[:, 1:] - if verbose: - tqdm.write(f'After decimate: {vertices.shape[0]} vertices, {faces.shape[0]} faces') - - # Remove invisible faces - if fill_holes: - vertices, faces = torch.tensor(vertices).cuda(), torch.tensor(faces.astype(np.int32)).cuda() - vertices, faces = _fill_holes( - vertices, faces, - max_hole_size=fill_holes_max_hole_size, - max_hole_nbe=fill_holes_max_hole_nbe, - resolution=fill_holes_resolution, - num_views=fill_holes_num_views, - debug=debug, - verbose=verbose, - ) - vertices, faces = vertices.cpu().numpy(), faces.cpu().numpy() - if verbose: - tqdm.write(f'After remove invisible faces: {vertices.shape[0]} vertices, {faces.shape[0]} faces') - - return vertices, faces - - -def parametrize_mesh(vertices: np.array, faces: np.array): - """ - Parametrize a mesh to a texture space, using xatlas. - - Args: - vertices (np.array): Vertices of the mesh. Shape (V, 3). - faces (np.array): Faces of the mesh. Shape (F, 3). - """ - - vmapping, indices, uvs = xatlas.parametrize(vertices, faces) - - vertices = vertices[vmapping] - faces = indices - - return vertices, faces, uvs - - -def bake_texture( - vertices: np.array, - faces: np.array, - uvs: np.array, - observations: List[np.array], - masks: List[np.array], - extrinsics: List[np.array], - intrinsics: List[np.array], - texture_size: int = 2048, - near: float = 0.1, - far: float = 10.0, - mode: Literal['fast', 'opt'] = 'opt', - lambda_tv: float = 1e-2, - verbose: bool = False, -): - """ - Bake texture to a mesh from multiple observations. - - Args: - vertices (np.array): Vertices of the mesh. Shape (V, 3). - faces (np.array): Faces of the mesh. Shape (F, 3). - uvs (np.array): UV coordinates of the mesh. Shape (V, 2). - observations (List[np.array]): List of observations. Each observation is a 2D image. Shape (H, W, 3). - masks (List[np.array]): List of masks. Each mask is a 2D image. Shape (H, W). - extrinsics (List[np.array]): List of extrinsics. Shape (4, 4). - intrinsics (List[np.array]): List of intrinsics. Shape (3, 3). - texture_size (int): Size of the texture. - near (float): Near plane of the camera. - far (float): Far plane of the camera. - mode (Literal['fast', 'opt']): Mode of texture baking. - lambda_tv (float): Weight of total variation loss in optimization. - verbose (bool): Whether to print progress. - """ - vertices = torch.tensor(vertices).cuda() - faces = torch.tensor(faces.astype(np.int32)).cuda() - uvs = torch.tensor(uvs).cuda() - observations = [torch.tensor(obs / 255.0).float().cuda() for obs in observations] - masks = [torch.tensor(m>0).bool().cuda() for m in masks] - views = [utils3d.torch.extrinsics_to_view(torch.tensor(extr).cuda()) for extr in extrinsics] - projections = [utils3d.torch.intrinsics_to_perspective(torch.tensor(intr).cuda(), near, far) for intr in intrinsics] - - if mode == 'fast': - texture = torch.zeros((texture_size * texture_size, 3), dtype=torch.float32).cuda() - texture_weights = torch.zeros((texture_size * texture_size), dtype=torch.float32).cuda() - rastctx = utils3d.torch.RastContext(backend='cuda') - for observation, view, projection in tqdm(zip(observations, views, projections), total=len(observations), disable=not verbose, desc='Texture baking (fast)'): - with torch.no_grad(): - rast = utils3d.torch.rasterize_triangle_faces( - rastctx, vertices[None], faces, observation.shape[1], observation.shape[0], uv=uvs[None], view=view, projection=projection - ) - uv_map = rast['uv'][0].detach().flip(0) - mask = rast['mask'][0].detach().bool() & masks[0] - - # nearest neighbor interpolation - uv_map = (uv_map * texture_size).floor().long() - obs = observation[mask] - uv_map = uv_map[mask] - idx = uv_map[:, 0] + (texture_size - uv_map[:, 1] - 1) * texture_size - texture = texture.scatter_add(0, idx.view(-1, 1).expand(-1, 3), obs) - texture_weights = texture_weights.scatter_add(0, idx, torch.ones((obs.shape[0]), dtype=torch.float32, device=texture.device)) - - mask = texture_weights > 0 - texture[mask] /= texture_weights[mask][:, None] - texture = np.clip(texture.reshape(texture_size, texture_size, 3).cpu().numpy() * 255, 0, 255).astype(np.uint8) - - # inpaint - mask = (texture_weights == 0).cpu().numpy().astype(np.uint8).reshape(texture_size, texture_size) - texture = cv2.inpaint(texture, mask, 3, cv2.INPAINT_TELEA) - - elif mode == 'opt': - rastctx = utils3d.torch.RastContext(backend='cuda') - observations = [observations.flip(0) for observations in observations] - masks = [m.flip(0) for m in masks] - _uv = [] - _uv_dr = [] - for observation, view, projection in tqdm(zip(observations, views, projections), total=len(views), disable=not verbose, desc='Texture baking (opt): UV'): - with torch.no_grad(): - rast = utils3d.torch.rasterize_triangle_faces( - rastctx, vertices[None], faces, observation.shape[1], observation.shape[0], uv=uvs[None], view=view, projection=projection - ) - _uv.append(rast['uv'].detach()) - _uv_dr.append(rast['uv_dr'].detach()) - - texture = torch.nn.Parameter(torch.zeros((1, texture_size, texture_size, 3), dtype=torch.float32).cuda()) - optimizer = torch.optim.Adam([texture], betas=(0.5, 0.9), lr=1e-2) - - def exp_anealing(optimizer, step, total_steps, start_lr, end_lr): - return start_lr * (end_lr / start_lr) ** (step / total_steps) - - def cosine_anealing(optimizer, step, total_steps, start_lr, end_lr): - return end_lr + 0.5 * (start_lr - end_lr) * (1 + np.cos(np.pi * step / total_steps)) - - def tv_loss(texture): - return torch.nn.functional.l1_loss(texture[:, :-1, :, :], texture[:, 1:, :, :]) + \ - torch.nn.functional.l1_loss(texture[:, :, :-1, :], texture[:, :, 1:, :]) - - total_steps = 2500 - with tqdm(total=total_steps, disable=not verbose, desc='Texture baking (opt): optimizing') as pbar: - for step in range(total_steps): - optimizer.zero_grad() - selected = np.random.randint(0, len(views)) - uv, uv_dr, observation, mask = _uv[selected], _uv_dr[selected], observations[selected], masks[selected] - render = dr.texture(texture, uv, uv_dr)[0] - loss = torch.nn.functional.l1_loss(render[mask], observation[mask]) - if lambda_tv > 0: - loss += lambda_tv * tv_loss(texture) - loss.backward() - optimizer.step() - # annealing - optimizer.param_groups[0]['lr'] = cosine_anealing(optimizer, step, total_steps, 1e-2, 1e-5) - pbar.set_postfix({'loss': loss.item()}) - pbar.update() - texture = np.clip(texture[0].flip(0).detach().cpu().numpy() * 255, 0, 255).astype(np.uint8) - mask = 1 - utils3d.torch.rasterize_triangle_faces( - rastctx, (uvs * 2 - 1)[None], faces, texture_size, texture_size - )['mask'][0].detach().cpu().numpy().astype(np.uint8) - texture = cv2.inpaint(texture, mask, 3, cv2.INPAINT_TELEA) - else: - raise ValueError(f'Unknown mode: {mode}') - - return texture - - -def to_glb( - app_rep: Union[Strivec, Gaussian], - mesh: MeshExtractResult, - simplify: float = 0.95, - fill_holes: bool = True, - fill_holes_max_size: float = 0.04, - texture_size: int = 1024, - debug: bool = False, - verbose: bool = True, -) -> trimesh.Trimesh: - """ - Convert a generated asset to a glb file. - - Args: - app_rep (Union[Strivec, Gaussian]): Appearance representation. - mesh (MeshExtractResult): Extracted mesh. - simplify (float): Ratio of faces to remove in simplification. - fill_holes (bool): Whether to fill holes in the mesh. - fill_holes_max_size (float): Maximum area of a hole to fill. - texture_size (int): Size of the texture. - debug (bool): Whether to print debug information. - verbose (bool): Whether to print progress. - """ - vertices = mesh.vertices.cpu().numpy() - faces = mesh.faces.cpu().numpy() - - # mesh postprocess - vertices, faces = postprocess_mesh( - vertices, faces, - simplify=simplify > 0, - simplify_ratio=simplify, - fill_holes=fill_holes, - fill_holes_max_hole_size=fill_holes_max_size, - fill_holes_max_hole_nbe=int(250 * np.sqrt(1-simplify)), - fill_holes_resolution=1024, - fill_holes_num_views=1000, - debug=debug, - verbose=verbose, - ) - - # parametrize mesh - vertices, faces, uvs = parametrize_mesh(vertices, faces) - - # bake texture - observations, extrinsics, intrinsics = render_multiview(app_rep, resolution=1024, nviews=100) - masks = [np.any(observation > 0, axis=-1) for observation in observations] - extrinsics = [extrinsics[i].cpu().numpy() for i in range(len(extrinsics))] - intrinsics = [intrinsics[i].cpu().numpy() for i in range(len(intrinsics))] - texture = bake_texture( - vertices, faces, uvs, - observations, masks, extrinsics, intrinsics, - texture_size=texture_size, mode='opt', - lambda_tv=0.01, - verbose=verbose - ) - texture = Image.fromarray(texture) - - # rotate mesh (from z-up to y-up) - vertices = vertices @ np.array([[1, 0, 0], [0, 0, -1], [0, 1, 0]]) - material = trimesh.visual.material.PBRMaterial( - roughnessFactor=1.0, - baseColorTexture=texture, - baseColorFactor=np.array([255, 255, 255, 255], dtype=np.uint8) - ) - mesh = trimesh.Trimesh(vertices, faces, visual=trimesh.visual.TextureVisuals(uv=uvs, material=material)) - return mesh - - -def simplify_gs( - gs: Gaussian, - simplify: float = 0.95, - verbose: bool = True, -): - """ - Simplify 3D Gaussians - NOTE: this function is not used in the current implementation for the unsatisfactory performance. - - Args: - gs (Gaussian): 3D Gaussian. - simplify (float): Ratio of Gaussians to remove in simplification. - """ - if simplify <= 0: - return gs - - # simplify - observations, extrinsics, intrinsics = render_multiview(gs, resolution=1024, nviews=100) - observations = [torch.tensor(obs / 255.0).float().cuda().permute(2, 0, 1) for obs in observations] - - # Following https://arxiv.org/pdf/2411.06019 - renderer = GaussianRenderer({ - "resolution": 1024, - "near": 0.8, - "far": 1.6, - "ssaa": 1, - "bg_color": (0,0,0), - }) - new_gs = Gaussian(**gs.init_params) - new_gs._features_dc = gs._features_dc.clone() - new_gs._features_rest = gs._features_rest.clone() if gs._features_rest is not None else None - new_gs._opacity = torch.nn.Parameter(gs._opacity.clone()) - new_gs._rotation = torch.nn.Parameter(gs._rotation.clone()) - new_gs._scaling = torch.nn.Parameter(gs._scaling.clone()) - new_gs._xyz = torch.nn.Parameter(gs._xyz.clone()) - - start_lr = [1e-4, 1e-3, 5e-3, 0.025] - end_lr = [1e-6, 1e-5, 5e-5, 0.00025] - optimizer = torch.optim.Adam([ - {"params": new_gs._xyz, "lr": start_lr[0]}, - {"params": new_gs._rotation, "lr": start_lr[1]}, - {"params": new_gs._scaling, "lr": start_lr[2]}, - {"params": new_gs._opacity, "lr": start_lr[3]}, - ], lr=start_lr[0]) - - def exp_anealing(optimizer, step, total_steps, start_lr, end_lr): - return start_lr * (end_lr / start_lr) ** (step / total_steps) - - def cosine_anealing(optimizer, step, total_steps, start_lr, end_lr): - return end_lr + 0.5 * (start_lr - end_lr) * (1 + np.cos(np.pi * step / total_steps)) - - _zeta = new_gs.get_opacity.clone().detach().squeeze() - _lambda = torch.zeros_like(_zeta) - _delta = 1e-7 - _interval = 10 - num_target = int((1 - simplify) * _zeta.shape[0]) - - with tqdm(total=2500, disable=not verbose, desc='Simplifying Gaussian') as pbar: - for i in range(2500): - # prune - if i % 100 == 0: - mask = new_gs.get_opacity.squeeze() > 0.05 - mask = torch.nonzero(mask).squeeze() - new_gs._xyz = torch.nn.Parameter(new_gs._xyz[mask]) - new_gs._rotation = torch.nn.Parameter(new_gs._rotation[mask]) - new_gs._scaling = torch.nn.Parameter(new_gs._scaling[mask]) - new_gs._opacity = torch.nn.Parameter(new_gs._opacity[mask]) - new_gs._features_dc = new_gs._features_dc[mask] - new_gs._features_rest = new_gs._features_rest[mask] if new_gs._features_rest is not None else None - _zeta = _zeta[mask] - _lambda = _lambda[mask] - # update optimizer state - for param_group, new_param in zip(optimizer.param_groups, [new_gs._xyz, new_gs._rotation, new_gs._scaling, new_gs._opacity]): - stored_state = optimizer.state[param_group['params'][0]] - if 'exp_avg' in stored_state: - stored_state['exp_avg'] = stored_state['exp_avg'][mask] - stored_state['exp_avg_sq'] = stored_state['exp_avg_sq'][mask] - del optimizer.state[param_group['params'][0]] - param_group['params'][0] = new_param - optimizer.state[param_group['params'][0]] = stored_state - - opacity = new_gs.get_opacity.squeeze() - - # sparisfy - if i % _interval == 0: - _zeta = _lambda + opacity.detach() - if opacity.shape[0] > num_target: - index = _zeta.topk(num_target)[1] - _m = torch.ones_like(_zeta, dtype=torch.bool) - _m[index] = 0 - _zeta[_m] = 0 - _lambda = _lambda + opacity.detach() - _zeta - - # sample a random view - view_idx = np.random.randint(len(observations)) - observation = observations[view_idx] - extrinsic = extrinsics[view_idx] - intrinsic = intrinsics[view_idx] - - color = renderer.render(new_gs, extrinsic, intrinsic)['color'] - rgb_loss = torch.nn.functional.l1_loss(color, observation) - loss = rgb_loss + \ - _delta * torch.sum(torch.pow(_lambda + opacity - _zeta, 2)) - - optimizer.zero_grad() - loss.backward() - optimizer.step() - - # update lr - for j in range(len(optimizer.param_groups)): - optimizer.param_groups[j]['lr'] = cosine_anealing(optimizer, i, 2500, start_lr[j], end_lr[j]) - - pbar.set_postfix({'loss': rgb_loss.item(), 'num': opacity.shape[0], 'lambda': _lambda.mean().item()}) - pbar.update() - - new_gs._xyz = new_gs._xyz.data - new_gs._rotation = new_gs._rotation.data - new_gs._scaling = new_gs._scaling.data - new_gs._opacity = new_gs._opacity.data - - return new_gs diff --git a/trellis/utils/random_utils.py b/trellis/utils/random_utils.py deleted file mode 100644 index 420c4146cbb4c2973f2cc2e69f376a3836e65eeb..0000000000000000000000000000000000000000 --- a/trellis/utils/random_utils.py +++ /dev/null @@ -1,30 +0,0 @@ -import numpy as np - -PRIMES = [2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37, 41, 43, 47, 53] - -def radical_inverse(base, n): - val = 0 - inv_base = 1.0 / base - inv_base_n = inv_base - while n > 0: - digit = n % base - val += digit * inv_base_n - n //= base - inv_base_n *= inv_base - return val - -def halton_sequence(dim, n): - return [radical_inverse(PRIMES[dim], n) for dim in range(dim)] - -def hammersley_sequence(dim, n, num_samples): - return [n / num_samples] + halton_sequence(dim - 1, n) - -def sphere_hammersley_sequence(n, num_samples, offset=(0, 0), remap=False): - u, v = hammersley_sequence(2, n, num_samples) - u += offset[0] / num_samples - v += offset[1] - if remap: - u = 2 * u if u < 0.25 else 2 / 3 * u + 1 / 3 - theta = np.arccos(1 - 2 * u) - np.pi / 2 - phi = v * 2 * np.pi - return [phi, theta] \ No newline at end of file diff --git a/trellis/utils/render_utils.py b/trellis/utils/render_utils.py deleted file mode 100644 index 54a33ce79c6d3e1e358ab1650ea14cfe1d50ba91..0000000000000000000000000000000000000000 --- a/trellis/utils/render_utils.py +++ /dev/null @@ -1,116 +0,0 @@ -import torch -import numpy as np -from tqdm import tqdm -import utils3d -from PIL import Image - -from ..renderers import OctreeRenderer, GaussianRenderer, MeshRenderer -from ..representations import Octree, Gaussian, MeshExtractResult -from ..modules import sparse as sp -from .random_utils import sphere_hammersley_sequence - - -def yaw_pitch_r_fov_to_extrinsics_intrinsics(yaws, pitchs, rs, fovs): - is_list = isinstance(yaws, list) - if not is_list: - yaws = [yaws] - pitchs = [pitchs] - if not isinstance(rs, list): - rs = [rs] * len(yaws) - if not isinstance(fovs, list): - fovs = [fovs] * len(yaws) - extrinsics = [] - intrinsics = [] - for yaw, pitch, r, fov in zip(yaws, pitchs, rs, fovs): - fov = torch.deg2rad(torch.tensor(float(fov))).cuda() - yaw = torch.tensor(float(yaw)).cuda() - pitch = torch.tensor(float(pitch)).cuda() - orig = torch.tensor([ - torch.sin(yaw) * torch.cos(pitch), - torch.cos(yaw) * torch.cos(pitch), - torch.sin(pitch), - ]).cuda() * r - extr = utils3d.torch.extrinsics_look_at(orig, torch.tensor([0, 0, 0]).float().cuda(), torch.tensor([0, 0, 1]).float().cuda()) - intr = utils3d.torch.intrinsics_from_fov_xy(fov, fov) - extrinsics.append(extr) - intrinsics.append(intr) - if not is_list: - extrinsics = extrinsics[0] - intrinsics = intrinsics[0] - return extrinsics, intrinsics - - -def render_frames(sample, extrinsics, intrinsics, options={}, colors_overwrite=None, verbose=True, **kwargs): - if isinstance(sample, Octree): - renderer = OctreeRenderer() - renderer.rendering_options.resolution = options.get('resolution', 512) - renderer.rendering_options.near = options.get('near', 0.8) - renderer.rendering_options.far = options.get('far', 1.6) - renderer.rendering_options.bg_color = options.get('bg_color', (0, 0, 0)) - renderer.rendering_options.ssaa = options.get('ssaa', 4) - renderer.pipe.primitive = sample.primitive - elif isinstance(sample, Gaussian): - renderer = GaussianRenderer() - renderer.rendering_options.resolution = options.get('resolution', 512) - renderer.rendering_options.near = options.get('near', 0.8) - renderer.rendering_options.far = options.get('far', 1.6) - renderer.rendering_options.bg_color = options.get('bg_color', (0, 0, 0)) - renderer.rendering_options.ssaa = options.get('ssaa', 1) - renderer.pipe.kernel_size = kwargs.get('kernel_size', 0.1) - renderer.pipe.use_mip_gaussian = True - elif isinstance(sample, MeshExtractResult): - renderer = MeshRenderer() - renderer.rendering_options.resolution = options.get('resolution', 512) - renderer.rendering_options.near = options.get('near', 1) - renderer.rendering_options.far = options.get('far', 100) - renderer.rendering_options.ssaa = options.get('ssaa', 4) - else: - raise ValueError(f'Unsupported sample type: {type(sample)}') - - rets = {} - for j, (extr, intr) in tqdm(enumerate(zip(extrinsics, intrinsics)), desc='Rendering', disable=not verbose): - if not isinstance(sample, MeshExtractResult): - res = renderer.render(sample, extr, intr, colors_overwrite=colors_overwrite) - if 'color' not in rets: rets['color'] = [] - if 'depth' not in rets: rets['depth'] = [] - rets['color'].append(np.clip(res['color'].detach().cpu().numpy().transpose(1, 2, 0) * 255, 0, 255).astype(np.uint8)) - if 'percent_depth' in res: - rets['depth'].append(res['percent_depth'].detach().cpu().numpy()) - elif 'depth' in res: - rets['depth'].append(res['depth'].detach().cpu().numpy()) - else: - rets['depth'].append(None) - else: - res = renderer.render(sample, extr, intr) - if 'normal' not in rets: rets['normal'] = [] - rets['normal'].append(np.clip(res['normal'].detach().cpu().numpy().transpose(1, 2, 0) * 255, 0, 255).astype(np.uint8)) - return rets - - -def render_video(sample, resolution=512, bg_color=(0, 0, 0), num_frames=300, r=2, fov=40, **kwargs): - yaws = torch.linspace(0, 2 * 3.1415, num_frames) - pitch = 0.25 + 0.5 * torch.sin(torch.linspace(0, 2 * 3.1415, num_frames)) - yaws = yaws.tolist() - pitch = pitch.tolist() - extrinsics, intrinsics = yaw_pitch_r_fov_to_extrinsics_intrinsics(yaws, pitch, r, fov) - return render_frames(sample, extrinsics, intrinsics, {'resolution': resolution, 'bg_color': bg_color}, **kwargs) - - -def render_multiview(sample, resolution=512, nviews=30): - r = 2 - fov = 40 - cams = [sphere_hammersley_sequence(i, nviews) for i in range(nviews)] - yaws = [cam[0] for cam in cams] - pitchs = [cam[1] for cam in cams] - extrinsics, intrinsics = yaw_pitch_r_fov_to_extrinsics_intrinsics(yaws, pitchs, r, fov) - res = render_frames(sample, extrinsics, intrinsics, {'resolution': resolution, 'bg_color': (0, 0, 0)}) - return res['color'], extrinsics, intrinsics - - -def render_snapshot(samples, resolution=512, bg_color=(0, 0, 0), offset=(-16 / 180 * np.pi, 20 / 180 * np.pi), r=10, fov=8, **kwargs): - yaw = [0, np.pi/2, np.pi, 3*np.pi/2] - yaw_offset = offset[0] - yaw = [y + yaw_offset for y in yaw] - pitch = [offset[1] for _ in range(4)] - extrinsics, intrinsics = yaw_pitch_r_fov_to_extrinsics_intrinsics(yaw, pitch, r, fov) - return render_frames(samples, extrinsics, intrinsics, {'resolution': resolution, 'bg_color': bg_color}, **kwargs)