Skip to content

Commit

Permalink
New JAX build from Google (#1172)
Browse files Browse the repository at this point in the history
  • Loading branch information
DwarKapex authored Nov 26, 2024
1 parent b0e6753 commit c07ffe4
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 20 deletions.
36 changes: 18 additions & 18 deletions .github/container/build-jax.sh
Original file line number Diff line number Diff line change
Expand Up @@ -274,34 +274,34 @@ if [[ ! -e "/usr/local/cuda/lib" ]]; then
fi

if ! grep 'try-import %workspace%/.local_cuda.bazelrc' "${SRC_PATH_JAX}/.bazelrc"; then
echo -e '\ntry-import %workspace%/.local_cuda.bazelrc' >> "${SRC_PATH_JAX}/.bazelrc"
echo -e '\ntry-import %workspace%/.local_cuda.bazelrc' >> "${SRC_PATH_JAX}/.bazelrc"
fi
cat > "${SRC_PATH_JAX}/.local_cuda.bazelrc" << EOF
build:cuda --repo_env=LOCAL_CUDA_PATH="/usr/local/cuda"
build:cuda --repo_env=LOCAL_CUDNN_PATH="/opt/nvidia/cudnn"
build:cuda --repo_env=LOCAL_NCCL_PATH="/opt/nvidia/nccl"
build --repo_env=LOCAL_CUDA_PATH="/usr/local/cuda"
build --repo_env=LOCAL_CUDNN_PATH="/opt/nvidia/cudnn"
build --repo_env=LOCAL_NCCL_PATH="/opt/nvidia/nccl"
EOF
time python "${SRC_PATH_JAX}/build/build.py" \

pushd ${SRC_PATH_JAX}
time python "${SRC_PATH_JAX}/build/build.py" build \
--editable \
--use_clang \
--enable_cuda \
--build_gpu_plugin \
--gpu_plugin_cuda_version=$TF_CUDA_MAJOR_VERSION \
--wheels=jaxlib,jax-cuda-plugin,jax-cuda-pjrt \
--cuda_compute_capabilities=$TF_CUDA_COMPUTE_CAPABILITIES \
--enable_nccl=true \
--bazel_options=--linkopt=-fuse-ld=lld \
--bazel_options=--override_repository=xla=$SRC_PATH_XLA \
--local_xla_path=$SRC_PATH_XLA \
--output_path=${BUILD_PATH_JAXLIB} \
$BUILD_PARAM
popd

# Make sure that JAX depends on the local jaxlib installation
# https://jax.readthedocs.io/en/latest/developer.html#specifying-dependencies-on-local-wheels
line="jaxlib @ file://${BUILD_PATH_JAXLIB}/jaxlib"
if ! grep -xF "${line}" "${SRC_PATH_JAX}/build/requirements.in"; then
pushd "${SRC_PATH_JAX}"
echo "${line}" >> build/requirements.in
echo "jax-cuda${TF_CUDA_MAJOR_VERSION}-pjrt @ file://${BUILD_PATH_JAXLIB}/jax_gpu_pjrt" >> build/requirements.in
echo "jax-cuda${TF_CUDA_MAJOR_VERSION}-plugin @ file://${BUILD_PATH_JAXLIB}/jax_gpu_plugin" >> build/requirements.in
echo "jax-cuda${TF_CUDA_MAJOR_VERSION}-pjrt @ file://${BUILD_PATH_JAXLIB}/jax-cuda-pjrt" >> build/requirements.in
echo "jax-cuda${TF_CUDA_MAJOR_VERSION}-plugin @ file://${BUILD_PATH_JAXLIB}/jax-cuda-plugin" >> build/requirements.in
PYTHON_VERSION=$(python -c 'import sys; print("{}.{}".format(*sys.version_info[:2]))')
bazel run --verbose_failures=true //build:requirements.update --repo_env=HERMETIC_PYTHON_VERSION="${PYTHON_VERSION}"
popd
Expand All @@ -316,13 +316,13 @@ else
fi

# install jax and jaxlib
pip --disable-pip-version-check install -e ${BUILD_PATH_JAXLIB}/jaxlib -e ${BUILD_PATH_JAXLIB}/jax_gpu_pjrt -e ${BUILD_PATH_JAXLIB}/jax_gpu_plugin -e "${SRC_PATH_JAX}"
pip --disable-pip-version-check install -e ${BUILD_PATH_JAXLIB}/jaxlib -e ${BUILD_PATH_JAXLIB}/jax-cuda-pjrt -e ${BUILD_PATH_JAXLIB}/jax-cuda-plugin -e "${SRC_PATH_JAX}"

# after installation (example)
# jax 0.4.32.dev20240808+9c2caedab /opt/jax
# jax-cuda12-pjrt 0.4.32.dev20240808 /opt/jaxlibs/jax_gpu_pjrt
# jax-cuda12-plugin 0.4.32.dev20240808 /opt/jaxlibs/jax_gpu_plugin
# jaxlib 0.4.32.dev20240808 /opt/jaxlibs/jaxlib
## after installation (example)
# jax 0.4.36.dev20241125+f828f2d7d /opt/jax
# jax-cuda12-pjrt 0.4.36.dev20241125 /opt/jaxlibs/jax-cuda-pjrt
# jax-cuda12-plugin 0.4.36.dev20241125 /opt/jaxlibs/jax-cuda-plugin
# jaxlib 0.4.36.dev20241125 /opt/jaxlibs/jaxlib
pip list | grep jax

# Ensure directories are readable by all for non-root users
Expand Down
4 changes: 2 additions & 2 deletions .github/container/test-jax.sh
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ jax_source_dir() {

query_tests() {
cd `jax_source_dir`
python build/build.py --configure_only
python build/build.py build --wheels=jaxlib,jax-cuda-plugin,jax-cuda-pjrt --configure_only
bazel query tests/... 2>&1 | grep -F '//tests:'
exit
}
Expand Down Expand Up @@ -191,5 +191,5 @@ pip install matplotlib
## Run tests

cd `jax_source_dir`
python build/build.py --configure_only
python build/build.py build --wheels=jaxlib,jax-cuda-plugin,jax-cuda-pjrt --configure_only
bazel test ${BAZEL_TARGET} ${TEST_TAG_FILTERS} ${COMMON_FLAGS} ${EXTRA_FLAGS}

0 comments on commit c07ffe4

Please sign in to comment.