diff --git a/.github/action/Dockerfile b/.github/action/Dockerfile index 9279506..2b9f297 100644 --- a/.github/action/Dockerfile +++ b/.github/action/Dockerfile @@ -1,10 +1,10 @@ FROM nvidia/cuda:11.8.0-devel-ubuntu22.04 RUN apt-get update && \ - DEBIAN_FRONTEND=noninteractive apt-get install -y git python3-pip cmake + DEBIAN_FRONTEND=noninteractive apt-get install -y git python3-pip RUN pip install --upgrade pip && \ - pip install --upgrade "jax[cuda]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html + pip install --upgrade "jax[cuda11_local]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html COPY entrypoint.sh /entrypoint.sh