Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[REVIEW] Consistent renaming of CHECK_CUDA and *_TRY macros #410

Merged
merged 5 commits into from
Dec 9, 2021
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion cpp/include/raft/common/scatter.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ void scatterImpl(
{
const IdxT nblks = raft::ceildiv(VecLen ? len / VecLen : len, (IdxT)TPB);
scatterKernel<DataT, VecLen, Lambda, IdxT><<<nblks, TPB, 0, stream>>>(out, in, idx, len, op);
CUDA_CHECK(cudaGetLastError());
RAFT_CHECK_CUDA(cudaGetLastError());
}

/**
Expand Down
158 changes: 85 additions & 73 deletions cpp/include/raft/comms/mpi_comms.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
#include <raft/error.hpp>
#include <raft/handle.hpp>

#define MPI_TRY(call) \
#define RAFT_MPI_TRY(call) \
do { \
int status = call; \
if (MPI_SUCCESS != status) { \
Expand All @@ -44,7 +44,12 @@
} \
} while (0)

#define MPI_TRY_NO_THROW(call) \
// FIXME: Remove after consumer rename
#ifndef MPI_TRY
#define MPI_TRY(call) RAFT_MPI_TRY(call)
#endif

#define RAFT_MPI_TRY_NO_THROW(call) \
do { \
int status = call; \
if (MPI_SUCCESS != status) { \
Expand All @@ -59,6 +64,11 @@
} \
} while (0)

// FIXME: Remove after consumer rename
#ifndef MPI_TRY_NO_THROW
#define MPI_TRY_NO_THROW(call) RAFT_MPI_TRY_NO_THROW(call)
#endif

namespace raft {
namespace comms {

Expand Down Expand Up @@ -98,24 +108,24 @@ class mpi_comms : public comms_iface {
: owns_mpi_comm_(owns_mpi_comm), mpi_comm_(comm), size_(0), rank_(1), next_request_id_(0)
{
int mpi_is_initialized = 0;
MPI_TRY(MPI_Initialized(&mpi_is_initialized));
RAFT_MPI_TRY(MPI_Initialized(&mpi_is_initialized));
RAFT_EXPECTS(mpi_is_initialized, "ERROR: MPI is not initialized!");
MPI_TRY(MPI_Comm_size(mpi_comm_, &size_));
MPI_TRY(MPI_Comm_rank(mpi_comm_, &rank_));
RAFT_MPI_TRY(MPI_Comm_size(mpi_comm_, &size_));
RAFT_MPI_TRY(MPI_Comm_rank(mpi_comm_, &rank_));
// get NCCL unique ID at rank 0 and broadcast it to all others
ncclUniqueId id;
if (0 == rank_) NCCL_TRY(ncclGetUniqueId(&id));
MPI_TRY(MPI_Bcast((void*)&id, sizeof(id), MPI_BYTE, 0, mpi_comm_));
if (0 == rank_) RAFT_NCCL_TRY(ncclGetUniqueId(&id));
RAFT_MPI_TRY(MPI_Bcast((void*)&id, sizeof(id), MPI_BYTE, 0, mpi_comm_));

// initializing NCCL
NCCL_TRY(ncclCommInitRank(&nccl_comm_, size_, id, rank_));
RAFT_NCCL_TRY(ncclCommInitRank(&nccl_comm_, size_, id, rank_));
}

virtual ~mpi_comms()
{
// finalizing NCCL
NCCL_TRY_NO_THROW(ncclCommDestroy(nccl_comm_));
if (owns_mpi_comm_) { MPI_TRY_NO_THROW(MPI_Comm_free(&mpi_comm_)); }
RAFT_NCCL_TRY_NO_THROW(ncclCommDestroy(nccl_comm_));
if (owns_mpi_comm_) { RAFT_MPI_TRY_NO_THROW(MPI_Comm_free(&mpi_comm_)); }
}

int get_size() const { return size_; }
Expand All @@ -125,11 +135,11 @@ class mpi_comms : public comms_iface {
std::unique_ptr<comms_iface> comm_split(int color, int key) const
{
MPI_Comm new_comm;
MPI_TRY(MPI_Comm_split(mpi_comm_, color, key, &new_comm));
RAFT_MPI_TRY(MPI_Comm_split(mpi_comm_, color, key, &new_comm));
return std::unique_ptr<comms_iface>(new mpi_comms(new_comm, true));
}

void barrier() const { MPI_TRY(MPI_Barrier(mpi_comm_)); }
void barrier() const { RAFT_MPI_TRY(MPI_Barrier(mpi_comm_)); }

void isend(const void* buf, size_t size, int dest, int tag, request_t* request) const
{
Expand All @@ -142,7 +152,7 @@ class mpi_comms : public comms_iface {
req_id = *it;
free_requests_.erase(it);
}
MPI_TRY(MPI_Isend(buf, size, MPI_BYTE, dest, tag, mpi_comm_, &mpi_req));
RAFT_MPI_TRY(MPI_Isend(buf, size, MPI_BYTE, dest, tag, mpi_comm_, &mpi_req));
requests_in_flight_.insert(std::make_pair(req_id, mpi_req));
*request = req_id;
}
Expand All @@ -159,7 +169,7 @@ class mpi_comms : public comms_iface {
free_requests_.erase(it);
}

MPI_TRY(MPI_Irecv(buf, size, MPI_BYTE, source, tag, mpi_comm_, &mpi_req));
RAFT_MPI_TRY(MPI_Irecv(buf, size, MPI_BYTE, source, tag, mpi_comm_, &mpi_req));
requests_in_flight_.insert(std::make_pair(req_id, mpi_req));
*request = req_id;
}
Expand All @@ -177,7 +187,7 @@ class mpi_comms : public comms_iface {
free_requests_.insert(req_it->first);
requests_in_flight_.erase(req_it);
}
MPI_TRY(MPI_Waitall(requests.size(), requests.data(), MPI_STATUSES_IGNORE));
RAFT_MPI_TRY(MPI_Waitall(requests.size(), requests.data(), MPI_STATUSES_IGNORE));
}

void allreduce(const void* sendbuff,
Expand All @@ -187,13 +197,13 @@ class mpi_comms : public comms_iface {
op_t op,
cudaStream_t stream) const
{
NCCL_TRY(ncclAllReduce(
RAFT_NCCL_TRY(ncclAllReduce(
sendbuff, recvbuff, count, get_nccl_datatype(datatype), get_nccl_op(op), nccl_comm_, stream));
}

void bcast(void* buff, size_t count, datatype_t datatype, int root, cudaStream_t stream) const
{
NCCL_TRY(
RAFT_NCCL_TRY(
ncclBroadcast(buff, buff, count, get_nccl_datatype(datatype), root, nccl_comm_, stream));
}

Expand All @@ -204,7 +214,7 @@ class mpi_comms : public comms_iface {
int root,
cudaStream_t stream) const
{
NCCL_TRY(ncclBroadcast(
RAFT_NCCL_TRY(ncclBroadcast(
sendbuff, recvbuff, count, get_nccl_datatype(datatype), root, nccl_comm_, stream));
}

Expand All @@ -216,14 +226,14 @@ class mpi_comms : public comms_iface {
int root,
cudaStream_t stream) const
{
NCCL_TRY(ncclReduce(sendbuff,
recvbuff,
count,
get_nccl_datatype(datatype),
get_nccl_op(op),
root,
nccl_comm_,
stream));
RAFT_NCCL_TRY(ncclReduce(sendbuff,
recvbuff,
count,
get_nccl_datatype(datatype),
get_nccl_op(op),
root,
nccl_comm_,
stream));
}

void allgather(const void* sendbuff,
Expand All @@ -232,7 +242,7 @@ class mpi_comms : public comms_iface {
datatype_t datatype,
cudaStream_t stream) const
{
NCCL_TRY(ncclAllGather(
RAFT_NCCL_TRY(ncclAllGather(
sendbuff, recvbuff, sendcount, get_nccl_datatype(datatype), nccl_comm_, stream));
}

Expand All @@ -246,7 +256,7 @@ class mpi_comms : public comms_iface {
// From: "An Empirical Evaluation of Allgatherv on Multi-GPU Systems" -
// https://arxiv.org/pdf/1812.05964.pdf Listing 1 on page 4.
for (int root = 0; root < size_; ++root) {
NCCL_TRY(
RAFT_NCCL_TRY(
ncclBroadcast(sendbuf,
static_cast<char*>(recvbuf) + displs[root] * get_datatype_size(datatype),
recvcounts[root],
Expand All @@ -265,19 +275,20 @@ class mpi_comms : public comms_iface {
cudaStream_t stream) const
{
size_t dtype_size = get_datatype_size(datatype);
NCCL_TRY(ncclGroupStart());
RAFT_NCCL_TRY(ncclGroupStart());
if (get_rank() == root) {
for (int r = 0; r < get_size(); ++r) {
NCCL_TRY(ncclRecv(static_cast<char*>(recvbuff) + sendcount * r * dtype_size,
sendcount,
get_nccl_datatype(datatype),
r,
nccl_comm_,
stream));
RAFT_NCCL_TRY(ncclRecv(static_cast<char*>(recvbuff) + sendcount * r * dtype_size,
sendcount,
get_nccl_datatype(datatype),
r,
nccl_comm_,
stream));
}
}
NCCL_TRY(ncclSend(sendbuff, sendcount, get_nccl_datatype(datatype), root, nccl_comm_, stream));
NCCL_TRY(ncclGroupEnd());
RAFT_NCCL_TRY(
ncclSend(sendbuff, sendcount, get_nccl_datatype(datatype), root, nccl_comm_, stream));
RAFT_NCCL_TRY(ncclGroupEnd());
}

void gatherv(const void* sendbuff,
Expand All @@ -290,19 +301,20 @@ class mpi_comms : public comms_iface {
cudaStream_t stream) const
{
size_t dtype_size = get_datatype_size(datatype);
NCCL_TRY(ncclGroupStart());
RAFT_NCCL_TRY(ncclGroupStart());
if (get_rank() == root) {
for (int r = 0; r < get_size(); ++r) {
NCCL_TRY(ncclRecv(static_cast<char*>(recvbuff) + displs[r] * dtype_size,
recvcounts[r],
get_nccl_datatype(datatype),
r,
nccl_comm_,
stream));
RAFT_NCCL_TRY(ncclRecv(static_cast<char*>(recvbuff) + displs[r] * dtype_size,
recvcounts[r],
get_nccl_datatype(datatype),
r,
nccl_comm_,
stream));
}
}
NCCL_TRY(ncclSend(sendbuff, sendcount, get_nccl_datatype(datatype), root, nccl_comm_, stream));
NCCL_TRY(ncclGroupEnd());
RAFT_NCCL_TRY(
ncclSend(sendbuff, sendcount, get_nccl_datatype(datatype), root, nccl_comm_, stream));
RAFT_NCCL_TRY(ncclGroupEnd());
}

void reducescatter(const void* sendbuff,
Expand All @@ -312,13 +324,13 @@ class mpi_comms : public comms_iface {
op_t op,
cudaStream_t stream) const
{
NCCL_TRY(ncclReduceScatter(sendbuff,
recvbuff,
recvcount,
get_nccl_datatype(datatype),
get_nccl_op(op),
nccl_comm_,
stream));
RAFT_NCCL_TRY(ncclReduceScatter(sendbuff,
recvbuff,
recvcount,
get_nccl_datatype(datatype),
get_nccl_op(op),
nccl_comm_,
stream));
}

status_t sync_stream(cudaStream_t stream) const
Expand Down Expand Up @@ -357,13 +369,13 @@ class mpi_comms : public comms_iface {
// if a thread is sending & receiving at the same time, use device_sendrecv to avoid deadlock
void device_send(const void* buf, size_t size, int dest, cudaStream_t stream) const
{
NCCL_TRY(ncclSend(buf, size, ncclUint8, dest, nccl_comm_, stream));
RAFT_NCCL_TRY(ncclSend(buf, size, ncclUint8, dest, nccl_comm_, stream));
}

// if a thread is sending & receiving at the same time, use device_sendrecv to avoid deadlock
void device_recv(void* buf, size_t size, int source, cudaStream_t stream) const
{
NCCL_TRY(ncclRecv(buf, size, ncclUint8, source, nccl_comm_, stream));
RAFT_NCCL_TRY(ncclRecv(buf, size, ncclUint8, source, nccl_comm_, stream));
}

void device_sendrecv(const void* sendbuf,
Expand All @@ -375,10 +387,10 @@ class mpi_comms : public comms_iface {
cudaStream_t stream) const
{
// ncclSend/ncclRecv pair needs to be inside ncclGroupStart/ncclGroupEnd to avoid deadlock
NCCL_TRY(ncclGroupStart());
NCCL_TRY(ncclSend(sendbuf, sendsize, ncclUint8, dest, nccl_comm_, stream));
NCCL_TRY(ncclRecv(recvbuf, recvsize, ncclUint8, source, nccl_comm_, stream));
NCCL_TRY(ncclGroupEnd());
RAFT_NCCL_TRY(ncclGroupStart());
RAFT_NCCL_TRY(ncclSend(sendbuf, sendsize, ncclUint8, dest, nccl_comm_, stream));
RAFT_NCCL_TRY(ncclRecv(recvbuf, recvsize, ncclUint8, source, nccl_comm_, stream));
RAFT_NCCL_TRY(ncclGroupEnd());
}

void device_multicast_sendrecv(const void* sendbuf,
Expand All @@ -392,24 +404,24 @@ class mpi_comms : public comms_iface {
cudaStream_t stream) const
{
// ncclSend/ncclRecv pair needs to be inside ncclGroupStart/ncclGroupEnd to avoid deadlock
NCCL_TRY(ncclGroupStart());
RAFT_NCCL_TRY(ncclGroupStart());
for (size_t i = 0; i < sendsizes.size(); ++i) {
NCCL_TRY(ncclSend(static_cast<const char*>(sendbuf) + sendoffsets[i],
sendsizes[i],
ncclUint8,
dests[i],
nccl_comm_,
stream));
RAFT_NCCL_TRY(ncclSend(static_cast<const char*>(sendbuf) + sendoffsets[i],
sendsizes[i],
ncclUint8,
dests[i],
nccl_comm_,
stream));
}
for (size_t i = 0; i < recvsizes.size(); ++i) {
NCCL_TRY(ncclRecv(static_cast<char*>(recvbuf) + recvoffsets[i],
recvsizes[i],
ncclUint8,
sources[i],
nccl_comm_,
stream));
RAFT_NCCL_TRY(ncclRecv(static_cast<char*>(recvbuf) + recvoffsets[i],
recvsizes[i],
ncclUint8,
sources[i],
nccl_comm_,
stream));
}
NCCL_TRY(ncclGroupEnd());
RAFT_NCCL_TRY(ncclGroupEnd());
}

private:
Expand Down
Loading