Spaces:
bla
/
Runtime error

bla commited on
Commit
0b34400
·
verified ·
1 Parent(s): 5bc3a57

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +261 -506
app.py CHANGED
@@ -7,27 +7,24 @@
7
  import copy
8
  import os
9
  from datetime import datetime
10
-
11
- import gradio as gr
12
-
13
- # Removed GPU-specific environment variable setting
14
- # os.environ["TORCH_CUDNN_SDPA_ENABLED"] = "0,1,2,3,4,5,6,7"
15
-
16
  import tempfile
17
 
18
  import cv2
19
  import matplotlib.pyplot as plt
20
  import numpy as np
21
- # Removed spaces decorator import for CPU-only demo
22
- # import spaces # Removed spaces import
23
  import torch
24
 
25
  from moviepy.editor import ImageSequenceClip
26
  from PIL import Image
27
  from sam2.build_sam import build_sam2_video_predictor
28
 
 
 
 
 
29
  # Description
30
- title = "<center><strong><font size='8'>EdgeTAM<font></strong> <a href='https://github.com/facebookresearch/EdgeTAM'><font size='6'>[GitHub]</font></a> </center>"
31
 
32
  description_p = """# Instructions
33
  <ol>
@@ -38,535 +35,314 @@ description_p = """# Instructions
38
  </ol>
39
  """
40
 
41
- # examples - Keep examples, they are input files
42
  examples = [
43
  ["examples/01_dog.mp4"],
44
  ["examples/02_cups.mp4"],
45
  ["examples/03_blocks.mp4"],
46
  ["examples/04_coffee.mp4"],
47
  ["examples/05_default_juggle.mp4"],
48
- ["examples/01_breakdancer.mp4"],
49
- ["examples/02_hummingbird.mp4"],
50
- ["examples/03_skateboarder.mp4"],
51
- ["examples/04_octopus.mp4"],
52
- ["examples/05_landing_dog_soccer.mp4"],
53
- ["examples/06_pingpong.mp4"],
54
- ["examples/07_snowboarder.mp4"],
55
- ["examples/08_driving.mp4"],
56
- ["examples/09_birdcartoon.mp4"],
57
- ["examples/10_cloth_magic.mp4"],
58
- ["examples/11_polevault.mp4"],
59
- ["examples/12_hideandseek.mp4"],
60
- ["examples/13_butterfly.mp4"],
61
- ["examples/14_social_dog_training.mp4"],
62
- ["examples/15_cricket.mp4"],
63
- ["examples/16_robotarm.mp4"],
64
- ["examples/17_childrendancing.mp4"],
65
- ["examples/18_threedogs.mp4"],
66
- ["examples/19_cyclist.mp4"],
67
- ["examples/20_doughkneading.mp4"],
68
- ["examples/21_biker.mp4"],
69
- ["examples/22_dogskateboarder.mp4"],
70
- ["examples/23_racecar.mp4"],
71
- ["examples/24_clownfish.mp4"],
72
  ]
73
 
74
  OBJ_ID = 0
75
 
 
76
  sam2_checkpoint = "checkpoints/edgetam.pt"
77
  model_cfg = "edgetam.yaml"
78
- # Ensure predictor is explicitly built for CPU
79
  predictor = build_sam2_video_predictor(model_cfg, sam2_checkpoint, device="cpu")
80
- # Removed .to("cuda") - predictor is already on CPU from build_sam2_video_predictor
81
- # predictor.to("cuda")
82
  print("predictor loaded on CPU")
83
 
84
- # Removed CUDA specific autocast and backend settings
85
- # torch.autocast(device_type="cuda", dtype=torch.bfloat16).__enter__()
86
- # if torch.cuda.is_available() and torch.cuda.get_device_properties(0).major >= 8:
87
- # torch.backends.cuda.matmul.allow_tf32 = True
88
- # torch.backends.cudnn.allow_tf32 = True
89
- # elif not torch.cuda.is_available():
90
- # print("Warning: CUDA not available. Running on CPU.")
91
-
92
-
93
  def get_video_fps(video_path):
94
- """Gets the frames per second of a video file."""
95
- if video_path is None or not os.path.exists(video_path):
96
- print(f"Warning: Video file not found at {video_path}")
97
- return None
98
  cap = cv2.VideoCapture(video_path)
99
  if not cap.isOpened():
100
- print(f"Error: Could not open video file {video_path}.")
101
- return None
102
  fps = cap.get(cv2.CAP_PROP_FPS)
103
- cap.release() # Release the capture object
104
  return fps
105
 
106
 
107
  def reset(session_state):
108
- """Resets the UI and session state."""
109
- print("Resetting demo.")
110
  session_state["input_points"] = []
111
  session_state["input_labels"] = []
112
- # Reset the predictor state if it exists
113
  if session_state["inference_state"] is not None:
114
- try:
115
- # Assuming predictor.reset_state handles clearing current masks/features
116
- predictor.reset_state(session_state["inference_state"])
117
- # Explicitly delete or re-init the state object if a full reset is intended
118
- # This depends on how predictor.reset_state works. Setting to None is safest for a full reset.
119
- session_state["inference_state"] = None
120
- except Exception as e:
121
- print(f"Error resetting predictor state: {e}")
122
- session_state["inference_state"] = None # Force-clear on error
123
-
124
  session_state["first_frame"] = None
125
  session_state["all_frames"] = None
126
- session_state["inference_state"] = None # Ensure state is None after a full reset
127
- # Also reset video path if stored
128
- session_state["video_path"] = None
129
-
130
- # Resetting UI components and disabling buttons
131
  return (
132
- None, # video_in (clears the video player)
133
- gr.update(open=True), # video_in_drawer (opens accordion)
134
- None, # points_map (clears the image)
135
- None, # output_image (clears the image)
136
- gr.update(value=None, visible=False), # output_video (hides and clears)
137
- gr.update(interactive=False), # propagate_btn disabled
138
- gr.update(interactive=False), # clear_points_btn disabled
139
- gr.update(interactive=False), # reset_btn disabled
140
- session_state, # return updated session state
141
  )
142
 
143
 
144
  def clear_points(session_state):
145
- """Clears selected points and resets segmentation on the first frame."""
146
- print("Clearing points.")
147
  session_state["input_points"] = []
148
  session_state["input_labels"] = []
149
-
150
- # Reset the predictor state to clear internal masks/features
151
- # This typically doesn't remove the video context, just the mask predictions
152
- if session_state["inference_state"] is not None:
153
- try:
154
- # Assuming reset_state handles clearing current masks/features
155
- predictor.reset_state(session_state["inference_state"])
156
- print("Predictor state reset for clearing points.")
157
- # If you need to re-initialize the state for the *same* video after clearing points,
158
- # you might need to call predictor.init_state again here, using the stored video_path.
159
- # Since we are on CPU, device="cpu" is implicit now.
160
- if session_state["video_path"] is not None:
161
- session_state["inference_state"] = predictor.init_state(video_path=session_state["video_path"])
162
- print("Predictor state re-initialized after clearing points.")
163
- else:
164
- print("Warning: Could not re-initialize state after clear_points (video_path missing).")
165
- session_state["inference_state"] = None # Ensure state is None if video_path is gone
166
-
167
-
168
- except Exception as e:
169
- print(f"Error resetting predictor state during clear_points: {e}")
170
- # If reset fails, this might leave old masks. Force-clear state on error.
171
- session_state["inference_state"] = None
172
-
173
-
174
- # Return the original first frame image for points_map and clear the output_image
175
- first_frame_img = session_state["first_frame"] if session_state["first_frame"] is not None else None
176
-
177
  return (
178
- first_frame_img, # points_map shows original first frame (no points yet)
179
- None, # output_image cleared (no mask)
180
- gr.update(value=None, visible=False), # output_video hidden
181
- session_state, # return updated session state
182
  )
183
 
184
 
185
- # Removed @spaces.GPU decorator
186
  def preprocess_video_in(video_path, session_state):
187
- """Loads video frames and initializes the predictor state."""
188
- print(f"Processing video: {video_path}")
189
- if video_path is None or not os.path.exists(video_path):
190
- print("No video path provided or file not found.")
191
- # Reset state and UI elements if input is invalid
192
- # Need to return updates for the buttons as well
193
  return (
194
- gr.update(open=True), None, None, gr.update(value=None, visible=False),
195
- gr.update(interactive=False), gr.update(interactive=False), gr.update(interactive=False),
196
- { # Reset session state
197
- "first_frame": None, "all_frames": None, "input_points": [],
198
- "input_labels": [], "inference_state": None, "video_path": None,
199
- }
200
  )
201
 
 
202
  cap = cv2.VideoCapture(video_path)
203
  if not cap.isOpened():
204
- print(f"Error: Could not open video file {video_path}.")
205
  return (
206
- gr.update(open=True), None, None, gr.update(value=None, visible=False),
207
- gr.update(interactive=False), gr.update(interactive=False), gr.update(interactive=False),
208
- { # Reset session state
209
- "first_frame": None, "all_frames": None, "input_points": [],
210
- "input_labels": [], "inference_state": None, "video_path": None,
211
- }
212
  )
213
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
214
  first_frame = None
215
  all_frames = []
216
-
 
 
 
 
 
217
  while True:
218
  ret, frame = cap.read()
219
  if not ret:
220
  break
221
- frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
222
- all_frames.append(frame)
223
- if first_frame is None:
224
- first_frame = frame
 
 
 
 
 
 
 
 
 
 
 
225
 
226
  cap.release()
227
-
228
- if not all_frames:
229
- print(f"Error: No frames read from video file {video_path}.")
230
- return (
231
- gr.update(open=True), None, None, gr.update(value=None, visible=False),
232
- gr.update(interactive=False), gr.update(interactive=False), gr.update(interactive=False),
233
- { # Reset session state
234
- "first_frame": None, "all_frames": None, "input_points": [],
235
- "input_labels": [], "inference_state": None, "video_path": None,
236
- }
237
- )
238
-
239
- # Update session state with frames and path
240
- session_state["first_frame"] = copy.deepcopy(first_frame) # Store a copy
241
  session_state["all_frames"] = all_frames
242
- session_state["video_path"] = video_path # Store video path
 
 
 
 
 
243
  session_state["input_points"] = []
244
  session_state["input_labels"] = []
245
- # Initialize state WITHOUT the device argument (uses predictor's device, which is CPU)
246
- session_state["inference_state"] = predictor.init_state(video_path=video_path)
247
- print("Video loaded and predictor state initialized on CPU.")
248
 
249
- # Enable buttons after successful load
250
  return [
251
  gr.update(open=False), # video_in_drawer
252
- first_frame, # points_map (shows first frame)
253
- None, # output_image (cleared initially)
254
- gr.update(value=None, visible=False), # output_video (hidden initially)
255
- gr.update(interactive=True), # propagate_btn enabled
256
- gr.update(interactive=True), # clear_points_btn enabled
257
- gr.update(interactive=True), # reset_btn enabled
258
- session_state, # session_state
259
  ]
260
 
261
 
262
- # Removed @spaces.GPU decorator
263
  def segment_with_points(
264
  point_type,
265
  session_state,
266
  evt: gr.SelectData,
267
  ):
268
- """Adds a point prompt and performs segmentation on the first frame."""
269
- # Ensure we have state and first frame
270
- if session_state["first_frame"] is None or session_state["inference_state"] is None:
271
- print("Error: Cannot segment. No video loaded or inference state missing.")
272
- # Return current states to avoid errors, without changing UI much
273
- return (
274
- session_state.get("first_frame"), # points_map (show first frame if exists)
275
- None, # output_image (keep cleared)
276
- session_state,
277
- )
278
-
279
- # evt.index is the (x, y) coordinate tuple
280
- click_coords = evt.index
281
- print(f"Clicked at: {click_coords} ({point_type})")
282
-
283
- session_state["input_points"].append(click_coords)
284
 
285
  if point_type == "include":
286
  session_state["input_labels"].append(1)
287
  elif point_type == "exclude":
288
  session_state["input_labels"].append(0)
 
289
 
290
- # Get the first frame as a PIL image for drawing
291
- first_frame_pil = Image.fromarray(session_state["first_frame"]).convert("RGBA")
292
- w, h = first_frame_pil.size
 
 
293
 
294
- # Define the circle radius
295
- fraction = 0.01
296
- radius = max(2, int(fraction * min(w, h))) # Ensure minimum radius of 2
297
 
298
- # Create a transparent layer to draw points
299
- transparent_layer_points = np.zeros((h, w, 4), dtype=np.uint8)
300
 
301
- # Draw points on the transparent layer
302
  for index, track in enumerate(session_state["input_points"]):
303
- # Ensure coordinates are integers for cv2.circle
304
- point_coords = (int(track[0]), int(track[1]))
305
- # Ensure color is RGBA (0-255)
306
  if session_state["input_labels"][index] == 1:
307
- cv2.circle(transparent_layer_points, point_coords, radius, (0, 255, 0, 255), -1) # Green for include
308
  else:
309
- cv2.circle(transparent_layer_points, point_coords, radius, (255, 0, 0, 255), -1) # Red for exclude
310
-
311
- # Convert the transparent layer back to an image and composite onto the first frame
312
- transparent_layer_points_pil = Image.fromarray(transparent_layer_points, "RGBA")
313
- # Combine the first frame image with the points layer for the points_map output
314
- # points_map shows the first frame *with the points you added*.
315
- selected_point_map_img = Image.alpha_composite(
316
- first_frame_pil.copy(), transparent_layer_points_pil
317
  )
318
 
319
- # Prepare points and labels as tensors on the correct device (CPU in this version)
320
  points = np.array(session_state["input_points"], dtype=np.float32)
 
321
  labels = np.array(session_state["input_labels"], np.int32)
 
 
 
 
 
 
 
 
 
322
 
323
- # Ensure tensors are on the correct device (CPU)
324
- device = next(predictor.parameters()).device # Get the device the model is on (should be "cpu")
325
- points_tensor = torch.tensor(points, dtype=torch.float32, device=device).unsqueeze(0) # Add batch dim
326
- labels_tensor = torch.tensor(labels, dtype=torch.int32, device=device).unsqueeze(0) # Add batch dim
327
-
328
-
329
- first_frame_output_img = None # Initialize output mask image as None in case of error
330
- try:
331
- # Note: predictor.add_new_points modifies the internal inference_state
332
- _, _, out_mask_logits = predictor.add_new_points(
333
- inference_state=session_state["inference_state"],
334
- frame_idx=0, # Always segment on the first frame initially
335
- obj_id=OBJ_ID,
336
- points=points_tensor,
337
- labels=labels_tensor,
338
- )
339
-
340
- # Process logits: detach from graph, move to CPU, apply threshold
341
- # out_mask_logits is a list of tensors [tensor([batch_size, H, W])] for the requested obj_id
342
- # Access the result for the first object (index 0) and the first item in batch (index 0)
343
- mask_tensor = (out_mask_logits[0][0].detach().cpu() > 0.0) # Move to CPU before converting to numpy
344
- mask_numpy = mask_tensor.numpy() # Convert to numpy
345
-
346
- # Get the mask image (RGBA)
347
- mask_image_pil = show_mask(mask_numpy, obj_id=OBJ_ID) # show_mask returns RGBA PIL Image
348
-
349
- # Composite the mask onto the first frame for the output_image
350
- # output_image shows the first frame *with the segmentation mask result*.
351
- first_frame_output_img = Image.alpha_composite(first_frame_pil.copy(), mask_image_pil)
352
-
353
- except Exception as e:
354
- print(f"Error during segmentation on first frame: {e}")
355
- # On error, first_frame_output_img remains None
356
-
357
- # Removed CUDA cache clearing call
358
- # if torch.cuda.is_available():
359
- # torch.cuda.empty_cache()
360
 
361
- return selected_point_map_img, first_frame_output_img, session_state
362
 
363
 
364
  def show_mask(mask, obj_id=None, random_color=False, convert_to_image=True):
365
- """Helper function to visualize a mask."""
366
- # Ensure mask is a numpy array (and boolean)
367
- if isinstance(mask, torch.Tensor):
368
- mask = mask.detach().cpu().numpy() # Ensure it's on CPU and converted to numpy
369
- # Convert potential float/int mask to boolean mask
370
- mask = mask.astype(bool)
371
-
372
  if random_color:
373
- color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0) # RGBA with 0.6 alpha
374
  else:
375
  cmap = plt.get_cmap("tab10")
376
- cmap_idx = 0 if obj_id is None else obj_id % 10 # Use modulo 10 for tab10 colors
377
- color = np.array([*cmap(cmap_idx)[:3], 0.6]) # RGBA with 0.6 alpha
378
-
379
- # Ensure mask has H, W dimensions
380
- if mask.ndim == 3:
381
- mask = mask.squeeze() # Remove singular dimensions like (H, W, 1)
382
- if mask.ndim != 2:
383
- print(f"Warning: show_mask received mask with shape {mask.shape}. Expected 2D.")
384
- # Create an empty transparent image if mask shape is unexpected
385
- h, w = mask.shape[:2] if mask.ndim >= 2 else (100, 100) # Use actual shape if possible, otherwise default
386
- if convert_to_image:
387
- return Image.fromarray(np.zeros((h, w, 4), dtype=np.uint8), "RGBA")
388
- else:
389
- return np.zeros((h, w, 4), dtype=np.uint8)
390
-
391
- h, w = mask.shape
392
- # Create an RGBA image from the mask and color
393
- # Apply color where mask is True
394
- # Need to reshape color to be broadcastable [1, 1, 4]
395
- colored_mask = np.zeros((h, w, 4), dtype=np.float32) # Start with fully transparent black
396
- # Apply the color only where the mask is True.
397
- # This directly creates the colored overlay with transparency.
398
- colored_mask[mask] = color
399
-
400
- # Convert to uint8 [0-255]
401
- colored_mask_uint8 = (colored_mask * 255).astype(np.uint8)
402
-
403
  if convert_to_image:
404
- mask_img = Image.fromarray(colored_mask_uint8, "RGBA")
405
- return mask_img
406
- else:
407
- return colored_mask_uint8
408
 
409
 
410
- # Removed @spaces.GPU decorator
411
  def propagate_to_all(
412
- video_in, # Keep video_in path as in original
413
  session_state,
414
  ):
415
- """Runs mask propagation through the video and generates the output video."""
416
- print("Starting propagation...")
417
- # Ensure state is ready
418
- # Using session_state.get("video_path") is safer than video_in directly
419
- current_video_path = session_state.get("video_path")
420
  if (
421
- len(session_state["input_points"]) == 0 # Need at least one point
422
- or session_state["all_frames"] is None
423
  or session_state["inference_state"] is None
424
- or current_video_path is None # Ensure we have the original video path
425
  ):
426
- print("Error: Cannot propagate. No points selected, video not loaded, or inference state missing.")
427
- return (
428
- gr.update(value=None, visible=False), # Hide output video on error
429
- session_state,
430
- )
431
-
432
- # run propagation throughout the video and collect the results
433
- video_segments = {}
434
- try:
435
- # This loop performs the core tracking prediction frame by frame
436
- for out_frame_idx, out_obj_ids, out_mask_logits in predictor.propagate_in_video(
437
- session_state["inference_state"]
438
- ):
439
- # Process logits: detach from graph, move to CPU, convert to numpy boolean mask
440
- # Ensure tensor is on CPU before converting to numpy
441
- video_segments[out_frame_idx] = {
442
- # out_mask_logits is a list of tensors (one per object tracked in this frame)
443
- # Each tensor is [batch_size, H, W]. Batch size is 1 here.
444
- # Access the result for the first object (index i) and the first item in batch (index 0)
445
- out_obj_id: (out_mask_logits[i][0].detach().cpu() > 0.0).numpy()
446
- for i, out_obj_id in enumerate(out_obj_ids)
447
- }
448
- # Optional: print progress
449
- # print(f"Processed frame {out_frame_idx+1}/{len(session_state['all_frames'])}")
450
-
451
- print("Propagation finished.")
452
- except Exception as e:
453
- print(f"Error during propagation: {e}")
454
  return (
455
- gr.update(value=None, visible=False), # Hide output video on error
456
  session_state,
457
  )
458
 
459
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
460
  output_frames = []
461
- # Iterate through all original frames to generate output video
462
- total_frames = len(session_state["all_frames"])
463
- for out_frame_idx in range(total_frames):
464
- original_frame_rgb = session_state["all_frames"][out_frame_idx]
465
- # Convert original frame to RGBA for compositing
466
- transparent_background = Image.fromarray(original_frame_rgb).convert("RGBA")
467
-
468
- # Check if we have a mask for this frame and object ID
469
- if out_frame_idx in video_segments and OBJ_ID in video_segments[out_frame_idx]:
470
- current_mask_numpy = video_segments[out_frame_idx][OBJ_ID]
471
- # Get the mask image (RGBA)
472
- mask_image_pil = show_mask(current_mask_numpy, obj_id=OBJ_ID)
473
- # Composite the mask onto the frame
474
- output_frame_img_rgba = Image.alpha_composite(transparent_background, mask_image_pil)
475
- # Convert back to numpy RGB (moviepy needs RGB or RGBA)
476
- output_frame_np = np.array(output_frame_img_rgba.convert("RGB"))
477
- else:
478
- # If no mask for this frame/object, just use the original frame (converted to RGB)
479
- # Note: all_frames are already RGB numpy arrays, so just use them directly.
480
- # print(f"Warning: No mask found for frame {out_frame_idx} and object {OBJ_ID}. Using original frame.")
481
- output_frame_np = original_frame_rgb # Already RGB numpy array
482
-
483
- output_frames.append(output_frame_np)
484
-
485
- # Removed CUDA cache clearing call
486
- # if torch.cuda.is_available():
487
- # torch.cuda.empty_cache()
488
-
489
- # Define output path in a temporary directory
490
- unique_id = datetime.now().strftime("%Y%m%d%H%M%S%f") # Use microseconds for more uniqueness
491
- final_vid_filename = f"output_video_{unique_id}.mp4"
492
- final_vid_output_path = os.path.join(tempfile.gettempdir(), final_vid_filename)
493
- print(f"Output video path: {final_vid_output_path}")
494
-
495
 
496
  # Create a video clip from the image sequence
497
- # Get original FPS from the stored video path
498
- original_fps = get_video_fps(current_video_path)
499
- fps = original_fps if original_fps is not None and original_fps > 0 else 30 # Default to 30 if detection fails or is zero
500
- print(f"Creating output video with FPS: {fps}")
501
-
502
- # Check if there are frames to process
503
- if not output_frames:
504
- print("No output frames generated.")
505
- return (
506
- gr.update(value=None, visible=False), # Hide output video
507
- session_state,
508
- )
 
 
 
 
509
 
510
- # Create ImageSequenceClip from the list of numpy arrays
511
- try:
512
- clip = ImageSequenceClip(output_frames, fps=fps)
513
- except Exception as e:
514
- print(f"Error creating ImageSequenceClip: {e}")
515
- return (
516
- gr.update(value=None, visible=False), # Hide output video on error
517
- session_state,
518
- )
519
-
520
- # Write the result to a file. Use 'libx264' codec for broad compatibility.
521
- # Added CPU optimization parameters for moviepy write
522
- try:
523
- print(f"Writing video file with codec='libx264', fps={fps}, preset='medium', threads='auto'")
524
- clip.write_videofile(
525
- final_vid_output_path,
526
- codec="libx264",
527
- fps=fps, # Ensure correct FPS is used during writing
528
- preset="medium", # CPU optimization: 'fast', 'faster', 'veryfast' are options for speed vs size
529
- threads="auto", # CPU optimization: Use multiple cores
530
- logger=None # Suppress moviepy output
531
- )
532
- print("Video writing complete.")
533
- # Return the path and make the video player visible
534
- return (
535
- gr.update(value=final_vid_output_path, visible=True),
536
- session_state,
537
- )
538
- except Exception as e:
539
- print(f"Error writing video file: {e}")
540
- # Clean up potentially created partial file
541
- if os.path.exists(final_vid_output_path):
542
- try:
543
- os.remove(final_vid_output_path)
544
- print(f"Removed partial video file: {final_vid_output_path}")
545
- except Exception as clean_e:
546
- print(f"Error removing partial file: {clean_e}")
547
-
548
- # Return None if writing fails
549
- return (
550
- gr.update(value=None, visible=False),
551
- session_state,
552
- )
553
 
554
 
555
- def update_output_video_visibility():
556
- """Simply returns a Gradio update to make the output video visible."""
557
  return gr.update(visible=True)
558
 
559
 
560
  with gr.Blocks() as demo:
561
- # Session state dictionary to hold video frames, points, labels, and predictor state
562
  session_state = gr.State(
563
  {
564
- "first_frame": None, # numpy array (RGB)
565
- "all_frames": None, # list of numpy arrays (RGB)
566
- "input_points": [], # list of (x, y) tuples/lists
567
- "input_labels": [], # list of 1s and 0s
568
- "inference_state": None, # EdgeTAM predictor state object
569
- "video_path": None, # Store the input video path
 
 
570
  }
571
  )
572
 
@@ -580,7 +356,7 @@ with gr.Blocks() as demo:
580
  gr.Markdown(description_p)
581
 
582
  with gr.Accordion("Input Video", open=True) as video_in_drawer:
583
- video_in = gr.Video(label="Input Video", format="mp4") # Will hold the video file path
584
 
585
  with gr.Row():
586
  point_type = gr.Radio(
@@ -588,142 +364,121 @@ with gr.Blocks() as demo:
588
  choices=["include", "exclude"],
589
  value="include",
590
  scale=2,
591
- interactive=True, # Make interactive
592
  )
593
- # Buttons are initially disabled until a video is loaded
594
- propagate_btn = gr.Button("Track", scale=1, variant="primary", interactive=False)
595
- clear_points_btn = gr.Button("Clear Points", scale=1, interactive=False)
596
- reset_btn = gr.Button("Reset", scale=1, interactive=False)
597
 
598
- # points_map is where users click to add points. Needs to be interactive.
599
- # Shows the first frame with points drawn on it.
600
  points_map = gr.Image(
601
- label="Click on the First Frame to Add Points", # Clearer label
602
- type="numpy",
603
- interactive=True, # <--- CHANGED TO True to enable clicking
604
- height=400, # Set a fixed height for better UI
605
- width="auto", # Let width adjust
606
- show_share_button=False,
607
- show_download_button=False,
608
  )
609
 
610
  with gr.Column():
611
  gr.Markdown("# Try some of the examples below ⬇️")
612
  gr.Examples(
613
  examples=examples,
614
- inputs=[video_in],
615
- examples_per_page=8,
616
- cache_examples=False, # Do not cache processed examples, as state is involved
617
- )
618
- # Removed extra blank lines
619
-
620
- # output_image shows the segmentation mask prediction on the *first* frame
621
- output_image = gr.Image(
622
- label="Segmentation Mask on First Frame", # Clearer label
623
- type="numpy",
624
- interactive=False, # Not interactive, just displays the mask
625
- height=400, # Match height of points_map
626
- width="auto", # Let width adjust
627
- show_share_button=False,
628
- show_download_button=False,
629
  )
 
 
 
630
 
631
- # output_video shows the final tracking result
632
- output_video = gr.Video(visible=False, label="Tracking Result")
633
-
634
-
635
- # --- Event Handlers ---
636
-
637
- # When a new video file is uploaded via the file browser
638
- # Added postprocess to update button interactivity based on whether video loaded
639
  video_in.upload(
640
  fn=preprocess_video_in,
641
- inputs=[video_in, session_state],
 
 
 
642
  outputs=[
643
- video_in_drawer, points_map, output_image, output_video,
644
- propagate_btn, clear_points_btn, reset_btn, session_state,
 
 
 
645
  ],
646
- queue=False, # Process immediately
647
  )
648
 
649
- # When an example video is selected (change event)
650
- # Added postprocess to update button interactivity
651
  video_in.change(
652
  fn=preprocess_video_in,
653
- inputs=[video_in, session_state],
 
 
 
654
  outputs=[
655
- video_in_drawer, points_map, output_image, output_video,
656
- propagate_btn, clear_points_btn, reset_btn, session_state,
 
 
 
657
  ],
658
- queue=False, # Process immediately
659
  )
660
 
661
-
662
- # Triggered when a user clicks on the points_map image
663
  points_map.select(
664
  fn=segment_with_points,
665
  inputs=[
666
- point_type, # "include" or "exclude" radio button value
667
- session_state, # Pass session state
668
  ],
669
  outputs=[
670
- points_map, # Updated image with points drawn
671
- output_image, # Updated image with first frame segmentation mask
672
- session_state, # Updated session state (points/labels added)
673
  ],
674
- queue=False, # Process clicks immediately
675
  )
676
 
677
- # Button to clear all selected points and reset the first frame mask
678
  clear_points_btn.click(
679
  fn=clear_points,
680
- inputs=[session_state], # Pass session state
681
  outputs=[
682
- points_map, # points_map shows original first frame without points
683
- output_image, # output_image cleared (or shows original first frame without mask)
684
- output_video, # Hide output video
685
- session_state, # Updated session state (points/labels cleared, inference state reset)
686
  ],
687
- queue=False, # Process immediately
688
  )
689
 
690
- # Button to reset the entire demo state and UI
691
  reset_btn.click(
692
  fn=reset,
693
- inputs=[session_state], # Pass session state
694
  outputs=[
695
- video_in, video_in_drawer, points_map, output_image, output_video,
696
- propagate_btn, clear_points_btn, reset_btn, session_state,
 
 
 
 
697
  ],
698
- queue=False, # Process immediately
699
  )
700
 
701
- # Button to start mask propagation through the video
702
  propagate_btn.click(
703
- fn=update_output_video_visibility, # First, make the output video player visible
704
  inputs=[],
705
- outputs=[output_video],
706
- queue=False, # Process this UI update immediately
707
- ).then( # Then, run the propagation function
708
  fn=propagate_to_all,
709
  inputs=[
710
- video_in, # Get the input video path (can also get from session_state["video_path"])
711
- session_state, # Pass session state (contains frames, points, inference_state, video_path)
712
  ],
713
  outputs=[
714
- output_video, # Update output video player with result
715
- session_state, # Update session state
716
  ],
717
- # CPU Optimization: Limit concurrency to 1 to prevent resource exhaustion.
718
- # Queue=True ensures requests wait if another is processing.
719
- concurrency_limit=1,
720
- queue=True,
721
  )
722
 
723
 
724
- # Launch the Gradio demo
725
- demo.queue() # Enable queuing for sequential processing under concurrency limits
726
- print("Gradio demo starting...")
727
- # Removed share=True for local debugging unless you specifically need a public link
728
- demo.launch()
729
- print("Gradio demo launched.")
 
7
  import copy
8
  import os
9
  from datetime import datetime
 
 
 
 
 
 
10
  import tempfile
11
 
12
  import cv2
13
  import matplotlib.pyplot as plt
14
  import numpy as np
15
+ import gradio as gr
 
16
  import torch
17
 
18
  from moviepy.editor import ImageSequenceClip
19
  from PIL import Image
20
  from sam2.build_sam import build_sam2_video_predictor
21
 
22
+ # Remove CUDA environment variables
23
+ if 'TORCH_CUDNN_SDPA_ENABLED' in os.environ:
24
+ del os.environ["TORCH_CUDNN_SDPA_ENABLED"]
25
+
26
  # Description
27
+ title = "<center><strong><font size='8'>EdgeTAM CPU<font></strong> <a href='https://github.com/facebookresearch/EdgeTAM'><font size='6'>[GitHub]</font></a> </center>"
28
 
29
  description_p = """# Instructions
30
  <ol>
 
35
  </ol>
36
  """
37
 
38
+ # examples - keeping fewer examples to reduce memory footprint
39
  examples = [
40
  ["examples/01_dog.mp4"],
41
  ["examples/02_cups.mp4"],
42
  ["examples/03_blocks.mp4"],
43
  ["examples/04_coffee.mp4"],
44
  ["examples/05_default_juggle.mp4"],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
  ]
46
 
47
  OBJ_ID = 0
48
 
49
+ # Initialize model on CPU
50
  sam2_checkpoint = "checkpoints/edgetam.pt"
51
  model_cfg = "edgetam.yaml"
 
52
  predictor = build_sam2_video_predictor(model_cfg, sam2_checkpoint, device="cpu")
 
 
53
  print("predictor loaded on CPU")
54
 
55
+ # Function to get video frame rate
 
 
 
 
 
 
 
 
56
  def get_video_fps(video_path):
 
 
 
 
57
  cap = cv2.VideoCapture(video_path)
58
  if not cap.isOpened():
59
+ print("Error: Could not open video.")
60
+ return 30.0 # Default fallback value
61
  fps = cap.get(cv2.CAP_PROP_FPS)
62
+ cap.release()
63
  return fps
64
 
65
 
66
  def reset(session_state):
 
 
67
  session_state["input_points"] = []
68
  session_state["input_labels"] = []
 
69
  if session_state["inference_state"] is not None:
70
+ predictor.reset_state(session_state["inference_state"])
 
 
 
 
 
 
 
 
 
71
  session_state["first_frame"] = None
72
  session_state["all_frames"] = None
73
+ session_state["inference_state"] = None
 
 
 
 
74
  return (
75
+ None,
76
+ gr.update(open=True),
77
+ None,
78
+ None,
79
+ gr.update(value=None, visible=False),
80
+ session_state,
 
 
 
81
  )
82
 
83
 
84
  def clear_points(session_state):
 
 
85
  session_state["input_points"] = []
86
  session_state["input_labels"] = []
87
+ if session_state["inference_state"] is not None and session_state["inference_state"].get("tracking_has_started", False):
88
+ predictor.reset_state(session_state["inference_state"])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
89
  return (
90
+ session_state["first_frame"],
91
+ None,
92
+ gr.update(value=None, visible=False),
93
+ session_state,
94
  )
95
 
96
 
 
97
  def preprocess_video_in(video_path, session_state):
98
+ if video_path is None:
 
 
 
 
 
99
  return (
100
+ gr.update(open=True), # video_in_drawer
101
+ None, # points_map
102
+ None, # output_image
103
+ gr.update(value=None, visible=False), # output_video
104
+ session_state,
 
105
  )
106
 
107
+ # Read the first frame
108
  cap = cv2.VideoCapture(video_path)
109
  if not cap.isOpened():
110
+ print("Error: Could not open video.")
111
  return (
112
+ gr.update(open=True), # video_in_drawer
113
+ None, # points_map
114
+ None, # output_image
115
+ gr.update(value=None, visible=False), # output_video
116
+ session_state,
 
117
  )
118
 
119
+ # For CPU optimization - determine video properties
120
+ frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
121
+ frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
122
+ total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
123
+
124
+ # Determine if we need to resize for CPU performance
125
+ target_width = 640 # Target width for processing on CPU
126
+ scale_factor = 1.0
127
+
128
+ if frame_width > target_width:
129
+ scale_factor = target_width / frame_width
130
+ frame_width = target_width
131
+ frame_height = int(frame_height * scale_factor)
132
+
133
+ # Read frames - for CPU we'll be more selective about which frames to keep
134
+ frame_number = 0
135
  first_frame = None
136
  all_frames = []
137
+
138
+ # For CPU optimization, skip frames if video is too long
139
+ frame_stride = 1
140
+ if total_frames > 300: # If more than 300 frames
141
+ frame_stride = max(1, int(total_frames / 300)) # Process at most ~300 frames
142
+
143
  while True:
144
  ret, frame = cap.read()
145
  if not ret:
146
  break
147
+
148
+ if frame_number % frame_stride == 0: # Process every frame_stride frames
149
+ # Resize the frame if needed
150
+ if scale_factor != 1.0:
151
+ frame = cv2.resize(frame, (frame_width, frame_height), interpolation=cv2.INTER_AREA)
152
+
153
+ frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
154
+ frame = np.array(frame)
155
+
156
+ # Store the first frame
157
+ if first_frame is None:
158
+ first_frame = frame
159
+ all_frames.append(frame)
160
+
161
+ frame_number += 1
162
 
163
  cap.release()
164
+ session_state["first_frame"] = copy.deepcopy(first_frame)
 
 
 
 
 
 
 
 
 
 
 
 
 
165
  session_state["all_frames"] = all_frames
166
+ session_state["frame_stride"] = frame_stride
167
+ session_state["scale_factor"] = scale_factor
168
+ session_state["original_dimensions"] = (int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)),
169
+ int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)))
170
+
171
+ session_state["inference_state"] = predictor.init_state(video_path=video_path)
172
  session_state["input_points"] = []
173
  session_state["input_labels"] = []
 
 
 
174
 
 
175
  return [
176
  gr.update(open=False), # video_in_drawer
177
+ first_frame, # points_map
178
+ None, # output_image
179
+ gr.update(value=None, visible=False), # output_video
180
+ session_state,
 
 
 
181
  ]
182
 
183
 
 
184
  def segment_with_points(
185
  point_type,
186
  session_state,
187
  evt: gr.SelectData,
188
  ):
189
+ session_state["input_points"].append(evt.index)
190
+ print(f"TRACKING INPUT POINT: {session_state['input_points']}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
191
 
192
  if point_type == "include":
193
  session_state["input_labels"].append(1)
194
  elif point_type == "exclude":
195
  session_state["input_labels"].append(0)
196
+ print(f"TRACKING INPUT LABEL: {session_state['input_labels']}")
197
 
198
+ # Open the image and get its dimensions
199
+ transparent_background = Image.fromarray(session_state["first_frame"]).convert(
200
+ "RGBA"
201
+ )
202
+ w, h = transparent_background.size
203
 
204
+ # Define the circle radius as a fraction of the smaller dimension
205
+ fraction = 0.01 # You can adjust this value as needed
206
+ radius = int(fraction * min(w, h))
207
 
208
+ # Create a transparent layer to draw on
209
+ transparent_layer = np.zeros((h, w, 4), dtype=np.uint8)
210
 
 
211
  for index, track in enumerate(session_state["input_points"]):
 
 
 
212
  if session_state["input_labels"][index] == 1:
213
+ cv2.circle(transparent_layer, track, radius, (0, 255, 0, 255), -1)
214
  else:
215
+ cv2.circle(transparent_layer, track, radius, (255, 0, 0, 255), -1)
216
+
217
+ # Convert the transparent layer back to an image
218
+ transparent_layer = Image.fromarray(transparent_layer, "RGBA")
219
+ selected_point_map = Image.alpha_composite(
220
+ transparent_background, transparent_layer
 
 
221
  )
222
 
223
+ # Let's add a positive click at (x, y) = (210, 350) to get started
224
  points = np.array(session_state["input_points"], dtype=np.float32)
225
+ # for labels, `1` means positive click and `0` means negative click
226
  labels = np.array(session_state["input_labels"], np.int32)
227
+
228
+ # For CPU optimization, we'll process with smaller batch size
229
+ _, _, out_mask_logits = predictor.add_new_points(
230
+ inference_state=session_state["inference_state"],
231
+ frame_idx=0,
232
+ obj_id=OBJ_ID,
233
+ points=points,
234
+ labels=labels,
235
+ )
236
 
237
+ mask_image = show_mask((out_mask_logits[0] > 0.0).cpu().numpy())
238
+ first_frame_output = Image.alpha_composite(transparent_background, mask_image)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
239
 
240
+ return selected_point_map, first_frame_output, session_state
241
 
242
 
243
  def show_mask(mask, obj_id=None, random_color=False, convert_to_image=True):
 
 
 
 
 
 
 
244
  if random_color:
245
+ color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
246
  else:
247
  cmap = plt.get_cmap("tab10")
248
+ cmap_idx = 0 if obj_id is None else obj_id
249
+ color = np.array([*cmap(cmap_idx)[:3], 0.6])
250
+ h, w = mask.shape[-2:]
251
+ mask = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
252
+ mask = (mask * 255).astype(np.uint8)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
253
  if convert_to_image:
254
+ mask = Image.fromarray(mask, "RGBA")
255
+ return mask
 
 
256
 
257
 
 
258
  def propagate_to_all(
259
+ video_in,
260
  session_state,
261
  ):
 
 
 
 
 
262
  if (
263
+ len(session_state["input_points"]) == 0
264
+ or video_in is None
265
  or session_state["inference_state"] is None
 
266
  ):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
267
  return (
268
+ None,
269
  session_state,
270
  )
271
 
272
+ # For CPU optimization: process in smaller batches
273
+ chunk_size = 5 # Process 5 frames at a time to avoid memory issues
274
+
275
+ # run propagation throughout the video and collect the results in a dict
276
+ video_segments = {} # video_segments contains the per-frame segmentation results
277
+ print("starting propagate_in_video on CPU")
278
+
279
+ # Get the frames in chunks for CPU memory optimization
280
+ for out_frame_idx, out_obj_ids, out_mask_logits in predictor.propagate_in_video(
281
+ session_state["inference_state"]
282
+ ):
283
+ video_segments[out_frame_idx] = {
284
+ out_obj_id: (out_mask_logits[i] > 0.0).cpu().numpy()
285
+ for i, out_obj_id in enumerate(out_obj_ids)
286
+ }
287
+
288
+ # Free up memory after processing each frame
289
+ if len(video_segments) % chunk_size == 0:
290
+ torch.cuda.empty_cache() if torch.cuda.is_available() else None
291
+
292
+ # obtain the segmentation results every few frames
293
+ # For CPU optimization: increase stride to reduce processing
294
+ vis_frame_stride = max(1, len(video_segments) // 100) # Limit to ~100 frames in output
295
+
296
  output_frames = []
297
+ for out_frame_idx in range(0, len(video_segments), vis_frame_stride):
298
+ transparent_background = Image.fromarray(
299
+ session_state["all_frames"][out_frame_idx]
300
+ ).convert("RGBA")
301
+ out_mask = video_segments[out_frame_idx][OBJ_ID]
302
+ mask_image = show_mask(out_mask)
303
+ output_frame = Image.alpha_composite(transparent_background, mask_image)
304
+ output_frame = np.array(output_frame)
305
+ output_frames.append(output_frame)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
306
 
307
  # Create a video clip from the image sequence
308
+ original_fps = get_video_fps(video_in)
309
+ fps = original_fps # Frames per second
310
+
311
+ # For CPU optimization - lower FPS if original is high
312
+ if fps > 24:
313
+ fps = 24
314
+
315
+ clip = ImageSequenceClip(output_frames, fps=fps)
316
+
317
+ # Write the result to a file - use lower quality for CPU
318
+ unique_id = datetime.now().strftime("%Y%m%d%H%M%S")
319
+ final_vid_output_path = f"output_video_{unique_id}.mp4"
320
+ final_vid_output_path = os.path.join(tempfile.gettempdir(), final_vid_output_path)
321
+
322
+ # Lower bitrate for CPU processing
323
+ clip.write_videofile(final_vid_output_path, codec="libx264", bitrate="1000k")
324
 
325
+ return (
326
+ gr.update(value=final_vid_output_path),
327
+ session_state,
328
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
329
 
330
 
331
+ def update_ui():
 
332
  return gr.update(visible=True)
333
 
334
 
335
  with gr.Blocks() as demo:
 
336
  session_state = gr.State(
337
  {
338
+ "first_frame": None,
339
+ "all_frames": None,
340
+ "input_points": [],
341
+ "input_labels": [],
342
+ "inference_state": None,
343
+ "frame_stride": 1,
344
+ "scale_factor": 1.0,
345
+ "original_dimensions": None,
346
  }
347
  )
348
 
 
356
  gr.Markdown(description_p)
357
 
358
  with gr.Accordion("Input Video", open=True) as video_in_drawer:
359
+ video_in = gr.Video(label="Input Video", format="mp4")
360
 
361
  with gr.Row():
362
  point_type = gr.Radio(
 
364
  choices=["include", "exclude"],
365
  value="include",
366
  scale=2,
 
367
  )
368
+ propagate_btn = gr.Button("Track", scale=1, variant="primary")
369
+ clear_points_btn = gr.Button("Clear Points", scale=1)
370
+ reset_btn = gr.Button("Reset", scale=1)
 
371
 
 
 
372
  points_map = gr.Image(
373
+ label="Frame with Point Prompt", type="numpy", interactive=False
 
 
 
 
 
 
374
  )
375
 
376
  with gr.Column():
377
  gr.Markdown("# Try some of the examples below ⬇️")
378
  gr.Examples(
379
  examples=examples,
380
+ inputs=[
381
+ video_in,
382
+ ],
383
+ examples_per_page=5,
 
 
 
 
 
 
 
 
 
 
 
384
  )
385
+
386
+ output_image = gr.Image(label="Reference Mask")
387
+ output_video = gr.Video(visible=False)
388
 
389
+ # When new video is uploaded
 
 
 
 
 
 
 
390
  video_in.upload(
391
  fn=preprocess_video_in,
392
+ inputs=[
393
+ video_in,
394
+ session_state,
395
+ ],
396
  outputs=[
397
+ video_in_drawer, # Accordion to hide uploaded video player
398
+ points_map, # Image component where we add new tracking points
399
+ output_image,
400
+ output_video,
401
+ session_state,
402
  ],
403
+ queue=False,
404
  )
405
 
 
 
406
  video_in.change(
407
  fn=preprocess_video_in,
408
+ inputs=[
409
+ video_in,
410
+ session_state,
411
+ ],
412
  outputs=[
413
+ video_in_drawer, # Accordion to hide uploaded video player
414
+ points_map, # Image component where we add new tracking points
415
+ output_image,
416
+ output_video,
417
+ session_state,
418
  ],
419
+ queue=False,
420
  )
421
 
422
+ # triggered when we click on image to add new points
 
423
  points_map.select(
424
  fn=segment_with_points,
425
  inputs=[
426
+ point_type, # "include" or "exclude"
427
+ session_state,
428
  ],
429
  outputs=[
430
+ points_map, # updated image with points
431
+ output_image,
432
+ session_state,
433
  ],
434
+ queue=False,
435
  )
436
 
437
+ # Clear every points clicked and added to the map
438
  clear_points_btn.click(
439
  fn=clear_points,
440
+ inputs=session_state,
441
  outputs=[
442
+ points_map,
443
+ output_image,
444
+ output_video,
445
+ session_state,
446
  ],
447
+ queue=False,
448
  )
449
 
 
450
  reset_btn.click(
451
  fn=reset,
452
+ inputs=session_state,
453
  outputs=[
454
+ video_in,
455
+ video_in_drawer,
456
+ points_map,
457
+ output_image,
458
+ output_video,
459
+ session_state,
460
  ],
461
+ queue=False,
462
  )
463
 
 
464
  propagate_btn.click(
465
+ fn=update_ui,
466
  inputs=[],
467
+ outputs=output_video,
468
+ queue=False,
469
+ ).then(
470
  fn=propagate_to_all,
471
  inputs=[
472
+ video_in,
473
+ session_state,
474
  ],
475
  outputs=[
476
+ output_video,
477
+ session_state,
478
  ],
479
+ queue=True, # Use queue for CPU processing
 
 
 
480
  )
481
 
482
 
483
+ demo.queue()
484
+ demo.launch()