Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[REVIEW] commSplit Implementation #18

Merged
merged 10 commits into from
Jul 29, 2020
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

## New Features
- PR #7: Migrating cuml comms -> raft comms_t
- PR #18: Adding commsplit to cuml communicator

## Improvements
- PR #13: Add RMM_INCLUDE and RMM_LIBRARY options to allow linking to non-conda RMM
Expand Down
4 changes: 0 additions & 4 deletions cpp/include/raft/comms/comms.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -89,8 +89,6 @@ constexpr datatype_t get_type<double>() {

class comms_iface {
public:
virtual ~comms_iface();

virtual int get_size() const = 0;
virtual int get_rank() const = 0;

Expand Down Expand Up @@ -323,7 +321,5 @@ class comms_t {
std::unique_ptr<comms_iface> impl_;
};

comms_iface::~comms_iface() {}

} // namespace comms
} // namespace raft
89 changes: 84 additions & 5 deletions cpp/include/raft/comms/std_comms.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@

#include <ucp/api/ucp.h>
#include <ucp/api/ucp_def.h>
#include "ucp_helper.hpp"
#include <raft/comms/ucp_helper.hpp>
cjnolet marked this conversation as resolved.
Show resolved Hide resolved

#include <nccl.h>

Expand All @@ -38,6 +38,7 @@
#include <exception>
#include <memory>
#include <raft/handle.hpp>
#include <raft/mr/device/buffer.hpp>

#include <thread>

Expand Down Expand Up @@ -188,11 +189,87 @@ class std_comms : public comms_iface {

int get_rank() const { return rank_; }

void scatter_nccluniqueid(ncclUniqueId &id, int color, int key, int root,
cjnolet marked this conversation as resolved.
Show resolved Hide resolved
int new_size, std::vector<int> &colors) const {
// root rank of new comm generates NCCL unique id and sends to other ranks of color
int request_idx = 0;
std::vector<request_t> requests;
cjnolet marked this conversation as resolved.
Show resolved Hide resolved

if (key == root) {
NCCL_CHECK(ncclGetUniqueId(&id));
cjnolet marked this conversation as resolved.
Show resolved Hide resolved
requests.resize(new_size);
for (int i = 0; i < get_size(); i++) {
if (colors.data()[i] == color) {
cjnolet marked this conversation as resolved.
Show resolved Hide resolved
isend(&id.internal, 128, i, color, requests.data() + request_idx);
cjnolet marked this conversation as resolved.
Show resolved Hide resolved
++request_idx;
}
}
} else {
requests.resize(1);
}

// non-root ranks of new comm recv unique id
irecv(&id.internal, 128, root, color, requests.data() + request_idx);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Another hard-coded 128 here


waitall(requests.size(), requests.data());
barrier();
cjnolet marked this conversation as resolved.
Show resolved Hide resolved
}

std::unique_ptr<comms_iface> comm_split(int color, int key) const {
cjnolet marked this conversation as resolved.
Show resolved Hide resolved
// Not supported by NCCL
ASSERT(false,
"ERROR: commSplit called but not yet supported in this comms "
"implementation.");
mr::device::buffer<int> colors(device_allocator_, stream_, get_size());
mr::device::buffer<int> keys(device_allocator_, stream_, get_size());

mr::device::buffer<int> color_buf(device_allocator_, stream_, 1);
mr::device::buffer<int> key_buf(device_allocator_, stream_, 1);

update_device(color_buf.data(), &color, 1, stream_);
update_device(key_buf.data(), &key, 1, stream_);

allgather(color_buf.data(), colors.data(), 1, datatype_t::INT32, stream_);
allgather(key_buf.data(), keys.data(), 1, datatype_t::INT32, stream_);
this->sync_stream(stream_);
cjnolet marked this conversation as resolved.
Show resolved Hide resolved

// find all ranks with same color and lowest key of that color
std::vector<int> colors_host(get_size());
std::vector<int> keys_host(get_size());

update_host(colors_host.data(), colors.data(), get_size(), stream_);
update_host(keys_host.data(), keys.data(), get_size(), stream_);
cjnolet marked this conversation as resolved.
Show resolved Hide resolved

CUDA_CHECK(cudaStreamSynchronize(stream_));

std::vector<int> ranks_with_color;
std::vector<ucp_ep_h> new_ucx_ptrs;
int min_rank = key;
for (int i = 0; i < get_size(); i++) {
if (colors_host[i] == color) {
ranks_with_color.push_back(keys_host[i]);
if (keys_host[i] < min_rank) min_rank = keys_host[i];
cjnolet marked this conversation as resolved.
Show resolved Hide resolved

if (ucp_worker_ != nullptr) new_ucx_ptrs.push_back((*ucp_eps_)[i]);
}
}

ncclUniqueId id;
scatter_nccluniqueid(id, color, key, min_rank, ranks_with_color.size(),
colors_host);
cjnolet marked this conversation as resolved.
Show resolved Hide resolved

ncclComm_t nccl_comm;
NCCL_CHECK(ncclCommInitRank(&nccl_comm, ranks_with_color.size(), id,
keys_host[get_rank()]));

std_comms *raft_comm;
cjnolet marked this conversation as resolved.
Show resolved Hide resolved
if (ucp_worker_ != nullptr) {
auto eps_sp = std::make_shared<ucp_ep_h *>(new_ucx_ptrs.data());
raft_comm =
cjnolet marked this conversation as resolved.
Show resolved Hide resolved
new std_comms(nccl_comm, (ucp_worker_h)ucp_worker_, eps_sp,
ranks_with_color.size(), key, device_allocator_, stream_);
} else {
raft_comm = new std_comms(nccl_comm, ranks_with_color.size(), key,
cjnolet marked this conversation as resolved.
Show resolved Hide resolved
device_allocator_, stream_);
}

return std::unique_ptr<comms_iface>(raft_comm);
}

void barrier() const {
Expand Down Expand Up @@ -307,6 +384,7 @@ class std_comms : public comms_iface {
restart = true;

// perform cleanup
std::cout << "Freeing request" << std::endl;
cjnolet marked this conversation as resolved.
Show resolved Hide resolved
ucp_handler_.free_ucp_request(req);

// remove from pending requests
Expand All @@ -324,6 +402,7 @@ class std_comms : public comms_iface {

void allreduce(const void *sendbuff, void *recvbuff, size_t count,
datatype_t datatype, op_t op, cudaStream_t stream) const {
std::cout << "Inside allreduce" << std::endl;
cjnolet marked this conversation as resolved.
Show resolved Hide resolved
NCCL_CHECK(ncclAllReduce(sendbuff, recvbuff, count,
get_nccl_datatype(datatype), get_nccl_op(op),
nccl_comm_, stream));
Expand Down
37 changes: 34 additions & 3 deletions cpp/include/raft/comms/test.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,9 @@ bool test_collective_allreduce(const handle_t &handle, int root) {
temp_d.resize(1, stream);
CUDA_CHECK(
cudaMemcpyAsync(temp_d.data(), &send, 1, cudaMemcpyHostToDevice, stream));

communicator.allreduce(temp_d.data(), temp_d.data(), 1, op_t::SUM, stream);

int temp_h = 0;
CUDA_CHECK(
cudaMemcpyAsync(&temp_h, temp_d.data(), 1, cudaMemcpyDeviceToHost, stream));
Expand Down Expand Up @@ -121,7 +123,7 @@ bool test_collective_reduce(const handle_t &handle, int root) {
bool test_collective_allgather(const handle_t &handle, int root) {
const comms_t &communicator = handle.get_comms();

const int send = root;
const int send = communicator.get_rank();

cudaStream_t stream = handle.get_stream();

Expand All @@ -138,7 +140,7 @@ bool test_collective_allgather(const handle_t &handle, int root) {
communicator.sync_stream(stream);
int
temp_h[communicator.get_size()]; // Verify more than one byte is being sent
CUDA_CHECK(cudaMemcpyAsync(&temp_h, temp_d.data(),
CUDA_CHECK(cudaMemcpyAsync(&temp_h, recv_d.data(),
sizeof(int) * communicator.get_size(),
cudaMemcpyDeviceToHost, stream));
CUDA_CHECK(cudaStreamSynchronize(stream));
Expand All @@ -147,8 +149,9 @@ bool test_collective_allgather(const handle_t &handle, int root) {
std::cout << "Clique size: " << communicator.get_size() << std::endl;
std::cout << "final_size: " << temp_h << std::endl;

for (int i = 0; i < communicator.get_size(); i++)
for (int i = 0; i < communicator.get_size(); i++) {
if (temp_h[i] != i) return false;
}
return true;
}

Expand Down Expand Up @@ -246,5 +249,33 @@ bool test_pointToPoint_simple_send_recv(const handle_t &h, int numTrials) {

return ret;
}

/**
* A simple test that the comms can be split into 2 separate subcommunicators
*
* @param the raft handle to use. This is expected to already have an
* initialized comms instance.
*/
cjnolet marked this conversation as resolved.
Show resolved Hide resolved
bool test_commsplit(const handle_t &h, int n_colors) {
const comms_t &communicator = h.get_comms();
const int rank = communicator.get_rank();
const int size = communicator.get_size();
cjnolet marked this conversation as resolved.
Show resolved Hide resolved

if (n_colors > size) n_colors = size;
cjnolet marked this conversation as resolved.
Show resolved Hide resolved

int n_ranks_per_color = communicator.get_size() / n_colors;
cjnolet marked this conversation as resolved.
Show resolved Hide resolved

// first we need to assign to a color, then assign the rank within the color
int color = rank % n_colors;
int key = rank % n_ranks_per_color;
cjnolet marked this conversation as resolved.
Show resolved Hide resolved

handle_t new_handle(1);
auto shared_comm =
std::make_shared<comms_t>(communicator.comm_split(color, key));
new_handle.set_comms(shared_comm);
cjnolet marked this conversation as resolved.
Show resolved Hide resolved

return test_collective_allreduce(new_handle, 0);
}

} // namespace comms
}; // namespace raft
6 changes: 6 additions & 0 deletions cpp/include/raft/comms/ucp_helper.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,7 @@ class comms_ucp_handler {
void free_ucp_request(ucp_request *request) const {
if (request->needs_release) {
request->req->completed = 0;
std::cout << "FREEING REQUEST" << std::endl;
cjnolet marked this conversation as resolved.
Show resolved Hide resolved
(*(req_free_func))(request->req);
}
free(request);
Expand All @@ -179,6 +180,9 @@ class comms_ucp_handler {
ucs_status_ptr_t send_result = (*(send_func))(
ep_ptr, buf, size, ucp_dt_make_contig(1), ucp_tag, send_callback);
struct ucx_context *ucp_req = (struct ucx_context *)send_result;

std::cout << "REQ: " << ucp_req << std::endl;
cjnolet marked this conversation as resolved.
Show resolved Hide resolved

if (UCS_PTR_IS_ERR(send_result)) {
ASSERT(!UCS_PTR_IS_ERR(send_result),
"unable to send UCX data message (%d)\n",
Expand Down Expand Up @@ -222,6 +226,8 @@ class comms_ucp_handler {
req->is_send_request = false;
req->other_rank = sender_rank;

std::cout << "REQ: " << ucp_req << std::endl;
cjnolet marked this conversation as resolved.
Show resolved Hide resolved

ASSERT(!UCS_PTR_IS_ERR(recv_result),
"unable to receive UCX data message (%d)\n",
UCS_PTR_STATUS(recv_result));
Expand Down
3 changes: 3 additions & 0 deletions python/raft/dask/common/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,3 +24,6 @@
from .comms_utils import perform_test_comms_bcast
from .comms_utils import perform_test_comms_reduce
from .comms_utils import perform_test_comms_reducescatter
from .comms_utils import perform_test_comm_split

from .ucx import UCX
13 changes: 4 additions & 9 deletions python/raft/dask/common/comms.py
Original file line number Diff line number Diff line change
Expand Up @@ -375,21 +375,16 @@ async def _func_ucp_create_endpoints(sessionId, worker_info):
worker_info : dict
Maps worker addresses to NCCL ranks & UCX ports
"""
dask_worker = get_worker()
local_address = dask_worker.address

eps = [None] * len(worker_info)
count = 1

for k in worker_info:
if str(k) != str(local_address):

ip, port = parse_host_port(k)
ip, port = parse_host_port(k)

ep = await get_ucx().get_endpoint(ip, worker_info[k]["port"])
ep = await get_ucx().get_endpoint(ip, worker_info[k]["port"])

eps[worker_info[k]["rank"]] = ep
count += 1
eps[worker_info[k]["rank"]] = ep
count += 1

worker_state(sessionId)["ucp_eps"] = eps

Expand Down
15 changes: 14 additions & 1 deletion python/raft/dask/common/comms_utils.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@ cdef extern from "nccl.h":
cdef struct ncclComm
ctypedef ncclComm *ncclComm_t


cdef extern from "raft/handle.hpp" namespace "raft":
cdef cppclass handle_t:
handle_t() except +
Expand Down Expand Up @@ -64,6 +63,7 @@ cdef extern from "raft/comms/test.hpp" namespace "raft::comms":
bool test_collective_reducescatter(const handle_t &h, int root) except +
bool test_pointToPoint_simple_send_recv(const handle_t &h,
int numTrials) except +
bool test_commsplit(const handle_t &h, int n_colors) except +


def perform_test_comms_allreduce(handle, root):
Expand Down Expand Up @@ -144,6 +144,19 @@ def perform_test_comms_send_recv(handle, n_trials):
return test_pointToPoint_simple_send_recv(deref(h), <int>n_trials)


def perform_test_comm_split(handle, n_colors):
"""
Performs a p2p send/recv on the current worker

Parameters
----------
handle : raft.common.Handle
handle containing comms_t to use
"""
cdef const handle_t * h = < handle_t * > < size_t > handle.getHandle()
return test_commsplit(deref(h), < int > n_colors)


def inject_comms_on_handle_coll_only(handle, nccl_inst, size, rank, verbose):
"""
Given a handle and initialized nccl comm, creates a comms_t
Expand Down
4 changes: 4 additions & 0 deletions python/raft/dask/common/ucx.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,10 @@ async def get_endpoint(self, ip, port):

return ep

async def close_endpoints(self):
for k, ep in self._endpoints.items():
await ep.close()

def __del__(self):
for ip_port, ep in self._endpoints.items():
if not ep.closed():
Expand Down
13 changes: 8 additions & 5 deletions python/raft/test/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,19 +5,23 @@
from dask_cuda import initialize
from dask_cuda import LocalCUDACluster

import os
os.environ["UCX_LOG_LEVEL"] = "error"


enable_tcp_over_ucx = True
enable_nvlink = False
enable_infiniband = False


@pytest.fixture(scope="module")
@pytest.fixture(scope="session")
def cluster():
cluster = LocalCUDACluster(protocol="tcp", scheduler_port=0)
yield cluster
cluster.close()


@pytest.fixture(scope="module")
@pytest.fixture(scope="session")
def ucx_cluster():
initialize.initialize(create_cuda_context=True,
enable_tcp_over_ucx=enable_tcp_over_ucx,
Expand All @@ -26,13 +30,12 @@ def ucx_cluster():
cluster = LocalCUDACluster(protocol="ucx",
enable_tcp_over_ucx=enable_tcp_over_ucx,
enable_nvlink=enable_nvlink,
enable_infiniband=enable_infiniband,
ucx_net_devices="auto")
enable_infiniband=enable_infiniband)
yield cluster
cluster.close()


@pytest.fixture()
@pytest.fixture(scope="session")
def client(cluster):
client = Client(cluster)
yield client
Expand Down
Loading