Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable pivoted QR on GPU via MAGMA. #25955

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open

Conversation

tttc3
Copy link
Contributor

@tttc3 tttc3 commented Jan 17, 2025

Originally noted in #20282, this commit provides a GPU compatible implementation of geqp3 via MAGMA.
MAGMA implementation is based on @dfm's implementation of eig in ccb3317.

Maybe closes #12897?

To reduce code duplication I've moved AllocateWorkspace from solver_kernels_ffi.cc into ffi_helpers.h.

@dfm dfm self-assigned this Jan 17, 2025
@tttc3 tttc3 force-pushed the magma_qr branch 2 times, most recently from 47dc68d to ac4aeb0 Compare January 18, 2025 10:07
Copy link
Collaborator

@dfm dfm left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! I've approved so that we can run all the tests, but also left some small inline comments.

tests/magma_linalg_test.py Outdated Show resolved Hide resolved
tests/magma_linalg_test.py Outdated Show resolved Hide resolved
tests/magma_linalg_test.py Outdated Show resolved Hide resolved
@google-ml-butler google-ml-butler bot added kokoro:force-run pull ready Ready for copybara import and testing labels Jan 27, 2025
jax/_src/lax/linalg.py Outdated Show resolved Hide resolved
@dfm
Copy link
Collaborator

dfm commented Jan 27, 2025

The TPU failures are unrelated, but can you rebase onto the current main branch after making your edits to fix them? Thanks!

tests/magma_linalg_test.py Outdated Show resolved Hide resolved
Copy link
Collaborator

@dfm dfm left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for all the work on this. I think we're there! One tiny last comment. Then can you rebase onto main (there's a merge conflict with the CHANGELOG)? Then, I can merge. Thanks again!

tests/magma_linalg_test.py Outdated Show resolved Hide resolved
@tttc3
Copy link
Contributor Author

tttc3 commented Jan 29, 2025

MAGMA 2.9.0 was released yesterday and now supports a workspace query for lwork (in addition to an expert interface to geqp3_gpu). The workspace requirements appear unchanged (marked backward compatible) so we can continue to use the manual calculation implemented here (with the added benefit that it will work for MAGMA versions below 2.9.0).

Would you prefer I update this pull to make a workspace query and require magma>=2.9.0 or are you happy to leave it as it stands?

@dfm
Copy link
Collaborator

dfm commented Jan 29, 2025

Would you prefer I update this pull to make a workspace query and require magma>=2.9.0 or are you happy to leave it as it stands?

Let's leave this as is!

So, we're getting quite a few JVP failures in testScipyQrModes with complex64 dtypes in our internal CI, and I can reproduce those when I run myself. The errors seem large enough that they don't seem to just be numerics. Can you try running those tests yourself to see if you can reproduce and debug?

@tttc3
Copy link
Contributor Author

tttc3 commented Jan 29, 2025

So, we're getting quite a few JVP failures in testScipyQrModes with complex64 dtypes in our internal CI, and I can reproduce those when I run myself. The errors seem large enough that they don't seem to just be numerics. Can you try running those tests yourself to see if you can reproduce and debug?

I can reproduce these errors. I'm pretty sure I've implemented the pivot inversion in qr_and_mul incorrectly, if you set inverted_pivots = jnp.argsort(p[0]) instead of inverted_pivots = p[0][p[0]], it should fix any JVP specific issues (I verified this by checking the output of qr_and_mul is indeed the identity, which it turns out it wasn't before).

Originally noted in jax-ml#20282, this commit provides a GPU compatible
implementation of `geqp3` via MAGMA.
@tttc3 tttc3 requested a review from dfm January 29, 2025 22:14
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
pull ready Ready for copybara import and testing
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Column-Pivoted QR Decomposition
3 participants