File size: 4,925 Bytes
5ac1897
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120

import os
import pickle
import torch
import numpy as np
from psbody.mesh.sphere import Sphere

# to_params = lambda x: torch.from_numpy(x).float().to(self.device).requires_grad_(True)
# to_torch = lambda x: torch.from_numpy(x).float().to(self.device)

def to_params(x, device):
    return x.to(device).requires_grad_(True)

def to_torch(x, device):
    return torch.from_numpy(x).float().to(device)

def to_numpy(x):
    return x.detach().cpu().numpy()

def load_smpl_seq(smpl_seq_path, gender=None, straighten_hands=False):

    if not os.path.exists(smpl_seq_path):
        raise Exception('Path does not exist: {}'.format(smpl_seq_path))
    
    if smpl_seq_path.endswith('.pkl'):
        data_dict = pickle.load(open(smpl_seq_path, 'rb'))
    
    elif smpl_seq_path.endswith('.npz'):
        data_dict = np.load(smpl_seq_path, allow_pickle=True)
        
        if data_dict.files == ['pred_smpl_parms', 'verts', 'pred_cam_t']:
            data_dict = data_dict['pred_smpl_parms'].item()# ['global_orient', 'body_pose', 'body_pose_axis_angle', 'global_orient_axis_angle', 'betas']
        else:
            data_dict = {key: data_dict[key] for key in data_dict.keys()} # convert to python dict
    else:
        raise Exception('Unknown file format: {}. Supported formats are .pkl and .npz'.format(smpl_seq_path))
        
    # Instanciate a dictionary with the keys expected by the fitter
    data_fixed = {}  
    
    # Get gender     
    if 'gender' not in data_dict:
        assert gender is not None, f"The provided SMPL data dictionary does not contain gender, you need to pass it in command line"
        data_fixed['gender'] = gender
    elif not isinstance(data_dict['gender'], str):
            # In some npz, the gender type happens to be: array('male', dtype='<U4'). So we convert it to string
            data_fixed['gender'] = str(data_dict['gender'])
    else:
        data_fixed['gender'] = gender
            
    # convert tensors to numpy arrays 
    for key, val in data_dict.items():
        if isinstance(val, torch.Tensor):
            data_dict[key] = val.detach().cpu().numpy()

    # Get the SMPL pose
    if 'poses' in data_dict: 
        poses = data_dict['poses']
    elif 'body_pose_axis_angle' in data_dict and 'global_orient_axis_angle' in data_dict:
        # assert 'global_orient' in data_dict and 'body_pose' in data_dict, f"Could not find poses in {smpl_seq_path}. Available keys: {data_dict.keys()})"
        poses = np.concatenate([data_dict['global_orient_axis_angle'], data_dict['body_pose_axis_angle']], axis=1)
        poses = poses.reshape(-1, 72)
    elif 'body_pose' in data_dict and 'global_orient' in data_dict:
        poses = np.concatenate([data_dict['global_orient_axis_angle'], data_dict['body_pose_axis_angle']], axis=-1)
    else: 
        raise Exception(f"Could not find poses in {smpl_seq_path}. Available keys: {data_dict.keys()})")
        
    if poses.shape[1] == 156:
        # Those are SMPL+H poses, we remove the hand poses to keep only the body poses
        smpl_poses = np.zeros((poses.shape[0], 72))
        smpl_poses[:, :72-2*3] = poses[:, :72-2*3] # We leave params for SMPL joints 22 and 23 to zero as these DOF are not present in SMPLH
        poses = smpl_poses
    
    # Set SMPL joints 22 and 23 to zero as SKEL has rigid hands 
    if straighten_hands:      
        poses[:, 72-2*3:] = 0
        
    data_fixed['poses'] = poses
        
    # Translation
    if 'trans' not in data_dict:
        data_fixed['trans'] = np.zeros((poses.shape[0], 3))
    else:
        data_fixed['trans'] = data_dict['trans']
        
    # Get betas 
    betas = data_dict['betas'][..., :10] # Keep only the 10 first betas
    if len(betas.shape) == 1 and len(poses.shape) == 2:
        betas = betas[None, :] # Add a batch dimension
    data_fixed['betas'] = betas
     
    for key in ['trans', 'poses', 'betas', 'gender']:
        assert key in data_fixed.keys(), f'Could not find {key} in {smpl_seq_path}. Available keys: {data_fixed.keys()})'
        
    out_dict = {}
    out_dict['trans'] = data_fixed['trans']
    out_dict['poses'] = data_fixed['poses']
    out_dict['betas'] = data_fixed['betas']
    out_dict['gender'] = data_fixed['gender']
    
    return out_dict
        
        
def location_to_spheres(loc, color=(1,0,0), radius=0.02):
    """Given an array of 3D points, return a list of spheres located at those positions.

    Args:
        loc (numpy.array): Nx3 array giving 3D positions
        color (tuple, optional): One RGB float color vector to color the spheres. Defaults to (1,0,0).
        radius (float, optional): Radius of the spheres in meters. Defaults to 0.02.

    Returns:
        list: List of spheres Mesh
    """

    cL = [Sphere(np.asarray([loc[i, 0], loc[i, 1], loc[i, 2]]), radius).to_mesh() for i in range(loc.shape[0])]
    for spL in cL:
        spL.set_vertex_colors(np.array(color)) 
    return cL