Skip to content

Commit

Permalink
tmp
Browse files Browse the repository at this point in the history
  • Loading branch information
SangbumChoi committed Nov 20, 2024
1 parent 0e64e85 commit c86b3fe
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 43 deletions.
22 changes: 18 additions & 4 deletions src/transformers/models/sam2/configuration_sam2.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,13 +157,27 @@ class Sam2MemoryEncoderConfig(PretrainedConfig):

def __init__(
self,
in_dim=256,
out_dim=64,
hidden_size=256,
output_channels=64,
mask_downsampler_embed_dim=256,
mask_downsampler_kernel_size=4,
mask_downsampler_stride=4,
mask_downsampler_padding=0,
mask_downsampler_total_stride=16,
mask_downsampler_hidden_act="gelu",
memory_fuser_num_layers=2,
memory_fuser_embed_dim=256,
memory_fuser_input_projection=False,
memory_fuser_num_layers=2,
memory_fuser_kernel_size=7,
memory_fuser_padding=3,
**kwargs,
):
super().__init__(**kwargs)
self.in_dim = in_dim
self.out_dim = out_dim
assert mask_downsampler_stride**int(math.log2(mask_downsampler_total_stride) // math.log2(mask_downsampler_stride)) == mask_downsampler_total_stride

self.hidden_size = hidden_size
self.output_channels = output_channels


class Sam2MaskDecoderConfig(PretrainedConfig):
Expand Down
72 changes: 33 additions & 39 deletions src/transformers/models/sam2/modeling_sam2.py
Original file line number Diff line number Diff line change
Expand Up @@ -1744,14 +1744,13 @@ class Sam2MemoryFuserCXBlock(nn.Module):

def __init__(
self,
dim,
kernel_size=7,
padding=3,
config,
drop_path=0.0,
layer_scale_init_value=1e-6,
use_dwconv=True,
):
super().__init__()
embed_dim = config.
self.dwconv = nn.Conv2d(
dim,
dim,
Expand Down Expand Up @@ -1787,19 +1786,18 @@ def forward(self, x):


class Sam2MemoryFuser(nn.Module):
def __init__(self, num_layers, dim=None, input_projection=False):
def __init__(self, config):
super().__init__()
self.proj = nn.Identity()
layer = Sam2MemoryFuserCXBlock(dim=256, kernel_size=7)
self.layers = get_clones(layer, num_layers)

if input_projection:
assert dim is not None
self.proj = nn.Conv2d(dim, dim, kernel_size=1)
self.input_projection = nn.Identity()
layer = Sam2MemoryFuserCXBlock(config)
self.layers = get_clones(layer, config.memory_fuser_num_layers)
if config.memory_fuser_input_projection:
assert config.memory_fuser_embed_dim is not None
self.input_projection = nn.Conv2d(dim, dim, kernel_size=1)

def forward(self, x):
# normally x: (N, C, H, W)
x = self.proj(x)
x = self.input_projection(x)
for layer in self.layers:
x = layer(x)
return x
Expand All @@ -1816,34 +1814,31 @@ class Sam2MaskDownSampler(nn.Module):

def __init__(
self,
embed_dim=256,
kernel_size=4,
stride=4,
padding=0,
total_stride=16,
activation=nn.GELU,
config,
):
super().__init__()
num_layers = int(math.log2(total_stride) // math.log2(stride))
assert stride**num_layers == total_stride

num_layers = int(math.log2(config.mask_downsampler_total_stride) // math.log2(config.mask_downsampler_stride))

self.encoder = nn.Sequential()
self.activation = ACT2FN(config.mask_downsampler_hidden_act)
mask_in_chans, mask_out_chans = 1, 1
for _ in range(num_layers):
mask_out_chans = mask_in_chans * (stride**2)
mask_out_chans = mask_in_chans * (config.mask_downsampler_stride**2)
self.encoder.append(
nn.Conv2d(
mask_in_chans,
mask_out_chans,
kernel_size=kernel_size,
stride=stride,
padding=padding,
kernel_size=config.mask_downsampler_kernel_size,
stride=config.mask_downsampler_stride,
padding=config.mask_downsampler_padding,
)
)
self.encoder.append(Sam2LayerNorm(mask_out_chans))
self.encoder.append(activation())
self.encoder.append(self.activation)
mask_in_chans = mask_out_chans

self.encoder.append(nn.Conv2d(mask_out_chans, embed_dim, kernel_size=1))
self.encoder.append(nn.Conv2d(mask_out_chans, config.mask_downsampler_embed_dim, kernel_size=1))

def forward(self, x):
return self.encoder(x)
Expand All @@ -1856,16 +1851,15 @@ def __init__(
):
super().__init__()

out_dim = config.out_dim
in_dim = config.in_dim
self.mask_downsampler = Sam2MaskDownSampler(kernel_size=3, stride=2, padding=1)

self.pix_feat_proj = nn.Conv2d(in_dim, in_dim, kernel_size=1)
self.fuser = Sam2MemoryFuser(num_layers=2)
self.position_encoding = Sam2PositionEmbeddingSine(num_pos_feats=out_dim)
self.out_proj = nn.Identity()
if out_dim != in_dim:
self.out_proj = nn.Conv2d(in_dim, out_dim, kernel_size=1)
hidden_size = config.hidden_size
output_channels = config.output_channels
self.mask_downsampler = Sam2MaskDownSampler(config)
self.feature_projection = nn.Conv2d(hidden_size, hidden_size, kernel_size=1)
self.memory_fuser = Sam2MemoryFuser(config)
self.position_encoding = Sam2PositionEmbeddingSine(num_pos_feats=output_channels)
self.projection = nn.Identity()
if output_channels != hidden_size:
self.projection = nn.Conv2d(hidden_size, output_channels, kernel_size=1)

def forward(
self,
Expand All @@ -1883,10 +1877,10 @@ def forward(
# in case the visual features are on CPU, cast them to CUDA
pix_feat = pix_feat.to(masks.device)

x = self.pix_feat_proj(pix_feat)
x = self.feature_projection(pix_feat)
x = x + masks
x = self.fuser(x)
x = self.out_proj(x)
x = self.memory_fuser(x)
x = self.projection(x)

pos = self.position_encoding(x).to(x.dtype)

Expand Down

0 comments on commit c86b3fe

Please sign in to comment.