File size: 34,365 Bytes
4d1a850
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
"""Record3D visualizer

Parse and stream record3d captures. To get the demo data, see `./assets/download_record3d_dance.sh`.
"""

import time
from pathlib import Path

import numpy as onp
import tyro
import cv2
from tqdm.auto import tqdm

import viser
import viser.extras
import viser.transforms as tf

from glob import glob
import numpy as np
import imageio.v3 as iio
import matplotlib.pyplot as plt
import psutil

def log_memory_usage(message=""):
    """Log current memory usage with an optional message."""
    process = psutil.Process()
    memory_info = process.memory_info()
    memory_mb = memory_info.rss / (1024 * 1024)  # Convert to MB
    print(f"Memory usage {message}: {memory_mb:.2f} MB")

def load_trajectory_data(traj_path="results", use_float16=True, max_frames=None, mask_folder='./train', conf_thre_percentile=10):
    """Load trajectory data from files.
    
    Args:
        traj_path: Path to the directory containing trajectory data
        use_float16: Whether to convert data to float16 to save memory
        max_frames: Maximum number of frames to load (None for all)
        mask_folder: Path to the directory containing mask images
        
    Returns:
        A dictionary containing loaded data
    """
    log_memory_usage("before loading data")
    
    data_cache = {
        'traj_3d_head1': None,
        'traj_3d_head2': None,
        'conf_mask_head1': None,
        'conf_mask_head2': None,
        'masks': None,
        'raw_video': None,
        'loaded': False
    }
    
    # Load masks
    masks_paths = sorted(glob(mask_folder + '/*.jpg'))
    masks = None
    
    if masks_paths:
        masks = [iio.imread(p) for p in masks_paths]
        masks = np.stack(masks, axis=0)
        # Convert masks to binary (0 or 1)
        masks = (masks < 1).astype(np.float32)
        masks = masks.sum(axis=-1) > 2  # Combine all channels, True where any channel was 1
        print(f"Original masks shape: {masks.shape}")
    else:
        print("No masks found. Will create default masks when needed.")
    
    data_cache['masks'] = masks
    
    if Path(traj_path).is_dir():
        # Find all trajectory files
        traj_3d_paths_head1 = sorted(glob(traj_path + '/pts3d1_p*.npy'), 
                                   key=lambda x: int(x.split('_p')[-1].split('.')[0]))
        conf_paths_head1 = sorted(glob(traj_path + '/conf1_p*.npy'), 
                                key=lambda x: int(x.split('_p')[-1].split('.')[0]))
        
        traj_3d_paths_head2 = sorted(glob(traj_path + '/pts3d2_p*.npy'), 
                                   key=lambda x: int(x.split('_p')[-1].split('.')[0]))
        conf_paths_head2 = sorted(glob(traj_path + '/conf2_p*.npy'), 
                                key=lambda x: int(x.split('_p')[-1].split('.')[0]))
        
        # Limit number of frames if specified
        if max_frames is not None:
            traj_3d_paths_head1 = traj_3d_paths_head1[:max_frames]
            conf_paths_head1 = conf_paths_head1[:max_frames] if conf_paths_head1 else []
            traj_3d_paths_head2 = traj_3d_paths_head2[:max_frames]
            conf_paths_head2 = conf_paths_head2[:max_frames] if conf_paths_head2 else []

        # Process head1
        if traj_3d_paths_head1:
            if use_float16:
                traj_3d_head1 = onp.stack([onp.load(p).astype(onp.float16) for p in traj_3d_paths_head1], axis=0)
            else:
                traj_3d_head1 = onp.stack([onp.load(p) for p in traj_3d_paths_head1], axis=0)
                
            log_memory_usage("after loading head1 data")
            
            h, w, _ = traj_3d_head1.shape[1:]
            num_frames = traj_3d_head1.shape[0]
            
            # If masks is None, create default masks (all ones)
            if masks is None:
                masks = np.ones((num_frames, h, w), dtype=bool)
                print(f"Created default masks with shape: {masks.shape}")
                data_cache['masks'] = masks
            else:
                # Resize masks to match trajectory dimensions using nearest neighbor interpolation
                masks_resized = np.zeros((masks.shape[0], h, w), dtype=bool)
                for i in range(masks.shape[0]):
                    masks_resized[i] = cv2.resize(
                        masks[i].astype(np.uint8), 
                        (w, h), 
                        interpolation=cv2.INTER_NEAREST
                    ).astype(bool)
                
                print(f"Resized masks shape: {masks_resized.shape}")
                data_cache['masks'] = masks_resized
            
            # Reshape trajectory data
            traj_3d_head1 = traj_3d_head1.reshape(traj_3d_head1.shape[0], -1, 6)
            data_cache['traj_3d_head1'] = traj_3d_head1
            
            if conf_paths_head1:
                conf_head1 = onp.stack([onp.load(p).astype(onp.float16) for p in conf_paths_head1], axis=0)
                conf_head1 = conf_head1.reshape(conf_head1.shape[0], -1)
                conf_head1 = conf_head1.mean(axis=0)
                # repeat the conf_head1 to match the number of frames in the dimension 0
                conf_head1 = np.tile(conf_head1, (num_frames, 1))
                # Convert to float32 before calculating percentile to avoid overflow
                conf_thre = np.percentile(conf_head1.astype(np.float32), conf_thre_percentile)  # Default percentile
                conf_mask_head1 = conf_head1 > conf_thre
                data_cache['conf_mask_head1'] = conf_mask_head1

        # Process head2
        if traj_3d_paths_head2:
            if use_float16:
                traj_3d_head2 = onp.stack([onp.load(p).astype(onp.float16) for p in traj_3d_paths_head2], axis=0)
            else:
                traj_3d_head2 = onp.stack([onp.load(p) for p in traj_3d_paths_head2], axis=0)
                
            log_memory_usage("after loading head2 data")
            
            # Store raw video data
            raw_video = traj_3d_head2[:, :, :, 3:6]  # [num_frames, h, w, 3]
            data_cache['raw_video'] = raw_video
            
            traj_3d_head2 = traj_3d_head2.reshape(traj_3d_head2.shape[0], -1, 6)
            data_cache['traj_3d_head2'] = traj_3d_head2
            
            if conf_paths_head2:
                conf_head2 = onp.stack([onp.load(p).astype(onp.float16) for p in conf_paths_head2], axis=0)
                conf_head2 = conf_head2.reshape(conf_head2.shape[0], -1)
                # set conf thre to be 1 percentile of the conf_head2, for each frame
                conf_thre = np.percentile(conf_head2.astype(np.float32), conf_thre_percentile, axis=1)
                conf_mask_head2 = conf_head2 > conf_thre[:, None]
                data_cache['conf_mask_head2'] = conf_mask_head2
    
    data_cache['loaded'] = True
    log_memory_usage("after loading all data")
    return data_cache

def visualize_st4rtrack(
    traj_path: str = "results",
    up_dir: str = "-z", # should be +z or -z
    max_frames: int = 100,
    share: bool = False,
    point_size: float = 0.005,
    downsample_factor: int = 3,
    num_traj_points: int = 100,
    conf_thre_percentile: float = 1,
    traj_end_frame: int = 100,
    traj_start_frame: int = 0,
    traj_line_width: float = 3.,
    fixed_length_traj: int = 20,
    server: viser.ViserServer = None,
    use_float16: bool = True,
    preloaded_data: dict = None,  # Add this parameter to accept preloaded data
    color_code: str = "jet",
    # Updated hex colors: #002676 for blue and #FDB515 for red/gold
    blue_rgb: tuple[float, float, float] = (0.0, 0.149, 0.463),  # #002676
    red_rgb: tuple[float, float, float] = (0.769, 0.510, 0.055),   # #FDB515
    blend_ratio: float = 0.7,
    mask_folder: str = None,
    mid_anchor: bool = False,
    video_width: int = 320,   # Video display width
    video_height: int = 180,  # Video display height
    camera_position: tuple[float, float, float] = (1e-3, 1.5, -0.2),
) -> None:
    log_memory_usage("at start of visualization")
    
    if server is None:
        server = viser.ViserServer()
    if share:
        server.request_share_url()

    @server.on_client_connect
    def _(client: viser.ClientHandle) -> None:
        client.camera.position = camera_position
        client.camera.look_at = (0, 0, 0)

    # Configure the GUI panel size and layout
    server.gui.configure_theme(
        control_layout="collapsible",
        control_width="small",
        dark_mode=False,
        show_logo=False,
        show_share_button=True
    )

    # Add video preview to the GUI panel - placed at the top
    video_preview = server.gui.add_image(
        np.zeros((video_height, video_width, 3), dtype=np.uint8),  # Initial blank image
        format="jpeg"
    )
    
    # Use preloaded data if available
    if preloaded_data and preloaded_data.get('loaded', False):
        traj_3d_head1 = preloaded_data.get('traj_3d_head1')
        traj_3d_head2 = preloaded_data.get('traj_3d_head2')
        conf_mask_head1 = preloaded_data.get('conf_mask_head1')
        conf_mask_head2 = preloaded_data.get('conf_mask_head2')
        masks = preloaded_data.get('masks')
        raw_video = preloaded_data.get('raw_video')
        print("Using preloaded data!")
    else:
        # Load data using the shared function
        print("No preloaded data available, loading from files...")
        data = load_trajectory_data(traj_path, use_float16, max_frames, mask_folder, conf_thre_percentile)
        traj_3d_head1 = data.get('traj_3d_head1')
        traj_3d_head2 = data.get('traj_3d_head2')
        conf_mask_head1 = data.get('conf_mask_head1')
        conf_mask_head2 = data.get('conf_mask_head2')
        masks = data.get('masks')
        raw_video = data.get('raw_video')

    def process_video_frame(frame_idx):
        if raw_video is None:
            return np.zeros((video_height, video_width, 3), dtype=np.uint8)
        
        # Get the original frame
        raw_frame = raw_video[frame_idx]
        
        # Adjust value range to 0-255
        if raw_frame.max() <= 1.0:
            frame = (raw_frame * 255).astype(np.uint8)
        else:
            frame = raw_frame.astype(np.uint8)
        
        # Resize to fit the preview window
        h, w = frame.shape[:2]
        # Calculate size while maintaining aspect ratio
        if h/w > video_height/video_width:  # Height limited
            new_h = video_height
            new_w = int(w * (new_h / h))
        else:  # Width limited
            new_w = video_width
            new_h = int(h * (new_w / w))
        
        # Resize
        resized_frame = cv2.resize(frame, (new_w, new_h), interpolation=cv2.INTER_AREA)
        
        # Create a black background
        display_frame = np.zeros((video_height, video_width, 3), dtype=np.uint8)
        
        # Place the resized frame in the center
        y_offset = (video_height - new_h) // 2
        x_offset = (video_width - new_w) // 2
        display_frame[y_offset:y_offset+new_h, x_offset:x_offset+new_w] = resized_frame
        
        return display_frame

    server.scene.set_up_direction(up_dir)
    print("Setting up visualization!")

    # Add visualization controls
    with server.gui.add_folder("Visualization"):
        gui_show_head1 = server.gui.add_checkbox("Tracking Points", True)
        gui_show_head2 = server.gui.add_checkbox("Recon Points", True)
        gui_show_trajectories = server.gui.add_checkbox("Trajectories", True)
        gui_use_color_tint = server.gui.add_checkbox("Use Color Tint", True)

    # Process and center point clouds
    center_point = None
    if traj_3d_head1 is not None:
        xyz_head1 = traj_3d_head1[:, :, :3]
        rgb_head1 = traj_3d_head1[:, :, 3:6]
        if center_point is None:
            center_point = onp.mean(xyz_head1, axis=(0, 1), keepdims=True)
        xyz_head1 -= center_point
        if rgb_head1.sum(axis=(-1)).max() > 125:
            rgb_head1 /= 255.0

    if traj_3d_head2 is not None:
        xyz_head2 = traj_3d_head2[:, :, :3]
        rgb_head2 = traj_3d_head2[:, :, 3:6]
        if center_point is None:
            center_point = onp.mean(xyz_head2, axis=(0, 1), keepdims=True)
        xyz_head2 -= center_point
        if rgb_head2.sum(axis=(-1)).max() > 125:
            rgb_head2 /= 255.0

    # Determine number of frames
    F = max(
        traj_3d_head1.shape[0] if traj_3d_head1 is not None else 0,
        traj_3d_head2.shape[0] if traj_3d_head2 is not None else 0
    )
    num_frames = min(max_frames, F)
    traj_end_frame = min(traj_end_frame, num_frames)
    print(f"Number of frames: {num_frames}")
    xyz_head1 = xyz_head1[:num_frames]
    xyz_head2 = xyz_head2[:num_frames]
    rgb_head1 = rgb_head1[:num_frames]
    rgb_head2 = rgb_head2[:num_frames]

    # Add playback UI.
    with server.gui.add_folder("Playback"):
        gui_timestep = server.gui.add_slider(
            "Timestep",
            min=0,
            max=num_frames - 1,
            step=1,
            initial_value=0,
            disabled=True,
        )
        gui_next_frame = server.gui.add_button("Next Frame", disabled=True)
        gui_prev_frame = server.gui.add_button("Prev Frame", disabled=True)
        gui_playing = server.gui.add_checkbox("Playing", True)
        gui_framerate = server.gui.add_slider(
            "FPS", min=1, max=60, step=0.1, initial_value=20
        )
        gui_framerate_options = server.gui.add_button_group(
            "FPS options", ("10", "20", "30")
        )
        gui_show_all_frames = server.gui.add_checkbox("Show all frames", False)
        gui_stride = server.gui.add_slider(
            "Stride",
            min=1,
            max=num_frames,
            step=1,
            initial_value=5,
            disabled=True,  # Initially disabled
        )

    # Frame step buttons.
    @gui_next_frame.on_click
    def _(_) -> None:
        gui_timestep.value = (gui_timestep.value + 1) % num_frames

    @gui_prev_frame.on_click
    def _(_) -> None:
        gui_timestep.value = (gui_timestep.value - 1) % num_frames

    # Disable frame controls when we're playing.
    @gui_playing.on_update
    def _(_) -> None:
        gui_timestep.disabled = gui_playing.value or gui_show_all_frames.value
        gui_next_frame.disabled = gui_playing.value or gui_show_all_frames.value
        gui_prev_frame.disabled = gui_playing.value or gui_show_all_frames.value

    # Set the framerate when we click one of the options.
    @gui_framerate_options.on_click
    def _(_) -> None:
        gui_framerate.value = int(gui_framerate_options.value)

    prev_timestep = gui_timestep.value

    # Toggle frame visibility when the timestep slider changes.
    @gui_timestep.on_update
    def _(_) -> None:
        nonlocal prev_timestep
        current_timestep = gui_timestep.value
        if not gui_show_all_frames.value:
            with server.atomic():
                if gui_show_head1.value:
                    frame_nodes_head1[current_timestep].visible = True
                    frame_nodes_head1[prev_timestep].visible = False
                if gui_show_head2.value:
                    frame_nodes_head2[current_timestep].visible = True
                    frame_nodes_head2[prev_timestep].visible = False
        prev_timestep = current_timestep
        server.flush()  # Optional!

    # Show or hide all frames based on the checkbox.
    @gui_show_all_frames.on_update
    def _(_) -> None:
        gui_stride.disabled = not gui_show_all_frames.value  # Enable/disable stride slider
        if gui_show_all_frames.value:
            # Show frames with stride
            stride = gui_stride.value
            with server.atomic():
                for i, (node1, node2) in enumerate(zip(frame_nodes_head1, frame_nodes_head2)):
                    node1.visible = gui_show_head1.value and (i % stride == 0)
                    node2.visible = gui_show_head2.value and (i % stride == 0)
            # Disable playback controls
            gui_playing.disabled = True
            gui_timestep.disabled = True
            gui_next_frame.disabled = True
            gui_prev_frame.disabled = True
        else:
            # Show only the current frame
            current_timestep = gui_timestep.value
            with server.atomic():
                for i, (node1, node2) in enumerate(zip(frame_nodes_head1, frame_nodes_head2)):
                    node1.visible = gui_show_head1.value and (i == current_timestep)
                    node2.visible = gui_show_head2.value and (i == current_timestep)
            # Re-enable playback controls
            gui_playing.disabled = False
            gui_timestep.disabled = gui_playing.value
            gui_next_frame.disabled = gui_playing.value
            gui_prev_frame.disabled = gui_playing.value

    # Update frame visibility when the stride changes.
    @gui_stride.on_update
    def _(_) -> None:
        if gui_show_all_frames.value:
            # Update frame visibility based on new stride
            stride = gui_stride.value
            with server.atomic():
                for i, (node1, node2) in enumerate(zip(frame_nodes_head1, frame_nodes_head2)):
                    node1.visible = gui_show_head1.value and (i % stride == 0)
                    node2.visible = gui_show_head2.value and (i % stride == 0)

    # Load in frames.
    server.scene.add_frame(
        "/frames",
        wxyz=tf.SO3.exp(onp.array([onp.pi / 2.0, 0.0, 0.0])).wxyz,
        position=(0, 0, 0),
        show_axes=False,
    )
    frame_nodes_head1: list[viser.FrameHandle] = []
    frame_nodes_head2: list[viser.FrameHandle] = []

    # Extract RGB components for tinting
    blue_r, blue_g, blue_b = blue_rgb
    red_r, red_g, red_b = red_rgb
    
    # Create frames for each timestep
    frame_nodes_head1 = []
    frame_nodes_head2 = []
    for i in tqdm(range(num_frames)):
        # Process head1
        if traj_3d_head1 is not None:
            frame_nodes_head1.append(server.scene.add_frame(f"/frames/t{i}/head1", show_axes=False))
            position = xyz_head1[i]
            color = rgb_head1[i]
            if conf_mask_head1 is not None:
                position = position[conf_mask_head1[i]]
                color = color[conf_mask_head1[i]]
            
            # Add point cloud for head1 with optional blue tint
            color_head1 = color.copy()
            if gui_use_color_tint.value:
                color_head1 *= blend_ratio
                color_head1[:, 0] = onp.clip(color_head1[:, 0] + blue_r * (1 - blend_ratio), 0, 1)  # R
                color_head1[:, 1] = onp.clip(color_head1[:, 1] + blue_g * (1 - blend_ratio), 0, 1)  # G
                color_head1[:, 2] = onp.clip(color_head1[:, 2] + blue_b * (1 - blend_ratio), 0, 1)  # B
            
            server.scene.add_point_cloud(
                name=f"/frames/t{i}/head1/point_cloud",
                points=position[::downsample_factor],
                colors=color_head1[::downsample_factor],
                point_size=point_size,
                point_shape="rounded",
            )

        # Process head2
        if traj_3d_head2 is not None:
            frame_nodes_head2.append(server.scene.add_frame(f"/frames/t{i}/head2", show_axes=False))
            position = xyz_head2[i]
            color = rgb_head2[i]
            if conf_mask_head2 is not None:
                position = position[conf_mask_head2[i]]
                color = color[conf_mask_head2[i]]
            
            # Add point cloud for head2 with optional red tint
            color_head2 = color.copy()
            if gui_use_color_tint.value:
                color_head2 *= blend_ratio
                color_head2[:, 0] = onp.clip(color_head2[:, 0] + red_r * (1 - blend_ratio), 0, 1)  # R
                color_head2[:, 1] = onp.clip(color_head2[:, 1] + red_g * (1 - blend_ratio), 0, 1)  # G
                color_head2[:, 2] = onp.clip(color_head2[:, 2] + red_b * (1 - blend_ratio), 0, 1)  # B
            
            server.scene.add_point_cloud(
                name=f"/frames/t{i}/head2/point_cloud",
                points=position[::downsample_factor],
                colors=color_head2[::downsample_factor],
                point_size=point_size,
                point_shape="rounded",
            )

    # Update visibility based on checkboxes
    @gui_show_head1.on_update
    def _(_) -> None:
        with server.atomic():
            for frame_node in frame_nodes_head1:
                frame_node.visible = gui_show_head1.value and (
                    gui_show_all_frames.value
                    or (not gui_show_all_frames.value )
                )

    @gui_show_head2.on_update
    def _(_) -> None:
        with server.atomic():
            for frame_node in frame_nodes_head2:
                frame_node.visible = gui_show_head2.value and (
                    gui_show_all_frames.value
                    or (not gui_show_all_frames.value )
                )

    # Initial visibility
    for i, (node1, node2) in enumerate(zip(frame_nodes_head1, frame_nodes_head2)):
        if gui_show_all_frames.value:
            node1.visible = gui_show_head1.value and (i % gui_stride.value == 0)
            node2.visible = gui_show_head2.value and (i % gui_stride.value == 0)
        else:
            node1.visible = gui_show_head1.value and (i == gui_timestep.value)
            node2.visible = gui_show_head2.value and (i == gui_timestep.value)

    # Process and visualize trajectories for head1
    if traj_3d_head1 is not None:
        # Get points over time
        xyz_head1_centered = xyz_head1.copy()
        
        # Select points to visualize
        num_points = xyz_head1.shape[1]
        points_to_visualize = min(num_points, num_traj_points)
        
        # Get the mask for the first frame and reshape it to match point cloud dimensions
        if mid_anchor:
            first_frame_mask = masks[num_frames//2].reshape(-1)
        else:
            first_frame_mask = masks[0].reshape(-1) #[#points, h]
        
        # Calculate trajectory lengths for each point
        trajectories = xyz_head1_centered[traj_start_frame:traj_end_frame]  # Shape: (num_frames, num_points, 3)
        traj_diffs = np.diff(trajectories, axis=0)  # Differences between consecutive frames
        traj_lengths = np.sum(np.sqrt(np.sum(traj_diffs**2, axis=-1)), axis=0)  # Sum of distances for each point
        
        # Get points that are within the mask
        valid_indices = np.where(first_frame_mask)[0]
        
        if len(valid_indices) > 0:
            # Calculate average trajectory length for masked points
            masked_traj_lengths = traj_lengths[valid_indices]
            avg_traj_length = np.mean(masked_traj_lengths)
            
            if mask_folder is not None:
                # do not filter points by trajectory length
                long_traj_indices = valid_indices
            else:
                # Filter points by trajectory length
                long_traj_indices = valid_indices[masked_traj_lengths >= avg_traj_length]
            
            # Randomly sample from the filtered points
            if len(long_traj_indices) > 0:
                # Random sampling without replacement
                selected_indices = np.random.choice(
                    len(long_traj_indices),
                    min(points_to_visualize, len(long_traj_indices)),
                    replace=False
                )
                # Get the actual indices in their original order
                valid_point_indices = long_traj_indices[np.sort(selected_indices)]
            else:
                valid_point_indices = np.array([])
        else:
            valid_point_indices = np.array([])
        
        if len(valid_point_indices) > 0:
            # Get trajectories for all valid points
            trajectories = xyz_head1_centered[traj_start_frame:traj_end_frame, valid_point_indices]
            N_point = trajectories.shape[1]
            if color_code == "rainbow":
                point_colors = plt.cm.rainbow(np.linspace(0, 1, N_point))[:, :3]
            elif color_code == "jet":
                point_colors = plt.cm.jet(np.linspace(0, 1, N_point))[:, :3]
            # Modify the loop to handle frames less than fixed_length_traj
            for i in range(traj_end_frame - traj_start_frame):
                # Calculate the actual trajectory length for this frame
                actual_length = min(fixed_length_traj, i + 1)
                
                if actual_length > 1:  # Need at least 2 points to form a line
                    # Get the appropriate slice of trajectory data
                    start_idx = max(0, i - actual_length + 1)
                    end_idx = i + 1
                    
                    # Create line segments between consecutive frames
                    traj_slice = trajectories[start_idx:end_idx]
                    line_points = np.stack([traj_slice[:-1], traj_slice[1:]], axis=2)
                    line_points = line_points.reshape(-1, 2, 3)
                    
                    # Create corresponding colors
                    line_colors = np.tile(point_colors, (actual_length-1, 1))
                    line_colors = np.stack([line_colors, line_colors], axis=1)
                    
                    # Add line segments
                    server.scene.add_line_segments(
                        name=f"/frames/t{i+traj_start_frame}/head1/trajectory",
                        points=line_points,
                        colors=line_colors,
                        line_width=traj_line_width,
                        visible=gui_show_trajectories.value
                    )

    # Add trajectory controls functionality
    @gui_show_trajectories.on_update
    def _(_) -> None:
        with server.atomic():
            # Remove all existing trajectories
            for i in range(num_frames):
                try:
                    server.scene.remove_by_name(f"/frames/t{i}/head1/trajectory")
                except KeyError:
                    pass
            
            # Create new trajectories if enabled
            if gui_show_trajectories.value and traj_3d_head1 is not None:
                # Get the mask for the last frame and reshape it
                last_frame_mask = masks[traj_end_frame-1].reshape(-1)
                
                # Calculate trajectory lengths
                trajectories = xyz_head1_centered[traj_start_frame:traj_end_frame]
                traj_diffs = np.diff(trajectories, axis=0)
                traj_lengths = np.sum(np.sqrt(np.sum(traj_diffs**2, axis=-1)), axis=0)
                
                # Get points that are within the mask
                valid_indices = np.where(last_frame_mask)[0]
                
                if len(valid_indices) > 0:
                    # Filter by trajectory length
                    masked_traj_lengths = traj_lengths[valid_indices]
                    avg_traj_length = np.mean(masked_traj_lengths)
                    long_traj_indices = valid_indices[masked_traj_lengths >= avg_traj_length]
                    
                    # Randomly sample from the filtered points
                    if len(long_traj_indices) > 0:
                        # Random sampling without replacement
                        selected_indices = np.random.choice(
                            len(long_traj_indices),
                            min(points_to_visualize, len(long_traj_indices)),
                            replace=False
                        )
                        # Get the actual indices in their original order
                        valid_point_indices = long_traj_indices[np.sort(selected_indices)]
                    else:
                        valid_point_indices = np.array([])
                else:
                    valid_point_indices = np.array([])

                if len(valid_point_indices) > 0:
                    # Get trajectories for all valid points
                    trajectories = xyz_head1_centered[traj_start_frame:traj_end_frame, valid_point_indices]
                    N_point = trajectories.shape[1]

                    if color_code == "rainbow":
                        point_colors = plt.cm.rainbow(np.linspace(0, 1, N_point))[:, :3]
                    elif color_code == "jet":
                        point_colors = plt.cm.jet(np.linspace(0, 1, N_point))[:, :3]
                    
                    # Modify the loop to handle frames less than fixed_length_traj
                    for i in range(traj_end_frame - traj_start_frame):
                        # Calculate the actual trajectory length for this frame
                        actual_length = min(fixed_length_traj, i + 1)
                        
                        if actual_length > 1:  # Need at least 2 points to form a line
                            # Get the appropriate slice of trajectory data
                            start_idx = max(0, i - actual_length + 1)
                            end_idx = i + 1
                            
                            # Create line segments between consecutive frames
                            traj_slice = trajectories[start_idx:end_idx]
                            line_points = np.stack([traj_slice[:-1], traj_slice[1:]], axis=2)
                            line_points = line_points.reshape(-1, 2, 3)
                            
                            # Create corresponding colors
                            line_colors = np.tile(point_colors, (actual_length-1, 1))
                            line_colors = np.stack([line_colors, line_colors], axis=1)
                            
                            # Add line segments
                            server.scene.add_line_segments(
                                name=f"/frames/t{i+traj_start_frame}/head1/trajectory",
                                points=line_points,
                                colors=line_colors,
                                line_width=traj_line_width,
                                visible=True
                            )

    # Update color tinting when the checkbox changes
    @gui_use_color_tint.on_update
    def _(_) -> None:
        with server.atomic():
            for i in range(num_frames):
                # Update head1 point cloud
                if traj_3d_head1 is not None:
                    position = xyz_head1[i]
                    color = rgb_head1[i]
                    if conf_mask_head1 is not None:
                        position = position[conf_mask_head1[i]]
                        color = color[conf_mask_head1[i]]
                    
                    color_head1 = color.copy()
                    if gui_use_color_tint.value:
                        color_head1 *= blend_ratio
                        color_head1[:, 0] = onp.clip(color_head1[:, 0] + blue_r * (1 - blend_ratio), 0, 1)  # R
                        color_head1[:, 1] = onp.clip(color_head1[:, 1] + blue_g * (1 - blend_ratio), 0, 1)  # G
                        color_head1[:, 2] = onp.clip(color_head1[:, 2] + blue_b * (1 - blend_ratio), 0, 1)  # B
                    
                    server.scene.remove_by_name(f"/frames/t{i}/head1/point_cloud")
                    server.scene.add_point_cloud(
                        name=f"/frames/t{i}/head1/point_cloud",
                        points=position[::downsample_factor],
                        colors=color_head1[::downsample_factor],
                        point_size=point_size,
                        point_shape="rounded",
                    )
                
                # Update head2 point cloud
                if traj_3d_head2 is not None:
                    position = xyz_head2[i]
                    color = rgb_head2[i]
                    if conf_mask_head2 is not None:
                        position = position[conf_mask_head2[i]]
                        color = color[conf_mask_head2[i]]
                    
                    color_head2 = color.copy()
                    if gui_use_color_tint.value:
                        color_head2 *= blend_ratio
                        color_head2[:, 0] = onp.clip(color_head2[:, 0] + red_r * (1 - blend_ratio), 0, 1)  # R
                        color_head2[:, 1] = onp.clip(color_head2[:, 1] + red_g * (1 - blend_ratio), 0, 1)  # G
                        color_head2[:, 2] = onp.clip(color_head2[:, 2] + red_b * (1 - blend_ratio), 0, 1)  # B
                    
                    server.scene.remove_by_name(f"/frames/t{i}/head2/point_cloud")
                    server.scene.add_point_cloud(
                        name=f"/frames/t{i}/head2/point_cloud",
                        points=position[::downsample_factor],
                        colors=color_head2[::downsample_factor],
                        point_size=point_size,
                        point_shape="rounded",
                    )

    # Initialize video preview
    if raw_video is not None:
        video_preview.image = process_video_frame(0)
    
    # Update video preview when timestep changes
    @gui_timestep.on_update
    def _(_) -> None:
        current_timestep = gui_timestep.value
        if raw_video is not None:
            video_preview.image = process_video_frame(current_timestep)
    
    # Playback update loop.
    log_memory_usage("before starting playback loop")
    
    prev_timestep = gui_timestep.value
    while True:
        current_timestep = gui_timestep.value
        
        # If timestep changes, update frame visibility
        if current_timestep != prev_timestep:
            with server.atomic():
                # ... existing code ...
                
                # Update video preview
                if raw_video is not None:
                    video_preview.image = process_video_frame(current_timestep)
        
        # Update in playback mode
        if gui_playing.value and not gui_show_all_frames.value:
            gui_timestep.value = (gui_timestep.value + 1) % num_frames
            
            # Update video preview in playback mode
            if raw_video is not None:
                video_preview.image = process_video_frame(gui_timestep.value)
        
        time.sleep(1.0 / gui_framerate.value)


if __name__ == "__main__":
    tyro.cli(visualize_st4rtrack)