File size: 1,321 Bytes
2d2347f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 |
import torch
import numpy as np
from skimage import transform
# from sam2_train.build_sam import build_sam2
# from sam2_train.sam2_image_predictor import SAM2ImagePredictor
from sam2.build_sam import build_sam2
from sam2.sam2_image_predictor import SAM2ImagePredictor
class MedSAM2:
def __init__(self, model_path, device="cpu"):
self.device = device
self.model = build_sam2("sam2_hiera_t", model_path, device=device)
self.predictor = SAM2ImagePredictor(self.model)
def predict(self, image: np.ndarray, box: list[float]) -> np.ndarray:
image_3c = image if image.shape[2] == 3 else np.repeat(image[:, :, None], 3, axis=-1)
img_1024 = transform.resize(image_3c, (1024, 1024), preserve_range=True).astype(np.uint8)
box_np = np.array(box)
box_1024 = box_np / np.array([image.shape[1], image.shape[0], image.shape[1], image.shape[0]]) * 1024
box_1024 = box_1024[None, :]
with torch.inference_mode(), torch.autocast(self.device, dtype=torch.bfloat16):
self.predictor.set_image(img_1024)
masks, _, _ = self.predictor.predict(
point_coords=None,
point_labels=None,
box=box_1024,
multimask_output=False
)
return masks[0].astype(np.uint8)
|