from typing import * import torch import torch.nn as nn import numpy as np from transformers import CLIPTextModel, AutoTokenizer # import open3d as o3d from .base import Pipeline from . import samplers from ..modules import sparse as sp class TrellisTextTo3DPipeline(Pipeline): """ Pipeline for inferring Trellis text-to-3D models. Args: models (dict[str, nn.Module]): The models to use in the pipeline. sparse_structure_sampler (samplers.Sampler): The sampler for the sparse structure. slat_sampler (samplers.Sampler): The sampler for the structured latent. slat_normalization (dict): The normalization parameters for the structured latent. text_cond_model (str): The name of the text conditioning model. """ def __init__( self, models: dict[str, nn.Module] = None, sparse_structure_sampler: samplers.Sampler = None, slat_sampler: samplers.Sampler = None, slat_normalization: dict = None, text_cond_model: str = None, ): if models is None: return super().__init__(models) self.sparse_structure_sampler = sparse_structure_sampler self.slat_sampler = slat_sampler self.sparse_structure_sampler_params = {} self.slat_sampler_params = {} self.slat_normalization = slat_normalization self._init_text_cond_model(text_cond_model) @staticmethod def from_pretrained(path: str) -> "TrellisTextTo3DPipeline": """ Load a pretrained model. Args: path (str): The path to the model. Can be either local path or a Hugging Face repository. """ pipeline = super(TrellisTextTo3DPipeline, TrellisTextTo3DPipeline).from_pretrained(path) new_pipeline = TrellisTextTo3DPipeline() new_pipeline.__dict__ = pipeline.__dict__ args = pipeline._pretrained_args new_pipeline.sparse_structure_sampler = getattr(samplers, args['sparse_structure_sampler']['name'])(**args['sparse_structure_sampler']['args']) new_pipeline.sparse_structure_sampler_params = args['sparse_structure_sampler']['params'] new_pipeline.slat_sampler = getattr(samplers, args['slat_sampler']['name'])(**args['slat_sampler']['args']) new_pipeline.slat_sampler_params = args['slat_sampler']['params'] new_pipeline.slat_normalization = args['slat_normalization'] new_pipeline._init_text_cond_model(args['text_cond_model']) return new_pipeline def _init_text_cond_model(self, name: str): """ Initialize the text conditioning model. """ # load model model = CLIPTextModel.from_pretrained(name) tokenizer = AutoTokenizer.from_pretrained(name) model.eval() model = model.cuda() self.text_cond_model = { 'model': model, 'tokenizer': tokenizer, } self.text_cond_model['null_cond'] = self.encode_text(['']) @torch.no_grad() def encode_text(self, text: List[str]) -> torch.Tensor: """ Encode the text. """ assert isinstance(text, list) and all(isinstance(t, str) for t in text), "text must be a list of strings" encoding = self.text_cond_model['tokenizer'](text, max_length=77, padding='max_length', truncation=True, return_tensors='pt') tokens = encoding['input_ids'].cuda() embeddings = self.text_cond_model['model'](input_ids=tokens).last_hidden_state return embeddings def get_cond(self, prompt: List[str]) -> dict: """ Get the conditioning information for the model. Args: prompt (List[str]): The text prompt. Returns: dict: The conditioning information """ cond = self.encode_text(prompt) neg_cond = self.text_cond_model['null_cond'] return { 'cond': cond, 'neg_cond': neg_cond, } def sample_sparse_structure( self, cond: dict, num_samples: int = 1, sampler_params: dict = {}, ) -> torch.Tensor: """ Sample sparse structures with the given conditioning. Args: cond (dict): The conditioning information. num_samples (int): The number of samples to generate. sampler_params (dict): Additional parameters for the sampler. """ # Sample occupancy latent flow_model = self.models['sparse_structure_flow_model'] reso = flow_model.resolution noise = torch.randn(num_samples, flow_model.in_channels, reso, reso, reso).to(self.device) sampler_params = {**self.sparse_structure_sampler_params, **sampler_params} z_s = self.sparse_structure_sampler.sample( flow_model, noise, **cond, **sampler_params, verbose=True ).samples # Decode occupancy latent decoder = self.models['sparse_structure_decoder'] coords = torch.argwhere(decoder(z_s)>0)[:, [0, 2, 3, 4]].int() return coords def decode_slat( self, slat: sp.SparseTensor, formats: List[str] = ['mesh', 'gaussian', 'radiance_field'], ) -> dict: """ Decode the structured latent. Args: slat (sp.SparseTensor): The structured latent. formats (List[str]): The formats to decode the structured latent to. Returns: dict: The decoded structured latent. """ ret = {} if 'mesh' in formats: ret['mesh'] = self.models['slat_decoder_mesh'](slat) if 'gaussian' in formats: ret['gaussian'] = self.models['slat_decoder_gs'](slat) if 'radiance_field' in formats: ret['radiance_field'] = self.models['slat_decoder_rf'](slat) return ret def sample_slat( self, cond: dict, coords: torch.Tensor, sampler_params: dict = {}, ) -> sp.SparseTensor: """ Sample structured latent with the given conditioning. Args: cond (dict): The conditioning information. coords (torch.Tensor): The coordinates of the sparse structure. sampler_params (dict): Additional parameters for the sampler. """ # Sample structured latent flow_model = self.models['slat_flow_model'] noise = sp.SparseTensor( feats=torch.randn(coords.shape[0], flow_model.in_channels).to(self.device), coords=coords, ) sampler_params = {**self.slat_sampler_params, **sampler_params} slat = self.slat_sampler.sample( flow_model, noise, **cond, **sampler_params, verbose=True ).samples std = torch.tensor(self.slat_normalization['std'])[None].to(slat.device) mean = torch.tensor(self.slat_normalization['mean'])[None].to(slat.device) slat = slat * std + mean return slat @torch.no_grad() def run( self, prompt: str, num_samples: int = 1, seed: int = 42, sparse_structure_sampler_params: dict = {}, slat_sampler_params: dict = {}, formats: List[str] = ['mesh', 'gaussian', 'radiance_field'], ) -> dict: """ Run the pipeline. Args: prompt (str): The text prompt. num_samples (int): The number of samples to generate. seed (int): The random seed. sparse_structure_sampler_params (dict): Additional parameters for the sparse structure sampler. slat_sampler_params (dict): Additional parameters for the structured latent sampler. formats (List[str]): The formats to decode the structured latent to. """ cond = self.get_cond([prompt]) torch.manual_seed(seed) coords = self.sample_sparse_structure(cond, num_samples, sparse_structure_sampler_params) slat = self.sample_slat(cond, coords, slat_sampler_params) return self.decode_slat(slat, formats) ''' def voxelize(self, mesh: o3d.geometry.TriangleMesh) -> torch.Tensor: """ Voxelize a mesh. Args: mesh (o3d.geometry.TriangleMesh): The mesh to voxelize. sha256 (str): The SHA256 hash of the mesh. output_dir (str): The output directory. """ vertices = np.asarray(mesh.vertices) aabb = np.stack([vertices.min(0), vertices.max(0)]) center = (aabb[0] + aabb[1]) / 2 scale = (aabb[1] - aabb[0]).max() vertices = (vertices - center) / scale vertices = np.clip(vertices, -0.5 + 1e-6, 0.5 - 1e-6) mesh.vertices = o3d.utility.Vector3dVector(vertices) voxel_grid = o3d.geometry.VoxelGrid.create_from_triangle_mesh_within_bounds(mesh, voxel_size=1/64, min_bound=(-0.5, -0.5, -0.5), max_bound=(0.5, 0.5, 0.5)) vertices = np.array([voxel.grid_index for voxel in voxel_grid.get_voxels()]) return torch.tensor(vertices).int().cuda() @torch.no_grad() def run_variant( self, mesh: o3d.geometry.TriangleMesh, prompt: str, num_samples: int = 1, seed: int = 42, slat_sampler_params: dict = {}, formats: List[str] = ['mesh', 'gaussian', 'radiance_field'], ) -> dict: """ Run the pipeline for making variants of an asset. Args: mesh (o3d.geometry.TriangleMesh): The base mesh. prompt (str): The text prompt. num_samples (int): The number of samples to generate. seed (int): The random seed slat_sampler_params (dict): Additional parameters for the structured latent sampler. formats (List[str]): The formats to decode the structured latent to. """ cond = self.get_cond([prompt]) coords = self.voxelize(mesh) coords = torch.cat([ torch.arange(num_samples).repeat_interleave(coords.shape[0], 0)[:, None].int().cuda(), coords.repeat(num_samples, 1) ], 1) torch.manual_seed(seed) slat = self.sample_slat(cond, coords, slat_sampler_params) return self.decode_slat(slat, formats) '''