Skip to content

Commit

Permalink
TO DO : fix the image_embeddings and sparse_embeddings part
Browse files Browse the repository at this point in the history
  • Loading branch information
SangbumChoi committed Oct 28, 2024
1 parent ab46f71 commit 9182af6
Show file tree
Hide file tree
Showing 3 changed files with 111 additions and 45 deletions.
22 changes: 12 additions & 10 deletions src/transformers/models/sam2/configuration_sam2.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,7 @@ def __init__(
pred_obj_scores=True,
pred_obj_scores_mlp=True,
use_multimask_token_for_obj_ptr=True,
feed_forward_hidden_act="relu",
two_way_transformer_depth=2,
two_way_transformer_embedding_dim=256,
two_way_transformer_num_heads=8,
Expand All @@ -202,6 +203,7 @@ def __init__(
self.pred_obj_scores = pred_obj_scores
self.pred_obj_scores_mlp = pred_obj_scores_mlp
self.use_multimask_token_for_obj_ptr = use_multimask_token_for_obj_ptr
self.feed_forward_hidden_act = feed_forward_hidden_act

# TwoWayTransformer configuration
self.two_way_transformer_depth = two_way_transformer_depth
Expand All @@ -223,8 +225,8 @@ class Sam2ImageEncoderConfig(PretrainedConfig):
documentation from [`PretrainedConfig`] for more information.
Args:
scalp (`int`, *optional*, defaults to 1):
The scalp parameter for the image encoder.
skip_lowest_resolutions (`int`, *optional*, defaults to 1):
The skip_lowest_resolutions parameter for the image encoder.
hidden_size (`<fill_type>`, *optional*, defaults to 96): <fill_docstring>
num_heads (`int`, *optional*, defaults to 1):
Initial number of attention heads.
Expand All @@ -245,11 +247,11 @@ class Sam2ImageEncoderConfig(PretrainedConfig):
Dimension multiplier factor at stage shift.
head_mul (`float`, *optional*, defaults to 2.0):
Head multiplier factor at stage shift.
window_pos_embed_bkg_spatial_size (`Tuple[int, int]`, *optional*, defaults to `(7, 7)`):
window_positional_embedding_background_size (`Tuple[int, int]`, *optional*, defaults to `(7, 7)`):
Window size per stage when not using global attention.
window_spec (`Tuple[int, ...]`, *optional*, defaults to `(8, 4, 14, 7)`):
Window specifications for each stage.
global_att_blocks (`Tuple[int, ...]`, *optional*, defaults to `(5, 7, 9)`):
global_attention_blocks (`Tuple[int, ...]`, *optional*, defaults to `(5, 7, 9)`):
Blocks where global attention is used.
d_model (`int`, *optional*, defaults to 256):
Dimension of the model in the neck.
Expand All @@ -274,7 +276,6 @@ class Sam2ImageEncoderConfig(PretrainedConfig):

def __init__(
self,
scalp=1,
hidden_size=96,
num_heads=1,
num_channels=3,
Expand All @@ -288,9 +289,10 @@ def __init__(
stages=(1, 2, 7, 2),
dim_mul=2.0,
head_mul=2.0,
window_pos_embed_bkg_spatial_size=(7, 7),
window_positional_embedding_background_size=(7, 7),
window_spec=(8, 4, 14, 7),
global_att_blocks=(5, 7, 9),
global_attention_blocks=(5, 7, 9),
skip_lowest_resolutions=1,
backbone_channel_list=[768, 384, 192, 96],
fpn_hidden_size=256,
fpn_kernel_size=1,
Expand All @@ -308,7 +310,6 @@ def __init__(
assert len(stages) == len(window_spec) == len(backbone_channel_list)
assert fuse_type in ["sum", "avg"]

self.scalp = scalp
self.hidden_size = hidden_size
self.num_heads = num_heads
self.num_channels = num_channels
Expand All @@ -322,9 +323,10 @@ def __init__(
self.stages = stages
self.dim_mul = dim_mul
self.head_mul = head_mul
self.window_pos_embed_bkg_spatial_size = window_pos_embed_bkg_spatial_size
self.window_positional_embedding_background_size = window_positional_embedding_background_size
self.window_spec = window_spec
self.global_att_blocks = global_att_blocks
self.global_attention_blocks = global_attention_blocks
self.skip_lowest_resolutions = skip_lowest_resolutions

# Neck
self.backbone_channel_list = backbone_channel_list
Expand Down
62 changes: 56 additions & 6 deletions src/transformers/models/sam2/convert_sam2_to_hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,15 +82,15 @@ def get_config(model_name):
"mask_downscaling.6": "mask_embed.conv3",
"point_embeddings": "point_embed",
"pe_layer.positional_encoding_gaussian_matrix": "shared_embedding.positional_embedding",
"image_encoder": "vision_encoder",
"vision_encoder": "image_encoder",
"sam_prompt_encoder": "prompt_encoder",
"sam_mask_decoder": "mask_decoder",
"neck.0": "neck.conv1",
"neck.1": "neck.layer_norm1",
"neck.2": "neck.conv2",
"neck.3": "neck.layer_norm2",
"patch_embed.proj": "patch_embed.projection",
".norm": ".layer_norm",
"blocks": "layers",
"trunk.layers": "blocks",
"trunk.": "",
}

Expand All @@ -101,15 +101,41 @@ def replace_keys(state_dict):
state_dict.pop("pixel_std", None)

output_hypernetworks_mlps_pattern = r".*.output_hypernetworks_mlps.(\d+).layers.(\d+).*"
output_image_encoder_pattern = r"patch_embed.*.*"
output_mask_decoder_mlps_pattern = r"mask_decoder.transformer.layers.(\d+).mlp.layers.(\d+).*"
output_mask_decoder_score_head_pattern = r"mask_decoder.pred_obj_score_head.layers.(\d+).*"
output_image_encoder_mlps_pattern = r"image_encoder.blocks.(\d+).mlp.layers.(\d+).*"
output_image_encoder_neck_pattern = r"image_encoder.neck.convs.(\d+).conv"

for key, value in state_dict.items():
for key_to_modify, new_key in KEYS_TO_MODIFY_MAPPING.items():
if key_to_modify in key:
key = key.replace(key_to_modify, new_key)

if re.match(output_image_encoder_pattern, key):
key = key.replace("projection", "proj")
# image_encoder.blocks.0.mlp.layers.1.weight -> image_encoder.blocks.0.mlp.proj_out.weight
if re.match(output_image_encoder_mlps_pattern, key):
layer_nb = int(re.match(output_image_encoder_mlps_pattern, key).group(2))
if layer_nb == 0:
key = key.replace("layers.0", "proj_in")
elif layer_nb == 1:
key = key.replace("layers.1", "proj_out")

# mask_decoder.transformer.layers.0.mlp.layers.1.weight -> mask_decoder.transformer.layers.1.mlp.proj_out.weight
if re.match(output_mask_decoder_mlps_pattern, key):
layer_nb = int(re.match(output_mask_decoder_mlps_pattern, key).group(2))
if layer_nb == 0:
key = key.replace("mlp.layers.0", "mlp.proj_in")
elif layer_nb == 1:
key = key.replace("mlp.layers.1", "mlp.proj_out")

# mask_decoder.pred_obj_score_head.layers.1.weight -> mask_decoder.pred_obj_score_head.proj_in.weight
if re.match(output_mask_decoder_score_head_pattern, key):
layer_nb = int(re.match(output_mask_decoder_score_head_pattern, key).group(1))
if layer_nb == 0:
key = key.replace("layers.0", "proj_in")
elif layer_nb == 1:
key = key.replace("layers.1", "layers.0")
elif layer_nb == 2:
key = key.replace("layers.2", "proj_out")

if re.match(output_hypernetworks_mlps_pattern, key):
layer_nb = int(re.match(output_hypernetworks_mlps_pattern, key).group(2))
Expand All @@ -120,6 +146,10 @@ def replace_keys(state_dict):
elif layer_nb == 2:
key = key.replace("layers.2", "proj_out")

# image_encoder.neck.convs.1.conv.bias -> image_encoder.neck.convs.1.bias
if re.match(output_image_encoder_neck_pattern, key):
key = key.replace(".conv.", ".")

model_state_dict[key] = value

model_state_dict["shared_image_embedding.positional_embedding"] = model_state_dict[
Expand All @@ -135,6 +165,26 @@ def convert_sam2_checkpoint(model_name, checkpoint_path, pytorch_dump_folder, pu
state_dict = torch.load(checkpoint_path, map_location="cpu")
state_dict = replace_keys(state_dict)

# TO DO : This is temp code for pass video part.
def should_delete_key(key: str) -> bool:
# Define pattern prefixes to match
patterns = {
"maskmem_tpos_enc",
"no_mem_embed",
"no_mem_pos_enc",
"no_obj_ptr",
"mask_downsample",
"obj_ptr_proj",
"memory_attention",
"memory_encoder.fuser",
}

# Quick check using startswith for any pattern
return any(key.startswith(pattern) for pattern in patterns)

# Usage:
state_dict = {key: value for key, value in state_dict.items() if not should_delete_key(key)}

image_processor = Sam2ImageProcessor()
processor = Sam2Processor(image_processor=image_processor)
hf_model = Sam2Model(config)
Expand Down
72 changes: 43 additions & 29 deletions src/transformers/models/sam2/modeling_sam2.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,13 +300,15 @@ def __init__(self, config: Sam2ImageEncoderConfig):
# Patch embdding
self.patch_embed = Sam2PatchEmbeddings(config)
# Windowed positional embedding (https://arxiv.org/abs/2311.05613)
self.pos_embed = nn.Parameter(torch.zeros(1, config.hidden_size, *config.window_pos_embed_bkg_spatial_size))
self.pos_embed = nn.Parameter(
torch.zeros(1, config.hidden_size, *config.window_positional_embedding_background_size)
)
self.pos_embed_window = nn.Parameter(
torch.zeros(1, config.hidden_size, config.window_spec[0], config.window_spec[0])
)

self.stage_ends = [sum(config.stages[:i]) - 1 for i in range(1, len(config.stages) + 1)]
self.global_att_blocks = config.global_att_blocks
self.global_attention_blocks = config.global_attention_blocks

self.blocks = nn.ModuleList()
embed_dim = config.hidden_size
Expand All @@ -323,8 +325,8 @@ def __init__(self, config: Sam2ImageEncoderConfig):
# of previous stage and final window size of current stage
window_size = config.window_spec[cur_stage - 1]

if self.global_att_blocks is not None:
window_size = 0 if i in self.global_att_blocks else window_size
if self.global_attention_blocks is not None:
window_size = 0 if i in self.global_attention_blocks else window_size

if i - 1 in self.stage_ends:
dim_out = int(embed_dim * config.dim_mul)
Expand All @@ -345,7 +347,7 @@ def __init__(self, config: Sam2ImageEncoderConfig):
self.blocks.append(block)

self.neck = Sam2VisionNeck(config)
self.scalp = config.scalp
self.skip_lowest_resolutions = config.skip_lowest_resolutions

def _get_pos_embed(self, hw: Tuple[int, int]) -> torch.Tensor:
h, w = hw
Expand Down Expand Up @@ -396,11 +398,11 @@ def forward(

# Forward through backbone
fpn_hidden_states, fpn_position_encoding = self.neck(intermediate_hidden_states)
if self.scalp > 0:
if self.skip_lowest_resolutions > 0:
# Discard the lowest resolution features
fpn_hidden_states, fpn_position_encoding = (
fpn_hidden_states[: -self.scalp],
fpn_position_encoding[: -self.scalp],
fpn_hidden_states[: -self.skip_lowest_resolutions],
fpn_position_encoding[: -self.skip_lowest_resolutions],
)

if not return_dict:
Expand Down Expand Up @@ -602,12 +604,13 @@ def __init__(self, config: Sam2MaskDecoderConfig):
super().__init__()
self.config = config

self.transformer = Sam2TwoWayTransformer(config)
self.num_mask_tokens = config.num_multimask_outputs + 1

self.iou_token = nn.Embedding(1, config.hidden_size)
self.num_mask_tokens = config.num_multimask_outputs + 1
self.mask_tokens = nn.Embedding(self.num_mask_tokens, config.hidden_size)

self.transformer = Sam2TwoWayTransformer(config)

self.pred_obj_scores = config.pred_obj_scores
if self.pred_obj_scores:
self.obj_score_token = nn.Embedding(1, config.hidden_size)
Expand All @@ -627,23 +630,31 @@ def __init__(self, config: Sam2MaskDecoderConfig):

self.output_hypernetworks_mlps = nn.ModuleList(
[
Sam2MLP(config.hidden_size, config.hidden_size, config.hidden_size // 8, 3, activation="relu")
for i in range(self.num_mask_tokens)
Sam2FeedForward(
config.hidden_size,
config.hidden_size,
config.hidden_size // 8,
3,
activation=config.feed_forward_hidden_act,
)
for _ in range(self.num_mask_tokens)
]
)

self.iou_prediction_head = Sam2MLP(
self.iou_prediction_head = Sam2FeedForward(
config.hidden_size,
config.iou_head_hidden_dim,
self.num_mask_tokens,
config.iou_head_depth,
activation="relu",
activation=config.feed_forward_hidden_act,
sigmoid_output=config.iou_prediction_use_sigmoid,
)
if config.pred_obj_scores:
self.pred_obj_score_head = nn.Linear(config.hidden_size, 1)
if config.pred_obj_scores_mlp:
self.pred_obj_score_head = Sam2MLP(config.hidden_size, config.hidden_size, 1, 3, activation="relu")
self.pred_obj_score_head = Sam2FeedForward(
config.hidden_size, config.hidden_size, 1, 3, activation="relu"
)

# When outputting a single mask, optionally we can dynamically fall back to the best
# multimask output token if the single mask output token gives low stability scores.
Expand Down Expand Up @@ -845,7 +856,7 @@ def __init__(
)
self.layer_norm2 = nn.LayerNorm(config.two_way_transformer_embedding_dim)

self.mlp = Sam2MLP(
self.mlp = Sam2FeedForward(
config.two_way_transformer_embedding_dim,
config.two_way_transformer_mlp_dim,
config.two_way_transformer_embedding_dim,
Expand Down Expand Up @@ -1173,9 +1184,7 @@ def get_clones(module, N):
return nn.ModuleList([copy.deepcopy(module) for i in range(N)])


# Lightly adapted from
# https://github.com/facebookresearch/MaskFormer/blob/main/mask_former/modeling/transformer/transformer_predictor.py # noqa
class Sam2MLP(nn.Module):
class Sam2FeedForward(nn.Module):
def __init__(
self,
input_dim: int,
Expand All @@ -1184,20 +1193,25 @@ def __init__(
num_layers: int,
activation: str = "gelu",
sigmoid_output: bool = False,
) -> None:
):
super().__init__()
self.num_layers = num_layers
h = [hidden_dim] * (num_layers - 1)
self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]))
self.sigmoid_output = sigmoid_output
self.activation = ACT2FN[activation]
self.proj_in = nn.Linear(input_dim, hidden_dim)
self.proj_out = nn.Linear(hidden_dim, output_dim)
self.layers = nn.ModuleList([nn.Linear(hidden_dim, hidden_dim) for _ in range(num_layers - 2)])
self.sigmoid_output = sigmoid_output

def forward(self, x):
for i, layer in enumerate(self.layers):
x = self.activation(layer(x)) if i < self.num_layers - 1 else layer(x)
def forward(self, hidden_states):
hidden_states = self.proj_in(hidden_states)
hidden_states = self.activation(hidden_states)
for layer in self.layers:
hidden_states = self.activation(layer(hidden_states))

hidden_states = self.proj_out(hidden_states)
if self.sigmoid_output:
x = F.sigmoid(x)
return x
hidden_states = F.sigmoid(hidden_states)
return hidden_states


# Copied from transformers.models.convnext.modeling_convnext.ConvNextLayerNorm with ConvNext->Sam2
Expand Down Expand Up @@ -1371,7 +1385,7 @@ def __init__(
self.drop_path = Sam2DropPath(drop_path) if drop_path > 0.0 else nn.Identity()

self.layer_norm2 = nn.LayerNorm(dim_out, eps=config.layer_norm_eps)
self.mlp = Sam2MLP(
self.mlp = Sam2FeedForward(
dim_out,
int(dim_out * mlp_ratio),
dim_out,
Expand Down

0 comments on commit 9182af6

Please sign in to comment.