-
Notifications
You must be signed in to change notification settings - Fork 27.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add [Mamba
] model
#28086
Comments
Thanks for opening this issue! Given the sensitivity of this model, the HF team will take it over, we'll have a look at your fork and add you as a co-other 🤗 |
Thanks a lot! My fork is largely inspired from the original Mamba repo, the differences mostly consisting in boilerplate code. So don’t hesitate to start from the upstream repo. I (and the linter) have noticed a couple of bugs or pieces of dead code in the upstream (some of which remain in my fork). So keep an eye for them! |
I did a similar study https://github.com/LegallyCoder/mamba-hf . |
I've seen a CPU only implementation fork mentioned somewhere in the source repo issues. The author of the fork removed Triton and CUDA dependencies. Found it: https://github.com/kroggen/mamba-cpu |
Model description
Mamba is a new architecture proposed in arXiv:2312.00752 by Albert Gu (CMU) and Tri Dao (Princeton).
It is inspired by structured state space models (SSMs), but with the addition of a selection mechanism that allows it to combines the ability of transformers to perform content-based reasoning with the performance of SSMs on long sequences. Mamba can be efficiently trained in parallel while also enjoying efficient inference by running recurrently.
The paper claims SoTA performance on various modalities, with performance tested up to 2.8B parameters. Crucially, the model cannot be implemented efficiently using only PyTorch operations; instead, it relies on optimised CUDA and
triton
kernels.The original implementation by the authors is available at https://github.com/state-spaces/mamba/tree/main under an Apache 2.0 license.
Starting from their implementation, I have started porting the model to 🤗 Transformers. This is work in progress 🚧, and can be found in my fork at https://github.com/JLTastet/transformers/tree/mamba.
I can open a PR, but in its current state my branch is not ready to be merged. I will also open an issue in the original repo to let the authors know about this, in case they want to chime in.
What I got working:
AutoModel
.What still needs some work:
Trainer
, and I still don’t understand what causes them.generate
, we should check that the optimised recurrent inference is used instead of the slower autoregressive inference.I am opening this issue to avoid duplicating work, since I saw some mention of Mamba today by @ArthurZucker.
My main motivation for porting this model is to learn a bit more about it (and about the internals of 🤗 Transformers) and to run more evals. Some of you probably know this library much better than me, so feel free to write your own implementation if you can do it better or quicker. Otherwise, don’t hesitate to build on top of my fork.
Open source status
Provide useful links for the implementation
The text was updated successfully, but these errors were encountered: