Vibu46vk commited on
Commit
0dd3455
·
verified ·
1 Parent(s): eec8d8f

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +403 -0
app.py ADDED
@@ -0,0 +1,403 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from gradio_litmodel3d import LitModel3D
3
+
4
+ import os
5
+ import shutil
6
+ from typing import *
7
+ import torch
8
+ import numpy as np
9
+ import imageio
10
+ from easydict import EasyDict as edict
11
+ from PIL import Image
12
+ from trellis.pipelines import TrellisImageTo3DPipeline
13
+ from trellis.representations import Gaussian, MeshExtractResult
14
+ from trellis.utils import render_utils, postprocessing_utils
15
+
16
+
17
+ MAX_SEED = np.iinfo(np.int32).max
18
+ TMP_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'tmp')
19
+ os.makedirs(TMP_DIR, exist_ok=True)
20
+
21
+
22
+ def start_session(req: gr.Request):
23
+ user_dir = os.path.join(TMP_DIR, str(req.session_hash))
24
+ os.makedirs(user_dir, exist_ok=True)
25
+
26
+
27
+ def end_session(req: gr.Request):
28
+ user_dir = os.path.join(TMP_DIR, str(req.session_hash))
29
+ shutil.rmtree(user_dir)
30
+
31
+
32
+ def preprocess_image(image: Image.Image) -> Image.Image:
33
+ """
34
+ Preprocess the input image.
35
+
36
+ Args:
37
+ image (Image.Image): The input image.
38
+
39
+ Returns:
40
+ Image.Image: The preprocessed image.
41
+ """
42
+ processed_image = pipeline.preprocess_image(image)
43
+ return processed_image
44
+
45
+
46
+ def preprocess_images(images: List[Tuple[Image.Image, str]]) -> List[Image.Image]:
47
+ """
48
+ Preprocess a list of input images.
49
+
50
+ Args:
51
+ images (List[Tuple[Image.Image, str]]): The input images.
52
+
53
+ Returns:
54
+ List[Image.Image]: The preprocessed images.
55
+ """
56
+ images = [image[0] for image in images]
57
+ processed_images = [pipeline.preprocess_image(image) for image in images]
58
+ return processed_images
59
+
60
+
61
+ def pack_state(gs: Gaussian, mesh: MeshExtractResult) -> dict:
62
+ return {
63
+ 'gaussian': {
64
+ **gs.init_params,
65
+ '_xyz': gs._xyz.cpu().numpy(),
66
+ '_features_dc': gs._features_dc.cpu().numpy(),
67
+ '_scaling': gs._scaling.cpu().numpy(),
68
+ '_rotation': gs._rotation.cpu().numpy(),
69
+ '_opacity': gs._opacity.cpu().numpy(),
70
+ },
71
+ 'mesh': {
72
+ 'vertices': mesh.vertices.cpu().numpy(),
73
+ 'faces': mesh.faces.cpu().numpy(),
74
+ },
75
+ }
76
+
77
+
78
+ def unpack_state(state: dict) -> Tuple[Gaussian, edict, str]:
79
+ gs = Gaussian(
80
+ aabb=state['gaussian']['aabb'],
81
+ sh_degree=state['gaussian']['sh_degree'],
82
+ mininum_kernel_size=state['gaussian']['mininum_kernel_size'],
83
+ scaling_bias=state['gaussian']['scaling_bias'],
84
+ opacity_bias=state['gaussian']['opacity_bias'],
85
+ scaling_activation=state['gaussian']['scaling_activation'],
86
+ )
87
+ gs._xyz = torch.tensor(state['gaussian']['_xyz'], device='cuda')
88
+ gs._features_dc = torch.tensor(state['gaussian']['_features_dc'], device='cuda')
89
+ gs._scaling = torch.tensor(state['gaussian']['_scaling'], device='cuda')
90
+ gs._rotation = torch.tensor(state['gaussian']['_rotation'], device='cuda')
91
+ gs._opacity = torch.tensor(state['gaussian']['_opacity'], device='cuda')
92
+
93
+ mesh = edict(
94
+ vertices=torch.tensor(state['mesh']['vertices'], device='cuda'),
95
+ faces=torch.tensor(state['mesh']['faces'], device='cuda'),
96
+ )
97
+
98
+ return gs, mesh
99
+
100
+
101
+ def get_seed(randomize_seed: bool, seed: int) -> int:
102
+ """
103
+ Get the random seed.
104
+ """
105
+ return np.random.randint(0, MAX_SEED) if randomize_seed else seed
106
+
107
+
108
+ def image_to_3d(
109
+ image: Image.Image,
110
+ multiimages: List[Tuple[Image.Image, str]],
111
+ is_multiimage: bool,
112
+ seed: int,
113
+ ss_guidance_strength: float,
114
+ ss_sampling_steps: int,
115
+ slat_guidance_strength: float,
116
+ slat_sampling_steps: int,
117
+ multiimage_algo: Literal["multidiffusion", "stochastic"],
118
+ req: gr.Request,
119
+ ) -> Tuple[dict, str]:
120
+ """
121
+ Convert an image to a 3D model.
122
+
123
+ Args:
124
+ image (Image.Image): The input image.
125
+ multiimages (List[Tuple[Image.Image, str]]): The input images in multi-image mode.
126
+ is_multiimage (bool): Whether is in multi-image mode.
127
+ seed (int): The random seed.
128
+ ss_guidance_strength (float): The guidance strength for sparse structure generation.
129
+ ss_sampling_steps (int): The number of sampling steps for sparse structure generation.
130
+ slat_guidance_strength (float): The guidance strength for structured latent generation.
131
+ slat_sampling_steps (int): The number of sampling steps for structured latent generation.
132
+ multiimage_algo (Literal["multidiffusion", "stochastic"]): The algorithm for multi-image generation.
133
+
134
+ Returns:
135
+ dict: The information of the generated 3D model.
136
+ str: The path to the video of the 3D model.
137
+ """
138
+ user_dir = os.path.join(TMP_DIR, str(req.session_hash))
139
+ if not is_multiimage:
140
+ outputs = pipeline.run(
141
+ image,
142
+ seed=seed,
143
+ formats=["gaussian", "mesh"],
144
+ preprocess_image=False,
145
+ sparse_structure_sampler_params={
146
+ "steps": ss_sampling_steps,
147
+ "cfg_strength": ss_guidance_strength,
148
+ },
149
+ slat_sampler_params={
150
+ "steps": slat_sampling_steps,
151
+ "cfg_strength": slat_guidance_strength,
152
+ },
153
+ )
154
+ else:
155
+ outputs = pipeline.run_multi_image(
156
+ [image[0] for image in multiimages],
157
+ seed=seed,
158
+ formats=["gaussian", "mesh"],
159
+ preprocess_image=False,
160
+ sparse_structure_sampler_params={
161
+ "steps": ss_sampling_steps,
162
+ "cfg_strength": ss_guidance_strength,
163
+ },
164
+ slat_sampler_params={
165
+ "steps": slat_sampling_steps,
166
+ "cfg_strength": slat_guidance_strength,
167
+ },
168
+ mode=multiimage_algo,
169
+ )
170
+ video = render_utils.render_video(outputs['gaussian'][0], num_frames=120)['color']
171
+ video_geo = render_utils.render_video(outputs['mesh'][0], num_frames=120)['normal']
172
+ video = [np.concatenate([video[i], video_geo[i]], axis=1) for i in range(len(video))]
173
+ video_path = os.path.join(user_dir, 'sample.mp4')
174
+ imageio.mimsave(video_path, video, fps=15)
175
+ state = pack_state(outputs['gaussian'][0], outputs['mesh'][0])
176
+ torch.cuda.empty_cache()
177
+ return state, video_path
178
+
179
+
180
+ def extract_glb(
181
+ state: dict,
182
+ mesh_simplify: float,
183
+ texture_size: int,
184
+ req: gr.Request,
185
+ ) -> Tuple[str, str]:
186
+ """
187
+ Extract a GLB file from the 3D model.
188
+
189
+ Args:
190
+ state (dict): The state of the generated 3D model.
191
+ mesh_simplify (float): The mesh simplification factor.
192
+ texture_size (int): The texture resolution.
193
+
194
+ Returns:
195
+ str: The path to the extracted GLB file.
196
+ """
197
+ user_dir = os.path.join(TMP_DIR, str(req.session_hash))
198
+ gs, mesh = unpack_state(state)
199
+ glb = postprocessing_utils.to_glb(gs, mesh, simplify=mesh_simplify, texture_size=texture_size, verbose=False)
200
+ glb_path = os.path.join(user_dir, 'sample.glb')
201
+ glb.export(glb_path)
202
+ torch.cuda.empty_cache()
203
+ return glb_path, glb_path
204
+
205
+
206
+ def extract_gaussian(state: dict, req: gr.Request) -> Tuple[str, str]:
207
+ """
208
+ Extract a Gaussian file from the 3D model.
209
+
210
+ Args:
211
+ state (dict): The state of the generated 3D model.
212
+
213
+ Returns:
214
+ str: The path to the extracted Gaussian file.
215
+ """
216
+ user_dir = os.path.join(TMP_DIR, str(req.session_hash))
217
+ gs, _ = unpack_state(state)
218
+ gaussian_path = os.path.join(user_dir, 'sample.ply')
219
+ gs.save_ply(gaussian_path)
220
+ torch.cuda.empty_cache()
221
+ return gaussian_path, gaussian_path
222
+
223
+
224
+ def prepare_multi_example() -> List[Image.Image]:
225
+ multi_case = list(set([i.split('_')[0] for i in os.listdir("assets/example_multi_image")]))
226
+ images = []
227
+ for case in multi_case:
228
+ _images = []
229
+ for i in range(1, 4):
230
+ img = Image.open(f'assets/example_multi_image/{case}_{i}.png')
231
+ W, H = img.size
232
+ img = img.resize((int(W / H * 512), 512))
233
+ _images.append(np.array(img))
234
+ images.append(Image.fromarray(np.concatenate(_images, axis=1)))
235
+ return images
236
+
237
+
238
+ def split_image(image: Image.Image) -> List[Image.Image]:
239
+ """
240
+ Split an image into multiple views.
241
+ """
242
+ image = np.array(image)
243
+ alpha = image[..., 3]
244
+ alpha = np.any(alpha>0, axis=0)
245
+ start_pos = np.where(~alpha[:-1] & alpha[1:])[0].tolist()
246
+ end_pos = np.where(alpha[:-1] & ~alpha[1:])[0].tolist()
247
+ images = []
248
+ for s, e in zip(start_pos, end_pos):
249
+ images.append(Image.fromarray(image[:, s:e+1]))
250
+ return [preprocess_image(image) for image in images]
251
+
252
+
253
+ with gr.Blocks(delete_cache=(600, 600)) as demo:
254
+ gr.Markdown("""
255
+ ## Image to 3D Asset with [TRELLIS](https://trellis3d.github.io/)
256
+ * Upload an image and click "Generate" to create a 3D asset. If the image has alpha channel, it be used as the mask. Otherwise, we use `rembg` to remove the background.
257
+ * If you find the generated 3D asset satisfactory, click "Extract GLB" to extract the GLB file and download it.
258
+ """)
259
+
260
+ with gr.Row():
261
+ with gr.Column():
262
+ with gr.Tabs() as input_tabs:
263
+ with gr.Tab(label="Single Image", id=0) as single_image_input_tab:
264
+ image_prompt = gr.Image(label="Image Prompt", format="png", image_mode="RGBA", type="pil", height=300)
265
+ with gr.Tab(label="Multiple Images", id=1) as multiimage_input_tab:
266
+ multiimage_prompt = gr.Gallery(label="Image Prompt", format="png", type="pil", height=300, columns=3)
267
+ gr.Markdown("""
268
+ Input different views of the object in separate images.
269
+
270
+ *NOTE: this is an experimental algorithm without training a specialized model. It may not produce the best results for all images, especially those having different poses or inconsistent details.*
271
+ """)
272
+
273
+ with gr.Accordion(label="Generation Settings", open=False):
274
+ seed = gr.Slider(0, MAX_SEED, label="Seed", value=0, step=1)
275
+ randomize_seed = gr.Checkbox(label="Randomize Seed", value=True)
276
+ gr.Markdown("Stage 1: Sparse Structure Generation")
277
+ with gr.Row():
278
+ ss_guidance_strength = gr.Slider(0.0, 10.0, label="Guidance Strength", value=7.5, step=0.1)
279
+ ss_sampling_steps = gr.Slider(1, 50, label="Sampling Steps", value=12, step=1)
280
+ gr.Markdown("Stage 2: Structured Latent Generation")
281
+ with gr.Row():
282
+ slat_guidance_strength = gr.Slider(0.0, 10.0, label="Guidance Strength", value=3.0, step=0.1)
283
+ slat_sampling_steps = gr.Slider(1, 50, label="Sampling Steps", value=12, step=1)
284
+ multiimage_algo = gr.Radio(["stochastic", "multidiffusion"], label="Multi-image Algorithm", value="stochastic")
285
+
286
+ generate_btn = gr.Button("Generate")
287
+
288
+ with gr.Accordion(label="GLB Extraction Settings", open=False):
289
+ mesh_simplify = gr.Slider(0.9, 0.98, label="Simplify", value=0.95, step=0.01)
290
+ texture_size = gr.Slider(512, 2048, label="Texture Size", value=1024, step=512)
291
+
292
+ with gr.Row():
293
+ extract_glb_btn = gr.Button("Extract GLB", interactive=False)
294
+ extract_gs_btn = gr.Button("Extract Gaussian", interactive=False)
295
+ gr.Markdown("""
296
+ *NOTE: Gaussian file can be very large (~50MB), it will take a while to display and download.*
297
+ """)
298
+
299
+ with gr.Column():
300
+ video_output = gr.Video(label="Generated 3D Asset", autoplay=True, loop=True, height=300)
301
+ model_output = LitModel3D(label="Extracted GLB/Gaussian", exposure=10.0, height=300)
302
+
303
+ with gr.Row():
304
+ download_glb = gr.DownloadButton(label="Download GLB", interactive=False)
305
+ download_gs = gr.DownloadButton(label="Download Gaussian", interactive=False)
306
+
307
+ is_multiimage = gr.State(False)
308
+ output_buf = gr.State()
309
+
310
+ # Example images at the bottom of the page
311
+ with gr.Row() as single_image_example:
312
+ examples = gr.Examples(
313
+ examples=[
314
+ f'assets/example_image/{image}'
315
+ for image in os.listdir("assets/example_image")
316
+ ],
317
+ inputs=[image_prompt],
318
+ fn=preprocess_image,
319
+ outputs=[image_prompt],
320
+ run_on_click=True,
321
+ examples_per_page=64,
322
+ )
323
+ with gr.Row(visible=False) as multiimage_example:
324
+ examples_multi = gr.Examples(
325
+ examples=prepare_multi_example(),
326
+ inputs=[image_prompt],
327
+ fn=split_image,
328
+ outputs=[multiimage_prompt],
329
+ run_on_click=True,
330
+ examples_per_page=8,
331
+ )
332
+
333
+ # Handlers
334
+ demo.load(start_session)
335
+ demo.unload(end_session)
336
+
337
+ single_image_input_tab.select(
338
+ lambda: tuple([False, gr.Row.update(visible=True), gr.Row.update(visible=False)]),
339
+ outputs=[is_multiimage, single_image_example, multiimage_example]
340
+ )
341
+ multiimage_input_tab.select(
342
+ lambda: tuple([True, gr.Row.update(visible=False), gr.Row.update(visible=True)]),
343
+ outputs=[is_multiimage, single_image_example, multiimage_example]
344
+ )
345
+
346
+ image_prompt.upload(
347
+ preprocess_image,
348
+ inputs=[image_prompt],
349
+ outputs=[image_prompt],
350
+ )
351
+ multiimage_prompt.upload(
352
+ preprocess_images,
353
+ inputs=[multiimage_prompt],
354
+ outputs=[multiimage_prompt],
355
+ )
356
+
357
+ generate_btn.click(
358
+ get_seed,
359
+ inputs=[randomize_seed, seed],
360
+ outputs=[seed],
361
+ ).then(
362
+ image_to_3d,
363
+ inputs=[image_prompt, multiimage_prompt, is_multiimage, seed, ss_guidance_strength, ss_sampling_steps, slat_guidance_strength, slat_sampling_steps, multiimage_algo],
364
+ outputs=[output_buf, video_output],
365
+ ).then(
366
+ lambda: tuple([gr.Button(interactive=True), gr.Button(interactive=True)]),
367
+ outputs=[extract_glb_btn, extract_gs_btn],
368
+ )
369
+
370
+ video_output.clear(
371
+ lambda: tuple([gr.Button(interactive=False), gr.Button(interactive=False)]),
372
+ outputs=[extract_glb_btn, extract_gs_btn],
373
+ )
374
+
375
+ extract_glb_btn.click(
376
+ extract_glb,
377
+ inputs=[output_buf, mesh_simplify, texture_size],
378
+ outputs=[model_output, download_glb],
379
+ ).then(
380
+ lambda: gr.Button(interactive=True),
381
+ outputs=[download_glb],
382
+ )
383
+
384
+ extract_gs_btn.click(
385
+ extract_gaussian,
386
+ inputs=[output_buf],
387
+ outputs=[model_output, download_gs],
388
+ ).then(
389
+ lambda: gr.Button(interactive=True),
390
+ outputs=[download_gs],
391
+ )
392
+
393
+ model_output.clear(
394
+ lambda: gr.Button(interactive=False),
395
+ outputs=[download_glb],
396
+ )
397
+
398
+
399
+ # Launch the Gradio app
400
+ if __name__ == "__main__":
401
+ pipeline = TrellisImageTo3DPipeline.from_pretrained("JeffreyXiang/TRELLIS-image-large")
402
+ pipeline.cuda()
403
+ demo.launch()