Update app.py
Browse files
app.py
CHANGED
@@ -264,6 +264,8 @@ def segment_with_points(
|
|
264 |
# Define the circle radius as a fraction of the smaller dimension
|
265 |
fraction = 0.01 # You can adjust this value as needed
|
266 |
radius = int(fraction * min(w, h))
|
|
|
|
|
267 |
|
268 |
# Create a transparent layer to draw on
|
269 |
transparent_layer = np.zeros((h, w, 4), dtype=np.uint8)
|
@@ -280,9 +282,8 @@ def segment_with_points(
|
|
280 |
transparent_background, transparent_layer
|
281 |
)
|
282 |
|
283 |
-
#
|
284 |
points = np.array(session_state["input_points"], dtype=np.float32)
|
285 |
-
# for labels, `1` means positive click and `0` means negative click
|
286 |
labels = np.array(session_state["input_labels"], np.int32)
|
287 |
|
288 |
try:
|
@@ -301,24 +302,52 @@ def segment_with_points(
|
|
301 |
labels=labels,
|
302 |
)
|
303 |
|
304 |
-
# Create the mask
|
305 |
-
|
306 |
|
307 |
-
#
|
308 |
-
|
309 |
-
|
310 |
-
mask_array.astype(np.uint8),
|
311 |
-
(w, h),
|
312 |
-
interpolation=cv2.INTER_NEAREST
|
313 |
-
).astype(bool)
|
314 |
|
315 |
-
|
|
|
316 |
|
317 |
-
#
|
318 |
-
if
|
319 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
320 |
|
321 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
322 |
except Exception as e:
|
323 |
print(f"Error in segmentation: {e}")
|
324 |
import traceback
|
@@ -326,46 +355,66 @@ def segment_with_points(
|
|
326 |
# Return just the points as fallback
|
327 |
first_frame_output = selected_point_map
|
328 |
|
329 |
-
return selected_point_map, first_frame_output, session_state
|
330 |
|
331 |
def show_mask(mask, obj_id=None, random_color=False, convert_to_image=True):
|
332 |
"""Convert binary mask to RGBA image for visualization."""
|
333 |
-
if
|
334 |
-
|
335 |
-
|
336 |
-
|
337 |
-
|
338 |
-
|
|
|
|
|
339 |
|
340 |
-
#
|
341 |
if len(mask.shape) == 2:
|
342 |
h, w = mask.shape
|
343 |
else:
|
344 |
h, w = mask.shape[-2:]
|
345 |
|
346 |
-
|
347 |
-
|
348 |
-
|
|
|
|
|
|
|
|
|
349 |
|
350 |
-
|
351 |
-
|
352 |
-
|
353 |
-
|
354 |
-
|
355 |
-
|
356 |
-
|
357 |
-
proper_mask[:, :, :min(mask_rgba.shape[2], 4)] = mask_rgba[:, :, :min(mask_rgba.shape[2], 4)]
|
358 |
-
mask_rgba = proper_mask
|
359 |
-
|
360 |
-
# Create the PIL image
|
361 |
-
return Image.fromarray(mask_rgba, "RGBA")
|
362 |
-
except Exception as e:
|
363 |
-
print(f"Error converting mask to image: {e}")
|
364 |
-
# Fallback: create a blank transparent image of correct size
|
365 |
-
blank = np.zeros((h, w, 4), dtype=np.uint8)
|
366 |
-
return Image.fromarray(blank, "RGBA")
|
367 |
|
368 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
369 |
|
370 |
def update_progress(progress_percent, progress_bar):
|
371 |
"""Update progress bar during processing."""
|
@@ -398,17 +447,11 @@ def propagate_to_all(
|
|
398 |
video_segments = {} # video_segments contains the per-frame segmentation results
|
399 |
print("Starting propagate_in_video on CPU")
|
400 |
|
401 |
-
progress
|
402 |
-
|
403 |
-
# Get the count for progress reporting
|
404 |
-
all_frames_count = 0
|
405 |
-
for _ in predictor.propagate_in_video(session_state["inference_state"], count_only=True):
|
406 |
-
all_frames_count += 1
|
407 |
-
|
408 |
-
print(f"Total frames to process: {all_frames_count}")
|
409 |
-
progress.tqdm.total = all_frames_count
|
410 |
|
411 |
# Now do the actual processing with progress updates
|
|
|
412 |
for out_frame_idx, out_obj_ids, out_mask_logits in predictor.propagate_in_video(
|
413 |
session_state["inference_state"]
|
414 |
):
|
@@ -420,12 +463,13 @@ def propagate_to_all(
|
|
420 |
}
|
421 |
|
422 |
# Update progress
|
423 |
-
|
424 |
-
progress_percent = min(
|
425 |
session_state["progress"] = progress_percent
|
|
|
426 |
|
427 |
if out_frame_idx % 10 == 0:
|
428 |
-
print(f"Processed frame {out_frame_idx}
|
429 |
|
430 |
# Release memory periodically
|
431 |
if out_frame_idx % chunk_size == 0:
|
@@ -445,10 +489,8 @@ def propagate_to_all(
|
|
445 |
print(f"Total frames processed: {total_frames}")
|
446 |
|
447 |
# Update progress to show rendering phase
|
448 |
-
progress.tqdm.reset()
|
449 |
-
progress.tqdm.total = 2 # Two phases: rendering and video creation
|
450 |
-
progress.tqdm.update(1)
|
451 |
session_state["progress"] = 50
|
|
|
452 |
|
453 |
# Limit to max 50 frames for CPU processing
|
454 |
max_output_frames = 50
|
@@ -464,12 +506,12 @@ def propagate_to_all(
|
|
464 |
|
465 |
# Create output frames
|
466 |
output_frames = []
|
467 |
-
progress.tqdm.reset()
|
468 |
-
progress.tqdm.total = (total_frames // vis_frame_stride) + 1
|
469 |
|
470 |
-
|
|
|
|
|
|
|
471 |
if out_frame_idx not in video_segments or OBJ_ID not in video_segments[out_frame_idx]:
|
472 |
-
progress.tqdm.update(1)
|
473 |
continue
|
474 |
|
475 |
try:
|
@@ -481,41 +523,50 @@ def propagate_to_all(
|
|
481 |
frame_idx = out_frame_idx
|
482 |
|
483 |
frame = session_state["all_frames"][frame_idx]
|
484 |
-
transparent_background = Image.fromarray(frame).convert("RGBA")
|
485 |
|
486 |
-
#
|
|
|
487 |
out_mask = video_segments[out_frame_idx][OBJ_ID]
|
488 |
|
489 |
-
# Ensure the mask is not empty and has
|
490 |
-
if out_mask.size == 0:
|
491 |
-
print(f"Warning:
|
492 |
-
#
|
493 |
-
|
494 |
|
495 |
-
#
|
|
|
496 |
mask_h, mask_w = out_mask.shape[:2]
|
497 |
-
if mask_h != h or mask_w != w:
|
498 |
-
print(f"Resizing mask from {mask_h}x{mask_w} to {h}x{w}")
|
499 |
-
out_mask = cv2.resize(
|
500 |
-
out_mask.astype(np.uint8),
|
501 |
-
(w, h),
|
502 |
-
interpolation=cv2.INTER_NEAREST
|
503 |
-
).astype(bool)
|
504 |
|
505 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
506 |
|
507 |
-
#
|
508 |
-
|
509 |
-
mask_image = mask_image.resize(transparent_background.size, Image.NEAREST)
|
510 |
|
511 |
-
|
512 |
-
|
|
|
|
|
|
|
513 |
output_frames.append(output_frame)
|
514 |
|
515 |
# Update progress
|
516 |
-
|
517 |
-
progress_percent = 50 + min(50, int((len(output_frames) / (total_frames // vis_frame_stride)) * 50))
|
518 |
session_state["progress"] = progress_percent
|
|
|
519 |
|
520 |
# Clear memory periodically
|
521 |
if len(output_frames) % 10 == 0:
|
|
|
264 |
# Define the circle radius as a fraction of the smaller dimension
|
265 |
fraction = 0.01 # You can adjust this value as needed
|
266 |
radius = int(fraction * min(w, h))
|
267 |
+
if radius < 3:
|
268 |
+
radius = 3 # Ensure minimum visibility
|
269 |
|
270 |
# Create a transparent layer to draw on
|
271 |
transparent_layer = np.zeros((h, w, 4), dtype=np.uint8)
|
|
|
282 |
transparent_background, transparent_layer
|
283 |
)
|
284 |
|
285 |
+
# Use the clicked points and labels
|
286 |
points = np.array(session_state["input_points"], dtype=np.float32)
|
|
|
287 |
labels = np.array(session_state["input_labels"], np.int32)
|
288 |
|
289 |
try:
|
|
|
302 |
labels=labels,
|
303 |
)
|
304 |
|
305 |
+
# Create the mask and check dimensions first
|
306 |
+
out_mask = (out_mask_logits[0] > 0.0).cpu().numpy()
|
307 |
|
308 |
+
# Convert to RGB for visualization
|
309 |
+
# Create an overlay with semi-transparent color
|
310 |
+
overlay = np.zeros((h, w, 3), dtype=np.uint8)
|
|
|
|
|
|
|
|
|
311 |
|
312 |
+
# Create a colored mask - blue with opacity
|
313 |
+
overlay_mask = np.zeros_like(overlay)
|
314 |
|
315 |
+
# Resize mask carefully if needed - handle empty dimensions
|
316 |
+
if out_mask.shape[0] > 0 and out_mask.shape[1] > 0:
|
317 |
+
# Check if dimensions differ
|
318 |
+
if out_mask.shape[:2] != (h, w):
|
319 |
+
print(f"Resizing mask from {out_mask.shape[:2]} to {h}x{w}")
|
320 |
+
# Use numpy/PIL for resizing to avoid OpenCV issues
|
321 |
+
from PIL import Image
|
322 |
+
mask_img = Image.fromarray(out_mask.astype(np.uint8) * 255)
|
323 |
+
mask_img = mask_img.resize((w, h), Image.NEAREST)
|
324 |
+
out_mask = np.array(mask_img) > 0
|
325 |
|
326 |
+
# Apply mask color
|
327 |
+
overlay_mask[out_mask] = [0, 120, 255] # Blue color for mask
|
328 |
+
|
329 |
+
# Blend original frame with mask
|
330 |
+
alpha = 0.5 # Opacity
|
331 |
+
frame_with_mask = cv2.addWeighted(
|
332 |
+
first_frame, 1, overlay_mask, alpha, 0
|
333 |
+
)
|
334 |
+
|
335 |
+
# Add points on top of mask
|
336 |
+
points_overlay = np.zeros((h, w, 4), dtype=np.uint8)
|
337 |
+
for index, track in enumerate(session_state["input_points"]):
|
338 |
+
if session_state["input_labels"][index] == 1:
|
339 |
+
cv2.circle(points_overlay, track, radius, (0, 255, 0, 255), -1) # Green
|
340 |
+
else:
|
341 |
+
cv2.circle(points_overlay, track, radius, (255, 0, 0, 255), -1) # Red
|
342 |
+
|
343 |
+
# Convert to PIL for overlay
|
344 |
+
frame_with_mask_pil = Image.fromarray(frame_with_mask)
|
345 |
+
points_overlay_pil = Image.fromarray(points_overlay, "RGBA")
|
346 |
+
|
347 |
+
# Final composite
|
348 |
+
first_frame_output = Image.alpha_composite(
|
349 |
+
frame_with_mask_pil.convert("RGBA"), points_overlay_pil
|
350 |
+
)
|
351 |
except Exception as e:
|
352 |
print(f"Error in segmentation: {e}")
|
353 |
import traceback
|
|
|
355 |
# Return just the points as fallback
|
356 |
first_frame_output = selected_point_map
|
357 |
|
358 |
+
return selected_point_map, np.array(first_frame_output), session_state
|
359 |
|
360 |
def show_mask(mask, obj_id=None, random_color=False, convert_to_image=True):
|
361 |
"""Convert binary mask to RGBA image for visualization."""
|
362 |
+
# Check if mask is valid
|
363 |
+
if mask is None or mask.size == 0:
|
364 |
+
print("Warning: Empty mask provided to show_mask")
|
365 |
+
# Return an empty transparent mask
|
366 |
+
if convert_to_image:
|
367 |
+
return Image.new('RGBA', (100, 100), (0, 0, 0, 0))
|
368 |
+
else:
|
369 |
+
return np.zeros((100, 100, 4), dtype=np.uint8)
|
370 |
|
371 |
+
# Get mask dimensions
|
372 |
if len(mask.shape) == 2:
|
373 |
h, w = mask.shape
|
374 |
else:
|
375 |
h, w = mask.shape[-2:]
|
376 |
|
377 |
+
if h == 0 or w == 0:
|
378 |
+
print(f"Warning: Invalid mask dimensions: {h}x{w}")
|
379 |
+
# Return an empty transparent mask
|
380 |
+
if convert_to_image:
|
381 |
+
return Image.new('RGBA', (100, 100), (0, 0, 0, 0))
|
382 |
+
else:
|
383 |
+
return np.zeros((100, 100, 4), dtype=np.uint8)
|
384 |
|
385 |
+
# Set the color for visualization
|
386 |
+
if random_color:
|
387 |
+
color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
|
388 |
+
else:
|
389 |
+
cmap = plt.get_cmap("tab10")
|
390 |
+
cmap_idx = 0 if obj_id is None else obj_id
|
391 |
+
color = np.array([*cmap(cmap_idx)[:3], 0.6])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
392 |
|
393 |
+
try:
|
394 |
+
# Create a colored visualization of the mask
|
395 |
+
colored_mask = np.zeros((h, w, 4), dtype=np.uint8)
|
396 |
+
|
397 |
+
# Apply color to mask areas (where mask is True)
|
398 |
+
for i in range(3): # RGB channels
|
399 |
+
colored_mask[:, :, i] = (mask * color[i] * 255).astype(np.uint8)
|
400 |
+
|
401 |
+
# Set alpha channel
|
402 |
+
colored_mask[:, :, 3] = (mask * color[3] * 255).astype(np.uint8)
|
403 |
+
|
404 |
+
if convert_to_image:
|
405 |
+
return Image.fromarray(colored_mask, "RGBA")
|
406 |
+
else:
|
407 |
+
return colored_mask
|
408 |
+
except Exception as e:
|
409 |
+
print(f"Error in show_mask: {e}")
|
410 |
+
import traceback
|
411 |
+
traceback.print_exc()
|
412 |
+
|
413 |
+
# Return a fallback transparent image
|
414 |
+
if convert_to_image:
|
415 |
+
return Image.new('RGBA', (h, w), (0, 0, 0, 0))
|
416 |
+
else:
|
417 |
+
return np.zeros((h, w, 4), dtype=np.uint8)
|
418 |
|
419 |
def update_progress(progress_percent, progress_bar):
|
420 |
"""Update progress bar during processing."""
|
|
|
447 |
video_segments = {} # video_segments contains the per-frame segmentation results
|
448 |
print("Starting propagate_in_video on CPU")
|
449 |
|
450 |
+
# Get the count for progress reporting (estimate)
|
451 |
+
all_frames_count = 300 # Reasonable estimate
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
452 |
|
453 |
# Now do the actual processing with progress updates
|
454 |
+
current_frame = 0
|
455 |
for out_frame_idx, out_obj_ids, out_mask_logits in predictor.propagate_in_video(
|
456 |
session_state["inference_state"]
|
457 |
):
|
|
|
463 |
}
|
464 |
|
465 |
# Update progress
|
466 |
+
current_frame += 1
|
467 |
+
progress_percent = min(50, int((current_frame / all_frames_count) * 50))
|
468 |
session_state["progress"] = progress_percent
|
469 |
+
progress(progress_percent/100, desc="Processing frames")
|
470 |
|
471 |
if out_frame_idx % 10 == 0:
|
472 |
+
print(f"Processed frame {out_frame_idx} ({progress_percent}%)")
|
473 |
|
474 |
# Release memory periodically
|
475 |
if out_frame_idx % chunk_size == 0:
|
|
|
489 |
print(f"Total frames processed: {total_frames}")
|
490 |
|
491 |
# Update progress to show rendering phase
|
|
|
|
|
|
|
492 |
session_state["progress"] = 50
|
493 |
+
progress(0.5, desc="Rendering video")
|
494 |
|
495 |
# Limit to max 50 frames for CPU processing
|
496 |
max_output_frames = 50
|
|
|
506 |
|
507 |
# Create output frames
|
508 |
output_frames = []
|
|
|
|
|
509 |
|
510 |
+
frame_indices = list(range(0, total_frames, vis_frame_stride))
|
511 |
+
total_output_frames = len(frame_indices)
|
512 |
+
|
513 |
+
for i, out_frame_idx in enumerate(frame_indices):
|
514 |
if out_frame_idx not in video_segments or OBJ_ID not in video_segments[out_frame_idx]:
|
|
|
515 |
continue
|
516 |
|
517 |
try:
|
|
|
523 |
frame_idx = out_frame_idx
|
524 |
|
525 |
frame = session_state["all_frames"][frame_idx]
|
|
|
526 |
|
527 |
+
# Create a colored overlay rather than using transparency
|
528 |
+
# Get the mask
|
529 |
out_mask = video_segments[out_frame_idx][OBJ_ID]
|
530 |
|
531 |
+
# Ensure the mask is not empty and has valid dimensions
|
532 |
+
if out_mask.size == 0 or 0 in out_mask.shape:
|
533 |
+
print(f"Warning: Invalid mask for frame {out_frame_idx}")
|
534 |
+
# Skip this frame
|
535 |
+
continue
|
536 |
|
537 |
+
# Get dimensions
|
538 |
+
frame_h, frame_w = frame.shape[:2]
|
539 |
mask_h, mask_w = out_mask.shape[:2]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
540 |
|
541 |
+
# Resize mask using PIL if dimensions don't match (avoid OpenCV)
|
542 |
+
if mask_h != frame_h or mask_w != frame_w:
|
543 |
+
print(f"Resizing mask from {mask_h}x{mask_w} to {frame_h}x{frame_w}")
|
544 |
+
try:
|
545 |
+
mask_img = Image.fromarray(out_mask.astype(np.uint8) * 255)
|
546 |
+
mask_img = mask_img.resize((frame_w, frame_h), Image.NEAREST)
|
547 |
+
out_mask = np.array(mask_img) > 0
|
548 |
+
except Exception as e:
|
549 |
+
print(f"Error resizing mask: {e}")
|
550 |
+
# Skip this frame if resize fails
|
551 |
+
continue
|
552 |
+
|
553 |
+
# Create an overlay with semi-transparent color
|
554 |
+
overlay = np.zeros_like(frame)
|
555 |
|
556 |
+
# Set blue color for mask area
|
557 |
+
overlay[out_mask] = [0, 120, 255] # BGR format for OpenCV
|
|
|
558 |
|
559 |
+
# Blend with original frame
|
560 |
+
alpha = 0.5
|
561 |
+
output_frame = cv2.addWeighted(frame, 1, overlay, alpha, 0)
|
562 |
+
|
563 |
+
# Add to output frames
|
564 |
output_frames.append(output_frame)
|
565 |
|
566 |
# Update progress
|
567 |
+
progress_percent = 50 + min(50, int((i / total_output_frames) * 50))
|
|
|
568 |
session_state["progress"] = progress_percent
|
569 |
+
progress(progress_percent/100, desc=f"Rendering video frames ({i}/{total_output_frames})")
|
570 |
|
571 |
# Clear memory periodically
|
572 |
if len(output_frames) % 10 == 0:
|