Spaces:
bla
/
Runtime error

bla commited on
Commit
ea6a5ed
·
verified ·
1 Parent(s): 6e871ac

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +388 -138
app.py CHANGED
@@ -1,178 +1,428 @@
1
- # The full rewritten version of the provided code with progress bar, error fixes, and proper Gradio integration
 
 
 
 
2
 
3
- import os
4
  import copy
5
- import tempfile
6
  from datetime import datetime
7
- import gc
8
 
9
  import cv2
10
- import numpy as np
11
- from PIL import Image
12
  import matplotlib.pyplot as plt
13
- import torch
14
  import gradio as gr
 
15
  from moviepy.editor import ImageSequenceClip
16
-
17
  from sam2.build_sam import build_sam2_video_predictor
18
 
19
- # Remove CUDA-related env var to force CPU-only mode
20
- os.environ.pop("TORCH_CUDNN_SDPA_ENABLED", None)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
 
22
- # Config
23
  sam2_checkpoint = "checkpoints/edgetam.pt"
24
  model_cfg = "edgetam.yaml"
25
- examples = [[f"examples/{vid}"] for vid in ["01_dog.mp4", "02_cups.mp4", "03_blocks.mp4", "04_coffee.mp4", "05_default_juggle.mp4"]]
26
- OBJ_ID = 0
27
 
28
- # Model loader
29
- if os.path.exists(sam2_checkpoint) and os.path.exists(model_cfg):
30
- try:
31
- predictor = build_sam2_video_predictor(model_cfg, sam2_checkpoint, device="cpu")
32
- except Exception as e:
33
- print("Error loading predictor:", e)
34
- predictor = None
35
- else:
36
- print("Model files missing.")
 
 
 
 
 
 
37
  predictor = None
38
 
39
- def get_fps(video_path):
 
40
  cap = cv2.VideoCapture(video_path)
41
- if not cap.isOpened(): return 30.0
 
 
42
  fps = cap.get(cv2.CAP_PROP_FPS)
43
  cap.release()
44
  return fps
45
 
46
- def reset(session):
47
- if session["inference_state"]:
48
- predictor.reset_state(session["inference_state"])
49
- session.update({"input_points": [], "input_labels": [], "first_frame": None, "all_frames": None, "inference_state": None})
50
- return None, gr.update(open=True), None, None, gr.update(value=None, visible=False), session
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
 
52
- def clear_points(session):
53
- session["input_points"] = []
54
- session["input_labels"] = []
55
- if session["inference_state"] and session["inference_state"].get("tracking_has_started"):
56
- predictor.reset_state(session["inference_state"])
57
- return session["first_frame"], None, gr.update(value=None, visible=False), session
 
 
 
58
 
59
- def preprocess_video(video_path, session):
60
  cap = cv2.VideoCapture(video_path)
61
- if not cap.isOpened(): return gr.update(open=True), None, None, gr.update(value=None, visible=False), session
62
-
 
 
 
 
 
 
 
 
 
 
 
63
  total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
64
- stride = max(1, total_frames // 300)
65
- frames, first_frame = [], None
66
-
67
- w, h = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)), int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
68
- target_w = 640
69
- scale = target_w / w if w > target_w else 1.0
70
 
71
- frame_id = 0
 
 
 
 
 
 
 
 
 
 
 
 
 
72
  while True:
73
  ret, frame = cap.read()
74
- if not ret: break
75
- if frame_id % stride == 0:
76
- if scale < 1.0:
77
- frame = cv2.resize(frame, (int(w*scale), int(h*scale)))
 
78
  frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
79
- if first_frame is None: first_frame = frame
80
- frames.append(frame)
81
- frame_id += 1
 
 
82
  cap.release()
 
 
 
 
 
83
 
84
- session.update({"first_frame": first_frame, "all_frames": frames, "frame_stride": stride, "scale_factor": scale, "inference_state": predictor.init_state(video_path=video_path), "input_points": [], "input_labels": []})
85
- return gr.update(open=False), first_frame, None, gr.update(value=None, visible=False), session
86
-
87
- def show_mask(mask, obj_id=None):
88
- cmap = plt.get_cmap("tab10")
89
- color = np.array([*cmap(0 if obj_id is None else obj_id)[:3], 0.6])
90
- h, w = mask.shape
91
- mask_rgba = (mask.reshape(h, w, 1) * color.reshape(1, 1, -1) * 255).astype(np.uint8)
92
- proper_mask = np.zeros((h, w, 4), dtype=np.uint8)
93
- proper_mask[:, :, :min(mask_rgba.shape[2], 4)] = mask_rgba[:, :, :min(mask_rgba.shape[2], 4)]
94
- return Image.fromarray(proper_mask, "RGBA")
95
-
96
- def segment_with_points(ptype, session, evt):
97
- session["input_points"].append(evt.index)
98
- session["input_labels"].append(1 if ptype == "include" else 0)
99
- first = session["first_frame"]
100
- h, w = first.shape[:2]
101
-
102
- layer = np.zeros((h, w, 4), dtype=np.uint8)
103
- for idx, pt in enumerate(session["input_points"]):
104
- color = (0, 255, 0, 255) if session["input_labels"][idx] == 1 else (255, 0, 0, 255)
105
- cv2.circle(layer, pt, int(min(w, h)*0.01), color, -1)
106
-
107
- overlay = Image.alpha_composite(Image.fromarray(first).convert("RGBA"), Image.fromarray(layer, "RGBA"))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
108
 
109
  try:
110
- _, _, logits = predictor.add_new_points(session["inference_state"], 0, OBJ_ID, np.array(session["input_points"]), np.array(session["input_labels"]))
111
- mask = (logits[0] > 0.0).cpu().numpy()
112
- mask = cv2.resize(mask.astype(np.uint8), (w, h), interpolation=cv2.INTER_NEAREST).astype(bool)
113
- mask_img = show_mask(mask)
114
- return overlay, Image.alpha_composite(Image.fromarray(first).convert("RGBA"), mask_img), session
 
 
 
 
 
 
 
 
 
 
 
 
 
115
  except Exception as e:
116
- print("Segmentation error:", e)
117
- return overlay, overlay, session
118
 
119
- def propagate(video_in, session, progress=gr.Progress()):
120
- if not session["input_points"] or not session["inference_state"]: return None, session
121
 
122
- masks = {}
123
- for i, (idxs, obj_ids, logits) in enumerate(predictor.propagate_in_video(session["inference_state"])):
124
- try:
125
- masks[idxs] = {oid: (logits[j] > 0.0).cpu().numpy() for j, oid in enumerate(obj_ids)}
126
- progress(i / 300, desc=f"Tracking frame {idxs}")
127
- except: continue
 
128
 
129
- frames_out, stride = [], max(1, len(masks) // 50)
130
- for i in range(0, len(masks), stride):
131
- if i not in masks or OBJ_ID not in masks[i]: continue
 
 
132
  try:
133
- frame = session["all_frames"][i]
134
- mask = masks[i][OBJ_ID]
135
- h, w = frame.shape[:2]
136
- mask = cv2.resize(mask.astype(np.uint8), (w, h), interpolation=cv2.INTER_NEAREST).astype(bool)
137
- output = Image.alpha_composite(Image.fromarray(frame).convert("RGBA"), show_mask(mask))
138
- frames_out.append(np.array(output))
139
- except: continue
140
-
141
- out_path = os.path.join(tempfile.gettempdir(), f"output_video_{datetime.now().strftime('%Y%m%d%H%M%S')}.mp4")
142
- fps = min(15, get_fps(video_in))
143
- ImageSequenceClip(frames_out, fps=fps).write_videofile(out_path, codec="libx264", bitrate="800k", threads=2, logger=None)
144
- gc.collect()
145
- return gr.update(value=out_path, visible=True), session
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
146
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
147
  with gr.Blocks() as demo:
148
- state = gr.State({"first_frame": None, "all_frames": None, "input_points": [], "input_labels": [], "inference_state": None, "frame_stride": 1, "scale_factor": 1.0, "original_dimensions": None})
149
-
150
- gr.Markdown("<center><strong><font size='8'>EdgeTAM CPU</font></strong> <a href='https://github.com/facebookresearch/EdgeTAM'><font size='6'>[GitHub]</font></a></center>")
151
-
152
- with gr.Row():
153
- with gr.Column():
154
- gr.Markdown("""<ol><li>Upload a video or use an example</li><li>Select 'include' or 'exclude' and click points</li><li>Click 'Track' to segment and track</li></ol>""")
155
- drawer = gr.Accordion("Input Video", open=True)
156
- with drawer:
157
- video_in = gr.Video(label="Input Video", format="mp4")
158
- ptype = gr.Radio(label="Point Type", choices=["include", "exclude"], value="include")
159
- track_btn = gr.Button("Track", variant="primary")
160
- clear_btn = gr.Button("Clear Points")
161
- reset_btn = gr.Button("Reset")
162
- points_map = gr.Image(label="Frame with Points", type="numpy", interactive=False)
163
- with gr.Column():
164
- gr.Markdown("# Try some examples ⬇️")
165
- gr.Examples(examples, inputs=[video_in], examples_per_page=5)
166
- output_img = gr.Image(label="Reference Mask")
167
- output_vid = gr.Video(visible=False)
168
-
169
- video_in.upload(preprocess_video, [video_in, state], [drawer, points_map, output_img, output_vid, state])
170
- video_in.change(preprocess_video, [video_in, state], [drawer, points_map, output_img, output_vid, state])
171
- points_map.select(segment_with_points, [ptype, state], [points_map, output_img, state])
172
- clear_btn.click(clear_points, state, [points_map, output_img, output_vid, state])
173
- reset_btn.click(reset, state, [video_in, drawer, points_map, output_img, output_vid, state])
174
- track_btn.click(fn=propagate, inputs=[video_in, state], outputs=[output_vid, state])
175
-
176
- if __name__ == '__main__':
177
- demo.queue()
178
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
 
 
7
  import copy
8
+ import os
9
  from datetime import datetime
10
+ import tempfile
11
 
12
  import cv2
 
 
13
  import matplotlib.pyplot as plt
14
+ import numpy as np
15
  import gradio as gr
16
+ import torch
17
  from moviepy.editor import ImageSequenceClip
18
+ from PIL import Image
19
  from sam2.build_sam import build_sam2_video_predictor
20
 
21
+ # Remove CUDA environment variables
22
+ if 'TORCH_CUDNN_SDPA_ENABLED' in os.environ:
23
+ del os.environ["TORCH_CUDNN_SDPA_ENABLED"]
24
+
25
+ # UI Description
26
+ title = "<center><strong><font size='8'>EdgeTAM CPU</font></strong> <a href='https://github.com/facebookresearch/EdgeTAM'><font size='6'>[GitHub]</font></a></center>"
27
+
28
+ description_p = """# Instructions
29
+ <ol>
30
+ <li>Upload one video or click one example video</li>
31
+ <li>Click 'include' point type, select the object to segment and track</li>
32
+ <li>Click 'exclude' point type (optional), select the area to avoid segmenting</li>
33
+ <li>Click the 'Track' button to obtain the masked video</li>
34
+ </ol>
35
+ """
36
+
37
+ # Example videos
38
+ examples = [
39
+ ["examples/01_dog.mp4"],
40
+ ["examples/02_cups.mp4"],
41
+ ["examples/03_blocks.mp4"],
42
+ ["examples/04_coffee.mp4"],
43
+ ["examples/05_default_juggle.mp4"],
44
+ ]
45
+
46
+ OBJ_ID = 0
47
 
48
+ # Initialize model on CPU
49
  sam2_checkpoint = "checkpoints/edgetam.pt"
50
  model_cfg = "edgetam.yaml"
 
 
51
 
52
+ def check_file_exists(filepath):
53
+ exists = os.path.exists(filepath)
54
+ if not exists:
55
+ print(f"WARNING: File not found: {filepath}")
56
+ return exists
57
+
58
+ # Verify model files
59
+ model_files_exist = check_file_exists(sam2_checkpoint) and check_file_exists(model_cfg)
60
+ try:
61
+ predictor = build_sam2_video_predictor(model_cfg, sam2_checkpoint, device="cpu")
62
+ print("Predictor loaded on CPU")
63
+ except Exception as e:
64
+ print(f"Error loading model: {e}")
65
+ import traceback
66
+ traceback.print_exc()
67
  predictor = None
68
 
69
+ # Utility Functions
70
+ def get_video_fps(video_path):
71
  cap = cv2.VideoCapture(video_path)
72
+ if not cap.isOpened():
73
+ print("Error: Could not open video.")
74
+ return 30.0
75
  fps = cap.get(cv2.CAP_PROP_FPS)
76
  cap.release()
77
  return fps
78
 
79
+ def reset(session_state):
80
+ session_state["input_points"] = []
81
+ session_state["input_labels"] = []
82
+ if session_state["inference_state"] is not None:
83
+ predictor.reset_state(session_state["inference_state"])
84
+ session_state["first_frame"] = None
85
+ session_state["all_frames"] = None
86
+ session_state["inference_state"] = None
87
+ return (
88
+ None,
89
+ gr.update(open=True),
90
+ None,
91
+ None,
92
+ gr.update(value=None, visible=False),
93
+ session_state,
94
+ )
95
+
96
+ def clear_points(session_state):
97
+ session_state["input_points"] = []
98
+ session_state["input_labels"] = []
99
+ if session_state["inference_state"] is not None and session_state["inference_state"].get("tracking_has_started", False):
100
+ predictor.reset_state(session_state["inference_state"])
101
+ return (
102
+ session_state["first_frame"],
103
+ None,
104
+ gr.update(value=None, visible=False),
105
+ session_state,
106
+ )
107
 
108
+ def preprocess_video_in(video_path, session_state):
109
+ if video_path is None:
110
+ return (
111
+ gr.update(open=True),
112
+ None,
113
+ None,
114
+ gr.update(value=None, visible=False),
115
+ session_state,
116
+ )
117
 
 
118
  cap = cv2.VideoCapture(video_path)
119
+ if not cap.isOpened():
120
+ print("Error: Could not open video.")
121
+ return (
122
+ gr.update(open=True),
123
+ None,
124
+ None,
125
+ gr.update(value=None, visible=False),
126
+ session_state,
127
+ )
128
+
129
+ # Video properties
130
+ frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
131
+ frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
132
  total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
 
 
 
 
 
 
133
 
134
+ # Resize for CPU performance
135
+ target_width = 640
136
+ scale_factor = 1.0
137
+ if frame_width > target_width:
138
+ scale_factor = target_width / frame_width
139
+ frame_width = target_width
140
+ frame_height = int(frame_height * scale_factor)
141
+
142
+ # Read frames with stride for CPU optimization
143
+ frame_number = 0
144
+ first_frame = None
145
+ all_frames = []
146
+ frame_stride = max(1, total_frames // 300) # Limit to ~300 frames
147
+
148
  while True:
149
  ret, frame = cap.read()
150
+ if not ret:
151
+ break
152
+ if frame_number % frame_stride == 0:
153
+ if scale_factor != 1.0:
154
+ frame = cv2.resize(frame, (frame_width, frame_height), interpolation=cv2.INTER_AREA)
155
  frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
156
+ if first_frame is None:
157
+ first_frame = frame
158
+ all_frames.append(frame)
159
+ frame_number += 1
160
+
161
  cap.release()
162
+ session_state["first_frame"] = copy.deepcopy(first_frame)
163
+ session_state["all_frames"] = all_frames
164
+ session_state["frame_stride"] = frame_stride
165
+ session_state["scale_factor"] = scale_factor
166
+ session_state["original_dimensions"] = (int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)), int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)))
167
 
168
+ session_state["inference_state"] = predictor.init_state(video_path=video_path)
169
+ session_state["input_points"] = []
170
+ session_state["input_labels"] = []
171
+
172
+ return [
173
+ gr.update(open=False),
174
+ first_frame,
175
+ None,
176
+ gr.update(value=None, visible=False),
177
+ session_state,
178
+ ]
179
+
180
+ def segment_with_points(point_type, session_state, evt: gr.SelectData):
181
+ session_state["input_points"].append(evt.index)
182
+ print(f"TRACKING INPUT POINT: {session_state['input_points']}")
183
+
184
+ if point_type == "include":
185
+ session_state["input_labels"].append(1)
186
+ elif point_type == "exclude":
187
+ session_state["input_labels"].append(0)
188
+ print(f"TRACKING INPUT LABEL: {session_state['input_labels']}")
189
+
190
+ first_frame = session_state["first_frame"]
191
+ h, w = first_frame.shape[:2]
192
+ transparent_background = Image.fromarray(first_frame).convert("RGBA")
193
+
194
+ # Draw points
195
+ fraction = 0.01
196
+ radius = int(fraction * min(w, h))
197
+ transparent_layer = np.zeros((h, w, 4), dtype=np.uint8)
198
+
199
+ for index, track in enumerate(session_state["input_points"]):
200
+ color = (0, 255, 0, 255) if session_state["input_labels"][index] == 1 else (255, 0, 0, 255)
201
+ cv2.circle(transparent_layer, track, radius, color, -1)
202
+
203
+ transparent_layer = Image.fromarray(transparent_layer, "RGBA")
204
+ selected_point_map = Image.alpha_composite(transparent_background, transparent_layer)
205
+
206
+ points = np.array(session_state["input_points"], dtype=np.float32)
207
+ labels = np.array(session_state["input_labels"], np.int32)
208
 
209
  try:
210
+ _, _, out_mask_logits = predictor.add_new_points(
211
+ inference_state=session_state["inference_state"],
212
+ frame_idx=0,
213
+ obj_id=OBJ_ID,
214
+ points=points,
215
+ labels=labels,
216
+ )
217
+ mask_array = (out_mask_logits[0] > 0.0).cpu().numpy()
218
+
219
+ # Ensure mask matches frame size
220
+ if mask_array.shape[:2] != (h, w):
221
+ mask_array = cv2.resize(mask_array.astype(np.uint8), (w, h), interpolation=cv2.INTER_NEAREST).astype(bool)
222
+
223
+ mask_image = show_mask(mask_array)
224
+ if mask_image.size != transparent_background.size:
225
+ mask_image = mask_image.resize(transparent_background.size, Image.NEAREST)
226
+
227
+ first_frame_output = Image.alpha_composite(transparent_background, mask_image)
228
  except Exception as e:
229
+ print(f"Error in segmentation: {e}")
230
+ first_frame_output = selected_point_map
231
 
232
+ return selected_point_map, first_frame_output, session_state
 
233
 
234
+ def show_mask(mask, obj_id=None, random_color=False, convert_to_image=True):
235
+ if random_color:
236
+ color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
237
+ else:
238
+ cmap = plt.get_cmap("tab10")
239
+ cmap_idx = 0 if obj_id is None else obj_id
240
+ color = np.array([*cmap(cmap_idx)[:3], 0.6])
241
 
242
+ h, w = mask.shape[-2:] if len(mask.shape) > 2 else mask.shape
243
+ mask_reshaped = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
244
+ mask_rgba = (mask_reshaped * 255).astype(np.uint8)
245
+
246
+ if convert_to_image:
247
  try:
248
+ if mask_rgba.shape[2] != 4:
249
+ proper_mask = np.zeros((h, w, 4), dtype=np.uint8)
250
+ proper_mask[:, :, :min(mask_rgba.shape[2], 4)] = mask_rgba[:, :, :min(mask_rgba.shape[2], 4)]
251
+ mask_rgba = proper_mask
252
+ return Image.fromarray(mask_rgba, "RGBA")
253
+ except Exception as e:
254
+ print(f"Error converting mask to image: {e}")
255
+ return Image.fromarray(np.zeros((h, w, 4), dtype=np.uint8), "RGBA")
256
+
257
+ return mask_rgba
258
+
259
+ def propagate_to_all(video_in, session_state, progress=gr.Progress()):
260
+ if len(session_state["input_points"]) == 0 or video_in is None or session_state["inference_state"] is None:
261
+ return gr.update(value=None, visible=False), session_state
262
+
263
+ chunk_size = 3
264
+ try:
265
+ video_segments = {}
266
+ total_frames = len(session_state["all_frames"])
267
+ progress(0, desc="Propagating segmentation through video...")
268
+
269
+ for i, (out_frame_idx, out_obj_ids, out_mask_logit) in enumerate(predictor.propagate_in_video(session_state["inference_state"])):
270
+ try:
271
+ video_segments[out_frame_idx] = {
272
+ out_obj_id: (out_mask_logit[i] > 0.0).cpu().numpy()
273
+ for i, out_obj_id in enumerate(out_obj_ids)
274
+ }
275
+ progress((i + 1) / total_frames, desc=f"Processed frame {out_frame_idx}/{total_frames}")
276
+ if out_frame_idx % chunk_size == 0:
277
+ del out_mask_logit
278
+ import gc
279
+ gc.collect()
280
+ except Exception as e:
281
+ print(f"Error processing frame {out_frame_idx}: {e}")
282
+ continue
283
+
284
+ max_output_frames = 50
285
+ vis_frame_stride = max(1, total_frames // max_output_frames)
286
+ first_frame = session_state["all_frames"][0]
287
+ h, w = first_frame.shape[:2]
288
+ output_frames = []
289
+
290
+ for out_frame_idx in range(0, total_frames, vis_frame_stride):
291
+ if out_frame_idx not in video_segments or OBJ_ID not in video_segments[out_frame_idx]:
292
+ continue
293
+ try:
294
+ frame = session_state["all_frames"][out_frame_idx]
295
+ transparent_background = Image.fromarray(frame).convert("RGBA")
296
+ out_mask = video_segments[out_frame_idx][OBJ_ID]
297
+
298
+ # Validate mask dimensions
299
+ if out_mask.shape[:2] != (h, w):
300
+ if out_mask.size == 0: # Skip empty masks
301
+ print(f"Skipping empty mask for frame {out_frame_idx}")
302
+ continue
303
+ out_mask = cv2.resize(out_mask.astype(np.uint8), (w, h), interpolation=cv2.INTER_NEAREST).astype(bool)
304
+
305
+ mask_image = show_mask(out_mask)
306
+ if mask_image.size != transparent_background.size:
307
+ mask_image = mask_image.resize(transparent_background.size, Image.NEAREST)
308
+
309
+ output_frame = Image.alpha_composite(transparent_background, mask_image)
310
+ output_frames.append(np.array(output_frame))
311
+
312
+ if len(output_frames) % 10 == 0:
313
+ import gc
314
+ gc.collect()
315
+ except Exception as e:
316
+ print(f"Error creating output frame {out_frame_idx}: {e_RAW
317
+ traceback.print_exc()
318
+ continue
319
 
320
+ original_fps = get_video_fps(video_in)
321
+ fps = min(original_fps, 15) # Cap at 15 FPS for CPU
322
+
323
+ clip = ImageSequenceClip(output_frames, fps=fps)
324
+ unique_id = datetime.now().strftime("%Y%m%d%H%M%S")
325
+ final_vid_output_path = os.path.join(tempfile.gettempdir(), f"output_video_{unique_id}.mp4")
326
+
327
+ clip.write_videofile(
328
+ final_vid_output_path,
329
+ codec="libx264",
330
+ bitrate="800k",
331
+ threads=2,
332
+ logger=None
333
+ )
334
+
335
+ del video_segments, output_frames
336
+ import gc
337
+ gc.collect()
338
+
339
+ return gr.update(value=final_vid_output_path, visible=True), session_state
340
+
341
+ except Exception as e:
342
+ print(f"Error in propagate_to_all: {e}")
343
+ return gr.update(value=None, visible=False), session_state
344
+
345
+ def update_ui():
346
+ return gr.update(visible=True)
347
+
348
+ # Gradio Interface
349
  with gr.Blocks() as demo:
350
+ session_state = gr.State({
351
+ "first_frame": None,
352
+ "all_frames": None,
353
+ "input_points": [],
354
+ "input_labels": [],
355
+ "inference_state": None,
356
+ "frame_stride": 1,
357
+ "scale_factor": 1.0,
358
+ "original_dimensions": None,
359
+ })
360
+
361
+ with gr.Column():
362
+ gr.Markdown(title)
363
+ with gr.Row():
364
+ with gr.Column():
365
+ gr.Markdown(description_p)
366
+ with gr.Accordion("Input Video", open=True) as video_in_drawer:
367
+ video_in = gr.Video(label="Input Video", format="mp4")
368
+ with gr.Row():
369
+ point_type = gr.Radio(label="point type", choices=["include", "exclude"], value="include", scale=2)
370
+ propagate_btn = gr.Button("Track", scale=1, variant="primary")
371
+ clear_points_btn = gr.Button("Clear Points", scale=1)
372
+ reset_btn = gr.Button("Reset", scale=1)
373
+ points_map = gr.Image(label="Frame with Point Prompt", type="numpy", interactive=False)
374
+ with gr.Column():
375
+ gr.Markdown("# Try some of the examples below ⬇️")
376
+ gr.Examples(examples=examples, inputs=[video_in], examples_per_page=5)
377
+ output_image = gr.Image(label="Reference Mask")
378
+ output_video = gr.Video(visible=False)
379
+
380
+ video_in.upload(
381
+ fn=preprocess_video_in,
382
+ inputs=[video_in, session_state],
383
+ outputs=[video_in_drawer, points_map, output_image, output_video, session_state],
384
+ queue=False,
385
+ )
386
+
387
+ video_in.change(
388
+ fn=preprocess_video_in,
389
+ inputs=[video_in, session_state],
390
+ outputs=[video_in_drawer, points_map, output_image, output_video, session_state],
391
+ queue=False,
392
+ )
393
+
394
+ points_map.select(
395
+ fn=segment_with_points,
396
+ inputs=[point_type, session_state],
397
+ outputs=[points_map, output_image, session_state],
398
+ queue=False,
399
+ )
400
+
401
+ clear_points_btn.click(
402
+ fn=clear_points,
403
+ inputs=session_state,
404
+ outputs=[points_map, output_image, output_video, session_state],
405
+ queue=False,
406
+ )
407
+
408
+ reset_btn.click(
409
+ fn=reset,
410
+ inputs=session_state,
411
+ outputs=[video_in, video_in_drawer, points_map, output_image, output_video, session_state],
412
+ queue=False,
413
+ )
414
+
415
+ propagate_btn.click(
416
+ fn=update_ui,
417
+ inputs=[],
418
+ outputs=output_video,
419
+ queue=False,
420
+ ).then(
421
+ fn=propagate_to_all,
422
+ inputs=[video_in, session_state],
423
+ outputs=[output_video, session_state],
424
+ queue=True,
425
+ )
426
+
427
+ demo.queue()
428
+ demo.launch()