diff --git a/cpp/include/cuml/tree/decisiontree.hpp b/cpp/include/cuml/tree/decisiontree.hpp index c483783679..8bd49729e2 100644 --- a/cpp/include/cuml/tree/decisiontree.hpp +++ b/cpp/include/cuml/tree/decisiontree.hpp @@ -92,7 +92,7 @@ void set_tree_params(DecisionTreeParams ¶ms, int cfg_max_depth = -1, int cfg_min_samples_split = 2, float cfg_min_impurity_decrease = 0.0f, CRITERION cfg_split_criterion = CRITERION_END, - int cfg_max_batch_size = 128); + int cfg_max_batch_size = 4096); /** * @brief Check validity of all decision tree hyper-parameters. diff --git a/cpp/src/decisiontree/batched-levelalgo/builder_base.cuh b/cpp/src/decisiontree/batched-levelalgo/builder_base.cuh index ef9cb04e2a..8cf52c76b5 100644 --- a/cpp/src/decisiontree/batched-levelalgo/builder_base.cuh +++ b/cpp/src/decisiontree/batched-levelalgo/builder_base.cuh @@ -327,6 +327,8 @@ struct Builder { raft::update_device(curr_nodes, h_nodes.data() + node_start, batchSize, s); int total_samples_in_curr_batch = 0; + int n_large_nodes_in_curr_batch = + 0; // large nodes are nodes having training instances larger than block size, hence require global memory for histogram construction total_num_blocks = 0; for (int n = 0; n < batchSize; n++) { total_samples_in_curr_batch += h_nodes[node_start + n].count; @@ -334,6 +336,8 @@ struct Builder { SAMPLES_PER_THREAD * TPB_DEFAULT); num_blocks = std::max(1, num_blocks); + if (num_blocks > 1) ++n_large_nodes_in_curr_batch; + bool is_leaf = leafBasedOnParams( h_nodes[node_start + n].depth, params.max_depth, params.min_samples_split, params.max_leaves, h_n_leaves, @@ -342,6 +346,8 @@ struct Builder { for (int b = 0; b < num_blocks; b++) { h_workload_info[total_num_blocks + b].nodeid = n; + h_workload_info[total_num_blocks + b].large_nodeid = + n_large_nodes_in_curr_batch - 1; h_workload_info[total_num_blocks + b].offset_blockid = b; h_workload_info[total_num_blocks + b].num_blocks = num_blocks; } @@ -353,7 +359,8 @@ struct Builder { auto n_col_blks = n_blks_for_cols; if (total_num_blocks) { for (IdxT c = 0; c < input.nSampledCols; c += n_col_blks) { - computeSplit(c, batchSize, params.split_criterion, s); + computeSplit(c, batchSize, params.split_criterion, + n_large_nodes_in_curr_batch, s); CUDA_CHECK(cudaGetLastError()); } } @@ -387,7 +394,7 @@ struct Builder { * @param[in] s cuda stream */ void computeSplit(IdxT col, IdxT batchSize, CRITERION splitType, - cudaStream_t s) { + const int n_large_nodes_in_curr_batch, cudaStream_t s) { ML::PUSH_RANGE( "Builder::computeSplit @builder_base.cuh [batched-levelalgo]"); auto nbins = params.n_bins; @@ -407,7 +414,8 @@ struct Builder { // Pick the max of two size_t smemSize = std::max(smemSize1, smemSize2); dim3 grid(total_num_blocks, colBlks, 1); - CUDA_CHECK(cudaMemsetAsync(hist, 0, sizeof(int) * nHistBins, s)); + int nHistBins = n_large_nodes_in_curr_batch * nbins * colBlks * nclasses; + CUDA_CHECK(cudaMemsetAsync(hist, 0, sizeof(BinT) * nHistBins, s)); ML::PUSH_RANGE( "computeSplitClassificationKernel @builder_base.cuh [batched-levelalgo]"); ObjectiveT objective(input.numOutputs, params.min_impurity_decrease, diff --git a/cpp/src/decisiontree/batched-levelalgo/kernels.cuh b/cpp/src/decisiontree/batched-levelalgo/kernels.cuh index dba371f18c..1fbe5694d3 100644 --- a/cpp/src/decisiontree/batched-levelalgo/kernels.cuh +++ b/cpp/src/decisiontree/batched-levelalgo/kernels.cuh @@ -36,6 +36,8 @@ namespace DecisionTree { template struct WorkloadInfo { IdxT nodeid; // Node in the batch on which the threadblock needs to work + IdxT + large_nodeid; // counts only large nodes (nodes that require more than one block along x-dim for histogram calculation) IdxT offset_blockid; // Offset threadblock id among all the blocks that are // working on this node IdxT num_blocks; // Total number of blocks that are working on the node @@ -305,6 +307,7 @@ __global__ void computeSplitKernel( // Read workload info for this block WorkloadInfo workload_info_cta = workload_info[blockIdx.x]; IdxT nid = workload_info_cta.nodeid; + IdxT large_nid = workload_info_cta.large_nodeid; auto node = nodes[nid]; auto range_start = node.start; auto range_len = node.count; @@ -358,7 +361,7 @@ __global__ void computeSplitKernel( __syncthreads(); if (num_blocks > 1) { // update the corresponding global location - auto histOffset = ((nid * gridDim.y) + blockIdx.y) * pdf_shist_len; + auto histOffset = ((large_nid * gridDim.y) + blockIdx.y) * pdf_shist_len; for (IdxT i = threadIdx.x; i < pdf_shist_len; i += blockDim.x) { BinT::AtomicAdd(hist + histOffset + i, pdf_shist[i]); } diff --git a/python/cuml/ensemble/randomforest_common.pyx b/python/cuml/ensemble/randomforest_common.pyx index 1a8a70ba32..34be5ed68f 100644 --- a/python/cuml/ensemble/randomforest_common.pyx +++ b/python/cuml/ensemble/randomforest_common.pyx @@ -55,7 +55,7 @@ class BaseRandomForestModel(Base): classes_ = CumlArrayDescriptor() - def __init__(self, *, split_criterion, n_streams=8, n_estimators=100, + def __init__(self, *, split_criterion, n_streams=4, n_estimators=100, max_depth=16, handle=None, max_features='auto', n_bins=128, split_algo=1, bootstrap=True, verbose=False, min_samples_leaf=1, min_samples_split=2, @@ -65,7 +65,7 @@ class BaseRandomForestModel(Base): min_impurity_split=None, oob_score=None, random_state=None, warm_start=None, class_weight=None, criterion=None, use_experimental_backend=True, - max_batch_size=128): + max_batch_size=4096): sklearn_params = {"criterion": criterion, "min_weight_fraction_leaf": min_weight_fraction_leaf, diff --git a/python/cuml/ensemble/randomforestclassifier.pyx b/python/cuml/ensemble/randomforestclassifier.pyx index b84cdb4690..dd5945ccfe 100644 --- a/python/cuml/ensemble/randomforestclassifier.pyx +++ b/python/cuml/ensemble/randomforestclassifier.pyx @@ -225,6 +225,8 @@ class RandomForestClassifier(BaseRandomForestModel, Number of bins used by the split algorithm. For large problems, particularly those with highly-skewed input data, increasing the number of bins may improve accuracy. + n_streams : int (default = 4 ) + Number of parallel streams used for forest building min_samples_leaf : int or float (default = 1) The minimum number of samples (rows) in each leaf node. If int, then min_samples_leaf represents the minimum number. @@ -243,7 +245,7 @@ class RandomForestClassifier(BaseRandomForestModel, use_experimental_backend : boolean (default = True) Deprecated and currrently has no effect. .. deprecated:: 21.08 - max_batch_size: int (default = 128) + max_batch_size: int (default = 4096) Maximum number of nodes that can be processed in a given batch. This is used only when 'use_experimental_backend' is true. Does not currently fully guarantee the exact same results. diff --git a/python/cuml/ensemble/randomforestregressor.pyx b/python/cuml/ensemble/randomforestregressor.pyx index 794eadbcb4..422c707bdf 100644 --- a/python/cuml/ensemble/randomforestregressor.pyx +++ b/python/cuml/ensemble/randomforestregressor.pyx @@ -205,6 +205,8 @@ class RandomForestRegressor(BaseRandomForestModel, Number of bins used by the split algorithm. For large problems, particularly those with highly-skewed input data, increasing the number of bins may improve accuracy. + n_streams : int (default = 4 ) + Number of parallel streams used for forest building min_samples_leaf : int or float (default = 1) The minimum number of samples (rows) in each leaf node. If int, then min_samples_leaf represents the minimum number. @@ -230,14 +232,13 @@ class RandomForestRegressor(BaseRandomForestModel, use_experimental_backend : boolean (default = True) Deprecated and currrently has no effect. .. deprecated:: 21.08 - max_batch_size: int (default = 128) + max_batch_size: int (default = 4096) Maximum number of nodes that can be processed in a given batch. This is used only when 'use_experimental_backend' is true. random_state : int (default = None) Seed for the random number generator. Unseeded by default. Does not currently fully guarantee the exact same results. **Note: Parameter `seed` is removed since release 0.19.** - handle : cuml.Handle Specifies the cuml.handle that holds internal CUDA state for computations in this model. Most importantly, this specifies the CUDA