Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Conversion Script for Mamba checkpoints (mamba_ssm -> transformers) #29631

Closed
haileyschoelkopf opened this issue Mar 13, 2024 · 24 comments · Fixed by #29705 or #29851
Closed

Conversion Script for Mamba checkpoints (mamba_ssm -> transformers) #29631

haileyschoelkopf opened this issue Mar 13, 2024 · 24 comments · Fixed by #29705 or #29851
Labels
Core: Modeling Internals of the library; Models. Feature request Request for a new feature

Comments

@haileyschoelkopf
Copy link
Contributor

Feature request

Thanks very much for the Mamba support (#28094), this interoperability is fantastic!

I wanted to ask if there were any utility (doesn't have to be clean, just functional) for converting checkpoints provided for use in the mamba_ssm library into the format provided in transformers.

This would be very helpful if it exists! Thanks 🤗

Motivation

I'd like to be able to convert novel trained mamba models from the state-spaces/mamba repo into HF transformers without rewriting a conversion script myself if need be.

Your contribution

I could write a utility for this if none exists but would probably not have the bandwidth to upstream it.

@amyeroberts amyeroberts added the Feature request Request for a new feature label Mar 13, 2024
@amyeroberts
Copy link
Collaborator

Hi @haileyschoelkopf, thanks for opening this feature request!

Normally we have conversion scripts under each model's folder (would be here).

I think adsding conversion script sounds like a good idea! @ArthurZucker contributed the model. He's off at the moment, but back soon - I'll let him reply in case there's a good reason we didn't add it alongside the model originally

@haileyschoelkopf
Copy link
Contributor Author

Thank you! I'll share here if I get the time to implement one myself!

@byi8220
Copy link
Contributor

byi8220 commented Mar 13, 2024

If it's okay/not too complicated, could I try and give this a shot (as a new outside contributor)?

Admittedly very new to ML stuff, but at a very high level, would this entail implementing conversion scripts similar to something like what's found in other model dirs such as https://github.com/huggingface/transformers/tree/b340d90738fa14bd6f81b65e4148173cbec62ff6/src/transformers/models/bert ?

I.e. Just 2 files for the forward and backwards pass convert_mamba_ssm_checkpoint_to_pytorch.py and convert_pytorch_to_mamba_ssm_checkpoint.py?

@haileyschoelkopf
Copy link
Contributor Author

yep! Basically just load checkpoint file -> convert to a format loadable by the other library, e.g. reshaping or renaming weights as needed -> run load_state_dict() using transformers or mamba_ssm model -> call save_pretrained() to output new loadable/uploadable model.

@SrGonao
Copy link
Contributor

SrGonao commented Mar 15, 2024

I've made this if it's helpful for anyone.
https://gist.github.com/SrGonao/33f373a13a6cad6b245450f3d6361598

@byi8220
Copy link
Contributor

byi8220 commented Mar 17, 2024

Thanks, that's very helpful. I'll try to get a PR out soon modeled off that.

@ArthurZucker
Copy link
Collaborator

ArthurZucker commented Mar 17, 2024

hey! The reason I did not add one is because the original checkpoints are compatible. This can still be added but only the config should be inferred / updated!
So your checkpoint should already be compatible, no renaming and no reshaping

@byi8220
Copy link
Contributor

byi8220 commented Mar 17, 2024

So your checkpoint should already be compatible, no renaming and no reshaping

Hm, when I try to run a conversion, I get an error suggesting there needs to be a rename:

RuntimeError: Error(s) in loading state_dict for MambaForCausalLM:
        Missing key(s) in state_dict: "backbone.embeddings.weight". 
        Unexpected key(s) in state_dict: "backbone.embedding.weight". 

The unexpected key contains "embedding" (with no s at the end), while the missing key contains "embeddings" (with an s at the end)

I've attempted to create a PR which both converts the config and does the above renaming for the forward pass: #29705

(Huge thanks to @SrGonao , my PR does pretty much the same thing as his script, except on local files instead of interacting with the hub)

@ArthurZucker
Copy link
Collaborator

Oh no 😅
I am not getting this one on the original checkpoints, so maybe it was updated at some point?

@byi8220
Copy link
Contributor

byi8220 commented Mar 21, 2024

Oh no 😅 I am not getting this one on the original checkpoints, so maybe it was updated at some point?

It might have, but I couldn't get it to work without the rename.

A quick printout of the original ssm model suggests at least the current version of mamba_ssm works with backbone.embedding.weight (no s)

from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel
from mamba_ssm.models.config_mamba import MambaConfig
config = MambaConfig(d_model = 64, n_layer = 8)
model = MambaLMHeadModel(config)
print(model)

Outputs:

MambaLMHeadModel(
  (backbone): MixerModel(
    (embedding): Embedding(50280, 64)
    (layers): ModuleList(
      (0-7): 8 x Block(
        (mixer): Mamba(
          (in_proj): Linear(in_features=64, out_features=256, bias=False)
          (conv1d): Conv1d(128, 128, kernel_size=(4,), stride=(1,), padding=(3,), groups=128)
          (act): SiLU()
          (x_proj): Linear(in_features=128, out_features=36, bias=False)
          (dt_proj): Linear(in_features=4, out_features=128, bias=True)
          (out_proj): Linear(in_features=128, out_features=64, bias=False)
        )
        (norm): RMSNorm()
      )
    )
    (norm_f): RMSNorm()
  )
  (lm_head): Linear(in_features=64, out_features=50280, bias=False)
)

@ArthurZucker
Copy link
Collaborator

ArthurZucker commented Mar 25, 2024

You are right 😢 I have no idea why I did not get any warnings / maybe because the weights are tied it used the lm_head's weights and tied using them.

Just saw your PR to remove the tie_weights that were forced before in ssm-state, so let's try to fix this in transformers as well!

Arf that is really annoying

@byi8220
Copy link
Contributor

byi8220 commented Mar 25, 2024

maybe because the weights are tied it used the lm_head's weights and tied using them.

I'm confused by what you mean here.

I thought that the problem was due to a difference in naming conventions between the two packages, where the transformers library and the mamba_ssm model library just chose to name their embedding layer differently.

Just saw your PR to remove the tie_weights that were forced before in ssm-state

Are you referring to @haileyschoelkopf's PR in state-spaces/mamba#211?

@ArthurZucker
Copy link
Collaborator

I'll add a loading hook just this once! IMO should be the cleanest way to fix this

@ArthurZucker
Copy link
Collaborator

I mean that on my side, when implementing mamba in transformers I did not have a warning about the weights. I suppose that this is because the weights are by default tied. tie_word_embeddings=True. Thus this:

>>> from transformers import AutoModelForCausalLM
>>> model = AutoModelForCausalLM.from_pretrained("state-spaces/mamba-130m", num_hidden_layers=24, vocab_size = 50280)

does not produce any warning.
while:

>>> from transformers import AutoModelForCausalLM
>>> model = AutoModelForCausalLM.from_pretrained("state-spaces/mamba-130m", num_hidden_layers=24, vocab_size = 50280, tie_word_embeddings=False)
Some weights of MambaForCausalLM were not initialized from the model checkpoint at state-spaces/mamba-130m and are newly initialized: ['backbone.embeddings.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.

does.

and:

>>> from transformers import AutoModel

>>> model = AutoModel.from_pretrained("state-spaces/mamba-130m", num_hidden_layers=24, vocab_size = 50280)
Some weights of MambaModel were not initialized from the model checkpoint at state-spaces/mamba-130m and are newly initialized: ['backbone.embeddings.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.

@ArthurZucker
Copy link
Collaborator

That a really silent bug, and I got tricked by it...

@ArthurZucker ArthurZucker added the Core: Modeling Internals of the library; Models. label Mar 25, 2024
@ArthurZucker
Copy link
Collaborator

cc @amyeroberts core issue !

@byi8220
Copy link
Contributor

byi8220 commented Mar 25, 2024

Hm, just to make sure I understand:

  1. The lack of warning about uninitialized weights is because when tie_word_embeddings=True, the input embedding layer weights name is somehow ignored at some step of loading a pretrained model?
  2. A converter which renames weights alongside config conversion should exist.
  3. This issue may exist in other models since it's due to behavior in weight tying and not something mamba specific?

@ArthurZucker
Copy link
Collaborator

ArthurZucker commented Mar 25, 2024

  1. More like the lack of an error being raised!

  2. A converter which renames weights alongside config conversion should exist.
    "Could ", what I am thinking about is just a simple hook triggered in from_pretrained, let me open a PR in a bit. This avoids having a conversion script

3.This issue may exist in other models since it's due to behavior in weight tying and not something mamba specific?
Totally. But most probably we would have heard of that, as people use the AutoModel api a lot, and don't always tie weights.

@byi8220
Copy link
Contributor

byi8220 commented Mar 25, 2024

  1. Just wondering, if it's an error, should these be logged at warning level or are these generally just warnings: https://github.com/huggingface/transformers/blob/main/src/transformers/modeling_utils.py#L4011-L4030
  2. "This avoids having a conversion script" - Wouldn't there still need to be one to convert the MambaConfig? (If there doesn't need to be a converter at all, I guess I can discard Add a converter from mamba_ssm -> huggingface mamba #29705)
  3. If it causes issues immediately then yeah that makes sense too.

@ArthurZucker
Copy link
Collaborator

About 2. I mean avoid having to explicitly convert if you know the config. Config can be initialized first with from_pretrained!

But let's still have the conversion script! It will be beneficial to have a mapping between the names and the config explicitly!

@ArthurZucker
Copy link
Collaborator

For 1. yes a warning should indeed be issued sorry, we raise error for mismatch sizes!

@byi8220
Copy link
Contributor

byi8220 commented Mar 25, 2024

Sg, I removed the weight rename from my PR (although now my PR won't actually work until yours is checked in)

@byi8220
Copy link
Contributor

byi8220 commented Mar 28, 2024

Just a quick plug: I think your PR fixes the checkpointing issue, but PR #29705 is still open for config->config conversion.

@ArthurZucker ArthurZucker reopened this Mar 28, 2024
@ArthurZucker
Copy link
Collaborator

Indeed opening again!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Core: Modeling Internals of the library; Models. Feature request Request for a new feature
Projects
None yet
5 participants