ReubenSun commited on
Commit
5c326b3
·
1 Parent(s): 691c14e

Revert "texture sync"

Browse files

This reverts commit 55f226f582932e6ec64e096f296c54d47a59de80.

step1x3d_texture/pipelines/ig2mv_sdxl_pipeline.py CHANGED
@@ -51,20 +51,6 @@ from ..models.attention_processor import (
51
  DecoupledMVRowSelfAttnProcessor2_0,
52
  set_unet_2d_condition_attn_processor,
53
  )
54
- import random
55
- from ..texture_sync.project import UVProjection as UVP
56
- from ..texture_sync.step_sync import step_tex_sync
57
- from trimesh import Trimesh
58
- from torchvision.transforms import Compose, Resize, GaussianBlur, InterpolationMode
59
- from diffusers.utils import (
60
- BaseOutput,
61
- numpy_to_pil,
62
- pt_to_pil,
63
- is_accelerate_available,
64
- is_accelerate_version,
65
- logging,
66
- replace_example_docstring
67
- )
68
 
69
  logger = logging.get_logger(__name__) # pylint: disable=invalid-name
70
 
@@ -84,27 +70,6 @@ def retrieve_latents(
84
  raise AttributeError("Could not access latents of provided encoder_output")
85
 
86
 
87
- @torch.no_grad()
88
- def composite_rendered_view(scheduler, backgrounds, foregrounds, masks, t):
89
- composited_images = []
90
- for i, (background, foreground, mask) in enumerate(zip(backgrounds, foregrounds, masks)):
91
- if t > 0:
92
- alphas_cumprod = scheduler.alphas_cumprod[t]
93
- noise = torch.normal(0, 1, background.shape, device=background.device)
94
- background = (1-alphas_cumprod) * noise + alphas_cumprod * background
95
- composited = foreground * mask + background * (1-mask)
96
- composited_images.append(composited)
97
- composited_tensor = torch.stack(composited_images)
98
- return composited_tensor
99
-
100
-
101
- @torch.no_grad()
102
- def encode_latents(vae, imgs):
103
- imgs = (imgs-0.5)*2
104
- latents = vae.encode(imgs).latent_dist.sample()
105
- latents = vae.config.scaling_factor * latents
106
- return latents
107
-
108
  class IG2MVSDXLPipeline(StableDiffusionXLPipeline, CustomAdapterMixin):
109
  def __init__(
110
  self,
@@ -344,8 +309,6 @@ class IG2MVSDXLPipeline(StableDiffusionXLPipeline, CustomAdapterMixin):
344
  # Image condition
345
  reference_image: Optional[PipelineImageInput] = None,
346
  reference_conditioning_scale: Optional[float] = 1.0,
347
- mesh: Optional[Trimesh] = None,
348
- texture_sync_config: Optional[dict] = None,
349
  **kwargs,
350
  ):
351
  r"""
@@ -593,27 +556,6 @@ class IG2MVSDXLPipeline(StableDiffusionXLPipeline, CustomAdapterMixin):
593
  latents,
594
  )
595
 
596
- # texture patams init
597
- texture_size = texture_sync_config["texture_size"]
598
- latent_size = texture_sync_config["latent_size"]
599
- elevations = texture_sync_config["elevations"]
600
- azimuths = texture_sync_config["azimuths"]
601
- texture_sync_ratio = texture_sync_config["texture_sync_ratio"]
602
- camera_poses = [(elv, azim) for elv, azim in zip(elevations, azimuths)]
603
- uvp = UVP(texture_size=texture_size, render_size=latent_size, sampling_mode="nearest", channels=4, device=self._execution_device)
604
- uvp.load_mesh(mesh, scale_factor=1.0, autouv=True)
605
- uvp.set_cameras_and_render_settings(camera_poses, centers=None, camera_distance=texture_sync_config["camera_distance"], scale=((1.0, 1.0, 1.0),))
606
-
607
- latent_tex = uvp.set_noise_texture()
608
- noise_views = uvp.render_textured_views()
609
- foregrounds = [view[:-1] for view in noise_views]
610
- masks = [view[-1:] for view in noise_views]
611
-
612
- if texture_sync_ratio>0:
613
- composited_tensor = composite_rendered_view(self.scheduler, latents, foregrounds, masks, int(timesteps[0].cpu().item())+1)
614
- latents = composited_tensor.type(latents.dtype)
615
- uvp.to("cpu")
616
-
617
  # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
618
  extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
619
 
@@ -767,36 +709,6 @@ class IG2MVSDXLPipeline(StableDiffusionXLPipeline, CustomAdapterMixin):
767
  ).to(device=device, dtype=latents.dtype)
768
 
769
  self._num_timesteps = len(timesteps)
770
-
771
-
772
- # texture sync params
773
- exp_start = texture_sync_config["exp_start"]
774
- exp_end = texture_sync_config["exp_end"]
775
- shuffle_background_change = texture_sync_config["shuffle_background_change"]
776
- shuffle_background_end = texture_sync_config["shuffle_background_end"]
777
- num_timesteps = self.scheduler.config.num_train_timesteps
778
-
779
- uvp.to(self._execution_device)
780
- color_constants = {"black": [-1, -1, -1], "white": [1, 1, 1], "maroon": [0, -1, -1],
781
- "red": [1, -1, -1], "olive": [0, 0, -1], "yellow": [1, 1, -1],
782
- "green": [-1, 0, -1], "lime": [-1 ,1, -1], "teal": [-1, 0, 0],
783
- "aqua": [-1, 1, 1], "navy": [-1, -1, 0], "blue": [-1, -1, 1],
784
- "purple": [0, -1 , 0], "fuchsia": [1, -1, 1]}
785
- color_names = list(color_constants.keys())
786
- background_colors = [random.choice(list(color_constants.keys())) for i in range(len(camera_poses))]
787
- intermediate_results = []
788
- self.upcast_vae()
789
- self.vae.config.force_upcast = True
790
- color_images = torch.FloatTensor([color_constants[name] for name in color_names]).reshape(-1,3,1,1).to(dtype=torch.float32, device=self._execution_device)
791
- color_images = torch.ones(
792
- (1,1,latent_size*8, latent_size*8),
793
- device=self._execution_device,
794
- dtype=torch.float32
795
- ) * color_images
796
- color_images = ((0.5*color_images)+0.5)
797
- color_latents = encode_latents(self.vae, color_images).to(dtype=self.text_encoder_2.dtype)
798
- color_latents = {color[0]:color[1] for color in zip(color_names, [latent for latent in color_latents])}
799
-
800
  with self.progress_bar(total=num_inference_steps) as progress_bar:
801
  for i, t in enumerate(timesteps):
802
  if self.interrupt:
@@ -856,49 +768,9 @@ class IG2MVSDXLPipeline(StableDiffusionXLPipeline, CustomAdapterMixin):
856
 
857
  # compute the previous noisy sample x_t -> x_t-1
858
  latents_dtype = latents.dtype
859
-
860
- # texture sync
861
- current_exp = ((exp_end-exp_start) * i / num_inference_steps) + exp_start
862
- if t > (1-texture_sync_ratio)*num_timesteps:
863
- step_results = step_tex_sync(
864
- scheduler=self.scheduler,
865
- uvp=uvp,
866
- model_output=noise_pred,
867
- timestep=t,
868
- sample=latents,
869
- texture=latent_tex,
870
- return_dict=True,
871
- main_views=[],
872
- exp= current_exp,
873
- **extra_step_kwargs
874
- )
875
-
876
- pred_original_sample = step_results["pred_original_sample"]
877
- latents = step_results["prev_sample"]
878
- latent_tex = step_results["prev_tex"]
879
-
880
- # Composit latent foreground with random color background
881
- background_latents = [color_latents[color] for color in background_colors]
882
- composited_tensor = composite_rendered_view(self.scheduler, background_latents, latents, masks, t)
883
- latents = composited_tensor.type(latents.dtype)
884
-
885
- intermediate_results.append((latents.to("cpu"), pred_original_sample.to("cpu")))
886
- else:
887
- step_results = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=True)
888
- pred_original_sample = step_results["pred_original_sample"]
889
- latents = step_results["prev_sample"]
890
- latent_tex = None
891
- intermediate_results.append((latents.to("cpu"), pred_original_sample.to("cpu")))
892
-
893
- # 2. Shuffle background colors; only black and white used after certain timestep
894
- if (1-t/num_timesteps) < shuffle_background_change:
895
- background_colors = [random.choice(list(color_constants.keys())) for i in range(len(camera_poses))]
896
- elif (1-t/num_timesteps) < shuffle_background_end:
897
- background_colors = [random.choice(["black","white"]) for i in range(len(camera_poses))]
898
- else:
899
- background_colors = background_colors
900
- del noise_pred
901
-
902
  if latents.dtype != latents_dtype:
903
  if torch.backends.mps.is_available():
904
  # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
 
51
  DecoupledMVRowSelfAttnProcessor2_0,
52
  set_unet_2d_condition_attn_processor,
53
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54
 
55
  logger = logging.get_logger(__name__) # pylint: disable=invalid-name
56
 
 
70
  raise AttributeError("Could not access latents of provided encoder_output")
71
 
72
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
73
  class IG2MVSDXLPipeline(StableDiffusionXLPipeline, CustomAdapterMixin):
74
  def __init__(
75
  self,
 
309
  # Image condition
310
  reference_image: Optional[PipelineImageInput] = None,
311
  reference_conditioning_scale: Optional[float] = 1.0,
 
 
312
  **kwargs,
313
  ):
314
  r"""
 
556
  latents,
557
  )
558
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
559
  # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
560
  extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
561
 
 
709
  ).to(device=device, dtype=latents.dtype)
710
 
711
  self._num_timesteps = len(timesteps)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
712
  with self.progress_bar(total=num_inference_steps) as progress_bar:
713
  for i, t in enumerate(timesteps):
714
  if self.interrupt:
 
768
 
769
  # compute the previous noisy sample x_t -> x_t-1
770
  latents_dtype = latents.dtype
771
+ latents = self.scheduler.step(
772
+ noise_pred, t, latents, **extra_step_kwargs, return_dict=False
773
+ )[0]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
774
  if latents.dtype != latents_dtype:
775
  if torch.backends.mps.is_available():
776
  # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
step1x3d_texture/pipelines/step1x_3d_texture_synthesis_pipeline.py CHANGED
@@ -24,6 +24,7 @@ import trimesh
24
  import xatlas
25
  import scipy.sparse
26
  from scipy.sparse.linalg import spsolve
 
27
  from step1x3d_geometry.models.pipelines.pipeline_utils import smart_load_model
28
 
29
 
@@ -35,7 +36,7 @@ class Step1X3DTextureConfig:
35
  self.unet_model = None
36
  self.lora_model = None
37
  self.adapter_path = "stepfun-ai/Step1X-3D"
38
- self.scheduler = "ddpm"
39
  self.num_views = 6
40
  self.device = "cuda" if torch.cuda.is_available() else "cpu"
41
  self.dtype = torch.float16
@@ -60,20 +61,6 @@ class Step1X3DTextureConfig:
60
  self.bake_exp = 4
61
  self.merge_method = "fast"
62
 
63
- # texture sync params
64
- self.texture_sync_config = {
65
- "texture_size": 1536,
66
- "latent_size": 768//8,
67
- "elevations": [0, 0, 0, 0, 90, -90],
68
- "azimuths": [0, 90, 180, 270, 0, 0],
69
- "texture_sync_ratio": 0.5,
70
- "exp_end": 6.0,
71
- "exp_start": 0,
72
- "shuffle_background_change": 0.4,
73
- "shuffle_background_end": 0.99,
74
- "camera_distance": 1.8
75
- }
76
-
77
 
78
  class Step1X3DTexturePipeline:
79
  def __init__(self, config):
@@ -133,9 +120,11 @@ class Step1X3DTexturePipeline:
133
  if unet_model is not None:
134
  pipe_kwargs["unet"] = UNet2DConditionModel.from_pretrained(unet_model)
135
 
 
136
  # Prepare pipeline
137
  pipe = IG2MVSDXLPipeline.from_pretrained(base_model, **pipe_kwargs)
138
 
 
139
  # Load scheduler if provided
140
  scheduler_class = None
141
  if scheduler == "ddpm":
@@ -149,11 +138,14 @@ class Step1X3DTexturePipeline:
149
  shift_scale=8.0,
150
  scheduler_class=scheduler_class,
151
  )
 
152
  pipe.init_custom_adapter(
153
  num_views=num_views,
154
  self_attn_processor=DecoupledMVRowColSelfAttnProcessor2_0,
155
  )
 
156
  pipe.load_custom_adapter(adapter_path, "step1x-3d-ig2v.safetensors")
 
157
  pipe.to(device=device, dtype=dtype)
158
  pipe.cond_encoder.to(device=device, dtype=dtype)
159
 
@@ -290,7 +282,6 @@ class Step1X3DTexturePipeline:
290
  negative_prompt=negative_prompt,
291
  cross_attention_kwargs={"scale": lora_scale},
292
  mesh=mesh_bp,
293
- texture_sync_config=self.config.texture_sync_config,
294
  **pipe_kwargs,
295
  ).images
296
 
@@ -368,7 +359,7 @@ class Step1X3DTexturePipeline:
368
  width=768,
369
  num_inference_steps=self.config.num_inference_steps,
370
  guidance_scale=self.config.guidance_scale,
371
- seed= seed if seed is not None else self.config.seed,
372
  lora_scale=self.config.lora_scale,
373
  reference_conditioning_scale=self.config.reference_conditioning_scale,
374
  negative_prompt=self.config.negative_prompt,
 
24
  import xatlas
25
  import scipy.sparse
26
  from scipy.sparse.linalg import spsolve
27
+
28
  from step1x3d_geometry.models.pipelines.pipeline_utils import smart_load_model
29
 
30
 
 
36
  self.unet_model = None
37
  self.lora_model = None
38
  self.adapter_path = "stepfun-ai/Step1X-3D"
39
+ self.scheduler = None
40
  self.num_views = 6
41
  self.device = "cuda" if torch.cuda.is_available() else "cpu"
42
  self.dtype = torch.float16
 
61
  self.bake_exp = 4
62
  self.merge_method = "fast"
63
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
64
 
65
  class Step1X3DTexturePipeline:
66
  def __init__(self, config):
 
120
  if unet_model is not None:
121
  pipe_kwargs["unet"] = UNet2DConditionModel.from_pretrained(unet_model)
122
 
123
+ print('VAE Loaded!')
124
  # Prepare pipeline
125
  pipe = IG2MVSDXLPipeline.from_pretrained(base_model, **pipe_kwargs)
126
 
127
+ print('Base model Loaded!')
128
  # Load scheduler if provided
129
  scheduler_class = None
130
  if scheduler == "ddpm":
 
138
  shift_scale=8.0,
139
  scheduler_class=scheduler_class,
140
  )
141
+ print('Scheduler Loaded!')
142
  pipe.init_custom_adapter(
143
  num_views=num_views,
144
  self_attn_processor=DecoupledMVRowColSelfAttnProcessor2_0,
145
  )
146
+ print(f'Load adapter from {adapter_path}/step1x-3d-ig2v.safetensors')
147
  pipe.load_custom_adapter(adapter_path, "step1x-3d-ig2v.safetensors")
148
+ print(f'Load adapter successed!')
149
  pipe.to(device=device, dtype=dtype)
150
  pipe.cond_encoder.to(device=device, dtype=dtype)
151
 
 
282
  negative_prompt=negative_prompt,
283
  cross_attention_kwargs={"scale": lora_scale},
284
  mesh=mesh_bp,
 
285
  **pipe_kwargs,
286
  ).images
287
 
 
359
  width=768,
360
  num_inference_steps=self.config.num_inference_steps,
361
  guidance_scale=self.config.guidance_scale,
362
+ seed=seed if seed is not None else self.config.seed,
363
  lora_scale=self.config.lora_scale,
364
  reference_conditioning_scale=self.config.reference_conditioning_scale,
365
  negative_prompt=self.config.negative_prompt,
step1x3d_texture/{texture_sync → renderer}/__init__.py RENAMED
File without changes
step1x3d_texture/renderer/geometry.py ADDED
@@ -0,0 +1,151 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import pytorch3d
3
+ import torch.nn.functional as F
4
+
5
+ from pytorch3d.ops import interpolate_face_attributes
6
+
7
+ from pytorch3d.renderer import (
8
+ look_at_view_transform,
9
+ FoVPerspectiveCameras,
10
+ AmbientLights,
11
+ PointLights,
12
+ DirectionalLights,
13
+ Materials,
14
+ RasterizationSettings,
15
+ MeshRenderer,
16
+ MeshRasterizer,
17
+ SoftPhongShader,
18
+ SoftSilhouetteShader,
19
+ HardPhongShader,
20
+ TexturesVertex,
21
+ TexturesUV,
22
+ Materials,
23
+ )
24
+ from pytorch3d.renderer.blending import BlendParams, hard_rgb_blend
25
+ from pytorch3d.renderer.utils import convert_to_tensors_and_broadcast, TensorProperties
26
+ from pytorch3d.renderer.mesh.shader import ShaderBase
27
+
28
+
29
+ def get_cos_angle(points, normals, camera_position):
30
+ """
31
+ calculate cosine similarity between view->surface and surface normal.
32
+ """
33
+
34
+ if points.shape != normals.shape:
35
+ msg = "Expected points and normals to have the same shape: got %r, %r"
36
+ raise ValueError(msg % (points.shape, normals.shape))
37
+
38
+ # Ensure all inputs have same batch dimension as points
39
+ matched_tensors = convert_to_tensors_and_broadcast(
40
+ points, camera_position, device=points.device
41
+ )
42
+ _, camera_position = matched_tensors
43
+
44
+ # Reshape direction and color so they have all the arbitrary intermediate
45
+ # dimensions as points. Assume first dim = batch dim and last dim = 3.
46
+ points_dims = points.shape[1:-1]
47
+ expand_dims = (-1,) + (1,) * len(points_dims)
48
+
49
+ if camera_position.shape != normals.shape:
50
+ camera_position = camera_position.view(expand_dims + (3,))
51
+
52
+ normals = F.normalize(normals, p=2, dim=-1, eps=1e-6)
53
+
54
+ # Calculate the cosine value.
55
+ view_direction = camera_position - points
56
+ view_direction = F.normalize(view_direction, p=2, dim=-1, eps=1e-6)
57
+ cos_angle = torch.sum(view_direction * normals, dim=-1, keepdim=True)
58
+ cos_angle = cos_angle.clamp(0, 1)
59
+
60
+ # Cosine of the angle between the reflected light ray and the viewer
61
+ return cos_angle
62
+
63
+
64
+ def _geometry_shading_with_pixels(
65
+ meshes, fragments, lights, cameras, materials, texels
66
+ ):
67
+ """
68
+ Render pixel space vertex position, normal(world), depth, and cos angle
69
+
70
+ Args:
71
+ meshes: Batch of meshes
72
+ fragments: Fragments named tuple with the outputs of rasterization
73
+ lights: Lights class containing a batch of lights
74
+ cameras: Cameras class containing a batch of cameras
75
+ materials: Materials class containing a batch of material properties
76
+ texels: texture per pixel of shape (N, H, W, K, 3)
77
+
78
+ Returns:
79
+ colors: (N, H, W, K, 3)
80
+ pixel_coords: (N, H, W, K, 3), camera coordinates of each intersection.
81
+ """
82
+ verts = meshes.verts_packed() # (V, 3)
83
+ faces = meshes.faces_packed() # (F, 3)
84
+ vertex_normals = meshes.verts_normals_packed() # (V, 3)
85
+ faces_verts = verts[faces]
86
+ faces_normals = vertex_normals[faces]
87
+ pixel_coords_in_camera = interpolate_face_attributes(
88
+ fragments.pix_to_face, fragments.bary_coords, faces_verts
89
+ )
90
+ pixel_normals = interpolate_face_attributes(
91
+ fragments.pix_to_face, fragments.bary_coords, faces_normals
92
+ )
93
+
94
+ cos_angles = get_cos_angle(
95
+ pixel_coords_in_camera, pixel_normals, cameras.get_camera_center()
96
+ )
97
+
98
+ return pixel_coords_in_camera, pixel_normals, fragments.zbuf[..., None], cos_angles
99
+
100
+
101
+ class HardGeometryShader(ShaderBase):
102
+ """
103
+ renders common geometric informations.
104
+
105
+
106
+ """
107
+
108
+ def forward(self, fragments, meshes, **kwargs):
109
+ cameras = super()._get_cameras(**kwargs)
110
+ texels = self.texel_from_uv(fragments, meshes)
111
+
112
+ lights = kwargs.get("lights", self.lights)
113
+ materials = kwargs.get("materials", self.materials)
114
+ blend_params = kwargs.get("blend_params", self.blend_params)
115
+ verts, normals, depths, cos_angles = _geometry_shading_with_pixels(
116
+ meshes=meshes,
117
+ fragments=fragments,
118
+ texels=texels,
119
+ lights=lights,
120
+ cameras=cameras,
121
+ materials=materials,
122
+ )
123
+ texels = meshes.sample_textures(fragments)
124
+ verts = hard_rgb_blend(verts, fragments, blend_params)
125
+ normals = hard_rgb_blend(normals, fragments, blend_params)
126
+ depths = hard_rgb_blend(depths, fragments, blend_params)
127
+ cos_angles = hard_rgb_blend(cos_angles, fragments, blend_params)
128
+ from IPython import embed
129
+
130
+ embed()
131
+ texels = hard_rgb_blend(texels, fragments, blend_params)
132
+ return verts, normals, depths, cos_angles, texels, fragments
133
+
134
+ def texel_from_uv(self, fragments, meshes):
135
+ texture_tmp = meshes.textures
136
+ maps_tmp = texture_tmp.maps_padded()
137
+ uv_color = [[[1, 0], [1, 1]], [[0, 0], [0, 1]]]
138
+ uv_color = (
139
+ torch.FloatTensor(uv_color).to(maps_tmp[0].device).type(maps_tmp[0].dtype)
140
+ )
141
+ uv_texture = TexturesUV(
142
+ [uv_color.clone() for t in maps_tmp],
143
+ texture_tmp.faces_uvs_padded(),
144
+ texture_tmp.verts_uvs_padded(),
145
+ sampling_mode="bilinear",
146
+ )
147
+ meshes.textures = uv_texture
148
+ texels = meshes.sample_textures(fragments)
149
+ meshes.textures = texture_tmp
150
+ texels = torch.cat((texels, texels[..., -1:] * 0), dim=-1)
151
+ return texels
step1x3d_texture/renderer/project.py ADDED
@@ -0,0 +1,875 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import pytorch3d
3
+
4
+
5
+ from pytorch3d.io import load_objs_as_meshes, load_obj, save_obj, IO
6
+
7
+ from pytorch3d.structures import Meshes
8
+ from pytorch3d.renderer import (
9
+ look_at_view_transform,
10
+ FoVPerspectiveCameras,
11
+ FoVOrthographicCameras,
12
+ AmbientLights,
13
+ PointLights,
14
+ DirectionalLights,
15
+ Materials,
16
+ RasterizationSettings,
17
+ MeshRenderer,
18
+ MeshRasterizer,
19
+ TexturesUV,
20
+ )
21
+
22
+ from .geometry import HardGeometryShader
23
+ from .shader import HardNChannelFlatShader
24
+ from .voronoi import voronoi_solve
25
+ import torch.nn.functional as F
26
+ import open3d as o3d
27
+ import pdb
28
+ import kaolin as kal
29
+ import numpy as np
30
+
31
+
32
+ import torch
33
+ from pytorch3d.renderer.cameras import FoVOrthographicCameras
34
+ from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
35
+ from pytorch3d.common.datatypes import Device
36
+ import math
37
+ import torch.nn.functional as F
38
+ from trimesh import Trimesh
39
+ from pytorch3d.structures import Meshes
40
+ import os
41
+
42
+ LIST_TYPE = Union[list, np.ndarray, torch.Tensor]
43
+
44
+ _R = torch.eye(3)[None] # (1, 3, 3)
45
+ _T = torch.zeros(1, 3) # (1, 3)
46
+ _BatchFloatType = Union[float, Sequence[float], torch.Tensor]
47
+
48
+
49
+ class CustomOrthographicCameras(FoVOrthographicCameras):
50
+ def compute_projection_matrix(
51
+ self, znear, zfar, max_x, min_x, max_y, min_y, scale_xyz
52
+ ) -> torch.Tensor:
53
+ """
54
+ 自定义正交投影矩阵计算,继承并修改深度通道参数
55
+ 参数维度说明:
56
+ - znear/zfar: (N,)
57
+ - max_x/min_x: (N,)
58
+ - max_y/min_y: (N,)
59
+ - scale_xyz: (N, 3)
60
+ """
61
+ K = torch.zeros((self._N, 4, 4), dtype=torch.float32, device=self.device)
62
+
63
+ ones = torch.ones((self._N), dtype=torch.float32, device=self.device)
64
+ # NOTE: OpenGL flips handedness of coordinate system between camera
65
+ # space and NDC space so z sign is -ve. In PyTorch3D we maintain a
66
+ # right handed coordinate system throughout.
67
+ z_sign = +1.0
68
+
69
+ K[:, 0, 0] = (2.0 / (max_x - min_x)) * scale_xyz[:, 0]
70
+ K[:, 1, 1] = (2.0 / (max_y - min_y)) * scale_xyz[:, 1]
71
+ K[:, 0, 3] = -(max_x + min_x) / (max_x - min_x)
72
+ K[:, 1, 3] = -(max_y + min_y) / (max_y - min_y)
73
+ K[:, 3, 3] = ones
74
+
75
+ # NOTE: This maps the z coordinate to the range [0, 1] and replaces the
76
+ # the OpenGL z normalization to [-1, 1]
77
+ K[:, 2, 2] = -2 * (1.0 / (zfar - znear)) * scale_xyz[:, 2]
78
+ K[:, 2, 3] = -(znear + zfar) / (zfar - znear)
79
+
80
+ return K
81
+
82
+ def __init__(
83
+ self,
84
+ znear: _BatchFloatType = 1.0,
85
+ zfar: _BatchFloatType = 100.0,
86
+ max_y: _BatchFloatType = 1.0,
87
+ min_y: _BatchFloatType = -1.0,
88
+ max_x: _BatchFloatType = 1.0,
89
+ min_x: _BatchFloatType = -1.0,
90
+ scale_xyz=((1.0, 1.0, 1.0),), # (N, 3)
91
+ R: torch.Tensor = _R,
92
+ T: torch.Tensor = _T,
93
+ K: Optional[torch.Tensor] = None,
94
+ device: Device = "cpu",
95
+ ):
96
+ # 继承父类初始化逻辑
97
+ super().__init__(
98
+ znear=znear,
99
+ zfar=zfar,
100
+ max_y=max_y,
101
+ min_y=min_y,
102
+ max_x=max_x,
103
+ min_x=min_x,
104
+ scale_xyz=scale_xyz,
105
+ R=R,
106
+ T=T,
107
+ K=K,
108
+ device=device,
109
+ )
110
+
111
+
112
+ def erode_torch_batch(binary_img_batch, kernel_size):
113
+ pad = (kernel_size - 1) // 2
114
+ bin_img = F.pad(
115
+ binary_img_batch.unsqueeze(1), pad=[pad, pad, pad, pad], mode="reflect"
116
+ )
117
+ out = -F.max_pool2d(-bin_img, kernel_size=kernel_size, stride=1, padding=0)
118
+ out = out.squeeze(1)
119
+ return out
120
+
121
+
122
+ def dilate_torch_batch(binary_img_batch, kernel_size):
123
+ pad = (kernel_size - 1) // 2
124
+ bin_img = F.pad(binary_img_batch, pad=[pad, pad, pad, pad], mode="reflect")
125
+ out = F.max_pool2d(bin_img, kernel_size=kernel_size, stride=1, padding=0)
126
+ out = out.squeeze()
127
+ return out
128
+
129
+
130
+ # Pytorch3D based renderering functions, managed in a class
131
+ # Render size is recommended to be the same as your latent view size
132
+ # DO NOT USE "bilinear" sampling when you are handling latents.
133
+ # Stable Diffusion has 4 latent channels so use channels=4
134
+
135
+
136
+ class UVProjection:
137
+ def __init__(
138
+ self,
139
+ texture_size=96,
140
+ render_size=64,
141
+ sampling_mode="nearest",
142
+ channels=3,
143
+ device=None,
144
+ ):
145
+ self.channels = channels
146
+ self.device = device or torch.device("cpu")
147
+ self.lights = AmbientLights(
148
+ ambient_color=((1.0,) * channels,), device=self.device
149
+ )
150
+ self.target_size = (texture_size, texture_size)
151
+ self.render_size = render_size
152
+ self.sampling_mode = sampling_mode
153
+
154
+ # Load obj mesh, rescale the mesh to fit into the bounding box
155
+ def load_mesh(self, mesh, scale_factor=2.0, auto_center=True, autouv=False):
156
+ if isinstance(mesh, Trimesh):
157
+ vertices = torch.tensor(mesh.vertices, dtype=torch.float32).to(self.device)
158
+ faces = torch.tensor(mesh.faces, dtype=torch.int64).to(self.device)
159
+ mesh = Meshes(verts=[vertices], faces=[faces])
160
+ verts = mesh.verts_packed()
161
+ mesh = mesh.update_padded(verts[None, :, :])
162
+ elif isinstance(mesh, str) and os.path.isfile(mesh):
163
+ mesh = load_objs_as_meshes([mesh_path], device=self.device)
164
+ if auto_center:
165
+ verts = mesh.verts_packed()
166
+ max_bb = (verts - 0).max(0)[0]
167
+ min_bb = (verts - 0).min(0)[0]
168
+ scale = (max_bb - min_bb).max() / 2
169
+ center = (max_bb + min_bb) / 2
170
+ mesh.offset_verts_(-center)
171
+ mesh.scale_verts_((scale_factor / float(scale)))
172
+ else:
173
+ mesh.scale_verts_((scale_factor))
174
+
175
+ if autouv or (mesh.textures is None):
176
+ mesh = self.uv_unwrap(mesh)
177
+ self.mesh = mesh
178
+
179
+ def load_glb_mesh(
180
+ self, mesh_path, trimesh, scale_factor=1.0, auto_center=True, autouv=False
181
+ ):
182
+ from pytorch3d.io.experimental_gltf_io import MeshGlbFormat
183
+
184
+ io = IO()
185
+ io.register_meshes_format(MeshGlbFormat())
186
+ with open(mesh_path, "rb") as f:
187
+ mesh = io.load_mesh(f, include_textures=True, device=self.device)
188
+ if auto_center:
189
+ verts = mesh.verts_packed()
190
+
191
+ max_bb = (verts - 0).max(0)[0]
192
+ min_bb = (verts - 0).min(0)[0]
193
+ scale = (max_bb - min_bb).max() / 2
194
+ center = (max_bb + min_bb) / 2
195
+ mesh.offset_verts_(-center)
196
+ mesh.scale_verts_((scale_factor / float(scale)))
197
+ verts = mesh.verts_packed()
198
+ # T = torch.tensor([[1, 0, 0], [0, 0, -1], [0, 1, 0]], device=verts.device, dtype=verts.dtype)
199
+ # T = torch.tensor([[0, 0, 1], [0, 1, 0], [-1, 0, 0]], device=verts.device, dtype=verts.dtype)
200
+ # verts = verts @ T
201
+ mesh = mesh.update_padded(verts[None, :, :])
202
+ else:
203
+ mesh.scale_verts_((scale_factor))
204
+ if autouv or (mesh.textures is None):
205
+ mesh = self.uv_unwrap(mesh)
206
+ self.mesh = mesh
207
+
208
+ # Save obj mesh
209
+ def save_mesh(self, mesh_path, texture):
210
+ save_obj(
211
+ mesh_path,
212
+ self.mesh.verts_list()[0],
213
+ self.mesh.faces_list()[0],
214
+ verts_uvs=self.mesh.textures.verts_uvs_list()[0],
215
+ faces_uvs=self.mesh.textures.faces_uvs_list()[0],
216
+ texture_map=texture,
217
+ )
218
+
219
+ # Code referred to TEXTure code (https://github.com/TEXTurePaper/TEXTurePaper.git)
220
+ def uv_unwrap(self, mesh):
221
+ verts_list = mesh.verts_list()[0]
222
+ faces_list = mesh.faces_list()[0]
223
+
224
+ import xatlas
225
+ import numpy as np
226
+
227
+ v_np = verts_list.cpu().numpy()
228
+ f_np = faces_list.int().cpu().numpy()
229
+ atlas = xatlas.Atlas()
230
+ atlas.add_mesh(v_np, f_np)
231
+ chart_options = xatlas.ChartOptions()
232
+ chart_options.max_iterations = 4
233
+ atlas.generate(chart_options=chart_options)
234
+ vmapping, ft_np, vt_np = atlas[0] # [N], [M, 3], [N, 2]
235
+
236
+ vt = (
237
+ torch.from_numpy(vt_np.astype(np.float32))
238
+ .type(verts_list.dtype)
239
+ .to(mesh.device)
240
+ )
241
+ ft = (
242
+ torch.from_numpy(ft_np.astype(np.int64))
243
+ .type(faces_list.dtype)
244
+ .to(mesh.device)
245
+ )
246
+
247
+ new_map = torch.zeros(self.target_size + (self.channels,), device=mesh.device)
248
+ new_tex = TexturesUV([new_map], [ft], [vt], sampling_mode=self.sampling_mode)
249
+
250
+ mesh.textures = new_tex
251
+ return mesh
252
+
253
+ """
254
+ A functions that disconnect faces in the mesh according to
255
+ its UV seams. The number of vertices are made equal to the
256
+ number of unique vertices its UV layout, while the faces list
257
+ is intact.
258
+ """
259
+
260
+ def disconnect_faces(self):
261
+ mesh = self.mesh
262
+ verts_list = mesh.verts_list()
263
+ faces_list = mesh.faces_list()
264
+ verts_uvs_list = mesh.textures.verts_uvs_list()
265
+ faces_uvs_list = mesh.textures.faces_uvs_list()
266
+ packed_list = [v[f] for v, f in zip(verts_list, faces_list)]
267
+ verts_disconnect_list = [
268
+ torch.zeros(
269
+ (verts_uvs_list[i].shape[0], 3),
270
+ dtype=verts_list[0].dtype,
271
+ device=verts_list[0].device,
272
+ )
273
+ for i in range(len(verts_list))
274
+ ]
275
+ for i in range(len(verts_list)):
276
+ verts_disconnect_list[i][faces_uvs_list] = packed_list[i]
277
+ assert not mesh.has_verts_normals(), "Not implemented for vertex normals"
278
+ self.mesh_d = Meshes(verts_disconnect_list, faces_uvs_list, mesh.textures)
279
+ return self.mesh_d
280
+
281
+ """
282
+ A function that construct a temp mesh for back-projection.
283
+ Take a disconnected mesh and a rasterizer, the function calculates
284
+ the projected faces as the UV, as use its original UV with pseudo
285
+ z value as world space geometry.
286
+ """
287
+
288
+ def construct_uv_mesh(self):
289
+ mesh = self.mesh_d
290
+ verts_list = mesh.verts_list()
291
+ verts_uvs_list = mesh.textures.verts_uvs_list()
292
+ # faces_list = [torch.flip(faces, [-1]) for faces in mesh.faces_list()]
293
+ new_verts_list = []
294
+ for i, (verts, verts_uv) in enumerate(zip(verts_list, verts_uvs_list)):
295
+ verts = verts.clone()
296
+ verts_uv = verts_uv.clone()
297
+ verts[..., 0:2] = verts_uv[..., :]
298
+ verts = (verts - 0.5) * 2
299
+ verts[..., 2] *= 1
300
+ new_verts_list.append(verts)
301
+ textures_uv = mesh.textures.clone()
302
+ self.mesh_uv = Meshes(new_verts_list, mesh.faces_list(), textures_uv)
303
+ return self.mesh_uv
304
+
305
+ # Set texture for the current mesh.
306
+ def set_texture_map(self, texture):
307
+ new_map = texture.permute(1, 2, 0)
308
+ new_map = new_map.to(self.device)
309
+ new_tex = TexturesUV(
310
+ [new_map],
311
+ self.mesh.textures.faces_uvs_padded(),
312
+ self.mesh.textures.verts_uvs_padded(),
313
+ sampling_mode=self.sampling_mode,
314
+ )
315
+ self.mesh.textures = new_tex
316
+
317
+ # Set the initial normal noise texture
318
+ # No generator here for replication of the experiment result. Add one as you wish
319
+ def set_noise_texture(self, channels=None):
320
+ if not channels:
321
+ channels = self.channels
322
+ noise_texture = torch.normal(
323
+ 0, 1, (channels,) + self.target_size, device=self.device
324
+ )
325
+ self.set_texture_map(noise_texture)
326
+ return noise_texture
327
+
328
+ # Set the cameras given the camera poses and centers
329
+ def set_cameras(self, camera_poses, centers=None, camera_distance=2.7, scale=None):
330
+ elev = torch.FloatTensor([pose[0] for pose in camera_poses])
331
+ azim = torch.FloatTensor([pose[1] for pose in camera_poses])
332
+ print("camera_distance:{}".format(camera_distance))
333
+ R, T = look_at_view_transform(
334
+ dist=camera_distance, elev=elev, azim=azim, at=centers or ((0, 0, 0),)
335
+ )
336
+ # flip_mat = torch.from_numpy(np.diag([-1.0, 1.0, -1.0]) ).type(torch.FloatTensor).to(R.device)
337
+ # R = R@flip_mat
338
+ # R = R.permute(0, 2, 1)
339
+ # T = T*torch.from_numpy(np.array([-1.0, 1.0, -1.0])).type(torch.FloatTensor).to(R.device)
340
+ # print("v R size:{}, v T size:{}".format(R.size(), T.size()))
341
+ # c2w = self.get_c2w(elev, [camera_distance]*len(elev), azim)
342
+ # w2c = torch.linalg.inv(c2w)
343
+ # R, T= w2c[:, :3, :3], w2c[:, :3, 3]
344
+ print("R size:{}, T size:{}".format(R.size(), T.size()))
345
+ # self.cameras = CustomOrthographicCameras(device=self.device, R=R, T=T, scale_xyz=scale or ((1,1,1),), znear=0.1, min_x=-0.55, max_x=0.55, min_y=-0.55, max_y=0.55)
346
+ self.cameras = FoVOrthographicCameras(
347
+ device=self.device, R=R, T=T, scale_xyz=scale or ((1, 1, 1),)
348
+ )
349
+
350
+ # Set all necessary internal data for rendering and texture baking
351
+ # Can be used to refresh after changing camera positions
352
+ def set_cameras_and_render_settings(
353
+ self,
354
+ camera_poses,
355
+ centers=None,
356
+ camera_distance=2.7,
357
+ render_size=None,
358
+ scale=None,
359
+ ):
360
+ self.set_cameras(camera_poses, centers, camera_distance, scale=scale)
361
+ if render_size is None:
362
+ render_size = self.render_size
363
+ if not hasattr(self, "renderer"):
364
+ self.setup_renderer(size=render_size)
365
+ if not hasattr(self, "mesh_d"):
366
+ self.disconnect_faces()
367
+ if not hasattr(self, "mesh_uv"):
368
+ self.construct_uv_mesh()
369
+ self.calculate_tex_gradient()
370
+ self.calculate_visible_triangle_mask()
371
+ _, _, _, cos_maps, _, _ = self.render_geometry()
372
+ self.calculate_cos_angle_weights(cos_maps)
373
+
374
+ # Setup renderers for rendering
375
+ # max faces per bin set to 30000 to avoid overflow in many test cases.
376
+ # You can use default value to let pytorch3d handle that for you.
377
+ def setup_renderer(
378
+ self,
379
+ size=64,
380
+ blur=0.0,
381
+ face_per_pix=1,
382
+ perspective_correct=False,
383
+ channels=None,
384
+ ):
385
+ if not channels:
386
+ channels = self.channels
387
+
388
+ self.raster_settings = RasterizationSettings(
389
+ image_size=size,
390
+ blur_radius=blur,
391
+ faces_per_pixel=face_per_pix,
392
+ perspective_correct=perspective_correct,
393
+ cull_backfaces=True,
394
+ max_faces_per_bin=30000,
395
+ )
396
+
397
+ self.renderer = MeshRenderer(
398
+ rasterizer=MeshRasterizer(
399
+ cameras=self.cameras,
400
+ raster_settings=self.raster_settings,
401
+ ),
402
+ shader=HardNChannelFlatShader(
403
+ device=self.device,
404
+ cameras=self.cameras,
405
+ lights=self.lights,
406
+ channels=channels,
407
+ # materials=materials
408
+ ),
409
+ )
410
+
411
+ # Bake screen-space cosine weights to UV space
412
+ # May be able to reimplement using the generic "bake_texture" function, but it works so leave it here for now
413
+ @torch.enable_grad()
414
+ def calculate_cos_angle_weights(self, cos_angles, fill=True, channels=None):
415
+ if not channels:
416
+ channels = self.channels
417
+ cos_maps = []
418
+ tmp_mesh = self.mesh.clone()
419
+ for i in range(len(self.cameras)):
420
+
421
+ zero_map = torch.zeros(
422
+ self.target_size + (channels,), device=self.device, requires_grad=True
423
+ )
424
+ optimizer = torch.optim.SGD([zero_map], lr=1, momentum=0)
425
+ optimizer.zero_grad()
426
+ zero_tex = TexturesUV(
427
+ [zero_map],
428
+ self.mesh.textures.faces_uvs_padded(),
429
+ self.mesh.textures.verts_uvs_padded(),
430
+ sampling_mode=self.sampling_mode,
431
+ )
432
+ tmp_mesh.textures = zero_tex
433
+
434
+ images_predicted = self.renderer(
435
+ tmp_mesh, cameras=self.cameras[i], lights=self.lights
436
+ )
437
+
438
+ loss = torch.sum((cos_angles[i, :, :, 0:1] ** 1 - images_predicted) ** 2)
439
+ loss.backward()
440
+ optimizer.step()
441
+
442
+ if fill:
443
+ zero_map = zero_map.detach() / (self.gradient_maps[i] + 1e-8)
444
+ zero_map = voronoi_solve(
445
+ zero_map, self.gradient_maps[i][..., 0], self.device
446
+ )
447
+ else:
448
+ zero_map = zero_map.detach() / (self.gradient_maps[i] + 1e-8)
449
+ cos_maps.append(zero_map)
450
+ self.cos_maps = cos_maps
451
+
452
+ # Get geometric info from fragment shader
453
+ # Can be used for generating conditioning image and cosine weights
454
+ # Returns some information you may not need, remember to release them for memory saving
455
+ @torch.no_grad()
456
+ def render_geometry(self, image_size=None):
457
+ if image_size:
458
+ size = self.renderer.rasterizer.raster_settings.image_size
459
+ self.renderer.rasterizer.raster_settings.image_size = image_size
460
+ shader = self.renderer.shader
461
+ self.renderer.shader = HardGeometryShader(
462
+ device=self.device, cameras=self.cameras[0], lights=self.lights
463
+ )
464
+ tmp_mesh = self.mesh.clone()
465
+
466
+ verts, normals, depths, cos_angles, texels, fragments = self.renderer(
467
+ tmp_mesh.extend(len(self.cameras)), cameras=self.cameras, lights=self.lights
468
+ )
469
+ self.renderer.shader = shader
470
+
471
+ if image_size:
472
+ self.renderer.rasterizer.raster_settings.image_size = size
473
+
474
+ return verts, normals, depths, cos_angles, texels, fragments
475
+
476
+ # Project world normal to view space and normalize
477
+ @torch.no_grad()
478
+ def decode_view_normal(self, normals):
479
+ w2v_mat = self.cameras.get_full_projection_transform()
480
+ normals_view = torch.clone(normals)[:, :, :, 0:3]
481
+ normals_view = normals_view.reshape(normals_view.shape[0], -1, 3)
482
+ normals_view = w2v_mat.transform_normals(normals_view)
483
+ normals_view = normals_view.reshape(normals.shape[0:3] + (3,))
484
+ normals_view[:, :, :, 2] *= -1
485
+ normals = (normals_view[..., 0:3] + 1) * normals[
486
+ ..., 3:
487
+ ] / 2 + torch.FloatTensor(((((0.5, 0.5, 1))))).to(self.device) * (
488
+ 1 - normals[..., 3:]
489
+ )
490
+ # normals = torch.cat([normal for normal in normals], dim=1)
491
+ normals = normals.clamp(0, 1)
492
+ return normals
493
+
494
+ # Normalize absolute depth to inverse depth
495
+ @torch.no_grad()
496
+ def decode_normalized_depth(self, depths, batched_norm=False):
497
+ view_z, mask = depths.unbind(-1)
498
+ view_z = view_z * mask + 100 * (1 - mask)
499
+ inv_z = 1 / view_z
500
+ inv_z_min = inv_z * mask + 100 * (1 - mask)
501
+ if not batched_norm:
502
+ max_ = torch.max(inv_z, 1, keepdim=True)
503
+ max_ = torch.max(max_[0], 2, keepdim=True)[0]
504
+
505
+ min_ = torch.min(inv_z_min, 1, keepdim=True)
506
+ min_ = torch.min(min_[0], 2, keepdim=True)[0]
507
+ else:
508
+ max_ = torch.max(inv_z)
509
+ min_ = torch.min(inv_z_min)
510
+ inv_z = (inv_z - min_) / (max_ - min_)
511
+ inv_z = inv_z.clamp(0, 1)
512
+ inv_z = inv_z[..., None].repeat(1, 1, 1, 3)
513
+
514
+ return inv_z
515
+
516
+ # Multiple screen pixels could pass gradient to a same texel
517
+ # We can precalculate this gradient strength and use it to normalize gradients when we bake textures
518
+ @torch.enable_grad()
519
+ def calculate_tex_gradient(self, channels=None):
520
+ if not channels:
521
+ channels = self.channels
522
+ tmp_mesh = self.mesh.clone()
523
+ gradient_maps = []
524
+ for i in range(len(self.cameras)):
525
+ zero_map = torch.zeros(
526
+ self.target_size + (channels,), device=self.device, requires_grad=True
527
+ )
528
+ optimizer = torch.optim.SGD([zero_map], lr=1, momentum=0)
529
+ optimizer.zero_grad()
530
+ zero_tex = TexturesUV(
531
+ [zero_map],
532
+ self.mesh.textures.faces_uvs_padded(),
533
+ self.mesh.textures.verts_uvs_padded(),
534
+ sampling_mode=self.sampling_mode,
535
+ )
536
+ tmp_mesh.textures = zero_tex
537
+ images_predicted = self.renderer(
538
+ tmp_mesh, cameras=self.cameras[i], lights=self.lights
539
+ )
540
+ loss = torch.sum((1 - images_predicted) ** 2)
541
+ loss.backward()
542
+ optimizer.step()
543
+
544
+ gradient_maps.append(zero_map.detach())
545
+
546
+ self.gradient_maps = gradient_maps
547
+
548
+ # Get the UV space masks of triangles visible in each view
549
+ # First get face ids from each view, then filter pixels on UV space to generate masks
550
+
551
+ @torch.no_grad()
552
+ def get_c2w(
553
+ self,
554
+ elevation_deg: LIST_TYPE,
555
+ distance: LIST_TYPE,
556
+ azimuth_deg: Optional[LIST_TYPE],
557
+ num_views: Optional[int] = 1,
558
+ device: Optional[str] = None,
559
+ ) -> torch.FloatTensor:
560
+ if azimuth_deg is None:
561
+ assert (
562
+ num_views is not None
563
+ ), "num_views must be provided if azimuth_deg is None."
564
+ azimuth_deg = torch.linspace(
565
+ 0, 360, num_views + 1, dtype=torch.float32, device=device
566
+ )[:-1]
567
+ else:
568
+ num_views = len(azimuth_deg)
569
+
570
+ def list_to_pt(
571
+ x: LIST_TYPE,
572
+ dtype: Optional[torch.dtype] = None,
573
+ device: Optional[str] = None,
574
+ ) -> torch.Tensor:
575
+ if isinstance(x, list) or isinstance(x, np.ndarray):
576
+ return torch.tensor(x, dtype=dtype, device=device)
577
+ return x.to(dtype=dtype)
578
+
579
+ azimuth_deg = list_to_pt(azimuth_deg, dtype=torch.float32, device=device)
580
+ elevation_deg = list_to_pt(elevation_deg, dtype=torch.float32, device=device)
581
+ camera_distances = list_to_pt(distance, dtype=torch.float32, device=device)
582
+ elevation = elevation_deg * math.pi / 180
583
+ azimuth = azimuth_deg * math.pi / 180
584
+ camera_positions = torch.stack(
585
+ [
586
+ camera_distances * torch.cos(elevation) * torch.cos(azimuth),
587
+ camera_distances * torch.cos(elevation) * torch.sin(azimuth),
588
+ camera_distances * torch.sin(elevation),
589
+ ],
590
+ dim=-1,
591
+ )
592
+ center = torch.zeros_like(camera_positions)
593
+ up = torch.tensor([0, 0, 1], dtype=torch.float32, device=device)[
594
+ None, :
595
+ ].repeat(num_views, 1)
596
+ lookat = F.normalize(center - camera_positions, dim=-1)
597
+ right = F.normalize(torch.cross(lookat, up, dim=-1), dim=-1)
598
+ up = F.normalize(torch.cross(right, lookat, dim=-1), dim=-1)
599
+ c2w3x4 = torch.cat(
600
+ [torch.stack([right, up, -lookat], dim=-1), camera_positions[:, :, None]],
601
+ dim=-1,
602
+ )
603
+ c2w = torch.cat([c2w3x4, torch.zeros_like(c2w3x4[:, :1])], dim=1)
604
+ c2w[:, 3, 3] = 1.0
605
+ return c2w
606
+
607
+ @torch.no_grad()
608
+ def calculate_visible_triangle_mask(self, channels=None, image_size=(512, 512)):
609
+ if not channels:
610
+ channels = self.channels
611
+
612
+ pix2face_list = []
613
+ for i in range(len(self.cameras)):
614
+ self.renderer.rasterizer.raster_settings.image_size = image_size
615
+ pix2face = self.renderer.rasterizer(
616
+ self.mesh_d, cameras=self.cameras[i]
617
+ ).pix_to_face
618
+ self.renderer.rasterizer.raster_settings.image_size = self.render_size
619
+ pix2face_list.append(pix2face)
620
+
621
+ if not hasattr(self, "mesh_uv"):
622
+ self.construct_uv_mesh()
623
+
624
+ raster_settings = RasterizationSettings(
625
+ image_size=self.target_size,
626
+ blur_radius=0,
627
+ faces_per_pixel=1,
628
+ perspective_correct=False,
629
+ cull_backfaces=False,
630
+ max_faces_per_bin=30000,
631
+ )
632
+
633
+ R, T = look_at_view_transform(dist=2, elev=0, azim=0)
634
+ # flip_mat = torch.from_numpy(np.diag([-1.0, 1.0, -1.0]) ).type(torch.FloatTensor).to(R.device)
635
+ # R = R@flip_mat
636
+ # T = T*torch.tensor(np.array([-1.0, 1.0, -1.0])).type(torch.FloatTensor).to(R.device)
637
+ # c2w = self.get_c2w([0], [1.8], [0])
638
+ # w2c = torch.linalg.inv(c2w)[:, :3,:]
639
+ # R, T= w2c[:, :3,:3], w2c[:, :3, 3]
640
+ # print("R size:{}, T size:{}".format(R.size(), T.size()))
641
+ cameras = FoVOrthographicCameras(device=self.device, R=R, T=T)
642
+ # cameras = CustomOrthographicCameras(device=self.device, R=R, T=T)
643
+
644
+ # cameras = CustomOrthographicCameras(device=self.device, R=R, T=T, znear=0.1, min_x=-0.55, max_x=0.55, min_y=-0.55, max_y=0.55)
645
+
646
+ rasterizer = MeshRasterizer(cameras=cameras, raster_settings=raster_settings)
647
+ uv_pix2face = rasterizer(self.mesh_uv).pix_to_face
648
+
649
+ visible_triangles = []
650
+ for i in range(len(pix2face_list)):
651
+ valid_faceid = torch.unique(pix2face_list[i])
652
+ valid_faceid = valid_faceid[1:] if valid_faceid[0] == -1 else valid_faceid
653
+ mask = torch.isin(uv_pix2face[0], valid_faceid, assume_unique=False)
654
+ # uv_pix2face[0][~mask] = -1
655
+ triangle_mask = torch.ones(self.target_size + (1,), device=self.device)
656
+ triangle_mask[~mask] = 0
657
+
658
+ triangle_mask[:, 1:][triangle_mask[:, :-1] > 0] = 1
659
+ triangle_mask[:, :-1][triangle_mask[:, 1:] > 0] = 1
660
+ triangle_mask[1:, :][triangle_mask[:-1, :] > 0] = 1
661
+ triangle_mask[:-1, :][triangle_mask[1:, :] > 0] = 1
662
+ visible_triangles.append(triangle_mask)
663
+
664
+ self.visible_triangles = visible_triangles
665
+
666
+ # Render the current mesh and texture from current cameras
667
+ def render_textured_views(self):
668
+ meshes = self.mesh.extend(len(self.cameras))
669
+ images_predicted = self.renderer(
670
+ meshes, cameras=self.cameras, lights=self.lights
671
+ )
672
+
673
+ return [image.permute(2, 0, 1) for image in images_predicted]
674
+
675
+ @torch.no_grad()
676
+ def get_point_validation_by_o3d(
677
+ self, points, eye_position, hidden_point_removal_radius=200
678
+ ):
679
+ point_visibility = torch.zeros((points.shape[0]), device=points.device).bool()
680
+
681
+ pcd = o3d.geometry.PointCloud(
682
+ points=o3d.utility.Vector3dVector(points.cpu().numpy())
683
+ )
684
+ camera_pose = (
685
+ eye_position.get_camera_center().squeeze().cpu().numpy().astype(np.float64)
686
+ )
687
+ # o3d_camera = [0, 0, diameter]
688
+ diameter = np.linalg.norm(
689
+ np.asarray(pcd.get_max_bound()) - np.asarray(pcd.get_min_bound())
690
+ )
691
+ radius = diameter * 200 # The radius of the sperical projection
692
+ _, pt_map = pcd.hidden_point_removal(camera_pose, radius)
693
+
694
+ visible_point_ids = np.array(pt_map)
695
+
696
+ point_visibility[visible_point_ids] = True
697
+ return point_visibility
698
+
699
+ @torch.no_grad()
700
+ def hidden_judge(self, camera, texture_dim):
701
+ mesh = self.mesh
702
+
703
+ verts = mesh.verts_packed()
704
+ faces = mesh.faces_packed()
705
+ verts_uv = mesh.textures.verts_uvs_padded()[0] # 获取打包后的 UV 坐标 (V, 2)
706
+ faces_uv = mesh.textures.faces_uvs_padded()[0]
707
+ uv_face_attr = torch.index_select(
708
+ verts_uv, 0, faces_uv.view(-1)
709
+ ) # 选择对应顶点的 UV 坐标
710
+ uv_face_attr = uv_face_attr.view(
711
+ faces.shape[0], faces_uv.shape[1], 2
712
+ ).unsqueeze(0)
713
+ x, y, z = verts[:, 0], verts[:, 1], verts[:, 2]
714
+ mesh_out_of_range = False
715
+ if (
716
+ x.min() < -1
717
+ or x.max() > 1
718
+ or y.min() < -1
719
+ or y.max() > 1
720
+ or z.min() < -1
721
+ or z.max() > 1
722
+ ):
723
+ mesh_out_of_range = True
724
+ face_vertices_world = kal.ops.mesh.index_vertices_by_faces(
725
+ verts.unsqueeze(0), faces
726
+ )
727
+ face_vertices_z = torch.zeros_like(
728
+ face_vertices_world[:, :, :, -1], device=verts.device
729
+ )
730
+ uv_position, face_idx = kal.render.mesh.rasterize(
731
+ texture_dim,
732
+ texture_dim,
733
+ face_vertices_z,
734
+ uv_face_attr * 2 - 1,
735
+ face_features=face_vertices_world,
736
+ )
737
+ uv_position = torch.clamp(uv_position, -1, 1)
738
+ uv_position[face_idx == -1] = 0
739
+
740
+ points = uv_position.reshape(-1, 3)
741
+ mask = points[:, 0] != 0
742
+ valid_points = points[mask]
743
+ # np.save("tmp/pcd.npy", valid_points.cpu().numpy())
744
+ # print(camera.get_camera_center())
745
+
746
+ points_visibility = self.get_point_validation_by_o3d(
747
+ valid_points, camera
748
+ ).float()
749
+ visibility_map = torch.zeros((texture_dim * texture_dim,)).to(self.device)
750
+ visibility_map[mask] = points_visibility
751
+ visibility_map = visibility_map.reshape((texture_dim, texture_dim))
752
+ return visibility_map
753
+
754
+ @torch.enable_grad()
755
+ def bake_texture(
756
+ self,
757
+ views=None,
758
+ main_views=[],
759
+ cos_weighted=True,
760
+ channels=None,
761
+ exp=None,
762
+ noisy=False,
763
+ generator=None,
764
+ smooth_colorize=False,
765
+ ):
766
+ if not exp:
767
+ exp = 1
768
+ if not channels:
769
+ channels = self.channels
770
+ views = [view.permute(1, 2, 0) for view in views]
771
+
772
+ tmp_mesh = self.mesh
773
+ bake_maps = [
774
+ torch.zeros(
775
+ self.target_size + (views[0].shape[2],),
776
+ device=self.device,
777
+ requires_grad=True,
778
+ )
779
+ for view in views
780
+ ]
781
+ optimizer = torch.optim.SGD(bake_maps, lr=1, momentum=0)
782
+ optimizer.zero_grad()
783
+ loss = 0
784
+ for i in range(len(self.cameras)):
785
+ bake_tex = TexturesUV(
786
+ [bake_maps[i]],
787
+ tmp_mesh.textures.faces_uvs_padded(),
788
+ tmp_mesh.textures.verts_uvs_padded(),
789
+ sampling_mode=self.sampling_mode,
790
+ )
791
+ tmp_mesh.textures = bake_tex
792
+ images_predicted = self.renderer(
793
+ tmp_mesh,
794
+ cameras=self.cameras[i],
795
+ lights=self.lights,
796
+ device=self.device,
797
+ )
798
+ predicted_rgb = images_predicted[..., :-1]
799
+ loss += (((predicted_rgb[...] - views[i])) ** 2).sum()
800
+ loss.backward(retain_graph=False)
801
+ optimizer.step()
802
+
803
+ total_weights = 0
804
+ baked = 0
805
+ for i in range(len(bake_maps)):
806
+ normalized_baked_map = bake_maps[i].detach() / (
807
+ self.gradient_maps[i] + 1e-8
808
+ )
809
+ bake_map = voronoi_solve(
810
+ normalized_baked_map, self.gradient_maps[i][..., 0], self.device
811
+ )
812
+ # bake_map = voronoi_solve(normalized_baked_map, self.visible_triangles[i].squeeze())
813
+
814
+ weight = self.visible_triangles[i] * (self.cos_maps[i]) ** exp
815
+ if smooth_colorize:
816
+ visibility_map = self.hidden_judge(
817
+ self.cameras[i], self.target_size[0]
818
+ ).unsqueeze(-1)
819
+ weight *= visibility_map
820
+ if noisy:
821
+ noise = (
822
+ torch.rand(weight.shape[:-1] + (1,), generator=generator)
823
+ .type(weight.dtype)
824
+ .to(weight.device)
825
+ )
826
+ weight *= noise
827
+ total_weights += weight
828
+
829
+ baked += bake_map * weight
830
+ baked /= total_weights + 1e-8
831
+
832
+ whole_visible_mask = None
833
+ if not smooth_colorize:
834
+ baked = voronoi_solve(baked, total_weights[..., 0], self.device)
835
+ tmp_mesh.textures = TexturesUV(
836
+ [baked],
837
+ tmp_mesh.textures.faces_uvs_padded(),
838
+ tmp_mesh.textures.verts_uvs_padded(),
839
+ sampling_mode=self.sampling_mode,
840
+ )
841
+ else: # smooth colorize
842
+ baked = voronoi_solve(baked, total_weights[..., 0], self.device)
843
+ whole_visible_mask = self.visible_triangles[0].to(torch.int32)
844
+ for tensor in self.visible_triangles[1:]:
845
+ whole_visible_mask = torch.bitwise_or(
846
+ whole_visible_mask, tensor.to(torch.int32)
847
+ )
848
+
849
+ baked *= whole_visible_mask
850
+ tmp_mesh.textures = TexturesUV(
851
+ [baked],
852
+ tmp_mesh.textures.faces_uvs_padded(),
853
+ tmp_mesh.textures.verts_uvs_padded(),
854
+ sampling_mode=self.sampling_mode,
855
+ )
856
+
857
+ extended_mesh = tmp_mesh.extend(len(self.cameras))
858
+ images_predicted = self.renderer(
859
+ extended_mesh, cameras=self.cameras, lights=self.lights
860
+ )
861
+ learned_views = [image.permute(2, 0, 1) for image in images_predicted]
862
+
863
+ return learned_views, baked.permute(2, 0, 1), total_weights.permute(2, 0, 1)
864
+
865
+ # Move the internel data to a specific device
866
+ def to(self, device):
867
+ for mesh_name in ["mesh", "mesh_d", "mesh_uv"]:
868
+ if hasattr(self, mesh_name):
869
+ mesh = getattr(self, mesh_name)
870
+ setattr(self, mesh_name, mesh.to(device))
871
+ for list_name in ["visible_triangles", "visibility_maps", "cos_maps"]:
872
+ if hasattr(self, list_name):
873
+ map_list = getattr(self, list_name)
874
+ for i in range(len(map_list)):
875
+ map_list[i] = map_list[i].to(device)
step1x3d_texture/renderer/shader.py ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional
2
+
3
+ import torch
4
+ import pytorch3d
5
+
6
+
7
+ from pytorch3d.io import load_objs_as_meshes, load_obj, save_obj
8
+ from pytorch3d.ops import interpolate_face_attributes
9
+
10
+ from pytorch3d.structures import Meshes
11
+ from pytorch3d.renderer import (
12
+ look_at_view_transform,
13
+ FoVPerspectiveCameras,
14
+ AmbientLights,
15
+ PointLights,
16
+ DirectionalLights,
17
+ Materials,
18
+ RasterizationSettings,
19
+ MeshRenderer,
20
+ MeshRasterizer,
21
+ SoftPhongShader,
22
+ SoftSilhouetteShader,
23
+ HardPhongShader,
24
+ TexturesVertex,
25
+ TexturesUV,
26
+ Materials,
27
+ )
28
+ from pytorch3d.renderer.blending import BlendParams, hard_rgb_blend
29
+ from pytorch3d.renderer.utils import convert_to_tensors_and_broadcast, TensorProperties
30
+
31
+ from pytorch3d.renderer.lighting import AmbientLights
32
+ from pytorch3d.renderer.materials import Materials
33
+ from pytorch3d.renderer.mesh.shader import ShaderBase
34
+ from pytorch3d.renderer.mesh.shading import _apply_lighting, flat_shading
35
+ from pytorch3d.renderer.mesh.rasterizer import Fragments
36
+
37
+
38
+ """
39
+ Customized the original pytorch3d hard flat shader to support N channel flat shading
40
+ """
41
+
42
+
43
+ class HardNChannelFlatShader(ShaderBase):
44
+ """
45
+ Per face lighting - the lighting model is applied using the average face
46
+ position and the face normal. The blending function hard assigns
47
+ the color of the closest face for each pixel.
48
+
49
+ To use the default values, simply initialize the shader with the desired
50
+ device e.g.
51
+
52
+ .. code-block::
53
+
54
+ shader = HardFlatShader(device=torch.device("cuda:0"))
55
+ """
56
+
57
+ def __init__(
58
+ self,
59
+ device="cpu",
60
+ cameras: Optional[TensorProperties] = None,
61
+ lights: Optional[TensorProperties] = None,
62
+ materials: Optional[Materials] = None,
63
+ blend_params: Optional[BlendParams] = None,
64
+ channels: int = 3,
65
+ ):
66
+ self.channels = channels
67
+ ones = ((1.0,) * channels,)
68
+ zeros = ((0.0,) * channels,)
69
+
70
+ if (
71
+ not isinstance(lights, AmbientLights)
72
+ or not lights.ambient_color.shape[-1] == channels
73
+ ):
74
+ lights = AmbientLights(
75
+ ambient_color=ones,
76
+ device=device,
77
+ )
78
+
79
+ if not materials or not materials.ambient_color.shape[-1] == channels:
80
+ materials = Materials(
81
+ device=device,
82
+ diffuse_color=zeros,
83
+ ambient_color=ones,
84
+ specular_color=zeros,
85
+ shininess=0.0,
86
+ )
87
+
88
+ blend_params_new = BlendParams(background_color=(1.0,) * channels)
89
+ if not isinstance(blend_params, BlendParams):
90
+ blend_params = blend_params_new
91
+ else:
92
+ background_color_ = blend_params.background_color
93
+ if (
94
+ isinstance(background_color_, Sequence[float])
95
+ and not len(background_color_) == channels
96
+ ):
97
+ blend_params = blend_params_new
98
+ if (
99
+ isinstance(background_color_, torch.Tensor)
100
+ and not background_color_.shape[-1] == channels
101
+ ):
102
+ blend_params = blend_params_new
103
+
104
+ super().__init__(
105
+ device,
106
+ cameras,
107
+ lights,
108
+ materials,
109
+ blend_params,
110
+ )
111
+
112
+ def forward(self, fragments: Fragments, meshes: Meshes, **kwargs) -> torch.Tensor:
113
+ cameras = super()._get_cameras(**kwargs)
114
+ texels = meshes.sample_textures(fragments)
115
+ lights = kwargs.get("lights", self.lights)
116
+ materials = kwargs.get("materials", self.materials)
117
+ blend_params = kwargs.get("blend_params", self.blend_params)
118
+ colors = flat_shading(
119
+ meshes=meshes,
120
+ fragments=fragments,
121
+ texels=texels,
122
+ lights=lights,
123
+ cameras=cameras,
124
+ materials=materials,
125
+ )
126
+ images = hard_rgb_blend(colors, fragments, blend_params)
127
+ return images
step1x3d_texture/{texture_sync → renderer}/voronoi.py RENAMED
File without changes
step1x3d_texture/texture_sync/geometry.py DELETED
@@ -1,141 +0,0 @@
1
- import torch
2
- import pytorch3d
3
- import torch.nn.functional as F
4
-
5
- from pytorch3d.ops import interpolate_face_attributes
6
-
7
- from pytorch3d.renderer import (
8
- look_at_view_transform,
9
- FoVPerspectiveCameras,
10
- AmbientLights,
11
- PointLights,
12
- DirectionalLights,
13
- Materials,
14
- RasterizationSettings,
15
- MeshRenderer,
16
- MeshRasterizer,
17
- SoftPhongShader,
18
- SoftSilhouetteShader,
19
- HardPhongShader,
20
- TexturesVertex,
21
- TexturesUV,
22
- Materials,
23
-
24
- )
25
- from pytorch3d.renderer.blending import BlendParams,hard_rgb_blend
26
- from pytorch3d.renderer.utils import convert_to_tensors_and_broadcast, TensorProperties
27
- from pytorch3d.renderer.mesh.shader import ShaderBase
28
-
29
-
30
- def get_cos_angle(
31
- points, normals, camera_position
32
- ):
33
- '''
34
- calculate cosine similarity between view->surface and surface normal.
35
- '''
36
-
37
- if points.shape != normals.shape:
38
- msg = "Expected points and normals to have the same shape: got %r, %r"
39
- raise ValueError(msg % (points.shape, normals.shape))
40
-
41
- # Ensure all inputs have same batch dimension as points
42
- matched_tensors = convert_to_tensors_and_broadcast(
43
- points, camera_position, device=points.device
44
- )
45
- _, camera_position = matched_tensors
46
-
47
- # Reshape direction and color so they have all the arbitrary intermediate
48
- # dimensions as points. Assume first dim = batch dim and last dim = 3.
49
- points_dims = points.shape[1:-1]
50
- expand_dims = (-1,) + (1,) * len(points_dims)
51
-
52
- if camera_position.shape != normals.shape:
53
- camera_position = camera_position.view(expand_dims + (3,))
54
-
55
- normals = F.normalize(normals, p=2, dim=-1, eps=1e-6)
56
-
57
- # Calculate the cosine value.
58
- view_direction = camera_position - points
59
- view_direction = F.normalize(view_direction, p=2, dim=-1, eps=1e-6)
60
- cos_angle = torch.sum(view_direction * normals, dim=-1, keepdim=True)
61
- cos_angle = cos_angle.clamp(0, 1)
62
-
63
- # Cosine of the angle between the reflected light ray and the viewer
64
- return cos_angle
65
-
66
-
67
- def _geometry_shading_with_pixels(
68
- meshes, fragments, lights, cameras, materials, texels
69
- ):
70
- """
71
- Render pixel space vertex position, normal(world), depth, and cos angle
72
-
73
- Args:
74
- meshes: Batch of meshes
75
- fragments: Fragments named tuple with the outputs of rasterization
76
- lights: Lights class containing a batch of lights
77
- cameras: Cameras class containing a batch of cameras
78
- materials: Materials class containing a batch of material properties
79
- texels: texture per pixel of shape (N, H, W, K, 3)
80
-
81
- Returns:
82
- colors: (N, H, W, K, 3)
83
- pixel_coords: (N, H, W, K, 3), camera coordinates of each intersection.
84
- """
85
- verts = meshes.verts_packed() # (V, 3)
86
- faces = meshes.faces_packed() # (F, 3)
87
- vertex_normals = meshes.verts_normals_packed() # (V, 3)
88
- faces_verts = verts[faces]
89
- faces_normals = vertex_normals[faces]
90
- pixel_coords_in_camera = interpolate_face_attributes(
91
- fragments.pix_to_face, fragments.bary_coords, faces_verts
92
- )
93
- pixel_normals = interpolate_face_attributes(
94
- fragments.pix_to_face, fragments.bary_coords, faces_normals
95
- )
96
-
97
- cos_angles = get_cos_angle(pixel_coords_in_camera, pixel_normals, cameras.get_camera_center())
98
-
99
- return pixel_coords_in_camera, pixel_normals, fragments.zbuf[...,None], cos_angles
100
-
101
-
102
- class HardGeometryShader(ShaderBase):
103
- """
104
- renders common geometric informations.
105
-
106
-
107
- """
108
-
109
- def forward(self, fragments, meshes, **kwargs):
110
- cameras = super()._get_cameras(**kwargs)
111
- texels = self.texel_from_uv(fragments, meshes)
112
-
113
- lights = kwargs.get("lights", self.lights)
114
- materials = kwargs.get("materials", self.materials)
115
- blend_params = kwargs.get("blend_params", self.blend_params)
116
- verts, normals, depths, cos_angles = _geometry_shading_with_pixels(
117
- meshes=meshes,
118
- fragments=fragments,
119
- texels=texels,
120
- lights=lights,
121
- cameras=cameras,
122
- materials=materials,
123
- )
124
- verts = hard_rgb_blend(verts, fragments, blend_params)
125
- normals = hard_rgb_blend(normals, fragments, blend_params)
126
- depths = hard_rgb_blend(depths, fragments, blend_params)
127
- cos_angles = hard_rgb_blend(cos_angles, fragments, blend_params)
128
- texels = hard_rgb_blend(texels, fragments, blend_params)
129
- return verts, normals, depths, cos_angles, texels, fragments
130
-
131
- def texel_from_uv(self, fragments, meshes):
132
- texture_tmp = meshes.textures
133
- maps_tmp = texture_tmp.maps_padded()
134
- uv_color = [ [[1,0],[1,1]],[[0,0],[0,1]] ]
135
- uv_color = torch.FloatTensor(uv_color).to(maps_tmp[0].device).type(maps_tmp[0].dtype)
136
- uv_texture = TexturesUV([uv_color.clone() for t in maps_tmp], texture_tmp.faces_uvs_padded(), texture_tmp.verts_uvs_padded(), sampling_mode="bilinear")
137
- meshes.textures = uv_texture
138
- texels = meshes.sample_textures(fragments)
139
- meshes.textures = texture_tmp
140
- texels = torch.cat((texels, texels[...,-1:]*0), dim=-1)
141
- return texels
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
step1x3d_texture/texture_sync/project.py DELETED
@@ -1,521 +0,0 @@
1
- import torch
2
- import pytorch3d
3
-
4
-
5
- from pytorch3d.io import load_objs_as_meshes, load_obj, save_obj, IO
6
-
7
- from pytorch3d.structures import Meshes
8
- from pytorch3d.renderer import (
9
- look_at_view_transform,
10
- FoVPerspectiveCameras,
11
- FoVOrthographicCameras,
12
- AmbientLights,
13
- PointLights,
14
- DirectionalLights,
15
- Materials,
16
- RasterizationSettings,
17
- MeshRenderer,
18
- MeshRasterizer,
19
- TexturesUV
20
- )
21
-
22
- from .geometry import HardGeometryShader
23
- from .shader import HardNChannelFlatShader
24
- from .voronoi import voronoi_solve
25
- from trimesh import Trimesh
26
-
27
- # Pytorch3D based renderering functions, managed in a class
28
- # Render size is recommended to be the same as your latent view size
29
- # DO NOT USE "bilinear" sampling when you are handling latents.
30
- # Stable Diffusion has 4 latent channels so use channels=4
31
-
32
- class UVProjection():
33
- def __init__(self, texture_size=96, render_size=64, sampling_mode="nearest", channels=3, device=None):
34
- self.channels = channels
35
- self.device = device or torch.device("cpu")
36
- self.lights = AmbientLights(ambient_color=((1.0,)*channels,), device=self.device)
37
- self.target_size = (texture_size,texture_size)
38
- self.render_size = render_size
39
- self.sampling_mode = sampling_mode
40
-
41
-
42
- # # Load obj mesh, rescale the mesh to fit into the bounding box
43
- # def load_mesh(self, mesh_path, scale_factor=2.0, auto_center=True, autouv=False):
44
- # mesh = load_objs_as_meshes([mesh_path], device=self.device)
45
- # if auto_center:
46
- # verts = mesh.verts_packed()
47
- # max_bb = (verts - 0).max(0)[0]
48
- # min_bb = (verts - 0).min(0)[0]
49
- # scale = (max_bb - min_bb).max()/2
50
- # center = (max_bb+min_bb) /2
51
- # mesh.offset_verts_(-center)
52
- # mesh.scale_verts_((scale_factor / float(scale)))
53
- # else:
54
- # mesh.scale_verts_((scale_factor))
55
-
56
- # if autouv or (mesh.textures is None):
57
- # mesh = self.uv_unwrap(mesh)
58
- # self.mesh = mesh
59
- # Load obj mesh, rescale the mesh to fit into the bounding box
60
- def load_mesh(self, mesh, scale_factor=2.0, auto_center=True, autouv=False, normals=None):
61
- if isinstance(mesh, Trimesh):
62
- vertices = torch.tensor(mesh.vertices, dtype=torch.float32).to(self.device)
63
- faces = torch.tensor(mesh.faces, dtype=torch.int64).to(self.device)
64
- if faces.ndim == 1:
65
- faces = faces.unsqueeze(0)
66
- mesh = Meshes(
67
- verts=[vertices],
68
- faces=[faces]
69
- )
70
- verts = mesh.verts_packed()
71
- mesh = mesh.update_padded(verts[None,:, :])
72
- # from pytorch3d.renderer.mesh.textures import TexturesVertex
73
- # if normals is None:
74
- # normals = mesh.verts_normals_packed()
75
- # # set normals as vertext colors
76
- # mesh.textures = TexturesVertex(verts_features=[normals / 2 + 0.5])
77
- elif isinstance(mesh, str) and os.path.isfile(mesh):
78
- mesh = load_objs_as_meshes([mesh_path], device=self.device)
79
- if auto_center:
80
- verts = mesh.verts_packed()
81
- max_bb = (verts - 0).max(0)[0]
82
- min_bb = (verts - 0).min(0)[0]
83
- scale = (max_bb - min_bb).max()/2
84
- center = (max_bb+min_bb) /2
85
- mesh.offset_verts_(-center)
86
- mesh.scale_verts_((scale_factor / float(scale)))
87
- else:
88
- mesh.scale_verts_((scale_factor))
89
-
90
- if autouv or (mesh.textures is None):
91
- mesh = self.uv_unwrap(mesh)
92
- self.mesh = mesh
93
-
94
- def load_glb_mesh(self, mesh_path, scale_factor=2.0, auto_center=True, autouv=False):
95
- from pytorch3d.io.experimental_gltf_io import MeshGlbFormat
96
- io = IO()
97
- io.register_meshes_format(MeshGlbFormat())
98
- with open(mesh_path, "rb") as f:
99
- mesh = io.load_mesh(f, include_textures=True, device=self.device)
100
- if auto_center:
101
- verts = mesh.verts_packed()
102
- max_bb = (verts - 0).max(0)[0]
103
- min_bb = (verts - 0).min(0)[0]
104
- scale = (max_bb - min_bb).max()/2
105
- center = (max_bb+min_bb) /2
106
- mesh.offset_verts_(-center)
107
- mesh.scale_verts_((scale_factor / float(scale)))
108
- else:
109
- mesh.scale_verts_((scale_factor))
110
- if autouv or (mesh.textures is None):
111
- mesh = self.uv_unwrap(mesh)
112
- self.mesh = mesh
113
-
114
-
115
- # Save obj mesh
116
- def save_mesh(self, mesh_path, texture):
117
- save_obj(mesh_path,
118
- self.mesh.verts_list()[0],
119
- self.mesh.faces_list()[0],
120
- verts_uvs= self.mesh.textures.verts_uvs_list()[0],
121
- faces_uvs= self.mesh.textures.faces_uvs_list()[0],
122
- texture_map=texture)
123
-
124
- # Code referred to TEXTure code (https://github.com/TEXTurePaper/TEXTurePaper.git)
125
- def uv_unwrap(self, mesh):
126
- verts_list = mesh.verts_list()[0]
127
- faces_list = mesh.faces_list()[0]
128
-
129
-
130
- import xatlas
131
- import numpy as np
132
- v_np = verts_list.cpu().numpy()
133
- f_np = faces_list.int().cpu().numpy()
134
- atlas = xatlas.Atlas()
135
- atlas.add_mesh(v_np, f_np)
136
- chart_options = xatlas.ChartOptions()
137
- chart_options.max_iterations = 4
138
- atlas.generate(chart_options=chart_options)
139
- vmapping, ft_np, vt_np = atlas[0] # [N], [M, 3], [N, 2]
140
-
141
- vt = torch.from_numpy(vt_np.astype(np.float32)).type(verts_list.dtype).to(mesh.device)
142
- ft = torch.from_numpy(ft_np.astype(np.int64)).type(faces_list.dtype).to(mesh.device)
143
-
144
- new_map = torch.zeros(self.target_size+(self.channels,), device=mesh.device)
145
- new_tex = TexturesUV(
146
- [new_map],
147
- [ft],
148
- [vt],
149
- sampling_mode=self.sampling_mode
150
- )
151
-
152
- mesh.textures = new_tex
153
- return mesh
154
-
155
-
156
- '''
157
- A functions that disconnect faces in the mesh according to
158
- its UV seams. The number of vertices are made equal to the
159
- number of unique vertices its UV layout, while the faces list
160
- is intact.
161
- '''
162
- def disconnect_faces(self):
163
- mesh = self.mesh
164
- verts_list = mesh.verts_list()
165
- faces_list = mesh.faces_list()
166
- verts_uvs_list = mesh.textures.verts_uvs_list()
167
- faces_uvs_list = mesh.textures.faces_uvs_list()
168
- packed_list = [v[f] for v,f in zip(verts_list, faces_list)]
169
- verts_disconnect_list = [
170
- torch.zeros(
171
- (verts_uvs_list[i].shape[0], 3),
172
- dtype=verts_list[0].dtype,
173
- device=verts_list[0].device
174
- )
175
- for i in range(len(verts_list))]
176
- for i in range(len(verts_list)):
177
- verts_disconnect_list[i][faces_uvs_list] = packed_list[i]
178
- assert not mesh.has_verts_normals(), "Not implemented for vertex normals"
179
- self.mesh_d = Meshes(verts_disconnect_list, faces_uvs_list, mesh.textures)
180
- return self.mesh_d
181
-
182
-
183
- '''
184
- A function that construct a temp mesh for back-projection.
185
- Take a disconnected mesh and a rasterizer, the function calculates
186
- the projected faces as the UV, as use its original UV with pseudo
187
- z value as world space geometry.
188
- '''
189
- def construct_uv_mesh(self):
190
- mesh = self.mesh_d
191
- verts_list = mesh.verts_list()
192
- verts_uvs_list = mesh.textures.verts_uvs_list()
193
- # faces_list = [torch.flip(faces, [-1]) for faces in mesh.faces_list()]
194
- new_verts_list = []
195
- for i, (verts, verts_uv) in enumerate(zip(verts_list, verts_uvs_list)):
196
- verts = verts.clone()
197
- verts_uv = verts_uv.clone()
198
- verts[...,0:2] = verts_uv[...,:]
199
- verts = (verts - 0.5) * 2
200
- verts[...,2] *= 1
201
- new_verts_list.append(verts)
202
- textures_uv = mesh.textures.clone()
203
- self.mesh_uv = Meshes(new_verts_list, mesh.faces_list(), textures_uv)
204
- return self.mesh_uv
205
-
206
-
207
- # Set texture for the current mesh.
208
- def set_texture_map(self, texture):
209
- new_map = texture.permute(1, 2, 0)
210
- new_map = new_map.to(self.device)
211
- new_tex = TexturesUV(
212
- [new_map],
213
- self.mesh.textures.faces_uvs_padded(),
214
- self.mesh.textures.verts_uvs_padded(),
215
- sampling_mode=self.sampling_mode
216
- )
217
- self.mesh.textures = new_tex
218
-
219
-
220
- # Set the initial normal noise texture
221
- # No generator here for replication of the experiment result. Add one as you wish
222
- def set_noise_texture(self, channels=None):
223
- if not channels:
224
- channels = self.channels
225
- noise_texture = torch.normal(0, 1, (channels,) + self.target_size, device=self.device)
226
- self.set_texture_map(noise_texture)
227
- return noise_texture
228
-
229
-
230
- # Set the cameras given the camera poses and centers
231
- def set_cameras(self, camera_poses, centers=None, camera_distance=2.7, scale=None):
232
- elev = torch.FloatTensor([pose[0] for pose in camera_poses])
233
- azim = torch.FloatTensor([pose[1] for pose in camera_poses])
234
- R, T = look_at_view_transform(dist=camera_distance, elev=elev, azim=azim, at=centers or ((0,0,0),))
235
- # self.cameras = FoVOrthographicCameras(device=self.device, R=R, T=T, scale_xyz=scale or ((1,1,1),))
236
- self.cameras = FoVOrthographicCameras(device=self.device, R=R, T=T, scale_xyz=scale or ((1,1,1),), znear=0.1, min_x=-0.55, max_x=0.55, min_y=-0.55, max_y=0.55)
237
-
238
- # Set all necessary internal data for rendering and texture baking
239
- # Can be used to refresh after changing camera positions
240
- def set_cameras_and_render_settings(self, camera_poses, centers=None, camera_distance=2.7, render_size=None, scale=None):
241
- self.set_cameras(camera_poses, centers, camera_distance, scale=scale)
242
- if render_size is None:
243
- render_size = self.render_size
244
- if not hasattr(self, "renderer"):
245
- self.setup_renderer(size=render_size)
246
- if not hasattr(self, "mesh_d"):
247
- self.disconnect_faces()
248
- if not hasattr(self, "mesh_uv"):
249
- self.construct_uv_mesh()
250
- self.calculate_tex_gradient()
251
- self.calculate_visible_triangle_mask()
252
- _,_,_,cos_maps,_, _ = self.render_geometry()
253
- self.calculate_cos_angle_weights(cos_maps)
254
-
255
-
256
- # Setup renderers for rendering
257
- # max faces per bin set to 30000 to avoid overflow in many test cases.
258
- # You can use default value to let pytorch3d handle that for you.
259
- def setup_renderer(self, size=64, blur=0.0, face_per_pix=1, perspective_correct=False, channels=None):
260
- if not channels:
261
- channels = self.channels
262
-
263
- self.raster_settings = RasterizationSettings(
264
- image_size=size,
265
- blur_radius=blur,
266
- faces_per_pixel=face_per_pix,
267
- perspective_correct=perspective_correct,
268
- cull_backfaces=True,
269
- max_faces_per_bin=30000,
270
- )
271
-
272
- self.renderer = MeshRenderer(
273
- rasterizer=MeshRasterizer(
274
- cameras=self.cameras,
275
- raster_settings=self.raster_settings,
276
-
277
- ),
278
- shader=HardNChannelFlatShader(
279
- device=self.device,
280
- cameras=self.cameras,
281
- lights=self.lights,
282
- channels=channels
283
- # materials=materials
284
- )
285
- )
286
-
287
-
288
- # Bake screen-space cosine weights to UV space
289
- # May be able to reimplement using the generic "bake_texture" function, but it works so leave it here for now
290
- @torch.enable_grad()
291
- def calculate_cos_angle_weights(self, cos_angles, fill=True, channels=None):
292
- if not channels:
293
- channels = self.channels
294
- cos_maps = []
295
- tmp_mesh = self.mesh.clone()
296
- for i in range(len(self.cameras)):
297
-
298
- zero_map = torch.zeros(self.target_size+(channels,), device=self.device, requires_grad=True)
299
- optimizer = torch.optim.SGD([zero_map], lr=1, momentum=0)
300
- optimizer.zero_grad()
301
- zero_tex = TexturesUV([zero_map], self.mesh.textures.faces_uvs_padded(), self.mesh.textures.verts_uvs_padded(), sampling_mode=self.sampling_mode)
302
- tmp_mesh.textures = zero_tex
303
-
304
- images_predicted = self.renderer(tmp_mesh, cameras=self.cameras[i], lights=self.lights)
305
-
306
- loss = torch.sum((cos_angles[i,:,:,0:1]**1 - images_predicted)**2)
307
- loss.backward()
308
- optimizer.step()
309
-
310
- if fill:
311
- zero_map = zero_map.detach() / (self.gradient_maps[i] + 1E-8)
312
- zero_map = voronoi_solve(zero_map, self.gradient_maps[i][...,0])
313
- else:
314
- zero_map = zero_map.detach() / (self.gradient_maps[i]+1E-8)
315
- cos_maps.append(zero_map)
316
- self.cos_maps = cos_maps
317
-
318
-
319
- # Get geometric info from fragment shader
320
- # Can be used for generating conditioning image and cosine weights
321
- # Returns some information you may not need, remember to release them for memory saving
322
- @torch.no_grad()
323
- def render_geometry(self, image_size=None):
324
- if image_size:
325
- size = self.renderer.rasterizer.raster_settings.image_size
326
- self.renderer.rasterizer.raster_settings.image_size = image_size
327
- shader = self.renderer.shader
328
- self.renderer.shader = HardGeometryShader(device=self.device, cameras=self.cameras[0], lights=self.lights)
329
- tmp_mesh = self.mesh.clone()
330
-
331
- verts, normals, depths, cos_angles, texels, fragments = self.renderer(tmp_mesh.extend(len(self.cameras)), cameras=self.cameras, lights=self.lights)
332
- self.renderer.shader = shader
333
-
334
- if image_size:
335
- self.renderer.rasterizer.raster_settings.image_size = size
336
-
337
- return verts, normals, depths, cos_angles, texels, fragments
338
-
339
-
340
- # Project world normal to view space and normalize
341
- @torch.no_grad()
342
- def decode_view_normal(self, normals):
343
- w2v_mat = self.cameras.get_full_projection_transform()
344
- normals_view = torch.clone(normals)[:,:,:,0:3]
345
- normals_view = normals_view.reshape(normals_view.shape[0], -1, 3)
346
- normals_view = w2v_mat.transform_normals(normals_view)
347
- normals_view = normals_view.reshape(normals.shape[0:3]+(3,))
348
- normals_view[:,:,:,2] *= -1
349
- normals = (normals_view[...,0:3]+1) * normals[...,3:] / 2 + torch.FloatTensor(((((0.5,0.5,1))))).to(self.device) * (1 - normals[...,3:])
350
- # normals = torch.cat([normal for normal in normals], dim=1)
351
- normals = normals.clamp(0, 1)
352
- return normals
353
-
354
-
355
- # Normalize absolute depth to inverse depth
356
- @torch.no_grad()
357
- def decode_normalized_depth(self, depths, batched_norm=False):
358
- view_z, mask = depths.unbind(-1)
359
- view_z = view_z * mask + 100 * (1-mask)
360
- inv_z = 1 / view_z
361
- inv_z_min = inv_z * mask + 100 * (1-mask)
362
- if not batched_norm:
363
- max_ = torch.max(inv_z, 1, keepdim=True)
364
- max_ = torch.max(max_[0], 2, keepdim=True)[0]
365
-
366
- min_ = torch.min(inv_z_min, 1, keepdim=True)
367
- min_ = torch.min(min_[0], 2, keepdim=True)[0]
368
- else:
369
- max_ = torch.max(inv_z)
370
- min_ = torch.min(inv_z_min)
371
- inv_z = (inv_z - min_) / (max_ - min_)
372
- inv_z = inv_z.clamp(0,1)
373
- inv_z = inv_z[...,None].repeat(1,1,1,3)
374
-
375
- return inv_z
376
-
377
-
378
- # Multiple screen pixels could pass gradient to a same texel
379
- # We can precalculate this gradient strength and use it to normalize gradients when we bake textures
380
- @torch.enable_grad()
381
- def calculate_tex_gradient(self, channels=None):
382
- if not channels:
383
- channels = self.channels
384
- tmp_mesh = self.mesh.clone()
385
- gradient_maps = []
386
- for i in range(len(self.cameras)):
387
- zero_map = torch.zeros(self.target_size+(channels,), device=self.device, requires_grad=True)
388
- optimizer = torch.optim.SGD([zero_map], lr=1, momentum=0)
389
- optimizer.zero_grad()
390
- zero_tex = TexturesUV([zero_map], self.mesh.textures.faces_uvs_padded(), self.mesh.textures.verts_uvs_padded(), sampling_mode=self.sampling_mode)
391
- tmp_mesh.textures = zero_tex
392
- images_predicted = self.renderer(tmp_mesh, cameras=self.cameras[i], lights=self.lights)
393
- loss = torch.sum((1 - images_predicted)**2)
394
- loss.backward()
395
- optimizer.step()
396
-
397
- gradient_maps.append(zero_map.detach())
398
-
399
- self.gradient_maps = gradient_maps
400
-
401
-
402
- # Get the UV space masks of triangles visible in each view
403
- # First get face ids from each view, then filter pixels on UV space to generate masks
404
- @torch.no_grad()
405
- def calculate_visible_triangle_mask(self, channels=None, image_size=(512,512)):
406
- if not channels:
407
- channels = self.channels
408
-
409
- pix2face_list = []
410
- for i in range(len(self.cameras)):
411
- self.renderer.rasterizer.raster_settings.image_size=image_size
412
- pix2face = self.renderer.rasterizer(self.mesh_d, cameras=self.cameras[i]).pix_to_face
413
- self.renderer.rasterizer.raster_settings.image_size=self.render_size
414
- pix2face_list.append(pix2face)
415
-
416
- if not hasattr(self, "mesh_uv"):
417
- self.construct_uv_mesh()
418
-
419
- raster_settings = RasterizationSettings(
420
- image_size=self.target_size,
421
- blur_radius=0,
422
- faces_per_pixel=1,
423
- perspective_correct=False,
424
- cull_backfaces=False,
425
- max_faces_per_bin=30000,
426
- )
427
-
428
- R, T = look_at_view_transform(dist=2, elev=0, azim=0)
429
- cameras = FoVOrthographicCameras(device=self.device, R=R, T=T)
430
-
431
- rasterizer=MeshRasterizer(
432
- cameras=cameras,
433
- raster_settings=raster_settings
434
- )
435
- uv_pix2face = rasterizer(self.mesh_uv).pix_to_face
436
-
437
- visible_triangles = []
438
- for i in range(len(pix2face_list)):
439
- valid_faceid = torch.unique(pix2face_list[i])
440
- valid_faceid = valid_faceid[1:] if valid_faceid[0]==-1 else valid_faceid
441
- mask = torch.isin(uv_pix2face[0], valid_faceid, assume_unique=False)
442
- # uv_pix2face[0][~mask] = -1
443
- triangle_mask = torch.ones(self.target_size+(1,), device=self.device)
444
- triangle_mask[~mask] = 0
445
-
446
- triangle_mask[:,1:][triangle_mask[:,:-1] > 0] = 1
447
- triangle_mask[:,:-1][triangle_mask[:,1:] > 0] = 1
448
- triangle_mask[1:,:][triangle_mask[:-1,:] > 0] = 1
449
- triangle_mask[:-1,:][triangle_mask[1:,:] > 0] = 1
450
- visible_triangles.append(triangle_mask)
451
-
452
- self.visible_triangles = visible_triangles
453
-
454
-
455
-
456
- # Render the current mesh and texture from current cameras
457
- def render_textured_views(self):
458
- meshes = self.mesh.extend(len(self.cameras))
459
- images_predicted = self.renderer(meshes, cameras=self.cameras, lights=self.lights)
460
-
461
- return [image.permute(2, 0, 1) for image in images_predicted]
462
-
463
-
464
- # Bake views into a texture
465
- # First bake into individual textures then combine based on cosine weight
466
- @torch.enable_grad()
467
- def bake_texture(self, views=None, main_views=[], cos_weighted=True, channels=None, exp=None, noisy=False, generator=None):
468
- if not exp:
469
- exp=1
470
- if not channels:
471
- channels = self.channels
472
- views = [view.permute(1, 2, 0) for view in views]
473
-
474
- tmp_mesh = self.mesh
475
- bake_maps = [torch.zeros(self.target_size+(views[0].shape[2],), device=self.device, requires_grad=True) for view in views]
476
- optimizer = torch.optim.SGD(bake_maps, lr=1, momentum=0)
477
- optimizer.zero_grad()
478
- loss = 0
479
- for i in range(len(self.cameras)):
480
- bake_tex = TexturesUV([bake_maps[i]], tmp_mesh.textures.faces_uvs_padded(), tmp_mesh.textures.verts_uvs_padded(), sampling_mode=self.sampling_mode)
481
- tmp_mesh.textures = bake_tex
482
- images_predicted = self.renderer(tmp_mesh, cameras=self.cameras[i], lights=self.lights, device=self.device)
483
- predicted_rgb = images_predicted[..., :-1]
484
- loss += (((predicted_rgb[...] - views[i]))**2).sum()
485
- loss.backward(retain_graph=False)
486
- optimizer.step()
487
-
488
- total_weights = 0
489
- baked = 0
490
- for i in range(len(bake_maps)):
491
- normalized_baked_map = bake_maps[i].detach() / (self.gradient_maps[i] + 1E-8)
492
- bake_map = voronoi_solve(normalized_baked_map, self.gradient_maps[i][...,0])
493
- weight = self.visible_triangles[i] * (self.cos_maps[i]) ** exp
494
- if noisy:
495
- noise = torch.rand(weight.shape[:-1]+(1,), generator=generator).type(weight.dtype).to(weight.device)
496
- weight *= noise
497
- total_weights += weight
498
- baked += bake_map * weight
499
- baked /= total_weights + 1E-8
500
- baked = voronoi_solve(baked, total_weights[...,0])
501
-
502
- bake_tex = TexturesUV([baked], tmp_mesh.textures.faces_uvs_padded(), tmp_mesh.textures.verts_uvs_padded(), sampling_mode=self.sampling_mode)
503
- tmp_mesh.textures = bake_tex
504
- extended_mesh = tmp_mesh.extend(len(self.cameras))
505
- images_predicted = self.renderer(extended_mesh, cameras=self.cameras, lights=self.lights)
506
- learned_views = [image.permute(2, 0, 1) for image in images_predicted]
507
-
508
- return learned_views, baked.permute(2, 0, 1), total_weights.permute(2, 0, 1)
509
-
510
-
511
- # Move the internel data to a specific device
512
- def to(self, device):
513
- for mesh_name in ["mesh", "mesh_d", "mesh_uv"]:
514
- if hasattr(self, mesh_name):
515
- mesh = getattr(self, mesh_name)
516
- setattr(self, mesh_name, mesh.to(device))
517
- for list_name in ["visible_triangles", "visibility_maps", "cos_maps"]:
518
- if hasattr(self, list_name):
519
- map_list = getattr(self, list_name)
520
- for i in range(len(map_list)):
521
- map_list[i] = map_list[i].to(device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
step1x3d_texture/texture_sync/shader.py DELETED
@@ -1,118 +0,0 @@
1
- from typing import Optional
2
-
3
- import torch
4
- import pytorch3d
5
-
6
-
7
- from pytorch3d.io import load_objs_as_meshes, load_obj, save_obj
8
- from pytorch3d.ops import interpolate_face_attributes
9
-
10
- from pytorch3d.structures import Meshes
11
- from pytorch3d.renderer import (
12
- look_at_view_transform,
13
- FoVPerspectiveCameras,
14
- AmbientLights,
15
- PointLights,
16
- DirectionalLights,
17
- Materials,
18
- RasterizationSettings,
19
- MeshRenderer,
20
- MeshRasterizer,
21
- SoftPhongShader,
22
- SoftSilhouetteShader,
23
- HardPhongShader,
24
- TexturesVertex,
25
- TexturesUV,
26
- Materials,
27
-
28
- )
29
- from pytorch3d.renderer.blending import BlendParams,hard_rgb_blend
30
- from pytorch3d.renderer.utils import convert_to_tensors_and_broadcast, TensorProperties
31
-
32
- from pytorch3d.renderer.lighting import AmbientLights
33
- from pytorch3d.renderer.materials import Materials
34
- from pytorch3d.renderer.mesh.shader import ShaderBase
35
- from pytorch3d.renderer.mesh.shading import _apply_lighting, flat_shading
36
- from pytorch3d.renderer.mesh.rasterizer import Fragments
37
-
38
-
39
- '''
40
- Customized the original pytorch3d hard flat shader to support N channel flat shading
41
- '''
42
- class HardNChannelFlatShader(ShaderBase):
43
- """
44
- Per face lighting - the lighting model is applied using the average face
45
- position and the face normal. The blending function hard assigns
46
- the color of the closest face for each pixel.
47
-
48
- To use the default values, simply initialize the shader with the desired
49
- device e.g.
50
-
51
- .. code-block::
52
-
53
- shader = HardFlatShader(device=torch.device("cuda:0"))
54
- """
55
-
56
- def __init__(
57
- self,
58
- device = "cpu",
59
- cameras: Optional[TensorProperties] = None,
60
- lights: Optional[TensorProperties] = None,
61
- materials: Optional[Materials] = None,
62
- blend_params: Optional[BlendParams] = None,
63
- channels: int = 3,
64
- ):
65
- self.channels = channels
66
- ones = ((1.0,)*channels,)
67
- zeros = ((0.0,)*channels,)
68
-
69
- if not isinstance(lights, AmbientLights) or not lights.ambient_color.shape[-1] == channels:
70
- lights = AmbientLights(
71
- ambient_color=ones,
72
- device=device,
73
- )
74
-
75
- if not materials or not materials.ambient_color.shape[-1] == channels:
76
- materials = Materials(
77
- device=device,
78
- diffuse_color=zeros,
79
- ambient_color=ones,
80
- specular_color=zeros,
81
- shininess=0.0,
82
- )
83
-
84
- blend_params_new = BlendParams(background_color=(1.0,)*channels)
85
- if not isinstance(blend_params, BlendParams):
86
- blend_params = blend_params_new
87
- else:
88
- background_color_ = blend_params.background_color
89
- if isinstance(background_color_, Sequence[float]) and not len(background_color_) == channels:
90
- blend_params = blend_params_new
91
- if isinstance(background_color_, torch.Tensor) and not background_color_.shape[-1] == channels:
92
- blend_params = blend_params_new
93
-
94
- super().__init__(
95
- device,
96
- cameras,
97
- lights,
98
- materials,
99
- blend_params,
100
- )
101
-
102
-
103
- def forward(self, fragments: Fragments, meshes: Meshes, **kwargs) -> torch.Tensor:
104
- cameras = super()._get_cameras(**kwargs)
105
- texels = meshes.sample_textures(fragments)
106
- lights = kwargs.get("lights", self.lights)
107
- materials = kwargs.get("materials", self.materials)
108
- blend_params = kwargs.get("blend_params", self.blend_params)
109
- colors = flat_shading(
110
- meshes=meshes,
111
- fragments=fragments,
112
- texels=texels,
113
- lights=lights,
114
- cameras=cameras,
115
- materials=materials,
116
- )
117
- images = hard_rgb_blend(colors, fragments, blend_params)
118
- return images
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
step1x3d_texture/texture_sync/step_sync.py DELETED
@@ -1,125 +0,0 @@
1
- import torch
2
- from diffusers.utils.torch_utils import randn_tensor
3
-
4
- '''
5
-
6
- Customized Step Function
7
- step on texture
8
- '''
9
- @torch.no_grad()
10
- def step_tex_sync(
11
- scheduler,
12
- uvp,
13
- model_output: torch.FloatTensor,
14
- timestep: int,
15
- sample: torch.FloatTensor,
16
- texture: None,
17
- generator=None,
18
- return_dict: bool = True,
19
- guidance_scale = 1,
20
- main_views = [],
21
- hires_original_views = True,
22
- exp=None,
23
- cos_weighted=True
24
- ):
25
- t = timestep
26
-
27
- prev_t = scheduler.previous_timestep(t)
28
-
29
- if model_output.shape[1] == sample.shape[1] * 2 and scheduler.variance_type in ["learned", "learned_range"]:
30
- model_output, predicted_variance = torch.split(model_output, sample.shape[1], dim=1)
31
- else:
32
- predicted_variance = None
33
-
34
- # 1. compute alphas, betas
35
- alpha_prod_t = scheduler.alphas_cumprod[t]
36
- alpha_prod_t_prev = scheduler.alphas_cumprod[prev_t] if prev_t >= 0 else scheduler.one
37
- beta_prod_t = 1 - alpha_prod_t
38
- beta_prod_t_prev = 1 - alpha_prod_t_prev
39
- current_alpha_t = alpha_prod_t / alpha_prod_t_prev
40
- current_beta_t = 1 - current_alpha_t
41
-
42
- # 2. compute predicted original sample from predicted noise also called
43
- # "predicted x_0" of formula (15) from https://arxiv.org/pdf/2006.11239.pdf
44
- if scheduler.config.prediction_type == "epsilon":
45
- pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
46
- elif scheduler.config.prediction_type == "sample":
47
- pred_original_sample = model_output
48
- elif scheduler.config.prediction_type == "v_prediction":
49
- pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output
50
- else:
51
- raise ValueError(
52
- f"prediction_type given as {scheduler.config.prediction_type} must be one of `epsilon`, `sample` or"
53
- " `v_prediction` for the DDPMScheduler."
54
- )
55
- # 3. Clip or threshold "predicted x_0"
56
- if scheduler.config.thresholding:
57
- pred_original_sample = scheduler._threshold_sample(pred_original_sample)
58
- elif scheduler.config.clip_sample:
59
- pred_original_sample = pred_original_sample.clamp(
60
- -scheduler.config.clip_sample_range, scheduler.config.clip_sample_range
61
- )
62
-
63
- # 4. Compute coefficients for pred_original_sample x_0 and current sample x_t
64
- # See formula (7) from https://arxiv.org/pdf/2006.11239.pdf
65
- pred_original_sample_coeff = (alpha_prod_t_prev ** (0.5) * current_beta_t) / beta_prod_t
66
- current_sample_coeff = current_alpha_t ** (0.5) * beta_prod_t_prev / beta_prod_t
67
-
68
- '''
69
- Add multidiffusion here
70
- '''
71
-
72
- if texture is None:
73
- sample_views = [view for view in sample]
74
- sample_views, texture, _ = uvp.bake_texture(views=sample_views, main_views=main_views, exp=exp)
75
- sample_views = torch.stack(sample_views, axis=0)[:,:-1,...]
76
-
77
-
78
- original_views = [view for view in pred_original_sample]
79
- original_views, original_tex, visibility_weights = uvp.bake_texture(views=original_views, main_views=main_views, exp=exp)
80
- uvp.set_texture_map(original_tex)
81
- original_views = uvp.render_textured_views()
82
- original_views = torch.stack(original_views, axis=0)[:,:-1,...]
83
-
84
- # 5. Compute predicted previous sample µ_t
85
- # See formula (7) from https://arxiv.org/pdf/2006.11239.pdf
86
- # pred_prev_sample = pred_original_sample_coeff * pred_original_sample + current_sample_coeff * sample
87
- prev_tex = pred_original_sample_coeff * original_tex + current_sample_coeff * texture
88
-
89
- # 6. Add noise
90
- variance = 0
91
-
92
- if predicted_variance is not None:
93
- variance_views = [view for view in predicted_variance]
94
- variance_views, variance_tex, visibility_weights = uvp.bake_texture(views=variance_views, main_views=main_views, cos_weighted=cos_weighted, exp=exp)
95
- variance_views = torch.stack(variance_views, axis=0)[:,:-1,...]
96
- else:
97
- variance_tex = None
98
-
99
- if t > 0:
100
- device = texture.device
101
- variance_noise = randn_tensor(
102
- texture.shape, generator=generator, device=device, dtype=texture.dtype
103
- )
104
- if scheduler.variance_type == "fixed_small_log":
105
- variance = scheduler._get_variance(t, predicted_variance=variance_tex) * variance_noise
106
- elif scheduler.variance_type == "learned_range":
107
- variance = scheduler._get_variance(t, predicted_variance=variance_tex)
108
- variance = torch.exp(0.5 * variance) * variance_noise
109
- else:
110
- variance = (scheduler._get_variance(t, predicted_variance=variance_tex) ** 0.5) * variance_noise
111
- prev_tex = prev_tex + variance
112
-
113
- uvp.set_texture_map(prev_tex)
114
- prev_views = uvp.render_textured_views()
115
- pred_prev_sample = torch.clone(sample)
116
- for i, view in enumerate(prev_views):
117
- pred_prev_sample[i] = view[:-1]
118
- masks = [view[-1:] for view in prev_views]
119
-
120
- return {"prev_sample": pred_prev_sample, "pred_original_sample":pred_original_sample, "prev_tex": prev_tex}
121
-
122
- if not return_dict:
123
- return pred_prev_sample, pred_original_sample
124
- pass
125
-