-
Notifications
You must be signed in to change notification settings - Fork 1.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add custom ops for compatibility with PT Compile #1139
Conversation
Added varlen functions too |
Thank you! If this requires torch 2.4, is there some way to make it optional for older pytorch version? e.g. define |
There's a different API that was used in previous pytorch versions. It's more cumbersome to use, but let me think about how to use both APIs at the same time with a version selection. I can also do a noop wrapper based on version |
Ok, I've updated the code to only wrap things if Pytorch is 2.4 or higher. Even though we're comparing strings, |
@tridao I've updated the code after testing it with SFTTrainer to make sure everything was getting called correctly and compiling without graph breaks. I've also added a check for the pytorch version so this only activates if the APIs are available. One question I have is if it's ok to change the C++ pytorch interface so the |
Yes let's change the C++ code to not pad the output, and have the Python part handle that. |
For older versions, you can use:
Here I assume that mutates_args is empty. If it is not - I don't currently know how to handle that (and I didnt think much) without copy-pasting lots of code. |
@GLivshits I dont think it can be handled in older versions of torch |
Hey, I'm almost done with the C++ portion, hopefully by Monday I can update the PR and rerun all the CI to make sure it's still good with all the changes. I'll also try the suggestions from @GLivshits to make it work with previous versions |
@GLivshits the way the code is currently written, the backward functions need to have |
@tridao this is ready for final review and/or merging. I tested with the whole unit test suite for both pytorch 2.3 and pytorch 2.4 on an A100 |
@tridao Any update on this? |
flash_attn/flash_attn_interface.py
Outdated
_torch_register_fake_wrapper = register_fake_wrapper | ||
|
||
|
||
@_torch_custom_op_wrapper("flashattn::_flash_attn_forward", mutates_args=(), device_types="cuda") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit - s/flashattn/flash_attn/
flash_attn/flash_attn_interface.py
Outdated
return out, softmax_lse, p, rng_state | ||
|
||
|
||
try: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Perhaps use the same if condition
if torch.__version__ >= "2.4.0":
flash_attn/flash_attn_interface.py
Outdated
def inner(*args, **kwargs): | ||
return func(*args, **kwargs) | ||
return inner |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this equivalent of return func
?
flash_attn/flash_attn_interface.py
Outdated
def inner(*args, **kwargs): | ||
return func(*args, **kwargs) | ||
return inner |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ditto - return func
flash_attn/flash_attn_interface.py
Outdated
_torch_custom_op_wrapper = torch.library.custom_op | ||
_torch_register_fake_wrapper = torch.library.register_fake | ||
else: | ||
def custom_op_wrapper(name, fn=None, /, *, mutates_args, device_types=None, schema=None): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit - a more desciptive name - noop_custom_op_wrapper
And maybe a comment that we dont support < 2.4 versions.
flash_attn/flash_attn_interface.py
Outdated
return inner | ||
if fn is None: | ||
return wrap | ||
return wrap(fn) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
return fn
@anijain2305 just pushed all the updates/fixes you suggested. I am running CI again to make sure it works for both 2.4 and 2.3. Will update the comments when it's done |
Tests all ran successfully again |
Awesome, thanks! |
I just built the flash-attention main branch from source and sadly I'm getting the same error message during training with the Hugging Face Trainer with
|
It looks like I somehow ran |
I was worried for a second, as I just used it with the Huggingface Trainer yesterday successfully 😥 |
@ani300 I can confirm that it works! And much faster than PyTorch SDPA, shaved off 50 training hours (400h -> 350h) 😆 |
This PR adds basic support for torch.compile() for the non-varlen variants of Flash Attention.
This essentially allows for models that use the
flash_attn_qkvpacked_func
,flash_attn_kvpacked_func
, andflash_attn_func
to compile without graph breaks.I can add unit tests if it makes sense. I'll test it with our own training pipeline for performance measurements and I'll post them later.
This uses the new custom operators API in Pytorch 2.4. I can move to the older APIs if needed, or I can look up how to make both coexist