|
|
|
|
|
|
|
|
|
|
|
|
|
import copy |
|
import os |
|
from datetime import datetime |
|
|
|
import gradio as gr |
|
|
|
|
|
|
|
|
|
import tempfile |
|
|
|
import cv2 |
|
import matplotlib.pyplot as plt |
|
import numpy as np |
|
|
|
|
|
import torch |
|
|
|
from moviepy.editor import ImageSequenceClip |
|
from PIL import Image |
|
from sam2.build_sam import build_sam2_video_predictor |
|
|
|
|
|
title = "<center><strong><font size='8'>EdgeTAM<font></strong> <a href='https://github.com/facebookresearch/EdgeTAM'><font size='6'>[GitHub]</font></a> </center>" |
|
|
|
description_p = """# Instructions |
|
<ol> |
|
<li> Upload one video or click one example video</li> |
|
<li> Click 'include' point type, select the object to segment and track</li> |
|
<li> Click 'exclude' point type (optional), select the area you want to avoid segmenting and tracking</li> |
|
<li> Click the 'Track' button to obtain the masked video </li> |
|
</ol> |
|
""" |
|
|
|
|
|
examples = [ |
|
["examples/01_dog.mp4"], |
|
["examples/02_cups.mp4"], |
|
["examples/03_blocks.mp4"], |
|
["examples/04_coffee.mp4"], |
|
["examples/05_default_juggle.mp4"], |
|
["examples/01_breakdancer.mp4"], |
|
["examples/02_hummingbird.mp4"], |
|
["examples/03_skateboarder.mp4"], |
|
["examples/04_octopus.mp4"], |
|
["examples/05_landing_dog_soccer.mp4"], |
|
["examples/06_pingpong.mp4"], |
|
["examples/07_snowboarder.mp4"], |
|
["examples/08_driving.mp4"], |
|
["examples/09_birdcartoon.mp4"], |
|
["examples/10_cloth_magic.mp4"], |
|
["examples/11_polevault.mp4"], |
|
["examples/12_hideandseek.mp4"], |
|
["examples/13_butterfly.mp4"], |
|
["examples/14_social_dog_training.mp4"], |
|
["examples/15_cricket.mp4"], |
|
["examples/16_robotarm.mp4"], |
|
["examples/17_childrendancing.mp4"], |
|
["examples/18_threedogs.mp4"], |
|
["examples/19_cyclist.mp4"], |
|
["examples/20_doughkneading.mp4"], |
|
["examples/21_biker.mp4"], |
|
["examples/22_dogskateboarder.mp4"], |
|
["examples/23_racecar.mp4"], |
|
["examples/24_clownfish.mp4"], |
|
] |
|
|
|
OBJ_ID = 0 |
|
|
|
sam2_checkpoint = "checkpoints/edgetam.pt" |
|
model_cfg = "edgetam.yaml" |
|
|
|
|
|
predictor = build_sam2_video_predictor(model_cfg, sam2_checkpoint, device="cpu") |
|
predictor.to("cpu") |
|
print("predictor loaded on CPU") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_video_fps(video_path): |
|
"""Gets the frames per second of a video file.""" |
|
if video_path is None or not os.path.exists(video_path): |
|
print(f"Warning: Video file not found at {video_path}") |
|
return None |
|
cap = cv2.VideoCapture(video_path) |
|
if not cap.isOpened(): |
|
print(f"Error: Could not open video file {video_path}.") |
|
return None |
|
fps = cap.get(cv2.CAP_PROP_FPS) |
|
cap.release() |
|
return fps |
|
|
|
|
|
def preprocess_video_in(video_path, session_state): |
|
"""Loads video frames and initializes the predictor state.""" |
|
print(f"Processing video: {video_path}") |
|
if video_path is None or not os.path.exists(video_path): |
|
print("No video path provided or file not found.") |
|
|
|
return ( |
|
gr.update(open=True), |
|
None, |
|
None, |
|
gr.update(value=None, visible=False), |
|
gr.update(interactive=False), |
|
gr.update(interactive=False), |
|
gr.update(interactive=False), |
|
{ |
|
"first_frame": None, |
|
"all_frames": None, |
|
"input_points": [], |
|
"input_labels": [], |
|
"inference_state": None, |
|
"video_path": None, |
|
} |
|
) |
|
|
|
|
|
cap = cv2.VideoCapture(video_path) |
|
if not cap.isOpened(): |
|
print(f"Error: Could not open video file {video_path}.") |
|
|
|
return ( |
|
gr.update(open=True), |
|
None, |
|
None, |
|
gr.update(value=None, visible=False), |
|
gr.update(interactive=False), |
|
gr.update(interactive=False), |
|
gr.update(interactive=False), |
|
{ |
|
"first_frame": None, |
|
"all_frames": None, |
|
"input_points": [], |
|
"input_labels": [], |
|
"inference_state": None, |
|
"video_path": None, |
|
} |
|
) |
|
|
|
first_frame = None |
|
all_frames = [] |
|
|
|
while True: |
|
ret, frame = cap.read() |
|
if not ret: |
|
break |
|
|
|
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) |
|
all_frames.append(frame) |
|
if first_frame is None: |
|
first_frame = frame |
|
|
|
cap.release() |
|
|
|
if not all_frames: |
|
print(f"Error: No frames read from video file {video_path}.") |
|
|
|
return ( |
|
gr.update(open=True), |
|
None, |
|
None, |
|
gr.update(value=None, visible=False), |
|
gr.update(interactive=False), |
|
gr.update(interactive=False), |
|
gr.update(interactive=False), |
|
{ |
|
"first_frame": None, |
|
"all_frames": None, |
|
"input_points": [], |
|
"input_labels": [], |
|
"inference_state": None, |
|
"video_path": None, |
|
} |
|
) |
|
|
|
|
|
session_state["first_frame"] = copy.deepcopy(first_frame) |
|
session_state["all_frames"] = all_frames |
|
session_state["video_path"] = video_path |
|
session_state["input_points"] = [] |
|
session_state["input_labels"] = [] |
|
|
|
session_state["inference_state"] = predictor.init_state(video_path=video_path) |
|
print("Video loaded and predictor state initialized.") |
|
|
|
return [ |
|
gr.update(open=False), |
|
first_frame, |
|
None, |
|
gr.update(value=None, visible=False), |
|
gr.update(interactive=True), |
|
gr.update(interactive=True), |
|
gr.update(interactive=True), |
|
session_state, |
|
] |
|
|
|
|
|
def reset(session_state): |
|
"""Resets the UI and session state.""" |
|
print("Resetting demo.") |
|
|
|
session_state["input_points"] = [] |
|
session_state["input_labels"] = [] |
|
|
|
if session_state["inference_state"] is not None: |
|
predictor.reset_state(session_state["inference_state"]) |
|
|
|
session_state["inference_state"] = None |
|
|
|
session_state["first_frame"] = None |
|
session_state["all_frames"] = None |
|
session_state["video_path"] = None |
|
|
|
|
|
return ( |
|
None, |
|
gr.update(open=True), |
|
None, |
|
None, |
|
gr.update(value=None, visible=False), |
|
gr.update(interactive=False), |
|
gr.update(interactive=False), |
|
gr.update(interactive=False), |
|
session_state, |
|
) |
|
|
|
|
|
def clear_points(session_state): |
|
"""Clears selected points and resets segmentation on the first frame.""" |
|
print("Clearing points.") |
|
|
|
session_state["input_points"] = [] |
|
session_state["input_labels"] = [] |
|
|
|
|
|
|
|
if session_state["inference_state"] is not None: |
|
predictor.reset_state(session_state["inference_state"]) |
|
|
|
|
|
if session_state["video_path"] is not None: |
|
|
|
session_state["inference_state"] = predictor.init_state(video_path=session_state["video_path"]) |
|
print("Predictor state re-initialized after clearing points.") |
|
else: |
|
print("Warning: Could not re-initialize state after clear_points (video_path missing).") |
|
session_state["inference_state"] = None |
|
|
|
|
|
|
|
|
|
first_frame_img = session_state["first_frame"] if session_state["first_frame"] is not None else None |
|
|
|
return ( |
|
first_frame_img, |
|
None, |
|
gr.update(value=None, visible=False), |
|
session_state, |
|
) |
|
|
|
|
|
|
|
def segment_with_points( |
|
point_type, |
|
session_state, |
|
evt: gr.SelectData, |
|
): |
|
"""Adds a point prompt and performs segmentation on the first frame.""" |
|
|
|
if session_state["first_frame"] is None or session_state["inference_state"] is None: |
|
print("Error: Cannot segment. No video loaded or inference state missing.") |
|
|
|
return ( |
|
session_state["first_frame"], |
|
None, |
|
session_state, |
|
) |
|
|
|
|
|
click_coords = evt.index |
|
print(f"Clicked at: {click_coords} ({point_type})") |
|
|
|
session_state["input_points"].append(click_coords) |
|
|
|
if point_type == "include": |
|
session_state["input_labels"].append(1) |
|
elif point_type == "exclude": |
|
session_state["input_labels"].append(0) |
|
|
|
|
|
first_frame_pil = Image.fromarray(session_state["first_frame"]).convert("RGBA") |
|
w, h = first_frame_pil.size |
|
|
|
|
|
fraction = 0.01 |
|
radius = max(2, int(fraction * min(w, h))) |
|
|
|
|
|
transparent_layer_points = np.zeros((h, w, 4), dtype=np.uint8) |
|
|
|
|
|
for index, track in enumerate(session_state["input_points"]): |
|
|
|
point_coords = (int(track[0]), int(track[1])) |
|
if session_state["input_labels"][index] == 1: |
|
|
|
cv2.circle(transparent_layer_points, point_coords, radius, (0, 255, 0, 255), -1) |
|
else: |
|
|
|
cv2.circle(transparent_layer_points, point_coords, radius, (255, 0, 0, 255), -1) |
|
|
|
|
|
transparent_layer_points_pil = Image.fromarray(transparent_layer_points, "RGBA") |
|
|
|
|
|
selected_point_map_img = Image.alpha_composite( |
|
first_frame_pil.copy(), transparent_layer_points_pil |
|
) |
|
|
|
|
|
points = np.array(session_state["input_points"], dtype=np.float32) |
|
labels = np.array(session_state["input_labels"], np.int32) |
|
|
|
|
|
points_tensor = torch.tensor(points, dtype=torch.float32, device="cpu").unsqueeze(0) |
|
labels_tensor = torch.tensor(labels, dtype=torch.int32, device="cpu").unsqueeze(0) |
|
|
|
|
|
|
|
first_frame_output_img = None |
|
try: |
|
|
|
_, _, out_mask_logits = predictor.add_new_points( |
|
inference_state=session_state["inference_state"], |
|
frame_idx=0, |
|
obj_id=OBJ_ID, |
|
points=points_tensor, |
|
labels=labels_tensor, |
|
) |
|
|
|
|
|
|
|
mask_tensor = (out_mask_logits[0][0].detach().cpu() > 0.0) |
|
mask_numpy = mask_tensor.numpy() |
|
|
|
|
|
mask_image_pil = show_mask(mask_numpy, obj_id=OBJ_ID) |
|
|
|
|
|
|
|
first_frame_output_img = Image.alpha_composite(first_frame_pil.copy(), mask_image_pil) |
|
|
|
except Exception as e: |
|
print(f"Error during segmentation on first frame: {e}") |
|
|
|
|
|
|
|
return selected_point_map_img, first_frame_output_img, session_state |
|
|
|
|
|
def show_mask(mask, obj_id=None, random_color=False, convert_to_image=True): |
|
"""Helper function to visualize a mask.""" |
|
|
|
if isinstance(mask, torch.Tensor): |
|
mask = mask.detach().cpu().numpy() |
|
|
|
mask = mask.astype(bool) |
|
|
|
if random_color: |
|
color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0) |
|
else: |
|
cmap = plt.get_cmap("tab10") |
|
cmap_idx = 0 if obj_id is None else obj_id % 10 |
|
color = np.array([*cmap(cmap_idx)[:3], 0.6]) |
|
|
|
|
|
if mask.ndim == 3: |
|
mask = mask.squeeze() |
|
if mask.ndim != 2: |
|
print(f"Warning: show_mask received mask with shape {mask.shape}. Expected 2D.") |
|
|
|
h, w = mask.shape[:2] if mask.ndim >= 2 else (100, 100) |
|
if convert_to_image: |
|
return Image.fromarray(np.zeros((h, w, 4), dtype=np.uint8), "RGBA") |
|
else: |
|
return np.zeros((h, w, 4), dtype=np.uint8) |
|
|
|
h, w = mask.shape |
|
|
|
|
|
|
|
colored_mask = np.zeros((h, w, 4), dtype=np.float32) |
|
|
|
|
|
colored_mask[mask] = color |
|
|
|
|
|
colored_mask_uint8 = (colored_mask * 255).astype(np.uint8) |
|
|
|
if convert_to_image: |
|
mask_img = Image.fromarray(colored_mask_uint8, "RGBA") |
|
return mask_img |
|
else: |
|
return colored_mask_uint8 |
|
|
|
|
|
|
|
def propagate_to_all( |
|
|
|
|
|
video_in, |
|
session_state, |
|
): |
|
"""Runs mask propagation through the video and generates the output video.""" |
|
print("Starting propagation...") |
|
|
|
if ( |
|
len(session_state["input_points"]) == 0 |
|
or session_state["all_frames"] is None |
|
or session_state["inference_state"] is None |
|
or session_state["video_path"] is None |
|
): |
|
print("Error: Cannot propagate. No points selected, video not loaded, or inference state missing.") |
|
return ( |
|
gr.update(value=None, visible=False), |
|
session_state, |
|
) |
|
|
|
|
|
|
|
video_segments = {} |
|
try: |
|
|
|
for out_frame_idx, out_obj_ids, out_mask_logits in predictor.propagate_in_video( |
|
session_state["inference_state"] |
|
): |
|
|
|
|
|
video_segments[out_frame_idx] = { |
|
|
|
|
|
|
|
out_obj_id: (out_mask_logits[i][0].detach().cpu() > 0.0).numpy() |
|
for i, out_obj_id in enumerate(out_obj_ids) |
|
} |
|
|
|
|
|
|
|
print("Propagation finished.") |
|
except Exception as e: |
|
print(f"Error during propagation: {e}") |
|
return ( |
|
gr.update(value=None, visible=False), |
|
session_state, |
|
) |
|
|
|
|
|
output_frames = [] |
|
|
|
total_frames = len(session_state["all_frames"]) |
|
for out_frame_idx in range(total_frames): |
|
original_frame_rgb = session_state["all_frames"][out_frame_idx] |
|
|
|
transparent_background = Image.fromarray(original_frame_rgb).convert("RGBA") |
|
|
|
|
|
if out_frame_idx in video_segments and OBJ_ID in video_segments[out_frame_idx]: |
|
current_mask_numpy = video_segments[out_frame_idx][OBJ_ID] |
|
|
|
mask_image_pil = show_mask(current_mask_numpy, obj_id=OBJ_ID) |
|
|
|
output_frame_img_rgba = Image.alpha_composite(transparent_background, mask_image_pil) |
|
|
|
output_frame_np = np.array(output_frame_img_rgba.convert("RGB")) |
|
else: |
|
|
|
|
|
|
|
output_frame_np = original_frame_rgb |
|
|
|
output_frames.append(output_frame_np) |
|
|
|
|
|
|
|
|
|
unique_id = datetime.now().strftime("%Y%m%d%H%M%S%f") |
|
final_vid_filename = f"output_video_{unique_id}.mp4" |
|
final_vid_output_path = os.path.join(tempfile.gettempdir(), final_vid_filename) |
|
print(f"Output video path: {final_vid_output_path}") |
|
|
|
|
|
|
|
|
|
|
|
original_fps = get_video_fps(session_state["video_path"]) |
|
fps = original_fps if original_fps is not None and original_fps > 0 else 30 |
|
print(f"Creating output video with FPS: {fps}") |
|
|
|
|
|
if not output_frames: |
|
print("No output frames generated.") |
|
return ( |
|
gr.update(value=None, visible=False), |
|
session_state, |
|
) |
|
|
|
|
|
try: |
|
clip = ImageSequenceClip(output_frames, fps=fps) |
|
except Exception as e: |
|
print(f"Error creating ImageSequenceClip: {e}") |
|
return ( |
|
gr.update(value=None, visible=False), |
|
session_state, |
|
) |
|
|
|
|
|
|
|
|
|
|
|
try: |
|
print(f"Writing video file with codec='libx264', fps={fps}, preset='medium', threads='auto'") |
|
clip.write_videofile( |
|
final_vid_output_path, |
|
codec="libx264", |
|
fps=fps, |
|
preset="medium", |
|
threads="auto", |
|
logger=None |
|
) |
|
print("Video writing complete.") |
|
|
|
return ( |
|
gr.update(value=final_vid_output_path, visible=True), |
|
session_state, |
|
) |
|
except Exception as e: |
|
print(f"Error writing video file: {e}") |
|
|
|
if os.path.exists(final_vid_output_path): |
|
try: |
|
os.remove(final_vid_output_path) |
|
print(f"Removed partial video file: {final_vid_output_path}") |
|
except Exception as clean_e: |
|
print(f"Error removing partial file: {clean_e}") |
|
|
|
|
|
return ( |
|
gr.update(value=None, visible=False), |
|
session_state, |
|
) |
|
|
|
|
|
def update_output_video_visibility(): |
|
"""Simply returns a Gradio update to make the output video visible.""" |
|
return gr.update(visible=True) |
|
|
|
|
|
with gr.Blocks() as demo: |
|
|
|
session_state = gr.State( |
|
{ |
|
"first_frame": None, |
|
"all_frames": None, |
|
"input_points": [], |
|
"input_labels": [], |
|
"inference_state": None, |
|
"video_path": None, |
|
} |
|
) |
|
|
|
with gr.Column(): |
|
|
|
gr.Markdown(title) |
|
with gr.Row(): |
|
|
|
with gr.Column(): |
|
|
|
gr.Markdown(description_p) |
|
|
|
with gr.Accordion("Input Video", open=True) as video_in_drawer: |
|
video_in = gr.Video(label="Input Video", format="mp4") |
|
|
|
with gr.Row(): |
|
point_type = gr.Radio( |
|
label="point type", |
|
choices=["include", "exclude"], |
|
value="include", |
|
scale=2, |
|
interactive=True, |
|
) |
|
|
|
propagate_btn = gr.Button("Track", scale=1, variant="primary", interactive=False) |
|
clear_points_btn = gr.Button("Clear Points", scale=1, interactive=False) |
|
reset_btn = gr.Button("Reset", scale=1, interactive=False) |
|
|
|
|
|
|
|
points_map = gr.Image( |
|
label="Click on the First Frame to Add Points", |
|
type="numpy", |
|
interactive=True, |
|
height=400, |
|
width="auto", |
|
show_share_button=False, |
|
show_download_button=False, |
|
|
|
) |
|
|
|
with gr.Column(): |
|
gr.Markdown("# Try some of the examples below ⬇️") |
|
gr.Examples( |
|
examples=examples, |
|
inputs=[video_in], |
|
examples_per_page=8, |
|
cache_examples=False, |
|
) |
|
|
|
|
|
|
|
|
|
output_image = gr.Image( |
|
label="Segmentation Mask on First Frame", |
|
type="numpy", |
|
interactive=False, |
|
height=400, |
|
width="auto", |
|
show_share_button=False, |
|
show_download_button=False, |
|
|
|
) |
|
|
|
|
|
output_video = gr.Video(visible=False, label="Tracking Result") |
|
|
|
|
|
|
|
|
|
|
|
video_in.upload( |
|
fn=preprocess_video_in, |
|
inputs=[video_in, session_state], |
|
outputs=[ |
|
video_in_drawer, |
|
points_map, |
|
output_image, |
|
output_video, |
|
propagate_btn, |
|
clear_points_btn, |
|
reset_btn, |
|
session_state, |
|
], |
|
queue=False, |
|
) |
|
|
|
|
|
video_in.change( |
|
fn=preprocess_video_in, |
|
inputs=[video_in, session_state], |
|
outputs=[ |
|
video_in_drawer, |
|
points_map, |
|
output_image, |
|
output_video, |
|
propagate_btn, |
|
clear_points_btn, |
|
reset_btn, |
|
session_state, |
|
], |
|
queue=False, |
|
) |
|
|
|
|
|
|
|
points_map.select( |
|
fn=segment_with_points, |
|
inputs=[ |
|
point_type, |
|
session_state, |
|
], |
|
outputs=[ |
|
points_map, |
|
output_image, |
|
session_state, |
|
], |
|
queue=False, |
|
) |
|
|
|
|
|
clear_points_btn.click( |
|
fn=clear_points, |
|
inputs=[session_state], |
|
outputs=[ |
|
points_map, |
|
output_image, |
|
output_video, |
|
session_state, |
|
], |
|
queue=False, |
|
) |
|
|
|
|
|
reset_btn.click( |
|
fn=reset, |
|
inputs=[session_state], |
|
outputs=[ |
|
video_in, |
|
video_in_drawer, |
|
points_map, |
|
output_image, |
|
output_video, |
|
propagate_btn, |
|
clear_points_btn, |
|
reset_btn, |
|
session_state, |
|
], |
|
queue=False, |
|
) |
|
|
|
|
|
propagate_btn.click( |
|
fn=update_output_video_visibility, |
|
inputs=[], |
|
outputs=[output_video], |
|
queue=False, |
|
).then( |
|
fn=propagate_to_all, |
|
inputs=[ |
|
video_in, |
|
session_state, |
|
], |
|
outputs=[ |
|
output_video, |
|
session_state, |
|
], |
|
|
|
|
|
concurrency_limit=1, |
|
queue=True, |
|
) |
|
|
|
|
|
|
|
demo.queue() |
|
print("Gradio demo starting...") |
|
|
|
demo.launch() |
|
print("Gradio demo launched.") |