dkatz2391 commited on
Commit
c7a648a
·
verified ·
1 Parent(s): e819550

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +95 -130
app.py CHANGED
@@ -1,11 +1,11 @@
1
- # Version: 1.1.2 - Removed torch_dtype from from_pretrained call
2
- # Applied:
3
  # - Removed unsupported inputs/outputs kwargs on demo.load/unload
4
  # - Converted NumPy arrays to lists in pack_state for JSON safety
5
  # - Fixed indentation in Blocks event-handlers
6
  # - Verified clear() callbacks use only callback + outputs
7
- # - Removed `torch_dtype` arg from TrellisTextTo3DPipeline.from_pretrained
8
- # - Bumped version, added comments at change sites
9
 
10
  import gradio as gr
11
  import spaces
@@ -25,10 +25,26 @@ from trellis.utils import render_utils, postprocessing_utils
25
  import traceback
26
  import sys
27
 
 
28
  MAX_SEED = np.iinfo(np.int32).max
29
  TMP_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'tmp')
30
  os.makedirs(TMP_DIR, exist_ok=True)
31
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
 
33
  def start_session(req: gr.Request):
34
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
@@ -50,10 +66,9 @@ def end_session(req: gr.Request):
50
 
51
  def pack_state(gs: Gaussian, mesh: MeshExtractResult) -> dict:
52
  """Packs Gaussian and Mesh data into a JSON-serializable dictionary."""
53
- packed_data = {
54
  'gaussian': {
55
  **{k: v for k, v in gs.init_params.items()},
56
- # FIX: convert arrays to lists for JSON
57
  '_xyz': gs._xyz.detach().cpu().numpy().tolist(),
58
  '_features_dc': gs._features_dc.detach().cpu().numpy().tolist(),
59
  '_scaling': gs._scaling.detach().cpu().numpy().tolist(),
@@ -65,178 +80,128 @@ def pack_state(gs: Gaussian, mesh: MeshExtractResult) -> dict:
65
  'faces': mesh.faces.detach().cpu().numpy().tolist(),
66
  },
67
  }
68
- return packed_data
69
 
70
 
71
  def unpack_state(state_dict: dict) -> Tuple[Gaussian, edict]:
72
- print("[unpack_state] Unpacking state from dictionary... ")
73
  device = 'cuda' if torch.cuda.is_available() else 'cpu'
74
- gauss_data = state_dict['gaussian']
75
- mesh_data = state_dict['mesh']
76
  gs = Gaussian(
77
- aabb=gauss_data.get('aabb'),
78
- sh_degree=gauss_data.get('sh_degree'),
79
- mininum_kernel_size=gauss_data.get('mininum_kernel_size'),
80
- scaling_bias=gauss_data.get('scaling_bias'),
81
- opacity_bias=gauss_data.get('opacity_bias'),
82
- scaling_activation=gauss_data.get('scaling_activation'),
83
  )
84
- gs._xyz = torch.tensor(np.array(gauss_data['_xyz']), device=device, dtype=torch.float32)
85
- gs._features_dc = torch.tensor(np.array(gauss_data['_features_dc']), device=device, dtype=torch.float32)
86
- gs._scaling = torch.tensor(np.array(gauss_data['_scaling']), device=device, dtype=torch.float32)
87
- gs._rotation = torch.tensor(np.array(gauss_data['_rotation']), device=device, dtype=torch.float32)
88
- gs._opacity = torch.tensor(np.array(gauss_data['_opacity']), device=device, dtype=torch.float32)
89
  mesh = edict(
90
- vertices=torch.tensor(np.array(mesh_data['vertices']), device=device, dtype=torch.float32),
91
- faces=torch.tensor(np.array(mesh_data['faces']), device=device, dtype=torch.int64),
92
  )
93
  return gs, mesh
94
 
95
 
96
  def get_seed(randomize_seed: bool, seed: int) -> int:
97
- new_seed = np.random.randint(0, MAX_SEED) if randomize_seed else seed
98
- return int(new_seed)
99
 
100
  @spaces.GPU
101
  def text_to_3d(
102
- prompt: str,
103
- seed: int,
104
- ss_guidance_strength: float,
105
- ss_sampling_steps: int,
106
- slat_guidance_strength: float,
107
- slat_sampling_steps: int,
108
- req: gr.Request,
109
  ) -> Tuple[dict, str]:
110
- outputs = pipeline.run(
111
- prompt,
112
- seed=seed,
113
- formats=["gaussian", "mesh"],
114
- sparse_structure_sampler_params={"steps": int(ss_sampling_steps), "cfg_strength": float(ss_guidance_strength)},
115
- slat_sampler_params={"steps": int(slat_sampling_steps), "cfg_strength": float(slat_guidance_strength)},
116
  )
117
- state_dict = pack_state(outputs['gaussian'][0], outputs['mesh'][0])
118
- video = render_utils.render_video(outputs['gaussian'][0], num_frames=120)['color']
119
- video_geo = render_utils.render_video(outputs['mesh'][0], num_frames=120)['normal']
120
- video_combined = [np.concatenate([v.astype(np.uint8), vg.astype(np.uint8)], axis=1) for v, vg in zip(video, video_geo)]
121
- user_dir = os.path.join(TMP_DIR, str(req.session_hash))
122
- os.makedirs(user_dir, exist_ok=True)
123
- video_path = os.path.join(user_dir, 'sample.mp4')
124
- imageio.mimsave(video_path, video_combined, fps=15, quality=8)
125
  if torch.cuda.is_available(): torch.cuda.empty_cache()
126
- return state_dict, video_path
127
 
128
  @spaces.GPU(duration=120)
129
- def extract_glb(
130
- state_dict: dict,
131
- mesh_simplify: float,
132
- texture_size: int,
133
- req: gr.Request,
134
- ) -> Tuple[str, str]:
135
  gs, mesh = unpack_state(state_dict)
136
- user_dir = os.path.join(TMP_DIR, str(req.session_hash))
137
- os.makedirs(user_dir, exist_ok=True)
138
- glb = postprocessing_utils.to_glb(gs, mesh, simplify=float(mesh_simplify), texture_size=int(texture_size), verbose=True)
139
- glb_path = os.path.join(user_dir, 'sample.glb')
140
- glb.export(glb_path)
141
  if torch.cuda.is_available(): torch.cuda.empty_cache()
142
- return glb_path, glb_path
143
 
144
  @spaces.GPU
145
- def extract_gaussian(
146
- state_dict: dict,
147
- req: gr.Request
148
- ) -> Tuple[str, str]:
149
  gs, _ = unpack_state(state_dict)
150
- user_dir = os.path.join(TMP_DIR, str(req.session_hash))
151
- os.makedirs(user_dir, exist_ok=True)
152
- gaussian_path = os.path.join(user_dir, 'sample.ply')
153
- gs.save_ply(gaussian_path)
154
  if torch.cuda.is_available(): torch.cuda.empty_cache()
155
- return gaussian_path, gaussian_path
156
 
157
- # --- Gradio UI Definition ---
158
- with gr.Blocks(delete_cache=(600, 600), title="TRELLIS Text-to-3D") as demo:
159
  gr.Markdown("""
160
- # Text to 3D Asset with [TRELLIS](https://trellis3d.github.io/)
161
  """)
162
-
163
- # State buffer
164
  output_buf = gr.State()
165
-
166
  with gr.Row():
167
  with gr.Column(scale=1):
168
  text_prompt = gr.Textbox(label="Text Prompt", lines=5)
169
- with gr.Accordion(label="Generation Settings", open=False):
170
  seed = gr.Slider(0, MAX_SEED, label="Seed", value=0, step=1)
171
  randomize_seed = gr.Checkbox(label="Randomize Seed", value=True)
172
- gr.Markdown("---\n**Stage 1**")
173
- ss_guidance_strength = gr.Slider(0.0, 15.0, label="Guidance Strength", value=7.5, step=0.1)
174
- ss_sampling_steps = gr.Slider(10, 50, label="Sampling Steps", value=25, step=1)
175
- gr.Markdown("---\n**Stage 2**")
176
- slat_guidance_strength = gr.Slider(0.0, 15.0, label="Guidance Strength", value=7.5, step=0.1)
177
- slat_sampling_steps = gr.Slider(10, 50, label="Sampling Steps", value=25, step=1)
178
- generate_btn = gr.Button("Generate 3D Preview", variant="primary")
179
- with gr.Accordion(label="GLB Extraction Settings", open=True):
180
- mesh_simplify = gr.Slider(0.9, 0.99, label="Simplify Factor", value=0.95, step=0.01)
181
- texture_size = gr.Slider(512, 2048, label="Texture Size", value=1024, step=512)
182
  extract_glb_btn = gr.Button("Extract GLB", interactive=False)
183
- extract_gs_btn = gr.Button("Extract Gaussian (PLY)", interactive=False)
184
- download_glb = gr.DownloadButton(label="Download GLB", interactive=False)
185
- download_gs = gr.DownloadButton(label="Download Gaussian (PLY)", interactive=False)
186
  with gr.Column(scale=1):
187
- video_output = gr.Video(label="3D Preview", autoplay=True, loop=True)
188
- model_output = gr.Model3D(label="Extracted Model Preview")
189
 
190
- # --- Event handlers ---
191
- demo.load(start_session) # FIX: remove inputs/outputs kwargs
192
- demo.unload(end_session) # FIX: remove inputs/outputs kwargs
193
 
194
- # Align indentation to one level under Blocks
195
  generate_event = generate_btn.click(
196
  get_seed,
197
- inputs=[randomize_seed, seed],
198
- outputs=[seed],
199
  ).then(
200
  text_to_3d,
201
- inputs=[text_prompt, seed, ss_guidance_strength, ss_sampling_steps, slat_guidance_strength, slat_sampling_steps],
202
- outputs=[output_buf, video_output],
203
- ).then(
204
- lambda: (extract_glb_btn.update(interactive=True), extract_gs_btn.update(interactive=True)),
205
- outputs=[extract_glb_btn, extract_gs_btn],
206
- )
207
 
208
- extract_glb_event = extract_glb_btn.click(
209
  extract_glb,
210
- inputs=[output_buf, mesh_simplify, texture_size],
211
- outputs=[model_output, download_glb],
212
- ).then(
213
- lambda: download_glb.update(interactive=True),
214
- outputs=[download_glb],
215
- )
216
 
217
- extract_gs_event = extract_gs_btn.click(
218
  extract_gaussian,
219
- inputs=[output_buf],
220
- outputs=[model_output, download_gs],
221
- ).then(
222
- lambda: download_gaussian.update(interactive=True),
223
- outputs=[download_gs],
224
- )
225
 
226
- # Clear callbacks
227
- model_output.clear(
228
- lambda: (download_glb.update(interactive=False), download_gs.update(interactive=False)),
229
- outputs=[download_glb, download_gs],
230
- )
231
- video_output.clear(
232
- lambda: (extract_glb_btn.update(interactive=False), extract_gs_btn.update(interactive=False), download_glb.update(interactive=False), download_gs.update(interactive=False)),
233
- outputs=[extract_glb_btn, extract_gs_btn, download_glb, download_gs],
234
- )
235
 
236
  if __name__ == "__main__":
237
- # Removed torch_dtype argument to match current API
238
- pipeline = TrellisTextTo3DPipeline.from_pretrained(
239
- "JeffreyXiang/TRELLIS-text-xlarge"
240
- )
241
- if torch.cuda.is_available(): pipeline = pipeline.to("cuda")
242
  demo.queue().launch(debug=True)
 
1
+ # Version: 1.1.3 - Load pipeline at module level for Spaces environment
2
+ # Applied targeted fixes:
3
  # - Removed unsupported inputs/outputs kwargs on demo.load/unload
4
  # - Converted NumPy arrays to lists in pack_state for JSON safety
5
  # - Fixed indentation in Blocks event-handlers
6
  # - Verified clear() callbacks use only callback + outputs
7
+ # - Removed `torch_dtype` arg from from_pretrained
8
+ # - Moved pipeline initialization to module level so it's available in threads
9
 
10
  import gradio as gr
11
  import spaces
 
25
  import traceback
26
  import sys
27
 
28
+ # --- Global Config ---
29
  MAX_SEED = np.iinfo(np.int32).max
30
  TMP_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'tmp')
31
  os.makedirs(TMP_DIR, exist_ok=True)
32
 
33
+ # --- Initialize Trellis Pipeline at import time ---
34
+ print("[Startup] Loading Trellis pipeline...")
35
+ try:
36
+ pipeline = TrellisTextTo3DPipeline.from_pretrained(
37
+ "JeffreyXiang/TRELLIS-text-xlarge"
38
+ )
39
+ if torch.cuda.is_available():
40
+ pipeline = pipeline.to("cuda")
41
+ print("[Startup] Trellis pipeline loaded to GPU.")
42
+ else:
43
+ print("[Startup] Trellis pipeline loaded to CPU.")
44
+ except Exception as e:
45
+ print(f"❌ [Startup] Failed to load Trellis pipeline: {e}")
46
+ raise
47
+
48
 
49
  def start_session(req: gr.Request):
50
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
 
66
 
67
  def pack_state(gs: Gaussian, mesh: MeshExtractResult) -> dict:
68
  """Packs Gaussian and Mesh data into a JSON-serializable dictionary."""
69
+ return {
70
  'gaussian': {
71
  **{k: v for k, v in gs.init_params.items()},
 
72
  '_xyz': gs._xyz.detach().cpu().numpy().tolist(),
73
  '_features_dc': gs._features_dc.detach().cpu().numpy().tolist(),
74
  '_scaling': gs._scaling.detach().cpu().numpy().tolist(),
 
80
  'faces': mesh.faces.detach().cpu().numpy().tolist(),
81
  },
82
  }
 
83
 
84
 
85
  def unpack_state(state_dict: dict) -> Tuple[Gaussian, edict]:
 
86
  device = 'cuda' if torch.cuda.is_available() else 'cpu'
87
+ gd = state_dict['gaussian']
88
+ md = state_dict['mesh']
89
  gs = Gaussian(
90
+ aabb=gd.get('aabb'), sh_degree=gd.get('sh_degree'),
91
+ mininum_kernel_size=gd.get('mininum_kernel_size'),
92
+ scaling_bias=gd.get('scaling_bias'), opacity_bias=gd.get('opacity_bias'),
93
+ scaling_activation=gd.get('scaling_activation')
 
 
94
  )
95
+ gs._xyz = torch.tensor(np.array(gd['_xyz']), device=device, dtype=torch.float32)
96
+ gs._features_dc = torch.tensor(np.array(gd['_features_dc']), device=device, dtype=torch.float32)
97
+ gs._scaling = torch.tensor(np.array(gd['_scaling']), device=device, dtype=torch.float32)
98
+ gs._rotation = torch.tensor(np.array(gd['_rotation']), device=device, dtype=torch.float32)
99
+ gs._opacity = torch.tensor(np.array(gd['_opacity']), device=device, dtype=torch.float32)
100
  mesh = edict(
101
+ vertices=torch.tensor(np.array(md['vertices']), device=device, dtype=torch.float32),
102
+ faces=torch.tensor(np.array(md['faces']), device=device, dtype=torch.int64),
103
  )
104
  return gs, mesh
105
 
106
 
107
  def get_seed(randomize_seed: bool, seed: int) -> int:
108
+ return int(np.random.randint(0, MAX_SEED) if randomize_seed else seed)
 
109
 
110
  @spaces.GPU
111
  def text_to_3d(
112
+ prompt: str, seed: int,
113
+ ss_guidance_strength: float, ss_sampling_steps: int,
114
+ slat_guidance_strength: float, slat_sampling_steps: int,
115
+ req: gr.Request
 
 
 
116
  ) -> Tuple[dict, str]:
117
+ out = pipeline.run(
118
+ prompt, seed=seed,
119
+ formats=["gaussian","mesh"],
120
+ sparse_structure_sampler_params={"steps": ss_sampling_steps, "cfg_strength": ss_guidance_strength},
121
+ slat_sampler_params={"steps": slat_sampling_steps, "cfg_strength": slat_guidance_strength}
 
122
  )
123
+ state = pack_state(out['gaussian'][0], out['mesh'][0])
124
+ vid_c = render_utils.render_video(out['gaussian'][0],num_frames=120)['color']
125
+ vid_n = render_utils.render_video(out['mesh'][0],num_frames=120)['normal']
126
+ vid = [np.concatenate([c.astype(np.uint8), n.astype(np.uint8)], axis=1) for c,n in zip(vid_c,vid_n)]
127
+ ud = os.path.join(TMP_DIR,str(req.session_hash)); os.makedirs(ud,exist_ok=True)
128
+ vp = os.path.join(ud,'sample.mp4'); imageio.mimsave(vp,vid,fps=15,quality=8)
 
 
129
  if torch.cuda.is_available(): torch.cuda.empty_cache()
130
+ return state, vp
131
 
132
  @spaces.GPU(duration=120)
133
+ def extract_glb(state_dict: dict, mesh_simplify: float, texture_size: int, req: gr.Request):
 
 
 
 
 
134
  gs, mesh = unpack_state(state_dict)
135
+ ud = os.path.join(TMP_DIR, str(req.session_hash)); os.makedirs(ud, exist_ok=True)
136
+ glb = postprocessing_utils.to_glb(gs,mesh,simplify=mesh_simplify,texture_size=texture_size,verbose=True)
137
+ gp = os.path.join(ud,'sample.glb'); glb.export(gp)
 
 
138
  if torch.cuda.is_available(): torch.cuda.empty_cache()
139
+ return gp, gp
140
 
141
  @spaces.GPU
142
+ def extract_gaussian(state_dict: dict, req: gr.Request):
 
 
 
143
  gs, _ = unpack_state(state_dict)
144
+ ud = os.path.join(TMP_DIR, str(req.session_hash)); os.makedirs(ud, exist_ok=True)
145
+ pp = os.path.join(ud,'sample.ply'); gs.save_ply(pp)
 
 
146
  if torch.cuda.is_available(): torch.cuda.empty_cache()
147
+ return pp, pp
148
 
149
+ # --- Gradio UI ---
150
+ with gr.Blocks(delete_cache=(600,600), title="TRELLIS Text-to-3D") as demo:
151
  gr.Markdown("""
152
+ # Text to 3D Asset with TRELLIS
153
  """)
 
 
154
  output_buf = gr.State()
 
155
  with gr.Row():
156
  with gr.Column(scale=1):
157
  text_prompt = gr.Textbox(label="Text Prompt", lines=5)
158
+ with gr.Accordion("Generation Settings", open=False):
159
  seed = gr.Slider(0, MAX_SEED, label="Seed", value=0, step=1)
160
  randomize_seed = gr.Checkbox(label="Randomize Seed", value=True)
161
+ gr.Markdown("--- Stage 1 ---")
162
+ ss_guidance_strength = gr.Slider(0.0,15.0,label="Guidance Strength",value=7.5,step=0.1)
163
+ ss_sampling_steps = gr.Slider(10,50,label="Steps",value=25,step=1)
164
+ gr.Markdown("--- Stage 2 ---")
165
+ slat_guidance_strength = gr.Slider(0.0,15.0,label="Guidance Strength",value=7.5,step=0.1)
166
+ slat_sampling_steps = gr.Slider(10,50,label="Steps",value=25,step=1)
167
+ generate_btn = gr.Button("Generate 3D Preview")
168
+ with gr.Accordion("GLB Extraction Settings", open=True):
169
+ mesh_simplify = gr.Slider(0.9,0.99,label="Simplify",value=0.95,step=0.01)
170
+ texture_size = gr.Slider(512,2048,label="Texture Size",value=1024,step=512)
171
  extract_glb_btn = gr.Button("Extract GLB", interactive=False)
172
+ extract_gs_btn = gr.Button("Extract Gaussian", interactive=False)
173
+ download_glb = gr.DownloadButton("Download GLB", interactive=False)
174
+ download_gs = gr.DownloadButton("Download Gaussian", interactive=False)
175
  with gr.Column(scale=1):
176
+ video_output = gr.Video(autoplay=True,loop=True)
177
+ model_output = gr.Model3D()
178
 
179
+ # --- Handlers ---
180
+ demo.load(start_session)
181
+ demo.unload(end_session)
182
 
 
183
  generate_event = generate_btn.click(
184
  get_seed,
185
+ inputs=[randomize_seed,seed], outputs=[seed]
 
186
  ).then(
187
  text_to_3d,
188
+ inputs=[text_prompt,seed,ss_guidance_strength,ss_sampling_steps,slat_guidance_strength,slat_sampling_steps],
189
+ outputs=[output_buf,video_output]
190
+ ).then(lambda: (extract_glb_btn.update(interactive=True),extract_gs_btn.update(interactive=True)), outputs=[extract_glb_btn,extract_gs_btn])
 
 
 
191
 
192
+ extract_glb_btn.click(
193
  extract_glb,
194
+ inputs=[output_buf,mesh_simplify,texture_size],
195
+ outputs=[model_output,download_glb]
196
+ ).then(lambda: download_glb.update(interactive=True), outputs=[download_glb])
 
 
 
197
 
198
+ extract_gs_btn.click(
199
  extract_gaussian,
200
+ inputs=[output_buf], outputs=[model_output,download_gs]
201
+ ).then(lambda: download_gs.update(interactive=True), outputs=[download_gs])
 
 
 
 
202
 
203
+ model_output.clear(lambda: (download_glb.update(interactive=False),download_gs.update(interactive=False)), outputs=[download_glb,download_gs])
204
+ video_output.clear(lambda: (extract_glb_btn.update(interactive=False),extract_gs_btn.update(interactive=False),download_glb.update(interactive=False),download_gs.update(interactive=False)), outputs=[extract_glb_btn,extract_gs_btn,download_glb,download_gs])
 
 
 
 
 
 
 
205
 
206
  if __name__ == "__main__":
 
 
 
 
 
207
  demo.queue().launch(debug=True)