Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix: Mamba2 norm_before_gate usage #32686

Merged
merged 3 commits into from
Aug 20, 2024

Conversation

vasqu
Copy link
Contributor

@vasqu vasqu commented Aug 14, 2024

What does this PR do?

Fixes the default value for norm_before_gate (False) in the default config of mamba2. Additionally, implemented the other variation with norm_before_gate = True. Currently, this only affects continuous training from codestral on the fast path with the fused kernel:

out, ssm_state = mamba_split_conv1d_scan_combined(
projected_states,
self.conv1d.weight.squeeze(1),
self.conv1d.bias,
self.dt_bias,
A,
D=self.D,
chunk_size=self.chunk_size,
seq_idx=None, # was seq_idx
activation=self.activation,
rmsnorm_weight=self.norm.weight,
rmsnorm_eps=self.norm.variance_epsilon,
outproj_weight=self.out_proj.weight,
outproj_bias=self.out_proj.bias,
headdim=self.head_dim,
ngroups=self.n_groups,
norm_before_gate=self.norm_before_gate,
return_final_states=True,
**dt_limit_kwargs,
)

Discovered in #32580 (comment)

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@molbap @ArthurZucker

@vasqu
Copy link
Contributor Author

vasqu commented Aug 14, 2024

Tbh, we could also just remove norm_before_gate in the config and only allow the False path. I don't mind either.

@vasqu
Copy link
Contributor Author

vasqu commented Aug 14, 2024

Will rebase later when the other fix(es) are in main 😄

Copy link
Contributor

@molbap molbap left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great, thanks a lot! I'm pro-keeping norm_before_gate as it's a valid config option in the original code and people could want to experiment on it. pinging @ArthurZucker for final review :)

@@ -248,7 +254,9 @@ def __init__(self, config: Mamba2Config, layer_idx: int):
A = torch.arange(1, self.num_heads + 1)
self.A_log = nn.Parameter(torch.log(A))
self.A_log._no_weight_decay = True
self.norm = MambaRMSNormGated(self.intermediate_size, eps=self.layer_norm_epsilon)
self.norm = MambaRMSNormGated(
self.intermediate_size, eps=self.layer_norm_epsilon, norm_before_gate=config.norm_before_gate
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

great, maybe let's add this as an attribute of Mamba2Mixer in the init to get all config-derived args in the same place!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's already been in the Mixer class, I just used it from the passed config. Changed it to use self now instead 👍


def forward(self, hidden_states, gate=None):
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)

if gate is not None:
if gate is not None and not self.norm_before_gate:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👌

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey! Not entirely sure this is super transfomers friendly, as a research would just copy paste the modelling and add all the if-else!
our usualy motivation is to add a new model if there is a good pretrained checkpoint that demonstrated the use of this param!

If there is a strong demand from the community why not, but in general such changes go against the philosphy! 🤗

@vasqu
Copy link
Contributor Author

vasqu commented Aug 19, 2024

Hmm, there are currently two problems in the current implementation though:

  1. The norm_before_gate flag in the config was incorrectly set to True but the implementation follows the path as if it were False. This is problematic as it does not suggest so and enables incorrect codestral training via the fast path (see the initial description and the code snippet - direct passing of the flag).
  2. Now that the flag exists, it would be weird that it only supports one path but suggests otherwise.

If I read it correctly, would you prefer removing this flag in the config and adjusting the code to only follow one path (i.e. norm_before_gate=False)?

@ArthurZucker
Copy link
Collaborator

If I read it correctly, would you prefer removing this flag in the config and adjusting the code to only follow one path (i.e. norm_before_gate=False)?

yes! 🤗 this would be less confusing IMO!

@vasqu
Copy link
Contributor Author

vasqu commented Aug 20, 2024

@ArthurZucker Should be good now 👀

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks let's update the PR title

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@ArthurZucker ArthurZucker merged commit c63a3d0 into huggingface:main Aug 20, 2024
21 checks passed
@vasqu vasqu deleted the mamba2-gated-norm-fix branch August 20, 2024 19:54
@vasqu
Copy link
Contributor Author

vasqu commented Aug 20, 2024

I guess the renaming of the PR is too late now 😓

Titus-von-Koeller pushed a commit to jiqing-feng/transformers that referenced this pull request Aug 21, 2024
* mamba2 uses norm_before_gate=False

* small nit

* remove norm_before_gate flag and follow False path only
@ArthurZucker
Copy link
Collaborator

No worries 🤗

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants