zhuhai111 commited on
Commit
75b2ef1
·
verified ·
1 Parent(s): 53ef571

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +245 -200
  2. requirements.txt +30 -31
app.py CHANGED
@@ -1,200 +1,245 @@
1
- import gradio as gr
2
- import os
3
- import shutil
4
- import tempfile
5
- import datetime
6
- import numpy as np
7
- import torch
8
- import imageio
9
- import trimesh
10
-
11
- from PIL import Image
12
- from typing import *
13
- from gradio_litmodel3d import LitModel3D
14
- from trellis.pipelines import TrellisImageTo3DPipeline
15
- from trellis.utils import render_utils
16
-
17
- os.environ['SPCONV_ALGO'] = 'native'
18
-
19
- MAX_SEED = np.iinfo(np.int32).max
20
- TMP_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'tmp')
21
- os.makedirs(TMP_DIR, exist_ok=True)
22
-
23
- def preprocess_mesh(mesh_prompt):
24
- print("Processing mesh")
25
- trimesh_mesh = trimesh.load_mesh(mesh_prompt)
26
- trimesh_mesh.export(mesh_prompt+'.glb')
27
- return mesh_prompt+'.glb'
28
-
29
- def preprocess_image(image):
30
- if image is None:
31
- return None
32
- image = pipeline.preprocess_image(image, resolution=1024)
33
- return image
34
-
35
- def generate_3d(image, seed=-1,
36
- ss_guidance_strength=3, ss_sampling_steps=50,
37
- slat_guidance_strength=3, slat_sampling_steps=6,):
38
- if image is None:
39
- return None, None, None, None
40
-
41
- if seed == -1:
42
- seed = np.random.randint(0, MAX_SEED)
43
-
44
- image = pipeline.preprocess_image(image, resolution=1024)
45
- normal_image = normal_predictor(image, resolution=768, match_input_resolution=True, data_type='object')
46
-
47
- outputs = pipeline.run(
48
- normal_image,
49
- seed=seed,
50
- formats=["mesh",],
51
- preprocess_image=False,
52
- sparse_structure_sampler_params={
53
- "steps": ss_sampling_steps,
54
- "cfg_strength": ss_guidance_strength,
55
- },
56
- slat_sampler_params={
57
- "steps": slat_sampling_steps,
58
- "cfg_strength": slat_guidance_strength,
59
- },
60
- )
61
- generated_mesh = outputs['mesh'][0]
62
-
63
- output_id = datetime.datetime.now().strftime("%Y%m%d%H%M%S")
64
- os.makedirs(os.path.join(TMP_DIR, output_id), exist_ok=True)
65
- mesh_path = f"{TMP_DIR}/{output_id}/mesh.glb"
66
-
67
- render_results = render_utils.render_video(generated_mesh, resolution=1024, ssaa=1, num_frames=8, pitch=0.25, inverse_direction=True)
68
-
69
- def combine_diagonal(color_np, normal_np):
70
- h, w, c = color_np.shape
71
- mask = np.fromfunction(lambda y, x: x > y, (h, w)).astype(bool)
72
- mask = np.stack([mask] * c, axis=-1)
73
- combined_np = np.where(mask, color_np, normal_np)
74
- return Image.fromarray(combined_np)
75
-
76
- preview_images = [combine_diagonal(c, n) for c, n in zip(render_results['color'], render_results['normal'])]
77
-
78
- trimesh_mesh = generated_mesh.to_trimesh(transform_pose=True)
79
- trimesh_mesh.export(mesh_path)
80
-
81
- return preview_images, normal_image, mesh_path, mesh_path
82
-
83
- def convert_mesh(mesh_path, export_format):
84
- if not mesh_path:
85
- return None
86
- temp_file = tempfile.NamedTemporaryFile(suffix=f".{export_format}", delete=False)
87
- mesh = trimesh.load_mesh(mesh_path)
88
- mesh.export(temp_file.name)
89
- return temp_file.name
90
-
91
- with gr.Blocks(css="footer {visibility: hidden}") as demo:
92
- gr.Markdown("""
93
- <h1 style='text-align: center;'>Hi3DGen: High-fidelity 3D Geometry Generation from Images via Normal Bridging</h1>
94
- <p style='text-align: center;'>
95
- <strong>V0.1, Introduced By
96
- <a href="https://gaplab.cuhk.edu.cn/" target="_blank">GAP Lab</a> from CUHKSZ and
97
- <a href="https://www.nvsgames.cn/" target="_blank">Game-AIGC Team</a> from ByteDance</strong>
98
- </p>
99
- """)
100
-
101
- with gr.Row():
102
- with gr.Column(scale=1):
103
- with gr.Tabs():
104
- with gr.Tab("Single Image"):
105
- with gr.Row():
106
- image_prompt = gr.Image(label="Image Prompt", image_mode="RGBA", type="pil")
107
- normal_output = gr.Image(label="Normal Bridge", image_mode="RGBA", type="pil")
108
- with gr.Tab("Multiple Images"):
109
- gr.Markdown("<div style='text-align: center; padding: 40px; font-size: 24px;'>Multiple Images functionality is coming soon!</div>")
110
- with gr.Accordion("Advanced Settings", open=False):
111
- seed = gr.Slider(-1, MAX_SEED, label="Seed", value=0, step=1)
112
- gr.Markdown("#### Stage 1: Sparse Structure Generation")
113
- with gr.Row():
114
- ss_guidance_strength = gr.Slider(0.0, 10.0, label="Guidance Strength", value=3, step=0.1)
115
- ss_sampling_steps = gr.Slider(1, 50, label="Sampling Steps", value=50, step=1)
116
- gr.Markdown("#### Stage 2: Structured Latent Generation")
117
- with gr.Row():
118
- slat_guidance_strength = gr.Slider(0.0, 10.0, label="Guidance Strength", value=3.0, step=0.1)
119
- slat_sampling_steps = gr.Slider(1, 50, label="Sampling Steps", value=6, step=1)
120
- with gr.Group():
121
- with gr.Row():
122
- gen_shape_btn = gr.Button("Generate Shape", size="lg", variant="primary")
123
-
124
- with gr.Column(scale=1):
125
- with gr.Tabs():
126
- with gr.Tab("Preview"):
127
- output_gallery = gr.Gallery(label="Examples", columns=4, rows=2, object_fit="contain", height="auto", show_label=False)
128
- with gr.Tab("3D Model"):
129
- with gr.Column():
130
- model_output = gr.Model3D(label="3D Model Preview (Each model is approx. 40MB)")
131
- with gr.Column():
132
- export_format = gr.Dropdown(
133
- choices=["obj", "glb", "ply", "stl"],
134
- value="glb",
135
- label="File Format"
136
- )
137
- download_btn = gr.DownloadButton(label="Export Mesh", interactive=False)
138
-
139
- image_prompt.upload(
140
- preprocess_image,
141
- inputs=[image_prompt],
142
- outputs=[image_prompt]
143
- )
144
-
145
- gen_shape_btn.click(
146
- generate_3d,
147
- inputs=[
148
- image_prompt, seed,
149
- ss_guidance_strength, ss_sampling_steps,
150
- slat_guidance_strength, slat_sampling_steps
151
- ],
152
- outputs=[output_gallery, normal_output, model_output, download_btn]
153
- ).then(
154
- lambda: gr.Button(interactive=True),
155
- outputs=[download_btn],
156
- )
157
-
158
- def update_download_button(mesh_path, export_format):
159
- if not mesh_path:
160
- return gr.File.update(value=None, interactive=False)
161
- download_path = convert_mesh(mesh_path, export_format)
162
- return download_path
163
-
164
- export_format.change(
165
- update_download_button,
166
- inputs=[model_output, export_format],
167
- outputs=[download_btn]
168
- ).then(
169
- lambda: gr.Button(interactive=True),
170
- outputs=[download_btn],
171
- )
172
-
173
- examples = gr.Examples(
174
- examples=[
175
- f'assets/example_image/{image}'
176
- for image in os.listdir("assets/example_image")
177
- ],
178
- inputs=image_prompt,
179
- )
180
-
181
- gr.Markdown("""
182
- **Acknowledgments**: Hi3DGen is built on the shoulders of giants. We acknowledge contributions from:
183
- - [Trellis 3D](https://github.com/microsoft/TRELLIS)
184
- - [StableNormal](https://github.com/hugoycj/StableNormal)
185
- """)
186
-
187
- if __name__ == "__main__":
188
- # ✅ 强制使用 CPU
189
- pipeline = TrellisImageTo3DPipeline.from_pretrained("Stable-X/trellis-normal-v0-1")
190
- pipeline.to("cpu") # <-- 强制使用 CPU
191
-
192
- normal_predictor = torch.hub.load(
193
- "hugoycj/StableNormal",
194
- "StableNormal_turbo",
195
- trust_repo=True,
196
- yoso_version="yoso-normal-v1-8-1"
197
- )
198
- normal_predictor.to("cpu") # <-- 也强制使用 CPU
199
-
200
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import spaces
3
+ from gradio_litmodel3d import LitModel3D
4
+
5
+ import os
6
+ import shutil
7
+ os.environ['SPCONV_ALGO'] = 'native'
8
+ from typing import *
9
+ import torch
10
+ import numpy as np
11
+ import imageio
12
+ from PIL import Image
13
+ from trellis.pipelines import TrellisImageTo3DPipeline
14
+ from trellis.utils import render_utils
15
+ import trimesh
16
+ import tempfile
17
+
18
+ MAX_SEED = np.iinfo(np.int32).max
19
+ TMP_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'tmp')
20
+ os.makedirs(TMP_DIR, exist_ok=True)
21
+
22
+ def preprocess_mesh(mesh_prompt):
23
+ print("Processing mesh")
24
+ trimesh_mesh = trimesh.load_mesh(mesh_prompt)
25
+ trimesh_mesh.export(mesh_prompt+'.glb')
26
+ return mesh_prompt+'.glb'
27
+
28
+ def preprocess_image(image):
29
+ if image is None:
30
+ return None
31
+ image = pipeline.preprocess_image(image, resolution=1024)
32
+ return image
33
+
34
+ # Removed @spaces.GPU decorator to allow CPU execution
35
+ def generate_3d(image, seed=-1,
36
+ ss_guidance_strength=3, ss_sampling_steps=50,
37
+ slat_guidance_strength=3, slat_sampling_steps=6,):
38
+ if image is None:
39
+ return None, None, None
40
+
41
+ if seed == -1:
42
+ seed = np.random.randint(0, MAX_SEED)
43
+
44
+ image = pipeline.preprocess_image(image, resolution=1024)
45
+ normal_image = normal_predictor(image, resolution=768, match_input_resolution=True, data_type='object')
46
+
47
+ outputs = pipeline.run(
48
+ normal_image,
49
+ seed=seed,
50
+ formats=["mesh",],
51
+ preprocess_image=False,
52
+ sparse_structure_sampler_params={
53
+ "steps": ss_sampling_steps,
54
+ "cfg_strength": ss_guidance_strength,
55
+ },
56
+ slat_sampler_params={
57
+ "steps": slat_sampling_steps,
58
+ "cfg_strength": slat_guidance_strength,
59
+ },
60
+ )
61
+ generated_mesh = outputs['mesh'][0]
62
+
63
+ # Save outputs
64
+ import datetime
65
+ output_id = datetime.datetime.now().strftime("%Y%m%d%H%M%S")
66
+ os.makedirs(os.path.join(TMP_DIR, output_id), exist_ok=True)
67
+ mesh_path = f"{TMP_DIR}/{output_id}/mesh.glb"
68
+
69
+ render_results = render_utils.render_video(generated_mesh, resolution=1024, ssaa=1, num_frames=8, pitch=0.25, inverse_direction=True)
70
+ def combine_diagonal(color_np, normal_np):
71
+ # Convert images to numpy arrays
72
+ h, w, c = color_np.shape
73
+ # Create a boolean mask that is True for pixels where x > y (diagonally)
74
+ mask = np.fromfunction(lambda y, x: x > y, (h, w))
75
+ mask = mask.astype(bool)
76
+ mask = np.stack([mask] * c, axis=-1)
77
+ # Where mask is True take color, else normal
78
+ combined_np = np.where(mask, color_np, normal_np)
79
+ return Image.fromarray(combined_np)
80
+
81
+ preview_images = [combine_diagonal(c, n) for c, n in zip(render_results['color'], render_results['normal'])]
82
+
83
+ # Export mesh
84
+ trimesh_mesh = generated_mesh.to_trimesh(transform_pose=True)
85
+
86
+ trimesh_mesh.export(mesh_path)
87
+
88
+ return preview_images, normal_image, mesh_path, mesh_path
89
+
90
+ def convert_mesh(mesh_path, export_format):
91
+ """Download the mesh in the selected format."""
92
+ if not mesh_path:
93
+ return None
94
+
95
+ # Create a temporary file to store the mesh data
96
+ temp_file = tempfile.NamedTemporaryFile(suffix=f".{export_format}", delete=False)
97
+ temp_file_path = temp_file.name
98
+
99
+ new_mesh_path = mesh_path.replace(".glb", f".{export_format}")
100
+ mesh = trimesh.load_mesh(mesh_path)
101
+ mesh.export(temp_file_path) # Export to the temporary file
102
+
103
+ return temp_file_path # Return the path to the temporary file
104
+
105
+ # Create the Gradio interface with improved layout
106
+ with gr.Blocks(css="footer {visibility: hidden}") as demo:
107
+ gr.Markdown(
108
+ """
109
+ <h1 style='text-align: center;'>Hi3DGen: High-fidelity 3D Geometry Generation from Images via Normal Bridging</h1>
110
+ <p style='text-align: center;'>
111
+ <strong>V0.1, Introduced By
112
+ <a href="https://gaplab.cuhk.edu.cn/" target="_blank">GAP Lab</a> from CUHKSZ and
113
+ <a href="https://www.nvsgames.cn/" target="_blank">Game-AIGC Team</a> from ByteDance</strong>
114
+ </p>
115
+ """
116
+ )
117
+
118
+ with gr.Row():
119
+ gr.Markdown("""
120
+ <p align="center">
121
+ <a title="Website" href="https://stable-x.github.io/Hi3DGen/" target="_blank" rel="noopener noreferrer" style="display: inline-block;">
122
+ <img src="https://www.obukhov.ai/img/badges/badge-website.svg">
123
+ </a>
124
+ <a title="arXiv" href="https://stable-x.github.io/Hi3DGen/hi3dgen_paper.pdf" target="_blank" rel="noopener noreferrer" style="display: inline-block;">
125
+ <img src="https://www.obukhov.ai/img/badges/badge-pdf.svg">
126
+ </a>
127
+ <a title="Github" href="https://github.com/Stable-X/Hi3DGen" target="_blank" rel="noopener noreferrer" style="display: inline-block;">
128
+ <img src="https://img.shields.io/github/stars/Stable-X/Hi3DGen?label=GitHub%20%E2%98%85&logo=github&color=C8C" alt="badge-github-stars">
129
+ </a>
130
+ <a title="Social" href="https://x.com/ychngji6" target="_blank" rel="noopener noreferrer" style="display: inline-block;">
131
+ <img src="https://www.obukhov.ai/img/badges/badge-social.svg" alt="social">
132
+ </a>
133
+ </p>
134
+ """)
135
+
136
+ with gr.Row():
137
+ with gr.Column(scale=1):
138
+ with gr.Tabs():
139
+
140
+ with gr.Tab("Single Image"):
141
+ with gr.Row():
142
+ image_prompt = gr.Image(label="Image Prompt", image_mode="RGBA", type="pil")
143
+ normal_output = gr.Image(label="Normal Bridge", image_mode="RGBA", type="pil")
144
+
145
+ with gr.Tab("Multiple Images"):
146
+ gr.Markdown("<div style='text-align: center; padding: 40px; font-size: 24px;'>Multiple Images functionality is coming soon!</div>")
147
+
148
+ with gr.Accordion("Advanced Settings", open=False):
149
+ seed = gr.Slider(-1, MAX_SEED, label="Seed", value=0, step=1)
150
+ gr.Markdown("#### Stage 1: Sparse Structure Generation")
151
+ with gr.Row():
152
+ ss_guidance_strength = gr.Slider(0.0, 10.0, label="Guidance Strength", value=3, step=0.1)
153
+ ss_sampling_steps = gr.Slider(1, 50, label="Sampling Steps", value=50, step=1)
154
+ gr.Markdown("#### Stage 2: Structured Latent Generation")
155
+ with gr.Row():
156
+ slat_guidance_strength = gr.Slider(0.0, 10.0, label="Guidance Strength", value=3.0, step=0.1)
157
+ slat_sampling_steps = gr.Slider(1, 50, label="Sampling Steps", value=6, step=1)
158
+
159
+ with gr.Group():
160
+ with gr.Row():
161
+ gen_shape_btn = gr.Button("Generate Shape", size="lg", variant="primary")
162
+
163
+ # Right column - Output
164
+ with gr.Column(scale=1):
165
+ with gr.Tabs():
166
+ with gr.Tab("Preview"):
167
+ output_gallery = gr.Gallery(label="Examples", columns=4, rows=2, object_fit="contain", height="auto",show_label=False)
168
+ with gr.Tab("3D Model"):
169
+ with gr.Column():
170
+ model_output = gr.Model3D(label="3D Model Preview (Each model is approximately 40MB, may take around 1 minute to load)")
171
+ with gr.Column():
172
+ export_format = gr.Dropdown(
173
+ choices=["obj", "glb", "ply", "stl"],
174
+ value="glb",
175
+ label="File Format"
176
+ )
177
+ download_btn = gr.DownloadButton(label="Export Mesh", interactive=False)
178
+
179
+ image_prompt.upload(
180
+ preprocess_image,
181
+ inputs=[image_prompt],
182
+ outputs=[image_prompt]
183
+ )
184
+
185
+ gen_shape_btn.click(
186
+ generate_3d,
187
+ inputs=[
188
+ image_prompt, seed,
189
+ ss_guidance_strength, ss_sampling_steps,
190
+ slat_guidance_strength, slat_sampling_steps
191
+ ],
192
+ outputs=[output_gallery, normal_output, model_output, download_btn]
193
+ ).then(
194
+ lambda: gr.Button(interactive=True),
195
+ outputs=[download_btn],
196
+ )
197
+
198
+
199
+ def update_download_button(mesh_path, export_format):
200
+ if not mesh_path:
201
+ return gr.File.update(value=None, interactive=False)
202
+
203
+ download_path = convert_mesh(mesh_path, export_format)
204
+ return download_path
205
+
206
+ export_format.change(
207
+ update_download_button,
208
+ inputs=[model_output, export_format],
209
+ outputs=[download_btn]
210
+ ).then(
211
+ lambda: gr.Button(interactive=True),
212
+ outputs=[download_btn],
213
+ )
214
+
215
+ examples = gr.Examples(
216
+ examples=[
217
+ f'assets/example_image/{image}'
218
+ for image in os.listdir("assets/example_image")
219
+ ],
220
+ inputs=image_prompt,
221
+ )
222
+
223
+ gr.Markdown(
224
+ """
225
+ **Acknowledgments**: Hi3DGen is built on the shoulders of giants. We would like to express our gratitude to the open-source research community and the developers of these pioneering projects:
226
+ - **3D Modeling:** Our 3D Model is finetuned from the SOTA open-source 3D foundation model [Trellis](https://github.com/microsoft/TRELLIS) and we draw inspiration from the teams behind [Rodin](https://hyperhuman.deemos.com/rodin), [Tripo](https://www.tripo3d.ai/app/home), and [Dora](https://github.com/Seed3D/Dora).
227
+ - **Normal Estimation:** Our Normal Estimation Model builds on the leading normal estimation research such as [StableNormal](https://github.com/hugoycj/StableNormal) and [GenPercept](https://github.com/aim-uofa/GenPercept).
228
+
229
+ **Your contributions and collaboration push the boundaries of 3D modeling!**
230
+ """
231
+ )
232
+
233
+ if __name__ == "__main__":
234
+ # Initialize pipeline
235
+ pipeline = TrellisImageTo3DPipeline.from_pretrained("Stable-X/trellis-normal-v0-1")
236
+ # Use CPU instead of GPU
237
+ pipeline.to("cpu")
238
+
239
+ # Initialize normal predictor
240
+ normal_predictor = torch.hub.load("hugoycj/StableNormal", "StableNormal_turbo", trust_repo=True, yoso_version='yoso-normal-v1-8-1')
241
+ # Ensure normal predictor is on CPU
242
+ normal_predictor.to("cpu")
243
+
244
+ # Launch the app
245
+ demo.launch()
requirements.txt CHANGED
@@ -1,31 +1,30 @@
1
- --extra-index-url https://download.pytorch.org/whl/cu121
2
- huggingface_hub==0.25.0
3
- diffusers==0.28.0
4
- accelerate==1.2.1
5
- kornia==0.8.0
6
- timm==0.6.7
7
-
8
- torch==2.4.0
9
- torchvision==0.19.0
10
- pillow==10.4.0
11
- imageio==2.36.1
12
- imageio-ffmpeg==0.5.1
13
- tqdm==4.67.1
14
- easydict==1.13
15
- opencv-python-headless==4.10.0.84
16
- scipy==1.14.1
17
- rembg==2.0.60
18
- onnxruntime==1.20.1
19
- trimesh==4.5.3
20
- xatlas==0.0.9
21
- pyvista==0.44.2
22
- pymeshfix==0.17.0
23
- igraph==0.11.8
24
- git+https://github.com/EasternJournalist/utils3d.git@9a4eb15e4021b67b12c460c7057d642626897ec8
25
- xformers==0.0.27.post2
26
- spconv-cu120==2.3.6
27
- transformers==4.46.3
28
- gradio_litmodel3d==0.0.1
29
- https://github.com/Dao-AILab/flash-attention/releases/download/v2.7.0.post2/flash_attn-2.7.0.post2+cu12torch2.4cxx11abiFALSE-cp310-cp310-linux_x86_64.whl
30
- https://huggingface.co/spaces/JeffreyXiang/TRELLIS/resolve/main/wheels/diff_gaussian_rasterization-0.0.0-cp310-cp310-linux_x86_64.whl?download=true
31
- https://huggingface.co/spaces/JeffreyXiang/TRELLIS/resolve/main/wheels/nvdiffrast-0.3.3-cp310-cp310-linux_x86_64.whl?download=true
 
1
+ huggingface_hub==0.25.0
2
+ diffusers==0.28.0
3
+ accelerate==1.2.1
4
+ kornia==0.8.0
5
+ timm==0.6.7
6
+
7
+ # CPU versions of PyTorch packages
8
+ torch==2.4.0+cpu
9
+ torchvision==0.19.0+cpu
10
+ --extra-index-url https://download.pytorch.org/whl/cpu
11
+
12
+ pillow==10.4.0
13
+ imageio==2.36.1
14
+ imageio-ffmpeg==0.5.1
15
+ tqdm==4.67.1
16
+ easydict==1.13
17
+ opencv-python-headless==4.10.0.84
18
+ scipy==1.14.1
19
+ rembg==2.0.60
20
+ onnxruntime==1.20.1
21
+ trimesh==4.5.3
22
+ xatlas==0.0.9
23
+ pyvista==0.44.2
24
+ pymeshfix==0.17.0
25
+ igraph==0.11.8
26
+ git+https://github.com/EasternJournalist/utils3d.git@9a4eb15e4021b67b12c460c7057d642626897ec8
27
+
28
+ # Remove GPU-specific packages
29
+ transformers==4.46.3
30
+ gradio_litmodel3d==0.0.1