File size: 1,352 Bytes
5ac1897
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
import torch

def compute_scapula_loss(poses):
    
    scapula_indices = [26, 27, 28, 36, 37, 38]
    
    scapula_poses = poses[:, scapula_indices]
    scapula_loss = torch.linalg.norm(scapula_poses, ord=2)
    return scapula_loss

def compute_spine_loss(poses):
    
    spine_indices = range(17, 25)
    
    spine_poses = poses[:, spine_indices]
    spine_loss = torch.linalg.norm(spine_poses, ord=2)
    return spine_loss

def compute_pose_loss(poses, pose_init):
    
    pose_loss = torch.linalg.norm(poses[:, 3:], ord=2) # The global rotation should not be constrained
    return pose_loss

def compute_anchor_pose(poses, pose_init):
    
    pose_loss = torch.nn.functional.mse_loss(poses[:, :3], pose_init[:, :3])
    return pose_loss 

def compute_anchor_trans(trans, trans_init):

    trans_loss = torch.nn.functional.mse_loss(trans, trans_init)
    return trans_loss 

def compute_time_loss(poses):
    
    pose_delta = poses[1:] - poses[:-1]
    time_loss = torch.linalg.norm(pose_delta, ord=2)
    return time_loss

def pretty_loss_print(loss_dict):
    # Pretty print the loss on the form loss val | loss1 val1 | loss2 val2 
    # Start with the total loss
    loss = sum(loss_dict.values())
    pretty_loss = f'{loss:.4f}'
    for key, val in loss_dict.items():
        pretty_loss += f' | {key} {val:.4f}'
    return pretty_loss