pablovela5620 commited on
Commit
6c083d0
·
1 Parent(s): 1cd2fa5

more updates

Browse files
annotation_example/gradio_ui/sv_sam.py CHANGED
@@ -307,7 +307,7 @@ def _preprocess_video(
307
  first_bgr: np.ndarray = cv2.imread(str(first_frame_path))
308
 
309
  progress(0.5, desc="Initializing SAM")
310
- VIDEO_SAM_PREDICTOR: SAM2VideoPredictor = SAM2VideoPredictor.from_pretrained("facebook/sam2-hiera-tiny", device="cpu")
311
  with torch.inference_mode():
312
  inference_state = VIDEO_SAM_PREDICTOR.init_state(video_path=tmp_frames_dir)
313
  VIDEO_SAM_PREDICTOR.reset_state(inference_state)
@@ -382,6 +382,7 @@ def get_initial_mask(
382
  f"Include points: {keypoint_container.include_points.shape}, Exclude points: {keypoint_container.exclude_points.shape}"
383
  )
384
 
 
385
  with torch.inference_mode():
386
  frame_idx: int
387
  object_ids: list
@@ -439,7 +440,8 @@ def propagate_mask(
439
  f"{log_paths['pinhole_path']}/exclude",
440
  rr.Clear(recursive=True),
441
  )
442
-
 
443
  with torch.inference_mode():
444
  frame_idx: int
445
  object_ids: list
 
307
  first_bgr: np.ndarray = cv2.imread(str(first_frame_path))
308
 
309
  progress(0.5, desc="Initializing SAM")
310
+ VIDEO_SAM_PREDICTOR: SAM2VideoPredictor = SAM2VideoPredictor.from_pretrained("facebook/sam2-hiera-tiny")
311
  with torch.inference_mode():
312
  inference_state = VIDEO_SAM_PREDICTOR.init_state(video_path=tmp_frames_dir)
313
  VIDEO_SAM_PREDICTOR.reset_state(inference_state)
 
382
  f"Include points: {keypoint_container.include_points.shape}, Exclude points: {keypoint_container.exclude_points.shape}"
383
  )
384
 
385
+ VIDEO_SAM_PREDICTOR: SAM2VideoPredictor = SAM2VideoPredictor.from_pretrained("facebook/sam2-hiera-tiny")
386
  with torch.inference_mode():
387
  frame_idx: int
388
  object_ids: list
 
440
  f"{log_paths['pinhole_path']}/exclude",
441
  rr.Clear(recursive=True),
442
  )
443
+
444
+ VIDEO_SAM_PREDICTOR: SAM2VideoPredictor = SAM2VideoPredictor.from_pretrained("facebook/sam2-hiera-tiny")
445
  with torch.inference_mode():
446
  frame_idx: int
447
  object_ids: list