dkatz2391 commited on
Commit
e7bca60
·
verified ·
1 Parent(s): c7a648a

revert back to his am restart of spcae cursor

Browse files
Files changed (1) hide show
  1. app.py +203 -136
app.py CHANGED
@@ -1,19 +1,10 @@
1
- # Version: 1.1.3 - Load pipeline at module level for Spaces environment
2
- # Applied targeted fixes:
3
- # - Removed unsupported inputs/outputs kwargs on demo.load/unload
4
- # - Converted NumPy arrays to lists in pack_state for JSON safety
5
- # - Fixed indentation in Blocks event-handlers
6
- # - Verified clear() callbacks use only callback + outputs
7
- # - Removed `torch_dtype` arg from from_pretrained
8
- # - Moved pipeline initialization to module level so it's available in threads
9
-
10
  import gradio as gr
11
  import spaces
 
12
  import os
13
  import shutil
14
  os.environ['TOKENIZERS_PARALLELISM'] = 'true'
15
  os.environ['SPCONV_ALGO'] = 'native'
16
-
17
  from typing import *
18
  import torch
19
  import numpy as np
@@ -22,186 +13,262 @@ from easydict import EasyDict as edict
22
  from trellis.pipelines import TrellisTextTo3DPipeline
23
  from trellis.representations import Gaussian, MeshExtractResult
24
  from trellis.utils import render_utils, postprocessing_utils
 
25
  import traceback
26
  import sys
27
 
28
- # --- Global Config ---
29
  MAX_SEED = np.iinfo(np.int32).max
30
  TMP_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'tmp')
31
  os.makedirs(TMP_DIR, exist_ok=True)
32
 
33
- # --- Initialize Trellis Pipeline at import time ---
34
- print("[Startup] Loading Trellis pipeline...")
35
- try:
36
- pipeline = TrellisTextTo3DPipeline.from_pretrained(
37
- "JeffreyXiang/TRELLIS-text-xlarge"
38
- )
39
- if torch.cuda.is_available():
40
- pipeline = pipeline.to("cuda")
41
- print("[Startup] Trellis pipeline loaded to GPU.")
42
- else:
43
- print("[Startup] Trellis pipeline loaded to CPU.")
44
- except Exception as e:
45
- print(f"❌ [Startup] Failed to load Trellis pipeline: {e}")
46
- raise
47
-
48
 
49
  def start_session(req: gr.Request):
50
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
51
  os.makedirs(user_dir, exist_ok=True)
52
- print(f"Started session, created directory: {user_dir}")
53
-
54
-
55
  def end_session(req: gr.Request):
56
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
57
- if os.path.exists(user_dir):
58
- try:
59
- shutil.rmtree(user_dir)
60
- print(f"Ended session, removed directory: {user_dir}")
61
- except OSError as e:
62
- print(f"Error removing tmp directory {user_dir}: {e.strerror}", file=sys.stderr)
63
- else:
64
- print(f"Ended session, directory already removed: {user_dir}")
65
 
66
 
67
  def pack_state(gs: Gaussian, mesh: MeshExtractResult) -> dict:
68
- """Packs Gaussian and Mesh data into a JSON-serializable dictionary."""
69
  return {
70
  'gaussian': {
71
- **{k: v for k, v in gs.init_params.items()},
72
- '_xyz': gs._xyz.detach().cpu().numpy().tolist(),
73
- '_features_dc': gs._features_dc.detach().cpu().numpy().tolist(),
74
- '_scaling': gs._scaling.detach().cpu().numpy().tolist(),
75
- '_rotation': gs._rotation.detach().cpu().numpy().tolist(),
76
- '_opacity': gs._opacity.detach().cpu().numpy().tolist(),
77
  },
78
  'mesh': {
79
- 'vertices': mesh.vertices.detach().cpu().numpy().tolist(),
80
- 'faces': mesh.faces.detach().cpu().numpy().tolist(),
81
  },
82
  }
83
-
84
-
85
- def unpack_state(state_dict: dict) -> Tuple[Gaussian, edict]:
86
- device = 'cuda' if torch.cuda.is_available() else 'cpu'
87
- gd = state_dict['gaussian']
88
- md = state_dict['mesh']
89
  gs = Gaussian(
90
- aabb=gd.get('aabb'), sh_degree=gd.get('sh_degree'),
91
- mininum_kernel_size=gd.get('mininum_kernel_size'),
92
- scaling_bias=gd.get('scaling_bias'), opacity_bias=gd.get('opacity_bias'),
93
- scaling_activation=gd.get('scaling_activation')
 
 
94
  )
95
- gs._xyz = torch.tensor(np.array(gd['_xyz']), device=device, dtype=torch.float32)
96
- gs._features_dc = torch.tensor(np.array(gd['_features_dc']), device=device, dtype=torch.float32)
97
- gs._scaling = torch.tensor(np.array(gd['_scaling']), device=device, dtype=torch.float32)
98
- gs._rotation = torch.tensor(np.array(gd['_rotation']), device=device, dtype=torch.float32)
99
- gs._opacity = torch.tensor(np.array(gd['_opacity']), device=device, dtype=torch.float32)
 
100
  mesh = edict(
101
- vertices=torch.tensor(np.array(md['vertices']), device=device, dtype=torch.float32),
102
- faces=torch.tensor(np.array(md['faces']), device=device, dtype=torch.int64),
103
  )
 
104
  return gs, mesh
105
 
106
 
107
  def get_seed(randomize_seed: bool, seed: int) -> int:
108
- return int(np.random.randint(0, MAX_SEED) if randomize_seed else seed)
 
 
 
 
109
 
110
  @spaces.GPU
111
  def text_to_3d(
112
- prompt: str, seed: int,
113
- ss_guidance_strength: float, ss_sampling_steps: int,
114
- slat_guidance_strength: float, slat_sampling_steps: int,
115
- req: gr.Request
 
 
 
116
  ) -> Tuple[dict, str]:
117
- out = pipeline.run(
118
- prompt, seed=seed,
119
- formats=["gaussian","mesh"],
120
- sparse_structure_sampler_params={"steps": ss_sampling_steps, "cfg_strength": ss_guidance_strength},
121
- slat_sampler_params={"steps": slat_sampling_steps, "cfg_strength": slat_guidance_strength}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
122
  )
123
- state = pack_state(out['gaussian'][0], out['mesh'][0])
124
- vid_c = render_utils.render_video(out['gaussian'][0],num_frames=120)['color']
125
- vid_n = render_utils.render_video(out['mesh'][0],num_frames=120)['normal']
126
- vid = [np.concatenate([c.astype(np.uint8), n.astype(np.uint8)], axis=1) for c,n in zip(vid_c,vid_n)]
127
- ud = os.path.join(TMP_DIR,str(req.session_hash)); os.makedirs(ud,exist_ok=True)
128
- vp = os.path.join(ud,'sample.mp4'); imageio.mimsave(vp,vid,fps=15,quality=8)
129
- if torch.cuda.is_available(): torch.cuda.empty_cache()
130
- return state, vp
131
-
132
- @spaces.GPU(duration=120)
133
- def extract_glb(state_dict: dict, mesh_simplify: float, texture_size: int, req: gr.Request):
134
- gs, mesh = unpack_state(state_dict)
135
- ud = os.path.join(TMP_DIR, str(req.session_hash)); os.makedirs(ud, exist_ok=True)
136
- glb = postprocessing_utils.to_glb(gs,mesh,simplify=mesh_simplify,texture_size=texture_size,verbose=True)
137
- gp = os.path.join(ud,'sample.glb'); glb.export(gp)
138
- if torch.cuda.is_available(): torch.cuda.empty_cache()
139
- return gp, gp
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
140
 
141
  @spaces.GPU
142
- def extract_gaussian(state_dict: dict, req: gr.Request):
143
- gs, _ = unpack_state(state_dict)
144
- ud = os.path.join(TMP_DIR, str(req.session_hash)); os.makedirs(ud, exist_ok=True)
145
- pp = os.path.join(ud,'sample.ply'); gs.save_ply(pp)
146
- if torch.cuda.is_available(): torch.cuda.empty_cache()
147
- return pp, pp
148
-
149
- # --- Gradio UI ---
150
- with gr.Blocks(delete_cache=(600,600), title="TRELLIS Text-to-3D") as demo:
 
 
 
 
 
 
 
 
 
 
 
 
151
  gr.Markdown("""
152
- # Text to 3D Asset with TRELLIS
 
 
153
  """)
154
- output_buf = gr.State()
155
  with gr.Row():
156
- with gr.Column(scale=1):
157
  text_prompt = gr.Textbox(label="Text Prompt", lines=5)
158
- with gr.Accordion("Generation Settings", open=False):
 
159
  seed = gr.Slider(0, MAX_SEED, label="Seed", value=0, step=1)
160
  randomize_seed = gr.Checkbox(label="Randomize Seed", value=True)
161
- gr.Markdown("--- Stage 1 ---")
162
- ss_guidance_strength = gr.Slider(0.0,15.0,label="Guidance Strength",value=7.5,step=0.1)
163
- ss_sampling_steps = gr.Slider(10,50,label="Steps",value=25,step=1)
164
- gr.Markdown("--- Stage 2 ---")
165
- slat_guidance_strength = gr.Slider(0.0,15.0,label="Guidance Strength",value=7.5,step=0.1)
166
- slat_sampling_steps = gr.Slider(10,50,label="Steps",value=25,step=1)
167
- generate_btn = gr.Button("Generate 3D Preview")
168
- with gr.Accordion("GLB Extraction Settings", open=True):
169
- mesh_simplify = gr.Slider(0.9,0.99,label="Simplify",value=0.95,step=0.01)
170
- texture_size = gr.Slider(512,2048,label="Texture Size",value=1024,step=512)
171
- extract_glb_btn = gr.Button("Extract GLB", interactive=False)
172
- extract_gs_btn = gr.Button("Extract Gaussian", interactive=False)
173
- download_glb = gr.DownloadButton("Download GLB", interactive=False)
174
- download_gs = gr.DownloadButton("Download Gaussian", interactive=False)
175
- with gr.Column(scale=1):
176
- video_output = gr.Video(autoplay=True,loop=True)
177
- model_output = gr.Model3D()
178
-
179
- # --- Handlers ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
180
  demo.load(start_session)
181
  demo.unload(end_session)
182
 
183
- generate_event = generate_btn.click(
184
  get_seed,
185
- inputs=[randomize_seed,seed], outputs=[seed]
 
186
  ).then(
187
  text_to_3d,
188
- inputs=[text_prompt,seed,ss_guidance_strength,ss_sampling_steps,slat_guidance_strength,slat_sampling_steps],
189
- outputs=[output_buf,video_output]
190
- ).then(lambda: (extract_glb_btn.update(interactive=True),extract_gs_btn.update(interactive=True)), outputs=[extract_glb_btn,extract_gs_btn])
 
 
 
 
 
 
 
 
191
 
192
  extract_glb_btn.click(
193
  extract_glb,
194
- inputs=[output_buf,mesh_simplify,texture_size],
195
- outputs=[model_output,download_glb]
196
- ).then(lambda: download_glb.update(interactive=True), outputs=[download_glb])
197
-
 
 
 
198
  extract_gs_btn.click(
199
  extract_gaussian,
200
- inputs=[output_buf], outputs=[model_output,download_gs]
201
- ).then(lambda: download_gs.update(interactive=True), outputs=[download_gs])
 
 
 
 
202
 
203
- model_output.clear(lambda: (download_glb.update(interactive=False),download_gs.update(interactive=False)), outputs=[download_glb,download_gs])
204
- video_output.clear(lambda: (extract_glb_btn.update(interactive=False),extract_gs_btn.update(interactive=False),download_glb.update(interactive=False),download_gs.update(interactive=False)), outputs=[extract_glb_btn,extract_gs_btn,download_glb,download_gs])
 
 
 
205
 
 
206
  if __name__ == "__main__":
207
- demo.queue().launch(debug=True)
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
  import spaces
3
+
4
  import os
5
  import shutil
6
  os.environ['TOKENIZERS_PARALLELISM'] = 'true'
7
  os.environ['SPCONV_ALGO'] = 'native'
 
8
  from typing import *
9
  import torch
10
  import numpy as np
 
13
  from trellis.pipelines import TrellisTextTo3DPipeline
14
  from trellis.representations import Gaussian, MeshExtractResult
15
  from trellis.utils import render_utils, postprocessing_utils
16
+
17
  import traceback
18
  import sys
19
 
20
+
21
  MAX_SEED = np.iinfo(np.int32).max
22
  TMP_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'tmp')
23
  os.makedirs(TMP_DIR, exist_ok=True)
24
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
 
26
  def start_session(req: gr.Request):
27
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
28
  os.makedirs(user_dir, exist_ok=True)
29
+
30
+
 
31
  def end_session(req: gr.Request):
32
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
33
+ shutil.rmtree(user_dir)
 
 
 
 
 
 
 
34
 
35
 
36
  def pack_state(gs: Gaussian, mesh: MeshExtractResult) -> dict:
 
37
  return {
38
  'gaussian': {
39
+ **gs.init_params,
40
+ '_xyz': gs._xyz.cpu().numpy(),
41
+ '_features_dc': gs._features_dc.cpu().numpy(),
42
+ '_scaling': gs._scaling.cpu().numpy(),
43
+ '_rotation': gs._rotation.cpu().numpy(),
44
+ '_opacity': gs._opacity.cpu().numpy(),
45
  },
46
  'mesh': {
47
+ 'vertices': mesh.vertices.cpu().numpy(),
48
+ 'faces': mesh.faces.cpu().numpy(),
49
  },
50
  }
51
+
52
+
53
+ def unpack_state(state: dict) -> Tuple[Gaussian, edict, str]:
 
 
 
54
  gs = Gaussian(
55
+ aabb=state['gaussian']['aabb'],
56
+ sh_degree=state['gaussian']['sh_degree'],
57
+ mininum_kernel_size=state['gaussian']['mininum_kernel_size'],
58
+ scaling_bias=state['gaussian']['scaling_bias'],
59
+ opacity_bias=state['gaussian']['opacity_bias'],
60
+ scaling_activation=state['gaussian']['scaling_activation'],
61
  )
62
+ gs._xyz = torch.tensor(state['gaussian']['_xyz'], device='cuda')
63
+ gs._features_dc = torch.tensor(state['gaussian']['_features_dc'], device='cuda')
64
+ gs._scaling = torch.tensor(state['gaussian']['_scaling'], device='cuda')
65
+ gs._rotation = torch.tensor(state['gaussian']['_rotation'], device='cuda')
66
+ gs._opacity = torch.tensor(state['gaussian']['_opacity'], device='cuda')
67
+
68
  mesh = edict(
69
+ vertices=torch.tensor(state['mesh']['vertices'], device='cuda'),
70
+ faces=torch.tensor(state['mesh']['faces'], device='cuda'),
71
  )
72
+
73
  return gs, mesh
74
 
75
 
76
  def get_seed(randomize_seed: bool, seed: int) -> int:
77
+ """
78
+ Get the random seed.
79
+ """
80
+ return np.random.randint(0, MAX_SEED) if randomize_seed else seed
81
+
82
 
83
  @spaces.GPU
84
  def text_to_3d(
85
+ prompt: str,
86
+ seed: int,
87
+ ss_guidance_strength: float,
88
+ ss_sampling_steps: int,
89
+ slat_guidance_strength: float,
90
+ slat_sampling_steps: int,
91
+ req: gr.Request,
92
  ) -> Tuple[dict, str]:
93
+ """
94
+ Convert an text prompt to a 3D model.
95
+ Args:
96
+ prompt (str): The text prompt.
97
+ seed (int): The random seed.
98
+ ss_guidance_strength (float): The guidance strength for sparse structure generation.
99
+ ss_sampling_steps (int): The number of sampling steps for sparse structure generation.
100
+ slat_guidance_strength (float): The guidance strength for structured latent generation.
101
+ slat_sampling_steps (int): The number of sampling steps for structured latent generation.
102
+ Returns:
103
+ dict: The information of the generated 3D model.
104
+ str: The path to the video of the 3D model.
105
+ """
106
+ user_dir = os.path.join(TMP_DIR, str(req.session_hash))
107
+ os.makedirs(user_dir, exist_ok=True)
108
+ outputs = pipeline.run(
109
+ prompt,
110
+ seed=seed,
111
+ formats=["gaussian", "mesh"],
112
+ sparse_structure_sampler_params={
113
+ "steps": ss_sampling_steps,
114
+ "cfg_strength": ss_guidance_strength,
115
+ },
116
+ slat_sampler_params={
117
+ "steps": slat_sampling_steps,
118
+ "cfg_strength": slat_guidance_strength,
119
+ },
120
  )
121
+ video = render_utils.render_video(outputs['gaussian'][0], num_frames=120)['color']
122
+ video_geo = render_utils.render_video(outputs['mesh'][0], num_frames=120)['normal']
123
+ video = [np.concatenate([video[i], video_geo[i]], axis=1) for i in range(len(video))]
124
+ video_path = os.path.join(user_dir, 'sample.mp4')
125
+ imageio.mimsave(video_path, video, fps=15)
126
+ state = pack_state(outputs['gaussian'][0], outputs['mesh'][0])
127
+ torch.cuda.empty_cache()
128
+ return state, video_path
129
+
130
+
131
+ @spaces.GPU(duration=90)
132
+ def extract_glb(
133
+ state: dict,
134
+ mesh_simplify: float,
135
+ texture_size: int,
136
+ req: gr.Request,
137
+ ) -> Tuple[str, str]:
138
+ """
139
+ Extract a GLB file from the 3D model.
140
+ Args:
141
+ state (dict): The state of the generated 3D model.
142
+ mesh_simplify (float): The mesh simplification factor.
143
+ texture_size (int): The texture resolution.
144
+ Returns:
145
+ str: The path to the extracted GLB file.
146
+ """
147
+ user_dir = os.path.join(TMP_DIR, str(req.session_hash))
148
+ os.makedirs(user_dir, exist_ok=True)
149
+ gs, mesh = unpack_state(state)
150
+ glb = postprocessing_utils.to_glb(gs, mesh, simplify=mesh_simplify, texture_size=texture_size, verbose=False)
151
+ glb_path = os.path.join(user_dir, 'sample.glb')
152
+ glb.export(glb_path)
153
+ torch.cuda.empty_cache()
154
+ return glb_path, glb_path
155
+
156
 
157
  @spaces.GPU
158
+ def extract_gaussian(state: dict, req: gr.Request) -> Tuple[str, str]:
159
+ """
160
+ Extract a Gaussian file from the 3D model.
161
+ Args:
162
+ state (dict): The state of the generated 3D model.
163
+ Returns:
164
+ str: The path to the extracted Gaussian file.
165
+ """
166
+ user_dir = os.path.join(TMP_DIR, str(req.session_hash))
167
+ os.makedirs(user_dir, exist_ok=True)
168
+ gs, _ = unpack_state(state)
169
+ gaussian_path = os.path.join(user_dir, 'sample.ply')
170
+ gs.save_ply(gaussian_path)
171
+ torch.cuda.empty_cache()
172
+ return gaussian_path, gaussian_path
173
+
174
+
175
+ output_buf = gr.State()
176
+ video_output = gr.Video(label="Generated 3D Asset", autoplay=True, loop=True, height=300)
177
+
178
+ with gr.Blocks(delete_cache=(600, 600)) as demo:
179
  gr.Markdown("""
180
+ ## Text to 3D Asset with [TRELLIS](https://trellis3d.github.io/)
181
+ * Type a text prompt and click "Generate" to create a 3D asset.
182
+ * If you find the generated 3D asset satisfactory, click "Extract GLB" to extract the GLB file and download it.
183
  """)
184
+
185
  with gr.Row():
186
+ with gr.Column():
187
  text_prompt = gr.Textbox(label="Text Prompt", lines=5)
188
+
189
+ with gr.Accordion(label="Generation Settings", open=False):
190
  seed = gr.Slider(0, MAX_SEED, label="Seed", value=0, step=1)
191
  randomize_seed = gr.Checkbox(label="Randomize Seed", value=True)
192
+ gr.Markdown("Stage 1: Sparse Structure Generation")
193
+ with gr.Row():
194
+ ss_guidance_strength = gr.Slider(0.0, 10.0, label="Guidance Strength", value=7.5, step=0.1)
195
+ ss_sampling_steps = gr.Slider(1, 50, label="Sampling Steps", value=25, step=1)
196
+ gr.Markdown("Stage 2: Structured Latent Generation")
197
+ with gr.Row():
198
+ slat_guidance_strength = gr.Slider(0.0, 10.0, label="Guidance Strength", value=7.5, step=0.1)
199
+ slat_sampling_steps = gr.Slider(1, 50, label="Sampling Steps", value=25, step=1)
200
+
201
+ generate_btn = gr.Button("Generate")
202
+
203
+ with gr.Accordion(label="GLB Extraction Settings", open=False):
204
+ mesh_simplify = gr.Slider(0.9, 0.98, label="Simplify", value=0.95, step=0.01)
205
+ texture_size = gr.Slider(512, 2048, label="Texture Size", value=1024, step=512)
206
+
207
+ with gr.Row():
208
+ extract_glb_btn = gr.Button("Extract GLB", interactive=False)
209
+ extract_gs_btn = gr.Button("Extract Gaussian", interactive=False)
210
+ gr.Markdown("""
211
+ *NOTE: Gaussian file can be very large (~50MB), it will take a while to display and download.*
212
+ """)
213
+
214
+ with gr.Column():
215
+ video_output = gr.Video(label="Generated 3D Asset", autoplay=True, loop=True, height=300)
216
+ model_output = gr.Model3D(label="Extracted GLB/Gaussian", height=300)
217
+
218
+ with gr.Row():
219
+ download_glb = gr.DownloadButton(label="Download GLB", interactive=False)
220
+ download_gs = gr.DownloadButton(label="Download Gaussian", interactive=False)
221
+
222
+ output_buf = gr.State()
223
+
224
+ # Handlers
225
  demo.load(start_session)
226
  demo.unload(end_session)
227
 
228
+ generate_btn.click(
229
  get_seed,
230
+ inputs=[randomize_seed, seed],
231
+ outputs=[seed],
232
  ).then(
233
  text_to_3d,
234
+ inputs=[text_prompt, seed, ss_guidance_strength, ss_sampling_steps, slat_guidance_strength, slat_sampling_steps],
235
+ outputs=[output_buf, video_output],
236
+ ).then(
237
+ lambda: tuple([gr.Button(interactive=True), gr.Button(interactive=True)]),
238
+ outputs=[extract_glb_btn, extract_gs_btn],
239
+ )
240
+
241
+ video_output.clear(
242
+ lambda: tuple([gr.Button(interactive=False), gr.Button(interactive=False)]),
243
+ outputs=[extract_glb_btn, extract_gs_btn],
244
+ )
245
 
246
  extract_glb_btn.click(
247
  extract_glb,
248
+ inputs=[output_buf, mesh_simplify, texture_size],
249
+ outputs=[model_output, download_glb],
250
+ ).then(
251
+ lambda: gr.Button(interactive=True),
252
+ outputs=[download_glb],
253
+ )
254
+
255
  extract_gs_btn.click(
256
  extract_gaussian,
257
+ inputs=[output_buf],
258
+ outputs=[model_output, download_gs],
259
+ ).then(
260
+ lambda: gr.Button(interactive=True),
261
+ outputs=[download_gs],
262
+ )
263
 
264
+ model_output.clear(
265
+ lambda: gr.Button(interactive=False),
266
+ outputs=[download_glb],
267
+ )
268
+
269
 
270
+ # Launch the Gradio app
271
  if __name__ == "__main__":
272
+ pipeline = TrellisTextTo3DPipeline.from_pretrained("JeffreyXiang/TRELLIS-text-xlarge")
273
+ pipeline.cuda()
274
+ demo.launch()