import torch import time import uuid from typing import Tuple import gradio as gr import supervision as sv import numpy as np from tqdm import tqdm from transformers import pipeline from PIL import Image START_FRAME = 0 END_FRAME = 10 TOTAL = END_FRAME - START_FRAME DEVICE = "cuda" if torch.cuda.is_available() else "cpu" SAM_GENERATOR = pipeline( task="mask-generation", model="facebook/sam-vit-base", device=DEVICE) MASK_ANNOTATOR = sv.MaskAnnotator( color=sv.Color.red(), color_lookup=sv.ColorLookup.INDEX) def run_sam(frame: np.ndarray) -> sv.Detections: # convert from Numpy BGR to PIL RGB image = Image.fromarray(frame[:, :, ::-1]) outputs = SAM_GENERATOR(image) mask = np.array(outputs['masks']) return sv.Detections(xyxy=sv.mask_to_xyxy(masks=mask), mask=mask) def mask_video(source_video: str, prompt: str, confidence: float, name: str) -> str: video_info = sv.VideoInfo.from_video_path(source_video) frame_iterator = iter(sv.get_video_frames_generator( source_path=source_video, start=START_FRAME, end=END_FRAME)) with sv.VideoSink(f"{name}.mp4", video_info=video_info) as sink: for _ in tqdm(range(TOTAL), desc="Masking frames"): frame = next(frame_iterator) detections = run_sam(frame) annotated_frame = MASK_ANNOTATOR.annotate( scene=frame.copy(), detections=detections) sink.write_frame(annotated_frame) return f"{name}.mp4" def process( source_video: str, prompt: str, confidence: float, progress=gr.Progress(track_tqdm=True) ) -> Tuple[str, str]: name = str(uuid.uuid4()) masked_video = mask_video(source_video, prompt, confidence, name) return masked_video, masked_video with gr.Blocks() as demo: with gr.Row(): with gr.Column(): source_video_player = gr.Video( label="Source video", source="upload", format="mp4") prompt_text = gr.Textbox( label="Prompt", value="person") confidence_slider = gr.Slider( label="Confidence", minimum=0.5, maximum=1.0, step=0.05, value=0.6) submit_button = gr.Button("Submit") with gr.Column(): masked_video_player = gr.Video(label="Masked video") painted_video_player = gr.Video(label="Painted video") submit_button.click( process, inputs=[source_video_player, prompt_text, confidence_slider], outputs=[masked_video_player, painted_video_player]) demo.queue().launch(debug=False, show_error=True)