ZhiyuanthePony commited on
Commit
fc44d4b
·
1 Parent(s): dfeea18

remove_type_annotator

Browse files
app.py CHANGED
@@ -1,11 +1,11 @@
1
  import os
 
2
  import subprocess
3
  import sys
4
  try:
5
  import spaces
6
  except:
7
  pass
8
- os.environ["PYDANTIC_STRICT_TYPE_CHECKING"] = "0"
9
 
10
  # Check if setup has been run
11
  setup_marker = ".setup_complete"
@@ -23,7 +23,6 @@ if not os.path.exists(setup_marker):
23
 
24
  import torch
25
  import gradio as gr
26
- from typing import Tuple, List, Dict, Any, Optional
27
  from collections import deque
28
  from diffusers import StableDiffusionPipeline
29
 
@@ -58,7 +57,7 @@ def initialize_pipeline():
58
  return PIPELINE
59
 
60
  @spaces.GPU
61
- def generate_3d_mesh(prompt: str) -> Tuple[Optional[str], Optional[str]]:
62
  """Generate 3D mesh from text prompt"""
63
  global PIPELINE, OBJ_FILE_QUEUE
64
 
 
1
  import os
2
+
3
  import subprocess
4
  import sys
5
  try:
6
  import spaces
7
  except:
8
  pass
 
9
 
10
  # Check if setup has been run
11
  setup_marker = ".setup_complete"
 
23
 
24
  import torch
25
  import gradio as gr
 
26
  from collections import deque
27
  from diffusers import StableDiffusionPipeline
28
 
 
57
  return PIPELINE
58
 
59
  @spaces.GPU
60
+ def generate_3d_mesh(prompt):
61
  """Generate 3D mesh from text prompt"""
62
  global PIPELINE, OBJ_FILE_QUEUE
63
 
example.py CHANGED
@@ -17,8 +17,8 @@ from triplaneturbo_executable import TriplaneTurboTextTo3DPipeline, TriplaneTurb
17
 
18
  # Initialize configuration and parameters
19
  prompt = "a beautiful girl"
20
- output_dir = "examples/output"
21
- adapter_name_or_path = "/home/user/app/pretrained/triplane_turbo_sd_v1.pth"
22
  num_results_per_prompt = 1
23
  seed = 42
24
  device = "cuda"
 
17
 
18
  # Initialize configuration and parameters
19
  prompt = "a beautiful girl"
20
+ output_dir = "output"
21
+ adapter_name_or_path = "pretrained/triplane_turbo_sd_v1.pth"
22
  num_results_per_prompt = 1
23
  seed = 42
24
  device = "cuda"
setup.sh CHANGED
@@ -17,12 +17,12 @@ pip install --force-reinstall -v "numpy==1.25.2"
17
  # cd ..
18
  # cd ..
19
 
20
-
21
 
22
  echo "Installing other requirements..."
23
  pip install -r requirements.txt
24
 
25
-
26
  echo "Installing pre-compiled DISO wheel package..."
27
  huggingface-cli download --resume-download ZhiyuanthePony/TriplaneTurbo \
28
  --include "diso-0.1.4-*.whl" \
@@ -30,12 +30,3 @@ huggingface-cli download --resume-download ZhiyuanthePony/TriplaneTurbo \
30
  --local-dir-use-symlinks False
31
 
32
  pip install ./diso_package/diso-0.1.4-*.whl
33
- echo "Setup completed successfully!"
34
-
35
- echo "Installing compatible dependency versions..."
36
- pip uninstall -y pydantic
37
- pip install pydantic==1.10.8 # Install compatible older version
38
-
39
- # Ensure Gradio and other dependencies are installed correctly
40
- pip install "gradio>=4.0.0,<5.0.0"
41
- pip install "fastapi<0.103.0" # Ensure compatible FastAPI version
 
17
  # cd ..
18
  # cd ..
19
 
20
+ echo "Setup completed successfully!"
21
 
22
  echo "Installing other requirements..."
23
  pip install -r requirements.txt
24
 
25
+ # 从您的Hugging Face仓库下载并安装预编译的DISO wheel
26
  echo "Installing pre-compiled DISO wheel package..."
27
  huggingface-cli download --resume-download ZhiyuanthePony/TriplaneTurbo \
28
  --include "diso-0.1.4-*.whl" \
 
30
  --local-dir-use-symlinks False
31
 
32
  pip install ./diso_package/diso-0.1.4-*.whl
 
 
 
 
 
 
 
 
 
triplaneturbo_executable/extern/sd_dual_triplane_modules.py CHANGED
@@ -2,7 +2,6 @@ import re
2
  import torch
3
  import torch.nn as nn
4
  from dataclasses import dataclass
5
- from typing import Optional, Union, Tuple
6
 
7
  from diffusers.models.attention_processor import Attention
8
  from diffusers import (
@@ -39,9 +38,9 @@ class LoRALinearLayerwBias(nn.Module):
39
  in_features: int,
40
  out_features: int,
41
  rank: int = 4,
42
- network_alpha: Optional[float] = None,
43
- device: Optional[Union[torch.device, str]] = None,
44
- dtype: Optional[torch.dtype] = None,
45
  with_bias: bool = False
46
  ):
47
  super().__init__()
@@ -105,10 +104,10 @@ class TriplaneLoRAConv2dLayer(nn.Module):
105
  in_features: int,
106
  out_features: int,
107
  rank: int = 4,
108
- kernel_size: Union[int, Tuple[int, int]] = (1, 1),
109
- stride: Union[int, Tuple[int, int]] = (1, 1),
110
- padding: Union[int, Tuple[int, int], str] = 0,
111
- network_alpha: Optional[float] = None,
112
  with_bias: bool = False,
113
  locon_type: str = "hexa_v1", #hexa_v2, vanilla_v1, vanilla_v2
114
  ):
@@ -220,7 +219,7 @@ class TriplaneSelfAttentionLoRAAttnProcessor(nn.Module):
220
  self,
221
  hidden_size: int,
222
  rank: int = 4,
223
- network_alpha: Optional[float] = None,
224
  with_bias: bool = False,
225
  lora_type: str = "hexa_v1", # vanilla,
226
  ):
@@ -492,7 +491,7 @@ class TriplaneCrossAttentionLoRAAttnProcessor(nn.Module):
492
  hidden_size: int,
493
  cross_attention_dim: int,
494
  rank: int = 4,
495
- network_alpha: Optional[float] = None,
496
  with_bias: bool = False,
497
  lora_type: str = "hexa_v1", # vanilla,
498
  ):
@@ -713,7 +712,7 @@ class OneStepTriplaneDualStableDiffusion(nn.Module):
713
  """
714
  def __init__(
715
  self,
716
- config: Union[dict, GeneratorConfig],
717
  vae: AutoencoderKL,
718
  unet: UNet2DConditionModel,
719
  ):
 
2
  import torch
3
  import torch.nn as nn
4
  from dataclasses import dataclass
 
5
 
6
  from diffusers.models.attention_processor import Attention
7
  from diffusers import (
 
38
  in_features: int,
39
  out_features: int,
40
  rank: int = 4,
41
+ network_alpha=None,
42
+ device=None,
43
+ dtype=None,
44
  with_bias: bool = False
45
  ):
46
  super().__init__()
 
104
  in_features: int,
105
  out_features: int,
106
  rank: int = 4,
107
+ kernel_size = (1, 1),
108
+ stride = (1, 1),
109
+ padding = 0,
110
+ network_alpha = None,
111
  with_bias: bool = False,
112
  locon_type: str = "hexa_v1", #hexa_v2, vanilla_v1, vanilla_v2
113
  ):
 
219
  self,
220
  hidden_size: int,
221
  rank: int = 4,
222
+ network_alpha=None,
223
  with_bias: bool = False,
224
  lora_type: str = "hexa_v1", # vanilla,
225
  ):
 
491
  hidden_size: int,
492
  cross_attention_dim: int,
493
  rank: int = 4,
494
+ network_alpha = None,
495
  with_bias: bool = False,
496
  lora_type: str = "hexa_v1", # vanilla,
497
  ):
 
712
  """
713
  def __init__(
714
  self,
715
+ config,
716
  vae: AutoencoderKL,
717
  unet: UNet2DConditionModel,
718
  ):
triplaneturbo_executable/models/networks.py CHANGED
@@ -3,7 +3,6 @@ import torch.nn as nn
3
  import torch.nn.functional as F
4
  from ..utils.general_utils import config_to_primitive
5
  from dataclasses import dataclass
6
- from typing import Optional, Literal
7
 
8
  def get_activation(name):
9
  if name is None:
@@ -21,7 +20,7 @@ def get_activation(name):
21
 
22
 
23
  class VanillaMLP(nn.Module):
24
- def __init__(self, dim_in: int, dim_out: int, config: dict):
25
  super().__init__()
26
  # Convert dict to MLPConfig if needed
27
  if isinstance(config, dict):
@@ -70,7 +69,7 @@ class MLPConfig:
70
  n_neurons: int = 64
71
  n_hidden_layers: int = 2
72
 
73
- def get_mlp(input_dim: int, output_dim: int, config: dict) -> nn.Module:
74
  """Create MLP network based on config"""
75
  # Convert dict to MLPConfig
76
  if isinstance(config, dict):
 
3
  import torch.nn.functional as F
4
  from ..utils.general_utils import config_to_primitive
5
  from dataclasses import dataclass
 
6
 
7
  def get_activation(name):
8
  if name is None:
 
20
 
21
 
22
  class VanillaMLP(nn.Module):
23
+ def __init__(self, dim_in, dim_out, config):
24
  super().__init__()
25
  # Convert dict to MLPConfig if needed
26
  if isinstance(config, dict):
 
69
  n_neurons: int = 64
70
  n_hidden_layers: int = 2
71
 
72
+ def get_mlp(input_dim, output_dim, config):
73
  """Create MLP network based on config"""
74
  # Convert dict to MLPConfig
75
  if isinstance(config, dict):
triplaneturbo_executable/pipelines/triplaneturbo_text_to_3d.py CHANGED
@@ -4,7 +4,6 @@ import json
4
  from tqdm import tqdm
5
 
6
  import torch
7
- from typing import *
8
  from dataclasses import dataclass, field
9
  from diffusers import StableDiffusionPipeline
10
 
@@ -21,11 +20,6 @@ class TriplaneTurboTextTo3DPipelineConfig:
21
  # Basic pipeline settings
22
  base_model_name_or_path: str = "stabilityai/stable-diffusion-2-1-base"
23
 
24
- num_inference_steps: int = 4
25
- num_results_per_prompt: int = 1
26
- latent_channels: int = 4
27
- latent_height: int = 64
28
- latent_width: int = 64
29
 
30
  # Training/sampling settings
31
  num_steps_sampling: int = 4
@@ -72,7 +66,7 @@ class TriplaneTurboTextTo3DPipelineConfig:
72
  color_activation: str = "sigmoid-mipnerf"
73
 
74
  @classmethod
75
- def from_pretrained(cls, pretrained_path: str) -> "TriplaneTurboTextTo3DPipelineConfig":
76
  """Load config from pretrained path"""
77
  config_path = os.path.join(pretrained_path, "config.json")
78
  if os.path.exists(config_path):
@@ -91,11 +85,11 @@ class TriplaneTurboTextTo3DPipeline(Pipeline):
91
 
92
  def __init__(
93
  self,
94
- geometry: StableDiffusionTriplaneDualAttention,
95
- material: Callable,
96
- base_pipeline: StableDiffusionPipeline,
97
- sample_scheduler: Callable,
98
- isosurface_helper: Callable,
99
  **kwargs,
100
  ):
101
  super().__init__()
@@ -116,7 +110,7 @@ class TriplaneTurboTextTo3DPipeline(Pipeline):
116
  @classmethod
117
  def from_pretrained(
118
  cls,
119
- pretrained_model_name_or_path: str,
120
  **kwargs,
121
  ):
122
  """
@@ -197,10 +191,10 @@ class TriplaneTurboTextTo3DPipeline(Pipeline):
197
 
198
  def encode_prompt(
199
  self,
200
- prompt: Union[str, List[str]],
201
- device: str,
202
- num_results_per_prompt: int = 1,
203
- ) -> torch.FloatTensor:
204
  """
205
  Encodes the prompt into text encoder hidden states.
206
 
@@ -227,14 +221,13 @@ class TriplaneTurboTextTo3DPipeline(Pipeline):
227
  @torch.no_grad()
228
  def __call__(
229
  self,
230
- prompt: Union[str, List[str]],
231
- num_inference_steps: int = 4,
232
- num_results_per_prompt: int = 1,
233
- generator: Optional[torch.Generator] = None,
234
- latents: Optional[torch.FloatTensor] = None,
235
- return_dict: bool = True,
236
- colorize: bool = True,
237
- **kwargs,
238
  ):
239
  # Implementation similar to Zero123Pipeline
240
  # Reference code from: https://github.com/zero123/zero123-diffusers
@@ -251,15 +244,18 @@ class TriplaneTurboTextTo3DPipeline(Pipeline):
251
  # Get the device from the first available module
252
 
253
  # Generate latents if not provided
254
- if latents is None:
255
- latents = torch.randn(
256
- (batch_size * 6, 4, 32, 32), # hard-coded for now
257
- generator=generator,
258
- device=self.device,
259
- )
 
 
 
260
 
261
  # Process text prompt through geometry module
262
- text_embed, _ = self.encode_prompt(prompt, self.device, num_results_per_prompt)
263
 
264
  # Run diffusion process
265
  # Set up timesteps for sampling
@@ -282,7 +278,7 @@ class TriplaneTurboTextTo3DPipeline(Pipeline):
282
  pred = self.geometry.denoise(
283
  noisy_input=noisy_latent_input,
284
  text_embed=text_embed,
285
- timestep=t.to(self.device),
286
  )
287
 
288
  # Update latents
@@ -311,20 +307,19 @@ class TriplaneTurboTextTo3DPipeline(Pipeline):
311
  activation=self.material,
312
  )
313
 
314
- # decide output type based on return_dict
315
- if return_dict:
316
- return {
317
- "space_cache": space_cache,
318
- "latents": latents,
319
- "mesh": mesh_list,
320
- }
321
- else:
322
- return mesh_list
323
 
324
  def _set_timesteps(
325
  self,
326
  scheduler,
327
- num_steps: int,
328
  ):
329
  """Set up timesteps for sampling.
330
 
 
4
  from tqdm import tqdm
5
 
6
  import torch
 
7
  from dataclasses import dataclass, field
8
  from diffusers import StableDiffusionPipeline
9
 
 
20
  # Basic pipeline settings
21
  base_model_name_or_path: str = "stabilityai/stable-diffusion-2-1-base"
22
 
 
 
 
 
 
23
 
24
  # Training/sampling settings
25
  num_steps_sampling: int = 4
 
66
  color_activation: str = "sigmoid-mipnerf"
67
 
68
  @classmethod
69
+ def from_pretrained(cls, pretrained_path):
70
  """Load config from pretrained path"""
71
  config_path = os.path.join(pretrained_path, "config.json")
72
  if os.path.exists(config_path):
 
85
 
86
  def __init__(
87
  self,
88
+ geometry,
89
+ material,
90
+ base_pipeline,
91
+ sample_scheduler,
92
+ isosurface_helper,
93
  **kwargs,
94
  ):
95
  super().__init__()
 
110
  @classmethod
111
  def from_pretrained(
112
  cls,
113
+ pretrained_model_name_or_path,
114
  **kwargs,
115
  ):
116
  """
 
191
 
192
  def encode_prompt(
193
  self,
194
+ prompt,
195
+ device,
196
+ num_results_per_prompt = 1,
197
+ ):
198
  """
199
  Encodes the prompt into text encoder hidden states.
200
 
 
221
  @torch.no_grad()
222
  def __call__(
223
  self,
224
+ prompt,
225
+ num_results_per_prompt=1,
226
+ generator=None,
227
+ device=None,
228
+ return_dict=True,
229
+ num_inference_steps=4,
230
+ colorize = True,
 
231
  ):
232
  # Implementation similar to Zero123Pipeline
233
  # Reference code from: https://github.com/zero123/zero123-diffusers
 
244
  # Get the device from the first available module
245
 
246
  # Generate latents if not provided
247
+ if device is None:
248
+ device = self.device
249
+ if generator is None:
250
+ generator = torch.Generator(device=device)
251
+ latents = torch.randn(
252
+ (batch_size * 6, 4, 32, 32), # hard-coded for now
253
+ generator=generator,
254
+ device=device,
255
+ )
256
 
257
  # Process text prompt through geometry module
258
+ text_embed, _ = self.encode_prompt(prompt, device, num_results_per_prompt)
259
 
260
  # Run diffusion process
261
  # Set up timesteps for sampling
 
278
  pred = self.geometry.denoise(
279
  noisy_input=noisy_latent_input,
280
  text_embed=text_embed,
281
+ timestep=t.to(device),
282
  )
283
 
284
  # Update latents
 
307
  activation=self.material,
308
  )
309
 
310
+ if return_dict:
311
+ return {
312
+ "space_cache": space_cache,
313
+ "latents": latents,
314
+ "mesh": mesh_list,
315
+ }
316
+ else:
317
+ return mesh_list
 
318
 
319
  def _set_timesteps(
320
  self,
321
  scheduler,
322
+ num_steps,
323
  ):
324
  """Set up timesteps for sampling.
325
 
triplaneturbo_executable/utils/general_utils.py CHANGED
@@ -2,17 +2,28 @@ import torch
2
  import torch.nn as nn
3
  import torch.nn.functional as F
4
  from torch import Tensor
5
- from typing import *
6
- from jaxtyping import Float
7
- from omegaconf import OmegaConf
8
 
9
- def config_to_primitive(config, resolve: bool = True) -> Any:
10
- return OmegaConf.to_container(config, resolve=resolve)
 
 
 
 
 
 
 
 
 
 
 
11
 
12
  def scale_tensor(
13
- dat: Float[Tensor, "... D"],
14
- inp_scale: Union[Tuple[float, float], Float[Tensor, "2 D"]],
15
- tgt_scale: Union[Tuple[float, float], Float[Tensor, "2 D"]]
16
  ):
17
  if inp_scale is None:
18
  inp_scale = (0, 1)
@@ -25,8 +36,8 @@ def scale_tensor(
25
  return dat
26
 
27
  def contract_to_unisphere_custom(
28
- x: Float[Tensor, "... 3"], bbox: Float[Tensor, "2 3"], unbounded: bool = False
29
- ) -> Float[Tensor, "... 3"]:
30
  if unbounded:
31
  x = scale_tensor(x, bbox, (-1, 1))
32
  x = x * 2 - 1 # aabb is at [-1, 1]
@@ -81,7 +92,7 @@ def project_onto_planes(planes, coordinates):
81
  projections = torch.bmm(coordinates, inv_planes)
82
  return projections[..., :2]
83
 
84
- def sample_from_planes(plane_features, coordinates, mode='bilinear', padding_mode='zeros', box_warp=2, interpolate_feat: Optional[str] = 'None'):
85
  assert padding_mode == 'zeros'
86
  N, n_planes, C, H, W = plane_features.shape
87
  _, M, _ = coordinates.shape
@@ -101,4 +112,10 @@ def sample_from_planes(plane_features, coordinates, mode='bilinear', padding_mod
101
  output_features = output_features.permute(0, 3, 2, 1).reshape(N, n_planes, M, C)
102
  output_features = output_features.permute(0, 2, 1, 3).reshape(N, M, n_planes*C)
103
 
104
- return output_features.contiguous()
 
 
 
 
 
 
 
2
  import torch.nn as nn
3
  import torch.nn.functional as F
4
  from torch import Tensor
5
+ import numpy as np
6
+ from dataclasses import asdict, is_dataclass
7
+ import gc
8
 
9
+ def config_to_primitive(config):
10
+ """Convert a dataclass config to a dictionary recursively."""
11
+ if is_dataclass(config):
12
+ config_dict = asdict(config)
13
+ return {k: config_to_primitive(v) for k, v in config_dict.items()}
14
+ elif isinstance(config, dict):
15
+ return {k: config_to_primitive(v) for k, v in config.items()}
16
+ elif isinstance(config, list):
17
+ return [config_to_primitive(v) for v in config]
18
+ elif isinstance(config, tuple):
19
+ return tuple(config_to_primitive(v) for v in config)
20
+ else:
21
+ return config
22
 
23
  def scale_tensor(
24
+ dat,
25
+ inp_scale,
26
+ tgt_scale
27
  ):
28
  if inp_scale is None:
29
  inp_scale = (0, 1)
 
36
  return dat
37
 
38
  def contract_to_unisphere_custom(
39
+ x, bbox, unbounded = False
40
+ ):
41
  if unbounded:
42
  x = scale_tensor(x, bbox, (-1, 1))
43
  x = x * 2 - 1 # aabb is at [-1, 1]
 
92
  projections = torch.bmm(coordinates, inv_planes)
93
  return projections[..., :2]
94
 
95
+ def sample_from_planes(plane_features, coordinates, mode='bilinear', padding_mode='zeros', box_warp=2, interpolate_feat = None):
96
  assert padding_mode == 'zeros'
97
  N, n_planes, C, H, W = plane_features.shape
98
  _, M, _ = coordinates.shape
 
112
  output_features = output_features.permute(0, 3, 2, 1).reshape(N, n_planes, M, C)
113
  output_features = output_features.permute(0, 2, 1, 3).reshape(N, M, n_planes*C)
114
 
115
+ return output_features.contiguous()
116
+
117
+ def cleanup():
118
+ """Cleanup torch memory."""
119
+ gc.collect()
120
+ torch.cuda.empty_cache()
121
+ torch.cuda.ipc_collect()
triplaneturbo_executable/utils/mesh.py CHANGED
@@ -1,77 +1,54 @@
1
  import numpy as np
2
  import torch
3
  import torch.nn.functional as F
4
-
5
- from typing import Any, Dict, Optional, Union
6
-
7
- import numpy as np
8
- import torch
9
- import torch.nn.functional as F
10
- from jaxtyping import Float, Integer
11
- from torch import Tensor
12
 
13
  def dot(x, y):
14
  return torch.sum(x * y, -1, keepdim=True)
15
 
16
  class Mesh:
17
  def __init__(
18
- self, v_pos: Float[Tensor, "Nv 3"], t_pos_idx: Integer[Tensor, "Nf 3"], **kwargs
19
- ) -> None:
20
- self.v_pos: Float[Tensor, "Nv 3"] = v_pos
21
- self.t_pos_idx: Integer[Tensor, "Nf 3"] = t_pos_idx
22
- self._v_nrm: Optional[Float[Tensor, "Nv 3"]] = None
23
- self._v_tng: Optional[Float[Tensor, "Nv 3"]] = None
24
- self._v_tex: Optional[Float[Tensor, "Nt 3"]] = None
25
- self._t_tex_idx: Optional[Float[Tensor, "Nf 3"]] = None
26
- self._v_rgb: Optional[Float[Tensor, "Nv 3"]] = None
27
- self._edges: Optional[Integer[Tensor, "Ne 2"]] = None
28
- self.extras: Dict[str, Any] = {}
29
- for k, v in kwargs.items():
30
- self.add_extra(k, v)
31
 
32
  def add_extra(self, k, v) -> None:
33
  self.extras[k] = v
34
 
35
- def remove_outlier(self, outlier_n_faces_threshold: Union[int, float]):
36
-
37
- # use trimesh to first split the mesh into connected components
38
- # then remove the components with less than n_face_threshold faces
39
- import trimesh
40
-
41
- # construct a trimesh object
42
- mesh = trimesh.Trimesh(
43
- vertices=self.v_pos.detach().cpu().numpy(),
44
- faces=self.t_pos_idx.detach().cpu().numpy(),
 
 
 
 
 
 
 
 
 
 
 
 
45
  )
46
-
47
- # split the mesh into connected components
48
- components = mesh.split(only_watertight=False)
49
-
50
-
51
- n_faces_threshold: int
52
- if isinstance(outlier_n_faces_threshold, float):
53
- # set the threshold to the number of faces in the largest component multiplied by outlier_n_faces_threshold
54
- n_faces_threshold = int(
55
- max([c.faces.shape[0] for c in components]) * outlier_n_faces_threshold
56
- )
57
- else:
58
- # set the threshold directly to outlier_n_faces_threshold
59
- n_faces_threshold = outlier_n_faces_threshold
60
-
61
- # remove the components with less than n_face_threshold faces
62
- components = [c for c in components if c.faces.shape[0] >= n_faces_threshold]
63
-
64
- # merge the components
65
- mesh = trimesh.util.concatenate(components)
66
-
67
- # convert back to our mesh format
68
- v_pos = torch.from_numpy(mesh.vertices).to(self.v_pos)
69
- t_pos_idx = torch.from_numpy(mesh.faces).to(self.t_pos_idx)
70
-
71
- clean_mesh = Mesh(v_pos, t_pos_idx)
72
- # keep the extras unchanged
73
-
74
- return clean_mesh
75
 
76
  @property
77
  def requires_grad(self):
@@ -245,8 +222,8 @@ class Mesh:
245
  edges = torch.unique(edges, dim=0)
246
  return edges
247
 
248
- def normal_consistency(self) -> Float[Tensor, ""]:
249
- edge_nrm: Float[Tensor, "Ne 2 3"] = self.v_nrm[self.edges]
250
  nc = (
251
  1.0 - torch.cosine_similarity(edge_nrm[:, 0], edge_nrm[:, 1], dim=-1)
252
  ).mean()
@@ -279,10 +256,45 @@ class Mesh:
279
  # correct diagonal
280
  return torch.sparse_coo_tensor(idx, values, (V, V)).coalesce()
281
 
282
- def laplacian(self) -> Float[Tensor, ""]:
283
  with torch.no_grad():
284
  L = self._laplacian_uniform()
285
  loss = L.mm(self.v_pos)
286
  loss = loss.norm(dim=1)
287
  loss = loss.mean()
288
  return loss
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import numpy as np
2
  import torch
3
  import torch.nn.functional as F
4
+ import trimesh
 
 
 
 
 
 
 
5
 
6
  def dot(x, y):
7
  return torch.sum(x * y, -1, keepdim=True)
8
 
9
  class Mesh:
10
  def __init__(
11
+ self, v_pos, t_pos_idx, material=None
12
+ ):
13
+ self.v_pos = v_pos
14
+ self.t_pos_idx = t_pos_idx
15
+ self.material = material
16
+ self._v_nrm = None
17
+ self._v_tng = None
18
+ self._v_tex = None
19
+ self._t_tex_idx = None
20
+ self._v_rgb = None
21
+ self._edges = None
22
+ self.extras = {}
 
23
 
24
  def add_extra(self, k, v) -> None:
25
  self.extras[k] = v
26
 
27
+ def remove_outlier(self, n_face_threshold=5):
28
+ """Remove outlier components with fewer faces than threshold."""
29
+ # Convert to trimesh
30
+ trimesh_mesh = self.as_trimesh()
31
+
32
+ # Split into connected components
33
+ components = trimesh_mesh.split(only_watertight=False)
34
+
35
+ # Filter components with few faces
36
+ valid_components = [c for c in components if len(c.faces) > n_face_threshold]
37
+
38
+ if len(valid_components) == 0:
39
+ # If no valid components, return the original mesh
40
+ return self
41
+
42
+ # Combine valid components
43
+ combined = trimesh.util.concatenate(valid_components)
44
+
45
+ # Convert back to our Mesh format
46
+ new_mesh = Mesh(
47
+ torch.tensor(combined.vertices, dtype=self.v_pos.dtype, device=self.v_pos.device),
48
+ torch.tensor(combined.faces, dtype=self.t_pos_idx.dtype, device=self.t_pos_idx.device)
49
  )
50
+
51
+ return new_mesh
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52
 
53
  @property
54
  def requires_grad(self):
 
222
  edges = torch.unique(edges, dim=0)
223
  return edges
224
 
225
+ def normal_consistency(self):
226
+ edge_nrm = self.v_nrm[self.edges]
227
  nc = (
228
  1.0 - torch.cosine_similarity(edge_nrm[:, 0], edge_nrm[:, 1], dim=-1)
229
  ).mean()
 
256
  # correct diagonal
257
  return torch.sparse_coo_tensor(idx, values, (V, V)).coalesce()
258
 
259
+ def laplacian(self):
260
  with torch.no_grad():
261
  L = self._laplacian_uniform()
262
  loss = L.mm(self.v_pos)
263
  loss = loss.norm(dim=1)
264
  loss = loss.mean()
265
  return loss
266
+
267
+ def to(self, device):
268
+ v_pos = self.v_pos.to(device)
269
+ t_pos_idx = self.t_pos_idx.to(device)
270
+ return Mesh(v_pos, t_pos_idx)
271
+
272
+ def as_trimesh(self):
273
+ vertices = self.v_pos.detach().cpu().numpy()
274
+ faces = self.t_pos_idx.detach().cpu().numpy()
275
+
276
+ mesh = trimesh.Trimesh(
277
+ vertices=vertices,
278
+ faces=faces,
279
+ process=False
280
+ )
281
+
282
+ # Add texture if available
283
+ if hasattr(self, 'albedo_map') and self.albedo_map is not None:
284
+ # Create texture visuals
285
+ uv = self.v_tex.detach().cpu().numpy()
286
+
287
+ # Create texture visuals
288
+ visual = trimesh.visual.texture.TextureVisuals(
289
+ uv=uv,
290
+ material=trimesh.visual.material.SimpleMaterial()
291
+ )
292
+ mesh.visual = visual
293
+
294
+ return mesh
295
+
296
+ def scale_tensor(x, input_range, target_range):
297
+ """Scale tensor from input_range to target_range."""
298
+ x_unit = (x - input_range[0]) / (input_range[1] - input_range[0])
299
+ x_scaled = x_unit * (target_range[1] - target_range[0]) + target_range[0]
300
+ return x_scaled
triplaneturbo_executable/utils/mesh_exporter.py CHANGED
@@ -1,6 +1,3 @@
1
- from typing import Callable, Dict, List, Optional, Tuple, Any
2
- from jaxtyping import Float
3
- from torch import Tensor
4
  from dataclasses import dataclass
5
 
6
  import torch
@@ -16,36 +13,35 @@ from ..utils.general_utils import scale_tensor
16
  class ExporterOutput:
17
  save_name: str
18
  save_type: str
19
- params: Dict[str, Any]
20
 
21
 
22
  class IsosurfaceHelper(nn.Module):
23
- points_range: Tuple[float, float] = (0, 1)
24
 
25
  @property
26
- def grid_vertices(self) -> Float[Tensor, "N 3"]:
27
  raise NotImplementedError
28
 
29
  class DiffMarchingCubeHelper(IsosurfaceHelper):
30
  def __init__(
31
  self,
32
- resolution: int,
33
- point_range: Tuple[float, float] = (0, 1)
34
- ) -> None:
35
  super().__init__()
36
  self.resolution = resolution
37
  self.points_range = point_range
38
 
39
  from diso import DiffMC
40
- self.mc_func: Callable = DiffMC(dtype=torch.float32)
41
- self._grid_vertices: Optional[Float[Tensor, "N3 3"]] = None
42
- self._dummy: Float[Tensor, "..."]
43
  self.register_buffer(
44
  "_dummy", torch.zeros(0, dtype=torch.float32), persistent=False
45
  )
46
 
47
  @property
48
- def grid_vertices(self) -> Float[Tensor, "N3 3"]:
49
  if self._grid_vertices is None:
50
  # keep the vertices on CPU so that we can support very large resolution
51
  x, y, z = (
@@ -62,10 +58,10 @@ class DiffMarchingCubeHelper(IsosurfaceHelper):
62
 
63
  def forward(
64
  self,
65
- level: Float[Tensor, "N3 1"],
66
- deformation: Optional[Float[Tensor, "N3 3"]] = None,
67
  isovalue=0.0,
68
- ) -> Mesh:
69
  level = level.view(self.resolution, self.resolution, self.resolution)
70
  if deformation is not None:
71
  deformation = deformation.view(self.resolution, self.resolution, self.resolution, 3)
@@ -76,17 +72,17 @@ class DiffMarchingCubeHelper(IsosurfaceHelper):
76
 
77
 
78
  def isosurface(
79
- space_cache: Float[Tensor, "B ..."],
80
- forward_field: Callable,
81
- isosurface_helper: Callable,
82
- ) -> List[Mesh]:
83
 
84
  # the isosurface is dependent on the space cache
85
  # randomly detach isosurface method if it is differentiable
86
  # get the batchsize
87
  if torch.is_tensor(space_cache): #space cache
88
  batch_size = space_cache.shape[0]
89
- elif isinstance(space_cache, Dict): #hyper net
90
  # Dict[str, List[Float[Tensor, "B ..."]]]
91
  for key in space_cache.keys():
92
  batch_size = space_cache[key][0].shape[0]
@@ -141,11 +137,11 @@ def isosurface(
141
  return mesh_list
142
 
143
  def colorize_mesh(
144
- space_cache: Any,
145
- export_fn: Callable,
146
- mesh_list: List[Mesh],
147
- activation: Callable,
148
- ) -> List[Mesh]:
149
  """Colorize the mesh using the geometry's export function and space cache.
150
 
151
  Args:
@@ -199,10 +195,10 @@ class MeshExporter(SaverMixin):
199
  return x
200
 
201
  def export_obj(
202
- mesh: Mesh,
203
- save_path: str,
204
- save_normal: bool = False,
205
- ) -> List[str]:
206
  """
207
  Export mesh data to OBJ file format.
208
 
 
 
 
 
1
  from dataclasses import dataclass
2
 
3
  import torch
 
13
  class ExporterOutput:
14
  save_name: str
15
  save_type: str
16
+ params: dict
17
 
18
 
19
  class IsosurfaceHelper(nn.Module):
20
+ points_range = (0, 1)
21
 
22
  @property
23
+ def grid_vertices(self):
24
  raise NotImplementedError
25
 
26
  class DiffMarchingCubeHelper(IsosurfaceHelper):
27
  def __init__(
28
  self,
29
+ resolution,
30
+ point_range = (0, 1)
31
+ ):
32
  super().__init__()
33
  self.resolution = resolution
34
  self.points_range = point_range
35
 
36
  from diso import DiffMC
37
+ self.mc_func = DiffMC(dtype=torch.float32)
38
+ self._grid_vertices = None
 
39
  self.register_buffer(
40
  "_dummy", torch.zeros(0, dtype=torch.float32), persistent=False
41
  )
42
 
43
  @property
44
+ def grid_vertices(self):
45
  if self._grid_vertices is None:
46
  # keep the vertices on CPU so that we can support very large resolution
47
  x, y, z = (
 
58
 
59
  def forward(
60
  self,
61
+ level,
62
+ deformation = None,
63
  isovalue=0.0,
64
+ ):
65
  level = level.view(self.resolution, self.resolution, self.resolution)
66
  if deformation is not None:
67
  deformation = deformation.view(self.resolution, self.resolution, self.resolution, 3)
 
72
 
73
 
74
  def isosurface(
75
+ space_cache,
76
+ forward_field,
77
+ isosurface_helper,
78
+ ):
79
 
80
  # the isosurface is dependent on the space cache
81
  # randomly detach isosurface method if it is differentiable
82
  # get the batchsize
83
  if torch.is_tensor(space_cache): #space cache
84
  batch_size = space_cache.shape[0]
85
+ elif isinstance(space_cache, dict): #hyper net
86
  # Dict[str, List[Float[Tensor, "B ..."]]]
87
  for key in space_cache.keys():
88
  batch_size = space_cache[key][0].shape[0]
 
137
  return mesh_list
138
 
139
  def colorize_mesh(
140
+ space_cache,
141
+ export_fn,
142
+ mesh_list,
143
+ activation,
144
+ ):
145
  """Colorize the mesh using the geometry's export function and space cache.
146
 
147
  Args:
 
195
  return x
196
 
197
  def export_obj(
198
+ mesh,
199
+ save_path,
200
+ save_normal = False,
201
+ ):
202
  """
203
  Export mesh data to OBJ file format.
204
 
triplaneturbo_executable/utils/saving.py CHANGED
@@ -13,22 +13,15 @@ import wandb
13
  from matplotlib import cm
14
  from matplotlib.colors import LinearSegmentedColormap
15
  from PIL import Image, ImageDraw
16
- # from pytorch_lightning.loggers import WandbLogger
17
 
18
- from ..utils.mesh import Mesh
19
-
20
- from typing import Dict, List, Optional, Union, Any
21
- from omegaconf import DictConfig
22
- from jaxtyping import Float
23
- from torch import Tensor
24
 
25
  import threading
26
 
27
  class SaverMixin:
28
- _save_dir: Optional[str] = None
29
- # _wandb_logger: Optional[WandbLogger] = None
30
 
31
- def set_save_dir(self, save_dir: str):
32
  self._save_dir = save_dir
33
 
34
  def get_save_dir(self):
@@ -58,17 +51,6 @@ class SaverMixin:
58
  os.makedirs(os.path.dirname(save_path), exist_ok=True)
59
  return save_path
60
 
61
- # def create_loggers(self, cfg_loggers: DictConfig) -> None:
62
- # if "wandb" in cfg_loggers.keys() and cfg_loggers.wandb.enable:
63
- # self._wandb_logger = WandbLogger(
64
- # project=cfg_loggers.wandb.project, name=cfg_loggers.wandb.name
65
- # )
66
-
67
- # def get_loggers(self) -> List:
68
- # if self._wandb_logger:
69
- # return [self._wandb_logger]
70
- # else:
71
- # return []
72
 
73
  DEFAULT_RGB_KWARGS = {"data_format": "HWC", "data_range": (0, 1)}
74
  DEFAULT_UV_KWARGS = {
@@ -119,8 +101,8 @@ class SaverMixin:
119
  img,
120
  data_format,
121
  data_range,
122
- name: Optional[str] = None,
123
- step: Optional[int] = None,
124
  ):
125
  img = self.get_rgb_image_(img, data_format, data_range)
126
  cv2.imwrite(filename, img)
@@ -138,8 +120,8 @@ class SaverMixin:
138
  img,
139
  data_format=DEFAULT_RGB_KWARGS["data_format"],
140
  data_range=DEFAULT_RGB_KWARGS["data_range"],
141
- name: Optional[str] = None,
142
- step: Optional[int] = None,
143
  ) -> str:
144
  save_path = self.get_save_path(filename)
145
  self._save_rgb_image(save_path, img, data_format, data_range, name, step)
@@ -231,8 +213,8 @@ class SaverMixin:
231
  img,
232
  data_range,
233
  cmap,
234
- name: Optional[str] = None,
235
- step: Optional[int] = None,
236
  ):
237
  img = self.get_grayscale_image_(img, data_range, cmap)
238
  cv2.imwrite(filename, img)
@@ -250,8 +232,8 @@ class SaverMixin:
250
  img,
251
  data_range=DEFAULT_GRAYSCALE_KWARGS["data_range"],
252
  cmap=DEFAULT_GRAYSCALE_KWARGS["cmap"],
253
- name: Optional[str] = None,
254
- step: Optional[int] = None,
255
  ) -> str:
256
  save_path = self.get_save_path(filename)
257
  self._save_grayscale_image(save_path, img, data_range, cmap, name, step)
@@ -308,9 +290,9 @@ class SaverMixin:
308
  filename,
309
  imgs,
310
  align=DEFAULT_GRID_KWARGS["align"],
311
- name: Optional[str] = None,
312
- step: Optional[int] = None,
313
- texts: Optional[List[float]] = None,
314
  ):
315
  save_path = self.get_save_path(filename)
316
  img = self.get_image_grid_(imgs, align=align)
@@ -404,8 +386,8 @@ class SaverMixin:
404
  # matcher,
405
  # save_format="mp4",
406
  # fps=30,
407
- # name: Optional[str] = None,
408
- # step: Optional[int] = None,
409
  # ) -> str:
410
  # assert save_format in ["gif", "mp4"]
411
  # if not filename.endswith(save_format):
@@ -442,9 +424,9 @@ class SaverMixin:
442
  matcher,
443
  save_format="mp4",
444
  fps=30,
445
- name: Optional[str] = None,
446
- step: Optional[int] = None,
447
- multithreaded: bool = False
448
  ) -> str:
449
  assert save_format in ["gif", "mp4"]
450
  if not filename.endswith(save_format):
@@ -494,20 +476,19 @@ class SaverMixin:
494
 
495
  def save_obj(
496
  self,
497
- filename: str,
498
- mesh: Mesh,
499
- save_mat: bool = False,
500
- save_normal: bool = False,
501
- save_uv: bool = False,
502
- save_vertex_color: bool = False,
503
- map_Kd: Optional[Float[Tensor, "H W 3"]] = None,
504
- map_Ks: Optional[Float[Tensor, "H W 3"]] = None,
505
- map_Bump: Optional[Float[Tensor, "H W 3"]] = None,
506
- map_Pm: Optional[Float[Tensor, "H W 1"]] = None,
507
- map_Pr: Optional[Float[Tensor, "H W 1"]] = None,
508
- map_format: str = "jpg",
509
- ) -> List[str]:
510
-
511
  if not filename.endswith(".obj"):
512
  filename += ".obj"
513
  save_path = self.get_save_path(filename)
@@ -658,8 +639,8 @@ class SaverMixin:
658
  map_Pm=None,
659
  map_Pr=None,
660
  map_format="jpg",
661
- step: Optional[int] = None,
662
- ) -> List[str]:
663
  mtl_save_path = self.get_save_path(filename)
664
  save_paths = [mtl_save_path]
665
  mtl_str = f"newmtl {matname}\n"
 
13
  from matplotlib import cm
14
  from matplotlib.colors import LinearSegmentedColormap
15
  from PIL import Image, ImageDraw
 
16
 
 
 
 
 
 
 
17
 
18
  import threading
19
 
20
  class SaverMixin:
21
+ _save_dir = None
22
+ # _wandb_logger = None
23
 
24
+ def set_save_dir(self, save_dir):
25
  self._save_dir = save_dir
26
 
27
  def get_save_dir(self):
 
51
  os.makedirs(os.path.dirname(save_path), exist_ok=True)
52
  return save_path
53
 
 
 
 
 
 
 
 
 
 
 
 
54
 
55
  DEFAULT_RGB_KWARGS = {"data_format": "HWC", "data_range": (0, 1)}
56
  DEFAULT_UV_KWARGS = {
 
101
  img,
102
  data_format,
103
  data_range,
104
+ name=None,
105
+ step=None,
106
  ):
107
  img = self.get_rgb_image_(img, data_format, data_range)
108
  cv2.imwrite(filename, img)
 
120
  img,
121
  data_format=DEFAULT_RGB_KWARGS["data_format"],
122
  data_range=DEFAULT_RGB_KWARGS["data_range"],
123
+ name=None,
124
+ step=None,
125
  ) -> str:
126
  save_path = self.get_save_path(filename)
127
  self._save_rgb_image(save_path, img, data_format, data_range, name, step)
 
213
  img,
214
  data_range,
215
  cmap,
216
+ name=None,
217
+ step=None,
218
  ):
219
  img = self.get_grayscale_image_(img, data_range, cmap)
220
  cv2.imwrite(filename, img)
 
232
  img,
233
  data_range=DEFAULT_GRAYSCALE_KWARGS["data_range"],
234
  cmap=DEFAULT_GRAYSCALE_KWARGS["cmap"],
235
+ name=None,
236
+ step=None,
237
  ) -> str:
238
  save_path = self.get_save_path(filename)
239
  self._save_grayscale_image(save_path, img, data_range, cmap, name, step)
 
290
  filename,
291
  imgs,
292
  align=DEFAULT_GRID_KWARGS["align"],
293
+ name=None,
294
+ step=None,
295
+ texts=None,
296
  ):
297
  save_path = self.get_save_path(filename)
298
  img = self.get_image_grid_(imgs, align=align)
 
386
  # matcher,
387
  # save_format="mp4",
388
  # fps=30,
389
+ # name=None,
390
+ # step=None,
391
  # ) -> str:
392
  # assert save_format in ["gif", "mp4"]
393
  # if not filename.endswith(save_format):
 
424
  matcher,
425
  save_format="mp4",
426
  fps=30,
427
+ name=None,
428
+ step=None,
429
+ multithreaded=False
430
  ) -> str:
431
  assert save_format in ["gif", "mp4"]
432
  if not filename.endswith(save_format):
 
476
 
477
  def save_obj(
478
  self,
479
+ filename,
480
+ mesh,
481
+ save_mat=False,
482
+ save_normal=False,
483
+ save_uv=False,
484
+ save_vertex_color=False,
485
+ map_Kd=None,
486
+ map_Ks=None,
487
+ map_Bump=None,
488
+ map_Pm=None,
489
+ map_Pr=None,
490
+ map_format="jpg",
491
+ ):
 
492
  if not filename.endswith(".obj"):
493
  filename += ".obj"
494
  save_path = self.get_save_path(filename)
 
639
  map_Pm=None,
640
  map_Pr=None,
641
  map_format="jpg",
642
+ step=None,
643
+ ):
644
  mtl_save_path = self.get_save_path(filename)
645
  save_paths = [mtl_save_path]
646
  mtl_str = f"newmtl {matname}\n"