Skip to content

Commit

Permalink
RF param initialization cython and C++ layer cleanup (#3358)
Browse files Browse the repository at this point in the history
* This PR partially solves the issue raised [here](#3089 (comment)).
* Removes unused `DecisionTreeParams` struct in `randomforest_shared.pxd`.
* Unifies the different APIs (namely `set_rf_params`, `set_all_rf_params`, `set_rf_class_obj`) into a single point of parameter initialization (as `set_rf_params`) in the C++ layer; and propagating the changes.

Authors:
  - Venkat (@venkywonka)
  - John Zedlewski (@JohnZed)

Approvers:
  - Philip Hyunsu Cho (@hcho3)
  - John Zedlewski (@JohnZed)
  - Thejaswi. N. S (@teju85)

URL: #3358
  • Loading branch information
venkywonka authored Mar 17, 2021
1 parent de42e7f commit 14bd6c1
Show file tree
Hide file tree
Showing 14 changed files with 196 additions and 284 deletions.
42 changes: 19 additions & 23 deletions cpp/bench/sg/fil.cu
Original file line number Diff line number Diff line change
Expand Up @@ -143,29 +143,25 @@ std::vector<Params> getInputs() {
.shuffle = false,
.seed = 12345ULL};

set_rf_params(p.rf, // Output RF parameters
1, // n_trees, just a placeholder value,
// anyway changed below
true, // bootstrap
1.f, // max_samples
1234ULL, // seed
8); // n_streams

set_tree_params(p.rf.tree_params, // Output tree parameters
10, // max_depth, just a placeholder value,
// anyway changed below
(1 << 20), // max_leaves
1, // max_features
32, // n_bins
1, // split_algo
3, // min_samples_leaf
3, // min_samples_split
0.0f, // min_impurity_decrease
true, // bootstrap_features
ML::CRITERION::MSE, // split_criterion
false, // quantile_per_tree
false, // use_experimental_backend
128); // max_batch_size
p.rf = set_rf_params(10, /*max_depth */
(1 << 20), /* max_leaves */
1.f, /* max_features */
32, /* n_bins */
1, /* split_algo */
3, /* min_samples_leaf */
3, /* min_samples_split */
0.0f, /* min_impurity_decrease */
true, /* bootstrap_features */
true, /* bootstrap */
1, /* n_trees */
1.f, /* max_samples */
1234ULL, /* seed */
ML::CRITERION::MSE, /* split_criterion */
false, /* quantile_per_tree */
8, /* n_streams */
false, /* use_experimental_backend */
128 /* max_batch_size */
);

using ML::fil::algo_t;
using ML::fil::storage_type_t;
Expand Down
41 changes: 19 additions & 22 deletions cpp/bench/sg/rf_classifier.cu
Original file line number Diff line number Diff line change
Expand Up @@ -83,28 +83,25 @@ std::vector<Params> getInputs() {
10.0, // center_box_max
2152953ULL}; //seed

set_rf_params(p.rf, // Output RF parameters
500, // n_trees
true, // bootstrap
1.f, // max_samples
1234ULL, // seed
8); // n_streams

set_tree_params(p.rf.tree_params, // Output tree parameters
10, // max_depth, this is anyway changed below
(1 << 20), // max_leaves
0.3, // max_features, just a placeholder value,
// anyway changed below
32, // n_bins
1, // split_algo
3, // min_samples_leaf
3, // min_samples_split
0.0f, // min_impurity_decrease
true, // bootstrap_features
ML::CRITERION::GINI, // split_criterion
false, // quantile_per_tree
false, // use_experimental_backend
128); // max_batch_size
p.rf = set_rf_params(10, /*max_depth */
(1 << 20), /* max_leaves */
0.3, /* max_features */
32, /* n_bins */
1, /* split_algo */
3, /* min_samples_leaf */
3, /* min_samples_split */
0.0f, /* min_impurity_decrease */
true, /* bootstrap_features */
true, /* bootstrap */
500, /* n_trees */
1.f, /* max_samples */
1234ULL, /* seed */
ML::CRITERION::GINI, /* split_criterion */
false, /* quantile_per_tree */
8, /* n_streams */
false, /* use_experimental_backend */
128 /* max_batch_size */
);

std::vector<Triplets> rowcols = {
{160000, 64, 2},
Expand Down
43 changes: 19 additions & 24 deletions cpp/bench/sg/rf_regressor.cu
Original file line number Diff line number Diff line change
Expand Up @@ -85,30 +85,25 @@ std::vector<RegParams> getInputs() {
.noise = 1.0,
.seed = 12345ULL};

set_rf_params(p.rf, // Output RF parameters
500, // n_trees
true, // bootstrap
1.f, // max_samples
1234ULL, // seed
8); // n_streams

set_tree_params(p.rf.tree_params, // Output tree parameters
10, // max_depth, just a place holder value,
// anyway changed below
(1 << 20), // max_leaves
0.3, // max_features, just a place holder value,
// anyway changed below
32, // n_bins
1, // split_algo
3, // min_samples_leaf
3, // min_samples_split
0.0f, // min_impurity_decrease
true, // bootstrap_features
ML::CRITERION::MSE, // split_criterion
false, // quantile_per_tree
false, // use_experimental_backend
128); // max_batch_size

p.rf = set_rf_params(10, /*max_depth */
(1 << 20), /* max_leaves */
0.3, /* max_features */
32, /* n_bins */
1, /* split_algo */
3, /* min_samples_leaf */
3, /* min_samples_split */
0.0f, /* min_impurity_decrease */
true, /* bootstrap_features */
true, /* bootstrap */
500, /* n_trees */
1.f, /* max_samples */
1234ULL, /* seed */
ML::CRITERION::MSE, /* split_criterion */
false, /* quantile_per_tree */
8, /* n_streams */
false, /* use_experimental_backend */
128 /* max_batch_size */
);
std::vector<DimInfo> dim_info = {{500000, 500, 400}};
for (auto& di : dim_info) {
// Let's run Bosch only for float type
Expand Down
23 changes: 8 additions & 15 deletions cpp/include/cuml/ensemble/randomforest.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -87,13 +87,6 @@ struct RF_params {
DecisionTree::DecisionTreeParams tree_params;
};

void set_rf_params(RF_params& params, int cfg_n_trees = 1,
bool cfg_bootstrap = true, float cfg_max_samples = 1.0f,
uint64_t cfg_seed = 0, int cfg_n_streams = 8);
void set_all_rf_params(RF_params& params, int cfg_n_trees, bool cfg_bootstrap,
float cfg_max_samples, uint64_t cfg_seed,
int cfg_n_streams,
DecisionTree::DecisionTreeParams cfg_tree_params);
void validity_check(const RF_params rf_params);
void print(const RF_params rf_params);

Expand Down Expand Up @@ -187,14 +180,14 @@ RF_metrics score(const raft::handle_t& user_handle,
int n_rows, const int* predictions,
int verbosity = CUML_LEVEL_INFO);

RF_params set_rf_class_obj(int max_depth, int max_leaves, float max_features,
int n_bins, int split_algo, int min_samples_leaf,
int min_samples_split, float min_impurity_decrease,
bool bootstrap_features, bool bootstrap, int n_trees,
float max_samples, uint64_t seed,
CRITERION split_criterion, bool quantile_per_tree,
int cfg_n_streams, bool use_experimental_backend,
int max_batch_size);
RF_params set_rf_params(int max_depth, int max_leaves, float max_features,
int n_bins, int split_algo, int min_samples_leaf,
int min_samples_split, float min_impurity_decrease,
bool bootstrap_features, bool bootstrap, int n_trees,
float max_samples, uint64_t seed,
CRITERION split_criterion, bool quantile_per_tree,
int cfg_n_streams, bool use_experimental_backend,
int max_batch_size);

// ----------------------------- Regression ----------------------------------- //

Expand Down
72 changes: 15 additions & 57 deletions cpp/src/randomforest/randomforest.cu
Original file line number Diff line number Diff line change
Expand Up @@ -157,53 +157,6 @@ void postprocess_labels(int n_rows, std::vector<int>& labels,
CUML_LOG_DEBUG("Finished postrocessing labels");
}

/**
* @brief Set RF_params parameters members; use default tree parameters.
* @param[in,out] params: update with random forest parameters
* @param[in] cfg_n_trees: number of trees; default 1
* @param[in] cfg_bootstrap: bootstrapping; default true
* @param[in] cfg_max_samples: rows sample; default 1.0f
* @param[in] cfg_n_streams: No of parallel CUDA for training forest
*/
void set_rf_params(RF_params& params, int cfg_n_trees, bool cfg_bootstrap,
float cfg_max_samples, uint64_t cfg_seed,
int cfg_n_streams) {
params.n_trees = cfg_n_trees;
params.bootstrap = cfg_bootstrap;
params.max_samples = cfg_max_samples;
params.seed = cfg_seed;
params.n_streams = min(cfg_n_streams, omp_get_max_threads());
if (params.n_streams == cfg_n_streams) {
CUML_LOG_WARN("Warning! Max setting Max streams to max openmp threads %d",
omp_get_max_threads());
}
if (cfg_n_trees < params.n_streams) params.n_streams = cfg_n_trees;
set_tree_params(params.tree_params); // use default tree params
}

/**
* @brief Set all RF_params parameters members, including tree parameters.
* @param[in,out] params: update with random forest parameters
* @param[in] cfg_n_trees: number of trees
* @param[in] cfg_bootstrap: bootstrapping
* @param[in] cfg_max_samples: rows sample
* @param[in] cfg_n_streams: No of parallel CUDA for training forest
* @param[in] cfg_tree_params: tree parameters
*/
void set_all_rf_params(RF_params& params, int cfg_n_trees, bool cfg_bootstrap,
float cfg_max_samples, uint64_t cfg_seed,
int cfg_n_streams,
DecisionTree::DecisionTreeParams cfg_tree_params) {
params.n_trees = cfg_n_trees;
params.bootstrap = cfg_bootstrap;
params.max_samples = cfg_max_samples;
params.seed = cfg_seed;
params.n_streams = min(cfg_n_streams, omp_get_max_threads());
if (cfg_n_trees < params.n_streams) params.n_streams = cfg_n_trees;
set_tree_params(params.tree_params); // use input tree params
params.tree_params = cfg_tree_params;
}

/**
* @brief Check validity of all random forest hyper-parameters.
* @param[in] rf_params: random forest hyper-parameters
Expand Down Expand Up @@ -657,23 +610,28 @@ RF_metrics score(const raft::handle_t& user_handle,
return classification_score;
}

RF_params set_rf_class_obj(int max_depth, int max_leaves, float max_features,
int n_bins, int split_algo, int min_samples_leaf,
int min_samples_split, float min_impurity_decrease,
bool bootstrap_features, bool bootstrap, int n_trees,
float max_samples, uint64_t seed,
CRITERION split_criterion, bool quantile_per_tree,
int cfg_n_streams, bool use_experimental_backend,
int max_batch_size) {
RF_params set_rf_params(int max_depth, int max_leaves, float max_features,
int n_bins, int split_algo, int min_samples_leaf,
int min_samples_split, float min_impurity_decrease,
bool bootstrap_features, bool bootstrap, int n_trees,
float max_samples, uint64_t seed,
CRITERION split_criterion, bool quantile_per_tree,
int cfg_n_streams, bool use_experimental_backend,
int max_batch_size) {
DecisionTree::DecisionTreeParams tree_params;
DecisionTree::set_tree_params(
tree_params, max_depth, max_leaves, max_features, n_bins, split_algo,
min_samples_leaf, min_samples_split, min_impurity_decrease,
bootstrap_features, split_criterion, quantile_per_tree,
use_experimental_backend, max_batch_size);
RF_params rf_params;
set_all_rf_params(rf_params, n_trees, bootstrap, max_samples, seed,
cfg_n_streams, tree_params);
rf_params.n_trees = n_trees;
rf_params.bootstrap = bootstrap;
rf_params.max_samples = max_samples;
rf_params.seed = seed;
rf_params.n_streams = min(cfg_n_streams, omp_get_max_threads());
if (n_trees < rf_params.n_streams) rf_params.n_streams = n_trees;
rf_params.tree_params = tree_params;
return rf_params;
}

Expand Down
37 changes: 19 additions & 18 deletions cpp/test/sg/rf_accuracy_test.cu
Original file line number Diff line number Diff line change
Expand Up @@ -81,27 +81,28 @@ class RFClassifierAccuracyTest : public ::testing::TestWithParam<RFInputs> {

private:
void setRFParams() {
DecisionTree::DecisionTreeParams tree_params;
auto algo = SPLIT_ALGO::GLOBAL_QUANTILE;
auto sc = CRITERION::CRITERION_END;
set_tree_params(tree_params, 0, /* max_depth */
-1, /* max_leaves */
1.0, /* max_features */
16, /* n_bins */
algo, /* split_algo */
2, /* min_samples_leaf */
2, /* min_samples_split */
0.f, /* min_impurity_decrease */
false, /* bootstrap_features */
sc, /* split_criterion */
false /* quantile_per_tree */

rfp = set_rf_params(0, /*max_depth */
-1, /* max_leaves */
1.0, /* max_features */
16, /* n_bins */
algo, /* split_algo */
2, /* min_samples_leaf */
2, /* min_samples_split */
0.f, /* min_impurity_decrease */
false, /* bootstrap_features */
true, /* bootstrap */
1, /* n_trees */
1.0, /* max_samples */
0, /* seed */
sc, /* split_criterion */
false, /* quantile_per_tree */
1, /* n_streams */
false, /* use_experimental_backend */
128 /* max_batch_size */
);
set_all_rf_params(rfp, 1, /* n_trees */
true, /* bootstrap */
1.0, /* max_samples */
0, /* seed */
1, /* n_streams */
tree_params);
}

void loadData(T *X, int *y, int nrows, int ncols) {
Expand Down
14 changes: 6 additions & 8 deletions cpp/test/sg/rf_batched_classification_test.cu
Original file line number Diff line number Diff line change
Expand Up @@ -52,15 +52,13 @@ class RFBatchedClsTest : public ::testing::TestWithParam<RfInputs> {
void basicTest() {
params = ::testing::TestWithParam<RfInputs>::GetParam();

DecisionTree::DecisionTreeParams tree_params;
set_tree_params(tree_params, params.max_depth, params.max_leaves,
params.max_features, params.n_bins, params.split_algo,
params.min_samples_leaf, params.min_samples_split,
params.min_impurity_decrease, params.bootstrap_features,
params.split_criterion, false, true);
RF_params rf_params;
set_all_rf_params(rf_params, params.n_trees, params.bootstrap,
params.max_samples, 0, params.n_streams, tree_params);
rf_params = set_rf_params(
params.max_depth, params.max_leaves, params.max_features, params.n_bins,
params.split_algo, params.min_samples_leaf, params.min_samples_split,
params.min_impurity_decrease, params.bootstrap_features, params.bootstrap,
params.n_trees, params.max_samples, 0, params.split_criterion, false,
params.n_streams, true, 128);

CUDA_CHECK(cudaStreamCreate(&stream));
handle.reset(new raft::handle_t(rf_params.n_streams));
Expand Down
14 changes: 6 additions & 8 deletions cpp/test/sg/rf_batched_regression_test.cu
Original file line number Diff line number Diff line change
Expand Up @@ -54,15 +54,13 @@ class RFBatchedRegTest : public ::testing::TestWithParam<RfInputs> {
void basicTest() {
params = ::testing::TestWithParam<RfInputs>::GetParam();

DecisionTree::DecisionTreeParams tree_params;
set_tree_params(tree_params, params.max_depth, params.max_leaves,
params.max_features, params.n_bins, params.split_algo,
params.min_samples_leaf, params.min_samples_split,
params.min_impurity_decrease, params.bootstrap_features,
params.split_criterion, false, true);
RF_params rf_params;
set_all_rf_params(rf_params, params.n_trees, params.bootstrap,
params.max_samples, 0, params.n_streams, tree_params);
rf_params = set_rf_params(
params.max_depth, params.max_leaves, params.max_features, params.n_bins,
params.split_algo, params.min_samples_leaf, params.min_samples_split,
params.min_impurity_decrease, params.bootstrap_features, params.bootstrap,
params.n_trees, params.max_samples, 0, params.split_criterion, false,
params.n_streams, true, 128);

CUDA_CHECK(cudaStreamCreate(&stream));
handle.reset(new raft::handle_t(rf_params.n_streams));
Expand Down
Loading

0 comments on commit 14bd6c1

Please sign in to comment.