diff --git a/src/transformers/models/sam2/configuration_sam2.py b/src/transformers/models/sam2/configuration_sam2.py index ba99ca8ec032ce..e61b145989ddd2 100644 --- a/src/transformers/models/sam2/configuration_sam2.py +++ b/src/transformers/models/sam2/configuration_sam2.py @@ -157,13 +157,27 @@ class Sam2MemoryEncoderConfig(PretrainedConfig): def __init__( self, - in_dim=256, - out_dim=64, + hidden_size=256, + output_channels=64, + mask_downsampler_embed_dim=256, + mask_downsampler_kernel_size=4, + mask_downsampler_stride=4, + mask_downsampler_padding=0, + mask_downsampler_total_stride=16, + mask_downsampler_hidden_act="gelu", + memory_fuser_num_layers=2, + memory_fuser_embed_dim=256, + memory_fuser_input_projection=False, + memory_fuser_num_layers=2, + memory_fuser_kernel_size=7, + memory_fuser_padding=3, **kwargs, ): super().__init__(**kwargs) - self.in_dim = in_dim - self.out_dim = out_dim + assert mask_downsampler_stride**int(math.log2(mask_downsampler_total_stride) // math.log2(mask_downsampler_stride)) == mask_downsampler_total_stride + + self.hidden_size = hidden_size + self.output_channels = output_channels class Sam2MaskDecoderConfig(PretrainedConfig): diff --git a/src/transformers/models/sam2/modeling_sam2.py b/src/transformers/models/sam2/modeling_sam2.py index 4defe1722283c9..b8ecf723bc4b73 100644 --- a/src/transformers/models/sam2/modeling_sam2.py +++ b/src/transformers/models/sam2/modeling_sam2.py @@ -1744,14 +1744,13 @@ class Sam2MemoryFuserCXBlock(nn.Module): def __init__( self, - dim, - kernel_size=7, - padding=3, + config, drop_path=0.0, layer_scale_init_value=1e-6, use_dwconv=True, ): super().__init__() + embed_dim = config. self.dwconv = nn.Conv2d( dim, dim, @@ -1787,19 +1786,18 @@ def forward(self, x): class Sam2MemoryFuser(nn.Module): - def __init__(self, num_layers, dim=None, input_projection=False): + def __init__(self, config): super().__init__() - self.proj = nn.Identity() - layer = Sam2MemoryFuserCXBlock(dim=256, kernel_size=7) - self.layers = get_clones(layer, num_layers) - - if input_projection: - assert dim is not None - self.proj = nn.Conv2d(dim, dim, kernel_size=1) + self.input_projection = nn.Identity() + layer = Sam2MemoryFuserCXBlock(config) + self.layers = get_clones(layer, config.memory_fuser_num_layers) + if config.memory_fuser_input_projection: + assert config.memory_fuser_embed_dim is not None + self.input_projection = nn.Conv2d(dim, dim, kernel_size=1) def forward(self, x): # normally x: (N, C, H, W) - x = self.proj(x) + x = self.input_projection(x) for layer in self.layers: x = layer(x) return x @@ -1816,34 +1814,31 @@ class Sam2MaskDownSampler(nn.Module): def __init__( self, - embed_dim=256, - kernel_size=4, - stride=4, - padding=0, - total_stride=16, - activation=nn.GELU, + config, ): super().__init__() - num_layers = int(math.log2(total_stride) // math.log2(stride)) - assert stride**num_layers == total_stride + + num_layers = int(math.log2(config.mask_downsampler_total_stride) // math.log2(config.mask_downsampler_stride)) + self.encoder = nn.Sequential() + self.activation = ACT2FN(config.mask_downsampler_hidden_act) mask_in_chans, mask_out_chans = 1, 1 for _ in range(num_layers): - mask_out_chans = mask_in_chans * (stride**2) + mask_out_chans = mask_in_chans * (config.mask_downsampler_stride**2) self.encoder.append( nn.Conv2d( mask_in_chans, mask_out_chans, - kernel_size=kernel_size, - stride=stride, - padding=padding, + kernel_size=config.mask_downsampler_kernel_size, + stride=config.mask_downsampler_stride, + padding=config.mask_downsampler_padding, ) ) self.encoder.append(Sam2LayerNorm(mask_out_chans)) - self.encoder.append(activation()) + self.encoder.append(self.activation) mask_in_chans = mask_out_chans - self.encoder.append(nn.Conv2d(mask_out_chans, embed_dim, kernel_size=1)) + self.encoder.append(nn.Conv2d(mask_out_chans, config.mask_downsampler_embed_dim, kernel_size=1)) def forward(self, x): return self.encoder(x) @@ -1856,16 +1851,15 @@ def __init__( ): super().__init__() - out_dim = config.out_dim - in_dim = config.in_dim - self.mask_downsampler = Sam2MaskDownSampler(kernel_size=3, stride=2, padding=1) - - self.pix_feat_proj = nn.Conv2d(in_dim, in_dim, kernel_size=1) - self.fuser = Sam2MemoryFuser(num_layers=2) - self.position_encoding = Sam2PositionEmbeddingSine(num_pos_feats=out_dim) - self.out_proj = nn.Identity() - if out_dim != in_dim: - self.out_proj = nn.Conv2d(in_dim, out_dim, kernel_size=1) + hidden_size = config.hidden_size + output_channels = config.output_channels + self.mask_downsampler = Sam2MaskDownSampler(config) + self.feature_projection = nn.Conv2d(hidden_size, hidden_size, kernel_size=1) + self.memory_fuser = Sam2MemoryFuser(config) + self.position_encoding = Sam2PositionEmbeddingSine(num_pos_feats=output_channels) + self.projection = nn.Identity() + if output_channels != hidden_size: + self.projection = nn.Conv2d(hidden_size, output_channels, kernel_size=1) def forward( self, @@ -1883,10 +1877,10 @@ def forward( # in case the visual features are on CPU, cast them to CUDA pix_feat = pix_feat.to(masks.device) - x = self.pix_feat_proj(pix_feat) + x = self.feature_projection(pix_feat) x = x + masks - x = self.fuser(x) - x = self.out_proj(x) + x = self.memory_fuser(x) + x = self.projection(x) pos = self.position_encoding(x).to(x.dtype)