Skip to content

Commit

Permalink
refactor memoryattention
Browse files Browse the repository at this point in the history
  • Loading branch information
SangbumChoi committed Nov 13, 2024
1 parent 9003953 commit 0e64e85
Show file tree
Hide file tree
Showing 2 changed files with 171 additions and 192 deletions.
50 changes: 42 additions & 8 deletions src/transformers/models/sam2/configuration_sam2.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,30 +79,64 @@ class Sam2MemoryAttentionConfig(PretrainedConfig):
documentation from [`PretrainedConfig`] for more information.
Args:
d_model (`int`, *optional*, defaults to 256):
The dimension of the model in the memory attention module.
pos_enc_at_input (`bool`, *optional*, defaults to `True`):
Whether to apply positional encoding at the input.
hidden_size (`<fill_type>`, *optional*, defaults to 256): <fill_docstring>
num_layers (`int`, *optional*, defaults to 4):
The number of layers in the memory attention module.
batch_first (`bool`, *optional*, defaults to `True`):
Whether the input and output tensors are provided in batch-first format.
apply_pe_at_input (`<fill_type>`, *optional*, defaults to `True`): <fill_docstring>
hidden_act (`<fill_type>`, *optional*, defaults to `"relu"`): <fill_docstring>
dim_feedforward (`<fill_type>`, *optional*, defaults to 2048): <fill_docstring>
dropout (`<fill_type>`, *optional*, defaults to 0.1): <fill_docstring>
rope_theta (`<fill_type>`, *optional*, defaults to 10000): <fill_docstring>
rope_feat_sizes (`<fill_type>`, *optional*, defaults to `[32, 32]`): <fill_docstring>
rope_embedding_dim (`<fill_type>`, *optional*, defaults to 256): <fill_docstring>
rope_num_heads (`<fill_type>`, *optional*, defaults to 1): <fill_docstring>
rope_downsample_rate (`<fill_type>`, *optional*, defaults to 1): <fill_docstring>
rope_dropout (`<fill_type>`, *optional*, defaults to 0.1): <fill_docstring>
apply_pe_at_self_attn (`<fill_type>`, *optional*, defaults to `False`): <fill_docstring>
apply_pe_at_cross_attn_keys (`<fill_type>`, *optional*, defaults to `True`): <fill_docstring>
apply_pe_at_cross_attn_queries (`<fill_type>`, *optional*, defaults to `False`): <fill_docstring>
"""

def __init__(
self,
d_model=256,
pos_enc_at_input=True,
hidden_size=256,
num_layers=4,
batch_first=True,
apply_pe_at_input=True,
hidden_act="relu",
dim_feedforward=2048,
dropout=0.1,
rope_theta=10000,
rope_feat_sizes=[32, 32],
rope_embedding_dim=256,
rope_num_heads=1,
rope_downsample_rate=1,
rope_dropout=0.1,
apply_pe_at_self_attn=False,
apply_pe_at_cross_attn_keys=True,
apply_pe_at_cross_attn_queries=False,
**kwargs,
):
super().__init__(**kwargs)
self.d_model = d_model
self.pos_enc_at_input = pos_enc_at_input
self.hidden_size = hidden_size
self.num_layers = num_layers
self.batch_first = batch_first
self.apply_pe_at_input = apply_pe_at_input
self.hidden_act = hidden_act
self.dim_feedforward = dim_feedforward
self.dropout = dropout
self.rope_theta = rope_theta
self.rope_feat_sizes = rope_feat_sizes
self.rope_embedding_dim = rope_embedding_dim
self.rope_num_heads = rope_num_heads
self.rope_downsample_rate = rope_downsample_rate
self.rope_dropout = rope_dropout
self.apply_pe_at_self_attn = apply_pe_at_self_attn
self.apply_pe_at_cross_attn_keys = apply_pe_at_cross_attn_keys
self.apply_pe_at_cross_attn_queries = apply_pe_at_cross_attn_queries


class Sam2MemoryEncoderConfig(PretrainedConfig):
Expand Down
Loading

0 comments on commit 0e64e85

Please sign in to comment.