Spaces:
Sleeping
Sleeping
""" | |
Copyright©2024 Max-Planck-Gesellschaft zur Förderung | |
der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute | |
for Intelligent Systems. All rights reserved. | |
Author: Marilyn Keller | |
See https://skel.is.tue.mpg.de/license.html for licensing and contact information. | |
""" | |
import os | |
import torch.nn as nn | |
import torch | |
import numpy as np | |
import pickle as pkl | |
from typing import NewType, Optional | |
from lib.body_models.skel.joints_def import curve_torch_3d, left_scapula, right_scapula | |
from lib.body_models.skel.osim_rot import ConstantCurvatureJoint, CustomJoint, EllipsoidJoint, PinJoint, WalkerKnee | |
from lib.body_models.skel.utils import build_homog_matrix, rotation_matrix_from_vectors, sparce_coo_matrix2tensor, with_zeros, matmul_chain | |
from dataclasses import dataclass, fields | |
from lib.body_models.skel.kin_skel import scaling_keypoints, pose_param_names, smpl_joint_corresp | |
import lib.body_models.skel.config as cg | |
Tensor = NewType('Tensor', torch.Tensor) | |
class ModelOutput: | |
vertices: Optional[Tensor] = None | |
joints: Optional[Tensor] = None | |
full_pose: Optional[Tensor] = None | |
global_orient: Optional[Tensor] = None | |
transl: Optional[Tensor] = None | |
v_shaped: Optional[Tensor] = None | |
def __getitem__(self, key): | |
return getattr(self, key) | |
def get(self, key, default=None): | |
return getattr(self, key, default) | |
def __iter__(self): | |
return self.keys() | |
def keys(self): | |
keys = [t.name for t in fields(self)] | |
return iter(keys) | |
def values(self): | |
values = [getattr(self, t.name) for t in fields(self)] | |
return iter(values) | |
def items(self): | |
data = [(t.name, getattr(self, t.name)) for t in fields(self)] | |
return iter(data) | |
class SKELOutput(ModelOutput): | |
betas: Optional[Tensor] = None | |
body_pose: Optional[Tensor] = None | |
skin_verts: Optional[Tensor] = None | |
skel_verts: Optional[Tensor] = None | |
joints: Optional[Tensor] = None | |
joints_ori: Optional[Tensor] = None | |
betas: Optional[Tensor] = None | |
poses: Optional[Tensor] = None | |
trans : Optional[Tensor] = None | |
pose_offsets : Optional[Tensor] = None | |
joints_tpose : Optional[Tensor] = None | |
v_skin_shaped : Optional[Tensor] = None | |
class SKEL(nn.Module): | |
num_betas = 10 | |
def __init__(self, gender, model_path=None, custom_joint_reg_path=None, **kwargs): | |
super(SKEL, self).__init__() | |
if gender not in ['male', 'female']: | |
raise RuntimeError(f'Invalid Gender, got {gender}') | |
self.gender = gender | |
if model_path is None: | |
# skel_file = f"/Users/mkeller2/Data/skel_models_v1.0/skel_{gender}.pkl" | |
skel_file = os.path.join(cg.skel_folder, f"skel_{gender}.pkl") | |
else: | |
skel_file = os.path.join(model_path, f"skel_{gender}.pkl") | |
assert os.path.exists(skel_file), f"Skel model file {skel_file} does not exist" | |
skel_data = pkl.load(open(skel_file, 'rb')) | |
# Check that the version of the skel model is compatible with this loader | |
assert 'version' in skel_data, f"Expected version 1.1.1 of the SKEL picke. Please download the latest skel pkl versions from https://skel.is.tue.mpg.de/download.html" | |
version = skel_data['version'] | |
assert version == '1.1.1', f"Expected version 1.1.1, got {version}. Please download the latest skel pkl versions from https://skel.is.tue.mpg.de/download.html" | |
self.num_betas = 10 | |
self.num_q_params = 46 | |
self.bone_names = skel_data['bone_names'] | |
self.num_joints = skel_data['J_regressor_osim'].shape[0] | |
self.num_joints_smpl = skel_data['J_regressor'].shape[0] | |
self.joints_name = skel_data['joints_name'] | |
self.pose_params_name = skel_data['pose_params_name'] | |
# register the template meshes | |
self.register_buffer('skin_template_v', torch.FloatTensor(skel_data['skin_template_v'])) | |
self.register_buffer('skin_f', torch.LongTensor(skel_data['skin_template_f'])) | |
self.register_buffer('skel_template_v', torch.FloatTensor(skel_data['skel_template_v'])) | |
self.register_buffer('skel_f', torch.LongTensor(skel_data['skel_template_f'])) | |
# Shape corrective blend shapes | |
self.register_buffer('shapedirs', torch.FloatTensor(np.array(skel_data['shapedirs'][:,:,:self.num_betas]))) | |
self.register_buffer('posedirs', torch.FloatTensor(np.array(skel_data['posedirs']))) | |
# Model sparse joints regressor, regresses joints location from a mesh | |
self.register_buffer('J_regressor', sparce_coo_matrix2tensor(skel_data['J_regressor'])) | |
# Regress the anatomical joint location with a regressor learned from BioAmass | |
if custom_joint_reg_path is not None: | |
J_regressor_skel = pkl.load(open(custom_joint_reg_path, 'rb')) | |
if 'scipy.sparse' in str(type(J_regressor_skel)): | |
J_regressor_skel = J_regressor_skel.todense() | |
self.register_buffer('J_regressor_osim', torch.FloatTensor(J_regressor_skel)) | |
print('WARNING: Using custom joint regressor') | |
else: | |
self.register_buffer('J_regressor_osim', sparce_coo_matrix2tensor(skel_data['J_regressor_osim'], make_dense=True)) | |
self.register_buffer('per_joint_rot', torch.FloatTensor(skel_data['per_joint_rot'])) | |
# Skin model skinning weights | |
self.register_buffer('skin_weights', sparce_coo_matrix2tensor(skel_data['skin_weights'])) | |
# Skeleton model skinning weights | |
self.register_buffer('skel_weights', sparce_coo_matrix2tensor(skel_data['skel_weights'])) | |
self.register_buffer('skel_weights_rigid', sparce_coo_matrix2tensor(skel_data['skel_weights_rigid'])) | |
# Kinematic tree of the model | |
self.register_buffer('kintree_table', torch.from_numpy(skel_data['osim_kintree_table'].astype(np.int64))) | |
self.register_buffer('parameter_mapping', torch.from_numpy(skel_data['parameter_mapping'].astype(np.int64))) | |
# transformation from osim can pose to T pose | |
self.register_buffer('tpose_transfo', torch.FloatTensor(skel_data['tpose_transfo'])) | |
# transformation from osim can pose to A pose | |
self.register_buffer('apose_transfo', torch.FloatTensor(skel_data['apose_transfo'])) | |
self.register_buffer('apose_rel_transfo', torch.FloatTensor(skel_data['apose_rel_transfo'])) | |
# Indices of bones which orientation should not vary with beta in T pose: | |
joint_idx_fixed_beta = [0, 5, 10, 13, 18, 23] | |
self.register_buffer('joint_idx_fixed_beta', torch.IntTensor(joint_idx_fixed_beta)) | |
id_to_col = {self.kintree_table[1, i].item(): i for i in range(self.kintree_table.shape[1])} | |
self.register_buffer('parent', torch.LongTensor( | |
[id_to_col[self.kintree_table[0, it].item()] for it in range(1, self.kintree_table.shape[1])])) | |
# child array | |
# TODO create this array in the SKEL creator | |
child_array = [] | |
Nj = self.num_joints | |
for i in range(0, Nj): | |
try: | |
j_array = torch.where(self.kintree_table[0] == i)[0] # candidate child lines | |
if len(j_array) == 0: | |
child_index = 0 | |
else: | |
j = j_array[0] | |
if j>=len(self.kintree_table[1]): | |
child_index = 0 | |
else: | |
child_index = self.kintree_table[1,j].item() | |
child_array.append(child_index) | |
except: | |
import ipdb; ipdb.set_trace() | |
# print(f"child_array: ") | |
# [print(i,child_array[i]) for i in range(0, Nj)] | |
self.register_buffer('child', torch.LongTensor(child_array)) | |
# Instantiate joints | |
self.joints_dict = nn.ModuleList([ | |
CustomJoint(axis=[[0,0,1], [1,0,0], [0,1,0]], axis_flip=[1, 1, 1]), # 0 pelvis | |
CustomJoint(axis=[[0,0,1], [1,0,0], [0,1,0]], axis_flip=[1, 1, 1]), # 1 femur_r | |
WalkerKnee(), # 2 tibia_r | |
PinJoint(parent_frame_ori = [0.175895, -0.105208, 0.0186622]), # 3 talus_r Field taken from .osim Joint-> frames -> PhysicalOffsetFrame -> orientation | |
PinJoint(parent_frame_ori = [-1.76818999, 0.906223, 1.8196000]), # 4 calcn_r | |
PinJoint(parent_frame_ori = [-3.141589999, 0.6199010, 0]), # 5 toes_r | |
CustomJoint(axis=[[0,0,1], [1,0,0], [0,1,0]], axis_flip=[1, -1, -1]), # 6 femur_l | |
WalkerKnee(), # 7 tibia_l | |
PinJoint(parent_frame_ori = [0.175895, -0.105208, 0.0186622]), # 8 talus_l | |
PinJoint(parent_frame_ori = [1.768189999 ,-0.906223, 1.8196000]), # 9 calcn_l | |
PinJoint(parent_frame_ori = [-3.141589999, -0.6199010, 0]), # 10 toes_l | |
ConstantCurvatureJoint(axis=[[1,0,0], [0,0,1], [0,1,0]], axis_flip=[1, 1, 1]), # 11 lumbar | |
ConstantCurvatureJoint(axis=[[1,0,0], [0,0,1], [0,1,0]], axis_flip=[1, 1, 1]), # 12 thorax | |
ConstantCurvatureJoint(axis=[[1,0,0], [0,0,1], [0,1,0]], axis_flip=[1, 1, 1]), # 13 head | |
EllipsoidJoint(axis=[[0,1,0], [0,0,1], [1,0,0]], axis_flip=[1, -1, -1]), # 14 scapula_r | |
CustomJoint(axis=[[1,0,0], [0,1,0], [0,0,1]], axis_flip=[1, 1, 1]), # 15 humerus_r | |
CustomJoint(axis=[[0.0494, 0.0366, 0.99810825]], axis_flip=[[1]]), # 16 ulna_r | |
CustomJoint(axis=[[-0.01716099, 0.99266564, -0.11966796]], axis_flip=[[1]]), # 17 radius_r | |
CustomJoint(axis=[[1,0,0], [0,0,-1]], axis_flip=[1, 1]), # 18 hand_r | |
EllipsoidJoint(axis=[[0,1,0], [0,0,1], [1,0,0]], axis_flip=[1, 1, 1]), # 19 scapula_l | |
CustomJoint(axis=[[1,0,0], [0,1,0], [0,0,1]], axis_flip=[1, 1, 1]), # 20 humerus_l | |
CustomJoint(axis=[[-0.0494, -0.0366, 0.99810825]], axis_flip=[[1]]), # 21 ulna_l | |
CustomJoint(axis=[[0.01716099, -0.99266564, -0.11966796]], axis_flip=[[1]]), # 22 radius_l | |
CustomJoint(axis=[[-1,0,0], [0,0,-1]], axis_flip=[1, 1]), # 23 hand_l | |
]) | |
def pose_params_to_rot(self, osim_poses): | |
""" Transform the pose parameters to 3x3 rotation matrices | |
Each parameter is mapped to a joint as described in joint_dict. | |
The specific joint object is then used to compute the rotation matrix. | |
""" | |
B = osim_poses.shape[0] | |
Nj = self.num_joints | |
ident = torch.eye(3, dtype=osim_poses.dtype).to(osim_poses.device) | |
Rp = ident.unsqueeze(0).unsqueeze(0).repeat(B, Nj,1,1) | |
tp = torch.zeros(B, Nj, 3).to(osim_poses.device) | |
start_index = 0 | |
for i in range(0, Nj): | |
joint_object = self.joints_dict[i] | |
end_index = start_index + joint_object.nb_dof | |
Rp[:, i] = joint_object.q_to_rot(osim_poses[:, start_index:end_index]) | |
start_index = end_index | |
return Rp, tp | |
def params_name_to_index(self, param_name): | |
assert param_name in pose_param_names | |
param_index = pose_param_names.index(param_name) | |
return param_index | |
def forward(self, poses, betas, trans, poses_type='skel', skelmesh=True, dJ=None, pose_dep_bs=True): | |
""" | |
params | |
poses : B x 46 tensor of pose parameters | |
betas : B x 10 tensor of shape parameters, same as SMPL | |
trans : B x 3 tensor of translation | |
poses_type : str, 'skel', should not be changed | |
skelmesh : bool, if True, returns the skeleton vertices. The skeleton mesh is heavy so to fit on GPU memory, set to False when not needed. | |
dJ : B x 24 x 3 tensor of the offset of the joints location from the anatomical regressor. If None, the offset is set to 0. | |
pose_dep_bs : bool, if True (default), applies the pose dependant blend shapes. If False, the pose dependant blend shapes are not applied. | |
return SKELOutput class with the following fields: | |
betas : Bx10 tensor of shape parameters | |
poses : Bx46 tensor of pose parameters | |
skin_verts : Bx6890x3 tensor of skin vertices | |
skel_verts : tensor of skeleton vertices | |
joints : Bx24x3 tensor of joints location | |
joints_ori : Bx24x3x3 tensor of joints orientation | |
trans : Bx3 pose dependant blend shapes offsets | |
pose_offsets : Bx6080x3 pose dependant blend shapes offsets | |
joints_tpose : Bx24x3 3D joints location in T pose | |
In this function we use the following conventions: | |
B : batch size | |
Ns : skin vertices | |
Nk : skeleton vertices | |
""" | |
Ns = self.skin_template_v.shape[0] # nb skin vertices | |
Nk = self.skel_template_v.shape[0] # nb skeleton vertices | |
Nj = self.num_joints | |
B = poses.shape[0] | |
device = poses.device | |
# Check the shapes of the inputs | |
assert len(betas.shape) == 2, f"Betas should be of shape (B, {self.num_betas}), but got {betas.shape}" | |
assert poses.shape[0] == betas.shape[0], f"Expected poses and betas to have the same batch size, but got {poses.shape[0]} and {betas.shape[0]}" | |
assert poses.shape[0] == trans.shape[0], f"Expected poses and betas to have the same batch size, but got {poses.shape[0]} and {trans.shape[0]}" | |
if dJ is not None: | |
assert len(dJ.shape) == 3, f"Expected dJ to have shape (B, {Nj}, 3), but got {dJ.shape}" | |
assert dJ is None or dJ.shape[0] == B, f"Expected dJ to have the same batch size as poses, but got {dJ.shape[0]} and {poses.shape[0]}" | |
assert dJ.shape[1] == Nj, f"Expected dJ to have the same number of joints as the model, but got {dJ.shape[1]} and {Nj}" | |
# Check the device of the inputs | |
assert betas.device == device, f"Betas should be on device {device}, but got {betas.device}" | |
assert trans.device == device, f"Trans should be on device {device}, but got {trans.device}" | |
skin_v0 = self.skin_template_v[None, :] | |
skel_v0 = self.skel_template_v[None, :] | |
betas = betas[:, :, None] # TODO Name the expanded beta differently | |
# TODO clean this part | |
assert poses_type in ['skel', 'bsm'], f"got {poses_type}" | |
if poses_type == 'bsm': | |
assert poses.shape[1] == self.num_q_params - 3, f'With poses_type bsm, expected parameters of shape (B, {self.num_q_params - 3}, got {poses.shape}' | |
poses_bsm = poses | |
poses_skel = torch.zeros(B, self.num_q_params) | |
poses_skel[:,:3] = poses_bsm[:, :3] | |
trans = poses_bsm[:, 3:6] # In BSM parametrization, the hips translation is given by params 3 to 5 | |
poses_skel[:, 3:] = poses_bsm | |
poses = poses_skel | |
else: | |
assert poses.shape[1] == self.num_q_params, f'With poses_type skel, expected parameters of shape (B, {self.num_q_params}), got {poses.shape}' | |
pass | |
# Load poses as expected | |
# Distinction bsm skel. by default it will be bsm | |
# ------- Shape ---------- | |
# Apply the beta offset to the template | |
shapedirs = self.shapedirs.view(-1, self.num_betas)[None, :].expand(B, -1, -1) # B x D*Ns x num_betas | |
v_shaped = skin_v0 + torch.matmul(shapedirs, betas).view(B, Ns, 3) | |
# ------- Joints ---------- | |
# Regress the anatomical joint location | |
J = torch.einsum('bik,ji->bjk', [v_shaped, self.J_regressor_osim]) # BxJx3 # osim regressor | |
# J = self.apose_transfo[:, :3, -1].view(1, Nj, 3).expand(B, -1, -1) # Osim default pose joints location | |
if dJ is not None: | |
J = J + dJ | |
J_tpose = J.clone() | |
# Local translation | |
J_ = J.clone() # BxJx3 | |
J_[:, 1:, :] = J[:, 1:, :] - J[:, self.parent, :] | |
t = J_[:, :, :, None] # BxJx3x1 | |
# ------- Bones transformation matrix---------- | |
# Bone initial transform to go from unposed to SMPL T pose | |
Rk01 = self.compute_bone_orientation(J, J_) | |
# BSM default pose rotations | |
Ra = self.apose_rel_transfo[:, :3, :3].view(1, Nj, 3,3).expand(B, Nj, 3, 3) | |
# Local bone rotation given by the pose param | |
Rp, tp = self.pose_params_to_rot(poses) # BxNjx3x3 pose params to rotation | |
R = matmul_chain([Rk01, Ra.transpose(2,3), Rp, Ra, Rk01.transpose(2,3)]) | |
###### Compute translation for non pure rotation joints | |
t_posed = t.clone() | |
# Scapula | |
thorax_width = torch.norm(J[:, 19, :] - J[:, 14, :], dim=1) # Distance between the two scapula joints, size B | |
thorax_height = torch.norm(J[:, 12, :] - J[:, 11, :], dim=1) # Distance between the two scapula joints, size B | |
angle_abduction = poses[:,26] | |
angle_elevation = poses[:,27] | |
angle_rot = poses[:,28] | |
angle_zero = torch.zeros_like(angle_abduction) | |
t_posed[:,14] = t_posed[:,14] + \ | |
(right_scapula(angle_abduction, angle_elevation, angle_rot, thorax_width, thorax_height).view(-1,3,1) | |
- right_scapula(angle_zero, angle_zero, angle_zero, thorax_width, thorax_height).view(-1,3,1)) | |
angle_abduction = poses[:,36] | |
angle_elevation = poses[:,37] | |
angle_rot = poses[:,38] | |
angle_zero = torch.zeros_like(angle_abduction) | |
t_posed[:,19] = t_posed[:,19] + \ | |
(left_scapula(angle_abduction, angle_elevation, angle_rot, thorax_width, thorax_height).view(-1,3,1) | |
- left_scapula(angle_zero, angle_zero, angle_zero, thorax_width, thorax_height).view(-1,3,1)) | |
# Knee_r | |
# TODO add the Walker knee offset | |
# bone_scale = self.compute_bone_scale(J_,J, skin_v0, v_shaped) | |
# f1 = poses[:, 2*3+2].clone() | |
# scale_femur = bone_scale[:, 2] | |
# factor = 0.076/0.080 * scale_femur # The template femur medial laterak spacing #66 | |
# f = -f1*180/torch.pi #knee_flexion | |
# varus = (0.12367*f)-0.0009*f**2 | |
# introt = 0.3781*f-0.001781*f**2 | |
# ydis = (-0.0683*f | |
# + 8.804e-4 * f**2 | |
# - 3.750e-06*f**3 | |
# )/1000*factor # up-down | |
# zdis = (-0.1283*f | |
# + 4.796e-4 * f**2)/1000*factor # | |
# import ipdb; ipdb.set_trace() | |
# poses[:, 9] = poses[:, 9] + varus | |
# t_posed[:,2] = t_posed[:,2] + torch.stack([torch.zeros_like(ydis), ydis, zdis], dim=1).view(-1,3,1) | |
# poses[:, 2*3+2]=0 | |
# t_unposed = torch.zeros_like(t_posed) | |
# t_unposed[:,2] = torch.stack([torch.zeros_like(ydis), ydis, zdis], dim=1).view(-1,3,1) | |
# Spine | |
lumbar_bending = poses[:,17] | |
lumbar_extension = poses[:,18] | |
angle_zero = torch.zeros_like(lumbar_bending) | |
interp_t = torch.ones_like(lumbar_bending) | |
l = torch.abs(J[:, 11, 1] - J[:, 0, 1]) # Length of the spine section along y axis | |
t_posed[:,11] = t_posed[:,11] + \ | |
(curve_torch_3d(lumbar_bending, lumbar_extension, t=interp_t, l=l) | |
- curve_torch_3d(angle_zero, angle_zero, t=interp_t, l=l)) | |
thorax_bending = poses[:,20] | |
thorax_extension = poses[:,21] | |
angle_zero = torch.zeros_like(thorax_bending) | |
interp_t = torch.ones_like(thorax_bending) | |
l = torch.abs(J[:, 12, 1] - J[:, 11, 1]) # Length of the spine section | |
t_posed[:,12] = t_posed[:,12] + \ | |
(curve_torch_3d(thorax_bending, thorax_extension, t=interp_t, l=l) | |
- curve_torch_3d(angle_zero, angle_zero, t=interp_t, l=l)) | |
head_bending = poses[:, 23] | |
head_extension = poses[:,24] | |
angle_zero = torch.zeros_like(head_bending) | |
interp_t = torch.ones_like(head_bending) | |
l = torch.abs(J[:, 13, 1] - J[:, 12, 1]) # Length of the spine section | |
t_posed[:,13] = t_posed[:,13] + \ | |
(curve_torch_3d(head_bending, head_extension, t=interp_t, l=l) | |
- curve_torch_3d(angle_zero, angle_zero, t=interp_t, l=l)) | |
# ------- Body surface transformation matrix---------- | |
G_ = torch.cat([R, t_posed], dim=-1) # BxJx3x4 local transformation matrix | |
pad_row = torch.FloatTensor([0, 0, 0, 1]).to(device).view(1, 1, 1, 4).expand(B, Nj, -1, -1) # BxJx1x4 | |
G_ = torch.cat([G_, pad_row], dim=2) # BxJx4x4 padded to be 4x4 matrix an enable multiplication for the kinematic chain | |
# Global transform | |
G = [G_[:, 0].clone()] | |
for i in range(1, Nj): | |
G.append(torch.matmul(G[self.parent[i - 1]], G_[:, i, :, :])) | |
G = torch.stack(G, dim=1) | |
# ------- Pose dependant blend shapes ---------- | |
if pose_dep_bs is False: | |
v_shaped_pd = v_shaped | |
else: | |
# Note : Those should be retrained for SKEL as the SKEL joints location are different from SMPL. | |
# But the current version lets use get decent pose dependant deformations for the shoulders, belly and knies | |
ident = torch.eye(3, dtype=v_shaped.dtype, device=device) | |
# We need the per SMPL joint bone transform to compute pose dependant blend shapes. | |
# Initialize each joint rotation with identity | |
Rsmpl = ident.unsqueeze(0).unsqueeze(0).expand(B, self.num_joints_smpl, -1, -1).clone() # BxNjx3x3 | |
Rskin = G_[:, :, :3, :3] # BxNjx3x3 | |
Rsmpl[:, smpl_joint_corresp] = Rskin[:] # BxNjx3x3 pose params to rotation | |
pose_feature = Rsmpl[:, 1:].view(B, -1, 3, 3) - ident | |
pose_offsets = torch.matmul(pose_feature.view(B, -1), | |
self.posedirs.view(Ns*3, -1).T).view(B, -1, 3) | |
v_shaped_pd = v_shaped + pose_offsets | |
########################################################################################## | |
#Transform skin mesh | |
############################################################################################ | |
# Apply global transformation to the template mesh | |
rest = torch.cat([J, torch.zeros(B, Nj, 1).to(device)], dim=2).view(B, Nj, 4, 1) # BxJx4x1 | |
zeros = torch.zeros(B, Nj, 4, 3).to(device) # BxJx4x3 | |
rest = torch.cat([zeros, rest], dim=-1) # BxJx4x4 | |
rest = torch.matmul(G, rest) # This is a 4x4 transformation matrix that only contains translation to the rest pose joint location | |
Gskin = G - rest | |
# Compute per vertex transformation matrix (after weighting) | |
T = torch.matmul(self.skin_weights, Gskin.permute(1, 0, 2, 3).contiguous().view(Nj, -1)).view(Ns, B, 4,4).transpose(0, 1) | |
rest_shape_h = torch.cat([v_shaped_pd, torch.ones_like(v_shaped_pd)[:, :, [0]]], dim=-1) | |
v_posed = torch.matmul(T, rest_shape_h[:, :, :, None])[:, :, :3, 0] | |
# translation | |
v_trans = v_posed + trans[:,None,:] | |
########################################################################################## | |
#Transform joints | |
############################################################################################ | |
# import ipdb; ipdb.set_trace() | |
root_transform = with_zeros(torch.cat((R[:,0],J[:,0][:,:,None]),2)) | |
results = [root_transform] | |
for i in range(0, self.parent.shape[0]): | |
transform_i = with_zeros(torch.cat((R[:, i + 1], t_posed[:,i+1]), 2)) | |
curr_res = torch.matmul(results[self.parent[i]],transform_i) | |
results.append(curr_res) | |
results = torch.stack(results, dim=1) | |
posed_joints = results[:, :, :3, 3] | |
J_transformed = posed_joints + trans[:,None,:] | |
########################################################################################## | |
# Transform skeleton | |
############################################################################################ | |
if skelmesh: | |
G_bones = None | |
# Shape the skeleton by scaling its bones | |
skel_rest_shape_h = torch.cat([skel_v0, torch.ones_like(skel_v0)[:, :, [0]]], dim=-1).expand(B, Nk, -1) # (1,Nk,3) | |
# compute the bones scaling from the kinematic tree and skin mesh | |
#with torch.no_grad(): | |
# TODO: when dJ is optimized the shape of the mesh should be affected by the gradients | |
bone_scale = self.compute_bone_scale(J_, v_shaped, skin_v0) | |
# Apply bone meshes scaling: | |
skel_v_shaped = torch.cat([(torch.matmul(bone_scale[:,:,0], self.skel_weights_rigid.T) * skel_rest_shape_h[:, :, 0])[:, :, None], | |
(torch.matmul(bone_scale[:,:,1], self.skel_weights_rigid.T) * skel_rest_shape_h[:, :, 1])[:, :, None], | |
(torch.matmul(bone_scale[:,:,2], self.skel_weights_rigid.T) * skel_rest_shape_h[:, :, 2])[:, :, None], | |
(torch.ones(B, Nk, 1).to(device)) | |
], dim=-1) | |
# Align the bones with the proper axis | |
Gk01 = build_homog_matrix(Rk01, J.unsqueeze(-1)) # BxJx4x4 | |
T = torch.matmul(self.skel_weights_rigid, Gk01.permute(1, 0, 2, 3).contiguous().view(Nj, -1)).view(Nk, B, 4,4).transpose(0, 1) #[1, 48757, 3, 3] | |
skel_v_align = torch.matmul(T, skel_v_shaped[:, :, :, None])[:, :, :, 0] | |
# This transfo will be applied with weights, effectively unposing the whole skeleton mesh in each joint frame. | |
# Then, per joint weighted transformation can then be applied | |
G_tpose_to_unposed = build_homog_matrix(torch.eye(3).view(1,1,3,3).expand(B, Nj, 3, 3).to(device), -J.unsqueeze(-1)) # BxJx4x4 | |
G_skel = torch.matmul(G, G_tpose_to_unposed) | |
G_bones = torch.matmul(G, Gk01) | |
T = torch.matmul(self.skel_weights, G_skel.permute(1, 0, 2, 3).contiguous().view(Nj, -1)).view(Nk, B, 4,4).transpose(0, 1) | |
skel_v_posed = torch.matmul(T, skel_v_align[:, :, :, None])[:, :, :3, 0] | |
skel_trans = skel_v_posed + trans[:,None,:] | |
else: | |
skel_trans = skel_v0 | |
Gk01 = build_homog_matrix(Rk01, J.unsqueeze(-1)) # BxJx4x4 | |
G_bones = torch.matmul(G, Gk01) | |
joints = J_transformed | |
skin_verts = v_trans | |
skel_verts = skel_trans | |
joints_ori = G_bones[:,:,:3,:3] | |
if skin_verts.max() > 1e3: | |
import ipdb; ipdb.set_trace() | |
output = SKELOutput(skin_verts=skin_verts, | |
skel_verts=skel_verts, | |
joints=joints, | |
joints_ori=joints_ori, | |
betas=betas, | |
poses=poses, | |
trans = trans, | |
pose_offsets = pose_offsets, | |
joints_tpose = J_tpose, | |
v_shaped = v_shaped,) | |
return output | |
def compute_bone_scale(self, J_, v_shaped, skin_v0): | |
# index [0, 1, 2, 3 4, 5, , ...] # todo add last one, figure out bone scale indices | |
# J_ bone vectors [j0, j1-j0, j2-j0, j3-j0, j4-j1, j5-j2, ...] | |
# norm(J) = length of the bone [j0, j1-j0, j2-j0, j3-j0, j4-j1, j5-j2, ...] | |
# self.joints_sockets [j0, j1-j0, j2-j0, j3-j0, j4-j1, j5-j2, ...] | |
# self.skel_weights [j0, j1, j2, j3, j4, j5, ...] | |
B = J_.shape[0] | |
Nj = J_.shape[1] | |
bone_scale = torch.ones(B, Nj).to(J_.device) | |
# BSM template joints location | |
osim_joints_r = self.apose_rel_transfo[:, :3, 3].view(1, Nj, 3).expand(B, Nj, 3).clone() | |
length_bones_bsm = torch.norm(osim_joints_r, dim=-1).expand(B, -1) | |
length_bones_smpl = torch.norm(J_, dim=-1) # (B, Nj) | |
bone_scale_parent = length_bones_smpl / length_bones_bsm | |
non_leaf_node = (self.child != 0) | |
bone_scale[:,non_leaf_node] = (bone_scale_parent[:,self.child])[:,non_leaf_node] | |
# Ulna should have the same scale as radius | |
bone_scale[:,16] = bone_scale[:,17] | |
bone_scale[:,16] = bone_scale[:,17] | |
bone_scale[:,21] = bone_scale[:,22] | |
bone_scale[:,21] = bone_scale[:,22] | |
# Thorax | |
# Thorax scale is defined by the relative position of the thorax to its child joint, not parent joint as for other bones | |
bone_scale[:, 12] = bone_scale[:, 11] | |
# Lumbars | |
# Lumbar scale is defined by the y relative position of the lumbar joint | |
length_bones_bsm = torch.abs(osim_joints_r[:,11, 1]) | |
length_bones_smpl = torch.abs(J_[:, 11, 1]) # (B, Nj) | |
bone_scale_lumbar = length_bones_smpl / length_bones_bsm | |
bone_scale[:, 11] = bone_scale_lumbar | |
# Expand to 3 dimensions and adjest scaling to avoid skin-skeleton intersection and handle the scaling of leaf body parts (hands, feet) | |
bone_scale = bone_scale.reshape(B, Nj, 1).expand(B, Nj, 3).clone() | |
for (ji, doi, dsi), (v1, v2) in scaling_keypoints.items(): | |
bone_scale[:, ji, doi] = ((v_shaped[:,v1] - v_shaped[:, v2])/ (skin_v0[:,v1] - skin_v0[:, v2]))[:,dsi] # Top over chin | |
#TODO: Add keypoints for feet scaling in scaling_keypoints | |
# Adjust thorax front-back scaling | |
# TODO fix this part | |
v1 = 3027 #thorax back | |
v2 = 3495 #thorax front | |
scale_thorax_up = ((v_shaped[:,v1] - v_shaped[:, v2])/ (skin_v0[:,v1] - skin_v0[:, v2]))[:,2] # good for large people | |
v2 = 3506 #sternum | |
scale_thorax_sternum = ((v_shaped[:,v1] - v_shaped[:, v2])/ (skin_v0[:,v1] - skin_v0[:, v2]))[:,2] # Good for skinny people | |
bone_scale[:, 12, 0] = torch.min(scale_thorax_up, scale_thorax_sternum) # Avoids super expanded ribcage for large people and sternum outside for skinny people | |
#lumbars, adjust width to be same as thorax | |
bone_scale[:, 11, 0] = bone_scale[:, 12, 0] | |
return bone_scale | |
def compute_bone_orientation(self, J, J_): | |
"""Compute each bone orientation in T pose """ | |
# method = 'unposed' | |
# method = 'learned' | |
method = 'learn_adjust' | |
B = J_.shape[0] | |
Nj = J_.shape[1] | |
# Create an array of bone vectors the bone meshes should be aligned to. | |
bone_vect = torch.zeros_like(J_) # / torch.norm(J_, dim=-1)[:, :, None] # (B, Nj, 3) | |
bone_vect[:] = J_[:, self.child] # Most bones are aligned between their parent and child joint | |
bone_vect[:,16] = bone_vect[:,16]+bone_vect[:,17] # We want to align the ulna to the segment joint 16 to 18 | |
bone_vect[:,21] = bone_vect[:,21]+bone_vect[:,22] # Same other ulna | |
# TODO Check indices here | |
# bone_vect[:,13] = bone_vect[:,12].clone() | |
bone_vect[:,12] = bone_vect.clone()[:,11].clone() # We want to align the thorax on the thorax-lumbar segment | |
# bone_vect[:,11] = bone_vect[:,0].clone() | |
osim_vect = self.apose_rel_transfo[:, :3, 3].clone().view(1, Nj, 3).expand(B, Nj, 3).clone() | |
osim_vect[:] = osim_vect[:,self.child] | |
osim_vect[:,16] = osim_vect[:,16]+osim_vect[:,17] # We want to align the ulna to the segment joint 16 to 18 | |
osim_vect[:,21] = osim_vect[:,21]+osim_vect[:,22] # We want to align the ulna to the segment joint 16 to 18 | |
# TODO: remove when this has been checked | |
# import matplotlib.pyplot as plt | |
# fig = plt.figure() | |
# ax = fig.add_subplot(111, projection='3d') | |
# ax.plot(osim_vect[:,0,0], osim_vect[:,0,1], osim_vect[:,0,2], color='r') | |
# plt.show() | |
Gk = torch.eye(3, device=J_.device).repeat(B, Nj, 1, 1) | |
if method == 'unposed': | |
return Gk | |
elif method == 'learn_adjust': | |
Gk_learned = self.per_joint_rot.view(1, Nj, 3, 3).expand(B, -1, -1, -1) #load learned rotation | |
osim_vect_corr = torch.matmul(Gk_learned, osim_vect.unsqueeze(-1)).squeeze(-1) | |
Gk[:,:] = rotation_matrix_from_vectors(osim_vect_corr, bone_vect) | |
# set nan to zero | |
# TODO: Check again why the following line was required | |
Gk[torch.isnan(Gk)] = 0 | |
# Gk[:,[18,23]] = Gk[:,[16,21]] # hand has same orientation as ulna | |
# Gk[:,[5,10]] = Gk[:,[4,9]] # toe has same orientation as calcaneus | |
# Gk[:,[0,11,12,13,14,19]] = torch.eye(3, device=J_.device).view(1,3,3).expand(B, 6, 3, 3) # pelvis, torso and shoulder blade orientation does not vary with beta, leave it | |
Gk[:, self.joint_idx_fixed_beta] = torch.eye(3, device=J_.device).view(1,3,3).expand(B, len(self.joint_idx_fixed_beta), 3, 3) # pelvis, torso and shoulder blade orientation should not vary with beta, leave it | |
Gk = torch.matmul(Gk, Gk_learned) | |
elif method == 'learned': | |
""" Apply learned transformation""" | |
Gk = self.per_joint_rot.view(1, Nj, 3, 3).expand(B, -1, -1, -1) | |
else: | |
raise NotImplementedError | |
return Gk | |