Spaces:
Sleeping
Sleeping
File size: 3,188 Bytes
5ac1897 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 |
from lib.kits.basic import *
import smplx
from psbody.mesh import Mesh
class MoYoSMPLX(smplx.SMPLX):
def __init__(
self,
model_path : Union[str, Path],
v_template_path : Union[str, Path],
batch_size = 1,
n_betas = 10,
device = 'cpu'
):
if isinstance(v_template_path, Path):
v_template_path = str(v_template_path)
# Load the `v_template`.
v_template_mesh = Mesh(filename=v_template_path)
v_template = to_tensor(v_template_mesh.v, device=device)
self.n_betas = n_betas
# Create the `body_model_params`.
body_model_params = {
'model_path' : model_path,
'gender' : 'neutral',
'v_template' : v_template.float(),
'batch_size' : batch_size,
'create_global_orient' : True,
'create_body_pose' : True,
'create_betas' : True,
'num_betas' : self.n_betas, # They actually don't use num_betas.
'create_left_hand_pose' : True,
'create_right_hand_pose' : True,
'create_expression' : True,
'create_jaw_pose' : True,
'create_leye_pose' : True,
'create_reye_pose' : True,
'create_transl' : True,
'use_pca' : False,
'flat_hand_mean' : True,
'dtype' : torch.float32,
}
super().__init__(**body_model_params)
self = self.to(device)
def forward(self, **kwargs):
''' Only all parameters are passed, the batch_size will be flexible adjusted. '''
assert 'global_orient' in kwargs, '`global_orient` is required for the forward pass.'
assert 'body_pose' in kwargs, '`body_pose` is required for the forward pass.'
B = kwargs['global_orient'].shape[0]
body_pose = kwargs['body_pose']
if 'left_hand_pose' not in kwargs:
kwargs['left_hand_pose'] = body_pose.new_zeros((B, 45))
get_logger().warning('`left_hand_pose` is not provided, but it\'s expected, set to zeros.')
if 'right_hand_pose' not in kwargs:
kwargs['right_hand_pose'] = body_pose.new_zeros((B, 45))
get_logger().warning('`left_hand_pose` is not provided, but it\'s expected, set to zeros.')
if 'transl' not in kwargs:
kwargs['transl'] = body_pose.new_zeros((B, 3))
if 'betas' not in kwargs:
kwargs['betas'] = body_pose.new_zeros((B, self.n_betas))
if 'expression' not in kwargs:
kwargs['expression'] = body_pose.new_zeros((B, 10))
if 'jaw_pose' not in kwargs:
kwargs['jaw_pose'] = body_pose.new_zeros((B, 3))
if 'leye_pose' not in kwargs:
kwargs['leye_pose'] = body_pose.new_zeros((B, 3))
if 'reye_pose' not in kwargs:
kwargs['reye_pose'] = body_pose.new_zeros((B, 3))
return super().forward(**kwargs) |