From 4bf0ba4da04d652c5650cbd85db64123f1dd7092 Mon Sep 17 00:00:00 2001 From: Venkat Date: Sun, 4 Apr 2021 12:57:20 +0530 Subject: [PATCH] ENH Decision Tree new backend `computeSplitClassificationKernel` histogram calculation and occupancy optimization (#3616) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * This PR introduces: * A faster way to calculate the histograms containing splits in the `ML::DecisionTree::computeSplitClassificationKernel` . These histograms are used for node-splitting in decision trees for the task of classification. * A change in the default `gridDim.x` in the launch configuration of the above kernel from `4` to based on occupancy calculator and other dimension gridDims, thus improving the occupancy to theoretical limits * Earlier too many atomic adds to shared memory limited the kernel times, which has been avoided by blockwide sum-scans to obtain the same histogram using fewer atomic writes to shared memory. * The resulting kernel time speedups are significant (upto 30x for some nodes) * `computeSplitRegressionKernel` has different share-memory write patterns that deserves it's own PR for optimization 😬 * Tests will pass once #3690 is merged Authors: - Venkat (https://github.com/venkywonka) Approvers: - AJ Schmidt (https://github.com/ajschmidt8) - Philip Hyunsu Cho (https://github.com/hcho3) - Thejaswi. N. S (https://github.com/teju85) - John Zedlewski (https://github.com/JohnZed) URL: https://github.com/rapidsai/cuml/pull/3616 --- .../batched-levelalgo/builder_base.cuh | 58 +++++++-- .../batched-levelalgo/kernels.cuh | 112 +++++++++++++++--- .../batched-levelalgo/metrics.cuh | 14 +-- 3 files changed, 154 insertions(+), 30 deletions(-) diff --git a/cpp/src/decisiontree/batched-levelalgo/builder_base.cuh b/cpp/src/decisiontree/batched-levelalgo/builder_base.cuh index 9f537d5a63..f97f2291ec 100644 --- a/cpp/src/decisiontree/batched-levelalgo/builder_base.cuh +++ b/cpp/src/decisiontree/batched-levelalgo/builder_base.cuh @@ -45,6 +45,8 @@ struct Builder { typedef typename Traits::SplitT SplitT; typedef typename Traits::InputT InputT; + /** default threads per block for most kernels in here */ + static constexpr int TPB_DEFAULT = 256; /** DT params */ DecisionTreeParams params; /** input dataset */ @@ -100,8 +102,6 @@ struct Builder { IdxT node_start, node_end; /** number of blocks used to parallelize column-wise computations. */ int n_blks_for_cols = 10; - /** Number of blocks used to parallelize row-wise computations. */ - int n_blks_for_rows = 4; /** Memory alignment value */ const size_t alignValue = 512; @@ -110,6 +110,36 @@ struct Builder { return std::is_same::value; } + /** + * @brief Assigns number of blocks used to parallelize row-wise computations to maximize occupacy + * + * @param[out] n_blks_for_rows Appropriate blocks for rows (gridDim.x) + * that maximizes occupancy + * @param[in] gridDimy number of blocks assigned in the y-dimension (n_blks_for_cols) + * @param[in] func Kernel function; needed by the occupancy calculator for finding + * maximum active blocks per multiprocessor + * @param[in] blockSize Threads per Block, passed to cuda occupancy calculator API + * @param[in] dynamic_smem_size dynamic shared memory size, passed to cuda occupancy calculator API + * @param[in] gridDimz Number of blocks along the z-dimension, based + * on the concurrent nodes of tree available to be processed. + */ + int n_blks_for_rows(const int gridDimy, const void* func, const int blockSize, + const size_t dynamic_smem_size, const int gridDimz) { + int devid; + CUDA_CHECK(cudaGetDevice(&devid)); + int mpcount; + CUDA_CHECK( + cudaDeviceGetAttribute(&mpcount, cudaDevAttrMultiProcessorCount, devid)); + int maxblks; + // get expected max blocks per multiprocessor + CUDA_CHECK(cudaOccupancyMaxActiveBlocksPerMultiprocessor( + &maxblks, func, blockSize, dynamic_smem_size)); + // get the total number of blocks + int n_blks = maxblks * mpcount; + // return appropriate number of blocks in x-dimension + return raft::ceildiv(n_blks, gridDimy * gridDimz); + } + size_t calculateAlignedBytes(const size_t actualSize) { return raft::alignTo(actualSize, alignValue); } @@ -160,7 +190,7 @@ struct Builder { input.quantiles = quantiles; auto max_batch = params.max_batch_size; auto n_col_blks = n_blks_for_cols; - nHistBins = 2 * max_batch * params.n_bins * n_col_blks * nclasses; + nHistBins = max_batch * (1 + params.n_bins) * n_col_blks * nclasses; // x2 for mean and mean-of-square nPredCounts = max_batch * params.n_bins * n_col_blks; if (params.max_depth < 13) { @@ -172,6 +202,11 @@ struct Builder { } if (isRegression()) { + int n_blks_for_rows = this->n_blks_for_rows( + n_col_blks, + (const void*) + computeSplitRegressionKernel, + TPB_DEFAULT, 0, max_batch); dim3 grid(n_blks_for_rows, n_col_blks, max_batch); block_sync_size = MLCommon::GridSync::computeWorkspaceSize( grid, MLCommon::SyncType::ACROSS_X, false); @@ -409,15 +444,19 @@ struct ClsTraits { "Builder::computeSplit @builder_base.cuh [batched-levelalgo]"); auto nbins = b.params.n_bins; auto nclasses = b.input.nclasses; - auto binSize = nbins * 2 * nclasses; + auto binSize = (nbins * 3 + 1) * nclasses; auto colBlks = std::min(b.n_blks_for_cols, b.input.nSampledCols - col); - dim3 grid(b.n_blks_for_rows, colBlks, batchSize); size_t smemSize = sizeof(int) * binSize + sizeof(DataT) * nbins; smemSize += sizeof(int); // Extra room for alignment (see alignPointer in computeSplitClassificationKernel) smemSize += 2 * sizeof(DataT) + 1 * sizeof(int); - + int n_blks_for_rows = b.n_blks_for_rows( + colBlks, + (const void*) + computeSplitClassificationKernel, + TPB_DEFAULT, smemSize, batchSize); + dim3 grid(n_blks_for_rows, colBlks, batchSize); CUDA_CHECK(cudaMemsetAsync(b.hist, 0, sizeof(int) * b.nHistBins, s)); ML::PUSH_RANGE( "computeSplitClassificationKernel @builder_base.cuh [batched-levelalgo]"); @@ -483,7 +522,6 @@ struct RegTraits { ML::PUSH_RANGE( "Builder::computeSplit @builder_base.cuh [batched-levelalgo]"); auto n_col_blks = std::min(b.n_blks_for_cols, b.input.nSampledCols - col); - dim3 grid(b.n_blks_for_rows, n_col_blks, batchSize); auto nbins = b.params.n_bins; size_t smemSize = 7 * nbins * sizeof(DataT) + nbins * sizeof(int); smemSize += sizeof(int); @@ -491,6 +529,12 @@ struct RegTraits { // Room for alignment in worst case (see alignPointer in // computeSplitRegressionKernel) smemSize += 5 * sizeof(DataT) + 2 * sizeof(int); + int n_blks_for_rows = b.n_blks_for_rows( + n_col_blks, + (const void*) + computeSplitRegressionKernel, + TPB_DEFAULT, smemSize, batchSize); + dim3 grid(n_blks_for_rows, n_col_blks, batchSize); CUDA_CHECK( cudaMemsetAsync(b.pred, 0, sizeof(DataT) * b.nPredCounts * 2, s)); diff --git a/cpp/src/decisiontree/batched-levelalgo/kernels.cuh b/cpp/src/decisiontree/batched-levelalgo/kernels.cuh index 1a35eb9f40..b66c2d4144 100644 --- a/cpp/src/decisiontree/batched-levelalgo/kernels.cuh +++ b/cpp/src/decisiontree/batched-levelalgo/kernels.cuh @@ -341,19 +341,24 @@ __global__ void computeSplitClassificationKernel( auto node = nodes[nid]; auto range_start = node.start; auto range_len = node.count; + + // return if leaf if (leafBasedOnParams(node.depth, max_depth, min_samples_split, max_leaves, n_leaves, range_len)) { return; } auto end = range_start + range_len; auto nclasses = input.nclasses; - auto len = nbins * 2 * nclasses; - auto* shist = alignPointer(smem); - auto* sbins = alignPointer(shist + len); + auto pdf_shist_len = (nbins + 1) * nclasses; + auto cdf_shist_len = nbins * 2 * nclasses; + auto* pdf_shist = alignPointer(smem); + auto* cdf_shist = alignPointer(pdf_shist + pdf_shist_len); + auto* sbins = alignPointer(cdf_shist + cdf_shist_len); auto* sDone = alignPointer(sbins + nbins); IdxT stride = blockDim.x * gridDim.x; IdxT tid = threadIdx.x + blockIdx.x * blockDim.x; + // obtaining the feature to test split on IdxT col; if (input.nSampledCols == input.N) { col = colStart + blockIdx.y; @@ -362,50 +367,125 @@ __global__ void computeSplitClassificationKernel( col = select(colIndex, treeid, node.info.unique_id, seed, input.N); } - for (IdxT i = threadIdx.x; i < len; i += blockDim.x) shist[i] = 0; + // populating shared memory with initial values + for (IdxT i = threadIdx.x; i < pdf_shist_len; i += blockDim.x) + pdf_shist[i] = 0; + for (IdxT j = threadIdx.x; j < cdf_shist_len; j += blockDim.x) + cdf_shist[j] = 0; for (IdxT b = threadIdx.x; b < nbins; b += blockDim.x) sbins[b] = input.quantiles[col * nbins + b]; + + // synchronizing above changes across block __syncthreads(); + + // compute pdf shared histogram for all bins for all classes in shared mem auto coloffset = col * input.M; - // compute class histogram for all bins for all classes in shared mem for (auto i = range_start + tid; i < end; i += stride) { + // each thread works over a data point and strides to the next auto row = input.rowids[i]; auto d = input.data[row + coloffset]; auto label = input.labels[row]; for (IdxT b = 0; b < nbins; ++b) { - auto isRight = d > sbins[b]; // no divergence - auto offset = b * 2 * nclasses + isRight * nclasses + label; - atomicAdd(shist + offset, 1); // class hist + if (d <= sbins[b]) { + auto offset = label * (1 + nbins) + b; + atomicAdd(pdf_shist + offset, 1); + break; + } } } + + // synchronizeing above changes across block __syncthreads(); + // update the corresponding global location - auto histOffset = ((nid * gridDim.y) + blockIdx.y) * len; - for (IdxT i = threadIdx.x; i < len; i += blockDim.x) { - atomicAdd(hist + histOffset + i, shist[i]); + auto histOffset = ((nid * gridDim.y) + blockIdx.y) * pdf_shist_len; + for (IdxT i = threadIdx.x; i < pdf_shist_len; i += blockDim.x) { + atomicAdd(hist + histOffset + i, pdf_shist[i]); } + __threadfence(); // for commit guarantee __syncthreads(); + // last threadblock will go ahead and compute the best split bool last = true; if (gridDim.x > 1) { last = MLCommon::signalDone(done_count + nid * gridDim.y + blockIdx.y, gridDim.x, blockIdx.x == 0, sDone); } + // if not the last threadblock, exit if (!last) return; - for (IdxT i = threadIdx.x; i < len; i += blockDim.x) - shist[i] = hist[histOffset + i]; + + // store the complete global histogram in shared memory of last block + for (IdxT i = threadIdx.x; i < pdf_shist_len; i += blockDim.x) + pdf_shist[i] = hist[histOffset + i]; + + __syncthreads(); + + // Blockscan instance preparation + typedef cub::BlockScan BlockScan; + __shared__ typename BlockScan::TempStorage temp_storage; + + /** + * Scanning code: + * span: block-wide + * Function: convert the PDF calculated in the previous steps to CDF + * This CDF is done over 2 passes + * * one from left to right to sum-scan counts of left splits + * for each split-point. + * * second from right to left to sum-scan the right splits + * for each split-point + */ + for (IdxT tix = threadIdx.x; tix < max(TPB, nbins); tix += blockDim.x) { + for (IdxT c = 0; c < nclasses; ++c) { + // for each class, do inclusive block scan + int pdf_per_bin_per_class; + int cdf_per_bin_per_class; + // left to right scan operation for scanning lesser-than-or-equal-to-bin counts + // offset for left to right scan of pdf_shist + IdxT class_segment_offset = (1 + nbins) * c; + pdf_per_bin_per_class = + tix < nbins ? pdf_shist[class_segment_offset + tix] : 0; + BlockScan(temp_storage) + .InclusiveSum(pdf_per_bin_per_class, cdf_per_bin_per_class); + __syncthreads(); // synchronizing the scan + if (tix < nbins) { + auto histOffset = (2 * nbins * c + tix); + cdf_shist[histOffset] = cdf_per_bin_per_class; + } + + // right to left scan operation for scanning greater-than-bin counts + // thread0 -> last class segment of pdf_shist + // thread(nbins - 1) -> 2nd class segment of pdf_shist + // offset for right to left scan of pdf_shist + pdf_per_bin_per_class = + tix < nbins ? pdf_shist[class_segment_offset + (nbins - tix)] : 0; + BlockScan(temp_storage) + .InclusiveSum(pdf_per_bin_per_class, cdf_per_bin_per_class); + __syncthreads(); // synchronizing the scan + if (tix < nbins) { + auto histOffset = (2 * nbins * c) + nbins + (nbins - tix - 1); + cdf_shist[histOffset] = cdf_per_bin_per_class; + } + } + } + + // create a split instance to test current feature split Split sp; sp.init(); __syncthreads(); + + // calculate the best candidate bins (one for each block-thread) in current feature and corresponding information gain for splitting if (splitType == CRITERION::GINI) { - giniGain(shist, sbins, sp, col, range_len, nbins, nclasses, + giniGain(cdf_shist, sbins, sp, col, range_len, nbins, nclasses, min_samples_leaf, min_impurity_decrease); } else { - entropyGain(shist, sbins, sp, col, range_len, nbins, nclasses, - min_samples_leaf, min_impurity_decrease); + entropyGain(cdf_shist, sbins, sp, col, range_len, nbins, + nclasses, min_samples_leaf, min_impurity_decrease); } __syncthreads(); + + // calculate best bins among candidate bins per feature using warp reduce + // then atomically update across features to get best split per node (in split[nid]) sp.evalBestSplit(smem, splits + nid, mutex + nid); } diff --git a/cpp/src/decisiontree/batched-levelalgo/metrics.cuh b/cpp/src/decisiontree/batched-levelalgo/metrics.cuh index 0f3faafbd5..9a5e101b54 100644 --- a/cpp/src/decisiontree/batched-levelalgo/metrics.cuh +++ b/cpp/src/decisiontree/batched-levelalgo/metrics.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-2020, NVIDIA CORPORATION. + * Copyright (c) 2019-2021, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -76,7 +76,7 @@ DI void giniGain(int* shist, DataT* sbins, Split& sp, IdxT col, for (IdxT i = threadIdx.x; i < nbins; i += blockDim.x) { int nLeft = 0; for (IdxT j = 0; j < nclasses; ++j) { - nLeft += shist[i * 2 * nclasses + j]; + nLeft += shist[2 * nbins * j + i]; } auto nRight = len - nLeft; auto gain = DataT(0.0); @@ -88,12 +88,12 @@ DI void giniGain(int* shist, DataT* sbins, Split& sp, IdxT col, auto invRight = One / nRight; for (IdxT j = 0; j < nclasses; ++j) { int val_i = 0; - auto lval_i = shist[i * 2 * nclasses + j]; + auto lval_i = shist[2 * nbins * j + i]; auto lval = DataT(lval_i); gain += lval * invLeft * lval * invlen; val_i += lval_i; - auto rval_i = shist[i * 2 * nclasses + nclasses + j]; + auto rval_i = shist[2 * nbins * j + nbins + i]; auto rval = DataT(rval_i); gain += rval * invRight * rval * invlen; @@ -142,7 +142,7 @@ DI void entropyGain(int* shist, DataT* sbins, Split& sp, IdxT col, for (IdxT i = threadIdx.x; i < nbins; i += blockDim.x) { int nLeft = 0; for (IdxT j = 0; j < nclasses; ++j) { - nLeft += shist[i * 2 * nclasses + j]; + nLeft += shist[2 * nbins * j + i]; } auto nRight = len - nLeft; auto gain = DataT(0.0); @@ -154,7 +154,7 @@ DI void entropyGain(int* shist, DataT* sbins, Split& sp, IdxT col, auto invRight = One / nRight; for (IdxT j = 0; j < nclasses; ++j) { int val_i = 0; - auto lval_i = shist[i * 2 * nclasses + j]; + auto lval_i = shist[2 * nbins * j + i]; if (lval_i != 0) { auto lval = DataT(lval_i); gain += @@ -162,7 +162,7 @@ DI void entropyGain(int* shist, DataT* sbins, Split& sp, IdxT col, } val_i += lval_i; - auto rval_i = shist[i * 2 * nclasses + nclasses + j]; + auto rval_i = shist[2 * nbins * j + nbins + i]; if (rval_i != 0) { auto rval = DataT(rval_i); gain += raft::myLog(rval * invRight) / raft::myLog(DataT(2)) * rval *