diff --git a/.github/eks-workflow-files/job.yml b/.github/eks-workflow-files/job.yml new file mode 100644 index 000000000..463f0ee31 --- /dev/null +++ b/.github/eks-workflow-files/job.yml @@ -0,0 +1,80 @@ +apiVersion: v1 +kind: Service +metadata: + name: jax-headless-svc +spec: + clusterIP: None # clusterIP must be None to create a headless service + selector: + job-name: PLACEHOLDER # must match Job name +--- +apiVersion: batch/v1 +kind: Job +metadata: + name: PLACEHOLDER +spec: + completions: 2 # number of nodes + parallelism: 2 # number of nodes + completionMode: Indexed + template: + spec: + subdomain: jax-headless-svc # has to match Service name + restartPolicy: Never + containers: + - name: jax + image: PLACEHOLDER + ports: + - containerPort: 3389 + command: + - sh + - -c + - | + install-efa.sh + nsys-jax \ + --output=/opt/output/${JOB_NAME}-rank%q{JOB_COMPLETION_INDEX}.zip \ + -- \ + jax-nccl-test \ + --coordinator-address \ + ${JOB_NAME}-0.jax-headless-svc:3389 \ + --distributed \ + --gpus-per-process=8 \ + --process-count=2 \ + --process-id=$JOB_COMPLETION_INDEX + touch /opt/output/.done + env: + - name: JOB_NAME + value: PLACEHOLDER + - name: XLA_FLAGS + value: --xla_gpu_enable_command_buffer= + resources: + limits: + nvidia.com/gpu: 8 + vpc.amazonaws.com/efa: 32 + volumeMounts: + - mountPath: /dev/shm + name: shmem + - mountPath: /opt/output + name: output + - name: upload + image: amazon/aws-cli + command: + - sh + - -c + - | + while [[ ! -f /opt/output/.done ]]; do + sleep 1 + done + aws s3 cp \ + /opt/output/*rank${JOB_COMPLETION_INDEX}.zip \ + s3://jax-toolbox-eks-output/ + volumeMounts: + - mountPath: /opt/output + name: output + imagePullSecrets: + - name: PLACEHOLDER + volumes: + - name: output + emptyDir: {} + - name: shmem + emptyDir: + medium: Memory + sizeLimit: 8Gi diff --git a/.github/eks-workflow-files/post-process-job.yml b/.github/eks-workflow-files/post-process-job.yml new file mode 100644 index 000000000..989ddebe2 --- /dev/null +++ b/.github/eks-workflow-files/post-process-job.yml @@ -0,0 +1,46 @@ +apiVersion: batch/v1 +kind: Job +metadata: + name: PLACEHOLDER +spec: + template: + spec: + restartPolicy: Never + initContainers: + - name: download + image: amazon/aws-cli + command: + - aws + - s3 + - cp + - --recursive + - --exclude + - "*" + - --include + - PLACEHOLDER + - s3://jax-toolbox-eks-output/ + - /opt/output + volumeMounts: + - mountPath: /opt/output + name: output + containers: + - name: jax + image: PLACEHOLDER + command: + - bash + - -exo + - pipefail + - -c + - nsys-jax-combine -o /opt/output/combined.zip /opt/output/*.zip --analysis communication + # FIXME: GPU not actually needed, but the test cluster doesn't have appropriate non-GPU nodes + resources: + limits: + nvidia.com/gpu: 1 + volumeMounts: + - mountPath: /opt/output + name: output + imagePullSecrets: + - name: PLACEHOLDER + volumes: + - name: output + emptyDir: {} diff --git a/.github/workflows/_ci.yaml b/.github/workflows/_ci.yaml index 2ca5d1a1e..a5f685ae1 100644 --- a/.github/workflows/_ci.yaml +++ b/.github/workflows/_ci.yaml @@ -440,6 +440,84 @@ jobs: popd done + test-nsys-jax-eks: + needs: build-jax + if: inputs.ARCHITECTURE == 'amd64' # arm64 runners n/a + runs-on: eks + env: + JAX_DOCKER_IMAGE: ${{ needs.build-jax.outputs.DOCKER_TAG_FINAL }} + JOB_NAME: ${{ github.run_id }}-${{ github.run_attempt }}-jax + POSTPROCESS_JOB_NAME: ${{ github.run_id }}-${{ github.run_attempt }}-postprocess + TOKEN_NAME: ${{ github.run_id }}-${{ github.run_attempt }}-token + steps: + - name: Check out the repository + uses: actions/checkout@v4 + - name: Install yq + run: | + mkdir local_bin/ + curl -L -o ./local_bin/yq https://github.com/mikefarah/yq/releases/latest/download/yq_linux_$(dpkg --print-architecture) + chmod 777 ./local_bin/yq + echo "${PWD}/local_bin" >> "${GITHUB_PATH}" + - name: Login to GitHub Container Registry + uses: docker/login-action@v3 + with: + registry: ghcr.io + username: ${{ github.repository_owner }} + password: ${{ secrets.GITHUB_TOKEN }} + - name: Store GitHub Container Registry token as Kubernetes secret + run: | + kubectl create secret generic \ + ${{ github.run_id }}-${{ github.run_attempt }}-token \ + --from-file=.dockerconfigjson=$HOME/.docker/config.json \ + --type=kubernetes.io/dockerconfigjson + - name: Configure Kubernetes job + run: | + yq -i ea 'select(di == 0).spec.selector.job-name = strenv(JOB_NAME) + | select(di == 1).metadata.name = strenv(JOB_NAME) + | select(di == 1).spec.template.spec.imagePullSecrets[].name = strenv(TOKEN_NAME) + | select(di == 1).spec.template.spec.containers[0].image = strenv(JAX_DOCKER_IMAGE) + | select(di == 1).spec.template.spec.containers[0].env[0].value = strenv(JOB_NAME)' \ + .github/eks-workflow-files/job.yml + git diff .github/eks-workflow-files/job.yml + - name: Submit Kubernetes job + run: kubectl apply -f .github/eks-workflow-files/job.yml + - name: Wait for Kubernetes job to start + run: | + while [[ -n $(kubectl get pods --selector=batch.kubernetes.io/job-name=${{ github.run_id }}-${{ github.run_attempt }}-jax --output=jsonpath='{.items[?(@.status.phase == "Pending")].metadata.name}') ]]; do + sleep 2 + done + - name: Stream Kubernetes job output + run: kubectl logs --all-containers=true --all-pods=true --follow job/${{ github.run_id }}-${{ github.run_attempt }}-jax + # Clean up in case of errors as well as success + - name: Delete Kubernetes job + if: always() + run: kubectl delete job ${{ github.run_id }}-${{ github.run_attempt }}-jax + - name: Configure post-processing job + run: | + export JOB_OUTPUT_PATTERN="${JOB_NAME}-rank*.zip" + yq -i '.metadata.name = strenv(POSTPROCESS_JOB_NAME) + | .spec.template.spec.containers[].image = strenv(JAX_DOCKER_IMAGE) + | .spec.template.spec.imagePullSecrets[].name = strenv(TOKEN_NAME) + | .spec.template.spec.initContainers[].command[7] = strenv(JOB_OUTPUT_PATTERN)' \ + .github/eks-workflow-files/post-process-job.yml + git diff .github/eks-workflow-files/post-process-job.yml + - name: Submit post-processing Kubernetes job + run: kubectl apply -f .github/eks-workflow-files/post-process-job.yml + - name: Wait for post-processing Kubernetes job to start + run: | + while [[ -n $(kubectl get pods --selector=batch.kubernetes.io/job-name=${{ github.run_id }}-${{ github.run_attempt }}-postprocess --output=jsonpath='{.items[?(@.status.phase == "Pending")].metadata.name}') ]]; do + sleep 2 + done + - name: Stream post-processing Kubernetes job output + run: kubectl logs --all-containers=true --all-pods=true --follow job/${{ github.run_id }}-${{ github.run_attempt }}-postprocess + # Clean up in case of errors as well as success + - name: Delete post-processing Kubernetes job + if: always() + run: kubectl delete job ${{ github.run_id }}-${{ github.run_attempt }}-postprocess + - name: Delete GitHub Container Registry token + if: always() + run: kubectl delete secret ${{ github.run_id }}-${{ github.run_attempt }}-token + # test-equinox: # needs: build-equinox # if: inputs.ARCHITECTURE == 'amd64' # arm64 runners n/a