Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Zamba #30950

Merged
merged 115 commits into from
Oct 4, 2024
Merged
Changes from 1 commit
Commits
Show all changes
115 commits
Select commit Hold shift + click to select a range
7eff1cc
Update index.md
pglorio Aug 6, 2024
14961a2
Rebase
pglorio Aug 6, 2024
b67ff24
Rebase
pglorio Aug 6, 2024
0aa1003
Updates from make fixup
pglorio Aug 6, 2024
5e88653
Update zamba.md
pglorio Aug 6, 2024
123d959
Batched inference
pglorio Aug 14, 2024
f35bdf9
Update
pglorio Aug 14, 2024
1ec90d1
Fix tests
Aug 16, 2024
4d3f8c0
Fix tests
pglorio Aug 16, 2024
e51113d
Fix tests
pglorio Aug 16, 2024
cf6ee16
Fix tests
pglorio Aug 16, 2024
f80b813
Update docs/source/en/model_doc/zamba.md
pglorio Sep 6, 2024
c010a68
Update docs/source/en/model_doc/zamba.md
pglorio Sep 6, 2024
9c3abc8
Update configuration_zamba.py
pglorio Sep 6, 2024
5d3d615
Update src/transformers/models/zamba/modeling_zamba.py
pglorio Sep 6, 2024
d245749
Update src/transformers/models/zamba/modeling_zamba.py
pglorio Sep 6, 2024
663343d
Update src/transformers/models/zamba/modeling_zamba.py
pglorio Sep 6, 2024
b3540ea
Update src/transformers/models/zamba/modeling_zamba.py
pglorio Sep 6, 2024
554c14c
Update modeling_zamba.py
pglorio Sep 6, 2024
c9c97fd
Update modeling_zamba.py
pglorio Sep 6, 2024
ec9edd7
Update modeling_zamba.py
pglorio Sep 6, 2024
939d6a9
Update configuration_zamba.py
pglorio Sep 6, 2024
c26addd
Update modeling_zamba.py
pglorio Sep 7, 2024
e3c93f0
Update modeling_zamba.py
pglorio Sep 7, 2024
58d8c2d
Merge branch 'main' of https://github.com/Zyphra/transformers_zamba
pglorio Sep 9, 2024
396ebff
Update ZambaForCausalLM
pglorio Sep 9, 2024
df8dfd3
Update ZambaForCausalLM
pglorio Sep 9, 2024
4ab88a2
Describe diffs with original mamba layer
pglorio Sep 9, 2024
1a521de
Moved mamba init into `_init_weights`
pglorio Sep 9, 2024
767a591
Moved mamba weight init into _init_weights
pglorio Sep 9, 2024
d5b2beb
Update index.md
pglorio Aug 6, 2024
029813b
Rebase
pglorio Aug 6, 2024
bec7dce
Rebase
pglorio Aug 6, 2024
db15348
Updates from make fixup
pglorio Aug 6, 2024
6c7f812
Update zamba.md
pglorio Aug 6, 2024
c3766ba
Batched inference
pglorio Aug 14, 2024
dff24b8
Update
pglorio Aug 14, 2024
0e9f3c9
Fix tests
Aug 16, 2024
245d9d9
Fix tests
pglorio Aug 16, 2024
8aedd30
Fix tests
pglorio Aug 16, 2024
f8ed17a
Fix tests
pglorio Aug 16, 2024
a5d5873
Update docs/source/en/model_doc/zamba.md
pglorio Sep 6, 2024
17cef25
Update docs/source/en/model_doc/zamba.md
pglorio Sep 6, 2024
f773f12
Update configuration_zamba.py
pglorio Sep 6, 2024
c5852aa
Update src/transformers/models/zamba/modeling_zamba.py
pglorio Sep 6, 2024
da64b36
Update src/transformers/models/zamba/modeling_zamba.py
pglorio Sep 6, 2024
7679578
Update src/transformers/models/zamba/modeling_zamba.py
pglorio Sep 6, 2024
85fe7cb
Update src/transformers/models/zamba/modeling_zamba.py
pglorio Sep 6, 2024
6c949c6
Update modeling_zamba.py
pglorio Sep 6, 2024
f78b627
Update modeling_zamba.py
pglorio Sep 6, 2024
3b3605a
Update modeling_zamba.py
pglorio Sep 6, 2024
f6fc1e8
Update configuration_zamba.py
pglorio Sep 6, 2024
d5b8d6e
Update modeling_zamba.py
pglorio Sep 7, 2024
c2428a4
Update modeling_zamba.py
pglorio Sep 7, 2024
bbc9a8e
Merge branch 'main' of https://github.com/Zyphra/transformers_zamba
pglorio Sep 9, 2024
b13fdde
Update ZambaForCausalLM
pglorio Sep 9, 2024
037b938
Moved mamba init into `_init_weights`
pglorio Sep 9, 2024
9a1ef16
Update ZambaForCausalLM
pglorio Sep 9, 2024
d9d436c
Describe diffs with original mamba layer
pglorio Sep 9, 2024
0bbb6c9
Merge branch 'main' of https://github.com/Zyphra/transformers_zamba
pglorio Sep 10, 2024
91bc076
make fixup fixes
Sep 10, 2024
8f0100f
quality test fixes
pglorio Sep 10, 2024
7478e25
Fix Zamba model path
pglorio Sep 10, 2024
a7c9d17
circleci fixes
pglorio Sep 10, 2024
c2d097f
circleci fixes
pglorio Sep 10, 2024
3788196
circleci fixes
pglorio Sep 10, 2024
1c6cca8
circleci fixes
pglorio Sep 10, 2024
911a78a
circleci fixes
pglorio Sep 10, 2024
c6f2b3f
circleci fixes
pglorio Sep 10, 2024
df93132
circleci fixes
pglorio Sep 10, 2024
211a5b5
circleci fixes
pglorio Sep 11, 2024
e0cb9fe
circleci fixes
pglorio Sep 11, 2024
1df30bb
Update
pglorio Sep 11, 2024
3d2800b
circleci fixes
pglorio Sep 11, 2024
1e6f38b
Merge branch 'huggingface:main' into main
Quentin-Anthony Sep 22, 2024
d93377d
fix zamba test from merge
Quentin-Anthony Sep 22, 2024
d01d80d
fix ValueError for disabling mamba kernels
Quentin-Anthony Sep 22, 2024
b9e86b0
add HF copyright
Quentin-Anthony Sep 23, 2024
4b0fb52
shared_transf --> shared_transformer
Quentin-Anthony Sep 23, 2024
66b72c8
Update src/transformers/models/zamba/modeling_zamba.py
pglorio Sep 24, 2024
d527a14
Update src/transformers/models/zamba/modeling_zamba.py
pglorio Sep 24, 2024
97c646c
Fixes
pglorio Sep 24, 2024
1e4ffe6
Move attention head dim to config
pglorio Sep 24, 2024
2c53db2
Fix circle/ci tests
pglorio Sep 24, 2024
9a1ad32
Merge branch 'huggingface:main' into main
Quentin-Anthony Sep 24, 2024
a7717f2
Update modeling_zamba.py
pglorio Sep 24, 2024
0fae398
apply GenerationMixin inheritance change from upstream
Quentin-Anthony Sep 24, 2024
0304440
apply import ordering
Quentin-Anthony Sep 24, 2024
3d9ec8e
Merge branch 'huggingface:main' into main
Quentin-Anthony Sep 24, 2024
cb1d1d9
Merge branch 'main' into main
Quentin-Anthony Sep 25, 2024
efcf16a
Merge branch 'huggingface:main' into main
Quentin-Anthony Sep 26, 2024
339d4cc
update needed transformers version for zamba
Quentin-Anthony Sep 26, 2024
a46a26b
add contribution author
Quentin-Anthony Sep 26, 2024
d0c1bc1
add @slow to avoid CI
Quentin-Anthony Sep 26, 2024
4fcd130
Merge branch 'huggingface:main' into main
Quentin-Anthony Sep 26, 2024
8d29964
Update src/transformers/models/zamba/modeling_zamba.py
pglorio Sep 27, 2024
0381c33
Define attention_hidden_size
pglorio Sep 27, 2024
75554d8
Merge branch 'huggingface:main' into main
Quentin-Anthony Sep 27, 2024
a109b3f
Added doc for attention_head_size
pglorio Sep 27, 2024
9c10afe
trigger CI
Quentin-Anthony Sep 27, 2024
1880455
Fix doc of attention_hidden_size
pglorio Sep 27, 2024
daef5b0
Merge branch 'huggingface:main' into main
Quentin-Anthony Sep 30, 2024
347f761
[run-slow] zamba
pglorio Sep 30, 2024
634837f
Merge branch 'huggingface:main' into main
Quentin-Anthony Sep 30, 2024
4e8db07
Merge branch 'huggingface:main' into main
Quentin-Anthony Oct 1, 2024
1504774
Fixed shared layer logic, swapped up<->gate in mlp
pglorio Oct 3, 2024
06e3a7a
fix shared layer logic, swap up<->gate in mlp
pglorio Oct 3, 2024
267530d
shared_transformer -> shared_transf
pglorio Oct 4, 2024
0a90fc7
reformat HybridLayer __init__
pglorio Oct 4, 2024
fabaaec
Merge branch 'huggingface:main' into main
Quentin-Anthony Oct 4, 2024
75f0d89
fix docstrings in zamba config
pglorio Oct 4, 2024
b9545eb
added definition of _get_input_ids_and_config
pglorio Oct 4, 2024
cdbd690
fixed formatting of _get_input_ids_and_config
pglorio Oct 4, 2024
6fabb6a
Merge branch 'huggingface:main' into main
Quentin-Anthony Oct 4, 2024
b9f6cce
Merge branch 'huggingface:main' into main
Quentin-Anthony Oct 4, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Fixed shared layer logic, swapped up<->gate in mlp
pglorio committed Oct 3, 2024

Unverified

This commit is not signed, but one or more authors requires that any commit attributed to them is signed.
commit 1504774123b7d56b64334593f58220f14acbf4a8
4 changes: 2 additions & 2 deletions src/transformers/models/zamba/configuration_zamba.py
Original file line number Diff line number Diff line change
@@ -50,7 +50,7 @@ class ZambaConfig(PretrainedConfig):
Number of hidden layers in the model.
num_attention_heads (`int`, *optional*, defaults to 16):
Number of attention heads for each attention layer in the Transformer decoder.
num_key_value_heads (`int`, *optional*):
num_key_value_heads (`int`, *optional*, defaults to 16):
This is the number of key_value heads that should be used to implement Grouped Query Attention. If
`num_key_value_heads=None`, the model will use Multi Head Attention (MHA), if
`num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When
@@ -127,7 +127,7 @@ def __init__(
intermediate_size=14848,
num_hidden_layers=76,
num_attention_heads=16,
num_key_value_heads=None,
num_key_value_heads=16,
n_mamba_heads=2,
hidden_act="gelu",
hidden_mamba_act="silu",
59 changes: 36 additions & 23 deletions src/transformers/models/zamba/modeling_zamba.py
Original file line number Diff line number Diff line change
@@ -39,6 +39,7 @@
SequenceClassifierOutputWithPast,
)
from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import ALL_LAYERNORM_LAYERS
from ...utils import (
add_start_docstrings,
add_start_docstrings_to_model_forward,
@@ -150,6 +151,9 @@ def extra_repr(self):
return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"


ALL_LAYERNORM_LAYERS.append(ZambaRMSNorm)


# Copied from transformers.models.llama.modeling_llama.repeat_kv
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
"""
@@ -841,24 +845,19 @@ def forward(self, hidden_states, cache_params: HybridMambaAttentionDynamicCache
# fmt: on


# Copied from transformers.models.mistral.modeling_mistral.MistralMLP with Mistral->Zamba
class ZambaMLP(nn.Module):
"""
Adapted from transformers.models.gemma.modeling_gemma.GemmaMLP, with the flipped convention:
`up_proj` <-> `gate_proj`.
"""

def __init__(self, config):
super().__init__()
self.config = config
self.hidden_size = config.hidden_size
self.intermediate_size = config.intermediate_size
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
self.act_fn = ACT2FN[config.hidden_act]

def forward(self, x):
return self.down_proj(self.act_fn(self.up_proj(x)) * self.gate_proj(x))
def forward(self, hidden_state):
return self.down_proj(self.act_fn(self.gate_proj(hidden_state)) * self.up_proj(hidden_state))


class ZambaAttentionDecoderLayer(nn.Module):
@@ -1006,7 +1005,7 @@ def __init__(self, shared_transf: ZambaAttentionDecoderLayer, linear: nn.Linear,
super().__init__()
self.shared_transf = shared_transf
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
self.shared_transf = shared_transf
self.shared_transformer = shared_transformer

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done: 4b0fb52

self.linear = linear
self.mamba = mamba
self.mamba_decoder = mamba

def forward(
self,
@@ -1059,7 +1058,7 @@ def forward(

transformer_hidden_states = self.linear(transformer_hidden_states)

layer_outputs = self.mamba(
layer_outputs = self.mamba_decoder(
hidden_states,
transformer_hidden_states=transformer_hidden_states,
attention_mask=attention_mask,
@@ -1227,7 +1226,7 @@ def __init__(self, config: ZambaConfig):
self.vocab_size = config.vocab_size

self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
self.block = ZambaAttentionDecoderLayer(config)
block = ZambaAttentionDecoderLayer(config)
mamba_layers = []
linear_layers = []
self.layers_block_type = config.layers_block_type
@@ -1237,17 +1236,32 @@ def __init__(self, config: ZambaConfig):
elif config.layers_block_type[i] == "hybrid":
linear_layers.append(nn.Linear(self.config.hidden_size, self.config.hidden_size, bias=False))
mamba_layers.append(ZambaMambaDecoderLayer(config, layer_idx=i))
self.mamba_layers = nn.ModuleList(mamba_layers)
self.linear_layers = nn.ModuleList(linear_layers)

mamba_layers = iter(self.mamba_layers)
linear_layers = iter(self.linear_layers)
self.layers = []
for layer_type in self.layers_block_type:
mamba_layers = iter(mamba_layers)
linear_layers = iter(linear_layers)
layers = []
self._tied_weights_keys = []
for layer_id, layer_type in enumerate(self.layers_block_type):
if layer_type == "hybrid":
self.layers.append(HybridLayer(self.block, next(linear_layers), next(mamba_layers)))
prefix_name = f"layers.{layer_id}."
tied_keys = [
"shared_transf.self_attn.q_proj.weight",
"shared_transf.self_attn.k_proj.weight",
"shared_transf.self_attn.v_proj.weight",
"shared_transf.self_attn.o_proj.weight",
"shared_transf.feed_forward.gate_proj.weight",
"shared_transf.feed_forward.up_proj.weight",
"shared_transf.feed_forward.down_proj.weight",
"shared_transf.input_layernorm.weight",
"shared_transf.pre_ff_layernorm.weight",
# 'linear.weight',
# 'mamba.input_layernorm.weight',
# *['mamba.mamba.' + m_layer for m_layer in mamba_layer_keys]
]
self._tied_weights_keys = [*self._tied_weights_keys, *[prefix_name + key for key in tied_keys]]
layers.append(HybridLayer(block, next(linear_layers), next(mamba_layers)))
else:
self.layers.append(next(mamba_layers))
layers.append(next(mamba_layers))
self.layers = nn.ModuleList(layers)

self._attn_implementation = config._attn_implementation
self.final_layernorm = ZambaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
@@ -1417,11 +1431,10 @@ def _update_causal_mask(self, attention_mask, input_tensor, cache_position):

# Adapted from transformers.models.jamba.modeling_jamba.JambaForCausalLM with Jamba->Zamba, JAMBA->ZAMBA
class ZambaForCausalLM(ZambaPreTrainedModel):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this should be super similar to Gemma / can probably be copied from entirey!

Copy link
Contributor Author

@pglorio pglorio Sep 9, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, that is right. More in detail, we originally copied this from Jamba, and adapted it to Zamba by removing lines related to expert routers (which are not present in Zamba, there is no mixture of experts). We now slightly updated the class to reflect recent changes in upstream transformers and added # Adapted from transformers.models.jamba.modeling_jamba.JambaForCausalLM with Jamba->Zamba, JAMBA->ZAMBA
Zyphra@91bc076#diff-0f4d89960530c068a10af906f1958ed46e3e5f2ff937d6be61517478f383b074R1349
The comment mentions Jamba instead of Gemma as there are a few differences with GemmaForCausalLM in the prepare_inputs_for_generation method due to that we use HybridMambaAttentionDynamicCache.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, that's good enough thanks for the detailed explanation

_tied_weights_keys = ["lm_head.weight"]

def __init__(self, config: ZambaConfig):
super().__init__(config)
self.model = ZambaModel(config)
self._tied_weights_keys = ["lm_head.weight", *self.model._tied_weights_keys]
self.vocab_size = config.vocab_size
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)

@@ -1619,12 +1632,12 @@ def prepare_inputs_for_generation(
""",
ZAMBA_START_DOCSTRING,
)
# Copied from transformers.models.mixtral.modeling_mixtral.MixtralForSequenceClassification with Mixtral->Zamba, MIXTRAL->ZAMBA
class ZambaForSequenceClassification(ZambaPreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.num_labels = config.num_labels
self.model = ZambaModel(config)
self._tied_weights_keys = self.model._tied_weights_keys
self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)

# Initialize weights and apply final processing