Spaces:
bla
/
Runtime error

File size: 31,445 Bytes
9bc4638
 
 
 
 
 
 
 
 
 
 
 
1affb38
 
 
9bc4638
 
 
 
 
1affb38
 
9bc4638
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1affb38
9bc4638
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6e60611
 
1affb38
2a466e4
eafda84
2a466e4
1affb38
 
 
 
9bc4638
1affb38
 
9bc4638
 
 
1affb38
 
 
 
9bc4638
 
1affb38
9bc4638
 
1affb38
9bc4638
 
1affb38
9bc4638
1affb38
 
 
 
 
9bc4638
 
 
 
 
1affb38
 
 
 
 
 
 
 
 
2a466e4
1affb38
9bc4638
 
1affb38
9bc4638
 
1affb38
 
9bc4638
1affb38
 
 
 
 
 
 
 
 
 
 
 
 
2a466e4
1affb38
9bc4638
 
 
 
 
 
 
 
 
1affb38
9bc4638
 
1affb38
 
9bc4638
 
 
1affb38
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2a466e4
1affb38
 
 
2a466e4
1affb38
 
2a466e4
9bc4638
 
2a466e4
 
1affb38
9bc4638
 
 
1affb38
 
 
 
 
 
 
9bc4638
 
 
1affb38
 
 
 
 
 
 
 
 
 
 
2a466e4
1affb38
 
2a466e4
1affb38
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2a466e4
1affb38
 
 
2a466e4
 
 
 
 
 
1affb38
 
2a466e4
1affb38
cac3a2b
1affb38
 
 
 
 
 
 
 
 
 
 
 
 
9bc4638
 
 
 
 
1affb38
 
 
 
fa0b563
1affb38
 
 
 
 
 
 
 
 
 
 
9bc4638
 
 
 
 
 
1affb38
 
 
9bc4638
1affb38
 
 
9bc4638
1affb38
 
9bc4638
1affb38
9bc4638
1affb38
 
9bc4638
1affb38
 
9bc4638
1affb38
 
 
 
 
 
fa0b563
1affb38
 
9bc4638
 
1affb38
9bc4638
 
 
2a466e4
1affb38
 
 
 
 
fa0b563
1affb38
fa0b563
1affb38
 
 
 
 
 
 
9bc4638
1affb38
fa0b563
1affb38
 
 
 
 
 
 
fa0b563
1affb38
 
 
 
fa0b563
1affb38
 
 
9bc4638
 
 
1affb38
 
 
 
2a466e4
 
1affb38
9bc4638
1affb38
9bc4638
 
1affb38
 
 
 
 
2a466e4
1affb38
 
 
2a466e4
1affb38
2a466e4
1affb38
2a466e4
1affb38
 
 
 
 
 
fa0b563
 
 
1affb38
 
 
 
9bc4638
1affb38
 
 
 
9bc4638
 
1affb38
9bc4638
2a466e4
 
 
9bc4638
 
1affb38
 
 
9bc4638
1affb38
 
9bc4638
2a466e4
9bc4638
1affb38
9bc4638
1affb38
9bc4638
 
 
1affb38
 
 
 
2a466e4
1affb38
 
 
 
 
 
2a466e4
 
fa0b563
2a466e4
1affb38
 
 
 
 
 
 
 
 
 
 
 
9bc4638
 
 
1affb38
2a466e4
 
1affb38
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2a466e4
1affb38
 
 
 
 
9bc4638
 
1affb38
2a466e4
 
1affb38
 
 
 
 
 
 
 
 
 
9bc4638
1affb38
 
 
 
 
 
 
 
 
9bc4638
1affb38
 
 
 
 
 
 
 
 
 
2a466e4
1affb38
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9bc4638
 
1affb38
 
9bc4638
 
 
 
1affb38
9bc4638
 
1affb38
 
 
 
 
 
9bc4638
 
 
 
 
 
 
 
 
 
 
 
 
1affb38
9bc4638
 
 
 
 
 
 
1affb38
9bc4638
1affb38
 
 
 
9bc4638
1affb38
 
9bc4638
fa0b563
1affb38
 
 
 
 
 
 
9bc4638
 
 
 
 
 
1affb38
9bc4638
1affb38
 
 
 
 
 
 
fa0b563
1affb38
 
 
 
 
 
 
9bc4638
 
1affb38
 
 
 
 
9bc4638
1affb38
9bc4638
 
1affb38
9bc4638
1affb38
 
 
 
 
 
 
 
9bc4638
1affb38
9bc4638
 
1affb38
9bc4638
 
1affb38
 
 
 
 
 
 
 
 
 
9bc4638
1affb38
9bc4638
 
1affb38
 
9bc4638
 
 
1affb38
 
9bc4638
 
1affb38
 
 
9bc4638
1affb38
9bc4638
 
1affb38
9bc4638
 
1affb38
9bc4638
1affb38
 
 
 
9bc4638
1affb38
9bc4638
 
1affb38
9bc4638
 
1affb38
9bc4638
1affb38
 
 
 
 
 
 
 
 
9bc4638
1affb38
9bc4638
 
1affb38
9bc4638
1affb38
9bc4638
1affb38
 
 
9bc4638
 
2a466e4
 
9bc4638
 
1affb38
 
9bc4638
1affb38
 
 
 
9bc4638
 
 
1affb38
 
 
fa0b563
9bc4638
1affb38
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.

# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

import copy
import os
from datetime import datetime

import gradio as gr

# Removed GPU-specific environment variable setting
# os.environ["TORCH_CUDNN_SDPA_ENABLED"] = "0,1,2,3,4,5,6,7"

import tempfile

import cv2
import matplotlib.pyplot as plt
import numpy as np
# Removed spaces decorator import for CPU-only demo
# import spaces
import torch

from moviepy.editor import ImageSequenceClip
from PIL import Image
from sam2.build_sam import build_sam2_video_predictor

# Description
title = "<center><strong><font size='8'>EdgeTAM<font></strong> <a href='https://github.com/facebookresearch/EdgeTAM'><font size='6'>[GitHub]</font></a> </center>"

description_p = """# Instructions
                <ol>
                <li> Upload one video or click one example video</li>
                <li> Click 'include' point type, select the object to segment and track</li>
                <li> Click 'exclude' point type (optional), select the area you want to avoid segmenting and tracking</li>
                <li> Click the 'Track' button to obtain the masked video </li>
                </ol>
              """

# examples - Keep examples, they are input files
examples = [
    ["examples/01_dog.mp4"],
    ["examples/02_cups.mp4"],
    ["examples/03_blocks.mp4"],
    ["examples/04_coffee.mp4"],
    ["examples/05_default_juggle.mp4"],
    ["examples/01_breakdancer.mp4"],
    ["examples/02_hummingbird.mp4"],
    ["examples/03_skateboarder.mp4"],
    ["examples/04_octopus.mp4"],
    ["examples/05_landing_dog_soccer.mp4"],
    ["examples/06_pingpong.mp4"],
    ["examples/07_snowboarder.mp4"],
    ["examples/08_driving.mp4"],
    ["examples/09_birdcartoon.mp4"],
    ["examples/10_cloth_magic.mp4"],
    ["examples/11_polevault.mp4"],
    ["examples/12_hideandseek.mp4"],
    ["examples/13_butterfly.mp4"],
    ["examples/14_social_dog_training.mp4"],
    ["examples/15_cricket.mp4"],
    ["examples/16_robotarm.mp4"],
    ["examples/17_childrendancing.mp4"],
    ["examples/18_threedogs.mp4"],
    ["examples/19_cyclist.mp4"],
    ["examples/20_doughkneading.mp4"],
    ["examples/21_biker.mp4"],
    ["examples/22_dogskateboarder.mp4"],
    ["examples/23_racecar.mp4"],
    ["examples/24_clownfish.mp4"],
]

OBJ_ID = 0

sam2_checkpoint = "checkpoints/edgetam.pt"
model_cfg = "edgetam.yaml"
# Ensure predictor is explicitly built for CPU
# The device is set here and with .to("cpu")
predictor = build_sam2_video_predictor(model_cfg, sam2_checkpoint, device="cpu")
predictor.to("cpu") # Explicitly move to CPU after building
print("predictor loaded on CPU")

# Removed autocast block for maximum CPU compatibility
# torch.autocast(device_type="cpu", dtype=torch.bfloat16).__enter__()

# Removed commented-out GPU-specific code
# if torch.cuda.get_device_properties(0).major >= 8: ...


def get_video_fps(video_path):
    """Gets the frames per second of a video file."""
    if video_path is None or not os.path.exists(video_path):
         print(f"Warning: Video file not found at {video_path}")
         return None
    cap = cv2.VideoCapture(video_path)
    if not cap.isOpened():
        print(f"Error: Could not open video file {video_path}.")
        return None
    fps = cap.get(cv2.CAP_PROP_FPS)
    cap.release()
    return fps

# Removed @spaces.GPU decorator
def preprocess_video_in(video_path, session_state):
    """Loads video frames and initializes the predictor state."""
    print(f"Processing video: {video_path}")
    if video_path is None or not os.path.exists(video_path):
        print("No video path provided or file not found.")
        # Reset state and UI elements if input is invalid
        return (
            gr.update(open=True),  # video_in_drawer
            None,  # points_map
            None,  # output_image
            gr.update(value=None, visible=False),  # output_video
            gr.update(interactive=False), # propagate_btn
            gr.update(interactive=False), # clear_points_btn
            gr.update(interactive=False), # reset_btn
            { # Reset session state
                "first_frame": None,
                "all_frames": None,
                "input_points": [],
                "input_labels": [],
                "inference_state": None,
                "video_path": None,
            }
        )

    # Read the first frame and all frames
    cap = cv2.VideoCapture(video_path)
    if not cap.isOpened():
        print(f"Error: Could not open video file {video_path}.")
        # Reset state and UI elements on error
        return (
            gr.update(open=True),
            None,
            None,
            gr.update(value=None, visible=False),
             gr.update(interactive=False), # propagate_btn
            gr.update(interactive=False), # clear_points_btn
            gr.update(interactive=False), # reset_btn
            { # Reset session state
                "first_frame": None,
                "all_frames": None,
                "input_points": [],
                "input_labels": [],
                "inference_state": None,
                "video_path": None,
            }
        )

    first_frame = None
    all_frames = []

    while True:
        ret, frame = cap.read()
        if not ret:
            break
        # Convert BGR to RGB
        frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
        all_frames.append(frame)
        if first_frame is None:
            first_frame = frame # Store the first frame

    cap.release()

    if not all_frames:
        print(f"Error: No frames read from video file {video_path}.")
         # Reset state and UI elements if no frames are read
        return (
            gr.update(open=True),
            None,
            None,
            gr.update(value=None, visible=False),
            gr.update(interactive=False), # propagate_btn
            gr.update(interactive=False), # clear_points_btn
            gr.update(interactive=False), # reset_btn
            { # Reset session state
                "first_frame": None,
                "all_frames": None,
                "input_points": [],
                "input_labels": [],
                "inference_state": None,
                "video_path": None,
            }
        )

    # Update session state with frames and path
    session_state["first_frame"] = copy.deepcopy(first_frame) # Store a copy
    session_state["all_frames"] = all_frames
    session_state["video_path"] = video_path # Store the path
    session_state["input_points"] = []
    session_state["input_labels"] = []
    # Initialize state *without* the device argument
    session_state["inference_state"] = predictor.init_state(video_path=video_path)
    print("Video loaded and predictor state initialized.")

    return [
        gr.update(open=False),  # video_in_drawer
        first_frame,  # points_map (shows first frame)
        None,  # output_image (cleared initially)
        gr.update(value=None, visible=False),  # output_video (hidden initially)
        gr.update(interactive=True), # Enable buttons
        gr.update(interactive=True), # Enable buttons
        gr.update(interactive=True), # Enable buttons
        session_state, # Updated state
    ]


def reset(session_state):
    """Resets the UI and session state."""
    print("Resetting demo.")
    # Clear points and labels
    session_state["input_points"] = []
    session_state["input_labels"] = []
    # Reset the predictor state if it exists
    if session_state["inference_state"] is not None:
        predictor.reset_state(session_state["inference_state"])
        # After reset, we also discard the state object as a new video might be loaded
        session_state["inference_state"] = None
    # Clear frames and video path
    session_state["first_frame"] = None
    session_state["all_frames"] = None
    session_state["video_path"] = None

    # Update UI elements to their initial state
    return (
        None, # video_in
        gr.update(open=True), # video_in_drawer open
        None, # points_map cleared
        None, # output_image cleared
        gr.update(value=None, visible=False), # output_video hidden
        gr.update(interactive=False), # Disable buttons
        gr.update(interactive=False), # Disable buttons
        gr.update(interactive=False), # Disable buttons
        session_state, # Updated session state
    )


def clear_points(session_state):
    """Clears selected points and resets segmentation on the first frame."""
    print("Clearing points.")
    # Clear points and labels lists
    session_state["input_points"] = []
    session_state["input_labels"] = []

    # Reset the predictor state if it exists. This clears internal masks/features
    # but keeps the video context initialized by preprocess_video_in.
    if session_state["inference_state"] is not None:
        predictor.reset_state(session_state["inference_state"])
         # After resetting the state, if we still have the video path, re-initialize the state
         # to be ready for new points on the same video.
        if session_state["video_path"] is not None:
             # Re-initialize state *without* the device argument
             session_state["inference_state"] = predictor.init_state(video_path=session_state["video_path"])
             print("Predictor state re-initialized after clearing points.")
        else:
             print("Warning: Could not re-initialize state after clear_points (video_path missing).")
             session_state["inference_state"] = None # Ensure state is None if video_path is gone


    # Re-render the points_map with no points drawn (just the first frame)
    # Re-render the output_image with no mask (just the first frame)
    first_frame_img = session_state["first_frame"] if session_state["first_frame"] is not None else None

    return (
        first_frame_img, # points_map shows original first frame
        None, # output_image cleared
        gr.update(value=None, visible=False), # Hide output video
        session_state, # Updated session state
    )


# Removed @spaces.GPU decorator
def segment_with_points(
    point_type,
    session_state,
    evt: gr.SelectData,
):
    """Adds a point prompt and performs segmentation on the first frame."""
    # Ensure we have a valid first frame and inference state
    if session_state["first_frame"] is None or session_state["inference_state"] is None:
         print("Error: Cannot segment. No video loaded or inference state missing.")
         # Return current states to avoid errors, without changing UI much
         return (
             session_state["first_frame"], # points_map remains unchanged
             None, # output_image remains unchanged or cleared
             session_state,
         )

    # evt.index gives the (x, y) coordinates of the click
    click_coords = evt.index
    print(f"Clicked at: {click_coords} ({point_type})")

    session_state["input_points"].append(click_coords)

    if point_type == "include":
        session_state["input_labels"].append(1)
    elif point_type == "exclude":
        session_state["input_labels"].append(0)

    # Get the first frame as a PIL image for drawing
    first_frame_pil = Image.fromarray(session_state["first_frame"]).convert("RGBA")
    w, h = first_frame_pil.size

    # Define the circle radius
    fraction = 0.01
    radius = max(2, int(fraction * min(w, h))) # Ensure minimum radius of 2

    # Create a transparent layer to draw points
    transparent_layer_points = np.zeros((h, w, 4), dtype=np.uint8)

    # Draw points on the transparent layer
    for index, track in enumerate(session_state["input_points"]):
        # Ensure coordinates are integers for cv2.circle
        point_coords = (int(track[0]), int(track[1]))
        if session_state["input_labels"][index] == 1:
            # Green circle for include
            cv2.circle(transparent_layer_points, point_coords, radius, (0, 255, 0, 255), -1)
        else:
            # Red circle for exclude
            cv2.circle(transparent_layer_points, point_coords, radius, (255, 0, 0, 255), -1)

    # Convert the transparent layer back to an image and composite onto the first frame
    transparent_layer_points_pil = Image.fromarray(transparent_layer_points, "RGBA")
    # Combine the first frame image with the points layer for the points_map output
    # points_map shows the first frame *with the points you added*.
    selected_point_map_img = Image.alpha_composite(
        first_frame_pil.copy(), transparent_layer_points_pil
    )

    # Prepare points and labels as tensors on CPU for the predictor
    points = np.array(session_state["input_points"], dtype=np.float32)
    labels = np.array(session_state["input_labels"], np.int32)

    # Ensure tensors are on CPU
    points_tensor = torch.tensor(points, dtype=torch.float32, device="cpu").unsqueeze(0) # Add batch dim
    labels_tensor = torch.tensor(labels, dtype=torch.int32, device="cpu").unsqueeze(0) # Add batch dim

    # Add new points to the predictor's state and get the mask for the first frame
    # This call performs segmentation on the current frame (frame_idx=0) using all accumulated points
    first_frame_output_img = None # Initialize output mask image as None in case of error
    try:
        # Note: predictor.add_new_points modifies the internal inference_state
        _, _, out_mask_logits = predictor.add_new_points(
            inference_state=session_state["inference_state"],
            frame_idx=0, # Always segment on the first frame initially
            obj_id=OBJ_ID,
            points=points_tensor,
            labels=labels_tensor,
        )

        # Process logits: detach from graph, move to CPU, apply threshold
        # out_mask_logits is a list of tensors [tensor([H, W])] for the requested obj_id
        mask_tensor = (out_mask_logits[0][0].detach().cpu() > 0.0) # Apply threshold and get the single mask tensor [H, W]
        mask_numpy = mask_tensor.numpy() # Convert to numpy

        # Get the mask image (RGBA)
        mask_image_pil = show_mask(mask_numpy, obj_id=OBJ_ID) # show_mask returns RGBA PIL Image

        # Composite the mask onto the first frame for the output_image
        # output_image shows the first frame *with the segmentation mask result*.
        first_frame_output_img = Image.alpha_composite(first_frame_pil.copy(), mask_image_pil)

    except Exception as e:
        print(f"Error during segmentation on first frame: {e}")
        # On error, first_frame_output_img remains None


    return selected_point_map_img, first_frame_output_img, session_state


def show_mask(mask, obj_id=None, random_color=False, convert_to_image=True):
    """Helper function to visualize a mask."""
    # Ensure mask is a numpy array (and boolean)
    if isinstance(mask, torch.Tensor):
         mask = mask.detach().cpu().numpy() # Ensure it's on CPU and converted to numpy
    # Convert potential float/int mask to boolean mask
    mask = mask.astype(bool)

    if random_color:
        color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0) # RGBA with 0.6 alpha
    else:
        cmap = plt.get_cmap("tab10")
        cmap_idx = 0 if obj_id is None else obj_id % 10 # Use modulo 10 for tab10 colors
        color = np.array([*cmap(cmap_idx)[:3], 0.6]) # RGBA with 0.6 alpha

    # Ensure mask has H, W dimensions
    if mask.ndim == 3:
        mask = mask.squeeze() # Remove singular dimensions like (H, W, 1)
    if mask.ndim != 2:
        print(f"Warning: show_mask received mask with shape {mask.shape}. Expected 2D.")
        # Create an empty transparent image if mask shape is unexpected
        h, w = mask.shape[:2] if mask.ndim >= 2 else (100, 100) # Use actual shape if possible, otherwise default
        if convert_to_image:
             return Image.fromarray(np.zeros((h, w, 4), dtype=np.uint8), "RGBA")
        else:
             return np.zeros((h, w, 4), dtype=np.uint8)

    h, w = mask.shape
    # Create an RGBA image from the mask and color
    # Apply color where mask is True
    # Need to reshape color to be broadcastable [1, 1, 4]
    colored_mask = np.zeros((h, w, 4), dtype=np.float32) # Start with fully transparent black
    # Apply the color only where the mask is True.
    # This directly creates the colored overlay with transparency.
    colored_mask[mask] = color

    # Convert to uint8 [0-255]
    colored_mask_uint8 = (colored_mask * 255).astype(np.uint8)

    if convert_to_image:
        mask_img = Image.fromarray(colored_mask_uint8, "RGBA")
        return mask_img
    else:
        return colored_mask_uint8


# Removed @spaces.GPU decorator
def propagate_to_all(
    # We don't strictly need video_in path here anymore as it's in session_state,
    # but keeping it is fine. Accessing session_state["video_path"] is more robust.
    video_in,
    session_state,
):
    """Runs mask propagation through the video and generates the output video."""
    print("Starting propagation...")
    # Ensure state is ready
    if (
        len(session_state["input_points"]) == 0 # Need at least one point
        or session_state["all_frames"] is None
        or session_state["inference_state"] is None
        or session_state["video_path"] is None # Ensure we have the original video path
    ):
        print("Error: Cannot propagate. No points selected, video not loaded, or inference state missing.")
        return (
            gr.update(value=None, visible=False), # Hide output video on error
            session_state,
        )

    # run propagation throughout the video and collect the results
    # The generator yields (frame_idx, obj_ids, mask_logits)
    video_segments = {}
    try:
        # This loop performs the core tracking prediction frame by frame
        for out_frame_idx, out_obj_ids, out_mask_logits in predictor.propagate_in_video(
            session_state["inference_state"]
        ):
            # Process logits: detach from graph, move to CPU, convert to numpy boolean mask
             # Ensure tensor is on CPU before converting to numpy
             video_segments[out_frame_idx] = {
                 # out_mask_logits is a list of tensors (one per object tracked in this frame)
                 # Each tensor is [batch_size, H, W]. Batch size is 1 here.
                 # Access the first element of the batch [0]
                 out_obj_id: (out_mask_logits[i][0].detach().cpu() > 0.0).numpy()
                 for i, out_obj_id in enumerate(out_obj_ids)
             }
             # Optional: print progress
             # print(f"Processed frame {out_frame_idx+1}/{len(session_state['all_frames'])}")

        print("Propagation finished.")
    except Exception as e:
        print(f"Error during propagation: {e}")
        return (
            gr.update(value=None, visible=False), # Hide output video on error
            session_state,
        )


    output_frames = []
    # Iterate through all original frames to generate output video
    total_frames = len(session_state["all_frames"])
    for out_frame_idx in range(total_frames):
        original_frame_rgb = session_state["all_frames"][out_frame_idx]
        # Convert original frame to RGBA for compositing
        transparent_background = Image.fromarray(original_frame_rgb).convert("RGBA")

        # Check if we have a mask for this frame and object ID
        if out_frame_idx in video_segments and OBJ_ID in video_segments[out_frame_idx]:
            current_mask_numpy = video_segments[out_frame_idx][OBJ_ID]
            # Get the mask image (RGBA)
            mask_image_pil = show_mask(current_mask_numpy, obj_id=OBJ_ID)
            # Composite the mask onto the frame
            output_frame_img_rgba = Image.alpha_composite(transparent_background, mask_image_pil)
            # Convert back to numpy RGB (moviepy needs RGB or RGBA)
            output_frame_np = np.array(output_frame_img_rgba.convert("RGB"))
        else:
             # If no mask for this frame/object, just use the original frame (converted to RGB)
             # Note: all_frames are already RGB numpy arrays, so just use them directly.
             # print(f"Warning: No mask found for frame {out_frame_idx} and object {OBJ_ID}. Using original frame.")
             output_frame_np = original_frame_rgb # Already RGB numpy array

        output_frames.append(output_frame_np)


    # Define output path in a temporary directory
    # Use os.path.join for cross-platform compatibility
    unique_id = datetime.now().strftime("%Y%m%d%H%M%S%f") # Use microseconds for more uniqueness
    final_vid_filename = f"output_video_{unique_id}.mp4"
    final_vid_output_path = os.path.join(tempfile.gettempdir(), final_vid_filename)
    print(f"Output video path: {final_vid_output_path}")


    # Create a video clip from the image sequence
    # Get original FPS or default
    # Get FPS from the stored video path in session state
    original_fps = get_video_fps(session_state["video_path"])
    fps = original_fps if original_fps is not None and original_fps > 0 else 30 # Default to 30 if detection fails or is zero
    print(f"Creating output video with FPS: {fps}")

    # Check if there are frames to process
    if not output_frames:
         print("No output frames generated.")
         return (
            gr.update(value=None, visible=False), # Hide output video
            session_state,
         )

    # Create ImageSequenceClip from the list of numpy arrays
    try:
        clip = ImageSequenceClip(output_frames, fps=fps)
    except Exception as e:
        print(f"Error creating ImageSequenceClip: {e}")
        return (
            gr.update(value=None, visible=False), # Hide output video on error
            session_state,
        )


    # Write the result to a file. Use 'libx264' codec for broad compatibility.
    # `preset` and `threads` for CPU optimization.
    # `logger=None` prevents moviepy from printing progress to stdout/stderr, which can clutter the Gradio logs.
    try:
        print(f"Writing video file with codec='libx264', fps={fps}, preset='medium', threads='auto'")
        clip.write_videofile(
            final_vid_output_path,
            codec="libx264",
            fps=fps, # Ensure correct FPS is used during writing
            preset="medium", # CPU optimization: 'fast', 'faster', 'veryfast' are options for speed vs size
            threads="auto", # CPU optimization: Use multiple cores
            logger=None # Suppress moviepy output
        )
        print("Video writing complete.")
        # Return the path and make the video player visible
        return (
            gr.update(value=final_vid_output_path, visible=True),
            session_state,
        )
    except Exception as e:
        print(f"Error writing video file: {e}")
        # Clean up potentially created partial file
        if os.path.exists(final_vid_output_path):
             try:
                 os.remove(final_vid_output_path)
                 print(f"Removed partial video file: {final_vid_output_path}")
             except Exception as clean_e:
                 print(f"Error removing partial file: {clean_e}")

        # Return None if writing fails
        return (
            gr.update(value=None, visible=False),
            session_state,
        )


def update_output_video_visibility():
    """Simply returns a Gradio update to make the output video visible."""
    return gr.update(visible=True)


with gr.Blocks() as demo:
    # Session state dictionary to hold video frames, points, labels, and predictor state
    session_state = gr.State(
        {
            "first_frame": None, # numpy array (RGB)
            "all_frames": None,  # list of numpy arrays (RGB)
            "input_points": [],  # list of (x, y) tuples/lists
            "input_labels": [],  # list of 1s and 0s
            "inference_state": None, # EdgeTAM predictor state object
            "video_path": None, # Store the input video path
        }
    )

    with gr.Column():
        # Title
        gr.Markdown(title)
        with gr.Row():

            with gr.Column():
                # Instructions
                gr.Markdown(description_p)

                with gr.Accordion("Input Video", open=True) as video_in_drawer:
                    video_in = gr.Video(label="Input Video", format="mp4") # Will hold the video file path

                with gr.Row():
                    point_type = gr.Radio(
                        label="point type",
                        choices=["include", "exclude"],
                        value="include",
                        scale=2,
                        interactive=True, # Make interactive
                    )
                    # Buttons are initially disabled until a video is loaded
                    propagate_btn = gr.Button("Track", scale=1, variant="primary", interactive=False)
                    clear_points_btn = gr.Button("Clear Points", scale=1, interactive=False)
                    reset_btn = gr.Button("Reset", scale=1, interactive=False)

                # points_map is where users click to add points. Needs to be interactive.
                # Shows the first frame with points drawn on it.
                points_map = gr.Image(
                    label="Click on the First Frame to Add Points", # Clearer label
                    type="numpy",
                    interactive=True, # Make interactive to capture clicks
                    height=400, # Set a fixed height for better UI
                    width="auto", # Let width adjust
                    show_share_button=False,
                    show_download_button=False,
                    # show_label=False # Can hide label if space is tight
                )

            with gr.Column():
                gr.Markdown("# Try some of the examples below ⬇️")
                gr.Examples(
                    examples=examples,
                    inputs=[video_in],
                    examples_per_page=8,
                    cache_examples=False, # Do not cache processed examples, as state is involved
                )
                # Add padding/space
                # gr.Markdown("<br>")

                # output_image shows the segmentation mask prediction on the *first* frame
                output_image = gr.Image(
                    label="Segmentation Mask on First Frame", # Clearer label
                    type="numpy",
                    interactive=False, # Not interactive, just displays the mask
                    height=400, # Match height of points_map
                    width="auto", # Let width adjust
                    show_share_button=False,
                    show_download_button=False,
                    # show_label=False # Can hide label
                )

                # output_video shows the final tracking result
                output_video = gr.Video(visible=False, label="Tracking Result")


    # --- Event Handlers ---

    # When a new video file is uploaded via the file browser
    video_in.upload(
        fn=preprocess_video_in,
        inputs=[video_in, session_state],
        outputs=[
            video_in_drawer, # Close accordion
            points_map,      # Show first frame in points_map
            output_image,    # Clear output image
            output_video,    # Hide output video
            propagate_btn,   # Enable Track button
            clear_points_btn,# Enable Clear Points button
            reset_btn,       # Enable Reset button
            session_state,   # Update session state
        ],
        queue=False, # Process immediately
    )

    # When an example video is selected (change event)
    video_in.change(
        fn=preprocess_video_in,
        inputs=[video_in, session_state],
         outputs=[
            video_in_drawer, # Close accordion
            points_map,      # Show first frame in points_map
            output_image,    # Clear output image
            output_video,    # Hide output video
            propagate_btn,   # Enable Track button
            clear_points_btn,# Enable Clear Points button
            reset_btn,       # Enable Reset button
            session_state,   # Update session state
        ],
        queue=False, # Process immediately
    )


    # Triggered when a user clicks on the points_map image
    points_map.select(
        fn=segment_with_points,
        inputs=[
            point_type,  # "include" or "exclude" radio button value
            session_state, # Pass session state
        ],
        outputs=[
            points_map,      # Updated image with points drawn
            output_image,    # Updated image with first frame segmentation mask
            session_state,   # Updated session state (points/labels added)
        ],
        queue=False, # Process clicks immediately
    )

    # Button to clear all selected points and reset the first frame mask
    clear_points_btn.click(
        fn=clear_points,
        inputs=[session_state], # Pass session state
        outputs=[
            points_map,    # points_map shows original first frame without points
            output_image,  # output_image cleared (or shows original first frame without mask)
            output_video,  # Hide output video
            session_state, # Updated session state (points/labels cleared, inference state reset)
        ],
        queue=False, # Process immediately
    )

    # Button to reset the entire demo state and UI
    reset_btn.click(
        fn=reset,
        inputs=[session_state], # Pass session state
        outputs=[
            video_in,        # Clear video input
            video_in_drawer, # Open video accordion
            points_map,      # Clear points_map
            output_image,    # Clear output_image
            output_video,    # Hide output_video
            propagate_btn,   # Disable buttons
            clear_points_btn,# Disable buttons
            reset_btn,       # Disable buttons
            session_state,   # Reset session state
        ],
        queue=False, # Process immediately
    )

    # Button to start mask propagation through the video
    propagate_btn.click(
        fn=update_output_video_visibility, # First, make the output video player visible
        inputs=[],
        outputs=[output_video],
        queue=False, # Process this UI update immediately
    ).then( # Then, run the propagation function
        fn=propagate_to_all,
        inputs=[
            video_in,      # Get the input video path (can also get from session_state["video_path"])
            session_state, # Pass session state (contains frames, points, inference_state, video_path)
        ],
        outputs=[
            output_video,  # Update output video player with result
            session_state, # Update session state (currently, propagate doesn't modify state much, but good practice)
        ],
        # CPU Optimization: Limit concurrency to 1 to prevent resource exhaustion.
        # Queue=True ensures requests wait if another is processing.
        concurrency_limit=1,
        queue=True,
    )


# Launch the Gradio demo
demo.queue() # Enable queuing for sequential processing under concurrency limits
print("Gradio demo starting...")
# Removed share=True for local debugging unless you specifically need a public link
demo.launch()
print("Gradio demo launched.")