Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RF: memset and batch size optimization for computing splits #4001

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion cpp/include/cuml/tree/decisiontree.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ void set_tree_params(DecisionTreeParams &params, int cfg_max_depth = -1,
int cfg_min_samples_split = 2,
float cfg_min_impurity_decrease = 0.0f,
CRITERION cfg_split_criterion = CRITERION_END,
int cfg_max_batch_size = 128);
int cfg_max_batch_size = 4096);

/**
* @brief Check validity of all decision tree hyper-parameters.
Expand Down
14 changes: 11 additions & 3 deletions cpp/src/decisiontree/batched-levelalgo/builder_base.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -327,13 +327,17 @@ struct Builder {
raft::update_device(curr_nodes, h_nodes.data() + node_start, batchSize, s);

int total_samples_in_curr_batch = 0;
int n_large_nodes_in_curr_batch =
0; // large nodes are nodes having training instances larger than block size, hence require global memory for histogram construction
total_num_blocks = 0;
for (int n = 0; n < batchSize; n++) {
total_samples_in_curr_batch += h_nodes[node_start + n].count;
int num_blocks = raft::ceildiv(h_nodes[node_start + n].count,
SAMPLES_PER_THREAD * TPB_DEFAULT);
num_blocks = std::max(1, num_blocks);

if (num_blocks > 1) ++n_large_nodes_in_curr_batch;

bool is_leaf = leafBasedOnParams<DataT, IdxT>(
h_nodes[node_start + n].depth, params.max_depth,
params.min_samples_split, params.max_leaves, h_n_leaves,
Expand All @@ -342,6 +346,8 @@ struct Builder {

for (int b = 0; b < num_blocks; b++) {
h_workload_info[total_num_blocks + b].nodeid = n;
h_workload_info[total_num_blocks + b].large_nodeid =
n_large_nodes_in_curr_batch - 1;
h_workload_info[total_num_blocks + b].offset_blockid = b;
h_workload_info[total_num_blocks + b].num_blocks = num_blocks;
}
Expand All @@ -353,7 +359,8 @@ struct Builder {
auto n_col_blks = n_blks_for_cols;
if (total_num_blocks) {
for (IdxT c = 0; c < input.nSampledCols; c += n_col_blks) {
computeSplit(c, batchSize, params.split_criterion, s);
computeSplit(c, batchSize, params.split_criterion,
n_large_nodes_in_curr_batch, s);
CUDA_CHECK(cudaGetLastError());
}
}
Expand Down Expand Up @@ -387,7 +394,7 @@ struct Builder {
* @param[in] s cuda stream
*/
void computeSplit(IdxT col, IdxT batchSize, CRITERION splitType,
cudaStream_t s) {
const int n_large_nodes_in_curr_batch, cudaStream_t s) {
ML::PUSH_RANGE(
"Builder::computeSplit @builder_base.cuh [batched-levelalgo]");
auto nbins = params.n_bins;
Expand All @@ -407,7 +414,8 @@ struct Builder {
// Pick the max of two
size_t smemSize = std::max(smemSize1, smemSize2);
dim3 grid(total_num_blocks, colBlks, 1);
CUDA_CHECK(cudaMemsetAsync(hist, 0, sizeof(int) * nHistBins, s));
int nHistBins = n_large_nodes_in_curr_batch * nbins * colBlks * nclasses;
CUDA_CHECK(cudaMemsetAsync(hist, 0, sizeof(BinT) * nHistBins, s));
ML::PUSH_RANGE(
"computeSplitClassificationKernel @builder_base.cuh [batched-levelalgo]");
ObjectiveT objective(input.numOutputs, params.min_impurity_decrease,
Expand Down
5 changes: 4 additions & 1 deletion cpp/src/decisiontree/batched-levelalgo/kernels.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@ namespace DecisionTree {
template <typename IdxT>
struct WorkloadInfo {
IdxT nodeid; // Node in the batch on which the threadblock needs to work
IdxT
large_nodeid; // counts only large nodes (nodes that require more than one block along x-dim for histogram calculation)
IdxT offset_blockid; // Offset threadblock id among all the blocks that are
// working on this node
IdxT num_blocks; // Total number of blocks that are working on the node
Expand Down Expand Up @@ -305,6 +307,7 @@ __global__ void computeSplitKernel(
// Read workload info for this block
WorkloadInfo<IdxT> workload_info_cta = workload_info[blockIdx.x];
IdxT nid = workload_info_cta.nodeid;
IdxT large_nid = workload_info_cta.large_nodeid;
auto node = nodes[nid];
auto range_start = node.start;
auto range_len = node.count;
Expand Down Expand Up @@ -358,7 +361,7 @@ __global__ void computeSplitKernel(
__syncthreads();
if (num_blocks > 1) {
// update the corresponding global location
auto histOffset = ((nid * gridDim.y) + blockIdx.y) * pdf_shist_len;
auto histOffset = ((large_nid * gridDim.y) + blockIdx.y) * pdf_shist_len;
for (IdxT i = threadIdx.x; i < pdf_shist_len; i += blockDim.x) {
BinT::AtomicAdd(hist + histOffset + i, pdf_shist[i]);
}
Expand Down
4 changes: 2 additions & 2 deletions python/cuml/ensemble/randomforest_common.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ class BaseRandomForestModel(Base):

classes_ = CumlArrayDescriptor()

def __init__(self, *, split_criterion, n_streams=8, n_estimators=100,
def __init__(self, *, split_criterion, n_streams=4, n_estimators=100,
max_depth=16, handle=None, max_features='auto', n_bins=128,
split_algo=1, bootstrap=True,
verbose=False, min_samples_leaf=1, min_samples_split=2,
Expand All @@ -65,7 +65,7 @@ class BaseRandomForestModel(Base):
min_impurity_split=None, oob_score=None, random_state=None,
warm_start=None, class_weight=None,
criterion=None, use_experimental_backend=True,
max_batch_size=128):
max_batch_size=4096):

sklearn_params = {"criterion": criterion,
"min_weight_fraction_leaf": min_weight_fraction_leaf,
Expand Down
4 changes: 3 additions & 1 deletion python/cuml/ensemble/randomforestclassifier.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,8 @@ class RandomForestClassifier(BaseRandomForestModel,
Number of bins used by the split algorithm.
For large problems, particularly those with highly-skewed input data,
increasing the number of bins may improve accuracy.
n_streams : int (default = 4 )
Number of parallel streams used for forest building
min_samples_leaf : int or float (default = 1)
The minimum number of samples (rows) in each leaf node.
If int, then min_samples_leaf represents the minimum number.
Expand All @@ -243,7 +245,7 @@ class RandomForestClassifier(BaseRandomForestModel,
use_experimental_backend : boolean (default = True)
Deprecated and currrently has no effect.
.. deprecated:: 21.08
max_batch_size: int (default = 128)
max_batch_size: int (default = 4096)
Maximum number of nodes that can be processed in a given batch. This is
used only when 'use_experimental_backend' is true. Does not currently
fully guarantee the exact same results.
Expand Down
5 changes: 3 additions & 2 deletions python/cuml/ensemble/randomforestregressor.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,8 @@ class RandomForestRegressor(BaseRandomForestModel,
Number of bins used by the split algorithm.
For large problems, particularly those with highly-skewed input data,
increasing the number of bins may improve accuracy.
n_streams : int (default = 4 )
Number of parallel streams used for forest building
min_samples_leaf : int or float (default = 1)
The minimum number of samples (rows) in each leaf node.
If int, then min_samples_leaf represents the minimum number.
Expand All @@ -230,14 +232,13 @@ class RandomForestRegressor(BaseRandomForestModel,
use_experimental_backend : boolean (default = True)
Deprecated and currrently has no effect.
.. deprecated:: 21.08
max_batch_size: int (default = 128)
max_batch_size: int (default = 4096)
Maximum number of nodes that can be processed in a given batch. This is
used only when 'use_experimental_backend' is true.
random_state : int (default = None)
Seed for the random number generator. Unseeded by default. Does not
currently fully guarantee the exact same results. **Note: Parameter
`seed` is removed since release 0.19.**

handle : cuml.Handle
Specifies the cuml.handle that holds internal CUDA state for
computations in this model. Most importantly, this specifies the CUDA
Expand Down