cavargas10 commited on
Commit
512a50c
verified
1 Parent(s): e5dd267

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +227 -274
app.py CHANGED
@@ -1,274 +1,227 @@
1
- import gradio as gr
2
- import spaces
3
-
4
- import os
5
- import shutil
6
- os.environ['TOKENIZERS_PARALLELISM'] = 'true'
7
- os.environ['SPCONV_ALGO'] = 'native'
8
- from typing import *
9
- import torch
10
- import numpy as np
11
- import imageio
12
- from easydict import EasyDict as edict
13
- from trellis.pipelines import TrellisTextTo3DPipeline
14
- from trellis.representations import Gaussian, MeshExtractResult
15
- from trellis.utils import render_utils, postprocessing_utils
16
-
17
- import traceback
18
- import sys
19
-
20
-
21
- MAX_SEED = np.iinfo(np.int32).max
22
- TMP_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'tmp')
23
- os.makedirs(TMP_DIR, exist_ok=True)
24
-
25
-
26
- def start_session(req: gr.Request):
27
- user_dir = os.path.join(TMP_DIR, str(req.session_hash))
28
- os.makedirs(user_dir, exist_ok=True)
29
-
30
-
31
- def end_session(req: gr.Request):
32
- user_dir = os.path.join(TMP_DIR, str(req.session_hash))
33
- shutil.rmtree(user_dir)
34
-
35
-
36
- def pack_state(gs: Gaussian, mesh: MeshExtractResult) -> dict:
37
- return {
38
- 'gaussian': {
39
- **gs.init_params,
40
- '_xyz': gs._xyz.cpu().numpy(),
41
- '_features_dc': gs._features_dc.cpu().numpy(),
42
- '_scaling': gs._scaling.cpu().numpy(),
43
- '_rotation': gs._rotation.cpu().numpy(),
44
- '_opacity': gs._opacity.cpu().numpy(),
45
- },
46
- 'mesh': {
47
- 'vertices': mesh.vertices.cpu().numpy(),
48
- 'faces': mesh.faces.cpu().numpy(),
49
- },
50
- }
51
-
52
-
53
- def unpack_state(state: dict) -> Tuple[Gaussian, edict, str]:
54
- gs = Gaussian(
55
- aabb=state['gaussian']['aabb'],
56
- sh_degree=state['gaussian']['sh_degree'],
57
- mininum_kernel_size=state['gaussian']['mininum_kernel_size'],
58
- scaling_bias=state['gaussian']['scaling_bias'],
59
- opacity_bias=state['gaussian']['opacity_bias'],
60
- scaling_activation=state['gaussian']['scaling_activation'],
61
- )
62
- gs._xyz = torch.tensor(state['gaussian']['_xyz'], device='cuda')
63
- gs._features_dc = torch.tensor(state['gaussian']['_features_dc'], device='cuda')
64
- gs._scaling = torch.tensor(state['gaussian']['_scaling'], device='cuda')
65
- gs._rotation = torch.tensor(state['gaussian']['_rotation'], device='cuda')
66
- gs._opacity = torch.tensor(state['gaussian']['_opacity'], device='cuda')
67
-
68
- mesh = edict(
69
- vertices=torch.tensor(state['mesh']['vertices'], device='cuda'),
70
- faces=torch.tensor(state['mesh']['faces'], device='cuda'),
71
- )
72
-
73
- return gs, mesh
74
-
75
-
76
- def get_seed(randomize_seed: bool, seed: int) -> int:
77
- """
78
- Get the random seed.
79
- """
80
- return np.random.randint(0, MAX_SEED) if randomize_seed else seed
81
-
82
-
83
- @spaces.GPU
84
- def text_to_3d(
85
- prompt: str,
86
- seed: int,
87
- ss_guidance_strength: float,
88
- ss_sampling_steps: int,
89
- slat_guidance_strength: float,
90
- slat_sampling_steps: int,
91
- req: gr.Request,
92
- ) -> Tuple[dict, str]:
93
- """
94
- Convert an text prompt to a 3D model.
95
-
96
- Args:
97
- prompt (str): The text prompt.
98
- seed (int): The random seed.
99
- ss_guidance_strength (float): The guidance strength for sparse structure generation.
100
- ss_sampling_steps (int): The number of sampling steps for sparse structure generation.
101
- slat_guidance_strength (float): The guidance strength for structured latent generation.
102
- slat_sampling_steps (int): The number of sampling steps for structured latent generation.
103
-
104
- Returns:
105
- dict: The information of the generated 3D model.
106
- str: The path to the video of the 3D model.
107
- """
108
- user_dir = os.path.join(TMP_DIR, str(req.session_hash))
109
- outputs = pipeline.run(
110
- prompt,
111
- seed=seed,
112
- formats=["gaussian", "mesh"],
113
- sparse_structure_sampler_params={
114
- "steps": ss_sampling_steps,
115
- "cfg_strength": ss_guidance_strength,
116
- },
117
- slat_sampler_params={
118
- "steps": slat_sampling_steps,
119
- "cfg_strength": slat_guidance_strength,
120
- },
121
- )
122
- video = render_utils.render_video(outputs['gaussian'][0], num_frames=120)['color']
123
- video_geo = render_utils.render_video(outputs['mesh'][0], num_frames=120)['normal']
124
- video = [np.concatenate([video[i], video_geo[i]], axis=1) for i in range(len(video))]
125
- video_path = os.path.join(user_dir, 'sample.mp4')
126
- imageio.mimsave(video_path, video, fps=15)
127
- state = pack_state(outputs['gaussian'][0], outputs['mesh'][0])
128
- torch.cuda.empty_cache()
129
- return state, video_path
130
-
131
-
132
- @spaces.GPU(duration=90)
133
- def extract_glb(
134
- state: dict,
135
- mesh_simplify: float,
136
- texture_size: int,
137
- req: gr.Request,
138
- ) -> Tuple[str, str]:
139
- """
140
- Extract a GLB file from the 3D model.
141
-
142
- Args:
143
- state (dict): The state of the generated 3D model.
144
- mesh_simplify (float): The mesh simplification factor.
145
- texture_size (int): The texture resolution.
146
-
147
- Returns:
148
- str: The path to the extracted GLB file.
149
- """
150
- user_dir = os.path.join(TMP_DIR, str(req.session_hash))
151
- gs, mesh = unpack_state(state)
152
- glb = postprocessing_utils.to_glb(gs, mesh, simplify=mesh_simplify, texture_size=texture_size, verbose=False)
153
- glb_path = os.path.join(user_dir, 'sample.glb')
154
- glb.export(glb_path)
155
- torch.cuda.empty_cache()
156
- return glb_path, glb_path
157
-
158
-
159
- @spaces.GPU
160
- def extract_gaussian(state: dict, req: gr.Request) -> Tuple[str, str]:
161
- """
162
- Extract a Gaussian file from the 3D model.
163
-
164
- Args:
165
- state (dict): The state of the generated 3D model.
166
-
167
- Returns:
168
- str: The path to the extracted Gaussian file.
169
- """
170
- user_dir = os.path.join(TMP_DIR, str(req.session_hash))
171
- gs, _ = unpack_state(state)
172
- gaussian_path = os.path.join(user_dir, 'sample.ply')
173
- gs.save_ply(gaussian_path)
174
- torch.cuda.empty_cache()
175
- return gaussian_path, gaussian_path
176
-
177
-
178
- with gr.Blocks(delete_cache=(600, 600)) as demo:
179
- gr.Markdown("""
180
- ## Text to 3D Asset with [TRELLIS](https://trellis3d.github.io/)
181
- * Type a text prompt and click "Generate" to create a 3D asset.
182
- * If you find the generated 3D asset satisfactory, click "Extract GLB" to extract the GLB file and download it.
183
- """)
184
-
185
- with gr.Row():
186
- with gr.Column():
187
- text_prompt = gr.Textbox(label="Text Prompt", lines=5)
188
-
189
- with gr.Accordion(label="Generation Settings", open=False):
190
- seed = gr.Slider(0, MAX_SEED, label="Seed", value=0, step=1)
191
- randomize_seed = gr.Checkbox(label="Randomize Seed", value=True)
192
- gr.Markdown("Stage 1: Sparse Structure Generation")
193
- with gr.Row():
194
- ss_guidance_strength = gr.Slider(0.0, 10.0, label="Guidance Strength", value=7.5, step=0.1)
195
- ss_sampling_steps = gr.Slider(1, 50, label="Sampling Steps", value=25, step=1)
196
- gr.Markdown("Stage 2: Structured Latent Generation")
197
- with gr.Row():
198
- slat_guidance_strength = gr.Slider(0.0, 10.0, label="Guidance Strength", value=7.5, step=0.1)
199
- slat_sampling_steps = gr.Slider(1, 50, label="Sampling Steps", value=25, step=1)
200
-
201
- generate_btn = gr.Button("Generate")
202
-
203
- with gr.Accordion(label="GLB Extraction Settings", open=False):
204
- mesh_simplify = gr.Slider(0.9, 0.98, label="Simplify", value=0.95, step=0.01)
205
- texture_size = gr.Slider(512, 2048, label="Texture Size", value=1024, step=512)
206
-
207
- with gr.Row():
208
- extract_glb_btn = gr.Button("Extract GLB", interactive=False)
209
- extract_gs_btn = gr.Button("Extract Gaussian", interactive=False)
210
- gr.Markdown("""
211
- *NOTE: Gaussian file can be very large (~50MB), it will take a while to display and download.*
212
- """)
213
-
214
- with gr.Column():
215
- video_output = gr.Video(label="Generated 3D Asset", autoplay=True, loop=True, height=300)
216
- model_output = gr.Model3D(label="Extracted GLB/Gaussian", height=300)
217
-
218
- with gr.Row():
219
- download_glb = gr.DownloadButton(label="Download GLB", interactive=False)
220
- download_gs = gr.DownloadButton(label="Download Gaussian", interactive=False)
221
-
222
- output_buf = gr.State()
223
-
224
- # Handlers
225
- demo.load(start_session)
226
- demo.unload(end_session)
227
-
228
- generate_btn.click(
229
- get_seed,
230
- inputs=[randomize_seed, seed],
231
- outputs=[seed],
232
- ).then(
233
- text_to_3d,
234
- inputs=[text_prompt, seed, ss_guidance_strength, ss_sampling_steps, slat_guidance_strength, slat_sampling_steps],
235
- outputs=[output_buf, video_output],
236
- ).then(
237
- lambda: tuple([gr.Button(interactive=True), gr.Button(interactive=True)]),
238
- outputs=[extract_glb_btn, extract_gs_btn],
239
- )
240
-
241
- video_output.clear(
242
- lambda: tuple([gr.Button(interactive=False), gr.Button(interactive=False)]),
243
- outputs=[extract_glb_btn, extract_gs_btn],
244
- )
245
-
246
- extract_glb_btn.click(
247
- extract_glb,
248
- inputs=[output_buf, mesh_simplify, texture_size],
249
- outputs=[model_output, download_glb],
250
- ).then(
251
- lambda: gr.Button(interactive=True),
252
- outputs=[download_glb],
253
- )
254
-
255
- extract_gs_btn.click(
256
- extract_gaussian,
257
- inputs=[output_buf],
258
- outputs=[model_output, download_gs],
259
- ).then(
260
- lambda: gr.Button(interactive=True),
261
- outputs=[download_gs],
262
- )
263
-
264
- model_output.clear(
265
- lambda: gr.Button(interactive=False),
266
- outputs=[download_glb],
267
- )
268
-
269
-
270
- # Launch the Gradio app
271
- if __name__ == "__main__":
272
- pipeline = TrellisTextTo3DPipeline.from_pretrained("JeffreyXiang/TRELLIS-text-xlarge")
273
- pipeline.cuda()
274
- demo.launch()
 
1
+ import gradio as gr
2
+ import spaces
3
+
4
+ import os
5
+ import shutil
6
+ os.environ['TOKENIZERS_PARALLELISM'] = 'true'
7
+ os.environ['SPCONV_ALGO'] = 'native'
8
+ from typing import *
9
+ import torch
10
+ import numpy as np
11
+ import imageio
12
+ from easydict import EasyDict as edict
13
+ from trellis.pipelines import TrellisTextTo3DPipeline
14
+ from trellis.representations import Gaussian, MeshExtractResult
15
+ from trellis.utils import render_utils, postprocessing_utils
16
+
17
+ import traceback
18
+ import sys
19
+
20
+ MAX_SEED = np.iinfo(np.int32).max
21
+ TMP_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'tmp')
22
+ os.makedirs(TMP_DIR, exist_ok=True)
23
+
24
+ def start_session(req: gr.Request):
25
+ user_dir = os.path.join(TMP_DIR, str(req.session_hash))
26
+ os.makedirs(user_dir, exist_ok=True)
27
+
28
+ def end_session(req: gr.Request):
29
+ user_dir = os.path.join(TMP_DIR, str(req.session_hash))
30
+ shutil.rmtree(user_dir)
31
+
32
+ def pack_state(gs: Gaussian, mesh: MeshExtractResult) -> dict:
33
+ return {
34
+ 'gaussian': {
35
+ **gs.init_params,
36
+ '_xyz': gs._xyz.cpu().numpy(),
37
+ '_features_dc': gs._features_dc.cpu().numpy(),
38
+ '_scaling': gs._scaling.cpu().numpy(),
39
+ '_rotation': gs._rotation.cpu().numpy(),
40
+ '_opacity': gs._opacity.cpu().numpy(),
41
+ },
42
+ 'mesh': {
43
+ 'vertices': mesh.vertices.cpu().numpy(),
44
+ 'faces': mesh.faces.cpu().numpy(),
45
+ },
46
+ }
47
+
48
+ def unpack_state(state: dict) -> Tuple[Gaussian, edict, str]:
49
+ gs = Gaussian(
50
+ aabb=state['gaussian']['aabb'],
51
+ sh_degree=state['gaussian']['sh_degree'],
52
+ mininum_kernel_size=state['gaussian']['mininum_kernel_size'],
53
+ scaling_bias=state['gaussian']['scaling_bias'],
54
+ opacity_bias=state['gaussian']['opacity_bias'],
55
+ scaling_activation=state['gaussian']['scaling_activation'],
56
+ )
57
+ gs._xyz = torch.tensor(state['gaussian']['_xyz'], device='cuda')
58
+ gs._features_dc = torch.tensor(state['gaussian']['_features_dc'], device='cuda')
59
+ gs._scaling = torch.tensor(state['gaussian']['_scaling'], device='cuda')
60
+ gs._rotation = torch.tensor(state['gaussian']['_rotation'], device='cuda')
61
+ gs._opacity = torch.tensor(state['gaussian']['_opacity'], device='cuda')
62
+
63
+ mesh = edict(
64
+ vertices=torch.tensor(state['mesh']['vertices'], device='cuda'),
65
+ faces=torch.tensor(state['mesh']['faces'], device='cuda'),
66
+ )
67
+
68
+ return gs, mesh
69
+
70
+ def get_seed(randomize_seed: bool, seed: int) -> int:
71
+ return np.random.randint(0, MAX_SEED) if randomize_seed else seed
72
+
73
+ @spaces.GPU
74
+ def text_to_3d(
75
+ prompt: str,
76
+ seed: int,
77
+ ss_guidance_strength: float,
78
+ ss_sampling_steps: int,
79
+ slat_guidance_strength: float,
80
+ slat_sampling_steps: int,
81
+ req: gr.Request,
82
+ ) -> Tuple[dict, str]:
83
+ user_dir = os.path.join(TMP_DIR, str(req.session_hash))
84
+ outputs = pipeline.run(
85
+ prompt,
86
+ seed=seed,
87
+ formats=["gaussian", "mesh"],
88
+ sparse_structure_sampler_params={
89
+ "steps": ss_sampling_steps,
90
+ "cfg_strength": ss_guidance_strength,
91
+ },
92
+ slat_sampler_params={
93
+ "steps": slat_sampling_steps,
94
+ "cfg_strength": slat_guidance_strength,
95
+ },
96
+ )
97
+ video = render_utils.render_video(outputs['gaussian'][0], num_frames=120)['color']
98
+ video_geo = render_utils.render_video(outputs['mesh'][0], num_frames=120)['normal']
99
+ video = [np.concatenate([video[i], video_geo[i]], axis=1) for i in range(len(video))]
100
+ video_path = os.path.join(user_dir, 'sample.mp4')
101
+ imageio.mimsave(video_path, video, fps=15)
102
+ state = pack_state(outputs['gaussian'][0], outputs['mesh'][0])
103
+ torch.cuda.empty_cache()
104
+ return state, video_path
105
+
106
+ @spaces.GPU(duration=90)
107
+ def extract_glb(
108
+ state: dict,
109
+ mesh_simplify: float,
110
+ texture_size: int,
111
+ req: gr.Request,
112
+ ) -> Tuple[str, str]:
113
+ user_dir = os.path.join(TMP_DIR, str(req.session_hash))
114
+ gs, mesh = unpack_state(state)
115
+ glb = postprocessing_utils.to_glb(gs, mesh, simplify=mesh_simplify, texture_size=texture_size, verbose=False)
116
+ glb_path = os.path.join(user_dir, 'sample.glb')
117
+ glb.export(glb_path)
118
+ torch.cuda.empty_cache()
119
+ return glb_path, glb_path
120
+
121
+ @spaces.GPU
122
+ def extract_gaussian(state: dict, req: gr.Request) -> Tuple[str, str]:
123
+ user_dir = os.path.join(TMP_DIR, str(req.session_hash))
124
+ gs, _ = unpack_state(state)
125
+ gaussian_path = os.path.join(user_dir, 'sample.ply')
126
+ gs.save_ply(gaussian_path)
127
+ torch.cuda.empty_cache()
128
+ return gaussian_path, gaussian_path
129
+
130
+ with gr.Blocks(delete_cache=(600, 600)) as demo:
131
+ gr.Markdown("""
132
+ # UTPL - Conversi贸n de Texto a Imagen a objetos 3D usando IA
133
+ ### Tesis: *"Objetos tridimensionales creados por IA: Innovaci贸n en entornos virtuales"*
134
+ **Autor:** Carlos Vargas
135
+ **Base t茅cnica:** Adaptaci贸n de [TRELLIS](https://trellis3d.github.io/) y [FLUX](https://huggingface.co/camenduru/FLUX.1-dev-diffusers) (herramientas de c贸digo abierto para generaci贸n 3D)
136
+ **Prop贸sito educativo:** Demostraciones acad茅micas e Investigaci贸n en modelado 3D autom谩tico
137
+ """)
138
+
139
+ with gr.Row():
140
+ with gr.Column():
141
+ text_prompt = gr.Textbox(label="Text Prompt", lines=5)
142
+
143
+ with gr.Accordion(label="Generation Settings", open=False):
144
+ seed = gr.Slider(0, MAX_SEED, label="Seed", value=0, step=1)
145
+ randomize_seed = gr.Checkbox(label="Randomize Seed", value=True)
146
+ gr.Markdown("Stage 1: Sparse Structure Generation")
147
+ with gr.Row():
148
+ ss_guidance_strength = gr.Slider(0.0, 10.0, label="Guidance Strength", value=7.5, step=0.1)
149
+ ss_sampling_steps = gr.Slider(1, 50, label="Sampling Steps", value=25, step=1)
150
+ gr.Markdown("Stage 2: Structured Latent Generation")
151
+ with gr.Row():
152
+ slat_guidance_strength = gr.Slider(0.0, 10.0, label="Guidance Strength", value=7.5, step=0.1)
153
+ slat_sampling_steps = gr.Slider(1, 50, label="Sampling Steps", value=25, step=1)
154
+
155
+ generate_btn = gr.Button("Generate")
156
+
157
+ with gr.Accordion(label="GLB Extraction Settings", open=False):
158
+ mesh_simplify = gr.Slider(0.9, 0.98, label="Simplify", value=0.95, step=0.01)
159
+ texture_size = gr.Slider(512, 2048, label="Texture Size", value=1024, step=512)
160
+
161
+ with gr.Row():
162
+ extract_glb_btn = gr.Button("Extract GLB", interactive=False)
163
+ extract_gs_btn = gr.Button("Extract Gaussian", interactive=False)
164
+ gr.Markdown("""
165
+ *NOTE: Gaussian file can be very large (~50MB), it will take a while to display and download.*
166
+ """)
167
+
168
+ with gr.Column():
169
+ video_output = gr.Video(label="Generated 3D Asset", autoplay=True, loop=True, height=300)
170
+ model_output = gr.Model3D(label="Extracted GLB/Gaussian", height=300)
171
+
172
+ with gr.Row():
173
+ download_glb = gr.DownloadButton(label="Download GLB", interactive=False)
174
+ download_gs = gr.DownloadButton(label="Download Gaussian", interactive=False)
175
+
176
+ output_buf = gr.State()
177
+
178
+ # Handlers
179
+ demo.load(start_session)
180
+ demo.unload(end_session)
181
+
182
+ generate_btn.click(
183
+ get_seed,
184
+ inputs=[randomize_seed, seed],
185
+ outputs=[seed],
186
+ ).then(
187
+ text_to_3d,
188
+ inputs=[text_prompt, seed, ss_guidance_strength, ss_sampling_steps, slat_guidance_strength, slat_sampling_steps],
189
+ outputs=[output_buf, video_output],
190
+ ).then(
191
+ lambda: tuple([gr.Button(interactive=True), gr.Button(interactive=True)]),
192
+ outputs=[extract_glb_btn, extract_gs_btn],
193
+ )
194
+
195
+ video_output.clear(
196
+ lambda: tuple([gr.Button(interactive=False), gr.Button(interactive=False)]),
197
+ outputs=[extract_glb_btn, extract_gs_btn],
198
+ )
199
+
200
+ extract_glb_btn.click(
201
+ extract_glb,
202
+ inputs=[output_buf, mesh_simplify, texture_size],
203
+ outputs=[model_output, download_glb],
204
+ ).then(
205
+ lambda: gr.Button(interactive=True),
206
+ outputs=[download_glb],
207
+ )
208
+
209
+ extract_gs_btn.click(
210
+ extract_gaussian,
211
+ inputs=[output_buf],
212
+ outputs=[model_output, download_gs],
213
+ ).then(
214
+ lambda: gr.Button(interactive=True),
215
+ outputs=[download_gs],
216
+ )
217
+
218
+ model_output.clear(
219
+ lambda: gr.Button(interactive=False),
220
+ outputs=[download_glb],
221
+ )
222
+
223
+ # Launch the Gradio app
224
+ if __name__ == "__main__":
225
+ pipeline = TrellisTextTo3DPipeline.from_pretrained("JeffreyXiang/TRELLIS-text-xlarge")
226
+ pipeline.cuda()
227
+ demo.launch()