|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from typing import List |
|
|
|
import torch |
|
import torch.nn.functional as F |
|
from torch import nn |
|
from torch.nn.init import trunc_normal_ |
|
|
|
from ultralytics.nn.modules import MLP |
|
|
|
from .blocks import SAM2TwoWayTransformer |
|
from .decoders import MaskDecoder, SAM2MaskDecoder |
|
from .encoders import ImageEncoderViT, PromptEncoder |
|
from .utils import get_1d_sine_pe, select_closest_cond_frames |
|
|
|
|
|
NO_OBJ_SCORE = -1024.0 |
|
|
|
|
|
class SAMModel(nn.Module): |
|
""" |
|
Segment Anything Model (SAM) for object segmentation tasks. |
|
|
|
This class combines image encoders, prompt encoders, and mask decoders to predict object masks from images |
|
and input prompts. |
|
|
|
Attributes: |
|
mask_threshold (float): Threshold value for mask prediction. |
|
image_encoder (ImageEncoderViT): Backbone for encoding images into embeddings. |
|
prompt_encoder (PromptEncoder): Encoder for various types of input prompts. |
|
mask_decoder (MaskDecoder): Predicts object masks from image and prompt embeddings. |
|
|
|
Methods: |
|
__init__: Initializes the SAMModel with encoders, decoder, and normalization parameters. |
|
|
|
Examples: |
|
>>> image_encoder = ImageEncoderViT(...) |
|
>>> prompt_encoder = PromptEncoder(...) |
|
>>> mask_decoder = MaskDecoder(...) |
|
>>> sam_model = SAMModel(image_encoder, prompt_encoder, mask_decoder) |
|
>>> # Further usage depends on SAMPredictor class |
|
|
|
Notes: |
|
All forward() operations are implemented in the SAMPredictor class. |
|
""" |
|
|
|
mask_threshold: float = 0.0 |
|
|
|
def __init__( |
|
self, |
|
image_encoder: ImageEncoderViT, |
|
prompt_encoder: PromptEncoder, |
|
mask_decoder: MaskDecoder, |
|
pixel_mean: List[float] = (123.675, 116.28, 103.53), |
|
pixel_std: List[float] = (58.395, 57.12, 57.375), |
|
) -> None: |
|
""" |
|
Initialize the SAMModel class to predict object masks from an image and input prompts. |
|
|
|
Args: |
|
image_encoder (ImageEncoderViT): The backbone used to encode the image into image embeddings. |
|
prompt_encoder (PromptEncoder): Encodes various types of input prompts. |
|
mask_decoder (MaskDecoder): Predicts masks from the image embeddings and encoded prompts. |
|
pixel_mean (List[float]): Mean values for normalizing pixels in the input image. |
|
pixel_std (List[float]): Std values for normalizing pixels in the input image. |
|
|
|
Examples: |
|
>>> image_encoder = ImageEncoderViT(...) |
|
>>> prompt_encoder = PromptEncoder(...) |
|
>>> mask_decoder = MaskDecoder(...) |
|
>>> sam_model = SAMModel(image_encoder, prompt_encoder, mask_decoder) |
|
>>> # Further usage depends on SAMPredictor class |
|
|
|
Notes: |
|
All forward() operations moved to SAMPredictor. |
|
""" |
|
super().__init__() |
|
self.image_encoder = image_encoder |
|
self.prompt_encoder = prompt_encoder |
|
self.mask_decoder = mask_decoder |
|
self.register_buffer("pixel_mean", torch.Tensor(pixel_mean).view(-1, 1, 1), False) |
|
self.register_buffer("pixel_std", torch.Tensor(pixel_std).view(-1, 1, 1), False) |
|
|
|
def set_imgsz(self, imgsz): |
|
""" |
|
Set image size to make model compatible with different image sizes. |
|
|
|
Args: |
|
imgsz (Tuple[int, int]): The size of the input image. |
|
""" |
|
if hasattr(self.image_encoder, "set_imgsz"): |
|
self.image_encoder.set_imgsz(imgsz) |
|
self.prompt_encoder.input_image_size = imgsz |
|
self.prompt_encoder.image_embedding_size = [x // 16 for x in imgsz] |
|
self.image_encoder.img_size = imgsz[0] |
|
|
|
|
|
class SAM2Model(torch.nn.Module): |
|
""" |
|
SAM2Model class for Segment Anything Model 2 with memory-based video object segmentation capabilities. |
|
|
|
This class extends the functionality of SAM to handle video sequences, incorporating memory mechanisms |
|
for temporal consistency and efficient tracking of objects across frames. |
|
|
|
Attributes: |
|
mask_threshold (float): Threshold value for mask prediction. |
|
image_encoder (ImageEncoderViT): Visual encoder for extracting image features. |
|
memory_attention (nn.Module): Module for attending to memory features. |
|
memory_encoder (nn.Module): Encoder for generating memory representations. |
|
num_maskmem (int): Number of accessible memory frames. |
|
image_size (int): Size of input images. |
|
backbone_stride (int): Stride of the backbone network output. |
|
sam_prompt_embed_dim (int): Dimension of SAM prompt embeddings. |
|
sam_image_embedding_size (int): Size of SAM image embeddings. |
|
sam_prompt_encoder (PromptEncoder): Encoder for processing input prompts. |
|
sam_mask_decoder (SAM2MaskDecoder): Decoder for generating object masks. |
|
obj_ptr_proj (nn.Module): Projection layer for object pointers. |
|
obj_ptr_tpos_proj (nn.Module): Projection for temporal positional encoding in object pointers. |
|
|
|
Methods: |
|
forward_image: Processes image batch through encoder to extract multi-level features. |
|
track_step: Performs a single tracking step, updating object masks and memory features. |
|
|
|
Examples: |
|
>>> model = SAM2Model(image_encoder, memory_attention, memory_encoder) |
|
>>> image_batch = torch.rand(1, 3, 512, 512) |
|
>>> features = model.forward_image(image_batch) |
|
>>> track_results = model.track_step(0, True, features, None, None, None, {}) |
|
""" |
|
|
|
mask_threshold: float = 0.0 |
|
|
|
def __init__( |
|
self, |
|
image_encoder, |
|
memory_attention, |
|
memory_encoder, |
|
num_maskmem=7, |
|
image_size=512, |
|
backbone_stride=16, |
|
sigmoid_scale_for_mem_enc=1.0, |
|
sigmoid_bias_for_mem_enc=0.0, |
|
binarize_mask_from_pts_for_mem_enc=False, |
|
use_mask_input_as_output_without_sam=False, |
|
max_cond_frames_in_attn=-1, |
|
directly_add_no_mem_embed=False, |
|
use_high_res_features_in_sam=False, |
|
multimask_output_in_sam=False, |
|
multimask_min_pt_num=1, |
|
multimask_max_pt_num=1, |
|
multimask_output_for_tracking=False, |
|
use_multimask_token_for_obj_ptr: bool = False, |
|
iou_prediction_use_sigmoid=False, |
|
memory_temporal_stride_for_eval=1, |
|
non_overlap_masks_for_mem_enc=False, |
|
use_obj_ptrs_in_encoder=False, |
|
max_obj_ptrs_in_encoder=16, |
|
add_tpos_enc_to_obj_ptrs=True, |
|
proj_tpos_enc_in_obj_ptrs=False, |
|
use_signed_tpos_enc_to_obj_ptrs=False, |
|
only_obj_ptrs_in_the_past_for_eval=False, |
|
pred_obj_scores: bool = False, |
|
pred_obj_scores_mlp: bool = False, |
|
fixed_no_obj_ptr: bool = False, |
|
soft_no_obj_ptr: bool = False, |
|
use_mlp_for_obj_ptr_proj: bool = False, |
|
no_obj_embed_spatial: bool = False, |
|
sam_mask_decoder_extra_args=None, |
|
compile_image_encoder: bool = False, |
|
): |
|
""" |
|
Initializes the SAM2Model for video object segmentation with memory-based tracking. |
|
|
|
Args: |
|
image_encoder (nn.Module): Visual encoder for extracting image features. |
|
memory_attention (nn.Module): Module for attending to memory features. |
|
memory_encoder (nn.Module): Encoder for generating memory representations. |
|
num_maskmem (int): Number of accessible memory frames. Default is 7 (1 input frame + 6 previous frames). |
|
image_size (int): Size of input images. |
|
backbone_stride (int): Stride of the image backbone output. |
|
sigmoid_scale_for_mem_enc (float): Scale factor for mask sigmoid probability. |
|
sigmoid_bias_for_mem_enc (float): Bias factor for mask sigmoid probability. |
|
binarize_mask_from_pts_for_mem_enc (bool): Whether to binarize sigmoid mask logits on interacted frames |
|
with clicks during evaluation. |
|
use_mask_input_as_output_without_sam (bool): Whether to directly output the input mask without using SAM |
|
prompt encoder and mask decoder on frames with mask input. |
|
max_cond_frames_in_attn (int): Maximum number of conditioning frames to participate in memory attention. |
|
-1 means no limit. |
|
directly_add_no_mem_embed (bool): Whether to directly add no-memory embedding to image feature on the |
|
first frame. |
|
use_high_res_features_in_sam (bool): Whether to use high-resolution feature maps in the SAM mask decoder. |
|
multimask_output_in_sam (bool): Whether to output multiple (3) masks for the first click on initial |
|
conditioning frames. |
|
multimask_min_pt_num (int): Minimum number of clicks to use multimask output in SAM. |
|
multimask_max_pt_num (int): Maximum number of clicks to use multimask output in SAM. |
|
multimask_output_for_tracking (bool): Whether to use multimask output for tracking. |
|
use_multimask_token_for_obj_ptr (bool): Whether to use multimask tokens for object pointers. |
|
iou_prediction_use_sigmoid (bool): Whether to use sigmoid to restrict IoU prediction to [0-1]. |
|
memory_temporal_stride_for_eval (int): Memory bank's temporal stride during evaluation. |
|
non_overlap_masks_for_mem_enc (bool): Whether to apply non-overlapping constraints on object masks in |
|
memory encoder during evaluation. |
|
use_obj_ptrs_in_encoder (bool): Whether to cross-attend to object pointers from other frames in the encoder. |
|
max_obj_ptrs_in_encoder (int): Maximum number of object pointers from other frames in encoder |
|
cross-attention. |
|
add_tpos_enc_to_obj_ptrs (bool): Whether to add temporal positional encoding to object pointers in |
|
the encoder. |
|
proj_tpos_enc_in_obj_ptrs (bool): Whether to add an extra linear projection layer for temporal positional |
|
encoding in object pointers. |
|
use_signed_tpos_enc_to_obj_ptrs (bool): whether to use signed distance (instead of unsigned absolute distance) |
|
in the temporal positional encoding in the object pointers, only relevant when both `use_obj_ptrs_in_encoder=True` |
|
and `add_tpos_enc_to_obj_ptrs=True`. |
|
only_obj_ptrs_in_the_past_for_eval (bool): Whether to only attend to object pointers in the past |
|
during evaluation. |
|
pred_obj_scores (bool): Whether to predict if there is an object in the frame. |
|
pred_obj_scores_mlp (bool): Whether to use an MLP to predict object scores. |
|
fixed_no_obj_ptr (bool): Whether to have a fixed no-object pointer when there is no object present. |
|
soft_no_obj_ptr (bool): Whether to mix in no-object pointer softly for easier recovery and error mitigation. |
|
use_mlp_for_obj_ptr_proj (bool): Whether to use MLP for object pointer projection. |
|
no_obj_embed_spatial (bool): Whether add no obj embedding to spatial frames. |
|
sam_mask_decoder_extra_args (Dict | None): Extra arguments for constructing the SAM mask decoder. |
|
compile_image_encoder (bool): Whether to compile the image encoder for faster inference. |
|
|
|
Examples: |
|
>>> image_encoder = ImageEncoderViT(...) |
|
>>> memory_attention = SAM2TwoWayTransformer(...) |
|
>>> memory_encoder = nn.Sequential(...) |
|
>>> model = SAM2Model(image_encoder, memory_attention, memory_encoder) |
|
>>> image_batch = torch.rand(1, 3, 512, 512) |
|
>>> features = model.forward_image(image_batch) |
|
>>> track_results = model.track_step(0, True, features, None, None, None, {}) |
|
""" |
|
super().__init__() |
|
|
|
|
|
self.image_encoder = image_encoder |
|
|
|
self.use_high_res_features_in_sam = use_high_res_features_in_sam |
|
self.num_feature_levels = 3 if use_high_res_features_in_sam else 1 |
|
self.use_obj_ptrs_in_encoder = use_obj_ptrs_in_encoder |
|
self.max_obj_ptrs_in_encoder = max_obj_ptrs_in_encoder |
|
if use_obj_ptrs_in_encoder: |
|
|
|
|
|
|
|
self.mask_downsample = torch.nn.Conv2d(1, 1, kernel_size=4, stride=4) |
|
self.add_tpos_enc_to_obj_ptrs = add_tpos_enc_to_obj_ptrs |
|
if proj_tpos_enc_in_obj_ptrs: |
|
assert add_tpos_enc_to_obj_ptrs |
|
self.proj_tpos_enc_in_obj_ptrs = proj_tpos_enc_in_obj_ptrs |
|
self.use_signed_tpos_enc_to_obj_ptrs = use_signed_tpos_enc_to_obj_ptrs |
|
self.only_obj_ptrs_in_the_past_for_eval = only_obj_ptrs_in_the_past_for_eval |
|
|
|
|
|
|
|
self.memory_attention = memory_attention |
|
self.hidden_dim = memory_attention.d_model |
|
|
|
|
|
self.memory_encoder = memory_encoder |
|
self.mem_dim = self.hidden_dim |
|
if hasattr(self.memory_encoder, "out_proj") and hasattr(self.memory_encoder.out_proj, "weight"): |
|
|
|
self.mem_dim = self.memory_encoder.out_proj.weight.shape[0] |
|
self.num_maskmem = num_maskmem |
|
|
|
self.maskmem_tpos_enc = torch.nn.Parameter(torch.zeros(num_maskmem, 1, 1, self.mem_dim)) |
|
trunc_normal_(self.maskmem_tpos_enc, std=0.02) |
|
|
|
self.no_mem_embed = torch.nn.Parameter(torch.zeros(1, 1, self.hidden_dim)) |
|
self.no_mem_pos_enc = torch.nn.Parameter(torch.zeros(1, 1, self.hidden_dim)) |
|
trunc_normal_(self.no_mem_embed, std=0.02) |
|
trunc_normal_(self.no_mem_pos_enc, std=0.02) |
|
self.directly_add_no_mem_embed = directly_add_no_mem_embed |
|
|
|
|
|
self.sigmoid_scale_for_mem_enc = sigmoid_scale_for_mem_enc |
|
self.sigmoid_bias_for_mem_enc = sigmoid_bias_for_mem_enc |
|
self.binarize_mask_from_pts_for_mem_enc = binarize_mask_from_pts_for_mem_enc |
|
self.non_overlap_masks_for_mem_enc = non_overlap_masks_for_mem_enc |
|
self.memory_temporal_stride_for_eval = memory_temporal_stride_for_eval |
|
|
|
|
|
self.use_mask_input_as_output_without_sam = use_mask_input_as_output_without_sam |
|
self.multimask_output_in_sam = multimask_output_in_sam |
|
self.multimask_min_pt_num = multimask_min_pt_num |
|
self.multimask_max_pt_num = multimask_max_pt_num |
|
self.multimask_output_for_tracking = multimask_output_for_tracking |
|
self.use_multimask_token_for_obj_ptr = use_multimask_token_for_obj_ptr |
|
self.iou_prediction_use_sigmoid = iou_prediction_use_sigmoid |
|
|
|
|
|
|
|
self.image_size = image_size |
|
self.backbone_stride = backbone_stride |
|
self.sam_mask_decoder_extra_args = sam_mask_decoder_extra_args |
|
self.pred_obj_scores = pred_obj_scores |
|
self.pred_obj_scores_mlp = pred_obj_scores_mlp |
|
self.fixed_no_obj_ptr = fixed_no_obj_ptr |
|
self.soft_no_obj_ptr = soft_no_obj_ptr |
|
if self.fixed_no_obj_ptr: |
|
assert self.pred_obj_scores |
|
assert self.use_obj_ptrs_in_encoder |
|
if self.pred_obj_scores and self.use_obj_ptrs_in_encoder: |
|
self.no_obj_ptr = torch.nn.Parameter(torch.zeros(1, self.hidden_dim)) |
|
trunc_normal_(self.no_obj_ptr, std=0.02) |
|
self.use_mlp_for_obj_ptr_proj = use_mlp_for_obj_ptr_proj |
|
self.no_obj_embed_spatial = None |
|
if no_obj_embed_spatial: |
|
self.no_obj_embed_spatial = torch.nn.Parameter(torch.zeros(1, self.mem_dim)) |
|
trunc_normal_(self.no_obj_embed_spatial, std=0.02) |
|
|
|
self._build_sam_heads() |
|
self.max_cond_frames_in_attn = max_cond_frames_in_attn |
|
|
|
|
|
if compile_image_encoder: |
|
|
|
print("Image encoder compilation is enabled. First forward pass will be slow.") |
|
self.image_encoder.forward = torch.compile( |
|
self.image_encoder.forward, |
|
mode="max-autotune", |
|
fullgraph=True, |
|
dynamic=False, |
|
) |
|
|
|
@property |
|
def device(self): |
|
"""Returns the device on which the model's parameters are stored.""" |
|
return next(self.parameters()).device |
|
|
|
def forward(self, *args, **kwargs): |
|
"""Processes image and prompt inputs to generate object masks and scores in video sequences.""" |
|
raise NotImplementedError( |
|
"Please use the corresponding methods in SAM2VideoPredictor for inference." |
|
"See notebooks/video_predictor_example.ipynb for an example." |
|
) |
|
|
|
def _build_sam_heads(self): |
|
"""Builds SAM-style prompt encoder and mask decoder for image segmentation tasks.""" |
|
self.sam_prompt_embed_dim = self.hidden_dim |
|
self.sam_image_embedding_size = self.image_size // self.backbone_stride |
|
|
|
|
|
self.sam_prompt_encoder = PromptEncoder( |
|
embed_dim=self.sam_prompt_embed_dim, |
|
image_embedding_size=( |
|
self.sam_image_embedding_size, |
|
self.sam_image_embedding_size, |
|
), |
|
input_image_size=(self.image_size, self.image_size), |
|
mask_in_chans=16, |
|
) |
|
self.sam_mask_decoder = SAM2MaskDecoder( |
|
num_multimask_outputs=3, |
|
transformer=SAM2TwoWayTransformer( |
|
depth=2, |
|
embedding_dim=self.sam_prompt_embed_dim, |
|
mlp_dim=2048, |
|
num_heads=8, |
|
), |
|
transformer_dim=self.sam_prompt_embed_dim, |
|
iou_head_depth=3, |
|
iou_head_hidden_dim=256, |
|
use_high_res_features=self.use_high_res_features_in_sam, |
|
iou_prediction_use_sigmoid=self.iou_prediction_use_sigmoid, |
|
pred_obj_scores=self.pred_obj_scores, |
|
pred_obj_scores_mlp=self.pred_obj_scores_mlp, |
|
use_multimask_token_for_obj_ptr=self.use_multimask_token_for_obj_ptr, |
|
**(self.sam_mask_decoder_extra_args or {}), |
|
) |
|
if self.use_obj_ptrs_in_encoder: |
|
|
|
self.obj_ptr_proj = torch.nn.Linear(self.hidden_dim, self.hidden_dim) |
|
if self.use_mlp_for_obj_ptr_proj: |
|
self.obj_ptr_proj = MLP(self.hidden_dim, self.hidden_dim, self.hidden_dim, 3) |
|
else: |
|
self.obj_ptr_proj = torch.nn.Identity() |
|
if self.proj_tpos_enc_in_obj_ptrs: |
|
|
|
|
|
self.obj_ptr_tpos_proj = torch.nn.Linear(self.hidden_dim, self.mem_dim) |
|
else: |
|
self.obj_ptr_tpos_proj = torch.nn.Identity() |
|
|
|
def _forward_sam_heads( |
|
self, |
|
backbone_features, |
|
point_inputs=None, |
|
mask_inputs=None, |
|
high_res_features=None, |
|
multimask_output=False, |
|
): |
|
""" |
|
Forward pass through SAM prompt encoders and mask heads. |
|
|
|
This method processes image features and optional point/mask inputs to generate object masks and scores. |
|
|
|
Args: |
|
backbone_features (torch.Tensor): Image features with shape (B, C, H, W). |
|
point_inputs (Dict[str, torch.Tensor] | None): Dictionary containing point prompts. |
|
'point_coords': Tensor of shape (B, P, 2) with float32 dtype, containing absolute |
|
pixel-unit coordinates in (x, y) format for P input points. |
|
'point_labels': Tensor of shape (B, P) with int32 dtype, where 1 means positive clicks, |
|
0 means negative clicks, and -1 means padding. |
|
mask_inputs (torch.Tensor | None): Mask of shape (B, 1, H*16, W*16), float or bool, with the |
|
same spatial size as the image. |
|
high_res_features (List[torch.Tensor] | None): List of two feature maps with shapes |
|
(B, C, 4*H, 4*W) and (B, C, 2*H, 2*W) respectively, used as high-resolution feature maps |
|
for SAM decoder. |
|
multimask_output (bool): If True, output 3 candidate masks and their IoU estimates; if False, |
|
output only 1 mask and its IoU estimate. |
|
|
|
Returns: |
|
(Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]): |
|
low_res_multimasks: Tensor of shape (B, M, H*4, W*4) with SAM output mask logits. |
|
high_res_multimasks: Tensor of shape (B, M, H*16, W*16) with upsampled mask logits. |
|
ious: Tensor of shape (B, M) with estimated IoU for each output mask. |
|
low_res_masks: Tensor of shape (B, 1, H*4, W*4) with the best low-resolution mask. |
|
high_res_masks: Tensor of shape (B, 1, H*16, W*16) with the best high-resolution mask. |
|
obj_ptr: Tensor of shape (B, C) with object pointer vector for the output mask. |
|
object_score_logits: Tensor of shape (B) with object score logits. |
|
|
|
Where M is 3 if multimask_output=True, and 1 if multimask_output=False. |
|
|
|
Examples: |
|
>>> backbone_features = torch.rand(1, 256, 32, 32) |
|
>>> point_inputs = {"point_coords": torch.rand(1, 2, 2), "point_labels": torch.tensor([[1, 0]])} |
|
>>> mask_inputs = torch.rand(1, 1, 512, 512) |
|
>>> results = model._forward_sam_heads(backbone_features, point_inputs, mask_inputs) |
|
>>> ( |
|
... low_res_multimasks, |
|
... high_res_multimasks, |
|
... ious, |
|
... low_res_masks, |
|
... high_res_masks, |
|
... obj_ptr, |
|
... object_score_logits, |
|
... ) = results |
|
""" |
|
B = backbone_features.size(0) |
|
device = backbone_features.device |
|
assert backbone_features.size(1) == self.sam_prompt_embed_dim |
|
assert backbone_features.size(2) == self.sam_image_embedding_size |
|
assert backbone_features.size(3) == self.sam_image_embedding_size |
|
|
|
|
|
if point_inputs is not None: |
|
sam_point_coords = point_inputs["point_coords"] |
|
sam_point_labels = point_inputs["point_labels"] |
|
assert sam_point_coords.size(0) == B and sam_point_labels.size(0) == B |
|
else: |
|
|
|
sam_point_coords = torch.zeros(B, 1, 2, device=device) |
|
sam_point_labels = -torch.ones(B, 1, dtype=torch.int32, device=device) |
|
|
|
|
|
if mask_inputs is not None: |
|
|
|
|
|
assert len(mask_inputs.shape) == 4 and mask_inputs.shape[:2] == (B, 1) |
|
if mask_inputs.shape[-2:] != self.sam_prompt_encoder.mask_input_size: |
|
sam_mask_prompt = F.interpolate( |
|
mask_inputs.float(), |
|
size=self.sam_prompt_encoder.mask_input_size, |
|
align_corners=False, |
|
mode="bilinear", |
|
antialias=True, |
|
) |
|
else: |
|
sam_mask_prompt = mask_inputs |
|
else: |
|
|
|
|
|
sam_mask_prompt = None |
|
|
|
sparse_embeddings, dense_embeddings = self.sam_prompt_encoder( |
|
points=(sam_point_coords, sam_point_labels), |
|
boxes=None, |
|
masks=sam_mask_prompt, |
|
) |
|
low_res_multimasks, ious, sam_output_tokens, object_score_logits = self.sam_mask_decoder( |
|
image_embeddings=backbone_features, |
|
image_pe=self.sam_prompt_encoder.get_dense_pe(), |
|
sparse_prompt_embeddings=sparse_embeddings, |
|
dense_prompt_embeddings=dense_embeddings, |
|
multimask_output=multimask_output, |
|
repeat_image=False, |
|
high_res_features=high_res_features, |
|
) |
|
if self.pred_obj_scores: |
|
is_obj_appearing = object_score_logits > 0 |
|
|
|
|
|
low_res_multimasks = torch.where(is_obj_appearing[:, None, None], low_res_multimasks, NO_OBJ_SCORE) |
|
|
|
|
|
|
|
low_res_multimasks = low_res_multimasks.float() |
|
high_res_multimasks = F.interpolate( |
|
low_res_multimasks, |
|
size=(self.image_size, self.image_size), |
|
mode="bilinear", |
|
align_corners=False, |
|
) |
|
|
|
sam_output_token = sam_output_tokens[:, 0] |
|
if multimask_output: |
|
|
|
best_iou_inds = torch.argmax(ious, dim=-1) |
|
batch_inds = torch.arange(B, device=device) |
|
low_res_masks = low_res_multimasks[batch_inds, best_iou_inds].unsqueeze(1) |
|
high_res_masks = high_res_multimasks[batch_inds, best_iou_inds].unsqueeze(1) |
|
if sam_output_tokens.size(1) > 1: |
|
sam_output_token = sam_output_tokens[batch_inds, best_iou_inds] |
|
else: |
|
low_res_masks, high_res_masks = low_res_multimasks, high_res_multimasks |
|
|
|
|
|
obj_ptr = self.obj_ptr_proj(sam_output_token) |
|
if self.pred_obj_scores: |
|
|
|
if self.soft_no_obj_ptr: |
|
lambda_is_obj_appearing = object_score_logits.sigmoid() |
|
else: |
|
lambda_is_obj_appearing = is_obj_appearing.float() |
|
|
|
if self.fixed_no_obj_ptr: |
|
obj_ptr = lambda_is_obj_appearing * obj_ptr |
|
obj_ptr = obj_ptr + (1 - lambda_is_obj_appearing) * self.no_obj_ptr |
|
|
|
return ( |
|
low_res_multimasks, |
|
high_res_multimasks, |
|
ious, |
|
low_res_masks, |
|
high_res_masks, |
|
obj_ptr, |
|
object_score_logits, |
|
) |
|
|
|
def _use_mask_as_output(self, backbone_features, high_res_features, mask_inputs): |
|
"""Processes mask inputs directly as output, bypassing SAM encoder/decoder.""" |
|
|
|
out_scale, out_bias = 20.0, -10.0 |
|
mask_inputs_float = mask_inputs.float() |
|
high_res_masks = mask_inputs_float * out_scale + out_bias |
|
low_res_masks = F.interpolate( |
|
high_res_masks, |
|
size=(high_res_masks.size(-2) // 4, high_res_masks.size(-1) // 4), |
|
align_corners=False, |
|
mode="bilinear", |
|
antialias=True, |
|
) |
|
|
|
ious = mask_inputs.new_ones(mask_inputs.size(0), 1).float() |
|
if not self.use_obj_ptrs_in_encoder: |
|
|
|
obj_ptr = torch.zeros(mask_inputs.size(0), self.hidden_dim, device=mask_inputs.device) |
|
else: |
|
|
|
_, _, _, _, _, obj_ptr, _ = self._forward_sam_heads( |
|
backbone_features=backbone_features, |
|
mask_inputs=self.mask_downsample(mask_inputs_float), |
|
high_res_features=high_res_features, |
|
) |
|
|
|
|
|
|
|
is_obj_appearing = torch.any(mask_inputs.flatten(1).float() > 0.0, dim=1) |
|
is_obj_appearing = is_obj_appearing[..., None] |
|
lambda_is_obj_appearing = is_obj_appearing.float() |
|
object_score_logits = out_scale * lambda_is_obj_appearing + out_bias |
|
if self.pred_obj_scores: |
|
if self.fixed_no_obj_ptr: |
|
obj_ptr = lambda_is_obj_appearing * obj_ptr |
|
obj_ptr = obj_ptr + (1 - lambda_is_obj_appearing) * self.no_obj_ptr |
|
|
|
return ( |
|
low_res_masks, |
|
high_res_masks, |
|
ious, |
|
low_res_masks, |
|
high_res_masks, |
|
obj_ptr, |
|
object_score_logits, |
|
) |
|
|
|
def forward_image(self, img_batch: torch.Tensor): |
|
"""Processes image batch through encoder to extract multi-level features for SAM model.""" |
|
backbone_out = self.image_encoder(img_batch) |
|
if self.use_high_res_features_in_sam: |
|
|
|
|
|
backbone_out["backbone_fpn"][0] = self.sam_mask_decoder.conv_s0(backbone_out["backbone_fpn"][0]) |
|
backbone_out["backbone_fpn"][1] = self.sam_mask_decoder.conv_s1(backbone_out["backbone_fpn"][1]) |
|
return backbone_out |
|
|
|
def _prepare_backbone_features(self, backbone_out): |
|
"""Prepares and flattens visual features from the image backbone output for further processing.""" |
|
assert len(backbone_out["backbone_fpn"]) == len(backbone_out["vision_pos_enc"]) |
|
assert len(backbone_out["backbone_fpn"]) >= self.num_feature_levels |
|
|
|
feature_maps = backbone_out["backbone_fpn"][-self.num_feature_levels :] |
|
vision_pos_embeds = backbone_out["vision_pos_enc"][-self.num_feature_levels :] |
|
|
|
feat_sizes = [(x.shape[-2], x.shape[-1]) for x in vision_pos_embeds] |
|
|
|
vision_feats = [x.flatten(2).permute(2, 0, 1) for x in feature_maps] |
|
vision_pos_embeds = [x.flatten(2).permute(2, 0, 1) for x in vision_pos_embeds] |
|
|
|
return backbone_out, vision_feats, vision_pos_embeds, feat_sizes |
|
|
|
def _prepare_memory_conditioned_features( |
|
self, |
|
frame_idx, |
|
is_init_cond_frame, |
|
current_vision_feats, |
|
current_vision_pos_embeds, |
|
feat_sizes, |
|
output_dict, |
|
num_frames, |
|
track_in_reverse=False, |
|
): |
|
"""Prepares memory-conditioned features by fusing current frame's visual features with previous memories.""" |
|
B = current_vision_feats[-1].size(1) |
|
C = self.hidden_dim |
|
H, W = feat_sizes[-1] |
|
device = current_vision_feats[-1].device |
|
|
|
|
|
if self.num_maskmem == 0: |
|
return current_vision_feats[-1].permute(1, 2, 0).view(B, C, H, W) |
|
num_obj_ptr_tokens = 0 |
|
tpos_sign_mul = -1 if track_in_reverse else 1 |
|
|
|
if not is_init_cond_frame: |
|
|
|
to_cat_memory, to_cat_memory_pos_embed = [], [] |
|
|
|
|
|
assert len(output_dict["cond_frame_outputs"]) > 0 |
|
|
|
cond_outputs = output_dict["cond_frame_outputs"] |
|
selected_cond_outputs, unselected_cond_outputs = select_closest_cond_frames( |
|
frame_idx, cond_outputs, self.max_cond_frames_in_attn |
|
) |
|
t_pos_and_prevs = [(0, out) for out in selected_cond_outputs.values()] |
|
|
|
|
|
|
|
|
|
r = 1 if self.training else self.memory_temporal_stride_for_eval |
|
for t_pos in range(1, self.num_maskmem): |
|
t_rel = self.num_maskmem - t_pos |
|
if t_rel == 1: |
|
|
|
prev_frame_idx = frame_idx + t_rel if track_in_reverse else frame_idx - t_rel |
|
elif not track_in_reverse: |
|
|
|
|
|
prev_frame_idx = ((frame_idx - 2) // r) * r |
|
|
|
prev_frame_idx = prev_frame_idx - (t_rel - 2) * r |
|
else: |
|
|
|
|
|
prev_frame_idx = -(-(frame_idx + 2) // r) * r |
|
|
|
prev_frame_idx = prev_frame_idx + (t_rel - 2) * r |
|
out = output_dict["non_cond_frame_outputs"].get(prev_frame_idx, None) |
|
if out is None: |
|
|
|
|
|
out = unselected_cond_outputs.get(prev_frame_idx, None) |
|
t_pos_and_prevs.append((t_pos, out)) |
|
|
|
for t_pos, prev in t_pos_and_prevs: |
|
if prev is None: |
|
continue |
|
|
|
|
|
feats = prev["maskmem_features"].to(device=device, non_blocking=True) |
|
to_cat_memory.append(feats.flatten(2).permute(2, 0, 1)) |
|
|
|
maskmem_enc = prev["maskmem_pos_enc"][-1].to(device=device) |
|
maskmem_enc = maskmem_enc.flatten(2).permute(2, 0, 1) |
|
|
|
maskmem_enc = maskmem_enc + self.maskmem_tpos_enc[self.num_maskmem - t_pos - 1] |
|
to_cat_memory_pos_embed.append(maskmem_enc) |
|
|
|
|
|
if self.use_obj_ptrs_in_encoder: |
|
max_obj_ptrs_in_encoder = min(num_frames, self.max_obj_ptrs_in_encoder) |
|
|
|
|
|
if not self.training and self.only_obj_ptrs_in_the_past_for_eval: |
|
ptr_cond_outputs = { |
|
t: out |
|
for t, out in selected_cond_outputs.items() |
|
if (t >= frame_idx if track_in_reverse else t <= frame_idx) |
|
} |
|
else: |
|
ptr_cond_outputs = selected_cond_outputs |
|
pos_and_ptrs = [ |
|
|
|
( |
|
( |
|
(frame_idx - t) * tpos_sign_mul |
|
if self.use_signed_tpos_enc_to_obj_ptrs |
|
else abs(frame_idx - t) |
|
), |
|
out["obj_ptr"], |
|
) |
|
for t, out in ptr_cond_outputs.items() |
|
] |
|
|
|
for t_diff in range(1, max_obj_ptrs_in_encoder): |
|
t = frame_idx + t_diff if track_in_reverse else frame_idx - t_diff |
|
if t < 0 or (num_frames is not None and t >= num_frames): |
|
break |
|
out = output_dict["non_cond_frame_outputs"].get(t, unselected_cond_outputs.get(t, None)) |
|
if out is not None: |
|
pos_and_ptrs.append((t_diff, out["obj_ptr"])) |
|
|
|
if pos_and_ptrs: |
|
pos_list, ptrs_list = zip(*pos_and_ptrs) |
|
|
|
obj_ptrs = torch.stack(ptrs_list, dim=0) |
|
|
|
|
|
if self.add_tpos_enc_to_obj_ptrs: |
|
t_diff_max = max_obj_ptrs_in_encoder - 1 |
|
tpos_dim = C if self.proj_tpos_enc_in_obj_ptrs else self.mem_dim |
|
obj_pos = torch.tensor(pos_list, device=device) |
|
obj_pos = get_1d_sine_pe(obj_pos / t_diff_max, dim=tpos_dim) |
|
obj_pos = self.obj_ptr_tpos_proj(obj_pos) |
|
obj_pos = obj_pos.unsqueeze(1).expand(-1, B, self.mem_dim) |
|
else: |
|
obj_pos = obj_ptrs.new_zeros(len(pos_list), B, self.mem_dim) |
|
if self.mem_dim < C: |
|
|
|
obj_ptrs = obj_ptrs.reshape(-1, B, C // self.mem_dim, self.mem_dim) |
|
obj_ptrs = obj_ptrs.permute(0, 2, 1, 3).flatten(0, 1) |
|
obj_pos = obj_pos.repeat_interleave(C // self.mem_dim, dim=0) |
|
to_cat_memory.append(obj_ptrs) |
|
to_cat_memory_pos_embed.append(obj_pos) |
|
num_obj_ptr_tokens = obj_ptrs.shape[0] |
|
else: |
|
num_obj_ptr_tokens = 0 |
|
else: |
|
|
|
if self.directly_add_no_mem_embed: |
|
|
|
pix_feat_with_mem = current_vision_feats[-1] + self.no_mem_embed |
|
pix_feat_with_mem = pix_feat_with_mem.permute(1, 2, 0).view(B, C, H, W) |
|
return pix_feat_with_mem |
|
|
|
|
|
to_cat_memory = [self.no_mem_embed.expand(1, B, self.mem_dim)] |
|
to_cat_memory_pos_embed = [self.no_mem_pos_enc.expand(1, B, self.mem_dim)] |
|
|
|
|
|
memory = torch.cat(to_cat_memory, dim=0) |
|
memory_pos_embed = torch.cat(to_cat_memory_pos_embed, dim=0) |
|
|
|
pix_feat_with_mem = self.memory_attention( |
|
curr=current_vision_feats, |
|
curr_pos=current_vision_pos_embeds, |
|
memory=memory, |
|
memory_pos=memory_pos_embed, |
|
num_obj_ptr_tokens=num_obj_ptr_tokens, |
|
) |
|
|
|
pix_feat_with_mem = pix_feat_with_mem.permute(1, 2, 0).view(B, C, H, W) |
|
return pix_feat_with_mem |
|
|
|
def _encode_new_memory( |
|
self, |
|
current_vision_feats, |
|
feat_sizes, |
|
pred_masks_high_res, |
|
object_score_logits, |
|
is_mask_from_pts, |
|
): |
|
"""Encodes frame features and masks into a new memory representation for video segmentation.""" |
|
B = current_vision_feats[-1].size(1) |
|
C = self.hidden_dim |
|
H, W = feat_sizes[-1] |
|
|
|
pix_feat = current_vision_feats[-1].permute(1, 2, 0).view(B, C, H, W) |
|
if self.non_overlap_masks_for_mem_enc and not self.training: |
|
|
|
|
|
|
|
pred_masks_high_res = self._apply_non_overlapping_constraints(pred_masks_high_res) |
|
|
|
binarize = self.binarize_mask_from_pts_for_mem_enc and is_mask_from_pts |
|
if binarize and not self.training: |
|
mask_for_mem = (pred_masks_high_res > 0).float() |
|
else: |
|
|
|
mask_for_mem = torch.sigmoid(pred_masks_high_res) |
|
|
|
if self.sigmoid_scale_for_mem_enc != 1.0: |
|
mask_for_mem = mask_for_mem * self.sigmoid_scale_for_mem_enc |
|
if self.sigmoid_bias_for_mem_enc != 0.0: |
|
mask_for_mem = mask_for_mem + self.sigmoid_bias_for_mem_enc |
|
maskmem_out = self.memory_encoder(pix_feat, mask_for_mem, skip_mask_sigmoid=True) |
|
maskmem_features = maskmem_out["vision_features"] |
|
maskmem_pos_enc = maskmem_out["vision_pos_enc"] |
|
|
|
|
|
if self.no_obj_embed_spatial is not None: |
|
is_obj_appearing = (object_score_logits > 0).float() |
|
maskmem_features += (1 - is_obj_appearing[..., None, None]) * self.no_obj_embed_spatial[ |
|
..., None, None |
|
].expand(*maskmem_features.shape) |
|
|
|
return maskmem_features, maskmem_pos_enc |
|
|
|
def _track_step( |
|
self, |
|
frame_idx, |
|
is_init_cond_frame, |
|
current_vision_feats, |
|
current_vision_pos_embeds, |
|
feat_sizes, |
|
point_inputs, |
|
mask_inputs, |
|
output_dict, |
|
num_frames, |
|
track_in_reverse, |
|
prev_sam_mask_logits, |
|
): |
|
"""Performs a single tracking step, updating object masks and memory features based on current frame inputs.""" |
|
current_out = {"point_inputs": point_inputs, "mask_inputs": mask_inputs} |
|
|
|
if len(current_vision_feats) > 1: |
|
high_res_features = [ |
|
x.permute(1, 2, 0).view(x.size(1), x.size(2), *s) |
|
for x, s in zip(current_vision_feats[:-1], feat_sizes[:-1]) |
|
] |
|
else: |
|
high_res_features = None |
|
if mask_inputs is not None and self.use_mask_input_as_output_without_sam: |
|
|
|
|
|
pix_feat = current_vision_feats[-1].permute(1, 2, 0) |
|
pix_feat = pix_feat.view(-1, self.hidden_dim, *feat_sizes[-1]) |
|
sam_outputs = self._use_mask_as_output(pix_feat, high_res_features, mask_inputs) |
|
else: |
|
|
|
pix_feat = self._prepare_memory_conditioned_features( |
|
frame_idx=frame_idx, |
|
is_init_cond_frame=is_init_cond_frame, |
|
current_vision_feats=current_vision_feats[-1:], |
|
current_vision_pos_embeds=current_vision_pos_embeds[-1:], |
|
feat_sizes=feat_sizes[-1:], |
|
output_dict=output_dict, |
|
num_frames=num_frames, |
|
track_in_reverse=track_in_reverse, |
|
) |
|
|
|
|
|
|
|
|
|
if prev_sam_mask_logits is not None: |
|
assert point_inputs is not None and mask_inputs is None |
|
mask_inputs = prev_sam_mask_logits |
|
multimask_output = self._use_multimask(is_init_cond_frame, point_inputs) |
|
sam_outputs = self._forward_sam_heads( |
|
backbone_features=pix_feat, |
|
point_inputs=point_inputs, |
|
mask_inputs=mask_inputs, |
|
high_res_features=high_res_features, |
|
multimask_output=multimask_output, |
|
) |
|
return current_out, sam_outputs, high_res_features, pix_feat |
|
|
|
def _encode_memory_in_output( |
|
self, |
|
current_vision_feats, |
|
feat_sizes, |
|
point_inputs, |
|
run_mem_encoder, |
|
high_res_masks, |
|
object_score_logits, |
|
current_out, |
|
): |
|
"""Finally run the memory encoder on the predicted mask to encode, it into a new memory feature (that can be |
|
used in future frames). |
|
""" |
|
if run_mem_encoder and self.num_maskmem > 0: |
|
high_res_masks_for_mem_enc = high_res_masks |
|
maskmem_features, maskmem_pos_enc = self._encode_new_memory( |
|
current_vision_feats=current_vision_feats, |
|
feat_sizes=feat_sizes, |
|
pred_masks_high_res=high_res_masks_for_mem_enc, |
|
object_score_logits=object_score_logits, |
|
is_mask_from_pts=(point_inputs is not None), |
|
) |
|
current_out["maskmem_features"] = maskmem_features |
|
current_out["maskmem_pos_enc"] = maskmem_pos_enc |
|
else: |
|
current_out["maskmem_features"] = None |
|
current_out["maskmem_pos_enc"] = None |
|
|
|
def track_step( |
|
self, |
|
frame_idx, |
|
is_init_cond_frame, |
|
current_vision_feats, |
|
current_vision_pos_embeds, |
|
feat_sizes, |
|
point_inputs, |
|
mask_inputs, |
|
output_dict, |
|
num_frames, |
|
track_in_reverse=False, |
|
|
|
|
|
|
|
|
|
|
|
run_mem_encoder=True, |
|
|
|
prev_sam_mask_logits=None, |
|
): |
|
"""Performs a single tracking step, updating object masks and memory features based on current frame inputs.""" |
|
current_out, sam_outputs, _, _ = self._track_step( |
|
frame_idx, |
|
is_init_cond_frame, |
|
current_vision_feats, |
|
current_vision_pos_embeds, |
|
feat_sizes, |
|
point_inputs, |
|
mask_inputs, |
|
output_dict, |
|
num_frames, |
|
track_in_reverse, |
|
prev_sam_mask_logits, |
|
) |
|
_, _, _, low_res_masks, high_res_masks, obj_ptr, object_score_logits = sam_outputs |
|
|
|
current_out["pred_masks"] = low_res_masks |
|
current_out["pred_masks_high_res"] = high_res_masks |
|
current_out["obj_ptr"] = obj_ptr |
|
if not self.training: |
|
|
|
|
|
current_out["object_score_logits"] = object_score_logits |
|
|
|
|
|
self._encode_memory_in_output( |
|
current_vision_feats, |
|
feat_sizes, |
|
point_inputs, |
|
run_mem_encoder, |
|
high_res_masks, |
|
object_score_logits, |
|
current_out, |
|
) |
|
|
|
return current_out |
|
|
|
def _use_multimask(self, is_init_cond_frame, point_inputs): |
|
"""Determines whether to use multiple mask outputs in the SAM head based on configuration and inputs.""" |
|
num_pts = 0 if point_inputs is None else point_inputs["point_labels"].size(1) |
|
return ( |
|
self.multimask_output_in_sam |
|
and (is_init_cond_frame or self.multimask_output_for_tracking) |
|
and (self.multimask_min_pt_num <= num_pts <= self.multimask_max_pt_num) |
|
) |
|
|
|
@staticmethod |
|
def _apply_non_overlapping_constraints(pred_masks): |
|
"""Applies non-overlapping constraints to masks, keeping the highest scoring object per location.""" |
|
batch_size = pred_masks.size(0) |
|
if batch_size == 1: |
|
return pred_masks |
|
|
|
device = pred_masks.device |
|
|
|
max_obj_inds = torch.argmax(pred_masks, dim=0, keepdim=True) |
|
|
|
batch_obj_inds = torch.arange(batch_size, device=device)[:, None, None, None] |
|
keep = max_obj_inds == batch_obj_inds |
|
|
|
|
|
pred_masks = torch.where(keep, pred_masks, torch.clamp(pred_masks, max=-10.0)) |
|
return pred_masks |
|
|
|
def set_binarize(self, binarize=False): |
|
"""Set binarize for VideoPredictor.""" |
|
self.binarize_mask_from_pts_for_mem_enc = binarize |
|
|
|
def set_imgsz(self, imgsz): |
|
""" |
|
Set image size to make model compatible with different image sizes. |
|
|
|
Args: |
|
imgsz (Tuple[int, int]): The size of the input image. |
|
""" |
|
self.image_size = imgsz[0] |
|
self.sam_prompt_encoder.input_image_size = imgsz |
|
self.sam_prompt_encoder.image_embedding_size = [x // 16 for x in imgsz] |
|
|