Skip to content

Commit

Permalink
ENH Decision Tree new backend computeSplitClassificationKernel hist…
Browse files Browse the repository at this point in the history
…ogram calculation and occupancy optimization (#3616)

* This PR introduces:
    * A faster way to calculate the histograms containing splits in the `ML::DecisionTree::computeSplitClassificationKernel` . These histograms are used for node-splitting in decision trees for the task of classification.
    * A change in the default `gridDim.x` in the launch configuration of the above kernel from `4` to based on occupancy calculator and other dimension gridDims, thus improving the occupancy to theoretical limits
* Earlier too many atomic adds to shared memory limited the kernel times, which has been avoided by blockwide sum-scans to obtain the same histogram using fewer atomic writes to shared memory.
* The resulting kernel time speedups are significant (upto 30x for some nodes) 
* `computeSplitRegressionKernel` has different share-memory write patterns that deserves it's own PR for optimization 😬 
* Tests will pass once #3690  is merged

Authors:
  - Venkat (https://github.com/venkywonka)

Approvers:
  - AJ Schmidt (https://github.com/ajschmidt8)
  - Philip Hyunsu Cho (https://github.com/hcho3)
  - Thejaswi. N. S (https://github.com/teju85)
  - John Zedlewski (https://github.com/JohnZed)

URL: #3616
  • Loading branch information
venkywonka authored Apr 4, 2021
1 parent 1554f14 commit 4bf0ba4
Show file tree
Hide file tree
Showing 3 changed files with 154 additions and 30 deletions.
58 changes: 51 additions & 7 deletions cpp/src/decisiontree/batched-levelalgo/builder_base.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,8 @@ struct Builder {
typedef typename Traits::SplitT SplitT;
typedef typename Traits::InputT InputT;

/** default threads per block for most kernels in here */
static constexpr int TPB_DEFAULT = 256;
/** DT params */
DecisionTreeParams params;
/** input dataset */
Expand Down Expand Up @@ -100,8 +102,6 @@ struct Builder {
IdxT node_start, node_end;
/** number of blocks used to parallelize column-wise computations. */
int n_blks_for_cols = 10;
/** Number of blocks used to parallelize row-wise computations. */
int n_blks_for_rows = 4;
/** Memory alignment value */
const size_t alignValue = 512;

Expand All @@ -110,6 +110,36 @@ struct Builder {
return std::is_same<DataT, LabelT>::value;
}

/**
* @brief Assigns number of blocks used to parallelize row-wise computations to maximize occupacy
*
* @param[out] n_blks_for_rows Appropriate blocks for rows (gridDim.x)
* that maximizes occupancy
* @param[in] gridDimy number of blocks assigned in the y-dimension (n_blks_for_cols)
* @param[in] func Kernel function; needed by the occupancy calculator for finding
* maximum active blocks per multiprocessor
* @param[in] blockSize Threads per Block, passed to cuda occupancy calculator API
* @param[in] dynamic_smem_size dynamic shared memory size, passed to cuda occupancy calculator API
* @param[in] gridDimz Number of blocks along the z-dimension, based
* on the concurrent nodes of tree available to be processed.
*/
int n_blks_for_rows(const int gridDimy, const void* func, const int blockSize,
const size_t dynamic_smem_size, const int gridDimz) {
int devid;
CUDA_CHECK(cudaGetDevice(&devid));
int mpcount;
CUDA_CHECK(
cudaDeviceGetAttribute(&mpcount, cudaDevAttrMultiProcessorCount, devid));
int maxblks;
// get expected max blocks per multiprocessor
CUDA_CHECK(cudaOccupancyMaxActiveBlocksPerMultiprocessor(
&maxblks, func, blockSize, dynamic_smem_size));
// get the total number of blocks
int n_blks = maxblks * mpcount;
// return appropriate number of blocks in x-dimension
return raft::ceildiv(n_blks, gridDimy * gridDimz);
}

size_t calculateAlignedBytes(const size_t actualSize) {
return raft::alignTo(actualSize, alignValue);
}
Expand Down Expand Up @@ -160,7 +190,7 @@ struct Builder {
input.quantiles = quantiles;
auto max_batch = params.max_batch_size;
auto n_col_blks = n_blks_for_cols;
nHistBins = 2 * max_batch * params.n_bins * n_col_blks * nclasses;
nHistBins = max_batch * (1 + params.n_bins) * n_col_blks * nclasses;
// x2 for mean and mean-of-square
nPredCounts = max_batch * params.n_bins * n_col_blks;
if (params.max_depth < 13) {
Expand All @@ -172,6 +202,11 @@ struct Builder {
}

if (isRegression()) {
int n_blks_for_rows = this->n_blks_for_rows(
n_col_blks,
(const void*)
computeSplitRegressionKernel<DataT, LabelT, IdxT, TPB_DEFAULT>,
TPB_DEFAULT, 0, max_batch);
dim3 grid(n_blks_for_rows, n_col_blks, max_batch);
block_sync_size = MLCommon::GridSync::computeWorkspaceSize(
grid, MLCommon::SyncType::ACROSS_X, false);
Expand Down Expand Up @@ -409,15 +444,19 @@ struct ClsTraits {
"Builder::computeSplit @builder_base.cuh [batched-levelalgo]");
auto nbins = b.params.n_bins;
auto nclasses = b.input.nclasses;
auto binSize = nbins * 2 * nclasses;
auto binSize = (nbins * 3 + 1) * nclasses;
auto colBlks = std::min(b.n_blks_for_cols, b.input.nSampledCols - col);
dim3 grid(b.n_blks_for_rows, colBlks, batchSize);
size_t smemSize = sizeof(int) * binSize + sizeof(DataT) * nbins;
smemSize += sizeof(int);

// Extra room for alignment (see alignPointer in computeSplitClassificationKernel)
smemSize += 2 * sizeof(DataT) + 1 * sizeof(int);

int n_blks_for_rows = b.n_blks_for_rows(
colBlks,
(const void*)
computeSplitClassificationKernel<DataT, LabelT, IdxT, TPB_DEFAULT>,
TPB_DEFAULT, smemSize, batchSize);
dim3 grid(n_blks_for_rows, colBlks, batchSize);
CUDA_CHECK(cudaMemsetAsync(b.hist, 0, sizeof(int) * b.nHistBins, s));
ML::PUSH_RANGE(
"computeSplitClassificationKernel @builder_base.cuh [batched-levelalgo]");
Expand Down Expand Up @@ -483,14 +522,19 @@ struct RegTraits {
ML::PUSH_RANGE(
"Builder::computeSplit @builder_base.cuh [batched-levelalgo]");
auto n_col_blks = std::min(b.n_blks_for_cols, b.input.nSampledCols - col);
dim3 grid(b.n_blks_for_rows, n_col_blks, batchSize);
auto nbins = b.params.n_bins;
size_t smemSize = 7 * nbins * sizeof(DataT) + nbins * sizeof(int);
smemSize += sizeof(int);

// Room for alignment in worst case (see alignPointer in
// computeSplitRegressionKernel)
smemSize += 5 * sizeof(DataT) + 2 * sizeof(int);
int n_blks_for_rows = b.n_blks_for_rows(
n_col_blks,
(const void*)
computeSplitRegressionKernel<DataT, LabelT, IdxT, TPB_DEFAULT>,
TPB_DEFAULT, smemSize, batchSize);
dim3 grid(n_blks_for_rows, n_col_blks, batchSize);

CUDA_CHECK(
cudaMemsetAsync(b.pred, 0, sizeof(DataT) * b.nPredCounts * 2, s));
Expand Down
112 changes: 96 additions & 16 deletions cpp/src/decisiontree/batched-levelalgo/kernels.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -341,19 +341,24 @@ __global__ void computeSplitClassificationKernel(
auto node = nodes[nid];
auto range_start = node.start;
auto range_len = node.count;

// return if leaf
if (leafBasedOnParams<DataT, IdxT>(node.depth, max_depth, min_samples_split,
max_leaves, n_leaves, range_len)) {
return;
}
auto end = range_start + range_len;
auto nclasses = input.nclasses;
auto len = nbins * 2 * nclasses;
auto* shist = alignPointer<int>(smem);
auto* sbins = alignPointer<DataT>(shist + len);
auto pdf_shist_len = (nbins + 1) * nclasses;
auto cdf_shist_len = nbins * 2 * nclasses;
auto* pdf_shist = alignPointer<int>(smem);
auto* cdf_shist = alignPointer<int>(pdf_shist + pdf_shist_len);
auto* sbins = alignPointer<DataT>(cdf_shist + cdf_shist_len);
auto* sDone = alignPointer<int>(sbins + nbins);
IdxT stride = blockDim.x * gridDim.x;
IdxT tid = threadIdx.x + blockIdx.x * blockDim.x;

// obtaining the feature to test split on
IdxT col;
if (input.nSampledCols == input.N) {
col = colStart + blockIdx.y;
Expand All @@ -362,50 +367,125 @@ __global__ void computeSplitClassificationKernel(
col = select(colIndex, treeid, node.info.unique_id, seed, input.N);
}

for (IdxT i = threadIdx.x; i < len; i += blockDim.x) shist[i] = 0;
// populating shared memory with initial values
for (IdxT i = threadIdx.x; i < pdf_shist_len; i += blockDim.x)
pdf_shist[i] = 0;
for (IdxT j = threadIdx.x; j < cdf_shist_len; j += blockDim.x)
cdf_shist[j] = 0;
for (IdxT b = threadIdx.x; b < nbins; b += blockDim.x)
sbins[b] = input.quantiles[col * nbins + b];

// synchronizing above changes across block
__syncthreads();

// compute pdf shared histogram for all bins for all classes in shared mem
auto coloffset = col * input.M;
// compute class histogram for all bins for all classes in shared mem
for (auto i = range_start + tid; i < end; i += stride) {
// each thread works over a data point and strides to the next
auto row = input.rowids[i];
auto d = input.data[row + coloffset];
auto label = input.labels[row];
for (IdxT b = 0; b < nbins; ++b) {
auto isRight = d > sbins[b]; // no divergence
auto offset = b * 2 * nclasses + isRight * nclasses + label;
atomicAdd(shist + offset, 1); // class hist
if (d <= sbins[b]) {
auto offset = label * (1 + nbins) + b;
atomicAdd(pdf_shist + offset, 1);
break;
}
}
}

// synchronizeing above changes across block
__syncthreads();

// update the corresponding global location
auto histOffset = ((nid * gridDim.y) + blockIdx.y) * len;
for (IdxT i = threadIdx.x; i < len; i += blockDim.x) {
atomicAdd(hist + histOffset + i, shist[i]);
auto histOffset = ((nid * gridDim.y) + blockIdx.y) * pdf_shist_len;
for (IdxT i = threadIdx.x; i < pdf_shist_len; i += blockDim.x) {
atomicAdd(hist + histOffset + i, pdf_shist[i]);
}

__threadfence(); // for commit guarantee
__syncthreads();

// last threadblock will go ahead and compute the best split
bool last = true;
if (gridDim.x > 1) {
last = MLCommon::signalDone(done_count + nid * gridDim.y + blockIdx.y,
gridDim.x, blockIdx.x == 0, sDone);
}
// if not the last threadblock, exit
if (!last) return;
for (IdxT i = threadIdx.x; i < len; i += blockDim.x)
shist[i] = hist[histOffset + i];

// store the complete global histogram in shared memory of last block
for (IdxT i = threadIdx.x; i < pdf_shist_len; i += blockDim.x)
pdf_shist[i] = hist[histOffset + i];

__syncthreads();

// Blockscan instance preparation
typedef cub::BlockScan<int, TPB> BlockScan;
__shared__ typename BlockScan::TempStorage temp_storage;

/**
* Scanning code:
* span: block-wide
* Function: convert the PDF calculated in the previous steps to CDF
* This CDF is done over 2 passes
* * one from left to right to sum-scan counts of left splits
* for each split-point.
* * second from right to left to sum-scan the right splits
* for each split-point
*/
for (IdxT tix = threadIdx.x; tix < max(TPB, nbins); tix += blockDim.x) {
for (IdxT c = 0; c < nclasses; ++c) {
// for each class, do inclusive block scan
int pdf_per_bin_per_class;
int cdf_per_bin_per_class;
// left to right scan operation for scanning lesser-than-or-equal-to-bin counts
// offset for left to right scan of pdf_shist
IdxT class_segment_offset = (1 + nbins) * c;
pdf_per_bin_per_class =
tix < nbins ? pdf_shist[class_segment_offset + tix] : 0;
BlockScan(temp_storage)
.InclusiveSum(pdf_per_bin_per_class, cdf_per_bin_per_class);
__syncthreads(); // synchronizing the scan
if (tix < nbins) {
auto histOffset = (2 * nbins * c + tix);
cdf_shist[histOffset] = cdf_per_bin_per_class;
}

// right to left scan operation for scanning greater-than-bin counts
// thread0 -> last class segment of pdf_shist
// thread(nbins - 1) -> 2nd class segment of pdf_shist
// offset for right to left scan of pdf_shist
pdf_per_bin_per_class =
tix < nbins ? pdf_shist[class_segment_offset + (nbins - tix)] : 0;
BlockScan(temp_storage)
.InclusiveSum(pdf_per_bin_per_class, cdf_per_bin_per_class);
__syncthreads(); // synchronizing the scan
if (tix < nbins) {
auto histOffset = (2 * nbins * c) + nbins + (nbins - tix - 1);
cdf_shist[histOffset] = cdf_per_bin_per_class;
}
}
}

// create a split instance to test current feature split
Split<DataT, IdxT> sp;
sp.init();
__syncthreads();

// calculate the best candidate bins (one for each block-thread) in current feature and corresponding information gain for splitting
if (splitType == CRITERION::GINI) {
giniGain<DataT, IdxT>(shist, sbins, sp, col, range_len, nbins, nclasses,
giniGain<DataT, IdxT>(cdf_shist, sbins, sp, col, range_len, nbins, nclasses,
min_samples_leaf, min_impurity_decrease);
} else {
entropyGain<DataT, IdxT>(shist, sbins, sp, col, range_len, nbins, nclasses,
min_samples_leaf, min_impurity_decrease);
entropyGain<DataT, IdxT>(cdf_shist, sbins, sp, col, range_len, nbins,
nclasses, min_samples_leaf, min_impurity_decrease);
}
__syncthreads();

// calculate best bins among candidate bins per feature using warp reduce
// then atomically update across features to get best split per node (in split[nid])
sp.evalBestSplit(smem, splits + nid, mutex + nid);
}

Expand Down
14 changes: 7 additions & 7 deletions cpp/src/decisiontree/batched-levelalgo/metrics.cuh
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2019-2020, NVIDIA CORPORATION.
* Copyright (c) 2019-2021, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -76,7 +76,7 @@ DI void giniGain(int* shist, DataT* sbins, Split<DataT, IdxT>& sp, IdxT col,
for (IdxT i = threadIdx.x; i < nbins; i += blockDim.x) {
int nLeft = 0;
for (IdxT j = 0; j < nclasses; ++j) {
nLeft += shist[i * 2 * nclasses + j];
nLeft += shist[2 * nbins * j + i];
}
auto nRight = len - nLeft;
auto gain = DataT(0.0);
Expand All @@ -88,12 +88,12 @@ DI void giniGain(int* shist, DataT* sbins, Split<DataT, IdxT>& sp, IdxT col,
auto invRight = One / nRight;
for (IdxT j = 0; j < nclasses; ++j) {
int val_i = 0;
auto lval_i = shist[i * 2 * nclasses + j];
auto lval_i = shist[2 * nbins * j + i];
auto lval = DataT(lval_i);
gain += lval * invLeft * lval * invlen;

val_i += lval_i;
auto rval_i = shist[i * 2 * nclasses + nclasses + j];
auto rval_i = shist[2 * nbins * j + nbins + i];
auto rval = DataT(rval_i);
gain += rval * invRight * rval * invlen;

Expand Down Expand Up @@ -142,7 +142,7 @@ DI void entropyGain(int* shist, DataT* sbins, Split<DataT, IdxT>& sp, IdxT col,
for (IdxT i = threadIdx.x; i < nbins; i += blockDim.x) {
int nLeft = 0;
for (IdxT j = 0; j < nclasses; ++j) {
nLeft += shist[i * 2 * nclasses + j];
nLeft += shist[2 * nbins * j + i];
}
auto nRight = len - nLeft;
auto gain = DataT(0.0);
Expand All @@ -154,15 +154,15 @@ DI void entropyGain(int* shist, DataT* sbins, Split<DataT, IdxT>& sp, IdxT col,
auto invRight = One / nRight;
for (IdxT j = 0; j < nclasses; ++j) {
int val_i = 0;
auto lval_i = shist[i * 2 * nclasses + j];
auto lval_i = shist[2 * nbins * j + i];
if (lval_i != 0) {
auto lval = DataT(lval_i);
gain +=
raft::myLog(lval * invLeft) / raft::myLog(DataT(2)) * lval * invlen;
}

val_i += lval_i;
auto rval_i = shist[i * 2 * nclasses + nclasses + j];
auto rval_i = shist[2 * nbins * j + nbins + i];
if (rval_i != 0) {
auto rval = DataT(rval_i);
gain += raft::myLog(rval * invRight) / raft::myLog(DataT(2)) * rval *
Expand Down

0 comments on commit 4bf0ba4

Please sign in to comment.