|
|
|
|
|
|
|
|
|
|
|
|
|
import copy |
|
import os |
|
from datetime import datetime |
|
import tempfile |
|
|
|
import cv2 |
|
import matplotlib.pyplot as plt |
|
import numpy as np |
|
import gradio as gr |
|
import torch |
|
from moviepy.editor import ImageSequenceClip |
|
from PIL import Image |
|
from sam2.build_sam import build_sam2_video_predictor |
|
|
|
|
|
if 'TORCH_CUDNN_SDPA_ENABLED' in os.environ: |
|
del os.environ["TORCH_CUDNN_SDPA_ENABLED"] |
|
|
|
|
|
title = "<center><strong><font size='8'>EdgeTAM CPU</font></strong> <a href='https://github.com/facebookresearch/EdgeTAM'><font size='6'>[GitHub]</font></a></center>" |
|
|
|
description_p = """# Instructions |
|
<ol> |
|
<li>Upload one video or click one example video</li> |
|
<li>Click 'include' point type, select the object to segment and track</li> |
|
<li>Click 'exclude' point type (optional), select the area to avoid segmenting</li> |
|
<li>Click the 'Track' button to obtain the masked video</li> |
|
</ol> |
|
""" |
|
|
|
|
|
examples = [ |
|
["examples/01_dog.mp4"], |
|
["examples/02_cups.mp4"], |
|
["examples/03_blocks.mp4"], |
|
["examples/04_coffee.mp4"], |
|
["examples/05_default_juggle.mp4"], |
|
] |
|
|
|
OBJ_ID = 0 |
|
|
|
|
|
sam2_checkpoint = "checkpoints/edgetam.pt" |
|
model_cfg = "edgetam.yaml" |
|
|
|
def check_file_exists(filepath): |
|
exists = os.path.exists(filepath) |
|
if not exists: |
|
print(f"WARNING: File not found: {filepath}") |
|
return exists |
|
|
|
|
|
model_files_exist = check_file_exists(sam2_checkpoint) and check_file_exists(model_cfg) |
|
try: |
|
predictor = build_sam2_video_predictor(model_cfg, sam2_checkpoint, device="cpu") |
|
print("Predictor loaded on CPU") |
|
except Exception as e: |
|
print(f"Error loading model: {e}") |
|
import traceback |
|
traceback.print_exc() |
|
predictor = None |
|
|
|
|
|
def get_video_fps(video_path): |
|
cap = cv2.VideoCapture(video_path) |
|
if not cap.isOpened(): |
|
print("Error: Could not open video.") |
|
return 30.0 |
|
fps = cap.get(cv2.CAP_PROP_FPS) |
|
cap.release() |
|
return fps |
|
|
|
def reset(session_state): |
|
session_state["input_points"] = [] |
|
session_state["input_labels"] = [] |
|
if session_state["inference_state"] is not None: |
|
predictor.reset_state(session_state["inference_state"]) |
|
session_state["first_frame"] = None |
|
session_state["all_frames"] = None |
|
session_state["inference_state"] = None |
|
return ( |
|
None, |
|
gr.update(open=True), |
|
None, |
|
None, |
|
gr.update(value=None, visible=False), |
|
session_state, |
|
) |
|
|
|
def clear_points(session_state): |
|
session_state["input_points"] = [] |
|
session_state["input_labels"] = [] |
|
if session_state["inference_state"] is not None and session_state["inference_state"].get("tracking_has_started", False): |
|
predictor.reset_state(session_state["inference_state"]) |
|
return ( |
|
session_state["first_frame"], |
|
None, |
|
gr.update(value=None, visible=False), |
|
session_state, |
|
) |
|
|
|
def preprocess_video_in(video_path, session_state): |
|
if video_path is None: |
|
return ( |
|
gr.update(open=True), |
|
None, |
|
None, |
|
gr.update(value=None, visible=False), |
|
session_state, |
|
) |
|
|
|
cap = cv2.VideoCapture(video_path) |
|
if not cap.isOpened(): |
|
print("Error: Could not open video.") |
|
return ( |
|
gr.update(open=True), |
|
None, |
|
None, |
|
gr.update(value=None, visible=False), |
|
session_state, |
|
) |
|
|
|
|
|
frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) |
|
frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) |
|
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) |
|
|
|
|
|
target_width = 640 |
|
scale_factor = 1.0 |
|
if frame_width > target_width: |
|
scale_factor = target_width / frame_width |
|
frame_width = target_width |
|
frame_height = int(frame_height * scale_factor) |
|
|
|
|
|
frame_number = 0 |
|
first_frame = None |
|
all_frames = [] |
|
frame_stride = max(1, total_frames // 300) |
|
|
|
while True: |
|
ret, frame = cap.read() |
|
if not ret: |
|
break |
|
if frame_number % frame_stride == 0: |
|
if scale_factor != 1.0: |
|
frame = cv2.resize(frame, (frame_width, frame_height), interpolation=cv2.INTER_AREA) |
|
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) |
|
if first_frame is None: |
|
first_frame = frame |
|
all_frames.append(frame) |
|
frame_number += 1 |
|
|
|
cap.release() |
|
session_state["first_frame"] = copy.deepcopy(first_frame) |
|
session_state["all_frames"] = all_frames |
|
session_state["frame_stride"] = frame_stride |
|
session_state["scale_factor"] = scale_factor |
|
session_state["original_dimensions"] = (int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)), int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))) |
|
|
|
session_state["inference_state"] = predictor.init_state(video_path=video_path) |
|
session_state["input_points"] = [] |
|
session_state["input_labels"] = [] |
|
|
|
return [ |
|
gr.update(open=False), |
|
first_frame, |
|
None, |
|
gr.update(value=None, visible=False), |
|
session_state, |
|
] |
|
|
|
def segment_with_points(point_type, session_state, evt: gr.SelectData): |
|
session_state["input_points"].append(evt.index) |
|
print(f"TRACKING INPUT POINT: {session_state['input_points']}") |
|
|
|
if point_type == "include": |
|
session_state["input_labels"].append(1) |
|
elif point_type == "exclude": |
|
session_state["input_labels"].append(0) |
|
print(f"TRACKING INPUT LABEL: {session_state['input_labels']}") |
|
|
|
first_frame = session_state["first_frame"] |
|
h, w = first_frame.shape[:2] |
|
transparent_background = Image.fromarray(first_frame).convert("RGBA") |
|
|
|
|
|
fraction = 0.01 |
|
radius = int(fraction * min(w, h)) |
|
transparent_layer = np.zeros((h, w, 4), dtype=np.uint8) |
|
|
|
for index, track in enumerate(session_state["input_points"]): |
|
color = (0, 255, 0, 255) if session_state["input_labels"][index] == 1 else (255, 0, 0, 255) |
|
cv2.circle(transparent_layer, track, radius, color, -1) |
|
|
|
transparent_layer = Image.fromarray(transparent_layer, "RGBA") |
|
selected_point_map = Image.alpha_composite(transparent_background, transparent_layer) |
|
|
|
points = np.array(session_state["input_points"], dtype=np.float32) |
|
labels = np.array(session_state["input_labels"], np.int32) |
|
|
|
try: |
|
_, _, out_mask_logits = predictor.add_new_points( |
|
inference_state=session_state["inference_state"], |
|
frame_idx=0, |
|
obj_id=OBJ_ID, |
|
points=points, |
|
labels=labels, |
|
) |
|
mask_array = (out_mask_logits[0] > 0.0).cpu().numpy() |
|
|
|
|
|
if mask_array.shape[:2] != (h, w): |
|
mask_array = cv2.resize(mask_array.astype(np.uint8), (w, h), interpolation=cv2.INTER_NEAREST).astype(bool) |
|
|
|
mask_image = show_mask(mask_array) |
|
if mask_image.size != transparent_background.size: |
|
mask_image = mask_image.resize(transparent_background.size, Image.NEAREST) |
|
|
|
first_frame_output = Image.alpha_composite(transparent_background, mask_image) |
|
except Exception as e: |
|
print(f"Error in segmentation: {e}") |
|
first_frame_output = selected_point_map |
|
|
|
return selected_point_map, first_frame_output, session_state |
|
|
|
def show_mask(mask, obj_id=None, random_color=False, convert_to_image=True): |
|
if random_color: |
|
color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0) |
|
else: |
|
cmap = plt.get_cmap("tab10") |
|
cmap_idx = 0 if obj_id is None else obj_id |
|
color = np.array([*cmap(cmap_idx)[:3], 0.6]) |
|
|
|
h, w = mask.shape[-2:] if len(mask.shape) > 2 else mask.shape |
|
mask_reshaped = mask.reshape(h, w, 1) * color.reshape(1, 1, -1) |
|
mask_rgba = (mask_reshaped * 255).astype(np.uint8) |
|
|
|
if convert_to_image: |
|
try: |
|
if mask_rgba.shape[2] != 4: |
|
proper_mask = np.zeros((h, w, 4), dtype=np.uint8) |
|
proper_mask[:, :, :min(mask_rgba.shape[2], 4)] = mask_rgba[:, :, :min(mask_rgba.shape[2], 4)] |
|
mask_rgba = proper_mask |
|
return Image.fromarray(mask_rgba, "RGBA") |
|
except Exception as e: |
|
print(f"Error converting mask to image: {e}") |
|
return Image.fromarray(np.zeros((h, w, 4), dtype=np.uint8), "RGBA") |
|
|
|
return mask_rgba |
|
|
|
def propagate_to_all(video_in, session_state, progress=gr.Progress()): |
|
if len(session_state["input_points"]) == 0 or video_in is None or session_state["inference_state"] is None: |
|
return gr.update(value=None, visible=False), session_state |
|
|
|
chunk_size = 3 |
|
try: |
|
video_segments = {} |
|
total_frames = len(session_state["all_frames"]) |
|
progress(0, desc="Propagating segmentation through video...") |
|
|
|
for i, (out_frame_idx, out_obj_ids, out_mask_logit) in enumerate(predictor.propagate_in_video(session_state["inference_state"])): |
|
try: |
|
video_segments[out_frame_idx] = { |
|
out_obj_id: (out_mask_logit[i] > 0.0).cpu().numpy() |
|
for i, out_obj_id in enumerate(out_obj_ids) |
|
} |
|
progress((i + 1) / total_frames, desc=f"Processed frame {out_frame_idx}/{total_frames}") |
|
if out_frame_idx % chunk_size == 0: |
|
del out_mask_logit |
|
import gc |
|
gc.collect() |
|
except Exception as e: |
|
print(f"Error processing frame {out_frame_idx}: {e}") |
|
continue |
|
|
|
max_output_frames = 50 |
|
vis_frame_stride = max(1, total_frames // max_output_frames) |
|
first_frame = session_state["all_frames"][0] |
|
h, w = first_frame.shape[:2] |
|
output_frames = [] |
|
|
|
for out_frame_idx in range(0, total_frames, vis_frame_stride): |
|
if out_frame_idx not in video_segments or OBJ_ID not in video_segments[out_frame_idx]: |
|
continue |
|
try: |
|
frame = session_state["all_frames"][out_frame_idx] |
|
transparent_background = Image.fromarray(frame).convert("RGBA") |
|
out_mask = video_segments[out_frame_idx][OBJ_ID] |
|
|
|
|
|
if out_mask.shape[:2] != (h, w): |
|
if out_mask.size == 0: |
|
print(f"Skipping empty mask for frame {out_frame_idx}") |
|
continue |
|
out_mask = cv2.resize(out_mask.astype(np.uint8), (w, h), interpolation=cv2.INTER_NEAREST).astype(bool) |
|
|
|
mask_image = show_mask(out_mask) |
|
if mask_image.size != transparent_background.size: |
|
mask_image = mask_image.resize(transparent_background.size, Image.NEAREST) |
|
|
|
output_frame = Image.alpha_composite(transparent_background, mask_image) |
|
output_frames.append(np.array(output_frame)) |
|
|
|
if len(output_frames) % 10 == 0: |
|
import gc |
|
gc.collect() |
|
except Exception as e: |
|
print(f"Error creating output frame {out_frame_idx}: {e_RAW |
|
traceback.print_exc() |
|
continue |
|
|
|
original_fps = get_video_fps(video_in) |
|
fps = min(original_fps, 15) # Cap at 15 FPS for CPU |
|
|
|
clip = ImageSequenceClip(output_frames, fps=fps) |
|
unique_id = datetime.now().strftime("%Y%m%d%H%M%S") |
|
final_vid_output_path = os.path.join(tempfile.gettempdir(), f"output_video_{unique_id}.mp4") |
|
|
|
clip.write_videofile( |
|
final_vid_output_path, |
|
codec="libx264", |
|
bitrate="800k", |
|
threads=2, |
|
logger=None |
|
) |
|
|
|
del video_segments, output_frames |
|
import gc |
|
gc.collect() |
|
|
|
return gr.update(value=final_vid_output_path, visible=True), session_state |
|
|
|
except Exception as e: |
|
print(f"Error in propagate_to_all: {e}") |
|
return gr.update(value=None, visible=False), session_state |
|
|
|
def update_ui(): |
|
return gr.update(visible=True) |
|
|
|
# Gradio Interface |
|
with gr.Blocks() as demo: |
|
session_state = gr.State({ |
|
"first_frame": None, |
|
"all_frames": None, |
|
"input_points": [], |
|
"input_labels": [], |
|
"inference_state": None, |
|
"frame_stride": 1, |
|
"scale_factor": 1.0, |
|
"original_dimensions": None, |
|
}) |
|
|
|
with gr.Column(): |
|
gr.Markdown(title) |
|
with gr.Row(): |
|
with gr.Column(): |
|
gr.Markdown(description_p) |
|
with gr.Accordion("Input Video", open=True) as video_in_drawer: |
|
video_in = gr.Video(label="Input Video", format="mp4") |
|
with gr.Row(): |
|
point_type = gr.Radio(label="point type", choices=["include", "exclude"], value="include", scale=2) |
|
propagate_btn = gr.Button("Track", scale=1, variant="primary") |
|
clear_points_btn = gr.Button("Clear Points", scale=1) |
|
reset_btn = gr.Button("Reset", scale=1) |
|
points_map = gr.Image(label="Frame with Point Prompt", type="numpy", interactive=False) |
|
with gr.Column(): |
|
gr.Markdown(" |
|
gr.Examples(examples=examples, inputs=[video_in], examples_per_page=5) |
|
output_image = gr.Image(label="Reference Mask") |
|
output_video = gr.Video(visible=False) |
|
|
|
video_in.upload( |
|
fn=preprocess_video_in, |
|
inputs=[video_in, session_state], |
|
outputs=[video_in_drawer, points_map, output_image, output_video, session_state], |
|
queue=False, |
|
) |
|
|
|
video_in.change( |
|
fn=preprocess_video_in, |
|
inputs=[video_in, session_state], |
|
outputs=[video_in_drawer, points_map, output_image, output_video, session_state], |
|
queue=False, |
|
) |
|
|
|
points_map.select( |
|
fn=segment_with_points, |
|
inputs=[point_type, session_state], |
|
outputs=[points_map, output_image, session_state], |
|
queue=False, |
|
) |
|
|
|
clear_points_btn.click( |
|
fn=clear_points, |
|
inputs=session_state, |
|
outputs=[points_map, output_image, output_video, session_state], |
|
queue=False, |
|
) |
|
|
|
reset_btn.click( |
|
fn=reset, |
|
inputs=session_state, |
|
outputs=[video_in, video_in_drawer, points_map, output_image, output_video, session_state], |
|
queue=False, |
|
) |
|
|
|
propagate_btn.click( |
|
fn=update_ui, |
|
inputs=[], |
|
outputs=output_video, |
|
queue=False, |
|
).then( |
|
fn=propagate_to_all, |
|
inputs=[video_in, session_state], |
|
outputs=[output_video, session_state], |
|
queue=True, |
|
) |
|
|
|
demo.queue() |
|
demo.launch() |