Spaces:
bla
/
Runtime error

File size: 27,783 Bytes
ea6a5ed
 
 
 
 
9bc4638
6e871ac
ea6a5ed
807f473
6e871ac
ea6a5ed
9bc4638
 
6e871ac
ea6a5ed
6e871ac
ea6a5ed
807f473
9bc4638
ea6a5ed
6e871ac
9bc4638
ea6a5ed
 
 
 
807f473
 
ea6a5ed
 
 
807f473
 
 
 
ea6a5ed
 
 
807f473
ea6a5ed
 
 
 
 
 
 
 
 
9bc4638
807f473
6e60611
 
b950bc5
807f473
ea6a5ed
 
 
 
 
 
807f473
ea6a5ed
807f473
ea6a5ed
807f473
ea6a5ed
807f473
ea6a5ed
 
 
 
5bc3a57
807f473
ea6a5ed
9bc4638
ea6a5ed
 
807f473
9bc4638
0b34400
9bc4638
 
ea6a5ed
807f473
ea6a5ed
 
 
 
 
 
 
807f473
ea6a5ed
 
 
 
 
 
807f473
ea6a5ed
 
 
 
807f473
ea6a5ed
 
 
 
 
 
 
 
807f473
ea6a5ed
 
e508568
ea6a5ed
807f473
 
ea6a5ed
807f473
 
 
 
 
ea6a5ed
 
e508568
807f473
9bc4638
ea6a5ed
807f473
ea6a5ed
807f473
 
 
 
 
ea6a5ed
 
 
807f473
ea6a5ed
 
6e871ac
807f473
 
 
 
 
 
ea6a5ed
807f473
ea6a5ed
 
807f473
 
 
 
 
ea6a5ed
 
 
807f473
 
 
 
 
 
 
9bc4638
 
ea6a5ed
 
807f473
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ea6a5ed
 
9bc4638
807f473
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ea6a5ed
 
 
 
807f473
 
ea6a5ed
807f473
 
 
 
 
 
 
 
 
 
ea6a5ed
807f473
 
 
 
 
ea6a5ed
 
 
807f473
 
 
 
 
 
 
 
 
 
ea6a5ed
 
 
 
 
 
 
 
 
807f473
ea6a5ed
 
 
 
807f473
 
ea6a5ed
807f473
 
ea6a5ed
 
 
807f473
 
 
 
ea6a5ed
807f473
ea6a5ed
807f473
 
 
ea6a5ed
807f473
ea6a5ed
807f473
ea6a5ed
807f473
b950bc5
807f473
 
 
 
 
 
 
ea6a5ed
 
 
 
 
 
 
807f473
 
ea6a5ed
 
807f473
ea6a5ed
807f473
 
 
 
 
ea6a5ed
 
807f473
 
ea6a5ed
 
807f473
ea6a5ed
b950bc5
ea6a5ed
807f473
 
 
ea6a5ed
9bc4638
ea6a5ed
9bc4638
ea6a5ed
807f473
ea6a5ed
 
 
 
 
 
807f473
 
 
 
 
 
 
 
ea6a5ed
 
807f473
ea6a5ed
6e871ac
807f473
ea6a5ed
807f473
ea6a5ed
807f473
ea6a5ed
 
807f473
 
ea6a5ed
 
 
807f473
 
 
ea6a5ed
 
 
807f473
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ea6a5ed
807f473
 
 
ea6a5ed
807f473
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ea6a5ed
807f473
ea6a5ed
807f473
ea6a5ed
 
807f473
 
 
 
 
 
 
 
 
 
ea6a5ed
807f473
 
ea6a5ed
 
 
 
807f473
 
ea6a5ed
 
807f473
 
 
 
 
 
 
 
 
 
 
 
ea6a5ed
 
807f473
 
 
 
 
 
ea6a5ed
 
807f473
 
ea6a5ed
807f473
 
 
ea6a5ed
 
807f473
ea6a5ed
807f473
ea6a5ed
807f473
 
 
 
 
 
 
 
ea6a5ed
807f473
 
ea6a5ed
807f473
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ea6a5ed
807f473
 
ea6a5ed
 
807f473
ea6a5ed
807f473
 
 
 
 
 
 
 
 
ea6a5ed
 
 
807f473
ea6a5ed
807f473
 
ea6a5ed
807f473
ea6a5ed
9bc4638
807f473
ea6a5ed
807f473
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ea6a5ed
807f473
 
ea6a5ed
807f473
ea6a5ed
807f473
 
ea6a5ed
807f473
 
ea6a5ed
807f473
 
 
 
 
 
 
ea6a5ed
 
 
807f473
 
 
 
 
 
ea6a5ed
 
807f473
 
 
 
 
 
 
ea6a5ed
 
807f473
 
 
ea6a5ed
807f473
9bc4638
807f473
 
 
 
 
 
 
 
 
 
 
 
 
ea6a5ed
 
807f473
ea6a5ed
 
807f473
ea6a5ed
807f473
ea6a5ed
807f473
ea6a5ed
 
807f473
ea6a5ed
807f473
 
 
 
 
 
ea6a5ed
 
 
807f473
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ea6a5ed
 
807f473
 
 
 
 
 
 
 
ea6a5ed
 
 
807f473
ea6a5ed
 
807f473
 
 
 
 
 
 
 
 
 
 
 
ea6a5ed
 
 
 
 
807f473
 
 
 
 
 
 
 
 
 
 
 
ea6a5ed
 
 
807f473
ea6a5ed
 
807f473
 
 
 
 
 
 
 
 
ea6a5ed
 
 
807f473
ea6a5ed
 
 
807f473
 
 
 
 
 
 
ea6a5ed
 
 
 
 
 
807f473
 
 
 
 
 
 
 
 
ea6a5ed
 
 
 
 
 
807f473
ea6a5ed
 
 
807f473
 
 
 
 
 
 
 
 
 
ea6a5ed
 
807f473
ea6a5ed
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.

# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

import copy
import os
import time
from datetime import datetime
import tempfile

import cv2
import matplotlib.pyplot as plt
import numpy as np
import gradio as gr
import torch

from moviepy.editor import ImageSequenceClip
from PIL import Image
from sam2.build_sam import build_sam2_video_predictor

# Remove CUDA environment variables
if 'TORCH_CUDNN_SDPA_ENABLED' in os.environ:
    del os.environ["TORCH_CUDNN_SDPA_ENABLED"]

# Description
title = "<center><strong><font size='8'>EdgeTAM CPU<font></strong> <a href='https://github.com/facebookresearch/EdgeTAM'><font size='6'>[GitHub]</font></a> </center>"

description_p = """# Instructions
                <ol>
                <li> Upload one video or click one example video</li>
                <li> Click 'include' point type, select the object to segment and track</li>
                <li> Click 'exclude' point type (optional), select the area you want to avoid segmenting and tracking</li>
                <li> Click the 'Track' button to obtain the masked video </li>
                </ol>
              """

# examples - keeping fewer examples to reduce memory footprint
examples = [
    ["examples/01_dog.mp4"],
    ["examples/02_cups.mp4"],
    ["examples/03_blocks.mp4"],
    ["examples/04_coffee.mp4"],
    ["examples/05_default_juggle.mp4"],
]

OBJ_ID = 0

# Initialize model on CPU - add error handling for file paths
sam2_checkpoint = "checkpoints/edgetam.pt"
model_cfg = "edgetam.yaml"

# Check if model files exist
def check_file_exists(filepath):
    exists = os.path.exists(filepath)
    if not exists:
        print(f"WARNING: File not found: {filepath}")
    return exists

# Verify files exist
model_files_exist = check_file_exists(sam2_checkpoint) and check_file_exists(model_cfg)
predictor = None
try:
    # Load model with careful error handling
    predictor = build_sam2_video_predictor(model_cfg, sam2_checkpoint, device="cpu")
    print("predictor loaded on CPU")
except Exception as e:
    print(f"Error loading model: {e}")
    import traceback
    traceback.print_exc()

# Function to get video frame rate
def get_video_fps(video_path):
    cap = cv2.VideoCapture(video_path)
    if not cap.isOpened():
        print("Error: Could not open video.")
        return 30.0  # Default fallback value
    fps = cap.get(cv2.CAP_PROP_FPS)
    cap.release()
    return fps

def reset(session_state):
    """Reset all session state variables and UI elements."""
    session_state["input_points"] = []
    session_state["input_labels"] = []
    if session_state["inference_state"] is not None:
        predictor.reset_state(session_state["inference_state"])
    session_state["first_frame"] = None
    session_state["all_frames"] = None
    session_state["inference_state"] = None
    session_state["progress"] = 0
    return (
        None,
        gr.update(open=True),
        None,
        None,
        gr.update(value=None, visible=False),
        gr.update(value=0, visible=False),
        session_state,
    )

def clear_points(session_state):
    """Clear tracking points while keeping the video frames."""
    session_state["input_points"] = []
    session_state["input_labels"] = []
    if session_state["inference_state"] is not None and session_state["inference_state"].get("tracking_has_started", False):
        predictor.reset_state(session_state["inference_state"])
    return (
        session_state["first_frame"],
        None,
        gr.update(value=None, visible=False),
        gr.update(value=0, visible=False),
        session_state,
    )

def preprocess_video_in(video_path, session_state):
    """Process input video to extract frames for tracking."""
    if video_path is None or not os.path.exists(video_path):
        return (
            gr.update(open=True),  # video_in_drawer
            None,  # points_map
            None,  # output_image
            gr.update(value=None, visible=False),  # output_video
            gr.update(value=0, visible=False),  # progress_bar
            session_state,
        )

    # Read the first frame
    cap = cv2.VideoCapture(video_path)
    if not cap.isOpened():
        print(f"Error: Could not open video at {video_path}.")
        return (
            gr.update(open=True),  # video_in_drawer
            None,  # points_map
            None,  # output_image
            gr.update(value=None, visible=False),  # output_video
            gr.update(value=0, visible=False),  # progress_bar
            session_state,
        )

    # For CPU optimization - determine video properties
    frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
    frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
    total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
    fps = cap.get(cv2.CAP_PROP_FPS)
    
    print(f"Video info: {frame_width}x{frame_height}, {total_frames} frames, {fps} FPS")
    
    # Determine if we need to resize for CPU performance
    target_width = 640  # Target width for processing on CPU
    scale_factor = 1.0
    
    if frame_width > target_width:
        scale_factor = target_width / frame_width
        new_width = int(frame_width * scale_factor)
        new_height = int(frame_height * scale_factor)
        print(f"Resizing video for CPU processing: {frame_width}x{frame_height} -> {new_width}x{new_height}")
    
    # Read frames - for CPU we'll be more selective about which frames to keep
    frame_number = 0
    first_frame = None
    all_frames = []
    
    # For CPU optimization, skip frames if video is too long
    frame_stride = 1
    if total_frames > 300:  # If more than 300 frames
        frame_stride = max(1, int(total_frames / 300))  # Process at most ~300 frames
        print(f"Video has {total_frames} frames, using stride of {frame_stride} to reduce processing load")
    
    while True:
        ret, frame = cap.read()
        if not ret:
            break
            
        if frame_number % frame_stride == 0:  # Process every frame_stride frames
            try:
                # Resize the frame if needed
                if scale_factor != 1.0:
                    frame = cv2.resize(
                        frame, 
                        (int(frame_width * scale_factor), int(frame_height * scale_factor)), 
                        interpolation=cv2.INTER_AREA
                    )
                    
                frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
                frame = np.array(frame)
                
                # Store the first frame
                if first_frame is None:
                    first_frame = frame
                all_frames.append(frame)
            except Exception as e:
                print(f"Error processing frame {frame_number}: {e}")
        
        frame_number += 1

    cap.release()
    
    # Ensure we have at least one frame
    if first_frame is None or len(all_frames) == 0:
        print("Error: No frames could be extracted from the video.")
        return (
            gr.update(open=True),  # video_in_drawer
            None,  # points_map
            None,  # output_image
            gr.update(value=None, visible=False),  # output_video
            gr.update(value=0, visible=False),  # progress_bar
            session_state,
        )
    
    print(f"Successfully extracted {len(all_frames)} frames from video")
    
    session_state["first_frame"] = copy.deepcopy(first_frame)
    session_state["all_frames"] = all_frames
    session_state["frame_stride"] = frame_stride
    session_state["scale_factor"] = scale_factor
    session_state["original_dimensions"] = (frame_width, frame_height)
    session_state["progress"] = 0

    try:
        session_state["inference_state"] = predictor.init_state(video_path=video_path)
        session_state["input_points"] = []
        session_state["input_labels"] = []
    except Exception as e:
        print(f"Error initializing inference state: {e}")
        import traceback
        traceback.print_exc()
        session_state["inference_state"] = None
    
    return [
        gr.update(open=False),  # video_in_drawer
        first_frame,  # points_map
        None,  # output_image
        gr.update(value=None, visible=False),  # output_video
        gr.update(value=0, visible=False),  # progress_bar
        session_state,
    ]

def segment_with_points(
    point_type,
    session_state,
    evt: gr.SelectData,
):
    """Add and process tracking points on the first frame."""
    if session_state["first_frame"] is None:
        print("Error: No frame available for segmentation")
        return None, None, session_state
    
    session_state["input_points"].append(evt.index)
    print(f"TRACKING INPUT POINT: {session_state['input_points']}")

    if point_type == "include":
        session_state["input_labels"].append(1)
    elif point_type == "exclude":
        session_state["input_labels"].append(0)
    print(f"TRACKING INPUT LABEL: {session_state['input_labels']}")

    # Open the image and get its dimensions
    first_frame = session_state["first_frame"]
    h, w = first_frame.shape[:2]
    transparent_background = Image.fromarray(first_frame).convert("RGBA")

    # Define the circle radius as a fraction of the smaller dimension
    fraction = 0.01  # You can adjust this value as needed
    radius = int(fraction * min(w, h))

    # Create a transparent layer to draw on
    transparent_layer = np.zeros((h, w, 4), dtype=np.uint8)

    for index, track in enumerate(session_state["input_points"]):
        if session_state["input_labels"][index] == 1:
            cv2.circle(transparent_layer, track, radius, (0, 255, 0, 255), -1)  # Green for include
        else:
            cv2.circle(transparent_layer, track, radius, (255, 0, 0, 255), -1)  # Red for exclude

    # Convert the transparent layer back to an image
    transparent_layer = Image.fromarray(transparent_layer, "RGBA")
    selected_point_map = Image.alpha_composite(
        transparent_background, transparent_layer
    )

    # Let's add a positive click at (x, y) = (210, 350) to get started
    points = np.array(session_state["input_points"], dtype=np.float32)
    # for labels, `1` means positive click and `0` means negative click
    labels = np.array(session_state["input_labels"], np.int32)
    
    try:
        if predictor is None:
            raise ValueError("Model predictor is not initialized")
            
        if session_state["inference_state"] is None:
            raise ValueError("Inference state is not initialized")
            
        # For CPU optimization, we'll process with smaller batch size
        _, _, out_mask_logits = predictor.add_new_points(
            inference_state=session_state["inference_state"],
            frame_idx=0,
            obj_id=OBJ_ID,
            points=points,
            labels=labels,
        )
        
        # Create the mask
        mask_array = (out_mask_logits[0] > 0.0).cpu().numpy()
        
        # Ensure the mask has the same size as the frame
        if mask_array.shape[:2] != (h, w):
            mask_array = cv2.resize(
                mask_array.astype(np.uint8), 
                (w, h), 
                interpolation=cv2.INTER_NEAREST
            ).astype(bool)
        
        mask_image = show_mask(mask_array)
        
        # Make sure mask_image has the same size as the background
        if mask_image.size != transparent_background.size:
            mask_image = mask_image.resize(transparent_background.size, Image.NEAREST)
            
        first_frame_output = Image.alpha_composite(transparent_background, mask_image)
    except Exception as e:
        print(f"Error in segmentation: {e}")
        import traceback
        traceback.print_exc()
        # Return just the points as fallback
        first_frame_output = selected_point_map

    return selected_point_map, first_frame_output, session_state

def show_mask(mask, obj_id=None, random_color=False, convert_to_image=True):
    """Convert binary mask to RGBA image for visualization."""
    if random_color:
        color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
    else:
        cmap = plt.get_cmap("tab10")
        cmap_idx = 0 if obj_id is None else obj_id
        color = np.array([*cmap(cmap_idx)[:3], 0.6])
    
    # Handle different mask shapes properly
    if len(mask.shape) == 2:
        h, w = mask.shape
    else:
        h, w = mask.shape[-2:]
    
    # Ensure correct reshaping based on mask dimensions
    mask_reshaped = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
    mask_rgba = (mask_reshaped * 255).astype(np.uint8)
    
    if convert_to_image:
        try:
            # Ensure the mask has correct RGBA shape (h, w, 4)
            if mask_rgba.shape[2] != 4:
                # If not RGBA, create a proper RGBA array
                proper_mask = np.zeros((h, w, 4), dtype=np.uint8)
                # Copy available channels
                proper_mask[:, :, :min(mask_rgba.shape[2], 4)] = mask_rgba[:, :, :min(mask_rgba.shape[2], 4)]
                mask_rgba = proper_mask
            
            # Create the PIL image
            return Image.fromarray(mask_rgba, "RGBA")
        except Exception as e:
            print(f"Error converting mask to image: {e}")
            # Fallback: create a blank transparent image of correct size
            blank = np.zeros((h, w, 4), dtype=np.uint8)
            return Image.fromarray(blank, "RGBA")
    
    return mask_rgba

def update_progress(progress_percent, progress_bar):
    """Update progress bar during processing."""
    return gr.update(value=progress_percent, visible=True)

def propagate_to_all(
    video_in,
    session_state,
    progress=gr.Progress(),
):
    """Process video frames and generate masked video output with progress tracking."""
    if (
        len(session_state["input_points"]) == 0
        or video_in is None
        or session_state["inference_state"] is None
        or predictor is None
    ):
        print("Missing required data for tracking")
        return (
            gr.update(value=None, visible=False),
            gr.update(value=0, visible=False),
            session_state,
        )

    # For CPU optimization: process in smaller batches
    chunk_size = 3  # Process 3 frames at a time to avoid memory issues on CPU
    
    try:
        # run propagation throughout the video and collect the results in a dict
        video_segments = {}  # video_segments contains the per-frame segmentation results
        print("Starting propagate_in_video on CPU")
        
        progress.tqdm.reset()
        
        # Get the count for progress reporting
        all_frames_count = 0
        for _ in predictor.propagate_in_video(session_state["inference_state"], count_only=True):
            all_frames_count += 1
        
        print(f"Total frames to process: {all_frames_count}")
        progress.tqdm.total = all_frames_count
        
        # Now do the actual processing with progress updates
        for out_frame_idx, out_obj_ids, out_mask_logits in predictor.propagate_in_video(
            session_state["inference_state"]
        ):
            try:
                # Store the masks for each object ID
                video_segments[out_frame_idx] = {
                    out_obj_id: (out_mask_logits[i] > 0.0).cpu().numpy()
                    for i, out_obj_id in enumerate(out_obj_ids)
                }
                
                # Update progress
                progress.tqdm.update(1)
                progress_percent = min(100, int((out_frame_idx + 1) / all_frames_count * 100))
                session_state["progress"] = progress_percent
                
                if out_frame_idx % 10 == 0:
                    print(f"Processed frame {out_frame_idx}/{all_frames_count} ({progress_percent}%)")
                
                # Release memory periodically
                if out_frame_idx % chunk_size == 0:
                    # Explicitly clear any tensors
                    del out_mask_logits
                    import gc
                    gc.collect()
            except Exception as e:
                print(f"Error processing frame {out_frame_idx}: {e}")
                import traceback
                traceback.print_exc()
                continue

        # For CPU optimization: increase stride to reduce processing
        # Create a more aggressive stride to limit to fewer frames in output
        total_frames = len(video_segments)
        print(f"Total frames processed: {total_frames}")
        
        # Update progress to show rendering phase
        progress.tqdm.reset()
        progress.tqdm.total = 2  # Two phases: rendering and video creation
        progress.tqdm.update(1)
        session_state["progress"] = 50
        
        # Limit to max 50 frames for CPU processing
        max_output_frames = 50
        vis_frame_stride = max(1, total_frames // max_output_frames)
        print(f"Using stride of {vis_frame_stride} for output video generation")
        
        # Get dimensions of the frames
        if len(session_state["all_frames"]) == 0:
            raise ValueError("No frames available in session state")
            
        first_frame = session_state["all_frames"][0]
        h, w = first_frame.shape[:2]
        
        # Create output frames
        output_frames = []
        progress.tqdm.reset()
        progress.tqdm.total = (total_frames // vis_frame_stride) + 1
        
        for out_frame_idx in range(0, total_frames, vis_frame_stride):
            if out_frame_idx not in video_segments or OBJ_ID not in video_segments[out_frame_idx]:
                progress.tqdm.update(1)
                continue
                
            try:
                # Get corresponding frame from all_frames
                if out_frame_idx >= len(session_state["all_frames"]):
                    print(f"Warning: Frame index {out_frame_idx} exceeds available frames {len(session_state['all_frames'])}")
                    frame_idx = min(out_frame_idx, len(session_state["all_frames"])-1)
                else:
                    frame_idx = out_frame_idx
                    
                frame = session_state["all_frames"][frame_idx]
                transparent_background = Image.fromarray(frame).convert("RGBA")
                
                # Get the mask and ensure it's the right size
                out_mask = video_segments[out_frame_idx][OBJ_ID]
                
                # Ensure the mask is not empty and has the right dimensions
                if out_mask.size == 0:
                    print(f"Warning: Empty mask for frame {out_frame_idx}")
                    # Create an empty mask of the right size
                    out_mask = np.zeros((h, w), dtype=bool)
                
                # Resize mask if dimensions don't match
                mask_h, mask_w = out_mask.shape[:2]
                if mask_h != h or mask_w != w:
                    print(f"Resizing mask from {mask_h}x{mask_w} to {h}x{w}")
                    out_mask = cv2.resize(
                        out_mask.astype(np.uint8), 
                        (w, h), 
                        interpolation=cv2.INTER_NEAREST
                    ).astype(bool)
                
                mask_image = show_mask(out_mask)
                
                # Make sure mask has same dimensions as background
                if mask_image.size != transparent_background.size:
                    mask_image = mask_image.resize(transparent_background.size, Image.NEAREST)
                
                output_frame = Image.alpha_composite(transparent_background, mask_image)
                output_frame = np.array(output_frame)
                output_frames.append(output_frame)
                
                # Update progress
                progress.tqdm.update(1)
                progress_percent = 50 + min(50, int((len(output_frames) / (total_frames // vis_frame_stride)) * 50))
                session_state["progress"] = progress_percent
                
                # Clear memory periodically
                if len(output_frames) % 10 == 0:
                    import gc
                    gc.collect()
                    
            except Exception as e:
                print(f"Error creating output frame {out_frame_idx}: {e}")
                import traceback
                traceback.print_exc()
                progress.tqdm.update(1)
                continue

        # Create a video clip from the image sequence
        original_fps = get_video_fps(video_in)
        fps = original_fps
        
        # For CPU optimization - lower FPS if original is high
        if fps > 15:
            fps = 15  # Lower fps for CPU processing
        
        print(f"Creating video with {len(output_frames)} frames at {fps} FPS")
        
        # Update progress to show video creation phase
        session_state["progress"] = 90
        
        # Check if we have any frames to work with
        if len(output_frames) == 0:
            raise ValueError("No output frames were generated")
        
        # Ensure all frames have the same shape
        first_shape = output_frames[0].shape
        valid_frames = []
        for i, frame in enumerate(output_frames):
            if frame.shape == first_shape:
                valid_frames.append(frame)
            else:
                print(f"Skipping frame {i} with inconsistent shape: {frame.shape} vs {first_shape}")
        
        if len(valid_frames) == 0:
            raise ValueError("No valid frames with consistent shape")
            
        clip = ImageSequenceClip(valid_frames, fps=fps)
        
        # Write the result to a file - use lower quality for CPU
        unique_id = datetime.now().strftime("%Y%m%d%H%M%S")
        final_vid_output_path = f"output_video_{unique_id}.mp4"
        final_vid_output_path = os.path.join(tempfile.gettempdir(), final_vid_output_path)

        # Lower bitrate for CPU processing
        clip.write_videofile(
            final_vid_output_path, 
            codec="libx264", 
            bitrate="800k",
            threads=2,  # Use fewer threads for CPU
            logger=None   # Disable logger to reduce console output
        )
        
        # Complete progress
        session_state["progress"] = 100
        
        # Free memory
        del video_segments
        del output_frames
        import gc
        gc.collect()

        return (
            gr.update(value=final_vid_output_path, visible=True),
            gr.update(value=100, visible=False),
            session_state,
        )
    
    except Exception as e:
        print(f"Error in propagate_to_all: {e}")
        import traceback
        traceback.print_exc()
        return (
            gr.update(value=None, visible=False),
            gr.update(value=0, visible=False),
            session_state,
        )

def update_ui():
    """Show progress bar when starting processing."""
    return gr.update(visible=True), gr.update(visible=True, value=0)


# Main Gradio UI setup
with gr.Blocks() as demo:
    session_state = gr.State(
        {
            "first_frame": None,
            "all_frames": None,
            "input_points": [],
            "input_labels": [],
            "inference_state": None,
            "frame_stride": 1,
            "scale_factor": 1.0,
            "original_dimensions": None,
            "progress": 0,
        }
    )

    with gr.Column():
        # Title
        gr.Markdown(title)
        with gr.Row():

            with gr.Column():
                # Instructions
                gr.Markdown(description_p)

                with gr.Accordion("Input Video", open=True) as video_in_drawer:
                    video_in = gr.Video(label="Input Video", format="mp4")

                with gr.Row():
                    point_type = gr.Radio(
                        label="point type",
                        choices=["include", "exclude"],
                        value="include",
                        scale=2,
                    )
                    propagate_btn = gr.Button("Track", scale=1, variant="primary")
                    clear_points_btn = gr.Button("Clear Points", scale=1)
                    reset_btn = gr.Button("Reset", scale=1)

                points_map = gr.Image(
                    label="Frame with Point Prompt", type="numpy", interactive=False
                )
                
                # Add progress bar
                progress_bar = gr.Slider(
                    minimum=0,
                    maximum=100,
                    value=0,
                    step=1,
                    label="Processing Progress",
                    visible=False,
                    interactive=False
                )

            with gr.Column():
                gr.Markdown("# Try some of the examples below ⬇️")
                gr.Examples(
                    examples=examples,
                    inputs=[
                        video_in,
                    ],
                    examples_per_page=5,
                )
                
                output_image = gr.Image(label="Reference Mask")
                output_video = gr.Video(visible=False)

    # When new video is uploaded
    video_in.upload(
        fn=preprocess_video_in,
        inputs=[
            video_in,
            session_state,
        ],
        outputs=[
            video_in_drawer,  # Accordion to hide uploaded video player
            points_map,  # Image component where we add new tracking points
            output_image,
            output_video,
            progress_bar,
            session_state,
        ],
        queue=False,
    )

    video_in.change(
        fn=preprocess_video_in,
        inputs=[
            video_in,
            session_state,
        ],
        outputs=[
            video_in_drawer,  # Accordion to hide uploaded video player
            points_map,  # Image component where we add new tracking points
            output_image,
            output_video,
            progress_bar,
            session_state,
        ],
        queue=False,
    )

    # triggered when we click on image to add new points
    points_map.select(
        fn=segment_with_points,
        inputs=[
            point_type,  # "include" or "exclude"
            session_state,
        ],
        outputs=[
            points_map,  # updated image with points
            output_image,
            session_state,
        ],
        queue=False,
    )

    # Clear every points clicked and added to the map
    clear_points_btn.click(
        fn=clear_points,
        inputs=session_state,
        outputs=[
            points_map,
            output_image,
            output_video,
            progress_bar,
            session_state,
        ],
        queue=False,
    )

    reset_btn.click(
        fn=reset,
        inputs=session_state,
        outputs=[
            video_in,
            video_in_drawer,
            points_map,
            output_image,
            output_video,
            progress_bar,
            session_state,
        ],
        queue=False,
    )

    propagate_btn.click(
        fn=update_ui,
        inputs=[],
        outputs=[output_video, progress_bar],
        queue=False,
    ).then(
        fn=propagate_to_all,
        inputs=[
            video_in,
            session_state,
        ],
        outputs=[
            output_video,
            progress_bar,
            session_state,
        ],
        queue=True,  # Use queue for CPU processing
    )


demo.queue()
demo.launch()