HSMR / lib /body_models /skel /osim_rot.py
IsshikiHugh's picture
feat: CPU demo
5ac1897
import torch
from lib.body_models.skel.joints_def import curve_torch_3d
from lib.body_models.skel.utils import axis_angle_to_matrix, euler_angles_to_matrix, rodrigues
class OsimJoint(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
pass
def q_to_translation(self, q, **kwargs):
return torch.zeros(q.shape[0], 3).to(q.device)
class CustomJoint(OsimJoint):
def __init__(self, axis, axis_flip) -> None:
super().__init__()
self.register_buffer('axis', torch.FloatTensor(axis))
self.register_buffer('axis_flip', torch.FloatTensor(axis_flip))
self.register_buffer('nb_dof', torch.tensor(len(axis)))
def q_to_rot(self, q, **kwargs):
ident = torch.eye(3, dtype=q.dtype).to(q.device)
Rp = ident.unsqueeze(0).expand(q.shape[0],3,3) # torch.eye(q.shape[0], 3, 3)
for i in range(self.nb_dof):
axis = self.axis[i].to(q.device)
angle_axis = q[:, i:i+1] * self.axis_flip[i].to(q.device) * axis
Rp_i = axis_angle_to_matrix(angle_axis)
Rp = torch.matmul(Rp_i, Rp)
return Rp
class CustomJoint1D(OsimJoint):
def __init__(self, axis, axis_flip) -> None:
super().__init__()
self.axis = torch.FloatTensor(axis)
self.axis = self.axis / torch.linalg.norm(self.axis)
self.axis_flip = torch.FloatTensor(axis_flip)
self.nb_dof = 1
def q_to_rot(self, q, **kwargs):
axis = self.axis.to(q.device)
angle_axis = q[:, 0:1] * self.axis_flip.to(q.device) * axis
Rp_i = axis_angle_to_matrix(angle_axis)
return Rp_i
class WalkerKnee(OsimJoint):
def __init__(self) -> None:
super().__init__()
self.register_buffer('nb_dof', torch.tensor(1))
# self.nb_dof = 1
def q_to_rot(self, q, **kwargs):
# Todo : for now implement a basic knee
theta_i = torch.zeros(q.shape[0], 3).to(q.device)
theta_i[:, 2] = -q[:, 0]
Rp_i = axis_angle_to_matrix(theta_i)
return Rp_i
class PinJoint(OsimJoint):
def __init__(self, parent_frame_ori) -> None:
super().__init__()
self.register_buffer('parent_frame_ori', torch.FloatTensor(parent_frame_ori))
self.register_buffer('nb_dof', torch.tensor(1))
def q_to_rot(self, q, **kwargs):
talus_orient_torch = self.parent_frame_ori.to(q.device)
Ra_i = euler_angles_to_matrix(talus_orient_torch, 'XYZ')
z_axis = torch.FloatTensor([0,0,1]).to(q.device)
axis = torch.matmul(Ra_i, z_axis).to(q.device)
axis_angle = q[:, 0:1] * axis
Rp_i = axis_angle_to_matrix(axis_angle)
return Rp_i
class ConstantCurvatureJoint(CustomJoint):
def __init__(self, **kwargs ) -> None:
super().__init__( **kwargs)
class EllipsoidJoint(CustomJoint):
def __init__(self, **kwargs) -> None:
super().__init__(**kwargs)