File size: 1,321 Bytes
2d2347f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
import torch
import numpy as np
from skimage import transform
# from sam2_train.build_sam import build_sam2
# from sam2_train.sam2_image_predictor import SAM2ImagePredictor
from sam2.build_sam import build_sam2
from sam2.sam2_image_predictor import SAM2ImagePredictor

class MedSAM2:
    def __init__(self, model_path, device="cpu"):
        self.device = device
        self.model = build_sam2("sam2_hiera_t", model_path, device=device)
        self.predictor = SAM2ImagePredictor(self.model)

    def predict(self, image: np.ndarray, box: list[float]) -> np.ndarray:
        image_3c = image if image.shape[2] == 3 else np.repeat(image[:, :, None], 3, axis=-1)
        img_1024 = transform.resize(image_3c, (1024, 1024), preserve_range=True).astype(np.uint8)

        box_np = np.array(box)
        box_1024 = box_np / np.array([image.shape[1], image.shape[0], image.shape[1], image.shape[0]]) * 1024
        box_1024 = box_1024[None, :]

        with torch.inference_mode(), torch.autocast(self.device, dtype=torch.bfloat16):
            self.predictor.set_image(img_1024)
            masks, _, _ = self.predictor.predict(
                point_coords=None,
                point_labels=None,
                box=box_1024,
                multimask_output=False
            )

        return masks[0].astype(np.uint8)