# Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. import copy import os from datetime import datetime import gradio as gr # Removed GPU-specific environment variable setting # os.environ["TORCH_CUDNN_SDPA_ENABLED"] = "0,1,2,3,4,5,6,7" import tempfile import cv2 import matplotlib.pyplot as plt import numpy as np # Removed spaces decorator import for CPU-only demo # import spaces import torch from moviepy.editor import ImageSequenceClip from PIL import Image from sam2.build_sam import build_sam2_video_predictor # Description title = "
EdgeTAM [GitHub]
" description_p = """# Instructions
  1. Upload one video or click one example video
  2. Click 'include' point type, select the object to segment and track
  3. Click 'exclude' point type (optional), select the area you want to avoid segmenting and tracking
  4. Click the 'Track' button to obtain the masked video
""" # examples - Keep examples, they are input files examples = [ ["examples/01_dog.mp4"], ["examples/02_cups.mp4"], ["examples/03_blocks.mp4"], ["examples/04_coffee.mp4"], ["examples/05_default_juggle.mp4"], ["examples/01_breakdancer.mp4"], ["examples/02_hummingbird.mp4"], ["examples/03_skateboarder.mp4"], ["examples/04_octopus.mp4"], ["examples/05_landing_dog_soccer.mp4"], ["examples/06_pingpong.mp4"], ["examples/07_snowboarder.mp4"], ["examples/08_driving.mp4"], ["examples/09_birdcartoon.mp4"], ["examples/10_cloth_magic.mp4"], ["examples/11_polevault.mp4"], ["examples/12_hideandseek.mp4"], ["examples/13_butterfly.mp4"], ["examples/14_social_dog_training.mp4"], ["examples/15_cricket.mp4"], ["examples/16_robotarm.mp4"], ["examples/17_childrendancing.mp4"], ["examples/18_threedogs.mp4"], ["examples/19_cyclist.mp4"], ["examples/20_doughkneading.mp4"], ["examples/21_biker.mp4"], ["examples/22_dogskateboarder.mp4"], ["examples/23_racecar.mp4"], ["examples/24_clownfish.mp4"], ] OBJ_ID = 0 sam2_checkpoint = "checkpoints/edgetam.pt" model_cfg = "edgetam.yaml" # Ensure predictor is explicitly built for CPU # The device is set here and with .to("cpu") predictor = build_sam2_video_predictor(model_cfg, sam2_checkpoint, device="cpu") predictor.to("cpu") # Explicitly move to CPU after building print("predictor loaded on CPU") # Removed autocast block for maximum CPU compatibility # torch.autocast(device_type="cpu", dtype=torch.bfloat16).__enter__() # Removed commented-out GPU-specific code # if torch.cuda.get_device_properties(0).major >= 8: ... def get_video_fps(video_path): """Gets the frames per second of a video file.""" if video_path is None or not os.path.exists(video_path): print(f"Warning: Video file not found at {video_path}") return None cap = cv2.VideoCapture(video_path) if not cap.isOpened(): print(f"Error: Could not open video file {video_path}.") return None fps = cap.get(cv2.CAP_PROP_FPS) cap.release() return fps # Removed @spaces.GPU decorator def preprocess_video_in(video_path, session_state): """Loads video frames and initializes the predictor state.""" print(f"Processing video: {video_path}") if video_path is None or not os.path.exists(video_path): print("No video path provided or file not found.") # Reset state and UI elements if input is invalid return ( gr.update(open=True), # video_in_drawer None, # points_map None, # output_image gr.update(value=None, visible=False), # output_video gr.update(interactive=False), # propagate_btn gr.update(interactive=False), # clear_points_btn gr.update(interactive=False), # reset_btn { # Reset session state "first_frame": None, "all_frames": None, "input_points": [], "input_labels": [], "inference_state": None, "video_path": None, } ) # Read the first frame and all frames cap = cv2.VideoCapture(video_path) if not cap.isOpened(): print(f"Error: Could not open video file {video_path}.") # Reset state and UI elements on error return ( gr.update(open=True), None, None, gr.update(value=None, visible=False), gr.update(interactive=False), # propagate_btn gr.update(interactive=False), # clear_points_btn gr.update(interactive=False), # reset_btn { # Reset session state "first_frame": None, "all_frames": None, "input_points": [], "input_labels": [], "inference_state": None, "video_path": None, } ) first_frame = None all_frames = [] while True: ret, frame = cap.read() if not ret: break # Convert BGR to RGB frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) all_frames.append(frame) if first_frame is None: first_frame = frame # Store the first frame cap.release() if not all_frames: print(f"Error: No frames read from video file {video_path}.") # Reset state and UI elements if no frames are read return ( gr.update(open=True), None, None, gr.update(value=None, visible=False), gr.update(interactive=False), # propagate_btn gr.update(interactive=False), # clear_points_btn gr.update(interactive=False), # reset_btn { # Reset session state "first_frame": None, "all_frames": None, "input_points": [], "input_labels": [], "inference_state": None, "video_path": None, } ) # Update session state with frames and path session_state["first_frame"] = copy.deepcopy(first_frame) # Store a copy session_state["all_frames"] = all_frames session_state["video_path"] = video_path # Store the path session_state["input_points"] = [] session_state["input_labels"] = [] # Initialize state *without* the device argument session_state["inference_state"] = predictor.init_state(video_path=video_path) print("Video loaded and predictor state initialized.") return [ gr.update(open=False), # video_in_drawer first_frame, # points_map (shows first frame) None, # output_image (cleared initially) gr.update(value=None, visible=False), # output_video (hidden initially) gr.update(interactive=True), # Enable buttons gr.update(interactive=True), # Enable buttons gr.update(interactive=True), # Enable buttons session_state, # Updated state ] def reset(session_state): """Resets the UI and session state.""" print("Resetting demo.") # Clear points and labels session_state["input_points"] = [] session_state["input_labels"] = [] # Reset the predictor state if it exists if session_state["inference_state"] is not None: predictor.reset_state(session_state["inference_state"]) # After reset, we also discard the state object as a new video might be loaded session_state["inference_state"] = None # Clear frames and video path session_state["first_frame"] = None session_state["all_frames"] = None session_state["video_path"] = None # Update UI elements to their initial state return ( None, # video_in gr.update(open=True), # video_in_drawer open None, # points_map cleared None, # output_image cleared gr.update(value=None, visible=False), # output_video hidden gr.update(interactive=False), # Disable buttons gr.update(interactive=False), # Disable buttons gr.update(interactive=False), # Disable buttons session_state, # Updated session state ) def clear_points(session_state): """Clears selected points and resets segmentation on the first frame.""" print("Clearing points.") # Clear points and labels lists session_state["input_points"] = [] session_state["input_labels"] = [] # Reset the predictor state if it exists. This clears internal masks/features # but keeps the video context initialized by preprocess_video_in. if session_state["inference_state"] is not None: predictor.reset_state(session_state["inference_state"]) # After resetting the state, if we still have the video path, re-initialize the state # to be ready for new points on the same video. if session_state["video_path"] is not None: # Re-initialize state *without* the device argument session_state["inference_state"] = predictor.init_state(video_path=session_state["video_path"]) print("Predictor state re-initialized after clearing points.") else: print("Warning: Could not re-initialize state after clear_points (video_path missing).") session_state["inference_state"] = None # Ensure state is None if video_path is gone # Re-render the points_map with no points drawn (just the first frame) # Re-render the output_image with no mask (just the first frame) first_frame_img = session_state["first_frame"] if session_state["first_frame"] is not None else None return ( first_frame_img, # points_map shows original first frame None, # output_image cleared gr.update(value=None, visible=False), # Hide output video session_state, # Updated session state ) # Removed @spaces.GPU decorator def segment_with_points( point_type, session_state, evt: gr.SelectData, ): """Adds a point prompt and performs segmentation on the first frame.""" # Ensure we have a valid first frame and inference state if session_state["first_frame"] is None or session_state["inference_state"] is None: print("Error: Cannot segment. No video loaded or inference state missing.") # Return current states to avoid errors, without changing UI much return ( session_state["first_frame"], # points_map remains unchanged None, # output_image remains unchanged or cleared session_state, ) # evt.index gives the (x, y) coordinates of the click click_coords = evt.index print(f"Clicked at: {click_coords} ({point_type})") session_state["input_points"].append(click_coords) if point_type == "include": session_state["input_labels"].append(1) elif point_type == "exclude": session_state["input_labels"].append(0) # Get the first frame as a PIL image for drawing first_frame_pil = Image.fromarray(session_state["first_frame"]).convert("RGBA") w, h = first_frame_pil.size # Define the circle radius fraction = 0.01 radius = max(2, int(fraction * min(w, h))) # Ensure minimum radius of 2 # Create a transparent layer to draw points transparent_layer_points = np.zeros((h, w, 4), dtype=np.uint8) # Draw points on the transparent layer for index, track in enumerate(session_state["input_points"]): # Ensure coordinates are integers for cv2.circle point_coords = (int(track[0]), int(track[1])) if session_state["input_labels"][index] == 1: # Green circle for include cv2.circle(transparent_layer_points, point_coords, radius, (0, 255, 0, 255), -1) else: # Red circle for exclude cv2.circle(transparent_layer_points, point_coords, radius, (255, 0, 0, 255), -1) # Convert the transparent layer back to an image and composite onto the first frame transparent_layer_points_pil = Image.fromarray(transparent_layer_points, "RGBA") # Combine the first frame image with the points layer for the points_map output # points_map shows the first frame *with the points you added*. selected_point_map_img = Image.alpha_composite( first_frame_pil.copy(), transparent_layer_points_pil ) # Prepare points and labels as tensors on CPU for the predictor points = np.array(session_state["input_points"], dtype=np.float32) labels = np.array(session_state["input_labels"], np.int32) # Ensure tensors are on CPU points_tensor = torch.tensor(points, dtype=torch.float32, device="cpu").unsqueeze(0) # Add batch dim labels_tensor = torch.tensor(labels, dtype=torch.int32, device="cpu").unsqueeze(0) # Add batch dim # Add new points to the predictor's state and get the mask for the first frame # This call performs segmentation on the current frame (frame_idx=0) using all accumulated points first_frame_output_img = None # Initialize output mask image as None in case of error try: # Note: predictor.add_new_points modifies the internal inference_state _, _, out_mask_logits = predictor.add_new_points( inference_state=session_state["inference_state"], frame_idx=0, # Always segment on the first frame initially obj_id=OBJ_ID, points=points_tensor, labels=labels_tensor, ) # Process logits: detach from graph, move to CPU, apply threshold # out_mask_logits is a list of tensors [tensor([H, W])] for the requested obj_id mask_tensor = (out_mask_logits[0][0].detach().cpu() > 0.0) # Apply threshold and get the single mask tensor [H, W] mask_numpy = mask_tensor.numpy() # Convert to numpy # Get the mask image (RGBA) mask_image_pil = show_mask(mask_numpy, obj_id=OBJ_ID) # show_mask returns RGBA PIL Image # Composite the mask onto the first frame for the output_image # output_image shows the first frame *with the segmentation mask result*. first_frame_output_img = Image.alpha_composite(first_frame_pil.copy(), mask_image_pil) except Exception as e: print(f"Error during segmentation on first frame: {e}") # On error, first_frame_output_img remains None return selected_point_map_img, first_frame_output_img, session_state def show_mask(mask, obj_id=None, random_color=False, convert_to_image=True): """Helper function to visualize a mask.""" # Ensure mask is a numpy array (and boolean) if isinstance(mask, torch.Tensor): mask = mask.detach().cpu().numpy() # Ensure it's on CPU and converted to numpy # Convert potential float/int mask to boolean mask mask = mask.astype(bool) if random_color: color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0) # RGBA with 0.6 alpha else: cmap = plt.get_cmap("tab10") cmap_idx = 0 if obj_id is None else obj_id % 10 # Use modulo 10 for tab10 colors color = np.array([*cmap(cmap_idx)[:3], 0.6]) # RGBA with 0.6 alpha # Ensure mask has H, W dimensions if mask.ndim == 3: mask = mask.squeeze() # Remove singular dimensions like (H, W, 1) if mask.ndim != 2: print(f"Warning: show_mask received mask with shape {mask.shape}. Expected 2D.") # Create an empty transparent image if mask shape is unexpected h, w = mask.shape[:2] if mask.ndim >= 2 else (100, 100) # Use actual shape if possible, otherwise default if convert_to_image: return Image.fromarray(np.zeros((h, w, 4), dtype=np.uint8), "RGBA") else: return np.zeros((h, w, 4), dtype=np.uint8) h, w = mask.shape # Create an RGBA image from the mask and color # Apply color where mask is True # Need to reshape color to be broadcastable [1, 1, 4] colored_mask = np.zeros((h, w, 4), dtype=np.float32) # Start with fully transparent black # Apply the color only where the mask is True. # This directly creates the colored overlay with transparency. colored_mask[mask] = color # Convert to uint8 [0-255] colored_mask_uint8 = (colored_mask * 255).astype(np.uint8) if convert_to_image: mask_img = Image.fromarray(colored_mask_uint8, "RGBA") return mask_img else: return colored_mask_uint8 # Removed @spaces.GPU decorator def propagate_to_all( # We don't strictly need video_in path here anymore as it's in session_state, # but keeping it is fine. Accessing session_state["video_path"] is more robust. video_in, session_state, ): """Runs mask propagation through the video and generates the output video.""" print("Starting propagation...") # Ensure state is ready if ( len(session_state["input_points"]) == 0 # Need at least one point or session_state["all_frames"] is None or session_state["inference_state"] is None or session_state["video_path"] is None # Ensure we have the original video path ): print("Error: Cannot propagate. No points selected, video not loaded, or inference state missing.") return ( gr.update(value=None, visible=False), # Hide output video on error session_state, ) # run propagation throughout the video and collect the results # The generator yields (frame_idx, obj_ids, mask_logits) video_segments = {} try: # This loop performs the core tracking prediction frame by frame for out_frame_idx, out_obj_ids, out_mask_logits in predictor.propagate_in_video( session_state["inference_state"] ): # Process logits: detach from graph, move to CPU, convert to numpy boolean mask # Ensure tensor is on CPU before converting to numpy video_segments[out_frame_idx] = { # out_mask_logits is a list of tensors (one per object tracked in this frame) # Each tensor is [batch_size, H, W]. Batch size is 1 here. # Access the first element of the batch [0] out_obj_id: (out_mask_logits[i][0].detach().cpu() > 0.0).numpy() for i, out_obj_id in enumerate(out_obj_ids) } # Optional: print progress # print(f"Processed frame {out_frame_idx+1}/{len(session_state['all_frames'])}") print("Propagation finished.") except Exception as e: print(f"Error during propagation: {e}") return ( gr.update(value=None, visible=False), # Hide output video on error session_state, ) output_frames = [] # Iterate through all original frames to generate output video total_frames = len(session_state["all_frames"]) for out_frame_idx in range(total_frames): original_frame_rgb = session_state["all_frames"][out_frame_idx] # Convert original frame to RGBA for compositing transparent_background = Image.fromarray(original_frame_rgb).convert("RGBA") # Check if we have a mask for this frame and object ID if out_frame_idx in video_segments and OBJ_ID in video_segments[out_frame_idx]: current_mask_numpy = video_segments[out_frame_idx][OBJ_ID] # Get the mask image (RGBA) mask_image_pil = show_mask(current_mask_numpy, obj_id=OBJ_ID) # Composite the mask onto the frame output_frame_img_rgba = Image.alpha_composite(transparent_background, mask_image_pil) # Convert back to numpy RGB (moviepy needs RGB or RGBA) output_frame_np = np.array(output_frame_img_rgba.convert("RGB")) else: # If no mask for this frame/object, just use the original frame (converted to RGB) # Note: all_frames are already RGB numpy arrays, so just use them directly. # print(f"Warning: No mask found for frame {out_frame_idx} and object {OBJ_ID}. Using original frame.") output_frame_np = original_frame_rgb # Already RGB numpy array output_frames.append(output_frame_np) # Define output path in a temporary directory # Use os.path.join for cross-platform compatibility unique_id = datetime.now().strftime("%Y%m%d%H%M%S%f") # Use microseconds for more uniqueness final_vid_filename = f"output_video_{unique_id}.mp4" final_vid_output_path = os.path.join(tempfile.gettempdir(), final_vid_filename) print(f"Output video path: {final_vid_output_path}") # Create a video clip from the image sequence # Get original FPS or default # Get FPS from the stored video path in session state original_fps = get_video_fps(session_state["video_path"]) fps = original_fps if original_fps is not None and original_fps > 0 else 30 # Default to 30 if detection fails or is zero print(f"Creating output video with FPS: {fps}") # Check if there are frames to process if not output_frames: print("No output frames generated.") return ( gr.update(value=None, visible=False), # Hide output video session_state, ) # Create ImageSequenceClip from the list of numpy arrays try: clip = ImageSequenceClip(output_frames, fps=fps) except Exception as e: print(f"Error creating ImageSequenceClip: {e}") return ( gr.update(value=None, visible=False), # Hide output video on error session_state, ) # Write the result to a file. Use 'libx264' codec for broad compatibility. # `preset` and `threads` for CPU optimization. # `logger=None` prevents moviepy from printing progress to stdout/stderr, which can clutter the Gradio logs. try: print(f"Writing video file with codec='libx264', fps={fps}, preset='medium', threads='auto'") clip.write_videofile( final_vid_output_path, codec="libx264", fps=fps, # Ensure correct FPS is used during writing preset="medium", # CPU optimization: 'fast', 'faster', 'veryfast' are options for speed vs size threads="auto", # CPU optimization: Use multiple cores logger=None # Suppress moviepy output ) print("Video writing complete.") # Return the path and make the video player visible return ( gr.update(value=final_vid_output_path, visible=True), session_state, ) except Exception as e: print(f"Error writing video file: {e}") # Clean up potentially created partial file if os.path.exists(final_vid_output_path): try: os.remove(final_vid_output_path) print(f"Removed partial video file: {final_vid_output_path}") except Exception as clean_e: print(f"Error removing partial file: {clean_e}") # Return None if writing fails return ( gr.update(value=None, visible=False), session_state, ) def update_output_video_visibility(): """Simply returns a Gradio update to make the output video visible.""" return gr.update(visible=True) with gr.Blocks() as demo: # Session state dictionary to hold video frames, points, labels, and predictor state session_state = gr.State( { "first_frame": None, # numpy array (RGB) "all_frames": None, # list of numpy arrays (RGB) "input_points": [], # list of (x, y) tuples/lists "input_labels": [], # list of 1s and 0s "inference_state": None, # EdgeTAM predictor state object "video_path": None, # Store the input video path } ) with gr.Column(): # Title gr.Markdown(title) with gr.Row(): with gr.Column(): # Instructions gr.Markdown(description_p) with gr.Accordion("Input Video", open=True) as video_in_drawer: video_in = gr.Video(label="Input Video", format="mp4") # Will hold the video file path with gr.Row(): point_type = gr.Radio( label="point type", choices=["include", "exclude"], value="include", scale=2, interactive=True, # Make interactive ) # Buttons are initially disabled until a video is loaded propagate_btn = gr.Button("Track", scale=1, variant="primary", interactive=False) clear_points_btn = gr.Button("Clear Points", scale=1, interactive=False) reset_btn = gr.Button("Reset", scale=1, interactive=False) # points_map is where users click to add points. Needs to be interactive. # Shows the first frame with points drawn on it. points_map = gr.Image( label="Click on the First Frame to Add Points", # Clearer label type="numpy", interactive=True, # Make interactive to capture clicks height=400, # Set a fixed height for better UI width="auto", # Let width adjust show_share_button=False, show_download_button=False, # show_label=False # Can hide label if space is tight ) with gr.Column(): gr.Markdown("# Try some of the examples below ⬇️") gr.Examples( examples=examples, inputs=[video_in], examples_per_page=8, cache_examples=False, # Do not cache processed examples, as state is involved ) # Add padding/space # gr.Markdown("
") # output_image shows the segmentation mask prediction on the *first* frame output_image = gr.Image( label="Segmentation Mask on First Frame", # Clearer label type="numpy", interactive=False, # Not interactive, just displays the mask height=400, # Match height of points_map width="auto", # Let width adjust show_share_button=False, show_download_button=False, # show_label=False # Can hide label ) # output_video shows the final tracking result output_video = gr.Video(visible=False, label="Tracking Result") # --- Event Handlers --- # When a new video file is uploaded via the file browser video_in.upload( fn=preprocess_video_in, inputs=[video_in, session_state], outputs=[ video_in_drawer, # Close accordion points_map, # Show first frame in points_map output_image, # Clear output image output_video, # Hide output video propagate_btn, # Enable Track button clear_points_btn,# Enable Clear Points button reset_btn, # Enable Reset button session_state, # Update session state ], queue=False, # Process immediately ) # When an example video is selected (change event) video_in.change( fn=preprocess_video_in, inputs=[video_in, session_state], outputs=[ video_in_drawer, # Close accordion points_map, # Show first frame in points_map output_image, # Clear output image output_video, # Hide output video propagate_btn, # Enable Track button clear_points_btn,# Enable Clear Points button reset_btn, # Enable Reset button session_state, # Update session state ], queue=False, # Process immediately ) # Triggered when a user clicks on the points_map image points_map.select( fn=segment_with_points, inputs=[ point_type, # "include" or "exclude" radio button value session_state, # Pass session state ], outputs=[ points_map, # Updated image with points drawn output_image, # Updated image with first frame segmentation mask session_state, # Updated session state (points/labels added) ], queue=False, # Process clicks immediately ) # Button to clear all selected points and reset the first frame mask clear_points_btn.click( fn=clear_points, inputs=[session_state], # Pass session state outputs=[ points_map, # points_map shows original first frame without points output_image, # output_image cleared (or shows original first frame without mask) output_video, # Hide output video session_state, # Updated session state (points/labels cleared, inference state reset) ], queue=False, # Process immediately ) # Button to reset the entire demo state and UI reset_btn.click( fn=reset, inputs=[session_state], # Pass session state outputs=[ video_in, # Clear video input video_in_drawer, # Open video accordion points_map, # Clear points_map output_image, # Clear output_image output_video, # Hide output_video propagate_btn, # Disable buttons clear_points_btn,# Disable buttons reset_btn, # Disable buttons session_state, # Reset session state ], queue=False, # Process immediately ) # Button to start mask propagation through the video propagate_btn.click( fn=update_output_video_visibility, # First, make the output video player visible inputs=[], outputs=[output_video], queue=False, # Process this UI update immediately ).then( # Then, run the propagation function fn=propagate_to_all, inputs=[ video_in, # Get the input video path (can also get from session_state["video_path"]) session_state, # Pass session state (contains frames, points, inference_state, video_path) ], outputs=[ output_video, # Update output video player with result session_state, # Update session state (currently, propagate doesn't modify state much, but good practice) ], # CPU Optimization: Limit concurrency to 1 to prevent resource exhaustion. # Queue=True ensures requests wait if another is processing. concurrency_limit=1, queue=True, ) # Launch the Gradio demo demo.queue() # Enable queuing for sequential processing under concurrency limits print("Gradio demo starting...") # Removed share=True for local debugging unless you specifically need a public link demo.launch() print("Gradio demo launched.")