Spaces:
bla
/
Runtime error

bla commited on
Commit
1affb38
·
verified ·
1 Parent(s): cac3a2b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +487 -232
app.py CHANGED
@@ -10,13 +10,16 @@ from datetime import datetime
10
 
11
  import gradio as gr
12
 
13
- os.environ["TORCH_CUDNN_SDPA_ENABLED"] = "0,1,2,3,4,5,6,7"
 
 
14
  import tempfile
15
 
16
  import cv2
17
  import matplotlib.pyplot as plt
18
  import numpy as np
19
- import spaces
 
20
  import torch
21
 
22
  from moviepy.editor import ImageSequenceClip
@@ -35,7 +38,7 @@ description_p = """# Instructions
35
  </ol>
36
  """
37
 
38
- # examples
39
  examples = [
40
  ["examples/01_dog.mp4"],
41
  ["examples/02_cups.mp4"],
@@ -70,90 +73,79 @@ examples = [
70
 
71
  OBJ_ID = 0
72
 
73
-
74
  sam2_checkpoint = "checkpoints/edgetam.pt"
75
  model_cfg = "edgetam.yaml"
 
76
  predictor = build_sam2_video_predictor(model_cfg, sam2_checkpoint, device="cpu")
77
- predictor.to("cpu")
78
- print("predictor loaded")
 
 
 
79
 
80
- # use bfloat16 for the entire demo
81
- torch.autocast(device_type="cpu", dtype=torch.bfloat16).__enter__()
82
- # if torch.cuda.get_device_properties(0).major >= 8:
83
- # # turn on tfloat32 for Ampere GPUs (https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices)
84
- # torch.backends.cuda.matmul.allow_tf32 = True
85
- # torch.backends.cudnn.allow_tf32 = True
86
 
87
 
88
  def get_video_fps(video_path):
89
- # Open the video file
 
 
 
90
  cap = cv2.VideoCapture(video_path)
91
-
92
  if not cap.isOpened():
93
- print("Error: Could not open video.")
94
  return None
95
-
96
- # Get the FPS of the video
97
  fps = cap.get(cv2.CAP_PROP_FPS)
98
-
99
  return fps
100
 
101
-
102
- def reset(session_state):
103
- session_state["input_points"] = []
104
- session_state["input_labels"] = []
105
- if session_state["inference_state"] is not None:
106
- predictor.reset_state(session_state["inference_state"])
107
- session_state["first_frame"] = None
108
- session_state["all_frames"] = None
109
- session_state["inference_state"] = None
110
- return (
111
- None,
112
- gr.update(open=True),
113
- None,
114
- None,
115
- gr.update(value=None, visible=False),
116
- session_state,
117
- )
118
-
119
-
120
- def clear_points(session_state):
121
- session_state["input_points"] = []
122
- session_state["input_labels"] = []
123
- if session_state["inference_state"]["tracking_has_started"]:
124
- predictor.reset_state(session_state["inference_state"])
125
- return (
126
- session_state["first_frame"],
127
- None,
128
- gr.update(value=None, visible=False),
129
- session_state,
130
- )
131
-
132
-
133
- @spaces.GPU
134
  def preprocess_video_in(video_path, session_state):
135
- if video_path is None:
 
 
 
 
136
  return (
137
  gr.update(open=True), # video_in_drawer
138
  None, # points_map
139
  None, # output_image
140
  gr.update(value=None, visible=False), # output_video
141
- session_state,
 
 
 
 
 
 
 
 
 
142
  )
143
 
144
- # Read the first frame
145
  cap = cv2.VideoCapture(video_path)
146
  if not cap.isOpened():
147
- print("Error: Could not open video.")
 
148
  return (
149
- gr.update(open=True), # video_in_drawer
150
- None, # points_map
151
- None, # output_image
152
- gr.update(value=None, visible=False), # output_video
153
- session_state,
 
 
 
 
 
 
 
 
 
154
  )
155
 
156
- frame_number = 0
157
  first_frame = None
158
  all_frames = []
159
 
@@ -161,180 +153,407 @@ def preprocess_video_in(video_path, session_state):
161
  ret, frame = cap.read()
162
  if not ret:
163
  break
164
-
165
  frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
166
- frame = np.array(frame)
167
-
168
- # Store the first frame
169
- if frame_number == 0:
170
- first_frame = frame
171
  all_frames.append(frame)
172
-
173
- frame_number += 1
174
 
175
  cap.release()
176
- session_state["first_frame"] = copy.deepcopy(first_frame)
177
- session_state["all_frames"] = all_frames
178
 
179
- session_state["inference_state"] = predictor.init_state(video_path=video_path)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
180
  session_state["input_points"] = []
181
  session_state["input_labels"] = []
 
 
 
182
 
183
  return [
184
  gr.update(open=False), # video_in_drawer
185
- first_frame, # points_map
186
- None, # output_image
187
- gr.update(value=None, visible=False), # output_video
188
- session_state,
 
 
 
189
  ]
190
 
191
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
192
 
 
 
 
 
 
 
 
 
 
 
 
 
 
193
  def segment_with_points(
194
  point_type,
195
  session_state,
196
  evt: gr.SelectData,
197
  ):
198
- session_state["input_points"].append(evt.index)
199
- print(f"TRACKING INPUT POINT: {session_state['input_points']}")
 
 
 
 
 
 
 
 
 
 
 
 
 
200
 
201
  if point_type == "include":
202
  session_state["input_labels"].append(1)
203
  elif point_type == "exclude":
204
  session_state["input_labels"].append(0)
205
- print(f"TRACKING INPUT LABEL: {session_state['input_labels']}")
206
 
207
- # Open the image and get its dimensions
208
- transparent_background = Image.fromarray(session_state["first_frame"]).convert(
209
- "RGBA"
210
- )
211
- w, h = transparent_background.size
212
 
213
- # Define the circle radius as a fraction of the smaller dimension
214
- fraction = 0.01 # You can adjust this value as needed
215
- radius = int(fraction * min(w, h))
216
 
217
- # Create a transparent layer to draw on
218
- transparent_layer = np.zeros((h, w, 4), dtype=np.uint8)
219
 
 
220
  for index, track in enumerate(session_state["input_points"]):
 
 
221
  if session_state["input_labels"][index] == 1:
222
- cv2.circle(transparent_layer, track, radius, (0, 255, 0, 255), -1)
 
223
  else:
224
- cv2.circle(transparent_layer, track, radius, (255, 0, 0, 255), -1)
225
-
226
- # Convert the transparent layer back to an image
227
- transparent_layer = Image.fromarray(transparent_layer, "RGBA")
228
- selected_point_map = Image.alpha_composite(
229
- transparent_background, transparent_layer
 
 
230
  )
231
 
232
- # Let's add a positive click at (x, y) = (210, 350) to get started
233
  points = np.array(session_state["input_points"], dtype=np.float32)
234
- # for labels, `1` means positive click and `0` means negative click
235
  labels = np.array(session_state["input_labels"], np.int32)
236
- _, _, out_mask_logits = predictor.add_new_points(
237
- inference_state=session_state["inference_state"],
238
- frame_idx=0,
239
- obj_id=OBJ_ID,
240
- points=points,
241
- labels=labels,
242
- )
243
 
244
- mask_image = show_mask((out_mask_logits[0] > 0.0).cpu().numpy())
245
- first_frame_output = Image.alpha_composite(transparent_background, mask_image)
 
 
 
 
 
 
 
 
 
 
 
246
 
247
- # torch.cuda.empty_cache()
248
- return selected_point_map, first_frame_output, session_state
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
249
 
250
 
251
  def show_mask(mask, obj_id=None, random_color=False, convert_to_image=True):
 
 
 
 
 
 
252
  if random_color:
253
- color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
254
  else:
255
  cmap = plt.get_cmap("tab10")
256
- cmap_idx = 0 if obj_id is None else obj_id
257
- color = np.array([*cmap(cmap_idx)[:3], 0.6])
258
- h, w = mask.shape[-2:]
259
- mask = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
260
- mask = (mask * 255).astype(np.uint8)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
261
  if convert_to_image:
262
- mask = Image.fromarray(mask, "RGBA")
263
- return mask
 
 
264
 
265
 
266
- @spaces.GPU
267
  def propagate_to_all(
268
- video_in,
269
  session_state,
270
  ):
 
 
 
271
  if (
272
- len(session_state["input_points"]) == 0
273
- or video_in is None
274
  or session_state["inference_state"] is None
275
  ):
 
276
  return (
277
- None,
278
  session_state,
279
  )
280
 
281
- # run propagation throughout the video and collect the results in a dict
282
- video_segments = {} # video_segments contains the per-frame segmentation results
283
- print("starting propagate_in_video")
284
- for out_frame_idx, out_obj_ids, out_mask_logits in predictor.propagate_in_video(
285
- session_state["inference_state"]
286
- ):
287
- video_segments[out_frame_idx] = {
288
- out_obj_id: (out_mask_logits[i] > 0.0).cpu().numpy()
289
- for i, out_obj_id in enumerate(out_obj_ids)
290
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
291
 
292
- # obtain the segmentation results every few frames
293
- vis_frame_stride = 1
294
 
295
  output_frames = []
296
- for out_frame_idx in range(0, len(video_segments), vis_frame_stride):
297
- transparent_background = Image.fromarray(
298
- session_state["all_frames"][out_frame_idx]
299
- ).convert("RGBA")
300
- out_mask = video_segments[out_frame_idx][OBJ_ID]
301
- mask_image = show_mask(out_mask)
302
- output_frame = Image.alpha_composite(transparent_background, mask_image)
303
- output_frame = np.array(output_frame)
304
- output_frames.append(output_frame)
305
-
306
- # torch.cuda.empty_cache()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
307
 
308
  # Create a video clip from the image sequence
309
- original_fps = get_video_fps(video_in)
310
- fps = original_fps # Frames per second
311
- clip = ImageSequenceClip(output_frames, fps=fps)
312
- # Write the result to a file
313
- unique_id = datetime.now().strftime("%Y%m%d%H%M%S")
314
- final_vid_output_path = f"output_video_{unique_id}.mp4"
315
- final_vid_output_path = os.path.join(tempfile.gettempdir(), final_vid_output_path)
 
 
 
 
 
316
 
317
- # Write the result to a file
318
- clip.write_videofile(final_vid_output_path, codec="libx264")
 
 
 
 
 
 
 
319
 
320
- return (
321
- gr.update(value=final_vid_output_path),
322
- session_state,
323
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
324
 
325
 
326
- def update_ui():
 
327
  return gr.update(visible=True)
328
 
329
 
330
  with gr.Blocks() as demo:
 
331
  session_state = gr.State(
332
  {
333
- "first_frame": None,
334
- "all_frames": None,
335
- "input_points": [],
336
- "input_labels": [],
337
- "inference_state": None,
 
338
  }
339
  )
340
 
@@ -348,7 +567,7 @@ with gr.Blocks() as demo:
348
  gr.Markdown(description_p)
349
 
350
  with gr.Accordion("Input Video", open=True) as video_in_drawer:
351
- video_in = gr.Video(label="Input Video", format="mp4")
352
 
353
  with gr.Row():
354
  point_type = gr.Radio(
@@ -356,125 +575,161 @@ with gr.Blocks() as demo:
356
  choices=["include", "exclude"],
357
  value="include",
358
  scale=2,
 
359
  )
360
- propagate_btn = gr.Button("Track", scale=1, variant="primary")
361
- clear_points_btn = gr.Button("Clear Points", scale=1)
362
- reset_btn = gr.Button("Reset", scale=1)
 
363
 
 
 
364
  points_map = gr.Image(
365
- label="Frame with Point Prompt", type="numpy", interactive=False
 
 
 
 
 
 
 
366
  )
367
 
368
  with gr.Column():
369
  gr.Markdown("# Try some of the examples below ⬇️")
370
  gr.Examples(
371
  examples=examples,
372
- inputs=[
373
- video_in,
374
- ],
375
  examples_per_page=8,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
376
  )
377
- gr.Markdown("\n\n\n\n\n\n\n\n\n\n\n")
378
- gr.Markdown("\n\n\n\n\n\n\n\n\n\n\n")
379
- gr.Markdown("\n\n\n\n\n\n\n\n\n\n\n")
380
- output_image = gr.Image(label="Reference Mask")
381
 
382
- output_video = gr.Video(visible=False)
 
 
 
 
383
 
384
- # When new video is uploaded
385
  video_in.upload(
386
  fn=preprocess_video_in,
387
- inputs=[
388
- video_in,
389
- session_state,
390
- ],
391
  outputs=[
392
- video_in_drawer, # Accordion to hide uploaded video player
393
- points_map, # Image component where we add new tracking points
394
- output_image,
395
- output_video,
396
- session_state,
 
 
 
397
  ],
398
- queue=False,
399
  )
400
 
 
401
  video_in.change(
402
  fn=preprocess_video_in,
403
- inputs=[
404
- video_in,
405
- session_state,
406
- ],
407
- outputs=[
408
- video_in_drawer, # Accordion to hide uploaded video player
409
- points_map, # Image component where we add new tracking points
410
- output_image,
411
- output_video,
412
- session_state,
413
  ],
414
- queue=False,
415
  )
416
 
417
- # triggered when we click on image to add new points
 
418
  points_map.select(
419
  fn=segment_with_points,
420
  inputs=[
421
- point_type, # "include" or "exclude"
422
- session_state,
423
  ],
424
  outputs=[
425
- points_map, # updated image with points
426
- output_image,
427
- session_state,
428
  ],
429
- queue=False,
430
  )
431
 
432
- # Clear every points clicked and added to the map
433
  clear_points_btn.click(
434
  fn=clear_points,
435
- inputs=session_state,
436
  outputs=[
437
- points_map,
438
- output_image,
439
- output_video,
440
- session_state,
441
  ],
442
- queue=False,
443
  )
444
 
 
445
  reset_btn.click(
446
  fn=reset,
447
- inputs=session_state,
448
  outputs=[
449
- video_in,
450
- video_in_drawer,
451
- points_map,
452
- output_image,
453
- output_video,
454
- session_state,
 
 
 
455
  ],
456
- queue=False,
457
  )
458
 
 
459
  propagate_btn.click(
460
- fn=update_ui,
461
  inputs=[],
462
- outputs=output_video,
463
- queue=False,
464
- ).then(
465
  fn=propagate_to_all,
466
  inputs=[
467
- video_in,
468
- session_state,
469
  ],
470
  outputs=[
471
- output_video,
472
- session_state,
473
  ],
474
- concurrency_limit=10,
475
- queue=False,
 
 
476
  )
477
 
478
 
479
- demo.queue()
 
 
480
  demo.launch()
 
 
10
 
11
  import gradio as gr
12
 
13
+ # Removed GPU-specific environment variable setting
14
+ # os.environ["TORCH_CUDNN_SDPA_ENABLED"] = "0,1,2,3,4,5,6,7"
15
+
16
  import tempfile
17
 
18
  import cv2
19
  import matplotlib.pyplot as plt
20
  import numpy as np
21
+ # Removed spaces decorator import for CPU-only demo
22
+ # import spaces
23
  import torch
24
 
25
  from moviepy.editor import ImageSequenceClip
 
38
  </ol>
39
  """
40
 
41
+ # examples - Keep examples, they are input files
42
  examples = [
43
  ["examples/01_dog.mp4"],
44
  ["examples/02_cups.mp4"],
 
73
 
74
  OBJ_ID = 0
75
 
 
76
  sam2_checkpoint = "checkpoints/edgetam.pt"
77
  model_cfg = "edgetam.yaml"
78
+ # Ensure predictor is explicitly built for CPU
79
  predictor = build_sam2_video_predictor(model_cfg, sam2_checkpoint, device="cpu")
80
+ predictor.to("cpu") # Explicitly move to CPU, though device="cpu" should handle it
81
+ print("predictor loaded on CPU")
82
+
83
+ # Removed autocast block for maximum CPU compatibility
84
+ # torch.autocast(device_type="cpu", dtype=torch.bfloat16).__enter__()
85
 
86
+ # Removed commented-out GPU-specific code
87
+ # if torch.cuda.get_device_properties(0).major >= 8: ...
 
 
 
 
88
 
89
 
90
  def get_video_fps(video_path):
91
+ """Gets the frames per second of a video file."""
92
+ if video_path is None or not os.path.exists(video_path):
93
+ print(f"Warning: Video file not found at {video_path}")
94
+ return None
95
  cap = cv2.VideoCapture(video_path)
 
96
  if not cap.isOpened():
97
+ print(f"Error: Could not open video file {video_path}.")
98
  return None
 
 
99
  fps = cap.get(cv2.CAP_PROP_FPS)
100
+ cap.release()
101
  return fps
102
 
103
+ # Removed @spaces.GPU decorator
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
104
  def preprocess_video_in(video_path, session_state):
105
+ """Loads video frames and initializes the predictor state."""
106
+ print(f"Processing video: {video_path}")
107
+ if video_path is None or not os.path.exists(video_path):
108
+ print("No video path provided or file not found.")
109
+ # Reset state and UI elements if input is invalid
110
  return (
111
  gr.update(open=True), # video_in_drawer
112
  None, # points_map
113
  None, # output_image
114
  gr.update(value=None, visible=False), # output_video
115
+ gr.update(interactive=False), # propagate_btn
116
+ gr.update(interactive=False), # clear_points_btn
117
+ gr.update(interactive=False), # reset_btn
118
+ { # Reset session state
119
+ "first_frame": None,
120
+ "all_frames": None,
121
+ "input_points": [],
122
+ "input_labels": [],
123
+ "inference_state": None,
124
+ }
125
  )
126
 
127
+ # Read the first frame and all frames
128
  cap = cv2.VideoCapture(video_path)
129
  if not cap.isOpened():
130
+ print(f"Error: Could not open video file {video_path}.")
131
+ # Reset state and UI elements on error
132
  return (
133
+ gr.update(open=True),
134
+ None,
135
+ None,
136
+ gr.update(value=None, visible=False),
137
+ gr.update(interactive=False), # propagate_btn
138
+ gr.update(interactive=False), # clear_points_btn
139
+ gr.update(interactive=False), # reset_btn
140
+ { # Reset session state
141
+ "first_frame": None,
142
+ "all_frames": None,
143
+ "input_points": [],
144
+ "input_labels": [],
145
+ "inference_state": None,
146
+ }
147
  )
148
 
 
149
  first_frame = None
150
  all_frames = []
151
 
 
153
  ret, frame = cap.read()
154
  if not ret:
155
  break
156
+ # Convert BGR to RGB
157
  frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
 
 
 
 
 
158
  all_frames.append(frame)
159
+ if first_frame is None:
160
+ first_frame = frame # Store the first frame
161
 
162
  cap.release()
 
 
163
 
164
+ if not all_frames:
165
+ print(f"Error: No frames read from video file {video_path}.")
166
+ # Reset state and UI elements if no frames are read
167
+ return (
168
+ gr.update(open=True),
169
+ None,
170
+ None,
171
+ gr.update(value=None, visible=False),
172
+ gr.update(interactive=False), # propagate_btn
173
+ gr.update(interactive=False), # clear_points_btn
174
+ gr.update(interactive=False), # reset_btn
175
+ { # Reset session state
176
+ "first_frame": None,
177
+ "all_frames": None,
178
+ "input_points": [],
179
+ "input_labels": [],
180
+ "inference_state": None,
181
+ }
182
+ )
183
+
184
+
185
+ session_state["first_frame"] = copy.deepcopy(first_frame) # Store a copy
186
+ session_state["all_frames"] = all_frames
187
  session_state["input_points"] = []
188
  session_state["input_labels"] = []
189
+ # Initialize state explicitly for CPU
190
+ session_state["inference_state"] = predictor.init_state(video_path=video_path, device="cpu")
191
+ print("Video loaded and predictor state initialized.")
192
 
193
  return [
194
  gr.update(open=False), # video_in_drawer
195
+ first_frame, # points_map (shows first frame)
196
+ None, # output_image (cleared initially)
197
+ gr.update(value=None, visible=False), # output_video (hidden initially)
198
+ gr.update(interactive=True), # Enable buttons
199
+ gr.update(interactive=True), # Enable buttons
200
+ gr.update(interactive=True), # Enable buttons
201
+ session_state, # Updated state
202
  ]
203
 
204
 
205
+ def reset(session_state):
206
+ """Resets the UI and session state."""
207
+ print("Resetting demo.")
208
+ # Clear points and labels
209
+ session_state["input_points"] = []
210
+ session_state["input_labels"] = []
211
+ # Reset the predictor state if it exists
212
+ if session_state["inference_state"] is not None:
213
+ predictor.reset_state(session_state["inference_state"])
214
+ # After reset, we also discard the state object as a new video might be loaded
215
+ session_state["inference_state"] = None
216
+ # Clear frames
217
+ session_state["first_frame"] = None
218
+ session_state["all_frames"] = None
219
+
220
+ # Update UI elements to their initial state
221
+ return (
222
+ None, # video_in
223
+ gr.update(open=True), # video_in_drawer open
224
+ None, # points_map cleared
225
+ None, # output_image cleared
226
+ gr.update(value=None, visible=False), # output_video hidden
227
+ gr.update(interactive=False), # Disable buttons
228
+ gr.update(interactive=False), # Disable buttons
229
+ gr.update(interactive=False), # Disable buttons
230
+ session_state, # Updated session state
231
+ )
232
+
233
+
234
+ def clear_points(session_state):
235
+ """Clears selected points and resets segmentation on the first frame."""
236
+ print("Clearing points.")
237
+ # Clear points and labels lists
238
+ session_state["input_points"] = []
239
+ session_state["input_labels"] = []
240
+
241
+ # If inference state exists, reset it. This clears internal masks/features
242
+ # but keeps the video context initialized by preprocess_video_in.
243
+ if session_state["inference_state"] is not None:
244
+ predictor.reset_state(session_state["inference_state"])
245
+ # After resetting the state, we need to re-initialize it to be ready for new points.
246
+ # Pass the original video path stored in the state.
247
+ if "video_path" in session_state["inference_state"] and session_state["inference_state"]["video_path"] is not None:
248
+ session_state["inference_state"] = predictor.init_state(video_path=session_state["inference_state"]["video_path"], device="cpu")
249
+ else:
250
+ # This case should ideally not happen if preprocess_video_in ran correctly
251
+ print("Warning: Could not re-initialize state after clear_points (video_path missing).")
252
+ session_state["inference_state"] = None
253
+
254
 
255
+ # Re-render the points_map with no points drawn (just the first frame)
256
+ # Re-render the output_image with no mask (just the first frame)
257
+ first_frame_img = session_state["first_frame"] if session_state["first_frame"] is not None else None
258
+
259
+ return (
260
+ first_frame_img, # points_map shows original first frame
261
+ None, # output_image cleared
262
+ gr.update(value=None, visible=False), # Hide output video
263
+ session_state, # Updated session state
264
+ )
265
+
266
+
267
+ # Removed @spaces.GPU decorator
268
  def segment_with_points(
269
  point_type,
270
  session_state,
271
  evt: gr.SelectData,
272
  ):
273
+ """Adds a point prompt and performs segmentation on the first frame."""
274
+ # Ensure we have a valid first frame and inference state
275
+ if session_state["first_frame"] is None or session_state["inference_state"] is None:
276
+ print("Error: Cannot segment. No video loaded or inference state missing.")
277
+ return (
278
+ session_state["first_frame"], # points_map remains unchanged
279
+ None, # output_image remains unchanged or cleared
280
+ session_state,
281
+ )
282
+
283
+ # evt.index gives the (x, y) coordinates of the click
284
+ click_coords = evt.index
285
+ print(f"Clicked at: {click_coords} ({point_type})")
286
+
287
+ session_state["input_points"].append(click_coords)
288
 
289
  if point_type == "include":
290
  session_state["input_labels"].append(1)
291
  elif point_type == "exclude":
292
  session_state["input_labels"].append(0)
 
293
 
294
+ # Get the first frame as a PIL image for drawing
295
+ first_frame_pil = Image.fromarray(session_state["first_frame"]).convert("RGBA")
296
+ w, h = first_frame_pil.size
 
 
297
 
298
+ # Define the circle radius
299
+ fraction = 0.01
300
+ radius = max(2, int(fraction * min(w, h))) # Ensure minimum radius of 2
301
 
302
+ # Create a transparent layer to draw points
303
+ transparent_layer_points = np.zeros((h, w, 4), dtype=np.uint8)
304
 
305
+ # Draw points on the transparent layer
306
  for index, track in enumerate(session_state["input_points"]):
307
+ # Ensure coordinates are integers for cv2.circle
308
+ point_coords = (int(track[0]), int(track[1]))
309
  if session_state["input_labels"][index] == 1:
310
+ # Green circle for include
311
+ cv2.circle(transparent_layer_points, point_coords, radius, (0, 255, 0, 255), -1)
312
  else:
313
+ # Red circle for exclude
314
+ cv2.circle(transparent_layer_points, point_coords, radius, (255, 0, 0, 255), -1)
315
+
316
+ # Convert the transparent layer back to an image and composite onto the first frame
317
+ transparent_layer_points_pil = Image.fromarray(transparent_layer_points, "RGBA")
318
+ # Combine the first frame image with the points layer for the points_map output
319
+ selected_point_map_img = Image.alpha_composite(
320
+ first_frame_pil.copy(), transparent_layer_points_pil
321
  )
322
 
323
+ # Prepare points and labels as tensors on CPU for the predictor
324
  points = np.array(session_state["input_points"], dtype=np.float32)
 
325
  labels = np.array(session_state["input_labels"], np.int32)
 
 
 
 
 
 
 
326
 
327
+ points_tensor = torch.tensor(points, dtype=torch.float32, device="cpu").unsqueeze(0) # Add batch dim
328
+ labels_tensor = torch.tensor(labels, dtype=torch.int32, device="cpu").unsqueeze(0) # Add batch dim
329
+
330
+ # Add new points to the predictor's state and get the mask for the first frame
331
+ # This call performs segmentation on the current frame (frame_idx=0) using all accumulated points
332
+ try:
333
+ _, _, out_mask_logits = predictor.add_new_points(
334
+ inference_state=session_state["inference_state"],
335
+ frame_idx=0, # Always segment on the first frame initially
336
+ obj_id=OBJ_ID,
337
+ points=points_tensor,
338
+ labels=labels_tensor,
339
+ )
340
 
341
+ # Process logits: detach from graph, move to CPU, apply threshold
342
+ # out_mask_logits is [batch_size, H, W] (batch_size=1 here)
343
+ mask_tensor = (out_mask_logits[0][0].detach().cpu() > 0.0) # Apply threshold and get the single mask tensor [H, W]
344
+ mask_numpy = mask_tensor.numpy() # Convert to numpy
345
+
346
+ # Get the mask image (RGBA)
347
+ mask_image_pil = show_mask(mask_numpy, obj_id=OBJ_ID) # show_mask returns RGBA PIL Image
348
+
349
+ # Composite the mask onto the first frame for the output_image
350
+ first_frame_output_img = Image.alpha_composite(first_frame_pil.copy(), mask_image_pil)
351
+
352
+ except Exception as e:
353
+ print(f"Error during segmentation on first frame: {e}")
354
+ # On error, return the points_map but clear the output_image
355
+ first_frame_output_img = None
356
+
357
+
358
+ return selected_point_map_img, first_frame_output_img, session_state
359
 
360
 
361
  def show_mask(mask, obj_id=None, random_color=False, convert_to_image=True):
362
+ """Helper function to visualize a mask."""
363
+ # Ensure mask is a numpy array (and boolean)
364
+ if isinstance(mask, torch.Tensor):
365
+ mask = mask.detach().cpu().numpy() # Ensure it's on CPU and converted to numpy
366
+ mask = mask.astype(bool) # Ensure mask is boolean
367
+
368
  if random_color:
369
+ color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0) # RGBA with 0.6 alpha
370
  else:
371
  cmap = plt.get_cmap("tab10")
372
+ cmap_idx = 0 if obj_id is None else obj_id % 10 # Use modulo 10 for tab10 colors
373
+ color = np.array([*cmap(cmap_idx)[:3], 0.6]) # RGBA with 0.6 alpha
374
+
375
+ # Ensure mask has H, W dimensions
376
+ if mask.ndim == 3:
377
+ mask = mask.squeeze() # Remove singular dimensions
378
+ if mask.ndim != 2:
379
+ print(f"Warning: show_mask received mask with shape {mask.shape}. Expected 2D.")
380
+ # Create an empty transparent image if mask shape is unexpected
381
+ if convert_to_image:
382
+ return Image.fromarray(np.zeros((*mask.shape[:2], 4), dtype=np.uint8), "RGBA")
383
+ else:
384
+ return np.zeros((*mask.shape[:2], 4), dtype=np.uint8)
385
+
386
+
387
+ h, w = mask.shape
388
+ # Create an RGBA image from the mask and color
389
+ # Apply color where mask is True
390
+ # Need to reshape color to be broadcastable [1, 1, 4]
391
+ colored_mask = np.zeros((h, w, 4), dtype=np.float32) # Start with fully transparent black
392
+ colored_mask[mask] = color # Apply color where mask is True
393
+
394
+ # Convert to uint8 [0-255]
395
+ colored_mask_uint8 = (colored_mask * 255).astype(np.uint8)
396
+
397
  if convert_to_image:
398
+ mask_img = Image.fromarray(colored_mask_uint8, "RGBA")
399
+ return mask_img
400
+ else:
401
+ return colored_mask_uint8
402
 
403
 
404
+ # Removed @spaces.GPU decorator
405
  def propagate_to_all(
406
+ video_in, # Keep video_in path to potentially get FPS again if needed
407
  session_state,
408
  ):
409
+ """Runs mask propagation through the video and generates the output video."""
410
+ print("Starting propagation...")
411
+ # Ensure state is ready
412
  if (
413
+ len(session_state["input_points"]) == 0 # Need at least one point
414
+ or session_state["all_frames"] is None
415
  or session_state["inference_state"] is None
416
  ):
417
+ print("Error: Cannot propagate. No points selected, video not loaded, or inference state missing.")
418
  return (
419
+ gr.update(value=None, visible=False), # Hide output video on error
420
  session_state,
421
  )
422
 
423
+ # run propagation throughout the video and collect the results
424
+ # The generator yields (frame_idx, obj_ids, mask_logits)
425
+ video_segments = {}
426
+ try:
427
+ for out_frame_idx, out_obj_ids, out_mask_logits in predictor.propagate_in_video(
428
+ session_state["inference_state"]
429
+ ):
430
+ # Process logits: detach from graph, move to CPU, convert to numpy boolean mask
431
+ # Ensure tensor is on CPU before converting to numpy
432
+ video_segments[out_frame_idx] = {
433
+ out_obj_id: (out_mask_logits[i].detach().cpu() > 0.0).numpy()
434
+ for i, out_obj_id in enumerate(out_obj_ids)
435
+ }
436
+ # Optional: print progress
437
+ # print(f"Processed frame {out_frame_idx+1}/{len(session_state['all_frames'])}")
438
+
439
+ print("Propagation finished.")
440
+ except Exception as e:
441
+ print(f"Error during propagation: {e}")
442
+ return (
443
+ gr.update(value=None, visible=False), # Hide output video on error
444
+ session_state,
445
+ )
446
 
 
 
447
 
448
  output_frames = []
449
+ # Iterate through all original frames to generate output video
450
+ for out_frame_idx in range(len(session_state["all_frames"])):
451
+ original_frame_rgb = session_state["all_frames"][out_frame_idx]
452
+ # Convert original frame to RGBA for compositing
453
+ transparent_background = Image.fromarray(original_frame_rgb).convert("RGBA")
454
+
455
+ # Check if we have a mask for this frame and object ID
456
+ if out_frame_idx in video_segments and OBJ_ID in video_segments[out_frame_idx]:
457
+ current_mask_numpy = video_segments[out_frame_idx][OBJ_ID]
458
+ # Get the mask image (RGBA)
459
+ mask_image_pil = show_mask(current_mask_numpy, obj_id=OBJ_ID)
460
+ # Composite the mask onto the frame
461
+ output_frame_img_rgba = Image.alpha_composite(transparent_background, mask_image_pil)
462
+ # Convert back to numpy RGB (moviepy needs RGB or RGBA)
463
+ output_frame_np = np.array(output_frame_img_rgba.convert("RGB"))
464
+ else:
465
+ # If no mask for this frame/object, just use the original frame (converted to RGB)
466
+ # Note: all_frames are already RGB numpy arrays, so just use them directly.
467
+ # print(f"Warning: No mask found for frame {out_frame_idx} and object {OBJ_ID}. Using original frame.")
468
+ output_frame_np = original_frame_rgb # Already RGB numpy array
469
+
470
+ output_frames.append(output_frame_np)
471
+
472
+
473
+ # Define output path in a temporary directory
474
+ unique_id = datetime.now().strftime("%Y%m%d%H%M%S%f") # Use microseconds for more uniqueness
475
+ final_vid_filename = f"output_video_{unique_id}.mp4"
476
+ # Use os.path.join for cross-platform compatibility
477
+ final_vid_output_path = os.path.join(tempfile.gettempdir(), final_vid_filename)
478
+ print(f"Output video path: {final_vid_output_path}")
479
+
480
 
481
  # Create a video clip from the image sequence
482
+ # Get original FPS or default
483
+ original_fps = get_video_fps(video_in) # Re-get FPS from the input file path
484
+ fps = original_fps if original_fps is not None and original_fps > 0 else 30 # Default to 30 if detection fails or is zero
485
+ print(f"Creating output video with FPS: {fps}")
486
+
487
+ # Check if there are frames to process
488
+ if not output_frames:
489
+ print("No output frames generated.")
490
+ return (
491
+ gr.update(value=None, visible=False), # Hide output video
492
+ session_state,
493
+ )
494
 
495
+ # Create ImageSequenceClip from the list of numpy arrays
496
+ try:
497
+ clip = ImageSequenceClip(output_frames, fps=fps)
498
+ except Exception as e:
499
+ print(f"Error creating ImageSequenceClip: {e}")
500
+ return (
501
+ gr.update(value=None, visible=False), # Hide output video on error
502
+ session_state,
503
+ )
504
 
505
+
506
+ # Write the result to a file. Use 'libx264' codec for broad compatibility.
507
+ # `preset` and `threads` for CPU optimization.
508
+ # `logger=None` prevents moviepy from printing progress to stdout/stderr, which can clutter the Gradio logs.
509
+ try:
510
+ print(f"Writing video file with codec='libx264', fps={fps}, preset='medium', threads='auto'")
511
+ clip.write_videofile(
512
+ final_vid_output_path,
513
+ codec="libx264",
514
+ fps=fps, # Ensure correct FPS is used during writing
515
+ preset="medium", # CPU optimization: 'fast', 'faster', 'veryfast' are options for speed
516
+ threads="auto", # CPU optimization: Use multiple cores
517
+ logger=None # Suppress moviepy output
518
+ )
519
+ print("Video writing complete.")
520
+ # Return the path and make the video player visible
521
+ return (
522
+ gr.update(value=final_vid_output_path, visible=True),
523
+ session_state,
524
+ )
525
+ except Exception as e:
526
+ print(f"Error writing video file: {e}")
527
+ # Clean up potentially created partial file
528
+ if os.path.exists(final_vid_output_path):
529
+ try:
530
+ os.remove(final_vid_output_path)
531
+ print(f"Removed partial video file: {final_vid_output_path}")
532
+ except Exception as clean_e:
533
+ print(f"Error removing partial file: {clean_e}")
534
+
535
+ # Return None if writing fails
536
+ return (
537
+ gr.update(value=None, visible=False),
538
+ session_state,
539
+ )
540
 
541
 
542
+ def update_output_video_visibility():
543
+ """Simply returns a Gradio update to make the output video visible."""
544
  return gr.update(visible=True)
545
 
546
 
547
  with gr.Blocks() as demo:
548
+ # Session state dictionary to hold video frames, points, labels, and predictor state
549
  session_state = gr.State(
550
  {
551
+ "first_frame": None, # numpy array (RGB)
552
+ "all_frames": None, # list of numpy arrays (RGB)
553
+ "input_points": [], # list of (x, y) tuples/lists
554
+ "input_labels": [], # list of 1s and 0s
555
+ "inference_state": None, # EdgeTAM predictor state object
556
+ "video_path": None, # Store the input video path
557
  }
558
  )
559
 
 
567
  gr.Markdown(description_p)
568
 
569
  with gr.Accordion("Input Video", open=True) as video_in_drawer:
570
+ video_in = gr.Video(label="Input Video", format="mp4") # Will hold the video file path
571
 
572
  with gr.Row():
573
  point_type = gr.Radio(
 
575
  choices=["include", "exclude"],
576
  value="include",
577
  scale=2,
578
+ interactive=True, # Make interactive
579
  )
580
+ # Buttons are initially disabled until a video is loaded
581
+ propagate_btn = gr.Button("Track", scale=1, variant="primary", interactive=False)
582
+ clear_points_btn = gr.Button("Clear Points", scale=1, interactive=False)
583
+ reset_btn = gr.Button("Reset", scale=1, interactive=False)
584
 
585
+ # points_map is where users click to add points. Needs to be interactive.
586
+ # Shows the first frame with points drawn on it.
587
  points_map = gr.Image(
588
+ label="Frame with Point Prompt",
589
+ type="numpy",
590
+ interactive=True, # Make interactive to capture clicks
591
+ height=400, # Set a fixed height for better UI
592
+ width="auto", # Let width adjust
593
+ show_share_button=False,
594
+ show_download_button=False,
595
+ # show_label=False # Can hide label if space is tight
596
  )
597
 
598
  with gr.Column():
599
  gr.Markdown("# Try some of the examples below ⬇️")
600
  gr.Examples(
601
  examples=examples,
602
+ inputs=[video_in],
 
 
603
  examples_per_page=8,
604
+ cache_examples=False, # Do not cache processed examples, as state is involved
605
+ )
606
+ # Add padding/space
607
+ # gr.Markdown("<br>")
608
+
609
+ # output_image shows the segmentation mask prediction on the *first* frame
610
+ output_image = gr.Image(
611
+ label="Reference Mask (First Frame)",
612
+ type="numpy",
613
+ interactive=False, # Not interactive, just displays the mask
614
+ height=400, # Match height of points_map
615
+ width="auto", # Let width adjust
616
+ show_share_button=False,
617
+ show_download_button=False,
618
+ # show_label=False # Can hide label
619
  )
 
 
 
 
620
 
621
+ # output_video shows the final tracking result
622
+ output_video = gr.Video(visible=False, label="Tracking Result")
623
+
624
+
625
+ # --- Event Handlers ---
626
 
627
+ # When a new video file is uploaded via the file browser
628
  video_in.upload(
629
  fn=preprocess_video_in,
630
+ inputs=[video_in, session_state],
 
 
 
631
  outputs=[
632
+ video_in_drawer, # Close accordion
633
+ points_map, # Show first frame in points_map
634
+ output_image, # Clear output image
635
+ output_video, # Hide output video
636
+ propagate_btn, # Enable Track button
637
+ clear_points_btn,# Enable Clear Points button
638
+ reset_btn, # Enable Reset button
639
+ session_state, # Update session state
640
  ],
641
+ queue=False, # Process immediately
642
  )
643
 
644
+ # When an example video is selected (change event)
645
  video_in.change(
646
  fn=preprocess_video_in,
647
+ inputs=[video_in, session_state],
648
+ outputs=[
649
+ video_in_drawer, # Close accordion
650
+ points_map, # Show first frame in points_map
651
+ output_image, # Clear output image
652
+ output_video, # Hide output video
653
+ propagate_btn, # Enable Track button
654
+ clear_points_btn,# Enable Clear Points button
655
+ reset_btn, # Enable Reset button
656
+ session_state, # Update session state
657
  ],
658
+ queue=False, # Process immediately
659
  )
660
 
661
+
662
+ # Triggered when a user clicks on the points_map image
663
  points_map.select(
664
  fn=segment_with_points,
665
  inputs=[
666
+ point_type, # "include" or "exclude" radio button value
667
+ session_state, # Pass session state
668
  ],
669
  outputs=[
670
+ points_map, # Updated image with points drawn
671
+ output_image, # Updated image with first frame segmentation mask
672
+ session_state, # Updated session state (points/labels added)
673
  ],
674
+ queue=False, # Process clicks immediately
675
  )
676
 
677
+ # Button to clear all selected points and reset the first frame mask
678
  clear_points_btn.click(
679
  fn=clear_points,
680
+ inputs=[session_state], # Pass session state
681
  outputs=[
682
+ points_map, # points_map shows original first frame without points
683
+ output_image, # output_image cleared (or shows original first frame without mask)
684
+ output_video, # Hide output video
685
+ session_state, # Updated session state (points/labels cleared, inference state reset)
686
  ],
687
+ queue=False, # Process immediately
688
  )
689
 
690
+ # Button to reset the entire demo state and UI
691
  reset_btn.click(
692
  fn=reset,
693
+ inputs=[session_state], # Pass session state
694
  outputs=[
695
+ video_in, # Clear video input
696
+ video_in_drawer, # Open video accordion
697
+ points_map, # Clear points_map
698
+ output_image, # Clear output_image
699
+ output_video, # Hide output_video
700
+ propagate_btn, # Disable buttons
701
+ clear_points_btn,# Disable buttons
702
+ reset_btn, # Disable buttons
703
+ session_state, # Reset session state
704
  ],
705
+ queue=False, # Process immediately
706
  )
707
 
708
+ # Button to start mask propagation through the video
709
  propagate_btn.click(
710
+ fn=update_output_video_visibility, # First, make the output video player visible
711
  inputs=[],
712
+ outputs=[output_video],
713
+ queue=False, # Process this UI update immediately
714
+ ).then( # Then, run the propagation function
715
  fn=propagate_to_all,
716
  inputs=[
717
+ video_in, # Get the input video path
718
+ session_state, # Pass session state (contains frames, points, inference_state)
719
  ],
720
  outputs=[
721
+ output_video, # Update output video player with result
722
+ session_state, # Update session state (currently, propagate doesn't modify state much, but good practice)
723
  ],
724
+ # CPU Optimization: Limit concurrency to 1 to prevent resource exhaustion.
725
+ # Queue=True ensures requests wait if another is processing.
726
+ concurrency_limit=1,
727
+ queue=True,
728
  )
729
 
730
 
731
+ # Launch the Gradio demo
732
+ demo.queue() # Enable queuing for sequential processing under concurrency limits
733
+ print("Gradio demo starting...")
734
  demo.launch()
735
+ print("Gradio demo launched.")