chongzhou commited on
Commit
0a7fba1
·
1 Parent(s): 113b7b2

move inference_states out of gr.State

Browse files
Files changed (1) hide show
  1. app.py +18 -11
app.py CHANGED
@@ -73,6 +73,7 @@ OBJ_ID = 0
73
  sam2_checkpoint = "checkpoints/edgetam.pt"
74
  model_cfg = "edgetam.yaml"
75
  predictor = build_sam2_video_predictor(model_cfg, sam2_checkpoint, device="cpu")
 
76
 
77
 
78
  def get_video_fps(video_path):
@@ -89,15 +90,17 @@ def get_video_fps(video_path):
89
  return fps
90
 
91
 
92
- def reset(session_state):
93
  predictor.to("cpu")
94
  session_state["input_points"] = []
95
  session_state["input_labels"] = []
96
- if session_state["inference_state"] is not None:
97
- predictor.reset_state(session_state["inference_state"])
 
 
98
  session_state["first_frame"] = None
99
  session_state["all_frames"] = None
100
- session_state["inference_state"] = None
101
  return (
102
  None,
103
  gr.update(open=True),
@@ -112,8 +115,9 @@ def clear_points(session_state):
112
  predictor.to("cpu")
113
  session_state["input_points"] = []
114
  session_state["input_labels"] = []
115
- if session_state["inference_state"]["tracking_has_started"]:
116
- predictor.reset_state(session_state["inference_state"])
 
117
  return (
118
  session_state["first_frame"],
119
  None,
@@ -168,7 +172,9 @@ def preprocess_video_in(video_path, session_state):
168
  session_state["first_frame"] = copy.deepcopy(first_frame)
169
  session_state["all_frames"] = all_frames
170
 
171
- session_state["inference_state"] = predictor.init_state(video_path=video_path)
 
 
172
  session_state["input_points"] = []
173
  session_state["input_labels"] = []
174
 
@@ -230,8 +236,9 @@ def segment_with_points(
230
  points = np.array(session_state["input_points"], dtype=np.float32)
231
  # for labels, `1` means positive click and `0` means negative click
232
  labels = np.array(session_state["input_labels"], np.int32)
 
233
  _, _, out_mask_logits = predictor.add_new_points(
234
- inference_state=session_state["inference_state"],
235
  frame_idx=0,
236
  obj_id=OBJ_ID,
237
  points=points,
@@ -270,10 +277,11 @@ def propagate_to_all(
270
  torch.backends.cuda.matmul.allow_tf32 = True
271
  torch.backends.cudnn.allow_tf32 = True
272
  with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
 
273
  if (
274
  len(session_state["input_points"]) == 0
275
  or video_in is None
276
- or session_state["inference_state"] is None
277
  ):
278
  return (
279
  None,
@@ -286,7 +294,7 @@ def propagate_to_all(
286
  ) # video_segments contains the per-frame segmentation results
287
  print("starting propagate_in_video")
288
  for out_frame_idx, out_obj_ids, out_mask_logits in predictor.propagate_in_video(
289
- session_state["inference_state"]
290
  ):
291
  video_segments[out_frame_idx] = {
292
  out_obj_id: (out_mask_logits[i] > 0.0).cpu().numpy()
@@ -340,7 +348,6 @@ with gr.Blocks() as demo:
340
  "all_frames": None,
341
  "input_points": [],
342
  "input_labels": [],
343
- "inference_state": None,
344
  }
345
  )
346
 
 
73
  sam2_checkpoint = "checkpoints/edgetam.pt"
74
  model_cfg = "edgetam.yaml"
75
  predictor = build_sam2_video_predictor(model_cfg, sam2_checkpoint, device="cpu")
76
+ global_inference_states = {}
77
 
78
 
79
  def get_video_fps(video_path):
 
90
  return fps
91
 
92
 
93
+ def reset():
94
  predictor.to("cpu")
95
  session_state["input_points"] = []
96
  session_state["input_labels"] = []
97
+
98
+ session_id = id(session_state)
99
+ if global_inference_states[session_id] is not None:
100
+ predictor.reset_state(global_inference_states[session_id])
101
  session_state["first_frame"] = None
102
  session_state["all_frames"] = None
103
+ global_inference_states[session_id] = None
104
  return (
105
  None,
106
  gr.update(open=True),
 
115
  predictor.to("cpu")
116
  session_state["input_points"] = []
117
  session_state["input_labels"] = []
118
+ session_id = id(session_state)
119
+ if global_inference_states[session_id]["tracking_has_started"]:
120
+ predictor.reset_state(global_inference_states[session_id])
121
  return (
122
  session_state["first_frame"],
123
  None,
 
172
  session_state["first_frame"] = copy.deepcopy(first_frame)
173
  session_state["all_frames"] = all_frames
174
 
175
+ session_id = id(session_state)
176
+ global_inference_states[session_id] = predictor.init_state(video_path=video_path)
177
+
178
  session_state["input_points"] = []
179
  session_state["input_labels"] = []
180
 
 
236
  points = np.array(session_state["input_points"], dtype=np.float32)
237
  # for labels, `1` means positive click and `0` means negative click
238
  labels = np.array(session_state["input_labels"], np.int32)
239
+ session_id = id(session_state)
240
  _, _, out_mask_logits = predictor.add_new_points(
241
+ inference_state=global_inference_states[session_id],
242
  frame_idx=0,
243
  obj_id=OBJ_ID,
244
  points=points,
 
277
  torch.backends.cuda.matmul.allow_tf32 = True
278
  torch.backends.cudnn.allow_tf32 = True
279
  with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
280
+ session_id = id(session_state)
281
  if (
282
  len(session_state["input_points"]) == 0
283
  or video_in is None
284
+ or global_inference_states[session_id] is None
285
  ):
286
  return (
287
  None,
 
294
  ) # video_segments contains the per-frame segmentation results
295
  print("starting propagate_in_video")
296
  for out_frame_idx, out_obj_ids, out_mask_logits in predictor.propagate_in_video(
297
+ global_inference_states[session_id]
298
  ):
299
  video_segments[out_frame_idx] = {
300
  out_obj_id: (out_mask_logits[i] > 0.0).cpu().numpy()
 
348
  "all_frames": None,
349
  "input_points": [],
350
  "input_labels": [],
 
351
  }
352
  )
353