Skip to content

Commit

Permalink
Matching updates for RAFT comms updates (device_sendrecv, device_mult…
Browse files Browse the repository at this point in the history
…icast_sendrecv, gather, gatherv) (#1391)

- [x] Update cuGraph to use RAFT::comms_t's newly added device_sendrecv & device_multicast_sendrecv)
- [x] Update cuGraph to use RAFT::comms_t's newly added gather & gatherv
- [x] Update RAFT git tag once rapidsai/raft#114 (currently merged in 0.18 but is not merged to 0.19) and rapidsai/raft#144 are merged to 0.19

Ready for review but cannot be merged till RAFT PR 114 and 144 are merged to RAFT branch-0.19.

Authors:
  - Seunghwa Kang (@seunghwak)

Approvers:
  - Alex Fender (@afender)

URL: #1391
  • Loading branch information
seunghwak authored Feb 25, 2021
1 parent 89bffa5 commit 06ac713
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 50 deletions.
2 changes: 1 addition & 1 deletion cpp/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -298,7 +298,7 @@ else(DEFINED ENV{RAFT_PATH})
FetchContent_Declare(
raft
GIT_REPOSITORY https://github.com/rapidsai/raft.git
GIT_TAG 4a79adcb0c0e87964dcdc9b9122f242b5235b702
GIT_TAG a3461b201ea1c9f61571f1927274f739e775d2d2
SOURCE_SUBDIR raft
)

Expand Down
55 changes: 16 additions & 39 deletions cpp/include/utilities/device_comm.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -196,21 +196,13 @@ device_sendrecv_impl(raft::comms::comms_t const& comm,
using value_type = typename std::iterator_traits<InputIterator>::value_type;
static_assert(
std::is_same<typename std::iterator_traits<OutputIterator>::value_type, value_type>::value);
// ncclSend/ncclRecv pair needs to be located inside ncclGroupStart/ncclGroupEnd to avoid deadlock
ncclGroupStart();
ncclSend(iter_to_raw_ptr(input_first),
tx_count * sizeof(value_type),
ncclUint8,
dst,
comm.get_nccl_comm(),
stream);
ncclRecv(iter_to_raw_ptr(output_first),
rx_count * sizeof(value_type),
ncclUint8,
src,
comm.get_nccl_comm(),
stream);
ncclGroupEnd();
comm.device_sendrecv(iter_to_raw_ptr(input_first),
tx_count,
dst,
iter_to_raw_ptr(output_first),
rx_count,
src,
stream);
}

template <typename InputIterator, typename OutputIterator, size_t I, size_t N>
Expand Down Expand Up @@ -288,25 +280,15 @@ device_multicast_sendrecv_impl(raft::comms::comms_t const& comm,
using value_type = typename std::iterator_traits<InputIterator>::value_type;
static_assert(
std::is_same<typename std::iterator_traits<OutputIterator>::value_type, value_type>::value);
// ncclSend/ncclRecv pair needs to be located inside ncclGroupStart/ncclGroupEnd to avoid deadlock
ncclGroupStart();
for (size_t i = 0; i < tx_counts.size(); ++i) {
ncclSend(iter_to_raw_ptr(input_first + tx_offsets[i]),
tx_counts[i] * sizeof(value_type),
ncclUint8,
tx_dst_ranks[i],
comm.get_nccl_comm(),
stream);
}
for (size_t i = 0; i < rx_counts.size(); ++i) {
ncclRecv(iter_to_raw_ptr(output_first + rx_offsets[i]),
rx_counts[i] * sizeof(value_type),
ncclUint8,
rx_src_ranks[i],
comm.get_nccl_comm(),
stream);
}
ncclGroupEnd();
comm.device_multicast_sendrecv(iter_to_raw_ptr(input_first),
tx_counts,
tx_offsets,
tx_dst_ranks,
iter_to_raw_ptr(output_first),
rx_counts,
rx_offsets,
rx_src_ranks,
stream);
}

template <typename InputIterator, typename OutputIterator, size_t I, size_t N>
Expand Down Expand Up @@ -589,18 +571,13 @@ device_gatherv_impl(raft::comms::comms_t const& comm,
{
static_assert(std::is_same<typename std::iterator_traits<InputIterator>::value_type,
typename std::iterator_traits<OutputIterator>::value_type>::value);
// FIXME: should be enabled once the RAFT gather & gatherv PR is merged
#if 1
CUGRAPH_FAIL("Unimplemented.");
#else
comm.gatherv(iter_to_raw_ptr(input_first),
iter_to_raw_ptr(output_first),
sendcount,
recvcounts.data(),
displacements.data(),
root,
stream);
#endif
}

template <typename InputIterator, typename OutputIterator, size_t I, size_t N>
Expand Down
10 changes: 0 additions & 10 deletions cpp/include/utilities/host_scalar_comm.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -321,16 +321,11 @@ std::enable_if_t<std::is_arithmetic<T>::value, std::vector<T>> host_scalar_gathe
&input,
1,
stream);
// FIXME: should be enabled once the RAFT gather & gatherv PR is merged
#if 1
CUGRAPH_FAIL("Unimplemented.");
#else
comm.gather(comm.get_rank() == root ? d_outputs.data() + comm.get_rank() : d_outputs.data(),
d_outputs.data(),
size_t{1},
root,
stream);
#endif
std::vector<T> h_outputs(comm.get_rank() == root ? comm.get_size() : 0);
if (comm.get_rank() == root) {
raft::update_host(h_outputs.data(), d_outputs.data(), comm.get_size(), stream);
Expand Down Expand Up @@ -358,18 +353,13 @@ host_scalar_gather(raft::comms::comms_t const& comm, T input, int root, cudaStre
h_tuple_scalar_elements.data(),
tuple_size,
stream);
// FIXME: should be enabled once the RAFT gather & gatherv PR is merged
#if 1
CUGRAPH_FAIL("Unimplemented.");
#else
comm.gather(comm.get_rank() == root
? d_gathered_tuple_scalar_elements.data() + comm.get_rank() * tuple_size
: d_gathered_tuple_scalar_elements.data(),
d_gathered_tuple_scalar_elements.data(),
tuple_size,
root,
stream);
#endif
std::vector<int64_t> h_gathered_tuple_scalar_elements(
comm.get_rank() == root ? comm.get_size() * tuple_size : size_t{0});
if (comm.get_rank() == root) {
Expand Down

0 comments on commit 06ac713

Please sign in to comment.