Spaces:
Running
on
Zero
Running
on
Zero
from dataclasses import dataclass | |
import torch | |
import torch.nn as nn | |
import os | |
import numpy as np | |
from .saving import SaverMixin | |
from ..utils.mesh import Mesh | |
from ..utils.general_utils import scale_tensor | |
class ExporterOutput: | |
save_name: str | |
save_type: str | |
params: dict | |
class IsosurfaceHelper(nn.Module): | |
points_range = (0, 1) | |
def grid_vertices(self): | |
raise NotImplementedError | |
class DiffMarchingCubeHelper(IsosurfaceHelper): | |
def __init__( | |
self, | |
resolution, | |
point_range = (0, 1) | |
): | |
super().__init__() | |
self.resolution = resolution | |
self.points_range = point_range | |
from diso import DiffMC | |
self.mc_func = DiffMC(dtype=torch.float32) | |
self._grid_vertices = None | |
self.register_buffer( | |
"_dummy", torch.zeros(0, dtype=torch.float32), persistent=False | |
) | |
def grid_vertices(self): | |
if self._grid_vertices is None: | |
# keep the vertices on CPU so that we can support very large resolution | |
x, y, z = ( | |
torch.linspace(*self.points_range, self.resolution), | |
torch.linspace(*self.points_range, self.resolution), | |
torch.linspace(*self.points_range, self.resolution), | |
) | |
x, y, z = torch.meshgrid(x, y, z, indexing="ij") | |
verts = torch.stack([x, y, z], dim=-1).reshape(-1, 3) | |
verts = verts * (self.points_range[1] - self.points_range[0]) + self.points_range[0] | |
self._grid_vertices = verts | |
return self._grid_vertices | |
def forward( | |
self, | |
level, | |
deformation = None, | |
isovalue=0.0, | |
): | |
level = level.view(self.resolution, self.resolution, self.resolution) | |
if deformation is not None: | |
deformation = deformation.view(self.resolution, self.resolution, self.resolution, 3) | |
v_pos, t_pos_idx = self.mc_func(level, deformation, isovalue=isovalue) | |
v_pos = v_pos * (self.points_range[1] - self.points_range[0]) + self.points_range[0] | |
# TODO: if the mesh is good | |
return Mesh(v_pos=v_pos, t_pos_idx=t_pos_idx) | |
def isosurface( | |
space_cache, | |
forward_field, | |
isosurface_helper, | |
): | |
# the isosurface is dependent on the space cache | |
# randomly detach isosurface method if it is differentiable | |
# get the batchsize | |
if torch.is_tensor(space_cache): #space cache | |
batch_size = space_cache.shape[0] | |
elif isinstance(space_cache, dict): #hyper net | |
# Dict[str, List[Float[Tensor, "B ..."]]] | |
for key in space_cache.keys(): | |
batch_size = space_cache[key][0].shape[0] | |
break | |
# scale the points to [-1, 1] | |
points = scale_tensor( | |
isosurface_helper.grid_vertices.to(space_cache.device), | |
isosurface_helper.points_range, | |
[-1, 1], # hard coded isosurface_bbox | |
) | |
# get the sdf values | |
sdf_batch, deformation_batch = forward_field( | |
points[None, ...].expand(batch_size, -1, -1), | |
space_cache | |
) | |
# get the isosurface | |
mesh_list = [] | |
# check if the sdf is empty | |
# for sdf, deformation in zip(sdf_batch, deformation_batch): | |
for index in range(sdf_batch.shape[0]): | |
sdf = sdf_batch[index] | |
# the deformation may be None | |
if deformation_batch is None: | |
deformation = None | |
else: | |
deformation = deformation_batch[index] | |
# special case when all sdf values are positive or negative, thus no isosurface | |
if torch.all(sdf > 0) or torch.all(sdf < 0): | |
print(f"All sdf values are positive or negative, no isosurface") | |
sdf = torch.norm(points, dim=-1) - 1 | |
mesh = isosurface_helper(sdf, deformation) | |
mesh.v_pos = scale_tensor( | |
mesh.v_pos, | |
isosurface_helper.points_range, | |
[-1, 1], # hard coded isosurface_bbox | |
) | |
# TODO: implement outlier removal | |
# if cfg.isosurface_remove_outliers: | |
# mesh = mesh.remove_outlier(cfg.isosurface_outlier_n_faces_threshold) | |
mesh_list.append(mesh) | |
return mesh_list | |
def colorize_mesh( | |
space_cache, | |
export_fn, | |
mesh_list, | |
activation, | |
): | |
"""Colorize the mesh using the geometry's export function and space cache. | |
Args: | |
space_cache: The space cache containing feature information | |
export_fn: The export function from geometry that generates features | |
mesh_list: List of meshes to colorize | |
Returns: | |
List[Mesh]: List of colorized meshes | |
""" | |
# Process each mesh in the batch | |
for i, mesh in enumerate(mesh_list): | |
# Get vertex positions | |
points = mesh.v_pos[None, ...] # Add batch dimension [1, N, 3] | |
# Get the corresponding space cache slice for this mesh | |
if torch.is_tensor(space_cache): | |
space_cache_slice = space_cache[i:i+1] | |
elif isinstance(space_cache, dict): | |
space_cache_slice = {} | |
for key in space_cache.keys(): | |
space_cache_slice[key] = [ | |
weight[i:i+1] for weight in space_cache[key] | |
] | |
# Export features for the vertices | |
out = export_fn(points, space_cache_slice) | |
# Update vertex colors if features exist | |
if "features" in out: | |
features = out["features"].squeeze(0) # Remove batch dim [N, C] | |
# Convert features to RGB colors | |
mesh._v_rgb = activation(features) # Access private attribute directly | |
return mesh_list | |
class MeshExporter(SaverMixin): | |
def __init__(self, save_dir="outputs"): | |
self.save_dir = save_dir | |
os.makedirs(save_dir, exist_ok=True) | |
def get_save_dir(self): | |
return self.save_dir | |
def get_save_path(self, filename): | |
return os.path.join(self.save_dir, filename) | |
def convert_data(self, x): | |
if isinstance(x, torch.Tensor): | |
return x.detach().cpu().numpy() | |
return x | |
def export_obj( | |
mesh, | |
save_path, | |
save_normal = False, | |
): | |
""" | |
Export mesh data to OBJ file format. | |
Args: | |
mesh_data: Dictionary containing mesh data (vertices, faces, etc.) | |
save_path: Path to save the OBJ file | |
Returns: | |
List of saved file paths | |
""" | |
# Create exporter | |
exporter = MeshExporter(os.path.dirname(save_path)) | |
# Export mesh | |
save_paths = exporter.save_obj( | |
os.path.basename(save_path), | |
mesh, | |
save_mat=None, | |
save_normal=save_normal and mesh.v_nrm is not None, | |
save_uv=False, | |
save_vertex_color=mesh.v_rgb is not None, | |
) | |
return save_paths | |