diff --git a/.github/container/test-jax.sh b/.github/container/test-jax.sh index a5fea5365..3a14c7a72 100755 --- a/.github/container/test-jax.sh +++ b/.github/container/test-jax.sh @@ -26,7 +26,7 @@ jax_source_dir() { query_tests() { cd `jax_source_dir` - python build/build.py --configure_only + python build/build.py build --wheels=jaxlib,jax-cuda-plugin,jax-cuda-pjrt --configure_only bazel query tests/... 2>&1 | grep -F '//tests:' exit } @@ -191,5 +191,5 @@ pip install matplotlib ## Run tests cd `jax_source_dir` -python build/build.py --configure_only +python build/build.py build --wheels=jaxlib,jax-cuda-plugin,jax-cuda-pjrt --configure_only bazel test ${BAZEL_TARGET} ${TEST_TAG_FILTERS} ${COMMON_FLAGS} ${EXTRA_FLAGS}