-
Notifications
You must be signed in to change notification settings - Fork 27.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix: Mamba2 norm_before_gate
usage
#32686
Conversation
Tbh, we could also just remove |
Will rebase later when the other fix(es) are in main 😄 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great, thanks a lot! I'm pro-keeping norm_before_gate
as it's a valid config option in the original code and people could want to experiment on it. pinging @ArthurZucker for final review :)
@@ -248,7 +254,9 @@ def __init__(self, config: Mamba2Config, layer_idx: int): | |||
A = torch.arange(1, self.num_heads + 1) | |||
self.A_log = nn.Parameter(torch.log(A)) | |||
self.A_log._no_weight_decay = True | |||
self.norm = MambaRMSNormGated(self.intermediate_size, eps=self.layer_norm_epsilon) | |||
self.norm = MambaRMSNormGated( | |||
self.intermediate_size, eps=self.layer_norm_epsilon, norm_before_gate=config.norm_before_gate |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
great, maybe let's add this as an attribute of Mamba2Mixer
in the init to get all config-derived args in the same place!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's already been in the Mixer class, I just used it from the passed config. Changed it to use self
now instead 👍
|
||
def forward(self, hidden_states, gate=None): | ||
input_dtype = hidden_states.dtype | ||
hidden_states = hidden_states.to(torch.float32) | ||
|
||
if gate is not None: | ||
if gate is not None and not self.norm_before_gate: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
👌
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hey! Not entirely sure this is super transfomers
friendly, as a research would just copy paste the modelling and add all the if-else!
our usualy motivation is to add a new model if there is a good pretrained checkpoint that demonstrated the use of this param!
If there is a strong demand from the community why not, but in general such changes go against the philosphy! 🤗
Hmm, there are currently two problems in the current implementation though:
If I read it correctly, would you prefer removing this flag in the config and adjusting the code to only follow one path (i.e. |
yes! 🤗 this would be less confusing IMO! |
9cd9196
to
7d01af0
Compare
@ArthurZucker Should be good now 👀 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks let's update the PR title
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
I guess the renaming of the PR is too late now 😓 |
* mamba2 uses norm_before_gate=False * small nit * remove norm_before_gate flag and follow False path only
No worries 🤗 |
What does this PR do?
Fixes the default value for norm_before_gate (False) in the default config of mamba2. Additionally, implemented the other variation with norm_before_gate = True. Currently, this only affects continuous training from codestral on the fast path with the fused kernel:
transformers/src/transformers/models/mamba2/modeling_mamba2.py
Lines 334 to 353 in 20a0449
Discovered in #32580 (comment)
Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.
@molbap @ArthurZucker