Aluren commited on
Commit
eebae35
·
verified ·
1 Parent(s): 367a6d1

Upload 35 files

Browse files
Files changed (36) hide show
  1. .gitattributes +11 -0
  2. app.py +246 -0
  3. assets/image/100.png +3 -0
  4. assets/image/34933195-9c2c-4271-8d31-a28bc5348b7a.png +0 -0
  5. assets/image/503d193a-1b9b-4685-b05f-00ac82f93d7b.png +3 -0
  6. assets/image/579584fb-8d1c-4312-a3f0-f7a81bd16493.png +0 -0
  7. assets/image/a5d09c66-1617-465c-aec9-431f48d9a7e1.png +3 -0
  8. assets/image/cb7e6c4a-b4dd-483c-9789-3d4887ee7434.png +3 -0
  9. assets/image/e799e6b4-3b47-40e0-befb-b156af8758ad.png +0 -0
  10. assets/model/100.glb +3 -0
  11. assets/model/34933195-9c2c-4271-8d31-a28bc5348b7a.glb +3 -0
  12. assets/model/503d193a-1b9b-4685-b05f-00ac82f93d7b.glb +3 -0
  13. assets/model/579584fb-8d1c-4312-a3f0-f7a81bd16493.glb +3 -0
  14. assets/model/a5d09c66-1617-465c-aec9-431f48d9a7e1.glb +3 -0
  15. assets/model/cb7e6c4a-b4dd-483c-9789-3d4887ee7434.glb +3 -0
  16. assets/model/e799e6b4-3b47-40e0-befb-b156af8758ad.glb +3 -0
  17. detailgen3d/__init__.py +0 -0
  18. detailgen3d/inference_utils.py +17 -0
  19. detailgen3d/models/attention_processor.py +576 -0
  20. detailgen3d/models/autoencoders/__init__.py +1 -0
  21. detailgen3d/models/autoencoders/autoencoder_kl_triposg.py +536 -0
  22. detailgen3d/models/autoencoders/vae.py +69 -0
  23. detailgen3d/models/embeddings.py +96 -0
  24. detailgen3d/models/transformers/__init__.py +61 -0
  25. detailgen3d/models/transformers/detailgen3d_transformers.py +771 -0
  26. detailgen3d/models/transformers/modeling_outputs.py +8 -0
  27. detailgen3d/models/transformers/triposg_transformer.py +726 -0
  28. detailgen3d/pipelines/__init__.py +1 -0
  29. detailgen3d/pipelines/pipeline_detailgen3d.py +322 -0
  30. detailgen3d/pipelines/pipeline_detailgen3d_output.py +13 -0
  31. detailgen3d/pipelines/pipeline_utils.py +96 -0
  32. detailgen3d/schedulers/__init__.py +5 -0
  33. detailgen3d/schedulers/scheduling_rectified_flow.py +327 -0
  34. detailgen3d/utils/__init__.py +2 -0
  35. detailgen3d/utils/typing.py +64 -0
  36. scripts/inference_detailgen3d.py +70 -0
.gitattributes CHANGED
@@ -33,3 +33,14 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ assets/image/100.png filter=lfs diff=lfs merge=lfs -text
37
+ assets/image/503d193a-1b9b-4685-b05f-00ac82f93d7b.png filter=lfs diff=lfs merge=lfs -text
38
+ assets/image/a5d09c66-1617-465c-aec9-431f48d9a7e1.png filter=lfs diff=lfs merge=lfs -text
39
+ assets/image/cb7e6c4a-b4dd-483c-9789-3d4887ee7434.png filter=lfs diff=lfs merge=lfs -text
40
+ assets/model/100.glb filter=lfs diff=lfs merge=lfs -text
41
+ assets/model/34933195-9c2c-4271-8d31-a28bc5348b7a.glb filter=lfs diff=lfs merge=lfs -text
42
+ assets/model/503d193a-1b9b-4685-b05f-00ac82f93d7b.glb filter=lfs diff=lfs merge=lfs -text
43
+ assets/model/579584fb-8d1c-4312-a3f0-f7a81bd16493.glb filter=lfs diff=lfs merge=lfs -text
44
+ assets/model/a5d09c66-1617-465c-aec9-431f48d9a7e1.glb filter=lfs diff=lfs merge=lfs -text
45
+ assets/model/cb7e6c4a-b4dd-483c-9789-3d4887ee7434.glb filter=lfs diff=lfs merge=lfs -text
46
+ assets/model/e799e6b4-3b47-40e0-befb-b156af8758ad.glb filter=lfs diff=lfs merge=lfs -text
app.py ADDED
@@ -0,0 +1,246 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import random
3
+ import tempfile
4
+ from typing import Any, List, Union
5
+
6
+ import gradio as gr
7
+ import numpy as np
8
+ import torch
9
+ from gradio_image_prompter import ImagePrompter
10
+ from gradio_litmodel3d import LitModel3D
11
+ from huggingface_hub import snapshot_download
12
+ from PIL import Image
13
+ import trimesh
14
+ from skimage import measure
15
+
16
+ from detailgen3d.pipelines.pipeline_detailgen3d import DetailGen3DPipeline
17
+ from detailgen3d.inference_utils import generate_dense_grid_points
18
+
19
+ # Constants
20
+ MAX_SEED = np.iinfo(np.int32).max
21
+ TMP_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "tmp")
22
+ DTYPE = torch.bfloat16
23
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
24
+ REPO_ID = "" # 似乎还没有
25
+
26
+ MARKDOWN = """
27
+ ## Generating geometry details guided by reference image with [DetailGen3D](https://detailgen3d.github.io/DetailGen3D/)
28
+ 1. Upload a detailed image of the frontal view and a coarse model. Then clik "Run " to generate the refined result.
29
+ 2. If you find the generated 3D scene satisfactory, download it by clicking the "Download GLB" button.
30
+ 3. If you want the refine result to be more consistent with the image, please manually increase the CFG strength.
31
+ """
32
+ EXAMPLES = [
33
+ [
34
+ {
35
+ "image": "assets/image/100.png",
36
+ },
37
+ "assets/model/100.glb",
38
+ 42,
39
+ False,
40
+ ],
41
+ [
42
+ {
43
+ "image": "assets/image/503d193a-1b9b-4685-b05f-00ac82f93d7b.png",
44
+ },
45
+ "assets/model/503d193a-1b9b-4685-b05f-00ac82f93d7b.glb",
46
+ 42,
47
+ False,
48
+ ],
49
+ [
50
+ {
51
+ "image": "assets/image/34933195-9c2c-4271-8d31-a28bc5348b7a.png",
52
+ },
53
+ "assets/model/34933195-9c2c-4271-8d31-a28bc5348b7a.glb",
54
+ 42,
55
+ False,
56
+ ],
57
+ [
58
+ {
59
+ "image": "assets/image/a5d09c66-1617-465c-aec9-431f48d9a7e1.png",
60
+ },
61
+ "assets/model/a5d09c66-1617-465c-aec9-431f48d9a7e1.glb",
62
+ 42,
63
+ False,
64
+ ],
65
+ [
66
+ {
67
+ "image": "assets/image/cb7e6c4a-b4dd-483c-9789-3d4887ee7434.png",
68
+ },
69
+ "assets/model/cb7e6c4a-b4dd-483c-9789-3d4887ee7434.glb",
70
+ 42,
71
+ False,
72
+ ],
73
+ [
74
+ {
75
+ "image": "assets/image/e799e6b4-3b47-40e0-befb-b156af8758ad.png",
76
+ },
77
+ "assets/model/instant3d/e799e6b4-3b47-40e0-befb-b156af8758ad.glb",
78
+ 42,
79
+ False,
80
+ ],
81
+ ]
82
+
83
+ os.makedirs(TMP_DIR, exist_ok=True)
84
+
85
+ device = "cuda"
86
+ dtype = torch.float16
87
+
88
+ pipeline = DetailGen3DPipeline.from_pretrained(
89
+ "VAST-AI/DetailGen3D",
90
+ low_cpu_mem_usage=False
91
+ ).to(device, dtype=dtype)
92
+
93
+
94
+ def load_mesh(mesh_path, num_pc=20480):
95
+ mesh = trimesh.load(mesh_path,force="mesh")
96
+
97
+ center = mesh.bounding_box.centroid
98
+ mesh.apply_translation(-center)
99
+ scale = max(mesh.bounding_box.extents)
100
+ mesh.apply_scale(1.9 / scale)
101
+
102
+ surface, face_indices = trimesh.sample.sample_surface(mesh, 1000000,)
103
+ normal = mesh.face_normals[face_indices]
104
+
105
+ rng = np.random.default_rng()
106
+ ind = rng.choice(surface.shape[0], num_pc, replace=False)
107
+ surface = torch.FloatTensor(surface[ind])
108
+ normal = torch.FloatTensor(normal[ind])
109
+ surface = torch.cat([surface, normal], dim=-1).unsqueeze(0).cuda()
110
+
111
+ return surface
112
+
113
+ @torch.no_grad()
114
+ @torch.autocast(device_type=DEVICE)
115
+ def run_detailgen3d(
116
+ pipeline,
117
+ image,
118
+ mesh,
119
+ seed,
120
+ num_inference_steps,
121
+ guidance_scale,
122
+ ):
123
+ surface = load_mesh(mesh)
124
+
125
+ batch_size = 1
126
+
127
+ # sample query points for decoding
128
+ box_min = np.array([-1.005, -1.005, -1.005])
129
+ box_max = np.array([1.005, 1.005, 1.005])
130
+ sampled_points, grid_size, bbox_size = generate_dense_grid_points(
131
+ bbox_min=box_min, bbox_max=box_max, octree_depth=8, indexing="ij"
132
+ )
133
+ sampled_points = torch.FloatTensor(sampled_points).to(device, dtype=dtype)
134
+ sampled_points = sampled_points.unsqueeze(0).repeat(batch_size, 1, 1)
135
+
136
+ # inference pipeline
137
+ sample = pipeline.vae.encode(surface).latent_dist.sample()
138
+ occ = pipeline(image, latents=sample, sampled_points=sampled_points, guidance_scale=guidance_scale, noise_aug_level=0, num_inference_steps=num_inference_steps).samples[0]
139
+
140
+ # marching cubes
141
+ grid_logits = occ.view(grid_size).cpu().numpy()
142
+ vertices, faces, normals, _ = measure.marching_cubes(
143
+ grid_logits, 0, method="lewiner"
144
+ )
145
+ vertices = vertices / grid_size * bbox_size + box_min
146
+ mesh = trimesh.Trimesh(vertices.astype(np.float32), np.ascontiguousarray(faces))
147
+ return mesh
148
+
149
+ @torch.no_grad()
150
+ @torch.autocast(device_type=DEVICE)
151
+ def run_refinement(
152
+ rgb_image: Any,
153
+ mesh: Any,
154
+ seed: int,
155
+ randomize_seed: bool = False,
156
+ num_inference_steps: int = 50,
157
+ guidance_scale: float = 4.0,
158
+ ):
159
+ if randomize_seed:
160
+ seed = random.randint(0, MAX_SEED)
161
+
162
+ # print("rgb_image", rgb_image)
163
+ # print("mesh", rgb_image)
164
+
165
+ if not isinstance(rgb_image, Image.Image) and "image" in rgb_image:
166
+ rgb_image = rgb_image["image"]
167
+
168
+ scene = run_detailgen3d(
169
+ pipeline,
170
+ rgb_image,
171
+ mesh,
172
+ seed,
173
+ num_inference_steps,
174
+ guidance_scale,
175
+ )
176
+
177
+ _, tmp_path = tempfile.mkstemp(suffix=".glb", prefix="detailgen3d_", dir=TMP_DIR)
178
+ scene.export(tmp_path)
179
+
180
+ torch.cuda.empty_cache()
181
+
182
+ return tmp_path, tmp_path, seed
183
+
184
+ # Demo
185
+ with gr.Blocks() as demo:
186
+ gr.Markdown(MARKDOWN)
187
+
188
+ with gr.Row():
189
+ with gr.Column():
190
+ with gr.Row():
191
+ image_prompts = ImagePrompter(label="Input Image", type="pil")
192
+ mesh = gr.Model3D(label="Input Coarse Model",camera_position=(90,90,3))
193
+
194
+ with gr.Accordion("Generation Settings", open=False):
195
+ seed = gr.Slider(
196
+ label="Seed",
197
+ minimum=0,
198
+ maximum=MAX_SEED,
199
+ step=1,
200
+ value=0,
201
+ )
202
+ randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
203
+ num_inference_steps = gr.Slider(
204
+ label="Number of inference steps",
205
+ minimum=1,
206
+ maximum=50,
207
+ step=1,
208
+ value=50,
209
+ )
210
+ guidance_scale = gr.Slider(
211
+ label="CFG scale",
212
+ minimum=0.0,
213
+ maximum=50.0,
214
+ step=0.1,
215
+ value=4.0,
216
+ )
217
+ gen_button = gr.Button("Run Refinement", variant="primary")
218
+
219
+ with gr.Column():
220
+ model_output = LitModel3D(label="Generated GLB", exposure=1.0, height=500,camera_position=(90,90,3))
221
+ download_glb = gr.DownloadButton(label="Download GLB", interactive=False)
222
+
223
+ with gr.Row():
224
+ gr.Examples(
225
+ examples=EXAMPLES,
226
+ fn=run_refinement,
227
+ inputs=[image_prompts, mesh, seed, randomize_seed],
228
+ outputs=[model_output, download_glb, seed],
229
+ cache_examples=False,
230
+ )
231
+
232
+ gen_button.click(
233
+ run_refinement,
234
+ inputs=[
235
+ image_prompts,
236
+ mesh,
237
+ seed,
238
+ randomize_seed,
239
+ num_inference_steps,
240
+ guidance_scale,
241
+ ],
242
+ outputs=[model_output, download_glb, seed],
243
+ ).then(lambda: gr.Button(interactive=True), outputs=[download_glb])
244
+
245
+
246
+ demo.launch()
assets/image/100.png ADDED

Git LFS Details

  • SHA256: 59167d1793ff737650efa3495d45cb0230c298712dbc7a495757581eb8899f21
  • Pointer size: 132 Bytes
  • Size of remote file: 1.61 MB
assets/image/34933195-9c2c-4271-8d31-a28bc5348b7a.png ADDED
assets/image/503d193a-1b9b-4685-b05f-00ac82f93d7b.png ADDED

Git LFS Details

  • SHA256: 166a5c8ffbcd3775b2f2640dac1eb2b0c902fe7ebeeff04d33bf4abe8d67f080
  • Pointer size: 131 Bytes
  • Size of remote file: 133 kB
assets/image/579584fb-8d1c-4312-a3f0-f7a81bd16493.png ADDED
assets/image/a5d09c66-1617-465c-aec9-431f48d9a7e1.png ADDED

Git LFS Details

  • SHA256: 4c8203fb4820393918ec6dee91e0bb7981e24a241a10613ed22aa53e505bf5de
  • Pointer size: 131 Bytes
  • Size of remote file: 152 kB
assets/image/cb7e6c4a-b4dd-483c-9789-3d4887ee7434.png ADDED

Git LFS Details

  • SHA256: 4a66b5dc5bf5d0ff8ec4b720e26fee8995c64807f147523f6f4d0c1b3d83097e
  • Pointer size: 131 Bytes
  • Size of remote file: 107 kB
assets/image/e799e6b4-3b47-40e0-befb-b156af8758ad.png ADDED
assets/model/100.glb ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:353a255228aa9a95a0607a9da07decc4f9fa72378b58773540029b62c56b0680
3
+ size 650964
assets/model/34933195-9c2c-4271-8d31-a28bc5348b7a.glb ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f81d1f33696ad3bae8a81df4754d1a582ac819a32d5837cfd27ad3f9419f830e
3
+ size 969996
assets/model/503d193a-1b9b-4685-b05f-00ac82f93d7b.glb ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e1dfd2100d0997a3ee423ee40c0a7ce40c04f99a0bb7147962444c5ab5ae8550
3
+ size 961580
assets/model/579584fb-8d1c-4312-a3f0-f7a81bd16493.glb ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:30c6f40b2fb6be7e887ba554e81da39ee7ee690238712eb8fb7dde6691c131ba
3
+ size 1886840
assets/model/a5d09c66-1617-465c-aec9-431f48d9a7e1.glb ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:da2b13e6939c0cef2d5c092a1814905abc948814552d8d3ff71163a2cc9e25d5
3
+ size 958896
assets/model/cb7e6c4a-b4dd-483c-9789-3d4887ee7434.glb ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5ff09758624653fec32bda11b47c04a44c1c79f327220a953f0ba4633f7ac871
3
+ size 951492
assets/model/e799e6b4-3b47-40e0-befb-b156af8758ad.glb ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8d2843167e957e71b9a144e104dea8e41c9eba50de912aa8fd48b27e642d8983
3
+ size 944340
detailgen3d/__init__.py ADDED
File without changes
detailgen3d/inference_utils.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+
3
+
4
+ def generate_dense_grid_points(
5
+ bbox_min: np.ndarray, bbox_max: np.ndarray, octree_depth: int, indexing: str = "ij"
6
+ ):
7
+ length = bbox_max - bbox_min
8
+ num_cells = np.exp2(octree_depth)
9
+ x = np.linspace(bbox_min[0], bbox_max[0], int(num_cells) + 1, dtype=np.float32)
10
+ y = np.linspace(bbox_min[1], bbox_max[1], int(num_cells) + 1, dtype=np.float32)
11
+ z = np.linspace(bbox_min[2], bbox_max[2], int(num_cells) + 1, dtype=np.float32)
12
+ [xs, ys, zs] = np.meshgrid(x, y, z, indexing=indexing)
13
+ xyz = np.stack((xs, ys, zs), axis=-1)
14
+ xyz = xyz.reshape(-1, 3)
15
+ grid_size = [int(num_cells) + 1, int(num_cells) + 1, int(num_cells) + 1]
16
+
17
+ return xyz, grid_size, length
detailgen3d/models/attention_processor.py ADDED
@@ -0,0 +1,576 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Callable, List, Optional, Tuple, Union
2
+
3
+ import torch
4
+ import torch.nn.functional as F
5
+ from diffusers.models.attention_processor import Attention
6
+ from diffusers.utils import logging
7
+ from diffusers.utils.import_utils import is_torch_npu_available, is_xformers_available
8
+ from diffusers.utils.torch_utils import is_torch_version, maybe_allow_in_graph
9
+ from einops import rearrange
10
+ from torch import nn
11
+
12
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
13
+
14
+ class FlashTripo2AttnProcessor2_0:
15
+ r"""
16
+ Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). This is
17
+ used in the Tripo2DiT model. It applies a s normalization layer and rotary embedding on query and key vector.
18
+ """
19
+
20
+ def __init__(self, topk=True):
21
+ if not hasattr(F, "scaled_dot_product_attention"):
22
+ raise ImportError(
23
+ "AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
24
+ )
25
+ self.topk = topk
26
+
27
+ def qkv(self, attn, q, k, v, attn_mask, dropout_p, is_causal):
28
+ if k.shape[-2] == 3072:
29
+ topk = 1024
30
+ elif k.shape[-2] == 512:
31
+ topk = 256
32
+ else:
33
+ topk = k.shape[-2] // 3
34
+
35
+ if self.topk is True:
36
+ q1 = q[:, :, ::100, :]
37
+ sim = q1 @ k.transpose(-1, -2)
38
+ sim = torch.mean(sim, -2)
39
+ topk_ind = torch.topk(sim, dim=-1, k=topk).indices.squeeze(-2).unsqueeze(-1)
40
+ topk_ind = topk_ind.expand(-1, -1, -1, v.shape[-1])
41
+ v0 = torch.gather(v, dim=-2, index=topk_ind)
42
+ k0 = torch.gather(k, dim=-2, index=topk_ind)
43
+ out = F.scaled_dot_product_attention(q, k0, v0)
44
+ elif self.topk is False:
45
+ out = F.scaled_dot_product_attention(q, k, v)
46
+ else:
47
+ idx, counts = self.topk
48
+ start = 0
49
+ outs = []
50
+ for grid_coord, count in zip(idx, counts):
51
+ end = start + count
52
+ q_chunk = q[:, :, start:end, :]
53
+ q1 = q_chunk[:, :, ::50, :]
54
+ sim = q1 @ k.transpose(-1, -2)
55
+ sim = torch.mean(sim, -2)
56
+ topk_ind = torch.topk(sim, dim=-1, k=topk).indices.squeeze(-2).unsqueeze(-1)
57
+ topk_ind = topk_ind.expand(-1, -1, -1, v.shape[-1])
58
+ v0 = torch.gather(v, dim=-2, index=topk_ind)
59
+ k0 = torch.gather(k, dim=-2, index=topk_ind)
60
+ out = F.scaled_dot_product_attention(q_chunk, k0, v0)
61
+ outs.append(out)
62
+ start += count
63
+ out = torch.cat(outs, dim=-2)
64
+ self.topk = False
65
+ return out
66
+
67
+ def __call__(
68
+ self,
69
+ attn: Attention,
70
+ hidden_states: torch.Tensor,
71
+ encoder_hidden_states: Optional[torch.Tensor] = None,
72
+ attention_mask: Optional[torch.Tensor] = None,
73
+ temb: Optional[torch.Tensor] = None,
74
+ image_rotary_emb: Optional[torch.Tensor] = None,
75
+ ) -> torch.Tensor:
76
+ from diffusers.models.embeddings import apply_rotary_emb
77
+
78
+ residual = hidden_states
79
+ if attn.spatial_norm is not None:
80
+ hidden_states = attn.spatial_norm(hidden_states, temb)
81
+
82
+ input_ndim = hidden_states.ndim
83
+
84
+ if input_ndim == 4:
85
+ batch_size, channel, height, width = hidden_states.shape
86
+ hidden_states = hidden_states.view(
87
+ batch_size, channel, height * width
88
+ ).transpose(1, 2)
89
+
90
+ batch_size, sequence_length, _ = (
91
+ hidden_states.shape
92
+ if encoder_hidden_states is None
93
+ else encoder_hidden_states.shape
94
+ )
95
+
96
+ if attention_mask is not None:
97
+ attention_mask = attn.prepare_attention_mask(
98
+ attention_mask, sequence_length, batch_size
99
+ )
100
+ # scaled_dot_product_attention expects attention_mask shape to be
101
+ # (batch, heads, source_length, target_length)
102
+ attention_mask = attention_mask.view(
103
+ batch_size, attn.heads, -1, attention_mask.shape[-1]
104
+ )
105
+
106
+ if attn.group_norm is not None:
107
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(
108
+ 1, 2
109
+ )
110
+
111
+ query = attn.to_q(hidden_states)
112
+
113
+ if encoder_hidden_states is None:
114
+ encoder_hidden_states = hidden_states
115
+ elif attn.norm_cross:
116
+ encoder_hidden_states = attn.norm_encoder_hidden_states(
117
+ encoder_hidden_states
118
+ )
119
+
120
+ key = attn.to_k(encoder_hidden_states)
121
+ value = attn.to_v(encoder_hidden_states)
122
+
123
+ # NOTE that tripo2 split heads first then split qkv or kv, like .view(..., attn.heads, 3, dim)
124
+ # instead of .view(..., 3, attn.heads, dim). So we need to re-split here.
125
+ if not attn.is_cross_attention:
126
+ qkv = torch.cat((query, key, value), dim=-1)
127
+ split_size = qkv.shape[-1] // attn.heads // 3
128
+ qkv = qkv.view(batch_size, -1, attn.heads, split_size * 3)
129
+ query, key, value = torch.split(qkv, split_size, dim=-1)
130
+ else:
131
+ kv = torch.cat((key, value), dim=-1)
132
+ split_size = kv.shape[-1] // attn.heads // 2
133
+ kv = kv.view(batch_size, -1, attn.heads, split_size * 2)
134
+ key, value = torch.split(kv, split_size, dim=-1)
135
+
136
+ head_dim = key.shape[-1]
137
+
138
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
139
+
140
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
141
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
142
+
143
+ if attn.norm_q is not None:
144
+ query = attn.norm_q(query)
145
+ if attn.norm_k is not None:
146
+ key = attn.norm_k(key)
147
+
148
+ # Apply RoPE if needed
149
+ if image_rotary_emb is not None:
150
+ query = apply_rotary_emb(query, image_rotary_emb)
151
+ if not attn.is_cross_attention:
152
+ key = apply_rotary_emb(key, image_rotary_emb)
153
+
154
+ # flashvdm topk
155
+ hidden_states = self.qkv(attn, query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False)
156
+
157
+ hidden_states = hidden_states.transpose(1, 2).reshape(
158
+ batch_size, -1, attn.heads * head_dim
159
+ )
160
+ hidden_states = hidden_states.to(query.dtype)
161
+
162
+ # linear proj
163
+ hidden_states = attn.to_out[0](hidden_states)
164
+ # dropout
165
+ hidden_states = attn.to_out[1](hidden_states)
166
+
167
+ if input_ndim == 4:
168
+ hidden_states = hidden_states.transpose(-1, -2).reshape(
169
+ batch_size, channel, height, width
170
+ )
171
+
172
+ if attn.residual_connection:
173
+ hidden_states = hidden_states + residual
174
+
175
+ hidden_states = hidden_states / attn.rescale_output_factor
176
+
177
+ return hidden_states
178
+
179
+ class TripoSGAttnProcessor2_0:
180
+ r"""
181
+ Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). This is
182
+ used in the TripoSG model. It applies a s normalization layer and rotary embedding on query and key vector.
183
+ """
184
+
185
+ def __init__(self):
186
+ if not hasattr(F, "scaled_dot_product_attention"):
187
+ raise ImportError(
188
+ "AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
189
+ )
190
+
191
+ def __call__(
192
+ self,
193
+ attn: Attention,
194
+ hidden_states: torch.Tensor,
195
+ encoder_hidden_states: Optional[torch.Tensor] = None,
196
+ attention_mask: Optional[torch.Tensor] = None,
197
+ temb: Optional[torch.Tensor] = None,
198
+ image_rotary_emb: Optional[torch.Tensor] = None,
199
+ ) -> torch.Tensor:
200
+ from diffusers.models.embeddings import apply_rotary_emb
201
+
202
+ residual = hidden_states
203
+ if attn.spatial_norm is not None:
204
+ hidden_states = attn.spatial_norm(hidden_states, temb)
205
+
206
+ input_ndim = hidden_states.ndim
207
+
208
+ if input_ndim == 4:
209
+ batch_size, channel, height, width = hidden_states.shape
210
+ hidden_states = hidden_states.view(
211
+ batch_size, channel, height * width
212
+ ).transpose(1, 2)
213
+
214
+ batch_size, sequence_length, _ = (
215
+ hidden_states.shape
216
+ if encoder_hidden_states is None
217
+ else encoder_hidden_states.shape
218
+ )
219
+
220
+ if attention_mask is not None:
221
+ attention_mask = attn.prepare_attention_mask(
222
+ attention_mask, sequence_length, batch_size
223
+ )
224
+ # scaled_dot_product_attention expects attention_mask shape to be
225
+ # (batch, heads, source_length, target_length)
226
+ attention_mask = attention_mask.view(
227
+ batch_size, attn.heads, -1, attention_mask.shape[-1]
228
+ )
229
+
230
+ if attn.group_norm is not None:
231
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(
232
+ 1, 2
233
+ )
234
+
235
+ query = attn.to_q(hidden_states)
236
+
237
+ if encoder_hidden_states is None:
238
+ encoder_hidden_states = hidden_states
239
+ elif attn.norm_cross:
240
+ encoder_hidden_states = attn.norm_encoder_hidden_states(
241
+ encoder_hidden_states
242
+ )
243
+
244
+ key = attn.to_k(encoder_hidden_states)
245
+ value = attn.to_v(encoder_hidden_states)
246
+
247
+ # NOTE that pre-trained models split heads first then split qkv or kv, like .view(..., attn.heads, 3, dim)
248
+ # instead of .view(..., 3, attn.heads, dim). So we need to re-split here.
249
+ if not attn.is_cross_attention:
250
+ qkv = torch.cat((query, key, value), dim=-1)
251
+ split_size = qkv.shape[-1] // attn.heads // 3
252
+ qkv = qkv.view(batch_size, -1, attn.heads, split_size * 3)
253
+ query, key, value = torch.split(qkv, split_size, dim=-1)
254
+ else:
255
+ kv = torch.cat((key, value), dim=-1)
256
+ split_size = kv.shape[-1] // attn.heads // 2
257
+ kv = kv.view(batch_size, -1, attn.heads, split_size * 2)
258
+ key, value = torch.split(kv, split_size, dim=-1)
259
+
260
+ head_dim = key.shape[-1]
261
+
262
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
263
+
264
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
265
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
266
+
267
+ if attn.norm_q is not None:
268
+ query = attn.norm_q(query)
269
+ if attn.norm_k is not None:
270
+ key = attn.norm_k(key)
271
+
272
+ # Apply RoPE if needed
273
+ if image_rotary_emb is not None:
274
+ query = apply_rotary_emb(query, image_rotary_emb)
275
+ if not attn.is_cross_attention:
276
+ key = apply_rotary_emb(key, image_rotary_emb)
277
+
278
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
279
+ # TODO: add support for attn.scale when we move to Torch 2.1
280
+ hidden_states = F.scaled_dot_product_attention(
281
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
282
+ )
283
+
284
+ hidden_states = hidden_states.transpose(1, 2).reshape(
285
+ batch_size, -1, attn.heads * head_dim
286
+ )
287
+ hidden_states = hidden_states.to(query.dtype)
288
+
289
+ # linear proj
290
+ hidden_states = attn.to_out[0](hidden_states)
291
+ # dropout
292
+ hidden_states = attn.to_out[1](hidden_states)
293
+
294
+ if input_ndim == 4:
295
+ hidden_states = hidden_states.transpose(-1, -2).reshape(
296
+ batch_size, channel, height, width
297
+ )
298
+
299
+ if attn.residual_connection:
300
+ hidden_states = hidden_states + residual
301
+
302
+ hidden_states = hidden_states / attn.rescale_output_factor
303
+
304
+ return hidden_states
305
+
306
+
307
+ class FusedTripoSGAttnProcessor2_0:
308
+ r"""
309
+ Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0) with fused
310
+ projection layers. This is used in the HunyuanDiT model. It applies a s normalization layer and rotary embedding on
311
+ query and key vector.
312
+ """
313
+
314
+ def __init__(self):
315
+ if not hasattr(F, "scaled_dot_product_attention"):
316
+ raise ImportError(
317
+ "FusedTripoSGAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
318
+ )
319
+
320
+ def __call__(
321
+ self,
322
+ attn: Attention,
323
+ hidden_states: torch.Tensor,
324
+ encoder_hidden_states: Optional[torch.Tensor] = None,
325
+ attention_mask: Optional[torch.Tensor] = None,
326
+ temb: Optional[torch.Tensor] = None,
327
+ image_rotary_emb: Optional[torch.Tensor] = None,
328
+ ) -> torch.Tensor:
329
+ from diffusers.models.embeddings import apply_rotary_emb
330
+
331
+ residual = hidden_states
332
+ if attn.spatial_norm is not None:
333
+ hidden_states = attn.spatial_norm(hidden_states, temb)
334
+
335
+ input_ndim = hidden_states.ndim
336
+
337
+ if input_ndim == 4:
338
+ batch_size, channel, height, width = hidden_states.shape
339
+ hidden_states = hidden_states.view(
340
+ batch_size, channel, height * width
341
+ ).transpose(1, 2)
342
+
343
+ batch_size, sequence_length, _ = (
344
+ hidden_states.shape
345
+ if encoder_hidden_states is None
346
+ else encoder_hidden_states.shape
347
+ )
348
+
349
+ if attention_mask is not None:
350
+ attention_mask = attn.prepare_attention_mask(
351
+ attention_mask, sequence_length, batch_size
352
+ )
353
+ # scaled_dot_product_attention expects attention_mask shape to be
354
+ # (batch, heads, source_length, target_length)
355
+ attention_mask = attention_mask.view(
356
+ batch_size, attn.heads, -1, attention_mask.shape[-1]
357
+ )
358
+
359
+ if attn.group_norm is not None:
360
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(
361
+ 1, 2
362
+ )
363
+
364
+ # NOTE that pre-trained split heads first, then split qkv
365
+ if encoder_hidden_states is None:
366
+ qkv = attn.to_qkv(hidden_states)
367
+ split_size = qkv.shape[-1] // attn.heads // 3
368
+ qkv = qkv.view(batch_size, -1, attn.heads, split_size * 3)
369
+ query, key, value = torch.split(qkv, split_size, dim=-1)
370
+ else:
371
+ if attn.norm_cross:
372
+ encoder_hidden_states = attn.norm_encoder_hidden_states(
373
+ encoder_hidden_states
374
+ )
375
+ query = attn.to_q(hidden_states)
376
+
377
+ kv = attn.to_kv(encoder_hidden_states)
378
+ split_size = kv.shape[-1] // attn.heads // 2
379
+ kv = kv.view(batch_size, -1, attn.heads, split_size * 2)
380
+ key, value = torch.split(kv, split_size, dim=-1)
381
+
382
+ head_dim = key.shape[-1]
383
+
384
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
385
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
386
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
387
+
388
+ if attn.norm_q is not None:
389
+ query = attn.norm_q(query)
390
+ if attn.norm_k is not None:
391
+ key = attn.norm_k(key)
392
+
393
+ # Apply RoPE if needed
394
+ if image_rotary_emb is not None:
395
+ query = apply_rotary_emb(query, image_rotary_emb)
396
+ if not attn.is_cross_attention:
397
+ key = apply_rotary_emb(key, image_rotary_emb)
398
+
399
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
400
+ # TODO: add support for attn.scale when we move to Torch 2.1
401
+ hidden_states = F.scaled_dot_product_attention(
402
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
403
+ )
404
+
405
+ hidden_states = hidden_states.transpose(1, 2).reshape(
406
+ batch_size, -1, attn.heads * head_dim
407
+ )
408
+ hidden_states = hidden_states.to(query.dtype)
409
+
410
+ # linear proj
411
+ hidden_states = attn.to_out[0](hidden_states)
412
+ # dropout
413
+ hidden_states = attn.to_out[1](hidden_states)
414
+
415
+ if input_ndim == 4:
416
+ hidden_states = hidden_states.transpose(-1, -2).reshape(
417
+ batch_size, channel, height, width
418
+ )
419
+
420
+ if attn.residual_connection:
421
+ hidden_states = hidden_states + residual
422
+
423
+ hidden_states = hidden_states / attn.rescale_output_factor
424
+
425
+ return hidden_states
426
+
427
+
428
+ class MIAttnProcessor2_0:
429
+ r"""
430
+ Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). This is
431
+ used in the TripoSG model. It applies a normalization layer and rotary embedding on query and key vector.
432
+ """
433
+
434
+ def __init__(self, use_mi: bool = True):
435
+ if not hasattr(F, "scaled_dot_product_attention"):
436
+ raise ImportError(
437
+ "AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
438
+ )
439
+
440
+ self.use_mi = use_mi
441
+
442
+ def __call__(
443
+ self,
444
+ attn: Attention,
445
+ hidden_states: torch.Tensor,
446
+ encoder_hidden_states: Optional[torch.Tensor] = None,
447
+ attention_mask: Optional[torch.Tensor] = None,
448
+ temb: Optional[torch.Tensor] = None,
449
+ image_rotary_emb: Optional[torch.Tensor] = None,
450
+ num_instances: Optional[torch.IntTensor] = None,
451
+ ) -> torch.Tensor:
452
+ from diffusers.models.embeddings import apply_rotary_emb
453
+
454
+ residual = hidden_states
455
+ if attn.spatial_norm is not None:
456
+ hidden_states = attn.spatial_norm(hidden_states, temb)
457
+
458
+ input_ndim = hidden_states.ndim
459
+
460
+ if input_ndim == 4:
461
+ batch_size, channel, height, width = hidden_states.shape
462
+ hidden_states = hidden_states.view(
463
+ batch_size, channel, height * width
464
+ ).transpose(1, 2)
465
+
466
+ batch_size, sequence_length, _ = (
467
+ hidden_states.shape
468
+ if encoder_hidden_states is None
469
+ else encoder_hidden_states.shape
470
+ )
471
+
472
+ if attention_mask is not None:
473
+ attention_mask = attn.prepare_attention_mask(
474
+ attention_mask, sequence_length, batch_size
475
+ )
476
+ # scaled_dot_product_attention expects attention_mask shape to be
477
+ # (batch, heads, source_length, target_length)
478
+ attention_mask = attention_mask.view(
479
+ batch_size, attn.heads, -1, attention_mask.shape[-1]
480
+ )
481
+
482
+ if attn.group_norm is not None:
483
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(
484
+ 1, 2
485
+ )
486
+
487
+ query = attn.to_q(hidden_states)
488
+
489
+ if encoder_hidden_states is None:
490
+ encoder_hidden_states = hidden_states
491
+ elif attn.norm_cross:
492
+ encoder_hidden_states = attn.norm_encoder_hidden_states(
493
+ encoder_hidden_states
494
+ )
495
+
496
+ key = attn.to_k(encoder_hidden_states)
497
+ value = attn.to_v(encoder_hidden_states)
498
+
499
+ # NOTE that pre-trained models split heads first then split qkv or kv, like .view(..., attn.heads, 3, dim)
500
+ # instead of .view(..., 3, attn.heads, dim). So we need to re-split here.
501
+ if not attn.is_cross_attention:
502
+ qkv = torch.cat((query, key, value), dim=-1)
503
+ split_size = qkv.shape[-1] // attn.heads // 3
504
+ qkv = qkv.view(batch_size, -1, attn.heads, split_size * 3)
505
+ query, key, value = torch.split(qkv, split_size, dim=-1)
506
+ else:
507
+ kv = torch.cat((key, value), dim=-1)
508
+ split_size = kv.shape[-1] // attn.heads // 2
509
+ kv = kv.view(batch_size, -1, attn.heads, split_size * 2)
510
+ key, value = torch.split(kv, split_size, dim=-1)
511
+
512
+ head_dim = key.shape[-1]
513
+
514
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
515
+
516
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
517
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
518
+
519
+ if attn.norm_q is not None:
520
+ query = attn.norm_q(query)
521
+ if attn.norm_k is not None:
522
+ key = attn.norm_k(key)
523
+
524
+ # Apply RoPE if needed
525
+ if image_rotary_emb is not None:
526
+ query = apply_rotary_emb(query, image_rotary_emb)
527
+ if not attn.is_cross_attention:
528
+ key = apply_rotary_emb(key, image_rotary_emb)
529
+
530
+ if self.use_mi and num_instances is not None:
531
+ key = rearrange(
532
+ key, "(b ni) h nt c -> b h (ni nt) c", ni=num_instances
533
+ ).repeat_interleave(num_instances, dim=0)
534
+ value = rearrange(
535
+ value, "(b ni) h nt c -> b h (ni nt) c", ni=num_instances
536
+ ).repeat_interleave(num_instances, dim=0)
537
+
538
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
539
+ hidden_states = F.scaled_dot_product_attention(
540
+ query,
541
+ key,
542
+ value,
543
+ dropout_p=0.0,
544
+ is_causal=False,
545
+ )
546
+ else:
547
+ hidden_states = F.scaled_dot_product_attention(
548
+ query,
549
+ key,
550
+ value,
551
+ attn_mask=attention_mask,
552
+ dropout_p=0.0,
553
+ is_causal=False,
554
+ )
555
+
556
+ hidden_states = hidden_states.transpose(1, 2).reshape(
557
+ batch_size, -1, attn.heads * head_dim
558
+ )
559
+ hidden_states = hidden_states.to(query.dtype)
560
+
561
+ # linear proj
562
+ hidden_states = attn.to_out[0](hidden_states)
563
+ # dropout
564
+ hidden_states = attn.to_out[1](hidden_states)
565
+
566
+ if input_ndim == 4:
567
+ hidden_states = hidden_states.transpose(-1, -2).reshape(
568
+ batch_size, channel, height, width
569
+ )
570
+
571
+ if attn.residual_connection:
572
+ hidden_states = hidden_states + residual
573
+
574
+ hidden_states = hidden_states / attn.rescale_output_factor
575
+
576
+ return hidden_states
detailgen3d/models/autoencoders/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .autoencoder_kl_triposg import TripoSGVAEModel
detailgen3d/models/autoencoders/autoencoder_kl_triposg.py ADDED
@@ -0,0 +1,536 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, List, Optional, Tuple, Union
2
+
3
+ import numpy as np
4
+ import torch
5
+ import torch.nn as nn
6
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
7
+ from diffusers.models.attention_processor import Attention, AttentionProcessor
8
+ from diffusers.models.autoencoders.vae import DecoderOutput
9
+ from diffusers.models.modeling_outputs import AutoencoderKLOutput
10
+ from diffusers.models.modeling_utils import ModelMixin
11
+ from diffusers.models.normalization import FP32LayerNorm, LayerNorm
12
+ from diffusers.utils import logging
13
+ from diffusers.utils.accelerate_utils import apply_forward_hook
14
+ from einops import repeat
15
+ from torch_cluster import fps
16
+ from tqdm import tqdm
17
+
18
+ from ..attention_processor import FusedTripoSGAttnProcessor2_0, TripoSGAttnProcessor2_0, FlashTripo2AttnProcessor2_0
19
+ from ..embeddings import FrequencyPositionalEmbedding
20
+ from ..transformers.triposg_transformer import DiTBlock
21
+ from .vae import DiagonalGaussianDistribution
22
+
23
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
24
+
25
+
26
+ class TripoSGEncoder(nn.Module):
27
+ def __init__(
28
+ self,
29
+ in_channels: int = 3,
30
+ dim: int = 512,
31
+ num_attention_heads: int = 8,
32
+ num_layers: int = 8,
33
+ ):
34
+ super().__init__()
35
+
36
+ self.proj_in = nn.Linear(in_channels, dim, bias=True)
37
+
38
+ self.blocks = nn.ModuleList(
39
+ [
40
+ DiTBlock(
41
+ dim=dim,
42
+ num_attention_heads=num_attention_heads,
43
+ use_self_attention=False,
44
+ use_cross_attention=True,
45
+ cross_attention_dim=dim,
46
+ cross_attention_norm_type="layer_norm",
47
+ activation_fn="gelu",
48
+ norm_type="fp32_layer_norm",
49
+ norm_eps=1e-5,
50
+ qk_norm=False,
51
+ qkv_bias=False,
52
+ ) # cross attention
53
+ ]
54
+ + [
55
+ DiTBlock(
56
+ dim=dim,
57
+ num_attention_heads=num_attention_heads,
58
+ use_self_attention=True,
59
+ self_attention_norm_type="fp32_layer_norm",
60
+ use_cross_attention=False,
61
+ activation_fn="gelu",
62
+ norm_type="fp32_layer_norm",
63
+ norm_eps=1e-5,
64
+ qk_norm=False,
65
+ qkv_bias=False,
66
+ )
67
+ for _ in range(num_layers) # self attention
68
+ ]
69
+ )
70
+
71
+ self.norm_out = LayerNorm(dim)
72
+
73
+ def forward(self, sample_1: torch.Tensor, sample_2: torch.Tensor):
74
+ hidden_states = self.proj_in(sample_1)
75
+ encoder_hidden_states = self.proj_in(sample_2)
76
+
77
+ for layer, block in enumerate(self.blocks):
78
+ if layer == 0:
79
+ hidden_states = block(
80
+ hidden_states, encoder_hidden_states=encoder_hidden_states
81
+ )
82
+ else:
83
+ hidden_states = block(hidden_states)
84
+
85
+ hidden_states = self.norm_out(hidden_states)
86
+
87
+ return hidden_states
88
+
89
+
90
+ class TripoSGDecoder(nn.Module):
91
+ def __init__(
92
+ self,
93
+ in_channels: int = 3,
94
+ out_channels: int = 1,
95
+ dim: int = 512,
96
+ num_attention_heads: int = 8,
97
+ num_layers: int = 16,
98
+ grad_type: str = "analytical",
99
+ grad_interval: float = 0.001,
100
+ ):
101
+ super().__init__()
102
+
103
+ if grad_type not in ["numerical", "analytical"]:
104
+ raise ValueError(f"grad_type must be one of ['numerical', 'analytical']")
105
+ self.grad_type = grad_type
106
+ self.grad_interval = grad_interval
107
+
108
+ self.blocks = nn.ModuleList(
109
+ [
110
+ DiTBlock(
111
+ dim=dim,
112
+ num_attention_heads=num_attention_heads,
113
+ use_self_attention=True,
114
+ self_attention_norm_type="fp32_layer_norm",
115
+ use_cross_attention=False,
116
+ activation_fn="gelu",
117
+ norm_type="fp32_layer_norm",
118
+ norm_eps=1e-5,
119
+ qk_norm=False,
120
+ qkv_bias=False,
121
+ )
122
+ for _ in range(num_layers) # self attention
123
+ ]
124
+ + [
125
+ DiTBlock(
126
+ dim=dim,
127
+ num_attention_heads=num_attention_heads,
128
+ use_self_attention=False,
129
+ use_cross_attention=True,
130
+ cross_attention_dim=dim,
131
+ cross_attention_norm_type="layer_norm",
132
+ activation_fn="gelu",
133
+ norm_type="fp32_layer_norm",
134
+ norm_eps=1e-5,
135
+ qk_norm=False,
136
+ qkv_bias=False,
137
+ ) # cross attention
138
+ ]
139
+ )
140
+
141
+ self.proj_query = nn.Linear(in_channels, dim, bias=True)
142
+
143
+ self.norm_out = LayerNorm(dim)
144
+ self.proj_out = nn.Linear(dim, out_channels, bias=True)
145
+
146
+ def set_topk(self, topk):
147
+ self.blocks[-1].set_topk(topk)
148
+
149
+ def set_flash_processor(self, processor):
150
+ self.blocks[-1].set_flash_processor(processor)
151
+
152
+ def query_geometry(
153
+ self,
154
+ model_fn: callable,
155
+ queries: torch.Tensor,
156
+ sample: torch.Tensor,
157
+ grad: bool = False,
158
+ ):
159
+ logits = model_fn(queries, sample)
160
+ if grad:
161
+ with torch.autocast(device_type="cuda", dtype=torch.float32):
162
+ if self.grad_type == "numerical":
163
+ interval = self.grad_interval
164
+ grad_value = []
165
+ for offset in [
166
+ (interval, 0, 0),
167
+ (0, interval, 0),
168
+ (0, 0, interval),
169
+ ]:
170
+ offset_tensor = torch.tensor(offset, device=queries.device)[
171
+ None, :
172
+ ]
173
+ res_p = model_fn(queries + offset_tensor, sample)[..., 0]
174
+ res_n = model_fn(queries - offset_tensor, sample)[..., 0]
175
+ grad_value.append((res_p - res_n) / (2 * interval))
176
+ grad_value = torch.stack(grad_value, dim=-1)
177
+ else:
178
+ queries_d = torch.clone(queries)
179
+ queries_d.requires_grad = True
180
+ with torch.enable_grad():
181
+ res_d = model_fn(queries_d, sample)
182
+ grad_value = torch.autograd.grad(
183
+ res_d,
184
+ [queries_d],
185
+ grad_outputs=torch.ones_like(res_d),
186
+ create_graph=self.training,
187
+ )[0]
188
+ else:
189
+ grad_value = None
190
+
191
+ return logits, grad_value
192
+
193
+ def forward(
194
+ self,
195
+ sample: torch.Tensor,
196
+ queries: torch.Tensor,
197
+ kv_cache: Optional[torch.Tensor] = None,
198
+ ):
199
+ if kv_cache is None:
200
+ hidden_states = sample
201
+ for _, block in enumerate(self.blocks[:-1]):
202
+ hidden_states = block(hidden_states)
203
+ kv_cache = hidden_states
204
+
205
+ # query grid logits by cross attention
206
+ def query_fn(q, kv):
207
+ q = self.proj_query(q)
208
+ l = self.blocks[-1](q, encoder_hidden_states=kv)
209
+ return self.proj_out(self.norm_out(l))
210
+
211
+ logits, grad = self.query_geometry(
212
+ query_fn, queries, kv_cache, grad=self.training
213
+ )
214
+ logits = logits * -1 if not isinstance(logits, Tuple) else logits[0] * -1
215
+
216
+ return logits, kv_cache
217
+
218
+
219
+ class TripoSGVAEModel(ModelMixin, ConfigMixin):
220
+ @register_to_config
221
+ def __init__(
222
+ self,
223
+ in_channels: int = 3, # NOTE xyz instead of feature dim
224
+ latent_channels: int = 64,
225
+ num_attention_heads: int = 8,
226
+ width_encoder: int = 512,
227
+ width_decoder: int = 1024,
228
+ num_layers_encoder: int = 8,
229
+ num_layers_decoder: int = 16,
230
+ embedding_type: str = "frequency",
231
+ embed_frequency: int = 8,
232
+ embed_include_pi: bool = False,
233
+ ):
234
+ super().__init__()
235
+
236
+ self.out_channels = 1
237
+
238
+ if embedding_type == "frequency":
239
+ self.embedder = FrequencyPositionalEmbedding(
240
+ num_freqs=embed_frequency,
241
+ logspace=True,
242
+ input_dim=in_channels,
243
+ include_pi=embed_include_pi,
244
+ )
245
+ else:
246
+ raise NotImplementedError(
247
+ f"Embedding type {embedding_type} is not supported."
248
+ )
249
+
250
+ self.encoder = TripoSGEncoder(
251
+ in_channels=in_channels + self.embedder.out_dim,
252
+ dim=width_encoder,
253
+ num_attention_heads=num_attention_heads,
254
+ num_layers=num_layers_encoder,
255
+ )
256
+ self.decoder = TripoSGDecoder(
257
+ in_channels=self.embedder.out_dim,
258
+ out_channels=self.out_channels,
259
+ dim=width_decoder,
260
+ num_attention_heads=num_attention_heads,
261
+ num_layers=num_layers_decoder,
262
+ )
263
+
264
+ self.quant = nn.Linear(width_encoder, latent_channels * 2, bias=True)
265
+ self.post_quant = nn.Linear(latent_channels, width_decoder, bias=True)
266
+
267
+ self.use_slicing = False
268
+ self.slicing_length = 1
269
+
270
+ def set_flash_decoder(self):
271
+ self.decoder.set_flash_processor(FlashTripo2AttnProcessor2_0())
272
+
273
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections with FusedAttnProcessor2_0->FusedTripoSGAttnProcessor2_0
274
+ def fuse_qkv_projections(self):
275
+ """
276
+ Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
277
+ are fused. For cross-attention modules, key and value projection matrices are fused.
278
+
279
+ <Tip warning={true}>
280
+
281
+ This API is 🧪 experimental.
282
+
283
+ </Tip>
284
+ """
285
+ self.original_attn_processors = None
286
+
287
+ for _, attn_processor in self.attn_processors.items():
288
+ if "Added" in str(attn_processor.__class__.__name__):
289
+ raise ValueError(
290
+ "`fuse_qkv_projections()` is not supported for models having added KV projections."
291
+ )
292
+
293
+ self.original_attn_processors = self.attn_processors
294
+
295
+ for module in self.modules():
296
+ if isinstance(module, Attention):
297
+ module.fuse_projections(fuse=True)
298
+
299
+ self.set_attn_processor(FusedTripoSGAttnProcessor2_0())
300
+
301
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
302
+ def unfuse_qkv_projections(self):
303
+ """Disables the fused QKV projection if enabled.
304
+
305
+ <Tip warning={true}>
306
+
307
+ This API is 🧪 experimental.
308
+
309
+ </Tip>
310
+
311
+ """
312
+ if self.original_attn_processors is not None:
313
+ self.set_attn_processor(self.original_attn_processors)
314
+
315
+ @property
316
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
317
+ def attn_processors(self) -> Dict[str, AttentionProcessor]:
318
+ r"""
319
+ Returns:
320
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
321
+ indexed by its weight name.
322
+ """
323
+ # set recursively
324
+ processors = {}
325
+
326
+ def fn_recursive_add_processors(
327
+ name: str,
328
+ module: torch.nn.Module,
329
+ processors: Dict[str, AttentionProcessor],
330
+ ):
331
+ if hasattr(module, "get_processor"):
332
+ processors[f"{name}.processor"] = module.get_processor()
333
+
334
+ for sub_name, child in module.named_children():
335
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
336
+
337
+ return processors
338
+
339
+ for name, module in self.named_children():
340
+ fn_recursive_add_processors(name, module, processors)
341
+
342
+ return processors
343
+
344
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
345
+ def set_attn_processor(
346
+ self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]
347
+ ):
348
+ r"""
349
+ Sets the attention processor to use to compute attention.
350
+
351
+ Parameters:
352
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
353
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
354
+ for **all** `Attention` layers.
355
+
356
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
357
+ processor. This is strongly recommended when setting trainable attention processors.
358
+
359
+ """
360
+ count = len(self.attn_processors.keys())
361
+
362
+ if isinstance(processor, dict) and len(processor) != count:
363
+ raise ValueError(
364
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
365
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
366
+ )
367
+
368
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
369
+ if hasattr(module, "set_processor"):
370
+ if not isinstance(processor, dict):
371
+ module.set_processor(processor)
372
+ else:
373
+ module.set_processor(processor.pop(f"{name}.processor"))
374
+
375
+ for sub_name, child in module.named_children():
376
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
377
+
378
+ for name, module in self.named_children():
379
+ fn_recursive_attn_processor(name, module, processor)
380
+
381
+ def set_default_attn_processor(self):
382
+ """
383
+ Disables custom attention processors and sets the default attention implementation.
384
+ """
385
+ self.set_attn_processor(TripoSGAttnProcessor2_0())
386
+
387
+ def enable_slicing(self, slicing_length: int = 1) -> None:
388
+ r"""
389
+ Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
390
+ compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
391
+ """
392
+ self.use_slicing = True
393
+ self.slicing_length = slicing_length
394
+
395
+ def disable_slicing(self) -> None:
396
+ r"""
397
+ Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
398
+ decoding in one step.
399
+ """
400
+ self.use_slicing = False
401
+
402
+ def _sample_features(
403
+ self, x: torch.Tensor, num_tokens: int = 2048, seed: Optional[int] = None
404
+ ):
405
+ """
406
+ Sample points from features of the input point cloud.
407
+
408
+ Args:
409
+ x (torch.Tensor): The input point cloud. shape: (B, N, C)
410
+ num_tokens (int, optional): The number of points to sample. Defaults to 2048.
411
+ seed (Optional[int], optional): The random seed. Defaults to None.
412
+ """
413
+ rng = np.random.default_rng(seed)
414
+ indices = rng.choice(
415
+ x.shape[1], num_tokens * 4, replace=num_tokens * 4 > x.shape[1]
416
+ )
417
+ selected_points = x[:, indices]
418
+
419
+ batch_size, num_points, num_channels = selected_points.shape
420
+ flattened_points = selected_points.view(batch_size * num_points, num_channels)
421
+ batch_indices = (
422
+ torch.arange(batch_size).to(x.device).repeat_interleave(num_points)
423
+ )
424
+
425
+ # fps sampling
426
+ sampling_ratio = 1.0 / 4
427
+ sampled_indices = fps(
428
+ flattened_points[:, :3],
429
+ batch_indices,
430
+ ratio=sampling_ratio,
431
+ random_start=self.training,
432
+ )
433
+ sampled_points = flattened_points[sampled_indices].view(
434
+ batch_size, -1, num_channels
435
+ )
436
+
437
+ return sampled_points
438
+
439
+ def _encode(
440
+ self, x: torch.Tensor, num_tokens: int = 2048, seed: Optional[int] = None
441
+ ):
442
+ position_channels = self.config.in_channels
443
+ positions, features = x[..., :position_channels], x[..., position_channels:]
444
+ x_kv = torch.cat([self.embedder(positions), features], dim=-1)
445
+
446
+ sampled_x = self._sample_features(x, num_tokens, seed)
447
+ positions, features = (
448
+ sampled_x[..., :position_channels],
449
+ sampled_x[..., position_channels:],
450
+ )
451
+ x_q = torch.cat([self.embedder(positions), features], dim=-1)
452
+
453
+ x = self.encoder(x_q, x_kv)
454
+
455
+ x = self.quant(x)
456
+
457
+ return x
458
+
459
+ @apply_forward_hook
460
+ def encode(
461
+ self, x: torch.Tensor, return_dict: bool = True, **kwargs
462
+ ) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]:
463
+ """
464
+ Encode a batch of point features into latents.
465
+ """
466
+ if self.use_slicing and x.shape[0] > 1:
467
+ encoded_slices = [
468
+ self._encode(x_slice, **kwargs)
469
+ for x_slice in x.split(self.slicing_length)
470
+ ]
471
+ h = torch.cat(encoded_slices)
472
+ else:
473
+ h = self._encode(x, **kwargs)
474
+
475
+ posterior = DiagonalGaussianDistribution(h, feature_dim=-1)
476
+
477
+ if not return_dict:
478
+ return (posterior,)
479
+ return AutoencoderKLOutput(latent_dist=posterior)
480
+
481
+ def _decode(
482
+ self,
483
+ z: torch.Tensor,
484
+ sampled_points: torch.Tensor,
485
+ num_chunks: int = 50000,
486
+ to_cpu: bool = False,
487
+ return_dict: bool = True,
488
+ ) -> Union[DecoderOutput, torch.Tensor]:
489
+ xyz_samples = sampled_points
490
+
491
+ z = self.post_quant(z)
492
+
493
+ num_points = xyz_samples.shape[1]
494
+ kv_cache = None
495
+ dec = []
496
+
497
+ for i in range(0, num_points, num_chunks):
498
+ queries = xyz_samples[:, i : i + num_chunks, :].to(z.device, dtype=z.dtype)
499
+ queries = self.embedder(queries)
500
+
501
+ z_, kv_cache = self.decoder(z, queries, kv_cache)
502
+ dec.append(z_ if not to_cpu else z_.cpu())
503
+
504
+ z = torch.cat(dec, dim=1)
505
+
506
+ if not return_dict:
507
+ return (z,)
508
+
509
+ return DecoderOutput(sample=z)
510
+
511
+ @apply_forward_hook
512
+ def decode(
513
+ self,
514
+ z: torch.Tensor,
515
+ sampled_points: torch.Tensor,
516
+ return_dict: bool = True,
517
+ **kwargs,
518
+ ) -> Union[DecoderOutput, torch.Tensor]:
519
+ if self.use_slicing and z.shape[0] > 1:
520
+ decoded_slices = [
521
+ self._decode(z_slice, p_slice, **kwargs).sample
522
+ for z_slice, p_slice in zip(
523
+ z.split(self.slicing_length),
524
+ sampled_points.split(self.slicing_length),
525
+ )
526
+ ]
527
+ decoded = torch.cat(decoded_slices)
528
+ else:
529
+ decoded = self._decode(z, sampled_points, **kwargs).sample
530
+
531
+ if not return_dict:
532
+ return (decoded,)
533
+ return DecoderOutput(sample=decoded)
534
+
535
+ def forward(self, x: torch.Tensor):
536
+ pass
detailgen3d/models/autoencoders/vae.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, Tuple
2
+
3
+ import numpy as np
4
+ import torch
5
+ from diffusers.utils.torch_utils import randn_tensor
6
+
7
+
8
+ class DiagonalGaussianDistribution(object):
9
+ def __init__(
10
+ self,
11
+ parameters: torch.Tensor,
12
+ deterministic: bool = False,
13
+ feature_dim: int = 1,
14
+ ):
15
+ self.parameters = parameters
16
+ self.feature_dim = feature_dim
17
+ self.mean, self.logvar = torch.chunk(parameters, 2, dim=feature_dim)
18
+ self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
19
+ self.deterministic = deterministic
20
+ self.std = torch.exp(0.5 * self.logvar)
21
+ self.var = torch.exp(self.logvar)
22
+ if self.deterministic:
23
+ self.var = self.std = torch.zeros_like(
24
+ self.mean, device=self.parameters.device, dtype=self.parameters.dtype
25
+ )
26
+
27
+ def sample(self, generator: Optional[torch.Generator] = None) -> torch.Tensor:
28
+ # make sure sample is on the same device as the parameters and has same dtype
29
+ sample = randn_tensor(
30
+ self.mean.shape,
31
+ generator=generator,
32
+ device=self.parameters.device,
33
+ dtype=self.parameters.dtype,
34
+ )
35
+ x = self.mean + self.std * sample
36
+ return x
37
+
38
+ def kl(self, other: "DiagonalGaussianDistribution" = None) -> torch.Tensor:
39
+ if self.deterministic:
40
+ return torch.Tensor([0.0])
41
+ else:
42
+ if other is None:
43
+ return 0.5 * torch.sum(
44
+ torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar,
45
+ dim=[1, 2, 3],
46
+ )
47
+ else:
48
+ return 0.5 * torch.sum(
49
+ torch.pow(self.mean - other.mean, 2) / other.var
50
+ + self.var / other.var
51
+ - 1.0
52
+ - self.logvar
53
+ + other.logvar,
54
+ dim=[1, 2, 3],
55
+ )
56
+
57
+ def nll(
58
+ self, sample: torch.Tensor, dims: Tuple[int, ...] = [1, 2, 3]
59
+ ) -> torch.Tensor:
60
+ if self.deterministic:
61
+ return torch.Tensor([0.0])
62
+ logtwopi = np.log(2.0 * np.pi)
63
+ return 0.5 * torch.sum(
64
+ logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,
65
+ dim=dims,
66
+ )
67
+
68
+ def mode(self) -> torch.Tensor:
69
+ return self.mean
detailgen3d/models/embeddings.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+
5
+ class FrequencyPositionalEmbedding(nn.Module):
6
+ """The sin/cosine positional embedding. Given an input tensor `x` of shape [n_batch, ..., c_dim], it converts
7
+ each feature dimension of `x[..., i]` into:
8
+ [
9
+ sin(x[..., i]),
10
+ sin(f_1*x[..., i]),
11
+ sin(f_2*x[..., i]),
12
+ ...
13
+ sin(f_N * x[..., i]),
14
+ cos(x[..., i]),
15
+ cos(f_1*x[..., i]),
16
+ cos(f_2*x[..., i]),
17
+ ...
18
+ cos(f_N * x[..., i]),
19
+ x[..., i] # only present if include_input is True.
20
+ ], here f_i is the frequency.
21
+
22
+ Denote the space is [0 / num_freqs, 1 / num_freqs, 2 / num_freqs, 3 / num_freqs, ..., (num_freqs - 1) / num_freqs].
23
+ If logspace is True, then the frequency f_i is [2^(0 / num_freqs), ..., 2^(i / num_freqs), ...];
24
+ Otherwise, the frequencies are linearly spaced between [1.0, 2^(num_freqs - 1)].
25
+
26
+ Args:
27
+ num_freqs (int): the number of frequencies, default is 6;
28
+ logspace (bool): If logspace is True, then the frequency f_i is [..., 2^(i / num_freqs), ...],
29
+ otherwise, the frequencies are linearly spaced between [1.0, 2^(num_freqs - 1)];
30
+ input_dim (int): the input dimension, default is 3;
31
+ include_input (bool): include the input tensor or not, default is True.
32
+
33
+ Attributes:
34
+ frequencies (torch.Tensor): If logspace is True, then the frequency f_i is [..., 2^(i / num_freqs), ...],
35
+ otherwise, the frequencies are linearly spaced between [1.0, 2^(num_freqs - 1);
36
+
37
+ out_dim (int): the embedding size, if include_input is True, it is input_dim * (num_freqs * 2 + 1),
38
+ otherwise, it is input_dim * num_freqs * 2.
39
+
40
+ """
41
+
42
+ def __init__(
43
+ self,
44
+ num_freqs: int = 6,
45
+ logspace: bool = True,
46
+ input_dim: int = 3,
47
+ include_input: bool = True,
48
+ include_pi: bool = True,
49
+ ) -> None:
50
+ """The initialization"""
51
+
52
+ super().__init__()
53
+
54
+ if logspace:
55
+ frequencies = 2.0 ** torch.arange(num_freqs, dtype=torch.float32)
56
+ else:
57
+ frequencies = torch.linspace(
58
+ 1.0, 2.0 ** (num_freqs - 1), num_freqs, dtype=torch.float32
59
+ )
60
+
61
+ if include_pi:
62
+ frequencies *= torch.pi
63
+
64
+ self.register_buffer("frequencies", frequencies, persistent=False)
65
+ self.include_input = include_input
66
+ self.num_freqs = num_freqs
67
+
68
+ self.out_dim = self.get_dims(input_dim)
69
+
70
+ def get_dims(self, input_dim):
71
+ temp = 1 if self.include_input or self.num_freqs == 0 else 0
72
+ out_dim = input_dim * (self.num_freqs * 2 + temp)
73
+
74
+ return out_dim
75
+
76
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
77
+ """Forward process.
78
+
79
+ Args:
80
+ x: tensor of shape [..., dim]
81
+
82
+ Returns:
83
+ embedding: an embedding of `x` of shape [..., dim * (num_freqs * 2 + temp)]
84
+ where temp is 1 if include_input is True and 0 otherwise.
85
+ """
86
+
87
+ if self.num_freqs > 0:
88
+ embed = (x[..., None].contiguous() * self.frequencies).view(
89
+ *x.shape[:-1], -1
90
+ )
91
+ if self.include_input:
92
+ return torch.cat((x, embed.sin(), embed.cos()), dim=-1)
93
+ else:
94
+ return torch.cat((embed.sin(), embed.cos()), dim=-1)
95
+ else:
96
+ return x
detailgen3d/models/transformers/__init__.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Callable, Optional
2
+
3
+ from .detailgen3d_transformers import DetailGen3DDiTModel
4
+
5
+
6
+ def default_set_attn_proc_func(
7
+ name: str,
8
+ hidden_size: int,
9
+ cross_attention_dim: Optional[int],
10
+ ori_attn_proc: object,
11
+ ) -> object:
12
+ return ori_attn_proc
13
+
14
+
15
+ def set_transformer_attn_processor(
16
+ transformer: DetailGen3DDiTModel,
17
+ set_self_attn_proc_func: Callable = default_set_attn_proc_func,
18
+ set_cross_attn_1_proc_func: Callable = default_set_attn_proc_func,
19
+ set_cross_attn_2_proc_func: Callable = default_set_attn_proc_func,
20
+ set_self_attn_module_names: Optional[list[str]] = None,
21
+ set_cross_attn_1_module_names: Optional[list[str]] = None,
22
+ set_cross_attn_2_module_names: Optional[list[str]] = None,
23
+ ) -> None:
24
+ do_set_processor = lambda name, module_names: (
25
+ any([name.startswith(module_name) for module_name in module_names])
26
+ if module_names is not None
27
+ else True
28
+ ) # prefix match
29
+
30
+ attn_procs = {}
31
+ for name, attn_processor in transformer.attn_processors.items():
32
+ hidden_size = transformer.config.width
33
+ if name.endswith("attn1.processor"):
34
+ # self attention
35
+ attn_procs[name] = (
36
+ set_self_attn_proc_func(name, hidden_size, None, attn_processor)
37
+ if do_set_processor(name, set_self_attn_module_names)
38
+ else attn_processor
39
+ )
40
+ elif name.endswith("attn2.processor"):
41
+ # cross attention
42
+ cross_attention_dim = transformer.config.cross_attention_dim
43
+ attn_procs[name] = (
44
+ set_cross_attn_1_proc_func(
45
+ name, hidden_size, cross_attention_dim, attn_processor
46
+ )
47
+ if do_set_processor(name, set_cross_attn_1_module_names)
48
+ else attn_processor
49
+ )
50
+ elif name.endswith("attn2_2.processor"):
51
+ # cross attention 2
52
+ cross_attention_dim = transformer.config.cross_attention_2_dim
53
+ attn_procs[name] = (
54
+ set_cross_attn_2_proc_func(
55
+ name, hidden_size, cross_attention_dim, attn_processor
56
+ )
57
+ if do_set_processor(name, set_cross_attn_2_module_names)
58
+ else attn_processor
59
+ )
60
+
61
+ transformer.set_attn_processor(attn_procs)
detailgen3d/models/transformers/detailgen3d_transformers.py ADDED
@@ -0,0 +1,771 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025 VAST-AI-Research and contributors
2
+
3
+ # This code is based on Tencent HunyuanDiT (https://huggingface.co/Tencent-Hunyuan/HunyuanDiT),
4
+ # which is licensed under the Tencent Hunyuan Community License Agreement.
5
+ # Portions of this code are copied or adapted from HunyuanDiT.
6
+ # See the original license below:
7
+
8
+ # ---- Start of Tencent Hunyuan Community License Agreement ----
9
+
10
+ # TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT
11
+ # Tencent Hunyuan DiT Release Date: 14 May 2024
12
+ # THIS LICENSE AGREEMENT DOES NOT APPLY IN THE EUROPEAN UNION AND IS EXPRESSLY LIMITED TO THE TERRITORY, AS DEFINED BELOW.
13
+ # By clicking to agree or by using, reproducing, modifying, distributing, performing or displaying any portion or element of the Tencent Hunyuan Works, including via any Hosted Service, You will be deemed to have recognized and accepted the content of this Agreement, which is effective immediately.
14
+ # 1. DEFINITIONS.
15
+ # a. “Acceptable Use Policy” shall mean the policy made available by Tencent as set forth in the Exhibit A.
16
+ # b. “Agreement” shall mean the terms and conditions for use, reproduction, distribution, modification, performance and displaying of Tencent Hunyuan Works or any portion or element thereof set forth herein.
17
+ # c. “Documentation” shall mean the specifications, manuals and documentation for Tencent Hunyuan made publicly available by Tencent.
18
+ # d. “Hosted Service” shall mean a hosted service offered via an application programming interface (API), web access, or any other electronic or remote means.
19
+ # e. “Licensee,” “You” or “Your” shall mean a natural person or legal entity exercising the rights granted by this Agreement and/or using the Tencent Hunyuan Works for any purpose and in any field of use.
20
+ # f. “Materials” shall mean, collectively, Tencent’s proprietary Tencent Hunyuan and Documentation (and any portion thereof) as made available by Tencent under this Agreement.
21
+ # g. “Model Derivatives” shall mean all: (i) modifications to Tencent Hunyuan or any Model Derivative of Tencent Hunyuan; (ii) works based on Tencent Hunyuan or any Model Derivative of Tencent Hunyuan; or (iii) any other machine learning model which is created by transfer of patterns of the weights, parameters, operations, or Output of Tencent Hunyuan or any Model Derivative of Tencent Hunyuan, to that model in order to cause that model to perform similarly to Tencent Hunyuan or a Model Derivative of Tencent Hunyuan, including distillation methods, methods that use intermediate data representations, or methods based on the generation of synthetic data Outputs by Tencent Hunyuan or a Model Derivative of Tencent Hunyuan for training that model. For clarity, Outputs by themselves are not deemed Model Derivatives.
22
+ # h. “Output” shall mean the information and/or content output of Tencent Hunyuan or a Model Derivative that results from operating or otherwise using Tencent Hunyuan or a Model Derivative, including via a Hosted Service.
23
+ # i. “Tencent,” “We” or “Us” shall mean THL A29 Limited.
24
+ # j. “Tencent Hunyuan” shall mean the large language models, text/image/video/audio/3D generation models, and multimodal large language models and their software and algorithms, including trained model weights, parameters (including optimizer states), machine-learning model code, inference-enabling code, training-enabling code, fine-tuning enabling code and other elements of the foregoing made publicly available by Us, including, without limitation to, Tencent Hunyuan DiT released at https://huggingface.co/Tencent-Hunyuan/HunyuanDiT.
25
+ # k. “Tencent Hunyuan Works” shall mean: (i) the Materials; (ii) Model Derivatives; and (iii) all derivative works thereof.
26
+ # l. “Territory” shall mean the worldwide territory, excluding the territory of the European Union.
27
+ # m. “Third Party” or “Third Parties” shall mean individuals or legal entities that are not under common control with Us or You.
28
+ # n. “including” shall mean including but not limited to.
29
+ # 2. GRANT OF RIGHTS.
30
+ # We grant You, for the Territory only, a non-exclusive, non-transferable and royalty-free limited license under Tencent’s intellectual property or other rights owned by Us embodied in or utilized by the Materials to use, reproduce, distribute, create derivative works of (including Model Derivatives), and make modifications to the Materials, only in accordance with the terms of this Agreement and the Acceptable Use Policy, and You must not violate (or encourage or permit anyone else to violate) any term of this Agreement or the Acceptable Use Policy.
31
+ # 3. DISTRIBUTION.
32
+ # You may, subject to Your compliance with this Agreement, distribute or make available to Third Parties the Tencent Hunyuan Works, exclusively in the Territory, provided that You meet all of the following conditions:
33
+ # a. You must provide all such Third Party recipients of the Tencent Hunyuan Works or products or services using them a copy of this Agreement;
34
+ # b. You must cause any modified files to carry prominent notices stating that You changed the files;
35
+ # c. You are encouraged to: (i) publish at least one technology introduction blogpost or one public statement expressing Your experience of using the Tencent Hunyuan Works; and (ii) mark the products or services developed by using the Tencent Hunyuan Works to indicate that the product/service is “Powered by Tencent Hunyuan”; and
36
+ # d. All distributions to Third Parties (other than through a Hosted Service) must be accompanied by a “Notice” text file that contains the following notice: “Tencent Hunyuan is licensed under the Tencent Hunyuan Community License Agreement, Copyright © 2024 Tencent. All Rights Reserved. The trademark rights of “Tencent Hunyuan” are owned by Tencent or its affiliate.”
37
+ # You may add Your own copyright statement to Your modifications and, except as set forth in this Section and in Section 5, may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such Model Derivatives as a whole, provided Your use, reproduction, modification, distribution, performance and display of the work otherwise complies with the terms and conditions of this Agreement (including as regards the Territory). If You receive Tencent Hunyuan Works from a Licensee as part of an integrated end user product, then this Section 3 of this Agreement will not apply to You.
38
+ # 4. ADDITIONAL COMMERCIAL TERMS.
39
+ # If, on the Tencent Hunyuan version release date, the monthly active users of all products or services made available by or for Licensee is greater than 100 million monthly active users in the preceding calendar month, You must request a license from Tencent, which Tencent may grant to You in its sole discretion, and You are not authorized to exercise any of the rights under this Agreement unless or until Tencent otherwise expressly grants You such rights.
40
+ # 5. RULES OF USE.
41
+ # a. Your use of the Tencent Hunyuan Works must comply with applicable laws and regulations (including trade compliance laws and regulations) and adhere to the Acceptable Use Policy for the Tencent Hunyuan Works, which is hereby incorporated by reference into this Agreement. You must include the use restrictions referenced in these Sections 5(a) and 5(b) as an enforceable provision in any agreement (e.g., license agreement, terms of use, etc.) governing the use and/or distribution of Tencent Hunyuan Works and You must provide notice to subsequent users to whom You distribute that Tencent Hunyuan Works are subject to the use restrictions in these Sections 5(a) and 5(b).
42
+ # b. You must not use the Tencent Hunyuan Works or any Output or results of the Tencent Hunyuan Works to improve any other large language model (other than Tencent Hunyuan or Model Derivatives thereof).
43
+ # c. You must not use, reproduce, modify, distribute, or display the Tencent Hunyuan Works, Output or results of the Tencent Hunyuan Works outside the Territory. Any such use outside the Territory is unlicensed and unauthorized under this Agreement.
44
+ # 6. INTELLECTUAL PROPERTY.
45
+ # a. Subject to Tencent’s ownership of Tencent Hunyuan Works made by or for Tencent and intellectual property rights therein, conditioned upon Your compliance with the terms and conditions of this Agreement, as between You and Tencent, You will be the owner of any derivative works and modifications of the Materials and any Model Derivatives that are made by or for You.
46
+ # b. No trademark licenses are granted under this Agreement, and in connection with the Tencent Hunyuan Works, Licensee may not use any name or mark owned by or associated with Tencent or any of its affiliates, except as required for reasonable and customary use in describing and distributing the Tencent Hunyuan Works. Tencent hereby grants You a license to use “Tencent Hunyuan” (the “Mark”) in the Territory solely as required to comply with the provisions of Section 3(c), provided that You comply with any applicable laws related to trademark protection. All goodwill arising out of Your use of the Mark will inure to the benefit of Tencent.
47
+ # c. If You commence a lawsuit or other proceedings (including a cross-claim or counterclaim in a lawsuit) against Us or any person or entity alleging that the Materials or any Output, or any portion of any of the foregoing, infringe any intellectual property or other right owned or licensable by You, then all licenses granted to You under this Agreement shall terminate as of the date such lawsuit or other proceeding is filed. You will defend, indemnify and hold harmless Us from and against any claim by any Third Party arising out of or related to Your or the Third Party’s use or distribution of the Tencent Hunyuan Works.
48
+ # d. Tencent claims no rights in Outputs You generate. You and Your users are solely responsible for Outputs and their subsequent uses.
49
+ # 7. DISCLAIMERS OF WARRANTY AND LIMITATIONS OF LIABILITY.
50
+ # a. We are not obligated to support, update, provide training for, or develop any further version of the Tencent Hunyuan Works or to grant any license thereto.
51
+ # b. UNLESS AND ONLY TO THE EXTENT REQUIRED BY APPLICABLE LAW, THE TENCENT HUNYUAN WORKS AND ANY OUTPUT AND RESULTS THEREFROM ARE PROVIDED “AS IS” WITHOUT ANY EXPRESS OR IMPLIED WARRANTIES OF ANY KIND INCLUDING ANY WARRANTIES OF TITLE, MERCHANTABILITY, NONINFRINGEMENT, COURSE OF DEALING, USAGE OF TRADE, OR FITNESS FOR A PARTICULAR PURPOSE. YOU ARE SOLELY RESPONSIBLE FOR DETERMINING THE APPROPRIATENESS OF USING, REPRODUCING, MODIFYING, PERFORMING, DISPLAYING OR DISTRIBUTING ANY OF THE TENCENT HUNYUAN WORKS OR OUTPUTS AND ASSUME ANY AND ALL RISKS ASSOCIATED WITH YOUR OR A THIRD PARTY’S USE OR DISTRIBUTION OF ANY OF THE TENCENT HUNYUAN WORKS OR OUTPUTS AND YOUR EXERCISE OF RIGHTS AND PERMISSIONS UNDER THIS AGREEMENT.
52
+ # c. TO THE FULLEST EXTENT PERMITTED BY APPLICABLE LAW, IN NO EVENT SHALL TENCENT OR ITS AFFILIATES BE LIABLE UNDER ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, TORT, NEGLIGENCE, PRODUCTS LIABILITY, OR OTHERWISE, FOR ANY DAMAGES, INCLUDING ANY DIRECT, INDIRECT, SPECIAL, INCIDENTAL, EXEMPLARY, CONSEQUENTIAL OR PUNITIVE DAMAGES, OR LOST PROFITS OF ANY KIND ARISING FROM THIS AGREEMENT OR RELATED TO ANY OF THE TENCENT HUNYUAN WORKS OR OUTPUTS, EVEN IF TENCENT OR ITS AFFILIATES HAVE BEEN ADVISED OF THE POSSIBILITY OF ANY OF THE FOREGOING.
53
+ # 8. SURVIVAL AND TERMINATION.
54
+ # a. The term of this Agreement shall commence upon Your acceptance of this Agreement or access to the Materials and will continue in full force and effect until terminated in accordance with the terms and conditions herein.
55
+ # b. We may terminate this Agreement if You breach any of the terms or conditions of this Agreement. Upon termination of this Agreement, You must promptly delete and cease use of the Tencent Hunyuan Works. Sections 6(a), 6(c), 7 and 9 shall survive the termination of this Agreement.
56
+ # 9. GOVERNING LAW AND JURISDICTION.
57
+ # a. This Agreement and any dispute arising out of or relating to it will be governed by the laws of the Hong Kong Special Administrative Region of the People’s Republic of China, without regard to conflict of law principles, and the UN Convention on Contracts for the International Sale of Goods does not apply to this Agreement.
58
+ # b. Exclusive jurisdiction and venue for any dispute arising out of or relating to this Agreement will be a court of competent jurisdiction in the Hong Kong Special Administrative Region of the People’s Republic of China, and Tencent and Licensee consent to the exclusive jurisdiction of such court with respect to any such dispute.
59
+ #
60
+ # EXHIBIT A
61
+ # ACCEPTABLE USE POLICY
62
+
63
+ # Tencent reserves the right to update this Acceptable Use Policy from time to time.
64
+ # Last modified: [insert date]
65
+
66
+ # Tencent endeavors to promote safe and fair use of its tools and features, including Tencent Hunyuan. You agree not to use Tencent Hunyuan or Model Derivatives:
67
+ # 1. Outside the Territory;
68
+ # 2. In any way that violates any applicable national, federal, state, local, international or any other law or regulation;
69
+ # 3. To harm Yourself or others;
70
+ # 4. To repurpose or distribute output from Tencent Hunyuan or any Model Derivatives to harm Yourself or others;
71
+ # 5. To override or circumvent the safety guardrails and safeguards We have put in place;
72
+ # 6. For the purpose of exploiting, harming or attempting to exploit or harm minors in any way;
73
+ # 7. To generate or disseminate verifiably false information and/or content with the purpose of harming others or influencing elections;
74
+ # 8. To generate or facilitate false online engagement, including fake reviews and other means of fake online engagement;
75
+ # 9. To intentionally defame, disparage or otherwise harass others;
76
+ # 10. To generate and/or disseminate malware (including ransomware) or any other content to be used for the purpose of harming electronic systems;
77
+ # 11. To generate or disseminate personal identifiable information with the purpose of harming others;
78
+ # 12. To generate or disseminate information (including images, code, posts, articles), and place the information in any public context (including –through the use of bot generated tweets), without expressly and conspicuously identifying that the information and/or content is machine generated;
79
+ # 13. To impersonate another individual without consent, authorization, or legal right;
80
+ # 14. To make high-stakes automated decisions in domains that affect an individual’s safety, rights or wellbeing (e.g., law enforcement, migration, medicine/health, management of critical infrastructure, safety components of products, essential services, credit, employment, housing, education, social scoring, or insurance);
81
+ # 15. In a manner that violates or disrespects the social ethics and moral standards of other countries or regions;
82
+ # 16. To perform, facilitate, threaten, incite, plan, promote or encourage violent extremism or terrorism;
83
+ # 17. For any use intended to discriminate against or harm individuals or groups based on protected characteristics or categories, online or offline social behavior or known or predicted personal or personality characteristics;
84
+ # 18. To intentionally exploit any of the vulnerabilities of a specific group of persons based on their age, social, physical or mental characteristics, in order to materially distort the behavior of a person pertaining to that group in a manner that causes or is likely to cause that person or another person physical or psychological harm;
85
+ # 19. For military purposes;
86
+ # 20. To engage in the unauthorized or unlicensed practice of any profession including, but not limited to, financial, legal, medical/health, or other professional practices.
87
+
88
+ # ---- End of Tencent Hunyuan Community License Agreement ----
89
+
90
+ # Please note that the use of this code is subject to the terms and conditions
91
+ # of the Tencent Hunyuan Community License Agreement, including the Acceptable Use Policy.
92
+
93
+ from typing import Any, Dict, Optional, Tuple, Union
94
+
95
+ import torch
96
+ import torch.utils.checkpoint
97
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
98
+ from diffusers.models.attention import FeedForward
99
+ from diffusers.models.attention_processor import Attention, AttentionProcessor
100
+ from diffusers.models.embeddings import (
101
+ GaussianFourierProjection,
102
+ TimestepEmbedding,
103
+ Timesteps,
104
+ )
105
+ from diffusers.models.modeling_utils import ModelMixin
106
+ from diffusers.models.normalization import (
107
+ AdaLayerNormContinuous,
108
+ FP32LayerNorm,
109
+ LayerNorm,
110
+ )
111
+ from diffusers.utils import (
112
+ USE_PEFT_BACKEND,
113
+ is_torch_version,
114
+ logging,
115
+ scale_lora_layers,
116
+ unscale_lora_layers,
117
+ )
118
+ from diffusers.utils.torch_utils import maybe_allow_in_graph
119
+ from torch import nn
120
+
121
+ from ..attention_processor import FusedTripoSGAttnProcessor2_0, TripoSGAttnProcessor2_0
122
+ from .modeling_outputs import Transformer1DModelOutput
123
+
124
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
125
+
126
+
127
+ @maybe_allow_in_graph
128
+ class DiTBlock(nn.Module):
129
+ r"""
130
+ Transformer block used in Hunyuan-DiT model (https://github.com/Tencent/HunyuanDiT). Allow skip connection and
131
+ QKNorm
132
+
133
+ Parameters:
134
+ dim (`int`):
135
+ The number of channels in the input and output.
136
+ num_attention_heads (`int`):
137
+ The number of headsto use for multi-head attention.
138
+ cross_attention_dim (`int`,*optional*):
139
+ The size of the encoder_hidden_states vector for cross attention.
140
+ dropout(`float`, *optional*, defaults to 0.0):
141
+ The dropout probability to use.
142
+ activation_fn (`str`,*optional*, defaults to `"geglu"`):
143
+ Activation function to be used in feed-forward. .
144
+ norm_elementwise_affine (`bool`, *optional*, defaults to `True`):
145
+ Whether to use learnable elementwise affine parameters for normalization.
146
+ norm_eps (`float`, *optional*, defaults to 1e-6):
147
+ A small constant added to the denominator in normalization layers to prevent division by zero.
148
+ final_dropout (`bool` *optional*, defaults to False):
149
+ Whether to apply a final dropout after the last feed-forward layer.
150
+ ff_inner_dim (`int`, *optional*):
151
+ The size of the hidden layer in the feed-forward block. Defaults to `None`.
152
+ ff_bias (`bool`, *optional*, defaults to `True`):
153
+ Whether to use bias in the feed-forward block.
154
+ skip (`bool`, *optional*, defaults to `False`):
155
+ Whether to use skip connection. Defaults to `False` for down-blocks and mid-blocks.
156
+ qk_norm (`bool`, *optional*, defaults to `True`):
157
+ Whether to use normalization in QK calculation. Defaults to `True`.
158
+ """
159
+
160
+ def __init__(
161
+ self,
162
+ dim: int,
163
+ num_attention_heads: int,
164
+ use_self_attention: bool = True,
165
+ use_cross_attention: bool = False,
166
+ self_attention_norm_type: Optional[str] = None, # ada layer norm
167
+ cross_attention_dim: Optional[int] = None,
168
+ cross_attention_norm_type: Optional[str] = "fp32_layer_norm",
169
+ # parallel second cross attention
170
+ use_cross_attention_2: bool = False,
171
+ cross_attention_2_dim: Optional[int] = None,
172
+ cross_attention_2_norm_type: Optional[str] = None,
173
+ dropout=0.0,
174
+ activation_fn: str = "gelu",
175
+ norm_type: str = "fp32_layer_norm", # TODO
176
+ norm_elementwise_affine: bool = True,
177
+ norm_eps: float = 1e-5,
178
+ final_dropout: bool = False,
179
+ ff_inner_dim: Optional[int] = None, # int(dim * 4) if None
180
+ ff_bias: bool = True,
181
+ skip: bool = False,
182
+ skip_concat_front: bool = False, # [x, skip] or [skip, x]
183
+ skip_norm_last: bool = False, # this is an error
184
+ qk_norm: bool = True,
185
+ qkv_bias: bool = True,
186
+ ):
187
+ super().__init__()
188
+
189
+ self.use_self_attention = use_self_attention
190
+ self.use_cross_attention = use_cross_attention
191
+ self.use_cross_attention_2 = use_cross_attention_2
192
+ self.skip_concat_front = skip_concat_front
193
+ self.skip_norm_last = skip_norm_last
194
+ # Define 3 blocks. Each block has its own normalization layer.
195
+ # NOTE: when new version comes, check norm2 and norm 3
196
+ # 1. Self-Attn
197
+ if use_self_attention:
198
+ if (
199
+ self_attention_norm_type == "fp32_layer_norm"
200
+ or self_attention_norm_type is None
201
+ ):
202
+ self.norm1 = FP32LayerNorm(dim, norm_eps, norm_elementwise_affine)
203
+ else:
204
+ raise NotImplementedError
205
+
206
+ self.attn1 = Attention(
207
+ query_dim=dim,
208
+ cross_attention_dim=None,
209
+ dim_head=dim // num_attention_heads,
210
+ heads=num_attention_heads,
211
+ qk_norm="rms_norm" if qk_norm else None,
212
+ eps=1e-6,
213
+ bias=qkv_bias,
214
+ processor=TripoSGAttnProcessor2_0(),
215
+ )
216
+
217
+ # 2. Cross-Attn
218
+ if use_cross_attention:
219
+ assert cross_attention_dim is not None
220
+
221
+ self.norm2 = FP32LayerNorm(dim, norm_eps, norm_elementwise_affine)
222
+
223
+ self.attn2 = Attention(
224
+ query_dim=dim,
225
+ cross_attention_dim=cross_attention_dim,
226
+ dim_head=dim // num_attention_heads,
227
+ heads=num_attention_heads,
228
+ qk_norm="rms_norm" if qk_norm else None,
229
+ cross_attention_norm=cross_attention_norm_type,
230
+ eps=1e-6,
231
+ bias=qkv_bias,
232
+ processor=TripoSGAttnProcessor2_0(),
233
+ )
234
+
235
+ # 2'. Parallel Second Cross-Attn
236
+ if use_cross_attention_2:
237
+ assert cross_attention_2_dim is not None
238
+
239
+ self.norm2_2 = FP32LayerNorm(dim, norm_eps, norm_elementwise_affine)
240
+
241
+ self.attn2_2 = Attention(
242
+ query_dim=dim,
243
+ cross_attention_dim=cross_attention_2_dim,
244
+ dim_head=dim // num_attention_heads,
245
+ heads=num_attention_heads,
246
+ qk_norm="rms_norm" if qk_norm else None,
247
+ cross_attention_norm=cross_attention_2_norm_type,
248
+ eps=1e-6,
249
+ bias=qkv_bias,
250
+ processor=TripoSGAttnProcessor2_0(),
251
+ )
252
+
253
+ # 3. Feed-forward
254
+ self.norm3 = FP32LayerNorm(dim, norm_eps, norm_elementwise_affine)
255
+
256
+ self.ff = FeedForward(
257
+ dim,
258
+ dropout=dropout, ### 0.0
259
+ activation_fn=activation_fn, ### approx GeLU
260
+ final_dropout=final_dropout, ### 0.0
261
+ inner_dim=ff_inner_dim, ### int(dim * mlp_ratio)
262
+ bias=ff_bias,
263
+ )
264
+
265
+ # 4. Skip Connection
266
+ if skip:
267
+ self.skip_norm = FP32LayerNorm(dim, norm_eps, elementwise_affine=True)
268
+ self.skip_linear = nn.Linear(2 * dim, dim)
269
+ else:
270
+ self.skip_linear = None
271
+
272
+ # 5. adaLN time embedding
273
+ self.adaln_modulation = nn.Sequential(
274
+ nn.SiLU(),
275
+ nn.Linear(dim, 9 * dim, bias=True)
276
+ )
277
+
278
+ # let chunk size default to None
279
+ self._chunk_size = None
280
+ self._chunk_dim = 0
281
+
282
+
283
+ # Copied from diffusers.models.attention.BasicTransformerBlock.set_chunk_feed_forward
284
+ def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int = 0):
285
+ # Sets chunk feed-forward
286
+ self._chunk_size = chunk_size
287
+ self._chunk_dim = dim
288
+
289
+ def forward(
290
+ self,
291
+ hidden_states: torch.Tensor,
292
+ encoder_hidden_states: Optional[torch.Tensor] = None,
293
+ encoder_hidden_states_2: Optional[torch.Tensor] = None,
294
+ temb: Optional[torch.Tensor] = None,
295
+ image_rotary_emb: Optional[torch.Tensor] = None,
296
+ skip: Optional[torch.Tensor] = None,
297
+ attention_kwargs: Optional[Dict[str, Any]] = None,
298
+ ) -> torch.Tensor:
299
+ # Prepare attention kwargs
300
+ attention_kwargs = attention_kwargs or {}
301
+
302
+ # Notice that normalization is always applied before the real computation in the following blocks.
303
+ # 0. Long Skip Connection
304
+ if self.skip_linear is not None:
305
+ cat = torch.cat(
306
+ (
307
+ [skip, hidden_states]
308
+ if self.skip_concat_front
309
+ else [hidden_states, skip]
310
+ ),
311
+ dim=-1,
312
+ )
313
+ if self.skip_norm_last:
314
+ # don't do this
315
+ hidden_states = self.skip_linear(cat)
316
+ hidden_states = self.skip_norm(hidden_states)
317
+ else:
318
+ cat = self.skip_norm(cat)
319
+ hidden_states = self.skip_linear(cat)
320
+
321
+ # 0. adaLN time embedding
322
+ shift_msa, scale_msa, gate_msa, shift_mca, scale_mca, gate_mca, shift_mlp, scale_mlp, gate_mlp = self.adaln_modulation(
323
+ temb
324
+ ).chunk(9, dim=-1)
325
+
326
+ # 1. Self-Attention
327
+ if self.use_self_attention:
328
+ norm_hidden_states = self.norm1(hidden_states) * (1 + scale_msa) + shift_msa
329
+ attn_output = self.attn1(
330
+ norm_hidden_states,
331
+ image_rotary_emb=image_rotary_emb,
332
+ **attention_kwargs,
333
+ )
334
+ hidden_states = hidden_states + gate_msa * attn_output
335
+
336
+ # 2. Cross-Attention
337
+ if self.use_cross_attention:
338
+ if self.use_cross_attention_2:
339
+ hidden_states = (
340
+ hidden_states
341
+ + self.attn2(
342
+ self.norm2(hidden_states),
343
+ encoder_hidden_states=encoder_hidden_states,
344
+ image_rotary_emb=image_rotary_emb,
345
+ **attention_kwargs,
346
+ )
347
+ + self.attn2_2(
348
+ self.norm2_2(hidden_states),
349
+ encoder_hidden_states=encoder_hidden_states_2,
350
+ image_rotary_emb=image_rotary_emb,
351
+ **attention_kwargs,
352
+ )
353
+ )
354
+ else:
355
+ hidden_states = hidden_states + gate_mca * self.attn2(
356
+ self.norm2(hidden_states) * (1 + scale_mca) + shift_mca,
357
+ encoder_hidden_states=encoder_hidden_states,
358
+ image_rotary_emb=image_rotary_emb,
359
+ **attention_kwargs,
360
+ )
361
+
362
+ # FFN Layer ### TODO: switch norm2 and norm3 in the state dict
363
+ mlp_inputs = self.norm3(hidden_states) * (1 + scale_mlp) + shift_mlp
364
+ hidden_states = hidden_states + gate_mlp * self.ff(mlp_inputs)
365
+
366
+ return hidden_states
367
+
368
+
369
+ class DetailGen3DDiTModel(ModelMixin, ConfigMixin):
370
+ """
371
+ DetailGen3DDiT: Diffusion model with a Transformer backbone.
372
+
373
+ Inherit ModelMixin and ConfigMixin to be compatible with the sampler StableDiffusionPipeline of diffusers.
374
+
375
+ Parameters:
376
+ num_attention_heads (`int`, *optional*, defaults to 16):
377
+ The number of heads to use for multi-head attention.
378
+ attention_head_dim (`int`, *optional*, defaults to 88):
379
+ The number of channels in each head.
380
+ in_channels (`int`, *optional*):
381
+ The number of channels in the input and output (specify if the input is **continuous**).
382
+ patch_size (`int`, *optional*):
383
+ The size of the patch to use for the input.
384
+ activation_fn (`str`, *optional*, defaults to `"geglu"`):
385
+ Activation function to use in feed-forward.
386
+ sample_size (`int`, *optional*):
387
+ The width of the latent images. This is fixed during training since it is used to learn a number of
388
+ position embeddings.
389
+ dropout (`float`, *optional*, defaults to 0.0):
390
+ The dropout probability to use.
391
+ cross_attention_dim (`int`, *optional*):
392
+ The number of dimension in the clip text embedding.
393
+ hidden_size (`int`, *optional*):
394
+ The size of hidden layer in the conditioning embedding layers.
395
+ num_layers (`int`, *optional*, defaults to 1):
396
+ The number of layers of Transformer blocks to use.
397
+ mlp_ratio (`float`, *optional*, defaults to 4.0):
398
+ The ratio of the hidden layer size to the input size.
399
+ learn_sigma (`bool`, *optional*, defaults to `True`):
400
+ Whether to predict variance.
401
+ cross_attention_dim_t5 (`int`, *optional*):
402
+ The number dimensions in t5 text embedding.
403
+ pooled_projection_dim (`int`, *optional*):
404
+ The size of the pooled projection.
405
+ text_len (`int`, *optional*):
406
+ The length of the clip text embedding.
407
+ text_len_t5 (`int`, *optional*):
408
+ The length of the T5 text embedding.
409
+ use_style_cond_and_image_meta_size (`bool`, *optional*):
410
+ Whether or not to use style condition and image meta size. True for version <=1.1, False for version >= 1.2
411
+ """
412
+
413
+ _supports_gradient_checkpointing = True
414
+
415
+ @register_to_config
416
+ def __init__(
417
+ self,
418
+ num_attention_heads: int = 12,
419
+ width: int = 768,
420
+ in_channels: int = 64,
421
+ num_layers: int = 24,
422
+ cross_attention_dim: int = 1024,
423
+ ):
424
+ super().__init__()
425
+ self.out_channels = in_channels
426
+ self.num_heads = num_attention_heads
427
+ self.inner_dim = width
428
+ self.mlp_ratio = 4.0
429
+
430
+ time_embed_dim, timestep_input_dim = self._set_time_proj(
431
+ "positional",
432
+ inner_dim=self.inner_dim,
433
+ flip_sin_to_cos=False,
434
+ freq_shift=0,
435
+ time_embedding_dim=None,
436
+ )
437
+ self.time_proj = TimestepEmbedding(
438
+ timestep_input_dim, time_embed_dim, act_fn="gelu", out_dim=self.inner_dim
439
+ )
440
+ self.proj_in = nn.Linear(self.config.in_channels, self.inner_dim, bias=True)
441
+
442
+ self.blocks = nn.ModuleList(
443
+ [
444
+ DiTBlock(
445
+ dim=self.inner_dim,
446
+ num_attention_heads=self.config.num_attention_heads,
447
+ use_self_attention=True,
448
+ use_cross_attention=True,
449
+ self_attention_norm_type="fp32_layer_norm",
450
+ cross_attention_dim=self.config.cross_attention_dim,
451
+ cross_attention_norm_type=None,
452
+ use_cross_attention_2=False,
453
+ cross_attention_2_norm_type=None,
454
+ activation_fn="gelu",
455
+ norm_type="fp32_layer_norm", # TODO
456
+ norm_eps=1e-5,
457
+ ff_inner_dim=int(self.inner_dim * self.mlp_ratio),
458
+ qk_norm=False, # See http://arxiv.org/abs/2302.05442 for details.
459
+ qkv_bias=False,
460
+ )
461
+ for layer in range(num_layers)
462
+ ]
463
+ )
464
+
465
+ self.norm_out = LayerNorm(self.inner_dim)
466
+ self.proj_out = nn.Linear(self.inner_dim, self.out_channels, bias=True)
467
+
468
+ self.gradient_checkpointing = False
469
+
470
+ def _set_gradient_checkpointing(self, module, value=False):
471
+ self.gradient_checkpointing = value
472
+
473
+ def _set_time_proj(
474
+ self,
475
+ time_embedding_type: str,
476
+ inner_dim: int,
477
+ flip_sin_to_cos: bool,
478
+ freq_shift: float,
479
+ time_embedding_dim: int,
480
+ ) -> Tuple[int, int]:
481
+ if time_embedding_type == "fourier":
482
+ time_embed_dim = time_embedding_dim or inner_dim * 2
483
+ if time_embed_dim % 2 != 0:
484
+ raise ValueError(
485
+ f"`time_embed_dim` should be divisible by 2, but is {time_embed_dim}."
486
+ )
487
+ self.time_embed = GaussianFourierProjection(
488
+ time_embed_dim // 2,
489
+ set_W_to_weight=False,
490
+ log=False,
491
+ flip_sin_to_cos=flip_sin_to_cos,
492
+ )
493
+ timestep_input_dim = time_embed_dim
494
+ elif time_embedding_type == "positional":
495
+ time_embed_dim = time_embedding_dim or inner_dim * 4
496
+
497
+ self.time_embed = Timesteps(inner_dim, flip_sin_to_cos, freq_shift)
498
+ timestep_input_dim = inner_dim
499
+ else:
500
+ raise ValueError(
501
+ f"{time_embedding_type} does not exist. Please make sure to use one of `fourier` or `positional`."
502
+ )
503
+
504
+ return time_embed_dim, timestep_input_dim
505
+
506
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections with FusedAttnProcessor2_0->FusedTripoSGAttnProcessor2_0
507
+ def fuse_qkv_projections(self):
508
+ """
509
+ Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
510
+ are fused. For cross-attention modules, key and value projection matrices are fused.
511
+
512
+ <Tip warning={true}>
513
+
514
+ This API is 🧪 experimental.
515
+
516
+ </Tip>
517
+ """
518
+ self.original_attn_processors = None
519
+
520
+ for _, attn_processor in self.attn_processors.items():
521
+ if "Added" in str(attn_processor.__class__.__name__):
522
+ raise ValueError(
523
+ "`fuse_qkv_projections()` is not supported for models having added KV projections."
524
+ )
525
+
526
+ self.original_attn_processors = self.attn_processors
527
+
528
+ for module in self.modules():
529
+ if isinstance(module, Attention):
530
+ module.fuse_projections(fuse=True)
531
+
532
+ self.set_attn_processor(FusedTripoSGAttnProcessor2_0())
533
+
534
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
535
+ def unfuse_qkv_projections(self):
536
+ """Disables the fused QKV projection if enabled.
537
+
538
+ <Tip warning={true}>
539
+
540
+ This API is 🧪 experimental.
541
+
542
+ </Tip>
543
+
544
+ """
545
+ if self.original_attn_processors is not None:
546
+ self.set_attn_processor(self.original_attn_processors)
547
+
548
+ @property
549
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
550
+ def attn_processors(self) -> Dict[str, AttentionProcessor]:
551
+ r"""
552
+ Returns:
553
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
554
+ indexed by its weight name.
555
+ """
556
+ # set recursively
557
+ processors = {}
558
+
559
+ def fn_recursive_add_processors(
560
+ name: str,
561
+ module: torch.nn.Module,
562
+ processors: Dict[str, AttentionProcessor],
563
+ ):
564
+ if hasattr(module, "get_processor"):
565
+ processors[f"{name}.processor"] = module.get_processor()
566
+
567
+ for sub_name, child in module.named_children():
568
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
569
+
570
+ return processors
571
+
572
+ for name, module in self.named_children():
573
+ fn_recursive_add_processors(name, module, processors)
574
+
575
+ return processors
576
+
577
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
578
+ def set_attn_processor(
579
+ self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]
580
+ ):
581
+ r"""
582
+ Sets the attention processor to use to compute attention.
583
+
584
+ Parameters:
585
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
586
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
587
+ for **all** `Attention` layers.
588
+
589
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
590
+ processor. This is strongly recommended when setting trainable attention processors.
591
+
592
+ """
593
+ count = len(self.attn_processors.keys())
594
+
595
+ if isinstance(processor, dict) and len(processor) != count:
596
+ raise ValueError(
597
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
598
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
599
+ )
600
+
601
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
602
+ if hasattr(module, "set_processor"):
603
+ if not isinstance(processor, dict):
604
+ module.set_processor(processor)
605
+ else:
606
+ module.set_processor(processor.pop(f"{name}.processor"))
607
+
608
+ for sub_name, child in module.named_children():
609
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
610
+
611
+ for name, module in self.named_children():
612
+ fn_recursive_attn_processor(name, module, processor)
613
+
614
+ def set_default_attn_processor(self):
615
+ """
616
+ Disables custom attention processors and sets the default attention implementation.
617
+ """
618
+ self.set_attn_processor(TripoSGAttnProcessor2_0())
619
+
620
+ def forward(
621
+ self,
622
+ hidden_states: Optional[torch.Tensor],
623
+ timestep: Union[int, float, torch.LongTensor],
624
+ encoder_hidden_states: Optional[torch.Tensor] = None,
625
+ encoder_hidden_states_2: Optional[torch.Tensor] = None,
626
+ image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
627
+ attention_kwargs: Optional[Dict[str, Any]] = None,
628
+ return_dict: bool = True,
629
+ ):
630
+ """
631
+ The [`HunyuanDiT2DModel`] forward method.
632
+
633
+ Args:
634
+ hidden_states (`torch.Tensor` of shape `(batch size, dim, height, width)`):
635
+ The input tensor.
636
+ timestep ( `torch.LongTensor`, *optional*):
637
+ Used to indicate denoising step.
638
+ encoder_hidden_states ( `torch.Tensor` of shape `(batch size, sequence len, embed dims)`, *optional*):
639
+ Conditional embeddings for cross attention layer.
640
+ encoder_hidden_states_2 ( `torch.Tensor` of shape `(batch size, sequence len, embed dims)`, *optional*):
641
+ Conditional embeddings for cross attention layer.
642
+ return_dict: bool
643
+ Whether to return a dictionary.
644
+ """
645
+
646
+ if attention_kwargs is not None:
647
+ attention_kwargs = attention_kwargs.copy()
648
+ lora_scale = attention_kwargs.pop("scale", 1.0)
649
+ else:
650
+ lora_scale = 1.0
651
+
652
+ if USE_PEFT_BACKEND:
653
+ # weight the lora layers by setting `lora_scale` for each PEFT layer
654
+ scale_lora_layers(self, lora_scale)
655
+ else:
656
+ if (
657
+ attention_kwargs is not None
658
+ and attention_kwargs.get("scale", None) is not None
659
+ ):
660
+ logger.warning(
661
+ "Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
662
+ )
663
+
664
+ _, N, _ = hidden_states.shape
665
+
666
+ temb = self.time_embed(timestep).to(hidden_states.dtype)
667
+ temb = self.time_proj(temb)
668
+ temb = temb.unsqueeze(dim=1) # unsqueeze to concat with hidden_states
669
+
670
+ hidden_states = self.proj_in(hidden_states)
671
+
672
+ skips = []
673
+ for layer, block in enumerate(self.blocks):
674
+ skip = None if layer <= self.config.num_layers // 2 else skips.pop()
675
+
676
+ if self.training and self.gradient_checkpointing:
677
+
678
+ def create_custom_forward(module):
679
+ def custom_forward(*inputs):
680
+ return module(*inputs)
681
+
682
+ return custom_forward
683
+
684
+ ckpt_kwargs: Dict[str, Any] = (
685
+ {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
686
+ )
687
+ hidden_states = torch.utils.checkpoint.checkpoint(
688
+ create_custom_forward(block),
689
+ hidden_states,
690
+ encoder_hidden_states,
691
+ encoder_hidden_states_2,
692
+ temb,
693
+ image_rotary_emb,
694
+ skip,
695
+ attention_kwargs,
696
+ **ckpt_kwargs,
697
+ )
698
+ else:
699
+ hidden_states = block(
700
+ hidden_states,
701
+ encoder_hidden_states=encoder_hidden_states,
702
+ encoder_hidden_states_2=encoder_hidden_states_2,
703
+ temb=temb,
704
+ image_rotary_emb=image_rotary_emb,
705
+ skip=skip,
706
+ attention_kwargs=attention_kwargs,
707
+ ) # (N, L, D)
708
+
709
+ if layer < self.config.num_layers // 2:
710
+ skips.append(hidden_states)
711
+
712
+ # final layer
713
+ hidden_states = self.norm_out(hidden_states)
714
+ hidden_states = self.proj_out(hidden_states)
715
+
716
+ if USE_PEFT_BACKEND:
717
+ # remove `lora_scale` from each PEFT layer
718
+ unscale_lora_layers(self, lora_scale)
719
+
720
+ if not return_dict:
721
+ return (hidden_states,)
722
+
723
+ return Transformer1DModelOutput(sample=hidden_states)
724
+
725
+ # Copied from diffusers.models.unets.unet_3d_condition.UNet3DConditionModel.enable_forward_chunking
726
+ def enable_forward_chunking(
727
+ self, chunk_size: Optional[int] = None, dim: int = 0
728
+ ) -> None:
729
+ """
730
+ Sets the attention processor to use [feed forward
731
+ chunking](https://huggingface.co/blog/reformer#2-chunked-feed-forward-layers).
732
+
733
+ Parameters:
734
+ chunk_size (`int`, *optional*):
735
+ The chunk size of the feed-forward layers. If not specified, will run feed-forward layer individually
736
+ over each tensor of dim=`dim`.
737
+ dim (`int`, *optional*, defaults to `0`):
738
+ The dimension over which the feed-forward computation should be chunked. Choose between dim=0 (batch)
739
+ or dim=1 (sequence length).
740
+ """
741
+ if dim not in [0, 1]:
742
+ raise ValueError(f"Make sure to set `dim` to either 0 or 1, not {dim}")
743
+
744
+ # By default chunk size is 1
745
+ chunk_size = chunk_size or 1
746
+
747
+ def fn_recursive_feed_forward(
748
+ module: torch.nn.Module, chunk_size: int, dim: int
749
+ ):
750
+ if hasattr(module, "set_chunk_feed_forward"):
751
+ module.set_chunk_feed_forward(chunk_size=chunk_size, dim=dim)
752
+
753
+ for child in module.children():
754
+ fn_recursive_feed_forward(child, chunk_size, dim)
755
+
756
+ for module in self.children():
757
+ fn_recursive_feed_forward(module, chunk_size, dim)
758
+
759
+ # Copied from diffusers.models.unets.unet_3d_condition.UNet3DConditionModel.disable_forward_chunking
760
+ def disable_forward_chunking(self):
761
+ def fn_recursive_feed_forward(
762
+ module: torch.nn.Module, chunk_size: int, dim: int
763
+ ):
764
+ if hasattr(module, "set_chunk_feed_forward"):
765
+ module.set_chunk_feed_forward(chunk_size=chunk_size, dim=dim)
766
+
767
+ for child in module.children():
768
+ fn_recursive_feed_forward(child, chunk_size, dim)
769
+
770
+ for module in self.children():
771
+ fn_recursive_feed_forward(module, None, 0)
detailgen3d/models/transformers/modeling_outputs.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+
3
+ import torch
4
+
5
+
6
+ @dataclass
7
+ class Transformer1DModelOutput:
8
+ sample: torch.FloatTensor
detailgen3d/models/transformers/triposg_transformer.py ADDED
@@ -0,0 +1,726 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025 VAST-AI-Research and contributors
2
+
3
+ # This code is based on Tencent HunyuanDiT (https://huggingface.co/Tencent-Hunyuan/HunyuanDiT),
4
+ # which is licensed under the Tencent Hunyuan Community License Agreement.
5
+ # Portions of this code are copied or adapted from HunyuanDiT.
6
+ # See the original license below:
7
+
8
+ # ---- Start of Tencent Hunyuan Community License Agreement ----
9
+
10
+ # TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT
11
+ # Tencent Hunyuan DiT Release Date: 14 May 2024
12
+ # THIS LICENSE AGREEMENT DOES NOT APPLY IN THE EUROPEAN UNION AND IS EXPRESSLY LIMITED TO THE TERRITORY, AS DEFINED BELOW.
13
+ # By clicking to agree or by using, reproducing, modifying, distributing, performing or displaying any portion or element of the Tencent Hunyuan Works, including via any Hosted Service, You will be deemed to have recognized and accepted the content of this Agreement, which is effective immediately.
14
+ # 1. DEFINITIONS.
15
+ # a. “Acceptable Use Policy” shall mean the policy made available by Tencent as set forth in the Exhibit A.
16
+ # b. “Agreement” shall mean the terms and conditions for use, reproduction, distribution, modification, performance and displaying of Tencent Hunyuan Works or any portion or element thereof set forth herein.
17
+ # c. “Documentation” shall mean the specifications, manuals and documentation for Tencent Hunyuan made publicly available by Tencent.
18
+ # d. “Hosted Service” shall mean a hosted service offered via an application programming interface (API), web access, or any other electronic or remote means.
19
+ # e. “Licensee,” “You” or “Your” shall mean a natural person or legal entity exercising the rights granted by this Agreement and/or using the Tencent Hunyuan Works for any purpose and in any field of use.
20
+ # f. “Materials” shall mean, collectively, Tencent’s proprietary Tencent Hunyuan and Documentation (and any portion thereof) as made available by Tencent under this Agreement.
21
+ # g. “Model Derivatives” shall mean all: (i) modifications to Tencent Hunyuan or any Model Derivative of Tencent Hunyuan; (ii) works based on Tencent Hunyuan or any Model Derivative of Tencent Hunyuan; or (iii) any other machine learning model which is created by transfer of patterns of the weights, parameters, operations, or Output of Tencent Hunyuan or any Model Derivative of Tencent Hunyuan, to that model in order to cause that model to perform similarly to Tencent Hunyuan or a Model Derivative of Tencent Hunyuan, including distillation methods, methods that use intermediate data representations, or methods based on the generation of synthetic data Outputs by Tencent Hunyuan or a Model Derivative of Tencent Hunyuan for training that model. For clarity, Outputs by themselves are not deemed Model Derivatives.
22
+ # h. “Output” shall mean the information and/or content output of Tencent Hunyuan or a Model Derivative that results from operating or otherwise using Tencent Hunyuan or a Model Derivative, including via a Hosted Service.
23
+ # i. “Tencent,” “We” or “Us” shall mean THL A29 Limited.
24
+ # j. “Tencent Hunyuan” shall mean the large language models, text/image/video/audio/3D generation models, and multimodal large language models and their software and algorithms, including trained model weights, parameters (including optimizer states), machine-learning model code, inference-enabling code, training-enabling code, fine-tuning enabling code and other elements of the foregoing made publicly available by Us, including, without limitation to, Tencent Hunyuan DiT released at https://huggingface.co/Tencent-Hunyuan/HunyuanDiT.
25
+ # k. “Tencent Hunyuan Works” shall mean: (i) the Materials; (ii) Model Derivatives; and (iii) all derivative works thereof.
26
+ # l. “Territory” shall mean the worldwide territory, excluding the territory of the European Union.
27
+ # m. “Third Party” or “Third Parties” shall mean individuals or legal entities that are not under common control with Us or You.
28
+ # n. “including” shall mean including but not limited to.
29
+ # 2. GRANT OF RIGHTS.
30
+ # We grant You, for the Territory only, a non-exclusive, non-transferable and royalty-free limited license under Tencent’s intellectual property or other rights owned by Us embodied in or utilized by the Materials to use, reproduce, distribute, create derivative works of (including Model Derivatives), and make modifications to the Materials, only in accordance with the terms of this Agreement and the Acceptable Use Policy, and You must not violate (or encourage or permit anyone else to violate) any term of this Agreement or the Acceptable Use Policy.
31
+ # 3. DISTRIBUTION.
32
+ # You may, subject to Your compliance with this Agreement, distribute or make available to Third Parties the Tencent Hunyuan Works, exclusively in the Territory, provided that You meet all of the following conditions:
33
+ # a. You must provide all such Third Party recipients of the Tencent Hunyuan Works or products or services using them a copy of this Agreement;
34
+ # b. You must cause any modified files to carry prominent notices stating that You changed the files;
35
+ # c. You are encouraged to: (i) publish at least one technology introduction blogpost or one public statement expressing Your experience of using the Tencent Hunyuan Works; and (ii) mark the products or services developed by using the Tencent Hunyuan Works to indicate that the product/service is “Powered by Tencent Hunyuan”; and
36
+ # d. All distributions to Third Parties (other than through a Hosted Service) must be accompanied by a “Notice” text file that contains the following notice: “Tencent Hunyuan is licensed under the Tencent Hunyuan Community License Agreement, Copyright © 2024 Tencent. All Rights Reserved. The trademark rights of “Tencent Hunyuan” are owned by Tencent or its affiliate.”
37
+ # You may add Your own copyright statement to Your modifications and, except as set forth in this Section and in Section 5, may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such Model Derivatives as a whole, provided Your use, reproduction, modification, distribution, performance and display of the work otherwise complies with the terms and conditions of this Agreement (including as regards the Territory). If You receive Tencent Hunyuan Works from a Licensee as part of an integrated end user product, then this Section 3 of this Agreement will not apply to You.
38
+ # 4. ADDITIONAL COMMERCIAL TERMS.
39
+ # If, on the Tencent Hunyuan version release date, the monthly active users of all products or services made available by or for Licensee is greater than 100 million monthly active users in the preceding calendar month, You must request a license from Tencent, which Tencent may grant to You in its sole discretion, and You are not authorized to exercise any of the rights under this Agreement unless or until Tencent otherwise expressly grants You such rights.
40
+ # 5. RULES OF USE.
41
+ # a. Your use of the Tencent Hunyuan Works must comply with applicable laws and regulations (including trade compliance laws and regulations) and adhere to the Acceptable Use Policy for the Tencent Hunyuan Works, which is hereby incorporated by reference into this Agreement. You must include the use restrictions referenced in these Sections 5(a) and 5(b) as an enforceable provision in any agreement (e.g., license agreement, terms of use, etc.) governing the use and/or distribution of Tencent Hunyuan Works and You must provide notice to subsequent users to whom You distribute that Tencent Hunyuan Works are subject to the use restrictions in these Sections 5(a) and 5(b).
42
+ # b. You must not use the Tencent Hunyuan Works or any Output or results of the Tencent Hunyuan Works to improve any other large language model (other than Tencent Hunyuan or Model Derivatives thereof).
43
+ # c. You must not use, reproduce, modify, distribute, or display the Tencent Hunyuan Works, Output or results of the Tencent Hunyuan Works outside the Territory. Any such use outside the Territory is unlicensed and unauthorized under this Agreement.
44
+ # 6. INTELLECTUAL PROPERTY.
45
+ # a. Subject to Tencent’s ownership of Tencent Hunyuan Works made by or for Tencent and intellectual property rights therein, conditioned upon Your compliance with the terms and conditions of this Agreement, as between You and Tencent, You will be the owner of any derivative works and modifications of the Materials and any Model Derivatives that are made by or for You.
46
+ # b. No trademark licenses are granted under this Agreement, and in connection with the Tencent Hunyuan Works, Licensee may not use any name or mark owned by or associated with Tencent or any of its affiliates, except as required for reasonable and customary use in describing and distributing the Tencent Hunyuan Works. Tencent hereby grants You a license to use “Tencent Hunyuan” (the “Mark”) in the Territory solely as required to comply with the provisions of Section 3(c), provided that You comply with any applicable laws related to trademark protection. All goodwill arising out of Your use of the Mark will inure to the benefit of Tencent.
47
+ # c. If You commence a lawsuit or other proceedings (including a cross-claim or counterclaim in a lawsuit) against Us or any person or entity alleging that the Materials or any Output, or any portion of any of the foregoing, infringe any intellectual property or other right owned or licensable by You, then all licenses granted to You under this Agreement shall terminate as of the date such lawsuit or other proceeding is filed. You will defend, indemnify and hold harmless Us from and against any claim by any Third Party arising out of or related to Your or the Third Party’s use or distribution of the Tencent Hunyuan Works.
48
+ # d. Tencent claims no rights in Outputs You generate. You and Your users are solely responsible for Outputs and their subsequent uses.
49
+ # 7. DISCLAIMERS OF WARRANTY AND LIMITATIONS OF LIABILITY.
50
+ # a. We are not obligated to support, update, provide training for, or develop any further version of the Tencent Hunyuan Works or to grant any license thereto.
51
+ # b. UNLESS AND ONLY TO THE EXTENT REQUIRED BY APPLICABLE LAW, THE TENCENT HUNYUAN WORKS AND ANY OUTPUT AND RESULTS THEREFROM ARE PROVIDED “AS IS” WITHOUT ANY EXPRESS OR IMPLIED WARRANTIES OF ANY KIND INCLUDING ANY WARRANTIES OF TITLE, MERCHANTABILITY, NONINFRINGEMENT, COURSE OF DEALING, USAGE OF TRADE, OR FITNESS FOR A PARTICULAR PURPOSE. YOU ARE SOLELY RESPONSIBLE FOR DETERMINING THE APPROPRIATENESS OF USING, REPRODUCING, MODIFYING, PERFORMING, DISPLAYING OR DISTRIBUTING ANY OF THE TENCENT HUNYUAN WORKS OR OUTPUTS AND ASSUME ANY AND ALL RISKS ASSOCIATED WITH YOUR OR A THIRD PARTY’S USE OR DISTRIBUTION OF ANY OF THE TENCENT HUNYUAN WORKS OR OUTPUTS AND YOUR EXERCISE OF RIGHTS AND PERMISSIONS UNDER THIS AGREEMENT.
52
+ # c. TO THE FULLEST EXTENT PERMITTED BY APPLICABLE LAW, IN NO EVENT SHALL TENCENT OR ITS AFFILIATES BE LIABLE UNDER ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, TORT, NEGLIGENCE, PRODUCTS LIABILITY, OR OTHERWISE, FOR ANY DAMAGES, INCLUDING ANY DIRECT, INDIRECT, SPECIAL, INCIDENTAL, EXEMPLARY, CONSEQUENTIAL OR PUNITIVE DAMAGES, OR LOST PROFITS OF ANY KIND ARISING FROM THIS AGREEMENT OR RELATED TO ANY OF THE TENCENT HUNYUAN WORKS OR OUTPUTS, EVEN IF TENCENT OR ITS AFFILIATES HAVE BEEN ADVISED OF THE POSSIBILITY OF ANY OF THE FOREGOING.
53
+ # 8. SURVIVAL AND TERMINATION.
54
+ # a. The term of this Agreement shall commence upon Your acceptance of this Agreement or access to the Materials and will continue in full force and effect until terminated in accordance with the terms and conditions herein.
55
+ # b. We may terminate this Agreement if You breach any of the terms or conditions of this Agreement. Upon termination of this Agreement, You must promptly delete and cease use of the Tencent Hunyuan Works. Sections 6(a), 6(c), 7 and 9 shall survive the termination of this Agreement.
56
+ # 9. GOVERNING LAW AND JURISDICTION.
57
+ # a. This Agreement and any dispute arising out of or relating to it will be governed by the laws of the Hong Kong Special Administrative Region of the People’s Republic of China, without regard to conflict of law principles, and the UN Convention on Contracts for the International Sale of Goods does not apply to this Agreement.
58
+ # b. Exclusive jurisdiction and venue for any dispute arising out of or relating to this Agreement will be a court of competent jurisdiction in the Hong Kong Special Administrative Region of the People’s Republic of China, and Tencent and Licensee consent to the exclusive jurisdiction of such court with respect to any such dispute.
59
+ #
60
+ # EXHIBIT A
61
+ # ACCEPTABLE USE POLICY
62
+
63
+ # Tencent reserves the right to update this Acceptable Use Policy from time to time.
64
+ # Last modified: [insert date]
65
+
66
+ # Tencent endeavors to promote safe and fair use of its tools and features, including Tencent Hunyuan. You agree not to use Tencent Hunyuan or Model Derivatives:
67
+ # 1. Outside the Territory;
68
+ # 2. In any way that violates any applicable national, federal, state, local, international or any other law or regulation;
69
+ # 3. To harm Yourself or others;
70
+ # 4. To repurpose or distribute output from Tencent Hunyuan or any Model Derivatives to harm Yourself or others;
71
+ # 5. To override or circumvent the safety guardrails and safeguards We have put in place;
72
+ # 6. For the purpose of exploiting, harming or attempting to exploit or harm minors in any way;
73
+ # 7. To generate or disseminate verifiably false information and/or content with the purpose of harming others or influencing elections;
74
+ # 8. To generate or facilitate false online engagement, including fake reviews and other means of fake online engagement;
75
+ # 9. To intentionally defame, disparage or otherwise harass others;
76
+ # 10. To generate and/or disseminate malware (including ransomware) or any other content to be used for the purpose of harming electronic systems;
77
+ # 11. To generate or disseminate personal identifiable information with the purpose of harming others;
78
+ # 12. To generate or disseminate information (including images, code, posts, articles), and place the information in any public context (including –through the use of bot generated tweets), without expressly and conspicuously identifying that the information and/or content is machine generated;
79
+ # 13. To impersonate another individual without consent, authorization, or legal right;
80
+ # 14. To make high-stakes automated decisions in domains that affect an individual’s safety, rights or wellbeing (e.g., law enforcement, migration, medicine/health, management of critical infrastructure, safety components of products, essential services, credit, employment, housing, education, social scoring, or insurance);
81
+ # 15. In a manner that violates or disrespects the social ethics and moral standards of other countries or regions;
82
+ # 16. To perform, facilitate, threaten, incite, plan, promote or encourage violent extremism or terrorism;
83
+ # 17. For any use intended to discriminate against or harm individuals or groups based on protected characteristics or categories, online or offline social behavior or known or predicted personal or personality characteristics;
84
+ # 18. To intentionally exploit any of the vulnerabilities of a specific group of persons based on their age, social, physical or mental characteristics, in order to materially distort the behavior of a person pertaining to that group in a manner that causes or is likely to cause that person or another person physical or psychological harm;
85
+ # 19. For military purposes;
86
+ # 20. To engage in the unauthorized or unlicensed practice of any profession including, but not limited to, financial, legal, medical/health, or other professional practices.
87
+
88
+ # ---- End of Tencent Hunyuan Community License Agreement ----
89
+
90
+ # Please note that the use of this code is subject to the terms and conditions
91
+ # of the Tencent Hunyuan Community License Agreement, including the Acceptable Use Policy.
92
+
93
+ from typing import Any, Dict, Optional, Tuple, Union
94
+
95
+ import torch
96
+ import torch.utils.checkpoint
97
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
98
+ from diffusers.loaders import PeftAdapterMixin
99
+ from diffusers.models.attention import FeedForward
100
+ from diffusers.models.attention_processor import Attention, AttentionProcessor
101
+ from diffusers.models.embeddings import (
102
+ GaussianFourierProjection,
103
+ TimestepEmbedding,
104
+ Timesteps,
105
+ )
106
+ from diffusers.models.modeling_utils import ModelMixin
107
+ from diffusers.models.normalization import (
108
+ AdaLayerNormContinuous,
109
+ FP32LayerNorm,
110
+ LayerNorm,
111
+ )
112
+ from diffusers.utils import (
113
+ USE_PEFT_BACKEND,
114
+ is_torch_version,
115
+ logging,
116
+ scale_lora_layers,
117
+ unscale_lora_layers,
118
+ )
119
+ from diffusers.utils.torch_utils import maybe_allow_in_graph
120
+ from torch import nn
121
+
122
+ from ..attention_processor import FusedTripoSGAttnProcessor2_0, TripoSGAttnProcessor2_0
123
+ from .modeling_outputs import Transformer1DModelOutput
124
+
125
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
126
+
127
+
128
+ @maybe_allow_in_graph
129
+ class DiTBlock(nn.Module):
130
+ r"""
131
+ Transformer block used in Hunyuan-DiT model (https://github.com/Tencent/HunyuanDiT). Allow skip connection and
132
+ QKNorm
133
+
134
+ Parameters:
135
+ dim (`int`):
136
+ The number of channels in the input and output.
137
+ num_attention_heads (`int`):
138
+ The number of headsto use for multi-head attention.
139
+ cross_attention_dim (`int`,*optional*):
140
+ The size of the encoder_hidden_states vector for cross attention.
141
+ dropout(`float`, *optional*, defaults to 0.0):
142
+ The dropout probability to use.
143
+ activation_fn (`str`,*optional*, defaults to `"geglu"`):
144
+ Activation function to be used in feed-forward. .
145
+ norm_elementwise_affine (`bool`, *optional*, defaults to `True`):
146
+ Whether to use learnable elementwise affine parameters for normalization.
147
+ norm_eps (`float`, *optional*, defaults to 1e-6):
148
+ A small constant added to the denominator in normalization layers to prevent division by zero.
149
+ final_dropout (`bool` *optional*, defaults to False):
150
+ Whether to apply a final dropout after the last feed-forward layer.
151
+ ff_inner_dim (`int`, *optional*):
152
+ The size of the hidden layer in the feed-forward block. Defaults to `None`.
153
+ ff_bias (`bool`, *optional*, defaults to `True`):
154
+ Whether to use bias in the feed-forward block.
155
+ skip (`bool`, *optional*, defaults to `False`):
156
+ Whether to use skip connection. Defaults to `False` for down-blocks and mid-blocks.
157
+ qk_norm (`bool`, *optional*, defaults to `True`):
158
+ Whether to use normalization in QK calculation. Defaults to `True`.
159
+ """
160
+
161
+ def __init__(
162
+ self,
163
+ dim: int,
164
+ num_attention_heads: int,
165
+ use_self_attention: bool = True,
166
+ self_attention_norm_type: Optional[str] = None,
167
+ use_cross_attention: bool = True, # ada layer norm
168
+ cross_attention_dim: Optional[int] = None,
169
+ cross_attention_norm_type: Optional[str] = "fp32_layer_norm",
170
+ dropout=0.0,
171
+ activation_fn: str = "gelu",
172
+ norm_type: str = "fp32_layer_norm", # TODO
173
+ norm_elementwise_affine: bool = True,
174
+ norm_eps: float = 1e-5,
175
+ final_dropout: bool = False,
176
+ ff_inner_dim: Optional[int] = None, # int(dim * 4) if None
177
+ ff_bias: bool = True,
178
+ skip: bool = False,
179
+ skip_concat_front: bool = False, # [x, skip] or [skip, x]
180
+ skip_norm_last: bool = False, # this is an error
181
+ qk_norm: bool = True,
182
+ qkv_bias: bool = True,
183
+ ):
184
+ super().__init__()
185
+
186
+ self.use_self_attention = use_self_attention
187
+ self.use_cross_attention = use_cross_attention
188
+ self.skip_concat_front = skip_concat_front
189
+ self.skip_norm_last = skip_norm_last
190
+ # Define 3 blocks. Each block has its own normalization layer.
191
+ # NOTE: when new version comes, check norm2 and norm 3
192
+ # 1. Self-Attn
193
+ if use_self_attention:
194
+ if (
195
+ self_attention_norm_type == "fp32_layer_norm"
196
+ or self_attention_norm_type is None
197
+ ):
198
+ self.norm1 = FP32LayerNorm(dim, norm_eps, norm_elementwise_affine)
199
+ else:
200
+ raise NotImplementedError
201
+
202
+ self.attn1 = Attention(
203
+ query_dim=dim,
204
+ cross_attention_dim=None,
205
+ dim_head=dim // num_attention_heads,
206
+ heads=num_attention_heads,
207
+ qk_norm="rms_norm" if qk_norm else None,
208
+ eps=1e-6,
209
+ bias=qkv_bias,
210
+ processor=TripoSGAttnProcessor2_0(),
211
+ )
212
+
213
+ # 2. Cross-Attn
214
+ if use_cross_attention:
215
+ assert cross_attention_dim is not None
216
+
217
+ self.norm2 = FP32LayerNorm(dim, norm_eps, norm_elementwise_affine)
218
+
219
+ self.attn2 = Attention(
220
+ query_dim=dim,
221
+ cross_attention_dim=cross_attention_dim,
222
+ dim_head=dim // num_attention_heads,
223
+ heads=num_attention_heads,
224
+ qk_norm="rms_norm" if qk_norm else None,
225
+ cross_attention_norm=cross_attention_norm_type,
226
+ eps=1e-6,
227
+ bias=qkv_bias,
228
+ processor=TripoSGAttnProcessor2_0(),
229
+ )
230
+
231
+ # 3. Feed-forward
232
+ self.norm3 = FP32LayerNorm(dim, norm_eps, norm_elementwise_affine)
233
+
234
+ self.ff = FeedForward(
235
+ dim,
236
+ dropout=dropout, ### 0.0
237
+ activation_fn=activation_fn, ### approx GeLU
238
+ final_dropout=final_dropout, ### 0.0
239
+ inner_dim=ff_inner_dim, ### int(dim * mlp_ratio)
240
+ bias=ff_bias,
241
+ )
242
+
243
+ # 4. Skip Connection
244
+ if skip:
245
+ self.skip_norm = FP32LayerNorm(dim, norm_eps, elementwise_affine=True)
246
+ self.skip_linear = nn.Linear(2 * dim, dim)
247
+ else:
248
+ self.skip_linear = None
249
+
250
+ # let chunk size default to None
251
+ self._chunk_size = None
252
+ self._chunk_dim = 0
253
+
254
+ def set_topk(self, topk):
255
+ self.flash_processor.topk = topk
256
+
257
+ def set_flash_processor(self, flash_processor):
258
+ self.flash_processor = flash_processor
259
+ self.attn2.processor = self.flash_processor
260
+
261
+ # Copied from diffusers.models.attention.BasicTransformerBlock.set_chunk_feed_forward
262
+ def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int = 0):
263
+ # Sets chunk feed-forward
264
+ self._chunk_size = chunk_size
265
+ self._chunk_dim = dim
266
+
267
+ def forward(
268
+ self,
269
+ hidden_states: torch.Tensor,
270
+ encoder_hidden_states: Optional[torch.Tensor] = None,
271
+ temb: Optional[torch.Tensor] = None,
272
+ image_rotary_emb: Optional[torch.Tensor] = None,
273
+ skip: Optional[torch.Tensor] = None,
274
+ attention_kwargs: Optional[Dict[str, Any]] = None,
275
+ ) -> torch.Tensor:
276
+ # Prepare attention kwargs
277
+ attention_kwargs = attention_kwargs or {}
278
+
279
+ # Notice that normalization is always applied before the real computation in the following blocks.
280
+ # 0. Long Skip Connection
281
+ if self.skip_linear is not None:
282
+ cat = torch.cat(
283
+ (
284
+ [skip, hidden_states]
285
+ if self.skip_concat_front
286
+ else [hidden_states, skip]
287
+ ),
288
+ dim=-1,
289
+ )
290
+ if self.skip_norm_last:
291
+ # don't do this
292
+ hidden_states = self.skip_linear(cat)
293
+ hidden_states = self.skip_norm(hidden_states)
294
+ else:
295
+ cat = self.skip_norm(cat)
296
+ hidden_states = self.skip_linear(cat)
297
+
298
+ # 1. Self-Attention
299
+ if self.use_self_attention:
300
+ norm_hidden_states = self.norm1(hidden_states)
301
+ attn_output = self.attn1(
302
+ norm_hidden_states,
303
+ image_rotary_emb=image_rotary_emb,
304
+ **attention_kwargs,
305
+ )
306
+ hidden_states = hidden_states + attn_output
307
+
308
+ # 2. Cross-Attention
309
+ if self.use_cross_attention:
310
+ hidden_states = hidden_states + self.attn2(
311
+ self.norm2(hidden_states),
312
+ encoder_hidden_states=encoder_hidden_states,
313
+ image_rotary_emb=image_rotary_emb,
314
+ **attention_kwargs,
315
+ )
316
+
317
+ # FFN Layer ### TODO: switch norm2 and norm3 in the state dict
318
+ mlp_inputs = self.norm3(hidden_states)
319
+ hidden_states = hidden_states + self.ff(mlp_inputs)
320
+
321
+ return hidden_states
322
+
323
+
324
+ class TripoSGDiTModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
325
+ """
326
+ TripoSG: Diffusion model with a Transformer backbone.
327
+
328
+ Inherit ModelMixin and ConfigMixin to be compatible with the sampler StableDiffusionPipeline of diffusers.
329
+
330
+ Parameters:
331
+ num_attention_heads (`int`, *optional*, defaults to 16):
332
+ The number of heads to use for multi-head attention.
333
+ attention_head_dim (`int`, *optional*, defaults to 88):
334
+ The number of channels in each head.
335
+ in_channels (`int`, *optional*):
336
+ The number of channels in the input and output (specify if the input is **continuous**).
337
+ patch_size (`int`, *optional*):
338
+ The size of the patch to use for the input.
339
+ activation_fn (`str`, *optional*, defaults to `"geglu"`):
340
+ Activation function to use in feed-forward.
341
+ sample_size (`int`, *optional*):
342
+ The width of the latent images. This is fixed during training since it is used to learn a number of
343
+ position embeddings.
344
+ dropout (`float`, *optional*, defaults to 0.0):
345
+ The dropout probability to use.
346
+ cross_attention_dim (`int`, *optional*):
347
+ The number of dimension in the clip text embedding.
348
+ hidden_size (`int`, *optional*):
349
+ The size of hidden layer in the conditioning embedding layers.
350
+ num_layers (`int`, *optional*, defaults to 1):
351
+ The number of layers of Transformer blocks to use.
352
+ mlp_ratio (`float`, *optional*, defaults to 4.0):
353
+ The ratio of the hidden layer size to the input size.
354
+ learn_sigma (`bool`, *optional*, defaults to `True`):
355
+ Whether to predict variance.
356
+ cross_attention_dim_t5 (`int`, *optional*):
357
+ The number dimensions in t5 text embedding.
358
+ pooled_projection_dim (`int`, *optional*):
359
+ The size of the pooled projection.
360
+ text_len (`int`, *optional*):
361
+ The length of the clip text embedding.
362
+ text_len_t5 (`int`, *optional*):
363
+ The length of the T5 text embedding.
364
+ use_style_cond_and_image_meta_size (`bool`, *optional*):
365
+ Whether or not to use style condition and image meta size. True for version <=1.1, False for version >= 1.2
366
+ """
367
+
368
+ _supports_gradient_checkpointing = True
369
+
370
+ @register_to_config
371
+ def __init__(
372
+ self,
373
+ num_attention_heads: int = 16,
374
+ width: int = 2048,
375
+ in_channels: int = 64,
376
+ num_layers: int = 21,
377
+ cross_attention_dim: int = 1024,
378
+ ):
379
+ super().__init__()
380
+ self.out_channels = in_channels
381
+ self.num_heads = num_attention_heads
382
+ self.inner_dim = width
383
+ self.mlp_ratio = 4.0
384
+
385
+ time_embed_dim, timestep_input_dim = self._set_time_proj(
386
+ "positional",
387
+ inner_dim=self.inner_dim,
388
+ flip_sin_to_cos=False,
389
+ freq_shift=0,
390
+ time_embedding_dim=None,
391
+ )
392
+ self.time_proj = TimestepEmbedding(
393
+ timestep_input_dim, time_embed_dim, act_fn="gelu", out_dim=self.inner_dim
394
+ )
395
+ self.proj_in = nn.Linear(self.config.in_channels, self.inner_dim, bias=True)
396
+
397
+ self.blocks = nn.ModuleList(
398
+ [
399
+ DiTBlock(
400
+ dim=self.inner_dim,
401
+ num_attention_heads=self.config.num_attention_heads,
402
+ use_self_attention=True,
403
+ self_attention_norm_type="fp32_layer_norm",
404
+ use_cross_attention=True,
405
+ cross_attention_dim=cross_attention_dim,
406
+ cross_attention_norm_type=None,
407
+ activation_fn="gelu",
408
+ norm_type="fp32_layer_norm", # TODO
409
+ norm_eps=1e-5,
410
+ ff_inner_dim=int(self.inner_dim * self.mlp_ratio),
411
+ skip=layer > num_layers // 2,
412
+ skip_concat_front=True,
413
+ skip_norm_last=True, # this is an error
414
+ qk_norm=True, # See http://arxiv.org/abs/2302.05442 for details.
415
+ qkv_bias=False,
416
+ )
417
+ for layer in range(num_layers)
418
+ ]
419
+ )
420
+
421
+ self.norm_out = LayerNorm(self.inner_dim)
422
+ self.proj_out = nn.Linear(self.inner_dim, self.out_channels, bias=True)
423
+
424
+ self.gradient_checkpointing = False
425
+
426
+ def _set_gradient_checkpointing(self, module, value=False):
427
+ self.gradient_checkpointing = value
428
+
429
+ def _set_time_proj(
430
+ self,
431
+ time_embedding_type: str,
432
+ inner_dim: int,
433
+ flip_sin_to_cos: bool,
434
+ freq_shift: float,
435
+ time_embedding_dim: int,
436
+ ) -> Tuple[int, int]:
437
+ if time_embedding_type == "fourier":
438
+ time_embed_dim = time_embedding_dim or inner_dim * 2
439
+ if time_embed_dim % 2 != 0:
440
+ raise ValueError(
441
+ f"`time_embed_dim` should be divisible by 2, but is {time_embed_dim}."
442
+ )
443
+ self.time_embed = GaussianFourierProjection(
444
+ time_embed_dim // 2,
445
+ set_W_to_weight=False,
446
+ log=False,
447
+ flip_sin_to_cos=flip_sin_to_cos,
448
+ )
449
+ timestep_input_dim = time_embed_dim
450
+ elif time_embedding_type == "positional":
451
+ time_embed_dim = time_embedding_dim or inner_dim * 4
452
+
453
+ self.time_embed = Timesteps(inner_dim, flip_sin_to_cos, freq_shift)
454
+ timestep_input_dim = inner_dim
455
+ else:
456
+ raise ValueError(
457
+ f"{time_embedding_type} does not exist. Please make sure to use one of `fourier` or `positional`."
458
+ )
459
+
460
+ return time_embed_dim, timestep_input_dim
461
+
462
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections with FusedAttnProcessor2_0->FusedTripoSGAttnProcessor2_0
463
+ def fuse_qkv_projections(self):
464
+ """
465
+ Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
466
+ are fused. For cross-attention modules, key and value projection matrices are fused.
467
+
468
+ <Tip warning={true}>
469
+
470
+ This API is 🧪 experimental.
471
+
472
+ </Tip>
473
+ """
474
+ self.original_attn_processors = None
475
+
476
+ for _, attn_processor in self.attn_processors.items():
477
+ if "Added" in str(attn_processor.__class__.__name__):
478
+ raise ValueError(
479
+ "`fuse_qkv_projections()` is not supported for models having added KV projections."
480
+ )
481
+
482
+ self.original_attn_processors = self.attn_processors
483
+
484
+ for module in self.modules():
485
+ if isinstance(module, Attention):
486
+ module.fuse_projections(fuse=True)
487
+
488
+ self.set_attn_processor(FusedTripoSGAttnProcessor2_0())
489
+
490
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
491
+ def unfuse_qkv_projections(self):
492
+ """Disables the fused QKV projection if enabled.
493
+
494
+ <Tip warning={true}>
495
+
496
+ This API is 🧪 experimental.
497
+
498
+ </Tip>
499
+
500
+ """
501
+ if self.original_attn_processors is not None:
502
+ self.set_attn_processor(self.original_attn_processors)
503
+
504
+ @property
505
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
506
+ def attn_processors(self) -> Dict[str, AttentionProcessor]:
507
+ r"""
508
+ Returns:
509
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
510
+ indexed by its weight name.
511
+ """
512
+ # set recursively
513
+ processors = {}
514
+
515
+ def fn_recursive_add_processors(
516
+ name: str,
517
+ module: torch.nn.Module,
518
+ processors: Dict[str, AttentionProcessor],
519
+ ):
520
+ if hasattr(module, "get_processor"):
521
+ processors[f"{name}.processor"] = module.get_processor()
522
+
523
+ for sub_name, child in module.named_children():
524
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
525
+
526
+ return processors
527
+
528
+ for name, module in self.named_children():
529
+ fn_recursive_add_processors(name, module, processors)
530
+
531
+ return processors
532
+
533
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
534
+ def set_attn_processor(
535
+ self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]
536
+ ):
537
+ r"""
538
+ Sets the attention processor to use to compute attention.
539
+
540
+ Parameters:
541
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
542
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
543
+ for **all** `Attention` layers.
544
+
545
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
546
+ processor. This is strongly recommended when setting trainable attention processors.
547
+
548
+ """
549
+ count = len(self.attn_processors.keys())
550
+
551
+ if isinstance(processor, dict) and len(processor) != count:
552
+ raise ValueError(
553
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
554
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
555
+ )
556
+
557
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
558
+ if hasattr(module, "set_processor"):
559
+ if not isinstance(processor, dict):
560
+ module.set_processor(processor)
561
+ else:
562
+ module.set_processor(processor.pop(f"{name}.processor"))
563
+
564
+ for sub_name, child in module.named_children():
565
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
566
+
567
+ for name, module in self.named_children():
568
+ fn_recursive_attn_processor(name, module, processor)
569
+
570
+ def set_default_attn_processor(self):
571
+ """
572
+ Disables custom attention processors and sets the default attention implementation.
573
+ """
574
+ self.set_attn_processor(TripoSGAttnProcessor2_0())
575
+
576
+ def forward(
577
+ self,
578
+ hidden_states: Optional[torch.Tensor],
579
+ timestep: Union[int, float, torch.LongTensor],
580
+ encoder_hidden_states: Optional[torch.Tensor] = None,
581
+ image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
582
+ attention_kwargs: Optional[Dict[str, Any]] = None,
583
+ return_dict: bool = True,
584
+ ):
585
+ """
586
+ The [`HunyuanDiT2DModel`] forward method.
587
+
588
+ Args:
589
+ hidden_states (`torch.Tensor` of shape `(batch size, dim, height, width)`):
590
+ The input tensor.
591
+ timestep ( `torch.LongTensor`, *optional*):
592
+ Used to indicate denoising step.
593
+ encoder_hidden_states ( `torch.Tensor` of shape `(batch size, sequence len, embed dims)`, *optional*):
594
+ Conditional embeddings for cross attention layer.
595
+ return_dict: bool
596
+ Whether to return a dictionary.
597
+ """
598
+
599
+ if attention_kwargs is not None:
600
+ attention_kwargs = attention_kwargs.copy()
601
+ lora_scale = attention_kwargs.pop("scale", 1.0)
602
+ else:
603
+ lora_scale = 1.0
604
+
605
+ if USE_PEFT_BACKEND:
606
+ # weight the lora layers by setting `lora_scale` for each PEFT layer
607
+ scale_lora_layers(self, lora_scale)
608
+ else:
609
+ if (
610
+ attention_kwargs is not None
611
+ and attention_kwargs.get("scale", None) is not None
612
+ ):
613
+ logger.warning(
614
+ "Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
615
+ )
616
+
617
+ _, N, _ = hidden_states.shape
618
+
619
+ temb = self.time_embed(timestep).to(hidden_states.dtype)
620
+ temb = self.time_proj(temb)
621
+ temb = temb.unsqueeze(dim=1) # unsqueeze to concat with hidden_states
622
+
623
+ hidden_states = self.proj_in(hidden_states)
624
+
625
+ # N + 1 token
626
+ hidden_states = torch.cat([temb, hidden_states], dim=1)
627
+
628
+ skips = []
629
+ for layer, block in enumerate(self.blocks):
630
+ skip = None if layer <= self.config.num_layers // 2 else skips.pop()
631
+
632
+ if self.training and self.gradient_checkpointing:
633
+
634
+ def create_custom_forward(module):
635
+ def custom_forward(*inputs):
636
+ return module(*inputs)
637
+
638
+ return custom_forward
639
+
640
+ ckpt_kwargs: Dict[str, Any] = (
641
+ {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
642
+ )
643
+ hidden_states = torch.utils.checkpoint.checkpoint(
644
+ create_custom_forward(block),
645
+ hidden_states,
646
+ encoder_hidden_states,
647
+ temb,
648
+ image_rotary_emb,
649
+ skip,
650
+ attention_kwargs,
651
+ **ckpt_kwargs,
652
+ )
653
+ else:
654
+ hidden_states = block(
655
+ hidden_states,
656
+ encoder_hidden_states=encoder_hidden_states,
657
+ temb=temb,
658
+ image_rotary_emb=image_rotary_emb,
659
+ skip=skip,
660
+ attention_kwargs=attention_kwargs,
661
+ ) # (N, L, D)
662
+
663
+ if layer < self.config.num_layers // 2:
664
+ skips.append(hidden_states)
665
+
666
+ # final layer
667
+ hidden_states = self.norm_out(hidden_states)
668
+ hidden_states = hidden_states[:, -N:]
669
+ hidden_states = self.proj_out(hidden_states)
670
+
671
+ if USE_PEFT_BACKEND:
672
+ # remove `lora_scale` from each PEFT layer
673
+ unscale_lora_layers(self, lora_scale)
674
+
675
+ if not return_dict:
676
+ return (hidden_states,)
677
+
678
+ return Transformer1DModelOutput(sample=hidden_states)
679
+
680
+ # Copied from diffusers.models.unets.unet_3d_condition.UNet3DConditionModel.enable_forward_chunking
681
+ def enable_forward_chunking(
682
+ self, chunk_size: Optional[int] = None, dim: int = 0
683
+ ) -> None:
684
+ """
685
+ Sets the attention processor to use [feed forward
686
+ chunking](https://huggingface.co/blog/reformer#2-chunked-feed-forward-layers).
687
+
688
+ Parameters:
689
+ chunk_size (`int`, *optional*):
690
+ The chunk size of the feed-forward layers. If not specified, will run feed-forward layer individually
691
+ over each tensor of dim=`dim`.
692
+ dim (`int`, *optional*, defaults to `0`):
693
+ The dimension over which the feed-forward computation should be chunked. Choose between dim=0 (batch)
694
+ or dim=1 (sequence length).
695
+ """
696
+ if dim not in [0, 1]:
697
+ raise ValueError(f"Make sure to set `dim` to either 0 or 1, not {dim}")
698
+
699
+ # By default chunk size is 1
700
+ chunk_size = chunk_size or 1
701
+
702
+ def fn_recursive_feed_forward(
703
+ module: torch.nn.Module, chunk_size: int, dim: int
704
+ ):
705
+ if hasattr(module, "set_chunk_feed_forward"):
706
+ module.set_chunk_feed_forward(chunk_size=chunk_size, dim=dim)
707
+
708
+ for child in module.children():
709
+ fn_recursive_feed_forward(child, chunk_size, dim)
710
+
711
+ for module in self.children():
712
+ fn_recursive_feed_forward(module, chunk_size, dim)
713
+
714
+ # Copied from diffusers.models.unets.unet_3d_condition.UNet3DConditionModel.disable_forward_chunking
715
+ def disable_forward_chunking(self):
716
+ def fn_recursive_feed_forward(
717
+ module: torch.nn.Module, chunk_size: int, dim: int
718
+ ):
719
+ if hasattr(module, "set_chunk_feed_forward"):
720
+ module.set_chunk_feed_forward(chunk_size=chunk_size, dim=dim)
721
+
722
+ for child in module.children():
723
+ fn_recursive_feed_forward(child, chunk_size, dim)
724
+
725
+ for module in self.children():
726
+ fn_recursive_feed_forward(module, None, 0)
detailgen3d/pipelines/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .pipeline_detailgen3d import DetailGen3DPipeline
detailgen3d/pipelines/pipeline_detailgen3d.py ADDED
@@ -0,0 +1,322 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import inspect
2
+ import math
3
+ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
4
+
5
+ import numpy as np
6
+ import PIL
7
+ import PIL.Image
8
+ import torch
9
+ from diffusers.image_processor import PipelineImageInput
10
+ from diffusers.pipelines.pipeline_utils import DiffusionPipeline
11
+ from diffusers.schedulers import FlowMatchEulerDiscreteScheduler # not sure
12
+ from diffusers.schedulers.scheduling_ddpm import DDPMScheduler
13
+ from diffusers.utils import logging
14
+ from diffusers.utils.torch_utils import randn_tensor
15
+ from transformers import (
16
+ BitImageProcessor,
17
+ CLIPImageProcessor,
18
+ CLIPVisionModelWithProjection,
19
+ Dinov2Model,
20
+ )
21
+
22
+ from ..models.autoencoders import TripoSGVAEModel
23
+ from ..models.transformers import DetailGen3DDiTModel
24
+ from .pipeline_detailgen3d_output import DetailGen3DPipelineOutput
25
+ from .pipeline_utils import TransformerDiffusionMixin
26
+
27
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
28
+
29
+
30
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
31
+ def retrieve_timesteps(
32
+ scheduler,
33
+ num_inference_steps: Optional[int] = None,
34
+ device: Optional[Union[str, torch.device]] = None,
35
+ timesteps: Optional[List[int]] = None,
36
+ sigmas: Optional[List[float]] = None,
37
+ **kwargs,
38
+ ):
39
+ """
40
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
41
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
42
+
43
+ Args:
44
+ scheduler (`SchedulerMixin`):
45
+ The scheduler to get timesteps from.
46
+ num_inference_steps (`int`):
47
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
48
+ must be `None`.
49
+ device (`str` or `torch.device`, *optional*):
50
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
51
+ timesteps (`List[int]`, *optional*):
52
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
53
+ `num_inference_steps` and `sigmas` must be `None`.
54
+ sigmas (`List[float]`, *optional*):
55
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
56
+ `num_inference_steps` and `timesteps` must be `None`.
57
+
58
+ Returns:
59
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
60
+ second element is the number of inference steps.
61
+ """
62
+ if timesteps is not None and sigmas is not None:
63
+ raise ValueError(
64
+ "Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values"
65
+ )
66
+ if timesteps is not None:
67
+ accepts_timesteps = "timesteps" in set(
68
+ inspect.signature(scheduler.set_timesteps).parameters.keys()
69
+ )
70
+ if not accepts_timesteps:
71
+ raise ValueError(
72
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
73
+ f" timestep schedules. Please check whether you are using the correct scheduler."
74
+ )
75
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
76
+ timesteps = scheduler.timesteps
77
+ num_inference_steps = len(timesteps)
78
+ elif sigmas is not None:
79
+ accept_sigmas = "sigmas" in set(
80
+ inspect.signature(scheduler.set_timesteps).parameters.keys()
81
+ )
82
+ if not accept_sigmas:
83
+ raise ValueError(
84
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
85
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
86
+ )
87
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
88
+ timesteps = scheduler.timesteps
89
+ num_inference_steps = len(timesteps)
90
+ else:
91
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
92
+ timesteps = scheduler.timesteps
93
+ return timesteps, num_inference_steps
94
+
95
+
96
+ class DetailGen3DPipeline(
97
+ DiffusionPipeline, TransformerDiffusionMixin
98
+ ):
99
+ """
100
+ Pipeline for detail generation using DetailGen3D.
101
+ """
102
+
103
+ def __init__(
104
+ self,
105
+ vae: TripoSGVAEModel,
106
+ transformer: DetailGen3DDiTModel,
107
+ scheduler: FlowMatchEulerDiscreteScheduler,
108
+ noise_scheduler: DDPMScheduler,
109
+ image_encoder_1: Dinov2Model,
110
+ feature_extractor_1: BitImageProcessor,
111
+ ):
112
+ super().__init__()
113
+
114
+ self.register_modules(
115
+ vae=vae,
116
+ transformer=transformer,
117
+ scheduler=scheduler,
118
+ noise_scheduler=noise_scheduler,
119
+ image_encoder_1=image_encoder_1,
120
+ feature_extractor_1=feature_extractor_1,
121
+ )
122
+
123
+ @property
124
+ def guidance_scale(self):
125
+ return self._guidance_scale
126
+
127
+ @property
128
+ def do_classifier_free_guidance(self):
129
+ return self._guidance_scale > 1
130
+
131
+ @property
132
+ def num_timesteps(self):
133
+ return self._num_timesteps
134
+
135
+ @property
136
+ def attention_kwargs(self):
137
+ return self._attention_kwargs
138
+
139
+ @property
140
+ def interrupt(self):
141
+ return self._interrupt
142
+
143
+ def encode_image_1(self, image, device, num_images_per_prompt):
144
+ dtype = next(self.image_encoder_1.parameters()).dtype
145
+
146
+ if not isinstance(image, torch.Tensor):
147
+ image = self.feature_extractor_1(image, return_tensors="pt").pixel_values
148
+
149
+ image = image.to(device=device, dtype=dtype)
150
+ image_embeds = self.image_encoder_1(image).last_hidden_state
151
+ image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
152
+ uncond_image_embeds = torch.zeros_like(image_embeds)
153
+
154
+ return image_embeds, uncond_image_embeds
155
+
156
+ def prepare_latents(
157
+ self,
158
+ batch_size,
159
+ num_tokens,
160
+ num_channels_latents,
161
+ dtype,
162
+ device,
163
+ generator,
164
+ latents: Optional[torch.Tensor] = None,
165
+ noise_aug_level = 0,
166
+ ):
167
+ if latents is not None:
168
+ latents = latents.to(device=device, dtype=dtype)
169
+ latents = self.noise_scheduler.add_noise(latents, torch.randn_like(latents), torch.tensor(noise_aug_level))
170
+ return latents
171
+
172
+ raise Exception(
173
+ f"You have to pass latents of geometry you want to refine."
174
+ )
175
+
176
+ @torch.no_grad()
177
+ def __call__(
178
+ self,
179
+ image: PipelineImageInput,
180
+ image_2: Optional[PipelineImageInput] = None,
181
+ num_inference_steps: int = 10,
182
+ timesteps: List[int] = None,
183
+ guidance_scale: float = 4.0,
184
+ num_images_per_prompt: int = 1,
185
+ sampled_points: Optional[torch.Tensor] = None,
186
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
187
+ latents: Optional[torch.FloatTensor] = None,
188
+ attention_kwargs: Optional[Dict[str, Any]] = None,
189
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
190
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
191
+ output_type: Optional[str] = "mesh_vf",
192
+ return_dict: bool = True,
193
+ noise_aug_level = 0,
194
+ ):
195
+ # 1. Check inputs. Raise error if not correct
196
+ # TODO
197
+
198
+ self._guidance_scale = guidance_scale
199
+ self._attention_kwargs = attention_kwargs
200
+ self._interrupt = False
201
+
202
+ # 2. Define call parameters
203
+ if isinstance(image, PIL.Image.Image):
204
+ batch_size = 1
205
+ elif isinstance(image, list):
206
+ batch_size = len(image)
207
+ elif isinstance(image, torch.Tensor):
208
+ batch_size = image.shape[0]
209
+ else:
210
+ raise ValueError("Invalid input type for image")
211
+
212
+ device = self._execution_device
213
+
214
+ # 3. Encode condition
215
+ image_embeds_1, negative_image_embeds_1 = self.encode_image_1(
216
+ image, device, num_images_per_prompt
217
+ )
218
+
219
+ if self.do_classifier_free_guidance:
220
+ image_embeds_1 = torch.cat([negative_image_embeds_1, image_embeds_1], dim=0)
221
+
222
+ # 4. Prepare timesteps
223
+ timesteps, num_inference_steps = retrieve_timesteps(
224
+ self.scheduler, num_inference_steps, device, timesteps
225
+ )
226
+ num_warmup_steps = max(
227
+ len(timesteps) - num_inference_steps * self.scheduler.order, 0
228
+ )
229
+ self._num_timesteps = len(timesteps)
230
+
231
+ # 5. Prepare latent variables
232
+ num_tokens = self.transformer.config.width
233
+ num_channels_latents = self.transformer.config.in_channels
234
+ latents = self.prepare_latents(
235
+ batch_size * num_images_per_prompt,
236
+ num_tokens,
237
+ num_channels_latents,
238
+ image_embeds_1.dtype,
239
+ device,
240
+ generator,
241
+ latents,
242
+ noise_aug_level,
243
+ )
244
+
245
+ # 6. Denoising loop
246
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
247
+ for i, t in enumerate(timesteps):
248
+ if self.interrupt:
249
+ continue
250
+
251
+ # expand the latents if we are doing classifier free guidance
252
+ latent_model_input = (
253
+ torch.cat([latents] * 2)
254
+ if self.do_classifier_free_guidance
255
+ else latents
256
+ )
257
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
258
+ timestep = t.expand(latent_model_input.shape[0])
259
+
260
+ noise_pred = self.transformer(
261
+ latent_model_input,
262
+ timestep,
263
+ encoder_hidden_states=image_embeds_1,
264
+ attention_kwargs=attention_kwargs,
265
+ return_dict=False,
266
+ )[0]
267
+
268
+ # perform guidance
269
+ if self.do_classifier_free_guidance:
270
+ noise_pred_uncond, noise_pred_image = noise_pred.chunk(2)
271
+ noise_pred = noise_pred_uncond + self.guidance_scale * (
272
+ noise_pred_image - noise_pred_uncond
273
+ )
274
+
275
+ # compute the previous noisy sample x_t -> x_t-1
276
+ latents_dtype = latents.dtype
277
+ latents = self.scheduler.step(
278
+ noise_pred, t, latents, return_dict=False
279
+ )[0]
280
+
281
+ if latents.dtype != latents_dtype:
282
+ if torch.backends.mps.is_available():
283
+ # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
284
+ latents = latents.to(latents_dtype)
285
+
286
+ if callback_on_step_end is not None:
287
+ callback_kwargs = {}
288
+ for k in callback_on_step_end_tensor_inputs:
289
+ callback_kwargs[k] = locals()[k]
290
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
291
+
292
+ latents = callback_outputs.pop("latents", latents)
293
+ image_embeds_1 = callback_outputs.pop(
294
+ "image_embeds_1", image_embeds_1
295
+ )
296
+ negative_image_embeds_1 = callback_outputs.pop(
297
+ "negative_image_embeds_1", negative_image_embeds_1
298
+ )
299
+
300
+ # call the callback, if provided
301
+ if i == len(timesteps) - 1 or (
302
+ (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0
303
+ ):
304
+ progress_bar.update()
305
+
306
+ if output_type == "latent":
307
+ output = latents
308
+ else:
309
+ if sampled_points is None:
310
+ raise ValueError(
311
+ "sampled_points must be provided when output_type is not 'latent'"
312
+ )
313
+
314
+ output = self.vae.decode(latents, sampled_points=sampled_points).sample
315
+
316
+ # Offload all models
317
+ self.maybe_free_model_hooks()
318
+
319
+ if not return_dict:
320
+ return (output,)
321
+
322
+ return DetailGen3DPipelineOutput(samples=output)
detailgen3d/pipelines/pipeline_detailgen3d_output.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+
3
+ import torch
4
+ from diffusers.utils import BaseOutput
5
+
6
+
7
+ @dataclass
8
+ class DetailGen3DPipelineOutput(BaseOutput):
9
+ r"""
10
+ Output class for DetailGen3D pipelines.
11
+ """
12
+
13
+ samples: torch.Tensor
detailgen3d/pipelines/pipeline_utils.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from diffusers.utils import logging
2
+
3
+ logger = logging.get_logger(__name__)
4
+
5
+
6
+ class TransformerDiffusionMixin:
7
+ r"""
8
+ Helper for DiffusionPipeline with vae and transformer.(mainly for DIT)
9
+ """
10
+
11
+ def enable_vae_slicing(self):
12
+ r"""
13
+ Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
14
+ compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
15
+ """
16
+ self.vae.enable_slicing()
17
+
18
+ def disable_vae_slicing(self):
19
+ r"""
20
+ Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
21
+ computing decoding in one step.
22
+ """
23
+ self.vae.disable_slicing()
24
+
25
+ def enable_vae_tiling(self):
26
+ r"""
27
+ Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
28
+ compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
29
+ processing larger images.
30
+ """
31
+ self.vae.enable_tiling()
32
+
33
+ def disable_vae_tiling(self):
34
+ r"""
35
+ Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
36
+ computing decoding in one step.
37
+ """
38
+ self.vae.disable_tiling()
39
+
40
+ def fuse_qkv_projections(self, transformer: bool = True, vae: bool = True):
41
+ """
42
+ Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
43
+ are fused. For cross-attention modules, key and value projection matrices are fused.
44
+
45
+ <Tip warning={true}>
46
+
47
+ This API is 🧪 experimental.
48
+
49
+ </Tip>
50
+
51
+ Args:
52
+ transformer (`bool`, defaults to `True`): To apply fusion on the Transformer.
53
+ vae (`bool`, defaults to `True`): To apply fusion on the VAE.
54
+ """
55
+ self.fusing_transformer = False
56
+ self.fusing_vae = False
57
+
58
+ if transformer:
59
+ self.fusing_transformer = True
60
+ self.transformer.fuse_qkv_projections()
61
+
62
+ if vae:
63
+ self.fusing_vae = True
64
+ self.vae.fuse_qkv_projections()
65
+
66
+ def unfuse_qkv_projections(self, transformer: bool = True, vae: bool = True):
67
+ """Disable QKV projection fusion if enabled.
68
+
69
+ <Tip warning={true}>
70
+
71
+ This API is 🧪 experimental.
72
+
73
+ </Tip>
74
+
75
+ Args:
76
+ transformer (`bool`, defaults to `True`): To apply fusion on the Transformer.
77
+ vae (`bool`, defaults to `True`): To apply fusion on the VAE.
78
+
79
+ """
80
+ if transformer:
81
+ if not self.fusing_transformer:
82
+ logger.warning(
83
+ "The UNet was not initially fused for QKV projections. Doing nothing."
84
+ )
85
+ else:
86
+ self.transformer.unfuse_qkv_projections()
87
+ self.fusing_transformer = False
88
+
89
+ if vae:
90
+ if not self.fusing_vae:
91
+ logger.warning(
92
+ "The VAE was not initially fused for QKV projections. Doing nothing."
93
+ )
94
+ else:
95
+ self.vae.unfuse_qkv_projections()
96
+ self.fusing_vae = False
detailgen3d/schedulers/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ from .scheduling_rectified_flow import (
2
+ RectifiedFlowScheduler,
3
+ compute_density_for_timestep_sampling,
4
+ compute_loss_weighting,
5
+ )
detailgen3d/schedulers/scheduling_rectified_flow.py ADDED
@@ -0,0 +1,327 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Adapted from https://github.com/huggingface/diffusers/blob/v0.30.3/src/diffusers/schedulers/scheduling_flow_match_euler_discrete.py.
3
+ """
4
+
5
+ import math
6
+ from dataclasses import dataclass
7
+ from typing import List, Optional, Tuple, Union
8
+
9
+ import numpy as np
10
+ import torch
11
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
12
+ from diffusers.schedulers.scheduling_utils import SchedulerMixin
13
+ from diffusers.utils import BaseOutput, logging
14
+ from torch.distributions import LogisticNormal
15
+
16
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
17
+
18
+
19
+ # TODO: may move to training_utils.py
20
+ def compute_density_for_timestep_sampling(
21
+ weighting_scheme: str,
22
+ batch_size: int,
23
+ logit_mean: float = 0.0,
24
+ logit_std: float = 1.0,
25
+ mode_scale: float = None,
26
+ ):
27
+ if weighting_scheme == "logit_normal":
28
+ # See 3.1 in the SD3 paper ($rf/lognorm(0.00,1.00)$).
29
+ u = torch.normal(
30
+ mean=logit_mean, std=logit_std, size=(batch_size,), device="cpu"
31
+ )
32
+ u = torch.nn.functional.sigmoid(u)
33
+ elif weighting_scheme == "logit_normal_dist":
34
+ u = (
35
+ LogisticNormal(loc=logit_mean, scale=logit_std)
36
+ .sample((batch_size,))[:, 0]
37
+ .to("cpu")
38
+ )
39
+ elif weighting_scheme == "mode":
40
+ u = torch.rand(size=(batch_size,), device="cpu")
41
+ u = 1 - u - mode_scale * (torch.cos(math.pi * u / 2) ** 2 - 1 + u)
42
+ else:
43
+ u = torch.rand(size=(batch_size,), device="cpu")
44
+ return u
45
+
46
+
47
+ def compute_loss_weighting(weighting_scheme: str, sigmas=None):
48
+ """
49
+ Computes loss weighting scheme for SD3 training.
50
+
51
+ Courtesy: This was contributed by Rafie Walker in https://github.com/huggingface/diffusers/pull/8528.
52
+
53
+ SD3 paper reference: https://arxiv.org/abs/2403.03206v1.
54
+ """
55
+ if weighting_scheme == "sigma_sqrt":
56
+ weighting = (sigmas**-2.0).float()
57
+ elif weighting_scheme == "cosmap":
58
+ bot = 1 - 2 * sigmas + 2 * sigmas**2
59
+ weighting = 2 / (math.pi * bot)
60
+ else:
61
+ weighting = torch.ones_like(sigmas)
62
+ return weighting
63
+
64
+
65
+ @dataclass
66
+ class RectifiedFlowSchedulerOutput(BaseOutput):
67
+ """
68
+ Output class for the scheduler's `step` function output.
69
+
70
+ Args:
71
+ prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
72
+ Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the
73
+ denoising loop.
74
+ """
75
+
76
+ prev_sample: torch.FloatTensor
77
+
78
+
79
+ class RectifiedFlowScheduler(SchedulerMixin, ConfigMixin):
80
+ """
81
+ The rectified flow scheduler is a scheduler that is used to propagate the diffusion process in the rectified flow.
82
+
83
+ This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic
84
+ methods the library implements for all schedulers such as loading and saving.
85
+
86
+ Args:
87
+ num_train_timesteps (`int`, defaults to 1000):
88
+ The number of diffusion steps to train the model.
89
+ timestep_spacing (`str`, defaults to `"linspace"`):
90
+ The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
91
+ Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
92
+ shift (`float`, defaults to 1.0):
93
+ The shift value for the timestep schedule.
94
+ """
95
+
96
+ _compatibles = []
97
+ order = 1
98
+
99
+ @register_to_config
100
+ def __init__(
101
+ self,
102
+ num_train_timesteps: int = 1000,
103
+ shift: float = 1.0,
104
+ use_dynamic_shifting: bool = False,
105
+ ):
106
+ # pre-compute timesteps and sigmas; no use in fact
107
+ # NOTE that shape diffusion sample timesteps randomly or in a distribution,
108
+ # instead of sampling from the pre-defined linspace
109
+ timesteps = np.array(
110
+ [
111
+ (1.0 - i / num_train_timesteps) * num_train_timesteps
112
+ for i in range(num_train_timesteps)
113
+ ]
114
+ )
115
+ timesteps = torch.from_numpy(timesteps).to(dtype=torch.float32)
116
+
117
+ sigmas = timesteps / num_train_timesteps
118
+ if not use_dynamic_shifting:
119
+ # when use_dynamic_shifting is True, we apply the timestep shifting on the fly based on the image resolution
120
+ sigmas = self.time_shift(sigmas)
121
+
122
+ self.timesteps = sigmas * num_train_timesteps
123
+
124
+ self._step_index = None
125
+ self._begin_index = None
126
+
127
+ self.sigmas = sigmas.to("cpu") # to avoid too much CPU/GPU communication
128
+
129
+ @property
130
+ def step_index(self):
131
+ """
132
+ The index counter for current timestep. It will increase 1 after each scheduler step.
133
+ """
134
+ return self._step_index
135
+
136
+ @property
137
+ def begin_index(self):
138
+ """
139
+ The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
140
+ """
141
+ return self._begin_index
142
+
143
+ # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index
144
+ def set_begin_index(self, begin_index: int = 0):
145
+ """
146
+ Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
147
+
148
+ Args:
149
+ begin_index (`int`):
150
+ The begin index for the scheduler.
151
+ """
152
+ self._begin_index = begin_index
153
+
154
+ def _sigma_to_t(self, sigma):
155
+ return sigma * self.config.num_train_timesteps
156
+
157
+ def _t_to_sigma(self, timestep):
158
+ return timestep / self.config.num_train_timesteps
159
+
160
+ def time_shift_dynamic(self, mu: float, sigma: float, t: torch.Tensor):
161
+ return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
162
+
163
+ def time_shift(self, t: torch.Tensor):
164
+ return self.config.shift * t / (1 + (self.config.shift - 1) * t)
165
+
166
+ def set_timesteps(
167
+ self,
168
+ num_inference_steps: int = None,
169
+ device: Union[str, torch.device] = None,
170
+ sigmas: Optional[List[float]] = None,
171
+ mu: Optional[float] = None,
172
+ ):
173
+ """
174
+ Sets the discrete timesteps used for the diffusion chain (to be run before inference).
175
+
176
+ Args:
177
+ num_inference_steps (`int`):
178
+ The number of diffusion steps used when generating samples with a pre-trained model.
179
+ device (`str` or `torch.device`, *optional*):
180
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
181
+ """
182
+
183
+ if self.config.use_dynamic_shifting and mu is None:
184
+ raise ValueError(
185
+ " you have to pass a value for `mu` when `use_dynamic_shifting` is set to be `True`"
186
+ )
187
+
188
+ if sigmas is None:
189
+ self.num_inference_steps = num_inference_steps
190
+ timesteps = np.array(
191
+ [
192
+ (1.0 - i / num_inference_steps) * self.config.num_train_timesteps
193
+ for i in range(num_inference_steps)
194
+ ]
195
+ ) # different from the original code in SD3
196
+ sigmas = timesteps / self.config.num_train_timesteps
197
+
198
+ if self.config.use_dynamic_shifting:
199
+ sigmas = self.time_shift_dynamic(mu, 1.0, sigmas)
200
+ else:
201
+ sigmas = self.config.shift * sigmas / (1 + (self.config.shift - 1) * sigmas)
202
+
203
+ sigmas = torch.from_numpy(sigmas).to(dtype=torch.float32, device=device)
204
+ timesteps = sigmas * self.config.num_train_timesteps
205
+
206
+ self.timesteps = timesteps.to(device=device)
207
+ self.sigmas = torch.cat([sigmas, torch.zeros(1, device=sigmas.device)])
208
+
209
+ self._step_index = None
210
+ self._begin_index = None
211
+
212
+ def index_for_timestep(self, timestep, schedule_timesteps=None):
213
+ if schedule_timesteps is None:
214
+ schedule_timesteps = self.timesteps
215
+
216
+ indices = (schedule_timesteps == timestep).nonzero()
217
+
218
+ # The sigma index that is taken for the **very** first `step`
219
+ # is always the second index (or the last index if there is only 1)
220
+ # This way we can ensure we don't accidentally skip a sigma in
221
+ # case we start in the middle of the denoising schedule (e.g. for image-to-image)
222
+ pos = 1 if len(indices) > 1 else 0
223
+
224
+ return indices[pos].item()
225
+
226
+ def _init_step_index(self, timestep):
227
+ if self.begin_index is None:
228
+ if isinstance(timestep, torch.Tensor):
229
+ timestep = timestep.to(self.timesteps.device)
230
+ self._step_index = self.index_for_timestep(timestep)
231
+ else:
232
+ self._step_index = self._begin_index
233
+
234
+ def step(
235
+ self,
236
+ model_output: torch.FloatTensor,
237
+ timestep: Union[float, torch.FloatTensor],
238
+ sample: torch.FloatTensor,
239
+ s_churn: float = 0.0,
240
+ s_tmin: float = 0.0,
241
+ s_tmax: float = float("inf"),
242
+ s_noise: float = 1.0,
243
+ generator: Optional[torch.Generator] = None,
244
+ return_dict: bool = True,
245
+ ) -> Union[RectifiedFlowSchedulerOutput, Tuple]:
246
+ """
247
+ Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
248
+ process from the learned model outputs (most often the predicted noise).
249
+
250
+ Args:
251
+ model_output (`torch.FloatTensor`):
252
+ The direct output from learned diffusion model.
253
+ timestep (`float`):
254
+ The current discrete timestep in the diffusion chain.
255
+ sample (`torch.FloatTensor`):
256
+ A current instance of a sample created by the diffusion process.
257
+ s_churn (`float`):
258
+ s_tmin (`float`):
259
+ s_tmax (`float`):
260
+ s_noise (`float`, defaults to 1.0):
261
+ Scaling factor for noise added to the sample.
262
+ generator (`torch.Generator`, *optional*):
263
+ A random number generator.
264
+ return_dict (`bool`):
265
+ Whether or not to return a [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or
266
+ tuple.
267
+
268
+ Returns:
269
+ [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or `tuple`:
270
+ If return_dict is `True`, [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] is
271
+ returned, otherwise a tuple is returned where the first element is the sample tensor.
272
+ """
273
+
274
+ if (
275
+ isinstance(timestep, int)
276
+ or isinstance(timestep, torch.IntTensor)
277
+ or isinstance(timestep, torch.LongTensor)
278
+ ):
279
+ raise ValueError(
280
+ (
281
+ "Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to"
282
+ " `EulerDiscreteScheduler.step()` is not supported. Make sure to pass"
283
+ " one of the `scheduler.timesteps` as a timestep."
284
+ ),
285
+ )
286
+
287
+ if self.step_index is None:
288
+ self._init_step_index(timestep)
289
+
290
+ # Upcast to avoid precision issues when computing prev_sample
291
+ sample = sample.to(torch.float32)
292
+
293
+ sigma = self.sigmas[self.step_index]
294
+ sigma_next = self.sigmas[self.step_index + 1]
295
+
296
+ # Here different directions are used for the flow matching
297
+ prev_sample = sample + (sigma - sigma_next) * model_output
298
+
299
+ # Cast sample back to model compatible dtype
300
+ prev_sample = prev_sample.to(model_output.dtype)
301
+
302
+ # upon completion increase step index by one
303
+ self._step_index += 1
304
+
305
+ if not return_dict:
306
+ return (prev_sample,)
307
+
308
+ return RectifiedFlowSchedulerOutput(prev_sample=prev_sample)
309
+
310
+ def scale_noise(
311
+ self,
312
+ original_samples: torch.Tensor,
313
+ noise: torch.Tensor,
314
+ timesteps: torch.IntTensor,
315
+ ) -> torch.Tensor:
316
+ """
317
+ Forward function for the noise scaling in the flow matching.
318
+ """
319
+ sigmas = self._t_to_sigma(timesteps.to(dtype=torch.float32))
320
+
321
+ while len(sigmas.shape) < len(original_samples.shape):
322
+ sigmas = sigmas.unsqueeze(-1)
323
+
324
+ return (1.0 - sigmas) * original_samples + sigmas * noise
325
+
326
+ def __len__(self):
327
+ return self.config.num_train_timesteps
detailgen3d/utils/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .constants import USE_FLASH3_BACKEND, USE_SDPA_BACKEND, disable_flash3
2
+ from .import_utils import is_flash3_available
detailgen3d/utils/typing.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This module contains type annotations for the project, using
3
+ 1. Python type hints (https://docs.python.org/3/library/typing.html) for Python objects
4
+ 2. jaxtyping (https://github.com/google/jaxtyping/blob/main/API.md) for PyTorch tensors
5
+
6
+ Two types of typing checking can be used:
7
+ 1. Static type checking with mypy (install with pip and enabled as the default linter in VSCode)
8
+ 2. Runtime type checking with typeguard (install with pip and triggered at runtime, mainly for tensor dtype and shape checking)
9
+ """
10
+
11
+ # Basic types
12
+ from typing import (
13
+ Any,
14
+ Callable,
15
+ Dict,
16
+ Iterable,
17
+ List,
18
+ Literal,
19
+ NamedTuple,
20
+ NewType,
21
+ Optional,
22
+ Sized,
23
+ Tuple,
24
+ Type,
25
+ TypedDict,
26
+ TypeVar,
27
+ Union,
28
+ )
29
+
30
+ # Tensor dtype
31
+ # for jaxtyping usage, see https://github.com/google/jaxtyping/blob/main/API.md
32
+ from jaxtyping import Bool, Complex, Float, Inexact, Int, Integer, Num, Shaped, UInt
33
+
34
+ # Config type
35
+ from omegaconf import DictConfig, ListConfig
36
+
37
+ # PyTorch Tensor type
38
+ from torch import Tensor
39
+
40
+ # Runtime type checking decorator
41
+ from typeguard import typechecked as typechecker
42
+
43
+
44
+ # Custom types
45
+ class FuncArgs(TypedDict):
46
+ """Type for instantiating a function with keyword arguments"""
47
+
48
+ name: str
49
+ kwargs: Dict[str, Any]
50
+
51
+ @staticmethod
52
+ def validate(variable):
53
+ necessary_keys = ["name", "kwargs"]
54
+ for key in necessary_keys:
55
+ assert key in variable, f"Key {key} is missing in {variable}"
56
+ if not isinstance(variable["name"], str):
57
+ raise TypeError(
58
+ f"Key 'name' should be a string, not {type(variable['name'])}"
59
+ )
60
+ if not isinstance(variable["kwargs"], dict):
61
+ raise TypeError(
62
+ f"Key 'kwargs' should be a dictionary, not {type(variable['kwargs'])}"
63
+ )
64
+ return variable
scripts/inference_detailgen3d.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import trimesh
4
+ from PIL import Image
5
+ from skimage import measure
6
+
7
+ from detailgen3d.inference_utils import generate_dense_grid_points
8
+ from detailgen3d.pipelines.pipeline_detailgen3d import (
9
+ DetailGen3DPipeline,
10
+ )
11
+
12
+ def load_mesh(mesh_path, num_pc=20480):
13
+ mesh = trimesh.load(mesh_path,force="mesh")
14
+
15
+ center = mesh.bounding_box.centroid
16
+ mesh.apply_translation(-center)
17
+ scale = max(mesh.bounding_box.extents)
18
+ mesh.apply_scale(1.9 / scale)
19
+
20
+ surface, face_indices = trimesh.sample.sample_surface(mesh, 1000000,)
21
+ normal = mesh.face_normals[face_indices]
22
+
23
+ rng = np.random.default_rng()
24
+ ind = rng.choice(surface.shape[0], num_pc, replace=False)
25
+ surface = torch.FloatTensor(surface[ind])
26
+ normal = torch.FloatTensor(normal[ind])
27
+ surface = torch.cat([surface, normal], dim=-1).unsqueeze(0).cuda()
28
+
29
+ return surface
30
+
31
+ if __name__ == "__main__":
32
+ device = "cuda"
33
+ dtype = torch.float16
34
+
35
+ # prepare pipeline
36
+ pipeline = DetailGen3DPipeline.from_pretrained(
37
+ "VAST-AI/DetailGen3D",
38
+ low_cpu_mem_usage=False
39
+ ).to(device, dtype=dtype)
40
+
41
+ # prepare data
42
+ image_path = "assets/image/503d193a-1b9b-4685-b05f-00ac82f93d7b.png"
43
+ image = Image.open(image_path).convert("RGB")
44
+
45
+ mesh_path = "assets/model/503d193a-1b9b-4685-b05f-00ac82f93d7b.glb"
46
+ surface = load_mesh(mesh_path).to(device, dtype=dtype)
47
+
48
+ batch_size = 1
49
+
50
+ # sample query points for decoding
51
+ box_min = np.array([-1.005, -1.005, -1.005])
52
+ box_max = np.array([1.005, 1.005, 1.005])
53
+ sampled_points, grid_size, bbox_size = generate_dense_grid_points(
54
+ bbox_min=box_min, bbox_max=box_max, octree_depth=9, indexing="ij"
55
+ )
56
+ sampled_points = torch.FloatTensor(sampled_points).to(device, dtype=dtype)
57
+ sampled_points = sampled_points.unsqueeze(0).repeat(batch_size, 1, 1)
58
+
59
+ # inference pipeline
60
+ sample = pipeline.vae.encode(surface).latent_dist.sample()
61
+ sdf = pipeline(image, latents=sample, sampled_points=sampled_points, noise_aug_level=0).samples[0]
62
+
63
+ # marching cubes
64
+ grid_logits = sdf.view(grid_size).cpu().numpy()
65
+ vertices, faces, normals, _ = measure.marching_cubes(
66
+ grid_logits, 0, method="lewiner"
67
+ )
68
+ vertices = vertices / grid_size * bbox_size + box_min
69
+ mesh = trimesh.Trimesh(vertices.astype(np.float32), np.ascontiguousarray(faces))
70
+ mesh.export("output.glb", file_type="glb")