-
-
Notifications
You must be signed in to change notification settings - Fork 25
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Feature request]: Add support for Linux ARM64 #125
Comments
Hi, contributions are welcome. See this initial PR for an attempt: #105. I can help/guide you if you would like to try to add it. |
BTW, I verified that |
@hawkinsp you're everywhere all of the sudden (in a good way!) and I am even getting a 1-question survey to fill about you in my edu inbox :P (relatedly, #105 was in reaction to the the same people who invited you to give that talk wanting a ppc build for jax...) We have some options --- I will explain the process more in detail, but briefly:
|
We also have CUDA builds on aarch and ppc now, so we could go all out and add those too... but probably we should take care of the cpu ones first 😅 |
Yes, I think it's a great idea to have Linux CUDA aarch64 builds at least because of the upcoming https://www.nvidia.com/en-us/data-center/grace-cpu/ which I'm sure someone will want to use with JAX... |
Well... the bot failed, so starting manually #127 |
Just a heads up: I was able to cross compile jaxlib for AArch64 easily enough, but the JIT compiler target detection isn't correct without making some upstream TensorFlow changes (tensorflow/tensorflow#57182). So we will not be able to get a working cross-compiled Aarch64 build under 0.3.15 as is and it will need a new |
Thank you for the heads. Did you use our tooling (conda-forge) or something else for this? How about when you built native aarch64 version? Did you use our setup here? |
I will apply your method here later in the week to see if we can get this sorted. |
And you're correct, we can test if the build time is reasonable. |
Yup, that's what I did. I suspect the |
BTW, there is a way to publish pypi wheels here too. I am not sure if the core team is okay with that, but if you want we can make the pypi wheels here too. An example is numba publishing some of their wheels on anaconda.org: https://anaconda.org/numba/numba/files?type=pypi |
I also need Jaxlib for Linux ARM64! |
I don't speak for the conda-forge jaxlib package maintainers, but jaxlib should work fine on ARM64 if you build it from source. So hopefully that can unblock you in the meantime! |
Let's see how #147 pans out (contributions welcome!) |
@hawkinsp here's what it stops in #155:
|
Link to PPC/arm64 builds: https://app.travis-ci.com/github/conda-forge/jaxlib-feedstock/builds/257816827 Note we have our own toolchain that may need thorough updating ... I can work on that https://github.com/conda-forge/bazel-toolchain-feedstock |
It's now slightly clearer how we have your customizations (see collapsed code below) from jax-ml/jax#7097 (comment) in our tooling in #157, but I am get an error:
|
Thanks, @hawkinsp ! That's true but Alphafold uses
and there are only Just few lines above in the Dockerfile they install some more dependencies from conda-forge but I'm afraid even if we solve this issue it won't help because it won't depend on the correct CUDA version. I hope I am wrong though! |
Sorry for reviving this old issue. With the introduction of linux-aarch64 support for bioconda, my package colabfold should also work on ARM, except for the missing jaxlib on linux-aarch64. Even compilation without CUDA would be quite useful to me, as I could point users to install Colabfold through conda on e.g. a cloud ARM machine for the MSA generation part, and then then run the GPU inference separately on a different machine. However, since I can't selectively disable conda dependencies, I would still need jaxlib to be installable on ARM. |
I don't speak for the conda-forge maintainers, but upstream we ship a linux aarch64 |
I would still prefer to provide a single conda command for installation to users, since I have a few dependencies that are not I am very thankful for all the jaxlib pip variants though! They are super useful! |
The main issue here is that we currently have receached the time of CI. Cross-compiled builds, e.g. for linux-aarch64 will take even longer. Once this is fixed, we can look into enabling this here. |
I am missing something or indeed conda packages for |
They are available since a year: #183 |
Indeed, that was my understanding, but this was not clear from @milot-mirdita in #125 (comment) . Could it make sense to rename the issue to "[Feature request]: Add support for CUDA builds on Linux ARM64"? |
Solution to issue cannot be found in the documentation.
Issue
According to https://anaconda.org/conda-forge/jaxlib the current supported OS+CPU architectures are:
I'd like to request adding
Linux ARM64
to this list.At the moment AlphaFold project cannot be used on Linux ARM64 due to a missing jaxlib+cuda Python wheel - google-deepmind/alphafold#528
Currently AlphaFold uses Pip3 to install jaxlib from https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
It would be nice if it could use conda-forge instead!
Installed packages
Environment info
The text was updated successfully, but these errors were encountered: