File size: 2,935 Bytes
5ac1897
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
import torch
from lib.body_models.skel.joints_def import curve_torch_3d
from lib.body_models.skel.utils import axis_angle_to_matrix, euler_angles_to_matrix, rodrigues

class OsimJoint(torch.nn.Module):

    def __init__(self) -> None:
        super().__init__()
        pass

    def q_to_translation(self, q, **kwargs):
        return torch.zeros(q.shape[0], 3).to(q.device)


class CustomJoint(OsimJoint):

    def __init__(self, axis, axis_flip) -> None:
        super().__init__()
        self.register_buffer('axis', torch.FloatTensor(axis))
        self.register_buffer('axis_flip', torch.FloatTensor(axis_flip))
        self.register_buffer('nb_dof', torch.tensor(len(axis)))

    def q_to_rot(self, q, **kwargs):

        ident = torch.eye(3, dtype=q.dtype).to(q.device)

        Rp = ident.unsqueeze(0).expand(q.shape[0],3,3) # torch.eye(q.shape[0], 3, 3)
        for i in range(self.nb_dof):
            axis = self.axis[i].to(q.device)
            angle_axis = q[:, i:i+1] * self.axis_flip[i].to(q.device) * axis
            Rp_i = axis_angle_to_matrix(angle_axis)
            Rp = torch.matmul(Rp_i, Rp)
        return Rp



class CustomJoint1D(OsimJoint):

    def __init__(self, axis, axis_flip) -> None:
        super().__init__()
        self.axis = torch.FloatTensor(axis)
        self.axis = self.axis / torch.linalg.norm(self.axis)
        self.axis_flip = torch.FloatTensor(axis_flip)
        self.nb_dof = 1

    def q_to_rot(self, q, **kwargs):
        axis = self.axis.to(q.device)
        angle_axis = q[:, 0:1] * self.axis_flip.to(q.device) * axis
        Rp_i = axis_angle_to_matrix(angle_axis)
        return Rp_i


class WalkerKnee(OsimJoint):

    def __init__(self) -> None:
        super().__init__()
        self.register_buffer('nb_dof', torch.tensor(1))
        # self.nb_dof = 1

    def q_to_rot(self, q, **kwargs):
        # Todo : for now implement a basic knee
        theta_i = torch.zeros(q.shape[0], 3).to(q.device)
        theta_i[:, 2] = -q[:, 0]
        Rp_i = axis_angle_to_matrix(theta_i)
        return Rp_i

class PinJoint(OsimJoint):

    def __init__(self, parent_frame_ori) -> None:
        super().__init__()
        self.register_buffer('parent_frame_ori', torch.FloatTensor(parent_frame_ori))
        self.register_buffer('nb_dof', torch.tensor(1))


    def q_to_rot(self, q, **kwargs):

        talus_orient_torch = self.parent_frame_ori.to(q.device)
        Ra_i = euler_angles_to_matrix(talus_orient_torch, 'XYZ')

        z_axis = torch.FloatTensor([0,0,1]).to(q.device)
        axis = torch.matmul(Ra_i, z_axis).to(q.device)

        axis_angle = q[:, 0:1] * axis
        Rp_i = axis_angle_to_matrix(axis_angle)

        return Rp_i


class ConstantCurvatureJoint(CustomJoint):

    def __init__(self, **kwargs ) -> None:
        super().__init__( **kwargs)



class EllipsoidJoint(CustomJoint):

    def __init__(self, **kwargs) -> None:
        super().__init__(**kwargs)