383 lines
12 KiB
Python
383 lines
12 KiB
Python
|
|
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
||
|
|
|
||
|
|
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
|
||
|
|
|
||
|
|
import torch.nn as nn
|
||
|
|
|
||
|
|
from ultralytics.nn.modules.transformer import MLP
|
||
|
|
from ultralytics.utils.patches import torch_load
|
||
|
|
|
||
|
|
from .modules.blocks import PositionEmbeddingSine, RoPEAttention
|
||
|
|
from .modules.encoders import MemoryEncoder
|
||
|
|
from .modules.memory_attention import MemoryAttention, MemoryAttentionLayer
|
||
|
|
from .modules.sam import SAM3Model
|
||
|
|
from .sam3.decoder import TransformerDecoder, TransformerDecoderLayer
|
||
|
|
from .sam3.encoder import TransformerEncoderFusion, TransformerEncoderLayer
|
||
|
|
from .sam3.geometry_encoders import SequenceGeometryEncoder
|
||
|
|
from .sam3.maskformer_segmentation import PixelDecoder, UniversalSegmentationHead
|
||
|
|
from .sam3.model_misc import DotProductScoring, TransformerWrapper
|
||
|
|
from .sam3.necks import Sam3DualViTDetNeck
|
||
|
|
from .sam3.sam3_image import SAM3SemanticModel
|
||
|
|
from .sam3.text_encoder_ve import VETextEncoder
|
||
|
|
from .sam3.vitdet import ViT
|
||
|
|
from .sam3.vl_combiner import SAM3VLBackbone
|
||
|
|
|
||
|
|
|
||
|
|
def _create_vision_backbone(compile_mode=None, enable_inst_interactivity=True) -> Sam3DualViTDetNeck:
|
||
|
|
"""Create SAM3 visual backbone with ViT and neck."""
|
||
|
|
# Position encoding
|
||
|
|
position_encoding = PositionEmbeddingSine(
|
||
|
|
num_pos_feats=256,
|
||
|
|
normalize=True,
|
||
|
|
scale=None,
|
||
|
|
temperature=10000,
|
||
|
|
)
|
||
|
|
|
||
|
|
# ViT backbone
|
||
|
|
vit_backbone = ViT(
|
||
|
|
img_size=1008,
|
||
|
|
pretrain_img_size=336,
|
||
|
|
patch_size=14,
|
||
|
|
embed_dim=1024,
|
||
|
|
depth=32,
|
||
|
|
num_heads=16,
|
||
|
|
mlp_ratio=4.625,
|
||
|
|
norm_layer="LayerNorm",
|
||
|
|
drop_path_rate=0.1,
|
||
|
|
qkv_bias=True,
|
||
|
|
use_abs_pos=True,
|
||
|
|
tile_abs_pos=True,
|
||
|
|
global_att_blocks=(7, 15, 23, 31),
|
||
|
|
rel_pos_blocks=(),
|
||
|
|
use_rope=True,
|
||
|
|
use_interp_rope=True,
|
||
|
|
window_size=24,
|
||
|
|
pretrain_use_cls_token=True,
|
||
|
|
retain_cls_token=False,
|
||
|
|
ln_pre=True,
|
||
|
|
ln_post=False,
|
||
|
|
return_interm_layers=False,
|
||
|
|
bias_patch_embed=False,
|
||
|
|
compile_mode=compile_mode,
|
||
|
|
)
|
||
|
|
return Sam3DualViTDetNeck(
|
||
|
|
position_encoding=position_encoding,
|
||
|
|
d_model=256,
|
||
|
|
scale_factors=[4.0, 2.0, 1.0, 0.5],
|
||
|
|
trunk=vit_backbone,
|
||
|
|
add_sam2_neck=enable_inst_interactivity,
|
||
|
|
)
|
||
|
|
|
||
|
|
|
||
|
|
def _create_sam3_transformer() -> TransformerWrapper:
|
||
|
|
"""Create SAM3 detector encoder and decoder."""
|
||
|
|
encoder: TransformerEncoderFusion = TransformerEncoderFusion(
|
||
|
|
layer=TransformerEncoderLayer(
|
||
|
|
d_model=256,
|
||
|
|
dim_feedforward=2048,
|
||
|
|
dropout=0.1,
|
||
|
|
pos_enc_at_attn=True,
|
||
|
|
pos_enc_at_cross_attn_keys=False,
|
||
|
|
pos_enc_at_cross_attn_queries=False,
|
||
|
|
pre_norm=True,
|
||
|
|
self_attention=nn.MultiheadAttention(
|
||
|
|
num_heads=8,
|
||
|
|
dropout=0.1,
|
||
|
|
embed_dim=256,
|
||
|
|
batch_first=True,
|
||
|
|
),
|
||
|
|
cross_attention=nn.MultiheadAttention(
|
||
|
|
num_heads=8,
|
||
|
|
dropout=0.1,
|
||
|
|
embed_dim=256,
|
||
|
|
batch_first=True,
|
||
|
|
),
|
||
|
|
),
|
||
|
|
num_layers=6,
|
||
|
|
d_model=256,
|
||
|
|
num_feature_levels=1,
|
||
|
|
frozen=False,
|
||
|
|
use_act_checkpoint=True,
|
||
|
|
add_pooled_text_to_img_feat=False,
|
||
|
|
pool_text_with_mask=True,
|
||
|
|
)
|
||
|
|
decoder: TransformerDecoder = TransformerDecoder(
|
||
|
|
layer=TransformerDecoderLayer(
|
||
|
|
d_model=256,
|
||
|
|
dim_feedforward=2048,
|
||
|
|
dropout=0.1,
|
||
|
|
cross_attention=nn.MultiheadAttention(
|
||
|
|
num_heads=8,
|
||
|
|
dropout=0.1,
|
||
|
|
embed_dim=256,
|
||
|
|
),
|
||
|
|
n_heads=8,
|
||
|
|
use_text_cross_attention=True,
|
||
|
|
),
|
||
|
|
num_layers=6,
|
||
|
|
num_queries=200,
|
||
|
|
return_intermediate=True,
|
||
|
|
box_refine=True,
|
||
|
|
num_o2m_queries=0,
|
||
|
|
dac=True,
|
||
|
|
boxRPB="log",
|
||
|
|
d_model=256,
|
||
|
|
frozen=False,
|
||
|
|
interaction_layer=None,
|
||
|
|
dac_use_selfatt_ln=True,
|
||
|
|
use_act_checkpoint=True,
|
||
|
|
presence_token=True,
|
||
|
|
)
|
||
|
|
|
||
|
|
return TransformerWrapper(encoder=encoder, decoder=decoder, d_model=256)
|
||
|
|
|
||
|
|
|
||
|
|
def build_sam3_image_model(checkpoint_path: str, enable_segmentation: bool = True, compile: bool = False):
|
||
|
|
"""Build SAM3 image model.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
checkpoint_path: Optional path to model checkpoint
|
||
|
|
enable_segmentation: Whether to enable segmentation head
|
||
|
|
compile: Whether to enable compilation of the model
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
A SAM3 image model
|
||
|
|
"""
|
||
|
|
try:
|
||
|
|
import clip
|
||
|
|
except ImportError:
|
||
|
|
from ultralytics.utils.checks import check_requirements
|
||
|
|
|
||
|
|
check_requirements("git+https://github.com/ultralytics/CLIP.git")
|
||
|
|
import clip
|
||
|
|
# Create visual components
|
||
|
|
compile_mode = "default" if compile else None
|
||
|
|
vision_encoder = _create_vision_backbone(compile_mode=compile_mode, enable_inst_interactivity=True)
|
||
|
|
|
||
|
|
# Create text components
|
||
|
|
text_encoder = VETextEncoder(
|
||
|
|
tokenizer=clip.simple_tokenizer.SimpleTokenizer(),
|
||
|
|
d_model=256,
|
||
|
|
width=1024,
|
||
|
|
heads=16,
|
||
|
|
layers=24,
|
||
|
|
)
|
||
|
|
|
||
|
|
# Create visual-language backbone
|
||
|
|
backbone = SAM3VLBackbone(visual=vision_encoder, text=text_encoder, scalp=1)
|
||
|
|
|
||
|
|
# Create transformer components
|
||
|
|
transformer = _create_sam3_transformer()
|
||
|
|
|
||
|
|
# Create dot product scoring
|
||
|
|
dot_prod_scoring = DotProductScoring(
|
||
|
|
d_model=256,
|
||
|
|
d_proj=256,
|
||
|
|
prompt_mlp=MLP(
|
||
|
|
input_dim=256,
|
||
|
|
hidden_dim=2048,
|
||
|
|
output_dim=256,
|
||
|
|
num_layers=2,
|
||
|
|
residual=True,
|
||
|
|
out_norm=nn.LayerNorm(256),
|
||
|
|
),
|
||
|
|
)
|
||
|
|
|
||
|
|
# Create segmentation head if enabled
|
||
|
|
segmentation_head = (
|
||
|
|
UniversalSegmentationHead(
|
||
|
|
hidden_dim=256,
|
||
|
|
upsampling_stages=3,
|
||
|
|
aux_masks=False,
|
||
|
|
presence_head=False,
|
||
|
|
dot_product_scorer=None,
|
||
|
|
act_ckpt=True,
|
||
|
|
cross_attend_prompt=nn.MultiheadAttention(
|
||
|
|
num_heads=8,
|
||
|
|
dropout=0,
|
||
|
|
embed_dim=256,
|
||
|
|
),
|
||
|
|
pixel_decoder=PixelDecoder(
|
||
|
|
num_upsampling_stages=3,
|
||
|
|
interpolation_mode="nearest",
|
||
|
|
hidden_dim=256,
|
||
|
|
compile_mode=compile_mode,
|
||
|
|
),
|
||
|
|
)
|
||
|
|
if enable_segmentation
|
||
|
|
else None
|
||
|
|
)
|
||
|
|
|
||
|
|
# Create geometry encoder
|
||
|
|
input_geometry_encoder = SequenceGeometryEncoder(
|
||
|
|
pos_enc=PositionEmbeddingSine(
|
||
|
|
num_pos_feats=256,
|
||
|
|
normalize=True,
|
||
|
|
scale=None,
|
||
|
|
temperature=10000,
|
||
|
|
),
|
||
|
|
encode_boxes_as_points=False,
|
||
|
|
boxes_direct_project=True,
|
||
|
|
boxes_pool=True,
|
||
|
|
boxes_pos_enc=True,
|
||
|
|
d_model=256,
|
||
|
|
num_layers=3,
|
||
|
|
layer=TransformerEncoderLayer(
|
||
|
|
d_model=256,
|
||
|
|
dim_feedforward=2048,
|
||
|
|
dropout=0.1,
|
||
|
|
pos_enc_at_attn=False,
|
||
|
|
pre_norm=True,
|
||
|
|
pos_enc_at_cross_attn_queries=False,
|
||
|
|
pos_enc_at_cross_attn_keys=True,
|
||
|
|
),
|
||
|
|
use_act_ckpt=True,
|
||
|
|
add_cls=True,
|
||
|
|
add_post_encode_proj=True,
|
||
|
|
)
|
||
|
|
|
||
|
|
# Create the SAM3SemanticModel model
|
||
|
|
model = SAM3SemanticModel(
|
||
|
|
backbone=backbone,
|
||
|
|
transformer=transformer,
|
||
|
|
input_geometry_encoder=input_geometry_encoder,
|
||
|
|
segmentation_head=segmentation_head,
|
||
|
|
num_feature_levels=1,
|
||
|
|
o2m_mask_predict=True,
|
||
|
|
dot_prod_scoring=dot_prod_scoring,
|
||
|
|
use_instance_query=False,
|
||
|
|
multimask_output=True,
|
||
|
|
)
|
||
|
|
|
||
|
|
# Load checkpoint
|
||
|
|
model = _load_checkpoint(model, checkpoint_path)
|
||
|
|
model.eval()
|
||
|
|
return model
|
||
|
|
|
||
|
|
|
||
|
|
def build_interactive_sam3(checkpoint_path: str, compile=None, with_backbone=True) -> SAM3Model:
|
||
|
|
"""Build the SAM3 Tracker module for video tracking.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
checkpoint_path (str): Path to model checkpoint.
|
||
|
|
compile (str | None): Compilation mode for the vision backbone.
|
||
|
|
with_backbone (bool): Whether to include the vision backbone in the model.
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
(SAM3Model): A configured and initialized SAM3 model.
|
||
|
|
"""
|
||
|
|
# Create model components
|
||
|
|
memory_encoder = MemoryEncoder(out_dim=64, interpol_size=[1152, 1152])
|
||
|
|
memory_attention = MemoryAttention(
|
||
|
|
batch_first=True,
|
||
|
|
d_model=256,
|
||
|
|
pos_enc_at_input=True,
|
||
|
|
layer=MemoryAttentionLayer(
|
||
|
|
dim_feedforward=2048,
|
||
|
|
dropout=0.1,
|
||
|
|
pos_enc_at_attn=False,
|
||
|
|
pos_enc_at_cross_attn_keys=True,
|
||
|
|
pos_enc_at_cross_attn_queries=False,
|
||
|
|
self_attn=RoPEAttention(
|
||
|
|
embedding_dim=256,
|
||
|
|
num_heads=1,
|
||
|
|
downsample_rate=1,
|
||
|
|
rope_theta=10000.0,
|
||
|
|
feat_sizes=[72, 72],
|
||
|
|
),
|
||
|
|
d_model=256,
|
||
|
|
cross_attn=RoPEAttention(
|
||
|
|
embedding_dim=256,
|
||
|
|
num_heads=1,
|
||
|
|
downsample_rate=1,
|
||
|
|
kv_in_dim=64,
|
||
|
|
rope_theta=10000.0,
|
||
|
|
feat_sizes=[72, 72],
|
||
|
|
rope_k_repeat=True,
|
||
|
|
),
|
||
|
|
),
|
||
|
|
num_layers=4,
|
||
|
|
)
|
||
|
|
|
||
|
|
backbone = (
|
||
|
|
SAM3VLBackbone(scalp=1, visual=_create_vision_backbone(compile_mode=compile), text=None)
|
||
|
|
if with_backbone
|
||
|
|
else None
|
||
|
|
)
|
||
|
|
model = SAM3Model(
|
||
|
|
image_size=1008,
|
||
|
|
image_encoder=backbone,
|
||
|
|
memory_attention=memory_attention,
|
||
|
|
memory_encoder=memory_encoder,
|
||
|
|
backbone_stride=14,
|
||
|
|
num_maskmem=7,
|
||
|
|
sigmoid_scale_for_mem_enc=20.0,
|
||
|
|
sigmoid_bias_for_mem_enc=-10.0,
|
||
|
|
use_mask_input_as_output_without_sam=True,
|
||
|
|
directly_add_no_mem_embed=True,
|
||
|
|
use_high_res_features_in_sam=True,
|
||
|
|
multimask_output_in_sam=True,
|
||
|
|
iou_prediction_use_sigmoid=True,
|
||
|
|
use_obj_ptrs_in_encoder=True,
|
||
|
|
add_tpos_enc_to_obj_ptrs=True,
|
||
|
|
only_obj_ptrs_in_the_past_for_eval=True,
|
||
|
|
pred_obj_scores=True,
|
||
|
|
pred_obj_scores_mlp=True,
|
||
|
|
fixed_no_obj_ptr=True,
|
||
|
|
multimask_output_for_tracking=True,
|
||
|
|
use_multimask_token_for_obj_ptr=True,
|
||
|
|
multimask_min_pt_num=0,
|
||
|
|
multimask_max_pt_num=1,
|
||
|
|
use_mlp_for_obj_ptr_proj=True,
|
||
|
|
compile_image_encoder=False,
|
||
|
|
no_obj_embed_spatial=True,
|
||
|
|
proj_tpos_enc_in_obj_ptrs=True,
|
||
|
|
use_signed_tpos_enc_to_obj_ptrs=True,
|
||
|
|
sam_mask_decoder_extra_args=dict(
|
||
|
|
dynamic_multimask_via_stability=True,
|
||
|
|
dynamic_multimask_stability_delta=0.05,
|
||
|
|
dynamic_multimask_stability_thresh=0.98,
|
||
|
|
),
|
||
|
|
)
|
||
|
|
|
||
|
|
# Load checkpoint if provided
|
||
|
|
model = _load_checkpoint(model, checkpoint_path, interactive=True)
|
||
|
|
|
||
|
|
# Setup device and mode
|
||
|
|
model.eval()
|
||
|
|
return model
|
||
|
|
|
||
|
|
|
||
|
|
def _load_checkpoint(model, checkpoint, interactive=False):
|
||
|
|
"""Load SAM3 model checkpoint from file."""
|
||
|
|
with open(checkpoint, "rb") as f:
|
||
|
|
ckpt = torch_load(f)
|
||
|
|
if "model" in ckpt and isinstance(ckpt["model"], dict):
|
||
|
|
ckpt = ckpt["model"]
|
||
|
|
sam3_image_ckpt = {k.replace("detector.", ""): v for k, v in ckpt.items() if "detector" in k}
|
||
|
|
if interactive:
|
||
|
|
sam3_image_ckpt.update(
|
||
|
|
{
|
||
|
|
k.replace("backbone.vision_backbone", "image_encoder.vision_backbone"): v
|
||
|
|
for k, v in sam3_image_ckpt.items()
|
||
|
|
if "backbone.vision_backbone" in k
|
||
|
|
}
|
||
|
|
)
|
||
|
|
sam3_image_ckpt.update(
|
||
|
|
{
|
||
|
|
k.replace("tracker.transformer.encoder", "memory_attention"): v
|
||
|
|
for k, v in ckpt.items()
|
||
|
|
if "tracker.transformer" in k
|
||
|
|
}
|
||
|
|
)
|
||
|
|
sam3_image_ckpt.update(
|
||
|
|
{
|
||
|
|
k.replace("tracker.maskmem_backbone", "memory_encoder"): v
|
||
|
|
for k, v in ckpt.items()
|
||
|
|
if "tracker.maskmem_backbone" in k
|
||
|
|
}
|
||
|
|
)
|
||
|
|
sam3_image_ckpt.update({k.replace("tracker.", ""): v for k, v in ckpt.items() if "tracker." in k})
|
||
|
|
model.load_state_dict(sam3_image_ckpt, strict=False)
|
||
|
|
return model
|