Spaces:
Running
on
Zero
Running
on
Zero
update et visualization
Browse files- visualization/et_visualizer.py +99 -41
visualization/et_visualizer.py
CHANGED
@@ -1,6 +1,3 @@
|
|
1 |
-
import tempfile
|
2 |
-
import os
|
3 |
-
import spaces
|
4 |
import numpy as np
|
5 |
import torch
|
6 |
import torch.nn.functional as F
|
@@ -9,6 +6,8 @@ from pathlib import Path
|
|
9 |
import rerun as rr
|
10 |
from typing import Optional, Dict
|
11 |
from visualization.logger import SimulationLogger
|
|
|
|
|
12 |
|
13 |
|
14 |
def load_trajectory_data(traj_file: str, char_file: str, num_cams: int = 30) -> Dict:
|
@@ -55,77 +54,136 @@ class ETLogger(SimulationLogger):
|
|
55 |
rr.init("et_visualization")
|
56 |
rr.log("world", rr.ViewCoordinates.RIGHT_HAND_Y_UP, timeless=True)
|
57 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
58 |
def log_trajectory(self, trajectory: np.ndarray, padding_mask: np.ndarray):
|
59 |
-
"""Log camera trajectory."""
|
60 |
valid_frames = int(padding_mask.sum())
|
61 |
valid_trajectory = trajectory[:valid_frames]
|
62 |
|
63 |
-
# Log trajectory points
|
64 |
positions = valid_trajectory[:, :3, 3]
|
|
|
|
|
|
|
|
|
|
|
65 |
rr.log(
|
66 |
"world/trajectory/points",
|
67 |
rr.Points3D(
|
68 |
positions,
|
69 |
-
colors=
|
70 |
),
|
71 |
timeless=True
|
72 |
)
|
73 |
|
74 |
-
# Log trajectory line
|
75 |
if len(positions) > 1:
|
76 |
lines = np.stack([positions[:-1], positions[1:]], axis=1)
|
|
|
|
|
|
|
|
|
|
|
77 |
rr.log(
|
78 |
"world/trajectory/line",
|
79 |
rr.LineStrips3D(
|
80 |
lines,
|
81 |
-
colors=
|
82 |
),
|
83 |
timeless=True
|
84 |
)
|
85 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
86 |
def log_character(self, char_feature: np.ndarray, padding_mask: np.ndarray):
|
87 |
-
"""Log character feature visualization."""
|
88 |
valid_frames = int(padding_mask.sum())
|
89 |
valid_char = char_feature[:, :valid_frames]
|
90 |
|
91 |
if valid_char.shape[0] > 0:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
92 |
rr.log(
|
93 |
"world/character",
|
94 |
rr.Points3D(
|
95 |
valid_char.reshape(-1, 3),
|
96 |
-
colors=
|
97 |
-
|
98 |
),
|
99 |
timeless=True
|
100 |
)
|
101 |
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
|
118 |
-
|
119 |
-
|
120 |
-
|
121 |
-
|
122 |
-
data["padding_mask"].numpy()
|
123 |
-
)
|
124 |
-
|
125 |
-
# Save visualization
|
126 |
-
rr.save(rrd_path)
|
127 |
-
return rrd_path
|
128 |
-
|
129 |
-
except Exception as e:
|
130 |
-
print(f"Error visualizing E.T. data: {str(e)}")
|
131 |
-
return None
|
|
|
|
|
|
|
|
|
1 |
import numpy as np
|
2 |
import torch
|
3 |
import torch.nn.functional as F
|
|
|
6 |
import rerun as rr
|
7 |
from typing import Optional, Dict
|
8 |
from visualization.logger import SimulationLogger
|
9 |
+
from scipy.spatial.transform import Rotation
|
10 |
+
from rerun.components import Material
|
11 |
|
12 |
|
13 |
def load_trajectory_data(traj_file: str, char_file: str, num_cams: int = 30) -> Dict:
|
|
|
54 |
rr.init("et_visualization")
|
55 |
rr.log("world", rr.ViewCoordinates.RIGHT_HAND_Y_UP, timeless=True)
|
56 |
|
57 |
+
# Define default camera parameters
|
58 |
+
self.camera_width = 640 # default width
|
59 |
+
self.camera_height = 480 # default height
|
60 |
+
self.focal_length = 500 # default focal length
|
61 |
+
self.K = np.array([
|
62 |
+
[self.focal_length, 0, self.camera_width/2],
|
63 |
+
[0, self.focal_length, self.camera_height/2],
|
64 |
+
[0, 0, 1]
|
65 |
+
])
|
66 |
+
|
67 |
def log_trajectory(self, trajectory: np.ndarray, padding_mask: np.ndarray):
|
68 |
+
"""Log camera trajectory with enhanced visualization."""
|
69 |
valid_frames = int(padding_mask.sum())
|
70 |
valid_trajectory = trajectory[:valid_frames]
|
71 |
|
72 |
+
# Log trajectory points with rainbow coloring
|
73 |
positions = valid_trajectory[:, :3, 3]
|
74 |
+
colors = np.zeros((len(positions), 4))
|
75 |
+
colors[:, :3] = plt.cm.rainbow(
|
76 |
+
np.linspace(0, 1, len(positions)))[:, :3]
|
77 |
+
colors[:, 3] = 1.0 # Set alpha to 1
|
78 |
+
|
79 |
rr.log(
|
80 |
"world/trajectory/points",
|
81 |
rr.Points3D(
|
82 |
positions,
|
83 |
+
colors=colors
|
84 |
),
|
85 |
timeless=True
|
86 |
)
|
87 |
|
88 |
+
# Log trajectory line with gradient color
|
89 |
if len(positions) > 1:
|
90 |
lines = np.stack([positions[:-1], positions[1:]], axis=1)
|
91 |
+
line_colors = np.zeros((len(lines), 4))
|
92 |
+
line_colors[:, :3] = plt.cm.rainbow(
|
93 |
+
np.linspace(0, 1, len(lines)))[:, :3]
|
94 |
+
line_colors[:, 3] = 1.0
|
95 |
+
|
96 |
rr.log(
|
97 |
"world/trajectory/line",
|
98 |
rr.LineStrips3D(
|
99 |
lines,
|
100 |
+
colors=line_colors
|
101 |
),
|
102 |
timeless=True
|
103 |
)
|
104 |
|
105 |
+
# Log camera frustums
|
106 |
+
for i in range(valid_frames):
|
107 |
+
# Get camera position and rotation
|
108 |
+
translation = valid_trajectory[i, :3, 3]
|
109 |
+
rotation_matrix = valid_trajectory[i, :3, :3]
|
110 |
+
rotation_quat = Rotation.from_matrix(rotation_matrix).as_quat()
|
111 |
+
|
112 |
+
# Set time sequence for animation
|
113 |
+
rr.set_time_sequence("frame_idx", i)
|
114 |
+
|
115 |
+
# Log camera frustum
|
116 |
+
rr.log(
|
117 |
+
f"world/cameras/camera_{i}",
|
118 |
+
rr.Transform3D(
|
119 |
+
translation=translation,
|
120 |
+
rotation=rr.Quaternion(xyzw=rotation_quat),
|
121 |
+
)
|
122 |
+
)
|
123 |
+
|
124 |
+
# Add camera visualization
|
125 |
+
rr.log(
|
126 |
+
f"world/cameras/camera_{i}/frustum",
|
127 |
+
rr.Pinhole(
|
128 |
+
image_from_camera=self.K,
|
129 |
+
width=self.camera_width,
|
130 |
+
height=self.camera_height,
|
131 |
+
focal_length=self.focal_length,
|
132 |
+
),
|
133 |
+
)
|
134 |
+
|
135 |
+
# Add coordinate axes for each camera
|
136 |
+
rr.log(
|
137 |
+
f"world/cameras/camera_{i}/axes",
|
138 |
+
rr.Arrows3D(
|
139 |
+
origins=np.zeros((3, 3)),
|
140 |
+
vectors=np.eye(3) * 0.5, # 0.5 meter long axes
|
141 |
+
colors=[[1, 0, 0, 1], [0, 1, 0, 1], [
|
142 |
+
0, 0, 1, 1]] # RGB colors for XYZ
|
143 |
+
)
|
144 |
+
)
|
145 |
+
|
146 |
def log_character(self, char_feature: np.ndarray, padding_mask: np.ndarray):
|
147 |
+
"""Log character feature visualization with enhanced appearance."""
|
148 |
valid_frames = int(padding_mask.sum())
|
149 |
valid_char = char_feature[:, :valid_frames]
|
150 |
|
151 |
if valid_char.shape[0] > 0:
|
152 |
+
# Create gradient colors for character points
|
153 |
+
num_points = valid_char.reshape(-1, 3).shape[0]
|
154 |
+
colors = np.zeros((num_points, 4))
|
155 |
+
colors[:, 0] = 0.8 # Red component
|
156 |
+
colors[:, 1] = 0.2 # Green component
|
157 |
+
colors[:, 2] = np.linspace(0.2, 0.8, num_points) # Blue gradient
|
158 |
+
colors[:, 3] = 1.0 # Alpha
|
159 |
+
|
160 |
rr.log(
|
161 |
"world/character",
|
162 |
rr.Points3D(
|
163 |
valid_char.reshape(-1, 3),
|
164 |
+
colors=colors,
|
165 |
+
radii=0.05 # Add point size for better visibility
|
166 |
),
|
167 |
timeless=True
|
168 |
)
|
169 |
|
170 |
+
# Add a semi-transparent hull around character points
|
171 |
+
try:
|
172 |
+
from scipy.spatial import ConvexHull
|
173 |
+
points = valid_char.reshape(-1, 3)
|
174 |
+
hull = ConvexHull(points)
|
175 |
+
|
176 |
+
rr.log(
|
177 |
+
"world/character/hull",
|
178 |
+
rr.Mesh3D(
|
179 |
+
vertex_positions=points[hull.vertices],
|
180 |
+
indices=hull.simplices,
|
181 |
+
mesh_material=Material(
|
182 |
+
# Semi-transparent red
|
183 |
+
albedo_factor=[0.8, 0.2, 0.2, 0.3]
|
184 |
+
)
|
185 |
+
),
|
186 |
+
timeless=True
|
187 |
+
)
|
188 |
+
except Exception:
|
189 |
+
pass # Skip hull visualization if it fails
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|