|
|
|
|
|
|
|
|
|
|
|
|
|
import copy |
|
import os |
|
import time |
|
from datetime import datetime |
|
import tempfile |
|
|
|
import cv2 |
|
import matplotlib.pyplot as plt |
|
import numpy as np |
|
import gradio as gr |
|
import torch |
|
|
|
from moviepy.editor import ImageSequenceClip |
|
from PIL import Image |
|
from sam2.build_sam import build_sam2_video_predictor |
|
|
|
|
|
if 'TORCH_CUDNN_SDPA_ENABLED' in os.environ: |
|
del os.environ["TORCH_CUDNN_SDPA_ENABLED"] |
|
|
|
|
|
title = "<center><strong><font size='8'>EdgeTAM CPU<font></strong> <a href='https://github.com/facebookresearch/EdgeTAM'><font size='6'>[GitHub]</font></a> </center>" |
|
|
|
description_p = """# Instructions |
|
<ol> |
|
<li> Upload one video or click one example video</li> |
|
<li> Click 'include' point type, select the object to segment and track</li> |
|
<li> Click 'exclude' point type (optional), select the area you want to avoid segmenting and tracking</li> |
|
<li> Click the 'Track' button to obtain the masked video </li> |
|
</ol> |
|
""" |
|
|
|
|
|
examples = [ |
|
["examples/01_dog.mp4"], |
|
["examples/02_cups.mp4"], |
|
["examples/03_blocks.mp4"], |
|
["examples/04_coffee.mp4"], |
|
["examples/05_default_juggle.mp4"], |
|
] |
|
|
|
OBJ_ID = 0 |
|
|
|
|
|
sam2_checkpoint = "checkpoints/edgetam.pt" |
|
model_cfg = "edgetam.yaml" |
|
|
|
|
|
def check_file_exists(filepath): |
|
exists = os.path.exists(filepath) |
|
if not exists: |
|
print(f"WARNING: File not found: {filepath}") |
|
return exists |
|
|
|
|
|
model_files_exist = check_file_exists(sam2_checkpoint) and check_file_exists(model_cfg) |
|
predictor = None |
|
try: |
|
|
|
predictor = build_sam2_video_predictor(model_cfg, sam2_checkpoint, device="cpu") |
|
print("predictor loaded on CPU") |
|
except Exception as e: |
|
print(f"Error loading model: {e}") |
|
import traceback |
|
traceback.print_exc() |
|
|
|
|
|
def get_video_fps(video_path): |
|
cap = cv2.VideoCapture(video_path) |
|
if not cap.isOpened(): |
|
print("Error: Could not open video.") |
|
return 30.0 |
|
fps = cap.get(cv2.CAP_PROP_FPS) |
|
cap.release() |
|
return fps |
|
|
|
def reset(session_state): |
|
"""Reset all session state variables and UI elements.""" |
|
session_state["input_points"] = [] |
|
session_state["input_labels"] = [] |
|
if session_state["inference_state"] is not None: |
|
predictor.reset_state(session_state["inference_state"]) |
|
session_state["first_frame"] = None |
|
session_state["all_frames"] = None |
|
session_state["inference_state"] = None |
|
session_state["progress"] = 0 |
|
return ( |
|
None, |
|
gr.update(open=True), |
|
None, |
|
None, |
|
gr.update(value=None, visible=False), |
|
gr.update(value=0, visible=False), |
|
session_state, |
|
) |
|
|
|
def clear_points(session_state): |
|
"""Clear tracking points while keeping the video frames.""" |
|
session_state["input_points"] = [] |
|
session_state["input_labels"] = [] |
|
if session_state["inference_state"] is not None and session_state["inference_state"].get("tracking_has_started", False): |
|
predictor.reset_state(session_state["inference_state"]) |
|
return ( |
|
session_state["first_frame"], |
|
None, |
|
gr.update(value=None, visible=False), |
|
gr.update(value=0, visible=False), |
|
session_state, |
|
) |
|
|
|
def preprocess_video_in(video_path, session_state): |
|
"""Process input video to extract frames for tracking.""" |
|
if video_path is None or not os.path.exists(video_path): |
|
return ( |
|
gr.update(open=True), |
|
None, |
|
None, |
|
gr.update(value=None, visible=False), |
|
gr.update(value=0, visible=False), |
|
session_state, |
|
) |
|
|
|
|
|
cap = cv2.VideoCapture(video_path) |
|
if not cap.isOpened(): |
|
print(f"Error: Could not open video at {video_path}.") |
|
return ( |
|
gr.update(open=True), |
|
None, |
|
None, |
|
gr.update(value=None, visible=False), |
|
gr.update(value=0, visible=False), |
|
session_state, |
|
) |
|
|
|
frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) |
|
frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) |
|
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) |
|
fps = cap.get(cv2.CAP_PROP_FPS) |
|
|
|
print(f"Video info: {frame_width}x{frame_height}, {total_frames} frames, {fps} FPS") |
|
|
|
target_width = 640 |
|
scale_factor = 1.0 |
|
|
|
if frame_width > target_width: |
|
scale_factor = target_width / frame_width |
|
new_width = int(frame_width * scale_factor) |
|
new_height = int(frame_height * scale_factor) |
|
print(f"Resizing video for CPU processing: {frame_width}x{frame_height} -> {new_width}x{new_height}") |
|
|
|
|
|
frame_stride = 1 |
|
max_frames = 150 |
|
if total_frames > max_frames: |
|
frame_stride = max(1, int(total_frames / max_frames)) |
|
print(f"Video has {total_frames} frames, using stride of {frame_stride} to limit to {max_frames}") |
|
|
|
frame_number = 0 |
|
first_frame = None |
|
all_frames = [] |
|
|
|
while True: |
|
ret, frame = cap.read() |
|
if not ret: |
|
break |
|
|
|
if frame_number % frame_stride == 0: |
|
try: |
|
|
|
if scale_factor != 1.0: |
|
frame = cv2.resize( |
|
frame, |
|
(int(frame_width * scale_factor), int(frame_height * scale_factor)), |
|
interpolation=cv2.INTER_AREA |
|
) |
|
|
|
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) |
|
frame = np.array(frame) |
|
|
|
if first_frame is None: |
|
first_frame = frame |
|
all_frames.append(frame) |
|
except Exception as e: |
|
print(f"Error processing frame {frame_number}: {e}") |
|
|
|
frame_number += 1 |
|
|
|
cap.release() |
|
|
|
if first_frame is None or len(all_frames) == 0: |
|
print("Error: No frames could be extracted from the video.") |
|
return ( |
|
gr.update(open=True), |
|
None, |
|
None, |
|
gr.update(value=None, visible=False), |
|
gr.update(value=0, visible=False), |
|
session_state, |
|
) |
|
|
|
print(f"Successfully extracted {len(all_frames)} frames from video") |
|
|
|
session_state["first_frame"] = copy.deepcopy(first_frame) |
|
session_state["all_frames"] = all_frames |
|
session_state["frame_stride"] = frame_stride |
|
session_state["scale_factor"] = scale_factor |
|
session_state["original_dimensions"] = (frame_width, frame_height) |
|
session_state["progress"] = 0 |
|
|
|
try: |
|
session_state["inference_state"] = predictor.init_state(video_path=video_path) |
|
session_state["input_points"] = [] |
|
session_state["input_labels"] = [] |
|
except Exception as e: |
|
print(f"Error initializing inference state: {e}") |
|
import traceback |
|
traceback.print_exc() |
|
session_state["inference_state"] = None |
|
|
|
return [ |
|
gr.update(open=False), |
|
first_frame, |
|
None, |
|
gr.update(value=None, visible=False), |
|
gr.update(value=0, visible=False), |
|
session_state, |
|
] |
|
|
|
def segment_with_points( |
|
point_type, |
|
session_state, |
|
evt: gr.SelectData, |
|
): |
|
"""Add and process tracking points on the first frame.""" |
|
if session_state["first_frame"] is None: |
|
print("Error: No frame available for segmentation") |
|
return None, None, session_state |
|
|
|
session_state["input_points"].append(evt.index) |
|
print(f"TRACKING INPUT POINT: {session_state['input_points']}") |
|
|
|
if point_type == "include": |
|
session_state["input_labels"].append(1) |
|
elif point_type == "exclude": |
|
session_state["input_labels"].append(0) |
|
print(f"TRACKING INPUT LABEL: {session_state['input_labels']}") |
|
|
|
|
|
first_frame = session_state["first_frame"] |
|
h, w = first_frame.shape[:2] |
|
from PIL import Image |
|
transparent_background = Image.fromarray(first_frame).convert("RGBA") |
|
|
|
|
|
fraction = 0.01 |
|
radius = int(fraction * min(w, h)) |
|
if radius < 3: |
|
radius = 3 |
|
|
|
|
|
transparent_layer = np.zeros((h, w, 4), dtype=np.uint8) |
|
|
|
for index, track in enumerate(session_state["input_points"]): |
|
if session_state["input_labels"][index] == 1: |
|
cv2.circle(transparent_layer, track, radius, (0, 255, 0, 255), -1) |
|
else: |
|
cv2.circle(transparent_layer, track, radius, (255, 0, 0, 255), -1) |
|
|
|
|
|
transparent_layer = Image.fromarray(transparent_layer, "RGBA") |
|
selected_point_map = Image.alpha_composite( |
|
transparent_background, transparent_layer |
|
) |
|
|
|
|
|
points = np.array(session_state["input_points"], dtype=np.float32) |
|
labels = np.array(session_state["input_labels"], np.int32) |
|
|
|
try: |
|
if predictor is None: |
|
raise ValueError("Model predictor is not initialized") |
|
|
|
if session_state["inference_state"] is None: |
|
raise ValueError("Inference state is not initialized") |
|
|
|
|
|
_, _, out_mask_logits = predictor.add_new_points( |
|
inference_state=session_state["inference_state"], |
|
frame_idx=0, |
|
obj_id=OBJ_ID, |
|
points=points, |
|
labels=labels, |
|
) |
|
|
|
|
|
out_mask = (out_mask_logits[0] > 0.0).cpu().numpy() |
|
|
|
|
|
|
|
overlay = np.zeros((h, w, 3), dtype=np.uint8) |
|
|
|
|
|
overlay_mask = np.zeros_like(overlay) |
|
|
|
|
|
if out_mask.shape[0] > 0 and out_mask.shape[1] > 0: |
|
|
|
if out_mask.shape[:2] != (h, w): |
|
print(f"Resizing mask from {out_mask.shape[:2]} to {h}x{w}") |
|
|
|
from PIL import Image |
|
|
|
|
|
if out_mask.dtype != np.bool_: |
|
out_mask = out_mask > 0 |
|
|
|
mask_img = Image.fromarray(out_mask.astype(np.uint8) * 255) |
|
mask_img = mask_img.resize((w, h), Image.NEAREST) |
|
out_mask = np.array(mask_img) > 0 |
|
|
|
|
|
overlay_mask[out_mask] = [0, 120, 255] |
|
|
|
|
|
alpha = 0.5 |
|
frame_with_mask = cv2.addWeighted( |
|
first_frame, 1, overlay_mask, alpha, 0 |
|
) |
|
|
|
|
|
points_overlay = np.zeros((h, w, 4), dtype=np.uint8) |
|
for index, track in enumerate(session_state["input_points"]): |
|
if session_state["input_labels"][index] == 1: |
|
cv2.circle(points_overlay, track, radius, (0, 255, 0, 255), -1) |
|
else: |
|
cv2.circle(points_overlay, track, radius, (255, 0, 0, 255), -1) |
|
|
|
|
|
frame_with_mask_pil = Image.fromarray(frame_with_mask) |
|
points_overlay_pil = Image.fromarray(points_overlay, "RGBA") |
|
|
|
|
|
first_frame_output = Image.alpha_composite( |
|
frame_with_mask_pil.convert("RGBA"), points_overlay_pil |
|
) |
|
except Exception as e: |
|
print(f"Error in segmentation: {e}") |
|
import traceback |
|
traceback.print_exc() |
|
|
|
first_frame_output = selected_point_map |
|
|
|
return selected_point_map, np.array(first_frame_output), session_state |
|
|
|
def show_mask(mask, obj_id=None, random_color=False, convert_to_image=True): |
|
"""Convert binary mask to RGBA image for visualization.""" |
|
|
|
if mask is None or mask.size == 0: |
|
print("Warning: Empty mask provided to show_mask") |
|
|
|
if convert_to_image: |
|
return Image.new('RGBA', (100, 100), (0, 0, 0, 0)) |
|
else: |
|
return np.zeros((100, 100, 4), dtype=np.uint8) |
|
|
|
|
|
if len(mask.shape) == 2: |
|
h, w = mask.shape |
|
else: |
|
h, w = mask.shape[-2:] |
|
|
|
if h == 0 or w == 0: |
|
print(f"Warning: Invalid mask dimensions: {h}x{w}") |
|
|
|
if convert_to_image: |
|
return Image.new('RGBA', (100, 100), (0, 0, 0, 0)) |
|
else: |
|
return np.zeros((100, 100, 4), dtype=np.uint8) |
|
|
|
|
|
if random_color: |
|
color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0) |
|
else: |
|
cmap = plt.get_cmap("tab10") |
|
cmap_idx = 0 if obj_id is None else obj_id |
|
color = np.array([*cmap(cmap_idx)[:3], 0.6]) |
|
|
|
try: |
|
|
|
colored_mask = np.zeros((h, w, 4), dtype=np.uint8) |
|
|
|
|
|
for i in range(3): |
|
colored_mask[:, :, i] = (mask * color[i] * 255).astype(np.uint8) |
|
|
|
|
|
colored_mask[:, :, 3] = (mask * color[3] * 255).astype(np.uint8) |
|
|
|
if convert_to_image: |
|
return Image.fromarray(colored_mask, "RGBA") |
|
else: |
|
return colored_mask |
|
except Exception as e: |
|
print(f"Error in show_mask: {e}") |
|
import traceback |
|
traceback.print_exc() |
|
|
|
|
|
if convert_to_image: |
|
return Image.new('RGBA', (h, w), (0, 0, 0, 0)) |
|
else: |
|
return np.zeros((h, w, 4), dtype=np.uint8) |
|
|
|
def update_progress(progress_percent, progress_bar): |
|
"""Update progress bar during processing.""" |
|
return gr.update(value=progress_percent, visible=True) |
|
|
|
def propagate_to_all( |
|
video_in, |
|
session_state, |
|
progress=gr.Progress(), |
|
): |
|
"""Process video frames and generate masked video output with progress tracking.""" |
|
if ( |
|
len(session_state["input_points"]) == 0 |
|
or video_in is None |
|
or session_state["inference_state"] is None |
|
or predictor is None |
|
): |
|
print("Missing required data for tracking") |
|
return ( |
|
gr.update(value=None, visible=False), |
|
gr.update(value=0, visible=False), |
|
session_state, |
|
) |
|
|
|
|
|
chunk_size = 3 |
|
|
|
try: |
|
|
|
video_segments = {} |
|
print("Starting propagate_in_video on CPU") |
|
|
|
|
|
all_frames_count = 100 |
|
|
|
|
|
current_frame = 0 |
|
for out_frame_idx, out_obj_ids, out_mask_logits in predictor.propagate_in_video( |
|
session_state["inference_state"] |
|
): |
|
try: |
|
|
|
video_segments[out_frame_idx] = { |
|
out_obj_id: (out_mask_logits[i] > 0.0).cpu().numpy() |
|
for i, out_obj_id in enumerate(out_obj_ids) |
|
} |
|
|
|
|
|
current_frame += 1 |
|
progress_percent = min(50, int((current_frame / all_frames_count) * 50)) |
|
session_state["progress"] = progress_percent |
|
progress(progress_percent/100, desc="Processing frames") |
|
|
|
if out_frame_idx % 10 == 0: |
|
print(f"Processed frame {out_frame_idx} ({progress_percent}%)") |
|
|
|
|
|
if out_frame_idx % chunk_size == 0: |
|
|
|
del out_mask_logits |
|
import gc |
|
gc.collect() |
|
except Exception as e: |
|
print(f"Error processing frame {out_frame_idx}: {e}") |
|
import traceback |
|
traceback.print_exc() |
|
continue |
|
|
|
|
|
|
|
total_frames = len(video_segments) |
|
print(f"Total frames processed: {total_frames}") |
|
|
|
|
|
session_state["progress"] = 50 |
|
progress(0.5, desc="Rendering video") |
|
|
|
|
|
max_output_frames = 30 |
|
vis_frame_stride = max(1, total_frames // max_output_frames) |
|
print(f"Using stride of {vis_frame_stride} for output video generation") |
|
|
|
|
|
if len(session_state["all_frames"]) == 0: |
|
raise ValueError("No frames available in session state") |
|
|
|
first_frame = session_state["all_frames"][0] |
|
h, w = first_frame.shape[:2] |
|
|
|
|
|
output_frames = [] |
|
|
|
frame_indices = list(range(0, total_frames, vis_frame_stride)) |
|
total_output_frames = len(frame_indices) |
|
|
|
for i, out_frame_idx in enumerate(frame_indices): |
|
if out_frame_idx not in video_segments or OBJ_ID not in video_segments[out_frame_idx]: |
|
continue |
|
|
|
try: |
|
|
|
if out_frame_idx >= len(session_state["all_frames"]): |
|
print(f"Warning: Frame index {out_frame_idx} exceeds available frames {len(session_state['all_frames'])}") |
|
frame_idx = min(out_frame_idx, len(session_state["all_frames"])-1) |
|
else: |
|
frame_idx = out_frame_idx |
|
|
|
frame = session_state["all_frames"][frame_idx] |
|
|
|
|
|
|
|
out_mask = video_segments[out_frame_idx][OBJ_ID] |
|
|
|
|
|
if out_mask.size == 0 or 0 in out_mask.shape: |
|
print(f"Warning: Invalid mask for frame {out_frame_idx}") |
|
|
|
continue |
|
|
|
|
|
frame_h, frame_w = frame.shape[:2] |
|
mask_h, mask_w = out_mask.shape[:2] |
|
|
|
|
|
if mask_h != frame_h or mask_w != frame_w: |
|
print(f"Resizing mask from {mask_h}x{mask_w} to {frame_h}x{frame_w}") |
|
try: |
|
|
|
if out_mask.dtype != np.bool_: |
|
out_mask = out_mask > 0 |
|
|
|
mask_img = Image.fromarray(out_mask.astype(np.uint8) * 255) |
|
mask_img = mask_img.resize((frame_w, frame_h), Image.NEAREST) |
|
out_mask = np.array(mask_img) > 0 |
|
except Exception as e: |
|
print(f"Error resizing mask: {e}") |
|
|
|
continue |
|
|
|
|
|
overlay = np.zeros_like(frame) |
|
|
|
|
|
overlay[out_mask] = [0, 120, 255] |
|
|
|
|
|
alpha = 0.5 |
|
output_frame = cv2.addWeighted(frame, 1, overlay, alpha, 0) |
|
|
|
|
|
output_frames.append(output_frame) |
|
|
|
|
|
progress_percent = 50 + min(50, int((i / total_output_frames) * 50)) |
|
session_state["progress"] = progress_percent |
|
progress(progress_percent/100, desc=f"Rendering video frames ({i}/{total_output_frames})") |
|
|
|
|
|
if len(output_frames) % 10 == 0: |
|
import gc |
|
gc.collect() |
|
|
|
except Exception as e: |
|
print(f"Error creating output frame {out_frame_idx}: {e}") |
|
import traceback |
|
traceback.print_exc() |
|
progress.tqdm.update(1) |
|
continue |
|
|
|
|
|
original_fps = get_video_fps(video_in) |
|
fps = original_fps |
|
|
|
|
|
if fps > 15: |
|
fps = 15 |
|
|
|
print(f"Creating video with {len(output_frames)} frames at {fps} FPS") |
|
|
|
|
|
session_state["progress"] = 90 |
|
|
|
|
|
if len(output_frames) == 0: |
|
raise ValueError("No output frames were generated") |
|
|
|
|
|
first_shape = output_frames[0].shape |
|
valid_frames = [] |
|
for i, frame in enumerate(output_frames): |
|
if frame.shape == first_shape: |
|
valid_frames.append(frame) |
|
else: |
|
print(f"Skipping frame {i} with inconsistent shape: {frame.shape} vs {first_shape}") |
|
|
|
if len(valid_frames) == 0: |
|
raise ValueError("No valid frames with consistent shape") |
|
|
|
clip = ImageSequenceClip(valid_frames, fps=fps) |
|
|
|
|
|
unique_id = datetime.now().strftime("%Y%m%d%H%M%S") |
|
final_vid_output_path = f"output_video_{unique_id}.mp4" |
|
final_vid_output_path = os.path.join(tempfile.gettempdir(), final_vid_output_path) |
|
|
|
|
|
clip.write_videofile( |
|
final_vid_output_path, |
|
codec="libx264", |
|
bitrate="800k", |
|
threads=2, |
|
logger=None |
|
) |
|
|
|
|
|
session_state["progress"] = 100 |
|
|
|
|
|
del video_segments |
|
del output_frames |
|
import gc |
|
gc.collect() |
|
|
|
return ( |
|
gr.update(value=final_vid_output_path, visible=True), |
|
gr.update(value=100, visible=False), |
|
session_state, |
|
) |
|
|
|
except Exception as e: |
|
print(f"Error in propagate_to_all: {e}") |
|
import traceback |
|
traceback.print_exc() |
|
return ( |
|
gr.update(value=None, visible=False), |
|
gr.update(value=0, visible=False), |
|
session_state, |
|
) |
|
|
|
def update_ui(): |
|
"""Show progress bar when starting processing.""" |
|
return gr.update(visible=True), gr.update(visible=True, value=0) |
|
|
|
|
|
|
|
with gr.Blocks() as demo: |
|
session_state = gr.State( |
|
{ |
|
"first_frame": None, |
|
"all_frames": None, |
|
"input_points": [], |
|
"input_labels": [], |
|
"inference_state": None, |
|
"frame_stride": 1, |
|
"scale_factor": 1.0, |
|
"original_dimensions": None, |
|
"progress": 0, |
|
} |
|
) |
|
|
|
with gr.Column(): |
|
|
|
gr.Markdown(title) |
|
with gr.Row(): |
|
|
|
with gr.Column(): |
|
|
|
gr.Markdown(description_p) |
|
|
|
with gr.Accordion("Input Video", open=True) as video_in_drawer: |
|
video_in = gr.Video(label="Input Video", format="mp4") |
|
|
|
with gr.Row(): |
|
point_type = gr.Radio( |
|
label="point type", |
|
choices=["include", "exclude"], |
|
value="include", |
|
scale=2, |
|
) |
|
propagate_btn = gr.Button("Track", scale=1, variant="primary") |
|
clear_points_btn = gr.Button("Clear Points", scale=1) |
|
reset_btn = gr.Button("Reset", scale=1) |
|
|
|
points_map = gr.Image( |
|
label="Frame with Point Prompt", type="numpy", interactive=False |
|
) |
|
|
|
|
|
progress_bar = gr.Slider( |
|
minimum=0, |
|
maximum=100, |
|
value=0, |
|
step=1, |
|
label="Processing Progress", |
|
visible=False, |
|
interactive=False |
|
) |
|
|
|
with gr.Column(): |
|
gr.Markdown("# Try some of the examples below ⬇️") |
|
gr.Examples( |
|
examples=examples, |
|
inputs=[ |
|
video_in, |
|
], |
|
examples_per_page=5, |
|
) |
|
|
|
output_image = gr.Image(label="Reference Mask") |
|
output_video = gr.Video(visible=False) |
|
|
|
|
|
video_in.upload( |
|
fn=preprocess_video_in, |
|
inputs=[ |
|
video_in, |
|
session_state, |
|
], |
|
outputs=[ |
|
video_in_drawer, |
|
points_map, |
|
output_image, |
|
output_video, |
|
progress_bar, |
|
session_state, |
|
], |
|
queue=False, |
|
) |
|
|
|
video_in.change( |
|
fn=preprocess_video_in, |
|
inputs=[ |
|
video_in, |
|
session_state, |
|
], |
|
outputs=[ |
|
video_in_drawer, |
|
points_map, |
|
output_image, |
|
output_video, |
|
progress_bar, |
|
session_state, |
|
], |
|
queue=False, |
|
) |
|
|
|
|