xinjie.wang commited on
Commit
7182bc7
·
1 Parent(s): c9bd2e0
Files changed (1) hide show
  1. common.py +2 -2
common.py CHANGED
@@ -323,7 +323,8 @@ def pack_state(gs: Gaussian, mesh: MeshExtractResult) -> dict:
323
  }
324
 
325
 
326
- def unpack_state(state: dict, device: str = "cuda") -> tuple[Gaussian, dict]:
 
327
  print("debug11")
328
  gs = Gaussian(
329
  aabb=state["gaussian"]["aabb"],
@@ -332,7 +333,6 @@ def unpack_state(state: dict, device: str = "cuda") -> tuple[Gaussian, dict]:
332
  scaling_bias=state["gaussian"]["scaling_bias"],
333
  opacity_bias=state["gaussian"]["opacity_bias"],
334
  scaling_activation=state["gaussian"]["scaling_activation"],
335
- device="cpu",
336
  )
337
  print("debug12")
338
  gs._xyz = torch.tensor(state["gaussian"]["_xyz"], device=device)
 
323
  }
324
 
325
 
326
+ @spaces.GPU
327
+ def unpack_state(state: dict, device: str = "cpu") -> tuple[Gaussian, dict]:
328
  print("debug11")
329
  gs = Gaussian(
330
  aabb=state["gaussian"]["aabb"],
 
333
  scaling_bias=state["gaussian"]["scaling_bias"],
334
  opacity_bias=state["gaussian"]["opacity_bias"],
335
  scaling_activation=state["gaussian"]["scaling_activation"],
 
336
  )
337
  print("debug12")
338
  gs._xyz = torch.tensor(state["gaussian"]["_xyz"], device=device)