Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add custom ops for compatibility with PT Compile #1139

Merged
merged 13 commits into from
Sep 18, 2024

Conversation

ani300
Copy link
Contributor

@ani300 ani300 commented Aug 8, 2024

This PR adds basic support for torch.compile() for the non-varlen variants of Flash Attention.

This essentially allows for models that use the flash_attn_qkvpacked_func, flash_attn_kvpacked_func, and flash_attn_func to compile without graph breaks.

I can add unit tests if it makes sense. I'll test it with our own training pipeline for performance measurements and I'll post them later.

This uses the new custom operators API in Pytorch 2.4. I can move to the older APIs if needed, or I can look up how to make both coexist

@ani300
Copy link
Contributor Author

ani300 commented Aug 8, 2024

Added varlen functions too

@tridao
Copy link
Contributor

tridao commented Aug 8, 2024

Thank you! If this requires torch 2.4, is there some way to make it optional for older pytorch version? e.g. define torch.library.custom_op to be a no-op? Idk what people usually do.

@ani300
Copy link
Contributor Author

ani300 commented Aug 8, 2024

There's a different API that was used in previous pytorch versions. It's more cumbersome to use, but let me think about how to use both APIs at the same time with a version selection. I can also do a noop wrapper based on version

@ani300
Copy link
Contributor Author

ani300 commented Aug 9, 2024

Ok, I've updated the code to only wrap things if Pytorch is 2.4 or higher. Even though we're comparing strings, torch.__version__ implements semantic comparisons, so it should always work correctly. As soon as I have the training performance comparison I'll update the PR

@mayank31398
Copy link

aah, yikes
@ani300 I had started working on the same thing #1145 😓

Ill let you handle this 😃

@ani300
Copy link
Contributor Author

ani300 commented Aug 13, 2024

@tridao I've updated the code after testing it with SFTTrainer to make sure everything was getting called correctly and compiling without graph breaks. I've also added a check for the pytorch version so this only activates if the APIs are available.

One question I have is if it's ok to change the C++ pytorch interface so the out/out_padded padding/unpadding can be done outside of the custom op. Right now, Pytorch custom ops don't allow outputs to be aliases of other outputs, which this pair is when head_dim % 8 == 0. This means I have to clone one of the outputs to unalias them, which is a waste of memory bandwidth.

@tridao
Copy link
Contributor

tridao commented Aug 14, 2024

Yes let's change the C++ code to not pad the output, and have the Python part handle that.

@GLivshits
Copy link

There's a different API that was used in previous pytorch versions. It's more cumbersome to use, but let me think about how to use both APIs at the same time with a version selection. I can also do a noop wrapper based on version

For older versions, you can use:

  1. torch.library.define - just for definition of ops.
  2. Then use this:
    OLD_TORCH_VERSION = version.parse(torch.version).base_version < "2.4.0"
if OLD_TORCH_VERSION:
    _torch_custom_op_wrapper = partial(torch.library.impl, types="cuda")
    _torch_register_fake_wrapper = torch.library.impl_abstract
else:
    _torch_custom_op_wrapper = partial(torch.library.custom_op, device_types="cuda", mutates_args=())
    _torch_register_fake_wrapper = torch.library.register_fake

Here I assume that mutates_args is empty. If it is not - I don't currently know how to handle that (and I didnt think much) without copy-pasting lots of code.

@mayank31398
Copy link

@GLivshits I dont think it can be handled in older versions of torch

@mayank31398
Copy link

@tridao @ani300 is there any progress/updates on this?
Its a pretty neat feature to have Flash Attention fully end-to-end traceable natively.

@ani300
Copy link
Contributor Author

ani300 commented Aug 24, 2024

Hey, I'm almost done with the C++ portion, hopefully by Monday I can update the PR and rerun all the CI to make sure it's still good with all the changes. I'll also try the suggestions from @GLivshits to make it work with previous versions

@ani300
Copy link
Contributor Author

ani300 commented Aug 27, 2024

@GLivshits the way the code is currently written, the backward functions need to have mutate_args, or we need to do an extra memory copy on GPU, which might significantly impact performance

@ani300
Copy link
Contributor Author

ani300 commented Aug 27, 2024

@tridao this is ready for final review and/or merging. I tested with the whole unit test suite for both pytorch 2.3 and pytorch 2.4 on an A100

@raghukiran1224
Copy link

@tridao Any update on this?

@anijain2305
Copy link

Thanks @ani300 for taking this on. I started working on this here (#1209). But this is way further along. So, I closed mine.

cc @zou3519 for custom ops API usage review @Chillee

_torch_register_fake_wrapper = register_fake_wrapper


@_torch_custom_op_wrapper("flashattn::_flash_attn_forward", mutates_args=(), device_types="cuda")

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit - s/flashattn/flash_attn/

return out, softmax_lse, p, rng_state


try:

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps use the same if condition

if torch.__version__ >= "2.4.0":

Comment on lines 66 to 68
def inner(*args, **kwargs):
return func(*args, **kwargs)
return inner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this equivalent of return func?

Comment on lines 74 to 76
def inner(*args, **kwargs):
return func(*args, **kwargs)
return inner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ditto - return func

_torch_custom_op_wrapper = torch.library.custom_op
_torch_register_fake_wrapper = torch.library.register_fake
else:
def custom_op_wrapper(name, fn=None, /, *, mutates_args, device_types=None, schema=None):

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit - a more desciptive name - noop_custom_op_wrapper

And maybe a comment that we dont support < 2.4 versions.

return inner
if fn is None:
return wrap
return wrap(fn)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

return fn

@ani300
Copy link
Contributor Author

ani300 commented Sep 11, 2024

@anijain2305 just pushed all the updates/fixes you suggested. I am running CI again to make sure it works for both 2.4 and 2.3. Will update the comments when it's done

@anijain2305
Copy link

Thanks @ani300

Custom ops changes look good to me. Hope @tridao can take a look at the other parts.

@ani300
Copy link
Contributor Author

ani300 commented Sep 11, 2024

Tests all ran successfully again

@tridao tridao merged commit 83e41b3 into Dao-AILab:main Sep 18, 2024
@tridao
Copy link
Contributor

tridao commented Sep 18, 2024

Awesome, thanks!

@umarbutler
Copy link

I just built the flash-attention main branch from source and sadly I'm getting the same error message during training with the Hugging Face Trainer with torch_compile=True:

.venv/lib/python3.12/site-packages/torch/_dynamo/variables/functions.py:663: UserWarning: Graph break due to unsupported builtin flash_attn_2_cuda.PyCapsule.fwd. This function is either a Python builtin (e.g. _warnings.warn) or a third-party C/C++ Python extension (perhaps created with pybind). If it is a Python builtin, please file an issue on GitHub so the PyTorch team can add support for it and see the next case for a workaround. If it is a third-party C/C++ Python extension, please either wrap it into a PyTorch-understood custom operator (see https://pytorch.org/tutorials/advanced/custom_ops_landing_page.html for more details) or, if it is traceable, use torch.compiler.allow_in_graph.

@umarbutler
Copy link

It looks like I somehow ran python setup.py bdist_wheel and forgot I was creating a wheel (which I also ended up mistakenly deleting moments ago) 😅 hopefully second time's the charm

@ani300
Copy link
Contributor Author

ani300 commented Sep 26, 2024

I was worried for a second, as I just used it with the Huggingface Trainer yesterday successfully 😥

@umarbutler
Copy link

@ani300 I can confirm that it works! And much faster than PyTorch SDPA, shaved off 50 training hours (400h -> 350h) 😆

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants