Spaces:
bla
/
Runtime error

bla commited on
Commit
fa0b563
·
verified ·
1 Parent(s): 2a466e4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +14 -7
app.py CHANGED
@@ -281,6 +281,7 @@ def segment_with_points(
281
  # Ensure we have a valid first frame and inference state
282
  if session_state["first_frame"] is None or session_state["inference_state"] is None:
283
  print("Error: Cannot segment. No video loaded or inference state missing.")
 
284
  return (
285
  session_state["first_frame"], # points_map remains unchanged
286
  None, # output_image remains unchanged or cleared
@@ -323,6 +324,7 @@ def segment_with_points(
323
  # Convert the transparent layer back to an image and composite onto the first frame
324
  transparent_layer_points_pil = Image.fromarray(transparent_layer_points, "RGBA")
325
  # Combine the first frame image with the points layer for the points_map output
 
326
  selected_point_map_img = Image.alpha_composite(
327
  first_frame_pil.copy(), transparent_layer_points_pil
328
  )
@@ -337,7 +339,9 @@ def segment_with_points(
337
 
338
  # Add new points to the predictor's state and get the mask for the first frame
339
  # This call performs segmentation on the current frame (frame_idx=0) using all accumulated points
 
340
  try:
 
341
  _, _, out_mask_logits = predictor.add_new_points(
342
  inference_state=session_state["inference_state"],
343
  frame_idx=0, # Always segment on the first frame initially
@@ -347,8 +351,7 @@ def segment_with_points(
347
  )
348
 
349
  # Process logits: detach from graph, move to CPU, apply threshold
350
- # out_mask_logits is [batch_size, H, W] (batch_size=1 here)
351
- # out_mask_logits[0] is the tensor for obj_id=OBJ_ID
352
  mask_tensor = (out_mask_logits[0][0].detach().cpu() > 0.0) # Apply threshold and get the single mask tensor [H, W]
353
  mask_numpy = mask_tensor.numpy() # Convert to numpy
354
 
@@ -356,12 +359,12 @@ def segment_with_points(
356
  mask_image_pil = show_mask(mask_numpy, obj_id=OBJ_ID) # show_mask returns RGBA PIL Image
357
 
358
  # Composite the mask onto the first frame for the output_image
 
359
  first_frame_output_img = Image.alpha_composite(first_frame_pil.copy(), mask_image_pil)
360
 
361
  except Exception as e:
362
  print(f"Error during segmentation on first frame: {e}")
363
- # On error, return the points_map but clear the output_image
364
- first_frame_output_img = None
365
 
366
 
367
  return selected_point_map_img, first_frame_output_img, session_state
@@ -399,7 +402,9 @@ def show_mask(mask, obj_id=None, random_color=False, convert_to_image=True):
399
  # Apply color where mask is True
400
  # Need to reshape color to be broadcastable [1, 1, 4]
401
  colored_mask = np.zeros((h, w, 4), dtype=np.float32) # Start with fully transparent black
402
- colored_mask[mask] = color # Apply color where mask is True
 
 
403
 
404
  # Convert to uint8 [0-255]
405
  colored_mask_uint8 = (colored_mask * 255).astype(np.uint8)
@@ -446,6 +451,7 @@ def propagate_to_all(
446
  video_segments[out_frame_idx] = {
447
  # out_mask_logits is a list of tensors (one per object tracked in this frame)
448
  # Each tensor is [batch_size, H, W]. Batch size is 1 here.
 
449
  out_obj_id: (out_mask_logits[i][0].detach().cpu() > 0.0).numpy()
450
  for i, out_obj_id in enumerate(out_obj_ids)
451
  }
@@ -603,7 +609,7 @@ with gr.Blocks() as demo:
603
  # points_map is where users click to add points. Needs to be interactive.
604
  # Shows the first frame with points drawn on it.
605
  points_map = gr.Image(
606
- label="Frame with Point Prompt",
607
  type="numpy",
608
  interactive=True, # Make interactive to capture clicks
609
  height=400, # Set a fixed height for better UI
@@ -626,7 +632,7 @@ with gr.Blocks() as demo:
626
 
627
  # output_image shows the segmentation mask prediction on the *first* frame
628
  output_image = gr.Image(
629
- label="Reference Mask (First Frame)",
630
  type="numpy",
631
  interactive=False, # Not interactive, just displays the mask
632
  height=400, # Match height of points_map
@@ -749,5 +755,6 @@ with gr.Blocks() as demo:
749
  # Launch the Gradio demo
750
  demo.queue() # Enable queuing for sequential processing under concurrency limits
751
  print("Gradio demo starting...")
 
752
  demo.launch()
753
  print("Gradio demo launched.")
 
281
  # Ensure we have a valid first frame and inference state
282
  if session_state["first_frame"] is None or session_state["inference_state"] is None:
283
  print("Error: Cannot segment. No video loaded or inference state missing.")
284
+ # Return current states to avoid errors, without changing UI much
285
  return (
286
  session_state["first_frame"], # points_map remains unchanged
287
  None, # output_image remains unchanged or cleared
 
324
  # Convert the transparent layer back to an image and composite onto the first frame
325
  transparent_layer_points_pil = Image.fromarray(transparent_layer_points, "RGBA")
326
  # Combine the first frame image with the points layer for the points_map output
327
+ # points_map shows the first frame *with the points you added*.
328
  selected_point_map_img = Image.alpha_composite(
329
  first_frame_pil.copy(), transparent_layer_points_pil
330
  )
 
339
 
340
  # Add new points to the predictor's state and get the mask for the first frame
341
  # This call performs segmentation on the current frame (frame_idx=0) using all accumulated points
342
+ first_frame_output_img = None # Initialize output mask image as None in case of error
343
  try:
344
+ # Note: predictor.add_new_points modifies the internal inference_state
345
  _, _, out_mask_logits = predictor.add_new_points(
346
  inference_state=session_state["inference_state"],
347
  frame_idx=0, # Always segment on the first frame initially
 
351
  )
352
 
353
  # Process logits: detach from graph, move to CPU, apply threshold
354
+ # out_mask_logits is a list of tensors [tensor([H, W])] for the requested obj_id
 
355
  mask_tensor = (out_mask_logits[0][0].detach().cpu() > 0.0) # Apply threshold and get the single mask tensor [H, W]
356
  mask_numpy = mask_tensor.numpy() # Convert to numpy
357
 
 
359
  mask_image_pil = show_mask(mask_numpy, obj_id=OBJ_ID) # show_mask returns RGBA PIL Image
360
 
361
  # Composite the mask onto the first frame for the output_image
362
+ # output_image shows the first frame *with the segmentation mask result*.
363
  first_frame_output_img = Image.alpha_composite(first_frame_pil.copy(), mask_image_pil)
364
 
365
  except Exception as e:
366
  print(f"Error during segmentation on first frame: {e}")
367
+ # On error, first_frame_output_img remains None
 
368
 
369
 
370
  return selected_point_map_img, first_frame_output_img, session_state
 
402
  # Apply color where mask is True
403
  # Need to reshape color to be broadcastable [1, 1, 4]
404
  colored_mask = np.zeros((h, w, 4), dtype=np.float32) # Start with fully transparent black
405
+ # Apply the color only where the mask is True.
406
+ # This directly creates the colored overlay with transparency.
407
+ colored_mask[mask] = color
408
 
409
  # Convert to uint8 [0-255]
410
  colored_mask_uint8 = (colored_mask * 255).astype(np.uint8)
 
451
  video_segments[out_frame_idx] = {
452
  # out_mask_logits is a list of tensors (one per object tracked in this frame)
453
  # Each tensor is [batch_size, H, W]. Batch size is 1 here.
454
+ # Access the first element of the batch [0]
455
  out_obj_id: (out_mask_logits[i][0].detach().cpu() > 0.0).numpy()
456
  for i, out_obj_id in enumerate(out_obj_ids)
457
  }
 
609
  # points_map is where users click to add points. Needs to be interactive.
610
  # Shows the first frame with points drawn on it.
611
  points_map = gr.Image(
612
+ label="Click on the First Frame to Add Points", # Clearer label
613
  type="numpy",
614
  interactive=True, # Make interactive to capture clicks
615
  height=400, # Set a fixed height for better UI
 
632
 
633
  # output_image shows the segmentation mask prediction on the *first* frame
634
  output_image = gr.Image(
635
+ label="Segmentation Mask on First Frame", # Clearer label
636
  type="numpy",
637
  interactive=False, # Not interactive, just displays the mask
638
  height=400, # Match height of points_map
 
755
  # Launch the Gradio demo
756
  demo.queue() # Enable queuing for sequential processing under concurrency limits
757
  print("Gradio demo starting...")
758
+ # Removed share=True for local debugging unless you specifically need a public link
759
  demo.launch()
760
  print("Gradio demo launched.")