Update app.py
Browse files
app.py
CHANGED
@@ -10,16 +10,16 @@ from datetime import datetime
|
|
10 |
|
11 |
import gradio as gr
|
12 |
|
13 |
-
#
|
14 |
-
|
15 |
-
|
16 |
import tempfile
|
17 |
|
18 |
import cv2
|
19 |
import matplotlib.pyplot as plt
|
|
|
|
|
|
|
20 |
import numpy as np
|
21 |
-
# Removed spaces decorator import for CPU-only demo
|
22 |
-
# import spaces
|
23 |
import torch
|
24 |
|
25 |
from moviepy.editor import ImageSequenceClip
|
@@ -38,7 +38,7 @@ description_p = """# Instructions
|
|
38 |
</ol>
|
39 |
"""
|
40 |
|
41 |
-
# examples
|
42 |
examples = [
|
43 |
["examples/01_dog.mp4"],
|
44 |
["examples/02_cups.mp4"],
|
@@ -75,77 +75,133 @@ OBJ_ID = 0
|
|
75 |
|
76 |
sam2_checkpoint = "checkpoints/edgetam.pt"
|
77 |
model_cfg = "edgetam.yaml"
|
78 |
-
#
|
79 |
-
# The device is set here and with .to("cpu")
|
80 |
predictor = build_sam2_video_predictor(model_cfg, sam2_checkpoint, device="cpu")
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
#
|
86 |
-
|
87 |
-
#
|
88 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
89 |
|
90 |
|
91 |
def get_video_fps(video_path):
|
92 |
-
|
93 |
-
if video_path is None or not os.path.exists(video_path):
|
94 |
-
print(f"Warning: Video file not found at {video_path}")
|
95 |
-
return None
|
96 |
cap = cv2.VideoCapture(video_path)
|
|
|
97 |
if not cap.isOpened():
|
98 |
-
print(
|
99 |
return None
|
|
|
|
|
100 |
fps = cap.get(cv2.CAP_PROP_FPS)
|
101 |
-
cap.release()
|
102 |
return fps
|
103 |
|
104 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
105 |
def preprocess_video_in(video_path, session_state):
|
106 |
"""Loads video frames and initializes the predictor state."""
|
107 |
print(f"Processing video: {video_path}")
|
108 |
if video_path is None or not os.path.exists(video_path):
|
109 |
print("No video path provided or file not found.")
|
110 |
# Reset state and UI elements if input is invalid
|
|
|
111 |
return (
|
112 |
-
gr.update(open=True),
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
gr.update(interactive=False), # clear_points_btn
|
118 |
-
gr.update(interactive=False), # reset_btn
|
119 |
-
{ # Reset session state
|
120 |
-
"first_frame": None,
|
121 |
-
"all_frames": None,
|
122 |
-
"input_points": [],
|
123 |
-
"input_labels": [],
|
124 |
-
"inference_state": None,
|
125 |
-
"video_path": None,
|
126 |
}
|
127 |
)
|
128 |
|
129 |
-
# Read the first frame and all frames
|
130 |
cap = cv2.VideoCapture(video_path)
|
131 |
if not cap.isOpened():
|
132 |
print(f"Error: Could not open video file {video_path}.")
|
133 |
-
# Reset state and UI elements on error
|
134 |
return (
|
135 |
-
gr.update(open=True),
|
136 |
-
|
137 |
-
None,
|
138 |
-
gr.update(value=None, visible=False),
|
139 |
-
gr.update(interactive=False), # propagate_btn
|
140 |
-
gr.update(interactive=False), # clear_points_btn
|
141 |
-
gr.update(interactive=False), # reset_btn
|
142 |
{ # Reset session state
|
143 |
-
"first_frame": None,
|
144 |
-
"
|
145 |
-
"input_points": [],
|
146 |
-
"input_labels": [],
|
147 |
-
"inference_state": None,
|
148 |
-
"video_path": None,
|
149 |
}
|
150 |
)
|
151 |
|
@@ -156,139 +212,65 @@ def preprocess_video_in(video_path, session_state):
|
|
156 |
ret, frame = cap.read()
|
157 |
if not ret:
|
158 |
break
|
159 |
-
# Convert BGR to RGB
|
160 |
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
161 |
all_frames.append(frame)
|
162 |
if first_frame is None:
|
163 |
-
first_frame = frame
|
164 |
|
165 |
cap.release()
|
166 |
|
167 |
if not all_frames:
|
168 |
print(f"Error: No frames read from video file {video_path}.")
|
169 |
-
# Reset state and UI elements if no frames are read
|
170 |
return (
|
171 |
-
gr.update(open=True),
|
172 |
-
|
173 |
-
None,
|
174 |
-
gr.update(value=None, visible=False),
|
175 |
-
gr.update(interactive=False), # propagate_btn
|
176 |
-
gr.update(interactive=False), # clear_points_btn
|
177 |
-
gr.update(interactive=False), # reset_btn
|
178 |
{ # Reset session state
|
179 |
-
"first_frame": None,
|
180 |
-
"
|
181 |
-
"input_points": [],
|
182 |
-
"input_labels": [],
|
183 |
-
"inference_state": None,
|
184 |
-
"video_path": None,
|
185 |
}
|
186 |
)
|
187 |
|
188 |
-
|
189 |
-
session_state["first_frame"] = copy.deepcopy(first_frame) # Store a copy
|
190 |
session_state["all_frames"] = all_frames
|
191 |
-
session_state["video_path"] = video_path # Store
|
192 |
session_state["input_points"] = []
|
193 |
session_state["input_labels"] = []
|
194 |
-
#
|
195 |
session_state["inference_state"] = predictor.init_state(video_path=video_path)
|
196 |
print("Video loaded and predictor state initialized.")
|
197 |
|
|
|
198 |
return [
|
199 |
gr.update(open=False), # video_in_drawer
|
200 |
-
first_frame, # points_map
|
201 |
-
None, # output_image
|
202 |
-
gr.update(value=None, visible=False), # output_video
|
203 |
-
gr.update(interactive=True), #
|
204 |
-
gr.update(interactive=True), #
|
205 |
-
gr.update(interactive=True), #
|
206 |
-
session_state, #
|
207 |
]
|
208 |
|
209 |
|
210 |
-
|
211 |
-
|
212 |
-
print("Resetting demo.")
|
213 |
-
# Clear points and labels
|
214 |
-
session_state["input_points"] = []
|
215 |
-
session_state["input_labels"] = []
|
216 |
-
# Reset the predictor state if it exists
|
217 |
-
if session_state["inference_state"] is not None:
|
218 |
-
predictor.reset_state(session_state["inference_state"])
|
219 |
-
# After reset, we also discard the state object as a new video might be loaded
|
220 |
-
session_state["inference_state"] = None
|
221 |
-
# Clear frames and video path
|
222 |
-
session_state["first_frame"] = None
|
223 |
-
session_state["all_frames"] = None
|
224 |
-
session_state["video_path"] = None
|
225 |
-
|
226 |
-
# Update UI elements to their initial state
|
227 |
-
return (
|
228 |
-
None, # video_in
|
229 |
-
gr.update(open=True), # video_in_drawer open
|
230 |
-
None, # points_map cleared
|
231 |
-
None, # output_image cleared
|
232 |
-
gr.update(value=None, visible=False), # output_video hidden
|
233 |
-
gr.update(interactive=False), # Disable buttons
|
234 |
-
gr.update(interactive=False), # Disable buttons
|
235 |
-
gr.update(interactive=False), # Disable buttons
|
236 |
-
session_state, # Updated session state
|
237 |
-
)
|
238 |
-
|
239 |
-
|
240 |
-
def clear_points(session_state):
|
241 |
-
"""Clears selected points and resets segmentation on the first frame."""
|
242 |
-
print("Clearing points.")
|
243 |
-
# Clear points and labels lists
|
244 |
-
session_state["input_points"] = []
|
245 |
-
session_state["input_labels"] = []
|
246 |
-
|
247 |
-
# Reset the predictor state if it exists. This clears internal masks/features
|
248 |
-
# but keeps the video context initialized by preprocess_video_in.
|
249 |
-
if session_state["inference_state"] is not None:
|
250 |
-
predictor.reset_state(session_state["inference_state"])
|
251 |
-
# After resetting the state, if we still have the video path, re-initialize the state
|
252 |
-
# to be ready for new points on the same video.
|
253 |
-
if session_state["video_path"] is not None:
|
254 |
-
# Re-initialize state *without* the device argument
|
255 |
-
session_state["inference_state"] = predictor.init_state(video_path=session_state["video_path"])
|
256 |
-
print("Predictor state re-initialized after clearing points.")
|
257 |
-
else:
|
258 |
-
print("Warning: Could not re-initialize state after clear_points (video_path missing).")
|
259 |
-
session_state["inference_state"] = None # Ensure state is None if video_path is gone
|
260 |
-
|
261 |
-
|
262 |
-
# Re-render the points_map with no points drawn (just the first frame)
|
263 |
-
# Re-render the output_image with no mask (just the first frame)
|
264 |
-
first_frame_img = session_state["first_frame"] if session_state["first_frame"] is not None else None
|
265 |
-
|
266 |
-
return (
|
267 |
-
first_frame_img, # points_map shows original first frame
|
268 |
-
None, # output_image cleared
|
269 |
-
gr.update(value=None, visible=False), # Hide output video
|
270 |
-
session_state, # Updated session state
|
271 |
-
)
|
272 |
-
|
273 |
-
|
274 |
-
# Removed @spaces.GPU decorator
|
275 |
def segment_with_points(
|
276 |
point_type,
|
277 |
session_state,
|
278 |
evt: gr.SelectData,
|
279 |
):
|
280 |
"""Adds a point prompt and performs segmentation on the first frame."""
|
281 |
-
# Ensure we have
|
282 |
if session_state["first_frame"] is None or session_state["inference_state"] is None:
|
283 |
print("Error: Cannot segment. No video loaded or inference state missing.")
|
284 |
-
# Return current
|
285 |
return (
|
286 |
-
session_state
|
287 |
-
None, # output_image
|
288 |
session_state,
|
289 |
)
|
290 |
|
291 |
-
# evt.index
|
292 |
click_coords = evt.index
|
293 |
print(f"Clicked at: {click_coords} ({point_type})")
|
294 |
|
@@ -314,12 +296,11 @@ def segment_with_points(
|
|
314 |
for index, track in enumerate(session_state["input_points"]):
|
315 |
# Ensure coordinates are integers for cv2.circle
|
316 |
point_coords = (int(track[0]), int(track[1]))
|
|
|
317 |
if session_state["input_labels"][index] == 1:
|
318 |
-
# Green
|
319 |
-
cv2.circle(transparent_layer_points, point_coords, radius, (0, 255, 0, 255), -1)
|
320 |
else:
|
321 |
-
# Red
|
322 |
-
cv2.circle(transparent_layer_points, point_coords, radius, (255, 0, 0, 255), -1)
|
323 |
|
324 |
# Convert the transparent layer back to an image and composite onto the first frame
|
325 |
transparent_layer_points_pil = Image.fromarray(transparent_layer_points, "RGBA")
|
@@ -329,13 +310,14 @@ def segment_with_points(
|
|
329 |
first_frame_pil.copy(), transparent_layer_points_pil
|
330 |
)
|
331 |
|
332 |
-
# Prepare points and labels as tensors on
|
333 |
points = np.array(session_state["input_points"], dtype=np.float32)
|
334 |
labels = np.array(session_state["input_labels"], np.int32)
|
335 |
|
336 |
-
# Ensure tensors are on
|
337 |
-
|
338 |
-
|
|
|
339 |
|
340 |
# Add new points to the predictor's state and get the mask for the first frame
|
341 |
# This call performs segmentation on the current frame (frame_idx=0) using all accumulated points
|
@@ -351,8 +333,9 @@ def segment_with_points(
|
|
351 |
)
|
352 |
|
353 |
# Process logits: detach from graph, move to CPU, apply threshold
|
354 |
-
# out_mask_logits is a list of tensors [tensor([H, W])] for the requested obj_id
|
355 |
-
|
|
|
356 |
mask_numpy = mask_tensor.numpy() # Convert to numpy
|
357 |
|
358 |
# Get the mask image (RGBA)
|
@@ -366,6 +349,9 @@ def segment_with_points(
|
|
366 |
print(f"Error during segmentation on first frame: {e}")
|
367 |
# On error, first_frame_output_img remains None
|
368 |
|
|
|
|
|
|
|
369 |
|
370 |
return selected_point_map_img, first_frame_output_img, session_state
|
371 |
|
@@ -416,21 +402,22 @@ def show_mask(mask, obj_id=None, random_color=False, convert_to_image=True):
|
|
416 |
return colored_mask_uint8
|
417 |
|
418 |
|
419 |
-
#
|
|
|
420 |
def propagate_to_all(
|
421 |
-
#
|
422 |
-
# but keeping it is fine. Accessing session_state["video_path"] is more robust.
|
423 |
-
video_in,
|
424 |
session_state,
|
425 |
):
|
426 |
"""Runs mask propagation through the video and generates the output video."""
|
427 |
print("Starting propagation...")
|
428 |
# Ensure state is ready
|
|
|
|
|
429 |
if (
|
430 |
len(session_state["input_points"]) == 0 # Need at least one point
|
431 |
or session_state["all_frames"] is None
|
432 |
or session_state["inference_state"] is None
|
433 |
-
or
|
434 |
):
|
435 |
print("Error: Cannot propagate. No points selected, video not loaded, or inference state missing.")
|
436 |
return (
|
@@ -439,7 +426,6 @@ def propagate_to_all(
|
|
439 |
)
|
440 |
|
441 |
# run propagation throughout the video and collect the results
|
442 |
-
# The generator yields (frame_idx, obj_ids, mask_logits)
|
443 |
video_segments = {}
|
444 |
try:
|
445 |
# This loop performs the core tracking prediction frame by frame
|
@@ -451,7 +437,7 @@ def propagate_to_all(
|
|
451 |
video_segments[out_frame_idx] = {
|
452 |
# out_mask_logits is a list of tensors (one per object tracked in this frame)
|
453 |
# Each tensor is [batch_size, H, W]. Batch size is 1 here.
|
454 |
-
# Access the first
|
455 |
out_obj_id: (out_mask_logits[i][0].detach().cpu() > 0.0).numpy()
|
456 |
for i, out_obj_id in enumerate(out_obj_ids)
|
457 |
}
|
@@ -492,9 +478,11 @@ def propagate_to_all(
|
|
492 |
|
493 |
output_frames.append(output_frame_np)
|
494 |
|
|
|
|
|
|
|
495 |
|
496 |
# Define output path in a temporary directory
|
497 |
-
# Use os.path.join for cross-platform compatibility
|
498 |
unique_id = datetime.now().strftime("%Y%m%d%H%M%S%f") # Use microseconds for more uniqueness
|
499 |
final_vid_filename = f"output_video_{unique_id}.mp4"
|
500 |
final_vid_output_path = os.path.join(tempfile.gettempdir(), final_vid_filename)
|
@@ -502,9 +490,8 @@ def propagate_to_all(
|
|
502 |
|
503 |
|
504 |
# Create a video clip from the image sequence
|
505 |
-
# Get original FPS
|
506 |
-
|
507 |
-
original_fps = get_video_fps(session_state["video_path"])
|
508 |
fps = original_fps if original_fps is not None and original_fps > 0 else 30 # Default to 30 if detection fails or is zero
|
509 |
print(f"Creating output video with FPS: {fps}")
|
510 |
|
@@ -526,20 +513,11 @@ def propagate_to_all(
|
|
526 |
session_state,
|
527 |
)
|
528 |
|
529 |
-
|
530 |
# Write the result to a file. Use 'libx264' codec for broad compatibility.
|
531 |
-
# `preset` and `threads` for CPU optimization.
|
532 |
-
# `logger=None` prevents moviepy from printing progress to stdout/stderr, which can clutter the Gradio logs.
|
533 |
try:
|
534 |
-
print(f"Writing video file with codec='libx264', fps={fps}
|
535 |
-
|
536 |
-
|
537 |
-
codec="libx264",
|
538 |
-
fps=fps, # Ensure correct FPS is used during writing
|
539 |
-
preset="medium", # CPU optimization: 'fast', 'faster', 'veryfast' are options for speed vs size
|
540 |
-
threads="auto", # CPU optimization: Use multiple cores
|
541 |
-
logger=None # Suppress moviepy output
|
542 |
-
)
|
543 |
print("Video writing complete.")
|
544 |
# Return the path and make the video player visible
|
545 |
return (
|
@@ -563,7 +541,7 @@ def propagate_to_all(
|
|
563 |
)
|
564 |
|
565 |
|
566 |
-
def
|
567 |
"""Simply returns a Gradio update to make the output video visible."""
|
568 |
return gr.update(visible=True)
|
569 |
|
@@ -611,12 +589,11 @@ with gr.Blocks() as demo:
|
|
611 |
points_map = gr.Image(
|
612 |
label="Click on the First Frame to Add Points", # Clearer label
|
613 |
type="numpy",
|
614 |
-
interactive=True, #
|
615 |
height=400, # Set a fixed height for better UI
|
616 |
width="auto", # Let width adjust
|
617 |
show_share_button=False,
|
618 |
show_download_button=False,
|
619 |
-
# show_label=False # Can hide label if space is tight
|
620 |
)
|
621 |
|
622 |
with gr.Column():
|
@@ -627,7 +604,7 @@ with gr.Blocks() as demo:
|
|
627 |
examples_per_page=8,
|
628 |
cache_examples=False, # Do not cache processed examples, as state is involved
|
629 |
)
|
630 |
-
# Add padding/space
|
631 |
# gr.Markdown("<br>")
|
632 |
|
633 |
# output_image shows the segmentation mask prediction on the *first* frame
|
@@ -639,7 +616,6 @@ with gr.Blocks() as demo:
|
|
639 |
width="auto", # Let width adjust
|
640 |
show_share_button=False,
|
641 |
show_download_button=False,
|
642 |
-
# show_label=False # Can hide label
|
643 |
)
|
644 |
|
645 |
# output_video shows the final tracking result
|
@@ -649,35 +625,25 @@ with gr.Blocks() as demo:
|
|
649 |
# --- Event Handlers ---
|
650 |
|
651 |
# When a new video file is uploaded via the file browser
|
|
|
652 |
video_in.upload(
|
653 |
fn=preprocess_video_in,
|
654 |
inputs=[video_in, session_state],
|
655 |
outputs=[
|
656 |
-
video_in_drawer,
|
657 |
-
|
658 |
-
output_image, # Clear output image
|
659 |
-
output_video, # Hide output video
|
660 |
-
propagate_btn, # Enable Track button
|
661 |
-
clear_points_btn,# Enable Clear Points button
|
662 |
-
reset_btn, # Enable Reset button
|
663 |
-
session_state, # Update session state
|
664 |
],
|
665 |
queue=False, # Process immediately
|
666 |
)
|
667 |
|
668 |
# When an example video is selected (change event)
|
|
|
669 |
video_in.change(
|
670 |
fn=preprocess_video_in,
|
671 |
inputs=[video_in, session_state],
|
672 |
-
|
673 |
-
video_in_drawer,
|
674 |
-
|
675 |
-
output_image, # Clear output image
|
676 |
-
output_video, # Hide output video
|
677 |
-
propagate_btn, # Enable Track button
|
678 |
-
clear_points_btn,# Enable Clear Points button
|
679 |
-
reset_btn, # Enable Reset button
|
680 |
-
session_state, # Update session state
|
681 |
],
|
682 |
queue=False, # Process immediately
|
683 |
)
|
@@ -716,15 +682,8 @@ with gr.Blocks() as demo:
|
|
716 |
fn=reset,
|
717 |
inputs=[session_state], # Pass session state
|
718 |
outputs=[
|
719 |
-
video_in,
|
720 |
-
|
721 |
-
points_map, # Clear points_map
|
722 |
-
output_image, # Clear output_image
|
723 |
-
output_video, # Hide output_video
|
724 |
-
propagate_btn, # Disable buttons
|
725 |
-
clear_points_btn,# Disable buttons
|
726 |
-
reset_btn, # Disable buttons
|
727 |
-
session_state, # Reset session state
|
728 |
],
|
729 |
queue=False, # Process immediately
|
730 |
)
|
@@ -743,18 +702,16 @@ with gr.Blocks() as demo:
|
|
743 |
],
|
744 |
outputs=[
|
745 |
output_video, # Update output video player with result
|
746 |
-
session_state, # Update session state
|
747 |
],
|
748 |
-
#
|
749 |
-
|
750 |
-
|
751 |
-
queue=True,
|
752 |
)
|
753 |
|
754 |
|
755 |
# Launch the Gradio demo
|
756 |
-
demo.queue() # Enable queuing
|
757 |
print("Gradio demo starting...")
|
758 |
-
# Removed share=True for local debugging unless you specifically need a public link
|
759 |
demo.launch()
|
760 |
print("Gradio demo launched.")
|
|
|
10 |
|
11 |
import gradio as gr
|
12 |
|
13 |
+
# This line might be related to GPU, kept from original
|
14 |
+
os.environ["TORCH_CUDNN_SDPA_ENABLED"] = "0,1,2,3,4,5,6,7"
|
|
|
15 |
import tempfile
|
16 |
|
17 |
import cv2
|
18 |
import matplotlib.pyplot as plt
|
19 |
+
# spaces import and decorators are for Hugging Face Spaces GPU allocation,
|
20 |
+
# if running locally without spaces, these can be removed or will be ignored.
|
21 |
+
import spaces
|
22 |
import numpy as np
|
|
|
|
|
23 |
import torch
|
24 |
|
25 |
from moviepy.editor import ImageSequenceClip
|
|
|
38 |
</ol>
|
39 |
"""
|
40 |
|
41 |
+
# examples
|
42 |
examples = [
|
43 |
["examples/01_dog.mp4"],
|
44 |
["examples/02_cups.mp4"],
|
|
|
75 |
|
76 |
sam2_checkpoint = "checkpoints/edgetam.pt"
|
77 |
model_cfg = "edgetam.yaml"
|
78 |
+
# Model built for CPU but immediately moved to CUDA in original code
|
|
|
79 |
predictor = build_sam2_video_predictor(model_cfg, sam2_checkpoint, device="cpu")
|
80 |
+
# *** Original code moves to CUDA ***
|
81 |
+
predictor.to("cuda")
|
82 |
+
print("predictor loaded on CUDA")
|
83 |
+
|
84 |
+
# use bfloat16 for the entire demo - Original code uses CUDA bfloat16
|
85 |
+
torch.autocast(device_type="cuda", dtype=torch.bfloat16).__enter__()
|
86 |
+
# Original CUDA settings
|
87 |
+
if torch.cuda.is_available() and torch.cuda.get_device_properties(0).major >= 8:
|
88 |
+
# turn on tfloat32 for Ampere GPUs (https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices)
|
89 |
+
torch.backends.cuda.matmul.allow_tf32 = True
|
90 |
+
torch.backends.cudnn.allow_tf32 = True
|
91 |
+
elif not torch.cuda.is_available():
|
92 |
+
print("Warning: CUDA not available. The original code is configured for GPU.")
|
93 |
+
# Note: Without a GPU, the .to("cuda") calls will likely cause errors.
|
94 |
|
95 |
|
96 |
def get_video_fps(video_path):
|
97 |
+
# Open the video file
|
|
|
|
|
|
|
98 |
cap = cv2.VideoCapture(video_path)
|
99 |
+
|
100 |
if not cap.isOpened():
|
101 |
+
print("Error: Could not open video.")
|
102 |
return None
|
103 |
+
|
104 |
+
# Get the FPS of the video
|
105 |
fps = cap.get(cv2.CAP_PROP_FPS)
|
106 |
+
cap.release() # Release the capture object
|
107 |
return fps
|
108 |
|
109 |
+
|
110 |
+
def reset(session_state):
|
111 |
+
"""Resets the UI and session state."""
|
112 |
+
print("Resetting demo.")
|
113 |
+
session_state["input_points"] = []
|
114 |
+
session_state["input_labels"] = []
|
115 |
+
# Reset the predictor state if it exists
|
116 |
+
if session_state["inference_state"] is not None:
|
117 |
+
# Assuming predictor.reset_state handles None or invalid states gracefully
|
118 |
+
# Or you might need to explicitly pass the state object if required
|
119 |
+
try:
|
120 |
+
predictor.reset_state(session_state["inference_state"])
|
121 |
+
# Explicitly delete or re-init the state object if a full reset is intended
|
122 |
+
# This depends on how predictor.reset_state works
|
123 |
+
# session_state["inference_state"] = None # Example if reset_state doesn't fully clear
|
124 |
+
except Exception as e:
|
125 |
+
print(f"Error resetting predictor state: {e}")
|
126 |
+
# If reset fails, perhaps force-clear the state object
|
127 |
+
session_state["inference_state"] = None
|
128 |
+
|
129 |
+
session_state["first_frame"] = None
|
130 |
+
session_state["all_frames"] = None
|
131 |
+
session_state["inference_state"] = None # Ensure state is None after a full reset
|
132 |
+
# Also reset video path if stored
|
133 |
+
session_state["video_path"] = None
|
134 |
+
|
135 |
+
# Resetting UI components
|
136 |
+
return (
|
137 |
+
None, # video_in (clears the video player)
|
138 |
+
gr.update(open=True), # video_in_drawer (opens accordion)
|
139 |
+
None, # points_map (clears the image)
|
140 |
+
None, # output_image (clears the image)
|
141 |
+
gr.update(value=None, visible=False), # output_video (hides and clears)
|
142 |
+
session_state, # return updated session state
|
143 |
+
)
|
144 |
+
|
145 |
+
|
146 |
+
def clear_points(session_state):
|
147 |
+
"""Clears selected points and resets segmentation on the first frame."""
|
148 |
+
print("Clearing points.")
|
149 |
+
session_state["input_points"] = []
|
150 |
+
session_state["input_labels"] = []
|
151 |
+
|
152 |
+
# Reset the predictor state to clear internal masks/features
|
153 |
+
# This typically doesn't remove the video context, just the mask predictions
|
154 |
+
if session_state["inference_state"] is not None:
|
155 |
+
try:
|
156 |
+
# Assuming reset_state handles clearing current masks/features
|
157 |
+
predictor.reset_state(session_state["inference_state"])
|
158 |
+
print("Predictor state reset for clearing points.")
|
159 |
+
# If you need to re-initialize the state for the *same* video after clearing points,
|
160 |
+
# you might need to call predictor.init_state again here, using the stored video_path.
|
161 |
+
# session_state["inference_state"] = predictor.init_state(video_path=session_state["video_path"], device="cuda") # Or device="cpu" if modified earlier
|
162 |
+
except Exception as e:
|
163 |
+
print(f"Error resetting predictor state during clear_points: {e}")
|
164 |
+
# If reset fails, this might leave old masks. Depending on SAM2's behavior,
|
165 |
+
# you might need a more aggressive state clear or re-initialization.
|
166 |
+
|
167 |
+
# Return the original first frame image for points_map and clear the output_image
|
168 |
+
first_frame_img = session_state["first_frame"] if session_state["first_frame"] is not None else None
|
169 |
+
|
170 |
+
return (
|
171 |
+
first_frame_img, # points_map shows original first frame (no points yet)
|
172 |
+
None, # output_image cleared (no mask)
|
173 |
+
gr.update(value=None, visible=False), # output_video hidden
|
174 |
+
session_state, # return updated session state
|
175 |
+
)
|
176 |
+
|
177 |
+
|
178 |
+
# Added @spaces.GPU decorator back as it was in the original code
|
179 |
+
@spaces.GPU
|
180 |
def preprocess_video_in(video_path, session_state):
|
181 |
"""Loads video frames and initializes the predictor state."""
|
182 |
print(f"Processing video: {video_path}")
|
183 |
if video_path is None or not os.path.exists(video_path):
|
184 |
print("No video path provided or file not found.")
|
185 |
# Reset state and UI elements if input is invalid
|
186 |
+
# Need to return updates for the buttons as well
|
187 |
return (
|
188 |
+
gr.update(open=True), None, None, gr.update(value=None, visible=False),
|
189 |
+
gr.update(interactive=False), gr.update(interactive=False), gr.update(interactive=False),
|
190 |
+
{ # Reset session state
|
191 |
+
"first_frame": None, "all_frames": None, "input_points": [],
|
192 |
+
"input_labels": [], "inference_state": None, "video_path": None,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
193 |
}
|
194 |
)
|
195 |
|
|
|
196 |
cap = cv2.VideoCapture(video_path)
|
197 |
if not cap.isOpened():
|
198 |
print(f"Error: Could not open video file {video_path}.")
|
|
|
199 |
return (
|
200 |
+
gr.update(open=True), None, None, gr.update(value=None, visible=False),
|
201 |
+
gr.update(interactive=False), gr.update(interactive=False), gr.update(interactive=False),
|
|
|
|
|
|
|
|
|
|
|
202 |
{ # Reset session state
|
203 |
+
"first_frame": None, "all_frames": None, "input_points": [],
|
204 |
+
"input_labels": [], "inference_state": None, "video_path": None,
|
|
|
|
|
|
|
|
|
205 |
}
|
206 |
)
|
207 |
|
|
|
212 |
ret, frame = cap.read()
|
213 |
if not ret:
|
214 |
break
|
|
|
215 |
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
216 |
all_frames.append(frame)
|
217 |
if first_frame is None:
|
218 |
+
first_frame = frame
|
219 |
|
220 |
cap.release()
|
221 |
|
222 |
if not all_frames:
|
223 |
print(f"Error: No frames read from video file {video_path}.")
|
|
|
224 |
return (
|
225 |
+
gr.update(open=True), None, None, gr.update(value=None, visible=False),
|
226 |
+
gr.update(interactive=False), gr.update(interactive=False), gr.update(interactive=False),
|
|
|
|
|
|
|
|
|
|
|
227 |
{ # Reset session state
|
228 |
+
"first_frame": None, "all_frames": None, "input_points": [],
|
229 |
+
"input_labels": [], "inference_state": None, "video_path": None,
|
|
|
|
|
|
|
|
|
230 |
}
|
231 |
)
|
232 |
|
233 |
+
session_state["first_frame"] = copy.deepcopy(first_frame)
|
|
|
234 |
session_state["all_frames"] = all_frames
|
235 |
+
session_state["video_path"] = video_path # Store video path
|
236 |
session_state["input_points"] = []
|
237 |
session_state["input_labels"] = []
|
238 |
+
# Original code did NOT pass device here. It uses the device the predictor is on.
|
239 |
session_state["inference_state"] = predictor.init_state(video_path=video_path)
|
240 |
print("Video loaded and predictor state initialized.")
|
241 |
|
242 |
+
# Enable buttons after successful load
|
243 |
return [
|
244 |
gr.update(open=False), # video_in_drawer
|
245 |
+
first_frame, # points_map
|
246 |
+
None, # output_image
|
247 |
+
gr.update(value=None, visible=False), # output_video
|
248 |
+
gr.update(interactive=True), # propagate_btn
|
249 |
+
gr.update(interactive=True), # clear_points_btn
|
250 |
+
gr.update(interactive=True), # reset_btn
|
251 |
+
session_state, # session_state
|
252 |
]
|
253 |
|
254 |
|
255 |
+
# Added @spaces.GPU decorator back as it was in the original code
|
256 |
+
@spaces.GPU
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
257 |
def segment_with_points(
|
258 |
point_type,
|
259 |
session_state,
|
260 |
evt: gr.SelectData,
|
261 |
):
|
262 |
"""Adds a point prompt and performs segmentation on the first frame."""
|
263 |
+
# Ensure we have state and first frame
|
264 |
if session_state["first_frame"] is None or session_state["inference_state"] is None:
|
265 |
print("Error: Cannot segment. No video loaded or inference state missing.")
|
266 |
+
# Return current images and state without changes
|
267 |
return (
|
268 |
+
session_state.get("first_frame"), # points_map (show first frame if exists)
|
269 |
+
None, # output_image (keep cleared)
|
270 |
session_state,
|
271 |
)
|
272 |
|
273 |
+
# evt.index is the (x, y) coordinate tuple
|
274 |
click_coords = evt.index
|
275 |
print(f"Clicked at: {click_coords} ({point_type})")
|
276 |
|
|
|
296 |
for index, track in enumerate(session_state["input_points"]):
|
297 |
# Ensure coordinates are integers for cv2.circle
|
298 |
point_coords = (int(track[0]), int(track[1]))
|
299 |
+
# Ensure color is RGBA (0-255)
|
300 |
if session_state["input_labels"][index] == 1:
|
301 |
+
cv2.circle(transparent_layer_points, point_coords, radius, (0, 255, 0, 255), -1) # Green for include
|
|
|
302 |
else:
|
303 |
+
cv2.circle(transparent_layer_points, point_coords, radius, (255, 0, 0, 255), -1) # Red for exclude
|
|
|
304 |
|
305 |
# Convert the transparent layer back to an image and composite onto the first frame
|
306 |
transparent_layer_points_pil = Image.fromarray(transparent_layer_points, "RGBA")
|
|
|
310 |
first_frame_pil.copy(), transparent_layer_points_pil
|
311 |
)
|
312 |
|
313 |
+
# Prepare points and labels as tensors on the correct device (CUDA in original code)
|
314 |
points = np.array(session_state["input_points"], dtype=np.float32)
|
315 |
labels = np.array(session_state["input_labels"], np.int32)
|
316 |
|
317 |
+
# Ensure tensors are on the correct device (CUDA as per original code setup)
|
318 |
+
device = next(predictor.parameters()).device # Get the device the model is on
|
319 |
+
points_tensor = torch.tensor(points, dtype=torch.float32, device=device).unsqueeze(0) # Add batch dim
|
320 |
+
labels_tensor = torch.tensor(labels, dtype=torch.int32, device=device).unsqueeze(0) # Add batch dim
|
321 |
|
322 |
# Add new points to the predictor's state and get the mask for the first frame
|
323 |
# This call performs segmentation on the current frame (frame_idx=0) using all accumulated points
|
|
|
333 |
)
|
334 |
|
335 |
# Process logits: detach from graph, move to CPU, apply threshold
|
336 |
+
# out_mask_logits is a list of tensors [tensor([batch_size, H, W])] for the requested obj_id
|
337 |
+
# Access the result for the first object (index 0) and the first item in batch (index 0)
|
338 |
+
mask_tensor = (out_mask_logits[0][0].detach().cpu() > 0.0) # Move to CPU before converting to numpy
|
339 |
mask_numpy = mask_tensor.numpy() # Convert to numpy
|
340 |
|
341 |
# Get the mask image (RGBA)
|
|
|
349 |
print(f"Error during segmentation on first frame: {e}")
|
350 |
# On error, first_frame_output_img remains None
|
351 |
|
352 |
+
# Original code clears CUDA cache here
|
353 |
+
if torch.cuda.is_available():
|
354 |
+
torch.cuda.empty_cache()
|
355 |
|
356 |
return selected_point_map_img, first_frame_output_img, session_state
|
357 |
|
|
|
402 |
return colored_mask_uint8
|
403 |
|
404 |
|
405 |
+
# Added @spaces.GPU decorator back as it was in the original code
|
406 |
+
@spaces.GPU
|
407 |
def propagate_to_all(
|
408 |
+
video_in, # Keep video_in path as in original
|
|
|
|
|
409 |
session_state,
|
410 |
):
|
411 |
"""Runs mask propagation through the video and generates the output video."""
|
412 |
print("Starting propagation...")
|
413 |
# Ensure state is ready
|
414 |
+
# Using session_state.get("video_path") is safer than video_in directly
|
415 |
+
current_video_path = session_state.get("video_path")
|
416 |
if (
|
417 |
len(session_state["input_points"]) == 0 # Need at least one point
|
418 |
or session_state["all_frames"] is None
|
419 |
or session_state["inference_state"] is None
|
420 |
+
or current_video_path is None # Ensure we have the original video path
|
421 |
):
|
422 |
print("Error: Cannot propagate. No points selected, video not loaded, or inference state missing.")
|
423 |
return (
|
|
|
426 |
)
|
427 |
|
428 |
# run propagation throughout the video and collect the results
|
|
|
429 |
video_segments = {}
|
430 |
try:
|
431 |
# This loop performs the core tracking prediction frame by frame
|
|
|
437 |
video_segments[out_frame_idx] = {
|
438 |
# out_mask_logits is a list of tensors (one per object tracked in this frame)
|
439 |
# Each tensor is [batch_size, H, W]. Batch size is 1 here.
|
440 |
+
# Access the result for the first object (index i) and the first item in batch (index 0)
|
441 |
out_obj_id: (out_mask_logits[i][0].detach().cpu() > 0.0).numpy()
|
442 |
for i, out_obj_id in enumerate(out_obj_ids)
|
443 |
}
|
|
|
478 |
|
479 |
output_frames.append(output_frame_np)
|
480 |
|
481 |
+
# Original code clears CUDA cache here
|
482 |
+
if torch.cuda.is_available():
|
483 |
+
torch.cuda.empty_cache()
|
484 |
|
485 |
# Define output path in a temporary directory
|
|
|
486 |
unique_id = datetime.now().strftime("%Y%m%d%H%M%S%f") # Use microseconds for more uniqueness
|
487 |
final_vid_filename = f"output_video_{unique_id}.mp4"
|
488 |
final_vid_output_path = os.path.join(tempfile.gettempdir(), final_vid_filename)
|
|
|
490 |
|
491 |
|
492 |
# Create a video clip from the image sequence
|
493 |
+
# Get original FPS from the stored video path
|
494 |
+
original_fps = get_video_fps(current_video_path)
|
|
|
495 |
fps = original_fps if original_fps is not None and original_fps > 0 else 30 # Default to 30 if detection fails or is zero
|
496 |
print(f"Creating output video with FPS: {fps}")
|
497 |
|
|
|
513 |
session_state,
|
514 |
)
|
515 |
|
|
|
516 |
# Write the result to a file. Use 'libx264' codec for broad compatibility.
|
|
|
|
|
517 |
try:
|
518 |
+
print(f"Writing video file with codec='libx264', fps={fps}")
|
519 |
+
# Added basic moviepy writing parameters back, similar to original intent
|
520 |
+
clip.write_videofile(final_vid_output_path, codec="libx264", fps=fps)
|
|
|
|
|
|
|
|
|
|
|
|
|
521 |
print("Video writing complete.")
|
522 |
# Return the path and make the video player visible
|
523 |
return (
|
|
|
541 |
)
|
542 |
|
543 |
|
544 |
+
def update_ui():
|
545 |
"""Simply returns a Gradio update to make the output video visible."""
|
546 |
return gr.update(visible=True)
|
547 |
|
|
|
589 |
points_map = gr.Image(
|
590 |
label="Click on the First Frame to Add Points", # Clearer label
|
591 |
type="numpy",
|
592 |
+
interactive=True, # <--- THIS WAS CHANGED FROM False TO True
|
593 |
height=400, # Set a fixed height for better UI
|
594 |
width="auto", # Let width adjust
|
595 |
show_share_button=False,
|
596 |
show_download_button=False,
|
|
|
597 |
)
|
598 |
|
599 |
with gr.Column():
|
|
|
604 |
examples_per_page=8,
|
605 |
cache_examples=False, # Do not cache processed examples, as state is involved
|
606 |
)
|
607 |
+
# Add padding/space - removed extra lines as they take up a lot of space
|
608 |
# gr.Markdown("<br>")
|
609 |
|
610 |
# output_image shows the segmentation mask prediction on the *first* frame
|
|
|
616 |
width="auto", # Let width adjust
|
617 |
show_share_button=False,
|
618 |
show_download_button=False,
|
|
|
619 |
)
|
620 |
|
621 |
# output_video shows the final tracking result
|
|
|
625 |
# --- Event Handlers ---
|
626 |
|
627 |
# When a new video file is uploaded via the file browser
|
628 |
+
# Added postprocess to update button interactivity based on whether video loaded
|
629 |
video_in.upload(
|
630 |
fn=preprocess_video_in,
|
631 |
inputs=[video_in, session_state],
|
632 |
outputs=[
|
633 |
+
video_in_drawer, points_map, output_image, output_video,
|
634 |
+
propagate_btn, clear_points_btn, reset_btn, session_state,
|
|
|
|
|
|
|
|
|
|
|
|
|
635 |
],
|
636 |
queue=False, # Process immediately
|
637 |
)
|
638 |
|
639 |
# When an example video is selected (change event)
|
640 |
+
# Added postprocess to update button interactivity
|
641 |
video_in.change(
|
642 |
fn=preprocess_video_in,
|
643 |
inputs=[video_in, session_state],
|
644 |
+
outputs=[
|
645 |
+
video_in_drawer, points_map, output_image, output_video,
|
646 |
+
propagate_btn, clear_points_btn, reset_btn, session_state,
|
|
|
|
|
|
|
|
|
|
|
|
|
647 |
],
|
648 |
queue=False, # Process immediately
|
649 |
)
|
|
|
682 |
fn=reset,
|
683 |
inputs=[session_state], # Pass session state
|
684 |
outputs=[
|
685 |
+
video_in, video_in_drawer, points_map, output_image, output_video,
|
686 |
+
propagate_btn, clear_points_btn, reset_btn, session_state,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
687 |
],
|
688 |
queue=False, # Process immediately
|
689 |
)
|
|
|
702 |
],
|
703 |
outputs=[
|
704 |
output_video, # Update output video player with result
|
705 |
+
session_state, # Update session state
|
706 |
],
|
707 |
+
# concurrency_limit from original code (may need adjustment based on your hardware/GPU)
|
708 |
+
concurrency_limit=10,
|
709 |
+
queue=False, # queue from original code
|
|
|
710 |
)
|
711 |
|
712 |
|
713 |
# Launch the Gradio demo
|
714 |
+
demo.queue() # Enable queuing
|
715 |
print("Gradio demo starting...")
|
|
|
716 |
demo.launch()
|
717 |
print("Gradio demo launched.")
|