Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CUDA wheels #88

Open
lgarrison opened this issue Apr 22, 2024 · 2 comments
Open

CUDA wheels #88

lgarrison opened this issue Apr 22, 2024 · 2 comments

Comments

@lgarrison
Copy link
Member

Just starting to write down my thoughts on how we could build and distribute GPU wheels.

For background on GPU wheels, this is the best summary of the current state of affairs I've found: https://pypackaging-native.github.io/key-issues/gpus/

It seems there are two broadly different approaches we could take:

  1. bundle CUDA in the wheel, or
  2. use "external" CUDA.

(1) is the traditional method and results in large wheels. (2) is what JAX does, but is pretty cutting-edge and under-documented.

The way (1) would work is that we would statically link the CUDA libraries (which is what we're currently doing, I think), or dynamically link but let auditwheel copy the libraries into the wheel. There's a few parts of this I still don't understand, such as how it would work with JAX linking against one CUDA runtime but jax-finufft potentially having another. Would that result in two CUDA contexts? Clearly it's already working somehow!

With (2), we would use the NVIDIA CUDA wheels on PyPI. I can't find any official documentation on them, but the Python CUDA tech lead did write this nice tutorial in the cuQuantum repo: https://github.com/NVIDIA/cuQuantum/tree/main/extra/demo_build_with_wheels

It's somewhat hacky, but the basic ideas are clear. We would set the rpath to find the pip-installed CUDA libraries, using some helper scripts and auditwheel --exclude to allow specific shared libraries. At runtime, the linker will look for the pip-installed CUDA, or user/system installations if that fails.

Either way, I think the build itself can be done on cibuildwheel, probably just with a yum install of the CUDA development libraries (like this project does: https://github.com/OpenNMT/CTranslate2/blob/master/python/tools/prepare_build_environment_linux.sh).

In terms of PyPI distribution, with CUDA minor version compatibility, I think we can just do what cupy does and use jax-finufft-cuda12x and jax-finufft-cuda11x (if we want to support CUDA 11); no need for a custom package index URL. With (2), we would use [cuda_local] and [cuda_pip] extras. I don't think we need a full matrix of cuDNN versions like JAX does, but I could be wrong.

This is all a bit experimental since this isn't a "vanilla" CUDA extension, but one that has to work with JAX! For that reason, (2) seems more appealing, since it seems more likely to find the same CUDA JAX does more often than not.

@f0uriest
Copy link

f0uriest commented Oct 7, 2024

Is there any update on this? We'd love to use this in DESC (PlasmaControl/DESC#1294) but having to build from source to get GPU support is limiting, as all of our other dependencies can be handled with pip install ...

@lgarrison
Copy link
Member Author

Sorry, not yet. I'll update this space if/when there's anything to share!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants