From 19842a28a323a79d9aded4b4369e5eb889678258 Mon Sep 17 00:00:00 2001 From: Peter Andreas Entschev Date: Tue, 7 May 2024 04:22:02 +0200 Subject: [PATCH] Add UCXX support (#1983) Add support for [UCXX](https://github.com/rapidsai/ucxx). It is our intention to soon switch from UCX-Py to UCXX and archive the former. This PR adds support for UCXX on the C++ backend but retains the original UCX implementation for now (based on the UCP layer), moving to UCXX will simplify RAFT code a bit given the UCXX implementation requires fewer lines of boilerplate code. On the Python front raft-dask tests are added for both UCX-Py (which there weren't any) and UCXX. The UCX-Py tests continue to use the UCX (UCP layer) implementation, whereas the UCXX tests use the UCXX C++ implementation. When the switch is complete we can remove all previous UCX/UCX-Py code from the RAFT codebase. If for some reason using the UCX (UCP layer) is preferred on the C++ backend instead of the UCXX C++ implementation this is possible, but UCX-Py code will be archived and dropped in favor of the UCXX Python backend. Authors: - Peter Andreas Entschev (https://github.com/pentschev) - Akira Naruse (https://github.com/anaruse) - Vyas Ramasubramani (https://github.com/vyasr) - Bradley Dice (https://github.com/bdice) Approvers: - Jake Awe (https://github.com/AyodeAwe) - Divye Gala (https://github.com/divyegala) - Vyas Ramasubramani (https://github.com/vyasr) URL: https://github.com/rapidsai/raft/pull/1983 --- ci/build_wheel.sh | 10 +- ci/release/update-version.sh | 9 +- ci/test_python.sh | 18 ++ ci/test_wheel_raft_dask.sh | 12 +- .../all_cuda-118_arch-aarch64.yaml | 2 + .../all_cuda-118_arch-x86_64.yaml | 2 + .../all_cuda-122_arch-aarch64.yaml | 2 + .../all_cuda-122_arch-x86_64.yaml | 2 + .../bench_ann_cuda-118_arch-aarch64.yaml | 1 + .../bench_ann_cuda-118_arch-x86_64.yaml | 1 + .../bench_ann_cuda-120_arch-aarch64.yaml | 1 + .../bench_ann_cuda-120_arch-x86_64.yaml | 1 + .../recipes/raft-dask/conda_build_config.yaml | 6 +- conda/recipes/raft-dask/meta.yaml | 7 +- cpp/CMakeLists.txt | 15 +- cpp/include/raft/comms/detail/std_comms.hpp | 274 +++++++++++++----- cpp/include/raft/comms/detail/ucp_helper.hpp | 27 +- cpp/include/raft/comms/std_comms.hpp | 54 +++- dependencies.yaml | 82 +++++- python/raft-dask/CMakeLists.txt | 9 +- .../raft-dask/cmake/thirdparty/get_ucxx.cmake | 55 ++++ python/raft-dask/pyproject.toml | 3 + python/raft-dask/pytest.ini | 3 +- python/raft-dask/raft_dask/__init__.py | 12 +- python/raft-dask/raft_dask/common/comms.py | 19 +- .../raft_dask/common/comms_utils.pyx | 11 +- python/raft-dask/raft_dask/common/ucx.py | 32 +- python/raft-dask/raft_dask/test/conftest.py | 64 +++- python/raft-dask/raft_dask/test/test_comms.py | 158 +++++++++- 29 files changed, 735 insertions(+), 157 deletions(-) create mode 100644 python/raft-dask/cmake/thirdparty/get_ucxx.cmake diff --git a/ci/build_wheel.sh b/ci/build_wheel.sh index 5d06e46303..e3e7ce9c89 100755 --- a/ci/build_wheel.sh +++ b/ci/build_wheel.sh @@ -1,5 +1,5 @@ #!/bin/bash -# Copyright (c) 2023, NVIDIA CORPORATION. +# Copyright (c) 2023-2024, NVIDIA CORPORATION. set -euo pipefail @@ -7,6 +7,10 @@ package_name=$1 package_dir=$2 underscore_package_name=$(echo "${package_name}" | tr "-" "_") +# Clear out system ucx files to ensure that we're getting ucx from the wheel. +rm -rf /usr/lib64/ucx +rm -rf /usr/lib64/libuc* + source rapids-configure-sccache source rapids-date-string @@ -38,9 +42,11 @@ fi if [[ ${package_name} == "raft-dask" ]]; then sed -r -i "s/pylibraft==(.*)\"/pylibraft${PACKAGE_CUDA_SUFFIX}==\1${alpha_spec}\"/g" ${pyproject_file} + sed -r -i "s/libucx(.*)\"/libucx${PACKAGE_CUDA_SUFFIX}\1${alpha_spec}\"/g" ${pyproject_file} sed -r -i "s/ucx-py==(.*)\"/ucx-py${PACKAGE_CUDA_SUFFIX}==\1${alpha_spec}\"/g" ${pyproject_file} sed -r -i "s/rapids-dask-dependency==(.*)\"/rapids-dask-dependency==\1${alpha_spec}\"/g" ${pyproject_file} sed -r -i "s/dask-cuda==(.*)\"/dask-cuda==\1${alpha_spec}\"/g" ${pyproject_file} + sed -r -i "s/distributed-ucxx==(.*)\"/distributed-ucxx${PACKAGE_CUDA_SUFFIX}==\1${alpha_spec}\"/g" ${pyproject_file} else sed -r -i "s/rmm(.*)\"/rmm${PACKAGE_CUDA_SUFFIX}\1${alpha_spec}\"/g" ${pyproject_file} fi @@ -56,6 +62,6 @@ cd "${package_dir}" python -m pip wheel . -w dist -vvv --no-deps --disable-pip-version-check mkdir -p final_dist -python -m auditwheel repair -w final_dist dist/* +python -m auditwheel repair -w final_dist --exclude "libucp.so.0" dist/* RAPIDS_PY_WHEEL_NAME="${underscore_package_name}_${RAPIDS_PY_CUDA_SUFFIX}" rapids-upload-wheels-to-s3 final_dist diff --git a/ci/release/update-version.sh b/ci/release/update-version.sh index 46b992392c..ef9b3e4b83 100755 --- a/ci/release/update-version.sh +++ b/ci/release/update-version.sh @@ -37,6 +37,8 @@ function sed_runner() { } sed_runner "s/set(RAPIDS_VERSION .*)/set(RAPIDS_VERSION \"${NEXT_SHORT_TAG}\")/g" cpp/template/cmake/thirdparty/fetch_rapids.cmake +sed_runner 's/'"find_and_configure_ucxx(VERSION .*"'/'"find_and_configure_ucxx(VERSION ${NEXT_UCX_PY_SHORT_TAG_PEP440}"'/g' python/raft-dask/cmake/thirdparty/get_ucxx.cmake +sed_runner 's/'"branch-.*"'/'"branch-${NEXT_UCX_PY_SHORT_TAG_PEP440}"'/g' python/raft-dask/cmake/thirdparty/get_ucxx.cmake # Centralized version file update echo "${NEXT_FULL_TAG}" > VERSION @@ -50,7 +52,7 @@ DEPENDENCIES=( rmm-cu11 rmm-cu12 rapids-dask-dependency - # ucx-py is handled separately below + # ucx-py and ucxx are handled separately below ) for FILE in dependencies.yaml conda/environments/*.yaml; do for DEP in "${DEPENDENCIES[@]}"; do @@ -59,6 +61,10 @@ for FILE in dependencies.yaml conda/environments/*.yaml; do sed_runner "/-.* ucx-py==/ s/==.*/==${NEXT_UCX_PY_SHORT_TAG_PEP440}\.*/g" ${FILE}; sed_runner "/-.* ucx-py-cu11==/ s/==.*/==${NEXT_UCX_PY_SHORT_TAG_PEP440}\.*/g" ${FILE}; sed_runner "/-.* ucx-py-cu12==/ s/==.*/==${NEXT_UCX_PY_SHORT_TAG_PEP440}\.*/g" ${FILE}; + sed_runner "/-.* libucxx==/ s/==.*/==${NEXT_UCX_PY_SHORT_TAG_PEP440}\.*/g" ${FILE}; + sed_runner "/-.* distributed-ucxx==/ s/==.*/==${NEXT_UCX_PY_SHORT_TAG_PEP440}\.*/g" ${FILE}; + sed_runner "/-.* distributed-ucxx-cu11==/ s/==.*/==${NEXT_UCX_PY_SHORT_TAG_PEP440}\.*/g" ${FILE}; + sed_runner "/-.* distributed-ucxx-cu12==/ s/==.*/==${NEXT_UCX_PY_SHORT_TAG_PEP440}\.*/g" ${FILE}; done for FILE in python/*/pyproject.toml; do for DEP in "${DEPENDENCIES[@]}"; do @@ -68,6 +74,7 @@ for FILE in python/*/pyproject.toml; do done sed_runner "/^ucx_py_version:$/ {n;s/.*/ - \"${NEXT_UCX_PY_VERSION}\"/}" conda/recipes/raft-dask/conda_build_config.yaml +sed_runner "/^ucxx_version:$/ {n;s/.*/ - \"${NEXT_UCX_PY_VERSION}\"/}" conda/recipes/raft-dask/conda_build_config.yaml for FILE in .github/workflows/*.yaml; do sed_runner "/shared-workflows/ s/@.*/@branch-${NEXT_SHORT_TAG}/g" "${FILE}" diff --git a/ci/test_python.sh b/ci/test_python.sh index f5b188ca0b..59da1f0bc4 100755 --- a/ci/test_python.sh +++ b/ci/test_python.sh @@ -59,5 +59,23 @@ rapids-logger "pytest raft-dask" --cov-report=xml:"${RAPIDS_COVERAGE_DIR}/raft-dask-coverage.xml" \ --cov-report=term +rapids-logger "pytest raft-dask (ucx-py only)" +./ci/run_raft_dask_pytests.sh \ + --junitxml="${RAPIDS_TESTS_DIR}/junit-raft-dask-ucx.xml" \ + --cov-config=../.coveragerc \ + --cov=raft_dask \ + --cov-report=xml:"${RAPIDS_COVERAGE_DIR}/raft-dask-ucx-coverage.xml" \ + --cov-report=term \ + --run_ucx + +rapids-logger "pytest raft-dask (ucxx only)" +./ci/run_raft_dask_pytests.sh \ + --junitxml="${RAPIDS_TESTS_DIR}/junit-raft-dask-ucxx.xml" \ + --cov-config=../.coveragerc \ + --cov=raft_dask \ + --cov-report=xml:"${RAPIDS_COVERAGE_DIR}/raft-dask-ucxx-coverage.xml" \ + --cov-report=term \ + --run_ucxx + rapids-logger "Test script exiting with value: $EXITCODE" exit ${EXITCODE} diff --git a/ci/test_wheel_raft_dask.sh b/ci/test_wheel_raft_dask.sh index 76bb62e859..fe2d44f2b3 100755 --- a/ci/test_wheel_raft_dask.sh +++ b/ci/test_wheel_raft_dask.sh @@ -11,7 +11,13 @@ RAPIDS_PY_WHEEL_NAME="raft_dask_${RAPIDS_PY_CUDA_SUFFIX}" rapids-download-wheels RAPIDS_PY_WHEEL_NAME="pylibraft_${RAPIDS_PY_CUDA_SUFFIX}" rapids-download-wheels-from-s3 ./local-pylibraft-dep python -m pip install --no-deps ./local-pylibraft-dep/pylibraft*.whl -# echo to expand wildcard before adding `[extra]` requires for pip -python -m pip install $(echo ./dist/raft_dask*.whl)[test] +python -m pip install "raft_dask-${RAPIDS_PY_CUDA_SUFFIX}[test]>=0.0.0a0" --find-links dist/ -python -m pytest ./python/raft-dask/raft_dask/test +# rapids-logger "pytest raft-dask" +# python -m pytest ./python/raft-dask/raft_dask/test + +# rapids-logger "pytest raft-dask (ucx-py only)" +# python -m pytest ./python/raft-dask/raft_dask/test --run_ucx + +rapids-logger "pytest raft-dask (ucxx only)" +python -m pytest ./python/raft-dask/raft_dask/test --run_ucxx diff --git a/conda/environments/all_cuda-118_arch-aarch64.yaml b/conda/environments/all_cuda-118_arch-aarch64.yaml index 189f8268df..7453df2593 100644 --- a/conda/environments/all_cuda-118_arch-aarch64.yaml +++ b/conda/environments/all_cuda-118_arch-aarch64.yaml @@ -21,6 +21,7 @@ dependencies: - cxx-compiler - cython>=3.0.0 - dask-cuda==24.6.* +- distributed-ucxx==0.38.* - doxygen>=1.8.20 - gcc_linux-aarch64=11.* - graphviz @@ -34,6 +35,7 @@ dependencies: - libcusolver=11.4.1.48 - libcusparse-dev=11.7.5.86 - libcusparse=11.7.5.86 +- libucxx==0.38.* - nccl>=2.9.9 - ninja - numba>=0.57 diff --git a/conda/environments/all_cuda-118_arch-x86_64.yaml b/conda/environments/all_cuda-118_arch-x86_64.yaml index e604705112..b983eb0388 100644 --- a/conda/environments/all_cuda-118_arch-x86_64.yaml +++ b/conda/environments/all_cuda-118_arch-x86_64.yaml @@ -21,6 +21,7 @@ dependencies: - cxx-compiler - cython>=3.0.0 - dask-cuda==24.6.* +- distributed-ucxx==0.38.* - doxygen>=1.8.20 - gcc_linux-64=11.* - graphviz @@ -34,6 +35,7 @@ dependencies: - libcusolver=11.4.1.48 - libcusparse-dev=11.7.5.86 - libcusparse=11.7.5.86 +- libucxx==0.38.* - nccl>=2.9.9 - ninja - numba>=0.57 diff --git a/conda/environments/all_cuda-122_arch-aarch64.yaml b/conda/environments/all_cuda-122_arch-aarch64.yaml index 49c53b4cfe..7dacfc2d2b 100644 --- a/conda/environments/all_cuda-122_arch-aarch64.yaml +++ b/conda/environments/all_cuda-122_arch-aarch64.yaml @@ -22,6 +22,7 @@ dependencies: - cxx-compiler - cython>=3.0.0 - dask-cuda==24.6.* +- distributed-ucxx==0.38.* - doxygen>=1.8.20 - gcc_linux-aarch64=11.* - graphviz @@ -31,6 +32,7 @@ dependencies: - libcurand-dev - libcusolver-dev - libcusparse-dev +- libucxx==0.38.* - nccl>=2.9.9 - ninja - numba>=0.57 diff --git a/conda/environments/all_cuda-122_arch-x86_64.yaml b/conda/environments/all_cuda-122_arch-x86_64.yaml index 6f782175dd..1c16d2ea93 100644 --- a/conda/environments/all_cuda-122_arch-x86_64.yaml +++ b/conda/environments/all_cuda-122_arch-x86_64.yaml @@ -22,6 +22,7 @@ dependencies: - cxx-compiler - cython>=3.0.0 - dask-cuda==24.6.* +- distributed-ucxx==0.38.* - doxygen>=1.8.20 - gcc_linux-64=11.* - graphviz @@ -31,6 +32,7 @@ dependencies: - libcurand-dev - libcusolver-dev - libcusparse-dev +- libucxx==0.38.* - nccl>=2.9.9 - ninja - numba>=0.57 diff --git a/conda/environments/bench_ann_cuda-118_arch-aarch64.yaml b/conda/environments/bench_ann_cuda-118_arch-aarch64.yaml index b5f662ebc1..7315f82c13 100644 --- a/conda/environments/bench_ann_cuda-118_arch-aarch64.yaml +++ b/conda/environments/bench_ann_cuda-118_arch-aarch64.yaml @@ -30,6 +30,7 @@ dependencies: - libcusolver=11.4.1.48 - libcusparse-dev=11.7.5.86 - libcusparse=11.7.5.86 +- libucxx==0.38.* - matplotlib - nccl>=2.9.9 - ninja diff --git a/conda/environments/bench_ann_cuda-118_arch-x86_64.yaml b/conda/environments/bench_ann_cuda-118_arch-x86_64.yaml index 6c56cb688c..ff973acc0c 100644 --- a/conda/environments/bench_ann_cuda-118_arch-x86_64.yaml +++ b/conda/environments/bench_ann_cuda-118_arch-x86_64.yaml @@ -30,6 +30,7 @@ dependencies: - libcusolver=11.4.1.48 - libcusparse-dev=11.7.5.86 - libcusparse=11.7.5.86 +- libucxx==0.38.* - matplotlib - nccl>=2.9.9 - ninja diff --git a/conda/environments/bench_ann_cuda-120_arch-aarch64.yaml b/conda/environments/bench_ann_cuda-120_arch-aarch64.yaml index 7f3107e5d6..056550fc07 100644 --- a/conda/environments/bench_ann_cuda-120_arch-aarch64.yaml +++ b/conda/environments/bench_ann_cuda-120_arch-aarch64.yaml @@ -27,6 +27,7 @@ dependencies: - libcurand-dev - libcusolver-dev - libcusparse-dev +- libucxx==0.38.* - matplotlib - nccl>=2.9.9 - ninja diff --git a/conda/environments/bench_ann_cuda-120_arch-x86_64.yaml b/conda/environments/bench_ann_cuda-120_arch-x86_64.yaml index 62739354a5..41a48f4a12 100644 --- a/conda/environments/bench_ann_cuda-120_arch-x86_64.yaml +++ b/conda/environments/bench_ann_cuda-120_arch-x86_64.yaml @@ -27,6 +27,7 @@ dependencies: - libcurand-dev - libcusolver-dev - libcusparse-dev +- libucxx==0.38.* - matplotlib - nccl>=2.9.9 - ninja diff --git a/conda/recipes/raft-dask/conda_build_config.yaml b/conda/recipes/raft-dask/conda_build_config.yaml index 345cef49a1..b157e41753 100644 --- a/conda/recipes/raft-dask/conda_build_config.yaml +++ b/conda/recipes/raft-dask/conda_build_config.yaml @@ -16,11 +16,11 @@ c_stdlib: c_stdlib_version: - "2.17" -ucx_version: - - ">=1.15.0,<1.16.0" - ucx_py_version: - "0.38.*" +ucxx_version: + - "0.38.*" + cmake_version: - ">=3.26.4" diff --git a/conda/recipes/raft-dask/meta.yaml b/conda/recipes/raft-dask/meta.yaml index 7c2fb257b1..50042780b4 100644 --- a/conda/recipes/raft-dask/meta.yaml +++ b/conda/recipes/raft-dask/meta.yaml @@ -56,9 +56,9 @@ requirements: - rmm ={{ minor_version }} - scikit-build-core >=0.7.0 - setuptools - - ucx {{ ucx_version }} - - ucx-proc=*=gpu - ucx-py {{ ucx_py_version }} + - libucxx {{ ucxx_version }} + - ucxx {{ ucxx_version }} run: {% if cuda_major == "11" %} - cudatoolkit @@ -73,9 +73,8 @@ requirements: - pylibraft {{ version }} - python x.x - rmm ={{ minor_version }} - - ucx {{ ucx_version }} - - ucx-proc=*=gpu - ucx-py {{ ucx_py_version }} + - ucxx {{ ucxx_version }} tests: requirements: diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt index eaab637338..259d9fe428 100644 --- a/cpp/CMakeLists.txt +++ b/cpp/CMakeLists.txt @@ -650,12 +650,21 @@ rapids_find_generate_module( INSTALL_EXPORT_SET raft-distributed-exports ) -rapids_export_package(BUILD ucx raft-distributed-exports) -rapids_export_package(INSTALL ucx raft-distributed-exports) +rapids_export_package( + BUILD ucxx raft-distributed-exports COMPONENTS ucxx python GLOBAL_TARGETS ucxx::ucxx ucxx::python +) +rapids_export_package( + INSTALL ucxx raft-distributed-exports COMPONENTS ucxx python GLOBAL_TARGETS ucxx::ucxx + ucxx::python +) rapids_export_package(BUILD NCCL raft-distributed-exports) rapids_export_package(INSTALL NCCL raft-distributed-exports) -target_link_libraries(raft_distributed INTERFACE ucx::ucp NCCL::NCCL) +# ucx is a requirement for raft_distributed, but its config is not safe to be found multiple times, +# so rather than exporting a package dependency on it above we rely on consumers to find it +# themselves. Once https://github.com/rapidsai/ucxx/issues/173 is resolved we can export it above +# again. +target_link_libraries(raft_distributed INTERFACE ucx::ucp ucxx::ucxx NCCL::NCCL) # ################################################################################################## # * install targets----------------------------------------------------------- diff --git a/cpp/include/raft/comms/detail/std_comms.hpp b/cpp/include/raft/comms/detail/std_comms.hpp index 6e7ff7106f..cb1accc95e 100644 --- a/cpp/include/raft/comms/detail/std_comms.hpp +++ b/cpp/include/raft/comms/detail/std_comms.hpp @@ -34,6 +34,7 @@ #include #include #include +#include #include #include @@ -49,6 +50,17 @@ namespace raft { namespace comms { namespace detail { +using ucp_endpoint_array_t = std::shared_ptr; +using ucxx_endpoint_array_t = std::shared_ptr; +using ucp_worker_t = ucp_worker_h; +using ucxx_worker_t = ucxx::Worker*; + +struct ucx_objects_t { + public: + std::variant endpoints; + std::variant worker; +}; + class std_comms : public comms_iface { public: std_comms() = delete; @@ -64,8 +76,7 @@ class std_comms : public comms_iface { * @param subcomms_ucp use ucp for subcommunicators */ std_comms(ncclComm_t nccl_comm, - ucp_worker_h ucp_worker, - std::shared_ptr eps, + ucx_objects_t ucx_objects, int num_ranks, int rank, rmm::cuda_stream_view stream, @@ -76,9 +87,8 @@ class std_comms : public comms_iface { num_ranks_(num_ranks), rank_(rank), subcomms_ucp_(subcomms_ucp), + ucx_objects_(ucx_objects), own_nccl_comm_(false), - ucp_worker_(ucp_worker), - ucp_eps_(eps), next_request_id_(0) { initialize(); @@ -205,96 +215,209 @@ class std_comms : public comms_iface { void isend(const void* buf, size_t size, int dest, int tag, request_t* request) const { - ASSERT(ucp_worker_ != nullptr, "ERROR: UCX comms not initialized on communicator."); + if (std::holds_alternative(ucx_objects_.worker)) { + get_request_id(request); - get_request_id(request); - ucp_ep_h ep_ptr = (*ucp_eps_)[dest]; + ucxx::Endpoint* ep_ptr = (*std::get(ucx_objects_.endpoints))[dest]; - ucp_request* ucp_req = (ucp_request*)malloc(sizeof(ucp_request)); + ucp_tag_t ucp_tag = build_message_tag(get_rank(), tag); + auto ucxx_req = ep_ptr->tagSend(const_cast(buf), size, ucxx::Tag(ucp_tag)); - this->ucp_handler_.ucp_isend(ucp_req, ep_ptr, buf, size, tag, default_tag_mask, get_rank()); + requests_in_flight_.insert(std::make_pair(*request, ucxx_req)); + } else { + ASSERT(std::get(ucx_objects_.worker) != nullptr, + "ERROR: UCX comms not initialized on communicator."); - requests_in_flight_.insert(std::make_pair(*request, ucp_req)); - } + get_request_id(request); + ucp_ep_h ep_ptr = (*std::get(ucx_objects_.endpoints))[dest]; - void irecv(void* buf, size_t size, int source, int tag, request_t* request) const - { - ASSERT(ucp_worker_ != nullptr, "ERROR: UCX comms not initialized on communicator."); + ucp_request* ucp_req = (ucp_request*)malloc(sizeof(ucp_request)); - get_request_id(request); + this->ucp_handler_.ucp_isend(ucp_req, ep_ptr, buf, size, tag, default_tag_mask, get_rank()); - ucp_ep_h ep_ptr = (*ucp_eps_)[source]; - - ucp_tag_t tag_mask = default_tag_mask; - - ucp_request* ucp_req = (ucp_request*)malloc(sizeof(ucp_request)); - ucp_handler_.ucp_irecv(ucp_req, ucp_worker_, ep_ptr, buf, size, tag, tag_mask, source); - - requests_in_flight_.insert(std::make_pair(*request, ucp_req)); + requests_in_flight_.insert(std::make_pair(*request, ucp_req)); + } } - void waitall(int count, request_t array_of_requests[]) const + void irecv(void* buf, size_t size, int source, int tag, request_t* request) const { - ASSERT(ucp_worker_ != nullptr, "ERROR: UCX comms not initialized on communicator."); + if (std::holds_alternative(ucx_objects_.worker)) { + get_request_id(request); - std::vector requests; - requests.reserve(count); + ucxx::Endpoint* ep_ptr = (*std::get(ucx_objects_.endpoints))[source]; - time_t start = time(NULL); + ucp_tag_t ucp_tag = build_message_tag(get_rank(), tag); + auto ucxx_req = + ep_ptr->tagRecv(buf, size, ucxx::Tag(ucp_tag), ucxx::TagMask(default_tag_mask)); - for (int i = 0; i < count; ++i) { - auto req_it = requests_in_flight_.find(array_of_requests[i]); - ASSERT(requests_in_flight_.end() != req_it, - "ERROR: waitall on invalid request: %d", - array_of_requests[i]); - requests.push_back(req_it->second); - free_requests_.insert(req_it->first); - requests_in_flight_.erase(req_it); - } - - while (requests.size() > 0) { - time_t now = time(NULL); + requests_in_flight_.insert(std::make_pair(*request, ucxx_req)); + } else { + ASSERT(std::get(ucx_objects_.worker) != nullptr, + "ERROR: UCX comms not initialized on communicator."); - // Timeout if we have not gotten progress or completed any requests - // in 10 or more seconds. - ASSERT(now - start < 10, "Timed out waiting for requests."); + get_request_id(request); - for (std::vector::iterator it = requests.begin(); it != requests.end();) { - bool restart = false; // resets the timeout when any progress was made + ucp_ep_h ep_ptr = (*std::get(ucx_objects_.endpoints))[source]; - // Causes UCP to progress through the send/recv message queue - while (ucp_worker_progress(ucp_worker_) != 0) { - restart = true; - } + ucp_tag_t tag_mask = default_tag_mask; - auto req = *it; + ucp_request* ucp_req = (ucp_request*)malloc(sizeof(ucp_request)); + ucp_handler_.ucp_irecv(ucp_req, + std::get(ucx_objects_.worker), + ep_ptr, + buf, + size, + tag, + tag_mask, + source); - // If the message needs release, we know it will be sent/received - // asynchronously, so we will need to track and verify its state - if (req->needs_release) { - ASSERT(UCS_PTR_IS_PTR(req->req), "UCX Request Error. Request is not valid UCX pointer"); - ASSERT(!UCS_PTR_IS_ERR(req->req), "UCX Request Error: %d\n", UCS_PTR_STATUS(req->req)); - ASSERT(req->req->completed == 1 || req->req->completed == 0, - "request->completed not a valid value: %d\n", - req->req->completed); - } + requests_in_flight_.insert(std::make_pair(*request, ucp_req)); + } + } - // If a message was sent synchronously (eg. completed before - // `isend`/`irecv` completed) or an asynchronous message - // is complete, we can go ahead and clean it up. - if (!req->needs_release || req->req->completed == 1) { - restart = true; + void waitall(int count, request_t array_of_requests[]) const + { + if (std::holds_alternative(ucx_objects_.worker)) { + ucxx_worker_t worker = std::get(ucx_objects_.worker); + + std::vector> requests; + requests.reserve(count); + + time_t start = time(NULL); + + for (int i = 0; i < count; ++i) { + auto req_it = requests_in_flight_.find(array_of_requests[i]); + ASSERT(requests_in_flight_.end() != req_it, + "ERROR: waitall on invalid request: %d", + array_of_requests[i]); + requests.push_back(std::get>(req_it->second)); + free_requests_.insert(req_it->first); + requests_in_flight_.erase(req_it); + } - // perform cleanup - ucp_handler_.free_ucp_request(req); + while (requests.size() > 0) { + time_t now = time(NULL); + + // Timeout if we have not gotten progress or completed any requests + // in 10 or more seconds. + ASSERT(now - start < 10, "Timed out waiting for requests."); + + for (std::vector>::iterator it = requests.begin(); + it != requests.end();) { + bool restart = false; // resets the timeout when any progress was made + + if (worker->isProgressThreadRunning()) { + // Wait for a UCXX progress thread roundtrip + ucxx::utils::CallbackNotifier callbackNotifierPre{}; + worker->registerGenericPre([&callbackNotifierPre]() { callbackNotifierPre.set(); }); + callbackNotifierPre.wait(); + + ucxx::utils::CallbackNotifier callbackNotifierPost{}; + worker->registerGenericPost([&callbackNotifierPost]() { callbackNotifierPost.set(); }); + callbackNotifierPost.wait(); + } else { + // Causes UCXX to progress through the send/recv message queue + while (!worker->progress()) { + restart = true; + } + } + + auto req = *it; + + // If the message needs release, we know it will be sent/received + // asynchronously, so we will need to track and verify its state + if (req->isCompleted()) { + auto status = req->getStatus(); + ASSERT(req->getStatus() == UCS_OK, + "UCX Request Error: %d (%s)\n", + status, + ucs_status_string(status)); + } + + // If a message was sent synchronously (eg. completed before + // `isend`/`irecv` completed) or an asynchronous message + // is complete, we can go ahead and clean it up. + if (req->isCompleted()) { + restart = true; + + auto status = req->getStatus(); + ASSERT(req->getStatus() == UCS_OK, + "UCX Request Error: %d (%s)\n", + status, + ucs_status_string(status)); + + // remove from pending requests + it = requests.erase(it); + } else { + ++it; + } + // if any progress was made, reset the timeout start time + if (restart) { start = time(NULL); } + } + } + } else { + ucp_worker_t worker = std::get(ucx_objects_.worker); + ASSERT(worker != nullptr, "ERROR: UCX comms not initialized on communicator."); + + std::vector requests; + requests.reserve(count); + + time_t start = time(NULL); + + for (int i = 0; i < count; ++i) { + auto req_it = requests_in_flight_.find(array_of_requests[i]); + ASSERT(requests_in_flight_.end() != req_it, + "ERROR: waitall on invalid request: %d", + array_of_requests[i]); + requests.push_back(std::get(req_it->second)); + free_requests_.insert(req_it->first); + requests_in_flight_.erase(req_it); + } - // remove from pending requests - it = requests.erase(it); - } else { - ++it; + while (requests.size() > 0) { + time_t now = time(NULL); + + // Timeout if we have not gotten progress or completed any requests + // in 10 or more seconds. + ASSERT(now - start < 10, "Timed out waiting for requests."); + + for (std::vector::iterator it = requests.begin(); it != requests.end();) { + bool restart = false; // resets the timeout when any progress was made + + // Causes UCP to progress through the send/recv message queue + while (ucp_worker_progress(worker) != 0) { + restart = true; + } + + auto req = *it; + + // If the message needs release, we know it will be sent/received + // asynchronously, so we will need to track and verify its state + if (req->needs_release) { + ASSERT(UCS_PTR_IS_PTR(req->req), "UCX Request Error. Request is not valid UCX pointer"); + ASSERT(!UCS_PTR_IS_ERR(req->req), "UCX Request Error: %d\n", UCS_PTR_STATUS(req->req)); + ASSERT(req->req->completed == 1 || req->req->completed == 0, + "request->completed not a valid value: %d\n", + req->req->completed); + } + + // If a message was sent synchronously (eg. completed before + // `isend`/`irecv` completed) or an asynchronous message + // is complete, we can go ahead and clean it up. + if (!req->needs_release || req->req->completed == 1) { + restart = true; + + // perform cleanup + ucp_handler_.free_ucp_request(req); + + // remove from pending requests + it = requests.erase(it); + } else { + ++it; + } + // if any progress was made, reset the timeout start time + if (restart) { start = time(NULL); } } - // if any progress was made, reset the timeout start time - if (restart) { start = time(NULL); } } } } @@ -524,10 +647,11 @@ class std_comms : public comms_iface { bool own_nccl_comm_; comms_ucp_handler ucp_handler_; - ucp_worker_h ucp_worker_; - std::shared_ptr ucp_eps_; + ucx_objects_t ucx_objects_; mutable request_t next_request_id_; - mutable std::unordered_map requests_in_flight_; + mutable std::unordered_map>> + requests_in_flight_; mutable std::unordered_set free_requests_; }; } // namespace detail diff --git a/cpp/include/raft/comms/detail/ucp_helper.hpp b/cpp/include/raft/comms/detail/ucp_helper.hpp index 5896248c1d..65e1957e54 100644 --- a/cpp/include/raft/comms/detail/ucp_helper.hpp +++ b/cpp/include/raft/comms/detail/ucp_helper.hpp @@ -46,9 +46,7 @@ struct ucx_context { class ucp_request { public: struct ucx_context* req; - bool needs_release = true; - int other_rank = -1; - bool is_send_request = false; + bool needs_release = true; }; // by default, match the whole tag @@ -72,17 +70,16 @@ static void recv_callback(void* request, ucs_status_t status, ucp_tag_recv_info_ context->completed = 1; } +ucp_tag_t build_message_tag(int rank, int tag) +{ + // keeping the rank in the lower bits enables debugging. + return ((uint32_t)tag << 31) | (uint32_t)rank; +} + /** * Helper class for interacting with ucp. */ class comms_ucp_handler { - private: - ucp_tag_t build_message_tag(int rank, int tag) const - { - // keeping the rank in the lower bits enables debugging. - return ((uint32_t)tag << 31) | (uint32_t)rank; - } - public: /** * @brief Frees any memory underlying the given ucp request object @@ -132,9 +129,7 @@ class comms_ucp_handler { req->needs_release = false; } - req->other_rank = rank; - req->is_send_request = true; - req->req = ucp_req; + req->req = ucp_req; } /** @@ -156,10 +151,8 @@ class comms_ucp_handler { struct ucx_context* ucp_req = (struct ucx_context*)recv_result; - req->req = ucp_req; - req->needs_release = true; - req->is_send_request = false; - req->other_rank = sender_rank; + req->req = ucp_req; + req->needs_release = true; ASSERT(!UCS_PTR_IS_ERR(recv_result), "unable to receive UCX data message (%d)\n", diff --git a/cpp/include/raft/comms/std_comms.hpp b/cpp/include/raft/comms/std_comms.hpp index c81b19c9ba..667c8be285 100644 --- a/cpp/include/raft/comms/std_comms.hpp +++ b/cpp/include/raft/comms/std_comms.hpp @@ -24,6 +24,7 @@ #include #include +#include #include @@ -81,6 +82,8 @@ void build_comms_nccl_only(resources* handle, ncclComm_t nccl_comm, int num_rank * * @param handle raft::resources for injecting the comms * @param nccl_comm initialized NCCL communicator to use for collectives + * @param is_ucxx whether `ucp_worker` and `eps` objects are UCXX (true) or + * pure UCX (false). * @param ucp_worker of local process * Note: This is purposefully left as void* so that the ucp_worker_h * doesn't need to be exposed through the cython layer @@ -112,30 +115,55 @@ void build_comms_nccl_only(resources* handle, ncclComm_t nccl_comm, int num_rank * comm.sync_stream(resource::get_cuda_stream(handle)); * @endcode */ -void build_comms_nccl_ucx( - resources* handle, ncclComm_t nccl_comm, void* ucp_worker, void* eps, int num_ranks, int rank) +void build_comms_nccl_ucx(resources* handle, + ncclComm_t nccl_comm, + bool is_ucxx, + void* ucp_worker, + void* eps, + int num_ranks, + int rank) { - auto eps_sp = std::make_shared(new ucp_ep_h[num_ranks]); + detail::ucx_objects_t ucx_objects; + if (is_ucxx) { + ucx_objects.endpoints = std::make_shared(new ucxx::Endpoint*[num_ranks]); + ucx_objects.worker = static_cast(ucp_worker); + } else { + ucx_objects.endpoints = std::make_shared(new ucp_ep_h[num_ranks]); + ucx_objects.worker = static_cast(ucp_worker); + } auto size_t_ep_arr = reinterpret_cast(eps); for (int i = 0; i < num_ranks; i++) { - size_t ptr = size_t_ep_arr[i]; - auto ucp_ep_v = reinterpret_cast(*eps_sp); - - if (ptr != 0) { - auto eps_ptr = reinterpret_cast(size_t_ep_arr[i]); - ucp_ep_v[i] = eps_ptr; + size_t ptr = size_t_ep_arr[i]; + + if (is_ucxx) { + auto ucp_ep_v = reinterpret_cast( + *std::get(ucx_objects.endpoints)); + + if (ptr != 0) { + auto eps_ptr = reinterpret_cast(size_t_ep_arr[i]); + ucp_ep_v[i] = eps_ptr; + } else { + ucp_ep_v[i] = nullptr; + } } else { - ucp_ep_v[i] = nullptr; + auto ucp_ep_v = + reinterpret_cast(*std::get(ucx_objects.endpoints)); + + if (ptr != 0) { + auto eps_ptr = reinterpret_cast(size_t_ep_arr[i]); + ucp_ep_v[i] = eps_ptr; + } else { + ucp_ep_v[i] = nullptr; + } } } cudaStream_t stream = resource::get_cuda_stream(*handle); - auto communicator = - std::make_shared(std::unique_ptr(new raft::comms::std_comms( - nccl_comm, (ucp_worker_h)ucp_worker, eps_sp, num_ranks, rank, stream))); + auto communicator = std::make_shared(std::unique_ptr( + new raft::comms::std_comms(nccl_comm, ucx_objects, num_ranks, rank, stream))); resource::set_comms(*handle, communicator); } diff --git a/dependencies.yaml b/dependencies.yaml index a83cd003d6..a336aa1577 100644 --- a/dependencies.yaml +++ b/dependencies.yaml @@ -10,6 +10,8 @@ files: - build_pylibraft - cuda - cuda_version + - depends_on_cupy + - depends_on_distributed_ucxx - develop - checks - build_wheels @@ -19,7 +21,6 @@ files: - run_pylibraft - test_python_common - test_pylibraft - - cupy bench_ann: output: conda matrix: @@ -44,7 +45,8 @@ files: - py_version - test_python_common - test_pylibraft - - cupy + - depends_on_cupy + - depends_on_distributed_ucxx checks: output: none includes: @@ -54,7 +56,7 @@ files: output: none includes: - cuda_version - - cupy + - depends_on_cupy - docs - py_version - test_pylibraft @@ -82,7 +84,7 @@ files: includes: - test_python_common - test_pylibraft - - cupy + - depends_on_cupy py_build_raft_dask: output: pyproject pyproject_dir: python/raft-dask @@ -90,6 +92,7 @@ files: table: build-system includes: - build + - depends_on_ucx_build py_run_raft_dask: output: pyproject pyproject_dir: python/raft-dask @@ -105,6 +108,8 @@ files: key: test includes: - test_python_common + - depends_on_distributed_ucxx + - depends_on_ucx_run py_build_raft_ann_bench: output: pyproject pyproject_dir: python/raft-ann-bench @@ -138,6 +143,7 @@ dependencies: - c-compiler - cxx-compiler - nccl>=2.9.9 + - libucxx==0.38.* - scikit-build-core>=0.7.0 - output_types: [requirements, pyproject] packages: @@ -337,7 +343,7 @@ dependencies: - *libcusparse_dev114 - *libcusparse114 - cupy: + depends_on_cupy: common: - output_types: conda packages: @@ -477,3 +483,69 @@ dependencies: packages: - scikit-learn - scipy + depends_on_distributed_ucxx: + common: + - output_types: conda + packages: + # UCXX is not currently a hard-dependency thus only installed during tests, + # this will change in the future. + - &distributed_ucxx_conda distributed-ucxx==0.38.* + - output_types: requirements + packages: + # pip recognizes the index as a global option for the requirements.txt file + - --extra-index-url=https://pypi.nvidia.com + - --extra-index-url=https://pypi.anaconda.org/rapidsai-wheels-nightly/simple + specific: + - output_types: [requirements, pyproject] + matrices: + - matrix: {cuda: "12.*"} + packages: + - distributed-ucxx-cu12==0.38.* + - matrix: {cuda: "11.*"} + packages: + - distributed-ucxx-cu11==0.38.* + - {matrix: null, packages: [*distributed_ucxx_conda]} + depends_on_ucx_build: + common: + - output_types: conda + packages: + - &ucx_conda_build ucx==1.15.0 + - output_types: requirements + packages: + # pip recognizes the index as a global option for the requirements.txt file + - --extra-index-url=https://pypi.nvidia.com + - --extra-index-url=https://pypi.anaconda.org/rapidsai-wheels-nightly/simple + specific: + - output_types: [requirements, pyproject] + matrices: + - matrix: {cuda: "12.*"} + packages: + - libucx-cu12==1.15.0 + - matrix: {cuda: "11.*"} + packages: + - libucx-cu11==1.15.0 + - matrix: null + packages: + - libucx==1.15.0 + depends_on_ucx_run: + common: + - output_types: conda + packages: + - &ucx_conda_run ucx>=1.15.0 + - output_types: requirements + packages: + # pip recognizes the index as a global option for the requirements.txt file + - --extra-index-url=https://pypi.nvidia.com + - --extra-index-url=https://pypi.anaconda.org/rapidsai-wheels-nightly/simple + specific: + - output_types: [requirements, pyproject] + matrices: + - matrix: {cuda: "12.*"} + packages: + - libucx-cu12>=1.15.0 + - matrix: {cuda: "11.*"} + packages: + - libucx-cu11>=1.15.0 + - matrix: null + packages: + - libucx>=1.15.0 diff --git a/python/raft-dask/CMakeLists.txt b/python/raft-dask/CMakeLists.txt index 58e5ae8104..2c629f3b73 100644 --- a/python/raft-dask/CMakeLists.txt +++ b/python/raft-dask/CMakeLists.txt @@ -15,6 +15,7 @@ cmake_minimum_required(VERSION 3.26.4 FATAL_ERROR) include(../../rapids_config.cmake) +include(rapids-cpm) include(rapids-cuda) rapids_cuda_init_architectures(raft-dask-python) @@ -28,6 +29,11 @@ option(FIND_RAFT_CPP "Search for existing RAFT C++ installations before defaulti OFF ) +rapids_cpm_init() +# Once https://github.com/rapidsai/ucxx/issues/173 is resolved we can remove this. +find_package(ucx REQUIRED) +include(cmake/thirdparty/get_ucxx.cmake) + # If the user requested it we attempt to find RAFT. if(FIND_RAFT_CPP) find_package(raft "${RAPIDS_VERSION}" REQUIRED COMPONENTS distributed) @@ -36,8 +42,6 @@ else() endif() if(NOT raft_FOUND) - find_package(ucx REQUIRED) - # raft-dask doesn't actually use raft libraries, it just needs the headers, so we can turn off all # library compilation and we don't need to install anything here. set(BUILD_TESTS OFF) @@ -47,6 +51,7 @@ if(NOT raft_FOUND) set(RAFT_COMPILE_DIST_LIBRARY OFF) set(RAFT_COMPILE_NN_LIBRARY OFF) set(CUDA_STATIC_RUNTIME ON) + set(RAFT_DASK_UCXX_STATIC ON) add_subdirectory(../../cpp raft-cpp EXCLUDE_FROM_ALL) list(APPEND CMAKE_MODULE_PATH ${CMAKE_BINARY_DIR}/cmake/find_modules) diff --git a/python/raft-dask/cmake/thirdparty/get_ucxx.cmake b/python/raft-dask/cmake/thirdparty/get_ucxx.cmake new file mode 100644 index 0000000000..8e340eec73 --- /dev/null +++ b/python/raft-dask/cmake/thirdparty/get_ucxx.cmake @@ -0,0 +1,55 @@ +#============================================================================= +# Copyright (c) 2024, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +#============================================================================= + +function(find_and_configure_ucxx) + set(oneValueArgs VERSION FORK PINNED_TAG EXCLUDE_FROM_ALL) + set(options UCXX_STATIC) + cmake_parse_arguments(PKG "${options}" "${oneValueArgs}" + "${multiValueArgs}" ${ARGN} ) + + set(BUILD_UCXX_SHARED ON) + if(PKG_UCXX_STATIC) + set(BUILD_UCXX_SHARED OFF) + endif() + + rapids_cpm_find(ucxx ${PKG_VERSION} + GLOBAL_TARGETS ucxx::ucxx ucxx::python + BUILD_EXPORT_SET raft-distributed-exports + INSTALL_EXPORT_SET raft-distributed-exports + CPM_ARGS + GIT_REPOSITORY https://github.com/${PKG_FORK}/ucxx.git + GIT_TAG ${PKG_PINNED_TAG} + SOURCE_SUBDIR cpp + EXCLUDE_FROM_ALL ${PKG_EXCLUDE_FROM_ALL} + OPTIONS + "BUILD_TESTS OFF" + "BUILD_BENCH OFF" + "UCXX_ENABLE_PYTHON ON" + "UCXX_ENABLE_RMM ON" + "BUILD_SHARED_LIBS ${BUILD_UCXX_SHARED}" + ) + +endfunction() + +# Change pinned tag here to test a commit in CI +# To use a different RAFT locally, set the CMake variable +# CPM_raft_SOURCE=/path/to/local/raft +find_and_configure_ucxx(VERSION 0.38 + FORK rapidsai + PINNED_TAG branch-0.38 + EXCLUDE_FROM_ALL YES + UCXX_STATIC ${RAFT_DASK_UCXX_STATIC} + ) diff --git a/python/raft-dask/pyproject.toml b/python/raft-dask/pyproject.toml index 815f6b277c..0181bef4ce 100644 --- a/python/raft-dask/pyproject.toml +++ b/python/raft-dask/pyproject.toml @@ -18,6 +18,7 @@ build-backend = "scikit_build_core.build" requires = [ "cmake>=3.26.4", "cython>=3.0.0", + "libucx==1.15.0", "ninja", "scikit-build-core[pyproject]>=0.7.0", ] # This list was generated by `rapids-dependency-file-generator`. To make changes, edit ../../dependencies.yaml and run `rapids-dependency-file-generator`. @@ -51,6 +52,8 @@ classifiers = [ [project.optional-dependencies] test = [ + "distributed-ucxx==0.38.*", + "libucx>=1.15.0", "pytest-cov", "pytest==7.*", ] # This list was generated by `rapids-dependency-file-generator`. To make changes, edit ../../dependencies.yaml and run `rapids-dependency-file-generator`. diff --git a/python/raft-dask/pytest.ini b/python/raft-dask/pytest.ini index 5559bb08c8..fcb18fe412 100644 --- a/python/raft-dask/pytest.ini +++ b/python/raft-dask/pytest.ini @@ -6,4 +6,5 @@ markers = mg: marks a test as multi-GPU memleak: marks a test as a memory leak test nccl: marks a test as using NCCL - ucx: marks a test as using ucx-py + ucx: marks a test as using UCX-Py + ucxx: marks a test as using UCXX diff --git a/python/raft-dask/raft_dask/__init__.py b/python/raft-dask/raft_dask/__init__.py index fbbaee4118..19a037ae75 100644 --- a/python/raft-dask/raft_dask/__init__.py +++ b/python/raft-dask/raft_dask/__init__.py @@ -1,4 +1,4 @@ -# Copyright (c) 2020-2023, NVIDIA CORPORATION. +# Copyright (c) 2020-2024, NVIDIA CORPORATION. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -14,3 +14,13 @@ # from raft_dask._version import __git_commit__, __version__ + +# If libucx was installed as a wheel, we must request it to load the library symbols. +# Otherwise, we assume that the library was installed in a system path that ld can find. +try: + import libucx +except ModuleNotFoundError: + pass +else: + libucx.load_library() + del libucx diff --git a/python/raft-dask/raft_dask/common/comms.py b/python/raft-dask/raft_dask/common/comms.py index b2f7d1fb74..c67170342f 100644 --- a/python/raft-dask/raft_dask/common/comms.py +++ b/python/raft-dask/raft_dask/common/comms.py @@ -327,11 +327,15 @@ def get_ucx(dask_worker=None): (Note: if called by client.run(), this is supplied by Dask and not the client) """ + protocol = ( + "ucxx" if dask_worker._protocol.split("://")[0] == "ucxx" else "ucx" + ) + raft_comm_state = get_raft_comm_state( sessionId="ucp", state_object=dask_worker ) if "ucx" not in raft_comm_state: - raft_comm_state["ucx"] = UCX.get() + raft_comm_state["ucx"] = UCX.get(protocol=protocol) return raft_comm_state["ucx"] @@ -535,7 +539,9 @@ def _func_build_handle_p2p( if verbose: dask_worker.log_event(topic="info", msg="Building p2p handle.") - ucp_worker = get_ucx(dask_worker).get_worker() + ucx = get_ucx(dask_worker) + is_ucxx = ucx._protocol == "ucxx" + ucx_worker = ucx.get_worker() raft_comm_state = get_raft_comm_state( sessionId=sessionId, state_object=dask_worker ) @@ -550,7 +556,14 @@ def _func_build_handle_p2p( dask_worker.log_event(topic="info", msg="Injecting comms on handle.") inject_comms_on_handle( - handle, nccl_comm, ucp_worker, eps, nWorkers, workerId, verbose + handle, + nccl_comm, + is_ucxx, + ucx_worker, + eps, + nWorkers, + workerId, + verbose, ) if verbose: diff --git a/python/raft-dask/raft_dask/common/comms_utils.pyx b/python/raft-dask/raft_dask/common/comms_utils.pyx index 768ba0e422..2d4d2cc83b 100644 --- a/python/raft-dask/raft_dask/common/comms_utils.pyx +++ b/python/raft-dask/raft_dask/common/comms_utils.pyx @@ -1,4 +1,4 @@ -# Copyright (c) 2019-2023, NVIDIA CORPORATION. +# Copyright (c) 2019-2024, NVIDIA CORPORATION. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -41,6 +41,7 @@ cdef extern from "raft/comms/std_comms.hpp" namespace "raft::comms": void build_comms_nccl_ucx(device_resources *handle, ncclComm_t comm, + bint is_ucxx, void *ucp_worker, void *eps, int size, @@ -285,7 +286,7 @@ def inject_comms_on_handle_coll_only(handle, nccl_inst, size, rank, verbose): rank) -def inject_comms_on_handle(handle, nccl_inst, ucp_worker, eps, size, +def inject_comms_on_handle(handle, nccl_inst, is_ucxx, ucp_worker, eps, size, rank, verbose): """ Given a handle and initialized comms, creates a comms_t instance @@ -308,7 +309,10 @@ def inject_comms_on_handle(handle, nccl_inst, ucp_worker, eps, size, for i in range(len(eps)): if eps[i] is not None: - ep_st = eps[i].get_ucp_endpoint() + if is_ucxx: + ep_st = eps[i].ucxx_endpoint + else: + ep_st = eps[i].get_ucp_endpoint() ucp_eps[i] = ep_st else: ucp_eps[i] = 0 @@ -323,6 +327,7 @@ def inject_comms_on_handle(handle, nccl_inst, ucp_worker, eps, size, build_comms_nccl_ucx(handle_, deref(nccl_comm_), + is_ucxx, ucp_worker_st, ucp_eps, size, diff --git a/python/raft-dask/raft_dask/common/ucx.py b/python/raft-dask/raft_dask/common/ucx.py index eb246853f4..423e6f4692 100644 --- a/python/raft-dask/raft_dask/common/ucx.py +++ b/python/raft-dask/raft_dask/common/ucx.py @@ -1,4 +1,4 @@ -# Copyright (c) 2020-2022, NVIDIA CORPORATION. +# Copyright (c) 2020-2024, NVIDIA CORPORATION. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,8 +13,6 @@ # limitations under the License. # -import ucp - async def _connection_func(ep): UCX.get().add_server_endpoint(ep) @@ -29,10 +27,20 @@ class UCX: __instance = None - def __init__(self, listener_callback): + def __init__(self, listener_callback, protocol): self.listener_callback = listener_callback + self._protocol = protocol + if self._protocol == "ucxx": + import ucxx + + self.ucx_api = ucxx + else: + import ucp + + self.ucx_api = ucp + self._create_listener() self._endpoints = {} self._server_endpoints = [] @@ -42,22 +50,28 @@ def __init__(self, listener_callback): UCX.__instance = self @staticmethod - def get(listener_callback=_connection_func): + def get(listener_callback=_connection_func, protocol="ucx"): if UCX.__instance is None: - UCX(listener_callback) + UCX(listener_callback, protocol) return UCX.__instance + def get_protocol(self): + return self._protocol + def get_worker(self): - return ucp.get_ucp_worker() + if self._protocol == "ucxx": + return self.ucx_api.get_ucxx_worker() + else: + return self.ucx_api.get_ucp_worker() def _create_listener(self): - self._listener = ucp.create_listener(self.listener_callback) + self._listener = self.ucx_api.create_listener(self.listener_callback) def listener_port(self): return self._listener.port async def _create_endpoint(self, ip, port): - ep = await ucp.create_endpoint(ip, port) + ep = await self.ucx_api.create_endpoint(ip, port) self._endpoints[(ip, port)] = ep return ep diff --git a/python/raft-dask/raft_dask/test/conftest.py b/python/raft-dask/raft_dask/test/conftest.py index d1baa684d4..a60e4d995f 100644 --- a/python/raft-dask/raft_dask/test/conftest.py +++ b/python/raft-dask/raft_dask/test/conftest.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2023, NVIDIA CORPORATION. +# Copyright (c) 2022-2024, NVIDIA CORPORATION. import os @@ -34,6 +34,21 @@ def ucx_cluster(): cluster.close() +@pytest.fixture(scope="session") +def ucxx_cluster(): + pytest.importorskip("distributed_ucxx") + + scheduler_file = os.environ.get("SCHEDULER_FILE") + if scheduler_file: + yield scheduler_file + else: + cluster = LocalCUDACluster( + protocol="ucxx", + ) + yield cluster + cluster.close() + + @pytest.fixture(scope="session") def client(cluster): client = create_client(cluster) @@ -48,6 +63,13 @@ def ucx_client(ucx_cluster): client.close() +@pytest.fixture() +def ucxx_client(ucxx_cluster): + client = create_client(ucxx_cluster) + yield client + client.close() + + def create_client(cluster): """ Create a Dask distributed client for a specified cluster. @@ -69,3 +91,43 @@ def create_client(cluster): return Client(cluster) else: return Client(scheduler_file=cluster) + + +def pytest_addoption(parser): + group = parser.getgroup("Dask RAFT Custom Options") + + group.addoption( + "--run_ucx", action="store_true", help="run _only_ UCX-Py tests" + ) + + group.addoption( + "--run_ucxx", action="store_true", help="run _only_ UCXX tests" + ) + + +def pytest_collection_modifyitems(config, items): + if config.getoption("--run_ucx"): + skip_others = pytest.mark.skip( + reason="only runs when --run_ucx is not specified" + ) + for item in items: + if "ucx" not in item.keywords: + item.add_marker(skip_others) + else: + skip_ucx = pytest.mark.skip(reason="requires --run_ucx to run") + for item in items: + if "ucx" in item.keywords: + item.add_marker(skip_ucx) + + if config.getoption("--run_ucxx"): + skip_others = pytest.mark.skip( + reason="only runs when --run_ucxx is not specified" + ) + for item in items: + if "ucxx" not in item.keywords: + item.add_marker(skip_others) + else: + skip_ucxx = pytest.mark.skip(reason="requires --run_ucxx to run") + for item in items: + if "ucxx" in item.keywords: + item.add_marker(skip_ucxx) diff --git a/python/raft-dask/raft_dask/test/test_comms.py b/python/raft-dask/raft_dask/test/test_comms.py index b62d7185b2..109dd12b5e 100644 --- a/python/raft-dask/raft_dask/test/test_comms.py +++ b/python/raft-dask/raft_dask/test/test_comms.py @@ -66,6 +66,10 @@ def create_client(cluster): return Client(scheduler_file=cluster) +def _get_client(dask_client, request): + return request.getfixturevalue(dask_client) + + def test_comms_init_no_p2p(cluster): client = create_client(cluster) try: @@ -179,8 +183,7 @@ def _has_handle(sessionId): functions = [None] -@pytest.mark.parametrize("root_location", ["client", "worker", "scheduler"]) -def test_nccl_root_placement(client, root_location): +def _test_nccl_root_placement(client, root_location): cb = None try: @@ -214,10 +217,31 @@ def test_nccl_root_placement(client, root_location): cb.destroy() -@pytest.mark.parametrize("func", functions) @pytest.mark.parametrize("root_location", ["client", "worker", "scheduler"]) @pytest.mark.nccl -def test_collectives(client, func, root_location): +def test_nccl_root_placement(root_location, request): + _test_nccl_root_placement(_get_client("client", request), root_location) + + +@pytest.mark.parametrize("root_location", ["client", "worker", "scheduler"]) +@pytest.mark.nccl +@pytest.mark.ucx +def test_nccl_root_placement_ucx(root_location, request): + _test_nccl_root_placement( + _get_client("ucx_client", request), root_location + ) + + +@pytest.mark.parametrize("root_location", ["client", "worker", "scheduler"]) +@pytest.mark.nccl +@pytest.mark.ucxx +def test_nccl_root_placement_ucxx(root_location, request): + _test_nccl_root_placement( + _get_client("ucxx_client", request), root_location + ) + + +def _test_collectives(client, func, root_location): try: cb = Comms( @@ -246,8 +270,30 @@ def test_collectives(client, func, root_location): cb.destroy() +@pytest.mark.parametrize("func", functions) +@pytest.mark.parametrize("root_location", ["client", "worker", "scheduler"]) @pytest.mark.nccl -def test_comm_split(client): +def test_collectives(func, root_location, request): + _test_collectives(_get_client("client", request), func, root_location) + + +@pytest.mark.parametrize("func", functions) +@pytest.mark.parametrize("root_location", ["client", "worker", "scheduler"]) +@pytest.mark.nccl +@pytest.mark.ucx +def test_collectives_ucx(func, root_location, request): + _test_collectives(_get_client("ucx_client", request), func, root_location) + + +@pytest.mark.parametrize("func", functions) +@pytest.mark.parametrize("root_location", ["client", "worker", "scheduler"]) +@pytest.mark.nccl +@pytest.mark.ucxx +def test_collectives_ucxx(func, root_location, request): + _test_collectives(_get_client("ucxx_client", request), func, root_location) + + +def _test_comm_split(client): cb = Comms(comms_p2p=True, verbose=True) cb.init() @@ -264,9 +310,24 @@ def test_comm_split(client): assert all([x.result() for x in dfs]) +@pytest.mark.nccl +def test_comm_split(request): + _test_comm_split(_get_client("client", request)) + + +@pytest.mark.nccl @pytest.mark.ucx -@pytest.mark.parametrize("n_trials", [1, 5]) -def test_send_recv(n_trials, client): +def test_comm_split_ucx(request): + _test_comm_split(_get_client("ucx_client", request)) + + +@pytest.mark.nccl +@pytest.mark.ucxx +def test_comm_split_ucxx(request): + _test_comm_split(_get_client("ucxx_client", request)) + + +def _test_send_recv_protocol(n_trials, client): cb = Comms(comms_p2p=True, verbose=True) cb.init() @@ -287,9 +348,24 @@ def test_send_recv(n_trials, client): assert list(map(lambda x: x.result(), dfs)) -@pytest.mark.nccl @pytest.mark.parametrize("n_trials", [1, 5]) -def test_device_send_or_recv(n_trials, client): +def test_send_recv_protocol(n_trials, request): + _test_send_recv_protocol(n_trials, _get_client("client", request)) + + +@pytest.mark.parametrize("n_trials", [1, 5]) +@pytest.mark.ucx +def test_send_recv_protocol_ucx(n_trials, request): + _test_send_recv_protocol(n_trials, _get_client("ucx_client", request)) + + +@pytest.mark.parametrize("n_trials", [1, 5]) +@pytest.mark.ucxx +def test_send_recv_protocol_ucxx(n_trials, request): + _test_send_recv_protocol(n_trials, _get_client("ucxx_client", request)) + + +def _test_device_send_or_recv(n_trials, client): cb = Comms(comms_p2p=True, verbose=True) cb.init() @@ -310,9 +386,27 @@ def test_device_send_or_recv(n_trials, client): assert list(map(lambda x: x.result(), dfs)) +@pytest.mark.parametrize("n_trials", [1, 5]) +@pytest.mark.nccl +def test_device_send_or_recv(n_trials, request): + _test_device_send_or_recv(n_trials, _get_client("client", request)) + + +@pytest.mark.parametrize("n_trials", [1, 5]) @pytest.mark.nccl +@pytest.mark.ucx +def test_device_send_or_recv_ucx(n_trials, request): + _test_device_send_or_recv(n_trials, _get_client("ucx_client", request)) + + @pytest.mark.parametrize("n_trials", [1, 5]) -def test_device_sendrecv(n_trials, client): +@pytest.mark.nccl +@pytest.mark.ucxx +def test_device_send_or_recv_ucxx(n_trials, request): + _test_device_send_or_recv(n_trials, _get_client("ucxx_client", request)) + + +def _test_device_sendrecv(n_trials, client): cb = Comms(comms_p2p=True, verbose=True) cb.init() @@ -333,9 +427,27 @@ def test_device_sendrecv(n_trials, client): assert list(map(lambda x: x.result(), dfs)) +@pytest.mark.parametrize("n_trials", [1, 5]) +@pytest.mark.nccl +def test_device_sendrecv(n_trials, request): + _test_device_sendrecv(n_trials, _get_client("client", request)) + + +@pytest.mark.parametrize("n_trials", [1, 5]) @pytest.mark.nccl +@pytest.mark.ucx +def test_device_sendrecv_ucx(n_trials, request): + _test_device_sendrecv(n_trials, _get_client("ucx_client", request)) + + @pytest.mark.parametrize("n_trials", [1, 5]) -def test_device_multicast_sendrecv(n_trials, client): +@pytest.mark.nccl +@pytest.mark.ucxx +def test_device_sendrecv_ucxx(n_trials, request): + _test_device_sendrecv(n_trials, _get_client("ucxx_client", request)) + + +def _test_device_multicast_sendrecv(n_trials, client): cb = Comms(comms_p2p=True, verbose=True) cb.init() @@ -356,6 +468,30 @@ def test_device_multicast_sendrecv(n_trials, client): assert list(map(lambda x: x.result(), dfs)) +@pytest.mark.parametrize("n_trials", [1, 5]) +@pytest.mark.nccl +def test_device_multicast_sendrecv(n_trials, request): + _test_device_multicast_sendrecv(n_trials, _get_client("client", request)) + + +@pytest.mark.parametrize("n_trials", [1, 5]) +@pytest.mark.nccl +@pytest.mark.ucx +def test_device_multicast_sendrecv_ucx(n_trials, request): + _test_device_multicast_sendrecv( + n_trials, _get_client("ucx_client", request) + ) + + +@pytest.mark.parametrize("n_trials", [1, 5]) +@pytest.mark.nccl +@pytest.mark.ucxx +def test_device_multicast_sendrecv_ucxx(n_trials, request): + _test_device_multicast_sendrecv( + n_trials, _get_client("ucxx_client", request) + ) + + @pytest.mark.nccl @pytest.mark.parametrize( "subset", [slice(-1, None), slice(1), slice(None, None, -2)]