Update app.py
Browse files
app.py
CHANGED
@@ -8,6 +8,8 @@ from PIL import Image
|
|
8 |
from sam2.build_sam import build_sam2
|
9 |
from sam2.sam2_image_predictor import SAM2ImagePredictor
|
10 |
|
|
|
|
|
11 |
# Remove all CUDA-specific configurations
|
12 |
torch.autocast(device_type="cpu", dtype=torch.float32).__enter__()
|
13 |
|
@@ -51,7 +53,7 @@ def show_mask(mask, ax, random_color=False, borders=True):
|
|
51 |
mask_image = cv2.drawContours(mask_image, contours, -1, (1, 1, 1, 0.5), thickness=2)
|
52 |
ax.imshow(mask_image)
|
53 |
|
54 |
-
def show_points(coords, labels, ax, marker_size=
|
55 |
pos_points = coords[labels==1]
|
56 |
neg_points = coords[labels==0]
|
57 |
ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)
|
@@ -112,16 +114,18 @@ def process_mask(mask, expand_contract_px, expand, feathering_enabled, feather_s
|
|
112 |
def sam_process(input_image, checkpoint, tracking_points, trackings_input_label, expand_contract_px, expand, feathering_enabled, feather_size):
|
113 |
image = Image.open(input_image)
|
114 |
image = np.array(image.convert("RGB"))
|
115 |
-
|
116 |
-
"tiny":
|
117 |
-
"small":
|
118 |
-
"base-plus":
|
119 |
-
"large":
|
120 |
}
|
121 |
-
sam2_checkpoint, model_cfg = checkpoint_map[checkpoint]
|
122 |
# Use CPU for both model and computations
|
123 |
-
sam2_model = build_sam2(model_cfg, sam2_checkpoint, device="cpu")
|
124 |
-
predictor = SAM2ImagePredictor(
|
|
|
|
|
125 |
predictor.set_image(image)
|
126 |
input_point = np.array(tracking_points.value)
|
127 |
input_label = np.array(trackings_input_label.value)
|
|
|
8 |
from sam2.build_sam import build_sam2
|
9 |
from sam2.sam2_image_predictor import SAM2ImagePredictor
|
10 |
|
11 |
+
|
12 |
+
|
13 |
# Remove all CUDA-specific configurations
|
14 |
torch.autocast(device_type="cpu", dtype=torch.float32).__enter__()
|
15 |
|
|
|
53 |
mask_image = cv2.drawContours(mask_image, contours, -1, (1, 1, 1, 0.5), thickness=2)
|
54 |
ax.imshow(mask_image)
|
55 |
|
56 |
+
def show_points(coords, labels, ax, marker_size=200):
|
57 |
pos_points = coords[labels==1]
|
58 |
neg_points = coords[labels==0]
|
59 |
ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)
|
|
|
114 |
def sam_process(input_image, checkpoint, tracking_points, trackings_input_label, expand_contract_px, expand, feathering_enabled, feather_size):
|
115 |
image = Image.open(input_image)
|
116 |
image = np.array(image.convert("RGB"))
|
117 |
+
sam21_hfmap = {
|
118 |
+
"tiny": "facebook/sam2.1-hiera-tiny",
|
119 |
+
"small": "facebook/sam2.1-hiera-small",
|
120 |
+
"base-plus": "facebook/sam2.1-hiera-base-plus",
|
121 |
+
"large": "facebook/sam2.1-hiera-large",
|
122 |
}
|
123 |
+
# sam2_checkpoint, model_cfg = checkpoint_map[checkpoint]
|
124 |
# Use CPU for both model and computations
|
125 |
+
# sam2_model = build_sam2(model_cfg, sam2_checkpoint, device="cpu")
|
126 |
+
predictor = SAM2ImagePredictor.from_pretrained(sam21_hfmap[checkpoint], device="cpu")
|
127 |
+
|
128 |
+
# predictor = SAM2ImagePredictor(sam2_model)
|
129 |
predictor.set_image(image)
|
130 |
input_point = np.array(tracking_points.value)
|
131 |
input_label = np.array(trackings_input_label.value)
|