Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

register custom op for flash attn and use from torch.ops #7536

Merged
merged 18 commits into from
Aug 16, 2024

Conversation

youkaichao
Copy link
Member

No description provided.

Copy link

👋 Hi! Thank you for contributing to the vLLM project.
Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which consists a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of default ones by unblocking the steps in your fast-check build on Buildkite UI.

Once the PR is approved and ready to go, please make sure to run full CI as it is required to merge (or just use auto-merge).

To run full CI, you can do one of these:

  • Comment /ready on the PR
  • Add ready label to the PR
  • Enable auto-merge.

🚀

from vllm_flash_attn import flash_attn_with_kvcache as _flash_attn_with_kvcache


@torch.library.custom_op("vllm::flash_attn_varlen_func", mutates_args=[])
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

confirmed with @WoosukKwon , these two functions do not mutate the input.

cache_seqlens=kv_lens_tensor,
softcap=soft_cap if soft_cap is not None else 0,
),
test_utils=("test_faketensor", ))
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we can add test_schema later, after solving the OOM issue.

currently, I will get OOM when I test the schema, even though I'm using H100 80GB.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think part of testing the schema involves copying all the inputs and doubling checking that they are (or aren't) mutated in agreement with the op schema. I don't know if that's what is causing the OOMs here tho.

Copy link
Contributor

@bnellnm bnellnm left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm!

@youkaichao youkaichao merged commit 54bd9a0 into vllm-project:main Aug 16, 2024
27 of 29 checks passed
@youkaichao youkaichao deleted the fa_registration branch August 16, 2024 05:38
kylesayrs pushed a commit to neuralmagic/vllm that referenced this pull request Aug 17, 2024
fialhocoelho pushed a commit to opendatahub-io/vllm that referenced this pull request Aug 22, 2024
Alvant pushed a commit to compressa-ai/vllm that referenced this pull request Oct 26, 2024
KuntaiDu pushed a commit to KuntaiDu/vllm that referenced this pull request Nov 20, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants