File size: 3,398 Bytes
98a8f68
 
 
161b8e5
 
 
 
 
 
 
 
289635e
161b8e5
 
c065381
161b8e5
 
 
 
 
 
 
 
 
c065381
 
161b8e5
 
 
 
 
 
 
 
44f6fd9
289635e
98a8f68
 
289635e
 
 
c065381
 
161b8e5
 
 
 
98a8f68
161b8e5
 
 
 
 
 
 
 
 
 
98a8f68
161b8e5
 
 
 
c065381
98a8f68
289635e
c065381
98a8f68
c065381
289635e
 
98a8f68
289635e
 
98a8f68
 
289635e
 
 
98a8f68
289635e
 
98a8f68
 
289635e
 
 
c065381
 
 
 
 
 
 
 
 
 
161b8e5
98a8f68
 
 
 
44f6fd9
98a8f68
 
 
 
 
 
c065381
 
98a8f68
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
import tempfile
import os
import spaces
import numpy as np
import torch
import torch.nn.functional as F
from evo.tools.file_interface import read_kitti_poses_file
from pathlib import Path
import rerun as rr
from typing import Optional, Dict
from visualization.logger import SimulationLogger
from scipy.spatial.transform import Rotation


def load_trajectory_data(traj_file: str, char_file: str) -> Dict:
    trajectory = read_kitti_poses_file(traj_file)
    matrix_trajectory = torch.from_numpy(
        np.array(trajectory.poses_se3)).to(torch.float32)

    char_feature = torch.from_numpy(np.load(char_file)).to(torch.float32)

    return {
        "traj_filename": Path(traj_file).name,
        "char_filename": Path(char_file).name,
        "char_feat": char_feature,
        "matrix_trajectory": matrix_trajectory
    }


class ETLogger(SimulationLogger):
    def __init__(self):
        super().__init__()
        rr.init("et_visualization")
        rr.log("world", rr.ViewCoordinates.RIGHT_HAND_Y_UP, timeless=True)

        self.K = np.array([
            [500, 0, 320],
            [0, 500, 240],
            [0, 0, 1]
        ])

    def log_trajectory(self, trajectory: np.ndarray):
        positions = trajectory[:, :3, 3]
        rr.log(
            "world/trajectory/points",
            rr.Points3D(
                positions,
                colors=np.full((len(positions), 4), [0.0, 0.8, 0.8, 1.0])
            ),
            timeless=True
        )

        if len(positions) > 1:
            lines = np.stack([positions[:-1], positions[1:]], axis=1)
            rr.log(
                "world/trajectory/line",
                rr.LineStrips3D(
                    lines,
                    colors=[(0.0, 0.8, 0.8, 1.0)]
                ),
                timeless=True
            )

        for k in range(len(trajectory)):
            rr.set_time_sequence("frame_idx", k)

            translation = trajectory[k, :3, 3]
            rotation_q = Rotation.from_matrix(
                trajectory[k, :3, :3]).as_quat()

            rr.log(
                f"world/camera",
                rr.Transform3D(
                    translation=translation,
                    rotation=rr.Quaternion(xyzw=rotation_q),
                ),
            )

            rr.log(
                f"world/camera/image",
                rr.Pinhole(
                    image_from_camera=self.K,
                    width=640,
                    height=480,
                ),
            )

    def log_character(self, char_feature: np.ndarray):
        rr.log(
            "world/character",
            rr.Points3D(
                char_feature.reshape(-1, 3),
                colors=np.full(
                    (char_feature.reshape(-1, 3).shape[0], 4), [0.8, 0.2, 0.2, 1.0])
            ),
            timeless=True
        )


@spaces.GPU
def visualize_et_data(traj_file: str, char_file: str) -> Optional[str]:
    try:

        data = load_trajectory_data(traj_file, char_file)

        temp_dir = tempfile.mkdtemp()
        rrd_path = os.path.join(temp_dir, "et_visualization.rrd")

        logger = ETLogger()
        logger.log_trajectory(data["matrix_trajectory"].numpy())
        logger.log_character(data["char_feat"].numpy())

        rr.save(rrd_path)
        return rrd_path

    except Exception as e:
        print(f"Error visualizing E.T. data: {str(e)}")
        return None