Spaces:
bla
/
Runtime error

bla commited on
Commit
5b2f03e
·
verified ·
1 Parent(s): 807f473

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +138 -87
app.py CHANGED
@@ -264,6 +264,8 @@ def segment_with_points(
264
  # Define the circle radius as a fraction of the smaller dimension
265
  fraction = 0.01 # You can adjust this value as needed
266
  radius = int(fraction * min(w, h))
 
 
267
 
268
  # Create a transparent layer to draw on
269
  transparent_layer = np.zeros((h, w, 4), dtype=np.uint8)
@@ -280,9 +282,8 @@ def segment_with_points(
280
  transparent_background, transparent_layer
281
  )
282
 
283
- # Let's add a positive click at (x, y) = (210, 350) to get started
284
  points = np.array(session_state["input_points"], dtype=np.float32)
285
- # for labels, `1` means positive click and `0` means negative click
286
  labels = np.array(session_state["input_labels"], np.int32)
287
 
288
  try:
@@ -301,24 +302,52 @@ def segment_with_points(
301
  labels=labels,
302
  )
303
 
304
- # Create the mask
305
- mask_array = (out_mask_logits[0] > 0.0).cpu().numpy()
306
 
307
- # Ensure the mask has the same size as the frame
308
- if mask_array.shape[:2] != (h, w):
309
- mask_array = cv2.resize(
310
- mask_array.astype(np.uint8),
311
- (w, h),
312
- interpolation=cv2.INTER_NEAREST
313
- ).astype(bool)
314
 
315
- mask_image = show_mask(mask_array)
 
316
 
317
- # Make sure mask_image has the same size as the background
318
- if mask_image.size != transparent_background.size:
319
- mask_image = mask_image.resize(transparent_background.size, Image.NEAREST)
 
 
 
 
 
 
 
320
 
321
- first_frame_output = Image.alpha_composite(transparent_background, mask_image)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
322
  except Exception as e:
323
  print(f"Error in segmentation: {e}")
324
  import traceback
@@ -326,46 +355,66 @@ def segment_with_points(
326
  # Return just the points as fallback
327
  first_frame_output = selected_point_map
328
 
329
- return selected_point_map, first_frame_output, session_state
330
 
331
  def show_mask(mask, obj_id=None, random_color=False, convert_to_image=True):
332
  """Convert binary mask to RGBA image for visualization."""
333
- if random_color:
334
- color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
335
- else:
336
- cmap = plt.get_cmap("tab10")
337
- cmap_idx = 0 if obj_id is None else obj_id
338
- color = np.array([*cmap(cmap_idx)[:3], 0.6])
 
 
339
 
340
- # Handle different mask shapes properly
341
  if len(mask.shape) == 2:
342
  h, w = mask.shape
343
  else:
344
  h, w = mask.shape[-2:]
345
 
346
- # Ensure correct reshaping based on mask dimensions
347
- mask_reshaped = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
348
- mask_rgba = (mask_reshaped * 255).astype(np.uint8)
 
 
 
 
349
 
350
- if convert_to_image:
351
- try:
352
- # Ensure the mask has correct RGBA shape (h, w, 4)
353
- if mask_rgba.shape[2] != 4:
354
- # If not RGBA, create a proper RGBA array
355
- proper_mask = np.zeros((h, w, 4), dtype=np.uint8)
356
- # Copy available channels
357
- proper_mask[:, :, :min(mask_rgba.shape[2], 4)] = mask_rgba[:, :, :min(mask_rgba.shape[2], 4)]
358
- mask_rgba = proper_mask
359
-
360
- # Create the PIL image
361
- return Image.fromarray(mask_rgba, "RGBA")
362
- except Exception as e:
363
- print(f"Error converting mask to image: {e}")
364
- # Fallback: create a blank transparent image of correct size
365
- blank = np.zeros((h, w, 4), dtype=np.uint8)
366
- return Image.fromarray(blank, "RGBA")
367
 
368
- return mask_rgba
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
369
 
370
  def update_progress(progress_percent, progress_bar):
371
  """Update progress bar during processing."""
@@ -398,17 +447,11 @@ def propagate_to_all(
398
  video_segments = {} # video_segments contains the per-frame segmentation results
399
  print("Starting propagate_in_video on CPU")
400
 
401
- progress.tqdm.reset()
402
-
403
- # Get the count for progress reporting
404
- all_frames_count = 0
405
- for _ in predictor.propagate_in_video(session_state["inference_state"], count_only=True):
406
- all_frames_count += 1
407
-
408
- print(f"Total frames to process: {all_frames_count}")
409
- progress.tqdm.total = all_frames_count
410
 
411
  # Now do the actual processing with progress updates
 
412
  for out_frame_idx, out_obj_ids, out_mask_logits in predictor.propagate_in_video(
413
  session_state["inference_state"]
414
  ):
@@ -420,12 +463,13 @@ def propagate_to_all(
420
  }
421
 
422
  # Update progress
423
- progress.tqdm.update(1)
424
- progress_percent = min(100, int((out_frame_idx + 1) / all_frames_count * 100))
425
  session_state["progress"] = progress_percent
 
426
 
427
  if out_frame_idx % 10 == 0:
428
- print(f"Processed frame {out_frame_idx}/{all_frames_count} ({progress_percent}%)")
429
 
430
  # Release memory periodically
431
  if out_frame_idx % chunk_size == 0:
@@ -445,10 +489,8 @@ def propagate_to_all(
445
  print(f"Total frames processed: {total_frames}")
446
 
447
  # Update progress to show rendering phase
448
- progress.tqdm.reset()
449
- progress.tqdm.total = 2 # Two phases: rendering and video creation
450
- progress.tqdm.update(1)
451
  session_state["progress"] = 50
 
452
 
453
  # Limit to max 50 frames for CPU processing
454
  max_output_frames = 50
@@ -464,12 +506,12 @@ def propagate_to_all(
464
 
465
  # Create output frames
466
  output_frames = []
467
- progress.tqdm.reset()
468
- progress.tqdm.total = (total_frames // vis_frame_stride) + 1
469
 
470
- for out_frame_idx in range(0, total_frames, vis_frame_stride):
 
 
 
471
  if out_frame_idx not in video_segments or OBJ_ID not in video_segments[out_frame_idx]:
472
- progress.tqdm.update(1)
473
  continue
474
 
475
  try:
@@ -481,41 +523,50 @@ def propagate_to_all(
481
  frame_idx = out_frame_idx
482
 
483
  frame = session_state["all_frames"][frame_idx]
484
- transparent_background = Image.fromarray(frame).convert("RGBA")
485
 
486
- # Get the mask and ensure it's the right size
 
487
  out_mask = video_segments[out_frame_idx][OBJ_ID]
488
 
489
- # Ensure the mask is not empty and has the right dimensions
490
- if out_mask.size == 0:
491
- print(f"Warning: Empty mask for frame {out_frame_idx}")
492
- # Create an empty mask of the right size
493
- out_mask = np.zeros((h, w), dtype=bool)
494
 
495
- # Resize mask if dimensions don't match
 
496
  mask_h, mask_w = out_mask.shape[:2]
497
- if mask_h != h or mask_w != w:
498
- print(f"Resizing mask from {mask_h}x{mask_w} to {h}x{w}")
499
- out_mask = cv2.resize(
500
- out_mask.astype(np.uint8),
501
- (w, h),
502
- interpolation=cv2.INTER_NEAREST
503
- ).astype(bool)
504
 
505
- mask_image = show_mask(out_mask)
 
 
 
 
 
 
 
 
 
 
 
 
 
506
 
507
- # Make sure mask has same dimensions as background
508
- if mask_image.size != transparent_background.size:
509
- mask_image = mask_image.resize(transparent_background.size, Image.NEAREST)
510
 
511
- output_frame = Image.alpha_composite(transparent_background, mask_image)
512
- output_frame = np.array(output_frame)
 
 
 
513
  output_frames.append(output_frame)
514
 
515
  # Update progress
516
- progress.tqdm.update(1)
517
- progress_percent = 50 + min(50, int((len(output_frames) / (total_frames // vis_frame_stride)) * 50))
518
  session_state["progress"] = progress_percent
 
519
 
520
  # Clear memory periodically
521
  if len(output_frames) % 10 == 0:
 
264
  # Define the circle radius as a fraction of the smaller dimension
265
  fraction = 0.01 # You can adjust this value as needed
266
  radius = int(fraction * min(w, h))
267
+ if radius < 3:
268
+ radius = 3 # Ensure minimum visibility
269
 
270
  # Create a transparent layer to draw on
271
  transparent_layer = np.zeros((h, w, 4), dtype=np.uint8)
 
282
  transparent_background, transparent_layer
283
  )
284
 
285
+ # Use the clicked points and labels
286
  points = np.array(session_state["input_points"], dtype=np.float32)
 
287
  labels = np.array(session_state["input_labels"], np.int32)
288
 
289
  try:
 
302
  labels=labels,
303
  )
304
 
305
+ # Create the mask and check dimensions first
306
+ out_mask = (out_mask_logits[0] > 0.0).cpu().numpy()
307
 
308
+ # Convert to RGB for visualization
309
+ # Create an overlay with semi-transparent color
310
+ overlay = np.zeros((h, w, 3), dtype=np.uint8)
 
 
 
 
311
 
312
+ # Create a colored mask - blue with opacity
313
+ overlay_mask = np.zeros_like(overlay)
314
 
315
+ # Resize mask carefully if needed - handle empty dimensions
316
+ if out_mask.shape[0] > 0 and out_mask.shape[1] > 0:
317
+ # Check if dimensions differ
318
+ if out_mask.shape[:2] != (h, w):
319
+ print(f"Resizing mask from {out_mask.shape[:2]} to {h}x{w}")
320
+ # Use numpy/PIL for resizing to avoid OpenCV issues
321
+ from PIL import Image
322
+ mask_img = Image.fromarray(out_mask.astype(np.uint8) * 255)
323
+ mask_img = mask_img.resize((w, h), Image.NEAREST)
324
+ out_mask = np.array(mask_img) > 0
325
 
326
+ # Apply mask color
327
+ overlay_mask[out_mask] = [0, 120, 255] # Blue color for mask
328
+
329
+ # Blend original frame with mask
330
+ alpha = 0.5 # Opacity
331
+ frame_with_mask = cv2.addWeighted(
332
+ first_frame, 1, overlay_mask, alpha, 0
333
+ )
334
+
335
+ # Add points on top of mask
336
+ points_overlay = np.zeros((h, w, 4), dtype=np.uint8)
337
+ for index, track in enumerate(session_state["input_points"]):
338
+ if session_state["input_labels"][index] == 1:
339
+ cv2.circle(points_overlay, track, radius, (0, 255, 0, 255), -1) # Green
340
+ else:
341
+ cv2.circle(points_overlay, track, radius, (255, 0, 0, 255), -1) # Red
342
+
343
+ # Convert to PIL for overlay
344
+ frame_with_mask_pil = Image.fromarray(frame_with_mask)
345
+ points_overlay_pil = Image.fromarray(points_overlay, "RGBA")
346
+
347
+ # Final composite
348
+ first_frame_output = Image.alpha_composite(
349
+ frame_with_mask_pil.convert("RGBA"), points_overlay_pil
350
+ )
351
  except Exception as e:
352
  print(f"Error in segmentation: {e}")
353
  import traceback
 
355
  # Return just the points as fallback
356
  first_frame_output = selected_point_map
357
 
358
+ return selected_point_map, np.array(first_frame_output), session_state
359
 
360
  def show_mask(mask, obj_id=None, random_color=False, convert_to_image=True):
361
  """Convert binary mask to RGBA image for visualization."""
362
+ # Check if mask is valid
363
+ if mask is None or mask.size == 0:
364
+ print("Warning: Empty mask provided to show_mask")
365
+ # Return an empty transparent mask
366
+ if convert_to_image:
367
+ return Image.new('RGBA', (100, 100), (0, 0, 0, 0))
368
+ else:
369
+ return np.zeros((100, 100, 4), dtype=np.uint8)
370
 
371
+ # Get mask dimensions
372
  if len(mask.shape) == 2:
373
  h, w = mask.shape
374
  else:
375
  h, w = mask.shape[-2:]
376
 
377
+ if h == 0 or w == 0:
378
+ print(f"Warning: Invalid mask dimensions: {h}x{w}")
379
+ # Return an empty transparent mask
380
+ if convert_to_image:
381
+ return Image.new('RGBA', (100, 100), (0, 0, 0, 0))
382
+ else:
383
+ return np.zeros((100, 100, 4), dtype=np.uint8)
384
 
385
+ # Set the color for visualization
386
+ if random_color:
387
+ color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
388
+ else:
389
+ cmap = plt.get_cmap("tab10")
390
+ cmap_idx = 0 if obj_id is None else obj_id
391
+ color = np.array([*cmap(cmap_idx)[:3], 0.6])
 
 
 
 
 
 
 
 
 
 
392
 
393
+ try:
394
+ # Create a colored visualization of the mask
395
+ colored_mask = np.zeros((h, w, 4), dtype=np.uint8)
396
+
397
+ # Apply color to mask areas (where mask is True)
398
+ for i in range(3): # RGB channels
399
+ colored_mask[:, :, i] = (mask * color[i] * 255).astype(np.uint8)
400
+
401
+ # Set alpha channel
402
+ colored_mask[:, :, 3] = (mask * color[3] * 255).astype(np.uint8)
403
+
404
+ if convert_to_image:
405
+ return Image.fromarray(colored_mask, "RGBA")
406
+ else:
407
+ return colored_mask
408
+ except Exception as e:
409
+ print(f"Error in show_mask: {e}")
410
+ import traceback
411
+ traceback.print_exc()
412
+
413
+ # Return a fallback transparent image
414
+ if convert_to_image:
415
+ return Image.new('RGBA', (h, w), (0, 0, 0, 0))
416
+ else:
417
+ return np.zeros((h, w, 4), dtype=np.uint8)
418
 
419
  def update_progress(progress_percent, progress_bar):
420
  """Update progress bar during processing."""
 
447
  video_segments = {} # video_segments contains the per-frame segmentation results
448
  print("Starting propagate_in_video on CPU")
449
 
450
+ # Get the count for progress reporting (estimate)
451
+ all_frames_count = 300 # Reasonable estimate
 
 
 
 
 
 
 
452
 
453
  # Now do the actual processing with progress updates
454
+ current_frame = 0
455
  for out_frame_idx, out_obj_ids, out_mask_logits in predictor.propagate_in_video(
456
  session_state["inference_state"]
457
  ):
 
463
  }
464
 
465
  # Update progress
466
+ current_frame += 1
467
+ progress_percent = min(50, int((current_frame / all_frames_count) * 50))
468
  session_state["progress"] = progress_percent
469
+ progress(progress_percent/100, desc="Processing frames")
470
 
471
  if out_frame_idx % 10 == 0:
472
+ print(f"Processed frame {out_frame_idx} ({progress_percent}%)")
473
 
474
  # Release memory periodically
475
  if out_frame_idx % chunk_size == 0:
 
489
  print(f"Total frames processed: {total_frames}")
490
 
491
  # Update progress to show rendering phase
 
 
 
492
  session_state["progress"] = 50
493
+ progress(0.5, desc="Rendering video")
494
 
495
  # Limit to max 50 frames for CPU processing
496
  max_output_frames = 50
 
506
 
507
  # Create output frames
508
  output_frames = []
 
 
509
 
510
+ frame_indices = list(range(0, total_frames, vis_frame_stride))
511
+ total_output_frames = len(frame_indices)
512
+
513
+ for i, out_frame_idx in enumerate(frame_indices):
514
  if out_frame_idx not in video_segments or OBJ_ID not in video_segments[out_frame_idx]:
 
515
  continue
516
 
517
  try:
 
523
  frame_idx = out_frame_idx
524
 
525
  frame = session_state["all_frames"][frame_idx]
 
526
 
527
+ # Create a colored overlay rather than using transparency
528
+ # Get the mask
529
  out_mask = video_segments[out_frame_idx][OBJ_ID]
530
 
531
+ # Ensure the mask is not empty and has valid dimensions
532
+ if out_mask.size == 0 or 0 in out_mask.shape:
533
+ print(f"Warning: Invalid mask for frame {out_frame_idx}")
534
+ # Skip this frame
535
+ continue
536
 
537
+ # Get dimensions
538
+ frame_h, frame_w = frame.shape[:2]
539
  mask_h, mask_w = out_mask.shape[:2]
 
 
 
 
 
 
 
540
 
541
+ # Resize mask using PIL if dimensions don't match (avoid OpenCV)
542
+ if mask_h != frame_h or mask_w != frame_w:
543
+ print(f"Resizing mask from {mask_h}x{mask_w} to {frame_h}x{frame_w}")
544
+ try:
545
+ mask_img = Image.fromarray(out_mask.astype(np.uint8) * 255)
546
+ mask_img = mask_img.resize((frame_w, frame_h), Image.NEAREST)
547
+ out_mask = np.array(mask_img) > 0
548
+ except Exception as e:
549
+ print(f"Error resizing mask: {e}")
550
+ # Skip this frame if resize fails
551
+ continue
552
+
553
+ # Create an overlay with semi-transparent color
554
+ overlay = np.zeros_like(frame)
555
 
556
+ # Set blue color for mask area
557
+ overlay[out_mask] = [0, 120, 255] # BGR format for OpenCV
 
558
 
559
+ # Blend with original frame
560
+ alpha = 0.5
561
+ output_frame = cv2.addWeighted(frame, 1, overlay, alpha, 0)
562
+
563
+ # Add to output frames
564
  output_frames.append(output_frame)
565
 
566
  # Update progress
567
+ progress_percent = 50 + min(50, int((i / total_output_frames) * 50))
 
568
  session_state["progress"] = progress_percent
569
+ progress(progress_percent/100, desc=f"Rendering video frames ({i}/{total_output_frames})")
570
 
571
  # Clear memory periodically
572
  if len(output_frames) % 10 == 0: