-
Notifications
You must be signed in to change notification settings - Fork 531
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
8-bit LION, take 2 #514
8-bit LION, take 2 #514
Conversation
High level design point Should we make sure that the state dict of Yes this point should have been brought up a LONG time ago. I'm not saying we should do this, just opening the discussion. Also can the state dict of |
@vchiley How would you feel about just making this be DecoupledLionW in a future PR? So we'd have one optimizer with a Asking because this informs whether I should rip out the flag. I like the idea of avoiding redundant code, but maybe we have reasons to not do this. |
@bmosaicml originally wrote lion implementation |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Incredible work Davis!
@bmosaicml so are you good with "yes, leave the quantization flag and make this the single LION implementation"? |
Adds an 8-bit version of the LION optimizer. Also features 1 byte of (optional) auxiliary error correction state for each parameter to make pure bf16 training work.
Code changes:
lion8b.py
tollm-foundry/optim
DecoupledLionW_8bit
tollm-foundry/optim/__init__.py
lion8b
as an option inllm-foundry/optim/builders.py
test_lion8b.py
to the tests.mosaicml-turbo
to the GPU dependencies insetup.py
. This is the repo that currently holds all the CUDA kernels. These are in a separate repo for now to avoid complicating LLM foundry {install, deps, source code}.master_weight_dtype
field intrain.py
. If set to bf16 or fp16, the script doesmodel.to(dtype=<that dtype>)
before training. This works when we have error correction turned on.config_utils.py
to set FSDP's param_dtype to None if the master weights are already fp16/bf16.Non-obvious design choices:
_fused
arg that's kind of needed for testing and maybe (?) should be part of the API in case someone wants to check whether it's causing issues. I kind of want to get rid of this though once we trust the kernel logic fully.There's enough test coverage here that I'm not super worried about these choices, but wanted to highlight them in case someone has strong opinions.
WandB report