Update app.py
Browse files
app.py
CHANGED
@@ -1,178 +1,428 @@
|
|
1 |
-
#
|
|
|
|
|
|
|
|
|
2 |
|
3 |
-
import os
|
4 |
import copy
|
5 |
-
import
|
6 |
from datetime import datetime
|
7 |
-
import
|
8 |
|
9 |
import cv2
|
10 |
-
import numpy as np
|
11 |
-
from PIL import Image
|
12 |
import matplotlib.pyplot as plt
|
13 |
-
import
|
14 |
import gradio as gr
|
|
|
15 |
from moviepy.editor import ImageSequenceClip
|
16 |
-
|
17 |
from sam2.build_sam import build_sam2_video_predictor
|
18 |
|
19 |
-
# Remove CUDA
|
20 |
-
os.environ
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
21 |
|
22 |
-
#
|
23 |
sam2_checkpoint = "checkpoints/edgetam.pt"
|
24 |
model_cfg = "edgetam.yaml"
|
25 |
-
examples = [[f"examples/{vid}"] for vid in ["01_dog.mp4", "02_cups.mp4", "03_blocks.mp4", "04_coffee.mp4", "05_default_juggle.mp4"]]
|
26 |
-
OBJ_ID = 0
|
27 |
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
37 |
predictor = None
|
38 |
|
39 |
-
|
|
|
40 |
cap = cv2.VideoCapture(video_path)
|
41 |
-
if not cap.isOpened():
|
|
|
|
|
42 |
fps = cap.get(cv2.CAP_PROP_FPS)
|
43 |
cap.release()
|
44 |
return fps
|
45 |
|
46 |
-
def reset(
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
51 |
|
52 |
-
def
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
|
|
|
|
|
|
58 |
|
59 |
-
def preprocess_video(video_path, session):
|
60 |
cap = cv2.VideoCapture(video_path)
|
61 |
-
if not cap.isOpened():
|
62 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
63 |
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
64 |
-
stride = max(1, total_frames // 300)
|
65 |
-
frames, first_frame = [], None
|
66 |
-
|
67 |
-
w, h = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)), int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
68 |
-
target_w = 640
|
69 |
-
scale = target_w / w if w > target_w else 1.0
|
70 |
|
71 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
72 |
while True:
|
73 |
ret, frame = cap.read()
|
74 |
-
if not ret:
|
75 |
-
|
76 |
-
|
77 |
-
|
|
|
78 |
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
79 |
-
if first_frame is None:
|
80 |
-
|
81 |
-
|
|
|
|
|
82 |
cap.release()
|
|
|
|
|
|
|
|
|
|
|
83 |
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
def segment_with_points(
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
108 |
|
109 |
try:
|
110 |
-
_, _,
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
115 |
except Exception as e:
|
116 |
-
print("
|
117 |
-
|
118 |
|
119 |
-
|
120 |
-
if not session["input_points"] or not session["inference_state"]: return None, session
|
121 |
|
122 |
-
|
123 |
-
|
124 |
-
|
125 |
-
|
126 |
-
|
127 |
-
|
|
|
128 |
|
129 |
-
|
130 |
-
|
131 |
-
|
|
|
|
|
132 |
try:
|
133 |
-
|
134 |
-
|
135 |
-
|
136 |
-
|
137 |
-
|
138 |
-
|
139 |
-
|
140 |
-
|
141 |
-
|
142 |
-
|
143 |
-
|
144 |
-
|
145 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
146 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
147 |
with gr.Blocks() as demo:
|
148 |
-
|
149 |
-
|
150 |
-
|
151 |
-
|
152 |
-
|
153 |
-
|
154 |
-
|
155 |
-
|
156 |
-
|
157 |
-
|
158 |
-
|
159 |
-
|
160 |
-
|
161 |
-
|
162 |
-
|
163 |
-
|
164 |
-
|
165 |
-
|
166 |
-
|
167 |
-
|
168 |
-
|
169 |
-
|
170 |
-
|
171 |
-
|
172 |
-
|
173 |
-
|
174 |
-
|
175 |
-
|
176 |
-
|
177 |
-
|
178 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
|
|
|
7 |
import copy
|
8 |
+
import os
|
9 |
from datetime import datetime
|
10 |
+
import tempfile
|
11 |
|
12 |
import cv2
|
|
|
|
|
13 |
import matplotlib.pyplot as plt
|
14 |
+
import numpy as np
|
15 |
import gradio as gr
|
16 |
+
import torch
|
17 |
from moviepy.editor import ImageSequenceClip
|
18 |
+
from PIL import Image
|
19 |
from sam2.build_sam import build_sam2_video_predictor
|
20 |
|
21 |
+
# Remove CUDA environment variables
|
22 |
+
if 'TORCH_CUDNN_SDPA_ENABLED' in os.environ:
|
23 |
+
del os.environ["TORCH_CUDNN_SDPA_ENABLED"]
|
24 |
+
|
25 |
+
# UI Description
|
26 |
+
title = "<center><strong><font size='8'>EdgeTAM CPU</font></strong> <a href='https://github.com/facebookresearch/EdgeTAM'><font size='6'>[GitHub]</font></a></center>"
|
27 |
+
|
28 |
+
description_p = """# Instructions
|
29 |
+
<ol>
|
30 |
+
<li>Upload one video or click one example video</li>
|
31 |
+
<li>Click 'include' point type, select the object to segment and track</li>
|
32 |
+
<li>Click 'exclude' point type (optional), select the area to avoid segmenting</li>
|
33 |
+
<li>Click the 'Track' button to obtain the masked video</li>
|
34 |
+
</ol>
|
35 |
+
"""
|
36 |
+
|
37 |
+
# Example videos
|
38 |
+
examples = [
|
39 |
+
["examples/01_dog.mp4"],
|
40 |
+
["examples/02_cups.mp4"],
|
41 |
+
["examples/03_blocks.mp4"],
|
42 |
+
["examples/04_coffee.mp4"],
|
43 |
+
["examples/05_default_juggle.mp4"],
|
44 |
+
]
|
45 |
+
|
46 |
+
OBJ_ID = 0
|
47 |
|
48 |
+
# Initialize model on CPU
|
49 |
sam2_checkpoint = "checkpoints/edgetam.pt"
|
50 |
model_cfg = "edgetam.yaml"
|
|
|
|
|
51 |
|
52 |
+
def check_file_exists(filepath):
|
53 |
+
exists = os.path.exists(filepath)
|
54 |
+
if not exists:
|
55 |
+
print(f"WARNING: File not found: {filepath}")
|
56 |
+
return exists
|
57 |
+
|
58 |
+
# Verify model files
|
59 |
+
model_files_exist = check_file_exists(sam2_checkpoint) and check_file_exists(model_cfg)
|
60 |
+
try:
|
61 |
+
predictor = build_sam2_video_predictor(model_cfg, sam2_checkpoint, device="cpu")
|
62 |
+
print("Predictor loaded on CPU")
|
63 |
+
except Exception as e:
|
64 |
+
print(f"Error loading model: {e}")
|
65 |
+
import traceback
|
66 |
+
traceback.print_exc()
|
67 |
predictor = None
|
68 |
|
69 |
+
# Utility Functions
|
70 |
+
def get_video_fps(video_path):
|
71 |
cap = cv2.VideoCapture(video_path)
|
72 |
+
if not cap.isOpened():
|
73 |
+
print("Error: Could not open video.")
|
74 |
+
return 30.0
|
75 |
fps = cap.get(cv2.CAP_PROP_FPS)
|
76 |
cap.release()
|
77 |
return fps
|
78 |
|
79 |
+
def reset(session_state):
|
80 |
+
session_state["input_points"] = []
|
81 |
+
session_state["input_labels"] = []
|
82 |
+
if session_state["inference_state"] is not None:
|
83 |
+
predictor.reset_state(session_state["inference_state"])
|
84 |
+
session_state["first_frame"] = None
|
85 |
+
session_state["all_frames"] = None
|
86 |
+
session_state["inference_state"] = None
|
87 |
+
return (
|
88 |
+
None,
|
89 |
+
gr.update(open=True),
|
90 |
+
None,
|
91 |
+
None,
|
92 |
+
gr.update(value=None, visible=False),
|
93 |
+
session_state,
|
94 |
+
)
|
95 |
+
|
96 |
+
def clear_points(session_state):
|
97 |
+
session_state["input_points"] = []
|
98 |
+
session_state["input_labels"] = []
|
99 |
+
if session_state["inference_state"] is not None and session_state["inference_state"].get("tracking_has_started", False):
|
100 |
+
predictor.reset_state(session_state["inference_state"])
|
101 |
+
return (
|
102 |
+
session_state["first_frame"],
|
103 |
+
None,
|
104 |
+
gr.update(value=None, visible=False),
|
105 |
+
session_state,
|
106 |
+
)
|
107 |
|
108 |
+
def preprocess_video_in(video_path, session_state):
|
109 |
+
if video_path is None:
|
110 |
+
return (
|
111 |
+
gr.update(open=True),
|
112 |
+
None,
|
113 |
+
None,
|
114 |
+
gr.update(value=None, visible=False),
|
115 |
+
session_state,
|
116 |
+
)
|
117 |
|
|
|
118 |
cap = cv2.VideoCapture(video_path)
|
119 |
+
if not cap.isOpened():
|
120 |
+
print("Error: Could not open video.")
|
121 |
+
return (
|
122 |
+
gr.update(open=True),
|
123 |
+
None,
|
124 |
+
None,
|
125 |
+
gr.update(value=None, visible=False),
|
126 |
+
session_state,
|
127 |
+
)
|
128 |
+
|
129 |
+
# Video properties
|
130 |
+
frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
|
131 |
+
frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
132 |
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
|
|
|
|
|
|
|
|
|
|
|
|
133 |
|
134 |
+
# Resize for CPU performance
|
135 |
+
target_width = 640
|
136 |
+
scale_factor = 1.0
|
137 |
+
if frame_width > target_width:
|
138 |
+
scale_factor = target_width / frame_width
|
139 |
+
frame_width = target_width
|
140 |
+
frame_height = int(frame_height * scale_factor)
|
141 |
+
|
142 |
+
# Read frames with stride for CPU optimization
|
143 |
+
frame_number = 0
|
144 |
+
first_frame = None
|
145 |
+
all_frames = []
|
146 |
+
frame_stride = max(1, total_frames // 300) # Limit to ~300 frames
|
147 |
+
|
148 |
while True:
|
149 |
ret, frame = cap.read()
|
150 |
+
if not ret:
|
151 |
+
break
|
152 |
+
if frame_number % frame_stride == 0:
|
153 |
+
if scale_factor != 1.0:
|
154 |
+
frame = cv2.resize(frame, (frame_width, frame_height), interpolation=cv2.INTER_AREA)
|
155 |
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
156 |
+
if first_frame is None:
|
157 |
+
first_frame = frame
|
158 |
+
all_frames.append(frame)
|
159 |
+
frame_number += 1
|
160 |
+
|
161 |
cap.release()
|
162 |
+
session_state["first_frame"] = copy.deepcopy(first_frame)
|
163 |
+
session_state["all_frames"] = all_frames
|
164 |
+
session_state["frame_stride"] = frame_stride
|
165 |
+
session_state["scale_factor"] = scale_factor
|
166 |
+
session_state["original_dimensions"] = (int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)), int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)))
|
167 |
|
168 |
+
session_state["inference_state"] = predictor.init_state(video_path=video_path)
|
169 |
+
session_state["input_points"] = []
|
170 |
+
session_state["input_labels"] = []
|
171 |
+
|
172 |
+
return [
|
173 |
+
gr.update(open=False),
|
174 |
+
first_frame,
|
175 |
+
None,
|
176 |
+
gr.update(value=None, visible=False),
|
177 |
+
session_state,
|
178 |
+
]
|
179 |
+
|
180 |
+
def segment_with_points(point_type, session_state, evt: gr.SelectData):
|
181 |
+
session_state["input_points"].append(evt.index)
|
182 |
+
print(f"TRACKING INPUT POINT: {session_state['input_points']}")
|
183 |
+
|
184 |
+
if point_type == "include":
|
185 |
+
session_state["input_labels"].append(1)
|
186 |
+
elif point_type == "exclude":
|
187 |
+
session_state["input_labels"].append(0)
|
188 |
+
print(f"TRACKING INPUT LABEL: {session_state['input_labels']}")
|
189 |
+
|
190 |
+
first_frame = session_state["first_frame"]
|
191 |
+
h, w = first_frame.shape[:2]
|
192 |
+
transparent_background = Image.fromarray(first_frame).convert("RGBA")
|
193 |
+
|
194 |
+
# Draw points
|
195 |
+
fraction = 0.01
|
196 |
+
radius = int(fraction * min(w, h))
|
197 |
+
transparent_layer = np.zeros((h, w, 4), dtype=np.uint8)
|
198 |
+
|
199 |
+
for index, track in enumerate(session_state["input_points"]):
|
200 |
+
color = (0, 255, 0, 255) if session_state["input_labels"][index] == 1 else (255, 0, 0, 255)
|
201 |
+
cv2.circle(transparent_layer, track, radius, color, -1)
|
202 |
+
|
203 |
+
transparent_layer = Image.fromarray(transparent_layer, "RGBA")
|
204 |
+
selected_point_map = Image.alpha_composite(transparent_background, transparent_layer)
|
205 |
+
|
206 |
+
points = np.array(session_state["input_points"], dtype=np.float32)
|
207 |
+
labels = np.array(session_state["input_labels"], np.int32)
|
208 |
|
209 |
try:
|
210 |
+
_, _, out_mask_logits = predictor.add_new_points(
|
211 |
+
inference_state=session_state["inference_state"],
|
212 |
+
frame_idx=0,
|
213 |
+
obj_id=OBJ_ID,
|
214 |
+
points=points,
|
215 |
+
labels=labels,
|
216 |
+
)
|
217 |
+
mask_array = (out_mask_logits[0] > 0.0).cpu().numpy()
|
218 |
+
|
219 |
+
# Ensure mask matches frame size
|
220 |
+
if mask_array.shape[:2] != (h, w):
|
221 |
+
mask_array = cv2.resize(mask_array.astype(np.uint8), (w, h), interpolation=cv2.INTER_NEAREST).astype(bool)
|
222 |
+
|
223 |
+
mask_image = show_mask(mask_array)
|
224 |
+
if mask_image.size != transparent_background.size:
|
225 |
+
mask_image = mask_image.resize(transparent_background.size, Image.NEAREST)
|
226 |
+
|
227 |
+
first_frame_output = Image.alpha_composite(transparent_background, mask_image)
|
228 |
except Exception as e:
|
229 |
+
print(f"Error in segmentation: {e}")
|
230 |
+
first_frame_output = selected_point_map
|
231 |
|
232 |
+
return selected_point_map, first_frame_output, session_state
|
|
|
233 |
|
234 |
+
def show_mask(mask, obj_id=None, random_color=False, convert_to_image=True):
|
235 |
+
if random_color:
|
236 |
+
color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
|
237 |
+
else:
|
238 |
+
cmap = plt.get_cmap("tab10")
|
239 |
+
cmap_idx = 0 if obj_id is None else obj_id
|
240 |
+
color = np.array([*cmap(cmap_idx)[:3], 0.6])
|
241 |
|
242 |
+
h, w = mask.shape[-2:] if len(mask.shape) > 2 else mask.shape
|
243 |
+
mask_reshaped = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
|
244 |
+
mask_rgba = (mask_reshaped * 255).astype(np.uint8)
|
245 |
+
|
246 |
+
if convert_to_image:
|
247 |
try:
|
248 |
+
if mask_rgba.shape[2] != 4:
|
249 |
+
proper_mask = np.zeros((h, w, 4), dtype=np.uint8)
|
250 |
+
proper_mask[:, :, :min(mask_rgba.shape[2], 4)] = mask_rgba[:, :, :min(mask_rgba.shape[2], 4)]
|
251 |
+
mask_rgba = proper_mask
|
252 |
+
return Image.fromarray(mask_rgba, "RGBA")
|
253 |
+
except Exception as e:
|
254 |
+
print(f"Error converting mask to image: {e}")
|
255 |
+
return Image.fromarray(np.zeros((h, w, 4), dtype=np.uint8), "RGBA")
|
256 |
+
|
257 |
+
return mask_rgba
|
258 |
+
|
259 |
+
def propagate_to_all(video_in, session_state, progress=gr.Progress()):
|
260 |
+
if len(session_state["input_points"]) == 0 or video_in is None or session_state["inference_state"] is None:
|
261 |
+
return gr.update(value=None, visible=False), session_state
|
262 |
+
|
263 |
+
chunk_size = 3
|
264 |
+
try:
|
265 |
+
video_segments = {}
|
266 |
+
total_frames = len(session_state["all_frames"])
|
267 |
+
progress(0, desc="Propagating segmentation through video...")
|
268 |
+
|
269 |
+
for i, (out_frame_idx, out_obj_ids, out_mask_logit) in enumerate(predictor.propagate_in_video(session_state["inference_state"])):
|
270 |
+
try:
|
271 |
+
video_segments[out_frame_idx] = {
|
272 |
+
out_obj_id: (out_mask_logit[i] > 0.0).cpu().numpy()
|
273 |
+
for i, out_obj_id in enumerate(out_obj_ids)
|
274 |
+
}
|
275 |
+
progress((i + 1) / total_frames, desc=f"Processed frame {out_frame_idx}/{total_frames}")
|
276 |
+
if out_frame_idx % chunk_size == 0:
|
277 |
+
del out_mask_logit
|
278 |
+
import gc
|
279 |
+
gc.collect()
|
280 |
+
except Exception as e:
|
281 |
+
print(f"Error processing frame {out_frame_idx}: {e}")
|
282 |
+
continue
|
283 |
+
|
284 |
+
max_output_frames = 50
|
285 |
+
vis_frame_stride = max(1, total_frames // max_output_frames)
|
286 |
+
first_frame = session_state["all_frames"][0]
|
287 |
+
h, w = first_frame.shape[:2]
|
288 |
+
output_frames = []
|
289 |
+
|
290 |
+
for out_frame_idx in range(0, total_frames, vis_frame_stride):
|
291 |
+
if out_frame_idx not in video_segments or OBJ_ID not in video_segments[out_frame_idx]:
|
292 |
+
continue
|
293 |
+
try:
|
294 |
+
frame = session_state["all_frames"][out_frame_idx]
|
295 |
+
transparent_background = Image.fromarray(frame).convert("RGBA")
|
296 |
+
out_mask = video_segments[out_frame_idx][OBJ_ID]
|
297 |
+
|
298 |
+
# Validate mask dimensions
|
299 |
+
if out_mask.shape[:2] != (h, w):
|
300 |
+
if out_mask.size == 0: # Skip empty masks
|
301 |
+
print(f"Skipping empty mask for frame {out_frame_idx}")
|
302 |
+
continue
|
303 |
+
out_mask = cv2.resize(out_mask.astype(np.uint8), (w, h), interpolation=cv2.INTER_NEAREST).astype(bool)
|
304 |
+
|
305 |
+
mask_image = show_mask(out_mask)
|
306 |
+
if mask_image.size != transparent_background.size:
|
307 |
+
mask_image = mask_image.resize(transparent_background.size, Image.NEAREST)
|
308 |
+
|
309 |
+
output_frame = Image.alpha_composite(transparent_background, mask_image)
|
310 |
+
output_frames.append(np.array(output_frame))
|
311 |
+
|
312 |
+
if len(output_frames) % 10 == 0:
|
313 |
+
import gc
|
314 |
+
gc.collect()
|
315 |
+
except Exception as e:
|
316 |
+
print(f"Error creating output frame {out_frame_idx}: {e_RAW
|
317 |
+
traceback.print_exc()
|
318 |
+
continue
|
319 |
|
320 |
+
original_fps = get_video_fps(video_in)
|
321 |
+
fps = min(original_fps, 15) # Cap at 15 FPS for CPU
|
322 |
+
|
323 |
+
clip = ImageSequenceClip(output_frames, fps=fps)
|
324 |
+
unique_id = datetime.now().strftime("%Y%m%d%H%M%S")
|
325 |
+
final_vid_output_path = os.path.join(tempfile.gettempdir(), f"output_video_{unique_id}.mp4")
|
326 |
+
|
327 |
+
clip.write_videofile(
|
328 |
+
final_vid_output_path,
|
329 |
+
codec="libx264",
|
330 |
+
bitrate="800k",
|
331 |
+
threads=2,
|
332 |
+
logger=None
|
333 |
+
)
|
334 |
+
|
335 |
+
del video_segments, output_frames
|
336 |
+
import gc
|
337 |
+
gc.collect()
|
338 |
+
|
339 |
+
return gr.update(value=final_vid_output_path, visible=True), session_state
|
340 |
+
|
341 |
+
except Exception as e:
|
342 |
+
print(f"Error in propagate_to_all: {e}")
|
343 |
+
return gr.update(value=None, visible=False), session_state
|
344 |
+
|
345 |
+
def update_ui():
|
346 |
+
return gr.update(visible=True)
|
347 |
+
|
348 |
+
# Gradio Interface
|
349 |
with gr.Blocks() as demo:
|
350 |
+
session_state = gr.State({
|
351 |
+
"first_frame": None,
|
352 |
+
"all_frames": None,
|
353 |
+
"input_points": [],
|
354 |
+
"input_labels": [],
|
355 |
+
"inference_state": None,
|
356 |
+
"frame_stride": 1,
|
357 |
+
"scale_factor": 1.0,
|
358 |
+
"original_dimensions": None,
|
359 |
+
})
|
360 |
+
|
361 |
+
with gr.Column():
|
362 |
+
gr.Markdown(title)
|
363 |
+
with gr.Row():
|
364 |
+
with gr.Column():
|
365 |
+
gr.Markdown(description_p)
|
366 |
+
with gr.Accordion("Input Video", open=True) as video_in_drawer:
|
367 |
+
video_in = gr.Video(label="Input Video", format="mp4")
|
368 |
+
with gr.Row():
|
369 |
+
point_type = gr.Radio(label="point type", choices=["include", "exclude"], value="include", scale=2)
|
370 |
+
propagate_btn = gr.Button("Track", scale=1, variant="primary")
|
371 |
+
clear_points_btn = gr.Button("Clear Points", scale=1)
|
372 |
+
reset_btn = gr.Button("Reset", scale=1)
|
373 |
+
points_map = gr.Image(label="Frame with Point Prompt", type="numpy", interactive=False)
|
374 |
+
with gr.Column():
|
375 |
+
gr.Markdown("# Try some of the examples below ⬇️")
|
376 |
+
gr.Examples(examples=examples, inputs=[video_in], examples_per_page=5)
|
377 |
+
output_image = gr.Image(label="Reference Mask")
|
378 |
+
output_video = gr.Video(visible=False)
|
379 |
+
|
380 |
+
video_in.upload(
|
381 |
+
fn=preprocess_video_in,
|
382 |
+
inputs=[video_in, session_state],
|
383 |
+
outputs=[video_in_drawer, points_map, output_image, output_video, session_state],
|
384 |
+
queue=False,
|
385 |
+
)
|
386 |
+
|
387 |
+
video_in.change(
|
388 |
+
fn=preprocess_video_in,
|
389 |
+
inputs=[video_in, session_state],
|
390 |
+
outputs=[video_in_drawer, points_map, output_image, output_video, session_state],
|
391 |
+
queue=False,
|
392 |
+
)
|
393 |
+
|
394 |
+
points_map.select(
|
395 |
+
fn=segment_with_points,
|
396 |
+
inputs=[point_type, session_state],
|
397 |
+
outputs=[points_map, output_image, session_state],
|
398 |
+
queue=False,
|
399 |
+
)
|
400 |
+
|
401 |
+
clear_points_btn.click(
|
402 |
+
fn=clear_points,
|
403 |
+
inputs=session_state,
|
404 |
+
outputs=[points_map, output_image, output_video, session_state],
|
405 |
+
queue=False,
|
406 |
+
)
|
407 |
+
|
408 |
+
reset_btn.click(
|
409 |
+
fn=reset,
|
410 |
+
inputs=session_state,
|
411 |
+
outputs=[video_in, video_in_drawer, points_map, output_image, output_video, session_state],
|
412 |
+
queue=False,
|
413 |
+
)
|
414 |
+
|
415 |
+
propagate_btn.click(
|
416 |
+
fn=update_ui,
|
417 |
+
inputs=[],
|
418 |
+
outputs=output_video,
|
419 |
+
queue=False,
|
420 |
+
).then(
|
421 |
+
fn=propagate_to_all,
|
422 |
+
inputs=[video_in, session_state],
|
423 |
+
outputs=[output_video, session_state],
|
424 |
+
queue=True,
|
425 |
+
)
|
426 |
+
|
427 |
+
demo.queue()
|
428 |
+
demo.launch()
|