Spaces:
Runtime error
Runtime error
""" | |
Copyright (c) Meta Platforms, Inc. and affiliates. | |
All rights reserved. | |
This source code is licensed under the license found in the | |
LICENSE file in the root directory of this source tree. | |
""" | |
import json | |
import numpy as np | |
import re | |
import torch | |
import torch as th | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from typing import Dict, Any | |
from visualize.ca_body.utils.quaternion import Quaternion | |
from pytorch3d.transforms import matrix_to_euler_angles | |
from typing import Optional, Tuple | |
import logging | |
logger = logging.getLogger(__name__) | |
class ParameterTransform(nn.Module): | |
def __init__(self, lbs_cfg_dict: Dict[str, Any]): | |
super().__init__() | |
# self.pose_names = list(lbs_cfg_dict["joint_names"]) | |
self.channel_names = list(lbs_cfg_dict["channel_names"]) | |
transform_offsets = torch.FloatTensor(lbs_cfg_dict["transform_offsets"]) | |
transform = torch.FloatTensor(lbs_cfg_dict["transform"]) | |
self.limits = lbs_cfg_dict["limits"] | |
self.nr_scaling_params = lbs_cfg_dict["nr_scaling_params"] | |
self.nr_position_params = lbs_cfg_dict["nr_position_params"] | |
self.nr_total_params = self.nr_scaling_params + self.nr_position_params | |
self.register_buffer("transform_offsets", transform_offsets) | |
self.register_buffer("transform", transform) | |
def forward(self, pose: th.Tensor) -> th.Tensor: | |
""" | |
:param pose: raw pose inputs, shape (batch_size, len(pose_names)) | |
:return: skeleton parameters, shape (batch_size, len(channel_names)*nr_skeleton_joints) | |
""" | |
return self.transform.mm(pose.t()).t() + self.transform_offsets | |
class LinearBlendSkinning(nn.Module): | |
def __init__( | |
self, | |
model_json: Dict[str, Any], | |
lbs_config_dict: Dict[str, Any], | |
num_max_skin_joints: int =8, | |
scale_path: str =None, | |
): | |
super().__init__() | |
model = model_json | |
self.param_transform = ParameterTransform(lbs_config_dict) | |
self.joint_names = [] | |
nr_joints = len(model["Skeleton"]["Bones"]) | |
joint_parents = torch.zeros((nr_joints, 1), dtype=torch.int64) | |
joint_rotation = torch.zeros((nr_joints, 4), dtype=torch.float32) | |
joint_offset = torch.zeros((nr_joints, 3), dtype=torch.float32) | |
for idx, bone in enumerate(model["Skeleton"]["Bones"]): | |
self.joint_names.append(bone["Name"]) | |
if bone["Parent"] > nr_joints: | |
joint_parents[idx] = -1 | |
else: | |
joint_parents[idx] = bone["Parent"] | |
joint_rotation[idx, :] = torch.FloatTensor(bone["PreRotation"]) | |
joint_offset[idx, :] = torch.FloatTensor(bone["TranslationOffset"]) | |
skin_model = model["SkinnedModel"] | |
mesh_vertices = torch.FloatTensor(skin_model["RestPositions"]) | |
mesh_normals = torch.FloatTensor(skin_model["RestVertexNormals"]) | |
weights = torch.FloatTensor([e[1] for e in skin_model["SkinningWeights"]]) | |
indices = torch.LongTensor([e[0] for e in skin_model["SkinningWeights"]]) | |
offsets = torch.LongTensor(skin_model["SkinningOffsets"]) | |
nr_vertices = len(offsets) - 1 | |
skin_weights = torch.zeros((nr_vertices, num_max_skin_joints), dtype=torch.float32) | |
skin_indices = torch.zeros((nr_vertices, num_max_skin_joints), dtype=torch.int64) | |
offset_right = offsets[1:] | |
for offset in range(num_max_skin_joints): | |
offset_left = offsets[:-1] + offset | |
skin_weights[offset_left < offset_right, offset] = weights[ | |
offset_left[offset_left < offset_right] | |
] | |
skin_indices[offset_left < offset_right, offset] = indices[ | |
offset_left[offset_left < offset_right] | |
] | |
mesh_faces = torch.IntTensor(skin_model["Faces"]["Indices"]).view(-1, 3) | |
mesh_texture_faces = torch.IntTensor(skin_model["Faces"]["TextureIndices"]).view(-1, 3) | |
mesh_texture_coords = torch.FloatTensor(skin_model["TextureCoordinates"]).view(-1, 2) | |
# zero_pose = torch.zeros((1, len(self.param_transform.pose_names)), dtype=torch.float32) | |
zero_pose = torch.zeros((1, self.param_transform.nr_total_params), dtype=torch.float32) | |
bind_state = solve_skeleton_state( | |
self.param_transform(zero_pose), joint_offset, joint_rotation, joint_parents | |
) | |
# self.register_buffer('mesh_vertices', mesh_vertices) # we want to train on rest pose | |
# self.mesh_vertices = nn.Parameter(mesh_vertices, requires_grad=optimize_mesh) | |
self.register_buffer("mesh_vertices", mesh_vertices) | |
self.register_buffer("joint_parents", joint_parents) | |
self.register_buffer("joint_rotation", joint_rotation) | |
self.register_buffer("joint_offset", joint_offset) | |
self.register_buffer("mesh_normals", mesh_normals) | |
self.register_buffer("mesh_faces", mesh_faces) | |
self.register_buffer("mesh_texture_faces", mesh_texture_faces) | |
self.register_buffer("mesh_texture_coords", mesh_texture_coords) | |
self.register_buffer("skin_weights", skin_weights) | |
self.register_buffer("skin_indices", skin_indices) | |
self.register_buffer("bind_state", bind_state) | |
self.register_buffer("rest_vertices", mesh_vertices) | |
# pre-compute joint weights | |
self.register_buffer("joints_weights", self.compute_joints_weights()) | |
if scale_path is not None: | |
scale = np.loadtxt(scale_path).astype(np.float32)[np.newaxis] | |
scale = scale[:, 0, :] if len(scale.shape) == 3 else scale | |
self.register_buffer("scale", torch.tensor(scale)) | |
def num_verts(self): | |
return self.mesh_vertices.size(0) | |
def num_joints(self): | |
return self.joint_offset.size(0) | |
def num_params(self): | |
return self.skin_weights.shape[-1] | |
def compute_rigid_transforms(self, global_pose: th.Tensor, local_pose: th.Tensor, scale: th.Tensor): | |
"""Returns rigid transforms.""" | |
params = torch.cat([global_pose, local_pose, scale], axis=-1) | |
params = self.param_transform(params) | |
return solve_skeleton_state( | |
params, self.joint_offset, self.joint_rotation, self.joint_parents | |
) | |
def compute_rigid_transforms_matrix(self, global_pose: th.Tensor, local_pose: th.Tensor, scale: th.Tensor): | |
params = torch.cat([global_pose, local_pose, scale], axis=-1) | |
params = self.param_transform(params) | |
states = solve_skeleton_state( | |
params, self.joint_offset, self.joint_rotation, self.joint_parents | |
) | |
return states_to_matrix(self.bind_state, states) | |
def compute_joints_weights(self, drop_empty=False): | |
"""Compute weights per joint given flattened weights-indices.""" | |
idxs_verts = torch.arange(self.num_verts)[:, np.newaxis].expand(-1, self.num_params) | |
weights_joints = torch.zeros( | |
(self.num_joints, self.num_verts), | |
dtype=torch.float32, | |
device=self.skin_weights.device, | |
) | |
weights_joints[self.skin_indices, idxs_verts] = self.skin_weights | |
if drop_empty: | |
weights_joints = weights_joints[weights_joints.sum(axis=-1).abs() > 0] | |
return weights_joints | |
def compute_root_rigid_transform(self, poses: th.Tensor) -> Tuple[th.Tensor, th.Tensor]: | |
"""Get a transform of the root joint.""" | |
scales = torch.zeros( | |
(poses.shape[0], self.nr_total_params - poses.shape[1]), | |
dtype=poses.dtype, | |
device=poses.device, | |
) | |
params = torch.cat((poses, scales), 1) | |
states = solve_skeleton_state( | |
self.param_transform(params), | |
self.joint_offset, | |
self.joint_rotation, | |
self.joint_parents, | |
) | |
mat = states_to_matrix(self.bind_state, states) | |
return mat[:, 1, :, 3], mat[:, 1, :, :3] | |
def compute_relative_rigid_transforms(self, global_pose: th.Tensor, local_pose: th.Tensor, scale: th.Tensor): | |
params = torch.cat([global_pose, local_pose, scale], axis=-1) | |
params = self.param_transform(params) | |
batch_size = params.shape[0] | |
joint_offset = self.joint_offset | |
joint_rotation = self.joint_rotation | |
# batch processing for parameters | |
jp = params.view((batch_size, -1, 7)) | |
lt = jp[:, :, 0:3] + joint_offset.unsqueeze(0) | |
lr = Quaternion.batchMul(joint_rotation.unsqueeze(0), Quaternion.batchFromXYZ(jp[:, :, 3:6])) | |
return torch.cat([lt, lr], axis=-1) | |
def skinning(self, bind_state: th.Tensor, vertices: th.Tensor, target_states: th.Tensor): | |
""" | |
Apply skinning to a set of states | |
Args: | |
b/bind_state: 1 x nr_joint x 8 bind state | |
v/vertices: 1 x nr_vertices x 3 vertices | |
t/target_states: batch_size x nr_joint x 8 current states | |
Returns: | |
batch_size x nr_vertices x 3 skinned vertices | |
""" | |
assert target_states.size()[1:] == bind_state.size()[1:] | |
mat = states_to_matrix(bind_state, target_states) | |
# apply skinning to vertices | |
vs = torch.matmul( | |
mat[:, self.skin_indices], | |
torch.cat((vertices, torch.ones_like(vertices[:, :, 0]).unsqueeze(2)), dim=2) | |
.unsqueeze(2) | |
.unsqueeze(4), | |
) | |
ws = self.skin_weights.unsqueeze(2).unsqueeze(3) | |
res = (vs * ws).sum(dim=2).squeeze(3) | |
return res | |
def unpose(self, poses: th.Tensor, scales: th.Tensor, verts: th.Tensor): | |
""" | |
:param poses: 100 (tx ty tz rx ry rz) params in blueman | |
:param scales: 29 (s) params in blueman | |
:return: | |
""" | |
# check shape of poses and scales | |
params = torch.cat((poses, scales), 1) | |
states = solve_skeleton_state( | |
self.param_transform(params), | |
self.joint_offset, | |
self.joint_rotation, | |
self.joint_parents, | |
) | |
return self.unskinning(self.bind_state, states, verts) | |
def unskinning(self, bind_state: th.Tensor, target_states: th.Tensor, verts: th.Tensor): | |
"""Apply skinning to a set of states | |
Args: | |
bind_state: [B, NJ, 8] - bind state | |
target_states: [B, NJ, 8] - current states | |
vertices: [B, V, 3] - vertices | |
Returns: | |
batch_size x nr_vertices x 3 skinned vertices | |
""" | |
assert target_states.size()[1:] == bind_state.size()[1:] | |
mat = states_to_matrix(bind_state, target_states) | |
ws = self.skin_weights[None, :, :, None, None] | |
sum_mat = (mat[:, self.skin_indices] * ws).sum(dim=2) | |
sum_mat4x4 = torch.cat((sum_mat, torch.zeros_like(sum_mat[:, :, :1, :])), dim=2) | |
sum_mat4x4[:, :, 3, 3] = 1.0 | |
verts_4d = torch.cat((verts, torch.ones_like(verts[:, :, :1])), dim=2).unsqueeze(3) | |
resmesh = [] | |
for i in range(sum_mat.shape[0]): | |
newmat = sum_mat4x4[i, :, :, :].contiguous() | |
invnewmat = newmat.inverse() | |
tmpvets = invnewmat.matmul(verts_4d[i]) | |
resmesh.append(tmpvets.unsqueeze(0)) | |
resmesh = torch.cat(resmesh) | |
return resmesh.squeeze(3)[..., :3].contiguous() | |
def forward(self, poses: th.Tensor, scales: th.Tensor, verts_unposed: Optional[th.Tensor] = None) -> th.Tensor: | |
""" | |
Args: | |
poses: [B, NP] - pose parametersa | |
scales: [B, NS] - additional scaling params | |
verts_unposed: [B, N, 3] - unposed vertices | |
Returns: | |
[B, N, 3] - posed vertices | |
""" | |
params = torch.cat((poses, scales), 1) | |
params_transformed = self.param_transform(params) | |
states = solve_skeleton_state( | |
params_transformed, | |
self.joint_offset, | |
self.joint_rotation, | |
self.joint_parents, | |
) | |
if verts_unposed is None: | |
mesh = self.skinning(self.bind_state, self.mesh_vertices.unsqueeze(0), states) | |
else: | |
mesh = self.skinning(self.bind_state, verts_unposed, states) | |
return mesh | |
def solve_skeleton_state(param: th.Tensor, joint_offset: th.Tensor, joint_rotation: th.Tensor, joint_parents: th.Tensor): | |
""" | |
:param param: batch_size x (7*nr_skeleton_joints) ParamTransform Outputs. | |
:return: batch_size x nr_skeleton_joints x 8 Skeleton States | |
8 stands form 3 translation + 4 rotation (quat) + 1 scale | |
""" | |
batch_size = param.shape[0] | |
# batch processing for parameters | |
jp = param.view((batch_size, -1, 7)) | |
lt = jp[:, :, 0:3] + joint_offset.unsqueeze(0) | |
lr = Quaternion.batchMul(joint_rotation.unsqueeze(0), Quaternion.batchFromXYZ(jp[:, :, 3:6])) | |
ls = torch.pow( | |
torch.tensor([2.0], dtype=torch.float32, device=param.device), | |
jp[:, :, 6].unsqueeze(2), | |
) | |
state = [] | |
for index, parent in enumerate(joint_parents): | |
if int(parent) != -1: | |
gr = Quaternion.batchMul(state[parent][:, :, 3:7], lr[:, index, :].unsqueeze(1)) | |
gt = ( | |
Quaternion.batchRot( | |
state[parent][:, :, 3:7], | |
lt[:, index, :].unsqueeze(1) * state[parent][:, :, 7].unsqueeze(2), | |
) | |
+ state[parent][:, :, 0:3] | |
) | |
gs = state[parent][:, :, 7].unsqueeze(2) * ls[:, index, :].unsqueeze(1) | |
state.append(torch.cat((gt, gr, gs), dim=2)) | |
else: | |
state.append( | |
torch.cat((lt[:, index, :], lr[:, index, :], ls[:, index, :]), dim=1).view( | |
(batch_size, 1, 8) | |
) | |
) | |
return torch.cat(state, dim=1) | |
def states_to_matrix(bind_state: th.Tensor, target_states: th.Tensor, return_transform: bool=False): | |
# multiply bind inverse with states | |
br = Quaternion.batchInvert(bind_state[:, :, 3:7]) | |
bs = bind_state[:, :, 7].unsqueeze(2).reciprocal() | |
bt = Quaternion.batchRot(br, -bind_state[:, :, 0:3]) * bs | |
# applying rotation | |
tr = Quaternion.batchMul(target_states[:, :, 3:7], br) | |
# applying scaling | |
ts = target_states[:, :, 7].unsqueeze(2) * bs | |
# applying transformation | |
tt = ( | |
Quaternion.batchRot(target_states[:, :, 3:7], bt * target_states[:, :, 7].unsqueeze(2)) | |
+ target_states[:, :, 0:3] | |
) | |
# convert to matrices | |
twx = 2.0 * tr[:, :, 0] * tr[:, :, 3] | |
twy = 2.0 * tr[:, :, 1] * tr[:, :, 3] | |
twz = 2.0 * tr[:, :, 2] * tr[:, :, 3] | |
txx = 2.0 * tr[:, :, 0] * tr[:, :, 0] | |
txy = 2.0 * tr[:, :, 1] * tr[:, :, 0] | |
txz = 2.0 * tr[:, :, 2] * tr[:, :, 0] | |
tyy = 2.0 * tr[:, :, 1] * tr[:, :, 1] | |
tyz = 2.0 * tr[:, :, 2] * tr[:, :, 1] | |
tzz = 2.0 * tr[:, :, 2] * tr[:, :, 2] | |
mat = torch.stack( | |
( | |
torch.stack((1.0 - (tyy + tzz), txy + twz, txz - twy), dim=2) * ts, | |
torch.stack((txy - twz, 1.0 - (txx + tzz), tyz + twx), dim=2) * ts, | |
torch.stack((txz + twy, tyz - twx, 1.0 - (txx + tyy)), dim=2) * ts, | |
tt, | |
), | |
dim=3, | |
) | |
if return_transform: | |
return mat, (tr, tt, ts) | |
return mat | |
def get_influence_map( | |
transform_raw: th.Tensor, pose_length=None, num_params_per_joint=7, eps=1.0e-6 | |
): | |
num_joints = transform_raw.shape[0] // num_params_per_joint | |
num_params = transform_raw.shape[-1] | |
if pose_length is None: | |
pose_length = num_params | |
assert pose_length <= num_params | |
transform_raw = transform_raw.reshape((num_joints, num_params_per_joint, num_params)) | |
return [ | |
torch.where(torch.abs(transform_raw[i, :, :pose_length]) > eps)[1].tolist() | |
for i in range(num_joints) | |
] | |
def compute_weights_joints_slow(lbs_weights, lbs_indices, num_joints): | |
num_verts = lbs_weights.shape[0] | |
weights_joints = torch.zeros((num_joints, num_verts), dtype=torch.float32) | |
for i in range(num_verts): | |
idx = lbs_indices[i, :] | |
weights_joints[idx, i] = lbs_weights[i, :] | |
return weights_joints | |
def load_momentum_cfg(model, lbs_config_txt_fh, nr_scaling_params=None): | |
def find(l, x): | |
try: | |
return l.index(x) | |
except ValueError: | |
return None | |
"""Load a parameter configuration file""" | |
channelNames = ["tx", "ty", "tz", "rx", "ry", "rz", "sc"] | |
paramNames = [] | |
joint_names = [] | |
for idx, bone in enumerate(model["Skeleton"]["Bones"]): | |
joint_names.append(bone["Name"]) | |
def findJointIndex(x): | |
return find(joint_names, x) | |
def findParameterIndex(x): | |
return find(paramNames, x) | |
limits = [] | |
# create empty result | |
transform_triplets = [] | |
lines = lbs_config_txt_fh.readlines() | |
# read until end | |
for line in lines: | |
# strip comments | |
line = line[: line.find("#")] | |
if line.find("limit") != -1: | |
r = re.search("limit ([\\w.]+) (\\w+) (.*)", line) | |
if r is None: | |
continue | |
if len(r.groups()) != 3: | |
logger.info("Failed to parse limit configuration line :\n " + line) | |
continue | |
# find parameter and/or joint index | |
fullname = r.groups()[0] | |
type = r.groups()[1] | |
remaining = r.groups()[2] | |
parameterIndex = findParameterIndex(fullname) | |
jointName = fullname.split(".") | |
jointIndex = findJointIndex(jointName[0]) | |
channelIndex = -1 | |
if jointIndex is not None and len(jointName) == 2: | |
# find matching channel name | |
channelIndex = channelNames.index(jointName[1]) | |
if channelIndex is None: | |
logger.info( | |
"Unknown joint channel name " | |
+ jointName[1] | |
+ " in parameter configuration line :\n " | |
+ line | |
) | |
continue | |
# only parse passive limits for now | |
if type == "minmax_passive" or type == "minmax": | |
# match [<float> , <float>] <optional weight> | |
rp = re.search( | |
"\\[\\s*([-+]?[0-9]*\\.?[0-9]+)\\s*,\\s*([-+]?[0-9]*\\.?[0-9]+)\\s*\\](\\s*[-+]?[0-9]*\\.?[0-9]+)?", | |
remaining, | |
) | |
if len(rp.groups()) != 3: | |
logger.info(f"Failed to parse passive limit configuration line :\n {line}") | |
continue | |
minVal = float(rp.groups()[0]) | |
maxVal = float(rp.groups()[1]) | |
weightVal = 1.0 | |
if len(rp.groups()) == 3 and not rp.groups()[2] is None: | |
weightVal = float(rp.groups()[2]) | |
# result.limits.append([jointIndex * 7 + channelIndex, minVal, maxVal]) | |
if channelIndex >= 0: | |
valueIndex = jointIndex * 7 + channelIndex | |
limit = { | |
"type": "LimitMinMaxJointValue", | |
"str": fullname, | |
"valueIndex": valueIndex, | |
"limits": [minVal, maxVal], | |
"weight": weightVal, | |
} | |
limits.append(limit) | |
else: | |
if parameterIndex is None: | |
logger.info(f"Unknown parameterIndex : {fullname}\n {line} {paramNames} ") | |
continue | |
limit = { | |
"type": "LimitMinMaxParameter", | |
"str": fullname, | |
"parameterIndex": parameterIndex, | |
"limits": [minVal, maxVal], | |
"weight": weightVal, | |
} | |
limits.append(limit) | |
# continue the remaining file | |
continue | |
# check for parameterset definitions and ignore | |
if line.find("parameterset") != -1: | |
continue | |
# use regex to parse definition | |
r = re.search("(\w+).(\w+)\s*=\s*(.*)", line) | |
if r is None: | |
continue | |
if len(r.groups()) != 3: | |
logger.info("Failed to parse parameter configuration line :\n " + line) | |
continue | |
# find joint name and parameter | |
jointIndex = findJointIndex(r.groups()[0]) | |
if jointIndex is None: | |
logger.info( | |
"Unknown joint name " | |
+ r.groups()[0] | |
+ " in parameter configuration line :\n " | |
+ line | |
) | |
continue | |
# find matching channel name | |
channelIndex = channelNames.index(r.groups()[1]) | |
if channelIndex is None: | |
logger.info( | |
"Unknown joint channel name " | |
+ r.groups()[1] | |
+ " in parameter configuration line :\n " | |
+ line | |
) | |
continue | |
valueIndex = jointIndex * 7 + channelIndex | |
# parse parameters | |
parameterList = r.groups()[2].split("+") | |
for parameterPair in parameterList: | |
parameterPair = parameterPair.strip() | |
r = re.search("\s*([+-]?[0-9]*\.?[0-9]*)\s\*\s(\w+)\s*", parameterPair) | |
if r is None or len(r.groups()) != 2: | |
logger.info( | |
"Malformed parameter description " | |
+ parameterPair | |
+ " in parameter configuration line :\n " | |
+ line | |
) | |
continue | |
val = float(r.groups()[0]) | |
parameter = r.groups()[1] | |
# check if parameter exists | |
parameterIndex = findParameterIndex(parameter) | |
if parameterIndex is None: | |
# no, create new parameter entry | |
parameterIndex = len(paramNames) | |
paramNames.append(parameter) | |
transform_triplets.append((valueIndex, parameterIndex, val)) | |
# set (dense) parameter_transformation matrix | |
transform = np.zeros((len(channelNames) * len(joint_names), len(paramNames)), dtype=np.float32) | |
for i, j, v in transform_triplets: | |
transform[i, j] = v | |
outputs = { | |
"model_param_names": paramNames, | |
"joint_names": joint_names, | |
"channel_names": channelNames, | |
"limits": limits, | |
"transform": transform, | |
"transform_offsets": np.zeros((1, len(channelNames) * len(joint_names)), dtype=np.float32), | |
} | |
# set number of scales automatically | |
if nr_scaling_params is None: | |
outputs.update(nr_scaling_params=len([s for s in paramNames if s.startswith("scale")])) | |
outputs.update(nr_position_params=len(paramNames) - outputs["nr_scaling_params"]) | |
return outputs | |
def compute_normalized_pose_quat(lbs, local_pose, scale): | |
"""Computes a normalized representation of the pose in quaternion space. | |
This is a delta between the per-joint local transformation and the bind state. | |
Returns: | |
[B, NJ, 4] - normalized rotations | |
""" | |
B = local_pose.shape[0] | |
global_pose_zero = th.zeros((B, 6), dtype=th.float32, device=local_pose.device) | |
params = lbs.param_transform(th.cat([global_pose_zero, local_pose, scale], axis=-1)) | |
params = params.reshape(B, -1, 7) | |
# applying rotation | |
# TODO: what is this? | |
rot_quat = Quaternion.batchMul(lbs.joint_rotation[np.newaxis], Quaternion.batchFromXYZ(params[:, :, 3:6])) | |
# removing the bind state | |
bind_rot_quat = Quaternion.batchInvert(lbs.bind_state[:, :, 3:7]) | |
return Quaternion.batchMul(rot_quat, bind_rot_quat) | |
def compute_root_transform_cuda(lbs_fn, poses, verts=None): | |
# NOTE: verts is not really necessary, | |
# NOTE: should be used in conjuncation with LBSCuda | |
B = poses.shape[0] | |
# NOTE: scales are zero (!) | |
_, _, _, state_t, state_r, state_s = lbs_fn(poses, vertices=verts) | |
bind_r = lbs_fn.joint_state_r_zero[np.newaxis, 1].expand(B, -1, -1) | |
bind_t = lbs_fn.joint_state_t_zero[np.newaxis, 1].expand(B, -1) | |
R_root = th.matmul(state_r[:, 1], bind_r) | |
t_root = ( | |
th.matmul(state_r[:, 1], (bind_t * state_s[:, 1])[..., np.newaxis])[..., 0] + state_t[:, 1] | |
) | |
return R_root, t_root | |
# def compute_joints_weights(lbs_fn: LinearBlendSkinningCuda, drop_empty: bool = False) -> th.Tensor: | |
# device = lbs_fn.skin_indices.device | |
# idxs_verts = th.arange(lbs_fn.nr_vertices)[:, np.newaxis].to(device) | |
# weights_joints = th.zeros( | |
# (lbs_fn.nr_joints, lbs_fn.nr_vertices), | |
# dtype=th.float32, | |
# device=lbs_fn.skin_indices.device, | |
# ) | |
# weights_joints[lbs_fn.skin_indices, idxs_verts] = lbs_fn.skin_weights | |
# if drop_empty: | |
# weights_joints = weights_joints[weights_joints.sum(axis=-1).abs() > 0] | |
# return weights_joints | |
# def compute_pose_regions(lbs_fn: LinearBlendSkinningCuda) -> np.ndarray: | |
# """Computes pose regions given a linear blend skinning function. | |
# Returns: | |
# np.ndarray of boolean masks of shape [nr_params, n_rvertices] | |
# """ | |
# weights = compute_joints_weights(lbs_fn).cpu().numpy() | |
# n_pos = lbs_fn.nr_position_params | |
# param_masks = np.zeros((n_pos, weights.shape[-1])) | |
# children = {j: [] for j in range(lbs_fn.nr_joints)} | |
# parents = {j: None for j in range(lbs_fn.nr_joints)} | |
# prec = {j: [] for j in range(lbs_fn.nr_joints)} | |
# for j in range(lbs_fn.nr_joints): | |
# parent_index = int(lbs_fn.joint_parents[j]) | |
# if parent_index == -1: | |
# continue | |
# children[parent_index].append(j) | |
# parents[j] = parent_index | |
# prec[j] = [parent_index, int(lbs_fn.joint_parents[parent_index])] | |
# # get parameters for each joint | |
# # j_to_p = get_influence_map(lbs_fn.param_transform.transform, n_pos) | |
# j_to_p = get_influence_map(lbs_fn.param_transform, n_pos) | |
# # get all the joints | |
# p_to_j = [[] for i in range(n_pos)] | |
# for j, pidx in enumerate(j_to_p): | |
# for p in pidx: | |
# if j not in p_to_j[p]: | |
# p_to_j[p].append(j) | |
# for p, jidx in enumerate(p_to_j): | |
# param_masks[p] = weights[jidx].sum(axis=0) | |
# if not np.any(param_masks[p]): | |
# assert len(jidx) == 1 | |
# jidx_c = children[jidx[0]][:] | |
# for jc in jidx_c[:]: | |
# jidx_c += children[jc] | |
# param_masks[p] = weights[jidx_c].sum(axis=0) | |
# return param_masks > 0.0 | |
def compute_pose_regions_legacy(lbs_fn) -> np.ndarray: | |
"""Computes pose regions given a linear blend skinning function.""" | |
weights = lbs_fn.joints_weights.cpu().numpy() | |
n_pos = lbs_fn.param_transform.nr_position_params | |
param_masks = np.zeros((n_pos, lbs_fn.joints_weights.shape[-1])) | |
children = {j: [] for j in range(lbs_fn.num_joints)} | |
parents = {j: None for j in range(lbs_fn.num_joints)} | |
prec = {j: [] for j in range(lbs_fn.num_joints)} | |
for j in range(lbs_fn.num_joints): | |
parent_index = int(lbs_fn.joint_parents[j, 0]) | |
if parent_index == -1: | |
continue | |
children[parent_index].append(j) | |
parents[j] = parent_index | |
prec[j] = [parent_index, int(lbs_fn.joint_parents[parent_index, 0])] | |
# get parameters for each joint | |
j_to_p = get_influence_map(lbs_fn.param_transform.transform, n_pos) | |
# get all the joints | |
p_to_j = [[] for i in range(n_pos)] | |
for j, pidx in enumerate(j_to_p): | |
for p in pidx: | |
if j not in p_to_j[p]: | |
p_to_j[p].append(j) | |
for p, jidx in enumerate(p_to_j): | |
param_masks[p] = weights[jidx].sum(axis=0) | |
if not np.any(param_masks[p]): | |
assert len(jidx) == 1 | |
jidx_c = children[jidx[0]][:] | |
for jc in jidx_c[:]: | |
jidx_c += children[jc] | |
param_masks[p] = weights[jidx_c].sum(axis=0) | |
return param_masks > 0.0 | |
def compute_pose_mask_uv(lbs_fn, geo_fn, uv_size, ksize=25): | |
device = geo_fn.index_image.device | |
pose_regions = compute_pose_regions(lbs_fn) | |
pose_regions = ( | |
th.as_tensor(pose_regions[6:], dtype=th.float32).permute(1, 0)[np.newaxis].to(device) | |
) | |
pose_regions_uv = geo_fn.to_uv(pose_regions) | |
pose_regions_uv = F.max_pool2d(pose_regions_uv, ksize, 1, padding=ksize // 2) | |
pose_cond_mask = (F.interpolate(pose_regions_uv, size=(uv_size, uv_size)) > 0.1).to(th.int32) | |
return pose_cond_mask | |
def parent_chain(joint_parents, idx, depth): | |
if depth == 0 or idx == 0: | |
return [] | |
parent_idx = int(joint_parents[idx]) | |
return [parent_idx] + parent_chain(joint_parents, parent_idx, depth - 1) | |
def joint_connectivity(nr_joints, joint_parents, chain_depth=2, pad_ancestors=False): | |
children = {j: [] for j in range(nr_joints)} | |
parents = {j: None for j in range(nr_joints)} | |
ancestors = {j: [] for j in range(nr_joints)} | |
for j in range(nr_joints): | |
parent_index = int(joint_parents[j]) | |
ancestors[j] = parent_chain(joint_parents, j, depth=chain_depth) | |
if pad_ancestors: | |
# adding itself | |
ancestors[j] += [j] * (chain_depth - len(ancestors[j])) | |
if parent_index == -1: | |
continue | |
children[parent_index].append(j) | |
parents[j] = parent_index | |
return { | |
'children': children, | |
'parents': parents, | |
'ancestors': ancestors, | |
} | |
# TODO: merge this with LinearBlendSkinning? | |
class LBSModule(nn.Module): | |
def __init__( | |
self, lbs_model_json, lbs_config_dict, lbs_template_verts, lbs_scale, global_scaling | |
): | |
super().__init__() | |
self.lbs_fn = LinearBlendSkinning(lbs_model_json, lbs_config_dict) | |
self.register_buffer("lbs_scale", th.as_tensor(lbs_scale, dtype=th.float32)) | |
self.register_buffer( | |
"lbs_template_verts", th.as_tensor(lbs_template_verts, dtype=th.float32) | |
) | |
self.register_buffer("global_scaling", th.as_tensor(global_scaling)) | |
def pose(self, verts_unposed, motion, template: Optional[th.Tensor] = None): | |
scale = self.lbs_scale.expand(motion.shape[0], -1) | |
if template is None: | |
template = self.lbs_template_verts | |
return self.lbs_fn(motion, scale, verts_unposed + template) * self.global_scaling | |
def unpose(self, verts, motion): | |
B = motion.shape[0] | |
scale = self.lbs_scale.expand(B, -1) | |
return ( | |
self.lbs_fn.unpose(motion, scale, verts / self.global_scaling) - self.lbs_template_verts | |
) | |
def template_pose(self, motion): | |
B = motion.shape[0] | |
scale = self.lbs_scale.expand(B, -1) | |
verts = self.lbs_template_verts[np.newaxis].expand(B, -1, -1) | |
return self.lbs_fn(motion, scale, verts) * self.global_scaling[np.newaxis] | |