Update app.py
Browse files
app.py
CHANGED
@@ -281,6 +281,7 @@ def segment_with_points(
|
|
281 |
# Ensure we have a valid first frame and inference state
|
282 |
if session_state["first_frame"] is None or session_state["inference_state"] is None:
|
283 |
print("Error: Cannot segment. No video loaded or inference state missing.")
|
|
|
284 |
return (
|
285 |
session_state["first_frame"], # points_map remains unchanged
|
286 |
None, # output_image remains unchanged or cleared
|
@@ -323,6 +324,7 @@ def segment_with_points(
|
|
323 |
# Convert the transparent layer back to an image and composite onto the first frame
|
324 |
transparent_layer_points_pil = Image.fromarray(transparent_layer_points, "RGBA")
|
325 |
# Combine the first frame image with the points layer for the points_map output
|
|
|
326 |
selected_point_map_img = Image.alpha_composite(
|
327 |
first_frame_pil.copy(), transparent_layer_points_pil
|
328 |
)
|
@@ -337,7 +339,9 @@ def segment_with_points(
|
|
337 |
|
338 |
# Add new points to the predictor's state and get the mask for the first frame
|
339 |
# This call performs segmentation on the current frame (frame_idx=0) using all accumulated points
|
|
|
340 |
try:
|
|
|
341 |
_, _, out_mask_logits = predictor.add_new_points(
|
342 |
inference_state=session_state["inference_state"],
|
343 |
frame_idx=0, # Always segment on the first frame initially
|
@@ -347,8 +351,7 @@ def segment_with_points(
|
|
347 |
)
|
348 |
|
349 |
# Process logits: detach from graph, move to CPU, apply threshold
|
350 |
-
# out_mask_logits is [
|
351 |
-
# out_mask_logits[0] is the tensor for obj_id=OBJ_ID
|
352 |
mask_tensor = (out_mask_logits[0][0].detach().cpu() > 0.0) # Apply threshold and get the single mask tensor [H, W]
|
353 |
mask_numpy = mask_tensor.numpy() # Convert to numpy
|
354 |
|
@@ -356,12 +359,12 @@ def segment_with_points(
|
|
356 |
mask_image_pil = show_mask(mask_numpy, obj_id=OBJ_ID) # show_mask returns RGBA PIL Image
|
357 |
|
358 |
# Composite the mask onto the first frame for the output_image
|
|
|
359 |
first_frame_output_img = Image.alpha_composite(first_frame_pil.copy(), mask_image_pil)
|
360 |
|
361 |
except Exception as e:
|
362 |
print(f"Error during segmentation on first frame: {e}")
|
363 |
-
# On error,
|
364 |
-
first_frame_output_img = None
|
365 |
|
366 |
|
367 |
return selected_point_map_img, first_frame_output_img, session_state
|
@@ -399,7 +402,9 @@ def show_mask(mask, obj_id=None, random_color=False, convert_to_image=True):
|
|
399 |
# Apply color where mask is True
|
400 |
# Need to reshape color to be broadcastable [1, 1, 4]
|
401 |
colored_mask = np.zeros((h, w, 4), dtype=np.float32) # Start with fully transparent black
|
402 |
-
|
|
|
|
|
403 |
|
404 |
# Convert to uint8 [0-255]
|
405 |
colored_mask_uint8 = (colored_mask * 255).astype(np.uint8)
|
@@ -446,6 +451,7 @@ def propagate_to_all(
|
|
446 |
video_segments[out_frame_idx] = {
|
447 |
# out_mask_logits is a list of tensors (one per object tracked in this frame)
|
448 |
# Each tensor is [batch_size, H, W]. Batch size is 1 here.
|
|
|
449 |
out_obj_id: (out_mask_logits[i][0].detach().cpu() > 0.0).numpy()
|
450 |
for i, out_obj_id in enumerate(out_obj_ids)
|
451 |
}
|
@@ -603,7 +609,7 @@ with gr.Blocks() as demo:
|
|
603 |
# points_map is where users click to add points. Needs to be interactive.
|
604 |
# Shows the first frame with points drawn on it.
|
605 |
points_map = gr.Image(
|
606 |
-
label="Frame
|
607 |
type="numpy",
|
608 |
interactive=True, # Make interactive to capture clicks
|
609 |
height=400, # Set a fixed height for better UI
|
@@ -626,7 +632,7 @@ with gr.Blocks() as demo:
|
|
626 |
|
627 |
# output_image shows the segmentation mask prediction on the *first* frame
|
628 |
output_image = gr.Image(
|
629 |
-
label="
|
630 |
type="numpy",
|
631 |
interactive=False, # Not interactive, just displays the mask
|
632 |
height=400, # Match height of points_map
|
@@ -749,5 +755,6 @@ with gr.Blocks() as demo:
|
|
749 |
# Launch the Gradio demo
|
750 |
demo.queue() # Enable queuing for sequential processing under concurrency limits
|
751 |
print("Gradio demo starting...")
|
|
|
752 |
demo.launch()
|
753 |
print("Gradio demo launched.")
|
|
|
281 |
# Ensure we have a valid first frame and inference state
|
282 |
if session_state["first_frame"] is None or session_state["inference_state"] is None:
|
283 |
print("Error: Cannot segment. No video loaded or inference state missing.")
|
284 |
+
# Return current states to avoid errors, without changing UI much
|
285 |
return (
|
286 |
session_state["first_frame"], # points_map remains unchanged
|
287 |
None, # output_image remains unchanged or cleared
|
|
|
324 |
# Convert the transparent layer back to an image and composite onto the first frame
|
325 |
transparent_layer_points_pil = Image.fromarray(transparent_layer_points, "RGBA")
|
326 |
# Combine the first frame image with the points layer for the points_map output
|
327 |
+
# points_map shows the first frame *with the points you added*.
|
328 |
selected_point_map_img = Image.alpha_composite(
|
329 |
first_frame_pil.copy(), transparent_layer_points_pil
|
330 |
)
|
|
|
339 |
|
340 |
# Add new points to the predictor's state and get the mask for the first frame
|
341 |
# This call performs segmentation on the current frame (frame_idx=0) using all accumulated points
|
342 |
+
first_frame_output_img = None # Initialize output mask image as None in case of error
|
343 |
try:
|
344 |
+
# Note: predictor.add_new_points modifies the internal inference_state
|
345 |
_, _, out_mask_logits = predictor.add_new_points(
|
346 |
inference_state=session_state["inference_state"],
|
347 |
frame_idx=0, # Always segment on the first frame initially
|
|
|
351 |
)
|
352 |
|
353 |
# Process logits: detach from graph, move to CPU, apply threshold
|
354 |
+
# out_mask_logits is a list of tensors [tensor([H, W])] for the requested obj_id
|
|
|
355 |
mask_tensor = (out_mask_logits[0][0].detach().cpu() > 0.0) # Apply threshold and get the single mask tensor [H, W]
|
356 |
mask_numpy = mask_tensor.numpy() # Convert to numpy
|
357 |
|
|
|
359 |
mask_image_pil = show_mask(mask_numpy, obj_id=OBJ_ID) # show_mask returns RGBA PIL Image
|
360 |
|
361 |
# Composite the mask onto the first frame for the output_image
|
362 |
+
# output_image shows the first frame *with the segmentation mask result*.
|
363 |
first_frame_output_img = Image.alpha_composite(first_frame_pil.copy(), mask_image_pil)
|
364 |
|
365 |
except Exception as e:
|
366 |
print(f"Error during segmentation on first frame: {e}")
|
367 |
+
# On error, first_frame_output_img remains None
|
|
|
368 |
|
369 |
|
370 |
return selected_point_map_img, first_frame_output_img, session_state
|
|
|
402 |
# Apply color where mask is True
|
403 |
# Need to reshape color to be broadcastable [1, 1, 4]
|
404 |
colored_mask = np.zeros((h, w, 4), dtype=np.float32) # Start with fully transparent black
|
405 |
+
# Apply the color only where the mask is True.
|
406 |
+
# This directly creates the colored overlay with transparency.
|
407 |
+
colored_mask[mask] = color
|
408 |
|
409 |
# Convert to uint8 [0-255]
|
410 |
colored_mask_uint8 = (colored_mask * 255).astype(np.uint8)
|
|
|
451 |
video_segments[out_frame_idx] = {
|
452 |
# out_mask_logits is a list of tensors (one per object tracked in this frame)
|
453 |
# Each tensor is [batch_size, H, W]. Batch size is 1 here.
|
454 |
+
# Access the first element of the batch [0]
|
455 |
out_obj_id: (out_mask_logits[i][0].detach().cpu() > 0.0).numpy()
|
456 |
for i, out_obj_id in enumerate(out_obj_ids)
|
457 |
}
|
|
|
609 |
# points_map is where users click to add points. Needs to be interactive.
|
610 |
# Shows the first frame with points drawn on it.
|
611 |
points_map = gr.Image(
|
612 |
+
label="Click on the First Frame to Add Points", # Clearer label
|
613 |
type="numpy",
|
614 |
interactive=True, # Make interactive to capture clicks
|
615 |
height=400, # Set a fixed height for better UI
|
|
|
632 |
|
633 |
# output_image shows the segmentation mask prediction on the *first* frame
|
634 |
output_image = gr.Image(
|
635 |
+
label="Segmentation Mask on First Frame", # Clearer label
|
636 |
type="numpy",
|
637 |
interactive=False, # Not interactive, just displays the mask
|
638 |
height=400, # Match height of points_map
|
|
|
755 |
# Launch the Gradio demo
|
756 |
demo.queue() # Enable queuing for sequential processing under concurrency limits
|
757 |
print("Gradio demo starting...")
|
758 |
+
# Removed share=True for local debugging unless you specifically need a public link
|
759 |
demo.launch()
|
760 |
print("Gradio demo launched.")
|