xinjie.wang commited on
Commit
b75f769
·
1 Parent(s): 7182bc7
Files changed (1) hide show
  1. common.py +36 -45
common.py CHANGED
@@ -11,6 +11,8 @@ import gradio as gr
11
  import numpy as np
12
  import spaces
13
  import torch
 
 
14
  import trimesh
15
  from easydict import EasyDict as edict
16
  from PIL import Image
@@ -56,6 +58,8 @@ from thirdparty.TRELLIS.trellis.representations import (
56
  Gaussian,
57
  MeshExtractResult,
58
  )
 
 
59
  from thirdparty.TRELLIS.trellis.utils import postprocessing_utils
60
  from thirdparty.TRELLIS.trellis.utils.render_utils import (
61
  render_frames,
@@ -79,6 +83,36 @@ DELIGHT = DelightingModel()
79
  IMAGESR_MODEL = ImageRealESRGAN(outscale=4)
80
 
81
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
82
  def download_kolors_weights() -> None:
83
  logger.info(f"Download kolors weights from huggingface...")
84
  subprocess.run(
@@ -323,7 +357,7 @@ def pack_state(gs: Gaussian, mesh: MeshExtractResult) -> dict:
323
  }
324
 
325
 
326
- @spaces.GPU
327
  def unpack_state(state: dict, device: str = "cpu") -> tuple[Gaussian, dict]:
328
  print("debug11")
329
  gs = Gaussian(
@@ -333,6 +367,7 @@ def unpack_state(state: dict, device: str = "cpu") -> tuple[Gaussian, dict]:
333
  scaling_bias=state["gaussian"]["scaling_bias"],
334
  opacity_bias=state["gaussian"]["opacity_bias"],
335
  scaling_activation=state["gaussian"]["scaling_activation"],
 
336
  )
337
  print("debug12")
338
  gs._xyz = torch.tensor(state["gaussian"]["_xyz"], device=device)
@@ -449,50 +484,6 @@ def image_to_3d(
449
  return state, video_path
450
 
451
 
452
- @spaces.GPU
453
- def extract_3d_representations(
454
- state: dict, enable_delight: bool, req: gr.Request
455
- ):
456
- output_root = TMP_DIR
457
- output_root = os.path.join(output_root, str(req.session_hash))
458
- gs_model, mesh_model = unpack_state(state)
459
-
460
- mesh = postprocessing_utils.to_glb(
461
- gs_model,
462
- mesh_model,
463
- simplify=0.9,
464
- texture_size=1024,
465
- verbose=True,
466
- )
467
- filename = "sample"
468
- gs_path = os.path.join(output_root, f"{filename}_gs.ply")
469
- gs_model.save_ply(gs_path)
470
-
471
- # Rotate mesh and GS by 90 degrees around Z-axis.
472
- rot_matrix = [[0, 0, -1], [0, 1, 0], [1, 0, 0]]
473
- # Addtional rotation for GS to align mesh.
474
- gs_rot = np.array([[1, 0, 0], [0, -1, 0], [0, 0, -1]]) @ np.array(
475
- rot_matrix
476
- )
477
- pose = GaussianOperator.trans_to_quatpose(gs_rot)
478
- aligned_gs_path = gs_path.replace(".ply", "_aligned.ply")
479
- GaussianOperator.resave_ply(
480
- in_ply=gs_path,
481
- out_ply=aligned_gs_path,
482
- instance_pose=pose,
483
- )
484
-
485
- mesh.vertices = mesh.vertices @ np.array(rot_matrix)
486
- mesh_obj_path = os.path.join(output_root, f"{filename}.obj")
487
- mesh.export(mesh_obj_path)
488
- mesh_glb_path = os.path.join(output_root, f"{filename}.glb")
489
- mesh.export(mesh_glb_path)
490
-
491
- torch.cuda.empty_cache()
492
-
493
- return mesh_glb_path, gs_path, mesh_obj_path, aligned_gs_path
494
-
495
-
496
  def extract_3d_representations_v2(
497
  state: dict,
498
  enable_delight: bool,
 
11
  import numpy as np
12
  import spaces
13
  import torch
14
+ import torch
15
+ import torch.nn.functional as F
16
  import trimesh
17
  from easydict import EasyDict as edict
18
  from PIL import Image
 
58
  Gaussian,
59
  MeshExtractResult,
60
  )
61
+ from thirdparty.TRELLIS.trellis.representations.gaussian.general_utils import inverse_sigmoid, strip_symmetric, build_scaling_rotation
62
+
63
  from thirdparty.TRELLIS.trellis.utils import postprocessing_utils
64
  from thirdparty.TRELLIS.trellis.utils.render_utils import (
65
  render_frames,
 
83
  IMAGESR_MODEL = ImageRealESRGAN(outscale=4)
84
 
85
 
86
+ def inverse_softplus(x):
87
+ return x + torch.log(-torch.expm1(-x))
88
+
89
+ def build_covariance_from_scaling_rotation(scaling, scaling_modifier, rotation):
90
+ L = build_scaling_rotation(scaling_modifier * scaling, rotation)
91
+ actual_covariance = L @ L.transpose(1, 2)
92
+ symm = strip_symmetric(actual_covariance)
93
+ return symm
94
+
95
+ def patched_setup_functions(self):
96
+ if self.scaling_activation_type == "exp":
97
+ self.scaling_activation = torch.exp
98
+ self.inverse_scaling_activation = torch.log
99
+ elif self.scaling_activation_type == "softplus":
100
+ self.scaling_activation = F.softplus
101
+ self.inverse_scaling_activation = inverse_softplus
102
+
103
+ self.covariance_activation = build_covariance_from_scaling_rotation
104
+ self.opacity_activation = torch.sigmoid
105
+ self.inverse_opacity_activation = inverse_sigmoid
106
+ self.rotation_activation = F.normalize
107
+
108
+ self.scale_bias = self.inverse_scaling_activation(torch.tensor(self.scaling_bias)).to(self.device)
109
+ self.rots_bias = torch.zeros((4)).to(self.device)
110
+ self.rots_bias[0] = 1
111
+ self.opacity_bias = self.inverse_opacity_activation(torch.tensor(self.opacity_bias)).to(self.device)
112
+
113
+ Gaussian.setup_functions = patched_setup_functions
114
+
115
+
116
  def download_kolors_weights() -> None:
117
  logger.info(f"Download kolors weights from huggingface...")
118
  subprocess.run(
 
357
  }
358
 
359
 
360
+ # @spaces.GPU
361
  def unpack_state(state: dict, device: str = "cpu") -> tuple[Gaussian, dict]:
362
  print("debug11")
363
  gs = Gaussian(
 
367
  scaling_bias=state["gaussian"]["scaling_bias"],
368
  opacity_bias=state["gaussian"]["opacity_bias"],
369
  scaling_activation=state["gaussian"]["scaling_activation"],
370
+ device=device,
371
  )
372
  print("debug12")
373
  gs._xyz = torch.tensor(state["gaussian"]["_xyz"], device=device)
 
484
  return state, video_path
485
 
486
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
487
  def extract_3d_representations_v2(
488
  state: dict,
489
  enable_delight: bool,