Spaces:
Sleeping
Sleeping
File size: 1,352 Bytes
5ac1897 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 |
import torch
def compute_scapula_loss(poses):
scapula_indices = [26, 27, 28, 36, 37, 38]
scapula_poses = poses[:, scapula_indices]
scapula_loss = torch.linalg.norm(scapula_poses, ord=2)
return scapula_loss
def compute_spine_loss(poses):
spine_indices = range(17, 25)
spine_poses = poses[:, spine_indices]
spine_loss = torch.linalg.norm(spine_poses, ord=2)
return spine_loss
def compute_pose_loss(poses, pose_init):
pose_loss = torch.linalg.norm(poses[:, 3:], ord=2) # The global rotation should not be constrained
return pose_loss
def compute_anchor_pose(poses, pose_init):
pose_loss = torch.nn.functional.mse_loss(poses[:, :3], pose_init[:, :3])
return pose_loss
def compute_anchor_trans(trans, trans_init):
trans_loss = torch.nn.functional.mse_loss(trans, trans_init)
return trans_loss
def compute_time_loss(poses):
pose_delta = poses[1:] - poses[:-1]
time_loss = torch.linalg.norm(pose_delta, ord=2)
return time_loss
def pretty_loss_print(loss_dict):
# Pretty print the loss on the form loss val | loss1 val1 | loss2 val2
# Start with the total loss
loss = sum(loss_dict.values())
pretty_loss = f'{loss:.4f}'
for key, val in loss_dict.items():
pretty_loss += f' | {key} {val:.4f}'
return pretty_loss
|