Spaces:
Running
on
Zero
Running
on
Zero
Revert "texture sync"
Browse filesThis reverts commit 55f226f582932e6ec64e096f296c54d47a59de80.
- step1x3d_texture/pipelines/ig2mv_sdxl_pipeline.py +3 -131
- step1x3d_texture/pipelines/step1x_3d_texture_synthesis_pipeline.py +8 -17
- step1x3d_texture/{texture_sync → renderer}/__init__.py +0 -0
- step1x3d_texture/renderer/geometry.py +151 -0
- step1x3d_texture/renderer/project.py +875 -0
- step1x3d_texture/renderer/shader.py +127 -0
- step1x3d_texture/{texture_sync → renderer}/voronoi.py +0 -0
- step1x3d_texture/texture_sync/geometry.py +0 -141
- step1x3d_texture/texture_sync/project.py +0 -521
- step1x3d_texture/texture_sync/shader.py +0 -118
- step1x3d_texture/texture_sync/step_sync.py +0 -125
step1x3d_texture/pipelines/ig2mv_sdxl_pipeline.py
CHANGED
@@ -51,20 +51,6 @@ from ..models.attention_processor import (
|
|
51 |
DecoupledMVRowSelfAttnProcessor2_0,
|
52 |
set_unet_2d_condition_attn_processor,
|
53 |
)
|
54 |
-
import random
|
55 |
-
from ..texture_sync.project import UVProjection as UVP
|
56 |
-
from ..texture_sync.step_sync import step_tex_sync
|
57 |
-
from trimesh import Trimesh
|
58 |
-
from torchvision.transforms import Compose, Resize, GaussianBlur, InterpolationMode
|
59 |
-
from diffusers.utils import (
|
60 |
-
BaseOutput,
|
61 |
-
numpy_to_pil,
|
62 |
-
pt_to_pil,
|
63 |
-
is_accelerate_available,
|
64 |
-
is_accelerate_version,
|
65 |
-
logging,
|
66 |
-
replace_example_docstring
|
67 |
-
)
|
68 |
|
69 |
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
70 |
|
@@ -84,27 +70,6 @@ def retrieve_latents(
|
|
84 |
raise AttributeError("Could not access latents of provided encoder_output")
|
85 |
|
86 |
|
87 |
-
@torch.no_grad()
|
88 |
-
def composite_rendered_view(scheduler, backgrounds, foregrounds, masks, t):
|
89 |
-
composited_images = []
|
90 |
-
for i, (background, foreground, mask) in enumerate(zip(backgrounds, foregrounds, masks)):
|
91 |
-
if t > 0:
|
92 |
-
alphas_cumprod = scheduler.alphas_cumprod[t]
|
93 |
-
noise = torch.normal(0, 1, background.shape, device=background.device)
|
94 |
-
background = (1-alphas_cumprod) * noise + alphas_cumprod * background
|
95 |
-
composited = foreground * mask + background * (1-mask)
|
96 |
-
composited_images.append(composited)
|
97 |
-
composited_tensor = torch.stack(composited_images)
|
98 |
-
return composited_tensor
|
99 |
-
|
100 |
-
|
101 |
-
@torch.no_grad()
|
102 |
-
def encode_latents(vae, imgs):
|
103 |
-
imgs = (imgs-0.5)*2
|
104 |
-
latents = vae.encode(imgs).latent_dist.sample()
|
105 |
-
latents = vae.config.scaling_factor * latents
|
106 |
-
return latents
|
107 |
-
|
108 |
class IG2MVSDXLPipeline(StableDiffusionXLPipeline, CustomAdapterMixin):
|
109 |
def __init__(
|
110 |
self,
|
@@ -344,8 +309,6 @@ class IG2MVSDXLPipeline(StableDiffusionXLPipeline, CustomAdapterMixin):
|
|
344 |
# Image condition
|
345 |
reference_image: Optional[PipelineImageInput] = None,
|
346 |
reference_conditioning_scale: Optional[float] = 1.0,
|
347 |
-
mesh: Optional[Trimesh] = None,
|
348 |
-
texture_sync_config: Optional[dict] = None,
|
349 |
**kwargs,
|
350 |
):
|
351 |
r"""
|
@@ -593,27 +556,6 @@ class IG2MVSDXLPipeline(StableDiffusionXLPipeline, CustomAdapterMixin):
|
|
593 |
latents,
|
594 |
)
|
595 |
|
596 |
-
# texture patams init
|
597 |
-
texture_size = texture_sync_config["texture_size"]
|
598 |
-
latent_size = texture_sync_config["latent_size"]
|
599 |
-
elevations = texture_sync_config["elevations"]
|
600 |
-
azimuths = texture_sync_config["azimuths"]
|
601 |
-
texture_sync_ratio = texture_sync_config["texture_sync_ratio"]
|
602 |
-
camera_poses = [(elv, azim) for elv, azim in zip(elevations, azimuths)]
|
603 |
-
uvp = UVP(texture_size=texture_size, render_size=latent_size, sampling_mode="nearest", channels=4, device=self._execution_device)
|
604 |
-
uvp.load_mesh(mesh, scale_factor=1.0, autouv=True)
|
605 |
-
uvp.set_cameras_and_render_settings(camera_poses, centers=None, camera_distance=texture_sync_config["camera_distance"], scale=((1.0, 1.0, 1.0),))
|
606 |
-
|
607 |
-
latent_tex = uvp.set_noise_texture()
|
608 |
-
noise_views = uvp.render_textured_views()
|
609 |
-
foregrounds = [view[:-1] for view in noise_views]
|
610 |
-
masks = [view[-1:] for view in noise_views]
|
611 |
-
|
612 |
-
if texture_sync_ratio>0:
|
613 |
-
composited_tensor = composite_rendered_view(self.scheduler, latents, foregrounds, masks, int(timesteps[0].cpu().item())+1)
|
614 |
-
latents = composited_tensor.type(latents.dtype)
|
615 |
-
uvp.to("cpu")
|
616 |
-
|
617 |
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
618 |
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
619 |
|
@@ -767,36 +709,6 @@ class IG2MVSDXLPipeline(StableDiffusionXLPipeline, CustomAdapterMixin):
|
|
767 |
).to(device=device, dtype=latents.dtype)
|
768 |
|
769 |
self._num_timesteps = len(timesteps)
|
770 |
-
|
771 |
-
|
772 |
-
# texture sync params
|
773 |
-
exp_start = texture_sync_config["exp_start"]
|
774 |
-
exp_end = texture_sync_config["exp_end"]
|
775 |
-
shuffle_background_change = texture_sync_config["shuffle_background_change"]
|
776 |
-
shuffle_background_end = texture_sync_config["shuffle_background_end"]
|
777 |
-
num_timesteps = self.scheduler.config.num_train_timesteps
|
778 |
-
|
779 |
-
uvp.to(self._execution_device)
|
780 |
-
color_constants = {"black": [-1, -1, -1], "white": [1, 1, 1], "maroon": [0, -1, -1],
|
781 |
-
"red": [1, -1, -1], "olive": [0, 0, -1], "yellow": [1, 1, -1],
|
782 |
-
"green": [-1, 0, -1], "lime": [-1 ,1, -1], "teal": [-1, 0, 0],
|
783 |
-
"aqua": [-1, 1, 1], "navy": [-1, -1, 0], "blue": [-1, -1, 1],
|
784 |
-
"purple": [0, -1 , 0], "fuchsia": [1, -1, 1]}
|
785 |
-
color_names = list(color_constants.keys())
|
786 |
-
background_colors = [random.choice(list(color_constants.keys())) for i in range(len(camera_poses))]
|
787 |
-
intermediate_results = []
|
788 |
-
self.upcast_vae()
|
789 |
-
self.vae.config.force_upcast = True
|
790 |
-
color_images = torch.FloatTensor([color_constants[name] for name in color_names]).reshape(-1,3,1,1).to(dtype=torch.float32, device=self._execution_device)
|
791 |
-
color_images = torch.ones(
|
792 |
-
(1,1,latent_size*8, latent_size*8),
|
793 |
-
device=self._execution_device,
|
794 |
-
dtype=torch.float32
|
795 |
-
) * color_images
|
796 |
-
color_images = ((0.5*color_images)+0.5)
|
797 |
-
color_latents = encode_latents(self.vae, color_images).to(dtype=self.text_encoder_2.dtype)
|
798 |
-
color_latents = {color[0]:color[1] for color in zip(color_names, [latent for latent in color_latents])}
|
799 |
-
|
800 |
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
801 |
for i, t in enumerate(timesteps):
|
802 |
if self.interrupt:
|
@@ -856,49 +768,9 @@ class IG2MVSDXLPipeline(StableDiffusionXLPipeline, CustomAdapterMixin):
|
|
856 |
|
857 |
# compute the previous noisy sample x_t -> x_t-1
|
858 |
latents_dtype = latents.dtype
|
859 |
-
|
860 |
-
|
861 |
-
|
862 |
-
if t > (1-texture_sync_ratio)*num_timesteps:
|
863 |
-
step_results = step_tex_sync(
|
864 |
-
scheduler=self.scheduler,
|
865 |
-
uvp=uvp,
|
866 |
-
model_output=noise_pred,
|
867 |
-
timestep=t,
|
868 |
-
sample=latents,
|
869 |
-
texture=latent_tex,
|
870 |
-
return_dict=True,
|
871 |
-
main_views=[],
|
872 |
-
exp= current_exp,
|
873 |
-
**extra_step_kwargs
|
874 |
-
)
|
875 |
-
|
876 |
-
pred_original_sample = step_results["pred_original_sample"]
|
877 |
-
latents = step_results["prev_sample"]
|
878 |
-
latent_tex = step_results["prev_tex"]
|
879 |
-
|
880 |
-
# Composit latent foreground with random color background
|
881 |
-
background_latents = [color_latents[color] for color in background_colors]
|
882 |
-
composited_tensor = composite_rendered_view(self.scheduler, background_latents, latents, masks, t)
|
883 |
-
latents = composited_tensor.type(latents.dtype)
|
884 |
-
|
885 |
-
intermediate_results.append((latents.to("cpu"), pred_original_sample.to("cpu")))
|
886 |
-
else:
|
887 |
-
step_results = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=True)
|
888 |
-
pred_original_sample = step_results["pred_original_sample"]
|
889 |
-
latents = step_results["prev_sample"]
|
890 |
-
latent_tex = None
|
891 |
-
intermediate_results.append((latents.to("cpu"), pred_original_sample.to("cpu")))
|
892 |
-
|
893 |
-
# 2. Shuffle background colors; only black and white used after certain timestep
|
894 |
-
if (1-t/num_timesteps) < shuffle_background_change:
|
895 |
-
background_colors = [random.choice(list(color_constants.keys())) for i in range(len(camera_poses))]
|
896 |
-
elif (1-t/num_timesteps) < shuffle_background_end:
|
897 |
-
background_colors = [random.choice(["black","white"]) for i in range(len(camera_poses))]
|
898 |
-
else:
|
899 |
-
background_colors = background_colors
|
900 |
-
del noise_pred
|
901 |
-
|
902 |
if latents.dtype != latents_dtype:
|
903 |
if torch.backends.mps.is_available():
|
904 |
# some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
|
|
|
51 |
DecoupledMVRowSelfAttnProcessor2_0,
|
52 |
set_unet_2d_condition_attn_processor,
|
53 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
54 |
|
55 |
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
56 |
|
|
|
70 |
raise AttributeError("Could not access latents of provided encoder_output")
|
71 |
|
72 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
73 |
class IG2MVSDXLPipeline(StableDiffusionXLPipeline, CustomAdapterMixin):
|
74 |
def __init__(
|
75 |
self,
|
|
|
309 |
# Image condition
|
310 |
reference_image: Optional[PipelineImageInput] = None,
|
311 |
reference_conditioning_scale: Optional[float] = 1.0,
|
|
|
|
|
312 |
**kwargs,
|
313 |
):
|
314 |
r"""
|
|
|
556 |
latents,
|
557 |
)
|
558 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
559 |
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
560 |
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
561 |
|
|
|
709 |
).to(device=device, dtype=latents.dtype)
|
710 |
|
711 |
self._num_timesteps = len(timesteps)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
712 |
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
713 |
for i, t in enumerate(timesteps):
|
714 |
if self.interrupt:
|
|
|
768 |
|
769 |
# compute the previous noisy sample x_t -> x_t-1
|
770 |
latents_dtype = latents.dtype
|
771 |
+
latents = self.scheduler.step(
|
772 |
+
noise_pred, t, latents, **extra_step_kwargs, return_dict=False
|
773 |
+
)[0]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
774 |
if latents.dtype != latents_dtype:
|
775 |
if torch.backends.mps.is_available():
|
776 |
# some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
|
step1x3d_texture/pipelines/step1x_3d_texture_synthesis_pipeline.py
CHANGED
@@ -24,6 +24,7 @@ import trimesh
|
|
24 |
import xatlas
|
25 |
import scipy.sparse
|
26 |
from scipy.sparse.linalg import spsolve
|
|
|
27 |
from step1x3d_geometry.models.pipelines.pipeline_utils import smart_load_model
|
28 |
|
29 |
|
@@ -35,7 +36,7 @@ class Step1X3DTextureConfig:
|
|
35 |
self.unet_model = None
|
36 |
self.lora_model = None
|
37 |
self.adapter_path = "stepfun-ai/Step1X-3D"
|
38 |
-
self.scheduler =
|
39 |
self.num_views = 6
|
40 |
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
41 |
self.dtype = torch.float16
|
@@ -60,20 +61,6 @@ class Step1X3DTextureConfig:
|
|
60 |
self.bake_exp = 4
|
61 |
self.merge_method = "fast"
|
62 |
|
63 |
-
# texture sync params
|
64 |
-
self.texture_sync_config = {
|
65 |
-
"texture_size": 1536,
|
66 |
-
"latent_size": 768//8,
|
67 |
-
"elevations": [0, 0, 0, 0, 90, -90],
|
68 |
-
"azimuths": [0, 90, 180, 270, 0, 0],
|
69 |
-
"texture_sync_ratio": 0.5,
|
70 |
-
"exp_end": 6.0,
|
71 |
-
"exp_start": 0,
|
72 |
-
"shuffle_background_change": 0.4,
|
73 |
-
"shuffle_background_end": 0.99,
|
74 |
-
"camera_distance": 1.8
|
75 |
-
}
|
76 |
-
|
77 |
|
78 |
class Step1X3DTexturePipeline:
|
79 |
def __init__(self, config):
|
@@ -133,9 +120,11 @@ class Step1X3DTexturePipeline:
|
|
133 |
if unet_model is not None:
|
134 |
pipe_kwargs["unet"] = UNet2DConditionModel.from_pretrained(unet_model)
|
135 |
|
|
|
136 |
# Prepare pipeline
|
137 |
pipe = IG2MVSDXLPipeline.from_pretrained(base_model, **pipe_kwargs)
|
138 |
|
|
|
139 |
# Load scheduler if provided
|
140 |
scheduler_class = None
|
141 |
if scheduler == "ddpm":
|
@@ -149,11 +138,14 @@ class Step1X3DTexturePipeline:
|
|
149 |
shift_scale=8.0,
|
150 |
scheduler_class=scheduler_class,
|
151 |
)
|
|
|
152 |
pipe.init_custom_adapter(
|
153 |
num_views=num_views,
|
154 |
self_attn_processor=DecoupledMVRowColSelfAttnProcessor2_0,
|
155 |
)
|
|
|
156 |
pipe.load_custom_adapter(adapter_path, "step1x-3d-ig2v.safetensors")
|
|
|
157 |
pipe.to(device=device, dtype=dtype)
|
158 |
pipe.cond_encoder.to(device=device, dtype=dtype)
|
159 |
|
@@ -290,7 +282,6 @@ class Step1X3DTexturePipeline:
|
|
290 |
negative_prompt=negative_prompt,
|
291 |
cross_attention_kwargs={"scale": lora_scale},
|
292 |
mesh=mesh_bp,
|
293 |
-
texture_sync_config=self.config.texture_sync_config,
|
294 |
**pipe_kwargs,
|
295 |
).images
|
296 |
|
@@ -368,7 +359,7 @@ class Step1X3DTexturePipeline:
|
|
368 |
width=768,
|
369 |
num_inference_steps=self.config.num_inference_steps,
|
370 |
guidance_scale=self.config.guidance_scale,
|
371 |
-
seed=
|
372 |
lora_scale=self.config.lora_scale,
|
373 |
reference_conditioning_scale=self.config.reference_conditioning_scale,
|
374 |
negative_prompt=self.config.negative_prompt,
|
|
|
24 |
import xatlas
|
25 |
import scipy.sparse
|
26 |
from scipy.sparse.linalg import spsolve
|
27 |
+
|
28 |
from step1x3d_geometry.models.pipelines.pipeline_utils import smart_load_model
|
29 |
|
30 |
|
|
|
36 |
self.unet_model = None
|
37 |
self.lora_model = None
|
38 |
self.adapter_path = "stepfun-ai/Step1X-3D"
|
39 |
+
self.scheduler = None
|
40 |
self.num_views = 6
|
41 |
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
42 |
self.dtype = torch.float16
|
|
|
61 |
self.bake_exp = 4
|
62 |
self.merge_method = "fast"
|
63 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
64 |
|
65 |
class Step1X3DTexturePipeline:
|
66 |
def __init__(self, config):
|
|
|
120 |
if unet_model is not None:
|
121 |
pipe_kwargs["unet"] = UNet2DConditionModel.from_pretrained(unet_model)
|
122 |
|
123 |
+
print('VAE Loaded!')
|
124 |
# Prepare pipeline
|
125 |
pipe = IG2MVSDXLPipeline.from_pretrained(base_model, **pipe_kwargs)
|
126 |
|
127 |
+
print('Base model Loaded!')
|
128 |
# Load scheduler if provided
|
129 |
scheduler_class = None
|
130 |
if scheduler == "ddpm":
|
|
|
138 |
shift_scale=8.0,
|
139 |
scheduler_class=scheduler_class,
|
140 |
)
|
141 |
+
print('Scheduler Loaded!')
|
142 |
pipe.init_custom_adapter(
|
143 |
num_views=num_views,
|
144 |
self_attn_processor=DecoupledMVRowColSelfAttnProcessor2_0,
|
145 |
)
|
146 |
+
print(f'Load adapter from {adapter_path}/step1x-3d-ig2v.safetensors')
|
147 |
pipe.load_custom_adapter(adapter_path, "step1x-3d-ig2v.safetensors")
|
148 |
+
print(f'Load adapter successed!')
|
149 |
pipe.to(device=device, dtype=dtype)
|
150 |
pipe.cond_encoder.to(device=device, dtype=dtype)
|
151 |
|
|
|
282 |
negative_prompt=negative_prompt,
|
283 |
cross_attention_kwargs={"scale": lora_scale},
|
284 |
mesh=mesh_bp,
|
|
|
285 |
**pipe_kwargs,
|
286 |
).images
|
287 |
|
|
|
359 |
width=768,
|
360 |
num_inference_steps=self.config.num_inference_steps,
|
361 |
guidance_scale=self.config.guidance_scale,
|
362 |
+
seed=seed if seed is not None else self.config.seed,
|
363 |
lora_scale=self.config.lora_scale,
|
364 |
reference_conditioning_scale=self.config.reference_conditioning_scale,
|
365 |
negative_prompt=self.config.negative_prompt,
|
step1x3d_texture/{texture_sync → renderer}/__init__.py
RENAMED
File without changes
|
step1x3d_texture/renderer/geometry.py
ADDED
@@ -0,0 +1,151 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import pytorch3d
|
3 |
+
import torch.nn.functional as F
|
4 |
+
|
5 |
+
from pytorch3d.ops import interpolate_face_attributes
|
6 |
+
|
7 |
+
from pytorch3d.renderer import (
|
8 |
+
look_at_view_transform,
|
9 |
+
FoVPerspectiveCameras,
|
10 |
+
AmbientLights,
|
11 |
+
PointLights,
|
12 |
+
DirectionalLights,
|
13 |
+
Materials,
|
14 |
+
RasterizationSettings,
|
15 |
+
MeshRenderer,
|
16 |
+
MeshRasterizer,
|
17 |
+
SoftPhongShader,
|
18 |
+
SoftSilhouetteShader,
|
19 |
+
HardPhongShader,
|
20 |
+
TexturesVertex,
|
21 |
+
TexturesUV,
|
22 |
+
Materials,
|
23 |
+
)
|
24 |
+
from pytorch3d.renderer.blending import BlendParams, hard_rgb_blend
|
25 |
+
from pytorch3d.renderer.utils import convert_to_tensors_and_broadcast, TensorProperties
|
26 |
+
from pytorch3d.renderer.mesh.shader import ShaderBase
|
27 |
+
|
28 |
+
|
29 |
+
def get_cos_angle(points, normals, camera_position):
|
30 |
+
"""
|
31 |
+
calculate cosine similarity between view->surface and surface normal.
|
32 |
+
"""
|
33 |
+
|
34 |
+
if points.shape != normals.shape:
|
35 |
+
msg = "Expected points and normals to have the same shape: got %r, %r"
|
36 |
+
raise ValueError(msg % (points.shape, normals.shape))
|
37 |
+
|
38 |
+
# Ensure all inputs have same batch dimension as points
|
39 |
+
matched_tensors = convert_to_tensors_and_broadcast(
|
40 |
+
points, camera_position, device=points.device
|
41 |
+
)
|
42 |
+
_, camera_position = matched_tensors
|
43 |
+
|
44 |
+
# Reshape direction and color so they have all the arbitrary intermediate
|
45 |
+
# dimensions as points. Assume first dim = batch dim and last dim = 3.
|
46 |
+
points_dims = points.shape[1:-1]
|
47 |
+
expand_dims = (-1,) + (1,) * len(points_dims)
|
48 |
+
|
49 |
+
if camera_position.shape != normals.shape:
|
50 |
+
camera_position = camera_position.view(expand_dims + (3,))
|
51 |
+
|
52 |
+
normals = F.normalize(normals, p=2, dim=-1, eps=1e-6)
|
53 |
+
|
54 |
+
# Calculate the cosine value.
|
55 |
+
view_direction = camera_position - points
|
56 |
+
view_direction = F.normalize(view_direction, p=2, dim=-1, eps=1e-6)
|
57 |
+
cos_angle = torch.sum(view_direction * normals, dim=-1, keepdim=True)
|
58 |
+
cos_angle = cos_angle.clamp(0, 1)
|
59 |
+
|
60 |
+
# Cosine of the angle between the reflected light ray and the viewer
|
61 |
+
return cos_angle
|
62 |
+
|
63 |
+
|
64 |
+
def _geometry_shading_with_pixels(
|
65 |
+
meshes, fragments, lights, cameras, materials, texels
|
66 |
+
):
|
67 |
+
"""
|
68 |
+
Render pixel space vertex position, normal(world), depth, and cos angle
|
69 |
+
|
70 |
+
Args:
|
71 |
+
meshes: Batch of meshes
|
72 |
+
fragments: Fragments named tuple with the outputs of rasterization
|
73 |
+
lights: Lights class containing a batch of lights
|
74 |
+
cameras: Cameras class containing a batch of cameras
|
75 |
+
materials: Materials class containing a batch of material properties
|
76 |
+
texels: texture per pixel of shape (N, H, W, K, 3)
|
77 |
+
|
78 |
+
Returns:
|
79 |
+
colors: (N, H, W, K, 3)
|
80 |
+
pixel_coords: (N, H, W, K, 3), camera coordinates of each intersection.
|
81 |
+
"""
|
82 |
+
verts = meshes.verts_packed() # (V, 3)
|
83 |
+
faces = meshes.faces_packed() # (F, 3)
|
84 |
+
vertex_normals = meshes.verts_normals_packed() # (V, 3)
|
85 |
+
faces_verts = verts[faces]
|
86 |
+
faces_normals = vertex_normals[faces]
|
87 |
+
pixel_coords_in_camera = interpolate_face_attributes(
|
88 |
+
fragments.pix_to_face, fragments.bary_coords, faces_verts
|
89 |
+
)
|
90 |
+
pixel_normals = interpolate_face_attributes(
|
91 |
+
fragments.pix_to_face, fragments.bary_coords, faces_normals
|
92 |
+
)
|
93 |
+
|
94 |
+
cos_angles = get_cos_angle(
|
95 |
+
pixel_coords_in_camera, pixel_normals, cameras.get_camera_center()
|
96 |
+
)
|
97 |
+
|
98 |
+
return pixel_coords_in_camera, pixel_normals, fragments.zbuf[..., None], cos_angles
|
99 |
+
|
100 |
+
|
101 |
+
class HardGeometryShader(ShaderBase):
|
102 |
+
"""
|
103 |
+
renders common geometric informations.
|
104 |
+
|
105 |
+
|
106 |
+
"""
|
107 |
+
|
108 |
+
def forward(self, fragments, meshes, **kwargs):
|
109 |
+
cameras = super()._get_cameras(**kwargs)
|
110 |
+
texels = self.texel_from_uv(fragments, meshes)
|
111 |
+
|
112 |
+
lights = kwargs.get("lights", self.lights)
|
113 |
+
materials = kwargs.get("materials", self.materials)
|
114 |
+
blend_params = kwargs.get("blend_params", self.blend_params)
|
115 |
+
verts, normals, depths, cos_angles = _geometry_shading_with_pixels(
|
116 |
+
meshes=meshes,
|
117 |
+
fragments=fragments,
|
118 |
+
texels=texels,
|
119 |
+
lights=lights,
|
120 |
+
cameras=cameras,
|
121 |
+
materials=materials,
|
122 |
+
)
|
123 |
+
texels = meshes.sample_textures(fragments)
|
124 |
+
verts = hard_rgb_blend(verts, fragments, blend_params)
|
125 |
+
normals = hard_rgb_blend(normals, fragments, blend_params)
|
126 |
+
depths = hard_rgb_blend(depths, fragments, blend_params)
|
127 |
+
cos_angles = hard_rgb_blend(cos_angles, fragments, blend_params)
|
128 |
+
from IPython import embed
|
129 |
+
|
130 |
+
embed()
|
131 |
+
texels = hard_rgb_blend(texels, fragments, blend_params)
|
132 |
+
return verts, normals, depths, cos_angles, texels, fragments
|
133 |
+
|
134 |
+
def texel_from_uv(self, fragments, meshes):
|
135 |
+
texture_tmp = meshes.textures
|
136 |
+
maps_tmp = texture_tmp.maps_padded()
|
137 |
+
uv_color = [[[1, 0], [1, 1]], [[0, 0], [0, 1]]]
|
138 |
+
uv_color = (
|
139 |
+
torch.FloatTensor(uv_color).to(maps_tmp[0].device).type(maps_tmp[0].dtype)
|
140 |
+
)
|
141 |
+
uv_texture = TexturesUV(
|
142 |
+
[uv_color.clone() for t in maps_tmp],
|
143 |
+
texture_tmp.faces_uvs_padded(),
|
144 |
+
texture_tmp.verts_uvs_padded(),
|
145 |
+
sampling_mode="bilinear",
|
146 |
+
)
|
147 |
+
meshes.textures = uv_texture
|
148 |
+
texels = meshes.sample_textures(fragments)
|
149 |
+
meshes.textures = texture_tmp
|
150 |
+
texels = torch.cat((texels, texels[..., -1:] * 0), dim=-1)
|
151 |
+
return texels
|
step1x3d_texture/renderer/project.py
ADDED
@@ -0,0 +1,875 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import pytorch3d
|
3 |
+
|
4 |
+
|
5 |
+
from pytorch3d.io import load_objs_as_meshes, load_obj, save_obj, IO
|
6 |
+
|
7 |
+
from pytorch3d.structures import Meshes
|
8 |
+
from pytorch3d.renderer import (
|
9 |
+
look_at_view_transform,
|
10 |
+
FoVPerspectiveCameras,
|
11 |
+
FoVOrthographicCameras,
|
12 |
+
AmbientLights,
|
13 |
+
PointLights,
|
14 |
+
DirectionalLights,
|
15 |
+
Materials,
|
16 |
+
RasterizationSettings,
|
17 |
+
MeshRenderer,
|
18 |
+
MeshRasterizer,
|
19 |
+
TexturesUV,
|
20 |
+
)
|
21 |
+
|
22 |
+
from .geometry import HardGeometryShader
|
23 |
+
from .shader import HardNChannelFlatShader
|
24 |
+
from .voronoi import voronoi_solve
|
25 |
+
import torch.nn.functional as F
|
26 |
+
import open3d as o3d
|
27 |
+
import pdb
|
28 |
+
import kaolin as kal
|
29 |
+
import numpy as np
|
30 |
+
|
31 |
+
|
32 |
+
import torch
|
33 |
+
from pytorch3d.renderer.cameras import FoVOrthographicCameras
|
34 |
+
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
|
35 |
+
from pytorch3d.common.datatypes import Device
|
36 |
+
import math
|
37 |
+
import torch.nn.functional as F
|
38 |
+
from trimesh import Trimesh
|
39 |
+
from pytorch3d.structures import Meshes
|
40 |
+
import os
|
41 |
+
|
42 |
+
LIST_TYPE = Union[list, np.ndarray, torch.Tensor]
|
43 |
+
|
44 |
+
_R = torch.eye(3)[None] # (1, 3, 3)
|
45 |
+
_T = torch.zeros(1, 3) # (1, 3)
|
46 |
+
_BatchFloatType = Union[float, Sequence[float], torch.Tensor]
|
47 |
+
|
48 |
+
|
49 |
+
class CustomOrthographicCameras(FoVOrthographicCameras):
|
50 |
+
def compute_projection_matrix(
|
51 |
+
self, znear, zfar, max_x, min_x, max_y, min_y, scale_xyz
|
52 |
+
) -> torch.Tensor:
|
53 |
+
"""
|
54 |
+
自定义正交投影矩阵计算,继承并修改深度通道参数
|
55 |
+
参数维度说明:
|
56 |
+
- znear/zfar: (N,)
|
57 |
+
- max_x/min_x: (N,)
|
58 |
+
- max_y/min_y: (N,)
|
59 |
+
- scale_xyz: (N, 3)
|
60 |
+
"""
|
61 |
+
K = torch.zeros((self._N, 4, 4), dtype=torch.float32, device=self.device)
|
62 |
+
|
63 |
+
ones = torch.ones((self._N), dtype=torch.float32, device=self.device)
|
64 |
+
# NOTE: OpenGL flips handedness of coordinate system between camera
|
65 |
+
# space and NDC space so z sign is -ve. In PyTorch3D we maintain a
|
66 |
+
# right handed coordinate system throughout.
|
67 |
+
z_sign = +1.0
|
68 |
+
|
69 |
+
K[:, 0, 0] = (2.0 / (max_x - min_x)) * scale_xyz[:, 0]
|
70 |
+
K[:, 1, 1] = (2.0 / (max_y - min_y)) * scale_xyz[:, 1]
|
71 |
+
K[:, 0, 3] = -(max_x + min_x) / (max_x - min_x)
|
72 |
+
K[:, 1, 3] = -(max_y + min_y) / (max_y - min_y)
|
73 |
+
K[:, 3, 3] = ones
|
74 |
+
|
75 |
+
# NOTE: This maps the z coordinate to the range [0, 1] and replaces the
|
76 |
+
# the OpenGL z normalization to [-1, 1]
|
77 |
+
K[:, 2, 2] = -2 * (1.0 / (zfar - znear)) * scale_xyz[:, 2]
|
78 |
+
K[:, 2, 3] = -(znear + zfar) / (zfar - znear)
|
79 |
+
|
80 |
+
return K
|
81 |
+
|
82 |
+
def __init__(
|
83 |
+
self,
|
84 |
+
znear: _BatchFloatType = 1.0,
|
85 |
+
zfar: _BatchFloatType = 100.0,
|
86 |
+
max_y: _BatchFloatType = 1.0,
|
87 |
+
min_y: _BatchFloatType = -1.0,
|
88 |
+
max_x: _BatchFloatType = 1.0,
|
89 |
+
min_x: _BatchFloatType = -1.0,
|
90 |
+
scale_xyz=((1.0, 1.0, 1.0),), # (N, 3)
|
91 |
+
R: torch.Tensor = _R,
|
92 |
+
T: torch.Tensor = _T,
|
93 |
+
K: Optional[torch.Tensor] = None,
|
94 |
+
device: Device = "cpu",
|
95 |
+
):
|
96 |
+
# 继承父类初始化逻辑
|
97 |
+
super().__init__(
|
98 |
+
znear=znear,
|
99 |
+
zfar=zfar,
|
100 |
+
max_y=max_y,
|
101 |
+
min_y=min_y,
|
102 |
+
max_x=max_x,
|
103 |
+
min_x=min_x,
|
104 |
+
scale_xyz=scale_xyz,
|
105 |
+
R=R,
|
106 |
+
T=T,
|
107 |
+
K=K,
|
108 |
+
device=device,
|
109 |
+
)
|
110 |
+
|
111 |
+
|
112 |
+
def erode_torch_batch(binary_img_batch, kernel_size):
|
113 |
+
pad = (kernel_size - 1) // 2
|
114 |
+
bin_img = F.pad(
|
115 |
+
binary_img_batch.unsqueeze(1), pad=[pad, pad, pad, pad], mode="reflect"
|
116 |
+
)
|
117 |
+
out = -F.max_pool2d(-bin_img, kernel_size=kernel_size, stride=1, padding=0)
|
118 |
+
out = out.squeeze(1)
|
119 |
+
return out
|
120 |
+
|
121 |
+
|
122 |
+
def dilate_torch_batch(binary_img_batch, kernel_size):
|
123 |
+
pad = (kernel_size - 1) // 2
|
124 |
+
bin_img = F.pad(binary_img_batch, pad=[pad, pad, pad, pad], mode="reflect")
|
125 |
+
out = F.max_pool2d(bin_img, kernel_size=kernel_size, stride=1, padding=0)
|
126 |
+
out = out.squeeze()
|
127 |
+
return out
|
128 |
+
|
129 |
+
|
130 |
+
# Pytorch3D based renderering functions, managed in a class
|
131 |
+
# Render size is recommended to be the same as your latent view size
|
132 |
+
# DO NOT USE "bilinear" sampling when you are handling latents.
|
133 |
+
# Stable Diffusion has 4 latent channels so use channels=4
|
134 |
+
|
135 |
+
|
136 |
+
class UVProjection:
|
137 |
+
def __init__(
|
138 |
+
self,
|
139 |
+
texture_size=96,
|
140 |
+
render_size=64,
|
141 |
+
sampling_mode="nearest",
|
142 |
+
channels=3,
|
143 |
+
device=None,
|
144 |
+
):
|
145 |
+
self.channels = channels
|
146 |
+
self.device = device or torch.device("cpu")
|
147 |
+
self.lights = AmbientLights(
|
148 |
+
ambient_color=((1.0,) * channels,), device=self.device
|
149 |
+
)
|
150 |
+
self.target_size = (texture_size, texture_size)
|
151 |
+
self.render_size = render_size
|
152 |
+
self.sampling_mode = sampling_mode
|
153 |
+
|
154 |
+
# Load obj mesh, rescale the mesh to fit into the bounding box
|
155 |
+
def load_mesh(self, mesh, scale_factor=2.0, auto_center=True, autouv=False):
|
156 |
+
if isinstance(mesh, Trimesh):
|
157 |
+
vertices = torch.tensor(mesh.vertices, dtype=torch.float32).to(self.device)
|
158 |
+
faces = torch.tensor(mesh.faces, dtype=torch.int64).to(self.device)
|
159 |
+
mesh = Meshes(verts=[vertices], faces=[faces])
|
160 |
+
verts = mesh.verts_packed()
|
161 |
+
mesh = mesh.update_padded(verts[None, :, :])
|
162 |
+
elif isinstance(mesh, str) and os.path.isfile(mesh):
|
163 |
+
mesh = load_objs_as_meshes([mesh_path], device=self.device)
|
164 |
+
if auto_center:
|
165 |
+
verts = mesh.verts_packed()
|
166 |
+
max_bb = (verts - 0).max(0)[0]
|
167 |
+
min_bb = (verts - 0).min(0)[0]
|
168 |
+
scale = (max_bb - min_bb).max() / 2
|
169 |
+
center = (max_bb + min_bb) / 2
|
170 |
+
mesh.offset_verts_(-center)
|
171 |
+
mesh.scale_verts_((scale_factor / float(scale)))
|
172 |
+
else:
|
173 |
+
mesh.scale_verts_((scale_factor))
|
174 |
+
|
175 |
+
if autouv or (mesh.textures is None):
|
176 |
+
mesh = self.uv_unwrap(mesh)
|
177 |
+
self.mesh = mesh
|
178 |
+
|
179 |
+
def load_glb_mesh(
|
180 |
+
self, mesh_path, trimesh, scale_factor=1.0, auto_center=True, autouv=False
|
181 |
+
):
|
182 |
+
from pytorch3d.io.experimental_gltf_io import MeshGlbFormat
|
183 |
+
|
184 |
+
io = IO()
|
185 |
+
io.register_meshes_format(MeshGlbFormat())
|
186 |
+
with open(mesh_path, "rb") as f:
|
187 |
+
mesh = io.load_mesh(f, include_textures=True, device=self.device)
|
188 |
+
if auto_center:
|
189 |
+
verts = mesh.verts_packed()
|
190 |
+
|
191 |
+
max_bb = (verts - 0).max(0)[0]
|
192 |
+
min_bb = (verts - 0).min(0)[0]
|
193 |
+
scale = (max_bb - min_bb).max() / 2
|
194 |
+
center = (max_bb + min_bb) / 2
|
195 |
+
mesh.offset_verts_(-center)
|
196 |
+
mesh.scale_verts_((scale_factor / float(scale)))
|
197 |
+
verts = mesh.verts_packed()
|
198 |
+
# T = torch.tensor([[1, 0, 0], [0, 0, -1], [0, 1, 0]], device=verts.device, dtype=verts.dtype)
|
199 |
+
# T = torch.tensor([[0, 0, 1], [0, 1, 0], [-1, 0, 0]], device=verts.device, dtype=verts.dtype)
|
200 |
+
# verts = verts @ T
|
201 |
+
mesh = mesh.update_padded(verts[None, :, :])
|
202 |
+
else:
|
203 |
+
mesh.scale_verts_((scale_factor))
|
204 |
+
if autouv or (mesh.textures is None):
|
205 |
+
mesh = self.uv_unwrap(mesh)
|
206 |
+
self.mesh = mesh
|
207 |
+
|
208 |
+
# Save obj mesh
|
209 |
+
def save_mesh(self, mesh_path, texture):
|
210 |
+
save_obj(
|
211 |
+
mesh_path,
|
212 |
+
self.mesh.verts_list()[0],
|
213 |
+
self.mesh.faces_list()[0],
|
214 |
+
verts_uvs=self.mesh.textures.verts_uvs_list()[0],
|
215 |
+
faces_uvs=self.mesh.textures.faces_uvs_list()[0],
|
216 |
+
texture_map=texture,
|
217 |
+
)
|
218 |
+
|
219 |
+
# Code referred to TEXTure code (https://github.com/TEXTurePaper/TEXTurePaper.git)
|
220 |
+
def uv_unwrap(self, mesh):
|
221 |
+
verts_list = mesh.verts_list()[0]
|
222 |
+
faces_list = mesh.faces_list()[0]
|
223 |
+
|
224 |
+
import xatlas
|
225 |
+
import numpy as np
|
226 |
+
|
227 |
+
v_np = verts_list.cpu().numpy()
|
228 |
+
f_np = faces_list.int().cpu().numpy()
|
229 |
+
atlas = xatlas.Atlas()
|
230 |
+
atlas.add_mesh(v_np, f_np)
|
231 |
+
chart_options = xatlas.ChartOptions()
|
232 |
+
chart_options.max_iterations = 4
|
233 |
+
atlas.generate(chart_options=chart_options)
|
234 |
+
vmapping, ft_np, vt_np = atlas[0] # [N], [M, 3], [N, 2]
|
235 |
+
|
236 |
+
vt = (
|
237 |
+
torch.from_numpy(vt_np.astype(np.float32))
|
238 |
+
.type(verts_list.dtype)
|
239 |
+
.to(mesh.device)
|
240 |
+
)
|
241 |
+
ft = (
|
242 |
+
torch.from_numpy(ft_np.astype(np.int64))
|
243 |
+
.type(faces_list.dtype)
|
244 |
+
.to(mesh.device)
|
245 |
+
)
|
246 |
+
|
247 |
+
new_map = torch.zeros(self.target_size + (self.channels,), device=mesh.device)
|
248 |
+
new_tex = TexturesUV([new_map], [ft], [vt], sampling_mode=self.sampling_mode)
|
249 |
+
|
250 |
+
mesh.textures = new_tex
|
251 |
+
return mesh
|
252 |
+
|
253 |
+
"""
|
254 |
+
A functions that disconnect faces in the mesh according to
|
255 |
+
its UV seams. The number of vertices are made equal to the
|
256 |
+
number of unique vertices its UV layout, while the faces list
|
257 |
+
is intact.
|
258 |
+
"""
|
259 |
+
|
260 |
+
def disconnect_faces(self):
|
261 |
+
mesh = self.mesh
|
262 |
+
verts_list = mesh.verts_list()
|
263 |
+
faces_list = mesh.faces_list()
|
264 |
+
verts_uvs_list = mesh.textures.verts_uvs_list()
|
265 |
+
faces_uvs_list = mesh.textures.faces_uvs_list()
|
266 |
+
packed_list = [v[f] for v, f in zip(verts_list, faces_list)]
|
267 |
+
verts_disconnect_list = [
|
268 |
+
torch.zeros(
|
269 |
+
(verts_uvs_list[i].shape[0], 3),
|
270 |
+
dtype=verts_list[0].dtype,
|
271 |
+
device=verts_list[0].device,
|
272 |
+
)
|
273 |
+
for i in range(len(verts_list))
|
274 |
+
]
|
275 |
+
for i in range(len(verts_list)):
|
276 |
+
verts_disconnect_list[i][faces_uvs_list] = packed_list[i]
|
277 |
+
assert not mesh.has_verts_normals(), "Not implemented for vertex normals"
|
278 |
+
self.mesh_d = Meshes(verts_disconnect_list, faces_uvs_list, mesh.textures)
|
279 |
+
return self.mesh_d
|
280 |
+
|
281 |
+
"""
|
282 |
+
A function that construct a temp mesh for back-projection.
|
283 |
+
Take a disconnected mesh and a rasterizer, the function calculates
|
284 |
+
the projected faces as the UV, as use its original UV with pseudo
|
285 |
+
z value as world space geometry.
|
286 |
+
"""
|
287 |
+
|
288 |
+
def construct_uv_mesh(self):
|
289 |
+
mesh = self.mesh_d
|
290 |
+
verts_list = mesh.verts_list()
|
291 |
+
verts_uvs_list = mesh.textures.verts_uvs_list()
|
292 |
+
# faces_list = [torch.flip(faces, [-1]) for faces in mesh.faces_list()]
|
293 |
+
new_verts_list = []
|
294 |
+
for i, (verts, verts_uv) in enumerate(zip(verts_list, verts_uvs_list)):
|
295 |
+
verts = verts.clone()
|
296 |
+
verts_uv = verts_uv.clone()
|
297 |
+
verts[..., 0:2] = verts_uv[..., :]
|
298 |
+
verts = (verts - 0.5) * 2
|
299 |
+
verts[..., 2] *= 1
|
300 |
+
new_verts_list.append(verts)
|
301 |
+
textures_uv = mesh.textures.clone()
|
302 |
+
self.mesh_uv = Meshes(new_verts_list, mesh.faces_list(), textures_uv)
|
303 |
+
return self.mesh_uv
|
304 |
+
|
305 |
+
# Set texture for the current mesh.
|
306 |
+
def set_texture_map(self, texture):
|
307 |
+
new_map = texture.permute(1, 2, 0)
|
308 |
+
new_map = new_map.to(self.device)
|
309 |
+
new_tex = TexturesUV(
|
310 |
+
[new_map],
|
311 |
+
self.mesh.textures.faces_uvs_padded(),
|
312 |
+
self.mesh.textures.verts_uvs_padded(),
|
313 |
+
sampling_mode=self.sampling_mode,
|
314 |
+
)
|
315 |
+
self.mesh.textures = new_tex
|
316 |
+
|
317 |
+
# Set the initial normal noise texture
|
318 |
+
# No generator here for replication of the experiment result. Add one as you wish
|
319 |
+
def set_noise_texture(self, channels=None):
|
320 |
+
if not channels:
|
321 |
+
channels = self.channels
|
322 |
+
noise_texture = torch.normal(
|
323 |
+
0, 1, (channels,) + self.target_size, device=self.device
|
324 |
+
)
|
325 |
+
self.set_texture_map(noise_texture)
|
326 |
+
return noise_texture
|
327 |
+
|
328 |
+
# Set the cameras given the camera poses and centers
|
329 |
+
def set_cameras(self, camera_poses, centers=None, camera_distance=2.7, scale=None):
|
330 |
+
elev = torch.FloatTensor([pose[0] for pose in camera_poses])
|
331 |
+
azim = torch.FloatTensor([pose[1] for pose in camera_poses])
|
332 |
+
print("camera_distance:{}".format(camera_distance))
|
333 |
+
R, T = look_at_view_transform(
|
334 |
+
dist=camera_distance, elev=elev, azim=azim, at=centers or ((0, 0, 0),)
|
335 |
+
)
|
336 |
+
# flip_mat = torch.from_numpy(np.diag([-1.0, 1.0, -1.0]) ).type(torch.FloatTensor).to(R.device)
|
337 |
+
# R = R@flip_mat
|
338 |
+
# R = R.permute(0, 2, 1)
|
339 |
+
# T = T*torch.from_numpy(np.array([-1.0, 1.0, -1.0])).type(torch.FloatTensor).to(R.device)
|
340 |
+
# print("v R size:{}, v T size:{}".format(R.size(), T.size()))
|
341 |
+
# c2w = self.get_c2w(elev, [camera_distance]*len(elev), azim)
|
342 |
+
# w2c = torch.linalg.inv(c2w)
|
343 |
+
# R, T= w2c[:, :3, :3], w2c[:, :3, 3]
|
344 |
+
print("R size:{}, T size:{}".format(R.size(), T.size()))
|
345 |
+
# self.cameras = CustomOrthographicCameras(device=self.device, R=R, T=T, scale_xyz=scale or ((1,1,1),), znear=0.1, min_x=-0.55, max_x=0.55, min_y=-0.55, max_y=0.55)
|
346 |
+
self.cameras = FoVOrthographicCameras(
|
347 |
+
device=self.device, R=R, T=T, scale_xyz=scale or ((1, 1, 1),)
|
348 |
+
)
|
349 |
+
|
350 |
+
# Set all necessary internal data for rendering and texture baking
|
351 |
+
# Can be used to refresh after changing camera positions
|
352 |
+
def set_cameras_and_render_settings(
|
353 |
+
self,
|
354 |
+
camera_poses,
|
355 |
+
centers=None,
|
356 |
+
camera_distance=2.7,
|
357 |
+
render_size=None,
|
358 |
+
scale=None,
|
359 |
+
):
|
360 |
+
self.set_cameras(camera_poses, centers, camera_distance, scale=scale)
|
361 |
+
if render_size is None:
|
362 |
+
render_size = self.render_size
|
363 |
+
if not hasattr(self, "renderer"):
|
364 |
+
self.setup_renderer(size=render_size)
|
365 |
+
if not hasattr(self, "mesh_d"):
|
366 |
+
self.disconnect_faces()
|
367 |
+
if not hasattr(self, "mesh_uv"):
|
368 |
+
self.construct_uv_mesh()
|
369 |
+
self.calculate_tex_gradient()
|
370 |
+
self.calculate_visible_triangle_mask()
|
371 |
+
_, _, _, cos_maps, _, _ = self.render_geometry()
|
372 |
+
self.calculate_cos_angle_weights(cos_maps)
|
373 |
+
|
374 |
+
# Setup renderers for rendering
|
375 |
+
# max faces per bin set to 30000 to avoid overflow in many test cases.
|
376 |
+
# You can use default value to let pytorch3d handle that for you.
|
377 |
+
def setup_renderer(
|
378 |
+
self,
|
379 |
+
size=64,
|
380 |
+
blur=0.0,
|
381 |
+
face_per_pix=1,
|
382 |
+
perspective_correct=False,
|
383 |
+
channels=None,
|
384 |
+
):
|
385 |
+
if not channels:
|
386 |
+
channels = self.channels
|
387 |
+
|
388 |
+
self.raster_settings = RasterizationSettings(
|
389 |
+
image_size=size,
|
390 |
+
blur_radius=blur,
|
391 |
+
faces_per_pixel=face_per_pix,
|
392 |
+
perspective_correct=perspective_correct,
|
393 |
+
cull_backfaces=True,
|
394 |
+
max_faces_per_bin=30000,
|
395 |
+
)
|
396 |
+
|
397 |
+
self.renderer = MeshRenderer(
|
398 |
+
rasterizer=MeshRasterizer(
|
399 |
+
cameras=self.cameras,
|
400 |
+
raster_settings=self.raster_settings,
|
401 |
+
),
|
402 |
+
shader=HardNChannelFlatShader(
|
403 |
+
device=self.device,
|
404 |
+
cameras=self.cameras,
|
405 |
+
lights=self.lights,
|
406 |
+
channels=channels,
|
407 |
+
# materials=materials
|
408 |
+
),
|
409 |
+
)
|
410 |
+
|
411 |
+
# Bake screen-space cosine weights to UV space
|
412 |
+
# May be able to reimplement using the generic "bake_texture" function, but it works so leave it here for now
|
413 |
+
@torch.enable_grad()
|
414 |
+
def calculate_cos_angle_weights(self, cos_angles, fill=True, channels=None):
|
415 |
+
if not channels:
|
416 |
+
channels = self.channels
|
417 |
+
cos_maps = []
|
418 |
+
tmp_mesh = self.mesh.clone()
|
419 |
+
for i in range(len(self.cameras)):
|
420 |
+
|
421 |
+
zero_map = torch.zeros(
|
422 |
+
self.target_size + (channels,), device=self.device, requires_grad=True
|
423 |
+
)
|
424 |
+
optimizer = torch.optim.SGD([zero_map], lr=1, momentum=0)
|
425 |
+
optimizer.zero_grad()
|
426 |
+
zero_tex = TexturesUV(
|
427 |
+
[zero_map],
|
428 |
+
self.mesh.textures.faces_uvs_padded(),
|
429 |
+
self.mesh.textures.verts_uvs_padded(),
|
430 |
+
sampling_mode=self.sampling_mode,
|
431 |
+
)
|
432 |
+
tmp_mesh.textures = zero_tex
|
433 |
+
|
434 |
+
images_predicted = self.renderer(
|
435 |
+
tmp_mesh, cameras=self.cameras[i], lights=self.lights
|
436 |
+
)
|
437 |
+
|
438 |
+
loss = torch.sum((cos_angles[i, :, :, 0:1] ** 1 - images_predicted) ** 2)
|
439 |
+
loss.backward()
|
440 |
+
optimizer.step()
|
441 |
+
|
442 |
+
if fill:
|
443 |
+
zero_map = zero_map.detach() / (self.gradient_maps[i] + 1e-8)
|
444 |
+
zero_map = voronoi_solve(
|
445 |
+
zero_map, self.gradient_maps[i][..., 0], self.device
|
446 |
+
)
|
447 |
+
else:
|
448 |
+
zero_map = zero_map.detach() / (self.gradient_maps[i] + 1e-8)
|
449 |
+
cos_maps.append(zero_map)
|
450 |
+
self.cos_maps = cos_maps
|
451 |
+
|
452 |
+
# Get geometric info from fragment shader
|
453 |
+
# Can be used for generating conditioning image and cosine weights
|
454 |
+
# Returns some information you may not need, remember to release them for memory saving
|
455 |
+
@torch.no_grad()
|
456 |
+
def render_geometry(self, image_size=None):
|
457 |
+
if image_size:
|
458 |
+
size = self.renderer.rasterizer.raster_settings.image_size
|
459 |
+
self.renderer.rasterizer.raster_settings.image_size = image_size
|
460 |
+
shader = self.renderer.shader
|
461 |
+
self.renderer.shader = HardGeometryShader(
|
462 |
+
device=self.device, cameras=self.cameras[0], lights=self.lights
|
463 |
+
)
|
464 |
+
tmp_mesh = self.mesh.clone()
|
465 |
+
|
466 |
+
verts, normals, depths, cos_angles, texels, fragments = self.renderer(
|
467 |
+
tmp_mesh.extend(len(self.cameras)), cameras=self.cameras, lights=self.lights
|
468 |
+
)
|
469 |
+
self.renderer.shader = shader
|
470 |
+
|
471 |
+
if image_size:
|
472 |
+
self.renderer.rasterizer.raster_settings.image_size = size
|
473 |
+
|
474 |
+
return verts, normals, depths, cos_angles, texels, fragments
|
475 |
+
|
476 |
+
# Project world normal to view space and normalize
|
477 |
+
@torch.no_grad()
|
478 |
+
def decode_view_normal(self, normals):
|
479 |
+
w2v_mat = self.cameras.get_full_projection_transform()
|
480 |
+
normals_view = torch.clone(normals)[:, :, :, 0:3]
|
481 |
+
normals_view = normals_view.reshape(normals_view.shape[0], -1, 3)
|
482 |
+
normals_view = w2v_mat.transform_normals(normals_view)
|
483 |
+
normals_view = normals_view.reshape(normals.shape[0:3] + (3,))
|
484 |
+
normals_view[:, :, :, 2] *= -1
|
485 |
+
normals = (normals_view[..., 0:3] + 1) * normals[
|
486 |
+
..., 3:
|
487 |
+
] / 2 + torch.FloatTensor(((((0.5, 0.5, 1))))).to(self.device) * (
|
488 |
+
1 - normals[..., 3:]
|
489 |
+
)
|
490 |
+
# normals = torch.cat([normal for normal in normals], dim=1)
|
491 |
+
normals = normals.clamp(0, 1)
|
492 |
+
return normals
|
493 |
+
|
494 |
+
# Normalize absolute depth to inverse depth
|
495 |
+
@torch.no_grad()
|
496 |
+
def decode_normalized_depth(self, depths, batched_norm=False):
|
497 |
+
view_z, mask = depths.unbind(-1)
|
498 |
+
view_z = view_z * mask + 100 * (1 - mask)
|
499 |
+
inv_z = 1 / view_z
|
500 |
+
inv_z_min = inv_z * mask + 100 * (1 - mask)
|
501 |
+
if not batched_norm:
|
502 |
+
max_ = torch.max(inv_z, 1, keepdim=True)
|
503 |
+
max_ = torch.max(max_[0], 2, keepdim=True)[0]
|
504 |
+
|
505 |
+
min_ = torch.min(inv_z_min, 1, keepdim=True)
|
506 |
+
min_ = torch.min(min_[0], 2, keepdim=True)[0]
|
507 |
+
else:
|
508 |
+
max_ = torch.max(inv_z)
|
509 |
+
min_ = torch.min(inv_z_min)
|
510 |
+
inv_z = (inv_z - min_) / (max_ - min_)
|
511 |
+
inv_z = inv_z.clamp(0, 1)
|
512 |
+
inv_z = inv_z[..., None].repeat(1, 1, 1, 3)
|
513 |
+
|
514 |
+
return inv_z
|
515 |
+
|
516 |
+
# Multiple screen pixels could pass gradient to a same texel
|
517 |
+
# We can precalculate this gradient strength and use it to normalize gradients when we bake textures
|
518 |
+
@torch.enable_grad()
|
519 |
+
def calculate_tex_gradient(self, channels=None):
|
520 |
+
if not channels:
|
521 |
+
channels = self.channels
|
522 |
+
tmp_mesh = self.mesh.clone()
|
523 |
+
gradient_maps = []
|
524 |
+
for i in range(len(self.cameras)):
|
525 |
+
zero_map = torch.zeros(
|
526 |
+
self.target_size + (channels,), device=self.device, requires_grad=True
|
527 |
+
)
|
528 |
+
optimizer = torch.optim.SGD([zero_map], lr=1, momentum=0)
|
529 |
+
optimizer.zero_grad()
|
530 |
+
zero_tex = TexturesUV(
|
531 |
+
[zero_map],
|
532 |
+
self.mesh.textures.faces_uvs_padded(),
|
533 |
+
self.mesh.textures.verts_uvs_padded(),
|
534 |
+
sampling_mode=self.sampling_mode,
|
535 |
+
)
|
536 |
+
tmp_mesh.textures = zero_tex
|
537 |
+
images_predicted = self.renderer(
|
538 |
+
tmp_mesh, cameras=self.cameras[i], lights=self.lights
|
539 |
+
)
|
540 |
+
loss = torch.sum((1 - images_predicted) ** 2)
|
541 |
+
loss.backward()
|
542 |
+
optimizer.step()
|
543 |
+
|
544 |
+
gradient_maps.append(zero_map.detach())
|
545 |
+
|
546 |
+
self.gradient_maps = gradient_maps
|
547 |
+
|
548 |
+
# Get the UV space masks of triangles visible in each view
|
549 |
+
# First get face ids from each view, then filter pixels on UV space to generate masks
|
550 |
+
|
551 |
+
@torch.no_grad()
|
552 |
+
def get_c2w(
|
553 |
+
self,
|
554 |
+
elevation_deg: LIST_TYPE,
|
555 |
+
distance: LIST_TYPE,
|
556 |
+
azimuth_deg: Optional[LIST_TYPE],
|
557 |
+
num_views: Optional[int] = 1,
|
558 |
+
device: Optional[str] = None,
|
559 |
+
) -> torch.FloatTensor:
|
560 |
+
if azimuth_deg is None:
|
561 |
+
assert (
|
562 |
+
num_views is not None
|
563 |
+
), "num_views must be provided if azimuth_deg is None."
|
564 |
+
azimuth_deg = torch.linspace(
|
565 |
+
0, 360, num_views + 1, dtype=torch.float32, device=device
|
566 |
+
)[:-1]
|
567 |
+
else:
|
568 |
+
num_views = len(azimuth_deg)
|
569 |
+
|
570 |
+
def list_to_pt(
|
571 |
+
x: LIST_TYPE,
|
572 |
+
dtype: Optional[torch.dtype] = None,
|
573 |
+
device: Optional[str] = None,
|
574 |
+
) -> torch.Tensor:
|
575 |
+
if isinstance(x, list) or isinstance(x, np.ndarray):
|
576 |
+
return torch.tensor(x, dtype=dtype, device=device)
|
577 |
+
return x.to(dtype=dtype)
|
578 |
+
|
579 |
+
azimuth_deg = list_to_pt(azimuth_deg, dtype=torch.float32, device=device)
|
580 |
+
elevation_deg = list_to_pt(elevation_deg, dtype=torch.float32, device=device)
|
581 |
+
camera_distances = list_to_pt(distance, dtype=torch.float32, device=device)
|
582 |
+
elevation = elevation_deg * math.pi / 180
|
583 |
+
azimuth = azimuth_deg * math.pi / 180
|
584 |
+
camera_positions = torch.stack(
|
585 |
+
[
|
586 |
+
camera_distances * torch.cos(elevation) * torch.cos(azimuth),
|
587 |
+
camera_distances * torch.cos(elevation) * torch.sin(azimuth),
|
588 |
+
camera_distances * torch.sin(elevation),
|
589 |
+
],
|
590 |
+
dim=-1,
|
591 |
+
)
|
592 |
+
center = torch.zeros_like(camera_positions)
|
593 |
+
up = torch.tensor([0, 0, 1], dtype=torch.float32, device=device)[
|
594 |
+
None, :
|
595 |
+
].repeat(num_views, 1)
|
596 |
+
lookat = F.normalize(center - camera_positions, dim=-1)
|
597 |
+
right = F.normalize(torch.cross(lookat, up, dim=-1), dim=-1)
|
598 |
+
up = F.normalize(torch.cross(right, lookat, dim=-1), dim=-1)
|
599 |
+
c2w3x4 = torch.cat(
|
600 |
+
[torch.stack([right, up, -lookat], dim=-1), camera_positions[:, :, None]],
|
601 |
+
dim=-1,
|
602 |
+
)
|
603 |
+
c2w = torch.cat([c2w3x4, torch.zeros_like(c2w3x4[:, :1])], dim=1)
|
604 |
+
c2w[:, 3, 3] = 1.0
|
605 |
+
return c2w
|
606 |
+
|
607 |
+
@torch.no_grad()
|
608 |
+
def calculate_visible_triangle_mask(self, channels=None, image_size=(512, 512)):
|
609 |
+
if not channels:
|
610 |
+
channels = self.channels
|
611 |
+
|
612 |
+
pix2face_list = []
|
613 |
+
for i in range(len(self.cameras)):
|
614 |
+
self.renderer.rasterizer.raster_settings.image_size = image_size
|
615 |
+
pix2face = self.renderer.rasterizer(
|
616 |
+
self.mesh_d, cameras=self.cameras[i]
|
617 |
+
).pix_to_face
|
618 |
+
self.renderer.rasterizer.raster_settings.image_size = self.render_size
|
619 |
+
pix2face_list.append(pix2face)
|
620 |
+
|
621 |
+
if not hasattr(self, "mesh_uv"):
|
622 |
+
self.construct_uv_mesh()
|
623 |
+
|
624 |
+
raster_settings = RasterizationSettings(
|
625 |
+
image_size=self.target_size,
|
626 |
+
blur_radius=0,
|
627 |
+
faces_per_pixel=1,
|
628 |
+
perspective_correct=False,
|
629 |
+
cull_backfaces=False,
|
630 |
+
max_faces_per_bin=30000,
|
631 |
+
)
|
632 |
+
|
633 |
+
R, T = look_at_view_transform(dist=2, elev=0, azim=0)
|
634 |
+
# flip_mat = torch.from_numpy(np.diag([-1.0, 1.0, -1.0]) ).type(torch.FloatTensor).to(R.device)
|
635 |
+
# R = R@flip_mat
|
636 |
+
# T = T*torch.tensor(np.array([-1.0, 1.0, -1.0])).type(torch.FloatTensor).to(R.device)
|
637 |
+
# c2w = self.get_c2w([0], [1.8], [0])
|
638 |
+
# w2c = torch.linalg.inv(c2w)[:, :3,:]
|
639 |
+
# R, T= w2c[:, :3,:3], w2c[:, :3, 3]
|
640 |
+
# print("R size:{}, T size:{}".format(R.size(), T.size()))
|
641 |
+
cameras = FoVOrthographicCameras(device=self.device, R=R, T=T)
|
642 |
+
# cameras = CustomOrthographicCameras(device=self.device, R=R, T=T)
|
643 |
+
|
644 |
+
# cameras = CustomOrthographicCameras(device=self.device, R=R, T=T, znear=0.1, min_x=-0.55, max_x=0.55, min_y=-0.55, max_y=0.55)
|
645 |
+
|
646 |
+
rasterizer = MeshRasterizer(cameras=cameras, raster_settings=raster_settings)
|
647 |
+
uv_pix2face = rasterizer(self.mesh_uv).pix_to_face
|
648 |
+
|
649 |
+
visible_triangles = []
|
650 |
+
for i in range(len(pix2face_list)):
|
651 |
+
valid_faceid = torch.unique(pix2face_list[i])
|
652 |
+
valid_faceid = valid_faceid[1:] if valid_faceid[0] == -1 else valid_faceid
|
653 |
+
mask = torch.isin(uv_pix2face[0], valid_faceid, assume_unique=False)
|
654 |
+
# uv_pix2face[0][~mask] = -1
|
655 |
+
triangle_mask = torch.ones(self.target_size + (1,), device=self.device)
|
656 |
+
triangle_mask[~mask] = 0
|
657 |
+
|
658 |
+
triangle_mask[:, 1:][triangle_mask[:, :-1] > 0] = 1
|
659 |
+
triangle_mask[:, :-1][triangle_mask[:, 1:] > 0] = 1
|
660 |
+
triangle_mask[1:, :][triangle_mask[:-1, :] > 0] = 1
|
661 |
+
triangle_mask[:-1, :][triangle_mask[1:, :] > 0] = 1
|
662 |
+
visible_triangles.append(triangle_mask)
|
663 |
+
|
664 |
+
self.visible_triangles = visible_triangles
|
665 |
+
|
666 |
+
# Render the current mesh and texture from current cameras
|
667 |
+
def render_textured_views(self):
|
668 |
+
meshes = self.mesh.extend(len(self.cameras))
|
669 |
+
images_predicted = self.renderer(
|
670 |
+
meshes, cameras=self.cameras, lights=self.lights
|
671 |
+
)
|
672 |
+
|
673 |
+
return [image.permute(2, 0, 1) for image in images_predicted]
|
674 |
+
|
675 |
+
@torch.no_grad()
|
676 |
+
def get_point_validation_by_o3d(
|
677 |
+
self, points, eye_position, hidden_point_removal_radius=200
|
678 |
+
):
|
679 |
+
point_visibility = torch.zeros((points.shape[0]), device=points.device).bool()
|
680 |
+
|
681 |
+
pcd = o3d.geometry.PointCloud(
|
682 |
+
points=o3d.utility.Vector3dVector(points.cpu().numpy())
|
683 |
+
)
|
684 |
+
camera_pose = (
|
685 |
+
eye_position.get_camera_center().squeeze().cpu().numpy().astype(np.float64)
|
686 |
+
)
|
687 |
+
# o3d_camera = [0, 0, diameter]
|
688 |
+
diameter = np.linalg.norm(
|
689 |
+
np.asarray(pcd.get_max_bound()) - np.asarray(pcd.get_min_bound())
|
690 |
+
)
|
691 |
+
radius = diameter * 200 # The radius of the sperical projection
|
692 |
+
_, pt_map = pcd.hidden_point_removal(camera_pose, radius)
|
693 |
+
|
694 |
+
visible_point_ids = np.array(pt_map)
|
695 |
+
|
696 |
+
point_visibility[visible_point_ids] = True
|
697 |
+
return point_visibility
|
698 |
+
|
699 |
+
@torch.no_grad()
|
700 |
+
def hidden_judge(self, camera, texture_dim):
|
701 |
+
mesh = self.mesh
|
702 |
+
|
703 |
+
verts = mesh.verts_packed()
|
704 |
+
faces = mesh.faces_packed()
|
705 |
+
verts_uv = mesh.textures.verts_uvs_padded()[0] # 获取打包后的 UV 坐标 (V, 2)
|
706 |
+
faces_uv = mesh.textures.faces_uvs_padded()[0]
|
707 |
+
uv_face_attr = torch.index_select(
|
708 |
+
verts_uv, 0, faces_uv.view(-1)
|
709 |
+
) # 选择对应顶点的 UV 坐标
|
710 |
+
uv_face_attr = uv_face_attr.view(
|
711 |
+
faces.shape[0], faces_uv.shape[1], 2
|
712 |
+
).unsqueeze(0)
|
713 |
+
x, y, z = verts[:, 0], verts[:, 1], verts[:, 2]
|
714 |
+
mesh_out_of_range = False
|
715 |
+
if (
|
716 |
+
x.min() < -1
|
717 |
+
or x.max() > 1
|
718 |
+
or y.min() < -1
|
719 |
+
or y.max() > 1
|
720 |
+
or z.min() < -1
|
721 |
+
or z.max() > 1
|
722 |
+
):
|
723 |
+
mesh_out_of_range = True
|
724 |
+
face_vertices_world = kal.ops.mesh.index_vertices_by_faces(
|
725 |
+
verts.unsqueeze(0), faces
|
726 |
+
)
|
727 |
+
face_vertices_z = torch.zeros_like(
|
728 |
+
face_vertices_world[:, :, :, -1], device=verts.device
|
729 |
+
)
|
730 |
+
uv_position, face_idx = kal.render.mesh.rasterize(
|
731 |
+
texture_dim,
|
732 |
+
texture_dim,
|
733 |
+
face_vertices_z,
|
734 |
+
uv_face_attr * 2 - 1,
|
735 |
+
face_features=face_vertices_world,
|
736 |
+
)
|
737 |
+
uv_position = torch.clamp(uv_position, -1, 1)
|
738 |
+
uv_position[face_idx == -1] = 0
|
739 |
+
|
740 |
+
points = uv_position.reshape(-1, 3)
|
741 |
+
mask = points[:, 0] != 0
|
742 |
+
valid_points = points[mask]
|
743 |
+
# np.save("tmp/pcd.npy", valid_points.cpu().numpy())
|
744 |
+
# print(camera.get_camera_center())
|
745 |
+
|
746 |
+
points_visibility = self.get_point_validation_by_o3d(
|
747 |
+
valid_points, camera
|
748 |
+
).float()
|
749 |
+
visibility_map = torch.zeros((texture_dim * texture_dim,)).to(self.device)
|
750 |
+
visibility_map[mask] = points_visibility
|
751 |
+
visibility_map = visibility_map.reshape((texture_dim, texture_dim))
|
752 |
+
return visibility_map
|
753 |
+
|
754 |
+
@torch.enable_grad()
|
755 |
+
def bake_texture(
|
756 |
+
self,
|
757 |
+
views=None,
|
758 |
+
main_views=[],
|
759 |
+
cos_weighted=True,
|
760 |
+
channels=None,
|
761 |
+
exp=None,
|
762 |
+
noisy=False,
|
763 |
+
generator=None,
|
764 |
+
smooth_colorize=False,
|
765 |
+
):
|
766 |
+
if not exp:
|
767 |
+
exp = 1
|
768 |
+
if not channels:
|
769 |
+
channels = self.channels
|
770 |
+
views = [view.permute(1, 2, 0) for view in views]
|
771 |
+
|
772 |
+
tmp_mesh = self.mesh
|
773 |
+
bake_maps = [
|
774 |
+
torch.zeros(
|
775 |
+
self.target_size + (views[0].shape[2],),
|
776 |
+
device=self.device,
|
777 |
+
requires_grad=True,
|
778 |
+
)
|
779 |
+
for view in views
|
780 |
+
]
|
781 |
+
optimizer = torch.optim.SGD(bake_maps, lr=1, momentum=0)
|
782 |
+
optimizer.zero_grad()
|
783 |
+
loss = 0
|
784 |
+
for i in range(len(self.cameras)):
|
785 |
+
bake_tex = TexturesUV(
|
786 |
+
[bake_maps[i]],
|
787 |
+
tmp_mesh.textures.faces_uvs_padded(),
|
788 |
+
tmp_mesh.textures.verts_uvs_padded(),
|
789 |
+
sampling_mode=self.sampling_mode,
|
790 |
+
)
|
791 |
+
tmp_mesh.textures = bake_tex
|
792 |
+
images_predicted = self.renderer(
|
793 |
+
tmp_mesh,
|
794 |
+
cameras=self.cameras[i],
|
795 |
+
lights=self.lights,
|
796 |
+
device=self.device,
|
797 |
+
)
|
798 |
+
predicted_rgb = images_predicted[..., :-1]
|
799 |
+
loss += (((predicted_rgb[...] - views[i])) ** 2).sum()
|
800 |
+
loss.backward(retain_graph=False)
|
801 |
+
optimizer.step()
|
802 |
+
|
803 |
+
total_weights = 0
|
804 |
+
baked = 0
|
805 |
+
for i in range(len(bake_maps)):
|
806 |
+
normalized_baked_map = bake_maps[i].detach() / (
|
807 |
+
self.gradient_maps[i] + 1e-8
|
808 |
+
)
|
809 |
+
bake_map = voronoi_solve(
|
810 |
+
normalized_baked_map, self.gradient_maps[i][..., 0], self.device
|
811 |
+
)
|
812 |
+
# bake_map = voronoi_solve(normalized_baked_map, self.visible_triangles[i].squeeze())
|
813 |
+
|
814 |
+
weight = self.visible_triangles[i] * (self.cos_maps[i]) ** exp
|
815 |
+
if smooth_colorize:
|
816 |
+
visibility_map = self.hidden_judge(
|
817 |
+
self.cameras[i], self.target_size[0]
|
818 |
+
).unsqueeze(-1)
|
819 |
+
weight *= visibility_map
|
820 |
+
if noisy:
|
821 |
+
noise = (
|
822 |
+
torch.rand(weight.shape[:-1] + (1,), generator=generator)
|
823 |
+
.type(weight.dtype)
|
824 |
+
.to(weight.device)
|
825 |
+
)
|
826 |
+
weight *= noise
|
827 |
+
total_weights += weight
|
828 |
+
|
829 |
+
baked += bake_map * weight
|
830 |
+
baked /= total_weights + 1e-8
|
831 |
+
|
832 |
+
whole_visible_mask = None
|
833 |
+
if not smooth_colorize:
|
834 |
+
baked = voronoi_solve(baked, total_weights[..., 0], self.device)
|
835 |
+
tmp_mesh.textures = TexturesUV(
|
836 |
+
[baked],
|
837 |
+
tmp_mesh.textures.faces_uvs_padded(),
|
838 |
+
tmp_mesh.textures.verts_uvs_padded(),
|
839 |
+
sampling_mode=self.sampling_mode,
|
840 |
+
)
|
841 |
+
else: # smooth colorize
|
842 |
+
baked = voronoi_solve(baked, total_weights[..., 0], self.device)
|
843 |
+
whole_visible_mask = self.visible_triangles[0].to(torch.int32)
|
844 |
+
for tensor in self.visible_triangles[1:]:
|
845 |
+
whole_visible_mask = torch.bitwise_or(
|
846 |
+
whole_visible_mask, tensor.to(torch.int32)
|
847 |
+
)
|
848 |
+
|
849 |
+
baked *= whole_visible_mask
|
850 |
+
tmp_mesh.textures = TexturesUV(
|
851 |
+
[baked],
|
852 |
+
tmp_mesh.textures.faces_uvs_padded(),
|
853 |
+
tmp_mesh.textures.verts_uvs_padded(),
|
854 |
+
sampling_mode=self.sampling_mode,
|
855 |
+
)
|
856 |
+
|
857 |
+
extended_mesh = tmp_mesh.extend(len(self.cameras))
|
858 |
+
images_predicted = self.renderer(
|
859 |
+
extended_mesh, cameras=self.cameras, lights=self.lights
|
860 |
+
)
|
861 |
+
learned_views = [image.permute(2, 0, 1) for image in images_predicted]
|
862 |
+
|
863 |
+
return learned_views, baked.permute(2, 0, 1), total_weights.permute(2, 0, 1)
|
864 |
+
|
865 |
+
# Move the internel data to a specific device
|
866 |
+
def to(self, device):
|
867 |
+
for mesh_name in ["mesh", "mesh_d", "mesh_uv"]:
|
868 |
+
if hasattr(self, mesh_name):
|
869 |
+
mesh = getattr(self, mesh_name)
|
870 |
+
setattr(self, mesh_name, mesh.to(device))
|
871 |
+
for list_name in ["visible_triangles", "visibility_maps", "cos_maps"]:
|
872 |
+
if hasattr(self, list_name):
|
873 |
+
map_list = getattr(self, list_name)
|
874 |
+
for i in range(len(map_list)):
|
875 |
+
map_list[i] = map_list[i].to(device)
|
step1x3d_texture/renderer/shader.py
ADDED
@@ -0,0 +1,127 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Optional
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import pytorch3d
|
5 |
+
|
6 |
+
|
7 |
+
from pytorch3d.io import load_objs_as_meshes, load_obj, save_obj
|
8 |
+
from pytorch3d.ops import interpolate_face_attributes
|
9 |
+
|
10 |
+
from pytorch3d.structures import Meshes
|
11 |
+
from pytorch3d.renderer import (
|
12 |
+
look_at_view_transform,
|
13 |
+
FoVPerspectiveCameras,
|
14 |
+
AmbientLights,
|
15 |
+
PointLights,
|
16 |
+
DirectionalLights,
|
17 |
+
Materials,
|
18 |
+
RasterizationSettings,
|
19 |
+
MeshRenderer,
|
20 |
+
MeshRasterizer,
|
21 |
+
SoftPhongShader,
|
22 |
+
SoftSilhouetteShader,
|
23 |
+
HardPhongShader,
|
24 |
+
TexturesVertex,
|
25 |
+
TexturesUV,
|
26 |
+
Materials,
|
27 |
+
)
|
28 |
+
from pytorch3d.renderer.blending import BlendParams, hard_rgb_blend
|
29 |
+
from pytorch3d.renderer.utils import convert_to_tensors_and_broadcast, TensorProperties
|
30 |
+
|
31 |
+
from pytorch3d.renderer.lighting import AmbientLights
|
32 |
+
from pytorch3d.renderer.materials import Materials
|
33 |
+
from pytorch3d.renderer.mesh.shader import ShaderBase
|
34 |
+
from pytorch3d.renderer.mesh.shading import _apply_lighting, flat_shading
|
35 |
+
from pytorch3d.renderer.mesh.rasterizer import Fragments
|
36 |
+
|
37 |
+
|
38 |
+
"""
|
39 |
+
Customized the original pytorch3d hard flat shader to support N channel flat shading
|
40 |
+
"""
|
41 |
+
|
42 |
+
|
43 |
+
class HardNChannelFlatShader(ShaderBase):
|
44 |
+
"""
|
45 |
+
Per face lighting - the lighting model is applied using the average face
|
46 |
+
position and the face normal. The blending function hard assigns
|
47 |
+
the color of the closest face for each pixel.
|
48 |
+
|
49 |
+
To use the default values, simply initialize the shader with the desired
|
50 |
+
device e.g.
|
51 |
+
|
52 |
+
.. code-block::
|
53 |
+
|
54 |
+
shader = HardFlatShader(device=torch.device("cuda:0"))
|
55 |
+
"""
|
56 |
+
|
57 |
+
def __init__(
|
58 |
+
self,
|
59 |
+
device="cpu",
|
60 |
+
cameras: Optional[TensorProperties] = None,
|
61 |
+
lights: Optional[TensorProperties] = None,
|
62 |
+
materials: Optional[Materials] = None,
|
63 |
+
blend_params: Optional[BlendParams] = None,
|
64 |
+
channels: int = 3,
|
65 |
+
):
|
66 |
+
self.channels = channels
|
67 |
+
ones = ((1.0,) * channels,)
|
68 |
+
zeros = ((0.0,) * channels,)
|
69 |
+
|
70 |
+
if (
|
71 |
+
not isinstance(lights, AmbientLights)
|
72 |
+
or not lights.ambient_color.shape[-1] == channels
|
73 |
+
):
|
74 |
+
lights = AmbientLights(
|
75 |
+
ambient_color=ones,
|
76 |
+
device=device,
|
77 |
+
)
|
78 |
+
|
79 |
+
if not materials or not materials.ambient_color.shape[-1] == channels:
|
80 |
+
materials = Materials(
|
81 |
+
device=device,
|
82 |
+
diffuse_color=zeros,
|
83 |
+
ambient_color=ones,
|
84 |
+
specular_color=zeros,
|
85 |
+
shininess=0.0,
|
86 |
+
)
|
87 |
+
|
88 |
+
blend_params_new = BlendParams(background_color=(1.0,) * channels)
|
89 |
+
if not isinstance(blend_params, BlendParams):
|
90 |
+
blend_params = blend_params_new
|
91 |
+
else:
|
92 |
+
background_color_ = blend_params.background_color
|
93 |
+
if (
|
94 |
+
isinstance(background_color_, Sequence[float])
|
95 |
+
and not len(background_color_) == channels
|
96 |
+
):
|
97 |
+
blend_params = blend_params_new
|
98 |
+
if (
|
99 |
+
isinstance(background_color_, torch.Tensor)
|
100 |
+
and not background_color_.shape[-1] == channels
|
101 |
+
):
|
102 |
+
blend_params = blend_params_new
|
103 |
+
|
104 |
+
super().__init__(
|
105 |
+
device,
|
106 |
+
cameras,
|
107 |
+
lights,
|
108 |
+
materials,
|
109 |
+
blend_params,
|
110 |
+
)
|
111 |
+
|
112 |
+
def forward(self, fragments: Fragments, meshes: Meshes, **kwargs) -> torch.Tensor:
|
113 |
+
cameras = super()._get_cameras(**kwargs)
|
114 |
+
texels = meshes.sample_textures(fragments)
|
115 |
+
lights = kwargs.get("lights", self.lights)
|
116 |
+
materials = kwargs.get("materials", self.materials)
|
117 |
+
blend_params = kwargs.get("blend_params", self.blend_params)
|
118 |
+
colors = flat_shading(
|
119 |
+
meshes=meshes,
|
120 |
+
fragments=fragments,
|
121 |
+
texels=texels,
|
122 |
+
lights=lights,
|
123 |
+
cameras=cameras,
|
124 |
+
materials=materials,
|
125 |
+
)
|
126 |
+
images = hard_rgb_blend(colors, fragments, blend_params)
|
127 |
+
return images
|
step1x3d_texture/{texture_sync → renderer}/voronoi.py
RENAMED
File without changes
|
step1x3d_texture/texture_sync/geometry.py
DELETED
@@ -1,141 +0,0 @@
|
|
1 |
-
import torch
|
2 |
-
import pytorch3d
|
3 |
-
import torch.nn.functional as F
|
4 |
-
|
5 |
-
from pytorch3d.ops import interpolate_face_attributes
|
6 |
-
|
7 |
-
from pytorch3d.renderer import (
|
8 |
-
look_at_view_transform,
|
9 |
-
FoVPerspectiveCameras,
|
10 |
-
AmbientLights,
|
11 |
-
PointLights,
|
12 |
-
DirectionalLights,
|
13 |
-
Materials,
|
14 |
-
RasterizationSettings,
|
15 |
-
MeshRenderer,
|
16 |
-
MeshRasterizer,
|
17 |
-
SoftPhongShader,
|
18 |
-
SoftSilhouetteShader,
|
19 |
-
HardPhongShader,
|
20 |
-
TexturesVertex,
|
21 |
-
TexturesUV,
|
22 |
-
Materials,
|
23 |
-
|
24 |
-
)
|
25 |
-
from pytorch3d.renderer.blending import BlendParams,hard_rgb_blend
|
26 |
-
from pytorch3d.renderer.utils import convert_to_tensors_and_broadcast, TensorProperties
|
27 |
-
from pytorch3d.renderer.mesh.shader import ShaderBase
|
28 |
-
|
29 |
-
|
30 |
-
def get_cos_angle(
|
31 |
-
points, normals, camera_position
|
32 |
-
):
|
33 |
-
'''
|
34 |
-
calculate cosine similarity between view->surface and surface normal.
|
35 |
-
'''
|
36 |
-
|
37 |
-
if points.shape != normals.shape:
|
38 |
-
msg = "Expected points and normals to have the same shape: got %r, %r"
|
39 |
-
raise ValueError(msg % (points.shape, normals.shape))
|
40 |
-
|
41 |
-
# Ensure all inputs have same batch dimension as points
|
42 |
-
matched_tensors = convert_to_tensors_and_broadcast(
|
43 |
-
points, camera_position, device=points.device
|
44 |
-
)
|
45 |
-
_, camera_position = matched_tensors
|
46 |
-
|
47 |
-
# Reshape direction and color so they have all the arbitrary intermediate
|
48 |
-
# dimensions as points. Assume first dim = batch dim and last dim = 3.
|
49 |
-
points_dims = points.shape[1:-1]
|
50 |
-
expand_dims = (-1,) + (1,) * len(points_dims)
|
51 |
-
|
52 |
-
if camera_position.shape != normals.shape:
|
53 |
-
camera_position = camera_position.view(expand_dims + (3,))
|
54 |
-
|
55 |
-
normals = F.normalize(normals, p=2, dim=-1, eps=1e-6)
|
56 |
-
|
57 |
-
# Calculate the cosine value.
|
58 |
-
view_direction = camera_position - points
|
59 |
-
view_direction = F.normalize(view_direction, p=2, dim=-1, eps=1e-6)
|
60 |
-
cos_angle = torch.sum(view_direction * normals, dim=-1, keepdim=True)
|
61 |
-
cos_angle = cos_angle.clamp(0, 1)
|
62 |
-
|
63 |
-
# Cosine of the angle between the reflected light ray and the viewer
|
64 |
-
return cos_angle
|
65 |
-
|
66 |
-
|
67 |
-
def _geometry_shading_with_pixels(
|
68 |
-
meshes, fragments, lights, cameras, materials, texels
|
69 |
-
):
|
70 |
-
"""
|
71 |
-
Render pixel space vertex position, normal(world), depth, and cos angle
|
72 |
-
|
73 |
-
Args:
|
74 |
-
meshes: Batch of meshes
|
75 |
-
fragments: Fragments named tuple with the outputs of rasterization
|
76 |
-
lights: Lights class containing a batch of lights
|
77 |
-
cameras: Cameras class containing a batch of cameras
|
78 |
-
materials: Materials class containing a batch of material properties
|
79 |
-
texels: texture per pixel of shape (N, H, W, K, 3)
|
80 |
-
|
81 |
-
Returns:
|
82 |
-
colors: (N, H, W, K, 3)
|
83 |
-
pixel_coords: (N, H, W, K, 3), camera coordinates of each intersection.
|
84 |
-
"""
|
85 |
-
verts = meshes.verts_packed() # (V, 3)
|
86 |
-
faces = meshes.faces_packed() # (F, 3)
|
87 |
-
vertex_normals = meshes.verts_normals_packed() # (V, 3)
|
88 |
-
faces_verts = verts[faces]
|
89 |
-
faces_normals = vertex_normals[faces]
|
90 |
-
pixel_coords_in_camera = interpolate_face_attributes(
|
91 |
-
fragments.pix_to_face, fragments.bary_coords, faces_verts
|
92 |
-
)
|
93 |
-
pixel_normals = interpolate_face_attributes(
|
94 |
-
fragments.pix_to_face, fragments.bary_coords, faces_normals
|
95 |
-
)
|
96 |
-
|
97 |
-
cos_angles = get_cos_angle(pixel_coords_in_camera, pixel_normals, cameras.get_camera_center())
|
98 |
-
|
99 |
-
return pixel_coords_in_camera, pixel_normals, fragments.zbuf[...,None], cos_angles
|
100 |
-
|
101 |
-
|
102 |
-
class HardGeometryShader(ShaderBase):
|
103 |
-
"""
|
104 |
-
renders common geometric informations.
|
105 |
-
|
106 |
-
|
107 |
-
"""
|
108 |
-
|
109 |
-
def forward(self, fragments, meshes, **kwargs):
|
110 |
-
cameras = super()._get_cameras(**kwargs)
|
111 |
-
texels = self.texel_from_uv(fragments, meshes)
|
112 |
-
|
113 |
-
lights = kwargs.get("lights", self.lights)
|
114 |
-
materials = kwargs.get("materials", self.materials)
|
115 |
-
blend_params = kwargs.get("blend_params", self.blend_params)
|
116 |
-
verts, normals, depths, cos_angles = _geometry_shading_with_pixels(
|
117 |
-
meshes=meshes,
|
118 |
-
fragments=fragments,
|
119 |
-
texels=texels,
|
120 |
-
lights=lights,
|
121 |
-
cameras=cameras,
|
122 |
-
materials=materials,
|
123 |
-
)
|
124 |
-
verts = hard_rgb_blend(verts, fragments, blend_params)
|
125 |
-
normals = hard_rgb_blend(normals, fragments, blend_params)
|
126 |
-
depths = hard_rgb_blend(depths, fragments, blend_params)
|
127 |
-
cos_angles = hard_rgb_blend(cos_angles, fragments, blend_params)
|
128 |
-
texels = hard_rgb_blend(texels, fragments, blend_params)
|
129 |
-
return verts, normals, depths, cos_angles, texels, fragments
|
130 |
-
|
131 |
-
def texel_from_uv(self, fragments, meshes):
|
132 |
-
texture_tmp = meshes.textures
|
133 |
-
maps_tmp = texture_tmp.maps_padded()
|
134 |
-
uv_color = [ [[1,0],[1,1]],[[0,0],[0,1]] ]
|
135 |
-
uv_color = torch.FloatTensor(uv_color).to(maps_tmp[0].device).type(maps_tmp[0].dtype)
|
136 |
-
uv_texture = TexturesUV([uv_color.clone() for t in maps_tmp], texture_tmp.faces_uvs_padded(), texture_tmp.verts_uvs_padded(), sampling_mode="bilinear")
|
137 |
-
meshes.textures = uv_texture
|
138 |
-
texels = meshes.sample_textures(fragments)
|
139 |
-
meshes.textures = texture_tmp
|
140 |
-
texels = torch.cat((texels, texels[...,-1:]*0), dim=-1)
|
141 |
-
return texels
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
step1x3d_texture/texture_sync/project.py
DELETED
@@ -1,521 +0,0 @@
|
|
1 |
-
import torch
|
2 |
-
import pytorch3d
|
3 |
-
|
4 |
-
|
5 |
-
from pytorch3d.io import load_objs_as_meshes, load_obj, save_obj, IO
|
6 |
-
|
7 |
-
from pytorch3d.structures import Meshes
|
8 |
-
from pytorch3d.renderer import (
|
9 |
-
look_at_view_transform,
|
10 |
-
FoVPerspectiveCameras,
|
11 |
-
FoVOrthographicCameras,
|
12 |
-
AmbientLights,
|
13 |
-
PointLights,
|
14 |
-
DirectionalLights,
|
15 |
-
Materials,
|
16 |
-
RasterizationSettings,
|
17 |
-
MeshRenderer,
|
18 |
-
MeshRasterizer,
|
19 |
-
TexturesUV
|
20 |
-
)
|
21 |
-
|
22 |
-
from .geometry import HardGeometryShader
|
23 |
-
from .shader import HardNChannelFlatShader
|
24 |
-
from .voronoi import voronoi_solve
|
25 |
-
from trimesh import Trimesh
|
26 |
-
|
27 |
-
# Pytorch3D based renderering functions, managed in a class
|
28 |
-
# Render size is recommended to be the same as your latent view size
|
29 |
-
# DO NOT USE "bilinear" sampling when you are handling latents.
|
30 |
-
# Stable Diffusion has 4 latent channels so use channels=4
|
31 |
-
|
32 |
-
class UVProjection():
|
33 |
-
def __init__(self, texture_size=96, render_size=64, sampling_mode="nearest", channels=3, device=None):
|
34 |
-
self.channels = channels
|
35 |
-
self.device = device or torch.device("cpu")
|
36 |
-
self.lights = AmbientLights(ambient_color=((1.0,)*channels,), device=self.device)
|
37 |
-
self.target_size = (texture_size,texture_size)
|
38 |
-
self.render_size = render_size
|
39 |
-
self.sampling_mode = sampling_mode
|
40 |
-
|
41 |
-
|
42 |
-
# # Load obj mesh, rescale the mesh to fit into the bounding box
|
43 |
-
# def load_mesh(self, mesh_path, scale_factor=2.0, auto_center=True, autouv=False):
|
44 |
-
# mesh = load_objs_as_meshes([mesh_path], device=self.device)
|
45 |
-
# if auto_center:
|
46 |
-
# verts = mesh.verts_packed()
|
47 |
-
# max_bb = (verts - 0).max(0)[0]
|
48 |
-
# min_bb = (verts - 0).min(0)[0]
|
49 |
-
# scale = (max_bb - min_bb).max()/2
|
50 |
-
# center = (max_bb+min_bb) /2
|
51 |
-
# mesh.offset_verts_(-center)
|
52 |
-
# mesh.scale_verts_((scale_factor / float(scale)))
|
53 |
-
# else:
|
54 |
-
# mesh.scale_verts_((scale_factor))
|
55 |
-
|
56 |
-
# if autouv or (mesh.textures is None):
|
57 |
-
# mesh = self.uv_unwrap(mesh)
|
58 |
-
# self.mesh = mesh
|
59 |
-
# Load obj mesh, rescale the mesh to fit into the bounding box
|
60 |
-
def load_mesh(self, mesh, scale_factor=2.0, auto_center=True, autouv=False, normals=None):
|
61 |
-
if isinstance(mesh, Trimesh):
|
62 |
-
vertices = torch.tensor(mesh.vertices, dtype=torch.float32).to(self.device)
|
63 |
-
faces = torch.tensor(mesh.faces, dtype=torch.int64).to(self.device)
|
64 |
-
if faces.ndim == 1:
|
65 |
-
faces = faces.unsqueeze(0)
|
66 |
-
mesh = Meshes(
|
67 |
-
verts=[vertices],
|
68 |
-
faces=[faces]
|
69 |
-
)
|
70 |
-
verts = mesh.verts_packed()
|
71 |
-
mesh = mesh.update_padded(verts[None,:, :])
|
72 |
-
# from pytorch3d.renderer.mesh.textures import TexturesVertex
|
73 |
-
# if normals is None:
|
74 |
-
# normals = mesh.verts_normals_packed()
|
75 |
-
# # set normals as vertext colors
|
76 |
-
# mesh.textures = TexturesVertex(verts_features=[normals / 2 + 0.5])
|
77 |
-
elif isinstance(mesh, str) and os.path.isfile(mesh):
|
78 |
-
mesh = load_objs_as_meshes([mesh_path], device=self.device)
|
79 |
-
if auto_center:
|
80 |
-
verts = mesh.verts_packed()
|
81 |
-
max_bb = (verts - 0).max(0)[0]
|
82 |
-
min_bb = (verts - 0).min(0)[0]
|
83 |
-
scale = (max_bb - min_bb).max()/2
|
84 |
-
center = (max_bb+min_bb) /2
|
85 |
-
mesh.offset_verts_(-center)
|
86 |
-
mesh.scale_verts_((scale_factor / float(scale)))
|
87 |
-
else:
|
88 |
-
mesh.scale_verts_((scale_factor))
|
89 |
-
|
90 |
-
if autouv or (mesh.textures is None):
|
91 |
-
mesh = self.uv_unwrap(mesh)
|
92 |
-
self.mesh = mesh
|
93 |
-
|
94 |
-
def load_glb_mesh(self, mesh_path, scale_factor=2.0, auto_center=True, autouv=False):
|
95 |
-
from pytorch3d.io.experimental_gltf_io import MeshGlbFormat
|
96 |
-
io = IO()
|
97 |
-
io.register_meshes_format(MeshGlbFormat())
|
98 |
-
with open(mesh_path, "rb") as f:
|
99 |
-
mesh = io.load_mesh(f, include_textures=True, device=self.device)
|
100 |
-
if auto_center:
|
101 |
-
verts = mesh.verts_packed()
|
102 |
-
max_bb = (verts - 0).max(0)[0]
|
103 |
-
min_bb = (verts - 0).min(0)[0]
|
104 |
-
scale = (max_bb - min_bb).max()/2
|
105 |
-
center = (max_bb+min_bb) /2
|
106 |
-
mesh.offset_verts_(-center)
|
107 |
-
mesh.scale_verts_((scale_factor / float(scale)))
|
108 |
-
else:
|
109 |
-
mesh.scale_verts_((scale_factor))
|
110 |
-
if autouv or (mesh.textures is None):
|
111 |
-
mesh = self.uv_unwrap(mesh)
|
112 |
-
self.mesh = mesh
|
113 |
-
|
114 |
-
|
115 |
-
# Save obj mesh
|
116 |
-
def save_mesh(self, mesh_path, texture):
|
117 |
-
save_obj(mesh_path,
|
118 |
-
self.mesh.verts_list()[0],
|
119 |
-
self.mesh.faces_list()[0],
|
120 |
-
verts_uvs= self.mesh.textures.verts_uvs_list()[0],
|
121 |
-
faces_uvs= self.mesh.textures.faces_uvs_list()[0],
|
122 |
-
texture_map=texture)
|
123 |
-
|
124 |
-
# Code referred to TEXTure code (https://github.com/TEXTurePaper/TEXTurePaper.git)
|
125 |
-
def uv_unwrap(self, mesh):
|
126 |
-
verts_list = mesh.verts_list()[0]
|
127 |
-
faces_list = mesh.faces_list()[0]
|
128 |
-
|
129 |
-
|
130 |
-
import xatlas
|
131 |
-
import numpy as np
|
132 |
-
v_np = verts_list.cpu().numpy()
|
133 |
-
f_np = faces_list.int().cpu().numpy()
|
134 |
-
atlas = xatlas.Atlas()
|
135 |
-
atlas.add_mesh(v_np, f_np)
|
136 |
-
chart_options = xatlas.ChartOptions()
|
137 |
-
chart_options.max_iterations = 4
|
138 |
-
atlas.generate(chart_options=chart_options)
|
139 |
-
vmapping, ft_np, vt_np = atlas[0] # [N], [M, 3], [N, 2]
|
140 |
-
|
141 |
-
vt = torch.from_numpy(vt_np.astype(np.float32)).type(verts_list.dtype).to(mesh.device)
|
142 |
-
ft = torch.from_numpy(ft_np.astype(np.int64)).type(faces_list.dtype).to(mesh.device)
|
143 |
-
|
144 |
-
new_map = torch.zeros(self.target_size+(self.channels,), device=mesh.device)
|
145 |
-
new_tex = TexturesUV(
|
146 |
-
[new_map],
|
147 |
-
[ft],
|
148 |
-
[vt],
|
149 |
-
sampling_mode=self.sampling_mode
|
150 |
-
)
|
151 |
-
|
152 |
-
mesh.textures = new_tex
|
153 |
-
return mesh
|
154 |
-
|
155 |
-
|
156 |
-
'''
|
157 |
-
A functions that disconnect faces in the mesh according to
|
158 |
-
its UV seams. The number of vertices are made equal to the
|
159 |
-
number of unique vertices its UV layout, while the faces list
|
160 |
-
is intact.
|
161 |
-
'''
|
162 |
-
def disconnect_faces(self):
|
163 |
-
mesh = self.mesh
|
164 |
-
verts_list = mesh.verts_list()
|
165 |
-
faces_list = mesh.faces_list()
|
166 |
-
verts_uvs_list = mesh.textures.verts_uvs_list()
|
167 |
-
faces_uvs_list = mesh.textures.faces_uvs_list()
|
168 |
-
packed_list = [v[f] for v,f in zip(verts_list, faces_list)]
|
169 |
-
verts_disconnect_list = [
|
170 |
-
torch.zeros(
|
171 |
-
(verts_uvs_list[i].shape[0], 3),
|
172 |
-
dtype=verts_list[0].dtype,
|
173 |
-
device=verts_list[0].device
|
174 |
-
)
|
175 |
-
for i in range(len(verts_list))]
|
176 |
-
for i in range(len(verts_list)):
|
177 |
-
verts_disconnect_list[i][faces_uvs_list] = packed_list[i]
|
178 |
-
assert not mesh.has_verts_normals(), "Not implemented for vertex normals"
|
179 |
-
self.mesh_d = Meshes(verts_disconnect_list, faces_uvs_list, mesh.textures)
|
180 |
-
return self.mesh_d
|
181 |
-
|
182 |
-
|
183 |
-
'''
|
184 |
-
A function that construct a temp mesh for back-projection.
|
185 |
-
Take a disconnected mesh and a rasterizer, the function calculates
|
186 |
-
the projected faces as the UV, as use its original UV with pseudo
|
187 |
-
z value as world space geometry.
|
188 |
-
'''
|
189 |
-
def construct_uv_mesh(self):
|
190 |
-
mesh = self.mesh_d
|
191 |
-
verts_list = mesh.verts_list()
|
192 |
-
verts_uvs_list = mesh.textures.verts_uvs_list()
|
193 |
-
# faces_list = [torch.flip(faces, [-1]) for faces in mesh.faces_list()]
|
194 |
-
new_verts_list = []
|
195 |
-
for i, (verts, verts_uv) in enumerate(zip(verts_list, verts_uvs_list)):
|
196 |
-
verts = verts.clone()
|
197 |
-
verts_uv = verts_uv.clone()
|
198 |
-
verts[...,0:2] = verts_uv[...,:]
|
199 |
-
verts = (verts - 0.5) * 2
|
200 |
-
verts[...,2] *= 1
|
201 |
-
new_verts_list.append(verts)
|
202 |
-
textures_uv = mesh.textures.clone()
|
203 |
-
self.mesh_uv = Meshes(new_verts_list, mesh.faces_list(), textures_uv)
|
204 |
-
return self.mesh_uv
|
205 |
-
|
206 |
-
|
207 |
-
# Set texture for the current mesh.
|
208 |
-
def set_texture_map(self, texture):
|
209 |
-
new_map = texture.permute(1, 2, 0)
|
210 |
-
new_map = new_map.to(self.device)
|
211 |
-
new_tex = TexturesUV(
|
212 |
-
[new_map],
|
213 |
-
self.mesh.textures.faces_uvs_padded(),
|
214 |
-
self.mesh.textures.verts_uvs_padded(),
|
215 |
-
sampling_mode=self.sampling_mode
|
216 |
-
)
|
217 |
-
self.mesh.textures = new_tex
|
218 |
-
|
219 |
-
|
220 |
-
# Set the initial normal noise texture
|
221 |
-
# No generator here for replication of the experiment result. Add one as you wish
|
222 |
-
def set_noise_texture(self, channels=None):
|
223 |
-
if not channels:
|
224 |
-
channels = self.channels
|
225 |
-
noise_texture = torch.normal(0, 1, (channels,) + self.target_size, device=self.device)
|
226 |
-
self.set_texture_map(noise_texture)
|
227 |
-
return noise_texture
|
228 |
-
|
229 |
-
|
230 |
-
# Set the cameras given the camera poses and centers
|
231 |
-
def set_cameras(self, camera_poses, centers=None, camera_distance=2.7, scale=None):
|
232 |
-
elev = torch.FloatTensor([pose[0] for pose in camera_poses])
|
233 |
-
azim = torch.FloatTensor([pose[1] for pose in camera_poses])
|
234 |
-
R, T = look_at_view_transform(dist=camera_distance, elev=elev, azim=azim, at=centers or ((0,0,0),))
|
235 |
-
# self.cameras = FoVOrthographicCameras(device=self.device, R=R, T=T, scale_xyz=scale or ((1,1,1),))
|
236 |
-
self.cameras = FoVOrthographicCameras(device=self.device, R=R, T=T, scale_xyz=scale or ((1,1,1),), znear=0.1, min_x=-0.55, max_x=0.55, min_y=-0.55, max_y=0.55)
|
237 |
-
|
238 |
-
# Set all necessary internal data for rendering and texture baking
|
239 |
-
# Can be used to refresh after changing camera positions
|
240 |
-
def set_cameras_and_render_settings(self, camera_poses, centers=None, camera_distance=2.7, render_size=None, scale=None):
|
241 |
-
self.set_cameras(camera_poses, centers, camera_distance, scale=scale)
|
242 |
-
if render_size is None:
|
243 |
-
render_size = self.render_size
|
244 |
-
if not hasattr(self, "renderer"):
|
245 |
-
self.setup_renderer(size=render_size)
|
246 |
-
if not hasattr(self, "mesh_d"):
|
247 |
-
self.disconnect_faces()
|
248 |
-
if not hasattr(self, "mesh_uv"):
|
249 |
-
self.construct_uv_mesh()
|
250 |
-
self.calculate_tex_gradient()
|
251 |
-
self.calculate_visible_triangle_mask()
|
252 |
-
_,_,_,cos_maps,_, _ = self.render_geometry()
|
253 |
-
self.calculate_cos_angle_weights(cos_maps)
|
254 |
-
|
255 |
-
|
256 |
-
# Setup renderers for rendering
|
257 |
-
# max faces per bin set to 30000 to avoid overflow in many test cases.
|
258 |
-
# You can use default value to let pytorch3d handle that for you.
|
259 |
-
def setup_renderer(self, size=64, blur=0.0, face_per_pix=1, perspective_correct=False, channels=None):
|
260 |
-
if not channels:
|
261 |
-
channels = self.channels
|
262 |
-
|
263 |
-
self.raster_settings = RasterizationSettings(
|
264 |
-
image_size=size,
|
265 |
-
blur_radius=blur,
|
266 |
-
faces_per_pixel=face_per_pix,
|
267 |
-
perspective_correct=perspective_correct,
|
268 |
-
cull_backfaces=True,
|
269 |
-
max_faces_per_bin=30000,
|
270 |
-
)
|
271 |
-
|
272 |
-
self.renderer = MeshRenderer(
|
273 |
-
rasterizer=MeshRasterizer(
|
274 |
-
cameras=self.cameras,
|
275 |
-
raster_settings=self.raster_settings,
|
276 |
-
|
277 |
-
),
|
278 |
-
shader=HardNChannelFlatShader(
|
279 |
-
device=self.device,
|
280 |
-
cameras=self.cameras,
|
281 |
-
lights=self.lights,
|
282 |
-
channels=channels
|
283 |
-
# materials=materials
|
284 |
-
)
|
285 |
-
)
|
286 |
-
|
287 |
-
|
288 |
-
# Bake screen-space cosine weights to UV space
|
289 |
-
# May be able to reimplement using the generic "bake_texture" function, but it works so leave it here for now
|
290 |
-
@torch.enable_grad()
|
291 |
-
def calculate_cos_angle_weights(self, cos_angles, fill=True, channels=None):
|
292 |
-
if not channels:
|
293 |
-
channels = self.channels
|
294 |
-
cos_maps = []
|
295 |
-
tmp_mesh = self.mesh.clone()
|
296 |
-
for i in range(len(self.cameras)):
|
297 |
-
|
298 |
-
zero_map = torch.zeros(self.target_size+(channels,), device=self.device, requires_grad=True)
|
299 |
-
optimizer = torch.optim.SGD([zero_map], lr=1, momentum=0)
|
300 |
-
optimizer.zero_grad()
|
301 |
-
zero_tex = TexturesUV([zero_map], self.mesh.textures.faces_uvs_padded(), self.mesh.textures.verts_uvs_padded(), sampling_mode=self.sampling_mode)
|
302 |
-
tmp_mesh.textures = zero_tex
|
303 |
-
|
304 |
-
images_predicted = self.renderer(tmp_mesh, cameras=self.cameras[i], lights=self.lights)
|
305 |
-
|
306 |
-
loss = torch.sum((cos_angles[i,:,:,0:1]**1 - images_predicted)**2)
|
307 |
-
loss.backward()
|
308 |
-
optimizer.step()
|
309 |
-
|
310 |
-
if fill:
|
311 |
-
zero_map = zero_map.detach() / (self.gradient_maps[i] + 1E-8)
|
312 |
-
zero_map = voronoi_solve(zero_map, self.gradient_maps[i][...,0])
|
313 |
-
else:
|
314 |
-
zero_map = zero_map.detach() / (self.gradient_maps[i]+1E-8)
|
315 |
-
cos_maps.append(zero_map)
|
316 |
-
self.cos_maps = cos_maps
|
317 |
-
|
318 |
-
|
319 |
-
# Get geometric info from fragment shader
|
320 |
-
# Can be used for generating conditioning image and cosine weights
|
321 |
-
# Returns some information you may not need, remember to release them for memory saving
|
322 |
-
@torch.no_grad()
|
323 |
-
def render_geometry(self, image_size=None):
|
324 |
-
if image_size:
|
325 |
-
size = self.renderer.rasterizer.raster_settings.image_size
|
326 |
-
self.renderer.rasterizer.raster_settings.image_size = image_size
|
327 |
-
shader = self.renderer.shader
|
328 |
-
self.renderer.shader = HardGeometryShader(device=self.device, cameras=self.cameras[0], lights=self.lights)
|
329 |
-
tmp_mesh = self.mesh.clone()
|
330 |
-
|
331 |
-
verts, normals, depths, cos_angles, texels, fragments = self.renderer(tmp_mesh.extend(len(self.cameras)), cameras=self.cameras, lights=self.lights)
|
332 |
-
self.renderer.shader = shader
|
333 |
-
|
334 |
-
if image_size:
|
335 |
-
self.renderer.rasterizer.raster_settings.image_size = size
|
336 |
-
|
337 |
-
return verts, normals, depths, cos_angles, texels, fragments
|
338 |
-
|
339 |
-
|
340 |
-
# Project world normal to view space and normalize
|
341 |
-
@torch.no_grad()
|
342 |
-
def decode_view_normal(self, normals):
|
343 |
-
w2v_mat = self.cameras.get_full_projection_transform()
|
344 |
-
normals_view = torch.clone(normals)[:,:,:,0:3]
|
345 |
-
normals_view = normals_view.reshape(normals_view.shape[0], -1, 3)
|
346 |
-
normals_view = w2v_mat.transform_normals(normals_view)
|
347 |
-
normals_view = normals_view.reshape(normals.shape[0:3]+(3,))
|
348 |
-
normals_view[:,:,:,2] *= -1
|
349 |
-
normals = (normals_view[...,0:3]+1) * normals[...,3:] / 2 + torch.FloatTensor(((((0.5,0.5,1))))).to(self.device) * (1 - normals[...,3:])
|
350 |
-
# normals = torch.cat([normal for normal in normals], dim=1)
|
351 |
-
normals = normals.clamp(0, 1)
|
352 |
-
return normals
|
353 |
-
|
354 |
-
|
355 |
-
# Normalize absolute depth to inverse depth
|
356 |
-
@torch.no_grad()
|
357 |
-
def decode_normalized_depth(self, depths, batched_norm=False):
|
358 |
-
view_z, mask = depths.unbind(-1)
|
359 |
-
view_z = view_z * mask + 100 * (1-mask)
|
360 |
-
inv_z = 1 / view_z
|
361 |
-
inv_z_min = inv_z * mask + 100 * (1-mask)
|
362 |
-
if not batched_norm:
|
363 |
-
max_ = torch.max(inv_z, 1, keepdim=True)
|
364 |
-
max_ = torch.max(max_[0], 2, keepdim=True)[0]
|
365 |
-
|
366 |
-
min_ = torch.min(inv_z_min, 1, keepdim=True)
|
367 |
-
min_ = torch.min(min_[0], 2, keepdim=True)[0]
|
368 |
-
else:
|
369 |
-
max_ = torch.max(inv_z)
|
370 |
-
min_ = torch.min(inv_z_min)
|
371 |
-
inv_z = (inv_z - min_) / (max_ - min_)
|
372 |
-
inv_z = inv_z.clamp(0,1)
|
373 |
-
inv_z = inv_z[...,None].repeat(1,1,1,3)
|
374 |
-
|
375 |
-
return inv_z
|
376 |
-
|
377 |
-
|
378 |
-
# Multiple screen pixels could pass gradient to a same texel
|
379 |
-
# We can precalculate this gradient strength and use it to normalize gradients when we bake textures
|
380 |
-
@torch.enable_grad()
|
381 |
-
def calculate_tex_gradient(self, channels=None):
|
382 |
-
if not channels:
|
383 |
-
channels = self.channels
|
384 |
-
tmp_mesh = self.mesh.clone()
|
385 |
-
gradient_maps = []
|
386 |
-
for i in range(len(self.cameras)):
|
387 |
-
zero_map = torch.zeros(self.target_size+(channels,), device=self.device, requires_grad=True)
|
388 |
-
optimizer = torch.optim.SGD([zero_map], lr=1, momentum=0)
|
389 |
-
optimizer.zero_grad()
|
390 |
-
zero_tex = TexturesUV([zero_map], self.mesh.textures.faces_uvs_padded(), self.mesh.textures.verts_uvs_padded(), sampling_mode=self.sampling_mode)
|
391 |
-
tmp_mesh.textures = zero_tex
|
392 |
-
images_predicted = self.renderer(tmp_mesh, cameras=self.cameras[i], lights=self.lights)
|
393 |
-
loss = torch.sum((1 - images_predicted)**2)
|
394 |
-
loss.backward()
|
395 |
-
optimizer.step()
|
396 |
-
|
397 |
-
gradient_maps.append(zero_map.detach())
|
398 |
-
|
399 |
-
self.gradient_maps = gradient_maps
|
400 |
-
|
401 |
-
|
402 |
-
# Get the UV space masks of triangles visible in each view
|
403 |
-
# First get face ids from each view, then filter pixels on UV space to generate masks
|
404 |
-
@torch.no_grad()
|
405 |
-
def calculate_visible_triangle_mask(self, channels=None, image_size=(512,512)):
|
406 |
-
if not channels:
|
407 |
-
channels = self.channels
|
408 |
-
|
409 |
-
pix2face_list = []
|
410 |
-
for i in range(len(self.cameras)):
|
411 |
-
self.renderer.rasterizer.raster_settings.image_size=image_size
|
412 |
-
pix2face = self.renderer.rasterizer(self.mesh_d, cameras=self.cameras[i]).pix_to_face
|
413 |
-
self.renderer.rasterizer.raster_settings.image_size=self.render_size
|
414 |
-
pix2face_list.append(pix2face)
|
415 |
-
|
416 |
-
if not hasattr(self, "mesh_uv"):
|
417 |
-
self.construct_uv_mesh()
|
418 |
-
|
419 |
-
raster_settings = RasterizationSettings(
|
420 |
-
image_size=self.target_size,
|
421 |
-
blur_radius=0,
|
422 |
-
faces_per_pixel=1,
|
423 |
-
perspective_correct=False,
|
424 |
-
cull_backfaces=False,
|
425 |
-
max_faces_per_bin=30000,
|
426 |
-
)
|
427 |
-
|
428 |
-
R, T = look_at_view_transform(dist=2, elev=0, azim=0)
|
429 |
-
cameras = FoVOrthographicCameras(device=self.device, R=R, T=T)
|
430 |
-
|
431 |
-
rasterizer=MeshRasterizer(
|
432 |
-
cameras=cameras,
|
433 |
-
raster_settings=raster_settings
|
434 |
-
)
|
435 |
-
uv_pix2face = rasterizer(self.mesh_uv).pix_to_face
|
436 |
-
|
437 |
-
visible_triangles = []
|
438 |
-
for i in range(len(pix2face_list)):
|
439 |
-
valid_faceid = torch.unique(pix2face_list[i])
|
440 |
-
valid_faceid = valid_faceid[1:] if valid_faceid[0]==-1 else valid_faceid
|
441 |
-
mask = torch.isin(uv_pix2face[0], valid_faceid, assume_unique=False)
|
442 |
-
# uv_pix2face[0][~mask] = -1
|
443 |
-
triangle_mask = torch.ones(self.target_size+(1,), device=self.device)
|
444 |
-
triangle_mask[~mask] = 0
|
445 |
-
|
446 |
-
triangle_mask[:,1:][triangle_mask[:,:-1] > 0] = 1
|
447 |
-
triangle_mask[:,:-1][triangle_mask[:,1:] > 0] = 1
|
448 |
-
triangle_mask[1:,:][triangle_mask[:-1,:] > 0] = 1
|
449 |
-
triangle_mask[:-1,:][triangle_mask[1:,:] > 0] = 1
|
450 |
-
visible_triangles.append(triangle_mask)
|
451 |
-
|
452 |
-
self.visible_triangles = visible_triangles
|
453 |
-
|
454 |
-
|
455 |
-
|
456 |
-
# Render the current mesh and texture from current cameras
|
457 |
-
def render_textured_views(self):
|
458 |
-
meshes = self.mesh.extend(len(self.cameras))
|
459 |
-
images_predicted = self.renderer(meshes, cameras=self.cameras, lights=self.lights)
|
460 |
-
|
461 |
-
return [image.permute(2, 0, 1) for image in images_predicted]
|
462 |
-
|
463 |
-
|
464 |
-
# Bake views into a texture
|
465 |
-
# First bake into individual textures then combine based on cosine weight
|
466 |
-
@torch.enable_grad()
|
467 |
-
def bake_texture(self, views=None, main_views=[], cos_weighted=True, channels=None, exp=None, noisy=False, generator=None):
|
468 |
-
if not exp:
|
469 |
-
exp=1
|
470 |
-
if not channels:
|
471 |
-
channels = self.channels
|
472 |
-
views = [view.permute(1, 2, 0) for view in views]
|
473 |
-
|
474 |
-
tmp_mesh = self.mesh
|
475 |
-
bake_maps = [torch.zeros(self.target_size+(views[0].shape[2],), device=self.device, requires_grad=True) for view in views]
|
476 |
-
optimizer = torch.optim.SGD(bake_maps, lr=1, momentum=0)
|
477 |
-
optimizer.zero_grad()
|
478 |
-
loss = 0
|
479 |
-
for i in range(len(self.cameras)):
|
480 |
-
bake_tex = TexturesUV([bake_maps[i]], tmp_mesh.textures.faces_uvs_padded(), tmp_mesh.textures.verts_uvs_padded(), sampling_mode=self.sampling_mode)
|
481 |
-
tmp_mesh.textures = bake_tex
|
482 |
-
images_predicted = self.renderer(tmp_mesh, cameras=self.cameras[i], lights=self.lights, device=self.device)
|
483 |
-
predicted_rgb = images_predicted[..., :-1]
|
484 |
-
loss += (((predicted_rgb[...] - views[i]))**2).sum()
|
485 |
-
loss.backward(retain_graph=False)
|
486 |
-
optimizer.step()
|
487 |
-
|
488 |
-
total_weights = 0
|
489 |
-
baked = 0
|
490 |
-
for i in range(len(bake_maps)):
|
491 |
-
normalized_baked_map = bake_maps[i].detach() / (self.gradient_maps[i] + 1E-8)
|
492 |
-
bake_map = voronoi_solve(normalized_baked_map, self.gradient_maps[i][...,0])
|
493 |
-
weight = self.visible_triangles[i] * (self.cos_maps[i]) ** exp
|
494 |
-
if noisy:
|
495 |
-
noise = torch.rand(weight.shape[:-1]+(1,), generator=generator).type(weight.dtype).to(weight.device)
|
496 |
-
weight *= noise
|
497 |
-
total_weights += weight
|
498 |
-
baked += bake_map * weight
|
499 |
-
baked /= total_weights + 1E-8
|
500 |
-
baked = voronoi_solve(baked, total_weights[...,0])
|
501 |
-
|
502 |
-
bake_tex = TexturesUV([baked], tmp_mesh.textures.faces_uvs_padded(), tmp_mesh.textures.verts_uvs_padded(), sampling_mode=self.sampling_mode)
|
503 |
-
tmp_mesh.textures = bake_tex
|
504 |
-
extended_mesh = tmp_mesh.extend(len(self.cameras))
|
505 |
-
images_predicted = self.renderer(extended_mesh, cameras=self.cameras, lights=self.lights)
|
506 |
-
learned_views = [image.permute(2, 0, 1) for image in images_predicted]
|
507 |
-
|
508 |
-
return learned_views, baked.permute(2, 0, 1), total_weights.permute(2, 0, 1)
|
509 |
-
|
510 |
-
|
511 |
-
# Move the internel data to a specific device
|
512 |
-
def to(self, device):
|
513 |
-
for mesh_name in ["mesh", "mesh_d", "mesh_uv"]:
|
514 |
-
if hasattr(self, mesh_name):
|
515 |
-
mesh = getattr(self, mesh_name)
|
516 |
-
setattr(self, mesh_name, mesh.to(device))
|
517 |
-
for list_name in ["visible_triangles", "visibility_maps", "cos_maps"]:
|
518 |
-
if hasattr(self, list_name):
|
519 |
-
map_list = getattr(self, list_name)
|
520 |
-
for i in range(len(map_list)):
|
521 |
-
map_list[i] = map_list[i].to(device)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
step1x3d_texture/texture_sync/shader.py
DELETED
@@ -1,118 +0,0 @@
|
|
1 |
-
from typing import Optional
|
2 |
-
|
3 |
-
import torch
|
4 |
-
import pytorch3d
|
5 |
-
|
6 |
-
|
7 |
-
from pytorch3d.io import load_objs_as_meshes, load_obj, save_obj
|
8 |
-
from pytorch3d.ops import interpolate_face_attributes
|
9 |
-
|
10 |
-
from pytorch3d.structures import Meshes
|
11 |
-
from pytorch3d.renderer import (
|
12 |
-
look_at_view_transform,
|
13 |
-
FoVPerspectiveCameras,
|
14 |
-
AmbientLights,
|
15 |
-
PointLights,
|
16 |
-
DirectionalLights,
|
17 |
-
Materials,
|
18 |
-
RasterizationSettings,
|
19 |
-
MeshRenderer,
|
20 |
-
MeshRasterizer,
|
21 |
-
SoftPhongShader,
|
22 |
-
SoftSilhouetteShader,
|
23 |
-
HardPhongShader,
|
24 |
-
TexturesVertex,
|
25 |
-
TexturesUV,
|
26 |
-
Materials,
|
27 |
-
|
28 |
-
)
|
29 |
-
from pytorch3d.renderer.blending import BlendParams,hard_rgb_blend
|
30 |
-
from pytorch3d.renderer.utils import convert_to_tensors_and_broadcast, TensorProperties
|
31 |
-
|
32 |
-
from pytorch3d.renderer.lighting import AmbientLights
|
33 |
-
from pytorch3d.renderer.materials import Materials
|
34 |
-
from pytorch3d.renderer.mesh.shader import ShaderBase
|
35 |
-
from pytorch3d.renderer.mesh.shading import _apply_lighting, flat_shading
|
36 |
-
from pytorch3d.renderer.mesh.rasterizer import Fragments
|
37 |
-
|
38 |
-
|
39 |
-
'''
|
40 |
-
Customized the original pytorch3d hard flat shader to support N channel flat shading
|
41 |
-
'''
|
42 |
-
class HardNChannelFlatShader(ShaderBase):
|
43 |
-
"""
|
44 |
-
Per face lighting - the lighting model is applied using the average face
|
45 |
-
position and the face normal. The blending function hard assigns
|
46 |
-
the color of the closest face for each pixel.
|
47 |
-
|
48 |
-
To use the default values, simply initialize the shader with the desired
|
49 |
-
device e.g.
|
50 |
-
|
51 |
-
.. code-block::
|
52 |
-
|
53 |
-
shader = HardFlatShader(device=torch.device("cuda:0"))
|
54 |
-
"""
|
55 |
-
|
56 |
-
def __init__(
|
57 |
-
self,
|
58 |
-
device = "cpu",
|
59 |
-
cameras: Optional[TensorProperties] = None,
|
60 |
-
lights: Optional[TensorProperties] = None,
|
61 |
-
materials: Optional[Materials] = None,
|
62 |
-
blend_params: Optional[BlendParams] = None,
|
63 |
-
channels: int = 3,
|
64 |
-
):
|
65 |
-
self.channels = channels
|
66 |
-
ones = ((1.0,)*channels,)
|
67 |
-
zeros = ((0.0,)*channels,)
|
68 |
-
|
69 |
-
if not isinstance(lights, AmbientLights) or not lights.ambient_color.shape[-1] == channels:
|
70 |
-
lights = AmbientLights(
|
71 |
-
ambient_color=ones,
|
72 |
-
device=device,
|
73 |
-
)
|
74 |
-
|
75 |
-
if not materials or not materials.ambient_color.shape[-1] == channels:
|
76 |
-
materials = Materials(
|
77 |
-
device=device,
|
78 |
-
diffuse_color=zeros,
|
79 |
-
ambient_color=ones,
|
80 |
-
specular_color=zeros,
|
81 |
-
shininess=0.0,
|
82 |
-
)
|
83 |
-
|
84 |
-
blend_params_new = BlendParams(background_color=(1.0,)*channels)
|
85 |
-
if not isinstance(blend_params, BlendParams):
|
86 |
-
blend_params = blend_params_new
|
87 |
-
else:
|
88 |
-
background_color_ = blend_params.background_color
|
89 |
-
if isinstance(background_color_, Sequence[float]) and not len(background_color_) == channels:
|
90 |
-
blend_params = blend_params_new
|
91 |
-
if isinstance(background_color_, torch.Tensor) and not background_color_.shape[-1] == channels:
|
92 |
-
blend_params = blend_params_new
|
93 |
-
|
94 |
-
super().__init__(
|
95 |
-
device,
|
96 |
-
cameras,
|
97 |
-
lights,
|
98 |
-
materials,
|
99 |
-
blend_params,
|
100 |
-
)
|
101 |
-
|
102 |
-
|
103 |
-
def forward(self, fragments: Fragments, meshes: Meshes, **kwargs) -> torch.Tensor:
|
104 |
-
cameras = super()._get_cameras(**kwargs)
|
105 |
-
texels = meshes.sample_textures(fragments)
|
106 |
-
lights = kwargs.get("lights", self.lights)
|
107 |
-
materials = kwargs.get("materials", self.materials)
|
108 |
-
blend_params = kwargs.get("blend_params", self.blend_params)
|
109 |
-
colors = flat_shading(
|
110 |
-
meshes=meshes,
|
111 |
-
fragments=fragments,
|
112 |
-
texels=texels,
|
113 |
-
lights=lights,
|
114 |
-
cameras=cameras,
|
115 |
-
materials=materials,
|
116 |
-
)
|
117 |
-
images = hard_rgb_blend(colors, fragments, blend_params)
|
118 |
-
return images
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
step1x3d_texture/texture_sync/step_sync.py
DELETED
@@ -1,125 +0,0 @@
|
|
1 |
-
import torch
|
2 |
-
from diffusers.utils.torch_utils import randn_tensor
|
3 |
-
|
4 |
-
'''
|
5 |
-
|
6 |
-
Customized Step Function
|
7 |
-
step on texture
|
8 |
-
'''
|
9 |
-
@torch.no_grad()
|
10 |
-
def step_tex_sync(
|
11 |
-
scheduler,
|
12 |
-
uvp,
|
13 |
-
model_output: torch.FloatTensor,
|
14 |
-
timestep: int,
|
15 |
-
sample: torch.FloatTensor,
|
16 |
-
texture: None,
|
17 |
-
generator=None,
|
18 |
-
return_dict: bool = True,
|
19 |
-
guidance_scale = 1,
|
20 |
-
main_views = [],
|
21 |
-
hires_original_views = True,
|
22 |
-
exp=None,
|
23 |
-
cos_weighted=True
|
24 |
-
):
|
25 |
-
t = timestep
|
26 |
-
|
27 |
-
prev_t = scheduler.previous_timestep(t)
|
28 |
-
|
29 |
-
if model_output.shape[1] == sample.shape[1] * 2 and scheduler.variance_type in ["learned", "learned_range"]:
|
30 |
-
model_output, predicted_variance = torch.split(model_output, sample.shape[1], dim=1)
|
31 |
-
else:
|
32 |
-
predicted_variance = None
|
33 |
-
|
34 |
-
# 1. compute alphas, betas
|
35 |
-
alpha_prod_t = scheduler.alphas_cumprod[t]
|
36 |
-
alpha_prod_t_prev = scheduler.alphas_cumprod[prev_t] if prev_t >= 0 else scheduler.one
|
37 |
-
beta_prod_t = 1 - alpha_prod_t
|
38 |
-
beta_prod_t_prev = 1 - alpha_prod_t_prev
|
39 |
-
current_alpha_t = alpha_prod_t / alpha_prod_t_prev
|
40 |
-
current_beta_t = 1 - current_alpha_t
|
41 |
-
|
42 |
-
# 2. compute predicted original sample from predicted noise also called
|
43 |
-
# "predicted x_0" of formula (15) from https://arxiv.org/pdf/2006.11239.pdf
|
44 |
-
if scheduler.config.prediction_type == "epsilon":
|
45 |
-
pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
|
46 |
-
elif scheduler.config.prediction_type == "sample":
|
47 |
-
pred_original_sample = model_output
|
48 |
-
elif scheduler.config.prediction_type == "v_prediction":
|
49 |
-
pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output
|
50 |
-
else:
|
51 |
-
raise ValueError(
|
52 |
-
f"prediction_type given as {scheduler.config.prediction_type} must be one of `epsilon`, `sample` or"
|
53 |
-
" `v_prediction` for the DDPMScheduler."
|
54 |
-
)
|
55 |
-
# 3. Clip or threshold "predicted x_0"
|
56 |
-
if scheduler.config.thresholding:
|
57 |
-
pred_original_sample = scheduler._threshold_sample(pred_original_sample)
|
58 |
-
elif scheduler.config.clip_sample:
|
59 |
-
pred_original_sample = pred_original_sample.clamp(
|
60 |
-
-scheduler.config.clip_sample_range, scheduler.config.clip_sample_range
|
61 |
-
)
|
62 |
-
|
63 |
-
# 4. Compute coefficients for pred_original_sample x_0 and current sample x_t
|
64 |
-
# See formula (7) from https://arxiv.org/pdf/2006.11239.pdf
|
65 |
-
pred_original_sample_coeff = (alpha_prod_t_prev ** (0.5) * current_beta_t) / beta_prod_t
|
66 |
-
current_sample_coeff = current_alpha_t ** (0.5) * beta_prod_t_prev / beta_prod_t
|
67 |
-
|
68 |
-
'''
|
69 |
-
Add multidiffusion here
|
70 |
-
'''
|
71 |
-
|
72 |
-
if texture is None:
|
73 |
-
sample_views = [view for view in sample]
|
74 |
-
sample_views, texture, _ = uvp.bake_texture(views=sample_views, main_views=main_views, exp=exp)
|
75 |
-
sample_views = torch.stack(sample_views, axis=0)[:,:-1,...]
|
76 |
-
|
77 |
-
|
78 |
-
original_views = [view for view in pred_original_sample]
|
79 |
-
original_views, original_tex, visibility_weights = uvp.bake_texture(views=original_views, main_views=main_views, exp=exp)
|
80 |
-
uvp.set_texture_map(original_tex)
|
81 |
-
original_views = uvp.render_textured_views()
|
82 |
-
original_views = torch.stack(original_views, axis=0)[:,:-1,...]
|
83 |
-
|
84 |
-
# 5. Compute predicted previous sample µ_t
|
85 |
-
# See formula (7) from https://arxiv.org/pdf/2006.11239.pdf
|
86 |
-
# pred_prev_sample = pred_original_sample_coeff * pred_original_sample + current_sample_coeff * sample
|
87 |
-
prev_tex = pred_original_sample_coeff * original_tex + current_sample_coeff * texture
|
88 |
-
|
89 |
-
# 6. Add noise
|
90 |
-
variance = 0
|
91 |
-
|
92 |
-
if predicted_variance is not None:
|
93 |
-
variance_views = [view for view in predicted_variance]
|
94 |
-
variance_views, variance_tex, visibility_weights = uvp.bake_texture(views=variance_views, main_views=main_views, cos_weighted=cos_weighted, exp=exp)
|
95 |
-
variance_views = torch.stack(variance_views, axis=0)[:,:-1,...]
|
96 |
-
else:
|
97 |
-
variance_tex = None
|
98 |
-
|
99 |
-
if t > 0:
|
100 |
-
device = texture.device
|
101 |
-
variance_noise = randn_tensor(
|
102 |
-
texture.shape, generator=generator, device=device, dtype=texture.dtype
|
103 |
-
)
|
104 |
-
if scheduler.variance_type == "fixed_small_log":
|
105 |
-
variance = scheduler._get_variance(t, predicted_variance=variance_tex) * variance_noise
|
106 |
-
elif scheduler.variance_type == "learned_range":
|
107 |
-
variance = scheduler._get_variance(t, predicted_variance=variance_tex)
|
108 |
-
variance = torch.exp(0.5 * variance) * variance_noise
|
109 |
-
else:
|
110 |
-
variance = (scheduler._get_variance(t, predicted_variance=variance_tex) ** 0.5) * variance_noise
|
111 |
-
prev_tex = prev_tex + variance
|
112 |
-
|
113 |
-
uvp.set_texture_map(prev_tex)
|
114 |
-
prev_views = uvp.render_textured_views()
|
115 |
-
pred_prev_sample = torch.clone(sample)
|
116 |
-
for i, view in enumerate(prev_views):
|
117 |
-
pred_prev_sample[i] = view[:-1]
|
118 |
-
masks = [view[-1:] for view in prev_views]
|
119 |
-
|
120 |
-
return {"prev_sample": pred_prev_sample, "pred_original_sample":pred_original_sample, "prev_tex": prev_tex}
|
121 |
-
|
122 |
-
if not return_dict:
|
123 |
-
return pred_prev_sample, pred_original_sample
|
124 |
-
pass
|
125 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|