Spaces:
Sleeping
Sleeping
import torch | |
from lib.body_models.skel.joints_def import curve_torch_3d | |
from lib.body_models.skel.utils import axis_angle_to_matrix, euler_angles_to_matrix, rodrigues | |
class OsimJoint(torch.nn.Module): | |
def __init__(self) -> None: | |
super().__init__() | |
pass | |
def q_to_translation(self, q, **kwargs): | |
return torch.zeros(q.shape[0], 3).to(q.device) | |
class CustomJoint(OsimJoint): | |
def __init__(self, axis, axis_flip) -> None: | |
super().__init__() | |
self.register_buffer('axis', torch.FloatTensor(axis)) | |
self.register_buffer('axis_flip', torch.FloatTensor(axis_flip)) | |
self.register_buffer('nb_dof', torch.tensor(len(axis))) | |
def q_to_rot(self, q, **kwargs): | |
ident = torch.eye(3, dtype=q.dtype).to(q.device) | |
Rp = ident.unsqueeze(0).expand(q.shape[0],3,3) # torch.eye(q.shape[0], 3, 3) | |
for i in range(self.nb_dof): | |
axis = self.axis[i].to(q.device) | |
angle_axis = q[:, i:i+1] * self.axis_flip[i].to(q.device) * axis | |
Rp_i = axis_angle_to_matrix(angle_axis) | |
Rp = torch.matmul(Rp_i, Rp) | |
return Rp | |
class CustomJoint1D(OsimJoint): | |
def __init__(self, axis, axis_flip) -> None: | |
super().__init__() | |
self.axis = torch.FloatTensor(axis) | |
self.axis = self.axis / torch.linalg.norm(self.axis) | |
self.axis_flip = torch.FloatTensor(axis_flip) | |
self.nb_dof = 1 | |
def q_to_rot(self, q, **kwargs): | |
axis = self.axis.to(q.device) | |
angle_axis = q[:, 0:1] * self.axis_flip.to(q.device) * axis | |
Rp_i = axis_angle_to_matrix(angle_axis) | |
return Rp_i | |
class WalkerKnee(OsimJoint): | |
def __init__(self) -> None: | |
super().__init__() | |
self.register_buffer('nb_dof', torch.tensor(1)) | |
# self.nb_dof = 1 | |
def q_to_rot(self, q, **kwargs): | |
# Todo : for now implement a basic knee | |
theta_i = torch.zeros(q.shape[0], 3).to(q.device) | |
theta_i[:, 2] = -q[:, 0] | |
Rp_i = axis_angle_to_matrix(theta_i) | |
return Rp_i | |
class PinJoint(OsimJoint): | |
def __init__(self, parent_frame_ori) -> None: | |
super().__init__() | |
self.register_buffer('parent_frame_ori', torch.FloatTensor(parent_frame_ori)) | |
self.register_buffer('nb_dof', torch.tensor(1)) | |
def q_to_rot(self, q, **kwargs): | |
talus_orient_torch = self.parent_frame_ori.to(q.device) | |
Ra_i = euler_angles_to_matrix(talus_orient_torch, 'XYZ') | |
z_axis = torch.FloatTensor([0,0,1]).to(q.device) | |
axis = torch.matmul(Ra_i, z_axis).to(q.device) | |
axis_angle = q[:, 0:1] * axis | |
Rp_i = axis_angle_to_matrix(axis_angle) | |
return Rp_i | |
class ConstantCurvatureJoint(CustomJoint): | |
def __init__(self, **kwargs ) -> None: | |
super().__init__( **kwargs) | |
class EllipsoidJoint(CustomJoint): | |
def __init__(self, **kwargs) -> None: | |
super().__init__(**kwargs) | |