Spaces:
bla
/
Runtime error

bla commited on
Commit
b950bc5
·
verified ·
1 Parent(s): 0b34400

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +206 -75
app.py CHANGED
@@ -46,11 +46,30 @@ examples = [
46
 
47
  OBJ_ID = 0
48
 
49
- # Initialize model on CPU
50
  sam2_checkpoint = "checkpoints/edgetam.pt"
51
  model_cfg = "edgetam.yaml"
52
- predictor = build_sam2_video_predictor(model_cfg, sam2_checkpoint, device="cpu")
53
- print("predictor loaded on CPU")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54
 
55
  # Function to get video frame rate
56
  def get_video_fps(video_path):
@@ -196,10 +215,9 @@ def segment_with_points(
196
  print(f"TRACKING INPUT LABEL: {session_state['input_labels']}")
197
 
198
  # Open the image and get its dimensions
199
- transparent_background = Image.fromarray(session_state["first_frame"]).convert(
200
- "RGBA"
201
- )
202
- w, h = transparent_background.size
203
 
204
  # Define the circle radius as a fraction of the smaller dimension
205
  fraction = 0.01 # You can adjust this value as needed
@@ -225,17 +243,38 @@ def segment_with_points(
225
  # for labels, `1` means positive click and `0` means negative click
226
  labels = np.array(session_state["input_labels"], np.int32)
227
 
228
- # For CPU optimization, we'll process with smaller batch size
229
- _, _, out_mask_logits = predictor.add_new_points(
230
- inference_state=session_state["inference_state"],
231
- frame_idx=0,
232
- obj_id=OBJ_ID,
233
- points=points,
234
- labels=labels,
235
- )
236
-
237
- mask_image = show_mask((out_mask_logits[0] > 0.0).cpu().numpy())
238
- first_frame_output = Image.alpha_composite(transparent_background, mask_image)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
239
 
240
  return selected_point_map, first_frame_output, session_state
241
 
@@ -247,12 +286,36 @@ def show_mask(mask, obj_id=None, random_color=False, convert_to_image=True):
247
  cmap = plt.get_cmap("tab10")
248
  cmap_idx = 0 if obj_id is None else obj_id
249
  color = np.array([*cmap(cmap_idx)[:3], 0.6])
250
- h, w = mask.shape[-2:]
251
- mask = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
252
- mask = (mask * 255).astype(np.uint8)
 
 
 
 
 
 
 
 
253
  if convert_to_image:
254
- mask = Image.fromarray(mask, "RGBA")
255
- return mask
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
256
 
257
 
258
  def propagate_to_all(
@@ -270,62 +333,130 @@ def propagate_to_all(
270
  )
271
 
272
  # For CPU optimization: process in smaller batches
273
- chunk_size = 5 # Process 5 frames at a time to avoid memory issues
274
 
275
- # run propagation throughout the video and collect the results in a dict
276
- video_segments = {} # video_segments contains the per-frame segmentation results
277
- print("starting propagate_in_video on CPU")
278
-
279
- # Get the frames in chunks for CPU memory optimization
280
- for out_frame_idx, out_obj_ids, out_mask_logits in predictor.propagate_in_video(
281
- session_state["inference_state"]
282
- ):
283
- video_segments[out_frame_idx] = {
284
- out_obj_id: (out_mask_logits[i] > 0.0).cpu().numpy()
285
- for i, out_obj_id in enumerate(out_obj_ids)
286
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
287
 
288
- # Free up memory after processing each frame
289
- if len(video_segments) % chunk_size == 0:
290
- torch.cuda.empty_cache() if torch.cuda.is_available() else None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
291
 
292
- # obtain the segmentation results every few frames
293
- # For CPU optimization: increase stride to reduce processing
294
- vis_frame_stride = max(1, len(video_segments) // 100) # Limit to ~100 frames in output
295
-
296
- output_frames = []
297
- for out_frame_idx in range(0, len(video_segments), vis_frame_stride):
298
- transparent_background = Image.fromarray(
299
- session_state["all_frames"][out_frame_idx]
300
- ).convert("RGBA")
301
- out_mask = video_segments[out_frame_idx][OBJ_ID]
302
- mask_image = show_mask(out_mask)
303
- output_frame = Image.alpha_composite(transparent_background, mask_image)
304
- output_frame = np.array(output_frame)
305
- output_frames.append(output_frame)
306
-
307
- # Create a video clip from the image sequence
308
- original_fps = get_video_fps(video_in)
309
- fps = original_fps # Frames per second
310
-
311
- # For CPU optimization - lower FPS if original is high
312
- if fps > 24:
313
- fps = 24
314
-
315
- clip = ImageSequenceClip(output_frames, fps=fps)
316
 
317
- # Write the result to a file - use lower quality for CPU
318
- unique_id = datetime.now().strftime("%Y%m%d%H%M%S")
319
- final_vid_output_path = f"output_video_{unique_id}.mp4"
320
- final_vid_output_path = os.path.join(tempfile.gettempdir(), final_vid_output_path)
321
-
322
- # Lower bitrate for CPU processing
323
- clip.write_videofile(final_vid_output_path, codec="libx264", bitrate="1000k")
324
-
325
- return (
326
- gr.update(value=final_vid_output_path),
327
- session_state,
328
- )
329
 
330
 
331
  def update_ui():
 
46
 
47
  OBJ_ID = 0
48
 
49
+ # Initialize model on CPU - add error handling for file paths
50
  sam2_checkpoint = "checkpoints/edgetam.pt"
51
  model_cfg = "edgetam.yaml"
52
+
53
+ # Check if model files exist
54
+ def check_file_exists(filepath):
55
+ import os
56
+ exists = os.path.exists(filepath)
57
+ if not exists:
58
+ print(f"WARNING: File not found: {filepath}")
59
+ return exists
60
+
61
+ # Verify files exist
62
+ model_files_exist = check_file_exists(sam2_checkpoint) and check_file_exists(model_cfg)
63
+ try:
64
+ # Load model with more careful error handling
65
+ predictor = build_sam2_video_predictor(model_cfg, sam2_checkpoint, device="cpu")
66
+ print("predictor loaded on CPU")
67
+ except Exception as e:
68
+ print(f"Error loading model: {e}")
69
+ import traceback
70
+ traceback.print_exc()
71
+ # Still create a predictor variable to avoid NameError
72
+ predictor = None
73
 
74
  # Function to get video frame rate
75
  def get_video_fps(video_path):
 
215
  print(f"TRACKING INPUT LABEL: {session_state['input_labels']}")
216
 
217
  # Open the image and get its dimensions
218
+ first_frame = session_state["first_frame"]
219
+ h, w = first_frame.shape[:2]
220
+ transparent_background = Image.fromarray(first_frame).convert("RGBA")
 
221
 
222
  # Define the circle radius as a fraction of the smaller dimension
223
  fraction = 0.01 # You can adjust this value as needed
 
243
  # for labels, `1` means positive click and `0` means negative click
244
  labels = np.array(session_state["input_labels"], np.int32)
245
 
246
+ try:
247
+ # For CPU optimization, we'll process with smaller batch size
248
+ _, _, out_mask_logits = predictor.add_new_points(
249
+ inference_state=session_state["inference_state"],
250
+ frame_idx=0,
251
+ obj_id=OBJ_ID,
252
+ points=points,
253
+ labels=labels,
254
+ )
255
+
256
+ # Create the mask
257
+ mask_array = (out_mask_logits[0] > 0.0).cpu().numpy()
258
+
259
+ # Ensure the mask has the same size as the frame
260
+ if mask_array.shape[:2] != (h, w):
261
+ mask_array = cv2.resize(
262
+ mask_array.astype(np.uint8),
263
+ (w, h),
264
+ interpolation=cv2.INTER_NEAREST
265
+ ).astype(bool)
266
+
267
+ mask_image = show_mask(mask_array)
268
+
269
+ # Make sure mask_image has the same size as the background
270
+ if mask_image.size != transparent_background.size:
271
+ mask_image = mask_image.resize(transparent_background.size, Image.NEAREST)
272
+
273
+ first_frame_output = Image.alpha_composite(transparent_background, mask_image)
274
+ except Exception as e:
275
+ print(f"Error in segmentation: {e}")
276
+ # Return just the points as fallback
277
+ first_frame_output = selected_point_map
278
 
279
  return selected_point_map, first_frame_output, session_state
280
 
 
286
  cmap = plt.get_cmap("tab10")
287
  cmap_idx = 0 if obj_id is None else obj_id
288
  color = np.array([*cmap(cmap_idx)[:3], 0.6])
289
+
290
+ # Handle different mask shapes properly
291
+ if len(mask.shape) == 2:
292
+ h, w = mask.shape
293
+ else:
294
+ h, w = mask.shape[-2:]
295
+
296
+ # Ensure correct reshaping based on mask dimensions
297
+ mask_reshaped = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
298
+ mask_rgba = (mask_reshaped * 255).astype(np.uint8)
299
+
300
  if convert_to_image:
301
+ try:
302
+ # Ensure the mask has correct RGBA shape (h, w, 4)
303
+ if mask_rgba.shape[2] != 4:
304
+ # If not RGBA, create a proper RGBA array
305
+ proper_mask = np.zeros((h, w, 4), dtype=np.uint8)
306
+ # Copy available channels
307
+ proper_mask[:, :, :min(mask_rgba.shape[2], 4)] = mask_rgba[:, :, :min(mask_rgba.shape[2], 4)]
308
+ mask_rgba = proper_mask
309
+
310
+ # Create the PIL image
311
+ return Image.fromarray(mask_rgba, "RGBA")
312
+ except Exception as e:
313
+ print(f"Error converting mask to image: {e}")
314
+ # Fallback: create a blank transparent image of correct size
315
+ blank = np.zeros((h, w, 4), dtype=np.uint8)
316
+ return Image.fromarray(blank, "RGBA")
317
+
318
+ return mask_rgba
319
 
320
 
321
  def propagate_to_all(
 
333
  )
334
 
335
  # For CPU optimization: process in smaller batches
336
+ chunk_size = 3 # Process 3 frames at a time to avoid memory issues on CPU
337
 
338
+ try:
339
+ # run propagation throughout the video and collect the results in a dict
340
+ video_segments = {} # video_segments contains the per-frame segmentation results
341
+ print("starting propagate_in_video on CPU")
342
+
343
+ # Get the frames in chunks for CPU memory optimization
344
+ for out_frame_idx, out_obj_ids, out_mask_logits in predictor.propagate_in_video(
345
+ session_state["inference_state"]
346
+ ):
347
+ try:
348
+ # Store the masks for each object ID
349
+ video_segments[out_frame_idx] = {
350
+ out_obj_id: (out_mask_logits[i] > 0.0).cpu().numpy()
351
+ for i, out_obj_id in enumerate(out_obj_ids)
352
+ }
353
+
354
+ print(f"Processed frame {out_frame_idx}")
355
+
356
+ # Release memory periodically
357
+ if out_frame_idx % chunk_size == 0:
358
+ # Explicitly clear any tensors
359
+ del out_mask_logits
360
+ import gc
361
+ gc.collect()
362
+ except Exception as e:
363
+ print(f"Error processing frame {out_frame_idx}: {e}")
364
+ continue
365
+
366
+ # For CPU optimization: increase stride to reduce processing
367
+ # Create a more aggressive stride to limit to fewer frames in output
368
+ total_frames = len(video_segments)
369
+ print(f"Total frames processed: {total_frames}")
370
+
371
+ # Limit to max 50 frames for CPU processing
372
+ max_output_frames = 50
373
+ vis_frame_stride = max(1, total_frames // max_output_frames)
374
+
375
+ # Get dimensions of the frames
376
+ first_frame = session_state["all_frames"][0]
377
+ h, w = first_frame.shape[:2]
378
+
379
+ output_frames = []
380
+ for out_frame_idx in range(0, total_frames, vis_frame_stride):
381
+ if out_frame_idx not in video_segments or OBJ_ID not in video_segments[out_frame_idx]:
382
+ continue
383
+
384
+ try:
385
+ frame = session_state["all_frames"][out_frame_idx]
386
+ transparent_background = Image.fromarray(frame).convert("RGBA")
387
+
388
+ # Get the mask and ensure it's the right size
389
+ out_mask = video_segments[out_frame_idx][OBJ_ID]
390
+
391
+ # Resize mask if dimensions don't match
392
+ if out_mask.shape[:2] != (h, w):
393
+ out_mask = cv2.resize(
394
+ out_mask.astype(np.uint8),
395
+ (w, h),
396
+ interpolation=cv2.INTER_NEAREST
397
+ ).astype(bool)
398
+
399
+ mask_image = show_mask(out_mask)
400
+
401
+ # Make sure mask has same dimensions as background
402
+ if mask_image.size != transparent_background.size:
403
+ mask_image = mask_image.resize(transparent_background.size, Image.NEAREST)
404
+
405
+ output_frame = Image.alpha_composite(transparent_background, mask_image)
406
+ output_frame = np.array(output_frame)
407
+ output_frames.append(output_frame)
408
+
409
+ # Clear memory periodically
410
+ if len(output_frames) % 10 == 0:
411
+ import gc
412
+ gc.collect()
413
+
414
+ except Exception as e:
415
+ print(f"Error creating output frame {out_frame_idx}: {e}")
416
+ continue
417
+
418
+ # Create a video clip from the image sequence
419
+ original_fps = get_video_fps(video_in)
420
+ fps = original_fps
421
 
422
+ # For CPU optimization - lower FPS if original is high
423
+ if fps > 15:
424
+ fps = 15 # Lower fps for CPU processing
425
+
426
+ print(f"Creating video with {len(output_frames)} frames at {fps} FPS")
427
+ clip = ImageSequenceClip(output_frames, fps=fps)
428
+
429
+ # Write the result to a file - use lower quality for CPU
430
+ unique_id = datetime.now().strftime("%Y%m%d%H%M%S")
431
+ final_vid_output_path = f"output_video_{unique_id}.mp4"
432
+ final_vid_output_path = os.path.join(tempfile.gettempdir(), final_vid_output_path)
433
+
434
+ # Lower bitrate for CPU processing
435
+ clip.write_videofile(
436
+ final_vid_output_path,
437
+ codec="libx264",
438
+ bitrate="800k",
439
+ threads=2, # Use fewer threads for CPU
440
+ logger=None # Disable logger to reduce console output
441
+ )
442
+
443
+ # Free memory
444
+ del video_segments
445
+ del output_frames
446
+ import gc
447
+ gc.collect()
448
 
449
+ return (
450
+ gr.update(value=final_vid_output_path, visible=True),
451
+ session_state,
452
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
453
 
454
+ except Exception as e:
455
+ print(f"Error in propagate_to_all: {e}")
456
+ return (
457
+ gr.update(value=None, visible=False),
458
+ session_state,
459
+ )
 
 
 
 
 
 
460
 
461
 
462
  def update_ui():