-
Notifications
You must be signed in to change notification settings - Fork 1.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Jetson (aarch64) support #724
base: main
Are you sure you want to change the base?
Conversation
Co-authored-by: Aaron Gokaslan <[email protected]>
Co-authored-by: Aaron Gokaslan <[email protected]>
Co-authored-by: Aaron Gokaslan <[email protected]>
Hello, thanks for the amazing job. I installed flash attention from source using your committed setup.py with commit hash 0097ec4 for jetson Orin. The installation was completed without error and I can successfully import it in python. However, it returns all Failures when I run the unit tests with test_flash_attn.py. I don' t know if this is normal? Do we have other ways to test/check if flash attention works on Orin? Thank you. |
Which version of Jetpack you're using? I just tried on JP 6.0 DP |
I'm waiting for the JP 6.0 production release, I guess we just need to let the |
Hello, Thank you for the quick reply. Mine is JP 5.1.2. Couple months ago, just back to the moment of release of flash attention 2, I tried to install it with setting compute_87 or sm_87 but both attmpts were failed with the same JP. Do you have any ideas about what's wrong here? Thank you again. Best regards, |
I haven't tried on 5.1.x. I guess the reason is that the CUDA is too old. I have to upgrade to 6.0 because Ubuntu 18.04, CUDA 11.4, and Python 3.7 are too old to run recent versions of LLM and Stable Diffusion |
Hello again, I upgraded Orin to JP 6.0 DP today and tried to install flash_attn 2 again with your fork (branch aarch64). The upgraded JP eventually did not help for the correct installation. I noticed that the CUDA gencode was with compute_90 and sm_90 while compiling instead of 87 for Orin. Could you please share more info how you install the package from source? Thank you. |
I don't want to make it complex (Jetson isn't popular) so the PR actually introduces an env You can use this command
|
I refactored
setup.py
to make it work on my Jetson AGX Orin, I think it also helps for future ARM + GPU platformsI don't want to make it complex so I just allow to set CUDA gencode from ENV, Jetson is compute_87, sm_87