From 06ac713c5e5700185abe28fbc261c84e2b7165a8 Mon Sep 17 00:00:00 2001 From: Seunghwa Kang <45857425+seunghwak@users.noreply.github.com> Date: Thu, 25 Feb 2021 17:08:15 -0500 Subject: [PATCH] Matching updates for RAFT comms updates (device_sendrecv, device_multicast_sendrecv, gather, gatherv) (#1391) - [x] Update cuGraph to use RAFT::comms_t's newly added device_sendrecv & device_multicast_sendrecv) - [x] Update cuGraph to use RAFT::comms_t's newly added gather & gatherv - [x] Update RAFT git tag once https://github.com/rapidsai/raft/pull/114 (currently merged in 0.18 but is not merged to 0.19) and https://github.com/rapidsai/raft/pull/144 are merged to 0.19 Ready for review but cannot be merged till RAFT PR 114 and 144 are merged to RAFT branch-0.19. Authors: - Seunghwa Kang (@seunghwak) Approvers: - Alex Fender (@afender) URL: https://github.com/rapidsai/cugraph/pull/1391 --- cpp/CMakeLists.txt | 2 +- cpp/include/utilities/device_comm.cuh | 55 +++++++--------------- cpp/include/utilities/host_scalar_comm.cuh | 10 ---- 3 files changed, 17 insertions(+), 50 deletions(-) diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt index b2d537edaa2..d211fe9ed5a 100644 --- a/cpp/CMakeLists.txt +++ b/cpp/CMakeLists.txt @@ -298,7 +298,7 @@ else(DEFINED ENV{RAFT_PATH}) FetchContent_Declare( raft GIT_REPOSITORY https://github.com/rapidsai/raft.git - GIT_TAG 4a79adcb0c0e87964dcdc9b9122f242b5235b702 + GIT_TAG a3461b201ea1c9f61571f1927274f739e775d2d2 SOURCE_SUBDIR raft ) diff --git a/cpp/include/utilities/device_comm.cuh b/cpp/include/utilities/device_comm.cuh index 8c3b0f86a47..24b9147ce3c 100644 --- a/cpp/include/utilities/device_comm.cuh +++ b/cpp/include/utilities/device_comm.cuh @@ -196,21 +196,13 @@ device_sendrecv_impl(raft::comms::comms_t const& comm, using value_type = typename std::iterator_traits::value_type; static_assert( std::is_same::value_type, value_type>::value); - // ncclSend/ncclRecv pair needs to be located inside ncclGroupStart/ncclGroupEnd to avoid deadlock - ncclGroupStart(); - ncclSend(iter_to_raw_ptr(input_first), - tx_count * sizeof(value_type), - ncclUint8, - dst, - comm.get_nccl_comm(), - stream); - ncclRecv(iter_to_raw_ptr(output_first), - rx_count * sizeof(value_type), - ncclUint8, - src, - comm.get_nccl_comm(), - stream); - ncclGroupEnd(); + comm.device_sendrecv(iter_to_raw_ptr(input_first), + tx_count, + dst, + iter_to_raw_ptr(output_first), + rx_count, + src, + stream); } template @@ -288,25 +280,15 @@ device_multicast_sendrecv_impl(raft::comms::comms_t const& comm, using value_type = typename std::iterator_traits::value_type; static_assert( std::is_same::value_type, value_type>::value); - // ncclSend/ncclRecv pair needs to be located inside ncclGroupStart/ncclGroupEnd to avoid deadlock - ncclGroupStart(); - for (size_t i = 0; i < tx_counts.size(); ++i) { - ncclSend(iter_to_raw_ptr(input_first + tx_offsets[i]), - tx_counts[i] * sizeof(value_type), - ncclUint8, - tx_dst_ranks[i], - comm.get_nccl_comm(), - stream); - } - for (size_t i = 0; i < rx_counts.size(); ++i) { - ncclRecv(iter_to_raw_ptr(output_first + rx_offsets[i]), - rx_counts[i] * sizeof(value_type), - ncclUint8, - rx_src_ranks[i], - comm.get_nccl_comm(), - stream); - } - ncclGroupEnd(); + comm.device_multicast_sendrecv(iter_to_raw_ptr(input_first), + tx_counts, + tx_offsets, + tx_dst_ranks, + iter_to_raw_ptr(output_first), + rx_counts, + rx_offsets, + rx_src_ranks, + stream); } template @@ -589,10 +571,6 @@ device_gatherv_impl(raft::comms::comms_t const& comm, { static_assert(std::is_same::value_type, typename std::iterator_traits::value_type>::value); - // FIXME: should be enabled once the RAFT gather & gatherv PR is merged -#if 1 - CUGRAPH_FAIL("Unimplemented."); -#else comm.gatherv(iter_to_raw_ptr(input_first), iter_to_raw_ptr(output_first), sendcount, @@ -600,7 +578,6 @@ device_gatherv_impl(raft::comms::comms_t const& comm, displacements.data(), root, stream); -#endif } template diff --git a/cpp/include/utilities/host_scalar_comm.cuh b/cpp/include/utilities/host_scalar_comm.cuh index dda0ce1f091..2ecfd913813 100644 --- a/cpp/include/utilities/host_scalar_comm.cuh +++ b/cpp/include/utilities/host_scalar_comm.cuh @@ -321,16 +321,11 @@ std::enable_if_t::value, std::vector> host_scalar_gathe &input, 1, stream); - // FIXME: should be enabled once the RAFT gather & gatherv PR is merged -#if 1 - CUGRAPH_FAIL("Unimplemented."); -#else comm.gather(comm.get_rank() == root ? d_outputs.data() + comm.get_rank() : d_outputs.data(), d_outputs.data(), size_t{1}, root, stream); -#endif std::vector h_outputs(comm.get_rank() == root ? comm.get_size() : 0); if (comm.get_rank() == root) { raft::update_host(h_outputs.data(), d_outputs.data(), comm.get_size(), stream); @@ -358,10 +353,6 @@ host_scalar_gather(raft::comms::comms_t const& comm, T input, int root, cudaStre h_tuple_scalar_elements.data(), tuple_size, stream); - // FIXME: should be enabled once the RAFT gather & gatherv PR is merged -#if 1 - CUGRAPH_FAIL("Unimplemented."); -#else comm.gather(comm.get_rank() == root ? d_gathered_tuple_scalar_elements.data() + comm.get_rank() * tuple_size : d_gathered_tuple_scalar_elements.data(), @@ -369,7 +360,6 @@ host_scalar_gather(raft::comms::comms_t const& comm, T input, int root, cudaStre tuple_size, root, stream); -#endif std::vector h_gathered_tuple_scalar_elements( comm.get_rank() == root ? comm.get_size() * tuple_size : size_t{0}); if (comm.get_rank() == root) {