Skip to content

Commit

Permalink
Customize NCCL for base container
Browse files Browse the repository at this point in the history
  • Loading branch information
DwarKapex committed Oct 28, 2024
1 parent 277b9ef commit d541bd2
Show file tree
Hide file tree
Showing 5 changed files with 87 additions and 7 deletions.
12 changes: 12 additions & 0 deletions .github/container/Dockerfile.base
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,18 @@ FROM ${BASE_IMAGE}
ARG GIT_USER_EMAIL
ARG GIT_USER_NAME
ARG CLANG_VERSION
ARG JAX_NCCL_VERSION
ARG JAX_LIBNCCL_PACKAGE

###############################################################################
## Update NCCL version env variables
###############################################################################

ENV NV_LIBNCCL_DEV_PACKAGE=${NV_LIBNCCL_DEV_PACKAGE_NAME}=${JAX_LIBNCCL_PACKAGE}
ENV NV_LIBNCCL_DEV_PACKAGE_VERSION=${JAX_NCCL_VERSION}
ENV NCCL_VERSION=${JAX_NCCL_VERSION}
ENV NV_LIBNCCL_PACKAGE=${NV_LIBNCCL_PACKAGE_NAME}=${JAX_LIBNCCL_PACKAGE}
ENV NV_LIBNCCL_PACKAGE_VERSION=${JAX_NCCL_VERSION}

###############################################################################
## Install Python and essential tools
Expand Down
24 changes: 18 additions & 6 deletions .github/container/install-nccl.sh
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,19 @@ set -ex -o pipefail
export DEBIAN_FRONTEND=noninteractive
export TZ=America/Los_Angeles

# If NCCL is already installed, don't reinstall it. Print a message and exit
if dpkg -s libnccl2 libnccl-dev &> /dev/null; then
echo "NCCL is already installed. Skipping installation."
# Try to get NCCL_VERSION of installed libnccl-dev
if [[ -z $NCCL_VERSION ]]; then
NCCL_VERSION=$(dpkg -s libnccl-dev | sed -n "s/^Version: \(.*+cuda${cuda_version}\)$/\1/p" | head -n 1 | tr "+" "\n" | head -1)
fi

# Skip NCCL installation if both JAX_NCCL_VERSION (user defined) and
# NCCL_VERSION (defined in nvidia/cuda containers) are unset.
# This case means that the base container is built from a custom image with
# a custom network communicator or unset NCCL_VERSION env variable.
if [[ -z $JAX_NCCL_VERSION && -z $NCCL_VERSION ]]; then
echo "Skip NCCL installation"
else
JAX_NCCL_VERSION=${JAX_NCCL_VERSION:-$NCCL_VERSION}
apt-get update

# Extract CUDA version from `nvcc --version` output line
Expand All @@ -18,21 +27,24 @@ else

# Find latest NCCL version compatible with existing CUDA by matching
# ${cuda_version} in the package version string
libnccl2_version=$(apt-cache show libnccl-dev | sed -n "s/^Version: \(.*+cuda${cuda_version}\)$/\1/p" | head -n 1)
libnccl_dev_version=$(apt-cache show libnccl-dev | sed -n "s/^Version: \(.*+cuda${cuda_version}\)$/\1/p" | head -n 1)
libnccl2_version=$(apt-cache show libnccl-dev | sed -n "s/^Version: \(${JAX_NCCL_VERSION}.*+cuda.*\)$/\1/p" | head -n 1)
libnccl_dev_version=$(apt-cache show libnccl-dev | sed -n "s/^Version: \(${JAX_NCCL_VERSION}.*+cuda.*\)$/\1/p" | head -n 1)
if [[ -z "${libnccl2_version}" || -z "${libnccl_dev_version}" ]]; then
echo "Could not find compatible NCCL version for CUDA ${cuda_version}"
exit 1
fi

apt-get install -y \
apt-get install -y --allow-change-held-packages \
libnccl2=${libnccl2_version} \
libnccl-dev=${libnccl_dev_version}

apt-get clean
rm -rf /var/lib/apt/lists/*
fi

# Smoke test of installed NCCL packages
dpkg -s libnccl2 libnccl-dev

# Create a prefix with include/ and lib/ directories containing symlinks to the NCCL
# version installed at the system level; this is useful to pass to XLA to avoid it
# fetching its own copy.
Expand Down
45 changes: 44 additions & 1 deletion .github/workflows/_build_base.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,11 @@ on:
description: Artifact name in current run w/ manifest/patches. Leaving empty uses manifest/patches in current branch
default: ''
required: false
JAX_LIBNCCL_PACKAGE:
type: string
description: NCCL lib package version to be installed (in the format `2.19.3-1+cuda12.3`)
default: ''
required: false
outputs:
DOCKER_TAG:
description: "Tag of the image built"
Expand All @@ -56,8 +61,44 @@ permissions:
packages: write # to upload container

jobs:
nccl-version:
runs-on: ubuntu-22.04
outputs:
JAX_NCCL_VERSION: ${{ steps.get-nccl-version.outputs.JAX_NCCL_VERSION }}
JAX_LIBNCCL_PACKAGE: ${{ steps.get-nccl-version.outputs.JAX_LIBNCCL_PACKAGE }}
steps:
- name: Print environment variables
run: env

- name: Check out the repository under ${GITHUB_WORKSPACE}
uses: actions/checkout@v4

- name: Get NCCL version
id: get-nccl-version
shell: bash -x -e {0}
run: |
JAX_LIBNCCL_PACKAGE=${{ inputs.JAX_LIBNCCL_PACKAGE }}
if [[ -z $JAX_LIBNCCL_PACKAGE ]]; then
BASE_IMAGE=${{ inputs.BASE_IMAGE }}
if [[ $BASE_IMAGE == latest ]]; then
BASE_IMAGE=$(cat .github/container/Dockerfile.base | sed -n "s/^ARG BASE_IMAGE=\(.*\)$/\1/p")
fi
# try to get NCCL version from provided BASE_IMAGE of x86-arch
if [[ -z "$BASE_IMAGE" ]]; then
echo "Need to pass non-empty BASE_IMAGE variable"
exit 1
fi
source .github/workflows/scripts/get_remote_env.sh
JAX_LIBNCCL_PACKAGE=$(get_remote_env ${BASE_IMAGE} linux amd64 | jq -r '.[]' | egrep '^NV_LIBNCCL_PACKAGE')
JAX_NCCL_VERSION=$(get_remote_env ${BASE_IMAGE} linux amd64 | jq -r '.[]' | egrep '^NCCL_VERSION=' | cut -d= -f2-)
else
JAX_NCCL_VERSION=$(echo $JAX_LIBNCCL_PACKAGE | cut -d= -f2 | cut -d+ -f1)
fi
echo "JAX_NCCL_VERSION=$JAX_NCCL_VERSION" >> $GITHUB_OUTPUT
echo "JAX_LIBNCCL_PACKAGE=$JAX_LIBNCCL_PACKAGE" >> $GITHUB_OUTPUT
build-base:
needs: nccl-version
runs-on: [self-hosted, "${{ inputs.ARCHITECTURE }}", small]
env:
BADGE_FILENAME_FULL: ${{ inputs.BADGE_FILENAME }}-${{ inputs.ARCHITECTURE }}.json
Expand Down Expand Up @@ -133,7 +174,9 @@ jobs:
GIT_USER_EMAIL=${{ inputs.GIT_USER_EMAIL }}
BUILD_DATE=${{ inputs.BUILD_DATE }}
${{ inputs.BASE_IMAGE != 'latest' && format('BASE_IMAGE={0}', inputs.BASE_IMAGE) || '' }}
JAX_NCCL_VERSION=${{ needs.nccl-version.outputs.JAX_NCCL_VERSION }}
JAX_LIBNCCL_PACKAGE=${{ needs.nccl-version.outputs.JAX_LIBNCCL_PACKAGE }}
- name: Generate sitrep
if: "!cancelled()"
shell: bash -x -e {0}
Expand Down
6 changes: 6 additions & 0 deletions .github/workflows/_ci.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,11 @@ on:
description: 'A JSON object containing git url+refs for softwares to be built'
required: false
default: '{}'
JAX_LIBNCCL_PACKAGE:
type: string
description: NCCL version to be installed (for example, `2.20.3-1+cuda12.4`)
default: ''
required: false
outputs:
DOCKER_TAGS:
description: 'JSON object containing tags of all docker images built'
Expand All @@ -45,6 +50,7 @@ jobs:
BASE_IMAGE: ${{ inputs.CUDA_IMAGE }}
BUILD_DATE: ${{ inputs.BUILD_DATE }}
MANIFEST_ARTIFACT_NAME: ${{ inputs.MANIFEST_ARTIFACT_NAME }}
JAX_LIBNCCL_PACKAGE: ${{ inputs.JAX_LIBNCCL_PACKAGE }}
secrets: inherit

build-jax:
Expand Down
7 changes: 7 additions & 0 deletions .github/workflows/ci.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,11 @@ on:
PACKAGE∊{JAX,XLA,Flax,transformer-engine,T5X,paxml,praxis,maxtext,levanter,haliax,mujuco,mujuco-mpc,gemma,big-vision,common-loop-utils,flaxformer,panopticapi} (case-insensitive)
default: ''
required: false
JAX_LIBNCCL_PACKAGE:
type: string
description: NCCL version to be installed (for example, 2.20.3-1+cuda12.4)
default: ''
required: false

concurrency:
group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }}
Expand Down Expand Up @@ -197,6 +202,7 @@ jobs:
CUDA_IMAGE: ${{ needs.metadata.outputs.CUDA_IMAGE }}
MANIFEST_ARTIFACT_NAME: ${{ needs.metadata.outputs.MANIFEST_ARTIFACT_NAME }}
SOURCE_URLREFS: ${{ needs.bump-manifest.outputs.SOURCE_URLREFS }}
JAX_LIBNCCL_PACKAGE: ${{ inputs.JAX_LIBNCCL_PACKAGE }}
secrets: inherit

arm64:
Expand All @@ -208,6 +214,7 @@ jobs:
CUDA_IMAGE: ${{ needs.metadata.outputs.CUDA_IMAGE }}
MANIFEST_ARTIFACT_NAME: ${{ needs.metadata.outputs.MANIFEST_ARTIFACT_NAME }}
SOURCE_URLREFS: ${{ needs.bump-manifest.outputs.SOURCE_URLREFS }}
JAX_LIBNCCL_PACKAGE: ${{ inputs.JAX_LIBNCCL_PACKAGE }}
secrets: inherit

# Only merge if everything succeeds
Expand Down

0 comments on commit d541bd2

Please sign in to comment.