chongzhou commited on
Commit
dd6a79c
·
1 Parent(s): 1f52c1d

move inference_state to gr.state

Browse files
Files changed (1) hide show
  1. app.py +249 -180
app.py CHANGED
@@ -17,7 +17,6 @@ import cv2
17
  import matplotlib.pyplot as plt
18
  import numpy as np
19
 
20
- import spaces
21
  import torch
22
 
23
  from moviepy.editor import ImageSequenceClip
@@ -70,11 +69,8 @@ examples = [
70
  ]
71
 
72
  OBJ_ID = 0
73
-
74
  sam2_checkpoint = "checkpoints/edgetam.pt"
75
  model_cfg = "edgetam.yaml"
76
- predictor = build_sam2_video_predictor(model_cfg, sam2_checkpoint, device="cpu")
77
- global_inference_states = {}
78
 
79
 
80
  def get_video_fps(video_path):
@@ -92,75 +88,82 @@ def get_video_fps(video_path):
92
 
93
 
94
  def reset(
95
- session_first_frame,
96
- session_all_frames,
97
- session_input_points,
98
- session_input_labels,
99
- request: gr.Request,
 
100
  ):
101
- session_id = request.session_hash
102
- predictor.to("cpu")
103
- session_input_points = []
104
- session_input_labels = []
105
-
106
- if global_inference_states[session_id] is not None:
107
- predictor.reset_state(global_inference_states[session_id])
108
- session_first_frame = None
109
- session_all_frames = None
110
- global_inference_states[session_id] = None
111
  return (
112
  None,
113
  gr.update(open=True),
114
  None,
115
  None,
116
  gr.update(value=None, visible=False),
117
- session_first_frame,
118
- session_all_frames,
119
- session_input_points,
120
- session_input_labels,
 
 
121
  )
122
 
123
 
124
  def clear_points(
125
- session_input_points,
126
- session_input_labels,
127
- request: gr.Request,
 
 
 
128
  ):
129
- session_id = request.session_hash
130
- predictor.to("cpu")
131
- session_input_points = []
132
- session_input_labels = []
133
- if global_inference_states[session_id]["tracking_has_started"]:
134
- predictor.reset_state(global_inference_states[session_id])
135
  return (
136
- session_first_frame,
137
  None,
138
  gr.update(value=None, visible=False),
139
- session_input_points,
140
- session_input_labels,
 
 
 
 
141
  )
142
 
143
 
144
  def preprocess_video_in(
145
  video_path,
146
- session_first_frame,
147
- session_all_frames,
148
- session_input_points,
149
- session_input_labels,
150
- request: gr.Request,
 
151
  ):
152
- session_id = request.session_hash
153
- predictor.to("cpu")
154
  if video_path is None:
155
  return (
156
  gr.update(open=True), # video_in_drawer
157
  None, # points_map
158
  None, # output_image
159
  gr.update(value=None, visible=False), # output_video
160
- session_first_frame,
161
- session_all_frames,
162
- session_input_points,
163
- session_input_labels,
 
 
164
  )
165
 
166
  # Read the first frame
@@ -172,14 +175,19 @@ def preprocess_video_in(
172
  None, # points_map
173
  None, # output_image
174
  gr.update(value=None, visible=False), # output_video
175
- session_first_frame,
176
- session_all_frames,
177
- session_input_points,
178
- session_input_labels,
 
 
179
  )
180
 
 
 
 
181
  frame_number = 0
182
- first_frame = None
183
  all_frames = []
184
 
185
  while True:
@@ -192,100 +200,107 @@ def preprocess_video_in(
192
 
193
  # Store the first frame
194
  if frame_number == 0:
195
- first_frame = frame
196
  all_frames.append(frame)
197
 
198
  frame_number += 1
199
 
200
  cap.release()
201
- session_first_frame = copy.deepcopy(first_frame)
202
- session_all_frames = all_frames
203
-
204
- global_inference_states[session_id] = predictor.init_state(video_path=video_path)
205
-
206
- session_input_points = []
207
- session_input_labels = []
208
 
209
  return [
210
  gr.update(open=False), # video_in_drawer
211
  first_frame, # points_map
212
  None, # output_image
213
  gr.update(value=None, visible=False), # output_video
214
- session_first_frame,
215
- session_all_frames,
216
- session_input_points,
217
- session_input_labels,
 
 
218
  ]
219
 
220
 
221
- @spaces.GPU
222
  def segment_with_points(
223
  point_type,
224
- session_input_points,
225
- session_input_labels,
 
 
 
 
226
  evt: gr.SelectData,
227
- request: gr.Request,
228
  ):
229
- session_id = request.session_hash
230
- if torch.cuda.get_device_properties(0).major >= 8:
231
- torch.backends.cuda.matmul.allow_tf32 = True
232
- torch.backends.cudnn.allow_tf32 = True
233
- with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
234
  predictor.to("cuda")
235
- session_input_points.append(evt.index)
236
- print(f"TRACKING INPUT POINT: {session_input_points}")
237
-
238
- if point_type == "include":
239
- session_input_labels.append(1)
240
- elif point_type == "exclude":
241
- session_input_labels.append(0)
242
- print(f"TRACKING INPUT LABEL: {session_input_labels}")
243
-
244
- # Open the image and get its dimensions
245
- transparent_background = Image.fromarray(session_first_frame).convert("RGBA")
246
- w, h = transparent_background.size
247
-
248
- # Define the circle radius as a fraction of the smaller dimension
249
- fraction = 0.01 # You can adjust this value as needed
250
- radius = int(fraction * min(w, h))
251
-
252
- # Create a transparent layer to draw on
253
- transparent_layer = np.zeros((h, w, 4), dtype=np.uint8)
254
-
255
- for index, track in enumerate(session_input_points):
256
- if session_input_labels[index] == 1:
257
- cv2.circle(transparent_layer, track, radius, (0, 255, 0, 255), -1)
258
- else:
259
- cv2.circle(transparent_layer, track, radius, (255, 0, 0, 255), -1)
260
-
261
- # Convert the transparent layer back to an image
262
- transparent_layer = Image.fromarray(transparent_layer, "RGBA")
263
- selected_point_map = Image.alpha_composite(
264
- transparent_background, transparent_layer
265
- )
 
 
 
 
 
 
266
 
267
- # Let's add a positive click at (x, y) = (210, 350) to get started
268
- points = np.array(session_input_points, dtype=np.float32)
269
- # for labels, `1` means positive click and `0` means negative click
270
- labels = np.array(session_input_labels, dtype=np.int32)
271
- _, _, out_mask_logits = predictor.add_new_points(
272
- inference_state=global_inference_states[session_id],
273
- frame_idx=0,
274
- obj_id=OBJ_ID,
275
- points=points,
276
- labels=labels,
277
- )
278
 
279
- mask_image = show_mask((out_mask_logits[0] > 0.0).cpu().numpy())
280
- first_frame_output = Image.alpha_composite(transparent_background, mask_image)
281
 
282
- torch.cuda.empty_cache()
283
- return (
284
- selected_point_map,
285
- first_frame_output,
286
- session_input_points,
287
- session_input_labels,
288
- )
 
 
 
 
289
 
290
 
291
  def show_mask(mask, obj_id=None, random_color=False, convert_to_image=True):
@@ -303,69 +318,82 @@ def show_mask(mask, obj_id=None, random_color=False, convert_to_image=True):
303
  return mask
304
 
305
 
306
- @spaces.GPU
307
  def propagate_to_all(
308
  video_in,
309
- session_all_frames,
310
- request: gr.Request,
 
 
 
 
311
  ):
312
- session_id = request.session_hash
313
- predictor.to("cuda")
314
- if torch.cuda.get_device_properties(0).major >= 8:
315
- torch.backends.cuda.matmul.allow_tf32 = True
316
- torch.backends.cudnn.allow_tf32 = True
317
- with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
318
- if (
319
- len(session_input_points) == 0
320
- or video_in is None
321
- or global_inference_states[session_id] is None
322
- ):
323
- return None
324
-
325
- # run propagation throughout the video and collect the results in a dict
326
- video_segments = (
327
- {}
328
- ) # video_segments contains the per-frame segmentation results
329
- print("starting propagate_in_video")
330
- for out_frame_idx, out_obj_ids, out_mask_logits in predictor.propagate_in_video(
331
- global_inference_states[session_id]
332
- ):
333
- video_segments[out_frame_idx] = {
334
- out_obj_id: (out_mask_logits[i] > 0.0).cpu().numpy()
335
- for i, out_obj_id in enumerate(out_obj_ids)
336
- }
337
-
338
- # obtain the segmentation results every few frames
339
- vis_frame_stride = 1
340
-
341
- output_frames = []
342
- for out_frame_idx in range(0, len(video_segments), vis_frame_stride):
343
- transparent_background = Image.fromarray(
344
- session_all_frames[out_frame_idx]
345
- ).convert("RGBA")
346
- out_mask = video_segments[out_frame_idx][OBJ_ID]
347
- mask_image = show_mask(out_mask)
348
- output_frame = Image.alpha_composite(transparent_background, mask_image)
349
- output_frame = np.array(output_frame)
350
- output_frames.append(output_frame)
351
-
352
- torch.cuda.empty_cache()
353
-
354
- # Create a video clip from the image sequence
355
- original_fps = get_video_fps(video_in)
356
- fps = original_fps # Frames per second
357
- clip = ImageSequenceClip(output_frames, fps=fps)
358
- # Write the result to a file
359
- unique_id = datetime.now().strftime("%Y%m%d%H%M%S")
360
- final_vid_output_path = f"output_video_{unique_id}.mp4"
361
- final_vid_output_path = os.path.join(
362
- tempfile.gettempdir(), final_vid_output_path
363
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
364
 
365
- # Write the result to a file
366
- clip.write_videofile(final_vid_output_path, codec="libx264")
367
 
368
- return gr.update(value=final_vid_output_path)
 
 
 
 
 
 
369
 
370
 
371
  def update_ui():
@@ -377,6 +405,8 @@ with gr.Blocks() as demo:
377
  all_frames = gr.State(None)
378
  input_points = gr.State([])
379
  input_labels = gr.State([])
 
 
380
 
381
  with gr.Column():
382
  # Title
@@ -430,6 +460,8 @@ with gr.Blocks() as demo:
430
  all_frames,
431
  input_points,
432
  input_labels,
 
 
433
  ],
434
  outputs=[
435
  video_in_drawer, # Accordion to hide uploaded video player
@@ -440,6 +472,8 @@ with gr.Blocks() as demo:
440
  all_frames,
441
  input_points,
442
  input_labels,
 
 
443
  ],
444
  queue=False,
445
  )
@@ -452,6 +486,8 @@ with gr.Blocks() as demo:
452
  all_frames,
453
  input_points,
454
  input_labels,
 
 
455
  ],
456
  outputs=[
457
  video_in_drawer, # Accordion to hide uploaded video player
@@ -462,6 +498,8 @@ with gr.Blocks() as demo:
462
  all_frames,
463
  input_points,
464
  input_labels,
 
 
465
  ],
466
  queue=False,
467
  )
@@ -471,14 +509,22 @@ with gr.Blocks() as demo:
471
  fn=segment_with_points,
472
  inputs=[
473
  point_type, # "include" or "exclude"
 
 
474
  input_points,
475
  input_labels,
 
 
476
  ],
477
  outputs=[
478
  points_map, # updated image with points
479
  output_image,
 
 
480
  input_points,
481
  input_labels,
 
 
482
  ],
483
  queue=False,
484
  )
@@ -487,15 +533,23 @@ with gr.Blocks() as demo:
487
  clear_points_btn.click(
488
  fn=clear_points,
489
  inputs=[
 
 
490
  input_points,
491
  input_labels,
 
 
492
  ],
493
  outputs=[
494
  points_map,
495
  output_image,
496
  output_video,
 
 
497
  input_points,
498
  input_labels,
 
 
499
  ],
500
  queue=False,
501
  )
@@ -507,6 +561,8 @@ with gr.Blocks() as demo:
507
  all_frames,
508
  input_points,
509
  input_labels,
 
 
510
  ],
511
  outputs=[
512
  video_in,
@@ -518,6 +574,8 @@ with gr.Blocks() as demo:
518
  all_frames,
519
  input_points,
520
  input_labels,
 
 
521
  ],
522
  queue=False,
523
  )
@@ -531,10 +589,21 @@ with gr.Blocks() as demo:
531
  fn=propagate_to_all,
532
  inputs=[
533
  video_in,
 
534
  all_frames,
 
 
 
 
535
  ],
536
  outputs=[
537
  output_video,
 
 
 
 
 
 
538
  ],
539
  concurrency_limit=10,
540
  queue=False,
 
17
  import matplotlib.pyplot as plt
18
  import numpy as np
19
 
 
20
  import torch
21
 
22
  from moviepy.editor import ImageSequenceClip
 
69
  ]
70
 
71
  OBJ_ID = 0
 
72
  sam2_checkpoint = "checkpoints/edgetam.pt"
73
  model_cfg = "edgetam.yaml"
 
 
74
 
75
 
76
  def get_video_fps(video_path):
 
88
 
89
 
90
  def reset(
91
+ first_frame,
92
+ all_frames,
93
+ input_points,
94
+ input_labels,
95
+ inference_state,
96
+ predictor,
97
  ):
98
+ first_frame = None
99
+ all_frames = None
100
+ input_points = []
101
+ input_labels = []
102
+
103
+ if inference_state and predictor:
104
+ predictor.reset_state(inference_state)
105
+ inference_state = None
 
 
106
  return (
107
  None,
108
  gr.update(open=True),
109
  None,
110
  None,
111
  gr.update(value=None, visible=False),
112
+ first_frame,
113
+ all_frames,
114
+ input_points,
115
+ input_labels,
116
+ inference_state,
117
+ predictor,
118
  )
119
 
120
 
121
  def clear_points(
122
+ first_frame,
123
+ all_frames,
124
+ input_points,
125
+ input_labels,
126
+ inference_state,
127
+ predictor,
128
  ):
129
+ input_points = []
130
+ input_labels = []
131
+ if inference_state and predictor and inference_state["tracking_has_started"]:
132
+ predictor.reset_state(inference_state)
 
 
133
  return (
134
+ first_frame,
135
  None,
136
  gr.update(value=None, visible=False),
137
+ first_frame,
138
+ all_frames,
139
+ input_points,
140
+ input_labels,
141
+ inference_state,
142
+ predictor,
143
  )
144
 
145
 
146
  def preprocess_video_in(
147
  video_path,
148
+ first_frame,
149
+ all_frames,
150
+ input_points,
151
+ input_labels,
152
+ inference_state,
153
+ predictor,
154
  ):
 
 
155
  if video_path is None:
156
  return (
157
  gr.update(open=True), # video_in_drawer
158
  None, # points_map
159
  None, # output_image
160
  gr.update(value=None, visible=False), # output_video
161
+ first_frame,
162
+ all_frames,
163
+ input_points,
164
+ input_labels,
165
+ inference_state,
166
+ predictor,
167
  )
168
 
169
  # Read the first frame
 
175
  None, # points_map
176
  None, # output_image
177
  gr.update(value=None, visible=False), # output_video
178
+ first_frame,
179
+ all_frames,
180
+ input_points,
181
+ input_labels,
182
+ inference_state,
183
+ predictor,
184
  )
185
 
186
+ if predictor is None:
187
+ predictor = build_sam2_video_predictor(model_cfg, sam2_checkpoint, device="cpu")
188
+
189
  frame_number = 0
190
+ _first_frame = None
191
  all_frames = []
192
 
193
  while True:
 
200
 
201
  # Store the first frame
202
  if frame_number == 0:
203
+ _first_frame = frame
204
  all_frames.append(frame)
205
 
206
  frame_number += 1
207
 
208
  cap.release()
209
+ first_frame = copy.deepcopy(_first_frame)
210
+ inference_state = predictor.init_state(video_path=video_path)
211
+ input_points = []
212
+ input_labels = []
 
 
 
213
 
214
  return [
215
  gr.update(open=False), # video_in_drawer
216
  first_frame, # points_map
217
  None, # output_image
218
  gr.update(value=None, visible=False), # output_video
219
+ first_frame,
220
+ all_frames,
221
+ input_points,
222
+ input_labels,
223
+ inference_state,
224
+ predictor,
225
  ]
226
 
227
 
 
228
  def segment_with_points(
229
  point_type,
230
+ first_frame,
231
+ all_frames,
232
+ input_points,
233
+ input_labels,
234
+ inference_state,
235
+ predictor,
236
  evt: gr.SelectData,
 
237
  ):
238
+ if torch.cuda.is_available():
 
 
 
 
239
  predictor.to("cuda")
240
+ inference_state["device"] = "cuda"
241
+ if torch.cuda.get_device_properties(0).major >= 8:
242
+ torch.backends.cuda.matmul.allow_tf32 = True
243
+ torch.backends.cudnn.allow_tf32 = True
244
+ torch.autocast(device_type="cuda", dtype=torch.bfloat16).__enter__()
245
+
246
+ input_points.append(evt.index)
247
+ print(f"TRACKING INPUT POINT: {input_points}")
248
+
249
+ if point_type == "include":
250
+ input_labels.append(1)
251
+ elif point_type == "exclude":
252
+ input_labels.append(0)
253
+ print(f"TRACKING INPUT LABEL: {input_labels}")
254
+
255
+ # Open the image and get its dimensions
256
+ transparent_background = Image.fromarray(first_frame).convert("RGBA")
257
+ w, h = transparent_background.size
258
+
259
+ # Define the circle radius as a fraction of the smaller dimension
260
+ fraction = 0.01 # You can adjust this value as needed
261
+ radius = int(fraction * min(w, h))
262
+
263
+ # Create a transparent layer to draw on
264
+ transparent_layer = np.zeros((h, w, 4), dtype=np.uint8)
265
+
266
+ for index, track in enumerate(input_points):
267
+ if input_labels[index] == 1:
268
+ cv2.circle(transparent_layer, track, radius, (0, 255, 0, 255), -1)
269
+ else:
270
+ cv2.circle(transparent_layer, track, radius, (255, 0, 0, 255), -1)
271
+
272
+ # Convert the transparent layer back to an image
273
+ transparent_layer = Image.fromarray(transparent_layer, "RGBA")
274
+ selected_point_map = Image.alpha_composite(
275
+ transparent_background, transparent_layer
276
+ )
277
 
278
+ # Let's add a positive click at (x, y) = (210, 350) to get started
279
+ points = np.array(input_points, dtype=np.float32)
280
+ # for labels, `1` means positive click and `0` means negative click
281
+ labels = np.array(input_labels, dtype=np.int32)
282
+ _, _, out_mask_logits = predictor.add_new_points(
283
+ inference_state=inference_state,
284
+ frame_idx=0,
285
+ obj_id=OBJ_ID,
286
+ points=points,
287
+ labels=labels,
288
+ )
289
 
290
+ mask_image = show_mask((out_mask_logits[0] > 0.0).cpu().numpy())
291
+ first_frame_output = Image.alpha_composite(transparent_background, mask_image)
292
 
293
+ torch.cuda.empty_cache()
294
+ return (
295
+ selected_point_map,
296
+ first_frame_output,
297
+ first_frame,
298
+ all_frames,
299
+ input_points,
300
+ input_labels,
301
+ inference_state,
302
+ predictor,
303
+ )
304
 
305
 
306
  def show_mask(mask, obj_id=None, random_color=False, convert_to_image=True):
 
318
  return mask
319
 
320
 
 
321
  def propagate_to_all(
322
  video_in,
323
+ first_frame,
324
+ all_frames,
325
+ input_points,
326
+ input_labels,
327
+ inference_state,
328
+ predictor,
329
  ):
330
+ if torch.cuda.is_available():
331
+ predictor.to("cuda")
332
+ inference_state["device"] = "cuda"
333
+ if torch.cuda.get_device_properties(0).major >= 8:
334
+ torch.backends.cuda.matmul.allow_tf32 = True
335
+ torch.backends.cudnn.allow_tf32 = True
336
+ torch.autocast(device_type="cuda", dtype=torch.bfloat16).__enter__()
337
+
338
+ if len(input_points) == 0 or video_in is None or inference_state is None:
339
+ return None
340
+ # run propagation throughout the video and collect the results in a dict
341
+ video_segments = {} # video_segments contains the per-frame segmentation results
342
+ print("starting propagate_in_video")
343
+ for out_frame_idx, out_obj_ids, out_mask_logits in predictor.propagate_in_video(
344
+ inference_state
345
+ ):
346
+ video_segments[out_frame_idx] = {
347
+ out_obj_id: (out_mask_logits[i] > 0.0).cpu().numpy()
348
+ for i, out_obj_id in enumerate(out_obj_ids)
349
+ }
350
+
351
+ # obtain the segmentation results every few frames
352
+ vis_frame_stride = 1
353
+
354
+ output_frames = []
355
+ for out_frame_idx in range(0, len(video_segments), vis_frame_stride):
356
+ transparent_background = Image.fromarray(all_frames[out_frame_idx]).convert(
357
+ "RGBA"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
358
  )
359
+ out_mask = video_segments[out_frame_idx][OBJ_ID]
360
+ mask_image = show_mask(out_mask)
361
+ output_frame = Image.alpha_composite(transparent_background, mask_image)
362
+ output_frame = np.array(output_frame)
363
+ output_frames.append(output_frame)
364
+
365
+ torch.cuda.empty_cache()
366
+
367
+ # Create a video clip from the image sequence
368
+ original_fps = get_video_fps(video_in)
369
+ fps = original_fps # Frames per second
370
+ clip = ImageSequenceClip(output_frames, fps=fps)
371
+ # Write the result to a file
372
+ unique_id = datetime.now().strftime("%Y%m%d%H%M%S")
373
+ final_vid_output_path = f"output_video_{unique_id}.mp4"
374
+ final_vid_output_path = os.path.join(tempfile.gettempdir(), final_vid_output_path)
375
+
376
+ # Write the result to a file
377
+ clip.write_videofile(final_vid_output_path, codec="libx264")
378
+
379
+ return (
380
+ gr.update(value=final_vid_output_path),
381
+ first_frame,
382
+ all_frames,
383
+ input_points,
384
+ input_labels,
385
+ inference_state,
386
+ predictor,
387
+ )
388
 
 
 
389
 
390
+ try:
391
+ from spaces import GPU
392
+
393
+ segment_with_points = GPU(segment_with_points)
394
+ propagate_to_all = GPU(propagate_to_all)
395
+ except:
396
+ print("spaces unavailable")
397
 
398
 
399
  def update_ui():
 
405
  all_frames = gr.State(None)
406
  input_points = gr.State([])
407
  input_labels = gr.State([])
408
+ inference_state = gr.State(None)
409
+ predictor = gr.State(None)
410
 
411
  with gr.Column():
412
  # Title
 
460
  all_frames,
461
  input_points,
462
  input_labels,
463
+ inference_state,
464
+ predictor,
465
  ],
466
  outputs=[
467
  video_in_drawer, # Accordion to hide uploaded video player
 
472
  all_frames,
473
  input_points,
474
  input_labels,
475
+ inference_state,
476
+ predictor,
477
  ],
478
  queue=False,
479
  )
 
486
  all_frames,
487
  input_points,
488
  input_labels,
489
+ inference_state,
490
+ predictor,
491
  ],
492
  outputs=[
493
  video_in_drawer, # Accordion to hide uploaded video player
 
498
  all_frames,
499
  input_points,
500
  input_labels,
501
+ inference_state,
502
+ predictor,
503
  ],
504
  queue=False,
505
  )
 
509
  fn=segment_with_points,
510
  inputs=[
511
  point_type, # "include" or "exclude"
512
+ first_frame,
513
+ all_frames,
514
  input_points,
515
  input_labels,
516
+ inference_state,
517
+ predictor,
518
  ],
519
  outputs=[
520
  points_map, # updated image with points
521
  output_image,
522
+ first_frame,
523
+ all_frames,
524
  input_points,
525
  input_labels,
526
+ inference_state,
527
+ predictor,
528
  ],
529
  queue=False,
530
  )
 
533
  clear_points_btn.click(
534
  fn=clear_points,
535
  inputs=[
536
+ first_frame,
537
+ all_frames,
538
  input_points,
539
  input_labels,
540
+ inference_state,
541
+ predictor,
542
  ],
543
  outputs=[
544
  points_map,
545
  output_image,
546
  output_video,
547
+ first_frame,
548
+ all_frames,
549
  input_points,
550
  input_labels,
551
+ inference_state,
552
+ predictor,
553
  ],
554
  queue=False,
555
  )
 
561
  all_frames,
562
  input_points,
563
  input_labels,
564
+ inference_state,
565
+ predictor,
566
  ],
567
  outputs=[
568
  video_in,
 
574
  all_frames,
575
  input_points,
576
  input_labels,
577
+ inference_state,
578
+ predictor,
579
  ],
580
  queue=False,
581
  )
 
589
  fn=propagate_to_all,
590
  inputs=[
591
  video_in,
592
+ first_frame,
593
  all_frames,
594
+ input_points,
595
+ input_labels,
596
+ inference_state,
597
+ predictor,
598
  ],
599
  outputs=[
600
  output_video,
601
+ first_frame,
602
+ all_frames,
603
+ input_points,
604
+ input_labels,
605
+ inference_state,
606
+ predictor,
607
  ],
608
  concurrency_limit=10,
609
  queue=False,