Spaces:
Running
on
L4
Running
on
L4
""" | |
Copyright©2023 Max-Planck-Gesellschaft zur Förderung | |
der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute | |
for Intelligent Systems. All rights reserved. | |
Author: Soyong Shin, Marilyn Keller | |
See https://skel.is.tue.mpg.de/license.html for licensing and contact information. | |
""" | |
import traceback | |
import math | |
import os | |
import pickle | |
import torch | |
import smplx | |
import omegaconf | |
import torch.nn.functional as F | |
from psbody.mesh import Mesh, MeshViewer, MeshViewers | |
from tqdm import trange | |
from pathlib import Path | |
import lib.body_models.skel.config as cg | |
from lib.body_models.skel.skel_model import SKEL | |
from .losses import compute_anchor_pose, compute_anchor_trans, compute_pose_loss, compute_scapula_loss, compute_spine_loss, compute_time_loss, pretty_loss_print | |
from .utils import location_to_spheres, to_numpy, to_params, to_torch | |
from .align_config import config | |
from .align_config_joint import config as config_joint | |
class SkelFitter(object): | |
def __init__(self, gender, device, num_betas=10, export_meshes=False, joint_optim=False) -> None: | |
self.smpl = smplx.create(cg.smpl_folder, model_type='smpl', gender=gender, num_betas=num_betas, batch_size=1, export_meshes=False).to(device) | |
self.skel = SKEL(gender).to(device) | |
self.gender = gender | |
self.device = device | |
self.num_betas = num_betas | |
# Instanciate masks used for the vertex to vertex fitting | |
fitting_mask_file = Path(__file__).parent / 'riggid_parts_mask.pkl' | |
fitting_indices = pickle.load(open(fitting_mask_file, 'rb')) | |
fitting_mask = torch.zeros(6890, dtype=torch.bool, device=self.device) | |
fitting_mask[fitting_indices] = 1 | |
self.fitting_mask = fitting_mask.reshape(1, -1, 1).to(self.device) # 1xVx1 to be applied to verts that are BxVx3 | |
smpl_torso_joints = [0,3] | |
verts_mask = (self.smpl.lbs_weights[:,smpl_torso_joints]>0.5).sum(dim=-1)>0 | |
self.torso_verts_mask = verts_mask.unsqueeze(0).unsqueeze(-1) # Because verts are of shape BxVx3 | |
self.export_meshes = export_meshes | |
# make the cfg being an object using omegaconf | |
if joint_optim: | |
self.cfg = omegaconf.OmegaConf.create(config_joint) | |
else: | |
self.cfg = omegaconf.OmegaConf.create(config) | |
# Instanciate the mesh viewer to visualize the fitting | |
if('DISABLE_VIEWER' in os.environ): | |
self.mv = None | |
print("\n DISABLE_VIEWER flag is set, running in headless mode") | |
else: | |
self.mv = MeshViewers((1,2), keepalive=self.cfg.keepalive_meshviewer) | |
def run_fit(self, | |
trans_in, | |
betas_in, | |
poses_in, | |
batch_size=20, | |
skel_data_init=None, | |
force_recompute=False, | |
debug=False, | |
watch_frame=0, | |
freevert_mesh=None, | |
opt_sequence=False, | |
fix_poses=False, | |
variant_exp=''): | |
"""Align SKEL to a SMPL sequence.""" | |
self.nb_frames = poses_in.shape[0] | |
self.watch_frame = watch_frame | |
self.is_skel_data_init = skel_data_init is not None | |
self.force_recompute = force_recompute | |
print('Fitting {} frames'.format(self.nb_frames)) | |
print('Watching frame: {}'.format(watch_frame)) | |
# Initialize SKEL torch params | |
body_params = self._init_params(betas_in, poses_in, trans_in, skel_data_init, variant_exp) | |
# We cut the whole sequence in batches for parallel optimization | |
if batch_size > self.nb_frames: | |
batch_size = self.nb_frames | |
print('Batch size is larger than the number of frames. Setting batch size to {}'.format(batch_size)) | |
n_batch = math.ceil(self.nb_frames/batch_size) | |
pbar = trange(n_batch, desc='Running batch optimization') | |
# Initialize the res dict to store the per frame result skel parameters | |
out_keys = ['poses', 'betas', 'trans'] | |
if self.export_meshes: | |
out_keys += ['skel_v', 'skin_v', 'smpl_v'] | |
res_dict = {key: [] for key in out_keys} | |
res_dict['gender'] = self.gender | |
if self.export_meshes: | |
res_dict['skel_f'] = self.skel.skel_f.cpu().numpy().copy() | |
res_dict['skin_f'] = self.skel.skin_f.cpu().numpy().copy() | |
res_dict['smpl_f'] = self.smpl.faces | |
# Iterate over the batches to fit the whole sequence | |
for i in pbar: | |
if debug: | |
# Only run the first batch to test, ignore the rest | |
if i > 1: | |
continue | |
# Get batch start and end indices | |
i_start = i * batch_size | |
i_end = min((i+1) * batch_size, self.nb_frames) | |
# Fit the batch | |
betas, poses, trans, verts = self._fit_batch(body_params, i, i_start, i_end, enable_time=opt_sequence, fix_poses=fix_poses) | |
# if torch.isnan(betas).any() \ | |
# or torch.isnan(poses).any() \ | |
# or torch.isnan(trans).any(): | |
# print(f'Nan values detected.') | |
# raise ValueError('Nan values detected in the output.') | |
# Store ethe results | |
res_dict['poses'].append(poses) | |
res_dict['betas'].append(betas) | |
res_dict['trans'].append(trans) | |
if self.export_meshes: | |
# Store the meshes vertices | |
skel_output = self.skel.forward(poses=poses, betas=betas, trans=trans, poses_type='skel', skelmesh=True) | |
res_dict['skel_v'].append(skel_output.skel_verts) | |
res_dict['skin_v'].append(skel_output.skin_verts) | |
res_dict['smpl_v'].append(verts) | |
if opt_sequence: | |
# Initialize the next frames with current frame | |
body_params['poses_skel'][i_end:] = poses[-1:].detach() | |
body_params['trans_skel'][i_end:] = trans[-1].detach() | |
body_params['betas_skel'][i_end:] = betas[-1:].detach() | |
# Concatenate the batches and convert to numpy | |
for key, val in res_dict.items(): | |
if isinstance(val, list): | |
res_dict[key] = torch.cat(val, dim=0).detach().cpu().numpy() | |
return res_dict | |
def _init_params(self, betas_smpl, poses_smpl, trans_smpl, skel_data_init=None, variant_exp=''): | |
""" Return initial SKEL parameters from SMPL data dictionary and an optional SKEL data dictionary.""" | |
if skel_data_init is None or self.force_recompute: | |
poses_skel = torch.zeros((self.nb_frames, self.skel.num_q_params), device=self.device) | |
if variant_exp == '' or variant_exp == '_official_old': | |
poses_skel[:, :3] = poses_smpl[:, :3] # Global orient are similar between SMPL and SKEL, so init with SMPL angles | |
elif variant_exp == '_official_fix': | |
# https://github.com/MarilynKeller/SKEL/commit/d1f6ff62235c142ba010158e00e21fd4fe25807f#diff-09188717a56a42e9589e9bd289f9ddb4fb53160e03c81a7ced70b3a84c1d9d0bR157 | |
pass | |
elif variant_exp == '_my_fix': | |
gt_orient_aa = poses_smpl[:, :3] | |
# IMPORTANT: The alignment comes from `exp/inspect_skel/archive/orientation.py`. | |
from lib.utils.geometry.rotation import axis_angle_to_matrix, matrix_to_euler_angles | |
gt_orient_mat = axis_angle_to_matrix(gt_orient_aa) | |
gt_orient_ea = matrix_to_euler_angles(gt_orient_mat, 'YXZ') | |
flip = torch.tensor([-1, 1, 1], device=self.device) | |
poses_skel[:, :3] = gt_orient_ea[:, [2, 1, 0]] * flip | |
else: | |
raise ValueError(f'Unknown variant_exp {variant_exp}') | |
betas_skel = torch.zeros((self.nb_frames, 10), device=self.device) | |
betas_skel[:] = betas_smpl[..., :10] | |
trans_skel = trans_smpl # Translation is similar between SMPL and SKEL, so init with SMPL translation | |
else: | |
# Load from previous alignment | |
betas_skel = to_torch(skel_data_init['betas'], self.device) | |
poses_skel = to_torch(skel_data_init['poses'], self.device) | |
trans_skel = to_torch(skel_data_init['trans'], self.device) | |
# Make a dictionary out of the necessary body parameters | |
body_params = { | |
'betas_skel': betas_skel, | |
'poses_skel': poses_skel, | |
'trans_skel': trans_skel, | |
'betas_smpl': betas_smpl, | |
'poses_smpl': poses_smpl, | |
'trans_smpl': trans_smpl | |
} | |
return body_params | |
def _fit_batch(self, body_params, i, i_start, i_end, enable_time=False, fix_poses=False): | |
""" Create parameters for the batch and run the optimization.""" | |
# Sample a batch ver | |
body_params = { key: val[i_start:i_end] for key, val in body_params.items()} | |
# SMPL params | |
betas_smpl = body_params['betas_smpl'] | |
poses_smpl = body_params['poses_smpl'] | |
trans_smpl = body_params['trans_smpl'] | |
# SKEL params | |
betas = to_params(body_params['betas_skel'], device=self.device) | |
poses = to_params(body_params['poses_skel'], device=self.device) | |
trans = to_params(body_params['trans_skel'], device=self.device) | |
if 'verts' in body_params: | |
verts = body_params['verts'] | |
else: | |
# Run a SMPL forward pass to get the SMPL body vertices | |
smpl_output = self.smpl(betas=betas_smpl, body_pose=poses_smpl[:,3:], transl=trans_smpl, global_orient=poses_smpl[:,:3]) | |
verts = smpl_output.vertices | |
# Optimize | |
config = self.cfg.optim_steps | |
current_cfg = config[0] | |
# from lib.kits.debug import set_trace | |
# set_trace() | |
try: | |
if fix_poses: | |
# for ci, cfg in enumerate(config[1:]): | |
for ci, cfg in enumerate([config[-1]]): # To debug, only run the last step | |
current_cfg.update(cfg) | |
print(f'Step {ci+1}: {current_cfg.description}') | |
self._optim([trans,betas], poses, betas, trans, verts, current_cfg, enable_time) | |
else: | |
if not enable_time or not self.is_skel_data_init: | |
# Optimize the global rotation and translation for the initial fitting | |
print(f'Step 0: {current_cfg.description}') | |
self._optim([trans,poses], poses, betas, trans, verts, current_cfg, enable_time) | |
for ci, cfg in enumerate(config[1:]): | |
# for ci, cfg in enumerate([config[-1]]): # To debug, only run the last step | |
current_cfg.update(cfg) | |
print(f'Step {ci+1}: {current_cfg.description}') | |
self._optim([poses], poses, betas, trans, verts, current_cfg, enable_time) | |
# # Refine by optimizing the whole body | |
# cfg.update(self.cfg_optim[]) | |
# cfg.update({'mode' : 'free', 'tolerance_change': 0.0001, 'l_joint': 0.2e4}) | |
# self._optim([trans, poses], poses, betas, trans, verts, cfg) | |
except Exception as e: | |
print(e) | |
traceback.print_exc() | |
# from lib.kits.debug import set_trace | |
# set_trace() | |
return betas, poses, trans, verts | |
def _optim(self, | |
params, | |
poses, | |
betas, | |
trans, | |
verts, | |
cfg, | |
enable_time=False, | |
): | |
# regress anatomical joints from SMPL's vertices | |
anat_joints = torch.einsum('bik,ji->bjk', [verts, self.skel.J_regressor_osim]) | |
dJ=torch.zeros((poses.shape[0], 24, 3), device=betas.device) | |
# Create the optimizer | |
optimizer = torch.optim.LBFGS(params, | |
lr=cfg.lr, | |
max_iter=cfg.max_iter, | |
line_search_fn=cfg.line_search_fn, | |
tolerance_change=cfg.tolerance_change) | |
poses_init = poses.detach().clone() | |
trans_init = trans.detach().clone() | |
def closure(): | |
optimizer.zero_grad() | |
# fi = self.watch_frame #frame of the batch to display | |
# output = self.skel.forward(poses=poses[fi:fi+1], | |
# betas=betas[fi:fi+1], | |
# trans=trans[fi:fi+1], | |
# poses_type='skel', | |
# dJ=dJ[fi:fi+1], | |
# skelmesh=True) | |
# self._fstep_plot(output, cfg, verts[fi:fi+1], anat_joints[fi:fi+1], ) | |
loss_dict = self._fitting_loss(poses, | |
poses_init, | |
betas, | |
trans, | |
trans_init, | |
dJ, | |
anat_joints, | |
verts, | |
cfg, | |
enable_time) | |
# print(pretty_loss_print(loss_dict)) | |
loss = sum(loss_dict.values()) | |
loss.backward() | |
return loss | |
for step_i in range(cfg.num_steps): | |
loss = optimizer.step(closure).item() | |
def _get_masks(self, cfg): | |
pose_mask = torch.ones((self.skel.num_q_params)).to(self.device).unsqueeze(0) | |
verts_mask = torch.ones_like(self.fitting_mask) | |
joint_mask = torch.ones((self.skel.num_joints, 3)).to(self.device).unsqueeze(0).bool() | |
# Mask vertices | |
if cfg.mode=='root_only': | |
# Only optimize the global rotation of the body, i.e. the first 3 angles of the pose | |
pose_mask[:] = 0 # Only optimize for the global rotation | |
pose_mask[:,:3] = 1 | |
# Only fit the thorax vertices to recover the proper body orientation and translation | |
verts_mask = self.torso_verts_mask | |
elif cfg.mode=='fixed_upper_limbs': | |
upper_limbs_joints = [0,1,2,3,6,9,12,15,17] | |
verts_mask = (self.smpl.lbs_weights[:,upper_limbs_joints]>0.5).sum(dim=-1)>0 | |
verts_mask = verts_mask.unsqueeze(0).unsqueeze(-1) | |
joint_mask[:, [3,4,5,8,9,10,18,23], :] = 0 # Do not try to match the joints of the upper limbs | |
pose_mask[:] = 1 | |
pose_mask[:,:3] = 0 # Block the global rotation | |
pose_mask[:,19] = 0 # block the lumbar twist | |
# pose_mask[:, 36:39] = 0 | |
# pose_mask[:, 43:46] = 0 | |
# pose_mask[:, 62:65] = 0 | |
# pose_mask[:, 62:65] = 0 | |
elif cfg.mode=='fixed_root': | |
pose_mask[:] = 1 | |
pose_mask[:,:3] = 0 # Block the global rotation | |
# pose_mask[:,19] = 0 # block the lumbar twist | |
# The orientation of the upper limbs is often wrong in SMPL so ignore these vertices for the finale step | |
upper_limbs_joints = [1,2,16,17] | |
verts_mask = (self.smpl.lbs_weights[:,upper_limbs_joints]>0.5).sum(dim=-1)>0 | |
verts_mask = torch.logical_not(verts_mask) | |
verts_mask = verts_mask.unsqueeze(0).unsqueeze(-1) | |
elif cfg.mode=='free': | |
verts_mask = torch.ones_like(self.fitting_mask ) | |
joint_mask[:]=0 | |
joint_mask[:, [19,14], :] = 1 # Only fir the scapula join to avoid collapsing shoulders | |
else: | |
raise ValueError(f'Unknown mode {cfg.mode}') | |
return pose_mask, verts_mask, joint_mask | |
def _fitting_loss(self, | |
poses, | |
poses_init, | |
betas, | |
trans, | |
trans_init, | |
dJ, | |
anat_joints, | |
verts, | |
cfg, | |
enable_time=False): | |
loss_dict = {} | |
pose_mask, verts_mask, joint_mask = self._get_masks(cfg) | |
poses = poses * pose_mask + poses_init * (1-pose_mask) | |
# Mask joints to not optimize before computing the losses | |
output = self.skel.forward(poses=poses, betas=betas, trans=trans, poses_type='skel', dJ=dJ, skelmesh=False) | |
# Fit the SMPL vertices | |
# We know the skinning of the forearm and the neck are not perfect, | |
# so we create a mask of the SMPL vertices that are important to fit, like the hands and the head | |
loss_dict['verts_loss_loose'] = cfg.l_verts_loose * (verts_mask * (output.skin_verts - verts)**2).sum() / (((verts_mask).sum()*self.nb_frames)) | |
# Fit the regressed joints, this avoids collapsing shoulders | |
# loss_dict['joint_loss'] = cfg.l_joint * F.mse_loss(output.joints, anat_joints) | |
loss_dict['joint_loss'] = cfg.l_joint * (joint_mask * (output.joints - anat_joints)**2).mean() | |
# Time consistancy | |
if poses.shape[0] > 1 and enable_time: | |
# This avoids unstable hips orientationZ | |
loss_dict['time_loss'] = cfg.l_time_loss * F.mse_loss(poses[1:], poses[:-1]) | |
loss_dict['pose_loss'] = cfg.l_pose_loss * compute_pose_loss(poses, poses_init) | |
if cfg.use_basic_loss is False: | |
# These losses can be used to regularize the optimization but are not always necessary | |
loss_dict['anch_rot'] = cfg.l_anch_pose * compute_anchor_pose(poses, poses_init) | |
loss_dict['anch_trans'] = cfg.l_anch_trans * compute_anchor_trans(trans, trans_init) | |
loss_dict['verts_loss'] = cfg.l_verts * (verts_mask * self.fitting_mask * (output.skin_verts - verts)**2).sum() / (self.fitting_mask*verts_mask).sum() | |
# Regularize the pose | |
loss_dict['scapula_loss'] = cfg.l_scapula_loss * compute_scapula_loss(poses) | |
loss_dict['spine_loss'] = cfg.l_spine_loss * compute_spine_loss(poses) | |
# Adjust the losses of all the pose regularizations sub losses with the pose_reg_factor value | |
for key in ['scapula_loss', 'spine_loss', 'pose_loss']: | |
loss_dict[key] = cfg.pose_reg_factor * loss_dict[key] | |
return loss_dict | |
def _fstep_plot(self, output, cfg, verts, anat_joints): | |
"Function to plot each step" | |
if('DISABLE_VIEWER' in os.environ): | |
return | |
pose_mask, verts_mask, joint_mask = self._get_masks(cfg) | |
skin_err_value = ((output.skin_verts[0] - verts[0])**2).sum(dim=-1).sqrt() | |
skin_err_value = skin_err_value / 0.05 | |
skin_err_value = to_numpy(skin_err_value) | |
skin_mesh = Mesh(v=to_numpy(output.skin_verts[0]), f=[], vc='white') | |
skel_mesh = Mesh(v=to_numpy(output.skel_verts[0]), f=self.skel.skel_f.cpu().numpy(), vc='white') | |
# Display vertex distance on SMPL | |
smpl_verts = to_numpy(verts[0]) | |
smpl_mesh = Mesh(v=smpl_verts, f=self.smpl.faces) | |
smpl_mesh.set_vertex_colors_from_weights(skin_err_value, scale_to_range_1=False) | |
smpl_mesh_masked = Mesh(v=smpl_verts[to_numpy(verts_mask[0,:,0])], f=[], vc='green') | |
smpl_mesh_pc = Mesh(v=smpl_verts, f=[], vc='green') | |
skin_mesh_err = Mesh(v=to_numpy(output.skin_verts[0]), f=self.skel.skin_f.cpu().numpy(), vc='white') | |
skin_mesh_err.set_vertex_colors_from_weights(skin_err_value, scale_to_range_1=False) | |
# List the meshes to display | |
meshes_left = [skin_mesh_err, smpl_mesh_pc] | |
meshes_right = [smpl_mesh_masked, skin_mesh, skel_mesh] | |
if cfg.l_joint > 0: | |
# Plot the joints | |
meshes_right += location_to_spheres(to_numpy(output.joints[joint_mask[:,:,0]]), color=(1,0,0), radius=0.02) | |
meshes_right += location_to_spheres(to_numpy(anat_joints[joint_mask[:,:,0]]), color=(0,1,0), radius=0.02) \ | |
self.mv[0][0].set_dynamic_meshes(meshes_left) | |
self.mv[0][1].set_dynamic_meshes(meshes_right) | |
# print(poses[frame_to_watch, :3]) | |
# print(trans[frame_to_watch]) | |
# print(betas[frame_to_watch, :3]) | |
# mv.get_keypress() | |