abreza commited on
Commit
289635e
·
1 Parent(s): 161b8e5

update et visualization

Browse files
Files changed (1) hide show
  1. visualization/et_visualizer.py +99 -41
visualization/et_visualizer.py CHANGED
@@ -1,6 +1,3 @@
1
- import tempfile
2
- import os
3
- import spaces
4
  import numpy as np
5
  import torch
6
  import torch.nn.functional as F
@@ -9,6 +6,8 @@ from pathlib import Path
9
  import rerun as rr
10
  from typing import Optional, Dict
11
  from visualization.logger import SimulationLogger
 
 
12
 
13
 
14
  def load_trajectory_data(traj_file: str, char_file: str, num_cams: int = 30) -> Dict:
@@ -55,77 +54,136 @@ class ETLogger(SimulationLogger):
55
  rr.init("et_visualization")
56
  rr.log("world", rr.ViewCoordinates.RIGHT_HAND_Y_UP, timeless=True)
57
 
 
 
 
 
 
 
 
 
 
 
58
  def log_trajectory(self, trajectory: np.ndarray, padding_mask: np.ndarray):
59
- """Log camera trajectory."""
60
  valid_frames = int(padding_mask.sum())
61
  valid_trajectory = trajectory[:valid_frames]
62
 
63
- # Log trajectory points
64
  positions = valid_trajectory[:, :3, 3]
 
 
 
 
 
65
  rr.log(
66
  "world/trajectory/points",
67
  rr.Points3D(
68
  positions,
69
- colors=np.full((len(positions), 4), [0.0, 0.8, 0.8, 1.0])
70
  ),
71
  timeless=True
72
  )
73
 
74
- # Log trajectory line
75
  if len(positions) > 1:
76
  lines = np.stack([positions[:-1], positions[1:]], axis=1)
 
 
 
 
 
77
  rr.log(
78
  "world/trajectory/line",
79
  rr.LineStrips3D(
80
  lines,
81
- colors=[(0.0, 0.8, 0.8, 1.0)]
82
  ),
83
  timeless=True
84
  )
85
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
86
  def log_character(self, char_feature: np.ndarray, padding_mask: np.ndarray):
87
- """Log character feature visualization."""
88
  valid_frames = int(padding_mask.sum())
89
  valid_char = char_feature[:, :valid_frames]
90
 
91
  if valid_char.shape[0] > 0:
 
 
 
 
 
 
 
 
92
  rr.log(
93
  "world/character",
94
  rr.Points3D(
95
  valid_char.reshape(-1, 3),
96
- colors=np.full(
97
- (valid_char.reshape(-1, 3).shape[0], 4), [0.8, 0.2, 0.2, 1.0])
98
  ),
99
  timeless=True
100
  )
101
 
102
-
103
- @spaces.GPU
104
- def visualize_et_data(traj_file: str, char_file: str) -> Optional[str]:
105
- """Visualize E.T. dataset using Rerun."""
106
- try:
107
- # Load data
108
- data = load_trajectory_data(traj_file, char_file)
109
-
110
- # Create temporary file for RRD
111
- temp_dir = tempfile.mkdtemp()
112
- rrd_path = os.path.join(temp_dir, "et_visualization.rrd")
113
-
114
- # Initialize logger and log data
115
- logger = ETLogger()
116
- logger.log_trajectory(
117
- data["raw_matrix_trajectory"].numpy(),
118
- data["padding_mask"].numpy()
119
- )
120
- logger.log_character(
121
- data["char_feat"].numpy(),
122
- data["padding_mask"].numpy()
123
- )
124
-
125
- # Save visualization
126
- rr.save(rrd_path)
127
- return rrd_path
128
-
129
- except Exception as e:
130
- print(f"Error visualizing E.T. data: {str(e)}")
131
- return None
 
 
 
 
1
  import numpy as np
2
  import torch
3
  import torch.nn.functional as F
 
6
  import rerun as rr
7
  from typing import Optional, Dict
8
  from visualization.logger import SimulationLogger
9
+ from scipy.spatial.transform import Rotation
10
+ from rerun.components import Material
11
 
12
 
13
  def load_trajectory_data(traj_file: str, char_file: str, num_cams: int = 30) -> Dict:
 
54
  rr.init("et_visualization")
55
  rr.log("world", rr.ViewCoordinates.RIGHT_HAND_Y_UP, timeless=True)
56
 
57
+ # Define default camera parameters
58
+ self.camera_width = 640 # default width
59
+ self.camera_height = 480 # default height
60
+ self.focal_length = 500 # default focal length
61
+ self.K = np.array([
62
+ [self.focal_length, 0, self.camera_width/2],
63
+ [0, self.focal_length, self.camera_height/2],
64
+ [0, 0, 1]
65
+ ])
66
+
67
  def log_trajectory(self, trajectory: np.ndarray, padding_mask: np.ndarray):
68
+ """Log camera trajectory with enhanced visualization."""
69
  valid_frames = int(padding_mask.sum())
70
  valid_trajectory = trajectory[:valid_frames]
71
 
72
+ # Log trajectory points with rainbow coloring
73
  positions = valid_trajectory[:, :3, 3]
74
+ colors = np.zeros((len(positions), 4))
75
+ colors[:, :3] = plt.cm.rainbow(
76
+ np.linspace(0, 1, len(positions)))[:, :3]
77
+ colors[:, 3] = 1.0 # Set alpha to 1
78
+
79
  rr.log(
80
  "world/trajectory/points",
81
  rr.Points3D(
82
  positions,
83
+ colors=colors
84
  ),
85
  timeless=True
86
  )
87
 
88
+ # Log trajectory line with gradient color
89
  if len(positions) > 1:
90
  lines = np.stack([positions[:-1], positions[1:]], axis=1)
91
+ line_colors = np.zeros((len(lines), 4))
92
+ line_colors[:, :3] = plt.cm.rainbow(
93
+ np.linspace(0, 1, len(lines)))[:, :3]
94
+ line_colors[:, 3] = 1.0
95
+
96
  rr.log(
97
  "world/trajectory/line",
98
  rr.LineStrips3D(
99
  lines,
100
+ colors=line_colors
101
  ),
102
  timeless=True
103
  )
104
 
105
+ # Log camera frustums
106
+ for i in range(valid_frames):
107
+ # Get camera position and rotation
108
+ translation = valid_trajectory[i, :3, 3]
109
+ rotation_matrix = valid_trajectory[i, :3, :3]
110
+ rotation_quat = Rotation.from_matrix(rotation_matrix).as_quat()
111
+
112
+ # Set time sequence for animation
113
+ rr.set_time_sequence("frame_idx", i)
114
+
115
+ # Log camera frustum
116
+ rr.log(
117
+ f"world/cameras/camera_{i}",
118
+ rr.Transform3D(
119
+ translation=translation,
120
+ rotation=rr.Quaternion(xyzw=rotation_quat),
121
+ )
122
+ )
123
+
124
+ # Add camera visualization
125
+ rr.log(
126
+ f"world/cameras/camera_{i}/frustum",
127
+ rr.Pinhole(
128
+ image_from_camera=self.K,
129
+ width=self.camera_width,
130
+ height=self.camera_height,
131
+ focal_length=self.focal_length,
132
+ ),
133
+ )
134
+
135
+ # Add coordinate axes for each camera
136
+ rr.log(
137
+ f"world/cameras/camera_{i}/axes",
138
+ rr.Arrows3D(
139
+ origins=np.zeros((3, 3)),
140
+ vectors=np.eye(3) * 0.5, # 0.5 meter long axes
141
+ colors=[[1, 0, 0, 1], [0, 1, 0, 1], [
142
+ 0, 0, 1, 1]] # RGB colors for XYZ
143
+ )
144
+ )
145
+
146
  def log_character(self, char_feature: np.ndarray, padding_mask: np.ndarray):
147
+ """Log character feature visualization with enhanced appearance."""
148
  valid_frames = int(padding_mask.sum())
149
  valid_char = char_feature[:, :valid_frames]
150
 
151
  if valid_char.shape[0] > 0:
152
+ # Create gradient colors for character points
153
+ num_points = valid_char.reshape(-1, 3).shape[0]
154
+ colors = np.zeros((num_points, 4))
155
+ colors[:, 0] = 0.8 # Red component
156
+ colors[:, 1] = 0.2 # Green component
157
+ colors[:, 2] = np.linspace(0.2, 0.8, num_points) # Blue gradient
158
+ colors[:, 3] = 1.0 # Alpha
159
+
160
  rr.log(
161
  "world/character",
162
  rr.Points3D(
163
  valid_char.reshape(-1, 3),
164
+ colors=colors,
165
+ radii=0.05 # Add point size for better visibility
166
  ),
167
  timeless=True
168
  )
169
 
170
+ # Add a semi-transparent hull around character points
171
+ try:
172
+ from scipy.spatial import ConvexHull
173
+ points = valid_char.reshape(-1, 3)
174
+ hull = ConvexHull(points)
175
+
176
+ rr.log(
177
+ "world/character/hull",
178
+ rr.Mesh3D(
179
+ vertex_positions=points[hull.vertices],
180
+ indices=hull.simplices,
181
+ mesh_material=Material(
182
+ # Semi-transparent red
183
+ albedo_factor=[0.8, 0.2, 0.2, 0.3]
184
+ )
185
+ ),
186
+ timeless=True
187
+ )
188
+ except Exception:
189
+ pass # Skip hull visualization if it fails