From 9182af60b166f1fc52b9c59175c10bbd3f3133a4 Mon Sep 17 00:00:00 2001 From: sangbumchoi Date: Mon, 28 Oct 2024 14:36:40 +0000 Subject: [PATCH] TO DO : fix the image_embeddings and sparse_embeddings part --- .../models/sam2/configuration_sam2.py | 22 +++--- .../models/sam2/convert_sam2_to_hf.py | 62 ++++++++++++++-- src/transformers/models/sam2/modeling_sam2.py | 72 +++++++++++-------- 3 files changed, 111 insertions(+), 45 deletions(-) diff --git a/src/transformers/models/sam2/configuration_sam2.py b/src/transformers/models/sam2/configuration_sam2.py index 4702a2079b0fa2..30db298accf669 100644 --- a/src/transformers/models/sam2/configuration_sam2.py +++ b/src/transformers/models/sam2/configuration_sam2.py @@ -178,6 +178,7 @@ def __init__( pred_obj_scores=True, pred_obj_scores_mlp=True, use_multimask_token_for_obj_ptr=True, + feed_forward_hidden_act="relu", two_way_transformer_depth=2, two_way_transformer_embedding_dim=256, two_way_transformer_num_heads=8, @@ -202,6 +203,7 @@ def __init__( self.pred_obj_scores = pred_obj_scores self.pred_obj_scores_mlp = pred_obj_scores_mlp self.use_multimask_token_for_obj_ptr = use_multimask_token_for_obj_ptr + self.feed_forward_hidden_act = feed_forward_hidden_act # TwoWayTransformer configuration self.two_way_transformer_depth = two_way_transformer_depth @@ -223,8 +225,8 @@ class Sam2ImageEncoderConfig(PretrainedConfig): documentation from [`PretrainedConfig`] for more information. Args: - scalp (`int`, *optional*, defaults to 1): - The scalp parameter for the image encoder. + skip_lowest_resolutions (`int`, *optional*, defaults to 1): + The skip_lowest_resolutions parameter for the image encoder. hidden_size (``, *optional*, defaults to 96): num_heads (`int`, *optional*, defaults to 1): Initial number of attention heads. @@ -245,11 +247,11 @@ class Sam2ImageEncoderConfig(PretrainedConfig): Dimension multiplier factor at stage shift. head_mul (`float`, *optional*, defaults to 2.0): Head multiplier factor at stage shift. - window_pos_embed_bkg_spatial_size (`Tuple[int, int]`, *optional*, defaults to `(7, 7)`): + window_positional_embedding_background_size (`Tuple[int, int]`, *optional*, defaults to `(7, 7)`): Window size per stage when not using global attention. window_spec (`Tuple[int, ...]`, *optional*, defaults to `(8, 4, 14, 7)`): Window specifications for each stage. - global_att_blocks (`Tuple[int, ...]`, *optional*, defaults to `(5, 7, 9)`): + global_attention_blocks (`Tuple[int, ...]`, *optional*, defaults to `(5, 7, 9)`): Blocks where global attention is used. d_model (`int`, *optional*, defaults to 256): Dimension of the model in the neck. @@ -274,7 +276,6 @@ class Sam2ImageEncoderConfig(PretrainedConfig): def __init__( self, - scalp=1, hidden_size=96, num_heads=1, num_channels=3, @@ -288,9 +289,10 @@ def __init__( stages=(1, 2, 7, 2), dim_mul=2.0, head_mul=2.0, - window_pos_embed_bkg_spatial_size=(7, 7), + window_positional_embedding_background_size=(7, 7), window_spec=(8, 4, 14, 7), - global_att_blocks=(5, 7, 9), + global_attention_blocks=(5, 7, 9), + skip_lowest_resolutions=1, backbone_channel_list=[768, 384, 192, 96], fpn_hidden_size=256, fpn_kernel_size=1, @@ -308,7 +310,6 @@ def __init__( assert len(stages) == len(window_spec) == len(backbone_channel_list) assert fuse_type in ["sum", "avg"] - self.scalp = scalp self.hidden_size = hidden_size self.num_heads = num_heads self.num_channels = num_channels @@ -322,9 +323,10 @@ def __init__( self.stages = stages self.dim_mul = dim_mul self.head_mul = head_mul - self.window_pos_embed_bkg_spatial_size = window_pos_embed_bkg_spatial_size + self.window_positional_embedding_background_size = window_positional_embedding_background_size self.window_spec = window_spec - self.global_att_blocks = global_att_blocks + self.global_attention_blocks = global_attention_blocks + self.skip_lowest_resolutions = skip_lowest_resolutions # Neck self.backbone_channel_list = backbone_channel_list diff --git a/src/transformers/models/sam2/convert_sam2_to_hf.py b/src/transformers/models/sam2/convert_sam2_to_hf.py index d7698d68cdd444..1e3424c94085b1 100644 --- a/src/transformers/models/sam2/convert_sam2_to_hf.py +++ b/src/transformers/models/sam2/convert_sam2_to_hf.py @@ -82,15 +82,15 @@ def get_config(model_name): "mask_downscaling.6": "mask_embed.conv3", "point_embeddings": "point_embed", "pe_layer.positional_encoding_gaussian_matrix": "shared_embedding.positional_embedding", - "image_encoder": "vision_encoder", + "vision_encoder": "image_encoder", + "sam_prompt_encoder": "prompt_encoder", + "sam_mask_decoder": "mask_decoder", "neck.0": "neck.conv1", "neck.1": "neck.layer_norm1", "neck.2": "neck.conv2", "neck.3": "neck.layer_norm2", "patch_embed.proj": "patch_embed.projection", ".norm": ".layer_norm", - "blocks": "layers", - "trunk.layers": "blocks", "trunk.": "", } @@ -101,15 +101,41 @@ def replace_keys(state_dict): state_dict.pop("pixel_std", None) output_hypernetworks_mlps_pattern = r".*.output_hypernetworks_mlps.(\d+).layers.(\d+).*" - output_image_encoder_pattern = r"patch_embed.*.*" + output_mask_decoder_mlps_pattern = r"mask_decoder.transformer.layers.(\d+).mlp.layers.(\d+).*" + output_mask_decoder_score_head_pattern = r"mask_decoder.pred_obj_score_head.layers.(\d+).*" + output_image_encoder_mlps_pattern = r"image_encoder.blocks.(\d+).mlp.layers.(\d+).*" + output_image_encoder_neck_pattern = r"image_encoder.neck.convs.(\d+).conv" for key, value in state_dict.items(): for key_to_modify, new_key in KEYS_TO_MODIFY_MAPPING.items(): if key_to_modify in key: key = key.replace(key_to_modify, new_key) - if re.match(output_image_encoder_pattern, key): - key = key.replace("projection", "proj") + # image_encoder.blocks.0.mlp.layers.1.weight -> image_encoder.blocks.0.mlp.proj_out.weight + if re.match(output_image_encoder_mlps_pattern, key): + layer_nb = int(re.match(output_image_encoder_mlps_pattern, key).group(2)) + if layer_nb == 0: + key = key.replace("layers.0", "proj_in") + elif layer_nb == 1: + key = key.replace("layers.1", "proj_out") + + # mask_decoder.transformer.layers.0.mlp.layers.1.weight -> mask_decoder.transformer.layers.1.mlp.proj_out.weight + if re.match(output_mask_decoder_mlps_pattern, key): + layer_nb = int(re.match(output_mask_decoder_mlps_pattern, key).group(2)) + if layer_nb == 0: + key = key.replace("mlp.layers.0", "mlp.proj_in") + elif layer_nb == 1: + key = key.replace("mlp.layers.1", "mlp.proj_out") + + # mask_decoder.pred_obj_score_head.layers.1.weight -> mask_decoder.pred_obj_score_head.proj_in.weight + if re.match(output_mask_decoder_score_head_pattern, key): + layer_nb = int(re.match(output_mask_decoder_score_head_pattern, key).group(1)) + if layer_nb == 0: + key = key.replace("layers.0", "proj_in") + elif layer_nb == 1: + key = key.replace("layers.1", "layers.0") + elif layer_nb == 2: + key = key.replace("layers.2", "proj_out") if re.match(output_hypernetworks_mlps_pattern, key): layer_nb = int(re.match(output_hypernetworks_mlps_pattern, key).group(2)) @@ -120,6 +146,10 @@ def replace_keys(state_dict): elif layer_nb == 2: key = key.replace("layers.2", "proj_out") + # image_encoder.neck.convs.1.conv.bias -> image_encoder.neck.convs.1.bias + if re.match(output_image_encoder_neck_pattern, key): + key = key.replace(".conv.", ".") + model_state_dict[key] = value model_state_dict["shared_image_embedding.positional_embedding"] = model_state_dict[ @@ -135,6 +165,26 @@ def convert_sam2_checkpoint(model_name, checkpoint_path, pytorch_dump_folder, pu state_dict = torch.load(checkpoint_path, map_location="cpu") state_dict = replace_keys(state_dict) + # TO DO : This is temp code for pass video part. + def should_delete_key(key: str) -> bool: + # Define pattern prefixes to match + patterns = { + "maskmem_tpos_enc", + "no_mem_embed", + "no_mem_pos_enc", + "no_obj_ptr", + "mask_downsample", + "obj_ptr_proj", + "memory_attention", + "memory_encoder.fuser", + } + + # Quick check using startswith for any pattern + return any(key.startswith(pattern) for pattern in patterns) + + # Usage: + state_dict = {key: value for key, value in state_dict.items() if not should_delete_key(key)} + image_processor = Sam2ImageProcessor() processor = Sam2Processor(image_processor=image_processor) hf_model = Sam2Model(config) diff --git a/src/transformers/models/sam2/modeling_sam2.py b/src/transformers/models/sam2/modeling_sam2.py index 26b786ede501c4..a4f361371be1fa 100644 --- a/src/transformers/models/sam2/modeling_sam2.py +++ b/src/transformers/models/sam2/modeling_sam2.py @@ -300,13 +300,15 @@ def __init__(self, config: Sam2ImageEncoderConfig): # Patch embdding self.patch_embed = Sam2PatchEmbeddings(config) # Windowed positional embedding (https://arxiv.org/abs/2311.05613) - self.pos_embed = nn.Parameter(torch.zeros(1, config.hidden_size, *config.window_pos_embed_bkg_spatial_size)) + self.pos_embed = nn.Parameter( + torch.zeros(1, config.hidden_size, *config.window_positional_embedding_background_size) + ) self.pos_embed_window = nn.Parameter( torch.zeros(1, config.hidden_size, config.window_spec[0], config.window_spec[0]) ) self.stage_ends = [sum(config.stages[:i]) - 1 for i in range(1, len(config.stages) + 1)] - self.global_att_blocks = config.global_att_blocks + self.global_attention_blocks = config.global_attention_blocks self.blocks = nn.ModuleList() embed_dim = config.hidden_size @@ -323,8 +325,8 @@ def __init__(self, config: Sam2ImageEncoderConfig): # of previous stage and final window size of current stage window_size = config.window_spec[cur_stage - 1] - if self.global_att_blocks is not None: - window_size = 0 if i in self.global_att_blocks else window_size + if self.global_attention_blocks is not None: + window_size = 0 if i in self.global_attention_blocks else window_size if i - 1 in self.stage_ends: dim_out = int(embed_dim * config.dim_mul) @@ -345,7 +347,7 @@ def __init__(self, config: Sam2ImageEncoderConfig): self.blocks.append(block) self.neck = Sam2VisionNeck(config) - self.scalp = config.scalp + self.skip_lowest_resolutions = config.skip_lowest_resolutions def _get_pos_embed(self, hw: Tuple[int, int]) -> torch.Tensor: h, w = hw @@ -396,11 +398,11 @@ def forward( # Forward through backbone fpn_hidden_states, fpn_position_encoding = self.neck(intermediate_hidden_states) - if self.scalp > 0: + if self.skip_lowest_resolutions > 0: # Discard the lowest resolution features fpn_hidden_states, fpn_position_encoding = ( - fpn_hidden_states[: -self.scalp], - fpn_position_encoding[: -self.scalp], + fpn_hidden_states[: -self.skip_lowest_resolutions], + fpn_position_encoding[: -self.skip_lowest_resolutions], ) if not return_dict: @@ -602,12 +604,13 @@ def __init__(self, config: Sam2MaskDecoderConfig): super().__init__() self.config = config - self.transformer = Sam2TwoWayTransformer(config) + self.num_mask_tokens = config.num_multimask_outputs + 1 self.iou_token = nn.Embedding(1, config.hidden_size) - self.num_mask_tokens = config.num_multimask_outputs + 1 self.mask_tokens = nn.Embedding(self.num_mask_tokens, config.hidden_size) + self.transformer = Sam2TwoWayTransformer(config) + self.pred_obj_scores = config.pred_obj_scores if self.pred_obj_scores: self.obj_score_token = nn.Embedding(1, config.hidden_size) @@ -627,23 +630,31 @@ def __init__(self, config: Sam2MaskDecoderConfig): self.output_hypernetworks_mlps = nn.ModuleList( [ - Sam2MLP(config.hidden_size, config.hidden_size, config.hidden_size // 8, 3, activation="relu") - for i in range(self.num_mask_tokens) + Sam2FeedForward( + config.hidden_size, + config.hidden_size, + config.hidden_size // 8, + 3, + activation=config.feed_forward_hidden_act, + ) + for _ in range(self.num_mask_tokens) ] ) - self.iou_prediction_head = Sam2MLP( + self.iou_prediction_head = Sam2FeedForward( config.hidden_size, config.iou_head_hidden_dim, self.num_mask_tokens, config.iou_head_depth, - activation="relu", + activation=config.feed_forward_hidden_act, sigmoid_output=config.iou_prediction_use_sigmoid, ) if config.pred_obj_scores: self.pred_obj_score_head = nn.Linear(config.hidden_size, 1) if config.pred_obj_scores_mlp: - self.pred_obj_score_head = Sam2MLP(config.hidden_size, config.hidden_size, 1, 3, activation="relu") + self.pred_obj_score_head = Sam2FeedForward( + config.hidden_size, config.hidden_size, 1, 3, activation="relu" + ) # When outputting a single mask, optionally we can dynamically fall back to the best # multimask output token if the single mask output token gives low stability scores. @@ -845,7 +856,7 @@ def __init__( ) self.layer_norm2 = nn.LayerNorm(config.two_way_transformer_embedding_dim) - self.mlp = Sam2MLP( + self.mlp = Sam2FeedForward( config.two_way_transformer_embedding_dim, config.two_way_transformer_mlp_dim, config.two_way_transformer_embedding_dim, @@ -1173,9 +1184,7 @@ def get_clones(module, N): return nn.ModuleList([copy.deepcopy(module) for i in range(N)]) -# Lightly adapted from -# https://github.com/facebookresearch/MaskFormer/blob/main/mask_former/modeling/transformer/transformer_predictor.py # noqa -class Sam2MLP(nn.Module): +class Sam2FeedForward(nn.Module): def __init__( self, input_dim: int, @@ -1184,20 +1193,25 @@ def __init__( num_layers: int, activation: str = "gelu", sigmoid_output: bool = False, - ) -> None: + ): super().__init__() self.num_layers = num_layers - h = [hidden_dim] * (num_layers - 1) - self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim])) - self.sigmoid_output = sigmoid_output self.activation = ACT2FN[activation] + self.proj_in = nn.Linear(input_dim, hidden_dim) + self.proj_out = nn.Linear(hidden_dim, output_dim) + self.layers = nn.ModuleList([nn.Linear(hidden_dim, hidden_dim) for _ in range(num_layers - 2)]) + self.sigmoid_output = sigmoid_output - def forward(self, x): - for i, layer in enumerate(self.layers): - x = self.activation(layer(x)) if i < self.num_layers - 1 else layer(x) + def forward(self, hidden_states): + hidden_states = self.proj_in(hidden_states) + hidden_states = self.activation(hidden_states) + for layer in self.layers: + hidden_states = self.activation(layer(hidden_states)) + + hidden_states = self.proj_out(hidden_states) if self.sigmoid_output: - x = F.sigmoid(x) - return x + hidden_states = F.sigmoid(hidden_states) + return hidden_states # Copied from transformers.models.convnext.modeling_convnext.ConvNextLayerNorm with ConvNext->Sam2 @@ -1371,7 +1385,7 @@ def __init__( self.drop_path = Sam2DropPath(drop_path) if drop_path > 0.0 else nn.Identity() self.layer_norm2 = nn.LayerNorm(dim_out, eps=config.layer_norm_eps) - self.mlp = Sam2MLP( + self.mlp = Sam2FeedForward( dim_out, int(dim_out * mlp_ratio), dim_out,