Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[REVIEW] Enable Multi-Node Multi-GPU functionality #4095

Merged
merged 49 commits into from
Mar 1, 2019
Merged
Show file tree
Hide file tree
Changes from 34 commits
Commits
Show all changes
49 commits
Select commit Hold shift + click to select a range
c950cb0
Initial commit to support multi-node multi-gpu xgboost using dask
teju85 Aug 29, 2018
1cad3f8
Fixed NCCL initialization by not ignoring the opg parameter.
canonizer Sep 4, 2018
b8a4b48
At the root node, perform a rabit::Allreduce to get initial sum_gradi…
teju85 Sep 5, 2018
061a9e8
Synchronizing in a couple of more places.
canonizer Sep 6, 2018
abd3933
Added another missing max-allreduce operation inside BuildHistLeftRight
teju85 Sep 7, 2018
e969d97
Removed unnecessary collective operations.
canonizer Sep 7, 2018
bfe9b46
Simplified rabit::Allreduce() sync of gradient sums.
canonizer Sep 7, 2018
3b3846b
Removed unnecessary rabit syncs around ncclAllReduce.
canonizer Sep 7, 2018
28fe834
pulling in latest xgboost
mtjrider Feb 18, 2019
3feb460
Merge branch 'master' into mnmg
Feb 18, 2019
df4c78f
removing changes to updater_quantile_hist.cc
Feb 18, 2019
13f85ac
changing use_nccl_opg initialization, removing unnecessary if statements
mtjrider Feb 19, 2019
021bca0
Merge branch 'mnmg' of github.com:rapidsai/xgboost into mnmg
mtjrider Feb 19, 2019
a60840f
added definition for opaque ncclUniqueId struct to properly encapsula…
mtjrider Feb 19, 2019
d0a3598
placing struct defintion in guard to avoid duplicate code errors
mtjrider Feb 19, 2019
29aede6
addressing linting errors
mtjrider Feb 19, 2019
03471b1
removing
mtjrider Feb 19, 2019
201604c
removing additional arguments to AllReduer initialization
mtjrider Feb 19, 2019
b897632
removing distributed flag
mtjrider Feb 19, 2019
985d822
making comm init symmetric
mtjrider Feb 20, 2019
8e477a6
removing distributed flag
mtjrider Feb 20, 2019
8cf29b8
changing ncclCommInit to support multiple modalities
mtjrider Feb 20, 2019
efefb70
fix indenting
mtjrider Feb 20, 2019
ed95106
updating ncclCommInitRank block with necessary group calls
mtjrider Feb 20, 2019
fdf4c77
fix indenting
mtjrider Feb 20, 2019
2663164
adding print statement, and updating accessor in vector
mtjrider Feb 20, 2019
fb73405
improving print statement to end-line
mtjrider Feb 21, 2019
30fc31d
Merge branch 'master' into mnmg
mtjrider Feb 21, 2019
358fb49
generalizing nccl_rank construction using rabit
mtjrider Feb 21, 2019
b73e25d
assume device_ordinals is the same for every node
mtjrider Feb 21, 2019
6b1eca1
test, assume device_ordinals is identical for all nodes
mtjrider Feb 22, 2019
319a124
test, assume device_ordinals is unique for all nodes
mtjrider Feb 22, 2019
dae22df
changing names of offset variable to be more descriptive, editing ind…
mtjrider Feb 22, 2019
f4c2fe3
Merge branch 'master' into mnmg
mtjrider Feb 24, 2019
8462f05
wrapping ncclUniqueId GetUniqueId() and aesthetic changes
mtjrider Feb 25, 2019
8ba25fa
adding synchronization, and tests for distributed
mtjrider Feb 26, 2019
822d0b1
adding to tests
mtjrider Feb 26, 2019
b5d20f5
Merge branch 'master' into mnmg
mtjrider Feb 26, 2019
ac63db6
fixing broken #endif
mtjrider Feb 27, 2019
1f06d43
fixing initialization of gpu histograms, correcting errors in tests
mtjrider Feb 27, 2019
14e589a
adding to contributors list
mtjrider Feb 27, 2019
1813d48
adding distributed tests to jenkins
mtjrider Feb 27, 2019
489e499
fixing bad path in distributed test
mtjrider Feb 27, 2019
7187ea3
debugging
mtjrider Feb 27, 2019
0725839
adding kubernetes for distributed tests
mtjrider Feb 27, 2019
2f9827e
adding proper import for OrderedDict
mtjrider Feb 27, 2019
be0cd1a
adding urllib3==1.22 to address ordered_dict import error
mtjrider Feb 27, 2019
1bfd668
added sleep to allow workers to save their models for comparison
mtjrider Feb 27, 2019
ff6dc61
adding name to GPU contributors under docs
mtjrider Feb 27, 2019
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 65 additions & 7 deletions src/common/device_helpers.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,10 @@

#ifdef XGBOOST_USE_NCCL
#include "nccl.h"
#include "../common/io.h"
#else
#define NCCL_UNIQUE_ID_BYTES 128
typedef struct { char internal[NCCL_UNIQUE_ID_BYTES]; } ncclUniqueId;
#endif

// Uncomment to enable
Expand Down Expand Up @@ -853,6 +857,8 @@ class AllReducer {
std::vector<ncclComm_t> comms;
std::vector<cudaStream_t> streams;
std::vector<int> device_ordinals; // device id from CUDA
std::vector<int> device_counts; // device count from CUDA
RAMitchell marked this conversation as resolved.
Show resolved Hide resolved
ncclUniqueId id;
#endif

public:
Expand All @@ -872,14 +878,43 @@ class AllReducer {
#ifdef XGBOOST_USE_NCCL
/** \brief this >monitor . init. */
this->device_ordinals = device_ordinals;
comms.resize(device_ordinals.size());
dh::safe_nccl(ncclCommInitAll(comms.data(),
static_cast<int>(device_ordinals.size()),
device_ordinals.data()));
streams.resize(device_ordinals.size());
this->device_counts.resize(rabit::GetWorldSize());
this->comms.resize(device_ordinals.size());
this->streams.resize(device_ordinals.size());
this->id = GetUniqueId();
mtjrider marked this conversation as resolved.
Show resolved Hide resolved

device_counts.at(rabit::GetRank()) = device_ordinals.size();
for (size_t i = 0; i < device_counts.size(); i++) {
rabit::Broadcast(
(void*)&(device_counts.at(rabit::GetRank())),
mtjrider marked this conversation as resolved.
Show resolved Hide resolved
(size_t)sizeof(device_counts.at(rabit::GetRank())),
(int)rabit::GetRank());
}

int nccl_rank = 0;
int nccl_rank_offset = std::accumulate(device_counts.begin(),
device_counts.begin() + rabit::GetRank(), 0);
int nccl_nranks = std::accumulate(device_counts.begin(),
device_counts.end(), 0);
nccl_rank += nccl_rank_offset;

GroupStart();
for (size_t i = 0; i < device_ordinals.size(); i++) {
safe_cuda(cudaSetDevice(device_ordinals[i]));
safe_cuda(cudaStreamCreate(&streams[i]));
int dev = device_ordinals.at(i);

dh::safe_cuda(cudaSetDevice(dev));
dh::safe_nccl(ncclCommInitRank(
&(comms.at(i)),
mtjrider marked this conversation as resolved.
Show resolved Hide resolved
nccl_nranks, id,
nccl_rank));

nccl_rank++;
}
GroupEnd();

for (size_t i = 0; i < device_ordinals.size(); i++) {
safe_cuda(cudaSetDevice(device_ordinals.at(i)));
safe_cuda(cudaStreamCreate(&(streams.at(i))));
mtjrider marked this conversation as resolved.
Show resolved Hide resolved
}
initialised_ = true;
#else
Expand Down Expand Up @@ -1009,6 +1044,29 @@ class AllReducer {
dh::safe_cuda(cudaSetDevice(device_ordinals[i]));
dh::safe_cuda(cudaStreamSynchronize(streams[i]));
}
#endif
};

/**
* \fn ncclUniqueId GetUniqueId()
*
* \brief Gets the Unique ID from NCCL to be used in setting up interprocess
* communication
*
* \return the Unique ID
*/
ncclUniqueId GetUniqueId() {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Necessary addition for specifying the ncclUniqueId in ncclCommInitRank.

#ifdef XGBOOST_USE_NCCL
mtjrider marked this conversation as resolved.
Show resolved Hide resolved
static const int RootRank = 0;
ncclUniqueId id;
if (rabit::GetRank() == RootRank) {
dh::safe_nccl(ncclGetUniqueId(&id));
}
rabit::Broadcast(
(void*)&id,
(size_t)sizeof(ncclUniqueId),
(int)RootRank);
return id;
#endif
}
};
Expand Down
7 changes: 7 additions & 0 deletions src/tree/updater_gpu_hist.cu
Original file line number Diff line number Diff line change
Expand Up @@ -1050,6 +1050,7 @@ class GPUHistMakerSpecialised{

void AllReduceHist(int nidx) {
if (shards_.size() == 1) return;
dh::safe_cuda(cudaDeviceSynchronize());
RAMitchell marked this conversation as resolved.
Show resolved Hide resolved
mtjrider marked this conversation as resolved.
Show resolved Hide resolved
monitor_.Start("AllReduce");

reducer_.GroupStart();
Expand Down Expand Up @@ -1080,6 +1081,9 @@ class GPUHistMakerSpecialised{
right_node_max_elements, shard->ridx_segments[nidx_right].Size());
}

rabit::Allreduce<rabit::op::Max, size_t>(&left_node_max_elements, 1);
rabit::Allreduce<rabit::op::Max, size_t>(&right_node_max_elements, 1);

auto build_hist_nidx = nidx_left;
auto subtraction_trick_nidx = nidx_right;

Expand Down Expand Up @@ -1142,6 +1146,9 @@ class GPUHistMakerSpecialised{
tmp_sums[i] = dh::SumReduction(
shard->temp_memory, shard->gpair.Data(), shard->gpair.Size());
});

rabit::Allreduce<rabit::op::Sum>((GradientPair::ValueT*)&tmp_sums[0], 2);
mtjrider marked this conversation as resolved.
Show resolved Hide resolved

GradientPair sum_gradient =
std::accumulate(tmp_sums.begin(), tmp_sums.end(), GradientPair());

Expand Down