From 797d1575637ffe3ada4695621ea5f8f2a81e4758 Mon Sep 17 00:00:00 2001 From: venkywonka Date: Sat, 5 Jun 2021 14:12:12 +0530 Subject: [PATCH 1/8] writing workspace to file --- .../decisiontree/batched-levelalgo/builder.cuh | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/cpp/src/decisiontree/batched-levelalgo/builder.cuh b/cpp/src/decisiontree/batched-levelalgo/builder.cuh index c2405747a4..f2d19b01d5 100644 --- a/cpp/src/decisiontree/batched-levelalgo/builder.cuh +++ b/cpp/src/decisiontree/batched-levelalgo/builder.cuh @@ -24,7 +24,16 @@ #include "builder_base.cuh" #include - +// Usage example: filePutContents("./yourfile.txt", "content", true); +void filePutContents(const std::string& name, const std::string& content, bool append = false) { + std::ofstream outfile; + if (append) + outfile.open(name, std::ios_base::app); + else + outfile.open(name); + outfile << content << std::endl; + outfile.close(); +} namespace ML { namespace DecisionTree { @@ -63,6 +72,10 @@ void grow_tree(std::shared_ptr d_allocator, nrows, ncols, n_sampled_rows, IdxT(params.max_features * ncols), rowids, unique_labels, quantiles); + if(treeid == 0) { + CUML_LOG_WARN("device workspace allocated: %d kB", raft::ceildiv(d_wsize, 1000)); + filePutContents("workspace.txt", std::to_string(d_wsize)); + } MLCommon::device_buffer d_buff(d_allocator, stream, d_wsize); MLCommon::host_buffer h_buff(h_allocator, stream, h_wsize); From 93187f16276727c10221b58a0fd4896007241358 Mon Sep 17 00:00:00 2001 From: venkywonka Date: Tue, 8 Jun 2021 18:17:06 +0530 Subject: [PATCH 2/8] getenv to change blks_for_cols --- cpp/src/decisiontree/batched-levelalgo/builder.cuh | 5 +++-- cpp/src/decisiontree/batched-levelalgo/builder_base.cuh | 6 +++++- 2 files changed, 8 insertions(+), 3 deletions(-) diff --git a/cpp/src/decisiontree/batched-levelalgo/builder.cuh b/cpp/src/decisiontree/batched-levelalgo/builder.cuh index f2d19b01d5..209b8fb777 100644 --- a/cpp/src/decisiontree/batched-levelalgo/builder.cuh +++ b/cpp/src/decisiontree/batched-levelalgo/builder.cuh @@ -22,7 +22,7 @@ #include #include "builder_base.cuh" - +#include #include // Usage example: filePutContents("./yourfile.txt", "content", true); void filePutContents(const std::string& name, const std::string& content, bool append = false) { @@ -34,6 +34,7 @@ void filePutContents(const std::string& name, const std::string& content, bool a outfile << content << std::endl; outfile.close(); } + namespace ML { namespace DecisionTree { @@ -73,7 +74,7 @@ void grow_tree(std::shared_ptr d_allocator, IdxT(params.max_features * ncols), rowids, unique_labels, quantiles); if(treeid == 0) { - CUML_LOG_WARN("device workspace allocated: %d kB", raft::ceildiv(d_wsize, 1000)); + CUML_LOG_WARN("device workspace allocated: %d kB", raft::ceildiv((int)d_wsize, 1000)); filePutContents("workspace.txt", std::to_string(d_wsize)); } MLCommon::device_buffer d_buff(d_allocator, stream, d_wsize); diff --git a/cpp/src/decisiontree/batched-levelalgo/builder_base.cuh b/cpp/src/decisiontree/batched-levelalgo/builder_base.cuh index 42883d1fdb..cb7f7989bd 100644 --- a/cpp/src/decisiontree/batched-levelalgo/builder_base.cuh +++ b/cpp/src/decisiontree/batched-levelalgo/builder_base.cuh @@ -147,7 +147,10 @@ struct Builder { params = p; this->treeid = treeid; this->seed = seed; - n_blks_for_cols = std::min(sampledCols, n_blks_for_cols); + // int env_blks_for_cols = (int)strtol(std::getenv("BLKS_FOR_COLS"), NULL, 10); + // n_blks_for_cols = std::min(sampledCols, env_blks_for_cols); + n_blks_for_cols = sampledCols; + // CUML_LOG_WARN("blocks for cols: %d, env_var: %d", n_blks_for_cols, env_blks_for_cols); input.data = data; input.labels = labels; input.M = totalRows; @@ -511,6 +514,7 @@ struct RegTraits { ML::PUSH_RANGE( "Builder::computeSplit @builder_base.cuh [batched-levelalgo]"); auto colBlks = std::min(b.n_blks_for_cols, b.input.nSampledCols - col); + // CUML_LOG_WARN("column blocks used: %d", colBlks); auto nbins = b.params.n_bins; // Compute shared memory size From c4b154821474d2df59ef2b1be586ed897048d29f Mon Sep 17 00:00:00 2001 From: venkywonka Date: Mon, 21 Jun 2021 15:25:36 +0530 Subject: [PATCH 3/8] memsets for only large nodes --- .../batched-levelalgo/builder.cuh | 15 ----------- .../batched-levelalgo/builder_base.cuh | 26 ++++++++++++------- .../batched-levelalgo/kernels.cuh | 9 ++++--- cpp/src/randomforest/randomforest_impl.cuh | 1 + 4 files changed, 23 insertions(+), 28 deletions(-) diff --git a/cpp/src/decisiontree/batched-levelalgo/builder.cuh b/cpp/src/decisiontree/batched-levelalgo/builder.cuh index 209b8fb777..af934de270 100644 --- a/cpp/src/decisiontree/batched-levelalgo/builder.cuh +++ b/cpp/src/decisiontree/batched-levelalgo/builder.cuh @@ -22,18 +22,7 @@ #include #include "builder_base.cuh" -#include #include -// Usage example: filePutContents("./yourfile.txt", "content", true); -void filePutContents(const std::string& name, const std::string& content, bool append = false) { - std::ofstream outfile; - if (append) - outfile.open(name, std::ios_base::app); - else - outfile.open(name); - outfile << content << std::endl; - outfile.close(); -} namespace ML { namespace DecisionTree { @@ -73,10 +62,6 @@ void grow_tree(std::shared_ptr d_allocator, nrows, ncols, n_sampled_rows, IdxT(params.max_features * ncols), rowids, unique_labels, quantiles); - if(treeid == 0) { - CUML_LOG_WARN("device workspace allocated: %d kB", raft::ceildiv((int)d_wsize, 1000)); - filePutContents("workspace.txt", std::to_string(d_wsize)); - } MLCommon::device_buffer d_buff(d_allocator, stream, d_wsize); MLCommon::host_buffer h_buff(h_allocator, stream, h_wsize); diff --git a/cpp/src/decisiontree/batched-levelalgo/builder_base.cuh b/cpp/src/decisiontree/batched-levelalgo/builder_base.cuh index cb7f7989bd..66abba9a1f 100644 --- a/cpp/src/decisiontree/batched-levelalgo/builder_base.cuh +++ b/cpp/src/decisiontree/batched-levelalgo/builder_base.cuh @@ -147,10 +147,7 @@ struct Builder { params = p; this->treeid = treeid; this->seed = seed; - // int env_blks_for_cols = (int)strtol(std::getenv("BLKS_FOR_COLS"), NULL, 10); - // n_blks_for_cols = std::min(sampledCols, env_blks_for_cols); - n_blks_for_cols = sampledCols; - // CUML_LOG_WARN("blocks for cols: %d, env_var: %d", n_blks_for_cols, env_blks_for_cols); + n_blks_for_cols = std::min(sampledCols, n_blks_for_cols); input.data = data; input.labels = labels; input.M = totalRows; @@ -341,6 +338,7 @@ struct Builder { raft::update_device(curr_nodes, h_nodes.data() + node_start, batchSize, s); int total_samples_in_curr_batch = 0; + int n_large_nodes_in_curr_batch = 0; total_num_blocks = 0; for (int n = 0; n < batchSize; n++) { total_samples_in_curr_batch += h_nodes[node_start + n].count; @@ -348,6 +346,8 @@ struct Builder { SAMPLES_PER_THREAD * Traits::TPB_DEFAULT); num_blocks = std::max(1, num_blocks); + if(num_blocks > 1) ++n_large_nodes_in_curr_batch; + bool is_leaf = leafBasedOnParams( h_nodes[node_start + n].depth, params.max_depth, params.min_samples_split, params.max_leaves, h_n_leaves, @@ -356,6 +356,7 @@ struct Builder { for (int b = 0; b < num_blocks; b++) { h_workload_info[total_num_blocks + b].nodeid = n; + h_workload_info[total_num_blocks + b].large_nodeid = n_large_nodes_in_curr_batch - 1; h_workload_info[total_num_blocks + b].offset_blockid = b; h_workload_info[total_num_blocks + b].num_blocks = num_blocks; } @@ -367,7 +368,7 @@ struct Builder { auto n_col_blks = n_blks_for_cols; if (total_num_blocks) { for (IdxT c = 0; c < input.nSampledCols; c += n_col_blks) { - Traits::computeSplit(*this, c, batchSize, params.split_criterion, s); + Traits::computeSplit(*this, c, batchSize, params.split_criterion, n_large_nodes_in_curr_batch, s); CUDA_CHECK(cudaGetLastError()); } } @@ -428,7 +429,7 @@ struct ClsTraits { * @param[in] s cuda stream */ static void computeSplit(Builder>& b, IdxT col, - IdxT batchSize, CRITERION splitType, + IdxT batchSize, CRITERION splitType, int &n_large_nodes_in_curr_batch, cudaStream_t s) { ML::PUSH_RANGE( "Builder::computeSplit @builder_base.cuh [batched-levelalgo]"); @@ -449,7 +450,9 @@ struct ClsTraits { // Pick the max of two size_t smemSize = std::max(smemSize1, smemSize2); dim3 grid(b.total_num_blocks, colBlks, 1); - CUDA_CHECK(cudaMemsetAsync(b.hist, 0, sizeof(int) * b.nHistBins, s)); + int nHistBins = 0; + nHistBins = n_large_nodes_in_curr_batch * (1 + nbins) * colBlks * nclasses; + CUDA_CHECK(cudaMemsetAsync(b.hist, 0, sizeof(int) * nHistBins, s)); ML::PUSH_RANGE( "computeSplitClassificationKernel @builder_base.cuh [batched-levelalgo]"); computeSplitClassificationKernel @@ -509,7 +512,7 @@ struct RegTraits { * @param[in] s cuda stream */ static void computeSplit(Builder>& b, IdxT col, - IdxT batchSize, CRITERION splitType, + IdxT batchSize, CRITERION splitType, int &n_large_nodes_in_curr_batch, cudaStream_t s) { ML::PUSH_RANGE( "Builder::computeSplit @builder_base.cuh [batched-levelalgo]"); @@ -533,10 +536,13 @@ struct RegTraits { size_t smemSize = std::max(smemSize1, smemSize2); dim3 grid(b.total_num_blocks, colBlks, 1); + int nPredCounts = 0; + nPredCounts = n_large_nodes_in_curr_batch * nbins * colBlks; + // std::cout<<"nPredCounts = "< struct WorkloadInfo { IdxT nodeid; // Node in the batch on which the threadblock needs to work + IdxT large_nodeid; // counts only large nodes IdxT offset_blockid; // Offset threadblock id among all the blocks that are // working on this node IdxT num_blocks; // Total number of blocks that are working on the node @@ -387,6 +388,7 @@ __global__ void computeSplitClassificationKernel( // Read workload info for this block WorkloadInfo workload_info_cta = workload_info[blockIdx.x]; IdxT nid = workload_info_cta.nodeid; + IdxT large_nid = workload_info_cta.large_nodeid; auto node = nodes[nid]; auto range_start = node.start; auto range_len = node.count; @@ -445,7 +447,7 @@ __global__ void computeSplitClassificationKernel( __syncthreads(); if (num_blocks > 1) { // update the corresponding global location - auto histOffset = ((nid * gridDim.y) + blockIdx.y) * pdf_shist_len; + auto histOffset = ((large_nid * gridDim.y) + blockIdx.y) * pdf_shist_len; for (IdxT i = threadIdx.x; i < pdf_shist_len; i += blockDim.x) { atomicAdd(hist + histOffset + i, pdf_shist[i]); } @@ -530,6 +532,7 @@ __global__ void computeSplitRegressionKernel( // Read workload info for this block WorkloadInfo workload_info_cta = workload_info[blockIdx.x]; IdxT nid = workload_info_cta.nodeid; + IdxT large_nid = workload_info_cta.large_nodeid; auto node = nodes[nid]; auto range_start = node.start; @@ -598,13 +601,13 @@ __global__ void computeSplitRegressionKernel( if (num_blocks > 1) { // update the corresponding global location for counts - auto gcOffset = ((nid * gridDim.y) + blockIdx.y) * nbins; + auto gcOffset = ((large_nid * gridDim.y) + blockIdx.y) * nbins; for (IdxT i = threadIdx.x; i < nbins; i += blockDim.x) { atomicAdd(count + gcOffset + i, pdf_scount[i]); } // update the corresponding global location for preds - auto gOffset = ((nid * gridDim.y) + blockIdx.y) * pdf_spred_len; + auto gOffset = ((large_nid * gridDim.y) + blockIdx.y) * pdf_spred_len; for (IdxT i = threadIdx.x; i < pdf_spred_len; i += blockDim.x) { atomicAdd(pred + gOffset + i, pdf_spred[i]); } diff --git a/cpp/src/randomforest/randomforest_impl.cuh b/cpp/src/randomforest/randomforest_impl.cuh index c87a06035b..d47fc2af56 100644 --- a/cpp/src/randomforest/randomforest_impl.cuh +++ b/cpp/src/randomforest/randomforest_impl.cuh @@ -442,6 +442,7 @@ void rfRegressor::fit(const raft::handle_t& user_handle, const T* input, n_sampled_rows = n_rows; } int n_streams = this->rf_params.n_streams; + CUML_LOG_WARN("\n rf_params.n_streams: %d \n", n_streams); ASSERT( n_streams <= handle.get_num_internal_streams(), "rf_params.n_streams (=%d) should be <= raft::handle_t.n_streams (=%d)", From 4e20c28ec720104d3802808dcf869af477aaaa3f Mon Sep 17 00:00:00 2001 From: venkywonka Date: Mon, 21 Jun 2021 19:32:18 +0530 Subject: [PATCH 4/8] change default --- python/cuml/ensemble/randomforest_common.pyx | 2 +- python/cuml/ensemble/randomforestclassifier.pyx | 2 +- python/cuml/ensemble/randomforestregressor.pyx | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/python/cuml/ensemble/randomforest_common.pyx b/python/cuml/ensemble/randomforest_common.pyx index 1a8a70ba32..e8f769afa4 100644 --- a/python/cuml/ensemble/randomforest_common.pyx +++ b/python/cuml/ensemble/randomforest_common.pyx @@ -65,7 +65,7 @@ class BaseRandomForestModel(Base): min_impurity_split=None, oob_score=None, random_state=None, warm_start=None, class_weight=None, criterion=None, use_experimental_backend=True, - max_batch_size=128): + max_batch_size=4096): sklearn_params = {"criterion": criterion, "min_weight_fraction_leaf": min_weight_fraction_leaf, diff --git a/python/cuml/ensemble/randomforestclassifier.pyx b/python/cuml/ensemble/randomforestclassifier.pyx index b84cdb4690..62bf53e584 100644 --- a/python/cuml/ensemble/randomforestclassifier.pyx +++ b/python/cuml/ensemble/randomforestclassifier.pyx @@ -243,7 +243,7 @@ class RandomForestClassifier(BaseRandomForestModel, use_experimental_backend : boolean (default = True) Deprecated and currrently has no effect. .. deprecated:: 21.08 - max_batch_size: int (default = 128) + max_batch_size: int (default = 4096) Maximum number of nodes that can be processed in a given batch. This is used only when 'use_experimental_backend' is true. Does not currently fully guarantee the exact same results. diff --git a/python/cuml/ensemble/randomforestregressor.pyx b/python/cuml/ensemble/randomforestregressor.pyx index 794eadbcb4..91fb954738 100644 --- a/python/cuml/ensemble/randomforestregressor.pyx +++ b/python/cuml/ensemble/randomforestregressor.pyx @@ -230,7 +230,7 @@ class RandomForestRegressor(BaseRandomForestModel, use_experimental_backend : boolean (default = True) Deprecated and currrently has no effect. .. deprecated:: 21.08 - max_batch_size: int (default = 128) + max_batch_size: int (default = 4096) Maximum number of nodes that can be processed in a given batch. This is used only when 'use_experimental_backend' is true. random_state : int (default = None) From 8a0762eb076e0d7b794906a8cf09e5ead3052c58 Mon Sep 17 00:00:00 2001 From: venkywonka Date: Tue, 22 Jun 2021 12:15:48 +0530 Subject: [PATCH 5/8] pruning prints and changing defaults in other places --- cpp/include/cuml/tree/decisiontree.hpp | 2 +- cpp/src/decisiontree/batched-levelalgo/builder.cuh | 1 + cpp/src/decisiontree/batched-levelalgo/builder_base.cuh | 2 -- cpp/src/randomforest/randomforest_impl.cuh | 1 - 4 files changed, 2 insertions(+), 4 deletions(-) diff --git a/cpp/include/cuml/tree/decisiontree.hpp b/cpp/include/cuml/tree/decisiontree.hpp index c483783679..8bd49729e2 100644 --- a/cpp/include/cuml/tree/decisiontree.hpp +++ b/cpp/include/cuml/tree/decisiontree.hpp @@ -92,7 +92,7 @@ void set_tree_params(DecisionTreeParams ¶ms, int cfg_max_depth = -1, int cfg_min_samples_split = 2, float cfg_min_impurity_decrease = 0.0f, CRITERION cfg_split_criterion = CRITERION_END, - int cfg_max_batch_size = 128); + int cfg_max_batch_size = 4096); /** * @brief Check validity of all decision tree hyper-parameters. diff --git a/cpp/src/decisiontree/batched-levelalgo/builder.cuh b/cpp/src/decisiontree/batched-levelalgo/builder.cuh index af934de270..c2405747a4 100644 --- a/cpp/src/decisiontree/batched-levelalgo/builder.cuh +++ b/cpp/src/decisiontree/batched-levelalgo/builder.cuh @@ -22,6 +22,7 @@ #include #include "builder_base.cuh" + #include namespace ML { diff --git a/cpp/src/decisiontree/batched-levelalgo/builder_base.cuh b/cpp/src/decisiontree/batched-levelalgo/builder_base.cuh index 66abba9a1f..9de2d171ac 100644 --- a/cpp/src/decisiontree/batched-levelalgo/builder_base.cuh +++ b/cpp/src/decisiontree/batched-levelalgo/builder_base.cuh @@ -517,7 +517,6 @@ struct RegTraits { ML::PUSH_RANGE( "Builder::computeSplit @builder_base.cuh [batched-levelalgo]"); auto colBlks = std::min(b.n_blks_for_cols, b.input.nSampledCols - col); - // CUML_LOG_WARN("column blocks used: %d", colBlks); auto nbins = b.params.n_bins; // Compute shared memory size @@ -538,7 +537,6 @@ struct RegTraits { int nPredCounts = 0; nPredCounts = n_large_nodes_in_curr_batch * nbins * colBlks; - // std::cout<<"nPredCounts = "<::fit(const raft::handle_t& user_handle, const T* input, n_sampled_rows = n_rows; } int n_streams = this->rf_params.n_streams; - CUML_LOG_WARN("\n rf_params.n_streams: %d \n", n_streams); ASSERT( n_streams <= handle.get_num_internal_streams(), "rf_params.n_streams (=%d) should be <= raft::handle_t.n_streams (=%d)", From 23ff5da7cea77efc350f82cdab9e0a92439cdde5 Mon Sep 17 00:00:00 2001 From: venkywonka Date: Tue, 22 Jun 2021 12:25:18 +0530 Subject: [PATCH 6/8] FIX clang format --- .../batched-levelalgo/builder_base.cuh | 22 +++++++++---------- .../batched-levelalgo/kernels.cuh | 2 +- 2 files changed, 12 insertions(+), 12 deletions(-) diff --git a/cpp/src/decisiontree/batched-levelalgo/builder_base.cuh b/cpp/src/decisiontree/batched-levelalgo/builder_base.cuh index 9de2d171ac..e6cfeacaff 100644 --- a/cpp/src/decisiontree/batched-levelalgo/builder_base.cuh +++ b/cpp/src/decisiontree/batched-levelalgo/builder_base.cuh @@ -346,7 +346,7 @@ struct Builder { SAMPLES_PER_THREAD * Traits::TPB_DEFAULT); num_blocks = std::max(1, num_blocks); - if(num_blocks > 1) ++n_large_nodes_in_curr_batch; + if (num_blocks > 1) ++n_large_nodes_in_curr_batch; bool is_leaf = leafBasedOnParams( h_nodes[node_start + n].depth, params.max_depth, @@ -356,7 +356,8 @@ struct Builder { for (int b = 0; b < num_blocks; b++) { h_workload_info[total_num_blocks + b].nodeid = n; - h_workload_info[total_num_blocks + b].large_nodeid = n_large_nodes_in_curr_batch - 1; + h_workload_info[total_num_blocks + b].large_nodeid = + n_large_nodes_in_curr_batch - 1; h_workload_info[total_num_blocks + b].offset_blockid = b; h_workload_info[total_num_blocks + b].num_blocks = num_blocks; } @@ -368,7 +369,8 @@ struct Builder { auto n_col_blks = n_blks_for_cols; if (total_num_blocks) { for (IdxT c = 0; c < input.nSampledCols; c += n_col_blks) { - Traits::computeSplit(*this, c, batchSize, params.split_criterion, n_large_nodes_in_curr_batch, s); + Traits::computeSplit(*this, c, batchSize, params.split_criterion, + n_large_nodes_in_curr_batch, s); CUDA_CHECK(cudaGetLastError()); } } @@ -429,8 +431,8 @@ struct ClsTraits { * @param[in] s cuda stream */ static void computeSplit(Builder>& b, IdxT col, - IdxT batchSize, CRITERION splitType, int &n_large_nodes_in_curr_batch, - cudaStream_t s) { + IdxT batchSize, CRITERION splitType, + int& n_large_nodes_in_curr_batch, cudaStream_t s) { ML::PUSH_RANGE( "Builder::computeSplit @builder_base.cuh [batched-levelalgo]"); auto nbins = b.params.n_bins; @@ -512,8 +514,8 @@ struct RegTraits { * @param[in] s cuda stream */ static void computeSplit(Builder>& b, IdxT col, - IdxT batchSize, CRITERION splitType, int &n_large_nodes_in_curr_batch, - cudaStream_t s) { + IdxT batchSize, CRITERION splitType, + int& n_large_nodes_in_curr_batch, cudaStream_t s) { ML::PUSH_RANGE( "Builder::computeSplit @builder_base.cuh [batched-levelalgo]"); auto colBlks = std::min(b.n_blks_for_cols, b.input.nSampledCols - col); @@ -537,10 +539,8 @@ struct RegTraits { int nPredCounts = 0; nPredCounts = n_large_nodes_in_curr_batch * nbins * colBlks; - CUDA_CHECK( - cudaMemsetAsync(b.pred, 0, sizeof(DataT) * nPredCounts * 2, s)); - CUDA_CHECK( - cudaMemsetAsync(b.pred_count, 0, sizeof(IdxT) * nPredCounts, s)); + CUDA_CHECK(cudaMemsetAsync(b.pred, 0, sizeof(DataT) * nPredCounts * 2, s)); + CUDA_CHECK(cudaMemsetAsync(b.pred_count, 0, sizeof(IdxT) * nPredCounts, s)); ML::PUSH_RANGE( "computeSplitRegressionKernel @builder_base.cuh [batched-levelalgo]"); diff --git a/cpp/src/decisiontree/batched-levelalgo/kernels.cuh b/cpp/src/decisiontree/batched-levelalgo/kernels.cuh index e70d189045..565d792f96 100644 --- a/cpp/src/decisiontree/batched-levelalgo/kernels.cuh +++ b/cpp/src/decisiontree/batched-levelalgo/kernels.cuh @@ -35,7 +35,7 @@ namespace DecisionTree { template struct WorkloadInfo { IdxT nodeid; // Node in the batch on which the threadblock needs to work - IdxT large_nodeid; // counts only large nodes + IdxT large_nodeid; // counts only large nodes IdxT offset_blockid; // Offset threadblock id among all the blocks that are // working on this node IdxT num_blocks; // Total number of blocks that are working on the node From 68edd49df993b23b44380fb447b74cee51752574 Mon Sep 17 00:00:00 2001 From: venkywonka Date: Wed, 23 Jun 2021 09:29:04 +0530 Subject: [PATCH 7/8] add comments, fix conflicts --- .../batched-levelalgo/builder_base.cuh | 17 ++++++++--------- .../decisiontree/batched-levelalgo/kernels.cuh | 3 ++- 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/cpp/src/decisiontree/batched-levelalgo/builder_base.cuh b/cpp/src/decisiontree/batched-levelalgo/builder_base.cuh index d37c3a1b3c..8cf52c76b5 100644 --- a/cpp/src/decisiontree/batched-levelalgo/builder_base.cuh +++ b/cpp/src/decisiontree/batched-levelalgo/builder_base.cuh @@ -327,7 +327,8 @@ struct Builder { raft::update_device(curr_nodes, h_nodes.data() + node_start, batchSize, s); int total_samples_in_curr_batch = 0; - int n_large_nodes_in_curr_batch = 0; + int n_large_nodes_in_curr_batch = + 0; // large nodes are nodes having training instances larger than block size, hence require global memory for histogram construction total_num_blocks = 0; for (int n = 0; n < batchSize; n++) { total_samples_in_curr_batch += h_nodes[node_start + n].count; @@ -358,8 +359,8 @@ struct Builder { auto n_col_blks = n_blks_for_cols; if (total_num_blocks) { for (IdxT c = 0; c < input.nSampledCols; c += n_col_blks) { - Traits::computeSplit(c, batchSize, params.split_criterion, - n_large_nodes_in_curr_batch, s); + computeSplit(c, batchSize, params.split_criterion, + n_large_nodes_in_curr_batch, s); CUDA_CHECK(cudaGetLastError()); } } @@ -392,9 +393,8 @@ struct Builder { * @param[in] splitType split criterion * @param[in] s cuda stream */ - static void computeSplit(IdxT col, - IdxT batchSize, CRITERION splitType, - int& n_large_nodes_in_curr_batch, cudaStream_t s) { + void computeSplit(IdxT col, IdxT batchSize, CRITERION splitType, + const int n_large_nodes_in_curr_batch, cudaStream_t s) { ML::PUSH_RANGE( "Builder::computeSplit @builder_base.cuh [batched-levelalgo]"); auto nbins = params.n_bins; @@ -414,9 +414,8 @@ struct Builder { // Pick the max of two size_t smemSize = std::max(smemSize1, smemSize2); dim3 grid(total_num_blocks, colBlks, 1); - int nHistBins = 0; - nHistBins = n_large_nodes_in_curr_batch * nbins * colBlks * nclasses; - CUDA_CHECK(cudaMemsetAsync(b.hist, 0, sizeof(BinT) * nHistBins, s)); + int nHistBins = n_large_nodes_in_curr_batch * nbins * colBlks * nclasses; + CUDA_CHECK(cudaMemsetAsync(hist, 0, sizeof(BinT) * nHistBins, s)); ML::PUSH_RANGE( "computeSplitClassificationKernel @builder_base.cuh [batched-levelalgo]"); ObjectiveT objective(input.numOutputs, params.min_impurity_decrease, diff --git a/cpp/src/decisiontree/batched-levelalgo/kernels.cuh b/cpp/src/decisiontree/batched-levelalgo/kernels.cuh index 5a5d20d7d4..1fbe5694d3 100644 --- a/cpp/src/decisiontree/batched-levelalgo/kernels.cuh +++ b/cpp/src/decisiontree/batched-levelalgo/kernels.cuh @@ -36,7 +36,8 @@ namespace DecisionTree { template struct WorkloadInfo { IdxT nodeid; // Node in the batch on which the threadblock needs to work - IdxT large_nodeid; // counts only large nodes + IdxT + large_nodeid; // counts only large nodes (nodes that require more than one block along x-dim for histogram calculation) IdxT offset_blockid; // Offset threadblock id among all the blocks that are // working on this node IdxT num_blocks; // Total number of blocks that are working on the node From 70665a35628e3bd8aa08a921a47239aa0f293236 Mon Sep 17 00:00:00 2001 From: venkywonka Date: Tue, 29 Jun 2021 05:25:21 +0000 Subject: [PATCH 8/8] change default , update docstrings --- python/cuml/ensemble/randomforest_common.pyx | 2 +- python/cuml/ensemble/randomforestclassifier.pyx | 2 ++ python/cuml/ensemble/randomforestregressor.pyx | 3 ++- 3 files changed, 5 insertions(+), 2 deletions(-) diff --git a/python/cuml/ensemble/randomforest_common.pyx b/python/cuml/ensemble/randomforest_common.pyx index e8f769afa4..34be5ed68f 100644 --- a/python/cuml/ensemble/randomforest_common.pyx +++ b/python/cuml/ensemble/randomforest_common.pyx @@ -55,7 +55,7 @@ class BaseRandomForestModel(Base): classes_ = CumlArrayDescriptor() - def __init__(self, *, split_criterion, n_streams=8, n_estimators=100, + def __init__(self, *, split_criterion, n_streams=4, n_estimators=100, max_depth=16, handle=None, max_features='auto', n_bins=128, split_algo=1, bootstrap=True, verbose=False, min_samples_leaf=1, min_samples_split=2, diff --git a/python/cuml/ensemble/randomforestclassifier.pyx b/python/cuml/ensemble/randomforestclassifier.pyx index 62bf53e584..dd5945ccfe 100644 --- a/python/cuml/ensemble/randomforestclassifier.pyx +++ b/python/cuml/ensemble/randomforestclassifier.pyx @@ -225,6 +225,8 @@ class RandomForestClassifier(BaseRandomForestModel, Number of bins used by the split algorithm. For large problems, particularly those with highly-skewed input data, increasing the number of bins may improve accuracy. + n_streams : int (default = 4 ) + Number of parallel streams used for forest building min_samples_leaf : int or float (default = 1) The minimum number of samples (rows) in each leaf node. If int, then min_samples_leaf represents the minimum number. diff --git a/python/cuml/ensemble/randomforestregressor.pyx b/python/cuml/ensemble/randomforestregressor.pyx index 91fb954738..422c707bdf 100644 --- a/python/cuml/ensemble/randomforestregressor.pyx +++ b/python/cuml/ensemble/randomforestregressor.pyx @@ -205,6 +205,8 @@ class RandomForestRegressor(BaseRandomForestModel, Number of bins used by the split algorithm. For large problems, particularly those with highly-skewed input data, increasing the number of bins may improve accuracy. + n_streams : int (default = 4 ) + Number of parallel streams used for forest building min_samples_leaf : int or float (default = 1) The minimum number of samples (rows) in each leaf node. If int, then min_samples_leaf represents the minimum number. @@ -237,7 +239,6 @@ class RandomForestRegressor(BaseRandomForestModel, Seed for the random number generator. Unseeded by default. Does not currently fully guarantee the exact same results. **Note: Parameter `seed` is removed since release 0.19.** - handle : cuml.Handle Specifies the cuml.handle that holds internal CUDA state for computations in this model. Most importantly, this specifies the CUDA