Spaces:
bla
/
Runtime error

bla commited on
Commit
e508568
·
verified ·
1 Parent(s): fa0b563

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +177 -220
app.py CHANGED
@@ -10,16 +10,16 @@ from datetime import datetime
10
 
11
  import gradio as gr
12
 
13
- # Removed GPU-specific environment variable setting
14
- # os.environ["TORCH_CUDNN_SDPA_ENABLED"] = "0,1,2,3,4,5,6,7"
15
-
16
  import tempfile
17
 
18
  import cv2
19
  import matplotlib.pyplot as plt
 
 
 
20
  import numpy as np
21
- # Removed spaces decorator import for CPU-only demo
22
- # import spaces
23
  import torch
24
 
25
  from moviepy.editor import ImageSequenceClip
@@ -38,7 +38,7 @@ description_p = """# Instructions
38
  </ol>
39
  """
40
 
41
- # examples - Keep examples, they are input files
42
  examples = [
43
  ["examples/01_dog.mp4"],
44
  ["examples/02_cups.mp4"],
@@ -75,77 +75,133 @@ OBJ_ID = 0
75
 
76
  sam2_checkpoint = "checkpoints/edgetam.pt"
77
  model_cfg = "edgetam.yaml"
78
- # Ensure predictor is explicitly built for CPU
79
- # The device is set here and with .to("cpu")
80
  predictor = build_sam2_video_predictor(model_cfg, sam2_checkpoint, device="cpu")
81
- predictor.to("cpu") # Explicitly move to CPU after building
82
- print("predictor loaded on CPU")
83
-
84
- # Removed autocast block for maximum CPU compatibility
85
- # torch.autocast(device_type="cpu", dtype=torch.bfloat16).__enter__()
86
-
87
- # Removed commented-out GPU-specific code
88
- # if torch.cuda.get_device_properties(0).major >= 8: ...
 
 
 
 
 
 
89
 
90
 
91
  def get_video_fps(video_path):
92
- """Gets the frames per second of a video file."""
93
- if video_path is None or not os.path.exists(video_path):
94
- print(f"Warning: Video file not found at {video_path}")
95
- return None
96
  cap = cv2.VideoCapture(video_path)
 
97
  if not cap.isOpened():
98
- print(f"Error: Could not open video file {video_path}.")
99
  return None
 
 
100
  fps = cap.get(cv2.CAP_PROP_FPS)
101
- cap.release()
102
  return fps
103
 
104
- # Removed @spaces.GPU decorator
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
105
  def preprocess_video_in(video_path, session_state):
106
  """Loads video frames and initializes the predictor state."""
107
  print(f"Processing video: {video_path}")
108
  if video_path is None or not os.path.exists(video_path):
109
  print("No video path provided or file not found.")
110
  # Reset state and UI elements if input is invalid
 
111
  return (
112
- gr.update(open=True), # video_in_drawer
113
- None, # points_map
114
- None, # output_image
115
- gr.update(value=None, visible=False), # output_video
116
- gr.update(interactive=False), # propagate_btn
117
- gr.update(interactive=False), # clear_points_btn
118
- gr.update(interactive=False), # reset_btn
119
- { # Reset session state
120
- "first_frame": None,
121
- "all_frames": None,
122
- "input_points": [],
123
- "input_labels": [],
124
- "inference_state": None,
125
- "video_path": None,
126
  }
127
  )
128
 
129
- # Read the first frame and all frames
130
  cap = cv2.VideoCapture(video_path)
131
  if not cap.isOpened():
132
  print(f"Error: Could not open video file {video_path}.")
133
- # Reset state and UI elements on error
134
  return (
135
- gr.update(open=True),
136
- None,
137
- None,
138
- gr.update(value=None, visible=False),
139
- gr.update(interactive=False), # propagate_btn
140
- gr.update(interactive=False), # clear_points_btn
141
- gr.update(interactive=False), # reset_btn
142
  { # Reset session state
143
- "first_frame": None,
144
- "all_frames": None,
145
- "input_points": [],
146
- "input_labels": [],
147
- "inference_state": None,
148
- "video_path": None,
149
  }
150
  )
151
 
@@ -156,139 +212,65 @@ def preprocess_video_in(video_path, session_state):
156
  ret, frame = cap.read()
157
  if not ret:
158
  break
159
- # Convert BGR to RGB
160
  frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
161
  all_frames.append(frame)
162
  if first_frame is None:
163
- first_frame = frame # Store the first frame
164
 
165
  cap.release()
166
 
167
  if not all_frames:
168
  print(f"Error: No frames read from video file {video_path}.")
169
- # Reset state and UI elements if no frames are read
170
  return (
171
- gr.update(open=True),
172
- None,
173
- None,
174
- gr.update(value=None, visible=False),
175
- gr.update(interactive=False), # propagate_btn
176
- gr.update(interactive=False), # clear_points_btn
177
- gr.update(interactive=False), # reset_btn
178
  { # Reset session state
179
- "first_frame": None,
180
- "all_frames": None,
181
- "input_points": [],
182
- "input_labels": [],
183
- "inference_state": None,
184
- "video_path": None,
185
  }
186
  )
187
 
188
- # Update session state with frames and path
189
- session_state["first_frame"] = copy.deepcopy(first_frame) # Store a copy
190
  session_state["all_frames"] = all_frames
191
- session_state["video_path"] = video_path # Store the path
192
  session_state["input_points"] = []
193
  session_state["input_labels"] = []
194
- # Initialize state *without* the device argument
195
  session_state["inference_state"] = predictor.init_state(video_path=video_path)
196
  print("Video loaded and predictor state initialized.")
197
 
 
198
  return [
199
  gr.update(open=False), # video_in_drawer
200
- first_frame, # points_map (shows first frame)
201
- None, # output_image (cleared initially)
202
- gr.update(value=None, visible=False), # output_video (hidden initially)
203
- gr.update(interactive=True), # Enable buttons
204
- gr.update(interactive=True), # Enable buttons
205
- gr.update(interactive=True), # Enable buttons
206
- session_state, # Updated state
207
  ]
208
 
209
 
210
- def reset(session_state):
211
- """Resets the UI and session state."""
212
- print("Resetting demo.")
213
- # Clear points and labels
214
- session_state["input_points"] = []
215
- session_state["input_labels"] = []
216
- # Reset the predictor state if it exists
217
- if session_state["inference_state"] is not None:
218
- predictor.reset_state(session_state["inference_state"])
219
- # After reset, we also discard the state object as a new video might be loaded
220
- session_state["inference_state"] = None
221
- # Clear frames and video path
222
- session_state["first_frame"] = None
223
- session_state["all_frames"] = None
224
- session_state["video_path"] = None
225
-
226
- # Update UI elements to their initial state
227
- return (
228
- None, # video_in
229
- gr.update(open=True), # video_in_drawer open
230
- None, # points_map cleared
231
- None, # output_image cleared
232
- gr.update(value=None, visible=False), # output_video hidden
233
- gr.update(interactive=False), # Disable buttons
234
- gr.update(interactive=False), # Disable buttons
235
- gr.update(interactive=False), # Disable buttons
236
- session_state, # Updated session state
237
- )
238
-
239
-
240
- def clear_points(session_state):
241
- """Clears selected points and resets segmentation on the first frame."""
242
- print("Clearing points.")
243
- # Clear points and labels lists
244
- session_state["input_points"] = []
245
- session_state["input_labels"] = []
246
-
247
- # Reset the predictor state if it exists. This clears internal masks/features
248
- # but keeps the video context initialized by preprocess_video_in.
249
- if session_state["inference_state"] is not None:
250
- predictor.reset_state(session_state["inference_state"])
251
- # After resetting the state, if we still have the video path, re-initialize the state
252
- # to be ready for new points on the same video.
253
- if session_state["video_path"] is not None:
254
- # Re-initialize state *without* the device argument
255
- session_state["inference_state"] = predictor.init_state(video_path=session_state["video_path"])
256
- print("Predictor state re-initialized after clearing points.")
257
- else:
258
- print("Warning: Could not re-initialize state after clear_points (video_path missing).")
259
- session_state["inference_state"] = None # Ensure state is None if video_path is gone
260
-
261
-
262
- # Re-render the points_map with no points drawn (just the first frame)
263
- # Re-render the output_image with no mask (just the first frame)
264
- first_frame_img = session_state["first_frame"] if session_state["first_frame"] is not None else None
265
-
266
- return (
267
- first_frame_img, # points_map shows original first frame
268
- None, # output_image cleared
269
- gr.update(value=None, visible=False), # Hide output video
270
- session_state, # Updated session state
271
- )
272
-
273
-
274
- # Removed @spaces.GPU decorator
275
  def segment_with_points(
276
  point_type,
277
  session_state,
278
  evt: gr.SelectData,
279
  ):
280
  """Adds a point prompt and performs segmentation on the first frame."""
281
- # Ensure we have a valid first frame and inference state
282
  if session_state["first_frame"] is None or session_state["inference_state"] is None:
283
  print("Error: Cannot segment. No video loaded or inference state missing.")
284
- # Return current states to avoid errors, without changing UI much
285
  return (
286
- session_state["first_frame"], # points_map remains unchanged
287
- None, # output_image remains unchanged or cleared
288
  session_state,
289
  )
290
 
291
- # evt.index gives the (x, y) coordinates of the click
292
  click_coords = evt.index
293
  print(f"Clicked at: {click_coords} ({point_type})")
294
 
@@ -314,12 +296,11 @@ def segment_with_points(
314
  for index, track in enumerate(session_state["input_points"]):
315
  # Ensure coordinates are integers for cv2.circle
316
  point_coords = (int(track[0]), int(track[1]))
 
317
  if session_state["input_labels"][index] == 1:
318
- # Green circle for include
319
- cv2.circle(transparent_layer_points, point_coords, radius, (0, 255, 0, 255), -1)
320
  else:
321
- # Red circle for exclude
322
- cv2.circle(transparent_layer_points, point_coords, radius, (255, 0, 0, 255), -1)
323
 
324
  # Convert the transparent layer back to an image and composite onto the first frame
325
  transparent_layer_points_pil = Image.fromarray(transparent_layer_points, "RGBA")
@@ -329,13 +310,14 @@ def segment_with_points(
329
  first_frame_pil.copy(), transparent_layer_points_pil
330
  )
331
 
332
- # Prepare points and labels as tensors on CPU for the predictor
333
  points = np.array(session_state["input_points"], dtype=np.float32)
334
  labels = np.array(session_state["input_labels"], np.int32)
335
 
336
- # Ensure tensors are on CPU
337
- points_tensor = torch.tensor(points, dtype=torch.float32, device="cpu").unsqueeze(0) # Add batch dim
338
- labels_tensor = torch.tensor(labels, dtype=torch.int32, device="cpu").unsqueeze(0) # Add batch dim
 
339
 
340
  # Add new points to the predictor's state and get the mask for the first frame
341
  # This call performs segmentation on the current frame (frame_idx=0) using all accumulated points
@@ -351,8 +333,9 @@ def segment_with_points(
351
  )
352
 
353
  # Process logits: detach from graph, move to CPU, apply threshold
354
- # out_mask_logits is a list of tensors [tensor([H, W])] for the requested obj_id
355
- mask_tensor = (out_mask_logits[0][0].detach().cpu() > 0.0) # Apply threshold and get the single mask tensor [H, W]
 
356
  mask_numpy = mask_tensor.numpy() # Convert to numpy
357
 
358
  # Get the mask image (RGBA)
@@ -366,6 +349,9 @@ def segment_with_points(
366
  print(f"Error during segmentation on first frame: {e}")
367
  # On error, first_frame_output_img remains None
368
 
 
 
 
369
 
370
  return selected_point_map_img, first_frame_output_img, session_state
371
 
@@ -416,21 +402,22 @@ def show_mask(mask, obj_id=None, random_color=False, convert_to_image=True):
416
  return colored_mask_uint8
417
 
418
 
419
- # Removed @spaces.GPU decorator
 
420
  def propagate_to_all(
421
- # We don't strictly need video_in path here anymore as it's in session_state,
422
- # but keeping it is fine. Accessing session_state["video_path"] is more robust.
423
- video_in,
424
  session_state,
425
  ):
426
  """Runs mask propagation through the video and generates the output video."""
427
  print("Starting propagation...")
428
  # Ensure state is ready
 
 
429
  if (
430
  len(session_state["input_points"]) == 0 # Need at least one point
431
  or session_state["all_frames"] is None
432
  or session_state["inference_state"] is None
433
- or session_state["video_path"] is None # Ensure we have the original video path
434
  ):
435
  print("Error: Cannot propagate. No points selected, video not loaded, or inference state missing.")
436
  return (
@@ -439,7 +426,6 @@ def propagate_to_all(
439
  )
440
 
441
  # run propagation throughout the video and collect the results
442
- # The generator yields (frame_idx, obj_ids, mask_logits)
443
  video_segments = {}
444
  try:
445
  # This loop performs the core tracking prediction frame by frame
@@ -451,7 +437,7 @@ def propagate_to_all(
451
  video_segments[out_frame_idx] = {
452
  # out_mask_logits is a list of tensors (one per object tracked in this frame)
453
  # Each tensor is [batch_size, H, W]. Batch size is 1 here.
454
- # Access the first element of the batch [0]
455
  out_obj_id: (out_mask_logits[i][0].detach().cpu() > 0.0).numpy()
456
  for i, out_obj_id in enumerate(out_obj_ids)
457
  }
@@ -492,9 +478,11 @@ def propagate_to_all(
492
 
493
  output_frames.append(output_frame_np)
494
 
 
 
 
495
 
496
  # Define output path in a temporary directory
497
- # Use os.path.join for cross-platform compatibility
498
  unique_id = datetime.now().strftime("%Y%m%d%H%M%S%f") # Use microseconds for more uniqueness
499
  final_vid_filename = f"output_video_{unique_id}.mp4"
500
  final_vid_output_path = os.path.join(tempfile.gettempdir(), final_vid_filename)
@@ -502,9 +490,8 @@ def propagate_to_all(
502
 
503
 
504
  # Create a video clip from the image sequence
505
- # Get original FPS or default
506
- # Get FPS from the stored video path in session state
507
- original_fps = get_video_fps(session_state["video_path"])
508
  fps = original_fps if original_fps is not None and original_fps > 0 else 30 # Default to 30 if detection fails or is zero
509
  print(f"Creating output video with FPS: {fps}")
510
 
@@ -526,20 +513,11 @@ def propagate_to_all(
526
  session_state,
527
  )
528
 
529
-
530
  # Write the result to a file. Use 'libx264' codec for broad compatibility.
531
- # `preset` and `threads` for CPU optimization.
532
- # `logger=None` prevents moviepy from printing progress to stdout/stderr, which can clutter the Gradio logs.
533
  try:
534
- print(f"Writing video file with codec='libx264', fps={fps}, preset='medium', threads='auto'")
535
- clip.write_videofile(
536
- final_vid_output_path,
537
- codec="libx264",
538
- fps=fps, # Ensure correct FPS is used during writing
539
- preset="medium", # CPU optimization: 'fast', 'faster', 'veryfast' are options for speed vs size
540
- threads="auto", # CPU optimization: Use multiple cores
541
- logger=None # Suppress moviepy output
542
- )
543
  print("Video writing complete.")
544
  # Return the path and make the video player visible
545
  return (
@@ -563,7 +541,7 @@ def propagate_to_all(
563
  )
564
 
565
 
566
- def update_output_video_visibility():
567
  """Simply returns a Gradio update to make the output video visible."""
568
  return gr.update(visible=True)
569
 
@@ -611,12 +589,11 @@ with gr.Blocks() as demo:
611
  points_map = gr.Image(
612
  label="Click on the First Frame to Add Points", # Clearer label
613
  type="numpy",
614
- interactive=True, # Make interactive to capture clicks
615
  height=400, # Set a fixed height for better UI
616
  width="auto", # Let width adjust
617
  show_share_button=False,
618
  show_download_button=False,
619
- # show_label=False # Can hide label if space is tight
620
  )
621
 
622
  with gr.Column():
@@ -627,7 +604,7 @@ with gr.Blocks() as demo:
627
  examples_per_page=8,
628
  cache_examples=False, # Do not cache processed examples, as state is involved
629
  )
630
- # Add padding/space
631
  # gr.Markdown("<br>")
632
 
633
  # output_image shows the segmentation mask prediction on the *first* frame
@@ -639,7 +616,6 @@ with gr.Blocks() as demo:
639
  width="auto", # Let width adjust
640
  show_share_button=False,
641
  show_download_button=False,
642
- # show_label=False # Can hide label
643
  )
644
 
645
  # output_video shows the final tracking result
@@ -649,35 +625,25 @@ with gr.Blocks() as demo:
649
  # --- Event Handlers ---
650
 
651
  # When a new video file is uploaded via the file browser
 
652
  video_in.upload(
653
  fn=preprocess_video_in,
654
  inputs=[video_in, session_state],
655
  outputs=[
656
- video_in_drawer, # Close accordion
657
- points_map, # Show first frame in points_map
658
- output_image, # Clear output image
659
- output_video, # Hide output video
660
- propagate_btn, # Enable Track button
661
- clear_points_btn,# Enable Clear Points button
662
- reset_btn, # Enable Reset button
663
- session_state, # Update session state
664
  ],
665
  queue=False, # Process immediately
666
  )
667
 
668
  # When an example video is selected (change event)
 
669
  video_in.change(
670
  fn=preprocess_video_in,
671
  inputs=[video_in, session_state],
672
- outputs=[
673
- video_in_drawer, # Close accordion
674
- points_map, # Show first frame in points_map
675
- output_image, # Clear output image
676
- output_video, # Hide output video
677
- propagate_btn, # Enable Track button
678
- clear_points_btn,# Enable Clear Points button
679
- reset_btn, # Enable Reset button
680
- session_state, # Update session state
681
  ],
682
  queue=False, # Process immediately
683
  )
@@ -716,15 +682,8 @@ with gr.Blocks() as demo:
716
  fn=reset,
717
  inputs=[session_state], # Pass session state
718
  outputs=[
719
- video_in, # Clear video input
720
- video_in_drawer, # Open video accordion
721
- points_map, # Clear points_map
722
- output_image, # Clear output_image
723
- output_video, # Hide output_video
724
- propagate_btn, # Disable buttons
725
- clear_points_btn,# Disable buttons
726
- reset_btn, # Disable buttons
727
- session_state, # Reset session state
728
  ],
729
  queue=False, # Process immediately
730
  )
@@ -743,18 +702,16 @@ with gr.Blocks() as demo:
743
  ],
744
  outputs=[
745
  output_video, # Update output video player with result
746
- session_state, # Update session state (currently, propagate doesn't modify state much, but good practice)
747
  ],
748
- # CPU Optimization: Limit concurrency to 1 to prevent resource exhaustion.
749
- # Queue=True ensures requests wait if another is processing.
750
- concurrency_limit=1,
751
- queue=True,
752
  )
753
 
754
 
755
  # Launch the Gradio demo
756
- demo.queue() # Enable queuing for sequential processing under concurrency limits
757
  print("Gradio demo starting...")
758
- # Removed share=True for local debugging unless you specifically need a public link
759
  demo.launch()
760
  print("Gradio demo launched.")
 
10
 
11
  import gradio as gr
12
 
13
+ # This line might be related to GPU, kept from original
14
+ os.environ["TORCH_CUDNN_SDPA_ENABLED"] = "0,1,2,3,4,5,6,7"
 
15
  import tempfile
16
 
17
  import cv2
18
  import matplotlib.pyplot as plt
19
+ # spaces import and decorators are for Hugging Face Spaces GPU allocation,
20
+ # if running locally without spaces, these can be removed or will be ignored.
21
+ import spaces
22
  import numpy as np
 
 
23
  import torch
24
 
25
  from moviepy.editor import ImageSequenceClip
 
38
  </ol>
39
  """
40
 
41
+ # examples
42
  examples = [
43
  ["examples/01_dog.mp4"],
44
  ["examples/02_cups.mp4"],
 
75
 
76
  sam2_checkpoint = "checkpoints/edgetam.pt"
77
  model_cfg = "edgetam.yaml"
78
+ # Model built for CPU but immediately moved to CUDA in original code
 
79
  predictor = build_sam2_video_predictor(model_cfg, sam2_checkpoint, device="cpu")
80
+ # *** Original code moves to CUDA ***
81
+ predictor.to("cuda")
82
+ print("predictor loaded on CUDA")
83
+
84
+ # use bfloat16 for the entire demo - Original code uses CUDA bfloat16
85
+ torch.autocast(device_type="cuda", dtype=torch.bfloat16).__enter__()
86
+ # Original CUDA settings
87
+ if torch.cuda.is_available() and torch.cuda.get_device_properties(0).major >= 8:
88
+ # turn on tfloat32 for Ampere GPUs (https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices)
89
+ torch.backends.cuda.matmul.allow_tf32 = True
90
+ torch.backends.cudnn.allow_tf32 = True
91
+ elif not torch.cuda.is_available():
92
+ print("Warning: CUDA not available. The original code is configured for GPU.")
93
+ # Note: Without a GPU, the .to("cuda") calls will likely cause errors.
94
 
95
 
96
  def get_video_fps(video_path):
97
+ # Open the video file
 
 
 
98
  cap = cv2.VideoCapture(video_path)
99
+
100
  if not cap.isOpened():
101
+ print("Error: Could not open video.")
102
  return None
103
+
104
+ # Get the FPS of the video
105
  fps = cap.get(cv2.CAP_PROP_FPS)
106
+ cap.release() # Release the capture object
107
  return fps
108
 
109
+
110
+ def reset(session_state):
111
+ """Resets the UI and session state."""
112
+ print("Resetting demo.")
113
+ session_state["input_points"] = []
114
+ session_state["input_labels"] = []
115
+ # Reset the predictor state if it exists
116
+ if session_state["inference_state"] is not None:
117
+ # Assuming predictor.reset_state handles None or invalid states gracefully
118
+ # Or you might need to explicitly pass the state object if required
119
+ try:
120
+ predictor.reset_state(session_state["inference_state"])
121
+ # Explicitly delete or re-init the state object if a full reset is intended
122
+ # This depends on how predictor.reset_state works
123
+ # session_state["inference_state"] = None # Example if reset_state doesn't fully clear
124
+ except Exception as e:
125
+ print(f"Error resetting predictor state: {e}")
126
+ # If reset fails, perhaps force-clear the state object
127
+ session_state["inference_state"] = None
128
+
129
+ session_state["first_frame"] = None
130
+ session_state["all_frames"] = None
131
+ session_state["inference_state"] = None # Ensure state is None after a full reset
132
+ # Also reset video path if stored
133
+ session_state["video_path"] = None
134
+
135
+ # Resetting UI components
136
+ return (
137
+ None, # video_in (clears the video player)
138
+ gr.update(open=True), # video_in_drawer (opens accordion)
139
+ None, # points_map (clears the image)
140
+ None, # output_image (clears the image)
141
+ gr.update(value=None, visible=False), # output_video (hides and clears)
142
+ session_state, # return updated session state
143
+ )
144
+
145
+
146
+ def clear_points(session_state):
147
+ """Clears selected points and resets segmentation on the first frame."""
148
+ print("Clearing points.")
149
+ session_state["input_points"] = []
150
+ session_state["input_labels"] = []
151
+
152
+ # Reset the predictor state to clear internal masks/features
153
+ # This typically doesn't remove the video context, just the mask predictions
154
+ if session_state["inference_state"] is not None:
155
+ try:
156
+ # Assuming reset_state handles clearing current masks/features
157
+ predictor.reset_state(session_state["inference_state"])
158
+ print("Predictor state reset for clearing points.")
159
+ # If you need to re-initialize the state for the *same* video after clearing points,
160
+ # you might need to call predictor.init_state again here, using the stored video_path.
161
+ # session_state["inference_state"] = predictor.init_state(video_path=session_state["video_path"], device="cuda") # Or device="cpu" if modified earlier
162
+ except Exception as e:
163
+ print(f"Error resetting predictor state during clear_points: {e}")
164
+ # If reset fails, this might leave old masks. Depending on SAM2's behavior,
165
+ # you might need a more aggressive state clear or re-initialization.
166
+
167
+ # Return the original first frame image for points_map and clear the output_image
168
+ first_frame_img = session_state["first_frame"] if session_state["first_frame"] is not None else None
169
+
170
+ return (
171
+ first_frame_img, # points_map shows original first frame (no points yet)
172
+ None, # output_image cleared (no mask)
173
+ gr.update(value=None, visible=False), # output_video hidden
174
+ session_state, # return updated session state
175
+ )
176
+
177
+
178
+ # Added @spaces.GPU decorator back as it was in the original code
179
+ @spaces.GPU
180
  def preprocess_video_in(video_path, session_state):
181
  """Loads video frames and initializes the predictor state."""
182
  print(f"Processing video: {video_path}")
183
  if video_path is None or not os.path.exists(video_path):
184
  print("No video path provided or file not found.")
185
  # Reset state and UI elements if input is invalid
186
+ # Need to return updates for the buttons as well
187
  return (
188
+ gr.update(open=True), None, None, gr.update(value=None, visible=False),
189
+ gr.update(interactive=False), gr.update(interactive=False), gr.update(interactive=False),
190
+ { # Reset session state
191
+ "first_frame": None, "all_frames": None, "input_points": [],
192
+ "input_labels": [], "inference_state": None, "video_path": None,
 
 
 
 
 
 
 
 
 
193
  }
194
  )
195
 
 
196
  cap = cv2.VideoCapture(video_path)
197
  if not cap.isOpened():
198
  print(f"Error: Could not open video file {video_path}.")
 
199
  return (
200
+ gr.update(open=True), None, None, gr.update(value=None, visible=False),
201
+ gr.update(interactive=False), gr.update(interactive=False), gr.update(interactive=False),
 
 
 
 
 
202
  { # Reset session state
203
+ "first_frame": None, "all_frames": None, "input_points": [],
204
+ "input_labels": [], "inference_state": None, "video_path": None,
 
 
 
 
205
  }
206
  )
207
 
 
212
  ret, frame = cap.read()
213
  if not ret:
214
  break
 
215
  frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
216
  all_frames.append(frame)
217
  if first_frame is None:
218
+ first_frame = frame
219
 
220
  cap.release()
221
 
222
  if not all_frames:
223
  print(f"Error: No frames read from video file {video_path}.")
 
224
  return (
225
+ gr.update(open=True), None, None, gr.update(value=None, visible=False),
226
+ gr.update(interactive=False), gr.update(interactive=False), gr.update(interactive=False),
 
 
 
 
 
227
  { # Reset session state
228
+ "first_frame": None, "all_frames": None, "input_points": [],
229
+ "input_labels": [], "inference_state": None, "video_path": None,
 
 
 
 
230
  }
231
  )
232
 
233
+ session_state["first_frame"] = copy.deepcopy(first_frame)
 
234
  session_state["all_frames"] = all_frames
235
+ session_state["video_path"] = video_path # Store video path
236
  session_state["input_points"] = []
237
  session_state["input_labels"] = []
238
+ # Original code did NOT pass device here. It uses the device the predictor is on.
239
  session_state["inference_state"] = predictor.init_state(video_path=video_path)
240
  print("Video loaded and predictor state initialized.")
241
 
242
+ # Enable buttons after successful load
243
  return [
244
  gr.update(open=False), # video_in_drawer
245
+ first_frame, # points_map
246
+ None, # output_image
247
+ gr.update(value=None, visible=False), # output_video
248
+ gr.update(interactive=True), # propagate_btn
249
+ gr.update(interactive=True), # clear_points_btn
250
+ gr.update(interactive=True), # reset_btn
251
+ session_state, # session_state
252
  ]
253
 
254
 
255
+ # Added @spaces.GPU decorator back as it was in the original code
256
+ @spaces.GPU
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
257
  def segment_with_points(
258
  point_type,
259
  session_state,
260
  evt: gr.SelectData,
261
  ):
262
  """Adds a point prompt and performs segmentation on the first frame."""
263
+ # Ensure we have state and first frame
264
  if session_state["first_frame"] is None or session_state["inference_state"] is None:
265
  print("Error: Cannot segment. No video loaded or inference state missing.")
266
+ # Return current images and state without changes
267
  return (
268
+ session_state.get("first_frame"), # points_map (show first frame if exists)
269
+ None, # output_image (keep cleared)
270
  session_state,
271
  )
272
 
273
+ # evt.index is the (x, y) coordinate tuple
274
  click_coords = evt.index
275
  print(f"Clicked at: {click_coords} ({point_type})")
276
 
 
296
  for index, track in enumerate(session_state["input_points"]):
297
  # Ensure coordinates are integers for cv2.circle
298
  point_coords = (int(track[0]), int(track[1]))
299
+ # Ensure color is RGBA (0-255)
300
  if session_state["input_labels"][index] == 1:
301
+ cv2.circle(transparent_layer_points, point_coords, radius, (0, 255, 0, 255), -1) # Green for include
 
302
  else:
303
+ cv2.circle(transparent_layer_points, point_coords, radius, (255, 0, 0, 255), -1) # Red for exclude
 
304
 
305
  # Convert the transparent layer back to an image and composite onto the first frame
306
  transparent_layer_points_pil = Image.fromarray(transparent_layer_points, "RGBA")
 
310
  first_frame_pil.copy(), transparent_layer_points_pil
311
  )
312
 
313
+ # Prepare points and labels as tensors on the correct device (CUDA in original code)
314
  points = np.array(session_state["input_points"], dtype=np.float32)
315
  labels = np.array(session_state["input_labels"], np.int32)
316
 
317
+ # Ensure tensors are on the correct device (CUDA as per original code setup)
318
+ device = next(predictor.parameters()).device # Get the device the model is on
319
+ points_tensor = torch.tensor(points, dtype=torch.float32, device=device).unsqueeze(0) # Add batch dim
320
+ labels_tensor = torch.tensor(labels, dtype=torch.int32, device=device).unsqueeze(0) # Add batch dim
321
 
322
  # Add new points to the predictor's state and get the mask for the first frame
323
  # This call performs segmentation on the current frame (frame_idx=0) using all accumulated points
 
333
  )
334
 
335
  # Process logits: detach from graph, move to CPU, apply threshold
336
+ # out_mask_logits is a list of tensors [tensor([batch_size, H, W])] for the requested obj_id
337
+ # Access the result for the first object (index 0) and the first item in batch (index 0)
338
+ mask_tensor = (out_mask_logits[0][0].detach().cpu() > 0.0) # Move to CPU before converting to numpy
339
  mask_numpy = mask_tensor.numpy() # Convert to numpy
340
 
341
  # Get the mask image (RGBA)
 
349
  print(f"Error during segmentation on first frame: {e}")
350
  # On error, first_frame_output_img remains None
351
 
352
+ # Original code clears CUDA cache here
353
+ if torch.cuda.is_available():
354
+ torch.cuda.empty_cache()
355
 
356
  return selected_point_map_img, first_frame_output_img, session_state
357
 
 
402
  return colored_mask_uint8
403
 
404
 
405
+ # Added @spaces.GPU decorator back as it was in the original code
406
+ @spaces.GPU
407
  def propagate_to_all(
408
+ video_in, # Keep video_in path as in original
 
 
409
  session_state,
410
  ):
411
  """Runs mask propagation through the video and generates the output video."""
412
  print("Starting propagation...")
413
  # Ensure state is ready
414
+ # Using session_state.get("video_path") is safer than video_in directly
415
+ current_video_path = session_state.get("video_path")
416
  if (
417
  len(session_state["input_points"]) == 0 # Need at least one point
418
  or session_state["all_frames"] is None
419
  or session_state["inference_state"] is None
420
+ or current_video_path is None # Ensure we have the original video path
421
  ):
422
  print("Error: Cannot propagate. No points selected, video not loaded, or inference state missing.")
423
  return (
 
426
  )
427
 
428
  # run propagation throughout the video and collect the results
 
429
  video_segments = {}
430
  try:
431
  # This loop performs the core tracking prediction frame by frame
 
437
  video_segments[out_frame_idx] = {
438
  # out_mask_logits is a list of tensors (one per object tracked in this frame)
439
  # Each tensor is [batch_size, H, W]. Batch size is 1 here.
440
+ # Access the result for the first object (index i) and the first item in batch (index 0)
441
  out_obj_id: (out_mask_logits[i][0].detach().cpu() > 0.0).numpy()
442
  for i, out_obj_id in enumerate(out_obj_ids)
443
  }
 
478
 
479
  output_frames.append(output_frame_np)
480
 
481
+ # Original code clears CUDA cache here
482
+ if torch.cuda.is_available():
483
+ torch.cuda.empty_cache()
484
 
485
  # Define output path in a temporary directory
 
486
  unique_id = datetime.now().strftime("%Y%m%d%H%M%S%f") # Use microseconds for more uniqueness
487
  final_vid_filename = f"output_video_{unique_id}.mp4"
488
  final_vid_output_path = os.path.join(tempfile.gettempdir(), final_vid_filename)
 
490
 
491
 
492
  # Create a video clip from the image sequence
493
+ # Get original FPS from the stored video path
494
+ original_fps = get_video_fps(current_video_path)
 
495
  fps = original_fps if original_fps is not None and original_fps > 0 else 30 # Default to 30 if detection fails or is zero
496
  print(f"Creating output video with FPS: {fps}")
497
 
 
513
  session_state,
514
  )
515
 
 
516
  # Write the result to a file. Use 'libx264' codec for broad compatibility.
 
 
517
  try:
518
+ print(f"Writing video file with codec='libx264', fps={fps}")
519
+ # Added basic moviepy writing parameters back, similar to original intent
520
+ clip.write_videofile(final_vid_output_path, codec="libx264", fps=fps)
 
 
 
 
 
 
521
  print("Video writing complete.")
522
  # Return the path and make the video player visible
523
  return (
 
541
  )
542
 
543
 
544
+ def update_ui():
545
  """Simply returns a Gradio update to make the output video visible."""
546
  return gr.update(visible=True)
547
 
 
589
  points_map = gr.Image(
590
  label="Click on the First Frame to Add Points", # Clearer label
591
  type="numpy",
592
+ interactive=True, # <--- THIS WAS CHANGED FROM False TO True
593
  height=400, # Set a fixed height for better UI
594
  width="auto", # Let width adjust
595
  show_share_button=False,
596
  show_download_button=False,
 
597
  )
598
 
599
  with gr.Column():
 
604
  examples_per_page=8,
605
  cache_examples=False, # Do not cache processed examples, as state is involved
606
  )
607
+ # Add padding/space - removed extra lines as they take up a lot of space
608
  # gr.Markdown("<br>")
609
 
610
  # output_image shows the segmentation mask prediction on the *first* frame
 
616
  width="auto", # Let width adjust
617
  show_share_button=False,
618
  show_download_button=False,
 
619
  )
620
 
621
  # output_video shows the final tracking result
 
625
  # --- Event Handlers ---
626
 
627
  # When a new video file is uploaded via the file browser
628
+ # Added postprocess to update button interactivity based on whether video loaded
629
  video_in.upload(
630
  fn=preprocess_video_in,
631
  inputs=[video_in, session_state],
632
  outputs=[
633
+ video_in_drawer, points_map, output_image, output_video,
634
+ propagate_btn, clear_points_btn, reset_btn, session_state,
 
 
 
 
 
 
635
  ],
636
  queue=False, # Process immediately
637
  )
638
 
639
  # When an example video is selected (change event)
640
+ # Added postprocess to update button interactivity
641
  video_in.change(
642
  fn=preprocess_video_in,
643
  inputs=[video_in, session_state],
644
+ outputs=[
645
+ video_in_drawer, points_map, output_image, output_video,
646
+ propagate_btn, clear_points_btn, reset_btn, session_state,
 
 
 
 
 
 
647
  ],
648
  queue=False, # Process immediately
649
  )
 
682
  fn=reset,
683
  inputs=[session_state], # Pass session state
684
  outputs=[
685
+ video_in, video_in_drawer, points_map, output_image, output_video,
686
+ propagate_btn, clear_points_btn, reset_btn, session_state,
 
 
 
 
 
 
 
687
  ],
688
  queue=False, # Process immediately
689
  )
 
702
  ],
703
  outputs=[
704
  output_video, # Update output video player with result
705
+ session_state, # Update session state
706
  ],
707
+ # concurrency_limit from original code (may need adjustment based on your hardware/GPU)
708
+ concurrency_limit=10,
709
+ queue=False, # queue from original code
 
710
  )
711
 
712
 
713
  # Launch the Gradio demo
714
+ demo.queue() # Enable queuing
715
  print("Gradio demo starting...")
 
716
  demo.launch()
717
  print("Gradio demo launched.")