cavargas10 commited on
Commit
8ddc992
·
verified ·
1 Parent(s): 5668340

Delete trellis/pipelines

Browse files
trellis/pipelines/__init__.py DELETED
@@ -1,25 +0,0 @@
1
- from . import samplers
2
- from .trellis_image_to_3d import TrellisImageTo3DPipeline
3
- from .trellis_text_to_3d import TrellisTextTo3DPipeline
4
-
5
-
6
- def from_pretrained(path: str):
7
- """
8
- Load a pipeline from a model folder or a Hugging Face model hub.
9
-
10
- Args:
11
- path: The path to the model. Can be either local path or a Hugging Face model name.
12
- """
13
- import os
14
- import json
15
- is_local = os.path.exists(f"{path}/pipeline.json")
16
-
17
- if is_local:
18
- config_file = f"{path}/pipeline.json"
19
- else:
20
- from huggingface_hub import hf_hub_download
21
- config_file = hf_hub_download(path, "pipeline.json")
22
-
23
- with open(config_file, 'r') as f:
24
- config = json.load(f)
25
- return globals()[config['name']].from_pretrained(path)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
trellis/pipelines/base.py DELETED
@@ -1,68 +0,0 @@
1
- from typing import *
2
- import torch
3
- import torch.nn as nn
4
- from .. import models
5
-
6
-
7
- class Pipeline:
8
- """
9
- A base class for pipelines.
10
- """
11
- def __init__(
12
- self,
13
- models: dict[str, nn.Module] = None,
14
- ):
15
- if models is None:
16
- return
17
- self.models = models
18
- for model in self.models.values():
19
- model.eval()
20
-
21
- @staticmethod
22
- def from_pretrained(path: str) -> "Pipeline":
23
- """
24
- Load a pretrained model.
25
- """
26
- import os
27
- import json
28
- is_local = os.path.exists(f"{path}/pipeline.json")
29
-
30
- if is_local:
31
- config_file = f"{path}/pipeline.json"
32
- else:
33
- from huggingface_hub import hf_hub_download
34
- config_file = hf_hub_download(path, "pipeline.json")
35
-
36
- with open(config_file, 'r') as f:
37
- args = json.load(f)['args']
38
-
39
- _models = {}
40
- for k, v in args['models'].items():
41
- try:
42
- _models[k] = models.from_pretrained(f"{path}/{v}")
43
- except:
44
- _models[k] = models.from_pretrained(v)
45
-
46
- new_pipeline = Pipeline(_models)
47
- new_pipeline._pretrained_args = args
48
- return new_pipeline
49
-
50
- @property
51
- def device(self) -> torch.device:
52
- for model in self.models.values():
53
- if hasattr(model, 'device'):
54
- return model.device
55
- for model in self.models.values():
56
- if hasattr(model, 'parameters'):
57
- return next(model.parameters()).device
58
- raise RuntimeError("No device found.")
59
-
60
- def to(self, device: torch.device) -> None:
61
- for model in self.models.values():
62
- model.to(device)
63
-
64
- def cuda(self) -> None:
65
- self.to(torch.device("cuda"))
66
-
67
- def cpu(self) -> None:
68
- self.to(torch.device("cpu"))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
trellis/pipelines/samplers/__init__.py DELETED
@@ -1,2 +0,0 @@
1
- from .base import Sampler
2
- from .flow_euler import FlowEulerSampler, FlowEulerCfgSampler, FlowEulerGuidanceIntervalSampler
 
 
 
trellis/pipelines/samplers/base.py DELETED
@@ -1,20 +0,0 @@
1
- from typing import *
2
- from abc import ABC, abstractmethod
3
-
4
-
5
- class Sampler(ABC):
6
- """
7
- A base class for samplers.
8
- """
9
-
10
- @abstractmethod
11
- def sample(
12
- self,
13
- model,
14
- **kwargs
15
- ):
16
- """
17
- Sample from a model.
18
- """
19
- pass
20
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
trellis/pipelines/samplers/classifier_free_guidance_mixin.py DELETED
@@ -1,12 +0,0 @@
1
- from typing import *
2
-
3
-
4
- class ClassifierFreeGuidanceSamplerMixin:
5
- """
6
- A mixin class for samplers that apply classifier-free guidance.
7
- """
8
-
9
- def _inference_model(self, model, x_t, t, cond, neg_cond, cfg_strength, **kwargs):
10
- pred = super()._inference_model(model, x_t, t, cond, **kwargs)
11
- neg_pred = super()._inference_model(model, x_t, t, neg_cond, **kwargs)
12
- return (1 + cfg_strength) * pred - cfg_strength * neg_pred
 
 
 
 
 
 
 
 
 
 
 
 
 
trellis/pipelines/samplers/flow_euler.py DELETED
@@ -1,199 +0,0 @@
1
- from typing import *
2
- import torch
3
- import numpy as np
4
- from tqdm import tqdm
5
- from easydict import EasyDict as edict
6
- from .base import Sampler
7
- from .classifier_free_guidance_mixin import ClassifierFreeGuidanceSamplerMixin
8
- from .guidance_interval_mixin import GuidanceIntervalSamplerMixin
9
-
10
-
11
- class FlowEulerSampler(Sampler):
12
- """
13
- Generate samples from a flow-matching model using Euler sampling.
14
-
15
- Args:
16
- sigma_min: The minimum scale of noise in flow.
17
- """
18
- def __init__(
19
- self,
20
- sigma_min: float,
21
- ):
22
- self.sigma_min = sigma_min
23
-
24
- def _eps_to_xstart(self, x_t, t, eps):
25
- assert x_t.shape == eps.shape
26
- return (x_t - (self.sigma_min + (1 - self.sigma_min) * t) * eps) / (1 - t)
27
-
28
- def _xstart_to_eps(self, x_t, t, x_0):
29
- assert x_t.shape == x_0.shape
30
- return (x_t - (1 - t) * x_0) / (self.sigma_min + (1 - self.sigma_min) * t)
31
-
32
- def _v_to_xstart_eps(self, x_t, t, v):
33
- assert x_t.shape == v.shape
34
- eps = (1 - t) * v + x_t
35
- x_0 = (1 - self.sigma_min) * x_t - (self.sigma_min + (1 - self.sigma_min) * t) * v
36
- return x_0, eps
37
-
38
- def _inference_model(self, model, x_t, t, cond=None, **kwargs):
39
- t = torch.tensor([1000 * t] * x_t.shape[0], device=x_t.device, dtype=torch.float32)
40
- return model(x_t, t, cond, **kwargs)
41
-
42
- def _get_model_prediction(self, model, x_t, t, cond=None, **kwargs):
43
- pred_v = self._inference_model(model, x_t, t, cond, **kwargs)
44
- pred_x_0, pred_eps = self._v_to_xstart_eps(x_t=x_t, t=t, v=pred_v)
45
- return pred_x_0, pred_eps, pred_v
46
-
47
- @torch.no_grad()
48
- def sample_once(
49
- self,
50
- model,
51
- x_t,
52
- t: float,
53
- t_prev: float,
54
- cond: Optional[Any] = None,
55
- **kwargs
56
- ):
57
- """
58
- Sample x_{t-1} from the model using Euler method.
59
-
60
- Args:
61
- model: The model to sample from.
62
- x_t: The [N x C x ...] tensor of noisy inputs at time t.
63
- t: The current timestep.
64
- t_prev: The previous timestep.
65
- cond: conditional information.
66
- **kwargs: Additional arguments for model inference.
67
-
68
- Returns:
69
- a dict containing the following
70
- - 'pred_x_prev': x_{t-1}.
71
- - 'pred_x_0': a prediction of x_0.
72
- """
73
- pred_x_0, pred_eps, pred_v = self._get_model_prediction(model, x_t, t, cond, **kwargs)
74
- pred_x_prev = x_t - (t - t_prev) * pred_v
75
- return edict({"pred_x_prev": pred_x_prev, "pred_x_0": pred_x_0})
76
-
77
- @torch.no_grad()
78
- def sample(
79
- self,
80
- model,
81
- noise,
82
- cond: Optional[Any] = None,
83
- steps: int = 50,
84
- rescale_t: float = 1.0,
85
- verbose: bool = True,
86
- **kwargs
87
- ):
88
- """
89
- Generate samples from the model using Euler method.
90
-
91
- Args:
92
- model: The model to sample from.
93
- noise: The initial noise tensor.
94
- cond: conditional information.
95
- steps: The number of steps to sample.
96
- rescale_t: The rescale factor for t.
97
- verbose: If True, show a progress bar.
98
- **kwargs: Additional arguments for model_inference.
99
-
100
- Returns:
101
- a dict containing the following
102
- - 'samples': the model samples.
103
- - 'pred_x_t': a list of prediction of x_t.
104
- - 'pred_x_0': a list of prediction of x_0.
105
- """
106
- sample = noise
107
- t_seq = np.linspace(1, 0, steps + 1)
108
- t_seq = rescale_t * t_seq / (1 + (rescale_t - 1) * t_seq)
109
- t_pairs = list((t_seq[i], t_seq[i + 1]) for i in range(steps))
110
- ret = edict({"samples": None, "pred_x_t": [], "pred_x_0": []})
111
- for t, t_prev in tqdm(t_pairs, desc="Sampling", disable=not verbose):
112
- out = self.sample_once(model, sample, t, t_prev, cond, **kwargs)
113
- sample = out.pred_x_prev
114
- ret.pred_x_t.append(out.pred_x_prev)
115
- ret.pred_x_0.append(out.pred_x_0)
116
- ret.samples = sample
117
- return ret
118
-
119
-
120
- class FlowEulerCfgSampler(ClassifierFreeGuidanceSamplerMixin, FlowEulerSampler):
121
- """
122
- Generate samples from a flow-matching model using Euler sampling with classifier-free guidance.
123
- """
124
- @torch.no_grad()
125
- def sample(
126
- self,
127
- model,
128
- noise,
129
- cond,
130
- neg_cond,
131
- steps: int = 50,
132
- rescale_t: float = 1.0,
133
- cfg_strength: float = 3.0,
134
- verbose: bool = True,
135
- **kwargs
136
- ):
137
- """
138
- Generate samples from the model using Euler method.
139
-
140
- Args:
141
- model: The model to sample from.
142
- noise: The initial noise tensor.
143
- cond: conditional information.
144
- neg_cond: negative conditional information.
145
- steps: The number of steps to sample.
146
- rescale_t: The rescale factor for t.
147
- cfg_strength: The strength of classifier-free guidance.
148
- verbose: If True, show a progress bar.
149
- **kwargs: Additional arguments for model_inference.
150
-
151
- Returns:
152
- a dict containing the following
153
- - 'samples': the model samples.
154
- - 'pred_x_t': a list of prediction of x_t.
155
- - 'pred_x_0': a list of prediction of x_0.
156
- """
157
- return super().sample(model, noise, cond, steps, rescale_t, verbose, neg_cond=neg_cond, cfg_strength=cfg_strength, **kwargs)
158
-
159
-
160
- class FlowEulerGuidanceIntervalSampler(GuidanceIntervalSamplerMixin, FlowEulerSampler):
161
- """
162
- Generate samples from a flow-matching model using Euler sampling with classifier-free guidance and interval.
163
- """
164
- @torch.no_grad()
165
- def sample(
166
- self,
167
- model,
168
- noise,
169
- cond,
170
- neg_cond,
171
- steps: int = 50,
172
- rescale_t: float = 1.0,
173
- cfg_strength: float = 3.0,
174
- cfg_interval: Tuple[float, float] = (0.0, 1.0),
175
- verbose: bool = True,
176
- **kwargs
177
- ):
178
- """
179
- Generate samples from the model using Euler method.
180
-
181
- Args:
182
- model: The model to sample from.
183
- noise: The initial noise tensor.
184
- cond: conditional information.
185
- neg_cond: negative conditional information.
186
- steps: The number of steps to sample.
187
- rescale_t: The rescale factor for t.
188
- cfg_strength: The strength of classifier-free guidance.
189
- cfg_interval: The interval for classifier-free guidance.
190
- verbose: If True, show a progress bar.
191
- **kwargs: Additional arguments for model_inference.
192
-
193
- Returns:
194
- a dict containing the following
195
- - 'samples': the model samples.
196
- - 'pred_x_t': a list of prediction of x_t.
197
- - 'pred_x_0': a list of prediction of x_0.
198
- """
199
- return super().sample(model, noise, cond, steps, rescale_t, verbose, neg_cond=neg_cond, cfg_strength=cfg_strength, cfg_interval=cfg_interval, **kwargs)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
trellis/pipelines/samplers/guidance_interval_mixin.py DELETED
@@ -1,15 +0,0 @@
1
- from typing import *
2
-
3
-
4
- class GuidanceIntervalSamplerMixin:
5
- """
6
- A mixin class for samplers that apply classifier-free guidance with interval.
7
- """
8
-
9
- def _inference_model(self, model, x_t, t, cond, neg_cond, cfg_strength, cfg_interval, **kwargs):
10
- if cfg_interval[0] <= t <= cfg_interval[1]:
11
- pred = super()._inference_model(model, x_t, t, cond, **kwargs)
12
- neg_pred = super()._inference_model(model, x_t, t, neg_cond, **kwargs)
13
- return (1 + cfg_strength) * pred - cfg_strength * neg_pred
14
- else:
15
- return super()._inference_model(model, x_t, t, cond, **kwargs)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
trellis/pipelines/trellis_image_to_3d.py DELETED
@@ -1,376 +0,0 @@
1
- from typing import *
2
- from contextlib import contextmanager
3
- import torch
4
- import torch.nn as nn
5
- import torch.nn.functional as F
6
- import numpy as np
7
- from tqdm import tqdm
8
- from easydict import EasyDict as edict
9
- from torchvision import transforms
10
- from PIL import Image
11
- import rembg
12
- from .base import Pipeline
13
- from . import samplers
14
- from ..modules import sparse as sp
15
- from ..representations import Gaussian, Strivec, MeshExtractResult
16
-
17
-
18
- class TrellisImageTo3DPipeline(Pipeline):
19
- """
20
- Pipeline for inferring Trellis image-to-3D models.
21
-
22
- Args:
23
- models (dict[str, nn.Module]): The models to use in the pipeline.
24
- sparse_structure_sampler (samplers.Sampler): The sampler for the sparse structure.
25
- slat_sampler (samplers.Sampler): The sampler for the structured latent.
26
- slat_normalization (dict): The normalization parameters for the structured latent.
27
- image_cond_model (str): The name of the image conditioning model.
28
- """
29
- def __init__(
30
- self,
31
- models: dict[str, nn.Module] = None,
32
- sparse_structure_sampler: samplers.Sampler = None,
33
- slat_sampler: samplers.Sampler = None,
34
- slat_normalization: dict = None,
35
- image_cond_model: str = None,
36
- ):
37
- if models is None:
38
- return
39
- super().__init__(models)
40
- self.sparse_structure_sampler = sparse_structure_sampler
41
- self.slat_sampler = slat_sampler
42
- self.sparse_structure_sampler_params = {}
43
- self.slat_sampler_params = {}
44
- self.slat_normalization = slat_normalization
45
- self.rembg_session = None
46
- self._init_image_cond_model(image_cond_model)
47
-
48
- @staticmethod
49
- def from_pretrained(path: str) -> "TrellisImageTo3DPipeline":
50
- """
51
- Load a pretrained model.
52
-
53
- Args:
54
- path (str): The path to the model. Can be either local path or a Hugging Face repository.
55
- """
56
- pipeline = super(TrellisImageTo3DPipeline, TrellisImageTo3DPipeline).from_pretrained(path)
57
- new_pipeline = TrellisImageTo3DPipeline()
58
- new_pipeline.__dict__ = pipeline.__dict__
59
- args = pipeline._pretrained_args
60
-
61
- new_pipeline.sparse_structure_sampler = getattr(samplers, args['sparse_structure_sampler']['name'])(**args['sparse_structure_sampler']['args'])
62
- new_pipeline.sparse_structure_sampler_params = args['sparse_structure_sampler']['params']
63
-
64
- new_pipeline.slat_sampler = getattr(samplers, args['slat_sampler']['name'])(**args['slat_sampler']['args'])
65
- new_pipeline.slat_sampler_params = args['slat_sampler']['params']
66
-
67
- new_pipeline.slat_normalization = args['slat_normalization']
68
-
69
- new_pipeline._init_image_cond_model(args['image_cond_model'])
70
-
71
- return new_pipeline
72
-
73
- def _init_image_cond_model(self, name: str):
74
- """
75
- Initialize the image conditioning model.
76
- """
77
- dinov2_model = torch.hub.load('facebookresearch/dinov2', name, pretrained=True)
78
- dinov2_model.eval()
79
- self.models['image_cond_model'] = dinov2_model
80
- transform = transforms.Compose([
81
- transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
82
- ])
83
- self.image_cond_model_transform = transform
84
-
85
- def preprocess_image(self, input: Image.Image) -> Image.Image:
86
- """
87
- Preprocess the input image.
88
- """
89
- # if has alpha channel, use it directly; otherwise, remove background
90
- has_alpha = False
91
- if input.mode == 'RGBA':
92
- alpha = np.array(input)[:, :, 3]
93
- if not np.all(alpha == 255):
94
- has_alpha = True
95
- if has_alpha:
96
- output = input
97
- else:
98
- input = input.convert('RGB')
99
- max_size = max(input.size)
100
- scale = min(1, 1024 / max_size)
101
- if scale < 1:
102
- input = input.resize((int(input.width * scale), int(input.height * scale)), Image.Resampling.LANCZOS)
103
- if getattr(self, 'rembg_session', None) is None:
104
- self.rembg_session = rembg.new_session('u2net')
105
- output = rembg.remove(input, session=self.rembg_session)
106
- output_np = np.array(output)
107
- alpha = output_np[:, :, 3]
108
- bbox = np.argwhere(alpha > 0.8 * 255)
109
- bbox = np.min(bbox[:, 1]), np.min(bbox[:, 0]), np.max(bbox[:, 1]), np.max(bbox[:, 0])
110
- center = (bbox[0] + bbox[2]) / 2, (bbox[1] + bbox[3]) / 2
111
- size = max(bbox[2] - bbox[0], bbox[3] - bbox[1])
112
- size = int(size * 1.2)
113
- bbox = center[0] - size // 2, center[1] - size // 2, center[0] + size // 2, center[1] + size // 2
114
- output = output.crop(bbox) # type: ignore
115
- output = output.resize((518, 518), Image.Resampling.LANCZOS)
116
- output = np.array(output).astype(np.float32) / 255
117
- output = output[:, :, :3] * output[:, :, 3:4]
118
- output = Image.fromarray((output * 255).astype(np.uint8))
119
- return output
120
-
121
- @torch.no_grad()
122
- def encode_image(self, image: Union[torch.Tensor, list[Image.Image]]) -> torch.Tensor:
123
- """
124
- Encode the image.
125
-
126
- Args:
127
- image (Union[torch.Tensor, list[Image.Image]]): The image to encode
128
-
129
- Returns:
130
- torch.Tensor: The encoded features.
131
- """
132
- if isinstance(image, torch.Tensor):
133
- assert image.ndim == 4, "Image tensor should be batched (B, C, H, W)"
134
- elif isinstance(image, list):
135
- assert all(isinstance(i, Image.Image) for i in image), "Image list should be list of PIL images"
136
- image = [i.resize((518, 518), Image.LANCZOS) for i in image]
137
- image = [np.array(i.convert('RGB')).astype(np.float32) / 255 for i in image]
138
- image = [torch.from_numpy(i).permute(2, 0, 1).float() for i in image]
139
- image = torch.stack(image).to(self.device)
140
- else:
141
- raise ValueError(f"Unsupported type of image: {type(image)}")
142
-
143
- image = self.image_cond_model_transform(image).to(self.device)
144
- features = self.models['image_cond_model'](image, is_training=True)['x_prenorm']
145
- patchtokens = F.layer_norm(features, features.shape[-1:])
146
- return patchtokens
147
-
148
- def get_cond(self, image: Union[torch.Tensor, list[Image.Image]]) -> dict:
149
- """
150
- Get the conditioning information for the model.
151
-
152
- Args:
153
- image (Union[torch.Tensor, list[Image.Image]]): The image prompts.
154
-
155
- Returns:
156
- dict: The conditioning information
157
- """
158
- cond = self.encode_image(image)
159
- neg_cond = torch.zeros_like(cond)
160
- return {
161
- 'cond': cond,
162
- 'neg_cond': neg_cond,
163
- }
164
-
165
- def sample_sparse_structure(
166
- self,
167
- cond: dict,
168
- num_samples: int = 1,
169
- sampler_params: dict = {},
170
- ) -> torch.Tensor:
171
- """
172
- Sample sparse structures with the given conditioning.
173
-
174
- Args:
175
- cond (dict): The conditioning information.
176
- num_samples (int): The number of samples to generate.
177
- sampler_params (dict): Additional parameters for the sampler.
178
- """
179
- # Sample occupancy latent
180
- flow_model = self.models['sparse_structure_flow_model']
181
- reso = flow_model.resolution
182
- noise = torch.randn(num_samples, flow_model.in_channels, reso, reso, reso).to(self.device)
183
- sampler_params = {**self.sparse_structure_sampler_params, **sampler_params}
184
- z_s = self.sparse_structure_sampler.sample(
185
- flow_model,
186
- noise,
187
- **cond,
188
- **sampler_params,
189
- verbose=True
190
- ).samples
191
-
192
- # Decode occupancy latent
193
- decoder = self.models['sparse_structure_decoder']
194
- coords = torch.argwhere(decoder(z_s)>0)[:, [0, 2, 3, 4]].int()
195
-
196
- return coords
197
-
198
- def decode_slat(
199
- self,
200
- slat: sp.SparseTensor,
201
- formats: List[str] = ['mesh', 'gaussian', 'radiance_field'],
202
- ) -> dict:
203
- """
204
- Decode the structured latent.
205
-
206
- Args:
207
- slat (sp.SparseTensor): The structured latent.
208
- formats (List[str]): The formats to decode the structured latent to.
209
-
210
- Returns:
211
- dict: The decoded structured latent.
212
- """
213
- ret = {}
214
- if 'mesh' in formats:
215
- ret['mesh'] = self.models['slat_decoder_mesh'](slat)
216
- if 'gaussian' in formats:
217
- ret['gaussian'] = self.models['slat_decoder_gs'](slat)
218
- if 'radiance_field' in formats:
219
- ret['radiance_field'] = self.models['slat_decoder_rf'](slat)
220
- return ret
221
-
222
- def sample_slat(
223
- self,
224
- cond: dict,
225
- coords: torch.Tensor,
226
- sampler_params: dict = {},
227
- ) -> sp.SparseTensor:
228
- """
229
- Sample structured latent with the given conditioning.
230
-
231
- Args:
232
- cond (dict): The conditioning information.
233
- coords (torch.Tensor): The coordinates of the sparse structure.
234
- sampler_params (dict): Additional parameters for the sampler.
235
- """
236
- # Sample structured latent
237
- flow_model = self.models['slat_flow_model']
238
- noise = sp.SparseTensor(
239
- feats=torch.randn(coords.shape[0], flow_model.in_channels).to(self.device),
240
- coords=coords,
241
- )
242
- sampler_params = {**self.slat_sampler_params, **sampler_params}
243
- slat = self.slat_sampler.sample(
244
- flow_model,
245
- noise,
246
- **cond,
247
- **sampler_params,
248
- verbose=True
249
- ).samples
250
-
251
- std = torch.tensor(self.slat_normalization['std'])[None].to(slat.device)
252
- mean = torch.tensor(self.slat_normalization['mean'])[None].to(slat.device)
253
- slat = slat * std + mean
254
-
255
- return slat
256
-
257
- @torch.no_grad()
258
- def run(
259
- self,
260
- image: Image.Image,
261
- num_samples: int = 1,
262
- seed: int = 42,
263
- sparse_structure_sampler_params: dict = {},
264
- slat_sampler_params: dict = {},
265
- formats: List[str] = ['mesh', 'gaussian', 'radiance_field'],
266
- preprocess_image: bool = True,
267
- ) -> dict:
268
- """
269
- Run the pipeline.
270
-
271
- Args:
272
- image (Image.Image): The image prompt.
273
- num_samples (int): The number of samples to generate.
274
- sparse_structure_sampler_params (dict): Additional parameters for the sparse structure sampler.
275
- slat_sampler_params (dict): Additional parameters for the structured latent sampler.
276
- preprocess_image (bool): Whether to preprocess the image.
277
- """
278
- if preprocess_image:
279
- image = self.preprocess_image(image)
280
- cond = self.get_cond([image])
281
- torch.manual_seed(seed)
282
- coords = self.sample_sparse_structure(cond, num_samples, sparse_structure_sampler_params)
283
- slat = self.sample_slat(cond, coords, slat_sampler_params)
284
- return self.decode_slat(slat, formats)
285
-
286
- @contextmanager
287
- def inject_sampler_multi_image(
288
- self,
289
- sampler_name: str,
290
- num_images: int,
291
- num_steps: int,
292
- mode: Literal['stochastic', 'multidiffusion'] = 'stochastic',
293
- ):
294
- """
295
- Inject a sampler with multiple images as condition.
296
-
297
- Args:
298
- sampler_name (str): The name of the sampler to inject.
299
- num_images (int): The number of images to condition on.
300
- num_steps (int): The number of steps to run the sampler for.
301
- """
302
- sampler = getattr(self, sampler_name)
303
- setattr(sampler, f'_old_inference_model', sampler._inference_model)
304
-
305
- if mode == 'stochastic':
306
- if num_images > num_steps:
307
- print(f"\033[93mWarning: number of conditioning images is greater than number of steps for {sampler_name}. "
308
- "This may lead to performance degradation.\033[0m")
309
-
310
- cond_indices = (np.arange(num_steps) % num_images).tolist()
311
- def _new_inference_model(self, model, x_t, t, cond, **kwargs):
312
- cond_idx = cond_indices.pop(0)
313
- cond_i = cond[cond_idx:cond_idx+1]
314
- return self._old_inference_model(model, x_t, t, cond=cond_i, **kwargs)
315
-
316
- elif mode =='multidiffusion':
317
- from .samplers import FlowEulerSampler
318
- def _new_inference_model(self, model, x_t, t, cond, neg_cond, cfg_strength, cfg_interval, **kwargs):
319
- if cfg_interval[0] <= t <= cfg_interval[1]:
320
- preds = []
321
- for i in range(len(cond)):
322
- preds.append(FlowEulerSampler._inference_model(self, model, x_t, t, cond[i:i+1], **kwargs))
323
- pred = sum(preds) / len(preds)
324
- neg_pred = FlowEulerSampler._inference_model(self, model, x_t, t, neg_cond, **kwargs)
325
- return (1 + cfg_strength) * pred - cfg_strength * neg_pred
326
- else:
327
- preds = []
328
- for i in range(len(cond)):
329
- preds.append(FlowEulerSampler._inference_model(self, model, x_t, t, cond[i:i+1], **kwargs))
330
- pred = sum(preds) / len(preds)
331
- return pred
332
-
333
- else:
334
- raise ValueError(f"Unsupported mode: {mode}")
335
-
336
- sampler._inference_model = _new_inference_model.__get__(sampler, type(sampler))
337
-
338
- yield
339
-
340
- sampler._inference_model = sampler._old_inference_model
341
- delattr(sampler, f'_old_inference_model')
342
-
343
- @torch.no_grad()
344
- def run_multi_image(
345
- self,
346
- images: List[Image.Image],
347
- num_samples: int = 1,
348
- seed: int = 42,
349
- sparse_structure_sampler_params: dict = {},
350
- slat_sampler_params: dict = {},
351
- formats: List[str] = ['mesh', 'gaussian', 'radiance_field'],
352
- preprocess_image: bool = True,
353
- mode: Literal['stochastic', 'multidiffusion'] = 'stochastic',
354
- ) -> dict:
355
- """
356
- Run the pipeline with multiple images as condition
357
-
358
- Args:
359
- images (List[Image.Image]): The multi-view images of the assets
360
- num_samples (int): The number of samples to generate.
361
- sparse_structure_sampler_params (dict): Additional parameters for the sparse structure sampler.
362
- slat_sampler_params (dict): Additional parameters for the structured latent sampler.
363
- preprocess_image (bool): Whether to preprocess the image.
364
- """
365
- if preprocess_image:
366
- images = [self.preprocess_image(image) for image in images]
367
- cond = self.get_cond(images)
368
- cond['neg_cond'] = cond['neg_cond'][:1]
369
- torch.manual_seed(seed)
370
- ss_steps = {**self.sparse_structure_sampler_params, **sparse_structure_sampler_params}.get('steps')
371
- with self.inject_sampler_multi_image('sparse_structure_sampler', len(images), ss_steps, mode=mode):
372
- coords = self.sample_sparse_structure(cond, num_samples, sparse_structure_sampler_params)
373
- slat_steps = {**self.slat_sampler_params, **slat_sampler_params}.get('steps')
374
- with self.inject_sampler_multi_image('slat_sampler', len(images), slat_steps, mode=mode):
375
- slat = self.sample_slat(cond, coords, slat_sampler_params)
376
- return self.decode_slat(slat, formats)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
trellis/pipelines/trellis_text_to_3d.py DELETED
@@ -1,278 +0,0 @@
1
- from typing import *
2
- import torch
3
- import torch.nn as nn
4
- import numpy as np
5
- from transformers import CLIPTextModel, AutoTokenizer
6
- import open3d as o3d
7
- from .base import Pipeline
8
- from . import samplers
9
- from ..modules import sparse as sp
10
-
11
-
12
- class TrellisTextTo3DPipeline(Pipeline):
13
- """
14
- Pipeline for inferring Trellis text-to-3D models.
15
-
16
- Args:
17
- models (dict[str, nn.Module]): The models to use in the pipeline.
18
- sparse_structure_sampler (samplers.Sampler): The sampler for the sparse structure.
19
- slat_sampler (samplers.Sampler): The sampler for the structured latent.
20
- slat_normalization (dict): The normalization parameters for the structured latent.
21
- text_cond_model (str): The name of the text conditioning model.
22
- """
23
- def __init__(
24
- self,
25
- models: dict[str, nn.Module] = None,
26
- sparse_structure_sampler: samplers.Sampler = None,
27
- slat_sampler: samplers.Sampler = None,
28
- slat_normalization: dict = None,
29
- text_cond_model: str = None,
30
- ):
31
- if models is None:
32
- return
33
- super().__init__(models)
34
- self.sparse_structure_sampler = sparse_structure_sampler
35
- self.slat_sampler = slat_sampler
36
- self.sparse_structure_sampler_params = {}
37
- self.slat_sampler_params = {}
38
- self.slat_normalization = slat_normalization
39
- self._init_text_cond_model(text_cond_model)
40
-
41
- @staticmethod
42
- def from_pretrained(path: str) -> "TrellisTextTo3DPipeline":
43
- """
44
- Load a pretrained model.
45
-
46
- Args:
47
- path (str): The path to the model. Can be either local path or a Hugging Face repository.
48
- """
49
- pipeline = super(TrellisTextTo3DPipeline, TrellisTextTo3DPipeline).from_pretrained(path)
50
- new_pipeline = TrellisTextTo3DPipeline()
51
- new_pipeline.__dict__ = pipeline.__dict__
52
- args = pipeline._pretrained_args
53
-
54
- new_pipeline.sparse_structure_sampler = getattr(samplers, args['sparse_structure_sampler']['name'])(**args['sparse_structure_sampler']['args'])
55
- new_pipeline.sparse_structure_sampler_params = args['sparse_structure_sampler']['params']
56
-
57
- new_pipeline.slat_sampler = getattr(samplers, args['slat_sampler']['name'])(**args['slat_sampler']['args'])
58
- new_pipeline.slat_sampler_params = args['slat_sampler']['params']
59
-
60
- new_pipeline.slat_normalization = args['slat_normalization']
61
-
62
- new_pipeline._init_text_cond_model(args['text_cond_model'])
63
-
64
- return new_pipeline
65
-
66
- def _init_text_cond_model(self, name: str):
67
- """
68
- Initialize the text conditioning model.
69
- """
70
- # load model
71
- model = CLIPTextModel.from_pretrained(name)
72
- tokenizer = AutoTokenizer.from_pretrained(name)
73
- model.eval()
74
- model = model.cuda()
75
- self.text_cond_model = {
76
- 'model': model,
77
- 'tokenizer': tokenizer,
78
- }
79
- self.text_cond_model['null_cond'] = self.encode_text([''])
80
-
81
- @torch.no_grad()
82
- def encode_text(self, text: List[str]) -> torch.Tensor:
83
- """
84
- Encode the text.
85
- """
86
- assert isinstance(text, list) and all(isinstance(t, str) for t in text), "text must be a list of strings"
87
- encoding = self.text_cond_model['tokenizer'](text, max_length=77, padding='max_length', truncation=True, return_tensors='pt')
88
- tokens = encoding['input_ids'].cuda()
89
- embeddings = self.text_cond_model['model'](input_ids=tokens).last_hidden_state
90
-
91
- return embeddings
92
-
93
- def get_cond(self, prompt: List[str]) -> dict:
94
- """
95
- Get the conditioning information for the model.
96
-
97
- Args:
98
- prompt (List[str]): The text prompt.
99
-
100
- Returns:
101
- dict: The conditioning information
102
- """
103
- cond = self.encode_text(prompt)
104
- neg_cond = self.text_cond_model['null_cond']
105
- return {
106
- 'cond': cond,
107
- 'neg_cond': neg_cond,
108
- }
109
-
110
- def sample_sparse_structure(
111
- self,
112
- cond: dict,
113
- num_samples: int = 1,
114
- sampler_params: dict = {},
115
- ) -> torch.Tensor:
116
- """
117
- Sample sparse structures with the given conditioning.
118
-
119
- Args:
120
- cond (dict): The conditioning information.
121
- num_samples (int): The number of samples to generate.
122
- sampler_params (dict): Additional parameters for the sampler.
123
- """
124
- # Sample occupancy latent
125
- flow_model = self.models['sparse_structure_flow_model']
126
- reso = flow_model.resolution
127
- noise = torch.randn(num_samples, flow_model.in_channels, reso, reso, reso).to(self.device)
128
- sampler_params = {**self.sparse_structure_sampler_params, **sampler_params}
129
- z_s = self.sparse_structure_sampler.sample(
130
- flow_model,
131
- noise,
132
- **cond,
133
- **sampler_params,
134
- verbose=True
135
- ).samples
136
-
137
- # Decode occupancy latent
138
- decoder = self.models['sparse_structure_decoder']
139
- coords = torch.argwhere(decoder(z_s)>0)[:, [0, 2, 3, 4]].int()
140
-
141
- return coords
142
-
143
- def decode_slat(
144
- self,
145
- slat: sp.SparseTensor,
146
- formats: List[str] = ['mesh', 'gaussian', 'radiance_field'],
147
- ) -> dict:
148
- """
149
- Decode the structured latent.
150
-
151
- Args:
152
- slat (sp.SparseTensor): The structured latent.
153
- formats (List[str]): The formats to decode the structured latent to.
154
-
155
- Returns:
156
- dict: The decoded structured latent.
157
- """
158
- ret = {}
159
- if 'mesh' in formats:
160
- ret['mesh'] = self.models['slat_decoder_mesh'](slat)
161
- if 'gaussian' in formats:
162
- ret['gaussian'] = self.models['slat_decoder_gs'](slat)
163
- if 'radiance_field' in formats:
164
- ret['radiance_field'] = self.models['slat_decoder_rf'](slat)
165
- return ret
166
-
167
- def sample_slat(
168
- self,
169
- cond: dict,
170
- coords: torch.Tensor,
171
- sampler_params: dict = {},
172
- ) -> sp.SparseTensor:
173
- """
174
- Sample structured latent with the given conditioning.
175
-
176
- Args:
177
- cond (dict): The conditioning information.
178
- coords (torch.Tensor): The coordinates of the sparse structure.
179
- sampler_params (dict): Additional parameters for the sampler.
180
- """
181
- # Sample structured latent
182
- flow_model = self.models['slat_flow_model']
183
- noise = sp.SparseTensor(
184
- feats=torch.randn(coords.shape[0], flow_model.in_channels).to(self.device),
185
- coords=coords,
186
- )
187
- sampler_params = {**self.slat_sampler_params, **sampler_params}
188
- slat = self.slat_sampler.sample(
189
- flow_model,
190
- noise,
191
- **cond,
192
- **sampler_params,
193
- verbose=True
194
- ).samples
195
-
196
- std = torch.tensor(self.slat_normalization['std'])[None].to(slat.device)
197
- mean = torch.tensor(self.slat_normalization['mean'])[None].to(slat.device)
198
- slat = slat * std + mean
199
-
200
- return slat
201
-
202
- @torch.no_grad()
203
- def run(
204
- self,
205
- prompt: str,
206
- num_samples: int = 1,
207
- seed: int = 42,
208
- sparse_structure_sampler_params: dict = {},
209
- slat_sampler_params: dict = {},
210
- formats: List[str] = ['mesh', 'gaussian', 'radiance_field'],
211
- ) -> dict:
212
- """
213
- Run the pipeline.
214
-
215
- Args:
216
- prompt (str): The text prompt.
217
- num_samples (int): The number of samples to generate.
218
- seed (int): The random seed.
219
- sparse_structure_sampler_params (dict): Additional parameters for the sparse structure sampler.
220
- slat_sampler_params (dict): Additional parameters for the structured latent sampler.
221
- formats (List[str]): The formats to decode the structured latent to.
222
- """
223
- cond = self.get_cond([prompt])
224
- torch.manual_seed(seed)
225
- coords = self.sample_sparse_structure(cond, num_samples, sparse_structure_sampler_params)
226
- slat = self.sample_slat(cond, coords, slat_sampler_params)
227
- return self.decode_slat(slat, formats)
228
-
229
- def voxelize(self, mesh: o3d.geometry.TriangleMesh) -> torch.Tensor:
230
- """
231
- Voxelize a mesh.
232
-
233
- Args:
234
- mesh (o3d.geometry.TriangleMesh): The mesh to voxelize.
235
- sha256 (str): The SHA256 hash of the mesh.
236
- output_dir (str): The output directory.
237
- """
238
- vertices = np.asarray(mesh.vertices)
239
- aabb = np.stack([vertices.min(0), vertices.max(0)])
240
- center = (aabb[0] + aabb[1]) / 2
241
- scale = (aabb[1] - aabb[0]).max()
242
- vertices = (vertices - center) / scale
243
- vertices = np.clip(vertices, -0.5 + 1e-6, 0.5 - 1e-6)
244
- mesh.vertices = o3d.utility.Vector3dVector(vertices)
245
- voxel_grid = o3d.geometry.VoxelGrid.create_from_triangle_mesh_within_bounds(mesh, voxel_size=1/64, min_bound=(-0.5, -0.5, -0.5), max_bound=(0.5, 0.5, 0.5))
246
- vertices = np.array([voxel.grid_index for voxel in voxel_grid.get_voxels()])
247
- return torch.tensor(vertices).int().cuda()
248
-
249
- @torch.no_grad()
250
- def run_variant(
251
- self,
252
- mesh: o3d.geometry.TriangleMesh,
253
- prompt: str,
254
- num_samples: int = 1,
255
- seed: int = 42,
256
- slat_sampler_params: dict = {},
257
- formats: List[str] = ['mesh', 'gaussian', 'radiance_field'],
258
- ) -> dict:
259
- """
260
- Run the pipeline for making variants of an asset.
261
-
262
- Args:
263
- mesh (o3d.geometry.TriangleMesh): The base mesh.
264
- prompt (str): The text prompt.
265
- num_samples (int): The number of samples to generate.
266
- seed (int): The random seed
267
- slat_sampler_params (dict): Additional parameters for the structured latent sampler.
268
- formats (List[str]): The formats to decode the structured latent to.
269
- """
270
- cond = self.get_cond([prompt])
271
- coords = self.voxelize(mesh)
272
- coords = torch.cat([
273
- torch.arange(num_samples).repeat_interleave(coords.shape[0], 0)[:, None].int().cuda(),
274
- coords.repeat(num_samples, 1)
275
- ], 1)
276
- torch.manual_seed(seed)
277
- slat = self.sample_slat(cond, coords, slat_sampler_params)
278
- return self.decode_slat(slat, formats)