|
import torch |
|
import torch.nn as nn |
|
from ..modeling import Sam |
|
from .amg import calculate_stability_score |
|
|
|
|
|
class SamCoreMLModel(nn.Module): |
|
""" |
|
This model should not be called directly, but is used in CoreML export. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
model: Sam, |
|
use_stability_score: bool = False |
|
) -> None: |
|
super().__init__() |
|
self.mask_decoder = model.mask_decoder |
|
self.model = model |
|
self.img_size = model.image_encoder.img_size |
|
self.use_stability_score = use_stability_score |
|
self.stability_score_offset = 1.0 |
|
|
|
def _embed_points(self, point_coords: torch.Tensor, point_labels: torch.Tensor) -> torch.Tensor: |
|
point_coords = point_coords + 0.5 |
|
point_coords = point_coords / self.img_size |
|
point_embedding = self.model.prompt_encoder.pe_layer._pe_encoding(point_coords) |
|
point_labels = point_labels.unsqueeze(-1).expand_as(point_embedding) |
|
|
|
point_embedding = point_embedding * (point_labels != -1) |
|
point_embedding = point_embedding + self.model.prompt_encoder.not_a_point_embed.weight * ( |
|
point_labels == -1 |
|
) |
|
|
|
for i in range(self.model.prompt_encoder.num_point_embeddings): |
|
point_embedding = point_embedding + self.model.prompt_encoder.point_embeddings[ |
|
i |
|
].weight * (point_labels == i) |
|
|
|
return point_embedding |
|
|
|
@torch.no_grad() |
|
def forward( |
|
self, |
|
image_embeddings: torch.Tensor, |
|
point_coords: torch.Tensor, |
|
point_labels: torch.Tensor, |
|
): |
|
sparse_embedding = self._embed_points(point_coords, point_labels) |
|
dense_embedding = self.model.prompt_encoder.no_mask_embed.weight.reshape(1, -1, 1, 1) |
|
|
|
masks, scores = self.model.mask_decoder.predict_masks( |
|
image_embeddings=image_embeddings, |
|
image_pe=self.model.prompt_encoder.get_dense_pe(), |
|
sparse_prompt_embeddings=sparse_embedding, |
|
dense_prompt_embeddings=dense_embedding, |
|
) |
|
|
|
if self.use_stability_score: |
|
scores = calculate_stability_score( |
|
masks, self.model.mask_threshold, self.stability_score_offset |
|
) |
|
|
|
return scores, masks |
|
|