LPX55 commited on
Commit
ee9e061
·
verified ·
1 Parent(s): 8965f65

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +13 -9
app.py CHANGED
@@ -8,6 +8,8 @@ from PIL import Image
8
  from sam2.build_sam import build_sam2
9
  from sam2.sam2_image_predictor import SAM2ImagePredictor
10
 
 
 
11
  # Remove all CUDA-specific configurations
12
  torch.autocast(device_type="cpu", dtype=torch.float32).__enter__()
13
 
@@ -51,7 +53,7 @@ def show_mask(mask, ax, random_color=False, borders=True):
51
  mask_image = cv2.drawContours(mask_image, contours, -1, (1, 1, 1, 0.5), thickness=2)
52
  ax.imshow(mask_image)
53
 
54
- def show_points(coords, labels, ax, marker_size=375):
55
  pos_points = coords[labels==1]
56
  neg_points = coords[labels==0]
57
  ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)
@@ -112,16 +114,18 @@ def process_mask(mask, expand_contract_px, expand, feathering_enabled, feather_s
112
  def sam_process(input_image, checkpoint, tracking_points, trackings_input_label, expand_contract_px, expand, feathering_enabled, feather_size):
113
  image = Image.open(input_image)
114
  image = np.array(image.convert("RGB"))
115
- checkpoint_map = {
116
- "tiny": ("./checkpoints/sam2_hiera_tiny.pt", "sam2_hiera_t.yaml"),
117
- "small": ("./checkpoints/sam2_hiera_small.pt", "sam2_hiera_s.yaml"),
118
- "base-plus": ("./checkpoints/sam2_hiera_base_plus.pt", "sam2_hiera_b+.yaml"),
119
- "large": ("./checkpoints/sam2_hiera_large.pt", "sam2_hiera_l.yaml")
120
  }
121
- sam2_checkpoint, model_cfg = checkpoint_map[checkpoint]
122
  # Use CPU for both model and computations
123
- sam2_model = build_sam2(model_cfg, sam2_checkpoint, device="cpu")
124
- predictor = SAM2ImagePredictor(sam2_model)
 
 
125
  predictor.set_image(image)
126
  input_point = np.array(tracking_points.value)
127
  input_label = np.array(trackings_input_label.value)
 
8
  from sam2.build_sam import build_sam2
9
  from sam2.sam2_image_predictor import SAM2ImagePredictor
10
 
11
+
12
+
13
  # Remove all CUDA-specific configurations
14
  torch.autocast(device_type="cpu", dtype=torch.float32).__enter__()
15
 
 
53
  mask_image = cv2.drawContours(mask_image, contours, -1, (1, 1, 1, 0.5), thickness=2)
54
  ax.imshow(mask_image)
55
 
56
+ def show_points(coords, labels, ax, marker_size=200):
57
  pos_points = coords[labels==1]
58
  neg_points = coords[labels==0]
59
  ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)
 
114
  def sam_process(input_image, checkpoint, tracking_points, trackings_input_label, expand_contract_px, expand, feathering_enabled, feather_size):
115
  image = Image.open(input_image)
116
  image = np.array(image.convert("RGB"))
117
+ sam21_hfmap = {
118
+ "tiny": "facebook/sam2.1-hiera-tiny",
119
+ "small": "facebook/sam2.1-hiera-small",
120
+ "base-plus": "facebook/sam2.1-hiera-base-plus",
121
+ "large": "facebook/sam2.1-hiera-large",
122
  }
123
+ # sam2_checkpoint, model_cfg = checkpoint_map[checkpoint]
124
  # Use CPU for both model and computations
125
+ # sam2_model = build_sam2(model_cfg, sam2_checkpoint, device="cpu")
126
+ predictor = SAM2ImagePredictor.from_pretrained(sam21_hfmap[checkpoint], device="cpu")
127
+
128
+ # predictor = SAM2ImagePredictor(sam2_model)
129
  predictor.set_image(image)
130
  input_point = np.array(tracking_points.value)
131
  input_label = np.array(trackings_input_label.value)