""" Copyright©2023 Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute for Intelligent Systems. All rights reserved. Author: Soyong Shin, Marilyn Keller See https://skel.is.tue.mpg.de/license.html for licensing and contact information. """ import traceback import math import os import pickle import torch import smplx import omegaconf import torch.nn.functional as F from psbody.mesh import Mesh, MeshViewer, MeshViewers from tqdm import trange from pathlib import Path import lib.body_models.skel.config as cg from lib.body_models.skel.skel_model import SKEL from .losses import compute_anchor_pose, compute_anchor_trans, compute_pose_loss, compute_scapula_loss, compute_spine_loss, compute_time_loss, pretty_loss_print from .utils import location_to_spheres, to_numpy, to_params, to_torch from .align_config import config from .align_config_joint import config as config_joint class SkelFitter(object): def __init__(self, gender, device, num_betas=10, export_meshes=False, joint_optim=False) -> None: self.smpl = smplx.create(cg.smpl_folder, model_type='smpl', gender=gender, num_betas=num_betas, batch_size=1, export_meshes=False).to(device) self.skel = SKEL(gender).to(device) self.gender = gender self.device = device self.num_betas = num_betas # Instanciate masks used for the vertex to vertex fitting fitting_mask_file = Path(__file__).parent / 'riggid_parts_mask.pkl' fitting_indices = pickle.load(open(fitting_mask_file, 'rb')) fitting_mask = torch.zeros(6890, dtype=torch.bool, device=self.device) fitting_mask[fitting_indices] = 1 self.fitting_mask = fitting_mask.reshape(1, -1, 1).to(self.device) # 1xVx1 to be applied to verts that are BxVx3 smpl_torso_joints = [0,3] verts_mask = (self.smpl.lbs_weights[:,smpl_torso_joints]>0.5).sum(dim=-1)>0 self.torso_verts_mask = verts_mask.unsqueeze(0).unsqueeze(-1) # Because verts are of shape BxVx3 self.export_meshes = export_meshes # make the cfg being an object using omegaconf if joint_optim: self.cfg = omegaconf.OmegaConf.create(config_joint) else: self.cfg = omegaconf.OmegaConf.create(config) # Instanciate the mesh viewer to visualize the fitting if('DISABLE_VIEWER' in os.environ): self.mv = None print("\n DISABLE_VIEWER flag is set, running in headless mode") else: self.mv = MeshViewers((1,2), keepalive=self.cfg.keepalive_meshviewer) def run_fit(self, trans_in, betas_in, poses_in, batch_size=20, skel_data_init=None, force_recompute=False, debug=False, watch_frame=0, freevert_mesh=None, opt_sequence=False, fix_poses=False, variant_exp=''): """Align SKEL to a SMPL sequence.""" self.nb_frames = poses_in.shape[0] self.watch_frame = watch_frame self.is_skel_data_init = skel_data_init is not None self.force_recompute = force_recompute print('Fitting {} frames'.format(self.nb_frames)) print('Watching frame: {}'.format(watch_frame)) # Initialize SKEL torch params body_params = self._init_params(betas_in, poses_in, trans_in, skel_data_init, variant_exp) # We cut the whole sequence in batches for parallel optimization if batch_size > self.nb_frames: batch_size = self.nb_frames print('Batch size is larger than the number of frames. Setting batch size to {}'.format(batch_size)) n_batch = math.ceil(self.nb_frames/batch_size) pbar = trange(n_batch, desc='Running batch optimization') # Initialize the res dict to store the per frame result skel parameters out_keys = ['poses', 'betas', 'trans'] if self.export_meshes: out_keys += ['skel_v', 'skin_v', 'smpl_v'] res_dict = {key: [] for key in out_keys} res_dict['gender'] = self.gender if self.export_meshes: res_dict['skel_f'] = self.skel.skel_f.cpu().numpy().copy() res_dict['skin_f'] = self.skel.skin_f.cpu().numpy().copy() res_dict['smpl_f'] = self.smpl.faces # Iterate over the batches to fit the whole sequence for i in pbar: if debug: # Only run the first batch to test, ignore the rest if i > 1: continue # Get batch start and end indices i_start = i * batch_size i_end = min((i+1) * batch_size, self.nb_frames) # Fit the batch betas, poses, trans, verts = self._fit_batch(body_params, i, i_start, i_end, enable_time=opt_sequence, fix_poses=fix_poses) # if torch.isnan(betas).any() \ # or torch.isnan(poses).any() \ # or torch.isnan(trans).any(): # print(f'Nan values detected.') # raise ValueError('Nan values detected in the output.') # Store ethe results res_dict['poses'].append(poses) res_dict['betas'].append(betas) res_dict['trans'].append(trans) if self.export_meshes: # Store the meshes vertices skel_output = self.skel.forward(poses=poses, betas=betas, trans=trans, poses_type='skel', skelmesh=True) res_dict['skel_v'].append(skel_output.skel_verts) res_dict['skin_v'].append(skel_output.skin_verts) res_dict['smpl_v'].append(verts) if opt_sequence: # Initialize the next frames with current frame body_params['poses_skel'][i_end:] = poses[-1:].detach() body_params['trans_skel'][i_end:] = trans[-1].detach() body_params['betas_skel'][i_end:] = betas[-1:].detach() # Concatenate the batches and convert to numpy for key, val in res_dict.items(): if isinstance(val, list): res_dict[key] = torch.cat(val, dim=0).detach().cpu().numpy() return res_dict def _init_params(self, betas_smpl, poses_smpl, trans_smpl, skel_data_init=None, variant_exp=''): """ Return initial SKEL parameters from SMPL data dictionary and an optional SKEL data dictionary.""" if skel_data_init is None or self.force_recompute: poses_skel = torch.zeros((self.nb_frames, self.skel.num_q_params), device=self.device) if variant_exp == '' or variant_exp == '_official_old': poses_skel[:, :3] = poses_smpl[:, :3] # Global orient are similar between SMPL and SKEL, so init with SMPL angles elif variant_exp == '_official_fix': # https://github.com/MarilynKeller/SKEL/commit/d1f6ff62235c142ba010158e00e21fd4fe25807f#diff-09188717a56a42e9589e9bd289f9ddb4fb53160e03c81a7ced70b3a84c1d9d0bR157 pass elif variant_exp == '_my_fix': gt_orient_aa = poses_smpl[:, :3] # IMPORTANT: The alignment comes from `exp/inspect_skel/archive/orientation.py`. from lib.utils.geometry.rotation import axis_angle_to_matrix, matrix_to_euler_angles gt_orient_mat = axis_angle_to_matrix(gt_orient_aa) gt_orient_ea = matrix_to_euler_angles(gt_orient_mat, 'YXZ') flip = torch.tensor([-1, 1, 1], device=self.device) poses_skel[:, :3] = gt_orient_ea[:, [2, 1, 0]] * flip else: raise ValueError(f'Unknown variant_exp {variant_exp}') betas_skel = torch.zeros((self.nb_frames, 10), device=self.device) betas_skel[:] = betas_smpl[..., :10] trans_skel = trans_smpl # Translation is similar between SMPL and SKEL, so init with SMPL translation else: # Load from previous alignment betas_skel = to_torch(skel_data_init['betas'], self.device) poses_skel = to_torch(skel_data_init['poses'], self.device) trans_skel = to_torch(skel_data_init['trans'], self.device) # Make a dictionary out of the necessary body parameters body_params = { 'betas_skel': betas_skel, 'poses_skel': poses_skel, 'trans_skel': trans_skel, 'betas_smpl': betas_smpl, 'poses_smpl': poses_smpl, 'trans_smpl': trans_smpl } return body_params def _fit_batch(self, body_params, i, i_start, i_end, enable_time=False, fix_poses=False): """ Create parameters for the batch and run the optimization.""" # Sample a batch ver body_params = { key: val[i_start:i_end] for key, val in body_params.items()} # SMPL params betas_smpl = body_params['betas_smpl'] poses_smpl = body_params['poses_smpl'] trans_smpl = body_params['trans_smpl'] # SKEL params betas = to_params(body_params['betas_skel'], device=self.device) poses = to_params(body_params['poses_skel'], device=self.device) trans = to_params(body_params['trans_skel'], device=self.device) if 'verts' in body_params: verts = body_params['verts'] else: # Run a SMPL forward pass to get the SMPL body vertices smpl_output = self.smpl(betas=betas_smpl, body_pose=poses_smpl[:,3:], transl=trans_smpl, global_orient=poses_smpl[:,:3]) verts = smpl_output.vertices # Optimize config = self.cfg.optim_steps current_cfg = config[0] # from lib.kits.debug import set_trace # set_trace() try: if fix_poses: # for ci, cfg in enumerate(config[1:]): for ci, cfg in enumerate([config[-1]]): # To debug, only run the last step current_cfg.update(cfg) print(f'Step {ci+1}: {current_cfg.description}') self._optim([trans,betas], poses, betas, trans, verts, current_cfg, enable_time) else: if not enable_time or not self.is_skel_data_init: # Optimize the global rotation and translation for the initial fitting print(f'Step 0: {current_cfg.description}') self._optim([trans,poses], poses, betas, trans, verts, current_cfg, enable_time) for ci, cfg in enumerate(config[1:]): # for ci, cfg in enumerate([config[-1]]): # To debug, only run the last step current_cfg.update(cfg) print(f'Step {ci+1}: {current_cfg.description}') self._optim([poses], poses, betas, trans, verts, current_cfg, enable_time) # # Refine by optimizing the whole body # cfg.update(self.cfg_optim[]) # cfg.update({'mode' : 'free', 'tolerance_change': 0.0001, 'l_joint': 0.2e4}) # self._optim([trans, poses], poses, betas, trans, verts, cfg) except Exception as e: print(e) traceback.print_exc() # from lib.kits.debug import set_trace # set_trace() return betas, poses, trans, verts def _optim(self, params, poses, betas, trans, verts, cfg, enable_time=False, ): # regress anatomical joints from SMPL's vertices anat_joints = torch.einsum('bik,ji->bjk', [verts, self.skel.J_regressor_osim]) dJ=torch.zeros((poses.shape[0], 24, 3), device=betas.device) # Create the optimizer optimizer = torch.optim.LBFGS(params, lr=cfg.lr, max_iter=cfg.max_iter, line_search_fn=cfg.line_search_fn, tolerance_change=cfg.tolerance_change) poses_init = poses.detach().clone() trans_init = trans.detach().clone() def closure(): optimizer.zero_grad() # fi = self.watch_frame #frame of the batch to display # output = self.skel.forward(poses=poses[fi:fi+1], # betas=betas[fi:fi+1], # trans=trans[fi:fi+1], # poses_type='skel', # dJ=dJ[fi:fi+1], # skelmesh=True) # self._fstep_plot(output, cfg, verts[fi:fi+1], anat_joints[fi:fi+1], ) loss_dict = self._fitting_loss(poses, poses_init, betas, trans, trans_init, dJ, anat_joints, verts, cfg, enable_time) # print(pretty_loss_print(loss_dict)) loss = sum(loss_dict.values()) loss.backward() return loss for step_i in range(cfg.num_steps): loss = optimizer.step(closure).item() def _get_masks(self, cfg): pose_mask = torch.ones((self.skel.num_q_params)).to(self.device).unsqueeze(0) verts_mask = torch.ones_like(self.fitting_mask) joint_mask = torch.ones((self.skel.num_joints, 3)).to(self.device).unsqueeze(0).bool() # Mask vertices if cfg.mode=='root_only': # Only optimize the global rotation of the body, i.e. the first 3 angles of the pose pose_mask[:] = 0 # Only optimize for the global rotation pose_mask[:,:3] = 1 # Only fit the thorax vertices to recover the proper body orientation and translation verts_mask = self.torso_verts_mask elif cfg.mode=='fixed_upper_limbs': upper_limbs_joints = [0,1,2,3,6,9,12,15,17] verts_mask = (self.smpl.lbs_weights[:,upper_limbs_joints]>0.5).sum(dim=-1)>0 verts_mask = verts_mask.unsqueeze(0).unsqueeze(-1) joint_mask[:, [3,4,5,8,9,10,18,23], :] = 0 # Do not try to match the joints of the upper limbs pose_mask[:] = 1 pose_mask[:,:3] = 0 # Block the global rotation pose_mask[:,19] = 0 # block the lumbar twist # pose_mask[:, 36:39] = 0 # pose_mask[:, 43:46] = 0 # pose_mask[:, 62:65] = 0 # pose_mask[:, 62:65] = 0 elif cfg.mode=='fixed_root': pose_mask[:] = 1 pose_mask[:,:3] = 0 # Block the global rotation # pose_mask[:,19] = 0 # block the lumbar twist # The orientation of the upper limbs is often wrong in SMPL so ignore these vertices for the finale step upper_limbs_joints = [1,2,16,17] verts_mask = (self.smpl.lbs_weights[:,upper_limbs_joints]>0.5).sum(dim=-1)>0 verts_mask = torch.logical_not(verts_mask) verts_mask = verts_mask.unsqueeze(0).unsqueeze(-1) elif cfg.mode=='free': verts_mask = torch.ones_like(self.fitting_mask ) joint_mask[:]=0 joint_mask[:, [19,14], :] = 1 # Only fir the scapula join to avoid collapsing shoulders else: raise ValueError(f'Unknown mode {cfg.mode}') return pose_mask, verts_mask, joint_mask def _fitting_loss(self, poses, poses_init, betas, trans, trans_init, dJ, anat_joints, verts, cfg, enable_time=False): loss_dict = {} pose_mask, verts_mask, joint_mask = self._get_masks(cfg) poses = poses * pose_mask + poses_init * (1-pose_mask) # Mask joints to not optimize before computing the losses output = self.skel.forward(poses=poses, betas=betas, trans=trans, poses_type='skel', dJ=dJ, skelmesh=False) # Fit the SMPL vertices # We know the skinning of the forearm and the neck are not perfect, # so we create a mask of the SMPL vertices that are important to fit, like the hands and the head loss_dict['verts_loss_loose'] = cfg.l_verts_loose * (verts_mask * (output.skin_verts - verts)**2).sum() / (((verts_mask).sum()*self.nb_frames)) # Fit the regressed joints, this avoids collapsing shoulders # loss_dict['joint_loss'] = cfg.l_joint * F.mse_loss(output.joints, anat_joints) loss_dict['joint_loss'] = cfg.l_joint * (joint_mask * (output.joints - anat_joints)**2).mean() # Time consistancy if poses.shape[0] > 1 and enable_time: # This avoids unstable hips orientationZ loss_dict['time_loss'] = cfg.l_time_loss * F.mse_loss(poses[1:], poses[:-1]) loss_dict['pose_loss'] = cfg.l_pose_loss * compute_pose_loss(poses, poses_init) if cfg.use_basic_loss is False: # These losses can be used to regularize the optimization but are not always necessary loss_dict['anch_rot'] = cfg.l_anch_pose * compute_anchor_pose(poses, poses_init) loss_dict['anch_trans'] = cfg.l_anch_trans * compute_anchor_trans(trans, trans_init) loss_dict['verts_loss'] = cfg.l_verts * (verts_mask * self.fitting_mask * (output.skin_verts - verts)**2).sum() / (self.fitting_mask*verts_mask).sum() # Regularize the pose loss_dict['scapula_loss'] = cfg.l_scapula_loss * compute_scapula_loss(poses) loss_dict['spine_loss'] = cfg.l_spine_loss * compute_spine_loss(poses) # Adjust the losses of all the pose regularizations sub losses with the pose_reg_factor value for key in ['scapula_loss', 'spine_loss', 'pose_loss']: loss_dict[key] = cfg.pose_reg_factor * loss_dict[key] return loss_dict def _fstep_plot(self, output, cfg, verts, anat_joints): "Function to plot each step" if('DISABLE_VIEWER' in os.environ): return pose_mask, verts_mask, joint_mask = self._get_masks(cfg) skin_err_value = ((output.skin_verts[0] - verts[0])**2).sum(dim=-1).sqrt() skin_err_value = skin_err_value / 0.05 skin_err_value = to_numpy(skin_err_value) skin_mesh = Mesh(v=to_numpy(output.skin_verts[0]), f=[], vc='white') skel_mesh = Mesh(v=to_numpy(output.skel_verts[0]), f=self.skel.skel_f.cpu().numpy(), vc='white') # Display vertex distance on SMPL smpl_verts = to_numpy(verts[0]) smpl_mesh = Mesh(v=smpl_verts, f=self.smpl.faces) smpl_mesh.set_vertex_colors_from_weights(skin_err_value, scale_to_range_1=False) smpl_mesh_masked = Mesh(v=smpl_verts[to_numpy(verts_mask[0,:,0])], f=[], vc='green') smpl_mesh_pc = Mesh(v=smpl_verts, f=[], vc='green') skin_mesh_err = Mesh(v=to_numpy(output.skin_verts[0]), f=self.skel.skin_f.cpu().numpy(), vc='white') skin_mesh_err.set_vertex_colors_from_weights(skin_err_value, scale_to_range_1=False) # List the meshes to display meshes_left = [skin_mesh_err, smpl_mesh_pc] meshes_right = [smpl_mesh_masked, skin_mesh, skel_mesh] if cfg.l_joint > 0: # Plot the joints meshes_right += location_to_spheres(to_numpy(output.joints[joint_mask[:,:,0]]), color=(1,0,0), radius=0.02) meshes_right += location_to_spheres(to_numpy(anat_joints[joint_mask[:,:,0]]), color=(0,1,0), radius=0.02) \ self.mv[0][0].set_dynamic_meshes(meshes_left) self.mv[0][1].set_dynamic_meshes(meshes_right) # print(poses[frame_to_watch, :3]) # print(trans[frame_to_watch]) # print(betas[frame_to_watch, :3]) # mv.get_keypress()