ZhiyuanthePony's picture
remove_type_annotator
fc44d4b
import numpy as np
import torch
import torch.nn.functional as F
import trimesh
def dot(x, y):
return torch.sum(x * y, -1, keepdim=True)
class Mesh:
def __init__(
self, v_pos, t_pos_idx, material=None
):
self.v_pos = v_pos
self.t_pos_idx = t_pos_idx
self.material = material
self._v_nrm = None
self._v_tng = None
self._v_tex = None
self._t_tex_idx = None
self._v_rgb = None
self._edges = None
self.extras = {}
def add_extra(self, k, v) -> None:
self.extras[k] = v
def remove_outlier(self, n_face_threshold=5):
"""Remove outlier components with fewer faces than threshold."""
# Convert to trimesh
trimesh_mesh = self.as_trimesh()
# Split into connected components
components = trimesh_mesh.split(only_watertight=False)
# Filter components with few faces
valid_components = [c for c in components if len(c.faces) > n_face_threshold]
if len(valid_components) == 0:
# If no valid components, return the original mesh
return self
# Combine valid components
combined = trimesh.util.concatenate(valid_components)
# Convert back to our Mesh format
new_mesh = Mesh(
torch.tensor(combined.vertices, dtype=self.v_pos.dtype, device=self.v_pos.device),
torch.tensor(combined.faces, dtype=self.t_pos_idx.dtype, device=self.t_pos_idx.device)
)
return new_mesh
@property
def requires_grad(self):
return self.v_pos.requires_grad
@property
def v_nrm(self):
if self._v_nrm is None:
self._v_nrm = self._compute_vertex_normal()
return self._v_nrm
@property
def v_tng(self):
if self._v_tng is None:
self._v_tng = self._compute_vertex_tangent()
return self._v_tng
@property
def v_tex(self):
if self._v_tex is None:
self._v_tex, self._t_tex_idx = self._unwrap_uv()
return self._v_tex
@property
def t_tex_idx(self):
if self._t_tex_idx is None:
self._v_tex, self._t_tex_idx = self._unwrap_uv()
return self._t_tex_idx
@property
def v_rgb(self):
return self._v_rgb
@property
def edges(self):
if self._edges is None:
self._edges = self._compute_edges()
return self._edges
def _compute_vertex_normal(self):
i0 = self.t_pos_idx[:, 0]
i1 = self.t_pos_idx[:, 1]
i2 = self.t_pos_idx[:, 2]
v0 = self.v_pos[i0, :]
v1 = self.v_pos[i1, :]
v2 = self.v_pos[i2, :]
face_normals = torch.cross(v1 - v0, v2 - v0)
# Splat face normals to vertices
v_nrm = torch.zeros_like(self.v_pos)
v_nrm.scatter_add_(0, i0[:, None].repeat(1, 3), face_normals)
v_nrm.scatter_add_(0, i1[:, None].repeat(1, 3), face_normals)
v_nrm.scatter_add_(0, i2[:, None].repeat(1, 3), face_normals)
# Normalize, replace zero (degenerated) normals with some default value
v_nrm = torch.where(
dot(v_nrm, v_nrm) > 1e-20, v_nrm, torch.as_tensor([0.0, 0.0, 1.0]).to(v_nrm)
)
v_nrm = F.normalize(v_nrm, dim=1)
if torch.is_anomaly_enabled():
assert torch.all(torch.isfinite(v_nrm))
return v_nrm
def _compute_vertex_tangent(self):
vn_idx = [None] * 3
pos = [None] * 3
tex = [None] * 3
for i in range(0, 3):
pos[i] = self.v_pos[self.t_pos_idx[:, i]]
tex[i] = self.v_tex[self.t_tex_idx[:, i]]
# t_nrm_idx is always the same as t_pos_idx
vn_idx[i] = self.t_pos_idx[:, i]
tangents = torch.zeros_like(self.v_nrm)
tansum = torch.zeros_like(self.v_nrm)
# Compute tangent space for each triangle
uve1 = tex[1] - tex[0]
uve2 = tex[2] - tex[0]
pe1 = pos[1] - pos[0]
pe2 = pos[2] - pos[0]
nom = pe1 * uve2[..., 1:2] - pe2 * uve1[..., 1:2]
denom = uve1[..., 0:1] * uve2[..., 1:2] - uve1[..., 1:2] * uve2[..., 0:1]
# Avoid division by zero for degenerated texture coordinates
tang = nom / torch.where(
denom > 0.0, torch.clamp(denom, min=1e-6), torch.clamp(denom, max=-1e-6)
)
# Update all 3 vertices
for i in range(0, 3):
idx = vn_idx[i][:, None].repeat(1, 3)
tangents.scatter_add_(0, idx, tang) # tangents[n_i] = tangents[n_i] + tang
tansum.scatter_add_(
0, idx, torch.ones_like(tang)
) # tansum[n_i] = tansum[n_i] + 1
tangents = tangents / tansum
# Normalize and make sure tangent is perpendicular to normal
tangents = F.normalize(tangents, dim=1)
tangents = F.normalize(tangents - dot(tangents, self.v_nrm) * self.v_nrm)
if torch.is_anomaly_enabled():
assert torch.all(torch.isfinite(tangents))
return tangents
def _unwrap_uv(
self, xatlas_chart_options: dict = {}, xatlas_pack_options: dict = {}
):
import xatlas
atlas = xatlas.Atlas()
atlas.add_mesh(
self.v_pos.detach().cpu().numpy(),
self.t_pos_idx.cpu().numpy(),
)
co = xatlas.ChartOptions()
po = xatlas.PackOptions()
for k, v in xatlas_chart_options.items():
setattr(co, k, v)
for k, v in xatlas_pack_options.items():
setattr(po, k, v)
atlas.generate(co, po)
vmapping, indices, uvs = atlas.get_mesh(0)
vmapping = (
torch.from_numpy(
vmapping.astype(np.uint64, casting="same_kind").view(np.int64)
)
.to(self.v_pos.device)
.long()
)
uvs = torch.from_numpy(uvs).to(self.v_pos.device).float()
indices = (
torch.from_numpy(
indices.astype(np.uint64, casting="same_kind").view(np.int64)
)
.to(self.v_pos.device)
.long()
)
return uvs, indices
def unwrap_uv(
self, xatlas_chart_options: dict = {}, xatlas_pack_options: dict = {}
):
self._v_tex, self._t_tex_idx = self._unwrap_uv(
xatlas_chart_options, xatlas_pack_options
)
def set_vertex_color(self, v_rgb):
assert v_rgb.shape[0] == self.v_pos.shape[0]
self._v_rgb = v_rgb
def _compute_edges(self):
# Compute edges
edges = torch.cat(
[
self.t_pos_idx[:, [0, 1]],
self.t_pos_idx[:, [1, 2]],
self.t_pos_idx[:, [2, 0]],
],
dim=0,
)
edges = edges.sort()[0]
edges = torch.unique(edges, dim=0)
return edges
def normal_consistency(self):
edge_nrm = self.v_nrm[self.edges]
nc = (
1.0 - torch.cosine_similarity(edge_nrm[:, 0], edge_nrm[:, 1], dim=-1)
).mean()
return nc
def _laplacian_uniform(self):
# from stable-dreamfusion
# https://github.com/ashawkey/stable-dreamfusion/blob/8fb3613e9e4cd1ded1066b46e80ca801dfb9fd06/nerf/renderer.py#L224
verts, faces = self.v_pos, self.t_pos_idx
V = verts.shape[0]
F = faces.shape[0]
# Neighbor indices
ii = faces[:, [1, 2, 0]].flatten()
jj = faces[:, [2, 0, 1]].flatten()
adj = torch.stack([torch.cat([ii, jj]), torch.cat([jj, ii])], dim=0).unique(
dim=1
)
adj_values = torch.ones(adj.shape[1]).to(verts)
# Diagonal indices
diag_idx = adj[0]
# Build the sparse matrix
idx = torch.cat((adj, torch.stack((diag_idx, diag_idx), dim=0)), dim=1)
values = torch.cat((-adj_values, adj_values))
# The coalesce operation sums the duplicate indices, resulting in the
# correct diagonal
return torch.sparse_coo_tensor(idx, values, (V, V)).coalesce()
def laplacian(self):
with torch.no_grad():
L = self._laplacian_uniform()
loss = L.mm(self.v_pos)
loss = loss.norm(dim=1)
loss = loss.mean()
return loss
def to(self, device):
v_pos = self.v_pos.to(device)
t_pos_idx = self.t_pos_idx.to(device)
return Mesh(v_pos, t_pos_idx)
def as_trimesh(self):
vertices = self.v_pos.detach().cpu().numpy()
faces = self.t_pos_idx.detach().cpu().numpy()
mesh = trimesh.Trimesh(
vertices=vertices,
faces=faces,
process=False
)
# Add texture if available
if hasattr(self, 'albedo_map') and self.albedo_map is not None:
# Create texture visuals
uv = self.v_tex.detach().cpu().numpy()
# Create texture visuals
visual = trimesh.visual.texture.TextureVisuals(
uv=uv,
material=trimesh.visual.material.SimpleMaterial()
)
mesh.visual = visual
return mesh
def scale_tensor(x, input_range, target_range):
"""Scale tensor from input_range to target_range."""
x_unit = (x - input_range[0]) / (input_range[1] - input_range[0])
x_scaled = x_unit * (target_range[1] - target_range[0]) + target_range[0]
return x_scaled