diff --git a/cpp/bench/sg/fil.cu b/cpp/bench/sg/fil.cu index 406bc4d9e0..d1a28f3a5b 100644 --- a/cpp/bench/sg/fil.cu +++ b/cpp/bench/sg/fil.cu @@ -143,29 +143,25 @@ std::vector getInputs() { .shuffle = false, .seed = 12345ULL}; - set_rf_params(p.rf, // Output RF parameters - 1, // n_trees, just a placeholder value, - // anyway changed below - true, // bootstrap - 1.f, // max_samples - 1234ULL, // seed - 8); // n_streams - - set_tree_params(p.rf.tree_params, // Output tree parameters - 10, // max_depth, just a placeholder value, - // anyway changed below - (1 << 20), // max_leaves - 1, // max_features - 32, // n_bins - 1, // split_algo - 3, // min_samples_leaf - 3, // min_samples_split - 0.0f, // min_impurity_decrease - true, // bootstrap_features - ML::CRITERION::MSE, // split_criterion - false, // quantile_per_tree - false, // use_experimental_backend - 128); // max_batch_size + p.rf = set_rf_params(10, /*max_depth */ + (1 << 20), /* max_leaves */ + 1.f, /* max_features */ + 32, /* n_bins */ + 1, /* split_algo */ + 3, /* min_samples_leaf */ + 3, /* min_samples_split */ + 0.0f, /* min_impurity_decrease */ + true, /* bootstrap_features */ + true, /* bootstrap */ + 1, /* n_trees */ + 1.f, /* max_samples */ + 1234ULL, /* seed */ + ML::CRITERION::MSE, /* split_criterion */ + false, /* quantile_per_tree */ + 8, /* n_streams */ + false, /* use_experimental_backend */ + 128 /* max_batch_size */ + ); using ML::fil::algo_t; using ML::fil::storage_type_t; diff --git a/cpp/bench/sg/rf_classifier.cu b/cpp/bench/sg/rf_classifier.cu index 29a9d0c371..769d7e7488 100644 --- a/cpp/bench/sg/rf_classifier.cu +++ b/cpp/bench/sg/rf_classifier.cu @@ -83,28 +83,25 @@ std::vector getInputs() { 10.0, // center_box_max 2152953ULL}; //seed - set_rf_params(p.rf, // Output RF parameters - 500, // n_trees - true, // bootstrap - 1.f, // max_samples - 1234ULL, // seed - 8); // n_streams - - set_tree_params(p.rf.tree_params, // Output tree parameters - 10, // max_depth, this is anyway changed below - (1 << 20), // max_leaves - 0.3, // max_features, just a placeholder value, - // anyway changed below - 32, // n_bins - 1, // split_algo - 3, // min_samples_leaf - 3, // min_samples_split - 0.0f, // min_impurity_decrease - true, // bootstrap_features - ML::CRITERION::GINI, // split_criterion - false, // quantile_per_tree - false, // use_experimental_backend - 128); // max_batch_size + p.rf = set_rf_params(10, /*max_depth */ + (1 << 20), /* max_leaves */ + 0.3, /* max_features */ + 32, /* n_bins */ + 1, /* split_algo */ + 3, /* min_samples_leaf */ + 3, /* min_samples_split */ + 0.0f, /* min_impurity_decrease */ + true, /* bootstrap_features */ + true, /* bootstrap */ + 500, /* n_trees */ + 1.f, /* max_samples */ + 1234ULL, /* seed */ + ML::CRITERION::GINI, /* split_criterion */ + false, /* quantile_per_tree */ + 8, /* n_streams */ + false, /* use_experimental_backend */ + 128 /* max_batch_size */ + ); std::vector rowcols = { {160000, 64, 2}, diff --git a/cpp/bench/sg/rf_regressor.cu b/cpp/bench/sg/rf_regressor.cu index a6771ae567..9a39446705 100644 --- a/cpp/bench/sg/rf_regressor.cu +++ b/cpp/bench/sg/rf_regressor.cu @@ -85,30 +85,25 @@ std::vector getInputs() { .noise = 1.0, .seed = 12345ULL}; - set_rf_params(p.rf, // Output RF parameters - 500, // n_trees - true, // bootstrap - 1.f, // max_samples - 1234ULL, // seed - 8); // n_streams - - set_tree_params(p.rf.tree_params, // Output tree parameters - 10, // max_depth, just a place holder value, - // anyway changed below - (1 << 20), // max_leaves - 0.3, // max_features, just a place holder value, - // anyway changed below - 32, // n_bins - 1, // split_algo - 3, // min_samples_leaf - 3, // min_samples_split - 0.0f, // min_impurity_decrease - true, // bootstrap_features - ML::CRITERION::MSE, // split_criterion - false, // quantile_per_tree - false, // use_experimental_backend - 128); // max_batch_size - + p.rf = set_rf_params(10, /*max_depth */ + (1 << 20), /* max_leaves */ + 0.3, /* max_features */ + 32, /* n_bins */ + 1, /* split_algo */ + 3, /* min_samples_leaf */ + 3, /* min_samples_split */ + 0.0f, /* min_impurity_decrease */ + true, /* bootstrap_features */ + true, /* bootstrap */ + 500, /* n_trees */ + 1.f, /* max_samples */ + 1234ULL, /* seed */ + ML::CRITERION::MSE, /* split_criterion */ + false, /* quantile_per_tree */ + 8, /* n_streams */ + false, /* use_experimental_backend */ + 128 /* max_batch_size */ + ); std::vector dim_info = {{500000, 500, 400}}; for (auto& di : dim_info) { // Let's run Bosch only for float type diff --git a/cpp/include/cuml/ensemble/randomforest.hpp b/cpp/include/cuml/ensemble/randomforest.hpp index 525d6e07ea..9321b93307 100644 --- a/cpp/include/cuml/ensemble/randomforest.hpp +++ b/cpp/include/cuml/ensemble/randomforest.hpp @@ -87,13 +87,6 @@ struct RF_params { DecisionTree::DecisionTreeParams tree_params; }; -void set_rf_params(RF_params& params, int cfg_n_trees = 1, - bool cfg_bootstrap = true, float cfg_max_samples = 1.0f, - uint64_t cfg_seed = 0, int cfg_n_streams = 8); -void set_all_rf_params(RF_params& params, int cfg_n_trees, bool cfg_bootstrap, - float cfg_max_samples, uint64_t cfg_seed, - int cfg_n_streams, - DecisionTree::DecisionTreeParams cfg_tree_params); void validity_check(const RF_params rf_params); void print(const RF_params rf_params); @@ -187,14 +180,14 @@ RF_metrics score(const raft::handle_t& user_handle, int n_rows, const int* predictions, int verbosity = CUML_LEVEL_INFO); -RF_params set_rf_class_obj(int max_depth, int max_leaves, float max_features, - int n_bins, int split_algo, int min_samples_leaf, - int min_samples_split, float min_impurity_decrease, - bool bootstrap_features, bool bootstrap, int n_trees, - float max_samples, uint64_t seed, - CRITERION split_criterion, bool quantile_per_tree, - int cfg_n_streams, bool use_experimental_backend, - int max_batch_size); +RF_params set_rf_params(int max_depth, int max_leaves, float max_features, + int n_bins, int split_algo, int min_samples_leaf, + int min_samples_split, float min_impurity_decrease, + bool bootstrap_features, bool bootstrap, int n_trees, + float max_samples, uint64_t seed, + CRITERION split_criterion, bool quantile_per_tree, + int cfg_n_streams, bool use_experimental_backend, + int max_batch_size); // ----------------------------- Regression ----------------------------------- // diff --git a/cpp/src/randomforest/randomforest.cu b/cpp/src/randomforest/randomforest.cu index 632e9a485d..83ffdbf2c1 100644 --- a/cpp/src/randomforest/randomforest.cu +++ b/cpp/src/randomforest/randomforest.cu @@ -157,53 +157,6 @@ void postprocess_labels(int n_rows, std::vector& labels, CUML_LOG_DEBUG("Finished postrocessing labels"); } -/** - * @brief Set RF_params parameters members; use default tree parameters. - * @param[in,out] params: update with random forest parameters - * @param[in] cfg_n_trees: number of trees; default 1 - * @param[in] cfg_bootstrap: bootstrapping; default true - * @param[in] cfg_max_samples: rows sample; default 1.0f - * @param[in] cfg_n_streams: No of parallel CUDA for training forest - */ -void set_rf_params(RF_params& params, int cfg_n_trees, bool cfg_bootstrap, - float cfg_max_samples, uint64_t cfg_seed, - int cfg_n_streams) { - params.n_trees = cfg_n_trees; - params.bootstrap = cfg_bootstrap; - params.max_samples = cfg_max_samples; - params.seed = cfg_seed; - params.n_streams = min(cfg_n_streams, omp_get_max_threads()); - if (params.n_streams == cfg_n_streams) { - CUML_LOG_WARN("Warning! Max setting Max streams to max openmp threads %d", - omp_get_max_threads()); - } - if (cfg_n_trees < params.n_streams) params.n_streams = cfg_n_trees; - set_tree_params(params.tree_params); // use default tree params -} - -/** - * @brief Set all RF_params parameters members, including tree parameters. - * @param[in,out] params: update with random forest parameters - * @param[in] cfg_n_trees: number of trees - * @param[in] cfg_bootstrap: bootstrapping - * @param[in] cfg_max_samples: rows sample - * @param[in] cfg_n_streams: No of parallel CUDA for training forest - * @param[in] cfg_tree_params: tree parameters - */ -void set_all_rf_params(RF_params& params, int cfg_n_trees, bool cfg_bootstrap, - float cfg_max_samples, uint64_t cfg_seed, - int cfg_n_streams, - DecisionTree::DecisionTreeParams cfg_tree_params) { - params.n_trees = cfg_n_trees; - params.bootstrap = cfg_bootstrap; - params.max_samples = cfg_max_samples; - params.seed = cfg_seed; - params.n_streams = min(cfg_n_streams, omp_get_max_threads()); - if (cfg_n_trees < params.n_streams) params.n_streams = cfg_n_trees; - set_tree_params(params.tree_params); // use input tree params - params.tree_params = cfg_tree_params; -} - /** * @brief Check validity of all random forest hyper-parameters. * @param[in] rf_params: random forest hyper-parameters @@ -657,14 +610,14 @@ RF_metrics score(const raft::handle_t& user_handle, return classification_score; } -RF_params set_rf_class_obj(int max_depth, int max_leaves, float max_features, - int n_bins, int split_algo, int min_samples_leaf, - int min_samples_split, float min_impurity_decrease, - bool bootstrap_features, bool bootstrap, int n_trees, - float max_samples, uint64_t seed, - CRITERION split_criterion, bool quantile_per_tree, - int cfg_n_streams, bool use_experimental_backend, - int max_batch_size) { +RF_params set_rf_params(int max_depth, int max_leaves, float max_features, + int n_bins, int split_algo, int min_samples_leaf, + int min_samples_split, float min_impurity_decrease, + bool bootstrap_features, bool bootstrap, int n_trees, + float max_samples, uint64_t seed, + CRITERION split_criterion, bool quantile_per_tree, + int cfg_n_streams, bool use_experimental_backend, + int max_batch_size) { DecisionTree::DecisionTreeParams tree_params; DecisionTree::set_tree_params( tree_params, max_depth, max_leaves, max_features, n_bins, split_algo, @@ -672,8 +625,13 @@ RF_params set_rf_class_obj(int max_depth, int max_leaves, float max_features, bootstrap_features, split_criterion, quantile_per_tree, use_experimental_backend, max_batch_size); RF_params rf_params; - set_all_rf_params(rf_params, n_trees, bootstrap, max_samples, seed, - cfg_n_streams, tree_params); + rf_params.n_trees = n_trees; + rf_params.bootstrap = bootstrap; + rf_params.max_samples = max_samples; + rf_params.seed = seed; + rf_params.n_streams = min(cfg_n_streams, omp_get_max_threads()); + if (n_trees < rf_params.n_streams) rf_params.n_streams = n_trees; + rf_params.tree_params = tree_params; return rf_params; } diff --git a/cpp/test/sg/rf_accuracy_test.cu b/cpp/test/sg/rf_accuracy_test.cu index 7ab48d60b6..df31cf961c 100644 --- a/cpp/test/sg/rf_accuracy_test.cu +++ b/cpp/test/sg/rf_accuracy_test.cu @@ -81,27 +81,28 @@ class RFClassifierAccuracyTest : public ::testing::TestWithParam { private: void setRFParams() { - DecisionTree::DecisionTreeParams tree_params; auto algo = SPLIT_ALGO::GLOBAL_QUANTILE; auto sc = CRITERION::CRITERION_END; - set_tree_params(tree_params, 0, /* max_depth */ - -1, /* max_leaves */ - 1.0, /* max_features */ - 16, /* n_bins */ - algo, /* split_algo */ - 2, /* min_samples_leaf */ - 2, /* min_samples_split */ - 0.f, /* min_impurity_decrease */ - false, /* bootstrap_features */ - sc, /* split_criterion */ - false /* quantile_per_tree */ + + rfp = set_rf_params(0, /*max_depth */ + -1, /* max_leaves */ + 1.0, /* max_features */ + 16, /* n_bins */ + algo, /* split_algo */ + 2, /* min_samples_leaf */ + 2, /* min_samples_split */ + 0.f, /* min_impurity_decrease */ + false, /* bootstrap_features */ + true, /* bootstrap */ + 1, /* n_trees */ + 1.0, /* max_samples */ + 0, /* seed */ + sc, /* split_criterion */ + false, /* quantile_per_tree */ + 1, /* n_streams */ + false, /* use_experimental_backend */ + 128 /* max_batch_size */ ); - set_all_rf_params(rfp, 1, /* n_trees */ - true, /* bootstrap */ - 1.0, /* max_samples */ - 0, /* seed */ - 1, /* n_streams */ - tree_params); } void loadData(T *X, int *y, int nrows, int ncols) { diff --git a/cpp/test/sg/rf_batched_classification_test.cu b/cpp/test/sg/rf_batched_classification_test.cu index 4677d5732d..170b30644d 100644 --- a/cpp/test/sg/rf_batched_classification_test.cu +++ b/cpp/test/sg/rf_batched_classification_test.cu @@ -52,15 +52,13 @@ class RFBatchedClsTest : public ::testing::TestWithParam { void basicTest() { params = ::testing::TestWithParam::GetParam(); - DecisionTree::DecisionTreeParams tree_params; - set_tree_params(tree_params, params.max_depth, params.max_leaves, - params.max_features, params.n_bins, params.split_algo, - params.min_samples_leaf, params.min_samples_split, - params.min_impurity_decrease, params.bootstrap_features, - params.split_criterion, false, true); RF_params rf_params; - set_all_rf_params(rf_params, params.n_trees, params.bootstrap, - params.max_samples, 0, params.n_streams, tree_params); + rf_params = set_rf_params( + params.max_depth, params.max_leaves, params.max_features, params.n_bins, + params.split_algo, params.min_samples_leaf, params.min_samples_split, + params.min_impurity_decrease, params.bootstrap_features, params.bootstrap, + params.n_trees, params.max_samples, 0, params.split_criterion, false, + params.n_streams, true, 128); CUDA_CHECK(cudaStreamCreate(&stream)); handle.reset(new raft::handle_t(rf_params.n_streams)); diff --git a/cpp/test/sg/rf_batched_regression_test.cu b/cpp/test/sg/rf_batched_regression_test.cu index 972b610cdd..b444e16a85 100644 --- a/cpp/test/sg/rf_batched_regression_test.cu +++ b/cpp/test/sg/rf_batched_regression_test.cu @@ -54,15 +54,13 @@ class RFBatchedRegTest : public ::testing::TestWithParam { void basicTest() { params = ::testing::TestWithParam::GetParam(); - DecisionTree::DecisionTreeParams tree_params; - set_tree_params(tree_params, params.max_depth, params.max_leaves, - params.max_features, params.n_bins, params.split_algo, - params.min_samples_leaf, params.min_samples_split, - params.min_impurity_decrease, params.bootstrap_features, - params.split_criterion, false, true); RF_params rf_params; - set_all_rf_params(rf_params, params.n_trees, params.bootstrap, - params.max_samples, 0, params.n_streams, tree_params); + rf_params = set_rf_params( + params.max_depth, params.max_leaves, params.max_features, params.n_bins, + params.split_algo, params.min_samples_leaf, params.min_samples_split, + params.min_impurity_decrease, params.bootstrap_features, params.bootstrap, + params.n_trees, params.max_samples, 0, params.split_criterion, false, + params.n_streams, true, 128); CUDA_CHECK(cudaStreamCreate(&stream)); handle.reset(new raft::handle_t(rf_params.n_streams)); diff --git a/cpp/test/sg/rf_depth_test.cu b/cpp/test/sg/rf_depth_test.cu index d4204343fe..f5e29156e2 100644 --- a/cpp/test/sg/rf_depth_test.cu +++ b/cpp/test/sg/rf_depth_test.cu @@ -67,15 +67,13 @@ class RfClassifierDepthTest : public ::testing::TestWithParam { 2, CRITERION::ENTROPY}; - DecisionTree::DecisionTreeParams tree_params; - set_tree_params(tree_params, params.max_depth, params.max_leaves, - params.max_features, params.n_bins, params.split_algo, - params.min_samples_leaf, params.min_samples_split, - params.min_impurity_decrease, params.bootstrap_features, - params.split_criterion, false); RF_params rf_params; - set_all_rf_params(rf_params, params.n_trees, params.bootstrap, - params.max_samples, 0, params.n_streams, tree_params); + rf_params = set_rf_params( + params.max_depth, params.max_leaves, params.max_features, params.n_bins, + params.split_algo, params.min_samples_leaf, params.min_samples_split, + params.min_impurity_decrease, params.bootstrap_features, params.bootstrap, + params.n_trees, params.max_samples, 0, params.split_criterion, false, + params.n_streams, false, 128); int data_len = params.n_rows * params.n_cols; raft::allocate(data, data_len); @@ -161,15 +159,13 @@ class RfRegressorDepthTest : public ::testing::TestWithParam { 2, CRITERION::MSE}; - DecisionTree::DecisionTreeParams tree_params; - set_tree_params(tree_params, params.max_depth, params.max_leaves, - params.max_features, params.n_bins, params.split_algo, - params.min_samples_leaf, params.min_samples_split, - params.min_impurity_decrease, params.bootstrap_features, - params.split_criterion, false); RF_params rf_params; - set_all_rf_params(rf_params, params.n_trees, params.bootstrap, - params.max_samples, 0, params.n_streams, tree_params); + rf_params = set_rf_params( + params.max_depth, params.max_leaves, params.max_features, params.n_bins, + params.split_algo, params.min_samples_leaf, params.min_samples_split, + params.min_impurity_decrease, params.bootstrap_features, params.bootstrap, + params.n_trees, params.max_samples, 0, params.split_criterion, false, + params.n_streams, false, 128); int data_len = params.n_rows * params.n_cols; raft::allocate(data, data_len); diff --git a/cpp/test/sg/rf_test.cu b/cpp/test/sg/rf_test.cu index b4451cdd05..c9a672b85e 100644 --- a/cpp/test/sg/rf_test.cu +++ b/cpp/test/sg/rf_test.cu @@ -56,16 +56,13 @@ class RfClassifierTest : public ::testing::TestWithParam> { void basicTest() { params = ::testing::TestWithParam>::GetParam(); - DecisionTree::DecisionTreeParams tree_params; - set_tree_params(tree_params, params.max_depth, params.max_leaves, - params.max_features, params.n_bins, params.split_algo, - params.min_samples_leaf, params.min_samples_split, - params.min_impurity_decrease, params.bootstrap_features, - params.split_criterion, false); RF_params rf_params; - set_all_rf_params(rf_params, params.n_trees, params.bootstrap, - params.max_samples, 0, params.n_streams, tree_params); - //print(rf_params); + rf_params = set_rf_params( + params.max_depth, params.max_leaves, params.max_features, params.n_bins, + params.split_algo, params.min_samples_leaf, params.min_samples_split, + params.min_impurity_decrease, params.bootstrap_features, params.bootstrap, + params.n_trees, params.max_samples, 0, params.split_criterion, false, + params.n_streams, false, 128); //-------------------------------------------------------- // Random Forest @@ -159,16 +156,13 @@ class RfRegressorTest : public ::testing::TestWithParam> { void basicTest() { params = ::testing::TestWithParam>::GetParam(); - DecisionTree::DecisionTreeParams tree_params; - set_tree_params(tree_params, params.max_depth, params.max_leaves, - params.max_features, params.n_bins, params.split_algo, - params.min_samples_leaf, params.min_samples_split, - params.min_impurity_decrease, params.bootstrap_features, - params.split_criterion, false); RF_params rf_params; - set_all_rf_params(rf_params, params.n_trees, params.bootstrap, - params.max_samples, 0, params.n_streams, tree_params); - //print(rf_params); + rf_params = set_rf_params( + params.max_depth, params.max_leaves, params.max_features, params.n_bins, + params.split_algo, params.min_samples_leaf, params.min_samples_split, + params.min_impurity_decrease, params.bootstrap_features, params.bootstrap, + params.n_trees, params.max_samples, 0, params.split_criterion, false, + params.n_streams, false, 128); //-------------------------------------------------------- // Random Forest diff --git a/cpp/test/sg/rf_treelite_test.cu b/cpp/test/sg/rf_treelite_test.cu index 31a6f4f9d0..684c06f68c 100644 --- a/cpp/test/sg/rf_treelite_test.cu +++ b/cpp/test/sg/rf_treelite_test.cu @@ -183,14 +183,13 @@ class RfTreeliteTestCommon : public ::testing::TestWithParam> { void SetUp() override { params = ::testing::TestWithParam>::GetParam(); - DecisionTree::DecisionTreeParams tree_params; - set_tree_params(tree_params, params.max_depth, params.max_leaves, - params.max_features, params.n_bins, params.split_algo, - params.min_samples_leaf, params.min_samples_split, - params.min_impurity_decrease, params.bootstrap_features, - params.split_criterion, false); - set_all_rf_params(rf_params, params.n_trees, params.bootstrap, - params.max_samples, 0, params.n_streams, tree_params); + rf_params = set_rf_params( + params.max_depth, params.max_leaves, params.max_features, params.n_bins, + params.split_algo, params.min_samples_leaf, params.min_samples_split, + params.min_impurity_decrease, params.bootstrap_features, params.bootstrap, + params.n_trees, params.max_samples, 0, params.split_criterion, false, + params.n_streams, false, 128); + handle.reset(new raft::handle_t(rf_params.n_streams)); data_len = params.n_rows * params.n_cols; diff --git a/python/cuml/ensemble/randomforest_shared.pxd b/python/cuml/ensemble/randomforest_shared.pxd index c2b45cd8b4..ce9ce05a76 100644 --- a/python/cuml/ensemble/randomforest_shared.pxd +++ b/python/cuml/ensemble/randomforest_shared.pxd @@ -44,19 +44,6 @@ cdef extern from "cuml/ensemble/randomforest.hpp" namespace "ML": MAE, CRITERION_END -cdef extern from "cuml/tree/decisiontree.hpp" namespace "ML::DecisionTree": - cdef struct DecisionTreeParams: - int max_depth - int max_leaves - float max_features - int n_bins - int split_algo - int min_samples_leaf - int min_samples_split - bool bootstrap_features - bool quantile_per_tree - CRITERION split_criterion - cdef extern from "cuml/ensemble/randomforest.hpp" namespace "ML": cdef enum RF_type: @@ -103,24 +90,24 @@ cdef extern from "cuml/ensemble/randomforest.hpp" namespace "ML": ) except + cdef string get_rf_json[T, L](RandomForestMetaData[T, L]*) except + - cdef RF_params set_rf_class_obj(int, - int, - float, - int, - int, - int, - int, - float, - bool, - bool, - int, - float, - uint64_t, - CRITERION, - bool, - int, - bool, - int) except + + cdef RF_params set_rf_params(int, + int, + float, + int, + int, + int, + int, + float, + bool, + bool, + int, + float, + uint64_t, + CRITERION, + bool, + int, + bool, + int) except + cdef vector[unsigned char] save_model(ModelHandle) diff --git a/python/cuml/ensemble/randomforestclassifier.pyx b/python/cuml/ensemble/randomforestclassifier.pyx index 058ec5b591..e3069d27ea 100644 --- a/python/cuml/ensemble/randomforestclassifier.pyx +++ b/python/cuml/ensemble/randomforestclassifier.pyx @@ -464,24 +464,24 @@ class RandomForestClassifier(BaseRandomForestModel, else: seed_val = self.random_state - rf_params = set_rf_class_obj( self.max_depth, - self.max_leaves, - max_feature_val, - self.n_bins, - self.split_algo, - self.min_samples_leaf, - self.min_samples_split, - self.min_impurity_decrease, - self.bootstrap_features, - self.bootstrap, - self.n_estimators, - self.max_samples, - seed_val, - self.split_criterion, - self.quantile_per_tree, - self.n_streams, - self.use_experimental_backend, - self.max_batch_size) + rf_params = set_rf_params( self.max_depth, + self.max_leaves, + max_feature_val, + self.n_bins, + self.split_algo, + self.min_samples_leaf, + self.min_samples_split, + self.min_impurity_decrease, + self.bootstrap_features, + self.bootstrap, + self.n_estimators, + self.max_samples, + seed_val, + self.split_criterion, + self.quantile_per_tree, + self.n_streams, + self.use_experimental_backend, + self.max_batch_size) if self.dtype == np.float32: fit(handle_[0], diff --git a/python/cuml/ensemble/randomforestregressor.pyx b/python/cuml/ensemble/randomforestregressor.pyx index e7cadb7eb3..f36143e157 100644 --- a/python/cuml/ensemble/randomforestregressor.pyx +++ b/python/cuml/ensemble/randomforestregressor.pyx @@ -444,24 +444,24 @@ class RandomForestRegressor(BaseRandomForestModel, else: seed_val = self.random_state - rf_params = set_rf_class_obj( self.max_depth, - self.max_leaves, - max_feature_val, - self.n_bins, - self.split_algo, - self.min_samples_leaf, - self.min_samples_split, - self.min_impurity_decrease, - self.bootstrap_features, - self.bootstrap, - self.n_estimators, - self.max_samples, - seed_val, - self.split_criterion, - self.quantile_per_tree, - self.n_streams, - self.use_experimental_backend, - self.max_batch_size) + rf_params = set_rf_params( self.max_depth, + self.max_leaves, + max_feature_val, + self.n_bins, + self.split_algo, + self.min_samples_leaf, + self.min_samples_split, + self.min_impurity_decrease, + self.bootstrap_features, + self.bootstrap, + self.n_estimators, + self.max_samples, + seed_val, + self.split_criterion, + self.quantile_per_tree, + self.n_streams, + self.use_experimental_backend, + self.max_batch_size) if self.dtype == np.float32: fit(handle_[0],