Skip to content

Commit

Permalink
Merge pull request #18 from cjnolet/fea-015-comms_split
Browse files Browse the repository at this point in the history
[REVIEW] commSplit Implementation
  • Loading branch information
cjnolet authored Jul 29, 2020
2 parents a7f6e88 + 96d2dc7 commit b8c46c9
Show file tree
Hide file tree
Showing 11 changed files with 240 additions and 100 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
## New Features
- PR #12: Spectral clustering.
- PR #7: Migrating cuml comms -> raft comms_t
- PR #18: Adding commsplit to cuml communicator
- PR #15: add exception based error handling macros
- PR #29: Add ceildiv functionality

Expand Down
20 changes: 8 additions & 12 deletions cpp/include/raft/comms/helper.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,14 @@

#pragma once

#include <nccl.h>
#include <ucp/api/ucp.h>
#include <iostream>
#include <raft/comms/std_comms.hpp>
#include <raft/handle.hpp>
#include <raft/mr/device/buffer.hpp>

#include <nccl.h>
#include <ucp/api/ucp.h>
#include <iostream>

namespace raft {
namespace comms {

Expand All @@ -39,11 +40,8 @@ void build_comms_nccl_only(handle_t *handle, ncclComm_t nccl_comm,
int num_ranks, int rank) {
auto d_alloc = handle->get_device_allocator();
cudaStream_t stream = handle->get_stream();
comms_iface *raft_comm =
new raft::comms::std_comms(nccl_comm, num_ranks, rank, d_alloc, stream);

auto communicator =
std::make_shared<comms_t>(std::unique_ptr<comms_iface>(raft_comm));
auto communicator = std::make_shared<comms_t>(std::unique_ptr<comms_iface>(
new raft::comms::std_comms(nccl_comm, num_ranks, rank, d_alloc, stream)));
handle->set_comms(communicator);
}

Expand Down Expand Up @@ -84,11 +82,9 @@ void build_comms_nccl_ucx(handle_t *handle, ncclComm_t nccl_comm,
auto d_alloc = handle->get_device_allocator();
cudaStream_t stream = handle->get_stream();

auto *raft_comm =
auto communicator = std::make_shared<comms_t>(std::unique_ptr<comms_iface>(
new raft::comms::std_comms(nccl_comm, (ucp_worker_h)ucp_worker, eps_sp,
num_ranks, rank, d_alloc, stream);
auto communicator =
std::make_shared<comms_t>(std::unique_ptr<comms_iface>(raft_comm));
num_ranks, rank, d_alloc, stream)));
handle->set_comms(communicator);
}

Expand Down
122 changes: 101 additions & 21 deletions cpp/include/raft/comms/std_comms.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,36 +16,36 @@

#pragma once

#include <unordered_map>
#include <unordered_set>
#include <utility>
#include <raft/comms/comms.hpp>

#include <nccl.h>
#include <raft/comms/ucp_helper.hpp>
#include <raft/handle.hpp>
#include <raft/mr/device/buffer.hpp>

#include <raft/comms/comms.hpp>
#include <raft/error.hpp>

#include <raft/cudart_utils.h>

#include <cuda_runtime.h>

#include <ucp/api/ucp.h>
#include <ucp/api/ucp_def.h>
#include "ucp_helper.hpp"

#include <nccl.h>

#include <unordered_map>
#include <unordered_set>
#include <utility>

#include <stdlib.h>
#include <time.h>
#include <algorithm>
#include <chrono>
#include <cstdio>
#include <exception>
#include <memory>
#include <raft/handle.hpp>

#include <thread>

#include <cuda_runtime.h>

#include <raft/cudart_utils.h>
#include <raft/error.hpp>

namespace raft {

/**
Expand Down Expand Up @@ -164,15 +164,16 @@ class std_comms : public comms_iface {
std_comms(ncclComm_t nccl_comm, ucp_worker_h ucp_worker,
std::shared_ptr<ucp_ep_h *> eps, int num_ranks, int rank,
const std::shared_ptr<mr::device::allocator> device_allocator,
cudaStream_t stream)
cudaStream_t stream, bool subcomms_ucp = true)
: nccl_comm_(nccl_comm),
stream_(stream),
num_ranks_(num_ranks),
rank_(rank),
ucp_worker_(ucp_worker),
ucp_eps_(eps),
device_allocator_(device_allocator),
next_request_id_(0),
device_allocator_(device_allocator) {
subcomms_ucp_(subcomms_ucp),
ucp_worker_(ucp_worker),
ucp_eps_(eps) {
initialize();
};

Expand All @@ -189,6 +190,7 @@ class std_comms : public comms_iface {
stream_(stream),
num_ranks_(num_ranks),
rank_(rank),
subcomms_ucp_(false),
device_allocator_(device_allocator) {
initialize();
};
Expand All @@ -209,11 +211,86 @@ class std_comms : public comms_iface {

int get_rank() const { return rank_; }

void scatter_nccluniqueid(ncclUniqueId &id, int color, int key, int root,
int new_size, std::vector<int> &colors) const {
// root rank of new comm generates NCCL unique id and sends to other ranks of color
int request_idx = 0;
std::vector<request_t> requests;

if (key == root) {
NCCL_TRY(ncclGetUniqueId(&id));
requests.resize(new_size);
for (int i = 0; i < get_size(); i++) {
if (colors[i] == color) {
isend(&id, sizeof(ncclUniqueId), i, color,
requests.data() + request_idx);
++request_idx;
}
}
} else {
requests.resize(1);
// non-root ranks of new comm recv unique id
}
irecv(&id, sizeof(ncclUniqueId), root, color,
requests.data() + request_idx);

waitall(requests.size(), requests.data());
barrier();
}

std::unique_ptr<comms_iface> comm_split(int color, int key) const {
// Not supported by NCCL
ASSERT(false,
"ERROR: commSplit called but not yet supported in this comms "
"implementation.");
mr::device::buffer<int> colors(device_allocator_, stream_, get_size());
mr::device::buffer<int> keys(device_allocator_, stream_, get_size());

mr::device::buffer<int> color_buf(device_allocator_, stream_, 1);
mr::device::buffer<int> key_buf(device_allocator_, stream_, 1);

update_device(color_buf.data(), &color, 1, stream_);
update_device(key_buf.data(), &key, 1, stream_);

allgather(color_buf.data(), colors.data(), 1, datatype_t::INT32, stream_);
allgather(key_buf.data(), keys.data(), 1, datatype_t::INT32, stream_);
this->sync_stream(stream_);

// find all ranks with same color and lowest key of that color
std::vector<int> colors_host(get_size());
std::vector<int> keys_host(get_size());

update_host(colors_host.data(), colors.data(), get_size(), stream_);
update_host(keys_host.data(), keys.data(), get_size(), stream_);

CUDA_CHECK(cudaStreamSynchronize(stream_));

std::vector<int> ranks_with_color;
std::vector<ucp_ep_h> new_ucx_ptrs;
int min_rank = key;
for (int i = 0; i < get_size(); i++) {
if (colors_host[i] == color) {
ranks_with_color.push_back(keys_host[i]);
if (keys_host[i] < min_rank) min_rank = keys_host[i];
if (ucp_worker_ != nullptr && subcomms_ucp_)
new_ucx_ptrs.push_back((*ucp_eps_)[i]);
}
}

ncclUniqueId id;
scatter_nccluniqueid(id, color, key, min_rank, ranks_with_color.size(),
colors_host);

ncclComm_t nccl_comm;
NCCL_TRY(ncclCommInitRank(&nccl_comm, ranks_with_color.size(), id,
keys_host[get_rank()]));

std_comms *raft_comm;
if (ucp_worker_ != nullptr && subcomms_ucp_) {
auto eps_sp = std::make_shared<ucp_ep_h *>(new_ucx_ptrs.data());
return std::unique_ptr<comms_iface>(new std_comms(
nccl_comm, (ucp_worker_h)ucp_worker_, eps_sp, ranks_with_color.size(),
key, device_allocator_, stream_, subcomms_ucp_));
} else {
return std::unique_ptr<comms_iface>(new std_comms(
nccl_comm, ranks_with_color.size(), key, device_allocator_, stream_));
}
}

void barrier() const {
Expand Down Expand Up @@ -328,6 +405,7 @@ class std_comms : public comms_iface {
restart = true;

// perform cleanup
std::cout << "Freeing request" << std::endl;
ucp_handler_.free_ucp_request(req);

// remove from pending requests
Expand Down Expand Up @@ -431,6 +509,8 @@ class std_comms : public comms_iface {
int num_ranks_;
int rank_;

bool subcomms_ucp_;

comms_ucp_handler ucp_handler_;
ucp_worker_h ucp_worker_;
std::shared_ptr<ucp_ep_h *> ucp_eps_;
Expand Down
58 changes: 44 additions & 14 deletions cpp/include/raft/comms/test.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,17 +31,19 @@ namespace comms {
* initialized comms instance.
*/
bool test_collective_allreduce(const handle_t &handle, int root) {
const comms_t &communicator = handle.get_comms();
comms_t const &communicator = handle.get_comms();

const int send = 1;
int const send = 1;

cudaStream_t stream = handle.get_stream();

raft::mr::device::buffer<int> temp_d(handle.get_device_allocator(), stream);
temp_d.resize(1, stream);
CUDA_CHECK(
cudaMemcpyAsync(temp_d.data(), &send, 1, cudaMemcpyHostToDevice, stream));

communicator.allreduce(temp_d.data(), temp_d.data(), 1, op_t::SUM, stream);

int temp_h = 0;
CUDA_CHECK(
cudaMemcpyAsync(&temp_h, temp_d.data(), 1, cudaMemcpyDeviceToHost, stream));
Expand All @@ -61,9 +63,9 @@ bool test_collective_allreduce(const handle_t &handle, int root) {
* initialized comms instance.
*/
bool test_collective_broadcast(const handle_t &handle, int root) {
const comms_t &communicator = handle.get_comms();
comms_t const &communicator = handle.get_comms();

const int send = root;
int const send = root;

cudaStream_t stream = handle.get_stream();

Expand All @@ -89,9 +91,9 @@ bool test_collective_broadcast(const handle_t &handle, int root) {
}

bool test_collective_reduce(const handle_t &handle, int root) {
const comms_t &communicator = handle.get_comms();
comms_t const &communicator = handle.get_comms();

const int send = root;
int const send = root;

cudaStream_t stream = handle.get_stream();

Expand Down Expand Up @@ -119,9 +121,9 @@ bool test_collective_reduce(const handle_t &handle, int root) {
}

bool test_collective_allgather(const handle_t &handle, int root) {
const comms_t &communicator = handle.get_comms();
comms_t const &communicator = handle.get_comms();

const int send = root;
int const send = communicator.get_rank();

cudaStream_t stream = handle.get_stream();

Expand All @@ -138,7 +140,7 @@ bool test_collective_allgather(const handle_t &handle, int root) {
communicator.sync_stream(stream);
int
temp_h[communicator.get_size()]; // Verify more than one byte is being sent
CUDA_CHECK(cudaMemcpyAsync(&temp_h, temp_d.data(),
CUDA_CHECK(cudaMemcpyAsync(&temp_h, recv_d.data(),
sizeof(int) * communicator.get_size(),
cudaMemcpyDeviceToHost, stream));
CUDA_CHECK(cudaStreamSynchronize(stream));
Expand All @@ -147,15 +149,16 @@ bool test_collective_allgather(const handle_t &handle, int root) {
std::cout << "Clique size: " << communicator.get_size() << std::endl;
std::cout << "final_size: " << temp_h << std::endl;

for (int i = 0; i < communicator.get_size(); i++)
for (int i = 0; i < communicator.get_size(); i++) {
if (temp_h[i] != i) return false;
}
return true;
}

bool test_collective_reducescatter(const handle_t &handle, int root) {
const comms_t &communicator = handle.get_comms();
comms_t const &communicator = handle.get_comms();

const int send = 1;
int const send = 1;

cudaStream_t stream = handle.get_stream();

Expand Down Expand Up @@ -190,8 +193,8 @@ bool test_collective_reducescatter(const handle_t &handle, int root) {
* @param number of iterations of all-to-all messaging to perform
*/
bool test_pointToPoint_simple_send_recv(const handle_t &h, int numTrials) {
const comms_t &communicator = h.get_comms();
const int rank = communicator.get_rank();
comms_t const &communicator = h.get_comms();
int const rank = communicator.get_rank();

bool ret = true;
for (int i = 0; i < numTrials; i++) {
Expand Down Expand Up @@ -246,5 +249,32 @@ bool test_pointToPoint_simple_send_recv(const handle_t &h, int numTrials) {

return ret;
}

/**
* A simple test that the comms can be split into 2 separate subcommunicators
*
* @param the raft handle to use. This is expected to already have an
* initialized comms instance.
* @param n_colors number of different colors to test
*/
bool test_commsplit(const handle_t &h, int n_colors) {
comms_t const &communicator = h.get_comms();
int const rank = communicator.get_rank();
int const size = communicator.get_size();

if (n_colors > size) n_colors = size;

// first we need to assign to a color, then assign the rank within the color
int color = rank % n_colors;
int key = rank / n_colors;

handle_t new_handle(1);
auto shared_comm =
std::make_shared<comms_t>(communicator.comm_split(color, key));
new_handle.set_comms(shared_comm);

return test_collective_allreduce(new_handle, 0);
}

} // namespace comms
}; // namespace raft
6 changes: 6 additions & 0 deletions cpp/include/raft/comms/ucp_helper.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,7 @@ class comms_ucp_handler {
void free_ucp_request(ucp_request *request) const {
if (request->needs_release) {
request->req->completed = 0;
std::cout << "FREEING REQUEST" << std::endl;
(*(req_free_func))(request->req);
}
free(request);
Expand All @@ -179,6 +180,9 @@ class comms_ucp_handler {
ucs_status_ptr_t send_result = (*(send_func))(
ep_ptr, buf, size, ucp_dt_make_contig(1), ucp_tag, send_callback);
struct ucx_context *ucp_req = (struct ucx_context *)send_result;

std::cout << "REQ: " << ucp_req << std::endl;

if (UCS_PTR_IS_ERR(send_result)) {
ASSERT(!UCS_PTR_IS_ERR(send_result),
"unable to send UCX data message (%d)\n",
Expand Down Expand Up @@ -222,6 +226,8 @@ class comms_ucp_handler {
req->is_send_request = false;
req->other_rank = sender_rank;

std::cout << "REQ: " << ucp_req << std::endl;

ASSERT(!UCS_PTR_IS_ERR(recv_result),
"unable to receive UCX data message (%d)\n",
UCS_PTR_STATUS(recv_result));
Expand Down
Loading

0 comments on commit b8c46c9

Please sign in to comment.