Update app.py
Browse files
app.py
CHANGED
@@ -46,11 +46,30 @@ examples = [
|
|
46 |
|
47 |
OBJ_ID = 0
|
48 |
|
49 |
-
# Initialize model on CPU
|
50 |
sam2_checkpoint = "checkpoints/edgetam.pt"
|
51 |
model_cfg = "edgetam.yaml"
|
52 |
-
|
53 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
54 |
|
55 |
# Function to get video frame rate
|
56 |
def get_video_fps(video_path):
|
@@ -196,10 +215,9 @@ def segment_with_points(
|
|
196 |
print(f"TRACKING INPUT LABEL: {session_state['input_labels']}")
|
197 |
|
198 |
# Open the image and get its dimensions
|
199 |
-
|
200 |
-
|
201 |
-
)
|
202 |
-
w, h = transparent_background.size
|
203 |
|
204 |
# Define the circle radius as a fraction of the smaller dimension
|
205 |
fraction = 0.01 # You can adjust this value as needed
|
@@ -225,17 +243,38 @@ def segment_with_points(
|
|
225 |
# for labels, `1` means positive click and `0` means negative click
|
226 |
labels = np.array(session_state["input_labels"], np.int32)
|
227 |
|
228 |
-
|
229 |
-
|
230 |
-
|
231 |
-
|
232 |
-
|
233 |
-
|
234 |
-
|
235 |
-
|
236 |
-
|
237 |
-
|
238 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
239 |
|
240 |
return selected_point_map, first_frame_output, session_state
|
241 |
|
@@ -247,12 +286,36 @@ def show_mask(mask, obj_id=None, random_color=False, convert_to_image=True):
|
|
247 |
cmap = plt.get_cmap("tab10")
|
248 |
cmap_idx = 0 if obj_id is None else obj_id
|
249 |
color = np.array([*cmap(cmap_idx)[:3], 0.6])
|
250 |
-
|
251 |
-
|
252 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
253 |
if convert_to_image:
|
254 |
-
|
255 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
256 |
|
257 |
|
258 |
def propagate_to_all(
|
@@ -270,62 +333,130 @@ def propagate_to_all(
|
|
270 |
)
|
271 |
|
272 |
# For CPU optimization: process in smaller batches
|
273 |
-
chunk_size =
|
274 |
|
275 |
-
|
276 |
-
|
277 |
-
|
278 |
-
|
279 |
-
|
280 |
-
|
281 |
-
|
282 |
-
|
283 |
-
|
284 |
-
|
285 |
-
|
286 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
287 |
|
288 |
-
#
|
289 |
-
if
|
290 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
291 |
|
292 |
-
|
293 |
-
|
294 |
-
|
295 |
-
|
296 |
-
output_frames = []
|
297 |
-
for out_frame_idx in range(0, len(video_segments), vis_frame_stride):
|
298 |
-
transparent_background = Image.fromarray(
|
299 |
-
session_state["all_frames"][out_frame_idx]
|
300 |
-
).convert("RGBA")
|
301 |
-
out_mask = video_segments[out_frame_idx][OBJ_ID]
|
302 |
-
mask_image = show_mask(out_mask)
|
303 |
-
output_frame = Image.alpha_composite(transparent_background, mask_image)
|
304 |
-
output_frame = np.array(output_frame)
|
305 |
-
output_frames.append(output_frame)
|
306 |
-
|
307 |
-
# Create a video clip from the image sequence
|
308 |
-
original_fps = get_video_fps(video_in)
|
309 |
-
fps = original_fps # Frames per second
|
310 |
-
|
311 |
-
# For CPU optimization - lower FPS if original is high
|
312 |
-
if fps > 24:
|
313 |
-
fps = 24
|
314 |
-
|
315 |
-
clip = ImageSequenceClip(output_frames, fps=fps)
|
316 |
|
317 |
-
|
318 |
-
|
319 |
-
|
320 |
-
|
321 |
-
|
322 |
-
|
323 |
-
clip.write_videofile(final_vid_output_path, codec="libx264", bitrate="1000k")
|
324 |
-
|
325 |
-
return (
|
326 |
-
gr.update(value=final_vid_output_path),
|
327 |
-
session_state,
|
328 |
-
)
|
329 |
|
330 |
|
331 |
def update_ui():
|
|
|
46 |
|
47 |
OBJ_ID = 0
|
48 |
|
49 |
+
# Initialize model on CPU - add error handling for file paths
|
50 |
sam2_checkpoint = "checkpoints/edgetam.pt"
|
51 |
model_cfg = "edgetam.yaml"
|
52 |
+
|
53 |
+
# Check if model files exist
|
54 |
+
def check_file_exists(filepath):
|
55 |
+
import os
|
56 |
+
exists = os.path.exists(filepath)
|
57 |
+
if not exists:
|
58 |
+
print(f"WARNING: File not found: {filepath}")
|
59 |
+
return exists
|
60 |
+
|
61 |
+
# Verify files exist
|
62 |
+
model_files_exist = check_file_exists(sam2_checkpoint) and check_file_exists(model_cfg)
|
63 |
+
try:
|
64 |
+
# Load model with more careful error handling
|
65 |
+
predictor = build_sam2_video_predictor(model_cfg, sam2_checkpoint, device="cpu")
|
66 |
+
print("predictor loaded on CPU")
|
67 |
+
except Exception as e:
|
68 |
+
print(f"Error loading model: {e}")
|
69 |
+
import traceback
|
70 |
+
traceback.print_exc()
|
71 |
+
# Still create a predictor variable to avoid NameError
|
72 |
+
predictor = None
|
73 |
|
74 |
# Function to get video frame rate
|
75 |
def get_video_fps(video_path):
|
|
|
215 |
print(f"TRACKING INPUT LABEL: {session_state['input_labels']}")
|
216 |
|
217 |
# Open the image and get its dimensions
|
218 |
+
first_frame = session_state["first_frame"]
|
219 |
+
h, w = first_frame.shape[:2]
|
220 |
+
transparent_background = Image.fromarray(first_frame).convert("RGBA")
|
|
|
221 |
|
222 |
# Define the circle radius as a fraction of the smaller dimension
|
223 |
fraction = 0.01 # You can adjust this value as needed
|
|
|
243 |
# for labels, `1` means positive click and `0` means negative click
|
244 |
labels = np.array(session_state["input_labels"], np.int32)
|
245 |
|
246 |
+
try:
|
247 |
+
# For CPU optimization, we'll process with smaller batch size
|
248 |
+
_, _, out_mask_logits = predictor.add_new_points(
|
249 |
+
inference_state=session_state["inference_state"],
|
250 |
+
frame_idx=0,
|
251 |
+
obj_id=OBJ_ID,
|
252 |
+
points=points,
|
253 |
+
labels=labels,
|
254 |
+
)
|
255 |
+
|
256 |
+
# Create the mask
|
257 |
+
mask_array = (out_mask_logits[0] > 0.0).cpu().numpy()
|
258 |
+
|
259 |
+
# Ensure the mask has the same size as the frame
|
260 |
+
if mask_array.shape[:2] != (h, w):
|
261 |
+
mask_array = cv2.resize(
|
262 |
+
mask_array.astype(np.uint8),
|
263 |
+
(w, h),
|
264 |
+
interpolation=cv2.INTER_NEAREST
|
265 |
+
).astype(bool)
|
266 |
+
|
267 |
+
mask_image = show_mask(mask_array)
|
268 |
+
|
269 |
+
# Make sure mask_image has the same size as the background
|
270 |
+
if mask_image.size != transparent_background.size:
|
271 |
+
mask_image = mask_image.resize(transparent_background.size, Image.NEAREST)
|
272 |
+
|
273 |
+
first_frame_output = Image.alpha_composite(transparent_background, mask_image)
|
274 |
+
except Exception as e:
|
275 |
+
print(f"Error in segmentation: {e}")
|
276 |
+
# Return just the points as fallback
|
277 |
+
first_frame_output = selected_point_map
|
278 |
|
279 |
return selected_point_map, first_frame_output, session_state
|
280 |
|
|
|
286 |
cmap = plt.get_cmap("tab10")
|
287 |
cmap_idx = 0 if obj_id is None else obj_id
|
288 |
color = np.array([*cmap(cmap_idx)[:3], 0.6])
|
289 |
+
|
290 |
+
# Handle different mask shapes properly
|
291 |
+
if len(mask.shape) == 2:
|
292 |
+
h, w = mask.shape
|
293 |
+
else:
|
294 |
+
h, w = mask.shape[-2:]
|
295 |
+
|
296 |
+
# Ensure correct reshaping based on mask dimensions
|
297 |
+
mask_reshaped = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
|
298 |
+
mask_rgba = (mask_reshaped * 255).astype(np.uint8)
|
299 |
+
|
300 |
if convert_to_image:
|
301 |
+
try:
|
302 |
+
# Ensure the mask has correct RGBA shape (h, w, 4)
|
303 |
+
if mask_rgba.shape[2] != 4:
|
304 |
+
# If not RGBA, create a proper RGBA array
|
305 |
+
proper_mask = np.zeros((h, w, 4), dtype=np.uint8)
|
306 |
+
# Copy available channels
|
307 |
+
proper_mask[:, :, :min(mask_rgba.shape[2], 4)] = mask_rgba[:, :, :min(mask_rgba.shape[2], 4)]
|
308 |
+
mask_rgba = proper_mask
|
309 |
+
|
310 |
+
# Create the PIL image
|
311 |
+
return Image.fromarray(mask_rgba, "RGBA")
|
312 |
+
except Exception as e:
|
313 |
+
print(f"Error converting mask to image: {e}")
|
314 |
+
# Fallback: create a blank transparent image of correct size
|
315 |
+
blank = np.zeros((h, w, 4), dtype=np.uint8)
|
316 |
+
return Image.fromarray(blank, "RGBA")
|
317 |
+
|
318 |
+
return mask_rgba
|
319 |
|
320 |
|
321 |
def propagate_to_all(
|
|
|
333 |
)
|
334 |
|
335 |
# For CPU optimization: process in smaller batches
|
336 |
+
chunk_size = 3 # Process 3 frames at a time to avoid memory issues on CPU
|
337 |
|
338 |
+
try:
|
339 |
+
# run propagation throughout the video and collect the results in a dict
|
340 |
+
video_segments = {} # video_segments contains the per-frame segmentation results
|
341 |
+
print("starting propagate_in_video on CPU")
|
342 |
+
|
343 |
+
# Get the frames in chunks for CPU memory optimization
|
344 |
+
for out_frame_idx, out_obj_ids, out_mask_logits in predictor.propagate_in_video(
|
345 |
+
session_state["inference_state"]
|
346 |
+
):
|
347 |
+
try:
|
348 |
+
# Store the masks for each object ID
|
349 |
+
video_segments[out_frame_idx] = {
|
350 |
+
out_obj_id: (out_mask_logits[i] > 0.0).cpu().numpy()
|
351 |
+
for i, out_obj_id in enumerate(out_obj_ids)
|
352 |
+
}
|
353 |
+
|
354 |
+
print(f"Processed frame {out_frame_idx}")
|
355 |
+
|
356 |
+
# Release memory periodically
|
357 |
+
if out_frame_idx % chunk_size == 0:
|
358 |
+
# Explicitly clear any tensors
|
359 |
+
del out_mask_logits
|
360 |
+
import gc
|
361 |
+
gc.collect()
|
362 |
+
except Exception as e:
|
363 |
+
print(f"Error processing frame {out_frame_idx}: {e}")
|
364 |
+
continue
|
365 |
+
|
366 |
+
# For CPU optimization: increase stride to reduce processing
|
367 |
+
# Create a more aggressive stride to limit to fewer frames in output
|
368 |
+
total_frames = len(video_segments)
|
369 |
+
print(f"Total frames processed: {total_frames}")
|
370 |
+
|
371 |
+
# Limit to max 50 frames for CPU processing
|
372 |
+
max_output_frames = 50
|
373 |
+
vis_frame_stride = max(1, total_frames // max_output_frames)
|
374 |
+
|
375 |
+
# Get dimensions of the frames
|
376 |
+
first_frame = session_state["all_frames"][0]
|
377 |
+
h, w = first_frame.shape[:2]
|
378 |
+
|
379 |
+
output_frames = []
|
380 |
+
for out_frame_idx in range(0, total_frames, vis_frame_stride):
|
381 |
+
if out_frame_idx not in video_segments or OBJ_ID not in video_segments[out_frame_idx]:
|
382 |
+
continue
|
383 |
+
|
384 |
+
try:
|
385 |
+
frame = session_state["all_frames"][out_frame_idx]
|
386 |
+
transparent_background = Image.fromarray(frame).convert("RGBA")
|
387 |
+
|
388 |
+
# Get the mask and ensure it's the right size
|
389 |
+
out_mask = video_segments[out_frame_idx][OBJ_ID]
|
390 |
+
|
391 |
+
# Resize mask if dimensions don't match
|
392 |
+
if out_mask.shape[:2] != (h, w):
|
393 |
+
out_mask = cv2.resize(
|
394 |
+
out_mask.astype(np.uint8),
|
395 |
+
(w, h),
|
396 |
+
interpolation=cv2.INTER_NEAREST
|
397 |
+
).astype(bool)
|
398 |
+
|
399 |
+
mask_image = show_mask(out_mask)
|
400 |
+
|
401 |
+
# Make sure mask has same dimensions as background
|
402 |
+
if mask_image.size != transparent_background.size:
|
403 |
+
mask_image = mask_image.resize(transparent_background.size, Image.NEAREST)
|
404 |
+
|
405 |
+
output_frame = Image.alpha_composite(transparent_background, mask_image)
|
406 |
+
output_frame = np.array(output_frame)
|
407 |
+
output_frames.append(output_frame)
|
408 |
+
|
409 |
+
# Clear memory periodically
|
410 |
+
if len(output_frames) % 10 == 0:
|
411 |
+
import gc
|
412 |
+
gc.collect()
|
413 |
+
|
414 |
+
except Exception as e:
|
415 |
+
print(f"Error creating output frame {out_frame_idx}: {e}")
|
416 |
+
continue
|
417 |
+
|
418 |
+
# Create a video clip from the image sequence
|
419 |
+
original_fps = get_video_fps(video_in)
|
420 |
+
fps = original_fps
|
421 |
|
422 |
+
# For CPU optimization - lower FPS if original is high
|
423 |
+
if fps > 15:
|
424 |
+
fps = 15 # Lower fps for CPU processing
|
425 |
+
|
426 |
+
print(f"Creating video with {len(output_frames)} frames at {fps} FPS")
|
427 |
+
clip = ImageSequenceClip(output_frames, fps=fps)
|
428 |
+
|
429 |
+
# Write the result to a file - use lower quality for CPU
|
430 |
+
unique_id = datetime.now().strftime("%Y%m%d%H%M%S")
|
431 |
+
final_vid_output_path = f"output_video_{unique_id}.mp4"
|
432 |
+
final_vid_output_path = os.path.join(tempfile.gettempdir(), final_vid_output_path)
|
433 |
+
|
434 |
+
# Lower bitrate for CPU processing
|
435 |
+
clip.write_videofile(
|
436 |
+
final_vid_output_path,
|
437 |
+
codec="libx264",
|
438 |
+
bitrate="800k",
|
439 |
+
threads=2, # Use fewer threads for CPU
|
440 |
+
logger=None # Disable logger to reduce console output
|
441 |
+
)
|
442 |
+
|
443 |
+
# Free memory
|
444 |
+
del video_segments
|
445 |
+
del output_frames
|
446 |
+
import gc
|
447 |
+
gc.collect()
|
448 |
|
449 |
+
return (
|
450 |
+
gr.update(value=final_vid_output_path, visible=True),
|
451 |
+
session_state,
|
452 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
453 |
|
454 |
+
except Exception as e:
|
455 |
+
print(f"Error in propagate_to_all: {e}")
|
456 |
+
return (
|
457 |
+
gr.update(value=None, visible=False),
|
458 |
+
session_state,
|
459 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
460 |
|
461 |
|
462 |
def update_ui():
|