Skip to content

Commit

Permalink
RF: memset and batch size optimization for computing splits (rapidsai…
Browse files Browse the repository at this point in the history
…#4001)

* **optimization 1:** Increase the default maximum number of nodes that can be processed per batch (the `max_batch_size` hyperparameter)
    * However, this causes an increase in GPU memory, but for practical workloads, this hardly exceeds 200 MB.
* **optimization 2:** reduce the amount of memory accessed in the memset operations per kernel call 

---
* The current PR drastically reduces total number of kernel invocations (while increasing work-per-invocation) and also memsets required per kernel invocation. This can be seen in the following plot on the `year` dataset. 
    * x-axis: (with/without `optimization 1` x with/without `optimization 2`) , y-axis: times (s)
    * `CSRK` = `computeSplitRegressionKernel` 
    
    *  ![year-nsys-kernel-and-memset-times-lite_mode-max_bach_size](https://user-images.githubusercontent.com/23023424/122897144-5b319380-d367-11eb-995f-9c05a086fc0f.png)

---

* With `n_estimators: 10`, `n_streams: 4`, `max_depth:32`  (rest default) the following are the gbm-bench plots: 
    * (main: branch-21.08 , devel: current PR, skl: scikit-learn RF)
    * scores are accuracy for classification and MSE for regression
    * Note: scikit-learn runs on `n_jobs=-1` so it's leveraging all the 24 CPUs in my machine


![memset-batch-opt](https://user-images.githubusercontent.com/23023424/122897816-f88cc780-d367-11eb-9b0f-6384d4ef8cbb.png)

Authors:
  - Venkat (https://github.com/venkywonka)

Approvers:
  - Rory Mitchell (https://github.com/RAMitchell)
  - Dante Gama Dessavre (https://github.com/dantegd)

URL: rapidsai#4001
  • Loading branch information
venkywonka authored Jun 29, 2021
1 parent a6f12f9 commit d372d3d
Show file tree
Hide file tree
Showing 6 changed files with 24 additions and 10 deletions.
2 changes: 1 addition & 1 deletion cpp/include/cuml/tree/decisiontree.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ void set_tree_params(DecisionTreeParams &params, int cfg_max_depth = -1,
int cfg_min_samples_split = 2,
float cfg_min_impurity_decrease = 0.0f,
CRITERION cfg_split_criterion = CRITERION_END,
int cfg_max_batch_size = 128);
int cfg_max_batch_size = 4096);

/**
* @brief Check validity of all decision tree hyper-parameters.
Expand Down
14 changes: 11 additions & 3 deletions cpp/src/decisiontree/batched-levelalgo/builder_base.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -327,13 +327,17 @@ struct Builder {
raft::update_device(curr_nodes, h_nodes.data() + node_start, batchSize, s);

int total_samples_in_curr_batch = 0;
int n_large_nodes_in_curr_batch =
0; // large nodes are nodes having training instances larger than block size, hence require global memory for histogram construction
total_num_blocks = 0;
for (int n = 0; n < batchSize; n++) {
total_samples_in_curr_batch += h_nodes[node_start + n].count;
int num_blocks = raft::ceildiv(h_nodes[node_start + n].count,
SAMPLES_PER_THREAD * TPB_DEFAULT);
num_blocks = std::max(1, num_blocks);

if (num_blocks > 1) ++n_large_nodes_in_curr_batch;

bool is_leaf = leafBasedOnParams<DataT, IdxT>(
h_nodes[node_start + n].depth, params.max_depth,
params.min_samples_split, params.max_leaves, h_n_leaves,
Expand All @@ -342,6 +346,8 @@ struct Builder {

for (int b = 0; b < num_blocks; b++) {
h_workload_info[total_num_blocks + b].nodeid = n;
h_workload_info[total_num_blocks + b].large_nodeid =
n_large_nodes_in_curr_batch - 1;
h_workload_info[total_num_blocks + b].offset_blockid = b;
h_workload_info[total_num_blocks + b].num_blocks = num_blocks;
}
Expand All @@ -353,7 +359,8 @@ struct Builder {
auto n_col_blks = n_blks_for_cols;
if (total_num_blocks) {
for (IdxT c = 0; c < input.nSampledCols; c += n_col_blks) {
computeSplit(c, batchSize, params.split_criterion, s);
computeSplit(c, batchSize, params.split_criterion,
n_large_nodes_in_curr_batch, s);
CUDA_CHECK(cudaGetLastError());
}
}
Expand Down Expand Up @@ -387,7 +394,7 @@ struct Builder {
* @param[in] s cuda stream
*/
void computeSplit(IdxT col, IdxT batchSize, CRITERION splitType,
cudaStream_t s) {
const int n_large_nodes_in_curr_batch, cudaStream_t s) {
ML::PUSH_RANGE(
"Builder::computeSplit @builder_base.cuh [batched-levelalgo]");
auto nbins = params.n_bins;
Expand All @@ -407,7 +414,8 @@ struct Builder {
// Pick the max of two
size_t smemSize = std::max(smemSize1, smemSize2);
dim3 grid(total_num_blocks, colBlks, 1);
CUDA_CHECK(cudaMemsetAsync(hist, 0, sizeof(int) * nHistBins, s));
int nHistBins = n_large_nodes_in_curr_batch * nbins * colBlks * nclasses;
CUDA_CHECK(cudaMemsetAsync(hist, 0, sizeof(BinT) * nHistBins, s));
ML::PUSH_RANGE(
"computeSplitClassificationKernel @builder_base.cuh [batched-levelalgo]");
ObjectiveT objective(input.numOutputs, params.min_impurity_decrease,
Expand Down
5 changes: 4 additions & 1 deletion cpp/src/decisiontree/batched-levelalgo/kernels.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@ namespace DecisionTree {
template <typename IdxT>
struct WorkloadInfo {
IdxT nodeid; // Node in the batch on which the threadblock needs to work
IdxT
large_nodeid; // counts only large nodes (nodes that require more than one block along x-dim for histogram calculation)
IdxT offset_blockid; // Offset threadblock id among all the blocks that are
// working on this node
IdxT num_blocks; // Total number of blocks that are working on the node
Expand Down Expand Up @@ -305,6 +307,7 @@ __global__ void computeSplitKernel(
// Read workload info for this block
WorkloadInfo<IdxT> workload_info_cta = workload_info[blockIdx.x];
IdxT nid = workload_info_cta.nodeid;
IdxT large_nid = workload_info_cta.large_nodeid;
auto node = nodes[nid];
auto range_start = node.start;
auto range_len = node.count;
Expand Down Expand Up @@ -358,7 +361,7 @@ __global__ void computeSplitKernel(
__syncthreads();
if (num_blocks > 1) {
// update the corresponding global location
auto histOffset = ((nid * gridDim.y) + blockIdx.y) * pdf_shist_len;
auto histOffset = ((large_nid * gridDim.y) + blockIdx.y) * pdf_shist_len;
for (IdxT i = threadIdx.x; i < pdf_shist_len; i += blockDim.x) {
BinT::AtomicAdd(hist + histOffset + i, pdf_shist[i]);
}
Expand Down
4 changes: 2 additions & 2 deletions python/cuml/ensemble/randomforest_common.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ class BaseRandomForestModel(Base):

classes_ = CumlArrayDescriptor()

def __init__(self, *, split_criterion, n_streams=8, n_estimators=100,
def __init__(self, *, split_criterion, n_streams=4, n_estimators=100,
max_depth=16, handle=None, max_features='auto', n_bins=128,
split_algo=1, bootstrap=True,
verbose=False, min_samples_leaf=1, min_samples_split=2,
Expand All @@ -65,7 +65,7 @@ class BaseRandomForestModel(Base):
min_impurity_split=None, oob_score=None, random_state=None,
warm_start=None, class_weight=None,
criterion=None, use_experimental_backend=True,
max_batch_size=128):
max_batch_size=4096):

sklearn_params = {"criterion": criterion,
"min_weight_fraction_leaf": min_weight_fraction_leaf,
Expand Down
4 changes: 3 additions & 1 deletion python/cuml/ensemble/randomforestclassifier.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,8 @@ class RandomForestClassifier(BaseRandomForestModel,
Number of bins used by the split algorithm.
For large problems, particularly those with highly-skewed input data,
increasing the number of bins may improve accuracy.
n_streams : int (default = 4 )
Number of parallel streams used for forest building
min_samples_leaf : int or float (default = 1)
The minimum number of samples (rows) in each leaf node.
If int, then min_samples_leaf represents the minimum number.
Expand All @@ -243,7 +245,7 @@ class RandomForestClassifier(BaseRandomForestModel,
use_experimental_backend : boolean (default = True)
Deprecated and currrently has no effect.
.. deprecated:: 21.08
max_batch_size: int (default = 128)
max_batch_size: int (default = 4096)
Maximum number of nodes that can be processed in a given batch. This is
used only when 'use_experimental_backend' is true. Does not currently
fully guarantee the exact same results.
Expand Down
5 changes: 3 additions & 2 deletions python/cuml/ensemble/randomforestregressor.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,8 @@ class RandomForestRegressor(BaseRandomForestModel,
Number of bins used by the split algorithm.
For large problems, particularly those with highly-skewed input data,
increasing the number of bins may improve accuracy.
n_streams : int (default = 4 )
Number of parallel streams used for forest building
min_samples_leaf : int or float (default = 1)
The minimum number of samples (rows) in each leaf node.
If int, then min_samples_leaf represents the minimum number.
Expand All @@ -230,14 +232,13 @@ class RandomForestRegressor(BaseRandomForestModel,
use_experimental_backend : boolean (default = True)
Deprecated and currrently has no effect.
.. deprecated:: 21.08
max_batch_size: int (default = 128)
max_batch_size: int (default = 4096)
Maximum number of nodes that can be processed in a given batch. This is
used only when 'use_experimental_backend' is true.
random_state : int (default = None)
Seed for the random number generator. Unseeded by default. Does not
currently fully guarantee the exact same results. **Note: Parameter
`seed` is removed since release 0.19.**
handle : cuml.Handle
Specifies the cuml.handle that holds internal CUDA state for
computations in this model. Most importantly, this specifies the CUDA
Expand Down

0 comments on commit d372d3d

Please sign in to comment.