Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enforce having more threads than GPUs. #4162

Closed
trivialfis opened this issue Feb 19, 2019 · 8 comments
Closed

Enforce having more threads than GPUs. #4162

trivialfis opened this issue Feb 19, 2019 · 8 comments

Comments

@trivialfis
Copy link
Member

Brought up in #4076 . Having less threads than GPUs will lead to undefined behaviors. Another example is a hang I encountered when working on #3974. #4095 might be able to decouple the number of threads and number of GPUs. But until then, I want to make a last minute fix that when user specified nthread < n_gpus, we either decrease n_gpus to nthreads or increase nthreads to n_gpus. @RAMitchell WDYT?

@mtjrider
Copy link
Contributor

mtjrider commented Feb 19, 2019

In an effort to provide some additional insight:

Currently #4095 creates two modes of operation:

  1. A NCCL communicator clique where a single process manages a collection of GPUs
  2. A NCCL communicator clique where there is a single process (PID) for each GPU involved in the clique

(2) is triggered by rabit::GetWorldSize(). Meaning, if within a node you assign an MPI rank per physical GPU, you will create a process for each GPU.

(1) is triggered by calling an XGBoost process with OpenMP threads, rather than MPI ranks.

(2) is well suited for multi-node operation; whereas, (1) avoids unnecessary complication on single machine setups.

In short, it is possible to adjust the code so that multiple processes are created in the NCCL communicator clique which performs a check against nthread < n_gpus, but it may conflict with potentially desired functionality in (1).

@trivialfis
Copy link
Member Author

@mt-jones Thanks for the insight.

it is possible to adjust the code so that multiple processes are created in the NCCL communicator clique which performs a check against nthread < n_gpus

Are you saying that, in non-distributed mode, when user specifies nthread < n_gpus, we fallback to a slightly more complicated setup that uses process instead of threads. But on any other time we default to using threads?

@mtjrider
Copy link
Contributor

mtjrider commented Feb 19, 2019

@mt-jones Thanks for the insight.

it is possible to adjust the code so that multiple processes are created in the NCCL communicator clique which performs a check against nthread < n_gpus

Are you saying that, in non-distributed mode, when user specifies nthread < n_gpus, we fallback to a slightly more complicated setup that uses process instead of threads. But on any other time we default to using threads?

I'm saying that I've removed the flag for distributed. Instead, using rabit::GetWorldSize() to make an inference on whether XGB is being executed in a distributed manner.

Effectively, a check is performed to see if rabit::GetWorldSize() == 1. If so, create a NCCL communicator clique with one process; else, create a NCCL communicator clique with multiple processes. In short, the same code currently in XGB is executed GPU-side for communicator initialization.

More or less:

  void Init(const std::vector<int> &device_ordinals) {
#ifdef XGBOOST_USE_NCCL
    /** \brief this >monitor . init. */
    this->device_ordinals = device_ordinals;
    comms.resize(device_ordinals.size());

    if (1 < rabit::GetWorldSize()) {
      auto id = GetUniqueId();
      dh::safe_nccl(ncclCommInitRank(
      	&(comms[0]),
      	rabit::GetWorldSize(),
      	id, rabit::GetRank()));
    } else {
      dh::safe_nccl(ncclCommInitAll(
      	comms.data(),
      	static_cast<int>(device_ordinals.size()),
      	device_ordinals.data()));
    }
...

@trivialfis
Copy link
Member Author

@mt-jones Ah I see. Will look into the detail. Thanks!

@mtjrider
Copy link
Contributor

@mt-jones Ah I see. Will look into the detail. Thanks!

No problem! Let me know if you have questions. My original statement was simply that we could do the same thing based on OpenMP thread rank, but it may cause conflicts with how XGB is currently executed.

@RAMitchell
Copy link
Member

@trivialfis thanks, yes a last minute fix would be great for 0.82. Another option is to fail with an error I will leave it up to you.

@mtjrider
Copy link
Contributor

@trivialfis See the code below. NCCL requires (see below for quote) that ncclCommInitRank be either

  1. encapsulated by ncclGroupStart and ncclGroupEnd to unblock the internal synchronous call to initialize the rank
  2. called by a distinct thread/process to avoid the block

One potential solution is to use OpenMP threads to initialize each rank, eliminating the for loop construction and the calls to GroupStart() and GroupEnd()

Code snippet (1)

GroupStart();
for (size_t i = 0; i < device_ordinals.size(); i++) {
  int dev = device_ordinals[i];
  int ndevs = device_ordinals.size();
  int nccl_rank = rabit::GetRank() * ndevs + dev;
  int nccl_nranks = rabit::GetWorldSize() * ndevs;
      
  dh::safe_cuda(cudaSetDevice(dev));
  dh::safe_nccl(ncclCommInitRank(
    &(comms[i]),
    nccl_nranks, id, 
    nccl_rank));
}
GroupEnd();

Code snippet (2)

#pragma omp parallel num_threads(device_ordinals.size()) // ***
{
int tid = omp_get_thread_num();
int dev = device_ordinals[tid];
int ndevs = device_ordinals.size();

int nccl_rank = rabit::GetRank() * ndevs + dev;
int nccl_nranks = rabit::GetWorldSize() * ndevs;

dh::safe_cuda(cudaSetDevice(dev));
dh::safe_nccl(ncclCommInitRank(
  &(comms[tid]),
  nccl_nranks, id,
  nccl_rank));
}

At line // ***, we could implement a check to avoid initialization, and error, or we can let num_threads override nthreads to initialize the communicator clique.

I have tested the above in PR #4095, and it does work (all tests pass).

ncclResult_t ncclCommInitRank(ncclComm_t* comm, int nranks, ncclUniqueId commId, int rank)

Creates a new communicator (multi thread/process version). rank must be between 0 and nranks-1 and unique within a communicator clique. Each rank is associated to a CUDA device, which has to be set before calling ncclCommInitRank. ncclCommInitRank implicitly synchronizes with other ranks, hence it must be called by different threads/processes or use ncclGroupStart/ncclGroupEnd.

@trivialfis
Copy link
Member Author

I dropped the idea of manipulating openmp threads. The GPU threads must be decoupled with the nthread parameter. So I will try std thread. Please give me some time to learn its implications.

@lock lock bot locked as resolved and limited conversation to collaborators Sep 18, 2019
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants