Spaces:
bla
/
Runtime error

bla commited on
Commit
807f473
·
verified ·
1 Parent(s): ea6a5ed

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +475 -129
app.py CHANGED
@@ -6,6 +6,7 @@
6
 
7
  import copy
8
  import os
 
9
  from datetime import datetime
10
  import tempfile
11
 
@@ -14,6 +15,7 @@ import matplotlib.pyplot as plt
14
  import numpy as np
15
  import gradio as gr
16
  import torch
 
17
  from moviepy.editor import ImageSequenceClip
18
  from PIL import Image
19
  from sam2.build_sam import build_sam2_video_predictor
@@ -22,19 +24,19 @@ from sam2.build_sam import build_sam2_video_predictor
22
  if 'TORCH_CUDNN_SDPA_ENABLED' in os.environ:
23
  del os.environ["TORCH_CUDNN_SDPA_ENABLED"]
24
 
25
- # UI Description
26
- title = "<center><strong><font size='8'>EdgeTAM CPU</font></strong> <a href='https://github.com/facebookresearch/EdgeTAM'><font size='6'>[GitHub]</font></a></center>"
27
 
28
  description_p = """# Instructions
29
  <ol>
30
- <li>Upload one video or click one example video</li>
31
- <li>Click 'include' point type, select the object to segment and track</li>
32
- <li>Click 'exclude' point type (optional), select the area to avoid segmenting</li>
33
- <li>Click the 'Track' button to obtain the masked video</li>
34
  </ol>
35
  """
36
 
37
- # Example videos
38
  examples = [
39
  ["examples/01_dog.mp4"],
40
  ["examples/02_cups.mp4"],
@@ -45,38 +47,41 @@ examples = [
45
 
46
  OBJ_ID = 0
47
 
48
- # Initialize model on CPU
49
  sam2_checkpoint = "checkpoints/edgetam.pt"
50
  model_cfg = "edgetam.yaml"
51
 
 
52
  def check_file_exists(filepath):
53
  exists = os.path.exists(filepath)
54
  if not exists:
55
  print(f"WARNING: File not found: {filepath}")
56
  return exists
57
 
58
- # Verify model files
59
  model_files_exist = check_file_exists(sam2_checkpoint) and check_file_exists(model_cfg)
 
60
  try:
 
61
  predictor = build_sam2_video_predictor(model_cfg, sam2_checkpoint, device="cpu")
62
- print("Predictor loaded on CPU")
63
  except Exception as e:
64
  print(f"Error loading model: {e}")
65
  import traceback
66
  traceback.print_exc()
67
- predictor = None
68
 
69
- # Utility Functions
70
  def get_video_fps(video_path):
71
  cap = cv2.VideoCapture(video_path)
72
  if not cap.isOpened():
73
  print("Error: Could not open video.")
74
- return 30.0
75
  fps = cap.get(cv2.CAP_PROP_FPS)
76
  cap.release()
77
  return fps
78
 
79
  def reset(session_state):
 
80
  session_state["input_points"] = []
81
  session_state["input_labels"] = []
82
  if session_state["inference_state"] is not None:
@@ -84,16 +89,19 @@ def reset(session_state):
84
  session_state["first_frame"] = None
85
  session_state["all_frames"] = None
86
  session_state["inference_state"] = None
 
87
  return (
88
  None,
89
  gr.update(open=True),
90
  None,
91
  None,
92
  gr.update(value=None, visible=False),
 
93
  session_state,
94
  )
95
 
96
  def clear_points(session_state):
 
97
  session_state["input_points"] = []
98
  session_state["input_labels"] = []
99
  if session_state["inference_state"] is not None and session_state["inference_state"].get("tracking_has_started", False):
@@ -102,82 +110,143 @@ def clear_points(session_state):
102
  session_state["first_frame"],
103
  None,
104
  gr.update(value=None, visible=False),
 
105
  session_state,
106
  )
107
 
108
  def preprocess_video_in(video_path, session_state):
109
- if video_path is None:
 
110
  return (
111
- gr.update(open=True),
112
- None,
113
- None,
114
- gr.update(value=None, visible=False),
 
115
  session_state,
116
  )
117
 
 
118
  cap = cv2.VideoCapture(video_path)
119
  if not cap.isOpened():
120
- print("Error: Could not open video.")
121
  return (
122
- gr.update(open=True),
123
- None,
124
- None,
125
- gr.update(value=None, visible=False),
 
126
  session_state,
127
  )
128
 
129
- # Video properties
130
  frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
131
  frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
132
  total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
133
-
134
- # Resize for CPU performance
135
- target_width = 640
 
 
 
136
  scale_factor = 1.0
 
137
  if frame_width > target_width:
138
  scale_factor = target_width / frame_width
139
- frame_width = target_width
140
- frame_height = int(frame_height * scale_factor)
141
-
142
- # Read frames with stride for CPU optimization
 
143
  frame_number = 0
144
  first_frame = None
145
  all_frames = []
146
- frame_stride = max(1, total_frames // 300) # Limit to ~300 frames
147
-
 
 
 
 
 
148
  while True:
149
  ret, frame = cap.read()
150
  if not ret:
151
  break
152
- if frame_number % frame_stride == 0:
153
- if scale_factor != 1.0:
154
- frame = cv2.resize(frame, (frame_width, frame_height), interpolation=cv2.INTER_AREA)
155
- frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
156
- if first_frame is None:
157
- first_frame = frame
158
- all_frames.append(frame)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
159
  frame_number += 1
160
 
161
  cap.release()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
162
  session_state["first_frame"] = copy.deepcopy(first_frame)
163
  session_state["all_frames"] = all_frames
164
  session_state["frame_stride"] = frame_stride
165
  session_state["scale_factor"] = scale_factor
166
- session_state["original_dimensions"] = (int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)), int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)))
167
-
168
- session_state["inference_state"] = predictor.init_state(video_path=video_path)
169
- session_state["input_points"] = []
170
- session_state["input_labels"] = []
171
 
 
 
 
 
 
 
 
 
 
 
172
  return [
173
- gr.update(open=False),
174
- first_frame,
175
- None,
176
- gr.update(value=None, visible=False),
 
177
  session_state,
178
  ]
179
 
180
- def segment_with_points(point_type, session_state, evt: gr.SelectData):
 
 
 
 
 
 
 
 
 
181
  session_state["input_points"].append(evt.index)
182
  print(f"TRACKING INPUT POINT: {session_state['input_points']}")
183
 
@@ -187,26 +256,43 @@ def segment_with_points(point_type, session_state, evt: gr.SelectData):
187
  session_state["input_labels"].append(0)
188
  print(f"TRACKING INPUT LABEL: {session_state['input_labels']}")
189
 
 
190
  first_frame = session_state["first_frame"]
191
  h, w = first_frame.shape[:2]
192
  transparent_background = Image.fromarray(first_frame).convert("RGBA")
193
 
194
- # Draw points
195
- fraction = 0.01
196
  radius = int(fraction * min(w, h))
 
 
197
  transparent_layer = np.zeros((h, w, 4), dtype=np.uint8)
198
 
199
  for index, track in enumerate(session_state["input_points"]):
200
- color = (0, 255, 0, 255) if session_state["input_labels"][index] == 1 else (255, 0, 0, 255)
201
- cv2.circle(transparent_layer, track, radius, color, -1)
 
 
202
 
 
203
  transparent_layer = Image.fromarray(transparent_layer, "RGBA")
204
- selected_point_map = Image.alpha_composite(transparent_background, transparent_layer)
 
 
205
 
 
206
  points = np.array(session_state["input_points"], dtype=np.float32)
 
207
  labels = np.array(session_state["input_labels"], np.int32)
208
-
209
  try:
 
 
 
 
 
 
 
210
  _, _, out_mask_logits = predictor.add_new_points(
211
  inference_state=session_state["inference_state"],
212
  frame_idx=0,
@@ -214,215 +300,475 @@ def segment_with_points(point_type, session_state, evt: gr.SelectData):
214
  points=points,
215
  labels=labels,
216
  )
 
 
217
  mask_array = (out_mask_logits[0] > 0.0).cpu().numpy()
218
 
219
- # Ensure mask matches frame size
220
  if mask_array.shape[:2] != (h, w):
221
- mask_array = cv2.resize(mask_array.astype(np.uint8), (w, h), interpolation=cv2.INTER_NEAREST).astype(bool)
 
 
 
 
222
 
223
  mask_image = show_mask(mask_array)
 
 
224
  if mask_image.size != transparent_background.size:
225
  mask_image = mask_image.resize(transparent_background.size, Image.NEAREST)
226
-
227
  first_frame_output = Image.alpha_composite(transparent_background, mask_image)
228
  except Exception as e:
229
  print(f"Error in segmentation: {e}")
 
 
 
230
  first_frame_output = selected_point_map
231
 
232
  return selected_point_map, first_frame_output, session_state
233
 
234
  def show_mask(mask, obj_id=None, random_color=False, convert_to_image=True):
 
235
  if random_color:
236
  color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
237
  else:
238
  cmap = plt.get_cmap("tab10")
239
  cmap_idx = 0 if obj_id is None else obj_id
240
  color = np.array([*cmap(cmap_idx)[:3], 0.6])
241
-
242
- h, w = mask.shape[-2:] if len(mask.shape) > 2 else mask.shape
 
 
 
 
 
 
243
  mask_reshaped = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
244
  mask_rgba = (mask_reshaped * 255).astype(np.uint8)
245
-
246
  if convert_to_image:
247
  try:
 
248
  if mask_rgba.shape[2] != 4:
 
249
  proper_mask = np.zeros((h, w, 4), dtype=np.uint8)
 
250
  proper_mask[:, :, :min(mask_rgba.shape[2], 4)] = mask_rgba[:, :, :min(mask_rgba.shape[2], 4)]
251
  mask_rgba = proper_mask
 
 
252
  return Image.fromarray(mask_rgba, "RGBA")
253
  except Exception as e:
254
  print(f"Error converting mask to image: {e}")
255
- return Image.fromarray(np.zeros((h, w, 4), dtype=np.uint8), "RGBA")
 
 
256
 
257
  return mask_rgba
258
 
259
- def propagate_to_all(video_in, session_state, progress=gr.Progress()):
260
- if len(session_state["input_points"]) == 0 or video_in is None or session_state["inference_state"] is None:
261
- return gr.update(value=None, visible=False), session_state
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
262
 
263
- chunk_size = 3
 
 
264
  try:
265
- video_segments = {}
266
- total_frames = len(session_state["all_frames"])
267
- progress(0, desc="Propagating segmentation through video...")
268
-
269
- for i, (out_frame_idx, out_obj_ids, out_mask_logit) in enumerate(predictor.propagate_in_video(session_state["inference_state"])):
 
 
 
 
 
 
 
 
 
 
 
 
 
270
  try:
 
271
  video_segments[out_frame_idx] = {
272
- out_obj_id: (out_mask_logit[i] > 0.0).cpu().numpy()
273
  for i, out_obj_id in enumerate(out_obj_ids)
274
  }
275
- progress((i + 1) / total_frames, desc=f"Processed frame {out_frame_idx}/{total_frames}")
 
 
 
 
 
 
 
 
 
276
  if out_frame_idx % chunk_size == 0:
277
- del out_mask_logit
 
278
  import gc
279
  gc.collect()
280
  except Exception as e:
281
  print(f"Error processing frame {out_frame_idx}: {e}")
 
 
282
  continue
283
 
 
 
 
 
 
 
 
 
 
 
 
 
284
  max_output_frames = 50
285
  vis_frame_stride = max(1, total_frames // max_output_frames)
 
 
 
 
 
 
286
  first_frame = session_state["all_frames"][0]
287
  h, w = first_frame.shape[:2]
 
 
288
  output_frames = []
289
-
 
 
290
  for out_frame_idx in range(0, total_frames, vis_frame_stride):
291
  if out_frame_idx not in video_segments or OBJ_ID not in video_segments[out_frame_idx]:
 
292
  continue
 
293
  try:
294
- frame = session_state["all_frames"][out_frame_idx]
 
 
 
 
 
 
 
295
  transparent_background = Image.fromarray(frame).convert("RGBA")
 
 
296
  out_mask = video_segments[out_frame_idx][OBJ_ID]
297
-
298
- # Validate mask dimensions
299
- if out_mask.shape[:2] != (h, w):
300
- if out_mask.size == 0: # Skip empty masks
301
- print(f"Skipping empty mask for frame {out_frame_idx}")
302
- continue
303
- out_mask = cv2.resize(out_mask.astype(np.uint8), (w, h), interpolation=cv2.INTER_NEAREST).astype(bool)
304
-
 
 
 
 
 
 
 
 
 
305
  mask_image = show_mask(out_mask)
 
 
306
  if mask_image.size != transparent_background.size:
307
  mask_image = mask_image.resize(transparent_background.size, Image.NEAREST)
308
-
309
  output_frame = Image.alpha_composite(transparent_background, mask_image)
310
- output_frames.append(np.array(output_frame))
311
-
 
 
 
 
 
 
 
312
  if len(output_frames) % 10 == 0:
313
  import gc
314
  gc.collect()
 
315
  except Exception as e:
316
- print(f"Error creating output frame {out_frame_idx}: {e_RAW
 
317
  traceback.print_exc()
 
318
  continue
319
 
 
320
  original_fps = get_video_fps(video_in)
321
- fps = min(original_fps, 15) # Cap at 15 FPS for CPU
322
-
323
- clip = ImageSequenceClip(output_frames, fps=fps)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
324
  unique_id = datetime.now().strftime("%Y%m%d%H%M%S")
325
- final_vid_output_path = os.path.join(tempfile.gettempdir(), f"output_video_{unique_id}.mp4")
 
326
 
 
327
  clip.write_videofile(
328
- final_vid_output_path,
329
- codec="libx264",
330
  bitrate="800k",
331
- threads=2,
332
- logger=None
333
  )
334
-
335
- del video_segments, output_frames
 
 
 
 
 
336
  import gc
337
  gc.collect()
338
 
339
- return gr.update(value=final_vid_output_path, visible=True), session_state
340
-
 
 
 
 
341
  except Exception as e:
342
  print(f"Error in propagate_to_all: {e}")
343
- return gr.update(value=None, visible=False), session_state
 
 
 
 
 
 
344
 
345
  def update_ui():
346
- return gr.update(visible=True)
 
 
347
 
348
- # Gradio Interface
349
  with gr.Blocks() as demo:
350
- session_state = gr.State({
351
- "first_frame": None,
352
- "all_frames": None,
353
- "input_points": [],
354
- "input_labels": [],
355
- "inference_state": None,
356
- "frame_stride": 1,
357
- "scale_factor": 1.0,
358
- "original_dimensions": None,
359
- })
 
 
 
360
 
361
  with gr.Column():
 
362
  gr.Markdown(title)
363
  with gr.Row():
 
364
  with gr.Column():
 
365
  gr.Markdown(description_p)
 
366
  with gr.Accordion("Input Video", open=True) as video_in_drawer:
367
  video_in = gr.Video(label="Input Video", format="mp4")
 
368
  with gr.Row():
369
- point_type = gr.Radio(label="point type", choices=["include", "exclude"], value="include", scale=2)
 
 
 
 
 
370
  propagate_btn = gr.Button("Track", scale=1, variant="primary")
371
  clear_points_btn = gr.Button("Clear Points", scale=1)
372
  reset_btn = gr.Button("Reset", scale=1)
373
- points_map = gr.Image(label="Frame with Point Prompt", type="numpy", interactive=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
374
  with gr.Column():
375
  gr.Markdown("# Try some of the examples below ⬇️")
376
- gr.Examples(examples=examples, inputs=[video_in], examples_per_page=5)
 
 
 
 
 
 
 
377
  output_image = gr.Image(label="Reference Mask")
378
  output_video = gr.Video(visible=False)
379
 
 
380
  video_in.upload(
381
  fn=preprocess_video_in,
382
- inputs=[video_in, session_state],
383
- outputs=[video_in_drawer, points_map, output_image, output_video, session_state],
 
 
 
 
 
 
 
 
 
 
384
  queue=False,
385
  )
386
 
387
  video_in.change(
388
  fn=preprocess_video_in,
389
- inputs=[video_in, session_state],
390
- outputs=[video_in_drawer, points_map, output_image, output_video, session_state],
 
 
 
 
 
 
 
 
 
 
391
  queue=False,
392
  )
393
 
 
394
  points_map.select(
395
  fn=segment_with_points,
396
- inputs=[point_type, session_state],
397
- outputs=[points_map, output_image, session_state],
 
 
 
 
 
 
 
398
  queue=False,
399
  )
400
 
 
401
  clear_points_btn.click(
402
  fn=clear_points,
403
  inputs=session_state,
404
- outputs=[points_map, output_image, output_video, session_state],
 
 
 
 
 
 
405
  queue=False,
406
  )
407
 
408
  reset_btn.click(
409
  fn=reset,
410
  inputs=session_state,
411
- outputs=[video_in, video_in_drawer, points_map, output_image, output_video, session_state],
 
 
 
 
 
 
 
 
412
  queue=False,
413
  )
414
 
415
  propagate_btn.click(
416
  fn=update_ui,
417
  inputs=[],
418
- outputs=output_video,
419
  queue=False,
420
  ).then(
421
  fn=propagate_to_all,
422
- inputs=[video_in, session_state],
423
- outputs=[output_video, session_state],
424
- queue=True,
 
 
 
 
 
 
 
425
  )
426
 
 
427
  demo.queue()
428
  demo.launch()
 
6
 
7
  import copy
8
  import os
9
+ import time
10
  from datetime import datetime
11
  import tempfile
12
 
 
15
  import numpy as np
16
  import gradio as gr
17
  import torch
18
+
19
  from moviepy.editor import ImageSequenceClip
20
  from PIL import Image
21
  from sam2.build_sam import build_sam2_video_predictor
 
24
  if 'TORCH_CUDNN_SDPA_ENABLED' in os.environ:
25
  del os.environ["TORCH_CUDNN_SDPA_ENABLED"]
26
 
27
+ # Description
28
+ title = "<center><strong><font size='8'>EdgeTAM CPU<font></strong> <a href='https://github.com/facebookresearch/EdgeTAM'><font size='6'>[GitHub]</font></a> </center>"
29
 
30
  description_p = """# Instructions
31
  <ol>
32
+ <li> Upload one video or click one example video</li>
33
+ <li> Click 'include' point type, select the object to segment and track</li>
34
+ <li> Click 'exclude' point type (optional), select the area you want to avoid segmenting and tracking</li>
35
+ <li> Click the 'Track' button to obtain the masked video </li>
36
  </ol>
37
  """
38
 
39
+ # examples - keeping fewer examples to reduce memory footprint
40
  examples = [
41
  ["examples/01_dog.mp4"],
42
  ["examples/02_cups.mp4"],
 
47
 
48
  OBJ_ID = 0
49
 
50
+ # Initialize model on CPU - add error handling for file paths
51
  sam2_checkpoint = "checkpoints/edgetam.pt"
52
  model_cfg = "edgetam.yaml"
53
 
54
+ # Check if model files exist
55
  def check_file_exists(filepath):
56
  exists = os.path.exists(filepath)
57
  if not exists:
58
  print(f"WARNING: File not found: {filepath}")
59
  return exists
60
 
61
+ # Verify files exist
62
  model_files_exist = check_file_exists(sam2_checkpoint) and check_file_exists(model_cfg)
63
+ predictor = None
64
  try:
65
+ # Load model with careful error handling
66
  predictor = build_sam2_video_predictor(model_cfg, sam2_checkpoint, device="cpu")
67
+ print("predictor loaded on CPU")
68
  except Exception as e:
69
  print(f"Error loading model: {e}")
70
  import traceback
71
  traceback.print_exc()
 
72
 
73
+ # Function to get video frame rate
74
  def get_video_fps(video_path):
75
  cap = cv2.VideoCapture(video_path)
76
  if not cap.isOpened():
77
  print("Error: Could not open video.")
78
+ return 30.0 # Default fallback value
79
  fps = cap.get(cv2.CAP_PROP_FPS)
80
  cap.release()
81
  return fps
82
 
83
  def reset(session_state):
84
+ """Reset all session state variables and UI elements."""
85
  session_state["input_points"] = []
86
  session_state["input_labels"] = []
87
  if session_state["inference_state"] is not None:
 
89
  session_state["first_frame"] = None
90
  session_state["all_frames"] = None
91
  session_state["inference_state"] = None
92
+ session_state["progress"] = 0
93
  return (
94
  None,
95
  gr.update(open=True),
96
  None,
97
  None,
98
  gr.update(value=None, visible=False),
99
+ gr.update(value=0, visible=False),
100
  session_state,
101
  )
102
 
103
  def clear_points(session_state):
104
+ """Clear tracking points while keeping the video frames."""
105
  session_state["input_points"] = []
106
  session_state["input_labels"] = []
107
  if session_state["inference_state"] is not None and session_state["inference_state"].get("tracking_has_started", False):
 
110
  session_state["first_frame"],
111
  None,
112
  gr.update(value=None, visible=False),
113
+ gr.update(value=0, visible=False),
114
  session_state,
115
  )
116
 
117
  def preprocess_video_in(video_path, session_state):
118
+ """Process input video to extract frames for tracking."""
119
+ if video_path is None or not os.path.exists(video_path):
120
  return (
121
+ gr.update(open=True), # video_in_drawer
122
+ None, # points_map
123
+ None, # output_image
124
+ gr.update(value=None, visible=False), # output_video
125
+ gr.update(value=0, visible=False), # progress_bar
126
  session_state,
127
  )
128
 
129
+ # Read the first frame
130
  cap = cv2.VideoCapture(video_path)
131
  if not cap.isOpened():
132
+ print(f"Error: Could not open video at {video_path}.")
133
  return (
134
+ gr.update(open=True), # video_in_drawer
135
+ None, # points_map
136
+ None, # output_image
137
+ gr.update(value=None, visible=False), # output_video
138
+ gr.update(value=0, visible=False), # progress_bar
139
  session_state,
140
  )
141
 
142
+ # For CPU optimization - determine video properties
143
  frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
144
  frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
145
  total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
146
+ fps = cap.get(cv2.CAP_PROP_FPS)
147
+
148
+ print(f"Video info: {frame_width}x{frame_height}, {total_frames} frames, {fps} FPS")
149
+
150
+ # Determine if we need to resize for CPU performance
151
+ target_width = 640 # Target width for processing on CPU
152
  scale_factor = 1.0
153
+
154
  if frame_width > target_width:
155
  scale_factor = target_width / frame_width
156
+ new_width = int(frame_width * scale_factor)
157
+ new_height = int(frame_height * scale_factor)
158
+ print(f"Resizing video for CPU processing: {frame_width}x{frame_height} -> {new_width}x{new_height}")
159
+
160
+ # Read frames - for CPU we'll be more selective about which frames to keep
161
  frame_number = 0
162
  first_frame = None
163
  all_frames = []
164
+
165
+ # For CPU optimization, skip frames if video is too long
166
+ frame_stride = 1
167
+ if total_frames > 300: # If more than 300 frames
168
+ frame_stride = max(1, int(total_frames / 300)) # Process at most ~300 frames
169
+ print(f"Video has {total_frames} frames, using stride of {frame_stride} to reduce processing load")
170
+
171
  while True:
172
  ret, frame = cap.read()
173
  if not ret:
174
  break
175
+
176
+ if frame_number % frame_stride == 0: # Process every frame_stride frames
177
+ try:
178
+ # Resize the frame if needed
179
+ if scale_factor != 1.0:
180
+ frame = cv2.resize(
181
+ frame,
182
+ (int(frame_width * scale_factor), int(frame_height * scale_factor)),
183
+ interpolation=cv2.INTER_AREA
184
+ )
185
+
186
+ frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
187
+ frame = np.array(frame)
188
+
189
+ # Store the first frame
190
+ if first_frame is None:
191
+ first_frame = frame
192
+ all_frames.append(frame)
193
+ except Exception as e:
194
+ print(f"Error processing frame {frame_number}: {e}")
195
+
196
  frame_number += 1
197
 
198
  cap.release()
199
+
200
+ # Ensure we have at least one frame
201
+ if first_frame is None or len(all_frames) == 0:
202
+ print("Error: No frames could be extracted from the video.")
203
+ return (
204
+ gr.update(open=True), # video_in_drawer
205
+ None, # points_map
206
+ None, # output_image
207
+ gr.update(value=None, visible=False), # output_video
208
+ gr.update(value=0, visible=False), # progress_bar
209
+ session_state,
210
+ )
211
+
212
+ print(f"Successfully extracted {len(all_frames)} frames from video")
213
+
214
  session_state["first_frame"] = copy.deepcopy(first_frame)
215
  session_state["all_frames"] = all_frames
216
  session_state["frame_stride"] = frame_stride
217
  session_state["scale_factor"] = scale_factor
218
+ session_state["original_dimensions"] = (frame_width, frame_height)
219
+ session_state["progress"] = 0
 
 
 
220
 
221
+ try:
222
+ session_state["inference_state"] = predictor.init_state(video_path=video_path)
223
+ session_state["input_points"] = []
224
+ session_state["input_labels"] = []
225
+ except Exception as e:
226
+ print(f"Error initializing inference state: {e}")
227
+ import traceback
228
+ traceback.print_exc()
229
+ session_state["inference_state"] = None
230
+
231
  return [
232
+ gr.update(open=False), # video_in_drawer
233
+ first_frame, # points_map
234
+ None, # output_image
235
+ gr.update(value=None, visible=False), # output_video
236
+ gr.update(value=0, visible=False), # progress_bar
237
  session_state,
238
  ]
239
 
240
+ def segment_with_points(
241
+ point_type,
242
+ session_state,
243
+ evt: gr.SelectData,
244
+ ):
245
+ """Add and process tracking points on the first frame."""
246
+ if session_state["first_frame"] is None:
247
+ print("Error: No frame available for segmentation")
248
+ return None, None, session_state
249
+
250
  session_state["input_points"].append(evt.index)
251
  print(f"TRACKING INPUT POINT: {session_state['input_points']}")
252
 
 
256
  session_state["input_labels"].append(0)
257
  print(f"TRACKING INPUT LABEL: {session_state['input_labels']}")
258
 
259
+ # Open the image and get its dimensions
260
  first_frame = session_state["first_frame"]
261
  h, w = first_frame.shape[:2]
262
  transparent_background = Image.fromarray(first_frame).convert("RGBA")
263
 
264
+ # Define the circle radius as a fraction of the smaller dimension
265
+ fraction = 0.01 # You can adjust this value as needed
266
  radius = int(fraction * min(w, h))
267
+
268
+ # Create a transparent layer to draw on
269
  transparent_layer = np.zeros((h, w, 4), dtype=np.uint8)
270
 
271
  for index, track in enumerate(session_state["input_points"]):
272
+ if session_state["input_labels"][index] == 1:
273
+ cv2.circle(transparent_layer, track, radius, (0, 255, 0, 255), -1) # Green for include
274
+ else:
275
+ cv2.circle(transparent_layer, track, radius, (255, 0, 0, 255), -1) # Red for exclude
276
 
277
+ # Convert the transparent layer back to an image
278
  transparent_layer = Image.fromarray(transparent_layer, "RGBA")
279
+ selected_point_map = Image.alpha_composite(
280
+ transparent_background, transparent_layer
281
+ )
282
 
283
+ # Let's add a positive click at (x, y) = (210, 350) to get started
284
  points = np.array(session_state["input_points"], dtype=np.float32)
285
+ # for labels, `1` means positive click and `0` means negative click
286
  labels = np.array(session_state["input_labels"], np.int32)
287
+
288
  try:
289
+ if predictor is None:
290
+ raise ValueError("Model predictor is not initialized")
291
+
292
+ if session_state["inference_state"] is None:
293
+ raise ValueError("Inference state is not initialized")
294
+
295
+ # For CPU optimization, we'll process with smaller batch size
296
  _, _, out_mask_logits = predictor.add_new_points(
297
  inference_state=session_state["inference_state"],
298
  frame_idx=0,
 
300
  points=points,
301
  labels=labels,
302
  )
303
+
304
+ # Create the mask
305
  mask_array = (out_mask_logits[0] > 0.0).cpu().numpy()
306
 
307
+ # Ensure the mask has the same size as the frame
308
  if mask_array.shape[:2] != (h, w):
309
+ mask_array = cv2.resize(
310
+ mask_array.astype(np.uint8),
311
+ (w, h),
312
+ interpolation=cv2.INTER_NEAREST
313
+ ).astype(bool)
314
 
315
  mask_image = show_mask(mask_array)
316
+
317
+ # Make sure mask_image has the same size as the background
318
  if mask_image.size != transparent_background.size:
319
  mask_image = mask_image.resize(transparent_background.size, Image.NEAREST)
320
+
321
  first_frame_output = Image.alpha_composite(transparent_background, mask_image)
322
  except Exception as e:
323
  print(f"Error in segmentation: {e}")
324
+ import traceback
325
+ traceback.print_exc()
326
+ # Return just the points as fallback
327
  first_frame_output = selected_point_map
328
 
329
  return selected_point_map, first_frame_output, session_state
330
 
331
  def show_mask(mask, obj_id=None, random_color=False, convert_to_image=True):
332
+ """Convert binary mask to RGBA image for visualization."""
333
  if random_color:
334
  color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
335
  else:
336
  cmap = plt.get_cmap("tab10")
337
  cmap_idx = 0 if obj_id is None else obj_id
338
  color = np.array([*cmap(cmap_idx)[:3], 0.6])
339
+
340
+ # Handle different mask shapes properly
341
+ if len(mask.shape) == 2:
342
+ h, w = mask.shape
343
+ else:
344
+ h, w = mask.shape[-2:]
345
+
346
+ # Ensure correct reshaping based on mask dimensions
347
  mask_reshaped = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
348
  mask_rgba = (mask_reshaped * 255).astype(np.uint8)
349
+
350
  if convert_to_image:
351
  try:
352
+ # Ensure the mask has correct RGBA shape (h, w, 4)
353
  if mask_rgba.shape[2] != 4:
354
+ # If not RGBA, create a proper RGBA array
355
  proper_mask = np.zeros((h, w, 4), dtype=np.uint8)
356
+ # Copy available channels
357
  proper_mask[:, :, :min(mask_rgba.shape[2], 4)] = mask_rgba[:, :, :min(mask_rgba.shape[2], 4)]
358
  mask_rgba = proper_mask
359
+
360
+ # Create the PIL image
361
  return Image.fromarray(mask_rgba, "RGBA")
362
  except Exception as e:
363
  print(f"Error converting mask to image: {e}")
364
+ # Fallback: create a blank transparent image of correct size
365
+ blank = np.zeros((h, w, 4), dtype=np.uint8)
366
+ return Image.fromarray(blank, "RGBA")
367
 
368
  return mask_rgba
369
 
370
+ def update_progress(progress_percent, progress_bar):
371
+ """Update progress bar during processing."""
372
+ return gr.update(value=progress_percent, visible=True)
373
+
374
+ def propagate_to_all(
375
+ video_in,
376
+ session_state,
377
+ progress=gr.Progress(),
378
+ ):
379
+ """Process video frames and generate masked video output with progress tracking."""
380
+ if (
381
+ len(session_state["input_points"]) == 0
382
+ or video_in is None
383
+ or session_state["inference_state"] is None
384
+ or predictor is None
385
+ ):
386
+ print("Missing required data for tracking")
387
+ return (
388
+ gr.update(value=None, visible=False),
389
+ gr.update(value=0, visible=False),
390
+ session_state,
391
+ )
392
 
393
+ # For CPU optimization: process in smaller batches
394
+ chunk_size = 3 # Process 3 frames at a time to avoid memory issues on CPU
395
+
396
  try:
397
+ # run propagation throughout the video and collect the results in a dict
398
+ video_segments = {} # video_segments contains the per-frame segmentation results
399
+ print("Starting propagate_in_video on CPU")
400
+
401
+ progress.tqdm.reset()
402
+
403
+ # Get the count for progress reporting
404
+ all_frames_count = 0
405
+ for _ in predictor.propagate_in_video(session_state["inference_state"], count_only=True):
406
+ all_frames_count += 1
407
+
408
+ print(f"Total frames to process: {all_frames_count}")
409
+ progress.tqdm.total = all_frames_count
410
+
411
+ # Now do the actual processing with progress updates
412
+ for out_frame_idx, out_obj_ids, out_mask_logits in predictor.propagate_in_video(
413
+ session_state["inference_state"]
414
+ ):
415
  try:
416
+ # Store the masks for each object ID
417
  video_segments[out_frame_idx] = {
418
+ out_obj_id: (out_mask_logits[i] > 0.0).cpu().numpy()
419
  for i, out_obj_id in enumerate(out_obj_ids)
420
  }
421
+
422
+ # Update progress
423
+ progress.tqdm.update(1)
424
+ progress_percent = min(100, int((out_frame_idx + 1) / all_frames_count * 100))
425
+ session_state["progress"] = progress_percent
426
+
427
+ if out_frame_idx % 10 == 0:
428
+ print(f"Processed frame {out_frame_idx}/{all_frames_count} ({progress_percent}%)")
429
+
430
+ # Release memory periodically
431
  if out_frame_idx % chunk_size == 0:
432
+ # Explicitly clear any tensors
433
+ del out_mask_logits
434
  import gc
435
  gc.collect()
436
  except Exception as e:
437
  print(f"Error processing frame {out_frame_idx}: {e}")
438
+ import traceback
439
+ traceback.print_exc()
440
  continue
441
 
442
+ # For CPU optimization: increase stride to reduce processing
443
+ # Create a more aggressive stride to limit to fewer frames in output
444
+ total_frames = len(video_segments)
445
+ print(f"Total frames processed: {total_frames}")
446
+
447
+ # Update progress to show rendering phase
448
+ progress.tqdm.reset()
449
+ progress.tqdm.total = 2 # Two phases: rendering and video creation
450
+ progress.tqdm.update(1)
451
+ session_state["progress"] = 50
452
+
453
+ # Limit to max 50 frames for CPU processing
454
  max_output_frames = 50
455
  vis_frame_stride = max(1, total_frames // max_output_frames)
456
+ print(f"Using stride of {vis_frame_stride} for output video generation")
457
+
458
+ # Get dimensions of the frames
459
+ if len(session_state["all_frames"]) == 0:
460
+ raise ValueError("No frames available in session state")
461
+
462
  first_frame = session_state["all_frames"][0]
463
  h, w = first_frame.shape[:2]
464
+
465
+ # Create output frames
466
  output_frames = []
467
+ progress.tqdm.reset()
468
+ progress.tqdm.total = (total_frames // vis_frame_stride) + 1
469
+
470
  for out_frame_idx in range(0, total_frames, vis_frame_stride):
471
  if out_frame_idx not in video_segments or OBJ_ID not in video_segments[out_frame_idx]:
472
+ progress.tqdm.update(1)
473
  continue
474
+
475
  try:
476
+ # Get corresponding frame from all_frames
477
+ if out_frame_idx >= len(session_state["all_frames"]):
478
+ print(f"Warning: Frame index {out_frame_idx} exceeds available frames {len(session_state['all_frames'])}")
479
+ frame_idx = min(out_frame_idx, len(session_state["all_frames"])-1)
480
+ else:
481
+ frame_idx = out_frame_idx
482
+
483
+ frame = session_state["all_frames"][frame_idx]
484
  transparent_background = Image.fromarray(frame).convert("RGBA")
485
+
486
+ # Get the mask and ensure it's the right size
487
  out_mask = video_segments[out_frame_idx][OBJ_ID]
488
+
489
+ # Ensure the mask is not empty and has the right dimensions
490
+ if out_mask.size == 0:
491
+ print(f"Warning: Empty mask for frame {out_frame_idx}")
492
+ # Create an empty mask of the right size
493
+ out_mask = np.zeros((h, w), dtype=bool)
494
+
495
+ # Resize mask if dimensions don't match
496
+ mask_h, mask_w = out_mask.shape[:2]
497
+ if mask_h != h or mask_w != w:
498
+ print(f"Resizing mask from {mask_h}x{mask_w} to {h}x{w}")
499
+ out_mask = cv2.resize(
500
+ out_mask.astype(np.uint8),
501
+ (w, h),
502
+ interpolation=cv2.INTER_NEAREST
503
+ ).astype(bool)
504
+
505
  mask_image = show_mask(out_mask)
506
+
507
+ # Make sure mask has same dimensions as background
508
  if mask_image.size != transparent_background.size:
509
  mask_image = mask_image.resize(transparent_background.size, Image.NEAREST)
510
+
511
  output_frame = Image.alpha_composite(transparent_background, mask_image)
512
+ output_frame = np.array(output_frame)
513
+ output_frames.append(output_frame)
514
+
515
+ # Update progress
516
+ progress.tqdm.update(1)
517
+ progress_percent = 50 + min(50, int((len(output_frames) / (total_frames // vis_frame_stride)) * 50))
518
+ session_state["progress"] = progress_percent
519
+
520
+ # Clear memory periodically
521
  if len(output_frames) % 10 == 0:
522
  import gc
523
  gc.collect()
524
+
525
  except Exception as e:
526
+ print(f"Error creating output frame {out_frame_idx}: {e}")
527
+ import traceback
528
  traceback.print_exc()
529
+ progress.tqdm.update(1)
530
  continue
531
 
532
+ # Create a video clip from the image sequence
533
  original_fps = get_video_fps(video_in)
534
+ fps = original_fps
535
+
536
+ # For CPU optimization - lower FPS if original is high
537
+ if fps > 15:
538
+ fps = 15 # Lower fps for CPU processing
539
+
540
+ print(f"Creating video with {len(output_frames)} frames at {fps} FPS")
541
+
542
+ # Update progress to show video creation phase
543
+ session_state["progress"] = 90
544
+
545
+ # Check if we have any frames to work with
546
+ if len(output_frames) == 0:
547
+ raise ValueError("No output frames were generated")
548
+
549
+ # Ensure all frames have the same shape
550
+ first_shape = output_frames[0].shape
551
+ valid_frames = []
552
+ for i, frame in enumerate(output_frames):
553
+ if frame.shape == first_shape:
554
+ valid_frames.append(frame)
555
+ else:
556
+ print(f"Skipping frame {i} with inconsistent shape: {frame.shape} vs {first_shape}")
557
+
558
+ if len(valid_frames) == 0:
559
+ raise ValueError("No valid frames with consistent shape")
560
+
561
+ clip = ImageSequenceClip(valid_frames, fps=fps)
562
+
563
+ # Write the result to a file - use lower quality for CPU
564
  unique_id = datetime.now().strftime("%Y%m%d%H%M%S")
565
+ final_vid_output_path = f"output_video_{unique_id}.mp4"
566
+ final_vid_output_path = os.path.join(tempfile.gettempdir(), final_vid_output_path)
567
 
568
+ # Lower bitrate for CPU processing
569
  clip.write_videofile(
570
+ final_vid_output_path,
571
+ codec="libx264",
572
  bitrate="800k",
573
+ threads=2, # Use fewer threads for CPU
574
+ logger=None # Disable logger to reduce console output
575
  )
576
+
577
+ # Complete progress
578
+ session_state["progress"] = 100
579
+
580
+ # Free memory
581
+ del video_segments
582
+ del output_frames
583
  import gc
584
  gc.collect()
585
 
586
+ return (
587
+ gr.update(value=final_vid_output_path, visible=True),
588
+ gr.update(value=100, visible=False),
589
+ session_state,
590
+ )
591
+
592
  except Exception as e:
593
  print(f"Error in propagate_to_all: {e}")
594
+ import traceback
595
+ traceback.print_exc()
596
+ return (
597
+ gr.update(value=None, visible=False),
598
+ gr.update(value=0, visible=False),
599
+ session_state,
600
+ )
601
 
602
  def update_ui():
603
+ """Show progress bar when starting processing."""
604
+ return gr.update(visible=True), gr.update(visible=True, value=0)
605
+
606
 
607
+ # Main Gradio UI setup
608
  with gr.Blocks() as demo:
609
+ session_state = gr.State(
610
+ {
611
+ "first_frame": None,
612
+ "all_frames": None,
613
+ "input_points": [],
614
+ "input_labels": [],
615
+ "inference_state": None,
616
+ "frame_stride": 1,
617
+ "scale_factor": 1.0,
618
+ "original_dimensions": None,
619
+ "progress": 0,
620
+ }
621
+ )
622
 
623
  with gr.Column():
624
+ # Title
625
  gr.Markdown(title)
626
  with gr.Row():
627
+
628
  with gr.Column():
629
+ # Instructions
630
  gr.Markdown(description_p)
631
+
632
  with gr.Accordion("Input Video", open=True) as video_in_drawer:
633
  video_in = gr.Video(label="Input Video", format="mp4")
634
+
635
  with gr.Row():
636
+ point_type = gr.Radio(
637
+ label="point type",
638
+ choices=["include", "exclude"],
639
+ value="include",
640
+ scale=2,
641
+ )
642
  propagate_btn = gr.Button("Track", scale=1, variant="primary")
643
  clear_points_btn = gr.Button("Clear Points", scale=1)
644
  reset_btn = gr.Button("Reset", scale=1)
645
+
646
+ points_map = gr.Image(
647
+ label="Frame with Point Prompt", type="numpy", interactive=False
648
+ )
649
+
650
+ # Add progress bar
651
+ progress_bar = gr.Slider(
652
+ minimum=0,
653
+ maximum=100,
654
+ value=0,
655
+ step=1,
656
+ label="Processing Progress",
657
+ visible=False,
658
+ interactive=False
659
+ )
660
+
661
  with gr.Column():
662
  gr.Markdown("# Try some of the examples below ⬇️")
663
+ gr.Examples(
664
+ examples=examples,
665
+ inputs=[
666
+ video_in,
667
+ ],
668
+ examples_per_page=5,
669
+ )
670
+
671
  output_image = gr.Image(label="Reference Mask")
672
  output_video = gr.Video(visible=False)
673
 
674
+ # When new video is uploaded
675
  video_in.upload(
676
  fn=preprocess_video_in,
677
+ inputs=[
678
+ video_in,
679
+ session_state,
680
+ ],
681
+ outputs=[
682
+ video_in_drawer, # Accordion to hide uploaded video player
683
+ points_map, # Image component where we add new tracking points
684
+ output_image,
685
+ output_video,
686
+ progress_bar,
687
+ session_state,
688
+ ],
689
  queue=False,
690
  )
691
 
692
  video_in.change(
693
  fn=preprocess_video_in,
694
+ inputs=[
695
+ video_in,
696
+ session_state,
697
+ ],
698
+ outputs=[
699
+ video_in_drawer, # Accordion to hide uploaded video player
700
+ points_map, # Image component where we add new tracking points
701
+ output_image,
702
+ output_video,
703
+ progress_bar,
704
+ session_state,
705
+ ],
706
  queue=False,
707
  )
708
 
709
+ # triggered when we click on image to add new points
710
  points_map.select(
711
  fn=segment_with_points,
712
+ inputs=[
713
+ point_type, # "include" or "exclude"
714
+ session_state,
715
+ ],
716
+ outputs=[
717
+ points_map, # updated image with points
718
+ output_image,
719
+ session_state,
720
+ ],
721
  queue=False,
722
  )
723
 
724
+ # Clear every points clicked and added to the map
725
  clear_points_btn.click(
726
  fn=clear_points,
727
  inputs=session_state,
728
+ outputs=[
729
+ points_map,
730
+ output_image,
731
+ output_video,
732
+ progress_bar,
733
+ session_state,
734
+ ],
735
  queue=False,
736
  )
737
 
738
  reset_btn.click(
739
  fn=reset,
740
  inputs=session_state,
741
+ outputs=[
742
+ video_in,
743
+ video_in_drawer,
744
+ points_map,
745
+ output_image,
746
+ output_video,
747
+ progress_bar,
748
+ session_state,
749
+ ],
750
  queue=False,
751
  )
752
 
753
  propagate_btn.click(
754
  fn=update_ui,
755
  inputs=[],
756
+ outputs=[output_video, progress_bar],
757
  queue=False,
758
  ).then(
759
  fn=propagate_to_all,
760
+ inputs=[
761
+ video_in,
762
+ session_state,
763
+ ],
764
+ outputs=[
765
+ output_video,
766
+ progress_bar,
767
+ session_state,
768
+ ],
769
+ queue=True, # Use queue for CPU processing
770
  )
771
 
772
+
773
  demo.queue()
774
  demo.launch()