|
from PIL import Image |
|
import base64 |
|
import io |
|
import numpy as np |
|
|
|
|
|
from medsam2_model import MedSAM2 |
|
|
|
def image_to_base64(image: Image.Image) -> str: |
|
buffered = io.BytesIO() |
|
image.save(buffered, format="PNG") |
|
return "data:image/png;base64," + base64.b64encode(buffered.getvalue()).decode() |
|
|
|
class EndpointHandler(): |
|
def __init__(self, path=""): |
|
model = MedSAM2("MedSAM2_pretrain_10ep_b1_AMD-SD_sam2_hiera_t.pth") |
|
|
|
def __call__(self, data: Any) -> List[List[Dict[str, float]]]: |
|
if isinstance(data, dict) and "image" in data: |
|
image_data = data["image"] |
|
if image_data.startswith("data:image"): |
|
header, base64_data = image_data.split(",", 1) |
|
image = Image.open(io.BytesIO(base64.b64decode(base64_data))) |
|
|
|
|
|
mask_array = model.predict(image_np, box) |
|
mask_pil = Image.fromarray((mask_array * 255).astype(np.uint8)) |
|
|
|
return [{ |
|
"label": "mock-segmentation", |
|
"mask": image_to_base64(mask_pil), |
|
"score": 0.99 |
|
}] |
|
return [{"label": "mock-segmentation", "mask": None, "score": 0.0}] |
|
|