xinjie.wang commited on
Commit
146eff7
·
1 Parent(s): 0111b97
asset3d_gen/data/backproject_v2.py CHANGED
@@ -2,7 +2,7 @@ import argparse
2
  import logging
3
  import math
4
  import os
5
-
6
  import cv2
7
  import numpy as np
8
  import nvdiffrast.torch as dr
@@ -247,7 +247,7 @@ class TextureBacker:
247
  (2 / 512) * max(self.render_wh[0], self.render_wh[1])
248
  )
249
 
250
- def load_mesh(self, mesh: trimesh.Trimesh) -> None:
251
  mesh.vertices, scale, center = normalize_vertices_array(mesh.vertices)
252
  self.scale, self.center = scale, center
253
 
@@ -257,9 +257,7 @@ class TextureBacker:
257
  mesh.faces = indices
258
  mesh.visual.uv = uvs
259
 
260
- self.vertices = torch.from_numpy(mesh.vertices).to(self.device).float()
261
- self.faces = torch.from_numpy(mesh.faces).to(self.device).to(torch.int)
262
- self.uv_map = torch.from_numpy(mesh.visual.uv).to(self.device).float()
263
 
264
  def get_mesh_np_attrs(
265
  self,
@@ -397,32 +395,32 @@ class TextureBacker:
397
  return texture_merge, trust_map_merge > 1e-8
398
 
399
  def uv_inpaint(
400
- self, texture: torch.Tensor, mask: torch.Tensor
401
  ) -> np.ndarray:
402
- texture_np = texture.cpu().numpy()
403
- mask_np = (mask.squeeze(-1).cpu().numpy() * 255).astype(np.uint8)
404
  vertices, faces, uv_map = self.get_mesh_np_attrs()
405
 
406
- texture_np, mask_np = _texture_inpaint_smooth(
407
- texture_np, mask_np, vertices, faces, uv_map
408
  )
409
- texture_np = texture_np.clip(0, 1)
410
- texture_np = cv2.inpaint(
411
- (texture_np * 255).astype(np.uint8),
412
- 255 - mask_np,
413
  3,
414
  cv2.INPAINT_NS,
415
  )
416
 
417
- return texture_np
418
 
419
- def __call__(
 
420
  self,
421
  colors: list[Image.Image],
422
  mesh: trimesh.Trimesh,
423
- output_path: str,
424
  ) -> trimesh.Trimesh:
425
- self.load_mesh(mesh)
 
 
426
  rendered_depth, masks = self.renderer.render_depth(
427
  self.vertices, self.faces
428
  )
@@ -448,12 +446,26 @@ class TextureBacker:
448
  weighted_cos_maps.append(weight * (cos_map**4))
449
 
450
  texture, mask = self.fast_bake_texture(textures, weighted_cos_maps)
451
- texture_np = self.uv_inpaint(texture, mask)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
452
  texture_np = post_process_texture(texture_np)
453
  vertices, faces, uv_map = self.get_mesh_np_attrs(
454
  self.scale, self.center
455
  )
456
-
457
  textured_mesh = save_mesh_with_mtl(
458
  vertices, faces, uv_map, texture_np, output_path
459
  )
@@ -567,7 +579,6 @@ def entrypoint(
567
  )
568
  save_dir = os.path.dirname(args.output_path)
569
  os.makedirs(save_dir, exist_ok=True)
570
- color_grid.save(f"{save_dir}/color_grid.png")
571
  color_grid = delight_model(color_grid)
572
  color_grid.save(f"{save_dir}/color_grid_delight.png")
573
 
 
2
  import logging
3
  import math
4
  import os
5
+ import spaces
6
  import cv2
7
  import numpy as np
8
  import nvdiffrast.torch as dr
 
247
  (2 / 512) * max(self.render_wh[0], self.render_wh[1])
248
  )
249
 
250
+ def load_mesh(self, mesh: trimesh.Trimesh) -> trimesh.Trimesh:
251
  mesh.vertices, scale, center = normalize_vertices_array(mesh.vertices)
252
  self.scale, self.center = scale, center
253
 
 
257
  mesh.faces = indices
258
  mesh.visual.uv = uvs
259
 
260
+ return mesh
 
 
261
 
262
  def get_mesh_np_attrs(
263
  self,
 
395
  return texture_merge, trust_map_merge > 1e-8
396
 
397
  def uv_inpaint(
398
+ self, texture: np.ndarray, mask: np.ndarray
399
  ) -> np.ndarray:
 
 
400
  vertices, faces, uv_map = self.get_mesh_np_attrs()
401
 
402
+ texture, mask = _texture_inpaint_smooth(
403
+ texture, mask, vertices, faces, uv_map
404
  )
405
+ texture = texture.clip(0, 1)
406
+ texture = cv2.inpaint(
407
+ (texture * 255).astype(np.uint8),
408
+ 255 - mask,
409
  3,
410
  cv2.INPAINT_NS,
411
  )
412
 
413
+ return texture
414
 
415
+ @spaces.GPU
416
+ def cuda_forward(
417
  self,
418
  colors: list[Image.Image],
419
  mesh: trimesh.Trimesh,
 
420
  ) -> trimesh.Trimesh:
421
+ self.vertices = torch.from_numpy(mesh.vertices).to(self.device).float()
422
+ self.faces = torch.from_numpy(mesh.faces).to(self.device).to(torch.int)
423
+ self.uv_map = torch.from_numpy(mesh.visual.uv).to(self.device).float()
424
  rendered_depth, masks = self.renderer.render_depth(
425
  self.vertices, self.faces
426
  )
 
446
  weighted_cos_maps.append(weight * (cos_map**4))
447
 
448
  texture, mask = self.fast_bake_texture(textures, weighted_cos_maps)
449
+
450
+ texture_np = texture.cpu().numpy()
451
+ mask_np = (mask.squeeze(-1).cpu().numpy() * 255).astype(np.uint8)
452
+
453
+ return texture_np, mask_np
454
+
455
+ def __call__(
456
+ self,
457
+ colors: list[Image.Image],
458
+ mesh: trimesh.Trimesh,
459
+ output_path: str,
460
+ ) -> trimesh.Trimesh:
461
+ mesh = self.load_mesh(mesh)
462
+ texture_np, mask_np = self.cuda_forward(colors, mesh)
463
+
464
+ texture_np = self.uv_inpaint(texture_np, mask_np)
465
  texture_np = post_process_texture(texture_np)
466
  vertices, faces, uv_map = self.get_mesh_np_attrs(
467
  self.scale, self.center
468
  )
 
469
  textured_mesh = save_mesh_with_mtl(
470
  vertices, faces, uv_map, texture_np, output_path
471
  )
 
579
  )
580
  save_dir = os.path.dirname(args.output_path)
581
  os.makedirs(save_dir, exist_ok=True)
 
582
  color_grid = delight_model(color_grid)
583
  color_grid.save(f"{save_dir}/color_grid_delight.png")
584
 
asset3d_gen/data/backup/backproject_v2 copy.py DELETED
@@ -1,652 +0,0 @@
1
- import argparse
2
- import logging
3
- import math
4
- import os
5
-
6
- import cv2
7
- import numpy as np
8
- import nvdiffrast.torch as dr
9
- import torch
10
- import torch.nn.functional as F
11
- try:
12
- from torchvision.transforms import functional as tF
13
- except ImportError as e:
14
- tF = None
15
- import trimesh
16
- import xatlas
17
- from PIL import Image
18
- from asset3d_gen.data.mesh_operator import MeshFixer
19
- from asset3d_gen.data.utils import (
20
- CameraSetting,
21
- DiffrastRender,
22
- get_images_from_grid,
23
- init_kal_camera,
24
- normalize_vertices_array,
25
- post_process_texture,
26
- save_mesh_with_mtl,
27
- )
28
- from asset3d_gen.models.delight_model import DelightingModel
29
- from asset3d_gen.models.sr_model import ImageRealESRGAN
30
-
31
- logging.basicConfig(
32
- format="%(asctime)s - %(levelname)s - %(message)s", level=logging.INFO
33
- )
34
- logger = logging.getLogger(__name__)
35
-
36
-
37
- __all__ = [
38
- "TextureBacker",
39
- ]
40
-
41
-
42
- def transform_vertices(
43
- mtx: torch.Tensor, pos: torch.Tensor, keepdim: bool = False
44
- ) -> torch.Tensor:
45
- """Transform 3D vertices using a projection matrix."""
46
- t_mtx = torch.as_tensor(mtx, device=pos.device, dtype=pos.dtype)
47
- if pos.size(-1) == 3:
48
- pos = torch.cat([pos, torch.ones_like(pos[..., :1])], dim=-1)
49
-
50
- result = pos @ t_mtx.T
51
-
52
- return result if keepdim else result.unsqueeze(0)
53
-
54
-
55
- def _bilinear_interpolation_scattering(
56
- image_h: int, image_w: int, coords: torch.Tensor, values: torch.Tensor
57
- ) -> torch.Tensor:
58
- """Bilinear interpolation scattering for grid-based value accumulation."""
59
- device = values.device
60
- dtype = values.dtype
61
- C = values.shape[-1]
62
-
63
- indices = coords * torch.tensor(
64
- [image_h - 1, image_w - 1], dtype=dtype, device=device
65
- )
66
- i, j = indices.unbind(-1)
67
-
68
- i0, j0 = (
69
- indices.floor()
70
- .long()
71
- .clamp(0, image_h - 2)
72
- .clamp(0, image_w - 2)
73
- .unbind(-1)
74
- )
75
- i1, j1 = i0 + 1, j0 + 1
76
-
77
- w_i = i - i0.float()
78
- w_j = j - j0.float()
79
- weights = torch.stack(
80
- [(1 - w_i) * (1 - w_j), (1 - w_i) * w_j, w_i * (1 - w_j), w_i * w_j],
81
- dim=1,
82
- )
83
-
84
- indices_comb = torch.stack(
85
- [
86
- torch.stack([i0, j0], dim=1),
87
- torch.stack([i0, j1], dim=1),
88
- torch.stack([i1, j0], dim=1),
89
- torch.stack([i1, j1], dim=1),
90
- ],
91
- dim=1,
92
- )
93
-
94
- grid = torch.zeros(image_h, image_w, C, device=device, dtype=dtype)
95
- cnt = torch.zeros(image_h, image_w, 1, device=device, dtype=dtype)
96
-
97
- for k in range(4):
98
- idx = indices_comb[:, k]
99
- w = weights[:, k].unsqueeze(-1)
100
-
101
- stride = torch.tensor([image_w, 1], device=device, dtype=torch.long)
102
- flat_idx = (idx * stride).sum(-1)
103
-
104
- grid.view(-1, C).scatter_add_(
105
- 0, flat_idx.unsqueeze(-1).expand(-1, C), values * w
106
- )
107
- cnt.view(-1, 1).scatter_add_(0, flat_idx.unsqueeze(-1), w)
108
-
109
- mask = cnt.squeeze(-1) > 0
110
- grid[mask] = grid[mask] / cnt[mask].repeat(1, C)
111
-
112
- return grid
113
-
114
-
115
- def _texture_inpaint_smooth(
116
- texture: np.ndarray,
117
- mask: np.ndarray,
118
- vertices: np.ndarray,
119
- faces: np.ndarray,
120
- uv_map: np.ndarray,
121
- ) -> tuple[np.ndarray, np.ndarray]:
122
- """Perform texture inpainting using vertex-based color propagation."""
123
- image_h, image_w, C = texture.shape
124
- N = vertices.shape[0]
125
-
126
- # Initialize vertex data structures
127
- vtx_mask = np.zeros(N, dtype=np.float32)
128
- vtx_colors = np.zeros((N, C), dtype=np.float32)
129
- unprocessed = []
130
- adjacency = [[] for _ in range(N)]
131
-
132
- # Build adjacency graph and initial color assignment
133
- for face_idx in range(faces.shape[0]):
134
- for k in range(3):
135
- uv_idx_k = faces[face_idx, k]
136
- v_idx = faces[face_idx, k]
137
-
138
- # Convert UV to pixel coordinates with boundary clamping
139
- u = np.clip(
140
- int(round(uv_map[uv_idx_k, 0] * (image_w - 1))), 0, image_w - 1
141
- )
142
- v = np.clip(
143
- int(round((1.0 - uv_map[uv_idx_k, 1]) * (image_h - 1))),
144
- 0,
145
- image_h - 1,
146
- )
147
-
148
- if mask[v, u]:
149
- vtx_mask[v_idx] = 1.0
150
- vtx_colors[v_idx] = texture[v, u]
151
- elif v_idx not in unprocessed:
152
- unprocessed.append(v_idx)
153
-
154
- # Build undirected adjacency graph
155
- neighbor = faces[face_idx, (k + 1) % 3]
156
- if neighbor not in adjacency[v_idx]:
157
- adjacency[v_idx].append(neighbor)
158
- if v_idx not in adjacency[neighbor]:
159
- adjacency[neighbor].append(v_idx)
160
-
161
- # Color propagation with dynamic stopping
162
- remaining_iters, prev_count = 2, 0
163
- while remaining_iters > 0:
164
- current_unprocessed = []
165
-
166
- for v_idx in unprocessed:
167
- valid_neighbors = [n for n in adjacency[v_idx] if vtx_mask[n] > 0]
168
- if not valid_neighbors:
169
- current_unprocessed.append(v_idx)
170
- continue
171
-
172
- # Calculate inverse square distance weights
173
- neighbors_pos = vertices[valid_neighbors]
174
- dist_sq = np.sum((vertices[v_idx] - neighbors_pos) ** 2, axis=1)
175
- weights = 1 / np.maximum(dist_sq, 1e-8)
176
-
177
- vtx_colors[v_idx] = np.average(
178
- vtx_colors[valid_neighbors], weights=weights, axis=0
179
- )
180
- vtx_mask[v_idx] = 1.0
181
-
182
- # Update iteration control
183
- if len(current_unprocessed) == prev_count:
184
- remaining_iters -= 1
185
- else:
186
- remaining_iters = min(remaining_iters + 1, 2)
187
- prev_count = len(current_unprocessed)
188
- unprocessed = current_unprocessed
189
-
190
- # Generate output texture
191
- inpainted_texture, updated_mask = texture.copy(), mask.copy()
192
- for face_idx in range(faces.shape[0]):
193
- for k in range(3):
194
- v_idx = faces[face_idx, k]
195
- if not vtx_mask[v_idx]:
196
- continue
197
-
198
- # UV coordinate conversion
199
- uv_idx_k = faces[face_idx, k]
200
- u = np.clip(
201
- int(round(uv_map[uv_idx_k, 0] * (image_w - 1))), 0, image_w - 1
202
- )
203
- v = np.clip(
204
- int(round((1.0 - uv_map[uv_idx_k, 1]) * (image_h - 1))),
205
- 0,
206
- image_h - 1,
207
- )
208
-
209
- inpainted_texture[v, u] = vtx_colors[v_idx]
210
- updated_mask[v, u] = 255
211
-
212
- return inpainted_texture, updated_mask
213
-
214
-
215
- def interp_tensers(tensors: list[torch.Tensor], target_wh: tuple[int, int]) -> list[torch.Tensor]:
216
- for idx in range(len(tensors)):
217
- tensor = tensors[idx].permute(2, 0, 1)
218
- tensor = tF.resize(tensor, target_wh[::-1], antialias=True)
219
- tensors[idx] = tensor.permute(1, 2, 0)
220
-
221
- return tensors
222
-
223
-
224
- class TextureBacker:
225
- """Texture baking pipeline for multi-view projection and fusion."""
226
-
227
- def __init__(
228
- self,
229
- camera_params: CameraSetting,
230
- view_weights: list[float],
231
- render_wh: tuple[int, int] = (2048, 2048),
232
- texture_wh: tuple[int, int] = (2048, 2048),
233
- bake_angle_thresh: int = 75,
234
- mask_thresh: float = 0.5,
235
- ):
236
- camera = init_kal_camera(camera_params)
237
- mv = camera.view_matrix() # (n 4 4) world2cam
238
- p = camera.intrinsics.projection_matrix()
239
- # NOTE: add a negative sign at P[0, 2] as the y axis is flipped in `nvdiffrast` output. # noqa
240
- p[:, 1, 1] = -p[:, 1, 1]
241
- self.renderer = DiffrastRender(
242
- p_matrix=p,
243
- mv_matrix=mv,
244
- resolution_hw=camera_params.resolution_hw,
245
- context=dr.RasterizeCudaContext(),
246
- mask_thresh=mask_thresh,
247
- grad_db=False,
248
- device=camera_params.device,
249
- antialias_mask=True,
250
- )
251
- self.camera = camera
252
- self.view_weights = view_weights
253
- self.device = camera_params.device
254
- self.render_wh = render_wh
255
- self.texture_wh = texture_wh
256
-
257
- self.bake_angle_thresh = bake_angle_thresh
258
- self.bake_unreliable_kernel_size = int(
259
- (2 / 512) * max(self.render_wh[0], self.render_wh[1])
260
- )
261
-
262
- def load_mesh(self, mesh: trimesh.Trimesh) -> None:
263
- mesh.vertices, scale, center = normalize_vertices_array(mesh.vertices)
264
- self.scale, self.center = scale, center
265
-
266
- vmapping, indices, uvs = xatlas.parametrize(mesh.vertices, mesh.faces)
267
- uvs[:, 1] = 1 - uvs[:, 1]
268
- mesh.vertices = mesh.vertices[vmapping]
269
- mesh.faces = indices
270
- mesh.visual.uv = uvs
271
-
272
- self.vertices = torch.from_numpy(mesh.vertices).to(self.device).float()
273
- self.faces = torch.from_numpy(mesh.faces).to(self.device).to(torch.int)
274
- self.uv_map = torch.from_numpy(mesh.visual.uv).to(self.device).float()
275
-
276
- def get_mesh_np_attrs(
277
- self,
278
- scale: float = None,
279
- center: np.ndarray = None,
280
- ) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
281
- vertices = self.vertices.cpu().numpy()
282
- faces = self.faces.cpu().numpy()
283
- uv_map = self.uv_map.cpu().numpy()
284
- uv_map[:, 1] = 1.0 - uv_map[:, 1]
285
-
286
- if scale is not None:
287
- vertices = vertices / scale
288
- if center is not None:
289
- vertices = vertices + center
290
-
291
- return vertices, faces, uv_map
292
-
293
- def _render_depth_edges(self, depth_image: torch.Tensor) -> torch.Tensor:
294
- depth_image_np = depth_image.cpu().numpy()
295
- depth_image_np = (depth_image_np * 255).astype(np.uint8)
296
- depth_edges = cv2.Canny(depth_image_np, 30, 80)
297
- sketch_image = (
298
- torch.from_numpy(depth_edges).to(depth_image.device).float() / 255
299
- )
300
- sketch_image = sketch_image.unsqueeze(-1)
301
-
302
- return sketch_image
303
-
304
- def compute_enhanced_viewnormal(
305
- self, mv_mtx: torch.Tensor, vertices: torch.Tensor, faces: torch.Tensor
306
- ) -> torch.Tensor:
307
- rast, _ = self.renderer.compute_dr_raster(vertices, faces)
308
- rendered_view_normals = []
309
- for idx in range(len(mv_mtx)):
310
- pos_cam = transform_vertices(mv_mtx[idx], vertices, keepdim=True)
311
- pos_cam = pos_cam[:, :3] / pos_cam[:, 3:]
312
- v0, v1, v2 = (pos_cam[faces[:, i]] for i in range(3))
313
- face_norm = F.normalize(
314
- torch.cross(v1 - v0, v2 - v0, dim=-1), dim=-1
315
- )
316
- vertex_norm = (
317
- torch.from_numpy(
318
- trimesh.geometry.mean_vertex_normals(
319
- len(pos_cam), faces.cpu(), face_norm.cpu()
320
- )
321
- )
322
- .to(vertices.device)
323
- .contiguous()
324
- )
325
- im_base_normals, _ = dr.interpolate(
326
- vertex_norm[None, ...].float(),
327
- rast[idx : idx + 1],
328
- faces.to(torch.int32),
329
- )
330
- rendered_view_normals.append(im_base_normals)
331
-
332
- rendered_view_normals = torch.cat(rendered_view_normals, dim=0)
333
-
334
- return rendered_view_normals
335
-
336
- def back_project(
337
- self, image, vis_mask, depth, normal, uv
338
- ) -> tuple[torch.Tensor, torch.Tensor]:
339
- image = np.array(image)
340
- image = torch.as_tensor(image, device=self.device, dtype=torch.float32)
341
- if image.ndim == 2:
342
- image = image.unsqueeze(-1)
343
- image = image / 255
344
-
345
- depth_inv = (1.0 - depth) * vis_mask
346
- sketch_image = self._render_depth_edges(depth_inv)
347
-
348
- cos = F.cosine_similarity(
349
- torch.tensor([[0, 0, 1]], device=self.device),
350
- normal.view(-1, 3),
351
- ).view_as(normal[..., :1])
352
- cos[cos < np.cos(np.radians(self.bake_angle_thresh))] = 0
353
-
354
- k = self.bake_unreliable_kernel_size * 2 + 1
355
- kernel = torch.ones((1, 1, k, k), device=self.device)
356
-
357
- vis_mask = vis_mask.permute(2, 0, 1).unsqueeze(0).float()
358
- vis_mask = F.conv2d(
359
- 1.0 - vis_mask,
360
- kernel,
361
- padding=k // 2,
362
- )
363
- vis_mask = 1.0 - (vis_mask > 0).float()
364
- vis_mask = vis_mask.squeeze(0).permute(1, 2, 0)
365
-
366
- sketch_image = sketch_image.permute(2, 0, 1).unsqueeze(0)
367
- sketch_image = F.conv2d(sketch_image, kernel, padding=k // 2)
368
- sketch_image = (sketch_image > 0).float()
369
- sketch_image = sketch_image.squeeze(0).permute(1, 2, 0)
370
- vis_mask = vis_mask * (sketch_image < 0.5)
371
-
372
- cos[vis_mask == 0] = 0
373
- valid_pixels = (vis_mask != 0).view(-1)
374
-
375
- return (
376
- self._scatter_texture(uv, image, valid_pixels),
377
- self._scatter_texture(uv, cos, valid_pixels),
378
- )
379
-
380
- def _scatter_texture(self, uv, data, mask):
381
- def __filter_data(data, mask):
382
- return data.view(-1, data.shape[-1])[mask]
383
-
384
- return _bilinear_interpolation_scattering(
385
- self.texture_wh[1],
386
- self.texture_wh[0],
387
- __filter_data(uv, mask)[..., [1, 0]],
388
- __filter_data(data, mask),
389
- )
390
-
391
- @torch.no_grad()
392
- def fast_bake_texture(
393
- self, textures: list[torch.Tensor], confidence_maps: list[torch.Tensor]
394
- ) -> tuple[torch.Tensor, torch.Tensor]:
395
- channel = textures[0].shape[-1]
396
- texture_merge = torch.zeros(self.texture_wh + [channel]).to(
397
- self.device
398
- )
399
- trust_map_merge = torch.zeros(self.texture_wh + [1]).to(self.device)
400
- for texture, cos_map in zip(textures, confidence_maps):
401
- view_sum = (cos_map > 0).sum()
402
- painted_sum = ((cos_map > 0) * (trust_map_merge > 0)).sum()
403
- if painted_sum / view_sum > 0.99:
404
- continue
405
- texture_merge += texture * cos_map
406
- trust_map_merge += cos_map
407
- texture_merge = texture_merge / torch.clamp(trust_map_merge, min=1e-8)
408
-
409
- return texture_merge, trust_map_merge > 1e-8
410
-
411
- def uv_inpaint(
412
- self, texture: torch.Tensor, mask: torch.Tensor
413
- ) -> np.ndarray:
414
- texture_np = texture.cpu().numpy()
415
- mask_np = (mask.squeeze(-1).cpu().numpy() * 255).astype(np.uint8)
416
- vertices, faces, uv_map = self.get_mesh_np_attrs()
417
-
418
- texture_np, mask_np = _texture_inpaint_smooth(
419
- texture_np, mask_np, vertices, faces, uv_map
420
- )
421
- texture_np = texture_np.clip(0, 1)
422
- texture_np = cv2.inpaint(
423
- (texture_np * 255).astype(np.uint8),
424
- 255 - mask_np,
425
- 3,
426
- cv2.INPAINT_NS,
427
- )
428
-
429
- return texture_np
430
-
431
- def __call__(
432
- self,
433
- colors: list[Image.Image],
434
- mesh: trimesh.Trimesh,
435
- output_path: str,
436
- ) -> trimesh.Trimesh:
437
- import time
438
- start = time.time()
439
- self.load_mesh(mesh)
440
- print("load_mesh", time.time() - start)
441
-
442
- start = time.time()
443
- rendered_depth, masks = self.renderer.render_depth(
444
- self.vertices, self.faces
445
- )
446
- norm_deps = self.renderer.normalize_map_by_mask(rendered_depth, masks)
447
- render_uvs, _ = self.renderer.render_uv(
448
- self.vertices, self.faces, self.uv_map
449
- )
450
- view_normals = self.compute_enhanced_viewnormal(
451
- self.renderer.mv_mtx, self.vertices, self.faces
452
- )
453
- print("0", time.time() - start)
454
-
455
- textures, weighted_cos_maps = [], []
456
-
457
- start = time.time()
458
- for color, mask, dep, normal, uv, weight in zip(
459
- colors,
460
- masks,
461
- norm_deps,
462
- view_normals,
463
- render_uvs,
464
- self.view_weights,
465
- ):
466
- mask, dep, normal, uv = interp_tensers([mask, dep, normal, uv], self.render_wh)
467
- texture, cos_map = self.back_project(color, mask, dep, normal, uv)
468
- textures.append(texture)
469
- weighted_cos_maps.append(weight * (cos_map**4))
470
- print("1", time.time() - start)
471
- start = time.time()
472
- texture, mask = self.fast_bake_texture(textures, weighted_cos_maps)
473
- print("2", time.time() - start)
474
- start = time.time()
475
- texture_np = self.uv_inpaint(texture, mask)
476
- print("3", time.time() - start)
477
- start = time.time()
478
- texture_np = post_process_texture(texture_np)
479
- vertices, faces, uv_map = self.get_mesh_np_attrs(
480
- self.scale, self.center
481
- )
482
-
483
- textured_mesh = save_mesh_with_mtl(
484
- vertices, faces, uv_map, texture_np, output_path
485
- )
486
- print("4", time.time() - start)
487
-
488
- return textured_mesh
489
-
490
-
491
- def parse_args():
492
- parser = argparse.ArgumentParser(description="Backproject texture")
493
- parser.add_argument(
494
- "--color_path",
495
- type=str,
496
- help="Multiview color image in 6x512x512 file path",
497
- )
498
- parser.add_argument(
499
- "--mesh_path",
500
- type=str,
501
- help="Mesh path, .obj, .glb or .ply",
502
- )
503
- parser.add_argument(
504
- "--output_path",
505
- type=str,
506
- help="Output mesh path with suffix",
507
- )
508
- parser.add_argument(
509
- "--num_images", type=int, default=6, help="Number of images to render."
510
- )
511
- parser.add_argument(
512
- "--elevation",
513
- nargs=2,
514
- type=float,
515
- default=[20.0, -10.0],
516
- help="Elevation angles for the camera (default: [20.0, -10.0])",
517
- )
518
- parser.add_argument(
519
- "--distance",
520
- type=float,
521
- default=5,
522
- help="Camera distance (default: 5)",
523
- )
524
- parser.add_argument(
525
- "--resolution_hw",
526
- type=int,
527
- nargs=2,
528
- default=(2048, 2048),
529
- help="Resolution of the mesh rendering",
530
- )
531
- parser.add_argument(
532
- "--target_hw",
533
- type=int,
534
- nargs=2,
535
- default=(2048, 2048),
536
- help="Target rendering images resolution",
537
- )
538
- parser.add_argument(
539
- "--fov",
540
- type=float,
541
- default=30,
542
- help="Field of view in degrees (default: 30)",
543
- )
544
- parser.add_argument(
545
- "--device",
546
- type=str,
547
- choices=["cpu", "cuda"],
548
- default="cuda",
549
- help="Device to run on (default: `cuda`)",
550
- )
551
- parser.add_argument(
552
- "--skip_fix_mesh", action="store_true", help="Fix mesh geometry."
553
- )
554
- parser.add_argument(
555
- "--texture_wh",
556
- nargs=2,
557
- type=int,
558
- default=[2048, 2048],
559
- help="Texture resolution width and height",
560
- )
561
- parser.add_argument(
562
- "--mesh_sipmlify_ratio",
563
- type=float,
564
- default=0.9,
565
- help="Mesh simplification ratio (default: 0.9)",
566
- )
567
- parser.add_argument(
568
- "--delight", action="store_true", help="Use delighting model."
569
- )
570
- args = parser.parse_args()
571
-
572
- return args
573
-
574
-
575
- def entrypoint(
576
- delight_model: DelightingModel = None,
577
- imagesr_model: ImageRealESRGAN = None,
578
- **kwargs,
579
- ) -> trimesh.Trimesh:
580
- args = parse_args()
581
- for k, v in kwargs.items():
582
- if hasattr(args, k) and v is not None:
583
- setattr(args, k, v)
584
-
585
- # Setup camera parameters.
586
- camera_params = CameraSetting(
587
- num_images=args.num_images,
588
- elevation=args.elevation,
589
- distance=args.distance,
590
- resolution_hw=args.resolution_hw,
591
- fov=math.radians(args.fov),
592
- device=args.device,
593
- )
594
- view_weights = [1, 0.1, 0.02, 0.1, 1, 0.02]
595
-
596
- color_grid = Image.open(args.color_path)
597
- if args.delight:
598
- if delight_model is None:
599
- delight_model = DelightingModel(
600
- model_path="/horizon-bucket/robot_lab/users/xinjie.wang/weights/hunyuan3d-delight-v2-0" # noqa
601
- )
602
- save_dir = os.path.dirname(args.output_path)
603
- os.makedirs(save_dir, exist_ok=True)
604
- color_grid.save(f"{save_dir}/color_grid.png")
605
- color_grid = delight_model(color_grid)
606
- color_grid.save(f"{save_dir}/color_grid_delight.png")
607
-
608
- multiviews = get_images_from_grid(color_grid, img_size=512)
609
-
610
- # Use RealESRGAN_x4plus for x4 (512->2048) image super resolution.
611
- if imagesr_model is None:
612
- imagesr_model = ImageRealESRGAN(outscale=4)
613
- multiviews = [imagesr_model(img.convert("RGB")) for img in multiviews]
614
- multiviews = [img.resize(args.target_hw[::-1]) for img in multiviews]
615
-
616
- mesh = trimesh.load(args.mesh_path)
617
- if isinstance(mesh, trimesh.Scene):
618
- mesh = mesh.dump(concatenate=True)
619
-
620
- if not args.skip_fix_mesh:
621
- mesh.vertices, scale, center = normalize_vertices_array(mesh.vertices)
622
- mesh_fixer = MeshFixer(mesh.vertices, mesh.faces, args.device)
623
- mesh.vertices, mesh.faces = mesh_fixer(
624
- filter_ratio=args.mesh_sipmlify_ratio,
625
- max_hole_size=0.04,
626
- resolution=1024,
627
- num_views=1000,
628
- norm_mesh_ratio=0.5,
629
- )
630
- # Restore scale.
631
- mesh.vertices = mesh.vertices / scale
632
- mesh.vertices = mesh.vertices + center
633
-
634
- # Baking texture to mesh.
635
- import time
636
- start = time.time()
637
- texture_backer = TextureBacker(
638
- camera_params=camera_params,
639
- view_weights=view_weights,
640
- render_wh=args.target_hw,
641
- texture_wh=args.texture_wh,
642
- )
643
- print(time.time()-start)
644
- start = time.time()
645
- textured_mesh = texture_backer(multiviews, mesh, args.output_path)
646
- print(f"Texture backproject time: {time.time() - start:.2f}s")
647
-
648
- return textured_mesh
649
-
650
-
651
- if __name__ == "__main__":
652
- entrypoint()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
asset3d_gen/data/backup/backproject_v2.py DELETED
@@ -1,700 +0,0 @@
1
- import logging
2
- import math
3
- from typing import Union
4
-
5
- import custom_rasterizer as cr
6
- import cv2
7
- import numpy as np
8
- import torch
9
- import torch.nn.functional as F
10
- import trimesh
11
- import xatlas
12
- from PIL import Image
13
- from asset3d_gen.data.utils import (
14
- get_images_from_file,
15
- normalize_vertices_array,
16
- post_process_texture,
17
- save_mesh_with_mtl,
18
- )
19
-
20
- logging.basicConfig(
21
- format="%(asctime)s - %(levelname)s - %(message)s", level=logging.INFO
22
- )
23
- logger = logging.getLogger(__name__)
24
-
25
-
26
- __all__ = ["TextureBacker", "Image_Super_Net", "Image_GANNet"]
27
-
28
-
29
- import math
30
- import numpy as np
31
-
32
-
33
- def get_perspective_projection(
34
- fov: float, aspect_wh: float, near: float = 0.01, far: float = 100
35
- ) -> np.ndarray:
36
- """Compute the perspective projection matrix for 3D rendering."""
37
- fov_rad = math.radians(fov)
38
- tan_half_fov = math.tan(fov_rad / 2.0)
39
-
40
- return np.array(
41
- [
42
- [1.0 / (tan_half_fov * aspect_wh), 0.0, 0.0, 0.0],
43
- [0.0, 1.0 / tan_half_fov, 0.0, 0.0],
44
- [
45
- 0.0,
46
- 0.0,
47
- -(far + near) / (far - near),
48
- -(2.0 * far * near) / (far - near),
49
- ],
50
- [0.0, 0.0, -1.0, 0.0],
51
- ],
52
- dtype=np.float32,
53
- )
54
-
55
-
56
- def transform_vertices(
57
- mtx: torch.Tensor, pos: torch.Tensor, keepdim: bool = False
58
- ) -> torch.Tensor:
59
- """Transform 3D vertices using a projection matrix."""
60
- t_mtx = torch.as_tensor(mtx, device=pos.device, dtype=pos.dtype)
61
- if pos.size(-1) == 3:
62
- pos = torch.cat([pos, torch.ones_like(pos[..., :1])], dim=-1)
63
-
64
- result = pos @ t_mtx.T
65
-
66
- return result if keepdim else result.unsqueeze(0)
67
-
68
-
69
- def compute_w2c_matrix(
70
- elev_deg: float, azim_deg: float, cam_dist: float
71
- ) -> np.ndarray:
72
- """Compute w2c 4x4 transformation matrix from spherical coordinates."""
73
-
74
- elev_rad = math.radians(-elev_deg)
75
- azim_rad = math.radians(azim_deg)
76
-
77
- sin_elev = math.sin(elev_rad)
78
- cos_elev = math.cos(elev_rad)
79
- sin_azim = math.sin(azim_rad)
80
- cos_azim = math.cos(azim_rad)
81
-
82
- cam_pos = np.array(
83
- [
84
- cam_dist * cos_elev * cos_azim,
85
- cam_dist * cos_elev * sin_azim,
86
- cam_dist * sin_elev,
87
- ]
88
- )
89
-
90
- look_dir = -cam_pos / np.linalg.norm(cam_pos)
91
- right_dir = np.cross(look_dir, [0, 0, 1])
92
- right_dir /= np.linalg.norm(right_dir)
93
- up_dir = np.cross(right_dir, look_dir)
94
-
95
- c2w = np.eye(4)
96
- c2w[:3, 0] = right_dir
97
- c2w[:3, 1] = up_dir
98
- c2w[:3, 2] = -look_dir
99
- c2w[:3, 3] = cam_pos
100
-
101
- try:
102
- w2c = np.linalg.inv(c2w)
103
- except np.linalg.LinAlgError as e:
104
- raise ArithmeticError("Failed to invert camera-to-world matrix") from e
105
-
106
- return w2c.astype(np.float32)
107
-
108
-
109
- def _bilinear_interpolation_scattering(
110
- image_h: int, image_w: int, coords: torch.Tensor, values: torch.Tensor
111
- ) -> torch.Tensor:
112
- """Bilinear interpolation scattering for grid-based value accumulation."""
113
- device = values.device
114
- dtype = values.dtype
115
- C = values.shape[-1]
116
-
117
- indices = coords * torch.tensor(
118
- [image_h - 1, image_w - 1], dtype=dtype, device=device
119
- )
120
- i, j = indices.unbind(-1)
121
-
122
- i0, j0 = (
123
- indices.floor()
124
- .long()
125
- .clamp(0, image_h - 2)
126
- .clamp(0, image_w - 2)
127
- .unbind(-1)
128
- )
129
- i1, j1 = i0 + 1, j0 + 1
130
-
131
- w_i = i - i0.float()
132
- w_j = j - j0.float()
133
- weights = torch.stack(
134
- [(1 - w_i) * (1 - w_j), (1 - w_i) * w_j, w_i * (1 - w_j), w_i * w_j],
135
- dim=1,
136
- )
137
-
138
- indices_comb = torch.stack(
139
- [
140
- torch.stack([i0, j0], dim=1),
141
- torch.stack([i0, j1], dim=1),
142
- torch.stack([i1, j0], dim=1),
143
- torch.stack([i1, j1], dim=1),
144
- ],
145
- dim=1,
146
- )
147
-
148
- grid = torch.zeros(image_h, image_w, C, device=device, dtype=dtype)
149
- cnt = torch.zeros(image_h, image_w, 1, device=device, dtype=dtype)
150
-
151
- for k in range(4):
152
- idx = indices_comb[:, k]
153
- w = weights[:, k].unsqueeze(-1)
154
-
155
- stride = torch.tensor([image_w, 1], device=device, dtype=torch.long)
156
- flat_idx = (idx * stride).sum(-1)
157
-
158
- grid.view(-1, C).scatter_add_(
159
- 0, flat_idx.unsqueeze(-1).expand(-1, C), values * w
160
- )
161
- cnt.view(-1, 1).scatter_add_(0, flat_idx.unsqueeze(-1), w)
162
-
163
- mask = cnt.squeeze(-1) > 0
164
- grid[mask] = grid[mask] / cnt[mask].repeat(1, C)
165
-
166
- return grid
167
-
168
-
169
- def _texture_inpaint_smooth(
170
- texture: np.ndarray,
171
- mask: np.ndarray,
172
- vertices: np.ndarray,
173
- faces: np.ndarray,
174
- uv_map: np.ndarray,
175
- ) -> tuple[np.ndarray, np.ndarray]:
176
- """Perform texture inpainting using vertex-based color propagation."""
177
- image_h, image_w, C = texture.shape
178
- N = vertices.shape[0]
179
-
180
- # Initialize vertex data structures
181
- vtx_mask = np.zeros(N, dtype=np.float32)
182
- vtx_colors = np.zeros((N, C), dtype=np.float32)
183
- unprocessed = []
184
- adjacency = [[] for _ in range(N)]
185
-
186
- # Build adjacency graph and initial color assignment
187
- for face_idx in range(faces.shape[0]):
188
- for k in range(3):
189
- uv_idx_k = faces[face_idx, k]
190
- v_idx = faces[face_idx, k]
191
-
192
- # Convert UV to pixel coordinates with boundary clamping
193
- u = np.clip(
194
- int(round(uv_map[uv_idx_k, 0] * (image_w - 1))), 0, image_w - 1
195
- )
196
- v = np.clip(
197
- int(round((1.0 - uv_map[uv_idx_k, 1]) * (image_h - 1))),
198
- 0,
199
- image_h - 1,
200
- )
201
-
202
- if mask[v, u]:
203
- vtx_mask[v_idx] = 1.0
204
- vtx_colors[v_idx] = texture[v, u]
205
- elif v_idx not in unprocessed:
206
- unprocessed.append(v_idx)
207
-
208
- # Build undirected adjacency graph
209
- neighbor = faces[face_idx, (k + 1) % 3]
210
- if neighbor not in adjacency[v_idx]:
211
- adjacency[v_idx].append(neighbor)
212
- if v_idx not in adjacency[neighbor]:
213
- adjacency[neighbor].append(v_idx)
214
-
215
- # Color propagation with dynamic stopping
216
- remaining_iters, prev_count = 2, 0
217
- while remaining_iters > 0:
218
- current_unprocessed = []
219
-
220
- for v_idx in unprocessed:
221
- valid_neighbors = [n for n in adjacency[v_idx] if vtx_mask[n] > 0]
222
- if not valid_neighbors:
223
- current_unprocessed.append(v_idx)
224
- continue
225
-
226
- # Calculate inverse square distance weights
227
- neighbors_pos = vertices[valid_neighbors]
228
- dist_sq = np.sum((vertices[v_idx] - neighbors_pos) ** 2, axis=1)
229
- weights = 1 / np.maximum(dist_sq, 1e-8)
230
-
231
- vtx_colors[v_idx] = np.average(
232
- vtx_colors[valid_neighbors], weights=weights, axis=0
233
- )
234
- vtx_mask[v_idx] = 1.0
235
-
236
- # Update iteration control
237
- if len(current_unprocessed) == prev_count:
238
- remaining_iters -= 1
239
- else:
240
- remaining_iters = min(remaining_iters + 1, 2)
241
- prev_count = len(current_unprocessed)
242
- unprocessed = current_unprocessed
243
-
244
- # Generate output texture
245
- inpainted_texture, updated_mask = texture.copy(), mask.copy()
246
- for face_idx in range(faces.shape[0]):
247
- for k in range(3):
248
- v_idx = faces[face_idx, k]
249
- if not vtx_mask[v_idx]:
250
- continue
251
-
252
- # UV coordinate conversion
253
- uv_idx_k = faces[face_idx, k]
254
- u = np.clip(
255
- int(round(uv_map[uv_idx_k, 0] * (image_w - 1))), 0, image_w - 1
256
- )
257
- v = np.clip(
258
- int(round((1.0 - uv_map[uv_idx_k, 1]) * (image_h - 1))),
259
- 0,
260
- image_h - 1,
261
- )
262
-
263
- inpainted_texture[v, u] = vtx_colors[v_idx]
264
- updated_mask[v, u] = 255
265
-
266
- return inpainted_texture, updated_mask
267
-
268
-
269
- class TextureBacker:
270
- """Texture baking pipeline for multi-view projection and fusion."""
271
-
272
- def __init__(
273
- self,
274
- camera_elevs: list[float],
275
- camera_azims: list[float],
276
- camera_distance: int,
277
- camera_fov: float,
278
- view_weights: list[float] = None,
279
- render_wh: tuple[int, int] = (2048, 2048),
280
- texture_wh: tuple[int, int] = (2048, 2048),
281
- use_antialias: bool = True,
282
- bake_angle_thres: int = 75,
283
- device="cuda",
284
- ):
285
- self.camera_elevs = camera_elevs
286
- self.camera_azims = camera_azims
287
- self.view_weights = (
288
- view_weights
289
- if view_weights is not None
290
- else [1] * len(camera_elevs)
291
- )
292
- self.device = device
293
- self.render_wh = render_wh
294
- self.texture_wh = texture_wh
295
-
296
- self.camera_distance = camera_distance
297
- self.use_antialias = use_antialias
298
-
299
- self.bake_angle_thres = bake_angle_thres
300
- self.bake_unreliable_kernel_size = int(
301
- (2 / 512) * max(self.render_wh[0], self.render_wh[1])
302
- )
303
-
304
- self.camera_proj_mat = get_perspective_projection(
305
- camera_fov,
306
- self.render_wh[1] / self.render_wh[0],
307
- )
308
- self.cnt = 0
309
-
310
- def rasterize_mesh(
311
- self,
312
- vertex: torch.Tensor,
313
- face: torch.Tensor,
314
- resolution: tuple[int, int],
315
- ) -> torch.Tensor:
316
- vertex = vertex[None] if vertex.ndim == 2 else vertex
317
- indices, weights = cr.rasterize(vertex, face, resolution)
318
-
319
- return torch.cat(
320
- [weights, indices.unsqueeze(-1).to(weights.dtype)], dim=-1
321
- ).unsqueeze(0)
322
-
323
- def raster_interpolate(
324
- self, uv: torch.Tensor, rast_out: torch.Tensor, faces: torch.Tensor
325
- ) -> torch.Tensor:
326
- barycentric = rast_out[0, ..., :-1]
327
- findices = rast_out[0, ..., -1]
328
- if uv.dim() == 2:
329
- uv = uv.unsqueeze(0)
330
-
331
- return cr.interpolate(uv, findices, barycentric, faces)[0]
332
-
333
- def load_mesh(self, mesh_path: str) -> None:
334
- mesh = trimesh.load(mesh_path)
335
- if isinstance(mesh, trimesh.Scene):
336
- mesh = mesh.dump(concatenate=True)
337
-
338
- mesh.vertices, scale, center = normalize_vertices_array(mesh.vertices)
339
- self.scale, self.center = scale, center
340
-
341
- vmapping, indices, uvs = xatlas.parametrize(mesh.vertices, mesh.faces)
342
- mesh.vertices = mesh.vertices[vmapping]
343
- mesh.faces = indices
344
- mesh.visual.uv = uvs
345
-
346
- self.vertices = torch.from_numpy(mesh.vertices).to(self.device).float()
347
- self.faces = torch.from_numpy(mesh.faces).to(self.device).to(torch.int)
348
- self.uv_map = torch.from_numpy(mesh.visual.uv).to(self.device).float()
349
-
350
- # Transformation of coordinate system
351
- self.vertices[:, [0, 1]] = -self.vertices[:, [0, 1]]
352
- self.vertices[:, [1, 2]] = self.vertices[:, [2, 1]]
353
- self.uv_map[:, 1] = 1 - self.uv_map[:, 1]
354
-
355
- def get_mesh_attrs(
356
- self,
357
- scale: float = None,
358
- center: np.ndarray = None,
359
- ) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
360
- vertices = self.vertices.cpu().numpy()
361
- faces = self.faces.cpu().numpy()
362
- uv_map = self.uv_map.cpu().numpy()
363
-
364
- # Inverse transformation of coordinate system
365
- vertices[:, [1, 2]] = vertices[:, [2, 1]]
366
- vertices[:, [0, 1]] = -vertices[:, [0, 1]]
367
- uv_map[:, 1] = 1.0 - uv_map[:, 1]
368
-
369
- if scale is not None:
370
- vertices = vertices / scale
371
- if center is not None:
372
- vertices = vertices + center
373
-
374
- return vertices, faces, uv_map
375
-
376
- def _render_depth_edges(self, depth_image: torch.Tensor) -> torch.Tensor:
377
- depth_image_np = depth_image.cpu().numpy()
378
- depth_image_np = (depth_image_np * 255).astype(np.uint8)
379
- depth_edges = cv2.Canny(depth_image_np, 30, 80)
380
- combined_edges = depth_edges
381
- sketch_image = (
382
- torch.from_numpy(combined_edges).to(depth_image.device).float()
383
- / 255
384
- )
385
- sketch_image = sketch_image.unsqueeze(-1)
386
-
387
- return sketch_image
388
-
389
- def back_project(
390
- self, image: Image.Image, elev: float, azim: float
391
- ) -> tuple[torch.Tensor, torch.Tensor]:
392
- if isinstance(image, Image.Image):
393
- image = np.array(image)
394
- image = torch.as_tensor(image, device=self.device, dtype=torch.float32)
395
- if image.ndim == 2:
396
- image = image.unsqueeze(-1)
397
- image = image / 255.0
398
-
399
- view_mat = compute_w2c_matrix(elev, azim, self.camera_distance)
400
- import pdb
401
-
402
- pdb.set_trace()
403
- pos_cam = transform_vertices(view_mat, self.vertices, keepdim=True)
404
- pos_clip = transform_vertices(self.camera_proj_mat, pos_cam)
405
- pos_cam = pos_cam[:, :3] / pos_cam[:, 3:]
406
-
407
- v0, v1, v2 = (pos_cam[self.faces[:, i]] for i in range(3))
408
- face_norm = F.normalize(torch.cross(v1 - v0, v2 - v0, dim=-1), dim=-1)
409
- vertex_norm = (
410
- torch.from_numpy(
411
- trimesh.geometry.mean_vertex_normals(
412
- len(pos_cam), self.faces.cpu(), face_norm.cpu()
413
- )
414
- )
415
- .to(self.device)
416
- .contiguous()
417
- )
418
-
419
- rast_out = self.rasterize_mesh(pos_clip, self.faces, image.shape[:2])
420
- vis_mask = torch.clamp(rast_out[..., -1:], 0, 1)[0]
421
-
422
- interp_data = {
423
- "normal": self.raster_interpolate(
424
- vertex_norm[None], rast_out, self.faces
425
- ),
426
- "uv": self.raster_interpolate(
427
- self.uv_map[None], rast_out, self.faces
428
- ),
429
- "depth": self.raster_interpolate(
430
- pos_cam[:, 2].reshape(1, -1, 1), rast_out, self.faces
431
- ),
432
- }
433
-
434
- valid_depth = interp_data["depth"][vis_mask > 0]
435
- depth_norm = (interp_data["depth"] - valid_depth.min()) / (
436
- valid_depth.max() - valid_depth.min()
437
- )
438
- # depth_norm[vis_mask <= 0] = 0
439
- sketch_image = self._render_depth_edges(depth_norm * vis_mask)
440
-
441
- # ddd = depth_norm * vis_mask
442
- # cv2.imwrite(f"v2_depth_d{self.cnt}.png", (ddd.cpu().numpy() * 255).astype(np.uint8))
443
-
444
- cv2.imwrite(
445
- f"v2_vis_mask{self.cnt}.png",
446
- (vis_mask.cpu().numpy() * 255).astype(np.uint8),
447
- )
448
- cv2.imwrite(
449
- f"v2_normal{self.cnt}.png",
450
- (interp_data["normal"].cpu().numpy() * 255).astype(np.uint8),
451
- )
452
- cv2.imwrite(
453
- f"v2_depth{self.cnt}.png",
454
- (depth_norm.cpu().numpy() * 255).astype(np.uint8),
455
- )
456
- cv2.imwrite(
457
- f"v2_uv{self.cnt}.png",
458
- (interp_data["uv"][..., 0].cpu().numpy() * 255).astype(np.uint8),
459
- )
460
- cv2.imwrite(
461
- f"v2_sketch{self.cnt}.png",
462
- (sketch_image.cpu().numpy() * 255).astype(np.uint8),
463
- )
464
-
465
- self.cnt += 1
466
-
467
- cos = F.cosine_similarity(
468
- torch.tensor([[0, 0, -1]], device=self.device),
469
- interp_data["normal"].view(-1, 3),
470
- ).view_as(interp_data["normal"][..., :1])
471
- cos[cos < np.cos(np.radians(self.bake_angle_thres))] = 0
472
-
473
- cv2.imwrite(
474
- f"v2_cos{self.cnt}.png", (cos.cpu().numpy() * 255).astype(np.uint8)
475
- )
476
-
477
- k = self.bake_unreliable_kernel_size * 2 + 1
478
- kernel = torch.ones((1, 1, k, k), device=self.device)
479
-
480
- vis_mask = vis_mask.permute(2, 0, 1).unsqueeze(0).float()
481
- vis_mask = F.conv2d(
482
- 1.0 - vis_mask,
483
- kernel,
484
- padding=k // 2,
485
- )
486
- vis_mask = 1.0 - (vis_mask > 0).float()
487
- vis_mask = vis_mask.squeeze(0).permute(1, 2, 0)
488
-
489
- sketch_image = sketch_image.permute(2, 0, 1).unsqueeze(0)
490
- sketch_image = F.conv2d(sketch_image, kernel, padding=k // 2)
491
- sketch_image = (sketch_image > 0).float()
492
- sketch_image = sketch_image.squeeze(0).permute(1, 2, 0)
493
- vis_mask = vis_mask * (sketch_image < 0.5)
494
-
495
- cos[vis_mask == 0] = 0
496
-
497
- vis_mask = cv2.imread(
498
- f"v3_db_mask{self.cnt}.png", cv2.IMREAD_GRAYSCALE
499
- )
500
- vis_mask = (
501
- torch.from_numpy(vis_mask[..., None]).to(self.device).float() / 255
502
- )
503
- # cos2 = cv2.imread(f"v3_db_cos{self.cnt}.png", cv2.IMREAD_GRAYSCALE)
504
- # cos2 = torch.from_numpy(cos2[..., None]).to(self.device).float() / 255
505
- # cos = cos2
506
-
507
- valid_pixels = (vis_mask != 0).view(-1)
508
- # import pdb; pdb.set_trace()
509
-
510
- cv2.imwrite(
511
- f"v2_db_sketch{self.cnt}.png",
512
- (sketch_image.cpu().numpy() * 255).astype(np.uint8),
513
- )
514
- cv2.imwrite(
515
- f"v2_db_uv{self.cnt}.png",
516
- (interp_data["uv"][..., 0].cpu().numpy() * 255).astype(np.uint8),
517
- )
518
- cv2.imwrite(
519
- f"v2_db_uv2{self.cnt}.png",
520
- (interp_data["uv"][..., 1].cpu().numpy() * 255).astype(np.uint8),
521
- )
522
- cv2.imwrite(
523
- f"v2_db_color{self.cnt}.png",
524
- (image.cpu().numpy() * 255).astype(np.uint8),
525
- )
526
- cv2.imwrite(
527
- f"v2_db_cos{self.cnt}.png",
528
- (cos.cpu().numpy() * 255).astype(np.uint8),
529
- )
530
- cv2.imwrite(
531
- f"v2_db_mask{self.cnt}.png",
532
- (vis_mask.cpu().numpy() * 255).astype(np.uint8),
533
- )
534
- # import pdb; pdb.set_trace()
535
- return (
536
- self._scatter_texture(interp_data["uv"], image, valid_pixels),
537
- self._scatter_texture(interp_data["uv"], cos, valid_pixels),
538
- )
539
-
540
- def _scatter_texture(self, uv, data, mask):
541
- def __filter_data(data, mask):
542
- return data.view(-1, data.shape[-1])[mask]
543
-
544
- return _bilinear_interpolation_scattering(
545
- self.texture_wh[1],
546
- self.texture_wh[0],
547
- __filter_data(uv, mask)[..., [1, 0]],
548
- __filter_data(data, mask),
549
- )
550
-
551
- @torch.no_grad()
552
- def fast_bake_texture(
553
- self, textures: list[torch.Tensor], confidence_maps: list[torch.Tensor]
554
- ) -> tuple[torch.Tensor, torch.Tensor]:
555
- channel = textures[0].shape[-1]
556
- texture_merge = torch.zeros(self.texture_wh + (channel,)).to(
557
- self.device
558
- )
559
- trust_map_merge = torch.zeros(self.texture_wh + (1,)).to(self.device)
560
- for texture, cos_map in zip(textures, confidence_maps):
561
- view_sum = (cos_map > 0).sum()
562
- painted_sum = ((cos_map > 0) * (trust_map_merge > 0)).sum()
563
- if painted_sum / view_sum > 0.99:
564
- continue
565
- texture_merge += texture * cos_map
566
- trust_map_merge += cos_map
567
- texture_merge = texture_merge / torch.clamp(trust_map_merge, min=1e-8)
568
-
569
- return texture_merge, trust_map_merge > 1e-8
570
-
571
- def uv_inpaint(
572
- self, texture: torch.Tensor, mask: torch.Tensor
573
- ) -> np.ndarray:
574
- texture_np = texture.cpu().numpy()
575
- mask_np = (mask.squeeze(-1).cpu().numpy() * 255).astype(np.uint8)
576
- vertices, faces, uv_map = self.get_mesh_attrs()
577
- # import pdb; pdb.set_trace()
578
- texture_np, mask_np = _texture_inpaint_smooth(
579
- texture_np, mask_np, vertices, faces, uv_map
580
- )
581
- texture_np = texture_np.clip(0, 1)
582
- texture_np = cv2.inpaint(
583
- (texture_np * 255).astype(np.uint8),
584
- 255 - mask_np,
585
- 3,
586
- cv2.INPAINT_NS,
587
- )
588
-
589
- return texture_np
590
-
591
- def __call__(
592
- self, colors: list[Image.Image], input_mesh: str, output_path: str
593
- ) -> trimesh.Trimesh:
594
- self.load_mesh(input_mesh)
595
-
596
- textures, weighted_cos_maps = [], []
597
- for color, cam_elev, cam_azim, weight in zip(
598
- colors, self.camera_elevs, self.camera_azims, self.view_weights
599
- ):
600
- texture, cos_map = self.back_project(color, cam_elev, cam_azim)
601
- cv2.imwrite(
602
- f"v2_texture{self.cnt}.png",
603
- (texture.cpu().numpy() * 255).astype(np.uint8),
604
- )
605
- cv2.imwrite(
606
- f"v2_texture_cos{self.cnt}.png",
607
- (cos_map.cpu().numpy() * 255).astype(np.uint8),
608
- )
609
- # import pdb; pdb.set_trace()
610
- textures.append(texture)
611
- weighted_cos_maps.append(weight * (cos_map**4))
612
-
613
- texture, mask = self.fast_bake_texture(textures, weighted_cos_maps)
614
- texture_np = self.uv_inpaint(texture, mask)
615
- texture_np = post_process_texture(texture_np)
616
- vertices, faces, uvs = self.get_mesh_attrs(self.scale, self.center)
617
- # import pdb; pdb.set_trace()
618
- cv2.imwrite("v2_texture_np.png", texture_np)
619
-
620
- textured_mesh = save_mesh_with_mtl(
621
- vertices, faces, uvs, texture_np, output_path
622
- )
623
-
624
- return textured_mesh
625
-
626
-
627
- class Image_Super_Net:
628
- def __init__(self, device="cuda"):
629
- from diffusers import StableDiffusionUpscalePipeline
630
-
631
- self.up_pipeline_x4 = StableDiffusionUpscalePipeline.from_pretrained(
632
- "stabilityai/stable-diffusion-x4-upscaler",
633
- torch_dtype=torch.float16,
634
- ).to(device)
635
- self.up_pipeline_x4.set_progress_bar_config(disable=True)
636
-
637
- def __call__(self, image, prompt=""):
638
- with torch.no_grad():
639
- upscaled_image = self.up_pipeline_x4(
640
- prompt=[prompt],
641
- image=image,
642
- num_inference_steps=10,
643
- ).images[0]
644
-
645
- return upscaled_image
646
-
647
-
648
- class Image_GANNet:
649
- def __init__(self, outscale: int):
650
- from basicsr.archs.rrdbnet_arch import RRDBNet
651
- from realesrgan import RealESRGANer
652
-
653
- self.outscale = outscale
654
- model = RRDBNet(
655
- num_in_ch=3,
656
- num_out_ch=3,
657
- num_feat=64,
658
- num_block=23,
659
- num_grow_ch=32,
660
- scale=4,
661
- )
662
- self.upsampler = RealESRGANer(
663
- scale=4,
664
- model_path="/horizon-bucket/robot_lab/users/xinjie.wang/weights/super_resolution/RealESRGAN_x4plus.pth", # noqa
665
- model=model,
666
- pre_pad=0,
667
- half=True,
668
- )
669
-
670
- def __call__(self, image: Union[Image.Image, np.ndarray]) -> Image.Image:
671
- if isinstance(image, Image.Image):
672
- image = np.array(image)
673
- output, _ = self.upsampler.enhance(image, outscale=self.outscale)
674
-
675
- return Image.fromarray(output)
676
-
677
-
678
- if __name__ == "__main__":
679
- device = "cuda"
680
- color_path = "outputs/texture_mesh_gen/multi_view/color_sample0.png"
681
- mesh_path = "outputs/texture_mesh_gen/texture_mesh/kettle_color.glb"
682
- output_path = "robot_test_v2/robot.obj"
683
- target_image_size = (2048, 2048)
684
-
685
- super_model = Image_GANNet(outscale=4)
686
- multiviews = get_images_from_file(color_path, img_size=512)
687
-
688
- texture_backer = TextureBacker(
689
- camera_elevs=[20, 20, 20, -10, -10, -10],
690
- camera_azims=[-180, -60, 60, -120, 0, 120],
691
- view_weights=[1, 0.2, 0.2, 0.2, 1, 0.2],
692
- camera_distance=5,
693
- camera_fov=30,
694
- render_wh=(2048, 2048),
695
- texture_wh=(2048, 2048),
696
- )
697
-
698
- multiviews = [super_model(img) for img in multiviews]
699
- multiviews = [img.convert("RGB") for img in multiviews]
700
- textured_mesh = texture_backer(multiviews, mesh_path, output_path)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
asset3d_gen/data/backup/backproject_v3.py DELETED
@@ -1,866 +0,0 @@
1
- import logging
2
- import math
3
- from typing import Union
4
-
5
- import custom_rasterizer as cr
6
- import cv2
7
- import numpy as np
8
- import torch
9
- import torch.nn.functional as F
10
- import trimesh
11
- import xatlas
12
- from PIL import Image
13
- from asset3d_gen.data.utils import (
14
- get_images_from_file,
15
- normalize_vertices_array,
16
- post_process_texture,
17
- save_mesh_with_mtl,
18
- )
19
-
20
- logging.basicConfig(
21
- format="%(asctime)s - %(levelname)s - %(message)s", level=logging.INFO
22
- )
23
- logger = logging.getLogger(__name__)
24
-
25
-
26
- __all__ = ["TextureBacker", "Image_Super_Net", "Image_GANNet"]
27
-
28
-
29
- import math
30
- import numpy as np
31
-
32
-
33
- def get_perspective_projection(
34
- fov: float, aspect_wh: float, near: float = 0.01, far: float = 100
35
- ) -> np.ndarray:
36
- """Compute the perspective projection matrix for 3D rendering."""
37
- fov_rad = math.radians(fov)
38
- tan_half_fov = math.tan(fov_rad / 2.0)
39
-
40
- return np.array(
41
- [
42
- [1.0 / (tan_half_fov * aspect_wh), 0.0, 0.0, 0.0],
43
- [0.0, 1.0 / tan_half_fov, 0.0, 0.0],
44
- [
45
- 0.0,
46
- 0.0,
47
- -(far + near) / (far - near),
48
- -(2.0 * far * near) / (far - near),
49
- ],
50
- [0.0, 0.0, -1.0, 0.0],
51
- ],
52
- dtype=np.float32,
53
- )
54
-
55
-
56
- def transform_vertices(
57
- mtx: torch.Tensor, pos: torch.Tensor, keepdim: bool = False
58
- ) -> torch.Tensor:
59
- """Transform 3D vertices using a projection matrix."""
60
- t_mtx = torch.as_tensor(mtx, device=pos.device, dtype=pos.dtype)
61
- if pos.size(-1) == 3:
62
- pos = torch.cat([pos, torch.ones_like(pos[..., :1])], dim=-1)
63
-
64
- result = pos @ t_mtx.T
65
-
66
- return result if keepdim else result.unsqueeze(0)
67
-
68
-
69
- def compute_w2c_matrix(
70
- elev_deg: float, azim_deg: float, cam_dist: float
71
- ) -> np.ndarray:
72
- """Compute w2c 4x4 transformation matrix from spherical coordinates."""
73
-
74
- elev_rad = math.radians(-elev_deg)
75
- azim_rad = math.radians(azim_deg)
76
-
77
- sin_elev = math.sin(elev_rad)
78
- cos_elev = math.cos(elev_rad)
79
- sin_azim = math.sin(azim_rad)
80
- cos_azim = math.cos(azim_rad)
81
-
82
- cam_pos = np.array(
83
- [
84
- cam_dist * cos_elev * cos_azim,
85
- cam_dist * cos_elev * sin_azim,
86
- cam_dist * sin_elev,
87
- ]
88
- )
89
-
90
- look_dir = -cam_pos / np.linalg.norm(cam_pos)
91
- right_dir = np.cross(look_dir, [0, 0, 1])
92
- right_dir /= np.linalg.norm(right_dir)
93
- up_dir = np.cross(right_dir, look_dir)
94
-
95
- c2w = np.eye(4)
96
- c2w[:3, 0] = right_dir
97
- c2w[:3, 1] = up_dir
98
- c2w[:3, 2] = -look_dir
99
- c2w[:3, 3] = cam_pos
100
-
101
- try:
102
- w2c = np.linalg.inv(c2w)
103
- except np.linalg.LinAlgError as e:
104
- raise ArithmeticError("Failed to invert camera-to-world matrix") from e
105
-
106
- return w2c.astype(np.float32)
107
-
108
-
109
- def _bilinear_interpolation_scattering(
110
- image_h: int, image_w: int, coords: torch.Tensor, values: torch.Tensor
111
- ) -> torch.Tensor:
112
- """Bilinear interpolation scattering for grid-based value accumulation."""
113
- device = values.device
114
- dtype = values.dtype
115
- C = values.shape[-1]
116
-
117
- indices = coords * torch.tensor(
118
- [image_h - 1, image_w - 1], dtype=dtype, device=device
119
- )
120
- i, j = indices.unbind(-1)
121
-
122
- i0, j0 = (
123
- indices.floor()
124
- .long()
125
- .clamp(0, image_h - 2)
126
- .clamp(0, image_w - 2)
127
- .unbind(-1)
128
- )
129
- i1, j1 = i0 + 1, j0 + 1
130
-
131
- w_i = i - i0.float()
132
- w_j = j - j0.float()
133
- weights = torch.stack(
134
- [(1 - w_i) * (1 - w_j), (1 - w_i) * w_j, w_i * (1 - w_j), w_i * w_j],
135
- dim=1,
136
- )
137
-
138
- indices_comb = torch.stack(
139
- [
140
- torch.stack([i0, j0], dim=1),
141
- torch.stack([i0, j1], dim=1),
142
- torch.stack([i1, j0], dim=1),
143
- torch.stack([i1, j1], dim=1),
144
- ],
145
- dim=1,
146
- )
147
-
148
- grid = torch.zeros(image_h, image_w, C, device=device, dtype=dtype)
149
- cnt = torch.zeros(image_h, image_w, 1, device=device, dtype=dtype)
150
-
151
- for k in range(4):
152
- idx = indices_comb[:, k]
153
- w = weights[:, k].unsqueeze(-1)
154
-
155
- stride = torch.tensor([image_w, 1], device=device, dtype=torch.long)
156
- flat_idx = (idx * stride).sum(-1)
157
-
158
- grid.view(-1, C).scatter_add_(
159
- 0, flat_idx.unsqueeze(-1).expand(-1, C), values * w
160
- )
161
- cnt.view(-1, 1).scatter_add_(0, flat_idx.unsqueeze(-1), w)
162
-
163
- mask = cnt.squeeze(-1) > 0
164
- grid[mask] = grid[mask] / cnt[mask].repeat(1, C)
165
-
166
- return grid
167
-
168
-
169
- def _texture_inpaint_smooth(
170
- texture: np.ndarray,
171
- mask: np.ndarray,
172
- vertices: np.ndarray,
173
- faces: np.ndarray,
174
- uv_map: np.ndarray,
175
- ) -> tuple[np.ndarray, np.ndarray]:
176
- """Perform texture inpainting using vertex-based color propagation."""
177
- image_h, image_w, C = texture.shape
178
- N = vertices.shape[0]
179
-
180
- # Initialize vertex data structures
181
- vtx_mask = np.zeros(N, dtype=np.float32)
182
- vtx_colors = np.zeros((N, C), dtype=np.float32)
183
- unprocessed = []
184
- adjacency = [[] for _ in range(N)]
185
-
186
- # Build adjacency graph and initial color assignment
187
- for face_idx in range(faces.shape[0]):
188
- for k in range(3):
189
- uv_idx_k = faces[face_idx, k]
190
- v_idx = faces[face_idx, k]
191
-
192
- # Convert UV to pixel coordinates with boundary clamping
193
- u = np.clip(
194
- int(round(uv_map[uv_idx_k, 0] * (image_w - 1))), 0, image_w - 1
195
- )
196
- v = np.clip(
197
- int(round((1.0 - uv_map[uv_idx_k, 1]) * (image_h - 1))),
198
- 0,
199
- image_h - 1,
200
- )
201
-
202
- if mask[v, u]:
203
- vtx_mask[v_idx] = 1.0
204
- vtx_colors[v_idx] = texture[v, u]
205
- elif v_idx not in unprocessed:
206
- unprocessed.append(v_idx)
207
-
208
- # Build undirected adjacency graph
209
- neighbor = faces[face_idx, (k + 1) % 3]
210
- if neighbor not in adjacency[v_idx]:
211
- adjacency[v_idx].append(neighbor)
212
- if v_idx not in adjacency[neighbor]:
213
- adjacency[neighbor].append(v_idx)
214
-
215
- # Color propagation with dynamic stopping
216
- remaining_iters, prev_count = 2, 0
217
- while remaining_iters > 0:
218
- current_unprocessed = []
219
-
220
- for v_idx in unprocessed:
221
- valid_neighbors = [n for n in adjacency[v_idx] if vtx_mask[n] > 0]
222
- if not valid_neighbors:
223
- current_unprocessed.append(v_idx)
224
- continue
225
-
226
- # Calculate inverse square distance weights
227
- neighbors_pos = vertices[valid_neighbors]
228
- dist_sq = np.sum((vertices[v_idx] - neighbors_pos) ** 2, axis=1)
229
- weights = 1 / np.maximum(dist_sq, 1e-8)
230
-
231
- vtx_colors[v_idx] = np.average(
232
- vtx_colors[valid_neighbors], weights=weights, axis=0
233
- )
234
- vtx_mask[v_idx] = 1.0
235
-
236
- # Update iteration control
237
- if len(current_unprocessed) == prev_count:
238
- remaining_iters -= 1
239
- else:
240
- remaining_iters = min(remaining_iters + 1, 2)
241
- prev_count = len(current_unprocessed)
242
- unprocessed = current_unprocessed
243
-
244
- # Generate output texture
245
- inpainted_texture, updated_mask = texture.copy(), mask.copy()
246
- for face_idx in range(faces.shape[0]):
247
- for k in range(3):
248
- v_idx = faces[face_idx, k]
249
- if not vtx_mask[v_idx]:
250
- continue
251
-
252
- # UV coordinate conversion
253
- uv_idx_k = faces[face_idx, k]
254
- u = np.clip(
255
- int(round(uv_map[uv_idx_k, 0] * (image_w - 1))), 0, image_w - 1
256
- )
257
- v = np.clip(
258
- int(round((1.0 - uv_map[uv_idx_k, 1]) * (image_h - 1))),
259
- 0,
260
- image_h - 1,
261
- )
262
-
263
- inpainted_texture[v, u] = vtx_colors[v_idx]
264
- updated_mask[v, u] = 255
265
-
266
- return inpainted_texture, updated_mask
267
-
268
-
269
- class TextureBacker:
270
- """Texture baking pipeline for multi-view projection and fusion."""
271
-
272
- def __init__(
273
- self,
274
- camera_elevs: list[float],
275
- camera_azims: list[float],
276
- camera_distance: int,
277
- camera_fov: float,
278
- view_weights: list[float] = None,
279
- render_wh: tuple[int, int] = (2048, 2048),
280
- texture_wh: tuple[int, int] = (2048, 2048),
281
- use_antialias: bool = True,
282
- bake_angle_thresh: int = 75,
283
- device="cuda",
284
- ):
285
- self.camera_elevs = camera_elevs
286
- self.camera_azims = camera_azims
287
- self.view_weights = (
288
- view_weights
289
- if view_weights is not None
290
- else [1] * len(camera_elevs)
291
- )
292
- self.device = device
293
- self.render_wh = render_wh
294
- self.texture_wh = texture_wh
295
-
296
- self.camera_distance = camera_distance
297
- self.use_antialias = use_antialias
298
-
299
- self.bake_angle_thresh = bake_angle_thresh
300
- self.bake_unreliable_kernel_size = int(
301
- (2 / 512) * max(self.render_wh[0], self.render_wh[1])
302
- )
303
-
304
- self.camera_proj_mat = get_perspective_projection(
305
- camera_fov,
306
- self.render_wh[1] / self.render_wh[0],
307
- )
308
- self.cnt = 0
309
-
310
- def rasterize_mesh(
311
- self,
312
- vertex: torch.Tensor,
313
- face: torch.Tensor,
314
- resolution: tuple[int, int],
315
- ) -> torch.Tensor:
316
- vertex = vertex[None] if vertex.ndim == 2 else vertex
317
- indices, weights = cr.rasterize(vertex, face, resolution)
318
-
319
- return torch.cat(
320
- [weights, indices.unsqueeze(-1).to(weights.dtype)], dim=-1
321
- ).unsqueeze(0)
322
-
323
- def raster_interpolate(
324
- self, uv: torch.Tensor, rast_out: torch.Tensor, faces: torch.Tensor
325
- ) -> torch.Tensor:
326
- barycentric = rast_out[0, ..., :-1]
327
- findices = rast_out[0, ..., -1]
328
- if uv.dim() == 2:
329
- uv = uv.unsqueeze(0)
330
-
331
- return cr.interpolate(uv, findices, barycentric, faces)[0]
332
-
333
- def load_mesh(self, mesh_path: str) -> None:
334
- mesh = trimesh.load(mesh_path)
335
- if isinstance(mesh, trimesh.Scene):
336
- mesh = mesh.dump(concatenate=True)
337
-
338
- mesh.vertices, scale, center = normalize_vertices_array(mesh.vertices)
339
- self.scale, self.center = scale, center
340
-
341
- vmapping, indices, uvs = xatlas.parametrize(mesh.vertices, mesh.faces)
342
- mesh.vertices = mesh.vertices[vmapping]
343
- mesh.faces = indices
344
- mesh.visual.uv = uvs
345
-
346
- self.vertices = torch.from_numpy(mesh.vertices).to(self.device).float()
347
- self.faces = torch.from_numpy(mesh.faces).to(self.device).to(torch.int)
348
- self.uv_map = torch.from_numpy(mesh.visual.uv).to(self.device).float()
349
-
350
- # Transformation of coordinate system
351
- self.vertices[:, [0, 1]] = -self.vertices[:, [0, 1]]
352
- self.vertices[:, [1, 2]] = self.vertices[:, [2, 1]]
353
- self.uv_map[:, 1] = 1 - self.uv_map[:, 1]
354
-
355
- def get_mesh_attrs(
356
- self,
357
- scale: float = None,
358
- center: np.ndarray = None,
359
- ) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
360
- vertices = self.vertices.cpu().numpy()
361
- faces = self.faces.cpu().numpy()
362
- uv_map = self.uv_map.cpu().numpy()
363
-
364
- if scale is not None:
365
- vertices = vertices / scale
366
- if center is not None:
367
- vertices = vertices + center
368
-
369
- return vertices, faces, uv_map
370
-
371
- def _render_depth_edges(self, depth_image: torch.Tensor) -> torch.Tensor:
372
- depth_image_np = depth_image.cpu().numpy()
373
- depth_image_np = (depth_image_np * 255).astype(np.uint8)
374
- depth_edges = cv2.Canny(depth_image_np, 30, 80)
375
- sketch_image = (
376
- torch.from_numpy(depth_edges).to(depth_image.device).float() / 255
377
- )
378
- sketch_image = sketch_image.unsqueeze(-1)
379
-
380
- return sketch_image
381
-
382
- def back_project(
383
- self, image: Image.Image, elev: float, azim: float
384
- ) -> tuple[torch.Tensor, torch.Tensor]:
385
- if isinstance(image, Image.Image):
386
- image = np.array(image)
387
- image = torch.as_tensor(image, device=self.device, dtype=torch.float32)
388
- if image.ndim == 2:
389
- image = image.unsqueeze(-1)
390
- image = image / 255.0
391
-
392
- view_mat = compute_w2c_matrix(elev, azim, self.camera_distance)
393
- pos_cam = transform_vertices(view_mat, self.vertices, keepdim=True)
394
- pos_clip = transform_vertices(self.camera_proj_mat, pos_cam)
395
- pos_cam = pos_cam[:, :3] / pos_cam[:, 3:]
396
-
397
- v0, v1, v2 = (pos_cam[self.faces[:, i]] for i in range(3))
398
- face_norm = F.normalize(torch.cross(v1 - v0, v2 - v0, dim=-1), dim=-1)
399
- vertex_norm = (
400
- torch.from_numpy(
401
- trimesh.geometry.mean_vertex_normals(
402
- len(pos_cam), self.faces.cpu(), face_norm.cpu()
403
- )
404
- )
405
- .to(self.device)
406
- .contiguous()
407
- )
408
-
409
- rast_out = self.rasterize_mesh(pos_clip, self.faces, image.shape[:2])
410
- vis_mask = torch.clamp(rast_out[..., -1:], 0, 1)[0]
411
-
412
- interp_data = {
413
- "normal": self.raster_interpolate(
414
- vertex_norm[None], rast_out, self.faces
415
- ),
416
- "uv": self.raster_interpolate(
417
- self.uv_map[None], rast_out, self.faces
418
- ),
419
- "depth": self.raster_interpolate(
420
- pos_cam[:, 2].reshape(1, -1, 1), rast_out, self.faces
421
- ),
422
- }
423
-
424
- valid_depth = interp_data["depth"][vis_mask > 0]
425
- depth_norm = (interp_data["depth"] - valid_depth.min()) / (
426
- valid_depth.max() - valid_depth.min()
427
- )
428
- depth_norm[vis_mask <= 0] = 0
429
- sketch_image = self._render_depth_edges(depth_norm * vis_mask)
430
-
431
- # cv2.imwrite("vis_mask.png", (vis_mask.cpu().numpy() * 255).astype(np.uint8))
432
- # cv2.imwrite("normal.png", (interp_data['normal'].cpu().numpy() * 255).astype(np.uint8))
433
- # cv2.imwrite("depth.png", (depth_norm.cpu().numpy() * 255).astype(np.uint8))
434
- # cv2.imwrite("uv.png", (interp_data['uv'][..., 0].cpu().numpy() * 255).astype(np.uint8))
435
- # import pdb; pdb.set_trace()
436
-
437
- cos = F.cosine_similarity(
438
- torch.tensor([[0, 0, -1]], device=self.device),
439
- interp_data["normal"].view(-1, 3),
440
- ).view_as(interp_data["normal"][..., :1])
441
- cos[cos < np.cos(np.radians(self.bake_angle_thresh))] = 0
442
-
443
- k = self.bake_unreliable_kernel_size * 2 + 1
444
- kernel = torch.ones((1, 1, k, k), device=self.device)
445
-
446
- vis_mask = vis_mask.permute(2, 0, 1).unsqueeze(0).float()
447
- vis_mask = F.conv2d(
448
- 1.0 - vis_mask,
449
- kernel,
450
- padding=k // 2,
451
- )
452
- vis_mask = 1.0 - (vis_mask > 0).float()
453
- vis_mask = vis_mask.squeeze(0).permute(1, 2, 0)
454
-
455
- sketch_image = sketch_image.permute(2, 0, 1).unsqueeze(0)
456
- sketch_image = F.conv2d(sketch_image, kernel, padding=k // 2)
457
- sketch_image = (sketch_image > 0).float()
458
- sketch_image = sketch_image.squeeze(0).permute(1, 2, 0)
459
- vis_mask = vis_mask * (sketch_image < 0.5)
460
-
461
- cos[vis_mask == 0] = 0
462
- valid_pixels = (vis_mask != 0).view(-1)
463
-
464
- return (
465
- self._scatter_texture(interp_data["uv"], image, valid_pixels),
466
- self._scatter_texture(interp_data["uv"], cos, valid_pixels),
467
- )
468
-
469
- def back_project2(
470
- self, image, vis_mask, depth, normal, uv
471
- ) -> tuple[torch.Tensor, torch.Tensor]:
472
- if isinstance(image, Image.Image):
473
- image = np.array(image)
474
- image = torch.as_tensor(image, device=self.device, dtype=torch.float32)
475
- if image.ndim == 2:
476
- image = image.unsqueeze(-1)
477
- image = image / 255.0
478
-
479
- depth_inv = (1.0 - depth) * vis_mask
480
- sketch_image = self._render_depth_edges(depth_inv)
481
-
482
- cv2.imwrite(
483
- f"v3_depth_inv{self.cnt}.png",
484
- (depth_inv.cpu().numpy() * 255).astype(np.uint8),
485
- )
486
-
487
- cos = F.cosine_similarity(
488
- torch.tensor([[0, 0, 1]], device=self.device),
489
- normal.view(-1, 3),
490
- ).view_as(normal[..., :1])
491
- cos[cos < np.cos(np.radians(self.bake_angle_thresh))] = 0
492
- # import pdb; pdb.set_trace()
493
- # cv2.imwrite(f"v3_cos{self.cnt}.png", (cos.cpu().numpy() * 255).astype(np.uint8))
494
- # cv2.imwrite(f"v3_sketch{self.cnt}.png", (sketch_image.cpu().numpy() * 255).astype(np.uint8))
495
-
496
- # cos2 = cv2.imread(f"v2_cos{self.cnt+1}.png", cv2.IMREAD_GRAYSCALE)
497
- # cos2 = torch.from_numpy(cos2[..., None]).to(self.device).float() / 255
498
- # cos = cos2
499
-
500
- self.cnt += 1
501
-
502
- k = self.bake_unreliable_kernel_size * 2 + 1
503
- kernel = torch.ones((1, 1, k, k), device=self.device)
504
-
505
- vis_mask = vis_mask.permute(2, 0, 1).unsqueeze(0).float()
506
- vis_mask = F.conv2d(
507
- 1.0 - vis_mask,
508
- kernel,
509
- padding=k // 2,
510
- )
511
- vis_mask = 1.0 - (vis_mask > 0).float()
512
- vis_mask = vis_mask.squeeze(0).permute(1, 2, 0)
513
-
514
- sketch_image = sketch_image.permute(2, 0, 1).unsqueeze(0)
515
- sketch_image = F.conv2d(sketch_image, kernel, padding=k // 2)
516
- sketch_image = (sketch_image > 0).float()
517
- sketch_image = sketch_image.squeeze(0).permute(1, 2, 0)
518
- vis_mask = vis_mask * (sketch_image < 0.5)
519
- # import pdb; pdb.set_trace()
520
- cv2.imwrite(
521
- f"v3_db_sketch{self.cnt}.png",
522
- (sketch_image.cpu().numpy() * 255).astype(np.uint8),
523
- )
524
-
525
- cos[vis_mask == 0] = 0
526
- # import pdb; pdb.set_trace()
527
- # vis_mask = cv2.imread(f"v2_db_mask{self.cnt}.png", cv2.IMREAD_GRAYSCALE)
528
- # vis_mask = torch.from_numpy(vis_mask[..., None]).to(self.device).float() / 255
529
- # cos2 = cv2.imread(f"v2_db_cos{self.cnt}.png", cv2.IMREAD_GRAYSCALE)
530
- # cos2 = torch.from_numpy(cos2[..., None]).to(self.device).float() / 255
531
- # cos = cos2
532
-
533
- valid_pixels = (vis_mask != 0).view(-1)
534
- # import pdb; pdb.set_trace()
535
- cv2.imwrite(
536
- f"v3_db_uv{self.cnt}.png",
537
- (uv[..., 0].cpu().numpy() * 255).astype(np.uint8),
538
- )
539
- cv2.imwrite(
540
- f"v3_db_uv2{self.cnt}.png",
541
- (uv[..., 1].cpu().numpy() * 255).astype(np.uint8),
542
- )
543
- cv2.imwrite(
544
- f"v3_db_color{self.cnt}.png",
545
- (image.cpu().numpy() * 255).astype(np.uint8),
546
- )
547
- cv2.imwrite(
548
- f"v3_db_cos{self.cnt}.png",
549
- (cos.cpu().numpy() * 255).astype(np.uint8),
550
- )
551
- cv2.imwrite(
552
- f"v3_db_mask{self.cnt}.png",
553
- (vis_mask.cpu().numpy() * 255).astype(np.uint8),
554
- )
555
-
556
- return (
557
- self._scatter_texture(uv, image, valid_pixels),
558
- self._scatter_texture(uv, cos, valid_pixels),
559
- )
560
-
561
- def _scatter_texture(self, uv, data, mask):
562
- def __filter_data(data, mask):
563
- return data.view(-1, data.shape[-1])[mask]
564
-
565
- return _bilinear_interpolation_scattering(
566
- self.texture_wh[1],
567
- self.texture_wh[0],
568
- __filter_data(uv, mask)[..., [1, 0]],
569
- __filter_data(data, mask),
570
- )
571
-
572
- @torch.no_grad()
573
- def fast_bake_texture(
574
- self, textures: list[torch.Tensor], confidence_maps: list[torch.Tensor]
575
- ) -> tuple[torch.Tensor, torch.Tensor]:
576
- channel = textures[0].shape[-1]
577
- texture_merge = torch.zeros(self.texture_wh + (channel,)).to(
578
- self.device
579
- )
580
- trust_map_merge = torch.zeros(self.texture_wh + (1,)).to(self.device)
581
- for texture, cos_map in zip(textures, confidence_maps):
582
- view_sum = (cos_map > 0).sum()
583
- painted_sum = ((cos_map > 0) * (trust_map_merge > 0)).sum()
584
- if painted_sum / view_sum > 0.99:
585
- continue
586
- texture_merge += texture * cos_map
587
- trust_map_merge += cos_map
588
- texture_merge = texture_merge / torch.clamp(trust_map_merge, min=1e-8)
589
-
590
- return texture_merge, trust_map_merge > 1e-8
591
-
592
- def uv_inpaint(
593
- self, texture: torch.Tensor, mask: torch.Tensor
594
- ) -> np.ndarray:
595
- texture_np = texture.cpu().numpy()
596
- mask_np = (mask.squeeze(-1).cpu().numpy() * 255).astype(np.uint8)
597
- vertices, faces, uv_map = self.get_mesh_attrs()
598
- # import pdb; pdb.set_trace()
599
- texture_np, mask_np = _texture_inpaint_smooth(
600
- texture_np, mask_np, vertices, faces, uv_map
601
- )
602
- texture_np = texture_np.clip(0, 1)
603
- texture_np = cv2.inpaint(
604
- (texture_np * 255).astype(np.uint8),
605
- 255 - mask_np,
606
- 3,
607
- cv2.INPAINT_NS,
608
- )
609
-
610
- return texture_np
611
-
612
- def __call__(
613
- self, colors: list[Image.Image], input_mesh: str, output_path: str
614
- ) -> trimesh.Trimesh:
615
- self.load_mesh(input_mesh)
616
-
617
- textures, weighted_cos_maps = [], []
618
- for color, cam_elev, cam_azim, weight in zip(
619
- colors, self.camera_elevs, self.camera_azims, self.view_weights
620
- ):
621
- texture, cos_map = self.back_project(color, cam_elev, cam_azim)
622
- textures.append(texture)
623
- weighted_cos_maps.append(weight * (cos_map**4))
624
-
625
- texture, mask = self.fast_bake_texture(textures, weighted_cos_maps)
626
- texture_np = self.uv_inpaint(texture, mask)
627
- texture_np = post_process_texture(texture_np)
628
- vertices, faces, uv_map = self.get_mesh_attrs(self.scale, self.center)
629
- # import pdb; pdb.set_trace()
630
- textured_mesh = save_mesh_with_mtl(
631
- vertices, faces, uv_map, texture_np, output_path
632
- )
633
-
634
- return textured_mesh
635
-
636
- def forward(
637
- self,
638
- colors: list[Image.Image],
639
- masks,
640
- depths,
641
- normals,
642
- uvs,
643
- ) -> trimesh.Trimesh:
644
- textures, weighted_cos_maps = [], []
645
- for color, mask, depth, normal, uv, weight in zip(
646
- colors, masks, depths, normals, uvs, self.view_weights
647
- ):
648
- texture, cos_map = self.back_project2(
649
- color, mask, depth, normal, uv
650
- )
651
- cv2.imwrite(
652
- f"v3_texture{self.cnt}.png",
653
- (texture.cpu().numpy() * 255).astype(np.uint8),
654
- )
655
- cv2.imwrite(
656
- f"v3_texture_cos{self.cnt}.png",
657
- (cos_map.cpu().numpy() * 255).astype(np.uint8),
658
- )
659
-
660
- textures.append(texture)
661
- weighted_cos_maps.append(weight * (cos_map**4))
662
-
663
- texture, mask = self.fast_bake_texture(textures, weighted_cos_maps)
664
- texture_np = self.uv_inpaint(texture, mask)
665
- texture_np = post_process_texture(texture_np)
666
- vertices, faces, uv_map = self.get_mesh_attrs(self.scale, self.center)
667
- # import pdb; pdb.set_trace()
668
- cv2.imwrite("v3_texture_np.png", texture_np)
669
- textured_mesh = save_mesh_with_mtl(
670
- vertices, faces, uv_map, texture_np, output_path
671
- )
672
-
673
- return textured_mesh
674
-
675
-
676
- class Image_Super_Net:
677
- def __init__(self, device="cuda"):
678
- from diffusers import StableDiffusionUpscalePipeline
679
-
680
- self.up_pipeline_x4 = StableDiffusionUpscalePipeline.from_pretrained(
681
- "stabilityai/stable-diffusion-x4-upscaler",
682
- torch_dtype=torch.float16,
683
- ).to(device)
684
- self.up_pipeline_x4.set_progress_bar_config(disable=True)
685
-
686
- def __call__(self, image, prompt=""):
687
- with torch.no_grad():
688
- upscaled_image = self.up_pipeline_x4(
689
- prompt=[prompt],
690
- image=image,
691
- num_inference_steps=10,
692
- ).images[0]
693
-
694
- return upscaled_image
695
-
696
-
697
- class Image_GANNet:
698
- def __init__(self, outscale: int):
699
- from basicsr.archs.rrdbnet_arch import RRDBNet
700
- from realesrgan import RealESRGANer
701
-
702
- self.outscale = outscale
703
- model = RRDBNet(
704
- num_in_ch=3,
705
- num_out_ch=3,
706
- num_feat=64,
707
- num_block=23,
708
- num_grow_ch=32,
709
- scale=4,
710
- )
711
- self.upsampler = RealESRGANer(
712
- scale=4,
713
- model_path="/horizon-bucket/robot_lab/users/xinjie.wang/weights/super_resolution/RealESRGAN_x4plus.pth", # noqa
714
- model=model,
715
- pre_pad=0,
716
- half=True,
717
- )
718
-
719
- def __call__(self, image: Union[Image.Image, np.ndarray]) -> Image.Image:
720
- if isinstance(image, Image.Image):
721
- image = np.array(image)
722
- output, _ = self.upsampler.enhance(image, outscale=self.outscale)
723
-
724
- return Image.fromarray(output)
725
-
726
-
727
- if __name__ == "__main__":
728
- device = "cuda"
729
- color_path = "outputs/texture_mesh_gen/multi_view/color_sample0.png"
730
- mesh_path = "outputs/texture_mesh_gen/texture_mesh/kettle_color.glb"
731
- output_path = "robot_test_v6/robot.obj"
732
- target_image_size = (2048, 2048)
733
-
734
- super_model = Image_GANNet(outscale=4)
735
- multiviews = get_images_from_file(color_path, img_size=512)
736
- multiviews = [super_model(img) for img in multiviews]
737
- multiviews = [img.convert("RGB") for img in multiviews]
738
-
739
- from asset3d_gen.data.utils import (
740
- CameraSetting,
741
- init_kal_camera,
742
- DiffrastRender,
743
- )
744
- import nvdiffrast.torch as dr
745
-
746
- camera_params = CameraSetting(
747
- num_images=6,
748
- elevation=[20.0, -10.0],
749
- distance=5,
750
- resolution_hw=(2048, 2048),
751
- fov=math.radians(30),
752
- device="cuda",
753
- )
754
- camera = init_kal_camera(camera_params)
755
- mv = camera.view_matrix() # (n 4 4) world2cam
756
- p = camera.intrinsics.projection_matrix()
757
- # NOTE: add a negative sign at P[0, 2] as the y axis is flipped in `nvdiffrast` output. # noqa
758
- p[:, 1, 1] = -p[:, 1, 1]
759
- renderer = DiffrastRender(
760
- p_matrix=p,
761
- mv_matrix=mv,
762
- resolution_hw=camera_params.resolution_hw,
763
- context=dr.RasterizeCudaContext(),
764
- mask_thresh=0.5,
765
- grad_db=False,
766
- device=camera_params.device,
767
- antialias_mask=True,
768
- )
769
-
770
- mesh = trimesh.load(mesh_path)
771
- if isinstance(mesh, trimesh.Scene):
772
- mesh = mesh.dump(concatenate=True)
773
-
774
- mesh.vertices, scale, center = normalize_vertices_array(mesh.vertices)
775
- vmapping, indices, uvs = xatlas.parametrize(mesh.vertices, mesh.faces)
776
- uvs[:, 1] = 1 - uvs[:, 1]
777
- mesh.vertices = mesh.vertices[vmapping]
778
- mesh.faces = indices
779
- mesh.visual.uv = uvs
780
-
781
- vertices = torch.from_numpy(mesh.vertices).to(camera_params.device).float()
782
- faces = (
783
- torch.from_numpy(mesh.faces).to(camera_params.device).to(torch.int64)
784
- )
785
- uvs = torch.from_numpy(mesh.visual.uv).to(camera_params.device).float()
786
-
787
- rendered_view_normals = []
788
- rast, vertices_clip = renderer.compute_dr_raster(vertices, faces)
789
- for idx in range(len(mv)):
790
- pos_cam = transform_vertices(mv[idx], vertices, keepdim=True)
791
- pos_cam = pos_cam[:, :3] / pos_cam[:, 3:]
792
- v0, v1, v2 = (pos_cam[faces[:, i]] for i in range(3))
793
- face_norm = F.normalize(torch.cross(v1 - v0, v2 - v0, dim=-1), dim=-1)
794
- vertex_norm = (
795
- torch.from_numpy(
796
- trimesh.geometry.mean_vertex_normals(
797
- len(pos_cam), faces.cpu(), face_norm.cpu()
798
- )
799
- )
800
- .to(camera_params.device)
801
- .contiguous()
802
- )
803
- im_base_normals, _ = dr.interpolate(
804
- vertex_norm[None, ...].float(),
805
- rast[idx : idx + 1],
806
- faces.to(torch.int32),
807
- )
808
- rendered_view_normals.append(im_base_normals)
809
-
810
- rendered_view_normals = torch.cat(rendered_view_normals, dim=0)
811
-
812
- rendered_depth, masks = renderer.render_depth(vertices, faces)
813
- norm_depths = []
814
- for idx in range(len(rendered_depth)):
815
- norm_depth = renderer.normalize_map_by_mask(
816
- rendered_depth[idx : idx + 1], masks[idx : idx + 1]
817
- )
818
- norm_depths.append(norm_depth)
819
- norm_depths = torch.cat(norm_depths, dim=0)
820
- render_uvs, _ = renderer.render_uv(vertices, faces, uvs)
821
-
822
- for index in range(6):
823
- cv2.imwrite(
824
- f"v3_mask{index}.png",
825
- (masks[index] * 255).cpu().numpy().astype(np.uint8),
826
- )
827
- cv2.imwrite(
828
- f"v3_normalv2{index}.png",
829
- (rendered_view_normals[index] * 255)
830
- .cpu()
831
- .numpy()
832
- .astype(np.uint8)[..., ::-1],
833
- )
834
- cv2.imwrite(
835
- f"v3_depth{index}.png",
836
- (norm_depths[index] * 255).cpu().numpy().astype(np.uint8),
837
- )
838
- cv2.imwrite(
839
- f"v3_uv{index}.png",
840
- (render_uvs[index, ..., 0] * 255).cpu().numpy().astype(np.uint8),
841
- )
842
- multiviews[index].save(f"v3_color{index}.png")
843
-
844
- texture_backer = TextureBacker(
845
- camera_elevs=[20, 20, 20, -10, -10, -10],
846
- camera_azims=[-180, -60, 60, -120, 0, 120],
847
- view_weights=[1, 0.2, 0.2, 0.2, 1, 0.2],
848
- camera_distance=5,
849
- camera_fov=30,
850
- render_wh=(2048, 2048),
851
- texture_wh=(2048, 2048),
852
- )
853
- texture_backer.vertices = vertices
854
- texture_backer.faces = faces
855
- uvs[:, 1] = 1.0 - uvs[:, 1]
856
- texture_backer.uv_map = uvs
857
- texture_backer.center = center
858
- texture_backer.scale = scale
859
-
860
- textured_mesh = texture_backer.forward(
861
- multiviews, masks, norm_depths, rendered_view_normals, render_uvs
862
- )
863
-
864
- # multiviews = [super_model(img) for img in multiviews]
865
- # multiviews = [img.convert("RGB") for img in multiviews]
866
- # textured_mesh = texture_backer(multiviews, mesh_path, output_path)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
asset3d_gen/data/backup/backprojectv2.py DELETED
@@ -1,835 +0,0 @@
1
- from PIL import Image
2
- import torch
3
- import torch.nn.functional as F
4
- import numpy as np
5
- import math
6
- import trimesh
7
- import cv2
8
- import xatlas
9
- from typing import Union
10
-
11
-
12
- def get_perspective_projection_matrix(fovy, aspect_wh, near, far):
13
- fovy_rad = math.radians(fovy)
14
- return np.array(
15
- [
16
- [1.0 / (math.tan(fovy_rad / 2.0) * aspect_wh), 0, 0, 0],
17
- [0, 1.0 / math.tan(fovy_rad / 2.0), 0, 0],
18
- [
19
- 0,
20
- 0,
21
- -(far + near) / (far - near),
22
- -2.0 * far * near / (far - near),
23
- ],
24
- [0, 0, -1, 0],
25
- ]
26
- ).astype(np.float32)
27
-
28
-
29
- def load_mesh(mesh):
30
- vtx_pos = mesh.vertices if hasattr(mesh, "vertices") else None
31
- pos_idx = mesh.faces if hasattr(mesh, "faces") else None
32
-
33
- vtx_uv = mesh.visual.uv if hasattr(mesh.visual, "uv") else None
34
- uv_idx = mesh.faces if hasattr(mesh, "faces") else None
35
-
36
- texture_data = None
37
-
38
- return vtx_pos, pos_idx, vtx_uv, uv_idx, texture_data
39
-
40
-
41
- def save_mesh(mesh, texture_data):
42
- material = trimesh.visual.texture.SimpleMaterial(
43
- image=texture_data, diffuse=(255, 255, 255)
44
- )
45
- texture_visuals = trimesh.visual.TextureVisuals(
46
- uv=mesh.visual.uv, image=texture_data, material=material
47
- )
48
- mesh.visual = texture_visuals
49
- return mesh
50
-
51
-
52
- def transform_pos(mtx, pos, keepdim=False):
53
- t_mtx = (
54
- torch.from_numpy(mtx).to(pos.device)
55
- if isinstance(mtx, np.ndarray)
56
- else mtx
57
- )
58
- if pos.shape[-1] == 3:
59
- posw = torch.cat(
60
- [pos, torch.ones([pos.shape[0], 1]).to(pos.device)], axis=1
61
- )
62
- else:
63
- posw = pos
64
-
65
- if keepdim:
66
- return torch.matmul(posw, t_mtx.t())[...]
67
- else:
68
- return torch.matmul(posw, t_mtx.t())[None, ...]
69
-
70
-
71
- def get_mv_matrix(elev, azim, camera_distance, center=None):
72
- elev = -elev
73
-
74
- elev_rad = math.radians(elev)
75
- azim_rad = math.radians(azim)
76
-
77
- camera_position = np.array(
78
- [
79
- camera_distance * math.cos(elev_rad) * math.cos(azim_rad),
80
- camera_distance * math.cos(elev_rad) * math.sin(azim_rad),
81
- camera_distance * math.sin(elev_rad),
82
- ]
83
- )
84
-
85
- if center is None:
86
- center = np.array([0, 0, 0])
87
- else:
88
- center = np.array(center)
89
-
90
- lookat = center - camera_position
91
- lookat = lookat / np.linalg.norm(lookat)
92
-
93
- up = np.array([0, 0, 1.0])
94
- right = np.cross(lookat, up)
95
- right = right / np.linalg.norm(right)
96
- up = np.cross(right, lookat)
97
- up = up / np.linalg.norm(up)
98
-
99
- c2w = np.concatenate(
100
- [np.stack([right, up, -lookat], axis=-1), camera_position[:, None]],
101
- axis=-1,
102
- )
103
-
104
- w2c = np.zeros((4, 4))
105
- w2c[:3, :3] = np.transpose(c2w[:3, :3], (1, 0))
106
- w2c[:3, 3:] = -np.matmul(np.transpose(c2w[:3, :3], (1, 0)), c2w[:3, 3:])
107
- w2c[3, 3] = 1.0
108
-
109
- return w2c.astype(np.float32)
110
-
111
-
112
- def stride_from_shape(shape):
113
- stride = [1]
114
- for x in reversed(shape[1:]):
115
- stride.append(stride[-1] * x)
116
- return list(reversed(stride))
117
-
118
-
119
- def scatter_add_nd_with_count(input, count, indices, values, weights=None):
120
- # input: [..., C], D dimension + C channel
121
- # count: [..., 1], D dimension
122
- # indices: [N, D], long
123
- # values: [N, C]
124
-
125
- D = indices.shape[-1]
126
- C = input.shape[-1]
127
- size = input.shape[:-1]
128
- stride = stride_from_shape(size)
129
-
130
- assert len(size) == D
131
-
132
- input = input.view(-1, C) # [HW, C]
133
- count = count.view(-1, 1)
134
-
135
- flatten_indices = (
136
- indices * torch.tensor(stride, dtype=torch.long, device=indices.device)
137
- ).sum(
138
- -1
139
- ) # [N]
140
-
141
- if weights is None:
142
- weights = torch.ones_like(values[..., :1])
143
-
144
- input.scatter_add_(0, flatten_indices.unsqueeze(1).repeat(1, C), values)
145
- count.scatter_add_(0, flatten_indices.unsqueeze(1), weights)
146
-
147
- return input.view(*size, C), count.view(*size, 1)
148
-
149
-
150
- def linear_grid_put_2d(H, W, coords, values, return_count=False):
151
- # coords: [N, 2], float in [0, 1]
152
- # values: [N, C]
153
-
154
- C = values.shape[-1]
155
-
156
- indices = coords * torch.tensor(
157
- [H - 1, W - 1], dtype=torch.float32, device=coords.device
158
- )
159
- indices_00 = indices.floor().long() # [N, 2]
160
- indices_00[:, 0].clamp_(0, H - 2)
161
- indices_00[:, 1].clamp_(0, W - 2)
162
- indices_01 = indices_00 + torch.tensor(
163
- [0, 1], dtype=torch.long, device=indices.device
164
- )
165
- indices_10 = indices_00 + torch.tensor(
166
- [1, 0], dtype=torch.long, device=indices.device
167
- )
168
- indices_11 = indices_00 + torch.tensor(
169
- [1, 1], dtype=torch.long, device=indices.device
170
- )
171
-
172
- h = indices[..., 0] - indices_00[..., 0].float()
173
- w = indices[..., 1] - indices_00[..., 1].float()
174
- w_00 = (1 - h) * (1 - w)
175
- w_01 = (1 - h) * w
176
- w_10 = h * (1 - w)
177
- w_11 = h * w
178
-
179
- result = torch.zeros(
180
- H, W, C, device=values.device, dtype=values.dtype
181
- ) # [H, W, C]
182
- count = torch.zeros(
183
- H, W, 1, device=values.device, dtype=values.dtype
184
- ) # [H, W, 1]
185
- weights = torch.ones_like(values[..., :1]) # [N, 1]
186
-
187
- result, count = scatter_add_nd_with_count(
188
- result,
189
- count,
190
- indices_00,
191
- values * w_00.unsqueeze(1),
192
- weights * w_00.unsqueeze(1),
193
- )
194
- result, count = scatter_add_nd_with_count(
195
- result,
196
- count,
197
- indices_01,
198
- values * w_01.unsqueeze(1),
199
- weights * w_01.unsqueeze(1),
200
- )
201
- result, count = scatter_add_nd_with_count(
202
- result,
203
- count,
204
- indices_10,
205
- values * w_10.unsqueeze(1),
206
- weights * w_10.unsqueeze(1),
207
- )
208
- result, count = scatter_add_nd_with_count(
209
- result,
210
- count,
211
- indices_11,
212
- values * w_11.unsqueeze(1),
213
- weights * w_11.unsqueeze(1),
214
- )
215
-
216
- if return_count:
217
- return result, count
218
-
219
- mask = count.squeeze(-1) > 0
220
- result[mask] = result[mask] / count[mask].repeat(1, C)
221
-
222
- return result
223
-
224
-
225
- def meshVerticeInpaint_smooth(texture, mask, vtx_pos, vtx_uv, pos_idx, uv_idx):
226
- texture_height, texture_width, texture_channel = texture.shape
227
- vtx_num = vtx_pos.shape[0]
228
-
229
- vtx_mask = np.zeros(vtx_num, dtype=np.float32)
230
- vtx_color = [
231
- np.zeros(texture_channel, dtype=np.float32) for _ in range(vtx_num)
232
- ]
233
- uncolored_vtxs = []
234
- G = [[] for _ in range(vtx_num)]
235
-
236
- for i in range(uv_idx.shape[0]):
237
- for k in range(3):
238
- vtx_uv_idx = uv_idx[i, k]
239
- vtx_idx = pos_idx[i, k]
240
- uv_v = int(round(vtx_uv[vtx_uv_idx, 0] * (texture_width - 1)))
241
- uv_u = int(
242
- round((1.0 - vtx_uv[vtx_uv_idx, 1]) * (texture_height - 1))
243
- )
244
- if mask[uv_u, uv_v] > 0:
245
- vtx_mask[vtx_idx] = 1.0
246
- vtx_color[vtx_idx] = texture[uv_u, uv_v]
247
- else:
248
- uncolored_vtxs.append(vtx_idx)
249
- G[pos_idx[i, k]].append(pos_idx[i, (k + 1) % 3])
250
-
251
- smooth_count = 2
252
- last_uncolored_vtx_count = 0
253
- while smooth_count > 0:
254
- uncolored_vtx_count = 0
255
- for vtx_idx in uncolored_vtxs:
256
- sum_color = np.zeros(texture_channel, dtype=np.float32)
257
- total_weight = 0.0
258
- vtx_0 = vtx_pos[vtx_idx]
259
- for connected_idx in G[vtx_idx]:
260
- if vtx_mask[connected_idx] > 0:
261
- vtx1 = vtx_pos[connected_idx]
262
- dist = np.sqrt(np.sum((vtx_0 - vtx1) ** 2))
263
- dist_weight = 1.0 / max(dist, 1e-4)
264
- dist_weight *= dist_weight
265
- sum_color += vtx_color[connected_idx] * dist_weight
266
- total_weight += dist_weight
267
- if total_weight > 0:
268
- vtx_color[vtx_idx] = sum_color / total_weight
269
- vtx_mask[vtx_idx] = 1.0
270
- else:
271
- uncolored_vtx_count += 1
272
-
273
- if last_uncolored_vtx_count == uncolored_vtx_count:
274
- smooth_count -= 1
275
- else:
276
- smooth_count += 1
277
- last_uncolored_vtx_count = uncolored_vtx_count
278
-
279
- new_texture = texture.copy()
280
- new_mask = mask.copy()
281
- for face_idx in range(uv_idx.shape[0]):
282
- for k in range(3):
283
- vtx_uv_idx = uv_idx[face_idx, k]
284
- vtx_idx = pos_idx[face_idx, k]
285
- if vtx_mask[vtx_idx] == 1.0:
286
- uv_v = int(round(vtx_uv[vtx_uv_idx, 0] * (texture_width - 1)))
287
- uv_u = int(
288
- round((1.0 - vtx_uv[vtx_uv_idx, 1]) * (texture_height - 1))
289
- )
290
- new_texture[uv_u, uv_v] = vtx_color[vtx_idx]
291
- new_mask[uv_u, uv_v] = 255
292
-
293
- return new_texture, new_mask
294
-
295
-
296
- def mesh_uv_wrap(mesh):
297
- if isinstance(mesh, trimesh.Scene):
298
- mesh = mesh.dump(concatenate=True)
299
-
300
- if len(mesh.faces) > 500000000:
301
- raise ValueError(
302
- "The mesh has more than 500,000,000 faces, which is not supported."
303
- )
304
-
305
- vmapping, indices, uvs = xatlas.parametrize(mesh.vertices, mesh.faces)
306
-
307
- mesh.vertices = mesh.vertices[vmapping]
308
- mesh.faces = indices
309
- mesh.visual.uv = uvs
310
-
311
- return mesh
312
-
313
-
314
- class MeshRender:
315
- def __init__(
316
- self,
317
- camera_distance=1.45,
318
- default_resolution=1024,
319
- texture_size=1024,
320
- use_antialias=True,
321
- max_mip_level=None,
322
- filter_mode="linear",
323
- bake_mode="linear",
324
- raster_mode="cr",
325
- device="cuda",
326
- ):
327
-
328
- self.device = device
329
-
330
- self.set_default_render_resolution(default_resolution)
331
- self.set_default_texture_resolution(texture_size)
332
-
333
- self.camera_distance = camera_distance
334
- self.use_antialias = use_antialias
335
- self.max_mip_level = max_mip_level
336
- self.filter_mode = filter_mode
337
-
338
- self.bake_angle_thres = 75
339
- self.bake_unreliable_kernel_size = int(
340
- (2 / 512)
341
- * max(self.default_resolution[0], self.default_resolution[1])
342
- )
343
- self.bake_mode = bake_mode
344
-
345
- self.raster_mode = raster_mode
346
- if self.raster_mode == "cr":
347
- import custom_rasterizer as cr
348
-
349
- self.raster = cr
350
- else:
351
- raise f"No raster named {self.raster_mode}"
352
-
353
- fov = 30
354
- self.camera_proj_mat = get_perspective_projection_matrix(
355
- fov,
356
- self.default_resolution[1] / self.default_resolution[0],
357
- 0.01,
358
- 100.0,
359
- )
360
-
361
- def raster_rasterize(
362
- self, pos, tri, resolution, ranges=None, grad_db=True
363
- ):
364
-
365
- if self.raster_mode == "cr":
366
- rast_out_db = None
367
- if pos.dim() == 2:
368
- pos = pos.unsqueeze(0)
369
- findices, barycentric = self.raster.rasterize(pos, tri, resolution)
370
- rast_out = torch.cat((barycentric, findices.unsqueeze(-1)), dim=-1)
371
- rast_out = rast_out.unsqueeze(0)
372
- else:
373
- raise f"No raster named {self.raster_mode}"
374
-
375
- return rast_out, rast_out_db
376
-
377
- def raster_interpolate(
378
- self, uv, rast_out, uv_idx, rast_db=None, diff_attrs=None
379
- ):
380
-
381
- if self.raster_mode == "cr":
382
- textd = None
383
- barycentric = rast_out[0, ..., :-1]
384
- findices = rast_out[0, ..., -1]
385
- if uv.dim() == 2:
386
- uv = uv.unsqueeze(0)
387
- textc = self.raster.interpolate(uv, findices, barycentric, uv_idx)
388
- else:
389
- raise f"No raster named {self.raster_mode}"
390
-
391
- return textc, textd
392
-
393
- def load_mesh(
394
- self,
395
- mesh,
396
- ):
397
- vtx_pos, pos_idx, vtx_uv, uv_idx, texture_data = load_mesh(mesh)
398
- self.mesh_copy = mesh
399
- self.set_mesh(
400
- vtx_pos,
401
- pos_idx,
402
- vtx_uv=vtx_uv,
403
- uv_idx=uv_idx,
404
- )
405
- if texture_data is not None:
406
- self.set_texture(texture_data)
407
-
408
- def save_mesh(self):
409
- texture_data = self.get_texture()
410
- texture_data = Image.fromarray((texture_data * 255).astype(np.uint8))
411
- return save_mesh(self.mesh_copy, texture_data)
412
-
413
- def set_mesh(
414
- self,
415
- vtx_pos,
416
- pos_idx,
417
- vtx_uv=None,
418
- uv_idx=None,
419
- ):
420
-
421
- self.vtx_pos = torch.from_numpy(vtx_pos).to(self.device).float()
422
- self.pos_idx = torch.from_numpy(pos_idx).to(self.device).to(torch.int)
423
- if (vtx_uv is not None) and (uv_idx is not None):
424
- self.vtx_uv = torch.from_numpy(vtx_uv).to(self.device).float()
425
- self.uv_idx = (
426
- torch.from_numpy(uv_idx).to(self.device).to(torch.int)
427
- )
428
- else:
429
- self.vtx_uv = None
430
- self.uv_idx = None
431
-
432
- self.vtx_pos[:, [0, 1]] = -self.vtx_pos[:, [0, 1]]
433
- self.vtx_pos[:, [1, 2]] = self.vtx_pos[:, [2, 1]]
434
- if (vtx_uv is not None) and (uv_idx is not None):
435
- self.vtx_uv[:, 1] = 1.0 - self.vtx_uv[:, 1]
436
-
437
- def set_texture(self, tex):
438
- if isinstance(tex, np.ndarray):
439
- tex = Image.fromarray((tex * 255).astype(np.uint8))
440
- elif isinstance(tex, torch.Tensor):
441
- tex = tex.cpu().numpy()
442
- tex = Image.fromarray((tex * 255).astype(np.uint8))
443
-
444
- tex = tex.resize(self.texture_size).convert("RGB")
445
- tex = np.array(tex) / 255.0
446
- self.tex = torch.from_numpy(tex).to(self.device)
447
- self.tex = self.tex.float()
448
-
449
- def set_default_render_resolution(self, default_resolution):
450
- if isinstance(default_resolution, int):
451
- default_resolution = (default_resolution, default_resolution)
452
- self.default_resolution = default_resolution
453
-
454
- def set_default_texture_resolution(self, texture_size):
455
- if isinstance(texture_size, int):
456
- texture_size = (texture_size, texture_size)
457
- self.texture_size = texture_size
458
-
459
- def get_mesh(self):
460
- vtx_pos = self.vtx_pos.cpu().numpy()
461
- pos_idx = self.pos_idx.cpu().numpy()
462
- vtx_uv = self.vtx_uv.cpu().numpy()
463
- uv_idx = self.uv_idx.cpu().numpy()
464
-
465
- # 坐标变换的逆变换
466
- vtx_pos[:, [1, 2]] = vtx_pos[:, [2, 1]]
467
- vtx_pos[:, [0, 1]] = -vtx_pos[:, [0, 1]]
468
-
469
- vtx_uv[:, 1] = 1.0 - vtx_uv[:, 1]
470
- return vtx_pos, pos_idx, vtx_uv, uv_idx
471
-
472
- def get_texture(self):
473
- return self.tex.cpu().numpy()
474
-
475
- def render_sketch_from_depth(self, depth_image):
476
- depth_image_np = depth_image.cpu().numpy()
477
- depth_image_np = (depth_image_np * 255).astype(np.uint8)
478
- depth_edges = cv2.Canny(depth_image_np, 30, 80)
479
- combined_edges = depth_edges
480
- sketch_image = (
481
- torch.from_numpy(combined_edges).to(depth_image.device).float()
482
- / 255.0
483
- )
484
- sketch_image = sketch_image.unsqueeze(-1)
485
- return sketch_image
486
-
487
- def back_project(
488
- self, image, elev, azim, camera_distance=None, center=None, method=None
489
- ):
490
- if isinstance(image, Image.Image):
491
- image = torch.tensor(np.array(image) / 255.0)
492
- elif isinstance(image, np.ndarray):
493
- image = torch.tensor(image)
494
- if image.dim() == 2:
495
- image = image.unsqueeze(-1)
496
- image = image.float().to(self.device)
497
- resolution = image.shape[:2]
498
- channel = image.shape[-1]
499
- texture = torch.zeros(self.texture_size + (channel,)).to(self.device)
500
- cos_map = torch.zeros(self.texture_size + (1,)).to(self.device)
501
-
502
- proj = self.camera_proj_mat
503
- r_mv = get_mv_matrix(
504
- elev=elev,
505
- azim=azim,
506
- camera_distance=(
507
- self.camera_distance
508
- if camera_distance is None
509
- else camera_distance
510
- ),
511
- center=center,
512
- )
513
- pos_camera = transform_pos(r_mv, self.vtx_pos, keepdim=True)
514
- pos_clip = transform_pos(proj, pos_camera)
515
- pos_camera = pos_camera[:, :3] / pos_camera[:, 3:4]
516
- v0 = pos_camera[self.pos_idx[:, 0], :]
517
- v1 = pos_camera[self.pos_idx[:, 1], :]
518
- v2 = pos_camera[self.pos_idx[:, 2], :]
519
- face_normals = F.normalize(
520
- torch.cross(v1 - v0, v2 - v0, dim=-1), dim=-1
521
- )
522
- vertex_normals = trimesh.geometry.mean_vertex_normals(
523
- vertex_count=self.vtx_pos.shape[0],
524
- faces=self.pos_idx.cpu(),
525
- face_normals=face_normals.cpu(),
526
- )
527
- vertex_normals = (
528
- torch.from_numpy(vertex_normals)
529
- .float()
530
- .to(self.device)
531
- .contiguous()
532
- )
533
- tex_depth = pos_camera[:, 2].reshape(1, -1, 1).contiguous()
534
- rast_out, rast_out_db = self.raster_rasterize(
535
- pos_clip, self.pos_idx, resolution=resolution
536
- )
537
- visible_mask = torch.clamp(rast_out[..., -1:], 0, 1)[0, ...]
538
-
539
- normal, _ = self.raster_interpolate(
540
- vertex_normals[None, ...], rast_out, self.pos_idx
541
- )
542
- normal = normal[0, ...]
543
-
544
- uv, _ = self.raster_interpolate(
545
- self.vtx_uv[None, ...], rast_out, self.uv_idx
546
- )
547
- depth, _ = self.raster_interpolate(tex_depth, rast_out, self.pos_idx)
548
- depth = depth[0, ...]
549
-
550
- depth_max, depth_min = (
551
- depth[visible_mask > 0].max(),
552
- depth[visible_mask > 0].min(),
553
- )
554
- depth_normalized = (depth - depth_min) / (depth_max - depth_min)
555
- depth_image = depth_normalized * visible_mask # Mask out background.
556
-
557
- sketch_image = self.render_sketch_from_depth(depth_image)
558
-
559
- cv2.imwrite("d_depth.png", depth_image.cpu().numpy() * 255)
560
- cv2.imwrite("d_normal.png", normal.cpu().numpy() * 255)
561
- cv2.imwrite(
562
- "d_image.png", image.cpu().numpy()[..., :3][..., ::-1] * 255
563
- )
564
- cv2.imwrite("d_sketch_image.png", sketch_image.cpu().numpy() * 255)
565
- cv2.imwrite("d_uv1.png", uv.cpu().numpy()[0, ..., 0] * 255)
566
- cv2.imwrite("d_uv2.png", uv.cpu().numpy()[0, ..., 1] * 255)
567
- # p uv[0,...,0].mean(axis=0)
568
- # import pdb; pdb.set_trace()
569
-
570
- # depth_image = None
571
- # normal = None
572
- # image = None
573
-
574
- sketch_image = self.render_sketch_from_depth(depth_image)
575
- channel = image.shape[-1]
576
-
577
- lookat = torch.tensor([[0, 0, -1]], device=self.device)
578
- cos_image = torch.nn.functional.cosine_similarity(
579
- lookat, normal.view(-1, 3)
580
- )
581
- cos_image = cos_image.view(normal.shape[0], normal.shape[1], 1)
582
-
583
- cos_thres = np.cos(self.bake_angle_thres / 180 * np.pi)
584
- cos_image[cos_image < cos_thres] = 0
585
-
586
- # shrink
587
- kernel_size = self.bake_unreliable_kernel_size * 2 + 1
588
- kernel = torch.ones(
589
- (1, 1, kernel_size, kernel_size), dtype=torch.float32
590
- ).to(sketch_image.device)
591
-
592
- visible_mask = visible_mask.permute(2, 0, 1).unsqueeze(0).float()
593
- visible_mask = F.conv2d(
594
- 1.0 - visible_mask, kernel, padding=kernel_size // 2
595
- )
596
- visible_mask = 1.0 - (visible_mask > 0).float() # 二值化
597
- visible_mask = visible_mask.squeeze(0).permute(1, 2, 0)
598
-
599
- sketch_image = sketch_image.permute(2, 0, 1).unsqueeze(0)
600
- sketch_image = F.conv2d(sketch_image, kernel, padding=kernel_size // 2)
601
- sketch_image = (sketch_image > 0).float() # 二值化
602
- sketch_image = sketch_image.squeeze(0).permute(1, 2, 0)
603
- visible_mask = visible_mask * (sketch_image < 0.5)
604
-
605
- cos_image[visible_mask == 0] = 0
606
- proj_mask = (visible_mask != 0).view(-1)
607
- uv = uv.squeeze(0).contiguous().view(-1, 2)[proj_mask]
608
- image = image.squeeze(0).contiguous().view(-1, channel)[proj_mask]
609
- cos_image = cos_image.contiguous().view(-1, 1)[proj_mask]
610
- sketch_image = sketch_image.contiguous().view(-1, 1)[proj_mask]
611
- import pdb
612
-
613
- pdb.set_trace()
614
- texture = linear_grid_put_2d(
615
- self.texture_size[1], self.texture_size[0], uv[..., [1, 0]], image
616
- )
617
- cos_map = linear_grid_put_2d(
618
- self.texture_size[1],
619
- self.texture_size[0],
620
- uv[..., [1, 0]],
621
- cos_image,
622
- )
623
- boundary_map = linear_grid_put_2d(
624
- self.texture_size[1],
625
- self.texture_size[0],
626
- uv[..., [1, 0]],
627
- sketch_image,
628
- )
629
-
630
- return texture, cos_map, boundary_map
631
-
632
- @torch.no_grad()
633
- def fast_bake_texture(self, textures, cos_maps):
634
-
635
- channel = textures[0].shape[-1]
636
- texture_merge = torch.zeros(self.texture_size + (channel,)).to(
637
- self.device
638
- )
639
- trust_map_merge = torch.zeros(self.texture_size + (1,)).to(self.device)
640
- for texture, cos_map in zip(textures, cos_maps):
641
- view_sum = (cos_map > 0).sum()
642
- painted_sum = ((cos_map > 0) * (trust_map_merge > 0)).sum()
643
- if painted_sum / view_sum > 0.99:
644
- continue
645
- texture_merge += texture * cos_map
646
- trust_map_merge += cos_map
647
- texture_merge = texture_merge / torch.clamp(trust_map_merge, min=1e-8)
648
-
649
- return texture_merge, trust_map_merge > 1e-8
650
-
651
- def uv_inpaint(self, texture, mask):
652
-
653
- if isinstance(texture, torch.Tensor):
654
- texture_np = texture.cpu().numpy()
655
- elif isinstance(texture, np.ndarray):
656
- texture_np = texture
657
- elif isinstance(texture, Image.Image):
658
- texture_np = np.array(texture) / 255.0
659
-
660
- vtx_pos, pos_idx, vtx_uv, uv_idx = self.get_mesh()
661
-
662
- texture_np, mask = meshVerticeInpaint_smooth(
663
- texture_np, mask, vtx_pos, vtx_uv, pos_idx, uv_idx
664
- )
665
-
666
- texture_np = cv2.inpaint(
667
- (texture_np * 255).astype(np.uint8), 255 - mask, 3, cv2.INPAINT_NS
668
- )
669
-
670
- return texture_np
671
-
672
-
673
- def get_images_from_file(img_path: str, img_size: int) -> list[np.array]:
674
- input_image = Image.open(img_path)
675
- view_images = np.array(input_image)
676
- view_images = np.concatenate(
677
- [view_images[:img_size, ...], view_images[img_size:, ...]], axis=1
678
- )
679
- images = np.split(view_images, view_images.shape[1] // img_size, axis=1)
680
-
681
- return images
682
-
683
-
684
- def bake_from_multiview(
685
- render, views, camera_elevs, camera_azims, view_weights, method="fast"
686
- ):
687
- project_textures, project_weighted_cos_maps = [], []
688
- project_boundary_maps = []
689
- for view, camera_elev, camera_azim, weight in zip(
690
- views, camera_elevs, camera_azims, view_weights
691
- ):
692
- project_texture, project_cos_map, project_boundary_map = (
693
- render.back_project(view, camera_elev, camera_azim)
694
- )
695
- project_cos_map = weight * (project_cos_map**4)
696
- project_textures.append(project_texture)
697
- project_weighted_cos_maps.append(project_cos_map)
698
- project_boundary_maps.append(project_boundary_map)
699
-
700
- if method == "fast":
701
- texture, ori_trust_map = render.fast_bake_texture(
702
- project_textures, project_weighted_cos_maps
703
- )
704
- else:
705
- raise f"no method {method}"
706
-
707
- return texture, ori_trust_map > 1e-8
708
-
709
-
710
- def post_process(texture: np.ndarray, iter: int = 2) -> np.ndarray:
711
- for _ in range(iter):
712
- texture = cv2.fastNlMeansDenoisingColored(texture, None, 11, 11, 9, 25)
713
- texture = cv2.bilateralFilter(
714
- texture, d=7, sigmaColor=80, sigmaSpace=80
715
- )
716
-
717
- return texture
718
-
719
-
720
- class Image_Super_Net:
721
- def __init__(self, device="cuda"):
722
- from diffusers import StableDiffusionUpscalePipeline
723
-
724
- self.up_pipeline_x4 = StableDiffusionUpscalePipeline.from_pretrained(
725
- "stabilityai/stable-diffusion-x4-upscaler",
726
- torch_dtype=torch.float16,
727
- ).to(device)
728
- self.up_pipeline_x4.set_progress_bar_config(disable=True)
729
-
730
- def __call__(self, image, prompt=""):
731
- with torch.no_grad():
732
- upscaled_image = self.up_pipeline_x4(
733
- prompt=[prompt],
734
- image=image,
735
- num_inference_steps=10,
736
- ).images[0]
737
-
738
- return upscaled_image
739
-
740
-
741
- class Image_GANNet:
742
- def __init__(self, outscale: int):
743
- from realesrgan import RealESRGANer
744
- from basicsr.archs.rrdbnet_arch import RRDBNet
745
-
746
- self.outscale = outscale
747
- model = RRDBNet(
748
- num_in_ch=3,
749
- num_out_ch=3,
750
- num_feat=64,
751
- num_block=23,
752
- num_grow_ch=32,
753
- scale=4,
754
- )
755
- self.upsampler = RealESRGANer(
756
- scale=4,
757
- model_path="/home/users/xinjie.wang/xinjie/Real-ESRGAN/weights/RealESRGAN_x4plus.pth",
758
- model=model,
759
- pre_pad=0,
760
- half=True,
761
- )
762
-
763
- def __call__(self, image: Union[Image.Image, np.ndarray]) -> Image.Image:
764
- if isinstance(image, Image.Image):
765
- image = np.array(image)
766
- output, _ = self.upsampler.enhance(image, outscale=self.outscale)
767
-
768
- return Image.fromarray(output)
769
-
770
-
771
- if __name__ == "__main__":
772
- device = "cuda"
773
-
774
- # super_model = Image_Super_Net(device)
775
- super_model = Image_GANNet(outscale=4)
776
-
777
- selected_camera_elevs = [20, 20, 20, -10, -10, -10]
778
- selected_camera_azims = [-180, -60, 60, -120, 0, 120]
779
- selected_view_weights = [1, 0.2, 0.2, 0.2, 1, 0.2]
780
- # selected_view_weights = [1, 0.1, 0.5, 0.1, 0.05, 0.05]
781
-
782
- multiviews = get_images_from_file(
783
- "scripts/apps/texture_sessions/mfq4e7u4ko/multi_view/color_sample1.png",
784
- 512,
785
- )
786
- target_image_size = (2048, 2048)
787
-
788
- render = MeshRender(
789
- camera_distance=5,
790
- default_resolution=2048,
791
- texture_size=2048,
792
- )
793
-
794
- mesh = trimesh.load("scripts/apps/assets/example_texture/meshes/robot.obj")
795
- from asset3d_gen.data.utils import normalize_vertices_array
796
-
797
- mesh.vertices, scale, center = normalize_vertices_array(mesh.vertices)
798
- mesh = mesh_uv_wrap(mesh)
799
- render.load_mesh(mesh)
800
-
801
- # multiviews = [Image.fromarray(img) for img in multiviews]
802
- # multiviews = [Image.fromarray(img).convert("RGB") for img in multiviews]
803
- # for idx, img in enumerate(multiviews):
804
- # img.save(f"robot/raw/res_{idx}.png")
805
-
806
- multiviews = [super_model(img) for img in multiviews]
807
- multiviews = [img.convert("RGB") for img in multiviews]
808
- for idx, img in enumerate(multiviews):
809
- img.save(f"robot/super_gan_res_{idx}.png")
810
-
811
- texture, mask = bake_from_multiview(
812
- render,
813
- multiviews,
814
- selected_camera_elevs,
815
- selected_camera_azims,
816
- selected_view_weights,
817
- )
818
-
819
- texture_np = (texture.cpu().numpy() * 255).astype(np.uint8)[..., :3][
820
- ..., ::-1
821
- ]
822
- cv2.imwrite("robot/raw_texture.png", texture_np)
823
- print("texture done.")
824
-
825
- mask_np = (mask.squeeze(-1).cpu().numpy() * 255).astype(np.uint8)
826
- texture_np = render.uv_inpaint(texture, mask_np)
827
- cv2.imwrite("robot/inpaint_texture.png", texture_np[..., ::-1])
828
- # texture_np = post_process(texture_np, 2)
829
- # cv2.imwrite("robot/inpaint_conv_texture.png", texture_np[..., ::-1])
830
- print("inpaint done.")
831
-
832
- texture = torch.tensor(texture_np / 255).float().to(texture.device)
833
- render.set_texture(texture)
834
- textured_mesh = render.save_mesh()
835
- _ = textured_mesh.export("robot/robot.obj")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
asset3d_gen/data/backup/gpt_qwen.py DELETED
@@ -1,70 +0,0 @@
1
- import torch
2
- from transformers import Qwen2_5_VLForConditionalGeneration, AutoTokenizer, AutoProcessor
3
- from qwen_vl_utils import process_vision_info
4
- import os
5
- os.environ["https_proxy"] = "10.9.0.31:8838"
6
-
7
-
8
- # # default: Load the model on the available device(s)
9
- # model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
10
- # "Qwen/Qwen2.5-VL-7B-Instruct", torch_dtype="auto", device_map="auto"
11
- # )
12
-
13
- # We recommend enabling flash_attention_2 for better acceleration and memory saving, especially in multi-image and video scenarios.
14
- model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
15
- "Qwen/Qwen2.5-VL-7B-Instruct",
16
- torch_dtype=torch.bfloat16,
17
- attn_implementation="flash_attention_2",
18
- device_map="auto",
19
- )
20
-
21
-
22
- # default processer
23
- processor = AutoProcessor.from_pretrained("Qwen/Qwen2.5-VL-7B-Instruct")
24
-
25
- # The default range for the number of visual tokens per image in the model is 4-16384.
26
- # You can set min_pixels and max_pixels according to your needs, such as a token range of 256-1280, to balance performance and cost.
27
- # min_pixels = 256*28*28
28
- # max_pixels = 1280*28*28
29
- # processor = AutoProcessor.from_pretrained("Qwen/Qwen2.5-VL-7B-Instruct", min_pixels=min_pixels, max_pixels=max_pixels)
30
-
31
- messages = [
32
- {
33
- "role": "user",
34
- "content": [
35
- {
36
- "type": "image",
37
- "image": "outputs/text2image/demo_objects/bed/sample_0.jpg",
38
- },
39
- {
40
- "type": "image",
41
- "image": "outputs/imageto3d/v2/cups/sample_69/URDF_sample_69/qa_renders/image_color/003.png",
42
- },
43
- {"type": "text", "text": "Describe the secend image."},
44
- ],
45
- }
46
- ]
47
-
48
- # Preparation for inference
49
- text = processor.apply_chat_template(
50
- messages, tokenize=False, add_generation_prompt=True
51
- )
52
- image_inputs, video_inputs = process_vision_info(messages)
53
- inputs = processor(
54
- text=[text],
55
- images=image_inputs,
56
- videos=video_inputs,
57
- padding=True,
58
- return_tensors="pt",
59
- )
60
- inputs = inputs.to("cuda")
61
-
62
- # Inference: Generation of the output
63
- generated_ids = model.generate(**inputs, max_new_tokens=128)
64
- generated_ids_trimmed = [
65
- out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
66
- ]
67
- output_text = processor.batch_decode(
68
- generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
69
- )
70
- print(output_text)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
asset3d_gen/data/backup/quat.py DELETED
@@ -1,49 +0,0 @@
1
- import numpy as np
2
-
3
- def quaternion_rotation_x_counterclockwise(angle_degrees):
4
- angle_radians = np.radians(angle_degrees)
5
- w = np.cos(angle_radians / 2)
6
- x = np.sin(angle_radians / 2)
7
- y, z = 0.0, 0.0
8
- return np.array([x, y, z, w]).round(4).tolist()
9
-
10
-
11
- def quaternion_rotation_y_counterclockwise(angle_degrees):
12
- angle_radians = np.radians(angle_degrees)
13
- w = np.cos(angle_radians / 2)
14
- y = np.sin(angle_radians / 2)
15
- x, z = 0.0, 0.0
16
- return np.array([x, y, z, w]).round(4).tolist()
17
-
18
-
19
- def quaternion_rotation_z_counterclockwise(angle_degrees):
20
- angle_radians = np.radians(angle_degrees)
21
- w = np.cos(angle_radians / 2)
22
- z = np.sin(angle_radians / 2)
23
- x, y = 0.0, 0.0
24
- return np.array([x, y, z, w]).round(4).tolist()
25
-
26
-
27
- def quaternion_multiply(q1, q2):
28
- x1, y1, z1, w1 = q1
29
- x2, y2, z2, w2 = q2
30
- w = w1*w2 - x1*x2 - y1*y2 - z1*z2
31
- x = w1*x2 + x1*w2 + y1*z2 - z1*y2
32
- y = w1*y2 - x1*z2 + y1*w2 + z1*x2
33
- z = w1*z2 + x1*y2 - y1*x2 + z1*w2
34
- return np.array([w, x, y, z])
35
-
36
-
37
-
38
- angle = 180
39
-
40
- print(f"X轴逆时针旋转{angle}度: {quaternion_rotation_x_counterclockwise(angle)}")
41
- print(f"Y轴逆时针旋转{angle}度: {quaternion_rotation_y_counterclockwise(angle)}")
42
- print(f"Z轴逆时针旋转{angle}度: {quaternion_rotation_z_counterclockwise(angle)}")
43
-
44
-
45
- q_1 = np.array([1.0, 0.0, 0.0, 0.0])
46
- q_2 = np.array([0.0, 0.0, 1.0, 0.0])
47
-
48
- q_total = quaternion_multiply(q_2, q_1)
49
- print(q_total.round(4).tolist())
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
asset3d_gen/data/differentiable_render.py CHANGED
@@ -353,13 +353,11 @@ def parse_args():
353
  "--mesh_path",
354
  type=str,
355
  nargs="+",
356
- required=True,
357
  help="Paths to the mesh files for rendering.",
358
  )
359
  parser.add_argument(
360
  "--output_root",
361
  type=str,
362
- required=True,
363
  help="Root directory for output",
364
  )
365
  parser.add_argument(
@@ -446,7 +444,7 @@ def parse_args():
446
 
447
  args = parser.parse_args()
448
 
449
- if args.uuid is None:
450
  args.uuid = []
451
  for path in args.mesh_path:
452
  uuid = os.path.basename(path).split(".")[0]
@@ -455,8 +453,11 @@ def parse_args():
455
  return args
456
 
457
 
458
- def entrypoint() -> None:
459
  args = parse_args()
 
 
 
460
 
461
  camera_settings = CameraSetting(
462
  num_images=args.num_images,
 
353
  "--mesh_path",
354
  type=str,
355
  nargs="+",
 
356
  help="Paths to the mesh files for rendering.",
357
  )
358
  parser.add_argument(
359
  "--output_root",
360
  type=str,
 
361
  help="Root directory for output",
362
  )
363
  parser.add_argument(
 
444
 
445
  args = parser.parse_args()
446
 
447
+ if args.uuid is None and args.mesh_path is not None:
448
  args.uuid = []
449
  for path in args.mesh_path:
450
  uuid = os.path.basename(path).split(".")[0]
 
453
  return args
454
 
455
 
456
+ def entrypoint(**kwargs) -> None:
457
  args = parse_args()
458
+ for k, v in kwargs.items():
459
+ if hasattr(args, k) and v is not None:
460
+ setattr(args, k, v)
461
 
462
  camera_settings = CameraSetting(
463
  num_images=args.num_images,
asset3d_gen/data/mesh_operator.py CHANGED
@@ -1,6 +1,6 @@
1
  import logging
2
  from typing import Tuple, Union
3
-
4
  import igraph
5
  import numpy as np
6
  import pyvista as pv
@@ -384,6 +384,7 @@ class MeshFixer(object):
384
  dtype=torch.int32,
385
  )
386
 
 
387
  def __call__(
388
  self,
389
  filter_ratio: float,
 
1
  import logging
2
  from typing import Tuple, Union
3
+ import spaces
4
  import igraph
5
  import numpy as np
6
  import pyvista as pv
 
384
  dtype=torch.int32,
385
  )
386
 
387
+ @spaces.GPU
388
  def __call__(
389
  self,
390
  filter_ratio: float,
asset3d_gen/models/delight_model.py CHANGED
@@ -1,6 +1,6 @@
1
  import os
2
  from typing import Union
3
-
4
  import cv2
5
  import numpy as np
6
  import torch
@@ -102,6 +102,7 @@ class DelightingModel(object):
102
 
103
  return new_image
104
 
 
105
  @torch.no_grad()
106
  def __call__(
107
  self,
 
1
  import os
2
  from typing import Union
3
+ import spaces
4
  import cv2
5
  import numpy as np
6
  import torch
 
102
 
103
  return new_image
104
 
105
+ @spaces.GPU
106
  @torch.no_grad()
107
  def __call__(
108
  self,
asset3d_gen/models/sr_model.py CHANGED
@@ -1,7 +1,7 @@
1
  import logging
2
  import os
3
  from typing import Union
4
-
5
  import numpy as np
6
  import torch
7
  from huggingface_hub import snapshot_download
@@ -35,6 +35,7 @@ class ImageStableSR:
35
  self.up_pipeline_x4.set_progress_bar_config(disable=True)
36
  # self.up_pipeline_x4.enable_model_cpu_offload()
37
 
 
38
  def __call__(
39
  self,
40
  image: Union[Image.Image, np.ndarray],
@@ -105,6 +106,7 @@ class ImageRealESRGAN:
105
  half=True,
106
  )
107
 
 
108
  def __call__(self, image: Union[Image.Image, np.ndarray]) -> Image.Image:
109
  if isinstance(image, Image.Image):
110
  image = np.array(image)
 
1
  import logging
2
  import os
3
  from typing import Union
4
+ import spaces
5
  import numpy as np
6
  import torch
7
  from huggingface_hub import snapshot_download
 
35
  self.up_pipeline_x4.set_progress_bar_config(disable=True)
36
  # self.up_pipeline_x4.enable_model_cpu_offload()
37
 
38
+ @spaces.GPU
39
  def __call__(
40
  self,
41
  image: Union[Image.Image, np.ndarray],
 
106
  half=True,
107
  )
108
 
109
+ @spaces.GPU
110
  def __call__(self, image: Union[Image.Image, np.ndarray]) -> Image.Image:
111
  if isinstance(image, Image.Image):
112
  image = np.array(image)
asset3d_gen/scripts/render_gs.py CHANGED
@@ -2,7 +2,7 @@ import argparse
2
  import logging
3
  import math
4
  import os
5
-
6
  import cv2
7
  import numpy as np
8
  import torch
@@ -94,6 +94,7 @@ def load_gs_model(
94
  return gs_model
95
 
96
 
 
97
  def entrypoint(input_gs: str = None, output_path: str = None) -> None:
98
  args = parse_args()
99
  if isinstance(input_gs, str):
 
2
  import logging
3
  import math
4
  import os
5
+ import spaces
6
  import cv2
7
  import numpy as np
8
  import torch
 
94
  return gs_model
95
 
96
 
97
+ @spaces.GPU
98
  def entrypoint(input_gs: str = None, output_path: str = None) -> None:
99
  args = parse_args()
100
  if isinstance(input_gs, str):
common.py CHANGED
@@ -16,6 +16,7 @@ from easydict import EasyDict as edict
16
  from PIL import Image
17
  from tqdm import tqdm
18
  from asset3d_gen.data.backproject_v2 import entrypoint as backproject_api
 
19
  from asset3d_gen.models.delight_model import DelightingModel
20
  from asset3d_gen.models.gs_model import GaussianOperator
21
  from asset3d_gen.models.segment_model import (
@@ -71,7 +72,7 @@ os.environ["TORCH_EXTENSIONS_DIR"] = os.path.expanduser(
71
  "~/.cache/torch_extensions"
72
  )
73
  os.environ["GRADIO_ANALYTICS_ENABLED"] = "false"
74
- os.environ['SPCONV_ALGO'] = 'native'
75
 
76
  MAX_SEED = 100000
77
  DELIGHT = DelightingModel()
@@ -82,18 +83,25 @@ def download_kolors_weights() -> None:
82
  logger.info(f"Download kolors weights from huggingface...")
83
  subprocess.run(
84
  [
85
- "huggingface-cli", "download", "--resume-download",
86
- "Kwai-Kolors/Kolors", "--local-dir", "weights/Kolors"
 
 
 
 
87
  ],
88
- check=True
89
  )
90
  subprocess.run(
91
  [
92
- "huggingface-cli", "download", "--resume-download",
93
- "Kwai-Kolors/Kolors-IP-Adapter-Plus", "--local-dir",
94
- "weights/Kolors-IP-Adapter-Plus"
 
 
 
95
  ],
96
- check=True
97
  )
98
 
99
 
@@ -121,9 +129,7 @@ elif os.getenv("GRADIO_APP") == "textto3d":
121
  if not os.path.exists(text_model_dir):
122
  download_kolors_weights()
123
 
124
- PIPELINE_IMG_IP = build_text2img_ip_pipeline(
125
- text_model_dir, ref_scale=0.3
126
- )
127
  PIPELINE_IMG = build_text2img_pipeline(text_model_dir)
128
  SEG_CHECKER = ImageSegChecker(GPT_CLIENT)
129
  GEO_CHECKER = MeshGeoChecker(GPT_CLIENT)
@@ -156,7 +162,7 @@ os.makedirs(TMP_DIR, exist_ok=True)
156
  lighting_css = """
157
  <style>
158
  #lighter_mesh canvas {
159
- filter: brightness(1.6) !important;
160
  }
161
  </style>
162
  """
@@ -299,7 +305,6 @@ def get_cached_image(image_path: str) -> Image.Image:
299
  return Image.open(image_path).resize((512, 512))
300
 
301
 
302
- @spaces.GPU
303
  def pack_state(gs: Gaussian, mesh: MeshExtractResult) -> dict:
304
  return {
305
  "gaussian": {
@@ -318,7 +323,7 @@ def pack_state(gs: Gaussian, mesh: MeshExtractResult) -> dict:
318
 
319
 
320
  @spaces.GPU
321
- def unpack_state(state: dict) -> tuple[Gaussian, edict, str]:
322
  gs = Gaussian(
323
  aabb=state["gaussian"]["aabb"],
324
  sh_degree=state["gaussian"]["sh_degree"],
@@ -327,17 +332,17 @@ def unpack_state(state: dict) -> tuple[Gaussian, edict, str]:
327
  opacity_bias=state["gaussian"]["opacity_bias"],
328
  scaling_activation=state["gaussian"]["scaling_activation"],
329
  )
330
- gs._xyz = torch.tensor(state["gaussian"]["_xyz"], device="cuda")
331
  gs._features_dc = torch.tensor(
332
- state["gaussian"]["_features_dc"], device="cuda"
333
  )
334
- gs._scaling = torch.tensor(state["gaussian"]["_scaling"], device="cuda")
335
- gs._rotation = torch.tensor(state["gaussian"]["_rotation"], device="cuda")
336
- gs._opacity = torch.tensor(state["gaussian"]["_opacity"], device="cuda")
337
 
338
  mesh = edict(
339
- vertices=torch.tensor(state["mesh"]["vertices"], device="cuda"),
340
- faces=torch.tensor(state["mesh"]["faces"], device="cuda"),
341
  )
342
 
343
  return gs, mesh
@@ -484,7 +489,6 @@ def extract_3d_representations(
484
  return mesh_glb_path, gs_path, mesh_obj_path, aligned_gs_path
485
 
486
 
487
- @spaces.GPU
488
  def extract_3d_representations_v2(
489
  state: dict,
490
  enable_delight: bool,
@@ -492,7 +496,7 @@ def extract_3d_representations_v2(
492
  ):
493
  output_root = TMP_DIR
494
  user_dir = os.path.join(output_root, str(req.session_hash))
495
- gs_model, mesh_model = unpack_state(state)
496
 
497
  filename = "sample"
498
  gs_path = os.path.join(user_dir, f"{filename}_gs.ply")
@@ -538,12 +542,9 @@ def extract_3d_representations_v2(
538
  mesh_glb_path = os.path.join(user_dir, f"{filename}.glb")
539
  mesh.export(mesh_glb_path)
540
 
541
- torch.cuda.empty_cache()
542
-
543
  return mesh_glb_path, gs_path, mesh_obj_path, aligned_gs_path
544
 
545
 
546
- @spaces.GPU
547
  def extract_urdf(
548
  gs_path: str,
549
  mesh_obj_path: str,
@@ -556,7 +557,8 @@ def extract_urdf(
556
  output_root = TMP_DIR
557
  if req is not None:
558
  output_root = os.path.join(output_root, str(req.session_hash))
559
- # Convert to URDF and recover attrs by gpt4o
 
560
  filename = "sample"
561
  urdf_convertor = URDFGenerator(GPT_CLIENT, render_view_num=4)
562
  asset_attrs = {
@@ -635,8 +637,6 @@ def extract_urdf(
635
  output_zip=f"{output_root}/urdf_{filename}.zip",
636
  )
637
 
638
- torch.cuda.empty_cache()
639
-
640
  estimated_type = urdf_convertor.estimated_attrs["category"]
641
  estimated_height = urdf_convertor.estimated_attrs["height"]
642
  estimated_mass = urdf_convertor.estimated_attrs["mass"]
@@ -660,7 +660,6 @@ def text2image_fn(
660
  ip_adapt_scale: float = 0.3,
661
  image_wh: int | tuple[int, int] = [1024, 1024],
662
  n_sample: int = 3,
663
- postprocess: bool = True,
664
  req: gr.Request = None,
665
  ):
666
  if isinstance(image_wh, int):
@@ -683,10 +682,10 @@ def text2image_fn(
683
  image_wh=image_wh,
684
  infer_step=infer_step,
685
  )
686
- if postprocess:
687
- for idx in range(len(images)):
688
- image = images[idx]
689
- images[idx] = preprocess_image_fn(image, req)
690
 
691
  save_paths = []
692
  for idx, image in enumerate(images):
@@ -705,18 +704,11 @@ def text2image_fn(
705
  @spaces.GPU
706
  def generate_condition(mesh_path: str, req: gr.Request, uuid: str = "sample"):
707
  output_root = os.path.join(TMP_DIR, str(req.session_hash))
708
- command = [
709
- "drender-cli",
710
- "--mesh_path",
711
- mesh_path,
712
- "--output_root",
713
- f"{output_root}/condition",
714
- "--uuid",
715
- f"{uuid}",
716
- ]
717
 
718
- _ = subprocess.run(
719
- command, capture_output=True, text=True, encoding="utf-8"
 
 
720
  )
721
 
722
  gc.collect()
@@ -764,7 +756,6 @@ def generate_texture_mvimages(
764
  return img_save_paths + img_save_paths
765
 
766
 
767
- @spaces.GPU
768
  def backproject_texture(
769
  mesh_path: str,
770
  input_image: str,
@@ -864,32 +855,19 @@ def render_result_video(
864
  ) -> str:
865
  output_root = os.path.join(TMP_DIR, str(req.session_hash))
866
  output_dir = os.path.join(output_root, "texture_mesh")
867
- command = [
868
- "drender-cli",
869
- "--mesh_path",
870
- mesh_path,
871
- "--output_root",
872
- output_dir,
873
- "--num_images",
874
- "90",
875
- "--elevation",
876
- "20",
877
- "--with_mtl",
878
- "--pbr_light_factor",
879
- "1.",
880
- "--uuid",
881
- f"{uuid}",
882
- "--gen_color_mp4",
883
- "--gen_glonormal_mp4",
884
- "--distance",
885
- "5.5",
886
- "--resolution_hw",
887
- f"{video_size}",
888
- f"{video_size}",
889
- ]
890
 
891
- _ = subprocess.run(
892
- command, capture_output=True, text=True, encoding="utf-8"
 
 
 
 
 
 
 
 
 
 
893
  )
894
 
895
  gc.collect()
 
16
  from PIL import Image
17
  from tqdm import tqdm
18
  from asset3d_gen.data.backproject_v2 import entrypoint as backproject_api
19
+ from asset3d_gen.data.differentiable_render import entrypoint as render_api
20
  from asset3d_gen.models.delight_model import DelightingModel
21
  from asset3d_gen.models.gs_model import GaussianOperator
22
  from asset3d_gen.models.segment_model import (
 
72
  "~/.cache/torch_extensions"
73
  )
74
  os.environ["GRADIO_ANALYTICS_ENABLED"] = "false"
75
+ os.environ["SPCONV_ALGO"] = "native"
76
 
77
  MAX_SEED = 100000
78
  DELIGHT = DelightingModel()
 
83
  logger.info(f"Download kolors weights from huggingface...")
84
  subprocess.run(
85
  [
86
+ "huggingface-cli",
87
+ "download",
88
+ "--resume-download",
89
+ "Kwai-Kolors/Kolors",
90
+ "--local-dir",
91
+ "weights/Kolors",
92
  ],
93
+ check=True,
94
  )
95
  subprocess.run(
96
  [
97
+ "huggingface-cli",
98
+ "download",
99
+ "--resume-download",
100
+ "Kwai-Kolors/Kolors-IP-Adapter-Plus",
101
+ "--local-dir",
102
+ "weights/Kolors-IP-Adapter-Plus",
103
  ],
104
+ check=True,
105
  )
106
 
107
 
 
129
  if not os.path.exists(text_model_dir):
130
  download_kolors_weights()
131
 
132
+ PIPELINE_IMG_IP = build_text2img_ip_pipeline(text_model_dir, ref_scale=0.3)
 
 
133
  PIPELINE_IMG = build_text2img_pipeline(text_model_dir)
134
  SEG_CHECKER = ImageSegChecker(GPT_CLIENT)
135
  GEO_CHECKER = MeshGeoChecker(GPT_CLIENT)
 
162
  lighting_css = """
163
  <style>
164
  #lighter_mesh canvas {
165
+ filter: brightness(1.8) !important;
166
  }
167
  </style>
168
  """
 
305
  return Image.open(image_path).resize((512, 512))
306
 
307
 
 
308
  def pack_state(gs: Gaussian, mesh: MeshExtractResult) -> dict:
309
  return {
310
  "gaussian": {
 
323
 
324
 
325
  @spaces.GPU
326
+ def unpack_state(state: dict, device: str = "cuda") -> tuple[Gaussian, dict]:
327
  gs = Gaussian(
328
  aabb=state["gaussian"]["aabb"],
329
  sh_degree=state["gaussian"]["sh_degree"],
 
332
  opacity_bias=state["gaussian"]["opacity_bias"],
333
  scaling_activation=state["gaussian"]["scaling_activation"],
334
  )
335
+ gs._xyz = torch.tensor(state["gaussian"]["_xyz"], device=device)
336
  gs._features_dc = torch.tensor(
337
+ state["gaussian"]["_features_dc"], device=device
338
  )
339
+ gs._scaling = torch.tensor(state["gaussian"]["_scaling"], device=device)
340
+ gs._rotation = torch.tensor(state["gaussian"]["_rotation"], device=device)
341
+ gs._opacity = torch.tensor(state["gaussian"]["_opacity"], device=device)
342
 
343
  mesh = edict(
344
+ vertices=torch.tensor(state["mesh"]["vertices"], device=device),
345
+ faces=torch.tensor(state["mesh"]["faces"], device=device),
346
  )
347
 
348
  return gs, mesh
 
489
  return mesh_glb_path, gs_path, mesh_obj_path, aligned_gs_path
490
 
491
 
 
492
  def extract_3d_representations_v2(
493
  state: dict,
494
  enable_delight: bool,
 
496
  ):
497
  output_root = TMP_DIR
498
  user_dir = os.path.join(output_root, str(req.session_hash))
499
+ gs_model, mesh_model = unpack_state(state, device="cpu")
500
 
501
  filename = "sample"
502
  gs_path = os.path.join(user_dir, f"{filename}_gs.ply")
 
542
  mesh_glb_path = os.path.join(user_dir, f"{filename}.glb")
543
  mesh.export(mesh_glb_path)
544
 
 
 
545
  return mesh_glb_path, gs_path, mesh_obj_path, aligned_gs_path
546
 
547
 
 
548
  def extract_urdf(
549
  gs_path: str,
550
  mesh_obj_path: str,
 
557
  output_root = TMP_DIR
558
  if req is not None:
559
  output_root = os.path.join(output_root, str(req.session_hash))
560
+
561
+ # Convert to URDF and recover attrs by GPT.
562
  filename = "sample"
563
  urdf_convertor = URDFGenerator(GPT_CLIENT, render_view_num=4)
564
  asset_attrs = {
 
637
  output_zip=f"{output_root}/urdf_{filename}.zip",
638
  )
639
 
 
 
640
  estimated_type = urdf_convertor.estimated_attrs["category"]
641
  estimated_height = urdf_convertor.estimated_attrs["height"]
642
  estimated_mass = urdf_convertor.estimated_attrs["mass"]
 
660
  ip_adapt_scale: float = 0.3,
661
  image_wh: int | tuple[int, int] = [1024, 1024],
662
  n_sample: int = 3,
 
663
  req: gr.Request = None,
664
  ):
665
  if isinstance(image_wh, int):
 
682
  image_wh=image_wh,
683
  infer_step=infer_step,
684
  )
685
+
686
+ for idx in range(len(images)):
687
+ image = images[idx]
688
+ images[idx], _ = preprocess_image_fn(image)
689
 
690
  save_paths = []
691
  for idx, image in enumerate(images):
 
704
  @spaces.GPU
705
  def generate_condition(mesh_path: str, req: gr.Request, uuid: str = "sample"):
706
  output_root = os.path.join(TMP_DIR, str(req.session_hash))
 
 
 
 
 
 
 
 
 
707
 
708
+ _ = render_api(
709
+ mesh_path=mesh_path,
710
+ output_root=f"{output_root}/condition",
711
+ uuid=str(uuid),
712
  )
713
 
714
  gc.collect()
 
756
  return img_save_paths + img_save_paths
757
 
758
 
 
759
  def backproject_texture(
760
  mesh_path: str,
761
  input_image: str,
 
855
  ) -> str:
856
  output_root = os.path.join(TMP_DIR, str(req.session_hash))
857
  output_dir = os.path.join(output_root, "texture_mesh")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
858
 
859
+ _ = render_api(
860
+ mesh_path=mesh_path,
861
+ output_root=output_dir,
862
+ num_images=90,
863
+ elevation=[20],
864
+ with_mtl=True,
865
+ pbr_light_factor=1,
866
+ uuid=str(uuid),
867
+ gen_color_mp4=True,
868
+ gen_glonormal_mp4=True,
869
+ distance=5.5,
870
+ resolution_hw=(video_size, video_size),
871
  )
872
 
873
  gc.collect()
requirements.txt CHANGED
@@ -3,9 +3,10 @@
3
 
4
  torch==2.4.0
5
  torchvision==0.19.0
6
- xformers==0.0.27.post2
7
  pytorch-lightning==2.4.0
8
  spconv-cu120==2.3.6
 
9
  dataclasses_json
10
  easydict
11
  opencv-python>4.5
 
3
 
4
  torch==2.4.0
5
  torchvision==0.19.0
6
+ xformers==0.0.28.post1
7
  pytorch-lightning==2.4.0
8
  spconv-cu120==2.3.6
9
+ triton==2.1.0
10
  dataclasses_json
11
  easydict
12
  opencv-python>4.5