Spaces:
bla
/
Runtime error

File size: 28,199 Bytes
ea6a5ed
 
 
 
 
9bc4638
6e871ac
ea6a5ed
807f473
6e871ac
ea6a5ed
9bc4638
 
6e871ac
ea6a5ed
6e871ac
ea6a5ed
807f473
9bc4638
ea6a5ed
6e871ac
9bc4638
ea6a5ed
 
 
 
807f473
 
ea6a5ed
 
 
807f473
 
 
 
ea6a5ed
 
 
807f473
ea6a5ed
 
 
 
 
 
 
 
 
9bc4638
807f473
6e60611
 
b950bc5
807f473
ea6a5ed
 
 
 
 
 
807f473
ea6a5ed
807f473
ea6a5ed
807f473
ea6a5ed
807f473
ea6a5ed
 
 
 
5bc3a57
807f473
ea6a5ed
9bc4638
ea6a5ed
 
807f473
9bc4638
0b34400
9bc4638
 
ea6a5ed
807f473
ea6a5ed
 
 
 
 
 
 
807f473
ea6a5ed
 
 
 
 
 
807f473
ea6a5ed
 
 
 
807f473
ea6a5ed
 
 
 
 
 
 
 
807f473
ea6a5ed
 
e508568
ea6a5ed
807f473
 
ea6a5ed
807f473
 
 
 
 
ea6a5ed
 
e508568
628bfb2
9bc4638
ea6a5ed
807f473
ea6a5ed
807f473
 
 
 
 
ea6a5ed
 
 
 
 
6e871ac
807f473
628bfb2
807f473
628bfb2
807f473
ea6a5ed
628bfb2
ea6a5ed
 
807f473
 
 
628bfb2
 
 
 
 
 
 
 
ea6a5ed
 
 
628bfb2
9bc4638
 
ea6a5ed
 
628bfb2
 
807f473
 
 
 
628bfb2
 
807f473
 
628bfb2
807f473
 
628bfb2
807f473
 
 
 
 
628bfb2
ea6a5ed
 
9bc4638
628bfb2
807f473
 
 
 
 
 
 
 
 
 
628bfb2
807f473
628bfb2
ea6a5ed
 
 
 
807f473
 
ea6a5ed
807f473
 
 
 
 
 
 
 
 
628bfb2
ea6a5ed
807f473
 
 
 
 
ea6a5ed
 
 
807f473
 
 
 
 
 
 
 
 
 
ea6a5ed
 
 
 
 
 
 
 
 
807f473
ea6a5ed
 
5dc8194
ea6a5ed
 
807f473
 
ea6a5ed
5b2f03e
 
807f473
 
ea6a5ed
 
 
807f473
 
 
 
ea6a5ed
807f473
ea6a5ed
807f473
 
 
ea6a5ed
5b2f03e
ea6a5ed
 
807f473
b950bc5
807f473
 
 
 
 
 
 
ea6a5ed
 
 
 
 
 
 
807f473
5b2f03e
 
ea6a5ed
5b2f03e
 
 
ea6a5ed
5b2f03e
 
807f473
5b2f03e
 
 
 
 
 
 
628bfb2
 
 
 
 
5b2f03e
 
 
807f473
5b2f03e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b950bc5
ea6a5ed
807f473
 
 
ea6a5ed
9bc4638
5b2f03e
9bc4638
ea6a5ed
807f473
5b2f03e
 
 
 
 
 
 
 
807f473
5b2f03e
807f473
 
 
 
 
5b2f03e
 
 
 
 
 
 
807f473
5b2f03e
 
 
 
 
 
 
ea6a5ed
5b2f03e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ea6a5ed
807f473
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ea6a5ed
807f473
 
 
ea6a5ed
807f473
 
 
 
5b2f03e
628bfb2
807f473
 
5b2f03e
807f473
 
 
ea6a5ed
807f473
ea6a5ed
807f473
ea6a5ed
 
807f473
 
5b2f03e
 
807f473
5b2f03e
807f473
 
5b2f03e
807f473
 
ea6a5ed
807f473
 
ea6a5ed
 
 
 
807f473
 
ea6a5ed
 
807f473
 
 
 
 
 
 
5b2f03e
807f473
 
628bfb2
ea6a5ed
807f473
 
 
 
 
 
ea6a5ed
 
807f473
 
ea6a5ed
807f473
5b2f03e
 
 
 
ea6a5ed
 
807f473
ea6a5ed
807f473
 
 
 
 
 
 
 
 
5b2f03e
 
ea6a5ed
807f473
5b2f03e
 
 
 
 
807f473
5b2f03e
 
807f473
 
5b2f03e
 
 
 
628bfb2
 
 
 
5b2f03e
 
 
 
 
 
 
 
 
 
807f473
5b2f03e
 
807f473
5b2f03e
 
 
 
 
807f473
 
 
5b2f03e
807f473
5b2f03e
807f473
 
ea6a5ed
 
 
807f473
ea6a5ed
807f473
 
ea6a5ed
807f473
ea6a5ed
9bc4638
807f473
ea6a5ed
807f473
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ea6a5ed
807f473
 
ea6a5ed
807f473
ea6a5ed
807f473
 
ea6a5ed
807f473
 
ea6a5ed
807f473
 
 
 
 
 
 
ea6a5ed
 
 
807f473
 
 
 
 
 
ea6a5ed
 
807f473
 
 
 
 
 
 
ea6a5ed
 
807f473
 
 
ea6a5ed
807f473
9bc4638
807f473
 
 
 
 
 
 
 
 
 
 
 
 
ea6a5ed
 
807f473
ea6a5ed
 
807f473
ea6a5ed
807f473
ea6a5ed
807f473
ea6a5ed
 
807f473
ea6a5ed
807f473
 
 
 
 
 
ea6a5ed
 
 
807f473
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ea6a5ed
 
807f473
 
 
 
 
 
 
 
ea6a5ed
 
 
807f473
ea6a5ed
 
807f473
 
 
 
 
 
 
 
 
 
 
 
ea6a5ed
 
 
 
 
807f473
 
 
 
 
 
 
 
 
 
 
 
ea6a5ed
 
 
628bfb2
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.

# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

import copy
import os
import time
from datetime import datetime
import tempfile

import cv2
import matplotlib.pyplot as plt
import numpy as np
import gradio as gr
import torch

from moviepy.editor import ImageSequenceClip
from PIL import Image
from sam2.build_sam import build_sam2_video_predictor

# Remove CUDA environment variables
if 'TORCH_CUDNN_SDPA_ENABLED' in os.environ:
    del os.environ["TORCH_CUDNN_SDPA_ENABLED"]

# Description
title = "<center><strong><font size='8'>EdgeTAM CPU<font></strong> <a href='https://github.com/facebookresearch/EdgeTAM'><font size='6'>[GitHub]</font></a> </center>"

description_p = """# Instructions
                <ol>
                <li> Upload one video or click one example video</li>
                <li> Click 'include' point type, select the object to segment and track</li>
                <li> Click 'exclude' point type (optional), select the area you want to avoid segmenting and tracking</li>
                <li> Click the 'Track' button to obtain the masked video </li>
                </ol>
              """

# examples - keeping fewer examples to reduce memory footprint
examples = [
    ["examples/01_dog.mp4"],
    ["examples/02_cups.mp4"],
    ["examples/03_blocks.mp4"],
    ["examples/04_coffee.mp4"],
    ["examples/05_default_juggle.mp4"],
]

OBJ_ID = 0

# Initialize model on CPU - add error handling for file paths
sam2_checkpoint = "checkpoints/edgetam.pt"
model_cfg = "edgetam.yaml"

# Check if model files exist
def check_file_exists(filepath):
    exists = os.path.exists(filepath)
    if not exists:
        print(f"WARNING: File not found: {filepath}")
    return exists

# Verify files exist
model_files_exist = check_file_exists(sam2_checkpoint) and check_file_exists(model_cfg)
predictor = None
try:
    # Load model with careful error handling
    predictor = build_sam2_video_predictor(model_cfg, sam2_checkpoint, device="cpu")
    print("predictor loaded on CPU")
except Exception as e:
    print(f"Error loading model: {e}")
    import traceback
    traceback.print_exc()

# Function to get video frame rate
def get_video_fps(video_path):
    cap = cv2.VideoCapture(video_path)
    if not cap.isOpened():
        print("Error: Could not open video.")
        return 30.0  # Default fallback value
    fps = cap.get(cv2.CAP_PROP_FPS)
    cap.release()
    return fps

def reset(session_state):
    """Reset all session state variables and UI elements."""
    session_state["input_points"] = []
    session_state["input_labels"] = []
    if session_state["inference_state"] is not None:
        predictor.reset_state(session_state["inference_state"])
    session_state["first_frame"] = None
    session_state["all_frames"] = None
    session_state["inference_state"] = None
    session_state["progress"] = 0
    return (
        None,
        gr.update(open=True),
        None,
        None,
        gr.update(value=None, visible=False),
        gr.update(value=0, visible=False),
        session_state,
    )

def clear_points(session_state):
    """Clear tracking points while keeping the video frames."""
    session_state["input_points"] = []
    session_state["input_labels"] = []
    if session_state["inference_state"] is not None and session_state["inference_state"].get("tracking_has_started", False):
        predictor.reset_state(session_state["inference_state"])
    return (
        session_state["first_frame"],
        None,
        gr.update(value=None, visible=False),
        gr.update(value=0, visible=False),
        session_state,
    )

def preprocess_video_in(video_path, session_state):
    """Process input video to extract frames for tracking."""
    if video_path is None or not os.path.exists(video_path):
        return (
            gr.update(open=True),  # video_in_drawer
            None,  # points_map
            None,  # output_image
            gr.update(value=None, visible=False),  # output_video
            gr.update(value=0, visible=False),  # progress_bar
            session_state,
        )

    # Read the video
    cap = cv2.VideoCapture(video_path)
    if not cap.isOpened():
        print(f"Error: Could not open video at {video_path}.")
        return (
            gr.update(open=True),  # video_in_drawer
            None,  # points_map
            None,  # output_image
            gr.update(value=None, visible=False),  # output_video
            gr.update(value=0, visible=False),  # progress_bar
            session_state,
        )

    frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
    frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
    total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
    fps = cap.get(cv2.CAP_PROP_FPS)

    print(f"Video info: {frame_width}x{frame_height}, {total_frames} frames, {fps} FPS")

    target_width = 640  # Target width for processing on CPU
    scale_factor = 1.0

    if frame_width > target_width:
        scale_factor = target_width / frame_width
        new_width = int(frame_width * scale_factor)
        new_height = int(frame_height * scale_factor)
        print(f"Resizing video for CPU processing: {frame_width}x{frame_height} -> {new_width}x{new_height}")

    # Even more aggressive frame skipping for very long videos on CPU
    frame_stride = 1
    max_frames = 150 # Maximum number of frames to process
    if total_frames > max_frames:
        frame_stride = max(1, int(total_frames / max_frames))
        print(f"Video has {total_frames} frames, using stride of {frame_stride} to limit to {max_frames}")

    frame_number = 0
    first_frame = None
    all_frames = []

    while True:
        ret, frame = cap.read()
        if not ret:
            break

        if frame_number % frame_stride == 0:
            try:
                # Resize the frame if needed
                if scale_factor != 1.0:
                    frame = cv2.resize(
                        frame,
                        (int(frame_width * scale_factor), int(frame_height * scale_factor)),
                        interpolation=cv2.INTER_AREA
                    )

                frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
                frame = np.array(frame)

                if first_frame is None:
                    first_frame = frame
                all_frames.append(frame)
            except Exception as e:
                print(f"Error processing frame {frame_number}: {e}")

        frame_number += 1

    cap.release()

    if first_frame is None or len(all_frames) == 0:
        print("Error: No frames could be extracted from the video.")
        return (
            gr.update(open=True),  # video_in_drawer
            None,  # points_map
            None,  # output_image
            gr.update(value=None, visible=False),  # output_video
            gr.update(value=0, visible=False),  # progress_bar
            session_state,
        )

    print(f"Successfully extracted {len(all_frames)} frames from video")

    session_state["first_frame"] = copy.deepcopy(first_frame)
    session_state["all_frames"] = all_frames
    session_state["frame_stride"] = frame_stride
    session_state["scale_factor"] = scale_factor
    session_state["original_dimensions"] = (frame_width, frame_height)
    session_state["progress"] = 0

    try:
        session_state["inference_state"] = predictor.init_state(video_path=video_path)
        session_state["input_points"] = []
        session_state["input_labels"] = []
    except Exception as e:
        print(f"Error initializing inference state: {e}")
        import traceback
        traceback.print_exc()
        session_state["inference_state"] = None

    return [
        gr.update(open=False),  # video_in_drawer
        first_frame,  # points_map
        None,  # output_image
        gr.update(value=None, visible=False),  # output_video
        gr.update(value=0, visible=False),  # progress_bar
        session_state,
    ]

def segment_with_points(
    point_type,
    session_state,
    evt: gr.SelectData,
):
    """Add and process tracking points on the first frame."""
    if session_state["first_frame"] is None:
        print("Error: No frame available for segmentation")
        return None, None, session_state
    
    session_state["input_points"].append(evt.index)
    print(f"TRACKING INPUT POINT: {session_state['input_points']}")

    if point_type == "include":
        session_state["input_labels"].append(1)
    elif point_type == "exclude":
        session_state["input_labels"].append(0)
    print(f"TRACKING INPUT LABEL: {session_state['input_labels']}")

    # Open the image and get its dimensions
    first_frame = session_state["first_frame"]
    h, w = first_frame.shape[:2]
    from PIL import Image
    transparent_background = Image.fromarray(first_frame).convert("RGBA")

    # Define the circle radius as a fraction of the smaller dimension
    fraction = 0.01  # You can adjust this value as needed
    radius = int(fraction * min(w, h))
    if radius < 3:
        radius = 3  # Ensure minimum visibility

    # Create a transparent layer to draw on
    transparent_layer = np.zeros((h, w, 4), dtype=np.uint8)

    for index, track in enumerate(session_state["input_points"]):
        if session_state["input_labels"][index] == 1:
            cv2.circle(transparent_layer, track, radius, (0, 255, 0, 255), -1)  # Green for include
        else:
            cv2.circle(transparent_layer, track, radius, (255, 0, 0, 255), -1)  # Red for exclude

    # Convert the transparent layer back to an image
    transparent_layer = Image.fromarray(transparent_layer, "RGBA")
    selected_point_map = Image.alpha_composite(
        transparent_background, transparent_layer
    )

    # Use the clicked points and labels
    points = np.array(session_state["input_points"], dtype=np.float32)
    labels = np.array(session_state["input_labels"], np.int32)
    
    try:
        if predictor is None:
            raise ValueError("Model predictor is not initialized")
            
        if session_state["inference_state"] is None:
            raise ValueError("Inference state is not initialized")
            
        # For CPU optimization, we'll process with smaller batch size
        _, _, out_mask_logits = predictor.add_new_points(
            inference_state=session_state["inference_state"],
            frame_idx=0,
            obj_id=OBJ_ID,
            points=points,
            labels=labels,
        )
        
        # Create the mask and check dimensions first
        out_mask = (out_mask_logits[0] > 0.0).cpu().numpy()
        
        # Convert to RGB for visualization
        # Create an overlay with semi-transparent color
        overlay = np.zeros((h, w, 3), dtype=np.uint8)
        
        # Create a colored mask - blue with opacity
        overlay_mask = np.zeros_like(overlay)
        
        # Resize mask carefully if needed - handle empty dimensions
        if out_mask.shape[0] > 0 and out_mask.shape[1] > 0:
            # Check if dimensions differ
            if out_mask.shape[:2] != (h, w):
                print(f"Resizing mask from {out_mask.shape[:2]} to {h}x{w}")
                # Use numpy/PIL for resizing to avoid OpenCV issues
                from PIL import Image
                
                # Ensure mask is boolean type
                if out_mask.dtype != np.bool_:
                    out_mask = out_mask > 0
                
                mask_img = Image.fromarray(out_mask.astype(np.uint8) * 255)
                mask_img = mask_img.resize((w, h), Image.NEAREST)
                out_mask = np.array(mask_img) > 0
            
            # Apply mask color
            overlay_mask[out_mask] = [0, 120, 255]  # Blue color for mask
        
        # Blend original frame with mask
        alpha = 0.5  # Opacity
        frame_with_mask = cv2.addWeighted(
            first_frame, 1, overlay_mask, alpha, 0
        )
        
        # Add points on top of mask
        points_overlay = np.zeros((h, w, 4), dtype=np.uint8)
        for index, track in enumerate(session_state["input_points"]):
            if session_state["input_labels"][index] == 1:
                cv2.circle(points_overlay, track, radius, (0, 255, 0, 255), -1)  # Green
            else:
                cv2.circle(points_overlay, track, radius, (255, 0, 0, 255), -1)  # Red
        
        # Convert to PIL for overlay
        frame_with_mask_pil = Image.fromarray(frame_with_mask)
        points_overlay_pil = Image.fromarray(points_overlay, "RGBA")
        
        # Final composite
        first_frame_output = Image.alpha_composite(
            frame_with_mask_pil.convert("RGBA"), points_overlay_pil
        )
    except Exception as e:
        print(f"Error in segmentation: {e}")
        import traceback
        traceback.print_exc()
        # Return just the points as fallback
        first_frame_output = selected_point_map

    return selected_point_map, np.array(first_frame_output), session_state

def show_mask(mask, obj_id=None, random_color=False, convert_to_image=True):
    """Convert binary mask to RGBA image for visualization."""
    # Check if mask is valid
    if mask is None or mask.size == 0:
        print("Warning: Empty mask provided to show_mask")
        # Return an empty transparent mask
        if convert_to_image:
            return Image.new('RGBA', (100, 100), (0, 0, 0, 0))
        else:
            return np.zeros((100, 100, 4), dtype=np.uint8)
    
    # Get mask dimensions
    if len(mask.shape) == 2:
        h, w = mask.shape
    else:
        h, w = mask.shape[-2:]
    
    if h == 0 or w == 0:
        print(f"Warning: Invalid mask dimensions: {h}x{w}")
        # Return an empty transparent mask
        if convert_to_image:
            return Image.new('RGBA', (100, 100), (0, 0, 0, 0))
        else:
            return np.zeros((100, 100, 4), dtype=np.uint8)
    
    # Set the color for visualization
    if random_color:
        color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
    else:
        cmap = plt.get_cmap("tab10")
        cmap_idx = 0 if obj_id is None else obj_id
        color = np.array([*cmap(cmap_idx)[:3], 0.6])
    
    try:
        # Create a colored visualization of the mask
        colored_mask = np.zeros((h, w, 4), dtype=np.uint8)
        
        # Apply color to mask areas (where mask is True)
        for i in range(3):  # RGB channels
            colored_mask[:, :, i] = (mask * color[i] * 255).astype(np.uint8)
        
        # Set alpha channel
        colored_mask[:, :, 3] = (mask * color[3] * 255).astype(np.uint8)
        
        if convert_to_image:
            return Image.fromarray(colored_mask, "RGBA")
        else:
            return colored_mask
    except Exception as e:
        print(f"Error in show_mask: {e}")
        import traceback
        traceback.print_exc()
        
        # Return a fallback transparent image
        if convert_to_image:
            return Image.new('RGBA', (h, w), (0, 0, 0, 0))
        else:
            return np.zeros((h, w, 4), dtype=np.uint8)

def update_progress(progress_percent, progress_bar):
    """Update progress bar during processing."""
    return gr.update(value=progress_percent, visible=True)

def propagate_to_all(
    video_in,
    session_state,
    progress=gr.Progress(),
):
    """Process video frames and generate masked video output with progress tracking."""
    if (
        len(session_state["input_points"]) == 0
        or video_in is None
        or session_state["inference_state"] is None
        or predictor is None
    ):
        print("Missing required data for tracking")
        return (
            gr.update(value=None, visible=False),
            gr.update(value=0, visible=False),
            session_state,
        )

    # For CPU optimization: process in smaller batches
    chunk_size = 3  # Process 3 frames at a time to avoid memory issues on CPU
    
    try:
        # run propagation throughout the video and collect the results in a dict
        video_segments = {}  # video_segments contains the per-frame segmentation results
        print("Starting propagate_in_video on CPU")
        
        # Get the count for progress reporting (estimate)
        all_frames_count = 100  # Reasonable estimate
        
        # Now do the actual processing with progress updates
        current_frame = 0
        for out_frame_idx, out_obj_ids, out_mask_logits in predictor.propagate_in_video(
            session_state["inference_state"]
        ):
            try:
                # Store the masks for each object ID
                video_segments[out_frame_idx] = {
                    out_obj_id: (out_mask_logits[i] > 0.0).cpu().numpy()
                    for i, out_obj_id in enumerate(out_obj_ids)
                }
                
                # Update progress
                current_frame += 1
                progress_percent = min(50, int((current_frame / all_frames_count) * 50))
                session_state["progress"] = progress_percent
                progress(progress_percent/100, desc="Processing frames")
                
                if out_frame_idx % 10 == 0:
                    print(f"Processed frame {out_frame_idx} ({progress_percent}%)")
                
                # Release memory periodically
                if out_frame_idx % chunk_size == 0:
                    # Explicitly clear any tensors
                    del out_mask_logits
                    import gc
                    gc.collect()
            except Exception as e:
                print(f"Error processing frame {out_frame_idx}: {e}")
                import traceback
                traceback.print_exc()
                continue

        # For CPU optimization: increase stride to reduce processing
        # Create a more aggressive stride to limit to fewer frames in output
        total_frames = len(video_segments)
        print(f"Total frames processed: {total_frames}")
        
        # Update progress to show rendering phase
        session_state["progress"] = 50
        progress(0.5, desc="Rendering video")
        
        # Limit to max 50 frames for CPU processing
        max_output_frames = 30
        vis_frame_stride = max(1, total_frames // max_output_frames)
        print(f"Using stride of {vis_frame_stride} for output video generation")
        
        # Get dimensions of the frames
        if len(session_state["all_frames"]) == 0:
            raise ValueError("No frames available in session state")
            
        first_frame = session_state["all_frames"][0]
        h, w = first_frame.shape[:2]
        
        # Create output frames
        output_frames = []
        
        frame_indices = list(range(0, total_frames, vis_frame_stride))
        total_output_frames = len(frame_indices)
        
        for i, out_frame_idx in enumerate(frame_indices):
            if out_frame_idx not in video_segments or OBJ_ID not in video_segments[out_frame_idx]:
                continue
                
            try:
                # Get corresponding frame from all_frames
                if out_frame_idx >= len(session_state["all_frames"]):
                    print(f"Warning: Frame index {out_frame_idx} exceeds available frames {len(session_state['all_frames'])}")
                    frame_idx = min(out_frame_idx, len(session_state["all_frames"])-1)
                else:
                    frame_idx = out_frame_idx
                    
                frame = session_state["all_frames"][frame_idx]
                
                # Create a colored overlay rather than using transparency
                # Get the mask
                out_mask = video_segments[out_frame_idx][OBJ_ID]
                
                # Ensure the mask is not empty and has valid dimensions
                if out_mask.size == 0 or 0 in out_mask.shape:
                    print(f"Warning: Invalid mask for frame {out_frame_idx}")
                    # Skip this frame
                    continue
                
                # Get dimensions
                frame_h, frame_w = frame.shape[:2]
                mask_h, mask_w = out_mask.shape[:2]
                
                # Resize mask using PIL if dimensions don't match (avoid OpenCV)
                if mask_h != frame_h or mask_w != frame_w:
                    print(f"Resizing mask from {mask_h}x{mask_w} to {frame_h}x{frame_w}")
                    try:
                        # Ensure mask is boolean type
                        if out_mask.dtype != np.bool_:
                            out_mask = out_mask > 0
                        
                        mask_img = Image.fromarray(out_mask.astype(np.uint8) * 255)
                        mask_img = mask_img.resize((frame_w, frame_h), Image.NEAREST)
                        out_mask = np.array(mask_img) > 0
                    except Exception as e:
                        print(f"Error resizing mask: {e}")
                        # Skip this frame if resize fails
                        continue
                
                # Create an overlay with semi-transparent color
                overlay = np.zeros_like(frame)
                
                # Set blue color for mask area
                overlay[out_mask] = [0, 120, 255]  # BGR format for OpenCV
                
                # Blend with original frame
                alpha = 0.5
                output_frame = cv2.addWeighted(frame, 1, overlay, alpha, 0)
                
                # Add to output frames
                output_frames.append(output_frame)
                
                # Update progress
                progress_percent = 50 + min(50, int((i / total_output_frames) * 50))
                session_state["progress"] = progress_percent
                progress(progress_percent/100, desc=f"Rendering video frames ({i}/{total_output_frames})")
                
                # Clear memory periodically
                if len(output_frames) % 10 == 0:
                    import gc
                    gc.collect()
                    
            except Exception as e:
                print(f"Error creating output frame {out_frame_idx}: {e}")
                import traceback
                traceback.print_exc()
                progress.tqdm.update(1)
                continue

        # Create a video clip from the image sequence
        original_fps = get_video_fps(video_in)
        fps = original_fps
        
        # For CPU optimization - lower FPS if original is high
        if fps > 15:
            fps = 15  # Lower fps for CPU processing
        
        print(f"Creating video with {len(output_frames)} frames at {fps} FPS")
        
        # Update progress to show video creation phase
        session_state["progress"] = 90
        
        # Check if we have any frames to work with
        if len(output_frames) == 0:
            raise ValueError("No output frames were generated")
        
        # Ensure all frames have the same shape
        first_shape = output_frames[0].shape
        valid_frames = []
        for i, frame in enumerate(output_frames):
            if frame.shape == first_shape:
                valid_frames.append(frame)
            else:
                print(f"Skipping frame {i} with inconsistent shape: {frame.shape} vs {first_shape}")
        
        if len(valid_frames) == 0:
            raise ValueError("No valid frames with consistent shape")
            
        clip = ImageSequenceClip(valid_frames, fps=fps)
        
        # Write the result to a file - use lower quality for CPU
        unique_id = datetime.now().strftime("%Y%m%d%H%M%S")
        final_vid_output_path = f"output_video_{unique_id}.mp4"
        final_vid_output_path = os.path.join(tempfile.gettempdir(), final_vid_output_path)

        # Lower bitrate for CPU processing
        clip.write_videofile(
            final_vid_output_path, 
            codec="libx264", 
            bitrate="800k",
            threads=2,  # Use fewer threads for CPU
            logger=None   # Disable logger to reduce console output
        )
        
        # Complete progress
        session_state["progress"] = 100
        
        # Free memory
        del video_segments
        del output_frames
        import gc
        gc.collect()

        return (
            gr.update(value=final_vid_output_path, visible=True),
            gr.update(value=100, visible=False),
            session_state,
        )
    
    except Exception as e:
        print(f"Error in propagate_to_all: {e}")
        import traceback
        traceback.print_exc()
        return (
            gr.update(value=None, visible=False),
            gr.update(value=0, visible=False),
            session_state,
        )

def update_ui():
    """Show progress bar when starting processing."""
    return gr.update(visible=True), gr.update(visible=True, value=0)


# Main Gradio UI setup
with gr.Blocks() as demo:
    session_state = gr.State(
        {
            "first_frame": None,
            "all_frames": None,
            "input_points": [],
            "input_labels": [],
            "inference_state": None,
            "frame_stride": 1,
            "scale_factor": 1.0,
            "original_dimensions": None,
            "progress": 0,
        }
    )

    with gr.Column():
        # Title
        gr.Markdown(title)
        with gr.Row():

            with gr.Column():
                # Instructions
                gr.Markdown(description_p)

                with gr.Accordion("Input Video", open=True) as video_in_drawer:
                    video_in = gr.Video(label="Input Video", format="mp4")

                with gr.Row():
                    point_type = gr.Radio(
                        label="point type",
                        choices=["include", "exclude"],
                        value="include",
                        scale=2,
                    )
                    propagate_btn = gr.Button("Track", scale=1, variant="primary")
                    clear_points_btn = gr.Button("Clear Points", scale=1)
                    reset_btn = gr.Button("Reset", scale=1)

                points_map = gr.Image(
                    label="Frame with Point Prompt", type="numpy", interactive=False
                )
                
                # Add progress bar
                progress_bar = gr.Slider(
                    minimum=0,
                    maximum=100,
                    value=0,
                    step=1,
                    label="Processing Progress",
                    visible=False,
                    interactive=False
                )

            with gr.Column():
                gr.Markdown("# Try some of the examples below ⬇️")
                gr.Examples(
                    examples=examples,
                    inputs=[
                        video_in,
                    ],
                    examples_per_page=5,
                )
                
                output_image = gr.Image(label="Reference Mask")
                output_video = gr.Video(visible=False)

    # When new video is uploaded
    video_in.upload(
        fn=preprocess_video_in,
        inputs=[
            video_in,
            session_state,
        ],
        outputs=[
            video_in_drawer,  # Accordion to hide uploaded video player
            points_map,  # Image component where we add new tracking points
            output_image,
            output_video,
            progress_bar,
            session_state,
        ],
        queue=False,
    )

    video_in.change(
        fn=preprocess_video_in,
        inputs=[
            video_in,
            session_state,
        ],
        outputs=[
            video_in_drawer,  # Accordion to hide uploaded video player
            points_map,  # Image component where we add new tracking points
            output_image,
            output_video,
            progress_bar,
            session_state,
        ],
        queue=False,
    )

    # triggered when we click