File size: 2,594 Bytes
b643479
5b163f1
 
 
7b4534e
 
5b163f1
b643479
5b163f1
b643479
 
7b4534e
5b163f1
 
 
7b4534e
b643479
 
 
 
 
 
 
 
7b4534e
b643479
 
 
 
 
 
 
 
 
 
 
5b163f1
 
 
 
 
 
 
b643479
 
 
 
 
 
5b163f1
b643479
 
 
 
 
 
 
 
 
7b4534e
 
 
 
 
5b163f1
 
 
 
 
 
7b4534e
5b163f1
 
 
7b4534e
 
5b163f1
 
 
7b4534e
5b163f1
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
import torch
import time
import uuid
from typing import Tuple

import gradio as gr
import supervision as sv
import numpy as np
from tqdm import tqdm
from transformers import pipeline
from PIL import Image

START_FRAME = 0
END_FRAME = 10
TOTAL = END_FRAME - START_FRAME

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
SAM_GENERATOR = pipeline(
    task="mask-generation",
    model="facebook/sam-vit-base",
    device=DEVICE)
MASK_ANNOTATOR = sv.MaskAnnotator(
    color=sv.Color.red(),
    color_lookup=sv.ColorLookup.INDEX)


def run_sam(frame: np.ndarray) -> sv.Detections:
    # convert from Numpy BGR to PIL RGB
    image = Image.fromarray(frame[:, :, ::-1])

    outputs = SAM_GENERATOR(image)
    mask = np.array(outputs['masks'])
    return sv.Detections(xyxy=sv.mask_to_xyxy(masks=mask), mask=mask)


def mask_video(source_video: str, prompt: str, confidence: float, name: str) -> str:
    video_info = sv.VideoInfo.from_video_path(source_video)
    frame_iterator = iter(sv.get_video_frames_generator(
        source_path=source_video, start=START_FRAME, end=END_FRAME))

    with sv.VideoSink(f"{name}.mp4", video_info=video_info) as sink:
        for _ in tqdm(range(TOTAL), desc="Masking frames"):
            frame = next(frame_iterator)
            detections = run_sam(frame)
            annotated_frame = MASK_ANNOTATOR.annotate(
                scene=frame.copy(), detections=detections)
            sink.write_frame(annotated_frame)
    return f"{name}.mp4"


def process(
    source_video: str,
    prompt: str,
    confidence: float,
    progress=gr.Progress(track_tqdm=True)
) -> Tuple[str, str]:
    name = str(uuid.uuid4())
    masked_video = mask_video(source_video, prompt, confidence, name)
    return masked_video, masked_video


with gr.Blocks() as demo:
    with gr.Row():
        with gr.Column():
            source_video_player = gr.Video(
                label="Source video", source="upload", format="mp4")
            prompt_text = gr.Textbox(
                label="Prompt", value="person")
            confidence_slider = gr.Slider(
                label="Confidence", minimum=0.5, maximum=1.0, step=0.05, value=0.6)
            submit_button = gr.Button("Submit")
        with gr.Column():
            masked_video_player = gr.Video(label="Masked video")
            painted_video_player = gr.Video(label="Painted video")

    submit_button.click(
        process,
        inputs=[source_video_player, prompt_text, confidence_slider],
        outputs=[masked_video_player, painted_video_player])

demo.queue().launch(debug=False, show_error=True)